From 4edf622fb7d92a20d491851c05ba97d0e86b1368 Mon Sep 17 00:00:00 2001 From: codethinki Date: Wed, 18 Feb 2026 20:25:03 +0100 Subject: [PATCH 01/24] ci modifications (partial) --- .github/workflows/build-test.yml | 1 - .github/workflows/check-clang-format.yml | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/check-clang-format.yml diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 7f4c5c0..4ea8656 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -3,7 +3,6 @@ name: Build & Test on: push: { branches: ["master", "_master/add_ci"] } pull_request: { branches: ["master"] } - workflow_dispatch: permissions: { contents: read, packages: write } diff --git a/.github/workflows/check-clang-format.yml b/.github/workflows/check-clang-format.yml new file mode 100644 index 0000000..925d899 --- /dev/null +++ b/.github/workflows/check-clang-format.yml @@ -0,0 +1,15 @@ +name: Clang Format + +on: [push, pull_request] + +permissions: { contents: read } +jobs: + formatting-check: + name: Format check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Run clang-format style check for C/C++ programs. + uses: jidicula/clang-format-action@v4.2.0 + with: + clang-format-version: '13' From 40a8a10176836580a342bcad57da3f511729d976 Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Wed, 18 Feb 2026 20:25:35 +0100 Subject: [PATCH 02/24] adding clang tidy and clang format files --- .clang-format | 130 ++++++++++++++++++++++++++------------------------ .clang_tidy | 107 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+), 62 deletions(-) create mode 100644 .clang_tidy diff --git a/.clang-format b/.clang-format index 7ee05de..7da0a7c 100644 --- a/.clang-format +++ b/.clang-format @@ -1,88 +1,94 @@ --- -AccessModifierOffset: -1 -AlignAfterOpenBracket: AlwaysBreak +Language: Cpp +BasedOnStyle: LLVM +AccessModifierOffset: -4 +AlignAfterOpenBracket: BlockIndent AlignConsecutiveAssignments: false AlignConsecutiveDeclarations: false -AlignEscapedNewlinesLeft: true -AlignOperands: false -AlignTrailingComments: false +AlignOperands: false +AlignEscapedNewlines: Left +AllowAllArgumentsOnNextLine: false AllowAllParametersOfDeclarationOnNextLine: false -AllowShortBlocksOnASingleLine: false -AllowShortCaseLabelsOnASingleLine: false -AllowShortFunctionsOnASingleLine: Empty +AllowShortBlocksOnASingleLine: Empty +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: All AllowShortIfStatementsOnASingleLine: false +AllowShortLambdasOnASingleLine: All AllowShortLoopsOnASingleLine: false AlwaysBreakAfterReturnType: None -AlwaysBreakBeforeMultilineStrings: true -AlwaysBreakTemplateDeclarations: true +AlwaysBreakTemplateDeclarations: Yes BinPackArguments: false BinPackParameters: false BraceWrapping: - AfterClass: false + AfterCaseLabel: false + AfterClass: false AfterControlStatement: false - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - BeforeCatch: false - BeforeElse: false - IndentBraces: false -BreakBeforeBinaryOperators: None -BreakBeforeBraces: Attach + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + BeforeLambdaBody: false + BeforeWhile: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBraces: Custom BreakBeforeTernaryOperators: true -BreakConstructorInitializersBeforeComma: false -BreakAfterJavaFieldAnnotations: false -BreakStringLiterals: false -ColumnLimit: 80 -CommentPragmas: '^ IWYU pragma:' -ConstructorInitializerAllOnOneLineOrOnePerLine: true -ConstructorInitializerIndentWidth: 4 +BreakConstructorInitializers: AfterColon +BreakInheritanceList: BeforeComma +ColumnLimit: 110 +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: false ContinuationIndentWidth: 4 -Cpp11BracedListStyle: true -DerivePointerAlignment: false -DisableFormat: false -ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] +EmptyLineAfterAccessModifier: Never +EmptyLineBeforeAccessModifier: LogicalBlock +FixNamespaceComments: false +IncludeBlocks: Preserve IncludeCategories: - - Regex: '^<.*\.h(pp)?>' - Priority: 1 - - Regex: '^<.*' - Priority: 2 - - Regex: '.*' - Priority: 3 + - Regex: '^".*' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IndentCaseBlocks: false IndentCaseLabels: true +IndentGotoLabels: true IndentPPDirectives: None -IndentWidth: 2 +IndentWidth: 4 IndentWrappedFunctionNames: false -KeepEmptyLinesAtTheStartOfBlocks: false +InsertNewlineAtEOF: true MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None -ObjCBlockIndentWidth: 2 -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: false -PenaltyBreakBeforeFirstCallParameter: 1 -PenaltyBreakComment: 300 -PenaltyBreakFirstLessLess: 120 -PenaltyBreakString: 1000 -PenaltyExcessCharacter: 1000000 -PenaltyReturnTypeOnItsOwnLine: 200 +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 2 +NamespaceIndentation: Inner PointerAlignment: Left -ReflowComments: true -SortIncludes: true -SpaceAfterCStyleCast: false +RequiresClausePosition: WithPreceding +SpaceAfterCStyleCast: true +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: false SpaceBeforeAssignmentOperators: true -SpaceBeforeParens: ControlStatements +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: Never +SpaceBeforeRangeBasedForLoopColon: true +SpaceBeforeSquareBrackets: false SpaceInEmptyParentheses: false SpacesBeforeTrailingComments: 1 -SpacesInAngles: false -SpacesInContainerLiterals: true +SpacesInAngles: false +SpacesInConditionalStatement: false SpacesInCStyleCastParentheses: false SpacesInParentheses: false SpacesInSquareBrackets: false -Standard: Cpp11 -TabWidth: 8 -UseTab: Never +TabWidth: 4 +UseTab: Never +PenaltyBreakBeforeFirstCallParameter: 0 +PenaltyBreakTemplateDeclaration: 0 +PenaltyReturnTypeOnItsOwnLine: 10 ... + diff --git a/.clang_tidy b/.clang_tidy new file mode 100644 index 0000000..0264f2a --- /dev/null +++ b/.clang_tidy @@ -0,0 +1,107 @@ +Checks: + # --- 1. Enable broad categories by default --- + - 'bugprone-*' + - 'misc-*' + - 'modernize-*' + - 'performance-*' + - 'readability-*' + - 'mpi-*' + - 'openmp-*' + + # --- 2. CERT (C++ Secure Coding Standards) --- + - 'cert-err52-cpp' + - 'cert-err60-cpp' + - 'cert-err34-c' + - 'cert-err33-c' + - 'cert-str34-c' + - 'cert-mem57-cpp' + - 'cert-msc50-cpp' + - 'cert-oop57-cpp' + - 'cert-msc51-cpp' + - 'cert-dcl58-cpp' + - 'cert-flp30-c' + + # --- 3. C++ Core Guidelines --- + - 'cppcoreguidelines-avoid-capturing-lambda-coroutines' + - 'cppcoreguidelines-avoid-const-or-ref-data-members' + - 'cppcoreguidelines-misleading-capture-default-by-value' + - 'cppcoreguidelines-prefer-member-initializer' + - 'cppcoreguidelines-no-suspend-with-lock' + - 'cppcoreguidelines-rvalue-reference-param-not-moved' + - 'cppcoreguidelines-explicit-virtual-functions' + - 'cppcoreguidelines-slicing' + - 'cppcoreguidelines-pro-type-cstyle-cast' + - 'cppcoreguidelines-interfaces-global-init' + - 'cppcoreguidelines-pro-type-static-cast-downcast' + - 'cppcoreguidelines-narrowing-conversions' + - 'cppcoreguidelines-pro-bounds-constant-array-index' + - 'cppcoreguidelines-missing-std-forward' + - 'cppcoreguidelines-avoid-magic-numbers' + - 'cppcoreguidelines-pro-bounds-array-to-pointer-decay' + + # --- 4. High Integrity C++ --- + - 'hicpp-multiway-paths-covered' + + # --- 5. Portability --- + - 'portability-std-allocator-const' + - 'portability-simd-intrinsics' + + # --- 6. Exclusions (Disabling specific checks) --- + # Exclusions for 'bugprone-*' + - '-bugprone-switch-missing-default-case' + - '-bugprone-casting-through-void' + - '-bugprone-exception-escape' + - '-bugprone-tagged-union-member-count' + - '-bugprone-suspicious-stringview-data-usage' + - '-bugprone-multiple-new-in-one-expression' + - '-bugprone-incorrect-enable-shared-from-this' + - '-bugprone-misleading-setter-of-reference' + - '-bugprone-nondeterministic-pointer-iteration-order' + - '-bugprone-incorrect-enable-if' + - '-bugprone-unintended-char-ostream-output' + - '-bugprone-bool-pointer-implicit-conversion' + - '-bugprone-crtp-constructor-accessibility' + - '-bugprone-multi-level-implicit-pointer-conversion' + - '-bugprone-easily-swappable-parameters' + - '-bugprone-non-zero-enum-to-bool-conversion' + - '-bugprone-not-null-terminated-result' + - '-bugprone-standalone-empty' + + # Exclusions for 'misc-*' + - '-misc-unused-parameters' + - '-misc-misleading-identifier' + - '-misc-confusable-identifiers' + - '-misc-misleading-bidirectional' + - '-misc-header-include-cycle' + - '-misc-non-private-member-variables-in-classes' + - '-misc-redundant-expression' + + # Exclusions for 'modernize-*' + - '-modernize-use-designated-initializers' + - '-modernize-use-trailing-return-type' + - '-modernize-avoid-c-arrays' + - '-modernize-macro-to-enum' + + # Exclusions for 'performance-*' + - '-performance-noexcept-swap' + - '-performance-noexcept-destructor' + - '-performance-enum-size' + - '-performance-no-int-to-ptr' + - '-performance-avoid-endl' + + # Exclusions for 'readability-*' + - '-readability-named-parameter' + - '-readability-function-size' + - '-readability-identifier-length' + - '-readability-uppercase-literal-suffix' + - '-readability-math-missing-parentheses' + - '-readability-operators-representation' + - '-readability-ambiguous-smartptr-reset-call' + - '-readability-implicit-bool-conversion' + - '-readability-braces-around-statements' + - '-readability-qualified-auto' + - '-readability-container-data-pointer' + - '-readability-avoid-unconditional-preprocessor-if' + - '-readability-function-cognitive-complexity' + - '-readability-identifier-naming' + - '-readability-enum-initial-value' From 6a8ebe17c22054aaae35ee290d1aa24ebdc2b1c7 Mon Sep 17 00:00:00 2001 From: codethinki Date: Sat, 21 Feb 2026 19:42:21 +0100 Subject: [PATCH 03/24] bumped clang format to 21 --- .github/workflows/check-clang-format.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check-clang-format.yml b/.github/workflows/check-clang-format.yml index 925d899..94501f5 100644 --- a/.github/workflows/check-clang-format.yml +++ b/.github/workflows/check-clang-format.yml @@ -12,4 +12,4 @@ jobs: - name: Run clang-format style check for C/C++ programs. uses: jidicula/clang-format-action@v4.2.0 with: - clang-format-version: '13' + clang-format-version: '21' From d2bfc4fef3035122bc3d49db373863ac7075f848 Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Wed, 18 Feb 2026 20:25:35 +0100 Subject: [PATCH 04/24] adding clang tidy and clang format files --- .clang-format | 130 ++++++++++++++++++++++++++------------------------ .clang_tidy | 107 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+), 62 deletions(-) create mode 100644 .clang_tidy diff --git a/.clang-format b/.clang-format index 7ee05de..7da0a7c 100644 --- a/.clang-format +++ b/.clang-format @@ -1,88 +1,94 @@ --- -AccessModifierOffset: -1 -AlignAfterOpenBracket: AlwaysBreak +Language: Cpp +BasedOnStyle: LLVM +AccessModifierOffset: -4 +AlignAfterOpenBracket: BlockIndent AlignConsecutiveAssignments: false AlignConsecutiveDeclarations: false -AlignEscapedNewlinesLeft: true -AlignOperands: false -AlignTrailingComments: false +AlignOperands: false +AlignEscapedNewlines: Left +AllowAllArgumentsOnNextLine: false AllowAllParametersOfDeclarationOnNextLine: false -AllowShortBlocksOnASingleLine: false -AllowShortCaseLabelsOnASingleLine: false -AllowShortFunctionsOnASingleLine: Empty +AllowShortBlocksOnASingleLine: Empty +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: All AllowShortIfStatementsOnASingleLine: false +AllowShortLambdasOnASingleLine: All AllowShortLoopsOnASingleLine: false AlwaysBreakAfterReturnType: None -AlwaysBreakBeforeMultilineStrings: true -AlwaysBreakTemplateDeclarations: true +AlwaysBreakTemplateDeclarations: Yes BinPackArguments: false BinPackParameters: false BraceWrapping: - AfterClass: false + AfterCaseLabel: false + AfterClass: false AfterControlStatement: false - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - BeforeCatch: false - BeforeElse: false - IndentBraces: false -BreakBeforeBinaryOperators: None -BreakBeforeBraces: Attach + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + BeforeLambdaBody: false + BeforeWhile: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBraces: Custom BreakBeforeTernaryOperators: true -BreakConstructorInitializersBeforeComma: false -BreakAfterJavaFieldAnnotations: false -BreakStringLiterals: false -ColumnLimit: 80 -CommentPragmas: '^ IWYU pragma:' -ConstructorInitializerAllOnOneLineOrOnePerLine: true -ConstructorInitializerIndentWidth: 4 +BreakConstructorInitializers: AfterColon +BreakInheritanceList: BeforeComma +ColumnLimit: 110 +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: false ContinuationIndentWidth: 4 -Cpp11BracedListStyle: true -DerivePointerAlignment: false -DisableFormat: false -ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] +EmptyLineAfterAccessModifier: Never +EmptyLineBeforeAccessModifier: LogicalBlock +FixNamespaceComments: false +IncludeBlocks: Preserve IncludeCategories: - - Regex: '^<.*\.h(pp)?>' - Priority: 1 - - Regex: '^<.*' - Priority: 2 - - Regex: '.*' - Priority: 3 + - Regex: '^".*' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IndentCaseBlocks: false IndentCaseLabels: true +IndentGotoLabels: true IndentPPDirectives: None -IndentWidth: 2 +IndentWidth: 4 IndentWrappedFunctionNames: false -KeepEmptyLinesAtTheStartOfBlocks: false +InsertNewlineAtEOF: true MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None -ObjCBlockIndentWidth: 2 -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: false -PenaltyBreakBeforeFirstCallParameter: 1 -PenaltyBreakComment: 300 -PenaltyBreakFirstLessLess: 120 -PenaltyBreakString: 1000 -PenaltyExcessCharacter: 1000000 -PenaltyReturnTypeOnItsOwnLine: 200 +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 2 +NamespaceIndentation: Inner PointerAlignment: Left -ReflowComments: true -SortIncludes: true -SpaceAfterCStyleCast: false +RequiresClausePosition: WithPreceding +SpaceAfterCStyleCast: true +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: false SpaceBeforeAssignmentOperators: true -SpaceBeforeParens: ControlStatements +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: Never +SpaceBeforeRangeBasedForLoopColon: true +SpaceBeforeSquareBrackets: false SpaceInEmptyParentheses: false SpacesBeforeTrailingComments: 1 -SpacesInAngles: false -SpacesInContainerLiterals: true +SpacesInAngles: false +SpacesInConditionalStatement: false SpacesInCStyleCastParentheses: false SpacesInParentheses: false SpacesInSquareBrackets: false -Standard: Cpp11 -TabWidth: 8 -UseTab: Never +TabWidth: 4 +UseTab: Never +PenaltyBreakBeforeFirstCallParameter: 0 +PenaltyBreakTemplateDeclaration: 0 +PenaltyReturnTypeOnItsOwnLine: 10 ... + diff --git a/.clang_tidy b/.clang_tidy new file mode 100644 index 0000000..ec792f6 --- /dev/null +++ b/.clang_tidy @@ -0,0 +1,107 @@ +Checks: + # --- 1. Enable broad categories by default --- + - 'bugprone-*' + - 'misc-*' + - 'modernize-*' + - 'performance-*' + - 'readability-*' + - 'mpi-*' + - 'openmp-*' + + # --- 2. CERT (C++ Secure Coding Standards) --- + - 'cert-err52-cpp' + - 'cert-err60-cpp' + - 'cert-err34-c' + - 'cert-err33-c' + - 'cert-str34-c' + - 'cert-mem57-cpp' + - 'cert-msc50-cpp' + - 'cert-oop57-cpp' + - 'cert-msc51-cpp' + - 'cert-dcl58-cpp' + - 'cert-flp30-c' + + # --- 3. C++ Core Guidelines --- + - 'cppcoreguidelines-avoid-capturing-lambda-coroutines' + - 'cppcoreguidelines-avoid-const-or-ref-data-members' + - 'cppcoreguidelines-misleading-capture-default-by-value' + - 'cppcoreguidelines-prefer-member-initializer' + - 'cppcoreguidelines-no-suspend-with-lock' + - 'cppcoreguidelines-rvalue-reference-param-not-moved' + - 'cppcoreguidelines-explicit-virtual-functions' + - 'cppcoreguidelines-slicing' + - 'cppcoreguidelines-pro-type-cstyle-cast' + - 'cppcoreguidelines-interfaces-global-init' + - 'cppcoreguidelines-pro-type-static-cast-downcast' + - 'cppcoreguidelines-narrowing-conversions' + - 'cppcoreguidelines-pro-bounds-constant-array-index' + - 'cppcoreguidelines-missing-std-forward' + - 'cppcoreguidelines-avoid-magic-numbers' + - 'cppcoreguidelines-pro-bounds-array-to-pointer-decay' + + # --- 4. High Integrity C++ --- + - 'hicpp-multiway-paths-covered' + + # --- 5. Portability --- + - 'portability-std-allocator-const' + - 'portability-simd-intrinsics' + + # --- 6. Exclusions (Disabling specific checks) --- + # Exclusions for 'bugprone-*' + - '-bugprone-switch-missing-default-case' + - '-bugprone-casting-through-void' + - '-bugprone-exception-escape' + - '-bugprone-tagged-union-member-count' + - '-bugprone-suspicious-stringview-data-usage' + - '-bugprone-multiple-new-in-one-expression' + - '-bugprone-incorrect-enable-shared-from-this' + - '-bugprone-misleading-setter-of-reference' + - '-bugprone-nondeterministic-pointer-iteration-order' + - '-bugprone-incorrect-enable-if' + - '-bugprone-unintended-char-ostream-output' + - '-bugprone-bool-pointer-implicit-conversion' + - '-bugprone-crtp-constructor-accessibility' + - '-bugprone-multi-level-implicit-pointer-conversion' + - '-bugprone-easily-swappable-parameters' + - '-bugprone-non-zero-enum-to-bool-conversion' + - '-bugprone-not-null-terminated-result' + - '-bugprone-standalone-empty' + + # Exclusions for 'misc-*' + - '-misc-unused-parameters' + - '-misc-misleading-identifier' + - '-misc-confusable-identifiers' + - '-misc-misleading-bidirectional' + - '-misc-header-include-cycle' + - '-misc-non-private-member-variables-in-classes' + - '-misc-redundant-expression' + + # Exclusions for 'modernize-*' + - '-modernize-use-designated-initializers' + - '-modernize-use-trailing-return-type' + - '-modernize-avoid-c-arrays' + - '-modernize-macro-to-enum' + + # Exclusions for 'performance-*' + - '-performance-noexcept-swap' + - '-performance-noexcept-destructor' + - '-performance-enum-size' + - '-performance-no-int-to-ptr' + - '-performance-avoid-endl' + + # Exclusions for 'readability-*' + - '-readability-named-parameter' + - '-readability-function-size' + - '-readability-identifier-length' + - '-readability-uppercase-literal-suffix' + - '-readability-math-missing-parentheses' + - '-readability-operators-representation' + - '-readability-ambiguous-smartptr-reset-call' + - '-readability-implicit-bool-conversion' + - '-readability-braces-around-statements' + - '-readability-qualified-auto' + - '-readability-container-data-pointer' + - '-readability-avoid-unconditional-preprocessor-if' + - '-readability-function-cognitive-complexity' + - '-readability-identifier-naming' + - '-readability-enum-initial-value' \ No newline at end of file From 9c4a6d69d3b278620350784f55f844417acd9336 Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Sun, 22 Feb 2026 16:38:36 +0100 Subject: [PATCH 05/24] added the clang format target --- CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5bfeed9..282e70f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -142,3 +142,8 @@ include(${FL_ROOT_DIR}/pkg/CMakeLists.txt) # --------------------------- Cleanup --------------------------- setup_install_targets(INSTALL_TARGETS ${INSTALLABLE_TARGETS}) + +# --------------------------- Other ------------------------------ +include(fm_target_utilities) +fm_glob_cpp(FM_CPP) +fm_add_clang_format_target(clang-format) From d95ab79dd8017fa50ce3da61e86c1f4ccce94710 Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Sun, 22 Feb 2026 16:50:52 +0100 Subject: [PATCH 06/24] moved clang format version to 19 (maybe that works) --- .github/workflows/check-clang-format.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/check-clang-format.yml b/.github/workflows/check-clang-format.yml index 94501f5..20bc3ea 100644 --- a/.github/workflows/check-clang-format.yml +++ b/.github/workflows/check-clang-format.yml @@ -8,8 +8,10 @@ jobs: name: Format check runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Run clang-format style check for C/C++ programs. + - uses: actions/checkout@v4 + + - name: Run clang-format style check uses: jidicula/clang-format-action@v4.2.0 with: - clang-format-version: '21' + clang-format-version: '19' + check-path: 'flashlight' From 3f19fe02f30e8a63e69edece8628c768ae6d17bd Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Sun, 22 Feb 2026 16:59:20 +0100 Subject: [PATCH 07/24] updated the check-clang-format away from jidicula bc it just didnt work --- .github/workflows/check-clang-format.yml | 29 ++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/.github/workflows/check-clang-format.yml b/.github/workflows/check-clang-format.yml index 20bc3ea..1a89c08 100644 --- a/.github/workflows/check-clang-format.yml +++ b/.github/workflows/check-clang-format.yml @@ -3,6 +3,21 @@ name: Clang Format on: [push, pull_request] permissions: { contents: read } + +# --------------------------------------------------------- +# CONFIG +# --------------------------------------------------------- +env: + CLANG_FORMAT_VERSION: "21.1.0" + CHECK_PATH: "flashlight" + FILE_EXTENSIONS: "c|cpp|h|hpp|cu" + + + + +# --------------------------------------------------------- +# JOB +# --------------------------------------------------------- jobs: formatting-check: name: Format check @@ -10,8 +25,14 @@ jobs: steps: - uses: actions/checkout@v4 + - name: Install clang-format + # Installs the exact version specified in the env block + run: pipx install clang-format==${{ env.CLANG_FORMAT_VERSION }} + - name: Run clang-format style check - uses: jidicula/clang-format-action@v4.2.0 - with: - clang-format-version: '19' - check-path: 'flashlight' + run: | + find ${{ env.CHECK_PATH }} \ + -type f \ + -regextype posix-extended \ + -regex ".*\.(${{ env.FILE_EXTENSIONS }})$" \ + | xargs -r clang-format --style=file --dry-run --Werror \ No newline at end of file From 482d739e94841213027e73b08eac80bb530a7aa6 Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Sun, 22 Feb 2026 17:02:42 +0100 Subject: [PATCH 08/24] added files per thread to hopefully speed up the formatting check --- .github/workflows/check-clang-format.yml | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/check-clang-format.yml b/.github/workflows/check-clang-format.yml index 1a89c08..a2609c3 100644 --- a/.github/workflows/check-clang-format.yml +++ b/.github/workflows/check-clang-format.yml @@ -11,7 +11,7 @@ env: CLANG_FORMAT_VERSION: "21.1.0" CHECK_PATH: "flashlight" FILE_EXTENSIONS: "c|cpp|h|hpp|cu" - + FILES_PER_THREAD: "40" @@ -26,7 +26,6 @@ jobs: - uses: actions/checkout@v4 - name: Install clang-format - # Installs the exact version specified in the env block run: pipx install clang-format==${{ env.CLANG_FORMAT_VERSION }} - name: Run clang-format style check @@ -35,4 +34,4 @@ jobs: -type f \ -regextype posix-extended \ -regex ".*\.(${{ env.FILE_EXTENSIONS }})$" \ - | xargs -r clang-format --style=file --dry-run --Werror \ No newline at end of file + | xargs -r -P $(nproc) -n ${{ env.FILES_PER_THREAD }} clang-format --style=file --dry-run --Werror \ No newline at end of file From 6e8ac15d8b1a38836eb2ac7b2a29fc49d0c1a1af Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Sun, 22 Feb 2026 17:04:35 +0100 Subject: [PATCH 09/24] updated the "on" to only run for pr's and pushes to master --- .github/workflows/check-clang-format.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/check-clang-format.yml b/.github/workflows/check-clang-format.yml index a2609c3..945c4ec 100644 --- a/.github/workflows/check-clang-format.yml +++ b/.github/workflows/check-clang-format.yml @@ -1,6 +1,8 @@ name: Clang Format -on: [push, pull_request] +on: + push: { branches: ["master", "_master/add_ci"] } + pull_request: { branches: ["master"] } permissions: { contents: read } From 86d412f2dcf6ef2e2b0e798d7afb5b4b5d8fbdde Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Sun, 22 Feb 2026 21:35:51 +0100 Subject: [PATCH 10/24] a few changes to formatting and added the format target --- .clang-format | 4 +- CMakeLists.txt | 4 +- cmake/utils/fm_target_utilities.cmake | 114 ++++++++++++++++++++------ 3 files changed, 93 insertions(+), 29 deletions(-) diff --git a/.clang-format b/.clang-format index 7da0a7c..9e9ea45 100644 --- a/.clang-format +++ b/.clang-format @@ -37,6 +37,7 @@ BraceWrapping: SplitEmptyRecord: false SplitEmptyNamespace: false BreakBeforeBraces: Custom +BreakBeforeTemplateCloser: true BreakBeforeTernaryOperators: true BreakConstructorInitializers: AfterColon BreakInheritanceList: BeforeComma @@ -89,6 +90,7 @@ TabWidth: 4 UseTab: Never PenaltyBreakBeforeFirstCallParameter: 0 PenaltyBreakTemplateDeclaration: 0 -PenaltyReturnTypeOnItsOwnLine: 10 +PenaltyReturnTypeOnItsOwnLine: 1000 +PenaltyBreakScopeResolution: 1000 ... diff --git a/CMakeLists.txt b/CMakeLists.txt index 282e70f..69935d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -145,5 +145,5 @@ setup_install_targets(INSTALL_TARGETS ${INSTALLABLE_TARGETS}) # --------------------------- Other ------------------------------ include(fm_target_utilities) -fm_glob_cpp(FM_CPP) -fm_add_clang_format_target(clang-format) +fm_glob_cpp(FM_CPP "flashlight/*") +fm_add_clang_format_target(clang-format ${FM_CPP}) diff --git a/cmake/utils/fm_target_utilities.cmake b/cmake/utils/fm_target_utilities.cmake index ce2e437..a72071f 100644 --- a/cmake/utils/fm_target_utilities.cmake +++ b/cmake/utils/fm_target_utilities.cmake @@ -84,7 +84,7 @@ endfunction() #]] function(fm_glob_cpp OUT_VAR) - fm_glob(${OUT_VAR} ${ARGN} PATTERNS "*.cpp" "*.hpp" "*.inl") + fm_glob(${OUT_VAR} ${ARGN} PATTERNS "*.cpp" "*.hpp" "*.inl" "*.h") set(${OUT_VAR} ${${OUT_VAR}} PARENT_SCOPE) endfunction() @@ -366,58 +366,120 @@ endfunction() .. code-block:: cmake - fm_add_clang_format_target() + fm_add_clang_format_target( [OPTIONAL] ) - Creates a custom target named "format" that runs clang-format on specified files. + Creates a custom target that runs clang-format on specified files. + If OPTIONAL is specified, does not error and skips target creation if clang-format is not found. + :param target_name: Name of the custom target to create + :param OPTIONAL: If specified, do not raise FATAL_ERROR if clang-format is not found :param files: List of source files to format - :type files: list of file paths - :pre: clang-format executable is available in PATH - :post: A custom target named "format" is created that formats the specified files in-place + :post: A custom target is created if found, or configuration terminates with FATAL_ERROR if not found (unless OPTIONAL) .. note:: - The format target uses ``-i`` flag to format files in-place - The ``-style=file`` flag means clang-format will look for a .clang-format configuration file - Files are formatted relative to CMAKE_SOURCE_DIR - .. warning:: - This function will fail if clang-format is not found in PATH. + .. seealso:: + - ``fm_find_clang_format(OPTIONAL)`` from fm_tool_utilities to locate clang-format optionally + +#]] +function(fm_add_clang_format_target TARGET_NAME) + cmake_parse_arguments(PARSE_ARGV 1 ARG "OPTIONAL" "" "") + + # Use a different variable name to avoid conflicts with the parsed ARG_OPTIONAL boolean + if(ARG_OPTIONAL) + set(FIND_OPTIONAL_ARG "OPTIONAL") + else() + set(FIND_OPTIONAL_ARG "") + endif() + + fm_assert_not_empty("${TARGET_NAME}" REASON "add_clang_format_target requires a target name") + + # Use ARG_UNPARSED_ARGUMENTS instead of ARGN so "OPTIONAL" isn't treated as a file + set(FILES_TO_FORMAT ${ARG_UNPARSED_ARGUMENTS}) + fm_assert_not_empty("${FILES_TO_FORMAT}" REASON "no files provided") + + include(fm_tool_utilities) - **Example usage:** + # Pass the safely stored string to the find function + fm_find_clang_format(${FIND_OPTIONAL_ARG}) + if(NOT CLANG_FORMAT_EXECUTABLE) + return() + endif() + + set(FILE_LIST_PATH "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_files.txt") + string(REPLACE ";" "\n" FILES_TO_FORMAT_STR "${FILES_TO_FORMAT}") + file(WRITE "${FILE_LIST_PATH}" "${FILES_TO_FORMAT_STR}\n") + + add_custom_target( + ${TARGET_NAME} + COMMAND ${CLANG_FORMAT_EXECUTABLE} -i -style=file --files=${FILE_LIST_PATH} + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + COMMENT "Formatting all source files with clang-format..." + VERBATIM + ) +endfunction() + +#[[.rst: +.. command:: fm_add_uncrustify_target .. code-block:: cmake - # Format specific files - fm_add_clang_format_target( - src/main.cpp - src/utils.cpp - include/header.hpp - ) + fm_add_uncrustify_target( [OPTIONAL] ) + + Creates a custom target that runs uncrustify on specified files. + If OPTIONAL is specified, does not error and skips target creation if uncrustify is not found. - # Then run: cmake --build . --target format + :param target_name: Name of the custom target to create + :param OPTIONAL: If specified, do not raise FATAL_ERROR if uncrustify is not found + :param files: List of source files to format + + :post: A custom target is created if found, or configuration terminates with FATAL_ERROR if not found (unless OPTIONAL) + + .. note:: + - The format target uses ``--replace`` and ``--no-backup`` flags to format files in-place + - The ``-F`` flag is used to pass the text file containing the list of files to format + - Files are formatted relative to CMAKE_SOURCE_DIR + - Uncrustify will look for an uncrustify.cfg file in the working directory or rely on the UNCRUSTIFY_CONFIG environment variable. .. seealso:: - - ``fm_find_clang_format()`` from fm_tool_utilities to locate clang-format - - Create a .clang-format file in your project root to define formatting style + - ``fm_find_uncrustify(OPTIONAL)`` from fm_tool_utilities to locate uncrustify optionally #]] -function(fm_add_clang_format_target TARGET_NAME) - fm_assert_not_empty(${TARGET_NAME} REASON "add_clang_format_target requires a target name") +function(fm_add_uncrustify_target TARGET_NAME) + cmake_parse_arguments(PARSE_ARGV 1 ARG "OPTIONAL" "" "") + + # Use a different variable name to avoid being overwritten by cmake_parse_arguments + if(ARG_OPTIONAL) + set(FIND_OPTIONAL_ARG "OPTIONAL") + else() + set(FIND_OPTIONAL_ARG "") + endif() - include(fm_tool_utilities) - fm_find_clang_format() + fm_assert_not_empty("${TARGET_NAME}" REASON "add_uncrustify_target requires a target name") + + set(FILES_TO_FORMAT ${ARG_UNPARSED_ARGUMENTS}) + fm_assert_not_empty("${FILES_TO_FORMAT}" REASON "no files provided") - set(FILES_TO_FORMAT ${ARGN}) + include(fm_tool_utilities) - + fm_find_uncrustify(${FIND_OPTIONAL_ARG}) + if(NOT UNCRUSTIFY_EXECUTABLE) + return() + endif() + set(FILE_LIST_PATH "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_files.txt") + string(REPLACE ";" "\n" FILES_TO_FORMAT_STR "${FILES_TO_FORMAT}") + file(WRITE "${FILE_LIST_PATH}" "${FILES_TO_FORMAT_STR}\n") add_custom_target( ${TARGET_NAME} - COMMAND ${CLANG_FORMAT_EXECUTABLE} -i -style=file ${FILES_TO_FORMAT} + COMMAND ${UNCRUSTIFY_EXECUTABLE} --replace --no-backup -F ${FILE_LIST_PATH} WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} - COMMENT "Formatting all source files with clang-format..." + COMMENT "Formatting all source files with uncrustify..." VERBATIM ) endfunction() From 906c05da1a645945e381cb3779458de1f44f6404 Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Sun, 22 Feb 2026 21:39:33 +0100 Subject: [PATCH 11/24] added uncrustify to try --- CMakeLists.txt | 2 ++ cmake/utils/fm_tool_utilities.cmake | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 69935d4..f9d0ed9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -146,4 +146,6 @@ setup_install_targets(INSTALL_TARGETS ${INSTALLABLE_TARGETS}) # --------------------------- Other ------------------------------ include(fm_target_utilities) fm_glob_cpp(FM_CPP "flashlight/*") + fm_add_clang_format_target(clang-format ${FM_CPP}) +fm_add_uncrustify_target(uncrustify-format ${FM_CPP}) diff --git a/cmake/utils/fm_tool_utilities.cmake b/cmake/utils/fm_tool_utilities.cmake index 48c144f..da0e896 100644 --- a/cmake/utils/fm_tool_utilities.cmake +++ b/cmake/utils/fm_tool_utilities.cmake @@ -114,3 +114,31 @@ function(fm_find_clang_format) set(CLANG_FORMAT_EXECUTABLE ${CLANG_FORMAT_EXECUTABLE} PARENT_SCOPE) endfunction() + +#[[.rst: +.. command:: fm_find_uncrustify + + .. code-block:: cmake + + fm_find_uncrustify([OPTIONAL]) + + Locates a required clang-format executable and exports its path to the parent scope. + If OPTIONAL is specified, does not error if clang-format is not found. + + :post: UNCRUSTIFY_EXECUTABLE is set in PARENT_SCOPE with the full path to clang-format, or configuration terminates with FATAL_ERROR if not found and not OPTIONAL + + .. seealso:: + Use ``fm_add_uncrustify_target()`` from fm_target_utilities to create a format target. + +#]] +function(fm_find_uncrustify) + cmake_parse_arguments(PARSE_ARGV 0 ARG "OPTIONAL" "" "") + + if(ARG_OPTIONAL) + fm_find_program(UNCRUSTIFY_EXECUTABLE clang-format OPTIONAL) + else() + fm_find_program(UNCRUSTIFY_EXECUTABLE clang-format) + endif() + + set(UNCRUSTIFY_EXECUTABLE ${UNCRUSTIFY_EXECUTABLE} PARENT_SCOPE) +endfunction() \ No newline at end of file From fbdf4452c31d5743d25d5bbe86e3032a55616270 Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Sun, 22 Feb 2026 21:43:28 +0100 Subject: [PATCH 12/24] some corrections and uncrustify file added --- cmake/utils/fm_tool_utilities.cmake | 10 +- uncrustify.cfg | 249 ++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+), 5 deletions(-) create mode 100644 uncrustify.cfg diff --git a/cmake/utils/fm_tool_utilities.cmake b/cmake/utils/fm_tool_utilities.cmake index da0e896..1e0ff72 100644 --- a/cmake/utils/fm_tool_utilities.cmake +++ b/cmake/utils/fm_tool_utilities.cmake @@ -122,10 +122,10 @@ endfunction() fm_find_uncrustify([OPTIONAL]) - Locates a required clang-format executable and exports its path to the parent scope. - If OPTIONAL is specified, does not error if clang-format is not found. + Locates a required uncrustify executable and exports its path to the parent scope. + If OPTIONAL is specified, does not error if uncrustify is not found. - :post: UNCRUSTIFY_EXECUTABLE is set in PARENT_SCOPE with the full path to clang-format, or configuration terminates with FATAL_ERROR if not found and not OPTIONAL + :post: UNCRUSTIFY_EXECUTABLE is set in PARENT_SCOPE with the full path to uncrustify, or configuration terminates with FATAL_ERROR if not found and not OPTIONAL .. seealso:: Use ``fm_add_uncrustify_target()`` from fm_target_utilities to create a format target. @@ -135,9 +135,9 @@ function(fm_find_uncrustify) cmake_parse_arguments(PARSE_ARGV 0 ARG "OPTIONAL" "" "") if(ARG_OPTIONAL) - fm_find_program(UNCRUSTIFY_EXECUTABLE clang-format OPTIONAL) + fm_find_program(UNCRUSTIFY_EXECUTABLE uncrustify OPTIONAL) else() - fm_find_program(UNCRUSTIFY_EXECUTABLE clang-format) + fm_find_program(UNCRUSTIFY_EXECUTABLE uncrustify) endif() set(UNCRUSTIFY_EXECUTABLE ${UNCRUSTIFY_EXECUTABLE} PARENT_SCOPE) diff --git a/uncrustify.cfg b/uncrustify.cfg new file mode 100644 index 0000000..d070a12 --- /dev/null +++ b/uncrustify.cfg @@ -0,0 +1,249 @@ +newlines = auto +input_tab_size = 4 +output_tab_size = 4 +code_width = 110 +ls_code_width = true +ls_func_split_full = true +indent_columns = 4 +indent_continue = 4 +indent_with_tabs = 0 +indent_namespace = true +indent_namespace_inner_only = true +indent_class = true +indent_class_colon = true +indent_constr_colon = true +indent_switch_case = 4 +indent_case_brace = 0 +indent_label = 1 +indent_access_spec = -4 +indent_access_spec_body = false +indent_ternary_operator = 1 +indent_func_call_param = false +indent_func_def_param = false +indent_func_proto_param = false +indent_func_class_param = false +indent_func_ctor_var_param = false +indent_template_param = false +use_indent_func_call_param = false +align_func_params = false +indent_paren_close = 0 +align_keep_tabs = false +align_with_tabs = false +align_on_tabstop = false +align_var_def_span = 0 +align_assign_span = 0 +align_enum_equ_span = 0 +align_var_class_span = 0 +align_var_struct_span = 0 +align_struct_init_span = 0 +align_typedef_span = 0 +align_right_cmt_span = 0 +align_func_proto_span = 0 +align_nl_cont = 2 +align_left_shift = false +sp_arith = force +sp_assign = force +sp_cpp_lambda_assign = force +sp_cpp_lambda_square_paren = remove +sp_cpp_lambda_square_brace = force +sp_cpp_lambda_argument_list = force +sp_cpp_lambda_paren_brace = force +sp_enum_assign = force +sp_bool = force +sp_compare = force +sp_inside_paren = remove +sp_paren_paren = remove +sp_cparen_oparen = remove +sp_paren_brace = force +sp_after_type = force +sp_template_angle = remove +sp_before_angle = remove +sp_inside_angle = remove +sp_inside_angle_empty = remove +sp_angle_word = force +sp_angle_shift = remove +sp_permit_cpp11_shift = true +sp_before_sparen = remove +sp_inside_sparen = remove +sp_inside_for = remove +sp_sparen_brace = force +sp_special_semi = remove +sp_before_semi = remove +sp_before_semi_for = remove +sp_before_semi_for_empty = remove +sp_after_semi = force +sp_after_semi_for = force +sp_after_semi_for_empty = remove +sp_before_square = remove +sp_before_squares = remove +sp_before_vardef_square = remove +sp_inside_square = remove +sp_after_comma = force +sp_before_comma = remove +sp_paren_comma = force +sp_after_class_colon = force +sp_before_class_colon = force +sp_after_constr_colon = force +sp_before_constr_colon = force +sp_before_case_colon = remove +sp_after_operator = remove +sp_after_operator_sym = remove +sp_after_cast = force +sp_inside_paren_cast = remove +sp_cpp_cast_paren = remove +sp_sizeof_paren = remove +sp_inside_braces_enum = force +sp_inside_braces_struct = force +sp_inside_braces = force +sp_inside_braces_empty = remove +sp_type_func = force +sp_func_proto_paren = remove +sp_func_def_paren = remove +sp_inside_fparens = remove +sp_inside_fparen = remove +sp_fparen_brace = force +sp_func_call_paren = remove +sp_func_call_paren_empty = remove +sp_func_class_paren = remove +sp_func_class_paren_empty = remove +sp_return_paren = force +sp_attribute_paren = remove +sp_defined_paren = remove +sp_throw_paren = force +sp_catch_paren = remove +sp_macro = force +sp_macro_func = force +sp_else_brace = force +sp_brace_else = force +sp_catch_brace = force +sp_brace_catch = force +sp_finally_brace = force +sp_brace_finally = force +sp_try_brace = force +sp_getset_brace = force +sp_before_dc = remove +sp_after_dc = remove +sp_not = remove +sp_inv = remove +sp_addr = remove +sp_member = remove +sp_deref = remove +sp_sign = remove +sp_incdec = remove +sp_before_nl_cont = force +sp_cond_colon = force +sp_cond_question = force +sp_case_label = force +sp_after_for_colon = force +sp_before_for_colon = force +sp_cmt_cpp_start = force +sp_endif_cmt = force +sp_after_new = force +sp_between_new_paren = remove +sp_before_tr_cmt = force +sp_num_before_tr_cmt = 1 +sp_before_ptr_star = remove +sp_after_ptr_star = force +sp_before_unnamed_ptr_star = remove +sp_between_ptr_star = remove +sp_after_ptr_star_func = force +sp_before_ptr_star_func = remove +sp_before_byref = remove +sp_before_unnamed_byref = remove +sp_after_byref = force +sp_after_byref_func = force +sp_before_byref_func = remove +nl_collapse_empty_body = true +nl_collapse_empty_body_functions = true +nl_assign_leave_one_liners = true +nl_class_leave_one_liners = true +nl_enum_leave_one_liners = true +nl_getset_leave_one_liners = true +nl_func_leave_one_liners = true +nl_cpp_lambda_leave_one_liners = true +nl_if_leave_one_liners = false +nl_while_leave_one_liners = false +nl_do_leave_one_liners = false +nl_for_leave_one_liners = false +nl_start_of_file = remove +nl_end_of_file = force +nl_end_of_file_min = 1 +nl_assign_brace = remove +nl_fcall_brace = remove +nl_enum_brace = remove +nl_struct_brace = remove +nl_union_brace = remove +nl_if_brace = remove +nl_brace_else = remove +nl_elseif_brace = remove +nl_else_brace = remove +nl_else_if = remove +nl_try_brace = remove +nl_for_brace = remove +nl_catch_brace = remove +nl_brace_catch = remove +nl_while_brace = remove +nl_do_brace = remove +nl_brace_while = remove +nl_switch_brace = remove +nl_multi_line_cond = false +nl_multi_line_sparen_open = force +nl_multi_line_sparen_close = force +nl_namespace_brace = remove +nl_class_brace = remove +nl_enum_own_lines = force +nl_func_type_name = remove +nl_func_type_name_class = remove +nl_func_proto_type_name = remove +nl_func_paren = remove +nl_func_def_paren = remove +nl_func_call_paren = remove +nl_fdef_brace = remove +nl_cpp_ldef_brace = remove +nl_after_semicolon = true +nl_after_brace_open = true +nl_after_brace_close = true +nl_after_vbrace_close = true +nl_max = 3 +nl_before_access_spec = 2 +nl_after_access_spec = 0 +nl_template_class = force +nl_template_class_decl = force +nl_template_class_def = force +nl_template_func = force +nl_template_func_decl = force +nl_template_func_def = force +nl_template_var = remove +nl_func_decl_start_multi_line = true +nl_func_def_start_multi_line = true +nl_func_decl_args_multi_line = true +nl_func_def_args_multi_line = true +nl_func_decl_end_multi_line = true +nl_func_def_end_multi_line = true +nl_func_call_start_multi_line = true +nl_func_call_args_multi_line = true +nl_func_call_end_multi_line = true +pos_arith = trail +pos_assign = trail +pos_bool = trail +pos_compare = trail +pos_conditional = lead +pos_comma = trail +pos_class_comma = lead +pos_class_colon = lead +pos_constr_comma = trail +pos_constr_colon = trail +nl_constr_colon = force +nl_constr_init_args = force +mod_full_brace_do = force +mod_full_brace_for = force +mod_full_brace_if = force +mod_full_brace_while = force +mod_full_brace_using = force +mod_paren_on_return = remove +mod_paren_on_throw = remove +mod_sort_include = false +mod_case_brace = remove +mod_remove_empty_return = true +pp_indent = remove +pp_indent_at_level = false \ No newline at end of file From 1b705b82fcf3b633d5d60b7b3cf25b32064d9e25 Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 19:59:09 +0100 Subject: [PATCH 13/24] now using uncrustify - renamed formatting action - updated docs - removed clang format target - moved codecov and readthedocs to unused --- .clang-format | 96 ------------------- .github/CODEOWNERS | 7 +- ...-clang-format.yml => check-formatting.yml} | 16 ++-- CMakeLists.txt | 1 - CONTRIBUTING.md | 4 +- .../unused/.readthedocs.yml | 0 codecov.yml => ci/unused/codecov.yml | 0 cmake/utils/fm_target_utilities.cmake | 3 +- uncrustify.cfg | 37 ++++--- 9 files changed, 43 insertions(+), 121 deletions(-) delete mode 100644 .clang-format rename .github/workflows/{check-clang-format.yml => check-formatting.yml} (72%) rename .readthedocs.yml => ci/unused/.readthedocs.yml (100%) rename codecov.yml => ci/unused/codecov.yml (100%) diff --git a/.clang-format b/.clang-format deleted file mode 100644 index 9e9ea45..0000000 --- a/.clang-format +++ /dev/null @@ -1,96 +0,0 @@ ---- -Language: Cpp -BasedOnStyle: LLVM -AccessModifierOffset: -4 -AlignAfterOpenBracket: BlockIndent -AlignConsecutiveAssignments: false -AlignConsecutiveDeclarations: false -AlignOperands: false -AlignEscapedNewlines: Left -AllowAllArgumentsOnNextLine: false -AllowAllParametersOfDeclarationOnNextLine: false -AllowShortBlocksOnASingleLine: Empty -AllowShortCaseLabelsOnASingleLine: true -AllowShortFunctionsOnASingleLine: All -AllowShortIfStatementsOnASingleLine: false -AllowShortLambdasOnASingleLine: All -AllowShortLoopsOnASingleLine: false -AlwaysBreakAfterReturnType: None -AlwaysBreakTemplateDeclarations: Yes -BinPackArguments: false -BinPackParameters: false -BraceWrapping: - AfterCaseLabel: false - AfterClass: false - AfterControlStatement: false - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterStruct: false - AfterUnion: false - AfterExternBlock: false - BeforeCatch: false - BeforeElse: false - BeforeLambdaBody: false - BeforeWhile: false - SplitEmptyFunction: false - SplitEmptyRecord: false - SplitEmptyNamespace: false -BreakBeforeBraces: Custom -BreakBeforeTemplateCloser: true -BreakBeforeTernaryOperators: true -BreakConstructorInitializers: AfterColon -BreakInheritanceList: BeforeComma -ColumnLimit: 110 -CompactNamespaces: false -ConstructorInitializerAllOnOneLineOrOnePerLine: false -ContinuationIndentWidth: 4 -EmptyLineAfterAccessModifier: Never -EmptyLineBeforeAccessModifier: LogicalBlock -FixNamespaceComments: false -IncludeBlocks: Preserve -IncludeCategories: - - Regex: '^".*' - Priority: 1 - - Regex: '^<.*' - Priority: 2 - - Regex: '.*' - Priority: 3 -IncludeIsMainRegex: '([-_](test|unittest))?$' -IndentCaseBlocks: false -IndentCaseLabels: true -IndentGotoLabels: true -IndentPPDirectives: None -IndentWidth: 4 -IndentWrappedFunctionNames: false -InsertNewlineAtEOF: true -MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 2 -NamespaceIndentation: Inner -PointerAlignment: Left -RequiresClausePosition: WithPreceding -SpaceAfterCStyleCast: true -SpaceAfterLogicalNot: false -SpaceAfterTemplateKeyword: false -SpaceBeforeAssignmentOperators: true -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeParens: Never -SpaceBeforeRangeBasedForLoopColon: true -SpaceBeforeSquareBrackets: false -SpaceInEmptyParentheses: false -SpacesBeforeTrailingComments: 1 -SpacesInAngles: false -SpacesInConditionalStatement: false -SpacesInCStyleCastParentheses: false -SpacesInParentheses: false -SpacesInSquareBrackets: false -TabWidth: 4 -UseTab: Never -PenaltyBreakBeforeFirstCallParameter: 0 -PenaltyBreakTemplateDeclaration: 0 -PenaltyReturnTypeOnItsOwnLine: 1000 -PenaltyBreakScopeResolution: 1000 -... - diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 53f2797..bbce316 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,5 +2,10 @@ .github/workflows/ @codethinki ci/docker/ @codethinki +CITATION @codethinki LICENSE @codethinki -.clang-format @codethinki +FLASHLIGHT_LICENSE @codethinki + +.clang_tidy @codethinki +uncrustify.cfg @codethinki + diff --git a/.github/workflows/check-clang-format.yml b/.github/workflows/check-formatting.yml similarity index 72% rename from .github/workflows/check-clang-format.yml rename to .github/workflows/check-formatting.yml index 945c4ec..5593e16 100644 --- a/.github/workflows/check-clang-format.yml +++ b/.github/workflows/check-formatting.yml @@ -1,4 +1,4 @@ -name: Clang Format +name: Uncrustify Format on: push: { branches: ["master", "_master/add_ci"] } @@ -10,13 +10,11 @@ permissions: { contents: read } # CONFIG # --------------------------------------------------------- env: - CLANG_FORMAT_VERSION: "21.1.0" + UNCRUSTIFY_CONFIG: "uncrustify.cfg" # Make sure this matches your config file's name/path CHECK_PATH: "flashlight" FILE_EXTENSIONS: "c|cpp|h|hpp|cu" FILES_PER_THREAD: "40" - - # --------------------------------------------------------- # JOB # --------------------------------------------------------- @@ -27,13 +25,15 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Install clang-format - run: pipx install clang-format==${{ env.CLANG_FORMAT_VERSION }} + - name: Install uncrustify + run: | + sudo apt-get update + sudo apt-get install -y uncrustify - - name: Run clang-format style check + - name: Run uncrustify style check run: | find ${{ env.CHECK_PATH }} \ -type f \ -regextype posix-extended \ -regex ".*\.(${{ env.FILE_EXTENSIONS }})$" \ - | xargs -r -P $(nproc) -n ${{ env.FILES_PER_THREAD }} clang-format --style=file --dry-run --Werror \ No newline at end of file + | xargs -r -P $(nproc) -n ${{ env.FILES_PER_THREAD }} uncrustify -c ${{ env.UNCRUSTIFY_CONFIG }} --check \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index f9d0ed9..060360e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -147,5 +147,4 @@ setup_install_targets(INSTALL_TARGETS ${INSTALLABLE_TARGETS}) include(fm_target_utilities) fm_glob_cpp(FM_CPP "flashlight/*") -fm_add_clang_format_target(clang-format ${FM_CPP}) fm_add_uncrustify_target(uncrustify-format ${FM_CPP}) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ecb0acc..88ba7a6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,11 +4,11 @@ flashmini is still under development; we appreciate any contributions. ## Pull Requests We actively welcome your pull requests. -1. Fork the repo and create your branch from `master`. +1. Fork the repo and create your branch from `master`. Naming style of `_master/your_feature` is preferred 2. If you've added code that should be tested, add tests. 3. If you've changed APIs, update [and build](docs/README.md) the documentation (to check correctness - don't submit built documentation). 4. Ensure the test suite passes. -5. Make sure your code lints and run `clang-format` given the provided configuration. +5. Make sure your code lints and run `uncrustify format` given the provided configuration. Alternatively you can build the `uncrustify-format` target. ## Issues We use [GitHub issues](https://github.com/flashmini/flashmini/issues) to track public bugs. When filing, a bug, please make sure your description is clear and include sufficient instructions to reproduce the issue (for instance, your OS, compiler version, and selected backend). diff --git a/.readthedocs.yml b/ci/unused/.readthedocs.yml similarity index 100% rename from .readthedocs.yml rename to ci/unused/.readthedocs.yml diff --git a/codecov.yml b/ci/unused/codecov.yml similarity index 100% rename from codecov.yml rename to ci/unused/codecov.yml diff --git a/cmake/utils/fm_target_utilities.cmake b/cmake/utils/fm_target_utilities.cmake index a72071f..b625a2b 100644 --- a/cmake/utils/fm_target_utilities.cmake +++ b/cmake/utils/fm_target_utilities.cmake @@ -437,6 +437,7 @@ endfunction() :param OPTIONAL: If specified, do not raise FATAL_ERROR if uncrustify is not found :param files: List of source files to format + :pre expects uncrustify.cfg in root directory :post: A custom target is created if found, or configuration terminates with FATAL_ERROR if not found (unless OPTIONAL) .. note:: @@ -477,7 +478,7 @@ function(fm_add_uncrustify_target TARGET_NAME) add_custom_target( ${TARGET_NAME} - COMMAND ${UNCRUSTIFY_EXECUTABLE} --replace --no-backup -F ${FILE_LIST_PATH} + COMMAND ${UNCRUSTIFY_EXECUTABLE} --replace --no-backup -c uncrustify.cfg -q -F ${FILE_LIST_PATH} WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} COMMENT "Formatting all source files with uncrustify..." VERBATIM diff --git a/uncrustify.cfg b/uncrustify.cfg index d070a12..8e5673b 100644 --- a/uncrustify.cfg +++ b/uncrustify.cfg @@ -2,7 +2,7 @@ newlines = auto input_tab_size = 4 output_tab_size = 4 code_width = 110 -ls_code_width = true +ls_code_width = false ls_func_split_full = true indent_columns = 4 indent_continue = 4 @@ -24,9 +24,14 @@ indent_func_proto_param = false indent_func_class_param = false indent_func_ctor_var_param = false indent_template_param = false -use_indent_func_call_param = false +use_indent_func_call_param = true +donot_indent_func_def_close_paren = true align_func_params = false -indent_paren_close = 0 +indent_paren_close = 2 +indent_align_paren = false +indent_paren_after_func_def = false +indent_paren_after_func_decl = false +indent_paren_after_func_call = false align_keep_tabs = false align_with_tabs = false align_on_tabstop = false @@ -46,7 +51,8 @@ sp_assign = force sp_cpp_lambda_assign = force sp_cpp_lambda_square_paren = remove sp_cpp_lambda_square_brace = force -sp_cpp_lambda_argument_list = force +sp_cpp_lambda_argument_list = remove +sp_cpp_lambda_argument_list_empty = remove sp_cpp_lambda_paren_brace = force sp_enum_assign = force sp_bool = force @@ -92,13 +98,15 @@ sp_after_cast = force sp_inside_paren_cast = remove sp_cpp_cast_paren = remove sp_sizeof_paren = remove -sp_inside_braces_enum = force -sp_inside_braces_struct = force +sp_inside_type_brace_init_lst = remove +sp_inside_braces_enum = remove +sp_inside_braces_struct = remove sp_inside_braces = force sp_inside_braces_empty = remove sp_type_func = force sp_func_proto_paren = remove sp_func_def_paren = remove +sp_func_type_paren = remove sp_inside_fparens = remove sp_inside_fparen = remove sp_fparen_brace = force @@ -223,17 +231,17 @@ nl_func_def_end_multi_line = true nl_func_call_start_multi_line = true nl_func_call_args_multi_line = true nl_func_call_end_multi_line = true -pos_arith = trail +pos_arith = lead pos_assign = trail -pos_bool = trail -pos_compare = trail +pos_bool = lead +pos_compare = lead pos_conditional = lead pos_comma = trail pos_class_comma = lead pos_class_colon = lead -pos_constr_comma = trail +pos_constr_comma = trail_force pos_constr_colon = trail -nl_constr_colon = force +nl_constr_colon = remove nl_constr_init_args = force mod_full_brace_do = force mod_full_brace_for = force @@ -246,4 +254,9 @@ mod_sort_include = false mod_case_brace = remove mod_remove_empty_return = true pp_indent = remove -pp_indent_at_level = false \ No newline at end of file +pp_indent_at_level = false +sp_before_ellipsis = remove +sp_type_ellipsis = remove +sp_parameter_pack_ellipsis = remove +sp_ellipsis_parameter_pack = force +sp_ptr_type_ellipsis = remove \ No newline at end of file From a3d14c0394bda47c58f3af1ff3675ecac0e73513 Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 19:59:45 +0100 Subject: [PATCH 14/24] formatted everything with uncrustify --- flashlight/fl/autograd/Functions.cpp | 3121 +++++++++-------- flashlight/fl/autograd/Functions.h | 185 +- flashlight/fl/autograd/Utils.cpp | 5 +- flashlight/fl/autograd/Utils.h | 3 +- flashlight/fl/autograd/Variable.cpp | 335 +- flashlight/fl/autograd/Variable.h | 554 +-- .../fl/autograd/tensor/AutogradExtension.h | 285 +- .../tensor/AutogradExtensionBackends.h | 12 +- flashlight/fl/autograd/tensor/AutogradOps.cpp | 593 ++-- flashlight/fl/autograd/tensor/AutogradOps.h | 387 +- .../tensor/backend/cudnn/BatchNorm.cpp | 405 ++- .../autograd/tensor/backend/cudnn/Conv2D.cpp | 1487 ++++---- .../backend/cudnn/CudnnAutogradExtension.cpp | 36 +- .../backend/cudnn/CudnnAutogradExtension.h | 273 +- .../tensor/backend/cudnn/CudnnUtils.cpp | 541 +-- .../tensor/backend/cudnn/CudnnUtils.h | 111 +- .../autograd/tensor/backend/cudnn/Pool2D.cpp | 175 +- .../fl/autograd/tensor/backend/cudnn/RNN.cpp | 692 ++-- .../tensor/backend/onednn/BatchNorm.cpp | 459 +-- .../autograd/tensor/backend/onednn/Conv2D.cpp | 900 ++--- .../tensor/backend/onednn/DnnlUtils.cpp | 155 +- .../tensor/backend/onednn/DnnlUtils.h | 121 +- .../onednn/OneDnnAutogradExtension.cpp | 7 +- .../backend/onednn/OneDnnAutogradExtension.h | 245 +- .../autograd/tensor/backend/onednn/Pool2D.cpp | 397 ++- .../fl/autograd/tensor/backend/onednn/RNN.cpp | 1069 +++--- flashlight/fl/common/Defines.cpp | 33 +- flashlight/fl/common/Defines.h | 135 +- flashlight/fl/common/DevicePtr.cpp | 48 +- flashlight/fl/common/DevicePtr.h | 64 +- flashlight/fl/common/DynamicBenchmark.cpp | 37 +- flashlight/fl/common/DynamicBenchmark.h | 367 +- flashlight/fl/common/Histogram.cpp | 48 +- flashlight/fl/common/Histogram.h | 303 +- flashlight/fl/common/Logging.cpp | 335 +- flashlight/fl/common/Logging.h | 170 +- flashlight/fl/common/Plugin.cpp | 68 +- flashlight/fl/common/Plugin.h | 24 +- flashlight/fl/common/Serialization-inl.h | 253 +- flashlight/fl/common/Serialization.h | 78 +- flashlight/fl/common/Timer.cpp | 6 +- flashlight/fl/common/Timer.h | 23 +- flashlight/fl/common/Types.h | 4 +- flashlight/fl/common/Utils.cpp | 206 +- flashlight/fl/common/Utils.h | 53 +- flashlight/fl/common/WinUtility.cpp | 99 +- flashlight/fl/common/WinUtility.h | 4 +- flashlight/fl/common/stacktrace/Backward.cpp | 14 +- flashlight/fl/common/threadpool/ThreadPool.h | 177 +- .../fl/contrib/modules/AdaptiveEmbedding.cpp | 157 +- .../fl/contrib/modules/AdaptiveEmbedding.h | 71 +- .../fl/contrib/modules/AsymmetricConv1D.cpp | 128 +- .../fl/contrib/modules/AsymmetricConv1D.h | 71 +- flashlight/fl/contrib/modules/Conformer.cpp | 459 +-- flashlight/fl/contrib/modules/Conformer.h | 118 +- .../fl/contrib/modules/PositionEmbedding.cpp | 64 +- .../fl/contrib/modules/PositionEmbedding.h | 46 +- .../fl/contrib/modules/RawWavSpecAugment.cpp | 313 +- .../fl/contrib/modules/RawWavSpecAugment.h | 189 +- flashlight/fl/contrib/modules/Residual.cpp | 269 +- flashlight/fl/contrib/modules/Residual.h | 218 +- .../modules/SinusoidalPositionEmbedding.cpp | 129 +- .../modules/SinusoidalPositionEmbedding.h | 61 +- flashlight/fl/contrib/modules/SpecAugment.cpp | 123 +- flashlight/fl/contrib/modules/SpecAugment.h | 110 +- flashlight/fl/contrib/modules/TDSBlock.cpp | 130 +- flashlight/fl/contrib/modules/TDSBlock.h | 75 +- flashlight/fl/contrib/modules/Transformer.cpp | 329 +- flashlight/fl/contrib/modules/Transformer.h | 104 +- flashlight/fl/contrib/modules/modules.h | 2 +- flashlight/fl/dataset/BatchDataset.cpp | 116 +- flashlight/fl/dataset/BatchDataset.h | 112 +- flashlight/fl/dataset/BlobDataset.cpp | 348 +- flashlight/fl/dataset/BlobDataset.h | 297 +- flashlight/fl/dataset/ConcatDataset.cpp | 40 +- flashlight/fl/dataset/ConcatDataset.h | 57 +- flashlight/fl/dataset/Dataset.h | 112 +- flashlight/fl/dataset/DatasetIterator.h | 98 +- flashlight/fl/dataset/FileBlobDataset.cpp | 134 +- flashlight/fl/dataset/FileBlobDataset.h | 67 +- flashlight/fl/dataset/MemoryBlobDataset.cpp | 39 +- flashlight/fl/dataset/MemoryBlobDataset.h | 36 +- flashlight/fl/dataset/MergeDataset.cpp | 37 +- flashlight/fl/dataset/MergeDataset.h | 57 +- flashlight/fl/dataset/PrefetchDataset.cpp | 84 +- flashlight/fl/dataset/PrefetchDataset.h | 63 +- flashlight/fl/dataset/ResampleDataset.cpp | 54 +- flashlight/fl/dataset/ResampleDataset.h | 100 +- flashlight/fl/dataset/ShuffleDataset.cpp | 31 +- flashlight/fl/dataset/ShuffleDataset.h | 72 +- flashlight/fl/dataset/SpanDataset.cpp | 35 +- flashlight/fl/dataset/SpanDataset.h | 67 +- flashlight/fl/dataset/TensorDataset.cpp | 54 +- flashlight/fl/dataset/TensorDataset.h | 42 +- flashlight/fl/dataset/TransformDataset.cpp | 29 +- flashlight/fl/dataset/TransformDataset.h | 63 +- flashlight/fl/dataset/Utils.cpp | 295 +- flashlight/fl/dataset/Utils.h | 12 +- flashlight/fl/distributed/DistributedApi.cpp | 59 +- flashlight/fl/distributed/DistributedApi.h | 24 +- flashlight/fl/distributed/FileStore.cpp | 128 +- flashlight/fl/distributed/FileStore.h | 34 +- flashlight/fl/distributed/LRUCache.h | 112 +- .../backend/cpu/DistributedBackend.cpp | 223 +- .../backend/cuda/DistributedBackend.cpp | 870 ++--- .../backend/stub/DistributedBackend.cpp | 46 +- .../reducers/CoalescingReducer.cpp | 67 +- .../distributed/reducers/CoalescingReducer.h | 112 +- .../fl/distributed/reducers/InlineReducer.cpp | 8 +- .../fl/distributed/reducers/InlineReducer.h | 42 +- flashlight/fl/distributed/reducers/Reducer.h | 32 +- .../fl/examples/AdaptiveClassification.cpp | 149 +- flashlight/fl/examples/Benchmark.cpp | 232 +- flashlight/fl/examples/Classification.cpp | 102 +- .../fl/examples/DistributedTraining.cpp | 181 +- flashlight/fl/examples/LinearRegression.cpp | 72 +- flashlight/fl/examples/Mnist.cpp | 373 +- flashlight/fl/examples/Perceptron.cpp | 112 +- flashlight/fl/examples/RnnClassification.cpp | 642 ++-- flashlight/fl/examples/RnnLm.cpp | 460 +-- flashlight/fl/examples/Xor.cpp | 170 +- flashlight/fl/meter/AverageValueMeter.cpp | 60 +- flashlight/fl/meter/AverageValueMeter.h | 44 +- flashlight/fl/meter/CountMeter.cpp | 12 +- flashlight/fl/meter/CountMeter.h | 56 +- flashlight/fl/meter/EditDistanceMeter.cpp | 87 +- flashlight/fl/meter/EditDistanceMeter.h | 265 +- flashlight/fl/meter/FrameErrorMeter.cpp | 36 +- flashlight/fl/meter/FrameErrorMeter.h | 54 +- flashlight/fl/meter/MSEMeter.cpp | 24 +- flashlight/fl/meter/MSEMeter.h | 46 +- flashlight/fl/meter/TimeMeter.cpp | 64 +- flashlight/fl/meter/TimeMeter.h | 58 +- flashlight/fl/meter/TopKMeter.cpp | 49 +- flashlight/fl/meter/TopKMeter.h | 36 +- flashlight/fl/nn/DistributedUtils.cpp | 39 +- flashlight/fl/nn/DistributedUtils.h | 6 +- flashlight/fl/nn/Init.cpp | 279 +- flashlight/fl/nn/Init.h | 98 +- flashlight/fl/nn/Utils.cpp | 229 +- flashlight/fl/nn/Utils.h | 35 +- flashlight/fl/nn/modules/Activations.cpp | 88 +- flashlight/fl/nn/modules/Activations.h | 296 +- flashlight/fl/nn/modules/AdaptiveSoftMax.cpp | 212 +- flashlight/fl/nn/modules/AdaptiveSoftMax.h | 129 +- flashlight/fl/nn/modules/BatchNorm.cpp | 154 +- flashlight/fl/nn/modules/BatchNorm.h | 219 +- flashlight/fl/nn/modules/Container.cpp | 156 +- flashlight/fl/nn/modules/Container.h | 337 +- flashlight/fl/nn/modules/Conv2D.cpp | 310 +- flashlight/fl/nn/modules/Conv2D.h | 338 +- flashlight/fl/nn/modules/Dropout.cpp | 14 +- flashlight/fl/nn/modules/Dropout.h | 26 +- flashlight/fl/nn/modules/Embedding.cpp | 49 +- flashlight/fl/nn/modules/Embedding.h | 104 +- flashlight/fl/nn/modules/Identity.cpp | 6 +- flashlight/fl/nn/modules/Identity.h | 14 +- flashlight/fl/nn/modules/LayerNorm.cpp | 227 +- flashlight/fl/nn/modules/LayerNorm.h | 121 +- flashlight/fl/nn/modules/Linear.cpp | 105 +- flashlight/fl/nn/modules/Linear.h | 120 +- flashlight/fl/nn/modules/Loss.cpp | 282 +- flashlight/fl/nn/modules/Loss.h | 254 +- flashlight/fl/nn/modules/Module.cpp | 94 +- flashlight/fl/nn/modules/Module.h | 299 +- flashlight/fl/nn/modules/Normalize.cpp | 33 +- flashlight/fl/nn/modules/Normalize.h | 63 +- flashlight/fl/nn/modules/Padding.cpp | 22 +- flashlight/fl/nn/modules/Padding.h | 40 +- flashlight/fl/nn/modules/Pool2D.cpp | 108 +- flashlight/fl/nn/modules/Pool2D.h | 84 +- flashlight/fl/nn/modules/PrecisionCast.cpp | 29 +- flashlight/fl/nn/modules/PrecisionCast.h | 56 +- flashlight/fl/nn/modules/RNN.cpp | 214 +- flashlight/fl/nn/modules/RNN.h | 249 +- flashlight/fl/nn/modules/Reorder.cpp | 23 +- flashlight/fl/nn/modules/Reorder.h | 30 +- flashlight/fl/nn/modules/Transform.cpp | 15 +- flashlight/fl/nn/modules/Transform.h | 53 +- flashlight/fl/nn/modules/View.cpp | 12 +- flashlight/fl/nn/modules/View.h | 30 +- flashlight/fl/nn/modules/WeightNorm.cpp | 148 +- flashlight/fl/nn/modules/WeightNorm.h | 137 +- flashlight/fl/optim/AMSgradOptimizer.cpp | 98 +- flashlight/fl/optim/AMSgradOptimizer.h | 76 +- flashlight/fl/optim/AdadeltaOptimizer.cpp | 98 +- flashlight/fl/optim/AdadeltaOptimizer.h | 66 +- flashlight/fl/optim/AdagradOptimizer.cpp | 62 +- flashlight/fl/optim/AdagradOptimizer.h | 47 +- flashlight/fl/optim/AdamOptimizer.cpp | 98 +- flashlight/fl/optim/AdamOptimizer.h | 76 +- flashlight/fl/optim/NAGOptimizer.cpp | 83 +- flashlight/fl/optim/NAGOptimizer.h | 53 +- flashlight/fl/optim/NovogradOptimizer.cpp | 80 +- flashlight/fl/optim/NovogradOptimizer.h | 72 +- flashlight/fl/optim/Optimizers.cpp | 11 +- flashlight/fl/optim/Optimizers.h | 75 +- flashlight/fl/optim/RMSPropOptimizer.cpp | 118 +- flashlight/fl/optim/RMSPropOptimizer.h | 76 +- flashlight/fl/optim/SGDOptimizer.cpp | 94 +- flashlight/fl/optim/SGDOptimizer.h | 60 +- flashlight/fl/optim/Utils.cpp | 36 +- flashlight/fl/optim/Utils.h | 3 +- flashlight/fl/runtime/CUDADevice.cpp | 4 +- flashlight/fl/runtime/CUDADevice.h | 51 +- flashlight/fl/runtime/CUDAStream.cpp | 146 +- flashlight/fl/runtime/CUDAStream.h | 212 +- flashlight/fl/runtime/CUDAUtils.cpp | 44 +- flashlight/fl/runtime/CUDAUtils.h | 14 +- flashlight/fl/runtime/Device.cpp | 43 +- flashlight/fl/runtime/Device.h | 218 +- flashlight/fl/runtime/DeviceManager.cpp | 97 +- flashlight/fl/runtime/DeviceManager.h | 153 +- flashlight/fl/runtime/DeviceType.cpp | 20 +- flashlight/fl/runtime/DeviceType.h | 4 +- flashlight/fl/runtime/Stream.cpp | 9 +- flashlight/fl/runtime/Stream.h | 191 +- flashlight/fl/runtime/SynchronousStream.cpp | 6 +- flashlight/fl/runtime/SynchronousStream.h | 22 +- flashlight/fl/tensor/Compute.cpp | 144 +- flashlight/fl/tensor/Compute.h | 24 +- flashlight/fl/tensor/DefaultTensorType.cpp | 4 +- flashlight/fl/tensor/DefaultTensorType.h | 8 +- flashlight/fl/tensor/Index.cpp | 62 +- flashlight/fl/tensor/Index.h | 212 +- flashlight/fl/tensor/Init.cpp | 13 +- flashlight/fl/tensor/Random.cpp | 6 +- flashlight/fl/tensor/Shape.cpp | 68 +- flashlight/fl/tensor/Shape.h | 156 +- flashlight/fl/tensor/TensorAdapter.cpp | 28 +- flashlight/fl/tensor/TensorAdapter.h | 614 ++-- flashlight/fl/tensor/TensorBackend.cpp | 70 +- flashlight/fl/tensor/TensorBackend.h | 567 ++- flashlight/fl/tensor/TensorBase.cpp | 746 ++-- flashlight/fl/tensor/TensorBase.h | 1396 ++++---- flashlight/fl/tensor/TensorExtension.cpp | 73 +- flashlight/fl/tensor/TensorExtension.h | 128 +- flashlight/fl/tensor/Types.cpp | 66 +- flashlight/fl/tensor/Types.h | 46 +- .../fl/tensor/backend/af/AdvancedIndex.cpp | 19 +- .../fl/tensor/backend/af/AdvancedIndex.h | 15 +- .../fl/tensor/backend/af/ArrayFireBLAS.cpp | 68 +- .../fl/tensor/backend/af/ArrayFireBackend.cpp | 468 +-- .../fl/tensor/backend/af/ArrayFireBackend.h | 454 +-- .../tensor/backend/af/ArrayFireBinaryOps.cpp | 146 +- .../tensor/backend/af/ArrayFireCPUStream.cpp | 16 +- .../fl/tensor/backend/af/ArrayFireCPUStream.h | 18 +- .../tensor/backend/af/ArrayFireReductions.cpp | 580 +-- .../backend/af/ArrayFireShapeAndIndex.cpp | 249 +- .../fl/tensor/backend/af/ArrayFireTensor.cpp | 851 ++--- .../fl/tensor/backend/af/ArrayFireTensor.h | 397 +-- .../tensor/backend/af/ArrayFireUnaryOps.cpp | 90 +- flashlight/fl/tensor/backend/af/Utils.cpp | 450 +-- flashlight/fl/tensor/backend/af/Utils.h | 73 +- .../backend/af/mem/CachingMemoryManager.cpp | 660 ++-- .../backend/af/mem/CachingMemoryManager.h | 250 +- .../backend/af/mem/DefaultMemoryManager.cpp | 569 +-- .../backend/af/mem/DefaultMemoryManager.h | 177 +- .../backend/af/mem/MemoryManagerAdapter.cpp | 71 +- .../backend/af/mem/MemoryManagerAdapter.h | 261 +- .../af/mem/MemoryManagerDeviceInterface.h | 28 +- .../backend/af/mem/MemoryManagerInstaller.cpp | 418 ++- .../backend/af/mem/MemoryManagerInstaller.h | 131 +- .../fl/tensor/backend/stub/StubBackend.cpp | 354 +- .../fl/tensor/backend/stub/StubBackend.h | 449 +-- .../fl/tensor/backend/stub/StubTensor.cpp | 116 +- .../fl/tensor/backend/stub/StubTensor.h | 148 +- flashlight/fl/tensor/profile/CUDAProfile.cpp | 8 +- flashlight/fl/tensor/profile/Profile.h | 28 +- .../test/autograd/AutogradBinaryOpsTest.cpp | 501 +-- .../fl/test/autograd/AutogradConv2DTest.cpp | 488 +-- .../autograd/AutogradNormalizationTest.cpp | 872 +++-- .../test/autograd/AutogradReductionTest.cpp | 238 +- .../fl/test/autograd/AutogradRnnTest.cpp | 243 +- flashlight/fl/test/autograd/AutogradTest.cpp | 776 ++-- .../fl/test/autograd/AutogradTestUtils.h | 90 +- .../fl/test/autograd/AutogradUnaryOpsTest.cpp | 222 +- flashlight/fl/test/common/DevicePtrTest.cpp | 64 +- .../fl/test/common/DynamicBenchmarkTest.cpp | 226 +- flashlight/fl/test/common/HistogramTest.cpp | 203 +- flashlight/fl/test/common/LoggingTest.cpp | 208 +- .../fl/test/common/SerializationTest.cpp | 272 +- flashlight/fl/test/common/UtilsTest.cpp | 160 +- .../contrib/modules/ContribModuleTest.cpp | 743 ++-- .../modules/ContribSerializationTest.cpp | 317 +- flashlight/fl/test/dataset/DatasetTest.cpp | 1063 +++--- .../fl/test/dataset/DatasetUtilsTest.cpp | 64 +- .../test/distributed/AllReduceBenchmark.cpp | 85 +- .../fl/test/distributed/AllReduceTest.cpp | 282 +- flashlight/fl/test/meter/MeterTest.cpp | 134 +- flashlight/fl/test/nn/ModuleTest.cpp | 1781 +++++----- flashlight/fl/test/nn/NNSerializationTest.cpp | 505 +-- flashlight/fl/test/nn/NNUtilsTest.cpp | 222 +- flashlight/fl/test/optim/OptimBenchmark.cpp | 96 +- flashlight/fl/test/optim/OptimTest.cpp | 184 +- flashlight/fl/test/runtime/CUDADeviceTest.cpp | 38 +- flashlight/fl/test/runtime/CUDAStreamTest.cpp | 152 +- .../fl/test/runtime/DeviceManagerTest.cpp | 88 +- flashlight/fl/test/runtime/DeviceTest.cpp | 96 +- flashlight/fl/test/runtime/DeviceTypeTest.cpp | 16 +- flashlight/fl/test/tensor/ComputeTest.cpp | 40 +- flashlight/fl/test/tensor/IndexTest.cpp | 453 +-- flashlight/fl/test/tensor/ShapeTest.cpp | 106 +- .../fl/test/tensor/TensorAdapterTest.cpp | 32 +- flashlight/fl/test/tensor/TensorBLASTest.cpp | 350 +- flashlight/fl/test/tensor/TensorBaseTest.cpp | 980 +++--- .../fl/test/tensor/TensorBinaryOpsTest.cpp | 835 ++--- .../fl/test/tensor/TensorExtensionTest.cpp | 60 +- .../fl/test/tensor/TensorReductionTest.cpp | 682 ++-- .../fl/test/tensor/TensorUnaryOpsTest.cpp | 329 +- .../test/tensor/af/ArrayFireCPUStreamTest.cpp | 8 +- .../tensor/af/ArrayFireTensorBaseTest.cpp | 587 ++-- .../tensor/af/CachingMemoryManagerTest.cpp | 285 +- .../fl/test/tensor/af/MemoryFrameworkTest.cpp | 724 ++-- .../fl/test/tensor/af/MemoryInitTest.cpp | 22 +- flashlight/pkg/runtime/Runtime.cpp | 96 +- flashlight/pkg/runtime/Runtime.h | 24 +- flashlight/pkg/runtime/amp/DynamicScaler.cpp | 88 +- flashlight/pkg/runtime/amp/DynamicScaler.h | 79 +- .../pkg/runtime/common/DistributedUtils.cpp | 112 +- .../pkg/runtime/common/DistributedUtils.h | 63 +- .../pkg/runtime/common/SequentialBuilder.cpp | 1123 +++--- .../pkg/runtime/common/SequentialBuilder.h | 22 +- flashlight/pkg/runtime/common/Serializer.h | 138 +- .../pkg/runtime/plugin/ModulePlugin.cpp | 7 +- flashlight/pkg/runtime/plugin/ModulePlugin.h | 20 +- .../plugin/plugincompiler/PluginModule.cpp | 4 +- .../pkg/runtime/test/DynamicScalerTest.cpp | 65 +- .../test/common/SequentialBuilderTest.cpp | 80 +- .../runtime/test/plugin/ModulePluginTest.cpp | 48 +- .../test/plugin/test_module_plugin.cpp | 6 +- .../pkg/speech/audio/feature/Ceplifter.cpp | 40 +- .../pkg/speech/audio/feature/Ceplifter.h | 26 +- flashlight/pkg/speech/audio/feature/Dct.cpp | 19 +- flashlight/pkg/speech/audio/feature/Dct.h | 26 +- .../pkg/speech/audio/feature/Derivatives.cpp | 120 +- .../pkg/speech/audio/feature/Derivatives.h | 45 +- .../pkg/speech/audio/feature/Dither.cpp | 18 +- flashlight/pkg/speech/audio/feature/Dither.h | 26 +- .../pkg/speech/audio/feature/FeatureParams.h | 309 +- flashlight/pkg/speech/audio/feature/Mfcc.cpp | 83 +- flashlight/pkg/speech/audio/feature/Mfcc.h | 82 +- flashlight/pkg/speech/audio/feature/Mfsc.cpp | 159 +- flashlight/pkg/speech/audio/feature/Mfsc.h | 38 +- .../speech/audio/feature/PowerSpectrum.cpp | 212 +- .../pkg/speech/audio/feature/PowerSpectrum.h | 62 +- .../pkg/speech/audio/feature/PreEmphasis.cpp | 47 +- .../pkg/speech/audio/feature/PreEmphasis.h | 24 +- .../pkg/speech/audio/feature/SpeechUtils.cpp | 83 +- .../pkg/speech/audio/feature/SpeechUtils.h | 22 +- .../speech/audio/feature/TriFilterbank.cpp | 109 +- .../pkg/speech/audio/feature/TriFilterbank.h | 68 +- .../pkg/speech/audio/feature/Windowing.cpp | 68 +- .../pkg/speech/audio/feature/Windowing.h | 26 +- .../pkg/speech/augmentation/AdditiveNoise.cpp | 119 +- .../pkg/speech/augmentation/AdditiveNoise.h | 61 +- .../pkg/speech/augmentation/GaussianNoise.cpp | 39 +- .../pkg/speech/augmentation/GaussianNoise.h | 51 +- .../pkg/speech/augmentation/Reverberation.cpp | 100 +- .../pkg/speech/augmentation/Reverberation.h | 139 +- .../pkg/speech/augmentation/SoundEffect.cpp | 95 +- .../pkg/speech/augmentation/SoundEffect.h | 108 +- .../speech/augmentation/SoundEffectApply.cpp | 67 +- .../speech/augmentation/SoundEffectConfig.cpp | 193 +- .../speech/augmentation/SoundEffectConfig.h | 65 +- .../speech/augmentation/SoundEffectUtil.cpp | 57 +- .../pkg/speech/augmentation/SoundEffectUtil.h | 62 +- .../pkg/speech/augmentation/SoxWrapper.cpp | 308 +- .../pkg/speech/augmentation/SoxWrapper.h | 96 +- .../pkg/speech/augmentation/TimeStretch.cpp | 44 +- .../pkg/speech/augmentation/TimeStretch.h | 104 +- flashlight/pkg/speech/common/Defines.h | 98 +- flashlight/pkg/speech/common/Flags.cpp | 523 ++- flashlight/pkg/speech/common/Flags.h | 335 +- .../pkg/speech/common/ProducerConsumerQueue.h | 192 +- .../criterion/AutoSegmentationCriterion.h | 166 +- ...tionistTemporalClassificationCriterion.cpp | 179 +- ...ectionistTemporalClassificationCriterion.h | 54 +- .../pkg/speech/criterion/CriterionUtils.cpp | 115 +- .../pkg/speech/criterion/CriterionUtils.h | 149 +- flashlight/pkg/speech/criterion/Defines.h | 10 +- .../criterion/ForceAlignmentCriterion.cpp | 25 +- .../criterion/ForceAlignmentCriterion.h | 48 +- .../criterion/FullConnectionCriterion.cpp | 25 +- .../criterion/FullConnectionCriterion.h | 48 +- .../criterion/LinearSegmentationCriterion.h | 76 +- .../pkg/speech/criterion/Seq2SeqCriterion.cpp | 1205 ++++--- .../pkg/speech/criterion/Seq2SeqCriterion.h | 419 +-- .../pkg/speech/criterion/SequenceCriterion.h | 86 +- .../speech/criterion/TransformerCriterion.cpp | 655 ++-- .../speech/criterion/TransformerCriterion.h | 277 +- .../criterion/attention/AttentionBase.h | 150 +- .../criterion/attention/ContentAttention.cpp | 133 +- .../criterion/attention/ContentAttention.h | 66 +- .../pkg/speech/criterion/attention/Defines.h | 36 +- .../criterion/attention/LocationAttention.cpp | 288 +- .../criterion/attention/LocationAttention.h | 106 +- .../criterion/attention/MedianWindow.cpp | 120 +- .../speech/criterion/attention/MedianWindow.h | 62 +- .../attention/MultiHeadAttention.cpp | 170 +- .../criterion/attention/MultiHeadAttention.h | 58 +- .../attention/SoftPretrainWindow.cpp | 92 +- .../criterion/attention/SoftPretrainWindow.h | 79 +- .../speech/criterion/attention/SoftWindow.cpp | 82 +- .../speech/criterion/attention/SoftWindow.h | 81 +- .../speech/criterion/attention/StepWindow.cpp | 104 +- .../speech/criterion/attention/StepWindow.h | 83 +- .../pkg/speech/criterion/attention/Utils.cpp | 34 +- .../pkg/speech/criterion/attention/Utils.h | 11 +- .../speech/criterion/attention/WindowBase.cpp | 85 +- .../speech/criterion/attention/WindowBase.h | 164 +- ...tionistTemporalClassificationCriterion.cpp | 438 +-- .../criterion/backend/cpu/CriterionUtils.cpp | 78 +- .../backend/cpu/ForceAlignmentCriterion.cpp | 268 +- .../backend/cpu/FullConnectionCriterion.cpp | 165 +- ...tionistTemporalClassificationCriterion.cpp | 268 +- .../criterion/backend/cuda/CriterionUtils.cpp | 130 +- .../backend/cuda/ForceAlignmentCriterion.cpp | 291 +- .../backend/cuda/FullConnectionCriterion.cpp | 192 +- .../pkg/speech/data/FeatureTransforms.cpp | 268 +- .../pkg/speech/data/FeatureTransforms.h | 317 +- .../pkg/speech/data/ListFileDataset.cpp | 217 +- flashlight/pkg/speech/data/ListFileDataset.h | 50 +- flashlight/pkg/speech/data/Sound.cpp | 457 +-- flashlight/pkg/speech/data/Sound.h | 164 +- flashlight/pkg/speech/data/Utils.cpp | 146 +- flashlight/pkg/speech/data/Utils.h | 59 +- .../pkg/speech/decoder/ConvLmModule.cpp | 92 +- flashlight/pkg/speech/decoder/ConvLmModule.h | 10 +- .../pkg/speech/decoder/DecodeMaster.cpp | 524 +-- flashlight/pkg/speech/decoder/DecodeMaster.h | 306 +- flashlight/pkg/speech/decoder/DecodeUtils.cpp | 66 +- flashlight/pkg/speech/decoder/DecodeUtils.h | 27 +- flashlight/pkg/speech/decoder/Defines.h | 63 +- flashlight/pkg/speech/decoder/PlGenerator.cpp | 426 +-- flashlight/pkg/speech/decoder/PlGenerator.h | 185 +- .../pkg/speech/decoder/TranscriptionUtils.cpp | 163 +- .../pkg/speech/decoder/TranscriptionUtils.h | 146 +- flashlight/pkg/speech/runtime/Attention.cpp | 138 +- flashlight/pkg/speech/runtime/Attention.h | 8 +- flashlight/pkg/speech/runtime/Helpers.cpp | 341 +- flashlight/pkg/speech/runtime/Helpers.h | 58 +- flashlight/pkg/speech/runtime/Logger.cpp | 215 +- flashlight/pkg/speech/runtime/Logger.h | 77 +- flashlight/pkg/speech/runtime/Optimizer.cpp | 115 +- flashlight/pkg/speech/runtime/Optimizer.h | 17 +- .../pkg/speech/runtime/SpeechStatMeter.cpp | 68 +- .../pkg/speech/runtime/SpeechStatMeter.h | 46 +- .../pkg/speech/test/audio/CeplifterTest.cpp | 104 +- flashlight/pkg/speech/test/audio/DctTest.cpp | 90 +- .../pkg/speech/test/audio/DerivativesTest.cpp | 178 +- .../pkg/speech/test/audio/DitherTest.cpp | 66 +- flashlight/pkg/speech/test/audio/MfccTest.cpp | 301 +- .../pkg/speech/test/audio/PreEmphasisTest.cpp | 108 +- .../pkg/speech/test/audio/SpeechUtilsTest.cpp | 28 +- flashlight/pkg/speech/test/audio/TestUtils.h | 46 +- .../speech/test/audio/TriFilterbankTest.cpp | 106 +- .../pkg/speech/test/audio/WindowingTest.cpp | 96 +- .../test/augmentation/AdditiveNoiseTest.cpp | 119 +- .../test/augmentation/GaussianNoiseTest.cpp | 50 +- .../test/augmentation/ReverberationTest.cpp | 146 +- .../augmentation/SoundEffectConfigTest.cpp | 100 +- .../test/augmentation/SoundEffectTest.cpp | 150 +- .../test/augmentation/TimeStretchTest.cpp | 34 +- .../test/common/ProducerConsumerQueueTest.cpp | 122 +- .../speech/test/criterion/BenchmarkASG.cpp | 65 +- .../speech/test/criterion/BenchmarkCTC.cpp | 72 +- .../test/criterion/BenchmarkSeq2Seq.cpp | 121 +- .../pkg/speech/test/criterion/CompareASG.cpp | 222 +- .../speech/test/criterion/CriterionTest.cpp | 1587 ++++----- .../pkg/speech/test/criterion/Seq2SeqTest.cpp | 886 ++--- .../criterion/attention/AttentionTest.cpp | 381 +- .../test/criterion/attention/WindowTest.cpp | 608 ++-- .../speech/test/data/FeaturizationTest.cpp | 800 +++-- .../speech/test/data/ListFileDatasetTest.cpp | 97 +- flashlight/pkg/speech/test/data/SoundTest.cpp | 252 +- .../speech/test/decoder/ConvLmModuleTest.cpp | 172 +- .../pkg/speech/test/runtime/RuntimeTest.cpp | 176 +- .../contrib/moderngpu/include/mgpuenums.h | 52 +- .../contrib/moderngpu/include/util/static.h | 124 +- .../speech/third_party/warpctc/include/ctc.h | 37 +- .../warpctc/include/detail/cpu_ctc.h | 495 ++- .../warpctc/include/detail/ctc_helper.h | 34 +- .../warpctc/include/detail/gpu_ctc.h | 623 ++-- .../warpctc/include/detail/gpu_ctc_kernels.h | 293 +- .../warpctc/include/detail/hostdevice.h | 4 +- .../warpctc/include/detail/reduce.h | 9 +- flashlight/pkg/text/data/TextDataset.cpp | 262 +- flashlight/pkg/text/data/TextDataset.h | 61 +- .../pkg/text/test/data/TextDatasetTest.cpp | 210 +- .../pkg/vision/common/BetaDistribution.h | 232 +- flashlight/pkg/vision/criterion/Hungarian.cpp | 122 +- flashlight/pkg/vision/criterion/Hungarian.h | 69 +- .../pkg/vision/criterion/HungarianImpl.cpp | 486 +-- .../pkg/vision/criterion/HungarianImpl.h | 8 +- .../pkg/vision/criterion/SetCriterion.cpp | 475 +-- .../pkg/vision/criterion/SetCriterion.h | 145 +- .../vision/dataset/BatchTransformDataset.h | 145 +- flashlight/pkg/vision/dataset/BoxUtils.cpp | 359 +- flashlight/pkg/vision/dataset/BoxUtils.h | 43 +- flashlight/pkg/vision/dataset/Coco.cpp | 295 +- flashlight/pkg/vision/dataset/Coco.h | 116 +- .../pkg/vision/dataset/CocoTransforms.cpp | 399 +-- .../pkg/vision/dataset/CocoTransforms.h | 46 +- .../pkg/vision/dataset/DistributedDataset.cpp | 39 +- .../pkg/vision/dataset/DistributedDataset.h | 41 +- flashlight/pkg/vision/dataset/Imagenet.cpp | 152 +- flashlight/pkg/vision/dataset/Imagenet.h | 22 +- flashlight/pkg/vision/dataset/Jpeg.cpp | 47 +- flashlight/pkg/vision/dataset/Jpeg.h | 8 +- flashlight/pkg/vision/dataset/LoaderDataset.h | 36 +- .../pkg/vision/dataset/TransformAllDataset.h | 49 +- flashlight/pkg/vision/dataset/Transforms.cpp | 711 ++-- flashlight/pkg/vision/dataset/Transforms.h | 129 +- flashlight/pkg/vision/models/Detr.cpp | 193 +- flashlight/pkg/vision/models/Detr.h | 106 +- flashlight/pkg/vision/models/Resnet.cpp | 320 +- flashlight/pkg/vision/models/Resnet.h | 179 +- .../pkg/vision/models/Resnet50Backbone.cpp | 52 +- .../pkg/vision/models/Resnet50Backbone.h | 24 +- .../vision/models/ResnetFrozenBatchNorm.cpp | 251 +- .../pkg/vision/models/ResnetFrozenBatchNorm.h | 171 +- flashlight/pkg/vision/models/ViT.cpp | 203 +- flashlight/pkg/vision/models/ViT.h | 92 +- flashlight/pkg/vision/nn/FrozenBatchNorm.cpp | 66 +- flashlight/pkg/vision/nn/FrozenBatchNorm.h | 136 +- .../pkg/vision/nn/PositionalEmbeddingSine.cpp | 169 +- .../pkg/vision/nn/PositionalEmbeddingSine.h | 76 +- flashlight/pkg/vision/nn/Transformer.cpp | 799 +++-- flashlight/pkg/vision/nn/Transformer.h | 530 +-- .../pkg/vision/nn/VisionTransformer.cpp | 274 +- flashlight/pkg/vision/nn/VisionTransformer.h | 132 +- .../pkg/vision/tensor/VisionExtension.h | 92 +- .../vision/tensor/VisionExtensionBackends.h | 2 +- flashlight/pkg/vision/tensor/VisionOps.cpp | 104 +- flashlight/pkg/vision/tensor/VisionOps.h | 26 +- .../backend/af/ArrayFireVisionExtension.cpp | 403 ++- .../backend/af/ArrayFireVisionExtension.h | 89 +- .../vision/test/ModelSerializationTest.cpp | 109 +- .../test/PositionalEmbeddingSineTest.cpp | 26 +- .../pkg/vision/test/TransformerTest.cpp | 464 +-- flashlight/pkg/vision/test/TransformsTest.cpp | 148 +- .../vision/test/criterion/HungarianTest.cpp | 340 +- .../test/criterion/SetCriterionTest.cpp | 703 ++-- .../pkg/vision/test/dataset/BoxUtilsTest.cpp | 232 +- 545 files changed, 52924 insertions(+), 48408 deletions(-) diff --git a/flashlight/fl/autograd/Functions.cpp b/flashlight/fl/autograd/Functions.cpp index db0824a..3c6c530 100644 --- a/flashlight/fl/autograd/Functions.cpp +++ b/flashlight/fl/autograd/Functions.cpp @@ -24,1288 +24,1387 @@ namespace fl { namespace detail { -Tensor tileAs(const Tensor& input, const Shape& rdims) { - // Scalar tensor - if (input.ndim() == 0) { - return tile(input, rdims); - } - - Shape dims(std::vector(rdims.ndim(), 1)); - Shape idims = input.shape(); - for (int i = 0; i < rdims.ndim(); i++) { - int idimsSize = i + 1 > idims.ndim() ? 1 : idims[i]; - if (rdims[i] % idimsSize != 0) { - std::stringstream ss; - ss << "Invalid dims for tileAs for input dims " << idims - << " to output dims " << rdims; - throw std::invalid_argument(ss.str()); + Tensor tileAs(const Tensor& input, const Shape& rdims) { + // Scalar tensor + if(input.ndim() == 0) { + return tile(input, rdims); + } + + Shape dims(std::vector(rdims.ndim(), 1)); + Shape idims = input.shape(); + for(int i = 0; i < rdims.ndim(); i++) { + int idimsSize = i + 1 > idims.ndim() ? 1 : idims[i]; + if(rdims[i] % idimsSize != 0) { + std::stringstream ss; + ss << "Invalid dims for tileAs for input dims " << idims + << " to output dims " << rdims; + throw std::invalid_argument(ss.str()); + } + dims[i] = rdims[i] / idimsSize; + } + return tile(input, dims); } - dims[i] = rdims[i] / idimsSize; - } - return tile(input, dims); -} -Tensor sumAs(const Tensor& input, const Shape& rdims) { - Shape idims = input.shape(); - auto result = input; - for (int i = 0; i < input.ndim(); i++) { - if (i + 1 > rdims.ndim() || idims[i] != rdims[i]) { - result = fl::sum(result, {i}, /* keepDims = */ true); + Tensor sumAs(const Tensor& input, const Shape& rdims) { + Shape idims = input.shape(); + auto result = input; + for(int i = 0; i < input.ndim(); i++) { + if(i + 1 > rdims.ndim() || idims[i] != rdims[i]) { + result = fl::sum(result, {i}, /* keepDims = */ true); + } + } + + return fl::reshape(result.astype(input.type()), rdims); } - } - return fl::reshape(result.astype(input.type()), rdims); -} + Shape expandedShapeFromReducedDims( + const Tensor& input, + const std::vector& axes, + bool keepDims /* = false */ + ) { + // Fast path - tensor already retained its shape + if(keepDims) { + return input.shape(); + } + // If we output a scalar, + if(input.ndim() == 0) { + return {}; + } -Shape expandedShapeFromReducedDims( - const Tensor& input, - const std::vector& axes, - bool keepDims /* = false */) { - // Fast path - tensor already retained its shape - if (keepDims) { - return input.shape(); - } - // If we output a scalar, - if (input.ndim() == 0) { - return {}; - } - - unsigned preNDims = input.ndim() + axes.size(); - Shape newShape(std::vector(preNDims, 1)); - unsigned axesIdx = 0; - unsigned inputIdx = 0; - for (unsigned i = 0; i < preNDims; ++i) { - if (i == axes[axesIdx]) { - // This dim was reduced over, leave as 1 in the new shape - axesIdx++; - } else { - // Dim wasn't reduced over - add the shape from the new tensor - newShape[i] = input.dim(inputIdx); - inputIdx++; + unsigned preNDims = input.ndim() + axes.size(); + Shape newShape(std::vector(preNDims, 1)); + unsigned axesIdx = 0; + unsigned inputIdx = 0; + for(unsigned i = 0; i < preNDims; ++i) { + if(i == axes[axesIdx]) { + // This dim was reduced over, leave as 1 in the new shape + axesIdx++; + } else { + // Dim wasn't reduced over - add the shape from the new tensor + newShape[i] = input.dim(inputIdx); + inputIdx++; + } + } + return newShape; } - } - return newShape; -} // TODO: remove these/use a simple template -Variable expandFromReduction( - const Variable& input, - const std::vector& axes, - bool keepDims) { - return moddims( - input, expandedShapeFromReducedDims(input.tensor(), axes, keepDims)); -} + Variable expandFromReduction( + const Variable& input, + const std::vector& axes, + bool keepDims + ) { + return moddims( + input, + expandedShapeFromReducedDims(input.tensor(), axes, keepDims) + ); + } -Tensor expandFromReduction( - const Tensor& input, - const std::vector& axes, - bool keepDims) { - auto o = expandedShapeFromReducedDims(input, axes, keepDims); - return fl::reshape( - input, expandedShapeFromReducedDims(input, axes, keepDims)); -} + Tensor expandFromReduction( + const Tensor& input, + const std::vector& axes, + bool keepDims + ) { + auto o = expandedShapeFromReducedDims(input, axes, keepDims); + return fl::reshape( + input, + expandedShapeFromReducedDims(input, axes, keepDims) + ); + } -bool areVariableTypesEqual(const Variable& a, const Variable& b) { - return a.type() == b.type(); -} + bool areVariableTypesEqual(const Variable& a, const Variable& b) { + return a.type() == b.type(); + } } // namespace detail Variable operator+(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - auto result = lhs.tensor() + rhs.tensor(); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor(), false)); - inputs[1].addGrad(Variable(gradOutput.tensor(), false)); - }; - return Variable(result, {lhs.withoutData(), rhs.withoutData()}, gradFunc); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + auto result = lhs.tensor() + rhs.tensor(); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad(Variable(gradOutput.tensor(), false)); + inputs[1].addGrad(Variable(gradOutput.tensor(), false)); + }; + return Variable(result, {lhs.withoutData(), rhs.withoutData()}, gradFunc); } Variable operator+(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() + rhsVal).astype(lhs.type()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor(), false)); - }; - return Variable(result, {lhs.withoutData()}, gradFunc); + auto result = (lhs.tensor() + rhsVal).astype(lhs.type()); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad(Variable(gradOutput.tensor(), false)); + }; + return Variable(result, {lhs.withoutData()}, gradFunc); } Variable operator+(const double& lhsVal, const Variable& rhs) { - return rhs + lhsVal; + return rhs + lhsVal; } Variable operator-(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - auto result = lhs.tensor() - rhs.tensor(); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor(), false)); - inputs[1].addGrad(Variable(negate(gradOutput).tensor(), false)); - }; - return Variable(result, {lhs.withoutData(), rhs.withoutData()}, gradFunc); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + auto result = lhs.tensor() - rhs.tensor(); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad(Variable(gradOutput.tensor(), false)); + inputs[1].addGrad(Variable(negate(gradOutput).tensor(), false)); + }; + return Variable(result, {lhs.withoutData(), rhs.withoutData()}, gradFunc); } Variable operator-(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() - rhsVal).astype(lhs.type()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor(), false)); - }; - return Variable(result, {lhs.withoutData()}, gradFunc); + auto result = (lhs.tensor() - rhsVal).astype(lhs.type()); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad(Variable(gradOutput.tensor(), false)); + }; + return Variable(result, {lhs.withoutData()}, gradFunc); } Variable operator-(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal - rhs.tensor()).astype(rhs.type()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false)); - }; - return Variable(result, {rhs.withoutData()}, gradFunc); + auto result = (lhsVal - rhs.tensor()).astype(rhs.type()); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false)); + }; + return Variable(result, {rhs.withoutData()}, gradFunc); } Variable operator*(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - auto result = lhs.tensor() * rhs.tensor(); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - if (inputs[0].isCalcGrad()) { - inputs[0].addGrad( - Variable(gradOutput.tensor() * inputs[1].tensor(), false)); - } - if (inputs[1].isCalcGrad()) { - inputs[1].addGrad( - Variable(gradOutput.tensor() * inputs[0].tensor(), false)); - } - }; - return Variable( - result, - {rhs.isCalcGrad() ? lhs : lhs.withoutData(), - lhs.isCalcGrad() ? rhs : rhs.withoutData()}, - gradFunc); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + auto result = lhs.tensor() * rhs.tensor(); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + if(inputs[0].isCalcGrad()) { + inputs[0].addGrad( + Variable(gradOutput.tensor() * inputs[1].tensor(), false) + ); + } + if(inputs[1].isCalcGrad()) { + inputs[1].addGrad( + Variable(gradOutput.tensor() * inputs[0].tensor(), false) + ); + } + }; + return Variable( + result, + {rhs.isCalcGrad() ? lhs : lhs.withoutData(), + lhs.isCalcGrad() ? rhs : rhs.withoutData()}, + gradFunc + ); } Variable operator*(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() * rhsVal).astype(lhs.type()); - auto gradFunc = - [rhsVal](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor() * rhsVal, false)); - }; - return Variable(result, {lhs.withoutData()}, gradFunc); + auto result = (lhs.tensor() * rhsVal).astype(lhs.type()); + auto gradFunc = + [rhsVal](std::vector& inputs, const Variable& gradOutput) { + inputs[0].addGrad(Variable(gradOutput.tensor() * rhsVal, false)); + }; + return Variable(result, {lhs.withoutData()}, gradFunc); } Variable operator*(const double& lhsVal, const Variable& rhs) { - return rhs * lhsVal; + return rhs * lhsVal; } Variable operator/(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - auto result = lhs.tensor() / rhs.tensor(); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto inputs1rec = reciprocal(inputs[1]); - auto gradInput0 = gradOutput * inputs1rec; - if (inputs[0].isCalcGrad()) { - inputs[0].addGrad(Variable(gradInput0.tensor(), false)); - } - if (inputs[1].isCalcGrad()) { - inputs[1].addGrad(Variable( - (gradInput0 * negate(inputs[0]) * inputs1rec).tensor(), false)); - } - }; - return Variable( - result, {rhs.isCalcGrad() ? lhs : lhs.withoutData(), rhs}, gradFunc); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + auto result = lhs.tensor() / rhs.tensor(); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + auto inputs1rec = reciprocal(inputs[1]); + auto gradInput0 = gradOutput * inputs1rec; + if(inputs[0].isCalcGrad()) { + inputs[0].addGrad(Variable(gradInput0.tensor(), false)); + } + if(inputs[1].isCalcGrad()) { + inputs[1].addGrad( + Variable( + (gradInput0 * negate(inputs[0]) * inputs1rec).tensor(), + false + ) + ); + } + }; + return Variable( + result, + {rhs.isCalcGrad() ? lhs : lhs.withoutData(), rhs}, + gradFunc + ); } Variable operator/(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() / rhsVal).astype(lhs.type()); - auto gradFunc = - [rhsVal](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad(Variable((gradOutput / rhsVal).tensor(), false)); - }; - return Variable(result, {lhs.withoutData()}, gradFunc); + auto result = (lhs.tensor() / rhsVal).astype(lhs.type()); + auto gradFunc = + [rhsVal](std::vector& inputs, const Variable& gradOutput) { + inputs[0].addGrad(Variable((gradOutput / rhsVal).tensor(), false)); + }; + return Variable(result, {lhs.withoutData()}, gradFunc); } Variable operator/(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal / rhs.tensor()).astype(rhs.type()); - auto gradFunc = [lhsVal]( - std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable( - (gradOutput * (-lhsVal) / (inputs[0] * inputs[0])).tensor(), false)); - }; - return Variable(result, {rhs}, gradFunc); + auto result = (lhsVal / rhs.tensor()).astype(rhs.type()); + auto gradFunc = [lhsVal]( + std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad( + Variable( + (gradOutput * (-lhsVal) / (inputs[0] * inputs[0])).tensor(), + false + ) + ); + }; + return Variable(result, {rhs}, gradFunc); } Variable operator>(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - auto result = lhs.tensor() > rhs.tensor(); - return Variable(result, false); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + auto result = lhs.tensor() > rhs.tensor(); + return Variable(result, false); } Variable operator>(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() > rhsVal).astype(lhs.type()); - return Variable(result, false); + auto result = (lhs.tensor() > rhsVal).astype(lhs.type()); + return Variable(result, false); } Variable operator>(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal > rhs.tensor()).astype(rhs.type()); - return Variable(result, false); + auto result = (lhsVal > rhs.tensor()).astype(rhs.type()); + return Variable(result, false); } Variable operator<(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - auto result = lhs.tensor() < rhs.tensor(); - return Variable(result, false); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + auto result = lhs.tensor() < rhs.tensor(); + return Variable(result, false); } Variable operator<(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() < rhsVal).astype(lhs.type()); - return Variable(result, false); + auto result = (lhs.tensor() < rhsVal).astype(lhs.type()); + return Variable(result, false); } Variable operator<(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal < rhs.tensor()).astype(rhs.type()); - return Variable(result, false); + auto result = (lhsVal < rhs.tensor()).astype(rhs.type()); + return Variable(result, false); } Variable operator>=(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - auto result = lhs.tensor() >= rhs.tensor(); - return Variable(result, false); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + auto result = lhs.tensor() >= rhs.tensor(); + return Variable(result, false); } Variable operator>=(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() >= rhsVal).astype(lhs.type()); - return Variable(result, false); + auto result = (lhs.tensor() >= rhsVal).astype(lhs.type()); + return Variable(result, false); } Variable operator>=(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal >= rhs.tensor()).astype(rhs.type()); - return Variable(result, false); + auto result = (lhsVal >= rhs.tensor()).astype(rhs.type()); + return Variable(result, false); } Variable operator<=(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - auto result = lhs.tensor() <= rhs.tensor(); - return Variable(result, false); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + auto result = lhs.tensor() <= rhs.tensor(); + return Variable(result, false); } Variable operator<=(const Variable& lhs, const double& rhsVal) { - auto result = (lhs.tensor() <= rhsVal).astype(lhs.type()); - return Variable(result, false); + auto result = (lhs.tensor() <= rhsVal).astype(lhs.type()); + return Variable(result, false); } Variable operator<=(const double& lhsVal, const Variable& rhs) { - auto result = (lhsVal <= rhs.tensor()).astype(rhs.type()); - return Variable(result, false); + auto result = (lhsVal <= rhs.tensor()).astype(rhs.type()); + return Variable(result, false); } Variable operator&&(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - auto result = lhs.tensor() && rhs.tensor(); - return Variable(result, false); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + auto result = lhs.tensor() && rhs.tensor(); + return Variable(result, false); } Variable operator!(const Variable& input) { - auto result = (!input.tensor()).astype(input.type()); - return Variable(result, false); + auto result = (!input.tensor()).astype(input.type()); + return Variable(result, false); } Variable max(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - auto result = fl::maximum(lhs.tensor(), rhs.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto mask = Variable( - (inputs[0].tensor() > inputs[1].tensor()).astype(gradOutput.type()), - false); - inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); - inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false)); - }; - return Variable(result, {lhs, rhs}, gradFunc); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + auto result = fl::maximum(lhs.tensor(), rhs.tensor()); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + auto mask = Variable( + (inputs[0].tensor() > inputs[1].tensor()).astype(gradOutput.type()), + false + ); + inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); + inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false)); + }; + return Variable(result, {lhs, rhs}, gradFunc); } Variable max(const Variable& lhs, const double& rhsVal) { - auto result = fl::maximum(lhs.tensor(), rhsVal).astype(lhs.type()); - auto gradFunc = - [rhsVal](std::vector& inputs, const Variable& gradOutput) { - auto mask = Variable( - (inputs[0].tensor() > rhsVal).astype(gradOutput.type()), false); - inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); - }; - return Variable(result, {lhs}, gradFunc); + auto result = fl::maximum(lhs.tensor(), rhsVal).astype(lhs.type()); + auto gradFunc = + [rhsVal](std::vector& inputs, const Variable& gradOutput) { + auto mask = Variable( + (inputs[0].tensor() > rhsVal).astype(gradOutput.type()), + false + ); + inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); + }; + return Variable(result, {lhs}, gradFunc); } Variable max(const double& lhsVal, const Variable& rhs) { - return max(rhs, lhsVal); + return max(rhs, lhsVal); } Variable min(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - auto result = fl::minimum(lhs.tensor(), rhs.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto mask = Variable( - (inputs[0].tensor() < inputs[1].tensor()).astype(gradOutput.type()), - false); - inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); - inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false)); - }; - return Variable(result, {lhs, rhs}, gradFunc); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + auto result = fl::minimum(lhs.tensor(), rhs.tensor()); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + auto mask = Variable( + (inputs[0].tensor() < inputs[1].tensor()).astype(gradOutput.type()), + false + ); + inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); + inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false)); + }; + return Variable(result, {lhs, rhs}, gradFunc); } Variable min(const Variable& lhs, const double& rhsVal) { - auto result = fl::minimum(lhs.tensor(), rhsVal).astype(lhs.type()); - auto gradFunc = - [rhsVal](std::vector& inputs, const Variable& gradOutput) { - auto mask = Variable( - (inputs[0].tensor() < rhsVal).astype(gradOutput.type()), false); - inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); - }; - return Variable(result, {lhs}, gradFunc); + auto result = fl::minimum(lhs.tensor(), rhsVal).astype(lhs.type()); + auto gradFunc = + [rhsVal](std::vector& inputs, const Variable& gradOutput) { + auto mask = Variable( + (inputs[0].tensor() < rhsVal).astype(gradOutput.type()), + false + ); + inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false)); + }; + return Variable(result, {lhs}, gradFunc); } Variable min(const double& lhsVal, const Variable& rhs) { - return min(rhs, lhsVal); + return min(rhs, lhsVal); } Variable negate(const Variable& input) { - auto result = (0.0 - input.tensor()).astype(input.type()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + auto result = (0.0 - input.tensor()).astype(input.type()); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false)); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable reciprocal(const Variable& input) { - auto result = 1.0 / FL_ADJUST_INPUT_TYPE(input.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto res = reciprocal(inputs[0]); - inputs[0].addGrad( - Variable((negate(gradOutput) * res * res).tensor(), false)); - }; - return Variable(result, {input}, gradFunc); + auto result = 1.0 / FL_ADJUST_INPUT_TYPE(input.tensor()); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + auto res = reciprocal(inputs[0]); + inputs[0].addGrad( + Variable((negate(gradOutput) * res * res).tensor(), false) + ); + }; + return Variable(result, {input}, gradFunc); } Variable exp(const Variable& input) { - auto result = fl::exp(FL_ADJUST_INPUT_TYPE(input.tensor())); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable(gradOutput.tensor() * fl::exp(inputs[0].tensor()), false)); - }; - return Variable(result, {input}, gradFunc); + auto result = fl::exp(FL_ADJUST_INPUT_TYPE(input.tensor())); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad( + Variable(gradOutput.tensor() * fl::exp(inputs[0].tensor()), false) + ); + }; + return Variable(result, {input}, gradFunc); } Variable log(const Variable& input) { - auto result = fl::log(FL_ADJUST_INPUT_TYPE(input.tensor())); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable((gradOutput.tensor() / inputs[0].tensor()), false)); - }; - return Variable(result, {input}, gradFunc); + auto result = fl::log(FL_ADJUST_INPUT_TYPE(input.tensor())); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad( + Variable((gradOutput.tensor() / inputs[0].tensor()), false) + ); + }; + return Variable(result, {input}, gradFunc); } Variable log1p(const Variable& input) { - auto result = fl::log1p(FL_ADJUST_INPUT_TYPE(input.tensor())); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable((gradOutput.tensor() / (1.0 + inputs[0].tensor())), false)); - }; - return Variable(result, {input}, gradFunc); + auto result = fl::log1p(FL_ADJUST_INPUT_TYPE(input.tensor())); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad( + Variable((gradOutput.tensor() / (1.0 + inputs[0].tensor())), false) + ); + }; + return Variable(result, {input}, gradFunc); } Variable pow(const Variable& input, double p) { - auto result = fl::power(FL_ADJUST_INPUT_TYPE(input.tensor()), p); - auto gradFunc = [p](std::vector& inputs, - const Variable& gradOutput) { - Tensor grad = - p * fl::power(inputs[0].tensor(), p - 1) * gradOutput.tensor(); - inputs[0].addGrad(Variable(grad, false)); - }; - return Variable(result, {input}, gradFunc); + auto result = fl::power(FL_ADJUST_INPUT_TYPE(input.tensor()), p); + auto gradFunc = [p](std::vector& inputs, + const Variable& gradOutput) { + Tensor grad = + p * fl::power(inputs[0].tensor(), p - 1) * gradOutput.tensor(); + inputs[0].addGrad(Variable(grad, false)); + }; + return Variable(result, {input}, gradFunc); } Variable sin(const Variable& input) { - auto result = fl::sin(input.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad( - Variable((gradOutput.tensor() * cos(inputs[0].tensor())), false)); - }; - return Variable(result, {input}, gradFunc); + auto result = fl::sin(input.tensor()); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad( + Variable((gradOutput.tensor() * cos(inputs[0].tensor())), false) + ); + }; + return Variable(result, {input}, gradFunc); } Variable cos(const Variable& input) { - auto result = fl::cos(input.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable( - (gradOutput.tensor() * negative(sin(inputs[0].tensor()))), false)); - }; - return Variable(result, {input}, gradFunc); + auto result = fl::cos(input.tensor()); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad( + Variable( + (gradOutput.tensor() * negative(sin(inputs[0].tensor()))), + false + ) + ); + }; + return Variable(result, {input}, gradFunc); } Variable tanh(const Variable& input) { - auto result = fl::tanh(input.tensor()); - auto gradFunc = - [result](std::vector& inputs, const Variable& gradOutput) { - auto grad = - Variable((1.0 - result * result) * gradOutput.tensor(), false); - inputs[0].addGrad(Variable(grad.tensor(), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + auto result = fl::tanh(input.tensor()); + auto gradFunc = + [result](std::vector& inputs, const Variable& gradOutput) { + auto grad = + Variable((1.0 - result * result) * gradOutput.tensor(), false); + inputs[0].addGrad(Variable(grad.tensor(), false)); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable clamp(const Variable& input, const double lo, const double hi) { - auto result = fl::clip(input.tensor(), lo, hi); - auto gradFunc = [lo, hi, result]( - std::vector& inputs, - const Variable& gradOutput) { - Tensor gradMask = gradOutput.tensor(); - gradMask = fl::where((result > lo) && (result < hi), gradMask, 0); - inputs[0].addGrad(Variable(gradMask, false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + auto result = fl::clip(input.tensor(), lo, hi); + auto gradFunc = [lo, hi, result]( + std::vector& inputs, + const Variable& gradOutput) { + Tensor gradMask = gradOutput.tensor(); + gradMask = fl::where((result > lo) && (result < hi), gradMask, 0); + inputs[0].addGrad(Variable(gradMask, false)); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable sqrt(const Variable& input) { - auto result = fl::sqrt(input.tensor()); - auto gradFunc = [result]( - std::vector& inputs, - const Variable& gradOutput) { - auto output = Variable(result, false); - inputs[0].addGrad(Variable((gradOutput / (2 * output)).tensor(), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + auto result = fl::sqrt(input.tensor()); + auto gradFunc = [result]( + std::vector& inputs, + const Variable& gradOutput) { + auto output = Variable(result, false); + inputs[0].addGrad(Variable((gradOutput / (2 * output)).tensor(), false)); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable sigmoid(const Variable& input) { - auto result = fl::sigmoid(input.tensor()); - auto gradFunc = - [result](std::vector& inputs, const Variable& gradOutput) { - auto grad = gradOutput.tensor() * result * (1 - result); - inputs[0].addGrad(Variable(grad, false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + auto result = fl::sigmoid(input.tensor()); + auto gradFunc = + [result](std::vector& inputs, const Variable& gradOutput) { + auto grad = gradOutput.tensor() * result * (1 - result); + inputs[0].addGrad(Variable(grad, false)); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable swish(const Variable& input, double beta) { - return input * sigmoid(beta * input); + return input * sigmoid(beta * input); } Variable erf(const Variable& input) { - auto result = fl::erf(FL_ADJUST_INPUT_TYPE(input.tensor())); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto x = inputs[0].tensor(); - auto grad = gradOutput.tensor() * 2 / std::sqrt(M_PI) * fl::exp(-(x * x)); - inputs[0].addGrad(Variable(grad, false)); - }; - return Variable(result, {input}, gradFunc); + auto result = fl::erf(FL_ADJUST_INPUT_TYPE(input.tensor())); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + auto x = inputs[0].tensor(); + auto grad = gradOutput.tensor() * 2 / std::sqrt(M_PI) * fl::exp(-(x * x)); + inputs[0].addGrad(Variable(grad, false)); + }; + return Variable(result, {input}, gradFunc); } Variable transpose(const Variable& input, const Shape& dims /* = {} */) { - auto result = fl::transpose(input.tensor(), dims); - auto gradFunc = [inputDims = input.shape(), ndim = input.ndim(), dims]( - std::vector& inputs, - const Variable& gradOutput) { - Shape reverseShape = dims; - - if (dims.ndim()) { - // Reverse vec if transposing all dims (empty arg) - auto dVec = dims.get(); - std::reverse(dVec.begin(), dVec.end()); - reverseShape = Shape(dVec); - } - - for (unsigned i = 0; i < reverseShape.ndim(); ++i) { - reverseShape[dims[i]] = i; - } - - inputs[0].addGrad( - Variable(fl::transpose(gradOutput.tensor(), reverseShape), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + auto result = fl::transpose(input.tensor(), dims); + auto gradFunc = [inputDims = input.shape(), ndim = input.ndim(), dims]( + std::vector& inputs, + const Variable& gradOutput) { + Shape reverseShape = dims; + + if(dims.ndim()) { + // Reverse vec if transposing all dims (empty arg) + auto dVec = dims.get(); + std::reverse(dVec.begin(), dVec.end()); + reverseShape = Shape(dVec); + } + + for(unsigned i = 0; i < reverseShape.ndim(); ++i) { + reverseShape[dims[i]] = i; + } + + inputs[0].addGrad( + Variable(fl::transpose(gradOutput.tensor(), reverseShape), false) + ); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable tileAs(const Variable& input, const Shape& rdims) { - auto result = detail::tileAs(input.tensor(), rdims); - - Shape inDims = input.shape(); - auto gradFunc = [inDims]( - std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable( - sumAs(gradOutput, inDims).tensor().astype(inputs[0].type()), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + auto result = detail::tileAs(input.tensor(), rdims); + + Shape inDims = input.shape(); + auto gradFunc = [inDims]( + std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad( + Variable( + sumAs(gradOutput, inDims).tensor().astype(inputs[0].type()), + false + ) + ); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable tileAs(const Variable& input, const Variable& reference) { - return tileAs(input, reference.shape()); + return tileAs(input, reference.shape()); } Variable sumAs(const Variable& input, const Shape& rdims) { - auto result = detail::sumAs(FL_ADJUST_INPUT_TYPE(input.tensor()), rdims); - auto idims = input.tensor().shape(); - auto gradFunc = - [idims](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad(Variable(tileAs(gradOutput, idims).tensor(), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + auto result = detail::sumAs(FL_ADJUST_INPUT_TYPE(input.tensor()), rdims); + auto idims = input.tensor().shape(); + auto gradFunc = + [idims](std::vector& inputs, const Variable& gradOutput) { + inputs[0].addGrad(Variable(tileAs(gradOutput, idims).tensor(), false)); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable sumAs(const Variable& input, const Variable& reference) { - return sumAs(input, reference.shape()); + return sumAs(input, reference.shape()); } Variable concatenate(const std::vector& concatInputs, int dim) { - if (concatInputs.empty()) { - throw std::invalid_argument("cannot concatenate zero variables"); - } - - if (concatInputs.size() == 1) { - return concatInputs[0]; - } - // All Variables must be of the same type - fl::dtype type = concatInputs[0].type(); - for (auto& var : concatInputs) { - if (var.type() != type) { - throw std::invalid_argument( - "concatenate: all input Variables must be of the same type"); + if(concatInputs.empty()) { + throw std::invalid_argument("cannot concatenate zero variables"); } - } - // All Variables must have the same number of dims - unsigned numDims = concatInputs[0].ndim(); - for (auto& var : concatInputs) { - if (numDims != var.ndim()) { - throw std::invalid_argument( - "concatenate: all input Variables must " - "have the same number of dimensions"); + + if(concatInputs.size() == 1) { + return concatInputs[0]; } - } - - // All Variables must have the same size when indexed along the dim not being - // concatenated along - auto dims = concatInputs[0].shape(); - int concatSize = dims[dim]; - for (int i = 1; i < concatInputs.size(); i++) { - concatSize += concatInputs[i].dim(dim); - for (int d = 0; d < numDims; d++) { - if (dim != d && concatInputs[i].dim(d) != dims[d]) { - throw std::invalid_argument( - "mismatch in dimension not being concatenated"); - } + // All Variables must be of the same type + fl::dtype type = concatInputs[0].type(); + for(auto& var : concatInputs) { + if(var.type() != type) { + throw std::invalid_argument( + "concatenate: all input Variables must be of the same type" + ); + } } - } - dims[dim] = concatSize; - Tensor result(dims, concatInputs[0].type()); - std::vector slice(numDims, fl::span); - int start = 0; - for (const auto& input : concatInputs) { - slice[dim] = fl::range({start, start + input.dim(dim)}); - result(slice) = input.tensor(); - start += input.dim(dim); - } - - std::vector inputsNoData; - std::vector inDims; - - for (const auto& in : concatInputs) { - inputsNoData.push_back(in.withoutData()); - inDims.push_back(in.shape()); - } - - auto gradFunc = [dim, inDims, numDims]( - std::vector& inputs, - const Variable& gradOutput) { - std::vector sx(numDims, fl::span); - int s = 0; - for (size_t i = 0; i < inputs.size(); ++i) { - sx[dim] = fl::range(s, s + inDims[i][dim]); - inputs[i].addGrad(Variable(gradOutput.tensor()(sx), false)); - s += inDims[i][dim]; + // All Variables must have the same number of dims + unsigned numDims = concatInputs[0].ndim(); + for(auto& var : concatInputs) { + if(numDims != var.ndim()) { + throw std::invalid_argument( + "concatenate: all input Variables must " + "have the same number of dimensions" + ); + } + } + + // All Variables must have the same size when indexed along the dim not being + // concatenated along + auto dims = concatInputs[0].shape(); + int concatSize = dims[dim]; + for(int i = 1; i < concatInputs.size(); i++) { + concatSize += concatInputs[i].dim(dim); + for(int d = 0; d < numDims; d++) { + if(dim != d && concatInputs[i].dim(d) != dims[d]) { + throw std::invalid_argument( + "mismatch in dimension not being concatenated" + ); + } + } + } + dims[dim] = concatSize; + Tensor result(dims, concatInputs[0].type()); + std::vector slice(numDims, fl::span); + int start = 0; + for(const auto& input : concatInputs) { + slice[dim] = fl::range({start, start + input.dim(dim)}); + result(slice) = input.tensor(); + start += input.dim(dim); + } + + std::vector inputsNoData; + std::vector inDims; + + for(const auto& in : concatInputs) { + inputsNoData.push_back(in.withoutData()); + inDims.push_back(in.shape()); } - }; - return Variable(result, inputsNoData, gradFunc); + auto gradFunc = [dim, inDims, numDims]( + std::vector& inputs, + const Variable& gradOutput) { + std::vector sx(numDims, fl::span); + int s = 0; + for(size_t i = 0; i < inputs.size(); ++i) { + sx[dim] = fl::range(s, s + inDims[i][dim]); + inputs[i].addGrad(Variable(gradOutput.tensor()(sx), false)); + s += inDims[i][dim]; + } + }; + + return Variable(result, inputsNoData, gradFunc); } std::vector split(const Variable& input, long splitSize, int dim) { - if (splitSize <= 0) { - throw std::invalid_argument("split size must be a positive integer"); - } - auto dimSize = input.dim(dim); - std::vector splitSizes(dimSize / splitSize, splitSize); - - if (dimSize % splitSize > 0) { - splitSizes.push_back(dimSize % splitSize); - } - return split(input, splitSizes, dim); -} - -std::vector -split(const Variable& input, const std::vector& splitSizes, int dim) { - if (dim >= input.ndim()) { - throw std::invalid_argument( - "split: passed dim is larger than the number of dimensions " - "of the input."); - } - auto dimSize = input.dim(dim); - auto N = splitSizes.size(); - - std::vector outputs(N); - std::vector sel(input.ndim(), fl::span); - int start = 0; - for (int i = 0; i < N; ++i) { - if (splitSizes[i] <= 0) { - throw std::invalid_argument("elements in split sizes has to be positive"); + if(splitSize <= 0) { + throw std::invalid_argument("split size must be a positive integer"); + } + auto dimSize = input.dim(dim); + std::vector splitSizes(dimSize / splitSize, splitSize); + + if(dimSize % splitSize > 0) { + splitSizes.push_back(dimSize % splitSize); } - int end = start + splitSizes[i]; - sel[dim] = fl::range(start, end); - outputs[i] = input(sel); - start = end; - } - if (start != dimSize) { - throw std::invalid_argument("sum of split sizes must match split dim"); - } - return outputs; + return split(input, splitSizes, dim); +} + +std::vector split(const Variable& input, const std::vector& splitSizes, int dim) { + if(dim >= input.ndim()) { + throw std::invalid_argument( + "split: passed dim is larger than the number of dimensions " + "of the input." + ); + } + auto dimSize = input.dim(dim); + auto N = splitSizes.size(); + + std::vector outputs(N); + std::vector sel(input.ndim(), fl::span); + int start = 0; + for(int i = 0; i < N; ++i) { + if(splitSizes[i] <= 0) { + throw std::invalid_argument("elements in split sizes has to be positive"); + } + int end = start + splitSizes[i]; + sel[dim] = fl::range(start, end); + outputs[i] = input(sel); + start = end; + } + if(start != dimSize) { + throw std::invalid_argument("sum of split sizes must match split dim"); + } + return outputs; } Variable tile(const Variable& input, const Shape& dims) { - Tensor result = fl::tile(input.tensor(), dims); - Shape idims = input.shape(); - auto gradFunc = - [idims](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad(Variable( - sumAs(gradOutput, idims).tensor().astype(inputs[0].type()), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + Tensor result = fl::tile(input.tensor(), dims); + Shape idims = input.shape(); + auto gradFunc = + [idims](std::vector& inputs, const Variable& gradOutput) { + inputs[0].addGrad( + Variable( + sumAs(gradOutput, idims).tensor().astype(inputs[0].type()), + false + ) + ); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable sum( const Variable& input, const std::vector& axes, - bool keepDims /* = false*/) { - auto result = FL_ADJUST_INPUT_TYPE(input.tensor()); - result = fl::sum(result, axes, keepDims); - - Shape indims = input.shape(); - auto gradFunc = [indims, axes, keepDims]( - std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable( - detail::tileAs( - detail::expandFromReduction(gradOutput.tensor(), axes, keepDims), - indims), - false)); - }; - return Variable(result.astype(input.type()), {input.withoutData()}, gradFunc); + bool keepDims /* = false*/ +) { + auto result = FL_ADJUST_INPUT_TYPE(input.tensor()); + result = fl::sum(result, axes, keepDims); + + Shape indims = input.shape(); + auto gradFunc = [indims, axes, keepDims]( + std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad( + Variable( + detail::tileAs( + detail::expandFromReduction(gradOutput.tensor(), axes, keepDims), + indims + ), + false + ) + ); + }; + return Variable(result.astype(input.type()), {input.withoutData()}, gradFunc); } Variable mean( const Variable& input, const std::vector& axes, - bool keepDims /* = false*/) { - auto result = FL_ADJUST_INPUT_TYPE(input.tensor()); - result = mean(result, axes, keepDims); - - Shape idims = input.shape(); - auto gradFunc = [idims, axes, keepDims]( - std::vector& inputs, - const Variable& gradOutput) { - Shape odims = gradOutput.shape(); - Dim count = 1; - for (int i = 0; i < idims.ndim(); i++) { - Dim odimSize = i + 1 > odims.ndim() ? 1 : odims[i]; - count *= idims[i] / odimSize; - } - auto grad = - detail::tileAs( - detail::expandFromReduction(gradOutput.tensor(), axes, keepDims), - idims) / - count; - inputs[0].addGrad(Variable( - detail::tileAs( - detail::expandFromReduction(gradOutput.tensor(), axes, keepDims), - idims) / - count, - false)); - }; - - return Variable(result, {input.withoutData()}, gradFunc); + bool keepDims /* = false*/ +) { + auto result = FL_ADJUST_INPUT_TYPE(input.tensor()); + result = mean(result, axes, keepDims); + + Shape idims = input.shape(); + auto gradFunc = [idims, axes, keepDims]( + std::vector& inputs, + const Variable& gradOutput) { + Shape odims = gradOutput.shape(); + Dim count = 1; + for(int i = 0; i < idims.ndim(); i++) { + Dim odimSize = i + 1 > odims.ndim() ? 1 : odims[i]; + count *= idims[i] / odimSize; + } + auto grad = + detail::tileAs( + detail::expandFromReduction(gradOutput.tensor(), axes, keepDims), + idims + ) + / count; + inputs[0].addGrad( + Variable( + detail::tileAs( + detail::expandFromReduction(gradOutput.tensor(), axes, keepDims), + idims + ) + / count, + false + ) + ); + }; + + return Variable(result, {input.withoutData()}, gradFunc); } Variable var( const Variable& in, const std::vector& axes, const bool isbiased /* = false */, - bool keepDims /* = false*/) { - Tensor input = FL_ADJUST_INPUT_TYPE(in.tensor()); - auto result = sum(input * input, axes, keepDims); - - auto avg = fl::mean(input, axes, keepDims); - auto n = 1; - for (auto ax : axes) { - n *= input.dim(ax); - } - if (!isbiased && n == 1) { - throw std::invalid_argument( - "cannot compute unbiased variance with only one sample"); - } - auto val = 1.0 / (isbiased ? n : n - 1); - result = val * (result - n * avg * avg); - - auto gradFunc = - [val, axes](std::vector& inputs, const Variable& gradOutput) { - Shape expandedDims = inputs[0].shape(); - Shape tileDims = inputs[0].shape(); - for (auto ax : axes) { - tileDims[ax] = inputs[0].dim(ax); - expandedDims[ax] = 1; - } - - inputs[0].addGrad(Variable( - ((2 * val * tileAs(moddims(gradOutput, expandedDims), tileDims)) * - (inputs[0] - - tileAs(moddims(mean(inputs[0], axes), expandedDims), tileDims))) - .tensor(), - false)); - }; - return Variable(result, {in}, gradFunc); + bool keepDims /* = false*/ +) { + Tensor input = FL_ADJUST_INPUT_TYPE(in.tensor()); + auto result = sum(input * input, axes, keepDims); + + auto avg = fl::mean(input, axes, keepDims); + auto n = 1; + for(auto ax : axes) { + n *= input.dim(ax); + } + if(!isbiased && n == 1) { + throw std::invalid_argument( + "cannot compute unbiased variance with only one sample" + ); + } + auto val = 1.0 / (isbiased ? n : n - 1); + result = val * (result - n * avg * avg); + + auto gradFunc = + [val, axes](std::vector& inputs, const Variable& gradOutput) { + Shape expandedDims = inputs[0].shape(); + Shape tileDims = inputs[0].shape(); + for(auto ax : axes) { + tileDims[ax] = inputs[0].dim(ax); + expandedDims[ax] = 1; + } + + inputs[0].addGrad( + Variable( + ((2 * val * tileAs(moddims(gradOutput, expandedDims), tileDims)) + * (inputs[0] + - tileAs(moddims(mean(inputs[0], axes), expandedDims), tileDims))) + .tensor(), + false + ) + ); + }; + return Variable(result, {in}, gradFunc); } Variable norm( const Variable& input, const std::vector& axes, double p /* = 2 */, - bool keepDims /* = false */) { - if (p <= 0) { - throw std::out_of_range("Lp norm: p must be > 0"); - } - auto result = fl::power(fl::abs(FL_ADJUST_INPUT_TYPE(input.tensor())), p); - result = fl::sum(result, axes, /* keepDims = */ keepDims); - - Tensor sumap = detail::expandFromReduction(result, axes, keepDims); - result = fl::power(result, 1 / p); - fl::eval(result); - - auto gradFunc = [sumap, p, axes, keepDims]( - std::vector& inputs, - const Variable& gradOutput) { - // correct, but less precise: auto gvar = Variable(fl::power(result, p - 1), - // false); - auto gvar = Variable(fl::power(sumap, 1 - 1 / p), false); - auto normGrad = - (inputs[0].tensor() * fl::pow(fl::abs(inputs[0]), p - 2).tensor() * - detail::tileAs( - detail::expandFromReduction(gradOutput.tensor(), axes, keepDims) / - gvar.tensor(), - inputs[0].shape())); - inputs[0].addGrad(Variable(normGrad, false)); - }; - return Variable(result, {input}, gradFunc); + bool keepDims /* = false */ +) { + if(p <= 0) { + throw std::out_of_range("Lp norm: p must be > 0"); + } + auto result = fl::power(fl::abs(FL_ADJUST_INPUT_TYPE(input.tensor())), p); + result = fl::sum(result, axes, /* keepDims = */ keepDims); + + Tensor sumap = detail::expandFromReduction(result, axes, keepDims); + result = fl::power(result, 1 / p); + fl::eval(result); + + auto gradFunc = [sumap, p, axes, keepDims]( + std::vector& inputs, + const Variable& gradOutput) { + // correct, but less precise: auto gvar = Variable(fl::power(result, p - 1), + // false); + auto gvar = Variable(fl::power(sumap, 1 - 1 / p), false); + auto normGrad = + (inputs[0].tensor() * fl::pow(fl::abs(inputs[0]), p - 2).tensor() + * detail::tileAs( + detail::expandFromReduction(gradOutput.tensor(), axes, keepDims) + / gvar.tensor(), + inputs[0].shape() + )); + inputs[0].addGrad(Variable(normGrad, false)); + }; + return Variable(result, {input}, gradFunc); } Variable normalize( const Variable& in, const std::vector& axes, double p /* = 2 */, - double eps /* = 1e-12 */) { - auto input = FL_ADJUST_INPUT_TYPE(in); - Variable norm = fl::norm(input, axes, p); - Variable invscale = max(norm, eps); - return input / tileAs(invscale, input); + double eps /* = 1e-12 */ +) { + auto input = FL_ADJUST_INPUT_TYPE(in); + Variable norm = fl::norm(input, axes, p); + Variable invscale = max(norm, eps); + return input / tileAs(invscale, input); } Variable matmul(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - // lhs:Input[0] -- [M, N] - // rhs:Input[1] -- [N, K] - // matmul(lhs, rhs) - // -- matmul([M, N], [N, K]) -- [M, K] - // result:gradOutput -- [M, K] - auto result = fl::matmul(lhs.tensor(), rhs.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - if (inputs[0].isCalcGrad()) { - Tensor _lhs = gradOutput.tensor(); - if (_lhs.ndim() == 1) { - _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)}); - } - Tensor _rhs = inputs[1].tensor(); - if (_rhs.ndim() == 1) { - _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1}); - } - - // matmulNT(gradOutput, inputs[1]) - // -- matmulNT([M, K], [N, K]) - // -- matmul([M, K], [K, N]) -- [M, K] - auto val = fl::matmul( - _lhs, - _rhs, - /* lhsProp = */ MatrixProperty::None, - /* rhsProp = */ MatrixProperty::Transpose); - inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); - } - if (inputs[1].isCalcGrad()) { - Tensor _lhs = inputs[0].tensor(); - if (_lhs.ndim() == 1) { - _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)}); - } - Tensor _rhs = gradOutput.tensor(); - if (_rhs.ndim() == 1) { - _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1}); - } - - // matmulTN(inputs[0], gradOutput) - // -- matmulTN([M, N], [M, K]) - // -- matmul([N, M], [M, K]) -- [N, K] - auto val = fl::matmul( - _lhs, - _rhs, - /* lhsProp = */ MatrixProperty::Transpose); - inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); - } - }; - return Variable(result, {lhs, rhs}, gradFunc); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + // lhs:Input[0] -- [M, N] + // rhs:Input[1] -- [N, K] + // matmul(lhs, rhs) + // -- matmul([M, N], [N, K]) -- [M, K] + // result:gradOutput -- [M, K] + auto result = fl::matmul(lhs.tensor(), rhs.tensor()); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + if(inputs[0].isCalcGrad()) { + Tensor _lhs = gradOutput.tensor(); + if(_lhs.ndim() == 1) { + _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)}); + } + Tensor _rhs = inputs[1].tensor(); + if(_rhs.ndim() == 1) { + _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1}); + } + + // matmulNT(gradOutput, inputs[1]) + // -- matmulNT([M, K], [N, K]) + // -- matmul([M, K], [K, N]) -- [M, K] + auto val = fl::matmul( + _lhs, + _rhs, + /* lhsProp = */ MatrixProperty::None, + /* rhsProp = */ MatrixProperty::Transpose + ); + inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); + } + if(inputs[1].isCalcGrad()) { + Tensor _lhs = inputs[0].tensor(); + if(_lhs.ndim() == 1) { + _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)}); + } + Tensor _rhs = gradOutput.tensor(); + if(_rhs.ndim() == 1) { + _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1}); + } + + // matmulTN(inputs[0], gradOutput) + // -- matmulTN([M, N], [M, K]) + // -- matmul([N, M], [M, K]) -- [N, K] + auto val = fl::matmul( + _lhs, + _rhs, + /* lhsProp = */ MatrixProperty::Transpose + ); + inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); + } + }; + return Variable(result, {lhs, rhs}, gradFunc); } Variable matmulTN(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - // lhs:Input[0] -- [N, M] - // rhs:Input[1] -- [N, K] - // matmulTN(lhs, rhs) - // -- matmulTN([N, M], [N, K]) - // -- matmul([M, N], [N, K]) -- [M, K] - // result:gradOutput -- [M, K] - auto result = fl::matmul( - lhs.tensor(), rhs.tensor(), /* lhsProp = */ MatrixProperty::Transpose); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - if (inputs[0].isCalcGrad()) { - // matmulNT(inputs[1], gradOutput) - // -- matmulNT([N, K], [M, K]) - // -- matmul([N, K], [K, M]) -- [N, M] - auto val = fl::matmul( - inputs[1].tensor(), - gradOutput.tensor(), - /* lhsProp = */ MatrixProperty::None, - /* rhsProp = */ MatrixProperty::Transpose); - inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); - } - if (inputs[1].isCalcGrad()) { - // matmul(inputs[0], gradOutput) - // -- matmulNT([N, M], [M, K]) -- [N, K] - auto val = fl::matmul(inputs[0].tensor(), gradOutput.tensor()); - inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); - } - }; - return Variable(result, {lhs, rhs}, gradFunc); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + // lhs:Input[0] -- [N, M] + // rhs:Input[1] -- [N, K] + // matmulTN(lhs, rhs) + // -- matmulTN([N, M], [N, K]) + // -- matmul([M, N], [N, K]) -- [M, K] + // result:gradOutput -- [M, K] + auto result = fl::matmul( + lhs.tensor(), + rhs.tensor(), /* lhsProp = */ + MatrixProperty::Transpose + ); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + if(inputs[0].isCalcGrad()) { + // matmulNT(inputs[1], gradOutput) + // -- matmulNT([N, K], [M, K]) + // -- matmul([N, K], [K, M]) -- [N, M] + auto val = fl::matmul( + inputs[1].tensor(), + gradOutput.tensor(), + /* lhsProp = */ MatrixProperty::None, + /* rhsProp = */ MatrixProperty::Transpose + ); + inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); + } + if(inputs[1].isCalcGrad()) { + // matmul(inputs[0], gradOutput) + // -- matmulNT([N, M], [M, K]) -- [N, K] + auto val = fl::matmul(inputs[0].tensor(), gradOutput.tensor()); + inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); + } + }; + return Variable(result, {lhs, rhs}, gradFunc); } Variable matmulNT(const Variable& lhs, const Variable& rhs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); - // lhs:Input[0] -- [M, N] - // rhs:Input[1] -- [K, N] - // matmulNT(lhs, rhs) - // -- matmulNT([M, N], [K, N]) - // -- matmul([M, N], [N, K]) -- [M, K] - // result:gradOutput -- [M, K] - auto result = fl::matmul( - lhs.tensor(), - rhs.tensor(), - /* lhsProp = */ MatrixProperty::None, - /* rhsProp = */ MatrixProperty::Transpose); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - if (inputs[0].isCalcGrad()) { - // matmul(gradOutput, inputs[1]) - // -- matmul([M, K], [K, N]) -- [M, N] - auto val = fl::matmul(gradOutput.tensor(), inputs[1].tensor()); - inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); - } - if (inputs[1].isCalcGrad()) { - // matmulTN(gradOutput, inputs[0]) - // -- matmulTN([M, K], [M, N]) - // -- matmul([K, M], [M, N]) -- [K, N] - auto val = fl::matmul( - gradOutput.tensor(), - inputs[0].tensor(), - /* lhsProp = */ MatrixProperty::Transpose); - inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); - } - }; - return Variable(result, {lhs, rhs}, gradFunc); + FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); + // lhs:Input[0] -- [M, N] + // rhs:Input[1] -- [K, N] + // matmulNT(lhs, rhs) + // -- matmulNT([M, N], [K, N]) + // -- matmul([M, N], [N, K]) -- [M, K] + // result:gradOutput -- [M, K] + auto result = fl::matmul( + lhs.tensor(), + rhs.tensor(), + /* lhsProp = */ MatrixProperty::None, + /* rhsProp = */ MatrixProperty::Transpose + ); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + if(inputs[0].isCalcGrad()) { + // matmul(gradOutput, inputs[1]) + // -- matmul([M, K], [K, N]) -- [M, N] + auto val = fl::matmul(gradOutput.tensor(), inputs[1].tensor()); + inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false)); + } + if(inputs[1].isCalcGrad()) { + // matmulTN(gradOutput, inputs[0]) + // -- matmulTN([M, K], [M, N]) + // -- matmul([K, M], [M, N]) -- [K, N] + auto val = fl::matmul( + gradOutput.tensor(), + inputs[0].tensor(), + /* lhsProp = */ MatrixProperty::Transpose + ); + inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false)); + } + }; + return Variable(result, {lhs, rhs}, gradFunc); } Variable abs(const Variable& input) { - auto result = fl::abs(input.tensor()); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - // Convert it into -1, 0, 1 - auto sign = fl::sign(inputs[0].tensor()); - inputs[0].addGrad(Variable((sign * gradOutput.tensor()), false)); - }; - return Variable(result, {input}, gradFunc); + auto result = fl::abs(input.tensor()); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + // Convert it into -1, 0, 1 + auto sign = fl::sign(inputs[0].tensor()); + inputs[0].addGrad(Variable((sign * gradOutput.tensor()), false)); + }; + return Variable(result, {input}, gradFunc); } Variable flat(const Variable& input) { - auto result = input.tensor().flatten(); - Shape idims = input.shape(); - auto gradFunc = - [idims](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad(Variable(reshape(gradOutput.tensor(), idims), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + auto result = input.tensor().flatten(); + Shape idims = input.shape(); + auto gradFunc = + [idims](std::vector& inputs, const Variable& gradOutput) { + inputs[0].addGrad(Variable(reshape(gradOutput.tensor(), idims), false)); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable moddims(const Variable& input, const Shape& dims) { - if (input.ndim() == 0) { - return input; - } - Shape inferDims = dims; - unsigned maxNDims = - std::max(input.ndim(), static_cast(dims.ndim())); - - // Check for inferred dims that are beyond the input's number of dims - for (int i = 0; i < maxNDims; ++i) { - if (i >= input.ndim() && inferDims[i] == 0) { - throw std::invalid_argument( - "moddims: tried to infer dimension " + std::to_string(i) + - " which exceeds the number of dimensions of the input."); + if(input.ndim() == 0) { + return input; + } + Shape inferDims = dims; + unsigned maxNDims = + std::max(input.ndim(), static_cast(dims.ndim())); + + // Check for inferred dims that are beyond the input's number of dims + for(int i = 0; i < maxNDims; ++i) { + if(i >= input.ndim() && inferDims[i] == 0) { + throw std::invalid_argument( + "moddims: tried to infer dimension " + std::to_string(i) + + " which exceeds the number of dimensions of the input." + ); + } } - } - // Infer any 0 dim - for (int i = 0; i < maxNDims; ++i) { - if (i < inferDims.ndim() && inferDims[i] == 0) { - inferDims[i] = input.dim(i); + // Infer any 0 dim + for(int i = 0; i < maxNDims; ++i) { + if(i < inferDims.ndim() && inferDims[i] == 0) { + inferDims[i] = input.dim(i); + } } - } - - // Infer any -1 dim - int nInfer = 0; - for (int i = 0; i < maxNDims; i++) { - if (i < inferDims.ndim() && inferDims[i] == -1) { - nInfer++; - inferDims[i] = -(input.elements() / inferDims.elements()); + + // Infer any -1 dim + int nInfer = 0; + for(int i = 0; i < maxNDims; i++) { + if(i < inferDims.ndim() && inferDims[i] == -1) { + nInfer++; + inferDims[i] = -(input.elements() / inferDims.elements()); + } } - } - if (nInfer > 1) { - throw std::invalid_argument("moddims: too many dimensions infer"); - } + if(nInfer > 1) { + throw std::invalid_argument("moddims: too many dimensions infer"); + } - if (inferDims.elements() != input.elements()) { - throw std::invalid_argument("moddims: mismatched # of elements"); - } + if(inferDims.elements() != input.elements()) { + throw std::invalid_argument("moddims: mismatched # of elements"); + } - auto result = fl::reshape(input.tensor(), inferDims); + auto result = fl::reshape(input.tensor(), inferDims); - Shape inDims = input.shape(); - auto gradFunc = [inDims]( - std::vector& inputs, - const Variable& gradOutput) { - inputs[0].addGrad(Variable(moddims(gradOutput, inDims).tensor(), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + Shape inDims = input.shape(); + auto gradFunc = [inDims]( + std::vector& inputs, + const Variable& gradOutput) { + inputs[0].addGrad(Variable(moddims(gradOutput, inDims).tensor(), false)); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable softmax(const Variable& input, const int dim) { - Tensor inputArr = FL_ADJUST_INPUT_TYPE(input.tensor()); - auto maxvals = amax(inputArr, {dim}, /* keepDims = */ true); - Shape tiledims(std::vector(input.ndim(), 1)); - tiledims[dim] = input.dim(dim); - - auto expInput = fl::exp(inputArr - fl::tile(maxvals, tiledims)); - auto result = expInput / - fl::tile(fl::sum(expInput, {dim}, /* keepDims = */ true), tiledims); - - fl::eval(result); - auto gradFunc = [dim, tiledims, result]( - std::vector& inputs, - const Variable& gradOutput) { - auto rbyg = gradOutput.tensor() * result; - auto gradSm = rbyg - - result * - fl::tile(fl::sum(rbyg, {dim}, /* keepDims = */ true), tiledims); - inputs[0].addGrad(Variable(gradSm.astype(inputs[0].type()), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + Tensor inputArr = FL_ADJUST_INPUT_TYPE(input.tensor()); + auto maxvals = amax(inputArr, {dim}, /* keepDims = */ true); + Shape tiledims(std::vector(input.ndim(), 1)); + tiledims[dim] = input.dim(dim); + + auto expInput = fl::exp(inputArr - fl::tile(maxvals, tiledims)); + auto result = expInput + / fl::tile(fl::sum(expInput, {dim}, /* keepDims = */ true), tiledims); + + fl::eval(result); + auto gradFunc = [dim, tiledims, result]( + std::vector& inputs, + const Variable& gradOutput) { + auto rbyg = gradOutput.tensor() * result; + auto gradSm = rbyg + - result + * fl::tile(fl::sum(rbyg, {dim}, /* keepDims = */ true), tiledims); + inputs[0].addGrad(Variable(gradSm.astype(inputs[0].type()), false)); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable logSoftmax(const Variable& input, const int dim) { - Tensor inputArr = FL_ADJUST_INPUT_TYPE(input.tensor()); - auto maxvals = amax(inputArr, {dim}, /* keepDims = */ true); - // TODO{fl::Tensor}{rewrite} - Shape tiledims(std::vector(input.ndim(), 1)); - tiledims[dim] = input.dim(dim); - auto result = inputArr - - fl::tile(fl::log(fl::sum( - fl::exp(inputArr - fl::tile(maxvals, tiledims)), - {dim}, - /* keepDims = */ true)) + - maxvals, - tiledims); - - fl::eval(result); - auto gradFunc = [dim, tiledims, result]( - std::vector& inputs, - const Variable& gradOutput) { - auto gradLsm = gradOutput.tensor() - - fl::exp(result) * - fl::tile( - fl::sum(gradOutput.tensor(), {dim}, /* keepDims = */ true), - tiledims); - inputs[0].addGrad(Variable(gradLsm.astype(inputs[0].type()), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + Tensor inputArr = FL_ADJUST_INPUT_TYPE(input.tensor()); + auto maxvals = amax(inputArr, {dim}, /* keepDims = */ true); + // TODO{fl::Tensor}{rewrite} + Shape tiledims(std::vector(input.ndim(), 1)); + tiledims[dim] = input.dim(dim); + auto result = inputArr + - fl::tile( + fl::log( + fl::sum( + fl::exp(inputArr - fl::tile(maxvals, tiledims)), + {dim}, + /* keepDims = */ true + ) + ) + + maxvals, + tiledims + ); + + fl::eval(result); + auto gradFunc = [dim, tiledims, result]( + std::vector& inputs, + const Variable& gradOutput) { + auto gradLsm = gradOutput.tensor() + - fl::exp(result) + * fl::tile( + fl::sum(gradOutput.tensor(), {dim}, /* keepDims = */ true), + tiledims + ); + inputs[0].addGrad(Variable(gradLsm.astype(inputs[0].type()), false)); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable binaryCrossEntropy(const Variable& inputs, const Variable& targets) { - auto targetsTyped = targets.astype(inputs.type()); - return negate( - targetsTyped * log(inputs) + (1 - targetsTyped) * log(1 - inputs)); + auto targetsTyped = targets.astype(inputs.type()); + return negate( + targetsTyped * log(inputs) + (1 - targetsTyped) * log(1 - inputs) + ); } Variable categoricalCrossEntropy( const Variable& in, const Variable& targets, ReduceMode reduction /* =ReduceMode::MEAN */, - int ignoreIndex /* = -1 */) { - auto input = FL_ADJUST_INPUT_TYPE(in); - // input -- [C, X1, X2, X3] - // target -- [X1, X2, X3, 1] - if (input.ndim() != targets.ndim() + 1) { - throw std::invalid_argument( - "dimension mismatch in categorical cross entropy: " - "target must have one fewer dimension than input"); - } - for (int i = 1; i < input.ndim(); i++) { - if (input.dim(i) != targets.dim(i - 1)) { - throw std::invalid_argument( - "dimension mismatch in categorical cross entropy"); + int ignoreIndex /* = -1 */ +) { + auto input = FL_ADJUST_INPUT_TYPE(in); + // input -- [C, X1, X2, X3] + // target -- [X1, X2, X3, 1] + if(input.ndim() != targets.ndim() + 1) { + throw std::invalid_argument( + "dimension mismatch in categorical cross entropy: " + "target must have one fewer dimension than input" + ); } - } - - int C = input.dim(0); - int X = targets.elements(); - if (fl::any( - ((targets.tensor() < 0) || (targets.tensor() >= C)) && - (targets.tensor() != ignoreIndex)) - .scalar()) { - throw std::invalid_argument( - "target contains elements out of valid range [0, num_categories) " - "in categorical cross entropy"); - } - - auto x = fl::reshape(input.tensor(), Shape({C, X})); - auto y = fl::reshape(targets.tensor(), Shape({1, X})); - - auto A = fl::arange(Shape({C, X})); - auto B = fl::tile(y, Shape({C})); - auto mask = -(A == B); // [C X] - - auto result = mask * x; - auto ignoreMask = (y == ignoreIndex).flatten(); // [X, 1] - result = fl::sum(result, {0}).flatten(); // [X, 1] - result(ignoreMask) = 0.; - - Tensor denominator; - if (reduction == ReduceMode::NONE) { - result = fl::reshape(result, targets.shape()); // [X1 X2 X3] - } else if (reduction == ReduceMode::MEAN) { - denominator = fl::sum((!ignoreMask).astype(fl::dtype::s32), {0}); - result = fl::sum(result, {0}) / denominator; // [1] - } else if (reduction == ReduceMode::SUM) { - result = fl::sum(result, {0}); // [1] - } else { - throw std::invalid_argument( - "unknown reduction method for categorical cross entropy"); - } - - auto inputDims = input.shape(); - auto gradFunc = [C, X, mask, ignoreMask, denominator, reduction, inputDims]( - std::vector& inputs, - const Variable& gradOutput) { - Tensor grad = gradOutput.tensor(); - if (reduction == ReduceMode::NONE) { - grad = fl::reshape(grad, {X}); - } else if (reduction == ReduceMode::MEAN) { - grad = fl::tile(grad / denominator, {X}); - } else if (reduction == ReduceMode::SUM) { - grad = fl::tile(grad, {X}); + for(int i = 1; i < input.ndim(); i++) { + if(input.dim(i) != targets.dim(i - 1)) { + throw std::invalid_argument( + "dimension mismatch in categorical cross entropy" + ); + } + } + + int C = input.dim(0); + int X = targets.elements(); + if( + fl::any( + ((targets.tensor() < 0) || (targets.tensor() >= C)) + && (targets.tensor() != ignoreIndex) + ) + .scalar() + ) { + throw std::invalid_argument( + "target contains elements out of valid range [0, num_categories) " + "in categorical cross entropy" + ); + } + + auto x = fl::reshape(input.tensor(), Shape({C, X})); + auto y = fl::reshape(targets.tensor(), Shape({1, X})); + + auto A = fl::arange(Shape({C, X})); + auto B = fl::tile(y, Shape({C})); + auto mask = -(A == B); // [C X] + + auto result = mask * x; + auto ignoreMask = (y == ignoreIndex).flatten(); // [X, 1] + result = fl::sum(result, {0}).flatten(); // [X, 1] + result(ignoreMask) = 0.; + + Tensor denominator; + if(reduction == ReduceMode::NONE) { + result = fl::reshape(result, targets.shape()); // [X1 X2 X3] + } else if(reduction == ReduceMode::MEAN) { + denominator = fl::sum((!ignoreMask).astype(fl::dtype::s32), {0}); + result = fl::sum(result, {0}) / denominator; // [1] + } else if(reduction == ReduceMode::SUM) { + result = fl::sum(result, {0}); // [1] + } else { + throw std::invalid_argument( + "unknown reduction method for categorical cross entropy" + ); } - // [1 X] - grad(ignoreMask) = 0.; - grad = fl::reshape(grad, {1, X}); - grad = fl::tile(grad, {C}) * mask; - inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false)); - }; - return Variable(result, {input.withoutData(), targets}, gradFunc); + auto inputDims = input.shape(); + auto gradFunc = [C, X, mask, ignoreMask, denominator, reduction, inputDims]( + std::vector& inputs, + const Variable& gradOutput) { + Tensor grad = gradOutput.tensor(); + if(reduction == ReduceMode::NONE) { + grad = fl::reshape(grad, {X}); + } else if(reduction == ReduceMode::MEAN) { + grad = fl::tile(grad / denominator, {X}); + } else if(reduction == ReduceMode::SUM) { + grad = fl::tile(grad, {X}); + } + // [1 X] + grad(ignoreMask) = 0.; + grad = fl::reshape(grad, {1, X}); + grad = fl::tile(grad, {C}) * mask; + inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false)); + }; + + return Variable(result, {input.withoutData(), targets}, gradFunc); } Variable weightedCategoricalCrossEntropy( const Variable& input, const Variable& targets, const Variable& weight, - int ignoreIndex /* = -1 */) { - // input -- [C, X1, X2, X3] - // target -- [X1, X2, X3] - if (input.ndim() < targets.ndim() - 1) { - throw std::invalid_argument( - "weightedCategoricalCrossEntropy: input must have one more than the " - "number of target dimensions minus 1"); - } - - for (int i = 1; i < targets.ndim() - 2; i++) { - if (input.dim(i) != targets.dim(i - 1)) { - throw std::invalid_argument( - "weightedCategoricalCrossEntropy: dimension mismatch in categorical cross entropy"); + int ignoreIndex /* = -1 */ +) { + // input -- [C, X1, X2, X3] + // target -- [X1, X2, X3] + if(input.ndim() < targets.ndim() - 1) { + throw std::invalid_argument( + "weightedCategoricalCrossEntropy: input must have one more than the " + "number of target dimensions minus 1" + ); + } + + for(int i = 1; i < targets.ndim() - 2; i++) { + if(input.dim(i) != targets.dim(i - 1)) { + throw std::invalid_argument( + "weightedCategoricalCrossEntropy: dimension mismatch in categorical cross entropy" + ); + } + } + if(weight.dim(0) != input.dim(0)) { + throw std::invalid_argument( + "weightedCategoricalCrossEntropy: dimension mismatch in categorical cross entropy" + ); } - } - if (weight.dim(0) != input.dim(0)) { - throw std::invalid_argument( - "weightedCategoricalCrossEntropy: dimension mismatch in categorical cross entropy"); - } - - int C = input.dim(0); - int X = targets.elements(); - if (fl::any((targets.tensor() < 0) || (targets.tensor() >= C)) - .scalar()) { - throw std::invalid_argument( - "weightedCategoricalCrossEntropy: target contains elements out of valid range " - "[0, num_categories) in categorical cross entropy"); - } - - auto x = fl::reshape(input.tensor(), {C, X}); - auto y = fl::reshape(targets.tensor(), {1, X}); - - auto A = fl::arange({C, X}); - auto B = fl::tile(y, {C}); - auto mask = -(A == B); // [C X] - - auto weightSum = (-mask) * fl::tile(weight.tensor(), {1, X}); - Variable denominator = {fl::sum(weightSum, {0, 1}), false}; - - auto result = mask * x; - result = result * weight.tensor(); - - auto ignoreMask = (y != ignoreIndex).astype(fl::dtype::s32); // [1, X] - result = ignoreMask * fl::sum(result, {0}, /* keepDims = */ true); // [1, X] - result = fl::sum(result, {1}, /* keepDims = */ true) / denominator.tensor(); - - auto inputDims = input.shape(); - auto gradFunc = [C, X, mask, ignoreMask, denominator, inputDims]( - std::vector& inputs, - const Variable& gradOutput) { - auto grad = gradOutput.tensor(); - grad = fl::tile(grad / denominator.tensor(), {1, X}); - - auto weightTensor = inputs[2].tensor(); - grad *= ignoreMask; - grad = fl::tile(grad, {C}) * mask; - grad = fl::reshape(grad, inputDims); - grad = grad * weightTensor; - inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false)); - }; - - return Variable(result, {input.withoutData(), targets, weight}, gradFunc); + + int C = input.dim(0); + int X = targets.elements(); + if( + fl::any((targets.tensor() < 0) || (targets.tensor() >= C)) + .scalar() + ) { + throw std::invalid_argument( + "weightedCategoricalCrossEntropy: target contains elements out of valid range " + "[0, num_categories) in categorical cross entropy" + ); + } + + auto x = fl::reshape(input.tensor(), {C, X}); + auto y = fl::reshape(targets.tensor(), {1, X}); + + auto A = fl::arange({C, X}); + auto B = fl::tile(y, {C}); + auto mask = -(A == B); // [C X] + + auto weightSum = (-mask) * fl::tile(weight.tensor(), {1, X}); + Variable denominator = {fl::sum(weightSum, {0, 1}), false}; + + auto result = mask * x; + result = result * weight.tensor(); + + auto ignoreMask = (y != ignoreIndex).astype(fl::dtype::s32); // [1, X] + result = ignoreMask * fl::sum(result, {0}, /* keepDims = */ true); // [1, X] + result = fl::sum(result, {1}, /* keepDims = */ true) / denominator.tensor(); + + auto inputDims = input.shape(); + auto gradFunc = [C, X, mask, ignoreMask, denominator, inputDims]( + std::vector& inputs, + const Variable& gradOutput) { + auto grad = gradOutput.tensor(); + grad = fl::tile(grad / denominator.tensor(), {1, X}); + + auto weightTensor = inputs[2].tensor(); + grad *= ignoreMask; + grad = fl::tile(grad, {C}) * mask; + grad = fl::reshape(grad, inputDims); + grad = grad * weightTensor; + inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false)); + }; + + return Variable(result, {input.withoutData(), targets, weight}, gradFunc); } Variable reorder(const Variable& input, const Shape& shape) { - auto result = fl::transpose(input.tensor(), shape); - if (!result.isContiguous()) { - result = result.asContiguousTensor(); - } - - std::vector> dimGrad(shape.ndim()); - for (unsigned i = 0; i < shape.ndim(); ++i) { - dimGrad[i] = {shape.dim(i), i}; - } - - std::sort(dimGrad.begin(), dimGrad.end()); - - auto gradFunc = - [dimGrad](std::vector& inputs, const Variable& gradOutput) { - Shape reordered(std::vector(dimGrad.size())); - for (unsigned i = 0; i < dimGrad.size(); ++i) { - reordered[i] = dimGrad[i].second; - } + auto result = fl::transpose(input.tensor(), shape); + if(!result.isContiguous()) { + result = result.asContiguousTensor(); + } - inputs[0].addGrad( - Variable(fl::transpose(gradOutput.tensor(), reordered), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + std::vector> dimGrad(shape.ndim()); + for(unsigned i = 0; i < shape.ndim(); ++i) { + dimGrad[i] = {shape.dim(i), i}; + } + + std::sort(dimGrad.begin(), dimGrad.end()); + + auto gradFunc = + [dimGrad](std::vector& inputs, const Variable& gradOutput) { + Shape reordered(std::vector(dimGrad.size())); + for(unsigned i = 0; i < dimGrad.size(); ++i) { + reordered[i] = dimGrad[i].second; + } + + inputs[0].addGrad( + Variable(fl::transpose(gradOutput.tensor(), reordered), false) + ); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable linear(const Variable& input, const Variable& weight) { - auto dummyBias = Variable(Tensor().astype(input.type()), false); - return linear(input, weight, dummyBias); + auto dummyBias = Variable(Tensor().astype(input.type()), false); + return linear(input, weight, dummyBias); } Variable linear(const Variable& in, const Variable& wt, const Variable& bs) { - FL_VARIABLE_DTYPES_MATCH_CHECK(in, wt, bs); - auto input = FL_ADJUST_INPUT_TYPE(in); - auto weight = FL_ADJUST_INPUT_TYPE(wt); - auto bias = FL_ADJUST_INPUT_TYPE(bs); - - Shape to2d({input.dim(0), input.elements() / input.dim(0)}); - auto to4d = input.shape(); - to4d[0] = weight.tensor().dim(0); - - auto output = - reshape(fl::matmul(weight.tensor(), reshape(input.tensor(), to2d)), to4d); - - auto hasBias = bias.elements() > 0; - if (hasBias) { - auto tiledims = output.shape(); - tiledims[0] = 1; - output = output + tile(bias.tensor(), tiledims); - } - - auto gradFunc = [hasBias]( - std::vector& inputs, - const Variable& gradOutput) { - auto& in = inputs[0]; - auto& wt = inputs[1]; - Tensor wtTensor = wt.tensor(); - Tensor gradOutputTensor = gradOutput.tensor(); - - auto nframes = in.elements() / in.dim(0); - - if (hasBias && inputs[2].isCalcGrad()) { - auto& bs = inputs[2]; - auto biasGrad = sumAs(gradOutput, bs).tensor(); - bs.addGrad(Variable(biasGrad, false)); - } - if (in.isCalcGrad()) { - Shape to2dout({wtTensor.dim(0), nframes}); - auto inGrad = - moddims(matmulTN(wt, moddims(gradOutput, to2dout)), in.shape()) - .tensor(); - in.addGrad(Variable(inGrad, false)); + FL_VARIABLE_DTYPES_MATCH_CHECK(in, wt, bs); + auto input = FL_ADJUST_INPUT_TYPE(in); + auto weight = FL_ADJUST_INPUT_TYPE(wt); + auto bias = FL_ADJUST_INPUT_TYPE(bs); + + Shape to2d({input.dim(0), input.elements() / input.dim(0)}); + auto to4d = input.shape(); + to4d[0] = weight.tensor().dim(0); + + auto output = + reshape(fl::matmul(weight.tensor(), reshape(input.tensor(), to2d)), to4d); + + auto hasBias = bias.elements() > 0; + if(hasBias) { + auto tiledims = output.shape(); + tiledims[0] = 1; + output = output + tile(bias.tensor(), tiledims); } - if (wt.isCalcGrad()) { - Shape to2din({wtTensor.dim(1), nframes}); - Shape to2dout({wtTensor.dim(0), nframes}); - auto wtGrad = - matmulNT(moddims(gradOutput, to2dout), moddims(in, to2din)).tensor(); - wt.addGrad(Variable(wtGrad, false)); + + auto gradFunc = [hasBias]( + std::vector& inputs, + const Variable& gradOutput) { + auto& in = inputs[0]; + auto& wt = inputs[1]; + Tensor wtTensor = wt.tensor(); + Tensor gradOutputTensor = gradOutput.tensor(); + + auto nframes = in.elements() / in.dim(0); + + if(hasBias && inputs[2].isCalcGrad()) { + auto& bs = inputs[2]; + auto biasGrad = sumAs(gradOutput, bs).tensor(); + bs.addGrad(Variable(biasGrad, false)); + } + if(in.isCalcGrad()) { + Shape to2dout({wtTensor.dim(0), nframes}); + auto inGrad = + moddims(matmulTN(wt, moddims(gradOutput, to2dout)), in.shape()) + .tensor(); + in.addGrad(Variable(inGrad, false)); + } + if(wt.isCalcGrad()) { + Shape to2din({wtTensor.dim(1), nframes}); + Shape to2dout({wtTensor.dim(0), nframes}); + auto wtGrad = + matmulNT(moddims(gradOutput, to2dout), moddims(in, to2din)).tensor(); + wt.addGrad(Variable(wtGrad, false)); + } + }; + if(hasBias) { + return Variable(output, {input, weight, bias}, gradFunc); } - }; - if (hasBias) { - return Variable(output, {input, weight, bias}, gradFunc); - } - return Variable(output, {input, weight}, gradFunc); + return Variable(output, {input, weight}, gradFunc); } Variable conv2d( @@ -1318,10 +1417,22 @@ Variable conv2d( int dx, int dy, int groups, - std::shared_ptr benchmarks) { - auto dummyBias = Variable(Tensor(input.type()), false); - return conv2d( - input, weights, dummyBias, sx, sy, px, py, dx, dy, groups, benchmarks); + std::shared_ptr benchmarks +) { + auto dummyBias = Variable(Tensor(input.type()), false); + return conv2d( + input, + weights, + dummyBias, + sx, + sy, + px, + py, + dx, + dy, + groups, + benchmarks + ); } Variable conv2d( @@ -1335,125 +1446,129 @@ Variable conv2d( int dx, int dy, int groups, - std::shared_ptr benchmarks) { - FL_VARIABLE_DTYPES_MATCH_CHECK(in, wt, bs); - - auto payload = detail::createAutogradPayload(in, wt, bs); - - bool hasBias = !bs.isEmpty(); - - auto input = FL_ADJUST_INPUT_TYPE(in); - auto weights = FL_ADJUST_INPUT_TYPE(wt); - auto bias = FL_ADJUST_INPUT_TYPE(bs); - - Tensor output = detail::conv2d( - input.tensor(), - weights.tensor(), - bias.tensor(), - sx, - sy, - px, - py, - dx, - dy, - groups, - payload); - - auto gradFunc = - [sx, sy, px, py, dx, dy, hasBias, groups, benchmarks, payload]( - std::vector& inputs, const Variable& gradOutput) { - // Create benchmarks if needed - auto& autogradExtension = - inputs[0].tensor().backend().getExtension(); - - std::shared_ptr dataBench; - std::shared_ptr filterBench; - std::shared_ptr biasBench; - if (benchmarks && DynamicBenchmark::getBenchmarkMode()) { - if (!benchmarks->bwdFilterBenchmark) { - benchmarks->bwdFilterBenchmark = - autogradExtension.createBenchmarkOptions(); - filterBench = benchmarks->bwdFilterBenchmark; - } - if (!benchmarks->bwdDataBenchmark) { - benchmarks->bwdDataBenchmark = - autogradExtension.createBenchmarkOptions(); - dataBench = benchmarks->bwdDataBenchmark; - } - if (!benchmarks->bwdBiasBenchmark) { - benchmarks->bwdBiasBenchmark = - autogradExtension.createBenchmarkOptions(); - biasBench = benchmarks->bwdBiasBenchmark; - } - } - - // Bias gradients - Tensor bs; - const bool computeBiasGrad = - inputs.size() > 2 && inputs[2].isCalcGrad(); - if (hasBias && computeBiasGrad) { - bs = inputs[2].tensor(); - // auto biasGrad = - // bs.backend().getExtension().conv2dBackwardBias( - // gradOutput.tensor(), bs, biasBench, payload); - - // inputs[2].addGrad(Variable(biasGrad, false)); // bias - } - - auto& in = inputs[0].tensor(); - auto& wt = inputs[1].tensor(); - - // Data (input) gradients - if (inputs[0].isCalcGrad()) { - auto dataGrad = - in.backend().getExtension().conv2dBackwardData( - gradOutput.tensor(), - in, - wt, - sx, - sy, - px, - py, - dx, - dy, - groups, - dataBench, - payload); - - inputs[0].addGrad(Variable(dataGrad, false)); // input/data - } - - // Filter (weight) and bias gradients - if (inputs[1].isCalcGrad() || computeBiasGrad) { - auto [filterGrad, biasGrad] = wt.backend() - .getExtension() - .conv2dBackwardFilterBias( - gradOutput.tensor(), - in, - wt, - bs, - sx, - sy, - px, - py, - dx, - dy, - groups, - filterBench, - biasBench, - payload); - if (inputs[1].isCalcGrad()) { - inputs[1].addGrad(Variable(filterGrad, false)); // filter/weight - } - if (computeBiasGrad) { - inputs[2].addGrad(Variable(biasGrad, false)); - } - } - }; - if (hasBias) { - return Variable(output, {input, weights, bias}, gradFunc); - } - return Variable(output, {input, weights}, gradFunc); + std::shared_ptr benchmarks +) { + FL_VARIABLE_DTYPES_MATCH_CHECK(in, wt, bs); + + auto payload = detail::createAutogradPayload(in, wt, bs); + + bool hasBias = !bs.isEmpty(); + + auto input = FL_ADJUST_INPUT_TYPE(in); + auto weights = FL_ADJUST_INPUT_TYPE(wt); + auto bias = FL_ADJUST_INPUT_TYPE(bs); + + Tensor output = detail::conv2d( + input.tensor(), + weights.tensor(), + bias.tensor(), + sx, + sy, + px, + py, + dx, + dy, + groups, + payload + ); + + auto gradFunc = + [sx, sy, px, py, dx, dy, hasBias, groups, benchmarks, payload]( + std::vector& inputs, const Variable& gradOutput) { + // Create benchmarks if needed + auto& autogradExtension = + inputs[0].tensor().backend().getExtension(); + + std::shared_ptr dataBench; + std::shared_ptr filterBench; + std::shared_ptr biasBench; + if(benchmarks && DynamicBenchmark::getBenchmarkMode()) { + if(!benchmarks->bwdFilterBenchmark) { + benchmarks->bwdFilterBenchmark = + autogradExtension.createBenchmarkOptions(); + filterBench = benchmarks->bwdFilterBenchmark; + } + if(!benchmarks->bwdDataBenchmark) { + benchmarks->bwdDataBenchmark = + autogradExtension.createBenchmarkOptions(); + dataBench = benchmarks->bwdDataBenchmark; + } + if(!benchmarks->bwdBiasBenchmark) { + benchmarks->bwdBiasBenchmark = + autogradExtension.createBenchmarkOptions(); + biasBench = benchmarks->bwdBiasBenchmark; + } + } + + // Bias gradients + Tensor bs; + const bool computeBiasGrad = + inputs.size() > 2 && inputs[2].isCalcGrad(); + if(hasBias && computeBiasGrad) { + bs = inputs[2].tensor(); + // auto biasGrad = + // bs.backend().getExtension().conv2dBackwardBias( + // gradOutput.tensor(), bs, biasBench, payload); + + // inputs[2].addGrad(Variable(biasGrad, false)); // bias + } + + auto& in = inputs[0].tensor(); + auto& wt = inputs[1].tensor(); + + // Data (input) gradients + if(inputs[0].isCalcGrad()) { + auto dataGrad = + in.backend().getExtension().conv2dBackwardData( + gradOutput.tensor(), + in, + wt, + sx, + sy, + px, + py, + dx, + dy, + groups, + dataBench, + payload + ); + + inputs[0].addGrad(Variable(dataGrad, false)); // input/data + } + + // Filter (weight) and bias gradients + if(inputs[1].isCalcGrad() || computeBiasGrad) { + auto [filterGrad, biasGrad] = wt.backend() + .getExtension() + .conv2dBackwardFilterBias( + gradOutput.tensor(), + in, + wt, + bs, + sx, + sy, + px, + py, + dx, + dy, + groups, + filterBench, + biasBench, + payload + ); + if(inputs[1].isCalcGrad()) { + inputs[1].addGrad(Variable(filterGrad, false)); // filter/weight + } + if(computeBiasGrad) { + inputs[2].addGrad(Variable(biasGrad, false)); + } + } + }; + if(hasBias) { + return Variable(output, {input, weights, bias}, gradFunc); + } + return Variable(output, {input, weights}, gradFunc); } Variable pool2d( @@ -1464,35 +1579,40 @@ Variable pool2d( int sy, int px, int py, - PoolingMode mode /* = PoolingMode::MAX */) { - auto payload = detail::createAutogradPayload(input); - Tensor output = - fl::detail::pool2d(input.tensor(), wx, wy, sx, sy, px, py, mode, payload); - - auto gradFunc = [wx, wy, sx, sy, px, py, mode, output, payload]( - std::vector& inputs, - const Variable& gradOutput) { - auto& in = inputs[0]; - if (!in.isCalcGrad()) { - return; - } - - in.addGrad(Variable( - in.tensor().backend().getExtension().pool2dBackward( - gradOutput.tensor(), - in.tensor(), - output, - wx, - wy, - sx, - sy, - px, - py, - mode, - payload), - false)); - }; - return Variable(output, {input}, gradFunc); + PoolingMode mode /* = PoolingMode::MAX */ +) { + auto payload = detail::createAutogradPayload(input); + Tensor output = + fl::detail::pool2d(input.tensor(), wx, wy, sx, sy, px, py, mode, payload); + + auto gradFunc = [wx, wy, sx, sy, px, py, mode, output, payload]( + std::vector& inputs, + const Variable& gradOutput) { + auto& in = inputs[0]; + if(!in.isCalcGrad()) { + return; + } + + in.addGrad( + Variable( + in.tensor().backend().getExtension().pool2dBackward( + gradOutput.tensor(), + in.tensor(), + output, + wx, + wy, + sx, + sy, + px, + py, + mode, + payload + ), + false + ) + ); + }; + return Variable(output, {input}, gradFunc); } Variable batchnorm( @@ -1504,44 +1624,46 @@ Variable batchnorm( const std::vector& axes, bool train, double momentum, - double epsilon) { - auto payload = detail::createAutogradPayload(_input, weight, bias); - auto input = FL_ADJUST_INPUT_TYPE(_input); - - Tensor saveMean, saveVar; - Tensor output = fl::detail::batchnorm( - saveMean, - saveVar, - input.tensor(), - weight.tensor(), - bias.tensor(), - runningMean.tensor(), - runningVar.tensor(), - axes, - train, - momentum, - epsilon, - payload); - - auto gradFunc = - [saveMean = std::move(saveMean), - saveVar = std::move(saveVar), - train, - axes, - epsilon, - payload](std::vector& inputs, const Variable& _gradOutput) { - auto& in = inputs[0]; - auto& wt = inputs[1]; - auto& bs = inputs[2]; - - auto gradOutput = detail::adjustInputType(_gradOutput, "batchnorm"); - - if (!in.isCalcGrad() && !wt.isCalcGrad() && !bs.isCalcGrad()) { - return; - } - - auto [gradIn, gradWt, gradBs] = - in.tensor() + double epsilon +) { + auto payload = detail::createAutogradPayload(_input, weight, bias); + auto input = FL_ADJUST_INPUT_TYPE(_input); + + Tensor saveMean, saveVar; + Tensor output = fl::detail::batchnorm( + saveMean, + saveVar, + input.tensor(), + weight.tensor(), + bias.tensor(), + runningMean.tensor(), + runningVar.tensor(), + axes, + train, + momentum, + epsilon, + payload + ); + + auto gradFunc = + [saveMean = std::move(saveMean), + saveVar = std::move(saveVar), + train, + axes, + epsilon, + payload](std::vector& inputs, const Variable& _gradOutput) { + auto& in = inputs[0]; + auto& wt = inputs[1]; + auto& bs = inputs[2]; + + auto gradOutput = detail::adjustInputType(_gradOutput, "batchnorm"); + + if(!in.isCalcGrad() && !wt.isCalcGrad() && !bs.isCalcGrad()) { + return; + } + + auto [gradIn, gradWt, gradBs] = + in.tensor() .backend() .getExtension() .batchnormBackward( @@ -1553,54 +1675,56 @@ Variable batchnorm( axes, train, epsilon, - payload); - - in.addGrad(Variable(gradIn.astype(in.type()), false)); - wt.addGrad(Variable(gradWt.astype(wt.type()), false)); - if (!bs.isEmpty()) { - bs.addGrad(Variable(gradBs.astype(bs.type()), false)); - } - }; - return Variable(output, {input, weight, bias}, gradFunc); + payload + ); + + in.addGrad(Variable(gradIn.astype(in.type()), false)); + wt.addGrad(Variable(gradWt.astype(wt.type()), false)); + if(!bs.isEmpty()) { + bs.addGrad(Variable(gradBs.astype(bs.type()), false)); + } + }; + return Variable(output, {input, weight, bias}, gradFunc); } Variable gatedlinearunit(const Variable& input, const int dim) { - if (dim >= input.ndim()) { - throw std::invalid_argument( - "gatedlinearunit - passed dim is great than the " - "number of dimensions of the input."); - } - - auto inDims = input.shape(); - auto inType = input.type(); - auto inSize = inDims[dim]; - if (inSize % 2 == 1) { - throw std::invalid_argument("halving dimension must be even for GLU"); - } - - std::vector fhalf(input.ndim(), fl::span); - std::vector shalf(input.ndim(), fl::span); - fhalf[dim] = fl::range(inSize / 2); - shalf[dim] = fl::range(inSize / 2, inSize); - - Tensor fhalfout = input.tensor()(fhalf); - Tensor shalfout = input.tensor()(shalf); - - // Temporary workaround for indexing bug present in ArrayFire 3.6.1. - fhalfout = fl::reshape(fhalfout, fhalfout.shape()); - shalfout = fl::reshape(shalfout, shalfout.shape()); - shalfout = fl::sigmoid(shalfout); - - auto gradFunc = [fhalf, shalf, fhalfout, shalfout, inDims, inType]( - std::vector& inputs, - const Variable& gradOutput) { - auto gradGlu = Tensor(inDims, inType); - gradGlu(fhalf) = shalfout * gradOutput.tensor(); - gradGlu(shalf) = - shalfout * (1.0 - shalfout) * fhalfout * gradOutput.tensor(); - inputs[0].addGrad(Variable(gradGlu, false)); - }; - return Variable(fhalfout * shalfout, {input.withoutData()}, gradFunc); + if(dim >= input.ndim()) { + throw std::invalid_argument( + "gatedlinearunit - passed dim is great than the " + "number of dimensions of the input." + ); + } + + auto inDims = input.shape(); + auto inType = input.type(); + auto inSize = inDims[dim]; + if(inSize % 2 == 1) { + throw std::invalid_argument("halving dimension must be even for GLU"); + } + + std::vector fhalf(input.ndim(), fl::span); + std::vector shalf(input.ndim(), fl::span); + fhalf[dim] = fl::range(inSize / 2); + shalf[dim] = fl::range(inSize / 2, inSize); + + Tensor fhalfout = input.tensor()(fhalf); + Tensor shalfout = input.tensor()(shalf); + + // Temporary workaround for indexing bug present in ArrayFire 3.6.1. + fhalfout = fl::reshape(fhalfout, fhalfout.shape()); + shalfout = fl::reshape(shalfout, shalfout.shape()); + shalfout = fl::sigmoid(shalfout); + + auto gradFunc = [fhalf, shalf, fhalfout, shalfout, inDims, inType]( + std::vector& inputs, + const Variable& gradOutput) { + auto gradGlu = Tensor(inDims, inType); + gradGlu(fhalf) = shalfout * gradOutput.tensor(); + gradGlu(shalf) = + shalfout * (1.0 - shalfout) * fhalfout * gradOutput.tensor(); + inputs[0].addGrad(Variable(gradGlu, false)); + }; + return Variable(fhalfout * shalfout, {input.withoutData()}, gradFunc); } std::tuple rnn( @@ -1612,217 +1736,239 @@ std::tuple rnn( int numLayers, RnnMode mode, bool bidirectional, - float dropProb) { - auto payload = - detail::createAutogradPayload(input, hiddenState, cellState, weights); - - Tensor output, hiddenOut, cellStateOut; - std::tie(output, hiddenOut, cellStateOut) = detail::rnn( - input.tensor(), - hiddenState.tensor(), - cellState.tensor(), - weights.tensor(), - hiddenSize, - numLayers, - mode, - bidirectional, - dropProb, - payload); - - auto gradData = std::make_shared(); - - auto gradFunc = [output, - numLayers, - hiddenSize, - mode, - bidirectional, - dropProb, - gradData, - payload]( - std::vector& inputs, - const Variable& /* gradOutput */) { - auto& input = inputs[0]; - auto& hiddenState = inputs[1]; - auto& cellState = inputs[2]; - auto& weights = inputs[3]; - - if (!(input.isCalcGrad() || hiddenState.isCalcGrad() || - cellState.isCalcGrad() || weights.isCalcGrad())) { - return; - } - - auto [dy, dhy, dcy, dweights] = - input.tensor().backend().getExtension().rnnBackward( - input.tensor(), - hiddenState.tensor(), - cellState.tensor(), - weights.tensor(), - gradData, - output, + float dropProb +) { + auto payload = + detail::createAutogradPayload(input, hiddenState, cellState, weights); + + Tensor output, hiddenOut, cellStateOut; + std::tie(output, hiddenOut, cellStateOut) = detail::rnn( + input.tensor(), + hiddenState.tensor(), + cellState.tensor(), + weights.tensor(), + hiddenSize, + numLayers, + mode, + bidirectional, + dropProb, + payload + ); + + auto gradData = std::make_shared(); + + auto gradFunc = [output, numLayers, hiddenSize, mode, bidirectional, dropProb, - payload); - - input.addGrad(Variable(dy.astype(input.type()), false)); - hiddenState.addGrad(Variable(dhy.astype(hiddenState.type()), false)); - cellState.addGrad(Variable(dcy.astype(cellState.type()), false)); - weights.addGrad(Variable(dweights.astype(weights.type()), false)); - }; - - Variable dummy(Tensor(), {input, hiddenState, cellState, weights}, gradFunc); - - auto dyGradFunc = - [gradData](std::vector& inputs, const Variable& gradOutput) { - if (!inputs[0].isGradAvailable()) { - inputs[0].addGrad(Variable(Tensor(), false)); - } - gradData->dy = gradOutput.tensor().asContiguousTensor(); - }; - - auto dhyGradFunc = - [gradData](std::vector& inputs, const Variable& gradOutput) { - if (!inputs[0].isGradAvailable()) { - inputs[0].addGrad(Variable(Tensor(), false)); - } - gradData->dhy = gradOutput.tensor().asContiguousTensor(); - }; - - auto dcyGradFunc = - [gradData](std::vector& inputs, const Variable& gradOutput) { - if (!inputs[0].isGradAvailable()) { - inputs[0].addGrad(Variable(Tensor(), false)); - } - gradData->dcy = gradOutput.tensor().asContiguousTensor(); - }; - - Variable yv(output, {dummy}, dyGradFunc); // output - Variable hyv(hiddenOut, {dummy}, dhyGradFunc); // hidden state output - Variable cyv(cellStateOut, {dummy}, dcyGradFunc); // cell state output - return std::make_tuple(yv, hyv, cyv); + gradData, + payload]( + std::vector& inputs, + const Variable& /* gradOutput */) { + auto& input = inputs[0]; + auto& hiddenState = inputs[1]; + auto& cellState = inputs[2]; + auto& weights = inputs[3]; + + if( + !(input.isCalcGrad() || hiddenState.isCalcGrad() + || cellState.isCalcGrad() || weights.isCalcGrad()) + ) { + return; + } + + auto [dy, dhy, dcy, dweights] = + input.tensor().backend().getExtension().rnnBackward( + input.tensor(), + hiddenState.tensor(), + cellState.tensor(), + weights.tensor(), + gradData, + output, + numLayers, + hiddenSize, + mode, + bidirectional, + dropProb, + payload + ); + + input.addGrad(Variable(dy.astype(input.type()), false)); + hiddenState.addGrad(Variable(dhy.astype(hiddenState.type()), false)); + cellState.addGrad(Variable(dcy.astype(cellState.type()), false)); + weights.addGrad(Variable(dweights.astype(weights.type()), false)); + }; + + Variable dummy(Tensor(), {input, hiddenState, cellState, weights}, gradFunc); + + auto dyGradFunc = + [gradData](std::vector& inputs, const Variable& gradOutput) { + if(!inputs[0].isGradAvailable()) { + inputs[0].addGrad(Variable(Tensor(), false)); + } + gradData->dy = gradOutput.tensor().asContiguousTensor(); + }; + + auto dhyGradFunc = + [gradData](std::vector& inputs, const Variable& gradOutput) { + if(!inputs[0].isGradAvailable()) { + inputs[0].addGrad(Variable(Tensor(), false)); + } + gradData->dhy = gradOutput.tensor().asContiguousTensor(); + }; + + auto dcyGradFunc = + [gradData](std::vector& inputs, const Variable& gradOutput) { + if(!inputs[0].isGradAvailable()) { + inputs[0].addGrad(Variable(Tensor(), false)); + } + gradData->dcy = gradOutput.tensor().asContiguousTensor(); + }; + + Variable yv(output, {dummy}, dyGradFunc); // output + Variable hyv(hiddenOut, {dummy}, dhyGradFunc); // hidden state output + Variable cyv(cellStateOut, {dummy}, dcyGradFunc); // cell state output + return std::make_tuple(yv, hyv, cyv); } Variable embedding(const Variable& input, const Variable& embeddings) { - // TODO{fl::Tensor}{4-dims} - relax this - if (input.ndim() >= 4) { - throw std::invalid_argument("embedding input must have 3 or fewer dims"); - } - - auto idxs = input.tensor().flatten(); - auto inDims = input.shape(); - std::vector rDims(input.ndim() + 1); - rDims[0] = embeddings.dim(0); - for (unsigned i = 1; i < input.ndim() + 1; i++) { - rDims[i] = inDims[i - 1]; - } - Shape resultDims(rDims); - Tensor result = fl::reshape(embeddings.tensor()(fl::span, idxs), resultDims); - - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto& w = inputs[1]; - if (!w.isCalcGrad()) { - return; + // TODO{fl::Tensor}{4-dims} - relax this + if(input.ndim() >= 4) { + throw std::invalid_argument("embedding input must have 3 or fewer dims"); } - auto ip = inputs[0].tensor().flatten(); - unsigned size = ip.elements(); - auto deltas = fl::reshape(gradOutput.tensor(), {w.dim(0), size}); - - // Sparse Tensor - auto sp = Tensor( - ip.elements(), - w.dim(1), - fl::full({size}, 1, deltas.type()), - fl::arange({size + 1}, 0, fl::dtype::s32), - ip.astype(fl::dtype::s32), - fl::StorageType::CSR); - - auto grad = transpose(fl::matmul( - sp, transpose(deltas), /* lhsProp = */ MatrixProperty::Transpose)); - w.addGrad(Variable(grad, false)); - }; - - return Variable(result, {input, embeddings}, gradFunc); + auto idxs = input.tensor().flatten(); + auto inDims = input.shape(); + std::vector rDims(input.ndim() + 1); + rDims[0] = embeddings.dim(0); + for(unsigned i = 1; i < input.ndim() + 1; i++) { + rDims[i] = inDims[i - 1]; + } + Shape resultDims(rDims); + Tensor result = fl::reshape(embeddings.tensor()(fl::span, idxs), resultDims); + + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + auto& w = inputs[1]; + if(!w.isCalcGrad()) { + return; + } + + auto ip = inputs[0].tensor().flatten(); + unsigned size = ip.elements(); + auto deltas = fl::reshape(gradOutput.tensor(), {w.dim(0), size}); + + // Sparse Tensor + auto sp = Tensor( + ip.elements(), + w.dim(1), + fl::full({size}, 1, deltas.type()), + fl::arange({size + 1}, 0, fl::dtype::s32), + ip.astype(fl::dtype::s32), + fl::StorageType::CSR + ); + + auto grad = transpose( + fl::matmul( + sp, + transpose(deltas), /* lhsProp = */ + MatrixProperty::Transpose + ) + ); + w.addGrad(Variable(grad, false)); + }; + + return Variable(result, {input, embeddings}, gradFunc); } Variable padding( const Variable& input, std::vector> pad, - double val) { - if (pad.size() > input.ndim()) { - throw std::invalid_argument( - "padding: number of padding dimensions exceeds number " - "of input dimensions"); - } - - Shape opDims = input.shape(); - std::vector inSeq(input.ndim(), fl::span); - for (int i = 0; i < pad.size(); ++i) { - opDims[i] += (pad[i].first + pad[i].second); - inSeq[i] = fl::range(pad[i].first, opDims[i] - pad[i].second); - } - Tensor result = fl::full(opDims, val, input.type()); - result(inSeq) = input.tensor(); - - auto gradFunc = - [inSeq](std::vector& inputs, const Variable& gradOutput) { - inputs[0].addGrad(Variable(gradOutput.tensor()(inSeq), false)); - }; - return Variable(result, {input.withoutData()}, gradFunc); + double val +) { + if(pad.size() > input.ndim()) { + throw std::invalid_argument( + "padding: number of padding dimensions exceeds number " + "of input dimensions" + ); + } + + Shape opDims = input.shape(); + std::vector inSeq(input.ndim(), fl::span); + for(int i = 0; i < pad.size(); ++i) { + opDims[i] += (pad[i].first + pad[i].second); + inSeq[i] = fl::range(pad[i].first, opDims[i] - pad[i].second); + } + Tensor result = fl::full(opDims, val, input.type()); + result(inSeq) = input.tensor(); + + auto gradFunc = + [inSeq](std::vector& inputs, const Variable& gradOutput) { + inputs[0].addGrad(Variable(gradOutput.tensor()(inSeq), false)); + }; + return Variable(result, {input.withoutData()}, gradFunc); } Variable dropout(const Variable& input, double p) { - if (p > 0.0) { - auto mask = Variable( - (fl::rand(input.shape(), input.type()) > p).astype(input.type()), false); - return 1.0 / (1.0 - p) * mask * input; - } else { - return input; - } + if(p > 0.0) { + auto mask = Variable( + (fl::rand(input.shape(), input.type()) > p).astype(input.type()), + false + ); + return 1.0 / (1.0 - p) * mask * input; + } else { + return input; + } } Variable relu(const Variable& input) { - return max(input, 0.0); + return max(input, 0.0); } Variable gelu(const Variable& in) { - auto input = FL_ADJUST_INPUT_TYPE(in); - return 0.5 * input * - (1.0 + - fl::tanh(0.7978845608 * (input + 0.044715 * input * input * input))); + auto input = FL_ADJUST_INPUT_TYPE(in); + return 0.5 * input + * (1.0 + + fl::tanh(0.7978845608 * (input + 0.044715 * input * input * input))); } fl::Variable relativePositionEmbeddingRotate(const fl::Variable& input) { - if (input.ndim() != 3) { - throw std::invalid_argument( - "relativePositionEmbeddingRotate - " - "input tensor must have 3 dimensions"); - } - - auto data = input.tensor(); - int d0 = data.dim(0); - int d1 = data.dim(1); - int d2 = data.dim(2); - data = fl::concatenate( - /* axis = */ 0, data, fl::full({d1, d1, d2}, 0.0, data.type())); - data = fl::reshape(data, {(d0 + d1) * d1, 1, d2}); - data = data(fl::range(0, (d1 + d0 - 1) * d1)); - data = fl::reshape(data, {d0 + d1 - 1, d1, d2}); - auto gradFunc = [d0, d1, d2]( - std::vector& inputs, - const fl::Variable& gradOutput) { - auto gradData = gradOutput.tensor(); - gradData = fl::reshape(gradData, {(d0 + d1 - 1) * d1, 1, d2}); - gradData = fl::concatenate( - 0, gradData, fl::full({d1, 1, d2}, 0.0, gradData.type())); - gradData = reshape(gradData, {d0 + d1, d1, d2}); - gradData = Variable(gradData, false)(fl::range(0, d0)).tensor(); - inputs[0].addGrad(fl::Variable(gradData, false)); - }; - return fl::Variable(data, {input}, gradFunc); + if(input.ndim() != 3) { + throw std::invalid_argument( + "relativePositionEmbeddingRotate - " + "input tensor must have 3 dimensions" + ); + } + + auto data = input.tensor(); + int d0 = data.dim(0); + int d1 = data.dim(1); + int d2 = data.dim(2); + data = fl::concatenate( + /* axis = */ 0, + data, + fl::full({d1, d1, d2}, 0.0, data.type()) + ); + data = fl::reshape(data, {(d0 + d1) * d1, 1, d2}); + data = data(fl::range(0, (d1 + d0 - 1) * d1)); + data = fl::reshape(data, {d0 + d1 - 1, d1, d2}); + auto gradFunc = [d0, d1, d2]( + std::vector& inputs, + const fl::Variable& gradOutput) { + auto gradData = gradOutput.tensor(); + gradData = fl::reshape(gradData, {(d0 + d1 - 1) * d1, 1, d2}); + gradData = fl::concatenate( + 0, + gradData, + fl::full({d1, 1, d2}, 0.0, gradData.type()) + ); + gradData = reshape(gradData, {d0 + d1, d1, d2}); + gradData = Variable(gradData, false)(fl::range(0, d0)).tensor(); + inputs[0].addGrad(fl::Variable(gradData, false)); + }; + return fl::Variable(data, {input}, gradFunc); } fl::Variable multiheadAttention( @@ -1834,60 +1980,67 @@ fl::Variable multiheadAttention( const fl::Variable& padMask, const int32_t nHeads, const double pDropout, - const int32_t offset /* = 0 */) { - if (query.ndim() != 3) { - throw std::invalid_argument( - "multiheadAttention - query input tensor should be 3 dimensions: " - "Time x (nHeads * headDim) x B"); - } - if (key.ndim() != 3) { - throw std::invalid_argument( - "multiheadAttention - key input tensor should be 3 dimensions: " - "Time x (nHeads * headDim) x B"); - } - if (value.ndim() != 3) { - throw std::invalid_argument( - "multiheadAttention - value input tensor should be 3 dimensions: " - "Time x (nHeads * headDim) x B"); - } - - int32_t bsz = query.dim(2); - int32_t modelDim = query.dim(1); - int32_t headDim = modelDim / nHeads; - - auto q = moddims(query, {-1, headDim, nHeads * bsz}); - auto k = moddims(key, {-1, headDim, nHeads * bsz}); - auto v = moddims(value, {-1, headDim, nHeads * bsz}); - - q = q / std::sqrt(float(headDim)); - auto scores = matmulNT(q, k); - if (!posEmb.isEmpty()) { - int n = posEmb.dim(0) / 2 - offset; - auto pscores = - relativePositionEmbeddingRotate(matmulNT(posEmb.astype(q.type()), q)); - scores = - scores + transpose(pscores(fl::range(n, n + k.dim(0))), {1, 0, 2}); - } - if (!mask.isEmpty()) { - scores = scores + tileAs(mask.astype(scores.type()), scores); - } - if (!padMask.isEmpty()) { - if (padMask.dim(0) != query.dim(0)) { - throw std::invalid_argument( - "multiheadAttention: invalid padding mask size"); + const int32_t offset /* = 0 */ +) { + if(query.ndim() != 3) { + throw std::invalid_argument( + "multiheadAttention - query input tensor should be 3 dimensions: " + "Time x (nHeads * headDim) x B" + ); } - auto padMaskTile = moddims(padMask, {1, padMask.dim(0), 1, bsz}); - padMaskTile = - tileAs(padMaskTile, {padMask.dim(0), padMask.dim(0), nHeads, bsz}); - scores = scores + - moddims(padMaskTile.astype(scores.type()), - {padMask.dim(0), padMask.dim(0), nHeads * bsz}); - } - - auto attn = dropout(softmax(scores, 1), pDropout); - auto result = matmul(attn.astype(v.type()), v); - result = moddims(result, {-1, headDim * nHeads, bsz}); - return result; + if(key.ndim() != 3) { + throw std::invalid_argument( + "multiheadAttention - key input tensor should be 3 dimensions: " + "Time x (nHeads * headDim) x B" + ); + } + if(value.ndim() != 3) { + throw std::invalid_argument( + "multiheadAttention - value input tensor should be 3 dimensions: " + "Time x (nHeads * headDim) x B" + ); + } + + int32_t bsz = query.dim(2); + int32_t modelDim = query.dim(1); + int32_t headDim = modelDim / nHeads; + + auto q = moddims(query, {-1, headDim, nHeads * bsz}); + auto k = moddims(key, {-1, headDim, nHeads * bsz}); + auto v = moddims(value, {-1, headDim, nHeads * bsz}); + + q = q / std::sqrt(float(headDim)); + auto scores = matmulNT(q, k); + if(!posEmb.isEmpty()) { + int n = posEmb.dim(0) / 2 - offset; + auto pscores = + relativePositionEmbeddingRotate(matmulNT(posEmb.astype(q.type()), q)); + scores = + scores + transpose(pscores(fl::range(n, n + k.dim(0))), {1, 0, 2}); + } + if(!mask.isEmpty()) { + scores = scores + tileAs(mask.astype(scores.type()), scores); + } + if(!padMask.isEmpty()) { + if(padMask.dim(0) != query.dim(0)) { + throw std::invalid_argument( + "multiheadAttention: invalid padding mask size" + ); + } + auto padMaskTile = moddims(padMask, {1, padMask.dim(0), 1, bsz}); + padMaskTile = + tileAs(padMaskTile, {padMask.dim(0), padMask.dim(0), nHeads, bsz}); + scores = scores + + moddims( + padMaskTile.astype(scores.type()), + {padMask.dim(0), padMask.dim(0), nHeads * bsz} + ); + } + + auto attn = dropout(softmax(scores, 1), pDropout); + auto result = matmul(attn.astype(v.type()), v); + result = moddims(result, {-1, headDim * nHeads, bsz}); + return result; } } // namespace fl diff --git a/flashlight/fl/autograd/Functions.h b/flashlight/fl/autograd/Functions.h index 9dbaf79..2e64e80 100644 --- a/flashlight/fl/autograd/Functions.h +++ b/flashlight/fl/autograd/Functions.h @@ -22,71 +22,75 @@ class Variable; namespace detail { -struct ConvBenchmarks; + struct ConvBenchmarks; -struct AutogradPayload; + struct AutogradPayload; -FL_API Tensor tileAs(const Tensor& input, const Shape& rdims); + FL_API Tensor tileAs(const Tensor& input, const Shape& rdims); -FL_API Tensor sumAs(const Tensor& input, const Shape& rdims); + FL_API Tensor sumAs(const Tensor& input, const Shape& rdims); /* Reshape a tensor to the dims it had before a reduction over the given axes. * - This is a no-op if keepDims is true. */ -FL_API Shape expandedShapeFromReducedDims( - const Tensor& input, - const std::vector& axes, - bool keepDims = false); - -FL_API bool areVariableTypesEqual(const Variable& a, const Variable& b); - -template -bool areVariableTypesEqual( - const Variable& a, - const Variable& b, - const Args&... args) { - return areVariableTypesEqual(a, b) && areVariableTypesEqual(a, args...) && - areVariableTypesEqual(b, args...); -} + FL_API Shape expandedShapeFromReducedDims( + const Tensor& input, + const std::vector& axes, + bool keepDims = false + ); + + FL_API bool areVariableTypesEqual(const Variable& a, const Variable& b); + + template + bool areVariableTypesEqual( + const Variable& a, + const Variable& b, + const Args&... args + ) { + return areVariableTypesEqual(a, b) && areVariableTypesEqual(a, args...) + && areVariableTypesEqual(b, args...); + } /** * Performs type conversion based on the optim level. Operations that lack * sufficient precision are automatically upcast to f32 before computation. * These are typically operations that require accumulations or reductions. */ -template -T adjustInputType(const T& in, const char* funcname) { - OptimLevel optimLevel = OptimMode::get().getOptimLevel(); - // Fastpath - DEFAULT mode never casts tensors - if (optimLevel == OptimLevel::DEFAULT) { - return in; - } - - T res; - auto& funcs = kOptimLevelTypeExclusionMappings.find(optimLevel)->second; - // TODO: tiny, but this lookup incurs an extra alloc from char* to string - if (funcs.find(std::string(funcname)) == funcs.end() && - optimLevel != OptimLevel::DEFAULT) { - // Not in the excluded list - cast to f16 - res = in.astype(fl::dtype::f16); - } else { - // Upcast to f32 only if we have an f16 input - otherwise, leave as is - if (in.type() == fl::dtype::f16) { - res = in.astype(fl::dtype::f32); - } else { - res = in; + template + T adjustInputType(const T& in, const char* funcname) { + OptimLevel optimLevel = OptimMode::get().getOptimLevel(); + // Fastpath - DEFAULT mode never casts tensors + if(optimLevel == OptimLevel::DEFAULT) { + return in; + } + + T res; + auto& funcs = kOptimLevelTypeExclusionMappings.find(optimLevel)->second; + // TODO: tiny, but this lookup incurs an extra alloc from char* to string + if( + funcs.find(std::string(funcname)) == funcs.end() + && optimLevel != OptimLevel::DEFAULT + ) { + // Not in the excluded list - cast to f16 + res = in.astype(fl::dtype::f16); + } else { + // Upcast to f32 only if we have an f16 input - otherwise, leave as is + if(in.type() == fl::dtype::f16) { + res = in.astype(fl::dtype::f32); + } else { + res = in; + } + } + + return res; } - } - return res; -} - -template -std::shared_ptr createAutogradPayload(H head, T... tail) { - return (head.isCalcGrad() || ... || tail.isCalcGrad()) - ? std::make_shared() - : nullptr; -} + template + std::shared_ptr createAutogradPayload(H head, T... tail) { + return (head.isCalcGrad() || ... || tail.isCalcGrad()) + ? std::make_shared() + : nullptr; + } } // namespace detail @@ -99,13 +103,14 @@ std::shared_ptr createAutogradPayload(H head, T... tail) { /** * Checks if a variadic number of Variables have the same types. */ -#define FL_VARIABLE_DTYPES_MATCH_CHECK(...) \ - if (!detail::areVariableTypesEqual(__VA_ARGS__)) { \ - throw std::invalid_argument( \ - std::string(__func__) + \ - " doesn't support binary " \ - "operations with Variables of different types"); \ - } +#define FL_VARIABLE_DTYPES_MATCH_CHECK(...) \ + if(!detail::areVariableTypesEqual(__VA_ARGS__)) { \ + throw std::invalid_argument( \ + std::string(__func__) + \ + " doesn't support binary " \ + "operations with Variables of different types" \ + ); \ + } /** * \defgroup autograd_functions Autograd Functions @@ -337,8 +342,7 @@ FL_API Variable tanh(const Variable& input); * \text{max} & \text{if } x_i > \text{max} * \end{cases}\end{split} \f] */ -FL_API Variable -clamp(const Variable& input, const double min, const double max); +FL_API Variable clamp(const Variable& input, const double min, const double max); /** * Computes sigmoid of each element in a Variable. @@ -447,8 +451,7 @@ FL_API Variable concatenate(const std::vector& concatInputs, int dim); * divisible, last chunk of smaller splitSize will be included. * @param dim dimension along which to split the Variable */ -FL_API std::vector -split(const Variable& input, long splitSize, int dim); +FL_API std::vector split(const Variable& input, long splitSize, int dim); /** * Splits a Variable into smaller chunks. @@ -457,8 +460,7 @@ split(const Variable& input, long splitSize, int dim); * @param splitSizes vector of integers specifying the sizes for each split * @param dim dimension along which to split the Variable */ -FL_API std::vector -split(const Variable& input, const std::vector& splitSizes, int dim); +FL_API std::vector split(const Variable& input, const std::vector& splitSizes, int dim); /** * Repeats the tensor `input` along specific dimensions. The number of @@ -475,15 +477,13 @@ FL_API Variable tile(const Variable& input, const Shape& dims); * applied on parameters and the results will be used in a half precision * arithmetic. */ -FL_API Variable -tile(const Variable& input, const Shape& dims, const fl::dtype precision); +FL_API Variable tile(const Variable& input, const Shape& dims, const fl::dtype precision); /** * Sums up the tensors `input` along dimensions specified in descriptor `axes`. * If `axes` has size greater than 1, reduce over all of them. */ -FL_API Variable -sum(const Variable& input, const std::vector& axes, bool keepDims = false); +FL_API Variable sum(const Variable& input, const std::vector& axes, bool keepDims = false); /** * Computes the mean of the tensor `input` along dimensions specified in @@ -493,7 +493,8 @@ sum(const Variable& input, const std::vector& axes, bool keepDims = false); FL_API Variable mean( const Variable& input, const std::vector& axes, - bool keepDims = false); + bool keepDims = false +); /** * Lp-norm computation, reduced over specified dimensions. @@ -506,7 +507,8 @@ FL_API Variable norm( const Variable& input, const std::vector& axes, double p = 2, - bool keepDims = false); + bool keepDims = false +); /** * Lp norm normalization of values across the given dimensions. @@ -520,7 +522,8 @@ FL_API Variable normalize( const Variable& input, const std::vector& axes, double p = 2, - double eps = 1e-12); + double eps = 1e-12 +); /** * Computes variance of the tensor `input` along dimensions specified in @@ -534,11 +537,12 @@ FL_API Variable normalize( * ArrayFire before v3.7.0, the reverse is true. * TODO:{fl::Tensor} -- make this behavior consistent */ -FL_API Variable -var(const Variable& input, +FL_API Variable var( + const Variable& input, const std::vector& axes, const bool isbiased = false, - bool keepDims = false); + bool keepDims = false +); /** * Conducts matrix-matrix multiplication on two Variables. This is a batched @@ -616,8 +620,7 @@ FL_API Variable linear(const Variable& input, const Variable& weight); * @param bias a Variable with shape [\f$K\f$] * @return a Variable with shape [\f$K\f$, \f$M\f$, \f$B_1\f$, \f$B_2\f$] */ -FL_API Variable -linear(const Variable& input, const Variable& weight, const Variable& bias); +FL_API Variable linear(const Variable& input, const Variable& weight, const Variable& bias); /** * Applies a 2D convolution over an input signal given filter weights. In the @@ -661,7 +664,8 @@ FL_API Variable conv2d( int dx = 1, int dy = 1, int groups = 1, - std::shared_ptr benchmarks = nullptr); + std::shared_ptr benchmarks = nullptr +); /** * Applies a 2D convolution over an input signal given filter weights and @@ -708,7 +712,8 @@ FL_API Variable conv2d( int dx = 1, int dy = 1, int groups = 1, - std::shared_ptr benchmarks = nullptr); + std::shared_ptr benchmarks = nullptr +); /** * Applies a 2D pooling over an input signal composed of several input planes. @@ -735,7 +740,8 @@ FL_API Variable pool2d( int sy = 1, int px = 0, int py = 0, - PoolingMode mode = PoolingMode::MAX); + PoolingMode mode = PoolingMode::MAX +); /** * Applies a softmax function on Variable `input` along dimension `dim`, so that @@ -765,8 +771,7 @@ FL_API Variable logSoftmax(const Variable& input, const int dim); * @param inputs a tensor with the predicted values * @param targets a tensor with the target values */ -FL_API Variable -binaryCrossEntropy(const Variable& inputs, const Variable& targets); +FL_API Variable binaryCrossEntropy(const Variable& inputs, const Variable& targets); /** * Computes the categorical cross entropy loss. The input is expected to @@ -802,7 +807,8 @@ FL_API Variable categoricalCrossEntropy( const Variable& input, const Variable& targets, ReduceMode reduction = ReduceMode::MEAN, - int ignoreIndex = -1); + int ignoreIndex = -1 +); /** * Computes the weighted cross entropy loss. The input is expected to @@ -827,7 +833,8 @@ FL_API Variable weightedCategoricalCrossEntropy( const Variable& input, const Variable& targets, const Variable& weight, - int ignoreIndex); + int ignoreIndex +); /** * The gated linear unit. @@ -871,7 +878,7 @@ FL_API Variable gatedlinearunit(const Variable& input, const int dim); * - LSTM * - GRU * @param bidirectional if `True`, becomes a bidirectional RNN, unidirectional - otherwise + otherwise * @param dropout if non-zero, introduces a `Dropout` layer on the outputs of * each RNN layer except the last one, with dropout probability equal to dropout @@ -890,7 +897,8 @@ FL_API std::tuple rnn( int numLayers, RnnMode mode, bool bidirectional, - float dropout); + float dropout +); /** * Looks up embeddings in a fixed dictionary and size. @@ -941,7 +949,8 @@ FL_API Variable batchnorm( const std::vector& axes, bool train, double momentum, - double epsilon); + double epsilon +); /** * Applies asymmetric padding on a Variable `input`. @@ -954,7 +963,8 @@ FL_API Variable batchnorm( FL_API Variable padding( const Variable& input, std::vector> pad, - double val); + double val +); /** * Applies dropout on a Variable `input`. @@ -1011,7 +1021,8 @@ FL_API Variable multiheadAttention( const Variable& padMask, const int32_t nHeads, const double pDropout, - const int32_t offset = 0); + const int32_t offset = 0 +); /** @} */ diff --git a/flashlight/fl/autograd/Utils.cpp b/flashlight/fl/autograd/Utils.cpp index 691b7e9..1f958c8 100644 --- a/flashlight/fl/autograd/Utils.cpp +++ b/flashlight/fl/autograd/Utils.cpp @@ -14,8 +14,9 @@ namespace fl { bool allClose( const Variable& a, const Variable& b, - double absTolerance /* = 1e-5 */) { - return allClose(a.tensor(), b.tensor(), absTolerance); + double absTolerance /* = 1e-5 */ +) { + return allClose(a.tensor(), b.tensor(), absTolerance); } } // namespace fl diff --git a/flashlight/fl/autograd/Utils.h b/flashlight/fl/autograd/Utils.h index a3e03e9..ae2d3ed 100644 --- a/flashlight/fl/autograd/Utils.h +++ b/flashlight/fl/autograd/Utils.h @@ -25,8 +25,7 @@ namespace fl { * @param absTolerance absolute tolerance allowed * */ -FL_API bool -allClose(const Variable& a, const Variable& b, double absTolerance = 1e-5); +FL_API bool allClose(const Variable& a, const Variable& b, double absTolerance = 1e-5); /** @} */ diff --git a/flashlight/fl/autograd/Variable.cpp b/flashlight/fl/autograd/Variable.cpp index 57669b8..256a11d 100644 --- a/flashlight/fl/autograd/Variable.cpp +++ b/flashlight/fl/autograd/Variable.cpp @@ -24,273 +24,282 @@ namespace fl { Variable::Variable(Tensor data, bool calcGrad) { - sharedData_->data = std::move(data); - sharedGrad_->calcGrad = calcGrad; + sharedData_->data = std::move(data); + sharedGrad_->calcGrad = calcGrad; } Variable::Variable( Tensor data, std::vector inputs, - GradFunc gradFunc) { - sharedData_->data = std::move(data); - if (std::any_of(inputs.begin(), inputs.end(), [](const Variable& input) { - return input.isCalcGrad(); - })) { - sharedGrad_->calcGrad = true; - sharedGrad_->inputs = std::move(inputs); - sharedGrad_->gradFunc = std::move(gradFunc); - } + GradFunc gradFunc +) { + sharedData_->data = std::move(data); + if( + std::any_of( + inputs.begin(), + inputs.end(), + [](const Variable& input) { + return input.isCalcGrad(); + } + ) + ) { + sharedGrad_->calcGrad = true; + sharedGrad_->inputs = std::move(inputs); + sharedGrad_->gradFunc = std::move(gradFunc); + } } Variable Variable::operator()(const std::vector& indices) const { - auto result = tensor()(indices); - auto inDims = shape(); - auto inType = type(); - - auto gradFunc = [indices, inDims, inType]( - std::vector& inputs, - const Variable& gradOutput) { - if (!inputs[0].isGradAvailable()) { - auto grad = fl::full(inDims, 0.0, inType); - inputs[0].addGrad(Variable(grad, false)); - } + auto result = tensor()(indices); + auto inDims = shape(); + auto inType = type(); + + auto gradFunc = [indices, inDims, inType]( + std::vector& inputs, + const Variable& gradOutput) { + if(!inputs[0].isGradAvailable()) { + auto grad = fl::full(inDims, 0.0, inType); + inputs[0].addGrad(Variable(grad, false)); + } - auto& grad = inputs[0].grad().tensor(); - grad(indices) += gradOutput.tensor(); - }; - return Variable(result, {this->withoutData()}, gradFunc); + auto& grad = inputs[0].grad().tensor(); + grad(indices) += gradOutput.tensor(); + }; + return Variable(result, {this->withoutData()}, gradFunc); } Variable Variable::flat(const fl::Index& index) const { - auto result = tensor().flat(index); - auto inDims = shape(); - auto inType = type(); - - auto gradFunc = [index, inDims, inType]( - std::vector& inputs, - const Variable& gradOutput) { - if (!inputs[0].isGradAvailable()) { - auto grad = fl::full(inDims, 0.0, inType); - inputs[0].addGrad(Variable(grad, false)); - } - auto& grad = inputs[0].grad().tensor(); - grad.flat(index) += gradOutput.tensor(); - }; + auto result = tensor().flat(index); + auto inDims = shape(); + auto inType = type(); + + auto gradFunc = [index, inDims, inType]( + std::vector& inputs, + const Variable& gradOutput) { + if(!inputs[0].isGradAvailable()) { + auto grad = fl::full(inDims, 0.0, inType); + inputs[0].addGrad(Variable(grad, false)); + } + auto& grad = inputs[0].grad().tensor(); + grad.flat(index) += gradOutput.tensor(); + }; - return Variable(result, {this->withoutData()}, gradFunc); + return Variable(result, {this->withoutData()}, gradFunc); } Tensor& Variable::tensor() const { - return sharedData_->data; + return sharedData_->data; } Variable Variable::copy() const { - return Variable(sharedData_->data, sharedGrad_->calcGrad); + return Variable(sharedData_->data, sharedGrad_->calcGrad); } Variable Variable::astype(fl::dtype newType) const { - auto output = tensor().astype(newType); - auto gradFunc = [](std::vector& inputs, - const Variable& gradOutput) { - auto& input = inputs[0]; - // Cast the grad output to match the type of the input's grad - input.addGrad(Variable(gradOutput.tensor().astype(input.type()), false)); - }; - return Variable(output, {this->withoutData()}, gradFunc); + auto output = tensor().astype(newType); + auto gradFunc = [](std::vector& inputs, + const Variable& gradOutput) { + auto& input = inputs[0]; + // Cast the grad output to match the type of the input's grad + input.addGrad(Variable(gradOutput.tensor().astype(input.type()), false)); + }; + return Variable(output, {this->withoutData()}, gradFunc); } Variable& Variable::grad() const { - if (!sharedGrad_->calcGrad) { - throw std::logic_error("gradient calculation disabled for this Variable"); - } + if(!sharedGrad_->calcGrad) { + throw std::logic_error("gradient calculation disabled for this Variable"); + } - if (!sharedGrad_->grad) { - throw std::logic_error("gradient not calculated yet for this Variable"); - } + if(!sharedGrad_->grad) { + throw std::logic_error("gradient not calculated yet for this Variable"); + } - return *sharedGrad_->grad; + return *sharedGrad_->grad; } std::vector& Variable::getInputs() const { - return sharedGrad_->inputs; + return sharedGrad_->inputs; } bool Variable::isCalcGrad() const { - return sharedGrad_->calcGrad; + return sharedGrad_->calcGrad; } bool Variable::isGradAvailable() const { - if (!sharedGrad_->calcGrad) { - return false; - } - return sharedGrad_->grad != nullptr; + if(!sharedGrad_->calcGrad) { + return false; + } + return sharedGrad_->grad != nullptr; } Shape Variable::shape() const { - return tensor().shape(); + return tensor().shape(); } bool Variable::isEmpty() const { - return tensor().isEmpty(); + return tensor().isEmpty(); } bool Variable::isContiguous() const { - return tensor().isContiguous(); + return tensor().isContiguous(); } Variable Variable::asContiguous() const { - if (!isEmpty() && !isContiguous()) { - tensor() = tensor().asContiguousTensor(); - } - return *this; + if(!isEmpty() && !isContiguous()) { + tensor() = tensor().asContiguousTensor(); + } + return *this; } fl::dtype Variable::type() const { - return tensor().type(); + return tensor().type(); } Dim Variable::elements() const { - return tensor().elements(); + return tensor().elements(); } size_t Variable::bytes() const { - return tensor().bytes(); + return tensor().bytes(); } unsigned Variable::ndim() const { - return tensor().ndim(); + return tensor().ndim(); } Dim Variable::dim(unsigned dim) const { - return tensor().dim(dim); + return tensor().dim(dim); } void Variable::eval() const { - fl::eval(tensor()); + fl::eval(tensor()); } void Variable::zeroGrad() { - sharedGrad_->grad.reset(); + sharedGrad_->grad.reset(); } void Variable::setCalcGrad(bool calcGrad) { - sharedGrad_->calcGrad = calcGrad; - if (!calcGrad) { - sharedGrad_->gradFunc = nullptr; - sharedGrad_->inputs.clear(); - sharedGrad_->grad.reset(); - } + sharedGrad_->calcGrad = calcGrad; + if(!calcGrad) { + sharedGrad_->gradFunc = nullptr; + sharedGrad_->inputs.clear(); + sharedGrad_->grad.reset(); + } } void Variable::addGrad(const Variable& childGrad) { - if (sharedGrad_->calcGrad) { - // Ensure the type of the child grad is the same as the type of this - // Variable (and transitively, that it's the same type as an existing grad) - if (childGrad.type() != this->type()) { - std::stringstream ss; - ss << "Variable::addGrad: attempted to add child gradient of type " - << childGrad.type() << " to a Variable of type " << this->type() - << ". You might be performing an operation with " + if(sharedGrad_->calcGrad) { + // Ensure the type of the child grad is the same as the type of this + // Variable (and transitively, that it's the same type as an existing grad) + if(childGrad.type() != this->type()) { + std::stringstream ss; + ss << "Variable::addGrad: attempted to add child gradient of type " + << childGrad.type() << " to a Variable of type " << this->type() + << ". You might be performing an operation with " "two inputs of different types."; - throw std::invalid_argument(ss.str()); - } - if (childGrad.shape() != this->shape()) { - std::stringstream ss; - ss << "Variable::addGrad: given gradient has dimensions not equal " + throw std::invalid_argument(ss.str()); + } + if(childGrad.shape() != this->shape()) { + std::stringstream ss; + ss << "Variable::addGrad: given gradient has dimensions not equal " "to this Variable's dimensions: this variable has shape " - << this->shape() << " whereas the child gradient has dimensions " - << childGrad.shape() << std::endl; - throw std::invalid_argument(ss.str()); - } - if (sharedGrad_->grad) { - // Prevent increment of array refcount to avoid a copy - // if getting a device pointer. See - // https://git.io/fp9oM for more - sharedGrad_->grad = std::make_unique( - sharedGrad_->grad->tensor() + childGrad.tensor(), false); - } else { - // Copy the childGrad Variable so as to share a reference - // to the underlying childGrad.tensor() rather than copying - // the tensor into a new variable - sharedGrad_->grad = std::make_unique(childGrad); + << this->shape() << " whereas the child gradient has dimensions " + << childGrad.shape() << std::endl; + throw std::invalid_argument(ss.str()); + } + if(sharedGrad_->grad) { + // Prevent increment of array refcount to avoid a copy + // if getting a device pointer. See + // https://git.io/fp9oM for more + sharedGrad_->grad = std::make_unique( + sharedGrad_->grad->tensor() + childGrad.tensor(), + false + ); + } else { + // Copy the childGrad Variable so as to share a reference + // to the underlying childGrad.tensor() rather than copying + // the tensor into a new variable + sharedGrad_->grad = std::make_unique(childGrad); + } } - } } void Variable::registerGradHook(const GradHook& hook) { - sharedGrad_->onGradAvailable = hook; + sharedGrad_->onGradAvailable = hook; } void Variable::clearGradHook() { - sharedGrad_->onGradAvailable = nullptr; + sharedGrad_->onGradAvailable = nullptr; } void Variable::applyGradHook() { - if (sharedGrad_->onGradAvailable) { - assert(sharedGrad_->grad); - sharedGrad_->onGradAvailable(*sharedGrad_->grad); - } + if(sharedGrad_->onGradAvailable) { + assert(sharedGrad_->grad); + sharedGrad_->onGradAvailable(*sharedGrad_->grad); + } } void Variable::calcGradInputs(bool retainGraph) { - if (sharedGrad_->gradFunc) { - if (!sharedGrad_->grad) { - throw std::logic_error("gradient was not propagated to this Variable"); - } + if(sharedGrad_->gradFunc) { + if(!sharedGrad_->grad) { + throw std::logic_error("gradient was not propagated to this Variable"); + } - sharedGrad_->gradFunc(sharedGrad_->inputs, *sharedGrad_->grad); - } - if (!retainGraph) { - sharedGrad_->inputs.clear(); - } + sharedGrad_->gradFunc(sharedGrad_->inputs, *sharedGrad_->grad); + } + if(!retainGraph) { + sharedGrad_->inputs.clear(); + } } void Variable::backward(const Variable& grad, bool retainGraph) { - addGrad(grad); - auto dag = build(); - for (auto iter = dag.rbegin(); iter != dag.rend(); iter++) { - iter->calcGradInputs(retainGraph); - iter->applyGradHook(); - if (!retainGraph) { - *iter = Variable(); + addGrad(grad); + auto dag = build(); + for(auto iter = dag.rbegin(); iter != dag.rend(); iter++) { + iter->calcGradInputs(retainGraph); + iter->applyGradHook(); + if(!retainGraph) { + *iter = Variable(); + } } - } } void Variable::backward(bool retainGraph) { - auto ones = Variable(fl::full(shape(), 1, this->type()), false); - backward(ones, retainGraph); + auto ones = Variable(fl::full(shape(), 1, this->type()), false); + backward(ones, retainGraph); } Variable Variable::withoutData() const { - Variable other; - other.sharedGrad_ = sharedGrad_; - // Ensure the type of the underlying [but empty] Tensor data is of the same - // type and shape - other.tensor() = Tensor(shape(), this->type()); - return other; + Variable other; + other.sharedGrad_ = sharedGrad_; + // Ensure the type of the underlying [but empty] Tensor data is of the same + // type and shape + other.tensor() = Tensor(shape(), this->type()); + return other; } Variable::DAG Variable::build() const { - std::unordered_set cache; - DAG dag; - std::function recurse; - - // Topological sort - recurse = [&](const Variable& var) { - auto id = var.sharedGrad_.get(); - if (cache.find(id) != cache.end()) { - return; - } - for (const auto& input : var.getInputs()) { - recurse(input); - } - cache.insert(id); - dag.push_back(var); - }; - - recurse(*this); - return dag; + std::unordered_set cache; + DAG dag; + std::function recurse; + + // Topological sort + recurse = [&](const Variable& var) { + auto id = var.sharedGrad_.get(); + if(cache.find(id) != cache.end()) { + return; + } + for(const auto& input : var.getInputs()) { + recurse(input); + } + cache.insert(id); + dag.push_back(var); + }; + + recurse(*this); + return dag; } } // namespace fl diff --git a/flashlight/fl/autograd/Variable.h b/flashlight/fl/autograd/Variable.h index c18d16c..60fb040 100644 --- a/flashlight/fl/autograd/Variable.h +++ b/flashlight/fl/autograd/Variable.h @@ -60,283 +60,283 @@ namespace fl { * \endcode */ class FL_API Variable { - public: - using GradFunc = std::function< - void(std::vector& inputs, const Variable& grad_output)>; - - using GradHook = std::function; - - /** - * Creates an empty Variable. The underlying array is empty and - * isCalcGrad() is false. - */ - Variable() = default; - - /** - * Creates a Variable which wraps the specified Tensor - * @param[in] data Tensor to be stored in the Variable - * @param[in] calcGrad specifies whether to the gradient is required for this - * Variable - */ - Variable(Tensor data, bool calcGrad); - - /** - * Creates a Variable which wraps the specified Tensor and inputs - * @param[in] data Tensor to the stored in the Variable - * @param[in] inputs a vector specifying inputs for this Variable - * @param[in] gradFunc function specifying how to calculate gradient of the - * input Variables - */ - Variable(Tensor data, std::vector inputs, GradFunc gradFunc); - - Variable operator()(const std::vector& indices) const; - - /** - * Indexing operator on a flattened Variable. - * @param[in] indices a variable number of indices. - * @return Variable storing the result after indexing operation - */ - template - Variable operator()(const Ts&... args) const { - std::vector indices{{args...}}; - return this->operator()(indices); - } - - /** - * Indexing operator on a flattened Variable. - * @param[in] index index with which to index the flattened tensor - * @return Variable storing the result after indexing operation - */ - Variable flat(const fl::Index& index) const; - - /** - * @return a reference to the underlying Flashlight Tensor. - */ - Tensor& tensor() const; - - /** - * Creates a copy of this variable, but detached from the computation graph. - * @return returns the cloned and detached variable. - */ - Variable copy() const; - - /** - * Creates a new variable based on the current variable whose type will be - * adjusted based on the input type. - * - * @param[in] type target data type - * - * @return returns the casted variable. - */ - Variable astype(fl::dtype type) const; - - /** - * @return a reference to the underlying gradient Variable. - */ - Variable& grad() const; - - /** - * Returns whether the gradient calculation for the Variable is enabled - */ - bool isCalcGrad() const; - - /** - * Returns whether the gradient has been calculated for the Variable - */ - bool isGradAvailable() const; - - /** - * Returns the dimension of the array wrapped by the Variable - */ - Shape shape() const; - - /** - * Returns whether the array wrapped by the Variable is empty - */ - bool isEmpty() const; - - /** - * Returns whether the array wrapped by the Variable is contiguous in memory - * in C order. - */ - bool isContiguous() const; - - /** - * Returns a Variable with contiguous array containing the same data as self - * array. - */ - Variable asContiguous() const; - - /** - * Returns the type of the `Tensor` wrapped by the Variable - * (e.g. f32 for float, f64 for double). - * - * See `fl/tensor/Types.h`. - */ - fl::dtype type() const; - - /** - * Returns the total number of elements stored in array wrapped by the - * Variable - */ - Dim elements() const; - - /** - * Returns the total number of bytes stored in array wrapped by the - * Variable - */ - size_t bytes() const; - - /** - * Returns the number of dimension of array wrapped by the Variable - */ - unsigned ndim() const; - - /** - * Returns the dimension of array wrapped by the Variable - */ - Dim dim(unsigned dim) const; - - /** - * Evaluates any expressions in the array wrapped by the Variable - */ - void eval() const; - - /** - * Copies the array to the host and return the pointer. - * Must eventually be freed manually via `free` or a related call. - */ - template - T* host() const { - return tensor().host(); - } - - /** - * Copies the array to the existing host pointer `ptr` - */ - template - void host(T* ptr) const { - tensor().host(ptr); - } - - /** - * Get the first element of the array as a scalar - */ - template - T scalar() const { - return tensor().scalar(); - } - - /** - * Remove the gradient stored by the Variable - */ - void zeroGrad(); - - /** - * Set whether to calculate gradient for the Variable. - */ - void setCalcGrad(bool calcGrad); - - /** - * Add the gradient `childGrad` to the Variable. - * No-op if `this->isCalcGrad()` is false. - */ - void addGrad(const Variable& childGrad); - - /** - * Registers a lambda function `hook` to be applied on the gradient w.r.t - * Variable after it is computed during backward pass - */ - void registerGradHook(const GradHook& hook); - - /** - * Clears the gradient hook stored in the variable - */ - void clearGradHook(); - - /** - * Run backward pass on the Variable. Gradient of all the inputs - * in the computation graph leading up to the Variable on which the function - * is computed. - * @param[in] grad gradient w.r.t to the Variable - * @param[in] retainGraph If False, clears the input Variables stored - * by the Variable - */ - void backward(const Variable& grad, bool retainGraph = false); - - /** - * Run backward pass on the Variable. Gradient of all the inputs - * in the computation graph leading up to the Variable on which the function - * is computed. Gradient w.r.t the all the elements in the variable is set - * to 1.0 - * @param[in] retainGraph If False, clears the input Variables stored - * by the Variable - */ - void backward(bool retainGraph = false); - - /** - * Returns a copy of this variable after removing its underlying array. - * The new Variable is used to store the inputs for a Variable - * which doesn't need the output. - */ - Variable withoutData() const; - - private: - using DAG = std::vector; - - /** - * Get all the inputs to this Variable - */ - std::vector& getInputs() const; - - /** - * Builds the computation graph which comprises of all the input Variables for - * which the gradient of `var` can be propagated using chain rule - */ - DAG build() const; - - /** - * Calculate the gradient of inputs. - * @param[in] retainGraph If False, clears the inputs stored - * by the Variable - */ - void calcGradInputs(bool retainGraph = false); - - /** - * Calls the gradient hook (if any) registered by the Variable - */ - void applyGradHook(); - - struct SharedData { - /// Array wrapped by this Variable - Tensor data; - - FL_SAVE_LOAD(data) - }; - - struct SharedGrad { - /// Whether the gradient should be computed for this Variable - bool calcGrad{false}; - /// Inputs of this Variable - std::vector inputs; - /// Gradient with respect to this Variable - std::unique_ptr grad{nullptr}; - /// Function for calculating the gradient of the input Variables - GradFunc gradFunc{nullptr}; - /// Function applied to gradient after it's computed during bwd pass - GradHook onGradAvailable{nullptr}; - - private: - FL_SAVE_LOAD(calcGrad); - }; - - std::shared_ptr sharedData_ = std::make_shared(); - std::shared_ptr sharedGrad_ = std::make_shared(); - - // NB: array only; we don't try to serialize the autograd graph - // Saving the sharedData ptr helps to avoid saving variables which share the - // same underlying tensor twice - FL_SAVE_LOAD(sharedData_, sharedGrad_) +public: + using GradFunc = std::function< + void (std::vector& inputs, const Variable& grad_output)>; + + using GradHook = std::function; + + /** + * Creates an empty Variable. The underlying array is empty and + * isCalcGrad() is false. + */ + Variable() = default; + + /** + * Creates a Variable which wraps the specified Tensor + * @param[in] data Tensor to be stored in the Variable + * @param[in] calcGrad specifies whether to the gradient is required for this + * Variable + */ + Variable(Tensor data, bool calcGrad); + + /** + * Creates a Variable which wraps the specified Tensor and inputs + * @param[in] data Tensor to the stored in the Variable + * @param[in] inputs a vector specifying inputs for this Variable + * @param[in] gradFunc function specifying how to calculate gradient of the + * input Variables + */ + Variable(Tensor data, std::vector inputs, GradFunc gradFunc); + + Variable operator()(const std::vector& indices) const; + + /** + * Indexing operator on a flattened Variable. + * @param[in] indices a variable number of indices. + * @return Variable storing the result after indexing operation + */ + template + Variable operator()(const Ts&... args) const { + std::vector indices{{args...}}; + return this->operator()(indices); + } + + /** + * Indexing operator on a flattened Variable. + * @param[in] index index with which to index the flattened tensor + * @return Variable storing the result after indexing operation + */ + Variable flat(const fl::Index& index) const; + + /** + * @return a reference to the underlying Flashlight Tensor. + */ + Tensor& tensor() const; + + /** + * Creates a copy of this variable, but detached from the computation graph. + * @return returns the cloned and detached variable. + */ + Variable copy() const; + + /** + * Creates a new variable based on the current variable whose type will be + * adjusted based on the input type. + * + * @param[in] type target data type + * + * @return returns the casted variable. + */ + Variable astype(fl::dtype type) const; + + /** + * @return a reference to the underlying gradient Variable. + */ + Variable& grad() const; + + /** + * Returns whether the gradient calculation for the Variable is enabled + */ + bool isCalcGrad() const; + + /** + * Returns whether the gradient has been calculated for the Variable + */ + bool isGradAvailable() const; + + /** + * Returns the dimension of the array wrapped by the Variable + */ + Shape shape() const; + + /** + * Returns whether the array wrapped by the Variable is empty + */ + bool isEmpty() const; + + /** + * Returns whether the array wrapped by the Variable is contiguous in memory + * in C order. + */ + bool isContiguous() const; + + /** + * Returns a Variable with contiguous array containing the same data as self + * array. + */ + Variable asContiguous() const; + + /** + * Returns the type of the `Tensor` wrapped by the Variable + * (e.g. f32 for float, f64 for double). + * + * See `fl/tensor/Types.h`. + */ + fl::dtype type() const; + + /** + * Returns the total number of elements stored in array wrapped by the + * Variable + */ + Dim elements() const; + + /** + * Returns the total number of bytes stored in array wrapped by the + * Variable + */ + size_t bytes() const; + + /** + * Returns the number of dimension of array wrapped by the Variable + */ + unsigned ndim() const; + + /** + * Returns the dimension of array wrapped by the Variable + */ + Dim dim(unsigned dim) const; + + /** + * Evaluates any expressions in the array wrapped by the Variable + */ + void eval() const; + + /** + * Copies the array to the host and return the pointer. + * Must eventually be freed manually via `free` or a related call. + */ + template + T* host() const { + return tensor().host(); + } + + /** + * Copies the array to the existing host pointer `ptr` + */ + template + void host(T* ptr) const { + tensor().host(ptr); + } + + /** + * Get the first element of the array as a scalar + */ + template + T scalar() const { + return tensor().scalar(); + } + + /** + * Remove the gradient stored by the Variable + */ + void zeroGrad(); + + /** + * Set whether to calculate gradient for the Variable. + */ + void setCalcGrad(bool calcGrad); + + /** + * Add the gradient `childGrad` to the Variable. + * No-op if `this->isCalcGrad()` is false. + */ + void addGrad(const Variable& childGrad); + + /** + * Registers a lambda function `hook` to be applied on the gradient w.r.t + * Variable after it is computed during backward pass + */ + void registerGradHook(const GradHook& hook); + + /** + * Clears the gradient hook stored in the variable + */ + void clearGradHook(); + + /** + * Run backward pass on the Variable. Gradient of all the inputs + * in the computation graph leading up to the Variable on which the function + * is computed. + * @param[in] grad gradient w.r.t to the Variable + * @param[in] retainGraph If False, clears the input Variables stored + * by the Variable + */ + void backward(const Variable& grad, bool retainGraph = false); + + /** + * Run backward pass on the Variable. Gradient of all the inputs + * in the computation graph leading up to the Variable on which the function + * is computed. Gradient w.r.t the all the elements in the variable is set + * to 1.0 + * @param[in] retainGraph If False, clears the input Variables stored + * by the Variable + */ + void backward(bool retainGraph = false); + + /** + * Returns a copy of this variable after removing its underlying array. + * The new Variable is used to store the inputs for a Variable + * which doesn't need the output. + */ + Variable withoutData() const; + +private: + using DAG = std::vector; + + /** + * Get all the inputs to this Variable + */ + std::vector& getInputs() const; + + /** + * Builds the computation graph which comprises of all the input Variables for + * which the gradient of `var` can be propagated using chain rule + */ + DAG build() const; + + /** + * Calculate the gradient of inputs. + * @param[in] retainGraph If False, clears the inputs stored + * by the Variable + */ + void calcGradInputs(bool retainGraph = false); + + /** + * Calls the gradient hook (if any) registered by the Variable + */ + void applyGradHook(); + + struct SharedData { + /// Array wrapped by this Variable + Tensor data; + + FL_SAVE_LOAD(data) + }; + + struct SharedGrad { + /// Whether the gradient should be computed for this Variable + bool calcGrad{false}; + /// Inputs of this Variable + std::vector inputs; + /// Gradient with respect to this Variable + std::unique_ptr grad{nullptr}; + /// Function for calculating the gradient of the input Variables + GradFunc gradFunc{nullptr}; + /// Function applied to gradient after it's computed during bwd pass + GradHook onGradAvailable{nullptr}; + + private: + FL_SAVE_LOAD(calcGrad); + }; + + std::shared_ptr sharedData_ = std::make_shared(); + std::shared_ptr sharedGrad_ = std::make_shared(); + + // NB: array only; we don't try to serialize the autograd graph + // Saving the sharedData ptr helps to avoid saving variables which share the + // same underlying tensor twice + FL_SAVE_LOAD(sharedData_, sharedGrad_) }; } // namespace fl diff --git a/flashlight/fl/autograd/tensor/AutogradExtension.h b/flashlight/fl/autograd/tensor/AutogradExtension.h index dc87b76..dfb96f7 100644 --- a/flashlight/fl/autograd/tensor/AutogradExtension.h +++ b/flashlight/fl/autograd/tensor/AutogradExtension.h @@ -17,13 +17,13 @@ class DynamicBenchmark; namespace detail { -struct RNNGradData; + struct RNNGradData; /* * A base type that can be used to construct autograd payloads - this is * arbitrary data that can persist between forward and backward operations. */ -struct AutogradPayloadData {}; + struct AutogradPayloadData {}; /** * A simple type with semantics for assigning autograd payloads. It has the @@ -37,146 +37,155 @@ struct AutogradPayloadData {}; * guarantees: * - TODO: write me -- same type of payload so can be safely downcast */ -struct AutogradPayload { - std::shared_ptr data; -}; + struct AutogradPayload { + std::shared_ptr data; + }; } // namespace detail class AutogradExtension : public TensorExtension { - public: - virtual ~AutogradExtension() = default; - - static constexpr TensorExtensionType extensionType = - TensorExtensionType::Autograd; - - virtual std::shared_ptr createBenchmarkOptions() { - return nullptr; - } - - /**************************** Forward ****************************/ - virtual Tensor conv2d( - const Tensor& input, - const Tensor& weights, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr payload) = 0; - - virtual Tensor pool2d( - const Tensor& input, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, - std::shared_ptr payload) = 0; - - virtual Tensor batchnorm( - Tensor& saveMean, - Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - Tensor& runningMean, - Tensor& runningVar, - const std::vector& axes, - const bool train, - const double momentum, - const double epsilon, - std::shared_ptr payload) = 0; - - virtual std::tuple rnn( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const int hiddenSize, - const int numLayers, - const RnnMode mode, - const bool bidirectional, - const float dropout, - std::shared_ptr payload) = 0; - - /**************************** Backward ****************************/ - // ]----- conv2d - virtual Tensor conv2dBackwardData( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& weight, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr dataGradBenchmark, - std::shared_ptr payload) = 0; - - virtual std::pair conv2dBackwardFilterBias( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& weights, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr filterBench, - std::shared_ptr biasBench, - std::shared_ptr autogradPayload) = 0; - - // ]----- pool2D - virtual Tensor pool2dBackward( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& poolOutput, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, - std::shared_ptr payload) = 0; - - // ]----- batchnorm - virtual std::tuple batchnormBackward( - const Tensor& gradOutput, - const Tensor& saveMean, - const Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const std::vector& axes, - const bool train, - const float epsilon, - std::shared_ptr payload) = 0; - - // ]----- rnn - virtual std::tuple rnnBackward( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const std::shared_ptr gradData, - const Tensor& output, - const int numLayers, - const int hiddenSize, - const RnnMode mode, - const bool bidirectional, - const float dropProb, - std::shared_ptr payload) = 0; +public: + virtual ~AutogradExtension() = default; + + static constexpr TensorExtensionType extensionType = + TensorExtensionType::Autograd; + + virtual std::shared_ptr createBenchmarkOptions() { + return nullptr; + } + + /**************************** Forward ****************************/ + virtual Tensor conv2d( + const Tensor& input, + const Tensor& weights, + const Tensor& bias, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr payload + ) = 0; + + virtual Tensor pool2d( + const Tensor& input, + const int wx, + const int wy, + const int sx, + const int sy, + const int px, + const int py, + const PoolingMode mode, + std::shared_ptr payload + ) = 0; + + virtual Tensor batchnorm( + Tensor& saveMean, + Tensor& saveVar, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + Tensor& runningMean, + Tensor& runningVar, + const std::vector& axes, + const bool train, + const double momentum, + const double epsilon, + std::shared_ptr payload + ) = 0; + + virtual std::tuple rnn( + const Tensor& input, + const Tensor& hiddenState, + const Tensor& cellState, + const Tensor& weights, + const int hiddenSize, + const int numLayers, + const RnnMode mode, + const bool bidirectional, + const float dropout, + std::shared_ptr payload + ) = 0; + + /**************************** Backward ****************************/ + // ]----- conv2d + virtual Tensor conv2dBackwardData( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& weight, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr dataGradBenchmark, + std::shared_ptr payload + ) = 0; + + virtual std::pair conv2dBackwardFilterBias( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& weights, + const Tensor& bias, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr filterBench, + std::shared_ptr biasBench, + std::shared_ptr autogradPayload + ) = 0; + + // ]----- pool2D + virtual Tensor pool2dBackward( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& poolOutput, + const int wx, + const int wy, + const int sx, + const int sy, + const int px, + const int py, + const PoolingMode mode, + std::shared_ptr payload + ) = 0; + + // ]----- batchnorm + virtual std::tuple batchnormBackward( + const Tensor& gradOutput, + const Tensor& saveMean, + const Tensor& saveVar, + const Tensor& input, + const Tensor& weight, + const std::vector& axes, + const bool train, + const float epsilon, + std::shared_ptr payload + ) = 0; + + // ]----- rnn + virtual std::tuple rnnBackward( + const Tensor& input, + const Tensor& hiddenState, + const Tensor& cellState, + const Tensor& weights, + const std::shared_ptr gradData, + const Tensor& output, + const int numLayers, + const int hiddenSize, + const RnnMode mode, + const bool bidirectional, + const float dropProb, + std::shared_ptr payload + ) = 0; }; } // namespace fl diff --git a/flashlight/fl/autograd/tensor/AutogradExtensionBackends.h b/flashlight/fl/autograd/tensor/AutogradExtensionBackends.h index 27e69e9..dfd27d6 100644 --- a/flashlight/fl/autograd/tensor/AutogradExtensionBackends.h +++ b/flashlight/fl/autograd/tensor/AutogradExtensionBackends.h @@ -14,10 +14,10 @@ * Conditionally include autograd extensions */ #if FL_USE_CUDNN - #include "flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h" +#include "flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h" #endif // FL_USE_CUDNN #if FL_USE_ONEDNN - #include "flashlight/fl/autograd/tensor/backend/onednn/OneDnnAutogradExtension.h" +#include "flashlight/fl/autograd/tensor/backend/onednn/OneDnnAutogradExtension.h" #endif // FL_USE_ONEDNN namespace fl { @@ -26,15 +26,15 @@ namespace fl { // TODO{fl::Tensor} -- improve macros based on compute envs #if FL_USE_CUDNN - #if FL_USE_ARRAYFIRE && FL_ARRAYFIRE_USE_CUDA +#if FL_USE_ARRAYFIRE && FL_ARRAYFIRE_USE_CUDA FL_REGISTER_TENSOR_EXTENSION(CudnnAutogradExtension, ArrayFire); - #endif // FL_USE_ARRAYFIRE && FL_ARRAYFIRE_USE_CUDA +#endif // FL_USE_ARRAYFIRE && FL_ARRAYFIRE_USE_CUDA #endif // FL_USE_CUDNN #if FL_USE_ONEDNN - #if FL_USE_ARRAYFIRE && (FL_ARRAYFIRE_USE_CPU || FL_ARRAYFIRE_USE_OPENCL) +#if FL_USE_ARRAYFIRE && (FL_ARRAYFIRE_USE_CPU || FL_ARRAYFIRE_USE_OPENCL) FL_REGISTER_TENSOR_EXTENSION(OneDnnAutogradExtension, ArrayFire); - #endif +#endif #endif // FL_USE_ONEDNN } // namespace fl diff --git a/flashlight/fl/autograd/tensor/AutogradOps.cpp b/flashlight/fl/autograd/tensor/AutogradOps.cpp index 3cb3048..f01a9b3 100644 --- a/flashlight/fl/autograd/tensor/AutogradOps.cpp +++ b/flashlight/fl/autograd/tensor/AutogradOps.cpp @@ -21,9 +21,10 @@ Tensor conv2d( const int py, const int dx, const int dy, - const int groups) { - auto dummyBias = Tensor(input.type()); - return conv2d(input, weights, dummyBias, sx, sy, px, py, dx, dy, groups); + const int groups +) { + auto dummyBias = Tensor(input.type()); + return conv2d(input, weights, dummyBias, sx, sy, px, py, dx, dy, groups); } Tensor conv2d( @@ -36,19 +37,21 @@ Tensor conv2d( const int py, const int dx, const int dy, - const int groups) { - return detail::conv2d( - input, - weights, - bias, - sx, - sy, - px, - py, - dx, - dy, - groups, - /* payload = */ nullptr); + const int groups +) { + return detail::conv2d( + input, + weights, + bias, + sx, + sy, + px, + py, + dx, + dy, + groups, + /* payload = */ nullptr + ); } Tensor pool2d( @@ -59,9 +62,19 @@ Tensor pool2d( const int sy, const int px, const int py, - const PoolingMode mode) { - return detail::pool2d( - input, wx, wy, sx, sy, px, py, mode, /* payload = */ nullptr); + const PoolingMode mode +) { + return detail::pool2d( + input, + wx, + wy, + sx, + sy, + px, + py, + mode, /* payload = */ + nullptr + ); } Tensor batchnorm( @@ -73,22 +86,24 @@ Tensor batchnorm( const std::vector& axes, const bool train, const double momentum, - const double epsilon) { - Tensor saveMean; // empty - Tensor saveVar; // empty - return detail::batchnorm( - saveMean, - saveVar, - input, - weight, - bias, - runningMean, - runningVar, - axes, - train, - momentum, - epsilon, - /* payload = */ nullptr); + const double epsilon +) { + Tensor saveMean; // empty + Tensor saveVar; // empty + return detail::batchnorm( + saveMean, + saveVar, + input, + weight, + bias, + runningMean, + runningVar, + axes, + train, + momentum, + epsilon, + /* payload = */ nullptr + ); } Tensor batchnorm( @@ -102,20 +117,22 @@ Tensor batchnorm( const std::vector& axes, const bool train, const double momentum, - const double epsilon) { - return detail::batchnorm( - saveMean, - saveVar, - input, - weight, - bias, - runningMean, - runningVar, - axes, - train, - momentum, - epsilon, - /*payload = */ nullptr); + const double epsilon +) { + return detail::batchnorm( + saveMean, + saveVar, + input, + weight, + bias, + runningMean, + runningVar, + axes, + train, + momentum, + epsilon, + /*payload = */ nullptr + ); } std::tuple rnn( @@ -127,238 +144,286 @@ std::tuple rnn( const int numLayers, const RnnMode mode, const bool bidirectional, - const float dropout) { - return detail::rnn( - input, - hiddenState, - cellState, - weights, - hiddenSize, - numLayers, - mode, - bidirectional, - dropout, - /* payload = */ nullptr); + const float dropout +) { + return detail::rnn( + input, + hiddenState, + cellState, + weights, + hiddenSize, + numLayers, + mode, + bidirectional, + dropout, + /* payload = */ nullptr + ); } namespace detail { -Tensor conv2d( - const Tensor& input, - const Tensor& weights, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr payload) { - return input.backend().getExtension().conv2d( - input, weights, bias, sx, sy, px, py, dx, dy, groups, payload); -} + Tensor conv2d( + const Tensor& input, + const Tensor& weights, + const Tensor& bias, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr payload + ) { + return input.backend().getExtension().conv2d( + input, + weights, + bias, + sx, + sy, + px, + py, + dx, + dy, + groups, + payload + ); + } -Tensor batchnorm( - Tensor& saveMean, - Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - Tensor& runningMean, - Tensor& runningVar, - const std::vector& axes, - const bool train, - const double momentum, - const double epsilon, - std::shared_ptr payload) { - return input.backend().getExtension().batchnorm( - saveMean, - saveVar, - input, - weight, - bias, - runningMean, - runningVar, - axes, - train, - momentum, - epsilon, - payload); -} + Tensor batchnorm( + Tensor& saveMean, + Tensor& saveVar, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + Tensor& runningMean, + Tensor& runningVar, + const std::vector& axes, + const bool train, + const double momentum, + const double epsilon, + std::shared_ptr payload + ) { + return input.backend().getExtension().batchnorm( + saveMean, + saveVar, + input, + weight, + bias, + runningMean, + runningVar, + axes, + train, + momentum, + epsilon, + payload + ); + } -Tensor pool2d( - const Tensor& input, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, - std::shared_ptr payload) { - return input.backend().getExtension().pool2d( - input, wx, wy, sx, sy, px, py, mode, payload); -} + Tensor pool2d( + const Tensor& input, + const int wx, + const int wy, + const int sx, + const int sy, + const int px, + const int py, + const PoolingMode mode, + std::shared_ptr payload + ) { + return input.backend().getExtension().pool2d( + input, + wx, + wy, + sx, + sy, + px, + py, + mode, + payload + ); + } -std::tuple rnn( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const int hiddenSize, - const int numLayers, - const RnnMode mode, - const bool bidirectional, - const float dropout, - std::shared_ptr payload) { - return input.backend().getExtension().rnn( - input, - hiddenState, - cellState, - weights, - hiddenSize, - numLayers, - mode, - bidirectional, - dropout, - payload); -} + std::tuple rnn( + const Tensor& input, + const Tensor& hiddenState, + const Tensor& cellState, + const Tensor& weights, + const int hiddenSize, + const int numLayers, + const RnnMode mode, + const bool bidirectional, + const float dropout, + std::shared_ptr payload + ) { + return input.backend().getExtension().rnn( + input, + hiddenState, + cellState, + weights, + hiddenSize, + numLayers, + mode, + bidirectional, + dropout, + payload + ); + } -Tensor conv2dBackwardData( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& weight, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr dataGradBenchmark, - std::shared_ptr payload) { - return input.backend().getExtension().conv2dBackwardData( - gradOutput, - input, - weight, - sx, - sy, - px, - py, - dx, - dy, - groups, - dataGradBenchmark, - payload); -} + Tensor conv2dBackwardData( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& weight, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr dataGradBenchmark, + std::shared_ptr payload + ) { + return input.backend().getExtension().conv2dBackwardData( + gradOutput, + input, + weight, + sx, + sy, + px, + py, + dx, + dy, + groups, + dataGradBenchmark, + payload + ); + } -std::pair conv2dBackwardFilterBias( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& filter, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr filterGradBenchmark, - std::shared_ptr biasGradBenchmark, - std::shared_ptr payload) { - return input.backend() - .getExtension() - .conv2dBackwardFilterBias( - gradOutput, - input, - filter, - bias, - sx, - sy, - px, - py, - dx, - dy, - groups, - filterGradBenchmark, - biasGradBenchmark, - payload); -} + std::pair conv2dBackwardFilterBias( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& filter, + const Tensor& bias, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr filterGradBenchmark, + std::shared_ptr biasGradBenchmark, + std::shared_ptr payload + ) { + return input.backend() + .getExtension() + .conv2dBackwardFilterBias( + gradOutput, + input, + filter, + bias, + sx, + sy, + px, + py, + dx, + dy, + groups, + filterGradBenchmark, + biasGradBenchmark, + payload + ); + } -Tensor pool2dBackward( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& poolOutput, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, - std::shared_ptr payload) { - return input.backend().getExtension().pool2dBackward( - gradOutput, input, poolOutput, wx, wy, sx, sy, px, py, mode, payload); -} + Tensor pool2dBackward( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& poolOutput, + const int wx, + const int wy, + const int sx, + const int sy, + const int px, + const int py, + const PoolingMode mode, + std::shared_ptr payload + ) { + return input.backend().getExtension().pool2dBackward( + gradOutput, + input, + poolOutput, + wx, + wy, + sx, + sy, + px, + py, + mode, + payload + ); + } // Returns the gradinets with respect tot he input, hidden state cell state, and // weights respectively // Why one function for gradient of all of them? Most // implementations don't support computing separate gradients. If support for // this is added in most places, split out this function. -std::tuple batchnormBackward( - const Tensor& gradOutput, - const Tensor& saveMean, - const Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const std::vector& axes, - const bool train, - const float epsilon, - std::shared_ptr payload) { - return gradOutput.backend() - .getExtension() - .batchnormBackward( - gradOutput, - saveMean, - saveVar, - input, - weight, - axes, - train, - epsilon, - payload); -} + std::tuple batchnormBackward( + const Tensor& gradOutput, + const Tensor& saveMean, + const Tensor& saveVar, + const Tensor& input, + const Tensor& weight, + const std::vector& axes, + const bool train, + const float epsilon, + std::shared_ptr payload + ) { + return gradOutput.backend() + .getExtension() + .batchnormBackward( + gradOutput, + saveMean, + saveVar, + input, + weight, + axes, + train, + epsilon, + payload + ); + } -std::tuple rnnBackward( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const std::shared_ptr gradData, - const Tensor& output, - const int numLayers, - const int hiddenSize, - const RnnMode mode, - const bool bidirectional, - const float dropProb, - std::shared_ptr payload) { - return input.backend().getExtension().rnnBackward( - input, - hiddenState, - cellState, - weights, - gradData, - output, - numLayers, - hiddenSize, - mode, - bidirectional, - dropProb, - payload); -} + std::tuple rnnBackward( + const Tensor& input, + const Tensor& hiddenState, + const Tensor& cellState, + const Tensor& weights, + const std::shared_ptr gradData, + const Tensor& output, + const int numLayers, + const int hiddenSize, + const RnnMode mode, + const bool bidirectional, + const float dropProb, + std::shared_ptr payload + ) { + return input.backend().getExtension().rnnBackward( + input, + hiddenState, + cellState, + weights, + gradData, + output, + numLayers, + hiddenSize, + mode, + bidirectional, + dropProb, + payload + ); + } } // namespace detail diff --git a/flashlight/fl/autograd/tensor/AutogradOps.h b/flashlight/fl/autograd/tensor/AutogradOps.h index 2929ef1..78a5a9b 100644 --- a/flashlight/fl/autograd/tensor/AutogradOps.h +++ b/flashlight/fl/autograd/tensor/AutogradOps.h @@ -18,12 +18,12 @@ namespace fl { class DynamicBenchmark; namespace detail { -struct AutogradPayload; + struct AutogradPayload; } /** * Applies a 2D convolution over an input signal given filter weights. In - the + the * simplest case, the output with shape [\f$X_{out}\f$, \f$Y_{out}\f$, * \f$C_{out}\f$, \f$N\f$] of the convolution with input [\f$X_{in}\f$, * \f$Y_{in}\f$, \f$C_{in}\f$, \f$N\f$] and weight [\f$K_x\f$, \f$K_y\f$, @@ -34,7 +34,7 @@ struct AutogradPayload; \text{input}(k, N_i) * \f] * @param input a Tensor with shape [\f$X_{in}\f$, \f$Y_{in}\f$, - \f$C_{in}\f$, + \f$C_{in}\f$, * \f$N\f$] * @param weights a Tensor with shape [\f$K_x\f$, \f$K_y\f$, \f$C_{in}\f$, * \f$C_{out}\f$] @@ -63,14 +63,15 @@ FL_API Tensor conv2d( const int py = 0, const int dx = 1, const int dy = 1, - const int groups = 1); + const int groups = 1 +); /** * Applies a 2D convolution over an input signal given filter weights and * biases. In the simplest case, the output with shape [\f$X_{out}\f$, * \f$Y_{out}\f$, \f$C_{out}\f$, \f$N\f$] of the convolution with input * [\f$X_{in}\f$, \f$Y_{in}\f$, \f$C_{in}\f$, \f$N\f$] and weight - [\f$K_x\f$, + [\f$K_x\f$, * \f$K_y\f$, \f$C_{in}\f$, \f$C_{out}\f$] can be precisely described as: * \f[ \text{out}(C_{out_j}, N_i) = @@ -80,7 +81,7 @@ FL_API Tensor conv2d( * \f] * @param input a Tensor with shape [\f$X_{in}\f$, \f$Y_{in}\f$, - \f$C_{in}\f$, + \f$C_{in}\f$, * \f$N\f$] * @param weights a Tensor with shape [\f$K_x\f$, \f$K_y\f$, \f$C_{in}\f$, * \f$C_{out}\f$] @@ -111,7 +112,8 @@ FL_API Tensor conv2d( const int py = 0, const int dx = 1, const int dy = 1, - const int groups = 1); + const int groups = 1 +); /** * Applies a 2D pooling over an input signal composed of several input planes. @@ -138,35 +140,36 @@ FL_API Tensor pool2d( const int sy = 1, const int px = 0, const int py = 0, - const PoolingMode mode = PoolingMode::MAX); + const PoolingMode mode = PoolingMode::MAX +); /** -* Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with -* additional channel dimension) as described in the paper -* [Batch Normalization: Accelerating Deep Network Training by Reducing Internal -* Covariate Shift] (https://arxiv.org/abs/1502.03167) . -* \f[ -* y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + -* \beta -* \f] -* The mean and standard-deviation are calculated per-dimension over the -* mini-batches and \f$\gamma\f$ and \f$\beta\f$ are learnable parameter vectors -* of size \f$C\f$, the input size. By default, during training this layer keeps -* running estimates of its computed mean and variance, which are then used for -* normalization during evaluation. + * Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with + * additional channel dimension) as described in the paper + * [Batch Normalization: Accelerating Deep Network Training by Reducing Internal + * Covariate Shift] (https://arxiv.org/abs/1502.03167) . + * \f[ + * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + + * \beta + * \f] + * The mean and standard-deviation are calculated per-dimension over the + * mini-batches and \f$\gamma\f$ and \f$\beta\f$ are learnable parameter vectors + * of size \f$C\f$, the input size. By default, during training this layer keeps + * running estimates of its computed mean and variance, which are then used for + * normalization during evaluation. -* @param input a Tensor with size [\f$H\f$, \f$W\f$, \f$C\f$, \f$N\f$] -* @param weight a Tensor with size [\f$C\f$] for \f$\gamma\f$ -* @param bias a Tensor with size [\f$C\f$] for \f$\beta\f$ -* @param runningMean a buffer storing intermediate means during training -* @param runningVar a buffer storing intermediate variances during training -* @param axes dimensions to perform normalization on. If having size greater -* than one, reduce over all of them. -* @param train a flag indicating if running in training mode -* @param momentum value of momentum -* @param epsilon value of \f$\epsilon\f$ -* @return a Tensor with same shape as `input` -*/ + * @param input a Tensor with size [\f$H\f$, \f$W\f$, \f$C\f$, \f$N\f$] + * @param weight a Tensor with size [\f$C\f$] for \f$\gamma\f$ + * @param bias a Tensor with size [\f$C\f$] for \f$\beta\f$ + * @param runningMean a buffer storing intermediate means during training + * @param runningVar a buffer storing intermediate variances during training + * @param axes dimensions to perform normalization on. If having size greater + * than one, reduce over all of them. + * @param train a flag indicating if running in training mode + * @param momentum value of momentum + * @param epsilon value of \f$\epsilon\f$ + * @return a Tensor with same shape as `input` + */ FL_API Tensor batchnorm( const Tensor& input, const Tensor& weight, @@ -176,7 +179,8 @@ FL_API Tensor batchnorm( const std::vector& axes, const bool train, const double momentum, - const double epsilon); + const double epsilon +); FL_API Tensor batchnorm( Tensor& saveMean, @@ -189,46 +193,47 @@ FL_API Tensor batchnorm( const std::vector& axes, const bool train, const double momentum, - const double epsilon); + const double epsilon +); /** -* Applies an RNN unit to an input sequence. -* A general RNN operator can be expressed as following: -* \f[ - (h_t, c_t) = f_W(x_t, h_{t-1}, c_{t-1}) -* \f] -* where \f$h_t\f$, \f$c_t\f$ are the hidden/cell state at time \f$t\f$, -* \f$x_t\f$ is the input at time \f$t\f$ -* -* \note{cuDNN and oneDNN RNN weights are incompatible since the structure of -* the computation is different for each. There is no mapping between weights -* from each of those backends.} -* -* @param input Tensor of input with shape [input size, batch size, sequence -* length] -* @param hiddenState Tensor of hidden state with shape [hidden size, batch -* size, total layers] -* @param cellState [LSTM only] Tensor of cell state with same shape as -* hidden state -* @param weights Learnable parameters in the RNN unit -* @param hiddenSize number of features in the hidden state -* @param numLayers number of recurrent layers -* @param mode defines the type of RNN unit -* - RELU -* - TANH -* - LSTM -* - GRU -* @param bidirectional if `True`, becomes a bidirectional RNN, unidirectional -otherwise -* @param dropout if non-zero, introduces a `Dropout` layer on the outputs of -* each RNN layer except the last one,q with dropout probability equal to dropout + * Applies an RNN unit to an input sequence. + * A general RNN operator can be expressed as following: + * \f[ + (h_t, c_t) = f_W(x_t, h_{t-1}, c_{t-1}) + * \f] + * where \f$h_t\f$, \f$c_t\f$ are the hidden/cell state at time \f$t\f$, + * \f$x_t\f$ is the input at time \f$t\f$ + * + * \note{cuDNN and oneDNN RNN weights are incompatible since the structure of + * the computation is different for each. There is no mapping between weights + * from each of those backends.} + * + * @param input Tensor of input with shape [input size, batch size, sequence + * length] + * @param hiddenState Tensor of hidden state with shape [hidden size, batch + * size, total layers] + * @param cellState [LSTM only] Tensor of cell state with same shape as + * hidden state + * @param weights Learnable parameters in the RNN unit + * @param hiddenSize number of features in the hidden state + * @param numLayers number of recurrent layers + * @param mode defines the type of RNN unit + * - RELU + * - TANH + * - LSTM + * - GRU + * @param bidirectional if `True`, becomes a bidirectional RNN, unidirectional + otherwise + * @param dropout if non-zero, introduces a `Dropout` layer on the outputs of + * each RNN layer except the last one,q with dropout probability equal to dropout -* @return a tuple of three Tensors: -* - `y`: input with shape [input size, batch size, sequence length * -* directions] -* - `hiddenState`: hidden state for the current time step -* - `cellState`: cell state for the current time step -*/ + * @return a tuple of three Tensors: + * - `y`: input with shape [input size, batch size, sequence length * + * directions] + * - `hiddenState`: hidden state for the current time step + * - `cellState`: cell state for the current time step + */ FL_API std::tuple rnn( const Tensor& input, const Tensor& hiddenState, @@ -238,143 +243,153 @@ FL_API std::tuple rnn( const int numLayers, const RnnMode mode, const bool bidirectional, - const float dropout); + const float dropout +); namespace detail { -FL_API Tensor conv2d( - const Tensor& input, - const Tensor& weights, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr payload); + FL_API Tensor conv2d( + const Tensor& input, + const Tensor& weights, + const Tensor& bias, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr payload + ); -FL_API Tensor batchnorm( - Tensor& saveMean, - Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - Tensor& runningMean, - Tensor& runningVar, - const std::vector& axes, - const bool train, - const double momentum, - const double epsilon, - std::shared_ptr payload); + FL_API Tensor batchnorm( + Tensor& saveMean, + Tensor& saveVar, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + Tensor& runningMean, + Tensor& runningVar, + const std::vector& axes, + const bool train, + const double momentum, + const double epsilon, + std::shared_ptr payload + ); -FL_API Tensor pool2d( - const Tensor& input, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, - std::shared_ptr payload); + FL_API Tensor pool2d( + const Tensor& input, + const int wx, + const int wy, + const int sx, + const int sy, + const int px, + const int py, + const PoolingMode mode, + std::shared_ptr payload + ); -FL_API std::tuple rnn( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const int hiddenSize, - const int numLayers, - const RnnMode mode, - const bool bidirectional, - const float dropout, - std::shared_ptr payload); + FL_API std::tuple rnn( + const Tensor& input, + const Tensor& hiddenState, + const Tensor& cellState, + const Tensor& weights, + const int hiddenSize, + const int numLayers, + const RnnMode mode, + const bool bidirectional, + const float dropout, + std::shared_ptr payload + ); // Returns the gradient with respect to the input -FL_API Tensor conv2dBackwardData( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& weight, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr dataGradBenchmark, - std::shared_ptr payload); + FL_API Tensor conv2dBackwardData( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& weight, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr dataGradBenchmark, + std::shared_ptr payload + ); // Returns the gradient with respect to the filter and bias (if given) -FL_API std::pair conv2dBackwardFilterBias( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& weights, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr filterBench, - std::shared_ptr biasBench, - std::shared_ptr payload); + FL_API std::pair conv2dBackwardFilterBias( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& weights, + const Tensor& bias, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr filterBench, + std::shared_ptr biasBench, + std::shared_ptr payload + ); -FL_API Tensor pool2dBackward( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& poolOutput, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, - std::shared_ptr payload); + FL_API Tensor pool2dBackward( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& poolOutput, + const int wx, + const int wy, + const int sx, + const int sy, + const int px, + const int py, + const PoolingMode mode, + std::shared_ptr payload + ); // Returns the gradinets with respect tot he input, weight, and bias, // respectively // Why one function for gradient of all of them? Most implementations don't // support computing separate gradients. If support for this is added in most // places, split out this function. -FL_API std::tuple batchnormBackward( - const Tensor& gradOutput, - const Tensor& saveMean, - const Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const std::vector& axes, - const bool train, - const float epsilon, - std::shared_ptr payload); + FL_API std::tuple batchnormBackward( + const Tensor& gradOutput, + const Tensor& saveMean, + const Tensor& saveVar, + const Tensor& input, + const Tensor& weight, + const std::vector& axes, + const bool train, + const float epsilon, + std::shared_ptr payload + ); -struct RNNGradData { - fl::Tensor dy; - fl::Tensor dhy; - fl::Tensor dcy; -}; + struct RNNGradData { + fl::Tensor dy; + fl::Tensor dhy; + fl::Tensor dcy; + }; // input gradient, hidden state gradient, cell state gradient, weights // gradient // @param[in] gradData grad output for each comp -FL_API std::tuple rnnBackward( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const std::shared_ptr gradData, - const Tensor& output, - const int numLayers, - const int hiddenSize, - const RnnMode mode, - const bool bidirectional, - const float dropProb, - std::shared_ptr payload); + FL_API std::tuple rnnBackward( + const Tensor& input, + const Tensor& hiddenState, + const Tensor& cellState, + const Tensor& weights, + const std::shared_ptr gradData, + const Tensor& output, + const int numLayers, + const int hiddenSize, + const RnnMode mode, + const bool bidirectional, + const float dropProb, + std::shared_ptr payload + ); } // namespace detail diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp index 6e6aff0..1e7dec9 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp @@ -19,60 +19,66 @@ namespace fl { namespace { -void getBatchnormMetadata( - cudnnBatchNormMode_t& modeOut, - Shape& inDescDimsOut, - Shape& wtDescDimsOut, - const Tensor& input, - const std::vector& axes, - const bool train) { - int nfeatures = 1; - for (auto ax : axes) { - if (ax > input.ndim() - 1) { - throw std::invalid_argument( - "batchnorm - passed axes (axis value " + std::to_string(ax) + - ") exceeds the number of dimensions of the input (" + - std::to_string(input.ndim()) + ")"); - } - nfeatures *= input.dim(ax); - } + void getBatchnormMetadata( + cudnnBatchNormMode_t& modeOut, + Shape& inDescDimsOut, + Shape& wtDescDimsOut, + const Tensor& input, + const std::vector& axes, + const bool train + ) { + int nfeatures = 1; + for(auto ax : axes) { + if(ax > input.ndim() - 1) { + throw std::invalid_argument( + "batchnorm - passed axes (axis value " + std::to_string(ax) + + ") exceeds the number of dimensions of the input (" + + std::to_string(input.ndim()) + ")" + ); + } + nfeatures *= input.dim(ax); + } - auto maxAxis = *std::max_element(axes.begin(), axes.end()); - auto minAxis = *std::min_element(axes.begin(), axes.end()); + auto maxAxis = *std::max_element(axes.begin(), axes.end()); + auto minAxis = *std::min_element(axes.begin(), axes.end()); - // assuming no duplicates - bool axes_continuous = (axes.size() == (maxAxis - minAxis + 1)); - if (!axes_continuous) { - throw std::invalid_argument("unsupported axis config for cuDNN batchnorm"); - } + // assuming no duplicates + bool axes_continuous = (axes.size() == (maxAxis - minAxis + 1)); + if(!axes_continuous) { + throw std::invalid_argument("unsupported axis config for cuDNN batchnorm"); + } - if (minAxis == 0) { - modeOut = CUDNN_BATCHNORM_PER_ACTIVATION; - inDescDimsOut = Shape( - {1, - 1, - nfeatures, - static_cast(input.elements() / nfeatures)}); - wtDescDimsOut = Shape({1, 1, nfeatures}); - } else { - modeOut = CUDNN_BATCHNORM_SPATIAL; + if(minAxis == 0) { + modeOut = CUDNN_BATCHNORM_PER_ACTIVATION; + inDescDimsOut = Shape( + {1, + 1, + nfeatures, + static_cast(input.elements() / nfeatures)} + ); + wtDescDimsOut = Shape({1, 1, nfeatures}); + } else { + modeOut = CUDNN_BATCHNORM_SPATIAL; #if CUDNN_VERSION >= 7003 - if (train) { - modeOut = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; - } + if(train) { + modeOut = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + } #endif - int batchsz = 1; - for (int i = maxAxis + 1; i < input.ndim(); ++i) { - batchsz *= input.dim(i); + int batchsz = 1; + for(int i = maxAxis + 1; i < input.ndim(); ++i) { + batchsz *= input.dim(i); + } + inDescDimsOut = Shape( + { + 1, + static_cast(input.elements() / (nfeatures * batchsz)), + nfeatures, + batchsz, + } + ); + wtDescDimsOut = Shape({1, 1, nfeatures}); + } } - inDescDimsOut = Shape( - {1, - static_cast(input.elements() / (nfeatures * batchsz)), - nfeatures, - batchsz}); - wtDescDimsOut = Shape({1, 1, nfeatures}); - } -} } // namespace @@ -88,99 +94,108 @@ Tensor CudnnAutogradExtension::batchnorm( const bool train, const double momentum, const double epsilon, - std::shared_ptr) { - if (input.type() == fl::dtype::f16 && weight.type() != fl::dtype::f32) { - throw std::invalid_argument( - "fl::batchnorm: non-input tensors must be of type f32"); - } - FL_TENSOR_DTYPES_MATCH_CHECK(weight, bias, runningMean, runningVar); + std::shared_ptr +) { + if(input.type() == fl::dtype::f16 && weight.type() != fl::dtype::f32) { + throw std::invalid_argument( + "fl::batchnorm: non-input tensors must be of type f32" + ); + } + FL_TENSOR_DTYPES_MATCH_CHECK(weight, bias, runningMean, runningVar); - auto output = Tensor(input.shape(), input.type()); + auto output = Tensor(input.shape(), input.type()); - cudnnBatchNormMode_t mode; - Shape inDescDims, wtDescDims; - getBatchnormMetadata(mode, inDescDims, wtDescDims, input, axes, train); + cudnnBatchNormMode_t mode; + Shape inDescDims, wtDescDims; + getBatchnormMetadata(mode, inDescDims, wtDescDims, input, axes, train); - if (!weight.isEmpty() && weight.elements() != wtDescDims.elements()) { - throw std::invalid_argument("[BatchNorm] Invalid shape for weight."); - } + if(!weight.isEmpty() && weight.elements() != wtDescDims.elements()) { + throw std::invalid_argument("[BatchNorm] Invalid shape for weight."); + } - if (!bias.isEmpty() && bias.elements() != wtDescDims.elements()) { - throw std::invalid_argument("[BatchNorm] Invalid shape for bias."); - } - // Weight, bias, and running mean/var arrays can't be fp16 (must be fp32) - Tensor weightArray = weight.isEmpty() - ? fl::full(wtDescDims, 1.0, fl::dtype::f32) - : weight.astype(fl::dtype::f32); - Tensor biasArray = bias.isEmpty() ? fl::full(wtDescDims, 0.0, fl::dtype::f32) - : bias.astype(fl::dtype::f32); + if(!bias.isEmpty() && bias.elements() != wtDescDims.elements()) { + throw std::invalid_argument("[BatchNorm] Invalid shape for bias."); + } + // Weight, bias, and running mean/var arrays can't be fp16 (must be fp32) + Tensor weightArray = weight.isEmpty() + ? fl::full(wtDescDims, 1.0, fl::dtype::f32) + : weight.astype(fl::dtype::f32); + Tensor biasArray = bias.isEmpty() ? fl::full(wtDescDims, 0.0, fl::dtype::f32) + : bias.astype(fl::dtype::f32); - fl::dtype scalarsType = - input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); + fl::dtype scalarsType = + input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); - auto inDesc = TensorDescriptor(input.type(), inDescDims); - auto wtDesc = TensorDescriptor(weightArray.type(), wtDescDims); + auto inDesc = TensorDescriptor(input.type(), inDescDims); + auto wtDesc = TensorDescriptor(weightArray.type(), wtDescDims); - { - DevicePtr inRaw(input); - DevicePtr outRaw(output); - DevicePtr wtRaw(weightArray); - DevicePtr bsRaw(biasArray); - DevicePtr runMeanRaw(runningMean); - DevicePtr runVarRaw(runningVar); - const auto& cudnnStream = getCudnnStream(); - // ensure cudnn compute stream waits on streams of input/output tensors - relativeSync( - cudnnStream, - {input, output, weightArray, biasArray, runningMean, runningVar}); + { + DevicePtr inRaw(input); + DevicePtr outRaw(output); + DevicePtr wtRaw(weightArray); + DevicePtr bsRaw(biasArray); + DevicePtr runMeanRaw(runningMean); + DevicePtr runVarRaw(runningVar); + const auto& cudnnStream = getCudnnStream(); + // ensure cudnn compute stream waits on streams of input/output tensors + relativeSync( + cudnnStream, + {input, output, weightArray, biasArray, runningMean, runningVar} + ); - if (train) { - saveMean = Tensor({wtDescDims[2]}, scalarsType); - saveVar = Tensor({wtDescDims[2]}, scalarsType); + if(train) { + saveMean = Tensor({wtDescDims[2]}, scalarsType); + saveVar = Tensor({wtDescDims[2]}, scalarsType); - DevicePtr saveMeanRaw(saveMean); - DevicePtr saveVarRaw(saveVar); - // ensure cudnn compute stream waits on streams of saveMean/Var tensors - relativeSync(cudnnStream, {saveMean, saveVar}); - CUDNN_CHECK_ERR(cudnnBatchNormalizationForwardTraining( - getCudnnHandle(), - mode, - kOne(scalarsType), - kZero(scalarsType), - inDesc.descriptor, - inRaw.get(), - inDesc.descriptor, - outRaw.get(), - wtDesc.descriptor, - wtRaw.get(), - bsRaw.get(), - momentum, - runMeanRaw.get(), - runVarRaw.get(), - epsilon, - saveMeanRaw.get(), - saveVarRaw.get())); - } else { - CUDNN_CHECK_ERR(cudnnBatchNormalizationForwardInference( - getCudnnHandle(), - mode, - kOne(scalarsType), - kZero(scalarsType), - inDesc.descriptor, - inRaw.get(), - inDesc.descriptor, - outRaw.get(), - wtDesc.descriptor, - wtRaw.get(), - bsRaw.get(), - runMeanRaw.get(), - runVarRaw.get(), - epsilon)); + DevicePtr saveMeanRaw(saveMean); + DevicePtr saveVarRaw(saveVar); + // ensure cudnn compute stream waits on streams of saveMean/Var tensors + relativeSync(cudnnStream, {saveMean, saveVar}); + CUDNN_CHECK_ERR( + cudnnBatchNormalizationForwardTraining( + getCudnnHandle(), + mode, + kOne(scalarsType), + kZero(scalarsType), + inDesc.descriptor, + inRaw.get(), + inDesc.descriptor, + outRaw.get(), + wtDesc.descriptor, + wtRaw.get(), + bsRaw.get(), + momentum, + runMeanRaw.get(), + runVarRaw.get(), + epsilon, + saveMeanRaw.get(), + saveVarRaw.get() + ) + ); + } else { + CUDNN_CHECK_ERR( + cudnnBatchNormalizationForwardInference( + getCudnnHandle(), + mode, + kOne(scalarsType), + kZero(scalarsType), + inDesc.descriptor, + inRaw.get(), + inDesc.descriptor, + outRaw.get(), + wtDesc.descriptor, + wtRaw.get(), + bsRaw.get(), + runMeanRaw.get(), + runVarRaw.get(), + epsilon + ) + ); + } + // ensure output stream waits on cudnn compute stream + relativeSync({output}, cudnnStream); } - // ensure output stream waits on cudnn compute stream - relativeSync({output}, cudnnStream); - } - return output; + return output; } std::tuple CudnnAutogradExtension::batchnormBackward( @@ -192,76 +207,82 @@ std::tuple CudnnAutogradExtension::batchnormBackward( const std::vector& axes, const bool train, // TODO(jacobkahn): remove this arg const float epsilon, - std::shared_ptr) { - if (!train) { - throw std::logic_error( - "can't compute batchnorm grad when train was not specified"); - } + std::shared_ptr +) { + if(!train) { + throw std::logic_error( + "can't compute batchnorm grad when train was not specified" + ); + } - cudnnBatchNormMode_t mode; - Shape inDescDims, wtDescDims; - getBatchnormMetadata(mode, inDescDims, wtDescDims, input, axes, train); + cudnnBatchNormMode_t mode; + Shape inDescDims, wtDescDims; + getBatchnormMetadata(mode, inDescDims, wtDescDims, input, axes, train); - auto wt = - weight.isEmpty() ? fl::full(wtDescDims, 1.0, fl::dtype::f32) : weight; + auto wt = + weight.isEmpty() ? fl::full(wtDescDims, 1.0, fl::dtype::f32) : weight; - // Weight, bias, and running mean/var arrays can't be fp16 (must be - // fp32) - auto scalarsType = - input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); - const void* one1 = kOne(scalarsType); - const void* zero0 = kZero(scalarsType); + // Weight, bias, and running mean/var arrays can't be fp16 (must be + // fp32) + auto scalarsType = + input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); + const void* one1 = kOne(scalarsType); + const void* zero0 = kZero(scalarsType); - auto iDesc = TensorDescriptor(input.type(), inDescDims); - auto wDesc = TensorDescriptor(wt.type(), wtDescDims); - // CuDNN doesn't support calculating only the gradients - // required for batchnorm - auto gradIn = Tensor(input.shape(), input.type()); - auto gradWt = Tensor(wt.shape(), wt.type()); - auto gradBs = Tensor(wt.shape(), wt.type()); - { - DevicePtr iRaw(input); - DevicePtr wRaw(wt); + auto iDesc = TensorDescriptor(input.type(), inDescDims); + auto wDesc = TensorDescriptor(wt.type(), wtDescDims); + // CuDNN doesn't support calculating only the gradients + // required for batchnorm + auto gradIn = Tensor(input.shape(), input.type()); + auto gradWt = Tensor(wt.shape(), wt.type()); + auto gradBs = Tensor(wt.shape(), wt.type()); + { + DevicePtr iRaw(input); + DevicePtr wRaw(wt); - DevicePtr gradInRaw(gradIn); - DevicePtr gradWtRaw(gradWt); - DevicePtr gradBsRaw(gradBs); + DevicePtr gradInRaw(gradIn); + DevicePtr gradWtRaw(gradWt); + DevicePtr gradBsRaw(gradBs); - DevicePtr gradOpRaw(gradOutput); + DevicePtr gradOpRaw(gradOutput); - DevicePtr saveMeanRaw(saveMean); - DevicePtr saveVarRaw(saveVar); - const auto& cudnnStream = getCudnnStream(); - // ensure cudnn compute stream waits on streams of input/output tensors - relativeSync( - cudnnStream, - {input, gradOutput, gradIn, wt, gradWt, gradBs, saveMean, saveVar}); + DevicePtr saveMeanRaw(saveMean); + DevicePtr saveVarRaw(saveVar); + const auto& cudnnStream = getCudnnStream(); + // ensure cudnn compute stream waits on streams of input/output tensors + relativeSync( + cudnnStream, + {input, gradOutput, gradIn, wt, gradWt, gradBs, saveMean, saveVar} + ); - CUDNN_CHECK_ERR(cudnnBatchNormalizationBackward( - getCudnnHandle(), - mode, - one1, - zero0, - one1, - zero0, - iDesc.descriptor, - iRaw.get(), - iDesc.descriptor, - gradOpRaw.get(), - iDesc.descriptor, - gradInRaw.get(), - wDesc.descriptor, - wRaw.get(), - gradWtRaw.get(), - gradBsRaw.get(), - epsilon, - saveMeanRaw.get(), - saveVarRaw.get())); - // ensure streams of gradients wait on the cudnn compute stream - relativeSync({gradIn, gradWt, gradBs}, cudnnStream); - } + CUDNN_CHECK_ERR( + cudnnBatchNormalizationBackward( + getCudnnHandle(), + mode, + one1, + zero0, + one1, + zero0, + iDesc.descriptor, + iRaw.get(), + iDesc.descriptor, + gradOpRaw.get(), + iDesc.descriptor, + gradInRaw.get(), + wDesc.descriptor, + wRaw.get(), + gradWtRaw.get(), + gradBsRaw.get(), + epsilon, + saveMeanRaw.get(), + saveVarRaw.get() + ) + ); + // ensure streams of gradients wait on the cudnn compute stream + relativeSync({gradIn, gradWt, gradBs}, cudnnStream); + } - return std::make_tuple(gradIn, gradWt, gradBs); + return std::make_tuple(gradIn, gradWt, gradBs); } } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp index 84e6eb8..0102cf5 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp @@ -23,7 +23,7 @@ namespace fl { namespace { -std::unordered_map + std::unordered_map kKernelModesToCudnnMathType = { {fl::CudnnAutogradExtension::KernelMode::F32, CUDNN_DEFAULT_MATH}, {fl::CudnnAutogradExtension::KernelMode::F32_ALLOW_CONVERSION, @@ -31,211 +31,264 @@ std::unordered_map {fl::CudnnAutogradExtension::KernelMode::F16, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION}}; -const std::unordered_set kFwdPreferredAlgos = { - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED}; + const std::unordered_set kFwdPreferredAlgos = { + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED}; -const std::unordered_set kBwdDataPreferredAlgos = + const std::unordered_set kBwdDataPreferredAlgos = {CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; -const std::unordered_set + const std::unordered_set kBwdFilterPreferredAlgos = { CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED}; -constexpr size_t kWorkspaceSizeLimitBytes = 512 * 1024 * 1024; // 512 MB -constexpr cudnnConvolutionFwdAlgo_t kFwdDefaultAlgo = - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; -constexpr cudnnConvolutionBwdDataAlgo_t kBwdDataDefaultAlgo = - CUDNN_CONVOLUTION_BWD_DATA_ALGO_0; -constexpr cudnnConvolutionBwdFilterAlgo_t kBwdFilterDefaultAlgo = - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; + constexpr size_t kWorkspaceSizeLimitBytes = 512 * 1024 * 1024; // 512 MB + constexpr cudnnConvolutionFwdAlgo_t kFwdDefaultAlgo = + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + constexpr cudnnConvolutionBwdDataAlgo_t kBwdDataDefaultAlgo = + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0; + constexpr cudnnConvolutionBwdFilterAlgo_t kBwdFilterDefaultAlgo = + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; // Get the algorithm which gives best performance. // Since cuDNN doesn't support memory limits, we manually choose an algorithm // which requires less than a specific workspace size. -template -T getBestAlgorithm( - const std::vector& algoPerfs, - const std::unordered_set& preferredAlgos, - const fl::dtype arithmeticPrecision) { - T reserved; - bool algoFound = false; - for (const auto& algoPerf : algoPerfs) { - if (algoPerf.status == CUDNN_STATUS_SUCCESS && - algoPerf.memory < kWorkspaceSizeLimitBytes) { - if (!(arithmeticPrecision == fl::dtype::f16) || - (preferredAlgos.find(algoPerf.algo) != preferredAlgos.end())) { - return algoPerf; - } else if (!algoFound) { - reserved = algoPerf; - algoFound = true; - } + template + T getBestAlgorithm( + const std::vector& algoPerfs, + const std::unordered_set& preferredAlgos, + const fl::dtype arithmeticPrecision + ) { + T reserved; + bool algoFound = false; + for(const auto& algoPerf : algoPerfs) { + if( + algoPerf.status == CUDNN_STATUS_SUCCESS + && algoPerf.memory < kWorkspaceSizeLimitBytes + ) { + if( + !(arithmeticPrecision == fl::dtype::f16) + || (preferredAlgos.find(algoPerf.algo) != preferredAlgos.end()) + ) { + return algoPerf; + } else if(!algoFound) { + reserved = algoPerf; + algoFound = true; + } + } + } + if(algoFound) { + return reserved; + } else { + throw std::runtime_error("Error while finding cuDNN Conv Algorithm."); + } } - } - if (algoFound) { - return reserved; - } else { - throw std::runtime_error("Error while finding cuDNN Conv Algorithm."); - } -} -cudnnConvolutionFwdAlgoPerf_t getFwdAlgo( - const cudnnTensorDescriptor_t& xDesc, - const cudnnFilterDescriptor_t& wDesc, - const cudnnConvolutionDescriptor_t& convDesc, - const cudnnTensorDescriptor_t& yDesc, - const fl::dtype arithmeticPrecision) { - int numFwdAlgoRequested, numFwdAlgoReturned; - - CUDNN_CHECK_ERR(cudnnGetConvolutionForwardAlgorithmMaxCount( - fl::getCudnnHandle(), &numFwdAlgoRequested)); - - std::vector fwdAlgoPerfs(numFwdAlgoRequested); - CUDNN_CHECK_ERR(cudnnGetConvolutionForwardAlgorithm_v7( - fl::getCudnnHandle(), - xDesc, - wDesc, - convDesc, - yDesc, - numFwdAlgoRequested, - &numFwdAlgoReturned, - fwdAlgoPerfs.data())); - - return getBestAlgorithm( - fwdAlgoPerfs, kFwdPreferredAlgos, arithmeticPrecision); -} + cudnnConvolutionFwdAlgoPerf_t getFwdAlgo( + const cudnnTensorDescriptor_t& xDesc, + const cudnnFilterDescriptor_t& wDesc, + const cudnnConvolutionDescriptor_t& convDesc, + const cudnnTensorDescriptor_t& yDesc, + const fl::dtype arithmeticPrecision + ) { + int numFwdAlgoRequested, numFwdAlgoReturned; + + CUDNN_CHECK_ERR( + cudnnGetConvolutionForwardAlgorithmMaxCount( + fl::getCudnnHandle(), + &numFwdAlgoRequested + ) + ); + + std::vector fwdAlgoPerfs(numFwdAlgoRequested); + CUDNN_CHECK_ERR( + cudnnGetConvolutionForwardAlgorithm_v7( + fl::getCudnnHandle(), + xDesc, + wDesc, + convDesc, + yDesc, + numFwdAlgoRequested, + &numFwdAlgoReturned, + fwdAlgoPerfs.data() + ) + ); + + return getBestAlgorithm( + fwdAlgoPerfs, + kFwdPreferredAlgos, + arithmeticPrecision + ); + } -cudnnConvolutionBwdDataAlgoPerf_t getBwdDataAlgo( - const cudnnTensorDescriptor_t& xDesc, - const cudnnFilterDescriptor_t& wDesc, - const cudnnConvolutionDescriptor_t& convDesc, - const cudnnTensorDescriptor_t& yDesc, - bool /* isStrided */, - const fl::dtype arithmeticPrecision) { - int numBwdDataAlgoRequested, numBwdDataAlgoReturned; - - CUDNN_CHECK_ERR(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( - fl::getCudnnHandle(), &numBwdDataAlgoRequested)); - - std::vector bwdDataAlgoPerfs( - numBwdDataAlgoRequested); - CUDNN_CHECK_ERR(cudnnGetConvolutionBackwardDataAlgorithm_v7( - fl::getCudnnHandle(), - wDesc, - yDesc, - convDesc, - xDesc, - numBwdDataAlgoRequested, - &numBwdDataAlgoReturned, - bwdDataAlgoPerfs.data())); - - auto bestAlgo = getBestAlgorithm( - bwdDataAlgoPerfs, kBwdDataPreferredAlgos, arithmeticPrecision); - - // We use a few hacks here to resolve some cuDNN bugs - // 1: blacklist BWD_DATA_ALGO_1 - // Seems to produce erroneous results on Tesla P100 GPUs. - // 2: blacklist FFT algorithms for strided dgrad - - // https://github.com/pytorch/pytorch/issues/16610 - bool isAlgoBlacklisted = false; + cudnnConvolutionBwdDataAlgoPerf_t getBwdDataAlgo( + const cudnnTensorDescriptor_t& xDesc, + const cudnnFilterDescriptor_t& wDesc, + const cudnnConvolutionDescriptor_t& convDesc, + const cudnnTensorDescriptor_t& yDesc, + bool /* isStrided */, + const fl::dtype arithmeticPrecision + ) { + int numBwdDataAlgoRequested, numBwdDataAlgoReturned; + + CUDNN_CHECK_ERR( + cudnnGetConvolutionBackwardDataAlgorithmMaxCount( + fl::getCudnnHandle(), + &numBwdDataAlgoRequested + ) + ); + + std::vector bwdDataAlgoPerfs( + numBwdDataAlgoRequested); + CUDNN_CHECK_ERR( + cudnnGetConvolutionBackwardDataAlgorithm_v7( + fl::getCudnnHandle(), + wDesc, + yDesc, + convDesc, + xDesc, + numBwdDataAlgoRequested, + &numBwdDataAlgoReturned, + bwdDataAlgoPerfs.data() + ) + ); + + auto bestAlgo = getBestAlgorithm( + bwdDataAlgoPerfs, + kBwdDataPreferredAlgos, + arithmeticPrecision + ); + + // We use a few hacks here to resolve some cuDNN bugs + // 1: blacklist BWD_DATA_ALGO_1 + // Seems to produce erroneous results on Tesla P100 GPUs. + // 2: blacklist FFT algorithms for strided dgrad - + // https://github.com/pytorch/pytorch/issues/16610 + bool isAlgoBlacklisted = false; #ifndef FL_CUDNN_ALLOW_ALGO_1 - if (arithmeticPrecision != fl::dtype::f16 && - bestAlgo.algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_1) { - isAlgoBlacklisted = true; - } + if( + arithmeticPrecision != fl::dtype::f16 + && bestAlgo.algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 + ) { + isAlgoBlacklisted = true; + } #endif #if CUDNN_VERSION < 7500 - if (isStrided && - (bestAlgo.algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING || - bestAlgo.algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) { - isAlgoBlacklisted = true; - } + if( + isStrided + && (bestAlgo.algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING + || bestAlgo.algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT) + ) { + isAlgoBlacklisted = true; + } #endif - if (isAlgoBlacklisted) { - for (const auto& algoPerf : bwdDataAlgoPerfs) { - if (algoPerf.algo == kBwdDataDefaultAlgo) { - return algoPerf; - } + if(isAlgoBlacklisted) { + for(const auto& algoPerf : bwdDataAlgoPerfs) { + if(algoPerf.algo == kBwdDataDefaultAlgo) { + return algoPerf; + } + } + } + return bestAlgo; } - } - return bestAlgo; -} -cudnnConvolutionBwdFilterAlgoPerf_t getBwdFilterAlgo( - const cudnnTensorDescriptor_t& xDesc, - const cudnnFilterDescriptor_t& wDesc, - const cudnnConvolutionDescriptor_t& convDesc, - const cudnnTensorDescriptor_t& yDesc, - const fl::dtype arithmeticPrecision) { - int numBwdFilterAlgoRequested, numBwdFilterAlgoReturned; - - CUDNN_CHECK_ERR(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( - fl::getCudnnHandle(), &numBwdFilterAlgoRequested)); - - std::vector bwdFilterAlgoPerfs( - numBwdFilterAlgoRequested); - CUDNN_CHECK_ERR(cudnnGetConvolutionBackwardFilterAlgorithm_v7( - fl::getCudnnHandle(), - xDesc, - yDesc, - convDesc, - wDesc, - numBwdFilterAlgoRequested, - &numBwdFilterAlgoReturned, - bwdFilterAlgoPerfs.data())); - auto bestAlgo = getBestAlgorithm( - bwdFilterAlgoPerfs, kBwdFilterPreferredAlgos, arithmeticPrecision); - - // We use a few hacks here to resolve some cuDNN bugs - // 1: blacklist BWD_FILTER_ALGO_1 - // We do the blacklist here just to be safe as we did in BWD_DATA_ALGO_1 - bool isAlgoBlacklisted = false; + cudnnConvolutionBwdFilterAlgoPerf_t getBwdFilterAlgo( + const cudnnTensorDescriptor_t& xDesc, + const cudnnFilterDescriptor_t& wDesc, + const cudnnConvolutionDescriptor_t& convDesc, + const cudnnTensorDescriptor_t& yDesc, + const fl::dtype arithmeticPrecision + ) { + int numBwdFilterAlgoRequested, numBwdFilterAlgoReturned; + + CUDNN_CHECK_ERR( + cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + fl::getCudnnHandle(), + &numBwdFilterAlgoRequested + ) + ); + + std::vector bwdFilterAlgoPerfs( + numBwdFilterAlgoRequested); + CUDNN_CHECK_ERR( + cudnnGetConvolutionBackwardFilterAlgorithm_v7( + fl::getCudnnHandle(), + xDesc, + yDesc, + convDesc, + wDesc, + numBwdFilterAlgoRequested, + &numBwdFilterAlgoReturned, + bwdFilterAlgoPerfs.data() + ) + ); + auto bestAlgo = getBestAlgorithm( + bwdFilterAlgoPerfs, + kBwdFilterPreferredAlgos, + arithmeticPrecision + ); + + // We use a few hacks here to resolve some cuDNN bugs + // 1: blacklist BWD_FILTER_ALGO_1 + // We do the blacklist here just to be safe as we did in BWD_DATA_ALGO_1 + bool isAlgoBlacklisted = false; #ifndef FL_CUDNN_ALLOW_ALGO_1 - if (arithmeticPrecision != fl::dtype::f16 && - bestAlgo.algo == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1) { - isAlgoBlacklisted = true; - } + if( + arithmeticPrecision != fl::dtype::f16 + && bestAlgo.algo == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 + ) { + isAlgoBlacklisted = true; + } #endif - if (isAlgoBlacklisted) { - for (const auto& algoPerf : bwdFilterAlgoPerfs) { - if (algoPerf.algo == kBwdFilterDefaultAlgo) { - return algoPerf; - } + if(isAlgoBlacklisted) { + for(const auto& algoPerf : bwdFilterAlgoPerfs) { + if(algoPerf.algo == kBwdFilterDefaultAlgo) { + return algoPerf; + } + } + } + return bestAlgo; } - } - return bestAlgo; -} /** * Sets the cudnnMathType according to a `KernelMode` value. * * @param[in] cDesc a reference to a `ConvDescriptor` for which the math type - will + will * be set. * @param[in] kernelOptions a pointer to the DynamicBenchmarkOptions for the - possible kernel modes. + possible kernel modes. */ -void setCudnnConvMathType( - ConvDescriptor& cDesc, - const std::shared_ptr< - fl::DynamicBenchmarkOptions>& - kernelOptions) { - CUDNN_CHECK_ERR(cudnnSetConvolutionMathType( - cDesc.descriptor, - kKernelModesToCudnnMathType.at(kernelOptions->currentOption()))); -} + void setCudnnConvMathType( + ConvDescriptor& cDesc, + const std::shared_ptr< + fl::DynamicBenchmarkOptions>& + kernelOptions + ) { + CUDNN_CHECK_ERR( + cudnnSetConvolutionMathType( + cDesc.descriptor, + kKernelModesToCudnnMathType.at(kernelOptions->currentOption()) + ) + ); + } -void setDefaultMathType(ConvDescriptor& cDesc, const Tensor& input) { - if (input.type() == fl::dtype::f16) { - CUDNN_CHECK_ERR(cudnnSetConvolutionMathType( - cDesc.descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); - } else { - CUDNN_CHECK_ERR( - cudnnSetConvolutionMathType(cDesc.descriptor, CUDNN_DEFAULT_MATH)); - } -} + void setDefaultMathType(ConvDescriptor& cDesc, const Tensor& input) { + if(input.type() == fl::dtype::f16) { + CUDNN_CHECK_ERR( + cudnnSetConvolutionMathType( + cDesc.descriptor, + CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION + ) + ); + } else { + CUDNN_CHECK_ERR( + cudnnSetConvolutionMathType(cDesc.descriptor, CUDNN_DEFAULT_MATH) + ); + } + } } // namespace @@ -250,110 +303,130 @@ Tensor CudnnAutogradExtension::conv2d( const int dx, const int dy, const int groups, - std::shared_ptr) { - if (input.ndim() != 4) { - throw std::invalid_argument( - "conv2d: expects input tensor to be 4 dimensions: " - "in WHCN ordering. Given tensor has " + - std::to_string(input.ndim()) + " dimensions."); - } - - auto hasBias = bias.elements() > 0; - - auto inDesc = TensorDescriptor(input); - auto wtDesc = FilterDescriptor(weights); - auto convDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); - if (input.type() == fl::dtype::f16) { - CUDNN_CHECK_ERR(cudnnSetConvolutionMathType( - convDesc.descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); - } else { + std::shared_ptr +) { + if(input.ndim() != 4) { + throw std::invalid_argument( + "conv2d: expects input tensor to be 4 dimensions: " + "in WHCN ordering. Given tensor has " + + std::to_string(input.ndim()) + " dimensions." + ); + } + + auto hasBias = bias.elements() > 0; + + auto inDesc = TensorDescriptor(input); + auto wtDesc = FilterDescriptor(weights); + auto convDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); + if(input.type() == fl::dtype::f16) { + CUDNN_CHECK_ERR( + cudnnSetConvolutionMathType( + convDesc.descriptor, + CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION + ) + ); + } else { + CUDNN_CHECK_ERR( + cudnnSetConvolutionMathType(convDesc.descriptor, CUDNN_DEFAULT_MATH) + ); + } + + std::array odims; CUDNN_CHECK_ERR( - cudnnSetConvolutionMathType(convDesc.descriptor, CUDNN_DEFAULT_MATH)); - } - - std::array odims; - CUDNN_CHECK_ERR(cudnnGetConvolutionNdForwardOutputDim( - convDesc.descriptor, - inDesc.descriptor, - wtDesc.descriptor, - 4, - odims.data())); - auto output = Tensor({odims[3], odims[2], odims[1], odims[0]}, input.type()); - auto outDesc = TensorDescriptor(output); - - auto handle = getCudnnHandle(); - const auto& cudnnStream = getCudnnStream(); - - auto fwdAlgoBestPerf = getFwdAlgo( - inDesc.descriptor, - wtDesc.descriptor, - convDesc.descriptor, - outDesc.descriptor, - input.type()); - - Tensor wspace; - - try { - wspace = - Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8); - } catch (const std::exception&) { - fwdAlgoBestPerf.algo = kFwdDefaultAlgo; - CUDNN_CHECK_ERR(cudnnGetConvolutionForwardWorkspaceSize( - handle, + cudnnGetConvolutionNdForwardOutputDim( + convDesc.descriptor, + inDesc.descriptor, + wtDesc.descriptor, + 4, + odims.data() + ) + ); + auto output = Tensor({odims[3], odims[2], odims[1], odims[0]}, input.type()); + auto outDesc = TensorDescriptor(output); + + auto handle = getCudnnHandle(); + const auto& cudnnStream = getCudnnStream(); + + auto fwdAlgoBestPerf = getFwdAlgo( inDesc.descriptor, wtDesc.descriptor, convDesc.descriptor, outDesc.descriptor, - fwdAlgoBestPerf.algo, - &fwdAlgoBestPerf.memory)); - wspace = - Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8); - } - { - DevicePtr inPtr(input); - DevicePtr wtPtr(weights); - DevicePtr outPtr(output); - DevicePtr wspacePtr(wspace); - // ensure cudnn compute stream waits on streams of input/output tensors - relativeSync(cudnnStream, {input, weights, wspace, output}); + input.type() + ); - auto scalarsType = - input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); - const void* one = kOne(scalarsType); - const void* zero = kZero(scalarsType); - CUDNN_CHECK_ERR(cudnnConvolutionForward( - handle, - one, - inDesc.descriptor, - inPtr.get(), - wtDesc.descriptor, - wtPtr.get(), - convDesc.descriptor, - fwdAlgoBestPerf.algo, - wspacePtr.get(), - fwdAlgoBestPerf.memory, - zero, - outDesc.descriptor, - outPtr.get())); - - if (hasBias) { - auto bsDesc = TensorDescriptor(bias); - DevicePtr bsPtr(bias); - // ensure cudnn compute stream waits on stream of bias tensor - relativeSync(cudnnStream, {bias}); - CUDNN_CHECK_ERR(cudnnAddTensor( - handle, - one, - bsDesc.descriptor, - bsPtr.get(), - one, - outDesc.descriptor, - outPtr.get())); + Tensor wspace; + + try { + wspace = + Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8); + } catch(const std::exception&) { + fwdAlgoBestPerf.algo = kFwdDefaultAlgo; + CUDNN_CHECK_ERR( + cudnnGetConvolutionForwardWorkspaceSize( + handle, + inDesc.descriptor, + wtDesc.descriptor, + convDesc.descriptor, + outDesc.descriptor, + fwdAlgoBestPerf.algo, + &fwdAlgoBestPerf.memory + ) + ); + wspace = + Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8); } - // ensure output stream waits on cudnn compute stream - relativeSync({output}, cudnnStream); - } - return output; + { + DevicePtr inPtr(input); + DevicePtr wtPtr(weights); + DevicePtr outPtr(output); + DevicePtr wspacePtr(wspace); + // ensure cudnn compute stream waits on streams of input/output tensors + relativeSync(cudnnStream, {input, weights, wspace, output}); + + auto scalarsType = + input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); + const void* one = kOne(scalarsType); + const void* zero = kZero(scalarsType); + CUDNN_CHECK_ERR( + cudnnConvolutionForward( + handle, + one, + inDesc.descriptor, + inPtr.get(), + wtDesc.descriptor, + wtPtr.get(), + convDesc.descriptor, + fwdAlgoBestPerf.algo, + wspacePtr.get(), + fwdAlgoBestPerf.memory, + zero, + outDesc.descriptor, + outPtr.get() + ) + ); + + if(hasBias) { + auto bsDesc = TensorDescriptor(bias); + DevicePtr bsPtr(bias); + // ensure cudnn compute stream waits on stream of bias tensor + relativeSync(cudnnStream, {bias}); + CUDNN_CHECK_ERR( + cudnnAddTensor( + handle, + one, + bsDesc.descriptor, + bsPtr.get(), + one, + outDesc.descriptor, + outPtr.get() + ) + ); + } + // ensure output stream waits on cudnn compute stream + relativeSync({output}, cudnnStream); + } + return output; } Tensor CudnnAutogradExtension::conv2dBackwardData( @@ -368,176 +441,211 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( const int dy, const int groups, std::shared_ptr dataGradBenchmark, - std::shared_ptr) { - auto hndl = getCudnnHandle(); - const auto& cudnnStream = getCudnnStream(); - - auto scalarsType = - input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); - const void* oneg = kOne(scalarsType); - const void* zerog = kZero(scalarsType); - - // Create default descriptors assuming no casts. If dynamic - // benchmarking suggests input or weight casting should occur, these - // descriptors may not be used/new ones with the correct types will be - // used instead. - auto iDesc = TensorDescriptor(input); - auto wDesc = FilterDescriptor(weight); - auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); - auto oDesc = TensorDescriptor(gradOutput); - - setDefaultMathType(cDesc, input); - - // Gradients with respect to the input - auto convolutionBackwardData = - [&hndl, &cudnnStream, &dataGradBenchmark, oneg, zerog, dx, dy]( - const Tensor& inTensor, - const Tensor& wtTensor, - const Tensor& gradOutputTensor, - TensorDescriptor& iDesc, - FilterDescriptor& wDesc, - ConvDescriptor& cDesc, - TensorDescriptor& oDesc) -> Tensor { - if (dataGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { - setCudnnConvMathType( - cDesc, - dataGradBenchmark->getOptions>()); - } + std::shared_ptr +) { + auto hndl = getCudnnHandle(); + const auto& cudnnStream = getCudnnStream(); - DevicePtr wPtr(wtTensor); - // ensure cudnn compute stream waits on stream of weight tensor - relativeSync(cudnnStream, {wtTensor}); - bool isStrided = (dx * dy) > 1; - auto bwdDataAlgoBestPerf = getBwdDataAlgo( - iDesc.descriptor, - wDesc.descriptor, - cDesc.descriptor, - oDesc.descriptor, - isStrided, - inTensor.type()); - - Tensor ws; - try { - ws = Tensor( - {static_cast(bwdDataAlgoBestPerf.memory)}, fl::dtype::b8); - } catch (const std::exception&) { - bwdDataAlgoBestPerf.algo = kBwdDataDefaultAlgo; - CUDNN_CHECK_ERR(cudnnGetConvolutionBackwardDataWorkspaceSize( - hndl, - wDesc.descriptor, - oDesc.descriptor, - cDesc.descriptor, - iDesc.descriptor, - bwdDataAlgoBestPerf.algo, - &bwdDataAlgoBestPerf.memory)); - ws = Tensor( - {static_cast(bwdDataAlgoBestPerf.memory)}, fl::dtype::b8); - } - - auto gradInput = Tensor(inTensor.shape(), inTensor.type()); - { - DevicePtr gradInputPtr(gradInput); - DevicePtr gradResultPtr(gradOutputTensor); - DevicePtr wsPtr(ws); - // ensure cudnn compute stream waits on streams of input/output tensors - relativeSync(cudnnStream, {gradOutputTensor, ws, gradInput}); - CUDNN_CHECK_ERR(cudnnConvolutionBackwardData( - hndl, - oneg, - wDesc.descriptor, - wPtr.get(), - oDesc.descriptor, - gradResultPtr.get(), - cDesc.descriptor, - bwdDataAlgoBestPerf.algo, - wsPtr.get(), - bwdDataAlgoBestPerf.memory, - zerog, - iDesc.descriptor, - gradInputPtr.get())); - } - // ensure stream of gradient waits on the cudnn compute stream - relativeSync({gradInput}, cudnnStream); - return gradInput; - }; - - Tensor dataGradOut; - - if (dataGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { - KernelMode dataBwdOption = - dataGradBenchmark->getOptions>() + auto scalarsType = + input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); + const void* oneg = kOne(scalarsType); + const void* zerog = kZero(scalarsType); + + // Create default descriptors assuming no casts. If dynamic + // benchmarking suggests input or weight casting should occur, these + // descriptors may not be used/new ones with the correct types will be + // used instead. + auto iDesc = TensorDescriptor(input); + auto wDesc = FilterDescriptor(weight); + auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); + auto oDesc = TensorDescriptor(gradOutput); + + setDefaultMathType(cDesc, input); + + // Gradients with respect to the input + auto convolutionBackwardData = + [&hndl, &cudnnStream, &dataGradBenchmark, oneg, zerog, dx, dy]( + const Tensor& inTensor, + const Tensor& wtTensor, + const Tensor& gradOutputTensor, + TensorDescriptor& iDesc, + FilterDescriptor& wDesc, + ConvDescriptor& cDesc, + TensorDescriptor& oDesc) -> Tensor { + if(dataGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { + setCudnnConvMathType( + cDesc, + dataGradBenchmark->getOptions>() + ); + } + + DevicePtr wPtr(wtTensor); + // ensure cudnn compute stream waits on stream of weight tensor + relativeSync(cudnnStream, {wtTensor}); + bool isStrided = (dx * dy) > 1; + auto bwdDataAlgoBestPerf = getBwdDataAlgo( + iDesc.descriptor, + wDesc.descriptor, + cDesc.descriptor, + oDesc.descriptor, + isStrided, + inTensor.type() + ); + + Tensor ws; + try { + ws = Tensor( + {static_cast(bwdDataAlgoBestPerf.memory)}, + fl::dtype::b8 + ); + } catch(const std::exception&) { + bwdDataAlgoBestPerf.algo = kBwdDataDefaultAlgo; + CUDNN_CHECK_ERR( + cudnnGetConvolutionBackwardDataWorkspaceSize( + hndl, + wDesc.descriptor, + oDesc.descriptor, + cDesc.descriptor, + iDesc.descriptor, + bwdDataAlgoBestPerf.algo, + &bwdDataAlgoBestPerf.memory + ) + ); + ws = Tensor( + {static_cast(bwdDataAlgoBestPerf.memory)}, + fl::dtype::b8 + ); + } + + auto gradInput = Tensor(inTensor.shape(), inTensor.type()); + { + DevicePtr gradInputPtr(gradInput); + DevicePtr gradResultPtr(gradOutputTensor); + DevicePtr wsPtr(ws); + // ensure cudnn compute stream waits on streams of input/output tensors + relativeSync(cudnnStream, {gradOutputTensor, ws, gradInput}); + CUDNN_CHECK_ERR( + cudnnConvolutionBackwardData( + hndl, + oneg, + wDesc.descriptor, + wPtr.get(), + oDesc.descriptor, + gradResultPtr.get(), + cDesc.descriptor, + bwdDataAlgoBestPerf.algo, + wsPtr.get(), + bwdDataAlgoBestPerf.memory, + zerog, + iDesc.descriptor, + gradInputPtr.get() + ) + ); + } + // ensure stream of gradient waits on the cudnn compute stream + relativeSync({gradInput}, cudnnStream); + return gradInput; + }; + + Tensor dataGradOut; + + if(dataGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { + KernelMode dataBwdOption = + dataGradBenchmark->getOptions>() ->currentOption(); - if (input.type() == fl::dtype::f16 && - dataBwdOption == CudnnAutogradExtension::KernelMode::F32 && - dataBwdOption == - CudnnAutogradExtension::KernelMode::F32_ALLOW_CONVERSION) { - // The input type of fp16, but the result of the dynamic benchmark - // is that using fp32 kernels is faster for computing bwd with fp16 - // kernels, including the cast - Tensor inTensorF32; - Tensor wtTensorF32; - Tensor gradOutputTensorF32; - dataGradBenchmark->audit( - [&input, - &inTensorF32, - &weight, - &wtTensorF32, - &gradOutput, - &gradOutputTensorF32]() { - inTensorF32 = input.astype(fl::dtype::f32); - wtTensorF32 = weight.astype(fl::dtype::f32); - gradOutputTensorF32 = gradOutput.astype(fl::dtype::f32); - }, - /* incrementCount = */ false); - - auto iDescF32 = TensorDescriptor(inTensorF32); - auto wDescF32 = FilterDescriptor(wtTensorF32); - auto cDescF32 = - ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups); - auto oDescF32 = TensorDescriptor(gradOutputTensorF32); - // core bwd data computation - dataGradBenchmark->audit([&dataGradOut, - &convolutionBackwardData, - &inTensorF32, - &wtTensorF32, - &gradOutputTensorF32, - &iDescF32, - &wDescF32, - &cDescF32, - &oDescF32]() { - dataGradOut = convolutionBackwardData( - inTensorF32, - wtTensorF32, - gradOutputTensorF32, - iDescF32, - wDescF32, - cDescF32, - oDescF32); - }); + if( + input.type() == fl::dtype::f16 + && dataBwdOption == CudnnAutogradExtension::KernelMode::F32 + && dataBwdOption + == CudnnAutogradExtension::KernelMode::F32_ALLOW_CONVERSION + ) { + // The input type of fp16, but the result of the dynamic benchmark + // is that using fp32 kernels is faster for computing bwd with fp16 + // kernels, including the cast + Tensor inTensorF32; + Tensor wtTensorF32; + Tensor gradOutputTensorF32; + dataGradBenchmark->audit( + [&input, + &inTensorF32, + &weight, + &wtTensorF32, + &gradOutput, + &gradOutputTensorF32]() { + inTensorF32 = input.astype(fl::dtype::f32); + wtTensorF32 = weight.astype(fl::dtype::f32); + gradOutputTensorF32 = gradOutput.astype(fl::dtype::f32); + }, + /* incrementCount = */ false + ); + + auto iDescF32 = TensorDescriptor(inTensorF32); + auto wDescF32 = FilterDescriptor(wtTensorF32); + auto cDescF32 = + ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups); + auto oDescF32 = TensorDescriptor(gradOutputTensorF32); + // core bwd data computation + dataGradBenchmark->audit( + [&dataGradOut, + &convolutionBackwardData, + &inTensorF32, + &wtTensorF32, + &gradOutputTensorF32, + &iDescF32, + &wDescF32, + &cDescF32, + &oDescF32]() { + dataGradOut = convolutionBackwardData( + inTensorF32, + wtTensorF32, + gradOutputTensorF32, + iDescF32, + wDescF32, + cDescF32, + oDescF32 + ); + } + ); + + } else { + dataGradBenchmark->audit( + [&dataGradOut, + &convolutionBackwardData, + &input, + &weight, + &gradOutput, + &iDesc, + &wDesc, + &cDesc, + &oDesc]() { + dataGradOut = convolutionBackwardData( + input, + weight, + gradOutput, + iDesc, + wDesc, + cDesc, + oDesc + ); + } + ); + } } else { - dataGradBenchmark->audit([&dataGradOut, - &convolutionBackwardData, - &input, - &weight, - &gradOutput, - &iDesc, - &wDesc, - &cDesc, - &oDesc]() { + // No benchmarking - proceed normally dataGradOut = convolutionBackwardData( - input, weight, gradOutput, iDesc, wDesc, cDesc, oDesc); - }); + input, + weight, + gradOutput, + iDesc, + wDesc, + cDesc, + oDesc + ); } - } else { - // No benchmarking - proceed normally - dataGradOut = convolutionBackwardData( - input, weight, gradOutput, iDesc, wDesc, cDesc, oDesc); - } - - return dataGradOut; + return dataGradOut; } std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( @@ -554,252 +662,295 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( const int groups, std::shared_ptr filterGradBenchmark, std::shared_ptr biasGradBenchmark, - std::shared_ptr) { - auto hndl = getCudnnHandle(); - const auto& cudnnStream = getCudnnStream(); - - auto scalarsType = - input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); - const void* oneg = kOne(scalarsType); - const void* zerog = kZero(scalarsType); - - // Create default descriptors assuming no casts. If dynamic - // benchmarking suggests input or weight casting should occur, these - // descriptors may not be used/new ones with the correct types will be - // used instead. - auto iDesc = TensorDescriptor(input); - auto wDesc = FilterDescriptor(weight); - auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); - auto oDesc = TensorDescriptor(gradOutput); - - setDefaultMathType(cDesc, input); - - // Gradients with respect to the filter - auto convolutionBackwardFilter = - [&hndl, &cudnnStream, &filterGradBenchmark, oneg, zerog]( - const Tensor& inTensor, - const Tensor& wtTensor, - const Tensor& gradOutputTensor, - TensorDescriptor& iDesc, - FilterDescriptor& wDesc, - ConvDescriptor& cDesc, - TensorDescriptor& oDesc) -> Tensor { - if (filterGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { - setCudnnConvMathType( - cDesc, - filterGradBenchmark - ->getOptions>()); - } - - DevicePtr iPtr(inTensor); - // ensure cudnn compute stream waits on stream of input tensor - relativeSync(cudnnStream, {inTensor}); - auto bwdFilterAlgoBestPerf = getBwdFilterAlgo( - iDesc.descriptor, - wDesc.descriptor, - cDesc.descriptor, - oDesc.descriptor, - inTensor.type()); - - Tensor ws; - try { - ws = Tensor( - {static_cast(bwdFilterAlgoBestPerf.memory)}, - fl::dtype::b8); - } catch (const std::exception&) { - bwdFilterAlgoBestPerf.algo = kBwdFilterDefaultAlgo; - CUDNN_CHECK_ERR(cudnnGetConvolutionBackwardFilterWorkspaceSize( - hndl, - iDesc.descriptor, - oDesc.descriptor, - cDesc.descriptor, - wDesc.descriptor, - bwdFilterAlgoBestPerf.algo, - &bwdFilterAlgoBestPerf.memory)); - ws = Tensor( - {static_cast(bwdFilterAlgoBestPerf.memory)}, - fl::dtype::b8); - } + std::shared_ptr +) { + auto hndl = getCudnnHandle(); + const auto& cudnnStream = getCudnnStream(); - auto gradWeight = Tensor(wtTensor.shape(), wtTensor.type()); - { - DevicePtr gradWeightPtr(gradWeight); - DevicePtr gradResultPtr(gradOutputTensor); - DevicePtr wsPtr(ws); - // ensure cudnn compute stream waits on streams of input/output tensors - relativeSync(cudnnStream, {gradOutputTensor, ws, gradWeight}); - CUDNN_CHECK_ERR(cudnnConvolutionBackwardFilter( - hndl, - oneg, - iDesc.descriptor, - iPtr.get(), - oDesc.descriptor, - gradResultPtr.get(), - cDesc.descriptor, - bwdFilterAlgoBestPerf.algo, - wsPtr.get(), - bwdFilterAlgoBestPerf.memory, - zerog, - wDesc.descriptor, - gradWeightPtr.get())); - } - // ensure gradient tensor stream waits on cudnn compute stream - relativeSync({gradWeight}, cudnnStream); - return gradWeight; - }; - - Tensor filterGradOut; - - if (filterGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { - KernelMode dataBwdOption = - filterGradBenchmark->getOptions>() + auto scalarsType = + input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type(); + const void* oneg = kOne(scalarsType); + const void* zerog = kZero(scalarsType); + + // Create default descriptors assuming no casts. If dynamic + // benchmarking suggests input or weight casting should occur, these + // descriptors may not be used/new ones with the correct types will be + // used instead. + auto iDesc = TensorDescriptor(input); + auto wDesc = FilterDescriptor(weight); + auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); + auto oDesc = TensorDescriptor(gradOutput); + + setDefaultMathType(cDesc, input); + + // Gradients with respect to the filter + auto convolutionBackwardFilter = + [&hndl, &cudnnStream, &filterGradBenchmark, oneg, zerog]( + const Tensor& inTensor, + const Tensor& wtTensor, + const Tensor& gradOutputTensor, + TensorDescriptor& iDesc, + FilterDescriptor& wDesc, + ConvDescriptor& cDesc, + TensorDescriptor& oDesc) -> Tensor { + if(filterGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { + setCudnnConvMathType( + cDesc, + filterGradBenchmark + ->getOptions>() + ); + } + + DevicePtr iPtr(inTensor); + // ensure cudnn compute stream waits on stream of input tensor + relativeSync(cudnnStream, {inTensor}); + auto bwdFilterAlgoBestPerf = getBwdFilterAlgo( + iDesc.descriptor, + wDesc.descriptor, + cDesc.descriptor, + oDesc.descriptor, + inTensor.type() + ); + + Tensor ws; + try { + ws = Tensor( + {static_cast(bwdFilterAlgoBestPerf.memory)}, + fl::dtype::b8 + ); + } catch(const std::exception&) { + bwdFilterAlgoBestPerf.algo = kBwdFilterDefaultAlgo; + CUDNN_CHECK_ERR( + cudnnGetConvolutionBackwardFilterWorkspaceSize( + hndl, + iDesc.descriptor, + oDesc.descriptor, + cDesc.descriptor, + wDesc.descriptor, + bwdFilterAlgoBestPerf.algo, + &bwdFilterAlgoBestPerf.memory + ) + ); + ws = Tensor( + {static_cast(bwdFilterAlgoBestPerf.memory)}, + fl::dtype::b8 + ); + } + + auto gradWeight = Tensor(wtTensor.shape(), wtTensor.type()); + { + DevicePtr gradWeightPtr(gradWeight); + DevicePtr gradResultPtr(gradOutputTensor); + DevicePtr wsPtr(ws); + // ensure cudnn compute stream waits on streams of input/output tensors + relativeSync(cudnnStream, {gradOutputTensor, ws, gradWeight}); + CUDNN_CHECK_ERR( + cudnnConvolutionBackwardFilter( + hndl, + oneg, + iDesc.descriptor, + iPtr.get(), + oDesc.descriptor, + gradResultPtr.get(), + cDesc.descriptor, + bwdFilterAlgoBestPerf.algo, + wsPtr.get(), + bwdFilterAlgoBestPerf.memory, + zerog, + wDesc.descriptor, + gradWeightPtr.get() + ) + ); + } + // ensure gradient tensor stream waits on cudnn compute stream + relativeSync({gradWeight}, cudnnStream); + return gradWeight; + }; + + Tensor filterGradOut; + + if(filterGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { + KernelMode dataBwdOption = + filterGradBenchmark->getOptions>() ->currentOption(); - if (input.type() == fl::dtype::f16 && - dataBwdOption == CudnnAutogradExtension::KernelMode::F32 && - dataBwdOption == - CudnnAutogradExtension::KernelMode::F32_ALLOW_CONVERSION) { - // The input type of fp16, but the result of the dynamic benchmark is - // that using fp32 kernels is faster for computing bwd with fp16 - // kernels, including the cast - Tensor inTensorF32; - Tensor wtTensorF32; - Tensor gradOutputTensorF32; - filterGradBenchmark->audit( - [&input, - &inTensorF32, - &weight, - &wtTensorF32, - &gradOutput, - &gradOutputTensorF32]() { - inTensorF32 = input.astype(fl::dtype::f32); - wtTensorF32 = weight.astype(fl::dtype::f32); - gradOutputTensorF32 = gradOutput.astype(fl::dtype::f32); - }, - /* incrementCount = */ false); - - auto iDescF32 = TensorDescriptor(inTensorF32); - auto wDescF32 = FilterDescriptor(wtTensorF32); - auto cDescF32 = - ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups); - auto oDescF32 = TensorDescriptor(gradOutputTensorF32); - // core bwd data computation - filterGradBenchmark->audit([&filterGradOut, - &convolutionBackwardFilter, - &inTensorF32, - &wtTensorF32, - &gradOutputTensorF32, - &iDescF32, - &wDescF32, - &cDescF32, - &oDescF32]() { - filterGradOut = convolutionBackwardFilter( - inTensorF32, - wtTensorF32, - gradOutputTensorF32, - iDescF32, - wDescF32, - cDescF32, - oDescF32); - }); + if( + input.type() == fl::dtype::f16 + && dataBwdOption == CudnnAutogradExtension::KernelMode::F32 + && dataBwdOption + == CudnnAutogradExtension::KernelMode::F32_ALLOW_CONVERSION + ) { + // The input type of fp16, but the result of the dynamic benchmark is + // that using fp32 kernels is faster for computing bwd with fp16 + // kernels, including the cast + Tensor inTensorF32; + Tensor wtTensorF32; + Tensor gradOutputTensorF32; + filterGradBenchmark->audit( + [&input, + &inTensorF32, + &weight, + &wtTensorF32, + &gradOutput, + &gradOutputTensorF32]() { + inTensorF32 = input.astype(fl::dtype::f32); + wtTensorF32 = weight.astype(fl::dtype::f32); + gradOutputTensorF32 = gradOutput.astype(fl::dtype::f32); + }, + /* incrementCount = */ false + ); + + auto iDescF32 = TensorDescriptor(inTensorF32); + auto wDescF32 = FilterDescriptor(wtTensorF32); + auto cDescF32 = + ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups); + auto oDescF32 = TensorDescriptor(gradOutputTensorF32); + // core bwd data computation + filterGradBenchmark->audit( + [&filterGradOut, + &convolutionBackwardFilter, + &inTensorF32, + &wtTensorF32, + &gradOutputTensorF32, + &iDescF32, + &wDescF32, + &cDescF32, + &oDescF32]() { + filterGradOut = convolutionBackwardFilter( + inTensorF32, + wtTensorF32, + gradOutputTensorF32, + iDescF32, + wDescF32, + cDescF32, + oDescF32 + ); + } + ); + + } else { + filterGradBenchmark->audit( + [&filterGradOut, + &convolutionBackwardFilter, + &input, + &weight, + &gradOutput, + &iDesc, + &wDesc, + &cDesc, + &oDesc]() { + filterGradOut = convolutionBackwardFilter( + input, + weight, + gradOutput, + iDesc, + wDesc, + cDesc, + oDesc + ); + } + ); + } } else { - filterGradBenchmark->audit([&filterGradOut, - &convolutionBackwardFilter, - &input, - &weight, - &gradOutput, - &iDesc, - &wDesc, - &cDesc, - &oDesc]() { filterGradOut = convolutionBackwardFilter( - input, weight, gradOutput, iDesc, wDesc, cDesc, oDesc); - }); + input, + weight, + gradOutput, + iDesc, + wDesc, + cDesc, + oDesc + ); } - } else { - filterGradOut = convolutionBackwardFilter( - input, weight, gradOutput, iDesc, wDesc, cDesc, oDesc); - } - - auto convolutionBackwardBias = [&hndl, &cudnnStream, oneg, zerog]( - const Tensor& bsTensor, - const Tensor& gradOutput, - const TensorDescriptor& oDesc) -> Tensor { - auto gradBias = Tensor(bsTensor.shape(), bsTensor.type()); - { - DevicePtr gradBiasPtr(gradBias); - DevicePtr gradResultPtr(gradOutput); - // ensure cudnn compute stream waits on gradient tensor streams - relativeSync(cudnnStream, {gradOutput, gradBias}); - auto bDesc = TensorDescriptor(bsTensor); - CUDNN_CHECK_ERR(cudnnConvolutionBackwardBias( - hndl, - oneg, - oDesc.descriptor, - gradResultPtr.get(), - zerog, - bDesc.descriptor, - gradBiasPtr.get())); - } - // ensure gradient bias tensor stream waits on cudnn compute stream - relativeSync({gradBias}, cudnnStream); - return gradBias; - }; - - Tensor biasGradOut; - - if (!bias.isEmpty()) { - if (biasGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { - KernelMode biasBwdOption = - biasGradBenchmark->getOptions>() - ->currentOption(); - - if (bias.type() == fl::dtype::f16 && - biasBwdOption == CudnnAutogradExtension::KernelMode::F32 && - biasBwdOption == - CudnnAutogradExtension::KernelMode::F32_ALLOW_CONVERSION) { - // The input type of fp16, but the result of the dynamic benchmark is - // that using fp32 kernels is faster for computing bwd with fp16 - // kernels, including the cast - Tensor biasF32; - Tensor gradOutputF32; - // Time cast bias and grad output if benchmarking - biasGradBenchmark->audit( - [&bias, &gradOutput, &biasF32, &gradOutputF32]() { - biasF32 = bias.astype(fl::dtype::f32); - gradOutputF32 = gradOutput.astype(fl::dtype::f32); - }, - /* incrementCount = */ false); - auto oDescF32 = TensorDescriptor(gradOutputF32); - // Perform bias gradient computation - biasGradBenchmark->audit([&biasGradOut, - &convolutionBackwardBias, - &biasF32, - &gradOutputF32, - &oDescF32]() { - biasGradOut = - convolutionBackwardBias(biasF32, gradOutputF32, oDescF32); - }); - } else { - // Grad output and bias types are already the same, so perform the - // computation using whatever input type is given - biasGradBenchmark->audit([&biasGradOut, - &convolutionBackwardBias, - &bias, - &gradOutput, - &oDesc]() { - biasGradOut = convolutionBackwardBias(bias, gradOutput, oDesc); - }); - } - } else { - // No benchmark; proceed normally - biasGradOut = convolutionBackwardBias(bias, gradOutput, oDesc); + auto convolutionBackwardBias = [&hndl, &cudnnStream, oneg, zerog]( + const Tensor& bsTensor, + const Tensor& gradOutput, + const TensorDescriptor& oDesc) -> Tensor { + auto gradBias = Tensor(bsTensor.shape(), bsTensor.type()); + { + DevicePtr gradBiasPtr(gradBias); + DevicePtr gradResultPtr(gradOutput); + // ensure cudnn compute stream waits on gradient tensor streams + relativeSync(cudnnStream, {gradOutput, gradBias}); + auto bDesc = TensorDescriptor(bsTensor); + CUDNN_CHECK_ERR( + cudnnConvolutionBackwardBias( + hndl, + oneg, + oDesc.descriptor, + gradResultPtr.get(), + zerog, + bDesc.descriptor, + gradBiasPtr.get() + ) + ); + } + // ensure gradient bias tensor stream waits on cudnn compute stream + relativeSync({gradBias}, cudnnStream); + return gradBias; + }; + + Tensor biasGradOut; + + if(!bias.isEmpty()) { + if(biasGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { + KernelMode biasBwdOption = + biasGradBenchmark->getOptions>() + ->currentOption(); + + if( + bias.type() == fl::dtype::f16 + && biasBwdOption == CudnnAutogradExtension::KernelMode::F32 + && biasBwdOption + == CudnnAutogradExtension::KernelMode::F32_ALLOW_CONVERSION + ) { + // The input type of fp16, but the result of the dynamic benchmark is + // that using fp32 kernels is faster for computing bwd with fp16 + // kernels, including the cast + Tensor biasF32; + Tensor gradOutputF32; + // Time cast bias and grad output if benchmarking + biasGradBenchmark->audit( + [&bias, &gradOutput, &biasF32, &gradOutputF32]() { + biasF32 = bias.astype(fl::dtype::f32); + gradOutputF32 = gradOutput.astype(fl::dtype::f32); + }, + /* incrementCount = */ false + ); + auto oDescF32 = TensorDescriptor(gradOutputF32); + // Perform bias gradient computation + biasGradBenchmark->audit( + [&biasGradOut, + &convolutionBackwardBias, + &biasF32, + &gradOutputF32, + &oDescF32]() { + biasGradOut = + convolutionBackwardBias(biasF32, gradOutputF32, oDescF32); + } + ); + } else { + // Grad output and bias types are already the same, so perform the + // computation using whatever input type is given + biasGradBenchmark->audit( + [&biasGradOut, + &convolutionBackwardBias, + &bias, + &gradOutput, + &oDesc]() { + biasGradOut = convolutionBackwardBias(bias, gradOutput, oDesc); + } + ); + } + } else { + // No benchmark; proceed normally + biasGradOut = convolutionBackwardBias(bias, gradOutput, oDesc); + } } - } - return {filterGradOut, biasGradOut}; + return {filterGradOut, biasGradOut}; } } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp index 5daf3a9..2560cd4 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp @@ -13,26 +13,28 @@ namespace fl { -std::shared_ptr -CudnnAutogradExtension::createBenchmarkOptions() { - return std::make_shared( - std::make_shared>( - std::vector( - {KernelMode::F32, - KernelMode::F32_ALLOW_CONVERSION, - KernelMode::F16}), - fl::kDynamicBenchmarkDefaultCount)); +std::shared_ptr CudnnAutogradExtension::createBenchmarkOptions() { + return std::make_shared( + std::make_shared>( + std::vector( + {KernelMode::F32, + KernelMode::F32_ALLOW_CONVERSION, + KernelMode::F16} + ), + fl::kDynamicBenchmarkDefaultCount + ) + ); } bool CudnnAutogradExtension::isDataTypeSupported(const fl::dtype& dtype) const { - switch (dtype) { - case fl::dtype::f16: - case fl::dtype::f32: - case fl::dtype::f64: - return true; - default: - return false; - } + switch(dtype) { + case fl::dtype::f16: + case fl::dtype::f32: + case fl::dtype::f64: + return true; + default: + return false; + } } } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h index 51bc6b9..b960c30 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h @@ -16,138 +16,147 @@ namespace fl { class DynamicBenchmark; class CudnnAutogradExtension : public AutogradExtension { - // TODO(jacobkahn): implement getCudnnHandle - - public: - bool isDataTypeSupported(const fl::dtype& dtype) const override; - - enum class KernelMode { F32 = 0, F32_ALLOW_CONVERSION = 1, F16 = 2 }; - - std::shared_ptr createBenchmarkOptions() override; - - /**************************** Forward ****************************/ - Tensor conv2d( - const Tensor& input, - const Tensor& weights, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr payload) override; - - Tensor pool2d( - const Tensor& input, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, - std::shared_ptr payload) override; - - Tensor batchnorm( - Tensor& saveMean, - Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - Tensor& runningMean, - Tensor& runningVar, - const std::vector& axes, - const bool train, - const double momentum, - const double epsilon, - std::shared_ptr payload) override; - - std::tuple rnn( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const int hiddenSize, - const int numLayers, - const RnnMode mode, - const bool bidirectional, - const float dropout, - std::shared_ptr payload) override; - - /**************************** Backward ****************************/ - // ]----- Convolution - Tensor conv2dBackwardData( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& weight, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr dataGradBenchmark, - std::shared_ptr payload) override; - - std::pair conv2dBackwardFilterBias( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& weights, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr filterBench, - std::shared_ptr biasBench, - std::shared_ptr autogradPayload) override; - - // ]----- pool2D - Tensor pool2dBackward( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& poolOutput, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, - std::shared_ptr payload) override; - - // ]----- batchnorm - std::tuple batchnormBackward( - const Tensor& gradOutput, - const Tensor& saveMean, - const Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const std::vector& axes, - const bool train, - const float epsilon, - std::shared_ptr payload) override; - - // ]----- rnn - std::tuple rnnBackward( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const std::shared_ptr gradData, - const Tensor& output, - const int numLayers, - const int hiddenSize, - const RnnMode mode, - const bool bidirectional, - const float dropProb, - std::shared_ptr payload) override; + // TODO(jacobkahn): implement getCudnnHandle + +public: + bool isDataTypeSupported(const fl::dtype& dtype) const override; + + enum class KernelMode {F32 = 0, F32_ALLOW_CONVERSION = 1, F16 = 2}; + + std::shared_ptr createBenchmarkOptions() override; + + /**************************** Forward ****************************/ + Tensor conv2d( + const Tensor& input, + const Tensor& weights, + const Tensor& bias, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr payload + ) override; + + Tensor pool2d( + const Tensor& input, + const int wx, + const int wy, + const int sx, + const int sy, + const int px, + const int py, + const PoolingMode mode, + std::shared_ptr payload + ) override; + + Tensor batchnorm( + Tensor& saveMean, + Tensor& saveVar, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + Tensor& runningMean, + Tensor& runningVar, + const std::vector& axes, + const bool train, + const double momentum, + const double epsilon, + std::shared_ptr payload + ) override; + + std::tuple rnn( + const Tensor& input, + const Tensor& hiddenState, + const Tensor& cellState, + const Tensor& weights, + const int hiddenSize, + const int numLayers, + const RnnMode mode, + const bool bidirectional, + const float dropout, + std::shared_ptr payload + ) override; + + /**************************** Backward ****************************/ + // ]----- Convolution + Tensor conv2dBackwardData( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& weight, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr dataGradBenchmark, + std::shared_ptr payload + ) override; + + std::pair conv2dBackwardFilterBias( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& weights, + const Tensor& bias, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr filterBench, + std::shared_ptr biasBench, + std::shared_ptr autogradPayload + ) override; + + // ]----- pool2D + Tensor pool2dBackward( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& poolOutput, + const int wx, + const int wy, + const int sx, + const int sy, + const int px, + const int py, + const PoolingMode mode, + std::shared_ptr payload + ) override; + + // ]----- batchnorm + std::tuple batchnormBackward( + const Tensor& gradOutput, + const Tensor& saveMean, + const Tensor& saveVar, + const Tensor& input, + const Tensor& weight, + const std::vector& axes, + const bool train, + const float epsilon, + std::shared_ptr payload + ) override; + + // ]----- rnn + std::tuple rnnBackward( + const Tensor& input, + const Tensor& hiddenState, + const Tensor& cellState, + const Tensor& weights, + const std::shared_ptr gradData, + const Tensor& output, + const int numLayers, + const int hiddenSize, + const RnnMode mode, + const bool bidirectional, + const float dropProb, + std::shared_ptr payload + ) override; }; } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp index aeedcda..db30ea6 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp @@ -21,26 +21,26 @@ namespace { struct DeviceHandle { - cudnnHandle_t cudnnHandle; - std::shared_ptr stream; + cudnnHandle_t cudnnHandle; + std::shared_ptr stream; - explicit DeviceHandle(std::shared_ptr _stream) - : cudnnHandle(nullptr), stream(_stream) { - CUDNN_CHECK_ERR(cudnnCreate(&cudnnHandle)); - CUDNN_CHECK_ERR(cudnnSetStream(cudnnHandle, stream->handle())); - } + explicit DeviceHandle(std::shared_ptr _stream) : cudnnHandle(nullptr), + stream(_stream) { + CUDNN_CHECK_ERR(cudnnCreate(&cudnnHandle)); + CUDNN_CHECK_ERR(cudnnSetStream(cudnnHandle, stream->handle())); + } - ~DeviceHandle() { - if (cudnnHandle) { + ~DeviceHandle() { + if(cudnnHandle) { // See https://git.io/fNQnM - sometimes, at exit, the CUDA context // (or something) is already destroyed by the time a handle gets destroyed // because of an issue with the destruction order. #ifdef NO_CUDNN_DESTROY_HANDLE #else - CUDNN_CHECK_ERR(cudnnDestroy(cudnnHandle)); + CUDNN_CHECK_ERR(cudnnDestroy(cudnnHandle)); #endif + } } - } }; const float kFloatZero = 0.0; @@ -53,31 +53,31 @@ const double kDoubleOne = 1.0; std::unordered_map handles; const DeviceHandle& getActiveDeviceHandle() { - auto& manager = fl::DeviceManager::getInstance(); - auto& cudaDevice = - manager.getActiveDevice(fl::DeviceType::CUDA).impl(); - int id = cudaDevice.nativeId(); - // lazily initialize cuda stream for cudnn - if (!handles.contains(id)) { + auto& manager = fl::DeviceManager::getInstance(); + auto& cudaDevice = + manager.getActiveDevice(fl::DeviceType::CUDA).impl(); + int id = cudaDevice.nativeId(); + // lazily initialize cuda stream for cudnn + if(!handles.contains(id)) { #ifdef NO_CUDNN_DESTROY_HANDLE - // NOTE unmanaged so to avoid CUDA driver shut down prior to stream - // destruction. This is safe because this object is always part of a global - // map -- the resource won't be relased until program shutdown anyway. - auto stream = fl::CUDAStream::createUnmanaged(); + // NOTE unmanaged so to avoid CUDA driver shut down prior to stream + // destruction. This is safe because this object is always part of a global + // map -- the resource won't be relased until program shutdown anyway. + auto stream = fl::CUDAStream::createUnmanaged(); #else - auto stream = fl::CUDAStream::createManaged(); + auto stream = fl::CUDAStream::createManaged(); #endif - handles.emplace(id, DeviceHandle(stream)); - } - return handles.at(id); + handles.emplace(id, DeviceHandle(stream)); + } + return handles.at(id); } // See https://git.io/fp9oo for an explanation. #if CUDNN_VERSION < 7000 struct CudnnDropoutStruct { - float dropout; - int nstates; - void* states; + float dropout; + int nstates; + void* states; }; #endif @@ -86,123 +86,134 @@ struct CudnnDropoutStruct { namespace fl { void cudnnCheckErr(cudnnStatus_t status) { - if (status == CUDNN_STATUS_SUCCESS) { - return; - } - const char* err = cudnnGetErrorString(status); - switch (status) { - case CUDNN_STATUS_BAD_PARAM: - throw std::invalid_argument(err); - default: - throw std::runtime_error(err); - } + if(status == CUDNN_STATUS_SUCCESS) { + return; + } + const char* err = cudnnGetErrorString(status); + switch(status) { + case CUDNN_STATUS_BAD_PARAM: + throw std::invalid_argument(err); + default: + throw std::runtime_error(err); + } } cudnnDataType_t cudnnMapToType(const fl::dtype& t) { - switch (t) { - case fl::dtype::f16: - return CUDNN_DATA_HALF; - case fl::dtype::f32: - return CUDNN_DATA_FLOAT; - case fl::dtype::f64: - return CUDNN_DATA_DOUBLE; - default: - throw std::invalid_argument("unsupported data type for cuDNN"); - } + switch(t) { + case fl::dtype::f16: + return CUDNN_DATA_HALF; + case fl::dtype::f32: + return CUDNN_DATA_FLOAT; + case fl::dtype::f64: + return CUDNN_DATA_DOUBLE; + default: + throw std::invalid_argument("unsupported data type for cuDNN"); + } } cudnnPoolingMode_t cudnnMapToPoolingMode(const PoolingMode mode) { - switch (mode) { - case PoolingMode::MAX: - return CUDNN_POOLING_MAX; - case PoolingMode::AVG_INCLUDE_PADDING: - return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; - case PoolingMode::AVG_EXCLUDE_PADDING: - return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; - default: - throw std::invalid_argument("unsupported pooling mode for cuDNN"); - } + switch(mode) { + case PoolingMode::MAX: + return CUDNN_POOLING_MAX; + case PoolingMode::AVG_INCLUDE_PADDING: + return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + case PoolingMode::AVG_EXCLUDE_PADDING: + return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + default: + throw std::invalid_argument("unsupported pooling mode for cuDNN"); + } } cudnnRNNMode_t cudnnMapToRNNMode(const RnnMode mode) { - switch (mode) { - case RnnMode::RELU: - return CUDNN_RNN_RELU; - case RnnMode::TANH: - return CUDNN_RNN_TANH; - case RnnMode::LSTM: - return CUDNN_LSTM; - case RnnMode::GRU: - return CUDNN_GRU; - default: - throw std::invalid_argument("unsupported RNN mode for cuDNN"); - } + switch(mode) { + case RnnMode::RELU: + return CUDNN_RNN_RELU; + case RnnMode::TANH: + return CUDNN_RNN_TANH; + case RnnMode::LSTM: + return CUDNN_LSTM; + case RnnMode::GRU: + return CUDNN_GRU; + default: + throw std::invalid_argument("unsupported RNN mode for cuDNN"); + } } TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) { - CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor)); - cudnnDataType_t cudnntype = cudnnMapToType(type); - - std::array dims = {1, 1, 1, 1}; - // We want, if dims exist: - // {flDims[3], flDims[2], flDims[1], flDims[0]}; - for (unsigned i = 0; i < flDims.ndim(); ++i) { - dims[3 - i] = flDims[i]; - } - - // Sets strides so array is contiguous row-major for cudnn - std::vector r_strides = {1}; - for (auto it = dims.rbegin(); it != dims.rend() - 1; ++it) { - r_strides.push_back(r_strides.back() * (*it)); - } - std::vector strides(r_strides.rbegin(), r_strides.rend()); - - CUDNN_CHECK_ERR(cudnnSetTensorNdDescriptor( - descriptor, cudnntype, dims.size(), dims.data(), strides.data())); + CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor)); + cudnnDataType_t cudnntype = cudnnMapToType(type); + + std::array dims = {1, 1, 1, 1}; + // We want, if dims exist: + // {flDims[3], flDims[2], flDims[1], flDims[0]}; + for(unsigned i = 0; i < flDims.ndim(); ++i) { + dims[3 - i] = flDims[i]; + } + + // Sets strides so array is contiguous row-major for cudnn + std::vector r_strides = {1}; + for(auto it = dims.rbegin(); it != dims.rend() - 1; ++it) { + r_strides.push_back(r_strides.back() * (*it)); + } + std::vector strides(r_strides.rbegin(), r_strides.rend()); + + CUDNN_CHECK_ERR( + cudnnSetTensorNdDescriptor( + descriptor, + cudnntype, + dims.size(), + dims.data(), + strides.data() + ) + ); } TensorDescriptor::TensorDescriptor(const Tensor& input) { - CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor)); - cudnnDataType_t cudnntype = cudnnMapToType(input.type()); - - auto flStrides = input.strides(); - auto flDims = input.shape(); - - // reverse the dims (column -> row major) and cast to int type - std::array strides = {1, 1, 1, 1}; - // {flStrides[3], flStrides[2], flStrides[1], flStrides[0]}; - for (unsigned i = 0; i < flStrides.ndim(); ++i) { - strides[3 - i] = flStrides[i]; - } - - std::array dims = {1, 1, 1, 1}; - // {flDims[3], flDims[2], flDims[1], flDims[0]}; - for (unsigned i = 0; i < flDims.ndim(); ++i) { - dims[3 - i] = flDims[i]; - } - - CUDNN_CHECK_ERR(cudnnSetTensorNdDescriptor( - descriptor /* descriptor handle */, - cudnntype /* = dataType */, - 4, - dims.data(), - strides.data())); + CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor)); + cudnnDataType_t cudnntype = cudnnMapToType(input.type()); + + auto flStrides = input.strides(); + auto flDims = input.shape(); + + // reverse the dims (column -> row major) and cast to int type + std::array strides = {1, 1, 1, 1}; + // {flStrides[3], flStrides[2], flStrides[1], flStrides[0]}; + for(unsigned i = 0; i < flStrides.ndim(); ++i) { + strides[3 - i] = flStrides[i]; + } + + std::array dims = {1, 1, 1, 1}; + // {flDims[3], flDims[2], flDims[1], flDims[0]}; + for(unsigned i = 0; i < flDims.ndim(); ++i) { + dims[3 - i] = flDims[i]; + } + + CUDNN_CHECK_ERR( + cudnnSetTensorNdDescriptor( + descriptor /* descriptor handle */, + cudnntype /* = dataType */, + 4, + dims.data(), + strides.data() + ) + ); } TensorDescriptor::~TensorDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyTensorDescriptor(descriptor)); + CUDNN_CHECK_ERR(cudnnDestroyTensorDescriptor(descriptor)); } TensorDescriptorArray::TensorDescriptorArray( int size, const fl::dtype type, - const Shape& dims) { - desc_vec.reserve(size); - for (int i = 0; i < size; i++) { - desc_vec.emplace_back(type, dims); - desc_raw_vec.push_back(desc_vec.back().descriptor); - } - descriptors = desc_raw_vec.data(); + const Shape& dims +) { + desc_vec.reserve(size); + for(int i = 0; i < size; i++) { + desc_vec.emplace_back(type, dims); + desc_raw_vec.push_back(desc_vec.back().descriptor); + } + descriptors = desc_raw_vec.data(); } TensorDescriptorArray::~TensorDescriptorArray() = default; @@ -214,82 +225,109 @@ PoolingDescriptor::PoolingDescriptor( int sy, int px, int py, - PoolingMode mode) { - CUDNN_CHECK_ERR(cudnnCreatePoolingDescriptor(&descriptor)); - std::array window = {static_cast(wy), static_cast(wx)}; - std::array padding = {static_cast(py), static_cast(px)}; - std::array stride = {static_cast(sy), static_cast(sx)}; - - auto cudnnpoolingmode = cudnnMapToPoolingMode(mode); - CUDNN_CHECK_ERR(cudnnSetPoolingNdDescriptor( - descriptor, - cudnnpoolingmode, - CUDNN_PROPAGATE_NAN, - 2, - window.data(), - padding.data(), - stride.data())); + PoolingMode mode +) { + CUDNN_CHECK_ERR(cudnnCreatePoolingDescriptor(&descriptor)); + std::array window = {static_cast(wy), static_cast(wx)}; + std::array padding = {static_cast(py), static_cast(px)}; + std::array stride = {static_cast(sy), static_cast(sx)}; + + auto cudnnpoolingmode = cudnnMapToPoolingMode(mode); + CUDNN_CHECK_ERR( + cudnnSetPoolingNdDescriptor( + descriptor, + cudnnpoolingmode, + CUDNN_PROPAGATE_NAN, + 2, + window.data(), + padding.data(), + stride.data() + ) + ); } PoolingDescriptor::~PoolingDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyPoolingDescriptor(descriptor)); + CUDNN_CHECK_ERR(cudnnDestroyPoolingDescriptor(descriptor)); } FilterDescriptor::FilterDescriptor(const Tensor& input) { - CUDNN_CHECK_ERR(cudnnCreateFilterDescriptor(&descriptor)); - cudnnDataType_t cudnntype = cudnnMapToType(input.type()); - - auto flDims = input.shape(); - std::array dims = {1, 1, 1, 1}; - // We want, if dims exist: - // {flDims[3], flDims[2], flDims[1], flDims[0]}; - for (unsigned i = 0; i < flDims.ndim(); ++i) { - dims[3 - i] = flDims[i]; - } - - CUDNN_CHECK_ERR(cudnnSetFilterNdDescriptor( - descriptor, cudnntype, CUDNN_TENSOR_NCHW, 4, dims.data())); + CUDNN_CHECK_ERR(cudnnCreateFilterDescriptor(&descriptor)); + cudnnDataType_t cudnntype = cudnnMapToType(input.type()); + + auto flDims = input.shape(); + std::array dims = {1, 1, 1, 1}; + // We want, if dims exist: + // {flDims[3], flDims[2], flDims[1], flDims[0]}; + for(unsigned i = 0; i < flDims.ndim(); ++i) { + dims[3 - i] = flDims[i]; + } + + CUDNN_CHECK_ERR( + cudnnSetFilterNdDescriptor( + descriptor, + cudnntype, + CUDNN_TENSOR_NCHW, + 4, + dims.data() + ) + ); } FilterDescriptor::~FilterDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyFilterDescriptor(descriptor)); + CUDNN_CHECK_ERR(cudnnDestroyFilterDescriptor(descriptor)); } DropoutDescriptor::DropoutDescriptor(float drop_prob) { - CUDNN_CHECK_ERR(cudnnCreateDropoutDescriptor(&descriptor)); - auto cudnnHandle = getCudnnHandle(); - unsigned long long seed = 0; - size_t state_size; - CUDNN_CHECK_ERR(cudnnDropoutGetStatesSize(cudnnHandle, &state_size)); - auto& dropout_states = getDropoutStates(); - if (dropout_states.isEmpty()) { - dropout_states = - Tensor({static_cast(state_size)}, fl::dtype::b8); - DevicePtr statesraw(dropout_states); - CUDNN_CHECK_ERR(cudnnSetDropoutDescriptor( - descriptor, cudnnHandle, drop_prob, statesraw.get(), state_size, seed)); - } else { - DevicePtr statesraw(dropout_states); + CUDNN_CHECK_ERR(cudnnCreateDropoutDescriptor(&descriptor)); + auto cudnnHandle = getCudnnHandle(); + unsigned long long seed = 0; + size_t state_size; + CUDNN_CHECK_ERR(cudnnDropoutGetStatesSize(cudnnHandle, &state_size)); + auto& dropout_states = getDropoutStates(); + if(dropout_states.isEmpty()) { + dropout_states = + Tensor({static_cast(state_size)}, fl::dtype::b8); + DevicePtr statesraw(dropout_states); + CUDNN_CHECK_ERR( + cudnnSetDropoutDescriptor( + descriptor, + cudnnHandle, + drop_prob, + statesraw.get(), + state_size, + seed + ) + ); + } else { + DevicePtr statesraw(dropout_states); // See https://git.io/fp9oo for an explanation. #if CUDNN_VERSION >= 7000 - CUDNN_CHECK_ERR(cudnnRestoreDropoutDescriptor( - descriptor, cudnnHandle, drop_prob, statesraw.get(), state_size, seed)); + CUDNN_CHECK_ERR( + cudnnRestoreDropoutDescriptor( + descriptor, + cudnnHandle, + drop_prob, + statesraw.get(), + state_size, + seed + ) + ); #else - auto dropout_struct = reinterpret_cast(descriptor); - dropout_struct->dropout = drop_prob; - dropout_struct->nstates = state_size; - dropout_struct->states = statesraw.get(); + auto dropout_struct = reinterpret_cast(descriptor); + dropout_struct->dropout = drop_prob; + dropout_struct->nstates = state_size; + dropout_struct->states = statesraw.get(); #endif - } + } } DropoutDescriptor::~DropoutDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyDropoutDescriptor(descriptor)); + CUDNN_CHECK_ERR(cudnnDestroyDropoutDescriptor(descriptor)); } Tensor& DropoutDescriptor::getDropoutStates() { - thread_local Tensor dropout_states; - return dropout_states; + thread_local Tensor dropout_states; + return dropout_states; } RNNDescriptor::RNNDescriptor( @@ -298,49 +336,56 @@ RNNDescriptor::RNNDescriptor( int num_layers, RnnMode mode, bool bidirectional, - DropoutDescriptor& dropout) { - CUDNN_CHECK_ERR(cudnnCreateRNNDescriptor(&descriptor)); + DropoutDescriptor& dropout +) { + CUDNN_CHECK_ERR(cudnnCreateRNNDescriptor(&descriptor)); - auto cudnnHandle = getCudnnHandle(); + auto cudnnHandle = getCudnnHandle(); - cudnnRNNInputMode_t in_mode = CUDNN_LINEAR_INPUT; + cudnnRNNInputMode_t in_mode = CUDNN_LINEAR_INPUT; - cudnnDirectionMode_t dir = - bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; + cudnnDirectionMode_t dir = + bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL; - cudnnRNNMode_t cell = cudnnMapToRNNMode(mode); - cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; - cudnnDataType_t cudnntype = cudnnMapToType(type); + cudnnRNNMode_t cell = cudnnMapToRNNMode(mode); + cudnnRNNAlgo_t algo = CUDNN_RNN_ALGO_STANDARD; + cudnnDataType_t cudnntype = cudnnMapToType(type); #if CUDNN_VERSION >= 7000 && CUDNN_VERSION < 8000 - CUDNN_CHECK_ERR(cudnnSetRNNDescriptor( - cudnnHandle, - descriptor, - hidden_size, - num_layers, - dropout.descriptor, - in_mode, - dir, - cell, - algo, - cudnntype)); + CUDNN_CHECK_ERR( + cudnnSetRNNDescriptor( + cudnnHandle, + descriptor, + hidden_size, + num_layers, + dropout.descriptor, + in_mode, + dir, + cell, + algo, + cudnntype + ) + ); #else - CUDNN_CHECK_ERR(cudnnSetRNNDescriptor_v6( - cudnnHandle, - descriptor, - hidden_size, - num_layers, - dropout.descriptor, - in_mode, - dir, - cell, - algo, - cudnntype)); + CUDNN_CHECK_ERR( + cudnnSetRNNDescriptor_v6( + cudnnHandle, + descriptor, + hidden_size, + num_layers, + dropout.descriptor, + in_mode, + dir, + cell, + algo, + cudnntype + ) + ); #endif } RNNDescriptor::~RNNDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyRNNDescriptor(descriptor)); + CUDNN_CHECK_ERR(cudnnDestroyRNNDescriptor(descriptor)); } ConvDescriptor::ConvDescriptor( @@ -351,59 +396,63 @@ ConvDescriptor::ConvDescriptor( int sy, int dx, int dy, - int groups) { - CUDNN_CHECK_ERR(cudnnCreateConvolutionDescriptor(&descriptor)); - cudnnDataType_t cudnntype = cudnnMapToType(type); - std::array padding = {static_cast(py), static_cast(px)}; - std::array stride = {static_cast(sy), static_cast(sx)}; - std::array dilation = {static_cast(dy), static_cast(dx)}; - - CUDNN_CHECK_ERR(cudnnSetConvolutionNdDescriptor( - descriptor, - 2, - padding.data(), - stride.data(), - dilation.data(), - CUDNN_CROSS_CORRELATION, - cudnntype)); - - CUDNN_CHECK_ERR(cudnnSetConvolutionGroupCount(descriptor, groups)); + int groups +) { + CUDNN_CHECK_ERR(cudnnCreateConvolutionDescriptor(&descriptor)); + cudnnDataType_t cudnntype = cudnnMapToType(type); + std::array padding = {static_cast(py), static_cast(px)}; + std::array stride = {static_cast(sy), static_cast(sx)}; + std::array dilation = {static_cast(dy), static_cast(dx)}; + + CUDNN_CHECK_ERR( + cudnnSetConvolutionNdDescriptor( + descriptor, + 2, + padding.data(), + stride.data(), + dilation.data(), + CUDNN_CROSS_CORRELATION, + cudnntype + ) + ); + + CUDNN_CHECK_ERR(cudnnSetConvolutionGroupCount(descriptor, groups)); } ConvDescriptor::~ConvDescriptor() { - CUDNN_CHECK_ERR(cudnnDestroyConvolutionDescriptor(descriptor)); + CUDNN_CHECK_ERR(cudnnDestroyConvolutionDescriptor(descriptor)); } cudnnHandle_t getCudnnHandle() { - return getActiveDeviceHandle().cudnnHandle; + return getActiveDeviceHandle().cudnnHandle; } const CUDAStream& getCudnnStream() { - return *getActiveDeviceHandle().stream; + return *getActiveDeviceHandle().stream; } const void* kOne(const fl::dtype t) { - switch (t) { - case fl::dtype::f16: - case fl::dtype::f32: - return &kFloatOne; - case fl::dtype::f64: - return &kDoubleOne; - default: - throw std::invalid_argument("unsupported data type for cuDNN"); - } + switch(t) { + case fl::dtype::f16: + case fl::dtype::f32: + return &kFloatOne; + case fl::dtype::f64: + return &kDoubleOne; + default: + throw std::invalid_argument("unsupported data type for cuDNN"); + } } const void* kZero(const fl::dtype t) { - switch (t) { - case fl::dtype::f16: - case fl::dtype::f32: - return &kFloatZero; - case fl::dtype::f64: - return &kDoubleZero; - default: - throw std::invalid_argument("unsupported data type for cuDNN"); - } + switch(t) { + case fl::dtype::f16: + case fl::dtype::f32: + return &kFloatZero; + case fl::dtype::f64: + return &kDoubleZero; + default: + throw std::invalid_argument("unsupported data type for cuDNN"); + } } } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h index 405fe2a..fca9969 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h @@ -16,83 +16,86 @@ namespace fl { class TensorDescriptor { - public: - explicit TensorDescriptor(const Tensor& a); +public: + explicit TensorDescriptor(const Tensor& a); - TensorDescriptor(const fl::dtype type, const Shape& af_dims); + TensorDescriptor(const fl::dtype type, const Shape& af_dims); - cudnnTensorDescriptor_t descriptor; - ~TensorDescriptor(); + cudnnTensorDescriptor_t descriptor; + ~TensorDescriptor(); }; class TensorDescriptorArray { - public: - TensorDescriptorArray(int size, const fl::dtype type, const Shape& dims); +public: + TensorDescriptorArray(int size, const fl::dtype type, const Shape& dims); - cudnnTensorDescriptor_t* descriptors; - ~TensorDescriptorArray(); + cudnnTensorDescriptor_t* descriptors; + ~TensorDescriptorArray(); - private: - std::vector desc_vec; - std::vector desc_raw_vec; +private: + std::vector desc_vec; + std::vector desc_raw_vec; }; class FilterDescriptor { - public: - explicit FilterDescriptor(const Tensor& a); - cudnnFilterDescriptor_t descriptor; - ~FilterDescriptor(); +public: + explicit FilterDescriptor(const Tensor& a); + cudnnFilterDescriptor_t descriptor; + ~FilterDescriptor(); }; class ConvDescriptor { - public: - ConvDescriptor( - fl::dtype type, - int px, - int py, - int sx, - int sy, - int dx, - int dy, - int groups = 1); - cudnnConvolutionDescriptor_t descriptor; - ~ConvDescriptor(); +public: + ConvDescriptor( + fl::dtype type, + int px, + int py, + int sx, + int sy, + int dx, + int dy, + int groups = 1 + ); + cudnnConvolutionDescriptor_t descriptor; + ~ConvDescriptor(); }; class PoolingDescriptor { - public: - PoolingDescriptor( - int wx, - int wy, - int sx, - int sy, - int px, - int py, - PoolingMode mode); - cudnnPoolingDescriptor_t descriptor; - ~PoolingDescriptor(); +public: + PoolingDescriptor( + int wx, + int wy, + int sx, + int sy, + int px, + int py, + PoolingMode mode + ); + cudnnPoolingDescriptor_t descriptor; + ~PoolingDescriptor(); }; class DropoutDescriptor { - public: - explicit DropoutDescriptor(float drop_prob); - cudnnDropoutDescriptor_t descriptor; - ~DropoutDescriptor(); +public: + explicit DropoutDescriptor(float drop_prob); + cudnnDropoutDescriptor_t descriptor; + ~DropoutDescriptor(); - Tensor& getDropoutStates(); + Tensor& getDropoutStates(); }; class RNNDescriptor { - public: - RNNDescriptor( - fl::dtype type, - int hidden_size, - int num_layers, - RnnMode mode, - bool bidirectional, - DropoutDescriptor& dropout); - cudnnRNNDescriptor_t descriptor; - ~RNNDescriptor(); +public: + RNNDescriptor( + fl::dtype type, + int hidden_size, + int num_layers, + RnnMode mode, + bool bidirectional, + DropoutDescriptor& dropout + ); + cudnnRNNDescriptor_t descriptor; + ~RNNDescriptor(); }; #define CUDNN_CHECK_ERR(expr) ::fl::cudnnCheckErr((expr)) diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp index cde9499..f9956c1 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp @@ -23,51 +23,56 @@ Tensor CudnnAutogradExtension::pool2d( const int px, const int py, const PoolingMode mode, - std::shared_ptr) { - auto inDesc = TensorDescriptor(input); - - // init pooling descriptor - auto poolDesc = PoolingDescriptor(wx, wy, sx, sy, px, py, mode); - - // init output descriptor - auto ix = input.dim(0); - auto iy = input.ndim() < 2 ? 1 : input.dim(1); - auto ox = 1 + (ix + 2 * px - wx) / sx; - auto oy = 1 + (iy + 2 * py - wy) / sy; - - auto output = Tensor( - {ox, - oy, - input.ndim() < 3 ? 1 : input.dim(2), - input.ndim() < 4 ? 1 : input.dim(3)}, - input.type()); - auto outDesc = TensorDescriptor(output); - { - DevicePtr inputraw(input); - DevicePtr resultraw(output); - const auto& cudnnStream = getCudnnStream(); - // ensure cudnn compute stream waits on streams of input/output tensors - relativeSync(cudnnStream, {input, output}); - - auto handle = getCudnnHandle(); - const void* one = kOne(input.type()); - const void* zero = kZero(input.type()); - - CUDNN_CHECK_ERR(cudnnPoolingForward( - handle, - poolDesc.descriptor, - one, - inDesc.descriptor, - inputraw.get(), - zero, - outDesc.descriptor, - resultraw.get())); - - // ensure output tensor stream waits on cudnn compute stream - relativeSync({output}, cudnnStream); - } - - return output; + std::shared_ptr +) { + auto inDesc = TensorDescriptor(input); + + // init pooling descriptor + auto poolDesc = PoolingDescriptor(wx, wy, sx, sy, px, py, mode); + + // init output descriptor + auto ix = input.dim(0); + auto iy = input.ndim() < 2 ? 1 : input.dim(1); + auto ox = 1 + (ix + 2 * px - wx) / sx; + auto oy = 1 + (iy + 2 * py - wy) / sy; + + auto output = Tensor( + {ox, + oy, + input.ndim() < 3 ? 1 : input.dim(2), + input.ndim() < 4 ? 1 : input.dim(3)}, + input.type() + ); + auto outDesc = TensorDescriptor(output); + { + DevicePtr inputraw(input); + DevicePtr resultraw(output); + const auto& cudnnStream = getCudnnStream(); + // ensure cudnn compute stream waits on streams of input/output tensors + relativeSync(cudnnStream, {input, output}); + + auto handle = getCudnnHandle(); + const void* one = kOne(input.type()); + const void* zero = kZero(input.type()); + + CUDNN_CHECK_ERR( + cudnnPoolingForward( + handle, + poolDesc.descriptor, + one, + inDesc.descriptor, + inputraw.get(), + zero, + outDesc.descriptor, + resultraw.get() + ) + ); + + // ensure output tensor stream waits on cudnn compute stream + relativeSync({output}, cudnnStream); + } + + return output; } Tensor CudnnAutogradExtension::pool2dBackward( @@ -81,44 +86,48 @@ Tensor CudnnAutogradExtension::pool2dBackward( const int px, const int py, const PoolingMode mode, - std::shared_ptr) { - auto i_desc = TensorDescriptor(input); - auto o_desc = TensorDescriptor(poolOutput); - auto p_desc = PoolingDescriptor(wx, wy, sx, sy, px, py, mode); - - auto gradInput = Tensor(input.shape(), input.type()); - - auto hndl = getCudnnHandle(); - const auto& cudnnStream = getCudnnStream(); - const void* oneg = kOne(input.type()); - const void* zerog = kZero(input.type()); - - { - DevicePtr inraw(input); - DevicePtr outraw(poolOutput); - DevicePtr gradresultraw(gradOutput); - DevicePtr gradinputraw(gradInput); - // ensure cudnn compute stream waits on input/output tensor streams - relativeSync(cudnnStream, {input, poolOutput, gradOutput, gradInput}); - - CUDNN_CHECK_ERR(cudnnPoolingBackward( - hndl, - p_desc.descriptor, - oneg, - o_desc.descriptor, - outraw.get(), - o_desc.descriptor, - gradresultraw.get(), - i_desc.descriptor, - inraw.get(), - zerog, - i_desc.descriptor, - gradinputraw.get())); - // ensure gradient input tensor stream waits on cudnn compute stream - relativeSync({gradInput}, cudnnStream); - } - - return gradInput; + std::shared_ptr +) { + auto i_desc = TensorDescriptor(input); + auto o_desc = TensorDescriptor(poolOutput); + auto p_desc = PoolingDescriptor(wx, wy, sx, sy, px, py, mode); + + auto gradInput = Tensor(input.shape(), input.type()); + + auto hndl = getCudnnHandle(); + const auto& cudnnStream = getCudnnStream(); + const void* oneg = kOne(input.type()); + const void* zerog = kZero(input.type()); + + { + DevicePtr inraw(input); + DevicePtr outraw(poolOutput); + DevicePtr gradresultraw(gradOutput); + DevicePtr gradinputraw(gradInput); + // ensure cudnn compute stream waits on input/output tensor streams + relativeSync(cudnnStream, {input, poolOutput, gradOutput, gradInput}); + + CUDNN_CHECK_ERR( + cudnnPoolingBackward( + hndl, + p_desc.descriptor, + oneg, + o_desc.descriptor, + outraw.get(), + o_desc.descriptor, + gradresultraw.get(), + i_desc.descriptor, + inraw.get(), + zerog, + i_desc.descriptor, + gradinputraw.get() + ) + ); + // ensure gradient input tensor stream waits on cudnn compute stream + relativeSync({gradInput}, cudnnStream); + } + + return gradInput; } } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp index c86b3d9..1c85956 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp @@ -15,45 +15,62 @@ namespace fl { namespace { -size_t getWorkspaceSize( - cudnnHandle_t handle, - const RNNDescriptor& rnnDesc, - const int seqLength, - const TensorDescriptorArray& xDescs) { - size_t workspaceSize; - CUDNN_CHECK_ERR(cudnnGetRNNWorkspaceSize( - handle, - rnnDesc.descriptor, - seqLength, - xDescs.descriptors, - &workspaceSize)); - return workspaceSize; -} + size_t getWorkspaceSize( + cudnnHandle_t handle, + const RNNDescriptor& rnnDesc, + const int seqLength, + const TensorDescriptorArray& xDescs + ) { + size_t workspaceSize; + CUDNN_CHECK_ERR( + cudnnGetRNNWorkspaceSize( + handle, + rnnDesc.descriptor, + seqLength, + xDescs.descriptors, + &workspaceSize + ) + ); + return workspaceSize; + } -size_t getReserveSize( - cudnnHandle_t handle, - const RNNDescriptor& rnnDesc, - const int seqLength, - const TensorDescriptorArray& xDescs) { - size_t reserveSize; - CUDNN_CHECK_ERR(cudnnGetRNNTrainingReserveSize( - handle, rnnDesc.descriptor, seqLength, xDescs.descriptors, &reserveSize)); - return reserveSize; -} + size_t getReserveSize( + cudnnHandle_t handle, + const RNNDescriptor& rnnDesc, + const int seqLength, + const TensorDescriptorArray& xDescs + ) { + size_t reserveSize; + CUDNN_CHECK_ERR( + cudnnGetRNNTrainingReserveSize( + handle, + rnnDesc.descriptor, + seqLength, + xDescs.descriptors, + &reserveSize + ) + ); + return reserveSize; + } -void setCudnnRnnMathType(const Tensor& input, const RNNDescriptor& rnnDesc) { - if (input.type() == fl::dtype::f16) { - CUDNN_CHECK_ERR(cudnnSetRNNMatrixMathType( - rnnDesc.descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); - } else { - CUDNN_CHECK_ERR( - cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH)); - } -} + void setCudnnRnnMathType(const Tensor& input, const RNNDescriptor& rnnDesc) { + if(input.type() == fl::dtype::f16) { + CUDNN_CHECK_ERR( + cudnnSetRNNMatrixMathType( + rnnDesc.descriptor, + CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION + ) + ); + } else { + CUDNN_CHECK_ERR( + cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH) + ); + } + } -struct CudnnRnnAutogradPayload : public detail::AutogradPayloadData { - Tensor reserveSpace; -}; + struct CudnnRnnAutogradPayload : public detail::AutogradPayloadData { + Tensor reserveSpace; + }; } // namespace @@ -67,141 +84,156 @@ std::tuple CudnnAutogradExtension::rnn( const RnnMode mode, const bool bidirectional, const float dropProb, - std::shared_ptr autogradPayload) { - FL_TENSOR_DTYPES_MATCH_CHECK(input, hiddenStateIn, cellStateIn, weights); - - bool train = (autogradPayload != nullptr); - auto payload = std::make_shared(); - if (train) { - autogradPayload->data = payload; - } - - Tensor x = input.asContiguousTensor(); - Tensor hiddenState = hiddenStateIn.asContiguousTensor(); - Tensor cellState = cellStateIn.asContiguousTensor(); - - DropoutDescriptor dropout(dropProb); - RNNDescriptor rnnDesc( - input.type(), hiddenSize, numLayers, mode, bidirectional, dropout); - setCudnnRnnMathType(input, rnnDesc); - - auto dims = x.shape(); - int inputSize = dims[0]; - int batchSize = dims.ndim() < 2 ? 1 : dims[1]; - int seqLength = dims.ndim() < 3 ? 1 : dims[2]; - - int totalLayers = numLayers * (bidirectional ? 2 : 1); - int outSize = hiddenSize * (bidirectional ? 2 : 1); - - TensorDescriptorArray xDescs( - seqLength, x.type(), {1, 1, inputSize, batchSize}); - - if (!hiddenState.isEmpty()) { - auto hxDims = hiddenState.shape(); - int hxHiddenSize = hxDims[0]; - int hxBatchSize = hiddenState.ndim() < 2 ? 1 : hxDims[1]; - int hxTotalLayers = hiddenState.ndim() < 3 ? 1 : hxDims[2]; - - if (!(hxHiddenSize == hiddenSize && hxBatchSize == batchSize && - hxTotalLayers == totalLayers)) { - throw std::invalid_argument("invalid hidden state dims for RNN"); + std::shared_ptr autogradPayload +) { + FL_TENSOR_DTYPES_MATCH_CHECK(input, hiddenStateIn, cellStateIn, weights); + + bool train = (autogradPayload != nullptr); + auto payload = std::make_shared(); + if(train) { + autogradPayload->data = payload; } - } - - if (!cellState.isEmpty() && - !(mode == RnnMode::LSTM && cellState.dim(0) == hiddenSize && - cellState.dim(1) == batchSize && cellState.dim(2) == totalLayers)) { - throw std::invalid_argument("invalid cell state dims for RNN"); - } - - Shape hDims = {1, hiddenSize, batchSize, totalLayers}; - TensorDescriptor hxDesc(x.type(), hDims); - TensorDescriptor cxDesc(x.type(), hDims); - - auto handle = getCudnnHandle(); - const auto& cudnnStream = getCudnnStream(); - - size_t paramSize; - CUDNN_CHECK_ERR(cudnnGetRNNParamsSize( - handle, - rnnDesc.descriptor, - xDescs.descriptors[0], - ¶mSize, - cudnnMapToType(weights.type()))); - if (paramSize != weights.bytes()) { - throw std::invalid_argument( - "invalid # of parameters or wrong input shape for RNN"); - } - FilterDescriptor wDesc(weights); - - Tensor y({outSize, batchSize, seqLength}, input.type()); - TensorDescriptorArray yDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); - - Tensor hy({hiddenSize, batchSize, totalLayers}, x.type()); - TensorDescriptor hyDesc(x.type(), hDims); - - Tensor cy; - if (mode == RnnMode::LSTM) { - cy = Tensor(hy.shape(), x.type()); - } - - TensorDescriptor cyDesc(x.type(), hDims); - - size_t workspaceSize = - getWorkspaceSize(handle, rnnDesc, seqLength, xDescs); - size_t reserveSize = - getReserveSize(handle, rnnDesc, seqLength, xDescs); - - Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8); - // Space must be reused between forward and backward for cuDNN - payload->reserveSpace = - Tensor({static_cast(reserveSize)}, fl::dtype::b8); - - { - auto contiguousX = x.asContiguousTensor(); - auto contiguousWeights = weights.asContiguousTensor(); - DevicePtr xRaw(contiguousX); - DevicePtr hxRaw(hiddenState); - DevicePtr cxRaw(cellState); - DevicePtr wRaw(contiguousWeights); - DevicePtr yRaw(y); - DevicePtr hyRaw(hy); - DevicePtr cyRaw(cy); - DevicePtr workspaceRaw(workspace); - DevicePtr reserveSpaceRaw(payload->reserveSpace); - // ensure cudnn compute stream waits on input/output tensor streams - relativeSync(cudnnStream, { - contiguousX, hiddenState, cellState, contiguousWeights, y, hy, cy, - workspace, payload->reserveSpace, - }); - - CUDNN_CHECK_ERR(cudnnRNNForwardTraining( - handle, - rnnDesc.descriptor, - seqLength, - xDescs.descriptors, - xRaw.get(), - hxDesc.descriptor, - hxRaw.get(), - cxDesc.descriptor, - cxRaw.get(), - wDesc.descriptor, - wRaw.get(), - yDesc.descriptors, - yRaw.get(), - hyDesc.descriptor, - hyRaw.get(), - cyDesc.descriptor, - cyRaw.get(), - workspaceRaw.get(), - workspaceSize, - reserveSpaceRaw.get(), - reserveSize)); - } - - // ensure output tensor streams wait on cudnn compute stream - relativeSync({y, hy, cy}, cudnnStream); - return std::make_tuple(y, hy, cy); + + Tensor x = input.asContiguousTensor(); + Tensor hiddenState = hiddenStateIn.asContiguousTensor(); + Tensor cellState = cellStateIn.asContiguousTensor(); + + DropoutDescriptor dropout(dropProb); + RNNDescriptor rnnDesc( + input.type(), hiddenSize, numLayers, mode, bidirectional, dropout); + setCudnnRnnMathType(input, rnnDesc); + + auto dims = x.shape(); + int inputSize = dims[0]; + int batchSize = dims.ndim() < 2 ? 1 : dims[1]; + int seqLength = dims.ndim() < 3 ? 1 : dims[2]; + + int totalLayers = numLayers * (bidirectional ? 2 : 1); + int outSize = hiddenSize * (bidirectional ? 2 : 1); + + TensorDescriptorArray xDescs( + seqLength, x.type(), {1, 1, inputSize, batchSize}); + + if(!hiddenState.isEmpty()) { + auto hxDims = hiddenState.shape(); + int hxHiddenSize = hxDims[0]; + int hxBatchSize = hiddenState.ndim() < 2 ? 1 : hxDims[1]; + int hxTotalLayers = hiddenState.ndim() < 3 ? 1 : hxDims[2]; + + if( + !(hxHiddenSize == hiddenSize && hxBatchSize == batchSize + && hxTotalLayers == totalLayers) + ) { + throw std::invalid_argument("invalid hidden state dims for RNN"); + } + } + + if( + !cellState.isEmpty() + && !(mode == RnnMode::LSTM && cellState.dim(0) == hiddenSize + && cellState.dim(1) == batchSize && cellState.dim(2) == totalLayers) + ) { + throw std::invalid_argument("invalid cell state dims for RNN"); + } + + Shape hDims = {1, hiddenSize, batchSize, totalLayers}; + TensorDescriptor hxDesc(x.type(), hDims); + TensorDescriptor cxDesc(x.type(), hDims); + + auto handle = getCudnnHandle(); + const auto& cudnnStream = getCudnnStream(); + + size_t paramSize; + CUDNN_CHECK_ERR( + cudnnGetRNNParamsSize( + handle, + rnnDesc.descriptor, + xDescs.descriptors[0], + ¶mSize, + cudnnMapToType(weights.type()) + ) + ); + if(paramSize != weights.bytes()) { + throw std::invalid_argument( + "invalid # of parameters or wrong input shape for RNN" + ); + } + FilterDescriptor wDesc(weights); + + Tensor y({outSize, batchSize, seqLength}, input.type()); + TensorDescriptorArray yDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); + + Tensor hy({hiddenSize, batchSize, totalLayers}, x.type()); + TensorDescriptor hyDesc(x.type(), hDims); + + Tensor cy; + if(mode == RnnMode::LSTM) { + cy = Tensor(hy.shape(), x.type()); + } + + TensorDescriptor cyDesc(x.type(), hDims); + + size_t workspaceSize = + getWorkspaceSize(handle, rnnDesc, seqLength, xDescs); + size_t reserveSize = + getReserveSize(handle, rnnDesc, seqLength, xDescs); + + Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8); + // Space must be reused between forward and backward for cuDNN + payload->reserveSpace = + Tensor({static_cast(reserveSize)}, fl::dtype::b8); + + { + auto contiguousX = x.asContiguousTensor(); + auto contiguousWeights = weights.asContiguousTensor(); + DevicePtr xRaw(contiguousX); + DevicePtr hxRaw(hiddenState); + DevicePtr cxRaw(cellState); + DevicePtr wRaw(contiguousWeights); + DevicePtr yRaw(y); + DevicePtr hyRaw(hy); + DevicePtr cyRaw(cy); + DevicePtr workspaceRaw(workspace); + DevicePtr reserveSpaceRaw(payload->reserveSpace); + // ensure cudnn compute stream waits on input/output tensor streams + relativeSync( + cudnnStream, + { + contiguousX, hiddenState, cellState, contiguousWeights, y, hy, cy, + workspace, payload->reserveSpace, + } + ); + + CUDNN_CHECK_ERR( + cudnnRNNForwardTraining( + handle, + rnnDesc.descriptor, + seqLength, + xDescs.descriptors, + xRaw.get(), + hxDesc.descriptor, + hxRaw.get(), + cxDesc.descriptor, + cxRaw.get(), + wDesc.descriptor, + wRaw.get(), + yDesc.descriptors, + yRaw.get(), + hyDesc.descriptor, + hyRaw.get(), + cyDesc.descriptor, + cyRaw.get(), + workspaceRaw.get(), + workspaceSize, + reserveSpaceRaw.get(), + reserveSize + ) + ); + } + + // ensure output tensor streams wait on cudnn compute stream + relativeSync({y, hy, cy}, cudnnStream); + return std::make_tuple(y, hy, cy); } std::tuple CudnnAutogradExtension::rnnBackward( @@ -216,160 +248,174 @@ std::tuple CudnnAutogradExtension::rnnBackward( const RnnMode mode, const bool bidirectional, const float dropProb, - std::shared_ptr autogradPayload) { - if (!autogradPayload) { - throw std::invalid_argument( - "CudnnAutogradExtension::rnnBackward given null detail::AutogradPayload"); - } - auto payload = - std::static_pointer_cast(autogradPayload->data); - - auto handle = getCudnnHandle(); - const auto& cudnnStream = getCudnnStream(); - - auto x = input.asContiguousTensor(); - auto& y = output; - - auto dims = x.shape(); - int inputSize = dims[0]; - int batchSize = dims.ndim() < 2 ? 1 : dims[1]; - int seqLength = dims.ndim() < 3 ? 1 : dims[2]; - int totalLayers = numLayers * (bidirectional ? 2 : 1); - int outSize = hiddenSize * (bidirectional ? 2 : 1); - - DropoutDescriptor dropout(dropProb); - RNNDescriptor rnnDesc( - input.type(), hiddenSize, numLayers, mode, bidirectional, dropout); - setCudnnRnnMathType(input, rnnDesc); - - TensorDescriptorArray yDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); - TensorDescriptorArray dyDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); - - Shape hDims = {1, hiddenSize, batchSize, totalLayers}; - TensorDescriptor dhyDesc(x.type(), hDims); - TensorDescriptor dcyDesc(x.type(), hDims); - TensorDescriptor hxDesc(x.type(), hDims); - TensorDescriptor cxDesc(x.type(), hDims); - - Tensor dhx(hiddenState.shape(), hiddenState.type()); - Tensor dcx(cellState.shape(), cellState.type()); - TensorDescriptor dhxDesc(x.type(), hDims); - TensorDescriptor dcxDesc(x.type(), hDims); - - FilterDescriptor wDesc(weights); - - Tensor dx(input.shape(), input.type()); - TensorDescriptorArray dxDescs( - seqLength, dx.type(), {1, 1, inputSize, batchSize}); - - size_t workspaceSize = - getWorkspaceSize(handle, rnnDesc, seqLength, dxDescs); - Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8); - - auto& dy = gradData->dy; - if (dy.isEmpty()) { - dy = fl::full(y.shape(), 0.0, y.type()); - } - auto& dhy = gradData->dhy; - auto& dcy = gradData->dcy; - - DevicePtr yRaw(output); - DevicePtr workspaceRaw(workspace); - DevicePtr reserveSpaceRaw(payload->reserveSpace); - // ensure cudnn compute stream waits on input/output tensor streams - relativeSync(cudnnStream, {output, workspace, payload->reserveSpace}); - - { - DevicePtr dyRaw(dy); // Has to be set to 0 if empty - DevicePtr dhyRaw(dhy); - DevicePtr dcyRaw(dcy); - - DevicePtr wRaw(weights); - - DevicePtr hxRaw(hiddenState); - DevicePtr cxRaw(cellState); - - DevicePtr dxRaw(dx); - DevicePtr dhxRaw(dhx); - DevicePtr dcxRaw(dcx); - // ensure cudnn compute stream waits on input/output tensor streams - relativeSync( - cudnnStream, - {dy, dhy, dcy, weights, hiddenState, cellState, dx, dhx, dcx}); - - /* We need to update reserveSpace even if we just want the - * weight gradients. */ - CUDNN_CHECK_ERR(cudnnRNNBackwardData( - handle, - rnnDesc.descriptor, - seqLength, - yDesc.descriptors, - yRaw.get(), - dyDesc.descriptors, - dyRaw.get(), - dhyDesc.descriptor, - dhyRaw.get(), - dcyDesc.descriptor, - dcyRaw.get(), - wDesc.descriptor, - wRaw.get(), - hxDesc.descriptor, - hxRaw.get(), - cxDesc.descriptor, - cxRaw.get(), - dxDescs.descriptors, - dxRaw.get(), - dhxDesc.descriptor, - dhxRaw.get(), - dcxDesc.descriptor, - dcxRaw.get(), - workspaceRaw.get(), - workspaceSize, - reserveSpaceRaw.get(), - payload->reserveSpace.bytes())); - } - - if (input.type() == fl::dtype::f16) { - CUDNN_CHECK_ERR(cudnnSetRNNMatrixMathType( - rnnDesc.descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION)); - } else { - CUDNN_CHECK_ERR( - cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH)); - } - TensorDescriptorArray xDescs( - seqLength, x.type(), {1, 1, inputSize, batchSize}); - Tensor dw = fl::full(weights.shape(), 0, weights.type()); - - FilterDescriptor dwDesc(dw); - - { - DevicePtr xRaw(x); - DevicePtr dwRaw(dw); - DevicePtr hxRaw(hiddenState); + std::shared_ptr autogradPayload +) { + if(!autogradPayload) { + throw std::invalid_argument( + "CudnnAutogradExtension::rnnBackward given null detail::AutogradPayload" + ); + } + auto payload = + std::static_pointer_cast(autogradPayload->data); + + auto handle = getCudnnHandle(); + const auto& cudnnStream = getCudnnStream(); + + auto x = input.asContiguousTensor(); + auto& y = output; + + auto dims = x.shape(); + int inputSize = dims[0]; + int batchSize = dims.ndim() < 2 ? 1 : dims[1]; + int seqLength = dims.ndim() < 3 ? 1 : dims[2]; + int totalLayers = numLayers * (bidirectional ? 2 : 1); + int outSize = hiddenSize * (bidirectional ? 2 : 1); + + DropoutDescriptor dropout(dropProb); + RNNDescriptor rnnDesc( + input.type(), hiddenSize, numLayers, mode, bidirectional, dropout); + setCudnnRnnMathType(input, rnnDesc); + + TensorDescriptorArray yDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); + TensorDescriptorArray dyDesc(seqLength, y.type(), {1, 1, outSize, batchSize}); + + Shape hDims = {1, hiddenSize, batchSize, totalLayers}; + TensorDescriptor dhyDesc(x.type(), hDims); + TensorDescriptor dcyDesc(x.type(), hDims); + TensorDescriptor hxDesc(x.type(), hDims); + TensorDescriptor cxDesc(x.type(), hDims); + + Tensor dhx(hiddenState.shape(), hiddenState.type()); + Tensor dcx(cellState.shape(), cellState.type()); + TensorDescriptor dhxDesc(x.type(), hDims); + TensorDescriptor dcxDesc(x.type(), hDims); + + FilterDescriptor wDesc(weights); + + Tensor dx(input.shape(), input.type()); + TensorDescriptorArray dxDescs( + seqLength, dx.type(), {1, 1, inputSize, batchSize}); + + size_t workspaceSize = + getWorkspaceSize(handle, rnnDesc, seqLength, dxDescs); + Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8); + + auto& dy = gradData->dy; + if(dy.isEmpty()) { + dy = fl::full(y.shape(), 0.0, y.type()); + } + auto& dhy = gradData->dhy; + auto& dcy = gradData->dcy; + + DevicePtr yRaw(output); + DevicePtr workspaceRaw(workspace); + DevicePtr reserveSpaceRaw(payload->reserveSpace); // ensure cudnn compute stream waits on input/output tensor streams - relativeSync(cudnnStream, {x, dw, hiddenState}); - - CUDNN_CHECK_ERR(cudnnRNNBackwardWeights( - handle, - rnnDesc.descriptor, - seqLength, - xDescs.descriptors, - xRaw.get(), - hxDesc.descriptor, - hxRaw.get(), - yDesc.descriptors, - yRaw.get(), - workspaceRaw.get(), - workspaceSize, - dwDesc.descriptor, - dwRaw.get(), - reserveSpaceRaw.get(), - payload->reserveSpace.bytes())); - } - - // ensure output tensor streams wait on cudnn compute stream - relativeSync({dx, dhx, dcx, dw}, cudnnStream); - return std::make_tuple(dx, dhx, dcx, dw); + relativeSync(cudnnStream, {output, workspace, payload->reserveSpace}); + + { + DevicePtr dyRaw(dy); // Has to be set to 0 if empty + DevicePtr dhyRaw(dhy); + DevicePtr dcyRaw(dcy); + + DevicePtr wRaw(weights); + + DevicePtr hxRaw(hiddenState); + DevicePtr cxRaw(cellState); + + DevicePtr dxRaw(dx); + DevicePtr dhxRaw(dhx); + DevicePtr dcxRaw(dcx); + // ensure cudnn compute stream waits on input/output tensor streams + relativeSync( + cudnnStream, + {dy, dhy, dcy, weights, hiddenState, cellState, dx, dhx, dcx} + ); + + /* We need to update reserveSpace even if we just want the + * weight gradients. */ + CUDNN_CHECK_ERR( + cudnnRNNBackwardData( + handle, + rnnDesc.descriptor, + seqLength, + yDesc.descriptors, + yRaw.get(), + dyDesc.descriptors, + dyRaw.get(), + dhyDesc.descriptor, + dhyRaw.get(), + dcyDesc.descriptor, + dcyRaw.get(), + wDesc.descriptor, + wRaw.get(), + hxDesc.descriptor, + hxRaw.get(), + cxDesc.descriptor, + cxRaw.get(), + dxDescs.descriptors, + dxRaw.get(), + dhxDesc.descriptor, + dhxRaw.get(), + dcxDesc.descriptor, + dcxRaw.get(), + workspaceRaw.get(), + workspaceSize, + reserveSpaceRaw.get(), + payload->reserveSpace.bytes() + ) + ); + } + + if(input.type() == fl::dtype::f16) { + CUDNN_CHECK_ERR( + cudnnSetRNNMatrixMathType( + rnnDesc.descriptor, + CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION + ) + ); + } else { + CUDNN_CHECK_ERR( + cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH) + ); + } + TensorDescriptorArray xDescs( + seqLength, x.type(), {1, 1, inputSize, batchSize}); + Tensor dw = fl::full(weights.shape(), 0, weights.type()); + + FilterDescriptor dwDesc(dw); + + { + DevicePtr xRaw(x); + DevicePtr dwRaw(dw); + DevicePtr hxRaw(hiddenState); + // ensure cudnn compute stream waits on input/output tensor streams + relativeSync(cudnnStream, {x, dw, hiddenState}); + + CUDNN_CHECK_ERR( + cudnnRNNBackwardWeights( + handle, + rnnDesc.descriptor, + seqLength, + xDescs.descriptors, + xRaw.get(), + hxDesc.descriptor, + hxRaw.get(), + yDesc.descriptors, + yRaw.get(), + workspaceRaw.get(), + workspaceSize, + dwDesc.descriptor, + dwRaw.get(), + reserveSpaceRaw.get(), + payload->reserveSpace.bytes() + ) + ); + } + + // ensure output tensor streams wait on cudnn compute stream + relativeSync({dx, dhx, dcx, dw}, cudnnStream); + return std::make_tuple(dx, dhx, dcx, dw); } } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp b/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp index f495e24..285ec91 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp @@ -21,67 +21,70 @@ namespace fl { namespace { // Flashlight accept HWCN order according to docs -constexpr size_t kHIdx = 0; -constexpr size_t kWIdx = 1; -constexpr size_t kChannelSizeIdx = 2; -constexpr size_t kBatchSizeIdx = 3; - -constexpr auto formatNCHW = dnnl::memory::format_tag::nchw; -constexpr auto formatX = dnnl::memory::format_tag::x; - -int getNfeatures(const Shape& inputShape, const std::vector& axes) { - int nfeatures = 1; - for (auto ax : axes) { - nfeatures *= inputShape.dim(ax); - } - return nfeatures; -} + constexpr size_t kHIdx = 0; + constexpr size_t kWIdx = 1; + constexpr size_t kChannelSizeIdx = 2; + constexpr size_t kBatchSizeIdx = 3; + + constexpr auto formatNCHW = dnnl::memory::format_tag::nchw; + constexpr auto formatX = dnnl::memory::format_tag::x; + + int getNfeatures(const Shape& inputShape, const std::vector& axes) { + int nfeatures = 1; + for(auto ax : axes) { + nfeatures *= inputShape.dim(ax); + } + return nfeatures; + } -dnnl::memory::dims getInputOutputDims( - const int minAxis, - const int maxAxis, - const Tensor& input, - const int nfeatures) { - Shape inDescDims; - if (minAxis == 0) { - inDescDims = Shape( - {1, - 1, - nfeatures, - static_cast(input.elements() / nfeatures)}); - } else { - int batchsz = 1; - for (int i = maxAxis + 1; i < input.ndim(); ++i) { - batchsz *= input.dim(i); + dnnl::memory::dims getInputOutputDims( + const int minAxis, + const int maxAxis, + const Tensor& input, + const int nfeatures + ) { + Shape inDescDims; + if(minAxis == 0) { + inDescDims = Shape( + {1, + 1, + nfeatures, + static_cast(input.elements() / nfeatures)} + ); + } else { + int batchsz = 1; + for(int i = maxAxis + 1; i < input.ndim(); ++i) { + batchsz *= input.dim(i); + } + inDescDims = Shape( + {1, + static_cast(input.elements() / (nfeatures * batchsz)), + nfeatures, + batchsz} + ); + } + + dnnl::memory::dims inputOutputDims = { + inDescDims[kBatchSizeIdx], + inDescDims[kChannelSizeIdx], + inDescDims[kHIdx], + inDescDims[kWIdx]}; + + return inputOutputDims; } - inDescDims = Shape( - {1, - static_cast(input.elements() / (nfeatures * batchsz)), - nfeatures, - batchsz}); - } - - dnnl::memory::dims inputOutputDims = { - inDescDims[kBatchSizeIdx], - inDescDims[kChannelSizeIdx], - inDescDims[kHIdx], - inDescDims[kWIdx]}; - - return inputOutputDims; -} -struct OneDnnBatchNormPayload : detail::AutogradPayloadData { - dnnl::batch_normalization_forward::primitive_desc fwdPrimDesc; - Tensor weights; // combined weight and bias - Tensor bias; - dnnl::memory::dims weightsDims; - dnnl::memory::dims biasDims; - dnnl::memory::desc outputMemoryDescriptor; - dnnl::memory meanMemory; - dnnl::memory varMemory; - dnnl::memory weightsMemory; - dnnl::memory biasMemory; -}; + struct OneDnnBatchNormPayload : detail::AutogradPayloadData { + dnnl::batch_normalization_forward::primitive_desc fwdPrimDesc; + Tensor weights; // combined weight and bias + Tensor bias; + dnnl::memory::dims weightsDims; + dnnl::memory::dims biasDims; + dnnl::memory::desc outputMemoryDescriptor; + dnnl::memory meanMemory; + dnnl::memory varMemory; + dnnl::memory weightsMemory; + dnnl::memory biasMemory; + }; } // namespace @@ -97,106 +100,108 @@ Tensor OneDnnAutogradExtension::batchnorm( const bool train, const double momentum, const double epsilon, - std::shared_ptr autogradPayload) { - if (momentum != 0.) { - throw std::runtime_error("OneDNN batchnorm op doesn't support momentum."); - } - if (input.type() == fl::dtype::f16) { - throw std::runtime_error("OneDNN batchnorm op - f16 inputs not supported."); - } - - auto payload = std::make_shared(); - if (train && autogradPayload) { - autogradPayload->data = payload; - } - - auto output = Tensor(input.shape(), input.type()); - int nfeatures = getNfeatures(input.shape(), axes); - - if (runningVar.isEmpty()) { - runningVar = fl::full({nfeatures}, 1., input.type()); - } - - if (runningMean.isEmpty()) { - runningMean = fl::full({nfeatures}, 0., input.type()); - } - - // Check if axes are valid - auto maxAxis = *std::max_element(axes.begin(), axes.end()); - auto minAxis = *std::min_element(axes.begin(), axes.end()); - bool axesContinuous = (axes.size() == (maxAxis - minAxis + 1)); - if (!axesContinuous) { - throw std::invalid_argument("axis array should be continuous"); - } - - auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); - - // Prepare combined weights - // If empty, user specifies affine to false. Both not trainable. - auto weightNonempty = - weight.isEmpty() ? fl::full({nfeatures}, 1., fl::dtype::f32) : weight; - auto biasNonempty = - bias.isEmpty() ? fl::full({nfeatures}, 0., fl::dtype::f32) : bias; - - // DNNL only accepts weight and bias as a combined input. - // https://git.io/JLn9X - payload->weights = weightNonempty; - payload->bias = biasNonempty; - payload->weightsDims = detail::convertToDnnlDims({nfeatures}); - payload->biasDims = detail::convertToDnnlDims({nfeatures}); - auto inputOutputDims = getInputOutputDims(minAxis, maxAxis, input, nfeatures); - - // Memory for forward - const detail::DnnlMemoryWrapper inputMemory( - input, inputOutputDims, formatNCHW); - const detail::DnnlMemoryWrapper outputMemory( - output, inputOutputDims, formatNCHW); - const detail::DnnlMemoryWrapper meanMemory( - runningMean, {runningMean.dim(0)}, formatX); - const detail::DnnlMemoryWrapper varMemory( - runningVar, {runningVar.dim(0)}, formatX); - // combined scale and shift (weight and bias) - const detail::DnnlMemoryWrapper weightsMemory( - payload->weights, payload->weightsDims, formatX); - const detail::DnnlMemoryWrapper biasMemory( - payload->bias, payload->biasDims, formatX); - payload->meanMemory = meanMemory.getMemory(); - payload->varMemory = varMemory.getMemory(); - payload->weightsMemory = weightsMemory.getMemory(); - payload->biasMemory = biasMemory.getMemory(); - // Primitives and descriptors - auto kind = train ? dnnl::prop_kind::forward_training - : dnnl::prop_kind::forward_inference; - // https://fburl.com/6latj733 - dnnl::normalization_flags flag = train - ? dnnl::normalization_flags::none - : dnnl::normalization_flags::use_global_stats; - flag = flag | dnnl::normalization_flags::use_scale | - dnnl::normalization_flags::use_shift; - payload->fwdPrimDesc = dnnl::batch_normalization_forward::primitive_desc( - dnnlEngine, - kind, - inputMemory.getDescriptor(), - outputMemory.getDescriptor(), - epsilon, - flag); - payload->outputMemoryDescriptor = outputMemory.getDescriptor(); - auto bn = dnnl::batch_normalization_forward(payload->fwdPrimDesc); - std::unordered_map bnFwdArgs = { - {DNNL_ARG_SRC, inputMemory.getMemory()}, - {DNNL_ARG_MEAN, meanMemory.getMemory()}, - {DNNL_ARG_VARIANCE, varMemory.getMemory()}, - {DNNL_ARG_DST, outputMemory.getMemory()}, - {DNNL_ARG_SCALE, weightsMemory.getMemory()}, - {DNNL_ARG_SHIFT, biasMemory.getMemory()}}; - - // Execute - std::vector network; - std::vector> fwdArgs = {bnFwdArgs}; - network.push_back(bn); - detail::executeNetwork(network, fwdArgs); - - return output; + std::shared_ptr autogradPayload +) { + if(momentum != 0.) { + throw std::runtime_error("OneDNN batchnorm op doesn't support momentum."); + } + if(input.type() == fl::dtype::f16) { + throw std::runtime_error("OneDNN batchnorm op - f16 inputs not supported."); + } + + auto payload = std::make_shared(); + if(train && autogradPayload) { + autogradPayload->data = payload; + } + + auto output = Tensor(input.shape(), input.type()); + int nfeatures = getNfeatures(input.shape(), axes); + + if(runningVar.isEmpty()) { + runningVar = fl::full({nfeatures}, 1., input.type()); + } + + if(runningMean.isEmpty()) { + runningMean = fl::full({nfeatures}, 0., input.type()); + } + + // Check if axes are valid + auto maxAxis = *std::max_element(axes.begin(), axes.end()); + auto minAxis = *std::min_element(axes.begin(), axes.end()); + bool axesContinuous = (axes.size() == (maxAxis - minAxis + 1)); + if(!axesContinuous) { + throw std::invalid_argument("axis array should be continuous"); + } + + auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); + + // Prepare combined weights + // If empty, user specifies affine to false. Both not trainable. + auto weightNonempty = + weight.isEmpty() ? fl::full({nfeatures}, 1., fl::dtype::f32) : weight; + auto biasNonempty = + bias.isEmpty() ? fl::full({nfeatures}, 0., fl::dtype::f32) : bias; + + // DNNL only accepts weight and bias as a combined input. + // https://git.io/JLn9X + payload->weights = weightNonempty; + payload->bias = biasNonempty; + payload->weightsDims = detail::convertToDnnlDims({nfeatures}); + payload->biasDims = detail::convertToDnnlDims({nfeatures}); + auto inputOutputDims = getInputOutputDims(minAxis, maxAxis, input, nfeatures); + + // Memory for forward + const detail::DnnlMemoryWrapper inputMemory( + input, inputOutputDims, formatNCHW); + const detail::DnnlMemoryWrapper outputMemory( + output, inputOutputDims, formatNCHW); + const detail::DnnlMemoryWrapper meanMemory( + runningMean, {runningMean.dim(0)}, formatX); + const detail::DnnlMemoryWrapper varMemory( + runningVar, {runningVar.dim(0)}, formatX); + // combined scale and shift (weight and bias) + const detail::DnnlMemoryWrapper weightsMemory( + payload->weights, payload->weightsDims, formatX); + const detail::DnnlMemoryWrapper biasMemory( + payload->bias, payload->biasDims, formatX); + payload->meanMemory = meanMemory.getMemory(); + payload->varMemory = varMemory.getMemory(); + payload->weightsMemory = weightsMemory.getMemory(); + payload->biasMemory = biasMemory.getMemory(); + // Primitives and descriptors + auto kind = train ? dnnl::prop_kind::forward_training + : dnnl::prop_kind::forward_inference; + // https://fburl.com/6latj733 + dnnl::normalization_flags flag = train + ? dnnl::normalization_flags::none + : dnnl::normalization_flags::use_global_stats; + flag = flag | dnnl::normalization_flags::use_scale + | dnnl::normalization_flags::use_shift; + payload->fwdPrimDesc = dnnl::batch_normalization_forward::primitive_desc( + dnnlEngine, + kind, + inputMemory.getDescriptor(), + outputMemory.getDescriptor(), + epsilon, + flag + ); + payload->outputMemoryDescriptor = outputMemory.getDescriptor(); + auto bn = dnnl::batch_normalization_forward(payload->fwdPrimDesc); + std::unordered_map bnFwdArgs = { + {DNNL_ARG_SRC, inputMemory.getMemory()}, + {DNNL_ARG_MEAN, meanMemory.getMemory()}, + {DNNL_ARG_VARIANCE, varMemory.getMemory()}, + {DNNL_ARG_DST, outputMemory.getMemory()}, + {DNNL_ARG_SCALE, weightsMemory.getMemory()}, + {DNNL_ARG_SHIFT, biasMemory.getMemory()}}; + + // Execute + std::vector network; + std::vector> fwdArgs = {bnFwdArgs}; + network.push_back(bn); + detail::executeNetwork(network, fwdArgs); + + return output; } std::tuple OneDnnAutogradExtension::batchnormBackward( @@ -208,74 +213,76 @@ std::tuple OneDnnAutogradExtension::batchnormBackward( const std::vector& axes, const bool train, const float epsilon, - std::shared_ptr autogradPayload) { - if (!autogradPayload) { - throw std::invalid_argument( - "OneDnnAutogradExtension::pool2dBackward given null detail::AutogradPayload"); - } - auto payload = - std::static_pointer_cast(autogradPayload->data); - - auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); - - auto maxAxis = *std::max_element(axes.begin(), axes.end()); - auto minAxis = *std::min_element(axes.begin(), axes.end()); - const bool axesContinuous = (axes.size() == (maxAxis - minAxis + 1)); - if (!axesContinuous) { - throw std::invalid_argument("axis array should be continuous"); - } - - const int nfeatures = getNfeatures(input.shape(), axes); - auto inputOutputDims = getInputOutputDims(minAxis, maxAxis, input, nfeatures); - - auto gradInput = Tensor(input.shape(), input.type()); - auto gradWeights = Tensor(payload->weights.shape(), payload->weights.type()); - auto gradBias = Tensor(payload->bias.shape(), payload->bias.type()); - - const detail::DnnlMemoryWrapper inputMemory( - input, inputOutputDims, formatNCHW); - - // Memory for gradient computation - const detail::DnnlMemoryWrapper gradOutputMem( - gradOutput, inputOutputDims, formatNCHW); - const detail::DnnlMemoryWrapper gradInputMem( - gradInput, inputOutputDims, formatNCHW); - const detail::DnnlMemoryWrapper gradWeightsMem( - gradWeights, payload->weightsDims, formatX); - const detail::DnnlMemoryWrapper gradBiasMem( - gradBias, payload->biasDims, formatX); - - // Primitives and descriptors - auto bwdPrimitiveDesc = dnnl::batch_normalization_backward::primitive_desc( - dnnlEngine, - dnnl::prop_kind::backward, - gradOutputMem.getDescriptor(), - payload->outputMemoryDescriptor, - gradOutputMem.getDescriptor(), - epsilon, - dnnl::normalization_flags::use_scale | - dnnl::normalization_flags::use_shift, - payload->fwdPrimDesc // hint - ); - auto bwdPrim = - std::make_shared(bwdPrimitiveDesc); - // Execute - std::vector networkBackwards; - std::vector> bwdArgs = { - {{DNNL_ARG_SRC, inputMemory.getMemory()}, - {DNNL_ARG_MEAN, payload->meanMemory}, - {DNNL_ARG_VARIANCE, payload->varMemory}, - {DNNL_ARG_SCALE, payload->weightsMemory}, - //TODO dnnl_arg_shift was here, check if something can be optimized bc it's not needed - {DNNL_ARG_DIFF_SRC, gradInputMem.getMemory()}, - {DNNL_ARG_DIFF_DST, gradOutputMem.getMemory()}, - {DNNL_ARG_DIFF_SCALE, gradWeightsMem.getMemory()}, - {DNNL_ARG_DIFF_SHIFT, gradBiasMem.getMemory()}}}; - - networkBackwards.push_back(*bwdPrim); - detail::executeNetwork(networkBackwards, bwdArgs); - - return {gradInput, gradWeights, gradBias}; + std::shared_ptr autogradPayload +) { + if(!autogradPayload) { + throw std::invalid_argument( + "OneDnnAutogradExtension::pool2dBackward given null detail::AutogradPayload" + ); + } + auto payload = + std::static_pointer_cast(autogradPayload->data); + + auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); + + auto maxAxis = *std::max_element(axes.begin(), axes.end()); + auto minAxis = *std::min_element(axes.begin(), axes.end()); + const bool axesContinuous = (axes.size() == (maxAxis - minAxis + 1)); + if(!axesContinuous) { + throw std::invalid_argument("axis array should be continuous"); + } + + const int nfeatures = getNfeatures(input.shape(), axes); + auto inputOutputDims = getInputOutputDims(minAxis, maxAxis, input, nfeatures); + + auto gradInput = Tensor(input.shape(), input.type()); + auto gradWeights = Tensor(payload->weights.shape(), payload->weights.type()); + auto gradBias = Tensor(payload->bias.shape(), payload->bias.type()); + + const detail::DnnlMemoryWrapper inputMemory( + input, inputOutputDims, formatNCHW); + + // Memory for gradient computation + const detail::DnnlMemoryWrapper gradOutputMem( + gradOutput, inputOutputDims, formatNCHW); + const detail::DnnlMemoryWrapper gradInputMem( + gradInput, inputOutputDims, formatNCHW); + const detail::DnnlMemoryWrapper gradWeightsMem( + gradWeights, payload->weightsDims, formatX); + const detail::DnnlMemoryWrapper gradBiasMem( + gradBias, payload->biasDims, formatX); + + // Primitives and descriptors + auto bwdPrimitiveDesc = dnnl::batch_normalization_backward::primitive_desc( + dnnlEngine, + dnnl::prop_kind::backward, + gradOutputMem.getDescriptor(), + payload->outputMemoryDescriptor, + gradOutputMem.getDescriptor(), + epsilon, + dnnl::normalization_flags::use_scale + | dnnl::normalization_flags::use_shift, + payload->fwdPrimDesc // hint + ); + auto bwdPrim = + std::make_shared(bwdPrimitiveDesc); + // Execute + std::vector networkBackwards; + std::vector> bwdArgs = { + {{DNNL_ARG_SRC, inputMemory.getMemory()}, + {DNNL_ARG_MEAN, payload->meanMemory}, + {DNNL_ARG_VARIANCE, payload->varMemory}, + {DNNL_ARG_SCALE, payload->weightsMemory}, + // TODO dnnl_arg_shift was here, check if something can be optimized bc it's not needed + {DNNL_ARG_DIFF_SRC, gradInputMem.getMemory()}, + {DNNL_ARG_DIFF_DST, gradOutputMem.getMemory()}, + {DNNL_ARG_DIFF_SCALE, gradWeightsMem.getMemory()}, + {DNNL_ARG_DIFF_SHIFT, gradBiasMem.getMemory()}}}; + + networkBackwards.push_back(*bwdPrim); + detail::executeNetwork(networkBackwards, bwdArgs); + + return {gradInput, gradWeights, gradBias}; }; } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp b/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp index d6e558f..a83b62b 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp @@ -23,132 +23,140 @@ namespace fl { namespace { // Input, output: WHCN; weights: WHIO -constexpr size_t kWIdx = 0; -constexpr size_t kHIdx = 1; -constexpr size_t kIOChannelSizeIdx = 2; -constexpr size_t kIOBatchSizeIdx = 3; -constexpr size_t kWeightOutputChannelSizeIdx = 3; + constexpr size_t kWIdx = 0; + constexpr size_t kHIdx = 1; + constexpr size_t kIOChannelSizeIdx = 2; + constexpr size_t kIOBatchSizeIdx = 3; + constexpr size_t kWeightOutputChannelSizeIdx = 3; // Use memory::format_tag::any for memory formatting even if convolution // inputs are shaped in a particular way. -constexpr auto formatAny = memory::format_tag::any; -constexpr auto formatNCHW = memory::format_tag::nchw; -constexpr auto formatBias = memory::format_tag::x; - -struct OneDnnConv2DData { - memory::dims inputDims; - memory::dims weightDims; - memory::dims outputDims; - memory::dims biasDims; - memory::dims strideDims; - memory::dims dilationDims; - memory::dims paddingDims; - // Memory descriptors - memory::desc inputMemDesc; - memory::desc outputMemDesc; - memory::desc weightMemDesc; - memory::desc biasMemDesc; - // used for creating a backward desc - convolution_forward::primitive_desc fwdPrimDesc; -}; - -OneDnnConv2DData createOneDnnConv2DData( - fl::dtype inputType, - const Shape& inputShape, - const Shape& weightsShape, - const Shape& biasShape, - const Shape& outputShape, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups) { - const dnnl::memory::data_type dataType = detail::dnnlMapToType(inputType); - const auto formatWeight = - (groups == 1) ? memory::format_tag::oihw : memory::format_tag::goihw; - const bool hasBias = biasShape.elements() > 0; - - OneDnnConv2DData out; - // Create memory dims - out.inputDims = detail::convertToDnnlDims( - {inputShape.dim(kIOBatchSizeIdx), - inputShape.dim(kIOChannelSizeIdx), - inputShape.dim(kHIdx), - inputShape.dim(kWIdx)}); - if (groups == 1) { - out.weightDims = detail::convertToDnnlDims( - {weightsShape.dim(kWeightOutputChannelSizeIdx), - inputShape.dim(kIOChannelSizeIdx), - weightsShape.dim(kHIdx), - weightsShape.dim(kWIdx)}); - } else { - out.weightDims = detail::convertToDnnlDims( - {groups, - weightsShape.dim(kWeightOutputChannelSizeIdx) / groups, - inputShape.dim(kIOChannelSizeIdx) / groups, - weightsShape.dim(kHIdx), - weightsShape.dim(kWIdx)}); - } - out.outputDims = detail::convertToDnnlDims( - {inputShape.dim(kIOBatchSizeIdx), - weightsShape.dim(kWeightOutputChannelSizeIdx), - outputShape.dim(kHIdx), - outputShape.dim(kWIdx)}); - out.biasDims = detail::convertToDnnlDims( - {weightsShape.dim(kWeightOutputChannelSizeIdx)}); - out.strideDims = {sy, sx}; - out.paddingDims = {py, px}; - // NB: DNNL treats a dilation of 0 as a standard convolution and indexes - // larger dilations accordingly. See https://git.io/fhAT2 for more. - out.dilationDims = {dy - 1, dx - 1}; - - // Create memory descriptors. using format::any gives the best performance - out.inputMemDesc = memory::desc({out.inputDims}, dataType, formatAny); - out.outputMemDesc = memory::desc({out.outputDims}, dataType, formatAny); - out.weightMemDesc = memory::desc({out.weightDims}, dataType, formatWeight); - out.biasMemDesc = memory::desc({out.biasDims}, dataType, formatAny); - - // - const auto forwardMode = prop_kind::forward_training; - // TODO: determine train mode/assess perf impact of always choosing training - // (primitive cache storage overhead?) - // const auto forwardMode = - // train ? prop_kind::forward_training : prop_kind::forward_inference; - - auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); - convolution_forward::primitive_desc fwdPrimitiveDescriptor; - if (hasBias) { - fwdPrimitiveDescriptor = convolution_forward::primitive_desc( - dnnlEngine, - forwardMode, - algorithm::convolution_direct, - out.inputMemDesc, - out.weightMemDesc, - out.biasMemDesc, - out.outputMemDesc, - out.strideDims, - out.dilationDims, - out.paddingDims, - out.paddingDims); - } else { - fwdPrimitiveDescriptor = convolution_forward::primitive_desc( - dnnlEngine, - forwardMode, - algorithm::convolution_direct, - out.inputMemDesc, - out.weightMemDesc, - out.outputMemDesc, - out.strideDims, - out.dilationDims, - out.paddingDims, - out.paddingDims); - } - out.fwdPrimDesc = std::move(fwdPrimitiveDescriptor); - - return out; -} + constexpr auto formatAny = memory::format_tag::any; + constexpr auto formatNCHW = memory::format_tag::nchw; + constexpr auto formatBias = memory::format_tag::x; + + struct OneDnnConv2DData { + memory::dims inputDims; + memory::dims weightDims; + memory::dims outputDims; + memory::dims biasDims; + memory::dims strideDims; + memory::dims dilationDims; + memory::dims paddingDims; + // Memory descriptors + memory::desc inputMemDesc; + memory::desc outputMemDesc; + memory::desc weightMemDesc; + memory::desc biasMemDesc; + // used for creating a backward desc + convolution_forward::primitive_desc fwdPrimDesc; + }; + + OneDnnConv2DData createOneDnnConv2DData( + fl::dtype inputType, + const Shape& inputShape, + const Shape& weightsShape, + const Shape& biasShape, + const Shape& outputShape, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups + ) { + const dnnl::memory::data_type dataType = detail::dnnlMapToType(inputType); + const auto formatWeight = + (groups == 1) ? memory::format_tag::oihw : memory::format_tag::goihw; + const bool hasBias = biasShape.elements() > 0; + + OneDnnConv2DData out; + // Create memory dims + out.inputDims = detail::convertToDnnlDims( + {inputShape.dim(kIOBatchSizeIdx), + inputShape.dim(kIOChannelSizeIdx), + inputShape.dim(kHIdx), + inputShape.dim(kWIdx)} + ); + if(groups == 1) { + out.weightDims = detail::convertToDnnlDims( + {weightsShape.dim(kWeightOutputChannelSizeIdx), + inputShape.dim(kIOChannelSizeIdx), + weightsShape.dim(kHIdx), + weightsShape.dim(kWIdx)} + ); + } else { + out.weightDims = detail::convertToDnnlDims( + {groups, + weightsShape.dim(kWeightOutputChannelSizeIdx) / groups, + inputShape.dim(kIOChannelSizeIdx) / groups, + weightsShape.dim(kHIdx), + weightsShape.dim(kWIdx)} + ); + } + out.outputDims = detail::convertToDnnlDims( + {inputShape.dim(kIOBatchSizeIdx), + weightsShape.dim(kWeightOutputChannelSizeIdx), + outputShape.dim(kHIdx), + outputShape.dim(kWIdx)} + ); + out.biasDims = detail::convertToDnnlDims( + {weightsShape.dim(kWeightOutputChannelSizeIdx)} + ); + out.strideDims = {sy, sx}; + out.paddingDims = {py, px}; + // NB: DNNL treats a dilation of 0 as a standard convolution and indexes + // larger dilations accordingly. See https://git.io/fhAT2 for more. + out.dilationDims = {dy - 1, dx - 1}; + + // Create memory descriptors. using format::any gives the best performance + out.inputMemDesc = memory::desc({out.inputDims}, dataType, formatAny); + out.outputMemDesc = memory::desc({out.outputDims}, dataType, formatAny); + out.weightMemDesc = memory::desc({out.weightDims}, dataType, formatWeight); + out.biasMemDesc = memory::desc({out.biasDims}, dataType, formatAny); + + // + const auto forwardMode = prop_kind::forward_training; + // TODO: determine train mode/assess perf impact of always choosing training + // (primitive cache storage overhead?) + // const auto forwardMode = + // train ? prop_kind::forward_training : prop_kind::forward_inference; + + auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); + convolution_forward::primitive_desc fwdPrimitiveDescriptor; + if(hasBias) { + fwdPrimitiveDescriptor = convolution_forward::primitive_desc( + dnnlEngine, + forwardMode, + algorithm::convolution_direct, + out.inputMemDesc, + out.weightMemDesc, + out.biasMemDesc, + out.outputMemDesc, + out.strideDims, + out.dilationDims, + out.paddingDims, + out.paddingDims + ); + } else { + fwdPrimitiveDescriptor = convolution_forward::primitive_desc( + dnnlEngine, + forwardMode, + algorithm::convolution_direct, + out.inputMemDesc, + out.weightMemDesc, + out.outputMemDesc, + out.strideDims, + out.dilationDims, + out.paddingDims, + out.paddingDims + ); + } + out.fwdPrimDesc = std::move(fwdPrimitiveDescriptor); + + return out; + } } // namespace @@ -163,110 +171,122 @@ Tensor OneDnnAutogradExtension::conv2d( const int dx, const int dy, const int groups, - std::shared_ptr) { - if (input.type() == fl::dtype::f16) { - throw std::runtime_error("Half precision is not supported in CPU."); - } - - // flashlight input, weight, and output shapes in column-major: - // - Input is WHCN - // - Weights are WHIO - // - Output is WHCN - // Since ArrayFire is column major, getting a raw pointer (1D - // representation) of these shapes and viewing as if the representation is - // row major transposes along all axis into NCHW for the input and output - // and OIHW for the weights - auto output = Tensor( - {1 + - (input.dim(kWIdx) + (2 * px) - (1 + (weights.dim(kWIdx) - 1) * dx)) / - sx, - 1 + - (input.dim(kHIdx) + (2 * py) - (1 + (weights.dim(kHIdx) - 1) * dy)) / - sy, - weights.dim(kWeightOutputChannelSizeIdx), - input.dim(kIOBatchSizeIdx)}, - input.type()); - auto hasBias = bias.elements() > 0; - - auto formatWeight = - (groups == 1) ? memory::format_tag::oihw : memory::format_tag::goihw; - auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); - - /********************************* Forward *******************************/ - OneDnnConv2DData conv2DData = createOneDnnConv2DData( - input.type(), - input.shape(), - weights.shape(), - bias.shape(), - output.shape(), - sx, - sy, - px, - py, - dx, - dy, - groups); - - // Create memory - const detail::DnnlMemoryWrapper inputMemInit( - input, {conv2DData.inputDims}, formatNCHW); - const detail::DnnlMemoryWrapper outputMemInit( - output, {conv2DData.outputDims}, formatNCHW); - const detail::DnnlMemoryWrapper weightsMem( - weights, {conv2DData.weightDims}, formatWeight); - - // Network for execution - std::vector network; - std::vector> fwdArgs; - - // DNNL suggests checking if the layout requested for the convolution - // is different from NCHW/OIHW (even if specified), and reordering if - // necessary, since the convolution itself may request a different - // ordering - auto inputDesc = conv2DData.fwdPrimDesc.src_desc(); - auto weightsDesc = conv2DData.fwdPrimDesc.weights_desc(); - auto outputDesc = conv2DData.fwdPrimDesc.dst_desc(); - // Input - auto inputMemory = detail::dnnlAlignOrdering( - network, fwdArgs, inputMemInit.getMemory(), inputDesc); - auto weightsMemory = detail::dnnlAlignOrdering( - network, fwdArgs, weightsMem.getMemory(), weightsDesc); - // Output - adds a reorder after the conv if needed - auto outputMemory = outputMemInit.getMemory(); - if (outputMemInit.getMemory().get_desc() != outputDesc) { - outputMemory = memory(outputDesc, dnnlEngine); - } - - // Create convolution - std::shared_ptr conv; - const detail::DnnlMemoryWrapper biasMemory( - bias, conv2DData.biasDims, formatBias); - conv = std::make_shared(conv2DData.fwdPrimDesc); - - network.push_back(*conv); - - // Conv fwd args - std::unordered_map convFwdArgs = { - {DNNL_ARG_SRC, inputMemory}, - {DNNL_ARG_WEIGHTS, weightsMemory}, - {DNNL_ARG_DST, outputMemory}}; - if (hasBias) { - convFwdArgs[DNNL_ARG_BIAS] = biasMemory.getMemory(); - } - fwdArgs.push_back(convFwdArgs); - - // Add output reordering if needed - if (outputMemory != outputMemInit.getMemory()) { - network.push_back(dnnl::reorder(outputMemory, outputMemInit.getMemory())); - fwdArgs.push_back( - {{DNNL_ARG_FROM, outputMemory}, - {DNNL_ARG_TO, outputMemInit.getMemory()}}); - } - - // Run - detail::executeNetwork(network, fwdArgs); - - return output; + std::shared_ptr +) { + if(input.type() == fl::dtype::f16) { + throw std::runtime_error("Half precision is not supported in CPU."); + } + + // flashlight input, weight, and output shapes in column-major: + // - Input is WHCN + // - Weights are WHIO + // - Output is WHCN + // Since ArrayFire is column major, getting a raw pointer (1D + // representation) of these shapes and viewing as if the representation is + // row major transposes along all axis into NCHW for the input and output + // and OIHW for the weights + auto output = Tensor( + {1 + + (input.dim(kWIdx) + (2 * px) - (1 + (weights.dim(kWIdx) - 1) * dx)) + / sx, + 1 + + (input.dim(kHIdx) + (2 * py) - (1 + (weights.dim(kHIdx) - 1) * dy)) + / sy, + weights.dim(kWeightOutputChannelSizeIdx), + input.dim(kIOBatchSizeIdx)}, + input.type() + ); + auto hasBias = bias.elements() > 0; + + auto formatWeight = + (groups == 1) ? memory::format_tag::oihw : memory::format_tag::goihw; + auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); + + /********************************* Forward *******************************/ + OneDnnConv2DData conv2DData = createOneDnnConv2DData( + input.type(), + input.shape(), + weights.shape(), + bias.shape(), + output.shape(), + sx, + sy, + px, + py, + dx, + dy, + groups + ); + + // Create memory + const detail::DnnlMemoryWrapper inputMemInit( + input, {conv2DData.inputDims}, formatNCHW); + const detail::DnnlMemoryWrapper outputMemInit( + output, {conv2DData.outputDims}, formatNCHW); + const detail::DnnlMemoryWrapper weightsMem( + weights, {conv2DData.weightDims}, formatWeight); + + // Network for execution + std::vector network; + std::vector> fwdArgs; + + // DNNL suggests checking if the layout requested for the convolution + // is different from NCHW/OIHW (even if specified), and reordering if + // necessary, since the convolution itself may request a different + // ordering + auto inputDesc = conv2DData.fwdPrimDesc.src_desc(); + auto weightsDesc = conv2DData.fwdPrimDesc.weights_desc(); + auto outputDesc = conv2DData.fwdPrimDesc.dst_desc(); + // Input + auto inputMemory = detail::dnnlAlignOrdering( + network, + fwdArgs, + inputMemInit.getMemory(), + inputDesc + ); + auto weightsMemory = detail::dnnlAlignOrdering( + network, + fwdArgs, + weightsMem.getMemory(), + weightsDesc + ); + // Output - adds a reorder after the conv if needed + auto outputMemory = outputMemInit.getMemory(); + if(outputMemInit.getMemory().get_desc() != outputDesc) { + outputMemory = memory(outputDesc, dnnlEngine); + } + + // Create convolution + std::shared_ptr conv; + const detail::DnnlMemoryWrapper biasMemory( + bias, conv2DData.biasDims, formatBias); + conv = std::make_shared(conv2DData.fwdPrimDesc); + + network.push_back(*conv); + + // Conv fwd args + std::unordered_map convFwdArgs = { + {DNNL_ARG_SRC, inputMemory}, + {DNNL_ARG_WEIGHTS, weightsMemory}, + {DNNL_ARG_DST, outputMemory}}; + if(hasBias) { + convFwdArgs[DNNL_ARG_BIAS] = biasMemory.getMemory(); + } + fwdArgs.push_back(convFwdArgs); + + // Add output reordering if needed + if(outputMemory != outputMemInit.getMemory()) { + network.push_back(dnnl::reorder(outputMemory, outputMemInit.getMemory())); + fwdArgs.push_back( + {{DNNL_ARG_FROM, outputMemory}, + {DNNL_ARG_TO, outputMemInit.getMemory()}} + ); + } + + // Run + detail::executeNetwork(network, fwdArgs); + + return output; } Tensor OneDnnAutogradExtension::conv2dBackwardData( @@ -281,97 +301,104 @@ Tensor OneDnnAutogradExtension::conv2dBackwardData( const int dy, const int groups, std::shared_ptr, - std::shared_ptr) { - auto gradInput = Tensor(input.shape(), input.type()); // Result - - auto formatWeight = - (groups == 1) ? memory::format_tag::oihw : memory::format_tag::goihw; - auto& dnnlEngineBwd = detail::DnnlEngine::getInstance().getEngine(); - - Tensor bias; // dummy - OneDnnConv2DData conv2DData = createOneDnnConv2DData( - input.type(), - input.shape(), - weights.shape(), - bias.shape(), - gradOutput.shape(), // has the same shape as the Conv output - sx, - sy, - px, - py, - dx, - dy, - groups); - - // Backward descriptor - convolution_backward_data::primitive_desc bwdDataPrimitiveDesc( - dnnlEngineBwd, - algorithm::convolution_direct, - conv2DData.inputMemDesc, - conv2DData.weightMemDesc, - conv2DData.outputMemDesc, - conv2DData.strideDims, - conv2DData.dilationDims, - conv2DData.paddingDims, - conv2DData.paddingDims, - conv2DData.fwdPrimDesc); - // Primitive descriptor - auto bwdData = - std::make_shared(bwdDataPrimitiveDesc); - - // Create memory - const detail::DnnlMemoryWrapper gradOutputMemInit( - gradOutput, conv2DData.outputDims, formatNCHW); - const detail::DnnlMemoryWrapper gradInputMemInit( - gradInput, conv2DData.inputDims, formatNCHW); - const detail::DnnlMemoryWrapper weightsMemInitBwd( - weights, conv2DData.weightDims, formatWeight); - - std::vector networkBackwards; - std::vector> bwdDataArgs; - - // Check for reorderings - auto gradOutputDesc = bwdDataPrimitiveDesc.diff_dst_desc(); - auto weightsDesc = bwdDataPrimitiveDesc.weights_desc(); - auto gradInputDesc = bwdDataPrimitiveDesc.diff_src_desc(); - auto gradOutputMemory = detail::dnnlAlignOrdering( - networkBackwards, - bwdDataArgs, - gradOutputMemInit.getMemory(), - gradOutputDesc); - auto weightsMemoryBackwards = detail::dnnlAlignOrdering( - networkBackwards, - bwdDataArgs, - weightsMemInitBwd.getMemory(), - weightsDesc); - auto gradInputMemory = gradInputMemInit.getMemory(); - // Don't reorder the gradient until after the conv - if (gradInputMemInit.getMemory().get_desc() != gradInputDesc) { - gradInputMemory = memory(gradInputDesc, dnnlEngineBwd); - } - - // Convolution backwards - auto convBwdData = - std::make_shared(bwdDataPrimitiveDesc); - - bwdDataArgs.push_back( - {{DNNL_ARG_DIFF_SRC, gradInputMemory}, - {DNNL_ARG_WEIGHTS, weightsMemoryBackwards}, - {DNNL_ARG_DIFF_DST, gradOutputMemory}}); - networkBackwards.push_back(*convBwdData); - - // Reorder the output (which is gradInput here) if necessary - if (gradInputMemory != gradInputMemInit.getMemory()) { - networkBackwards.push_back( - dnnl::reorder(gradInputMemory, gradInputMemInit.getMemory())); - bwdDataArgs.push_back( - {{DNNL_ARG_FROM, gradInputMemory}, - {DNNL_ARG_TO, gradInputMemInit.getMemory()}}); - } - - detail::executeNetwork(networkBackwards, bwdDataArgs); + std::shared_ptr +) { + auto gradInput = Tensor(input.shape(), input.type()); // Result + + auto formatWeight = + (groups == 1) ? memory::format_tag::oihw : memory::format_tag::goihw; + auto& dnnlEngineBwd = detail::DnnlEngine::getInstance().getEngine(); + + Tensor bias; // dummy + OneDnnConv2DData conv2DData = createOneDnnConv2DData( + input.type(), + input.shape(), + weights.shape(), + bias.shape(), + gradOutput.shape(), // has the same shape as the Conv output + sx, + sy, + px, + py, + dx, + dy, + groups + ); + + // Backward descriptor + convolution_backward_data::primitive_desc bwdDataPrimitiveDesc( + dnnlEngineBwd, + algorithm::convolution_direct, + conv2DData.inputMemDesc, + conv2DData.weightMemDesc, + conv2DData.outputMemDesc, + conv2DData.strideDims, + conv2DData.dilationDims, + conv2DData.paddingDims, + conv2DData.paddingDims, + conv2DData.fwdPrimDesc); + // Primitive descriptor + auto bwdData = + std::make_shared(bwdDataPrimitiveDesc); + + // Create memory + const detail::DnnlMemoryWrapper gradOutputMemInit( + gradOutput, conv2DData.outputDims, formatNCHW); + const detail::DnnlMemoryWrapper gradInputMemInit( + gradInput, conv2DData.inputDims, formatNCHW); + const detail::DnnlMemoryWrapper weightsMemInitBwd( + weights, conv2DData.weightDims, formatWeight); + + std::vector networkBackwards; + std::vector> bwdDataArgs; + + // Check for reorderings + auto gradOutputDesc = bwdDataPrimitiveDesc.diff_dst_desc(); + auto weightsDesc = bwdDataPrimitiveDesc.weights_desc(); + auto gradInputDesc = bwdDataPrimitiveDesc.diff_src_desc(); + auto gradOutputMemory = detail::dnnlAlignOrdering( + networkBackwards, + bwdDataArgs, + gradOutputMemInit.getMemory(), + gradOutputDesc + ); + auto weightsMemoryBackwards = detail::dnnlAlignOrdering( + networkBackwards, + bwdDataArgs, + weightsMemInitBwd.getMemory(), + weightsDesc + ); + auto gradInputMemory = gradInputMemInit.getMemory(); + // Don't reorder the gradient until after the conv + if(gradInputMemInit.getMemory().get_desc() != gradInputDesc) { + gradInputMemory = memory(gradInputDesc, dnnlEngineBwd); + } + + // Convolution backwards + auto convBwdData = + std::make_shared(bwdDataPrimitiveDesc); - return gradInput; + bwdDataArgs.push_back( + {{DNNL_ARG_DIFF_SRC, gradInputMemory}, + {DNNL_ARG_WEIGHTS, weightsMemoryBackwards}, + {DNNL_ARG_DIFF_DST, gradOutputMemory}} + ); + networkBackwards.push_back(*convBwdData); + + // Reorder the output (which is gradInput here) if necessary + if(gradInputMemory != gradInputMemInit.getMemory()) { + networkBackwards.push_back( + dnnl::reorder(gradInputMemory, gradInputMemInit.getMemory()) + ); + bwdDataArgs.push_back( + {{DNNL_ARG_FROM, gradInputMemory}, + {DNNL_ARG_TO, gradInputMemInit.getMemory()}} + ); + } + + detail::executeNetwork(networkBackwards, bwdDataArgs); + + return gradInput; } std::pair OneDnnAutogradExtension::conv2dBackwardFilterBias( @@ -388,122 +415,129 @@ std::pair OneDnnAutogradExtension::conv2dBackwardFilterBias( const int groups, std::shared_ptr, std::shared_ptr, - std::shared_ptr) { - auto gradWeights = Tensor(weights.shape(), weights.type()); // Result - - auto formatWeight = - (groups == 1) ? memory::format_tag::oihw : memory::format_tag::goihw; - auto& dnnlEngineBwd = detail::DnnlEngine::getInstance().getEngine(); - OneDnnConv2DData conv2DData = createOneDnnConv2DData( - input.type(), - input.shape(), - weights.shape(), - bias.shape(), - gradOutput.shape(), // has the same shape as the Conv output - sx, - sy, - px, - py, - dx, - dy, - groups); - - Tensor gradBias; - bool computeBiasGrad = !bias.isEmpty() && !conv2DData.biasMemDesc.is_zero(); - if (computeBiasGrad) { - gradBias = Tensor(bias.shape(), bias.type()); - } - - // Weight backward descriptor - convolution_backward_weights::primitive_desc bwdWeightPrimitiveDesc; - if (computeBiasGrad) { - bwdWeightPrimitiveDesc = convolution_backward_weights::primitive_desc( - dnnlEngineBwd, - algorithm::convolution_direct, - conv2DData.inputMemDesc, - conv2DData.weightMemDesc, - conv2DData.biasMemDesc, - conv2DData.outputMemDesc, - conv2DData.strideDims, - conv2DData.dilationDims, - conv2DData.paddingDims, - conv2DData.paddingDims, - conv2DData.fwdPrimDesc); - } else { - bwdWeightPrimitiveDesc = convolution_backward_weights::primitive_desc( - dnnlEngineBwd, - algorithm::convolution_direct, - conv2DData.inputMemDesc, - conv2DData.weightMemDesc, - conv2DData.outputMemDesc, - conv2DData.strideDims, - conv2DData.dilationDims, - conv2DData.paddingDims, - conv2DData.paddingDims, - conv2DData.fwdPrimDesc); - } - // Weight backward primitive descriptor - auto bwdWeights = - std::make_shared(bwdWeightPrimitiveDesc); - - // Create memory - const detail::DnnlMemoryWrapper inputRawMemInitBwd( - input, conv2DData.inputDims, formatNCHW); - const detail::DnnlMemoryWrapper gradOutputMemInit( - gradOutput, conv2DData.outputDims, formatNCHW); - const detail::DnnlMemoryWrapper gradWeightsMemInit( - gradWeights, conv2DData.weightDims, formatWeight); - - std::vector networkBackwards; - std::vector> bwdWeightsArgs; - - // Check for reorderings, reorder if needed - auto inputDesc = bwdWeightPrimitiveDesc.src_desc(); - auto gradOutputDesc = bwdWeightPrimitiveDesc.diff_dst_desc(); - auto gradWeightsDesc = bwdWeightPrimitiveDesc.diff_weights_desc(); - auto inputMemoryBackwards = detail::dnnlAlignOrdering( - networkBackwards, - bwdWeightsArgs, - inputRawMemInitBwd.getMemory(), - inputDesc); - auto gradOutputMemory = detail::dnnlAlignOrdering( - networkBackwards, - bwdWeightsArgs, - gradOutputMemInit.getMemory(), - gradOutputDesc); - // Don't reorder the grads until after the conv bwd - auto gradWeightsMemory = gradWeightsMemInit.getMemory(); - if (gradWeightsMemInit.getMemory().get_desc() != gradWeightsDesc) { - gradWeightsMemory = memory(gradWeightsDesc, dnnlEngineBwd); - } - - // Create the convolution backward weight - std::unordered_map bwdConvWeightsArgs = { - {DNNL_ARG_SRC, inputMemoryBackwards}, - {DNNL_ARG_DIFF_WEIGHTS, gradWeightsMemory}, - {DNNL_ARG_DIFF_DST, gradOutputMemory}}; - - if (computeBiasGrad) { - const detail::DnnlMemoryWrapper gradBiasMem( - gradBias, conv2DData.biasDims, formatBias); - bwdConvWeightsArgs[DNNL_ARG_DIFF_BIAS] = gradBiasMem.getMemory(); - } else { - } - networkBackwards.push_back(*bwdWeights); - bwdWeightsArgs.push_back(bwdConvWeightsArgs); - - // Reorder weight gradients if necessary - if (gradWeightsMemory != gradWeightsMemInit.getMemory()) { - networkBackwards.push_back( - dnnl::reorder(gradWeightsMemory, gradWeightsMemInit.getMemory())); - bwdWeightsArgs.push_back( - {{DNNL_ARG_FROM, gradWeightsMemory}, - {DNNL_ARG_TO, gradWeightsMemInit.getMemory()}}); - } - - detail::executeNetwork(networkBackwards, bwdWeightsArgs); - - return {gradWeights, gradBias}; + std::shared_ptr +) { + auto gradWeights = Tensor(weights.shape(), weights.type()); // Result + + auto formatWeight = + (groups == 1) ? memory::format_tag::oihw : memory::format_tag::goihw; + auto& dnnlEngineBwd = detail::DnnlEngine::getInstance().getEngine(); + OneDnnConv2DData conv2DData = createOneDnnConv2DData( + input.type(), + input.shape(), + weights.shape(), + bias.shape(), + gradOutput.shape(), // has the same shape as the Conv output + sx, + sy, + px, + py, + dx, + dy, + groups + ); + + Tensor gradBias; + bool computeBiasGrad = !bias.isEmpty() && !conv2DData.biasMemDesc.is_zero(); + if(computeBiasGrad) { + gradBias = Tensor(bias.shape(), bias.type()); + } + + // Weight backward descriptor + convolution_backward_weights::primitive_desc bwdWeightPrimitiveDesc; + if(computeBiasGrad) { + bwdWeightPrimitiveDesc = convolution_backward_weights::primitive_desc( + dnnlEngineBwd, + algorithm::convolution_direct, + conv2DData.inputMemDesc, + conv2DData.weightMemDesc, + conv2DData.biasMemDesc, + conv2DData.outputMemDesc, + conv2DData.strideDims, + conv2DData.dilationDims, + conv2DData.paddingDims, + conv2DData.paddingDims, + conv2DData.fwdPrimDesc + ); + } else { + bwdWeightPrimitiveDesc = convolution_backward_weights::primitive_desc( + dnnlEngineBwd, + algorithm::convolution_direct, + conv2DData.inputMemDesc, + conv2DData.weightMemDesc, + conv2DData.outputMemDesc, + conv2DData.strideDims, + conv2DData.dilationDims, + conv2DData.paddingDims, + conv2DData.paddingDims, + conv2DData.fwdPrimDesc + ); + } + // Weight backward primitive descriptor + auto bwdWeights = + std::make_shared(bwdWeightPrimitiveDesc); + + // Create memory + const detail::DnnlMemoryWrapper inputRawMemInitBwd( + input, conv2DData.inputDims, formatNCHW); + const detail::DnnlMemoryWrapper gradOutputMemInit( + gradOutput, conv2DData.outputDims, formatNCHW); + const detail::DnnlMemoryWrapper gradWeightsMemInit( + gradWeights, conv2DData.weightDims, formatWeight); + + std::vector networkBackwards; + std::vector> bwdWeightsArgs; + + // Check for reorderings, reorder if needed + auto inputDesc = bwdWeightPrimitiveDesc.src_desc(); + auto gradOutputDesc = bwdWeightPrimitiveDesc.diff_dst_desc(); + auto gradWeightsDesc = bwdWeightPrimitiveDesc.diff_weights_desc(); + auto inputMemoryBackwards = detail::dnnlAlignOrdering( + networkBackwards, + bwdWeightsArgs, + inputRawMemInitBwd.getMemory(), + inputDesc + ); + auto gradOutputMemory = detail::dnnlAlignOrdering( + networkBackwards, + bwdWeightsArgs, + gradOutputMemInit.getMemory(), + gradOutputDesc + ); + // Don't reorder the grads until after the conv bwd + auto gradWeightsMemory = gradWeightsMemInit.getMemory(); + if(gradWeightsMemInit.getMemory().get_desc() != gradWeightsDesc) { + gradWeightsMemory = memory(gradWeightsDesc, dnnlEngineBwd); + } + + // Create the convolution backward weight + std::unordered_map bwdConvWeightsArgs = { + {DNNL_ARG_SRC, inputMemoryBackwards}, + {DNNL_ARG_DIFF_WEIGHTS, gradWeightsMemory}, + {DNNL_ARG_DIFF_DST, gradOutputMemory}}; + + if(computeBiasGrad) { + const detail::DnnlMemoryWrapper gradBiasMem( + gradBias, conv2DData.biasDims, formatBias); + bwdConvWeightsArgs[DNNL_ARG_DIFF_BIAS] = gradBiasMem.getMemory(); + } else {} + networkBackwards.push_back(*bwdWeights); + bwdWeightsArgs.push_back(bwdConvWeightsArgs); + + // Reorder weight gradients if necessary + if(gradWeightsMemory != gradWeightsMemInit.getMemory()) { + networkBackwards.push_back( + dnnl::reorder(gradWeightsMemory, gradWeightsMemInit.getMemory()) + ); + bwdWeightsArgs.push_back( + {{DNNL_ARG_FROM, gradWeightsMemory}, + {DNNL_ARG_TO, gradWeightsMemInit.getMemory()}} + ); + } + + detail::executeNetwork(networkBackwards, bwdWeightsArgs); + + return {gradWeights, gradBias}; } } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp b/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp index 5fa5530..371c46c 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp @@ -11,7 +11,7 @@ #include #if FL_BACKEND_OPENCL - #include +#include #endif #include "flashlight/fl/common/Defines.h" @@ -19,140 +19,149 @@ #include "flashlight/fl/tensor/TensorBase.h" #if FL_BACKEND_OPENCL - #include "flashlight/fl/common/OpenClUtils.h" +#include "flashlight/fl/common/OpenClUtils.h" #endif namespace fl::detail { DnnlStream::DnnlStream(dnnl::engine engine) { #if FL_BACKEND_OPENCL - stream_ = dnnl::ocl_interop::make_stream(engine, fl::ocl::getQueue()); + stream_ = dnnl::ocl_interop::make_stream(engine, fl::ocl::getQueue()); #else - stream_ = dnnl::stream(engine); + stream_ = dnnl::stream(engine); #endif } dnnl::stream& DnnlStream::getStream() { - return stream_; + return stream_; } DnnlStream& DnnlStream::getInstance() { - static DnnlStream instance(DnnlEngine::getInstance().getEngine()); - return instance; + static DnnlStream instance(DnnlEngine::getInstance().getEngine()); + return instance; } DnnlEngine::DnnlEngine() { #if FL_BACKEND_OPENCL - engine_ = dnnl::ocl_interop::make_engine( - fl::ocl::getDeviceId(), fl::ocl::getContext()); + engine_ = dnnl::ocl_interop::make_engine( + fl::ocl::getDeviceId(), + fl::ocl::getContext() + ); #else - engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0); + engine_ = dnnl::engine(dnnl::engine::kind::cpu, 0); #endif } dnnl::engine& DnnlEngine::getEngine() { - return engine_; + return engine_; } DnnlEngine& DnnlEngine::getInstance() { - static DnnlEngine instance; - return instance; + static DnnlEngine instance; + return instance; } dnnl::memory::dims convertToDnnlDims(const std::vector& shape) { - return dnnl::memory::dims(shape.begin(), shape.end()); + return dnnl::memory::dims(shape.begin(), shape.end()); } dnnl::memory::dims convertShapeToDnnlDims(const Shape& shape) { - return convertToDnnlDims(shape.get()); + return convertToDnnlDims(shape.get()); } DnnlMemoryWrapper::DnnlMemoryWrapper( const Tensor& tensor, dnnl::memory::dims dims, - dnnl::memory::format_tag format) { + dnnl::memory::format_tag format +) { #if FL_BACKEND_OPENCL - fl::ocl::DevicePtrOpenCl _devicePtr(tensor); - cl_mem* buffer = _devicePtr.getAsClMem(); - devicePtr_ = std::move(_devicePtr); + fl::ocl::DevicePtrOpenCl _devicePtr(tensor); + cl_mem* buffer = _devicePtr.getAsClMem(); + devicePtr_ = std::move(_devicePtr); #else - devicePtr_ = fl::DevicePtr(tensor); - void* buffer = devicePtr_.get(); + devicePtr_ = fl::DevicePtr(tensor); + void* buffer = devicePtr_.get(); #endif - descriptor_ = - dnnl::memory::desc({dims}, detail::dnnlMapToType(tensor.type()), format); - memory_ = dnnl::memory( - descriptor_, detail::DnnlEngine::getInstance().getEngine(), buffer); + descriptor_ = + dnnl::memory::desc({dims}, detail::dnnlMapToType(tensor.type()), format); + memory_ = dnnl::memory( + descriptor_, + detail::DnnlEngine::getInstance().getEngine(), + buffer + ); } DnnlMemoryWrapper& DnnlMemoryWrapper::operator=(DnnlMemoryWrapper&& other) { - devicePtr_ = std::move(other.devicePtr_); - memory_ = std::move(other.memory_); - descriptor_ = std::move(other.descriptor_); - return *this; + devicePtr_ = std::move(other.devicePtr_); + memory_ = std::move(other.memory_); + descriptor_ = std::move(other.descriptor_); + return *this; } dnnl::memory DnnlMemoryWrapper::getMemory() const { - return memory_; + return memory_; } dnnl::memory::desc DnnlMemoryWrapper::getDescriptor() const { - return descriptor_; + return descriptor_; } dnnl::memory dnnlAlignOrdering( std::vector& net, std::vector>& netArgs, const dnnl::memory& memory, - const dnnl::memory::desc& desc) { - auto memoryOut = memory; - if (memory.get_desc() != desc) { - // use the ordering requested by the descriptor - memoryOut = - dnnl::memory(desc, detail::DnnlEngine::getInstance().getEngine()); - net.push_back(dnnl::reorder(memory, memoryOut)); - netArgs.push_back({{DNNL_ARG_FROM, memory}, {DNNL_ARG_TO, memoryOut}}); - } - return memoryOut; + const dnnl::memory::desc& desc +) { + auto memoryOut = memory; + if(memory.get_desc() != desc) { + // use the ordering requested by the descriptor + memoryOut = + dnnl::memory(desc, detail::DnnlEngine::getInstance().getEngine()); + net.push_back(dnnl::reorder(memory, memoryOut)); + netArgs.push_back({{DNNL_ARG_FROM, memory}, {DNNL_ARG_TO, memoryOut}}); + } + return memoryOut; } void executeNetwork( std::vector& net, - std::vector>& netArgs) { - if (net.size() != netArgs.size()) { - throw std::invalid_argument( - "executeNetwork - given different size nets and netArgs"); - } - // TODO{fl::Tensor}{macros} -- improve this to work with other backend interop - // If on the CPU backend, there isn't a AF computation stream that facilitates - // enforcing that inputs to computation are ready; we're required to wait - // until all AF operations are done - if (FL_BACKEND_CPU) { - fl::sync(); - } - - for (size_t i = 0; i < net.size(); ++i) { - net.at(i).execute(DnnlStream::getInstance().getStream(), netArgs.at(i)); - } - - // TODO{fl::Tensor}{macros} -- improve this to work with other backend interop - if (FL_BACKEND_CPU) { - // Block the executing thread until the work is complete - DnnlStream::getInstance().getStream().wait(); - } + std::vector>& netArgs +) { + if(net.size() != netArgs.size()) { + throw std::invalid_argument( + "executeNetwork - given different size nets and netArgs" + ); + } + // TODO{fl::Tensor}{macros} -- improve this to work with other backend interop + // If on the CPU backend, there isn't a AF computation stream that facilitates + // enforcing that inputs to computation are ready; we're required to wait + // until all AF operations are done + if(FL_BACKEND_CPU) { + fl::sync(); + } + + for(size_t i = 0; i < net.size(); ++i) { + net.at(i).execute(DnnlStream::getInstance().getStream(), netArgs.at(i)); + } + + // TODO{fl::Tensor}{macros} -- improve this to work with other backend interop + if(FL_BACKEND_CPU) { + // Block the executing thread until the work is complete + DnnlStream::getInstance().getStream().wait(); + } } dnnl::algorithm dnnlMapToPoolingMode(const PoolingMode mode) { - switch (mode) { - case PoolingMode::MAX: - return dnnl::algorithm::pooling_max; - case PoolingMode::AVG_INCLUDE_PADDING: - return dnnl::algorithm::pooling_avg_include_padding; - case PoolingMode::AVG_EXCLUDE_PADDING: - return dnnl::algorithm::pooling_avg_exclude_padding; - default: - throw std::invalid_argument("unsupported pooling mode for cuDNN"); - } + switch(mode) { + case PoolingMode::MAX: + return dnnl::algorithm::pooling_max; + case PoolingMode::AVG_INCLUDE_PADDING: + return dnnl::algorithm::pooling_avg_include_padding; + case PoolingMode::AVG_EXCLUDE_PADDING: + return dnnl::algorithm::pooling_avg_exclude_padding; + default: + throw std::invalid_argument("unsupported pooling mode for cuDNN"); + } } } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.h b/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.h index 07be6db..5869c49 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.h +++ b/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.h @@ -25,71 +25,72 @@ namespace detail { /** * A singleton class that contains a static instance of a dnnl::stream. */ -class DnnlStream { - public: - DnnlStream(dnnl::engine engine); - ~DnnlStream() = default; + class DnnlStream { + public: + DnnlStream(dnnl::engine engine); + ~DnnlStream() = default; - /// Prohibit assignment - DnnlStream& operator=(DnnlStream const& s) = delete; + /// Prohibit assignment + DnnlStream& operator=(DnnlStream const& s) = delete; - dnnl::stream& getStream(); + dnnl::stream& getStream(); - static DnnlStream& getInstance(); + static DnnlStream& getInstance(); - private: - dnnl::stream stream_; -}; + private: + dnnl::stream stream_; + }; /** * A singleton class that contains a static instance of a dnnl::engine. */ -class DnnlEngine { - public: - DnnlEngine(); - ~DnnlEngine() = default; + class DnnlEngine { + public: + DnnlEngine(); + ~DnnlEngine() = default; - /// Prohibit assignment - DnnlEngine& operator=(DnnlEngine const& e) = delete; + /// Prohibit assignment + DnnlEngine& operator=(DnnlEngine const& e) = delete; - dnnl::engine& getEngine(); + dnnl::engine& getEngine(); - static DnnlEngine& getInstance(); + static DnnlEngine& getInstance(); - private: - dnnl::engine engine_; -}; + private: + dnnl::engine engine_; + }; /** * Helper for converting a Flashlight Shape into an DNNL-compatible input * for dnnl::memory::dims. */ -dnnl::memory::dims convertToDnnlDims(const std::vector& dims); -dnnl::memory::dims convertShapeToDnnlDims(const Shape& shape); + dnnl::memory::dims convertToDnnlDims(const std::vector& dims); + dnnl::memory::dims convertShapeToDnnlDims(const Shape& shape); /** * A light wrapper around dnnl::memory that manages underlying memory lifetime * in accordance with fl::DevicePtr. */ -class DnnlMemoryWrapper { - public: - DnnlMemoryWrapper( - const Tensor& tensor, - dnnl::memory::dims dims, - dnnl::memory::format_tag format); - DnnlMemoryWrapper() = default; + class DnnlMemoryWrapper { + public: + DnnlMemoryWrapper( + const Tensor& tensor, + dnnl::memory::dims dims, + dnnl::memory::format_tag format + ); + DnnlMemoryWrapper() = default; - DnnlMemoryWrapper& operator=(DnnlMemoryWrapper&& other); + DnnlMemoryWrapper& operator=(DnnlMemoryWrapper&& other); - dnnl::memory getMemory() const; + dnnl::memory getMemory() const; - dnnl::memory::desc getDescriptor() const; + dnnl::memory::desc getDescriptor() const; - private: - dnnl::memory::desc descriptor_; - dnnl::memory memory_; - fl::DevicePtr devicePtr_; -}; + private: + dnnl::memory::desc descriptor_; + dnnl::memory memory_; + fl::DevicePtr devicePtr_; + }; /** * Given some an dnnl network (a ``std::vector``), a @@ -100,11 +101,12 @@ class DnnlMemoryWrapper { * If so, adds a ``dnnl::reorder`` layer to the network, and returns a new * memory descriptor that will be properly reordered. */ -dnnl::memory dnnlAlignOrdering( - std::vector& net, - std::vector>& netArgs, - const dnnl::memory& memory, - const dnnl::memory::desc& desc); + dnnl::memory dnnlAlignOrdering( + std::vector& net, + std::vector>& netArgs, + const dnnl::memory& memory, + const dnnl::memory::desc& desc + ); /** * Executes a sequence of DNNL primitives in the default execution stream with @@ -116,32 +118,33 @@ dnnl::memory dnnlAlignOrdering( * * Blocks calling thread until the enqueued work has been completed. */ -void executeNetwork( - std::vector& net, - std::vector>& args); + void executeNetwork( + std::vector& net, + std::vector>& args + ); /** * Given a flashlight pooling mode, returns the corresponding dnnl pooling * mode. */ -dnnl::algorithm dnnlMapToPoolingMode(const PoolingMode mode); + dnnl::algorithm dnnlMapToPoolingMode(const PoolingMode mode); /** * Maps an ArrayFire array datatype into the corresponding DNNL datatype. * * Needs to be explicitly inlined due to a bug with DNNL. */ -inline dnnl::memory::data_type dnnlMapToType(const fl::dtype t) { - if (t == fl::dtype::f16) { - return dnnl::memory::data_type::f16; - } else if (t == fl::dtype::f32) { - return dnnl::memory::data_type::f32; - } else if (t == fl::dtype::f64) { - throw std::invalid_argument("float64 is not supported by DNNL"); - } else { - throw std::invalid_argument("data type not supported with DNNL"); - } -} + inline dnnl::memory::data_type dnnlMapToType(const fl::dtype t) { + if(t == fl::dtype::f16) { + return dnnl::memory::data_type::f16; + } else if(t == fl::dtype::f32) { + return dnnl::memory::data_type::f32; + } else if(t == fl::dtype::f64) { + throw std::invalid_argument("float64 is not supported by DNNL"); + } else { + throw std::invalid_argument("data type not supported with DNNL"); + } + } } // namespace detail } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/onednn/OneDnnAutogradExtension.cpp b/flashlight/fl/autograd/tensor/backend/onednn/OneDnnAutogradExtension.cpp index d180fec..9b55b8b 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/OneDnnAutogradExtension.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/OneDnnAutogradExtension.cpp @@ -10,9 +10,10 @@ namespace fl { bool OneDnnAutogradExtension::isDataTypeSupported( - const fl::dtype& dtype) const { - // fp16 computation is not supported with onednn - return dtype != fl::dtype::f16; + const fl::dtype& dtype +) const { + // fp16 computation is not supported with onednn + return dtype != fl::dtype::f16; } } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/onednn/OneDnnAutogradExtension.h b/flashlight/fl/autograd/tensor/backend/onednn/OneDnnAutogradExtension.h index 310ecb9..01c06fa 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/OneDnnAutogradExtension.h +++ b/flashlight/fl/autograd/tensor/backend/onednn/OneDnnAutogradExtension.h @@ -12,134 +12,143 @@ namespace fl { class OneDnnAutogradExtension : public AutogradExtension { - // TODO(jacobkahn): implement getEngine + // TODO(jacobkahn): implement getEngine - public: - bool isDataTypeSupported(const fl::dtype& dtype) const override; +public: + bool isDataTypeSupported(const fl::dtype& dtype) const override; - /**************************** Forward ****************************/ - Tensor conv2d( - const Tensor& input, - const Tensor& weights, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr payload) override; + /**************************** Forward ****************************/ + Tensor conv2d( + const Tensor& input, + const Tensor& weights, + const Tensor& bias, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr payload + ) override; - Tensor pool2d( - const Tensor& input, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, - std::shared_ptr payload) override; + Tensor pool2d( + const Tensor& input, + const int wx, + const int wy, + const int sx, + const int sy, + const int px, + const int py, + const PoolingMode mode, + std::shared_ptr payload + ) override; - Tensor batchnorm( - Tensor& saveMean, - Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - Tensor& runningMean, - Tensor& runningVar, - const std::vector& axes, - const bool train, - const double momentum, - const double epsilon, - std::shared_ptr payload) override; + Tensor batchnorm( + Tensor& saveMean, + Tensor& saveVar, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + Tensor& runningMean, + Tensor& runningVar, + const std::vector& axes, + const bool train, + const double momentum, + const double epsilon, + std::shared_ptr payload + ) override; - std::tuple rnn( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const int hiddenSize, - const int numLayers, - const RnnMode mode, - const bool bidirectional, - const float dropout, - std::shared_ptr payload) override; + std::tuple rnn( + const Tensor& input, + const Tensor& hiddenState, + const Tensor& cellState, + const Tensor& weights, + const int hiddenSize, + const int numLayers, + const RnnMode mode, + const bool bidirectional, + const float dropout, + std::shared_ptr payload + ) override; - /**************************** Backward ****************************/ - // ]----- Convolution - Tensor conv2dBackwardData( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& weight, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr dataGradBenchmark, - std::shared_ptr payload) override; + /**************************** Backward ****************************/ + // ]----- Convolution + Tensor conv2dBackwardData( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& weight, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr dataGradBenchmark, + std::shared_ptr payload + ) override; - std::pair conv2dBackwardFilterBias( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& weights, - const Tensor& bias, - const int sx, - const int sy, - const int px, - const int py, - const int dx, - const int dy, - const int groups, - std::shared_ptr filterBench, - std::shared_ptr biasBench, - std::shared_ptr autogradPayload) override; + std::pair conv2dBackwardFilterBias( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& weights, + const Tensor& bias, + const int sx, + const int sy, + const int px, + const int py, + const int dx, + const int dy, + const int groups, + std::shared_ptr filterBench, + std::shared_ptr biasBench, + std::shared_ptr autogradPayload + ) override; - // ]----- pool2D - Tensor pool2dBackward( - const Tensor& gradOutput, - const Tensor& input, - const Tensor& poolOutput, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py, - const PoolingMode mode, - std::shared_ptr payload) override; + // ]----- pool2D + Tensor pool2dBackward( + const Tensor& gradOutput, + const Tensor& input, + const Tensor& poolOutput, + const int wx, + const int wy, + const int sx, + const int sy, + const int px, + const int py, + const PoolingMode mode, + std::shared_ptr payload + ) override; - // ]----- batchnorm - std::tuple batchnormBackward( - const Tensor& gradOutput, - const Tensor& saveMean, - const Tensor& saveVar, - const Tensor& input, - const Tensor& weight, - const std::vector& axes, - const bool train, - const float epsilon, - std::shared_ptr payload) override; + // ]----- batchnorm + std::tuple batchnormBackward( + const Tensor& gradOutput, + const Tensor& saveMean, + const Tensor& saveVar, + const Tensor& input, + const Tensor& weight, + const std::vector& axes, + const bool train, + const float epsilon, + std::shared_ptr payload + ) override; - // ]----- rnn - std::tuple rnnBackward( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weights, - const std::shared_ptr gradData, - const Tensor& output, - const int numLayers, - const int hiddenSize, - const RnnMode mode, - const bool bidirectional, - const float dropProb, - std::shared_ptr payload) override; + // ]----- rnn + std::tuple rnnBackward( + const Tensor& input, + const Tensor& hiddenState, + const Tensor& cellState, + const Tensor& weights, + const std::shared_ptr gradData, + const Tensor& output, + const int numLayers, + const int hiddenSize, + const RnnMode mode, + const bool bidirectional, + const float dropProb, + std::shared_ptr payload + ) override; }; } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp b/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp index bf094b6..5d976e2 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp @@ -22,57 +22,60 @@ namespace fl { namespace { -constexpr size_t kWIdx = 0; -constexpr size_t kHIdx = 1; -constexpr size_t kChannelSizeIdx = 2; -constexpr size_t kBatchSizeIdx = 3; + constexpr size_t kWIdx = 0; + constexpr size_t kHIdx = 1; + constexpr size_t kChannelSizeIdx = 2; + constexpr size_t kBatchSizeIdx = 3; // Use memory::format_tag::any for memory formatting even if pool // inputs are shaped in a particular way. -constexpr auto formatAny = memory::format_tag::any; -constexpr auto formatNCHW = memory::format_tag::nchw; - -struct DimsData { - memory::dims inputDims; - memory::dims outputDims; - memory::dims windowDims; - memory::dims strideDims; - std::vector paddingDims; -}; - -DimsData getDimsData( - const Shape& input, - const Shape& output, - const int wx, - const int wy, - const int sx, - const int sy, - const int px, - const int py) { - DimsData d; - d.inputDims = detail::convertToDnnlDims( - {input.dim(kBatchSizeIdx), - input.dim(kChannelSizeIdx), - input.dim(kHIdx), - input.dim(kWIdx)}); - d.outputDims = detail::convertToDnnlDims( - {input.dim(kBatchSizeIdx), - input.dim(kChannelSizeIdx), - output.dim(kHIdx), - output.dim(kWIdx)}); - d.windowDims = {wy, wx}; - d.strideDims = {sy, sx}; - d.paddingDims = {py, px}; - return d; -} + constexpr auto formatAny = memory::format_tag::any; + constexpr auto formatNCHW = memory::format_tag::nchw; + + struct DimsData { + memory::dims inputDims; + memory::dims outputDims; + memory::dims windowDims; + memory::dims strideDims; + std::vector paddingDims; + }; + + DimsData getDimsData( + const Shape& input, + const Shape& output, + const int wx, + const int wy, + const int sx, + const int sy, + const int px, + const int py + ) { + DimsData d; + d.inputDims = detail::convertToDnnlDims( + {input.dim(kBatchSizeIdx), + input.dim(kChannelSizeIdx), + input.dim(kHIdx), + input.dim(kWIdx)} + ); + d.outputDims = detail::convertToDnnlDims( + {input.dim(kBatchSizeIdx), + input.dim(kChannelSizeIdx), + output.dim(kHIdx), + output.dim(kWIdx)} + ); + d.windowDims = {wy, wx}; + d.strideDims = {sy, sx}; + d.paddingDims = {py, px}; + return d; + } } // namespace struct OneDnnPool2DPayload : detail::AutogradPayloadData { - memory workspace; - memory outputMemory; - DimsData dimsData; - pooling_forward::primitive_desc poolingFwdPrimDesc; + memory workspace; + memory outputMemory; + DimsData dimsData; + pooling_forward::primitive_desc poolingFwdPrimDesc; }; Tensor OneDnnAutogradExtension::pool2d( @@ -84,95 +87,104 @@ Tensor OneDnnAutogradExtension::pool2d( const int px, const int py, const PoolingMode mode, - std::shared_ptr autogradPayload) { - const bool train = (autogradPayload != nullptr); - auto payload = std::make_shared(); - if (train) { - autogradPayload->data = payload; - } - - // inputX x inputY x channels x batch - auto ix = input.dim(kWIdx); - auto iy = input.ndim() > kHIdx ? input.dim(kHIdx) : 1; - auto c = input.ndim() > kChannelSizeIdx ? input.dim(kChannelSizeIdx) : 1; - auto b = input.ndim() > kBatchSizeIdx ? input.dim(kBatchSizeIdx) : 1; - - auto output = Tensor( - {1 + (ix + 2 * px - wx) / sx, 1 + (iy + 2 * py - wy) / sy, c, b}, - input.type()); - - payload->dimsData = - getDimsData({ix, iy, c, b}, output.shape(), wx, wy, sx, sy, px, py); - auto& d = payload->dimsData; - auto dataType = detail::dnnlMapToType(input.type()); - - // Memory desc - auto inputMD = memory::desc({d.inputDims}, dataType, formatNCHW); - auto outputMD = memory::desc({d.outputDims}, dataType, formatAny); - - // Memory - auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); - const detail::DnnlMemoryWrapper inputMemInit( - input, {d.inputDims}, formatNCHW); - const detail::DnnlMemoryWrapper outputMemInit( - output, {d.outputDims}, formatNCHW); - - // Choose a mode based on whether gradients are needed - auto forwardMode = train ? prop_kind::forward : prop_kind::forward_inference; - - // Descriptors - auto poolingMode = detail::dnnlMapToPoolingMode(mode); - payload->poolingFwdPrimDesc = pooling_forward::primitive_desc( - dnnlEngine, - forwardMode, - poolingMode, - inputMD, - outputMD, - d.strideDims, - d.windowDims, - memory::dims{0, 0}, // dilation -- TODO: add to API - d.paddingDims, - d.paddingDims); - auto& primDesc = payload->poolingFwdPrimDesc; - - // Network - std::vector network; - std::vector> fwdArgs; - // Reorder if needed - auto inputDesc = primDesc.src_desc(); - auto outputDesc = primDesc.dst_desc(); - auto inputMemory = detail::dnnlAlignOrdering( - network, fwdArgs, inputMemInit.getMemory(), inputDesc); - payload->outputMemory = outputMemInit.getMemory(); - if (outputMemInit.getMemory().get_desc() != outputDesc) { - payload->outputMemory = memory(outputDesc, dnnlEngine); - } - // Workspace and layer (only training mode requires a workspace) - std::shared_ptr pooling; - std::unordered_map fwdPoolingArgs; - fwdPoolingArgs[DNNL_ARG_SRC] = inputMemory; - fwdPoolingArgs[DNNL_ARG_DST] = payload->outputMemory; - if (train) { - payload->workspace = memory(primDesc.workspace_desc(), dnnlEngine); - pooling = std::make_shared(primDesc); - fwdPoolingArgs[DNNL_ARG_WORKSPACE] = payload->workspace; - } else { - pooling = std::make_shared(primDesc); - } - network.push_back(*pooling); - fwdArgs.push_back(fwdPoolingArgs); - - // Add output reordering if needed - if (payload->outputMemory != outputMemInit.getMemory()) { - network.push_back( - dnnl::reorder(payload->outputMemory, outputMemInit.getMemory())); - fwdArgs.push_back( - {{DNNL_ARG_FROM, payload->outputMemory}, - {DNNL_ARG_TO, outputMemInit.getMemory()}}); - } - - detail::executeNetwork(network, fwdArgs); - return output; + std::shared_ptr autogradPayload +) { + const bool train = (autogradPayload != nullptr); + auto payload = std::make_shared(); + if(train) { + autogradPayload->data = payload; + } + + // inputX x inputY x channels x batch + auto ix = input.dim(kWIdx); + auto iy = input.ndim() > kHIdx ? input.dim(kHIdx) : 1; + auto c = input.ndim() > kChannelSizeIdx ? input.dim(kChannelSizeIdx) : 1; + auto b = input.ndim() > kBatchSizeIdx ? input.dim(kBatchSizeIdx) : 1; + + auto output = Tensor( + {1 + (ix + 2 * px - wx) / sx, 1 + (iy + 2 * py - wy) / sy, c, b}, + input.type() + ); + + payload->dimsData = + getDimsData({ix, iy, c, b}, output.shape(), wx, wy, sx, sy, px, py); + auto& d = payload->dimsData; + auto dataType = detail::dnnlMapToType(input.type()); + + // Memory desc + auto inputMD = memory::desc({d.inputDims}, dataType, formatNCHW); + auto outputMD = memory::desc({d.outputDims}, dataType, formatAny); + + // Memory + auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); + const detail::DnnlMemoryWrapper inputMemInit( + input, {d.inputDims}, formatNCHW); + const detail::DnnlMemoryWrapper outputMemInit( + output, {d.outputDims}, formatNCHW); + + // Choose a mode based on whether gradients are needed + auto forwardMode = train ? prop_kind::forward : prop_kind::forward_inference; + + // Descriptors + auto poolingMode = detail::dnnlMapToPoolingMode(mode); + payload->poolingFwdPrimDesc = pooling_forward::primitive_desc( + dnnlEngine, + forwardMode, + poolingMode, + inputMD, + outputMD, + d.strideDims, + d.windowDims, + memory::dims{0, 0}, // dilation -- TODO: add to API + d.paddingDims, + d.paddingDims + ); + auto& primDesc = payload->poolingFwdPrimDesc; + + // Network + std::vector network; + std::vector> fwdArgs; + // Reorder if needed + auto inputDesc = primDesc.src_desc(); + auto outputDesc = primDesc.dst_desc(); + auto inputMemory = detail::dnnlAlignOrdering( + network, + fwdArgs, + inputMemInit.getMemory(), + inputDesc + ); + payload->outputMemory = outputMemInit.getMemory(); + if(outputMemInit.getMemory().get_desc() != outputDesc) { + payload->outputMemory = memory(outputDesc, dnnlEngine); + } + // Workspace and layer (only training mode requires a workspace) + std::shared_ptr pooling; + std::unordered_map fwdPoolingArgs; + fwdPoolingArgs[DNNL_ARG_SRC] = inputMemory; + fwdPoolingArgs[DNNL_ARG_DST] = payload->outputMemory; + if(train) { + payload->workspace = memory(primDesc.workspace_desc(), dnnlEngine); + pooling = std::make_shared(primDesc); + fwdPoolingArgs[DNNL_ARG_WORKSPACE] = payload->workspace; + } else { + pooling = std::make_shared(primDesc); + } + network.push_back(*pooling); + fwdArgs.push_back(fwdPoolingArgs); + + // Add output reordering if needed + if(payload->outputMemory != outputMemInit.getMemory()) { + network.push_back( + dnnl::reorder(payload->outputMemory, outputMemInit.getMemory()) + ); + fwdArgs.push_back( + {{DNNL_ARG_FROM, payload->outputMemory}, + {DNNL_ARG_TO, outputMemInit.getMemory()}} + ); + } + + detail::executeNetwork(network, fwdArgs); + return output; } Tensor OneDnnAutogradExtension::pool2dBackward( @@ -186,64 +198,67 @@ Tensor OneDnnAutogradExtension::pool2dBackward( const int px, const int py, const PoolingMode mode, - std::shared_ptr autogradPayload) { - if (!autogradPayload) { - throw std::invalid_argument( - "OneDnnAutogradExtension::pool2dBackward given null detail::AutogradPayload"); - } - auto payload = - std::static_pointer_cast(autogradPayload->data); - - auto gradInput = Tensor(input.shape(), fl::dtype::f32); - auto& dnnlEngineBwd = detail::DnnlEngine::getInstance().getEngine(); - - DimsData& d = payload->dimsData; - auto poolingMode = detail::dnnlMapToPoolingMode(mode); - - // Memory - const detail::DnnlMemoryWrapper gradInputMemInit( - gradInput, {d.inputDims}, formatNCHW); - const detail::DnnlMemoryWrapper gradOutputMemInit( - gradOutput, {d.outputDims}, formatNCHW); - - // Descriptors - // Memory descriptors from initialized memory must be used since - // pooling_backward descriptors require an ordering - auto gradInputMD = gradInputMemInit.getMemory().get_desc(); - auto gradOutputMD = gradOutputMemInit.getMemory().get_desc(); - auto bwdPrimitiveDesc = pooling_backward::primitive_desc( - dnnlEngineBwd, - poolingMode, - gradInputMD, - gradOutputMD, - d.strideDims, - d.windowDims, - memory::dims{0, 0}, // dilation - TODO: add to API - d.paddingDims, - d.paddingDims, - payload->poolingFwdPrimDesc // hint - ); - - std::vector networkBackward; - std::vector> bwdArgs; - // Reorder output memory if required - auto gradOutputMemory = detail::dnnlAlignOrdering( - networkBackward, - bwdArgs, - gradOutputMemInit.getMemory(), - payload->outputMemory.get_desc()); - - auto poolBwd = pooling_backward(bwdPrimitiveDesc); - std::unordered_map bwdPoolingArgs = { - {DNNL_ARG_DIFF_SRC, gradInputMemInit.getMemory()}, - {DNNL_ARG_DIFF_DST, gradOutputMemory}, - {DNNL_ARG_WORKSPACE, payload->workspace}}; - bwdArgs.push_back(bwdPoolingArgs); - networkBackward.push_back(poolBwd); - - detail::executeNetwork(networkBackward, bwdArgs); - - return gradInput; + std::shared_ptr autogradPayload +) { + if(!autogradPayload) { + throw std::invalid_argument( + "OneDnnAutogradExtension::pool2dBackward given null detail::AutogradPayload" + ); + } + auto payload = + std::static_pointer_cast(autogradPayload->data); + + auto gradInput = Tensor(input.shape(), fl::dtype::f32); + auto& dnnlEngineBwd = detail::DnnlEngine::getInstance().getEngine(); + + DimsData& d = payload->dimsData; + auto poolingMode = detail::dnnlMapToPoolingMode(mode); + + // Memory + const detail::DnnlMemoryWrapper gradInputMemInit( + gradInput, {d.inputDims}, formatNCHW); + const detail::DnnlMemoryWrapper gradOutputMemInit( + gradOutput, {d.outputDims}, formatNCHW); + + // Descriptors + // Memory descriptors from initialized memory must be used since + // pooling_backward descriptors require an ordering + auto gradInputMD = gradInputMemInit.getMemory().get_desc(); + auto gradOutputMD = gradOutputMemInit.getMemory().get_desc(); + auto bwdPrimitiveDesc = pooling_backward::primitive_desc( + dnnlEngineBwd, + poolingMode, + gradInputMD, + gradOutputMD, + d.strideDims, + d.windowDims, + memory::dims{0, 0}, // dilation - TODO: add to API + d.paddingDims, + d.paddingDims, + payload->poolingFwdPrimDesc // hint + ); + + std::vector networkBackward; + std::vector> bwdArgs; + // Reorder output memory if required + auto gradOutputMemory = detail::dnnlAlignOrdering( + networkBackward, + bwdArgs, + gradOutputMemInit.getMemory(), + payload->outputMemory.get_desc() + ); + + auto poolBwd = pooling_backward(bwdPrimitiveDesc); + std::unordered_map bwdPoolingArgs = { + {DNNL_ARG_DIFF_SRC, gradInputMemInit.getMemory()}, + {DNNL_ARG_DIFF_DST, gradOutputMemory}, + {DNNL_ARG_WORKSPACE, payload->workspace}}; + bwdArgs.push_back(bwdPoolingArgs); + networkBackward.push_back(poolBwd); + + detail::executeNetwork(networkBackward, bwdArgs); + + return gradInput; } } // namespace fl diff --git a/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp b/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp index dd1d8a0..e2b2836 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp @@ -20,407 +20,451 @@ namespace fl { namespace { -struct ParsedWeightsAndBias { - // First layer - will be empty if inSize == hiddenSize - Tensor weightsInput1L; - Tensor weightsHidden1L; - Tensor bias1L; - // All other layers - Tensor weightsInput; - Tensor weightsHidden; - Tensor bias; -}; + struct ParsedWeightsAndBias { + // First layer - will be empty if inSize == hiddenSize + Tensor weightsInput1L; + Tensor weightsHidden1L; + Tensor bias1L; + // All other layers + Tensor weightsInput; + Tensor weightsHidden; + Tensor bias; + }; // Each gate's weights have dimensions d1 x d2 -Tensor reorderLbrGruWeights(int d1, int d2, const Tensor& weights) { - // LBR GRU requires switch the given the r, u, o gate order from cuDNN to u, - // r, o as required by oneDNN (this from empirical verification) - int weightsSize = d1 * d2; - if (weights.elements() != weightsSize * 3) { - throw std::invalid_argument( - "RNN reorderLbrGruWeights given invalid weights tensor or dims - " - "weights of size " + - std::to_string(weights.elements()) + " which should be exactly " + - std::to_string(weightsSize * 3)); - } - return fl::concatenate( - 0, - weights.flat(fl::range(weightsSize, 2 * weightsSize)), - weights.flat(fl::range(0, weightsSize)), - weights.flat(fl::range(2 * weightsSize, fl::end))); -} + Tensor reorderLbrGruWeights(int d1, int d2, const Tensor& weights) { + // LBR GRU requires switch the given the r, u, o gate order from cuDNN to u, + // r, o as required by oneDNN (this from empirical verification) + int weightsSize = d1 * d2; + if(weights.elements() != weightsSize * 3) { + throw std::invalid_argument( + "RNN reorderLbrGruWeights given invalid weights tensor or dims - " + "weights of size " + + std::to_string(weights.elements()) + " which should be exactly " + + std::to_string(weightsSize * 3) + ); + } + return fl::concatenate( + 0, + weights.flat(fl::range(weightsSize, 2 * weightsSize)), + weights.flat(fl::range(0, weightsSize)), + weights.flat(fl::range(2 * weightsSize, fl::end)) + ); + } /** * Converts flat cuDNN weights into the corresponding oneDNN onednn RNN weights. */ -ParsedWeightsAndBias parseWeights( - const Tensor& weights, - RnnMode mode, - int numLayers, - int directionMult, - int inSize, - int numGates, - int hiddenSize) { - ParsedWeightsAndBias out; - - // Per-layer sizes for weightsInput and weightsHidden. - // If inSize == hiddenSize, then weightsInputSize == weightsHiddenSize for all - // layers, else all but the first layer - int weightsInputSize1L = directionMult * inSize * numGates * hiddenSize; - int weightsHiddenSize = directionMult * hiddenSize * numGates * hiddenSize; - int weightsInputSize = weightsHiddenSize; - int lbrGruBias = mode == RnnMode::GRU ? 1 : 0; - int biasSize = - numLayers * directionMult * (numGates + lbrGruBias) * hiddenSize; - - bool firstLayerDifferent = inSize != hiddenSize; - // Adjusted if skipping first layer parsing - int numWeightsLayers = firstLayerDifferent ? numLayers - 1 : numLayers; - int weightsOffset = - firstLayerDifferent ? weightsInputSize1L + weightsHiddenSize : 0; - // If skipping the first layer, parse then skip over the first layer - // weights and parse the remaining layers. Parsing all bias layers is still - // fine since biases for each layer have the same size - if (firstLayerDifferent) { - out.weightsInput1L = weights.flat(fl::range(weightsInputSize1L)); - out.weightsHidden1L = weights.flat( - fl::range(weightsInputSize1L, weightsInputSize1L + weightsHiddenSize)); - - if (mode == RnnMode::GRU) { - out.weightsInput1L = - reorderLbrGruWeights(inSize, hiddenSize, out.weightsInput1L); - out.weightsHidden1L = - reorderLbrGruWeights(hiddenSize, hiddenSize, out.weightsHidden1L); - } - } - - auto weightsFlat = weights.flatten().astype(weights.type()); - // cuDNN RNN weights, for each layer, are arranged with a chunk of - // input-hidden weights for each layer followed by a chunk of hidden-hidden - // weights for each layer: - // {[layers x [hiddenSize, inputSize]], [layers x [hiddenSize, hiddenSize]] } - // Rearrange this to what oneDNN expects (or will reorder if not optimal), - // which is numLayers chunks of two chunks containing input-hidden and - // hidden-hidden: - // {[layers x [[hiddenSize x inSize], [hiddenSize x hiddenSize]]]} - // Note that the loop is over the total number of layers in case we'r doing a - // single-layer operation where input size and hidden size are different but - // we'll call another primitive with the output of that first layer as the - // input to the next layers - auto weightsInput = Tensor({0}, weights.type()); - auto weightsHidden = Tensor({0}, weights.type()); - Tensor weightsFlatOffset = - weightsFlat.flat(fl::range(weightsOffset, fl::end)); - // Specifically ignore the first layer's weights, so inSize == hiddenSize - for (int i = 0; i < numWeightsLayers; ++i) { - // number of input/hidden weights - // TODO: Will change for bidirectional - int chunkSize = hiddenSize * hiddenSize * numGates; - // weights per layer - int layerChunkSize = chunkSize + chunkSize; - - // Grab input-hidden weights and chunk them together - auto inputWeightsChunk = weightsFlatOffset.flat( - fl::range(layerChunkSize * i, layerChunkSize * i + chunkSize)); - // Grab hidden-hidden weights and chunk them together - auto inputHiddenChunk = weightsFlatOffset.flat(fl::range( - layerChunkSize * i + chunkSize, - layerChunkSize * i + chunkSize + chunkSize)); - - if (mode == RnnMode::GRU) { - inputWeightsChunk = - reorderLbrGruWeights(hiddenSize, hiddenSize, inputWeightsChunk); - inputHiddenChunk = - reorderLbrGruWeights(hiddenSize, hiddenSize, inputHiddenChunk); - } - - weightsInput = fl::concatenate(2, weightsInput, inputWeightsChunk); - weightsHidden = fl::concatenate(2, weightsHidden, inputHiddenChunk); - } - out.weightsInput = weightsInput; - out.weightsHidden = weightsHidden; - - // Reduce the weights to form biases. cuDNN uses two separate bias terms: - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNMode_t - - // oneDNN expects only one bias term. Sum together the coefficients for both - // bias terms to get a single bias term for oneDNN. The gradients for - // each term can be computed as one since the gradients with respect to - // the bias subarrays will simply be half of the computed gradient with - // oneDNN - Tensor bias(weights.type()); - int biasStartOffset = numLayers * weightsHiddenSize + - (numLayers - 1) * weightsInputSize + weightsInputSize1L; - // In vanilla RNN modes, the biases can be simply added: - // two biases for each bias in fl cuDNN with CUDNN_RNN_DOUBLE_BIAS (default) - int numBiases = 2; - // First, grab a subarray which contains only both bias terms; then add them - Tensor biasFlat = weightsFlat.flat(fl::range(biasStartOffset, fl::end)); - // Layout is: {numLayers x [numBiases x [bias shape]]} - for (int i = 0; i < numLayers; ++i) { - if (mode == RnnMode::GRU) { - int lbrGruChunkSize = hiddenSize * 6; - // In the case of the LBR GRU, there's an extra bias term which shouldn't - // be combined with the first two pairs of biases. Six chunks total. - // cuDNN --> oneDNN transformation for ordering: - // r1, u1, o, r2, u2, u' --> u1 + u2, r1 + r2, o, u' - int base = i * lbrGruChunkSize; - // The sum of the following tensors yields the correct bias - // u1, r1, o, u' - auto biases1 = fl::concatenate( - 0, - // u1 -- [1, 2] - biasFlat.flat( - fl::range(base + hiddenSize * 1, base + hiddenSize * 2)), - // r1 -- [0, 1] - biasFlat.flat( - fl::range(base + hiddenSize * 0, base + hiddenSize * 1)), - // o -- [2, 3] - biasFlat.flat( - fl::range(base + hiddenSize * 2, base + hiddenSize * 3)), - // 'u -- [5, 6] - biasFlat.flat( - fl::range(base + hiddenSize * 5, base + hiddenSize * 6))); - // u2, r2, 0, 0 - auto biases2 = fl::concatenate( - 0, - // u2 -- [4, 5] - biasFlat.flat( - fl::range(base + hiddenSize * 4, base + hiddenSize * 5)), - // r2 -- [3, 4] - biasFlat.flat( - fl::range(base + hiddenSize * 3, base + hiddenSize * 4)), - // zeroes to add to o and u' - fl::full({hiddenSize * 2}, 0., biasFlat.type())); - auto layerBiasCombined = biases1 + biases2; - bias = fl::concatenate(0, bias, layerBiasCombined); - } else { - // The number of bias terms in the tensor per-layer - int layerStride = biasSize / numLayers * numBiases; - auto biases1 = biasFlat(fl::range( - layerStride * i, layerStride * i + layerStride / numBiases)); - auto biases2 = biasFlat(fl::range( - layerStride * i + layerStride / numBiases, layerStride * (i + 1))); - auto layerBiasCombined = biases1 + biases2; - bias = fl::concatenate(0, bias, layerBiasCombined); - } - } - - if (firstLayerDifferent) { - out.bias1L = bias.flat(fl::range(biasSize / numLayers)); - if (numLayers > 1) { - // bias for the second --> last layer - bias = bias.flat(fl::range(biasSize / numLayers, fl::end)); + ParsedWeightsAndBias parseWeights( + const Tensor& weights, + RnnMode mode, + int numLayers, + int directionMult, + int inSize, + int numGates, + int hiddenSize + ) { + ParsedWeightsAndBias out; + + // Per-layer sizes for weightsInput and weightsHidden. + // If inSize == hiddenSize, then weightsInputSize == weightsHiddenSize for all + // layers, else all but the first layer + int weightsInputSize1L = directionMult * inSize * numGates * hiddenSize; + int weightsHiddenSize = directionMult * hiddenSize * numGates * hiddenSize; + int weightsInputSize = weightsHiddenSize; + int lbrGruBias = mode == RnnMode::GRU ? 1 : 0; + int biasSize = + numLayers * directionMult * (numGates + lbrGruBias) * hiddenSize; + + bool firstLayerDifferent = inSize != hiddenSize; + // Adjusted if skipping first layer parsing + int numWeightsLayers = firstLayerDifferent ? numLayers - 1 : numLayers; + int weightsOffset = + firstLayerDifferent ? weightsInputSize1L + weightsHiddenSize : 0; + // If skipping the first layer, parse then skip over the first layer + // weights and parse the remaining layers. Parsing all bias layers is still + // fine since biases for each layer have the same size + if(firstLayerDifferent) { + out.weightsInput1L = weights.flat(fl::range(weightsInputSize1L)); + out.weightsHidden1L = weights.flat( + fl::range(weightsInputSize1L, weightsInputSize1L + weightsHiddenSize) + ); + + if(mode == RnnMode::GRU) { + out.weightsInput1L = + reorderLbrGruWeights(inSize, hiddenSize, out.weightsInput1L); + out.weightsHidden1L = + reorderLbrGruWeights(hiddenSize, hiddenSize, out.weightsHidden1L); + } + } + + auto weightsFlat = weights.flatten().astype(weights.type()); + // cuDNN RNN weights, for each layer, are arranged with a chunk of + // input-hidden weights for each layer followed by a chunk of hidden-hidden + // weights for each layer: + // {[layers x [hiddenSize, inputSize]], [layers x [hiddenSize, hiddenSize]] } + // Rearrange this to what oneDNN expects (or will reorder if not optimal), + // which is numLayers chunks of two chunks containing input-hidden and + // hidden-hidden: + // {[layers x [[hiddenSize x inSize], [hiddenSize x hiddenSize]]]} + // Note that the loop is over the total number of layers in case we'r doing a + // single-layer operation where input size and hidden size are different but + // we'll call another primitive with the output of that first layer as the + // input to the next layers + auto weightsInput = Tensor({0}, weights.type()); + auto weightsHidden = Tensor({0}, weights.type()); + Tensor weightsFlatOffset = + weightsFlat.flat(fl::range(weightsOffset, fl::end)); + // Specifically ignore the first layer's weights, so inSize == hiddenSize + for(int i = 0; i < numWeightsLayers; ++i) { + // number of input/hidden weights + // TODO: Will change for bidirectional + int chunkSize = hiddenSize * hiddenSize * numGates; + // weights per layer + int layerChunkSize = chunkSize + chunkSize; + + // Grab input-hidden weights and chunk them together + auto inputWeightsChunk = weightsFlatOffset.flat( + fl::range(layerChunkSize * i, layerChunkSize * i + chunkSize) + ); + // Grab hidden-hidden weights and chunk them together + auto inputHiddenChunk = weightsFlatOffset.flat( + fl::range( + layerChunkSize * i + chunkSize, + layerChunkSize * i + chunkSize + chunkSize + ) + ); + + if(mode == RnnMode::GRU) { + inputWeightsChunk = + reorderLbrGruWeights(hiddenSize, hiddenSize, inputWeightsChunk); + inputHiddenChunk = + reorderLbrGruWeights(hiddenSize, hiddenSize, inputHiddenChunk); + } + + weightsInput = fl::concatenate(2, weightsInput, inputWeightsChunk); + weightsHidden = fl::concatenate(2, weightsHidden, inputHiddenChunk); + } + out.weightsInput = weightsInput; + out.weightsHidden = weightsHidden; + + // Reduce the weights to form biases. cuDNN uses two separate bias terms: + // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNMode_t - + // oneDNN expects only one bias term. Sum together the coefficients for both + // bias terms to get a single bias term for oneDNN. The gradients for + // each term can be computed as one since the gradients with respect to + // the bias subarrays will simply be half of the computed gradient with + // oneDNN + Tensor bias(weights.type()); + int biasStartOffset = numLayers * weightsHiddenSize + + (numLayers - 1) * weightsInputSize + weightsInputSize1L; + // In vanilla RNN modes, the biases can be simply added: + // two biases for each bias in fl cuDNN with CUDNN_RNN_DOUBLE_BIAS (default) + int numBiases = 2; + // First, grab a subarray which contains only both bias terms; then add them + Tensor biasFlat = weightsFlat.flat(fl::range(biasStartOffset, fl::end)); + // Layout is: {numLayers x [numBiases x [bias shape]]} + for(int i = 0; i < numLayers; ++i) { + if(mode == RnnMode::GRU) { + int lbrGruChunkSize = hiddenSize * 6; + // In the case of the LBR GRU, there's an extra bias term which shouldn't + // be combined with the first two pairs of biases. Six chunks total. + // cuDNN --> oneDNN transformation for ordering: + // r1, u1, o, r2, u2, u' --> u1 + u2, r1 + r2, o, u' + int base = i * lbrGruChunkSize; + // The sum of the following tensors yields the correct bias + // u1, r1, o, u' + auto biases1 = fl::concatenate( + 0, + // u1 -- [1, 2] + biasFlat.flat( + fl::range(base + hiddenSize * 1, base + hiddenSize * 2) + ), + // r1 -- [0, 1] + biasFlat.flat( + fl::range(base + hiddenSize * 0, base + hiddenSize * 1) + ), + // o -- [2, 3] + biasFlat.flat( + fl::range(base + hiddenSize * 2, base + hiddenSize * 3) + ), + // 'u -- [5, 6] + biasFlat.flat( + fl::range(base + hiddenSize * 5, base + hiddenSize * 6) + ) + ); + // u2, r2, 0, 0 + auto biases2 = fl::concatenate( + 0, + // u2 -- [4, 5] + biasFlat.flat( + fl::range(base + hiddenSize * 4, base + hiddenSize * 5) + ), + // r2 -- [3, 4] + biasFlat.flat( + fl::range(base + hiddenSize * 3, base + hiddenSize * 4) + ), + // zeroes to add to o and u' + fl::full({hiddenSize* 2}, 0., biasFlat.type()) + ); + auto layerBiasCombined = biases1 + biases2; + bias = fl::concatenate(0, bias, layerBiasCombined); + } else { + // The number of bias terms in the tensor per-layer + int layerStride = biasSize / numLayers * numBiases; + auto biases1 = biasFlat( + fl::range( + layerStride * i, + layerStride * i + layerStride / numBiases + ) + ); + auto biases2 = biasFlat( + fl::range( + layerStride * i + layerStride / numBiases, + layerStride * (i + 1) + ) + ); + auto layerBiasCombined = biases1 + biases2; + bias = fl::concatenate(0, bias, layerBiasCombined); + } + } + + if(firstLayerDifferent) { + out.bias1L = bias.flat(fl::range(biasSize / numLayers)); + if(numLayers > 1) { + // bias for the second --> last layer + bias = bias.flat(fl::range(biasSize / numLayers, fl::end)); + } + } + out.bias = bias; + + // Case for a single layer of different in/hidden size + if(firstLayerDifferent && numLayers == 1) { + out.weightsInput = out.weightsInput1L; + out.weightsHidden = out.weightsHidden1L; + out.bias = out.bias1L; + } + + return out; } - } - out.bias = bias; - // Case for a single layer of different in/hidden size - if (firstLayerDifferent && numLayers == 1) { - out.weightsInput = out.weightsInput1L; - out.weightsHidden = out.weightsHidden1L; - out.bias = out.bias1L; - } - - return out; -} - -struct RnnResult { - dnnl::memory workspace; - Tensor y; // output - Tensor hy; // hidden output - Tensor cy; // cell output -}; + struct RnnResult { + dnnl::memory workspace; + Tensor y; // output + Tensor hy; // hidden output + Tensor cy; // cell output + }; /* * Does forward for a single onednn RNN primitive */ -RnnResult rnnImpl( - const Tensor& input, - const Tensor& hiddenState, - const Tensor& cellState, - const Tensor& weightsInput, - const Tensor& weightsHidden, - const Tensor& bias, - int hiddenSize, - int numLayers, - RnnMode mode, - dnnl::algorithm activation, - int numGates, - dnnl::rnn_direction direction, - int directionMult, - dnnl::prop_kind kind, - float dropout) { - RnnResult result; - auto dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); - - // Dimensions - int inSize = input.dim(0); - int batchSize = input.ndim() < 2 ? 1 : input.dim(1); - int seqLength = input.ndim() < 3 ? 1 : input.dim(2); - dnnl::memory::dims inputDims = {seqLength, batchSize, inSize}; - dnnl::memory::dims outputDims = { - seqLength, batchSize, hiddenSize * directionMult}; - auto dType = detail::dnnlMapToType(input.type()); - int totalLayers = numLayers; - int outSize = hiddenSize; - dnnl::memory::dims hDims = { - totalLayers, directionMult, batchSize, hiddenSize}; - dnnl::memory::dims cDims = { - totalLayers, directionMult, batchSize, hiddenSize}; - int extraBias = mode == RnnMode::GRU ? 1 : 0; // for LBR GRU - dnnl::memory::dims biasDims = { - numLayers, directionMult, numGates + extraBias, hiddenSize}; - // ldigo - dnnl::memory::dims weightsInputDims = { - numLayers, directionMult, inSize, numGates, hiddenSize}; - dnnl::memory::dims weightsHiddenDims = { - numLayers, directionMult, hiddenSize, numGates, hiddenSize}; - - // Out tensors: output (y), hidden state output (hy), cell state output (cy) - auto y = Tensor({outSize, batchSize, seqLength}, input.type()); - auto hy = Tensor({hiddenSize, batchSize, totalLayers}, input.type()); - Tensor cy; - if (mode == RnnMode::LSTM) { - cy = Tensor(hy.shape(), input.type()); - } - - // Memory for forward - auto tnc = dnnl::memory::format_tag::tnc; - auto ldnc = dnnl::memory::format_tag::ldnc; - auto ldgoi = dnnl::memory::format_tag::ldgoi; - auto ldgo = dnnl::memory::format_tag::ldgo; - const detail::DnnlMemoryWrapper inputMemInit( - input.asContiguousTensor(), {inputDims}, tnc); - const detail::DnnlMemoryWrapper outputMemInit(y, {outputDims}, tnc); - detail::DnnlMemoryWrapper hiddenInMemInit; - if (!hiddenState.isEmpty()) { - hiddenInMemInit = detail::DnnlMemoryWrapper( - hiddenState.asContiguousTensor(), {hDims}, ldnc); - } - const detail::DnnlMemoryWrapper hiddenOutMemInit(hy, {hDims}, ldnc); - const detail::DnnlMemoryWrapper weightsInputMemRawInit( - weightsInput.asContiguousTensor(), {weightsInputDims}, ldgoi); - const detail::DnnlMemoryWrapper weightsHiddenMemRawInit( - weightsHidden.asContiguousTensor(), {weightsHiddenDims}, ldgoi); - const detail::DnnlMemoryWrapper biasMemInit( - bias.asContiguousTensor(), {biasDims}, ldgo); - - // TODO(jacobkahn): don't force a format tag - use any and do a reorder based - // on the format of the primitive - what it says - like you're supposed to - // Primitive for reordering input weights: ldgoi --> ldigo - auto weightsInputMemDesc = dnnl::memory::desc( - weightsInputDims, dType, dnnl::memory::format_tag::ldigo); - auto weightsInputMemInit = dnnl::memory(weightsInputMemDesc, dnnlEngine); - // Primitive for reordering iter/hidden weights: ldgoi --> ldigo - auto weightsHiddenMemDesc = dnnl::memory::desc( - weightsHiddenDims, dType, dnnl::memory::format_tag::ldigo); - auto weightsHiddenMemInit = dnnl::memory(weightsHiddenMemDesc, dnnlEngine); - - // Add arguments - std::unordered_map rnnFwdArgs = { - {DNNL_ARG_SRC_LAYER, inputMemInit.getMemory()}, - {DNNL_ARG_SRC_ITER, hiddenInMemInit.getMemory()}, - {DNNL_ARG_WEIGHTS_LAYER, weightsInputMemInit}, - {DNNL_ARG_WEIGHTS_ITER, weightsHiddenMemInit}, - {DNNL_ARG_BIAS, biasMemInit.getMemory()}, - {DNNL_ARG_DST_LAYER, outputMemInit.getMemory()}, - {DNNL_ARG_DST_ITER, hiddenOutMemInit.getMemory()}}; - - // Workspace memory, if needed - dnnl::memory workspace; - std::vector network; - std::vector> fwdArgs; - - // reorder input weights - network.push_back( - dnnl::reorder(weightsInputMemRawInit.getMemory(), weightsInputMemInit)); - fwdArgs.push_back( - {{DNNL_ARG_FROM, weightsInputMemRawInit.getMemory()}, - {DNNL_ARG_TO, weightsInputMemInit}}); - // reorder iter weights - network.push_back( - dnnl::reorder(weightsHiddenMemRawInit.getMemory(), weightsHiddenMemInit)); - fwdArgs.push_back( - {{DNNL_ARG_FROM, weightsHiddenMemRawInit.getMemory()}, - {DNNL_ARG_TO, weightsHiddenMemInit}}); - - // Initialize descriptors - if (mode == RnnMode::RELU || mode == RnnMode::TANH) { - auto vanillaPd = dnnl::vanilla_rnn_forward::primitive_desc( - dnnlEngine, - kind, - activation, - direction, - inputMemInit.getDescriptor(), - hiddenInMemInit.getDescriptor(), - weightsInputMemDesc, // weights "layer" - weightsHiddenMemDesc, // weights "iter" - biasMemInit.getDescriptor(), - outputMemInit.getDescriptor(), - hiddenOutMemInit.getDescriptor()); - network.push_back(dnnl::vanilla_rnn_forward(vanillaPd)); - workspace = dnnl::memory(vanillaPd.workspace_desc(), dnnlEngine); - - } else if (mode == RnnMode::LSTM) { - // LSTM-only - // input cell state - // TODO(jacobkahn): function that takes the array and - // returns the desciptor and memory -- takes an argument for - // which determines whether or not it's ok to return empty - // descriptors if the array is empty - detail::DnnlMemoryWrapper cellInMemInit; - if (!cellState.isEmpty()) { - cellInMemInit = detail::DnnlMemoryWrapper( - cellState.asContiguousTensor(), {cDims}, ldnc); + RnnResult rnnImpl( + const Tensor& input, + const Tensor& hiddenState, + const Tensor& cellState, + const Tensor& weightsInput, + const Tensor& weightsHidden, + const Tensor& bias, + int hiddenSize, + int numLayers, + RnnMode mode, + dnnl::algorithm activation, + int numGates, + dnnl::rnn_direction direction, + int directionMult, + dnnl::prop_kind kind, + float dropout + ) { + RnnResult result; + auto dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); + + // Dimensions + int inSize = input.dim(0); + int batchSize = input.ndim() < 2 ? 1 : input.dim(1); + int seqLength = input.ndim() < 3 ? 1 : input.dim(2); + dnnl::memory::dims inputDims = {seqLength, batchSize, inSize}; + dnnl::memory::dims outputDims = { + seqLength, batchSize, hiddenSize* directionMult}; + auto dType = detail::dnnlMapToType(input.type()); + int totalLayers = numLayers; + int outSize = hiddenSize; + dnnl::memory::dims hDims = { + totalLayers, directionMult, batchSize, hiddenSize}; + dnnl::memory::dims cDims = { + totalLayers, directionMult, batchSize, hiddenSize}; + int extraBias = mode == RnnMode::GRU ? 1 : 0; // for LBR GRU + dnnl::memory::dims biasDims = { + numLayers, directionMult, numGates + extraBias, hiddenSize}; + // ldigo + dnnl::memory::dims weightsInputDims = { + numLayers, directionMult, inSize, numGates, hiddenSize}; + dnnl::memory::dims weightsHiddenDims = { + numLayers, directionMult, hiddenSize, numGates, hiddenSize}; + + // Out tensors: output (y), hidden state output (hy), cell state output (cy) + auto y = Tensor({outSize, batchSize, seqLength}, input.type()); + auto hy = Tensor({hiddenSize, batchSize, totalLayers}, input.type()); + Tensor cy; + if(mode == RnnMode::LSTM) { + cy = Tensor(hy.shape(), input.type()); + } + + // Memory for forward + auto tnc = dnnl::memory::format_tag::tnc; + auto ldnc = dnnl::memory::format_tag::ldnc; + auto ldgoi = dnnl::memory::format_tag::ldgoi; + auto ldgo = dnnl::memory::format_tag::ldgo; + const detail::DnnlMemoryWrapper inputMemInit( + input.asContiguousTensor(), {inputDims}, tnc); + const detail::DnnlMemoryWrapper outputMemInit(y, {outputDims}, tnc); + detail::DnnlMemoryWrapper hiddenInMemInit; + if(!hiddenState.isEmpty()) { + hiddenInMemInit = detail::DnnlMemoryWrapper( + hiddenState.asContiguousTensor(), + {hDims}, + ldnc + ); + } + const detail::DnnlMemoryWrapper hiddenOutMemInit(hy, {hDims}, ldnc); + const detail::DnnlMemoryWrapper weightsInputMemRawInit( + weightsInput.asContiguousTensor(), {weightsInputDims}, ldgoi); + const detail::DnnlMemoryWrapper weightsHiddenMemRawInit( + weightsHidden.asContiguousTensor(), {weightsHiddenDims}, ldgoi); + const detail::DnnlMemoryWrapper biasMemInit( + bias.asContiguousTensor(), {biasDims}, ldgo); + + // TODO(jacobkahn): don't force a format tag - use any and do a reorder based + // on the format of the primitive - what it says - like you're supposed to + // Primitive for reordering input weights: ldgoi --> ldigo + auto weightsInputMemDesc = dnnl::memory::desc( + weightsInputDims, + dType, + dnnl::memory::format_tag::ldigo + ); + auto weightsInputMemInit = dnnl::memory(weightsInputMemDesc, dnnlEngine); + // Primitive for reordering iter/hidden weights: ldgoi --> ldigo + auto weightsHiddenMemDesc = dnnl::memory::desc( + weightsHiddenDims, + dType, + dnnl::memory::format_tag::ldigo + ); + auto weightsHiddenMemInit = dnnl::memory(weightsHiddenMemDesc, dnnlEngine); + + // Add arguments + std::unordered_map rnnFwdArgs = { + {DNNL_ARG_SRC_LAYER, inputMemInit.getMemory()}, + {DNNL_ARG_SRC_ITER, hiddenInMemInit.getMemory()}, + {DNNL_ARG_WEIGHTS_LAYER, weightsInputMemInit}, + {DNNL_ARG_WEIGHTS_ITER, weightsHiddenMemInit}, + {DNNL_ARG_BIAS, biasMemInit.getMemory()}, + {DNNL_ARG_DST_LAYER, outputMemInit.getMemory()}, + {DNNL_ARG_DST_ITER, hiddenOutMemInit.getMemory()}}; + + // Workspace memory, if needed + dnnl::memory workspace; + std::vector network; + std::vector> fwdArgs; + + // reorder input weights + network.push_back( + dnnl::reorder(weightsInputMemRawInit.getMemory(), weightsInputMemInit) + ); + fwdArgs.push_back( + {{DNNL_ARG_FROM, weightsInputMemRawInit.getMemory()}, + {DNNL_ARG_TO, weightsInputMemInit}} + ); + // reorder iter weights + network.push_back( + dnnl::reorder(weightsHiddenMemRawInit.getMemory(), weightsHiddenMemInit) + ); + fwdArgs.push_back( + {{DNNL_ARG_FROM, weightsHiddenMemRawInit.getMemory()}, + {DNNL_ARG_TO, weightsHiddenMemInit}} + ); + + // Initialize descriptors + if(mode == RnnMode::RELU || mode == RnnMode::TANH) { + auto vanillaPd = dnnl::vanilla_rnn_forward::primitive_desc( + dnnlEngine, + kind, + activation, + direction, + inputMemInit.getDescriptor(), + hiddenInMemInit.getDescriptor(), + weightsInputMemDesc, // weights "layer" + weightsHiddenMemDesc, // weights "iter" + biasMemInit.getDescriptor(), + outputMemInit.getDescriptor(), + hiddenOutMemInit.getDescriptor() + ); + network.push_back(dnnl::vanilla_rnn_forward(vanillaPd)); + workspace = dnnl::memory(vanillaPd.workspace_desc(), dnnlEngine); + + } else if(mode == RnnMode::LSTM) { + // LSTM-only + // input cell state + // TODO(jacobkahn): function that takes the array and + // returns the desciptor and memory -- takes an argument for + // which determines whether or not it's ok to return empty + // descriptors if the array is empty + detail::DnnlMemoryWrapper cellInMemInit; + if(!cellState.isEmpty()) { + cellInMemInit = detail::DnnlMemoryWrapper( + cellState.asContiguousTensor(), + {cDims}, + ldnc + ); + } + // output cell state + detail::DnnlMemoryWrapper cellOutMemInit(cy, cDims, ldnc); + + auto lstmPd = dnnl::lstm_forward::primitive_desc( + dnnlEngine, + kind, + direction, + inputMemInit.getDescriptor(), + hiddenInMemInit.getDescriptor(), + cellInMemInit.getDescriptor(), + weightsInputMemDesc, // weights "layer" + weightsHiddenMemDesc, // weights "iter" + biasMemInit.getDescriptor(), + outputMemInit.getDescriptor(), + hiddenOutMemInit.getDescriptor(), + cellOutMemInit.getDescriptor() + ); + network.push_back(dnnl::lstm_forward(lstmPd)); + workspace = dnnl::memory(lstmPd.workspace_desc(), dnnlEngine); + rnnFwdArgs.insert({DNNL_ARG_SRC_ITER_C, cellInMemInit.getMemory()}); + rnnFwdArgs.insert({DNNL_ARG_DST_ITER_C, cellOutMemInit.getMemory()}); + + } else if(mode == RnnMode::GRU) { + // Use a linear-before-reset GRU so we can have parity with cuDNN + auto gruPd = dnnl::lbr_gru_forward::primitive_desc( + dnnlEngine, + kind, + direction, + inputMemInit.getDescriptor(), + hiddenInMemInit.getDescriptor(), + weightsInputMemDesc, + weightsHiddenMemDesc, + biasMemInit.getDescriptor(), + outputMemInit.getDescriptor(), + hiddenOutMemInit.getDescriptor() + ); + network.push_back(dnnl::lbr_gru_forward(gruPd)); + workspace = dnnl::memory(gruPd.workspace_desc(), dnnlEngine); + } + rnnFwdArgs.insert({DNNL_ARG_WORKSPACE, workspace}); + fwdArgs.push_back(rnnFwdArgs); + + detail::executeNetwork(network, fwdArgs); + + result.y = y; + result.hy = hy; + result.cy = cy; + result.workspace = workspace; + return result; } - // output cell state - detail::DnnlMemoryWrapper cellOutMemInit(cy, cDims, ldnc); - - auto lstmPd = dnnl::lstm_forward::primitive_desc( - dnnlEngine, - kind, - direction, - inputMemInit.getDescriptor(), - hiddenInMemInit.getDescriptor(), - cellInMemInit.getDescriptor(), - weightsInputMemDesc, // weights "layer" - weightsHiddenMemDesc, // weights "iter" - biasMemInit.getDescriptor(), - outputMemInit.getDescriptor(), - hiddenOutMemInit.getDescriptor(), - cellOutMemInit.getDescriptor()); - network.push_back(dnnl::lstm_forward(lstmPd)); - workspace = dnnl::memory(lstmPd.workspace_desc(), dnnlEngine); - rnnFwdArgs.insert({DNNL_ARG_SRC_ITER_C, cellInMemInit.getMemory()}); - rnnFwdArgs.insert({DNNL_ARG_DST_ITER_C, cellOutMemInit.getMemory()}); - - } else if (mode == RnnMode::GRU) { - // Use a linear-before-reset GRU so we can have parity with cuDNN - auto gruPd = dnnl::lbr_gru_forward::primitive_desc( - dnnlEngine, - kind, - direction, - inputMemInit.getDescriptor(), - hiddenInMemInit.getDescriptor(), - weightsInputMemDesc, - weightsHiddenMemDesc, - biasMemInit.getDescriptor(), - outputMemInit.getDescriptor(), - hiddenOutMemInit.getDescriptor()); - network.push_back(dnnl::lbr_gru_forward(gruPd)); - workspace = dnnl::memory(gruPd.workspace_desc(), dnnlEngine); - } - rnnFwdArgs.insert({DNNL_ARG_WORKSPACE, workspace}); - fwdArgs.push_back(rnnFwdArgs); - - detail::executeNetwork(network, fwdArgs); - - result.y = y; - result.hy = hy; - result.cy = cy; - result.workspace = workspace; - return result; -} } // namespace @@ -434,125 +478,136 @@ std::tuple OneDnnAutogradExtension::rnn( const RnnMode mode, const bool bidirectional, const float dropout, - std::shared_ptr autogradPayload) { - if (dropout > 0.0) { - throw std::invalid_argument("onednn RNN: dropout > 0.0 unsupported"); - } - if (bidirectional) { - throw std::invalid_argument("onednn RNN: bidirectional not yet supported"); - } - - const bool train = (autogradPayload != nullptr); - - // Constants - auto direction = bidirectional - ? dnnl::rnn_direction::bidirectional_concat - : dnnl::rnn_direction::unidirectional_left2right; - int directionMult = bidirectional ? 2 : 1; - auto kind = train ? dnnl::prop_kind::forward_training - : dnnl::prop_kind::forward_inference; - int numGates = 1; - auto activation = dnnl::algorithm::undef; - switch (mode) { - case RnnMode::LSTM: - numGates = 4; - break; - case RnnMode::GRU: - numGates = 3; - break; - case RnnMode::RELU: - activation = dnnl::algorithm::eltwise_relu; - break; - case RnnMode::TANH: - activation = dnnl::algorithm::eltwise_tanh; - break; - default: - break; - } - - int inSize = input.dim(0); - - // In Flashlight, all RNN weights are stored as one contiguous tensor, so we - // have to parse out the input weights, input biases, hidden weights, and - // hidden biases from one tensor. Order doesn't matter since the arrangement - // is a black box - auto parsedWeights = parseWeights( - weights, mode, numLayers, directionMult, inSize, numGates, hiddenSize); - - RnnResult result; - // The oneDNN RNN primitive has an API limitation where input size and - // hidden size can only differ if the primitive has exactly one layer. - // Therefore, for computations for more than one layer, first do the - // operation for one layer, which gives an output vector of size [hidden - // size, batch size, sequence length * number of directions], then use - // that output as the input for layers [2, L]. Since the input size dim 0 - // is now the hidden size, the primitive can fuse computation for - // arbitrarily-many layers. - if (input.dim(0) == hiddenSize || numLayers == 1) { - // Input and hidden size are the same, or we only have one layer, which - // means we can call the impl as is and parse weights "normally" - result = rnnImpl( - input, - hiddenState, - cellState, - parsedWeights.weightsInput, - parsedWeights.weightsHidden, - parsedWeights.bias, - hiddenSize, - numLayers, - mode, - activation, - numGates, - direction, - directionMult, - kind, - dropout); - } else { - // We require more than one layer with different input and hidden states - - // see the above. Seek to the first layer's hidden/cell state, weights, and - // bias - RnnResult resultL1 = rnnImpl( - input, - hiddenState(fl::span, fl::span, 0), - cellState(fl::span, fl::span, 0), - parsedWeights.weightsInput1L, - parsedWeights.weightsHidden1L, - parsedWeights.bias1L, - hiddenSize, - 1, + std::shared_ptr autogradPayload +) { + if(dropout > 0.0) { + throw std::invalid_argument("onednn RNN: dropout > 0.0 unsupported"); + } + if(bidirectional) { + throw std::invalid_argument("onednn RNN: bidirectional not yet supported"); + } + + const bool train = (autogradPayload != nullptr); + + // Constants + auto direction = bidirectional + ? dnnl::rnn_direction::bidirectional_concat + : dnnl::rnn_direction::unidirectional_left2right; + int directionMult = bidirectional ? 2 : 1; + auto kind = train ? dnnl::prop_kind::forward_training + : dnnl::prop_kind::forward_inference; + int numGates = 1; + auto activation = dnnl::algorithm::undef; + switch(mode) { + case RnnMode::LSTM: + numGates = 4; + break; + case RnnMode::GRU: + numGates = 3; + break; + case RnnMode::RELU: + activation = dnnl::algorithm::eltwise_relu; + break; + case RnnMode::TANH: + activation = dnnl::algorithm::eltwise_tanh; + break; + default: + break; + } + + int inSize = input.dim(0); + + // In Flashlight, all RNN weights are stored as one contiguous tensor, so we + // have to parse out the input weights, input biases, hidden weights, and + // hidden biases from one tensor. Order doesn't matter since the arrangement + // is a black box + auto parsedWeights = parseWeights( + weights, mode, - activation, - numGates, - direction, + numLayers, directionMult, - kind, - dropout); - - /* Layers [2..N] */ - // Seek past the first layer's hidden/cell state, weights, and bias - RnnResult resultL2N = rnnImpl( - resultL1.y, // fixme - hiddenState(fl::span, fl::span, fl::range(1, fl::end)), - cellState(fl::span, fl::span, fl::range(1, fl::end)), - parsedWeights.weightsInput, - parsedWeights.weightsHidden, - parsedWeights.bias, - hiddenSize, - numLayers - 1, // layers [2..N] - mode, - activation, + inSize, numGates, - direction, - directionMult, - kind, - dropout); - - result.y = resultL2N.y; - result.hy = fl::concatenate(2, resultL1.hy, resultL2N.hy); - result.cy = fl::concatenate(2, resultL1.cy, resultL2N.cy); - } + hiddenSize + ); + + RnnResult result; + // The oneDNN RNN primitive has an API limitation where input size and + // hidden size can only differ if the primitive has exactly one layer. + // Therefore, for computations for more than one layer, first do the + // operation for one layer, which gives an output vector of size [hidden + // size, batch size, sequence length * number of directions], then use + // that output as the input for layers [2, L]. Since the input size dim 0 + // is now the hidden size, the primitive can fuse computation for + // arbitrarily-many layers. + if(input.dim(0) == hiddenSize || numLayers == 1) { + // Input and hidden size are the same, or we only have one layer, which + // means we can call the impl as is and parse weights "normally" + result = rnnImpl( + input, + hiddenState, + cellState, + parsedWeights.weightsInput, + parsedWeights.weightsHidden, + parsedWeights.bias, + hiddenSize, + numLayers, + mode, + activation, + numGates, + direction, + directionMult, + kind, + dropout + ); + } else { + // We require more than one layer with different input and hidden states - + // see the above. Seek to the first layer's hidden/cell state, weights, and + // bias + RnnResult resultL1 = rnnImpl( + input, + hiddenState(fl::span, fl::span, 0), + cellState(fl::span, fl::span, 0), + parsedWeights.weightsInput1L, + parsedWeights.weightsHidden1L, + parsedWeights.bias1L, + hiddenSize, + 1, + mode, + activation, + numGates, + direction, + directionMult, + kind, + dropout + ); + + /* Layers [2..N] */ + // Seek past the first layer's hidden/cell state, weights, and bias + RnnResult resultL2N = rnnImpl( + resultL1.y, // fixme + hiddenState(fl::span, fl::span, fl::range(1, fl::end)), + cellState(fl::span, fl::span, fl::range(1, fl::end)), + parsedWeights.weightsInput, + parsedWeights.weightsHidden, + parsedWeights.bias, + hiddenSize, + numLayers - 1, // layers [2..N] + mode, + activation, + numGates, + direction, + directionMult, + kind, + dropout + ); + + result.y = resultL2N.y; + result.hy = fl::concatenate(2, resultL1.hy, resultL2N.hy); + result.cy = fl::concatenate(2, resultL1.cy, resultL2N.cy); + } - return {result.y, result.hy, result.cy}; + return {result.y, result.hy, result.cy}; } std::tuple OneDnnAutogradExtension::rnnBackward( @@ -567,9 +622,11 @@ std::tuple OneDnnAutogradExtension::rnnBackward( const RnnMode mode, const bool bidirectional, const float dropProb, - const std::shared_ptr payload) { - throw std::runtime_error( - "onednn RNN: Gradient computation not yet supported"); + const std::shared_ptr payload +) { + throw std::runtime_error( + "onednn RNN: Gradient computation not yet supported" + ); } } // namespace fl diff --git a/flashlight/fl/common/Defines.cpp b/flashlight/fl/common/Defines.cpp index 39ea847..9e8ae0c 100644 --- a/flashlight/fl/common/Defines.cpp +++ b/flashlight/fl/common/Defines.cpp @@ -14,34 +14,35 @@ namespace fl { OptimLevel OptimMode::getOptimLevel() { - return optimLevel_; + return optimLevel_; } void OptimMode::setOptimLevel(OptimLevel level) { - optimLevel_ = level; + optimLevel_ = level; } OptimMode& OptimMode::get() { - static OptimMode optimMode; - return optimMode; + static OptimMode optimMode; + return optimMode; } OptimLevel OptimMode::toOptimLevel(const std::string& in) { - auto l = kStringToOptimLevel.find(in); - if (l == kStringToOptimLevel.end()) { - throw std::invalid_argument( - "OptimMode::toOptimLevel - no matching " - "optim level for given string."); - } - return l->second; + auto l = kStringToOptimLevel.find(in); + if(l == kStringToOptimLevel.end()) { + throw std::invalid_argument( + "OptimMode::toOptimLevel - no matching " + "optim level for given string." + ); + } + return l->second; } const std::unordered_map - OptimMode::kStringToOptimLevel = { - {"DEFAULT", OptimLevel::DEFAULT}, - {"O1", OptimLevel::O1}, - {"O2", OptimLevel::O2}, - {"O3", OptimLevel::O3}, +OptimMode::kStringToOptimLevel = { + {"DEFAULT", OptimLevel::DEFAULT}, + {"O1", OptimLevel::O1}, + {"O2", OptimLevel::O2}, + {"O3", OptimLevel::O3}, }; } // namespace fl diff --git a/flashlight/fl/common/Defines.h b/flashlight/fl/common/Defines.h index 44a9342..fe1fbbf 100644 --- a/flashlight/fl/common/Defines.h +++ b/flashlight/fl/common/Defines.h @@ -42,59 +42,56 @@ namespace fl { * Reduction mode to used for CrossEntropy, AdaptiveSoftMax etc ... */ enum class ReduceMode { - NONE = 0, - MEAN = 1, - SUM = 2, + NONE = 0, + MEAN = 1, + SUM = 2, }; /** * Pooling method to be used */ enum class PoolingMode { - - /// Use maximum value inside the pooling window - MAX = 0, - - /// Use average value (including padding) inside the pooling window - AVG_INCLUDE_PADDING = 1, - - /// Use average value (excluding padding) inside the pooling window// Use - /// average value (excluding padding) inside the pooling window - AVG_EXCLUDE_PADDING = 2, + /// Use maximum value inside the pooling window + MAX = 0, + /// Use average value (including padding) inside the pooling window + AVG_INCLUDE_PADDING = 1, + /// Use average value (excluding padding) inside the pooling window// Use + /// average value (excluding padding) inside the pooling window + AVG_EXCLUDE_PADDING = 2, }; /** * RNN network type */ enum class RnnMode { - RELU = 0, - TANH = 1, - LSTM = 2, - GRU = 3, + RELU = 0, + TANH = 1, + LSTM = 2, + GRU = 3, }; enum class PaddingMode { - /// Use smallest possible padding such that out_size = ceil(in_size/stride) - SAME = -1, + /// Use smallest possible padding such that out_size = ceil(in_size/stride) + SAME = -1, }; enum class DistributedBackend { - /// https://github.com/facebookincubator/gloo - GLOO = 0, - /// https://developer.nvidia.com/nccl - NCCL = 1, - STUB = 2, + /// https://github.com/facebookincubator/gloo + GLOO = 0, + /// https://developer.nvidia.com/nccl + NCCL = 1, + STUB = 2, }; enum class DistributedInit { - MPI = 0, - FILE_SYSTEM = 1, + MPI = 0, + FILE_SYSTEM = 1, }; namespace DistributedConstants { -constexpr const char* kMaxDevicePerNode = "MAX_DEVICE_PER_NODE"; -constexpr const char* kFilePath = "FILE_PATH"; -constexpr const std::size_t kCoalesceCacheSize = ((size_t)(20) << 20); // 20 MB + constexpr const char* kMaxDevicePerNode = "MAX_DEVICE_PER_NODE"; + constexpr const char* kFilePath = "FILE_PATH"; + constexpr const std::size_t kCoalesceCacheSize = ((size_t) (20) << 20); // 20 MB } // namespace DistributedConstants constexpr std::size_t kDynamicBenchmarkDefaultCount = 10; @@ -114,17 +111,17 @@ constexpr double kAmpMinimumScaleFactorValue = 1e-4; * - https://bit.ly/310k8Z6 */ enum class OptimLevel { - /// All operations occur in default (f32 or f64) precision. - DEFAULT = 0, - /// Operations that perform reduction accumulation, including layer/batch - /// normalization are performed in f32 - all other operations are in fp16. - /// To be used in a standard mixed-precision training setup. - O1 = 1, - /// Only batch and layer normalization occur in f32 - all other operations - /// occur in f16. - O2 = 2, - /// All operations that support it use fp16. - O3 = 3 + /// All operations occur in default (f32 or f64) precision. + DEFAULT = 0, + /// Operations that perform reduction accumulation, including layer/batch + /// normalization are performed in f32 - all other operations are in fp16. + /// To be used in a standard mixed-precision training setup. + O1 = 1, + /// Only batch and layer normalization occur in f32 - all other operations + /// occur in f16. + O2 = 2, + /// All operations that support it use fp16. + O3 = 3 }; /** @@ -132,35 +129,35 @@ enum class OptimLevel { * flashlight. */ class FL_API OptimMode { - public: - /** - * @return the OptimMode singleton - */ - static OptimMode& get(); - - /** - * Gets the current optimization level. Not thread safe. - * - * @return the current optimization level. - */ - OptimLevel getOptimLevel(); - - /** - * Gets the current optimization level. Not thread safe. - * - * @param[in] level the optimization level to set - */ - void setOptimLevel(OptimLevel level); - - /** - * - */ - static OptimLevel toOptimLevel(const std::string& in); - - static const std::unordered_map kStringToOptimLevel; - - private: - OptimLevel optimLevel_{OptimLevel::DEFAULT}; +public: + /** + * @return the OptimMode singleton + */ + static OptimMode& get(); + + /** + * Gets the current optimization level. Not thread safe. + * + * @return the current optimization level. + */ + OptimLevel getOptimLevel(); + + /** + * Gets the current optimization level. Not thread safe. + * + * @param[in] level the optimization level to set + */ + void setOptimLevel(OptimLevel level); + + /** + * + */ + static OptimLevel toOptimLevel(const std::string& in); + + static const std::unordered_map kStringToOptimLevel; + +private: + OptimLevel optimLevel_{OptimLevel::DEFAULT}; }; /** @} */ diff --git a/flashlight/fl/common/DevicePtr.cpp b/flashlight/fl/common/DevicePtr.cpp index 144af9e..2aef961 100644 --- a/flashlight/fl/common/DevicePtr.cpp +++ b/flashlight/fl/common/DevicePtr.cpp @@ -13,42 +13,42 @@ namespace fl { -DevicePtr::DevicePtr(const Tensor& in) - : tensor_(std::make_unique(in.shallowCopy())) { - if (tensor_->isEmpty()) { - ptr_ = nullptr; - } else { - if (!tensor_->isContiguous()) { - throw std::invalid_argument( - "can't get device pointer of non-contiguous Tensor"); +DevicePtr::DevicePtr(const Tensor& in) : tensor_(std::make_unique(in.shallowCopy())) { + if(tensor_->isEmpty()) { + ptr_ = nullptr; + } else { + if(!tensor_->isContiguous()) { + throw std::invalid_argument( + "can't get device pointer of non-contiguous Tensor" + ); + } + ptr_ = tensor_->device(); } - ptr_ = tensor_->device(); - } } DevicePtr::~DevicePtr() { - if (ptr_ != nullptr) { - tensor_->unlock(); - } + if(ptr_ != nullptr) { + tensor_->unlock(); + } } -DevicePtr::DevicePtr(DevicePtr&& d) noexcept - : tensor_(std::move(d.tensor_)), ptr_(d.ptr_) { - d.ptr_ = nullptr; +DevicePtr::DevicePtr(DevicePtr&& d) noexcept : tensor_(std::move(d.tensor_)), + ptr_(d.ptr_) { + d.ptr_ = nullptr; } DevicePtr& DevicePtr::operator=(DevicePtr&& other) noexcept { - if (ptr_ != nullptr) { - tensor_->unlock(); - } - tensor_ = std::move(other.tensor_); - ptr_ = other.ptr_; - other.ptr_ = nullptr; - return *this; + if(ptr_ != nullptr) { + tensor_->unlock(); + } + tensor_ = std::move(other.tensor_); + ptr_ = other.ptr_; + other.ptr_ = nullptr; + return *this; } void* DevicePtr::get() const { - return ptr_; + return ptr_; } } // namespace fl diff --git a/flashlight/fl/common/DevicePtr.h b/flashlight/fl/common/DevicePtr.h index 04f43a6..4666413 100644 --- a/flashlight/fl/common/DevicePtr.h +++ b/flashlight/fl/common/DevicePtr.h @@ -36,52 +36,52 @@ class Tensor; * */ class FL_API DevicePtr { - public: - /** - * Creates a null DevicePtr. - */ - DevicePtr() : ptr_(nullptr) {} +public: + /** + * Creates a null DevicePtr. + */ + DevicePtr() : ptr_(nullptr) {} - /** - * @param in input array to get device pointer - */ - explicit DevicePtr(const Tensor& in); + /** + * @param in input array to get device pointer + */ + explicit DevicePtr(const Tensor& in); - /** - *`.unlock()` is called on the underlying array in destructor - */ - ~DevicePtr(); + /** + *`.unlock()` is called on the underlying array in destructor + */ + ~DevicePtr(); - DevicePtr(const DevicePtr& other) = delete; + DevicePtr(const DevicePtr& other) = delete; - DevicePtr& operator=(const DevicePtr& other) = delete; + DevicePtr& operator=(const DevicePtr& other) = delete; - DevicePtr(DevicePtr&& d) noexcept; + DevicePtr(DevicePtr&& d) noexcept; - DevicePtr& operator=(DevicePtr&& other) noexcept; + DevicePtr& operator=(DevicePtr&& other) noexcept; - bool operator==(const DevicePtr& other) const { - return get() == other.get(); - } + bool operator==(const DevicePtr& other) const { + return get() == other.get(); + } - void* get() const; + void* get() const; - template - T* getAs() const { - return reinterpret_cast(ptr_); - } + template + T* getAs() const { + return reinterpret_cast(ptr_); + } - protected: - std::unique_ptr tensor_; +protected: + std::unique_ptr tensor_; - private: - void* ptr_; +private: + void* ptr_; }; struct DevicePtrHasher { - std::size_t operator()(const DevicePtr& k) const { - return std::hash()(k.get()); - } + std::size_t operator()(const DevicePtr& k) const { + return std::hash()(k.get()); + } }; } // namespace fl diff --git a/flashlight/fl/common/DynamicBenchmark.cpp b/flashlight/fl/common/DynamicBenchmark.cpp index 8157de0..96797a6 100644 --- a/flashlight/fl/common/DynamicBenchmark.cpp +++ b/flashlight/fl/common/DynamicBenchmark.cpp @@ -15,36 +15,37 @@ bool DynamicBenchmark::benchmarkMode_ = false; void DynamicBenchmark::audit( const std::function& function, - bool incrementCount) { - // Only run the benchmarking components if some options are yet to be - // fully-timed and benchmark mode is on - otherwise, only run the passed - // lambda - if (options_->timingsComplete() || !benchmarkMode_) { - function(); - } else { - start(); - function(); - stop(incrementCount); - } + bool incrementCount +) { + // Only run the benchmarking components if some options are yet to be + // fully-timed and benchmark mode is on - otherwise, only run the passed + // lambda + if(options_->timingsComplete() || !benchmarkMode_) { + function(); + } else { + start(); + function(); + stop(incrementCount); + } } void DynamicBenchmark::start() { - fl::sync(); - currentTimer_ = fl::Timer::start(); + fl::sync(); + currentTimer_ = fl::Timer::start(); } void DynamicBenchmark::stop(bool incrementCount) { - fl::sync(); - auto elapsedTime = fl::Timer::stop(currentTimer_); - options_->accumulateTimeToCurrentOption(elapsedTime, incrementCount); + fl::sync(); + auto elapsedTime = fl::Timer::stop(currentTimer_); + options_->accumulateTimeToCurrentOption(elapsedTime, incrementCount); } void DynamicBenchmark::setBenchmarkMode(bool mode) { - benchmarkMode_ = mode; + benchmarkMode_ = mode; } bool DynamicBenchmark::getBenchmarkMode() { - return benchmarkMode_; + return benchmarkMode_; } } // namespace fl diff --git a/flashlight/fl/common/DynamicBenchmark.h b/flashlight/fl/common/DynamicBenchmark.h index d2bcb22..cc17ce8 100644 --- a/flashlight/fl/common/DynamicBenchmark.h +++ b/flashlight/fl/common/DynamicBenchmark.h @@ -28,23 +28,25 @@ namespace fl { * This type shouldn't be directly constructed. */ struct FL_API DynamicBenchmarkOptionsBase { - virtual ~DynamicBenchmarkOptionsBase() = default; + virtual ~DynamicBenchmarkOptionsBase() = default; - virtual void accumulateTimeToCurrentOption(double, bool = true) { - throw std::logic_error( - "DynamicBenchmarkOptionsBase::accumulateTimeToCurrentOption " - "- unimplemented"); - } + virtual void accumulateTimeToCurrentOption(double, bool = true) { + throw std::logic_error( + "DynamicBenchmarkOptionsBase::accumulateTimeToCurrentOption " + "- unimplemented" + ); + } - virtual bool timingsComplete() { - throw std::logic_error( - "DynamicBenchmarkOptionsBase::timingsComplete " - "- unimplemented"); - } + virtual bool timingsComplete() { + throw std::logic_error( + "DynamicBenchmarkOptionsBase::timingsComplete " + "- unimplemented" + ); + } - protected: - // Not intended for construction - DynamicBenchmarkOptionsBase() = default; +protected: + // Not intended for construction + DynamicBenchmarkOptionsBase() = default; }; /** @@ -57,134 +59,137 @@ struct FL_API DynamicBenchmarkOptionsBase { * accumulated, and provides the option with the lowest timing/best * performance when timings are complete. */ -template +template struct DynamicBenchmarkOptions : DynamicBenchmarkOptionsBase { - /** - * Constructs an instance given a vector of options of specified type. The - * options are assumed to be distinct since benchmarks options are - * determined by index. - * - * @param[in] options vector of options to use - * @param[in] benchCount the number of times to benchmark each option before - * fixing on the optimal option - */ - DynamicBenchmarkOptions(std::vector options, size_t benchCount) - : options_(options), benchCount_(benchCount) { - if (options_.empty()) { - throw std::invalid_argument( - "DynamicBenchmarkOptions: " - "Options must be passed vector with at least one element"); + /** + * Constructs an instance given a vector of options of specified type. The + * options are assumed to be distinct since benchmarks options are + * determined by index. + * + * @param[in] options vector of options to use + * @param[in] benchCount the number of times to benchmark each option before + * fixing on the optimal option + */ + DynamicBenchmarkOptions(std::vector options, size_t benchCount) + : options_(options), benchCount_(benchCount) { + if(options_.empty()) { + throw std::invalid_argument( + "DynamicBenchmarkOptions: " + "Options must be passed vector with at least one element" + ); + } + reset(); } - reset(); - } - /** - * Constructs an instance given a set of options. - * - * @param[in] options a set of options to use - * @param[in] benchCount the number of times to benchmark each option before - * fixing on the optimal option - */ - DynamicBenchmarkOptions(std::unordered_set options, size_t benchCount) - : DynamicBenchmarkOptions( - std::vector(options.begin(), options.end()), - benchCount) {} + /** + * Constructs an instance given a set of options. + * + * @param[in] options a set of options to use + * @param[in] benchCount the number of times to benchmark each option before + * fixing on the optimal option + */ + DynamicBenchmarkOptions(std::unordered_set options, size_t benchCount) + : DynamicBenchmarkOptions( + std::vector(options.begin(), options.end()), + benchCount + ) {} - /** - * Gets the current option; updates the current state. - * - * If each option hasn't been used/timed as many times as the max count, pick - * the first option that hasn't been timed to the maximum count. If all - * timings are complete, choose the optimal timing. - * - * @return the current option. - */ - T updateState() { - if (!timingsComplete_) { - for (size_t i = 0; i < options_.size(); ++i) { - if (counts_[i] < benchCount_) { - currentOptionIdx_ = i; - return options_[i]; - } - } - timingsComplete_ = true; + /** + * Gets the current option; updates the current state. + * + * If each option hasn't been used/timed as many times as the max count, pick + * the first option that hasn't been timed to the maximum count. If all + * timings are complete, choose the optimal timing. + * + * @return the current option. + */ + T updateState() { + if(!timingsComplete_) { + for(size_t i = 0; i < options_.size(); ++i) { + if(counts_[i] < benchCount_) { + currentOptionIdx_ = i; + return options_[i]; + } + } + timingsComplete_ = true; - // All options have been benchmarked with the max count - pick the one - // with the lowest time - size_t minTimeOptionIdx{0}; - for (size_t i = 0; i < options_.size(); ++i) { - if (times_[i] < times_[minTimeOptionIdx]) { - minTimeOptionIdx = i; + // All options have been benchmarked with the max count - pick the one + // with the lowest time + size_t minTimeOptionIdx{0}; + for(size_t i = 0; i < options_.size(); ++i) { + if(times_[i] < times_[minTimeOptionIdx]) { + minTimeOptionIdx = i; + } + } + currentOptionIdx_ = minTimeOptionIdx; } - } - currentOptionIdx_ = minTimeOptionIdx; + return options_[currentOptionIdx_]; } - return options_[currentOptionIdx_]; - } - - /** - * Gets the options' current value. This is deterministically computed and - * only changes as per calls to `accumulateTimeToCurrentOption` that may - * increment the count - * - * @return T the current option. - */ - T currentOption() { - return updateState(); - } - /** - * @return whether or not this options' timings are complete. - */ - bool timingsComplete() override { - updateState(); - return timingsComplete_; - } + /** + * Gets the options' current value. This is deterministically computed and + * only changes as per calls to `accumulateTimeToCurrentOption` that may + * increment the count + * + * @return T the current option. + */ + T currentOption() { + return updateState(); + } - /** - * Adds time to the current option tally. - * - * @param[in] time duration to add - * @param[in] incrementCount whether or not to increment the benchmark talley - * for the option. This facilitates timing options by using results from - * discontinuous functions - */ - void accumulateTimeToCurrentOption(double time, bool incrementCount = true) - override { - if (timingsComplete()) { - throw std::invalid_argument( - "Options::accumulateTimeToCurrentOption: " - "Tried to accumulate time when benchmarking is complete"); + /** + * @return whether or not this options' timings are complete. + */ + bool timingsComplete() override { + updateState(); + return timingsComplete_; } - updateState(); - times_[currentOptionIdx_] += time; - if (incrementCount) { - counts_[currentOptionIdx_]++; + + /** + * Adds time to the current option tally. + * + * @param[in] time duration to add + * @param[in] incrementCount whether or not to increment the benchmark talley + * for the option. This facilitates timing options by using results from + * discontinuous functions + */ + void accumulateTimeToCurrentOption(double time, bool incrementCount = true) + override { + if(timingsComplete()) { + throw std::invalid_argument( + "Options::accumulateTimeToCurrentOption: " + "Tried to accumulate time when benchmarking is complete" + ); + } + updateState(); + times_[currentOptionIdx_] += time; + if(incrementCount) { + counts_[currentOptionIdx_]++; + } } - } - /** - * Resets options state to the default. Clears timings and counts. - */ - void reset() { - for (size_t i = 0; i < options_.size(); ++i) { - counts_[i] = 0; - times_[i] = 0.; + /** + * Resets options state to the default. Clears timings and counts. + */ + void reset() { + for(size_t i = 0; i < options_.size(); ++i) { + counts_[i] = 0; + times_[i] = 0.; + } + timingsComplete_ = false; + currentOptionIdx_ = 0; } - timingsComplete_ = false; - currentOptionIdx_ = 0; - } - private: - const std::vector options_; - const size_t benchCount_{0}; +private: + const std::vector options_; + const size_t benchCount_{0}; - bool timingsComplete_{false}; - int currentOptionIdx_{0}; // first option is the default - // Number of times the option at each index has been timed - std::unordered_map counts_; - // Accumulated times for each option - std::unordered_map times_; + bool timingsComplete_{false}; + int currentOptionIdx_{0}; // first option is the default + // Number of times the option at each index has been timed + std::unordered_map counts_; + // Accumulated times for each option + std::unordered_map times_; }; /** @@ -194,73 +199,73 @@ struct DynamicBenchmarkOptions : DynamicBenchmarkOptionsBase { * configurations based on detected hardware. */ class FL_API DynamicBenchmark { - public: - explicit DynamicBenchmark( - std::shared_ptr options) - : options_(options) {} +public: + explicit DynamicBenchmark( + std::shared_ptr options + ) : options_(options) {} - virtual ~DynamicBenchmark() = default; + virtual ~DynamicBenchmark() = default; - /** - * Audits a dynamic benchmark function. Acccumulates times based on this - * DynamicBenchmark's options' currently-active option. - * - * If the timings are complete for the benchmark options, simply executes the - passed function. Calls `fl::sync()` before and after function execution to - get an accurate count. - * - * @param[in] function the function to benchmark - */ - void audit(const std::function& function, bool incrementCount = true); + /** + * Audits a dynamic benchmark function. Acccumulates times based on this + * DynamicBenchmark's options' currently-active option. + * + * If the timings are complete for the benchmark options, simply executes the + passed function. Calls `fl::sync()` before and after function execution to + get an accurate count. + * + * @param[in] function the function to benchmark + */ + void audit(const std::function& function, bool incrementCount = true); - /** - * Gets the benchmarks' underlying `DynamicBenchmarkOptionsBase` instance. - * - * @return a pointer to the underlying options. - */ - template - std::shared_ptr getOptions() const { - return std::static_pointer_cast(options_); - } + /** + * Gets the benchmarks' underlying `DynamicBenchmarkOptionsBase` instance. + * + * @return a pointer to the underlying options. + */ + template + std::shared_ptr getOptions() const { + return std::static_pointer_cast(options_); + } - /** - * Sets global benchmark mode. If benchmark mode is on, all - * `DynamicBenchmark`s will run normally. If benchmark mode is off, calling - * `DynamicBenchmark::audit` with a given closure will simply execute the - * closure. - * - * @param[in] mode the new value of benchmark mode - */ - static void setBenchmarkMode(bool mode); + /** + * Sets global benchmark mode. If benchmark mode is on, all + * `DynamicBenchmark`s will run normally. If benchmark mode is off, calling + * `DynamicBenchmark::audit` with a given closure will simply execute the + * closure. + * + * @param[in] mode the new value of benchmark mode + */ + static void setBenchmarkMode(bool mode); - /** - * @return whether benchmark mode is globally enabled - */ - static bool getBenchmarkMode(); + /** + * @return whether benchmark mode is globally enabled + */ + static bool getBenchmarkMode(); - private: - // Starts the benchmark timer - void start(); - // Stops the benchmark timer, accumulates times to the current option - void stop(bool incrementCount); +private: + // Starts the benchmark timer + void start(); + // Stops the benchmark timer, accumulates times to the current option + void stop(bool incrementCount); - std::shared_ptr options_; - // Timer for current benchmark iteration - fl::Timer currentTimer_; + std::shared_ptr options_; + // Timer for current benchmark iteration + fl::Timer currentTimer_; - // Global fl benchmark mode - if off, no benchmarks will run, and audited - // functions will be run directly without timings - static bool benchmarkMode_; + // Global fl benchmark mode - if off, no benchmarks will run, and audited + // functions will be run directly without timings + static bool benchmarkMode_; }; // Specific benchmark implementations namespace detail { -struct ConvBenchmarks { - std::shared_ptr bwdFilterBenchmark; - std::shared_ptr bwdDataBenchmark; - std::shared_ptr bwdBiasBenchmark; -}; + struct ConvBenchmarks { + std::shared_ptr bwdFilterBenchmark; + std::shared_ptr bwdDataBenchmark; + std::shared_ptr bwdBiasBenchmark; + }; } // namespace detail diff --git a/flashlight/fl/common/Histogram.cpp b/flashlight/fl/common/Histogram.cpp index 62dfea3..abca254 100644 --- a/flashlight/fl/common/Histogram.cpp +++ b/flashlight/fl/common/Histogram.cpp @@ -13,33 +13,33 @@ namespace fl { void shortFormatCount(std::stringstream& ss, size_t count) { - constexpr size_t stringLen = 5; - if (count >= 10e13) { // >= 10 trillion - ss << std::setw(stringLen - 1) << (count / (size_t)10e12) << 't'; - } else if (count >= 10e10) { // >= 10 billion - ss << std::setw(stringLen - 1) << (count / (size_t)10e9) << 'b'; - } else if (count >= 10e7) { // >= 10 million - ss << std::setw(stringLen - 1) << (count / (size_t)10e6) << 'm'; - } else if (count >= 10e4) { // >= 10 thousand - ss << std::setw(stringLen - 1) << (count / (size_t)10e3) << 'k'; - } else { - ss << std::setw(stringLen) << count; - } + constexpr size_t stringLen = 5; + if(count >= 10e13) { // >= 10 trillion + ss << std::setw(stringLen - 1) << (count / (size_t) 10e12) << 't'; + } else if(count >= 10e10) { // >= 10 billion + ss << std::setw(stringLen - 1) << (count / (size_t) 10e9) << 'b'; + } else if(count >= 10e7) { // >= 10 million + ss << std::setw(stringLen - 1) << (count / (size_t) 10e6) << 'm'; + } else if(count >= 10e4) { // >= 10 thousand + ss << std::setw(stringLen - 1) << (count / (size_t) 10e3) << 'k'; + } else { + ss << std::setw(stringLen) << count; + } } void shortFormatMemory(std::stringstream& ss, size_t size) { - constexpr size_t stringLen = 5; - if (size >= (1ULL << 43)) { // >= 8TB - ss << std::setw(stringLen - 1) << (size >> 40) << "T"; - } else if (size >= (1ULL << 33)) { // >= 8G B - ss << std::setw(stringLen - 1) << (size >> 30) << "G"; - } else if (size >= (1ULL << 23)) { // >= 8M B - ss << std::setw(stringLen - 1) << (size >> 20) << "M"; - } else if (size >= (1ULL << 13)) { // >= 8K B - ss << std::setw(stringLen - 1) << (size >> 10) << "K"; - } else { - ss << std::setw(stringLen) << size; - } + constexpr size_t stringLen = 5; + if(size >= (1ULL << 43)) { // >= 8TB + ss << std::setw(stringLen - 1) << (size >> 40) << "T"; + } else if(size >= (1ULL << 33)) { // >= 8G B + ss << std::setw(stringLen - 1) << (size >> 30) << "G"; + } else if(size >= (1ULL << 23)) { // >= 8M B + ss << std::setw(stringLen - 1) << (size >> 20) << "M"; + } else if(size >= (1ULL << 13)) { // >= 8K B + ss << std::setw(stringLen - 1) << (size >> 10) << "K"; + } else { + ss << std::setw(stringLen) << size; + } } } // namespace fl diff --git a/flashlight/fl/common/Histogram.h b/flashlight/fl/common/Histogram.h index df0b300..8cc557b 100644 --- a/flashlight/fl/common/Histogram.h +++ b/flashlight/fl/common/Histogram.h @@ -39,56 +39,58 @@ FL_API void shortFormatCount(std::stringstream& ss, size_t count); // count=16777216 (1 << 24) will be written as 16M FL_API void shortFormatMemory(std::stringstream& ss, size_t size); -using histValFmtFunc = std::function; +using histValFmtFunc = std::function; /** * Abstraction of generic histogram bucket. Used in the context of * HistogramStats */ -template +template struct HistogramBucket { - T startInclusive = 0; //! left boundary of the bucket. - T endExclusive = 0; //! right boundary of the bucket. - size_t count = 0; //! Number of elements in this bucket. - - std::string prettyString( - double countPerTick, // ratio of count/bar_length - histValFmtFunc fromatCountIntoStream = shortFormatCount, - histValFmtFunc fromatValuesIntoStream = shortFormatMemory) const; + T startInclusive = 0; // ! left boundary of the bucket. + T endExclusive = 0; // ! right boundary of the bucket. + size_t count = 0; // ! Number of elements in this bucket. + + std::string prettyString( + double countPerTick, // ratio of count/bar_length + histValFmtFunc fromatCountIntoStream = shortFormatCount, + histValFmtFunc fromatValuesIntoStream = shortFormatMemory + ) const; }; /** * Generic data structure for representation of value set stats and histogram. */ -template +template struct HistogramStats { - // double bucketWidth = 0; - T min = 0; - T max = 0; - T sum = 0; - bool sumOverflow = false; - double mean = 0; - size_t numValues = 0; - size_t maxNumValuesPerBucket = 0; - std::vector> buckets; - - std::string prettyString( - size_t maxBarWidth = 50, - histValFmtFunc fromatCountIntoStream = shortFormatCount, - histValFmtFunc fromatValuesIntoStream = shortFormatMemory) const; + // double bucketWidth = 0; + T min = 0; + T max = 0; + T sum = 0; + bool sumOverflow = false; + double mean = 0; + size_t numValues = 0; + size_t maxNumValuesPerBucket = 0; + std::vector> buckets; + + std::string prettyString( + size_t maxBarWidth = 50, + histValFmtFunc fromatCountIntoStream = shortFormatCount, + histValFmtFunc fromatValuesIntoStream = shortFormatMemory + ) const; }; -template +template bool isAdditionSafe(T a, T b) { - if (a > (std::numeric_limits::max() - b)) { - return false; - } - if (std::is_signed::value) { - if (a < 0 && b < 0 && (a < (std::numeric_limits::min() - b))) { - return false; + if(a > (std::numeric_limits::max() - b)) { + return false; } - } - return true; + if(std::is_signed::value) { + if(a < 0 && b < 0 && (a < (std::numeric_limits::min() - b))) { + return false; + } + } + return true; } /** @@ -100,142 +102,149 @@ bool isAdditionSafe(T a, T b) { * @param [clipMinValueInclusive,clipMaxValueExclusive] Consider only values * between the clipping bondenries */ -template +template HistogramStats FixedBucketSizeHistogram( Iterator begin, Iterator end, size_t nBuckets, T clipMinValueInclusive = std::numeric_limits::min(), - T clipMaxValueExclusive = std::numeric_limits::max()) { - if (!nBuckets) { - throw std::invalid_argument( - "FixedBucketSizeHistogram(nBuckets=0) nBuckets " - "must be a positive integer"); - } - - HistogramStats stats; - if (begin == end) { - return stats; - } + T clipMaxValueExclusive = std::numeric_limits::max() +) { + if(!nBuckets) { + throw std::invalid_argument( + "FixedBucketSizeHistogram(nBuckets=0) nBuckets " + "must be a positive integer" + ); + } - stats.min = std::numeric_limits::max(); - stats.max = std::numeric_limits::min(); - stats.buckets.resize(nBuckets); + HistogramStats stats; + if(begin == end) { + return stats; + } - // Calculate min/max, sum, ands mean - double simpleMovingAverage = 0.0; - for (auto itr = begin; itr != end; ++itr) { - if ((*itr < clipMinValueInclusive) || (*itr >= clipMaxValueExclusive)) { - continue; + stats.min = std::numeric_limits::max(); + stats.max = std::numeric_limits::min(); + stats.buckets.resize(nBuckets); + + // Calculate min/max, sum, ands mean + double simpleMovingAverage = 0.0; + for(auto itr = begin; itr != end; ++itr) { + if((*itr < clipMinValueInclusive) || (*itr >= clipMaxValueExclusive)) { + continue; + } + if(!stats.sumOverflow) { + if(isAdditionSafe(stats.sum, *itr)) { + stats.sum += *itr; + } else { + stats.sumOverflow = true; + } + } + + stats.min = std::min(stats.min, *itr); + stats.max = std::max(stats.max, *itr); + double denominator = static_cast(stats.numValues + 1); + double ratio = stats.numValues / denominator; + simpleMovingAverage = simpleMovingAverage * ratio + (*itr / denominator); + ++stats.numValues; } - if (!stats.sumOverflow) { - if (isAdditionSafe(stats.sum, *itr)) { - stats.sum += *itr; - } else { - stats.sumOverflow = true; - } + stats.mean = simpleMovingAverage; + + // Calculate bucket size + double range = stats.max - stats.min; + auto bucketWidth = range / nBuckets; + if(range == 0 || bucketWidth == 0) { + stats.buckets[0].count = stats.numValues; + stats.maxNumValuesPerBucket = stats.numValues; + return stats; } - stats.min = std::min(stats.min, *itr); - stats.max = std::max(stats.max, *itr); - double denominator = static_cast(stats.numValues + 1); - double ratio = stats.numValues / denominator; - simpleMovingAverage = simpleMovingAverage * ratio + (*itr / denominator); - ++stats.numValues; - } - stats.mean = simpleMovingAverage; - - // Calculate bucket size - double range = stats.max - stats.min; - auto bucketWidth = range / nBuckets; - if (range == 0 || bucketWidth == 0) { - stats.buckets[0].count = stats.numValues; - stats.maxNumValuesPerBucket = stats.numValues; - return stats; - } + // Calculate count per bucket + stats.maxNumValuesPerBucket = 0; + for(auto itr = begin; itr != end; ++itr) { + if(*itr < clipMinValueInclusive || *itr >= clipMaxValueExclusive) { + continue; + } + double index = + std::floor(static_cast(*itr - stats.min) / bucketWidth); + size_t intIndex = std::min(static_cast(index), nBuckets - 1); + + HistogramBucket& bucket = stats.buckets[intIndex]; + ++bucket.count; + + stats.maxNumValuesPerBucket = + std::max(stats.maxNumValuesPerBucket, bucket.count); + } - // Calculate count per bucket - stats.maxNumValuesPerBucket = 0; - for (auto itr = begin; itr != end; ++itr) { - if (*itr < clipMinValueInclusive || *itr >= clipMaxValueExclusive) { - continue; + // Set bucket start/end + int i = 0; + for(auto& bucket : stats.buckets) { + bucket.startInclusive = stats.min + bucketWidth * i; + bucket.endExclusive = stats.min + bucketWidth * (i + 1); + ++i; } - double index = - std::floor(static_cast(*itr - stats.min) / bucketWidth); - size_t intIndex = std::min(static_cast(index), nBuckets - 1); - - HistogramBucket& bucket = stats.buckets[intIndex]; - ++bucket.count; - - stats.maxNumValuesPerBucket = - std::max(stats.maxNumValuesPerBucket, bucket.count); - } - - // Set bucket start/end - int i = 0; - for (auto& bucket : stats.buckets) { - bucket.startInclusive = stats.min + bucketWidth * i; - bucket.endExclusive = stats.min + bucketWidth * (i + 1); - ++i; - } - // Fix possible finite precision algebra mistakes - stats.buckets.rbegin()->endExclusive = stats.max; - - return stats; + // Fix possible finite precision algebra mistakes + stats.buckets.rbegin()->endExclusive = stats.max; + + return stats; } -template +template std::string HistogramBucket::prettyString( double countPerTick, histValFmtFunc fromatCountIntoStream, - histValFmtFunc fromatValuesIntoStream) const { - std::stringstream ss; - ss << '['; - fromatValuesIntoStream(ss, startInclusive); - ss << '-'; - fromatValuesIntoStream(ss, endExclusive); - ss << "] "; - fromatCountIntoStream(ss, count); - ss << ": "; - const double numTicks = static_cast(count) / countPerTick; - for (int i = 0; i < std::round(numTicks); ++i) { - ss << "*"; - } - return ss.str(); + histValFmtFunc fromatValuesIntoStream +) const { + std::stringstream ss; + ss << '['; + fromatValuesIntoStream(ss, startInclusive); + ss << '-'; + fromatValuesIntoStream(ss, endExclusive); + ss << "] "; + fromatCountIntoStream(ss, count); + ss << ": "; + const double numTicks = static_cast(count) / countPerTick; + for(int i = 0; i < std::round(numTicks); ++i) { + ss << "*"; + } + return ss.str(); }; -template +template std::string HistogramStats::prettyString( size_t maxBarWidth, histValFmtFunc fromatCountIntoStream, - histValFmtFunc fromatValuesIntoStream) const { - std::stringstream ss; - ss << "HistogramStats{" - << " min=["; - fromatValuesIntoStream(ss, min); - ss << "] max_=["; - fromatValuesIntoStream(ss, max); - ss << "] sum=["; - if (sumOverflow) { - ss << "overflow"; - } else { - fromatCountIntoStream(ss, sum); - } - ss << "] mean=["; - fromatValuesIntoStream(ss, mean); - ss << "] numValues=["; - fromatCountIntoStream(ss, numValues); - ss << "] numBuckets=[" << buckets.size() << "]\n"; - if (buckets.size() > 1) { - double countPerTick = - static_cast(maxNumValuesPerBucket) / maxBarWidth; - for (const auto& bucket : buckets) { - ss << bucket.prettyString( - countPerTick, fromatCountIntoStream, fromatValuesIntoStream); - ss << std::endl; + histValFmtFunc fromatValuesIntoStream +) const { + std::stringstream ss; + ss << "HistogramStats{" + << " min=["; + fromatValuesIntoStream(ss, min); + ss << "] max_=["; + fromatValuesIntoStream(ss, max); + ss << "] sum=["; + if(sumOverflow) { + ss << "overflow"; + } else { + fromatCountIntoStream(ss, sum); + } + ss << "] mean=["; + fromatValuesIntoStream(ss, mean); + ss << "] numValues=["; + fromatCountIntoStream(ss, numValues); + ss << "] numBuckets=[" << buckets.size() << "]\n"; + if(buckets.size() > 1) { + double countPerTick = + static_cast(maxNumValuesPerBucket) / maxBarWidth; + for(const auto& bucket : buckets) { + ss << bucket.prettyString( + countPerTick, + fromatCountIntoStream, + fromatValuesIntoStream + ); + ss << std::endl; + } } - } - return ss.str(); + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/common/Logging.cpp b/flashlight/fl/common/Logging.cpp index d666739..7b004cd 100644 --- a/flashlight/fl/common/Logging.cpp +++ b/flashlight/fl/common/Logging.cpp @@ -18,9 +18,9 @@ namespace fl { void initLogging() { - // initialize backward stacktracing. noop if not built with a - // backtrace/stacktrace lib - detail::initBackward(); + // initialize backward stacktracing. noop if not built with a + // backtrace/stacktrace lib + detail::initBackward(); } LogLevel Logging::maxLoggingLevel_ = DEFAULT_MAX_FL_LOGGING_LEVEL; @@ -28,256 +28,259 @@ int VerboseLogging::maxLoggingLevel_ = DEFAULT_MAX_VERBOSE_FL_LOGGING_LEVEL; namespace { // Constatnts for ANSI terminal colors. -constexpr const char* RED = "\033[0;31m"; -constexpr const char* GREEN = "\033[0;32m"; -constexpr const char* YELLOW = "\033[0;33m"; -constexpr const char* NO_COLOR = "\033[0m"; + constexpr const char* RED = "\033[0;31m"; + constexpr const char* GREEN = "\033[0;32m"; + constexpr const char* YELLOW = "\033[0;33m"; + constexpr const char* NO_COLOR = "\033[0m"; #ifdef _WIN32 -constexpr const char* kSeparator = "\\"; + constexpr const char* kSeparator = "\\"; #else -constexpr const char* kSeparator = "/"; + constexpr const char* kSeparator = "/"; #endif -std::string getFileName(const std::string& path) { - const size_t separatorIndex = path.rfind(kSeparator, path.length()); - if (separatorIndex == std::string::npos) { - return path; - } - return path.substr(separatorIndex + 1, path.length() - separatorIndex); -} + std::string getFileName(const std::string& path) { + const size_t separatorIndex = path.rfind(kSeparator, path.length()); + if(separatorIndex == std::string::npos) { + return path; + } + return path.substr(separatorIndex + 1, path.length() - separatorIndex); + } // Returns high resolution time formatted as: // MMDD HH MM SS UUUUUU // 0206 08:42:42.123456 -std::string dateTimeWithMicroSeconds() { - auto systemTime = std::chrono::system_clock::now(); - const time_t secondsSinceEpoc = - std::chrono::system_clock::to_time_t(systemTime); - const struct tm* timeinfo = localtime(&secondsSinceEpoc); - - // Formate date and time to the seconds as: - // MMDD HH MM SS - // 1231 08:42:42 - constexpr size_t bufferSize = 50; - char buffer[bufferSize]; - const size_t nWrittenBytes = std::strftime(buffer, 30, "%m%d %T", timeinfo); - if (!nWrittenBytes) { - return "getTime() failed to format time"; - } - - const std::chrono::system_clock::time_point timeInSecondsResolution = - std::chrono::system_clock::from_time_t(secondsSinceEpoc); - const auto usec = std::chrono::duration_cast( - systemTime - timeInSecondsResolution); - - // Add msec and usec. - std::snprintf( - buffer + nWrittenBytes, - bufferSize - nWrittenBytes, - ".%06" PRId64, - static_cast(usec.count())); - - return buffer; -} - -void addContext( - const char* fullPath, - int lineNumber, - std::stringstream* outputStream) { - // report only the last threadIdNumDigits of the thread ID for succinctness - // and compatibility with glog. - constexpr size_t maxThreadIdNumDigits = 5; - std::stringstream ss; - ss << std::this_thread::get_id(); - - std::string threadId = ss.str(); - if(threadId.size() > maxThreadIdNumDigits){ - threadId = threadId.substr(threadId.size() - maxThreadIdNumDigits); - } - - - (*outputStream) << dateTimeWithMicroSeconds() << ' ' - << threadId << ' ' - << getFileName(fullPath) << ':' << lineNumber << ' '; -} + std::string dateTimeWithMicroSeconds() { + auto systemTime = std::chrono::system_clock::now(); + const time_t secondsSinceEpoc = + std::chrono::system_clock::to_time_t(systemTime); + const struct tm* timeinfo = localtime(&secondsSinceEpoc); + + // Formate date and time to the seconds as: + // MMDD HH MM SS + // 1231 08:42:42 + constexpr size_t bufferSize = 50; + char buffer[bufferSize]; + const size_t nWrittenBytes = std::strftime(buffer, 30, "%m%d %T", timeinfo); + if(!nWrittenBytes) { + return "getTime() failed to format time"; + } + + const std::chrono::system_clock::time_point timeInSecondsResolution = + std::chrono::system_clock::from_time_t(secondsSinceEpoc); + const auto usec = std::chrono::duration_cast( + systemTime - timeInSecondsResolution + ); + + // Add msec and usec. + std::snprintf( + buffer + nWrittenBytes, + bufferSize - nWrittenBytes, + ".%06" PRId64, + static_cast(usec.count()) + ); + + return buffer; + } + + void addContext( + const char* fullPath, + int lineNumber, + std::stringstream* outputStream + ) { + // report only the last threadIdNumDigits of the thread ID for succinctness + // and compatibility with glog. + constexpr size_t maxThreadIdNumDigits = 5; + std::stringstream ss; + ss << std::this_thread::get_id(); + + std::string threadId = ss.str(); + if(threadId.size() > maxThreadIdNumDigits) { + threadId = threadId.substr(threadId.size() - maxThreadIdNumDigits); + } + + + (*outputStream) << dateTimeWithMicroSeconds() << ' ' + << threadId << ' ' + << getFileName(fullPath) << ':' << lineNumber << ' '; + } } // namespace -Logging::Logging(LogLevel level, const char* fullPath, int lineNumber) - : level_(level), outputStreamPtr_(&std::cerr) { - if (level_ <= Logging::maxLoggingLevel_) { - switch (level_) { - case LogLevel::INFO: - stringStream_ << GREEN << "I"; - break; - case LogLevel::WARNING: - stringStream_ << YELLOW << "W"; - break; - case LogLevel::ERROR: - outputStreamPtr_ = &std::cerr; - stringStream_ << RED << "E"; - break; - case LogLevel::FATAL: - outputStreamPtr_ = &std::cerr; - stringStream_ << RED << "F"; - break; - default: - outputStreamPtr_ = &std::cerr; - stringStream_ << RED << "Invalid log level "; - }; - addContext(fullPath, lineNumber, &stringStream_); - stringStream_ << NO_COLOR; - } +Logging::Logging(LogLevel level, const char* fullPath, int lineNumber) : level_(level), + outputStreamPtr_(&std::cerr) { + if(level_ <= Logging::maxLoggingLevel_) { + switch(level_) { + case LogLevel::INFO: + stringStream_ << GREEN << "I"; + break; + case LogLevel::WARNING: + stringStream_ << YELLOW << "W"; + break; + case LogLevel::ERROR: + outputStreamPtr_ = &std::cerr; + stringStream_ << RED << "E"; + break; + case LogLevel::FATAL: + outputStreamPtr_ = &std::cerr; + stringStream_ << RED << "F"; + break; + default: + outputStreamPtr_ = &std::cerr; + stringStream_ << RED << "Invalid log level "; + } + ; + addContext(fullPath, lineNumber, &stringStream_); + stringStream_ << NO_COLOR; + } } Logging::~Logging() { - if (level_ <= Logging::maxLoggingLevel_) { - stringStream_ << std::endl; - (*outputStreamPtr_) << stringStream_.str(); - outputStreamPtr_->flush(); - if (level_ == LogLevel::FATAL) { - exit(-1); + if(level_ <= Logging::maxLoggingLevel_) { + stringStream_ << std::endl; + (*outputStreamPtr_) << stringStream_.str(); + outputStreamPtr_->flush(); + if(level_ == LogLevel::FATAL) { + exit(-1); + } } - } } void Logging::setMaxLoggingLevel(LogLevel maxLoggingLevel) { - if (maxLoggingLevel != Logging::maxLoggingLevel_) { - std::cerr << "Logging::setMaxLoggingLevel(maxLoggingLevel=" - << logLevelName(maxLoggingLevel) << ") Logging::maxLoggingLevel_=" - << logLevelName(Logging::maxLoggingLevel_) << std::endl; - Logging::maxLoggingLevel_ = maxLoggingLevel; - } + if(maxLoggingLevel != Logging::maxLoggingLevel_) { + std::cerr << "Logging::setMaxLoggingLevel(maxLoggingLevel=" + << logLevelName(maxLoggingLevel) << ") Logging::maxLoggingLevel_=" + << logLevelName(Logging::maxLoggingLevel_) << std::endl; + Logging::maxLoggingLevel_ = maxLoggingLevel; + } } Logging&& operator<<(Logging&& log, const std::string& s) { - return std::move(log.print(s)); + return std::move(log.print(s)); } Logging&& operator<<(Logging&& log, const char* s) { - return std::move(log.print(s)); + return std::move(log.print(s)); } Logging&& operator<<(Logging&& log, const void* s) { - return std::move(log.print(s)); + return std::move(log.print(s)); } Logging&& operator<<(Logging&& log, char c) { - return std::move(log.print(c)); + return std::move(log.print(c)); } Logging&& operator<<(Logging&& log, unsigned char u) { - return std::move(log.print(u)); + return std::move(log.print(u)); } Logging&& operator<<(Logging&& log, int i) { - return std::move(log.print(i)); + return std::move(log.print(i)); } Logging&& operator<<(Logging&& log, unsigned int u) { - return std::move(log.print(u)); + return std::move(log.print(u)); } Logging&& operator<<(Logging&& log, long l) { - return std::move(log.print(l)); + return std::move(log.print(l)); } Logging&& operator<<(Logging&& log, long long l) { - return std::move(log.print(l)); + return std::move(log.print(l)); } Logging&& operator<<(Logging&& log, unsigned long u) { - return std::move(log.print(u)); + return std::move(log.print(u)); } Logging&& operator<<(Logging&& log, unsigned long long u) { - return std::move(log.print(u)); + return std::move(log.print(u)); } Logging&& operator<<(Logging&& log, float f) { - return std::move(log.print(f)); + return std::move(log.print(f)); } Logging&& operator<<(Logging&& log, double d) { - return std::move(log.print(d)); + return std::move(log.print(d)); } Logging&& operator<<(Logging&& log, bool b) { - return std::move(log.print(b)); + return std::move(log.print(b)); } -VerboseLogging::VerboseLogging(int level, const char* fullPath, int lineNumber) - : level_(level) { - if (level_ <= VerboseLogging::maxLoggingLevel_) { - stringStream_ << "vlog(" << level_ << ") "; - addContext(fullPath, lineNumber, &stringStream_); - } +VerboseLogging::VerboseLogging(int level, const char* fullPath, int lineNumber) : level_(level) { + if(level_ <= VerboseLogging::maxLoggingLevel_) { + stringStream_ << "vlog(" << level_ << ") "; + addContext(fullPath, lineNumber, &stringStream_); + } } VerboseLogging::~VerboseLogging() { - if (level_ <= VerboseLogging::maxLoggingLevel_) { - stringStream_ << std::endl; - std::cerr << stringStream_.str(); - std::cerr.flush(); - } + if(level_ <= VerboseLogging::maxLoggingLevel_) { + stringStream_ << std::endl; + std::cerr << stringStream_.str(); + std::cerr.flush(); + } } void VerboseLogging::setMaxLoggingLevel(int maxLoggingLevel) { - if (maxLoggingLevel != VerboseLogging::maxLoggingLevel_) { - std::cerr << "VerboseLogging::setMaxLoggingLevel(maxLoggingLevel=" - << maxLoggingLevel << ") VerboseLogging::maxLoggingLevel_=" - << VerboseLogging::maxLoggingLevel_ << std::endl; - VerboseLogging::maxLoggingLevel_ = maxLoggingLevel; - } + if(maxLoggingLevel != VerboseLogging::maxLoggingLevel_) { + std::cerr << "VerboseLogging::setMaxLoggingLevel(maxLoggingLevel=" + << maxLoggingLevel << ") VerboseLogging::maxLoggingLevel_=" + << VerboseLogging::maxLoggingLevel_ << std::endl; + VerboseLogging::maxLoggingLevel_ = maxLoggingLevel; + } } VerboseLogging&& operator<<(VerboseLogging&& log, const std::string& s) { - return std::move(log.print(s)); + return std::move(log.print(s)); } VerboseLogging&& operator<<(VerboseLogging&& log, const char* s) { - return std::move(log.print(s)); + return std::move(log.print(s)); } VerboseLogging&& operator<<(VerboseLogging&& log, const void* s) { - return std::move(log.print(s)); + return std::move(log.print(s)); } VerboseLogging&& operator<<(VerboseLogging&& log, char c) { - return std::move(log.print(c)); + return std::move(log.print(c)); } VerboseLogging&& operator<<(VerboseLogging&& log, unsigned char u) { - return std::move(log.print(u)); + return std::move(log.print(u)); } VerboseLogging&& operator<<(VerboseLogging&& log, int i) { - return std::move(log.print(i)); + return std::move(log.print(i)); } VerboseLogging&& operator<<(VerboseLogging&& log, unsigned int u) { - return std::move(log.print(u)); + return std::move(log.print(u)); } VerboseLogging&& operator<<(VerboseLogging&& log, long l) { - return std::move(log.print(l)); + return std::move(log.print(l)); } VerboseLogging&& operator<<(VerboseLogging&& log, unsigned long u) { - return std::move(log.print(u)); + return std::move(log.print(u)); } VerboseLogging&& operator<<(VerboseLogging&& log, float f) { - return std::move(log.print(f)); + return std::move(log.print(f)); } VerboseLogging&& operator<<(VerboseLogging&& log, double d) { - return std::move(log.print(d)); + return std::move(log.print(d)); } VerboseLogging&& operator<<(VerboseLogging&& log, bool b) { - return std::move(log.print(b)); + return std::move(log.print(b)); } constexpr std::array flLogLevelValues = { @@ -287,31 +290,31 @@ constexpr std::array flLogLevelValues = { fl::LogLevel::FATAL, fl::LogLevel::DISABLED}; constexpr std::array flLogLevelNames = - {"INFO", "WARNING", "ERROR", "FATAL", "DISABLED"}; +{"INFO", "WARNING", "ERROR", "FATAL", "DISABLED"}; std::string logLevelName(LogLevel level) { - for (int i = 0; i < flLogLevelValues.size(); ++i) { - if (level == flLogLevelValues.at(i)) { - return flLogLevelNames.at(i); + for(int i = 0; i < flLogLevelValues.size(); ++i) { + if(level == flLogLevelValues.at(i)) { + return flLogLevelNames.at(i); + } } - } - std::stringstream ss; - ss << "logLevelName(level=" << static_cast(level) - << ") invalid level. Level should be in the range [0.." - << (flLogLevelNames.size() - 1) << "]"; - throw std::invalid_argument(ss.str()); + std::stringstream ss; + ss << "logLevelName(level=" << static_cast(level) + << ") invalid level. Level should be in the range [0.." + << (flLogLevelNames.size() - 1) << "]"; + throw std::invalid_argument(ss.str()); } LogLevel logLevelValue(const std::string& level) { - for (int i = 0; i < flLogLevelValues.size(); ++i) { - if (level == std::string(flLogLevelNames.at(i))) { - return flLogLevelValues.at(i); + for(int i = 0; i < flLogLevelValues.size(); ++i) { + if(level == std::string(flLogLevelNames.at(i))) { + return flLogLevelValues.at(i); + } } - } - std::stringstream ss; - ss << "logLevelValue(level=" << level - << ") invalid level. Level should be INFO, WARNING, ERROR or FATAL"; - throw std::invalid_argument(ss.str()); + std::stringstream ss; + ss << "logLevelValue(level=" << level + << ") invalid level. Level should be INFO, WARNING, ERROR or FATAL"; + throw std::invalid_argument(ss.str()); } } // namespace fl diff --git a/flashlight/fl/common/Logging.h b/flashlight/fl/common/Logging.h index 7169022..d41f7fd 100644 --- a/flashlight/fl/common/Logging.h +++ b/flashlight/fl/common/Logging.h @@ -78,12 +78,12 @@ FL_API void initLogging(); /// \ingroup logging enum class LogLevel { - DISABLED, // use only for when calling setMaxLoggingLevel() or - // setting DEFAULT_MAX_FL_LOGGING_LEVEL. - FATAL, - ERROR, - WARNING, - INFO, + DISABLED, // use only for when calling setMaxLoggingLevel() or + // setting DEFAULT_MAX_FL_LOGGING_LEVEL. + FATAL, + ERROR, + WARNING, + INFO, }; /** @@ -130,114 +130,114 @@ constexpr int DEFAULT_MAX_VERBOSE_FL_LOGGING_LEVEL = 0; #define FL_VLOG(level) fl::VerboseLogging(level, __FILE__, __LINE__) // Optimization macros that allow to run code only we are going to log it. -#define IF_LOG(level) if (fl::Logging::ifLog(level)) -#define IF_VLOG(level) if (fl::VerboseLogging::ifLog(level)) +#define IF_LOG(level) if(fl::Logging::ifLog(level)) +#define IF_VLOG(level) if(fl::VerboseLogging::ifLog(level)) /// \ingroup logging #define FL_LOG_IF(level, exp) \ - if (exp) \ - fl::Logging(level, __FILE__, __LINE__) + if(exp) \ + fl::Logging(level, __FILE__, __LINE__) /// \ingroup logging #define FL_VLOG_IF(level, exp) \ - if (exp) \ - fl::VerboseLogging(level, __FILE__, __LINE__) + if(exp) \ + fl::VerboseLogging(level, __FILE__, __LINE__) class FL_API Logging { - public: - Logging(LogLevel level, const char* filename, int lineNumber); - ~Logging(); - - // Prints t to stdout along with context and sensible font color. - template - Logging&& print(T& t) { - if (level_ <= Logging::maxLoggingLevel_) { - stringStream_ << t; +public: + Logging(LogLevel level, const char* filename, int lineNumber); + ~Logging(); + + // Prints t to stdout along with context and sensible font color. + template + Logging && print(T & t) { + if(level_ <= Logging::maxLoggingLevel_) { + stringStream_ << t; + } + return std::move(*this); } - return std::move(*this); - } - // Overrides DEFAULT_MAX_FL_LOGGING_LEVEL value. - static void setMaxLoggingLevel(LogLevel maxLoggingLevel); + // Overrides DEFAULT_MAX_FL_LOGGING_LEVEL value. + static void setMaxLoggingLevel(LogLevel maxLoggingLevel); - static bool ifLog(LogLevel level) { - return (maxLoggingLevel_ >= level); - } + static bool ifLog(LogLevel level) { + return maxLoggingLevel_ >= level; + } - private: - static LogLevel maxLoggingLevel_; - const LogLevel level_; - std::stringstream stringStream_; - std::ostream* outputStreamPtr_; +private: + static LogLevel maxLoggingLevel_; + const LogLevel level_; + std::stringstream stringStream_; + std::ostream* outputStreamPtr_; }; class FL_API VerboseLogging { - public: - VerboseLogging(int level, const char* filename, int lineNumber); - ~VerboseLogging(); - - // Prints t to stdout along with logging level and context. - template - VerboseLogging&& print(T& t) { - if (level_ <= VerboseLogging::maxLoggingLevel_) { - stringStream_ << t; +public: + VerboseLogging(int level, const char* filename, int lineNumber); + ~VerboseLogging(); + + // Prints t to stdout along with logging level and context. + template + VerboseLogging && print(T & t) { + if(level_ <= VerboseLogging::maxLoggingLevel_) { + stringStream_ << t; + } + return std::move(*this); } - return std::move(*this); - } - // Overrides DEFAULT_MAX_VERBOSE_FL_LOGGING_LEVEL value. - static void setMaxLoggingLevel(int maxLoggingLevel); + // Overrides DEFAULT_MAX_VERBOSE_FL_LOGGING_LEVEL value. + static void setMaxLoggingLevel(int maxLoggingLevel); - static bool ifLog(int level) { - return (maxLoggingLevel_ >= level); - } + static bool ifLog(int level) { + return maxLoggingLevel_ >= level; + } - private: - static int maxLoggingLevel_; - const int level_; - std::stringstream stringStream_; +private: + static int maxLoggingLevel_; + const int level_; + std::stringstream stringStream_; }; // Can't use template here since the compiler will try resolve // to all kind of other existing function before it considers // instantiating a template. -FL_API Logging&& operator<<(Logging&& log, const std::string& s); -FL_API Logging&& operator<<(Logging&& log, const char* s); -FL_API Logging&& operator<<(Logging&& log, const void* s); -FL_API Logging&& operator<<(Logging&& log, char c); -FL_API Logging&& operator<<(Logging&& log, unsigned char u); -FL_API Logging&& operator<<(Logging&& log, int i); -FL_API Logging&& operator<<(Logging&& log, unsigned int u); -FL_API Logging&& operator<<(Logging&& log, long l); -FL_API Logging&& operator<<(Logging&& log, long long l); -FL_API Logging&& operator<<(Logging&& log, unsigned long u); -FL_API Logging&& operator<<(Logging&& log, unsigned long long u); -FL_API Logging&& operator<<(Logging&& log, float f); -FL_API Logging&& operator<<(Logging&& log, double d); -FL_API Logging&& operator<<(Logging&& log, bool b); +FL_API Logging && operator<<(Logging && log, const std::string& s); +FL_API Logging && operator<<(Logging && log, const char* s); +FL_API Logging && operator<<(Logging && log, const void* s); +FL_API Logging && operator<<(Logging && log, char c); +FL_API Logging && operator<<(Logging && log, unsigned char u); +FL_API Logging && operator<<(Logging && log, int i); +FL_API Logging && operator<<(Logging && log, unsigned int u); +FL_API Logging && operator<<(Logging && log, long l); +FL_API Logging && operator<<(Logging && log, long long l); +FL_API Logging && operator<<(Logging && log, unsigned long u); +FL_API Logging && operator<<(Logging && log, unsigned long long u); +FL_API Logging && operator<<(Logging && log, float f); +FL_API Logging && operator<<(Logging && log, double d); +FL_API Logging && operator<<(Logging && log, bool b); // Catch all designed mostly for stuff. -template -Logging&& operator<<(Logging&& log, const T& t) { - return log.print(t); +template +Logging && operator<<(Logging&& log, const T& t) { + return log.print(t); } -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, const std::string& s); -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, const char* s); -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, const void* s); -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, char c); -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, unsigned char u); -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, int i); -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, unsigned int u); -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, long l); -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, unsigned long u); -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, float f); -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, double d); -FL_API VerboseLogging&& operator<<(VerboseLogging&& log, bool b); +FL_API VerboseLogging && operator<<(VerboseLogging && log, const std::string& s); +FL_API VerboseLogging && operator<<(VerboseLogging && log, const char* s); +FL_API VerboseLogging && operator<<(VerboseLogging && log, const void* s); +FL_API VerboseLogging && operator<<(VerboseLogging && log, char c); +FL_API VerboseLogging && operator<<(VerboseLogging && log, unsigned char u); +FL_API VerboseLogging && operator<<(VerboseLogging && log, int i); +FL_API VerboseLogging && operator<<(VerboseLogging && log, unsigned int u); +FL_API VerboseLogging && operator<<(VerboseLogging && log, long l); +FL_API VerboseLogging && operator<<(VerboseLogging && log, unsigned long u); +FL_API VerboseLogging && operator<<(VerboseLogging && log, float f); +FL_API VerboseLogging && operator<<(VerboseLogging && log, double d); +FL_API VerboseLogging && operator<<(VerboseLogging && log, bool b); // Catch all designed mostly for stuff. -template -VerboseLogging&& operator<<(VerboseLogging&& log, const T& t) { - return log.print(t); +template +VerboseLogging && operator<<(VerboseLogging&& log, const T& t) { + return log.print(t); } } // namespace fl diff --git a/flashlight/fl/common/Plugin.cpp b/flashlight/fl/common/Plugin.cpp index 05f7718..4e13b32 100644 --- a/flashlight/fl/common/Plugin.cpp +++ b/flashlight/fl/common/Plugin.cpp @@ -11,65 +11,65 @@ #include #ifdef _WIN32 - #include - #include "flashlight/fl/common/WinUtility.h" - #define PLUGIN_HANDLE HMODULE +#include +#include "flashlight/fl/common/WinUtility.h" +#define PLUGIN_HANDLE HMODULE #else - #include - #define PLUGIN_HANDLE void* +#include +#define PLUGIN_HANDLE void* #endif namespace fl { Plugin::Plugin(const std::string& name) : name_(name) { #ifdef _WIN32 - auto wideName = detail::utf8ToWide(name); - handle_ = (void*)LoadLibraryW(wideName.c_str()); - if (!handle_) { - auto err = detail::getWindowsErrorString(); - throw std::runtime_error("unable to load library <" + name + ">: " + err); - } + auto wideName = detail::utf8ToWide(name); + handle_ = (void*) LoadLibraryW(wideName.c_str()); + if(!handle_) { + auto err = detail::getWindowsErrorString(); + throw std::runtime_error("unable to load library <" + name + ">: " + err); + } #else - dlerror(); // clear errors - handle_ = dlopen(name.c_str(), RTLD_LAZY); - if (!handle_) { - auto err = dlerror(); - throw std::runtime_error("unable to load library <" + name + ">: " + err); - } + dlerror(); // clear errors + handle_ = dlopen(name.c_str(), RTLD_LAZY); + if(!handle_) { + auto err = dlerror(); + throw std::runtime_error("unable to load library <" + name + ">: " + err); + } #endif } void* Plugin::getRawSymbol(const std::string& symbol) { #ifdef _WIN32 - auto addr = (void*)GetProcAddress((HMODULE)handle_, symbol.c_str()); + auto addr = (void*) GetProcAddress((HMODULE) handle_, symbol.c_str()); #else - dlerror(); // clear errors - auto addr = dlsym(handle_, symbol.c_str()); + dlerror(); // clear errors + auto addr = dlsym(handle_, symbol.c_str()); #endif - if (!addr) { + if(!addr) { #ifdef _WIN32 - auto err = detail::getWindowsErrorString(); + auto err = detail::getWindowsErrorString(); #else - auto err = dlerror(); + auto err = dlerror(); #endif - std::stringstream msg; - msg << "unable to resolve symbol <" << symbol << ">"; - msg << " in library <" << name_ << ">"; - msg << ":" << err; - throw std::runtime_error(msg.str()); - } - return addr; + std::stringstream msg; + msg << "unable to resolve symbol <" << symbol << ">"; + msg << " in library <" << name_ << ">"; + msg << ":" << err; + throw std::runtime_error(msg.str()); + } + return addr; } Plugin::~Plugin() { - if (handle_) { + if(handle_) { #ifdef _WIN32 - FreeLibrary((HMODULE)handle_); + FreeLibrary((HMODULE) handle_); #else - dlclose(handle_); + dlclose(handle_); #endif - } + } } } // namespace fl diff --git a/flashlight/fl/common/Plugin.h b/flashlight/fl/common/Plugin.h index 58e0e72..d4d050b 100644 --- a/flashlight/fl/common/Plugin.h +++ b/flashlight/fl/common/Plugin.h @@ -14,19 +14,19 @@ namespace fl { class FL_API Plugin { - public: - explicit Plugin(const std::string& name); - ~Plugin(); +public: + explicit Plugin(const std::string& name); + ~Plugin(); - protected: - template - T getSymbol(const std::string& symbol) { - return (T)getRawSymbol(symbol); - } +protected: + template + T getSymbol(const std::string& symbol) { + return (T) getRawSymbol(symbol); + } - private: - void* getRawSymbol(const std::string& symbol); - std::string name_; - void* handle_; +private: + void* getRawSymbol(const std::string& symbol); + std::string name_; + void* handle_; }; } // namespace fl diff --git a/flashlight/fl/common/Serialization-inl.h b/flashlight/fl/common/Serialization-inl.h index f51a4e9..47a1002 100644 --- a/flashlight/fl/common/Serialization-inl.h +++ b/flashlight/fl/common/Serialization-inl.h @@ -23,133 +23,132 @@ namespace fl { namespace detail { -template -using IsOutputArchive = std::is_base_of; + template + using IsOutputArchive = std::is_base_of; -template -using IsInputArchive = std::is_base_of; + template + using IsInputArchive = std::is_base_of; /** * Wrapper indicating that an expression should be serialized only if the * version is in a certain range. */ -template -struct Versioned { - T&& ref; - uint32_t minVersion; - uint32_t maxVersion; -}; - -template -struct SerializeAs { - using T0 = std::decay_t; - T&& ref; - std::function saveConverter; - std::function loadConverter; -}; + template + struct Versioned { + T && ref; + uint32_t minVersion; + uint32_t maxVersion; + }; + + template + struct SerializeAs { + using T0 = std::decay_t; + T && ref; + std::function saveConverter; + std::function loadConverter; + }; // 0 arguments (no-op). -template -void applyArchive(Archive& ar, const uint32_t version) {} + template + void applyArchive(Archive& ar, const uint32_t version) {} // 1 argument, general case. -template -void applyArchive(Archive& ar, const uint32_t version, Arg&& arg) { - ar(std::forward(arg)); -} + template + void applyArchive(Archive& ar, const uint32_t version, Arg&& arg) { + ar(std::forward(arg)); + } // 1 argument, version-restricted. -template -void applyArchive(Archive& ar, const uint32_t version, Versioned varg) { - if (version >= varg.minVersion && version <= varg.maxVersion) { - applyArchive(ar, version, std::forward(varg.ref)); - } -} + template + void applyArchive(Archive& ar, const uint32_t version, Versioned varg) { + if(version >= varg.minVersion && version <= varg.maxVersion) { + applyArchive(ar, version, std::forward(varg.ref)); + } + } // 1 argument, with conversion, saving. -template < - typename Archive, - typename S, - typename T, - std::enable_if_t::value, int> = 0> -void applyArchive(Archive& ar, const uint32_t version, SerializeAs arg) { - if (arg.saveConverter) { - applyArchive(ar, version, arg.saveConverter(arg.ref)); - } else { - applyArchive(ar, version, static_cast(arg.ref)); - } -} + template< + typename Archive, + typename S, + typename T, + std::enable_if_t::value, int> = 0> + void applyArchive(Archive& ar, const uint32_t version, SerializeAs arg) { + if(arg.saveConverter) { + applyArchive(ar, version, arg.saveConverter(arg.ref)); + } else { + applyArchive(ar, version, static_cast(arg.ref)); + } + } // 1 argument, with conversion, loading. -template < - typename Archive, - typename S, - typename T, - std::enable_if_t::value, int> = 0> -void applyArchive(Archive& ar, const uint32_t version, SerializeAs arg) { - using T0 = std::remove_reference_t; - S s; - applyArchive(ar, version, s); - if (arg.loadConverter) { - arg.ref = arg.loadConverter(std::move(s)); - } else { - arg.ref = static_cast(std::move(s)); - } -} + template< + typename Archive, + typename S, + typename T, + std::enable_if_t::value, int> = 0> + void applyArchive(Archive& ar, const uint32_t version, SerializeAs arg) { + using T0 = std::remove_reference_t; + S s; + applyArchive(ar, version, s); + if(arg.loadConverter) { + arg.ref = arg.loadConverter(std::move(s)); + } else { + arg.ref = static_cast(std::move(s)); + } + } // 2+ arguments (recurse). -template -void applyArchive( - Archive& ar, - const uint32_t version, - Arg&& arg, - Args&&... args) { - applyArchive(ar, version, std::forward(arg)); - applyArchive(ar, version, std::forward(args)...); -} + template + void applyArchive( + Archive& ar, + const uint32_t version, + Arg&& arg, + Args&&... args + ) { + applyArchive(ar, version, std::forward(arg)); + applyArchive(ar, version, std::forward(args)...); + } } // namespace detail -template -detail::Versioned -versioned(T&& t, uint32_t minVersion, uint32_t maxVersion) { - return detail::Versioned{std::forward(t), minVersion, maxVersion}; +template +detail::Versioned versioned(T&& t, uint32_t minVersion, uint32_t maxVersion) { + return detail::Versioned{std::forward(t), minVersion, maxVersion}; } -template +template detail::SerializeAs serializeAs(T&& t) { - return detail::SerializeAs{std::forward(t), nullptr, nullptr}; + return detail::SerializeAs{std::forward(t), nullptr, nullptr}; } -template -detail::SerializeAs -serializeAs(T&& t, SaveConvFn saveConverter, LoadConvFn loadConverter) { - return detail::SerializeAs{ - std::forward(t), std::move(saveConverter), std::move(loadConverter)}; +template +detail::SerializeAs serializeAs(T&& t, SaveConvFn saveConverter, LoadConvFn loadConverter) { + return detail::SerializeAs{ + std::forward(t), std::move(saveConverter), std::move(loadConverter)}; } -template +template void save(const fs::path& filepath, const Args&... args) { - std::ofstream ofs(filepath, std::ios::binary); - save(ofs, args...); + std::ofstream ofs(filepath, std::ios::binary); + save(ofs, args...); } -template +template void save(std::ostream& ostr, const Args&... args) { - cereal::BinaryOutputArchive ar(ostr); - ar(args...); + cereal::BinaryOutputArchive ar(ostr); + ar(args...); } -template +template void load(const fs::path& filepath, Args&... args) { - std::ifstream ifs(filepath, std::ios::binary); - load(ifs, args...); + std::ifstream ifs(filepath, std::ios::binary); + load(ifs, args...); } -template +template void load(std::istream& istr, Args&... args) { - cereal::BinaryInputArchive ar(istr); - ar(args...); + cereal::BinaryInputArchive ar(istr); + ar(args...); } namespace detail { @@ -166,60 +165,64 @@ namespace detail { * For more info, see https://github.com/USCiLab/cereal/issues/132 * and https://en.cppreference.com/w/cpp/language/implicit_conversion */ -template -struct CerealSave { - /* implicit */ CerealSave(const T& x) : val(x) {} - const T& val; -}; + template + struct CerealSave { + /* implicit */ + CerealSave(const T& x) : val(x) {} + const T& val; + }; } // namespace detail } // namespace fl namespace cereal { // no versioning; simple and unlikely to ever change -template +template void save( Archive& ar, const fl::detail::CerealSave& dims_, - const uint32_t /* version */) { - // TODO{fl::Tensor} -- check version, then op as dim4 (if version ==) - const auto& dims = dims_.val; - const std::vector& vec = dims.get(); - ar(vec); + const uint32_t /* version */ +) { + // TODO{fl::Tensor} -- check version, then op as dim4 (if version ==) + const auto& dims = dims_.val; + const std::vector& vec = dims.get(); + ar(vec); } -template +template void load(Archive& ar, fl::Shape& dims, const uint32_t /* version */) { - // TODO{fl::Tensor} -- check version, then read dim4 into Shape (if version - // ==) - std::vector vec; - ar(vec); - dims = fl::Shape(vec); + // TODO{fl::Tensor} -- check version, then read dim4 into Shape (if version + // ==) + std::vector vec; + ar(vec); + dims = fl::Shape(vec); } -template +template void save( Archive& ar, const fl::detail::CerealSave& tensor_, - const uint32_t /* version */) { - const auto& tensor = tensor_.val; - // TODO{fl::Tensor}{sparse} figure out what to do here... - if (tensor.isSparse()) { - throw cereal::Exception( - "Serialzation of sparse Tensor is not supported yet!"); - } - std::vector vec(tensor.bytes()); - tensor.host(vec.data()); - ar(tensor.shape(), tensor.type(), vec); -} - -template + const uint32_t /* version */ +) { + const auto& tensor = tensor_.val; + // TODO{fl::Tensor}{sparse} figure out what to do here... + if(tensor.isSparse()) { + throw cereal::Exception( + "Serialzation of sparse Tensor is not supported yet!" + ); + } + std::vector vec(tensor.bytes()); + tensor.host(vec.data()); + ar(tensor.shape(), tensor.type(), vec); +} + +template void load(Archive& ar, fl::Tensor& tensor, const uint32_t /* version */) { - fl::Shape dims; - fl::dtype ty; - std::vector vec; - ar(dims, ty, vec); - tensor = fl::Tensor::fromVector(dims, vec, ty); + fl::Shape dims; + fl::dtype ty; + std::vector vec; + ar(dims, ty, vec); + tensor = fl::Tensor::fromVector(dims, vec, ty); } } // namespace cereal diff --git a/flashlight/fl/common/Serialization.h b/flashlight/fl/common/Serialization.h index 077cb5c..71cd48f 100644 --- a/flashlight/fl/common/Serialization.h +++ b/flashlight/fl/common/Serialization.h @@ -54,7 +54,7 @@ namespace fl { * @param filepath the file path to save to * @param args the objects to save (e.g. shared_ptr to Module) */ -template +template void save(const fs::path& filepath, const Args&... args); /** @@ -62,7 +62,7 @@ void save(const fs::path& filepath, const Args&... args); * @param ostr output stream * @param args the objects to save (e.g. shared_ptr to Module) */ -template +template void save(std::ostream& ostr, const Args&... args); /** @@ -70,16 +70,14 @@ void save(std::ostream& ostr, const Args&... args); * @param filepath the file path to load from * @param args the objects to load (expects default-constructed) */ -template -void load(const fs::path& filepath, Args&... args); +template void load(const fs::path& filepath, Args & ... args); /** * Load (deserialize) the specified args from a binary file (via Cereal). * @param istr input stream * @param args the objects to load (expects default-constructed) */ -template -void load(std::istream& istr, Args&... args); +template void load(std::istream& istr, Args & ... args); /** @} */ } // namespace fl @@ -98,16 +96,14 @@ void load(std::istream& istr, Args&... args); * Supports the common case when one adds fields to a class, which should be * conditionally loaded for newer file versions. See `fl::versioned()`. */ -#define FL_SAVE_LOAD(...) \ - friend class cereal::access; \ - template \ - void save(Archive& ar, const uint32_t version) const { \ - ::fl::detail::applyArchive(ar, version, ##__VA_ARGS__); \ - } \ - template \ - void load(Archive& ar, const uint32_t version) { \ - ::fl::detail::applyArchive(ar, version, ##__VA_ARGS__); \ - } +#define FL_SAVE_LOAD(...) \ + friend class cereal::access; \ + template void save(Archive & ar, const uint32_t version) const { \ + ::fl::detail::applyArchive(ar, version, ## __VA_ARGS__); \ + } \ + template void load(Archive & ar, const uint32_t version) { \ + ::fl::detail::applyArchive(ar, version, ## __VA_ARGS__); \ + } /** * Like `FL_SAVE_LOAD`, but also serializes the base class, which must @@ -117,7 +113,7 @@ void load(std::istream& istr, Args&... args); * using this macro. However you will still need `CEREAL_REGISTER_TYPE`. */ #define FL_SAVE_LOAD_WITH_BASE(Base, ...) \ - FL_SAVE_LOAD(cereal::base_class(this), ##__VA_ARGS__) + FL_SAVE_LOAD(cereal::base_class(this), ## __VA_ARGS__) /** * Declaration-only. Intended to reduce clutter in class definitions. @@ -125,26 +121,24 @@ void load(std::istream& istr, Args&... args); * The method must be defined later (outside the class). * Do not use this macro if you want your class to be unversioned. */ -#define FL_SAVE_LOAD_DECLARE() \ - friend class cereal::access; \ - template \ - void save(Archive& ar, const uint32_t version) const; \ - template \ - void load(Archive& ar, const uint32_t version); +#define FL_SAVE_LOAD_DECLARE() \ + friend class cereal::access; \ + template void save(Archive & ar, const uint32_t version) const; \ + template void load(Archive & ar, const uint32_t version); /** @} */ namespace fl { namespace detail { -template -struct Versioned; + template + struct Versioned; -template -struct SerializeAs; + template + struct SerializeAs; -template -struct CerealSave; + template + struct CerealSave; } // namespace detail @@ -160,9 +154,8 @@ struct CerealSave; * Example: if we have field `x` in version 0, and add field `y` in version 1, * we might write: `FL_SAVE_LOAD(x, fl::versioned(y, 1))`. */ -template -detail::Versioned -versioned(T&& t, uint32_t minVersion, uint32_t maxVersion = UINT32_MAX); +template +detail::Versioned versioned(T&& t, uint32_t minVersion, uint32_t maxVersion = UINT32_MAX); /** * Serialize an object of type T as another type S using static_cast on-the-fly. @@ -170,7 +163,7 @@ versioned(T&& t, uint32_t minVersion, uint32_t maxVersion = UINT32_MAX); * * Example: `FL_SAVE_LOAD(fl::serializeAs(x))` */ -template +template detail::SerializeAs serializeAs(T&& t); /** @@ -186,31 +179,32 @@ detail::SerializeAs serializeAs(T&& t); * * Example: please see tests/common/SerializationTest.cpp */ -template -detail::SerializeAs -serializeAs(T&& t, SaveConvFn saveConverter, LoadConvFn loadConverter); +template +detail::SerializeAs serializeAs(T&& t, SaveConvFn saveConverter, LoadConvFn loadConverter); /** @} */ } // namespace fl namespace cereal { -template +template void save( Archive& ar, const fl::detail::CerealSave& dims, - const uint32_t /* version */); + const uint32_t /* version */ +); -template +template void load(Archive& ar, fl::Shape& dims, const uint32_t /* version */); -template +template void save( Archive& ar, const fl::detail::CerealSave& tensor, - const uint32_t /* version */); + const uint32_t /* version */ +); -template +template void load(Archive& ar, fl::Tensor& tensor, const uint32_t /* version */); } // namespace cereal diff --git a/flashlight/fl/common/Timer.cpp b/flashlight/fl/common/Timer.cpp index 60290b6..f90045c 100644 --- a/flashlight/fl/common/Timer.cpp +++ b/flashlight/fl/common/Timer.cpp @@ -10,9 +10,9 @@ namespace fl { Timer Timer::start() { - Timer t; - t.startTime_ = std::chrono::high_resolution_clock::now(); - return t; + Timer t; + t.startTime_ = std::chrono::high_resolution_clock::now(); + return t; } } // namespace fl diff --git a/flashlight/fl/common/Timer.h b/flashlight/fl/common/Timer.h index 3e3f5ec..e409d13 100644 --- a/flashlight/fl/common/Timer.h +++ b/flashlight/fl/common/Timer.h @@ -14,17 +14,18 @@ namespace fl { class FL_API Timer { - std::chrono::time_point startTime_; - - public: - static Timer start(); - - template - static T stop(const Timer& t) { - return std::chrono::duration_cast>( - std::chrono::high_resolution_clock::now() - t.startTime_) - .count(); - } + std::chrono::time_point startTime_; + +public: + static Timer start(); + + template + static T stop(const Timer& t) { + return std::chrono::duration_cast>( + std::chrono::high_resolution_clock::now() - t.startTime_ + ) + .count(); + } }; } // namespace fl diff --git a/flashlight/fl/common/Types.h b/flashlight/fl/common/Types.h index 443f56b..11adfcd 100644 --- a/flashlight/fl/common/Types.h +++ b/flashlight/fl/common/Types.h @@ -20,7 +20,7 @@ namespace detail { /** * Precision specifications for autograd operators based on optimization level. */ -const std::unordered_map> + const std::unordered_map> kOptimLevelTypeExclusionMappings = { {OptimLevel::DEFAULT, {}}, // unused {OptimLevel::O1, @@ -45,7 +45,7 @@ const std::unordered_map> // Perform all operations in fp16 except for: {"batchnorm"}}, {OptimLevel::O3, {}} // Perform all operations in f16 -}; + }; } // namespace detail diff --git a/flashlight/fl/common/Utils.cpp b/flashlight/fl/common/Utils.cpp index 42f7db4..453dc7c 100644 --- a/flashlight/fl/common/Utils.cpp +++ b/flashlight/fl/common/Utils.cpp @@ -23,127 +23,129 @@ namespace fl { bool f16Supported() { - return defaultTensorBackend().isDataTypeSupported(fl::dtype::f16); + return defaultTensorBackend().isDataTypeSupported(fl::dtype::f16); } size_t divRoundUp(size_t numerator, size_t denominator) { - if (!numerator) { - return 0; - } - if (!denominator) { - throw std::invalid_argument( - std::string("divRoundUp() zero denominator error")); - } - return (numerator + denominator - 1) / denominator; -} - -namespace { -std::string prettyStringMemorySizeUnits(size_t size) { - if (size == SIZE_MAX) { - return "SIZE_MAX"; - } - std::stringstream ss; - - bool isFirst = true; - while (size) { - size_t shift = 0; - const char* unit = ""; - if (size >= (1ULL << 40)) { // >= 8TB - shift = 40; - unit = "TB"; - } else if (size >= (1ULL << 30)) { // >= 8G B - shift = 30; - unit = "GB"; - } else if (size >= (1ULL << 20)) { // >= 8M B - shift = 20; - unit = "MB"; - } else if (size >= (1ULL << 10)) { // >= 8K B - shift = 10; - unit = "KB"; + if(!numerator) { + return 0; } - if (size > 0) { - if (!isFirst) { - ss << '+'; - } - isFirst = false; - size_t nUnits = size >> shift; - ss << nUnits << unit; - size -= (nUnits << shift); + if(!denominator) { + throw std::invalid_argument( + std::string("divRoundUp() zero denominator error") + ); } - } - - return ss.str(); + return (numerator + denominator - 1) / denominator; } -std::string prettyStringCountUnits(size_t count) { - if (count == SIZE_MAX) { - return "SIZE_MAX"; - } - std::stringstream ss; - - bool isFirst = true; - while (count) { - size_t magnitude = 1; - const char* unit = ""; - if (count >= 1e12) { - magnitude = 1e12; - unit = "t"; - } else if (count >= 1e9) { - magnitude = 1e9; - unit = "b"; - } else if (count >= 1e6) { - magnitude = 1e6; - unit = "m"; - } else if (count >= 1e3) { - magnitude = 1e3; - unit = "k"; - } - if (count > 0) { - if (!isFirst) { - ss << '+'; - } - isFirst = false; - size_t nUnits = count / magnitude; - ss << nUnits << unit; - count -= (nUnits * magnitude); +namespace { + std::string prettyStringMemorySizeUnits(size_t size) { + if(size == SIZE_MAX) { + return "SIZE_MAX"; + } + std::stringstream ss; + + bool isFirst = true; + while(size) { + size_t shift = 0; + const char* unit = ""; + if(size >= (1ULL << 40)) { // >= 8TB + shift = 40; + unit = "TB"; + } else if(size >= (1ULL << 30)) { // >= 8G B + shift = 30; + unit = "GB"; + } else if(size >= (1ULL << 20)) { // >= 8M B + shift = 20; + unit = "MB"; + } else if(size >= (1ULL << 10)) { // >= 8K B + shift = 10; + unit = "KB"; + } + if(size > 0) { + if(!isFirst) { + ss << '+'; + } + isFirst = false; + size_t nUnits = size >> shift; + ss << nUnits << unit; + size -= (nUnits << shift); + } + } + + return ss.str(); } - } - return ss.str(); -} + std::string prettyStringCountUnits(size_t count) { + if(count == SIZE_MAX) { + return "SIZE_MAX"; + } + std::stringstream ss; + + bool isFirst = true; + while(count) { + size_t magnitude = 1; + const char* unit = ""; + if(count >= 1e12) { + magnitude = 1e12; + unit = "t"; + } else if(count >= 1e9) { + magnitude = 1e9; + unit = "b"; + } else if(count >= 1e6) { + magnitude = 1e6; + unit = "m"; + } else if(count >= 1e3) { + magnitude = 1e3; + unit = "k"; + } + if(count > 0) { + if(!isFirst) { + ss << '+'; + } + isFirst = false; + size_t nUnits = count / magnitude; + ss << nUnits << unit; + count -= (nUnits * magnitude); + } + } + + return ss.str(); + } } // namespace std::string prettyStringMemorySize(size_t size) { - if (size == SIZE_MAX) { - return "SIZE_MAX"; - } - std::stringstream ss; - ss << size; - if (size >= (1UL << 13)) { - ss << '(' << prettyStringMemorySizeUnits(size) << ')'; - } - - return ss.str(); + if(size == SIZE_MAX) { + return "SIZE_MAX"; + } + std::stringstream ss; + ss << size; + if(size >= (1UL << 13)) { + ss << '(' << prettyStringMemorySizeUnits(size) << ')'; + } + + return ss.str(); } std::string prettyStringCount(size_t count) { - if (count == SIZE_MAX) { - return "SIZE_MAX"; - } - std::stringstream ss; - ss << count; - - if (count >= 1e3) { // >= 10 thousand - ss << '(' << prettyStringCountUnits(count) << ')'; - } - return ss.str(); + if(count == SIZE_MAX) { + return "SIZE_MAX"; + } + std::stringstream ss; + ss << count; + + if(count >= 1e3) { // >= 10 thousand + ss << '(' << prettyStringCountUnits(count) << ')'; + } + return ss.str(); } std::string getEnvVar( const std::string& key, - const std::string& dflt /*= "" */) { - char* val = getenv(key.c_str()); - return val ? std::string(val) : dflt; + const std::string& dflt /*= "" */ +) { + char* val = getenv(key.c_str()); + return val ? std::string(val) : dflt; } } // namespace fl diff --git a/flashlight/fl/common/Utils.h b/flashlight/fl/common/Utils.h index a851c8d..955824f 100644 --- a/flashlight/fl/common/Utils.h +++ b/flashlight/fl/common/Utils.h @@ -44,37 +44,39 @@ FL_API std::string prettyStringCount(size_t count); * Supports sleeps between retries, with duration starting at `initial` and * multiplying by `factor` each retry. At most `maxIters` calls are made. */ -template +template typename std::invoke_result::type retryWithBackoff( std::chrono::duration initial, double factor, int64_t maxIters, Fn&& f, - Args&&... args) { - if (!(initial.count() >= 0.0)) { - throw std::invalid_argument("retryWithBackoff: bad initial"); - } else if (!(factor >= 0.0)) { - throw std::invalid_argument("retryWithBackoff: bad factor"); - } else if (maxIters <= 0) { - throw std::invalid_argument("retryWithBackoff: bad maxIters"); - } - auto sleepSecs = initial.count(); - for (int64_t i = 0; i < maxIters; ++i) { - try { - return f(std::forward(args)...); - } catch (...) { - if (i >= maxIters - 1) { - throw; - } + Args&&... args +) { + if(!(initial.count() >= 0.0)) { + throw std::invalid_argument("retryWithBackoff: bad initial"); + } else if(!(factor >= 0.0)) { + throw std::invalid_argument("retryWithBackoff: bad factor"); + } else if(maxIters <= 0) { + throw std::invalid_argument("retryWithBackoff: bad maxIters"); } - if (sleepSecs > 0.0) { - /* sleep override */ - std::this_thread::sleep_for( - std::chrono::duration(std::min(1e7, sleepSecs))); + auto sleepSecs = initial.count(); + for(int64_t i = 0; i < maxIters; ++i) { + try { + return f(std::forward(args)...); + } catch(...) { + if(i >= maxIters - 1) { + throw; + } + } + if(sleepSecs > 0.0) { + /* sleep override */ + std::this_thread::sleep_for( + std::chrono::duration(std::min(1e7, sleepSecs)) + ); + } + sleepSecs *= factor; } - sleepSecs *= factor; - } - throw std::logic_error("retryWithBackoff: hit unreachable"); + throw std::logic_error("retryWithBackoff: hit unreachable"); } /** @@ -82,7 +84,8 @@ typename std::invoke_result::type retryWithBackoff( */ FL_API std::string getEnvVar( const std::string& key, - const std::string& dflt = ""); + const std::string& dflt = "" +); /** @} */ diff --git a/flashlight/fl/common/WinUtility.cpp b/flashlight/fl/common/WinUtility.cpp index 1d5212b..f7b7063 100644 --- a/flashlight/fl/common/WinUtility.cpp +++ b/flashlight/fl/common/WinUtility.cpp @@ -15,53 +15,70 @@ namespace fl { namespace detail { -std::wstring utf8ToWide(const std::string& utf8) { - if (utf8.empty()) { - return std::wstring(); - } + std::wstring utf8ToWide(const std::string& utf8) { + if(utf8.empty()) { + return std::wstring(); + } - int wideSize = MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, nullptr, 0); - if (wideSize == 0) { - throw std::runtime_error("Failed to convert UTF-8 to wide string"); - } + int wideSize = MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, nullptr, 0); + if(wideSize == 0) { + throw std::runtime_error("Failed to convert UTF-8 to wide string"); + } - std::wstring wide(wideSize - 1, 0); - MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, &wide[0], wideSize); - return wide; -} + std::wstring wide(wideSize - 1, 0); + MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, &wide[0], wideSize); + return wide; + } -std::string getWindowsErrorString() { - DWORD error = GetLastError(); - if (error == 0) { - return "No error"; - } + std::string getWindowsErrorString() { + DWORD error = GetLastError(); + if(error == 0) { + return "No error"; + } - LPWSTR messageBuffer = nullptr; - FormatMessageW( - FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | - FORMAT_MESSAGE_IGNORE_INSERTS, - nullptr, - error, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPWSTR)&messageBuffer, - 0, - nullptr); + LPWSTR messageBuffer = nullptr; + FormatMessageW( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM + | FORMAT_MESSAGE_IGNORE_INSERTS, + nullptr, + error, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPWSTR) &messageBuffer, + 0, + nullptr + ); - std::string result; - if (messageBuffer) { - int utf8Size = WideCharToMultiByte(CP_UTF8, 0, messageBuffer, -1, nullptr, - 0, nullptr, nullptr); - if (utf8Size > 0) { - result.resize(utf8Size - 1); - WideCharToMultiByte(CP_UTF8, 0, messageBuffer, -1, &result[0], utf8Size, - nullptr, nullptr); + std::string result; + if(messageBuffer) { + int utf8Size = WideCharToMultiByte( + CP_UTF8, + 0, + messageBuffer, + -1, + nullptr, + 0, + nullptr, + nullptr + ); + if(utf8Size > 0) { + result.resize(utf8Size - 1); + WideCharToMultiByte( + CP_UTF8, + 0, + messageBuffer, + -1, + &result[0], + utf8Size, + nullptr, + nullptr + ); + } + LocalFree(messageBuffer); + } else { + result = "Unknown error"; + } + return result; } - LocalFree(messageBuffer); - } else { - result = "Unknown error"; - } - return result; -} } // namespace detail } // namespace fl diff --git a/flashlight/fl/common/WinUtility.h b/flashlight/fl/common/WinUtility.h index 4b29f56..0d9de45 100644 --- a/flashlight/fl/common/WinUtility.h +++ b/flashlight/fl/common/WinUtility.h @@ -20,13 +20,13 @@ namespace detail { * @return Wide string (UTF-16LE) * @throws std::runtime_error if conversion fails */ -std::wstring utf8ToWide(const std::string& utf8); + std::wstring utf8ToWide(const std::string& utf8); /** * Get a human-readable error message from the last Windows error code * @return Error message as UTF-8 string */ -std::string getWindowsErrorString(); + std::string getWindowsErrorString(); } // namespace detail } // namespace fl diff --git a/flashlight/fl/common/stacktrace/Backward.cpp b/flashlight/fl/common/stacktrace/Backward.cpp index 506c5d9..9e85ae1 100644 --- a/flashlight/fl/common/stacktrace/Backward.cpp +++ b/flashlight/fl/common/stacktrace/Backward.cpp @@ -4,13 +4,13 @@ // On GNU/Linux, you have few choices to get the most out of your stack trace. // // By default you get: -// - object filename -// - function name +// - object filename +// - function name // // In order to add: -// - source filename -// - line and column numbers -// - source code snippet (assuming the file is accessible) +// - source filename +// - line and column numbers +// - source code snippet (assuming the file is accessible) // Install one of the following libraries then uncomment one of the macro (or // better, add the detection of the lib and the macro definition in your build @@ -44,9 +44,9 @@ namespace fl::detail { void initBackward() { - // If not built with backward, this function is a noop + // If not built with backward, this function is a noop #if FL_USE_BACKWARD_CPP - static ::backward::SignalHandling sh; + static ::backward::SignalHandling sh; #endif } diff --git a/flashlight/fl/common/threadpool/ThreadPool.h b/flashlight/fl/common/threadpool/ThreadPool.h index b653d97..f2077fa 100644 --- a/flashlight/fl/common/threadpool/ThreadPool.h +++ b/flashlight/fl/common/threadpool/ThreadPool.h @@ -22,11 +22,11 @@ namespace fl { /** -* A simple C++11 Thread Pool implementation. -* Source - https://github.com/progschj/ThreadPool -* -* Basic usage: - \code + * A simple C++11 Thread Pool implementation. + * Source - https://github.com/progschj/ThreadPool + * + * Basic usage: + \code // create thread pool with 4 worker threads ThreadPool pool(4); @@ -35,99 +35,108 @@ namespace fl { // get result from future std::cout << result.get() << std::endl; - \endcode -*/ + \endcode + */ class ThreadPool { - public: - /** - * the constructor just launches given amount of workers - * \param [in] threads number of threads - * \param [in] initFn initialization code (if any) that will be run on all the - * threads - */ - ThreadPool( - size_t threads, - const std::function& initFn = nullptr); - - /** - * add new work item to the pool - * \param [in] f function to be executed in threadpool - * \param [in] args varadic arguments for the function - */ - template - auto enqueue(F&& f, Args&&... args) - -> std::future::type>; - /// destructor joins all threads. - ~ThreadPool(); - - private: - // need to keep track of threads so we can join them - std::vector workers; - // the task queue - std::queue> tasks; - - // synchronization - std::mutex queue_mutex; - std::condition_variable condition; - bool stop; +public: + /** + * the constructor just launches given amount of workers + * \param [in] threads number of threads + * \param [in] initFn initialization code (if any) that will be run on all the + * threads + */ + ThreadPool( + size_t threads, + const std::function& initFn = nullptr + ); + + /** + * add new work item to the pool + * \param [in] f function to be executed in threadpool + * \param [in] args varadic arguments for the function + */ + template + auto enqueue(F&& f, Args&&... args) + -> std::future::type>; + /// destructor joins all threads. + ~ThreadPool(); + +private: + // need to keep track of threads so we can join them + std::vector workers; + // the task queue + std::queue> tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; }; // namespace flclassThreadPool inline ThreadPool::ThreadPool( size_t threads, - const std::function& initFn /* = nullptr */) - : stop(false) { - for (size_t id = 0; id < threads; ++id) - workers.emplace_back([this, initFn, id] { - if (initFn) { - initFn(id); - } - for (;;) { - std::function task; - - { - std::unique_lock lock(this->queue_mutex); - this->condition.wait( - lock, [this] { return this->stop || !this->tasks.empty(); }); - if (this->stop && this->tasks.empty()) - return; - task = std::move(this->tasks.front()); - this->tasks.pop(); - } - - task(); - } - }); + const std::function& initFn /* = nullptr */ +) : stop(false) { + for(size_t id = 0; id < threads; ++id) { + workers.emplace_back( + [this, initFn, id] { + if(initFn) { + initFn(id); + } + for(;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, + [this] { return this->stop || !this->tasks.empty(); }); + if(this->stop && this->tasks.empty()) { + return; + } + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + } + ); + } } -template +template auto ThreadPool::enqueue(F&& f, Args&&... args) - -> std::future::type> { - using return_type = typename std::invoke_result::type; +-> std::future::type> { + using return_type = typename std::invoke_result::type; - auto task = std::make_shared>( - std::bind(std::forward(f), std::forward(args)...)); + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...) + ); - std::future res = task->get_future(); - { - std::unique_lock lock(queue_mutex); + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); - // don't allow enqueueing after stopping the pool - if (stop) - throw std::runtime_error("enqueue on stopped ThreadPool"); + // don't allow enqueueing after stopping the pool + if(stop) { + throw std::runtime_error("enqueue on stopped ThreadPool"); + } - tasks.emplace([task]() { (*task)(); }); - } - condition.notify_one(); - return res; + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; } inline ThreadPool::~ThreadPool() { - { - std::unique_lock lock(queue_mutex); - stop = true; - } - condition.notify_all(); - for (std::thread& worker : workers) - worker.join(); + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for(std::thread& worker : workers) { + worker.join(); + } } } // namespace fl diff --git a/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp b/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp index ff12c8e..3438150 100644 --- a/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp +++ b/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp @@ -19,93 +19,108 @@ namespace fl { AdaptiveEmbedding::AdaptiveEmbedding( int embeddingDim, std::vector cutoff, - float divValue /*= 4 */) - : embeddingDim_(embeddingDim), cutoff_(cutoff), divValue_(divValue) { - if (cutoff_.empty()) { - throw std::invalid_argument("Invalid cutoff for AdaptiveEmbedding"); - } - double stdv = std::sqrt(1.0 / static_cast(embeddingDim_)); - // to be in agreement with the adaptive softmax to simplify - // tied version of adaptive input and softmax - auto headEmbedding = fl::normal(cutoff_[0], embeddingDim_, stdv, 0); - params_.push_back(headEmbedding); - auto head = fl::glorotUniform( - {embeddingDim_, embeddingDim_}, embeddingDim_, embeddingDim_); - params_.push_back(head); - - int denominator = 1; - for (int tailIdx = 1; tailIdx < cutoff_.size(); tailIdx++) { - denominator *= divValue_; - int tailEmbeddingDim = embeddingDim_ / denominator; - double stdvTail = std::sqrt(1.0 / static_cast(tailEmbeddingDim)); + float divValue /*= 4 */ +) : embeddingDim_(embeddingDim), + cutoff_(cutoff), + divValue_(divValue) { + if(cutoff_.empty()) { + throw std::invalid_argument("Invalid cutoff for AdaptiveEmbedding"); + } + double stdv = std::sqrt(1.0 / static_cast(embeddingDim_)); // to be in agreement with the adaptive softmax to simplify // tied version of adaptive input and softmax - auto tailEmbedding = fl::normal( - cutoff_[tailIdx] - cutoff_[tailIdx - 1], tailEmbeddingDim, stdvTail, 0); - params_.push_back(tailEmbedding); - auto tail = fl::glorotUniform( - {embeddingDim_, tailEmbeddingDim}, tailEmbeddingDim, embeddingDim_); - params_.push_back(tail); - } + auto headEmbedding = fl::normal(cutoff_[0], embeddingDim_, stdv, 0); + params_.push_back(headEmbedding); + auto head = fl::glorotUniform( + {embeddingDim_, embeddingDim_}, + embeddingDim_, + embeddingDim_ + ); + params_.push_back(head); + + int denominator = 1; + for(int tailIdx = 1; tailIdx < cutoff_.size(); tailIdx++) { + denominator *= divValue_; + int tailEmbeddingDim = embeddingDim_ / denominator; + double stdvTail = std::sqrt(1.0 / static_cast(tailEmbeddingDim)); + // to be in agreement with the adaptive softmax to simplify + // tied version of adaptive input and softmax + auto tailEmbedding = fl::normal( + cutoff_[tailIdx] - cutoff_[tailIdx - 1], + tailEmbeddingDim, + stdvTail, + 0 + ); + params_.push_back(tailEmbedding); + auto tail = fl::glorotUniform( + {embeddingDim_, tailEmbeddingDim}, + tailEmbeddingDim, + embeddingDim_ + ); + params_.push_back(tail); + } } Variable AdaptiveEmbedding::forward(const Variable& input) { - if (input.ndim() != 2) { - throw std::invalid_argument( - "AdaptiveEmbedding::forward - input must " - "have 2 dimensions - expect T x B"); - } + if(input.ndim() != 2) { + throw std::invalid_argument( + "AdaptiveEmbedding::forward - input must " + "have 2 dimensions - expect T x B" + ); + } - auto flatInput = flat(input); - std::vector indices; - std::vector embeddings; + auto flatInput = flat(input); + std::vector indices; + std::vector embeddings; - Tensor headMask = flatInput.tensor() < cutoff_[0]; - if (fl::sum(headMask).scalar() > 0) { - auto headEmbedding = - embedding(flatInput(headMask), reorder(params_[0], {1, 0})); - headEmbedding = matmul(params_[1], headEmbedding); - indices.emplace_back(fl::nonzero(headMask), false); - embeddings.push_back(headEmbedding); - } + Tensor headMask = flatInput.tensor() < cutoff_[0]; + if(fl::sum(headMask).scalar() > 0) { + auto headEmbedding = + embedding(flatInput(headMask), reorder(params_[0], {1, 0})); + headEmbedding = matmul(params_[1], headEmbedding); + indices.emplace_back(fl::nonzero(headMask), false); + embeddings.push_back(headEmbedding); + } - for (int tailIdx = 1; tailIdx < cutoff_.size(); tailIdx++) { - Tensor tailMask = flatInput.tensor() < cutoff_[tailIdx] && - flatInput.tensor() >= cutoff_[tailIdx - 1]; - if (fl::any(tailMask).asScalar()) { - auto tailEmbedding = embedding( - flatInput(tailMask) - cutoff_[tailIdx - 1], - reorder(params_[tailIdx * 2], {1, 0})); - tailEmbedding = matmul(params_[tailIdx * 2 + 1], tailEmbedding); - indices.emplace_back(fl::nonzero(tailMask), false); - embeddings.push_back(tailEmbedding); + for(int tailIdx = 1; tailIdx < cutoff_.size(); tailIdx++) { + Tensor tailMask = flatInput.tensor() < cutoff_[tailIdx] + && flatInput.tensor() >= cutoff_[tailIdx - 1]; + if(fl::any(tailMask).asScalar()) { + auto tailEmbedding = embedding( + flatInput(tailMask) - cutoff_[tailIdx - 1], + reorder(params_[tailIdx * 2], {1, 0}) + ); + tailEmbedding = matmul(params_[tailIdx * 2 + 1], tailEmbedding); + indices.emplace_back(fl::nonzero(tailMask), false); + embeddings.push_back(tailEmbedding); + } + } + if(embeddings.empty()) { + throw std::invalid_argument( + "Invalid input, no positions in the AdaptiveEmbedding layer" + ); } - } - if (embeddings.empty()) { - throw std::invalid_argument( - "Invalid input, no positions in the AdaptiveEmbedding layer"); - } - Shape outShape({embeddingDim_, input.dim(0), input.dim(1)}); - auto result = fl::concatenate(embeddings, 1); - auto resultIndices = fl::concatenate(indices, 0); - Tensor tmpIndices = fl::argsort(resultIndices.tensor(), 0); - return moddims(result(fl::span, tmpIndices), outShape); + Shape outShape({embeddingDim_, input.dim(0), input.dim(1)}); + auto result = fl::concatenate(embeddings, 1); + auto resultIndices = fl::concatenate(indices, 0); + Tensor tmpIndices = fl::argsort(resultIndices.tensor(), 0); + return moddims(result(fl::span, tmpIndices), outShape); } std::unique_ptr AdaptiveEmbedding::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string AdaptiveEmbedding::prettyString() const { - std::ostringstream ss; - ss << "AdaptiveEmbedding (dim: " << embeddingDim_ << "), (cutoff: "; - for (int i = 0; i < cutoff_.size() - 1; i++) { - ss << cutoff_[i] << ", "; - } - ss << cutoff_[cutoff_.size() - 1] << "), " - << "(divValue: " << divValue_ << ")"; - return ss.str(); + std::ostringstream ss; + ss << "AdaptiveEmbedding (dim: " << embeddingDim_ << "), (cutoff: "; + for(int i = 0; i < cutoff_.size() - 1; i++) { + ss << cutoff_[i] << ", "; + } + ss << cutoff_[cutoff_.size() - 1] << "), " + << "(divValue: " << divValue_ << ")"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/contrib/modules/AdaptiveEmbedding.h b/flashlight/fl/contrib/modules/AdaptiveEmbedding.h index 50d26fc..0483df6 100644 --- a/flashlight/fl/contrib/modules/AdaptiveEmbedding.h +++ b/flashlight/fl/contrib/modules/AdaptiveEmbedding.h @@ -20,41 +20,42 @@ namespace fl { * [`embedding_dim`, \f$B_1\f$, \f$B_2\f$ (optional), \f$B_3\f$ (optional)]. */ class FL_API AdaptiveEmbedding : public UnaryModule { - private: - AdaptiveEmbedding() = default; // Intentionally private - int embeddingDim_; - std::vector cutoff_; - float divValue_; - - FL_SAVE_LOAD_WITH_BASE(UnaryModule, embeddingDim_, cutoff_, divValue_) - - public: - /** - * Constructs an Embedding module. - * - * @param[in] embedding_dim the size of each embedding vector - * @param[in] cutoff a sequence of integers sorted in ascending order, which - * determines the relative size of each bucket, and how many partitions are - * created. For example, given cutoffs `{5, 50, 100}`, the head bucket will - * contain `5` targets, the - * first tail bucket will contain `50 - 5 = 45` targets (subtracting the size - * of the head bucket), the second tail bucket will contain `100 - 50 = 50` - * targets (subtracting the size of the first tail bucket). Cutoffs must be - * specified to accommodate all targets: any remaining targets are not - * assigned to an 'overflow' bucket. - * @param[in] divValue is the scaling factor for tail groups dimention - * reduction (see paper https://arxiv.org/pdf/1809.10853.pdf for details). - */ - explicit AdaptiveEmbedding( - int embeddingDim, - std::vector cutoff, - float divValue = 4); - - Variable forward(const Variable& input) override; - - std::unique_ptr clone() const override; - - std::string prettyString() const override; +private: + AdaptiveEmbedding() = default; // Intentionally private + int embeddingDim_; + std::vector cutoff_; + float divValue_; + + FL_SAVE_LOAD_WITH_BASE(UnaryModule, embeddingDim_, cutoff_, divValue_) + +public: + /** + * Constructs an Embedding module. + * + * @param[in] embedding_dim the size of each embedding vector + * @param[in] cutoff a sequence of integers sorted in ascending order, which + * determines the relative size of each bucket, and how many partitions are + * created. For example, given cutoffs `{5, 50, 100}`, the head bucket will + * contain `5` targets, the + * first tail bucket will contain `50 - 5 = 45` targets (subtracting the size + * of the head bucket), the second tail bucket will contain `100 - 50 = 50` + * targets (subtracting the size of the first tail bucket). Cutoffs must be + * specified to accommodate all targets: any remaining targets are not + * assigned to an 'overflow' bucket. + * @param[in] divValue is the scaling factor for tail groups dimention + * reduction (see paper https://arxiv.org/pdf/1809.10853.pdf for details). + */ + explicit AdaptiveEmbedding( + int embeddingDim, + std::vector cutoff, + float divValue = 4 + ); + + Variable forward(const Variable& input) override; + + std::unique_ptr clone() const override; + + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/contrib/modules/AsymmetricConv1D.cpp b/flashlight/fl/contrib/modules/AsymmetricConv1D.cpp index a39cb98..ebd70a8 100644 --- a/flashlight/fl/contrib/modules/AsymmetricConv1D.cpp +++ b/flashlight/fl/contrib/modules/AsymmetricConv1D.cpp @@ -14,14 +14,16 @@ namespace fl { void AsymmetricConv1D::checkParams() { - if (xPad_ != static_cast(PaddingMode::SAME) && xPad_ != 0) { - throw std::invalid_argument( - "AsymmetricConv1D: invalid xPad_, now supports only '0' or 'SAME' "); - } - if (futurePart_ < 0 || futurePart_ > 1) { - throw std::invalid_argument( - "AsymmetricConv1D: invalid futurePart_, should be in [0, 1]"); - } + if(xPad_ != static_cast(PaddingMode::SAME) && xPad_ != 0) { + throw std::invalid_argument( + "AsymmetricConv1D: invalid xPad_, now supports only '0' or 'SAME' " + ); + } + if(futurePart_ < 0 || futurePart_ > 1) { + throw std::invalid_argument( + "AsymmetricConv1D: invalid futurePart_, should be in [0, 1]" + ); + } } AsymmetricConv1D::AsymmetricConv1D( @@ -33,10 +35,10 @@ AsymmetricConv1D::AsymmetricConv1D( float futurePart /* 0.5 */, int dx /* 1 */, bool bias /* true */, - int groups /* 1 */) - : Conv2D(nIn, nOut, wx, 1, sx, 1, px, 0, dx, 1, bias, groups), - futurePart_(futurePart) { - checkParams(); + int groups /* 1 */ +) : Conv2D(nIn, nOut, wx, 1, sx, 1, px, 0, dx, 1, bias, groups), + futurePart_(futurePart) { + checkParams(); } AsymmetricConv1D::AsymmetricConv1D( @@ -45,9 +47,10 @@ AsymmetricConv1D::AsymmetricConv1D( fl::detail::IntOrPadMode px /*= 0 */, float futurePart /*= 0.5 */, int dx /*= 1 */, - int groups /*= 1 */) - : Conv2D(w, sx, 1, px, 0, dx, 1, groups), futurePart_(futurePart) { - checkParams(); + int groups /*= 1 */ +) : Conv2D(w, sx, 1, px, 0, dx, 1, groups), + futurePart_(futurePart) { + checkParams(); } AsymmetricConv1D::AsymmetricConv1D( @@ -57,61 +60,64 @@ AsymmetricConv1D::AsymmetricConv1D( fl::detail::IntOrPadMode px /*= 0 */, float futurePart /*= 0.5 */, int dx /*= 1 */, - int groups /*= 1 */) - : Conv2D(w, b, sx, 1, px, 0, dx, 1, groups), futurePart_(futurePart) { - checkParams(); + int groups /*= 1 */ +) : Conv2D(w, b, sx, 1, px, 0, dx, 1, groups), + futurePart_(futurePart) { + checkParams(); } Variable AsymmetricConv1D::forward(const Variable& input) { - auto px = - fl::derivePadding(input.dim(0), xFilter_, xStride_, xPad_, xDilation_); - if (!(px >= 0)) { - throw std::invalid_argument("invalid padding for AsymmetricConv1D"); - } - Variable output; - int cutPx = std::abs(2 * (0.5 - futurePart_)) * px; - int asymmetryPx = px + cutPx; - if (bias_) { - output = conv2d( - input, - params_[0], - params_[1], - xStride_, - yStride_, - asymmetryPx, - 0, - xDilation_, - yDilation_, - groups_); - } else { - output = conv2d( - input, - params_[0], - xStride_, - yStride_, - asymmetryPx, - 0, - xDilation_, - yDilation_, - groups_); - } - if (futurePart_ < 0.5) { - output = output(fl::range(0, output.dim(0) - 2 * cutPx)); - } else if (futurePart_ > 0.5) { - output = output(fl::range(2 * cutPx, output.dim(0))); - } - return output; + auto px = + fl::derivePadding(input.dim(0), xFilter_, xStride_, xPad_, xDilation_); + if(!(px >= 0)) { + throw std::invalid_argument("invalid padding for AsymmetricConv1D"); + } + Variable output; + int cutPx = std::abs(2 * (0.5 - futurePart_)) * px; + int asymmetryPx = px + cutPx; + if(bias_) { + output = conv2d( + input, + params_[0], + params_[1], + xStride_, + yStride_, + asymmetryPx, + 0, + xDilation_, + yDilation_, + groups_ + ); + } else { + output = conv2d( + input, + params_[0], + xStride_, + yStride_, + asymmetryPx, + 0, + xDilation_, + yDilation_, + groups_ + ); + } + if(futurePart_ < 0.5) { + output = output(fl::range(0, output.dim(0) - 2 * cutPx)); + } else if(futurePart_ > 0.5) { + output = output(fl::range(2 * cutPx, output.dim(0))); + } + return output; } std::unique_ptr AsymmetricConv1D::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string AsymmetricConv1D::prettyString() const { - std::ostringstream ss; - ss << "AsymmetricConv1D"; - ss << " (" << Conv2D::prettyString() << ")"; - return ss.str(); + std::ostringstream ss; + ss << "AsymmetricConv1D"; + ss << " (" << Conv2D::prettyString() << ")"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/contrib/modules/AsymmetricConv1D.h b/flashlight/fl/contrib/modules/AsymmetricConv1D.h index 6e13e9e..3c1038c 100644 --- a/flashlight/fl/contrib/modules/AsymmetricConv1D.h +++ b/flashlight/fl/contrib/modules/AsymmetricConv1D.h @@ -21,47 +21,50 @@ namespace fl { * Note: currently only '0' and SAME padding are supported. */ class FL_API AsymmetricConv1D : public fl::Conv2D { - public: - AsymmetricConv1D( - int nIn, - int nOut, - int wx, - int sx = 1, - fl::detail::IntOrPadMode px = 0, - float futurePart = 0.5, - int dx = 1, - bool bias = true, - int groups = 1); +public: + AsymmetricConv1D( + int nIn, + int nOut, + int wx, + int sx = 1, + fl::detail::IntOrPadMode px = 0, + float futurePart = 0.5, + int dx = 1, + bool bias = true, + int groups = 1 + ); - explicit AsymmetricConv1D( - const fl::Variable& w, - int sx = 1, - fl::detail::IntOrPadMode px = 0, - float futurePart = 0.5, - int dx = 1, - int groups = 1); + explicit AsymmetricConv1D( + const fl::Variable& w, + int sx = 1, + fl::detail::IntOrPadMode px = 0, + float futurePart = 0.5, + int dx = 1, + int groups = 1 + ); - AsymmetricConv1D( - const fl::Variable& w, - const fl::Variable& b, - int sx = 1, - fl::detail::IntOrPadMode px = 0, - float futurePart = 0.5, - int dx = 1, - int groups = 1); + AsymmetricConv1D( + const fl::Variable& w, + const fl::Variable& b, + int sx = 1, + fl::detail::IntOrPadMode px = 0, + float futurePart = 0.5, + int dx = 1, + int groups = 1 + ); - fl::Variable forward(const fl::Variable& input) override; + fl::Variable forward(const fl::Variable& input) override; - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(fl::Conv2D, futurePart_) - float futurePart_; - void checkParams(); +private: + FL_SAVE_LOAD_WITH_BASE(fl::Conv2D, futurePart_) + float futurePart_; + void checkParams(); - AsymmetricConv1D() = default; + AsymmetricConv1D() = default; }; } // namespace fl diff --git a/flashlight/fl/contrib/modules/Conformer.cpp b/flashlight/fl/contrib/modules/Conformer.cpp index f8c6260..0629d8a 100644 --- a/flashlight/fl/contrib/modules/Conformer.cpp +++ b/flashlight/fl/contrib/modules/Conformer.cpp @@ -25,261 +25,288 @@ Conformer::Conformer( int32_t posEmbContextSize, int32_t convKernelSize, float pDropout, - float pLayerDropout /* = 0. */) - : nHeads_(nHeads), - posEmbContextSize_(posEmbContextSize), - convKernelSize_(convKernelSize), - pDropout_(pDropout), - pLayerDropout_(pLayerDropout), - w11_(std::make_shared(conformerInitLinear(modelDim, mlpDim))), - w12_(std::make_shared(conformerInitLinear(mlpDim, modelDim))), - w21_(std::make_shared(conformerInitLinear(modelDim, mlpDim))), - w22_(std::make_shared(conformerInitLinear(mlpDim, modelDim))), - wq_(std::make_shared( - conformerInitLinear(modelDim, headDim * nHeads))), - wk_(std::make_shared( - conformerInitLinear(modelDim, headDim * nHeads))), - wv_(std::make_shared( - conformerInitLinear(modelDim, headDim * nHeads))), - wf_(std::make_shared( - conformerInitLinear(headDim * nHeads, modelDim))), - conv1_(std::make_shared( - conformerInitLinear(modelDim, modelDim * 2))), - conv2_(std::make_shared(conformerInitLinear(modelDim, modelDim))), - norm1_(std::make_shared( - std::vector({0}), - 1e-5, - true, - modelDim)), - norm2_(std::make_shared( - std::vector({0}), - 1e-5, - true, - modelDim)), - normMhsa_(std::make_shared( - std::vector({0}), - 1e-5, - true, - modelDim)), - normConv1_(std::make_shared( - std::vector({0}), - 1e-5, - true, - modelDim)), - normConv2_(std::make_shared( - std::vector({0}), - 1e-5, - true, - modelDim)), - norm3_(std::make_shared( - std::vector({0}), - 1e-5, - true, - modelDim)), - convDepthWise_(std::make_shared( - modelDim, - modelDim, - convKernelSize, - 1, - 1, - 1, - fl::PaddingMode::SAME, - 0, - 1, - 1, - true, - modelDim)) { - if (posEmbContextSize_ > 0) { - params_.push_back(uniform(2 * posEmbContextSize_ - 1, headDim, -0.1, 0.1)); - } - createLayers(); + float pLayerDropout /* = 0. */ +) : nHeads_(nHeads), + posEmbContextSize_(posEmbContextSize), + convKernelSize_(convKernelSize), + pDropout_(pDropout), + pLayerDropout_(pLayerDropout), + w11_(std::make_shared(conformerInitLinear(modelDim, mlpDim))), + w12_(std::make_shared(conformerInitLinear(mlpDim, modelDim))), + w21_(std::make_shared(conformerInitLinear(modelDim, mlpDim))), + w22_(std::make_shared(conformerInitLinear(mlpDim, modelDim))), + wq_(std::make_shared( + conformerInitLinear(modelDim, headDim * nHeads) + )), + wk_(std::make_shared( + conformerInitLinear(modelDim, headDim * nHeads) + )), + wv_(std::make_shared( + conformerInitLinear(modelDim, headDim * nHeads) + )), + wf_(std::make_shared( + conformerInitLinear(headDim * nHeads, modelDim) + )), + conv1_(std::make_shared( + conformerInitLinear(modelDim, modelDim * 2) + )), + conv2_(std::make_shared(conformerInitLinear(modelDim, modelDim))), + norm1_(std::make_shared( + std::vector({0}), + 1e-5, + true, + modelDim + )), + norm2_(std::make_shared( + std::vector({0}), + 1e-5, + true, + modelDim + )), + normMhsa_(std::make_shared( + std::vector({0}), + 1e-5, + true, + modelDim + )), + normConv1_(std::make_shared( + std::vector({0}), + 1e-5, + true, + modelDim + )), + normConv2_(std::make_shared( + std::vector({0}), + 1e-5, + true, + modelDim + )), + norm3_(std::make_shared( + std::vector({0}), + 1e-5, + true, + modelDim + )), + convDepthWise_(std::make_shared( + modelDim, + modelDim, + convKernelSize, + 1, + 1, + 1, + fl::PaddingMode::SAME, + 0, + 1, + 1, + true, + modelDim + )) { + if(posEmbContextSize_ > 0) { + params_.push_back(uniform(2 * posEmbContextSize_ - 1, headDim, -0.1, 0.1)); + } + createLayers(); } Conformer::Conformer(const Conformer& other) { - copy(other); - createLayers(); + copy(other); + createLayers(); } Conformer& Conformer::operator=(const Conformer& other) { - clear(); - copy(other); - createLayers(); - return *this; + clear(); + copy(other); + createLayers(); + return *this; } void Conformer::copy(const Conformer& other) { - train_ = other.train_; - nHeads_ = other.nHeads_; - posEmbContextSize_ = other.posEmbContextSize_; - convKernelSize_ = other.convKernelSize_; - pDropout_ = other.pDropout_; - pLayerDropout_ = other.pLayerDropout_; - w11_ = std::make_shared(*other.w11_); - w12_ = std::make_shared(*other.w12_); - w21_ = std::make_shared(*other.w21_); - w22_ = std::make_shared(*other.w22_); - wq_ = std::make_shared(*other.wq_); - wk_ = std::make_shared(*other.wk_); - wv_ = std::make_shared(*other.wv_); - wf_ = std::make_shared(*other.wf_); - conv1_ = std::make_shared(*other.conv1_); - conv2_ = std::make_shared(*other.conv2_); - norm1_ = std::make_shared(*other.norm1_); - norm2_ = std::make_shared(*other.norm2_); - normMhsa_ = std::make_shared(*other.normMhsa_); - normConv1_ = std::make_shared(*other.normConv1_); - normConv2_ = std::make_shared(*other.normConv2_); - norm3_ = std::make_shared(*other.norm3_); - convDepthWise_ = std::make_shared(*other.convDepthWise_); - if (posEmbContextSize_ > 0) { - const auto& p = other.param(0); - params_.emplace_back(p.copy()); - } + train_ = other.train_; + nHeads_ = other.nHeads_; + posEmbContextSize_ = other.posEmbContextSize_; + convKernelSize_ = other.convKernelSize_; + pDropout_ = other.pDropout_; + pLayerDropout_ = other.pLayerDropout_; + w11_ = std::make_shared(*other.w11_); + w12_ = std::make_shared(*other.w12_); + w21_ = std::make_shared(*other.w21_); + w22_ = std::make_shared(*other.w22_); + wq_ = std::make_shared(*other.wq_); + wk_ = std::make_shared(*other.wk_); + wv_ = std::make_shared(*other.wv_); + wf_ = std::make_shared(*other.wf_); + conv1_ = std::make_shared(*other.conv1_); + conv2_ = std::make_shared(*other.conv2_); + norm1_ = std::make_shared(*other.norm1_); + norm2_ = std::make_shared(*other.norm2_); + normMhsa_ = std::make_shared(*other.normMhsa_); + normConv1_ = std::make_shared(*other.normConv1_); + normConv2_ = std::make_shared(*other.normConv2_); + norm3_ = std::make_shared(*other.norm3_); + convDepthWise_ = std::make_shared(*other.convDepthWise_); + if(posEmbContextSize_ > 0) { + const auto& p = other.param(0); + params_.emplace_back(p.copy()); + } } void Conformer::createLayers() { - // first feed-forward module - add(w11_); - add(w12_); - add(norm1_); - // second feed-forward module - add(w21_); - add(w22_); - add(norm2_); - // multihead attention module - add(wq_); - add(wk_); - add(wv_); - add(wf_); - add(normMhsa_); - // conv module - add(conv1_); - add(conv2_); - add(convDepthWise_); - add(normConv1_); - add(normConv2_); - // final layer norm of conformer block - add(norm3_); + // first feed-forward module + add(w11_); + add(w12_); + add(norm1_); + // second feed-forward module + add(w21_); + add(w22_); + add(norm2_); + // multihead attention module + add(wq_); + add(wk_); + add(wv_); + add(wf_); + add(normMhsa_); + // conv module + add(conv1_); + add(conv2_); + add(convDepthWise_); + add(normConv1_); + add(normConv2_); + // final layer norm of conformer block + add(norm3_); } Variable Conformer::conformerInitLinear(int32_t inDim, int32_t outDim) { - float std = std::sqrt(1.0 / float(inDim)); - return fl::uniform(outDim, inDim, -std, std); + float std = std::sqrt(1.0 / float(inDim)); + return fl::uniform(outDim, inDim, -std, std); } Variable Conformer::mhsa(const Variable& input, const Variable& inputPadMask) { - float pDropout = train_ ? pDropout_ : 0.0; - int bsz = input.dim(2); + float pDropout = train_ ? pDropout_ : 0.0; + int bsz = input.dim(2); - auto normedInput = (*normMhsa_)(input); - auto q = transpose((*wq_)(normedInput), {1, 0, 2}); - auto k = transpose((*wk_)(normedInput), {1, 0, 2}); - auto v = transpose((*wv_)(normedInput), {1, 0, 2}); + auto normedInput = (*normMhsa_)(input); + auto q = transpose((*wq_)(normedInput), {1, 0, 2}); + auto k = transpose((*wk_)(normedInput), {1, 0, 2}); + auto v = transpose((*wv_)(normedInput), {1, 0, 2}); - Variable mask, posEmb; - if (posEmbContextSize_ > 0) { - posEmb = tile(params_[0].astype(input.type()), {1, 1, nHeads_ * bsz}); - } + Variable mask, posEmb; + if(posEmbContextSize_ > 0) { + posEmb = tile(params_[0].astype(input.type()), {1, 1, nHeads_ * bsz}); + } - fl::Variable padMask; - // TODO{fl::Tensor}{resize} - emulate the ArrayFire resize operation for - // transformer pad mask - if (!inputPadMask.isEmpty()) { - auto padMaskArr = inputPadMask.tensor(); - Shape newMaskShape = {input.dim(1), input.dim(2)}; - if (padMaskArr.elements() != newMaskShape.elements()) { - throw std::runtime_error( - "Transformer::selfAttention - pad mask requires resize. " - "This behavior will be fixed in a future release "); + fl::Variable padMask; + // TODO{fl::Tensor}{resize} - emulate the ArrayFire resize operation for + // transformer pad mask + if(!inputPadMask.isEmpty()) { + auto padMaskArr = inputPadMask.tensor(); + Shape newMaskShape = {input.dim(1), input.dim(2)}; + if(padMaskArr.elements() != newMaskShape.elements()) { + throw std::runtime_error( + "Transformer::selfAttention - pad mask requires resize. " + "This behavior will be fixed in a future release " + ); + } + padMaskArr = fl::reshape(padMaskArr, newMaskShape); + padMask = fl::Variable(fl::log(padMaskArr), false); } - padMaskArr = fl::reshape(padMaskArr, newMaskShape); - padMask = fl::Variable(fl::log(padMaskArr), false); - } - auto result = - multiheadAttention(q, k, v, posEmb, mask, padMask, nHeads_, pDropout, 0); - result = (*wf_)(transpose(result, {1, 0, 2})); - result = dropout(result, pDropout); - return result; + auto result = + multiheadAttention(q, k, v, posEmb, mask, padMask, nHeads_, pDropout, 0); + result = (*wf_)(transpose(result, {1, 0, 2})); + result = dropout(result, pDropout); + return result; } Variable Conformer::conv(const Variable& _input) { - // Make sure the input has 4 dims for depthwise conv - Shape s = _input.shape(); - Variable input = moddims(_input, {s[0], s[1], s[2], 1}); + // Make sure the input has 4 dims for depthwise conv + Shape s = _input.shape(); + Variable input = moddims(_input, {s[0], s[1], s[2], 1}); - float pDropout = train_ ? pDropout_ : 0.0; - // input C x T x B x 1 - // apply first pointwise conv - auto result = gatedlinearunit( - (*conv1_)(((*normConv1_)(input)).astype(input.type())), 0); - result = reorder(result, {1, 3, 0, 2}); - // T x 1 x C x B - // apply depthwise separable convolutions - result = (*convDepthWise_)(result); - result = reorder(result, {2, 0, 3, 1}); - // C x T x B x 1 - result = fl::swish(((*normConv2_)(result)).astype(input.type()), 1.); - // apply second pointwise conv - result = dropout((*conv2_)(result), pDropout); - return moddims(result, _input.shape()); + float pDropout = train_ ? pDropout_ : 0.0; + // input C x T x B x 1 + // apply first pointwise conv + auto result = gatedlinearunit( + (*conv1_)(((*normConv1_)(input)).astype(input.type())), + 0 + ); + result = reorder(result, {1, 3, 0, 2}); + // T x 1 x C x B + // apply depthwise separable convolutions + result = (*convDepthWise_)(result); + result = reorder(result, {2, 0, 3, 1}); + // C x T x B x 1 + result = fl::swish(((*normConv2_)(result)).astype(input.type()), 1.); + // apply second pointwise conv + result = dropout((*conv2_)(result), pDropout); + return moddims(result, _input.shape()); } std::vector Conformer::forward(const std::vector& input) { - if (input.size() != 2) { - throw std::invalid_argument( - "Invalid inputs for conformer block: there should be input " - "and paddding mask (can be empty Variable)"); - } + if(input.size() != 2) { + throw std::invalid_argument( + "Invalid inputs for conformer block: there should be input " + "and paddding mask (can be empty Variable)" + ); + } - auto x = input[0]; + auto x = input[0]; - if (x.ndim() != 3) { - throw std::invalid_argument( - "Conformer::forward - input should be of 3 dimensions " - "expects an input of size C x T x B - see documentation."); - } + if(x.ndim() != 3) { + throw std::invalid_argument( + "Conformer::forward - input should be of 3 dimensions " + "expects an input of size C x T x B - see documentation." + ); + } - float pDropout = train_ ? pDropout_ : 0.0; - float f = 1.0; - if (train_ && (fl::rand({1}).scalar() < pLayerDropout_)) { - f = 0.0; - } - // apply first feed-forward module - auto ffn1 = dropout( - (*w12_)(dropout( - fl::swish((*w11_)(((*norm1_)(x)).astype(x.type())), 1.), pDropout)), - pDropout); - x = x + f * 0.5 * ffn1; - // apply multihead attention module - x = x + f * mhsa(x, input[1]); - // apply conv module - x = x + f * conv(x); - // apply second feed-forward module - auto ffn2 = dropout( - (*w22_)(dropout( - fl::swish((*w21_)(((*norm2_)(x)).astype(x.type())), 1.), pDropout)), - pDropout); - x = x + f * 0.5 * ffn2; - x = ((*norm3_)(x)).astype(x.type()); - return {x}; + float pDropout = train_ ? pDropout_ : 0.0; + float f = 1.0; + if(train_ && (fl::rand({1}).scalar() < pLayerDropout_)) { + f = 0.0; + } + // apply first feed-forward module + auto ffn1 = dropout( + (*w12_)( + dropout( + fl::swish((*w11_)(((*norm1_)(x)).astype(x.type())), 1.), + pDropout + ) + ), + pDropout + ); + x = x + f * 0.5 * ffn1; + // apply multihead attention module + x = x + f * mhsa(x, input[1]); + // apply conv module + x = x + f * conv(x); + // apply second feed-forward module + auto ffn2 = dropout( + (*w22_)( + dropout( + fl::swish((*w21_)(((*norm2_)(x)).astype(x.type())), 1.), + pDropout + ) + ), + pDropout + ); + x = x + f * 0.5 * ffn2; + x = ((*norm3_)(x)).astype(x.type()); + return {x}; } std::unique_ptr Conformer::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Conformer::prettyString() const { - std::ostringstream ss; - ss << "Conformer " - << "(modelDim: " << params_[1].dim(1) << "), " - << "(mlpDim: " << params_[1].dim(0) << "), " - << "(nHeads: " << nHeads_ << "), " - << "(pDropout: " << pDropout_ << "), " - << "(pLayerDropout: " << pLayerDropout_ << "), " - << "(posEmbContextSize: " << posEmbContextSize_ << "), " - << "(convKernelSize: " << convKernelSize_ << ") "; - return ss.str(); + std::ostringstream ss; + ss << "Conformer " + << "(modelDim: " << params_[1].dim(1) << "), " + << "(mlpDim: " << params_[1].dim(0) << "), " + << "(nHeads: " << nHeads_ << "), " + << "(pDropout: " << pDropout_ << "), " + << "(pLayerDropout: " << pLayerDropout_ << "), " + << "(posEmbContextSize: " << posEmbContextSize_ << "), " + << "(convKernelSize: " << convKernelSize_ << ") "; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/contrib/modules/Conformer.h b/flashlight/fl/contrib/modules/Conformer.h index 1a6338a..d57c0c3 100644 --- a/flashlight/fl/contrib/modules/Conformer.h +++ b/flashlight/fl/contrib/modules/Conformer.h @@ -34,71 +34,73 @@ namespace fl { * @param pLayerdrop layer dropout probability */ class FL_API Conformer : public Container { - public: - explicit Conformer( - int32_t modelDim, - int32_t headDim, - int32_t mlpDim, - int32_t nHeads, - int32_t posEmbContextSize, - int32_t convKernelSize, - float pDropout, - float pLayerDropout = 0.); - Conformer(const Conformer& other); - Conformer(Conformer&& other) = default; +public: + explicit Conformer( + int32_t modelDim, + int32_t headDim, + int32_t mlpDim, + int32_t nHeads, + int32_t posEmbContextSize, + int32_t convKernelSize, + float pDropout, + float pLayerDropout = 0. + ); + Conformer(const Conformer& other); + Conformer(Conformer&& other) = default; - Conformer& operator=(const Conformer& other); - Conformer& operator=(Conformer&& other) = default; + Conformer& operator=(const Conformer& other); + Conformer& operator=(Conformer&& other) = default; - std::vector forward(const std::vector& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::vector forward(const std::vector& input) override; + std::unique_ptr clone() const override; + std::string prettyString() const override; - private: - int32_t nHeads_; - int32_t posEmbContextSize_; - int32_t convKernelSize_; - double pDropout_; - float pLayerDropout_; +private: + int32_t nHeads_; + int32_t posEmbContextSize_; + int32_t convKernelSize_; + double pDropout_; + float pLayerDropout_; - std::shared_ptr w11_, w12_, w21_, w22_, wq_, wk_, wv_, wf_, conv1_, - conv2_; - std::shared_ptr norm1_, norm2_, normMhsa_, normConv1_, normConv2_, - norm3_; - std::shared_ptr convDepthWise_; + std::shared_ptr w11_, w12_, w21_, w22_, wq_, wk_, wv_, wf_, conv1_, + conv2_; + std::shared_ptr norm1_, norm2_, normMhsa_, normConv1_, normConv2_, + norm3_; + std::shared_ptr convDepthWise_; - void copy(const Conformer& other); - void createLayers(); - static Variable conformerInitLinear(int32_t inDim, int32_t outDim); - Variable mhsa(const Variable& input, const Variable& inputPadMask); - Variable conv(const Variable& input); + void copy(const Conformer& other); + void createLayers(); + static Variable conformerInitLinear(int32_t inDim, int32_t outDim); + Variable mhsa(const Variable& input, const Variable& inputPadMask); + Variable conv(const Variable& input); - Conformer() = default; + Conformer() = default; - FL_SAVE_LOAD_WITH_BASE( - Container, - w11_, - w12_, - w21_, - w22_, - wq_, - wk_, - wv_, - wf_, - normMhsa_, - norm1_, - norm2_, - norm3_, - normConv1_, - normConv2_, - conv1_, - conv2_, - convDepthWise_, - nHeads_, - pDropout_, - pLayerDropout_, - posEmbContextSize_, - convKernelSize_) + FL_SAVE_LOAD_WITH_BASE( + Container, + w11_, + w12_, + w21_, + w22_, + wq_, + wk_, + wv_, + wf_, + normMhsa_, + norm1_, + norm2_, + norm3_, + normConv1_, + normConv2_, + conv1_, + conv2_, + convDepthWise_, + nHeads_, + pDropout_, + pLayerDropout_, + posEmbContextSize_, + convKernelSize_ + ) }; } // namespace fl diff --git a/flashlight/fl/contrib/modules/PositionEmbedding.cpp b/flashlight/fl/contrib/modules/PositionEmbedding.cpp index e9e5b11..b94078a 100644 --- a/flashlight/fl/contrib/modules/PositionEmbedding.cpp +++ b/flashlight/fl/contrib/modules/PositionEmbedding.cpp @@ -19,54 +19,58 @@ namespace fl { PositionEmbedding::PositionEmbedding( int32_t layerDim, int32_t maxLen, - double dropout) - : dropout_(dropout) { - auto embeddings = uniform(layerDim, maxLen, -0.1, 0.1, fl::dtype::f32, true); - params_ = {embeddings}; + double dropout +) : dropout_(dropout) { + auto embeddings = uniform(layerDim, maxLen, -0.1, 0.1, fl::dtype::f32, true); + params_ = {embeddings}; } -PositionEmbedding::PositionEmbedding(const PositionEmbedding& other) - : Module(other.copyParams()), dropout_(other.dropout_) { - train_ = other.train_; +PositionEmbedding::PositionEmbedding(const PositionEmbedding& other) : Module(other.copyParams()), + dropout_(other.dropout_) { + train_ = other.train_; } PositionEmbedding& PositionEmbedding::operator=( - const PositionEmbedding& other) { - params_ = other.copyParams(); - train_ = other.train_; - dropout_ = other.dropout_; - return *this; + const PositionEmbedding& other +) { + params_ = other.copyParams(); + train_ = other.train_; + dropout_ = other.dropout_; + return *this; } std::vector PositionEmbedding::forward( - const std::vector& input) { - if (input[0].ndim() != 3) { - throw std::invalid_argument( - "PositionEmbedding::forward - expect a tensor with " - "3 dimensions - C x T x B"); - } + const std::vector& input +) { + if(input[0].ndim() != 3) { + throw std::invalid_argument( + "PositionEmbedding::forward - expect a tensor with " + "3 dimensions - C x T x B" + ); + } - int n = input[0].dim(1); - Variable posEmb = tileAs( - params_[0].astype(input[0].type())(fl::span, fl::range(0, n)), input[0]); - if (dropout_ > 0.0 && train_) { - return {input[0] + dropout(posEmb, dropout_)}; - } else { - return {input[0] + posEmb}; - } + int n = input[0].dim(1); + Variable posEmb = tileAs( + params_[0].astype(input[0].type())(fl::span, fl::range(0, n)), input[0]); + if(dropout_ > 0.0 && train_) { + return {input[0] + dropout(posEmb, dropout_)}; + } else { + return {input[0] + posEmb}; + } } std::vector PositionEmbedding::operator()( - const std::vector& input) { - return forward(input); + const std::vector& input +) { + return forward(input); } std::unique_ptr PositionEmbedding::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string PositionEmbedding::prettyString() const { - return "Position Embedding Layer"; + return "Position Embedding Layer"; } PositionEmbedding::PositionEmbedding() = default; diff --git a/flashlight/fl/contrib/modules/PositionEmbedding.h b/flashlight/fl/contrib/modules/PositionEmbedding.h index ca81ab0..8b9ed54 100644 --- a/flashlight/fl/contrib/modules/PositionEmbedding.h +++ b/flashlight/fl/contrib/modules/PositionEmbedding.h @@ -24,40 +24,40 @@ namespace fl { * */ class FL_API PositionEmbedding : public Module { - public: - PositionEmbedding(int32_t layerDim, int32_t maxLen, double dropout = 0); +public: + PositionEmbedding(int32_t layerDim, int32_t maxLen, double dropout = 0); - PositionEmbedding(const PositionEmbedding& other); + PositionEmbedding(const PositionEmbedding& other); - PositionEmbedding& operator=(const PositionEmbedding& other); + PositionEmbedding& operator=(const PositionEmbedding& other); - PositionEmbedding(PositionEmbedding&& other) = default; + PositionEmbedding(PositionEmbedding&& other) = default; - PositionEmbedding& operator=(PositionEmbedding&& other) = default; + PositionEmbedding& operator=(PositionEmbedding&& other) = default; - /** - * PositionEmbedding::forward(input) expects input[0] to be of - * dimensions C x T x B with C = layerDim and T <= maxLen. - * - * output[0] = input[0] + pos_emb, where pos_emb is a Tensor of dimensions - * C x T x B, and pos_emb = this.param_[0][:T], so pos_emb will be randomly - * initialized absolute position embeddings, that can be learned end-to-end. - * - */ - std::vector forward(const std::vector& input) override; + /** + * PositionEmbedding::forward(input) expects input[0] to be of + * dimensions C x T x B with C = layerDim and T <= maxLen. + * + * output[0] = input[0] + pos_emb, where pos_emb is a Tensor of dimensions + * C x T x B, and pos_emb = this.param_[0][:T], so pos_emb will be randomly + * initialized absolute position embeddings, that can be learned end-to-end. + * + */ + std::vector forward(const std::vector& input) override; - std::vector operator()(const std::vector& input); + std::vector operator()(const std::vector& input); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(Module, dropout_) +private: + FL_SAVE_LOAD_WITH_BASE(Module, dropout_) - double dropout_; + double dropout_; - PositionEmbedding(); + PositionEmbedding(); }; } // namespace fl diff --git a/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp b/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp index 26e77b5..49b0666 100644 --- a/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp +++ b/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp @@ -28,182 +28,189 @@ RawWavSpecAugment::RawWavSpecAugment( int highFreqHz /* 8000 */, int sampleRate /* 16000 */, int maxKernelSize /* 20000 */, - MaskingStrategy mStrategy /* = MaskingStrategy::ZERO */) - : timeWarpW_(tWarpW), - freqMaskF_(fMaskF), - numFreqMask_(nFMask), - timeMaskT_(tMaskT), - timeMaskP_(tMaskP), - numTimeMask_(nTMask), - maskStrategy_(mStrategy), - rawWavNMels_(nMels), - rawWavLowFreqHz_(lowFreqHz), - rawWavHighFreqHz_(highFreqHz), - rawWavSampleRate_(sampleRate), - maxKernelSize_(maxKernelSize) { - if (numFreqMask_ > 0 && freqMaskF_ <= 0) { - throw std::invalid_argument("invalid arguments for frequency masking."); - } - if (numTimeMask_ > 0 && timeMaskT_ <= 0) { - throw std::invalid_argument("invalid arguments for time masking."); - } - if (numTimeMask_ > 0 && (timeMaskP_ <= 0 || timeMaskP_ > 1.0)) { - throw std::invalid_argument("invalid arguments for time masking."); - } - if (rawWavLowFreqHz_ < 0 || rawWavHighFreqHz_ < 0 || - rawWavLowFreqHz_ >= rawWavHighFreqHz_) { - throw std::invalid_argument( - "invalid arguments for raw Wav high and low frequencies."); - } - if (rawWavNMels_ <= 0) { - throw std::invalid_argument("invalid arguments for raw Wav nMels."); - } - precomputeFilters(); + MaskingStrategy mStrategy /* = MaskingStrategy::ZERO */ +) : timeWarpW_(tWarpW), + freqMaskF_(fMaskF), + numFreqMask_(nFMask), + timeMaskT_(tMaskT), + timeMaskP_(tMaskP), + numTimeMask_(nTMask), + maskStrategy_(mStrategy), + rawWavNMels_(nMels), + rawWavLowFreqHz_(lowFreqHz), + rawWavHighFreqHz_(highFreqHz), + rawWavSampleRate_(sampleRate), + maxKernelSize_(maxKernelSize) { + if(numFreqMask_ > 0 && freqMaskF_ <= 0) { + throw std::invalid_argument("invalid arguments for frequency masking."); + } + if(numTimeMask_ > 0 && timeMaskT_ <= 0) { + throw std::invalid_argument("invalid arguments for time masking."); + } + if(numTimeMask_ > 0 && (timeMaskP_ <= 0 || timeMaskP_ > 1.0)) { + throw std::invalid_argument("invalid arguments for time masking."); + } + if( + rawWavLowFreqHz_ < 0 || rawWavHighFreqHz_ < 0 + || rawWavLowFreqHz_ >= rawWavHighFreqHz_ + ) { + throw std::invalid_argument( + "invalid arguments for raw Wav high and low frequencies." + ); + } + if(rawWavNMels_ <= 0) { + throw std::invalid_argument("invalid arguments for raw Wav nMels."); + } + precomputeFilters(); } void RawWavSpecAugment::precomputeFilters() { - if (!lowPassFilters_.empty()) { - return; - } - auto mel2hz = [](float mel) { - return 700.0 * (std::pow(10, (mel / 2595.0)) - 1.0); - }; - auto hz2mel = [](float hz) { return 2595.0 * std::log10(1.0 + hz / 700.0); }; - float minMel = hz2mel(rawWavLowFreqHz_), maxMel = hz2mel(rawWavHighFreqHz_); - // nMels intervals and nMels + 1 points - float delta = (maxMel - minMel) / rawWavNMels_; - float currentMel = minMel; - // set transition band as half of lowest bin frequency size (left bin) - // for lowest frequency set it to half of the right bin - // cutoff frequency and transmision band are stored from 0 to 0.5 of sampling - // rate - std::vector transBandKhz(rawWavNMels_ + 1); - for (int index = 0; index <= rawWavNMels_; index++) { - cutoff_.push_back(mel2hz(currentMel) / rawWavSampleRate_); - currentMel += delta; - if (index > 0) { - transBandKhz[index] = cutoff_[index - 1] / 4.; + if(!lowPassFilters_.empty()) { + return; + } + auto mel2hz = [](float mel) { + return 700.0 * (std::pow(10, (mel / 2595.0)) - 1.0); + }; + auto hz2mel = [](float hz) { return 2595.0 * std::log10(1.0 + hz / 700.0); }; + float minMel = hz2mel(rawWavLowFreqHz_), maxMel = hz2mel(rawWavHighFreqHz_); + // nMels intervals and nMels + 1 points + float delta = (maxMel - minMel) / rawWavNMels_; + float currentMel = minMel; + // set transition band as half of lowest bin frequency size (left bin) + // for lowest frequency set it to half of the right bin + // cutoff frequency and transmision band are stored from 0 to 0.5 of sampling + // rate + std::vector transBandKhz(rawWavNMels_ + 1); + for(int index = 0; index <= rawWavNMels_; index++) { + cutoff_.push_back(mel2hz(currentMel) / rawWavSampleRate_); + currentMel += delta; + if(index > 0) { + transBandKhz[index] = cutoff_[index - 1] / 4.; + } + } + transBandKhz[0] = transBandKhz[1]; + ignoredLowPassFilters_ = 0; + // compute filters for each frequency point, nMel + 1 low pass filters + for(int fidx = 0; fidx < cutoff_.size(); fidx++) { + int width = 2. / (1e-6 + transBandKhz[fidx]); + if(width * 2 + 1 > maxKernelSize_) { + FL_LOG(fl::LogLevel::INFO) + << "RawWavSpecAugment raw wave: frequency " << cutoff_[fidx] + << " will be skipped for eval, too large kernel"; + lowPassFilters_.push_back(nullptr); + ignoredLowPassFilters_++; + continue; + } + Tensor indexArr = fl::iota({2 * width + 1}); + Tensor blackmanWindow = 0.42 - 0.5 * fl::cos(M_PI * indexArr / width) + + 0.08 * fl::cos(2 * M_PI * indexArr / width); + Tensor denom = indexArr - width; + // compute sinc with proper process for index = width + Tensor kernel = fl::sin(2 * M_PI * cutoff_[fidx] * (indexArr - width)); + kernel(denom != 0) = kernel(denom != 0) / denom(denom != 0); + kernel(denom == 0) = 2 * M_PI * cutoff_[fidx]; + kernel = kernel * blackmanWindow; + // normalize kernel + kernel = kernel / fl::tile(fl::sum(kernel, {0}), {2 * width + 1}); + // create low pass filter + auto filter = std::make_shared( + Variable(fl::reshape(kernel, {kernel.dim(0), 1, 1, 1}), false), + 1, + 1, + PaddingMode::SAME, + 0 + ); + filter->eval(); + lowPassFilters_.push_back(filter); } - } - transBandKhz[0] = transBandKhz[1]; - ignoredLowPassFilters_ = 0; - // compute filters for each frequency point, nMel + 1 low pass filters - for (int fidx = 0; fidx < cutoff_.size(); fidx++) { - int width = 2. / (1e-6 + transBandKhz[fidx]); - if (width * 2 + 1 > maxKernelSize_) { - FL_LOG(fl::LogLevel::INFO) - << "RawWavSpecAugment raw wave: frequency " << cutoff_[fidx] - << " will be skipped for eval, too large kernel"; - lowPassFilters_.push_back(nullptr); - ignoredLowPassFilters_++; - continue; + if(ignoredLowPassFilters_ >= lowPassFilters_.size()) { + throw std::invalid_argument( + "All low pass filters are ignored, too huge kernel for all frequencies" + ); } - Tensor indexArr = fl::iota({2 * width + 1}); - Tensor blackmanWindow = 0.42 - 0.5 * fl::cos(M_PI * indexArr / width) + - 0.08 * fl::cos(2 * M_PI * indexArr / width); - Tensor denom = indexArr - width; - // compute sinc with proper process for index = width - Tensor kernel = fl::sin(2 * M_PI * cutoff_[fidx] * (indexArr - width)); - kernel(denom != 0) = kernel(denom != 0) / denom(denom != 0); - kernel(denom == 0) = 2 * M_PI * cutoff_[fidx]; - kernel = kernel * blackmanWindow; - // normalize kernel - kernel = kernel / fl::tile(fl::sum(kernel, {0}), {2 * width + 1}); - // create low pass filter - auto filter = std::make_shared( - Variable(fl::reshape(kernel, {kernel.dim(0), 1, 1, 1}), false), - 1, - 1, - PaddingMode::SAME, - 0); - filter->eval(); - lowPassFilters_.push_back(filter); - } - if (ignoredLowPassFilters_ >= lowPassFilters_.size()) { - throw std::invalid_argument( - "All low pass filters are ignored, too huge kernel for all frequencies"); - } } Variable RawWavSpecAugment::forward(const Variable& input) { - if (input.isCalcGrad()) { - throw std::invalid_argument( - "input gradient calculation is not supported for RawWavSpecAugment."); - } - if (lowPassFilters_.empty()) { - throw std::invalid_argument("invalid RawWavSpecAugment, filters are empty"); - } + if(input.isCalcGrad()) { + throw std::invalid_argument( + "input gradient calculation is not supported for RawWavSpecAugment." + ); + } + if(lowPassFilters_.empty()) { + throw std::invalid_argument("invalid RawWavSpecAugment, filters are empty"); + } - fl::Variable inputCast = detail::adjustInputType(input, "RawWavSpecAugment"); - auto output = Variable(inputCast.tensor(), false); - if (!train_) { - return output; - } + fl::Variable inputCast = detail::adjustInputType(input, "RawWavSpecAugment"); + auto output = Variable(inputCast.tensor(), false); + if(!train_) { + return output; + } - if (input.ndim() != 3) { - throw std::invalid_argument( - "RawWavSpecAugment::forward - invalid input shape: " - "input is expected to be T x C x B"); - } + if(input.ndim() != 3) { + throw std::invalid_argument( + "RawWavSpecAugment::forward - invalid input shape: " + "input is expected to be T x C x B" + ); + } - // input is expected T x C x B (mostly C=1) - const Shape& inShape = inputCast.shape(); - // Conv2D input must be 4 dims (W x H x C x N) (N = batch size) - Shape timeView = {inShape[0], inShape[1] * inShape[2], 1, 1}; - for (int i = 0; i < numFreqMask_; ++i) { - auto low = generateRandomInt(ignoredLowPassFilters_, rawWavNMels_); - auto high = - generateRandomInt(low, std::min(rawWavNMels_, low + freqMaskF_) + 1); - if (high > low) { - auto inputForFilter = fl::moddims(output, timeView); - auto midLowWav = lowPassFilters_[high]->forward(inputForFilter); - auto lowWav = lowPassFilters_[low]->forward(inputForFilter); - output = output - fl::moddims(midLowWav - lowWav, inputCast.shape()); + // input is expected T x C x B (mostly C=1) + const Shape& inShape = inputCast.shape(); + // Conv2D input must be 4 dims (W x H x C x N) (N = batch size) + Shape timeView = {inShape[0], inShape[1] * inShape[2], 1, 1}; + for(int i = 0; i < numFreqMask_; ++i) { + auto low = generateRandomInt(ignoredLowPassFilters_, rawWavNMels_); + auto high = + generateRandomInt(low, std::min(rawWavNMels_, low + freqMaskF_) + 1); + if(high > low) { + auto inputForFilter = fl::moddims(output, timeView); + auto midLowWav = lowPassFilters_[high]->forward(inputForFilter); + auto lowWav = lowPassFilters_[low]->forward(inputForFilter); + output = output - fl::moddims(midLowWav - lowWav, inputCast.shape()); + } } - } - double replaceVal = (maskStrategy_ == MaskingStrategy::GLOBAL_MEAN) - ? fl::mean(inputCast.tensor()).asScalar() - : 0.0; + double replaceVal = (maskStrategy_ == MaskingStrategy::GLOBAL_MEAN) + ? fl::mean(inputCast.tensor()).asScalar() + : 0.0; - auto& opArr = output.tensor(); - auto numTimeSteps = inputCast.dim(0); // number of time steps - // an upper bound on the time mask - int T = std::min(timeMaskT_, static_cast(numTimeSteps * timeMaskP_)); - if (T > 0) { - for (int i = 0; i < numTimeMask_; ++i) { - auto t = generateRandomInt(0, T); - auto t0 = generateRandomInt(0, numTimeSteps - t); - opArr(fl::range(t0, t0 + t + 1)) = replaceVal; + auto& opArr = output.tensor(); + auto numTimeSteps = inputCast.dim(0); // number of time steps + // an upper bound on the time mask + int T = std::min(timeMaskT_, static_cast(numTimeSteps * timeMaskP_)); + if(T > 0) { + for(int i = 0; i < numTimeMask_; ++i) { + auto t = generateRandomInt(0, T); + auto t0 = generateRandomInt(0, numTimeSteps - t); + opArr(fl::range(t0, t0 + t + 1)) = replaceVal; + } } - } - return output; + return output; } int RawWavSpecAugment::generateRandomInt(int low, int high) { - std::uniform_int_distribution uniformDist(low, high - 1); - return uniformDist(eng_); + std::uniform_int_distribution uniformDist(low, high - 1); + return uniformDist(eng_); } std::unique_ptr RawWavSpecAugment::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string RawWavSpecAugment::prettyString() const { - std::ostringstream ss; - ss << "RawWavSpecAugment ( "; - ss << "W: " << timeWarpW_ << ", "; - ss << "F: " << freqMaskF_ << ", "; - ss << "mF: " << numFreqMask_ << ", "; - ss << "T: " << timeMaskT_ << ", "; - ss << "p: " << timeMaskP_ << ", "; - ss << "mT: " << numTimeMask_ << ", "; - ss << "rawWavNMels: " << rawWavNMels_ << ", "; - ss << "rawWavLowFreqHz: " << rawWavLowFreqHz_ << ", "; - ss << "rawWavHighFreqHz: " << rawWavHighFreqHz_ << ", "; - ss << "rawWavSampleRate: " << rawWavSampleRate_ << ", "; - ss << "maxKernelSize: " << maxKernelSize_ << ", "; - ss << " )"; - return ss.str(); + std::ostringstream ss; + ss << "RawWavSpecAugment ( "; + ss << "W: " << timeWarpW_ << ", "; + ss << "F: " << freqMaskF_ << ", "; + ss << "mF: " << numFreqMask_ << ", "; + ss << "T: " << timeMaskT_ << ", "; + ss << "p: " << timeMaskP_ << ", "; + ss << "mT: " << numTimeMask_ << ", "; + ss << "rawWavNMels: " << rawWavNMels_ << ", "; + ss << "rawWavLowFreqHz: " << rawWavLowFreqHz_ << ", "; + ss << "rawWavHighFreqHz: " << rawWavHighFreqHz_ << ", "; + ss << "rawWavSampleRate: " << rawWavSampleRate_ << ", "; + ss << "maxKernelSize: " << maxKernelSize_ << ", "; + ss << " )"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/contrib/modules/RawWavSpecAugment.h b/flashlight/fl/contrib/modules/RawWavSpecAugment.h index aaf6144..1582a32 100644 --- a/flashlight/fl/contrib/modules/RawWavSpecAugment.h +++ b/flashlight/fl/contrib/modules/RawWavSpecAugment.h @@ -28,106 +28,111 @@ namespace fl { * mask these bins. bins are created on [lowFreqHz, highFreqHz] range. Time mask * is measured in number of input frames: there are sampleRate frames in 1s * audio, e.g. 50 frames for time masking of standard specAug corresponds to - *8000 frames (in case of 16kHz audio) for time masking with raw wave specaug + * 8000 frames (in case of 16kHz audio) for time masking with raw wave specaug **/ class FL_API RawWavSpecAugment : public UnaryModule { - public: - enum class MaskingStrategy { - ZERO = 0, - GLOBAL_MEAN = 1, - // TODO - add support for mean along time, freq axes - }; - - RawWavSpecAugment( - int tWarpW, - int fMaskF, - int nFMask, - int tMaskT, - float tMaskP, - int nTMask, - int nMels = 80, - int lowFreqHz = 0, - int highFreqHz = 8000, - int sampleRate = 16000, - int maxKernelSize = 20000, - MaskingStrategy mStrategy = MaskingStrategy::ZERO); - - Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; - - private: - // Time Warping - NOT SUPPORTED CURRENTLY - // Use timeWarpW_ = 0 to disable this - int timeWarpW_; - - // Frequency Masking - // Use freqMaskF_ = 0 to disable this - int freqMaskF_; - int numFreqMask_; - - // Time Masking - // Use timeMaskT_ = 0 to disable this - int timeMaskT_; - float timeMaskP_; - int numTimeMask_; - - std::mt19937 eng_{0}; - MaskingStrategy maskStrategy_; - - int rawWavNMels_; - int rawWavLowFreqHz_; - int rawWavHighFreqHz_; - int rawWavSampleRate_; - int maxKernelSize_; - int ignoredLowPassFilters_; - std::vector cutoff_; - std::vector> lowPassFilters_; - - int generateRandomInt(int low, int high); - - void precomputeFilters(); - - Tensor lowPassFilter(int freq, Tensor wav); - - RawWavSpecAugment() = default; - - FL_SAVE_LOAD_DECLARE() +public: + enum class MaskingStrategy { + ZERO = 0, + GLOBAL_MEAN = 1, + // TODO - add support for mean along time, freq axes + }; + + RawWavSpecAugment( + int tWarpW, + int fMaskF, + int nFMask, + int tMaskT, + float tMaskP, + int nTMask, + int nMels = 80, + int lowFreqHz = 0, + int highFreqHz = 8000, + int sampleRate = 16000, + int maxKernelSize = 20000, + MaskingStrategy mStrategy = MaskingStrategy::ZERO + ); + + Variable forward(const Variable& input) override; + std::unique_ptr clone() const override; + std::string prettyString() const override; + +private: + // Time Warping - NOT SUPPORTED CURRENTLY + // Use timeWarpW_ = 0 to disable this + int timeWarpW_; + + // Frequency Masking + // Use freqMaskF_ = 0 to disable this + int freqMaskF_; + int numFreqMask_; + + // Time Masking + // Use timeMaskT_ = 0 to disable this + int timeMaskT_; + float timeMaskP_; + int numTimeMask_; + + std::mt19937 eng_{0}; + MaskingStrategy maskStrategy_; + + int rawWavNMels_; + int rawWavLowFreqHz_; + int rawWavHighFreqHz_; + int rawWavSampleRate_; + int maxKernelSize_; + int ignoredLowPassFilters_; + std::vector cutoff_; + std::vector> lowPassFilters_; + + int generateRandomInt(int low, int high); + + void precomputeFilters(); + + Tensor lowPassFilter(int freq, Tensor wav); + + RawWavSpecAugment() = default; + + FL_SAVE_LOAD_DECLARE() }; -template +template void RawWavSpecAugment::save(Archive& ar, const uint32_t /* version */) const { - ar(cereal::base_class(this), - timeWarpW_, - freqMaskF_, - numFreqMask_, - timeMaskT_, - timeMaskP_, - numTimeMask_, - maskStrategy_, - rawWavNMels_, - rawWavLowFreqHz_, - rawWavHighFreqHz_, - rawWavSampleRate_, - maxKernelSize_); + ar( + cereal::base_class(this), + timeWarpW_, + freqMaskF_, + numFreqMask_, + timeMaskT_, + timeMaskP_, + numTimeMask_, + maskStrategy_, + rawWavNMels_, + rawWavLowFreqHz_, + rawWavHighFreqHz_, + rawWavSampleRate_, + maxKernelSize_ + ); } -template +template void RawWavSpecAugment::load(Archive& ar, const uint32_t /* version */) { - ar(cereal::base_class(this), - timeWarpW_, - freqMaskF_, - numFreqMask_, - timeMaskT_, - timeMaskP_, - numTimeMask_, - maskStrategy_, - rawWavNMels_, - rawWavLowFreqHz_, - rawWavHighFreqHz_, - rawWavSampleRate_, - maxKernelSize_); - precomputeFilters(); + ar( + cereal::base_class(this), + timeWarpW_, + freqMaskF_, + numFreqMask_, + timeMaskT_, + timeMaskP_, + numTimeMask_, + maskStrategy_, + rawWavNMels_, + rawWavLowFreqHz_, + rawWavHighFreqHz_, + rawWavSampleRate_, + maxKernelSize_ + ); + precomputeFilters(); } } // namespace fl diff --git a/flashlight/fl/contrib/modules/Residual.cpp b/flashlight/fl/contrib/modules/Residual.cpp index 3a10a52..528347d 100644 --- a/flashlight/fl/contrib/modules/Residual.cpp +++ b/flashlight/fl/contrib/modules/Residual.cpp @@ -13,172 +13,185 @@ namespace fl { std::unordered_set Residual::getProjectionsIndices() const { - return projectionsIndices_; + return projectionsIndices_; } void Residual::addScale(int beforeLayer, float scale) { - int nLayers = modules_.size() - projectionsIndices_.size(); - if (beforeLayer < 1 || beforeLayer > nLayers + 1) { - throw std::invalid_argument( - "Residual: invalid layer index " + std::to_string(beforeLayer) + - " before which apply the scaling"); - } - if (scales_.find(beforeLayer - 1) != scales_.end()) { - throw std::invalid_argument( - "Residual: scaling before layer " + std::to_string(beforeLayer) + - " was already added; adding only once is allowed"); - } - scales_[beforeLayer - 1] = scale; + int nLayers = modules_.size() - projectionsIndices_.size(); + if(beforeLayer < 1 || beforeLayer > nLayers + 1) { + throw std::invalid_argument( + "Residual: invalid layer index " + std::to_string(beforeLayer) + + " before which apply the scaling" + ); + } + if(scales_.find(beforeLayer - 1) != scales_.end()) { + throw std::invalid_argument( + "Residual: scaling before layer " + std::to_string(beforeLayer) + + " was already added; adding only once is allowed" + ); + } + scales_[beforeLayer - 1] = scale; } void Residual::checkShortcut(int fromLayer, int toLayer) { - int nLayers = modules_.size() - projectionsIndices_.size(); - - if (fromLayer < 0 || fromLayer >= nLayers || toLayer <= 0 || - toLayer > nLayers + 2 || toLayer - fromLayer <= 1) { - throw std::invalid_argument( - "Residual: invalid skip connection; check fromLayer=" + - std::to_string(fromLayer) + " and toLayer=" + std::to_string(toLayer) + - " parameters. They are out of range of added layers"); - } - if (shortcut_.find(toLayer - 1) != shortcut_.end() && - shortcut_[toLayer - 1].find(fromLayer) != shortcut_[toLayer - 1].end()) { - throw std::invalid_argument( - "Residual: skip connection for fromLayer " + std::to_string(fromLayer) + - " to toLayer " + std::to_string(toLayer) + " is already added"); - } + int nLayers = modules_.size() - projectionsIndices_.size(); + + if( + fromLayer < 0 || fromLayer >= nLayers || toLayer <= 0 + || toLayer > nLayers + 2 || toLayer - fromLayer <= 1 + ) { + throw std::invalid_argument( + "Residual: invalid skip connection; check fromLayer=" + + std::to_string(fromLayer) + " and toLayer=" + std::to_string(toLayer) + + " parameters. They are out of range of added layers" + ); + } + if( + shortcut_.find(toLayer - 1) != shortcut_.end() + && shortcut_[toLayer - 1].find(fromLayer) != shortcut_[toLayer - 1].end() + ) { + throw std::invalid_argument( + "Residual: skip connection for fromLayer " + std::to_string(fromLayer) + + " to toLayer " + std::to_string(toLayer) + " is already added" + ); + } } void Residual::processShortcut( int fromLayer, int toLayer, - int projectionIndex) { - shortcut_[toLayer - 1].insert({fromLayer, projectionIndex}); + int projectionIndex +) { + shortcut_[toLayer - 1].insert({fromLayer, projectionIndex}); } void Residual::addShortcut(int fromLayer, int toLayer) { - // fromLayer: 0, .., nLayers_ - 1; toLayer: 1, 2, .., nLayers_ + 1 - // toLayer - fromLayer > 1 (avoid adding skip connection - // from layer K to layer K+1) - checkShortcut(fromLayer, toLayer); - processShortcut(fromLayer, toLayer, -1); + // fromLayer: 0, .., nLayers_ - 1; toLayer: 1, 2, .., nLayers_ + 1 + // toLayer - fromLayer > 1 (avoid adding skip connection + // from layer K to layer K+1) + checkShortcut(fromLayer, toLayer); + processShortcut(fromLayer, toLayer, -1); } Variable Residual::applyScale(const Variable& input, const int layerIndex) { - float scale = - scales_.find(layerIndex) != scales_.end() ? scales_[layerIndex] : 1.; - return input * scale; + float scale = + scales_.find(layerIndex) != scales_.end() ? scales_[layerIndex] : 1.; + return input * scale; } std::vector Residual::forward(const std::vector& inputs) { - if (inputs.size() != 1) { - throw std::invalid_argument("Residual module expects only one input"); - } - return {forward(inputs[0])}; + if(inputs.size() != 1) { + throw std::invalid_argument("Residual module expects only one input"); + } + return {forward(inputs[0])}; } Variable Residual::forward(const Variable& input) { - Variable output = input; - int nLayers = modules_.size() - projectionsIndices_.size(); - std::vector outputs(nLayers + 1, Variable()); - outputs[0] = input; + Variable output = input; + int nLayers = modules_.size() - projectionsIndices_.size(); + std::vector outputs(nLayers + 1, Variable()); + outputs[0] = input; - int moduleIndex = 0, layerIndex = 0; + int moduleIndex = 0, layerIndex = 0; - while (layerIndex < nLayers) { - while (projectionsIndices_.find(moduleIndex) != projectionsIndices_.end()) { - moduleIndex++; - } - if (shortcut_.find(layerIndex) != shortcut_.end()) { - for (const auto& shortcut : shortcut_[layerIndex]) { - Variable connectionOut = outputs[shortcut.first]; - if (shortcut.second != -1) { - connectionOut = modules_[shortcut.second] - ->forward({outputs[shortcut.first]}) - .front(); + while(layerIndex < nLayers) { + while(projectionsIndices_.find(moduleIndex) != projectionsIndices_.end()) { + moduleIndex++; + } + if(shortcut_.find(layerIndex) != shortcut_.end()) { + for(const auto& shortcut : shortcut_[layerIndex]) { + Variable connectionOut = outputs[shortcut.first]; + if(shortcut.second != -1) { + connectionOut = modules_[shortcut.second] + ->forward({outputs[shortcut.first]}) + .front(); + } + output = output + connectionOut.astype(output.type()); + } } - output = output + connectionOut.astype(output.type()); - } + output = modules_[moduleIndex] + ->forward({applyScale(output, layerIndex)}) + .front(); + outputs[layerIndex + 1] = output; + layerIndex++; + moduleIndex++; } - output = modules_[moduleIndex] - ->forward({applyScale(output, layerIndex)}) - .front(); - outputs[layerIndex + 1] = output; - layerIndex++; - moduleIndex++; - } - if (shortcut_.find(nLayers) != shortcut_.end()) { - for (const auto& shortcut : shortcut_[nLayers]) { - Variable connectionOut = outputs[shortcut.first]; - if (shortcut.second != -1) { - connectionOut = modules_[shortcut.second] - ->forward({outputs[shortcut.first]}) - .front(); - } - output = output + connectionOut.astype(output.type()); + if(shortcut_.find(nLayers) != shortcut_.end()) { + for(const auto& shortcut : shortcut_[nLayers]) { + Variable connectionOut = outputs[shortcut.first]; + if(shortcut.second != -1) { + connectionOut = modules_[shortcut.second] + ->forward({outputs[shortcut.first]}) + .front(); + } + output = output + connectionOut.astype(output.type()); + } } - } - return applyScale(output, nLayers); + return applyScale(output, nLayers); } std::string Residual::prettyString() const { - std::ostringstream ss; - // prepare inverted residual skip connection - std::unordered_map> - reverseShortcut; // start -> end - for (const auto& shortcut : shortcut_) { - for (const auto& value : shortcut.second) { - reverseShortcut[value.first].insert({shortcut.first, value.second}); - } - } - - int nLayers = modules_.size() - projectionsIndices_.size(); - int moduleIndex = -1, layerIndex = 0; - std::unordered_map::const_iterator scaleIt; - - while (layerIndex <= nLayers) { - ss << "\n\tRes(" << layerIndex << "): "; - if (layerIndex == 0) { - ss << "Input"; - } else { - while (projectionsIndices_.find(moduleIndex) != - projectionsIndices_.end()) { - moduleIndex++; - } - ss << modules_[moduleIndex]->prettyString(); + std::ostringstream ss; + // prepare inverted residual skip connection + std::unordered_map> + reverseShortcut; // start -> end + for(const auto& shortcut : shortcut_) { + for(const auto& value : shortcut.second) { + reverseShortcut[value.first].insert({shortcut.first, value.second}); + } } - scaleIt = scales_.find(layerIndex); - if (scaleIt != scales_.end()) { - ss << " with scale (before layer is applied) " << scaleIt->second << ";"; - } + int nLayers = modules_.size() - projectionsIndices_.size(); + int moduleIndex = -1, layerIndex = 0; + std::unordered_map::const_iterator scaleIt; - if (reverseShortcut.find(layerIndex) != reverseShortcut.end() && - !reverseShortcut[layerIndex].empty()) { - ss << "; skip connection to "; - for (auto shortcut : reverseShortcut[layerIndex]) { - if (shortcut.first < nLayers) { - ss << "layer Res(" << shortcut.first + 1 << ")"; + while(layerIndex <= nLayers) { + ss << "\n\tRes(" << layerIndex << "): "; + if(layerIndex == 0) { + ss << "Input"; } else { - ss << "output"; + while( + projectionsIndices_.find(moduleIndex) + != projectionsIndices_.end() + ) { + moduleIndex++; + } + ss << modules_[moduleIndex]->prettyString(); } - if (shortcut.second != -1) { - ss << " with transformation: " - << modules_[shortcut.second]->prettyString() << ";"; + + scaleIt = scales_.find(layerIndex); + if(scaleIt != scales_.end()) { + ss << " with scale (before layer is applied) " << scaleIt->second << ";"; } - ss << " "; - } + + if( + reverseShortcut.find(layerIndex) != reverseShortcut.end() + && !reverseShortcut[layerIndex].empty() + ) { + ss << "; skip connection to "; + for(auto shortcut : reverseShortcut[layerIndex]) { + if(shortcut.first < nLayers) { + ss << "layer Res(" << shortcut.first + 1 << ")"; + } else { + ss << "output"; + } + if(shortcut.second != -1) { + ss << " with transformation: " + << modules_[shortcut.second]->prettyString() << ";"; + } + ss << " "; + } + } + layerIndex++; + moduleIndex++; } - layerIndex++; - moduleIndex++; - } - ss << "\n\tRes(" << nLayers + 1 << "): Output;"; - scaleIt = scales_.find(nLayers + 1); - if (scaleIt != scales_.end()) { - ss << " with scale (before layer is applied) " << scaleIt->second << ";"; - } - - return ss.str(); + ss << "\n\tRes(" << nLayers + 1 << "): Output;"; + scaleIt = scales_.find(nLayers + 1); + if(scaleIt != scales_.end()) { + ss << " with scale (before layer is applied) " << scaleIt->second << ";"; + } + + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/contrib/modules/Residual.h b/flashlight/fl/contrib/modules/Residual.h index d50348d..84bb03a 100644 --- a/flashlight/fl/contrib/modules/Residual.h +++ b/flashlight/fl/contrib/modules/Residual.h @@ -24,119 +24,119 @@ namespace fl { /** * A module for creating a generic Residual block as given by [He et al - (2015)](https://arxiv.org/abs/1512.03385) and [Kumar et al - (2015)](https://arxiv.org/abs/1505.00387). + (2015)](https://arxiv.org/abs/1512.03385) and [Kumar et al + (2015)](https://arxiv.org/abs/1505.00387). * * Example: - \code{.cpp} - auto res = Residual(); - // Add multiple layers - res.add(Conv2D(30, 50, 9, 7, 2, 3, 3, 2)); - res.add(BatchNorm(2, 50)); - res.add(ReLU()); - // Add a shortcut from the input to the block to the third layer - res.addShortcut(0, 3); - // Add a shortcut from the second layer to the output - res.addShortcut(2, 4); - // Scale the inputs to the third layer by some constant - res.addScale(3, 0.5); - - // Create a model - Sequential model; - // ... - // Add our residual block as needed - model.add(res); - model.add(Pool2D(2, 3, 1, 1, 1, 1, PoolingMode::MAX)); - model.add(res); - // ... - \endcode + \code{.cpp} + auto res = Residual(); + // Add multiple layers + res.add(Conv2D(30, 50, 9, 7, 2, 3, 3, 2)); + res.add(BatchNorm(2, 50)); + res.add(ReLU()); + // Add a shortcut from the input to the block to the third layer + res.addShortcut(0, 3); + // Add a shortcut from the second layer to the output + res.addShortcut(2, 4); + // Scale the inputs to the third layer by some constant + res.addScale(3, 0.5); + + // Create a model + Sequential model; + // ... + // Add our residual block as needed + model.add(res); + model.add(Pool2D(2, 3, 1, 1, 1, 1, PoolingMode::MAX)); + model.add(res); + // ... + \endcode */ class FL_API Residual : public Container { - private: - FL_SAVE_LOAD_WITH_BASE(Container, shortcut_, scales_, projectionsIndices_) - - void checkShortcut(int fromLayer, int toLayer); - void processShortcut(int fromLayer, int toLayer, int projectionIndex); - Variable applyScale(const Variable& input, const int layerIndex); - - // Maps end -> start - std::unordered_map> shortcut_; - // Indices of projection layers - std::unordered_set projectionsIndices_; - std::unordered_map scales_; - - public: - Residual() = default; - - std::unordered_set getProjectionsIndices() const; - - /** - * Adds a scaling factor to all residual connections connecting to a layer - * given by some index index. Given some scale \f$ \alpha \f$, the input to - * ``beforeLayer`` becomes \f$ (x + f(x)) * \alpha \f$. - * - * @param[in] beforeLayer the index of the layer to which to scale the input - * and residual connection output. - * @param[in] scale the value by which to scale the sum of the previous layer - * and output of the residual connection. - */ - void addScale(int beforeLayer, float scale); - - /** - * Adds a shortcut between two layers. - * - * @param[in] fromLayer the layer index from which the skip connection will - * originate; must be in the range \f$ [0, N_{layers} - 1] \f$. If the index 0 - * is used, the input to the shortcut will be equal to the input to the - * residual block. - * @param[in] toLayer the layer index to which the skip connection outputs a - * tensor; must be in the range \f$ [1, N_{layers} + 1] \f$. If the index - * \f$ N_{layers} + 1 \f$ is used, the output of the shortcut will be added to - * the output of the entire residual block. - */ - void addShortcut(int fromLayer, int toLayer); - - /** - * See ``Residual::addShortcut``. - */ - template - void addShortcut(int fromLayer, int toLayer, const T& module) { - addShortcut(fromLayer, toLayer, std::make_shared(module)); - } - - /** - * Adds a shortcut connection between two layers such that tensors passed - * through the connection are forwarded through a passed module before being - * added to the resultant module's input. Can be used to reshape the output of - * input module to match the input dimensions for the output module. - * - * @param[in] fromLayer the layer index from which the shortcut connection - * will originate; must be in the range \f$ [0, N_{layers} - 1] \f$. If the - * index 0 is used, the input to the shortcut will be equal to the input to - * the residual block. - * @param[in] toLayer the layer index to which the shortcut connection outputs - * a tensor; must be in the range \f$ [1, N_{layers} + 1] \f$. If the index - * \f$ N_{layers} + 1 \f$ is used, the output of the shortcut will be added - * to the output of the entire residual block. - * @param[in] module a specified module through which the input to the - * shortcut connection will be forwarded before being added to the input to - * the destination module. - */ - template - void addShortcut(int fromLayer, int toLayer, std::shared_ptr module) { - checkShortcut(fromLayer, toLayer); - Container::add(module); - processShortcut(fromLayer, toLayer, modules_.size() - 1); - projectionsIndices_.insert(modules_.size() - 1); - } - - std::vector forward(const std::vector& inputs) override; - - Variable forward(const Variable& input); - - std::string prettyString() const override; - - FL_BASIC_CONTAINER_CLONING(Residual) +private: + FL_SAVE_LOAD_WITH_BASE(Container, shortcut_, scales_, projectionsIndices_) + + void checkShortcut(int fromLayer, int toLayer); + void processShortcut(int fromLayer, int toLayer, int projectionIndex); + Variable applyScale(const Variable& input, const int layerIndex); + + // Maps end -> start + std::unordered_map> shortcut_; + // Indices of projection layers + std::unordered_set projectionsIndices_; + std::unordered_map scales_; + +public: + Residual() = default; + + std::unordered_set getProjectionsIndices() const; + + /** + * Adds a scaling factor to all residual connections connecting to a layer + * given by some index index. Given some scale \f$ \alpha \f$, the input to + * ``beforeLayer`` becomes \f$ (x + f(x)) * \alpha \f$. + * + * @param[in] beforeLayer the index of the layer to which to scale the input + * and residual connection output. + * @param[in] scale the value by which to scale the sum of the previous layer + * and output of the residual connection. + */ + void addScale(int beforeLayer, float scale); + + /** + * Adds a shortcut between two layers. + * + * @param[in] fromLayer the layer index from which the skip connection will + * originate; must be in the range \f$ [0, N_{layers} - 1] \f$. If the index 0 + * is used, the input to the shortcut will be equal to the input to the + * residual block. + * @param[in] toLayer the layer index to which the skip connection outputs a + * tensor; must be in the range \f$ [1, N_{layers} + 1] \f$. If the index + * \f$ N_{layers} + 1 \f$ is used, the output of the shortcut will be added to + * the output of the entire residual block. + */ + void addShortcut(int fromLayer, int toLayer); + + /** + * See ``Residual::addShortcut``. + */ + template + void addShortcut(int fromLayer, int toLayer, const T& module) { + addShortcut(fromLayer, toLayer, std::make_shared(module)); + } + + /** + * Adds a shortcut connection between two layers such that tensors passed + * through the connection are forwarded through a passed module before being + * added to the resultant module's input. Can be used to reshape the output of + * input module to match the input dimensions for the output module. + * + * @param[in] fromLayer the layer index from which the shortcut connection + * will originate; must be in the range \f$ [0, N_{layers} - 1] \f$. If the + * index 0 is used, the input to the shortcut will be equal to the input to + * the residual block. + * @param[in] toLayer the layer index to which the shortcut connection outputs + * a tensor; must be in the range \f$ [1, N_{layers} + 1] \f$. If the index + * \f$ N_{layers} + 1 \f$ is used, the output of the shortcut will be added + * to the output of the entire residual block. + * @param[in] module a specified module through which the input to the + * shortcut connection will be forwarded before being added to the input to + * the destination module. + */ + template + void addShortcut(int fromLayer, int toLayer, std::shared_ptr module) { + checkShortcut(fromLayer, toLayer); + Container::add(module); + processShortcut(fromLayer, toLayer, modules_.size() - 1); + projectionsIndices_.insert(modules_.size() - 1); + } + + std::vector forward(const std::vector& inputs) override; + + Variable forward(const Variable& input); + + std::string prettyString() const override; + + FL_BASIC_CONTAINER_CLONING(Residual) }; } // namespace fl diff --git a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp index 3897015..f35bf5e 100644 --- a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp +++ b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp @@ -20,84 +20,91 @@ namespace fl { SinusoidalPositionEmbedding::SinusoidalPositionEmbedding( int32_t layerDim, - double inputScale /* = 1. */) - : layerDim_(layerDim), inputScale_(inputScale) { - // 'scale_' chosen based on positional embedding from: - // Attention is All You Need - // Ashish Vaswani, et al. (2017) - // https://arxiv.org/pdf/1706.03762.pdf - // Create an `iota` that looks like `[0,0,1,1,2,2,...]`, use it as the scale. - scale_ = fl::exp( - -2 * fl::floor((fl::iota({layerDim_}) / 2)) * std::log(10000) / - layerDim_); - // Create a Cosine phase shift that acts on indices like `[0,1,0,1,...]` - const double sinToCosPhaseShift = M_PI / 2.0; - cosShifts_ = sinToCosPhaseShift * fl::iota({layerDim_}) % 2; - // In the forward pass, the even indices of embedding vectors will have the - // Sine function applied and the odd indices will have the Cosine function - // applied. + double inputScale /* = 1. */ +) : layerDim_(layerDim), + inputScale_(inputScale) { + // 'scale_' chosen based on positional embedding from: + // Attention is All You Need + // Ashish Vaswani, et al. (2017) + // https://arxiv.org/pdf/1706.03762.pdf + // Create an `iota` that looks like `[0,0,1,1,2,2,...]`, use it as the scale. + scale_ = fl::exp( + -2 * fl::floor((fl::iota({layerDim_}) / 2)) * std::log(10000) + / layerDim_ + ); + // Create a Cosine phase shift that acts on indices like `[0,1,0,1,...]` + const double sinToCosPhaseShift = M_PI / 2.0; + cosShifts_ = sinToCosPhaseShift * fl::iota({layerDim_}) % 2; + // In the forward pass, the even indices of embedding vectors will have the + // Sine function applied and the odd indices will have the Cosine function + // applied. } SinusoidalPositionEmbedding::SinusoidalPositionEmbedding( - const SinusoidalPositionEmbedding& other) - : layerDim_(other.layerDim_), - inputScale_(other.inputScale_), - scale_(other.scale_.copy()), - cosShifts_(other.cosShifts_.copy()) {} + const SinusoidalPositionEmbedding& other +) : layerDim_(other.layerDim_), + inputScale_(other.inputScale_), + scale_(other.scale_.copy()), + cosShifts_(other.cosShifts_.copy()) {} SinusoidalPositionEmbedding& SinusoidalPositionEmbedding::operator=( - const SinusoidalPositionEmbedding& other) { - layerDim_ = other.layerDim_; - inputScale_ = other.inputScale_; - scale_ = other.scale_.copy(); - cosShifts_ = other.cosShifts_.copy(); - return *this; + const SinusoidalPositionEmbedding& other +) { + layerDim_ = other.layerDim_; + inputScale_ = other.inputScale_; + scale_ = other.scale_.copy(); + cosShifts_ = other.cosShifts_.copy(); + return *this; } std::vector SinusoidalPositionEmbedding::forward( - const std::vector& input) { - if (input[0].dim(0) != layerDim_) { - throw std::invalid_argument( - "Input dimenstion " + std::to_string(input[0].dim(0)) + - " and Embedding dimension " + std::to_string(layerDim_) + - " are different."); - } - // Retrieve the number of tokens (positions) and the numeric type (floating - // point precision). - const int nPositions = input[0].dim(1); - const auto numType = input[0].type(); - // Generate the tensor of positions for each token vector [embedding size, num - // positions]. - // positions = [[ 0, 0, ..], - // [ 1, 1, ..], - // [.., .., ..]] - Tensor positions = fl::iota({1, nPositions}, {layerDim_}, numType); - // Generate the embedding transformation with the precomputed scale and shift - // factors. - positions = fl::sin( - positions * fl::tile(scale_.astype(numType), {1, nPositions}) + - fl::tile(cosShifts_.astype(numType), {1, nPositions})); - // Convert the positional embedding into a variable (for gradient tracking). - Variable embeddingsPos = Variable(positions, false); - // Return the inputs with the positional embeddings tiled over the batch - // dimension. - return {input[0] * inputScale_ + tileAs(embeddingsPos, input[0])}; + const std::vector& input +) { + if(input[0].dim(0) != layerDim_) { + throw std::invalid_argument( + "Input dimenstion " + std::to_string(input[0].dim(0)) + + " and Embedding dimension " + std::to_string(layerDim_) + + " are different." + ); + } + // Retrieve the number of tokens (positions) and the numeric type (floating + // point precision). + const int nPositions = input[0].dim(1); + const auto numType = input[0].type(); + // Generate the tensor of positions for each token vector [embedding size, num + // positions]. + // positions = [[ 0, 0, ..], + // [ 1, 1, ..], + // [.., .., ..]] + Tensor positions = fl::iota({1, nPositions}, {layerDim_}, numType); + // Generate the embedding transformation with the precomputed scale and shift + // factors. + positions = fl::sin( + positions * fl::tile(scale_.astype(numType), {1, nPositions}) + + fl::tile(cosShifts_.astype(numType), {1, nPositions}) + ); + // Convert the positional embedding into a variable (for gradient tracking). + Variable embeddingsPos = Variable(positions, false); + // Return the inputs with the positional embeddings tiled over the batch + // dimension. + return {input[0] * inputScale_ + tileAs(embeddingsPos, input[0])}; } std::vector SinusoidalPositionEmbedding::operator()( - const std::vector& input) { - return forward(input); + const std::vector& input +) { + return forward(input); } std::unique_ptr SinusoidalPositionEmbedding::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string SinusoidalPositionEmbedding::prettyString() const { - std::ostringstream ss; - ss << "Sinusoidal Position Embedding Layer (embDim: " << layerDim_ - << "), (input scale " << inputScale_ << ")"; - return ss.str(); + std::ostringstream ss; + ss << "Sinusoidal Position Embedding Layer (embDim: " << layerDim_ + << "), (input scale " << inputScale_ << ")"; + return ss.str(); } SinusoidalPositionEmbedding::SinusoidalPositionEmbedding() = default; diff --git a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.h b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.h index bb443e7..d7b9555 100644 --- a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.h +++ b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.h @@ -32,41 +32,44 @@ namespace fl { * */ class FL_API SinusoidalPositionEmbedding : public Module { - public: - explicit SinusoidalPositionEmbedding( - int32_t layerDim, - double inputScale = 1.); - SinusoidalPositionEmbedding(const SinusoidalPositionEmbedding& other); - SinusoidalPositionEmbedding& operator=( - const SinusoidalPositionEmbedding& other); - /** - * SinusoidalPositionEmbedding::forward(input) expects input[0] to be of - * dimensions CxTxBx1 with C = layerDim. - * output[0] = input[0] * inputScale + sinPosEmb, where sinPosEmb is a Tensor - * of dimensions CxTxBx1 computed based on position and C. - */ - std::vector forward(const std::vector& input) override; +public: + explicit SinusoidalPositionEmbedding( + int32_t layerDim, + double inputScale = 1. + ); + SinusoidalPositionEmbedding(const SinusoidalPositionEmbedding& other); + SinusoidalPositionEmbedding& operator=( + const SinusoidalPositionEmbedding& other + ); + /** + * SinusoidalPositionEmbedding::forward(input) expects input[0] to be of + * dimensions CxTxBx1 with C = layerDim. + * output[0] = input[0] * inputScale + sinPosEmb, where sinPosEmb is a Tensor + * of dimensions CxTxBx1 computed based on position and C. + */ + std::vector forward(const std::vector& input) override; - std::vector operator()(const std::vector& input); + std::vector operator()(const std::vector& input); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE( - Module, - layerDim_, - inputScale_, - scale_, - fl::versioned(cosShifts_, 1)) +private: + FL_SAVE_LOAD_WITH_BASE( + Module, + layerDim_, + inputScale_, + scale_, + fl::versioned(cosShifts_, 1) + ) - int32_t layerDim_; - double inputScale_; - Tensor scale_; - Tensor cosShifts_; + int32_t layerDim_; + double inputScale_; + Tensor scale_; + Tensor cosShifts_; - SinusoidalPositionEmbedding(); + SinusoidalPositionEmbedding(); }; } // namespace fl diff --git a/flashlight/fl/contrib/modules/SpecAugment.cpp b/flashlight/fl/contrib/modules/SpecAugment.cpp index d2f44f0..1643106 100644 --- a/flashlight/fl/contrib/modules/SpecAugment.cpp +++ b/flashlight/fl/contrib/modules/SpecAugment.cpp @@ -21,85 +21,86 @@ SpecAugment::SpecAugment( int tMaskT, float tMaskP, int nTMask, - MaskingStrategy mStrategy /* = MaskingStrategy::ZERO */) - : timeWarpW_(tWarpW), - freqMaskF_(fMaskF), - numFreqMask_(nFMask), - timeMaskT_(tMaskT), - timeMaskP_(tMaskP), - numTimeMask_(nTMask), - maskStrategy_(mStrategy) { - if (numFreqMask_ > 0 && freqMaskF_ <= 0) { - throw std::invalid_argument("invalid arguments for frequency masking."); - } - if (numTimeMask_ > 0 && timeMaskT_ <= 0) { - throw std::invalid_argument("invalid arguments for time masking."); - } - if (numTimeMask_ > 0 && (timeMaskP_ <= 0 || timeMaskP_ > 1.0)) { - throw std::invalid_argument("invalid arguments for time masking."); - } + MaskingStrategy mStrategy /* = MaskingStrategy::ZERO */ +) : timeWarpW_(tWarpW), + freqMaskF_(fMaskF), + numFreqMask_(nFMask), + timeMaskT_(tMaskT), + timeMaskP_(tMaskP), + numTimeMask_(nTMask), + maskStrategy_(mStrategy) { + if(numFreqMask_ > 0 && freqMaskF_ <= 0) { + throw std::invalid_argument("invalid arguments for frequency masking."); + } + if(numTimeMask_ > 0 && timeMaskT_ <= 0) { + throw std::invalid_argument("invalid arguments for time masking."); + } + if(numTimeMask_ > 0 && (timeMaskP_ <= 0 || timeMaskP_ > 1.0)) { + throw std::invalid_argument("invalid arguments for time masking."); + } } Variable SpecAugment::forward(const Variable& input) { - if (input.isCalcGrad()) { - throw std::invalid_argument( - "input gradient calculation is not supported for SpecAugment."); - } + if(input.isCalcGrad()) { + throw std::invalid_argument( + "input gradient calculation is not supported for SpecAugment." + ); + } - auto output = Variable(input.tensor(), false); - if (!train_) { - return output; - } + auto output = Variable(input.tensor(), false); + if(!train_) { + return output; + } - auto& opArr = output.tensor(); + auto& opArr = output.tensor(); - double replaceVal = (maskStrategy_ == MaskingStrategy::GLOBAL_MEAN) - ? fl::mean(input.tensor()).asScalar() - : 0.0; + double replaceVal = (maskStrategy_ == MaskingStrategy::GLOBAL_MEAN) + ? fl::mean(input.tensor()).asScalar() + : 0.0; - auto numFreqChans = input.dim(1); // number of frequency channels - if (numFreqChans < freqMaskF_) { - throw std::runtime_error("Invalid input frequency channels"); - } - for (int i = 0; i < numFreqMask_; ++i) { - auto f = generateRandomInt(0, freqMaskF_); - auto f0 = generateRandomInt(0, numFreqChans - f); - opArr(fl::span, fl::range(f0, f0 + f + 1)) = replaceVal; - } + auto numFreqChans = input.dim(1); // number of frequency channels + if(numFreqChans < freqMaskF_) { + throw std::runtime_error("Invalid input frequency channels"); + } + for(int i = 0; i < numFreqMask_; ++i) { + auto f = generateRandomInt(0, freqMaskF_); + auto f0 = generateRandomInt(0, numFreqChans - f); + opArr(fl::span, fl::range(f0, f0 + f + 1)) = replaceVal; + } - auto numTimeSteps = input.dim(0); // number of time steps - // an upper bound on the time mask - int T = std::min(timeMaskT_, static_cast(numTimeSteps * timeMaskP_)); - if (T > 0) { - for (int i = 0; i < numTimeMask_; ++i) { - auto t = generateRandomInt(0, T); - auto t0 = generateRandomInt(0, numTimeSteps - t); - opArr(fl::range(t0, t0 + t + 1)) = replaceVal; + auto numTimeSteps = input.dim(0); // number of time steps + // an upper bound on the time mask + int T = std::min(timeMaskT_, static_cast(numTimeSteps * timeMaskP_)); + if(T > 0) { + for(int i = 0; i < numTimeMask_; ++i) { + auto t = generateRandomInt(0, T); + auto t0 = generateRandomInt(0, numTimeSteps - t); + opArr(fl::range(t0, t0 + t + 1)) = replaceVal; + } } - } - return output; + return output; } int SpecAugment::generateRandomInt(int low, int high) { - std::uniform_int_distribution uniformDist(low, high - 1); - return uniformDist(eng_); + std::uniform_int_distribution uniformDist(low, high - 1); + return uniformDist(eng_); } std::unique_ptr SpecAugment::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string SpecAugment::prettyString() const { - std::ostringstream ss; - ss << "SpecAugment ( "; - ss << "W: " << timeWarpW_ << ", "; - ss << "F: " << freqMaskF_ << ", "; - ss << "mF: " << numFreqMask_ << ", "; - ss << "T: " << timeMaskT_ << ", "; - ss << "p: " << timeMaskP_ << ", "; - ss << "mT: " << numTimeMask_; - ss << " )"; - return ss.str(); + std::ostringstream ss; + ss << "SpecAugment ( "; + ss << "W: " << timeWarpW_ << ", "; + ss << "F: " << freqMaskF_ << ", "; + ss << "mF: " << numFreqMask_ << ", "; + ss << "T: " << timeMaskT_ << ", "; + ss << "p: " << timeMaskP_ << ", "; + ss << "mT: " << numTimeMask_; + ss << " )"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/contrib/modules/SpecAugment.h b/flashlight/fl/contrib/modules/SpecAugment.h index 941a3a2..d88a5c6 100644 --- a/flashlight/fl/contrib/modules/SpecAugment.h +++ b/flashlight/fl/contrib/modules/SpecAugment.h @@ -27,60 +27,62 @@ namespace fl { * Switchboard strong (SS) 40 27 2 70 0.2 2 **/ class FL_API SpecAugment : public UnaryModule { - public: - enum class MaskingStrategy { - ZERO = 0, - GLOBAL_MEAN = 1, - // TODO - add support for mean along time, freq axes - }; - - SpecAugment( - int tWarpW, - int fMaskF, - int nFMask, - int tMaskT, - float tMaskP, - int nTMask, - MaskingStrategy mStrategy = MaskingStrategy::ZERO); - - Variable forward(const Variable& input) override; - - FL_SAVE_LOAD_WITH_BASE( - UnaryModule, - timeWarpW_, - freqMaskF_, - numFreqMask_, - timeMaskT_, - timeMaskP_, - numTimeMask_, - maskStrategy_) - - std::unique_ptr clone() const override; - - std::string prettyString() const override; - - private: - // Time Warping - NOT SUPPORTED CURRENTLY - // Use timeWarpW_ = 0 to disable this - int timeWarpW_; - - // Frequency Masking - // Use freqMaskF_ = 0 to disable this - int freqMaskF_; - int numFreqMask_; - - // Time Masking - // Use timeMaskT_ = 0 to disable this - int timeMaskT_; - float timeMaskP_; - int numTimeMask_; - - std::mt19937 eng_{0}; - MaskingStrategy maskStrategy_; - - int generateRandomInt(int low, int high); - - SpecAugment() = default; +public: + enum class MaskingStrategy { + ZERO = 0, + GLOBAL_MEAN = 1, + // TODO - add support for mean along time, freq axes + }; + + SpecAugment( + int tWarpW, + int fMaskF, + int nFMask, + int tMaskT, + float tMaskP, + int nTMask, + MaskingStrategy mStrategy = MaskingStrategy::ZERO + ); + + Variable forward(const Variable& input) override; + + FL_SAVE_LOAD_WITH_BASE( + UnaryModule, + timeWarpW_, + freqMaskF_, + numFreqMask_, + timeMaskT_, + timeMaskP_, + numTimeMask_, + maskStrategy_ + ) + + std::unique_ptr clone() const override; + + std::string prettyString() const override; + +private: + // Time Warping - NOT SUPPORTED CURRENTLY + // Use timeWarpW_ = 0 to disable this + int timeWarpW_; + + // Frequency Masking + // Use freqMaskF_ = 0 to disable this + int freqMaskF_; + int numFreqMask_; + + // Time Masking + // Use timeMaskT_ = 0 to disable this + int timeMaskT_; + float timeMaskP_; + int numTimeMask_; + + std::mt19937 eng_{0}; + MaskingStrategy maskStrategy_; + + int generateRandomInt(int low, int high); + + SpecAugment() = default; }; } // namespace fl diff --git a/flashlight/fl/contrib/modules/TDSBlock.cpp b/flashlight/fl/contrib/modules/TDSBlock.cpp index 7cd22b1..277b87d 100644 --- a/flashlight/fl/contrib/modules/TDSBlock.cpp +++ b/flashlight/fl/contrib/modules/TDSBlock.cpp @@ -16,78 +16,84 @@ TDSBlock::TDSBlock( double dropout /* = 0 */, int innerLinearDim /* = 0 */, int rightPadding /* = -1 */, - bool lNormIncludeTime /* = true */) { - Sequential conv; - auto convPadding = static_cast(fl::PaddingMode::SAME); - if (rightPadding != -1) { - int totalPadding = kernelSize - 1; - if (rightPadding > totalPadding) { - throw std::invalid_argument( - "right padding exceeds the 'SAME' padding required for TDSBlock"); + bool lNormIncludeTime /* = true */ +) { + Sequential conv; + auto convPadding = static_cast(fl::PaddingMode::SAME); + if(rightPadding != -1) { + int totalPadding = kernelSize - 1; + if(rightPadding > totalPadding) { + throw std::invalid_argument( + "right padding exceeds the 'SAME' padding required for TDSBlock" + ); + } + conv.add( + Padding( + {std::pair{totalPadding - rightPadding, rightPadding}}, + 0.0 + ) + ); + convPadding = 0; } - conv.add(Padding( - {std::pair{totalPadding - rightPadding, rightPadding}}, 0.0)); - convPadding = 0; - } - conv.add(Conv2D(channels, channels, kernelSize, 1, 1, 1, convPadding, 0)); - conv.add(ReLU()); - conv.add(Dropout(dropout)); + conv.add(Conv2D(channels, channels, kernelSize, 1, 1, 1, convPadding, 0)); + conv.add(ReLU()); + conv.add(Dropout(dropout)); - int linearDim = channels * width; - if (innerLinearDim == 0) { - innerLinearDim = linearDim; - } - Sequential fc; - fc.add(Reorder({2, 1, 0, 3})); - fc.add(View({linearDim, -1, 1, 0})); + int linearDim = channels * width; + if(innerLinearDim == 0) { + innerLinearDim = linearDim; + } + Sequential fc; + fc.add(Reorder({2, 1, 0, 3})); + fc.add(View({linearDim, -1, 1, 0})); - fc.add(Linear(linearDim, innerLinearDim)); - fc.add(ReLU()); - if (dropout > 0) { - fc.add(Dropout(dropout)); - } - fc.add(Linear(innerLinearDim, linearDim)); - fc.add(View({channels, width, -1, 0})); - fc.add(Reorder({2, 1, 0, 3})); - if (dropout > 0) { - fc.add(Dropout(dropout)); - } + fc.add(Linear(linearDim, innerLinearDim)); + fc.add(ReLU()); + if(dropout > 0) { + fc.add(Dropout(dropout)); + } + fc.add(Linear(innerLinearDim, linearDim)); + fc.add(View({channels, width, -1, 0})); + fc.add(Reorder({2, 1, 0, 3})); + if(dropout > 0) { + fc.add(Dropout(dropout)); + } - add(std::move(conv)); - if (lNormIncludeTime) { - add(LayerNorm(std::vector{0, 1, 2})); - } else { - add(LayerNorm(std::vector{1, 2})); - } - add(std::move(fc)); - if (lNormIncludeTime) { - add(LayerNorm(std::vector{0, 1, 2})); - } else { - add(LayerNorm(std::vector{1, 2})); - } + add(std::move(conv)); + if(lNormIncludeTime) { + add(LayerNorm(std::vector{0, 1, 2})); + } else { + add(LayerNorm(std::vector{1, 2})); + } + add(std::move(fc)); + if(lNormIncludeTime) { + add(LayerNorm(std::vector{0, 1, 2})); + } else { + add(LayerNorm(std::vector{1, 2})); + } } std::vector TDSBlock::forward(const std::vector& inputs) { - auto out = inputs[0]; - out = module(0)->forward({out})[0].astype(out.type()) + out; - out = module(1)->forward({out})[0]; - out = module(2)->forward({out})[0].astype(out.type()) + out; - return module(3)->forward({out}); + auto out = inputs[0]; + out = module(0)->forward({out})[0].astype(out.type()) + out; + out = module(1)->forward({out})[0]; + out = module(2)->forward({out})[0].astype(out.type()) + out; + return module(3)->forward({out}); } std::string TDSBlock::prettyString() const { - std::ostringstream ss; - auto convW = param(0); - auto linW = param(4); - int kw = convW.dim(0); - int c = convW.dim(2); - int w = linW.dim(0) / c; - int l = linW.dim(1); - int l2 = linW.dim(0); - ss << "Time-Depth Separable Block ("; - ss << kw << ", " << w << ", " << c << ") [" << l << " -> " << l2 << " -> " - << l << "]"; - return ss.str(); + std::ostringstream ss; + auto convW = param(0); + auto linW = param(4); + int kw = convW.dim(0); + int c = convW.dim(2); + int w = linW.dim(0) / c; + int l = linW.dim(1); + int l2 = linW.dim(0); + ss << "Time-Depth Separable Block ("; + ss << kw << ", " << w << ", " << c << ") [" << l << " -> " << l2 << " -> " + << l << "]"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/contrib/modules/TDSBlock.h b/flashlight/fl/contrib/modules/TDSBlock.h index 855af20..addaea6 100644 --- a/flashlight/fl/contrib/modules/TDSBlock.h +++ b/flashlight/fl/contrib/modules/TDSBlock.h @@ -18,46 +18,47 @@ namespace fl { * This [link](https://imgur.com/a/LAdlwZK) shows the diagram of TDSBlock. */ class FL_API TDSBlock : public Container { - private: - TDSBlock() = default; - FL_SAVE_LOAD_WITH_BASE(Container) +private: + TDSBlock() = default; + FL_SAVE_LOAD_WITH_BASE(Container) - public: - /** - * Constructs a TDS Block. Input/Output Dim: T x W x C x B, where - * T = number of time-steps - * W = input/output width - * C = input/output channels - * B = batch size - * - * @param channels Number of input (and output) channels - * @param kernelSize Kernel size for convolution - * @param width Input width - * @param dropout Amount of dropout to be used - * @param innerLinearDim If > 0, the two linear layers perform the `input`-> - `output` channel transform as `linearDim` -> `innerLinearDim` and - `innerLinearDim` -> `linearDim`, where `linearDim` = `W * C` - * @param rightPadding Amount of right padding for asymmetric convolutions. - By default (= `-1`), performs symmetric padding for the "SAME" conv. - * @param lNormIncludeTime If `true`, normalization is performed on the - [T, W, C] axes as in the original paper. If `false`, exclude time - dimension from computing stats to follow standard LayerNorm ==> - normalization is performed on the [W, C] axes with scalar affine - transformation. - */ - explicit TDSBlock( - int channels, - int kernelSize, - int width, - double dropout = 0, - int innerLinearDim = 0, - int rightPadding = -1, - bool lNormIncludeTime = true); +public: + /** + * Constructs a TDS Block. Input/Output Dim: T x W x C x B, where + * T = number of time-steps + * W = input/output width + * C = input/output channels + * B = batch size + * + * @param channels Number of input (and output) channels + * @param kernelSize Kernel size for convolution + * @param width Input width + * @param dropout Amount of dropout to be used + * @param innerLinearDim If > 0, the two linear layers perform the `input`-> + `output` channel transform as `linearDim` -> `innerLinearDim` and + `innerLinearDim` -> `linearDim`, where `linearDim` = `W * C` + * @param rightPadding Amount of right padding for asymmetric convolutions. + By default (= `-1`), performs symmetric padding for the "SAME" conv. + * @param lNormIncludeTime If `true`, normalization is performed on the + [T, W, C] axes as in the original paper. If `false`, exclude time + dimension from computing stats to follow standard LayerNorm ==> + normalization is performed on the [W, C] axes with scalar affine + transformation. + */ + explicit TDSBlock( + int channels, + int kernelSize, + int width, + double dropout = 0, + int innerLinearDim = 0, + int rightPadding = -1, + bool lNormIncludeTime = true + ); - std::vector forward(const std::vector& inputs) override; - std::string prettyString() const override; + std::vector forward(const std::vector& inputs) override; + std::string prettyString() const override; - FL_BASIC_CONTAINER_CLONING(TDSBlock) + FL_BASIC_CONTAINER_CLONING(TDSBlock) }; } // namespace fl diff --git a/flashlight/fl/contrib/modules/Transformer.cpp b/flashlight/fl/contrib/modules/Transformer.cpp index 2638c93..43db939 100644 --- a/flashlight/fl/contrib/modules/Transformer.cpp +++ b/flashlight/fl/contrib/modules/Transformer.cpp @@ -16,8 +16,8 @@ namespace { fl::Variable transformerInitLinear(int32_t inDim, int32_t outDim) { - float std = std::sqrt(1.0 / float(inDim)); - return fl::uniform(outDim, inDim, -std, std, fl::dtype::f32, true); + float std = std::sqrt(1.0 / float(inDim)); + return fl::uniform(outDim, inDim, -std, std, fl::dtype::f32, true); } } // namespace @@ -32,201 +32,212 @@ Transformer::Transformer( float pDropout, float pLayerdrop, bool useMask, - bool preLN) - : nHeads_(nHeads), - bptt_(bptt), - pDropout_(pDropout), - pLayerdrop_(pLayerdrop), - useMask_(useMask), - preLN_(preLN), - w1_(std::make_shared(transformerInitLinear(modelDim, mlpDim))), - w2_(std::make_shared(transformerInitLinear(mlpDim, modelDim))), - wq_(std::make_shared( - transformerInitLinear(modelDim, headDim * nHeads))), - wk_(std::make_shared( - transformerInitLinear(modelDim, headDim * nHeads))), - wv_(std::make_shared( - transformerInitLinear(modelDim, headDim * nHeads))), - wf_(std::make_shared( - transformerInitLinear(headDim * nHeads, modelDim))), - norm1_(std::make_shared(std::vector({0, 3}))), - norm2_(std::make_shared(std::vector({0, 3}))) { - if (bptt > 0) { - params_.push_back( - uniform(2 * bptt - 1, headDim, -0.1, 0.1, fl::dtype::f32, true)); - } - - createLayers(); + bool preLN +) : nHeads_(nHeads), + bptt_(bptt), + pDropout_(pDropout), + pLayerdrop_(pLayerdrop), + useMask_(useMask), + preLN_(preLN), + w1_(std::make_shared(transformerInitLinear(modelDim, mlpDim))), + w2_(std::make_shared(transformerInitLinear(mlpDim, modelDim))), + wq_(std::make_shared(transformerInitLinear(modelDim, headDim * nHeads))), + wk_(std::make_shared(transformerInitLinear(modelDim, headDim * nHeads))), + wv_(std::make_shared(transformerInitLinear(modelDim, headDim * nHeads))), + wf_(std::make_shared(transformerInitLinear(headDim * nHeads, modelDim))), + norm1_(std::make_shared(std::vector({0, 3}))), + norm2_(std::make_shared(std::vector({0, 3}))) { + if(bptt > 0) { + params_.push_back( + uniform(2 * bptt - 1, headDim, -0.1, 0.1, fl::dtype::f32, true) + ); + } + + createLayers(); } Transformer::Transformer(const Transformer& other) { - copy(other); - createLayers(); + copy(other); + createLayers(); } Transformer& Transformer::operator=(const Transformer& other) { - clear(); - copy(other); - createLayers(); - return *this; + clear(); + copy(other); + createLayers(); + return *this; } void Transformer::copy(const Transformer& other) { - train_ = other.train_; - nHeads_ = other.nHeads_; - bptt_ = other.bptt_; - pDropout_ = other.pDropout_; - pLayerdrop_ = other.pLayerdrop_; - useMask_ = other.useMask_; - preLN_ = other.preLN_; - w1_ = std::make_shared(*other.w1_); - w2_ = std::make_shared(*other.w2_); - wq_ = std::make_shared(*other.wq_); - wk_ = std::make_shared(*other.wk_); - wv_ = std::make_shared(*other.wv_); - wf_ = std::make_shared(*other.wf_); - norm1_ = std::make_shared(*other.norm1_); - norm2_ = std::make_shared(*other.norm2_); - if (bptt_ > 0) { - const auto& p = other.param(0); - params_.emplace_back(p.copy()); - } + train_ = other.train_; + nHeads_ = other.nHeads_; + bptt_ = other.bptt_; + pDropout_ = other.pDropout_; + pLayerdrop_ = other.pLayerdrop_; + useMask_ = other.useMask_; + preLN_ = other.preLN_; + w1_ = std::make_shared(*other.w1_); + w2_ = std::make_shared(*other.w2_); + wq_ = std::make_shared(*other.wq_); + wk_ = std::make_shared(*other.wk_); + wv_ = std::make_shared(*other.wv_); + wf_ = std::make_shared(*other.wf_); + norm1_ = std::make_shared(*other.norm1_); + norm2_ = std::make_shared(*other.norm2_); + if(bptt_ > 0) { + const auto& p = other.param(0); + params_.emplace_back(p.copy()); + } } void Transformer::createLayers() { - add(w1_); - add(w2_); - add(wq_); - add(wk_); - add(wv_); - add(wf_); - add(norm1_); - add(norm2_); + add(w1_); + add(w2_); + add(wq_); + add(wk_); + add(wv_); + add(wf_); + add(norm1_); + add(norm2_); } Variable Transformer::mlp(const Variable& input) { - float pDropout = train_ ? pDropout_ : 0.0; - return (*w2_)(dropout(relu((*w1_)(input)), pDropout)); + float pDropout = train_ ? pDropout_ : 0.0; + return (*w2_)(dropout(relu((*w1_)(input)), pDropout)); } Variable Transformer::getMask(int32_t n, bool cache) { - auto mask = fl::tril(fl::full({n, n}, 1.0)); - if (cache) { - auto maskCache = fl::triu(fl::full({n, n}, 1.0)); - mask = fl::concatenate(1, maskCache, mask); - } - return Variable(fl::log(mask), false); + auto mask = fl::tril(fl::full({n, n}, 1.0)); + if(cache) { + auto maskCache = fl::triu(fl::full({n, n}, 1.0)); + mask = fl::concatenate(1, maskCache, mask); + } + return Variable(fl::log(mask), false); } Variable Transformer::selfAttention(const std::vector& input) { - // previous step[optionally], input, padMask - const auto& encoderInput = input.at(input.size() - 2); - // in case of previous state input[0] has size CxT_prevxB - int n = input[0].dim(1), bsz = input[0].dim(2); - double pDrop = train_ ? pDropout_ : 0.0; - - auto q = transpose((*wq_)(encoderInput), {1, 0, 2}); - std::vector inputWithState(input.begin(), input.end() - 1); - auto k = transpose((*wk_)(concatenate(inputWithState, 1)), {1, 0, 2}); - auto v = transpose((*wv_)(concatenate(inputWithState, 1)), {1, 0, 2}); - - Variable mask, posEmb; - if (bptt_ > 0) { - posEmb = - tile(params_[0].astype(encoderInput.type()), {1, 1, nHeads_ * bsz}); - } - if (useMask_ && encoderInput.dim(1) > 1) { - // mask future if we use the previous state (then n is previous time) - mask = getMask(n, input.size() == 3); - } - - int offset = (input.size() == 2) ? 0 : n; - - // time x batch - fl::Variable padMask; - if (!input.back().isEmpty()) { - auto padMaskArr = input.back().tensor(); - Shape newMaskShape = {encoderInput.dim(1), encoderInput.dim(2)}; - // TODO{fl::Tensor}{resize} - emulate the ArrayFire resize operation for - // transformer pad mask - if (padMaskArr.elements() != newMaskShape.elements()) { - throw std::runtime_error( - "Transformer::selfAttention - pad mask requires resize. " - "This behavior will be fixed in a future release "); + // previous step[optionally], input, padMask + const auto& encoderInput = input.at(input.size() - 2); + // in case of previous state input[0] has size CxT_prevxB + int n = input[0].dim(1), bsz = input[0].dim(2); + double pDrop = train_ ? pDropout_ : 0.0; + + auto q = transpose((*wq_)(encoderInput), {1, 0, 2}); + std::vector inputWithState(input.begin(), input.end() - 1); + auto k = transpose((*wk_)(concatenate(inputWithState, 1)), {1, 0, 2}); + auto v = transpose((*wv_)(concatenate(inputWithState, 1)), {1, 0, 2}); + + Variable mask, posEmb; + if(bptt_ > 0) { + posEmb = + tile(params_[0].astype(encoderInput.type()), {1, 1, nHeads_ * bsz}); + } + if(useMask_ && encoderInput.dim(1) > 1) { + // mask future if we use the previous state (then n is previous time) + mask = getMask(n, input.size() == 3); + } + + int offset = (input.size() == 2) ? 0 : n; + + // time x batch + fl::Variable padMask; + if(!input.back().isEmpty()) { + auto padMaskArr = input.back().tensor(); + Shape newMaskShape = {encoderInput.dim(1), encoderInput.dim(2)}; + // TODO{fl::Tensor}{resize} - emulate the ArrayFire resize operation for + // transformer pad mask + if(padMaskArr.elements() != newMaskShape.elements()) { + throw std::runtime_error( + "Transformer::selfAttention - pad mask requires resize. " + "This behavior will be fixed in a future release " + ); + } + padMaskArr = fl::reshape(padMaskArr, newMaskShape); + padMask = fl::Variable(fl::log(padMaskArr), false); } - padMaskArr = fl::reshape(padMaskArr, newMaskShape); - padMask = fl::Variable(fl::log(padMaskArr), false); - } - auto result = multiheadAttention( - q, k, v, posEmb, mask, padMask, nHeads_, pDrop, offset); - result = (*wf_)(transpose(result, {1, 0, 2})); - - return result; + auto result = multiheadAttention( + q, + k, + v, + posEmb, + mask, + padMask, + nHeads_, + pDrop, + offset + ); + result = (*wf_)(transpose(result, {1, 0, 2})); + + return result; } std::vector Transformer::forward(const std::vector& input) { - // previous step[optionally], input, padMask - // padMask should be empty if previous step is provided - // padMask is expected to have "1" on the used positions and "0" on padded - // positions - if (input.size() != 2) { - throw std::invalid_argument( - "Invalid inputs for transformer block: there should be at least input and mask"); - } - const auto& x = input.at(input.size() - 2); - if (x.ndim() != 3) { - throw std::invalid_argument( - "Transformer::forward - input should be of 3 dimensions " - "expects an input of size C x T x B - see documentation."); - } - - if (!input.back().isEmpty()) { - if (input.back().ndim() < 2) { - throw std::invalid_argument( - "Transformer::forward - invalid size for pad mask - " - "must have at least two dimensions"); - - } else if (x.dim(2) != input.back().dim(1)) { - throw std::invalid_argument( - "Transformer::forward - invalid inputs for transformer:" - " input and mask batch sizes are different"); + // previous step[optionally], input, padMask + // padMask should be empty if previous step is provided + // padMask is expected to have "1" on the used positions and "0" on padded + // positions + if(input.size() != 2) { + throw std::invalid_argument( + "Invalid inputs for transformer block: there should be at least input and mask" + ); + } + const auto& x = input.at(input.size() - 2); + if(x.ndim() != 3) { + throw std::invalid_argument( + "Transformer::forward - input should be of 3 dimensions " + "expects an input of size C x T x B - see documentation." + ); + } + + if(!input.back().isEmpty()) { + if(input.back().ndim() < 2) { + throw std::invalid_argument( + "Transformer::forward - invalid size for pad mask - " + "must have at least two dimensions" + ); + + } else if(x.dim(2) != input.back().dim(1)) { + throw std::invalid_argument( + "Transformer::forward - invalid inputs for transformer:" + " input and mask batch sizes are different" + ); + } + } + + float f = 1.0; + if(train_ && (fl::rand({1}).scalar() < pLayerdrop_)) { + f = 0.0; + } + if(preLN_) { + auto h = (f * (*norm1_)(selfAttention(input))).astype(x.type()) + x; + return {f* (*norm2_)(mlp(h)).astype(h.type()) + h}; + } else { + auto h = (*norm1_)((f* selfAttention(input)).astype(x.type()) + x); + return {(*norm2_)((f* mlp(h)).astype(h.type()) + h)}; } - } - - float f = 1.0; - if (train_ && (fl::rand({1}).scalar() < pLayerdrop_)) { - f = 0.0; - } - if (preLN_) { - auto h = (f * (*norm1_)(selfAttention(input))).astype(x.type()) + x; - return {f * (*norm2_)(mlp(h)).astype(h.type()) + h}; - } else { - auto h = (*norm1_)((f * selfAttention(input)).astype(x.type()) + x); - return {(*norm2_)((f * mlp(h)).astype(h.type()) + h)}; - } } void Transformer::setDropout(float value) { - pDropout_ = value; + pDropout_ = value; } void Transformer::setLayerDropout(float value) { - pLayerdrop_ = value; + pLayerdrop_ = value; } std::unique_ptr Transformer::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Transformer::prettyString() const { - std::ostringstream ss; - ss << "Transformer (nHeads: " << nHeads_ << "), " - << "(pDropout: " << pDropout_ << "), " - << "(pLayerdrop: " << pLayerdrop_ << "), " - << "(bptt: " << bptt_ << "), " - << "(useMask: " << useMask_ << "), " - << "(preLayerNorm: " << preLN_ << ")"; - return ss.str(); + std::ostringstream ss; + ss << "Transformer (nHeads: " << nHeads_ << "), " + << "(pDropout: " << pDropout_ << "), " + << "(pLayerdrop: " << pLayerdrop_ << "), " + << "(bptt: " << bptt_ << "), " + << "(useMask: " << useMask_ << "), " + << "(preLayerNorm: " << preLN_ << ")"; + return ss.str(); } Transformer::Transformer() = default; diff --git a/flashlight/fl/contrib/modules/Transformer.h b/flashlight/fl/contrib/modules/Transformer.h index 92d842d..a9d1dc0 100644 --- a/flashlight/fl/contrib/modules/Transformer.h +++ b/flashlight/fl/contrib/modules/Transformer.h @@ -47,62 +47,62 @@ namespace fl { * @param preLN apply layer normalization before or after residual connection */ class FL_API Transformer : public Container { - public: - Transformer( - int32_t modelDim, - int32_t headDim, - int32_t mlpDim, - int32_t nHeads, - int32_t bptt, - float pDropout, - float pLayerdrop, - bool useMask = false, - bool preLN = false); - Transformer(const Transformer& other); - Transformer& operator=(const Transformer& other); - Transformer(Transformer&& other) = default; - Transformer& operator=(Transformer&& other) = default; +public: + Transformer( + int32_t modelDim, + int32_t headDim, + int32_t mlpDim, + int32_t nHeads, + int32_t bptt, + float pDropout, + float pLayerdrop, + bool useMask = false, + bool preLN = false + ); + Transformer(const Transformer& other); + Transformer& operator=(const Transformer& other); + Transformer(Transformer&& other) = default; + Transformer& operator=(Transformer&& other) = default; - std::vector forward(const std::vector& input) override; - void setDropout(float value); - void setLayerDropout(float value); - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::vector forward(const std::vector& input) override; + void setDropout(float value); + void setLayerDropout(float value); + std::unique_ptr clone() const override; + std::string prettyString() const override; - private: - int32_t nHeads_; - int32_t bptt_; - double pDropout_; - double pLayerdrop_; - bool useMask_; - bool preLN_; - std::shared_ptr w1_, w2_, wq_, wk_, wv_, wf_; - std::shared_ptr norm1_, norm2_; +private: + int32_t nHeads_; + int32_t bptt_; + double pDropout_; + double pLayerdrop_; + bool useMask_; + bool preLN_; + std::shared_ptr w1_, w2_, wq_, wk_, wv_, wf_; + std::shared_ptr norm1_, norm2_; - void copy(const Transformer& other); - void createLayers(); - Variable mlp(const Variable& input); - Variable getMask(int32_t n, bool cache = false); - Variable selfAttention(const std::vector& input); + void copy(const Transformer& other); + void createLayers(); + Variable mlp(const Variable& input); + Variable getMask(int32_t n, bool cache = false); + Variable selfAttention(const std::vector& input); - FL_SAVE_LOAD_WITH_BASE( - Container, - w1_, - w2_, - wq_, - wk_, - wv_, - wf_, - norm1_, - norm2_, - nHeads_, - pDropout_, - pLayerdrop_, - bptt_, - useMask_, - preLN_) - - Transformer(); + FL_SAVE_LOAD_WITH_BASE( + Container, + w1_, + w2_, + wq_, + wk_, + wv_, + wf_, + norm1_, + norm2_, + nHeads_, + pDropout_, + pLayerdrop_, + bptt_, + useMask_, + preLN_ + ) Transformer(); }; } // namespace fl diff --git a/flashlight/fl/contrib/modules/modules.h b/flashlight/fl/contrib/modules/modules.h index 6015688..699d36e 100644 --- a/flashlight/fl/contrib/modules/modules.h +++ b/flashlight/fl/contrib/modules/modules.h @@ -9,7 +9,7 @@ #include "flashlight/fl/contrib/modules/AdaptiveEmbedding.h" #include "flashlight/fl/contrib/modules/AsymmetricConv1D.h" -#include "flashlight/fl/contrib/modules/Conformer.h" +#include "flashlight/fl/contrib/modules/Conformer.h" #include "flashlight/fl/contrib/modules/PositionEmbedding.h" #include "flashlight/fl/contrib/modules/RawWavSpecAugment.h" #include "flashlight/fl/contrib/modules/Residual.h" diff --git a/flashlight/fl/dataset/BatchDataset.cpp b/flashlight/fl/dataset/BatchDataset.cpp index 312a028..a271bd9 100644 --- a/flashlight/fl/dataset/BatchDataset.cpp +++ b/flashlight/fl/dataset/BatchDataset.cpp @@ -17,72 +17,76 @@ BatchDataset::BatchDataset( std::shared_ptr dataset, int64_t batchsize, BatchDatasetPolicy policy /* = BatchDatasetPolicy::INCLUDE_LAST */, - const std::vector& batchfns /* = {} */) - : dataset_(dataset), - batchSize_(batchsize), - batchPolicy_(policy), - batchFns_(batchfns) { - if (!dataset_) { - throw std::invalid_argument("dataset to be batched is null"); - } - if (batchSize_ <= 0) { - throw std::invalid_argument("invalid batch size"); - } - preBatchSize_ = dataset_->size(); - switch (batchPolicy_) { - case BatchDatasetPolicy::INCLUDE_LAST: - size_ = std::ceil(static_cast(preBatchSize_) / batchSize_); - break; - case BatchDatasetPolicy::SKIP_LAST: - size_ = std::floor(static_cast(preBatchSize_) / batchSize_); - break; - case BatchDatasetPolicy::DIVISIBLE_ONLY: - if (size_ % batchSize_ != 0) { - throw std::invalid_argument( - "dataset is not evenly divisible into batches"); - } - size_ = std::ceil(static_cast(preBatchSize_) / batchSize_); - break; - default: - throw std::invalid_argument("unknown BatchDatasetPolicy"); - } + const std::vector& batchfns /* = {} */ +) : dataset_(dataset), + batchSize_(batchsize), + batchPolicy_(policy), + batchFns_(batchfns) { + if(!dataset_) { + throw std::invalid_argument("dataset to be batched is null"); + } + if(batchSize_ <= 0) { + throw std::invalid_argument("invalid batch size"); + } + preBatchSize_ = dataset_->size(); + switch(batchPolicy_) { + case BatchDatasetPolicy::INCLUDE_LAST: + size_ = std::ceil(static_cast(preBatchSize_) / batchSize_); + break; + case BatchDatasetPolicy::SKIP_LAST: + size_ = std::floor(static_cast(preBatchSize_) / batchSize_); + break; + case BatchDatasetPolicy::DIVISIBLE_ONLY: + if(size_ % batchSize_ != 0) { + throw std::invalid_argument( + "dataset is not evenly divisible into batches" + ); + } + size_ = std::ceil(static_cast(preBatchSize_) / batchSize_); + break; + default: + throw std::invalid_argument("unknown BatchDatasetPolicy"); + } } BatchDataset::BatchDataset( std::shared_ptr dataset, const std::vector& batchSizes, - const std::vector& batchfns /* = {} */) - : dataset_(dataset), cumSumBatchSize_(batchSizes), batchFns_(batchfns) { - if (!dataset_) { - throw std::invalid_argument("dataset to be batched is null"); - } - if (cumSumBatchSize_.empty()) { - throw std::invalid_argument("batch size vector should not be empty"); - } - std::partial_sum( - cumSumBatchSize_.begin(), - cumSumBatchSize_.end(), - cumSumBatchSize_.begin()); - preBatchSize_ = dataset_->size(); - size_ = cumSumBatchSize_.size(); + const std::vector& batchfns /* = {} */ +) : dataset_(dataset), + cumSumBatchSize_(batchSizes), + batchFns_(batchfns) { + if(!dataset_) { + throw std::invalid_argument("dataset to be batched is null"); + } + if(cumSumBatchSize_.empty()) { + throw std::invalid_argument("batch size vector should not be empty"); + } + std::partial_sum( + cumSumBatchSize_.begin(), + cumSumBatchSize_.end(), + cumSumBatchSize_.begin() + ); + preBatchSize_ = dataset_->size(); + size_ = cumSumBatchSize_.size(); } std::vector BatchDataset::get(const int64_t idx) const { - checkIndexBounds(idx); - int64_t start, end; - if (cumSumBatchSize_.empty()) { - // batchsize is given - start = batchSize_ * idx; - end = std::min(start + batchSize_, preBatchSize_); - } else { - // specific batchsizes array is provided - start = idx == 0 ? 0 : cumSumBatchSize_[idx - 1]; - end = std::min(cumSumBatchSize_[idx], preBatchSize_); - } - return makeBatchFromRange(dataset_, batchFns_, start, end); + checkIndexBounds(idx); + int64_t start, end; + if(cumSumBatchSize_.empty()) { + // batchsize is given + start = batchSize_ * idx; + end = std::min(start + batchSize_, preBatchSize_); + } else { + // specific batchsizes array is provided + start = idx == 0 ? 0 : cumSumBatchSize_[idx - 1]; + end = std::min(cumSumBatchSize_[idx], preBatchSize_); + } + return makeBatchFromRange(dataset_, batchFns_, start, end); } int64_t BatchDataset::size() const { - return size_; + return size_; } } // namespace fl diff --git a/flashlight/fl/dataset/BatchDataset.h b/flashlight/fl/dataset/BatchDataset.h index 06e2300..a9f4be2 100644 --- a/flashlight/fl/dataset/BatchDataset.h +++ b/flashlight/fl/dataset/BatchDataset.h @@ -17,13 +17,13 @@ namespace fl { * exactly divisible by `batchsize` while performing batching. */ enum class BatchDatasetPolicy { - /// The last samples not evenly divisible by `batchsize` are packed - /// into a smaller-than-usual batch. - INCLUDE_LAST = 0, - /// The last samples not evenly divisible by `batchsize` are skipped. - SKIP_LAST = 1, - /// Constructor raises an error if sizes are not divisible. - DIVISIBLE_ONLY = 2, + /// The last samples not evenly divisible by `batchsize` are packed + /// into a smaller-than-usual batch. + INCLUDE_LAST = 0, + /// The last samples not evenly divisible by `batchsize` are skipped. + SKIP_LAST = 1, + /// Constructor raises an error if sizes are not divisible. + DIVISIBLE_ONLY = 2, }; // TODO: add RANDOM_LAST to fill up last examples with random ones? @@ -34,64 +34,64 @@ enum class BatchDatasetPolicy { * and it batches along the first singleton dimension. * * Example: - \code{.cpp} - // Make a dataset containing 42 tensors of shape [5, 4] - auto tensor = fl::rand({5, 4, 42}); - std::vector fields{{tensor}}; - auto ds = std::make_shared(fields); + \code{.cpp} + // Make a dataset containing 42 tensors of shape [5, 4] + auto tensor = fl::rand({5, 4, 42}); + std::vector fields{{tensor}}; + auto ds = std::make_shared(fields); - // Batch them with batchsize=10 - BatchDataset batchds(ds, 10, BatchDatasetPolicy::INCLUDE_LAST); - std::cout << batchds.get(0)[0].shape() << "\n"; // 5 4 10 1 - std::cout << batchds.get(4)[0].shape() << "\n"; // 5 4 2 1 + // Batch them with batchsize=10 + BatchDataset batchds(ds, 10, BatchDatasetPolicy::INCLUDE_LAST); + std::cout << batchds.get(0)[0].shape() << "\n"; // 5 4 10 1 + std::cout << batchds.get(4)[0].shape() << "\n"; // 5 4 2 1 - // create batch sizes vector specifying the each batch size (dynamic) - std::vector batchSizes = {5, 10, 5, 10, 2, 10} + // create batch sizes vector specifying the each batch size (dynamic) + std::vector batchSizes = {5, 10, 5, 10, 2, 10} - // Batch them with batchSizes - DynamicBatchDataset batchdsDynamic(ds, batchSizes); - std::cout << batchdsDynamic.get(0)[0].shape() << "\n"; // 5 4 5 1 - std::cout << batchdsDynamic.get(5)[0].shape() << "\n"; // 5 4 10 1 - \endcode + // Batch them with batchSizes + DynamicBatchDataset batchdsDynamic(ds, batchSizes); + std::cout << batchdsDynamic.get(0)[0].shape() << "\n"; // 5 4 5 1 + std::cout << batchdsDynamic.get(5)[0].shape() << "\n"; // 5 4 10 1 + \endcode */ class FL_API BatchDataset : public Dataset { - public: - /** - * Creates a `BatchDataset`. - * @param[in] dataset The underlying dataset. - * @param[in] batchsize The desired batch size. - * @param[in] policy How to handle the last batch if sizes are indivisible. - * @param[in] batchfns Custom batch function to use for difference indices. - */ - BatchDataset( - std::shared_ptr dataset, - int64_t batchsize, - BatchDatasetPolicy policy = BatchDatasetPolicy::INCLUDE_LAST, - const std::vector& batchfns = {}); +public: + /** + * Creates a `BatchDataset`. + * @param[in] dataset The underlying dataset. + * @param[in] batchsize The desired batch size. + * @param[in] policy How to handle the last batch if sizes are indivisible. + * @param[in] batchfns Custom batch function to use for difference indices. + */ + BatchDataset( + std::shared_ptr dataset, + int64_t batchsize, + BatchDatasetPolicy policy = BatchDatasetPolicy::INCLUDE_LAST, + const std::vector& batchfns = {}); - /** - * Creates a `BatchDataset`. - * @param[in] dataset The underlying dataset. - * @param[in] batchSizes desired batch sizes (dynamic). - * @param[in] batchfns Custom batch function to use for difference indices. - */ - BatchDataset( - std::shared_ptr dataset, - const std::vector& batchSizes, - const std::vector& batchfns = {}); + /** + * Creates a `BatchDataset`. + * @param[in] dataset The underlying dataset. + * @param[in] batchSizes desired batch sizes (dynamic). + * @param[in] batchfns Custom batch function to use for difference indices. + */ + BatchDataset( + std::shared_ptr dataset, + const std::vector& batchSizes, + const std::vector& batchfns = {}); - int64_t size() const override; + int64_t size() const override; - std::vector get(const int64_t idx) const override; + std::vector get(const int64_t idx) const override; - private: - std::shared_ptr dataset_; - int64_t batchSize_; - BatchDatasetPolicy batchPolicy_; - std::vector cumSumBatchSize_; - std::vector batchFns_; +private: + std::shared_ptr dataset_; + int64_t batchSize_; + BatchDatasetPolicy batchPolicy_; + std::vector cumSumBatchSize_; + std::vector batchFns_; - int64_t preBatchSize_; // Size of the dataset before batching - int64_t size_; + int64_t preBatchSize_; // Size of the dataset before batching + int64_t size_; }; } // namespace fl diff --git a/flashlight/fl/dataset/BlobDataset.cpp b/flashlight/fl/dataset/BlobDataset.cpp index 1b605f5..c84be6b 100644 --- a/flashlight/fl/dataset/BlobDataset.cpp +++ b/flashlight/fl/dataset/BlobDataset.cpp @@ -19,239 +19,247 @@ const int64_t magicNumber = 0x31626f6c423a6c66; BlobDatasetEntryBuffer::BlobDatasetEntryBuffer() = default; void BlobDatasetEntryBuffer::clear() { - data_.clear(); + data_.clear(); } int64_t BlobDatasetEntryBuffer::size() const { - return data_.size() / nFieldPerEntry_; + return data_.size() / nFieldPerEntry_; } void BlobDatasetEntryBuffer::resize(int64_t size) { - data_.resize(size * nFieldPerEntry_); + data_.resize(size * nFieldPerEntry_); } BlobDatasetEntry BlobDatasetEntryBuffer::get(const int64_t idx) const { - BlobDatasetEntry e; - auto dataIdx = idx * nFieldPerEntry_; - e.type = static_cast(data_[dataIdx++]); - unsigned numDims = data_[dataIdx++]; - e.dims = Shape(std::vector(numDims)); - for (int i = 0; i < numDims; i++) { - e.dims[i] = data_[dataIdx + i]; - } - e.offset = data_[dataIdx + maxNDims_]; - return e; + BlobDatasetEntry e; + auto dataIdx = idx * nFieldPerEntry_; + e.type = static_cast(data_[dataIdx++]); + unsigned numDims = data_[dataIdx++]; + e.dims = Shape(std::vector(numDims)); + for(int i = 0; i < numDims; i++) { + e.dims[i] = data_[dataIdx + i]; + } + e.offset = data_[dataIdx + maxNDims_]; + return e; } void BlobDatasetEntryBuffer::add(const BlobDatasetEntry& e) { - data_.push_back(static_cast(e.type)); - data_.push_back(static_cast(e.dims.ndim())); - int i = 0; - for (; i < e.dims.ndim(); i++) { - data_.push_back(e.dims[i]); - } - for (; i < maxNDims_; ++i) { - data_.push_back(1); // placeholder dim - } - data_.push_back(e.offset); + data_.push_back(static_cast(e.type)); + data_.push_back(static_cast(e.dims.ndim())); + int i = 0; + for(; i < e.dims.ndim(); i++) { + data_.push_back(e.dims[i]); + } + for(; i < maxNDims_; ++i) { + data_.push_back(1); // placeholder dim + } + data_.push_back(e.offset); } char* BlobDatasetEntryBuffer::data() { - return (char*)data_.data(); + return (char*) data_.data(); } int64_t BlobDatasetEntryBuffer::bytes() const { - return data_.size() * sizeof(int64_t); + return data_.size() * sizeof(int64_t); }; BlobDataset::BlobDataset() = default; int64_t BlobDataset::size() const { - return offsets_.size(); + return offsets_.size(); } std::vector BlobDataset::get(const int64_t idx) const { - std::vector sample; - for (int64_t i = 0; i < sizes_.at(idx); i++) { - auto entry = entries_.get(offsets_.at(idx) + i); - sample.push_back(readArray(entry, i)); - } - return sample; + std::vector sample; + for(int64_t i = 0; i < sizes_.at(idx); i++) { + auto entry = entries_.get(offsets_.at(idx) + i); + sample.push_back(readArray(entry, i)); + } + return sample; }; std::vector> BlobDataset::rawGet(const int64_t idx) const { - std::vector> sample; - for (int64_t i = 0; i < sizes_.at(idx); i++) { - auto entry = entries_.get(offsets_.at(idx) + i); - sample.push_back(readRawArray(entry)); - } - return sample; + std::vector> sample; + for(int64_t i = 0; i < sizes_.at(idx); i++) { + auto entry = entries_.get(offsets_.at(idx) + i); + sample.push_back(readRawArray(entry)); + } + return sample; }; void BlobDataset::add(const std::vector& sample) { - int64_t entryOffset; - { - std::lock_guard lock(mutex_); - entryOffset = entries_.size(); - offsets_.push_back(entries_.size()); - sizes_.push_back(sample.size()); - for (const auto& tensor : sample) { - if (tensor.ndim() > maxNDims_) { - throw std::invalid_argument( - "BlobDataset::add - no support for serialization of " - "tensors with > 4 dimensions"); - } - BlobDatasetEntry e; - e.type = tensor.type(); - e.dims = tensor.shape(); - e.offset = indexOffset_; - indexOffset_ += tensor.bytes(); - entries_.add(e); + int64_t entryOffset; + { + std::lock_guard lock(mutex_); + entryOffset = entries_.size(); + offsets_.push_back(entries_.size()); + sizes_.push_back(sample.size()); + for(const auto& tensor : sample) { + if(tensor.ndim() > maxNDims_) { + throw std::invalid_argument( + "BlobDataset::add - no support for serialization of " + "tensors with > 4 dimensions" + ); + } + BlobDatasetEntry e; + e.type = tensor.type(); + e.dims = tensor.shape(); + e.offset = indexOffset_; + indexOffset_ += tensor.bytes(); + entries_.add(e); + } + } + for(int64_t i = 0; i < sample.size(); i++) { + auto& array = sample[i]; + const auto& e = entries_.get(entryOffset + i); + writeArray(e, array); } - } - for (int64_t i = 0; i < sample.size(); i++) { - auto& array = sample[i]; - const auto& e = entries_.get(entryOffset + i); - writeArray(e, array); - } } void BlobDataset::add(const BlobDataset& blob, int64_t chunkSize) { - std::lock_guard lock(mutex_); - if (chunkSize <= 0) { - throw std::runtime_error("chunkSize must be positive"); - } - sizes_.insert(sizes_.end(), blob.sizes_.begin(), blob.sizes_.end()); - std::vector offsets = blob.offsets_; - for (auto& offset : offsets) { - offset += entries_.size(); - } - offsets_.insert(offsets_.end(), offsets.begin(), offsets.end()); - for (int64_t i = 0; i < blob.entries_.size(); i++) { - auto e = blob.entries_.get(i); - e.offset += indexOffset_ - 2 * sizeof(int64_t); - entries_.add(e); - } - int64_t blobOffset = 2 * sizeof(int64_t); - int64_t copySize = blob.indexOffset_ - blobOffset; - int64_t nChunk = copySize / chunkSize; - int64_t remainCopySize = copySize - nChunk * chunkSize; - std::vector buffer; - auto copyChunk = [&buffer, &blob, this, &blobOffset](int64_t size) { - buffer.resize(size); - blob.readData(blobOffset, buffer.data(), size); - blobOffset += size; - this->writeData(indexOffset_, buffer.data(), size); - this->indexOffset_ += size; - }; - for (int64_t i = 0; i < nChunk; i++) { - copyChunk(chunkSize); - } - if (remainCopySize > 0) { - copyChunk(remainCopySize); - } + std::lock_guard lock(mutex_); + if(chunkSize <= 0) { + throw std::runtime_error("chunkSize must be positive"); + } + sizes_.insert(sizes_.end(), blob.sizes_.begin(), blob.sizes_.end()); + std::vector offsets = blob.offsets_; + for(auto& offset : offsets) { + offset += entries_.size(); + } + offsets_.insert(offsets_.end(), offsets.begin(), offsets.end()); + for(int64_t i = 0; i < blob.entries_.size(); i++) { + auto e = blob.entries_.get(i); + e.offset += indexOffset_ - 2 * sizeof(int64_t); + entries_.add(e); + } + int64_t blobOffset = 2 * sizeof(int64_t); + int64_t copySize = blob.indexOffset_ - blobOffset; + int64_t nChunk = copySize / chunkSize; + int64_t remainCopySize = copySize - nChunk * chunkSize; + std::vector buffer; + auto copyChunk = [&buffer, &blob, this, &blobOffset](int64_t size) { + buffer.resize(size); + blob.readData(blobOffset, buffer.data(), size); + blobOffset += size; + this->writeData(indexOffset_, buffer.data(), size); + this->indexOffset_ += size; + }; + for(int64_t i = 0; i < nChunk; i++) { + copyChunk(chunkSize); + } + if(remainCopySize > 0) { + copyChunk(remainCopySize); + } } std::vector BlobDataset::readRawArray( - const BlobDatasetEntry& e) const { - std::vector buffer; - if (e.dims.elements() > 0) { - buffer.resize(fl::getTypeSize(e.type) * e.dims.elements()); - readData( - e.offset, - reinterpret_cast(buffer.data()), - fl::getTypeSize(e.type) * e.dims.elements()); - } - return buffer; + const BlobDatasetEntry& e +) const { + std::vector buffer; + if(e.dims.elements() > 0) { + buffer.resize(fl::getTypeSize(e.type) * e.dims.elements()); + readData( + e.offset, + reinterpret_cast(buffer.data()), + fl::getTypeSize(e.type) * e.dims.elements() + ); + } + return buffer; } Tensor BlobDataset::readArray(const BlobDatasetEntry& e, int i) const { - if (e.dims.elements() > 0) { - auto buffer = readRawArray(e); - auto keyval = hostTransforms_.find(i); - if (keyval == hostTransforms_.end()) { - return Tensor::fromBuffer( - e.dims, e.type, buffer.data(), MemoryLocation::Host); + if(e.dims.elements() > 0) { + auto buffer = readRawArray(e); + auto keyval = hostTransforms_.find(i); + if(keyval == hostTransforms_.end()) { + return Tensor::fromBuffer( + e.dims, + e.type, + buffer.data(), + MemoryLocation::Host + ); + } else { + return keyval->second(buffer.data(), e.dims, e.type); + } } else { - return keyval->second(buffer.data(), e.dims, e.type); + return Tensor(); } - } else { - return Tensor(); - } } void BlobDataset::writeArray(const BlobDatasetEntry& e, const Tensor& array) { - std::vector buffer(array.bytes()); - array.host(buffer.data()); - writeData(e.offset, (char*)buffer.data(), buffer.size()); + std::vector buffer(array.bytes()); + array.host(buffer.data()); + writeData(e.offset, (char*) buffer.data(), buffer.size()); } void BlobDataset::writeIndex() { - std::lock_guard lock(mutex_); - - int64_t offset = 0; - offset += writeData(offset, (char*)&magicNumber, sizeof(int64_t)); - writeData(offset, (char*)&indexOffset_, sizeof(int64_t)); - - offset = indexOffset_; - int64_t size = offsets_.size(); - int64_t entriesSize = entries_.size(); - offset += writeData(offset, (char*)&size, sizeof(int64_t)); - offset += writeData(offset, (char*)&entriesSize, sizeof(int64_t)); - offset += writeData(offset, (char*)sizes_.data(), sizeof(int64_t) * size); - offset += writeData(offset, (char*)offsets_.data(), sizeof(int64_t) * size); - writeData(offset, entries_.data(), entries_.bytes()); - flushData(); + std::lock_guard lock(mutex_); + + int64_t offset = 0; + offset += writeData(offset, (char*) &magicNumber, sizeof(int64_t)); + writeData(offset, (char*) &indexOffset_, sizeof(int64_t)); + + offset = indexOffset_; + int64_t size = offsets_.size(); + int64_t entriesSize = entries_.size(); + offset += writeData(offset, (char*) &size, sizeof(int64_t)); + offset += writeData(offset, (char*) &entriesSize, sizeof(int64_t)); + offset += writeData(offset, (char*) sizes_.data(), sizeof(int64_t) * size); + offset += writeData(offset, (char*) offsets_.data(), sizeof(int64_t) * size); + writeData(offset, entries_.data(), entries_.bytes()); + flushData(); } void BlobDataset::readIndex() { - std::lock_guard lock(mutex_); - - entries_.clear(); - - if (isEmptyData()) { - // skip magic number and index location - indexOffset_ = 2 * sizeof(int64_t); - return; - } - - int64_t magicNumberCheck = 0; - int64_t offset = readData(0, (char*)&magicNumberCheck, sizeof(int64_t)); - if (magicNumber != magicNumberCheck) { - throw std::runtime_error("BlobDataset::readIndex - not a fl::BlobDataset"); - } - readData(offset, (char*)&indexOffset_, sizeof(int64_t)); - offset = indexOffset_; - - int64_t size; - int64_t entriesSize; - offset += readData(offset, (char*)&size, sizeof(int64_t)); - offset += readData(offset, (char*)&entriesSize, sizeof(int64_t)); - sizes_.resize(size); - offsets_.resize(size); - entries_.resize(entriesSize); - - offset += readData(offset, (char*)sizes_.data(), sizeof(int64_t) * size); - offset += readData(offset, (char*)offsets_.data(), sizeof(int64_t) * size); - readData(offset, entries_.data(), entries_.bytes()); + std::lock_guard lock(mutex_); + + entries_.clear(); + + if(isEmptyData()) { + // skip magic number and index location + indexOffset_ = 2 * sizeof(int64_t); + return; + } + + int64_t magicNumberCheck = 0; + int64_t offset = readData(0, (char*) &magicNumberCheck, sizeof(int64_t)); + if(magicNumber != magicNumberCheck) { + throw std::runtime_error("BlobDataset::readIndex - not a fl::BlobDataset"); + } + readData(offset, (char*) &indexOffset_, sizeof(int64_t)); + offset = indexOffset_; + + int64_t size; + int64_t entriesSize; + offset += readData(offset, (char*) &size, sizeof(int64_t)); + offset += readData(offset, (char*) &entriesSize, sizeof(int64_t)); + sizes_.resize(size); + offsets_.resize(size); + entries_.resize(entriesSize); + + offset += readData(offset, (char*) sizes_.data(), sizeof(int64_t) * size); + offset += readData(offset, (char*) offsets_.data(), sizeof(int64_t) * size); + readData(offset, entries_.data(), entries_.bytes()); } void BlobDataset::flush() { - flushData(); + flushData(); } void BlobDataset::setHostTransform( int field, - std::function func) { - hostTransforms_[field] = func; + std::function func +) { + hostTransforms_[field] = func; } std::vector BlobDataset::getEntries(const int64_t idx) const { - std::vector entries; - for (int64_t i = 0; i < sizes_.at(idx); i++) { - entries.push_back(entries_.get(offsets_.at(idx) + i)); - } - return entries; + std::vector entries; + for(int64_t i = 0; i < sizes_.at(idx); i++) { + entries.push_back(entries_.get(offsets_.at(idx) + i)); + } + return entries; } BlobDataset::~BlobDataset() = default; diff --git a/flashlight/fl/dataset/BlobDataset.h b/flashlight/fl/dataset/BlobDataset.h index 1067069..7401fa2 100644 --- a/flashlight/fl/dataset/BlobDataset.h +++ b/flashlight/fl/dataset/BlobDataset.h @@ -38,163 +38,164 @@ namespace fl { * * * For advanced users, the format of the blob is the following: - \code{.unparsed} - - - ---- raw data ---- - - ... - - ---- index ---- - - - - - - \endcode - * + \code{.unparsed} + + + ---- raw data ---- + + ... + + ---- index ---- + + + + + + \endcode + * */ struct FL_API BlobDatasetEntry { - fl::dtype type; - fl::Shape dims; - int64_t offset; + fl::dtype type; + fl::Shape dims; + int64_t offset; }; class FL_API BlobDatasetEntryBuffer { - private: - std::vector data_; - const int nFieldPerEntry_ = 7; - - public: - static const int maxNDims_ = 4; // max dims supported based on index entries - - BlobDatasetEntryBuffer(); - void clear(); - int64_t size() const; - void resize(int64_t size); - BlobDatasetEntry get(const int64_t idx) const; - void add(const BlobDatasetEntry& entry); - char* data(); - int64_t bytes() const; +private: + std::vector data_; + const int nFieldPerEntry_ = 7; + +public: + static const int maxNDims_ = 4; // max dims supported based on index entries + + BlobDatasetEntryBuffer(); + void clear(); + int64_t size() const; + void resize(int64_t size); + BlobDatasetEntry get(const int64_t idx) const; + void add(const BlobDatasetEntry& entry); + char* data(); + int64_t bytes() const; }; class FL_API BlobDataset : public Dataset { - private: - const int maxNDims_ = BlobDatasetEntryBuffer::maxNDims_; - BlobDatasetEntryBuffer entries_; - std::vector sizes_; - std::vector offsets_; - int64_t indexOffset_; - std::unordered_map hostTransforms_; - mutable std::mutex mutex_; - - std::vector readRawArray(const BlobDatasetEntry& e) const; - Tensor readArray(const BlobDatasetEntry& e, int i) const; - void writeArray(const BlobDatasetEntry& e, const Tensor& array); - - protected: - void readIndex(); - - /** - * Write raw data in the blob. - * Implementation must be thread-safe. - * @param[in] offset Offset in the blob in bytes. - * @param[in] data Raw data bytes. - * @param[in] size Raw data size in bytes. - */ - virtual int64_t writeData(int64_t offset, const char* data, int64_t size) - const = 0; - - /** - * Read raw data in the blob. - * Implementation must be thread-safe. - * @param[in] offset Offset in the blob in bytes. - * @param[out] data Raw data bytes. - * @param[in] size Raw data size in bytes. - */ - virtual int64_t readData(int64_t offset, char* data, int64_t size) const = 0; - - /** - * Ensures all written data is flushed in the blob. - * Implementation must be thread-safe. - */ - virtual void flushData() = 0; - - /** - * Return true iff the blob is empty. - * Implementation must be thread-safe. - */ - virtual bool isEmptyData() const = 0; - - public: - /** - * Creates a `BlobDataset`, specifying a blob file name. - * @param[in] name A blob file name. - * @param[in] rw If true, opens in read-write mode. This must be specified - * to use the add() and writeIndex() methods. Except if truncate is true, - * previous stored samples will be read. - * @param[in] truncate In read-write mode, truncate the files if it - * already exists. - */ - BlobDataset(); - - int64_t size() const override; - - std::vector get(const int64_t idx) const override; - - /** - * Return raw data stored in given sample. Dimensions and types of each array - * can be retrieved with getEntries(). - * @param[in] idx An index in the dataset. - */ - std::vector> rawGet(const int64_t idx) const; - - /** - * Add a new sample in the dataset. The dataset must have been opened in - * read-write mode. Data is guaranteed to be on disk only after a flush(). - * @param[in] sample A vector of arrays, possibly of heterogeneous types and - * sizes. - */ - void add(const std::vector& sample); - - /** - * Add an entire blob to the current blob. This efficiently concatenate - * blobs by reading and writing (possibly large) chunks. - * @param[in] blob The blob to be added. - * @param[in] chunkSize Read-write chunk size. - */ - void add(const BlobDataset& blob, int64_t chunkSize = 104857600); - - /** - * Flush all data on disk. The dataset must have been opened in - * read-write mode. - */ - void flush(); - - /** - * Write index and flush data. - */ - void writeIndex(); - - /** - * Set a host transform on specified field. If a host transform is - * specified, it will be called to load the data from host to Tensor - * (on device). - * @param[in] field The field on which to apply the transform. - * @param[in] func The corresponding transform. - */ - void setHostTransform( - int field, - std::function func); - - /** - * Return entries in the blob for a given sample index. - * @param[in] idx A sample index. - */ - std::vector getEntries(const int64_t idx) const; - - virtual ~BlobDataset() override; +private: + const int maxNDims_ = BlobDatasetEntryBuffer::maxNDims_; + BlobDatasetEntryBuffer entries_; + std::vector sizes_; + std::vector offsets_; + int64_t indexOffset_; + std::unordered_map hostTransforms_; + mutable std::mutex mutex_; + + std::vector readRawArray(const BlobDatasetEntry& e) const; + Tensor readArray(const BlobDatasetEntry& e, int i) const; + void writeArray(const BlobDatasetEntry& e, const Tensor& array); + +protected: + void readIndex(); + + /** + * Write raw data in the blob. + * Implementation must be thread-safe. + * @param[in] offset Offset in the blob in bytes. + * @param[in] data Raw data bytes. + * @param[in] size Raw data size in bytes. + */ + virtual int64_t writeData(int64_t offset, const char* data, int64_t size) + const = 0; + + /** + * Read raw data in the blob. + * Implementation must be thread-safe. + * @param[in] offset Offset in the blob in bytes. + * @param[out] data Raw data bytes. + * @param[in] size Raw data size in bytes. + */ + virtual int64_t readData(int64_t offset, char* data, int64_t size) const = 0; + + /** + * Ensures all written data is flushed in the blob. + * Implementation must be thread-safe. + */ + virtual void flushData() = 0; + + /** + * Return true iff the blob is empty. + * Implementation must be thread-safe. + */ + virtual bool isEmptyData() const = 0; + +public: + /** + * Creates a `BlobDataset`, specifying a blob file name. + * @param[in] name A blob file name. + * @param[in] rw If true, opens in read-write mode. This must be specified + * to use the add() and writeIndex() methods. Except if truncate is true, + * previous stored samples will be read. + * @param[in] truncate In read-write mode, truncate the files if it + * already exists. + */ + BlobDataset(); + + int64_t size() const override; + + std::vector get(const int64_t idx) const override; + + /** + * Return raw data stored in given sample. Dimensions and types of each array + * can be retrieved with getEntries(). + * @param[in] idx An index in the dataset. + */ + std::vector> rawGet(const int64_t idx) const; + + /** + * Add a new sample in the dataset. The dataset must have been opened in + * read-write mode. Data is guaranteed to be on disk only after a flush(). + * @param[in] sample A vector of arrays, possibly of heterogeneous types and + * sizes. + */ + void add(const std::vector& sample); + + /** + * Add an entire blob to the current blob. This efficiently concatenate + * blobs by reading and writing (possibly large) chunks. + * @param[in] blob The blob to be added. + * @param[in] chunkSize Read-write chunk size. + */ + void add(const BlobDataset& blob, int64_t chunkSize = 104857600); + + /** + * Flush all data on disk. The dataset must have been opened in + * read-write mode. + */ + void flush(); + + /** + * Write index and flush data. + */ + void writeIndex(); + + /** + * Set a host transform on specified field. If a host transform is + * specified, it will be called to load the data from host to Tensor + * (on device). + * @param[in] field The field on which to apply the transform. + * @param[in] func The corresponding transform. + */ + void setHostTransform( + int field, + std::function func + ); + + /** + * Return entries in the blob for a given sample index. + * @param[in] idx A sample index. + */ + std::vector getEntries(const int64_t idx) const; + + virtual ~BlobDataset() override; }; } // namespace fl diff --git a/flashlight/fl/dataset/ConcatDataset.cpp b/flashlight/fl/dataset/ConcatDataset.cpp index c81fa63..e939808 100644 --- a/flashlight/fl/dataset/ConcatDataset.cpp +++ b/flashlight/fl/dataset/ConcatDataset.cpp @@ -12,30 +12,34 @@ namespace fl { ConcatDataset::ConcatDataset( - const std::vector>& datasets) - : datasets_(datasets), size_(0) { - if (datasets.empty()) { - throw std::invalid_argument("cannot concat 0 datasets"); - } - cumulativedatasetsizes_.emplace_back(0); - for (const auto& dataset : datasets_) { - size_ += dataset->size(); - cumulativedatasetsizes_.emplace_back(size_); - } + const std::vector>& datasets +) : datasets_(datasets), + size_(0) { + if(datasets.empty()) { + throw std::invalid_argument("cannot concat 0 datasets"); + } + cumulativedatasetsizes_.emplace_back(0); + for(const auto& dataset : datasets_) { + size_ += dataset->size(); + cumulativedatasetsizes_.emplace_back(size_); + } } std::vector ConcatDataset::get(const int64_t idx) const { - checkIndexBounds(idx); + checkIndexBounds(idx); - // get sample from correct dataset - int64_t datasetidx = - std::upper_bound( - cumulativedatasetsizes_.begin(), cumulativedatasetsizes_.end(), idx) - - cumulativedatasetsizes_.begin() - 1; - return datasets_[datasetidx]->get(idx - cumulativedatasetsizes_[datasetidx]); + // get sample from correct dataset + int64_t datasetidx = + std::upper_bound( + cumulativedatasetsizes_.begin(), + cumulativedatasetsizes_.end(), + idx + ) + - cumulativedatasetsizes_.begin() - 1; + return datasets_[datasetidx]->get(idx - cumulativedatasetsizes_[datasetidx]); } int64_t ConcatDataset::size() const { - return size_; + return size_; } } // namespace fl diff --git a/flashlight/fl/dataset/ConcatDataset.h b/flashlight/fl/dataset/ConcatDataset.h index 217ada3..59994db 100644 --- a/flashlight/fl/dataset/ConcatDataset.h +++ b/flashlight/fl/dataset/ConcatDataset.h @@ -18,38 +18,39 @@ namespace fl { * concatenated in sequential order. * * Example: - \code{.cpp} - // Make two datasets with sizes 10 and 20 - auto makeDataset = [](int size) { + \code{.cpp} + // Make two datasets with sizes 10 and 20 + auto makeDataset = [](int size) { auto tensor = fl::rand({5, 4, size}); std::vector fields{tensor}; return std::make_shared(fields); - }; - auto ds1 = makeDataset(10); - auto ds2 = makeDataset(20); - - // Concatenate them - ConcatDataset concatds({ds1, ds2}); - std::cout << concatds.size() << "\n"; // 30 - std::cout << allClose(concatds.get(15)[0], ds2->get(5)[0]) << "\n"; // 1 - \endcode + }; + auto ds1 = makeDataset(10); + auto ds2 = makeDataset(20); + + // Concatenate them + ConcatDataset concatds({ds1, ds2}); + std::cout << concatds.size() << "\n"; // 30 + std::cout << allClose(concatds.get(15)[0], ds2->get(5)[0]) << "\n"; // 1 + \endcode */ class FL_API ConcatDataset : public Dataset { - public: - /** - * Creates a `ConcatDataset`. - * @param[in] datasets The underlying datasets. - */ - explicit ConcatDataset( - const std::vector>& datasets); - - int64_t size() const override; - - std::vector get(const int64_t idx) const override; - - private: - std::vector> datasets_; - std::vector cumulativedatasetsizes_; - int64_t size_; +public: + /** + * Creates a `ConcatDataset`. + * @param[in] datasets The underlying datasets. + */ + explicit ConcatDataset( + const std::vector>& datasets + ); + + int64_t size() const override; + + std::vector get(const int64_t idx) const override; + +private: + std::vector> datasets_; + std::vector cumulativedatasetsizes_; + int64_t size_; }; } // namespace fl diff --git a/flashlight/fl/dataset/Dataset.h b/flashlight/fl/dataset/Dataset.h index d34c982..c18bbfb 100644 --- a/flashlight/fl/dataset/Dataset.h +++ b/flashlight/fl/dataset/Dataset.h @@ -29,63 +29,63 @@ namespace fl { * ownership of underlying `Dataset`s. */ class FL_API Dataset { - public: - /** - * A bijective mapping of dataset indices \f$[0, n) \to [0, n)\f$. - */ - using PermutationFunction = std::function; - - /** - * A function to transform an array. - */ - using TransformFunction = std::function; - - /** - * A function to load data from a file into an array. - */ - using LoadFunction = std::function; - - /** - * A function to pack arrays into a batched array. - */ - using BatchFunction = std::function&)>; - - /** - * A function to transform data from host to array. - */ - using DataTransformFunction = - std::function; - - /** - * @return The size of the dataset. - */ - virtual int64_t size() const = 0; - - /** - * @param[in] idx Index of the sample in the dataset. Must be in [0, size()). - * @return The sample fields (a `std::vector`). - */ - virtual std::vector get(const int64_t idx) const = 0; - - virtual ~Dataset() = default; - - // Setup iterators - using iterator = detail::DatasetIterator>; - - iterator begin() { - return iterator(this); - } - - iterator end() { - return iterator(); - } - - protected: - void checkIndexBounds(int64_t idx) const { - if (!(idx >= 0 && idx < size())) { - throw std::out_of_range("Dataset idx out of range"); +public: + /** + * A bijective mapping of dataset indices \f$[0, n) \to [0, n)\f$. + */ + using PermutationFunction = std::function; + + /** + * A function to transform an array. + */ + using TransformFunction = std::function; + + /** + * A function to load data from a file into an array. + */ + using LoadFunction = std::function; + + /** + * A function to pack arrays into a batched array. + */ + using BatchFunction = std::function&)>; + + /** + * A function to transform data from host to array. + */ + using DataTransformFunction = + std::function; + + /** + * @return The size of the dataset. + */ + virtual int64_t size() const = 0; + + /** + * @param[in] idx Index of the sample in the dataset. Must be in [0, size()). + * @return The sample fields (a `std::vector`). + */ + virtual std::vector get(const int64_t idx) const = 0; + + virtual ~Dataset() = default; + + // Setup iterators + using iterator = detail::DatasetIterator>; + + iterator begin() { + return iterator(this); + } + + iterator end() { + return iterator(); + } + +protected: + void checkIndexBounds(int64_t idx) const { + if(!(idx >= 0 && idx < size())) { + throw std::out_of_range("Dataset idx out of range"); + } } - } }; } // namespace fl diff --git a/flashlight/fl/dataset/DatasetIterator.h b/flashlight/fl/dataset/DatasetIterator.h index a6739ec..813ecf1 100644 --- a/flashlight/fl/dataset/DatasetIterator.h +++ b/flashlight/fl/dataset/DatasetIterator.h @@ -16,64 +16,66 @@ namespace detail { * STL style iterator class to easily iterate over a dataset. * * Example: - \ code{.cpp} - Tensor tensor = fl::rand({1, 2, 3}); - TensorDataset tensords(std::vector{tensor}); - for (auto& sample : tensords) { + \ code{.cpp} + Tensor tensor = fl::rand({1, 2, 3}); + TensorDataset tensords(std::vector{tensor}); + for (auto& sample : tensords) { // do something - } - \endcode + } + \endcode */ -template -class DatasetIterator { - protected: - D* dataset_; - int64_t idx_; - F buffer_; + template + class DatasetIterator { + protected: + D* dataset_; + int64_t idx_; + F buffer_; - public: - // DatasetIterator traits, previously from std::iterator. - using value_type = F; - using reference = F&; - using pointer = F*; - using iterator_category = std::forward_iterator_tag; + public: + // DatasetIterator traits, previously from std::iterator. + using value_type = F; + using reference = F&; + using pointer = F*; + using iterator_category = std::forward_iterator_tag; - // Default constructible. - DatasetIterator() : dataset_(nullptr), idx_(-1) {} + // Default constructible. + DatasetIterator() : dataset_(nullptr), + idx_(-1) {} - explicit DatasetIterator(D* dataset) - : dataset_(dataset), idx_(dataset_->size() > 0 ? 0 : -1) {} + explicit DatasetIterator(D* dataset) : dataset_(dataset), + idx_(dataset_->size() > 0 ? 0 : -1) {} - // Dereferencable - reference operator*() { - buffer_ = dataset_->get(idx_); - return buffer_; - } + // Dereferencable + reference operator*() { + buffer_ = dataset_->get(idx_); + return buffer_; + } - // Pre- and post-incrementable. - DatasetIterator& operator++() { - if (++idx_ >= dataset_->size()) { - idx_ = -1; - } - return *this; - } + // Pre- and post-incrementable. + DatasetIterator& operator++() { + if(++idx_ >= dataset_->size()) { + idx_ = -1; + } + return *this; + } - DatasetIterator operator++(int) { - DatasetIterator tmp(*this); - if (++idx_ >= dataset_->size()) - idx_ = -1; - return tmp; - } + DatasetIterator operator++(int) { + DatasetIterator tmp(*this); + if(++idx_ >= dataset_->size()) { + idx_ = -1; + } + return tmp; + } - // Equality / inequality. - bool operator==(const DatasetIterator& that) const { - return (idx_ == that.idx_); - } + // Equality / inequality. + bool operator==(const DatasetIterator& that) const { + return idx_ == that.idx_; + } - bool operator!=(const DatasetIterator& that) const { - return (idx_ != that.idx_); - } -}; + bool operator!=(const DatasetIterator& that) const { + return idx_ != that.idx_; + } + }; } // namespace detail } // namespace fl diff --git a/flashlight/fl/dataset/FileBlobDataset.cpp b/flashlight/fl/dataset/FileBlobDataset.cpp index f432935..f1717ed 100644 --- a/flashlight/fl/dataset/FileBlobDataset.cpp +++ b/flashlight/fl/dataset/FileBlobDataset.cpp @@ -14,97 +14,99 @@ namespace fl { FileBlobDataset::FileBlobDataset( const fs::path& name, bool rw, - bool truncate) - : name_(name) { - mode_ = (rw ? std::ios_base::in | std::ios_base::out : std::ios_base::in) | - std::ios_base::binary; - { - std::ofstream fs(name_, (truncate ? mode_ | std::ios_base::trunc : mode_)); - if (!fs.is_open()) { - throw std::runtime_error("could not open file " + name.string()); + bool truncate +) : name_(name) { + mode_ = (rw ? std::ios_base::in | std::ios_base::out : std::ios_base::in) + | std::ios_base::binary; + { + std::ofstream fs(name_, (truncate ? mode_ | std::ios_base::trunc : mode_)); + if(!fs.is_open()) { + throw std::runtime_error("could not open file " + name.string()); + } } - } - readIndex(); + readIndex(); } std::shared_ptr FileBlobDataset::getStream() const { - static thread_local std::shared_ptr< - std::unordered_map>> - threadFileHandles = std::make_shared< - std::unordered_map>>(); + static thread_local std::shared_ptr< + std::unordered_map>> + threadFileHandles = std::make_shared< + std::unordered_map>>(); - // Get a per-thread file handle. - auto keyval = threadFileHandles->find(reinterpret_cast(this)); - if (keyval == threadFileHandles->end()) { - auto fs = std::make_shared(); - fs->exceptions( - std::ifstream::eofbit | std::ifstream::failbit | std::ifstream::badbit); - fs->open(name_, mode_); - threadFileHandles->insert({reinterpret_cast(this), fs}); - // Link threadFileHandles to the object - // so the file handle can be cleaned at destruction. - { - std::lock_guard lock(afhmutex_); - auto i = allFileHandles_.begin(); - bool match = false; - while (i != std::end(allFileHandles_)) { - auto ptr = i->lock(); - if (ptr) { - if (threadFileHandles == ptr) { - match = true; - } - ++i; - } else { - i = allFileHandles_.erase(i); + // Get a per-thread file handle. + auto keyval = threadFileHandles->find(reinterpret_cast(this)); + if(keyval == threadFileHandles->end()) { + auto fs = std::make_shared(); + fs->exceptions( + std::ifstream::eofbit | std::ifstream::failbit | std::ifstream::badbit + ); + fs->open(name_, mode_); + threadFileHandles->insert({reinterpret_cast(this), fs}); + // Link threadFileHandles to the object + // so the file handle can be cleaned at destruction. + { + std::lock_guard lock(afhmutex_); + auto i = allFileHandles_.begin(); + bool match = false; + while(i != std::end(allFileHandles_)) { + auto ptr = i->lock(); + if(ptr) { + if(threadFileHandles == ptr) { + match = true; + } + ++i; + } else { + i = allFileHandles_.erase(i); + } + } + if(!match) { + allFileHandles_.push_back(threadFileHandles); + } } - } - if (!match) { - allFileHandles_.push_back(threadFileHandles); - } + return fs; + } else { + return keyval->second; } - return fs; - } else { - return keyval->second; - } } int64_t FileBlobDataset::writeData( int64_t offset, const char* data, - int64_t size) const { - auto fs = getStream(); - fs->seekp(offset, std::ios_base::beg); - fs->write(data, size); - return fs->tellp() - offset; + int64_t size +) const { + auto fs = getStream(); + fs->seekp(offset, std::ios_base::beg); + fs->write(data, size); + return fs->tellp() - offset; } int64_t FileBlobDataset::readData(int64_t offset, char* data, int64_t size) - const { - auto fs = getStream(); - fs->seekg(offset, std::ios_base::beg); - fs->read(data, size); - return fs->tellg() - offset; +const { + auto fs = getStream(); + fs->seekg(offset, std::ios_base::beg); + fs->read(data, size); + return fs->tellg() - offset; } void FileBlobDataset::flushData() { - auto fs = getStream(); - fs->flush(); + auto fs = getStream(); + fs->flush(); } bool FileBlobDataset::isEmptyData() const { - auto fs = getStream(); - fs->seekg(0, std::ios_base::end); - return (fs->tellg() == 0); + auto fs = getStream(); + fs->seekg(0, std::ios_base::end); + return fs->tellg() == 0; } FileBlobDataset::~FileBlobDataset() { - std::lock_guard lock(afhmutex_); - for (auto& weakFileHandles : allFileHandles_) { - auto fileHandles = weakFileHandles.lock(); - if (fileHandles) { - fileHandles->erase(reinterpret_cast(this)); + std::lock_guard lock(afhmutex_); + for(auto& weakFileHandles : allFileHandles_) { + auto fileHandles = weakFileHandles.lock(); + if(fileHandles) { + fileHandles->erase(reinterpret_cast(this)); + } } - } } } // namespace fl diff --git a/flashlight/fl/dataset/FileBlobDataset.h b/flashlight/fl/dataset/FileBlobDataset.h index 334da0d..29f605e 100644 --- a/flashlight/fl/dataset/FileBlobDataset.h +++ b/flashlight/fl/dataset/FileBlobDataset.h @@ -23,39 +23,40 @@ namespace fl { * */ class FL_API FileBlobDataset : public BlobDataset { - public: - /** - * Creates a `FileBlobDataset`, specifying a blob file name. - * @param[in] name A blob file name. - * @param[in] rw If true, opens in read-write mode. This must be specified - * to use the add() and synch() methods. Except if truncate is true, - * previous stored samples will be read. - * @param[in] truncate In read-write mode, truncate the files if it - * already exists. - */ - explicit FileBlobDataset( - const fs::path& name, - bool rw = false, - bool truncate = false); - - virtual ~FileBlobDataset() override; - - protected: - int64_t writeData(int64_t offset, const char* data, int64_t size) - const override; - int64_t readData(int64_t offset, char* data, int64_t size) const override; - void flushData() override; - bool isEmptyData() const override; - - private: - fs::path name_; - std::ios_base::openmode mode_; - std::shared_ptr getStream() const; - - mutable std::vector>>> - allFileHandles_; - mutable std::mutex afhmutex_; +public: + /** + * Creates a `FileBlobDataset`, specifying a blob file name. + * @param[in] name A blob file name. + * @param[in] rw If true, opens in read-write mode. This must be specified + * to use the add() and synch() methods. Except if truncate is true, + * previous stored samples will be read. + * @param[in] truncate In read-write mode, truncate the files if it + * already exists. + */ + explicit FileBlobDataset( + const fs::path& name, + bool rw = false, + bool truncate = false + ); + + virtual ~FileBlobDataset() override; + +protected: + int64_t writeData(int64_t offset, const char* data, int64_t size) + const override; + int64_t readData(int64_t offset, char* data, int64_t size) const override; + void flushData() override; + bool isEmptyData() const override; + +private: + fs::path name_; + std::ios_base::openmode mode_; + std::shared_ptr getStream() const; + + mutable std::vector>>> + allFileHandles_; + mutable std::mutex afhmutex_; }; } // namespace fl diff --git a/flashlight/fl/dataset/MemoryBlobDataset.cpp b/flashlight/fl/dataset/MemoryBlobDataset.cpp index 113c0d6..d2502ac 100644 --- a/flashlight/fl/dataset/MemoryBlobDataset.cpp +++ b/flashlight/fl/dataset/MemoryBlobDataset.cpp @@ -12,38 +12,41 @@ namespace fl { MemoryBlobDataset::MemoryBlobDataset() { - readIndex(); + readIndex(); } int64_t MemoryBlobDataset::writeData( int64_t offset, const char* data, - int64_t size) const { - std::lock_guard lock(writeMutex_); - if (offset + size > data_.size()) { - data_.resize(offset + size); - } - std::memcpy(data_.data() + offset, data, size); - return size; + int64_t size +) const { + std::lock_guard lock(writeMutex_); + if(offset + size > data_.size()) { + data_.resize(offset + size); + } + std::memcpy(data_.data() + offset, data, size); + return size; } int64_t MemoryBlobDataset::readData(int64_t offset, char* data, int64_t size) - const { - // what is available - int64_t maxSize = std::max( - static_cast(0), static_cast(data_.size()) - offset); - // min(what is available, wanted) - maxSize = std::min(maxSize, size); - std::memcpy(data, data_.data() + offset, maxSize); - return maxSize; +const { + // what is available + int64_t maxSize = std::max( + static_cast(0), + static_cast(data_.size()) - offset + ); + // min(what is available, wanted) + maxSize = std::min(maxSize, size); + std::memcpy(data, data_.data() + offset, maxSize); + return maxSize; } void MemoryBlobDataset::flushData() { - std::lock_guard lock(writeMutex_); + std::lock_guard lock(writeMutex_); } bool MemoryBlobDataset::isEmptyData() const { - return (data_.empty()); + return data_.empty(); } } // namespace fl diff --git a/flashlight/fl/dataset/MemoryBlobDataset.h b/flashlight/fl/dataset/MemoryBlobDataset.h index fae1de8..ae496fb 100644 --- a/flashlight/fl/dataset/MemoryBlobDataset.h +++ b/flashlight/fl/dataset/MemoryBlobDataset.h @@ -22,24 +22,24 @@ namespace fl { * */ class FL_API MemoryBlobDataset : public BlobDataset { - public: - /** - * Creates a `MemoryBlobDataset`, specifying a blob file name. - */ - MemoryBlobDataset(); - - virtual ~MemoryBlobDataset() override = default; - - protected: - int64_t writeData(int64_t offset, const char* data, int64_t size) - const override; - int64_t readData(int64_t offset, char* data, int64_t size) const override; - void flushData() override; - bool isEmptyData() const override; - - private: - mutable std::mutex writeMutex_; - mutable std::vector data_; +public: + /** + * Creates a `MemoryBlobDataset`, specifying a blob file name. + */ + MemoryBlobDataset(); + + virtual ~MemoryBlobDataset() override = default; + +protected: + int64_t writeData(int64_t offset, const char* data, int64_t size) + const override; + int64_t readData(int64_t offset, char* data, int64_t size) const override; + void flushData() override; + bool isEmptyData() const override; + +private: + mutable std::mutex writeMutex_; + mutable std::vector data_; }; } // namespace fl diff --git a/flashlight/fl/dataset/MergeDataset.cpp b/flashlight/fl/dataset/MergeDataset.cpp index b21daba..8c75c25 100644 --- a/flashlight/fl/dataset/MergeDataset.cpp +++ b/flashlight/fl/dataset/MergeDataset.cpp @@ -11,31 +11,32 @@ namespace fl { MergeDataset::MergeDataset( - const std::vector>& datasets) - : datasets_(datasets) { - size_ = 0; - for (const auto& dataset : datasets_) { - size_ = std::max(dataset->size(), size_); - } + const std::vector>& datasets +) : datasets_(datasets) { + size_ = 0; + for(const auto& dataset : datasets_) { + size_ = std::max(dataset->size(), size_); + } } std::vector MergeDataset::get(const int64_t idx) const { - checkIndexBounds(idx); + checkIndexBounds(idx); - std::vector result; - for (const auto& dataset : datasets_) { - if (idx < dataset->size()) { - auto f = dataset->get(idx); - result.insert( - result.end(), - std::make_move_iterator(f.begin()), - std::make_move_iterator(f.end())); + std::vector result; + for(const auto& dataset : datasets_) { + if(idx < dataset->size()) { + auto f = dataset->get(idx); + result.insert( + result.end(), + std::make_move_iterator(f.begin()), + std::make_move_iterator(f.end()) + ); + } } - } - return result; + return result; } int64_t MergeDataset::size() const { - return size_; + return size_; } } // namespace fl diff --git a/flashlight/fl/dataset/MergeDataset.h b/flashlight/fl/dataset/MergeDataset.h index 8fb888f..ed3264e 100644 --- a/flashlight/fl/dataset/MergeDataset.h +++ b/flashlight/fl/dataset/MergeDataset.h @@ -23,38 +23,39 @@ namespace fl { * where `merge` concatenates the `std::vector` from each dataset. * * Example: - \code{.cpp} - // Make two datasets - auto makeDataset = []() { + \code{.cpp} + // Make two datasets + auto makeDataset = []() { auto tensor = fl::rand({5, 4, 10}); std::vector fields{tensor}; return std::make_shared(fields); - }; - auto ds1 = makeDataset(); - auto ds2 = makeDataset(); - - // Merge them - MergeDataset mergeds({ds1, ds2}); - std::cout << mergeds.size() << "\n"; // 10 - std::cout << allClose(mergeds.get(5)[0], ds1->get(5)[0]) << "\n"; // 1 - std::cout << allClose(mergeds.get(5)[1], ds2->get(5)[0]) << "\n"; // 1 - \endcode + }; + auto ds1 = makeDataset(); + auto ds2 = makeDataset(); + + // Merge them + MergeDataset mergeds({ds1, ds2}); + std::cout << mergeds.size() << "\n"; // 10 + std::cout << allClose(mergeds.get(5)[0], ds1->get(5)[0]) << "\n"; // 1 + std::cout << allClose(mergeds.get(5)[1], ds2->get(5)[0]) << "\n"; // 1 + \endcode */ class FL_API MergeDataset : public Dataset { - public: - /** - * Creates a MergeDataset. - * @param[in] datasets The underlying datasets. - */ - explicit MergeDataset( - const std::vector>& datasets); - - int64_t size() const override; - - std::vector get(const int64_t idx) const override; - - private: - std::vector> datasets_; - int64_t size_; +public: + /** + * Creates a MergeDataset. + * @param[in] datasets The underlying datasets. + */ + explicit MergeDataset( + const std::vector>& datasets + ); + + int64_t size() const override; + + std::vector get(const int64_t idx) const override; + +private: + std::vector> datasets_; + int64_t size_; }; } // namespace fl diff --git a/flashlight/fl/dataset/PrefetchDataset.cpp b/flashlight/fl/dataset/PrefetchDataset.cpp index cd7c208..fe65c1d 100644 --- a/flashlight/fl/dataset/PrefetchDataset.cpp +++ b/flashlight/fl/dataset/PrefetchDataset.cpp @@ -17,57 +17,61 @@ namespace fl { PrefetchDataset::PrefetchDataset( std::shared_ptr dataset, int64_t numThreads, - int64_t prefetchSize) - : dataset_(dataset), - numThreads_(numThreads), - prefetchSize_(prefetchSize), - curIdx_(-1) { - if (!dataset_) { - throw std::invalid_argument("dataset to be prefetched is null"); - } - if (!(numThreads_ > 0 && prefetchSize_ > 0) && - !(numThreads_ == 0 && prefetchSize_ == 0)) { - throw std::invalid_argument("invalid numThreads or prefetchSize"); - } - if (numThreads_ > 0) { - auto deviceId = fl::getDevice(); - threadPool_ = std::make_unique( - numThreads_, - [deviceId](size_t /* threadId */) { fl::setDevice(deviceId); }); - } + int64_t prefetchSize +) : dataset_(dataset), + numThreads_(numThreads), + prefetchSize_(prefetchSize), + curIdx_(-1) { + if(!dataset_) { + throw std::invalid_argument("dataset to be prefetched is null"); + } + if( + !(numThreads_ > 0 && prefetchSize_ > 0) + && !(numThreads_ == 0 && prefetchSize_ == 0) + ) { + throw std::invalid_argument("invalid numThreads or prefetchSize"); + } + if(numThreads_ > 0) { + auto deviceId = fl::getDevice(); + threadPool_ = std::make_unique( + numThreads_, + [deviceId](size_t /* threadId */) { fl::setDevice(deviceId); }); + } } std::vector PrefetchDataset::get(int64_t idx) const { - checkIndexBounds(idx); + checkIndexBounds(idx); - if (numThreads_ == 0) { - return dataset_->get(idx); - } + if(numThreads_ == 0) { + return dataset_->get(idx); + } - // remove from cache (if necessary) - while (!prefetchCache_.empty() && idx != curIdx_) { - prefetchCache_.pop(); - ++curIdx_; - } + // remove from cache (if necessary) + while(!prefetchCache_.empty() && idx != curIdx_) { + prefetchCache_.pop(); + ++curIdx_; + } - // add to cache (if necessary) - while (prefetchCache_.size() < prefetchSize_) { - auto fetchIdx = idx + prefetchCache_.size(); - if (fetchIdx >= size()) { - break; + // add to cache (if necessary) + while(prefetchCache_.size() < prefetchSize_) { + auto fetchIdx = idx + prefetchCache_.size(); + if(fetchIdx >= size()) { + break; + } + prefetchCache_.emplace( + threadPool_->enqueue( + [this, fetchIdx]() { return this->dataset_->get(fetchIdx); }) + ); } - prefetchCache_.emplace(threadPool_->enqueue( - [this, fetchIdx]() { return this->dataset_->get(fetchIdx); })); - } - auto curSample = prefetchCache_.front().get(); + auto curSample = prefetchCache_.front().get(); - prefetchCache_.pop(); - curIdx_ = idx + 1; - return curSample; + prefetchCache_.pop(); + curIdx_ = idx + 1; + return curSample; } int64_t PrefetchDataset::size() const { - return dataset_->size(); + return dataset_->size(); } } // namespace fl diff --git a/flashlight/fl/dataset/PrefetchDataset.h b/flashlight/fl/dataset/PrefetchDataset.h index afd9bf1..08300b9 100644 --- a/flashlight/fl/dataset/PrefetchDataset.h +++ b/flashlight/fl/dataset/PrefetchDataset.h @@ -22,45 +22,46 @@ namespace fl { * cache misses leading to a degraded performance. * * Example: - \code{.cpp} - // Make a dataset with 100 samples - auto tensor = fl::rand({5, 4, 100}); - std::vector fields{tensor}; - auto ds = std::make_shared(fields); + \code{.cpp} + // Make a dataset with 100 samples + auto tensor = fl::rand({5, 4, 100}); + std::vector fields{tensor}; + auto ds = std::make_shared(fields); - // Iterate over the dataset using 4 background threads prefetching 2 samples - // in advance - for (auto& sample : PrefetchDataset(ds, 4, 2)) { + // Iterate over the dataset using 4 background threads prefetching 2 samples + // in advance + for (auto& sample : PrefetchDataset(ds, 4, 2)) { // do something - } - \endcode + } + \endcode */ class FL_API PrefetchDataset : public Dataset { - public: - /** - * Creates a `PrefetchDataset`. - * @param[in] dataset The underlying dataset. - * @param[in] numThreads Number of threads used by the threadpool - * @param[in] prefetchSize Number of samples to be prefetched - */ - explicit PrefetchDataset( - std::shared_ptr dataset, - int64_t numThreads, - int64_t prefetchSize); +public: + /** + * Creates a `PrefetchDataset`. + * @param[in] dataset The underlying dataset. + * @param[in] numThreads Number of threads used by the threadpool + * @param[in] prefetchSize Number of samples to be prefetched + */ + explicit PrefetchDataset( + std::shared_ptr dataset, + int64_t numThreads, + int64_t prefetchSize + ); - int64_t size() const override; + int64_t size() const override; - std::vector get(const int64_t idx) const override; + std::vector get(const int64_t idx) const override; - protected: - std::shared_ptr dataset_; - int64_t numThreads_, prefetchSize_; +protected: + std::shared_ptr dataset_; + int64_t numThreads_, prefetchSize_; - private: - std::unique_ptr threadPool_; - // state variables - mutable std::queue>> prefetchCache_; - mutable int64_t curIdx_; +private: + std::unique_ptr threadPool_; + // state variables + mutable std::queue>> prefetchCache_; + mutable int64_t curIdx_; }; } // namespace fl diff --git a/flashlight/fl/dataset/ResampleDataset.cpp b/flashlight/fl/dataset/ResampleDataset.cpp index 9e41782..1d7e8cd 100644 --- a/flashlight/fl/dataset/ResampleDataset.cpp +++ b/flashlight/fl/dataset/ResampleDataset.cpp @@ -14,58 +14,60 @@ namespace { std::vector makeIdentityPermutation(int64_t size) { - std::vector perm(size); - std::iota(perm.begin(), perm.end(), 0); - return perm; + std::vector perm(size); + std::iota(perm.begin(), perm.end(), 0); + return perm; } std::vector makePermutationFromFn( int64_t size, - const fl::Dataset::PermutationFunction& fn) { - if (!fn) { - throw std::invalid_argument("PermutationFunction is null"); - } - auto perm = makeIdentityPermutation(size); - std::transform(perm.begin(), perm.end(), perm.begin(), fn); - return perm; + const fl::Dataset::PermutationFunction& fn +) { + if(!fn) { + throw std::invalid_argument("PermutationFunction is null"); + } + auto perm = makeIdentityPermutation(size); + std::transform(perm.begin(), perm.end(), perm.begin(), fn); + return perm; } } // namespace namespace fl { -ResampleDataset::ResampleDataset(std::shared_ptr dataset) - : ResampleDataset(dataset, makeIdentityPermutation(dataset->size())) {} +ResampleDataset::ResampleDataset(std::shared_ptr dataset) : ResampleDataset(dataset, + makeIdentityPermutation( + dataset->size())) {} ResampleDataset::ResampleDataset( std::shared_ptr dataset, - std::vector resamplevec) - : dataset_(dataset) { - if (!dataset_) { - throw std::invalid_argument("dataset to be resampled is null"); - } - resample(std::move(resamplevec)); + std::vector resamplevec +) : dataset_(dataset) { + if(!dataset_) { + throw std::invalid_argument("dataset to be resampled is null"); + } + resample(std::move(resamplevec)); } ResampleDataset::ResampleDataset( std::shared_ptr dataset, const PermutationFunction& fn, - int n) - : ResampleDataset( - dataset, - makePermutationFromFn(n == -1 ? dataset->size() : n, fn)) {} + int n +) : ResampleDataset( + dataset, + makePermutationFromFn(n == -1 ? dataset->size() : n, fn)) {} void ResampleDataset::resample(std::vector resamplevec) { - resampleVec_ = std::move(resamplevec); + resampleVec_ = std::move(resamplevec); } std::vector ResampleDataset::get(const int64_t idx) const { - checkIndexBounds(idx); - return dataset_->get(resampleVec_[idx]); + checkIndexBounds(idx); + return dataset_->get(resampleVec_[idx]); } int64_t ResampleDataset::size() const { - return resampleVec_.size(); + return resampleVec_.size(); } } // namespace fl diff --git a/flashlight/fl/dataset/ResampleDataset.h b/flashlight/fl/dataset/ResampleDataset.h index 70af7f2..fa60062 100644 --- a/flashlight/fl/dataset/ResampleDataset.h +++ b/flashlight/fl/dataset/ResampleDataset.h @@ -18,63 +18,65 @@ namespace fl { * Note: the mapping doesn't have to be bijective. * * Example: - \code{.cpp} - // Make a dataset with 10 samples - auto tensor = fl::rand({5, 4, 10}); - std::vector fields{tensor}; - auto ds = std::make_shared(fields); + \code{.cpp} + // Make a dataset with 10 samples + auto tensor = fl::rand({5, 4, 10}); + std::vector fields{tensor}; + auto ds = std::make_shared(fields); - // Resample it by reversing it - auto permfn = [ds](int64_t x) { return ds->size() - 1 - x; }; - ResampleDataset resampleds(ds, permfn); - std::cout << resampleds.size() << "\n"; // 10 - std::cout << allClose(resampleds.get(9)[0], ds->get(0)[0]) << "\n"; // 1 - \endcode + // Resample it by reversing it + auto permfn = [ds](int64_t x) { return ds->size() - 1 - x; }; + ResampleDataset resampleds(ds, permfn); + std::cout << resampleds.size() << "\n"; // 10 + std::cout << allClose(resampleds.get(9)[0], ds->get(0)[0]) << "\n"; // 1 + \endcode */ class FL_API ResampleDataset : public Dataset { - public: - /** - * Constructs a ResampleDataset with the identity mapping: - * `ResampleDataset(ds)->get(i) == ds->get(i)` - * @param[in] dataset The underlying dataset. - */ - explicit ResampleDataset(std::shared_ptr dataset); +public: + /** + * Constructs a ResampleDataset with the identity mapping: + * `ResampleDataset(ds)->get(i) == ds->get(i)` + * @param[in] dataset The underlying dataset. + */ + explicit ResampleDataset(std::shared_ptr dataset); - /** - * Constructs a ResampleDataset with mapping specified by a vector: - * `ResampleDataset(ds, v)->get(i) == ds->get(v[i])` - * @param[in] dataset The underlying dataset. - * @param[in] resamplevec The vector specifying the mapping. - */ - ResampleDataset( - std::shared_ptr dataset, - std::vector resamplevec); + /** + * Constructs a ResampleDataset with mapping specified by a vector: + * `ResampleDataset(ds, v)->get(i) == ds->get(v[i])` + * @param[in] dataset The underlying dataset. + * @param[in] resamplevec The vector specifying the mapping. + */ + ResampleDataset( + std::shared_ptr dataset, + std::vector resamplevec + ); - /** - * Constructs a ResampleDataset with mapping specified by a function: - * `ResampleDataset(ds, fn)->get(i) == ds->get(fn(i))` - * The function should be deterministic. - * @param[in] dataset The underlying dataset. - * @param[in] resamplefn The function specifying the mapping. - * @param[in] n The size of the new dataset (if -1, uses previous size) - */ - ResampleDataset( - std::shared_ptr dataset, - const PermutationFunction& resamplefn, - int n = -1); + /** + * Constructs a ResampleDataset with mapping specified by a function: + * `ResampleDataset(ds, fn)->get(i) == ds->get(fn(i))` + * The function should be deterministic. + * @param[in] dataset The underlying dataset. + * @param[in] resamplefn The function specifying the mapping. + * @param[in] n The size of the new dataset (if -1, uses previous size) + */ + ResampleDataset( + std::shared_ptr dataset, + const PermutationFunction& resamplefn, + int n = -1 + ); - int64_t size() const override; + int64_t size() const override; - std::vector get(const int64_t idx) const override; + std::vector get(const int64_t idx) const override; - /** - * Changes the mapping used to resample the dataset. - * @param[in] resamplevec The vector specifying the new mapping. - */ - void resample(std::vector resamplevec); + /** + * Changes the mapping used to resample the dataset. + * @param[in] resamplevec The vector specifying the new mapping. + */ + void resample(std::vector resamplevec); - protected: - std::shared_ptr dataset_; - std::vector resampleVec_; +protected: + std::shared_ptr dataset_; + std::vector resampleVec_; }; } // namespace fl diff --git a/flashlight/fl/dataset/ShuffleDataset.cpp b/flashlight/fl/dataset/ShuffleDataset.cpp index dfc5327..008d461 100644 --- a/flashlight/fl/dataset/ShuffleDataset.cpp +++ b/flashlight/fl/dataset/ShuffleDataset.cpp @@ -14,26 +14,29 @@ namespace fl { ShuffleDataset::ShuffleDataset( std::shared_ptr dataset, - int seed /* = 0 */) - : ResampleDataset(dataset), rng_(seed) { - resample(); + int seed /* = 0 */ +) : ResampleDataset(dataset), + rng_(seed) { + resample(); } void ShuffleDataset::resample() { - std::iota(resampleVec_.begin(), resampleVec_.end(), 0); - auto n = resampleVec_.size(); - // custom implementation of shuffle - - // en.cppreference.com/w/cpp/algorithm/random_shuffle#Possible_implementation - using distr_t = std::uniform_int_distribution; - distr_t D; - for (int i = n - 1; i > 0; --i) { - std::swap( - resampleVec_[i], resampleVec_[D(rng_, distr_t::param_type(0, i))]); - } + std::iota(resampleVec_.begin(), resampleVec_.end(), 0); + auto n = resampleVec_.size(); + // custom implementation of shuffle - + // en.cppreference.com/w/cpp/algorithm/random_shuffle#Possible_implementation + using distr_t = std::uniform_int_distribution; + distr_t D; + for(int i = n - 1; i > 0; --i) { + std::swap( + resampleVec_[i], + resampleVec_[D(rng_, distr_t::param_type(0, i))] + ); + } } void ShuffleDataset::setSeed(int seed) { - rng_.seed(seed); + rng_.seed(seed); } } // namespace fl diff --git a/flashlight/fl/dataset/ShuffleDataset.h b/flashlight/fl/dataset/ShuffleDataset.h index 8ff99b1..6101760 100644 --- a/flashlight/fl/dataset/ShuffleDataset.h +++ b/flashlight/fl/dataset/ShuffleDataset.h @@ -17,44 +17,44 @@ namespace fl { * A view into a dataset, with indices permuted randomly. * * Example: - \code{.cpp} - // Make a dataset with 100 samples - auto tensor = fl::rand({5, 4, 100}); - std::vector fields{tensor}; - auto ds = std::make_shared(fields); - - // Shuffle it - ShuffleDataset shuffleds(ds); - std::cout << shuffleds.size() << "\n"; // 100 - std::cout << "first try" << shuffleds.get(0)["x"] << std::endl; - - // Reshuffle it - shuffleds.resample(); - std::cout << "second try" << shuffleds.get(0)["x"] << std::endl; - \endcode + \code{.cpp} + // Make a dataset with 100 samples + auto tensor = fl::rand({5, 4, 100}); + std::vector fields{tensor}; + auto ds = std::make_shared(fields); + + // Shuffle it + ShuffleDataset shuffleds(ds); + std::cout << shuffleds.size() << "\n"; // 100 + std::cout << "first try" << shuffleds.get(0)["x"] << std::endl; + + // Reshuffle it + shuffleds.resample(); + std::cout << "second try" << shuffleds.get(0)["x"] << std::endl; + \endcode */ class FL_API ShuffleDataset : public ResampleDataset { - public: - /** - * Creates a `ShuffleDataset`. - * @param[in] dataset The underlying dataset. - * @param[seed] seed initial seed to be used. - */ - explicit ShuffleDataset(std::shared_ptr dataset, int seed = 0); - - /** - * Generates a new random permutation for the dataset. - */ - void resample(); - - /** - * Sets the PRNG seed. - * @param[in] seed The desired seed. - */ - void setSeed(int seed); - - protected: - std::mt19937_64 rng_; +public: + /** + * Creates a `ShuffleDataset`. + * @param[in] dataset The underlying dataset. + * @param[seed] seed initial seed to be used. + */ + explicit ShuffleDataset(std::shared_ptr dataset, int seed = 0); + + /** + * Generates a new random permutation for the dataset. + */ + void resample(); + + /** + * Sets the PRNG seed. + * @param[in] seed The desired seed. + */ + void setSeed(int seed); + +protected: + std::mt19937_64 rng_; }; } // namespace fl diff --git a/flashlight/fl/dataset/SpanDataset.cpp b/flashlight/fl/dataset/SpanDataset.cpp index f502c79..5b7233f 100644 --- a/flashlight/fl/dataset/SpanDataset.cpp +++ b/flashlight/fl/dataset/SpanDataset.cpp @@ -13,28 +13,31 @@ namespace fl { SpanDataset::SpanDataset( std::shared_ptr dataset, const int64_t offset, - const int64_t length) - : dataset_(dataset), offset_(offset) { - size_ = (length < 0) ? (dataset_->size() - offset_) : length; - if (size_ + offset_ > dataset_->size()) { - throw std::out_of_range( - "Dataset length out of range (larger than underlying dataset)"); - } + const int64_t length +) : dataset_(dataset), + offset_(offset) { + size_ = (length < 0) ? (dataset_->size() - offset_) : length; + if(size_ + offset_ > dataset_->size()) { + throw std::out_of_range( + "Dataset length out of range (larger than underlying dataset)" + ); + } } std::vector SpanDataset::get(const int64_t idx) const { - checkIndexBounds(idx); + checkIndexBounds(idx); - std::vector result; - auto f = dataset_->get(idx + offset_); - result.insert( - result.end(), - std::make_move_iterator(f.begin()), - std::make_move_iterator(f.end())); - return result; + std::vector result; + auto f = dataset_->get(idx + offset_); + result.insert( + result.end(), + std::make_move_iterator(f.begin()), + std::make_move_iterator(f.end()) + ); + return result; } int64_t SpanDataset::size() const { - return size_; + return size_; } } // namespace fl diff --git a/flashlight/fl/dataset/SpanDataset.h b/flashlight/fl/dataset/SpanDataset.h index 927baaf..4494156 100644 --- a/flashlight/fl/dataset/SpanDataset.h +++ b/flashlight/fl/dataset/SpanDataset.h @@ -17,52 +17,53 @@ namespace fl { * A view into an underlying dataset with an offset and optional bounded length. * * The size of the `SpanDataset` is either specified for the size of the input - dataset + dataset * accounting for the offset. * * We have, for example `SpanDataset(ds, 13).get(i) == ds.get(13 + i)` * * Example: - \code{.cpp} - // Make a datasets - auto makeDataset = []() { + \code{.cpp} + // Make a datasets + auto makeDataset = []() { auto tensor = fl::rand({5, 4, 10}); std::vector fields{tensor}; return std::make_shared(fields); - }; - auto ds = makeDataset(); + }; + auto ds = makeDataset(); - // Create two spanned datasets - SpanDataset spands1(ds, 2); - SpanDataset spands2(ds, 0, 2); - std::cout << spands1.size() << "\n"; // 8 - std::cout << spands2.size() << "\n"; // 2 - std::cout << allClose(spands1.get(3)[0], ds->get(5)[0]) << "\n"; // 1 - std::cout << allClose(spands2.get(1)[1], ds->get(1)[0]) << "\n"; // 1 - \endcode + // Create two spanned datasets + SpanDataset spands1(ds, 2); + SpanDataset spands2(ds, 0, 2); + std::cout << spands1.size() << "\n"; // 8 + std::cout << spands2.size() << "\n"; // 2 + std::cout << allClose(spands1.get(3)[0], ds->get(5)[0]) << "\n"; // 1 + std::cout << allClose(spands2.get(1)[1], ds->get(1)[0]) << "\n"; // 1 + \endcode */ class FL_API SpanDataset : public Dataset { - public: - /** - * Creates a SpanDataset. - * @param[in] dataset The underlying dataset. - * @param[in] offset The starting index of the new dataset relative to the - * underlying dataset. - * @param[in] length The size of the new dataset (if -1, uses previous size - * minus the offset) - */ - explicit SpanDataset( - std::shared_ptr dataset, - const int64_t offset, - const int64_t length = -1); +public: + /** + * Creates a SpanDataset. + * @param[in] dataset The underlying dataset. + * @param[in] offset The starting index of the new dataset relative to the + * underlying dataset. + * @param[in] length The size of the new dataset (if -1, uses previous size + * minus the offset) + */ + explicit SpanDataset( + std::shared_ptr dataset, + const int64_t offset, + const int64_t length = -1 + ); - int64_t size() const override; + int64_t size() const override; - std::vector get(const int64_t idx) const override; + std::vector get(const int64_t idx) const override; - private: - std::shared_ptr dataset_; - int64_t offset_; - int64_t size_; +private: + std::shared_ptr dataset_; + int64_t offset_; + int64_t size_; }; } // namespace fl diff --git a/flashlight/fl/dataset/TensorDataset.cpp b/flashlight/fl/dataset/TensorDataset.cpp index b06cb9c..58576a1 100644 --- a/flashlight/fl/dataset/TensorDataset.cpp +++ b/flashlight/fl/dataset/TensorDataset.cpp @@ -14,41 +14,41 @@ namespace fl { -TensorDataset::TensorDataset(const std::vector& dataTensors) - : dataTensors_(dataTensors), size_(0) { - if (dataTensors_.empty()) { - throw std::invalid_argument("no tensors passed to TensorDataset"); - } - - for (const auto& tensor : dataTensors_) { - auto ndims = tensor.ndim(); - if (ndims == 0) { - throw std::invalid_argument("tensor for TensorDataset can't be empty"); +TensorDataset::TensorDataset(const std::vector& dataTensors) : dataTensors_(dataTensors), + size_(0) { + if(dataTensors_.empty()) { + throw std::invalid_argument("no tensors passed to TensorDataset"); } - auto lastdim = ndims - 1; - int64_t cursz = tensor.dim(lastdim); - size_ = std::max(size_, cursz); - } + for(const auto& tensor : dataTensors_) { + auto ndims = tensor.ndim(); + if(ndims == 0) { + throw std::invalid_argument("tensor for TensorDataset can't be empty"); + } + + auto lastdim = ndims - 1; + int64_t cursz = tensor.dim(lastdim); + size_ = std::max(size_, cursz); + } } std::vector TensorDataset::get(const int64_t idx) const { - checkIndexBounds(idx); - std::vector result(dataTensors_.size()); - for (int64_t i = 0; i < dataTensors_.size(); ++i) { - auto& tensor = dataTensors_[i]; - - std::vector sel(tensor.ndim(), fl::span); - auto lastdim = tensor.ndim() - 1; - if (idx < tensor.dim(lastdim)) { - sel[lastdim] = idx; - result[i] = tensor(sel); + checkIndexBounds(idx); + std::vector result(dataTensors_.size()); + for(int64_t i = 0; i < dataTensors_.size(); ++i) { + auto& tensor = dataTensors_[i]; + + std::vector sel(tensor.ndim(), fl::span); + auto lastdim = tensor.ndim() - 1; + if(idx < tensor.dim(lastdim)) { + sel[lastdim] = idx; + result[i] = tensor(sel); + } } - } - return result; + return result; } int64_t TensorDataset::size() const { - return size_; + return size_; } } // namespace fl diff --git a/flashlight/fl/dataset/TensorDataset.h b/flashlight/fl/dataset/TensorDataset.h index 1b1e5bb..05826ac 100644 --- a/flashlight/fl/dataset/TensorDataset.h +++ b/flashlight/fl/dataset/TensorDataset.h @@ -18,31 +18,31 @@ namespace fl { * Hence, it must be the same across all `int64_t`s in the input. * * Example: - \code{.cpp} - Tensor tensor1 = fl::rand({5, 4, 10}); - Tensor tensor2 = fl::rand({7, 10}); - TensorDataset ds({tensor1, tensor2}); - - std::cout << ds.size() << "\n"; // 10 - std::cout << ds.get(0)[0].shape() << "\n"; // 5 4 - std::cout << ds.get(0)[1].shape() << "\n"; // 7 1 - \endcode + \code{.cpp} + Tensor tensor1 = fl::rand({5, 4, 10}); + Tensor tensor2 = fl::rand({7, 10}); + TensorDataset ds({tensor1, tensor2}); + + std::cout << ds.size() << "\n"; // 10 + std::cout << ds.get(0)[0].shape() << "\n"; // 5 4 + std::cout << ds.get(0)[1].shape() << "\n"; // 7 1 + \endcode */ class FL_API TensorDataset : public Dataset { - public: - /** - * Creates a `TensorDataset` by unpacking the input tensors. - * @param[in] datatensors A vector of tensors, which will be - * unpacked along their last non-singleton dimensions. - */ - explicit TensorDataset(const std::vector& datatensors); +public: + /** + * Creates a `TensorDataset` by unpacking the input tensors. + * @param[in] datatensors A vector of tensors, which will be + * unpacked along their last non-singleton dimensions. + */ + explicit TensorDataset(const std::vector& datatensors); - int64_t size() const override; + int64_t size() const override; - std::vector get(const int64_t idx) const override; + std::vector get(const int64_t idx) const override; - private: - std::vector dataTensors_; - int64_t size_{0}; +private: + std::vector dataTensors_; + int64_t size_{0}; }; } // namespace fl diff --git a/flashlight/fl/dataset/TransformDataset.cpp b/flashlight/fl/dataset/TransformDataset.cpp index f572345..ae35827 100644 --- a/flashlight/fl/dataset/TransformDataset.cpp +++ b/flashlight/fl/dataset/TransformDataset.cpp @@ -13,28 +13,29 @@ namespace fl { TransformDataset::TransformDataset( std::shared_ptr dataset, - const std::vector& transformfns) - : dataset_(dataset), transformFns_(transformfns) { - if (!dataset_) { - throw std::invalid_argument("dataset to be transformed is null"); - } + const std::vector& transformfns +) : dataset_(dataset), + transformFns_(transformfns) { + if(!dataset_) { + throw std::invalid_argument("dataset to be transformed is null"); + } } std::vector TransformDataset::get(const int64_t idx) const { - checkIndexBounds(idx); + checkIndexBounds(idx); - auto result = dataset_->get(idx); + auto result = dataset_->get(idx); - for (int64_t i = 0; i < result.size(); ++i) { - if (i >= transformFns_.size() || !transformFns_[i]) { - continue; + for(int64_t i = 0; i < result.size(); ++i) { + if(i >= transformFns_.size() || !transformFns_[i]) { + continue; + } + result[i] = transformFns_[i](result[i]); } - result[i] = transformFns_[i](result[i]); - } - return result; + return result; } int64_t TransformDataset::size() const { - return dataset_->size(); + return dataset_->size(); } } // namespace fl diff --git a/flashlight/fl/dataset/TransformDataset.h b/flashlight/fl/dataset/TransformDataset.h index 91a711e..bc07f6e 100644 --- a/flashlight/fl/dataset/TransformDataset.h +++ b/flashlight/fl/dataset/TransformDataset.h @@ -19,38 +19,39 @@ namespace fl { * The dataset size remains unchanged. * * Example: - \code{.cpp} - // Make a dataset with 10 samples - auto tensor = fl::rand({5, 4, 10}); - std::vector fields{tensor}; - auto ds = std::make_shared(fields); - - // Transform it - auto negate = [](const Tensor& arr) { return -arr; }; - TransformDataset transformds(ds, {negate}); - std::cout << transformds.size() << "\n"; // 10 - std::cout << allClose(transformds.get(5)[0], -ds->get(5)[0]) << "\n"; // 1 - \endcode + \code{.cpp} + // Make a dataset with 10 samples + auto tensor = fl::rand({5, 4, 10}); + std::vector fields{tensor}; + auto ds = std::make_shared(fields); + + // Transform it + auto negate = [](const Tensor& arr) { return -arr; }; + TransformDataset transformds(ds, {negate}); + std::cout << transformds.size() << "\n"; // 10 + std::cout << allClose(transformds.get(5)[0], -ds->get(5)[0]) << "\n"; // 1 + \endcode */ class FL_API TransformDataset : public Dataset { - public: - /** - * Creates a `TransformDataset`. - * @param[in] dataset The underlying dataset. - * @param[in] transformfns The mappings used to transform the values. - * If a `TransformFunction` is null then the corresponding value is not - * transformed. - */ - TransformDataset( - std::shared_ptr dataset, - const std::vector& transformfns); - - int64_t size() const override; - - std::vector get(const int64_t idx) const override; - - private: - std::shared_ptr dataset_; - const std::vector transformFns_; +public: + /** + * Creates a `TransformDataset`. + * @param[in] dataset The underlying dataset. + * @param[in] transformfns The mappings used to transform the values. + * If a `TransformFunction` is null then the corresponding value is not + * transformed. + */ + TransformDataset( + std::shared_ptr dataset, + const std::vector& transformfns + ); + + int64_t size() const override; + + std::vector get(const int64_t idx) const override; + +private: + std::shared_ptr dataset_; + const std::vector transformFns_; }; } // namespace fl diff --git a/flashlight/fl/dataset/Utils.cpp b/flashlight/fl/dataset/Utils.cpp index ac736ff..8e298b0 100644 --- a/flashlight/fl/dataset/Utils.cpp +++ b/flashlight/fl/dataset/Utils.cpp @@ -19,171 +19,182 @@ std::vector partitionByRoundRobin( int64_t partitionId, int64_t numPartitions, int64_t batchSz /* = 1 */, - bool allowEmpty /* = false */) { - if (partitionId < 0 || partitionId >= numPartitions) { - throw std::invalid_argument( - "invalid partitionId, numPartitions for partitionByRoundRobin"); - } - int64_t nSamplesPerGlobalBatch = numPartitions * batchSz; - int64_t nGlobalBatches = numSamples / nSamplesPerGlobalBatch; - bool includeLast = (numSamples % nSamplesPerGlobalBatch) >= numPartitions; - if (allowEmpty && (numSamples % nSamplesPerGlobalBatch) > 0) { - includeLast = true; - } - if (includeLast) { - ++nGlobalBatches; - } - std::vector outSamples; - outSamples.reserve(nGlobalBatches * batchSz); + bool allowEmpty /* = false */ +) { + if(partitionId < 0 || partitionId >= numPartitions) { + throw std::invalid_argument( + "invalid partitionId, numPartitions for partitionByRoundRobin" + ); + } + int64_t nSamplesPerGlobalBatch = numPartitions * batchSz; + int64_t nGlobalBatches = numSamples / nSamplesPerGlobalBatch; + bool includeLast = (numSamples % nSamplesPerGlobalBatch) >= numPartitions; + if(allowEmpty && (numSamples % nSamplesPerGlobalBatch) > 0) { + includeLast = true; + } + if(includeLast) { + ++nGlobalBatches; + } + std::vector outSamples; + outSamples.reserve(nGlobalBatches * batchSz); - for (size_t i = 0; i < nGlobalBatches; i++) { - auto offset = i * nSamplesPerGlobalBatch; - int64_t nCurSamples; // num samples in current batch - if (includeLast && (i == nGlobalBatches - 1)) { - nCurSamples = - (numSamples - offset) / numPartitions; // min samples per proc - int64_t remaining = (numSamples - offset) % numPartitions; - offset += nCurSamples * partitionId; - if (partitionId < remaining) { - nCurSamples += 1; - } - offset += std::min(partitionId, remaining); - } else { - offset += batchSz * partitionId; - nCurSamples = batchSz; - } - for (int64_t b = 0; b < nCurSamples; ++b) { - outSamples.emplace_back(b + offset); - } - } - return outSamples; + for(size_t i = 0; i < nGlobalBatches; i++) { + auto offset = i * nSamplesPerGlobalBatch; + int64_t nCurSamples; // num samples in current batch + if(includeLast && (i == nGlobalBatches - 1)) { + nCurSamples = + (numSamples - offset) / numPartitions; // min samples per proc + int64_t remaining = (numSamples - offset) % numPartitions; + offset += nCurSamples * partitionId; + if(partitionId < remaining) { + nCurSamples += 1; + } + offset += std::min(partitionId, remaining); + } else { + offset += batchSz * partitionId; + nCurSamples = batchSz; + } + for(int64_t b = 0; b < nCurSamples; ++b) { + outSamples.emplace_back(b + offset); + } + } + return outSamples; } -std::pair, std::vector> -dynamicPartitionByRoundRobin( +std::pair, std::vector> dynamicPartitionByRoundRobin( const std::vector& samplesSize, int64_t partitionId, int64_t numPartitions, int64_t maxSizePerBatch, - bool allowEmpty /* = false */) { - if (partitionId < 0 || partitionId >= numPartitions) { - throw std::invalid_argument( - "[dynamicPartitionByRoundRobin] invalid partitionId, numPartitions"); - } - std::vector batchSizes, batchOffsets; - int64_t sampleIdx = 0, batchStartSampleIdx = 0; - float maxSampleLen = 0; - while (sampleIdx < samplesSize.size()) { - if (samplesSize[sampleIdx] > maxSizePerBatch) { - throw std::invalid_argument( - "[dynamicPartitionByRoundRobin] invalid samples length: each sample " - "should have size <= maxSizePerBatch, either filter data or set larger maxSizePerBatch. " - "maxSizePerBatch were set to " + - std::to_string(maxSizePerBatch) + " sample size is " + - std::to_string(samplesSize[sampleIdx])); - } - float maxSampleLenOld = maxSampleLen; - maxSampleLen = std::max(maxSampleLen, samplesSize[sampleIdx]); - if ((sampleIdx - batchStartSampleIdx + 1) * maxSampleLen > - maxSizePerBatch) { - if (maxSampleLenOld * (sampleIdx - batchStartSampleIdx) > - maxSizePerBatch) { + bool allowEmpty /* = false */ +) { + if(partitionId < 0 || partitionId >= numPartitions) { throw std::invalid_argument( - "dynamicPartitionByRoundRobin is doing wrong packing"); - } - batchSizes.push_back(sampleIdx - batchStartSampleIdx); - batchOffsets.push_back(batchStartSampleIdx); - batchStartSampleIdx = sampleIdx; - maxSampleLen = samplesSize[sampleIdx]; - } else { - sampleIdx++; - } - } - // process last batch with sampleIdx == numSamples, batchStartSampleIdx < - // numSamples - if ((sampleIdx - batchStartSampleIdx) * maxSampleLen < maxSizePerBatch) { - batchSizes.push_back(sampleIdx - batchStartSampleIdx); - batchOffsets.push_back(batchStartSampleIdx); - } + "[dynamicPartitionByRoundRobin] invalid partitionId, numPartitions" + ); + } + std::vector batchSizes, batchOffsets; + int64_t sampleIdx = 0, batchStartSampleIdx = 0; + float maxSampleLen = 0; + while(sampleIdx < samplesSize.size()) { + if(samplesSize[sampleIdx] > maxSizePerBatch) { + throw std::invalid_argument( + "[dynamicPartitionByRoundRobin] invalid samples length: each sample " + "should have size <= maxSizePerBatch, either filter data or set larger maxSizePerBatch. " + "maxSizePerBatch were set to " + + std::to_string(maxSizePerBatch) + " sample size is " + + std::to_string(samplesSize[sampleIdx]) + ); + } + float maxSampleLenOld = maxSampleLen; + maxSampleLen = std::max(maxSampleLen, samplesSize[sampleIdx]); + if( + (sampleIdx - batchStartSampleIdx + 1) * maxSampleLen + > maxSizePerBatch + ) { + if( + maxSampleLenOld * (sampleIdx - batchStartSampleIdx) + > maxSizePerBatch + ) { + throw std::invalid_argument( + "dynamicPartitionByRoundRobin is doing wrong packing" + ); + } + batchSizes.push_back(sampleIdx - batchStartSampleIdx); + batchOffsets.push_back(batchStartSampleIdx); + batchStartSampleIdx = sampleIdx; + maxSampleLen = samplesSize[sampleIdx]; + } else { + sampleIdx++; + } + } + // process last batch with sampleIdx == numSamples, batchStartSampleIdx < + // numSamples + if((sampleIdx - batchStartSampleIdx) * maxSampleLen < maxSizePerBatch) { + batchSizes.push_back(sampleIdx - batchStartSampleIdx); + batchOffsets.push_back(batchStartSampleIdx); + } - int64_t nGlobalBatches = batchSizes.size() / numPartitions; - if (allowEmpty && (batchSizes.size() % numPartitions) > 0) { - ++nGlobalBatches; - } - std::vector outSamples, outBatchSizes; - for (size_t i = 0; i < nGlobalBatches; i++) { - int index = i * numPartitions + partitionId; - if (index < batchSizes.size()) { - outBatchSizes.emplace_back(batchSizes[index]); - for (int64_t b = 0; b < batchSizes[index]; ++b) { - outSamples.emplace_back(b + batchOffsets[index]); - } - } - } - return {outSamples, outBatchSizes}; + int64_t nGlobalBatches = batchSizes.size() / numPartitions; + if(allowEmpty && (batchSizes.size() % numPartitions) > 0) { + ++nGlobalBatches; + } + std::vector outSamples, outBatchSizes; + for(size_t i = 0; i < nGlobalBatches; i++) { + int index = i * numPartitions + partitionId; + if(index < batchSizes.size()) { + outBatchSizes.emplace_back(batchSizes[index]); + for(int64_t b = 0; b < batchSizes[index]; ++b) { + outSamples.emplace_back(b + batchOffsets[index]); + } + } + } + return {outSamples, outBatchSizes}; } std::vector makeBatchFromRange( std::shared_ptr dataset, std::vector batchFns, int64_t start, - int64_t end) { - std::vector> buffer; - for (int64_t batchidx = start; batchidx < end; ++batchidx) { - auto fds = dataset->get(batchidx); - if (buffer.size() < fds.size()) { - buffer.resize(fds.size()); - } - for (int64_t i = 0; i < fds.size(); ++i) { - buffer[i].emplace_back(fds[i]); - } - } - std::vector result(buffer.size()); - for (int64_t i = 0; i < buffer.size(); ++i) { - result[i] = - makeBatch(buffer[i], (i < batchFns.size()) ? batchFns[i] : nullptr); - } - return result; + int64_t end +) { + std::vector> buffer; + for(int64_t batchidx = start; batchidx < end; ++batchidx) { + auto fds = dataset->get(batchidx); + if(buffer.size() < fds.size()) { + buffer.resize(fds.size()); + } + for(int64_t i = 0; i < fds.size(); ++i) { + buffer[i].emplace_back(fds[i]); + } + } + std::vector result(buffer.size()); + for(int64_t i = 0; i < buffer.size(); ++i) { + result[i] = + makeBatch(buffer[i], (i < batchFns.size()) ? batchFns[i] : nullptr); + } + return result; } Tensor makeBatch( const std::vector& data, - const Dataset::BatchFunction& batchFn) { - if (batchFn) { - return batchFn(data); - } - // Using default batching function - if (data.empty()) { - return Tensor(); - } - auto& dims = data[0].shape(); + const Dataset::BatchFunction& batchFn +) { + if(batchFn) { + return batchFn(data); + } + // Using default batching function + if(data.empty()) { + return Tensor(); + } + auto& dims = data[0].shape(); - for (const auto& d : data) { - if (d.shape() != dims) { - throw std::invalid_argument("dimension mismatch while batching dataset"); + for(const auto& d : data) { + if(d.shape() != dims) { + throw std::invalid_argument("dimension mismatch while batching dataset"); + } } - } - int ndims = (data[0].elements() > 1) ? dims.ndim() : 0; + int ndims = (data[0].elements() > 1) ? dims.ndim() : 0; - // TODO: expand this to > 4 given fl::Tensor - should work out of the box - // by just removing this check? Possibly also change to ndims >= dims.ndims() - if (ndims >= 4) { - throw std::invalid_argument("# of dims must be < ndim - 1 for batching"); - } - // Dimensions of the batched tensor - std::vector batchDims = dims.get(); - if (ndims + 1 > batchDims.size()) { - batchDims.push_back(1); // placeholder dim - } - batchDims[ndims] = data.size(); - auto batcharr = Tensor(Shape(batchDims), data[0].type()); + // TODO: expand this to > 4 given fl::Tensor - should work out of the box + // by just removing this check? Possibly also change to ndims >= dims.ndims() + if(ndims >= 4) { + throw std::invalid_argument("# of dims must be < ndim - 1 for batching"); + } + // Dimensions of the batched tensor + std::vector batchDims = dims.get(); + if(ndims + 1 > batchDims.size()) { + batchDims.push_back(1); // placeholder dim + } + batchDims[ndims] = data.size(); + auto batcharr = Tensor(Shape(batchDims), data[0].type()); - for (size_t i = 0; i < data.size(); ++i) { - std::vector sel(batcharr.ndim(), fl::span); - sel[ndims] = i; - batcharr(sel) = data[i]; - } - return batcharr; + for(size_t i = 0; i < data.size(); ++i) { + std::vector sel(batcharr.ndim(), fl::span); + sel[ndims] = i; + batcharr(sel) = data[i]; + } + return batcharr; } } // namespace fl diff --git a/flashlight/fl/dataset/Utils.h b/flashlight/fl/dataset/Utils.h index 1175478..ca6932b 100644 --- a/flashlight/fl/dataset/Utils.h +++ b/flashlight/fl/dataset/Utils.h @@ -31,7 +31,8 @@ FL_API std::vector partitionByRoundRobin( int64_t partitionId, int64_t numPartitions, int64_t batchSz = 1, - bool allowEmpty = false); + bool allowEmpty = false +); /** * Partitions the samples in a round-robin manner and return ids of the samples @@ -42,13 +43,13 @@ FL_API std::vector partitionByRoundRobin( * @param numPartitions total partitions * @param maxTokens total number of tokens in the batch */ -FL_API std::pair, std::vector> -dynamicPartitionByRoundRobin( +FL_API std::pair, std::vector> dynamicPartitionByRoundRobin( const std::vector& samplesSize, int64_t partitionId, int64_t numPartitions, int64_t maxSizePerBatch, - bool allowEmpty = false); + bool allowEmpty = false +); /** * Make batch by applying batchFn to the data @@ -71,7 +72,8 @@ FL_API std::vector makeBatchFromRange( std::shared_ptr dataset, std::vector batchFns, int64_t start, - int64_t end); + int64_t end +); /** @} */ diff --git a/flashlight/fl/distributed/DistributedApi.cpp b/flashlight/fl/distributed/DistributedApi.cpp index 4d29729..a32601e 100644 --- a/flashlight/fl/distributed/DistributedApi.cpp +++ b/flashlight/fl/distributed/DistributedApi.cpp @@ -13,53 +13,54 @@ namespace fl { FL_API bool isDistributedInit() { - return detail::DistributedInfo::getInstance().isInitialized_; + return detail::DistributedInfo::getInstance().isInitialized_; } FL_API DistributedBackend distributedBackend() { - return detail::DistributedInfo::getInstance().backend_; + return detail::DistributedInfo::getInstance().backend_; } -FL_API void -allReduce(Variable& var, double scale /* = 1.0 */, bool async /* = false */) { - if (getWorldSize() > 1) { - allReduce(var.tensor(), async); - } - var.tensor() *= scale; +FL_API void allReduce(Variable& var, double scale /* = 1.0 */, bool async /* = false */) { + if(getWorldSize() > 1) { + allReduce(var.tensor(), async); + } + var.tensor() *= scale; } FL_API void allReduceMultiple( std::vector vars, double scale /* = 1.0 */, bool async /* = false */, - bool contiguous /* = false */) { - // return a vector of pointers to avoid copying - std::vector arrs; - for (auto& var : vars) { - arrs.push_back(&var.tensor()); - } - if (getWorldSize() > 1) { - allReduceMultiple(arrs, async, contiguous); - } - for (auto& var : vars) { - var.tensor() *= scale; - } + bool contiguous /* = false */ +) { + // return a vector of pointers to avoid copying + std::vector arrs; + for(auto& var : vars) { + arrs.push_back(&var.tensor()); + } + if(getWorldSize() > 1) { + allReduceMultiple(arrs, async, contiguous); + } + for(auto& var : vars) { + var.tensor() *= scale; + } } FL_API void barrier() { - auto tensor = Tensor::fromVector({0}); - allReduce(tensor, false); + auto tensor = Tensor::fromVector({0}); + allReduce(tensor, false); - // This hack is to make sure `tensor` will not be optimized away by a - // JIT during allreduce(). - fl::sum(tensor).asScalar(); + // This hack is to make sure `tensor` will not be optimized away by a + // JIT during allreduce(). + fl::sum(tensor).asScalar(); } namespace detail { -/* static */ DistributedInfo& DistributedInfo::getInstance() { - static DistributedInfo dinfo; - return dinfo; -} +/* static */ + DistributedInfo& DistributedInfo::getInstance() { + static DistributedInfo dinfo; + return dinfo; + } } // namespace detail } // namespace fl diff --git a/flashlight/fl/distributed/DistributedApi.h b/flashlight/fl/distributed/DistributedApi.h index ed2cb9e..71a8a72 100644 --- a/flashlight/fl/distributed/DistributedApi.h +++ b/flashlight/fl/distributed/DistributedApi.h @@ -99,7 +99,8 @@ FL_API void allReduceMultiple( std::vector vars, double scale = 1.0, bool async = false, - bool contiguous = false); + bool contiguous = false +); /** * Synchronizes a vector of pointers to arrays with allreduce. @@ -116,7 +117,8 @@ FL_API void allReduceMultiple( FL_API void allReduceMultiple( std::vector arrs, bool async = false, - bool contiguous = false); + bool contiguous = false +); /** * Synchronizes operations in the Flashlight compute stream with operations in @@ -138,17 +140,17 @@ FL_API void barrier(); /** @} */ namespace detail { -class DistributedInfo { - public: - static DistributedInfo& getInstance(); + class DistributedInfo { + public: + static DistributedInfo& getInstance(); - bool isInitialized_ = false; - DistributedInit initMethod_; - DistributedBackend backend_; + bool isInitialized_ = false; + DistributedInit initMethod_; + DistributedBackend backend_; - private: - DistributedInfo() = default; -}; + private: + DistributedInfo() = default; + }; } // namespace detail } // namespace fl diff --git a/flashlight/fl/distributed/FileStore.cpp b/flashlight/fl/distributed/FileStore.cpp index 9adc678..8a24ccc 100644 --- a/flashlight/fl/distributed/FileStore.cpp +++ b/flashlight/fl/distributed/FileStore.cpp @@ -15,8 +15,8 @@ namespace { static std::string encodeName(const std::string& name) { - static std::hash hashFn; - return std::to_string(hashFn(name)); + static std::hash hashFn; + return std::to_string(hashFn(name)); } } // namespace @@ -24,89 +24,93 @@ static std::string encodeName(const std::string& name) { namespace fl::detail { void FileStore::set(const std::string& key, const std::vector& data) { - fs::path tmp = tmpPath(key); - fs::path path = objectPath(key); - - { - // Fail if the key already exists. This implementation is not race free. - // A race free solution would need to atomically create the file 'path' - // using an API that fails if the file exists (not provided by STL). If - // created successfully, rename the temp file as below. - std::ifstream ifs(path); - if (ifs.is_open()) { - throw std::runtime_error( - "FileStore set: file already exists: " + path.string()); + fs::path tmp = tmpPath(key); + fs::path path = objectPath(key); + + { + // Fail if the key already exists. This implementation is not race free. + // A race free solution would need to atomically create the file 'path' + // using an API that fails if the file exists (not provided by STL). If + // created successfully, rename the temp file as below. + std::ifstream ifs(path); + if(ifs.is_open()) { + throw std::runtime_error( + "FileStore set: file already exists: " + path.string() + ); + } } - } - { - std::ofstream ofs(tmp, std::ios::out | std::ios::trunc); - if (!ofs.is_open()) { - throw std::runtime_error( - "FileStore set: file create failed: " + tmp.string()); + { + std::ofstream ofs(tmp, std::ios::out | std::ios::trunc); + if(!ofs.is_open()) { + throw std::runtime_error( + "FileStore set: file create failed: " + tmp.string() + ); + } + ofs.write(data.data(), data.size()); } - ofs.write(data.data(), data.size()); - } - // Atomically move result to final location - fs::rename(tmp, path); + // Atomically move result to final location + fs::rename(tmp, path); } std::vector FileStore::get(const std::string& key) { - fs::path path = objectPath(key); - std::vector result; - - // Block until key is set - wait(key); - - std::ifstream ifs(path, std::ios::in); - if (!ifs) { - throw std::runtime_error( - "FileStore get: file open failed: " + path.string()); - } - - ifs.seekg(0, std::ios::end); - size_t n = ifs.tellg(); - if (n == 0) { - throw std::runtime_error("FileStore get: file is empty: " + path.string()); - } - result.resize(n); - ifs.seekg(0); - ifs.read(result.data(), n); - return result; + fs::path path = objectPath(key); + std::vector result; + + // Block until key is set + wait(key); + + std::ifstream ifs(path, std::ios::in); + if(!ifs) { + throw std::runtime_error( + "FileStore get: file open failed: " + path.string() + ); + } + + ifs.seekg(0, std::ios::end); + size_t n = ifs.tellg(); + if(n == 0) { + throw std::runtime_error("FileStore get: file is empty: " + path.string()); + } + result.resize(n); + ifs.seekg(0); + ifs.read(result.data(), n); + return result; } void FileStore::clear(const std::string& key) { - fs::path path = objectPath(key); - fs::remove(path); + fs::path path = objectPath(key); + fs::remove(path); } bool FileStore::check(const std::string& key) { - fs::path path = objectPath(key); - return fs::exists(path); + fs::path path = objectPath(key); + return fs::exists(path); } void FileStore::wait(const std::string& key) { - // Not using inotify because it doesn't work on many - // shared filesystems (such as NFS). - const auto start = std::chrono::steady_clock::now(); - while (!check(key)) { - const auto elapsed = std::chrono::duration_cast( - std::chrono::steady_clock::now() - start); - if (elapsed > FileStore::kDefaultTimeout) { - throw std::runtime_error("FileStore timed out for key: " + key); + // Not using inotify because it doesn't work on many + // shared filesystems (such as NFS). + const auto start = std::chrono::steady_clock::now(); + while(!check(key)) { + const auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start + ); + if(elapsed > FileStore::kDefaultTimeout) { + throw std::runtime_error("FileStore timed out for key: " + key); + } + /* sleep override */ + std::this_thread::sleep_for(std::chrono::milliseconds(10)); } - /* sleep override */ - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } } fs::path FileStore::tmpPath(const std::string& name) { - return basePath_ / fs::path("." + encodeName(name)); + return basePath_ / fs::path("." + encodeName(name)); } fs::path FileStore::objectPath(const std::string& name) { - return basePath_ / fs::path(encodeName(name)); + return basePath_ / fs::path(encodeName(name)); } } // namespace fl diff --git a/flashlight/fl/distributed/FileStore.h b/flashlight/fl/distributed/FileStore.h index f331a94..f707216 100644 --- a/flashlight/fl/distributed/FileStore.h +++ b/flashlight/fl/distributed/FileStore.h @@ -20,23 +20,23 @@ namespace detail { // Inspired from // https://github.com/facebookincubator/gloo/blob/master/gloo/rendezvous/file_store.h -class FL_API FileStore { - public: - static constexpr std::chrono::milliseconds kDefaultTimeout = - std::chrono::seconds(60 * 2); - explicit FileStore(const fs::path& path) : basePath_(path) {} - std::vector get(const std::string& key); - void set(const std::string& key, const std::vector& data); - void clear(const std::string& key); - - private: - fs::path basePath_; - - void wait(const std::string& key); - bool check(const std::string& key); - fs::path objectPath(const std::string& name); - fs::path tmpPath(const std::string& name); -}; + class FL_API FileStore { + public: + static constexpr std::chrono::milliseconds kDefaultTimeout = + std::chrono::seconds(60 * 2); + explicit FileStore(const fs::path& path) : basePath_(path) {} + std::vector get(const std::string& key); + void set(const std::string& key, const std::vector& data); + void clear(const std::string& key); + + private: + fs::path basePath_; + + void wait(const std::string& key); + bool check(const std::string& key); + fs::path objectPath(const std::string& name); + fs::path tmpPath(const std::string& name); + }; } // namespace detail } // namespace fl diff --git a/flashlight/fl/distributed/LRUCache.h b/flashlight/fl/distributed/LRUCache.h index 2dcf355..f13ad51 100644 --- a/flashlight/fl/distributed/LRUCache.h +++ b/flashlight/fl/distributed/LRUCache.h @@ -16,65 +16,65 @@ namespace detail { // The following section is taken from // https://github.com/fairinternal/FAIR_rush/blob/master/cpid/distributed.h -template -class LRUCache { - // store keys of cache - std::list dq_; - - // store references of key in cache - std::unordered_map< - K, - std::pair::iterator, std::unique_ptr>> - map_; - - size_t csize_; // maximum capacity of cache - - public: - explicit LRUCache(int n) : csize_(n) {} - - inline V* put(K k, std::unique_ptr&& v) { - if (map_.find(k) == map_.end()) { - // Not in cache, cache size too big - if (dq_.size() == csize_) { - map_.erase(dq_.back()); - dq_.pop_back(); - } - } else { - dq_.erase(map_[k].first); + template + class LRUCache { + // store keys of cache + std::list dq_; + + // store references of key in cache + std::unordered_map< + K, + std::pair::iterator, std::unique_ptr>> + map_; + + size_t csize_; // maximum capacity of cache + + public: + explicit LRUCache(int n) : csize_(n) {} + + inline V* put(K k, std::unique_ptr&& v) { + if(map_.find(k) == map_.end()) { + // Not in cache, cache size too big + if(dq_.size() == csize_) { + map_.erase(dq_.back()); + dq_.pop_back(); + } + } else { + dq_.erase(map_[k].first); + } + + dq_.push_front(k); + map_[k] = std::make_pair(dq_.begin(), std::move(v)); + return map_[k].second.get(); + } + + inline V* get(K const& k) { + if(map_.find(k) == map_.end()) { + return nullptr; + } else { + // Move list node to front + auto& it = map_[k].first; + dq_.splice(dq_.begin(), dq_, it); + return map_[k].second.get(); + } + } + }; + + inline void hashKeyHelper(std::stringstream&) {} + + template + inline void hashKeyHelper(std::stringstream& ss, const T& x, Args&&... params) { + ss << " " << x; + hashKeyHelper(ss, std::forward(params)...); } - dq_.push_front(k); - map_[k] = std::make_pair(dq_.begin(), std::move(v)); - return map_[k].second.get(); - } - - inline V* get(K const& k) { - if (map_.find(k) == map_.end()) { - return nullptr; - } else { - // Move list node to front - auto& it = map_[k].first; - dq_.splice(dq_.begin(), dq_, it); - return map_[k].second.get(); + template + inline std::string makeHashKey(T* ptr, Args&&... params) { + std::stringstream ss; + ss << typeid(T).name() << " " << reinterpret_cast(ptr); + hashKeyHelper(ss, std::forward(params)...); + return ss.str(); } - } -}; - -inline void hashKeyHelper(std::stringstream&) {} - -template -inline void hashKeyHelper(std::stringstream& ss, const T& x, Args&&... params) { - ss << " " << x; - hashKeyHelper(ss, std::forward(params)...); -} - -template -inline std::string makeHashKey(T* ptr, Args&&... params) { - std::stringstream ss; - ss << typeid(T).name() << " " << reinterpret_cast(ptr); - hashKeyHelper(ss, std::forward(params)...); - return ss.str(); -} } // namespace detail diff --git a/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp b/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp index dc7d1a1..e3ebb6d 100644 --- a/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp +++ b/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp @@ -45,133 +45,148 @@ namespace fl { namespace detail { -std::shared_ptr globalContext() { - return glooContext_; -} - -template -inline void allreduceGloo(T* ptr, size_t s) { - auto key = detail::makeHashKey(ptr, s, "allreduceCpu"); - auto algorithm = glooCache_.get(key); - if (algorithm == nullptr) { - using Allreduce = gloo::AllreduceHalvingDoubling; - algorithm = glooCache_.put( - key, - std::make_unique( - globalContext(), - std::vector({ptr}), - s, - gloo::ReductionFunction::sum)); - } - algorithm->run(); -} + std::shared_ptr globalContext() { + return glooContext_; + } + + template + inline void allreduceGloo(T* ptr, size_t s) { + auto key = detail::makeHashKey(ptr, s, "allreduceCpu"); + auto algorithm = glooCache_.get(key); + if(algorithm == nullptr) { + using Allreduce = gloo::AllreduceHalvingDoubling; + algorithm = glooCache_.put( + key, + std::make_unique( + globalContext(), + std::vector({ptr}), + s, + gloo::ReductionFunction::sum + ) + ); + } + algorithm->run(); + } } // namespace detail void distributedInit( DistributedInit initMethod, int /* worldRank */, int /* worldSize */, - const std::unordered_map& /* params = {} */) { - if (isDistributedInit()) { - std::cerr << "warning: fl::distributedInit() called more than once\n"; - return; - } - - if (initMethod != DistributedInit::MPI) { - throw std::runtime_error( - "unsupported distributed init method for gloo backend"); - } - - // using MPI - if (glooContext_ != nullptr) { - return; - } - // TODO: ibverbs support. - auto glooDev = gloo::transport::tcp::CreateDevice(""); - - // Create Gloo context from MPI communicator - glooContext_ = gloo::mpi::Context::createManaged(); - glooContext_->setTimeout(gloo::kNoTimeout); - glooContext_->connectFullMesh(glooDev); - - detail::DistributedInfo::getInstance().backend_ = DistributedBackend::GLOO; - detail::DistributedInfo::getInstance().isInitialized_ = true; - if (glooContext_->rank == 0) { - std::cout << "Initialized Gloo successfully!\n"; - } + const std::unordered_map& /* params = {} */ +) { + if(isDistributedInit()) { + std::cerr << "warning: fl::distributedInit() called more than once\n"; + return; + } + + if(initMethod != DistributedInit::MPI) { + throw std::runtime_error( + "unsupported distributed init method for gloo backend" + ); + } + + // using MPI + if(glooContext_ != nullptr) { + return; + } + // TODO: ibverbs support. + auto glooDev = gloo::transport::tcp::CreateDevice(""); + + // Create Gloo context from MPI communicator + glooContext_ = gloo::mpi::Context::createManaged(); + glooContext_->setTimeout(gloo::kNoTimeout); + glooContext_->connectFullMesh(glooDev); + + detail::DistributedInfo::getInstance().backend_ = DistributedBackend::GLOO; + detail::DistributedInfo::getInstance().isInitialized_ = true; + if(glooContext_->rank == 0) { + std::cout << "Initialized Gloo successfully!\n"; + } } void allReduce(fl::Tensor& tensor, bool async /* = false */) { - if (!isDistributedInit()) { - throw std::runtime_error("distributed environment not initialized"); - } - if (async) { - throw std::runtime_error( - "Asynchronous allReduce not yet supported for Gloo backend"); - } - size_t tensorSize = tensor.elements() * fl::getTypeSize(tensor.type()); - if (tensorSize > cacheTensor_.elements()) { - cacheTensor_ = - fl::Tensor({static_cast(tensorSize)}, fl::dtype::b8); - } - DevicePtr tensorPtr(tensor); - DevicePtr cacheTensorPtr(cacheTensor_); - memcpy(cacheTensorPtr.get(), tensorPtr.get(), tensorSize); - switch (tensor.type()) { - case fl::dtype::f32: - detail::allreduceGloo( - static_cast(cacheTensorPtr.get()), tensor.elements()); - break; - case fl::dtype::f64: - detail::allreduceGloo( - static_cast(cacheTensorPtr.get()), tensor.elements()); - break; - case fl::dtype::s32: - detail::allreduceGloo( - static_cast(cacheTensorPtr.get()), tensor.elements()); - break; - case fl::dtype::s64: - detail::allreduceGloo( - static_cast(cacheTensorPtr.get()), tensor.elements()); - break; - default: - throw std::runtime_error("unsupported data type for allreduce with gloo"); - } - memcpy(tensorPtr.get(), cacheTensorPtr.get(), tensorSize); + if(!isDistributedInit()) { + throw std::runtime_error("distributed environment not initialized"); + } + if(async) { + throw std::runtime_error( + "Asynchronous allReduce not yet supported for Gloo backend" + ); + } + size_t tensorSize = tensor.elements() * fl::getTypeSize(tensor.type()); + if(tensorSize > cacheTensor_.elements()) { + cacheTensor_ = + fl::Tensor({static_cast(tensorSize)}, fl::dtype::b8); + } + DevicePtr tensorPtr(tensor); + DevicePtr cacheTensorPtr(cacheTensor_); + memcpy(cacheTensorPtr.get(), tensorPtr.get(), tensorSize); + switch(tensor.type()) { + case fl::dtype::f32: + detail::allreduceGloo( + static_cast(cacheTensorPtr.get()), + tensor.elements() + ); + break; + case fl::dtype::f64: + detail::allreduceGloo( + static_cast(cacheTensorPtr.get()), + tensor.elements() + ); + break; + case fl::dtype::s32: + detail::allreduceGloo( + static_cast(cacheTensorPtr.get()), + tensor.elements() + ); + break; + case fl::dtype::s64: + detail::allreduceGloo( + static_cast(cacheTensorPtr.get()), + tensor.elements() + ); + break; + default: + throw std::runtime_error("unsupported data type for allreduce with gloo"); + } + memcpy(tensorPtr.get(), cacheTensorPtr.get(), tensorSize); } // Not yet supported void allReduceMultiple( std::vector tensors, bool async /* = false */, - bool contiguous /* = false */) { - if (contiguous) { - throw std::runtime_error( - "contiguous allReduceMultiple is not yet supported for Gloo backend"); - } - - for (auto& tensor : tensors) { - allReduce(*tensor, async); - } + bool contiguous /* = false */ +) { + if(contiguous) { + throw std::runtime_error( + "contiguous allReduceMultiple is not yet supported for Gloo backend" + ); + } + + for(auto& tensor : tensors) { + allReduce(*tensor, async); + } } void syncDistributed() { - // NOOP since async distributed operations aren't yet supported with the Gloo - // backend - return; + // NOOP since async distributed operations aren't yet supported with the Gloo + // backend + return; } int getWorldRank() { - if (!isDistributedInit()) { - return 0; - } - return detail::globalContext()->rank; + if(!isDistributedInit()) { + return 0; + } + return detail::globalContext()->rank; } int getWorldSize() { - if (!isDistributedInit()) { - return 1; - } - return detail::globalContext()->size; + if(!isDistributedInit()) { + return 1; + } + return detail::globalContext()->size; } } // namespace fl diff --git a/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp b/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp index 9a6ac72..257bac8 100644 --- a/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp +++ b/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp @@ -31,199 +31,216 @@ namespace fl { namespace detail { -namespace { + namespace { // We need to pass this flag to our dedicated NCCL CUDA stream, else activity in // the stream will be precluded from running in parallel with the default stream -constexpr unsigned int kDefaultStreamFlags = cudaStreamNonBlocking; - -constexpr const char* kNcclKey = "ncclUniqueId"; - -class NcclContext { - public: - static NcclContext& getInstance(); - NcclContext() = default; - ~NcclContext(); - void initWithMPI(const std::unordered_map& params); - void initWithFileSystem( - int worldRank, - int worldSize, - const std::unordered_map& params); - ncclComm_t& getComm(); - int getWorldSize() const; - int getWorldRank() const; - const CUDAStream& getReductionStream() const; - const CUDAStream& getWorkerStream() const; - void* getCoalesceBuffer(); - - private: - // create CUDA resources - void createCudaResources(); - ncclComm_t comm_; - int worldSize_, worldRank_; - // CUDA stream in which NCCL calls run if in async mode - std::shared_ptr reductionStream_; - // CUDA stream in which cudaMemcpyAsync calls run if in contiguous mode - std::shared_ptr workerStream_; - // Buffer for storing copied gradients contiguously; exists on device memory - void* coalesceBuffer_{nullptr}; - std::once_flag allocBuffer_; -}; - -bool isNonNegativeInteger(const std::string& s) { - return !s.empty() && std::find_if(s.begin(), s.end(), [](char c) { - return !std::isdigit(c); - }) == s.end(); -} - -ncclDataType_t getNcclTypeForArray(const Tensor& arr) { - switch (arr.type()) { - case fl::dtype::f16: - return ncclHalf; - case fl::dtype::f32: - return ncclFloat32; - case fl::dtype::f64: - return ncclFloat64; - case fl::dtype::s32: - return ncclInt32; - case fl::dtype::s64: - return ncclInt64; - break; - default: - throw std::runtime_error("unsupported data type for allreduce with NCCL"); - } -} - -} // namespace - -void ncclCheck(ncclResult_t r); - -void mpiCheck(int ec); - -void allReduceCuda( - const CUDAStream* bufferStream, - void* ptr, - const size_t count, - const ncclDataType_t ncclType, - const bool async, - const bool contiguous); + constexpr unsigned int kDefaultStreamFlags = cudaStreamNonBlocking; + + constexpr const char* kNcclKey = "ncclUniqueId"; + + class NcclContext { + public: + static NcclContext& getInstance(); + NcclContext() = default; + ~NcclContext(); + void initWithMPI(const std::unordered_map& params); + void initWithFileSystem( + int worldRank, + int worldSize, + const std::unordered_map& params + ); + ncclComm_t& getComm(); + int getWorldSize() const; + int getWorldRank() const; + const CUDAStream& getReductionStream() const; + const CUDAStream& getWorkerStream() const; + void* getCoalesceBuffer(); + + private: + // create CUDA resources + void createCudaResources(); + ncclComm_t comm_; + int worldSize_, worldRank_; + // CUDA stream in which NCCL calls run if in async mode + std::shared_ptr reductionStream_; + // CUDA stream in which cudaMemcpyAsync calls run if in contiguous mode + std::shared_ptr workerStream_; + // Buffer for storing copied gradients contiguously; exists on device memory + void* coalesceBuffer_{nullptr}; + std::once_flag allocBuffer_; + }; + + bool isNonNegativeInteger(const std::string& s) { + return !s.empty() && std::find_if( + s.begin(), + s.end(), + [](char c) { + return !std::isdigit(c); + } + ) == s.end(); + } + + ncclDataType_t getNcclTypeForArray(const Tensor& arr) { + switch(arr.type()) { + case fl::dtype::f16: + return ncclHalf; + case fl::dtype::f32: + return ncclFloat32; + case fl::dtype::f64: + return ncclFloat64; + case fl::dtype::s32: + return ncclInt32; + case fl::dtype::s64: + return ncclInt64; + break; + default: + throw std::runtime_error("unsupported data type for allreduce with NCCL"); + } + } + + } // namespace + + void ncclCheck(ncclResult_t r); + + void mpiCheck(int ec); + + void allReduceCuda( + const CUDAStream* bufferStream, + void* ptr, + const size_t count, + const ncclDataType_t ncclType, + const bool async, + const bool contiguous + ); } // namespace detail void allReduce(Tensor& arr, bool async /* = false */) { - if (!isDistributedInit()) { - throw std::runtime_error("distributed environment not initialized"); - } - ncclDataType_t type = detail::getNcclTypeForArray(arr); - DevicePtr tensorPtr(arr); - detail::allReduceCuda( - &arr.stream().impl(), - tensorPtr.get(), - arr.elements(), - type, - async, - /* contiguous = */ false); + if(!isDistributedInit()) { + throw std::runtime_error("distributed environment not initialized"); + } + ncclDataType_t type = detail::getNcclTypeForArray(arr); + DevicePtr tensorPtr(arr); + detail::allReduceCuda( + &arr.stream().impl(), + tensorPtr.get(), + arr.elements(), + type, + async, + /* contiguous = */ false + ); } void allReduceMultiple( std::vector arrs, bool async /* = false */, - bool contiguous /* = false */) { - // Fast paths - if (arrs.empty()) { - return; - } - - if (!contiguous) { - // Use nccl groups to do everything in a single kernel launch - NCCLCHECK(ncclGroupStart()); - for (auto& arr : arrs) { - allReduce(*arr, async); + bool contiguous /* = false */ +) { + // Fast paths + if(arrs.empty()) { + return; } - NCCLCHECK(ncclGroupEnd()); - return; - } - - // We can only do a contiguous set reduction if all arrays in the set are of - // the same type, else fail - ncclDataType_t ncclType = detail::getNcclTypeForArray(*arrs[0]); - for (auto& arr : arrs) { - if (detail::getNcclTypeForArray(*arr) != ncclType) { - throw std::runtime_error( - "Cannot perform contiguous set allReduce on a set of tensors " - "of different types"); + + if(!contiguous) { + // Use nccl groups to do everything in a single kernel launch + NCCLCHECK(ncclGroupStart()); + for(auto& arr : arrs) { + allReduce(*arr, async); + } + NCCLCHECK(ncclGroupEnd()); + return; + } + + // We can only do a contiguous set reduction if all arrays in the set are of + // the same type, else fail + ncclDataType_t ncclType = detail::getNcclTypeForArray(*arrs[0]); + for(auto& arr : arrs) { + if(detail::getNcclTypeForArray(*arr) != ncclType) { + throw std::runtime_error( + "Cannot perform contiguous set allReduce on a set of tensors " + "of different types" + ); + } + } + // Size of each element in each tensor in bytes + size_t typeSize = fl::getTypeSize(arrs[0]->type()); + + // Device ptrs from each array + std::vector> tensorPtrs; + tensorPtrs.reserve(arrs.size()); + size_t totalEls{0}; + for(auto& arr : arrs) { + totalEls += arr->elements(); + tensorPtrs.emplace_back(DevicePtr(*arr), arr->bytes()); } - } - // Size of each element in each tensor in bytes - size_t typeSize = fl::getTypeSize(arrs[0]->type()); - - // Device ptrs from each array - std::vector> tensorPtrs; - tensorPtrs.reserve(arrs.size()); - size_t totalEls{0}; - for (auto& arr : arrs) { - totalEls += arr->elements(); - tensorPtrs.emplace_back(DevicePtr(*arr), arr->bytes()); - } - - // Make sure our coalesce buffer is large enough. Since we're initializing our - // coalescing cache to the same size, if we're using contiguous sync, it - // should never be larger since we flush if adding an additional buffer would - // exceed the max cache size - if (totalEls * typeSize > DistributedConstants::kCoalesceCacheSize) { - throw std::runtime_error( - "Total coalesce buffer size is larger than existing buffer size"); - } - - auto& ncclContext = detail::NcclContext::getInstance(); - const auto& workerStream = ncclContext.getWorkerStream(); - - const auto constTensors = std::vector(arrs.begin(), arrs.end()); - // Block the copy worker stream on Flashlight's active CUDA stream - relativeSync(workerStream, constTensors); - - // In the worker stream, coalesce gradients into one large buffer so we - // only need to call allReduce - void* coalesceBuffer = ncclContext.getCoalesceBuffer(); - auto* cur = reinterpret_cast(coalesceBuffer); - for (auto& entry : tensorPtrs) { - FL_CUDA_CHECK(cudaMemcpyAsync( - cur, - entry.first.get(), - entry.second, - cudaMemcpyDeviceToDevice, - workerStream.handle())); - cur += entry.second; - } - - // Now, call allReduce once on the entire copy buffer - detail::allReduceCuda( - &workerStream, - coalesceBuffer, - totalEls, - ncclType, - async, - contiguous); - - // Block the worker stream's copy operations on allReduce operations that are - // currently enqueued in the reduction stream - if (async) { - workerStream.relativeSync(ncclContext.getReductionStream()); - } else { + + // Make sure our coalesce buffer is large enough. Since we're initializing our + // coalescing cache to the same size, if we're using contiguous sync, it + // should never be larger since we flush if adding an additional buffer would + // exceed the max cache size + if(totalEls * typeSize > DistributedConstants::kCoalesceCacheSize) { + throw std::runtime_error( + "Total coalesce buffer size is larger than existing buffer size" + ); + } + + auto& ncclContext = detail::NcclContext::getInstance(); + const auto& workerStream = ncclContext.getWorkerStream(); + + const auto constTensors = std::vector(arrs.begin(), arrs.end()); + // Block the copy worker stream on Flashlight's active CUDA stream relativeSync(workerStream, constTensors); - } - - // Enqueue operations in the stream to copy back to each respective array from - // the coalesce buffer - cur = reinterpret_cast(coalesceBuffer); - for (auto& entry : tensorPtrs) { - FL_CUDA_CHECK(cudaMemcpyAsync( - entry.first.get(), - cur, - entry.second, - cudaMemcpyDeviceToDevice, - workerStream.handle())); - cur += entry.second; - } + + // In the worker stream, coalesce gradients into one large buffer so we + // only need to call allReduce + void* coalesceBuffer = ncclContext.getCoalesceBuffer(); + auto* cur = reinterpret_cast(coalesceBuffer); + for(auto& entry : tensorPtrs) { + FL_CUDA_CHECK( + cudaMemcpyAsync( + cur, + entry.first.get(), + entry.second, + cudaMemcpyDeviceToDevice, + workerStream.handle() + ) + ); + cur += entry.second; + } + + // Now, call allReduce once on the entire copy buffer + detail::allReduceCuda( + &workerStream, + coalesceBuffer, + totalEls, + ncclType, + async, + contiguous + ); + + // Block the worker stream's copy operations on allReduce operations that are + // currently enqueued in the reduction stream + if(async) { + workerStream.relativeSync(ncclContext.getReductionStream()); + } else { + relativeSync(workerStream, constTensors); + } + + // Enqueue operations in the stream to copy back to each respective array from + // the coalesce buffer + cur = reinterpret_cast(coalesceBuffer); + for(auto& entry : tensorPtrs) { + FL_CUDA_CHECK( + cudaMemcpyAsync( + entry.first.get(), + cur, + entry.second, + cudaMemcpyDeviceToDevice, + workerStream.handle() + ) + ); + cur += entry.second; + } } /** @@ -231,166 +248,179 @@ void allReduceMultiple( * operations currently running in the NCCL [and worker] CUDA stream. */ void syncDistributed() { - const auto& ncclContext = detail::NcclContext::getInstance(); - const auto& manager = DeviceManager::getInstance(); - const auto& activeCudaDevice = manager.getActiveDevice(DeviceType::CUDA); - const auto& workerStream = ncclContext.getWorkerStream(); - const auto& reductionStream = ncclContext.getReductionStream(); - for (const auto& stream : activeCudaDevice.getStreams()) { - if (stream.get() != &workerStream && stream.get() != &reductionStream) { - stream->relativeSync(workerStream); - stream->relativeSync(reductionStream); + const auto& ncclContext = detail::NcclContext::getInstance(); + const auto& manager = DeviceManager::getInstance(); + const auto& activeCudaDevice = manager.getActiveDevice(DeviceType::CUDA); + const auto& workerStream = ncclContext.getWorkerStream(); + const auto& reductionStream = ncclContext.getReductionStream(); + for(const auto& stream : activeCudaDevice.getStreams()) { + if(stream.get() != &workerStream && stream.get() != &reductionStream) { + stream->relativeSync(workerStream); + stream->relativeSync(reductionStream); + } } - } } int getWorldRank() { - if (!isDistributedInit()) { - return 0; - } - return detail::NcclContext::getInstance().getWorldRank(); + if(!isDistributedInit()) { + return 0; + } + return detail::NcclContext::getInstance().getWorldRank(); } int getWorldSize() { - if (!isDistributedInit()) { - return 1; - } - return detail::NcclContext::getInstance().getWorldSize(); + if(!isDistributedInit()) { + return 1; + } + return detail::NcclContext::getInstance().getWorldSize(); } void distributedInit( DistributedInit initMethod, int worldRank, int worldSize, - const std::unordered_map& params /* = {} */) { - if (isDistributedInit()) { - std::cerr << "warning: fl::distributedInit() called more than once\n"; - return; - } - if (initMethod == DistributedInit::MPI) { - detail::NcclContext::getInstance().initWithMPI(params); - detail::DistributedInfo::getInstance().initMethod_ = DistributedInit::MPI; - } else if (initMethod == DistributedInit::FILE_SYSTEM) { - detail::NcclContext::getInstance().initWithFileSystem( - worldRank, worldSize, params); - detail::DistributedInfo::getInstance().initMethod_ = - DistributedInit::FILE_SYSTEM; - } else { - throw std::runtime_error( - "unsupported distributed init method for NCCL backend"); - } - detail::DistributedInfo::getInstance().isInitialized_ = true; - detail::DistributedInfo::getInstance().backend_ = DistributedBackend::NCCL; - if (getWorldRank() == 0) { - std::cout << "Initialized NCCL " << NCCL_MAJOR << "." << NCCL_MINOR << "." - << NCCL_PATCH << " successfully!\n"; - } + const std::unordered_map& params /* = {} */ +) { + if(isDistributedInit()) { + std::cerr << "warning: fl::distributedInit() called more than once\n"; + return; + } + if(initMethod == DistributedInit::MPI) { + detail::NcclContext::getInstance().initWithMPI(params); + detail::DistributedInfo::getInstance().initMethod_ = DistributedInit::MPI; + } else if(initMethod == DistributedInit::FILE_SYSTEM) { + detail::NcclContext::getInstance().initWithFileSystem( + worldRank, + worldSize, + params + ); + detail::DistributedInfo::getInstance().initMethod_ = + DistributedInit::FILE_SYSTEM; + } else { + throw std::runtime_error( + "unsupported distributed init method for NCCL backend" + ); + } + detail::DistributedInfo::getInstance().isInitialized_ = true; + detail::DistributedInfo::getInstance().backend_ = DistributedBackend::NCCL; + if(getWorldRank() == 0) { + std::cout << "Initialized NCCL " << NCCL_MAJOR << "." << NCCL_MINOR << "." + << NCCL_PATCH << " successfully!\n"; + } } namespace detail { -void ncclCheck(ncclResult_t r) { - if (r == ncclSuccess) { - return; - } - const char* err = ncclGetErrorString(r); - if (r == ncclInvalidArgument) { - throw std::invalid_argument(err); - } else if (r == ncclInvalidUsage) { - throw std::logic_error(err); - } else { - throw std::runtime_error(err); - } -} - -void mpiCheck(int ec) { - if (ec == MPI_SUCCESS) { - return; - } else { - char buf[MPI_MAX_ERROR_STRING]; - int resultlen; - MPI_Error_string(ec, buf, &resultlen); - throw std::runtime_error(buf); - } -} - -void allReduceCuda( - const CUDAStream* bufferStream, - void* ptr, - const size_t count, - const ncclDataType_t ncclType, - const bool async, - const bool contiguous) { - const CUDAStream* syncStream; - auto& ncclContext = detail::NcclContext::getInstance(); - if (async) { - syncStream = &ncclContext.getReductionStream(); - } else { - syncStream = bufferStream; - } - - // Synchronize with whatever CUDA stream is performing operations needed - // pre-reduction. If we're in contiguous mode, we need the reduction stream to - // wait for the copy in the worker stream to complete. If we're not in - // CUDA stream. - if (contiguous) { - // block future reduction stream ops on the copy-worker stream - syncStream->relativeSync(ncclContext.getWorkerStream()); - } else if (async) { - syncStream->relativeSync(*bufferStream); - } - // don't synchronize streams if not async and not contiguous - the AF CUDA - // stream does everything - - NCCLCHECK(ncclAllReduce( - ptr, - ptr, - count, - ncclType, - ncclSum, - ncclContext.getComm(), - syncStream->handle())); -} -namespace { - -ncclComm_t& NcclContext::getComm() { - return comm_; -} - -int NcclContext::getWorldSize() const { - return worldSize_; -} - -int NcclContext::getWorldRank() const { - return worldRank_; -} - -const CUDAStream& NcclContext::getReductionStream() const { - return *reductionStream_; -} + void ncclCheck(ncclResult_t r) { + if(r == ncclSuccess) { + return; + } + const char* err = ncclGetErrorString(r); + if(r == ncclInvalidArgument) { + throw std::invalid_argument(err); + } else if(r == ncclInvalidUsage) { + throw std::logic_error(err); + } else { + throw std::runtime_error(err); + } + } -const CUDAStream& NcclContext::getWorkerStream() const { - return *workerStream_; -} + void mpiCheck(int ec) { + if(ec == MPI_SUCCESS) { + return; + } else { + char buf[MPI_MAX_ERROR_STRING]; + int resultlen; + MPI_Error_string(ec, buf, &resultlen); + throw std::runtime_error(buf); + } + } -void* NcclContext::getCoalesceBuffer() { - std::call_once(allocBuffer_, [&]() { - FL_CUDA_CHECK( - cudaMalloc(&coalesceBuffer_, DistributedConstants::kCoalesceCacheSize)); - }); - return coalesceBuffer_; -} + void allReduceCuda( + const CUDAStream* bufferStream, + void* ptr, + const size_t count, + const ncclDataType_t ncclType, + const bool async, + const bool contiguous + ) { + const CUDAStream* syncStream; + auto& ncclContext = detail::NcclContext::getInstance(); + if(async) { + syncStream = &ncclContext.getReductionStream(); + } else { + syncStream = bufferStream; + } + + // Synchronize with whatever CUDA stream is performing operations needed + // pre-reduction. If we're in contiguous mode, we need the reduction stream to + // wait for the copy in the worker stream to complete. If we're not in + // CUDA stream. + if(contiguous) { + // block future reduction stream ops on the copy-worker stream + syncStream->relativeSync(ncclContext.getWorkerStream()); + } else if(async) { + syncStream->relativeSync(*bufferStream); + } + // don't synchronize streams if not async and not contiguous - the AF CUDA + // stream does everything + + NCCLCHECK( + ncclAllReduce( + ptr, + ptr, + count, + ncclType, + ncclSum, + ncclContext.getComm(), + syncStream->handle() + ) + ); + } + namespace { + + ncclComm_t& NcclContext::getComm() { + return comm_; + } + + int NcclContext::getWorldSize() const { + return worldSize_; + } + + int NcclContext::getWorldRank() const { + return worldRank_; + } + + const CUDAStream& NcclContext::getReductionStream() const { + return *reductionStream_; + } + + const CUDAStream& NcclContext::getWorkerStream() const { + return *workerStream_; + } + + void* NcclContext::getCoalesceBuffer() { + std::call_once( + allocBuffer_, + [&]() { + FL_CUDA_CHECK( + cudaMalloc(&coalesceBuffer_, DistributedConstants::kCoalesceCacheSize) + ); + } + ); + return coalesceBuffer_; + } /* static */ NcclContext& NcclContext::getInstance() { - static NcclContext ncclCtx; - return ncclCtx; -} + static NcclContext ncclCtx; + return ncclCtx; + } -void NcclContext::createCudaResources() { - // initialize - // - dedicated NCCL CUDA stream to support async allReduce - // - a third dedicated stream to asynchronously copy gradients - // into a coalesced form if using a contiguous allReduce + void NcclContext::createCudaResources() { + // initialize + // - dedicated NCCL CUDA stream to support async allReduce + // - a third dedicated stream to asynchronously copy gradients + // into a coalesced form if using a contiguous allReduce // Destroying the dedicated NCCL CUDA stream is a bit odd since the stream // lives in a global NcclContext singleton. The CUDA driver shuts down when @@ -400,118 +430,124 @@ void NcclContext::createCudaResources() { // all cases, streams are destroyed when the driver shuts down, so don't // destroy the stream by default. #ifdef CUDA_STREAM_POOL_DESTROY_ON_SHUTDOWN - reductionStream_ = CUDAStream::createManaged(detail::kDefaultStreamFlags); - workerStream_ = CUDAStream::createManaged(detail::kDefaultStreamFlags); + reductionStream_ = CUDAStream::createManaged(detail::kDefaultStreamFlags); + workerStream_ = CUDAStream::createManaged(detail::kDefaultStreamFlags); #else - reductionStream_ = CUDAStream::createUnmanaged(detail::kDefaultStreamFlags); - workerStream_ = CUDAStream::createUnmanaged(detail::kDefaultStreamFlags); + reductionStream_ = CUDAStream::createUnmanaged(detail::kDefaultStreamFlags); + workerStream_ = CUDAStream::createUnmanaged(detail::kDefaultStreamFlags); #endif -} - -void NcclContext::initWithMPI( - const std::unordered_map& params) { - // initializing MPI - MPICHECK(MPI_Init(nullptr, nullptr)); - MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &worldRank_)); - MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &worldSize_)); - - auto maxDevicePerNode = params.find(DistributedConstants::kMaxDevicePerNode); - if (maxDevicePerNode == params.end() || - !isNonNegativeInteger(maxDevicePerNode->second) || - std::stoi(maxDevicePerNode->second) == 0) { - throw std::invalid_argument( - "invalid MaxDevicePerNode for NCCL initWithMPI"); - } - - ncclUniqueId id; - - // TODO: Determining device is ugly. Find a better way. - fl::setDevice(worldRank_ % std::stoi(maxDevicePerNode->second)); - - // get NCCL unique ID at rank 0 and broadcast it to all others - if (worldRank_ == 0) { - ncclGetUniqueId(&id); - } - MPICHECK(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); - - // initializing NCCL - NCCLCHECK(ncclCommInitRank(&comm_, worldSize_, id, worldRank_)); - - createCudaResources(); -} - -void NcclContext::initWithFileSystem( - int worldRank, - int worldSize, - const std::unordered_map& params) { - auto filePath = params.find(DistributedConstants::kFilePath); - auto maxDevicePerNode = params.find(DistributedConstants::kMaxDevicePerNode); - - if (filePath == params.end() || filePath->second.empty()) { - throw std::invalid_argument("invalid FilePath for NCCL initWithFileSystem"); - } - if (maxDevicePerNode == params.end()) { - throw std::invalid_argument( - "invalid MaxDevicePerNode for NCCL initWithFileSystem"); - } - - worldRank_ = worldRank; - worldSize_ = worldSize; - - ncclUniqueId id; - - fl::setDevice(worldRank_ % std::stoi(maxDevicePerNode->second)); - - // get NCCL unique ID at rank 0 and broadcast it to all others - if (worldRank_ == 0) { - ncclGetUniqueId(&id); - } - - auto fs = FileStore(filePath->second); - if (worldRank_ == 0) { - std::vector data(sizeof(id)); - std::memcpy(data.data(), &id, sizeof(id)); - fs.set(kNcclKey, data); - } else { - auto data = fs.get(kNcclKey); - std::memcpy(&id, data.data(), sizeof(id)); - } - // No need for barrier here as ncclCommInitRank inherently synchronizes - - // initializing NCCL - NCCLCHECK(ncclCommInitRank(&comm_, worldSize_, id, worldRank_)); - - // Remove the temporary file created for initialization - if (worldRank_ == 0) { - fs.clear(kNcclKey); - } - - createCudaResources(); -} - -NcclContext::~NcclContext() { + } + + void NcclContext::initWithMPI( + const std::unordered_map& params + ) { + // initializing MPI + MPICHECK(MPI_Init(nullptr, nullptr)); + MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &worldRank_)); + MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &worldSize_)); + + auto maxDevicePerNode = params.find(DistributedConstants::kMaxDevicePerNode); + if( + maxDevicePerNode == params.end() + || !isNonNegativeInteger(maxDevicePerNode->second) + || std::stoi(maxDevicePerNode->second) == 0 + ) { + throw std::invalid_argument( + "invalid MaxDevicePerNode for NCCL initWithMPI" + ); + } + + ncclUniqueId id; + + // TODO: Determining device is ugly. Find a better way. + fl::setDevice(worldRank_ % std::stoi(maxDevicePerNode->second)); + + // get NCCL unique ID at rank 0 and broadcast it to all others + if(worldRank_ == 0) { + ncclGetUniqueId(&id); + } + MPICHECK(MPI_Bcast((void*) &id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); + + // initializing NCCL + NCCLCHECK(ncclCommInitRank(&comm_, worldSize_, id, worldRank_)); + + createCudaResources(); + } + + void NcclContext::initWithFileSystem( + int worldRank, + int worldSize, + const std::unordered_map& params + ) { + auto filePath = params.find(DistributedConstants::kFilePath); + auto maxDevicePerNode = params.find(DistributedConstants::kMaxDevicePerNode); + + if(filePath == params.end() || filePath->second.empty()) { + throw std::invalid_argument("invalid FilePath for NCCL initWithFileSystem"); + } + if(maxDevicePerNode == params.end()) { + throw std::invalid_argument( + "invalid MaxDevicePerNode for NCCL initWithFileSystem" + ); + } + + worldRank_ = worldRank; + worldSize_ = worldSize; + + ncclUniqueId id; + + fl::setDevice(worldRank_ % std::stoi(maxDevicePerNode->second)); + + // get NCCL unique ID at rank 0 and broadcast it to all others + if(worldRank_ == 0) { + ncclGetUniqueId(&id); + } + + auto fs = FileStore(filePath->second); + if(worldRank_ == 0) { + std::vector data(sizeof(id)); + std::memcpy(data.data(), &id, sizeof(id)); + fs.set(kNcclKey, data); + } else { + auto data = fs.get(kNcclKey); + std::memcpy(&id, data.data(), sizeof(id)); + } + // No need for barrier here as ncclCommInitRank inherently synchronizes + + // initializing NCCL + NCCLCHECK(ncclCommInitRank(&comm_, worldSize_, id, worldRank_)); + + // Remove the temporary file created for initialization + if(worldRank_ == 0) { + fs.clear(kNcclKey); + } + + createCudaResources(); + } + + NcclContext::~NcclContext() { #ifdef NO_NCCL_COMM_DESTROY_HANDLE // DEBUG : ncclCommDestroy disabled as it leads to segfault. #else - // finalizing NCCL - NCCLCHECK(ncclCommDestroy(comm_)); + // finalizing NCCL + NCCLCHECK(ncclCommDestroy(comm_)); #endif // The CUDA driver has already shut down before we can free, so don't free by // default, as driver shutdown will clean up this memory anyways. #ifdef CUDA_CONTIGUOUS_BUFFER_FREE_ON_SHUTDOWN - // Free the coalesce buffer if it was allocated - if (coalesceBuffer_ != nullptr) { - FL_CUDA_CHECK(cudaFree(coalesceBuffer_)); - } + // Free the coalesce buffer if it was allocated + if(coalesceBuffer_ != nullptr) { + FL_CUDA_CHECK(cudaFree(coalesceBuffer_)); + } #endif - if (DistributedInfo::getInstance().initMethod_ == DistributedInit::MPI) { - // finalizing MPI - MPICHECK(MPI_Finalize()); - } -} -} // namespace + if(DistributedInfo::getInstance().initMethod_ == DistributedInit::MPI) { + // finalizing MPI + MPICHECK(MPI_Finalize()); + } + } + } // namespace } // namespace detail } // namespace fl diff --git a/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp b/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp index 93e888b..a828ba4 100644 --- a/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp +++ b/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp @@ -19,44 +19,48 @@ void distributedInit( DistributedInit /* initMethod */, int worldRank, int worldSize, - const std::unordered_map& /* params = {} */) { - if (isDistributedInit()) { - std::cerr << "warning: fl::distributedInit() called more than once\n"; - return; - } - if (worldSize > 1 || worldRank > 0) { - throw std::runtime_error("worldSize must be 1 with distributed stub"); - } - detail::DistributedInfo::getInstance().backend_ = DistributedBackend::STUB; - detail::DistributedInfo::getInstance().isInitialized_ = true; + const std::unordered_map& /* params = {} */ +) { + if(isDistributedInit()) { + std::cerr << "warning: fl::distributedInit() called more than once\n"; + return; + } + if(worldSize > 1 || worldRank > 0) { + throw std::runtime_error("worldSize must be 1 with distributed stub"); + } + detail::DistributedInfo::getInstance().backend_ = DistributedBackend::STUB; + detail::DistributedInfo::getInstance().isInitialized_ = true; } void allReduce(Tensor& arr, bool async /* = false */) { - if (!isDistributedInit()) { - throw std::runtime_error("distributed environment not initialized"); - } - throw std::runtime_error("allReduce not supported for stub backend"); + if(!isDistributedInit()) { + throw std::runtime_error("distributed environment not initialized"); + } + throw std::runtime_error("allReduce not supported for stub backend"); } // Not yet supported void allReduceMultiple( std::vector arrs, bool async /* = false */, - bool contiguous /* = false */) { - throw std::runtime_error( - "allReduceMultiple not supported for distributed stub backend"); + bool contiguous /* = false */ +) { + throw std::runtime_error( + "allReduceMultiple not supported for distributed stub backend" + ); } void syncDistributed() { - throw std::runtime_error( - "Asynchronous allReduce not supported for distributed stub backend"); + throw std::runtime_error( + "Asynchronous allReduce not supported for distributed stub backend" + ); } int getWorldRank() { - return 0; + return 0; } int getWorldSize() { - return 1; + return 1; } } // namespace fl diff --git a/flashlight/fl/distributed/reducers/CoalescingReducer.cpp b/flashlight/fl/distributed/reducers/CoalescingReducer.cpp index f12b323..fb7f66e 100644 --- a/flashlight/fl/distributed/reducers/CoalescingReducer.cpp +++ b/flashlight/fl/distributed/reducers/CoalescingReducer.cpp @@ -11,55 +11,56 @@ namespace fl { -CoalescingReducer::CoalescingReducer(double scale, bool async, bool contiguous) - : scale_(scale), - async_(async), - contiguous_(contiguous), - cacheThresholdBytes_(DistributedConstants::kCoalesceCacheSize) {} +CoalescingReducer::CoalescingReducer(double scale, bool async, bool contiguous) : scale_(scale), + async_(async), + contiguous_(contiguous), + cacheThresholdBytes_( + DistributedConstants:: + kCoalesceCacheSize) {} CoalescingReducer::~CoalescingReducer() { - finalize(); + finalize(); } void CoalescingReducer::add(Variable& var) { - // if this tensor would push the cache oversize, flush - if (currCacheSize_ + var.bytes() > cacheThresholdBytes_) { - flush(); - } + // if this tensor would push the cache oversize, flush + if(currCacheSize_ + var.bytes() > cacheThresholdBytes_) { + flush(); + } - // check if the tensor is larger than the cache. If so, reduce immediately - // and don't copy-coalesce - if (var.bytes() > cacheThresholdBytes_) { - allReduce(var, scale_, async_); - } else { - // if async, evaluating the JIT on the value upfront is more efficient than - // evaluating the JIT for each Variable in the cache after we flush it, - // since it more effectively facilitates overlapping compuation between the - // AF and distributed compute streams. - if (async_) { - var.eval(); + // check if the tensor is larger than the cache. If so, reduce immediately + // and don't copy-coalesce + if(var.bytes() > cacheThresholdBytes_) { + allReduce(var, scale_, async_); + } else { + // if async, evaluating the JIT on the value upfront is more efficient than + // evaluating the JIT for each Variable in the cache after we flush it, + // since it more effectively facilitates overlapping compuation between the + // AF and distributed compute streams. + if(async_) { + var.eval(); + } + // otherwise, add to cache + cache_.push_back(var); + currCacheSize_ += var.bytes(); } - // otherwise, add to cache - cache_.push_back(var); - currCacheSize_ += var.bytes(); - } } void CoalescingReducer::finalize() { - flush(); - synchronize(); + flush(); + synchronize(); } void CoalescingReducer::flush() { - allReduceMultiple(cache_, scale_, async_, contiguous_); - currCacheSize_ = 0; - cache_.clear(); + allReduceMultiple(cache_, scale_, async_, contiguous_); + currCacheSize_ = 0; + cache_.clear(); } void CoalescingReducer::synchronize() { - if (async_ || contiguous_) { - syncDistributed(); - } + if(async_ || contiguous_) { + syncDistributed(); + } } } // namespace fl diff --git a/flashlight/fl/distributed/reducers/CoalescingReducer.h b/flashlight/fl/distributed/reducers/CoalescingReducer.h index a9d5d24..e8d3814 100644 --- a/flashlight/fl/distributed/reducers/CoalescingReducer.h +++ b/flashlight/fl/distributed/reducers/CoalescingReducer.h @@ -26,68 +26,68 @@ class Variable; * ``finalize`` must be called before using a given value. */ class FL_API CoalescingReducer : public Reducer { - /// A scale by which to scale reduced gradients - double scale_; - /// Whether or not the distributed synchronization operates in a separate - /// compute stream asynchronously to the ArrayFire stream - bool async_{true}; - /// Determines if the coalesced batch of gradients is put into - /// contiguous memory before being synchronized - bool contiguous_{true}; - /// The threshold at which the cache will be flushed and its contents - /// synchronized, in bytes - const std::size_t cacheThresholdBytes_; - /// A cache that stores coalesced gradients. - std::vector cache_; - /// The current cache size, in bytes - std::size_t currCacheSize_{0}; + /// A scale by which to scale reduced gradients + double scale_; + /// Whether or not the distributed synchronization operates in a separate + /// compute stream asynchronously to the ArrayFire stream + bool async_{true}; + /// Determines if the coalesced batch of gradients is put into + /// contiguous memory before being synchronized + bool contiguous_{true}; + /// The threshold at which the cache will be flushed and its contents + /// synchronized, in bytes + const std::size_t cacheThresholdBytes_; + /// A cache that stores coalesced gradients. + std::vector cache_; + /// The current cache size, in bytes + std::size_t currCacheSize_{0}; - public: - /** - * Creates a new coalescing reducer. - * - * @param[in] cache threshold at which the cache will be flushed - * and its contents synchronized, in bytes - * @param[in] async determines whether or not the distributed compute stream - * runs asynchronously to the AF stream. - * @param[in] contiguous forces synchronization of the set of Variables - * to occur in a contiguous buffer, which may improve performance. - */ - CoalescingReducer(double scale, bool async, bool contiguous); +public: + /** + * Creates a new coalescing reducer. + * + * @param[in] cache threshold at which the cache will be flushed + * and its contents synchronized, in bytes + * @param[in] async determines whether or not the distributed compute stream + * runs asynchronously to the AF stream. + * @param[in] contiguous forces synchronization of the set of Variables + * to occur in a contiguous buffer, which may improve performance. + */ + CoalescingReducer(double scale, bool async, bool contiguous); - /** - * Destroy the Reducer. Calls `finalize()` before returning. - */ - ~CoalescingReducer() override; + /** + * Destroy the Reducer. Calls `finalize()` before returning. + */ + ~CoalescingReducer() override; - /** - * Add a ``Variable`` to ``Reducer``. Behaves as follows: - * - if the ``Variable`` exceeds the size of the coalescing cache, call - * ``allReduce`` immediately to synchronize. - * - if the ``Variable`` is smaller than the cache and adding it would push - * the cache oversize, flush the cache and synchronize with - * ``allReduceMultiple`` - * - otherwise, add the ``Variable`` to the cache. - */ - void add(Variable& var) override; + /** + * Add a ``Variable`` to ``Reducer``. Behaves as follows: + * - if the ``Variable`` exceeds the size of the coalescing cache, call + * ``allReduce`` immediately to synchronize. + * - if the ``Variable`` is smaller than the cache and adding it would push + * the cache oversize, flush the cache and synchronize with + * ``allReduceMultiple`` + * - otherwise, add the ``Variable`` to the cache. + */ + void add(Variable& var) override; - /** - * Flush any remaining ``Variable``s in the cache and synchronize. - */ - void finalize() override; + /** + * Flush any remaining ``Variable``s in the cache and synchronize. + */ + void finalize() override; - private: - /** - * Synchronize the existing set of Variables with ``allReduceMultiple`` and - * reset the cache. - */ - void flush(); +private: + /** + * Synchronize the existing set of Variables with ``allReduceMultiple`` and + * reset the cache. + */ + void flush(); - /** - * Synchronize the distributed computation stream with the existing AF - * computation stream in a way that doesn't block the main host thread. - */ - void synchronize(); + /** + * Synchronize the distributed computation stream with the existing AF + * computation stream in a way that doesn't block the main host thread. + */ + void synchronize(); }; } // namespace fl diff --git a/flashlight/fl/distributed/reducers/InlineReducer.cpp b/flashlight/fl/distributed/reducers/InlineReducer.cpp index 8ac36f5..8d0f444 100644 --- a/flashlight/fl/distributed/reducers/InlineReducer.cpp +++ b/flashlight/fl/distributed/reducers/InlineReducer.cpp @@ -13,10 +13,10 @@ namespace fl { InlineReducer::InlineReducer(double scale) : scale_(scale) {} void InlineReducer::add(Variable& var) { - if (getWorldSize() > 1) { - allReduce(var.tensor()); - } - var.tensor() *= scale_; + if(getWorldSize() > 1) { + allReduce(var.tensor()); + } + var.tensor() *= scale_; } } // namespace fl diff --git a/flashlight/fl/distributed/reducers/InlineReducer.h b/flashlight/fl/distributed/reducers/InlineReducer.h index d609607..79d997e 100644 --- a/flashlight/fl/distributed/reducers/InlineReducer.h +++ b/flashlight/fl/distributed/reducers/InlineReducer.h @@ -19,27 +19,27 @@ class Variable; * synchronized gradients are scaled by a pre-specified factor. */ class FL_API InlineReducer : public Reducer { - /// A scale by which to scale reduced gradients - double scale_; - - public: - /** - * Creates a new InlineReducer with a given scaling factor - * - * @param[in] scale the factor by which to scale gradients after - * synchronization - */ - explicit InlineReducer(double scale); - - /** - * Ingest a Variable and immediately call allReduce on it. - * - * @param[in] var the Variable to process for synchronization - */ - void add(Variable& var) override; - - // no-op; no state - void finalize() override {} + /// A scale by which to scale reduced gradients + double scale_; + +public: + /** + * Creates a new InlineReducer with a given scaling factor + * + * @param[in] scale the factor by which to scale gradients after + * synchronization + */ + explicit InlineReducer(double scale); + + /** + * Ingest a Variable and immediately call allReduce on it. + * + * @param[in] var the Variable to process for synchronization + */ + void add(Variable& var) override; + + // no-op; no state + void finalize() override {} }; } // namespace fl diff --git a/flashlight/fl/distributed/reducers/Reducer.h b/flashlight/fl/distributed/reducers/Reducer.h index e881dad..c63a060 100644 --- a/flashlight/fl/distributed/reducers/Reducer.h +++ b/flashlight/fl/distributed/reducers/Reducer.h @@ -19,24 +19,24 @@ class Variable; * general. */ class Reducer { - public: - virtual ~Reducer() = default; +public: + virtual ~Reducer() = default; - /** - * Have the Reducer ingest a Variable. What happens next is - * implementation-specific; the implementation may cache the value, - * process/synchronize immediately, or ignore the value. - * - * @param[in] var a Variable to be ingested - */ - virtual void add(Variable& var) = 0; + /** + * Have the Reducer ingest a Variable. What happens next is + * implementation-specific; the implementation may cache the value, + * process/synchronize immediately, or ignore the value. + * + * @param[in] var a Variable to be ingested + */ + virtual void add(Variable& var) = 0; - /** - * Forces a reduction/synchronization of the Reducer. - * For some implementations, this may be a no-op if the Reducer immediately - * processes or synchronizes all gradients that are added. - */ - virtual void finalize() = 0; + /** + * Forces a reduction/synchronization of the Reducer. + * For some implementations, this may be a no-op if the Reducer immediately + * processes or synchronizes all gradients that are added. + */ + virtual void finalize() = 0; }; } // namespace fl diff --git a/flashlight/fl/examples/AdaptiveClassification.cpp b/flashlight/fl/examples/AdaptiveClassification.cpp index d6e68ad..8236dd0 100644 --- a/flashlight/fl/examples/AdaptiveClassification.cpp +++ b/flashlight/fl/examples/AdaptiveClassification.cpp @@ -16,82 +16,83 @@ using namespace fl; int main(int /* unused */, const char** /* unused */) { - fl::init(); - int nsamples = 100; - int categories = 3; - int feature_dim = 10; - Tensor data = - fl::rand({feature_dim, 2 * nsamples * (categories - 1), /* B = */ 1}) * - 5 + - 1; - Tensor label = fl::full({2 * nsamples * (categories - 1), /* B = */ 1}, 0.0); - for (int i = 1; i < categories; i++) { - int start = (categories - 2 + i) * nsamples; - int end = start + nsamples; - data(i, fl::range(start, end)) = 0 - data(i, fl::range(start, end)); - label(fl::range(start, end)) = label(fl::range(start, end)) + i; - } - - Sequential model; - model.add(Linear(feature_dim, feature_dim)); - - std::vector cutoff = {1, categories}; - auto asActivation = std::make_shared(feature_dim, cutoff); - - AdaptiveSoftMaxLoss criterion(asActivation); - auto sgd_m = SGDOptimizer(model.params(), 1e-2); - auto sgd_c = SGDOptimizer(criterion.params(), 1e-2); - - Variable result, l; - int nepochs = 500, warmup_epochs = 10; - model.train(); - criterion.train(); - - const Tensor& in_ = data; - const Tensor& out_ = label; - fl::Timer s; - for (int i = 0; i < nepochs; i++) { - if (i == warmup_epochs) { - s = fl::Timer::start(); + fl::init(); + int nsamples = 100; + int categories = 3; + int feature_dim = 10; + Tensor data = + fl::rand({feature_dim, 2 * nsamples * (categories - 1), /* B = */ 1}) + * 5 + + 1; + Tensor label = fl::full({2 * nsamples * (categories - 1), /* B = */ 1}, 0.0); + for(int i = 1; i < categories; i++) { + int start = (categories - 2 + i) * nsamples; + int end = start + nsamples; + data(i, fl::range(start, end)) = 0 - data(i, fl::range(start, end)); + label(fl::range(start, end)) = label(fl::range(start, end)) + i; } - /* Forward propagation */ - result = model(input(in_)); + Sequential model; + model.add(Linear(feature_dim, feature_dim)); - /* Calculate loss */ - l = criterion(result, noGrad(out_)); + std::vector cutoff = {1, categories}; + auto asActivation = std::make_shared(feature_dim, cutoff); + + AdaptiveSoftMaxLoss criterion(asActivation); + auto sgd_m = SGDOptimizer(model.params(), 1e-2); + auto sgd_c = SGDOptimizer(criterion.params(), 1e-2); + + Variable result, l; + int nepochs = 500, warmup_epochs = 10; + model.train(); + criterion.train(); + + const Tensor& in_ = data; + const Tensor& out_ = label; + fl::Timer s; + for(int i = 0; i < nepochs; i++) { + if(i == warmup_epochs) { + s = fl::Timer::start(); + } + + /* Forward propagation */ + result = model(input(in_)); - /* Backward propagation */ - sgd_m.zeroGrad(); - sgd_c.zeroGrad(); - l.backward(); - - /* Update parameters */ - sgd_m.step(); - sgd_c.step(); - } - auto e = fl::Timer::stop(s); - - // loss - model.eval(); - result = model(input(in_)); - l = criterion(result, noGrad(out_)); - auto loss = l.tensor(); - std::cout << "Loss: " << loss << std::endl; - - // accuracy - auto log_prob = criterion.getActivation()->forward(result).tensor(); - Tensor max_value, prediction; - fl::max(max_value, prediction, log_prob, 0); - auto accuracy = mean(prediction == label(fl::span, fl::range(0, 1)), {0}); - std::cout << "Accuracy: " << accuracy << std::endl; - - auto pred = asActivation->predict(result).tensor(); - accuracy = mean( - fl::reshape(pred, label.shape()) == label(fl::span, fl::range(0, 1)), - {0}); - std::cout << "Accuracy: " << accuracy << std::endl; - - // time - fmt::print("Time/iteration: {:.5f} msec\n", e * 1000.0 / (nepochs - warmup_epochs)); + /* Calculate loss */ + l = criterion(result, noGrad(out_)); + + /* Backward propagation */ + sgd_m.zeroGrad(); + sgd_c.zeroGrad(); + l.backward(); + + /* Update parameters */ + sgd_m.step(); + sgd_c.step(); + } + auto e = fl::Timer::stop(s); + + // loss + model.eval(); + result = model(input(in_)); + l = criterion(result, noGrad(out_)); + auto loss = l.tensor(); + std::cout << "Loss: " << loss << std::endl; + + // accuracy + auto log_prob = criterion.getActivation()->forward(result).tensor(); + Tensor max_value, prediction; + fl::max(max_value, prediction, log_prob, 0); + auto accuracy = mean(prediction == label(fl::span, fl::range(0, 1)), {0}); + std::cout << "Accuracy: " << accuracy << std::endl; + + auto pred = asActivation->predict(result).tensor(); + accuracy = mean( + fl::reshape(pred, label.shape()) == label(fl::span, fl::range(0, 1)), + {0} + ); + std::cout << "Accuracy: " << accuracy << std::endl; + + // time + fmt::print("Time/iteration: {:.5f} msec\n", e * 1000.0 / (nepochs - warmup_epochs)); } diff --git a/flashlight/fl/examples/Benchmark.cpp b/flashlight/fl/examples/Benchmark.cpp index 5a4f586..0b2a906 100644 --- a/flashlight/fl/examples/Benchmark.cpp +++ b/flashlight/fl/examples/Benchmark.cpp @@ -15,140 +15,140 @@ using namespace fl; -#define TIME(FUNC) \ - std::cout << "Timing " << #FUNC << " ... " << std::flush; \ - std::cout << std::setprecision(5) << FUNC() * 1000.0 << " msec" << std::endl; +#define TIME(FUNC) \ + std::cout << "Timing " << #FUNC << " ... " << std::flush; \ + std::cout << std::setprecision(5) << FUNC() * 1000.0 << " msec" << std::endl; double timeit(std::function fn) { - // warmup - for (int i = 0; i < 10; ++i) { - fn(); - } - fl::sync(); - - int num_iters = 100; - fl::sync(); - auto start = fl::Timer::start(); - for (int i = 0; i < num_iters; i++) { - fn(); - } - fl::sync(); - return fl::Timer::stop(start) / num_iters; + // warmup + for(int i = 0; i < 10; ++i) { + fn(); + } + fl::sync(); + + int num_iters = 100; + fl::sync(); + auto start = fl::Timer::start(); + for(int i = 0; i < num_iters; i++) { + fn(); + } + fl::sync(); + return fl::Timer::stop(start) / num_iters; } double alexnet() { - Sequential model; - model.add(Conv2D(3, 64, 11, 11, 4, 4, 2, 2)); // 224 -> 55 - model.add(ReLU()); - model.add(Pool2D(3, 3, 2, 2)); // 55 -> 27 - model.add(Conv2D(64, 192, 5, 5, 1, 1, 2, 2)); // 27 -> 27 - model.add(ReLU()); - model.add(Pool2D(3, 3, 2, 2)); // 27 -> 13 - model.add(Conv2D(192, 384, 3, 3, 1, 1, 1, 1)); // 13 -> 13 - model.add(ReLU()); - model.add(Conv2D(384, 256, 3, 3, 1, 1, 1, 1)); // 13 -> 13 - model.add(ReLU()); - model.add(Conv2D(256, 256, 3, 3, 1, 1, 1, 1)); // 13 -> 13 - model.add(ReLU()); - model.add(Pool2D(3, 3, 2, 2)); // 13 -> 6 - - auto input = Variable(fl::rand({224, 224, 3, 128}) * 2 - 2, false); - - auto b = model.forward(input); - auto gradoutput = Variable(fl::rand(b.shape()) * 2 - 2, false); - - auto alexnet_fn = [&]() { - auto output = model.forward(input); - output.backward(gradoutput); - }; - return timeit(alexnet_fn); + Sequential model; + model.add(Conv2D(3, 64, 11, 11, 4, 4, 2, 2)); // 224 -> 55 + model.add(ReLU()); + model.add(Pool2D(3, 3, 2, 2)); // 55 -> 27 + model.add(Conv2D(64, 192, 5, 5, 1, 1, 2, 2)); // 27 -> 27 + model.add(ReLU()); + model.add(Pool2D(3, 3, 2, 2)); // 27 -> 13 + model.add(Conv2D(192, 384, 3, 3, 1, 1, 1, 1)); // 13 -> 13 + model.add(ReLU()); + model.add(Conv2D(384, 256, 3, 3, 1, 1, 1, 1)); // 13 -> 13 + model.add(ReLU()); + model.add(Conv2D(256, 256, 3, 3, 1, 1, 1, 1)); // 13 -> 13 + model.add(ReLU()); + model.add(Pool2D(3, 3, 2, 2)); // 13 -> 6 + + auto input = Variable(fl::rand({224, 224, 3, 128}) * 2 - 2, false); + + auto b = model.forward(input); + auto gradoutput = Variable(fl::rand(b.shape()) * 2 - 2, false); + + auto alexnet_fn = [&]() { + auto output = model.forward(input); + output.backward(gradoutput); + }; + return timeit(alexnet_fn); } double embedding() { - int embed_dim = 256; - int vocab_size = 10000; - - Embedding embed(embed_dim, vocab_size); - - int num_elems = 400; - Variable input( - (fl::rand({num_elems}) * vocab_size).astype(fl::dtype::s32), false); - Variable grad_output( - fl::randn({embed_dim, num_elems}, fl::dtype::f32), false); - - auto embed_fn = [&]() { - embed.zeroGrad(); - auto output = embed(input); - output.backward(grad_output); - }; - return timeit(embed_fn); + int embed_dim = 256; + int vocab_size = 10000; + + Embedding embed(embed_dim, vocab_size); + + int num_elems = 400; + Variable input( + (fl::rand({num_elems}) * vocab_size).astype(fl::dtype::s32), false); + Variable grad_output( + fl::randn({embed_dim, num_elems}, fl::dtype::f32), false); + + auto embed_fn = [&]() { + embed.zeroGrad(); + auto output = embed(input); + output.backward(grad_output); + }; + return timeit(embed_fn); } double linear() { - int M = 256; - int N = 512; - int B = 8; - int T = 2; - Variable input(fl::rand({N, T, B}, fl::dtype::f32), true); - Variable dout(fl::rand({M, T, B}, fl::dtype::f32), false); - Linear lin(N, M); - - auto lin_fn = [&]() { - lin.zeroGrad(); - input.zeroGrad(); - auto output = lin(input); - output.backward(dout); - }; - - return timeit(lin_fn); + int M = 256; + int N = 512; + int B = 8; + int T = 2; + Variable input(fl::rand({N, T, B}, fl::dtype::f32), true); + Variable dout(fl::rand({M, T, B}, fl::dtype::f32), false); + Linear lin(N, M); + + auto lin_fn = [&]() { + lin.zeroGrad(); + input.zeroGrad(); + auto output = lin(input); + output.backward(dout); + }; + + return timeit(lin_fn); } double batchNorm() { - // Takes around 0.72 ms on Tesla M40 with cudnn torch - int N = 8; - int C = 512; - int H = 32; - int W = 32; - Variable input(fl::rand({W, H, C, N}, fl::dtype::f32), true); - Variable dout(fl::rand({W, H, C, N}, fl::dtype::f32), true); - BatchNorm bn(2, C); // Spatial batchnorm - - auto bn_fn = [&]() { - bn.zeroGrad(); - input.zeroGrad(); - auto output = bn(input); - output.backward(dout); - }; - - return timeit(bn_fn); + // Takes around 0.72 ms on Tesla M40 with cudnn torch + int N = 8; + int C = 512; + int H = 32; + int W = 32; + Variable input(fl::rand({W, H, C, N}, fl::dtype::f32), true); + Variable dout(fl::rand({W, H, C, N}, fl::dtype::f32), true); + BatchNorm bn(2, C); // Spatial batchnorm + + auto bn_fn = [&]() { + bn.zeroGrad(); + input.zeroGrad(); + auto output = bn(input); + output.backward(dout); + }; + + return timeit(bn_fn); } double layerNorm() { - // Takes around 7.8 ms on Tesla M40 with cudnn torch - int N = 8; - int C = 512; - int H = 32; - int W = 32; - Variable input(fl::rand({W, H, C, N}, fl::dtype::f32), true); - Variable dout(fl::rand({W, H, C, N}, fl::dtype::f32), true); - LayerNorm ln(3); - - auto ln_fn = [&]() { - ln.zeroGrad(); - input.zeroGrad(); - auto output = ln(input); - output.backward(dout); - }; - - return timeit(ln_fn); + // Takes around 7.8 ms on Tesla M40 with cudnn torch + int N = 8; + int C = 512; + int H = 32; + int W = 32; + Variable input(fl::rand({W, H, C, N}, fl::dtype::f32), true); + Variable dout(fl::rand({W, H, C, N}, fl::dtype::f32), true); + LayerNorm ln(3); + + auto ln_fn = [&]() { + ln.zeroGrad(); + input.zeroGrad(); + auto output = ln(input); + output.backward(dout); + }; + + return timeit(ln_fn); } int main() { - fl::init(); - TIME(alexnet); - TIME(embedding); - TIME(linear); - TIME(batchNorm); - TIME(layerNorm); - return 0; + fl::init(); + TIME(alexnet); + TIME(embedding); + TIME(linear); + TIME(batchNorm); + TIME(layerNorm); + return 0; } diff --git a/flashlight/fl/examples/Classification.cpp b/flashlight/fl/examples/Classification.cpp index 030ec76..3009064 100644 --- a/flashlight/fl/examples/Classification.cpp +++ b/flashlight/fl/examples/Classification.cpp @@ -16,67 +16,67 @@ using namespace fl; int main(int /* unused */, const char** /* unused */) { - fl::init(); - int nsamples = 500; - int categories = 3; - int feature_dim = 10; - Tensor data = fl::rand({feature_dim, nsamples * categories}) * 2; - Tensor label = fl::full({nsamples * categories}, 0.0); - for (int i = 1; i < categories; i++) { - data(fl::span, fl::range(i * nsamples, (i + 1) * nsamples)) = - data(fl::span, fl::range(i * nsamples, (i + 1) * nsamples)) + 2 * i; - label(fl::range(i * nsamples, (i + 1) * nsamples)) = - label(fl::range(i * nsamples, (i + 1) * nsamples)) + i; - } + fl::init(); + int nsamples = 500; + int categories = 3; + int feature_dim = 10; + Tensor data = fl::rand({feature_dim, nsamples * categories}) * 2; + Tensor label = fl::full({nsamples* categories}, 0.0); + for(int i = 1; i < categories; i++) { + data(fl::span, fl::range(i * nsamples, (i + 1) * nsamples)) = + data(fl::span, fl::range(i * nsamples, (i + 1) * nsamples)) + 2 * i; + label(fl::range(i * nsamples, (i + 1) * nsamples)) = + label(fl::range(i * nsamples, (i + 1) * nsamples)) + i; + } - Sequential model; + Sequential model; - model.add(Linear(feature_dim, 10)); - model.add(WeightNorm(Linear(10, categories), 0)); - model.add(LogSoftmax()); + model.add(Linear(feature_dim, 10)); + model.add(WeightNorm(Linear(10, categories), 0)); + model.add(LogSoftmax()); - auto criterion = CategoricalCrossEntropy(); + auto criterion = CategoricalCrossEntropy(); - auto sgd = SGDOptimizer(model.params(), 0.1); + auto sgd = SGDOptimizer(model.params(), 0.1); - Variable result, l; + Variable result, l; - /* Train */ - int nepochs = 1000, warmup_epochs = 10; - model.train(); + /* Train */ + int nepochs = 1000, warmup_epochs = 10; + model.train(); - const Tensor& in_ = data; - const Tensor& out_ = label; - fl::Timer s; - for (int i = 0; i < nepochs; i++) { - if (i == warmup_epochs) { - s = fl::Timer::start(); - } + const Tensor& in_ = data; + const Tensor& out_ = label; + fl::Timer s; + for(int i = 0; i < nepochs; i++) { + if(i == warmup_epochs) { + s = fl::Timer::start(); + } - /* Forward propagation */ - result = model(input(in_)); + /* Forward propagation */ + result = model(input(in_)); - /* Calculate loss */ - l = criterion(result, noGrad(out_)); + /* Calculate loss */ + l = criterion(result, noGrad(out_)); - /* Backward propagation */ - sgd.zeroGrad(); - l.backward(); + /* Backward propagation */ + sgd.zeroGrad(); + l.backward(); - /* Update parameters */ - sgd.step(); - } - auto e = fl::Timer::stop(s); + /* Update parameters */ + sgd.step(); + } + auto e = fl::Timer::stop(s); - /* Evaluate */ - model.eval(); - result = model(input(in_)); - l = criterion(result, noGrad(out_)); - auto loss = l.tensor(); - std::cout << "Loss: " << loss << std::endl; - Tensor max_value, prediction; - fl::max(max_value, prediction, result.tensor(), 0); - auto accuracy = mean(prediction == fl::transpose(label, {1, 0}), {0}); - std::cout << "Accuracy: " << accuracy << std::endl; - fmt::print("Time/iteration: {:.5f} msec\n", e * 1000.0 / (nepochs - warmup_epochs)); + /* Evaluate */ + model.eval(); + result = model(input(in_)); + l = criterion(result, noGrad(out_)); + auto loss = l.tensor(); + std::cout << "Loss: " << loss << std::endl; + Tensor max_value, prediction; + fl::max(max_value, prediction, result.tensor(), 0); + auto accuracy = mean(prediction == fl::transpose(label, {1, 0}), {0}); + std::cout << "Accuracy: " << accuracy << std::endl; + fmt::print("Time/iteration: {:.5f} msec\n", e * 1000.0 / (nepochs - warmup_epochs)); } diff --git a/flashlight/fl/examples/DistributedTraining.cpp b/flashlight/fl/examples/DistributedTraining.cpp index d54c93a..00ce190 100644 --- a/flashlight/fl/examples/DistributedTraining.cpp +++ b/flashlight/fl/examples/DistributedTraining.cpp @@ -18,98 +18,99 @@ using namespace fl; int main() { - fl::init(); - - fl::distributedInit( - fl::DistributedInit::MPI, - -1, // worldRank - unused. Automatically derived from `MPI_Comm_Rank` - -1, // worldRank - unused. Automatically derived from `MPI_Comm_Size` - {{fl::DistributedConstants::kMaxDevicePerNode, "8"}} // param - ); - - auto worldSize = fl::getWorldSize(); - auto worldRank = fl::getWorldRank(); - bool isMaster = (worldRank == 0); - fl::setSeed(worldRank); - - auto reducer = std::make_shared( - /*scale=*/1.0 / worldSize, - /*async=*/true, - /*contiguous=*/true); - - // Create dataset - const int nSamples = 10000 / worldSize; - const int nFeat = 10; - auto X = fl::rand({nFeat, nSamples}) + 1; // X elements in [1, 2] - auto Y = /* signal */ fl::transpose(fl::sum(fl::power(X, 3), {0})) + - /* noise */ fl::sin(2 * M_PI * fl::rand({nSamples})); - // Create Dataset to simplify the code for iterating over samples - TensorDataset data({X, Y}); - - const int inputIdx = 0, targetIdx = 1; - - // Model definition - 2-layer Perceptron with ReLU activation - auto model = std::make_shared(); - model->add(Linear(nFeat, 100)); - model->add(ReLU()); - model->add(Linear(100, 1)); - // MSE loss - auto loss = MeanSquaredError(); - - // synchronize parameters of the model so that the parameters in each process - // is the same - fl::allReduceParameters(model); - - // Add a hook to synchronize gradients of model parameters as they are - // computed - fl::distributeModuleGrads(model, reducer); - - // Optimizer definition - const float learningRate = 0.0001; - const float momentum = 0.9; - auto sgd = SGDOptimizer(model->params(), learningRate, momentum); - - // Meter definition - AverageValueMeter meter; - - // Start training - - if (isMaster) { - std::cout << "[Multi-layer Perceptron] Started..." << std::endl; - } - const int nEpochs = 100; - for (int e = 1; e <= nEpochs; ++e) { - meter.reset(); - for (auto& sample : data) { - sgd.zeroGrad(); - - // Forward propagation - auto result = model->forward(input(sample[inputIdx])); - - // Calculate loss - auto l = loss(result, noGrad(sample[targetIdx])); - - // Backward propagation - l.backward(); - reducer->finalize(); - - // Update parameters - sgd.step(); - - meter.add(l.scalar()); + fl::init(); + + fl::distributedInit( + fl::DistributedInit::MPI, + -1, // worldRank - unused. Automatically derived from `MPI_Comm_Rank` + -1, // worldRank - unused. Automatically derived from `MPI_Comm_Size` + {{fl::DistributedConstants::kMaxDevicePerNode, "8"}} // param + ); + + auto worldSize = fl::getWorldSize(); + auto worldRank = fl::getWorldRank(); + bool isMaster = (worldRank == 0); + fl::setSeed(worldRank); + + auto reducer = std::make_shared( + /*scale=*/ 1.0 / worldSize, + /*async=*/ true, + /*contiguous=*/ true + ); + + // Create dataset + const int nSamples = 10000 / worldSize; + const int nFeat = 10; + auto X = fl::rand({nFeat, nSamples}) + 1; // X elements in [1, 2] + auto Y = /* signal */ fl::transpose(fl::sum(fl::power(X, 3), {0})) + + /* noise */ fl::sin(2 * M_PI * fl::rand({nSamples})); + // Create Dataset to simplify the code for iterating over samples + TensorDataset data({X, Y}); + + const int inputIdx = 0, targetIdx = 1; + + // Model definition - 2-layer Perceptron with ReLU activation + auto model = std::make_shared(); + model->add(Linear(nFeat, 100)); + model->add(ReLU()); + model->add(Linear(100, 1)); + // MSE loss + auto loss = MeanSquaredError(); + + // synchronize parameters of the model so that the parameters in each process + // is the same + fl::allReduceParameters(model); + + // Add a hook to synchronize gradients of model parameters as they are + // computed + fl::distributeModuleGrads(model, reducer); + + // Optimizer definition + const float learningRate = 0.0001; + const float momentum = 0.9; + auto sgd = SGDOptimizer(model->params(), learningRate, momentum); + + // Meter definition + AverageValueMeter meter; + + // Start training + + if(isMaster) { + std::cout << "[Multi-layer Perceptron] Started..." << std::endl; } + const int nEpochs = 100; + for(int e = 1; e <= nEpochs; ++e) { + meter.reset(); + for(auto& sample : data) { + sgd.zeroGrad(); - auto mse = meter.value(); - auto mseArr = Tensor::fromBuffer({1}, mse.data(), MemoryLocation::Host); + // Forward propagation + auto result = model->forward(input(sample[inputIdx])); - fl::allReduce(mseArr); - if (isMaster) { - std::cout << "Epoch: " << e << " Mean Squared Error: " - << mseArr.scalar() / worldSize << std::endl; + // Calculate loss + auto l = loss(result, noGrad(sample[targetIdx])); + + // Backward propagation + l.backward(); + reducer->finalize(); + + // Update parameters + sgd.step(); + + meter.add(l.scalar()); + } + + auto mse = meter.value(); + auto mseArr = Tensor::fromBuffer({1}, mse.data(), MemoryLocation::Host); + + fl::allReduce(mseArr); + if(isMaster) { + std::cout << "Epoch: " << e << " Mean Squared Error: " + << mseArr.scalar() / worldSize << std::endl; + } + } + if(isMaster) { + std::cout << "[Multi-layer Perceptron] Done!" << std::endl; } - } - if (isMaster) { - std::cout << "[Multi-layer Perceptron] Done!" << std::endl; - } - return 0; + return 0; } diff --git a/flashlight/fl/examples/LinearRegression.cpp b/flashlight/fl/examples/LinearRegression.cpp index c6ce26b..43d374f 100644 --- a/flashlight/fl/examples/LinearRegression.cpp +++ b/flashlight/fl/examples/LinearRegression.cpp @@ -15,53 +15,53 @@ #include "flashlight/fl/tensor/TensorBase.h" int main() { - fl::init(); + fl::init(); - // Create data - const int nSamples = 10000; - const int nFeat = 10; - auto X = fl::rand({nFeat, nSamples}) + 1; // X elements in [1, 2] - auto Y = /* signal */ fl::transpose(fl::sum(fl::power(X, 3), {0})) + - /* noise */ fl::sin(2 * M_PI * fl::rand({nSamples})); + // Create data + const int nSamples = 10000; + const int nFeat = 10; + auto X = fl::rand({nFeat, nSamples}) + 1; // X elements in [1, 2] + auto Y = /* signal */ fl::transpose(fl::sum(fl::power(X, 3), {0})) + + /* noise */ fl::sin(2 * M_PI * fl::rand({nSamples})); - // Training params - const int nEpochs = 100; - const float learningRate = 0.001; - auto weight = fl::Variable(fl::rand({1, nFeat}), true /* isCalcGrad */); - auto bias = fl::Variable(fl::full({1}, 0.0), true /* isCalcGrad */); + // Training params + const int nEpochs = 100; + const float learningRate = 0.001; + auto weight = fl::Variable(fl::rand({1, nFeat}), true /* isCalcGrad */); + auto bias = fl::Variable(fl::full({1}, 0.0), true /* isCalcGrad */); - std::cout << "[Linear Regression] Started..." << std::endl; + std::cout << "[Linear Regression] Started..." << std::endl; - for (int e = 1; e <= nEpochs; ++e) { - fl::Tensor error = fl::full({1}, 0); - for (int i = 0; i < nSamples; ++i) { - auto input = fl::Variable(X(fl::span, i), false /* isCalcGrad */); - auto yPred = fl::matmul(weight, input) + bias; + for(int e = 1; e <= nEpochs; ++e) { + fl::Tensor error = fl::full({1}, 0); + for(int i = 0; i < nSamples; ++i) { + auto input = fl::Variable(X(fl::span, i), false /* isCalcGrad */); + auto yPred = fl::matmul(weight, input) + bias; - auto yTrue = fl::Variable(Y(i), false /* isCalcGrad */); + auto yTrue = fl::Variable(Y(i), false /* isCalcGrad */); - // Mean Squared Error - auto loss = ((yPred - yTrue) * (yPred - yTrue)) / nSamples; + // Mean Squared Error + auto loss = ((yPred - yTrue) * (yPred - yTrue)) / nSamples; - // Compute gradients using backprop - loss.backward(); + // Compute gradients using backprop + loss.backward(); - // Update the weight and bias - weight.tensor() = weight.tensor() - learningRate * weight.grad().tensor(); - bias.tensor() = bias.tensor() - learningRate * bias.grad().tensor(); + // Update the weight and bias + weight.tensor() = weight.tensor() - learningRate * weight.grad().tensor(); + bias.tensor() = bias.tensor() - learningRate * bias.grad().tensor(); - // clear the gradients for next iteration - weight.zeroGrad(); - bias.zeroGrad(); + // clear the gradients for next iteration + weight.zeroGrad(); + bias.zeroGrad(); - error += loss.tensor(); - } + error += loss.tensor(); + } - std::cout << "Epoch: " << e - << " Mean Squared Error: " << error.scalar() << std::endl; - } + std::cout << "Epoch: " << e + << " Mean Squared Error: " << error.scalar() << std::endl; + } - std::cout << "[Linear Regression] Done!" << std::endl; + std::cout << "[Linear Regression] Done!" << std::endl; - return 0; + return 0; } diff --git a/flashlight/fl/examples/Mnist.cpp b/flashlight/fl/examples/Mnist.cpp index 7d8b4e7..62f1075 100644 --- a/flashlight/fl/examples/Mnist.cpp +++ b/flashlight/fl/examples/Mnist.cpp @@ -45,156 +45,163 @@ const int INPUT_IDX = 0; const int TARGET_IDX = 1; std::pair eval_loop(Sequential& model, BatchDataset& dataset) { - AverageValueMeter loss_meter; - FrameErrorMeter error_meter; - - // Place the model in eval mode. - model.eval(); - for (auto& example : dataset) { - auto inputs = noGrad(example[INPUT_IDX]); - auto output = model(inputs); - - // Get the predictions in max_ids - Tensor max_vals, max_ids; - max(max_vals, max_ids, output.tensor(), 0); - - auto target = noGrad(example[TARGET_IDX]); - - // Compute and record the prediction error. - error_meter.add(transpose(max_ids, {1, 0}), target.tensor()); - - // Compute and record the loss. - auto loss = categoricalCrossEntropy(output, target); - loss_meter.add(loss.tensor().scalar()); - } - // Place the model back into train mode. - model.train(); - - double error = error_meter.value(); - double loss = loss_meter.value()[0]; - return std::make_pair(loss, error); + AverageValueMeter loss_meter; + FrameErrorMeter error_meter; + + // Place the model in eval mode. + model.eval(); + for(auto& example : dataset) { + auto inputs = noGrad(example[INPUT_IDX]); + auto output = model(inputs); + + // Get the predictions in max_ids + Tensor max_vals, max_ids; + max(max_vals, max_ids, output.tensor(), 0); + + auto target = noGrad(example[TARGET_IDX]); + + // Compute and record the prediction error. + error_meter.add(transpose(max_ids, {1, 0}), target.tensor()); + + // Compute and record the loss. + auto loss = categoricalCrossEntropy(output, target); + loss_meter.add(loss.tensor().scalar()); + } + // Place the model back into train mode. + model.train(); + + double error = error_meter.value(); + double loss = loss_meter.value()[0]; + return std::make_pair(loss, error); } std::pair load_dataset( const std::string& data_dir, - bool test = false); + bool test = false +); } // namespace int main(int argc, char** argv) { - fl::init(); - if (argc != 2) { - throw std::runtime_error("You must pass a data directory."); - } - fl::setSeed(1); - std::string data_dir = argv[1]; - - float learning_rate = 1e-2; - int epochs = 10; - int batch_size = 64; - - Tensor train_x; - Tensor train_y; - std::tie(train_x, train_y) = load_dataset(data_dir); - - // Hold out a dev set - auto val_x = train_x(span, span, 0, fl::range(0, VAL_SIZE)); - train_x = train_x(span, span, 0, fl::range(VAL_SIZE, TRAIN_SIZE)); - auto val_y = train_y(fl::range(0, VAL_SIZE)); - train_y = train_y(fl::range(VAL_SIZE, TRAIN_SIZE)); - - // Make the training batch dataset - BatchDataset trainset( - std::make_shared(std::vector{train_x, train_y}), - batch_size); - - // Make the validation batch dataset - BatchDataset valset( - std::make_shared(std::vector{val_x, val_y}), - batch_size); - - Sequential model; - auto pad = PaddingMode::SAME; - model.add(View({IM_DIM, IM_DIM, 1, -1})); - model.add(Conv2D( - 1 /* input channels */, - 32 /* output channels */, - 5 /* kernel width */, - 5 /* kernel height */, - 1 /* stride x */, - 1 /* stride y */, - pad /* padding mode */, - pad /* padding mode */)); - model.add(ReLU()); - model.add(Pool2D( - 2 /* kernel width */, - 2 /* kernel height */, - 2 /* stride x */, - 2 /* stride y */)); - model.add(Conv2D(32, 64, 5, 5, 1, 1, pad, pad)); - model.add(ReLU()); - model.add(Pool2D(2, 2, 2, 2)); - model.add(View({7 * 7 * 64, -1})); - model.add(Linear(7 * 7 * 64, 1024)); - model.add(ReLU()); - model.add(Dropout(0.5)); - model.add(Linear(1024, 10)); - model.add(LogSoftmax()); - - // Make the optimizer - SGDOptimizer opt(model.params(), learning_rate); - - // The main training loop - for (int e = 0; e < epochs; e++) { - AverageValueMeter train_loss_meter; - - // Get an iterator over the data - for (auto& example : trainset) { - // Make a Variable from the input tensor. - auto inputs = noGrad(example[INPUT_IDX]); - - // Get the activations from the model. - auto output = model(inputs); - - // Make a Variable from the target tensor. - auto target = noGrad(example[TARGET_IDX]); - - // Compute and record the loss. - auto loss = categoricalCrossEntropy(output, target); - train_loss_meter.add(loss.tensor().scalar()); - - // Backprop, update the weights and then zero the gradients. - loss.backward(); - opt.step(); - opt.zeroGrad(); + fl::init(); + if(argc != 2) { + throw std::runtime_error("You must pass a data directory."); + } + fl::setSeed(1); + std::string data_dir = argv[1]; + + float learning_rate = 1e-2; + int epochs = 10; + int batch_size = 64; + + Tensor train_x; + Tensor train_y; + std::tie(train_x, train_y) = load_dataset(data_dir); + + // Hold out a dev set + auto val_x = train_x(span, span, 0, fl::range(0, VAL_SIZE)); + train_x = train_x(span, span, 0, fl::range(VAL_SIZE, TRAIN_SIZE)); + auto val_y = train_y(fl::range(0, VAL_SIZE)); + train_y = train_y(fl::range(VAL_SIZE, TRAIN_SIZE)); + + // Make the training batch dataset + BatchDataset trainset( + std::make_shared(std::vector{train_x, train_y}), + batch_size); + + // Make the validation batch dataset + BatchDataset valset( + std::make_shared(std::vector{val_x, val_y}), + batch_size); + + Sequential model; + auto pad = PaddingMode::SAME; + model.add(View({IM_DIM, IM_DIM, 1, -1})); + model.add( + Conv2D( + 1 /* input channels */, + 32 /* output channels */, + 5 /* kernel width */, + 5 /* kernel height */, + 1 /* stride x */, + 1 /* stride y */, + pad /* padding mode */, + pad /* padding mode */ + ) + ); + model.add(ReLU()); + model.add( + Pool2D( + 2 /* kernel width */, + 2 /* kernel height */, + 2 /* stride x */, + 2 /* stride y */ + ) + ); + model.add(Conv2D(32, 64, 5, 5, 1, 1, pad, pad)); + model.add(ReLU()); + model.add(Pool2D(2, 2, 2, 2)); + model.add(View({7 * 7 * 64, -1})); + model.add(Linear(7 * 7 * 64, 1024)); + model.add(ReLU()); + model.add(Dropout(0.5)); + model.add(Linear(1024, 10)); + model.add(LogSoftmax()); + + // Make the optimizer + SGDOptimizer opt(model.params(), learning_rate); + + // The main training loop + for(int e = 0; e < epochs; e++) { + AverageValueMeter train_loss_meter; + + // Get an iterator over the data + for(auto& example : trainset) { + // Make a Variable from the input tensor. + auto inputs = noGrad(example[INPUT_IDX]); + + // Get the activations from the model. + auto output = model(inputs); + + // Make a Variable from the target tensor. + auto target = noGrad(example[TARGET_IDX]); + + // Compute and record the loss. + auto loss = categoricalCrossEntropy(output, target); + train_loss_meter.add(loss.tensor().scalar()); + + // Backprop, update the weights and then zero the gradients. + loss.backward(); + opt.step(); + opt.zeroGrad(); + } + + double train_loss = train_loss_meter.value()[0]; + + // Evaluate on the dev set. + double val_loss, val_error; + std::tie(val_loss, val_error) = eval_loop(model, valset); + + std::cout << "Epoch " << e << std::setprecision(3) + << ": Avg Train Loss: " << train_loss + << " Validation Loss: " << val_loss + << " Validation Error (%): " << val_error << std::endl; } - double train_loss = train_loss_meter.value()[0]; - - // Evaluate on the dev set. - double val_loss, val_error; - std::tie(val_loss, val_error) = eval_loop(model, valset); - - std::cout << "Epoch " << e << std::setprecision(3) - << ": Avg Train Loss: " << train_loss - << " Validation Loss: " << val_loss - << " Validation Error (%): " << val_error << std::endl; - } - - Tensor test_x; - Tensor test_y; - std::tie(test_x, test_y) = load_dataset(data_dir, true); + Tensor test_x; + Tensor test_y; + std::tie(test_x, test_y) = load_dataset(data_dir, true); - BatchDataset testset( - std::make_shared(std::vector{test_x, test_y}), - batch_size); + BatchDataset testset( + std::make_shared(std::vector{test_x, test_y}), + batch_size); - double test_loss, test_error; - std::tie(test_loss, test_error) = eval_loop(model, testset); - std::cout << "Test Loss: " << test_loss << " Test Error (%): " << test_error - << std::endl; + double test_loss, test_error; + std::tie(test_loss, test_error) = eval_loop(model, testset); + std::cout << "Test Loss: " << test_loss << " Test Error (%): " << test_error + << std::endl; - return 0; + return 0; } namespace { @@ -202,62 +209,64 @@ namespace { // MNIST Data loading functions below. int read_int(std::ifstream& f) { - int d = 0; - int c; - for (int i = 0; i < sizeof(int); i++) { - c = 0; - f.read((char*)&c, 1); - d |= (c << (8 * (sizeof(int) - i - 1))); - } - return d; + int d = 0; + int c; + for(int i = 0; i < sizeof(int); i++) { + c = 0; + f.read((char*) &c, 1); + d |= (c << (8 * (sizeof(int) - i - 1))); + } + return d; } -template +template Tensor load_data( const std::string& im_file, - const std::vector& dims) { - std::ifstream file(im_file, std::ios::binary); - if (!file.is_open()) { - throw std::runtime_error("[mnist:load_data] Can't find MNIST file."); - } - read_int(file); // unused magic - size_t elems = 1; - for (auto d : dims) { - int read_d = read_int(file); - elems *= read_d; - if (read_d != d) { - throw std::runtime_error("[mnist:load_data] Unexpected MNIST dimension."); + const std::vector& dims +) { + std::ifstream file(im_file, std::ios::binary); + if(!file.is_open()) { + throw std::runtime_error("[mnist:load_data] Can't find MNIST file."); } - } - - std::vector data; - data.reserve(elems); - for (int i = 0; i < elems; i++) { - unsigned char tmp; - file.read((char*)&tmp, sizeof(tmp)); - data.push_back(tmp); - } - - std::vector rdims(dims.rbegin(), dims.rend()); - // af is column-major - return Tensor::fromBuffer(Shape(rdims), data.data(), MemoryLocation::Host); + read_int(file); // unused magic + size_t elems = 1; + for(auto d : dims) { + int read_d = read_int(file); + elems *= read_d; + if(read_d != d) { + throw std::runtime_error("[mnist:load_data] Unexpected MNIST dimension."); + } + } + + std::vector data; + data.reserve(elems); + for(int i = 0; i < elems; i++) { + unsigned char tmp; + file.read((char*) &tmp, sizeof(tmp)); + data.push_back(tmp); + } + + std::vector rdims(dims.rbegin(), dims.rend()); + // af is column-major + return Tensor::fromBuffer(Shape(rdims), data.data(), MemoryLocation::Host); } std::pair load_dataset( const std::string& data_dir, - bool test /* = false */) { - std::string f = test ? "t10k" : "train"; - int size = test ? TEST_SIZE : TRAIN_SIZE; + bool test /* = false */ +) { + std::string f = test ? "t10k" : "train"; + int size = test ? TEST_SIZE : TRAIN_SIZE; - std::string image_file = data_dir + "/" + f + "-images-idx3-ubyte"; - Tensor ims = load_data(image_file, {size, IM_DIM, IM_DIM}); - ims = reshape(ims, {IM_DIM, IM_DIM, 1, size}); - // Rescale to [-0.5, 0.5] - ims = (ims - PIXEL_MAX / 2) / PIXEL_MAX; + std::string image_file = data_dir + "/" + f + "-images-idx3-ubyte"; + Tensor ims = load_data(image_file, {size, IM_DIM, IM_DIM}); + ims = reshape(ims, {IM_DIM, IM_DIM, 1, size}); + // Rescale to [-0.5, 0.5] + ims = (ims - PIXEL_MAX / 2) / PIXEL_MAX; - std::string label_file = data_dir + "/" + f + "-labels-idx1-ubyte"; - Tensor labels = load_data(label_file, {size}); + std::string label_file = data_dir + "/" + f + "-labels-idx1-ubyte"; + Tensor labels = load_data(label_file, {size}); - return std::make_pair(ims, labels); + return std::make_pair(ims, labels); } } // namespace diff --git a/flashlight/fl/examples/Perceptron.cpp b/flashlight/fl/examples/Perceptron.cpp index e7d79f1..78407dc 100644 --- a/flashlight/fl/examples/Perceptron.cpp +++ b/flashlight/fl/examples/Perceptron.cpp @@ -18,61 +18,61 @@ using namespace fl; int main() { - fl::init(); - - // Create dataset - const int nSamples = 10000; - const int nFeat = 10; - auto X = fl::rand({nFeat, nSamples}) + 1; // X elements in [1, 2] - auto Y = /* signal */ fl::transpose(fl::sum(fl::power(X, 3), {0})) + - /* noise */ fl::sin(2 * M_PI * fl::rand({nSamples})); - // Create Dataset to simplify the code for iterating over samples - TensorDataset data({X, Y}); - const int inputIdx = 0, targetIdx = 1; - - // Model definition - 2-layer Perceptron with ReLU activation - Sequential model; - model.add(Linear(nFeat, 100)); - model.add(ReLU()); - model.add(Linear(100, 1)); - // MSE loss - auto loss = MeanSquaredError(); - - // Optimizer definition - const float learningRate = 0.0001; - const float momentum = 0.9; - auto sgd = SGDOptimizer(model.params(), learningRate, momentum); - - // Meter definition - AverageValueMeter meter; - - // Start training - - std::cout << "[Multi-layer Perceptron] Started..." << std::endl; - - const int nEpochs = 100; - for (int e = 1; e <= nEpochs; ++e) { - meter.reset(); - for (auto& sample : data) { - sgd.zeroGrad(); - - // Forward propagation - auto result = model(input(sample[inputIdx])); - - // Calculate loss - auto l = loss(result, noGrad(sample[targetIdx])); - - // Backward propagation - l.backward(); - - // Update parameters - sgd.step(); - - meter.add(l.scalar()); + fl::init(); + + // Create dataset + const int nSamples = 10000; + const int nFeat = 10; + auto X = fl::rand({nFeat, nSamples}) + 1; // X elements in [1, 2] + auto Y = /* signal */ fl::transpose(fl::sum(fl::power(X, 3), {0})) + + /* noise */ fl::sin(2 * M_PI * fl::rand({nSamples})); + // Create Dataset to simplify the code for iterating over samples + TensorDataset data({X, Y}); + const int inputIdx = 0, targetIdx = 1; + + // Model definition - 2-layer Perceptron with ReLU activation + Sequential model; + model.add(Linear(nFeat, 100)); + model.add(ReLU()); + model.add(Linear(100, 1)); + // MSE loss + auto loss = MeanSquaredError(); + + // Optimizer definition + const float learningRate = 0.0001; + const float momentum = 0.9; + auto sgd = SGDOptimizer(model.params(), learningRate, momentum); + + // Meter definition + AverageValueMeter meter; + + // Start training + + std::cout << "[Multi-layer Perceptron] Started..." << std::endl; + + const int nEpochs = 100; + for(int e = 1; e <= nEpochs; ++e) { + meter.reset(); + for(auto& sample : data) { + sgd.zeroGrad(); + + // Forward propagation + auto result = model(input(sample[inputIdx])); + + // Calculate loss + auto l = loss(result, noGrad(sample[targetIdx])); + + // Backward propagation + l.backward(); + + // Update parameters + sgd.step(); + + meter.add(l.scalar()); + } + std::cout << "Epoch: " << e << " Mean Squared Error: " << meter.value()[0] + << std::endl; } - std::cout << "Epoch: " << e << " Mean Squared Error: " << meter.value()[0] - << std::endl; - } - std::cout << "[Multi-layer Perceptron] Done!" << std::endl; - return 0; + std::cout << "[Multi-layer Perceptron] Done!" << std::endl; + return 0; } diff --git a/flashlight/fl/examples/RnnClassification.cpp b/flashlight/fl/examples/RnnClassification.cpp index 61fed91..ec32d4a 100644 --- a/flashlight/fl/examples/RnnClassification.cpp +++ b/flashlight/fl/examples/RnnClassification.cpp @@ -8,7 +8,7 @@ /* Approximate re implementation of the char rnn PyTorch tutorial: https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html Dataset: https://download.pytorch.org/tutorial/data.zip -*/ + */ #include #include @@ -33,347 +33,353 @@ using namespace fl; // return a random int between [mini, maxi] int randi(int mini, int maxi) { - if (maxi < mini) { - std::swap(maxi, mini); - } - return rand() % (maxi - mini + 1) + mini; + if(maxi < mini) { + std::swap(maxi, mini); + } + return rand() % (maxi - mini + 1) + mini; } class ClassificationDataset : public Dataset { - public: - using names = std::vector; - std::map datasets; - static std::map Label2Id; - static std::map Id2Label; - unsigned totalExamples = 0; - - // read the folder/lang.txt file and register the examples in the datasets map - void read(const fs::path& folder, const std::string& lang) { - auto fp = folder / (lang + ".txt"); - std::cout << "Opening " << fp << std::endl; - std::ifstream file(fp); - if (!file.is_open()) { - throw std::runtime_error("Can't open the input dataset file"); +public: + using names = std::vector; + std::map datasets; + static std::map Label2Id; + static std::map Id2Label; + unsigned totalExamples = 0; + + // read the folder/lang.txt file and register the examples in the datasets map + void read(const fs::path& folder, const std::string& lang) { + auto fp = folder / (lang + ".txt"); + std::cout << "Opening " << fp << std::endl; + std::ifstream file(fp); + if(!file.is_open()) { + throw std::runtime_error("Can't open the input dataset file"); + } + unsigned id = Label2Id.size(); + Label2Id[lang] = id; + Id2Label[id] = lang; + names v; + std::string line; + while(std::getline(file, line)) { + if(line.empty()) { + continue; + } + v.push_back(line); + } + totalExamples += v.size(); + std::cout << "Found " << v.size() << " examples for category " << lang + << ". Total: " << totalExamples << std::endl; + datasets[lang] = v; } - unsigned id = Label2Id.size(); - Label2Id[lang] = id; - Id2Label[id] = lang; - names v; - std::string line; - while (std::getline(file, line)) { - if (line.empty()) { - continue; - } - v.push_back(line); + + // Turn a string into an AF array <1 x line_length> of char indices + static Tensor lineToTensor(const std::string& line) { + std::vector d; + for(char c : line) { + d.push_back(static_cast(c)); // direct cast of char to float + } + return Tensor::fromBuffer( + {1, static_cast(d.size())}, + d.data(), + MemoryLocation::Host + ); + } + + explicit ClassificationDataset(const fs::path& datasetPath) { + // As found in the dataset folder: + std::vector lang = { + "Arabic", + "Greek", + "Chinese", + "Czech", + "Dutch", + "Japanese", + "Korean", + "Russian", + "English", + "Scottish", + "Vietnamese", + "German", + "Spanish", + "French", + "Polish", + "Italian", + "Irish"}; + for(auto& l : lang) { + read(datasetPath, l); + } + for(auto& it : Id2Label) { + std::cout << it.first << ":" << it.second << ", "; + } + std::cout << std::endl; } - totalExamples += v.size(); - std::cout << "Found " << v.size() << " examples for category " << lang - << ". Total: " << totalExamples << std::endl; - datasets[lang] = v; - } - - // Turn a string into an AF array <1 x line_length> of char indices - static Tensor lineToTensor(const std::string& line) { - std::vector d; - for (char c : line) { - d.push_back(static_cast(c)); // direct cast of char to float + + // each epoch to go over some percent of the training dataset + int64_t size() const override { + return .3f * totalExamples; } - return Tensor::fromBuffer( - {1, static_cast(d.size())}, d.data(), MemoryLocation::Host); - } - - explicit ClassificationDataset(const fs::path& datasetPath) { - // As found in the dataset folder: - std::vector lang = { - "Arabic", - "Greek", - "Chinese", - "Czech", - "Dutch", - "Japanese", - "Korean", - "Russian", - "English", - "Scottish", - "Vietnamese", - "German", - "Spanish", - "French", - "Polish", - "Italian", - "Irish"}; - for (auto& l : lang) { - read(datasetPath, l); + + // get a random example: name, category + std::pair getRandomExample() const { + std::pair p; + auto it = datasets.begin(); + unsigned cat = randi(0, datasets.size() - 1); // random category index + std::advance(it, cat); + p.second = it->first; // category name + const auto& v = it->second; + unsigned nn = v.size(); + unsigned ri = randi(0, nn - 1); + p.first = v[ri]; // name + return p; } - for (auto& it : Id2Label) { - std::cout << it.first << ":" << it.second << ", "; + + // get a (random) example and return a vector of 2 tensors : the input and the + // expected category index + std::vector get(const int64_t) const override { + auto p = getRandomExample(); + const std::string& n = p.first; + Tensor input = lineToTensor(n); + std::vector cv; + cv.push_back(Label2Id[p.second]); + auto expected = Tensor::fromBuffer({1}, cv.data(), MemoryLocation::Host); + return {input, expected}; } - std::cout << std::endl; - } - - // each epoch to go over some percent of the training dataset - int64_t size() const override { - return .3f * totalExamples; - } - - // get a random example: name, category - std::pair getRandomExample() const { - std::pair p; - auto it = datasets.begin(); - unsigned cat = randi(0, datasets.size() - 1); // random category index - std::advance(it, cat); - p.second = it->first; // category name - const auto& v = it->second; - unsigned nn = v.size(); - unsigned ri = randi(0, nn - 1); - p.first = v[ri]; // name - return p; - } - - // get a (random) example and return a vector of 2 tensors : the input and the - // expected category index - std::vector get(const int64_t) const override { - auto p = getRandomExample(); - const std::string& n = p.first; - Tensor input = lineToTensor(n); - std::vector cv; - cv.push_back(Label2Id[p.second]); - auto expected = Tensor::fromBuffer({1}, cv.data(), MemoryLocation::Host); - return {input, expected}; - } }; std::map ClassificationDataset::Label2Id; std::map ClassificationDataset::Id2Label; class RnnClassifier : public Container { - public: - explicit RnnClassifier( - unsigned numClasses, - unsigned vocabSize, - int hiddenSize = 256, - unsigned numLayers = 2) - : embed_(std::make_shared(hiddenSize, vocabSize)), +public: + explicit RnnClassifier( + unsigned numClasses, + unsigned vocabSize, + int hiddenSize = 256, + unsigned numLayers = 2 + ) : embed_(std::make_shared(hiddenSize, vocabSize)), rnn_(std::make_shared( hiddenSize, hiddenSize, numLayers, RnnMode::GRU, - 0 /* Dropout */)), + 0 /* Dropout */ + )), linear_(std::make_shared(hiddenSize, numClasses)), logsoftmax_(0) { - std::cout << "Creating a RNN Classifier with vocab size: " << vocabSize - << " and num classes: " << numClasses << std::endl; - createLayers(); - } - - RnnClassifier(const RnnClassifier& other) { - copy(other); - createLayers(); - } - - // The compiler default generated is made explicit for reference. - // Users must be careful to include move and move assignment - // constructors where appropriate. - RnnClassifier(RnnClassifier&& other) = default; - - RnnClassifier& operator=(const RnnClassifier& other) { - clear(); - copy(other); - createLayers(); - return *this; - } - - // The compiler default generated is made explicit for reference. - // Users must be careful to include move and move assignment - // constructors where appropriate. - RnnClassifier& operator=(RnnClassifier&& other) = default; - - void copy(const RnnClassifier& other) { - embed_ = std::make_shared(*other.embed_); - rnn_ = std::make_shared(*other.rnn_); - linear_ = std::make_shared(*other.linear_); - logsoftmax_ = other.logsoftmax_; - } - - void createLayers() { - add(embed_); - add(rnn_); - add(linear_); - } - - std::unique_ptr clone() const override { - return std::make_unique(*this); - } - - std::vector forward(const std::vector& inputs) override { - throw std::runtime_error("Not implemented"); - } - - std::tuple - forward(const Variable& input, const Variable& h, const Variable& c) { - const unsigned numChars = input.dim(1); - Variable ho, co; // hidden and carry output - // output should be hs x bs x ts : [ 𝑋𝑖𝑛, 𝑁, 𝑇 ] - Variable output = embed_->forward(input); - // The input to the RNN is expected to be of shape [ 𝑋𝑖𝑛, 𝑁, 𝑇 ] where 𝑋𝑖𝑛 - // is the hidden size, 𝑁 is the batch size and 𝑇 is the sequence length. - std::tie(output, ho, co) = rnn_->forward(output, h, c); - // The output of the RNN should be of shape [ 𝑋𝑜𝑢𝑡, 𝑁, 𝑇], with 𝑋𝑜𝑢𝑡 to be - // the hidden_size (unidirectional RNN) - // Truncate BPTT - ho.setCalcGrad(false); - co.setCalcGrad(false); - output = linear_->forward(output); - output = logsoftmax_(output); - output = output(fl::span, fl::span, input.tensor().dim(1) - 1); - return std::make_tuple(output, ho, co); - } - - std::string prettyString() const override { - return "RnnClassifer"; - } - - // Inference on the given input: returns the category index - unsigned - infer(const std::string& inputString, const Variable& h, const Variable& c) { - Tensor ia = ClassificationDataset::lineToTensor(inputString); - Variable output, ho, co; - std::tie(output, ho, co) = forward(noGrad(ia), h, c); - Tensor maxValue, prediction; - fl::max(maxValue, prediction, output.tensor(), 0); - unsigned classId = prediction.scalar(); - return classId; - } - - // Predict the category of the given input, compared with the expected label - // and print the result - bool unittest(const std::string& input, const std::string& expectedLabel) { - Variable output, h, c; - auto p = ClassificationDataset::Id2Label[infer(input, h, c)]; - const bool passes = p == expectedLabel; - const std::string s = (passes ? "✓ " : "✗ "); - std::cout << "input: " << std::setw(20) << input - << "\t expected: " << expectedLabel << "\t prediction: " << p - << "\t" << s << std::endl; - return passes; - } - - private: - std::shared_ptr embed_; - std::shared_ptr rnn_; - std::shared_ptr linear_; - LogSoftmax logsoftmax_; + std::cout << "Creating a RNN Classifier with vocab size: " << vocabSize + << " and num classes: " << numClasses << std::endl; + createLayers(); + } + + RnnClassifier(const RnnClassifier& other) { + copy(other); + createLayers(); + } + + // The compiler default generated is made explicit for reference. + // Users must be careful to include move and move assignment + // constructors where appropriate. + RnnClassifier(RnnClassifier&& other) = default; + + RnnClassifier& operator=(const RnnClassifier& other) { + clear(); + copy(other); + createLayers(); + return *this; + } + + // The compiler default generated is made explicit for reference. + // Users must be careful to include move and move assignment + // constructors where appropriate. + RnnClassifier& operator=(RnnClassifier&& other) = default; + + void copy(const RnnClassifier& other) { + embed_ = std::make_shared(*other.embed_); + rnn_ = std::make_shared(*other.rnn_); + linear_ = std::make_shared(*other.linear_); + logsoftmax_ = other.logsoftmax_; + } + + void createLayers() { + add(embed_); + add(rnn_); + add(linear_); + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + std::vector forward(const std::vector& inputs) override { + throw std::runtime_error("Not implemented"); + } + + std::tuple forward( + const Variable& input, + const Variable& h, + const Variable& c + ) { + const unsigned numChars = input.dim(1); + Variable ho, co; // hidden and carry output + // output should be hs x bs x ts : [ 𝑋𝑖𝑛, 𝑁, 𝑇 ] + Variable output = embed_->forward(input); + // The input to the RNN is expected to be of shape [ 𝑋𝑖𝑛, 𝑁, 𝑇 ] where 𝑋𝑖𝑛 + // is the hidden size, 𝑁 is the batch size and 𝑇 is the sequence length. + std::tie(output, ho, co) = rnn_->forward(output, h, c); + // The output of the RNN should be of shape [ 𝑋𝑜𝑢𝑡, 𝑁, 𝑇], with 𝑋𝑜𝑢𝑡 to be + // the hidden_size (unidirectional RNN) + // Truncate BPTT + ho.setCalcGrad(false); + co.setCalcGrad(false); + output = linear_->forward(output); + output = logsoftmax_(output); + output = output(fl::span, fl::span, input.tensor().dim(1) - 1); + return std::make_tuple(output, ho, co); + } + + std::string prettyString() const override { + return "RnnClassifer"; + } + + // Inference on the given input: returns the category index + unsigned infer(const std::string& inputString, const Variable& h, const Variable& c) { + Tensor ia = ClassificationDataset::lineToTensor(inputString); + Variable output, ho, co; + std::tie(output, ho, co) = forward(noGrad(ia), h, c); + Tensor maxValue, prediction; + fl::max(maxValue, prediction, output.tensor(), 0); + unsigned classId = prediction.scalar(); + return classId; + } + + // Predict the category of the given input, compared with the expected label + // and print the result + bool unittest(const std::string& input, const std::string& expectedLabel) { + Variable output, h, c; + auto p = ClassificationDataset::Id2Label[infer(input, h, c)]; + const bool passes = p == expectedLabel; + const std::string s = (passes ? "✓ " : "✗ "); + std::cout << "input: " << std::setw(20) << input + << "\t expected: " << expectedLabel << "\t prediction: " << p + << "\t" << s << std::endl; + return passes; + } + +private: + std::shared_ptr embed_; + std::shared_ptr rnn_; + std::shared_ptr linear_; + LogSoftmax logsoftmax_; }; int main(int argc, char** argv) { - fl::init(); - std::cout << "RnnClassification (path to the data dir) (learning rate) (num " - "epochs) (hiddensize)" - << std::endl; - std::cout << "Dataset : https://download.pytorch.org/tutorial/data.zip" - << std::endl; - if (argc < 2) { - std::cout << "To setup the dataset: " << std::endl; - std::cout << "wget https://download.pytorch.org/tutorial/data.zip" - << std::endl; - std::cout << "unzip data.zip" << std::endl; - std::cout << "./RnnClassification data/names" << std::endl; - return 0; - } - fs::path dataDir = argv[1]; - - // To reproduce the pytorch tutorial, the dataset is not split in - // train/dev/test sub datasets but random samples are simply picked from the - // overall dataset: - ClassificationDataset trainSet(dataDir); - - const float learningRate = argc > 2 ? std::stof(argv[2]) : 0.1; - const int epochs = argc > 3 ? std::stol(argv[3]) : 6; - const unsigned hiddenSize = argc > 4 ? std::stol(argv[4]) : 256; - const float momentum = 0.9; - const float maxGradNorm = 0.25; - - RnnClassifier model( - ClassificationDataset::Label2Id.size(), - 256, // input vocab size set to 256 to support any possible character, - // ascii or not - hiddenSize); - // https://fl.readthedocs.io/en/latest/modules.html#categoricalcrossentropy - CategoricalCrossEntropy criterion; - auto opt = SGDOptimizer(model.params(), learningRate, momentum); - - // Each epoch to go over a small percent of the dataset - for (int e = 0; e < epochs; e++) { - AverageValueMeter trainLossMeter; - Variable output, h, c; - const int kInputIdx = 0, kTargetIdx = 1; - for (auto& example : trainSet) { - std::tie(output, h, c) = model.forward(noGrad(example[kInputIdx]), h, c); - auto target = noGrad(example[kTargetIdx]); - // Computes the categorical cross entropy loss: - // The input is expected to contain log-probabilities for each class. - // The targets should be the index of the ground truth class for each - // input example. - auto loss = criterion(output, target); - trainLossMeter.add(loss.tensor().scalar(), target.elements()); - opt.zeroGrad(); - loss.backward(); - // Clipping is a must have to avoid exploding gradients: - clipGradNorm(model.params(), maxGradNorm); - opt.step(); + fl::init(); + std::cout << "RnnClassification (path to the data dir) (learning rate) (num " + "epochs) (hiddensize)" + << std::endl; + std::cout << "Dataset : https://download.pytorch.org/tutorial/data.zip" + << std::endl; + if(argc < 2) { + std::cout << "To setup the dataset: " << std::endl; + std::cout << "wget https://download.pytorch.org/tutorial/data.zip" + << std::endl; + std::cout << "unzip data.zip" << std::endl; + std::cout << "./RnnClassification data/names" << std::endl; + return 0; } - - double trainLoss = trainLossMeter.value()[0]; - std::cout << "Epoch " << e + 1 << std::setprecision(3) - << " - Train Loss: " << trainLoss << std::endl; - - // compute the accuracy confusion matrix: - const unsigned nCategories = ClassificationDataset::Label2Id.size(); - Tensor confusion = fl::full({nCategories, nCategories}, 0.); - // Go through a bunch of examples and record which are correctly guessed - float numMatch = 0, nConfusion = 1000; - for (unsigned i = 0; i < nConfusion; ++i) { - auto p = trainSet.getRandomExample(); - unsigned pred = model.infer(p.first, h, c); - unsigned correctPred = ClassificationDataset::Label2Id[p.second]; - if (pred == correctPred) { - ++numMatch; - } - confusion(correctPred, pred) = confusion(correctPred, pred) + 1; + fs::path dataDir = argv[1]; + + // To reproduce the pytorch tutorial, the dataset is not split in + // train/dev/test sub datasets but random samples are simply picked from the + // overall dataset: + ClassificationDataset trainSet(dataDir); + + const float learningRate = argc > 2 ? std::stof(argv[2]) : 0.1; + const int epochs = argc > 3 ? std::stol(argv[3]) : 6; + const unsigned hiddenSize = argc > 4 ? std::stol(argv[4]) : 256; + const float momentum = 0.9; + const float maxGradNorm = 0.25; + + RnnClassifier model( + ClassificationDataset::Label2Id.size(), + 256, // input vocab size set to 256 to support any possible character, + // ascii or not + hiddenSize); + // https://fl.readthedocs.io/en/latest/modules.html#categoricalcrossentropy + CategoricalCrossEntropy criterion; + auto opt = SGDOptimizer(model.params(), learningRate, momentum); + + // Each epoch to go over a small percent of the dataset + for(int e = 0; e < epochs; e++) { + AverageValueMeter trainLossMeter; + Variable output, h, c; + const int kInputIdx = 0, kTargetIdx = 1; + for(auto& example : trainSet) { + std::tie(output, h, c) = model.forward(noGrad(example[kInputIdx]), h, c); + auto target = noGrad(example[kTargetIdx]); + // Computes the categorical cross entropy loss: + // The input is expected to contain log-probabilities for each class. + // The targets should be the index of the ground truth class for each + // input example. + auto loss = criterion(output, target); + trainLossMeter.add(loss.tensor().scalar(), target.elements()); + opt.zeroGrad(); + loss.backward(); + // Clipping is a must have to avoid exploding gradients: + clipGradNorm(model.params(), maxGradNorm); + opt.step(); + } + + double trainLoss = trainLossMeter.value()[0]; + std::cout << "Epoch " << e + 1 << std::setprecision(3) + << " - Train Loss: " << trainLoss << std::endl; + + // compute the accuracy confusion matrix: + const unsigned nCategories = ClassificationDataset::Label2Id.size(); + Tensor confusion = fl::full({nCategories, nCategories}, 0.); + // Go through a bunch of examples and record which are correctly guessed + float numMatch = 0, nConfusion = 1000; + for(unsigned i = 0; i < nConfusion; ++i) { + auto p = trainSet.getRandomExample(); + unsigned pred = model.infer(p.first, h, c); + unsigned correctPred = ClassificationDataset::Label2Id[p.second]; + if(pred == correctPred) { + ++numMatch; + } + confusion(correctPred, pred) = confusion(correctPred, pred) + 1; + } + confusion = confusion + / fl::tile(fl::sum(confusion, {1}), {1, nCategories}); // average + std::cout << "Global accuracy=" << numMatch / nConfusion << "\t "; + for(unsigned i = 0; i < nCategories; ++i) { + std::cout << ClassificationDataset::Id2Label[i] << ":" << std::fixed + << std::setprecision(2) << confusion(i, i).scalar() + << " "; + } + std::cout << std::endl; + } + // List of names not in the training dataset + const std::vector> quickList = { + {"Samad", "Arabic"}, + {"Papademos", "Greek"}, + {"Birovsky", "Czech"}, + {"Wai", "Chinese"}, + {"Nikolaev", "Russian"}, + {"Washington", "English"}, + {"Voltaire", "French"}, + {"Pfeiffer", "German"}, + {"Tambellini", "Italian"}}; + for(auto& p : quickList) { + model.unittest(p.first, p.second); } - confusion = confusion / - fl::tile(fl::sum(confusion, {1}), {1, nCategories}); // average - std::cout << "Global accuracy=" << numMatch / nConfusion << "\t "; - for (unsigned i = 0; i < nCategories; ++i) { - std::cout << ClassificationDataset::Id2Label[i] << ":" << std::fixed - << std::setprecision(2) << confusion(i, i).scalar() - << " "; + + while(true) { + std::string name; + std::cout << "Enter a surname and press enter to classify it: "; + std::cin >> name; + Variable output, h, c; + std::cout << ClassificationDataset::Id2Label[model.infer(name, h, c)] + << " ?" << std::endl; } - std::cout << std::endl; - } - // List of names not in the training dataset - const std::vector> quickList = { - {"Samad", "Arabic"}, - {"Papademos", "Greek"}, - {"Birovsky", "Czech"}, - {"Wai", "Chinese"}, - {"Nikolaev", "Russian"}, - {"Washington", "English"}, - {"Voltaire", "French"}, - {"Pfeiffer", "German"}, - {"Tambellini", "Italian"}}; - for (auto& p : quickList) { - model.unittest(p.first, p.second); - } - - while (true) { - std::string name; - std::cout << "Enter a surname and press enter to classify it: "; - std::cin >> name; - Variable output, h, c; - std::cout << ClassificationDataset::Id2Label[model.infer(name, h, c)] - << " ?" << std::endl; - } - std::cout << "Finished" << std::endl; - return 0; + std::cout << "Finished" << std::endl; + return 0; } diff --git a/flashlight/fl/examples/RnnLm.cpp b/flashlight/fl/examples/RnnLm.cpp index b279b39..6d4e3d3 100644 --- a/flashlight/fl/examples/RnnLm.cpp +++ b/flashlight/fl/examples/RnnLm.cpp @@ -40,284 +40,298 @@ using namespace fl; namespace { class Preprocessor { - public: - explicit Preprocessor(std::string dataset_path); +public: + explicit Preprocessor(std::string dataset_path); - int to_int(std::string word) { - return word_to_int[word]; - } + int to_int(std::string word) { + return word_to_int[word]; + } - int vocab_size() { - return word_to_int.size(); - } + int vocab_size() { + return word_to_int.size(); + } - static const std::string eos; + static const std::string eos; - private: - std::unordered_map word_to_int; +private: + std::unordered_map word_to_int; }; class LMDataset : public Dataset { - public: - LMDataset( - std::string dataset_path, - int batch_size, - int time_steps, - Preprocessor& preproc); - - int64_t size() const override { - return (data.dim(1) - 1) / time_steps; - } - - std::vector get(const int64_t idx) const override; - - private: - int time_steps; - Tensor data; +public: + LMDataset( + std::string dataset_path, + int batch_size, + int time_steps, + Preprocessor& preproc + ); + + int64_t size() const override { + return (data.dim(1) - 1) / time_steps; + } + + std::vector get(const int64_t idx) const override; + +private: + int time_steps; + Tensor data; }; class RnnLm : public Container { - public: - explicit RnnLm(int vocab_size, int hidden_size = 200) - : embed(std::make_shared(hidden_size, vocab_size)), - rnn(std::make_shared( - hidden_size, - hidden_size, - 2, /* Num layers. */ - RnnMode::LSTM, - 0 /* Dropout */)), - linear(std::make_shared(hidden_size, vocab_size)), - logsoftmax_(0) // max on the main dimension - { - createLayers(); - } - - RnnLm(const RnnLm& other) { - copy(other); - createLayers(); - } - - // The compiler default generated is made explicit for reference. - // Users must be careful to include move and move assignment - // constructors where appropriate. - RnnLm(RnnLm&& other) = default; - - RnnLm& operator=(const RnnLm& other) { - clear(); - copy(other); - createLayers(); - return *this; - } - - // The compiler default generated is made explicit for reference. - // Users must be careful to include move and move assignment - // constructors where appropriate. - RnnLm& operator=(RnnLm&& other) = default; - - void copy(const RnnLm& other) { - train_ = other.train_; - embed = std::make_shared(*other.embed); - rnn = std::make_shared(*other.rnn); - linear = std::make_shared(*other.linear); - } - - void createLayers() { - add(embed); - add(rnn); - add(linear); - } - - std::vector forward(const std::vector& inputs) override { - auto inSz = inputs.size(); - if (inSz < 1 || inSz > 3) { - throw std::invalid_argument("Invalid inputs size"); +public: + explicit RnnLm(int vocab_size, int hidden_size = 200) : embed(std::make_shared( + hidden_size, + vocab_size + )), + rnn(std::make_shared( + hidden_size, + hidden_size, + 2, /* Num layers. */ + RnnMode::LSTM, + 0 /* Dropout */ + )), + linear(std::make_shared( + hidden_size, + vocab_size + )), + logsoftmax_(0) { // max on the main dimension + createLayers(); + } + + RnnLm(const RnnLm& other) { + copy(other); + createLayers(); + } + + // The compiler default generated is made explicit for reference. + // Users must be careful to include move and move assignment + // constructors where appropriate. + RnnLm(RnnLm&& other) = default; + + RnnLm& operator=(const RnnLm& other) { + clear(); + copy(other); + createLayers(); + return *this; + } + + // The compiler default generated is made explicit for reference. + // Users must be careful to include move and move assignment + // constructors where appropriate. + RnnLm& operator=(RnnLm&& other) = default; + + void copy(const RnnLm& other) { + train_ = other.train_; + embed = std::make_shared(*other.embed); + rnn = std::make_shared(*other.rnn); + linear = std::make_shared(*other.linear); + } + + void createLayers() { + add(embed); + add(rnn); + add(linear); + } + + std::vector forward(const std::vector& inputs) override { + auto inSz = inputs.size(); + if(inSz < 1 || inSz > 3) { + throw std::invalid_argument("Invalid inputs size"); + } + return rnn->forward(inputs); + } + + std::tuple forward( + const Variable& input, + const Variable& h, + const Variable& c + ) { + auto output = embed->forward(input); + Variable ho, co; + std::tie(output, ho, co) = rnn->forward(output, h, c); + + // Truncate BPTT + ho.setCalcGrad(false); + co.setCalcGrad(false); + + output = linear->forward(output); + output = logsoftmax_(output); + return std::make_tuple(output, ho, co); + } + + std::tuple + operator()(const Variable& input, const Variable& h, const Variable& c) { + return forward(input, h, c); } - return rnn->forward(inputs); - } - - std::tuple - forward(const Variable& input, const Variable& h, const Variable& c) { - auto output = embed->forward(input); - Variable ho, co; - std::tie(output, ho, co) = rnn->forward(output, h, c); - - // Truncate BPTT - ho.setCalcGrad(false); - co.setCalcGrad(false); - - output = linear->forward(output); - output = logsoftmax_(output); - return std::make_tuple(output, ho, co); - } - - std::tuple - operator()(const Variable& input, const Variable& h, const Variable& c) { - return forward(input, h, c); - } - - std::string prettyString() const override { - return "RnnLm"; - } - - std::unique_ptr clone() const override { - return std::make_unique(*this); - } - - private: - std::shared_ptr embed; - std::shared_ptr rnn; - std::shared_ptr linear; - LogSoftmax logsoftmax_; + + std::string prettyString() const override { + return "RnnLm"; + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + +private: + std::shared_ptr embed; + std::shared_ptr rnn; + std::shared_ptr linear; + LogSoftmax logsoftmax_; }; } // namespace int main(int argc, char** argv) { - fl::init(); - if (argc != 2) { - throw std::runtime_error("You must pass a data directory."); - } + fl::init(); + if(argc != 2) { + throw std::runtime_error("You must pass a data directory."); + } - std::string data_dir = argv[1]; + std::string data_dir = argv[1]; - std::string train_dir = data_dir + "/ptb.train.txt"; - std::string valid_dir = data_dir + "/ptb.valid.txt"; - std::string test_dir = data_dir + "/ptb.test.txt"; + std::string train_dir = data_dir + "/ptb.train.txt"; + std::string valid_dir = data_dir + "/ptb.valid.txt"; + std::string test_dir = data_dir + "/ptb.test.txt"; - // Since we also average the loss by time_steps - float learning_rate = 20; - float max_grad_norm = 0.25; + // Since we also average the loss by time_steps + float learning_rate = 20; + float max_grad_norm = 0.25; - int epochs = 10; - int anneal_after_epoch = 4; - int batch_size = 20; - int time_steps = 20; + int epochs = 10; + int anneal_after_epoch = 4; + int batch_size = 20; + int time_steps = 20; - Preprocessor preproc(train_dir); - LMDataset trainset(train_dir, batch_size, time_steps, preproc); - LMDataset valset(valid_dir, batch_size, time_steps, preproc); - const int kInputIdx = 0, kTargetIdx = 1; + Preprocessor preproc(train_dir); + LMDataset trainset(train_dir, batch_size, time_steps, preproc); + LMDataset valset(valid_dir, batch_size, time_steps, preproc); + const int kInputIdx = 0, kTargetIdx = 1; - int vocab_size = preproc.vocab_size(); - std::cout << "Vocab size: " << vocab_size << std::endl; + int vocab_size = preproc.vocab_size(); + std::cout << "Vocab size: " << vocab_size << std::endl; - RnnLm model(vocab_size); - CategoricalCrossEntropy criterion; + RnnLm model(vocab_size); + CategoricalCrossEntropy criterion; - SGDOptimizer opt(model.params(), learning_rate); + SGDOptimizer opt(model.params(), learning_rate); - auto eval_loop = - [&model, &criterion, kInputIdx, kTargetIdx](LMDataset& dataset) { - AverageValueMeter avg_loss_meter; - Variable output, h, c; - for (auto& example : dataset) { - std::tie(output, h, c) = model(noGrad(example[kInputIdx]), h, c); - auto target = noGrad(example[kTargetIdx]); - auto loss = criterion(output, target); - avg_loss_meter.add(loss.tensor().scalar(), target.elements()); - } - return avg_loss_meter.value()[0]; - }; + auto eval_loop = + [&model, &criterion, kInputIdx, kTargetIdx](LMDataset& dataset) { + AverageValueMeter avg_loss_meter; + Variable output, h, c; + for(auto& example : dataset) { + std::tie(output, h, c) = model(noGrad(example[kInputIdx]), h, c); + auto target = noGrad(example[kTargetIdx]); + auto loss = criterion(output, target); + avg_loss_meter.add(loss.tensor().scalar(), target.elements()); + } + return avg_loss_meter.value()[0]; + }; - for (int e = 0; e < epochs; e++) { - AverageValueMeter train_loss_meter; - TimeMeter timer(true); - timer.resume(); + for(int e = 0; e < epochs; e++) { + AverageValueMeter train_loss_meter; + TimeMeter timer(true); + timer.resume(); - Variable output, h, c; + Variable output, h, c; - if (e >= anneal_after_epoch) { - opt.setLr(opt.getLr() / 2); - } + if(e >= anneal_after_epoch) { + opt.setLr(opt.getLr() / 2); + } - for (auto& example : trainset) { - std::tie(output, h, c) = model(noGrad(example[kInputIdx]), h, c); + for(auto& example : trainset) { + std::tie(output, h, c) = model(noGrad(example[kInputIdx]), h, c); - auto target = noGrad(example[kTargetIdx]); + auto target = noGrad(example[kTargetIdx]); - auto loss = criterion(output, target); - train_loss_meter.add(loss.tensor().scalar(), target.elements()); + auto loss = criterion(output, target); + train_loss_meter.add(loss.tensor().scalar(), target.elements()); - opt.zeroGrad(); - loss.backward(); + opt.zeroGrad(); + loss.backward(); - clipGradNorm(model.params(), max_grad_norm); - opt.step(); + clipGradNorm(model.params(), max_grad_norm); + opt.step(); - fl::sync(); - timer.incUnit(); - } + fl::sync(); + timer.incUnit(); + } - double train_loss = train_loss_meter.value()[0]; - double val_loss = eval_loop(valset); - double iter_time = timer.value(); + double train_loss = train_loss_meter.value()[0]; + double val_loss = eval_loop(valset); + double iter_time = timer.value(); - std::cout << "Epoch " << e + 1 << std::setprecision(3) - << " - Train Loss: " << train_loss - << " Validation Loss: " << val_loss - << " Validation Perplexity: " << std::exp(val_loss) - << " Time per iteration (ms): " << iter_time * 1000 << std::endl; - } + std::cout << "Epoch " << e + 1 << std::setprecision(3) + << " - Train Loss: " << train_loss + << " Validation Loss: " << val_loss + << " Validation Perplexity: " << std::exp(val_loss) + << " Time per iteration (ms): " << iter_time * 1000 << std::endl; + } - LMDataset testset(test_dir, batch_size, time_steps, preproc); + LMDataset testset(test_dir, batch_size, time_steps, preproc); - double test_loss = eval_loop(testset); - std::cout << " Test Loss: " << test_loss - << " Test Perplexity: " << std::exp(test_loss) << std::endl; + double test_loss = eval_loop(testset); + std::cout << " Test Loss: " << test_loss + << " Test Perplexity: " << std::exp(test_loss) << std::endl; - return 0; + return 0; } const std::string Preprocessor::eos = ""; Preprocessor::Preprocessor(std::string dataset_path) { - std::ifstream file(dataset_path); - if (!file.is_open()) { - throw std::runtime_error("[Preprocessor::Preprocessor] Can't find file."); - } - int v = 0; - std::string word; - while (file >> word) { - if (word_to_int.find(word) == word_to_int.end()) { - word_to_int[word] = v++; + std::ifstream file(dataset_path); + if(!file.is_open()) { + throw std::runtime_error("[Preprocessor::Preprocessor] Can't find file."); + } + int v = 0; + std::string word; + while(file >> word) { + if(word_to_int.find(word) == word_to_int.end()) { + word_to_int[word] = v++; + } } - } - word_to_int[eos] = v; + word_to_int[eos] = v; } LMDataset::LMDataset( std::string dataset_path, int batch_size, int time_steps, - Preprocessor& preproc) - : time_steps(time_steps) { - std::vector words; - std::ifstream file(dataset_path); - if (!file.is_open()) { - throw std::runtime_error("[LMDataset::LMDataset] Can't find file."); - } - - std::string line; - while (std::getline(file, line)) { - std::istringstream ss(line); - std::string word; - while (ss >> word) { - words.push_back(preproc.to_int(word)); + Preprocessor& preproc +) : time_steps(time_steps) { + std::vector words; + std::ifstream file(dataset_path); + if(!file.is_open()) { + throw std::runtime_error("[LMDataset::LMDataset] Can't find file."); + } + + std::string line; + while(std::getline(file, line)) { + std::istringstream ss(line); + std::string word; + while(ss >> word) { + words.push_back(preproc.to_int(word)); + } + words.push_back(preproc.to_int(Preprocessor::eos)); } - words.push_back(preproc.to_int(Preprocessor::eos)); - } - int words_per_batch = words.size() / batch_size; - words.resize(batch_size * words_per_batch); + int words_per_batch = words.size() / batch_size; + words.resize(batch_size * words_per_batch); - data = transpose(Tensor::fromBuffer( - {words_per_batch, batch_size}, words.data(), MemoryLocation::Host)); + data = transpose( + Tensor::fromBuffer( + {words_per_batch, batch_size}, + words.data(), + MemoryLocation::Host + ) + ); } std::vector LMDataset::get(const int64_t idx) const { - int start = idx * time_steps; - int end = (idx + 1) * time_steps; - return { - data(fl::span, fl::range(start, end)), - data(fl::span, fl::range(start, end))}; + int start = idx * time_steps; + int end = (idx + 1) * time_steps; + return { + data(fl::span, fl::range(start, end)), + data(fl::span, fl::range(start, end))}; } diff --git a/flashlight/fl/examples/Xor.cpp b/flashlight/fl/examples/Xor.cpp index 0f3dc62..7a80461 100644 --- a/flashlight/fl/examples/Xor.cpp +++ b/flashlight/fl/examples/Xor.cpp @@ -18,89 +18,95 @@ using namespace fl; int main(int argc, const char** argv) { - if (argc != 2) { - std::cerr << "usage: " << argv[0] << " [--adam | --rmsprop]\n"; - return 1; - } - fl::init(); - - int optim_mode = 0; - std::string optimizer_arg = std::string(argv[1]); - if (optimizer_arg == "--adam") { - optim_mode = 1; - } else if (optimizer_arg == "--rmsprop") { - optim_mode = 2; - } - - const int inputSize = 2; - const int outputSize = 1; - const float lr = 0.01; - const float mu = 0.1; - const int numSamples = 4; - - std::array hInput = {1, 1, 0, 0, 1, 0, 0, 1}; - std::array hOutput = {1, 0, 1, 1}; - - auto in = Tensor::fromBuffer( - {inputSize, numSamples}, hInput.data(), MemoryLocation::Host); - auto out = Tensor::fromBuffer( - {outputSize, numSamples}, hOutput.data(), MemoryLocation::Host); - - Sequential model; - - model.add(Linear(inputSize, outputSize)); - model.add(Sigmoid()); - - auto loss = MeanSquaredError(); - - std::unique_ptr optim; - - if (optimizer_arg == "--rmsprop") { - optim = std::make_unique(model.params(), lr); - } else if (optimizer_arg == "--adam") { - optim = std::make_unique(model.params(), lr); - } else { - optim = std::make_unique(model.params(), lr, mu); - } - - Variable result, l; - for (int i = 0; i < 1000; i++) { - for (int j = 0; j < numSamples; j++) { - model.train(); - optim->zeroGrad(); - - Tensor in_j = in(fl::span, j); - Tensor out_j = out(fl::span, j); - - // Forward propagation - result = model(input(in_j)); - - // Calculate loss - l = loss(result, noGrad(out_j)); - - // Backward propagation - l.backward(); - - // Update parameters - optim->step(); + if(argc != 2) { + std::cerr << "usage: " << argv[0] << " [--adam | --rmsprop]\n"; + return 1; } + fl::init(); + + int optim_mode = 0; + std::string optimizer_arg = std::string(argv[1]); + if(optimizer_arg == "--adam") { + optim_mode = 1; + } else if(optimizer_arg == "--rmsprop") { + optim_mode = 2; + } + + const int inputSize = 2; + const int outputSize = 1; + const float lr = 0.01; + const float mu = 0.1; + const int numSamples = 4; + + std::array hInput = {1, 1, 0, 0, 1, 0, 0, 1}; + std::array hOutput = {1, 0, 1, 1}; + + auto in = Tensor::fromBuffer( + {inputSize, numSamples}, + hInput.data(), + MemoryLocation::Host + ); + auto out = Tensor::fromBuffer( + {outputSize, numSamples}, + hOutput.data(), + MemoryLocation::Host + ); + + Sequential model; + + model.add(Linear(inputSize, outputSize)); + model.add(Sigmoid()); + + auto loss = MeanSquaredError(); + + std::unique_ptr optim; + + if(optimizer_arg == "--rmsprop") { + optim = std::make_unique(model.params(), lr); + } else if(optimizer_arg == "--adam") { + optim = std::make_unique(model.params(), lr); + } else { + optim = std::make_unique(model.params(), lr, mu); + } + + Variable result, l; + for(int i = 0; i < 1000; i++) { + for(int j = 0; j < numSamples; j++) { + model.train(); + optim->zeroGrad(); + + Tensor in_j = in(fl::span, j); + Tensor out_j = out(fl::span, j); + + // Forward propagation + result = model(input(in_j)); + + // Calculate loss + l = loss(result, noGrad(out_j)); + + // Backward propagation + l.backward(); + + // Update parameters + optim->step(); + } + + if((i + 1) % 100 == 0) { + model.eval(); + + // Forward propagation + result = model(input(in)); - if ((i + 1) % 100 == 0) { - model.eval(); - - // Forward propagation - result = model(input(in)); - - // Calculate loss - // TODO: Use loss function - Tensor diff = out - result.tensor(); - std::cout << "Average Error at iteration (" << i + 1 - << ") : " << fl::mean(fl::abs(diff)).scalar() << "\n"; - std::cout << "Predicted\n" - << result.tensor() << std::endl - << "Expected\n" - << out << std::endl; + // Calculate loss + // TODO: Use loss function + Tensor diff = out - result.tensor(); + std::cout << "Average Error at iteration (" << i + 1 + << ") : " << fl::mean(fl::abs(diff)).scalar() << "\n"; + std::cout << "Predicted\n" + << result.tensor() << std::endl + << "Expected\n" + << out << std::endl; + } } - } - return 0; + return 0; } diff --git a/flashlight/fl/meter/AverageValueMeter.cpp b/flashlight/fl/meter/AverageValueMeter.cpp index 885dfea..42d967a 100644 --- a/flashlight/fl/meter/AverageValueMeter.cpp +++ b/flashlight/fl/meter/AverageValueMeter.cpp @@ -12,49 +12,49 @@ namespace fl { AverageValueMeter::AverageValueMeter() { - reset(); + reset(); } void AverageValueMeter::reset() { - curMean_ = 0; - curMeanSquaredSum_ = 0; - curWeightSum_ = 0; - curWeightSquaredSum_ = 0; + curMean_ = 0; + curMeanSquaredSum_ = 0; + curWeightSum_ = 0; + curWeightSquaredSum_ = 0; } void AverageValueMeter::add(const double val, const double w /* = 1.0 */) { - curWeightSum_ += w; - curWeightSquaredSum_ += w * w; + curWeightSum_ += w; + curWeightSquaredSum_ += w * w; - if (curWeightSum_ == 0) { - return; - } + if(curWeightSum_ == 0) { + return; + } - curMean_ = curMean_ + w * (val - curMean_) / curWeightSum_; - curMeanSquaredSum_ = - curMeanSquaredSum_ + w * (val * val - curMeanSquaredSum_) / curWeightSum_; + curMean_ = curMean_ + w * (val - curMean_) / curWeightSum_; + curMeanSquaredSum_ = + curMeanSquaredSum_ + w * (val * val - curMeanSquaredSum_) / curWeightSum_; } void AverageValueMeter::add(const Tensor& vals) { - double w = vals.elements(); - curWeightSum_ += w; - curWeightSquaredSum_ += w; - - if (curWeightSum_ == 0) { - return; - } - - curMean_ = curMean_ + - (fl::sum(vals).asScalar() - w * curMean_) / curWeightSum_; - curMeanSquaredSum_ = curMeanSquaredSum_ + - (fl::sum(vals * vals).asScalar() - w * curMeanSquaredSum_) / - curWeightSum_; + double w = vals.elements(); + curWeightSum_ += w; + curWeightSquaredSum_ += w; + + if(curWeightSum_ == 0) { + return; + } + + curMean_ = curMean_ + + (fl::sum(vals).asScalar() - w * curMean_) / curWeightSum_; + curMeanSquaredSum_ = curMeanSquaredSum_ + + (fl::sum(vals * vals).asScalar() - w * curMeanSquaredSum_) + / curWeightSum_; } std::vector AverageValueMeter::value() const { - double mean = curMean_; - double var = (curMeanSquaredSum_ - curMean_ * curMean_) / - (1 - curWeightSquaredSum_ / (curWeightSum_ * curWeightSum_)); - return {mean, var, curWeightSum_}; + double mean = curMean_; + double var = (curMeanSquaredSum_ - curMean_ * curMean_) + / (1 - curWeightSquaredSum_ / (curWeightSum_ * curWeightSum_)); + return {mean, var, curWeightSum_}; } } // namespace fl diff --git a/flashlight/fl/meter/AverageValueMeter.h b/flashlight/fl/meter/AverageValueMeter.h index 25ac866..31a6970 100644 --- a/flashlight/fl/meter/AverageValueMeter.h +++ b/flashlight/fl/meter/AverageValueMeter.h @@ -41,32 +41,32 @@ class Tensor; * \endcode */ class FL_API AverageValueMeter { - public: - /** Constructor of `AverageValueMeter`. */ - AverageValueMeter(); +public: + /** Constructor of `AverageValueMeter`. */ + AverageValueMeter(); - /** Updates counters with the given value `val` with weight `w`. */ - void add(const double val, const double w = 1.0); + /** Updates counters with the given value `val` with weight `w`. */ + void add(const double val, const double w = 1.0); - /** Updates counters with all values in `vals` with equal weights. */ - void add(const Tensor& vals); + /** Updates counters with all values in `vals` with equal weights. */ + void add(const Tensor& vals); - /** Returns a vector of four values: - * - `unbiased mean`: \f$ \tilde{mu} \f$ - * - `unbiased variance`: \f$ \tilde{sigma}^2 = \frac{(\tilde{mu}_2 - - * \tilde{mu}^2)}{1 - Sum(P^2)} \f$ - * - `weight_sum`: \f$ Sum(W) \f$ - * - `weight_squared_sum`: \f$ Sum(W^2) \f$ - */ - std::vector value() const; + /** Returns a vector of four values: + * - `unbiased mean`: \f$ \tilde{mu} \f$ + * - `unbiased variance`: \f$ \tilde{sigma}^2 = \frac{(\tilde{mu}_2 - + * \tilde{mu}^2)}{1 - Sum(P^2)} \f$ + * - `weight_sum`: \f$ Sum(W) \f$ + * - `weight_squared_sum`: \f$ Sum(W^2) \f$ + */ + std::vector value() const; - /** Sets all the counters to 0. */ - void reset(); + /** Sets all the counters to 0. */ + void reset(); - private: - double curMean_; - double curMeanSquaredSum_; - double curWeightSum_; - double curWeightSquaredSum_; +private: + double curMean_; + double curMeanSquaredSum_; + double curWeightSum_; + double curWeightSquaredSum_; }; } // namespace fl diff --git a/flashlight/fl/meter/CountMeter.cpp b/flashlight/fl/meter/CountMeter.cpp index ba54c15..17a97ce 100644 --- a/flashlight/fl/meter/CountMeter.cpp +++ b/flashlight/fl/meter/CountMeter.cpp @@ -15,18 +15,18 @@ namespace fl { CountMeter::CountMeter(int num) : counts_(num, 0) {} void CountMeter::add(int id, int64_t val) { - if (!(id >= 0 && id < counts_.size())) { - throw std::out_of_range("invalid id to update count for"); - } - counts_[id] += val; + if(!(id >= 0 && id < counts_.size())) { + throw std::out_of_range("invalid id to update count for"); + } + counts_[id] += val; } std::vector CountMeter::value() const { - return counts_; + return counts_; } void CountMeter::reset() { - std::fill(counts_.begin(), counts_.end(), 0); + std::fill(counts_.begin(), counts_.end(), 0); } } // namespace fl diff --git a/flashlight/fl/meter/CountMeter.h b/flashlight/fl/meter/CountMeter.h index 4e0c7dc..58b094d 100644 --- a/flashlight/fl/meter/CountMeter.h +++ b/flashlight/fl/meter/CountMeter.h @@ -20,35 +20,35 @@ namespace fl { * Example usage: * * \code - CountMeter meter(10); // 10 categories in total - meter.add(4, 6); // add 6 count to category 4 - meter.add(7, 2); // add 2 count to category 7 - meter.add(4, -1); // add -1 count to category 4 - - auto counts = meter.value(); - std::cout << counts[4]; // prints 5 - \endcode + CountMeter meter(10); // 10 categories in total + meter.add(4, 6); // add 6 count to category 4 + meter.add(7, 2); // add 2 count to category 7 + meter.add(4, -1); // add -1 count to category 4 + + auto counts = meter.value(); + std::cout << counts[4]; // prints 5 + \endcode */ class FL_API CountMeter { - public: - /** Constructor of `CountMeter`. `num` specifies the total number of - * categories. - */ - explicit CountMeter(int num); - - /** Adds value `val` to category `id`. Note that `id` should be in range [0, - * `num` - 1].*/ - void add(int id, int64_t val); - - /** Returns a vector of `num` values, representing the total value of each - * category. - */ - std::vector value() const; - - /** Sets the value of each category to 0. */ - void reset(); - - private: - std::vector counts_; +public: + /** Constructor of `CountMeter`. `num` specifies the total number of + * categories. + */ + explicit CountMeter(int num); + + /** Adds value `val` to category `id`. Note that `id` should be in range [0, + * `num` - 1].*/ + void add(int id, int64_t val); + + /** Returns a vector of `num` values, representing the total value of each + * category. + */ + std::vector value() const; + + /** Sets the value of each category to 0. */ + void reset(); + +private: + std::vector counts_; }; } // namespace fl diff --git a/flashlight/fl/meter/EditDistanceMeter.cpp b/flashlight/fl/meter/EditDistanceMeter.cpp index fba8a21..e93f9b4 100644 --- a/flashlight/fl/meter/EditDistanceMeter.cpp +++ b/flashlight/fl/meter/EditDistanceMeter.cpp @@ -14,67 +14,70 @@ namespace fl { EditDistanceMeter::EditDistanceMeter() { - reset(); + reset(); } void EditDistanceMeter::reset() { - n_ = 0; - ndel_ = 0; - nins_ = 0; - nsub_ = 0; + n_ = 0; + ndel_ = 0; + nins_ = 0; + nsub_ = 0; } void EditDistanceMeter::add(const Tensor& output, const Tensor& target) { - if (target.ndim() != 1) { - throw std::invalid_argument( - "target must be 1-dimensional for EditDistanceMeter"); - } - if (output.ndim() != 1) { - throw std::invalid_argument( - "output must be 1-dimensional for EditDistanceMeter"); - } - int len1 = output.dim(0); - int len2 = target.dim(0); + if(target.ndim() != 1) { + throw std::invalid_argument( + "target must be 1-dimensional for EditDistanceMeter" + ); + } + if(output.ndim() != 1) { + throw std::invalid_argument( + "output must be 1-dimensional for EditDistanceMeter" + ); + } + int len1 = output.dim(0); + int len2 = target.dim(0); - int* in1raw = output.host(); - int* in2raw = target.host(); - auto err_state = levensteinDistance(in1raw, in2raw, len1, len2); - free(in1raw); - in1raw = nullptr; - free(in2raw); - in2raw = nullptr; - add(err_state, target.dim(0)); + int* in1raw = output.host(); + int* in2raw = target.host(); + auto err_state = levensteinDistance(in1raw, in2raw, len1, len2); + free(in1raw); + in1raw = nullptr; + free(in2raw); + in2raw = nullptr; + add(err_state, target.dim(0)); } void EditDistanceMeter::add( const int64_t n, const int64_t ndel, const int64_t nins, - const int64_t nsub) { - n_ += n; - ndel_ += ndel; - nins_ += nins; - nsub_ += nsub; + const int64_t nsub +) { + n_ += n; + ndel_ += ndel; + nins_ += nins; + nsub_ += nsub; } std::vector EditDistanceMeter::value() const { - return {sumErr(), n_, ndel_, nins_, nsub_}; + return {sumErr(), n_, ndel_, nins_, nsub_}; } std::vector EditDistanceMeter::errorRate() const { - double val, valDel, valIns, valSub; - if (n_ > 0) { - val = static_cast(sumErr() * 100.0) / n_; - valDel = static_cast(ndel_ * 100.0) / n_; - valIns = static_cast(nins_ * 100.0) / n_; - valSub = static_cast(nsub_ * 100.0) / n_; - } else { - val = (sumErr() > 0) ? std::numeric_limits::infinity() : 0.0; - valDel = (ndel_ > 0) ? std::numeric_limits::infinity() : 0.0; - valIns = (nins_ > 0) ? std::numeric_limits::infinity() : 0.0; - valSub = (nsub_ > 0) ? std::numeric_limits::infinity() : 0.0; - } - return {val, static_cast(n_), valDel, valIns, valSub}; + double val, valDel, valIns, valSub; + if(n_ > 0) { + val = static_cast(sumErr() * 100.0) / n_; + valDel = static_cast(ndel_ * 100.0) / n_; + valIns = static_cast(nins_ * 100.0) / n_; + valSub = static_cast(nsub_ * 100.0) / n_; + } else { + val = (sumErr() > 0) ? std::numeric_limits::infinity() : 0.0; + valDel = (ndel_ > 0) ? std::numeric_limits::infinity() : 0.0; + valIns = (nins_ > 0) ? std::numeric_limits::infinity() : 0.0; + valSub = (nsub_ > 0) ? std::numeric_limits::infinity() : 0.0; + } + return {val, static_cast(n_), valDel, valIns, valSub}; } } // namespace fl diff --git a/flashlight/fl/meter/EditDistanceMeter.h b/flashlight/fl/meter/EditDistanceMeter.h index 644c927..f98620c 100644 --- a/flashlight/fl/meter/EditDistanceMeter.h +++ b/flashlight/fl/meter/EditDistanceMeter.h @@ -32,142 +32,145 @@ class Tensor; * \endcode */ class FL_API EditDistanceMeter { - public: - /** A structure storing number of different type of errors when computing edit - * distance. */ - struct ErrorState { - int64_t ndel; //!< Number of deletion error - int64_t nins; //!< Number of insertion error - int64_t nsub; //!< Number of substitution error - ErrorState() : ndel(0), nins(0), nsub(0) {} - - /** Sums up all the errors. */ - int64_t sum() const { - return ndel + nins + nsub; +public: + /** A structure storing number of different type of errors when computing edit + * distance. */ + struct ErrorState { + int64_t ndel; // !< Number of deletion error + int64_t nins; // !< Number of insertion error + int64_t nsub; // !< Number of substitution error + ErrorState() : ndel(0), nins(0), nsub(0) {} + + /** Sums up all the errors. */ + int64_t sum() const { + return ndel + nins + nsub; + } + }; + + /** Constructor of `EditDistanceMeter`. An instance will maintain five + * counters initialized to 0: + * - `n`: total target lengths + * - `ndel`: total deletion error + * - `nins`: total insertion error + * - `nsub`: total substitution error + */ + EditDistanceMeter(); + + /** Computes edit distance between two arrayfire arrays `output` and `target` + * and updates the counters. + */ + void add(const Tensor& output, const Tensor& target); + + /** Updates all the counters with inputs sharing the same meaning. */ + void add( + const int64_t n, + const int64_t ndel, + const int64_t nins, + const int64_t nsub + ); + + /** Updates all the counters with an `ErrorState`. */ + void add(const ErrorState& es, const int64_t n) { + add(n, es.ndel, es.nins, es.nsub); } - }; - - /** Constructor of `EditDistanceMeter`. An instance will maintain five - * counters initialized to 0: - * - `n`: total target lengths - * - `ndel`: total deletion error - * - `nins`: total insertion error - * - `nsub`: total substitution error - */ - EditDistanceMeter(); - - /** Computes edit distance between two arrayfire arrays `output` and `target` - * and updates the counters. - */ - void add(const Tensor& output, const Tensor& target); - - /** Updates all the counters with inputs sharing the same meaning. */ - void add( - const int64_t n, - const int64_t ndel, - const int64_t nins, - const int64_t nsub); - - /** Updates all the counters with an `ErrorState`. */ - void add(const ErrorState& es, const int64_t n) { - add(n, es.ndel, es.nins, es.nsub); - } - - /** Returns a vector of five values: - * - `error rate`: \f$ \frac{(ndel + nins + nsub)}{n} \times 100.0 \f$ - * - `total length`: \f$ n \f$ - * - `deletion rate`: \f$ \frac{ndel}{n} \times 100.0\f$ - * - `insertion rate`: \f$ \frac{nins}{n} \times 100.0 \f$ - * - `substitution rate`: \f$ \frac{nsub}{n} \times 100.0 \f$ - */ - std::vector errorRate() const; - - /** Returns a vector of five values: - * - `edit distance`: \f$ (ndel + nins + nsub)\f$ - * - `total length`: \f$ n \f$ - * - `number of deletions`: \f$ ndel \f$ - * - `number of insertions`: \f$ nins \f$ - * - `number of substitution`: \f$ nsub \f$ - */ - std::vector value() const; - - /** Computes edit distance between two arrays `output` and `target`, with - * length `olen` and `tlen` respectively, and updates the counters. - */ - template - void - add(const T& output, const S& target, const size_t olen, const size_t tlen) { - auto err_state = levensteinDistance(output, target, olen, tlen); - add(err_state, tlen); - } - - /** Computes edit distance between two vectors `output` and `target` - * and updates the counters. - */ - template - void add(const std::vector& output, const std::vector& target) { - add(output.data(), target.data(), output.size(), target.size()); - } - - /** Sets all the counters to 0. */ - void reset(); - - private: - int64_t n_; - int64_t ndel_; - int64_t nins_; - int64_t nsub_; - - int64_t sumErr() const { - return ndel_ + nins_ + nsub_; - } - - template - ErrorState levensteinDistance( - const T& in1begin, - const T& in2begin, - size_t len1, - size_t len2) const { - std::vector column(len1 + 1); - for (int i = 0; i <= len1; ++i) { - column[i].nins = i; + + /** Returns a vector of five values: + * - `error rate`: \f$ \frac{(ndel + nins + nsub)}{n} \times 100.0 \f$ + * - `total length`: \f$ n \f$ + * - `deletion rate`: \f$ \frac{ndel}{n} \times 100.0\f$ + * - `insertion rate`: \f$ \frac{nins}{n} \times 100.0 \f$ + * - `substitution rate`: \f$ \frac{nsub}{n} \times 100.0 \f$ + */ + std::vector errorRate() const; + + /** Returns a vector of five values: + * - `edit distance`: \f$ (ndel + nins + nsub)\f$ + * - `total length`: \f$ n \f$ + * - `number of deletions`: \f$ ndel \f$ + * - `number of insertions`: \f$ nins \f$ + * - `number of substitution`: \f$ nsub \f$ + */ + std::vector value() const; + + /** Computes edit distance between two arrays `output` and `target`, with + * length `olen` and `tlen` respectively, and updates the counters. + */ + template + void add(const T& output, const S& target, const size_t olen, const size_t tlen) { + auto err_state = levensteinDistance(output, target, olen, tlen); + add(err_state, tlen); } - auto curin2 = in2begin; - for (int x = 1; x <= len2; x++) { - ErrorState lastdiagonal = column[0]; - column[0].ndel = x; - auto curin1 = in1begin; - for (int y = 1; y <= len1; y++) { - auto olddiagonal = column[y]; - auto possibilities = { - column[y].sum() + 1, - column[y - 1].sum() + 1, - lastdiagonal.sum() + ((*curin1 == *curin2) ? 0 : 1)}; - auto min_it = - std::min_element(possibilities.begin(), possibilities.end()); - if (std::distance(possibilities.begin(), min_it) == - 0) { // deletion error - ++column[y].ndel; - } else if ( - std::distance(possibilities.begin(), min_it) == 1) { // insertion - // error - column[y] = column[y - 1]; - ++column[y].nins; - } else { - column[y] = lastdiagonal; - if (*curin1 != *curin2) { // substitution error - ++column[y].nsub; - } - } + /** Computes edit distance between two vectors `output` and `target` + * and updates the counters. + */ + template + void add(const std::vector& output, const std::vector& target) { + add(output.data(), target.data(), output.size(), target.size()); + } + + /** Sets all the counters to 0. */ + void reset(); + +private: + int64_t n_; + int64_t ndel_; + int64_t nins_; + int64_t nsub_; - lastdiagonal = olddiagonal; - ++curin1; - } - ++curin2; + int64_t sumErr() const { + return ndel_ + nins_ + nsub_; } - return column[len1]; - } + template + ErrorState levensteinDistance( + const T& in1begin, + const T& in2begin, + size_t len1, + size_t len2 + ) const { + std::vector column(len1 + 1); + for(int i = 0; i <= len1; ++i) { + column[i].nins = i; + } + + auto curin2 = in2begin; + for(int x = 1; x <= len2; x++) { + ErrorState lastdiagonal = column[0]; + column[0].ndel = x; + auto curin1 = in1begin; + for(int y = 1; y <= len1; y++) { + auto olddiagonal = column[y]; + auto possibilities = { + column[y].sum() + 1, + column[y - 1].sum() + 1, + lastdiagonal.sum() + ((*curin1 == *curin2) ? 0 : 1)}; + auto min_it = + std::min_element(possibilities.begin(), possibilities.end()); + if( + std::distance(possibilities.begin(), min_it) + == 0 + ) { // deletion error + ++column[y].ndel; + } else if( + std::distance(possibilities.begin(), min_it) == 1) { // insertion + // error + column[y] = column[y - 1]; + ++column[y].nins; + } else { + column[y] = lastdiagonal; + if(*curin1 != *curin2) { // substitution error + ++column[y].nsub; + } + } + + lastdiagonal = olddiagonal; + ++curin1; + } + ++curin2; + } + + return column[len1]; + } }; } // namespace fl diff --git a/flashlight/fl/meter/FrameErrorMeter.cpp b/flashlight/fl/meter/FrameErrorMeter.cpp index e128df7..6262e37 100644 --- a/flashlight/fl/meter/FrameErrorMeter.cpp +++ b/flashlight/fl/meter/FrameErrorMeter.cpp @@ -12,32 +12,32 @@ #include "flashlight/fl/tensor/TensorBase.h" namespace fl { -FrameErrorMeter::FrameErrorMeter(bool accuracy /* = false */) - : accuracy_(accuracy) { - reset(); +FrameErrorMeter::FrameErrorMeter(bool accuracy /* = false */) : accuracy_(accuracy) { + reset(); } void FrameErrorMeter::reset() { - n_ = 0; - sum_ = 0; + n_ = 0; + sum_ = 0; } void FrameErrorMeter::add(const Tensor& output, const Tensor& target) { - if (output.shape() != target.shape()) { - throw std::invalid_argument("dimension mismatch in FrameErrorMeter"); - } - if (target.ndim() != 1) { - throw std::invalid_argument( - "output/target must be 1-dimensional for FrameErrorMeter"); - } - - sum_ += fl::countNonzero(output != target).scalar(); - n_ += target.dim(0); + if(output.shape() != target.shape()) { + throw std::invalid_argument("dimension mismatch in FrameErrorMeter"); + } + if(target.ndim() != 1) { + throw std::invalid_argument( + "output/target must be 1-dimensional for FrameErrorMeter" + ); + } + + sum_ += fl::countNonzero(output != target).scalar(); + n_ += target.dim(0); } double FrameErrorMeter::value() const { - double error = (n_ > 0) ? (static_cast(sum_ * 100.0) / n_) : 0.0; - double val = (accuracy_ ? (100.0 - error) : error); - return val; + double error = (n_ > 0) ? (static_cast(sum_ * 100.0) / n_) : 0.0; + double val = (accuracy_ ? (100.0 - error) : error); + return val; } } // namespace fl diff --git a/flashlight/fl/meter/FrameErrorMeter.h b/flashlight/fl/meter/FrameErrorMeter.h index 783e347..39ec430 100644 --- a/flashlight/fl/meter/FrameErrorMeter.h +++ b/flashlight/fl/meter/FrameErrorMeter.h @@ -30,32 +30,32 @@ class Tensor; * \endcode */ class FL_API FrameErrorMeter { - public: - /** Constructor of `FrameErrorMeter`. Flag `accuracy` indicates if the meter - * computes and returns accuracy or error rate instead. An instance will - * maintain two counters initialized to 0: - * - `n`: total samples - * - `sum`: total mismatches - */ - explicit FrameErrorMeter(bool accuracy = false); - - /** Computes frame-level mismatch between two arrayfire arrays `output` and - * `target` and updates the counters. Note that the shape of the two input - * arrays should be identical. - */ - void add(const Tensor& output, const Tensor& target); - - /** Returns a single value in percentage. If `accuracy` is `True`, the value - * returned is accuracy, error otherwise. - */ - double value() const; - - /** Sets all the counters to 0. */ - void reset(); - - private: - std::int64_t n_; - std::int64_t sum_; - bool accuracy_; +public: + /** Constructor of `FrameErrorMeter`. Flag `accuracy` indicates if the meter + * computes and returns accuracy or error rate instead. An instance will + * maintain two counters initialized to 0: + * - `n`: total samples + * - `sum`: total mismatches + */ + explicit FrameErrorMeter(bool accuracy = false); + + /** Computes frame-level mismatch between two arrayfire arrays `output` and + * `target` and updates the counters. Note that the shape of the two input + * arrays should be identical. + */ + void add(const Tensor& output, const Tensor& target); + + /** Returns a single value in percentage. If `accuracy` is `True`, the value + * returned is accuracy, error otherwise. + */ + double value() const; + + /** Sets all the counters to 0. */ + void reset(); + +private: + std::int64_t n_; + std::int64_t sum_; + bool accuracy_; }; } // namespace fl diff --git a/flashlight/fl/meter/MSEMeter.cpp b/flashlight/fl/meter/MSEMeter.cpp index a9e0782..0998406 100644 --- a/flashlight/fl/meter/MSEMeter.cpp +++ b/flashlight/fl/meter/MSEMeter.cpp @@ -13,26 +13,26 @@ namespace fl { MSEMeter::MSEMeter() { - reset(); + reset(); } void MSEMeter::reset() { - curN_ = 0; - curValue_ = .0; + curN_ = 0; + curValue_ = .0; } void MSEMeter::add(const Tensor& output, const Tensor& target) { - if (output.ndim() != target.ndim()) { - throw std::invalid_argument("dimension mismatch in MSEMeter"); - } - ++curN_; - curValue_ = - (curValue_ * (curN_ - 1) + - fl::sum((output - target) * (output - target)).asScalar()) / - curN_; + if(output.ndim() != target.ndim()) { + throw std::invalid_argument("dimension mismatch in MSEMeter"); + } + ++curN_; + curValue_ = + (curValue_ * (curN_ - 1) + + fl::sum((output - target) * (output - target)).asScalar()) + / curN_; } double MSEMeter::value() const { - return curValue_; + return curValue_; } } // namespace fl diff --git a/flashlight/fl/meter/MSEMeter.h b/flashlight/fl/meter/MSEMeter.h index aa30d99..2b77fdf 100644 --- a/flashlight/fl/meter/MSEMeter.h +++ b/flashlight/fl/meter/MSEMeter.h @@ -30,28 +30,28 @@ class Tensor; * \endcode */ class FL_API MSEMeter { - public: - /** Constructor of `MSEMeter`. An instance will maintain two - * counters initialized to 0: - * - `n`: total samples - * - `mse`: mean square error of samples - */ - MSEMeter(); - - /** Computes mean square error between two arrayfire arrays `output` and - * `target` and updates the counters. Note that the shape of the two input - * arrays should be identical. - */ - void add(const Tensor& output, const Tensor& target); - - /** Returns a single value of mean square error. */ - double value() const; - - /** Sets all the counters to 0. */ - void reset(); - - private: - double curValue_; - int64_t curN_; +public: + /** Constructor of `MSEMeter`. An instance will maintain two + * counters initialized to 0: + * - `n`: total samples + * - `mse`: mean square error of samples + */ + MSEMeter(); + + /** Computes mean square error between two arrayfire arrays `output` and + * `target` and updates the counters. Note that the shape of the two input + * arrays should be identical. + */ + void add(const Tensor& output, const Tensor& target); + + /** Returns a single value of mean square error. */ + double value() const; + + /** Sets all the counters to 0. */ + void reset(); + +private: + double curValue_; + int64_t curN_; }; } // namespace fl diff --git a/flashlight/fl/meter/TimeMeter.cpp b/flashlight/fl/meter/TimeMeter.cpp index b227b7d..d5bde72 100644 --- a/flashlight/fl/meter/TimeMeter.cpp +++ b/flashlight/fl/meter/TimeMeter.cpp @@ -10,58 +10,58 @@ namespace fl { TimeMeter::TimeMeter(bool unit /* = false */) : useUnit_(unit) { - reset(); + reset(); } void TimeMeter::reset() { - curN_ = 0; - curValue_ = 0.; - isStopped_ = true; + curN_ = 0; + curValue_ = 0.; + isStopped_ = true; } void TimeMeter::set(double val, int64_t num /* = 1 */) { - curValue_ = val; - curN_ = num; - start_ = std::chrono::system_clock::now(); + curValue_ = val; + curN_ = num; + start_ = std::chrono::system_clock::now(); } double TimeMeter::value() const { - double val = curValue_; - if (!isStopped_) { - std::chrono::duration duration = - std::chrono::system_clock::now() - start_; - val += duration.count(); - } - if (useUnit_) { - val = (curN_ > 0) ? (val / curN_) : 0.0; - } - return val; + double val = curValue_; + if(!isStopped_) { + std::chrono::duration duration = + std::chrono::system_clock::now() - start_; + val += duration.count(); + } + if(useUnit_) { + val = (curN_ > 0) ? (val / curN_) : 0.0; + } + return val; } void TimeMeter::stop() { - if (isStopped_) { - return; - } - std::chrono::duration duration = - std::chrono::system_clock::now() - start_; - curValue_ += duration.count(); - isStopped_ = true; + if(isStopped_) { + return; + } + std::chrono::duration duration = + std::chrono::system_clock::now() - start_; + curValue_ += duration.count(); + isStopped_ = true; } void TimeMeter::resume() { - if (!isStopped_) { - return; - } - start_ = std::chrono::system_clock::now(); - isStopped_ = false; + if(!isStopped_) { + return; + } + start_ = std::chrono::system_clock::now(); + isStopped_ = false; } void TimeMeter::incUnit(int64_t num) { - curN_ += num; + curN_ += num; } void TimeMeter::stopAndIncUnit(int64_t num) { - stop(); - incUnit(num); + stop(); + incUnit(num); } } // namespace fl diff --git a/flashlight/fl/meter/TimeMeter.h b/flashlight/fl/meter/TimeMeter.h index d7aa3f7..f5dfb8a 100644 --- a/flashlight/fl/meter/TimeMeter.h +++ b/flashlight/fl/meter/TimeMeter.h @@ -26,42 +26,42 @@ namespace fl { * \endcode */ class FL_API TimeMeter { - public: - /** Constructor of `TimeMeter`. An instance will maintain a timer which is - * initialized as stopped. The flag `unit` indicates if there is multiple - * units running in sequential in the current timing period. - */ - explicit TimeMeter(bool unit = false); +public: + /** Constructor of `TimeMeter`. An instance will maintain a timer which is + * initialized as stopped. The flag `unit` indicates if there is multiple + * units running in sequential in the current timing period. + */ + explicit TimeMeter(bool unit = false); - /** Stops the timer if still running. If `unit` is `True`, returns the average - * time spend per unit, otherwise the total time in the current timing period. - * Time is measured in seconds. - */ - double value() const; + /** Stops the timer if still running. If `unit` is `True`, returns the average + * time spend per unit, otherwise the total time in the current timing period. + * Time is measured in seconds. + */ + double value() const; - /** Refreshes the counters and stops the timer. */ - void reset(); + /** Refreshes the counters and stops the timer. */ + void reset(); - /** Increases the number of units by `num`. */ - void incUnit(int64_t num = 1); + /** Increases the number of units by `num`. */ + void incUnit(int64_t num = 1); - /** Starts the timer. */ - void resume(); + /** Starts the timer. */ + void resume(); - /** Stops the timer. */ - void stop(); + /** Stops the timer. */ + void stop(); - /** Sets the number of units by `num` and the total time spend by `val`. */ - void set(double val, int64_t num = 1); + /** Sets the number of units by `num` and the total time spend by `val`. */ + void set(double val, int64_t num = 1); - /** Stops the timer and increase the number of units by `num`. */ - void stopAndIncUnit(int64_t num = 1); + /** Stops the timer and increase the number of units by `num`. */ + void stopAndIncUnit(int64_t num = 1); - private: - std::chrono::time_point start_; - double curValue_; - int64_t curN_; - bool isStopped_; - bool useUnit_; +private: + std::chrono::time_point start_; + double curValue_; + int64_t curN_; + bool isStopped_; + bool useUnit_; }; } // namespace fl diff --git a/flashlight/fl/meter/TopKMeter.cpp b/flashlight/fl/meter/TopKMeter.cpp index b7c617a..b24dce7 100644 --- a/flashlight/fl/meter/TopKMeter.cpp +++ b/flashlight/fl/meter/TopKMeter.cpp @@ -12,43 +12,46 @@ namespace fl { -TopKMeter::TopKMeter(const int k) : k_(k), correct_(0), n_(0){}; +TopKMeter::TopKMeter(const int k) : k_(k), + correct_(0), + n_(0) {}; void TopKMeter::add(const Tensor& output, const Tensor& target) { - if (output.dim(1) != target.dim(0)) { - throw std::invalid_argument("dimension mismatch in TopKMeter"); - } - if (target.ndim() != 1) { - throw std::invalid_argument( - "output/target must be 1-dimensional for TopKMeter"); - } - - Tensor maxVals, maxIds, match; - topk(maxVals, maxIds, output, k_, 0); - match = maxIds == fl::reshape(target, {1, target.dim(0), 1, 1}); - const Tensor correct = fl::any(match, {0}); - - correct_ += fl::countNonzero(correct).asScalar(); - const int batchsize = target.dim(0); - n_ += batchsize; + if(output.dim(1) != target.dim(0)) { + throw std::invalid_argument("dimension mismatch in TopKMeter"); + } + if(target.ndim() != 1) { + throw std::invalid_argument( + "output/target must be 1-dimensional for TopKMeter" + ); + } + + Tensor maxVals, maxIds, match; + topk(maxVals, maxIds, output, k_, 0); + match = maxIds == fl::reshape(target, {1, target.dim(0), 1, 1}); + const Tensor correct = fl::any(match, {0}); + + correct_ += fl::countNonzero(correct).asScalar(); + const int batchsize = target.dim(0); + n_ += batchsize; } void TopKMeter::reset() { - correct_ = 0; - n_ = 0; + correct_ = 0; + n_ = 0; } double TopKMeter::value() const { - return (static_cast(correct_) / n_) * 100.0f; + return (static_cast(correct_) / n_) * 100.0f; } std::pair TopKMeter::getStats() { - return std::make_pair(correct_, n_); + return std::make_pair(correct_, n_); } void TopKMeter::set(int32_t correct, int32_t n) { - n_ = n; - correct_ = correct; + n_ = n; + correct_ = correct; } } // namespace fl diff --git a/flashlight/fl/meter/TopKMeter.h b/flashlight/fl/meter/TopKMeter.h index 4c70684..b40e888 100644 --- a/flashlight/fl/meter/TopKMeter.h +++ b/flashlight/fl/meter/TopKMeter.h @@ -31,30 +31,30 @@ class Tensor; * \endcode */ class FL_API TopKMeter { - public: - /** Constructor of `TopKMeter`. - * @param k number of top predictions in order to be considered correct - * Will have two counters: - * - `correct`: total number of correct predictions - * - `n`: total number of of predictions - */ - explicit TopKMeter(const int k); +public: + /** Constructor of `TopKMeter`. + * @param k number of top predictions in order to be considered correct + * Will have two counters: + * - `correct`: total number of correct predictions + * - `n`: total number of of predictions + */ + explicit TopKMeter(const int k); - void add(const Tensor& output, const Tensor& target); + void add(const Tensor& output, const Tensor& target); - void reset(); + void reset(); - // Used for distributed syncing - void set(int32_t correct, int32_t n); + // Used for distributed syncing + void set(int32_t correct, int32_t n); - std::pair getStats(); + std::pair getStats(); - double value() const; + double value() const; - private: - int k_; - int32_t correct_; - int32_t n_; +private: + int k_; + int32_t correct_; + int32_t n_; }; } // namespace fl diff --git a/flashlight/fl/nn/DistributedUtils.cpp b/flashlight/fl/nn/DistributedUtils.cpp index 753315f..cea102c 100644 --- a/flashlight/fl/nn/DistributedUtils.cpp +++ b/flashlight/fl/nn/DistributedUtils.cpp @@ -15,31 +15,34 @@ namespace fl { void distributeModuleGrads( std::shared_ptr module, - std::shared_ptr reducer) { - for (auto& param : module->params()) { - param.registerGradHook([reducer](Variable& grad) { reducer->add(grad); }); - } + std::shared_ptr reducer +) { + for(auto& param : module->params()) { + param.registerGradHook([reducer](Variable& grad) { reducer->add(grad); }); + } } void allReduceParameters(std::shared_ptr module) { - if (!module) { - throw std::invalid_argument("null module passed to allReduceParameters"); - } - double scale = 1.0 / getWorldSize(); - for (auto& param : module->params()) { - allReduce(param, scale); - } + if(!module) { + throw std::invalid_argument("null module passed to allReduceParameters"); + } + double scale = 1.0 / getWorldSize(); + for(auto& param : module->params()) { + allReduce(param, scale); + } } void allReduceGradients( std::shared_ptr module, - double scale /*= 1.0 */) { - if (!module) { - throw std::invalid_argument("null module passed to allReduceGradients"); - } - for (auto& param : module->params()) { - allReduce(param.grad(), scale); - }; + double scale /*= 1.0 */ +) { + if(!module) { + throw std::invalid_argument("null module passed to allReduceGradients"); + } + for(auto& param : module->params()) { + allReduce(param.grad(), scale); + } + ; } } // namespace fl diff --git a/flashlight/fl/nn/DistributedUtils.h b/flashlight/fl/nn/DistributedUtils.h index 494cacf..91d6e50 100644 --- a/flashlight/fl/nn/DistributedUtils.h +++ b/flashlight/fl/nn/DistributedUtils.h @@ -31,7 +31,8 @@ namespace fl { FL_API void distributeModuleGrads( std::shared_ptr module, std::shared_ptr reducer = - std::make_shared(1.0 / getWorldSize())); + std::make_shared(1.0 / getWorldSize()) +); /** * Traverses the network and averages its parameters with allreduce. @@ -49,7 +50,8 @@ FL_API void allReduceParameters(std::shared_ptr module); */ FL_API void allReduceGradients( std::shared_ptr module, - double scale = 1.0); + double scale = 1.0 +); /** @} */ diff --git a/flashlight/fl/nn/Init.cpp b/flashlight/fl/nn/Init.cpp index 6236cca..0e8d379 100644 --- a/flashlight/fl/nn/Init.cpp +++ b/flashlight/fl/nn/Init.cpp @@ -19,100 +19,106 @@ namespace fl { namespace detail { -Tensor uniform(const Shape& shape, double min, double max, fl::dtype type) { - Tensor result = fl::rand(shape, type); - result = (max - min) * result + min; - return result; -} -Tensor normal(const Shape& shape, double stdv, double mean, fl::dtype type) { - Tensor result = fl::randn(shape, type); - result = stdv * result + mean; - return result; -} - -Tensor kaimingUniform( - const Shape& shape, - int fanIn, - fl::dtype type /* = fl::dtype::f32 */) { - double stdv = std::sqrt(1.0 / static_cast(fanIn)); - double limit = std::sqrt(3.0) * stdv; - return detail::uniform(shape, -limit, limit, type); -} - -Tensor kaimingNormal( - const Shape& shape, - int fanIn, - fl::dtype type /* = fl::dtype::f32 */) { - double stdv = std::sqrt(1.0 / static_cast(fanIn)); - return detail::normal(shape, stdv, 0, type); -} - -Tensor glorotUniform( - const Shape& shape, - int fanIn, - int fanOut, - fl::dtype type /* = fl::dtype::f32 */) { - double stdv = std::sqrt(2.0 / static_cast(fanIn + fanOut)); - double limit = std::sqrt(3.0) * stdv; - return detail::uniform(shape, -limit, limit, type); -} - -Tensor glorotNormal( - const Shape& shape, - int fanIn, - int fanOut, - fl::dtype type /* = fl::dtype::f32 */) { - double stdv = std::sqrt(2.0 / static_cast(fanIn + fanOut)); - return detail::normal(shape, stdv, 0, type); -} - -Tensor erfinv(const Tensor& y) { - if (fl::any(fl::abs(y) >= 1.).scalar()) { - throw std::runtime_error("[erfinv] input is out of range (-1, 1)"); - } - double a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331}; - double b[4] = {-2.118377725, 1.442710462, -0.329097515, 0.012229801}; - double c[4] = {-1.970840454, -1.624906493, 3.429567803, 1.641345311}; - double d[2] = {3.543889200, 1.637067800}; - - auto centralMask = fl::abs(y) <= 0.7; - - auto z = y * y; - auto num = (((a[3] * z + a[2]) * z + a[1]) * z + a[0]); - auto dem = ((((b[3] * z + b[2]) * z + b[1]) * z + b[0]) * z + 1.0); - z = y * num / dem; - auto x = z * centralMask; - - z = fl::sqrt(-fl::log((1.0 - fl::abs(y)) / 2.0)); - num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0]; - dem = (d[1] * z + d[0]) * z + 1.0; - // TODO{fl::Tensor}{operator} - check af::sign - zero case? - z = fl::sign(y).astype(fl::dtype::f32); // -1 for negative, 1 for positive - z = z * num / dem; - x = x + z * !centralMask; - - /* Two steps of Newton-Raphson correction */ - x = x - (fl::erf(x) - y) / ((2.0 / std::sqrt(M_PI)) * fl::exp(-x * x)); - x = x - (fl::erf(x) - y) / ((2.0 / std::sqrt(M_PI)) * fl::exp(-x * x)); - if (fl::any(fl::isnan(x)).asScalar() || - fl::any(fl::isinf(x)).asScalar()) { - throw std::runtime_error("[erfinv] invalid result"); - } - return x; -} + Tensor uniform(const Shape& shape, double min, double max, fl::dtype type) { + Tensor result = fl::rand(shape, type); + result = (max - min) * result + min; + return result; + } + Tensor normal(const Shape& shape, double stdv, double mean, fl::dtype type) { + Tensor result = fl::randn(shape, type); + result = stdv * result + mean; + return result; + } + + Tensor kaimingUniform( + const Shape& shape, + int fanIn, + fl::dtype type /* = fl::dtype::f32 */ + ) { + double stdv = std::sqrt(1.0 / static_cast(fanIn)); + double limit = std::sqrt(3.0) * stdv; + return detail::uniform(shape, -limit, limit, type); + } + + Tensor kaimingNormal( + const Shape& shape, + int fanIn, + fl::dtype type /* = fl::dtype::f32 */ + ) { + double stdv = std::sqrt(1.0 / static_cast(fanIn)); + return detail::normal(shape, stdv, 0, type); + } + + Tensor glorotUniform( + const Shape& shape, + int fanIn, + int fanOut, + fl::dtype type /* = fl::dtype::f32 */ + ) { + double stdv = std::sqrt(2.0 / static_cast(fanIn + fanOut)); + double limit = std::sqrt(3.0) * stdv; + return detail::uniform(shape, -limit, limit, type); + } + + Tensor glorotNormal( + const Shape& shape, + int fanIn, + int fanOut, + fl::dtype type /* = fl::dtype::f32 */ + ) { + double stdv = std::sqrt(2.0 / static_cast(fanIn + fanOut)); + return detail::normal(shape, stdv, 0, type); + } + + Tensor erfinv(const Tensor& y) { + if(fl::any(fl::abs(y) >= 1.).scalar()) { + throw std::runtime_error("[erfinv] input is out of range (-1, 1)"); + } + double a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331}; + double b[4] = {-2.118377725, 1.442710462, -0.329097515, 0.012229801}; + double c[4] = {-1.970840454, -1.624906493, 3.429567803, 1.641345311}; + double d[2] = {3.543889200, 1.637067800}; + + auto centralMask = fl::abs(y) <= 0.7; + + auto z = y * y; + auto num = (((a[3] * z + a[2]) * z + a[1]) * z + a[0]); + auto dem = ((((b[3] * z + b[2]) * z + b[1]) * z + b[0]) * z + 1.0); + z = y * num / dem; + auto x = z * centralMask; + + z = fl::sqrt(-fl::log((1.0 - fl::abs(y)) / 2.0)); + num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0]; + dem = (d[1] * z + d[0]) * z + 1.0; + // TODO{fl::Tensor}{operator} - check af::sign - zero case? + z = fl::sign(y).astype(fl::dtype::f32); // -1 for negative, 1 for positive + z = z * num / dem; + x = x + z * !centralMask; + + /* Two steps of Newton-Raphson correction */ + x = x - (fl::erf(x) - y) / ((2.0 / std::sqrt(M_PI)) * fl::exp(-x * x)); + x = x - (fl::erf(x) - y) / ((2.0 / std::sqrt(M_PI)) * fl::exp(-x * x)); + if( + fl::any(fl::isnan(x)).asScalar() + || fl::any(fl::isinf(x)).asScalar() + ) { + throw std::runtime_error("[erfinv] invalid result"); + } + return x; + } } // namespace detail Variable input(const Tensor& arr) { - return Variable(arr, false); + return Variable(arr, false); } Variable noGrad(const Tensor& arr) { - return Variable(arr, false); + return Variable(arr, false); } Variable param(const Tensor& arr) { - return Variable(arr, true); + return Variable(arr, true); } Variable constant( @@ -120,29 +126,29 @@ Variable constant( int outputSize, int inputSize, fl::dtype type, - bool calcGrad) { - return constant(val, Shape({outputSize, inputSize}), type, calcGrad); + bool calcGrad +) { + return constant(val, Shape({outputSize, inputSize}), type, calcGrad); } -Variable -constant(double val, const Shape& dims, fl::dtype type, bool calcGrad) { - return Variable(fl::full(dims, val, type), calcGrad); +Variable constant(double val, const Shape& dims, fl::dtype type, bool calcGrad) { + return Variable(fl::full(dims, val, type), calcGrad); } -Variable -identity(int outputSize, int inputSize, fl::dtype type, bool calcGrad) { - // TODO{fl::Tensor}{fixme} add non-square identity to API - if (inputSize != outputSize) { - throw std::invalid_argument( - "identity - can't create tensor with " - "different in and output size - only square identity " - "tensors supported"); - } - return identity(Shape({inputSize, outputSize}), type, calcGrad); +Variable identity(int outputSize, int inputSize, fl::dtype type, bool calcGrad) { + // TODO{fl::Tensor}{fixme} add non-square identity to API + if(inputSize != outputSize) { + throw std::invalid_argument( + "identity - can't create tensor with " + "different in and output size - only square identity " + "tensors supported" + ); + } + return identity(Shape({inputSize, outputSize}), type, calcGrad); } Variable identity(const Shape& dims, fl::dtype type, bool calcGrad) { - return Variable(fl::identity(dims.dim(0), type), calcGrad); + return Variable(fl::identity(dims.dim(0), type), calcGrad); } Variable uniform( @@ -151,8 +157,9 @@ Variable uniform( double min, double max, fl::dtype type, - bool calcGrad) { - return uniform(Shape({outputSize, inputSize}), min, max, type, calcGrad); + bool calcGrad +) { + return uniform(Shape({outputSize, inputSize}), min, max, type, calcGrad); } Variable uniform( @@ -160,8 +167,9 @@ Variable uniform( double min, double max, fl::dtype type, - bool calcGrad) { - return Variable(detail::uniform(dims, min, max, type), calcGrad); + bool calcGrad +) { + return Variable(detail::uniform(dims, min, max, type), calcGrad); } Variable normal( @@ -170,8 +178,9 @@ Variable normal( double stdv, double mean, fl::dtype type, - bool calcGrad) { - return normal(Shape({outputSize, inputSize}), stdv, mean, type, calcGrad); + bool calcGrad +) { + return normal(Shape({outputSize, inputSize}), stdv, mean, type, calcGrad); } Variable normal( @@ -179,24 +188,27 @@ Variable normal( double stdv, double mean, fl::dtype type, - bool calcGrad) { - return Variable(detail::normal(dims, stdv, mean, type), calcGrad); + bool calcGrad +) { + return Variable(detail::normal(dims, stdv, mean, type), calcGrad); } Variable kaimingUniform( const Shape& shape, int fanIn, fl::dtype type /* = fl::dtype::f32 */, - bool calcGrad /* = true */) { - return Variable(detail::kaimingUniform(shape, fanIn, type), calcGrad); + bool calcGrad /* = true */ +) { + return Variable(detail::kaimingUniform(shape, fanIn, type), calcGrad); } Variable kaimingNormal( const Shape& shape, int fanIn, fl::dtype type /* = fl::dtype::f32 */, - bool calcGrad /* = true */) { - return Variable(detail::kaimingNormal(shape, fanIn, type), calcGrad); + bool calcGrad /* = true */ +) { + return Variable(detail::kaimingNormal(shape, fanIn, type), calcGrad); } Variable glorotUniform( @@ -204,8 +216,9 @@ Variable glorotUniform( int fanIn, int fanOut, fl::dtype type /* = fl::dtype::f32 */, - bool calcGrad /* = true */) { - return Variable(detail::glorotUniform(shape, fanIn, fanOut, type), calcGrad); + bool calcGrad /* = true */ +) { + return Variable(detail::glorotUniform(shape, fanIn, fanOut, type), calcGrad); } Variable glorotNormal( @@ -213,8 +226,9 @@ Variable glorotNormal( int fanIn, int fanOut, fl::dtype type /* = fl::dtype::f32 */, - bool calcGrad /* = true */) { - return Variable(detail::glorotNormal(shape, fanIn, fanOut, type), calcGrad); + bool calcGrad /* = true */ +) { + return Variable(detail::glorotNormal(shape, fanIn, fanOut, type), calcGrad); } Variable truncNormal( @@ -224,22 +238,23 @@ Variable truncNormal( double minCufOff, double maxCutOff, fl::dtype type, - bool calcGrad) { - // following: https://git.io/JYYAr - auto normCdf = [](double x) { - return (1. + std::erf(x / std::sqrt(2.))) / 2.; - }; - - auto l = 2 * normCdf((minCufOff - mean) / stdv) - 1; - auto u = 2 * normCdf((maxCutOff - mean) / stdv) - 1; - - float eps = 1e-7; - auto result = fl::rand(shape, type) * (u - l) + l; - result = fl::clip(result, -1 + eps, 1 - eps); // make sure erf is in range - result = detail::erfinv(result); - result = mean + result * (stdv * std::sqrt(2.)); - result = fl::clip(result, minCufOff, maxCutOff); - return Variable(result, calcGrad); + bool calcGrad +) { + // following: https://git.io/JYYAr + auto normCdf = [](double x) { + return (1. + std::erf(x / std::sqrt(2.))) / 2.; + }; + + auto l = 2 * normCdf((minCufOff - mean) / stdv) - 1; + auto u = 2 * normCdf((maxCutOff - mean) / stdv) - 1; + + float eps = 1e-7; + auto result = fl::rand(shape, type) * (u - l) + l; + result = fl::clip(result, -1 + eps, 1 - eps); // make sure erf is in range + result = detail::erfinv(result); + result = mean + result * (stdv * std::sqrt(2.)); + result = fl::clip(result, minCufOff, maxCutOff); + return Variable(result, calcGrad); } } // namespace fl diff --git a/flashlight/fl/nn/Init.h b/flashlight/fl/nn/Init.h index d91106f..6bb0422 100644 --- a/flashlight/fl/nn/Init.h +++ b/flashlight/fl/nn/Init.h @@ -42,11 +42,12 @@ namespace detail { * * \ingroup nn_init_utils */ -FL_API Tensor uniform( - const Shape& shape, - double min = 0, - double max = 1, - fl::dtype type = fl::dtype::f32); + FL_API Tensor uniform( + const Shape& shape, + double min = 0, + double max = 1, + fl::dtype type = fl::dtype::f32 + ); /** * Creates a `Tensor` representing a tensor of up to rank 4 with arbitrary @@ -64,11 +65,12 @@ FL_API Tensor uniform( * * \ingroup nn_init_utils */ -FL_API Tensor normal( - const Shape& shape, - double stdv = 1, - double mean = 0, - fl::dtype type = fl::dtype::f32); + FL_API Tensor normal( + const Shape& shape, + double stdv = 1, + double mean = 0, + fl::dtype type = fl::dtype::f32 + ); /** * Creates a `Tensor` representing a tensor with given input dimensions where @@ -87,8 +89,7 @@ FL_API Tensor normal( * * \ingroup nn_init_utils */ -FL_API Tensor -kaimingUniform(const Shape& shape, int fanIn, fl::dtype type = fl::dtype::f32); + FL_API Tensor kaimingUniform(const Shape& shape, int fanIn, fl::dtype type = fl::dtype::f32); /** * Creates a `Tensor` representing a tensor with given input dimensions * where elements are normally distributed according to the method outlined in @@ -104,8 +105,7 @@ kaimingUniform(const Shape& shape, int fanIn, fl::dtype type = fl::dtype::f32); * * \ingroup nn_init_utils */ -FL_API Tensor -kaimingNormal(const Shape& shape, int fanIn, fl::dtype type = fl::dtype::f32); + FL_API Tensor kaimingNormal(const Shape& shape, int fanIn, fl::dtype type = fl::dtype::f32); /** * Creates a `Tensor` representing a tensor with given input dimensions @@ -124,11 +124,12 @@ kaimingNormal(const Shape& shape, int fanIn, fl::dtype type = fl::dtype::f32); * * \ingroup nn_init_utils */ -FL_API Tensor glorotUniform( - const Shape& shape, - int fanIn, - int fanOut, - fl::dtype type = fl::dtype::f32); + FL_API Tensor glorotUniform( + const Shape& shape, + int fanIn, + int fanOut, + fl::dtype type = fl::dtype::f32 + ); /** * Creates a `Tensor` representing a tensor with given input dimensions @@ -148,11 +149,12 @@ FL_API Tensor glorotUniform( * * \ingroup nn_init_utils */ -FL_API Tensor glorotNormal( - const Shape& shape, - int fanIn, - int fanOut, - fl::dtype type = fl::dtype::f32); + FL_API Tensor glorotNormal( + const Shape& shape, + int fanIn, + int fanOut, + fl::dtype type = fl::dtype::f32 + ); /* * Approximation of inverse error function. @@ -195,7 +197,7 @@ FL_API Tensor glorotNormal( USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -FL_API Tensor erfinv(const Tensor& y); + FL_API Tensor erfinv(const Tensor& y); } // namespace detail @@ -250,7 +252,8 @@ FL_API Variable constant( int inputSize, int outputSize, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing a tensor of up to rank 4 with arbitrary @@ -270,7 +273,8 @@ FL_API Variable constant( double val, const Shape& shape, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing a scalar with a given value and type. @@ -284,10 +288,9 @@ FL_API Variable constant( * * \ingroup nn_init_utils */ -template -Variable -scalar(T val, fl::dtype type = dtype_traits::ctype, bool calcGrad = true) { - return Variable(fromScalar(val, type), calcGrad); +template +Variable scalar(T val, fl::dtype type = dtype_traits::ctype, bool calcGrad = true) { + return Variable(fromScalar(val, type), calcGrad); } /** @@ -308,7 +311,8 @@ FL_API Variable identity( int inputSize, int outputSize, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing an identity tensor of up to rank 4 with @@ -326,7 +330,8 @@ FL_API Variable identity( FL_API Variable identity( const Shape& shape, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing a tensor with dimensions `[inputSize, @@ -353,7 +358,8 @@ FL_API Variable uniform( double min = 0, double max = 1, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing a tensor of up to rank 4 with arbitrary @@ -378,7 +384,8 @@ FL_API Variable uniform( double min = 0, double max = 1, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing a tensor with dimensions `[inputSize, @@ -405,7 +412,8 @@ FL_API Variable normal( double stdv = 1, double mean = 0, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing a tensor of up to rank 4 with arbitrary @@ -430,7 +438,8 @@ FL_API Variable normal( double stdv = 1, double mean = 0, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing a tensor with given input dimensions where @@ -453,7 +462,8 @@ FL_API Variable kaimingUniform( const Shape& shape, int fanIn, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing a tensor with given input dimensions where * elements are normally distributed according to the method @@ -476,7 +486,8 @@ FL_API Variable kaimingNormal( const Shape& shape, int fanIn, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing a tensor with given input dimensions where @@ -502,7 +513,8 @@ FL_API Variable glorotUniform( int fanIn, int fanOut, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing a tensor with given input dimensions where @@ -529,7 +541,8 @@ FL_API Variable glorotNormal( int fanIn, int fanOut, fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); /** * Creates a `Variable` representing a tensor with given input dimensions where @@ -557,6 +570,7 @@ FL_API Variable truncNormal( double minCufOff = -2., double maxCutOff = 2., fl::dtype type = fl::dtype::f32, - bool calcGrad = true); + bool calcGrad = true +); } // namespace fl diff --git a/flashlight/fl/nn/Utils.cpp b/flashlight/fl/nn/Utils.cpp index e7afecd..3bf2640 100644 --- a/flashlight/fl/nn/Utils.cpp +++ b/flashlight/fl/nn/Utils.cpp @@ -17,145 +17,150 @@ namespace fl { int64_t numTotalParams(std::shared_ptr module) { - int64_t params = 0; - for (auto& p : module->params()) { - params += p.elements(); - } - return params; + int64_t params = 0; + for(auto& p : module->params()) { + params += p.elements(); + } + return params; } bool allParamsClose( const Module& a, const Module& b, - double absTolerance /* = 1e-5 */) { - if (a.params().size() != b.params().size()) { - return false; - } - const auto aParams = a.params(); - const auto bParams = b.params(); - for (int p = 0; p < aParams.size(); ++p) { - if (!allClose(aParams[p], bParams[p], absTolerance)) { - return false; + double absTolerance /* = 1e-5 */ +) { + if(a.params().size() != b.params().size()) { + return false; + } + const auto aParams = a.params(); + const auto bParams = b.params(); + for(int p = 0; p < aParams.size(); ++p) { + if(!allClose(aParams[p], bParams[p], absTolerance)) { + return false; + } } - } - return true; + return true; } namespace detail { -int64_t getNumRnnParams( - int input_size, - int hidden_size, - int num_layers, - RnnMode mode, - bool bidirectional) { - int bidir_mul = (bidirectional ? 2 : 1); - - int64_t i = input_size; - int64_t h = hidden_size; - int64_t n = num_layers; - int64_t b = bidir_mul; - - int64_t n_params = - /* hidden-to-hidden */ - h * h * n + - /* hidden biases */ - h * n + - /* input-to-hidden */ - i * h + b * (n - 1) * h * h + - /* input biases */ - h * n; - - n_params *= b; - - switch (mode) { - case RnnMode::LSTM: - n_params *= 4; - break; - case RnnMode::GRU: - n_params *= 3; - break; - case RnnMode::RELU: - case RnnMode::TANH: - default: - break; - } - - return n_params; -} + int64_t getNumRnnParams( + int input_size, + int hidden_size, + int num_layers, + RnnMode mode, + bool bidirectional + ) { + int bidir_mul = (bidirectional ? 2 : 1); + + int64_t i = input_size; + int64_t h = hidden_size; + int64_t n = num_layers; + int64_t b = bidir_mul; + + int64_t n_params = + /* hidden-to-hidden */ + h * h * n + + /* hidden biases */ + h * n + + /* input-to-hidden */ + i * h + b * (n - 1) * h * h + + /* input biases */ + h * n; + + n_params *= b; + + switch(mode) { + case RnnMode::LSTM: + n_params *= 4; + break; + case RnnMode::GRU: + n_params *= 3; + break; + case RnnMode::RELU: + case RnnMode::TANH: + default: + break; + } + + return n_params; + } } // namespace detail int derivePadding(int inSz, int filterSz, int stride, int pad, int dilation) { - if (pad == static_cast(PaddingMode::SAME)) { - int newPad; - if (inSz % stride == 0) { - newPad = (filterSz - 1) * dilation - stride + 1; - } else { - newPad = (filterSz - 1) * dilation - (inSz % stride) + 1; + if(pad == static_cast(PaddingMode::SAME)) { + int newPad; + if(inSz % stride == 0) { + newPad = (filterSz - 1) * dilation - stride + 1; + } else { + newPad = (filterSz - 1) * dilation - (inSz % stride) + 1; + } + newPad = (newPad + 1) / 2; // equal pad on both sides + return std::max(newPad, 0); } - newPad = (newPad + 1) / 2; // equal pad on both sides - return std::max(newPad, 0); - } - return pad; + return pad; } Tensor join( const std::vector& inputs, double padValue /* = 0.0 */, - int batchDim /* = -1 */) { - if (inputs.empty()) { - return Tensor(); - } - - Dim maxNumDims = 0; - for (const auto& in : inputs) { - if (in.ndim() > maxNumDims) { - maxNumDims = in.ndim(); + int batchDim /* = -1 */ +) { + if(inputs.empty()) { + return Tensor(); } - } - // If the batch dim > the max number of dims, make those dims singleton - int outNdims = std::max(batchDim + 1, static_cast(maxNumDims)); + Dim maxNumDims = 0; + for(const auto& in : inputs) { + if(in.ndim() > maxNumDims) { + maxNumDims = in.ndim(); + } + } - Shape maxDims(std::vector(outNdims, 1)); + // If the batch dim > the max number of dims, make those dims singleton + int outNdims = std::max(batchDim + 1, static_cast(maxNumDims)); + + Shape maxDims(std::vector(outNdims, 1)); + + fl::dtype type = inputs[0].type(); + bool isEmpty = true; + for(const auto& in : inputs) { + isEmpty = isEmpty && in.isEmpty(); + for(int d = 0; d < in.ndim(); ++d) { + maxDims[d] = std::max(maxDims[d], in.dim(d)); + if(in.type() != type) { + throw std::invalid_argument( + "join: all arrays should of same type for join" + ); + } + } + } - fl::dtype type = inputs[0].type(); - bool isEmpty = true; - for (const auto& in : inputs) { - isEmpty = isEmpty && in.isEmpty(); - for (int d = 0; d < in.ndim(); ++d) { - maxDims[d] = std::max(maxDims[d], in.dim(d)); - if (in.type() != type) { + if(batchDim < 0) { + batchDim = maxDims.ndim() - 1; + } + if(batchDim < maxDims.ndim() && maxDims[batchDim] > 1) { throw std::invalid_argument( - "join: all arrays should of same type for join"); - } + "join: no singleton dim available for batching" + ); } - } - - if (batchDim < 0) { - batchDim = maxDims.ndim() - 1; - } - if (batchDim < maxDims.ndim() && maxDims[batchDim] > 1) { - throw std::invalid_argument( - "join: no singleton dim available for batching"); - } - maxDims[batchDim] = inputs.size(); - if (isEmpty) { - return Tensor(maxDims, type); - } - auto padSeq = fl::full(maxDims, padValue, type); - std::vector sel( - std::max(maxNumDims, static_cast(batchDim + 1)), fl::span); - for (int i = 0; i < inputs.size(); ++i) { - for (int d = 0; d < maxNumDims; ++d) { - sel[d] = fl::range(inputs[i].dim(d)); + maxDims[batchDim] = inputs.size(); + if(isEmpty) { + return Tensor(maxDims, type); } - sel[batchDim] = fl::range(i, i + 1); - if (!inputs[i].isEmpty()) { - padSeq(sel) = inputs[i]; + auto padSeq = fl::full(maxDims, padValue, type); + std::vector sel( + std::max(maxNumDims, static_cast(batchDim + 1)), fl::span); + for(int i = 0; i < inputs.size(); ++i) { + for(int d = 0; d < maxNumDims; ++d) { + sel[d] = fl::range(inputs[i].dim(d)); + } + sel[batchDim] = fl::range(i, i + 1); + if(!inputs[i].isEmpty()) { + padSeq(sel) = inputs[i]; + } } - } - return padSeq; + return padSeq; } } // namespace fl diff --git a/flashlight/fl/nn/Utils.h b/flashlight/fl/nn/Utils.h index fd13755..d90532c 100644 --- a/flashlight/fl/nn/Utils.h +++ b/flashlight/fl/nn/Utils.h @@ -37,37 +37,38 @@ FL_API int64_t numTotalParams(std::shared_ptr module); * @param absTolerance absolute tolerance allowed * */ -FL_API bool -allParamsClose(const Module& a, const Module& b, double absTolerance = 1e-5); +FL_API bool allParamsClose(const Module& a, const Module& b, double absTolerance = 1e-5); namespace detail { -FL_API int64_t getNumRnnParams( - int input_size, - int hidden_size, - int num_layers, - RnnMode mode, - bool bidirectional); + FL_API int64_t getNumRnnParams( + int input_size, + int hidden_size, + int num_layers, + RnnMode mode, + bool bidirectional + ); /// used for Conv2D and Pool2D params -struct IntOrPadMode { - /* implicit */ IntOrPadMode(int val) : padVal(val) {} - /* implicit */ IntOrPadMode(PaddingMode mode) - : padVal(static_cast(mode)) {} - const int padVal; -}; + struct IntOrPadMode { + /* implicit */ + IntOrPadMode(int val) : padVal(val) {} + /* implicit */ IntOrPadMode(PaddingMode mode) + : padVal(static_cast(mode)) {} + const int padVal; + }; } // namespace detail -FL_API int -derivePadding(int inSz, int filterSz, int stride, int pad, int dilation); +FL_API int derivePadding(int inSz, int filterSz, int stride, int pad, int dilation); /// packs a list of arrays (possibly of different dimensions) to a single array /// by padding them to same dimensions FL_API Tensor join( const std::vector& inputs, double padValue = 0.0, - int batchDim = -1); + int batchDim = -1 +); /** @} */ diff --git a/flashlight/fl/nn/modules/Activations.cpp b/flashlight/fl/nn/modules/Activations.cpp index b9cb945..c583f03 100644 --- a/flashlight/fl/nn/modules/Activations.cpp +++ b/flashlight/fl/nn/modules/Activations.cpp @@ -16,191 +16,191 @@ namespace fl { Sigmoid::Sigmoid() = default; Variable Sigmoid::forward(const Variable& input) { - return sigmoid(input); + return sigmoid(input); } std::unique_ptr Sigmoid::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Sigmoid::prettyString() const { - return "Sigmoid"; + return "Sigmoid"; } Log::Log() = default; Variable Log::forward(const Variable& input) { - return log(input); + return log(input); } std::unique_ptr Log::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Log::prettyString() const { - return "Log"; + return "Log"; } Tanh::Tanh() = default; Variable Tanh::forward(const Variable& input) { - return tanh(input); + return tanh(input); } std::unique_ptr Tanh::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Tanh::prettyString() const { - return "Tanh"; + return "Tanh"; } HardTanh::HardTanh() = default; Variable HardTanh::forward(const Variable& input) { - return clamp(input, -1.0, 1.0); + return clamp(input, -1.0, 1.0); } std::unique_ptr HardTanh::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string HardTanh::prettyString() const { - return "HardTanh"; + return "HardTanh"; } ReLU::ReLU() = default; Variable ReLU::forward(const Variable& input) { - return max(input, 0.0); + return max(input, 0.0); } std::unique_ptr ReLU::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string ReLU::prettyString() const { - return "ReLU"; + return "ReLU"; } ReLU6::ReLU6() = default; Variable ReLU6::forward(const Variable& input) { - return clamp(input, 0.0, 6.0); + return clamp(input, 0.0, 6.0); } std::unique_ptr ReLU6::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string ReLU6::prettyString() const { - return "ReLU6"; + return "ReLU6"; } LeakyReLU::LeakyReLU(double slope) : mSlope_(slope) {} Variable LeakyReLU::forward(const Variable& input) { - return max(input, mSlope_ * input); + return max(input, mSlope_ * input); } std::unique_ptr LeakyReLU::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string LeakyReLU::prettyString() const { - return "LeakyReLU (" + std::to_string(mSlope_) + ")"; + return "LeakyReLU (" + std::to_string(mSlope_) + ")"; } PReLU::PReLU(const Variable& w) : UnaryModule({w}) {} PReLU::PReLU(int size, double value) { - auto w = constant(value, size, 1); - params_ = {w}; + auto w = constant(value, size, 1); + params_ = {w}; } Variable PReLU::forward(const Variable& input) { - auto mask = input >= 0.0; - return (input * mask) + (input * !mask * tileAs(params_[0], input)); + auto mask = input >= 0.0; + return (input * mask) + (input * !mask * tileAs(params_[0], input)); } std::unique_ptr PReLU::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string PReLU::prettyString() const { - return "PReLU"; + return "PReLU"; } ELU::ELU(double alpha) : mAlpha_(alpha) {} Variable ELU::forward(const Variable& input) { - auto mask = input >= 0.0; - return (mask * input) + (!mask * mAlpha_ * (exp(input) - 1)); + auto mask = input >= 0.0; + return (mask * input) + (!mask * mAlpha_ * (exp(input) - 1)); } std::unique_ptr ELU::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string ELU::prettyString() const { - return "ELU (" + std::to_string(mAlpha_) + ")"; + return "ELU (" + std::to_string(mAlpha_) + ")"; } ThresholdReLU::ThresholdReLU(double threshold) : mThreshold_(threshold) {} Variable ThresholdReLU::forward(const Variable& input) { - auto mask = input >= mThreshold_; - return input * mask; + auto mask = input >= mThreshold_; + return input * mask; } std::unique_ptr ThresholdReLU::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string ThresholdReLU::prettyString() const { - return "ThresholdReLU (" + std::to_string(mThreshold_) + ")"; + return "ThresholdReLU (" + std::to_string(mThreshold_) + ")"; } GatedLinearUnit::GatedLinearUnit(int dim) : dim_(dim) {} Variable GatedLinearUnit::forward(const Variable& input) { - return gatedlinearunit(input, dim_); + return gatedlinearunit(input, dim_); } std::unique_ptr GatedLinearUnit::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string GatedLinearUnit::prettyString() const { - return "GatedLinearUnit (" + std::to_string(dim_) + ")"; + return "GatedLinearUnit (" + std::to_string(dim_) + ")"; } LogSoftmax::LogSoftmax(int dim /* = 0 */) : dim_(dim) {} Variable LogSoftmax::forward(const Variable& input) { - return logSoftmax(input, dim_); + return logSoftmax(input, dim_); } std::unique_ptr LogSoftmax::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string LogSoftmax::prettyString() const { - return "LogSoftmax (" + std::to_string(dim_) + ")"; + return "LogSoftmax (" + std::to_string(dim_) + ")"; } Swish::Swish(double beta /* = 1.0 */) : beta_(beta) {} Variable Swish::forward(const Variable& input) { - return swish(input, beta_); + return swish(input, beta_); } std::unique_ptr Swish::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Swish::prettyString() const { - return "Swish (" + std::to_string(beta_) + ")"; + return "Swish (" + std::to_string(beta_) + ")"; } } // namespace fl diff --git a/flashlight/fl/nn/modules/Activations.h b/flashlight/fl/nn/modules/Activations.h index 8f5d2fe..48ad817 100644 --- a/flashlight/fl/nn/modules/Activations.h +++ b/flashlight/fl/nn/modules/Activations.h @@ -17,16 +17,16 @@ namespace fl { * `Variable`: \f[\text{sigmoid}(x) = \frac{1}{1 + e^{-x}}\f] */ class FL_API Sigmoid : public UnaryModule { - public: - Sigmoid(); +public: + Sigmoid(); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(UnaryModule) +private: + FL_SAVE_LOAD_WITH_BASE(UnaryModule) }; /** @@ -34,16 +34,16 @@ class FL_API Sigmoid : public UnaryModule { * element-wise to a `Variable`. */ class FL_API Log : public UnaryModule { - public: - Log(); +public: + Log(); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(UnaryModule) +private: + FL_SAVE_LOAD_WITH_BASE(UnaryModule) }; /** @@ -53,16 +53,16 @@ class FL_API Log : public UnaryModule { *\f[\text{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}\f] */ class FL_API Tanh : public UnaryModule { - public: - Tanh(); +public: + Tanh(); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(UnaryModule) +private: + FL_SAVE_LOAD_WITH_BASE(UnaryModule) }; /** @@ -76,16 +76,16 @@ class FL_API Tanh : public UnaryModule { \f] */ class FL_API HardTanh : public UnaryModule { - public: - HardTanh(); +public: + HardTanh(); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(UnaryModule) +private: + FL_SAVE_LOAD_WITH_BASE(UnaryModule) }; /** @@ -95,16 +95,16 @@ class FL_API HardTanh : public UnaryModule { * \f[ ReLU(x) = \max(0, x) \f] */ class FL_API ReLU : public UnaryModule { - public: - ReLU(); +public: + ReLU(); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(UnaryModule) +private: + FL_SAVE_LOAD_WITH_BASE(UnaryModule) }; /** @@ -114,16 +114,16 @@ class FL_API ReLU : public UnaryModule { * function element-wise to a `Variable`: \f[ ReLU6(x) = \min(\max(0, x), 6) \f] */ class FL_API ReLU6 : public UnaryModule { - public: - ReLU6(); +public: + ReLU6(); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(UnaryModule) +private: + FL_SAVE_LOAD_WITH_BASE(UnaryModule) }; /** @@ -144,24 +144,24 @@ class FL_API ReLU6 : public UnaryModule { * be multiplied if less than zero. */ class FL_API LeakyReLU : public UnaryModule { - private: - double mSlope_; +private: + double mSlope_; - FL_SAVE_LOAD_WITH_BASE(UnaryModule, mSlope_) + FL_SAVE_LOAD_WITH_BASE(UnaryModule, mSlope_) - public: - /** - * Creates a `LeakyReLU` with the specified slope - * - * @param slope a constant by which the input will be multiplied if less than - * 0 - */ - LeakyReLU(double slope = 0.0); +public: + /** + * Creates a `LeakyReLU` with the specified slope + * + * @param slope a constant by which the input will be multiplied if less than + * 0 + */ + LeakyReLU(double slope = 0.0); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; }; /** @@ -182,36 +182,36 @@ class FL_API LeakyReLU : public UnaryModule { * tuned. */ class FL_API PReLU : public UnaryModule { - private: - PReLU() = default; // Intentionally private - - FL_SAVE_LOAD_WITH_BASE(UnaryModule) - - public: - /** - * Creates a `PReLU` with the specified value and input size - * - * @param value a constant by which the input will be multiplied if less than - * 0 - * @param size the number of learnable parameters. The size must be a multiple - * of the first dimension of the input - */ - explicit PReLU(int size, double value = 0.25); - - /** - * Creates a `PReLU` with a custom tensor; if the input is less than zero, the - * output is equal to the tensor product of the input and this tensor. The - * initialization for the learned tensor can be smaller than the input; it - * will be broadcast in order to compute the product. - * - * @param w the tensor initializing the learned \f$\text{value}\f$ parameter - */ - explicit PReLU(const Variable& w); - - Variable forward(const Variable& input) override; - - std::unique_ptr clone() const override; - std::string prettyString() const override; +private: + PReLU() = default; // Intentionally private + + FL_SAVE_LOAD_WITH_BASE(UnaryModule) + +public: + /** + * Creates a `PReLU` with the specified value and input size + * + * @param value a constant by which the input will be multiplied if less than + * 0 + * @param size the number of learnable parameters. The size must be a multiple + * of the first dimension of the input + */ + explicit PReLU(int size, double value = 0.25); + + /** + * Creates a `PReLU` with a custom tensor; if the input is less than zero, the + * output is equal to the tensor product of the input and this tensor. The + * initialization for the learned tensor can be smaller than the input; it + * will be broadcast in order to compute the product. + * + * @param w the tensor initializing the learned \f$\text{value}\f$ parameter + */ + explicit PReLU(const Variable& w); + + Variable forward(const Variable& input) override; + + std::unique_ptr clone() const override; + std::string prettyString() const override; }; /** @@ -230,18 +230,18 @@ class FL_API PReLU : public UnaryModule { * where \f$\alpha\f$ is a tunable parameter. */ class FL_API ELU : public UnaryModule { - private: - double mAlpha_; +private: + double mAlpha_; - FL_SAVE_LOAD_WITH_BASE(UnaryModule, mAlpha_) + FL_SAVE_LOAD_WITH_BASE(UnaryModule, mAlpha_) - public: - ELU(double alpha = 1.0); +public: + ELU(double alpha = 1.0); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; }; /** @@ -258,23 +258,23 @@ class FL_API ELU : public UnaryModule { * where \f$\text{threshold}\f$ is a tunable parameter. */ class FL_API ThresholdReLU : public UnaryModule { - private: - double mThreshold_; +private: + double mThreshold_; - FL_SAVE_LOAD_WITH_BASE(UnaryModule, mThreshold_) + FL_SAVE_LOAD_WITH_BASE(UnaryModule, mThreshold_) - public: - /** - * Creates a `ThresholdReLU` with the specified threshold. - * - * @param threshold the threshold value above which the unit returns the input - */ - ThresholdReLU(double threshold = 1.0); +public: + /** + * Creates a `ThresholdReLU` with the specified threshold. + * + * @param threshold the threshold value above which the unit returns the input + */ + ThresholdReLU(double threshold = 1.0); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; }; /** @@ -287,24 +287,24 @@ class FL_API ThresholdReLU : public UnaryModule { * \f$\sigma(x)\f$ is the sigmoid function. */ class FL_API GatedLinearUnit : public UnaryModule { - private: - int dim_; +private: + int dim_; - FL_SAVE_LOAD_WITH_BASE(UnaryModule, dim_) + FL_SAVE_LOAD_WITH_BASE(UnaryModule, dim_) - public: - /** - * Creates a `GatedLinearUnit`. - * - * @param dim the dimension along which the GLU will cut the input in half. - * This dimension must be even in size in the input tensor. - */ - GatedLinearUnit(int dim = 0); +public: + /** + * Creates a `GatedLinearUnit`. + * + * @param dim the dimension along which the GLU will cut the input in half. + * This dimension must be even in size in the input tensor. + */ + GatedLinearUnit(int dim = 0); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; }; /** @@ -315,23 +315,23 @@ class FL_API GatedLinearUnit : public UnaryModule { \f] */ class FL_API LogSoftmax : public UnaryModule { - private: - int dim_; +private: + int dim_; - FL_SAVE_LOAD_WITH_BASE(UnaryModule, dim_) + FL_SAVE_LOAD_WITH_BASE(UnaryModule, dim_) - public: - /** - * Creates a `LogSoftmax`. - * - * @param dim the dimension along which to apply the LogSoftmax. - */ - LogSoftmax(int dim = 0); +public: + /** + * Creates a `LogSoftmax`. + * + * @param dim the dimension along which to apply the LogSoftmax. + */ + LogSoftmax(int dim = 0); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; }; /** @@ -342,24 +342,24 @@ class FL_API LogSoftmax : public UnaryModule { * where \f$\beta\f$ is a constant, often is 1. */ class FL_API Swish : public UnaryModule { - public: - /** - * Creates a `Swish` with the specified beta - * - * @param beta a constant by which the input will be multiplied in the x * - * sigma(beta * x) - */ - Swish(double beta = 1.0); +public: + /** + * Creates a `Swish` with the specified beta + * + * @param beta a constant by which the input will be multiplied in the x * + * sigma(beta * x) + */ + Swish(double beta = 1.0); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; - private: - double beta_; +private: + double beta_; - FL_SAVE_LOAD_WITH_BASE(UnaryModule, beta_) + FL_SAVE_LOAD_WITH_BASE(UnaryModule, beta_) }; } // namespace fl diff --git a/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp b/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp index 9ea2ee4..b73575f 100644 --- a/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp +++ b/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp @@ -18,127 +18,143 @@ namespace fl { AdaptiveSoftMax::AdaptiveSoftMax( int inputSize, const std::vector& cutoff, - float divValue) - : UnaryModule(), cutoff_(cutoff), divValue_(divValue) { - if (cutoff_.empty()) { - throw std::invalid_argument("invalid cutoff for AdaptiveSoftMaxLoss"); - } - - int outputSize = cutoff_[0] + cutoff_.size() - 1; - - auto head = kaimingUniform( - {outputSize, inputSize}, inputSize /* fanIn */, fl::dtype::f32, true); - params_.push_back(head); - - int denominator = 1; - for (int i = 0; i < cutoff_.size() - 1; i++) { - denominator *= divValue_; - int hiddenSize = inputSize / denominator; - auto tail1 = kaimingUniform( - {hiddenSize, inputSize}, inputSize /* fanIn */, fl::dtype::f32, true); - auto tail2 = kaimingUniform( - {cutoff_[i + 1] - cutoff_[i], hiddenSize}, - hiddenSize /* fanIn */, + float divValue +) : UnaryModule(), + cutoff_(cutoff), + divValue_(divValue) { + if(cutoff_.empty()) { + throw std::invalid_argument("invalid cutoff for AdaptiveSoftMaxLoss"); + } + + int outputSize = cutoff_[0] + cutoff_.size() - 1; + + auto head = kaimingUniform( + {outputSize, inputSize}, + inputSize /* fanIn */, fl::dtype::f32, - true); - - params_.push_back(tail1); - params_.push_back(tail2); - } + true + ); + params_.push_back(head); + + int denominator = 1; + for(int i = 0; i < cutoff_.size() - 1; i++) { + denominator *= divValue_; + int hiddenSize = inputSize / denominator; + auto tail1 = kaimingUniform( + {hiddenSize, inputSize}, + inputSize /* fanIn */, + fl::dtype::f32, + true + ); + auto tail2 = kaimingUniform( + {cutoff_[i + 1] - cutoff_[i], hiddenSize}, + hiddenSize /* fanIn */, + fl::dtype::f32, + true + ); + + params_.push_back(tail1); + params_.push_back(tail2); + } } Variable AdaptiveSoftMax::getFullLogProb( const Variable& inputs, - const Variable& headOutput) const { - auto outputSize = cutoff_[cutoff_.size() - 1]; - auto batchSize = inputs.dim(1); - Tensor output({outputSize, batchSize}, inputs.type()); - - output( - fl::range(0, cutoff_[0] + static_cast(cutoff_.size()) - 1)) = - headOutput.tensor(); - - for (int i = cutoff_.size() - 2; i >= 0; i--) { - auto tailOutput = matmul(params_[1 + i * 2], inputs); - tailOutput = matmul(params_[2 + i * 2], tailOutput); - auto idx = i + cutoff_[0]; - tailOutput = logSoftmax(tailOutput, 0) + - tileAs(headOutput(fl::range(idx, idx + 1)), tailOutput); - output(fl::range(cutoff_[i], cutoff_[i + 1])) = tailOutput.tensor(); - } - - return Variable(output, false); + const Variable& headOutput +) const { + auto outputSize = cutoff_[cutoff_.size() - 1]; + auto batchSize = inputs.dim(1); + Tensor output({outputSize, batchSize}, inputs.type()); + + output( + fl::range(0, cutoff_[0] + static_cast(cutoff_.size()) - 1) + ) = + headOutput.tensor(); + + for(int i = cutoff_.size() - 2; i >= 0; i--) { + auto tailOutput = matmul(params_[1 + i * 2], inputs); + tailOutput = matmul(params_[2 + i * 2], tailOutput); + auto idx = i + cutoff_[0]; + tailOutput = logSoftmax(tailOutput, 0) + + tileAs(headOutput(fl::range(idx, idx + 1)), tailOutput); + output(fl::range(cutoff_[i], cutoff_[i + 1])) = tailOutput.tensor(); + } + + return Variable(output, false); } Variable AdaptiveSoftMax::forward(const Variable& inputs) { - // input -- [C_in, .. , N] - // return -- [C_out, .. , N] - auto inputSize = inputs.dim(0); - if (inputSize != params_[0].dim(1)) { - throw std::invalid_argument("invalid input dimension for AdaptiveSoftMax"); - } + // input -- [C_in, .. , N] + // return -- [C_out, .. , N] + auto inputSize = inputs.dim(0); + if(inputSize != params_[0].dim(1)) { + throw std::invalid_argument("invalid input dimension for AdaptiveSoftMax"); + } - auto inputsFlattened = moddims(inputs, {inputSize, -1}); - auto headOutput = logSoftmax(matmul(params_[0], inputsFlattened), 0); + auto inputsFlattened = moddims(inputs, {inputSize, -1}); + auto headOutput = logSoftmax(matmul(params_[0], inputsFlattened), 0); - auto ret = getFullLogProb(inputsFlattened, headOutput); + auto ret = getFullLogProb(inputsFlattened, headOutput); - Shape outDims = inputs.shape(); - outDims[0] = ret.dim(0); - return moddims(ret, outDims); + Shape outDims = inputs.shape(); + outDims[0] = ret.dim(0); + return moddims(ret, outDims); } Variable AdaptiveSoftMax::predict(const Variable& inputs) const { - // input -- [C, .. , N] - // return -- [1, .. , N] - auto inputSize = inputs.dim(0); - if (inputSize != params_[0].dim(1)) { - throw std::invalid_argument( - "invalid input dimension for AdaptiveSoftMaxLoss"); - } - - auto inputsFlattened = moddims(inputs, {inputSize, -1}); - auto headOutput = matmul(params_[0], inputsFlattened); - Tensor maxValue, prediction; - fl::max(maxValue, prediction, headOutput.tensor(), 0); - - auto notInShortlist = (prediction >= cutoff_[0]); - Variable ret = Variable(prediction, false); - if (fl::any(notInShortlist).asScalar()) { - headOutput = logSoftmax(headOutput, 0); - auto logProbTailPositions = getFullLogProb( - inputsFlattened(fl::span, notInShortlist), - headOutput(fl::span, notInShortlist)); - Tensor maxValueTailPositions, predictionTailPositions; - fl::max( - maxValueTailPositions, - predictionTailPositions, - logProbTailPositions.tensor(), - 0); - ret.tensor()(notInShortlist) = predictionTailPositions; - } - - Shape outDims = inputs.shape(); - outDims[0] = 1; - return moddims(ret, outDims); + // input -- [C, .. , N] + // return -- [1, .. , N] + auto inputSize = inputs.dim(0); + if(inputSize != params_[0].dim(1)) { + throw std::invalid_argument( + "invalid input dimension for AdaptiveSoftMaxLoss" + ); + } + + auto inputsFlattened = moddims(inputs, {inputSize, -1}); + auto headOutput = matmul(params_[0], inputsFlattened); + Tensor maxValue, prediction; + fl::max(maxValue, prediction, headOutput.tensor(), 0); + + auto notInShortlist = (prediction >= cutoff_[0]); + Variable ret = Variable(prediction, false); + if(fl::any(notInShortlist).asScalar()) { + headOutput = logSoftmax(headOutput, 0); + auto logProbTailPositions = getFullLogProb( + inputsFlattened(fl::span, notInShortlist), + headOutput(fl::span, notInShortlist) + ); + Tensor maxValueTailPositions, predictionTailPositions; + fl::max( + maxValueTailPositions, + predictionTailPositions, + logProbTailPositions.tensor(), + 0 + ); + ret.tensor()(notInShortlist) = predictionTailPositions; + } + + Shape outDims = inputs.shape(); + outDims[0] = 1; + return moddims(ret, outDims); } std::vector AdaptiveSoftMax::getCutoff() const { - return cutoff_; + return cutoff_; } std::unique_ptr AdaptiveSoftMax::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string AdaptiveSoftMax::prettyString() const { - std::ostringstream ss; - ss << "Adaptive Softmax ("; - for (int i = 0; i < cutoff_.size() - 1; i++) { - ss << cutoff_[i] << ", "; - } - ss << cutoff_[cutoff_.size() - 1] << ")"; - return ss.str(); + std::ostringstream ss; + ss << "Adaptive Softmax ("; + for(int i = 0; i < cutoff_.size() - 1; i++) { + ss << cutoff_[i] << ", "; + } + ss << cutoff_[cutoff_.size() - 1] << ")"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/AdaptiveSoftMax.h b/flashlight/fl/nn/modules/AdaptiveSoftMax.h index 0cd9eb2..fc7660b 100644 --- a/flashlight/fl/nn/modules/AdaptiveSoftMax.h +++ b/flashlight/fl/nn/modules/AdaptiveSoftMax.h @@ -27,76 +27,77 @@ namespace fl { * up computation. */ class FL_API AdaptiveSoftMax : public UnaryModule { - private: - FL_SAVE_LOAD_WITH_BASE(UnaryModule, cutoff_, divValue_) - std::vector cutoff_; - float divValue_; +private: + FL_SAVE_LOAD_WITH_BASE(UnaryModule, cutoff_, divValue_) + std::vector cutoff_; + float divValue_; - /** - * Compute the output of the entire distribution. - * - * @param inputs values for each class to compute probabilities over - * @param head_output the output of the first frequency bucket (the 'top' - * bucket) - * @returns `Variable` containing the log probabilities over the full - * distribution - */ - Variable getFullLogProb(const Variable& inputs, const Variable& headOutput) - const; + /** + * Compute the output of the entire distribution. + * + * @param inputs values for each class to compute probabilities over + * @param head_output the output of the first frequency bucket (the 'top' + * bucket) + * @returns `Variable` containing the log probabilities over the full + * distribution + */ + Variable getFullLogProb(const Variable& inputs, const Variable& headOutput) + const; - public: - AdaptiveSoftMax() = default; +public: + AdaptiveSoftMax() = default; - /** - * Create an `AdaptiveSoftMax` with given parameters - * - * @param input_size the size of the input tensor, which doesn't has to be the - * number of classes. - * @param cutoff a sequence of integers sorted in ascending order, which - * determines the relative size of each bucket, and how many partitions are - * created. For example, given cutoffs `{5, 50, 100}`, the head bucket will - * contain `5 + 2 = 7` targets (`2` additional from the two tail buckets), the - * first tail bucket will contain `50 - 5 = 45` targets (subtracting the size - * of the head bucket), the second tail bucket will contain `100 - 50 = 50` - * targets (subtracting the size of the first tail bucket). Cutoffs must be - * specified to accommodate all targets: any remaining targets are not - * assigned to an 'overflow' bucket. - * @param div_value determines the number of hidden units in the intermediate - * layer for each tail bucket: - * \f[ - * \left\lfloor \frac{input\_size}{div\_value^{idx}} \right\rfloor - * \f] - */ - AdaptiveSoftMax( - int inputSize, - const std::vector& cutoff, - float divValue = 4); + /** + * Create an `AdaptiveSoftMax` with given parameters + * + * @param input_size the size of the input tensor, which doesn't has to be the + * number of classes. + * @param cutoff a sequence of integers sorted in ascending order, which + * determines the relative size of each bucket, and how many partitions are + * created. For example, given cutoffs `{5, 50, 100}`, the head bucket will + * contain `5 + 2 = 7` targets (`2` additional from the two tail buckets), the + * first tail bucket will contain `50 - 5 = 45` targets (subtracting the size + * of the head bucket), the second tail bucket will contain `100 - 50 = 50` + * targets (subtracting the size of the first tail bucket). Cutoffs must be + * specified to accommodate all targets: any remaining targets are not + * assigned to an 'overflow' bucket. + * @param div_value determines the number of hidden units in the intermediate + * layer for each tail bucket: + * \f[ + * \left\lfloor \frac{input\_size}{div\_value^{idx}} \right\rfloor + * \f] + */ + AdaptiveSoftMax( + int inputSize, + const std::vector& cutoff, + float divValue = 4 + ); - /** - * Computes log-probabilities across all classes for some input. - * - * @param inputs a Variable with size [\f$C_{in}\f$, \f$B_1\f$, \f$B_2\f$, - * \f$B_3\f$] - * @return a Variable containing log probabilities for each class with size - * [\f$C\f$, \f$B_1\f$, \f$B_2\f$, \f$B_3\f$], where \f$C\f$ is the number of - * classes. - */ - Variable forward(const Variable& inputs) override; + /** + * Computes log-probabilities across all classes for some input. + * + * @param inputs a Variable with size [\f$C_{in}\f$, \f$B_1\f$, \f$B_2\f$, + * \f$B_3\f$] + * @return a Variable containing log probabilities for each class with size + * [\f$C\f$, \f$B_1\f$, \f$B_2\f$, \f$B_3\f$], where \f$C\f$ is the number of + * classes. + */ + Variable forward(const Variable& inputs) override; - /** - * Computes the class with highest probability for each example in a given - * input. - * - * @param inputs a Variable with size [\f$C_{in}\f$, \f$B_1\f$, \f$B_2\f$, - * \f$B_3\f$]. - * @return a Variable with shape [\f$1\f$, \f$B_1\f$, \f$B_2\f$, \f$B_3\f$], - * containing the classes with the highest probabilities, over each sample. - */ - Variable predict(const Variable& inputs) const; - std::vector getCutoff() const; + /** + * Computes the class with highest probability for each example in a given + * input. + * + * @param inputs a Variable with size [\f$C_{in}\f$, \f$B_1\f$, \f$B_2\f$, + * \f$B_3\f$]. + * @return a Variable with shape [\f$1\f$, \f$B_1\f$, \f$B_2\f$, \f$B_3\f$], + * containing the classes with the highest probabilities, over each sample. + */ + Variable predict(const Variable& inputs) const; + std::vector getCutoff() const; - std::unique_ptr clone() const override; - std::string prettyString() const override; + std::unique_ptr clone() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/BatchNorm.cpp b/flashlight/fl/nn/modules/BatchNorm.cpp index 75f689b..4039ed9 100644 --- a/flashlight/fl/nn/modules/BatchNorm.cpp +++ b/flashlight/fl/nn/modules/BatchNorm.cpp @@ -18,14 +18,14 @@ BatchNorm::BatchNorm( double momentum /* = 0.1 */, double eps /* = 1e-5*/, bool affine /* = true*/, - bool trackStats /* = true*/) - : BatchNorm( - std::vector(1, featAxis), - featSize, - momentum, - eps, - affine, - trackStats) {} + bool trackStats /* = true*/ +) : BatchNorm( + std::vector(1, featAxis), + featSize, + momentum, + eps, + affine, + trackStats) {} BatchNorm::BatchNorm( const std::vector& featAxis, @@ -33,96 +33,96 @@ BatchNorm::BatchNorm( double momentum /* = 0.1*/, double eps /* = 1e-5 */, bool affine /* = true*/, - bool trackStats /* = true*/) - : featAxis_(featAxis), - featSize_(featSize), - numBatchesTracked_(0), - momentum_(momentum), - epsilon_(eps), - affine_(affine), - trackStats_(trackStats) { - initialize(); + bool trackStats /* = true*/ +) : featAxis_(featAxis), + featSize_(featSize), + numBatchesTracked_(0), + momentum_(momentum), + epsilon_(eps), + affine_(affine), + trackStats_(trackStats) { + initialize(); } -BatchNorm::BatchNorm(const BatchNorm& other) - : featAxis_(other.featAxis_), - featSize_(other.featSize_), - numBatchesTracked_(other.numBatchesTracked_), - runningMean_(other.runningMean_.copy()), - runningVar_(other.runningVar_.copy()), - momentum_(other.momentum_), - epsilon_(other.epsilon_), - affine_(other.affine_), - trackStats_(other.trackStats_) { - train_ = other.train_; +BatchNorm::BatchNorm(const BatchNorm& other) : featAxis_(other.featAxis_), + featSize_(other.featSize_), + numBatchesTracked_(other.numBatchesTracked_), + runningMean_(other.runningMean_.copy()), + runningVar_(other.runningVar_.copy()), + momentum_(other.momentum_), + epsilon_(other.epsilon_), + affine_(other.affine_), + trackStats_(other.trackStats_) { + train_ = other.train_; } BatchNorm& BatchNorm::operator=(const BatchNorm& other) { - train_ = other.train_; - featAxis_ = other.featAxis_; - featSize_ = other.featSize_; - numBatchesTracked_ = other.numBatchesTracked_; - runningMean_ = other.runningMean_.copy(); - runningVar_ = other.runningVar_.copy(); - momentum_ = other.momentum_; - epsilon_ = other.epsilon_; - affine_ = other.affine_; - trackStats_ = other.trackStats_; - return *this; + train_ = other.train_; + featAxis_ = other.featAxis_; + featSize_ = other.featSize_; + numBatchesTracked_ = other.numBatchesTracked_; + runningMean_ = other.runningMean_.copy(); + runningVar_ = other.runningVar_.copy(); + momentum_ = other.momentum_; + epsilon_ = other.epsilon_; + affine_ = other.affine_; + trackStats_ = other.trackStats_; + return *this; } Variable BatchNorm::forward(const Variable& input) { - double avgFactor = 0.0; + double avgFactor = 0.0; - if (train_ && trackStats_) { - ++numBatchesTracked_; - if (momentum_ < 0) { // cumulative moving average - avgFactor = 1.0 / numBatchesTracked_; - } else { // exponential moving average - avgFactor = momentum_; + if(train_ && trackStats_) { + ++numBatchesTracked_; + if(momentum_ < 0) { // cumulative moving average + avgFactor = 1.0 / numBatchesTracked_; + } else { // exponential moving average + avgFactor = momentum_; + } } - } - auto paramsType = - (input.type() == fl::dtype::f16) ? fl::dtype::f32 : input.type(); - return batchnorm( - input, - params_.empty() ? Variable(Tensor(paramsType), false) : params_[0], - params_.empty() ? Variable(Tensor(paramsType), false) : params_[1], - runningMean_, - runningVar_, - featAxis_, - train_ || (!trackStats_), - avgFactor, - epsilon_); + auto paramsType = + (input.type() == fl::dtype::f16) ? fl::dtype::f32 : input.type(); + return batchnorm( + input, + params_.empty() ? Variable(Tensor(paramsType), false) : params_[0], + params_.empty() ? Variable(Tensor(paramsType), false) : params_[1], + runningMean_, + runningVar_, + featAxis_, + train_ || (!trackStats_), + avgFactor, + epsilon_ + ); } void BatchNorm::initialize() { - if (trackStats_) { - runningMean_ = constant(0.0, {featSize_}, fl::dtype::f32, false); - runningVar_ = constant(1.0, {featSize_}, fl::dtype::f32, false); - } + if(trackStats_) { + runningMean_ = constant(0.0, {featSize_}, fl::dtype::f32, false); + runningVar_ = constant(1.0, {featSize_}, fl::dtype::f32, false); + } - if (affine_) { - auto wt = uniform({featSize_}, 0.0, 1.0, fl::dtype::f32, true); - auto bs = constant(0.0, {featSize_}, fl::dtype::f32, true); - params_ = {wt, bs}; - } + if(affine_) { + auto wt = uniform({featSize_}, 0.0, 1.0, fl::dtype::f32, true); + auto bs = constant(0.0, {featSize_}, fl::dtype::f32, true); + params_ = {wt, bs}; + } } std::unique_ptr BatchNorm::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string BatchNorm::prettyString() const { - std::ostringstream ss; - ss << "BatchNorm"; - ss << " ( axis : { "; - for (auto x : featAxis_) { - ss << x << " "; - } - ss << "}, size : " << featSize_ << " )"; - return ss.str(); + std::ostringstream ss; + ss << "BatchNorm"; + ss << " ( axis : { "; + for(auto x : featAxis_) { + ss << x << " "; + } + ss << "}, size : " << featSize_ << " )"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/BatchNorm.h b/flashlight/fl/nn/modules/BatchNorm.h index 63c0b38..aefcc26 100644 --- a/flashlight/fl/nn/modules/BatchNorm.h +++ b/flashlight/fl/nn/modules/BatchNorm.h @@ -24,114 +24,117 @@ namespace fl { * \f$\beta\f$ are learnable parameters for affine transformation. */ class FL_API BatchNorm : public UnaryModule { - protected: - BatchNorm() = default; // intentionally protected - std::vector featAxis_; - int featSize_; - int numBatchesTracked_; - Variable runningMean_, runningVar_; - double momentum_, epsilon_; - bool affine_, trackStats_; - - FL_SAVE_LOAD_WITH_BASE( - UnaryModule, - featAxis_, - fl::serializeAs(featSize_), - fl::serializeAs(numBatchesTracked_), - runningMean_, - runningVar_, - momentum_, - epsilon_, - affine_, - trackStats_) - - /** - * Called in the constructor to initialize parameters for running mean and - * variance if `trackStats` is set to `true`, - * and \f$\gamma\f$ and \f$\beta\f$ as learnable parameters if - * `affine` is set to `true` in the constructor. - */ - void initialize(); - - public: - /** - * Constructs a BatchNorm module. - * - * @param featAxis the axis over which normalizationis performed - * @param featSize the size of the dimension along `featAxis` - * @param momentum an exponential average factor used to compute running mean - * and variance. - * \f[ runningMean = runningMean \times (1-momentum) - * + newMean \times momentum \f] - * If < 0, cumulative moving average is used. - * @param eps \f$\epsilon\f$ - * @param affine a boolean value that controls the learning of \f$\gamma\f$ - * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively - * if set to `false`, or initialized as learnable parameters - * if set to `true`. - * @param trackStats a boolean value that controls whether to track the - * running mean and variance while in train mode. If `false`, batch - * statistics are used to perform normalization in both train and eval mode. - */ - BatchNorm( - int featAxis, - int featSize, - double momentum = 0.1, - double eps = 1e-5, - bool affine = true, - bool trackStats = true); - - /** - * Constructs a BatchNorm module. - * - * @param featAxis the axis over which normalization is performed - * @param featSize total dimension along `featAxis`. - * For example, to perform Temporal Batch Normalization on input of size - * [\f$L\f$, \f$C\f$, \f$N\f$], use `featAxis` = {1}, `featSize` = \f$C\f$. - * To perform normalization per activation on input of size - * [\f$W\f$, \f$H\f$, \f$C\f$, \f$N\f$], use `featAxis` = {0, 1, 2}, - * `featSize` = \f$W \times H \times C\f$. - * @param momentum an exponential average factor used to compute running mean - * and variance. - * \f[ runningMean = runningMean \times (1-momentum) - * + newMean \times momentum \f] - * If < 0, cumulative moving average is used. - * @param eps \f$\epsilon\f$ - * @param affine a boolean value that controls the learning of \f$\gamma\f$ - * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively - * if set to `false`, or initialized as learnable parameters - * if set to `true`. - * @param trackStats a boolean value that controls whether to track the - * running mean and variance while in train mode. If `false`, batch - * statistics are used to perform normalization in both train and eval mode. - */ - BatchNorm( - const std::vector& featAxis, - int featSize, - double momentum = 0.1, - double eps = 1e-5, - bool affine = true, - bool trackStats = true); - - /** - * Constructs a BatchNorm module from another, performing a copy of the - * stats parameters. - * - * @param other The BatchNorm module to copy from. - */ - BatchNorm(const BatchNorm& other); - - BatchNorm& operator=(const BatchNorm& other); - - BatchNorm(BatchNorm&& other) = default; - - BatchNorm& operator=(BatchNorm&& other) = default; - - Variable forward(const Variable& input) override; - - std::unique_ptr clone() const override; - - std::string prettyString() const override; +protected: + BatchNorm() = default; // intentionally protected + std::vector featAxis_; + int featSize_; + int numBatchesTracked_; + Variable runningMean_, runningVar_; + double momentum_, epsilon_; + bool affine_, trackStats_; + + FL_SAVE_LOAD_WITH_BASE( + UnaryModule, + featAxis_, + fl::serializeAs(featSize_), + fl::serializeAs(numBatchesTracked_), + runningMean_, + runningVar_, + momentum_, + epsilon_, + affine_, + trackStats_ + ) + + /** + * Called in the constructor to initialize parameters for running mean and + * variance if `trackStats` is set to `true`, + * and \f$\gamma\f$ and \f$\beta\f$ as learnable parameters if + * `affine` is set to `true` in the constructor. + */ + void initialize(); + +public: + /** + * Constructs a BatchNorm module. + * + * @param featAxis the axis over which normalizationis performed + * @param featSize the size of the dimension along `featAxis` + * @param momentum an exponential average factor used to compute running mean + * and variance. + * \f[ runningMean = runningMean \times (1-momentum) + * + newMean \times momentum \f] + * If < 0, cumulative moving average is used. + * @param eps \f$\epsilon\f$ + * @param affine a boolean value that controls the learning of \f$\gamma\f$ + * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively + * if set to `false`, or initialized as learnable parameters + * if set to `true`. + * @param trackStats a boolean value that controls whether to track the + * running mean and variance while in train mode. If `false`, batch + * statistics are used to perform normalization in both train and eval mode. + */ + BatchNorm( + int featAxis, + int featSize, + double momentum = 0.1, + double eps = 1e-5, + bool affine = true, + bool trackStats = true + ); + + /** + * Constructs a BatchNorm module. + * + * @param featAxis the axis over which normalization is performed + * @param featSize total dimension along `featAxis`. + * For example, to perform Temporal Batch Normalization on input of size + * [\f$L\f$, \f$C\f$, \f$N\f$], use `featAxis` = {1}, `featSize` = \f$C\f$. + * To perform normalization per activation on input of size + * [\f$W\f$, \f$H\f$, \f$C\f$, \f$N\f$], use `featAxis` = {0, 1, 2}, + * `featSize` = \f$W \times H \times C\f$. + * @param momentum an exponential average factor used to compute running mean + * and variance. + * \f[ runningMean = runningMean \times (1-momentum) + * + newMean \times momentum \f] + * If < 0, cumulative moving average is used. + * @param eps \f$\epsilon\f$ + * @param affine a boolean value that controls the learning of \f$\gamma\f$ + * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively + * if set to `false`, or initialized as learnable parameters + * if set to `true`. + * @param trackStats a boolean value that controls whether to track the + * running mean and variance while in train mode. If `false`, batch + * statistics are used to perform normalization in both train and eval mode. + */ + BatchNorm( + const std::vector& featAxis, + int featSize, + double momentum = 0.1, + double eps = 1e-5, + bool affine = true, + bool trackStats = true + ); + + /** + * Constructs a BatchNorm module from another, performing a copy of the + * stats parameters. + * + * @param other The BatchNorm module to copy from. + */ + BatchNorm(const BatchNorm& other); + + BatchNorm& operator=(const BatchNorm& other); + + BatchNorm(BatchNorm&& other) = default; + + BatchNorm& operator=(BatchNorm&& other) = default; + + Variable forward(const Variable& input) override; + + std::unique_ptr clone() const override; + + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Container.cpp b/flashlight/fl/nn/modules/Container.cpp index 159eeae..876de5e 100644 --- a/flashlight/fl/nn/modules/Container.cpp +++ b/flashlight/fl/nn/modules/Container.cpp @@ -14,127 +14,127 @@ namespace fl { Container::Container() = default; void Container::clear() { - childParamIdx_.clear(); - modules_.clear(); - params_.clear(); + childParamIdx_.clear(); + modules_.clear(); + params_.clear(); } std::unordered_multimap Container::getOrphanedParamsIdxMap() const { - // The previous module index which has params - int prevMidx = -1; - std::unordered_multimap orphanedParamsIdxMap; - for (size_t i = 0; i < params_.size();) { - auto paramIdx = childParamIdx_.find(i); - if (paramIdx != childParamIdx_.end()) { - const auto [midx, pidx] = paramIdx->second; - prevMidx = midx; - const auto& mod = modules_.at(midx); - i += mod->params().size(); - } else { - orphanedParamsIdxMap.emplace(prevMidx, static_cast(i)); - ++i; + // The previous module index which has params + int prevMidx = -1; + std::unordered_multimap orphanedParamsIdxMap; + for(size_t i = 0; i < params_.size();) { + auto paramIdx = childParamIdx_.find(i); + if(paramIdx != childParamIdx_.end()) { + const auto [midx, pidx] = paramIdx->second; + prevMidx = midx; + const auto& mod = modules_.at(midx); + i += mod->params().size(); + } else { + orphanedParamsIdxMap.emplace(prevMidx, static_cast(i)); + ++i; + } } - } - return orphanedParamsIdxMap; + return orphanedParamsIdxMap; } ModulePtr Container::module(int id) const { - return modules_[id]; + return modules_[id]; } std::vector Container::modules() const { - return modules_; + return modules_; } void Container::train() { - train_ = true; + train_ = true; - for (int i = 0; i < params_.size(); ++i) { - if (childParamIdx_.find(i) == childParamIdx_.end()) { - params_[i].setCalcGrad(true); + for(int i = 0; i < params_.size(); ++i) { + if(childParamIdx_.find(i) == childParamIdx_.end()) { + params_[i].setCalcGrad(true); + } } - } - for (auto& module : modules_) { - module->train(); - } + for(auto& module : modules_) { + module->train(); + } } void Container::eval() { - train_ = false; + train_ = false; - for (int i = 0; i < params_.size(); ++i) { - if (childParamIdx_.find(i) == childParamIdx_.end()) { - params_[i].setCalcGrad(false); + for(int i = 0; i < params_.size(); ++i) { + if(childParamIdx_.find(i) == childParamIdx_.end()) { + params_[i].setCalcGrad(false); + } } - } - for (auto& module : modules_) { - module->eval(); - } + for(auto& module : modules_) { + module->eval(); + } } void Container::setParams(const Variable& var, int position) { - Module::setParams(var, position); - auto indices = childParamIdx_.find(position); - if (indices != childParamIdx_.end()) { - int midx, pidx; - std::tie(midx, pidx) = indices->second; - modules_[midx]->setParams(var, pidx); - } + Module::setParams(var, position); + auto indices = childParamIdx_.find(position); + if(indices != childParamIdx_.end()) { + int midx, pidx; + std::tie(midx, pidx) = indices->second; + modules_[midx]->setParams(var, pidx); + } } std::string Container::prettyString() const { - std::ostringstream ss; - ss << " [input"; - for (int i = 0; i < modules_.size(); ++i) { - ss << " -> (" << i << ")"; - } - ss << " -> output]"; - for (int i = 0; i < modules_.size(); ++i) { - ss << "\n\t(" << i << "): " << modules_[i]->prettyString(); - } - return ss.str(); + std::ostringstream ss; + ss << " [input"; + for(int i = 0; i < modules_.size(); ++i) { + ss << " -> (" << i << ")"; + } + ss << " -> output]"; + for(int i = 0; i < modules_.size(); ++i) { + ss << "\n\t(" << i << "): " << modules_[i]->prettyString(); + } + return ss.str(); } Sequential::Sequential() = default; std::vector Sequential::forward(const std::vector& input) { - auto output = input; - for (auto& module : modules_) { - output = module->forward(output); - } - return output; + auto output = input; + for(auto& module : modules_) { + output = module->forward(output); + } + return output; } Variable Sequential::forward(const Variable& input) { - std::vector output = {input}; - for (auto& module : modules_) { - output = module->forward(output); - } - if (output.size() != 1) { - throw std::invalid_argument("Module output size is not 1"); - } - return output.front(); + std::vector output = {input}; + for(auto& module : modules_) { + output = module->forward(output); + } + if(output.size() != 1) { + throw std::invalid_argument("Module output size is not 1"); + } + return output.front(); } Variable Sequential::operator()(const Variable& input) { - return this->forward(input); + return this->forward(input); } std::string Sequential::prettyString() const { - std::ostringstream ss; - ss << "Sequential"; - ss << " [input"; - for (int i = 0; i < modules_.size(); ++i) { - ss << " -> (" << i << ")"; - } - ss << " -> output]"; - for (int i = 0; i < modules_.size(); ++i) { - ss << "\n\t(" << i << "): " << modules_[i]->prettyString(); - } - return ss.str(); + std::ostringstream ss; + ss << "Sequential"; + ss << " [input"; + for(int i = 0; i < modules_.size(); ++i) { + ss << " -> (" << i << ")"; + } + ss << " -> output]"; + for(int i = 0; i < modules_.size(); ++i) { + ss << "\n\t(" << i << "): " << modules_[i]->prettyString(); + } + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Container.h b/flashlight/fl/nn/modules/Container.h index 9111079..0d326ba 100644 --- a/flashlight/fl/nn/modules/Container.h +++ b/flashlight/fl/nn/modules/Container.h @@ -26,146 +26,147 @@ typedef std::shared_ptr ModulePtr; * collection of multiple `Module` and their respective parameters. */ class FL_API Container : public Module { - private: - // Keep track of location of submodule parameters in a map - // from param index -> {module index, module param index} - std::unordered_map> childParamIdx_; +private: + // Keep track of location of submodule parameters in a map + // from param index -> {module index, module param index} + std::unordered_map> childParamIdx_; - FL_SAVE_LOAD_WITH_BASE(Module, modules_, childParamIdx_) + FL_SAVE_LOAD_WITH_BASE(Module, modules_, childParamIdx_) - protected: - /** - * A collection of modules contained within a `Container`. - */ - std::vector modules_; +protected: + /** + * A collection of modules contained within a `Container`. + */ + std::vector modules_; - Container(); + Container(); - /** - * Removes all modules and parameters from the container. - */ - virtual void clear(); + /** + * Removes all modules and parameters from the container. + */ + virtual void clear(); - /** - * Find orphaned params (i.e. params not in modules contained in the modules_ - * list). This can be used to preserve the order of orphaned params when - * copying/cloning a container. std::unordered_multimap - * The module_idx is used to identify after which module params should be - * serted and the param_idx is used to index the specific param. The following - * example demonstrates its usage by ensuring params and modules are inserted - * in the same order when making a copy: - * \code - void copy(const MyContainer& other) { - auto orphanParamIdxMap = other.getOrphanedParamsIdxMap(); - for (int i = -1; i < static_cast(other.modules_.size()); ++i) { - if (i >= 0) { - add(other.modules_[i]->clone()); - } - auto [paramIter, pEnd] = orphanParamIdxMap.equal_range(i); - for (; paramIter != pEnd; ++paramIter) { - const auto& param = other.params_[paramIter->second]; - params_.emplace_back(param.copy()); + /** + * Find orphaned params (i.e. params not in modules contained in the modules_ + * list). This can be used to preserve the order of orphaned params when + * copying/cloning a container. std::unordered_multimap + * The module_idx is used to identify after which module params should be + * serted and the param_idx is used to index the specific param. The following + * example demonstrates its usage by ensuring params and modules are inserted + * in the same order when making a copy: + * \code + void copy(const MyContainer& other) { + auto orphanParamIdxMap = other.getOrphanedParamsIdxMap(); + for (int i = -1; i < static_cast(other.modules_.size()); ++i) { + if (i >= 0) { + add(other.modules_[i]->clone()); + } + auto [paramIter, pEnd] = orphanParamIdxMap.equal_range(i); + for (; paramIter != pEnd; ++paramIter) { + const auto& param = other.params_[paramIter->second]; + params_.emplace_back(param.copy()); + } } } - } - \endcode - * - * A module_idx of -1 indicates the orphaned params are to be inserted - * before the first module - * - * @return A multimap of orphaned params and the module index they appear - * after - */ - std::unordered_multimap getOrphanedParamsIdxMap() const; + \endcode + * + * A module_idx of -1 indicates the orphaned params are to be inserted + * before the first module + * + * @return A multimap of orphaned params and the module index they appear + * after + */ + std::unordered_multimap getOrphanedParamsIdxMap() const; - public: - /** - * Adds a module to a `Container` by making a copy of the underlying module if - * an lvalue or moving it if and rvalue - * - * @param[in] module the module to add. - */ - template - void add(T&& module) { - static_assert( - !std::is_lvalue_reference_v, - "add() can only accept rvalues. Use std::move()."); - add(std::make_shared>(std::forward(module))); - } - - /** - * Adds a module to a `Container` by moving it and taking ownership. - * - * @param module the module to add. - */ - template - void add(std::unique_ptr module) { - add(std::shared_ptr(std::move(module))); - } +public: + /** + * Adds a module to a `Container` by making a copy of the underlying module if + * an lvalue or moving it if and rvalue + * + * @param[in] module the module to add. + */ + template + void add(T&& module) { + static_assert( + !std::is_lvalue_reference_v, + "add() can only accept rvalues. Use std::move()." + ); + add(std::make_shared>(std::forward(module))); + } - /** - * Adds a module to `modules_`, and adds parameters to the container's - * `params_`. - * - * @param module the module to add. - */ - template - void add(std::shared_ptr module) { - if (!module) { - throw std::invalid_argument("can't add null Module to Container"); + /** + * Adds a module to a `Container` by moving it and taking ownership. + * + * @param module the module to add. + */ + template + void add(std::unique_ptr module) { + add(std::shared_ptr(std::move(module))); } - for (int i = 0; i < module->numParamTensors(); i++) { - childParamIdx_[params_.size()] = std::make_tuple(static_cast(modules_.size()), i); - params_.push_back(module->param(i)); + + /** + * Adds a module to `modules_`, and adds parameters to the container's + * `params_`. + * + * @param module the module to add. + */ + template + void add(std::shared_ptr module) { + if(!module) { + throw std::invalid_argument("can't add null Module to Container"); + } + for(int i = 0; i < module->numParamTensors(); i++) { + childParamIdx_[params_.size()] = std::make_tuple(static_cast(modules_.size()), i); + params_.push_back(module->param(i)); + } + modules_.emplace_back(std::move(module)); } - modules_.emplace_back(std::move(module)); - } - /** - * Returns a pointer to the module at the specified index in the container's - * `modules_`. - * - * @param id the index of the module to return - * @return a pointer to the requested module - */ - ModulePtr module(int id) const; + /** + * Returns a pointer to the module at the specified index in the container's + * `modules_`. + * + * @param id the index of the module to return + * @return a pointer to the requested module + */ + ModulePtr module(int id) const; - /** - * Returns pointers to each of `Module` in the `Container`. - * - * @return an ordered vector of pointers for each module. - */ - std::vector modules() const; + /** + * Returns pointers to each of `Module` in the `Container`. + * + * @return an ordered vector of pointers for each module. + */ + std::vector modules() const; - /** - * Switches all modules in the `Container` into train mode. See `Module`. - */ - void train() override; + /** + * Switches all modules in the `Container` into train mode. See `Module`. + */ + void train() override; - /** - * Switches all modules in the `Container` into eval mode. See `Module`. - */ - void eval() override; + /** + * Switches all modules in the `Container` into eval mode. See `Module`. + */ + void eval() override; - /** - * Sets a parameter at a specified position with a new, given one. - * - * If the specified position is not valid (it is negative or greater than - * ``params_.size() - 1``), then an error will be thrown. A new parameter - * will not be created at a specified index if out of bounds. - * - * @param var the new replacement `Variable` - * @param position The index of the parameter which will be replaced in - * `params_` - */ - void setParams(const Variable& var, int position) override; + /** + * Sets a parameter at a specified position with a new, given one. + * + * If the specified position is not valid (it is negative or greater than + * ``params_.size() - 1``), then an error will be thrown. A new parameter + * will not be created at a specified index if out of bounds. + * + * @param var the new replacement `Variable` + * @param position The index of the parameter which will be replaced in + * `params_` + */ + void setParams(const Variable& var, int position) override; - /** - * Generates a stringified representation of the module. - * - * @return a string containing the module label - */ - virtual std::string prettyString() const override; + /** + * Generates a stringified representation of the module. + * + * @return a string containing the module label + */ + virtual std::string prettyString() const override; }; /** @@ -188,26 +189,26 @@ class FL_API Container : public Module { }; \endcode */ -#define FL_BASIC_CONTAINER_CLONING(ContainerClass) \ - ContainerClass(const ContainerClass& other) { \ - train_ = other.train_; \ - for (auto& mod : other.modules_) { \ - add(mod->clone()); \ - } \ - } \ - ContainerClass& operator=(const ContainerClass& other) { \ - train_ = other.train_; \ - clear(); \ - for (auto& mod : other.modules_) { \ - add(mod->clone()); \ - } \ - return *this; \ - } \ - ContainerClass(ContainerClass&& other) = default; \ - ContainerClass& operator=(ContainerClass&& other) = default; \ - std::unique_ptr clone() const override { \ - return std::make_unique(*this); \ - } +#define FL_BASIC_CONTAINER_CLONING(ContainerClass) \ + ContainerClass(const ContainerClass& other) { \ + train_ = other.train_; \ + for(auto& mod : other.modules_) { \ + add(mod->clone()); \ + } \ + } \ + ContainerClass& operator=(const ContainerClass& other) { \ + train_ = other.train_; \ + clear(); \ + for(auto& mod : other.modules_) { \ + add(mod->clone()); \ + } \ + return *this; \ + } \ + ContainerClass(ContainerClass && other) = default; \ + ContainerClass& operator=(ContainerClass && other) = default; \ + std::unique_ptr clone() const override { \ + return std::make_unique(*this); \ + } /** * A `Container` representing an ordered sequence of modules, which is capable @@ -231,37 +232,37 @@ class FL_API Container : public Module { \endcode */ class FL_API Sequential : public Container { - public: - Sequential(); +public: + Sequential(); - /** - * Performs forward computation for the `Sequential`, calling `forward`, in - * order, for each `Module`, and feeding the result as input to the next - * `Module`. - * - * @param input the value on which the `Container` will perform forward - * computation. - * @return a `Variable` tensor containing the result of the forward - * computation - */ - std::vector forward(const std::vector& input) override; + /** + * Performs forward computation for the `Sequential`, calling `forward`, in + * order, for each `Module`, and feeding the result as input to the next + * `Module`. + * + * @param input the value on which the `Container` will perform forward + * computation. + * @return a `Variable` tensor containing the result of the forward + * computation + */ + std::vector forward(const std::vector& input) override; - Variable forward(const Variable& input); + Variable forward(const Variable& input); - Variable operator()(const Variable& input); + Variable operator()(const Variable& input); - /** - * Generates a stringified representation of the `Sequential` by concatenating - * string representations for each contained `Module` - * - * @return a string containing the module label - */ - std::string prettyString() const override; + /** + * Generates a stringified representation of the `Sequential` by concatenating + * string representations for each contained `Module` + * + * @return a string containing the module label + */ + std::string prettyString() const override; - FL_BASIC_CONTAINER_CLONING(Sequential) + FL_BASIC_CONTAINER_CLONING(Sequential) - private: - FL_SAVE_LOAD_WITH_BASE(Container) +private: + FL_SAVE_LOAD_WITH_BASE(Container) }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Conv2D.cpp b/flashlight/fl/nn/modules/Conv2D.cpp index 57a97d4..ab5f302 100644 --- a/flashlight/fl/nn/modules/Conv2D.cpp +++ b/flashlight/fl/nn/modules/Conv2D.cpp @@ -32,20 +32,20 @@ Conv2D::Conv2D( int dx, int dy, bool bias, - int groups) - : nIn_(nin), - nOut_(nout), - xFilter_(wx), - yFilter_(wy), - xStride_(sx), - yStride_(sy), - xPad_(px.padVal), - yPad_(py.padVal), - xDilation_(dx), - yDilation_(dy), - bias_(bias), - groups_(groups) { - initialize(); + int groups +) : nIn_(nin), + nOut_(nout), + xFilter_(wx), + yFilter_(wy), + xStride_(sx), + yStride_(sy), + xPad_(px.padVal), + yPad_(py.padVal), + xDilation_(dx), + yDilation_(dy), + bias_(bias), + groups_(groups) { + initialize(); } Conv2D::Conv2D( @@ -56,20 +56,20 @@ Conv2D::Conv2D( IntOrPadMode py, int dx, int dy, - int groups) - : UnaryModule({w}), - nIn_(w.dim(2)), - nOut_(w.dim(3)), - xFilter_(w.dim(0)), - yFilter_(w.dim(1)), - xStride_(sx), - yStride_(sy), - xPad_(px.padVal), - yPad_(py.padVal), - xDilation_(dx), - yDilation_(dy), - bias_(false), - groups_(groups) {} + int groups +) : UnaryModule({w}), + nIn_(w.dim(2)), + nOut_(w.dim(3)), + xFilter_(w.dim(0)), + yFilter_(w.dim(1)), + xStride_(sx), + yStride_(sy), + xPad_(px.padVal), + yPad_(py.padVal), + xDilation_(dx), + yDilation_(dy), + bias_(false), + groups_(groups) {} Conv2D::Conv2D( const Variable& w, @@ -80,148 +80,152 @@ Conv2D::Conv2D( IntOrPadMode py, int dx, int dy, - int groups) - : UnaryModule({w, b}), - nIn_(w.dim(2)), - nOut_(w.dim(3)), - xFilter_(w.dim(0)), - yFilter_(w.dim(1)), - xStride_(sx), - yStride_(sy), - xPad_(px.padVal), - yPad_(py.padVal), - xDilation_(dx), - yDilation_(dy), - bias_(true), - groups_(groups) { - if (b.dim(2) != w.dim(3)) { - throw std::invalid_argument( - "output channel dimension mismatch between Conv2D weight and bias"); - } - if (b.elements() != b.dim(2)) { - throw std::invalid_argument( - "only 3rd dimension of Conv2D bias may be non-singleton"); - } + int groups +) : UnaryModule({w, b}), + nIn_(w.dim(2)), + nOut_(w.dim(3)), + xFilter_(w.dim(0)), + yFilter_(w.dim(1)), + xStride_(sx), + yStride_(sy), + xPad_(px.padVal), + yPad_(py.padVal), + xDilation_(dx), + yDilation_(dy), + bias_(true), + groups_(groups) { + if(b.dim(2) != w.dim(3)) { + throw std::invalid_argument( + "output channel dimension mismatch between Conv2D weight and bias" + ); + } + if(b.elements() != b.dim(2)) { + throw std::invalid_argument( + "only 3rd dimension of Conv2D bias may be non-singleton" + ); + } } -Conv2D::Conv2D(const Conv2D& other) - : UnaryModule(other.copyParams()), - nIn_(other.nIn_), - nOut_(other.nOut_), - xFilter_(other.xFilter_), - yFilter_(other.yFilter_), - xStride_(other.xStride_), - yStride_(other.yStride_), - xPad_(other.xPad_), - yPad_(other.yPad_), - xDilation_(other.xDilation_), - yDilation_(other.yDilation_), - bias_(other.bias_), - groups_(other.groups_) { - train_ = other.train_; +Conv2D::Conv2D(const Conv2D& other) : UnaryModule(other.copyParams()), + nIn_(other.nIn_), + nOut_(other.nOut_), + xFilter_(other.xFilter_), + yFilter_(other.yFilter_), + xStride_(other.xStride_), + yStride_(other.yStride_), + xPad_(other.xPad_), + yPad_(other.yPad_), + xDilation_(other.xDilation_), + yDilation_(other.yDilation_), + bias_(other.bias_), + groups_(other.groups_) { + train_ = other.train_; } Conv2D& Conv2D::operator=(const Conv2D& other) { - params_ = other.copyParams(); - train_ = other.train_; - nIn_ = other.nIn_; - nOut_ = other.nOut_; - xFilter_ = other.xFilter_; - yFilter_ = other.yFilter_; - xStride_ = other.xStride_; - yStride_ = other.yStride_; - xPad_ = other.xPad_; - yPad_ = other.yPad_; - xDilation_ = other.xDilation_; - yDilation_ = other.yDilation_; - bias_ = other.bias_; - groups_ = other.groups_; - return *this; + params_ = other.copyParams(); + train_ = other.train_; + nIn_ = other.nIn_; + nOut_ = other.nOut_; + xFilter_ = other.xFilter_; + yFilter_ = other.yFilter_; + xStride_ = other.xStride_; + yStride_ = other.yStride_; + xPad_ = other.xPad_; + yPad_ = other.yPad_; + xDilation_ = other.xDilation_; + yDilation_ = other.yDilation_; + bias_ = other.bias_; + groups_ = other.groups_; + return *this; } Variable Conv2D::forward(const Variable& input) { - auto px = derivePadding(input.dim(0), xFilter_, xStride_, xPad_, xDilation_); - auto py = derivePadding(input.dim(1), yFilter_, yStride_, yPad_, yDilation_); - if (!(px >= 0 && py >= 0)) { - throw std::invalid_argument("invalid padding for Conv2D"); - } - - if (bias_) { - return conv2d( - input, - params_[0].astype(input.type()), - params_[1].astype(input.type()), - xStride_, - yStride_, - px, - py, - xDilation_, - yDilation_, - groups_, - benchmarks_); - } else { - return conv2d( - input, - params_[0].astype(input.type()), - xStride_, - yStride_, - px, - py, - xDilation_, - yDilation_, - groups_, - benchmarks_); - } + auto px = derivePadding(input.dim(0), xFilter_, xStride_, xPad_, xDilation_); + auto py = derivePadding(input.dim(1), yFilter_, yStride_, yPad_, yDilation_); + if(!(px >= 0 && py >= 0)) { + throw std::invalid_argument("invalid padding for Conv2D"); + } + + if(bias_) { + return conv2d( + input, + params_[0].astype(input.type()), + params_[1].astype(input.type()), + xStride_, + yStride_, + px, + py, + xDilation_, + yDilation_, + groups_, + benchmarks_ + ); + } else { + return conv2d( + input, + params_[0].astype(input.type()), + xStride_, + yStride_, + px, + py, + xDilation_, + yDilation_, + groups_, + benchmarks_ + ); + } } void Conv2D::initialize() { - int fanIn = xFilter_ * yFilter_ * nIn_ / groups_; - auto wt = kaimingUniform( - Shape({xFilter_, yFilter_, nIn_ / groups_, nOut_}), - fanIn, - fl::dtype::f32, - true); - if (bias_) { - double bound = std::sqrt(1.0 / fanIn); - auto bs = - uniform(Shape({1, 1, nOut_, 1}), -bound, bound, fl::dtype::f32, true); - params_ = {wt, bs}; - } else { - params_ = {wt}; - } - - benchmarks_ = std::make_shared(); + int fanIn = xFilter_ * yFilter_ * nIn_ / groups_; + auto wt = kaimingUniform( + Shape({xFilter_, yFilter_, nIn_ / groups_, nOut_}), + fanIn, + fl::dtype::f32, + true + ); + if(bias_) { + double bound = std::sqrt(1.0 / fanIn); + auto bs = + uniform(Shape({1, 1, nOut_, 1}), -bound, bound, fl::dtype::f32, true); + params_ = {wt, bs}; + } else { + params_ = {wt}; + } + + benchmarks_ = std::make_shared(); } std::unique_ptr Conv2D::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Conv2D::prettyString() const { - std::ostringstream ss; - ss << "Conv2D"; - ss << " (" << nIn_ << "->" << nOut_ << ", " << xFilter_ << "x" << yFilter_ - << ", " << xStride_ << "," << yStride_ << ", "; - if (xPad_ == static_cast(PaddingMode::SAME)) { - ss << "SAME"; - } else { - ss << xPad_; - } - ss << ","; - if (yPad_ == static_cast(PaddingMode::SAME)) { - ss << "SAME"; - } else { - ss << yPad_; - } - ss << ", " << xDilation_ << ", " << yDilation_; - ss << ")"; - - if (bias_) { - ss << " (with bias)"; - } else { - ss << " (without bias)"; - } - return ss.str(); + std::ostringstream ss; + ss << "Conv2D"; + ss << " (" << nIn_ << "->" << nOut_ << ", " << xFilter_ << "x" << yFilter_ + << ", " << xStride_ << "," << yStride_ << ", "; + if(xPad_ == static_cast(PaddingMode::SAME)) { + ss << "SAME"; + } else { + ss << xPad_; + } + ss << ","; + if(yPad_ == static_cast(PaddingMode::SAME)) { + ss << "SAME"; + } else { + ss << yPad_; + } + ss << ", " << xDilation_ << ", " << yDilation_; + ss << ")"; + + if(bias_) { + ss << " (with bias)"; + } else { + ss << " (without bias)"; + } + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Conv2D.h b/flashlight/fl/nn/modules/Conv2D.h index cf08251..13af593 100644 --- a/flashlight/fl/nn/modules/Conv2D.h +++ b/flashlight/fl/nn/modules/Conv2D.h @@ -14,7 +14,7 @@ namespace fl { namespace detail { -struct ConvBenchmarks; + struct ConvBenchmarks; } /** @@ -38,172 +38,176 @@ struct ConvBenchmarks; * Y_{out} = \lceil{\frac{Y_{in}}{Y_{stride}}}\rceil\f] */ class FL_API Conv2D : public UnaryModule { - private: - FL_SAVE_LOAD_WITH_BASE( - UnaryModule, - nIn_, - nOut_, - xFilter_, - yFilter_, - xStride_, - yStride_, - xPad_, - yPad_, - fl::versioned(xDilation_, 1), - fl::versioned(yDilation_, 1), - bias_, - groups_) - - void initialize(); - - protected: - Conv2D() = default; - int nIn_, nOut_; // in/op channels - int xFilter_, yFilter_; // filter dims - int xStride_, yStride_; // stride - int xPad_, yPad_; // padding - int xDilation_{1}, yDilation_{1}; // dilation - bool bias_; - int groups_; - - public: - /** - * Constructs a Conv2D module - * - * @param n_in \f$C_{in}\f$, the number of channels in the input - * @param n_out \f$C_{out}\f$, the number of channels in the output - * @param wx the size of the first dimension of the convolving kernel - * @param wy the size of the second dimension of the convolving kernel - * @param sx the stride of the convolution along the first dimension - * @param sy the stride of the convolution along the second dimension - * @param px the amount of zero-padding added to the both sides of the first - * dimension of the input. Accepts a non-negative integer value or an enum - * fl::PaddingMode - * @param py the amount of zero-padding added to the both sides of the second - * dimension of the input. Accepts a non-negative integer value or an enum - * fl::PaddingMode - * @param dx dilation of the convolution along the first kernel dimension. A - * dilation of 1 is equivalent to a standard convolution along this axis. - * @param dy dilation of the convolution along the second kernel dimension. A - * dilation of 1 is equivalent to a standard convolution along this axis. - * @param bias a boolean value that controls whether to add a learnable bias - * to the output - * @param groups the number of groups that the input and output channels - * are divided into for restricting the connectivity between input and output - * channels. If `groups` > 1, the the output channels in the i-th group will - * be only connected to the input channels in the i-th group - */ - Conv2D( - int n_in, - int n_out, - int wx, - int wy, - int sx = 1, - int sy = 1, - detail::IntOrPadMode px = 0, - detail::IntOrPadMode py = 0, - int dx = 1, - int dy = 1, - bool bias = true, - int groups = 1); - - /** - * Constructs a Conv2D module with a kernel `Variable` tensor. No bias term - * will be applied to the output. - * - * @param w the kernel `Variable` tensor. The shape should be - * [\f$kerneldim_0\f$, \f$kerneldim_1\f$, \f$C_{in}\f$, \f$C_{out}\f$]. - * @param sx the stride of the convolution along the first dimension - * @param sy the stride of the convolution along the second dimension - * @param px the amount of zero-padding added to the both sides of the first - * dimension of the input. Accepts a non-negative integer value or an enum - * fl::PaddingMode - * @param py the amount of zero-padding added to the both sides of the second - * dimension of the input. Accepts a non-negative integer value or an enum - * fl::PaddingMode - * @param dx dilation of the convolution along the first kernel dimension. A - * dilation of 1 is equivalent to a standard convolution along this axis. - * @param dy dilation of the convolution along the second kernel dimension. A - * dilation of 1 is equivalent to a standard convolution along this axis. - * @param groups the number of groups that the input and output channels - * are divided into for restricting the connectivity between input and output - * channels. If `groups` > 1, the the output channels in the i-th group will - * be only connected to the input channels in the i-th group. - */ - explicit Conv2D( - const Variable& w, - int sx = 1, - int sy = 1, - detail::IntOrPadMode px = 0, - detail::IntOrPadMode py = 0, - int dx = 1, - int dy = 1, - int groups = 1); - - /** - * Constructs a Conv2D module with a kernel `Variable` tensor and a bias - * `Variable` tensor. - * - * @param w the kernel `Variable` tensor. The shape should be - * [\f$kerneldim_0\f$, \f$kerneldim_1\f$, \f$C_{in}\f$, \f$C_{out}\f$]. - * @param b the bias `Variable` tensor. The shape should be - * [\f$1\f$, \f$1\f$, \f$C_{out}\f$, \f$1\f$]. - * @param sx the stride of the convolution along the first dimension - * @param sy the stride of the convolution along the second dimension - * @param px the amount of zero-padding added to the both sides of the first - * dimension of the input. Accepts a non-negative integer value or an enum - * fl::PaddingMode - * @param py the amount of zero-padding added to the both sides of the second - * dimension of the input. Accepts a non-negative integer value or an enum - * fl::PaddingMode - * @param dx dilation of the convolution along the first kernel dimension. A - * dilation of 1 is equivalent to a standard convolution along this axis. - * @param dy dilation of the convolution along the second kernel dimension. A - * dilation of 1 is equivalent to a standard convolution along this axis. - * @param groups the number of groups that the input and output channels - * are divided into for restricting the connectivity between input and output - * channels. If `groups` > 1, the the output channels in the i-th group will - * be only connected to the input channels in the i-th group. - */ - Conv2D( - const Variable& w, - const Variable& b, - int sx = 1, - int sy = 1, - detail::IntOrPadMode px = 0, - detail::IntOrPadMode py = 0, - int dx = 1, - int dy = 1, - int groups = 1); - - /** - * Constructs an Conv2D module from another, performing a copy of the - * parameters. - * - * @param other The Conv2D module to copy from. - */ - Conv2D(const Conv2D& other); - - /** - * Constructs an Conv2D module from another, performing a copy of the - * parameters. - * - * @param other The Conv2D module to copy from. - */ - Conv2D& operator=(const Conv2D& other); - - Conv2D(Conv2D&& other) = default; - - Conv2D& operator=(Conv2D&& other) = default; - - Variable forward(const Variable& input) override; - - std::unique_ptr clone() const override; - - std::string prettyString() const override; - - protected: - std::shared_ptr benchmarks_; +private: + FL_SAVE_LOAD_WITH_BASE( + UnaryModule, + nIn_, + nOut_, + xFilter_, + yFilter_, + xStride_, + yStride_, + xPad_, + yPad_, + fl::versioned(xDilation_, 1), + fl::versioned(yDilation_, 1), + bias_, + groups_ + ) + + void initialize(); + +protected: + Conv2D() = default; + int nIn_, nOut_; // in/op channels + int xFilter_, yFilter_; // filter dims + int xStride_, yStride_; // stride + int xPad_, yPad_; // padding + int xDilation_{1}, yDilation_{1}; // dilation + bool bias_; + int groups_; + +public: + /** + * Constructs a Conv2D module + * + * @param n_in \f$C_{in}\f$, the number of channels in the input + * @param n_out \f$C_{out}\f$, the number of channels in the output + * @param wx the size of the first dimension of the convolving kernel + * @param wy the size of the second dimension of the convolving kernel + * @param sx the stride of the convolution along the first dimension + * @param sy the stride of the convolution along the second dimension + * @param px the amount of zero-padding added to the both sides of the first + * dimension of the input. Accepts a non-negative integer value or an enum + * fl::PaddingMode + * @param py the amount of zero-padding added to the both sides of the second + * dimension of the input. Accepts a non-negative integer value or an enum + * fl::PaddingMode + * @param dx dilation of the convolution along the first kernel dimension. A + * dilation of 1 is equivalent to a standard convolution along this axis. + * @param dy dilation of the convolution along the second kernel dimension. A + * dilation of 1 is equivalent to a standard convolution along this axis. + * @param bias a boolean value that controls whether to add a learnable bias + * to the output + * @param groups the number of groups that the input and output channels + * are divided into for restricting the connectivity between input and output + * channels. If `groups` > 1, the the output channels in the i-th group will + * be only connected to the input channels in the i-th group + */ + Conv2D( + int n_in, + int n_out, + int wx, + int wy, + int sx = 1, + int sy = 1, + detail::IntOrPadMode px = 0, + detail::IntOrPadMode py = 0, + int dx = 1, + int dy = 1, + bool bias = true, + int groups = 1 + ); + + /** + * Constructs a Conv2D module with a kernel `Variable` tensor. No bias term + * will be applied to the output. + * + * @param w the kernel `Variable` tensor. The shape should be + * [\f$kerneldim_0\f$, \f$kerneldim_1\f$, \f$C_{in}\f$, \f$C_{out}\f$]. + * @param sx the stride of the convolution along the first dimension + * @param sy the stride of the convolution along the second dimension + * @param px the amount of zero-padding added to the both sides of the first + * dimension of the input. Accepts a non-negative integer value or an enum + * fl::PaddingMode + * @param py the amount of zero-padding added to the both sides of the second + * dimension of the input. Accepts a non-negative integer value or an enum + * fl::PaddingMode + * @param dx dilation of the convolution along the first kernel dimension. A + * dilation of 1 is equivalent to a standard convolution along this axis. + * @param dy dilation of the convolution along the second kernel dimension. A + * dilation of 1 is equivalent to a standard convolution along this axis. + * @param groups the number of groups that the input and output channels + * are divided into for restricting the connectivity between input and output + * channels. If `groups` > 1, the the output channels in the i-th group will + * be only connected to the input channels in the i-th group. + */ + explicit Conv2D( + const Variable& w, + int sx = 1, + int sy = 1, + detail::IntOrPadMode px = 0, + detail::IntOrPadMode py = 0, + int dx = 1, + int dy = 1, + int groups = 1 + ); + + /** + * Constructs a Conv2D module with a kernel `Variable` tensor and a bias + * `Variable` tensor. + * + * @param w the kernel `Variable` tensor. The shape should be + * [\f$kerneldim_0\f$, \f$kerneldim_1\f$, \f$C_{in}\f$, \f$C_{out}\f$]. + * @param b the bias `Variable` tensor. The shape should be + * [\f$1\f$, \f$1\f$, \f$C_{out}\f$, \f$1\f$]. + * @param sx the stride of the convolution along the first dimension + * @param sy the stride of the convolution along the second dimension + * @param px the amount of zero-padding added to the both sides of the first + * dimension of the input. Accepts a non-negative integer value or an enum + * fl::PaddingMode + * @param py the amount of zero-padding added to the both sides of the second + * dimension of the input. Accepts a non-negative integer value or an enum + * fl::PaddingMode + * @param dx dilation of the convolution along the first kernel dimension. A + * dilation of 1 is equivalent to a standard convolution along this axis. + * @param dy dilation of the convolution along the second kernel dimension. A + * dilation of 1 is equivalent to a standard convolution along this axis. + * @param groups the number of groups that the input and output channels + * are divided into for restricting the connectivity between input and output + * channels. If `groups` > 1, the the output channels in the i-th group will + * be only connected to the input channels in the i-th group. + */ + Conv2D( + const Variable& w, + const Variable& b, + int sx = 1, + int sy = 1, + detail::IntOrPadMode px = 0, + detail::IntOrPadMode py = 0, + int dx = 1, + int dy = 1, + int groups = 1 + ); + + /** + * Constructs an Conv2D module from another, performing a copy of the + * parameters. + * + * @param other The Conv2D module to copy from. + */ + Conv2D(const Conv2D& other); + + /** + * Constructs an Conv2D module from another, performing a copy of the + * parameters. + * + * @param other The Conv2D module to copy from. + */ + Conv2D& operator=(const Conv2D& other); + + Conv2D(Conv2D&& other) = default; + + Conv2D& operator=(Conv2D&& other) = default; + + Variable forward(const Variable& input) override; + + std::unique_ptr clone() const override; + + std::string prettyString() const override; + +protected: + std::shared_ptr benchmarks_; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Dropout.cpp b/flashlight/fl/nn/modules/Dropout.cpp index 3d898b5..b7a3ae2 100644 --- a/flashlight/fl/nn/modules/Dropout.cpp +++ b/flashlight/fl/nn/modules/Dropout.cpp @@ -15,19 +15,19 @@ namespace fl { Dropout::Dropout(double drop_ratio) : ratio_(drop_ratio) {} Variable Dropout::forward(const Variable& input) { - if (train_) { - return dropout(input, ratio_); - } else { - return input; - } + if(train_) { + return dropout(input, ratio_); + } else { + return input; + } } std::unique_ptr Dropout::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Dropout::prettyString() const { - return ("Dropout (" + std::to_string(ratio_) + ")"); + return "Dropout (" + std::to_string(ratio_) + ")"; } } // namespace fl diff --git a/flashlight/fl/nn/modules/Dropout.h b/flashlight/fl/nn/modules/Dropout.h index 4e413fa..1697aac 100644 --- a/flashlight/fl/nn/modules/Dropout.h +++ b/flashlight/fl/nn/modules/Dropout.h @@ -22,24 +22,24 @@ namespace fl { * evaluating the module gives the identity. */ class FL_API Dropout : public UnaryModule { - private: - double ratio_; +private: + double ratio_; - FL_SAVE_LOAD_WITH_BASE(UnaryModule, ratio_) + FL_SAVE_LOAD_WITH_BASE(UnaryModule, ratio_) - public: - /** - * Creates a `Dropout` layer. - * - * @param drop_ratio the probability that a weight will be set to zero - */ - Dropout(double drop_ratio = 0.5); +public: + /** + * Creates a `Dropout` layer. + * + * @param drop_ratio the probability that a weight will be set to zero + */ + Dropout(double drop_ratio = 0.5); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Embedding.cpp b/flashlight/fl/nn/modules/Embedding.cpp index 82c7534..8e2ec84 100644 --- a/flashlight/fl/nn/modules/Embedding.cpp +++ b/flashlight/fl/nn/modules/Embedding.cpp @@ -14,49 +14,48 @@ namespace fl { -Embedding::Embedding(int embeddingDim, int numEmbeddings) - : embeddingDim_(embeddingDim), numEmbeddings_(numEmbeddings) { - initialize(); +Embedding::Embedding(int embeddingDim, int numEmbeddings) : embeddingDim_(embeddingDim), + numEmbeddings_(numEmbeddings) { + initialize(); } -Embedding::Embedding(const Variable& w) - : UnaryModule({w}), embeddingDim_(w.dim(0)), numEmbeddings_(w.dim(1)) {} +Embedding::Embedding(const Variable& w) : UnaryModule({w}), embeddingDim_(w.dim(0)), + numEmbeddings_(w.dim(1)) {} -Embedding::Embedding(const Embedding& other) - : UnaryModule(other.copyParams()), - embeddingDim_(other.embeddingDim_), - numEmbeddings_(other.numEmbeddings_) { - train_ = other.train_; +Embedding::Embedding(const Embedding& other) : UnaryModule(other.copyParams()), + embeddingDim_(other.embeddingDim_), + numEmbeddings_(other.numEmbeddings_) { + train_ = other.train_; } Embedding& Embedding::operator=(const Embedding& other) { - params_ = other.copyParams(); - train_ = other.train_; - embeddingDim_ = other.embeddingDim_; - numEmbeddings_ = other.numEmbeddings_; - return *this; + params_ = other.copyParams(); + train_ = other.train_; + embeddingDim_ = other.embeddingDim_; + numEmbeddings_ = other.numEmbeddings_; + return *this; } void Embedding::initialize() { - double stdv = std::sqrt(1.0 / static_cast(embeddingDim_)); - auto embeddings = - uniform(embeddingDim_, numEmbeddings_, -stdv, stdv, fl::dtype::f32, true); - params_ = {embeddings}; + double stdv = std::sqrt(1.0 / static_cast(embeddingDim_)); + auto embeddings = + uniform(embeddingDim_, numEmbeddings_, -stdv, stdv, fl::dtype::f32, true); + params_ = {embeddings}; } Variable Embedding::forward(const Variable& input) { - return embedding(input, params_[0]); + return embedding(input, params_[0]); } std::unique_ptr Embedding::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Embedding::prettyString() const { - std::ostringstream ss; - ss << "Embedding (embeddings: " << numEmbeddings_ - << ") (dim: " << embeddingDim_ << ")"; - return ss.str(); + std::ostringstream ss; + ss << "Embedding (embeddings: " << numEmbeddings_ + << ") (dim: " << embeddingDim_ << ")"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Embedding.h b/flashlight/fl/nn/modules/Embedding.h index 81fcbec..cd06a73 100644 --- a/flashlight/fl/nn/modules/Embedding.h +++ b/flashlight/fl/nn/modules/Embedding.h @@ -19,58 +19,58 @@ namespace fl { * [`embeddingDim`, \f$B_1\f$, \f$B_2\f$ (optional), \f$B_3\f$ (optional)]. */ class FL_API Embedding : public UnaryModule { - private: - Embedding() = default; // Intentionally private - - int embeddingDim_; - int numEmbeddings_; - - FL_SAVE_LOAD_WITH_BASE(UnaryModule, embeddingDim_, numEmbeddings_) - - void initialize(); - - public: - /** - * Constructs an Embedding module. - * - * @param embeddingDim the size of each embedding vector - * @param numEmbeddings the size of the dictionary of embeddings - */ - Embedding(int embeddingDim, int numEmbeddings); - - /** - * Constructs an Embedding module from the weight parameter \f$w\f$. - * - * @param w the 2D `Variable` tensor for the weight \f$w\f$. - * The shape should be [`embeddingDim`, `numEmbeddings`]. - */ - explicit Embedding(const Variable& w); - - /** - * Constructs an Embedding module from another, performing a copy of the - * parameters. - * - * @param other The Embedding module to copy from. - */ - Embedding(const Embedding& other); - - /** - * Constructs an Embedding module from another, performing a copy of the - * parameters. - * - * @param other The Embedding module to copy from. - */ - Embedding& operator=(const Embedding& other); - - Embedding(Embedding&& other) = default; - - Embedding& operator=(Embedding&& other) = default; - - Variable forward(const Variable& input) override; - - std::unique_ptr clone() const override; - - std::string prettyString() const override; +private: + Embedding() = default; // Intentionally private + + int embeddingDim_; + int numEmbeddings_; + + FL_SAVE_LOAD_WITH_BASE(UnaryModule, embeddingDim_, numEmbeddings_) + + void initialize(); + +public: + /** + * Constructs an Embedding module. + * + * @param embeddingDim the size of each embedding vector + * @param numEmbeddings the size of the dictionary of embeddings + */ + Embedding(int embeddingDim, int numEmbeddings); + + /** + * Constructs an Embedding module from the weight parameter \f$w\f$. + * + * @param w the 2D `Variable` tensor for the weight \f$w\f$. + * The shape should be [`embeddingDim`, `numEmbeddings`]. + */ + explicit Embedding(const Variable& w); + + /** + * Constructs an Embedding module from another, performing a copy of the + * parameters. + * + * @param other The Embedding module to copy from. + */ + Embedding(const Embedding& other); + + /** + * Constructs an Embedding module from another, performing a copy of the + * parameters. + * + * @param other The Embedding module to copy from. + */ + Embedding& operator=(const Embedding& other); + + Embedding(Embedding&& other) = default; + + Embedding& operator=(Embedding&& other) = default; + + Variable forward(const Variable& input) override; + + std::unique_ptr clone() const override; + + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Identity.cpp b/flashlight/fl/nn/modules/Identity.cpp index aeb2497..eaaad1e 100644 --- a/flashlight/fl/nn/modules/Identity.cpp +++ b/flashlight/fl/nn/modules/Identity.cpp @@ -10,15 +10,15 @@ namespace fl { std::vector Identity::forward(const std::vector& inputs) { - return inputs; + return inputs; }; std::unique_ptr Identity::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Identity::prettyString() const { - return "Identity"; + return "Identity"; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Identity.h b/flashlight/fl/nn/modules/Identity.h index 8a122d4..4d08206 100644 --- a/flashlight/fl/nn/modules/Identity.h +++ b/flashlight/fl/nn/modules/Identity.h @@ -15,14 +15,14 @@ namespace fl { * Identity returns the inputs at forward. */ class FL_API Identity : public Module { - public: - Identity() = default; - std::vector forward(const std::vector& inputs) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; +public: + Identity() = default; + std::vector forward(const std::vector& inputs) override; + std::unique_ptr clone() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(Module) +private: + FL_SAVE_LOAD_WITH_BASE(Module) }; } // namespace fl diff --git a/flashlight/fl/nn/modules/LayerNorm.cpp b/flashlight/fl/nn/modules/LayerNorm.cpp index 23c5950..4b2e0c1 100644 --- a/flashlight/fl/nn/modules/LayerNorm.cpp +++ b/flashlight/fl/nn/modules/LayerNorm.cpp @@ -22,139 +22,148 @@ LayerNorm::LayerNorm( int axis, double eps /* = 1e-5*/, bool affine /* = true*/, - int axisSize /* = kLnVariableAxisSize */) - : LayerNorm(std::vector({axis}), eps, affine, axisSize) {} + int axisSize /* = kLnVariableAxisSize */ +) : LayerNorm(std::vector({axis}), eps, affine, axisSize) {} LayerNorm::LayerNorm( const std::vector& axis, double eps /* = 1e-5 */, bool affine /* = true */, - int axisSize /* = kLnVariableAxisSize */) - : epsilon_(eps), affine_(affine), axisSize_(axisSize) { - for (int d = 0; d < kLnExpectedNumDims; ++d) { - if (std::find(axis.begin(), axis.end(), d) == axis.end()) { - axisComplement_.push_back(d); + int axisSize /* = kLnVariableAxisSize */ +) : epsilon_(eps), + affine_(affine), + axisSize_(axisSize) { + for(int d = 0; d < kLnExpectedNumDims; ++d) { + if(std::find(axis.begin(), axis.end(), d) == axis.end()) { + axisComplement_.push_back(d); + } } - } - initialize(); + initialize(); } Variable LayerNorm::forward(const Variable& _input) { - Variable input = _input; - // If the input isn't of kLnExpectedNumDims, reshape so it is -- do this by - // adding singleton dims. This is needed per computing the axis complement - // TODO: this is pretty ugly -- eventually fix this up if it can be avoided - if (input.ndim() < kLnExpectedNumDims) { - std::vector s = _input.shape().get(); - for (unsigned i = s.size(); i < kLnExpectedNumDims; ++i) { - s.push_back(1); - } - input = moddims(_input, Shape(s)); - } else if (input.ndim() > kLnExpectedNumDims) { - throw std::invalid_argument( - "LayerNorm::forward - input must be " + - std::to_string(kLnExpectedNumDims) + " or fewer dimensions."); - } - - Variable dummyInMean, dummyInVar; - - Variable inputToBn = input; - std::vector inNormAxes; - // reorder is only required if axisComplement_ is not continuous - Shape reorderDims(std::vector(input.ndim())); - auto maxAxis = - *std::max_element(axisComplement_.begin(), axisComplement_.end()); - auto minAxis = - *std::min_element(axisComplement_.begin(), axisComplement_.end()); - bool axesContinuous = (axisComplement_.size() == (maxAxis - minAxis + 1)); - if (axesContinuous) { - inNormAxes = axisComplement_; - } else { - int i = 0; - for (int d = 0; d < input.ndim(); ++d) { - if (std::find(axisComplement_.begin(), axisComplement_.end(), d) == - axisComplement_.end()) { - reorderDims[i++] = d; - } - } - for (auto n : axisComplement_) { - inNormAxes.push_back(i); - reorderDims[i++] = n; + Variable input = _input; + // If the input isn't of kLnExpectedNumDims, reshape so it is -- do this by + // adding singleton dims. This is needed per computing the axis complement + // TODO: this is pretty ugly -- eventually fix this up if it can be avoided + if(input.ndim() < kLnExpectedNumDims) { + std::vector s = _input.shape().get(); + for(unsigned i = s.size(); i < kLnExpectedNumDims; ++i) { + s.push_back(1); + } + input = moddims(_input, Shape(s)); + } else if(input.ndim() > kLnExpectedNumDims) { + throw std::invalid_argument( + "LayerNorm::forward - input must be " + + std::to_string(kLnExpectedNumDims) + " or fewer dimensions." + ); } - inputToBn = reorder(input, reorderDims); - } - auto paramsType = - (input.type() == fl::dtype::f16) ? fl::dtype::f32 : input.type(); - auto output = batchnorm( - inputToBn, - Variable(Tensor(paramsType), false), - Variable(Tensor(paramsType), false), - dummyInMean, - dummyInVar, - inNormAxes, - true, - 0.0, - epsilon_); - - if (!axesContinuous) { - std::vector> restoreDims; - for (size_t i = 0; i < reorderDims.ndim(); ++i) { - restoreDims.emplace_back(reorderDims[i], i); + + Variable dummyInMean, dummyInVar; + + Variable inputToBn = input; + std::vector inNormAxes; + // reorder is only required if axisComplement_ is not continuous + Shape reorderDims(std::vector(input.ndim())); + auto maxAxis = + *std::max_element(axisComplement_.begin(), axisComplement_.end()); + auto minAxis = + *std::min_element(axisComplement_.begin(), axisComplement_.end()); + bool axesContinuous = (axisComplement_.size() == (maxAxis - minAxis + 1)); + if(axesContinuous) { + inNormAxes = axisComplement_; + } else { + int i = 0; + for(int d = 0; d < input.ndim(); ++d) { + if( + std::find(axisComplement_.begin(), axisComplement_.end(), d) + == axisComplement_.end() + ) { + reorderDims[i++] = d; + } + } + for(auto n : axisComplement_) { + inNormAxes.push_back(i); + reorderDims[i++] = n; + } + inputToBn = reorder(input, reorderDims); } - std::sort(restoreDims.begin(), restoreDims.end()); - Shape restoreDimsShape(std::vector(restoreDims.size())); - for (size_t i = 0; i < restoreDims.size(); ++i) { - restoreDimsShape[i] = restoreDims[i].second; + auto paramsType = + (input.type() == fl::dtype::f16) ? fl::dtype::f32 : input.type(); + auto output = batchnorm( + inputToBn, + Variable(Tensor(paramsType), false), + Variable(Tensor(paramsType), false), + dummyInMean, + dummyInVar, + inNormAxes, + true, + 0.0, + epsilon_ + ); + + if(!axesContinuous) { + std::vector> restoreDims; + for(size_t i = 0; i < reorderDims.ndim(); ++i) { + restoreDims.emplace_back(reorderDims[i], i); + } + std::sort(restoreDims.begin(), restoreDims.end()); + Shape restoreDimsShape(std::vector(restoreDims.size())); + for(size_t i = 0; i < restoreDims.size(); ++i) { + restoreDimsShape[i] = restoreDims[i].second; + } + output = reorder(output, restoreDimsShape); } - output = reorder(output, restoreDimsShape); - } - - if (affine_) { - Variable weight = params_[0].astype(output.type()); - Variable bias = params_[1].astype(output.type()); - if (axisSize_ != kLnVariableAxisSize) { - Shape affineDims = input.shape(); - for (int ax : axisComplement_) { - affineDims[ax] = 1; - } - if (affineDims.elements() != axisSize_) { - throw std::invalid_argument( - "[LayerNorm] Input size along the norm axis doesn't with axisSize."); - } - weight = moddims(params_[0].astype(output.type()), affineDims); - bias = moddims(params_[1].astype(output.type()), affineDims); + + if(affine_) { + Variable weight = params_[0].astype(output.type()); + Variable bias = params_[1].astype(output.type()); + if(axisSize_ != kLnVariableAxisSize) { + Shape affineDims = input.shape(); + for(int ax : axisComplement_) { + affineDims[ax] = 1; + } + if(affineDims.elements() != axisSize_) { + throw std::invalid_argument( + "[LayerNorm] Input size along the norm axis doesn't with axisSize." + ); + } + weight = moddims(params_[0].astype(output.type()), affineDims); + bias = moddims(params_[1].astype(output.type()), affineDims); + } + output = tileAs(weight, input) * output + tileAs(bias, input); } - output = tileAs(weight, input) * output + tileAs(bias, input); - } - return moddims(output, _input.shape()); + return moddims(output, _input.shape()); } void LayerNorm::initialize() { - if (affine_) { - auto paramDim = (axisSize_ == kLnVariableAxisSize) ? 1 : axisSize_; - auto wt = constant(1.0, {paramDim}, fl::dtype::f32, true); - auto bs = constant(0.0, {paramDim}, fl::dtype::f32, true); - params_ = {wt, bs}; - } + if(affine_) { + auto paramDim = (axisSize_ == kLnVariableAxisSize) ? 1 : axisSize_; + auto wt = constant(1.0, {paramDim}, fl::dtype::f32, true); + auto bs = constant(0.0, {paramDim}, fl::dtype::f32, true); + params_ = {wt, bs}; + } } std::unique_ptr LayerNorm::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string LayerNorm::prettyString() const { - std::ostringstream ss; - ss << "LayerNorm"; - ss << " ( axis : { "; - for (int d = 0; d < axisComplement_.size(); ++d) { - if (std::find(axisComplement_.begin(), axisComplement_.end(), d) == - axisComplement_.end()) { - ss << d << " "; + std::ostringstream ss; + ss << "LayerNorm"; + ss << " ( axis : { "; + for(int d = 0; d < axisComplement_.size(); ++d) { + if( + std::find(axisComplement_.begin(), axisComplement_.end(), d) + == axisComplement_.end() + ) { + ss << d << " "; + } } - } - ss << "} , size : " << axisSize_ << ")"; - return ss.str(); + ss << "} , size : " << axisSize_ << ")"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/LayerNorm.h b/flashlight/fl/nn/modules/LayerNorm.h index 300ec16..3c6bec0 100644 --- a/flashlight/fl/nn/modules/LayerNorm.h +++ b/flashlight/fl/nn/modules/LayerNorm.h @@ -26,73 +26,76 @@ constexpr const int kLnVariableAxisSize = -1; * \f$\beta\f$ are learnable parameters for affine transformation. */ class FL_API LayerNorm : public UnaryModule { - public: - /** - * Constructs a LayerNorm module. - * - * @param axis the axis along which normalization is computed. Usually set as - * the feature axis. - * @param eps \f$\epsilon\f$ - * @param affine a boolean value that controls the learning of \f$\gamma\f$ - * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively - * if set to `false`, or initialized as learnable parameters - * if set to `true`. - * @param axisSize total size of features specified by `axis` to perform - * elementwise affine transform. If the feat size is variable, use - * `kLnVariableAxisSize` which uses singleton weight, bias and tiles them - * dynamically according to the given input. - */ - explicit LayerNorm( - int axis, - double eps = 1e-5, - bool affine = true, - int axisSize = kLnVariableAxisSize); +public: + /** + * Constructs a LayerNorm module. + * + * @param axis the axis along which normalization is computed. Usually set as + * the feature axis. + * @param eps \f$\epsilon\f$ + * @param affine a boolean value that controls the learning of \f$\gamma\f$ + * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively + * if set to `false`, or initialized as learnable parameters + * if set to `true`. + * @param axisSize total size of features specified by `axis` to perform + * elementwise affine transform. If the feat size is variable, use + * `kLnVariableAxisSize` which uses singleton weight, bias and tiles them + * dynamically according to the given input. + */ + explicit LayerNorm( + int axis, + double eps = 1e-5, + bool affine = true, + int axisSize = kLnVariableAxisSize + ); - /** - * Constructs a LayerNorm module. - * - * @param axis the axis along which normalization is computed. Usually set as - * the feature axis. - * @param eps \f$\epsilon\f$ - * @param affine a boolean value that controls the learning of \f$\gamma\f$ - * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively - * if set to `false`, or initialized as learnable parameters - * if set to `true`. - * @param axisSize total size of features specified by `axis` to perform - * elementwise affine transform. If the feat size is variable, use - * `kLnVariableAxisSize` which uses singleton weight, bias and tiles them - * dynamically according to the given input. - */ - explicit LayerNorm( - const std::vector& axis, - double eps = 1e-5, - bool affine = true, - int axisSize = kLnVariableAxisSize); + /** + * Constructs a LayerNorm module. + * + * @param axis the axis along which normalization is computed. Usually set as + * the feature axis. + * @param eps \f$\epsilon\f$ + * @param affine a boolean value that controls the learning of \f$\gamma\f$ + * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively + * if set to `false`, or initialized as learnable parameters + * if set to `true`. + * @param axisSize total size of features specified by `axis` to perform + * elementwise affine transform. If the feat size is variable, use + * `kLnVariableAxisSize` which uses singleton weight, bias and tiles them + * dynamically according to the given input. + */ + explicit LayerNorm( + const std::vector& axis, + double eps = 1e-5, + bool affine = true, + int axisSize = kLnVariableAxisSize + ); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - LayerNorm() = default; +private: + LayerNorm() = default; - // For legacy reasons, we store the complement of `axis` - // to not break serialization - std::vector axisComplement_; - double epsilon_; - bool affine_; - int axisSize_{kLnVariableAxisSize}; + // For legacy reasons, we store the complement of `axis` + // to not break serialization + std::vector axisComplement_; + double epsilon_; + bool affine_; + int axisSize_{kLnVariableAxisSize}; - FL_SAVE_LOAD_WITH_BASE( - UnaryModule, - axisComplement_, - epsilon_, - affine_, - fl::versioned(axisSize_, 1)) + FL_SAVE_LOAD_WITH_BASE( + UnaryModule, + axisComplement_, + epsilon_, + affine_, + fl::versioned(axisSize_, 1) + ) - void initialize(); + void initialize(); }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Linear.cpp b/flashlight/fl/nn/modules/Linear.cpp index 2775856..5f86423 100644 --- a/flashlight/fl/nn/modules/Linear.cpp +++ b/flashlight/fl/nn/modules/Linear.cpp @@ -16,77 +16,80 @@ namespace fl { -Linear::Linear(int input_size, int output_size, bool bias) - : UnaryModule(), nIn_(input_size), nOut_(output_size), bias_(bias) { - initialize(); +Linear::Linear(int input_size, int output_size, bool bias) : UnaryModule(), + nIn_(input_size), + nOut_(output_size), + bias_(bias) { + initialize(); } -Linear::Linear(const Variable& w) - : UnaryModule({w}), nIn_(w.dim(1)), nOut_(w.dim(0)), bias_(false) {} +Linear::Linear(const Variable& w) : UnaryModule({w}), nIn_(w.dim(1)), nOut_(w.dim(0)), bias_(false) {} -Linear::Linear(const Variable& w, const Variable& b) - : UnaryModule({w, b}), nIn_(w.dim(1)), nOut_(w.dim(0)), bias_(true) { - if (b.dim(0) != w.dim(0)) { - throw std::invalid_argument( - "dimension mismatch between Linear weight and bias"); - } +Linear::Linear(const Variable& w, const Variable& b) : UnaryModule({w, b}), nIn_(w.dim(1)), nOut_(w.dim(0)), + bias_(true) { + if(b.dim(0) != w.dim(0)) { + throw std::invalid_argument( + "dimension mismatch between Linear weight and bias" + ); + } } -Linear::Linear(const Linear& other) - : UnaryModule(other.copyParams()), - nIn_(other.nIn_), - nOut_(other.nOut_), - bias_(other.bias_) { - train_ = other.train_; +Linear::Linear(const Linear& other) : UnaryModule(other.copyParams()), + nIn_(other.nIn_), + nOut_(other.nOut_), + bias_(other.bias_) { + train_ = other.train_; } Linear& Linear::operator=(const Linear& other) { - params_ = other.copyParams(); - train_ = other.train_; - nIn_ = other.nIn_; - nOut_ = other.nOut_; - bias_ = other.bias_; - return *this; + params_ = other.copyParams(); + train_ = other.train_; + nIn_ = other.nIn_; + nOut_ = other.nOut_; + bias_ = other.bias_; + return *this; } Variable Linear::forward(const Variable& input) { - if (bias_) { - return linear( - input, - params_[0].astype(input.type()), - params_[1].astype(input.type())); - } - return linear(input, params_[0].astype(input.type())); + if(bias_) { + return linear( + input, + params_[0].astype(input.type()), + params_[1].astype(input.type()) + ); + } + return linear(input, params_[0].astype(input.type())); } void Linear::initialize() { - int fanIn = nIn_; - auto w = Variable( - detail::kaimingUniform(Shape({nOut_, nIn_}), fanIn, fl::dtype::f32), - true); - if (bias_) { - double bound = std::sqrt(1.0 / fanIn); - auto b = uniform(Shape({nOut_}), -bound, bound, fl::dtype::f32, true); - params_ = {w, b}; - } else { - params_ = {w}; - } + int fanIn = nIn_; + auto w = Variable( + detail::kaimingUniform(Shape({nOut_, nIn_}), fanIn, fl::dtype::f32), + true + ); + if(bias_) { + double bound = std::sqrt(1.0 / fanIn); + auto b = uniform(Shape({nOut_}), -bound, bound, fl::dtype::f32, true); + params_ = {w, b}; + } else { + params_ = {w}; + } } std::unique_ptr Linear::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Linear::prettyString() const { - std::ostringstream ss; - ss << "Linear"; - ss << " (" << nIn_ << "->" << nOut_ << ")"; - if (bias_) { - ss << " (with bias)"; - } else { - ss << " (without bias)"; - } - return ss.str(); + std::ostringstream ss; + ss << "Linear"; + ss << " (" << nIn_ << "->" << nOut_ << ")"; + if(bias_) { + ss << " (with bias)"; + } else { + ss << " (without bias)"; + } + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Linear.h b/flashlight/fl/nn/modules/Linear.h index cbacd9c..826775c 100644 --- a/flashlight/fl/nn/modules/Linear.h +++ b/flashlight/fl/nn/modules/Linear.h @@ -17,66 +17,66 @@ namespace fl { * it to an output of shape [`output_size`, *, *, *]. */ class FL_API Linear : public UnaryModule { - private: - Linear() = default; // Intentionally private - - int nIn_, nOut_; - bool bias_; - - FL_SAVE_LOAD_WITH_BASE(UnaryModule, nIn_, nOut_, bias_) - - void initialize(); - - public: - /** - * Constructs a Linear module from the input and output sample sizes. - * - * @param input_size the size of each input sample - * @param output_size the size of each output sample - * @param bias a boolean value that controls whether the layer will include - * a bias term \f$b\f$. - */ - Linear(int input_size, int output_size, bool bias = true); - - /** - * Constructs a Linear module from the weight parameter \f$w\f$. The layer - * will not include the bias term \f$b\f$ in this case. - * - * @param w the 2D `Variable` tensor for the weight \f$w\f$. - * The shape should be [`output_size`, `input_size`]. - */ - explicit Linear(const Variable& w); - - /** - * Constructs a Linear module from the weight parameter \f$w\f$ and the bias - * parameter \f$b\f$. - * - * @param w the 2D `Variable` tensor for the weight \f$w\f$. - * The shape should be [`output_size`, `input_size`]. - * @param b the 1D `Variable` tensor for the bias \f$b\f$. - * The shape should be [`output_size`]. - */ - Linear(const Variable& w, const Variable& b); - - /** - * Constructs an Linear module from another, performing a deep copy of the - * parameters. - * - * @param other The Linear module to copy from. - */ - Linear(const Linear& other); - - Linear& operator=(const Linear& other); - - Linear(Linear&& other) = default; - - Linear& operator=(Linear&& other) = default; - - Variable forward(const Variable& input) override; - - std::unique_ptr clone() const override; - - std::string prettyString() const override; +private: + Linear() = default; // Intentionally private + + int nIn_, nOut_; + bool bias_; + + FL_SAVE_LOAD_WITH_BASE(UnaryModule, nIn_, nOut_, bias_) + + void initialize(); + +public: + /** + * Constructs a Linear module from the input and output sample sizes. + * + * @param input_size the size of each input sample + * @param output_size the size of each output sample + * @param bias a boolean value that controls whether the layer will include + * a bias term \f$b\f$. + */ + Linear(int input_size, int output_size, bool bias = true); + + /** + * Constructs a Linear module from the weight parameter \f$w\f$. The layer + * will not include the bias term \f$b\f$ in this case. + * + * @param w the 2D `Variable` tensor for the weight \f$w\f$. + * The shape should be [`output_size`, `input_size`]. + */ + explicit Linear(const Variable& w); + + /** + * Constructs a Linear module from the weight parameter \f$w\f$ and the bias + * parameter \f$b\f$. + * + * @param w the 2D `Variable` tensor for the weight \f$w\f$. + * The shape should be [`output_size`, `input_size`]. + * @param b the 1D `Variable` tensor for the bias \f$b\f$. + * The shape should be [`output_size`]. + */ + Linear(const Variable& w, const Variable& b); + + /** + * Constructs an Linear module from another, performing a deep copy of the + * parameters. + * + * @param other The Linear module to copy from. + */ + Linear(const Linear& other); + + Linear& operator=(const Linear& other); + + Linear(Linear&& other) = default; + + Linear& operator=(Linear&& other) = default; + + Variable forward(const Variable& input) override; + + std::unique_ptr clone() const override; + + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Loss.cpp b/flashlight/fl/nn/modules/Loss.cpp index ff31bf3..81eb725 100644 --- a/flashlight/fl/nn/modules/Loss.cpp +++ b/flashlight/fl/nn/modules/Loss.cpp @@ -16,214 +16,230 @@ namespace fl { Variable MeanSquaredError::forward( const Variable& inputs, - const Variable& targets) { - if (inputs.shape() != targets.shape()) { - throw std::invalid_argument( - "MeanSquaredError::forward - inputs and targets are of different" - " sizes: {inputs: " + - inputs.shape().toString() + ", targets: " + targets.shape().toString() + - "}"); - } + const Variable& targets +) { + if(inputs.shape() != targets.shape()) { + throw std::invalid_argument( + "MeanSquaredError::forward - inputs and targets are of different" + " sizes: {inputs: " + + inputs.shape().toString() + ", targets: " + targets.shape().toString() + + "}" + ); + } - auto df = inputs - targets; - auto res = mean(flat(df * df), {0}); - return res; + auto df = inputs - targets; + auto res = mean(flat(df * df), {0}); + return res; } std::unique_ptr MeanSquaredError::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string MeanSquaredError::prettyString() const { - return "MeanSquaredError"; + return "MeanSquaredError"; } Variable MeanAbsoluteError::forward( const Variable& inputs, - const Variable& targets) { - if (inputs.shape() != targets.shape()) { - throw std::invalid_argument( - "MeanAbsoluteError::forward - inputs and targets are of different" - " sizes: {inputs: " + - inputs.shape().toString() + ", targets: " + targets.shape().toString() + - "}"); - } + const Variable& targets +) { + if(inputs.shape() != targets.shape()) { + throw std::invalid_argument( + "MeanAbsoluteError::forward - inputs and targets are of different" + " sizes: {inputs: " + + inputs.shape().toString() + ", targets: " + targets.shape().toString() + + "}" + ); + } - auto df = inputs - targets; - return mean(flat(fl::abs(df)), {0}); + auto df = inputs - targets; + return mean(flat(fl::abs(df)), {0}); } std::unique_ptr MeanAbsoluteError::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string MeanAbsoluteError::prettyString() const { - return "MeanAbsoluteError"; + return "MeanAbsoluteError"; } Variable BinaryCrossEntropy::forward( const Variable& inputs, - const Variable& targets) { - return mean(flat(binaryCrossEntropy(inputs, targets)), {0}); + const Variable& targets +) { + return mean(flat(binaryCrossEntropy(inputs, targets)), {0}); } Variable BinaryCrossEntropy::forward( const Variable& inputs, const Variable& targets, - const Variable& weights) { - return mean(flat(weights * binaryCrossEntropy(inputs, targets)), {0}); + const Variable& weights +) { + return mean(flat(weights * binaryCrossEntropy(inputs, targets)), {0}); } std::unique_ptr BinaryCrossEntropy::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string BinaryCrossEntropy::prettyString() const { - return "BinaryCrossEntropy"; + return "BinaryCrossEntropy"; } Variable CategoricalCrossEntropy::forward( const Variable& inputs, - const Variable& targets) { - return categoricalCrossEntropy(inputs, targets, reduction_, ignoreIndex_); + const Variable& targets +) { + return categoricalCrossEntropy(inputs, targets, reduction_, ignoreIndex_); } std::unique_ptr CategoricalCrossEntropy::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string CategoricalCrossEntropy::prettyString() const { - return "CategoricalCrossEntropy"; + return "CategoricalCrossEntropy"; } AdaptiveSoftMaxLoss::AdaptiveSoftMaxLoss( std::shared_ptr activation, ReduceMode reduction, - int ignoreIndex) - : BinaryModule(), - activation_(activation), - reduction_(reduction), - ignoreIndex_(ignoreIndex) { - params_ = activation_->params(); + int ignoreIndex +) : BinaryModule(), + activation_(activation), + reduction_(reduction), + ignoreIndex_(ignoreIndex) { + params_ = activation_->params(); } Variable AdaptiveSoftMaxLoss::cast( const Variable& input, const Shape& outDims, - const Tensor& indices) { - if (input.elements() != indices.elements()) { - throw std::invalid_argument("AdaptiveSoftMaxLoss: input, indices mismatch"); - } - Tensor output = fl::full(outDims, 0, input.type()); - output(indices) = input.tensor().flatten(); - auto inputDims = input.shape(); - - auto gradFunc = [indices, inputDims]( - std::vector& inputs, - const Variable& grad_output) { - Tensor gradTensor = grad_output.tensor()(indices); - auto grad = Variable(fl::reshape(gradTensor, inputDims), false); - inputs[0].addGrad(grad); - }; - return Variable(output, {input.withoutData()}, gradFunc); + const Tensor& indices +) { + if(input.elements() != indices.elements()) { + throw std::invalid_argument("AdaptiveSoftMaxLoss: input, indices mismatch"); + } + Tensor output = fl::full(outDims, 0, input.type()); + output(indices) = input.tensor().flatten(); + auto inputDims = input.shape(); + + auto gradFunc = [indices, inputDims]( + std::vector& inputs, + const Variable& grad_output) { + Tensor gradTensor = grad_output.tensor()(indices); + auto grad = Variable(fl::reshape(gradTensor, inputDims), false); + inputs[0].addGrad(grad); + }; + return Variable(output, {input.withoutData()}, gradFunc); } Variable AdaptiveSoftMaxLoss::forward( const Variable& inputs, - const Variable& targets) { - // inputs: N x T x B - // targets: T x B - if (inputs.ndim() != 3) { - throw std::invalid_argument( - "AdaptiveSoftMaxLoss::forward expects input tensor with " - "3 dimensions in N x T x B ordering."); - } - if (targets.ndim() != 2) { - throw std::invalid_argument( - "AdaptiveSoftMaxLoss::forward expects target tensor with " - "2 dimensions in T x B ordering."); - } - if (inputs.dim(1) != targets.dim(0)) { - throw std::invalid_argument("AdaptiveSoftMaxLoss: length mismatch"); - } else if (inputs.dim(2) != targets.dim(1)) { - throw std::invalid_argument("AdaptiveSoftMaxLoss: batch size mismatch"); - } - - auto N = inputs.dim(0); - auto T = inputs.dim(1); - auto B = inputs.dim(2); - auto cutoff = activation_->getCutoff(); - - auto input = moddims(inputs, {N, T * B}); - auto target = moddims(targets, {T * B}); - - auto headOutput = matmul(params_[0], input); - auto headTarget = Variable(target.tensor(), false) * (target < cutoff[0]); - // TODO: check the type of res - auto res = Variable(fl::full({T * B}, 0, fl::dtype::f32), true); - - // Tail forwawrd - for (int i = 0; i < cutoff.size() - 1; i++) { - auto mask = (target >= cutoff[i]) && (target < cutoff[i + 1]); - if (!fl::any(mask.tensor()).scalar()) { - continue; + const Variable& targets +) { + // inputs: N x T x B + // targets: T x B + if(inputs.ndim() != 3) { + throw std::invalid_argument( + "AdaptiveSoftMaxLoss::forward expects input tensor with " + "3 dimensions in N x T x B ordering." + ); + } + if(targets.ndim() != 2) { + throw std::invalid_argument( + "AdaptiveSoftMaxLoss::forward expects target tensor with " + "2 dimensions in T x B ordering." + ); + } + if(inputs.dim(1) != targets.dim(0)) { + throw std::invalid_argument("AdaptiveSoftMaxLoss: length mismatch"); + } else if(inputs.dim(2) != targets.dim(1)) { + throw std::invalid_argument("AdaptiveSoftMaxLoss: batch size mismatch"); } - auto indicesArray = fl::nonzero(mask.tensor()); - headTarget = - headTarget + (mask * (cutoff[0] + i)).astype(headTarget.type()); - auto tailTarget = target(indicesArray) - cutoff[i]; - auto selectedInput = embedding(Variable(indicesArray, false), input); - auto tailOutput = matmul(params_[1 + i * 2], selectedInput); - tailOutput = matmul(params_[2 + i * 2], tailOutput); - auto localLoss = categoricalCrossEntropy( - logSoftmax(tailOutput, 0), tailTarget, ReduceMode::NONE, ignoreIndex_); - res = res + cast(localLoss, res.shape(), indicesArray); - } - - // Head forward - res = res + - categoricalCrossEntropy( + auto N = inputs.dim(0); + auto T = inputs.dim(1); + auto B = inputs.dim(2); + auto cutoff = activation_->getCutoff(); + + auto input = moddims(inputs, {N, T * B}); + auto target = moddims(targets, {T* B}); + + auto headOutput = matmul(params_[0], input); + auto headTarget = Variable(target.tensor(), false) * (target < cutoff[0]); + // TODO: check the type of res + auto res = Variable(fl::full({T* B}, 0, fl::dtype::f32), true); + + // Tail forwawrd + for(int i = 0; i < cutoff.size() - 1; i++) { + auto mask = (target >= cutoff[i]) && (target < cutoff[i + 1]); + if(!fl::any(mask.tensor()).scalar()) { + continue; + } + + auto indicesArray = fl::nonzero(mask.tensor()); + headTarget = + headTarget + (mask * (cutoff[0] + i)).astype(headTarget.type()); + auto tailTarget = target(indicesArray) - cutoff[i]; + auto selectedInput = embedding(Variable(indicesArray, false), input); + auto tailOutput = matmul(params_[1 + i * 2], selectedInput); + tailOutput = matmul(params_[2 + i * 2], tailOutput); + auto localLoss = categoricalCrossEntropy( + logSoftmax(tailOutput, 0), + tailTarget, + ReduceMode::NONE, + ignoreIndex_ + ); + res = res + cast(localLoss, res.shape(), indicesArray); + } + + // Head forward + res = res + + categoricalCrossEntropy( logSoftmax(headOutput, 0), headTarget, ReduceMode::NONE, - ignoreIndex_); + ignoreIndex_ + ); - // Reduce - if (reduction_ == ReduceMode::NONE) { - return moddims(res, targets.shape()); - } - res = sum(res, {0}); - if (reduction_ == ReduceMode::MEAN) { - auto denominator = - fl::countNonzero(target.tensor() != ignoreIndex_).scalar(); - res = res / denominator; - } - return res; + // Reduce + if(reduction_ == ReduceMode::NONE) { + return moddims(res, targets.shape()); + } + res = sum(res, {0}); + if(reduction_ == ReduceMode::MEAN) { + auto denominator = + fl::countNonzero(target.tensor() != ignoreIndex_).scalar(); + res = res / denominator; + } + return res; } std::shared_ptr AdaptiveSoftMaxLoss::getActivation() const { - return activation_; + return activation_; }; void AdaptiveSoftMaxLoss::setParams(const Variable& var, int position) { - Module::setParams(var, position); - activation_->setParams(var, position); + Module::setParams(var, position); + activation_->setParams(var, position); } std::unique_ptr AdaptiveSoftMaxLoss::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string AdaptiveSoftMaxLoss::prettyString() const { - std::ostringstream ss; - auto cutoff = activation_->getCutoff(); - ss << "Adaptive Softmax ("; - for (int i = 0; i < cutoff.size() - 1; i++) { - ss << cutoff[i] << ", "; - } - ss << cutoff[cutoff.size() - 1] << ")"; - return ss.str(); + std::ostringstream ss; + auto cutoff = activation_->getCutoff(); + ss << "Adaptive Softmax ("; + for(int i = 0; i < cutoff.size() - 1; i++) { + ss << cutoff[i] << ", "; + } + ss << cutoff[cutoff.size() - 1] << ")"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Loss.h b/flashlight/fl/nn/modules/Loss.h index 49f4758..bc79d50 100644 --- a/flashlight/fl/nn/modules/Loss.h +++ b/flashlight/fl/nn/modules/Loss.h @@ -18,50 +18,50 @@ class Tensor; /** * Computes the [mean squared - error](https://en.wikipedia.org/wiki/Mean_squared_error) between elements + error](https://en.wikipedia.org/wiki/Mean_squared_error) between elements * across two tensors: * \f[ \mathcal{L}(x, y) = \frac{1}{n} \sum_{i = 0}^n \left( x_i - y_i \right)^2 \f] * for input tensor \f$x\f$ and target tensor \f$y\f$ each of which contain - \f$n\f$ elements. + \f$n\f$ elements. */ class FL_API MeanSquaredError : public BinaryModule { - public: - MeanSquaredError() = default; +public: + MeanSquaredError() = default; - Variable forward(const Variable& inputs, const Variable& targets) override; + Variable forward(const Variable& inputs, const Variable& targets) override; - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(BinaryModule) +private: + FL_SAVE_LOAD_WITH_BASE(BinaryModule) }; /** * Computes the [mean absolute - error](https://en.wikipedia.org/wiki/Mean_absolute_error) (equivalent to the - \f$L_1\f$ loss): + error](https://en.wikipedia.org/wiki/Mean_absolute_error) (equivalent to the + \f$L_1\f$ loss): * \f[ \mathcal{L}(x, y) = \frac{1}{n} \sum_{i = 0}^n \left| x_i - y_i \right| \f] * for input tensor \f$x\f$ and target tensor \f$y\f$ each of which contain - \f$n\f$ elements. + \f$n\f$ elements. */ class FL_API MeanAbsoluteError : public BinaryModule { - public: - MeanAbsoluteError() = default; +public: + MeanAbsoluteError() = default; - Variable forward(const Variable& inputs, const Variable& targets) override; + Variable forward(const Variable& inputs, const Variable& targets) override; - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(BinaryModule) +private: + FL_SAVE_LOAD_WITH_BASE(BinaryModule) }; /** @@ -69,37 +69,38 @@ class FL_API MeanAbsoluteError : public BinaryModule { * target tensor \f$y\f$. The binary cross entropy loss is: * \f[ B(x, y) = \frac{1}{n} \sum_{i = 0}^n -\left( w_i \times (y_i \times \log(x_i) - + (1 - y_i) \times \log(1 - x_i)) \right) \f] + + (1 - y_i) \times \log(1 - x_i)) \right) \f] * where \f$w\f$ is an optional weight parameter for rescaling. * * Both the inputs and the targets are expected to be between 0 and 1. */ class FL_API BinaryCrossEntropy : public BinaryModule { - public: - BinaryCrossEntropy() = default; +public: + BinaryCrossEntropy() = default; - using BinaryModule::forward; + using BinaryModule::forward; - Variable forward(const Variable& inputs, const Variable& targets) override; + Variable forward(const Variable& inputs, const Variable& targets) override; - /** - * Perform forward loss computation with an additional weight tensor. - * - * @param inputs a tensor with the predicted values - * @param targets a tensor with the target values - * @param weights a rescaling weight given to the loss of each element. - */ - Variable forward( - const Variable& inputs, - const Variable& targets, - const Variable& weights); + /** + * Perform forward loss computation with an additional weight tensor. + * + * @param inputs a tensor with the predicted values + * @param targets a tensor with the target values + * @param weights a rescaling weight given to the loss of each element. + */ + Variable forward( + const Variable& inputs, + const Variable& targets, + const Variable& weights + ); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(BinaryModule) +private: + FL_SAVE_LOAD_WITH_BASE(BinaryModule) }; /** @@ -109,7 +110,7 @@ class FL_API BinaryCrossEntropy : public BinaryModule { * ground truth class for each input example. * * In the batch case, the output loss tensor \f$\{l_1,...,l_N\}^\top\f$, put - \f$l_n = -x_{n, y_n}\f$ + \f$l_n = -x_{n, y_n}\f$ * (only consider the probability of the correct class). Then reduce via: * \f[ \mathcal{L}(x, y) = \sum_{i = 1}^N l_i @@ -123,44 +124,46 @@ class FL_API BinaryCrossEntropy : public BinaryModule { * `ReduceMode`. */ class FL_API CategoricalCrossEntropy : public BinaryModule { - private: - ReduceMode reduction_; - int ignoreIndex_{-1}; - - FL_SAVE_LOAD_WITH_BASE( - BinaryModule, - reduction_, - fl::versioned(ignoreIndex_, 1)) - - public: - /** - * Creates a `CategoricalCrossEntropy`. - * - * @param reduction a reduction with which to compute the loss. See - * documentation on `ReduceMode` for available options. - * @param ignoreIndex a target value that is ignored and does not contribute - * to the loss or the input gradient. If `reduce` is MEAN, the loss is - * averaged over non-ignored targets. - */ - explicit CategoricalCrossEntropy( - ReduceMode reduction = ReduceMode::MEAN, - int ignoreIndex = -1) - : reduction_(reduction), ignoreIndex_(ignoreIndex) {} - - /** - * Computes the categorical cross entropy loss for some input and target - * tensors. - * - * @param inputs a `Variable` with shape [\f$C\f$, \f$B_1\f$, \f$B_2\f$, - * \f$B_3\f$] where \f$C\f$ is the number of classes. - * @param targets an integer `Variable` with shape [\f$B_1\f$, \f$B_2\f$, - * \f$B_3\f$]. The values must be in [\f$0\f$, \f$C - 1\f$] - */ - Variable forward(const Variable& inputs, const Variable& targets) override; - - std::unique_ptr clone() const override; - - std::string prettyString() const override; +private: + ReduceMode reduction_; + int ignoreIndex_{-1}; + + FL_SAVE_LOAD_WITH_BASE( + BinaryModule, + reduction_, + fl::versioned(ignoreIndex_, 1) + ) + +public: + /** + * Creates a `CategoricalCrossEntropy`. + * + * @param reduction a reduction with which to compute the loss. See + * documentation on `ReduceMode` for available options. + * @param ignoreIndex a target value that is ignored and does not contribute + * to the loss or the input gradient. If `reduce` is MEAN, the loss is + * averaged over non-ignored targets. + */ + explicit CategoricalCrossEntropy( + ReduceMode reduction = ReduceMode::MEAN, + int ignoreIndex = -1 + ) : reduction_(reduction), + ignoreIndex_(ignoreIndex) {} + + /** + * Computes the categorical cross entropy loss for some input and target + * tensors. + * + * @param inputs a `Variable` with shape [\f$C\f$, \f$B_1\f$, \f$B_2\f$, + * \f$B_3\f$] where \f$C\f$ is the number of classes. + * @param targets an integer `Variable` with shape [\f$B_1\f$, \f$B_2\f$, + * \f$B_3\f$]. The values must be in [\f$0\f$, \f$C - 1\f$] + */ + Variable forward(const Variable& inputs, const Variable& targets) override; + + std::unique_ptr clone() const override; + + std::string prettyString() const override; }; /** @@ -178,53 +181,54 @@ class FL_API CategoricalCrossEntropy : public BinaryModule { * up computation. */ class FL_API AdaptiveSoftMaxLoss : public BinaryModule { - private: - FL_SAVE_LOAD_WITH_BASE( - BinaryModule, - activation_, - reduction_, - fl::versioned(ignoreIndex_, 1)) - std::shared_ptr activation_; - ReduceMode reduction_; - int ignoreIndex_{-1}; - - Variable - cast(const Variable& input, const Shape& outDims, const Tensor& indices); - - public: - AdaptiveSoftMaxLoss() = default; - - /** - * Create an `AdaptiveSoftMaxLoss` with given parameters - * - * @param reduction the reduction mode - see `ReductionMode` See - * documentation on `ReduceMode` for available options. - * @param ignoreIndex a target value that is ignored and does not contribute - * to the loss or the input gradient. If `reduce` is MEAN, the loss is - * averaged over non-ignored targets. - */ - explicit AdaptiveSoftMaxLoss( - std::shared_ptr activation, - ReduceMode reduction = ReduceMode::MEAN, - int ignoreIndex = -1); - std::shared_ptr getActivation() const; - - /** - * Computes the categorical cross entropy loss for some input and target - * tensors (uses adaptive softmax function to do this efficiently) - * - * @param inputs a `Variable` with shape [\f$C\f$, \f$B_1\f$, \f$B_2\f$, - * \f$B_3\f$] where \f$C\f$ is the number of classes. - * @param targets an integer `Variable` with shape [\f$B_1\f$, \f$B_2\f$, - * \f$B_3\f$]. The values must be in [\f$0\f$, \f$C - 1\f$] - */ - Variable forward(const Variable& inputs, const Variable& targets) override; - - void setParams(const Variable& var, int position) override; - - std::unique_ptr clone() const override; - - std::string prettyString() const override; +private: + FL_SAVE_LOAD_WITH_BASE( + BinaryModule, + activation_, + reduction_, + fl::versioned(ignoreIndex_, 1) + ) + std::shared_ptr activation_; + ReduceMode reduction_; + int ignoreIndex_{-1}; + + Variable cast(const Variable& input, const Shape& outDims, const Tensor& indices); + +public: + AdaptiveSoftMaxLoss() = default; + + /** + * Create an `AdaptiveSoftMaxLoss` with given parameters + * + * @param reduction the reduction mode - see `ReductionMode` See + * documentation on `ReduceMode` for available options. + * @param ignoreIndex a target value that is ignored and does not contribute + * to the loss or the input gradient. If `reduce` is MEAN, the loss is + * averaged over non-ignored targets. + */ + explicit AdaptiveSoftMaxLoss( + std::shared_ptr activation, + ReduceMode reduction = ReduceMode::MEAN, + int ignoreIndex = -1 + ); + std::shared_ptr getActivation() const; + + /** + * Computes the categorical cross entropy loss for some input and target + * tensors (uses adaptive softmax function to do this efficiently) + * + * @param inputs a `Variable` with shape [\f$C\f$, \f$B_1\f$, \f$B_2\f$, + * \f$B_3\f$] where \f$C\f$ is the number of classes. + * @param targets an integer `Variable` with shape [\f$B_1\f$, \f$B_2\f$, + * \f$B_3\f$]. The values must be in [\f$0\f$, \f$C - 1\f$] + */ + Variable forward(const Variable& inputs, const Variable& targets) override; + + void setParams(const Variable& var, int position) override; + + std::unique_ptr clone() const override; + + std::string prettyString() const override; }; typedef MeanSquaredError MSE; diff --git a/flashlight/fl/nn/modules/Module.cpp b/flashlight/fl/nn/modules/Module.cpp index efb2dea..41e1fb8 100644 --- a/flashlight/fl/nn/modules/Module.cpp +++ b/flashlight/fl/nn/modules/Module.cpp @@ -16,98 +16,98 @@ namespace fl { Module::Module() = default; -Module::Module(const std::vector& params) - : params_(params.begin(), params.end()) {} +Module::Module(const std::vector& params) : params_(params.begin(), params.end()) {} Variable Module::param(int position) const { - if (!(position >= 0 && position < params_.size())) { - throw std::out_of_range("Module param index out of range"); - } - return params_[position]; + if(!(position >= 0 && position < params_.size())) { + throw std::out_of_range("Module param index out of range"); + } + return params_[position]; } void Module::setParams(const Variable& var, int position) { - if (!(position >= 0 && position < params_.size())) { - throw std::out_of_range("Module param index out of range"); - } - params_[position] = var; + if(!(position >= 0 && position < params_.size())) { + throw std::out_of_range("Module param index out of range"); + } + params_[position] = var; } std::vector Module::copyParams() const { - std::vector params; - params.reserve(params_.size()); - for (const auto& param : params_) { - params.emplace_back(param.copy()); - } - return params; + std::vector params; + params.reserve(params_.size()); + for(const auto& param : params_) { + params.emplace_back(param.copy()); + } + return params; } void Module::train() { - train_ = true; - for (auto& param : params_) { - param.setCalcGrad(true); - } + train_ = true; + for(auto& param : params_) { + param.setCalcGrad(true); + } } void Module::zeroGrad() { - for (auto& param : params_) { - param.zeroGrad(); - } + for(auto& param : params_) { + param.zeroGrad(); + } } void Module::eval() { - train_ = false; - for (auto& param : params_) { - param.setCalcGrad(false); - } + train_ = false; + for(auto& param : params_) { + param.setCalcGrad(false); + } } std::vector Module::params() const { - return params_; + return params_; } int Module::numParamTensors() const { - return static_cast(params_.size()); + return static_cast(params_.size()); } std::vector Module::operator()(const std::vector& input) { - return this->forward(input); + return this->forward(input); } UnaryModule::UnaryModule() = default; -UnaryModule::UnaryModule(const std::vector& params) - : Module(params) {} +UnaryModule::UnaryModule(const std::vector& params) : Module(params) {} std::vector UnaryModule::forward( - const std::vector& inputs) { - if (inputs.size() != 1) { - throw std::invalid_argument("UnaryModule expects only one input"); - } - return {forward(inputs[0])}; + const std::vector& inputs +) { + if(inputs.size() != 1) { + throw std::invalid_argument("UnaryModule expects only one input"); + } + return {forward(inputs[0])}; } Variable UnaryModule::operator()(const Variable& input) { - return this->forward(input); + return this->forward(input); } BinaryModule::BinaryModule() = default; -BinaryModule::BinaryModule(const std::vector& params) - : Module(params) {} +BinaryModule::BinaryModule(const std::vector& params) : Module(params) {} std::vector BinaryModule::forward( - const std::vector& inputs) { - if (inputs.size() != 2) { - throw std::invalid_argument("BinaryModule expects two inputs"); - } - return {forward(inputs[0], inputs[1])}; + const std::vector& inputs +) { + if(inputs.size() != 2) { + throw std::invalid_argument("BinaryModule expects two inputs"); + } + return {forward(inputs[0], inputs[1])}; } Variable BinaryModule::operator()( const Variable& input1, - const Variable& input2) { - return this->forward(input1, input2); + const Variable& input2 +) { + return this->forward(input1, input2); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Module.h b/flashlight/fl/nn/modules/Module.h index 1fe3437..b103a28 100644 --- a/flashlight/fl/nn/modules/Module.h +++ b/flashlight/fl/nn/modules/Module.h @@ -23,137 +23,138 @@ namespace fl { * serialized and deserialized with the module. */ class FL_API Module { - private: - /** - * Serialize the module's parameters. - */ - FL_SAVE_LOAD(params_, train_) - - protected: - /** - * Parameters of module, represented as a collection of `Variable`, whose - * ordering is based on the implementation of the respective module. - */ - std::vector params_; - - /** - * A flag specifying whether or not the module is in `train` mode. If - * `Module::train()` is called, it will be set to true, and if - * `Module::eval()` is called, it will be set to false. - */ - bool train_ = true; - - /** - * An empty module constructor, which creates a module with no parameters. - * - */ - Module(); - - /** - * Constructs a module given its parameters. - * - * @param params a vector of `Variable` which will replace `params_` - * This changes all parameters so that gradient calculation will be - * enabled/disabled for any calls to `forward`. - */ - explicit Module(const std::vector& params); - - public: - /** - * Gets the parameters of the module. - * - * @return the modules parameters as a vector of `Variable` - */ - std::vector params() const; - - /** - * Gets the nunber of parameter tensors of the module. - * - * @return the number of parameter tensors - */ - int numParamTensors() const; - - /** - * Switches the module to training mode. Changes all parameters so that - * gradient calculation will be enabled for any calls to `forward`. - */ - virtual void train(); - - /** - * Switches the module to evaluation mode. Changes all parameters so that - * gradient calculation will be disabled for any calls to `forward`. - */ - virtual void eval(); - - /** - * Returns a module parameter given a particular position. - * - * @param position the index of the requested parameter in `params_` - * @return a `Variable` tensor for the parameter at the requested position - */ - Variable param(int position) const; - - /** - * Sets a parameter at a specified position with a new, given one. - * - * If the specified position is not valid (it is negative or greater than - * ``params_.size() - 1``), then an error will be thrown. A new parameter - * will not be created at a specified index if out of bounds. - * - * @param[in] var the new replacement `Variable` - * @param position The index of the parameter which will be replaced in - * `params_` - */ - virtual void setParams(const Variable& var, int position); - - /** - * Copies the modules parameters, detaching from the computation graph. - * - * @return a copy of the modules parameters as a vector of `Variable` - */ - virtual std::vector copyParams() const; - - /** - * Clears references to gradient Variables for all parameters in the module. - */ - void zeroGrad(); - - /** - * Performs forward computation for the module, given some inputs. - * - * @param inputs the values to compute forward computation for the - * module. - * @return a vector of `Variable` tensors containing the result of - * the forward computation - */ - virtual std::vector forward( - const std::vector& inputs) = 0; - - /** - * Overload for forward computation for the module. - * - * @param inputs the values to compute forward computation for the - * module. - * @return a vector of `Variable` tensors containing the result of - * the forward computation - */ - std::vector operator()(const std::vector& inputs); - - /** - * Clone the module via deep copy of its parameters and members. - * - * @return A unique pointer of the cloned module. - */ - virtual std::unique_ptr clone() const = 0; - - /** - * Generates a stringified representation of the module. - * - * @return a string containing the module label - */ - virtual std::string prettyString() const = 0; - - virtual ~Module() = default; +private: + /** + * Serialize the module's parameters. + */ + FL_SAVE_LOAD(params_, train_) + +protected: + /** + * Parameters of module, represented as a collection of `Variable`, whose + * ordering is based on the implementation of the respective module. + */ + std::vector params_; + + /** + * A flag specifying whether or not the module is in `train` mode. If + * `Module::train()` is called, it will be set to true, and if + * `Module::eval()` is called, it will be set to false. + */ + bool train_ = true; + + /** + * An empty module constructor, which creates a module with no parameters. + * + */ + Module(); + + /** + * Constructs a module given its parameters. + * + * @param params a vector of `Variable` which will replace `params_` + * This changes all parameters so that gradient calculation will be + * enabled/disabled for any calls to `forward`. + */ + explicit Module(const std::vector& params); + +public: + /** + * Gets the parameters of the module. + * + * @return the modules parameters as a vector of `Variable` + */ + std::vector params() const; + + /** + * Gets the nunber of parameter tensors of the module. + * + * @return the number of parameter tensors + */ + int numParamTensors() const; + + /** + * Switches the module to training mode. Changes all parameters so that + * gradient calculation will be enabled for any calls to `forward`. + */ + virtual void train(); + + /** + * Switches the module to evaluation mode. Changes all parameters so that + * gradient calculation will be disabled for any calls to `forward`. + */ + virtual void eval(); + + /** + * Returns a module parameter given a particular position. + * + * @param position the index of the requested parameter in `params_` + * @return a `Variable` tensor for the parameter at the requested position + */ + Variable param(int position) const; + + /** + * Sets a parameter at a specified position with a new, given one. + * + * If the specified position is not valid (it is negative or greater than + * ``params_.size() - 1``), then an error will be thrown. A new parameter + * will not be created at a specified index if out of bounds. + * + * @param[in] var the new replacement `Variable` + * @param position The index of the parameter which will be replaced in + * `params_` + */ + virtual void setParams(const Variable& var, int position); + + /** + * Copies the modules parameters, detaching from the computation graph. + * + * @return a copy of the modules parameters as a vector of `Variable` + */ + virtual std::vector copyParams() const; + + /** + * Clears references to gradient Variables for all parameters in the module. + */ + void zeroGrad(); + + /** + * Performs forward computation for the module, given some inputs. + * + * @param inputs the values to compute forward computation for the + * module. + * @return a vector of `Variable` tensors containing the result of + * the forward computation + */ + virtual std::vector forward( + const std::vector& inputs + ) = 0; + + /** + * Overload for forward computation for the module. + * + * @param inputs the values to compute forward computation for the + * module. + * @return a vector of `Variable` tensors containing the result of + * the forward computation + */ + std::vector operator()(const std::vector& inputs); + + /** + * Clone the module via deep copy of its parameters and members. + * + * @return A unique pointer of the cloned module. + */ + virtual std::unique_ptr clone() const = 0; + + /** + * Generates a stringified representation of the module. + * + * @return a string containing the module label + */ + virtual std::string prettyString() const = 0; + + virtual ~Module() = default; }; /** @@ -162,21 +163,21 @@ class FL_API Module { * For example, `Sigmoid` module can be derived from `UnaryModule`. */ class FL_API UnaryModule : public Module { - public: - UnaryModule(); +public: + UnaryModule(); - explicit UnaryModule(const std::vector& params); + explicit UnaryModule(const std::vector& params); - std::vector forward(const std::vector& inputs) override; + std::vector forward(const std::vector& inputs) override; - virtual Variable forward(const Variable& input) = 0; + virtual Variable forward(const Variable& input) = 0; - Variable operator()(const Variable& input); + Variable operator()(const Variable& input); - virtual ~UnaryModule() = default; + virtual ~UnaryModule() = default; - private: - FL_SAVE_LOAD_WITH_BASE(Module) +private: + FL_SAVE_LOAD_WITH_BASE(Module) }; /** @@ -185,21 +186,21 @@ class FL_API UnaryModule : public Module { * For example, `BinaryCrossEntropy` Loss can be derived from `BinaryModule`. */ class FL_API BinaryModule : public Module { - public: - BinaryModule(); +public: + BinaryModule(); - explicit BinaryModule(const std::vector& params); + explicit BinaryModule(const std::vector& params); - std::vector forward(const std::vector& inputs) override; + std::vector forward(const std::vector& inputs) override; - virtual Variable forward(const Variable& input1, const Variable& input2) = 0; + virtual Variable forward(const Variable& input1, const Variable& input2) = 0; - Variable operator()(const Variable& input1, const Variable& input2); + Variable operator()(const Variable& input1, const Variable& input2); - virtual ~BinaryModule() = default; + virtual ~BinaryModule() = default; - private: - FL_SAVE_LOAD_WITH_BASE(Module) +private: + FL_SAVE_LOAD_WITH_BASE(Module) }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Normalize.cpp b/flashlight/fl/nn/modules/Normalize.cpp index 5a09425..53e9aeb 100644 --- a/flashlight/fl/nn/modules/Normalize.cpp +++ b/flashlight/fl/nn/modules/Normalize.cpp @@ -14,29 +14,32 @@ Normalize::Normalize( const std::vector& axes, double p /* = 2 */, double eps /* = 1e-12 */, - double value /* = 1 */) - : axes_(axes), p_(p), eps_(eps), value_(value) {} + double value /* = 1 */ +) : axes_(axes), + p_(p), + eps_(eps), + value_(value) {} Variable Normalize::forward(const Variable& input) { - return value_ * normalize(input, axes_, p_, eps_); + return value_ * normalize(input, axes_, p_, eps_); } std::unique_ptr Normalize::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Normalize::prettyString() const { - std::ostringstream ss; - ss << "Normalize"; - ss << " ( axis : { "; - for (auto d : axes_) { - ss << d << " "; - } - ss << "} , p : " << p_; - ss << ", eps : " << eps_; - ss << ", value : " << value_; - ss << " )"; - return ss.str(); + std::ostringstream ss; + ss << "Normalize"; + ss << " ( axis : { "; + for(auto d : axes_) { + ss << d << " "; + } + ss << "} , p : " << p_; + ss << ", eps : " << eps_; + ss << ", value : " << value_; + ss << " )"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Normalize.h b/flashlight/fl/nn/modules/Normalize.h index bae9098..7207f12 100644 --- a/flashlight/fl/nn/modules/Normalize.h +++ b/flashlight/fl/nn/modules/Normalize.h @@ -13,37 +13,38 @@ namespace fl { class FL_API Normalize : public UnaryModule { - public: - /** - * Constructs a Normalize module. - * - * @param value the target normalization value. - * @param axes reduce over specified axes - * @param p as p in Lp norm - * @param eps min clamping value to avoid overflows - * @param normalization mode, as supported by normalize() - */ - explicit Normalize( - const std::vector& axes, - double p = 2, - double eps = 1e-12, - double value = 1); - - Variable forward(const Variable& input) override; - - std::unique_ptr clone() const override; - - std::string prettyString() const override; - - private: - Normalize() = default; - - std::vector axes_; - double p_; - double eps_; - double value_; - - FL_SAVE_LOAD_WITH_BASE(UnaryModule, axes_, p_, eps_, value_) +public: + /** + * Constructs a Normalize module. + * + * @param value the target normalization value. + * @param axes reduce over specified axes + * @param p as p in Lp norm + * @param eps min clamping value to avoid overflows + * @param normalization mode, as supported by normalize() + */ + explicit Normalize( + const std::vector& axes, + double p = 2, + double eps = 1e-12, + double value = 1 + ); + + Variable forward(const Variable& input) override; + + std::unique_ptr clone() const override; + + std::string prettyString() const override; + +private: + Normalize() = default; + + std::vector axes_; + double p_; + double eps_; + double value_; + + FL_SAVE_LOAD_WITH_BASE(UnaryModule, axes_, p_, eps_, value_) }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Padding.cpp b/flashlight/fl/nn/modules/Padding.cpp index 8797f37..54cfa3b 100644 --- a/flashlight/fl/nn/modules/Padding.cpp +++ b/flashlight/fl/nn/modules/Padding.cpp @@ -11,25 +11,25 @@ namespace fl { -Padding::Padding(std::vector> padding, double val) - : m_pad(std::move(padding)), m_val(val) {} +Padding::Padding(std::vector> padding, double val) : m_pad(std::move(padding)), + m_val(val) {} Variable Padding::forward(const Variable& input) { - return padding(input, m_pad, m_val); + return padding(input, m_pad, m_val); } std::unique_ptr Padding::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Padding::prettyString() const { - std::ostringstream ss; - ss << "Padding (" << m_val << ", { "; - for (auto p : m_pad) { - ss << "(" << p.first << ", " << p.second << "), "; - } - ss << "})"; - return ss.str(); + std::ostringstream ss; + ss << "Padding (" << m_val << ", { "; + for(auto p : m_pad) { + ss << "(" << p.first << ", " << p.second << "), "; + } + ss << "})"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Padding.h b/flashlight/fl/nn/modules/Padding.h index 35163f4..df78701 100644 --- a/flashlight/fl/nn/modules/Padding.h +++ b/flashlight/fl/nn/modules/Padding.h @@ -19,32 +19,32 @@ namespace fl { * \f$i\f$ of size specified by the tuple `padi` to the input. */ class FL_API Padding : public UnaryModule { - private: - Padding() = default; // intentionally private +private: + Padding() = default; // intentionally private - std::vector> m_pad; - double m_val; + std::vector> m_pad; + double m_val; - FL_SAVE_LOAD_WITH_BASE(UnaryModule, m_pad, m_val) + FL_SAVE_LOAD_WITH_BASE(UnaryModule, m_pad, m_val) - public: - /** - * Constructs a Padding module that pads the first dimension of the input. If - * the input is of shape - * [\f$dim_0\f$, \f$dim_1\f$, \f$dim_2\f$, \f$dim_3\f$], - * the output will be of shape [\f$paddingBefore+dim_0+paddingAfter\f$, - * \f$dim_1\f$, \f$dim_2\f$, \f$dim_3\f$] - * @param[in] padding a vector of tuples representing padding (before, - * after) tuples for each axis - * @param val the value to be padded - */ - Padding(std::vector> padding, double val); +public: + /** + * Constructs a Padding module that pads the first dimension of the input. If + * the input is of shape + * [\f$dim_0\f$, \f$dim_1\f$, \f$dim_2\f$, \f$dim_3\f$], + * the output will be of shape [\f$paddingBefore+dim_0+paddingAfter\f$, + * \f$dim_1\f$, \f$dim_2\f$, \f$dim_3\f$] + * @param[in] padding a vector of tuples representing padding (before, + * after) tuples for each axis + * @param val the value to be padded + */ + Padding(std::vector> padding, double val); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Pool2D.cpp b/flashlight/fl/nn/modules/Pool2D.cpp index 39b4f45..609c3cc 100644 --- a/flashlight/fl/nn/modules/Pool2D.cpp +++ b/flashlight/fl/nn/modules/Pool2D.cpp @@ -24,69 +24,71 @@ Pool2D::Pool2D( int sy, IntOrPadMode px, IntOrPadMode py, - PoolingMode mode) - : xFilter_(wx), - yFilter_(wy), - xStride_(sx), - yStride_(sy), - xPad_(px.padVal), - yPad_(py.padVal), - mode_(mode) {} + PoolingMode mode +) : xFilter_(wx), + yFilter_(wy), + xStride_(sx), + yStride_(sy), + xPad_(px.padVal), + yPad_(py.padVal), + mode_(mode) {} Variable Pool2D::forward(const Variable& input) { - auto px = derivePadding( - input.dim(0), - xFilter_, - xStride_, - xPad_, - /* dilation= */ 1); - auto py = derivePadding( - input.dim(1), - yFilter_, - yStride_, - yPad_, - /* dilation= */ 1); + auto px = derivePadding( + input.dim(0), + xFilter_, + xStride_, + xPad_, + /* dilation= */ 1 + ); + auto py = derivePadding( + input.dim(1), + yFilter_, + yStride_, + yPad_, + /* dilation= */ 1 + ); - if (!(px >= 0 && py >= 0)) { - throw std::invalid_argument("invalid padding for Pool2D"); - } + if(!(px >= 0 && py >= 0)) { + throw std::invalid_argument("invalid padding for Pool2D"); + } - return pool2d(input, xFilter_, yFilter_, xStride_, yStride_, px, py, mode_); + return pool2d(input, xFilter_, yFilter_, xStride_, yStride_, px, py, mode_); } std::unique_ptr Pool2D::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Pool2D::prettyString() const { - std::ostringstream ss; - ss << "Pool2D"; - switch (mode_) { - case PoolingMode::MAX: - ss << "-max"; - break; - case PoolingMode::AVG_EXCLUDE_PADDING: - ss << "-avg(without pad)"; - break; - case PoolingMode::AVG_INCLUDE_PADDING: - ss << "-avg(with pad)"; - break; - } - ss << " (" << xFilter_ << "x" << yFilter_ << ", " << xStride_ << "," - << yStride_ << ", "; - if (xPad_ == static_cast(PaddingMode::SAME)) { - ss << "SAME"; - } else { - ss << xPad_; - } - ss << ","; - if (yPad_ == static_cast(PaddingMode::SAME)) { - ss << "SAME"; - } else { - ss << yPad_; - } - ss << ")"; - return ss.str(); + std::ostringstream ss; + ss << "Pool2D"; + switch(mode_) { + case PoolingMode::MAX: + ss << "-max"; + break; + case PoolingMode::AVG_EXCLUDE_PADDING: + ss << "-avg(without pad)"; + break; + case PoolingMode::AVG_INCLUDE_PADDING: + ss << "-avg(with pad)"; + break; + } + ss << " (" << xFilter_ << "x" << yFilter_ << ", " << xStride_ << "," + << yStride_ << ", "; + if(xPad_ == static_cast(PaddingMode::SAME)) { + ss << "SAME"; + } else { + ss << xPad_; + } + ss << ","; + if(yPad_ == static_cast(PaddingMode::SAME)) { + ss << "SAME"; + } else { + ss << yPad_; + } + ss << ")"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Pool2D.h b/flashlight/fl/nn/modules/Pool2D.h index 14b668c..c372819 100644 --- a/flashlight/fl/nn/modules/Pool2D.h +++ b/flashlight/fl/nn/modules/Pool2D.h @@ -20,53 +20,55 @@ namespace fl { * of shape [\f$X_{out}\f$, \f$Y_{out}\f$, \f$C\f$, \f$N\f$]. */ class FL_API Pool2D : public UnaryModule { - private: - Pool2D() = default; // Intentionally private +private: + Pool2D() = default; // Intentionally private - int xFilter_, yFilter_; // pooling dims - int xStride_, yStride_; // stride - int xPad_, yPad_; // padding - used iff padding mode is none - PoolingMode mode_; // pooling type + int xFilter_, yFilter_; // pooling dims + int xStride_, yStride_; // stride + int xPad_, yPad_; // padding - used iff padding mode is none + PoolingMode mode_; // pooling type - FL_SAVE_LOAD_WITH_BASE( - UnaryModule, - xFilter_, - yFilter_, - xStride_, - yStride_, - xPad_, - yPad_, - mode_) + FL_SAVE_LOAD_WITH_BASE( + UnaryModule, + xFilter_, + yFilter_, + xStride_, + yStride_, + xPad_, + yPad_, + mode_ + ) - public: - /** Construct a Pool2D layer. - * @param wx pooling window size in the first dimension - * @param wy pooling window size in the second dimension - * @param sx stride in the first dimension - * @param sy stride in the second dimension - * @param px amount of zero-padding on both sides in the first dimension. - * Accepts a non-negative integer value or an enum fl::PaddingMode - * @param py amount of zero-padding on both sides in the second dimension. - * Accepts a non-negative integer value or an enum fl::PaddingMode - * @param mode pooling mode. Can be any of: - * - MAX - * - AVG_INCLUDE_PADDING - * - AVG_EXCLUDE_PADDING - */ - Pool2D( - int wx, - int wy, - int sx = 1, - int sy = 1, - detail::IntOrPadMode px = 0, - detail::IntOrPadMode py = 0, - PoolingMode mode = PoolingMode::MAX); +public: + /** Construct a Pool2D layer. + * @param wx pooling window size in the first dimension + * @param wy pooling window size in the second dimension + * @param sx stride in the first dimension + * @param sy stride in the second dimension + * @param px amount of zero-padding on both sides in the first dimension. + * Accepts a non-negative integer value or an enum fl::PaddingMode + * @param py amount of zero-padding on both sides in the second dimension. + * Accepts a non-negative integer value or an enum fl::PaddingMode + * @param mode pooling mode. Can be any of: + * - MAX + * - AVG_INCLUDE_PADDING + * - AVG_EXCLUDE_PADDING + */ + Pool2D( + int wx, + int wy, + int sx = 1, + int sy = 1, + detail::IntOrPadMode px = 0, + detail::IntOrPadMode py = 0, + PoolingMode mode = PoolingMode::MAX + ); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/PrecisionCast.cpp b/flashlight/fl/nn/modules/PrecisionCast.cpp index 8ba97a8..8bfb5f4 100644 --- a/flashlight/fl/nn/modules/PrecisionCast.cpp +++ b/flashlight/fl/nn/modules/PrecisionCast.cpp @@ -14,32 +14,33 @@ namespace fl { PrecisionCast::PrecisionCast(fl::dtype targetType) : targetType_(targetType) {} std::vector PrecisionCast::forward( - const std::vector& inputs) { - std::vector outputs; - for (const auto& input : inputs) { - auto output = input.astype(targetType_); - outputs.push_back(output); - } - return outputs; + const std::vector& inputs +) { + std::vector outputs; + for(const auto& input : inputs) { + auto output = input.astype(targetType_); + outputs.push_back(output); + } + return outputs; } Variable PrecisionCast::forward(const Variable& input) { - return forward(std::vector{input}).front(); + return forward(std::vector{input}).front(); } Variable PrecisionCast::operator()(const Variable& input) { - return this->forward(input); + return this->forward(input); } std::unique_ptr PrecisionCast::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string PrecisionCast::prettyString() const { - std::ostringstream ss; - ss << "PrecisionCast"; - ss << " * -> " << targetType_; - return ss.str(); + std::ostringstream ss; + ss << "PrecisionCast"; + ss << " * -> " << targetType_; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/PrecisionCast.h b/flashlight/fl/nn/modules/PrecisionCast.h index d37597d..0070b5b 100644 --- a/flashlight/fl/nn/modules/PrecisionCast.h +++ b/flashlight/fl/nn/modules/PrecisionCast.h @@ -20,34 +20,34 @@ namespace fl { * other attributes of the input variable unchanged. */ class FL_API PrecisionCast : public Module { - private: - fl::dtype targetType_; - PrecisionCast() = default; - FL_SAVE_LOAD_WITH_BASE(Module, targetType_) - - public: - /** - * Constructor of the Cast Module (PrecisionCast class). - * - * @param targetType An ArrayFire type that specifies the target type of the - * cast. Inputs to the the `forward` function will be casted to `targetType`. - */ - explicit PrecisionCast(fl::dtype targetType); - - /** - * Casts every input variable according to the `targetType_`. The value of - * `targetType_` is set during the initialization of the module. - * - * @param inputs A reference to the vector containing the input variables. - * - * @return A vector that contains the casted variables. - */ - std::vector forward(const std::vector& inputs) override; - - Variable forward(const Variable& input); - Variable operator()(const Variable& input); - std::unique_ptr clone() const override; - std::string prettyString() const override; +private: + fl::dtype targetType_; + PrecisionCast() = default; + FL_SAVE_LOAD_WITH_BASE(Module, targetType_) + +public: + /** + * Constructor of the Cast Module (PrecisionCast class). + * + * @param targetType An ArrayFire type that specifies the target type of the + * cast. Inputs to the the `forward` function will be casted to `targetType`. + */ + explicit PrecisionCast(fl::dtype targetType); + + /** + * Casts every input variable according to the `targetType_`. The value of + * `targetType_` is set during the initialization of the module. + * + * @param inputs A reference to the vector containing the input variables. + * + * @return A vector that contains the casted variables. + */ + std::vector forward(const std::vector& inputs) override; + + Variable forward(const Variable& input); + Variable operator()(const Variable& input); + std::unique_ptr clone() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/RNN.cpp b/flashlight/fl/nn/modules/RNN.cpp index 7922706..b425137 100644 --- a/flashlight/fl/nn/modules/RNN.cpp +++ b/flashlight/fl/nn/modules/RNN.cpp @@ -22,149 +22,159 @@ RNN::RNN( int num_layers, RnnMode mode, bool bidirectional /* false */, - float drop_prob /* = 0.0 */) - : inputSize_(input_size), - hiddenSize_(hidden_size), - numLayers_(num_layers), - mode_(mode), - bidirectional_(bidirectional), - dropProb_(drop_prob) { - initialize(); + float drop_prob /* = 0.0 */ +) : inputSize_(input_size), + hiddenSize_(hidden_size), + numLayers_(num_layers), + mode_(mode), + bidirectional_(bidirectional), + dropProb_(drop_prob) { + initialize(); } -RNN::RNN(const RNN& other) - : Module(other.copyParams()), - inputSize_(other.inputSize_), - hiddenSize_(other.hiddenSize_), - numLayers_(other.numLayers_), - mode_(other.mode_), - bidirectional_(other.bidirectional_), - dropProb_(other.dropProb_) { - train_ = other.train_; +RNN::RNN(const RNN& other) : Module(other.copyParams()), + inputSize_(other.inputSize_), + hiddenSize_(other.hiddenSize_), + numLayers_(other.numLayers_), + mode_(other.mode_), + bidirectional_(other.bidirectional_), + dropProb_(other.dropProb_) { + train_ = other.train_; } RNN& RNN::operator=(const RNN& other) { - params_ = other.copyParams(); - train_ = other.train_; - inputSize_ = other.inputSize_; - hiddenSize_ = other.hiddenSize_; - numLayers_ = other.numLayers_; - mode_ = other.mode_; - bidirectional_ = other.bidirectional_; - dropProb_ = other.dropProb_; - return *this; + params_ = other.copyParams(); + train_ = other.train_; + inputSize_ = other.inputSize_; + hiddenSize_ = other.hiddenSize_; + numLayers_ = other.numLayers_; + mode_ = other.mode_; + bidirectional_ = other.bidirectional_; + dropProb_ = other.dropProb_; + return *this; } void RNN::initialize() { - int64_t n_params = detail::getNumRnnParams( - inputSize_, hiddenSize_, numLayers_, mode_, bidirectional_); - - double stdv = std::sqrt(1.0 / static_cast(hiddenSize_)); - auto w = uniform({n_params}, -stdv, stdv, fl::dtype::f32, true); - params_ = {w}; + int64_t n_params = detail::getNumRnnParams( + inputSize_, + hiddenSize_, + numLayers_, + mode_, + bidirectional_ + ); + + double stdv = std::sqrt(1.0 / static_cast(hiddenSize_)); + auto w = uniform({n_params}, -stdv, stdv, fl::dtype::f32, true); + params_ = {w}; } std::vector RNN::forward(const std::vector& inputs) { - if (inputs.empty() || inputs.size() > 3) { - throw std::invalid_argument("Invalid inputs size"); - } - - const auto& input = inputs[0]; - const auto& hiddenState = inputs.size() >= 2 ? inputs[1] : Variable(); - const auto& cellState = inputs.size() == 3 ? inputs[2] : Variable(); - - float dropProb = train_ ? dropProb_ : 0.0; - auto rnnRes = - rnn(input, - hiddenState.astype(input.type()), - cellState.astype(input.type()), - params_[0].astype(input.type()), - hiddenSize_, - numLayers_, - mode_, - bidirectional_, - dropProb); - - std::vector output(1, std::get<0>(rnnRes)); - if (inputs.size() >= 2) { - output.push_back(std::get<1>(rnnRes)); - } - if (inputs.size() == 3) { - output.push_back(std::get<2>(rnnRes)); - } - return output; + if(inputs.empty() || inputs.size() > 3) { + throw std::invalid_argument("Invalid inputs size"); + } + + const auto& input = inputs[0]; + const auto& hiddenState = inputs.size() >= 2 ? inputs[1] : Variable(); + const auto& cellState = inputs.size() == 3 ? inputs[2] : Variable(); + + float dropProb = train_ ? dropProb_ : 0.0; + auto rnnRes = + rnn( + input, + hiddenState.astype(input.type()), + cellState.astype(input.type()), + params_[0].astype(input.type()), + hiddenSize_, + numLayers_, + mode_, + bidirectional_, + dropProb + ); + + std::vector output(1, std::get<0>(rnnRes)); + if(inputs.size() >= 2) { + output.push_back(std::get<1>(rnnRes)); + } + if(inputs.size() == 3) { + output.push_back(std::get<2>(rnnRes)); + } + return output; } Variable RNN::forward(const Variable& input) { - return forward(std::vector{input}).front(); + return forward(std::vector{input}).front(); } Variable RNN::operator()(const Variable& input) { - return forward(input); + return forward(input); } std::tuple RNN::forward( const Variable& input, - const Variable& hidden_state) { - auto res = forward(std::vector{input, hidden_state}); - return std::make_tuple(res[0], res[1]); + const Variable& hidden_state +) { + auto res = forward(std::vector{input, hidden_state}); + return std::make_tuple(res[0], res[1]); } std::tuple RNN::operator()( const Variable& input, - const Variable& hidden_state) { - return forward(input, hidden_state); + const Variable& hidden_state +) { + return forward(input, hidden_state); } std::tuple RNN::forward( const Variable& input, const Variable& hidden_state, - const Variable& cell_state) { - auto res = forward(std::vector{input, hidden_state, cell_state}); - return std::make_tuple(res[0], res[1], res[2]); + const Variable& cell_state +) { + auto res = forward(std::vector{input, hidden_state, cell_state}); + return std::make_tuple(res[0], res[1], res[2]); } std::tuple RNN::operator()( const Variable& input, const Variable& hidden_state, - const Variable& cell_state) { - return forward(input, hidden_state, cell_state); + const Variable& cell_state +) { + return forward(input, hidden_state, cell_state); } std::unique_ptr RNN::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string RNN::prettyString() const { - std::ostringstream ss; - switch (mode_) { - case RnnMode::RELU: - ss << "RNN (relu)"; - break; - case RnnMode::TANH: - ss << "RNN (tanh)"; - break; - case RnnMode::LSTM: - ss << "LSTM"; - break; - case RnnMode::GRU: - ss << "GRU"; - break; - default: - break; - } - int output_size = bidirectional_ ? 2 * hiddenSize_ : hiddenSize_; - ss << " (" << inputSize_ << "->" << output_size << ")"; - if (numLayers_ > 1) { - ss << " (" << numLayers_ << "-layer)"; - } - if (bidirectional_) { - ss << " (bidirectional)"; - } - if (dropProb_ > 0) { - ss << " (dropout=" << dropProb_ << ")"; - } - return ss.str(); + std::ostringstream ss; + switch(mode_) { + case RnnMode::RELU: + ss << "RNN (relu)"; + break; + case RnnMode::TANH: + ss << "RNN (tanh)"; + break; + case RnnMode::LSTM: + ss << "LSTM"; + break; + case RnnMode::GRU: + ss << "GRU"; + break; + default: + break; + } + int output_size = bidirectional_ ? 2 * hiddenSize_ : hiddenSize_; + ss << " (" << inputSize_ << "->" << output_size << ")"; + if(numLayers_ > 1) { + ss << " (" << numLayers_ << "-layer)"; + } + if(bidirectional_) { + ss << " (bidirectional)"; + } + if(dropProb_ > 0) { + ss << " (dropout=" << dropProb_ << ")"; + } + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/RNN.h b/flashlight/fl/nn/modules/RNN.h index 07d83da..40e4b48 100644 --- a/flashlight/fl/nn/modules/RNN.h +++ b/flashlight/fl/nn/modules/RNN.h @@ -41,127 +41,134 @@ namespace fl { * assumed to be zero. */ class FL_API RNN : public Module { - private: - RNN() = default; // Intentionally private - - int inputSize_; - int hiddenSize_; - int numLayers_; - RnnMode mode_; - bool bidirectional_; - float dropProb_; - - FL_SAVE_LOAD_WITH_BASE( - Module, - inputSize_, - hiddenSize_, - numLayers_, - mode_, - bidirectional_, - dropProb_) - - void initialize(); - - public: - /** Construct an RNN layer. - * @param input_size The dimension of the input (e.g. \f$X_{in}\f$) - * @param hidden_size The hidden dimension of the RNN. - * @param num_layers The number of recurrent layers. - * @param mode The RNN mode to use. Can be any of: - * - RELU - * - TANH - * - LSTM - * - GRU - * @param bidirectional Whether or not the RNN is bidirectional. If `true` the - * output dimension will be doubled. - * @param drop_prob The probability of dropout after each RNN layer except the - * last layer. - */ - RNN(int input_size, - int hidden_size, - int num_layers, - RnnMode mode, - bool bidirectional = false, - float drop_prob = 0.0); - - /** - * Constructs an RNN module from another, performing a copy of the - * parameters. - * - * @param other The RNN module to copy from. - */ - RNN(const RNN& other); - - /** - * Constructs an RNN module from another, performing a copy of the - * parameters. - * - * @param other The RNN module to copy from. - */ - RNN& operator=(const RNN& other); - - RNN(RNN&& other) = default; - - RNN& operator=(RNN&& other) = default; - - std::vector forward(const std::vector& inputs) override; - - using Module::operator(); - - /** Forward the RNN Layer. - * @param input Should be of shape [\f$X_{in}\f$, \f$N\f$, \f$T\f$] - * @returns a single output Variable with shape [\f$X_{out}\f$, \f$N\f$, - * \f$T\f$] - */ - Variable forward(const Variable& input); - - Variable operator()(const Variable& input); - - /** Forward the RNN Layer. - * @param input Should be of shape [\f$X_{in}\f$, \f$N\f$, \f$T\f$] - * @param hidden_state Should be of shape [\f$X_{out}\f$, \f$N\f$]. If an - * empty Variable is passed in then the hidden state is assumed zero. - * @returns An tuple of output Variables. - * - The first element is the output of the RNN of shape [\f$X_{out}\f$, - * \f$N\f$, \f$T\f$] - * - The second element is the hidden state of the RNN of shape - * [\f$X_{out}\f$, \f$N\f$] - */ - std::tuple forward( - const Variable& input, - const Variable& hidden_state); - - std::tuple operator()( - const Variable& input, - const Variable& hidden_state); - - /** Forward the RNN Layer. - * @param input Should be of shape [\f$X_{in}\f$, \f$N\f$, \f$T\f$] - * @param hidden_state Should be of shape [\f$X_{out}\f$, \f$N\f$]. If an - * empty Variable is passed in then the hidden state is assumed zero. - * @param cell_state Should be of shape [\f$X_{out}\f$, \f$N\f$]. If an empty - * Variable is passed in then the hidden state is assumed zero. - * @returns An tuple of output Variables. - * - The first element is the output of the RNN of shape [\f$X_{out}\f$, - * \f$N\f$, \f$T\f$] - * - The second element is the hidden state of the RNN of shape - * [\f$X_{out}\f$, \f$N\f$] - * - The third element is the cell state of the RNN of shape [\f$X_{out}\f$, - * \f$N\f$] - */ - std::tuple forward( - const Variable& input, - const Variable& hidden_state, - const Variable& cell_state); - - std::tuple operator()( - const Variable& input, - const Variable& hidden_state, - const Variable& cell_state); - - std::unique_ptr clone() const override; - - std::string prettyString() const override; +private: + RNN() = default; // Intentionally private + + int inputSize_; + int hiddenSize_; + int numLayers_; + RnnMode mode_; + bool bidirectional_; + float dropProb_; + + FL_SAVE_LOAD_WITH_BASE( + Module, + inputSize_, + hiddenSize_, + numLayers_, + mode_, + bidirectional_, + dropProb_ + ) + + void initialize(); + +public: + /** Construct an RNN layer. + * @param input_size The dimension of the input (e.g. \f$X_{in}\f$) + * @param hidden_size The hidden dimension of the RNN. + * @param num_layers The number of recurrent layers. + * @param mode The RNN mode to use. Can be any of: + * - RELU + * - TANH + * - LSTM + * - GRU + * @param bidirectional Whether or not the RNN is bidirectional. If `true` the + * output dimension will be doubled. + * @param drop_prob The probability of dropout after each RNN layer except the + * last layer. + */ + RNN( + int input_size, + int hidden_size, + int num_layers, + RnnMode mode, + bool bidirectional = false, + float drop_prob = 0.0 + ); + + /** + * Constructs an RNN module from another, performing a copy of the + * parameters. + * + * @param other The RNN module to copy from. + */ + RNN(const RNN& other); + + /** + * Constructs an RNN module from another, performing a copy of the + * parameters. + * + * @param other The RNN module to copy from. + */ + RNN& operator=(const RNN& other); + + RNN(RNN&& other) = default; + + RNN& operator=(RNN&& other) = default; + + std::vector forward(const std::vector& inputs) override; + + using Module::operator(); + + /** Forward the RNN Layer. + * @param input Should be of shape [\f$X_{in}\f$, \f$N\f$, \f$T\f$] + * @returns a single output Variable with shape [\f$X_{out}\f$, \f$N\f$, + * \f$T\f$] + */ + Variable forward(const Variable& input); + + Variable operator()(const Variable& input); + + /** Forward the RNN Layer. + * @param input Should be of shape [\f$X_{in}\f$, \f$N\f$, \f$T\f$] + * @param hidden_state Should be of shape [\f$X_{out}\f$, \f$N\f$]. If an + * empty Variable is passed in then the hidden state is assumed zero. + * @returns An tuple of output Variables. + * - The first element is the output of the RNN of shape [\f$X_{out}\f$, + * \f$N\f$, \f$T\f$] + * - The second element is the hidden state of the RNN of shape + * [\f$X_{out}\f$, \f$N\f$] + */ + std::tuple forward( + const Variable& input, + const Variable& hidden_state + ); + + std::tuple operator()( + const Variable& input, + const Variable& hidden_state + ); + + /** Forward the RNN Layer. + * @param input Should be of shape [\f$X_{in}\f$, \f$N\f$, \f$T\f$] + * @param hidden_state Should be of shape [\f$X_{out}\f$, \f$N\f$]. If an + * empty Variable is passed in then the hidden state is assumed zero. + * @param cell_state Should be of shape [\f$X_{out}\f$, \f$N\f$]. If an empty + * Variable is passed in then the hidden state is assumed zero. + * @returns An tuple of output Variables. + * - The first element is the output of the RNN of shape [\f$X_{out}\f$, + * \f$N\f$, \f$T\f$] + * - The second element is the hidden state of the RNN of shape + * [\f$X_{out}\f$, \f$N\f$] + * - The third element is the cell state of the RNN of shape [\f$X_{out}\f$, + * \f$N\f$] + */ + std::tuple forward( + const Variable& input, + const Variable& hidden_state, + const Variable& cell_state + ); + + std::tuple operator()( + const Variable& input, + const Variable& hidden_state, + const Variable& cell_state + ); + + std::unique_ptr clone() const override; + + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Reorder.cpp b/flashlight/fl/nn/modules/Reorder.cpp index 724edec..e609c9c 100644 --- a/flashlight/fl/nn/modules/Reorder.cpp +++ b/flashlight/fl/nn/modules/Reorder.cpp @@ -18,23 +18,24 @@ namespace fl { Reorder::Reorder(Shape shape) : shape_(std::move(shape)) {} Variable Reorder::forward(const Variable& input) { - if (input.ndim() != shape_.ndim()) { - throw std::invalid_argument( - "Reorder::forward - input tensor has different " - "number of dimensions than reorder shape."); - } - return reorder(input, shape_); + if(input.ndim() != shape_.ndim()) { + throw std::invalid_argument( + "Reorder::forward - input tensor has different " + "number of dimensions than reorder shape." + ); + } + return reorder(input, shape_); } std::unique_ptr Reorder::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Reorder::prettyString() const { - std::ostringstream ss; - ss << "Reorder"; - ss << " (" << shape_ << ")"; - return ss.str(); + std::ostringstream ss; + ss << "Reorder"; + ss << " (" << shape_ << ")"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Reorder.h b/flashlight/fl/nn/modules/Reorder.h index 3a9c002..226725b 100644 --- a/flashlight/fl/nn/modules/Reorder.h +++ b/flashlight/fl/nn/modules/Reorder.h @@ -26,27 +26,27 @@ namespace fl { * \endcode */ class FL_API Reorder : public UnaryModule { - private: - Reorder() = default; +private: + Reorder() = default; - Shape shape_; + Shape shape_; - FL_SAVE_LOAD_WITH_BASE(UnaryModule, shape_) + FL_SAVE_LOAD_WITH_BASE(UnaryModule, shape_) - public: - /** - * Construct a Reorder layer. The dimension values must not repeat and must - * be between 0 and 3 inclusive. - * - * @param shape The shape to which the input will be reshaped. - */ - explicit Reorder(Shape shape); +public: + /** + * Construct a Reorder layer. The dimension values must not repeat and must + * be between 0 and 3 inclusive. + * + * @param shape The shape to which the input will be reshaped. + */ + explicit Reorder(Shape shape); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/Transform.cpp b/flashlight/fl/nn/modules/Transform.cpp index 632e873..ccf571e 100644 --- a/flashlight/fl/nn/modules/Transform.cpp +++ b/flashlight/fl/nn/modules/Transform.cpp @@ -13,21 +13,22 @@ namespace fl { Transform::Transform( const std::function& func, - const std::string& name /* = "" */) - : func_(func), name_(name) {} + const std::string& name /* = "" */ +) : func_(func), + name_(name) {} Variable Transform::forward(const Variable& input) { - return func_(input); + return func_(input); } std::unique_ptr Transform::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string Transform::prettyString() const { - std::ostringstream ss; - ss << "Transform ('" << name_ << "')"; - return ss.str(); + std::ostringstream ss; + ss << "Transform ('" << name_ << "')"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Transform.h b/flashlight/fl/nn/modules/Transform.h index 5a790a8..0b0ea4c 100644 --- a/flashlight/fl/nn/modules/Transform.h +++ b/flashlight/fl/nn/modules/Transform.h @@ -24,45 +24,46 @@ namespace fl { * Note this module cannot be serialized. */ class FL_API Transform : public UnaryModule { - private: - Transform() = default; // Intentionally private +private: + Transform() = default; // Intentionally private - std::function func_; + std::function func_; - std::string name_; + std::string name_; - /** - * Transform layers cannot be serialized. This function throws a runtime - * exception. - */ - FL_SAVE_LOAD_DECLARE() + /** + * Transform layers cannot be serialized. This function throws a runtime + * exception. + */ + FL_SAVE_LOAD_DECLARE() - public: - /** - * Construct a Transform (lambda) layer. - * @param func a lambda function which accepts an input Variable and returns - * an output Variable. - * @param name an optional name used by prettyString. - */ - explicit Transform( - const std::function& func, - const std::string& name = ""); +public: + /** + * Construct a Transform (lambda) layer. + * @param func a lambda function which accepts an input Variable and returns + * an output Variable. + * @param name an optional name used by prettyString. + */ + explicit Transform( + const std::function& func, + const std::string& name = "" + ); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; }; -template +template void Transform::save(Archive& /* ar */, const uint32_t /* version */) const { - throw std::runtime_error("Transform module does not support serialization"); + throw std::runtime_error("Transform module does not support serialization"); } -template +template void Transform::load(Archive& /* ar */, const uint32_t /* version */) { - throw std::runtime_error("Transform module does not support serialization"); + throw std::runtime_error("Transform module does not support serialization"); } } // namespace fl diff --git a/flashlight/fl/nn/modules/View.cpp b/flashlight/fl/nn/modules/View.cpp index 5c379d8..3158e36 100644 --- a/flashlight/fl/nn/modules/View.cpp +++ b/flashlight/fl/nn/modules/View.cpp @@ -18,18 +18,18 @@ namespace fl { View::View(Shape dims) : dims_(std::move(dims)) {} Variable View::forward(const Variable& input) { - Shape dims = dims_; - return moddims(input, dims); + Shape dims = dims_; + return moddims(input, dims); } std::unique_ptr View::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string View::prettyString() const { - std::ostringstream ss; - ss << "View (" << dims_ << ")"; - return ss.str(); + std::ostringstream ss; + ss << "View (" << dims_ << ")"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/View.h b/flashlight/fl/nn/modules/View.h index 5f82f72..640fb87 100644 --- a/flashlight/fl/nn/modules/View.h +++ b/flashlight/fl/nn/modules/View.h @@ -25,28 +25,28 @@ namespace fl { * tensor will have shape `(120, 20, 100)`. */ class FL_API View : public UnaryModule { - private: - View() = default; // Intentionally private +private: + View() = default; // Intentionally private - Shape dims_; + Shape dims_; - FL_SAVE_LOAD_WITH_BASE(UnaryModule, dims_) + FL_SAVE_LOAD_WITH_BASE(UnaryModule, dims_) - public: - /** - * Creates a `View` with the given dimensions. - * - * @param dims an `Shape` representing the dimensions of the `View`. - */ - explicit View(Shape dims); +public: + /** + * Creates a `View` with the given dimensions. + * + * @param dims an `Shape` representing the dimensions of the `View`. + */ + explicit View(Shape dims); - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; - ~View() = default; + ~View() = default; }; } // namespace fl diff --git a/flashlight/fl/nn/modules/WeightNorm.cpp b/flashlight/fl/nn/modules/WeightNorm.cpp index 70af468..0df4f83 100644 --- a/flashlight/fl/nn/modules/WeightNorm.cpp +++ b/flashlight/fl/nn/modules/WeightNorm.cpp @@ -13,114 +13,114 @@ namespace fl { -WeightNorm::WeightNorm(const WeightNorm& other) - : module_(other.module_->clone()), - dim_(other.dim_), - normDim_(other.normDim_) { - initParams(); +WeightNorm::WeightNorm(const WeightNorm& other) : module_(other.module_->clone()), + dim_(other.dim_), + normDim_(other.normDim_) { + initParams(); } WeightNorm& WeightNorm::operator=(const WeightNorm& other) { - module_ = other.clone(); - dim_ = other.dim_; - normDim_ = other.normDim_; - initParams(); - return *this; + module_ = other.clone(); + dim_ = other.dim_; + normDim_ = other.normDim_; + initParams(); + return *this; } void WeightNorm::transformDims() { - normDim_.clear(); - int vNumdims = module_->param(0).ndim(); - if (dim_ < 0 || dim_ > vNumdims) { - throw std::invalid_argument("invalid dimension for WeightNorm"); - } - for (int i = 0; i < vNumdims; i++) { - if (i != dim_) { - normDim_.push_back(i); + normDim_.clear(); + int vNumdims = module_->param(0).ndim(); + if(dim_ < 0 || dim_ > vNumdims) { + throw std::invalid_argument("invalid dimension for WeightNorm"); + } + for(int i = 0; i < vNumdims; i++) { + if(i != dim_) { + normDim_.push_back(i); + } } - } } void WeightNorm::computeWeight() { - auto v = params_[0]; - auto g = params_[1]; - Variable nm; - // speed of norm operation is the best while doing it across {1} dim - // tested for convlm model training - if (dim_ == 0) { - nm = moddims(v, {0, -1}); - nm = norm(nm, {1}, /* p = */ 2, /* keepDims = */ true); - } else if (dim_ == 3) { - // TODO{fl::Tensor}{enforce 4D parameters from child module?} - nm = moddims(v, {-1, 1, 1, 0}); - nm = reorder(nm, {3, 0, 1, 2}); - nm = norm(nm, {1}, /* p = */ 2, /* keepDims = */ true); - nm = reorder(nm, {1, 2, 3, 0}); - } else { - throw std::invalid_argument( - "Wrong dimension for Weight Norm: " + std::to_string(dim_)); - } - auto wt = v * tileAs(g / nm, v); - module_->setParams(wt, 0); + auto v = params_[0]; + auto g = params_[1]; + Variable nm; + // speed of norm operation is the best while doing it across {1} dim + // tested for convlm model training + if(dim_ == 0) { + nm = moddims(v, {0, -1}); + nm = norm(nm, {1}, /* p = */ 2, /* keepDims = */ true); + } else if(dim_ == 3) { + // TODO{fl::Tensor}{enforce 4D parameters from child module?} + nm = moddims(v, {-1, 1, 1, 0}); + nm = reorder(nm, {3, 0, 1, 2}); + nm = norm(nm, {1}, /* p = */ 2, /* keepDims = */ true); + nm = reorder(nm, {1, 2, 3, 0}); + } else { + throw std::invalid_argument( + "Wrong dimension for Weight Norm: " + std::to_string(dim_) + ); + } + auto wt = v * tileAs(g / nm, v); + module_->setParams(wt, 0); } void WeightNorm::initParams() { - auto moduleParams = module_->params(); - auto& v = moduleParams.at(0); - Variable g( - norm(v, normDim_, /* p = */ 2, /* keepDims = */ true).tensor(), true); - if (moduleParams.size() == 2) { - auto& b = moduleParams[1]; - params_ = {v, g, b}; - } else if (moduleParams.size() == 1) { - params_ = {v, g}; - } else { - throw std::invalid_argument("WeightNorm only supports Linear and Conv2D"); - } + auto moduleParams = module_->params(); + auto& v = moduleParams.at(0); + Variable g( + norm(v, normDim_, /* p = */ 2, /* keepDims = */ true).tensor(), true); + if(moduleParams.size() == 2) { + auto& b = moduleParams[1]; + params_ = {v, g, b}; + } else if(moduleParams.size() == 1) { + params_ = {v, g}; + } else { + throw std::invalid_argument("WeightNorm only supports Linear and Conv2D"); + } } void WeightNorm::setParams(const Variable& var, int position) { - Module::setParams(var, position); - // it is necessary to copy all params to the parent module - // due to copies stored in the parent module (not pointers) - if (position == 2) { - module_->setParams(var, 1); - } else if (position <= 1) { - computeWeight(); - } + Module::setParams(var, position); + // it is necessary to copy all params to the parent module + // due to copies stored in the parent module (not pointers) + if(position == 2) { + module_->setParams(var, 1); + } else if(position <= 1) { + computeWeight(); + } } std::vector WeightNorm::forward(const std::vector& inputs) { - if (train_) { - computeWeight(); - } - return module_->forward(inputs); + if(train_) { + computeWeight(); + } + return module_->forward(inputs); } ModulePtr WeightNorm::module() const { - return module_; + return module_; } void WeightNorm::train() { - Module::train(); - module_->train(); + Module::train(); + module_->train(); } void WeightNorm::eval() { - Module::eval(); - module_->eval(); - computeWeight(); + Module::eval(); + module_->eval(); + computeWeight(); } std::unique_ptr WeightNorm::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::string WeightNorm::prettyString() const { - std::ostringstream ss; - ss << "WeightNorm"; - ss << " (" << module_->prettyString() << ", " << dim_ << ")"; - return ss.str(); + std::ostringstream ss; + ss << "WeightNorm"; + ss << " (" << module_->prettyString() << ", " << dim_ << ")"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/nn/modules/WeightNorm.h b/flashlight/fl/nn/modules/WeightNorm.h index 09aaeaf..5b8976b 100644 --- a/flashlight/fl/nn/modules/WeightNorm.h +++ b/flashlight/fl/nn/modules/WeightNorm.h @@ -32,99 +32,100 @@ namespace fl { * https://arxiv.org/abs/1602.07868) */ class FL_API WeightNorm : public Module { - private: - WeightNorm() = default; +private: + WeightNorm() = default; - std::shared_ptr module_; + std::shared_ptr module_; - // Computes the norm over all dimensions except dim_ - int dim_; - std::vector normDim_; + // Computes the norm over all dimensions except dim_ + int dim_; + std::vector normDim_; - void transformDims(); + void transformDims(); - void computeWeight(); + void computeWeight(); - void initParams(); + void initParams(); - FL_SAVE_LOAD_DECLARE() + FL_SAVE_LOAD_DECLARE() - public: - /** Construct a WeightNorm layer. - * @param module A module to wrap (must be one of Linear or Conv2D). Takes - * ownership of the module. - * @param dim The dimension to normalize. - */ - template - WeightNorm(T&& module, int dim) - : WeightNorm( - std::make_shared>(std::forward(module)), - dim) {} +public: + /** Construct a WeightNorm layer. + * @param module A module to wrap (must be one of Linear or Conv2D). Takes + * ownership of the module. + * @param dim The dimension to normalize. + */ + template + WeightNorm(T&& module, int dim) : WeightNorm( + std::make_shared>(std::forward(module)), + dim + ) {} - /** Construct a WeightNorm layer. - * @param module Shared pointer to a module to wrap (the module must be one - * of Linear or Conv2D) - * @param dim The dimension to normalize. - */ - template - WeightNorm(std::shared_ptr module, int dim) : module_(module), dim_(dim) { - transformDims(); - initParams(); - } + /** Construct a WeightNorm layer. + * @param module Shared pointer to a module to wrap (the module must be one + * of Linear or Conv2D) + * @param dim The dimension to normalize. + */ + template + WeightNorm(std::shared_ptr module, int dim) : module_(module), + dim_(dim) { + transformDims(); + initParams(); + } - /** - * Construct a WeightNorm module from another, performing a deep copy for the - * wrapped module. - * - * @param other The WeightNorm module to copy from. - */ - WeightNorm(const WeightNorm& other); + /** + * Construct a WeightNorm module from another, performing a deep copy for the + * wrapped module. + * + * @param other The WeightNorm module to copy from. + */ + WeightNorm(const WeightNorm& other); - /** - * Construct a WeightNorm module from another, performing a deep copy for the - * wrapped module. - * - * @param other The WeightNorm module to copy from. - */ - WeightNorm& operator=(const WeightNorm& other); + /** + * Construct a WeightNorm module from another, performing a deep copy for the + * wrapped module. + * + * @param other The WeightNorm module to copy from. + */ + WeightNorm& operator=(const WeightNorm& other); - WeightNorm(WeightNorm&& other) = default; + WeightNorm(WeightNorm&& other) = default; - WeightNorm& operator=(WeightNorm&& other) = default; + WeightNorm& operator=(WeightNorm&& other) = default; - /** - * Returns a pointer to the inner `Module` normalized by this `WeightNorm`. - * - * @return a module pointer. - */ - ModulePtr module() const; + /** + * Returns a pointer to the inner `Module` normalized by this `WeightNorm`. + * + * @return a module pointer. + */ + ModulePtr module() const; - void train() override; + void train() override; - void eval() override; + void eval() override; - void setParams(const Variable& var, int position) override; + void setParams(const Variable& var, int position) override; - std::vector forward(const std::vector& inputs) override; + std::vector forward(const std::vector& inputs) override; - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; }; -template +template void WeightNorm::save(Archive& ar, const uint32_t /* version */) const { - // Not saving weight since it can be inferred from from v and g. - auto wt = module_->param(0); - module_->setParams(Variable(), 0); - ar(cereal::base_class(this), module_, dim_, normDim_); - module_->setParams(wt, 0); + // Not saving weight since it can be inferred from from v and g. + auto wt = module_->param(0); + module_->setParams(Variable(), 0); + ar(cereal::base_class(this), module_, dim_, normDim_); + module_->setParams(wt, 0); } -template +template void WeightNorm::load(Archive& ar, const uint32_t /* version */) { - ar(cereal::base_class(this), module_, dim_, normDim_); - computeWeight(); + ar(cereal::base_class(this), module_, dim_, normDim_); + computeWeight(); } } // namespace fl diff --git a/flashlight/fl/optim/AMSgradOptimizer.cpp b/flashlight/fl/optim/AMSgradOptimizer.cpp index 54140cf..f64152a 100644 --- a/flashlight/fl/optim/AMSgradOptimizer.cpp +++ b/flashlight/fl/optim/AMSgradOptimizer.cpp @@ -21,69 +21,69 @@ AMSgradOptimizer::AMSgradOptimizer( float beta1 /* = 0.9 */, float beta2 /* = 0.999 */, float epsilon /* = 1e-8 */, - float weightDecay /* = 0 */) - : FirstOrderOptimizer(parameters, learningRate), - beta1_(beta1), - beta2_(beta2), - eps_(epsilon), - wd_(weightDecay), - biasedFirst_(), - biasedSecond_(), - maxExpAvgSq_() { - biasedFirst_.reserve(parameters.size()); - biasedSecond_.reserve(parameters.size()); - maxExpAvgSq_.reserve(parameters.size()); - - for (const auto& parameter : parameters_) { - biasedFirst_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - biasedSecond_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - maxExpAvgSq_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - - fl::eval(biasedFirst_.back()); - fl::eval(biasedSecond_.back()); - fl::eval(maxExpAvgSq_.back()); - } + float weightDecay /* = 0 */ +) : FirstOrderOptimizer(parameters, learningRate), + beta1_(beta1), + beta2_(beta2), + eps_(epsilon), + wd_(weightDecay), + biasedFirst_(), + biasedSecond_(), + maxExpAvgSq_() { + biasedFirst_.reserve(parameters.size()); + biasedSecond_.reserve(parameters.size()); + maxExpAvgSq_.reserve(parameters.size()); + + for(const auto& parameter : parameters_) { + biasedFirst_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + biasedSecond_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + maxExpAvgSq_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + + fl::eval(biasedFirst_.back()); + fl::eval(biasedSecond_.back()); + fl::eval(maxExpAvgSq_.back()); + } } void AMSgradOptimizer::step() { - for (size_t i = 0; i < parameters_.size(); i++) { - if (!parameters_[i].isGradAvailable()) { - continue; - } + for(size_t i = 0; i < parameters_.size(); i++) { + if(!parameters_[i].isGradAvailable()) { + continue; + } - const Tensor& grad = parameters_[i].grad().tensor(); - Tensor& data = parameters_[i].tensor(); + const Tensor& grad = parameters_[i].grad().tensor(); + Tensor& data = parameters_[i].tensor(); - if (wd_ != 0) { - data = data - wd_ * data; - } + if(wd_ != 0) { + data = data - wd_ * data; + } - Tensor& biasedFirst = biasedFirst_[i]; - Tensor& biasedSecond = biasedSecond_[i]; - Tensor& maxExpAvgSq = maxExpAvgSq_[i]; + Tensor& biasedFirst = biasedFirst_[i]; + Tensor& biasedSecond = biasedSecond_[i]; + Tensor& maxExpAvgSq = maxExpAvgSq_[i]; - biasedFirst = beta1_ * biasedFirst + (1 - beta1_) * grad; - biasedSecond = beta2_ * biasedSecond + (1 - beta2_) * grad * grad; - maxExpAvgSq = fl::maximum(maxExpAvgSq, biasedSecond); - fl::eval(biasedFirst); - fl::eval(biasedSecond); - fl::eval(maxExpAvgSq); + biasedFirst = beta1_ * biasedFirst + (1 - beta1_) * grad; + biasedSecond = beta2_ * biasedSecond + (1 - beta2_) * grad * grad; + maxExpAvgSq = fl::maximum(maxExpAvgSq, biasedSecond); + fl::eval(biasedFirst); + fl::eval(biasedSecond); + fl::eval(maxExpAvgSq); - data = data - (lr_ * biasedFirst) / (fl::sqrt(maxExpAvgSq) + eps_); + data = data - (lr_ * biasedFirst) / (fl::sqrt(maxExpAvgSq) + eps_); - fl::eval(data); - } + fl::eval(data); + } } std::string AMSgradOptimizer::prettyString() const { - std::ostringstream ss; - ss << "AMSgrad from "; + std::ostringstream ss; + ss << "AMSgrad from "; - if (wd_ != 0) { - ss << " (weight decay=" << wd_ << ")"; - } + if(wd_ != 0) { + ss << " (weight decay=" << wd_ << ")"; + } - return ss.str(); + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/optim/AMSgradOptimizer.h b/flashlight/fl/optim/AMSgradOptimizer.h index 367c57f..ac148b8 100644 --- a/flashlight/fl/optim/AMSgradOptimizer.h +++ b/flashlight/fl/optim/AMSgradOptimizer.h @@ -22,48 +22,48 @@ namespace fl { * https://openreview.net/pdf?id=ryQu7f-RZ). */ class FL_API AMSgradOptimizer : public FirstOrderOptimizer { - private: - FL_SAVE_LOAD_WITH_BASE( - FirstOrderOptimizer, - beta1_, - beta2_, - eps_, - wd_, - biasedFirst_, - biasedSecond_, - maxExpAvgSq_) +private: + FL_SAVE_LOAD_WITH_BASE( + FirstOrderOptimizer, + beta1_, + beta2_, + eps_, + wd_, + biasedFirst_, + biasedSecond_, + maxExpAvgSq_ + ) AMSgradOptimizer() = default; // Intentionally private - AMSgradOptimizer() = default; // Intentionally private + float beta1_; + float beta2_; + float eps_; + float wd_; + std::vector biasedFirst_; + std::vector biasedSecond_; + std::vector maxExpAvgSq_; - float beta1_; - float beta2_; - float eps_; - float wd_; - std::vector biasedFirst_; - std::vector biasedSecond_; - std::vector maxExpAvgSq_; +public: + /** Construct an AMSgrad optimizer + * @param parameters The parameters from e.g. `model.parameters()`. + * @param learningRate The learning rate. + * @param beta1 AMSgrad hyperparameter \f$ \beta_1 \f$. + * @param beta2 AMSgrad hyperparameter \f$ \beta_2 \f$. + * @param epsilon A small value used for numerical stability. + * @param weightDecay The amount of L2 weight decay to use for all the + * parameters. + */ + AMSgradOptimizer( + const std::vector& parameters, + float learningRate, + float beta1 = 0.9, + float beta2 = 0.999, + float epsilon = 1e-8, + float weightDecay = 0 + ); - public: - /** Construct an AMSgrad optimizer - * @param parameters The parameters from e.g. `model.parameters()`. - * @param learningRate The learning rate. - * @param beta1 AMSgrad hyperparameter \f$ \beta_1 \f$. - * @param beta2 AMSgrad hyperparameter \f$ \beta_2 \f$. - * @param epsilon A small value used for numerical stability. - * @param weightDecay The amount of L2 weight decay to use for all the - * parameters. - */ - AMSgradOptimizer( - const std::vector& parameters, - float learningRate, - float beta1 = 0.9, - float beta2 = 0.999, - float epsilon = 1e-8, - float weightDecay = 0); + void step() override; - void step() override; - - std::string prettyString() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/optim/AdadeltaOptimizer.cpp b/flashlight/fl/optim/AdadeltaOptimizer.cpp index 65fd82b..644aef2 100644 --- a/flashlight/fl/optim/AdadeltaOptimizer.cpp +++ b/flashlight/fl/optim/AdadeltaOptimizer.cpp @@ -18,68 +18,68 @@ AdadeltaOptimizer::AdadeltaOptimizer( float learningRate /* = 1.0 */, float rho /* = 0.9 */, float epsilon /* = 1e-8 */, - float weightDecay /* = 0 */) - : FirstOrderOptimizer(parameters, learningRate), - rho_(rho), - eps_(epsilon), - wd_(weightDecay), - accGrad_(), - accDelta_() { - accGrad_.reserve(parameters.size()); - accDelta_.reserve(parameters.size()); - - for (const auto& parameter : parameters_) { - accGrad_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - accDelta_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - - fl::eval(accGrad_.back()); - fl::eval(accDelta_.back()); - } + float weightDecay /* = 0 */ +) : FirstOrderOptimizer(parameters, learningRate), + rho_(rho), + eps_(epsilon), + wd_(weightDecay), + accGrad_(), + accDelta_() { + accGrad_.reserve(parameters.size()); + accDelta_.reserve(parameters.size()); + + for(const auto& parameter : parameters_) { + accGrad_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + accDelta_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + + fl::eval(accGrad_.back()); + fl::eval(accDelta_.back()); + } } void AdadeltaOptimizer::step() { - for (size_t i = 0; i < parameters_.size(); i++) { - if (!parameters_[i].isGradAvailable()) { - continue; - } + for(size_t i = 0; i < parameters_.size(); i++) { + if(!parameters_[i].isGradAvailable()) { + continue; + } - const Tensor& grad = parameters_[i].grad().tensor(); - Tensor& data = parameters_[i].tensor(); + const Tensor& grad = parameters_[i].grad().tensor(); + Tensor& data = parameters_[i].tensor(); - if (wd_ != 0) { - // Weight decay term - data = data - wd_ * data; - } + if(wd_ != 0) { + // Weight decay term + data = data - wd_ * data; + } - Tensor& accGrad = accGrad_[i]; - Tensor& accDelta = accDelta_[i]; + Tensor& accGrad = accGrad_[i]; + Tensor& accDelta = accDelta_[i]; - accGrad = rho_ * accGrad + (1 - rho_) * grad * grad; - fl::eval(accGrad); + accGrad = rho_ * accGrad + (1 - rho_) * grad * grad; + fl::eval(accGrad); - auto delta = fl::sqrt(accDelta + eps_) / fl::sqrt(accGrad + eps_) * grad; + auto delta = fl::sqrt(accDelta + eps_) / fl::sqrt(accGrad + eps_) * grad; - data = data - lr_ * delta; - fl::eval(data); + data = data - lr_ * delta; + fl::eval(data); - accDelta = rho_ * accDelta + (1 - rho_) * delta * delta; - fl::eval(accDelta); - } + accDelta = rho_ * accDelta + (1 - rho_) * delta * delta; + fl::eval(accDelta); + } } std::string AdadeltaOptimizer::prettyString() const { - std::ostringstream ss; - ss << "Adadelta"; - - if (wd_ != 0) { - ss << " (weight decay=" << wd_ << ")"; - } - ss << " (rho=" << rho_ << ")"; - if (eps_ != 0) { - ss << " (epsilon=" << eps_ << ")"; - } - - return ss.str(); + std::ostringstream ss; + ss << "Adadelta"; + + if(wd_ != 0) { + ss << " (weight decay=" << wd_ << ")"; + } + ss << " (rho=" << rho_ << ")"; + if(eps_ != 0) { + ss << " (epsilon=" << eps_ << ")"; + } + + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/optim/AdadeltaOptimizer.h b/flashlight/fl/optim/AdadeltaOptimizer.h index d75a6e6..9a29d36 100644 --- a/flashlight/fl/optim/AdadeltaOptimizer.h +++ b/flashlight/fl/optim/AdadeltaOptimizer.h @@ -22,43 +22,43 @@ namespace fl { * https://arxiv.org/pdf/1212.5701.pdf). */ class FL_API AdadeltaOptimizer : public FirstOrderOptimizer { - private: - FL_SAVE_LOAD_WITH_BASE( - FirstOrderOptimizer, - fl::serializeAs(rho_), - fl::serializeAs(eps_), - fl::serializeAs(wd_), - accGrad_, - accDelta_) +private: + FL_SAVE_LOAD_WITH_BASE( + FirstOrderOptimizer, + fl::serializeAs(rho_), + fl::serializeAs(eps_), + fl::serializeAs(wd_), + accGrad_, + accDelta_ + ) AdadeltaOptimizer() = default; // Intentionally private - AdadeltaOptimizer() = default; // Intentionally private + float rho_; + float eps_; + float wd_; + std::vector accGrad_; + std::vector accDelta_; - float rho_; - float eps_; - float wd_; - std::vector accGrad_; - std::vector accDelta_; +public: + /** Construct an Adadelta optimizer. + * @param parameters The parameters from e.g. `model.parameters()`. + * @param learningRate The learning rate for scaling delta. The original + * paper does not include this term (i.e. learningRate = 1.0). + * @param rho The decay rate for accumulating squared gradients and deltas. + * @param epsilon A small value used for numerical stability. + * @param weightDecay The amount of L2 weight decay to use for all the + * parameters. + */ + explicit AdadeltaOptimizer( + const std::vector& parameters, + float learningRate = 1.0, + float rho = 0.9, + float epsilon = 1e-8, + float weightDecay = 0 + ); - public: - /** Construct an Adadelta optimizer. - * @param parameters The parameters from e.g. `model.parameters()`. - * @param learningRate The learning rate for scaling delta. The original - * paper does not include this term (i.e. learningRate = 1.0). - * @param rho The decay rate for accumulating squared gradients and deltas. - * @param epsilon A small value used for numerical stability. - * @param weightDecay The amount of L2 weight decay to use for all the - * parameters. - */ - explicit AdadeltaOptimizer( - const std::vector& parameters, - float learningRate = 1.0, - float rho = 0.9, - float epsilon = 1e-8, - float weightDecay = 0); + void step() override; - void step() override; - - std::string prettyString() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/optim/AdagradOptimizer.cpp b/flashlight/fl/optim/AdagradOptimizer.cpp index 81731a3..9b2cc25 100644 --- a/flashlight/fl/optim/AdagradOptimizer.cpp +++ b/flashlight/fl/optim/AdagradOptimizer.cpp @@ -17,48 +17,48 @@ AdagradOptimizer::AdagradOptimizer( const std::vector& parameters, float learningRate /* = 1.0 */, float epsilon /* = 1e-8 */, - float weightDecay /* = 0 */) - : FirstOrderOptimizer(parameters, learningRate), - eps_(epsilon), - wd_(weightDecay) { - variance_.reserve(parameters.size()); - for (const auto& param : parameters_) { - variance_.push_back(fl::full(param.shape(), 0, param.type())); - fl::eval(variance_.back()); - } + float weightDecay /* = 0 */ +) : FirstOrderOptimizer(parameters, learningRate), + eps_(epsilon), + wd_(weightDecay) { + variance_.reserve(parameters.size()); + for(const auto& param : parameters_) { + variance_.push_back(fl::full(param.shape(), 0, param.type())); + fl::eval(variance_.back()); + } } void AdagradOptimizer::step() { - for (size_t i = 0; i < parameters_.size(); i++) { - if (!parameters_[i].isGradAvailable()) { - continue; - } + for(size_t i = 0; i < parameters_.size(); i++) { + if(!parameters_[i].isGradAvailable()) { + continue; + } - const Tensor& grad = parameters_[i].grad().tensor(); - Tensor& data = parameters_[i].tensor(); - Tensor& variance = variance_[i]; + const Tensor& grad = parameters_[i].grad().tensor(); + Tensor& data = parameters_[i].tensor(); + Tensor& variance = variance_[i]; - if (wd_ != 0) { - // Weight decay term - data = data - wd_ * data; - } + if(wd_ != 0) { + // Weight decay term + data = data - wd_ * data; + } - variance = variance + grad * grad; - fl::eval(variance); - data = data - lr_ * grad / (fl::sqrt(variance) + eps_); - fl::eval(data); - } + variance = variance + grad * grad; + fl::eval(variance); + data = data - lr_ * grad / (fl::sqrt(variance) + eps_); + fl::eval(data); + } } std::string AdagradOptimizer::prettyString() const { - std::ostringstream ss; - ss << "Adagrad"; + std::ostringstream ss; + ss << "Adagrad"; - if (eps_ != 0) { - ss << " (epsilon=" << eps_ << ")"; - } + if(eps_ != 0) { + ss << " (epsilon=" << eps_ << ")"; + } - return ss.str(); + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/optim/AdagradOptimizer.h b/flashlight/fl/optim/AdagradOptimizer.h index 92f6a5a..1fbe188 100644 --- a/flashlight/fl/optim/AdagradOptimizer.h +++ b/flashlight/fl/optim/AdagradOptimizer.h @@ -23,30 +23,29 @@ namespace fl { * http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf). */ class FL_API AdagradOptimizer : public FirstOrderOptimizer { - private: - FL_SAVE_LOAD_WITH_BASE(FirstOrderOptimizer, eps_, wd_, variance_) - - AdagradOptimizer() = default; // Intentionally private - - float eps_; - float wd_; - std::vector variance_; // store sum_{tau=0}^{tau=t} grad_tau*grad_tau - - public: - /** Construct an Adagrad optimizer - * @param parameters The parameters from e.g. `model.parameters()`. - * @param learningRate The learning rate. - * @param epsilon A small value used for numerical stability. - */ - explicit AdagradOptimizer( - const std::vector& parameters, - float learningRate = 1.0, - float epsilon = 1e-8, - float weightDecay = 0); - - void step() override; - - std::string prettyString() const override; +private: + FL_SAVE_LOAD_WITH_BASE(FirstOrderOptimizer, eps_, wd_, variance_) AdagradOptimizer() = default; // Intentionally private + + float eps_; + float wd_; + std::vector variance_; // store sum_{tau=0}^{tau=t} grad_tau*grad_tau + +public: + /** Construct an Adagrad optimizer + * @param parameters The parameters from e.g. `model.parameters()`. + * @param learningRate The learning rate. + * @param epsilon A small value used for numerical stability. + */ + explicit AdagradOptimizer( + const std::vector& parameters, + float learningRate = 1.0, + float epsilon = 1e-8, + float weightDecay = 0 + ); + + void step() override; + + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/optim/AdamOptimizer.cpp b/flashlight/fl/optim/AdamOptimizer.cpp index 39df848..af8d0df 100644 --- a/flashlight/fl/optim/AdamOptimizer.cpp +++ b/flashlight/fl/optim/AdamOptimizer.cpp @@ -21,70 +21,70 @@ AdamOptimizer::AdamOptimizer( float beta1 /* = 0.9 */, float beta2 /* = 0.999 */, float epsilon /* = 1e-8 */, - float weightDecay /* = 0 */) - : FirstOrderOptimizer(parameters, learningRate), - beta1_(beta1), - beta2_(beta2), - eps_(epsilon), - wd_(weightDecay), - count_(0), - biasedFirst_(), - biasedSecond_() { - biasedFirst_.reserve(parameters.size()); - biasedSecond_.reserve(parameters.size()); - - for (const auto& parameter : parameters_) { - biasedFirst_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - biasedSecond_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - - fl::eval(biasedFirst_.back()); - fl::eval(biasedSecond_.back()); - } + float weightDecay /* = 0 */ +) : FirstOrderOptimizer(parameters, learningRate), + beta1_(beta1), + beta2_(beta2), + eps_(epsilon), + wd_(weightDecay), + count_(0), + biasedFirst_(), + biasedSecond_() { + biasedFirst_.reserve(parameters.size()); + biasedSecond_.reserve(parameters.size()); + + for(const auto& parameter : parameters_) { + biasedFirst_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + biasedSecond_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + + fl::eval(biasedFirst_.back()); + fl::eval(biasedSecond_.back()); + } } void AdamOptimizer::step() { - count_++; - float correctedBias1 = 1 - std::pow(beta1_, count_); - float correctedBias2 = 1 - std::pow(beta2_, count_); - float correctedLr = lr_ * std::sqrt(correctedBias2) / correctedBias1; - - for (size_t i = 0; i < parameters_.size(); i++) { - if (!parameters_[i].isGradAvailable()) { - continue; - } + count_++; + float correctedBias1 = 1 - std::pow(beta1_, count_); + float correctedBias2 = 1 - std::pow(beta2_, count_); + float correctedLr = lr_ * std::sqrt(correctedBias2) / correctedBias1; - const Tensor& grad = parameters_[i].grad().tensor(); - Tensor& data = parameters_[i].tensor(); + for(size_t i = 0; i < parameters_.size(); i++) { + if(!parameters_[i].isGradAvailable()) { + continue; + } - if (wd_ != 0) { - // Weight decay term - data = data - wd_ * lr_ * data; - } + const Tensor& grad = parameters_[i].grad().tensor(); + Tensor& data = parameters_[i].tensor(); - Tensor& biasedFirst = biasedFirst_[i]; - Tensor& biasedSecond = biasedSecond_[i]; + if(wd_ != 0) { + // Weight decay term + data = data - wd_ * lr_ * data; + } - biasedFirst = beta1_ * biasedFirst + (1 - beta1_) * grad; - biasedSecond = beta2_ * biasedSecond + (1 - beta2_) * grad * grad; + Tensor& biasedFirst = biasedFirst_[i]; + Tensor& biasedSecond = biasedSecond_[i]; - fl::eval(biasedFirst); - fl::eval(biasedSecond); + biasedFirst = beta1_ * biasedFirst + (1 - beta1_) * grad; + biasedSecond = beta2_ * biasedSecond + (1 - beta2_) * grad * grad; - data = data - (correctedLr * biasedFirst) / (fl::sqrt(biasedSecond) + eps_); + fl::eval(biasedFirst); + fl::eval(biasedSecond); - fl::eval(data); - } + data = data - (correctedLr * biasedFirst) / (fl::sqrt(biasedSecond) + eps_); + + fl::eval(data); + } } std::string AdamOptimizer::prettyString() const { - std::ostringstream ss; - ss << "Adam"; + std::ostringstream ss; + ss << "Adam"; - if (wd_ != 0) { - ss << " (weight decay=" << wd_ << ")"; - } + if(wd_ != 0) { + ss << " (weight decay=" << wd_ << ")"; + } - return ss.str(); + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/optim/AdamOptimizer.h b/flashlight/fl/optim/AdamOptimizer.h index 4919bcc..e223396 100644 --- a/flashlight/fl/optim/AdamOptimizer.h +++ b/flashlight/fl/optim/AdamOptimizer.h @@ -22,48 +22,48 @@ namespace fl { * https://arxiv.org/abs/1412.6980). */ class FL_API AdamOptimizer : public FirstOrderOptimizer { - private: - FL_SAVE_LOAD_WITH_BASE( - FirstOrderOptimizer, - fl::serializeAs(beta1_), - fl::serializeAs(beta2_), - fl::serializeAs(eps_), - fl::serializeAs(wd_), - count_, - biasedFirst_, - biasedSecond_) +private: + FL_SAVE_LOAD_WITH_BASE( + FirstOrderOptimizer, + fl::serializeAs(beta1_), + fl::serializeAs(beta2_), + fl::serializeAs(eps_), + fl::serializeAs(wd_), + count_, + biasedFirst_, + biasedSecond_ + ) AdamOptimizer() = default; // Intentionally private - AdamOptimizer() = default; // Intentionally private + float beta1_; + float beta2_; + float eps_; + float wd_; + int count_; + std::vector biasedFirst_; + std::vector biasedSecond_; - float beta1_; - float beta2_; - float eps_; - float wd_; - int count_; - std::vector biasedFirst_; - std::vector biasedSecond_; +public: + /** Construct an Adam optimizer. + * @param parameters The parameters from e.g. `model.parameters()`. + * @param learningRate The learning rate. + * @param beta1 Adam hyperparameter \f$ \beta_1 \f$. + * @param beta2 Adam hyperparameter \f$ \beta_2 \f$. + * @param epsilon A small value used for numerical stability. + * @param weightDecay The amount of L2 weight decay to use for all the + * parameters. + */ + AdamOptimizer( + const std::vector& parameters, + float learningRate, + float beta1 = 0.9, + float beta2 = 0.999, + float epsilon = 1e-8, + float weightDecay = 0 + ); - public: - /** Construct an Adam optimizer. - * @param parameters The parameters from e.g. `model.parameters()`. - * @param learningRate The learning rate. - * @param beta1 Adam hyperparameter \f$ \beta_1 \f$. - * @param beta2 Adam hyperparameter \f$ \beta_2 \f$. - * @param epsilon A small value used for numerical stability. - * @param weightDecay The amount of L2 weight decay to use for all the - * parameters. - */ - AdamOptimizer( - const std::vector& parameters, - float learningRate, - float beta1 = 0.9, - float beta2 = 0.999, - float epsilon = 1e-8, - float weightDecay = 0); + void step() override; - void step() override; - - std::string prettyString() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/optim/NAGOptimizer.cpp b/flashlight/fl/optim/NAGOptimizer.cpp index bca9f60..ca33650 100644 --- a/flashlight/fl/optim/NAGOptimizer.cpp +++ b/flashlight/fl/optim/NAGOptimizer.cpp @@ -19,58 +19,59 @@ NAGOptimizer::NAGOptimizer( const vector& parameters, float learningRate, float momentum /* = 0 */, - float weightDecay /* = 0 */) - : FirstOrderOptimizer(parameters, learningRate), - mu_(momentum), - wd_(weightDecay), - velocities_(), - oldLr_(learningRate) { - if (momentum <= 0) { - throw std::runtime_error( - "Invalid momentum for NAG optimizer, it should be > 0"); - } - velocities_.reserve(parameters.size()); - for (const auto& parameter : parameters_) { - velocities_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - fl::eval(velocities_.back()); - } + float weightDecay /* = 0 */ +) : FirstOrderOptimizer(parameters, learningRate), + mu_(momentum), + wd_(weightDecay), + velocities_(), + oldLr_(learningRate) { + if(momentum <= 0) { + throw std::runtime_error( + "Invalid momentum for NAG optimizer, it should be > 0" + ); + } + velocities_.reserve(parameters.size()); + for(const auto& parameter : parameters_) { + velocities_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + fl::eval(velocities_.back()); + } } void NAGOptimizer::step() { - float correctedLr = lr_ / oldLr_; + float correctedLr = lr_ / oldLr_; - for (size_t i = 0; i < parameters_.size(); i++) { - if (!parameters_[i].isGradAvailable()) { - continue; - } + for(size_t i = 0; i < parameters_.size(); i++) { + if(!parameters_[i].isGradAvailable()) { + continue; + } - Tensor& grad = parameters_[i].grad().tensor(); - Tensor& data = parameters_[i].tensor(); + Tensor& grad = parameters_[i].grad().tensor(); + Tensor& data = parameters_[i].tensor(); - if (wd_ != 0) { - // Weight decay term - data = data * (1 - lr_ * wd_); + if(wd_ != 0) { + // Weight decay term + data = data * (1 - lr_ * wd_); + } + Tensor& velocity = velocities_[i]; + // this velocity corresponds to fairseq velocity * -1 + velocity = mu_ * velocity * correctedLr + lr_ * grad; + fl::eval(velocity); + grad = grad * lr_ + velocity * mu_; + data = data - grad; + fl::eval(data); } - Tensor& velocity = velocities_[i]; - // this velocity corresponds to fairseq velocity * -1 - velocity = mu_ * velocity * correctedLr + lr_ * grad; - fl::eval(velocity); - grad = grad * lr_ + velocity * mu_; - data = data - grad; - fl::eval(data); - } - oldLr_ = lr_; + oldLr_ = lr_; } std::string NAGOptimizer::prettyString() const { - std::ostringstream ss; - ss << "NAG (lr=" << lr_ << " ); (previous lr=" << oldLr_ << ");"; + std::ostringstream ss; + ss << "NAG (lr=" << lr_ << " ); (previous lr=" << oldLr_ << ");"; - if (wd_ != 0) { - ss << " (weight decay=" << wd_ << ");"; - } - ss << " (Nesterov momentum=" << mu_ << ")"; - return ss.str(); + if(wd_ != 0) { + ss << " (weight decay=" << wd_ << ");"; + } + ss << " (Nesterov momentum=" << mu_ << ")"; + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/optim/NAGOptimizer.h b/flashlight/fl/optim/NAGOptimizer.h index 609218b..d7bc673 100644 --- a/flashlight/fl/optim/NAGOptimizer.h +++ b/flashlight/fl/optim/NAGOptimizer.h @@ -18,33 +18,32 @@ namespace fl { * https://github.com/pytorch/fairseq/blob/e75cff5f2c1d62f12dc911e0bf420025eb1a4e33/fairseq/optim/nag.py#L43 */ class FL_API NAGOptimizer : public FirstOrderOptimizer { - private: - FL_SAVE_LOAD_WITH_BASE(FirstOrderOptimizer, mu_, wd_, velocities_, oldLr_) - - NAGOptimizer() = default; // Intentionally private - - float mu_; - float wd_; - std::vector velocities_; - float oldLr_; - - public: - /** NAGOptimizer constructor. - * @param parameters The parameters from e.g. `model.parameters()` - * @param learningRate The learning rate. - * @param momentum The momentum. - * @param weightDecay The amount of L2 weight decay to use for all the - * parameters. - */ - NAGOptimizer( - const std::vector& parameters, - float learningRate, - float momentum = 0.99, - float weightDecay = 0); - - void step() override; - - std::string prettyString() const override; +private: + FL_SAVE_LOAD_WITH_BASE(FirstOrderOptimizer, mu_, wd_, velocities_, oldLr_) NAGOptimizer() = default; // Intentionally private + + float mu_; + float wd_; + std::vector velocities_; + float oldLr_; + +public: + /** NAGOptimizer constructor. + * @param parameters The parameters from e.g. `model.parameters()` + * @param learningRate The learning rate. + * @param momentum The momentum. + * @param weightDecay The amount of L2 weight decay to use for all the + * parameters. + */ + NAGOptimizer( + const std::vector& parameters, + float learningRate, + float momentum = 0.99, + float weightDecay = 0 + ); + + void step() override; + + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/optim/NovogradOptimizer.cpp b/flashlight/fl/optim/NovogradOptimizer.cpp index dc30b94..dd7a298 100644 --- a/flashlight/fl/optim/NovogradOptimizer.cpp +++ b/flashlight/fl/optim/NovogradOptimizer.cpp @@ -21,59 +21,59 @@ NovogradOptimizer::NovogradOptimizer( float beta1 /* = 0.9 */, float beta2 /* = 0.999 */, float epsilon /* = 1e-8 */, - float weightDecay /* = 0 */) - : FirstOrderOptimizer(parameters, learningRate), - beta1_(beta1), - beta2_(beta2), - eps_(epsilon), - wd_(weightDecay), - accGradNorm_(), - accGrad_() { - accGradNorm_.reserve(1); - accGrad_.reserve(parameters.size()); - - for (const auto& parameter : parameters_) { - accGradNorm_.emplace_back(0.0); - accGrad_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - - fl::eval(accGrad_.back()); - } + float weightDecay /* = 0 */ +) : FirstOrderOptimizer(parameters, learningRate), + beta1_(beta1), + beta2_(beta2), + eps_(epsilon), + wd_(weightDecay), + accGradNorm_(), + accGrad_() { + accGradNorm_.reserve(1); + accGrad_.reserve(parameters.size()); + + for(const auto& parameter : parameters_) { + accGradNorm_.emplace_back(0.0); + accGrad_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + + fl::eval(accGrad_.back()); + } } void NovogradOptimizer::step() { - for (size_t i = 0; i < parameters_.size(); i++) { - if (!parameters_[i].isGradAvailable()) { - continue; - } + for(size_t i = 0; i < parameters_.size(); i++) { + if(!parameters_[i].isGradAvailable()) { + continue; + } - const Tensor& grad = parameters_[i].grad().tensor(); - Tensor& data = parameters_[i].tensor(); - Tensor& accGrad = accGrad_[i]; + const Tensor& grad = parameters_[i].grad().tensor(); + Tensor& data = parameters_[i].tensor(); + Tensor& accGrad = accGrad_[i]; - double gradNorm = fl::sum(grad * grad).asScalar(); + double gradNorm = fl::sum(grad * grad).asScalar(); - accGradNorm_[i] = beta2_ * accGradNorm_[i] + (1 - beta2_) * gradNorm; - accGrad = beta1_ * accGrad + - (1 - beta1_) * - (grad / (static_cast(std::sqrt(accGradNorm_[i]) + eps_)) + - wd_ * data); - fl::eval(accGrad); + accGradNorm_[i] = beta2_ * accGradNorm_[i] + (1 - beta2_) * gradNorm; + accGrad = beta1_ * accGrad + + (1 - beta1_) + * (grad / (static_cast(std::sqrt(accGradNorm_[i]) + eps_)) + + wd_ * data); + fl::eval(accGrad); - data = data - (lr_ * accGrad); + data = data - (lr_ * accGrad); - fl::eval(data); - } + fl::eval(data); + } } std::string NovogradOptimizer::prettyString() const { - std::ostringstream ss; - ss << "Novograd"; + std::ostringstream ss; + ss << "Novograd"; - if (wd_ != 0) { - ss << " (weight decay=" << wd_ << ")"; - } + if(wd_ != 0) { + ss << " (weight decay=" << wd_ << ")"; + } - return ss.str(); + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/optim/NovogradOptimizer.h b/flashlight/fl/optim/NovogradOptimizer.h index d2b698a..2cd02d5 100644 --- a/flashlight/fl/optim/NovogradOptimizer.h +++ b/flashlight/fl/optim/NovogradOptimizer.h @@ -22,46 +22,46 @@ namespace fl { * Deep Networks](https://arxiv.org/abs/1905.11286). */ class FL_API NovogradOptimizer : public FirstOrderOptimizer { - private: - FL_SAVE_LOAD_WITH_BASE( - FirstOrderOptimizer, - beta1_, - beta2_, - eps_, - wd_, - accGradNorm_, - accGrad_) +private: + FL_SAVE_LOAD_WITH_BASE( + FirstOrderOptimizer, + beta1_, + beta2_, + eps_, + wd_, + accGradNorm_, + accGrad_ + ) NovogradOptimizer() = default; // Intentionally private - NovogradOptimizer() = default; // Intentionally private + float beta1_; + float beta2_; + float eps_; + float wd_; + std::vector accGradNorm_; + std::vector accGrad_; - float beta1_; - float beta2_; - float eps_; - float wd_; - std::vector accGradNorm_; - std::vector accGrad_; +public: + /** Construct a Novograd optimizer + * @param parameters The parameters from e.g. `model.parameters()`. + * @param learningRate The learning rate. + * @param beta1 Novograd hyperparameter \f$ \beta_1 \f$. + * @param beta2 Novograd hyperparameter \f$ \beta_2 \f$. + * @param epsilon A small value used for numerical stability. + * @param weightDecay The amount of L2 weight decay to use for all the + * parameters. + */ + explicit NovogradOptimizer( + const std::vector& parameters, + float learningRate, + float beta1 = 0.95, + float beta2 = 0.98, + float epsilon = 1e-8, + float weightDecay = 0 + ); - public: - /** Construct a Novograd optimizer - * @param parameters The parameters from e.g. `model.parameters()`. - * @param learningRate The learning rate. - * @param beta1 Novograd hyperparameter \f$ \beta_1 \f$. - * @param beta2 Novograd hyperparameter \f$ \beta_2 \f$. - * @param epsilon A small value used for numerical stability. - * @param weightDecay The amount of L2 weight decay to use for all the - * parameters. - */ - explicit NovogradOptimizer( - const std::vector& parameters, - float learningRate, - float beta1 = 0.95, - float beta2 = 0.98, - float epsilon = 1e-8, - float weightDecay = 0); + void step() override; - void step() override; - - std::string prettyString() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/optim/Optimizers.cpp b/flashlight/fl/optim/Optimizers.cpp index 1ff63fd..bed73de 100644 --- a/flashlight/fl/optim/Optimizers.cpp +++ b/flashlight/fl/optim/Optimizers.cpp @@ -18,13 +18,14 @@ namespace fl { FirstOrderOptimizer::FirstOrderOptimizer( const vector& parameters, - double learningRate) - : parameters_(parameters.begin(), parameters.end()), lr_(learningRate) {} + double learningRate +) : parameters_(parameters.begin(), parameters.end()), + lr_(learningRate) {} void FirstOrderOptimizer::zeroGrad() { - for (auto& parameter : parameters_) { - parameter.zeroGrad(); - } + for(auto& parameter : parameters_) { + parameter.zeroGrad(); + } } } // namespace fl diff --git a/flashlight/fl/optim/Optimizers.h b/flashlight/fl/optim/Optimizers.h index c13ad11..fe44e50 100644 --- a/flashlight/fl/optim/Optimizers.h +++ b/flashlight/fl/optim/Optimizers.h @@ -27,52 +27,53 @@ namespace fl { * \endcode */ class FL_API FirstOrderOptimizer { - private: - /** - * Serialize the module's parameters. - */ - FL_SAVE_LOAD(lr_, parameters_) +private: + /** + * Serialize the module's parameters. + */ + FL_SAVE_LOAD(lr_, parameters_) - protected: - std::vector parameters_; - double lr_; +protected: + std::vector parameters_; + double lr_; - FirstOrderOptimizer() = default; + FirstOrderOptimizer() = default; - public: - /** The `FirstOrderOptimizer` base class constructor. - * @param parameters The parameters from e.g. `model.parameters()` - * @param learningRate The learning rate. - */ - FirstOrderOptimizer( - const std::vector& parameters, - double learningRate); +public: + /** The `FirstOrderOptimizer` base class constructor. + * @param parameters The parameters from e.g. `model.parameters()` + * @param learningRate The learning rate. + */ + FirstOrderOptimizer( + const std::vector& parameters, + double learningRate + ); - virtual void step() = 0; + virtual void step() = 0; - /** Get the learning rate. */ - double getLr() const { - return lr_; - } + /** Get the learning rate. */ + double getLr() const { + return lr_; + } - /** Set the learning rate. */ - void setLr(double lr) { - lr_ = lr; - } + /** Set the learning rate. */ + void setLr(double lr) { + lr_ = lr; + } - /** Zero the gradients for all the parameters being optimized. Typically - * this will be called after every call to step(). - */ - virtual void zeroGrad(); + /** Zero the gradients for all the parameters being optimized. Typically + * this will be called after every call to step(). + */ + virtual void zeroGrad(); - /** - * Generates a stringified representation of the optimizer. - * - * @return a string containing the optimizer label - */ - virtual std::string prettyString() const = 0; + /** + * Generates a stringified representation of the optimizer. + * + * @return a string containing the optimizer label + */ + virtual std::string prettyString() const = 0; - virtual ~FirstOrderOptimizer() = default; + virtual ~FirstOrderOptimizer() = default; }; } // namespace fl diff --git a/flashlight/fl/optim/RMSPropOptimizer.cpp b/flashlight/fl/optim/RMSPropOptimizer.cpp index 6118470..18a4256 100644 --- a/flashlight/fl/optim/RMSPropOptimizer.cpp +++ b/flashlight/fl/optim/RMSPropOptimizer.cpp @@ -21,77 +21,77 @@ RMSPropOptimizer::RMSPropOptimizer( float rho /* = 0.99 */, float epsilon /* = 1e-8 */, float weightDecay /* = 0 */, - bool use_first /* = false */) - : FirstOrderOptimizer(parameters, learningRate), - useFirst_(use_first), - rho_(rho), - eps_(epsilon), - wd_(weightDecay), - first_(), - second_() { - if (useFirst_) { - first_.reserve(parameters.size()); - } - second_.reserve(parameters.size()); - - for (const auto& parameter : parameters_) { - if (useFirst_) { - first_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - fl::eval(first_.back()); + bool use_first /* = false */ +) : FirstOrderOptimizer(parameters, learningRate), + useFirst_(use_first), + rho_(rho), + eps_(epsilon), + wd_(weightDecay), + first_(), + second_() { + if(useFirst_) { + first_.reserve(parameters.size()); } + second_.reserve(parameters.size()); - second_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - fl::eval(second_.back()); - } -} - -void RMSPropOptimizer::step() { - for (size_t i = 0; i < parameters_.size(); i++) { - if (!parameters_[i].isGradAvailable()) { - continue; - } - - const Tensor& grad = parameters_[i].grad().tensor(); - Tensor& data = parameters_[i].tensor(); + for(const auto& parameter : parameters_) { + if(useFirst_) { + first_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + fl::eval(first_.back()); + } - if (wd_ != 0) { - // Weight decay term - data = data - wd_ * data; + second_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + fl::eval(second_.back()); } +} - Tensor& second = second_[i]; - second = rho_ * second + (1 - rho_) * grad * grad; - fl::eval(second); - - // Create shallow copy of second so that we don't update - // "second" below - Tensor moments = second; - if (useFirst_) { - Tensor& first = first_[i]; - first = rho_ * first + (1 - rho_) * grad; - moments = moments - first * first; - fl::eval(first); +void RMSPropOptimizer::step() { + for(size_t i = 0; i < parameters_.size(); i++) { + if(!parameters_[i].isGradAvailable()) { + continue; + } + + const Tensor& grad = parameters_[i].grad().tensor(); + Tensor& data = parameters_[i].tensor(); + + if(wd_ != 0) { + // Weight decay term + data = data - wd_ * data; + } + + Tensor& second = second_[i]; + second = rho_ * second + (1 - rho_) * grad * grad; + fl::eval(second); + + // Create shallow copy of second so that we don't update + // "second" below + Tensor moments = second; + if(useFirst_) { + Tensor& first = first_[i]; + first = rho_ * first + (1 - rho_) * grad; + moments = moments - first * first; + fl::eval(first); + } + + data = data - (lr_ * grad) / (fl::sqrt(moments) + eps_); + + fl::eval(data); } - - data = data - (lr_ * grad) / (fl::sqrt(moments) + eps_); - - fl::eval(data); - } } std::string RMSPropOptimizer::prettyString() const { - std::ostringstream ss; - ss << "RMSProp"; + std::ostringstream ss; + ss << "RMSProp"; - if (wd_ != 0) { - ss << " (weight decay=" << wd_ << ")"; - } + if(wd_ != 0) { + ss << " (weight decay=" << wd_ << ")"; + } - if (useFirst_) { - ss << " (use first moment)"; - } + if(useFirst_) { + ss << " (use first moment)"; + } - return ss.str(); + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/optim/RMSPropOptimizer.h b/flashlight/fl/optim/RMSPropOptimizer.h index 65297fc..4b5ed58 100644 --- a/flashlight/fl/optim/RMSPropOptimizer.h +++ b/flashlight/fl/optim/RMSPropOptimizer.h @@ -22,48 +22,48 @@ namespace fl { * and https://arxiv.org/pdf/1308.0850v5.pdf. */ class FL_API RMSPropOptimizer : public FirstOrderOptimizer { - private: - FL_SAVE_LOAD_WITH_BASE( - FirstOrderOptimizer, - useFirst_, - fl::serializeAs(rho_), - fl::serializeAs(eps_), - fl::serializeAs(wd_), - first_, - second_) +private: + FL_SAVE_LOAD_WITH_BASE( + FirstOrderOptimizer, + useFirst_, + fl::serializeAs(rho_), + fl::serializeAs(eps_), + fl::serializeAs(wd_), + first_, + second_ + ) RMSPropOptimizer() = default; // Intentionally private - RMSPropOptimizer() = default; // Intentionally private + bool useFirst_; + float rho_; + float eps_; + float wd_; + std::vector first_; + std::vector second_; - bool useFirst_; - float rho_; - float eps_; - float wd_; - std::vector first_; - std::vector second_; +public: + /** Construct an RMSProp optimizer. + * @param parameters The parameters from e.g. `model.parameters()`. + * @param learningRate The learning rate. + * @param rho The weight in the term \f$ rho * m + (1-rho) * g^2 \f$. + * @param epsilon A small value used for numerical stability. + * @param weightDecay The amount of L2 weight decay to use for all the + * parameters. + * @param use_first Use the first moment in the update. When `true` keep + * a running mean of the gradient and subtract it from the running mean of + * the squared gradients. + */ + RMSPropOptimizer( + const std::vector& parameters, + float learningRate, + float rho = 0.99, + float epsilon = 1e-8, + float weightDecay = 0, + bool use_first = false + ); - public: - /** Construct an RMSProp optimizer. - * @param parameters The parameters from e.g. `model.parameters()`. - * @param learningRate The learning rate. - * @param rho The weight in the term \f$ rho * m + (1-rho) * g^2 \f$. - * @param epsilon A small value used for numerical stability. - * @param weightDecay The amount of L2 weight decay to use for all the - * parameters. - * @param use_first Use the first moment in the update. When `true` keep - * a running mean of the gradient and subtract it from the running mean of - * the squared gradients. - */ - RMSPropOptimizer( - const std::vector& parameters, - float learningRate, - float rho = 0.99, - float epsilon = 1e-8, - float weightDecay = 0, - bool use_first = false); + void step() override; - void step() override; - - std::string prettyString() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/optim/SGDOptimizer.cpp b/flashlight/fl/optim/SGDOptimizer.cpp index 85f9778..6c31092 100644 --- a/flashlight/fl/optim/SGDOptimizer.cpp +++ b/flashlight/fl/optim/SGDOptimizer.cpp @@ -20,67 +20,67 @@ SGDOptimizer::SGDOptimizer( float learningRate, float momentum /* = 0 */, float weightDecay /* = 0 */, - bool useNesterov /* = false */) - : FirstOrderOptimizer(parameters, learningRate), - useNesterov_(useNesterov), - mu_(momentum), - wd_(weightDecay), - velocities_() { - if (momentum != 0) { - velocities_.reserve(parameters.size()); - for (const auto& parameter : parameters_) { - velocities_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); - fl::eval(velocities_.back()); + bool useNesterov /* = false */ +) : FirstOrderOptimizer(parameters, learningRate), + useNesterov_(useNesterov), + mu_(momentum), + wd_(weightDecay), + velocities_() { + if(momentum != 0) { + velocities_.reserve(parameters.size()); + for(const auto& parameter : parameters_) { + velocities_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); + fl::eval(velocities_.back()); + } } - } } void SGDOptimizer::step() { - for (size_t i = 0; i < parameters_.size(); i++) { - if (!parameters_[i].isGradAvailable()) { - continue; - } + for(size_t i = 0; i < parameters_.size(); i++) { + if(!parameters_[i].isGradAvailable()) { + continue; + } - Tensor& grad = parameters_[i].grad().tensor(); - Tensor& data = parameters_[i].tensor(); + Tensor& grad = parameters_[i].grad().tensor(); + Tensor& data = parameters_[i].tensor(); - if (wd_ != 0) { - // Weight decay term - grad = grad + wd_ * data; - } + if(wd_ != 0) { + // Weight decay term + grad = grad + wd_ * data; + } - if (mu_ != 0) { - Tensor& velocity = velocities_[i]; + if(mu_ != 0) { + Tensor& velocity = velocities_[i]; - // Regular momentum - velocity = mu_ * velocity + grad; - fl::eval(velocity); - if (useNesterov_) { - // Update for nesterov momentum - grad += velocity * mu_; - } else { - grad = velocity; - } + // Regular momentum + velocity = mu_ * velocity + grad; + fl::eval(velocity); + if(useNesterov_) { + // Update for nesterov momentum + grad += velocity * mu_; + } else { + grad = velocity; + } + } + data = data - lr_ * grad; + fl::eval(data); } - data = data - lr_ * grad; - fl::eval(data); - } } std::string SGDOptimizer::prettyString() const { - std::ostringstream ss; - ss << "SGD"; + std::ostringstream ss; + ss << "SGD"; - if (wd_ != 0) { - ss << " (weight decay=" << wd_ << ")"; - } - if (useNesterov_ && mu_ != 0) { - ss << " (Nesterov momentum=" << mu_ << ")"; - } else if (mu_ != 0) { - ss << " (momentum=" << mu_ << ")"; - } + if(wd_ != 0) { + ss << " (weight decay=" << wd_ << ")"; + } + if(useNesterov_ && mu_ != 0) { + ss << " (Nesterov momentum=" << mu_ << ")"; + } else if(mu_ != 0) { + ss << " (momentum=" << mu_ << ")"; + } - return ss.str(); + return ss.str(); } } // namespace fl diff --git a/flashlight/fl/optim/SGDOptimizer.h b/flashlight/fl/optim/SGDOptimizer.h index 67f6b68..ddca8be 100644 --- a/flashlight/fl/optim/SGDOptimizer.h +++ b/flashlight/fl/optim/SGDOptimizer.h @@ -29,40 +29,40 @@ namespace fl { * http://cs231n.github.io/neural-networks-3/#sgd */ class FL_API SGDOptimizer : public FirstOrderOptimizer { - private: - FL_SAVE_LOAD_WITH_BASE( - FirstOrderOptimizer, - useNesterov_, - fl::serializeAs(mu_), - fl::serializeAs(wd_), - velocities_) +private: + FL_SAVE_LOAD_WITH_BASE( + FirstOrderOptimizer, + useNesterov_, + fl::serializeAs(mu_), + fl::serializeAs(wd_), + velocities_ + ) SGDOptimizer() = default; // Intentionally private - SGDOptimizer() = default; // Intentionally private + bool useNesterov_; + float mu_; + float wd_; + std::vector velocities_; - bool useNesterov_; - float mu_; - float wd_; - std::vector velocities_; +public: + /** SGDOptimizer constructor. + * @param parameters The parameters from e.g. `model.parameters()` + * @param learningRate The learning rate. + * @param momentum The momentum. + * @param weightDecay The amount of L2 weight decay to use for all the + * parameters. + * @param useNesterov Whether or not to use nesterov style momentum. + */ + SGDOptimizer( + const std::vector& parameters, + float learningRate, + float momentum = 0, + float weightDecay = 0, + bool useNesterov = false + ); - public: - /** SGDOptimizer constructor. - * @param parameters The parameters from e.g. `model.parameters()` - * @param learningRate The learning rate. - * @param momentum The momentum. - * @param weightDecay The amount of L2 weight decay to use for all the - * parameters. - * @param useNesterov Whether or not to use nesterov style momentum. - */ - SGDOptimizer( - const std::vector& parameters, - float learningRate, - float momentum = 0, - float weightDecay = 0, - bool useNesterov = false); + void step() override; - void step() override; - - std::string prettyString() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/fl/optim/Utils.cpp b/flashlight/fl/optim/Utils.cpp index c43dfd8..a3b341b 100644 --- a/flashlight/fl/optim/Utils.cpp +++ b/flashlight/fl/optim/Utils.cpp @@ -14,26 +14,26 @@ namespace fl { double clipGradNorm(const std::vector& parameters, double maxNorm) { - double gradNorm = 0.0; - for (const auto& p : parameters) { - if (!p.isGradAvailable()) { - continue; + double gradNorm = 0.0; + for(const auto& p : parameters) { + if(!p.isGradAvailable()) { + continue; + } + const auto& grad = p.grad().tensor(); + gradNorm += fl::sum(grad * grad).asScalar(); } - const auto& grad = p.grad().tensor(); - gradNorm += fl::sum(grad * grad).asScalar(); - } - gradNorm = std::sqrt(gradNorm); - double scale = maxNorm / (gradNorm + 1e-6); - if (scale >= 1.0) { - return gradNorm; - } - for (auto& p : parameters) { - if (!p.isGradAvailable()) { - continue; + gradNorm = std::sqrt(gradNorm); + double scale = maxNorm / (gradNorm + 1e-6); + if(scale >= 1.0) { + return gradNorm; + } + for(auto& p : parameters) { + if(!p.isGradAvailable()) { + continue; + } + p.grad().tensor() *= scale; } - p.grad().tensor() *= scale; - } - return gradNorm; + return gradNorm; } } // namespace fl diff --git a/flashlight/fl/optim/Utils.h b/flashlight/fl/optim/Utils.h index 3f23ceb..ef83ffd 100644 --- a/flashlight/fl/optim/Utils.h +++ b/flashlight/fl/optim/Utils.h @@ -16,6 +16,7 @@ namespace fl { FL_API double clipGradNorm( const std::vector& parameters, - double max_norm); + double max_norm +); } diff --git a/flashlight/fl/runtime/CUDADevice.cpp b/flashlight/fl/runtime/CUDADevice.cpp index 178526e..0b74dfc 100644 --- a/flashlight/fl/runtime/CUDADevice.cpp +++ b/flashlight/fl/runtime/CUDADevice.cpp @@ -13,11 +13,11 @@ namespace fl { CUDADevice::CUDADevice(const int nativeId) : nativeId_(nativeId) {} int CUDADevice::nativeId() const { - return nativeId_; + return nativeId_; } void CUDADevice::setActiveImpl() const { - FL_CUDA_CHECK(cudaSetDevice(nativeId_)); + FL_CUDA_CHECK(cudaSetDevice(nativeId_)); } } // namespace fl diff --git a/flashlight/fl/runtime/CUDADevice.h b/flashlight/fl/runtime/CUDADevice.h index a674b4a..cde0f6e 100644 --- a/flashlight/fl/runtime/CUDADevice.h +++ b/flashlight/fl/runtime/CUDADevice.h @@ -1,4 +1,3 @@ - /* * Copyright (c) Meta Platforms, Inc. and affiliates. * @@ -16,31 +15,31 @@ namespace fl { * Represents a CUDA device. */ class FL_API CUDADevice : public DeviceTrait { - // native ID of the underlying CUDA device - const int nativeId_; - // TODO metadata, e.g., memory/compute capacity - - public: - static constexpr DeviceType type = DeviceType::CUDA; - - /** - * Creates a wrapper around the CUDA device with given native device ID. - * - * @param[in] nativeId the CUDA device ID with which to create this Device. - */ - explicit CUDADevice(int nativeId); - - /** - * Returns the native CUDA device ID. - * - * @return an integer representing the native CUDA device ID. - */ - int nativeId() const override; - - /** - * Set the underlying CUDA device as active. - */ - void setActiveImpl() const override; + // native ID of the underlying CUDA device + const int nativeId_; + // TODO metadata, e.g., memory/compute capacity + +public: + static constexpr DeviceType type = DeviceType::CUDA; + + /** + * Creates a wrapper around the CUDA device with given native device ID. + * + * @param[in] nativeId the CUDA device ID with which to create this Device. + */ + explicit CUDADevice(int nativeId); + + /** + * Returns the native CUDA device ID. + * + * @return an integer representing the native CUDA device ID. + */ + int nativeId() const override; + + /** + * Set the underlying CUDA device as active. + */ + void setActiveImpl() const override; }; } // namespace fl diff --git a/flashlight/fl/runtime/CUDAStream.cpp b/flashlight/fl/runtime/CUDAStream.cpp index 6717c17..6cc9eb4 100644 --- a/flashlight/fl/runtime/CUDAStream.cpp +++ b/flashlight/fl/runtime/CUDAStream.cpp @@ -14,113 +14,123 @@ namespace fl { -CUDAStream::CUDAStream(CUDADevice& device, cudaStream_t stream, bool managed) - : device_(device), nativeStream_(stream), managed_(managed) { - // Ensure `event_` and `nativeStream_` are associated with the same device - assert( - &DeviceManager::getInstance().getActiveDevice(DeviceType::CUDA) == - &device); - // `event_` is used by relativeSync only -- disable timing to reduce overhead - FL_CUDA_CHECK( - cudaEventCreate(&event_, cudaEventDefault | cudaEventDisableTiming)); +CUDAStream::CUDAStream(CUDADevice& device, cudaStream_t stream, bool managed) : device_(device), + nativeStream_(stream), + managed_(managed) { + // Ensure `event_` and `nativeStream_` are associated with the same device + assert( + &DeviceManager::getInstance().getActiveDevice(DeviceType::CUDA) + == &device + ); + // `event_` is used by relativeSync only -- disable timing to reduce overhead + FL_CUDA_CHECK( + cudaEventCreate(&event_, cudaEventDefault | cudaEventDisableTiming) + ); } std::shared_ptr CUDAStream::makeSharedAndRegister( CUDADevice& device, cudaStream_t stream, - bool managed) { - auto rawStreamPtr = new CUDAStream(device, stream, managed); - auto streamPtr = std::shared_ptr(rawStreamPtr); - device.addStream(streamPtr); - return streamPtr; + bool managed +) { + auto rawStreamPtr = new CUDAStream(device, stream, managed); + auto streamPtr = std::shared_ptr(rawStreamPtr); + device.addStream(streamPtr); + return streamPtr; } std::shared_ptr CUDAStream::create(int flag, bool managed) { - cudaStream_t nativeStream; - FL_CUDA_CHECK(cudaStreamCreateWithFlags(&nativeStream, flag)); - auto& manager = DeviceManager::getInstance(); - auto& device = manager.getActiveDevice(DeviceType::CUDA).impl(); - return makeSharedAndRegister(device, nativeStream, managed); + cudaStream_t nativeStream; + FL_CUDA_CHECK(cudaStreamCreateWithFlags(&nativeStream, flag)); + auto& manager = DeviceManager::getInstance(); + auto& device = manager.getActiveDevice(DeviceType::CUDA).impl(); + return makeSharedAndRegister(device, nativeStream, managed); } std::shared_ptr CUDAStream::createManaged(int flag) { - return CUDAStream::create(flag, /* managed */ true); + return CUDAStream::create(flag, /* managed */ true); } std::shared_ptr CUDAStream::createUnmanaged(int flag) { - return CUDAStream::create(flag, /* managed */ false); + return CUDAStream::create(flag, /* managed */ false); } std::shared_ptr CUDAStream::wrapUnmanaged( int deviceId, - cudaStream_t stream) { - auto& manager = DeviceManager::getInstance(); - const auto& oldActiveDevice = manager.getActiveDevice(DeviceType::CUDA); - auto& device = - manager.getDevice(DeviceType::CUDA, deviceId).impl(); - // satisfies assumptions of makeSharedAndRegister - bool needDeviceSwitch = &oldActiveDevice != &device; - if (needDeviceSwitch) { - device.setActive(); - } - auto streamPtr = makeSharedAndRegister(device, stream, /* managed */ false); - if (needDeviceSwitch) { - oldActiveDevice.setActive(); - } - return streamPtr; + cudaStream_t stream +) { + auto& manager = DeviceManager::getInstance(); + const auto& oldActiveDevice = manager.getActiveDevice(DeviceType::CUDA); + auto& device = + manager.getDevice(DeviceType::CUDA, deviceId).impl(); + // satisfies assumptions of makeSharedAndRegister + bool needDeviceSwitch = &oldActiveDevice != &device; + if(needDeviceSwitch) { + device.setActive(); + } + auto streamPtr = makeSharedAndRegister(device, stream, /* managed */ false); + if(needDeviceSwitch) { + oldActiveDevice.setActive(); + } + return streamPtr; } CUDAStream::~CUDAStream() { - if (managed_) { - FL_CUDA_CHECK(cudaStreamDestroy(nativeStream_)); - // Ideally we should unconditionally destroy the event we created, but there - // is a race hazard between CUDAStream destructor in global context and CUDA - // shutdown (sometimes the latter may precede the former). So we destroy the - // event only when it's safe to do so - FL_CUDA_CHECK(cudaEventDestroy(event_)); - } else { + if(managed_) { + FL_CUDA_CHECK(cudaStreamDestroy(nativeStream_)); + // Ideally we should unconditionally destroy the event we created, but there + // is a race hazard between CUDAStream destructor in global context and CUDA + // shutdown (sometimes the latter may precede the former). So we destroy the + // event only when it's safe to do so + FL_CUDA_CHECK(cudaEventDestroy(event_)); + } else { #ifdef NO_CUDA_STREAM_DESTROY_EVENT - // Note that this case only results in cuda event "resource leak" if someone - // creates an unmanaged cuda stream. But managed cuda streams are often used - // in a global context and released at program shutdown (e.g., for cudnn). - // So chances of real resource leak is very low. + // Note that this case only results in cuda event "resource leak" if someone + // creates an unmanaged cuda stream. But managed cuda streams are often used + // in a global context and released at program shutdown (e.g., for cudnn). + // So chances of real resource leak is very low. #else - FL_CUDA_CHECK(cudaEventDestroy(event_)); + FL_CUDA_CHECK(cudaEventDestroy(event_)); #endif - } + } } const CUDADevice& CUDAStream::device() const { - return device_; + return device_; } CUDADevice& CUDAStream::device() { - return device_; + return device_; } void CUDAStream::sync() const { - FL_CUDA_CHECK(cudaStreamSynchronize(this->nativeStream_)); + FL_CUDA_CHECK(cudaStreamSynchronize(this->nativeStream_)); } void CUDAStream::relativeSync(const CUDAStream& waitOn) const { - auto& manager = DeviceManager::getInstance(); - auto* oldActiveCUDADevice = &manager.getActiveDevice(DeviceType::CUDA); - bool needDeviceSwitch = oldActiveCUDADevice != &device_; - if (needDeviceSwitch) { - device_.setActive(); - } - // event and stream from same instance are guaranteed to have been created - // from the same device - FL_CUDA_CHECK(cudaEventRecord(waitOn.event_, waitOn.nativeStream_)); - FL_CUDA_CHECK(cudaStreamWaitEvent( - this->nativeStream_, waitOn.event_, /* cudaEventWaitDefault = */ 0)); - if (needDeviceSwitch) { - oldActiveCUDADevice->setActive(); - } + auto& manager = DeviceManager::getInstance(); + auto* oldActiveCUDADevice = &manager.getActiveDevice(DeviceType::CUDA); + bool needDeviceSwitch = oldActiveCUDADevice != &device_; + if(needDeviceSwitch) { + device_.setActive(); + } + // event and stream from same instance are guaranteed to have been created + // from the same device + FL_CUDA_CHECK(cudaEventRecord(waitOn.event_, waitOn.nativeStream_)); + FL_CUDA_CHECK( + cudaStreamWaitEvent( + this->nativeStream_, + waitOn.event_, /* cudaEventWaitDefault = */ + 0 + ) + ); + if(needDeviceSwitch) { + oldActiveCUDADevice->setActive(); + } } cudaStream_t CUDAStream::handle() const { - return nativeStream_; + return nativeStream_; } } // namespace fl diff --git a/flashlight/fl/runtime/CUDAStream.h b/flashlight/fl/runtime/CUDAStream.h index 40e7d36..2361192 100644 --- a/flashlight/fl/runtime/CUDAStream.h +++ b/flashlight/fl/runtime/CUDAStream.h @@ -18,110 +18,114 @@ namespace fl { * An abstraction for CUDA stream with controlled creation methods. */ class FL_API CUDAStream : public StreamTrait { - // the device upon which the underlying native stream was created - CUDADevice& device_; - // the underlying native stream - cudaStream_t nativeStream_; - // whether the native stream's lifetime is managed by this object - const bool managed_; - // re-used for relative synchronization to reduce overhead. Guaranteed to - // associate with the same device as `nativeStream_`, i.e., `device_` - cudaEvent_t event_; - - /** - * A barebones constructor which just initializes the fields. - * - * @param[in] device the device on which `stream` was created. - * @param[in] stream the underlying native CUDA stream. - * @param[in] managed whether this object will manage `stream`'s lifetime. - * - * ASSUME - * 1. `stream` was created on `device`. - * 2. `device` is the currently active cuda device. - */ - CUDAStream(CUDADevice& device, cudaStream_t stream, bool managed); - - /** - * Allocate a new CUDAStream as a shared_ptr and register it on given device. - * - * @param[in] device the device on which `stream` was created. - * @param[in] nativeStream the underlying native CUDA stream. - * @param[in] managed whether this object will manage `stream`'s lifetime. - * - * ASSUME - * 1. `nativeStream` was created on `device`. - * 2. `device` is the currently active cuda device. - */ - static std::shared_ptr makeSharedAndRegister( - CUDADevice& device, - cudaStream_t nativeStream, - bool managed); - - // A fully configurable create, hidden for internal use. - static std::shared_ptr create(int flag, bool managed); - - public: - // prevent name hiding - using StreamTrait::relativeSync; - - static constexpr StreamType type = StreamType::CUDA; - - /** - * Creates an unmanaged wrapper around an existing native CUDA stream and - * automatically register it on the device with given id in DeviceManager. - * - * @param[in] deviceId the native device ID upon which `stream` was created. - * @param[in] stream the underlying CUDA stream. - * @param[in] managed whether the lifetime of the created native stream will - * be managed by this object. - * - * @return a shared pointer to a CUDAStream that wraps around the given native - * stream. - */ - static std::shared_ptr wrapUnmanaged( - int deviceId, - cudaStream_t stream); - - /** - * Create a managed CUDAStream around an internally created native CUDA - * stream and automatically register it on the active CUDA device in - * DeviceManager. - * - * @param[in] flag the flag used for creating native CUDA stream. - * - * @return a shared pointer to the CUDAStream created. - */ - static std::shared_ptr createManaged( - int flag = cudaStreamDefault); - - /** - * Create an unmanaged CUDAStream around an internally created native CUDA - * stream and automatically register it on the active CUDA device in - * DeviceManager. - * - * @param[in] flag the flag used for creating native CUDA stream. - * - * @return a shared pointer to the CUDAStream created. - */ - static std::shared_ptr createUnmanaged( - int flag = cudaStreamDefault); - - /** - * Destroy any stream managed by this object. - */ - ~CUDAStream() override; - - CUDADevice& device() override; - const CUDADevice& device() const override; - void sync() const override; - void relativeSync(const CUDAStream& waitOn) const override; - - /** - * Get the native CUDA stream handle. - * - * @return the native CUDA stream handle. - */ - cudaStream_t handle() const; + // the device upon which the underlying native stream was created + CUDADevice& device_; + // the underlying native stream + cudaStream_t nativeStream_; + // whether the native stream's lifetime is managed by this object + const bool managed_; + // re-used for relative synchronization to reduce overhead. Guaranteed to + // associate with the same device as `nativeStream_`, i.e., `device_` + cudaEvent_t event_; + + /** + * A barebones constructor which just initializes the fields. + * + * @param[in] device the device on which `stream` was created. + * @param[in] stream the underlying native CUDA stream. + * @param[in] managed whether this object will manage `stream`'s lifetime. + * + * ASSUME + * 1. `stream` was created on `device`. + * 2. `device` is the currently active cuda device. + */ + CUDAStream(CUDADevice& device, cudaStream_t stream, bool managed); + + /** + * Allocate a new CUDAStream as a shared_ptr and register it on given device. + * + * @param[in] device the device on which `stream` was created. + * @param[in] nativeStream the underlying native CUDA stream. + * @param[in] managed whether this object will manage `stream`'s lifetime. + * + * ASSUME + * 1. `nativeStream` was created on `device`. + * 2. `device` is the currently active cuda device. + */ + static std::shared_ptr makeSharedAndRegister( + CUDADevice& device, + cudaStream_t nativeStream, + bool managed + ); + + // A fully configurable create, hidden for internal use. + static std::shared_ptr create(int flag, bool managed); + +public: + // prevent name hiding + using StreamTrait::relativeSync; + + static constexpr StreamType type = StreamType::CUDA; + + /** + * Creates an unmanaged wrapper around an existing native CUDA stream and + * automatically register it on the device with given id in DeviceManager. + * + * @param[in] deviceId the native device ID upon which `stream` was created. + * @param[in] stream the underlying CUDA stream. + * @param[in] managed whether the lifetime of the created native stream will + * be managed by this object. + * + * @return a shared pointer to a CUDAStream that wraps around the given native + * stream. + */ + static std::shared_ptr wrapUnmanaged( + int deviceId, + cudaStream_t stream + ); + + /** + * Create a managed CUDAStream around an internally created native CUDA + * stream and automatically register it on the active CUDA device in + * DeviceManager. + * + * @param[in] flag the flag used for creating native CUDA stream. + * + * @return a shared pointer to the CUDAStream created. + */ + static std::shared_ptr createManaged( + int flag = cudaStreamDefault + ); + + /** + * Create an unmanaged CUDAStream around an internally created native CUDA + * stream and automatically register it on the active CUDA device in + * DeviceManager. + * + * @param[in] flag the flag used for creating native CUDA stream. + * + * @return a shared pointer to the CUDAStream created. + */ + static std::shared_ptr createUnmanaged( + int flag = cudaStreamDefault + ); + + /** + * Destroy any stream managed by this object. + */ + ~CUDAStream() override; + + CUDADevice& device() override; + const CUDADevice& device() const override; + void sync() const override; + void relativeSync(const CUDAStream& waitOn) const override; + + /** + * Get the native CUDA stream handle. + * + * @return the native CUDA stream handle. + */ + cudaStream_t handle() const; }; } // namespace fl diff --git a/flashlight/fl/runtime/CUDAUtils.cpp b/flashlight/fl/runtime/CUDAUtils.cpp index a2404c2..ec1bcb3 100644 --- a/flashlight/fl/runtime/CUDAUtils.cpp +++ b/flashlight/fl/runtime/CUDAUtils.cpp @@ -15,35 +15,35 @@ namespace fl::cuda { int getActiveDeviceId() { - int cudaActiveDeviceId = 0; - FL_CUDA_CHECK(cudaGetDevice(&cudaActiveDeviceId)); - return cudaActiveDeviceId; + int cudaActiveDeviceId = 0; + FL_CUDA_CHECK(cudaGetDevice(&cudaActiveDeviceId)); + return cudaActiveDeviceId; } std::unordered_map> createCUDADevices() { - std::unordered_map> idToDevice; - int numCudaDevices = 0; - FL_CUDA_CHECK(cudaGetDeviceCount(&numCudaDevices)); - for (auto id = 0; id < numCudaDevices; id++) { - idToDevice.emplace(id, std::make_unique(id)); - } - return idToDevice; + std::unordered_map> idToDevice; + int numCudaDevices = 0; + FL_CUDA_CHECK(cudaGetDeviceCount(&numCudaDevices)); + for(auto id = 0; id < numCudaDevices; id++) { + idToDevice.emplace(id, std::make_unique(id)); + } + return idToDevice; } namespace detail { -void check(cudaError_t err, const char* file, int line) { - check(err, "", file, line); -} - -void check(cudaError_t err, const char* prefix, const char* file, int line) { - if (err != cudaSuccess) { - std::ostringstream ess; - ess << prefix << '[' << file << ':' << line - << "] CUDA error: " << cudaGetErrorString(err); - throw std::runtime_error(ess.str()); - } -} + void check(cudaError_t err, const char* file, int line) { + check(err, "", file, line); + } + + void check(cudaError_t err, const char* prefix, const char* file, int line) { + if(err != cudaSuccess) { + std::ostringstream ess; + ess << prefix << '[' << file << ':' << line + << "] CUDA error: " << cudaGetErrorString(err); + throw std::runtime_error(ess.str()); + } + } } // namespace detail diff --git a/flashlight/fl/runtime/CUDAUtils.h b/flashlight/fl/runtime/CUDAUtils.h index d6c466a..a2f7c38 100644 --- a/flashlight/fl/runtime/CUDAUtils.h +++ b/flashlight/fl/runtime/CUDAUtils.h @@ -15,7 +15,7 @@ #include #define FL_CUDA_CHECK(...) \ - ::fl::cuda::detail::check(__VA_ARGS__, __FILE__, __LINE__) + ::fl::cuda::detail::check(__VA_ARGS__, __FILE__, __LINE__) namespace fl { namespace cuda { @@ -25,22 +25,22 @@ namespace cuda { * * @return the native id of the active CUDA device. */ -FL_API int getActiveDeviceId(); + FL_API int getActiveDeviceId(); /** * Return a mapping from native CUDA device id to available CUDA devices. * * @return an unordered map from native CUDA device id to CUDA device. */ -FL_API std::unordered_map> createCUDADevices(); + FL_API std::unordered_map> createCUDADevices(); -namespace detail { + namespace detail { -FL_API void check(cudaError_t err, const char* file, int line); + FL_API void check(cudaError_t err, const char* file, int line); -FL_API void check(cudaError_t err, const char* prefix, const char* file, int line); + FL_API void check(cudaError_t err, const char* prefix, const char* file, int line); -} // namespace detail + } // namespace detail } // namespace cuda } // namespace fl diff --git a/flashlight/fl/runtime/Device.cpp b/flashlight/fl/runtime/Device.cpp index 7c4ab0a..c03fe5d 100644 --- a/flashlight/fl/runtime/Device.cpp +++ b/flashlight/fl/runtime/Device.cpp @@ -13,50 +13,51 @@ namespace fl { void deviceImplTypeCheck(DeviceType expect, DeviceType actual) { - if (expect != actual) { - std::ostringstream oss; - oss << "[fl::Device::impl] " + if(expect != actual) { + std::ostringstream oss; + oss << "[fl::Device::impl] " << "specified device type: [" << expect << "] " << "doesn't match actual device type: [" << actual << "]"; - throw std::invalid_argument(oss.str()); - } + throw std::invalid_argument(oss.str()); + } } const std::unordered_set>& Device::getStreams() const { - return streams_; + return streams_; } void Device::addStream(std::shared_ptr stream) { - if (&stream->device() != this) { - throw std::runtime_error( - "[Device::addStream] Must add stream to owner device"); - } - streams_.insert(stream); + if(&stream->device() != this) { + throw std::runtime_error( + "[Device::addStream] Must add stream to owner device" + ); + } + streams_.insert(stream); } void Device::sync() const { - for (const auto& stream : streams_) { - stream->sync(); - } + for(const auto& stream : streams_) { + stream->sync(); + } } void Device::addSetActiveCallback(std::function callback) { - setActiveCallbacks_.push_back(std::move(callback)); + setActiveCallbacks_.push_back(std::move(callback)); } void Device::setActive() const { - setActiveImpl(); - for (auto& callback : setActiveCallbacks_) { - callback(nativeId()); - } + setActiveImpl(); + for(auto& callback : setActiveCallbacks_) { + callback(nativeId()); + } } int X64Device::nativeId() const { - return fl::kX64DeviceId; + return fl::kX64DeviceId; } void X64Device::setActiveImpl() const { - // no op, CPU device is always active + // no op, CPU device is always active } } // namespace fl diff --git a/flashlight/fl/runtime/Device.h b/flashlight/fl/runtime/Device.h index e8db939..1248a2b 100644 --- a/flashlight/fl/runtime/Device.h +++ b/flashlight/fl/runtime/Device.h @@ -28,105 +28,105 @@ FL_API void deviceImplTypeCheck(DeviceType expect, DeviceType actual); * computing device. */ class FL_API Device { - std::unordered_set> streams_; - // Used to update internal backend state for active device, thereby - // eliminating the `setActive --> AnyTensorBackendImpl` dependency(s). - std::vector> setActiveCallbacks_; - - protected: - /** - * Set this device as the active device, without worrying about the callbacks. - */ - virtual void setActiveImpl() const = 0; - - public: - Device() = default; - virtual ~Device() = default; - - // no copy/move - Device(const Device&) = delete; - Device(Device&&) = delete; - Device& operator=(const Device&) = delete; - Device& operator=(Device&&) = delete; - - /** - * Return all streams managed by this device. - * - * @return an immutable vector reference containing all streams managed by - * this device. - */ - virtual const std::unordered_set>& getStreams() const; - - /** - * Let this device manage given stream. Do nothing if it was already added. - * - * Throws runtime_error if stream is owned by a different device than this - * one. - */ - virtual void addStream(std::shared_ptr stream); - - /** - * Block calling thread and synchronize w.r.t. all streams on this device. - */ - virtual void sync() const; - - /** - * Get the native ID of this device (semantics are implementation-dependent). - * - * @return the native ID of this device. - */ - virtual int nativeId() const = 0; - - /** - * Returns the type of this device. - * - * @return a enum denoting device type. - */ - virtual DeviceType type() const = 0; - - /** - * Set this device as the active device and invokes any callbacks added. - */ - void setActive() const; - - /** - * Lets this device keep track of the given callback (along with previously - * added ones), which will be invoked with the device's native ID after - * setting the device active. - * - * @param[in] callback the callback to be invoked with this device's native ID - */ - void addSetActiveCallback(std::function callback); - - /** - * Get the underlying implementation of this device. - * - * Throws invalid_argument if the specified type does not match the actual - * derived device type. - * - * @return an immutable reference to the specified device type. - */ - template - const T& impl() const { - deviceImplTypeCheck(T::type, type()); - return *(static_cast(this)); - } - - /** - * Get the underlying implementation of this device. - * - * Throws invalid_argument if the specified type does not match the actual - * derived device type. - * - * @return a reference to the specified device type. - */ - template - T& impl() { - deviceImplTypeCheck(T::type, type()); - return *(static_cast(this)); - } - - // TODO metadata, e.g., device name + std::unordered_set> streams_; + // Used to update internal backend state for active device, thereby + // eliminating the `setActive --> AnyTensorBackendImpl` dependency(s). + std::vector> setActiveCallbacks_; + +protected: + /** + * Set this device as the active device, without worrying about the callbacks. + */ + virtual void setActiveImpl() const = 0; + +public: + Device() = default; + virtual ~Device() = default; + + // no copy/move + Device(const Device&) = delete; + Device(Device&&) = delete; + Device& operator=(const Device&) = delete; + Device& operator=(Device&&) = delete; + + /** + * Return all streams managed by this device. + * + * @return an immutable vector reference containing all streams managed by + * this device. + */ + virtual const std::unordered_set>& getStreams() const; + + /** + * Let this device manage given stream. Do nothing if it was already added. + * + * Throws runtime_error if stream is owned by a different device than this + * one. + */ + virtual void addStream(std::shared_ptr stream); + + /** + * Block calling thread and synchronize w.r.t. all streams on this device. + */ + virtual void sync() const; + + /** + * Get the native ID of this device (semantics are implementation-dependent). + * + * @return the native ID of this device. + */ + virtual int nativeId() const = 0; + + /** + * Returns the type of this device. + * + * @return a enum denoting device type. + */ + virtual DeviceType type() const = 0; + + /** + * Set this device as the active device and invokes any callbacks added. + */ + void setActive() const; + + /** + * Lets this device keep track of the given callback (along with previously + * added ones), which will be invoked with the device's native ID after + * setting the device active. + * + * @param[in] callback the callback to be invoked with this device's native ID + */ + void addSetActiveCallback(std::function callback); + + /** + * Get the underlying implementation of this device. + * + * Throws invalid_argument if the specified type does not match the actual + * derived device type. + * + * @return an immutable reference to the specified device type. + */ + template + const T& impl() const { + deviceImplTypeCheck(T::type, type()); + return *(static_cast(this)); + } + + /** + * Get the underlying implementation of this device. + * + * Throws invalid_argument if the specified type does not match the actual + * derived device type. + * + * @return a reference to the specified device type. + */ + template + T& impl() { + deviceImplTypeCheck(T::type, type()); + return *(static_cast(this)); + } + + // TODO metadata, e.g., device name }; /** @@ -135,24 +135,24 @@ class FL_API Device { * REQUIRED definition in derived class: * static DeviceType type; */ -template +template class DeviceTrait : public Device { - public: - DeviceType type() const override { - return Derived::type; - } +public: + DeviceType type() const override { + return Derived::type; + } }; /** * A dummy to represent CPU device. */ class FL_API X64Device : public DeviceTrait { - public: - static constexpr DeviceType type = DeviceType::x64; +public: + static constexpr DeviceType type = DeviceType::x64; - X64Device() = default; - int nativeId() const override; - void setActiveImpl() const override; + X64Device() = default; + int nativeId() const override; + void setActiveImpl() const override; }; } // namespace fl diff --git a/flashlight/fl/runtime/DeviceManager.cpp b/flashlight/fl/runtime/DeviceManager.cpp index f9def0b..22b6bb7 100644 --- a/flashlight/fl/runtime/DeviceManager.cpp +++ b/flashlight/fl/runtime/DeviceManager.cpp @@ -17,15 +17,14 @@ namespace { int getActiveDeviceId(const fl::DeviceType type) { - switch (type) { - case fl::DeviceType::x64: return fl::kX64DeviceId; - case fl::DeviceType::CUDA: { + switch(type) { + case fl::DeviceType::x64: return fl::kX64DeviceId; + case fl::DeviceType::CUDA: #if FL_BACKEND_CUDA - return fl::cuda::getActiveDeviceId(); + return fl::cuda::getActiveDeviceId(); #endif - throw std::runtime_error("CUDA is unsupported"); + throw std::runtime_error("CUDA is unsupported"); } - } throw std::runtime_error("unsupported device type"); } @@ -34,73 +33,79 @@ int getActiveDeviceId(const fl::DeviceType type) { namespace fl { DeviceManager::DeviceManager() { - // initialize for x64 - DeviceTypeInfo x64Info; - x64Info.emplace(kX64DeviceId, std::make_unique()); - deviceTypeToInfo_.emplace(DeviceType::x64, std::move(x64Info)); + // initialize for x64 + DeviceTypeInfo x64Info; + x64Info.emplace(kX64DeviceId, std::make_unique()); + deviceTypeToInfo_.emplace(DeviceType::x64, std::move(x64Info)); - // initialize for CUDA + // initialize for CUDA #if FL_BACKEND_CUDA - deviceTypeToInfo_.insert({DeviceType::CUDA, fl::cuda::createCUDADevices()}); + deviceTypeToInfo_.insert({DeviceType::CUDA, fl::cuda::createCUDADevices()}); #endif } void DeviceManager::enforceDeviceTypeAvailable( - std::string_view errorPrefix, const DeviceType type) const { - if (!isDeviceTypeAvailable(type)) { - throw std::runtime_error( - std::string(errorPrefix) + " device type unavailable"); - } + std::string_view errorPrefix, + const DeviceType type +) const { + if(!isDeviceTypeAvailable(type)) { + throw std::runtime_error( + std::string(errorPrefix) + " device type unavailable" + ); + } } DeviceManager& DeviceManager::getInstance() { - static DeviceManager instance; - return instance; + static DeviceManager instance; + return instance; } bool DeviceManager::isDeviceTypeAvailable(const DeviceType type) const { - return deviceTypeToInfo_.contains(type); + return deviceTypeToInfo_.contains(type); } unsigned DeviceManager::getDeviceCount(const DeviceType type) const { - enforceDeviceTypeAvailable("[DeviceManager::getDeviceCount]", type); - return deviceTypeToInfo_.at(type).size(); + enforceDeviceTypeAvailable("[DeviceManager::getDeviceCount]", type); + return deviceTypeToInfo_.at(type).size(); } std::vector DeviceManager::getDevicesOfType( - DeviceType type) { - enforceDeviceTypeAvailable("[DeviceManager::getDevicesOfType]", type); - std::vector devices; - for (auto &[_, device] : deviceTypeToInfo_.at(type)) { - devices.push_back(device.get()); - } - return devices; + DeviceType type +) { + enforceDeviceTypeAvailable("[DeviceManager::getDevicesOfType]", type); + std::vector devices; + for(auto&[_, device] : deviceTypeToInfo_.at(type)) { + devices.push_back(device.get()); + } + return devices; } std::vector DeviceManager::getDevicesOfType( - DeviceType type) const { - enforceDeviceTypeAvailable("[DeviceManager::getDevicesOfType]", type); - std::vector devices; - for (auto &[_, device] : deviceTypeToInfo_.at(type)) { - devices.push_back(device.get()); - } - return devices; + DeviceType type +) const { + enforceDeviceTypeAvailable("[DeviceManager::getDevicesOfType]", type); + std::vector devices; + for(auto&[_, device] : deviceTypeToInfo_.at(type)) { + devices.push_back(device.get()); + } + return devices; } Device& DeviceManager::getDevice(const DeviceType type, int id) const { - enforceDeviceTypeAvailable("[DeviceManager::getActiveDevice]", type); - auto& idToDevice = deviceTypeToInfo_.at(type); - if (!idToDevice.contains(id)) { - throw std::runtime_error( - "[DeviceManager::getDevice] unknown device id"); - } - return *idToDevice.at(id); + enforceDeviceTypeAvailable("[DeviceManager::getActiveDevice]", type); + auto& idToDevice = deviceTypeToInfo_.at(type); + if(!idToDevice.contains(id)) { + throw std::runtime_error( + "[DeviceManager::getDevice] unknown device id" + ); + } + return *idToDevice.at(id); } Device& DeviceManager::getActiveDevice(const DeviceType type) const { - enforceDeviceTypeAvailable("[DeviceManager::getActiveDevice]", type); - int activeDeviceId = getActiveDeviceId(type); - return *deviceTypeToInfo_.at(type).at(activeDeviceId); + enforceDeviceTypeAvailable("[DeviceManager::getActiveDevice]", type); + int activeDeviceId = getActiveDeviceId(type); + return *deviceTypeToInfo_.at(type).at(activeDeviceId); } } // namespace fl diff --git a/flashlight/fl/runtime/DeviceManager.h b/flashlight/fl/runtime/DeviceManager.h index 8a18070..efa3b35 100644 --- a/flashlight/fl/runtime/DeviceManager.h +++ b/flashlight/fl/runtime/DeviceManager.h @@ -26,82 +26,83 @@ constexpr int kX64DeviceId = 0; * A singleton to manage all supported types of devices. */ class FL_API DeviceManager { - using DeviceTypeInfo = std::unordered_map>; - - std::unordered_map deviceTypeToInfo_; - - // Help enforce singleton - DeviceManager(); - DeviceManager(const DeviceManager&) = delete; - DeviceManager(DeviceManager&&) = delete; - DeviceManager& operator=(const DeviceManager&) = delete; - DeviceManager& operator=(DeviceManager&&) = delete; - - // throws runtime_error if `type` is unavailable - void enforceDeviceTypeAvailable( - std::string_view errorPrefix, - const DeviceType type) const; - - public: - /** - * Gets the singleton DeviceManager. - * - * @return a reference to the singleton DeviceManager. - */ - static DeviceManager& getInstance(); - - /** - * Returns if the given device type is available. - * - * @return a boolean denoting device type availability. - */ - bool isDeviceTypeAvailable(const DeviceType type) const; - - /** - * Gets the number of usable devices of given type. - * - * Throws a runtime_error if given device `type` is unavailable. - * - * @return the number of usable devices of given type. - */ - unsigned getDeviceCount(const DeviceType type) const; - - /** - * Gets all devices of given type. - * - * Throws a runtime_error if given device `type` is unavailable. - * - * @return a vector of pointers to all devices of given type. - */ - std::vector getDevicesOfType(const DeviceType type); - - /** - * Gets all devices of given type. - * - * Throws a runtime_error if given device `type` is unavailable. - * - * @return a vector of immutable pointers to all devices of given type. - */ - std::vector getDevicesOfType(const DeviceType type) const; - - /** - * Gets the device of given type and native device id. - * - * Throws a runtime_error if given device `type` is unavailable - * or `id` does not match any device. - * - * @return a reference to the device of given type and native device id. - */ - Device& getDevice(const DeviceType type, int id) const; - - /** - * Gets the active device of given type. - * - * Throws a runtime_error if given device `type` is unavailable. - * - * @return a reference to the active device of given type. - */ - Device& getActiveDevice(const DeviceType type) const; + using DeviceTypeInfo = std::unordered_map>; + + std::unordered_map deviceTypeToInfo_; + + // Help enforce singleton + DeviceManager(); + DeviceManager(const DeviceManager&) = delete; + DeviceManager(DeviceManager&&) = delete; + DeviceManager& operator=(const DeviceManager&) = delete; + DeviceManager& operator=(DeviceManager&&) = delete; + + // throws runtime_error if `type` is unavailable + void enforceDeviceTypeAvailable( + std::string_view errorPrefix, + const DeviceType type + ) const; + +public: + /** + * Gets the singleton DeviceManager. + * + * @return a reference to the singleton DeviceManager. + */ + static DeviceManager& getInstance(); + + /** + * Returns if the given device type is available. + * + * @return a boolean denoting device type availability. + */ + bool isDeviceTypeAvailable(const DeviceType type) const; + + /** + * Gets the number of usable devices of given type. + * + * Throws a runtime_error if given device `type` is unavailable. + * + * @return the number of usable devices of given type. + */ + unsigned getDeviceCount(const DeviceType type) const; + + /** + * Gets all devices of given type. + * + * Throws a runtime_error if given device `type` is unavailable. + * + * @return a vector of pointers to all devices of given type. + */ + std::vector getDevicesOfType(const DeviceType type); + + /** + * Gets all devices of given type. + * + * Throws a runtime_error if given device `type` is unavailable. + * + * @return a vector of immutable pointers to all devices of given type. + */ + std::vector getDevicesOfType(const DeviceType type) const; + + /** + * Gets the device of given type and native device id. + * + * Throws a runtime_error if given device `type` is unavailable + * or `id` does not match any device. + * + * @return a reference to the device of given type and native device id. + */ + Device& getDevice(const DeviceType type, int id) const; + + /** + * Gets the active device of given type. + * + * Throws a runtime_error if given device `type` is unavailable. + * + * @return a reference to the active device of given type. + */ + Device& getActiveDevice(const DeviceType type) const; }; } // namespace fl diff --git a/flashlight/fl/runtime/DeviceType.cpp b/flashlight/fl/runtime/DeviceType.cpp index 6bf9f37..2d923a8 100644 --- a/flashlight/fl/runtime/DeviceType.cpp +++ b/flashlight/fl/runtime/DeviceType.cpp @@ -10,22 +10,22 @@ namespace fl { std::string deviceTypeToString(const DeviceType type) { - switch (type) { - case DeviceType::x64: return "x64"; - case DeviceType::CUDA: return "CUDA"; - } + switch(type) { + case DeviceType::x64: return "x64"; + case DeviceType::CUDA: return "CUDA"; + } } std::ostream& operator<<(std::ostream& os, const DeviceType& type) { - return os << deviceTypeToString(type); + return os << deviceTypeToString(type); } const std::unordered_set& getDeviceTypes() { - static std::unordered_set types = { - DeviceType::x64, - DeviceType::CUDA - }; - return types; + static std::unordered_set types = { + DeviceType::x64, + DeviceType::CUDA + }; + return types; } } // namespace fl diff --git a/flashlight/fl/runtime/DeviceType.h b/flashlight/fl/runtime/DeviceType.h index f914fad..7f906e7 100644 --- a/flashlight/fl/runtime/DeviceType.h +++ b/flashlight/fl/runtime/DeviceType.h @@ -20,8 +20,8 @@ namespace fl { * NOTE update `fl::getAllDeviceTypes` after changing enum values. */ enum class DeviceType { - x64, - CUDA, + x64, + CUDA, }; #if FL_BACKEND_CUDA diff --git a/flashlight/fl/runtime/Stream.cpp b/flashlight/fl/runtime/Stream.cpp index d95416f..2210e99 100644 --- a/flashlight/fl/runtime/Stream.cpp +++ b/flashlight/fl/runtime/Stream.cpp @@ -10,10 +10,11 @@ namespace fl { void Stream::relativeSync( - const std::unordered_set& waitOns) const { - for (const auto* waitOn : waitOns) { - this->relativeSync(*waitOn); - } + const std::unordered_set& waitOns +) const { + for(const auto* waitOn : waitOns) { + this->relativeSync(*waitOn); + } } } // namespace fl diff --git a/flashlight/fl/runtime/Stream.h b/flashlight/fl/runtime/Stream.h index d8fbd5f..52a824c 100644 --- a/flashlight/fl/runtime/Stream.h +++ b/flashlight/fl/runtime/Stream.h @@ -17,8 +17,8 @@ namespace fl { class Device; enum class StreamType { - CUDA, - Synchronous, + CUDA, + Synchronous, }; /** @@ -27,78 +27,80 @@ enum class StreamType { * computations, while being agnostic to the computations themselves. */ class FL_API Stream { - public: - Stream() = default; - virtual ~Stream() = default; - - // no copy/move - Stream(const Stream&) = delete; - Stream(Stream&&) = delete; - Stream& operator=(const Stream&) = delete; - Stream& operator=(const Stream&&) = delete; - - /** - * Get the underlying implementation of this stream. - * - * Throws invalid_argument if the specified type does not match the actual - * derived stream type. - * - * @return an immutable reference to the specified stream type. - */ - template - const T& impl() const { - if (T::type != type()) { - throw std::invalid_argument( - "[fl::Stream::impl] " - "specified stream type doesn't match actual stream type."); +public: + Stream() = default; + virtual ~Stream() = default; + + // no copy/move + Stream(const Stream&) = delete; + Stream(Stream&&) = delete; + Stream& operator=(const Stream&) = delete; + Stream& operator=(const Stream&&) = delete; + + /** + * Get the underlying implementation of this stream. + * + * Throws invalid_argument if the specified type does not match the actual + * derived stream type. + * + * @return an immutable reference to the specified stream type. + */ + template + const T& impl() const { + if(T::type != type()) { + throw std::invalid_argument( + "[fl::Stream::impl] " + "specified stream type doesn't match actual stream type." + ); + } + return *(static_cast(this)); } - return *(static_cast(this)); - } - - /** - * Returns the type of this stream. - * - * @return a enum denoting stream type. - */ - virtual StreamType type() const = 0; - - /** - * Return the owner device of this stream. - * - * @return a reference to the owner device of this stream. - */ - virtual Device& device() = 0; - - /** - * Return the owner device of this stream. - * - * @return an immutable reference to the owner device of this stream. - */ - virtual const Device& device() const = 0; - - /** - * Block calling thread and synchronize w.r.t. all tasks on this stream. - */ - virtual void sync() const = 0; - - /** - * Synchronize future tasks on this stream w.r.t. current tasks on given - * stream, i.e., the former can only start after the completion of the latter. - * NOTE this function may or may not block the calling thread. - * - * @param[in] waitOn the stream to perform relative synchronization against. - */ - virtual void relativeSync(const Stream& waitOn) const = 0; - - /** - * Synchronize future tasks on this stream w.r.t. current tasks on all given - * stream, i.e., the former can only start after the completion of the latter. - * NOTE this function may or may not block the calling thread. - * - * @param[in] waitOns the streams to perform relative synchronization against. - */ - virtual void relativeSync( - const std::unordered_set& waitOns) const; + + /** + * Returns the type of this stream. + * + * @return a enum denoting stream type. + */ + virtual StreamType type() const = 0; + + /** + * Return the owner device of this stream. + * + * @return a reference to the owner device of this stream. + */ + virtual Device& device() = 0; + + /** + * Return the owner device of this stream. + * + * @return an immutable reference to the owner device of this stream. + */ + virtual const Device& device() const = 0; + + /** + * Block calling thread and synchronize w.r.t. all tasks on this stream. + */ + virtual void sync() const = 0; + + /** + * Synchronize future tasks on this stream w.r.t. current tasks on given + * stream, i.e., the former can only start after the completion of the latter. + * NOTE this function may or may not block the calling thread. + * + * @param[in] waitOn the stream to perform relative synchronization against. + */ + virtual void relativeSync(const Stream& waitOn) const = 0; + + /** + * Synchronize future tasks on this stream w.r.t. current tasks on all given + * stream, i.e., the former can only start after the completion of the latter. + * NOTE this function may or may not block the calling thread. + * + * @param[in] waitOns the streams to perform relative synchronization against. + */ + virtual void relativeSync( + const std::unordered_set& waitOns + ) const; }; /** @@ -107,29 +109,30 @@ class FL_API Stream { * REQUIRED definition in derived class: * static StreamType type; */ -template +template class StreamTrait : public Stream { - public: - // prevent name hiding - using Stream::relativeSync; - - // A specialized relativeSync for streams of the same type. - virtual void relativeSync(const Derived& waitOn) const = 0; - - StreamType type() const override { - return Derived::type; - } - - virtual void relativeSync(const Stream& waitOn) const override { - switch (waitOn.type()) { - case Derived::type: - relativeSync(waitOn.impl()); - break; - default: - throw std::runtime_error( - "[Stream::relativeSync] Unsupported for different types of streams"); +public: + // prevent name hiding + using Stream::relativeSync; + + // A specialized relativeSync for streams of the same type. + virtual void relativeSync(const Derived& waitOn) const = 0; + + StreamType type() const override { + return Derived::type; + } + + virtual void relativeSync(const Stream& waitOn) const override { + switch(waitOn.type()) { + case Derived::type: + relativeSync(waitOn.impl()); + break; + default: + throw std::runtime_error( + "[Stream::relativeSync] Unsupported for different types of streams" + ); + } } - } }; } // namespace fl diff --git a/flashlight/fl/runtime/SynchronousStream.cpp b/flashlight/fl/runtime/SynchronousStream.cpp index 0e760a9..b3013c7 100644 --- a/flashlight/fl/runtime/SynchronousStream.cpp +++ b/flashlight/fl/runtime/SynchronousStream.cpp @@ -10,15 +10,15 @@ namespace fl { X64Device& SynchronousStream::device() { - return device_; + return device_; } const X64Device& SynchronousStream::device() const { - return device_; + return device_; } void SynchronousStream::relativeSync(const SynchronousStream& waitOn) const { - waitOn.sync(); + waitOn.sync(); } } // namespace fl diff --git a/flashlight/fl/runtime/SynchronousStream.h b/flashlight/fl/runtime/SynchronousStream.h index a5deb02..8e0081b 100644 --- a/flashlight/fl/runtime/SynchronousStream.h +++ b/flashlight/fl/runtime/SynchronousStream.h @@ -17,20 +17,20 @@ namespace fl { * relative synchronization strategy, i.e., it merely delegates to `sync`. */ class FL_API SynchronousStream : public StreamTrait { - protected: - X64Device& device_{DeviceManager::getInstance() - .getActiveDevice(DeviceType::x64) - .impl()}; +protected: + X64Device& device_{DeviceManager::getInstance() + .getActiveDevice(DeviceType::x64) + .impl()}; - public: - // prevent name hiding - using StreamTrait::relativeSync; +public: + // prevent name hiding + using StreamTrait::relativeSync; - static constexpr StreamType type = StreamType::Synchronous; + static constexpr StreamType type = StreamType::Synchronous; - X64Device& device() override; - const X64Device& device() const override; - void relativeSync(const SynchronousStream& waitOn) const override; + X64Device& device() override; + const X64Device& device() const override; + void relativeSync(const SynchronousStream& waitOn) const override; }; } // namespace fl diff --git a/flashlight/fl/tensor/Compute.cpp b/flashlight/fl/tensor/Compute.cpp index f2a2e99..0a981b2 100644 --- a/flashlight/fl/tensor/Compute.cpp +++ b/flashlight/fl/tensor/Compute.cpp @@ -19,115 +19,119 @@ namespace fl { namespace { -std::unordered_set tensorsToUniqueStreams( - const std::vector& tensors) { - std::unordered_set uniqueStreams; - for (const auto& tensor : tensors) { - uniqueStreams.insert(&tensor.stream()); - } - return uniqueStreams; -} - -std::unordered_set tensorsToUniqueStreams( - const std::vector& tensors) { - std::unordered_set uniqueStreams; - for (const auto& tensor : tensors) { - uniqueStreams.insert(&tensor->stream()); - } - return uniqueStreams; -} + std::unordered_set tensorsToUniqueStreams( + const std::vector& tensors + ) { + std::unordered_set uniqueStreams; + for(const auto& tensor : tensors) { + uniqueStreams.insert(&tensor.stream()); + } + return uniqueStreams; + } + + std::unordered_set tensorsToUniqueStreams( + const std::vector& tensors + ) { + std::unordered_set uniqueStreams; + for(const auto& tensor : tensors) { + uniqueStreams.insert(&tensor->stream()); + } + return uniqueStreams; + } } // namespace void sync() { - DeviceManager::getInstance().getActiveDevice(fl::kDefaultDeviceType).sync(); + DeviceManager::getInstance().getActiveDevice(fl::kDefaultDeviceType).sync(); } void sync(const int deviceId) { - DeviceManager::getInstance() - .getDevice(fl::kDefaultDeviceType, deviceId) - .sync(); + DeviceManager::getInstance() + .getDevice(fl::kDefaultDeviceType, deviceId) + .sync(); } void sync(const std::unordered_set& types) { - const auto& manager = DeviceManager::getInstance(); - // TODO consider launching these `Device::sync` calls non-blockingly - for (const auto type : types) { - manager.getActiveDevice(type).sync(); - } + const auto& manager = DeviceManager::getInstance(); + // TODO consider launching these `Device::sync` calls non-blockingly + for(const auto type : types) { + manager.getActiveDevice(type).sync(); + } } void sync(const std::unordered_set& devices) { - // TODO consider launching these `Device::sync` calls non-blockingly - for (const auto* device : devices) { - device->sync(); - } + // TODO consider launching these `Device::sync` calls non-blockingly + for(const auto* device : devices) { + device->sync(); + } } void relativeSync( const Stream& wait, - const std::vector& waitOns) { - // ensure computations are launched - for (const auto* tensor : waitOns) { - tensor->backend().eval(*tensor); - } - wait.relativeSync(tensorsToUniqueStreams(waitOns)); + const std::vector& waitOns +) { + // ensure computations are launched + for(const auto* tensor : waitOns) { + tensor->backend().eval(*tensor); + } + wait.relativeSync(tensorsToUniqueStreams(waitOns)); } void relativeSync(const Stream& wait, const std::vector& waitOns) { - // ensure computations are launched - for (const auto& tensor : waitOns) { - tensor.backend().eval(tensor); - } - wait.relativeSync(tensorsToUniqueStreams(waitOns)); + // ensure computations are launched + for(const auto& tensor : waitOns) { + tensor.backend().eval(tensor); + } + wait.relativeSync(tensorsToUniqueStreams(waitOns)); } void relativeSync(const std::vector& waits, const Stream& waitOn) { - for (const auto& stream : tensorsToUniqueStreams(waits)) { - stream->relativeSync(waitOn); - } + for(const auto& stream : tensorsToUniqueStreams(waits)) { + stream->relativeSync(waitOn); + } } void eval(Tensor& tensor) { - tensor.backend().eval(tensor); + tensor.backend().eval(tensor); } int getDevice() { - return DeviceManager::getInstance() - .getActiveDevice(fl::kDefaultDeviceType) - .nativeId(); + return DeviceManager::getInstance() + .getActiveDevice(fl::kDefaultDeviceType) + .nativeId(); } void setDevice(const int deviceId) { - DeviceManager::getInstance() - .getDevice(fl::kDefaultDeviceType, deviceId) - .setActive(); + DeviceManager::getInstance() + .getDevice(fl::kDefaultDeviceType, deviceId) + .setActive(); } int getDeviceCount() { - return DeviceManager::getInstance().getDeviceCount(fl::kDefaultDeviceType); + return DeviceManager::getInstance().getDeviceCount(fl::kDefaultDeviceType); } namespace detail { -void getMemMgrInfo( - const char* msg, - const int deviceId, - std::ostream* ostream /* = &std::cout */) { - defaultTensorBackend().getMemMgrInfo(msg, deviceId, ostream); -} - -void setMemMgrLogStream(std::ostream* stream) { - defaultTensorBackend().setMemMgrLogStream(stream); -} - -void setMemMgrLoggingEnabled(const bool enabled) { - defaultTensorBackend().setMemMgrLoggingEnabled(enabled); -} - -void setMemMgrFlushInterval(const size_t interval) { - defaultTensorBackend().setMemMgrFlushInterval(interval); -} + void getMemMgrInfo( + const char* msg, + const int deviceId, + std::ostream* ostream /* = &std::cout */ + ) { + defaultTensorBackend().getMemMgrInfo(msg, deviceId, ostream); + } + + void setMemMgrLogStream(std::ostream* stream) { + defaultTensorBackend().setMemMgrLogStream(stream); + } + + void setMemMgrLoggingEnabled(const bool enabled) { + defaultTensorBackend().setMemMgrLoggingEnabled(enabled); + } + + void setMemMgrFlushInterval(const size_t interval) { + defaultTensorBackend().setMemMgrFlushInterval(interval); + } } // namespace detail diff --git a/flashlight/fl/tensor/Compute.h b/flashlight/fl/tensor/Compute.h index dc3f5fc..76f1eca 100644 --- a/flashlight/fl/tensor/Compute.h +++ b/flashlight/fl/tensor/Compute.h @@ -68,7 +68,8 @@ FL_API void sync(const std::unordered_set& devices); */ FL_API void relativeSync( const Stream& wait, - const std::vector& waitOns); + const std::vector& waitOns +); /** * Synchronize future tasks on given stream w.r.t. current tasks on all unique @@ -82,7 +83,8 @@ FL_API void relativeSync( */ FL_API void relativeSync( const Stream& wait, - const std::vector& waitOns); + const std::vector& waitOns +); /** * Synchronize future tasks on the streams of `waits` w.r.t. current task on @@ -95,7 +97,8 @@ FL_API void relativeSync( */ FL_API void relativeSync( const std::vector& waits, - const Stream& waitOn); + const Stream& waitOn +); /** * Launches computation, [usually] asynchronously, on operations needed to make @@ -155,10 +158,11 @@ namespace detail { * This function may be a noop for backends that do not implement memory * managers with configurable logging. */ -FL_API void getMemMgrInfo( - const char* msg, - const int deviceId, - std::ostream* ostream = &std::cout); + FL_API void getMemMgrInfo( + const char* msg, + const int deviceId, + std::ostream* ostream = & std::cout + ); /** * Configures memory manager log output to write to a specified output stream. @@ -169,7 +173,7 @@ FL_API void getMemMgrInfo( * * @returns the number of active devices usable in Flashlight. */ -FL_API void setMemMgrLogStream(std::ostream* stream); + FL_API void setMemMgrLogStream(std::ostream* stream); /** * Sets (or unsets) logging for memory management. This function may be a noop @@ -180,7 +184,7 @@ FL_API void setMemMgrLogStream(std::ostream* stream); * * @param[in] enabled true to enable logging, false to disable. */ -FL_API void setMemMgrLoggingEnabled(const bool enabled); + FL_API void setMemMgrLoggingEnabled(const bool enabled); /** * Configures memory manager log output to flush to the output stream after a @@ -192,7 +196,7 @@ FL_API void setMemMgrLoggingEnabled(const bool enabled); * @param[in] interval the number of lines after which to flush the temporary * log buffer. Supplied interval must be greater than 1. */ -FL_API void setMemMgrFlushInterval(const size_t interval); + FL_API void setMemMgrFlushInterval(const size_t interval); } // namespace detail } // namespace fl diff --git a/flashlight/fl/tensor/DefaultTensorType.cpp b/flashlight/fl/tensor/DefaultTensorType.cpp index f81fd2f..f2b5b18 100644 --- a/flashlight/fl/tensor/DefaultTensorType.cpp +++ b/flashlight/fl/tensor/DefaultTensorType.cpp @@ -10,8 +10,8 @@ namespace fl { TensorBackend& defaultTensorBackend() { - // TODO: improve this implementation! Hacky/requires creating a tensor - return Tensor().backend(); + // TODO: improve this implementation! Hacky/requires creating a tensor + return Tensor().backend(); } } // namespace fl diff --git a/flashlight/fl/tensor/DefaultTensorType.h b/flashlight/fl/tensor/DefaultTensorType.h index 3ec81bb..87ac531 100644 --- a/flashlight/fl/tensor/DefaultTensorType.h +++ b/flashlight/fl/tensor/DefaultTensorType.h @@ -8,12 +8,12 @@ #pragma once #if FL_USE_ARRAYFIRE - #include "flashlight/fl/tensor/backend/af/ArrayFireBackend.h" - #include "flashlight/fl/tensor/backend/af/ArrayFireTensor.h" +#include "flashlight/fl/tensor/backend/af/ArrayFireBackend.h" +#include "flashlight/fl/tensor/backend/af/ArrayFireTensor.h" #endif #if FL_USE_TENSOR_STUB - #include "flashlight/fl/tensor/backend/stub/StubBackend.h" - #include "flashlight/fl/tensor/backend/stub/StubTensor.h" +#include "flashlight/fl/tensor/backend/stub/StubBackend.h" +#include "flashlight/fl/tensor/backend/stub/StubTensor.h" #endif namespace fl { diff --git a/flashlight/fl/tensor/Index.cpp b/flashlight/fl/tensor/Index.cpp index 883aa7a..e4073dc 100644 --- a/flashlight/fl/tensor/Index.cpp +++ b/flashlight/fl/tensor/Index.cpp @@ -11,63 +11,65 @@ namespace fl { range::range(const Dim& i) : range(0, i) {} -range::range(const Dim& start, const idx& end) - : range(start, end, /* stride */ kDefaultStride) {} - -range::range(const Dim& start, const idx& end, const Dim stride) - : start_(start), - end_( - std::holds_alternative(end) - ? std::nullopt - : std::optional(std::get(end))), - stride_(stride) {} +range::range(const Dim& start, const idx& end) : range(start, end, /* stride */ kDefaultStride) {} + +range::range(const Dim& start, const idx& end, const Dim stride) : start_(start), + end_( + std::holds_alternative(end) + ? std::nullopt + : std::optional( + std::get( + end + ) + )), + stride_(stride) {} Dim range::start() const { - return start_; + return start_; } const std::optional& range::end() const { - return end_; + return end_; } Dim range::endVal() const { - if (end_.has_value()) { - return end_.value(); - } - throw std::runtime_error("[range::endVal] end is end_t"); + if(end_.has_value()) { + return end_.value(); + } + throw std::runtime_error("[range::endVal] end is end_t"); } Dim range::stride() const { - return stride_; + return stride_; } bool range::operator==(const range& other) const { - return start_ == other.start() && end_ == other.end() && - stride_ == other.stride(); + return start_ == other.start() && end_ == other.end() + && stride_ == other.stride(); } bool range::operator!=(const range& other) const { - return !this->operator==(other); + return !this->operator==(other); } -Index::Index(const Tensor& tensor) - : type_(detail::IndexType::Tensor), index_(tensor) {} +Index::Index(const Tensor& tensor) : type_(detail::IndexType::Tensor), + index_(tensor) {} -Index::Index(const range& range) - : type_(range == span ? detail::IndexType::Span : detail::IndexType::Range), - index_(range) {} +Index::Index(const range& range) : type_(range == span ? detail::IndexType::Span : detail::IndexType::Range), + index_(range) {} -Index::Index(const Dim idx) : type_(detail::IndexType::Literal), index_(idx) {} +Index::Index(const Dim idx) : type_(detail::IndexType::Literal), + index_(idx) {} -Index::Index(Index&& other) noexcept - : type_(other.type_), index_(std::move(other.index_)) {} +Index::Index(Index&& other) noexcept : type_(other.type_), + index_(std::move(other.index_)) {} detail::IndexType Index::type() const { - return type_; + return type_; } bool Index::isSpan() const { - return type_ == detail::IndexType::Span; + return type_ == detail::IndexType::Span; } } // namespace fl diff --git a/flashlight/fl/tensor/Index.h b/flashlight/fl/tensor/Index.h index a6204fb..d5f147a 100644 --- a/flashlight/fl/tensor/Index.h +++ b/flashlight/fl/tensor/Index.h @@ -37,53 +37,53 @@ static const end_t end = end_t(); * ------------------------- */ class FL_API range { - using idx = std::variant; - static constexpr Dim kDefaultStride = 1; - - Dim start_{0}; - // end is exclusive; std::nullopt means including the last element - std::optional end_{std::nullopt}; - Dim stride_{kDefaultStride}; - - public: - /** - * Default ctor. - */ - range() = default; - - /** - * Construct a range with the indices [0, idx) (i.e. [0, idx - 1]) - * - * @param[in] idx the end index of the range, which will start from 0 - */ - explicit range(const Dim& idx); - - /** - * Construct a range with the indices [start, end) (i.e. [start, end - 1]) - * - * @param[in] start the starting index of the range - * @param[in] end the end index of the range, which will start from 0 - */ - range(const Dim& start, const idx& end); - - /** - * Construct a range with the indices [start, end) (i.e. [start, end - 1]) - * with the given stride. - * - * @param[in] start the starting index of the range - * @param[in] end the end index of the range, which will start from 0 - * @param[in] stride the interval over which successive range elements appear - */ - range(const Dim& start, const idx& end, const Dim stride); - - Dim start() const; - // std::nullopt represents `end_t` - const std::optional& end() const; - // throw if end is `end_t` - Dim endVal() const; - Dim stride() const; - bool operator==(const range& other) const; - bool operator!=(const range& other) const; + using idx = std::variant; + static constexpr Dim kDefaultStride = 1; + + Dim start_{0}; + // end is exclusive; std::nullopt means including the last element + std::optional end_{std::nullopt}; + Dim stride_{kDefaultStride}; + +public: + /** + * Default ctor. + */ + range() = default; + + /** + * Construct a range with the indices [0, idx) (i.e. [0, idx - 1]) + * + * @param[in] idx the end index of the range, which will start from 0 + */ + explicit range(const Dim& idx); + + /** + * Construct a range with the indices [start, end) (i.e. [start, end - 1]) + * + * @param[in] start the starting index of the range + * @param[in] end the end index of the range, which will start from 0 + */ + range(const Dim& start, const idx& end); + + /** + * Construct a range with the indices [start, end) (i.e. [start, end - 1]) + * with the given stride. + * + * @param[in] start the starting index of the range + * @param[in] end the end index of the range, which will start from 0 + * @param[in] stride the interval over which successive range elements appear + */ + range(const Dim& start, const idx& end, const Dim stride); + + Dim start() const; + // std::nullopt represents `end_t` + const std::optional& end() const; + // throw if end is `end_t` + Dim endVal() const; + Dim stride() const; + bool operator==(const range& other) const; + bool operator!=(const range& other) const; }; // span is an instance of range @@ -94,7 +94,7 @@ namespace detail { /** * Allowed indexing operators. */ -enum class IndexType : int { Tensor = 0, Range = 1, Literal = 2, Span = 3 }; + enum class IndexType : int {Tensor = 0, Range = 1, Literal = 2, Span = 3}; } // namespace detail @@ -110,64 +110,64 @@ enum class IndexType : int { Tensor = 0, Range = 1, Literal = 2, Span = 3 }; * indexed. */ struct FL_API Index { - using IndexVariant = std::variant; - - private: - // The type of indexing operator. - detail::IndexType type_; - - // Underlying data referred to by the index - IndexVariant index_; - - // Intentionally private - Index() = default; - - public: - /* implicit */ Index(const Tensor& tensor); - /* implicit */ Index(const range& range); - /* implicit */ Index(const Dim idx); - - /** - * Default copy assignment operator. - */ - Index& operator=(const Index&) = default; - - /** - * Move constructor - moves the index data. - */ - Index(Index&& index) noexcept; - Index(const Index& index) = default; - - /** - * Get the index type for this index. - * - * @return the index type. - */ - detail::IndexType type() const; - - /** - * Returns true if the index represents a span. - */ - bool isSpan() const; - - /** - * Get the internal data for a particular Index. Parameterized by type. Will - * throw as per std::variant if the type doesn't match this Index's underlying - * type. - */ - template - const T& get() const { - return std::get(index_); - } - - template - T& get() { - return std::get(index_); - } - - IndexVariant getVariant() const { - return index_; - } + using IndexVariant = std::variant; + +private: + // The type of indexing operator. + detail::IndexType type_; + + // Underlying data referred to by the index + IndexVariant index_; + + // Intentionally private + Index() = default; + +public: + /* implicit */ Index(const Tensor& tensor); + /* implicit */ Index(const range& range); + /* implicit */ Index(const Dim idx); + + /** + * Default copy assignment operator. + */ + Index& operator=(const Index&) = default; + + /** + * Move constructor - moves the index data. + */ + Index(Index && index) noexcept; + Index(const Index& index) = default; + + /** + * Get the index type for this index. + * + * @return the index type. + */ + detail::IndexType type() const; + + /** + * Returns true if the index represents a span. + */ + bool isSpan() const; + + /** + * Get the internal data for a particular Index. Parameterized by type. Will + * throw as per std::variant if the type doesn't match this Index's underlying + * type. + */ + template + const T& get() const { + return std::get(index_); + } + + template + T& get() { + return std::get(index_); + } + + IndexVariant getVariant() const { + return index_; + } }; } // namespace fl diff --git a/flashlight/fl/tensor/Init.cpp b/flashlight/fl/tensor/Init.cpp index 8da7fb2..0a01091 100644 --- a/flashlight/fl/tensor/Init.cpp +++ b/flashlight/fl/tensor/Init.cpp @@ -15,7 +15,7 @@ namespace fl { namespace { -std::once_flag flInitFlag; + std::once_flag flInitFlag; } /** @@ -27,10 +27,13 @@ std::once_flag flInitFlag; * Body is only run once per process. Subsequent calls will be noops. */ void init() { - std::call_once(flInitFlag, []() { - defaultTensorBackend(); - initLogging(); - }); + std::call_once( + flInitFlag, + []() { + defaultTensorBackend(); + initLogging(); + } + ); } } // namespace fl diff --git a/flashlight/fl/tensor/Random.cpp b/flashlight/fl/tensor/Random.cpp index d15355d..e4f69c1 100644 --- a/flashlight/fl/tensor/Random.cpp +++ b/flashlight/fl/tensor/Random.cpp @@ -14,15 +14,15 @@ namespace fl { void setSeed(const int seed) { - defaultTensorBackend().setSeed(seed); + defaultTensorBackend().setSeed(seed); } Tensor randn(const Shape& shape, dtype type) { - return defaultTensorBackend().randn(shape, type); + return defaultTensorBackend().randn(shape, type); } Tensor rand(const Shape& shape, dtype type) { - return defaultTensorBackend().rand(shape, type); + return defaultTensorBackend().rand(shape, type); } } // namespace fl diff --git a/flashlight/fl/tensor/Shape.cpp b/flashlight/fl/tensor/Shape.cpp index 489808d..7fb21f9 100644 --- a/flashlight/fl/tensor/Shape.cpp +++ b/flashlight/fl/tensor/Shape.cpp @@ -21,79 +21,79 @@ Shape::Shape(std::initializer_list d) : Shape(std::vector(d)) {} const Dim kEmptyShapeNumberOfElements = 1; void Shape::checkDimsOrThrow(const size_t dim) const { - if (dim > ndim() - 1) { - std::stringstream ss; - ss << "Shape index " << std::to_string(dim) - << " out of bounds for shape with " << std::to_string(dims_.size()) - << " dimensions."; - throw std::invalid_argument(ss.str()); - } + if(dim > ndim() - 1) { + std::stringstream ss; + ss << "Shape index " << std::to_string(dim) + << " out of bounds for shape with " << std::to_string(dims_.size()) + << " dimensions."; + throw std::invalid_argument(ss.str()); + } } Dim Shape::elements() const { - if (dims_.empty()) { - return kEmptyShapeNumberOfElements; - } - return std::accumulate(dims_.begin(), dims_.end(), static_cast(1), std::multiplies()); + if(dims_.empty()) { + return kEmptyShapeNumberOfElements; + } + return std::accumulate(dims_.begin(), dims_.end(), static_cast(1), std::multiplies()); } int Shape::ndim() const { - return dims_.size(); + return dims_.size(); } Dim Shape::dim(const size_t dim) const { - checkDimsOrThrow(dim); - return dims_[dim]; + checkDimsOrThrow(dim); + return dims_[dim]; } Dim& Shape::operator[](const size_t dim) { - checkDimsOrThrow(dim); - return dims_[dim]; + checkDimsOrThrow(dim); + return dims_[dim]; } const Dim& Shape::operator[](const size_t dim) const { - checkDimsOrThrow(dim); - return dims_[dim]; + checkDimsOrThrow(dim); + return dims_[dim]; } bool Shape::operator==(const Shape& other) const { - return dims_ == other.dims_; + return dims_ == other.dims_; } bool Shape::operator!=(const Shape& other) const { - return !(this->operator==(other)); + return !(this->operator==(other)); } bool Shape::operator==(const std::initializer_list& other) const { - return dims_.size() == other.size() && - std::equal(std::begin(dims_), std::end(dims_), std::begin(other)); + return dims_.size() == other.size() + && std::equal(std::begin(dims_), std::end(dims_), std::begin(other)); } bool Shape::operator!=(const std::initializer_list& other) const { - return !(this->operator==(other)); + return !(this->operator==(other)); } const std::vector& Shape::get() const { - return dims_; + return dims_; } std::vector& Shape::get() { - return dims_; + return dims_; }; std::string Shape::toString() const { - std::stringstream ss; - ss << "("; - for (size_t i = 0; i < ndim(); ++i) { - ss << dim(i) << (i == ndim() - 1 ? "" : ", "); - } - ss << ")"; - return ss.str(); + std::stringstream ss; + ss << "("; + for(size_t i = 0; i < ndim(); ++i) { + ss << dim(i) << (i == ndim() - 1 ? "" : ", "); + } + ss << ")"; + return ss.str(); } std::ostream& operator<<(std::ostream& ostr, const Shape& s) { - ostr << s.toString(); - return ostr; + ostr << s.toString(); + return ostr; } } // namespace fl diff --git a/flashlight/fl/tensor/Shape.h b/flashlight/fl/tensor/Shape.h index 6e8fa95..7f6ab41 100644 --- a/flashlight/fl/tensor/Shape.h +++ b/flashlight/fl/tensor/Shape.h @@ -42,84 +42,84 @@ using Dim = long long; * backing storage or handles. */ class FL_API Shape { - // Storage for the dimension values. Defaults to an empty Shape {0}, whereas - // {} is a scalar shape. - std::vector dims_; - - /** - * Check if a dimension is valid (i.e. in bounds) given the current size of - * the shape. If not valid, throws an exception. - */ - void checkDimsOrThrow(const size_t dim) const; - - public: - Shape() = default; - ~Shape() = default; - /** - * Gives the maximum number of dimensions a tensor of a particular shape can - * have. - * - * If the maximum size can be arbitrarily high, `std::numeric_limits` - * should be used. - */ - static constexpr size_t kMaxDims = std::numeric_limits::max(); - - /** - * Initialize a Shape via a vector. - */ - explicit Shape(std::vector d); - - /** - * Initialize a Shape via an initializer list. - */ - /* implicit */ Shape(std::initializer_list d); - - /** - * @return the number of elements in a tensor that has the given shape. - */ - Dim elements() const; - - /** - * @return Number of dimensions in the shape. - */ - int ndim() const; - - /** - * Get the size of a given dimension in the number of arguments. Throws if the - * given dimension is larger than the number of dimensions. - * - * @return the number of elements at the given dimension - */ - Dim dim(const size_t dim) const; - - /** - * Returns a reference to the given index - */ - Dim& operator[](const size_t dim); - const Dim& operator[](const size_t dim) const; - - /** - * Compares two shapes. Returns true if their dim vectors are equal. - */ - bool operator==(const Shape& other) const; - bool operator!=(const Shape& other) const; - - /** - * Compare a shape to an initializer list. - */ - bool operator==(const std::initializer_list& other) const; - bool operator!=(const std::initializer_list& other) const; - - /** - * Gets a reference to the underying dims vector. - */ - const std::vector& get() const; - std::vector& get(); - - /** - * Returns a string representation of the Shape - */ - std::string toString() const; + // Storage for the dimension values. Defaults to an empty Shape {0}, whereas + // {} is a scalar shape. + std::vector dims_; + + /** + * Check if a dimension is valid (i.e. in bounds) given the current size of + * the shape. If not valid, throws an exception. + */ + void checkDimsOrThrow(const size_t dim) const; + +public: + Shape() = default; + ~Shape() = default; + /** + * Gives the maximum number of dimensions a tensor of a particular shape can + * have. + * + * If the maximum size can be arbitrarily high, `std::numeric_limits` + * should be used. + */ + static constexpr size_t kMaxDims = std::numeric_limits::max(); + + /** + * Initialize a Shape via a vector. + */ + explicit Shape(std::vector d); + + /** + * Initialize a Shape via an initializer list. + */ + /* implicit */ Shape(std::initializer_list d); + + /** + * @return the number of elements in a tensor that has the given shape. + */ + Dim elements() const; + + /** + * @return Number of dimensions in the shape. + */ + int ndim() const; + + /** + * Get the size of a given dimension in the number of arguments. Throws if the + * given dimension is larger than the number of dimensions. + * + * @return the number of elements at the given dimension + */ + Dim dim(const size_t dim) const; + + /** + * Returns a reference to the given index + */ + Dim& operator[](const size_t dim); + const Dim& operator[](const size_t dim) const; + + /** + * Compares two shapes. Returns true if their dim vectors are equal. + */ + bool operator==(const Shape& other) const; + bool operator!=(const Shape& other) const; + + /** + * Compare a shape to an initializer list. + */ + bool operator==(const std::initializer_list& other) const; + bool operator!=(const std::initializer_list& other) const; + + /** + * Gets a reference to the underying dims vector. + */ + const std::vector& get() const; + std::vector& get(); + + /** + * Returns a string representation of the Shape + */ + std::string toString() const; }; /** diff --git a/flashlight/fl/tensor/TensorAdapter.cpp b/flashlight/fl/tensor/TensorAdapter.cpp index 77f34ee..68d2c55 100644 --- a/flashlight/fl/tensor/TensorAdapter.cpp +++ b/flashlight/fl/tensor/TensorAdapter.cpp @@ -19,32 +19,34 @@ namespace fl::detail { DefaultTensorType& DefaultTensorType::getInstance() { - static DefaultTensorType instance; - return instance; + static DefaultTensorType instance; + return instance; } DefaultTensorType::DefaultTensorType() { - // Resolve the default backend in order of preference/availability - // See DefaultTensorType.h + // Resolve the default backend in order of preference/availability + // See DefaultTensorType.h #if FL_DEFAULT_BACKEND_COMPILE_FLAG - creationFunc_ = std::make_unique>(); + creationFunc_ = std::make_unique>(); #else - throw std::runtime_error( - "Cannot construct DefaultTensorType singleton: Flashlight built " - "without an available tensor backend."); + throw std::runtime_error( + "Cannot construct DefaultTensorType singleton: Flashlight built " + "without an available tensor backend." + ); #endif } std::unique_ptr DefaultTensorType::swap( - std::unique_ptr creator) noexcept { - std::unique_ptr old = std::move(creationFunc_); - creationFunc_ = std::move(creator); - return old; + std::unique_ptr creator +) noexcept { + std::unique_ptr old = std::move(creationFunc_); + creationFunc_ = std::move(creator); + return old; } const TensorCreator& DefaultTensorType::getTensorCreator() const { - return *creationFunc_; + return *creationFunc_; } } // namespace fl diff --git a/flashlight/fl/tensor/TensorAdapter.h b/flashlight/fl/tensor/TensorAdapter.h index fc7e61f..189f807 100644 --- a/flashlight/fl/tensor/TensorAdapter.h +++ b/flashlight/fl/tensor/TensorAdapter.h @@ -22,9 +22,9 @@ namespace fl { * @param[in] t the thing to convert * @return a Tensor containing the ArrayFire array */ -template +template Tensor toTensor(T&&... t) { - return Tensor(std::make_unique(std::forward(t)...)); + return Tensor(std::make_unique(std::forward(t)...)); } /** @@ -39,225 +39,227 @@ Tensor toTensor(T&&... t) { * Tensor or other interfaces. */ class FL_API TensorAdapterBase { - public: - TensorAdapterBase() = default; - virtual ~TensorAdapterBase() = default; - - /** - * Construct a tensor from some existing data. - * - * @param[in] shape the shape of the new tensor - * @param[in] ptr the buffer containing underlying tensor data - * @param[in] type the type of the new tensor - * @param[in] memoryLocation the location of the buffer - */ - TensorAdapterBase( - const Shape& shape, - fl::dtype type, - void* ptr, - MemoryLocation memoryLocation); - - TensorAdapterBase( - const Dim nRows, - const Dim nCols, - const Tensor& values, - const Tensor& rowIdx, - const Tensor& colIdx, - StorageType storageType); - - /** - * Copies the tensor adapter. The copy is not required to be eager -- the - * implementation can use copy-on-write. - */ - virtual std::unique_ptr clone() const = 0; - - /** - * Gets the tensor's associated backend. - * - * @return TensorBackendType enum associated with the backend - */ - virtual TensorBackendType backendType() const = 0; - - /** - * Gets the backend for a tensor with this adapter implementation. - * - * @return the TensorBackend instance backing this particular tensor. - */ - virtual TensorBackend& backend() const = 0; - - /** - * Deep copy the tensor, including underlying data. - */ - virtual Tensor copy() = 0; - - /** - * Shallow copy the tensor - return a tensor that points to the same - * underlying data. - */ - virtual Tensor shallowCopy() = 0; - - /** - * Get the shape of a tensor. - * - * @return the shape of the tensor - */ - virtual const Shape& shape() = 0; - - /** - * Get the data type of tensor. - * - * @return the dtype of the tensor - */ - virtual dtype type() = 0; - - /** - * Returns if the tensor is sparse. - * - * @return true if the tensor is sparse, else false - */ - virtual bool isSparse() = 0; - - /** - * Get a tensor's location, host or some device. - * - * @return the tensor's location - */ - virtual Location location() = 0; - - /** - * Populate a pointer with a scalar for the first element of the tensor. - */ - virtual void scalar(void* out) = 0; - - /** - * Returns a pointer to the tensor in device memory - */ - virtual void device(void** out) = 0; - - /** - * Populates a pointer with a pointer value in memory pointing to a host - * buffer containing tensor data. - */ - virtual void host(void* out) = 0; - - /** - * Unlocks any device memory associated with the tensor that was acquired with - * Tensor::device(), making it eligible to be freed. - */ - virtual void unlock() = 0; - - /** - * Returns true if the tensor has been memory-locked per a call to - * Tensor::device(). - * - * @return true if the tensor is locked and a device pointer is active. - */ - virtual bool isLocked() = 0; - - /** - * Returns a bool based on Tensor contiguousness in memory. - */ - virtual bool isContiguous() = 0; - - /** - * Get the dimension-wise strides for this tensor - the number of bytes to - * step in each direction when traversing. - */ - virtual Shape strides() = 0; - - /** - * Get the stream which contains(ed) the computation required to realize an - * up-to-date value for this tensor. For instance, `device()` may not yield a - * pointer to the up-to-date value -- to use this pointer, `Stream::sync` or - * `Stream::relativeSync` is required. - * - * @return an immutable reference to the stream that contains(ed) the - * computations which create this tensor. - */ - virtual const Stream& stream() const = 0; - - /** - * Returns a tensor with elements cast as a particular type - * - * @param[in] the type to which to cast the tensor - * @return a tensor with element-wise cast to the new type - */ - virtual Tensor astype(const dtype type) = 0; - - /** - * Index into a tensor with a variable number of indices. - * - * @param[in] indices a vector of Index references - * @return an indexed tensor - */ - virtual Tensor index(const std::vector& indices) = 0; - - /** - * Returns a representation of the tensor in 1 dimension. - * - * @return a 1D version of this tensor - */ - virtual Tensor flatten() const = 0; - - /** - * Returns a tensor indexed from this tensor but indexed as a 1D/flattened - * tensor. - * - * @return a 1D version of this tensor 1D-indexed with the given index. - */ - virtual Tensor flat(const Index& idx) const = 0; - - /** - * Returns a copy of the tensor that is contiguous in memory. - */ - virtual Tensor asContiguousTensor() = 0; - - /** - * Sets arbitrary data on a tensor. May be a no-op for some backends. - */ - virtual void setContext(void* context) = 0; - - /** - * Sets arbitrary data on a tensor. May be a no-op for some backends. - * - * @return An arbitrary payload - */ - virtual void* getContext() = 0; - - /** - * Return a string representation of a Tensor. Not intended to be portable - * across backends. - */ - virtual std::string toString() = 0; - - /** - * Write a string representation of a tensor to an output stream. - */ - virtual std::ostream& operator<<(std::ostream& ostr) = 0; - - /******************** Assignment Operators ********************/ +public: + TensorAdapterBase() = default; + virtual ~TensorAdapterBase() = default; + + /** + * Construct a tensor from some existing data. + * + * @param[in] shape the shape of the new tensor + * @param[in] ptr the buffer containing underlying tensor data + * @param[in] type the type of the new tensor + * @param[in] memoryLocation the location of the buffer + */ + TensorAdapterBase( + const Shape& shape, + fl::dtype type, + void* ptr, + MemoryLocation memoryLocation + ); + + TensorAdapterBase( + const Dim nRows, + const Dim nCols, + const Tensor& values, + const Tensor& rowIdx, + const Tensor& colIdx, + StorageType storageType + ); + + /** + * Copies the tensor adapter. The copy is not required to be eager -- the + * implementation can use copy-on-write. + */ + virtual std::unique_ptr clone() const = 0; + + /** + * Gets the tensor's associated backend. + * + * @return TensorBackendType enum associated with the backend + */ + virtual TensorBackendType backendType() const = 0; + + /** + * Gets the backend for a tensor with this adapter implementation. + * + * @return the TensorBackend instance backing this particular tensor. + */ + virtual TensorBackend& backend() const = 0; + + /** + * Deep copy the tensor, including underlying data. + */ + virtual Tensor copy() = 0; + + /** + * Shallow copy the tensor - return a tensor that points to the same + * underlying data. + */ + virtual Tensor shallowCopy() = 0; + + /** + * Get the shape of a tensor. + * + * @return the shape of the tensor + */ + virtual const Shape& shape() = 0; + + /** + * Get the data type of tensor. + * + * @return the dtype of the tensor + */ + virtual dtype type() = 0; + + /** + * Returns if the tensor is sparse. + * + * @return true if the tensor is sparse, else false + */ + virtual bool isSparse() = 0; + + /** + * Get a tensor's location, host or some device. + * + * @return the tensor's location + */ + virtual Location location() = 0; + + /** + * Populate a pointer with a scalar for the first element of the tensor. + */ + virtual void scalar(void* out) = 0; + + /** + * Returns a pointer to the tensor in device memory + */ + virtual void device(void** out) = 0; + + /** + * Populates a pointer with a pointer value in memory pointing to a host + * buffer containing tensor data. + */ + virtual void host(void* out) = 0; + + /** + * Unlocks any device memory associated with the tensor that was acquired with + * Tensor::device(), making it eligible to be freed. + */ + virtual void unlock() = 0; + + /** + * Returns true if the tensor has been memory-locked per a call to + * Tensor::device(). + * + * @return true if the tensor is locked and a device pointer is active. + */ + virtual bool isLocked() = 0; + + /** + * Returns a bool based on Tensor contiguousness in memory. + */ + virtual bool isContiguous() = 0; + + /** + * Get the dimension-wise strides for this tensor - the number of bytes to + * step in each direction when traversing. + */ + virtual Shape strides() = 0; + + /** + * Get the stream which contains(ed) the computation required to realize an + * up-to-date value for this tensor. For instance, `device()` may not yield a + * pointer to the up-to-date value -- to use this pointer, `Stream::sync` or + * `Stream::relativeSync` is required. + * + * @return an immutable reference to the stream that contains(ed) the + * computations which create this tensor. + */ + virtual const Stream& stream() const = 0; + + /** + * Returns a tensor with elements cast as a particular type + * + * @param[in] the type to which to cast the tensor + * @return a tensor with element-wise cast to the new type + */ + virtual Tensor astype(const dtype type) = 0; + + /** + * Index into a tensor with a variable number of indices. + * + * @param[in] indices a vector of Index references + * @return an indexed tensor + */ + virtual Tensor index(const std::vector& indices) = 0; + + /** + * Returns a representation of the tensor in 1 dimension. + * + * @return a 1D version of this tensor + */ + virtual Tensor flatten() const = 0; + + /** + * Returns a tensor indexed from this tensor but indexed as a 1D/flattened + * tensor. + * + * @return a 1D version of this tensor 1D-indexed with the given index. + */ + virtual Tensor flat(const Index& idx) const = 0; + + /** + * Returns a copy of the tensor that is contiguous in memory. + */ + virtual Tensor asContiguousTensor() = 0; + + /** + * Sets arbitrary data on a tensor. May be a no-op for some backends. + */ + virtual void setContext(void* context) = 0; + + /** + * Sets arbitrary data on a tensor. May be a no-op for some backends. + * + * @return An arbitrary payload + */ + virtual void* getContext() = 0; + + /** + * Return a string representation of a Tensor. Not intended to be portable + * across backends. + */ + virtual std::string toString() = 0; + + /** + * Write a string representation of a tensor to an output stream. + */ + virtual std::ostream& operator<<(std::ostream& ostr) = 0; + + /******************** Assignment Operators ********************/ #define ASSIGN_OP_TYPE(OP, TYPE) virtual void OP(const TYPE& val) = 0; -#define ASSIGN_OP(OP) \ - ASSIGN_OP_TYPE(OP, Tensor); \ - ASSIGN_OP_TYPE(OP, double); \ - ASSIGN_OP_TYPE(OP, float); \ - ASSIGN_OP_TYPE(OP, int); \ - ASSIGN_OP_TYPE(OP, unsigned); \ - ASSIGN_OP_TYPE(OP, bool); \ - ASSIGN_OP_TYPE(OP, char); \ - ASSIGN_OP_TYPE(OP, unsigned char); \ - ASSIGN_OP_TYPE(OP, short); \ - ASSIGN_OP_TYPE(OP, unsigned short); \ - ASSIGN_OP_TYPE(OP, long); \ - ASSIGN_OP_TYPE(OP, unsigned long); \ - ASSIGN_OP_TYPE(OP, long long); \ - ASSIGN_OP_TYPE(OP, unsigned long long); - - ASSIGN_OP(assign); // = - ASSIGN_OP(inPlaceAdd); // += - ASSIGN_OP(inPlaceSubtract); // -= - ASSIGN_OP(inPlaceMultiply); // *= - ASSIGN_OP(inPlaceDivide); // /= +#define ASSIGN_OP(OP) \ + ASSIGN_OP_TYPE(OP, Tensor); \ + ASSIGN_OP_TYPE(OP, double); \ + ASSIGN_OP_TYPE(OP, float); \ + ASSIGN_OP_TYPE(OP, int); \ + ASSIGN_OP_TYPE(OP, unsigned); \ + ASSIGN_OP_TYPE(OP, bool); \ + ASSIGN_OP_TYPE(OP, char); \ + ASSIGN_OP_TYPE(OP, unsigned char); \ + ASSIGN_OP_TYPE(OP, short); \ + ASSIGN_OP_TYPE(OP, unsigned short); \ + ASSIGN_OP_TYPE(OP, long); \ + ASSIGN_OP_TYPE(OP, unsigned long); \ + ASSIGN_OP_TYPE(OP, long long); \ + ASSIGN_OP_TYPE(OP, unsigned long long); + + ASSIGN_OP(assign); // = + ASSIGN_OP(inPlaceAdd); // += + ASSIGN_OP(inPlaceSubtract); // -= + ASSIGN_OP(inPlaceMultiply); // *= + ASSIGN_OP(inPlaceDivide); ///= #undef ASSIGN_OP_TYPE #undef ASSIGN_OP }; @@ -268,80 +270,92 @@ namespace detail { * An interface with which to construct a tensor. Templated based on used tensor * adapters. */ -struct FL_API TensorCreator { - virtual ~TensorCreator() = default; - - // General tensor ctor - virtual std::unique_ptr get( - const Shape& shape = {0}, // 0 shape is an empty Tensor - fl::dtype type = fl::dtype::f32, - const void* ptr = nullptr, - MemoryLocation memoryLocation = MemoryLocation::Host) const = 0; - - // Sparse tensor ctor - virtual std::unique_ptr get( - const Dim nRows, - const Dim nCols, - const Tensor& values, - const Tensor& rowIdx, - const Tensor& colIdx, - StorageType storageType) const = 0; -}; - -template -struct TensorCreatorImpl : public TensorCreator { - TensorCreatorImpl() = default; - ~TensorCreatorImpl() override = default; - - std::unique_ptr get( - const Shape& shape = {0}, // 0 shape is an empty Tensor - fl::dtype type = fl::dtype::f32, - const void* ptr = nullptr, - MemoryLocation memoryLocation = MemoryLocation::Host) const override { - return std::make_unique(shape, type, ptr, memoryLocation); - } - - std::unique_ptr get( - const Dim nRows, - const Dim nCols, - const Tensor& values, - const Tensor& rowIdx, - const Tensor& colIdx, - StorageType storageType) const override { - return std::make_unique( - nRows, nCols, values, rowIdx, colIdx, storageType); - } -}; + struct FL_API TensorCreator { + virtual ~TensorCreator() = default; + + // General tensor ctor + virtual std::unique_ptr get( + const Shape& shape = { 0 }, // 0 shape is an empty Tensor + fl::dtype type = fl::dtype::f32, + const void* ptr = nullptr, + MemoryLocation memoryLocation = MemoryLocation::Host + ) const = 0; + + // Sparse tensor ctor + virtual std::unique_ptr get( + const Dim nRows, + const Dim nCols, + const Tensor& values, + const Tensor& rowIdx, + const Tensor& colIdx, + StorageType storageType + ) const = 0; + }; + + template + struct TensorCreatorImpl : public TensorCreator { + TensorCreatorImpl() = default; + ~TensorCreatorImpl() override = default; + + std::unique_ptr get( + const Shape& shape = { 0 }, // 0 shape is an empty Tensor + fl::dtype type = fl::dtype::f32, + const void* ptr = nullptr, + MemoryLocation memoryLocation = MemoryLocation::Host + ) const override { + return std::make_unique(shape, type, ptr, memoryLocation); + } + + std::unique_ptr get( + const Dim nRows, + const Dim nCols, + const Tensor& values, + const Tensor& rowIdx, + const Tensor& colIdx, + StorageType storageType + ) const override { + return std::make_unique( + nRows, + nCols, + values, + rowIdx, + colIdx, + storageType + ); + } + }; /* * A singleton to hold a closure which creates a new tensor of default type. For * internal use only - use setDefaultTensorType() to set the type with which * to create a default tensor. */ -class FL_API DefaultTensorType { - // The function to use to create a tensor of default type. - std::unique_ptr creationFunc_; + class FL_API DefaultTensorType { + // The function to use to create a tensor of default type. + std::unique_ptr creationFunc_; - public: - static DefaultTensorType& getInstance(); - DefaultTensorType(); + public: + static DefaultTensorType& getInstance(); + DefaultTensorType(); - std::unique_ptr swap( - std::unique_ptr creator) noexcept; - const TensorCreator& getTensorCreator() const; + std::unique_ptr swap( + std::unique_ptr creator + ) noexcept; + const TensorCreator& getTensorCreator() const; - DefaultTensorType(DefaultTensorType const&) = delete; - void operator=(DefaultTensorType const&) = delete; -}; + DefaultTensorType(DefaultTensorType const&) = delete; + void operator=(DefaultTensorType const&) = delete; + }; /** * Get an instance of the default tensor adapter. */ -template -std::unique_ptr getDefaultAdapter(T&&... t) { - return DefaultTensorType::getInstance().getTensorCreator().get( - std::forward(t)...); -} + template + std::unique_ptr getDefaultAdapter(T&&... t) { + return DefaultTensorType::getInstance().getTensorCreator().get( + std::forward(t)... + ); + } } // namespace detail @@ -355,27 +369,31 @@ std::unique_ptr getDefaultAdapter(T&&... t) { * * Where TensorType is derived from TensorAdapterBase. */ -template +template void setDefaultTensorType() { - static_assert( - std::is_base_of::value, - "setDefaultTensorType: T must be a derived type of TensorAdapterBase"); - fl::detail::DefaultTensorType::getInstance().swap( - std::make_unique>()); + static_assert( + std::is_base_of::value, + "setDefaultTensorType: T must be a derived type of TensorAdapterBase" + ); + fl::detail::DefaultTensorType::getInstance().swap( + std::make_unique>() + ); } -template +template void withTensorType(B func) { - static_assert( - std::is_base_of::value, - "withTensorType: T must be a derived type of TensorAdapterBase"); - - // Swap - auto oldCreator = fl::detail::DefaultTensorType::getInstance().swap( - std::make_unique>()); - func(); - // Restore - fl::detail::DefaultTensorType::getInstance().swap(std::move(oldCreator)); + static_assert( + std::is_base_of::value, + "withTensorType: T must be a derived type of TensorAdapterBase" + ); + + // Swap + auto oldCreator = fl::detail::DefaultTensorType::getInstance().swap( + std::make_unique>() + ); + func(); + // Restore + fl::detail::DefaultTensorType::getInstance().swap(std::move(oldCreator)); } } // namespace fl diff --git a/flashlight/fl/tensor/TensorBackend.cpp b/flashlight/fl/tensor/TensorBackend.cpp index a6b4404..114e4de 100644 --- a/flashlight/fl/tensor/TensorBackend.cpp +++ b/flashlight/fl/tensor/TensorBackend.cpp @@ -10,82 +10,94 @@ namespace fl { namespace detail { -bool areBackendsEqual(const Tensor& a, const Tensor& b) { - return a.backendType() == b.backendType(); -} + bool areBackendsEqual(const Tensor& a, const Tensor& b) { + return a.backendType() == b.backendType(); + } } // namespace detail bool TensorBackend::isDataTypeSupported(const fl::dtype& dtype) const { - bool supported = this->supportsDataType(dtype); - for (auto& p : extensions_) { - supported &= p.second->isDataTypeSupported(dtype); - } - return supported; + bool supported = this->supportsDataType(dtype); + for(auto& p : extensions_) { + supported &= p.second->isDataTypeSupported(dtype); + } + return supported; } Tensor TensorBackend::clip( const Tensor& tensor, const Tensor& low, - const double& high) { - return clip( - tensor, low, full(tensor.shape(), high, dtype_traits::ctype)); + const double& high +) { + return clip( + tensor, + low, + full(tensor.shape(), high, dtype_traits::ctype) + ); } Tensor TensorBackend::clip( const Tensor& tensor, const double& low, - const Tensor& high) { - return clip( - tensor, full(tensor.shape(), low, dtype_traits::ctype), high); + const Tensor& high +) { + return clip( + tensor, + full(tensor.shape(), low, dtype_traits::ctype), + high + ); } Tensor TensorBackend::clip( const Tensor& tensor, const double& low, - const double& high) { - return clip( - tensor, - full(tensor.shape(), low, dtype_traits::ctype), - full(tensor.shape(), high, dtype_traits::ctype)); + const double& high +) { + return clip( + tensor, + full(tensor.shape(), low, dtype_traits::ctype), + full(tensor.shape(), high, dtype_traits::ctype) + ); } Tensor TensorBackend::where( const Tensor& condition, const Tensor& x, - const double& y) { - return where(condition, x, full(condition.shape(), y, x.type())); + const double& y +) { + return where(condition, x, full(condition.shape(), y, x.type())); } Tensor TensorBackend::where( const Tensor& condition, const double& x, - const Tensor& y) { - return where(condition, full(condition.shape(), x, y.type()), y); + const Tensor& y +) { + return where(condition, full(condition.shape(), x, y.type()), y); } Tensor TensorBackend::minimum(const Tensor& lhs, const double& rhs) { - return minimum(lhs, full(lhs.shape(), rhs, dtype_traits::ctype)); + return minimum(lhs, full(lhs.shape(), rhs, dtype_traits::ctype)); } Tensor TensorBackend::minimum(const double& lhs, const Tensor& rhs) { - return minimum(full(rhs.shape(), lhs, dtype_traits::ctype), rhs); + return minimum(full(rhs.shape(), lhs, dtype_traits::ctype), rhs); } Tensor TensorBackend::maximum(const Tensor& lhs, const double& rhs) { - return maximum(lhs, full(lhs.shape(), rhs, dtype_traits::ctype)); + return maximum(lhs, full(lhs.shape(), rhs, dtype_traits::ctype)); } Tensor TensorBackend::maximum(const double& lhs, const Tensor& rhs) { - return maximum(full(rhs.shape(), lhs, dtype_traits::ctype), rhs); + return maximum(full(rhs.shape(), lhs, dtype_traits::ctype), rhs); } Tensor TensorBackend::power(const Tensor& lhs, const double& rhs) { - return power(lhs, full(lhs.shape(), rhs, dtype_traits::ctype)); + return power(lhs, full(lhs.shape(), rhs, dtype_traits::ctype)); } Tensor TensorBackend::power(const double& lhs, const Tensor& rhs) { - return power(full(rhs.shape(), lhs, dtype_traits::ctype), rhs); + return power(full(rhs.shape(), lhs, dtype_traits::ctype), rhs); } } // namespace fl diff --git a/flashlight/fl/tensor/TensorBackend.h b/flashlight/fl/tensor/TensorBackend.h index 5257dab..70fc9e9 100644 --- a/flashlight/fl/tensor/TensorBackend.h +++ b/flashlight/fl/tensor/TensorBackend.h @@ -34,274 +34,263 @@ class Stream; * instance. */ class TensorBackend { - public: - TensorBackend() = default; - virtual ~TensorBackend() = default; - virtual TensorBackendType backendType() const = 0; - - /* -------------------------- Compute Functions -------------------------- */ - virtual void eval(const Tensor& tensor) = 0; - virtual bool supportsDataType(const fl::dtype& dtype) const = 0; - // Memory Management - virtual void - getMemMgrInfo(const char* msg, const int deviceId, std::ostream* ostream) = 0; - virtual void setMemMgrLogStream(std::ostream* stream) = 0; - virtual void setMemMgrLoggingEnabled(const bool enabled) = 0; - virtual void setMemMgrFlushInterval(const size_t interval) = 0; - - /* -------------------------- Rand Functions -------------------------- */ - virtual void setSeed(const int seed) = 0; - virtual Tensor randn(const Shape& shape, dtype type) = 0; - virtual Tensor rand(const Shape& shape, dtype type) = 0; - - /* --------------------------- Tensor Operators --------------------------- - * For operator documentation and expected behavior, see TensorBase.h. - */ - /******************** Tensor Creation Functions ********************/ -#define FL_CREATE_FUN_LITERAL_BACKEND_DECL(TYPE) \ - virtual Tensor fromScalar(TYPE value, const dtype type) = 0; \ - virtual Tensor full(const Shape& dims, TYPE value, const dtype type) = 0; - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const double&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const float&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const int&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const unsigned&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const char&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const unsigned char&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const long&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const unsigned long&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const long long&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const unsigned long long&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const bool&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const short&); - FL_CREATE_FUN_LITERAL_BACKEND_DECL(const unsigned short&); +public: + TensorBackend() = default; + virtual ~TensorBackend() = default; + virtual TensorBackendType backendType() const = 0; + + /* -------------------------- Compute Functions -------------------------- */ + virtual void eval(const Tensor& tensor) = 0; + virtual bool supportsDataType(const fl::dtype& dtype) const = 0; + // Memory Management + virtual void getMemMgrInfo(const char* msg, const int deviceId, std::ostream* ostream) = 0; + virtual void setMemMgrLogStream(std::ostream* stream) = 0; + virtual void setMemMgrLoggingEnabled(const bool enabled) = 0; + virtual void setMemMgrFlushInterval(const size_t interval) = 0; + + /* -------------------------- Rand Functions -------------------------- */ + virtual void setSeed(const int seed) = 0; + virtual Tensor randn(const Shape& shape, dtype type) = 0; + virtual Tensor rand(const Shape& shape, dtype type) = 0; + + /* --------------------------- Tensor Operators --------------------------- + * For operator documentation and expected behavior, see TensorBase.h. + */ + /******************** Tensor Creation Functions ********************/ +#define FL_CREATE_FUN_LITERAL_BACKEND_DECL(TYPE) \ + virtual Tensor fromScalar(TYPE value, const dtype type) = 0; \ + virtual Tensor full(const Shape& dims, TYPE value, const dtype type) = 0; + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const double&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const float&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const int&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const unsigned&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const char&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const unsigned char&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const long&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const unsigned long&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const long long&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const unsigned long long&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const bool&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const short&); + FL_CREATE_FUN_LITERAL_BACKEND_DECL(const unsigned short&); #undef FL_CREATE_FUN_LITERAL_BACKEND_DECL - virtual Tensor identity(const Dim dim, const dtype type) = 0; - virtual Tensor - arange(const Shape& shape, const Dim seqDim, const dtype type) = 0; - virtual Tensor - iota(const Shape& dims, const Shape& tileDims, const dtype type) = 0; - - /************************ Shaping and Indexing *************************/ - virtual Tensor reshape(const Tensor& tensor, const Shape& shape) = 0; - virtual Tensor transpose( - const Tensor& tensor, - const Shape& axes /* = {} */) = 0; - virtual Tensor tile(const Tensor& tensor, const Shape& shape) = 0; - virtual Tensor concatenate( - const std::vector& tensors, - const unsigned axis) = 0; - virtual Tensor nonzero(const Tensor& tensor) = 0; - virtual Tensor pad( - const Tensor& input, - const std::vector>& padWidths, - const PadType type) = 0; - - /************************** Unary Operators ***************************/ - virtual Tensor exp(const Tensor& tensor) = 0; - virtual Tensor log(const Tensor& tensor) = 0; - virtual Tensor negative(const Tensor& tensor) = 0; - virtual Tensor logicalNot(const Tensor& tensor) = 0; - virtual Tensor log1p(const Tensor& tensor) = 0; - virtual Tensor sin(const Tensor& tensor) = 0; - virtual Tensor cos(const Tensor& tensor) = 0; - virtual Tensor sqrt(const Tensor& tensor) = 0; - virtual Tensor tanh(const Tensor& tensor) = 0; - virtual Tensor floor(const Tensor& tensor) = 0; - virtual Tensor ceil(const Tensor& tensor) = 0; - virtual Tensor rint(const Tensor& tensor) = 0; - virtual Tensor absolute(const Tensor& tensor) = 0; - virtual Tensor sigmoid(const Tensor& tensor) = 0; - virtual Tensor erf(const Tensor& tensor) = 0; - virtual Tensor flip(const Tensor& tensor, const unsigned dim) = 0; - virtual Tensor - clip(const Tensor& tensor, const Tensor& low, const Tensor& high) = 0; - virtual Tensor - clip(const Tensor& tensor, const Tensor& low, const double& high); - virtual Tensor - clip(const Tensor& tensor, const double& low, const Tensor& high); - virtual Tensor - clip(const Tensor& tensor, const double& low, const double& high); - virtual Tensor - roll(const Tensor& tensor, const int shift, const unsigned axis) = 0; - virtual Tensor isnan(const Tensor& tensor) = 0; - virtual Tensor isinf(const Tensor& tensor) = 0; - virtual Tensor sign(const Tensor& tensor) = 0; - virtual Tensor tril(const Tensor& tensor) = 0; - virtual Tensor triu(const Tensor& tensor) = 0; - virtual Tensor - where(const Tensor& condition, const Tensor& x, const Tensor& y) = 0; - virtual Tensor - where(const Tensor& condition, const Tensor& x, const double& y); - virtual Tensor - where(const Tensor& condition, const double& x, const Tensor& y); - virtual void topk( - Tensor& values, - Tensor& indices, - const Tensor& input, - const unsigned k, - const Dim axis, - const SortMode sortMode) = 0; - virtual Tensor - sort(const Tensor& input, const Dim axis, const SortMode sortMode) = 0; - virtual void sort( - Tensor& values, - Tensor& indices, - const Tensor& input, - const Dim axis, - const SortMode sortMode) = 0; - virtual Tensor - argsort(const Tensor& input, const Dim axis, const SortMode sortMode) = 0; - - /************************** Binary Operators ***************************/ -#define FL_BINARY_OP_TYPE_DECL(FUNC, TYPE) \ - virtual Tensor FUNC(const Tensor& a, TYPE rhs) = 0; \ - virtual Tensor FUNC(TYPE lhs, const Tensor& a) = 0; - -#define FL_BINARY_OP_LITERALS_DECL(FUNC) \ - FL_BINARY_OP_TYPE_DECL(FUNC, const bool&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const int&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const unsigned&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const char&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const unsigned char&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const long&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const unsigned long&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const long long&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const unsigned long long&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const double&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const float&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const short&); \ - FL_BINARY_OP_TYPE_DECL(FUNC, const unsigned short&); - -#define FL_BINARY_OP_DECL(FUNC) \ - virtual Tensor FUNC(const Tensor& lhs, const Tensor& rhs) = 0; \ - FL_BINARY_OP_LITERALS_DECL(FUNC); - - FL_BINARY_OP_DECL(add); - FL_BINARY_OP_DECL(sub); - FL_BINARY_OP_DECL(mul); - FL_BINARY_OP_DECL(div); - FL_BINARY_OP_DECL(eq); - FL_BINARY_OP_DECL(neq); - FL_BINARY_OP_DECL(lessThan); - FL_BINARY_OP_DECL(lessThanEqual); - FL_BINARY_OP_DECL(greaterThan); - FL_BINARY_OP_DECL(greaterThanEqual); - FL_BINARY_OP_DECL(logicalOr); - FL_BINARY_OP_DECL(logicalAnd); - FL_BINARY_OP_DECL(mod); - FL_BINARY_OP_DECL(bitwiseAnd); - FL_BINARY_OP_DECL(bitwiseOr); - FL_BINARY_OP_DECL(bitwiseXor); - FL_BINARY_OP_DECL(lShift); - FL_BINARY_OP_DECL(rShift); + virtual Tensor identity(const Dim dim, const dtype type) = 0; + virtual Tensor arange(const Shape& shape, const Dim seqDim, const dtype type) = 0; + virtual Tensor iota(const Shape& dims, const Shape& tileDims, const dtype type) = 0; + + /************************ Shaping and Indexing *************************/ + virtual Tensor reshape(const Tensor& tensor, const Shape& shape) = 0; + virtual Tensor transpose( + const Tensor& tensor, + const Shape& axes /* = {} */ + ) = 0; + virtual Tensor tile(const Tensor& tensor, const Shape& shape) = 0; + virtual Tensor concatenate( + const std::vector& tensors, + const unsigned axis + ) = 0; + virtual Tensor nonzero(const Tensor& tensor) = 0; + virtual Tensor pad( + const Tensor& input, + const std::vector>& padWidths, + const PadType type + ) = 0; + + /************************** Unary Operators ***************************/ + virtual Tensor exp(const Tensor& tensor) = 0; + virtual Tensor log(const Tensor& tensor) = 0; + virtual Tensor negative(const Tensor& tensor) = 0; + virtual Tensor logicalNot(const Tensor& tensor) = 0; + virtual Tensor log1p(const Tensor& tensor) = 0; + virtual Tensor sin(const Tensor& tensor) = 0; + virtual Tensor cos(const Tensor& tensor) = 0; + virtual Tensor sqrt(const Tensor& tensor) = 0; + virtual Tensor tanh(const Tensor& tensor) = 0; + virtual Tensor floor(const Tensor& tensor) = 0; + virtual Tensor ceil(const Tensor& tensor) = 0; + virtual Tensor rint(const Tensor& tensor) = 0; + virtual Tensor absolute(const Tensor& tensor) = 0; + virtual Tensor sigmoid(const Tensor& tensor) = 0; + virtual Tensor erf(const Tensor& tensor) = 0; + virtual Tensor flip(const Tensor& tensor, const unsigned dim) = 0; + virtual Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high) = 0; + virtual Tensor clip(const Tensor& tensor, const Tensor& low, const double& high); + virtual Tensor clip(const Tensor& tensor, const double& low, const Tensor& high); + virtual Tensor clip(const Tensor& tensor, const double& low, const double& high); + virtual Tensor roll(const Tensor& tensor, const int shift, const unsigned axis) = 0; + virtual Tensor isnan(const Tensor& tensor) = 0; + virtual Tensor isinf(const Tensor& tensor) = 0; + virtual Tensor sign(const Tensor& tensor) = 0; + virtual Tensor tril(const Tensor& tensor) = 0; + virtual Tensor triu(const Tensor& tensor) = 0; + virtual Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) = 0; + virtual Tensor where(const Tensor& condition, const Tensor& x, const double& y); + virtual Tensor where(const Tensor& condition, const double& x, const Tensor& y); + virtual void topk( + Tensor& values, + Tensor& indices, + const Tensor& input, + const unsigned k, + const Dim axis, + const SortMode sortMode + ) = 0; + virtual Tensor sort(const Tensor& input, const Dim axis, const SortMode sortMode) = 0; + virtual void sort( + Tensor& values, + Tensor& indices, + const Tensor& input, + const Dim axis, + const SortMode sortMode + ) = 0; + virtual Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) = 0; + + /************************** Binary Operators ***************************/ +#define FL_BINARY_OP_TYPE_DECL(FUNC, TYPE) \ + virtual Tensor FUNC(const Tensor& a, TYPE rhs) = 0; \ + virtual Tensor FUNC(TYPE lhs, const Tensor& a) = 0; + +#define FL_BINARY_OP_LITERALS_DECL(FUNC) \ + FL_BINARY_OP_TYPE_DECL(FUNC, const bool&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const int&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const unsigned&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const char&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const unsigned char&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const long&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const unsigned long&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const long long&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const unsigned long long&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const double&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const float&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const short&); \ + FL_BINARY_OP_TYPE_DECL(FUNC, const unsigned short&); + +#define FL_BINARY_OP_DECL(FUNC) \ + virtual Tensor FUNC(const Tensor& lhs, const Tensor& rhs) = 0; \ + FL_BINARY_OP_LITERALS_DECL(FUNC); + + FL_BINARY_OP_DECL(add); + FL_BINARY_OP_DECL(sub); + FL_BINARY_OP_DECL(mul); + FL_BINARY_OP_DECL(div); + FL_BINARY_OP_DECL(eq); + FL_BINARY_OP_DECL(neq); + FL_BINARY_OP_DECL(lessThan); + FL_BINARY_OP_DECL(lessThanEqual); + FL_BINARY_OP_DECL(greaterThan); + FL_BINARY_OP_DECL(greaterThanEqual); + FL_BINARY_OP_DECL(logicalOr); + FL_BINARY_OP_DECL(logicalAnd); + FL_BINARY_OP_DECL(mod); + FL_BINARY_OP_DECL(bitwiseAnd); + FL_BINARY_OP_DECL(bitwiseOr); + FL_BINARY_OP_DECL(bitwiseXor); + FL_BINARY_OP_DECL(lShift); + FL_BINARY_OP_DECL(rShift); #undef FL_BINARY_OP_DECL #undef FL_BINARY_OP_TYPE_DECL #undef FL_BINARY_OP_LITERALS_DECL - virtual Tensor minimum(const Tensor& lhs, const Tensor& rhs) = 0; - virtual Tensor minimum(const Tensor& lhs, const double& rhs); - virtual Tensor minimum(const double& lhs, const Tensor& rhs); - virtual Tensor maximum(const Tensor& lhs, const Tensor& rhs) = 0; - virtual Tensor maximum(const Tensor& lhs, const double& rhs); - virtual Tensor maximum(const double& lhs, const Tensor& rhs); - virtual Tensor power(const Tensor& lhs, const Tensor& rhs) = 0; - virtual Tensor power(const Tensor& lhs, const double& rhs); - virtual Tensor power(const double& lhs, const Tensor& rhs); - - /******************************* BLAS ********************************/ - virtual Tensor matmul( - const Tensor& lhs, - const Tensor& rhs, - MatrixProperty lhsProp, - MatrixProperty rhsProp) = 0; - - /************************** Reductions ***************************/ - virtual Tensor - amin(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - virtual Tensor - amax(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - virtual void min( - Tensor& values, - Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims) = 0; - virtual void max( - Tensor& values, - Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims) = 0; - virtual Tensor - sum(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - virtual Tensor cumsum(const Tensor& input, const unsigned axis) = 0; - virtual Tensor - argmax(const Tensor& input, const unsigned axis, const bool keepDims) = 0; - virtual Tensor - argmin(const Tensor& input, const unsigned axis, const bool keepDims) = 0; - virtual Tensor - mean(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - virtual Tensor - median(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - virtual Tensor var( - const Tensor& input, - const std::vector& axes, - bool bias, - const bool keepDims) = 0; - virtual Tensor - std(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - virtual Tensor norm( - const Tensor& input, - const std::vector& axes, - double p, - const bool keepDims) = 0; - virtual Tensor countNonzero( - const Tensor& input, - const std::vector& axes, - const bool keepDims) = 0; - virtual Tensor - any(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - virtual Tensor - all(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; - - /************************** Utils ***************************/ - virtual void print(const Tensor& tensor) = 0; - - /** - * Checks if a datatype is supported by a TensorBackend and its registered - * extensions. - * - * @param[in] dtype the datatype to check - * - * @return true if the data type is supported, false otherwise - */ - virtual bool isDataTypeSupported(const fl::dtype& dtype) const final; - - /********************* Tensor Extensions **********************/ - template - T& getExtension() { - static_assert( - std::is_base_of::value, - "TensorBackend::getExtension() called with type T " - "that is not derived from TensorExtensionBase."); - - TensorExtensionType e = T::getExtensionType(); - - // If an extension isn't present, instantiate it via its registered - // creation function - only do this once per extension. - if (extensions_.find(e) == extensions_.end()) { - auto& creationFunc = - detail::TensorExtensionRegistrar::getInstance() - .getTensorExtensionCreationFunc(this->backendType(), e); - extensions_.emplace(e, creationFunc()); + virtual Tensor minimum(const Tensor& lhs, const Tensor& rhs) = 0; + virtual Tensor minimum(const Tensor& lhs, const double& rhs); + virtual Tensor minimum(const double& lhs, const Tensor& rhs); + virtual Tensor maximum(const Tensor& lhs, const Tensor& rhs) = 0; + virtual Tensor maximum(const Tensor& lhs, const double& rhs); + virtual Tensor maximum(const double& lhs, const Tensor& rhs); + virtual Tensor power(const Tensor& lhs, const Tensor& rhs) = 0; + virtual Tensor power(const Tensor& lhs, const double& rhs); + virtual Tensor power(const double& lhs, const Tensor& rhs); + + /******************************* BLAS ********************************/ + virtual Tensor matmul( + const Tensor& lhs, + const Tensor& rhs, + MatrixProperty lhsProp, + MatrixProperty rhsProp + ) = 0; + + /************************** Reductions ***************************/ + virtual Tensor amin(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + virtual Tensor amax(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + virtual void min( + Tensor& values, + Tensor& indices, + const Tensor& input, + const unsigned axis, + const bool keepDims + ) = 0; + virtual void max( + Tensor& values, + Tensor& indices, + const Tensor& input, + const unsigned axis, + const bool keepDims + ) = 0; + virtual Tensor sum(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + virtual Tensor cumsum(const Tensor& input, const unsigned axis) = 0; + virtual Tensor argmax(const Tensor& input, const unsigned axis, const bool keepDims) = 0; + virtual Tensor argmin(const Tensor& input, const unsigned axis, const bool keepDims) = 0; + virtual Tensor mean(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + virtual Tensor median(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + virtual Tensor var( + const Tensor& input, + const std::vector& axes, + bool bias, + const bool keepDims + ) = 0; + virtual Tensor std(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + virtual Tensor norm( + const Tensor& input, + const std::vector& axes, + double p, + const bool keepDims + ) = 0; + virtual Tensor countNonzero( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) = 0; + virtual Tensor any(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + virtual Tensor all(const Tensor& input, const std::vector& axes, const bool keepDims) = 0; + + /************************** Utils ***************************/ + virtual void print(const Tensor& tensor) = 0; + + /** + * Checks if a datatype is supported by a TensorBackend and its registered + * extensions. + * + * @param[in] dtype the datatype to check + * + * @return true if the data type is supported, false otherwise + */ + virtual bool isDataTypeSupported(const fl::dtype& dtype) const final; + + /********************* Tensor Extensions **********************/ + template + T& getExtension() { + static_assert( + std::is_base_of::value, + "TensorBackend::getExtension() called with type T " + "that is not derived from TensorExtensionBase." + ); + + TensorExtensionType e = T::getExtensionType(); + + // If an extension isn't present, instantiate it via its registered + // creation function - only do this once per extension. + if(extensions_.find(e) == extensions_.end()) { + auto& creationFunc = + detail::TensorExtensionRegistrar::getInstance() + .getTensorExtensionCreationFunc(this->backendType(), e); + extensions_.emplace(e, creationFunc()); + } + return *(static_cast(extensions_.at(e).get())); } - return *(static_cast(extensions_.at(e).get())); - } - protected: - std::unordered_map> - extensions_; +protected: + std::unordered_map> + extensions_; }; /** @@ -312,26 +301,30 @@ class TensorBackend { * @param[in] in a tensor rvalue reference * @return a tensor with backend type specified by the template */ -template +template Tensor toTensorType(Tensor&& in) { - static_assert( - std::is_base_of::value, - "toTensorType: T must be a derived type of TensorAdapterBase"); - // Fast path - backend is the same - // TODO: make fl::TensorBackendType a static constexpr on the class as well so - // as to not need to instantiate a backend to check the type - if (in.backendType() == T().backendType()) { - return std::move(in); - } - - // As per impl requirements, Tensor::device() should return a pointer to host - // memory if the tensor resides on the host. - return Tensor(std::make_unique( - in.shape(), - in.type(), - // TODO: use the void specialization instead of a reinterpret cast - reinterpret_cast(in.device()), // expects contiguous memory - in.location())); + static_assert( + std::is_base_of::value, + "toTensorType: T must be a derived type of TensorAdapterBase" + ); + // Fast path - backend is the same + // TODO: make fl::TensorBackendType a static constexpr on the class as well so + // as to not need to instantiate a backend to check the type + if(in.backendType() == T().backendType()) { + return std::move(in); + } + + // As per impl requirements, Tensor::device() should return a pointer to host + // memory if the tensor resides on the host. + return Tensor( + std::make_unique( + in.shape(), + in.type(), + // TODO: use the void specialization instead of a reinterpret cast + reinterpret_cast(in.device()), // expects contiguous memory + in.location() + ) + ); } namespace detail { @@ -341,25 +334,25 @@ namespace detail { * * @return true if the backends of both tensors are the same, else false. */ -bool areBackendsEqual(const Tensor& a, const Tensor& b); + bool areBackendsEqual(const Tensor& a, const Tensor& b); /** * Compare the backends of multiple tensors. * * @return true if all tensors' backends are the same, false otherwise. */ -template -bool areBackendsEqual(const Tensor& a, const Tensor& b, const Args&... args) { - return areBackendsEqual(a, b) && areBackendsEqual(a, args...) && - areBackendsEqual(b, args...); -} + template + bool areBackendsEqual(const Tensor& a, const Tensor& b, const Args&... args) { + return areBackendsEqual(a, b) && areBackendsEqual(a, args...) + && areBackendsEqual(b, args...); + } /** * * @return a reference to a tensor backend instance descripting the default - backend. + backend. */ -TensorBackend& getDefaultBackend(); + TensorBackend& getDefaultBackend(); } // namespace detail } // namespace fl diff --git a/flashlight/fl/tensor/TensorBase.cpp b/flashlight/fl/tensor/TensorBase.cpp index 61d52cd..fcc4f0b 100644 --- a/flashlight/fl/tensor/TensorBase.cpp +++ b/flashlight/fl/tensor/TensorBase.cpp @@ -15,20 +15,20 @@ #include "flashlight/fl/tensor/TensorAdapter.h" #include "flashlight/fl/tensor/TensorBackend.h" -#define FL_TENSOR_BACKENDS_MATCH_CHECK(...) \ - if (!detail::areBackendsEqual(__VA_ARGS__)) { \ - throw std::invalid_argument( \ - std::string(__func__) + \ - " called with tensors of different backends."); \ - } +#define FL_TENSOR_BACKENDS_MATCH_CHECK(...) \ + if(!detail::areBackendsEqual(__VA_ARGS__)) { \ + throw std::invalid_argument( \ + std::string(__func__) + \ + " called with tensors of different backends." \ + ); \ + } namespace fl { -Tensor::Tensor(std::unique_ptr adapter) - : impl_(std::move(adapter)) {} +Tensor::Tensor(std::unique_ptr adapter) : impl_(std::move(adapter)) {} std::unique_ptr Tensor::releaseAdapter() { - return std::move(impl_); + return std::move(impl_); } Tensor::~Tensor() = default; @@ -43,8 +43,8 @@ Tensor::Tensor( const Shape& shape, fl::dtype type, const void* ptr, - MemoryLocation memoryLocation) - : impl_(detail::getDefaultAdapter(shape, type, ptr, memoryLocation)) {} + MemoryLocation memoryLocation +) : impl_(detail::getDefaultAdapter(shape, type, ptr, memoryLocation)) {} Tensor::Tensor( const Dim nRows, @@ -52,149 +52,149 @@ Tensor::Tensor( const Tensor& values, const Tensor& rowIdx, const Tensor& colIdx, - StorageType storageType) - : impl_(detail::getDefaultAdapter( - nRows, - nCols, - values, - rowIdx, - colIdx, - storageType)) {} + StorageType storageType +) : impl_(detail::getDefaultAdapter( + nRows, + nCols, + values, + rowIdx, + colIdx, + storageType)) {} -Tensor::Tensor(const Shape& shape, fl::dtype type /* = fl::dtype::f32 */) - : impl_(detail::getDefaultAdapter(shape, type)) {} +Tensor::Tensor( + const Shape& shape, + fl::dtype type /* = fl::dtype::f32 */ +) : impl_(detail::getDefaultAdapter(shape, + type)) {} -Tensor::Tensor(fl::dtype type) - : impl_(detail::getDefaultAdapter(Shape({0}), type)) {} +Tensor::Tensor(fl::dtype type) : impl_(detail::getDefaultAdapter(Shape({ 0 }), type)) {} Tensor Tensor::copy() const { - return impl_->copy(); + return impl_->copy(); } Tensor Tensor::shallowCopy() const { - return impl_->shallowCopy(); + return impl_->shallowCopy(); } const Shape& Tensor::shape() const { - return impl_->shape(); + return impl_->shape(); } Location Tensor::location() const { - return impl_->location(); + return impl_->location(); } size_t Tensor::elements() const { - return impl_->shape().elements(); + return impl_->shape().elements(); } Dim Tensor::dim(const size_t dim) const { - return shape().dim(dim); + return shape().dim(dim); } int Tensor::ndim() const { - return shape().ndim(); + return shape().ndim(); } bool Tensor::isEmpty() const { - return elements() == 0; + return elements() == 0; } bool Tensor::hasAdapter() const { - return impl_.get() != nullptr; + return impl_.get() != nullptr; } size_t Tensor::bytes() const { - return elements() * getTypeSize(type()); + return elements() * getTypeSize(type()); } dtype Tensor::type() const { - return impl_->type(); + return impl_->type(); } bool Tensor::isSparse() const { - return impl_->isSparse(); + return impl_->isSparse(); } Tensor Tensor::astype(const dtype type) const { - return impl_->astype(type); + return impl_->astype(type); } Tensor Tensor::operator()(const std::vector& indices) const { - return impl_->index(indices); + return impl_->index(indices); } Tensor Tensor::flatten() const { - return impl_->flatten(); + return impl_->flatten(); } Tensor Tensor::flat(const Index& idx) const { - return impl_->flat(idx); + return impl_->flat(idx); } Tensor Tensor::asContiguousTensor() const { - return impl_->asContiguousTensor(); + return impl_->asContiguousTensor(); } TensorBackendType Tensor::backendType() const { - return impl_->backendType(); + return impl_->backendType(); } TensorBackend& Tensor::backend() const { - return impl_->backend(); -} - -#define FL_CREATE_MEMORY_OPS(TYPE) \ - template <> \ - FL_API TYPE Tensor::scalar() const { \ - if (isEmpty()) { \ - throw std::invalid_argument("Tensor::scalar called on empty tensor"); \ - } \ - if (type() != dtype_traits::fl_type) { \ - throw std::invalid_argument( \ - "Tensor::scalar: requested type of " + \ - std::string(dtype_traits::getName()) + \ - " doesn't match tensor type, which is " + dtypeToString(type())); \ - } \ - TYPE out; \ - impl_->scalar(&out); \ - return out; \ - } \ - \ - template <> \ - FL_API TYPE* Tensor::device() const { \ - if (isEmpty()) { \ - return nullptr; \ - } \ - TYPE* out; \ - void** addr = reinterpret_cast(&out); \ - impl_->device(addr); \ - return out; \ - } \ - \ - template <> \ - FL_API void Tensor::device(TYPE** ptr) const { \ - if (isEmpty()) { \ - return; \ - } \ - impl_->device(reinterpret_cast(ptr)); \ - } \ - \ - template <> \ - FL_API TYPE* Tensor::host() const { \ - if (isEmpty()) { \ - return nullptr; \ - } \ - TYPE* out = reinterpret_cast(new char[bytes()]); \ - impl_->host(out); \ - return out; \ - } \ - \ - template <> \ - FL_API void Tensor::host(TYPE* ptr) const { \ - if (!isEmpty()) { \ - impl_->host(ptr); \ - } \ - } + return impl_->backend(); +} + +#define FL_CREATE_MEMORY_OPS(TYPE) \ + template<> FL_API TYPE Tensor::scalar() const { \ + if(isEmpty()) { \ + throw std::invalid_argument("Tensor::scalar called on empty tensor"); \ + } \ + if(type() != dtype_traits::fl_type) { \ + throw std::invalid_argument( \ + "Tensor::scalar: requested type of " + \ + std::string(dtype_traits::getName()) + \ + " doesn't match tensor type, which is " + dtypeToString(type()) \ + ); \ + } \ + TYPE out; \ + impl_->scalar(&out); \ + return out; \ + } \ + \ + template<> FL_API TYPE * Tensor::device() const { \ + if(isEmpty()) { \ + return nullptr; \ + } \ + TYPE* out; \ + void** addr = reinterpret_cast(&out); \ + impl_->device(addr); \ + return out; \ + } \ + \ + template<> \ + FL_API void Tensor::device(TYPE * *ptr) const { \ + if(isEmpty()) { \ + return; \ + } \ + impl_->device(reinterpret_cast(ptr)); \ + } \ + \ + template<> FL_API TYPE * Tensor::host() const { \ + if(isEmpty()) { \ + return nullptr; \ + } \ + TYPE* out = reinterpret_cast(new char[bytes()]); \ + impl_->host(out); \ + return out; \ + } \ + \ + template<> \ + FL_API void Tensor::host(TYPE * ptr) const { \ + if(!isEmpty()) { \ + impl_->host(ptr); \ + } \ + } FL_CREATE_MEMORY_OPS(int); FL_CREATE_MEMORY_OPS(unsigned); FL_CREATE_MEMORY_OPS(char); @@ -208,101 +208,101 @@ FL_CREATE_MEMORY_OPS(float); FL_CREATE_MEMORY_OPS(short); FL_CREATE_MEMORY_OPS(unsigned short); // void specializations -template <> +template<> FL_API void* Tensor::device() const { - if (isEmpty()) { - return nullptr; - } - void* out; - impl_->device(&out); - return out; + if(isEmpty()) { + return nullptr; + } + void* out; + impl_->device(&out); + return out; } -template <> +template<> FL_API void Tensor::device(void** ptr) const { - if (isEmpty()) { - return; - } - impl_->device(ptr); + if(isEmpty()) { + return; + } + impl_->device(ptr); } -template <> +template<> FL_API void* Tensor::host() const { - if (isEmpty()) { - return nullptr; - } - void* out = reinterpret_cast(new char[bytes()]); - impl_->host(out); - return out; + if(isEmpty()) { + return nullptr; + } + void* out = reinterpret_cast(new char[bytes()]); + impl_->host(out); + return out; } -template <> +template<> FL_API void Tensor::host(void* ptr) const { - impl_->host(ptr); + impl_->host(ptr); } #undef FL_CREATE_MEMORY_OPS void Tensor::unlock() const { - impl_->unlock(); + impl_->unlock(); } bool Tensor::isLocked() const { - return impl_->isLocked(); + return impl_->isLocked(); } bool Tensor::isContiguous() const { - return impl_->isContiguous(); + return impl_->isContiguous(); } Shape Tensor::strides() const { - return impl_->strides(); + return impl_->strides(); } const Stream& Tensor::stream() const { - return impl_->stream(); + return impl_->stream(); } void Tensor::setContext(void* context) { - impl_->setContext(context); + impl_->setContext(context); } void* Tensor::getContext() const { - return impl_->getContext(); + return impl_->getContext(); } std::string Tensor::toString() const { - return impl_->toString(); + return impl_->toString(); } std::ostream& Tensor::operator<<(std::ostream& ostr) const { - return impl_->operator<<(ostr); + return impl_->operator<<(ostr); } /******************** Assignment Operators ********************/ #define FL_ASSIGN_OP_TYPE(OP, FUN, TYPE) \ - Tensor& Tensor::OP(TYPE val) { \ - impl_->FUN(val); \ - return *this; \ - } + Tensor & Tensor::OP(TYPE val) { \ + impl_->FUN(val); \ + return *this; \ + } #define FL_ASSIGN_TENSOR_OP(OP, FUN) FL_ASSIGN_OP_TYPE(OP, FUN, const Tensor&); -#define FL_ASSIGN_SCALAR_OP(OP, FUN) \ - FL_ASSIGN_OP_TYPE(OP, FUN, const double&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const float&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const int&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const unsigned&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const bool&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const char&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const unsigned char&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const short&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const unsigned short&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const long&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const unsigned long&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const long long&); \ - FL_ASSIGN_OP_TYPE(OP, FUN, const unsigned long long&); - -#define FL_ASSIGN_OP(OP, FUN) \ - FL_ASSIGN_TENSOR_OP(OP, FUN); \ - FL_ASSIGN_SCALAR_OP(OP, FUN); +#define FL_ASSIGN_SCALAR_OP(OP, FUN) \ + FL_ASSIGN_OP_TYPE(OP, FUN, const double&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const float&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const int&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const unsigned&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const bool&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const char&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const unsigned char&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const short&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const unsigned short&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const long&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const unsigned long&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const long long&); \ + FL_ASSIGN_OP_TYPE(OP, FUN, const unsigned long long&); + +#define FL_ASSIGN_OP(OP, FUN) \ + FL_ASSIGN_TENSOR_OP(OP, FUN); \ + FL_ASSIGN_SCALAR_OP(OP, FUN); // (operator, function name on impl) FL_ASSIGN_SCALAR_OP(operator=, assign); @@ -318,43 +318,41 @@ FL_ASSIGN_OP(operator/=, inPlaceDivide); // Move assignment operator when `this` is a lvalue, e.g., `x = std::move(y)`. // In such cases, we let `this` take over the tensor data of `other`. Tensor& Tensor::operator=(Tensor&& other) & { - this->impl_ = std::move(other.impl_); - return *this; + this->impl_ = std::move(other.impl_); + return *this; } // Move assignment operator when `this` is a rvalue, e.g., `x(0) = // std::move(y)`. In such cases, we copy the data from `other` to `this`. Tensor& Tensor::operator=(Tensor&& other) && { - this->impl_->assign(other); - return *this; + this->impl_->assign(other); + return *this; } // Copy assignment operator when `this` is a lvalue, e.g., `x = y`. // In such cases, we let `this` take over the _cloned_ data from `other`. Tensor& Tensor::operator=(const Tensor& other) & { - this->impl_ = other.impl_->clone(); - return *this; + this->impl_ = other.impl_->clone(); + return *this; } // Copy assignment operator when `this` is a lvalue, e.g., `x(0) = y`. // In such cases, we copy the data from `other` to `this`. Tensor& Tensor::operator=(const Tensor& other) && { - this->impl_->assign(other); - return *this; + this->impl_->assign(other); + return *this; } /* --------------------------- Tensor Operators --------------------------- */ /******************** Tensor Creation Functions ********************/ -#define FL_CREATE_FUN_LITERAL_TYPE(TYPE) \ - template <> \ - FL_API Tensor fromScalar(TYPE value, const dtype type) { \ - return defaultTensorBackend().fromScalar(value, type); \ - } \ - template <> \ - FL_API Tensor full(const Shape& dims, TYPE value, const dtype type) { \ - return defaultTensorBackend().full(dims, value, type); \ - } +#define FL_CREATE_FUN_LITERAL_TYPE(TYPE) \ + template<> FL_API Tensor fromScalar(TYPE value, const dtype type) { \ + return defaultTensorBackend().fromScalar(value, type); \ + } \ + template<> FL_API Tensor full(const Shape& dims, TYPE value, const dtype type) { \ + return defaultTensorBackend().full(dims, value, type); \ + } FL_CREATE_FUN_LITERAL_TYPE(const double&); FL_CREATE_FUN_LITERAL_TYPE(const float&); FL_CREATE_FUN_LITERAL_TYPE(const int&); @@ -371,16 +369,15 @@ FL_CREATE_FUN_LITERAL_TYPE(const unsigned short&); #undef FL_CREATE_FUN_LITERAL_TYPE Tensor identity(const Dim dim, const dtype type) { - return defaultTensorBackend().identity(dim, type); + return defaultTensorBackend().identity(dim, type); } -#define FL_ARANGE_FUN_DEF(TYPE) \ - template <> \ - FL_API Tensor arange(TYPE start, TYPE end, TYPE step, const dtype type) { \ - return fl::arange({static_cast((end - start) / step)}, 0, type) * \ - step + \ - start; \ - } +#define FL_ARANGE_FUN_DEF(TYPE) \ + template<> FL_API Tensor arange(TYPE start, TYPE end, TYPE step, const dtype type) { \ + return fl::arange({static_cast((end - start) / step)}, 0, type) * \ + step + \ + start; \ + } FL_ARANGE_FUN_DEF(const double&); FL_ARANGE_FUN_DEF(const float&); FL_ARANGE_FUN_DEF(const int&); @@ -391,178 +388,184 @@ FL_ARANGE_FUN_DEF(const long long&); FL_ARANGE_FUN_DEF(const unsigned long long&); Tensor arange(const Shape& shape, const Dim seqDim, const dtype type) { - return defaultTensorBackend().arange(shape, seqDim, type); + return defaultTensorBackend().arange(shape, seqDim, type); } Tensor iota(const Shape& dims, const Shape& tileDims, const dtype type) { - return defaultTensorBackend().iota(dims, tileDims, type); + return defaultTensorBackend().iota(dims, tileDims, type); } /************************ Shaping and Indexing *************************/ Tensor reshape(const Tensor& tensor, const Shape& shape) { - return tensor.backend().reshape(tensor, shape); + return tensor.backend().reshape(tensor, shape); } Tensor transpose(const Tensor& tensor, const Shape& axes /* = {} */) { - return tensor.backend().transpose(tensor, axes); + return tensor.backend().transpose(tensor, axes); } Tensor tile(const Tensor& tensor, const Shape& shape) { - return tensor.backend().tile(tensor, shape); + return tensor.backend().tile(tensor, shape); } Tensor concatenate(const std::vector& tensors, const unsigned axis) { - if (tensors.empty()) { - throw std::invalid_argument("concatenate: called on empty set of tensors"); - } - - // Check all backends match - const TensorBackendType b = tensors.front().backendType(); - const bool matches = - std::all_of(tensors.begin(), tensors.end(), [b](const Tensor& t) { - return t.backendType() == b; - }); - if (!matches) { - throw std::invalid_argument( - "concatenate: tried to concatenate tensors of different backends"); - } - - return tensors.front().backend().concatenate(tensors, axis); + if(tensors.empty()) { + throw std::invalid_argument("concatenate: called on empty set of tensors"); + } + + // Check all backends match + const TensorBackendType b = tensors.front().backendType(); + const bool matches = + std::all_of( + tensors.begin(), + tensors.end(), + [b](const Tensor& t) { + return t.backendType() == b; + } + ); + if(!matches) { + throw std::invalid_argument( + "concatenate: tried to concatenate tensors of different backends" + ); + } + + return tensors.front().backend().concatenate(tensors, axis); } Tensor nonzero(const Tensor& tensor) { - return tensor.backend().nonzero(tensor); + return tensor.backend().nonzero(tensor); } Tensor pad( const Tensor& input, const std::vector>& padWidths, - const PadType type) { - return input.backend().pad(input, padWidths, type); + const PadType type +) { + return input.backend().pad(input, padWidths, type); } /************************** Unary Operators ***************************/ Tensor exp(const Tensor& tensor) { - return tensor.backend().exp(tensor); + return tensor.backend().exp(tensor); } Tensor log(const Tensor& tensor) { - return tensor.backend().log(tensor); + return tensor.backend().log(tensor); } Tensor negative(const Tensor& tensor) { - return tensor.backend().negative(tensor); + return tensor.backend().negative(tensor); } Tensor logicalNot(const Tensor& tensor) { - return tensor.backend().logicalNot(tensor); + return tensor.backend().logicalNot(tensor); } Tensor log1p(const Tensor& tensor) { - return tensor.backend().log1p(tensor); + return tensor.backend().log1p(tensor); } Tensor sin(const Tensor& tensor) { - return tensor.backend().sin(tensor); + return tensor.backend().sin(tensor); } Tensor cos(const Tensor& tensor) { - return tensor.backend().cos(tensor); + return tensor.backend().cos(tensor); } Tensor sqrt(const Tensor& tensor) { - return tensor.backend().sqrt(tensor); + return tensor.backend().sqrt(tensor); } Tensor tanh(const Tensor& tensor) { - return tensor.backend().tanh(tensor); + return tensor.backend().tanh(tensor); } Tensor floor(const Tensor& tensor) { - return tensor.backend().floor(tensor); + return tensor.backend().floor(tensor); } Tensor ceil(const Tensor& tensor) { - return tensor.backend().ceil(tensor); + return tensor.backend().ceil(tensor); } Tensor rint(const Tensor& tensor) { - return tensor.backend().rint(tensor); + return tensor.backend().rint(tensor); } Tensor absolute(const Tensor& tensor) { - return tensor.backend().absolute(tensor); + return tensor.backend().absolute(tensor); } Tensor sigmoid(const Tensor& tensor) { - return tensor.backend().sigmoid(tensor); + return tensor.backend().sigmoid(tensor); } Tensor erf(const Tensor& tensor) { - return tensor.backend().erf(tensor); + return tensor.backend().erf(tensor); } Tensor flip(const Tensor& tensor, const unsigned dim) { - return tensor.backend().flip(tensor, dim); + return tensor.backend().flip(tensor, dim); } Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high) { - FL_TENSOR_BACKENDS_MATCH_CHECK(tensor, low, high); - return tensor.backend().clip(tensor, low, high); + FL_TENSOR_BACKENDS_MATCH_CHECK(tensor, low, high); + return tensor.backend().clip(tensor, low, high); } Tensor clip(const Tensor& tensor, const Tensor& low, const double& high) { - FL_TENSOR_BACKENDS_MATCH_CHECK(tensor, low); - return tensor.backend().clip(tensor, low, high); + FL_TENSOR_BACKENDS_MATCH_CHECK(tensor, low); + return tensor.backend().clip(tensor, low, high); } Tensor clip(const Tensor& tensor, const double& low, const Tensor& high) { - FL_TENSOR_BACKENDS_MATCH_CHECK(tensor, high); - return tensor.backend().clip(tensor, low, high); + FL_TENSOR_BACKENDS_MATCH_CHECK(tensor, high); + return tensor.backend().clip(tensor, low, high); } Tensor clip(const Tensor& tensor, const double& low, const double& high) { - return tensor.backend().clip(tensor, low, high); + return tensor.backend().clip(tensor, low, high); } Tensor roll(const Tensor& tensor, const int shift, const unsigned axis) { - return tensor.backend().roll(tensor, shift, axis); + return tensor.backend().roll(tensor, shift, axis); } Tensor isnan(const Tensor& tensor) { - return tensor.backend().isnan(tensor); + return tensor.backend().isnan(tensor); } Tensor isinf(const Tensor& tensor) { - return tensor.backend().isinf(tensor); + return tensor.backend().isinf(tensor); } Tensor sign(const Tensor& tensor) { - return tensor.backend().sign(tensor); + return tensor.backend().sign(tensor); } Tensor tril(const Tensor& tensor) { - return tensor.backend().tril(tensor); + return tensor.backend().tril(tensor); } Tensor triu(const Tensor& tensor) { - return tensor.backend().triu(tensor); + return tensor.backend().triu(tensor); } Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) { - FL_TENSOR_BACKENDS_MATCH_CHECK(condition, x, y); - return condition.backend().where(condition, x, y); + FL_TENSOR_BACKENDS_MATCH_CHECK(condition, x, y); + return condition.backend().where(condition, x, y); } Tensor where(const Tensor& condition, const Tensor& x, const double& y) { - FL_TENSOR_BACKENDS_MATCH_CHECK(condition, x); - return condition.backend().where(condition, x, y); + FL_TENSOR_BACKENDS_MATCH_CHECK(condition, x); + return condition.backend().where(condition, x, y); } Tensor where(const Tensor& condition, const double& x, const Tensor& y) { - FL_TENSOR_BACKENDS_MATCH_CHECK(condition, y); - return condition.backend().where(condition, x, y); + FL_TENSOR_BACKENDS_MATCH_CHECK(condition, y); + return condition.backend().where(condition, x, y); } void topk( @@ -571,13 +574,14 @@ void topk( const Tensor& input, const unsigned k, const Dim axis, - const SortMode sortMode /* = SortMode::Descending */) { - FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input); - input.backend().topk(values, indices, input, k, axis, sortMode); + const SortMode sortMode /* = SortMode::Descending */ +) { + FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input); + input.backend().topk(values, indices, input, k, axis, sortMode); } Tensor sort(const Tensor& input, const Dim axis, const SortMode sortMode) { - return input.backend().sort(input, axis, sortMode); + return input.backend().sort(input, axis, sortMode); } void sort( @@ -585,53 +589,54 @@ void sort( Tensor& indices, const Tensor& input, const Dim axis, - const SortMode sortMode /* = SortMode::Descending */) { - return values.backend().sort(values, indices, input, axis, sortMode); + const SortMode sortMode /* = SortMode::Descending */ +) { + return values.backend().sort(values, indices, input, axis, sortMode); } Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) { - return input.backend().argsort(input, axis, sortMode); + return input.backend().argsort(input, axis, sortMode); } /************************** Binary Operators ***************************/ -#define FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, TYPE) \ - Tensor FUNC(TYPE lhs, const Tensor& rhs) { \ - return rhs.backend().FUNC(lhs, rhs); \ - } \ - Tensor FUNC(const Tensor& lhs, TYPE rhs) { \ - return lhs.backend().FUNC(lhs, rhs); \ - } \ - Tensor operator OP(TYPE lhs, const Tensor& rhs) { \ - return FUNC(lhs, rhs); \ - } \ - Tensor operator OP(const Tensor& lhs, TYPE rhs) { \ - return FUNC(lhs, rhs); \ - } - -#define FL_BINARY_OP_LITERALS_DEF(OP, FUNC) \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const bool&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const int&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const unsigned&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const char&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const unsigned char&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const long&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const unsigned long&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const long long&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const unsigned long long&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const double&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const float&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const short&); \ - FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const unsigned short&); - -#define FL_BINARY_OP_DEF(OP, FUNC) \ - Tensor FUNC(const Tensor& lhs, const Tensor& rhs) { \ - FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); \ - return lhs.backend().FUNC(lhs, rhs); \ - } \ - Tensor operator OP(const Tensor& lhs, const Tensor& rhs) { \ - return FUNC(lhs, rhs); \ - } \ - FL_BINARY_OP_LITERALS_DEF(OP, FUNC); +#define FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, TYPE) \ + Tensor FUNC(TYPE lhs, const Tensor& rhs) { \ + return rhs.backend().FUNC(lhs, rhs); \ + } \ + Tensor FUNC(const Tensor& lhs, TYPE rhs) { \ + return lhs.backend().FUNC(lhs, rhs); \ + } \ + Tensor operator OP(TYPE lhs, const Tensor& rhs) { \ + return FUNC(lhs, rhs); \ + } \ + Tensor operator OP(const Tensor& lhs, TYPE rhs) { \ + return FUNC(lhs, rhs); \ + } + +#define FL_BINARY_OP_LITERALS_DEF(OP, FUNC) \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const bool&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const int&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const unsigned&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const char&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const unsigned char&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const long&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const unsigned long&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const long long&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const unsigned long long&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const double&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const float&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const short&); \ + FL_BINARY_OP_LITERAL_TYPE_DEF(OP, FUNC, const unsigned short&); + +#define FL_BINARY_OP_DEF(OP, FUNC) \ + Tensor FUNC(const Tensor& lhs, const Tensor& rhs) { \ + FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); \ + return lhs.backend().FUNC(lhs, rhs); \ + } \ + Tensor operator OP(const Tensor& lhs, const Tensor& rhs) { \ + return FUNC(lhs, rhs); \ + } \ + FL_BINARY_OP_LITERALS_DEF(OP, FUNC); FL_BINARY_OP_DEF(+, add); FL_BINARY_OP_DEF(-, sub); @@ -657,42 +662,42 @@ FL_BINARY_OP_DEF(>>, rShift); #undef FL_BINARY_OP_LITERAL_TYPE_DEF Tensor minimum(const Tensor& lhs, const Tensor& rhs) { - FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); - return lhs.backend().minimum(lhs, rhs); + FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); + return lhs.backend().minimum(lhs, rhs); } Tensor maximum(const Tensor& lhs, const Tensor& rhs) { - FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); - return lhs.backend().maximum(lhs, rhs); + FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); + return lhs.backend().maximum(lhs, rhs); } Tensor minimum(const Tensor& lhs, const double& rhs) { - return lhs.backend().minimum(lhs, rhs); + return lhs.backend().minimum(lhs, rhs); } Tensor minimum(const double& lhs, const Tensor& rhs) { - return rhs.backend().minimum(lhs, rhs); + return rhs.backend().minimum(lhs, rhs); } Tensor maximum(const Tensor& lhs, const double& rhs) { - return lhs.backend().maximum(lhs, rhs); + return lhs.backend().maximum(lhs, rhs); } Tensor maximum(const double& lhs, const Tensor& rhs) { - return rhs.backend().maximum(lhs, rhs); + return rhs.backend().maximum(lhs, rhs); } Tensor power(const Tensor& lhs, const Tensor& rhs) { - FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); - return lhs.backend().power(lhs, rhs); + FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); + return lhs.backend().power(lhs, rhs); } Tensor power(const Tensor& lhs, const double& rhs) { - return lhs.backend().power(lhs, rhs); + return lhs.backend().power(lhs, rhs); } Tensor power(const double& lhs, const Tensor& rhs) { - return rhs.backend().power(lhs, rhs); + return rhs.backend().power(lhs, rhs); } /******************************* BLAS ********************************/ @@ -700,9 +705,10 @@ Tensor matmul( const Tensor& lhs, const Tensor& rhs, MatrixProperty lhsProp, - MatrixProperty rhsProp) { - FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); - return lhs.backend().matmul(lhs, rhs, lhsProp, rhsProp); + MatrixProperty rhsProp +) { + FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); + return lhs.backend().matmul(lhs, rhs, lhsProp, rhsProp); } /************************** Reductions ***************************/ @@ -710,15 +716,17 @@ Tensor matmul( Tensor amin( const Tensor& input, const std::vector& axes /* = {} */, - const bool keepDims /* = false */) { - return input.backend().amin(input, axes, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().amin(input, axes, keepDims); } Tensor amax( const Tensor& input, const std::vector& axes /* = {} */, - const bool keepDims /* = false */) { - return input.backend().amax(input, axes, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().amax(input, axes, keepDims); } void min( @@ -726,9 +734,10 @@ void min( Tensor& indices, const Tensor& input, const unsigned axis, - const bool keepDims) { - FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input); - return input.backend().min(values, indices, input, axis, keepDims); + const bool keepDims +) { + FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input); + return input.backend().min(values, indices, input, axis, keepDims); } void max( @@ -736,157 +745,170 @@ void max( Tensor& indices, const Tensor& input, const unsigned axis, - const bool keepDims /* = false */) { - FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input); - return input.backend().max(values, indices, input, axis, keepDims); + const bool keepDims /* = false */ +) { + FL_TENSOR_BACKENDS_MATCH_CHECK(values, indices, input); + return input.backend().max(values, indices, input, axis, keepDims); } Tensor sum( const Tensor& input, const std::vector& axes /* = {} */, - const bool keepDims /* = false */) { - return input.backend().sum(input, axes, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().sum(input, axes, keepDims); } Tensor cumsum(const Tensor& input, const unsigned axis) { - return input.backend().cumsum(input, axis); + return input.backend().cumsum(input, axis); } Tensor argmax( const Tensor& input, const unsigned axis, - const bool keepDims /* = false */) { - return input.backend().argmax(input, axis, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().argmax(input, axis, keepDims); } Tensor argmin( const Tensor& input, const unsigned axis, - const bool keepDims /* = false */) { - return input.backend().argmin(input, axis, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().argmin(input, axis, keepDims); } Tensor mean( const Tensor& input, const std::vector& axes /* = {} */, - const bool keepDims /* = false */) { - return input.backend().mean(input, axes, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().mean(input, axes, keepDims); } Tensor median( const Tensor& input, const std::vector& axes /* = {} */, - const bool keepDims /* = false */) { - return input.backend().median(input, axes, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().median(input, axes, keepDims); } Tensor var( const Tensor& input, const std::vector& axes /* = {} */, const bool bias, - const bool keepDims /* = false */) { - return input.backend().var(input, axes, bias, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().var(input, axes, bias, keepDims); } Tensor std( const Tensor& input, const std::vector& axes /* = {} */, - const bool keepDims /* = false */) { - return input.backend().std(input, axes, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().std(input, axes, keepDims); } Tensor norm( const Tensor& input, const std::vector& axes /* = {} */, double p /* = 2 */, - const bool keepDims /* = false */) { - return input.backend().norm(input, axes, p, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().norm(input, axes, p, keepDims); } Tensor countNonzero( const Tensor& input, const std::vector& axes /* = {} */, - const bool keepDims /* = false */) { - return input.backend().countNonzero(input, axes, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().countNonzero(input, axes, keepDims); } Tensor any( const Tensor& input, const std::vector& axes /* = {} */, - const bool keepDims /* = false */) { - return input.backend().any(input, axes, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().any(input, axes, keepDims); } Tensor all( const Tensor& input, const std::vector& axes /* = {} */, - const bool keepDims /* = false */) { - return input.backend().all(input, axes, keepDims); + const bool keepDims /* = false */ +) { + return input.backend().all(input, axes, keepDims); } /************************** Utilities ***************************/ std::ostream& operator<<(std::ostream& ostr, const Tensor& t) { - t.operator<<(ostr); - return ostr; + t.operator<<(ostr); + return ostr; } void print(const Tensor& tensor) { - tensor.backend().print(tensor); + tensor.backend().print(tensor); } bool allClose( const fl::Tensor& a, const fl::Tensor& b, - const double absTolerance) { - if (a.type() != b.type()) { - return false; - } - if (a.shape() != b.shape()) { - return false; - } - if (a.elements() == 0 && b.elements() == 0) { - return true; - } - return fl::amax(fl::abs(a - b)).astype(dtype::f64).scalar() < - absTolerance; + const double absTolerance +) { + if(a.type() != b.type()) { + return false; + } + if(a.shape() != b.shape()) { + return false; + } + if(a.elements() == 0 && b.elements() == 0) { + return true; + } + return fl::amax(fl::abs(a - b)).astype(dtype::f64).scalar() + < absTolerance; } bool isInvalidArray(const Tensor& tensor) { - return fl::any(fl::isnan(tensor)).asScalar() || - fl::any(fl::isinf(tensor)).asScalar(); + return fl::any(fl::isnan(tensor)).asScalar() + || fl::any(fl::isinf(tensor)).asScalar(); } std::string tensorBackendTypeToString(const TensorBackendType type) { - switch (type) { - case TensorBackendType::Stub: - return "Stub"; - case TensorBackendType::Tracer: - return "Tracer"; - case TensorBackendType::ArrayFire: - return "ArrayFire"; - } - throw std::runtime_error("Unreachable -- unrecognized tensor backend type"); + switch(type) { + case TensorBackendType::Stub: + return "Stub"; + case TensorBackendType::Tracer: + return "Tracer"; + case TensorBackendType::ArrayFire: + return "ArrayFire"; + } + throw std::runtime_error("Unreachable -- unrecognized tensor backend type"); } std::ostream& operator<<(std::ostream& os, const TensorBackendType type) { - os << tensorBackendTypeToString(type); - return os; + os << tensorBackendTypeToString(type); + return os; } namespace detail { -std::unique_ptr releaseAdapter(Tensor&& t) { - return t.releaseAdapter(); -} + std::unique_ptr releaseAdapter(Tensor&& t) { + return t.releaseAdapter(); + } -std::unique_ptr releaseAdapterUnsafe(Tensor& t) { - return t.releaseAdapter(); -} + std::unique_ptr releaseAdapterUnsafe(Tensor& t) { + return t.releaseAdapter(); + } -bool areTensorTypesEqual(const Tensor& a, const Tensor& b) { - return a.type() == b.type(); -} + bool areTensorTypesEqual(const Tensor& a, const Tensor& b) { + return a.type() == b.type(); + } } // namespace detail diff --git a/flashlight/fl/tensor/TensorBase.h b/flashlight/fl/tensor/TensorBase.h index 17465a4..4f1f1bd 100644 --- a/flashlight/fl/tensor/TensorBase.h +++ b/flashlight/fl/tensor/TensorBase.h @@ -30,7 +30,7 @@ class Tensor; */ /// Enum for various tensor backends. -enum class TensorBackendType { Stub, Tracer, ArrayFire }; +enum class TensorBackendType {Stub, Tracer, ArrayFire}; // See TensorAdapter.h class TensorAdapterBase; @@ -45,19 +45,19 @@ struct Index; class Stream; /// Location of memory or tensors. -enum class Location { Host, Device }; +enum class Location {Host, Device}; /// Alias to make it semantically clearer when referring to buffer location using MemoryLocation = Location; /// Tensor storage types. -enum class StorageType { Dense = 0, CSR = 1, CSC = 2, COO = 3 }; +enum class StorageType {Dense = 0, CSR = 1, CSC = 2, COO = 3}; /* @} */ namespace detail { -FL_API std::unique_ptr releaseAdapter(Tensor&& t); -FL_API std::unique_ptr releaseAdapterUnsafe(Tensor& t); + FL_API std::unique_ptr releaseAdapter(Tensor&& t); + FL_API std::unique_ptr releaseAdapterUnsafe(Tensor& t); } // namespace detail @@ -76,571 +76,578 @@ FL_API std::unique_ptr releaseAdapterUnsafe(Tensor& t); * \warning This API may break and is not yet stable. */ class FL_API Tensor { - // The tensor adapter for the tensor - std::unique_ptr impl_; - - /* - * Construct a tensor with a given shape and using an existing buffer. - * - * This constructor is a void* specialization to facilitate interface - * compliance with TensorAdapter and is intentionally private. - */ - Tensor( - const Shape& shape, - fl::dtype type, - const void* ptr, - MemoryLocation memoryLocation); - - /** - * Shallow-copies the tensor, returning a tensor that points to the same - * underlying data. - * - * For internal use only. Tensor implementations should define when and where - * deep copies happen based on their dataflow graph abstractions. - * - * \todo slated for removal. Rely on copy-on-write and fix bad refcount - * issues. - */ - Tensor shallowCopy() const; - // shallowCopy() above is used in DevicePtr given that it doesn't mutate - // tensors in place with tensor operations, and only pulls out memory. - friend class DevicePtr; - // also used in tensor abstractions that wrap and call tensor ops: - friend class TracerTensorBase; - - /** - * Release and transfer ownership of the tensor's underlying - * TensorAdapterBase. - * - * NB: After unlocking the adapter, the resulting Tensor should - * *probably* be destroyed, as it has no adapter and thus can't perform any - * operations. - */ - - std::unique_ptr releaseAdapter(); - friend std::unique_ptr detail::releaseAdapter(Tensor&& t); - friend std::unique_ptr detail::releaseAdapterUnsafe( - Tensor& t); - - public: - explicit Tensor(std::unique_ptr adapter); - virtual ~Tensor(); - - /** - * Copy constructor - calls the implementation-defined copy constructor for - * the TensorAdapter. - */ - Tensor(const Tensor& tensor); - - /** - * Move constructor - moves the pointer to the TensorAdapter - performs no - * other operations. - */ - Tensor(Tensor&& tensor) noexcept; - - /** - * Construct an empty tensor with the default tensor backend's tensor adapter. - */ - Tensor(); - - /** - * Construct a tensor of a given shape (and optionally type) without - * populating its data. - * - * @param[in] shape the shape of the tensor - * @param[in] type (optional) the type of the tensor - */ - explicit Tensor(const Shape& shape, fl::dtype type = fl::dtype::f32); - - /** - * Construct an empty tensor of a given type. - * - * @param[in] type (optional) the type of the tensor - */ - explicit Tensor(fl::dtype type); - - /** - * Construct a sparse tensor. - * - * @param[in] nRows the number of rows of the tensor - * @param[in] nCols the number of columns of the tensor - * @param[in] values the values associated with the tensor - * @param[in] rowIdx the row indices of the sparse array - * @param[in] colIdx the the column indices of the sparse array - * @param[in] storageType the storage type of the underlying tensor - * - * \todo Expand this API with getters as needed. - */ - Tensor( - const Dim nRows, - const Dim nCols, - const Tensor& values, - const Tensor& rowIdx, - const Tensor& colIdx, - StorageType storageType); - - /** - * Create a tensor from a vector of values. - * - * @param[in] s the shape of the resulting tensor. - * @param[in] v values with which to populate the tensor. - * @return a tensor with values and shape as given. - */ - template - static Tensor fromVector(Shape s, std::vector v) { - return Tensor(s, fl::dtype_traits::fl_type, v.data(), Location::Host); - } - - template - static Tensor fromArray(Shape s, std::array a) { - return Tensor(s, fl::dtype_traits::fl_type, a.data(), Location::Host); - } - - template - static Tensor fromVector(Shape s, std::vector v, dtype type) { - return Tensor(s, type, v.data(), Location::Host); - } - - template - static Tensor fromArray(Shape s, std::array a, dtype type) { - return Tensor(s, type, a.data(), Location::Host); - } - - template - static Tensor fromVector(std::vector v) { - return Tensor( - {static_cast(v.size())}, - fl::dtype_traits::fl_type, - v.data(), - Location::Host); - } - - template - static Tensor fromArray(std::array a) { - return Tensor( - {static_cast(a.size())}, - fl::dtype_traits::fl_type, - a.data(), - Location::Host); - } - - /** - * Create a tensor from an existing buffer. - * - * @param[in] s the shape of the resulting tensor. - * @param[in] ptr the buffer containing the data - * @param[in] memoryLocation the location in memory where the input buffer - * with which to create the tensor resides. - * @return a tensor with values and shape as given. - */ - template - static Tensor fromBuffer(Shape s, const T* ptr, Location memoryLocation) { - return Tensor(s, fl::dtype_traits::fl_type, ptr, memoryLocation); - } - - /** - * Create a tensor from an existing byte buffer given a type. - * - * @param[in] s the shape of the resulting tensor. - * @param[in] t the type of the underlying tensor - * @param[in] ptr the buffer of bytes containing the data - * @param[in] memoryLocation the location in memory where the input buffer - * with which to create the tensor resides. - * @return a tensor with values and shape as given. - */ - static Tensor fromBuffer( - Shape s, - fl::dtype t, - const uint8_t* ptr, - Location memoryLocation) { - return Tensor(s, t, ptr, memoryLocation); - } - - /** - * Deep-copies the tensor, including underlying data. - */ - Tensor copy() const; - - /** - * Get the shape of a tensor. - * - * @return the shape of the tensor - */ - const Shape& shape() const; - - /** - * Get a tensor's location, host or some device. - * - * @return the tensor's location - */ - Location location() const; - - /** - * Get the number of elements in the tensor. - * - * @return the size of the tensor in elements. - */ - size_t elements() const; - - /** - * Get the size of a given dimension of a tensor in the number of arguments. - * Throws if the given dimension is larger than the number of tensor - * dimensions. - * - * @return the number of elements at the given dimension - */ - Dim dim(const size_t dim) const; - - /** - * Get the number of directions of the tensor. - * - * @return the number of dimensions - */ - int ndim() const; - - /** - * Returns true if the tensor has zero elements, else false. - * - * @return true if the tensor is empty - */ - bool isEmpty() const; - - /** - * Returns true if the tensor has an associated underlying adapter. - * - * @return true if the tensor has a valid adapter - */ - bool hasAdapter() const; - - /** - * Get the tensor size in bytes. - * - * @return the size of the tensor in bytes. - */ - size_t bytes() const; - - /** - * Get the data type of tensor. - * - * @return the dtype of the tensor - */ - dtype type() const; - - /** - * Returns whether or not the tensor is sparse. - * - * @return true if the tensor is a sparse tensor, else false - */ - bool isSparse() const; - - /** - * Get this tensor's strides - the number of elements/coefficients to step - * when moving along each dimension when traversing the tensor. - * - * @return a Shape containing strides in each dimension. - */ - Shape strides() const; - - /** - * Get the stream which contains(ed) the computation required to realize an - * up-to-date value for this tensor. For instance, `device()` may not yield a - * pointer to the up-to-date value -- to use this pointer, `Stream::sync` or - * `Stream::relativeSync` is required. - * - * @return an immutable reference to the stream that contains(ed) the - * computations which create this tensor. - */ - virtual const Stream& stream() const; - - /** - * Returns a tensor with elements cast as a particular type - * - * @param[in] type the type to which to cast the tensor - * @return a tensor with element-wise cast to the new type - */ - Tensor astype(const dtype type) const; - - /** - * Index into a tensor using a vector of fl::Index references. - * - * @param[in] indices a vector of fl::Index references with which to index. - * @return an indexed tensor - */ - Tensor operator()(const std::vector& indices) const; - - /** - * Index into a tensor using a variable number of fl::Index. - * - * @param[in] args fl::Index instances to use - * @return an indexed tensor - */ - template - Tensor operator()(const Ts&... args) const { - // TODO: add this back if acceptable with C++ 17 ABIs and a nvcc - // static_assert( - // std::conjunction...>::value, - // "Tensor index operator can only take Index-compatible types - " - // "fl::range, fl::Tensor, fl::span, and integer types."); - std::vector indices{{args...}}; - return this->operator()(indices); - } - - /** - * Returns a representation of the tensor in 1 dimension. - * - * @return a 1D version of this tensor 1D-indexed with the given index. - */ - Tensor flatten() const; - - /** - * Returns a tensor indexed from this tensor but indexed as a 1D/flattened - * tensor. - * - * @return an indexed, 1D version of this tensor. - */ - Tensor flat(const Index& idx) const; - - /** - * Return a copy (depending on copy-on-write behavior of the underlying - * implementation) of this tensor that is contigous in memory. - * - * @return an identical tensor that is contiguous in memory - */ - Tensor asContiguousTensor() const; - - /** - * Gets the backend enum from the underlying TensorAdapter. - * - * @return the backend in question - */ - TensorBackendType backendType() const; - - /** - * Gets the underlying tensor adapter implementation. - * - * @return the tensor adapter. - */ - template - T& getAdapter() const { - return *static_cast(impl_.get()); - } - - /** - * Return the TensorBackend associated with this tensor. - * - * @return a TensorBackend. - */ - TensorBackend& backend() const; - - /** - * Return a scalar of a specified type for the tensor. If the tensor has more - * than one element, returns the first element as a scalar. - * - * Throws an exception if the specified type does not match the dtype trait of - * the underlying tensor. To implicitly cast the scalar regardless of the - * underlying Tensor's dtype, use `asScalar`. - * - * @return a scalar of the first element in the tensor. - */ - template - T scalar() const; - - /** - * Return a scalar of the specified type of the tensor. If the specified type - * does not match the tensor's underlying dtype, the scalar value is - * implicitly cast. - * - * @return a scalar of the first element in the tensor cast to the specified - * type. - */ - template - T asScalar() const { - // Implicitly cast to the requested return type - switch (type()) { - case dtype::f16: - return astype(dtype::f32).scalar(); - case dtype::f32: - return scalar(); - case dtype::f64: - return scalar(); - case dtype::s32: - return scalar(); - case dtype::u32: - return scalar(); - case dtype::b8: - return scalar(); - case dtype::u8: - return scalar(); - case dtype::s64: - return scalar(); - case dtype::u64: - return scalar(); - case dtype::s16: - return scalar(); - case dtype::u16: - return scalar(); - default: - throw std::invalid_argument( - "Tensor::asScaler - no castable type exists."); + // The tensor adapter for the tensor + std::unique_ptr impl_; + + /* + * Construct a tensor with a given shape and using an existing buffer. + * + * This constructor is a void* specialization to facilitate interface + * compliance with TensorAdapter and is intentionally private. + */ + Tensor( + const Shape& shape, + fl::dtype type, + const void* ptr, + MemoryLocation memoryLocation + ); + + /** + * Shallow-copies the tensor, returning a tensor that points to the same + * underlying data. + * + * For internal use only. Tensor implementations should define when and where + * deep copies happen based on their dataflow graph abstractions. + * + * \todo slated for removal. Rely on copy-on-write and fix bad refcount + * issues. + */ + Tensor shallowCopy() const; + // shallowCopy() above is used in DevicePtr given that it doesn't mutate + // tensors in place with tensor operations, and only pulls out memory. + friend class DevicePtr; + // also used in tensor abstractions that wrap and call tensor ops: + friend class TracerTensorBase; + + /** + * Release and transfer ownership of the tensor's underlying + * TensorAdapterBase. + * + * NB: After unlocking the adapter, the resulting Tensor should + * *probably* be destroyed, as it has no adapter and thus can't perform any + * operations. + */ + + std::unique_ptr releaseAdapter(); + friend std::unique_ptr detail::releaseAdapter(Tensor&& t); + friend std::unique_ptr detail::releaseAdapterUnsafe( + Tensor& t + ); + +public: + explicit Tensor(std::unique_ptr adapter); + virtual ~Tensor(); + + /** + * Copy constructor - calls the implementation-defined copy constructor for + * the TensorAdapter. + */ + Tensor(const Tensor& tensor); + + /** + * Move constructor - moves the pointer to the TensorAdapter - performs no + * other operations. + */ + Tensor(Tensor&& tensor) noexcept; + + /** + * Construct an empty tensor with the default tensor backend's tensor adapter. + */ + Tensor(); + + /** + * Construct a tensor of a given shape (and optionally type) without + * populating its data. + * + * @param[in] shape the shape of the tensor + * @param[in] type (optional) the type of the tensor + */ + explicit Tensor(const Shape& shape, fl::dtype type = fl::dtype::f32); + + /** + * Construct an empty tensor of a given type. + * + * @param[in] type (optional) the type of the tensor + */ + explicit Tensor(fl::dtype type); + + /** + * Construct a sparse tensor. + * + * @param[in] nRows the number of rows of the tensor + * @param[in] nCols the number of columns of the tensor + * @param[in] values the values associated with the tensor + * @param[in] rowIdx the row indices of the sparse array + * @param[in] colIdx the the column indices of the sparse array + * @param[in] storageType the storage type of the underlying tensor + * + * \todo Expand this API with getters as needed. + */ + Tensor( + const Dim nRows, + const Dim nCols, + const Tensor& values, + const Tensor& rowIdx, + const Tensor& colIdx, + StorageType storageType + ); + + /** + * Create a tensor from a vector of values. + * + * @param[in] s the shape of the resulting tensor. + * @param[in] v values with which to populate the tensor. + * @return a tensor with values and shape as given. + */ + template + static Tensor fromVector(Shape s, std::vector v) { + return Tensor(s, fl::dtype_traits::fl_type, v.data(), Location::Host); + } + + template + static Tensor fromArray(Shape s, std::array a) { + return Tensor(s, fl::dtype_traits::fl_type, a.data(), Location::Host); + } + + template + static Tensor fromVector(Shape s, std::vector v, dtype type) { + return Tensor(s, type, v.data(), Location::Host); + } + + template + static Tensor fromArray(Shape s, std::array a, dtype type) { + return Tensor(s, type, a.data(), Location::Host); + } + + template + static Tensor fromVector(std::vector v) { + return Tensor( + {static_cast(v.size())}, + fl::dtype_traits::fl_type, + v.data(), + Location::Host + ); } - } - - /** - * Return a pointer to the tensor's underlying data per a certain type. This - * pointer exists on the computation device. - * - * \note The memory allocated here will not be freed until Tensor:unlock() is - * called. - * - * @return the requested pointer on the device. - */ - template - T* device() const; - - /** - * Populate a pointer value with the address of a Tensor's underlying buffer - * on the computation device. - * - * \note The memory allocated here will not be freed until Tensor:unlock() is - * called. - * - * @param[in] ptr the pointer to populate with the Tensor's buffer location on - * device. - */ - template - void device(T** ptr) const; - - /** - * Returns a pointer to the tensor's underlying data, but on the host. If the - * tensor is located on a device, makes a copy of device memory and returns a - * buffer on the host containing the relevant memory. - * - * @return the requested pointer on the host. - */ - template - T* host() const; - - /** - * Populates an existing buffer with the tensor's underlying data, but on the - * host. If the tensor is located on a device, makes a copy of device memory - * and returns a buffer on the host containing the relevant memory. - * - * @param[in] ptr a pointer to the region of memory to populate with tensor - * values - */ - template - void host(T* ptr) const; - - /** - * Returns a vector on the host contaning a flat representation of the tensor. - * The resulting vector is a copy of the underlying tensor memory, even if on - * the host. - * - * @return a vector in host memory containing - */ - template - std::vector toHostVector() const { - if (isEmpty()) { - return std::vector(); + + template + static Tensor fromArray(std::array a) { + return Tensor( + {static_cast(a.size())}, + fl::dtype_traits::fl_type, + a.data(), + Location::Host + ); + } + + /** + * Create a tensor from an existing buffer. + * + * @param[in] s the shape of the resulting tensor. + * @param[in] ptr the buffer containing the data + * @param[in] memoryLocation the location in memory where the input buffer + * with which to create the tensor resides. + * @return a tensor with values and shape as given. + */ + template + static Tensor fromBuffer(Shape s, const T* ptr, Location memoryLocation) { + return Tensor(s, fl::dtype_traits::fl_type, ptr, memoryLocation); + } + + /** + * Create a tensor from an existing byte buffer given a type. + * + * @param[in] s the shape of the resulting tensor. + * @param[in] t the type of the underlying tensor + * @param[in] ptr the buffer of bytes containing the data + * @param[in] memoryLocation the location in memory where the input buffer + * with which to create the tensor resides. + * @return a tensor with values and shape as given. + */ + static Tensor fromBuffer( + Shape s, + fl::dtype t, + const uint8_t* ptr, + Location memoryLocation + ) { + return Tensor(s, t, ptr, memoryLocation); + } + + /** + * Deep-copies the tensor, including underlying data. + */ + Tensor copy() const; + + /** + * Get the shape of a tensor. + * + * @return the shape of the tensor + */ + const Shape& shape() const; + + /** + * Get a tensor's location, host or some device. + * + * @return the tensor's location + */ + Location location() const; + + /** + * Get the number of elements in the tensor. + * + * @return the size of the tensor in elements. + */ + size_t elements() const; + + /** + * Get the size of a given dimension of a tensor in the number of arguments. + * Throws if the given dimension is larger than the number of tensor + * dimensions. + * + * @return the number of elements at the given dimension + */ + Dim dim(const size_t dim) const; + + /** + * Get the number of directions of the tensor. + * + * @return the number of dimensions + */ + int ndim() const; + + /** + * Returns true if the tensor has zero elements, else false. + * + * @return true if the tensor is empty + */ + bool isEmpty() const; + + /** + * Returns true if the tensor has an associated underlying adapter. + * + * @return true if the tensor has a valid adapter + */ + bool hasAdapter() const; + + /** + * Get the tensor size in bytes. + * + * @return the size of the tensor in bytes. + */ + size_t bytes() const; + + /** + * Get the data type of tensor. + * + * @return the dtype of the tensor + */ + dtype type() const; + + /** + * Returns whether or not the tensor is sparse. + * + * @return true if the tensor is a sparse tensor, else false + */ + bool isSparse() const; + + /** + * Get this tensor's strides - the number of elements/coefficients to step + * when moving along each dimension when traversing the tensor. + * + * @return a Shape containing strides in each dimension. + */ + Shape strides() const; + + /** + * Get the stream which contains(ed) the computation required to realize an + * up-to-date value for this tensor. For instance, `device()` may not yield a + * pointer to the up-to-date value -- to use this pointer, `Stream::sync` or + * `Stream::relativeSync` is required. + * + * @return an immutable reference to the stream that contains(ed) the + * computations which create this tensor. + */ + virtual const Stream& stream() const; + + /** + * Returns a tensor with elements cast as a particular type + * + * @param[in] type the type to which to cast the tensor + * @return a tensor with element-wise cast to the new type + */ + Tensor astype(const dtype type) const; + + /** + * Index into a tensor using a vector of fl::Index references. + * + * @param[in] indices a vector of fl::Index references with which to index. + * @return an indexed tensor + */ + Tensor operator()(const std::vector& indices) const; + + /** + * Index into a tensor using a variable number of fl::Index. + * + * @param[in] args fl::Index instances to use + * @return an indexed tensor + */ + template + Tensor operator()(const Ts&... args) const { + // TODO: add this back if acceptable with C++ 17 ABIs and a nvcc + // static_assert( + // std::conjunction...>::value, + // "Tensor index operator can only take Index-compatible types - " + // "fl::range, fl::Tensor, fl::span, and integer types."); + std::vector indices{{args...}}; + return this->operator()(indices); + } + + /** + * Returns a representation of the tensor in 1 dimension. + * + * @return a 1D version of this tensor 1D-indexed with the given index. + */ + Tensor flatten() const; + + /** + * Returns a tensor indexed from this tensor but indexed as a 1D/flattened + * tensor. + * + * @return an indexed, 1D version of this tensor. + */ + Tensor flat(const Index& idx) const; + + /** + * Return a copy (depending on copy-on-write behavior of the underlying + * implementation) of this tensor that is contigous in memory. + * + * @return an identical tensor that is contiguous in memory + */ + Tensor asContiguousTensor() const; + + /** + * Gets the backend enum from the underlying TensorAdapter. + * + * @return the backend in question + */ + TensorBackendType backendType() const; + + /** + * Gets the underlying tensor adapter implementation. + * + * @return the tensor adapter. + */ + template + T& getAdapter() const { + return *static_cast(impl_.get()); } - std::vector vec(this->elements()); - host(vec.data()); - return vec; - } - - /** - * Unlocks any device memory associated with the tensor that was acquired with - * Tensor::device(), making it eligible to be freed. - */ - void unlock() const; - - /** - * Returns true if the tensor has been memory-locked per a call to - * Tensor::device(). After unlocking via Tensor::unlock(), the tensor is no - * longer locked. - * - * @return true if the tensor is locked and a device pointer is active. - */ - bool isLocked() const; - - /** - * Returns if the Tensor is contiguous in its memory-based representation. - * - * @return a bool denoting Tensor contiguousness - */ - bool isContiguous() const; - - /** - * Stores arbitrary data on a tensor. For internal use/benchmarking only. This - * may be a no-op for some backends. - * - * @param[in] data a pointer to arbitrary data to pass to a tensor impl. - */ - void setContext(void* data); - - /** - * Gets arbitrary data stored on a tensor. For internal use/benchmarking only. - * This may be a no-op for some backends. - * - * @return a pointer to some implementation-defined data, else nullptr if a - * no-op. - */ - void* getContext() const; - - /** - * Returns a string representation of a Tensor. - * - * \note This is backend-dependent. See Flashlight's serialization utilities - * for ways to serialize Tensors that are portable across Tensor backends. - * - * @return a string representation of the Tensor. - */ - std::string toString() const; - - /** - * Write a string representation of a tensor to an output stream. - */ - std::ostream& operator<<(std::ostream& ostr) const; - - /******************** Assignment Operators ********************/ -#define ASSIGN_TENSOR_OP(OP) Tensor& OP(const Tensor& val); -#define ASSIGN_SCALAR_OP(OP) \ - Tensor& OP(const double& val); \ - Tensor& OP(const float& val); \ - Tensor& OP(const int& val); \ - Tensor& OP(const unsigned& val); \ - Tensor& OP(const bool& val); \ - Tensor& OP(const char& val); \ - Tensor& OP(const unsigned char& val); \ - Tensor& OP(const short& val); \ - Tensor& OP(const unsigned short& val); \ - Tensor& OP(const long& val); \ - Tensor& OP(const unsigned long& val); \ - Tensor& OP(const long long& val); \ - Tensor& OP(const unsigned long long& val); -#define ASSIGN_OP(OP) \ - ASSIGN_TENSOR_OP(OP); \ - ASSIGN_SCALAR_OP(OP); - - ASSIGN_SCALAR_OP(operator=); - ASSIGN_OP(operator+=); - ASSIGN_OP(operator-=); - ASSIGN_OP(operator*=); - ASSIGN_OP(operator/=); + + /** + * Return the TensorBackend associated with this tensor. + * + * @return a TensorBackend. + */ + TensorBackend& backend() const; + + /** + * Return a scalar of a specified type for the tensor. If the tensor has more + * than one element, returns the first element as a scalar. + * + * Throws an exception if the specified type does not match the dtype trait of + * the underlying tensor. To implicitly cast the scalar regardless of the + * underlying Tensor's dtype, use `asScalar`. + * + * @return a scalar of the first element in the tensor. + */ + template + T scalar() const; + + /** + * Return a scalar of the specified type of the tensor. If the specified type + * does not match the tensor's underlying dtype, the scalar value is + * implicitly cast. + * + * @return a scalar of the first element in the tensor cast to the specified + * type. + */ + template + T asScalar() const { + // Implicitly cast to the requested return type + switch(type()) { + case dtype::f16: + return astype(dtype::f32).scalar(); + case dtype::f32: + return scalar(); + case dtype::f64: + return scalar(); + case dtype::s32: + return scalar(); + case dtype::u32: + return scalar(); + case dtype::b8: + return scalar(); + case dtype::u8: + return scalar(); + case dtype::s64: + return scalar(); + case dtype::u64: + return scalar(); + case dtype::s16: + return scalar(); + case dtype::u16: + return scalar(); + default: + throw std::invalid_argument( + "Tensor::asScaler - no castable type exists." + ); + } + } + + /** + * Return a pointer to the tensor's underlying data per a certain type. This + * pointer exists on the computation device. + * + * \note The memory allocated here will not be freed until Tensor:unlock() is + * called. + * + * @return the requested pointer on the device. + */ + template + T* device() const; + + /** + * Populate a pointer value with the address of a Tensor's underlying buffer + * on the computation device. + * + * \note The memory allocated here will not be freed until Tensor:unlock() is + * called. + * + * @param[in] ptr the pointer to populate with the Tensor's buffer location on + * device. + */ + template + void device(T** ptr) const; + + /** + * Returns a pointer to the tensor's underlying data, but on the host. If the + * tensor is located on a device, makes a copy of device memory and returns a + * buffer on the host containing the relevant memory. + * + * @return the requested pointer on the host. + */ + template + T* host() const; + + /** + * Populates an existing buffer with the tensor's underlying data, but on the + * host. If the tensor is located on a device, makes a copy of device memory + * and returns a buffer on the host containing the relevant memory. + * + * @param[in] ptr a pointer to the region of memory to populate with tensor + * values + */ + template + void host(T* ptr) const; + + /** + * Returns a vector on the host contaning a flat representation of the tensor. + * The resulting vector is a copy of the underlying tensor memory, even if on + * the host. + * + * @return a vector in host memory containing + */ + template + std::vector toHostVector() const { + if(isEmpty()) { + return std::vector(); + } + std::vector vec(this->elements()); + host(vec.data()); + return vec; + } + + /** + * Unlocks any device memory associated with the tensor that was acquired with + * Tensor::device(), making it eligible to be freed. + */ + void unlock() const; + + /** + * Returns true if the tensor has been memory-locked per a call to + * Tensor::device(). After unlocking via Tensor::unlock(), the tensor is no + * longer locked. + * + * @return true if the tensor is locked and a device pointer is active. + */ + bool isLocked() const; + + /** + * Returns if the Tensor is contiguous in its memory-based representation. + * + * @return a bool denoting Tensor contiguousness + */ + bool isContiguous() const; + + /** + * Stores arbitrary data on a tensor. For internal use/benchmarking only. This + * may be a no-op for some backends. + * + * @param[in] data a pointer to arbitrary data to pass to a tensor impl. + */ + void setContext(void* data); + + /** + * Gets arbitrary data stored on a tensor. For internal use/benchmarking only. + * This may be a no-op for some backends. + * + * @return a pointer to some implementation-defined data, else nullptr if a + * no-op. + */ + void* getContext() const; + + /** + * Returns a string representation of a Tensor. + * + * \note This is backend-dependent. See Flashlight's serialization utilities + * for ways to serialize Tensors that are portable across Tensor backends. + * + * @return a string representation of the Tensor. + */ + std::string toString() const; + + /** + * Write a string representation of a tensor to an output stream. + */ + std::ostream& operator<<(std::ostream& ostr) const; + + /******************** Assignment Operators ********************/ +#define ASSIGN_TENSOR_OP(OP) Tensor & OP(const Tensor& val); +#define ASSIGN_SCALAR_OP(OP) \ + Tensor & OP(const double& val); \ + Tensor& OP(const float& val); \ + Tensor& OP(const int& val); \ + Tensor& OP(const unsigned& val); \ + Tensor& OP(const bool& val); \ + Tensor& OP(const char& val); \ + Tensor& OP(const unsigned char& val); \ + Tensor& OP(const short& val); \ + Tensor& OP(const unsigned short& val); \ + Tensor& OP(const long& val); \ + Tensor& OP(const unsigned long& val); \ + Tensor& OP(const long long& val); \ + Tensor& OP(const unsigned long long& val); +#define ASSIGN_OP(OP) \ + ASSIGN_TENSOR_OP(OP); \ + ASSIGN_SCALAR_OP(OP); + + ASSIGN_SCALAR_OP(operator=); + ASSIGN_OP(operator+=); + ASSIGN_OP(operator-=); + ASSIGN_OP(operator*=); + ASSIGN_OP(operator/=); #undef ASSIGN_TENSOR_OP #undef ASSIGN_SCALAR_OP #undef ASSIGN_OP - /* The following assignment operator differentiation via member method - * ref-qualifier ensures that - * 1. For `x = ...;`, the behavior is the same as the copy/move constructor. - * 2. For `... = ...`, a copy is made from the rhs tensor data to the lhs one. - * This allows tensor mutation via indexing, e.g., `t(0, 0) = 42`. - */ - Tensor& operator=(Tensor&& other) &; - Tensor& operator=(Tensor&& other) &&; - Tensor& operator=(const Tensor& other) &; - Tensor& operator=(const Tensor& other) &&; + /* The following assignment operator differentiation via member method + * ref-qualifier ensures that + * 1. For `x = ...;`, the behavior is the same as the copy/move constructor. + * 2. For `... = ...`, a copy is made from the rhs tensor data to the lhs one. + * This allows tensor mutation via indexing, e.g., `t(0, 0) = 42`. + */ + Tensor& operator=(Tensor&& other) &; + Tensor& operator=(Tensor&& other) &&; + Tensor& operator=(const Tensor& other) &; + Tensor& operator=(const Tensor& other) &&; }; /** @@ -659,9 +666,8 @@ class FL_API Tensor { * on the value type * @return a tensor of the specified shape filled with the specified value */ -template -FL_API Tensor -fromScalar(const T& val, const dtype type = dtype_traits::ctype); +template +FL_API Tensor fromScalar(const T& val, const dtype type = dtype_traits::ctype); /** * Creates a new Tensor with a given Shape and filled with a particular value. @@ -672,11 +678,12 @@ fromScalar(const T& val, const dtype type = dtype_traits::ctype); * on the value type * @return a tensor of the specified shape filled with the specified value */ -template +template FL_API Tensor full( const Shape& dims, const T& val, - const dtype type = dtype_traits::ctype); + const dtype type = dtype_traits::ctype +); /** * Return a the identity tensor of a given size and type. @@ -697,12 +704,13 @@ FL_API Tensor identity(const Dim dim, const dtype type = dtype::f32); * * @return a tensor containing the evenly-spaced values */ -template +template FL_API Tensor arange( const T& start, const T& end, const T& step = 1, - const dtype type = dtype_traits::ctype); + const dtype type = dtype_traits::ctype +); /** * Create a tensor with [0, N] values along dimension given by seqDim and @@ -716,8 +724,7 @@ FL_API Tensor arange( * @return a tensor with the given shape with the sequence along the given * dimension, tiled along other dimensions. */ -FL_API Tensor -arange(const Shape& shape, const Dim seqDim = 0, const dtype type = dtype::f32); +FL_API Tensor arange(const Shape& shape, const Dim seqDim = 0, const dtype type = dtype::f32); /** * Creates a sequence with the range `[0, dims.elements())` sequentially in the @@ -735,8 +742,9 @@ arange(const Shape& shape, const Dim seqDim = 0, const dtype type = dtype::f32); */ FL_API Tensor iota( const Shape& dims, - const Shape& tileDims = {1}, - const dtype type = dtype::f32); + const Shape& tileDims = { 1 }, + const dtype type = dtype::f32 +); /************************ Shaping and Indexing *************************/ @@ -779,8 +787,7 @@ FL_API Tensor tile(const Tensor& tensor, const Shape& shape); * @param[in] axis the axis along which to concatenate tensors * @return a concatenated tensor */ -FL_API Tensor -concatenate(const std::vector& tensors, const unsigned axis = 0); +FL_API Tensor concatenate(const std::vector& tensors, const unsigned axis = 0); /** * Join or concatenate tensors together along a particular axis. @@ -789,10 +796,10 @@ concatenate(const std::vector& tensors, const unsigned axis = 0); * @param[in] args tensors to concatenate * @return a concatenated tensor */ -template +template Tensor concatenate(unsigned axis, const Ts&... args) { - std::vector tensors{{args...}}; - return concatenate(tensors, axis); + std::vector tensors{{args...}}; + return concatenate(tensors, axis); } /** @@ -806,12 +813,12 @@ FL_API Tensor nonzero(const Tensor& tensor); /// Padding types for the pad operator. enum class PadType { - /// pad with a constant zero value. - Constant, - /// pad with the values at the edges of the tensor - Edge, - /// pad with a reflection of the tensor mirrored along each edge - Symmetric + /// pad with a constant zero value. + Constant, + /// pad with the values at the edges of the tensor + Edge, + /// pad with a reflection of the tensor mirrored along each edge + Symmetric }; /** @@ -824,10 +831,11 @@ enum class PadType { * * @return the padded tensor */ -FL_API Tensor -pad(const Tensor& input, +FL_API Tensor pad( + const Tensor& input, const std::vector>& padWidths, - const PadType type = PadType::Constant); + const PadType type = PadType::Constant +); /************************** Unary Operators ***************************/ /** @@ -838,7 +846,7 @@ pad(const Tensor& input, */ FL_API Tensor negative(const Tensor& tensor); inline Tensor operator-(const Tensor& tensor) { - return negative(tensor); + return negative(tensor); } /** @@ -849,7 +857,7 @@ inline Tensor operator-(const Tensor& tensor) { */ FL_API Tensor logicalNot(const Tensor& tensor); inline Tensor operator!(const Tensor& tensor) { - return logicalNot(tensor); + return logicalNot(tensor); } /** @@ -942,7 +950,7 @@ FL_API Tensor absolute(const Tensor& tensor); // \copydoc absolute inline Tensor abs(const Tensor& tensor) { - return absolute(tensor); + return absolute(tensor); } /** @@ -1154,7 +1162,7 @@ FL_API Tensor where(const Tensor& condition, const double& x, const Tensor& y); /*! * Sorting mode for sorting-related functions. */ -enum class SortMode { Descending = 0, Ascending = 1 }; +enum class SortMode {Descending = 0, Ascending = 1}; /** * Get the top-k values and indices from a Tensor. @@ -1174,7 +1182,8 @@ FL_API void topk( const Tensor& input, const unsigned k, const Dim axis, - const SortMode sortMode = SortMode::Descending); + const SortMode sortMode = SortMode::Descending +); /** * Sort the values of a tensor, and return the sorted tensor. @@ -1186,7 +1195,8 @@ FL_API void topk( FL_API Tensor sort( const Tensor& input, const Dim axis, - const SortMode sortMode = SortMode::Ascending); + const SortMode sortMode = SortMode::Ascending +); /** * Sort the values of a tensor, and return the sorted tensor and sorted indices. @@ -1202,7 +1212,8 @@ FL_API void sort( Tensor& indices, const Tensor& input, const Dim axis, - const SortMode sortMode = SortMode::Ascending); + const SortMode sortMode = SortMode::Ascending +); /** * Sort the values of a tensor and return the sorted indices. @@ -1214,35 +1225,36 @@ FL_API void sort( FL_API Tensor argsort( const Tensor& input, const Dim axis, - const SortMode sortMode = SortMode::Ascending); + const SortMode sortMode = SortMode::Ascending +); /************************** Binary Operators ***************************/ // \cond DOXYGEN_DO_NOT_DOCUMENT -#define FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, TYPE) \ - FL_API Tensor FUNC(TYPE lhs, const Tensor& rhs); \ - FL_API Tensor FUNC(const Tensor& lhs, TYPE rhs); \ - FL_API Tensor operator OP(TYPE lhs, const Tensor& rhs); \ - FL_API Tensor operator OP(const Tensor& lhs, TYPE rhs); - -#define FL_BINARY_OP_LITERALS_DECL(OP, FUNC) \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const bool&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const int&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const unsigned&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const char&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const unsigned char&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const long&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const unsigned long&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const long long&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const unsigned long long&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const double&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const float&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const short&); \ - FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const unsigned short&); - -#define FL_BINARY_OP_DECL(OP, FUNC) \ - FL_API Tensor FUNC(const Tensor& lhs, const Tensor& rhs); \ - FL_API Tensor operator OP(const Tensor& lhs, const Tensor& rhs); \ - FL_BINARY_OP_LITERALS_DECL(OP, FUNC); +#define FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, TYPE) \ + FL_API Tensor FUNC(TYPE lhs, const Tensor& rhs); \ + FL_API Tensor FUNC(const Tensor& lhs, TYPE rhs); \ + FL_API Tensor operator OP(TYPE lhs, const Tensor& rhs); \ + FL_API Tensor operator OP(const Tensor& lhs, TYPE rhs); + +#define FL_BINARY_OP_LITERALS_DECL(OP, FUNC) \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const bool&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const int&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const unsigned&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const char&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const unsigned char&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const long&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const unsigned long&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const long long&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const unsigned long long&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const double&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const float&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const short&); \ + FL_BINARY_OP_LITERAL_TYPE_DECL(OP, FUNC, const unsigned short&); + +#define FL_BINARY_OP_DECL(OP, FUNC) \ + FL_API Tensor FUNC(const Tensor& lhs, const Tensor& rhs); \ + FL_API Tensor operator OP(const Tensor& lhs, const Tensor& rhs); \ + FL_BINARY_OP_LITERALS_DECL(OP, FUNC); FL_BINARY_OP_DECL(+, add); FL_BINARY_OP_DECL(-, sub); @@ -1366,7 +1378,7 @@ FL_API Tensor power(const double& lhs, const Tensor& rhs); * Transformations to apply to Tensors (i.e. matrices) before applying certain * operations (i.e. matmul). */ -enum class MatrixProperty { None = 0, Transpose = 1 }; +enum class MatrixProperty {None = 0, Transpose = 1}; /** * Perform matrix multiplication between two tensors. @@ -1384,7 +1396,8 @@ FL_API Tensor matmul( const Tensor& lhs, const Tensor& rhs, MatrixProperty lhsProp = MatrixProperty::None, - MatrixProperty rhsProp = MatrixProperty::None); + MatrixProperty rhsProp = MatrixProperty::None +); /************************** Reductions ***************************/ @@ -1402,7 +1415,8 @@ FL_API Tensor matmul( FL_API Tensor amin( const Tensor& input, const std::vector& axes = {}, - const bool keepDims = false); + const bool keepDims = false +); /** * Compute the maximum value along multiple axes. If axes is left empty, @@ -1418,7 +1432,8 @@ FL_API Tensor amin( FL_API Tensor amax( const Tensor& input, const std::vector& axes = {}, - const bool keepDims = false); + const bool keepDims = false +); /** * Compute the maximum value along multiple axes for a tensor, returning both @@ -1438,7 +1453,8 @@ FL_API void min( Tensor& indices, const Tensor& input, const unsigned axis, - const bool keepDims = false); + const bool keepDims = false +); /** * Compute the maximum value along multiple axes for a tensor, returning both @@ -1458,7 +1474,8 @@ FL_API void max( Tensor& indices, const Tensor& input, const unsigned axis, - const bool keepDims = false); + const bool keepDims = false +); /** * Return the indices of the maximum values along an axis. @@ -1469,8 +1486,7 @@ FL_API void max( * as singleton dimensions rather than collapsing them * @return a tensor containing the indices of the max values along each axis */ -FL_API Tensor -argmax(const Tensor& input, const unsigned axis, const bool keepDims = false); +FL_API Tensor argmax(const Tensor& input, const unsigned axis, const bool keepDims = false); /** * Return the indices of the minimum values along an axis. @@ -1481,8 +1497,7 @@ argmax(const Tensor& input, const unsigned axis, const bool keepDims = false); * as singleton dimensions rather than collapsing them * @return a tensor containing the indices of the max values along each axis */ -FL_API Tensor -argmin(const Tensor& input, const unsigned axis, const bool keepDims = false); +FL_API Tensor argmin(const Tensor& input, const unsigned axis, const bool keepDims = false); /** * Sum of tensor over given axes. If axes is left empty, computes the sum along @@ -1495,10 +1510,11 @@ argmin(const Tensor& input, const unsigned axis, const bool keepDims = false); * as singleton dimensions rather than collapsing them * @return a tensor containing the sum(s) */ -FL_API Tensor -sum(const Tensor& input, +FL_API Tensor sum( + const Tensor& input, const std::vector& axes = {}, - const bool keepDims = false); + const bool keepDims = false +); /** * Compute the cumulative sum (or the prefix sum, scan, or inclusive scan) of a @@ -1524,7 +1540,8 @@ FL_API Tensor cumsum(const Tensor& input, const unsigned axis); FL_API Tensor mean( const Tensor& input, const std::vector& axes = {}, - const bool keepDims = false); + const bool keepDims = false +); /** * Median of tensor over given axes. If axes is left empty, computes the median @@ -1540,7 +1557,8 @@ FL_API Tensor mean( FL_API Tensor median( const Tensor& input, const std::vector& axes = {}, - const bool keepDims = false); + const bool keepDims = false +); /** * Variance of an tensor over given axes. If axes is left empty, computes the @@ -1554,11 +1572,12 @@ FL_API Tensor median( * as singleton dimensions rather than collapsing them * @return a tensor containing the variance(s) */ -FL_API Tensor -var(const Tensor& input, +FL_API Tensor var( + const Tensor& input, const std::vector& axes = {}, const bool bias = false, - const bool keepDims = false); + const bool keepDims = false +); /** * Standard deviation of an tensor over given axes. If axes is left empty, @@ -1571,10 +1590,11 @@ var(const Tensor& input, * as singleton dimensions rather than collapsing them * @return a tensor containing the standard deviation(s) */ -FL_API Tensor -std(const Tensor& input, +FL_API Tensor std( + const Tensor& input, const std::vector& axes = {}, - const bool keepDims = false); + const bool keepDims = false +); /** * Perform Lp-norm computation, reduced over specified dimensions. If axes is @@ -1591,7 +1611,8 @@ FL_API Tensor norm( const Tensor& input, const std::vector& axes = {}, double p = 2, - const bool keepDims = false); + const bool keepDims = false +); /** * Counts the number of nonzero elements in a tensor. @@ -1609,7 +1630,8 @@ FL_API Tensor norm( FL_API Tensor countNonzero( const Tensor& input, const std::vector& axes = {}, - const bool keepDims = false); + const bool keepDims = false +); /** * Checks for any true values in a tensor along one or more axes; returns true @@ -1625,10 +1647,11 @@ FL_API Tensor countNonzero( * @return a bool tensor containing axis-wise values denoting truthy values * along that axis in the input tensor. */ -FL_API Tensor -any(const Tensor& input, +FL_API Tensor any( + const Tensor& input, const std::vector& axes = {}, - const bool keepDims = false); + const bool keepDims = false +); /** * Checks if all values are true in a tensor along one or more axes; returns @@ -1644,10 +1667,11 @@ any(const Tensor& input, * @return a bool tensor containing axis-wise values with true along * axes that contain only true values. */ -FL_API Tensor -all(const Tensor& input, +FL_API Tensor all( + const Tensor& input, const std::vector& axes = {}, - const bool keepDims = false); + const bool keepDims = false +); /************************** Utilities ***************************/ @@ -1678,7 +1702,8 @@ FL_API void print(const Tensor& tensor); FL_API bool allClose( const fl::Tensor& a, const fl::Tensor& b, - const double absTolerance = 1e-5); + const double absTolerance = 1e-5 +); /** * @return if a Tensor contains any NaN or Inf values. @@ -1710,48 +1735,57 @@ FL_API std::ostream& operator<<(std::ostream& os, const TensorBackendType type); * @param[in] t the tensor to convert * @returns a tensor backed by the specified compile time type */ -template +template Tensor to(Tensor&& t) { - // Fast path -- types are the same - if (T::tensorBackendType == t.backendType()) { - return std::move(t); - } - - if (t.isSparse()) { - throw std::invalid_argument( - "Tensor type conversion between sparse " - "tensors not yet supported."); - } else { - // TODO: dynamically fix the memory location based on the type of - // backend/where base memory is - return Tensor(std::make_unique( - t.shape(), t.type(), t.device(), MemoryLocation::Device)); - } + // Fast path -- types are the same + if(T::tensorBackendType == t.backendType()) { + return std::move(t); + } + + if(t.isSparse()) { + throw std::invalid_argument( + "Tensor type conversion between sparse " + "tensors not yet supported." + ); + } else { + // TODO: dynamically fix the memory location based on the type of + // backend/where base memory is + return Tensor( + std::make_unique( + t.shape(), + t.type(), + t.device(), + MemoryLocation::Device + ) + ); + } } /** @} */ namespace detail { -bool areTensorTypesEqual(const Tensor& a, const Tensor& b); + bool areTensorTypesEqual(const Tensor& a, const Tensor& b); -template -bool areTensorTypesEqual( - const Tensor& a, - const Tensor& b, - const Args&... args) { - return areTensorTypesEqual(a, b) && areTensorTypesEqual(a, args...); -} + template + bool areTensorTypesEqual( + const Tensor& a, + const Tensor& b, + const Args&... args + ) { + return areTensorTypesEqual(a, b) && areTensorTypesEqual(a, args...); + } } // namespace detail /** * Checks if a variadic number of Tensors have the same type. */ -#define FL_TENSOR_DTYPES_MATCH_CHECK(...) \ - if (!detail::areTensorTypesEqual(__VA_ARGS__)) { \ - throw std::invalid_argument( \ - std::string(__func__) + ": tensors are not all of the same types. "); \ - } +#define FL_TENSOR_DTYPES_MATCH_CHECK(...) \ + if(!detail::areTensorTypesEqual(__VA_ARGS__)) { \ + throw std::invalid_argument( \ + std::string(__func__) + ": tensors are not all of the same types. " \ + ); \ + } } // namespace fl diff --git a/flashlight/fl/tensor/TensorExtension.cpp b/flashlight/fl/tensor/TensorExtension.cpp index b268d01..feca51c 100644 --- a/flashlight/fl/tensor/TensorExtension.cpp +++ b/flashlight/fl/tensor/TensorExtension.cpp @@ -14,49 +14,54 @@ namespace fl::detail { bool TensorExtensionRegistrar::registerTensorExtension( TensorBackendType backend, TensorExtensionType extensionType, - TensorExtensionCallback&& creationFunc) { - auto& _extensions = (*extensions_ - .try_emplace( - backend, - std::unordered_map< - TensorExtensionType, - TensorExtensionCallback>()) - .first) - .second; - - // Add extension to registry - _extensions.try_emplace(extensionType, std::move(creationFunc)); - return true; + TensorExtensionCallback&& creationFunc +) { + auto& _extensions = (*extensions_ + .try_emplace( + backend, + std::unordered_map< + TensorExtensionType, + TensorExtensionCallback>() + ) + .first) + .second; + + // Add extension to registry + _extensions.try_emplace(extensionType, std::move(creationFunc)); + return true; } bool TensorExtensionRegistrar::isTensorExtensionRegistered( - TensorBackendType backend, - TensorExtensionType extensionType) { - return extensions_.count(backend) && - extensions_[backend].count(extensionType); + TensorBackendType backend, + TensorExtensionType extensionType +) { + return extensions_.count(backend) + && extensions_[backend].count(extensionType); } -TensorExtensionCallback& -TensorExtensionRegistrar::getTensorExtensionCreationFunc( +TensorExtensionCallback& TensorExtensionRegistrar::getTensorExtensionCreationFunc( TensorBackendType backend, - TensorExtensionType extensionType) { - if (extensions_.find(backend) == extensions_.end()) { - throw std::invalid_argument( - "TensorExtensionRegistrar::getTensorExtensionCreationFunc: " - "no tensor extensions registered for given backend."); - } - auto& _extensions = extensions_[backend]; - if (_extensions.find(extensionType) == _extensions.end()) { - throw std::invalid_argument( - "TensorExtensionRegistrar::getTensorExtensionCreationFunc: " - "no tensor extensions registered for backend " + tensorBackendTypeToString(backend)); - } - return _extensions[extensionType]; + TensorExtensionType extensionType +) { + if(extensions_.find(backend) == extensions_.end()) { + throw std::invalid_argument( + "TensorExtensionRegistrar::getTensorExtensionCreationFunc: " + "no tensor extensions registered for given backend." + ); + } + auto& _extensions = extensions_[backend]; + if(_extensions.find(extensionType) == _extensions.end()) { + throw std::invalid_argument( + "TensorExtensionRegistrar::getTensorExtensionCreationFunc: " + "no tensor extensions registered for backend " + tensorBackendTypeToString(backend) + ); + } + return _extensions[extensionType]; } TensorExtensionRegistrar& TensorExtensionRegistrar::getInstance() { - static TensorExtensionRegistrar instance; - return instance; + static TensorExtensionRegistrar instance; + return instance; } } // namespace fl diff --git a/flashlight/fl/tensor/TensorExtension.h b/flashlight/fl/tensor/TensorExtension.h index a33b9be..adf8c13 100644 --- a/flashlight/fl/tensor/TensorExtension.h +++ b/flashlight/fl/tensor/TensorExtension.h @@ -20,24 +20,24 @@ namespace fl { * A runtime type denoting the tensor extension. */ enum class TensorExtensionType { - Generic, // placeholder - Autograd, - Vision, - JitOptimizer, + Generic, // placeholder + Autograd, + Vision, + JitOptimizer, }; // Common base type class TensorExtensionBase { - public: - virtual ~TensorExtensionBase() = default; +public: + virtual ~TensorExtensionBase() = default; - virtual bool isDataTypeSupported(const fl::dtype& dtype) const = 0; + virtual bool isDataTypeSupported(const fl::dtype& dtype) const = 0; }; namespace detail { -using TensorExtensionCallback = - std::function()>; + using TensorExtensionCallback = + std::function()>; /** * Employ an extensible factory singleton pattern to handle creation callbacks @@ -45,45 +45,51 @@ using TensorExtensionCallback = * * Users should not directly use this singleton and should instead */ -class TensorExtensionRegistrar { - // Intentionally private. Only one instance should exist/it should be accessed - // via getInstance(). - TensorExtensionRegistrar() = default; - - // TODO(jacobkahn): change this to an array and have indices for extension - // types correspond to extension instances - std::unordered_map< - TensorBackendType, - std::unordered_map> - extensions_; - - public: - bool registerTensorExtension( - TensorBackendType backend, - TensorExtensionType extensionType, - TensorExtensionCallback&& creationFunc); - - static TensorExtensionRegistrar& getInstance(); - ~TensorExtensionRegistrar() = default; - - template - bool registerTensorExtension(TensorBackendType backend) { - // TODO: use a static T::create instead of a lambda if we can enforce its - // declaration and definition on interface functions - return this->registerTensorExtension( - backend, T::getExtensionType(), []() -> std::unique_ptr { - return std::make_unique(); - }); - } - - bool isTensorExtensionRegistered( - TensorBackendType backend, - TensorExtensionType extensionType); - - TensorExtensionCallback& getTensorExtensionCreationFunc( - TensorBackendType backend, - TensorExtensionType extensionType); -}; + class TensorExtensionRegistrar { + // Intentionally private. Only one instance should exist/it should be accessed + // via getInstance(). + TensorExtensionRegistrar() = default; + + // TODO(jacobkahn): change this to an array and have indices for extension + // types correspond to extension instances + std::unordered_map< + TensorBackendType, + std::unordered_map> + extensions_; + + public: + bool registerTensorExtension( + TensorBackendType backend, + TensorExtensionType extensionType, + TensorExtensionCallback&& creationFunc + ); + + static TensorExtensionRegistrar& getInstance(); + ~TensorExtensionRegistrar() = default; + + template + bool registerTensorExtension(TensorBackendType backend) { + // TODO: use a static T::create instead of a lambda if we can enforce its + // declaration and definition on interface functions + return this->registerTensorExtension( + backend, + T::getExtensionType(), + []() -> std::unique_ptr { + return std::make_unique(); + } + ); + } + + bool isTensorExtensionRegistered( + TensorBackendType backend, + TensorExtensionType extensionType + ); + + TensorExtensionCallback& getTensorExtensionCreationFunc( + TensorBackendType backend, + TensorExtensionType extensionType + ); + }; } // namespace detail @@ -94,25 +100,25 @@ class TensorExtensionRegistrar { * @param[in] backendType the type of the backend to register the extension to. * See TensorBackendType. */ -template +template bool registerTensorExtension(TensorBackendType backendType) { - return detail::TensorExtensionRegistrar::getInstance() - .registerTensorExtension(backendType); + return detail::TensorExtensionRegistrar::getInstance() + .registerTensorExtension(backendType); } -template +template class TensorExtension : public TensorExtensionBase { - public: - static TensorExtensionType getExtensionType() { - return T::extensionType; - } +public: + static TensorExtensionType getExtensionType() { + return T::extensionType; + } }; -template +template struct TensorExtensionRegisterer { - TensorExtensionRegisterer(TensorBackendType t) { - ::fl::registerTensorExtension(t); - } + TensorExtensionRegisterer(TensorBackendType t) { + ::fl::registerTensorExtension(t); + } }; /** @@ -123,6 +129,6 @@ struct TensorExtensionRegisterer { * See TensorBackendType. */ #define FL_REGISTER_TENSOR_EXTENSION(T, BACKEND_TYPE) \ - TensorExtensionRegisterer T##BACKEND_TYPE(TensorBackendType::BACKEND_TYPE) + TensorExtensionRegisterer T ## BACKEND_TYPE(TensorBackendType::BACKEND_TYPE) } // namespace fl diff --git a/flashlight/fl/tensor/Types.cpp b/flashlight/fl/tensor/Types.cpp index a8e5ff3..9667c10 100644 --- a/flashlight/fl/tensor/Types.cpp +++ b/flashlight/fl/tensor/Types.cpp @@ -41,48 +41,48 @@ const std::unordered_map kStringToType = { }; size_t getTypeSize(dtype type) { - switch (type) { - case dtype::f16: - return sizeof(float) / 2; - case dtype::f32: - return sizeof(float); - case dtype::f64: - return sizeof(double); - case dtype::b8: - return sizeof(unsigned char); - case dtype::s16: - return sizeof(short); - case dtype::s64: - return sizeof(long long); - case dtype::s32: - return sizeof(int); - case dtype::u8: - return sizeof(unsigned char); - case dtype::u16: - return sizeof(unsigned short); - case dtype::u32: - return sizeof(unsigned); - case dtype::u64: - return sizeof(unsigned long long); - default: - throw std::invalid_argument("getTypeSize - invalid type queried."); - } + switch(type) { + case dtype::f16: + return sizeof(float) / 2; + case dtype::f32: + return sizeof(float); + case dtype::f64: + return sizeof(double); + case dtype::b8: + return sizeof(unsigned char); + case dtype::s16: + return sizeof(short); + case dtype::s64: + return sizeof(long long); + case dtype::s32: + return sizeof(int); + case dtype::u8: + return sizeof(unsigned char); + case dtype::u16: + return sizeof(unsigned short); + case dtype::u32: + return sizeof(unsigned); + case dtype::u64: + return sizeof(unsigned long long); + default: + throw std::invalid_argument("getTypeSize - invalid type queried."); + } } const std::string& dtypeToString(dtype type) { - return kTypeToString.at(type); + return kTypeToString.at(type); } fl::dtype stringToDtype(const std::string& string) { - if (kStringToType.find(string) != kStringToType.end()) { - return kStringToType.at(string); - } - throw std::invalid_argument("stringToDtype: Invalid input type: " + string); + if(kStringToType.find(string) != kStringToType.end()) { + return kStringToType.at(string); + } + throw std::invalid_argument("stringToDtype: Invalid input type: " + string); } std::ostream& operator<<(std::ostream& ostr, const dtype& s) { - ostr << dtypeToString(s); - return ostr; + ostr << dtypeToString(s); + return ostr; } } // namespace fl diff --git a/flashlight/fl/tensor/Types.h b/flashlight/fl/tensor/Types.h index 148f1a4..16f201a 100644 --- a/flashlight/fl/tensor/Types.h +++ b/flashlight/fl/tensor/Types.h @@ -15,18 +15,18 @@ namespace fl { enum class dtype { - f16 = 0, // 16-bit float - f32 = 1, // 32-bit float - f64 = 2, // 64-bit float - b8 = 3, // 8-bit boolean - s16 = 4, // 16-bit signed integer - s32 = 5, // 32-bit signed integer - s64 = 6, // 64-bit signed integer - u8 = 7, // 8-bit unsigned integer - u16 = 8, // 16-bit unsigned integer - u32 = 9, // 32-bit unsigned integer - u64 = 10 // 64-bit unsigned integer - // TODO: add support for complex-valued tensors? (AF) + f16 = 0, // 16-bit float + f32 = 1, // 32-bit float + f64 = 2, // 64-bit float + b8 = 3, // 8-bit boolean + s16 = 4, // 16-bit signed integer + s32 = 5, // 32-bit signed integer + s64 = 6, // 64-bit signed integer + u8 = 7, // 8-bit unsigned integer + u16 = 8, // 16-bit unsigned integer + u32 = 9, // 32-bit unsigned integer + u64 = 10 // 64-bit unsigned integer + // TODO: add support for complex-valued tensors? (AF) }; /** @@ -55,19 +55,19 @@ FL_API fl::dtype stringToDtype(const std::string& string); */ FL_API std::ostream& operator<<(std::ostream& ostr, const dtype& s); -template +template struct dtype_traits; -#define FL_TYPE_TRAIT(BASE_TYPE, DTYPE, CONSTANT_TYPE, STRING_NAME) \ - template <> \ - struct FL_API dtype_traits { \ - static const dtype fl_type = DTYPE; /* corresponding dtype */ \ - static const dtype ctype = CONSTANT_TYPE; /* constant init type */ \ - typedef BASE_TYPE base_type; \ - static const char* getName() { \ - return STRING_NAME; \ - } \ - } +#define FL_TYPE_TRAIT(BASE_TYPE, DTYPE, CONSTANT_TYPE, STRING_NAME) \ + template<> \ + struct FL_API dtype_traits { \ + static const dtype fl_type = DTYPE; /* corresponding dtype */ \ + static const dtype ctype = CONSTANT_TYPE; /* constant init type */ \ + typedef BASE_TYPE base_type; \ + static const char* getName() { \ + return STRING_NAME; \ + } \ + } FL_TYPE_TRAIT(float, dtype::f32, dtype::f32, "float"); FL_TYPE_TRAIT(double, dtype::f64, dtype::f32, "double"); diff --git a/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp b/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp index ceafce9..b84ee1a 100644 --- a/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp +++ b/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp @@ -13,15 +13,16 @@ namespace fl { namespace detail { -void advancedIndex( - const af::array& inp, - const af::dim4& idxStart, - const af::dim4& idxEnd, - const af::dim4& outDims, - const std::vector& idxArr, - af::array& out) { - throw std::runtime_error("gradAdvancedIndex not implemented for cpu"); -} + void advancedIndex( + const af::array& inp, + const af::dim4& idxStart, + const af::dim4& idxEnd, + const af::dim4& outDims, + const std::vector& idxArr, + af::array& out + ) { + throw std::runtime_error("gradAdvancedIndex not implemented for cpu"); + } } // namespace detail } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/AdvancedIndex.h b/flashlight/fl/tensor/backend/af/AdvancedIndex.h index cbc56af..5155df1 100644 --- a/flashlight/fl/tensor/backend/af/AdvancedIndex.h +++ b/flashlight/fl/tensor/backend/af/AdvancedIndex.h @@ -30,13 +30,14 @@ namespace detail { * @param out The output Varible which is the gradient of input of index * operator */ -void advancedIndex( - const af::array& inp, - const af::dim4& idxStart, - const af::dim4& idxEnd, - const af::dim4& outDims, - const std::vector& idxArr, - af::array& out); + void advancedIndex( + const af::array& inp, + const af::dim4& idxStart, + const af::dim4& idxEnd, + const af::dim4& outDims, + const std::vector& idxArr, + af::array& out + ); } // namespace detail } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp b/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp index d8f8a2a..bdbf9ad 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp @@ -18,43 +18,45 @@ Tensor ArrayFireBackend::matmul( const Tensor& lhs, const Tensor& rhs, MatrixProperty lhsProp, - MatrixProperty rhsProp) { - unsigned numDims = std::max(lhs.ndim(), rhs.ndim()); - if ((lhs.ndim() == 1 || rhs.ndim() == 1) && numDims > 1) { - numDims -= 1; - } - - af::array lhsArray = toArray(lhs); - af::array rhsArray = toArray(rhs); - - if (lhs.ndim() == 1 && rhs.ndim() == 1) { - // Simulate a dot product by transpoing the lhs: - // (1, k) x (k, 1) --> (1, 1) --> reshape to (1) - // Ignore other transposes since 1D tensors are the transpose of themselves. - // ArrayFire would otherwise transpose a (k) tensor to (1, k) since (k) = - // (k, 1, 1, 1) and ArrayFire transpose transposes the first two dimensions. - lhsProp = MatrixProperty::Transpose; - rhsProp = MatrixProperty::None; - numDims = 1; - } else { - if (rhs.ndim() == 1) { - rhsArray = af::moddims(toArray(rhs), {rhs.dim(0), 1}); + MatrixProperty rhsProp +) { + unsigned numDims = std::max(lhs.ndim(), rhs.ndim()); + if((lhs.ndim() == 1 || rhs.ndim() == 1) && numDims > 1) { + numDims -= 1; } - if (lhs.ndim() == 1) { - lhsArray = af::moddims(toArray(lhs), {1, lhs.dim(0)}); + + af::array lhsArray = toArray(lhs); + af::array rhsArray = toArray(rhs); + + if(lhs.ndim() == 1 && rhs.ndim() == 1) { + // Simulate a dot product by transpoing the lhs: + // (1, k) x (k, 1) --> (1, 1) --> reshape to (1) + // Ignore other transposes since 1D tensors are the transpose of themselves. + // ArrayFire would otherwise transpose a (k) tensor to (1, k) since (k) = + // (k, 1, 1, 1) and ArrayFire transpose transposes the first two dimensions. + lhsProp = MatrixProperty::Transpose; + rhsProp = MatrixProperty::None; + numDims = 1; + } else { + if(rhs.ndim() == 1) { + rhsArray = af::moddims(toArray(rhs), {rhs.dim(0), 1}); + } + if(lhs.ndim() == 1) { + lhsArray = af::moddims(toArray(lhs), {1, lhs.dim(0)}); + } } - } - auto arr = af::matmul( - lhsArray, - rhsArray, - detail::flToAfMatrixProperty(lhsProp), - detail::flToAfMatrixProperty(rhsProp)); + auto arr = af::matmul( + lhsArray, + rhsArray, + detail::flToAfMatrixProperty(lhsProp), + detail::flToAfMatrixProperty(rhsProp) + ); - if (lhs.ndim() == 1 && rhs.ndim() == 2) { - arr = af::moddims(arr, arr.dims(1)); - } + if(lhs.ndim() == 1 && rhs.ndim() == 2) { + arr = af::moddims(arr, arr.dims(1)); + } - return toTensor(std::move(arr), numDims); + return toTensor(std::move(arr), numDims); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp b/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp index d63bcee..9aceaa1 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp @@ -19,17 +19,17 @@ #include "flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.h" #if FL_ARRAYFIRE_USE_CPU - #include "flashlight/fl/tensor/backend/af/ArrayFireCPUStream.h" +#include "flashlight/fl/tensor/backend/af/ArrayFireCPUStream.h" #endif #if FL_ARRAYFIRE_USE_CUDA - #include - #include - #include +#include +#include +#include - #include +#include - #include "flashlight/fl/runtime/CUDAStream.h" +#include "flashlight/fl/runtime/CUDAStream.h" #endif namespace fl { @@ -38,199 +38,223 @@ namespace { // Get the stream associated with given device in the given map; if it's not in // the map, initialize it (by wrapping or creating) and put it into the map. -const Stream& getOrWrapAfDeviceStream( - const int afId, - const int nativeId, - std::unordered_map>& afIdToStream) { - auto iter = afIdToStream.find(afId); - if (iter != afIdToStream.end()) { - return *iter->second; - } + const Stream& getOrWrapAfDeviceStream( + const int afId, + const int nativeId, + std::unordered_map>& afIdToStream + ) { + auto iter = afIdToStream.find(afId); + if(iter != afIdToStream.end()) { + return *iter->second; + } #if FL_ARRAYFIRE_USE_CPU - auto resIter = afIdToStream.emplace(afId, ArrayFireCPUStream::create()); - return *resIter.first->second; + auto resIter = afIdToStream.emplace(afId, ArrayFireCPUStream::create()); + return *resIter.first->second; #elif FL_ARRAYFIRE_USE_CUDA - const cudaStream_t cudaNativeStream = afcu::getStream(afId); - auto resIter = afIdToStream.emplace( - afId, CUDAStream::wrapUnmanaged(nativeId, cudaNativeStream)); - return *resIter.first->second; + const cudaStream_t cudaNativeStream = afcu::getStream(afId); + auto resIter = afIdToStream.emplace( + afId, + CUDAStream::wrapUnmanaged(nativeId, cudaNativeStream) + ); + return *resIter.first->second; #else - throw std::runtime_error( - "ArrayFireBackend was not compiled with support for CPU or GPU"); + throw std::runtime_error( + "ArrayFireBackend was not compiled with support for CPU or GPU" + ); #endif -} + } } // namespace ArrayFireBackend::ArrayFireBackend() { - AF_CHECK(af_init()); - - std::call_once(memoryInitFlag, []() { - // TODO: remove this temporary workaround for TextDatasetTest crash on CPU - // backend when tearing down the test environment. This is possibly due to - // AF race conditions when tearing down our custom memory manager. - // TODO: remove this temporary workaround for crashes when using custom - // opencl kernels. - if (FL_BACKEND_CUDA) { - MemoryManagerInstaller::installDefaultMemoryManager(); - } - }); - - for (int id = 0; id < af::getDeviceCount(); id++) { - int nativeId = id; // TODO investigate how OpenCL fits into this. + AF_CHECK(af_init()); + + std::call_once( + memoryInitFlag, + []() { + // TODO: remove this temporary workaround for TextDatasetTest crash on CPU + // backend when tearing down the test environment. This is possibly due to + // AF race conditions when tearing down our custom memory manager. + // TODO: remove this temporary workaround for crashes when using custom + // opencl kernels. + if(FL_BACKEND_CUDA) { + MemoryManagerInstaller::installDefaultMemoryManager(); + } + } + ); + + for(int id = 0; id < af::getDeviceCount(); id++) { + int nativeId = id; // TODO investigate how OpenCL fits into this. #if FL_ARRAYFIRE_USE_CUDA - nativeId = afcu::getNativeId(id); + nativeId = afcu::getNativeId(id); #endif - // TODO make these maps `const` - nativeIdToId_[nativeId] = id; - idToNativeId_[id] = nativeId; - } - - const auto& manager = DeviceManager::getInstance(); - // This callback ensures consistency of AF internal state on active device. - // Capturing by value to avoid destructor race hazard for static objects. - const auto setActiveCallback = [nativeIdToId = nativeIdToId_, - afIdToStream = afIdToStream_](int nativeId) { - auto afId = nativeIdToId.at(nativeId); - af::setDevice(afId); - // this is the latest point we can lazily wrap the AF stream, which may get - // lazily intialized anytime in AF internally, e.g., via tensor computation. - getOrWrapAfDeviceStream(afId, nativeId, *afIdToStream); - }; + // TODO make these maps `const` + nativeIdToId_[nativeId] = id; + idToNativeId_[id] = nativeId; + } + + const auto& manager = DeviceManager::getInstance(); + // This callback ensures consistency of AF internal state on active device. + // Capturing by value to avoid destructor race hazard for static objects. + const auto setActiveCallback = [nativeIdToId = nativeIdToId_, + afIdToStream = afIdToStream_](int nativeId) { + auto afId = nativeIdToId.at(nativeId); + af::setDevice(afId); + // this is the latest point we can lazily wrap the AF stream, which may get + // lazily intialized anytime in AF internally, e.g., via tensor computation. + getOrWrapAfDeviceStream(afId, nativeId, *afIdToStream); + }; #if FL_ARRAYFIRE_USE_CPU - auto& device = manager.getActiveDevice(DeviceType::x64); - device.addSetActiveCallback(setActiveCallback); -#elif FL_ARRAYFIRE_USE_CUDA - const auto deviceCount = manager.getDeviceCount(DeviceType::CUDA); - for (unsigned nativeId = 0; nativeId < deviceCount; nativeId++) { - auto& device = manager.getDevice(DeviceType::CUDA, nativeId); + auto& device = manager.getActiveDevice(DeviceType::x64); device.addSetActiveCallback(setActiveCallback); - } +#elif FL_ARRAYFIRE_USE_CUDA + const auto deviceCount = manager.getDeviceCount(DeviceType::CUDA); + for(unsigned nativeId = 0; nativeId < deviceCount; nativeId++) { + auto& device = manager.getDevice(DeviceType::CUDA, nativeId); + device.addSetActiveCallback(setActiveCallback); + } #endif - // Active device is never set explicitly, so we must wrap its stream eagerly. - auto activeAfId = af::getDevice(); - getOrWrapAfDeviceStream( - activeAfId, idToNativeId_.at(activeAfId), *afIdToStream_); + // Active device is never set explicitly, so we must wrap its stream eagerly. + auto activeAfId = af::getDevice(); + getOrWrapAfDeviceStream( + activeAfId, + idToNativeId_.at(activeAfId), + *afIdToStream_ + ); } ArrayFireBackend& ArrayFireBackend::getInstance() { - static ArrayFireBackend instance; - return instance; + static ArrayFireBackend instance; + return instance; } TensorBackendType ArrayFireBackend::backendType() const { - return TensorBackendType::ArrayFire; + return TensorBackendType::ArrayFire; } /* -------------------------- Compute Functions -------------------------- */ void ArrayFireBackend::eval(const Tensor& tensor) { - af::eval(toArray(tensor)); + af::eval(toArray(tensor)); } const Stream& ArrayFireBackend::getStreamOfArray( - const af::array& arr) { - // TODO once we enforce integrate Device::setDevice into fl::setDevice, each - // array's stream should always be wrapped already (via setDevice callback). - // auto iter = afIdToStream_->find(af::getDeviceId(arr)); - // assert(iter != afIdToStream_->end() && "Stream should have been wrapped"); - // return *iter->second; - auto afId = af::getDeviceId(arr); - auto nativeId = idToNativeId_.at(afId); - return getOrWrapAfDeviceStream(afId, nativeId, *afIdToStream_); + const af::array& arr +) { + // TODO once we enforce integrate Device::setDevice into fl::setDevice, each + // array's stream should always be wrapped already (via setDevice callback). + // auto iter = afIdToStream_->find(af::getDeviceId(arr)); + // assert(iter != afIdToStream_->end() && "Stream should have been wrapped"); + // return *iter->second; + auto afId = af::getDeviceId(arr); + auto nativeId = idToNativeId_.at(afId); + return getOrWrapAfDeviceStream(afId, nativeId, *afIdToStream_); } bool ArrayFireBackend::supportsDataType(const fl::dtype& dtype) const { - switch (dtype) { - case fl::dtype::f16: - return af::isHalfAvailable(af::getDevice()) && - // f16 isn't [yet] supported with the CPU backend per onednn - // limitations - !FL_BACKEND_CPU; - default: - return true; - } + switch(dtype) { + case fl::dtype::f16: + return af::isHalfAvailable(af::getDevice()) + && // f16 isn't [yet] supported with the CPU backend per onednn + // limitations + !FL_BACKEND_CPU; + default: + return true; + } } void ArrayFireBackend::getMemMgrInfo( const char* msg, const int nativeDeviceId, - std::ostream* ostream) { - int deviceId = nativeIdToId_.at(nativeDeviceId); - if (ostream == nullptr) { - throw std::invalid_argument( - "ArrayFireBackend::getMemMgrInfo - got null ostream pointer"); - } - auto* curMemMgr = - fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); - if (curMemMgr) { - curMemMgr->printInfo(msg, deviceId, ostream); - } + std::ostream* ostream +) { + int deviceId = nativeIdToId_.at(nativeDeviceId); + if(ostream == nullptr) { + throw std::invalid_argument( + "ArrayFireBackend::getMemMgrInfo - got null ostream pointer" + ); + } + auto* curMemMgr = + fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); + if(curMemMgr) { + curMemMgr->printInfo(msg, deviceId, ostream); + } } void ArrayFireBackend::setMemMgrLogStream(std::ostream* stream) { - if (stream == nullptr) { - throw std::invalid_argument( - "ArrayFireBackend::getMemMgrInfo - got null ostream pointer"); - } - auto* curMemMgr = - fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); - if (curMemMgr) { - curMemMgr->setLogStream(stream); - } + if(stream == nullptr) { + throw std::invalid_argument( + "ArrayFireBackend::getMemMgrInfo - got null ostream pointer" + ); + } + auto* curMemMgr = + fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); + if(curMemMgr) { + curMemMgr->setLogStream(stream); + } } void ArrayFireBackend::setMemMgrLoggingEnabled(const bool enabled) { - auto* curMemMgr = - fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); - if (curMemMgr) { - curMemMgr->setLoggingEnabled(enabled); - } + auto* curMemMgr = + fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); + if(curMemMgr) { + curMemMgr->setLoggingEnabled(enabled); + } } void ArrayFireBackend::setMemMgrFlushInterval(const size_t interval) { - auto* curMemMgr = - fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); - if (curMemMgr) { - curMemMgr->setLogFlushInterval(interval); - } + auto* curMemMgr = + fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); + if(curMemMgr) { + curMemMgr->setLogFlushInterval(interval); + } } /* -------------------------- Rand Functions -------------------------- */ void ArrayFireBackend::setSeed(const int seed) { - af::setSeed(seed); + af::setSeed(seed); } Tensor ArrayFireBackend::randn(const Shape& shape, dtype type) { - return toTensor( - af::randn(detail::flToAfDims(shape), detail::flToAfType(type)), - shape.ndim()); + return toTensor( + af::randn(detail::flToAfDims(shape), detail::flToAfType(type)), + shape.ndim() + ); } Tensor ArrayFireBackend::rand(const Shape& shape, dtype type) { - return toTensor( - af::randu(detail::flToAfDims(shape), detail::flToAfType(type)), - shape.ndim()); + return toTensor( + af::randu(detail::flToAfDims(shape), detail::flToAfType(type)), + shape.ndim() + ); } /* --------------------------- Tensor Operators --------------------------- */ /******************** Tensor Creation Functions ********************/ -#define AF_BACKEND_CREATE_FUN_LITERAL_DEF(TYPE) \ - Tensor ArrayFireBackend::fromScalar(TYPE value, const dtype type) { \ - return toTensor( \ - af::constant(value, af::dim4(1), detail::flToAfType(type)), \ - /* ndim = */ 0); \ - } \ - Tensor ArrayFireBackend::full( \ - const Shape& shape, TYPE value, const dtype type) { \ - return toTensor( \ - af::constant( \ - value, detail::flToAfDims(shape), detail::flToAfType(type)), \ - shape.ndim()); \ - } +#define AF_BACKEND_CREATE_FUN_LITERAL_DEF(TYPE) \ + Tensor ArrayFireBackend::fromScalar(TYPE value, const dtype type) { \ + return toTensor( \ + af::constant(value, af::dim4(1), detail::flToAfType(type)), \ + /* ndim = */ 0 \ + ); \ + } \ + Tensor ArrayFireBackend::full( \ + const Shape& shape, \ + TYPE value, \ + const dtype type \ + ) { \ + return toTensor( \ + af::constant( \ + value, \ + detail::flToAfDims(shape), \ + detail::flToAfType(type) \ + ), \ + shape.ndim() \ + ); \ + } AF_BACKEND_CREATE_FUN_LITERAL_DEF(const double&); AF_BACKEND_CREATE_FUN_LITERAL_DEF(const float&); AF_BACKEND_CREATE_FUN_LITERAL_DEF(const int&); @@ -246,38 +270,46 @@ AF_BACKEND_CREATE_FUN_LITERAL_DEF(const short&); AF_BACKEND_CREATE_FUN_LITERAL_DEF(const unsigned short&); Tensor ArrayFireBackend::identity(const Dim dim, const dtype type) { - return toTensor( - af::identity({dim, dim}, detail::flToAfType(type)), /* numDims = */ 2); + return toTensor( + af::identity({dim, dim}, detail::flToAfType(type)), /* numDims = */ + 2 + ); } Tensor ArrayFireBackend::arange( const Shape& shape, const Dim seqDim, - const dtype type) { - return toTensor( - af::range(detail::flToAfDims(shape), seqDim, detail::flToAfType(type)), - shape.ndim()); + const dtype type +) { + return toTensor( + af::range(detail::flToAfDims(shape), seqDim, detail::flToAfType(type)), + shape.ndim() + ); } Tensor ArrayFireBackend::iota( const Shape& dims, const Shape& tileDims, - const dtype type) { - return toTensor( - af::iota( - detail::flToAfDims(dims), - detail::flToAfDims(tileDims), - detail::flToAfType(type)), - /* numDims = */ std::max(dims.ndim(), tileDims.ndim())); + const dtype type +) { + return toTensor( + af::iota( + detail::flToAfDims(dims), + detail::flToAfDims(tileDims), + detail::flToAfType(type) + ), + /* numDims = */ std::max(dims.ndim(), tileDims.ndim()) + ); } Tensor ArrayFireBackend::where( const Tensor& condition, const Tensor& x, - const Tensor& y) { - Tensor orig = x; - af::replace(toArray(orig), toArray(condition), toArray(y)); - return orig; + const Tensor& y +) { + Tensor orig = x; + af::replace(toArray(orig), toArray(condition), toArray(y)); + return orig; } void ArrayFireBackend::topk( @@ -286,38 +318,48 @@ void ArrayFireBackend::topk( const Tensor& input, const unsigned k, const Dim axis, - const SortMode sortMode) { - if (axis != 0) { - throw std::invalid_argument( - "ArrayFireTensor topk: operation only supported along zero axis."); - } - af::array valuesArr, indicesArr; - af::topk( - valuesArr, - indicesArr, - toArray(input), - k, - axis, - detail::flToAfTopKSortMode(sortMode)); - - values = toTensor(std::move(valuesArr), input.ndim()); - indices = toTensor(std::move(indicesArr), input.ndim()); + const SortMode sortMode +) { + if(axis != 0) { + throw std::invalid_argument( + "ArrayFireTensor topk: operation only supported along zero axis." + ); + } + af::array valuesArr, indicesArr; + af::topk( + valuesArr, + indicesArr, + toArray(input), + k, + axis, + detail::flToAfTopKSortMode(sortMode) + ); + + values = toTensor(std::move(valuesArr), input.ndim()); + indices = toTensor(std::move(indicesArr), input.ndim()); } Tensor ArrayFireBackend::sort( const Tensor& input, const Dim axis, - const SortMode sortMode) { - if (sortMode != SortMode::Descending && sortMode != SortMode::Ascending) { - throw std::invalid_argument( - "Cannot sort ArrayFire tensor with given SortMode: " - "only Descending and Ascending supported."); - } - - af::array values, indices; - af::sort( - values, indices, toArray(input), axis, sortMode == SortMode::Ascending); - return toTensor(std::move(values), input.ndim()); + const SortMode sortMode +) { + if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) { + throw std::invalid_argument( + "Cannot sort ArrayFire tensor with given SortMode: " + "only Descending and Ascending supported." + ); + } + + af::array values, indices; + af::sort( + values, + indices, + toArray(input), + axis, + sortMode == SortMode::Ascending + ); + return toTensor(std::move(values), input.ndim()); } void ArrayFireBackend::sort( @@ -325,37 +367,51 @@ void ArrayFireBackend::sort( Tensor& indices, const Tensor& input, const Dim axis, - const SortMode sortMode) { - if (sortMode != SortMode::Descending && sortMode != SortMode::Ascending) { - throw std::invalid_argument( - "Cannot sort ArrayFire tensor with given SortMode: " - "only Descending and Ascending supported."); - } - - af::array _values, _indices; - af::sort( - _values, _indices, toArray(input), axis, sortMode == SortMode::Ascending); - values = toTensor(std::move(_values), input.ndim()); - indices = toTensor(std::move(_indices), input.ndim()); + const SortMode sortMode +) { + if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) { + throw std::invalid_argument( + "Cannot sort ArrayFire tensor with given SortMode: " + "only Descending and Ascending supported." + ); + } + + af::array _values, _indices; + af::sort( + _values, + _indices, + toArray(input), + axis, + sortMode == SortMode::Ascending + ); + values = toTensor(std::move(_values), input.ndim()); + indices = toTensor(std::move(_indices), input.ndim()); } Tensor ArrayFireBackend::argsort( const Tensor& input, const Dim axis, - const SortMode sortMode) { - if (sortMode != SortMode::Descending && sortMode != SortMode::Ascending) { - throw std::invalid_argument( - "Cannot sort ArrayFire tensor with given SortMode: " - "only Descending and Ascending supported."); - } - - af::array values, indices; - af::sort( - values, indices, toArray(input), axis, sortMode == SortMode::Ascending); - return toTensor(std::move(indices), input.ndim()); + const SortMode sortMode +) { + if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) { + throw std::invalid_argument( + "Cannot sort ArrayFire tensor with given SortMode: " + "only Descending and Ascending supported." + ); + } + + af::array values, indices; + af::sort( + values, + indices, + toArray(input), + axis, + sortMode == SortMode::Ascending + ); + return toTensor(std::move(indices), input.ndim()); } void ArrayFireBackend::print(const Tensor& tensor) { - af::print("ArrayFireTensor", toArray(tensor)); + af::print("ArrayFireTensor", toArray(tensor)); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBackend.h b/flashlight/fl/tensor/backend/af/ArrayFireBackend.h index 4cdf4ca..19af6f9 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBackend.h +++ b/flashlight/fl/tensor/backend/af/ArrayFireBackend.h @@ -25,252 +25,262 @@ namespace fl { * ArrayFire counterparts. */ class ArrayFireBackend : public TensorBackend { - // TODO: consolidate the ArrayFire memory manager here so its global state can - // be stored/we can reduce the number of singletons. - std::once_flag memoryInitFlag; + // TODO: consolidate the ArrayFire memory manager here so its global state can + // be stored/we can reduce the number of singletons. + std::once_flag memoryInitFlag; - // These help ensure we are using native device id in public methods. - std::unordered_map nativeIdToId_; - std::unordered_map idToNativeId_; + // These help ensure we are using native device id in public methods. + std::unordered_map nativeIdToId_; + std::unordered_map idToNativeId_; - // keep track of the individual active stream on each ArrayFire device - // NOTE using a `shared_ptr` to allow its capture in setActive callback; - // see constructor for details. - std::shared_ptr>> - afIdToStream_{std::make_shared< - std::unordered_map>>()}; + // keep track of the individual active stream on each ArrayFire device + // NOTE using a `shared_ptr` to allow its capture in setActive callback; + // see constructor for details. + std::shared_ptr>> + afIdToStream_{std::make_shared< + std::unordered_map>>()}; - // Intentionally private. Only one instance should exist/it should be accessed - // via getInstance(). - ArrayFireBackend(); + // Intentionally private. Only one instance should exist/it should be accessed + // via getInstance(). + ArrayFireBackend(); - public: - static ArrayFireBackend& getInstance(); - ~ArrayFireBackend() override = default; - TensorBackendType backendType() const override; +public: + static ArrayFireBackend& getInstance(); + ~ArrayFireBackend() override = default; + TensorBackendType backendType() const override; - // No copy or move construction or assignment - ArrayFireBackend(ArrayFireBackend&&) = delete; - ArrayFireBackend(const ArrayFireBackend&) = delete; - ArrayFireBackend& operator=(ArrayFireBackend&&) = delete; - ArrayFireBackend& operator=(const ArrayFireBackend&) = delete; + // No copy or move construction or assignment + ArrayFireBackend(ArrayFireBackend&&) = delete; + ArrayFireBackend(const ArrayFireBackend&) = delete; + ArrayFireBackend& operator=(ArrayFireBackend&&) = delete; + ArrayFireBackend& operator=(const ArrayFireBackend&) = delete; - /* -------------------------- Compute Functions -------------------------- */ - void eval(const Tensor& tensor) override; + /* -------------------------- Compute Functions -------------------------- */ + void eval(const Tensor& tensor) override; - /** - * Return the stream from which the given array was created. - * - * @return an immutable reference to the stream from which `arr` was created. - */ - const Stream& getStreamOfArray(const af::array& arr); - bool supportsDataType(const fl::dtype& dtype) const override; - // Memory management - void getMemMgrInfo(const char* msg, const int nativeDeviceId, std::ostream* ostream) - override; - void setMemMgrLogStream(std::ostream* stream) override; - void setMemMgrLoggingEnabled(const bool enabled) override; - void setMemMgrFlushInterval(const size_t interval) override; + /** + * Return the stream from which the given array was created. + * + * @return an immutable reference to the stream from which `arr` was created. + */ + const Stream& getStreamOfArray(const af::array& arr); + bool supportsDataType(const fl::dtype& dtype) const override; + // Memory management + void getMemMgrInfo(const char* msg, const int nativeDeviceId, std::ostream* ostream) + override; + void setMemMgrLogStream(std::ostream* stream) override; + void setMemMgrLoggingEnabled(const bool enabled) override; + void setMemMgrFlushInterval(const size_t interval) override; - /* -------------------------- Rand Functions -------------------------- */ - void setSeed(const int seed) override; - Tensor randn(const Shape& shape, dtype type) override; - Tensor rand(const Shape& shape, dtype type) override; + /* -------------------------- Rand Functions -------------------------- */ + void setSeed(const int seed) override; + Tensor randn(const Shape& shape, dtype type) override; + Tensor rand(const Shape& shape, dtype type) override; - /* --------------------------- Tensor Operators --------------------------- */ - /******************** Tensor Creation Functions ********************/ -#define AF_BACKEND_CREATE_FUN_LITERAL_DECL(TYPE) \ - Tensor fromScalar(TYPE value, const dtype type) override; \ - Tensor full(const Shape& dims, TYPE value, const dtype type) override; - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const double&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const float&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const int&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const char&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned char&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const long&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned long&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const long long&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned long long&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const bool&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const short&); - AF_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned short&); + /* --------------------------- Tensor Operators --------------------------- */ + /******************** Tensor Creation Functions ********************/ +#define AF_BACKEND_CREATE_FUN_LITERAL_DECL(TYPE) \ + Tensor fromScalar(TYPE value, const dtype type) override; \ + Tensor full(const Shape& dims, TYPE value, const dtype type) override; + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const double&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const float&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const int&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const char&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned char&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const long&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned long&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const long long&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned long long&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const bool&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const short&); + AF_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned short&); #undef AF_BACKEND_CREATE_FUN_LITERAL_DECL - Tensor identity(const Dim dim, const dtype type) override; - Tensor arange(const Shape& shape, const Dim seqDim, const dtype type) - override; - Tensor iota(const Shape& dims, const Shape& tileDims, const dtype type) - override; + Tensor identity(const Dim dim, const dtype type) override; + Tensor arange(const Shape& shape, const Dim seqDim, const dtype type) + override; + Tensor iota(const Shape& dims, const Shape& tileDims, const dtype type) + override; - /************************ Shaping and Indexing *************************/ - Tensor reshape(const Tensor& tensor, const Shape& shape) override; - Tensor transpose(const Tensor& tensor, const Shape& axes /* = {} */) override; - Tensor tile(const Tensor& tensor, const Shape& shape) override; - Tensor concatenate(const std::vector& tensors, const unsigned axis) - override; - Tensor nonzero(const Tensor& tensor) override; - Tensor pad( - const Tensor& input, - const std::vector>& padWidths, - const PadType type) override; + /************************ Shaping and Indexing *************************/ + Tensor reshape(const Tensor& tensor, const Shape& shape) override; + Tensor transpose(const Tensor& tensor, const Shape& axes /* = {} */) override; + Tensor tile(const Tensor& tensor, const Shape& shape) override; + Tensor concatenate(const std::vector& tensors, const unsigned axis) + override; + Tensor nonzero(const Tensor& tensor) override; + Tensor pad( + const Tensor& input, + const std::vector>& padWidths, + const PadType type + ) override; - /************************** Unary Operators ***************************/ - Tensor exp(const Tensor& tensor) override; - Tensor log(const Tensor& tensor) override; - Tensor negative(const Tensor& tensor) override; - Tensor logicalNot(const Tensor& tensor) override; - Tensor log1p(const Tensor& tensor) override; - Tensor sin(const Tensor& tensor) override; - Tensor cos(const Tensor& tensor) override; - Tensor sqrt(const Tensor& tensor) override; - Tensor tanh(const Tensor& tensor) override; - Tensor floor(const Tensor& tensor) override; - Tensor ceil(const Tensor& tensor) override; - Tensor rint(const Tensor& tensor) override; - Tensor absolute(const Tensor& tensor) override; - Tensor sigmoid(const Tensor& tensor) override; - Tensor erf(const Tensor& tensor) override; - Tensor flip(const Tensor& tensor, const unsigned dim) override; - Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high) - override; - Tensor roll(const Tensor& tensor, const int shift, const unsigned axis) - override; - Tensor isnan(const Tensor& tensor) override; - Tensor isinf(const Tensor& tensor) override; - Tensor sign(const Tensor& tensor) override; - Tensor tril(const Tensor& tensor) override; - Tensor triu(const Tensor& tensor) override; - Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) - override; - void topk( - Tensor& values, - Tensor& indices, - const Tensor& input, - const unsigned k, - const Dim axis, - const SortMode sortMode) override; - Tensor sort(const Tensor& input, const Dim axis, const SortMode sortMode) - override; - void sort( - Tensor& values, - Tensor& indices, - const Tensor& input, - const Dim axis, - const SortMode sortMode) override; - Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) - override; + /************************** Unary Operators ***************************/ + Tensor exp(const Tensor& tensor) override; + Tensor log(const Tensor& tensor) override; + Tensor negative(const Tensor& tensor) override; + Tensor logicalNot(const Tensor& tensor) override; + Tensor log1p(const Tensor& tensor) override; + Tensor sin(const Tensor& tensor) override; + Tensor cos(const Tensor& tensor) override; + Tensor sqrt(const Tensor& tensor) override; + Tensor tanh(const Tensor& tensor) override; + Tensor floor(const Tensor& tensor) override; + Tensor ceil(const Tensor& tensor) override; + Tensor rint(const Tensor& tensor) override; + Tensor absolute(const Tensor& tensor) override; + Tensor sigmoid(const Tensor& tensor) override; + Tensor erf(const Tensor& tensor) override; + Tensor flip(const Tensor& tensor, const unsigned dim) override; + Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high) + override; + Tensor roll(const Tensor& tensor, const int shift, const unsigned axis) + override; + Tensor isnan(const Tensor& tensor) override; + Tensor isinf(const Tensor& tensor) override; + Tensor sign(const Tensor& tensor) override; + Tensor tril(const Tensor& tensor) override; + Tensor triu(const Tensor& tensor) override; + Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) + override; + void topk( + Tensor& values, + Tensor& indices, + const Tensor& input, + const unsigned k, + const Dim axis, + const SortMode sortMode + ) override; + Tensor sort(const Tensor& input, const Dim axis, const SortMode sortMode) + override; + void sort( + Tensor& values, + Tensor& indices, + const Tensor& input, + const Dim axis, + const SortMode sortMode + ) override; + Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) + override; - /************************** Binary Operators ***************************/ -#define FL_AF_BINARY_OP_TYPE_DECL(FUNC, TYPE) \ - Tensor FUNC(const Tensor& a, TYPE rhs) override; \ - Tensor FUNC(TYPE lhs, const Tensor& a) override; + /************************** Binary Operators ***************************/ +#define FL_AF_BINARY_OP_TYPE_DECL(FUNC, TYPE) \ + Tensor FUNC(const Tensor& a, TYPE rhs) override; \ + Tensor FUNC(TYPE lhs, const Tensor& a) override; -#define FL_AF_BINARY_OP_LITERALS_DECL(FUNC) \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const bool&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const int&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const unsigned&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const char&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const unsigned char&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const long&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const unsigned long&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const long long&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const unsigned long long&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const double&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const float&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const short&); \ - FL_AF_BINARY_OP_TYPE_DECL(FUNC, const unsigned short&); +#define FL_AF_BINARY_OP_LITERALS_DECL(FUNC) \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const bool&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const int&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const unsigned&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const char&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const unsigned char&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const long&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const unsigned long&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const long long&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const unsigned long long&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const double&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const float&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const short&); \ + FL_AF_BINARY_OP_TYPE_DECL(FUNC, const unsigned short&); -#define FL_AF_BINARY_OP_DECL(FUNC) \ - Tensor FUNC(const Tensor& lhs, const Tensor& rhs) override; \ - FL_AF_BINARY_OP_LITERALS_DECL(FUNC); +#define FL_AF_BINARY_OP_DECL(FUNC) \ + Tensor FUNC(const Tensor& lhs, const Tensor& rhs) override; \ + FL_AF_BINARY_OP_LITERALS_DECL(FUNC); - FL_AF_BINARY_OP_DECL(add); - FL_AF_BINARY_OP_DECL(sub); - FL_AF_BINARY_OP_DECL(mul); - FL_AF_BINARY_OP_DECL(div); - FL_AF_BINARY_OP_DECL(eq); - FL_AF_BINARY_OP_DECL(neq); - FL_AF_BINARY_OP_DECL(lessThan); - FL_AF_BINARY_OP_DECL(lessThanEqual); - FL_AF_BINARY_OP_DECL(greaterThan); - FL_AF_BINARY_OP_DECL(greaterThanEqual); - FL_AF_BINARY_OP_DECL(logicalOr); - FL_AF_BINARY_OP_DECL(logicalAnd); - FL_AF_BINARY_OP_DECL(mod); - FL_AF_BINARY_OP_DECL(bitwiseAnd); - FL_AF_BINARY_OP_DECL(bitwiseOr); - FL_AF_BINARY_OP_DECL(bitwiseXor); - FL_AF_BINARY_OP_DECL(lShift); - FL_AF_BINARY_OP_DECL(rShift); + FL_AF_BINARY_OP_DECL(add); + FL_AF_BINARY_OP_DECL(sub); + FL_AF_BINARY_OP_DECL(mul); + FL_AF_BINARY_OP_DECL(div); + FL_AF_BINARY_OP_DECL(eq); + FL_AF_BINARY_OP_DECL(neq); + FL_AF_BINARY_OP_DECL(lessThan); + FL_AF_BINARY_OP_DECL(lessThanEqual); + FL_AF_BINARY_OP_DECL(greaterThan); + FL_AF_BINARY_OP_DECL(greaterThanEqual); + FL_AF_BINARY_OP_DECL(logicalOr); + FL_AF_BINARY_OP_DECL(logicalAnd); + FL_AF_BINARY_OP_DECL(mod); + FL_AF_BINARY_OP_DECL(bitwiseAnd); + FL_AF_BINARY_OP_DECL(bitwiseOr); + FL_AF_BINARY_OP_DECL(bitwiseXor); + FL_AF_BINARY_OP_DECL(lShift); + FL_AF_BINARY_OP_DECL(rShift); #undef FL_AF_BINARY_OP_DECL #undef FL_AF_BINARY_OP_TYPE_DECL #undef FL_AF_BINARY_OP_LITERALS_DECL - Tensor minimum(const Tensor& lhs, const Tensor& rhs) override; - Tensor maximum(const Tensor& lhs, const Tensor& rhs) override; - Tensor power(const Tensor& lhs, const Tensor& rhs) override; + Tensor minimum(const Tensor& lhs, const Tensor& rhs) override; + Tensor maximum(const Tensor& lhs, const Tensor& rhs) override; + Tensor power(const Tensor& lhs, const Tensor& rhs) override; - /******************************* BLAS ********************************/ - Tensor matmul( - const Tensor& lhs, - const Tensor& rhs, - MatrixProperty lhsProp, - MatrixProperty rhsProp) override; + /******************************* BLAS ********************************/ + Tensor matmul( + const Tensor& lhs, + const Tensor& rhs, + MatrixProperty lhsProp, + MatrixProperty rhsProp + ) override; - /************************** Reductions ***************************/ - Tensor amin(const Tensor& input, const std::vector& axes, const bool keepDims) - override; - Tensor amax(const Tensor& input, const std::vector& axes, const bool keepDims) - override; - void min( - Tensor& values, - Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims) override; - void max( - Tensor& values, - Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims) override; - Tensor sum(const Tensor& input, const std::vector& axes, const bool keepDims) - override; - Tensor cumsum(const Tensor& input, const unsigned axis) override; - Tensor argmax(const Tensor& input, const unsigned axis, const bool keepDims) - override; - Tensor argmin(const Tensor& input, const unsigned axis, const bool keepDims) - override; - Tensor mean(const Tensor& input, const std::vector& axes, const bool keepDims) - override; - Tensor median( - const Tensor& input, - const std::vector& axes, - const bool keepDims) override; - Tensor var( - const Tensor& input, - const std::vector& axes, - const bool bias, - const bool keepDims) override; - Tensor std(const Tensor& input, const std::vector& axes, const bool keepDims) - override; - Tensor norm( - const Tensor& input, - const std::vector& axes, - double p, - const bool keepDims) override; - Tensor countNonzero( - const Tensor& input, - const std::vector& axes, - const bool keepDims) override; - Tensor any(const Tensor& input, const std::vector& axes, const bool keepDims) - override; - Tensor all(const Tensor& input, const std::vector& axes, const bool keepDims) - override; + /************************** Reductions ***************************/ + Tensor amin(const Tensor& input, const std::vector& axes, const bool keepDims) + override; + Tensor amax(const Tensor& input, const std::vector& axes, const bool keepDims) + override; + void min( + Tensor& values, + Tensor& indices, + const Tensor& input, + const unsigned axis, + const bool keepDims + ) override; + void max( + Tensor& values, + Tensor& indices, + const Tensor& input, + const unsigned axis, + const bool keepDims + ) override; + Tensor sum(const Tensor& input, const std::vector& axes, const bool keepDims) + override; + Tensor cumsum(const Tensor& input, const unsigned axis) override; + Tensor argmax(const Tensor& input, const unsigned axis, const bool keepDims) + override; + Tensor argmin(const Tensor& input, const unsigned axis, const bool keepDims) + override; + Tensor mean(const Tensor& input, const std::vector& axes, const bool keepDims) + override; + Tensor median( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) override; + Tensor var( + const Tensor& input, + const std::vector& axes, + const bool bias, + const bool keepDims + ) override; + Tensor std(const Tensor& input, const std::vector& axes, const bool keepDims) + override; + Tensor norm( + const Tensor& input, + const std::vector& axes, + double p, + const bool keepDims + ) override; + Tensor countNonzero( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) override; + Tensor any(const Tensor& input, const std::vector& axes, const bool keepDims) + override; + Tensor all(const Tensor& input, const std::vector& axes, const bool keepDims) + override; - /************************** Utils ***************************/ - void print(const Tensor& tensor) override; + /************************** Utils ***************************/ + void print(const Tensor& tensor) override; }; } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp b/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp index 43742ba..6223ec6 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp @@ -15,83 +15,89 @@ namespace fl { namespace { -bool canBroadcast(const Shape& lhs, const Shape& rhs) { - unsigned nDim = std::max(lhs.ndim(), rhs.ndim()); - - for (unsigned i = 0; i < nDim; ++i) { - if (i + 1 > lhs.ndim() || i + 1 > rhs.ndim()) { - // One Shape has more dimensions than the other - will broadcast to the - // smaller tensor - continue; + bool canBroadcast(const Shape& lhs, const Shape& rhs) { + unsigned nDim = std::max(lhs.ndim(), rhs.ndim()); + + for(unsigned i = 0; i < nDim; ++i) { + if(i + 1 > lhs.ndim() || i + 1 > rhs.ndim()) { + // One Shape has more dimensions than the other - will broadcast to the + // smaller tensor + continue; + } + if(lhs[i] != rhs[i] && lhs[i] != 1 && rhs[i] != 1) { + return false; + } + } + return true; } - if (lhs[i] != rhs[i] && lhs[i] != 1 && rhs[i] != 1) { - return false; - } - } - return true; -} // A binary operation on two ArrayFire arrays -using binaryOpFunc_t = - af::array (*)(const af::array& lhs, const af::array& rhs); - -Tensor doBinaryOpOrBroadcast( - const Tensor& lhs, - const Tensor& rhs, - binaryOpFunc_t func) { - // Dims are the same or scalar <> 1-el tensor - no broadcasting - if (lhs.shape() == rhs.shape() || - (lhs.elements() <= 1 && rhs.elements() <= 1)) { - return toTensor( - func(toArray(lhs), toArray(rhs)), lhs.ndim()); - } - - if (canBroadcast(lhs.shape(), rhs.shape())) { - return toTensor( - af::batchFunc(toArray(lhs), toArray(rhs), func), - std::max(lhs.ndim(), rhs.ndim())); - } else { - std::stringstream ss; - ss << "doBinaryOpOrBroadcast: cannot perform operation " - "or broadcasting with tensors of shapes " - << lhs.shape() << " and " << rhs.shape() << " - dimension mismatch."; - throw std::invalid_argument(ss.str()); - } -} + using binaryOpFunc_t = + af::array (*)(const af::array& lhs, const af::array& rhs); + + Tensor doBinaryOpOrBroadcast( + const Tensor& lhs, + const Tensor& rhs, + binaryOpFunc_t func + ) { + // Dims are the same or scalar <> 1-el tensor - no broadcasting + if( + lhs.shape() == rhs.shape() + || (lhs.elements() <= 1 && rhs.elements() <= 1) + ) { + return toTensor( + func(toArray(lhs), toArray(rhs)), + lhs.ndim() + ); + } + + if(canBroadcast(lhs.shape(), rhs.shape())) { + return toTensor( + af::batchFunc(toArray(lhs), toArray(rhs), func), + std::max(lhs.ndim(), rhs.ndim()) + ); + } else { + std::stringstream ss; + ss << "doBinaryOpOrBroadcast: cannot perform operation " + "or broadcasting with tensors of shapes " + << lhs.shape() << " and " << rhs.shape() << " - dimension mismatch."; + throw std::invalid_argument(ss.str()); + } + } } // namespace // For ArrayFire, af::array already implements overloads for all needed // operators -- use these by default. -#define FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, TYPE) \ - Tensor ArrayFireBackend::FUNC(const Tensor& a, TYPE rhs) { \ - return toTensor(toArray(a) OP rhs, a.ndim()); \ - } \ - Tensor ArrayFireBackend::FUNC(TYPE lhs, const Tensor& a) { \ - return toTensor(lhs OP toArray(a), a.ndim()); \ - } - -#define FL_AF_BINARY_OP_LITERALS_DEF(FUNC, OP) \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const bool&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const int&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const char&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned char&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const long&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned long&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const long long&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned long long&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const double&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const float&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const short&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned short&); +#define FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, TYPE) \ + Tensor ArrayFireBackend::FUNC(const Tensor& a, TYPE rhs) { \ + return toTensor(toArray(a) OP rhs, a.ndim()); \ + } \ + Tensor ArrayFireBackend::FUNC(TYPE lhs, const Tensor& a) { \ + return toTensor(lhs OP toArray(a), a.ndim()); \ + } + +#define FL_AF_BINARY_OP_LITERALS_DEF(FUNC, OP) \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const bool&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const int&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const char&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned char&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const long&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned long&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const long long&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned long long&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const double&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const float&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const short&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned short&); // Operations on fl::Tensor call the respective operator overloads that are // already defined on af::arrays -#define FL_AF_BINARY_OP_DEF(OP, FUNC) \ - Tensor ArrayFireBackend::FUNC(const Tensor& lhs, const Tensor& rhs) { \ - return doBinaryOpOrBroadcast(lhs, rhs, af::operator OP); \ - } \ - FL_AF_BINARY_OP_LITERALS_DEF(FUNC, OP); +#define FL_AF_BINARY_OP_DEF(OP, FUNC) \ + Tensor ArrayFireBackend::FUNC(const Tensor& lhs, const Tensor& rhs) { \ + return doBinaryOpOrBroadcast(lhs, rhs, af::operator OP); \ + } \ + FL_AF_BINARY_OP_LITERALS_DEF(FUNC, OP); // Definitions // Since ArrayFire implements operator overloads, map both fl::Tensor @@ -120,14 +126,14 @@ FL_AF_BINARY_OP_DEF(>>, rShift); #undef FL_AF_BINARY_OP_LITERALS_DEF Tensor ArrayFireBackend::minimum(const Tensor& lhs, const Tensor& rhs) { - return doBinaryOpOrBroadcast(lhs, rhs, af::min); + return doBinaryOpOrBroadcast(lhs, rhs, af::min); } Tensor ArrayFireBackend::maximum(const Tensor& lhs, const Tensor& rhs) { - return doBinaryOpOrBroadcast(lhs, rhs, af::max); + return doBinaryOpOrBroadcast(lhs, rhs, af::max); } Tensor ArrayFireBackend::power(const Tensor& lhs, const Tensor& rhs) { - return doBinaryOpOrBroadcast(lhs, rhs, af::pow); + return doBinaryOpOrBroadcast(lhs, rhs, af::pow); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.cpp b/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.cpp index 50e0b79..c1098a1 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.cpp @@ -12,17 +12,17 @@ namespace fl { std::shared_ptr ArrayFireCPUStream::create() { - // TODO `std::make_shared` requires a public constructor, which could be - // abused and lead to unregistered stream. However, it has one internal - // allocation and is more cache-friendly than `std::shared_ptr`. - const auto rawStreamPtr = new ArrayFireCPUStream(); - const auto stream = std::shared_ptr(rawStreamPtr); - rawStreamPtr->device_.addStream(stream); - return stream; + // TODO `std::make_shared` requires a public constructor, which could be + // abused and lead to unregistered stream. However, it has one internal + // allocation and is more cache-friendly than `std::shared_ptr`. + const auto rawStreamPtr = new ArrayFireCPUStream(); + const auto stream = std::shared_ptr(rawStreamPtr); + rawStreamPtr->device_.addStream(stream); + return stream; } void ArrayFireCPUStream::sync() const { - af::sync(af::getDevice()); + af::sync(af::getDevice()); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.h b/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.h index 27826c0..4ac150c 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.h +++ b/flashlight/fl/tensor/backend/af/ArrayFireCPUStream.h @@ -15,16 +15,16 @@ namespace fl { * An abstraction for ArrayFire's CPU Stream with controlled creation methods. */ class ArrayFireCPUStream : public SynchronousStream { - public: - /** - * Creates an ArrayFireCPUStream and automatically register it with - * the active x64 device from DeviceManager. - * - * @return a shared pointer to the created ArrayFireCPUStream. - */ - static std::shared_ptr create(); +public: + /** + * Creates an ArrayFireCPUStream and automatically register it with + * the active x64 device from DeviceManager. + * + * @return a shared pointer to the created ArrayFireCPUStream. + */ + static std::shared_ptr create(); - void sync() const override; + void sync() const override; }; } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp b/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp index da5f2e1..593289f 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp @@ -18,89 +18,97 @@ namespace fl { namespace { -using reduceFunc_t = af::array (*)(const af::array&, const int); + using reduceFunc_t = af::array (*)(const af::array&, const int); -template -af::array afReduceAxes( - const af::array& input, - const std::vector& axes, - T func, - const bool keepDims = false) { - auto arr = input; - for (int dim : axes) { - arr = func(arr, dim); - } - return fl::detail::condenseIndices(arr, keepDims); -} + template + af::array afReduceAxes( + const af::array& input, + const std::vector& axes, + T func, + const bool keepDims = false + ) { + auto arr = input; + for(int dim : axes) { + arr = func(arr, dim); + } + return fl::detail::condenseIndices(arr, keepDims); + } -unsigned -getReducedNumDims(unsigned inSize, unsigned axisSize, const bool keepDims) { - if (keepDims) { - return inSize; - } else { - if (inSize < axisSize) { - return 0; - } else { - return inSize - axisSize; + unsigned getReducedNumDims(unsigned inSize, unsigned axisSize, const bool keepDims) { + if(keepDims) { + return inSize; + } else { + if(inSize < axisSize) { + return 0; + } else { + return inSize - axisSize; + } + } } - } -} -bool isAllAxisReduction(const Tensor& input, const std::vector& axes) { - if (input.ndim() == 0 || axes.empty()) { - return true; - } - if (input.ndim() != axes.size()) { - return false; - } - // Check that all dims are present - auto _axes = axes; - std::sort(_axes.begin(), _axes.end()); - for (size_t i = 0; i < _axes.size(); ++i) { - if (_axes[i] != i) { - return false; + bool isAllAxisReduction(const Tensor& input, const std::vector& axes) { + if(input.ndim() == 0 || axes.empty()) { + return true; + } + if(input.ndim() != axes.size()) { + return false; + } + // Check that all dims are present + auto _axes = axes; + std::sort(_axes.begin(), _axes.end()); + for(size_t i = 0; i < _axes.size(); ++i) { + if(_axes[i] != i) { + return false; + } + } + return true; } - } - return true; -} } // namespace Tensor ArrayFireBackend::amin( const Tensor& input, const std::vector& axes, - const bool keepDims) { - if (isAllAxisReduction(input, axes)) { - // Reduce along all axes returning a singleton tensor - // TODO: modify this to af::min to take advantage of the - // ArrayFire reduce_all kernels once available - return toTensor( - detail::condenseIndices( - af::min(af::min(af::min(af::min(toArray(input)))))), - /* numDims = */ 0); - } else { - return toTensor( - afReduceAxes(toArray(input), axes, af::min, keepDims), - getReducedNumDims(input.ndim(), axes.size(), keepDims)); - } + const bool keepDims +) { + if(isAllAxisReduction(input, axes)) { + // Reduce along all axes returning a singleton tensor + // TODO: modify this to af::min to take advantage of the + // ArrayFire reduce_all kernels once available + return toTensor( + detail::condenseIndices( + af::min(af::min(af::min(af::min(toArray(input))))) + ), + /* numDims = */ 0 + ); + } else { + return toTensor( + afReduceAxes(toArray(input), axes, af::min, keepDims), + getReducedNumDims(input.ndim(), axes.size(), keepDims) + ); + } } Tensor ArrayFireBackend::amax( const Tensor& input, const std::vector& axes, - const bool keepDims) { - if (isAllAxisReduction(input, axes)) { - // Reduce along all axes returning a singleton tensor - // TODO: modify this to af::max to take advantage of the - // ArrayFire reduce_all kernels once available - return toTensor( - detail::condenseIndices( - af::max(af::max(af::max(af::max(toArray(input)))))), - /* numDims = */ 0); - } else { - return toTensor( - afReduceAxes(toArray(input), axes, af::max, keepDims), - getReducedNumDims(input.ndim(), axes.size(), keepDims)); - } + const bool keepDims +) { + if(isAllAxisReduction(input, axes)) { + // Reduce along all axes returning a singleton tensor + // TODO: modify this to af::max to take advantage of the + // ArrayFire reduce_all kernels once available + return toTensor( + detail::condenseIndices( + af::max(af::max(af::max(af::max(toArray(input))))) + ), + /* numDims = */ 0 + ); + } else { + return toTensor( + afReduceAxes(toArray(input), axes, af::max, keepDims), + getReducedNumDims(input.ndim(), axes.size(), keepDims) + ); + } } void ArrayFireBackend::min( @@ -108,14 +116,17 @@ void ArrayFireBackend::min( Tensor& indices, const Tensor& input, const unsigned axis, - const bool keepDims) { - af::min(toArray(values), toArray(indices), toArray(input), axis); - values = toTensor( - detail::condenseIndices(toArray(values), keepDims), - getReducedNumDims(input.ndim(), 1, keepDims)); - indices = toTensor( - detail::condenseIndices(toArray(indices), keepDims), - getReducedNumDims(input.ndim(), 1, keepDims)); + const bool keepDims +) { + af::min(toArray(values), toArray(indices), toArray(input), axis); + values = toTensor( + detail::condenseIndices(toArray(values), keepDims), + getReducedNumDims(input.ndim(), 1, keepDims) + ); + indices = toTensor( + detail::condenseIndices(toArray(indices), keepDims), + getReducedNumDims(input.ndim(), 1, keepDims) + ); } void ArrayFireBackend::max( @@ -123,250 +134,303 @@ void ArrayFireBackend::max( Tensor& indices, const Tensor& input, const unsigned axis, - const bool keepDims) { - af::max(toArray(values), toArray(indices), toArray(input), axis); - values = toTensor( - detail::condenseIndices(toArray(values), keepDims), - getReducedNumDims(input.ndim(), 1, keepDims)); - indices = toTensor( - detail::condenseIndices(toArray(indices), keepDims), - getReducedNumDims(input.ndim(), 1, keepDims)); + const bool keepDims +) { + af::max(toArray(values), toArray(indices), toArray(input), axis); + values = toTensor( + detail::condenseIndices(toArray(values), keepDims), + getReducedNumDims(input.ndim(), 1, keepDims) + ); + indices = toTensor( + detail::condenseIndices(toArray(indices), keepDims), + getReducedNumDims(input.ndim(), 1, keepDims) + ); } Tensor ArrayFireBackend::sum( const Tensor& input, const std::vector& axes, - const bool keepDims) { - if (isAllAxisReduction(input, axes)) { - // Reduce along all axes returning a singleton tensor - // TODO: modify this to af::sum to take advantage of the - // ArrayFire reduce_all kernels once available - return toTensor( - detail::condenseIndices( - af::sum(af::sum(af::sum(af::sum(toArray(input)))))), - /* numDims = */ 0); - } else { - return toTensor( - afReduceAxes(toArray(input), axes, af::sum, keepDims), - getReducedNumDims(input.ndim(), axes.size(), keepDims)); - } + const bool keepDims +) { + if(isAllAxisReduction(input, axes)) { + // Reduce along all axes returning a singleton tensor + // TODO: modify this to af::sum to take advantage of the + // ArrayFire reduce_all kernels once available + return toTensor( + detail::condenseIndices( + af::sum(af::sum(af::sum(af::sum(toArray(input))))) + ), + /* numDims = */ 0 + ); + } else { + return toTensor( + afReduceAxes(toArray(input), axes, af::sum, keepDims), + getReducedNumDims(input.ndim(), axes.size(), keepDims) + ); + } } Tensor ArrayFireBackend::cumsum(const Tensor& input, const unsigned axis) { - return toTensor( - af::accum(toArray(input), axis), /* numDims = */ input.ndim()); + return toTensor( + af::accum(toArray(input), axis), /* numDims = */ + input.ndim() + ); } Tensor ArrayFireBackend::argmax( const Tensor& input, const unsigned axis, - const bool keepDims) { - af::array tmpVal, indices; - af::max(tmpVal, indices, toArray(input), axis); - return toTensor( - detail::condenseIndices(indices, keepDims), - getReducedNumDims(input.ndim(), 1, keepDims)); + const bool keepDims +) { + af::array tmpVal, indices; + af::max(tmpVal, indices, toArray(input), axis); + return toTensor( + detail::condenseIndices(indices, keepDims), + getReducedNumDims(input.ndim(), 1, keepDims) + ); } Tensor ArrayFireBackend::argmin( const Tensor& input, const unsigned axis, - const bool keepDims) { - af::array tmpVal, indices; - af::min(tmpVal, indices, toArray(input), axis); - return toTensor( - detail::condenseIndices(indices, keepDims), - getReducedNumDims(input.ndim(), 1, keepDims)); + const bool keepDims +) { + af::array tmpVal, indices; + af::min(tmpVal, indices, toArray(input), axis); + return toTensor( + detail::condenseIndices(indices, keepDims), + getReducedNumDims(input.ndim(), 1, keepDims) + ); } Tensor ArrayFireBackend::mean( const Tensor& input, const std::vector& axes, - const bool keepDims) { - if (isAllAxisReduction(input, axes)) { - // Reduce along all axes returning a singleton tensor - // TODO: modify this to af::mean to take advantage of the - // ArrayFire reduce_all kernels once available - return toTensor( - detail::condenseIndices( - af::mean(af::mean(af::mean(af::mean(toArray(input)))))), - /* numDims = */ 0); - } else { - return toTensor( - afReduceAxes( - toArray(input), axes, af::mean, keepDims), - getReducedNumDims(input.ndim(), axes.size(), keepDims)); - } + const bool keepDims +) { + if(isAllAxisReduction(input, axes)) { + // Reduce along all axes returning a singleton tensor + // TODO: modify this to af::mean to take advantage of the + // ArrayFire reduce_all kernels once available + return toTensor( + detail::condenseIndices( + af::mean(af::mean(af::mean(af::mean(toArray(input))))) + ), + /* numDims = */ 0 + ); + } else { + return toTensor( + afReduceAxes( + toArray(input), + axes, + af::mean, + keepDims + ), + getReducedNumDims(input.ndim(), axes.size(), keepDims) + ); + } } Tensor ArrayFireBackend::median( const Tensor& input, const std::vector& axes, - const bool keepDims) { - if (isAllAxisReduction(input, axes)) { - // Reduce along all axes returning a singleton tensor - // TODO: modify this to af::median to take advantage of the - // ArrayFire reduce_all kernels once available - double median = af::median(toArray(input)); - return toTensor( - af::constant(median, 1), - /* numDims = */ 0); - } else { - return toTensor( - afReduceAxes( - toArray(input), axes, af::median, keepDims), - getReducedNumDims(input.ndim(), axes.size(), keepDims)); - } + const bool keepDims +) { + if(isAllAxisReduction(input, axes)) { + // Reduce along all axes returning a singleton tensor + // TODO: modify this to af::median to take advantage of the + // ArrayFire reduce_all kernels once available + double median = af::median(toArray(input)); + return toTensor( + af::constant(median, 1), + /* numDims = */ 0 + ); + } else { + return toTensor( + afReduceAxes( + toArray(input), + axes, + af::median, + keepDims + ), + getReducedNumDims(input.ndim(), axes.size(), keepDims) + ); + } } Tensor ArrayFireBackend::var( const Tensor& input, const std::vector& axes, const bool bias, - const bool keepDims) { - af_var_bias biasMode = bias ? AF_VARIANCE_SAMPLE : AF_VARIANCE_POPULATION; - // Use ArrayFire default for one dimension which may be optimized - auto& arr = toArray(input); - // Reduce along all axes returning a singleton tensor - // TODO: modify this to af::var to take advantage of the - // ArrayFire reduce_all kernels once available - if (isAllAxisReduction(input, axes)) { - double out = af::var(toArray(input), biasMode); - return toTensor(af::constant(out, 1), /* numDims = */ 0); - } else if (axes.size() == 1) { - return toTensor( - detail::condenseIndices(af::var(arr, biasMode, axes[0]), keepDims), - getReducedNumDims(input.ndim(), axes.size(), keepDims)); - } else { - auto meanArr = mean(input, axes, /* keepDims = */ true); - auto x = af::batchFunc(arr, toArray(meanArr), af::operator-); + const bool keepDims +) { + af_var_bias biasMode = bias ? AF_VARIANCE_SAMPLE : AF_VARIANCE_POPULATION; + // Use ArrayFire default for one dimension which may be optimized + auto& arr = toArray(input); + // Reduce along all axes returning a singleton tensor + // TODO: modify this to af::var to take advantage of the + // ArrayFire reduce_all kernels once available + if(isAllAxisReduction(input, axes)) { + double out = af::var(toArray(input), biasMode); + return toTensor(af::constant(out, 1), /* numDims = */ 0); + } else if(axes.size() == 1) { + return toTensor( + detail::condenseIndices(af::var(arr, biasMode, axes[0]), keepDims), + getReducedNumDims(input.ndim(), axes.size(), keepDims) + ); + } else { + auto meanArr = mean(input, axes, /* keepDims = */ true); + auto x = af::batchFunc(arr, toArray(meanArr), af::operator-); - x = af::pow(x, 2); - x = afReduceAxes(x, axes, af::sum, /* keepDims = */ true); + x = af::pow(x, 2); + x = afReduceAxes(x, axes, af::sum, /* keepDims = */ true); - int denominator = 1; - auto dims = arr.dims(); - for (auto dim : axes) { - denominator *= dims[dim]; - } - if (bias) { - denominator--; - } + int denominator = 1; + auto dims = arr.dims(); + for(auto dim : axes) { + denominator *= dims[dim]; + } + if(bias) { + denominator--; + } - x = x / denominator; - return toTensor( - detail::condenseIndices(x, keepDims), - getReducedNumDims(input.ndim(), axes.size(), keepDims)); - } + x = x / denominator; + return toTensor( + detail::condenseIndices(x, keepDims), + getReducedNumDims(input.ndim(), axes.size(), keepDims) + ); + } } Tensor ArrayFireBackend::std( const Tensor& input, const std::vector& axes, - const bool keepDims) { - const bool bias = false; // TODO: make this configurable - af_var_bias biasMode = bias ? AF_VARIANCE_SAMPLE : AF_VARIANCE_POPULATION; - if (isAllAxisReduction(input, axes)) { - // TODO: update to af::stdev once specialization is available - double out = af::stdev(toArray(input), biasMode); - return toTensor(af::constant(out, 1), /* numDims = */ 0); - } else if (axes.size() == 1) { - // Use arrayfire default for one dimension which may be optimized - // TODO: update this? stddev is deprecated. - return toTensor( - detail::condenseIndices( - af::stdev(toArray(input), biasMode, axes[0]), keepDims), - getReducedNumDims(input.ndim(), axes.size(), keepDims)); - } - return this->sqrt(this->var(input, axes, /* bias = */ bias, keepDims)); + const bool keepDims +) { + const bool bias = false; // TODO: make this configurable + af_var_bias biasMode = bias ? AF_VARIANCE_SAMPLE : AF_VARIANCE_POPULATION; + if(isAllAxisReduction(input, axes)) { + // TODO: update to af::stdev once specialization is available + double out = af::stdev(toArray(input), biasMode); + return toTensor(af::constant(out, 1), /* numDims = */ 0); + } else if(axes.size() == 1) { + // Use arrayfire default for one dimension which may be optimized + // TODO: update this? stddev is deprecated. + return toTensor( + detail::condenseIndices( + af::stdev(toArray(input), biasMode, axes[0]), + keepDims + ), + getReducedNumDims(input.ndim(), axes.size(), keepDims) + ); + } + return this->sqrt(this->var(input, axes, /* bias = */ bias, keepDims)); } Tensor ArrayFireBackend::norm( const Tensor& input, const std::vector& axes, double p /* = 2 */, - const bool keepDims) { - if (isAllAxisReduction(input, axes)) { - // TODO: update to af::norm if device-side specialization is - // available. Either that or use the all-axis specializations with the below - // implementation - auto result = af::pow(af::abs(af::flat(toArray(input))), p); - // Replace with af::sum - result = af::sum(af::sum(af::sum(result))); - result = af::pow(result, 1 / p); - return toTensor( - detail::condenseIndices(result), /* numDims = */ 0); - } else { - auto result = af::pow(af::abs(toArray(input)), p); - result = afReduceAxes(result, axes, af::sum, keepDims); - result = af::pow(result, 1 / p); - return toTensor( - std::move(result), - getReducedNumDims(input.ndim(), axes.size(), keepDims)); - } + const bool keepDims +) { + if(isAllAxisReduction(input, axes)) { + // TODO: update to af::norm if device-side specialization is + // available. Either that or use the all-axis specializations with the below + // implementation + auto result = af::pow(af::abs(af::flat(toArray(input))), p); + // Replace with af::sum + result = af::sum(af::sum(af::sum(result))); + result = af::pow(result, 1 / p); + return toTensor( + detail::condenseIndices(result), /* numDims = */ + 0 + ); + } else { + auto result = af::pow(af::abs(toArray(input)), p); + result = afReduceAxes(result, axes, af::sum, keepDims); + result = af::pow(result, 1 / p); + return toTensor( + std::move(result), + getReducedNumDims(input.ndim(), axes.size(), keepDims) + ); + } } Tensor ArrayFireBackend::countNonzero( const Tensor& input, const std::vector& axes, - const bool keepDims) { - auto& arr = toArray(input); - unsigned numDims; - af::array out; - if (isAllAxisReduction(input, axes)) { - out = detail::condenseIndices( - af::sum(af::sum(af::sum(af::count(arr)))), keepDims); - numDims = 0; - } else if (axes.size() == 1) { - out = af::count(arr, axes.front()); - numDims = getReducedNumDims(input.ndim(), axes.size(), keepDims); - } else { - out = afReduceAxes( - af::count(arr, axes.front()), - std::vector(axes.begin() + 1, axes.end()), - af::sum, - keepDims); - numDims = getReducedNumDims(input.ndim(), axes.size(), keepDims); - } - return toTensor( - detail::condenseIndices(out, keepDims), numDims); + const bool keepDims +) { + auto& arr = toArray(input); + unsigned numDims; + af::array out; + if(isAllAxisReduction(input, axes)) { + out = detail::condenseIndices( + af::sum(af::sum(af::sum(af::count(arr)))), + keepDims + ); + numDims = 0; + } else if(axes.size() == 1) { + out = af::count(arr, axes.front()); + numDims = getReducedNumDims(input.ndim(), axes.size(), keepDims); + } else { + out = afReduceAxes( + af::count(arr, axes.front()), + std::vector(axes.begin() + 1, axes.end()), + af::sum, + keepDims + ); + numDims = getReducedNumDims(input.ndim(), axes.size(), keepDims); + } + return toTensor( + detail::condenseIndices(out, keepDims), + numDims + ); } Tensor ArrayFireBackend::any( const Tensor& input, const std::vector& axes, - const bool keepDims) { - if (isAllAxisReduction(input, axes)) { - // Reduce along all axes returning a singleton tensor - // TODO: modify this to af::anyTrue to take advantage of the - // ArrayFire reduce_all kernels once available - return toTensor( - detail::condenseIndices( - af::anyTrue(af::anyTrue(af::anyTrue(af::anyTrue(toArray(input)))))), - /* numDims = */ 0); - } else { - return toTensor( - afReduceAxes(toArray(input), axes, af::anyTrue, keepDims), - getReducedNumDims(input.ndim(), axes.size(), keepDims)); - } + const bool keepDims +) { + if(isAllAxisReduction(input, axes)) { + // Reduce along all axes returning a singleton tensor + // TODO: modify this to af::anyTrue to take advantage of the + // ArrayFire reduce_all kernels once available + return toTensor( + detail::condenseIndices( + af::anyTrue(af::anyTrue(af::anyTrue(af::anyTrue(toArray(input))))) + ), + /* numDims = */ 0 + ); + } else { + return toTensor( + afReduceAxes(toArray(input), axes, af::anyTrue, keepDims), + getReducedNumDims(input.ndim(), axes.size(), keepDims) + ); + } } Tensor ArrayFireBackend::all( const Tensor& input, const std::vector& axes, - const bool keepDims) { - if (isAllAxisReduction(input, axes)) { - // Reduce along all axes returning a singleton tensor - // TODO: modify this to af::allTrue to take advantage of the - // ArrayFire reduce_all kernels once available - return toTensor( - detail::condenseIndices( - af::allTrue(af::allTrue(af::allTrue(af::allTrue(toArray(input)))))), - /* numDims = */ 0); - } else { - return toTensor( - afReduceAxes(toArray(input), axes, af::allTrue, keepDims), - getReducedNumDims(input.ndim(), axes.size(), keepDims)); - } + const bool keepDims +) { + if(isAllAxisReduction(input, axes)) { + // Reduce along all axes returning a singleton tensor + // TODO: modify this to af::allTrue to take advantage of the + // ArrayFire reduce_all kernels once available + return toTensor( + detail::condenseIndices( + af::allTrue(af::allTrue(af::allTrue(af::allTrue(toArray(input))))) + ), + /* numDims = */ 0 + ); + } else { + return toTensor( + afReduceAxes(toArray(input), axes, af::allTrue, keepDims), + getReducedNumDims(input.ndim(), axes.size(), keepDims) + ); + } } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp b/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp index 83402a7..48317c5 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp @@ -17,137 +17,162 @@ namespace fl { Tensor ArrayFireBackend::reshape(const Tensor& tensor, const Shape& shape) { - return toTensor( - af::moddims(toArray(tensor), detail::flToAfDims(shape)), shape.ndim()); + return toTensor( + af::moddims(toArray(tensor), detail::flToAfDims(shape)), + shape.ndim() + ); } Tensor ArrayFireBackend::transpose( const Tensor& tensor, - const Shape& axes /* = {} */) { - if (tensor.ndim() == 1) { - return tensor; - } else if ( - tensor.ndim() == 2 && (axes.ndim() == 0 || axes == Shape({1, 0}))) { - // fastpath for matrices - return toTensor( - af::transpose(toArray(tensor)), tensor.ndim()); - } else if (axes.ndim() == 0) { - std::vector dims(AF_MAX_DIMS); - std::iota(std::begin(dims), std::end(dims), 0); - // Compute the reversed dimensions for as many ndims as are in the input - for (unsigned i = 0; i < tensor.ndim(); ++i) { - dims[i] = tensor.ndim() - 1 - i; - } - - // flip all dimensions - return toTensor( - af::reorder(toArray(tensor), dims[0], dims[1], dims[2], dims[3]), - tensor.ndim()); - } else { - if (axes.ndim() > AF_MAX_DIMS) { - throw std::invalid_argument( - "ArrayFire tensor transpose was given " - "permutation dims with > 4 axes"); - } - if (axes.ndim() != tensor.ndim()) { - throw std::invalid_argument( - "ArrayFire tensor transpose axes don't match tensor's for " - "permutation - axes must have the same number of " - "dimensions as the tensor"); + const Shape& axes /* = {} */ +) { + if(tensor.ndim() == 1) { + return tensor; + } else if( + tensor.ndim() == 2 && (axes.ndim() == 0 || axes == Shape({1, 0}))) { + // fastpath for matrices + return toTensor( + af::transpose(toArray(tensor)), + tensor.ndim() + ); + } else if(axes.ndim() == 0) { + std::vector dims(AF_MAX_DIMS); + std::iota(std::begin(dims), std::end(dims), 0); + // Compute the reversed dimensions for as many ndims as are in the input + for(unsigned i = 0; i < tensor.ndim(); ++i) { + dims[i] = tensor.ndim() - 1 - i; + } + + // flip all dimensions + return toTensor( + af::reorder(toArray(tensor), dims[0], dims[1], dims[2], dims[3]), + tensor.ndim() + ); + } else { + if(axes.ndim() > AF_MAX_DIMS) { + throw std::invalid_argument( + "ArrayFire tensor transpose was given " + "permutation dims with > 4 axes" + ); + } + if(axes.ndim() != tensor.ndim()) { + throw std::invalid_argument( + "ArrayFire tensor transpose axes don't match tensor's for " + "permutation - axes must have the same number of " + "dimensions as the tensor" + ); + } + // reorder based on specified dimensions + std::vector d(AF_MAX_DIMS); + std::iota(std::begin(d), std::end(d), 0); + for(size_t i = 0; i < axes.ndim(); ++i) { + if(axes[i] > tensor.ndim() - 1) { + throw std::invalid_argument( + "ArrayFireBackend::transpose - given dimension is larger " + "than the number of dimensions in the tensor" + ); + } + + d[i] = axes[i]; + } + return toTensor( + af::reorder(toArray(tensor), d[0], d[1], d[2], d[3]), + tensor.ndim() + ); } - // reorder based on specified dimensions - std::vector d(AF_MAX_DIMS); - std::iota(std::begin(d), std::end(d), 0); - for (size_t i = 0; i < axes.ndim(); ++i) { - if (axes[i] > tensor.ndim() - 1) { - throw std::invalid_argument( - "ArrayFireBackend::transpose - given dimension is larger " - "than the number of dimensions in the tensor"); - } - - d[i] = axes[i]; - } - return toTensor( - af::reorder(toArray(tensor), d[0], d[1], d[2], d[3]), tensor.ndim()); - } } Tensor ArrayFireBackend::tile(const Tensor& tensor, const Shape& shape) { - return toTensor( - af::tile(toArray(tensor), detail::flToAfDims(shape)), - // TODO: check - std::max(tensor.ndim(), shape.ndim())); + return toTensor( + af::tile(toArray(tensor), detail::flToAfDims(shape)), + // TODO: check + std::max(tensor.ndim(), shape.ndim()) + ); } Tensor ArrayFireBackend::concatenate( const std::vector& tensors, - const unsigned axis) { - af::array out; - switch (tensors.size()) { - case 0: - return toTensor(ArrayFireTensor()); // empty tensor - case 1: - return tensors.front(); - case 2: - out = af::join(axis, toArray(tensors[0]), toArray(tensors[1])); - break; - case 3: - out = af::join( - axis, toArray(tensors[0]), toArray(tensors[1]), toArray(tensors[2])); - break; - case 4: - out = af::join( - axis, - toArray(tensors[0]), - toArray(tensors[1]), - toArray(tensors[2]), - toArray(tensors[3])); - break; - default: - // TODO: iteratively concat to remove this limitation - throw std::invalid_argument( - "ArrayFire concatenate doesn't support > 4 tensors"); - } - - unsigned numDims = tensors[0].ndim(); - if (axis > std::max(numDims - 1, 0u)) { - numDims = axis + 1; - } - - // All tensors have the same numdims else AF would throw - return toTensor(std::move(out), numDims); + const unsigned axis +) { + af::array out; + switch(tensors.size()) { + case 0: + return toTensor(ArrayFireTensor()); // empty tensor + case 1: + return tensors.front(); + case 2: + out = af::join(axis, toArray(tensors[0]), toArray(tensors[1])); + break; + case 3: + out = af::join( + axis, + toArray(tensors[0]), + toArray(tensors[1]), + toArray(tensors[2]) + ); + break; + case 4: + out = af::join( + axis, + toArray(tensors[0]), + toArray(tensors[1]), + toArray(tensors[2]), + toArray(tensors[3]) + ); + break; + default: + // TODO: iteratively concat to remove this limitation + throw std::invalid_argument( + "ArrayFire concatenate doesn't support > 4 tensors" + ); + } + + unsigned numDims = tensors[0].ndim(); + if(axis > std::max(numDims - 1, 0u)) { + numDims = axis + 1; + } + + // All tensors have the same numdims else AF would throw + return toTensor(std::move(out), numDims); } Tensor ArrayFireBackend::nonzero(const Tensor& tensor) { - return toTensor( - af::where(toArray(tensor)), /* numDims = */ 1); + return toTensor( + af::where(toArray(tensor)), /* numDims = */ + 1 + ); } Tensor ArrayFireBackend::pad( const Tensor& input, const std::vector>& padWidths, - const PadType type) { - if (padWidths.size() > AF_MAX_DIMS) { - throw std::invalid_argument( - "ArrayFireBackend::pad - given padWidths for more than 4 dimensions"); - } - - // convert ((begin_1, end_1), ..., (begin_k, end_k)) to ((begin_1, ..., - // begin_k), (end_1, ..., end_k)) for ArrayFire - af::dim4 beginPadding, endPadding; - for (size_t i = 0; i < padWidths.size(); ++i) { - auto& [first, second] = padWidths[i]; - beginPadding[i] = first; - endPadding[i] = second; - } - - return toTensor( - af::pad( - toArray(input), - beginPadding, - endPadding, - detail::flToAfPadType(type)), - /* numDims = */ // TODO: check - std::max(input.ndim(), static_cast(padWidths.size()))); + const PadType type +) { + if(padWidths.size() > AF_MAX_DIMS) { + throw std::invalid_argument( + "ArrayFireBackend::pad - given padWidths for more than 4 dimensions" + ); + } + + // convert ((begin_1, end_1), ..., (begin_k, end_k)) to ((begin_1, ..., + // begin_k), (end_1, ..., end_k)) for ArrayFire + af::dim4 beginPadding, endPadding; + for(size_t i = 0; i < padWidths.size(); ++i) { + auto& [first, second] = padWidths[i]; + beginPadding[i] = first; + endPadding[i] = second; + } + + return toTensor( + af::pad( + toArray(input), + beginPadding, + endPadding, + detail::flToAfPadType(type) + ), + /* numDims = */ // TODO: check + std::max(input.ndim(), static_cast(padWidths.size())) + ); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp b/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp index d5c9356..9f06309 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp @@ -27,52 +27,56 @@ namespace fl { const af::array& toArray(const Tensor& tensor) { - if (tensor.backendType() != TensorBackendType::ArrayFire) { - throw std::invalid_argument("toArray: tensor is not ArrayFire-backed"); - } - return tensor.getAdapter().getHandle(); + if(tensor.backendType() != TensorBackendType::ArrayFire) { + throw std::invalid_argument("toArray: tensor is not ArrayFire-backed"); + } + return tensor.getAdapter().getHandle(); } af::array& toArray(Tensor& tensor) { - if (tensor.backendType() != TensorBackendType::ArrayFire) { - throw std::invalid_argument("toArray: tensor is not ArrayFire-backed"); - } - return tensor.getAdapter().getHandle(); + if(tensor.backendType() != TensorBackendType::ArrayFire) { + throw std::invalid_argument("toArray: tensor is not ArrayFire-backed"); + } + return tensor.getAdapter().getHandle(); } -ArrayFireTensor::ArrayFireTensor(af::array&& array, const unsigned numDims) - : arrayHandle_(std::make_shared(std::move(array))), - numDims_(numDims) {} +ArrayFireTensor::ArrayFireTensor( + af::array&& array, + const unsigned numDims +) : arrayHandle_(std::make_shared(std::move(array))), + numDims_(numDims) {} ArrayFireTensor::ArrayFireTensor( std::shared_ptr arr, std::vector&& afIndices, std::vector&& indexTypes, const unsigned numDims, - const bool isFlat) - : arrayHandle_(arr), - indices_(std::move(afIndices)), - indexTypes_(std::move(indexTypes)), - handle_(IndexedArrayComponent(isFlat)), - numDims_(numDims) {} + const bool isFlat +) : arrayHandle_(arr), + indices_(std::move(afIndices)), + indexTypes_(std::move(indexTypes)), + handle_(IndexedArrayComponent(isFlat)), + numDims_(numDims) {} ArrayFireTensor::ArrayFireTensor( std::shared_ptr arr, - unsigned numDims) - : arrayHandle_(arr), numDims_(numDims) {} + unsigned numDims +) : arrayHandle_(arr), + numDims_(numDims) {} -ArrayFireTensor::ArrayFireTensor() - : arrayHandle_(std::make_shared()), handle_(ArrayComponent()) {} +ArrayFireTensor::ArrayFireTensor() : arrayHandle_(std::make_shared()), + handle_(ArrayComponent()) {} ArrayFireTensor::ArrayFireTensor( const Shape& shape, fl::dtype type, const void* ptr, - Location memoryLocation) - : arrayHandle_(std::make_shared( - detail::fromFlData(shape, ptr, type, memoryLocation))), - handle_(ArrayComponent()), - numDims_(shape.ndim()) {} + Location memoryLocation +) : arrayHandle_(std::make_shared( + detail::fromFlData(shape, ptr, type, memoryLocation) + )), + handle_(ArrayComponent()), + numDims_(shape.ndim()) {} ArrayFireTensor::ArrayFireTensor( const Dim nRows, @@ -80,442 +84,477 @@ ArrayFireTensor::ArrayFireTensor( const Tensor& values, const Tensor& rowIdx, const Tensor& colIdx, - StorageType storageType) - : arrayHandle_(std::make_shared(af::sparse( - nRows, - nCols, - toArray(values), - toArray(rowIdx), - toArray(colIdx), - detail::flToAfStorageType(storageType)))), - handle_(ArrayComponent()), - // ArrayFire only supports 2D sparsity - numDims_(2) {} + StorageType storageType +) : arrayHandle_(std::make_shared( + af::sparse( + nRows, + nCols, + toArray(values), + toArray(rowIdx), + toArray(colIdx), + detail::flToAfStorageType(storageType)) + )), + handle_(ArrayComponent()), + // ArrayFire only supports 2D sparsity + numDims_(2) {} unsigned ArrayFireTensor::numDims() const { - return numDims_; + return numDims_; } ArrayFireTensor::IndexedArrayComponent::IndexedArrayComponent( - const bool _isFlat /* = false */) - : isFlat(_isFlat) {} + const bool _isFlat /* = false */ +) : isFlat(_isFlat) {} af::array::array_proxy ArrayFireTensor::IndexedArrayComponent::get( - const ArrayFireTensor& inst) { - auto& i = inst.indices_.value(); - auto& a = *(inst.arrayHandle_); - switch (i.size()) { - case 1: - return a(i[0]); - case 2: - return a(i[0], i[1]); - case 3: - return a(i[0], i[1], i[2]); - case 4: - return a(i[0], i[1], i[2], i[3]); - default: - throw std::invalid_argument( - "ArrayFireTensor::IndexedArrayComponent::get - " - "given invalid number of index components."); - } + const ArrayFireTensor& inst +) { + auto& i = inst.indices_.value(); + auto& a = *(inst.arrayHandle_); + switch(i.size()) { + case 1: + return a(i[0]); + case 2: + return a(i[0], i[1]); + case 3: + return a(i[0], i[1], i[2]); + case 4: + return a(i[0], i[1], i[2], i[3]); + default: + throw std::invalid_argument( + "ArrayFireTensor::IndexedArrayComponent::get - " + "given invalid number of index components." + ); + } } af::array& ArrayFireTensor::ArrayComponent::get(const ArrayFireTensor& inst) { - return *(inst.arrayHandle_); + return *(inst.arrayHandle_); } const af::array& ArrayFireTensor::getHandle() const { - return const_cast(this)->getHandle(); + return const_cast(this)->getHandle(); } af::array& ArrayFireTensor::getHandle() { - // If the handle currently requires indexing, perform the indexing, change the - // getter to visit, and clear the indices. Upcast the af::array::array_proxy - // to an af::array via its operator array() and update the handle. - // Additionally, since we can't directly mutate the dimensions of an - // af::array::array_proxy, condense the indices of the resulting array after - // the conversion. - if (!std::holds_alternative(handle_)) { - auto& idxComp = std::get(handle_); - arrayHandle_ = std::make_shared(detail::condenseIndices( - idxComp.get(*this), - /* keepDims = */ false, - indexTypes_, - /* isFlat = */ idxComp.isFlat)); - // Clear state - handle_ = ArrayComponent(); // set to passthrough - indices_ = {}; // remove indices - indexTypes_ = {}; // remove IndexTypes - } - return *arrayHandle_; + // If the handle currently requires indexing, perform the indexing, change the + // getter to visit, and clear the indices. Upcast the af::array::array_proxy + // to an af::array via its operator array() and update the handle. + // Additionally, since we can't directly mutate the dimensions of an + // af::array::array_proxy, condense the indices of the resulting array after + // the conversion. + if(!std::holds_alternative(handle_)) { + auto& idxComp = std::get(handle_); + arrayHandle_ = std::make_shared( + detail::condenseIndices( + idxComp.get(*this), + /* keepDims = */ false, + indexTypes_, + /* isFlat = */ idxComp.isFlat + ) + ); + // Clear state + handle_ = ArrayComponent(); // set to passthrough + indices_ = {}; // remove indices + indexTypes_ = {}; // remove IndexTypes + } + return *arrayHandle_; } std::unique_ptr ArrayFireTensor::clone() const { - af::array arr = getHandle(); // increment internal AF refcount - return std::make_unique( - std::move(arr), numDims()); + af::array arr = getHandle(); // increment internal AF refcount + return std::make_unique( + std::move(arr), + numDims() + ); } Tensor ArrayFireTensor::copy() { - getHandle(); // if this tensor was a view, run indexing and promote - return toTensor(arrayHandle_->copy(), numDims()); + getHandle(); // if this tensor was a view, run indexing and promote + return toTensor(arrayHandle_->copy(), numDims()); } Tensor ArrayFireTensor::shallowCopy() { - getHandle(); // if this tensor was a view, run indexing and promote - return Tensor(std::unique_ptr( - new ArrayFireTensor(arrayHandle_, numDims()))); + getHandle(); // if this tensor was a view, run indexing and promote + return Tensor( + std::unique_ptr( + new ArrayFireTensor(arrayHandle_, numDims()) + ) + ); } TensorBackendType ArrayFireTensor::backendType() const { - return TensorBackendType::ArrayFire; + return TensorBackendType::ArrayFire; } TensorBackend& ArrayFireTensor::backend() const { - // The ArrayFire backend has a single ArrayFireBackend instance per process. - return ::fl::ArrayFireBackend::getInstance(); + // The ArrayFire backend has a single ArrayFireBackend instance per process. + return ::fl::ArrayFireBackend::getInstance(); } const Shape& ArrayFireTensor::shape() { - // Update the Shape in-place. Doesn't change any underlying data; only the - // mirrored Shape metadata. - detail::afToFlDims(getHandle().dims(), numDims(), shape_); - return shape_; + // Update the Shape in-place. Doesn't change any underlying data; only the + // mirrored Shape metadata. + detail::afToFlDims(getHandle().dims(), numDims(), shape_); + return shape_; } fl::dtype ArrayFireTensor::type() { - return detail::afToFlType(getHandle().type()); + return detail::afToFlType(getHandle().type()); } bool ArrayFireTensor::isSparse() { - return getHandle().issparse(); + return getHandle().issparse(); } af::dtype ArrayFireTensor::afHandleType() { - return arrayHandle_->type(); + return arrayHandle_->type(); } Location ArrayFireTensor::location() { - switch (af::getBackendId(getHandle())) { - case AF_BACKEND_CUDA: - case AF_BACKEND_OPENCL: - return Location::Device; - case AF_BACKEND_CPU: - return Location::Host; - default: - throw std::logic_error( - "ArrayFireTensor::location got an unmatched location"); - } + switch(af::getBackendId(getHandle())) { + case AF_BACKEND_CUDA: + case AF_BACKEND_OPENCL: + return Location::Device; + case AF_BACKEND_CPU: + return Location::Host; + default: + throw std::logic_error( + "ArrayFireTensor::location got an unmatched location" + ); + } } void ArrayFireTensor::scalar(void* out) { - AF_CHECK(af_get_scalar(out, getHandle().get())); + AF_CHECK(af_get_scalar(out, getHandle().get())); } void ArrayFireTensor::device(void** out) { - AF_CHECK(af_get_device_ptr(out, getHandle().get())); + AF_CHECK(af_get_device_ptr(out, getHandle().get())); } void ArrayFireTensor::host(void* out) { - AF_CHECK(af_get_data_ptr(out, getHandle().get())); + AF_CHECK(af_get_data_ptr(out, getHandle().get())); } void ArrayFireTensor::unlock() { - AF_CHECK(af_unlock_array(getHandle().get())); + AF_CHECK(af_unlock_array(getHandle().get())); } bool ArrayFireTensor::isLocked() { - bool res; - auto err = af_is_locked_array(&res, getHandle().get()); - if (err != AF_SUCCESS) { - throw std::runtime_error( - "ArrayFireTensor::isLocked - af_is_locked_array returned error: " + - std::to_string(err)); - } - return res; + bool res; + auto err = af_is_locked_array(&res, getHandle().get()); + if(err != AF_SUCCESS) { + throw std::runtime_error( + "ArrayFireTensor::isLocked - af_is_locked_array returned error: " + + std::to_string(err) + ); + } + return res; } bool ArrayFireTensor::isContiguous() { - return af::isLinear(getHandle()); + return af::isLinear(getHandle()); } Shape ArrayFireTensor::strides() { - return detail::afToFlDims(af::getStrides(getHandle()), numDims()); + return detail::afToFlDims(af::getStrides(getHandle()), numDims()); } const Stream& ArrayFireTensor::stream() const { - // TODO indexing is unlikely to change the stream associated with a tensor. - // But if it can, we need to call `getHandle()` here. - return ArrayFireBackend::getInstance().getStreamOfArray(*arrayHandle_); + // TODO indexing is unlikely to change the stream associated with a tensor. + // But if it can, we need to call `getHandle()` here. + return ArrayFireBackend::getInstance().getStreamOfArray(*arrayHandle_); } Tensor ArrayFireTensor::astype(const dtype type) { - auto a = getHandle().as(detail::flToAfType(type)); - return toTensor(std::move(a), numDims()); + auto a = getHandle().as(detail::flToAfType(type)); + return toTensor(std::move(a), numDims()); } Tensor ArrayFireTensor::index(const std::vector& indices) { - if (indices.size() > AF_MAX_DIMS) { - throw std::invalid_argument( - "ArrayFire-backed tensor was indexed with > 4 elements:" - "ArrayFire tensors support up to 4 dimensions."); - } - - // TODO: vet and stress test this a lot more/add proper support for - // multi-tensor - // If indexing by a single element and it's a tensor with the same number of - // indices as the array being indexed, do a flat index as this is probably a - // filter-based index (for example: a(a < 5)). - bool completeTensorIndex = indices.size() == 1 && - indices.front().type() == detail::IndexType::Tensor && - indices.front().get().elements() == getHandle().elements(); - std::vector afIndices; - if (completeTensorIndex) { - afIndices = {af::index(0)}; - } else { - afIndices = {af::span, af::span, af::span, af::span}; // implicit spans - } - - if (indices.size() > afIndices.size()) { - throw std::logic_error( - "ArrayFireTensor::index internal error - passed indiecs is larger " - "than the number of af indices"); - } - - // Fill in corresponding index types for each af index - std::vector indexTypes(afIndices.size()); - size_t i = 0; - for (; i < indices.size(); ++i) { - indexTypes[i] = indices[i].type(); - afIndices[i] = detail::flToAfIndex(indices[i]); - } - // If we're adding implicit spans, fill those indexTypes in - for (; i < afIndices.size(); ++i) { - indexTypes[i] = detail::IndexType::Span; - } - - getHandle(); // if this tensor was a view, run indexing and promote - - assert(afIndices.size() == indexTypes.size()); - // Compute numDums for the new Tensor - unsigned newNumDims = numDims(); - - if (completeTensorIndex) { - // TODO/FIXME: compute this based on the number of els in the indexing - // tensor(s) - newNumDims = 1; - } else { - for (const auto& type : indexTypes) { - if (type == detail::IndexType::Literal) { - newNumDims--; - } + if(indices.size() > AF_MAX_DIMS) { + throw std::invalid_argument( + "ArrayFire-backed tensor was indexed with > 4 elements:" + "ArrayFire tensors support up to 4 dimensions." + ); + } + + // TODO: vet and stress test this a lot more/add proper support for + // multi-tensor + // If indexing by a single element and it's a tensor with the same number of + // indices as the array being indexed, do a flat index as this is probably a + // filter-based index (for example: a(a < 5)). + bool completeTensorIndex = indices.size() == 1 + && indices.front().type() == detail::IndexType::Tensor + && indices.front().get().elements() == getHandle().elements(); + std::vector afIndices; + if(completeTensorIndex) { + afIndices = {af::index(0)}; + } else { + afIndices = {af::span, af::span, af::span, af::span}; // implicit spans + } + + if(indices.size() > afIndices.size()) { + throw std::logic_error( + "ArrayFireTensor::index internal error - passed indiecs is larger " + "than the number of af indices" + ); } - } - newNumDims = std::max(newNumDims, 1u); // can never index to a 0 dim tensor - return fl::Tensor(std::unique_ptr(new ArrayFireTensor( - arrayHandle_, - std::move(afIndices), - std::move(indexTypes), - newNumDims, - /* isFlat = */ false))); + // Fill in corresponding index types for each af index + std::vector indexTypes(afIndices.size()); + size_t i = 0; + for(; i < indices.size(); ++i) { + indexTypes[i] = indices[i].type(); + afIndices[i] = detail::flToAfIndex(indices[i]); + } + // If we're adding implicit spans, fill those indexTypes in + for(; i < afIndices.size(); ++i) { + indexTypes[i] = detail::IndexType::Span; + } + + getHandle(); // if this tensor was a view, run indexing and promote + + assert(afIndices.size() == indexTypes.size()); + // Compute numDums for the new Tensor + unsigned newNumDims = numDims(); + + if(completeTensorIndex) { + // TODO/FIXME: compute this based on the number of els in the indexing + // tensor(s) + newNumDims = 1; + } else { + for(const auto& type : indexTypes) { + if(type == detail::IndexType::Literal) { + newNumDims--; + } + } + } + newNumDims = std::max(newNumDims, 1u); // can never index to a 0 dim tensor + + return fl::Tensor( + std::unique_ptr( + new ArrayFireTensor( + arrayHandle_, + std::move(afIndices), + std::move(indexTypes), + newNumDims, + /* isFlat = */ false + ) + ) + ); } Tensor ArrayFireTensor::flatten() const { - return toTensor(af::flat(getHandle()), /* numDims = */ 1); + return toTensor(af::flat(getHandle()), /* numDims = */ 1); } Tensor ArrayFireTensor::flat(const Index& idx) const { - getHandle(); // if this tensor was a view, run indexing and promote - // Return a lazy indexing operation. Indexing with a single index on an - // ArrayFire tensor (with a type that is not an af::array) ends up doing - // flat indexing, so all index assignment operators will work as they are. - return fl::Tensor(std::unique_ptr(new ArrayFireTensor( - arrayHandle_, - {detail::flToAfIndex(idx)}, - {idx.type()}, - /* numDims = */ 1, - /* isFlat = */ true))); + getHandle(); // if this tensor was a view, run indexing and promote + // Return a lazy indexing operation. Indexing with a single index on an + // ArrayFire tensor (with a type that is not an af::array) ends up doing + // flat indexing, so all index assignment operators will work as they are. + return fl::Tensor( + std::unique_ptr( + new ArrayFireTensor( + arrayHandle_, + {detail::flToAfIndex(idx)}, + {idx.type()}, + /* numDims = */ 1, + /* isFlat = */ true + ) + ) + ); } Tensor ArrayFireTensor::asContiguousTensor() { - if (isContiguous()) { - af::array other = getHandle(); - return toTensor(std::move(other), numDims()); - } + if(isContiguous()) { + af::array other = getHandle(); + return toTensor(std::move(other), numDims()); + } - const af::array& array = getHandle(); - auto linearArray = af::array(array.dims(), array.type()); - af::copy(linearArray, array, af::span); - return toTensor(std::move(linearArray), numDims()); + const af::array& array = getHandle(); + auto linearArray = af::array(array.dims(), array.type()); + af::copy(linearArray, array, af::span); + return toTensor(std::move(linearArray), numDims()); } void ArrayFireTensor::setContext(void* context) {} // noop void* ArrayFireTensor::getContext() { - return nullptr; // noop + return nullptr; // noop } std::string ArrayFireTensor::toString() { - const char* afStr = af::toString("ArrayFireTensor", getHandle()); - // std::string copies `afStr` content into its own buffer - const std::string str(afStr); - af::freeHost(afStr); - return str; + const char* afStr = af::toString("ArrayFireTensor", getHandle()); + // std::string copies `afStr` content into its own buffer + const std::string str(afStr); + af::freeHost(afStr); + return str; } std::ostream& ArrayFireTensor::operator<<(std::ostream& ostr) { - ostr << this->toString(); - return ostr; + ostr << this->toString(); + return ostr; } /******************** Assignment Operators ********************/ -#define ASSIGN_OP_TYPE(FUN, AF_OP, TYPE) \ - void ArrayFireTensor::FUN(const TYPE& val) { \ - std::visit( \ - [val, this](auto&& arr) { arr.get(*this) AF_OP val; }, handle_); \ - } - -#define ASSIGN_OP_LITERALS(FUN, AF_OP) \ - ASSIGN_OP_TYPE(FUN, AF_OP, double); \ - ASSIGN_OP_TYPE(FUN, AF_OP, float); \ - ASSIGN_OP_TYPE(FUN, AF_OP, int); \ - ASSIGN_OP_TYPE(FUN, AF_OP, unsigned); \ - ASSIGN_OP_TYPE(FUN, AF_OP, bool); \ - ASSIGN_OP_TYPE(FUN, AF_OP, char); \ - ASSIGN_OP_TYPE(FUN, AF_OP, unsigned char); \ - ASSIGN_OP_TYPE(FUN, AF_OP, short); \ - ASSIGN_OP_TYPE(FUN, AF_OP, unsigned short); \ - ASSIGN_OP_TYPE(FUN, AF_OP, long); \ - ASSIGN_OP_TYPE(FUN, AF_OP, unsigned long); \ - ASSIGN_OP_TYPE(FUN, AF_OP, long long); \ - ASSIGN_OP_TYPE(FUN, AF_OP, unsigned long long); +#define ASSIGN_OP_TYPE(FUN, AF_OP, TYPE) \ + void ArrayFireTensor::FUN(const TYPE& val) { \ + std::visit( \ + [val, this](auto&& arr) { arr.get(*this) AF_OP val; }, \ + handle_ \ + ); \ + } + +#define ASSIGN_OP_LITERALS(FUN, AF_OP) \ + ASSIGN_OP_TYPE(FUN, AF_OP, double); \ + ASSIGN_OP_TYPE(FUN, AF_OP, float); \ + ASSIGN_OP_TYPE(FUN, AF_OP, int); \ + ASSIGN_OP_TYPE(FUN, AF_OP, unsigned); \ + ASSIGN_OP_TYPE(FUN, AF_OP, bool); \ + ASSIGN_OP_TYPE(FUN, AF_OP, char); \ + ASSIGN_OP_TYPE(FUN, AF_OP, unsigned char); \ + ASSIGN_OP_TYPE(FUN, AF_OP, short); \ + ASSIGN_OP_TYPE(FUN, AF_OP, unsigned short); \ + ASSIGN_OP_TYPE(FUN, AF_OP, long); \ + ASSIGN_OP_TYPE(FUN, AF_OP, unsigned long); \ + ASSIGN_OP_TYPE(FUN, AF_OP, long long); \ + ASSIGN_OP_TYPE(FUN, AF_OP, unsigned long long); af::array ArrayFireTensor::adjustInPlaceOperandDims(const Tensor& operand) { - // optimstically try to moddims the operand's singleton dims - const af::dim4& preIdxDims = arrayHandle_->dims(); - const af::array& operandArr = toArray(operand); - - // dims to which to try to modify the input if doing indexing - af::dim4 newDims; - const af::dim4 operandDims = operandArr.dims(); - - using detail::IndexType; - if (indices_ && indices_.value().size() == 1) { - // This case is only reachable via tensor-based indexing or indexing on a - // tensor via Tensor::flat() - if (numDims_ != 1) { - throw std::invalid_argument( - "ArrayFireTensor::adjustInPlaceOperandDims " - "index size was 1 but tensor has greater than 1 dimension."); - } - } else if (indices_ && !indices_.value().empty()) { - // All other indexing operations - const auto& indices = indices_.value(); - const auto& indexTypes = indexTypes_.value(); - if (indices.size() != indexTypes.size()) { - throw std::invalid_argument( - "ArrayFireTensor adjustInPlaceOperandDims - passed indices" - " and indexTypes are of different sizes."); - } + // optimstically try to moddims the operand's singleton dims + const af::dim4& preIdxDims = arrayHandle_->dims(); + const af::array& operandArr = toArray(operand); + + // dims to which to try to modify the input if doing indexing + af::dim4 newDims; + const af::dim4 operandDims = operandArr.dims(); + + using detail::IndexType; + if(indices_ && indices_.value().size() == 1) { + // This case is only reachable via tensor-based indexing or indexing on a + // tensor via Tensor::flat() + if(numDims_ != 1) { + throw std::invalid_argument( + "ArrayFireTensor::adjustInPlaceOperandDims " + "index size was 1 but tensor has greater than 1 dimension." + ); + } + } else if(indices_ && !indices_.value().empty()) { + // All other indexing operations + const auto& indices = indices_.value(); + const auto& indexTypes = indexTypes_.value(); + if(indices.size() != indexTypes.size()) { + throw std::invalid_argument( + "ArrayFireTensor adjustInPlaceOperandDims - passed indices" + " and indexTypes are of different sizes." + ); + } - // If the dimensions being indexed are 1 and collapsing them yields the same - // shape as the operand, we can safely moddims, the operand, else there's a - // dimension mismatch. For example: - // {4, 5, 6, 7}(span, span, 5) --> {4, 5, 1, 7} --> {4, 5, 7} - // {4, 5, 6, 7}(4) --> {1, 5, 1, 7} --> {5, 1, 7, 1} - std::vector indicesToCompress; - for (unsigned i = 0; i < indices.size(); ++i) { - // If an index literal, the corresponding dimension in the indexed array - // is 1, then we indexed the input to a dim of 1, so we can condense that - // index - if (indexTypes[i] == IndexType::Literal) { - indicesToCompress.push_back(i); - } - } + // If the dimensions being indexed are 1 and collapsing them yields the same + // shape as the operand, we can safely moddims, the operand, else there's a + // dimension mismatch. For example: + // {4, 5, 6, 7}(span, span, 5) --> {4, 5, 1, 7} --> {4, 5, 7} + // {4, 5, 6, 7}(4) --> {1, 5, 1, 7} --> {5, 1, 7, 1} + std::vector indicesToCompress; + for(unsigned i = 0; i < indices.size(); ++i) { + // If an index literal, the corresponding dimension in the indexed array + // is 1, then we indexed the input to a dim of 1, so we can condense that + // index + if(indexTypes[i] == IndexType::Literal) { + indicesToCompress.push_back(i); + } + } - af::dim4 condensedDims(1, 1, 1, 1); - af::dim4 postIdxDims = preIdxDims; - unsigned outDimIdx = 0; - unsigned compressIdx = 0; - for (unsigned i = 0; i < AF_MAX_DIMS; ++i) { - if (compressIdx < indicesToCompress.size() && - i == indicesToCompress[compressIdx]) { - compressIdx++; - postIdxDims[i] = 1; - } else { - // Use the size of the dim post-indexing. Span uses the preIdx dim - // and literals are pushed to 1. - if (i < indexTypes.size()) { - if (indexTypes[i] == IndexType::Tensor) { - dim_t size; - AF_CHECK(af_get_elements(&size, indices[i].get().idx.arr)); - postIdxDims[i] = size; - } else if (indexTypes[i] == IndexType::Range) { - postIdxDims[i] = af::seq(indices[i].get().idx.seq).size; - } else if (indexTypes[i] == IndexType::Literal) { - postIdxDims[i] = 1; - } + af::dim4 condensedDims(1, 1, 1, 1); + af::dim4 postIdxDims = preIdxDims; + unsigned outDimIdx = 0; + unsigned compressIdx = 0; + for(unsigned i = 0; i < AF_MAX_DIMS; ++i) { + if( + compressIdx < indicesToCompress.size() + && i == indicesToCompress[compressIdx] + ) { + compressIdx++; + postIdxDims[i] = 1; + } else { + // Use the size of the dim post-indexing. Span uses the preIdx dim + // and literals are pushed to 1. + if(i < indexTypes.size()) { + if(indexTypes[i] == IndexType::Tensor) { + dim_t size; + AF_CHECK(af_get_elements(&size, indices[i].get().idx.arr)); + postIdxDims[i] = size; + } else if(indexTypes[i] == IndexType::Range) { + postIdxDims[i] = af::seq(indices[i].get().idx.seq).size; + } else if(indexTypes[i] == IndexType::Literal) { + postIdxDims[i] = 1; + } + } + condensedDims[outDimIdx] = postIdxDims[i]; + outDimIdx++; + } } - condensedDims[outDimIdx] = postIdxDims[i]; - outDimIdx++; - } - } - // Can modify the operand to work with the proxy or array input only by - // removing singleton dimensions - if (condensedDims == operandDims) { - newDims = postIdxDims; + // Can modify the operand to work with the proxy or array input only by + // removing singleton dimensions + if(condensedDims == operandDims) { + newDims = postIdxDims; + } else { + throw std::invalid_argument( + "ArrayFireTensor adjustInPlaceOperandDims: can't apply operation " + "in-place to indexed ArrayFireTensor - dimensions don't match." + ); + } } else { - throw std::invalid_argument( - "ArrayFireTensor adjustInPlaceOperandDims: can't apply operation " - "in-place to indexed ArrayFireTensor - dimensions don't match."); + // No indexing so no change in dimensions required + newDims = operandDims; } - } else { - // No indexing so no change in dimensions required - newDims = operandDims; - } - // af::moddims involves an eval. This will be fixed in AF 3.8.1/3.8.2 - bool doModdims = operandArr.dims() != newDims; - return (doModdims ? af::moddims(operandArr, newDims) : operandArr); + // af::moddims involves an eval. This will be fixed in AF 3.8.1/3.8.2 + bool doModdims = operandArr.dims() != newDims; + return doModdims ? af::moddims(operandArr, newDims) : operandArr; } -#define ASSIGN_OP_TENSOR(FUN, AF_OP) \ - void ArrayFireTensor::FUN(const Tensor& tensor) { \ - std::visit( \ - [&tensor, this](auto&& arr) { \ - arr.get(*this) AF_OP this->adjustInPlaceOperandDims(tensor); \ - }, \ - handle_); \ - } +#define ASSIGN_OP_TENSOR(FUN, AF_OP) \ + void ArrayFireTensor::FUN(const Tensor& tensor) { \ + std::visit( \ + [&tensor, this](auto&& arr) { \ + arr.get(*this) AF_OP this->adjustInPlaceOperandDims(tensor); \ + }, \ + handle_ \ + ); \ + } -#define ASSIGN_OP(FUN, AF_OP) \ - ASSIGN_OP_TENSOR(FUN, AF_OP) \ - ASSIGN_OP_LITERALS(FUN, AF_OP) +#define ASSIGN_OP(FUN, AF_OP) \ + ASSIGN_OP_TENSOR(FUN, AF_OP) \ + ASSIGN_OP_LITERALS(FUN, AF_OP) // (function name, AF op). Use build-in AF operators. -ASSIGN_OP(inPlaceSubtract, -=); -ASSIGN_OP(inPlaceMultiply, *=); -ASSIGN_OP(inPlaceDivide, /=); +ASSIGN_OP(inPlaceSubtract, -= ); +ASSIGN_OP(inPlaceMultiply, *= ); +ASSIGN_OP(inPlaceDivide, /= ); // Instantiate definitions for type literals - those remain unchanged: -ASSIGN_OP_LITERALS(assign, =); +ASSIGN_OP_LITERALS(assign, = ); void ArrayFireTensor::assign(const Tensor& tensor) { - std::visit( - [&tensor, this](auto&& arr) { - if (indices_) { - // If this is an indexing op, do as other in-place ops with lvalue - // temporaries as a result of indexing do - arr.get(*this) = this->adjustInPlaceOperandDims(tensor); - } else { - // Not an indexing op - just assign the tensor, but make sure to - // update the number of dims - arr.get(*this) = toArray(tensor); - this->numDims_ = tensor.ndim(); - } - }, - handle_); + std::visit( + [&tensor, this](auto&& arr) { + if(indices_) { + // If this is an indexing op, do as other in-place ops with lvalue + // temporaries as a result of indexing do + arr.get(*this) = this->adjustInPlaceOperandDims(tensor); + } else { + // Not an indexing op - just assign the tensor, but make sure to + // update the number of dims + arr.get(*this) = toArray(tensor); + this->numDims_ = tensor.ndim(); + } + }, + handle_ + ); } /* @@ -525,70 +564,78 @@ void ArrayFireTensor::assign(const Tensor& tensor) { * it properly-handles the case of repeated indices. */ // Instantiate definitions for type literals - those remain unchanged: -ASSIGN_OP_LITERALS(inPlaceAdd, +=); +ASSIGN_OP_LITERALS(inPlaceAdd, += ); // Special tensor op: void ArrayFireTensor::inPlaceAdd(const Tensor& tensor) { - // First, check if this a tensor that's going to be lazily indexed. Don't - // implicitly cast to an array, else that will trigger indexing. - // Carefully get the handle types without calling type(), which will lazily - // evaluate indexing - af::dtype operandHandleType = - tensor.getAdapter().afHandleType(); - af::dtype handleType = arrayHandle_->type(); - // not all types are compatible with the kernel - bool typeIncompatible = - (handleType != af::dtype::f32 && handleType != af::dtype::f16) || - (operandHandleType != af::dtype::f32 && - operandHandleType != af::dtype::f16); - if (!std::holds_alternative(handle_) || - typeIncompatible || - !FL_BACKEND_CUDA // TODO{fl::Tensor} advanced indexing only impl for CUDA - ) { - // Call the regular af::array::operator+= - std::visit( - [&tensor, this](auto&& arr) { - arr.get(*this) += this->adjustInPlaceOperandDims(tensor); - }, - handle_); - return; - } else { - af::dim4 inDims = arrayHandle_->dims(); - af::dim4 idxStart; - af::dim4 idxEnd; - std::vector idxArr(4); - auto idxFunc = [&idxStart, &idxEnd, &idxArr, &inDims]( - const af::index& index, int pos) { - if (index.isspan()) { - idxStart[pos] = 0; - idxEnd[pos] = inDims[pos]; - } else { - const auto& idxSeq = index.get(); - if (idxSeq.isSeq) { - // arrayfire uses inclusive last dimension, we use exclusive - idxStart[pos] = idxSeq.idx.seq.begin; - idxEnd[pos] = idxSeq.idx.seq.end + 1; - } else { - af_array arr; - af_retain_array(&arr, idxSeq.idx.arr); - idxArr[pos] = af::array(arr); - idxStart[pos] = 0; - idxEnd[pos] = idxArr[pos].dims(0); + // First, check if this a tensor that's going to be lazily indexed. Don't + // implicitly cast to an array, else that will trigger indexing. + // Carefully get the handle types without calling type(), which will lazily + // evaluate indexing + af::dtype operandHandleType = + tensor.getAdapter().afHandleType(); + af::dtype handleType = arrayHandle_->type(); + // not all types are compatible with the kernel + bool typeIncompatible = + (handleType != af::dtype::f32 && handleType != af::dtype::f16) + || (operandHandleType != af::dtype::f32 + && operandHandleType != af::dtype::f16); + if( + !std::holds_alternative(handle_) + || typeIncompatible + || !FL_BACKEND_CUDA // TODO{fl::Tensor} advanced indexing only impl for CUDA + ) { + // Call the regular af::array::operator+= + std::visit( + [&tensor, this](auto&& arr) { + arr.get(*this) += this->adjustInPlaceOperandDims(tensor); + }, + handle_ + ); + return; + } else { + af::dim4 inDims = arrayHandle_->dims(); + af::dim4 idxStart; + af::dim4 idxEnd; + std::vector idxArr(4); + auto idxFunc = [&idxStart, &idxEnd, &idxArr, &inDims]( + const af::index& index, int pos) { + if(index.isspan()) { + idxStart[pos] = 0; + idxEnd[pos] = inDims[pos]; + } else { + const auto& idxSeq = index.get(); + if(idxSeq.isSeq) { + // arrayfire uses inclusive last dimension, we use exclusive + idxStart[pos] = idxSeq.idx.seq.begin; + idxEnd[pos] = idxSeq.idx.seq.end + 1; + } else { + af_array arr; + af_retain_array(&arr, idxSeq.idx.arr); + idxArr[pos] = af::array(arr); + idxStart[pos] = 0; + idxEnd[pos] = idxArr[pos].dims(0); + } + } + }; + + unsigned i = 0; + for(; i < indices_.value().size(); ++i) { + idxFunc(indices_.value()[i], i); + } + // The kernel needs to be padded with spans for remaining dims + for(; i < AF_MAX_DIMS; ++i) { + idxFunc(af::span, i); } - } - }; - unsigned i = 0; - for (; i < indices_.value().size(); ++i) { - idxFunc(indices_.value()[i], i); + fl::detail::advancedIndex( + toArray(tensor), + idxStart, + idxEnd, + inDims, + idxArr, + *arrayHandle_ + ); } - // The kernel needs to be padded with spans for remaining dims - for (; i < AF_MAX_DIMS; ++i) { - idxFunc(af::span, i); - } - - fl::detail::advancedIndex( - toArray(tensor), idxStart, idxEnd, inDims, idxArr, *arrayHandle_); - } } #undef ASSIGN_OP_TYPE diff --git a/flashlight/fl/tensor/backend/af/ArrayFireTensor.h b/flashlight/fl/tensor/backend/af/ArrayFireTensor.h index 5a365e1..faeda13 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireTensor.h +++ b/flashlight/fl/tensor/backend/af/ArrayFireTensor.h @@ -27,221 +27,224 @@ class ArrayFireBackend; * Flashlight Tensors to ArrayFire. */ class ArrayFireTensor : public TensorAdapterBase { - // A pointer to the internal ArrayFire array. Shared amongst tensors that are - // shallow-copied. - std::shared_ptr arrayHandle_; + // A pointer to the internal ArrayFire array. Shared amongst tensors that are + // shallow-copied. + std::shared_ptr arrayHandle_; - // Indices in the event that this tensor is about to be indexed. Cleared the - // next time this array handle is acquired. See getHandle(). - std::optional> indices_; - // Need to maintain the types of each index, as ArrayFire doesn't distinguish - // between an integer index literal and an af::seq of size one; both have - // slightly different behavior with fl::Tensor - std::optional> indexTypes_; - // To be visited when this tensor is to be indexed. Indexes the underlying - // af::array, and returns the proxy to be used as a temporary lvalue. - struct IndexedArrayComponent { - explicit IndexedArrayComponent(const bool _isFlat = false); - af::array::array_proxy get(const ArrayFireTensor& inst); - bool isFlat; - }; - // To be visited when this tensor is holding an array without needing - // indexing. Passthrough - returns the array directly. - struct ArrayComponent { - af::array& get(const ArrayFireTensor& inst); - }; - // An interface to visit when getting an array handle. Indexes lazily - // because we can't store an af::array::proxy as an lvalue. See getHandle(). - std::variant handle_{ArrayComponent()}; + // Indices in the event that this tensor is about to be indexed. Cleared the + // next time this array handle is acquired. See getHandle(). + std::optional> indices_; + // Need to maintain the types of each index, as ArrayFire doesn't distinguish + // between an integer index literal and an af::seq of size one; both have + // slightly different behavior with fl::Tensor + std::optional> indexTypes_; + // To be visited when this tensor is to be indexed. Indexes the underlying + // af::array, and returns the proxy to be used as a temporary lvalue. + struct IndexedArrayComponent { + explicit IndexedArrayComponent(const bool _isFlat = false); + af::array::array_proxy get(const ArrayFireTensor& inst); + bool isFlat; + }; + // To be visited when this tensor is holding an array without needing + // indexing. Passthrough - returns the array directly. + struct ArrayComponent { + af::array& get(const ArrayFireTensor& inst); + }; + // An interface to visit when getting an array handle. Indexes lazily + // because we can't store an af::array::proxy as an lvalue. See getHandle(). + std::variant handle_{ArrayComponent()}; - /** - * Constructs an ArrayFireTensor that will be lazily indexed. - * - * This constructor is for internal use only. Because af::array::array_proxy - * objects don't work properly as lvalues, they need to be used as temporary - * lvalues when doing in-place assignment. As such, Tensors are lazily-indexed - * if operators that might need to operate on array proxies are called. This - * ctor sets up that lazy indexing. - * - * Whenever these ArrayFireTensors are mutated, ArrayFireTensor::getHandle() - * is called, which performs indexing if needed and upcasts the array_proxy to - * a full af::array on which operations can be performed. - * - * @param[in] handle a pointer to the ArrayFire array - * @param[in] indices a vector of ArrayFire indices to lazily index. - * @param[in] indexTypes a vector of index types to lazily index. Needed to - * determine singleton dimension condensation - * @param[in] isFlat if the indexing op is flat (condense all dims) - */ - ArrayFireTensor( - std::shared_ptr handle, - std::vector&& afIndices, - std::vector&& indexTypes, - const unsigned numDims, - const bool isFlat); + /** + * Constructs an ArrayFireTensor that will be lazily indexed. + * + * This constructor is for internal use only. Because af::array::array_proxy + * objects don't work properly as lvalues, they need to be used as temporary + * lvalues when doing in-place assignment. As such, Tensors are lazily-indexed + * if operators that might need to operate on array proxies are called. This + * ctor sets up that lazy indexing. + * + * Whenever these ArrayFireTensors are mutated, ArrayFireTensor::getHandle() + * is called, which performs indexing if needed and upcasts the array_proxy to + * a full af::array on which operations can be performed. + * + * @param[in] handle a pointer to the ArrayFire array + * @param[in] indices a vector of ArrayFire indices to lazily index. + * @param[in] indexTypes a vector of index types to lazily index. Needed to + * determine singleton dimension condensation + * @param[in] isFlat if the indexing op is flat (condense all dims) + */ + ArrayFireTensor( + std::shared_ptr handle, + std::vector&& afIndices, + std::vector&& indexTypes, + const unsigned numDims, + const bool isFlat + ); - /** - * Construct an ArrayFireTensor from an ArrayFire array handle without copying - * the handle. Used for creating guaranteed-shallow copies. - */ - explicit ArrayFireTensor(std::shared_ptr arr, unsigned numDims); + /** + * Construct an ArrayFireTensor from an ArrayFire array handle without copying + * the handle. Used for creating guaranteed-shallow copies. + */ + explicit ArrayFireTensor(std::shared_ptr arr, unsigned numDims); - /* - * A Flashlight Shape that mirrors ArrayFire dims. - * - * NOTE: this shape is only updated on calls to ArrayFireTensor::shape() - * so as to satisfy API requirements as per returning a const reference. - * af::array::dims() should be used for internal computation where - * shape/dimensions are needed. - * - * The default shape is the empty Tensor 0. - */ - Shape shape_; + /* + * A Flashlight Shape that mirrors ArrayFire dims. + * + * NOTE: this shape is only updated on calls to ArrayFireTensor::shape() + * so as to satisfy API requirements as per returning a const reference. + * af::array::dims() should be used for internal computation where + * shape/dimensions are needed. + * + * The default shape is the empty Tensor 0. + */ + Shape shape_; - /* - * The number of dimensions in this ArrayFire tensor that are "expected" per - * interoperability with other tensors. Because ArrayFire doesn't distinguish - * between singleton dimensions that are defaults and those that are - * explicitly specified, this must be explicitly tracked. - * - * The fl::Tensor default Tensor shape is {0} - the default number of numDims - * is thus 1. Scalars have numDims == 0; - */ - unsigned numDims_{1}; + /* + * The number of dimensions in this ArrayFire tensor that are "expected" per + * interoperability with other tensors. Because ArrayFire doesn't distinguish + * between singleton dimensions that are defaults and those that are + * explicitly specified, this must be explicitly tracked. + * + * The fl::Tensor default Tensor shape is {0} - the default number of numDims + * is thus 1. Scalars have numDims == 0; + */ + unsigned numDims_{1}; - public: - constexpr static TensorBackendType tensorBackendType = TensorBackendType::ArrayFire; +public: + constexpr static TensorBackendType tensorBackendType = TensorBackendType::ArrayFire; - /** - * Constructs an ArrayFireTensor. - * - * Since af::arrays are refcounted, an instance of this class - * can only be created using arrays that are moved therein. - * - * Tensor operations occurring directly on this tensor's underlying - * af::array should not copy the array else take a performance penalty (via - * an internal copy if refcount is > 1 in some cases). - * - * @param[in] array construct a tensor from an ArrayFire array rvalue - * reference. - */ - explicit ArrayFireTensor(af::array&& array, const unsigned numDims); + /** + * Constructs an ArrayFireTensor. + * + * Since af::arrays are refcounted, an instance of this class + * can only be created using arrays that are moved therein. + * + * Tensor operations occurring directly on this tensor's underlying + * af::array should not copy the array else take a performance penalty (via + * an internal copy if refcount is > 1 in some cases). + * + * @param[in] array construct a tensor from an ArrayFire array rvalue + * reference. + */ + explicit ArrayFireTensor(af::array&& array, const unsigned numDims); - /** - * Default initialization - empty ArrayFire array and empty shape. - */ - ArrayFireTensor(); + /** + * Default initialization - empty ArrayFire array and empty shape. + */ + ArrayFireTensor(); - /** - * Construct an ArrayFire tensor using some data. - * - * @param[in] shape the shape of the new tensor - * @param[in] ptr the buffer containing underlying tensor data - * @param[in] type the type of the new tensor - * @param[in] memoryLocation the location of the buffer - */ - ArrayFireTensor( - const Shape& shape, - fl::dtype type, - const void* ptr, - Location memoryLocation); + /** + * Construct an ArrayFire tensor using some data. + * + * @param[in] shape the shape of the new tensor + * @param[in] ptr the buffer containing underlying tensor data + * @param[in] type the type of the new tensor + * @param[in] memoryLocation the location of the buffer + */ + ArrayFireTensor( + const Shape& shape, + fl::dtype type, + const void* ptr, + Location memoryLocation + ); - ArrayFireTensor( - const Dim nRows, - const Dim nCols, - const Tensor& values, - const Tensor& rowIdx, - const Tensor& colIdx, - StorageType storageType); + ArrayFireTensor( + const Dim nRows, + const Dim nCols, + const Tensor& values, + const Tensor& rowIdx, + const Tensor& colIdx, + StorageType storageType + ); - /** - * Gets an ArrayFire Array from this impl. - * - * Throws if this tensor represents an array_proxy, since it precludes - * promotion to an array. - */ - const af::array& getHandle() const; + /** + * Gets an ArrayFire Array from this impl. + * + * Throws if this tensor represents an array_proxy, since it precludes + * promotion to an array. + */ + const af::array& getHandle() const; - /** - * Gets an ArrayFire Array from this impl. If the underlying handle is an - * array_proxy, may promote it to an af::array condense dimensions as needed, - * replace the handle variant, and return a reference. - */ - af::array& getHandle(); + /** + * Gets an ArrayFire Array from this impl. If the underlying handle is an + * array_proxy, may promote it to an af::array condense dimensions as needed, + * replace the handle variant, and return a reference. + */ + af::array& getHandle(); - ~ArrayFireTensor() override = default; - unsigned numDims() const; - // Used with the fl::Tensor copy constructor - std::unique_ptr clone() const override; - TensorBackendType backendType() const override; - TensorBackend& backend() const override; - Tensor copy() override; - Tensor shallowCopy() override; - const Shape& shape() override; - dtype type() override; - bool isSparse() override; - af::dtype afHandleType(); // for internal use only - Location location() override; - void scalar(void* out) override; - void device(void** out) override; - void host(void* out) override; - void unlock() override; - bool isLocked() override; - bool isContiguous() override; - Shape strides() override; - const Stream& stream() const override; - Tensor astype(const dtype type) override; - Tensor index(const std::vector& indices) override; - Tensor flatten() const override; - Tensor flat(const Index& idx) const override; - Tensor asContiguousTensor() override; - void setContext(void* context) override; // noop - void* getContext() override; // noop - std::string toString() override; - std::ostream& operator<<(std::ostream& ostr) override; + ~ArrayFireTensor() override = default; + unsigned numDims() const; + // Used with the fl::Tensor copy constructor + std::unique_ptr clone() const override; + TensorBackendType backendType() const override; + TensorBackend& backend() const override; + Tensor copy() override; + Tensor shallowCopy() override; + const Shape& shape() override; + dtype type() override; + bool isSparse() override; + af::dtype afHandleType(); // for internal use only + Location location() override; + void scalar(void* out) override; + void device(void** out) override; + void host(void* out) override; + void unlock() override; + bool isLocked() override; + bool isContiguous() override; + Shape strides() override; + const Stream& stream() const override; + Tensor astype(const dtype type) override; + Tensor index(const std::vector& indices) override; + Tensor flatten() const override; + Tensor flat(const Index& idx) const override; + Tensor asContiguousTensor() override; + void setContext(void* context) override; // noop + void* getContext() override; // noop + std::string toString() override; + std::ostream& operator<<(std::ostream& ostr) override; - /******************** Assignment Operators ********************/ + /******************** Assignment Operators ********************/ #define ASSIGN_OP_TYPE(OP, TYPE) void OP(const TYPE& val) override; - /** - * When indexing ArrayFire arrays, their dimensions are condensed (i.e. {3, 4, - * 5, 6}(fl::span, 1) --> {3, 5, 6} rather than {3, 1, 5, 6}) when arrays are - * returned as lvalues. In the case of lvalue temporary af::array::array_proxy - * objects that have in-place operations applied to them, one can't modify - * their dimensions without upcasting them into an af::array, which breaks - * in-place op logic. - * - * The only option is thus to modify the dimensions of the operand of the - * inplace operation in order to make the shapes match, but this should only - * be done if the shapes are actually compatible. This function performs that - * op before in-place operations are applied. - * - * @param[in] operand the tensor operand - * @param[in] newNumDims the number of dims of the resulting tensor - */ - af::array adjustInPlaceOperandDims(const Tensor& operand); + /** + * When indexing ArrayFire arrays, their dimensions are condensed (i.e. {3, 4, + * 5, 6}(fl::span, 1) --> {3, 5, 6} rather than {3, 1, 5, 6}) when arrays are + * returned as lvalues. In the case of lvalue temporary af::array::array_proxy + * objects that have in-place operations applied to them, one can't modify + * their dimensions without upcasting them into an af::array, which breaks + * in-place op logic. + * + * The only option is thus to modify the dimensions of the operand of the + * inplace operation in order to make the shapes match, but this should only + * be done if the shapes are actually compatible. This function performs that + * op before in-place operations are applied. + * + * @param[in] operand the tensor operand + * @param[in] newNumDims the number of dims of the resulting tensor + */ + af::array adjustInPlaceOperandDims(const Tensor& operand); -#define ASSIGN_OP(OP) \ - ASSIGN_OP_TYPE(OP, Tensor); \ - ASSIGN_OP_TYPE(OP, double); \ - ASSIGN_OP_TYPE(OP, float); \ - ASSIGN_OP_TYPE(OP, int); \ - ASSIGN_OP_TYPE(OP, unsigned); \ - ASSIGN_OP_TYPE(OP, bool); \ - ASSIGN_OP_TYPE(OP, char); \ - ASSIGN_OP_TYPE(OP, unsigned char); \ - ASSIGN_OP_TYPE(OP, short); \ - ASSIGN_OP_TYPE(OP, unsigned short); \ - ASSIGN_OP_TYPE(OP, long); \ - ASSIGN_OP_TYPE(OP, unsigned long); \ - ASSIGN_OP_TYPE(OP, long long); \ - ASSIGN_OP_TYPE(OP, unsigned long long); +#define ASSIGN_OP(OP) \ + ASSIGN_OP_TYPE(OP, Tensor); \ + ASSIGN_OP_TYPE(OP, double); \ + ASSIGN_OP_TYPE(OP, float); \ + ASSIGN_OP_TYPE(OP, int); \ + ASSIGN_OP_TYPE(OP, unsigned); \ + ASSIGN_OP_TYPE(OP, bool); \ + ASSIGN_OP_TYPE(OP, char); \ + ASSIGN_OP_TYPE(OP, unsigned char); \ + ASSIGN_OP_TYPE(OP, short); \ + ASSIGN_OP_TYPE(OP, unsigned short); \ + ASSIGN_OP_TYPE(OP, long); \ + ASSIGN_OP_TYPE(OP, unsigned long); \ + ASSIGN_OP_TYPE(OP, long long); \ + ASSIGN_OP_TYPE(OP, unsigned long long); - ASSIGN_OP(assign); // = - ASSIGN_OP(inPlaceAdd); // += - ASSIGN_OP(inPlaceSubtract); // -= - ASSIGN_OP(inPlaceMultiply); // *= - ASSIGN_OP(inPlaceDivide); // /= + ASSIGN_OP(assign); // = + ASSIGN_OP(inPlaceAdd); // += + ASSIGN_OP(inPlaceSubtract); // -= + ASSIGN_OP(inPlaceMultiply); // *= + ASSIGN_OP(inPlaceDivide); ///= #undef ASSIGN_OP_TYPE #undef ASSIGN_OP }; diff --git a/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp b/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp index 11d237d..8415b8d 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp @@ -15,114 +15,126 @@ namespace fl { Tensor ArrayFireBackend::exp(const Tensor& tensor) { - return toTensor(af::exp(toArray(tensor)), tensor.ndim()); + return toTensor(af::exp(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::log(const Tensor& tensor) { - return toTensor(af::log(toArray(tensor)), tensor.ndim()); + return toTensor(af::log(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::negative(const Tensor& tensor) { - return toTensor(-toArray(tensor), tensor.ndim()); + return toTensor(-toArray(tensor), tensor.ndim()); } Tensor ArrayFireBackend::logicalNot(const Tensor& tensor) { - return toTensor(!toArray(tensor), tensor.ndim()); + return toTensor(!toArray(tensor), tensor.ndim()); } Tensor ArrayFireBackend::log1p(const Tensor& tensor) { - return toTensor(af::log1p(toArray(tensor)), tensor.ndim()); + return toTensor(af::log1p(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::sin(const Tensor& tensor) { - return toTensor(af::sin(toArray(tensor)), tensor.ndim()); + return toTensor(af::sin(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::cos(const Tensor& tensor) { - return toTensor(af::cos(toArray(tensor)), tensor.ndim()); + return toTensor(af::cos(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::sqrt(const Tensor& tensor) { - return toTensor(af::sqrt(toArray(tensor)), tensor.ndim()); + return toTensor(af::sqrt(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::tanh(const Tensor& tensor) { - return toTensor(af::tanh(toArray(tensor)), tensor.ndim()); + return toTensor(af::tanh(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::floor(const Tensor& tensor) { - return toTensor(af::floor(toArray(tensor)), tensor.ndim()); + return toTensor(af::floor(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::ceil(const Tensor& tensor) { - return toTensor(af::ceil(toArray(tensor)), tensor.ndim()); + return toTensor(af::ceil(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::rint(const Tensor& tensor) { - return toTensor(af::round(toArray(tensor)), tensor.ndim()); + return toTensor(af::round(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::absolute(const Tensor& tensor) { - return toTensor(af::abs(toArray(tensor)), tensor.ndim()); + return toTensor(af::abs(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::sigmoid(const Tensor& tensor) { - return toTensor(af::sigmoid(toArray(tensor)), tensor.ndim()); + return toTensor(af::sigmoid(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::erf(const Tensor& tensor) { - return toTensor(af::erf(toArray(tensor)), tensor.ndim()); + return toTensor(af::erf(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::flip(const Tensor& tensor, const unsigned dim) { - return toTensor( - af::flip(toArray(tensor), dim), tensor.ndim()); + return toTensor( + af::flip(toArray(tensor), dim), + tensor.ndim() + ); } Tensor ArrayFireBackend::clip( const Tensor& tensor, const Tensor& low, - const Tensor& high) { - return toTensor( - af::clamp(toArray(tensor), toArray(low), toArray(high)), tensor.ndim()); + const Tensor& high +) { + return toTensor( + af::clamp(toArray(tensor), toArray(low), toArray(high)), + tensor.ndim() + ); } Tensor ArrayFireBackend::roll( const Tensor& tensor, const int shift, - const unsigned axis) { - if (axis > AF_MAX_DIMS) { - throw std::invalid_argument( - "ArrayFireBackend::roll - given axis > 3 - unsupported"); - } - std::vector shifts(AF_MAX_DIMS, 0); - shifts[axis] = shift; - return toTensor( - af::shift(toArray(tensor), shifts[0], shifts[1], shifts[2], shifts[3]), - tensor.ndim()); + const unsigned axis +) { + if(axis > AF_MAX_DIMS) { + throw std::invalid_argument( + "ArrayFireBackend::roll - given axis > 3 - unsupported" + ); + } + std::vector shifts(AF_MAX_DIMS, 0); + shifts[axis] = shift; + return toTensor( + af::shift(toArray(tensor), shifts[0], shifts[1], shifts[2], shifts[3]), + tensor.ndim() + ); } Tensor ArrayFireBackend::isnan(const Tensor& tensor) { - return toTensor(af::isNaN(toArray(tensor)), tensor.ndim()); + return toTensor(af::isNaN(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::isinf(const Tensor& tensor) { - return toTensor(af::isInf(toArray(tensor)), tensor.ndim()); + return toTensor(af::isInf(toArray(tensor)), tensor.ndim()); } Tensor ArrayFireBackend::sign(const Tensor& tensor) { - auto wSigned = 1 - 2 * af::sign(toArray(tensor)); - wSigned(toArray(tensor) == 0) = 0; - return toTensor(std::move(wSigned), tensor.ndim()); + auto wSigned = 1 - 2 * af::sign(toArray(tensor)); + wSigned(toArray(tensor) == 0) = 0; + return toTensor(std::move(wSigned), tensor.ndim()); } Tensor ArrayFireBackend::tril(const Tensor& tensor) { - return toTensor( - af::lower(toArray(tensor), /* is_unit_diag = */ false), tensor.ndim()); + return toTensor( + af::lower(toArray(tensor), /* is_unit_diag = */ false), + tensor.ndim() + ); } Tensor ArrayFireBackend::triu(const Tensor& tensor) { - return toTensor( - af::upper(toArray(tensor), /* is_unit_diag = */ false), tensor.ndim()); + return toTensor( + af::upper(toArray(tensor), /* is_unit_diag = */ false), + tensor.ndim() + ); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/Utils.cpp b/flashlight/fl/tensor/backend/af/Utils.cpp index 6b91d81..5f0bdfe 100644 --- a/flashlight/fl/tensor/backend/af/Utils.cpp +++ b/flashlight/fl/tensor/backend/af/Utils.cpp @@ -16,283 +16,299 @@ namespace fl::detail { af::dtype flToAfType(fl::dtype type) { - static const std::unordered_map - kFlashlightTypeToArrayFire = { - {fl::dtype::f16, af::dtype::f16}, - {fl::dtype::f32, af::dtype::f32}, - {fl::dtype::f64, af::dtype::f64}, - {fl::dtype::b8, af::dtype::b8}, - {fl::dtype::s16, af::dtype::s16}, - {fl::dtype::s32, af::dtype::s32}, - {fl::dtype::s64, af::dtype::s64}, - {fl::dtype::u8, af::dtype::u8}, - {fl::dtype::u16, af::dtype::u16}, - {fl::dtype::u32, af::dtype::u32}, - {fl::dtype::u64, af::dtype::u64}}; - return kFlashlightTypeToArrayFire.at(type); + static const std::unordered_map + kFlashlightTypeToArrayFire = { + {fl::dtype::f16, af::dtype::f16}, + {fl::dtype::f32, af::dtype::f32}, + {fl::dtype::f64, af::dtype::f64}, + {fl::dtype::b8, af::dtype::b8}, + {fl::dtype::s16, af::dtype::s16}, + {fl::dtype::s32, af::dtype::s32}, + {fl::dtype::s64, af::dtype::s64}, + {fl::dtype::u8, af::dtype::u8}, + {fl::dtype::u16, af::dtype::u16}, + {fl::dtype::u32, af::dtype::u32}, + {fl::dtype::u64, af::dtype::u64}}; + return kFlashlightTypeToArrayFire.at(type); } fl::dtype afToFlType(af::dtype type) { - static const std::unordered_map - kArrayFireTypeToFlashlight = { - {af::dtype::f16, fl::dtype::f16}, - {af::dtype::f32, fl::dtype::f32}, - {af::dtype::f64, fl::dtype::f64}, - {af::dtype::b8, fl::dtype::b8}, - {af::dtype::s16, fl::dtype::s16}, - {af::dtype::s32, fl::dtype::s32}, - {af::dtype::s64, fl::dtype::s64}, - {af::dtype::u8, fl::dtype::u8}, - {af::dtype::u16, fl::dtype::u16}, - {af::dtype::u32, fl::dtype::u32}, - {af::dtype::u64, fl::dtype::u64}}; - return kArrayFireTypeToFlashlight.at(type); + static const std::unordered_map + kArrayFireTypeToFlashlight = { + {af::dtype::f16, fl::dtype::f16}, + {af::dtype::f32, fl::dtype::f32}, + {af::dtype::f64, fl::dtype::f64}, + {af::dtype::b8, fl::dtype::b8}, + {af::dtype::s16, fl::dtype::s16}, + {af::dtype::s32, fl::dtype::s32}, + {af::dtype::s64, fl::dtype::s64}, + {af::dtype::u8, fl::dtype::u8}, + {af::dtype::u16, fl::dtype::u16}, + {af::dtype::u32, fl::dtype::u32}, + {af::dtype::u64, fl::dtype::u64}}; + return kArrayFireTypeToFlashlight.at(type); } af_mat_prop flToAfMatrixProperty(MatrixProperty property) { - switch (property) { - case MatrixProperty::None: - return AF_MAT_NONE; - case MatrixProperty::Transpose: - return AF_MAT_TRANS; - default: - throw std::invalid_argument( - "flToAfMatrixProperty: invalid property specified"); - } + switch(property) { + case MatrixProperty::None: + return AF_MAT_NONE; + case MatrixProperty::Transpose: + return AF_MAT_TRANS; + default: + throw std::invalid_argument( + "flToAfMatrixProperty: invalid property specified" + ); + } } af_storage flToAfStorageType(StorageType storageType) { - switch (storageType) { - case StorageType::Dense: - return AF_STORAGE_DENSE; - case StorageType::CSR: - return AF_STORAGE_CSR; - case StorageType::CSC: - return AF_STORAGE_CSC; - case StorageType::COO: - return AF_STORAGE_COO; - default: - throw std::invalid_argument( - "flToAfStorageType: Flashlight storage type " - "doesn't have an ArrayFire analog"); - } + switch(storageType) { + case StorageType::Dense: + return AF_STORAGE_DENSE; + case StorageType::CSR: + return AF_STORAGE_CSR; + case StorageType::CSC: + return AF_STORAGE_CSC; + case StorageType::COO: + return AF_STORAGE_COO; + default: + throw std::invalid_argument( + "flToAfStorageType: Flashlight storage type " + "doesn't have an ArrayFire analog" + ); + } } af_topk_function flToAfTopKSortMode(SortMode sortMode) { - switch (sortMode) { - case SortMode::Descending: - return AF_TOPK_MAX; - case SortMode::Ascending: - return AF_TOPK_MIN; - default: - throw std::invalid_argument( - "flToAfTopKSortMode: sort mode with no ArrayFire analog specified"); - } + switch(sortMode) { + case SortMode::Descending: + return AF_TOPK_MAX; + case SortMode::Ascending: + return AF_TOPK_MIN; + default: + throw std::invalid_argument( + "flToAfTopKSortMode: sort mode with no ArrayFire analog specified" + ); + } } af::dim4 flToAfDims(const Shape& shape) { - if (shape.ndim() > 4) { - throw std::invalid_argument( - "flToAfDims: ArrayFire shapes can't be more than 4 dimensions"); - } + if(shape.ndim() > 4) { + throw std::invalid_argument( + "flToAfDims: ArrayFire shapes can't be more than 4 dimensions" + ); + } - af::dim4 out(1, 1, 1, 1); - for (size_t i = 0; i < shape.ndim(); ++i) { - out.dims[i] = shape.dim(i); - } - return out; + af::dim4 out(1, 1, 1, 1); + for(size_t i = 0; i < shape.ndim(); ++i) { + out.dims[i] = shape.dim(i); + } + return out; } void afToFlDims(const af::dim4& d, const unsigned numDims, Shape& s) { - if (numDims > AF_MAX_DIMS) { - throw std::invalid_argument("afToFlDims - numDims > AF_MAX_DIMS"); - } + if(numDims > AF_MAX_DIMS) { + throw std::invalid_argument("afToFlDims - numDims > AF_MAX_DIMS"); + } - auto& storage = s.get(); + auto& storage = s.get(); - // numdims constraint is enforced by the internal API per condenseDims - if (numDims == 1 && d.elements() == 0) { - // Empty tensor - storage.resize(1); - s[0] = 0; - return; - } + // numdims constraint is enforced by the internal API per condenseDims + if(numDims == 1 && d.elements() == 0) { + // Empty tensor + storage.resize(1); + s[0] = 0; + return; + } - // numDims == 0 --> scalar tensor - if (numDims == 0) { - storage.resize(0); - return; - } + // numDims == 0 --> scalar tensor + if(numDims == 0) { + storage.resize(0); + return; + } - storage.resize(numDims); - for (unsigned i = 0; i < numDims; ++i) { - s[i] = d[i]; - } + storage.resize(numDims); + for(unsigned i = 0; i < numDims; ++i) { + s[i] = d[i]; + } } Shape afToFlDims(const af::dim4& d, const unsigned numDims) { - Shape s; - afToFlDims(d, numDims, s); - return s; + Shape s; + afToFlDims(d, numDims, s); + return s; } af::seq flRangeToAfSeq(const fl::range& range) { - const int start = range.start(); - const auto& optEnd = range.end(); - const int end = optEnd.has_value() ? optEnd.value() - 1 : af::end; - // There could be have other empty sequence representations, e.g., (0, -1) - // for axis with 1 element. In those cases, AF will throw internally -- - // we can't throw here because these cases axis-size dependent. - if (optEnd.has_value() && optEnd.value() == start) { - throw std::runtime_error( - "flRangeToAfSeq: AF seq can't represent empty sequence"); - } - return af::seq(start, end, range.stride()); + const int start = range.start(); + const auto& optEnd = range.end(); + const int end = optEnd.has_value() ? optEnd.value() - 1 : af::end; + // There could be have other empty sequence representations, e.g., (0, -1) + // for axis with 1 element. In those cases, AF will throw internally -- + // we can't throw here because these cases axis-size dependent. + if(optEnd.has_value() && optEnd.value() == start) { + throw std::runtime_error( + "flRangeToAfSeq: AF seq can't represent empty sequence" + ); + } + return af::seq(start, end, range.stride()); } af::index flToAfIndex(const fl::Index& idx) { - switch (idx.type()) { - case IndexType::Tensor: - return af::index(toArray(idx.get())); - case IndexType::Span: - return af::index(af::span); - case IndexType::Range: - return af::index(flRangeToAfSeq(idx.get())); - case IndexType::Literal: - return af::index(idx.get()); - default: - throw std::invalid_argument( - "flToAfIndex: fl::Index has unknown or invalid type."); - } + switch(idx.type()) { + case IndexType::Tensor: + return af::index(toArray(idx.get())); + case IndexType::Span: + return af::index(af::span); + case IndexType::Range: + return af::index(flRangeToAfSeq(idx.get())); + case IndexType::Literal: + return af::index(idx.get()); + default: + throw std::invalid_argument( + "flToAfIndex: fl::Index has unknown or invalid type." + ); + } } af::dim4 condenseDims(const af::dim4& dims) { - if (dims.elements() == 0) { - return af::dim4(0); - } + if(dims.elements() == 0) { + return af::dim4(0); + } - // Find the condensed shape - af::dim4 newDims(1, 1, 1, 1); - unsigned newDimIdx = 0; - for (unsigned i = 0; i < AF_MAX_DIMS; ++i) { - if (dims[i] != 1) { - // found a non-1 dim size - populate newDims - newDims[newDimIdx] = dims[i]; - newDimIdx++; + // Find the condensed shape + af::dim4 newDims(1, 1, 1, 1); + unsigned newDimIdx = 0; + for(unsigned i = 0; i < AF_MAX_DIMS; ++i) { + if(dims[i] != 1) { + // found a non-1 dim size - populate newDims + newDims[newDimIdx] = dims[i]; + newDimIdx++; + } } - } - return newDims; + return newDims; } af::array condenseIndices( const af::array& arr, const bool keepDims /* = false */, const std::optional>& indexTypes /* = {} */, - const bool isFlat /* = false */) { - // Fast path - return the Array as is if keepDims - don't consolidate - if (keepDims) { - return arr; - } - // Fast path - Array has zero elements or a dim of size zero - if (arr.elements() == 0) { - return arr; - } + const bool isFlat /* = false */ +) { + // Fast path - return the Array as is if keepDims - don't consolidate + if(keepDims) { + return arr; + } + // Fast path - Array has zero elements or a dim of size zero + if(arr.elements() == 0) { + return arr; + } - const af::dim4& dims = arr.dims(); - af::dim4 newDims(1, 1, 1, 1); - unsigned newDimIdx = 0; - for (unsigned i = 0; i < AF_MAX_DIMS; ++i) { - // If we're doing an index op (indexTypes is non-empty), then only collapse - // the dimension if it contains an index literal and we aren't doing flat - // indexing (which collapses all dims) - if (dims[i] == 1 && indexTypes && indexTypes.value().size() > i && - indexTypes.value()[i] != detail::IndexType::Literal && !isFlat) { - newDims[newDimIdx] = 1; - newDimIdx++; - } else if (dims[i] != 1) { - // found a non-1 dim size - populate newDims. - newDims[newDimIdx] = dims[i]; - newDimIdx++; + const af::dim4& dims = arr.dims(); + af::dim4 newDims(1, 1, 1, 1); + unsigned newDimIdx = 0; + for(unsigned i = 0; i < AF_MAX_DIMS; ++i) { + // If we're doing an index op (indexTypes is non-empty), then only collapse + // the dimension if it contains an index literal and we aren't doing flat + // indexing (which collapses all dims) + if( + dims[i] == 1 && indexTypes && indexTypes.value().size() > i + && indexTypes.value()[i] != detail::IndexType::Literal && !isFlat + ) { + newDims[newDimIdx] = 1; + newDimIdx++; + } else if(dims[i] != 1) { + // found a non-1 dim size - populate newDims. + newDims[newDimIdx] = dims[i]; + newDimIdx++; + } } - } - // Only change dims if condensing is possible - if (newDims != arr.dims()) { - return af::moddims(arr, newDims); - } else { - return arr; - } + // Only change dims if condensing is possible + if(newDims != arr.dims()) { + return af::moddims(arr, newDims); + } else { + return arr; + } } af_source flToAfLocation(Location location) { - switch (location) { - case Location::Host: - return afHost; - case Location::Device: - return afDevice; - default: - throw std::invalid_argument( - "flToAfLocation: no valid ArrayFire location exists " - " for given Flashlight location."); - } + switch(location) { + case Location::Host: + return afHost; + case Location::Device: + return afDevice; + default: + throw std::invalid_argument( + "flToAfLocation: no valid ArrayFire location exists " + " for given Flashlight location." + ); + } } af::array fromFlData( const Shape& shape, const void* ptr, fl::dtype type, - fl::Location memoryLocation) { - af::dim4 dims = detail::flToAfDims(shape); - af::dtype afType = detail::flToAfType(type); - af_source loc = detail::flToAfLocation(memoryLocation); + fl::Location memoryLocation +) { + af::dim4 dims = detail::flToAfDims(shape); + af::dtype afType = detail::flToAfType(type); + af_source loc = detail::flToAfLocation(memoryLocation); - // No or null buffer - if (!ptr) { - return af::array(dims, afType); - } + // No or null buffer + if(!ptr) { + return af::array(dims, afType); + } - using af::dtype; - switch (afType) { - case f32: - return af::array(dims, reinterpret_cast(ptr), loc); - case f64: - return af::array(dims, reinterpret_cast(ptr), loc); - case s32: - return af::array(dims, reinterpret_cast(ptr), loc); - case u32: - return af::array(dims, reinterpret_cast(ptr), loc); - case s64: - return af::array(dims, reinterpret_cast(ptr), loc); - case u64: - return af::array( - dims, reinterpret_cast(ptr), loc); - case s16: - return af::array(dims, reinterpret_cast(ptr), loc); - case u16: - return af::array(dims, reinterpret_cast(ptr), loc); - case b8: - return af::array(dims, reinterpret_cast(ptr), loc); - case u8: - return af::array(dims, reinterpret_cast(ptr), loc); - default: - throw std::invalid_argument( - "fromFlData: can't construct ArrayFire array from given type."); - } + using af::dtype; + switch(afType) { + case f32: + return af::array(dims, reinterpret_cast(ptr), loc); + case f64: + return af::array(dims, reinterpret_cast(ptr), loc); + case s32: + return af::array(dims, reinterpret_cast(ptr), loc); + case u32: + return af::array(dims, reinterpret_cast(ptr), loc); + case s64: + return af::array(dims, reinterpret_cast(ptr), loc); + case u64: + return af::array( + dims, + reinterpret_cast(ptr), + loc + ); + case s16: + return af::array(dims, reinterpret_cast(ptr), loc); + case u16: + return af::array(dims, reinterpret_cast(ptr), loc); + case b8: + return af::array(dims, reinterpret_cast(ptr), loc); + case u8: + return af::array(dims, reinterpret_cast(ptr), loc); + default: + throw std::invalid_argument( + "fromFlData: can't construct ArrayFire array from given type." + ); + } } af_border_type flToAfPadType(PadType type) { - switch (type) { - case PadType::Constant: - return AF_PAD_ZERO; // constant padding --> zero padding in AF - case PadType::Edge: - return AF_PAD_CLAMP_TO_EDGE; - case PadType::Symmetric: - return AF_PAD_SYM; - default: - throw std::invalid_argument( - "flToAfPadType: Flashlight padding " - "type not supported by ArrayFire"); - } + switch(type) { + case PadType::Constant: + return AF_PAD_ZERO; // constant padding --> zero padding in AF + case PadType::Edge: + return AF_PAD_CLAMP_TO_EDGE; + case PadType::Symmetric: + return AF_PAD_SYM; + default: + throw std::invalid_argument( + "flToAfPadType: Flashlight padding " + "type not supported by ArrayFire" + ); + } } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/Utils.h b/flashlight/fl/tensor/backend/af/Utils.h index e4d8873..0b299ba 100644 --- a/flashlight/fl/tensor/backend/af/Utils.h +++ b/flashlight/fl/tensor/backend/af/Utils.h @@ -19,15 +19,20 @@ #include "flashlight/fl/tensor/TensorBase.h" #include "flashlight/fl/tensor/Types.h" -#define AF_CHECK(fn) \ - do { \ - af_err __err = fn; \ - if (__err == AF_SUCCESS) { \ - break; \ - } \ - throw af::exception( \ - "ArrayFire error: ", __PRETTY_FUNCTION__, __FILE__, __LINE__, __err); \ - } while (0) +#define AF_CHECK(fn) \ + do { \ + af_err __err = fn; \ + if(__err == AF_SUCCESS) { \ + break; \ + } \ + throw af::exception( \ + "ArrayFire error: ", \ + __PRETTY_FUNCTION__, \ + __FILE__, \ + __LINE__, \ + __err \ + ); \ + } while(0) namespace fl { namespace detail { @@ -35,59 +40,59 @@ namespace detail { /** * Convert an fl::dtype into an ArrayFire af::dtype */ -af::dtype flToAfType(fl::dtype type); + af::dtype flToAfType(fl::dtype type); /** * Convert an ArrayFire af::dtype into an fl::dtype */ -fl::dtype afToFlType(af::dtype type); + fl::dtype afToFlType(af::dtype type); /** * Convert a Flashlight matrix property into an ArrayFire matrix property. */ -af_mat_prop flToAfMatrixProperty(MatrixProperty property); + af_mat_prop flToAfMatrixProperty(MatrixProperty property); /** * Convert a Flashlight tensor storage type into an ArrayFire storage type. */ -af_storage flToAfStorageType(StorageType storageType); + af_storage flToAfStorageType(StorageType storageType); /** * Convert a Flashlight tensor sort mode into an ArrayFire topk sort mode. */ -af_topk_function flToAfTopKSortMode(SortMode sortMode); + af_topk_function flToAfTopKSortMode(SortMode sortMode); /** * Convert an fl::Shape into an ArrayFire af::dim4 */ -af::dim4 flToAfDims(const Shape& shape); + af::dim4 flToAfDims(const Shape& shape); /** * Convert an ArrayFire af::dim4 into an fl::Shape */ -Shape afToFlDims(const af::dim4& d, const unsigned numDims); + Shape afToFlDims(const af::dim4& d, const unsigned numDims); /** * Convert an ArrayFire af::dim4 into an fl::Shape, in-place */ -void afToFlDims(const af::dim4& d, const unsigned numDims, Shape& s); + void afToFlDims(const af::dim4& d, const unsigned numDims, Shape& s); /** * Convert an fl::range into an af::seq. */ -af::seq flRangeToAfSeq(const fl::range& range); + af::seq flRangeToAfSeq(const fl::range& range); /** * Convert an fl::Index into an af::index. */ -af::index flToAfIndex(const fl::Index& idx); + af::index flToAfIndex(const fl::Index& idx); -std::vector flToAfIndices(const std::vector& flIndices); + std::vector flToAfIndices(const std::vector& flIndices); /** * Strip leading 1 indices from an ArrayFire dim4. */ -af::dim4 condenseDims(const af::dim4& dims); + af::dim4 condenseDims(const af::dim4& dims); /** * Modify the dimensions (in place via af::moddims) or an Array to have no 1 @@ -98,31 +103,33 @@ af::dim4 condenseDims(const af::dim4& dims); * * If keepDims is true, this is a noop, and the array is returned as is. */ -af::array condenseIndices( - const af::array& arr, - const bool keepDims = false, - const std::optional>& indexTypes = {}, - const bool isFlat = false); + af::array condenseIndices( + const af::array& arr, + const bool keepDims = false, + const std::optional>& indexTypes = {}, + const bool isFlat = false + ); /** * Convert a Flashlight Location into an ArrayFire location (host or device). */ -af_source flToAfLocation(Location location); + af_source flToAfLocation(Location location); /** * Construct an ArrayFire array from a buffer and Flashlight details. */ -af::array fromFlData( - const Shape& shape, - const void* ptr, - fl::dtype type, - fl::Location memoryLocation); + af::array fromFlData( + const Shape& shape, + const void* ptr, + fl::dtype type, + fl::Location memoryLocation + ); /** * Convert a Flashlight PadType to an ArrayFire af_border_type for describing * padding. */ -af_border_type flToAfPadType(PadType type); + af_border_type flToAfPadType(PadType type); } // namespace detail } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp index c19ebaf..7dc12ce 100644 --- a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp +++ b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp @@ -22,423 +22,439 @@ namespace fl { namespace { -constexpr size_t kMinBlockSize = - 512; // all sizes are rounded to at least 512 bytes -constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB -constexpr size_t kSmallBuffer = - 2097152; // "small" allocations are packed in 2 MiB blocks -constexpr size_t kLargeBuffer = - 20971520; // "large" allocations may be packed in 20 MiB blocks -constexpr size_t kMinLargeAlloc = - 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer -constexpr size_t kRoundLarge = 2097152; // round up large allocs to 2 MiB + constexpr size_t kMinBlockSize = + 512; // all sizes are rounded to at least 512 bytes + constexpr size_t kSmallSize = 1048576; // largest "small" allocation is 1 MiB + constexpr size_t kSmallBuffer = + 2097152; // "small" allocations are packed in 2 MiB blocks + constexpr size_t kLargeBuffer = + 20971520; // "large" allocations may be packed in 20 MiB blocks + constexpr size_t kMinLargeAlloc = + 10485760; // allocations between 1 and 10 MiB may use kLargeBuffer + constexpr size_t kRoundLarge = 2097152; // round up large allocs to 2 MiB // Environment variables names, specifying number of mega bytes as floats. -constexpr const char* kMemRecyclingSize = "FL_MEM_RECYCLING_SIZE_MB"; -constexpr const char* kMemSplitSize = "FL_MEM_SPLIT_SIZE_MB"; -constexpr double kMB = static_cast(1UL << 20); - -size_t roundSize(size_t size) { - if (size < kMinBlockSize) { - return kMinBlockSize; - } else { - return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize); - } -} + constexpr const char* kMemRecyclingSize = "FL_MEM_RECYCLING_SIZE_MB"; + constexpr const char* kMemSplitSize = "FL_MEM_SPLIT_SIZE_MB"; + constexpr double kMB = static_cast(1UL << 20); + + size_t roundSize(size_t size) { + if(size < kMinBlockSize) { + return kMinBlockSize; + } else { + return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize); + } + } -size_t getAllocationSize(size_t size) { - if (size <= kSmallSize) { - return kSmallBuffer; - } else if (size < kMinLargeAlloc) { - return kLargeBuffer; - } else { - return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge); - } -} + size_t getAllocationSize(size_t size) { + if(size <= kSmallSize) { + return kSmallBuffer; + } else if(size < kMinLargeAlloc) { + return kLargeBuffer; + } else { + return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge); + } + } -static bool BlockComparator( - const CachingMemoryManager::Block* a, - const CachingMemoryManager::Block* b) { - if (a->size_ != b->size_) { - return a->size_ < b->size_; - } - return (uintptr_t)a->ptr_ < (uintptr_t)b->ptr_; -} + static bool BlockComparator( + const CachingMemoryManager::Block* a, + const CachingMemoryManager::Block* b + ) { + if(a->size_ != b->size_) { + return a->size_ < b->size_; + } + return (uintptr_t) a->ptr_ < (uintptr_t) b->ptr_; + } -std::string formatMemory(size_t bytes) { - const std::vector units = {"B", "KiB", "MiB", "GiB", "TiB"}; - size_t unitId = - bytes == 0 ? 0 : std::floor(std::log(bytes) / std::log(1024.0)); - unitId = std::min(unitId, units.size() - 1); - std::string bytesStr = std::to_string(bytes / std::pow(1024.0, unitId)); - bytesStr = bytesStr.substr(0, bytesStr.find('.') + 3); - return bytesStr + " " + units[unitId]; -} + std::string formatMemory(size_t bytes) { + const std::vector units = {"B", "KiB", "MiB", "GiB", "TiB"}; + size_t unitId = + bytes == 0 ? 0 : std::floor(std::log(bytes) / std::log(1024.0)); + unitId = std::min(unitId, units.size() - 1); + std::string bytesStr = std::to_string(bytes / std::pow(1024.0, unitId)); + bytesStr = bytesStr.substr(0, bytesStr.find('.') + 3); + return bytesStr + " " + units[unitId]; + } /** * Returns number of bytes as represented by the named environment variable. The * variable is interperested as a float string specifying value in MBs. Returns * defaultVal on failure to read the variable or parse its value. */ -size_t getEnvAsBytesFromFloatMb(const char* name, size_t defaultVal) { - const char* env = std::getenv(name); - if (env) { - try { - const double mb = std::stod(env); - return std::round(mb * kMB); - } catch (std::exception& ex) { - std::cerr << "getEnvAsBytesFromFloatMb: Invalid environment " + size_t getEnvAsBytesFromFloatMb(const char* name, size_t defaultVal) { + const char* env = std::getenv(name); + if(env) { + try { + const double mb = std::stod(env); + return std::round(mb * kMB); + } catch(std::exception& ex) { + std::cerr << "getEnvAsBytesFromFloatMb: Invalid environment " << "variable value: name=" << name << " value=" << env; - throw ex; + throw ex; + } + } + return defaultVal; } - } - return defaultVal; -} } // namespace -CachingMemoryManager::DeviceMemoryInfo::DeviceMemoryInfo(int id) - : deviceId_(id), - largeBlocks_(BlockComparator), - smallBlocks_(BlockComparator) {} +CachingMemoryManager::DeviceMemoryInfo::DeviceMemoryInfo(int id) : deviceId_(id), + largeBlocks_(BlockComparator), + smallBlocks_(BlockComparator) {} CachingMemoryManager::CachingMemoryManager( int numDevices, - std::shared_ptr deviceInterface) - : MemoryManagerAdapter(deviceInterface) { - recyclingSizeLimit_ = - getEnvAsBytesFromFloatMb(kMemRecyclingSize, recyclingSizeLimit_); - splitSizeLimit_ = getEnvAsBytesFromFloatMb(kMemSplitSize, splitSizeLimit_); - - for (int i = 0; i < numDevices; ++i) { - deviceMemInfos_.emplace( - i, std::make_unique(i)); - } + std::shared_ptr deviceInterface +) : MemoryManagerAdapter(deviceInterface) { + recyclingSizeLimit_ = + getEnvAsBytesFromFloatMb(kMemRecyclingSize, recyclingSizeLimit_); + splitSizeLimit_ = getEnvAsBytesFromFloatMb(kMemSplitSize, splitSizeLimit_); + + for(int i = 0; i < numDevices; ++i) { + deviceMemInfos_.emplace( + i, + std::make_unique(i) + ); + } } void CachingMemoryManager::initialize() {} void CachingMemoryManager::setRecyclingSizeLimit(size_t limit) { - recyclingSizeLimit_ = limit; + recyclingSizeLimit_ = limit; } void CachingMemoryManager::setSplitSizeLimit(size_t limit) { - splitSizeLimit_ = limit; + splitSizeLimit_ = limit; } void CachingMemoryManager::shutdown() { - signalMemoryCleanup(); + signalMemoryCleanup(); } void CachingMemoryManager::addMemoryManagement(int device) { - if (deviceMemInfos_.find(device) != deviceMemInfos_.end()) { - return; - } - deviceMemInfos_.emplace( - device, std::make_unique(device)); + if(deviceMemInfos_.find(device) != deviceMemInfos_.end()) { + return; + } + deviceMemInfos_.emplace( + device, + std::make_unique(device) + ); } void CachingMemoryManager::removeMemoryManagement(int device) { - if (deviceMemInfos_.find(device) == deviceMemInfos_.end()) { - return; - } - deviceMemInfos_.erase(device); + if(deviceMemInfos_.find(device) == deviceMemInfos_.end()) { + return; + } + deviceMemInfos_.erase(device); } void* CachingMemoryManager::alloc( bool userLock, const unsigned ndims, dim_t* dims, - const unsigned elementSize) { - auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); - size_t size = elementSize; - for (unsigned i = 0; i < ndims; ++i) { - size *= dims[i]; - } - if (size == 0) { - return nullptr; - } - size = roundSize(size); - const bool isSmallAlloc = (size <= kSmallSize); - CachingMemoryManager::Block searchKey(size); - CachingMemoryManager::BlockSet& pool = - isSmallAlloc ? memoryInfo.smallBlocks_ : memoryInfo.largeBlocks_; - - CachingMemoryManager::Block* block = nullptr; - auto it = pool.lower_bound(&searchKey); - // Recycle blocks if any found, and if small alloc or the block size is not - // too large: - if (it != pool.end() && - (isSmallAlloc || (*it)->size_ < recyclingSizeLimit_)) { - block = *it; - pool.erase(it); - memoryInfo.stats_.cachedBytes_ -= block->size_; - } else { - void* ptr = nullptr; - size_t allocSize = getAllocationSize(size); - mallocWithRetry(allocSize, &ptr); // could throw - block = new Block(allocSize, ptr); - memoryInfo.stats_.allocatedBytes_ += allocSize; - } - - // If the block is larger than the requested size to handle another - // allocation in the same large or small BlockSet, it will be split into two. - // Note that we don't split a small stepsize out of a large one to keep the - // implementation simple. - CachingMemoryManager::Block* remaining = nullptr; - size_t diff = block->size_ - size; - if ((diff >= (isSmallAlloc ? kMinBlockSize : kSmallSize)) && - (block->size_ < splitSizeLimit_) // possibly dont split large buffers to - // minimize risk of fragmentation - ) { - remaining = block; - block = new Block(size, block->ptr_); - block->prev_ = remaining->prev_; - if (block->prev_) { - block->prev_->next_ = block; + const unsigned elementSize +) { + auto& memoryInfo = getDeviceMemoryInfo(); + std::lock_guard lock(memoryInfo.mutexAll_); + size_t size = elementSize; + for(unsigned i = 0; i < ndims; ++i) { + size *= dims[i]; + } + if(size == 0) { + return nullptr; } - block->next_ = remaining; - - remaining->prev_ = block; - remaining->ptr_ = static_cast(remaining->ptr_) + size; - remaining->size_ -= size; - pool.insert(remaining); - memoryInfo.stats_.cachedBytes_ += remaining->size_; - } - - block->managerLock_ = !userLock; - block->userLock_ = userLock; - memoryInfo.allocatedBlocks_[block->ptr_] = block; - return static_cast(block->ptr_); + size = roundSize(size); + const bool isSmallAlloc = (size <= kSmallSize); + CachingMemoryManager::Block searchKey(size); + CachingMemoryManager::BlockSet& pool = + isSmallAlloc ? memoryInfo.smallBlocks_ : memoryInfo.largeBlocks_; + + CachingMemoryManager::Block* block = nullptr; + auto it = pool.lower_bound(&searchKey); + // Recycle blocks if any found, and if small alloc or the block size is not + // too large: + if( + it != pool.end() + && (isSmallAlloc || (*it)->size_ < recyclingSizeLimit_) + ) { + block = *it; + pool.erase(it); + memoryInfo.stats_.cachedBytes_ -= block->size_; + } else { + void* ptr = nullptr; + size_t allocSize = getAllocationSize(size); + mallocWithRetry(allocSize, &ptr); // could throw + block = new Block(allocSize, ptr); + memoryInfo.stats_.allocatedBytes_ += allocSize; + } + + // If the block is larger than the requested size to handle another + // allocation in the same large or small BlockSet, it will be split into two. + // Note that we don't split a small stepsize out of a large one to keep the + // implementation simple. + CachingMemoryManager::Block* remaining = nullptr; + size_t diff = block->size_ - size; + if( + (diff >= (isSmallAlloc ? kMinBlockSize : kSmallSize)) + && (block->size_ < splitSizeLimit_) // possibly dont split large buffers to + // minimize risk of fragmentation + ) { + remaining = block; + block = new Block(size, block->ptr_); + block->prev_ = remaining->prev_; + if(block->prev_) { + block->prev_->next_ = block; + } + block->next_ = remaining; + + remaining->prev_ = block; + remaining->ptr_ = static_cast(remaining->ptr_) + size; + remaining->size_ -= size; + pool.insert(remaining); + memoryInfo.stats_.cachedBytes_ += remaining->size_; + } + + block->managerLock_ = !userLock; + block->userLock_ = userLock; + memoryInfo.allocatedBlocks_[block->ptr_] = block; + return static_cast(block->ptr_); } size_t CachingMemoryManager::allocated(void* ptr) { - if (!ptr) { - return 0; - } - auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); - auto it = memoryInfo.allocatedBlocks_.find(ptr); - if (it == memoryInfo.allocatedBlocks_.end()) { - return 0; - } - return (it->second)->size_; + if(!ptr) { + return 0; + } + auto& memoryInfo = getDeviceMemoryInfo(); + std::lock_guard lock(memoryInfo.mutexAll_); + auto it = memoryInfo.allocatedBlocks_.find(ptr); + if(it == memoryInfo.allocatedBlocks_.end()) { + return 0; + } + return (it->second)->size_; } void CachingMemoryManager::unlock(void* ptr, bool userUnlock) { - if (!ptr) { - return; - } - auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); - auto it = memoryInfo.allocatedBlocks_.find(ptr); - if (it == memoryInfo.allocatedBlocks_.end()) { - // Probably came from user, just free it - this->deviceInterface->nativeFree(ptr); - ++memoryInfo.stats_.totalNativeFrees_; - return; - } - - CachingMemoryManager::Block* block = it->second; - if (userUnlock) { - block->userLock_ = false; - } else { - block->managerLock_ = false; - } - - // Return early if either one is locked - if (block->inUse()) { - return; - } - memoryInfo.allocatedBlocks_.erase(it); - freeBlock(block); + if(!ptr) { + return; + } + auto& memoryInfo = getDeviceMemoryInfo(); + std::lock_guard lock(memoryInfo.mutexAll_); + auto it = memoryInfo.allocatedBlocks_.find(ptr); + if(it == memoryInfo.allocatedBlocks_.end()) { + // Probably came from user, just free it + this->deviceInterface->nativeFree(ptr); + ++memoryInfo.stats_.totalNativeFrees_; + return; + } + + CachingMemoryManager::Block* block = it->second; + if(userUnlock) { + block->userLock_ = false; + } else { + block->managerLock_ = false; + } + + // Return early if either one is locked + if(block->inUse()) { + return; + } + memoryInfo.allocatedBlocks_.erase(it); + freeBlock(block); } void CachingMemoryManager::freeBlock(CachingMemoryManager::Block* block) { - if (block->inUse()) { - throw std::runtime_error("trying to free a block which is in use"); - } - auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); - - const bool isSmallAlloc = (block->size_ <= kSmallSize); - CachingMemoryManager::BlockSet& pool = - isSmallAlloc ? memoryInfo.smallBlocks_ : memoryInfo.largeBlocks_; - tryMergeBlocks(block, block->prev_, pool); - tryMergeBlocks(block, block->next_, pool); - - pool.insert(block); - memoryInfo.stats_.cachedBytes_ += block->size_; + if(block->inUse()) { + throw std::runtime_error("trying to free a block which is in use"); + } + auto& memoryInfo = getDeviceMemoryInfo(); + std::lock_guard lock(memoryInfo.mutexAll_); + + const bool isSmallAlloc = (block->size_ <= kSmallSize); + CachingMemoryManager::BlockSet& pool = + isSmallAlloc ? memoryInfo.smallBlocks_ : memoryInfo.largeBlocks_; + tryMergeBlocks(block, block->prev_, pool); + tryMergeBlocks(block, block->next_, pool); + + pool.insert(block); + memoryInfo.stats_.cachedBytes_ += block->size_; } /** combine previously split blocks */ void CachingMemoryManager::tryMergeBlocks( CachingMemoryManager::Block* dst, CachingMemoryManager::Block* src, - BlockSet& pool) { - if (!src || src->inUse()) { - return; - } - if (dst->prev_ == src) { - dst->ptr_ = src->ptr_; - dst->prev_ = src->prev_; - if (dst->prev_) { - dst->prev_->next_ = dst; + BlockSet& pool +) { + if(!src || src->inUse()) { + return; } - } else { - dst->next_ = src->next_; - if (dst->next_) { - dst->next_->prev_ = dst; + if(dst->prev_ == src) { + dst->ptr_ = src->ptr_; + dst->prev_ = src->prev_; + if(dst->prev_) { + dst->prev_->next_ = dst; + } + } else { + dst->next_ = src->next_; + if(dst->next_) { + dst->next_->prev_ = dst; + } } - } - dst->size_ += src->size_; - pool.erase(src); - getDeviceMemoryInfo().stats_.cachedBytes_ -= src->size_; - delete src; + dst->size_ += src->size_; + pool.erase(src); + getDeviceMemoryInfo().stats_.cachedBytes_ -= src->size_; + delete src; } void CachingMemoryManager::mallocWithRetry(size_t size, void** ptr) { - // Try nativeMalloc. If nativeMalloc fails, frees all non-split cached blocks - // and retries. - auto& memInfo = getDeviceMemoryInfo(); - try { - ++memInfo.stats_.totalNativeMallocs_; - *ptr = this->deviceInterface->nativeAlloc(size); - } catch (std::exception&) { + // Try nativeMalloc. If nativeMalloc fails, frees all non-split cached blocks + // and retries. + auto& memInfo = getDeviceMemoryInfo(); try { - signalMemoryCleanup(); - ++memInfo.stats_.totalNativeMallocs_; - *ptr = this->deviceInterface->nativeAlloc(size); - } catch (std::exception& ex) { - // note: af exception inherits from std exception - std::cerr << "Failed to allocate memory of size " << formatMemory(size) - << " (Device: " << memInfo.deviceId_ << ", Capacity: " - << formatMemory(this->deviceInterface->getMaxMemorySize( - memInfo.deviceId_)) - << ", Allocated: " - << formatMemory(memInfo.stats_.allocatedBytes_) - << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) - << ") with error '" << ex.what() << "'" << std::endl; - // note: converting here an af exception to std exception prevents to - // catch the af error code at the user level. Rethrowing. - throw; + ++memInfo.stats_.totalNativeMallocs_; + *ptr = this->deviceInterface->nativeAlloc(size); + } catch(std::exception&) { + try { + signalMemoryCleanup(); + ++memInfo.stats_.totalNativeMallocs_; + *ptr = this->deviceInterface->nativeAlloc(size); + } catch(std::exception& ex) { + // note: af exception inherits from std exception + std::cerr << "Failed to allocate memory of size " << formatMemory(size) + << " (Device: " << memInfo.deviceId_ << ", Capacity: " + << formatMemory( + this->deviceInterface->getMaxMemorySize( + memInfo.deviceId_ + ) + ) + << ", Allocated: " + << formatMemory(memInfo.stats_.allocatedBytes_) + << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) + << ") with error '" << ex.what() << "'" << std::endl; + // note: converting here an af exception to std exception prevents to + // catch the af error code at the user level. Rethrowing. + throw; + } } - } } void CachingMemoryManager::freeBlocks( BlockSet& blocks, BlockSet::iterator it, - BlockSet::iterator end) { - // Frees all non-split blocks between `it` and `end` - auto& memoryInfo = getDeviceMemoryInfo(); - while (it != end) { - Block* block = *it; - if (!block->isSplit()) { - this->deviceInterface->nativeFree(static_cast(block->ptr_)); - ++memoryInfo.stats_.totalNativeFrees_; - memoryInfo.stats_.allocatedBytes_ -= block->size_; - memoryInfo.stats_.cachedBytes_ -= block->size_; - auto cur = it; - ++it; - blocks.erase(cur); - delete block; - } else { - ++it; + BlockSet::iterator end +) { + // Frees all non-split blocks between `it` and `end` + auto& memoryInfo = getDeviceMemoryInfo(); + while(it != end) { + Block* block = *it; + if(!block->isSplit()) { + this->deviceInterface->nativeFree(static_cast(block->ptr_)); + ++memoryInfo.stats_.totalNativeFrees_; + memoryInfo.stats_.allocatedBytes_ -= block->size_; + memoryInfo.stats_.cachedBytes_ -= block->size_; + auto cur = it; + ++it; + blocks.erase(cur); + delete block; + } else { + ++it; + } } - } } void CachingMemoryManager::signalMemoryCleanup() { - // Free all non-split cached blocks on device - auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); - - freeBlocks( - memoryInfo.largeBlocks_, - memoryInfo.largeBlocks_.begin(), - memoryInfo.largeBlocks_.end()); - - freeBlocks( - memoryInfo.smallBlocks_, - memoryInfo.smallBlocks_.begin(), - memoryInfo.smallBlocks_.end()); + // Free all non-split cached blocks on device + auto& memoryInfo = getDeviceMemoryInfo(); + std::lock_guard lock(memoryInfo.mutexAll_); + + freeBlocks( + memoryInfo.largeBlocks_, + memoryInfo.largeBlocks_.begin(), + memoryInfo.largeBlocks_.end() + ); + + freeBlocks( + memoryInfo.smallBlocks_, + memoryInfo.smallBlocks_.begin(), + memoryInfo.smallBlocks_.end() + ); } float CachingMemoryManager::getMemoryPressure() { - return 0.0; // TODO: check if this is optimal + return 0.0; // TODO: check if this is optimal } bool CachingMemoryManager::jitTreeExceedsMemoryPressure(size_t /* unused */) { - return false; // TODO: check if this is optimal + return false; // TODO: check if this is optimal } void CachingMemoryManager::printInfo( const char* msg, const int /* unused */, - std::ostream* _ostream) { - std::ostream& ostream = *_ostream; - auto& memInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memInfo.mutexAll_); - - ostream << msg << "\nType: CachingMemoryManager" << std::endl - << "\nDevice: " << memInfo.deviceId_ << ", Capacity: " - << formatMemory( - this->deviceInterface->getMaxMemorySize(memInfo.deviceId_)) - << ", Allocated: " << formatMemory(memInfo.stats_.allocatedBytes_) - << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) - << std::endl - << "\nTotal native calls: " << memInfo.stats_.totalNativeMallocs_ - << "(mallocs), " << memInfo.stats_.totalNativeFrees_ << "(frees)" - << std::endl; + std::ostream* _ostream +) { + std::ostream& ostream = *_ostream; + auto& memInfo = getDeviceMemoryInfo(); + std::lock_guard lock(memInfo.mutexAll_); + + ostream << msg << "\nType: CachingMemoryManager" << std::endl + << "\nDevice: " << memInfo.deviceId_ << ", Capacity: " + << formatMemory( + this->deviceInterface->getMaxMemorySize(memInfo.deviceId_) + ) + << ", Allocated: " << formatMemory(memInfo.stats_.allocatedBytes_) + << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) + << std::endl + << "\nTotal native calls: " << memInfo.stats_.totalNativeMallocs_ + << "(mallocs), " << memInfo.stats_.totalNativeFrees_ << "(frees)" + << std::endl; } void CachingMemoryManager::userLock(const void* ptr) { - if (!ptr) { - return; - } - auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); - - auto it = memoryInfo.allocatedBlocks_.find(const_cast(ptr)); - if (it == memoryInfo.allocatedBlocks_.end()) { - // Follows the behavior of DefaultMemoryManager - auto block = new Block(kSmallBuffer, const_cast(ptr)); - block->managerLock_ = false; - block->userLock_ = true; - memoryInfo.allocatedBlocks_[block->ptr_] = block; - } else { - it->second->userLock_ = true; - } + if(!ptr) { + return; + } + auto& memoryInfo = getDeviceMemoryInfo(); + std::lock_guard lock(memoryInfo.mutexAll_); + + auto it = memoryInfo.allocatedBlocks_.find(const_cast(ptr)); + if(it == memoryInfo.allocatedBlocks_.end()) { + // Follows the behavior of DefaultMemoryManager + auto block = new Block(kSmallBuffer, const_cast(ptr)); + block->managerLock_ = false; + block->userLock_ = true; + memoryInfo.allocatedBlocks_[block->ptr_] = block; + } else { + it->second->userLock_ = true; + } } void CachingMemoryManager::userUnlock(const void* ptr) { - this->unlock(const_cast(ptr), true); + this->unlock(const_cast(ptr), true); } bool CachingMemoryManager::isUserLocked(const void* ptr) { - if (!ptr) { - return false; - } - auto& memoryInfo = getDeviceMemoryInfo(); - std::lock_guard lock(memoryInfo.mutexAll_); - auto it = memoryInfo.allocatedBlocks_.find(const_cast(ptr)); - if (it == memoryInfo.allocatedBlocks_.end()) { - return false; - } - return it->second->userLock_; + if(!ptr) { + return false; + } + auto& memoryInfo = getDeviceMemoryInfo(); + std::lock_guard lock(memoryInfo.mutexAll_); + auto it = memoryInfo.allocatedBlocks_.find(const_cast(ptr)); + if(it == memoryInfo.allocatedBlocks_.end()) { + return false; + } + return it->second->userLock_; } -CachingMemoryManager::DeviceMemoryInfo& -CachingMemoryManager::getDeviceMemoryInfo(int device /* = -1*/) { - if (device == -1) { - device = this->deviceInterface->getActiveDeviceId(); - } - auto it = deviceMemInfos_.find(device); - if (it == deviceMemInfos_.end() || !it->second) { - throw std::runtime_error("meminfo for the device doesn't exist"); - } - return *(it->second); +CachingMemoryManager::DeviceMemoryInfo& CachingMemoryManager::getDeviceMemoryInfo(int device /* = -1*/) { + if(device == -1) { + device = this->deviceInterface->getActiveDeviceId(); + } + auto it = deviceMemInfos_.find(device); + if(it == deviceMemInfos_.end() || !it->second) { + throw std::runtime_error("meminfo for the device doesn't exist"); + } + return *(it->second); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.h b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.h index be8c841..57ea899 100644 --- a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.h +++ b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.h @@ -29,130 +29,132 @@ namespace fl { * https://github.com/pytorch/pytorch/blob/master/c10/cuda/CUDACachingAllocator.cpp */ class CachingMemoryManager : public MemoryManagerAdapter { - public: - CachingMemoryManager( - int numDevices, - std::shared_ptr deviceInterface); - ~CachingMemoryManager() override = default; - void initialize() override; - void shutdown() override; - void* alloc( - bool userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elSize) override; - size_t allocated(void* ptr) override; - void unlock(void* ptr, bool userLock) override; - void printInfo( - const char* msg, - const int device, - std::ostream* ostream = &std::cout) override; - void userLock(const void* ptr) override; - void userUnlock(const void* ptr) override; - bool isUserLocked(const void* ptr) override; - void signalMemoryCleanup() override; - float getMemoryPressure() override; - bool jitTreeExceedsMemoryPressure(size_t bytes) override; - void addMemoryManagement(int device) override; - void removeMemoryManagement(int device) override; - // Set runtime options: RecyclingSizeLimit, SplitSizeLimit, ... Warning: not - // thread safe - void setRecyclingSizeLimit(size_t); - void setSplitSizeLimit(size_t); - - // Block denotes a single allocated unit of memory. - struct Block { - size_t size_; // size of block in bytes - void* ptr_; // memory address - bool managerLock_; // whether the memory is locked by the memory manager - bool userLock_; // whether the memory is locked by the user - Block* prev_; // prev block if split from a larger allocation - Block* next_; // next block if split from a larger allocation - - bool isSplit() const { - return (prev_ != nullptr) || (next_ != nullptr); - } - - bool inUse() const { - return managerLock_ || userLock_; - } - - explicit Block(size_t size, void* ptr = nullptr) - : size_(size), - ptr_(ptr), - managerLock_(false), - userLock_(false), - prev_(nullptr), - next_(nullptr) {} - }; - - typedef bool (*Comparison)(const Block*, const Block*); - typedef std::set BlockSet; - - // A structure to store allocation stats per device. - struct MemoryAllocationStats { - size_t totalNativeMallocs_; - size_t totalNativeFrees_; - size_t allocatedBytes_; // memory allocated by mem manager for the program - size_t cachedBytes_; // memory held by mem manager & not used by the program - - MemoryAllocationStats() - : totalNativeMallocs_(0), - totalNativeFrees_(0), - allocatedBytes_(0), - cachedBytes_(0) {} - }; - - // Stores the mutex and misc variables per device so that we operate in a - // thredsafe manner. - struct DeviceMemoryInfo { - int deviceId_; - - // lock around all operations - std::recursive_mutex mutexAll_; // TODO:: improve perf using R/W locks - - // cached blocks larger than 1 MB - BlockSet largeBlocks_; - - // cached blocks 1 MB or smaller - BlockSet smallBlocks_; - - // allocated blocks by device pointer - std::unordered_map allocatedBlocks_; - - MemoryAllocationStats stats_; - - explicit DeviceMemoryInfo(int id); - }; - - protected: - std::unordered_map> deviceMemInfos_; - - CachingMemoryManager(const CachingMemoryManager& other) = delete; - CachingMemoryManager(CachingMemoryManager&& other) = delete; - CachingMemoryManager& operator=(const CachingMemoryManager& other) = delete; - CachingMemoryManager& operator=(CachingMemoryManager&& other) = delete; - - // Returns the memory info of the caching allocator for the given device. - // Using "-1" will return info for the current active device. - DeviceMemoryInfo& getDeviceMemoryInfo(int device = -1); - - void - freeBlocks(BlockSet& blocks, BlockSet::iterator it, BlockSet::iterator end); - - void mallocWithRetry(size_t size, void** ptr); - - void tryMergeBlocks(Block* dst, Block* src, BlockSet& freeBlocks); - void freeBlock(Block* block); - - private: - // Non-const runtime options in order to fine tune the behavior of this - // manager. Prevents to recycle some buffers, to be set by the user if - // desired: - size_t recyclingSizeLimit_{std::numeric_limits::max()}; - // size_t recyclingSizeLimit; - // Prevents to split big buffers, to be set by the user if desired: - size_t splitSizeLimit_{std::numeric_limits::max()}; +public: + CachingMemoryManager( + int numDevices, + std::shared_ptr deviceInterface + ); + ~CachingMemoryManager() override = default; + void initialize() override; + void shutdown() override; + void* alloc( + bool userLock, + const unsigned ndims, + dim_t* dims, + const unsigned elSize + ) override; + size_t allocated(void* ptr) override; + void unlock(void* ptr, bool userLock) override; + void printInfo( + const char* msg, + const int device, + std::ostream* ostream = & std::cout + ) override; + void userLock(const void* ptr) override; + void userUnlock(const void* ptr) override; + bool isUserLocked(const void* ptr) override; + void signalMemoryCleanup() override; + float getMemoryPressure() override; + bool jitTreeExceedsMemoryPressure(size_t bytes) override; + void addMemoryManagement(int device) override; + void removeMemoryManagement(int device) override; + // Set runtime options: RecyclingSizeLimit, SplitSizeLimit, ... Warning: not + // thread safe + void setRecyclingSizeLimit(size_t); + void setSplitSizeLimit(size_t); + + // Block denotes a single allocated unit of memory. + struct Block { + size_t size_; // size of block in bytes + void* ptr_; // memory address + bool managerLock_; // whether the memory is locked by the memory manager + bool userLock_; // whether the memory is locked by the user + Block* prev_; // prev block if split from a larger allocation + Block* next_; // next block if split from a larger allocation + + bool isSplit() const { + return (prev_ != nullptr) || (next_ != nullptr); + } + + bool inUse() const { + return managerLock_ || userLock_; + } + + explicit Block(size_t size, void* ptr = nullptr) + : size_(size), + ptr_(ptr), + managerLock_(false), + userLock_(false), + prev_(nullptr), + next_(nullptr) {} + }; + + typedef bool (*Comparison)(const Block*, const Block*); + typedef std::set BlockSet; + + // A structure to store allocation stats per device. + struct MemoryAllocationStats { + size_t totalNativeMallocs_; + size_t totalNativeFrees_; + size_t allocatedBytes_; // memory allocated by mem manager for the program + size_t cachedBytes_; // memory held by mem manager & not used by the program + + MemoryAllocationStats() + : totalNativeMallocs_(0), + totalNativeFrees_(0), + allocatedBytes_(0), + cachedBytes_(0) {} + }; + + // Stores the mutex and misc variables per device so that we operate in a + // thredsafe manner. + struct DeviceMemoryInfo { + int deviceId_; + + // lock around all operations + std::recursive_mutex mutexAll_; // TODO:: improve perf using R/W locks + + // cached blocks larger than 1 MB + BlockSet largeBlocks_; + + // cached blocks 1 MB or smaller + BlockSet smallBlocks_; + + // allocated blocks by device pointer + std::unordered_map allocatedBlocks_; + + MemoryAllocationStats stats_; + + explicit DeviceMemoryInfo(int id); + }; + +protected: + std::unordered_map> deviceMemInfos_; + + CachingMemoryManager(const CachingMemoryManager& other) = delete; + CachingMemoryManager(CachingMemoryManager&& other) = delete; + CachingMemoryManager& operator=(const CachingMemoryManager& other) = delete; + CachingMemoryManager& operator=(CachingMemoryManager&& other) = delete; + + // Returns the memory info of the caching allocator for the given device. + // Using "-1" will return info for the current active device. + DeviceMemoryInfo& getDeviceMemoryInfo(int device = -1); + + void freeBlocks(BlockSet& blocks, BlockSet::iterator it, BlockSet::iterator end); + + void mallocWithRetry(size_t size, void** ptr); + + void tryMergeBlocks(Block* dst, Block* src, BlockSet& freeBlocks); + void freeBlock(Block* block); + +private: + // Non-const runtime options in order to fine tune the behavior of this + // manager. Prevents to recycle some buffers, to be set by the user if + // desired: + size_t recyclingSizeLimit_{std::numeric_limits::max()}; + // size_t recyclingSizeLimit; + // Prevents to split big buffers, to be set by the user if desired: + size_t splitSizeLimit_{std::numeric_limits::max()}; }; } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp index bbc96fc..31be137 100644 --- a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp +++ b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp @@ -22,384 +22,391 @@ #include "flashlight/fl/tensor/backend/af/mem/MemoryManagerDeviceInterface.h" -#define divup(a, b) (((a) + (b)-1) / (b)) +#define divup(a, b) (((a) + (b) - 1) / (b)) namespace fl { DefaultMemoryManager::MemoryInfo& DefaultMemoryManager::getCurrentMemoryInfo() { - return memory[this->deviceInterface->getActiveDeviceId()]; + return memory[this->deviceInterface->getActiveDeviceId()]; } void DefaultMemoryManager::cleanDeviceMemoryManager(int device) { - if (this->debugMode) { - return; - } - - // This vector is used to store the pointers which will be deleted by - // the memory manager. We are using this to avoid calling free while - // the lock is being held because the CPU backend calls sync. - std::vector freePtrs; - size_t bytesFreed = 0; - MemoryInfo& current = memory[device]; - { - std::lock_guard lock(this->memoryMutex); - // Return if all buffers are locked - if (current.totalBuffers == current.lockBuffers) { - return; -} - freePtrs.reserve(current.freeMap.size()); - - for (auto& kv : current.freeMap) { - size_t numPtrs = kv.second.size(); - // Free memory by pushing the last element into the freePtrs - // vector which will be freed once outside of the lock - std::move( - std::begin(kv.second), - std::end(kv.second), - std::back_inserter(freePtrs)); - current.totalBytes -= numPtrs * kv.first; - bytesFreed += numPtrs * kv.first; - current.totalBuffers -= numPtrs; + if(this->debugMode) { + return; + } + + // This vector is used to store the pointers which will be deleted by + // the memory manager. We are using this to avoid calling free while + // the lock is being held because the CPU backend calls sync. + std::vector freePtrs; + size_t bytesFreed = 0; + MemoryInfo& current = memory[device]; + { + std::lock_guard lock(this->memoryMutex); + // Return if all buffers are locked + if(current.totalBuffers == current.lockBuffers) { + return; + } + freePtrs.reserve(current.freeMap.size()); + + for(auto& kv : current.freeMap) { + size_t numPtrs = kv.second.size(); + // Free memory by pushing the last element into the freePtrs + // vector which will be freed once outside of the lock + std::move( + std::begin(kv.second), + std::end(kv.second), + std::back_inserter(freePtrs) + ); + current.totalBytes -= numPtrs * kv.first; + bytesFreed += numPtrs * kv.first; + current.totalBuffers -= numPtrs; + } + current.freeMap.clear(); + } + + std::stringstream ss; + ss << "GC: Clearing " << freePtrs.size() << " buffers |" + << std::to_string(bytesFreed) << " bytes"; + this->log(ss.str()); + + // Free memory outside of the lock + for(auto ptr : freePtrs) { + this->deviceInterface->nativeFree(ptr); } - current.freeMap.clear(); - } - - std::stringstream ss; - ss << "GC: Clearing " << freePtrs.size() << " buffers |" - << std::to_string(bytesFreed) << " bytes"; - this->log(ss.str()); - - // Free memory outside of the lock - for (auto ptr : freePtrs) { - this->deviceInterface->nativeFree(ptr); - } } DefaultMemoryManager::DefaultMemoryManager( int numDevices, unsigned maxBuffers, bool debug, - std::shared_ptr deviceInterface) - : MemoryManagerAdapter(deviceInterface), - memStepSize(1024), - maxBuffers(maxBuffers), - debugMode(debug), - memory(numDevices) { - // Check for environment variables - // Debug mode - if (const char* c = std::getenv("AF_MEM_DEBUG")) { - this->debugMode = (std::string(c) != "0"); - } - if (this->debugMode) { - memStepSize = 1; - } - - // Max Buffer count - if (const char* c = std::getenv("AF_MAX_BUFFERS")) { - this->maxBuffers = std::max(1, std::stoi(std::string(c))); - } + std::shared_ptr deviceInterface +) : MemoryManagerAdapter(deviceInterface), + memStepSize(1024), + maxBuffers(maxBuffers), + debugMode(debug), + memory(numDevices) { + // Check for environment variables + // Debug mode + if(const char* c = std::getenv("AF_MEM_DEBUG")) { + this->debugMode = (std::string(c) != "0"); + } + if(this->debugMode) { + memStepSize = 1; + } + + // Max Buffer count + if(const char* c = std::getenv("AF_MAX_BUFFERS")) { + this->maxBuffers = std::max(1, std::stoi(std::string(c))); + } } void DefaultMemoryManager::initialize() { - this->setMaxMemorySize(); + this->setMaxMemorySize(); } void DefaultMemoryManager::shutdown() { - signalMemoryCleanup(); + signalMemoryCleanup(); } void DefaultMemoryManager::addMemoryManagement(int device) { - // If there is a memory manager allocated for this device id, we might - // as well use it and the buffers allocated for it - if (static_cast(device) < memory.size()) { - return; -} + // If there is a memory manager allocated for this device id, we might + // as well use it and the buffers allocated for it + if(static_cast(device) < memory.size()) { + return; + } - // Assuming, device need not be always the next device Lets resize to - // current_size + device + 1 +1 is to account for device being 0-based - // index of devices - memory.resize(memory.size() + device + 1); + // Assuming, device need not be always the next device Lets resize to + // current_size + device + 1 +1 is to account for device being 0-based + // index of devices + memory.resize(memory.size() + device + 1); } void DefaultMemoryManager::removeMemoryManagement(int device) { - if ((size_t)device >= memory.size()) { - throw std::runtime_error("No matching device found"); - } + if((size_t) device >= memory.size()) { + throw std::runtime_error("No matching device found"); + } - // Do garbage collection for the device and leave the - // MemoryInfo struct from the memory vector intact - cleanDeviceMemoryManager(device); + // Do garbage collection for the device and leave the + // MemoryInfo struct from the memory vector intact + cleanDeviceMemoryManager(device); } void DefaultMemoryManager::setMaxMemorySize() { - for (unsigned n = 0; n < memory.size(); n++) { - // Calls garbage collection when: totalBytes > memsize * 0.75 when - // memsize < 4GB totalBytes > memsize - 1 GB when memsize >= 4GB If - // memsize returned 0, then use 1GB - size_t memsize = this->deviceInterface->getMaxMemorySize(n); - memory[n].maxBytes = memsize == 0 - ? ONE_GB - : std::max(memsize * 0.75, static_cast(memsize - ONE_GB)); - } + for(unsigned n = 0; n < memory.size(); n++) { + // Calls garbage collection when: totalBytes > memsize * 0.75 when + // memsize < 4GB totalBytes > memsize - 1 GB when memsize >= 4GB If + // memsize returned 0, then use 1GB + size_t memsize = this->deviceInterface->getMaxMemorySize(n); + memory[n].maxBytes = memsize == 0 + ? ONE_GB + : std::max(memsize * 0.75, static_cast(memsize - ONE_GB)); + } } void* DefaultMemoryManager::alloc( bool userLock, const unsigned ndims, dim_t* dims, - const unsigned elementSize) { - size_t bytes = elementSize; - for (unsigned i = 0; i < ndims; ++i) { - bytes *= dims[i]; - } - - void* ptr = nullptr; - size_t allocBytes = - this->debugMode ? bytes : (divup(bytes, memStepSize) * memStepSize); - - if (bytes > 0) { - MemoryInfo& current = this->getCurrentMemoryInfo(); - LockedInfo info = {!userLock, userLock, allocBytes}; - - // There is no memory cache in debug mode - if (!this->debugMode) { - // FIXME: Add better checks for garbage collection - // Perhaps look at total memory available as a metric - if (current.lockBytes >= current.maxBytes || - current.totalBuffers >= this->maxBuffers) { - this->signalMemoryCleanup(); - } - - std::lock_guard lock(this->memoryMutex); - free_iter iter = current.freeMap.find(allocBytes); - - if (iter != current.freeMap.end() && !iter->second.empty()) { - // Set to existing in from free map - ptr = iter->second.back(); - iter->second.pop_back(); - current.lockedMap[ptr] = info; - current.lockBytes += allocBytes; - current.lockBuffers++; - } + const unsigned elementSize +) { + size_t bytes = elementSize; + for(unsigned i = 0; i < ndims; ++i) { + bytes *= dims[i]; } - // Only comes here if buffer size not found or in debug mode - if (ptr == nullptr) { - // Perform garbage collection if memory can not be allocated - try { - ptr = this->deviceInterface->nativeAlloc(allocBytes); - } catch (std::exception&) { - // FIXME: assume that the exception is due to out of memory, and don't - // continue propagating it - // If out of memory, run garbage collect and try again - // if (ex.err() != AF_ERR_NO_MEM) { - // throw; - // } - this->signalMemoryCleanup(); - ptr = this->deviceInterface->nativeAlloc(allocBytes); - } - std::lock_guard lock(this->memoryMutex); - // Increment these two only when it succeeds to come here. - current.totalBytes += allocBytes; - current.totalBuffers += 1; - current.lockedMap[ptr] = info; - current.lockBytes += allocBytes; - current.lockBuffers++; + void* ptr = nullptr; + size_t allocBytes = + this->debugMode ? bytes : (divup(bytes, memStepSize) * memStepSize); + + if(bytes > 0) { + MemoryInfo& current = this->getCurrentMemoryInfo(); + LockedInfo info = {!userLock, userLock, allocBytes}; + + // There is no memory cache in debug mode + if(!this->debugMode) { + // FIXME: Add better checks for garbage collection + // Perhaps look at total memory available as a metric + if( + current.lockBytes >= current.maxBytes + || current.totalBuffers >= this->maxBuffers + ) { + this->signalMemoryCleanup(); + } + + std::lock_guard lock(this->memoryMutex); + free_iter iter = current.freeMap.find(allocBytes); + + if(iter != current.freeMap.end() && !iter->second.empty()) { + // Set to existing in from free map + ptr = iter->second.back(); + iter->second.pop_back(); + current.lockedMap[ptr] = info; + current.lockBytes += allocBytes; + current.lockBuffers++; + } + } + + // Only comes here if buffer size not found or in debug mode + if(ptr == nullptr) { + // Perform garbage collection if memory can not be allocated + try { + ptr = this->deviceInterface->nativeAlloc(allocBytes); + } catch(std::exception&) { + // FIXME: assume that the exception is due to out of memory, and don't + // continue propagating it + // If out of memory, run garbage collect and try again + // if (ex.err() != AF_ERR_NO_MEM) { + // throw; + // } + this->signalMemoryCleanup(); + ptr = this->deviceInterface->nativeAlloc(allocBytes); + } + std::lock_guard lock(this->memoryMutex); + // Increment these two only when it succeeds to come here. + current.totalBytes += allocBytes; + current.totalBuffers += 1; + current.lockedMap[ptr] = info; + current.lockBytes += allocBytes; + current.lockBuffers++; + } } - } - return ptr; + return ptr; } size_t DefaultMemoryManager::allocated(void* ptr) { - if (!ptr) { - return 0; -} - MemoryInfo& current = this->getCurrentMemoryInfo(); - locked_iter iter = current.lockedMap.find((void*)ptr); - if (iter == current.lockedMap.end()) { - return 0; -} - return (iter->second).bytes; -} - -void DefaultMemoryManager::unlock(void* ptr, bool userUnlock) { - // Shortcut for empty arrays - if (!ptr) { - return; - } - - // Frees the pointer outside the lock. - uptr_t freedPtr( - nullptr, [this](void* p) { this->deviceInterface->nativeFree(p); }); - { - std::lock_guard lock(this->memoryMutex); - MemoryInfo& current = this->getCurrentMemoryInfo(); - - locked_iter iter = current.lockedMap.find((void*)ptr); - - // Pointer not found in locked map - if (iter == current.lockedMap.end()) { - // Probably came from user, just free it - freedPtr.reset(ptr); - return; + if(!ptr) { + return 0; } - - if (userUnlock) { - (iter->second).userLock = false; - } else { - (iter->second).managerLock = false; + MemoryInfo& current = this->getCurrentMemoryInfo(); + locked_iter iter = current.lockedMap.find((void*) ptr); + if(iter == current.lockedMap.end()) { + return 0; } + return (iter->second).bytes; +} - // Return early if either one is locked - if ((iter->second).userLock || (iter->second).managerLock) { - return; +void DefaultMemoryManager::unlock(void* ptr, bool userUnlock) { + // Shortcut for empty arrays + if(!ptr) { + return; } - size_t bytes = iter->second.bytes; - current.lockBytes -= iter->second.bytes; - current.lockBuffers--; - - if (this->debugMode) { - // Just free memory in debug mode - if ((iter->second).bytes > 0) { - freedPtr.reset(iter->first); - current.totalBuffers--; - current.totalBytes -= iter->second.bytes; - } - } else { - current.freeMap.at(bytes).emplace_back(ptr); + // Frees the pointer outside the lock. + uptr_t freedPtr( + nullptr, [this](void* p) { this->deviceInterface->nativeFree(p); }); + { + std::lock_guard lock(this->memoryMutex); + MemoryInfo& current = this->getCurrentMemoryInfo(); + + locked_iter iter = current.lockedMap.find((void*) ptr); + + // Pointer not found in locked map + if(iter == current.lockedMap.end()) { + // Probably came from user, just free it + freedPtr.reset(ptr); + return; + } + + if(userUnlock) { + (iter->second).userLock = false; + } else { + (iter->second).managerLock = false; + } + + // Return early if either one is locked + if((iter->second).userLock || (iter->second).managerLock) { + return; + } + + size_t bytes = iter->second.bytes; + current.lockBytes -= iter->second.bytes; + current.lockBuffers--; + + if(this->debugMode) { + // Just free memory in debug mode + if((iter->second).bytes > 0) { + freedPtr.reset(iter->first); + current.totalBuffers--; + current.totalBytes -= iter->second.bytes; + } + } else { + current.freeMap.at(bytes).emplace_back(ptr); + } + current.lockedMap.erase(iter); } - current.lockedMap.erase(iter); - } } void DefaultMemoryManager::signalMemoryCleanup() { - cleanDeviceMemoryManager(this->deviceInterface->getActiveDeviceId()); + cleanDeviceMemoryManager(this->deviceInterface->getActiveDeviceId()); } float DefaultMemoryManager::getMemoryPressure() { - std::lock_guard lock(this->memoryMutex); - MemoryInfo& current = this->getCurrentMemoryInfo(); - if (current.lockBytes > current.maxBytes || - current.lockBuffers > maxBuffers) { - return 1.0; - } else { - return 0.0; - } + std::lock_guard lock(this->memoryMutex); + MemoryInfo& current = this->getCurrentMemoryInfo(); + if( + current.lockBytes > current.maxBytes + || current.lockBuffers > maxBuffers + ) { + return 1.0; + } else { + return 0.0; + } } bool DefaultMemoryManager::jitTreeExceedsMemoryPressure(size_t bytes) { - std::lock_guard lock(this->memoryMutex); - MemoryInfo& current = this->getCurrentMemoryInfo(); - return 2 * bytes > current.lockBytes; + std::lock_guard lock(this->memoryMutex); + MemoryInfo& current = this->getCurrentMemoryInfo(); + return 2 * bytes > current.lockBytes; } void DefaultMemoryManager::printInfo( const char* msg, const int /* device */, - std::ostream* _ostream) { - std::ostream& ostream = *_ostream; - const MemoryInfo& current = this->getCurrentMemoryInfo(); - - ostream << msg << std::endl - << "---------------------------------------------------------\n" - << "| POINTER | SIZE | AF LOCK | USER LOCK |\n" - << "---------------------------------------------------------\n"; - - std::lock_guard lock(this->memoryMutex); - for (auto& kv : current.lockedMap) { - const char* statusMngr = "Yes"; - const char* statusUser = "Unknown"; - if (kv.second.userLock) { - statusUser = "Yes"; - } else { - statusUser = " No"; -} + std::ostream* _ostream +) { + std::ostream& ostream = *_ostream; + const MemoryInfo& current = this->getCurrentMemoryInfo(); - const char* unit = "KB"; - double size = static_cast(kv.second.bytes) / 1024; - if (size >= 1024) { - size = size / 1024; - unit = "MB"; - } + ostream << msg << std::endl + << "---------------------------------------------------------\n" + << "| POINTER | SIZE | AF LOCK | USER LOCK |\n" + << "---------------------------------------------------------\n"; - ostream << "| " << kv.first << " | " << size << " " << unit << " | " - << statusMngr << " | " << statusUser << " |\n"; - } + std::lock_guard lock(this->memoryMutex); + for(auto& kv : current.lockedMap) { + const char* statusMngr = "Yes"; + const char* statusUser = "Unknown"; + if(kv.second.userLock) { + statusUser = "Yes"; + } else { + statusUser = " No"; + } + + const char* unit = "KB"; + double size = static_cast(kv.second.bytes) / 1024; + if(size >= 1024) { + size = size / 1024; + unit = "MB"; + } + + ostream << "| " << kv.first << " | " << size << " " << unit << " | " + << statusMngr << " | " << statusUser << " |\n"; + } - for (auto& kv : current.freeMap) { - const char* statusMngr = "No"; - const char* statusUser = "No"; + for(auto& kv : current.freeMap) { + const char* statusMngr = "No"; + const char* statusUser = "No"; - const char* unit = "KB"; - double size = static_cast(kv.first) / 1024; - if (size >= 1024) { - size = size / 1024; - unit = "MB"; - } + const char* unit = "KB"; + double size = static_cast(kv.first) / 1024; + if(size >= 1024) { + size = size / 1024; + unit = "MB"; + } - for (auto& ptr : kv.second) { - ostream << "| " << ptr << " | " << size << " " << unit << " | " - << statusMngr << " | " << statusUser << " |\n"; + for(auto& ptr : kv.second) { + ostream << "| " << ptr << " | " << size << " " << unit << " | " + << statusMngr << " | " << statusUser << " |\n"; + } } - } - ostream << "---------------------------------------------------------\n"; + ostream << "---------------------------------------------------------\n"; } void DefaultMemoryManager::userLock(const void* ptr) { - MemoryInfo& current = this->getCurrentMemoryInfo(); + MemoryInfo& current = this->getCurrentMemoryInfo(); - std::lock_guard lock(this->memoryMutex); + std::lock_guard lock(this->memoryMutex); - locked_iter iter = current.lockedMap.find(const_cast(ptr)); - if (iter != current.lockedMap.end()) { - iter->second.userLock = true; - } else { - LockedInfo info = {false, true, 100}; // This number is not relevant + locked_iter iter = current.lockedMap.find(const_cast(ptr)); + if(iter != current.lockedMap.end()) { + iter->second.userLock = true; + } else { + LockedInfo info = {false, true, 100}; // This number is not relevant - current.lockedMap[(void*)ptr] = info; - } + current.lockedMap[(void*) ptr] = info; + } } void DefaultMemoryManager::userUnlock(const void* ptr) { - this->unlock(const_cast(ptr), true); + this->unlock(const_cast(ptr), true); } bool DefaultMemoryManager::isUserLocked(const void* ptr) { - MemoryInfo& current = this->getCurrentMemoryInfo(); - std::lock_guard lock(this->memoryMutex); - locked_iter iter = current.lockedMap.find(const_cast(ptr)); - if (iter != current.lockedMap.end()) { - return iter->second.userLock; - } else { - return false; - } + MemoryInfo& current = this->getCurrentMemoryInfo(); + std::lock_guard lock(this->memoryMutex); + locked_iter iter = current.lockedMap.find(const_cast(ptr)); + if(iter != current.lockedMap.end()) { + return iter->second.userLock; + } else { + return false; + } } size_t DefaultMemoryManager::getMemStepSize() { - std::lock_guard lock(this->memoryMutex); - return this->memStepSize; + std::lock_guard lock(this->memoryMutex); + return this->memStepSize; } void DefaultMemoryManager::setMemStepSize(size_t new_step_size) { - std::lock_guard lock(this->memoryMutex); - this->memStepSize = new_step_size; + std::lock_guard lock(this->memoryMutex); + this->memStepSize = new_step_size; } size_t DefaultMemoryManager::getMaxBytes() { - std::lock_guard lock(this->memoryMutex); - return this->getCurrentMemoryInfo().maxBytes; + std::lock_guard lock(this->memoryMutex); + return this->getCurrentMemoryInfo().maxBytes; } unsigned DefaultMemoryManager::getMaxBuffers() { - return this->maxBuffers; + return this->maxBuffers; } bool DefaultMemoryManager::checkMemoryLimit() { - const MemoryInfo& current = this->getCurrentMemoryInfo(); - return current.lockBytes >= current.maxBytes || - current.totalBuffers >= this->maxBuffers; + const MemoryInfo& current = this->getCurrentMemoryInfo(); + return current.lockBytes >= current.maxBytes + || current.totalBuffers >= this->maxBuffers; } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.h b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.h index bb2e2bb..7fd2134 100644 --- a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.h +++ b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.h @@ -27,102 +27,105 @@ namespace fl { * implementations. */ class DefaultMemoryManager : public MemoryManagerAdapter { - constexpr static unsigned MAX_BUFFERS = 1000; - constexpr static size_t ONE_GB = 1 << 30; + constexpr static unsigned MAX_BUFFERS = 1000; + constexpr static size_t ONE_GB = 1 << 30; - struct LockedInfo { - bool managerLock; - bool userLock; - size_t bytes; - }; + struct LockedInfo { + bool managerLock; + bool userLock; + size_t bytes; + }; - using locked_t = typename std::unordered_map; - using locked_iter = typename locked_t::iterator; + using locked_t = typename std::unordered_map; + using locked_iter = typename locked_t::iterator; - using free_t = std::unordered_map>; - using free_iter = typename free_t::iterator; + using free_t = std::unordered_map>; + using free_iter = typename free_t::iterator; - using uptr_t = std::unique_ptr>; + using uptr_t = std::unique_ptr>; - struct MemoryInfo { - locked_t lockedMap; - free_t freeMap; + struct MemoryInfo { + locked_t lockedMap; + free_t freeMap; - size_t lockBytes; - size_t lockBuffers; - size_t totalBytes; - size_t totalBuffers; - size_t maxBytes; + size_t lockBytes; + size_t lockBuffers; + size_t totalBytes; + size_t totalBuffers; + size_t maxBytes; - MemoryInfo() + MemoryInfo() // Calling getMaxMemorySize() here calls the virtual function // that returns 0 Call it from outside the constructor. - : lockBytes(0), - lockBuffers(0), - totalBytes(0), - totalBuffers(0), - maxBytes(ONE_GB) {} - - MemoryInfo(MemoryInfo& other) = delete; - MemoryInfo(MemoryInfo&& other) = default; - MemoryInfo& operator=(MemoryInfo& other) = delete; - MemoryInfo& operator=(MemoryInfo&& other) = default; - }; - - size_t memStepSize; - unsigned maxBuffers; - - bool debugMode; - - MemoryInfo& getCurrentMemoryInfo(); - - public: - DefaultMemoryManager( - int numDevices, - unsigned maxBuffers, - bool debug, - std::shared_ptr deviceInterface); - ~DefaultMemoryManager() = default; - void initialize() override; - void shutdown() override; - void* alloc( - bool userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elSize) override; - size_t allocated(void* ptr) override; - void unlock(void* ptr, bool userLock) override; - void printInfo( - const char* msg, - const int device, - std::ostream* ostream = &std::cout) override; - void userLock(const void* ptr) override; - void userUnlock(const void* ptr) override; - bool isUserLocked(const void* ptr) override; - void signalMemoryCleanup() override; - float getMemoryPressure() override; - bool jitTreeExceedsMemoryPressure(size_t bytes) override; - void addMemoryManagement(int device) override; - void removeMemoryManagement(int device) override; - // Implementation-specific functions - void setMaxMemorySize(); - size_t getMemStepSize() override; - void setMemStepSize(size_t size) override; - size_t getMaxBytes(); - unsigned getMaxBuffers(); - bool checkMemoryLimit(); - - protected: - DefaultMemoryManager(const DefaultMemoryManager& other) = delete; - DefaultMemoryManager(const DefaultMemoryManager&& other) = delete; - DefaultMemoryManager& operator=(const DefaultMemoryManager& other) = delete; - DefaultMemoryManager& operator=(const DefaultMemoryManager&& other) = delete; - - std::mutex memoryMutex; - // backend-specific - std::vector memory; - // backend-agnostic - void cleanDeviceMemoryManager(int device); + : lockBytes(0), + lockBuffers(0), + totalBytes(0), + totalBuffers(0), + maxBytes(ONE_GB) {} + + MemoryInfo(MemoryInfo & other) = delete; + MemoryInfo(MemoryInfo && other) = default; + MemoryInfo& operator=(MemoryInfo& other) = delete; + MemoryInfo& operator=(MemoryInfo&& other) = default; + }; + + size_t memStepSize; + unsigned maxBuffers; + + bool debugMode; + + MemoryInfo& getCurrentMemoryInfo(); + +public: + DefaultMemoryManager( + int numDevices, + unsigned maxBuffers, + bool debug, + std::shared_ptr deviceInterface + ); + ~DefaultMemoryManager() = default; + void initialize() override; + void shutdown() override; + void* alloc( + bool userLock, + const unsigned ndims, + dim_t* dims, + const unsigned elSize + ) override; + size_t allocated(void* ptr) override; + void unlock(void* ptr, bool userLock) override; + void printInfo( + const char* msg, + const int device, + std::ostream* ostream = & std::cout + ) override; + void userLock(const void* ptr) override; + void userUnlock(const void* ptr) override; + bool isUserLocked(const void* ptr) override; + void signalMemoryCleanup() override; + float getMemoryPressure() override; + bool jitTreeExceedsMemoryPressure(size_t bytes) override; + void addMemoryManagement(int device) override; + void removeMemoryManagement(int device) override; + // Implementation-specific functions + void setMaxMemorySize(); + size_t getMemStepSize() override; + void setMemStepSize(size_t size) override; + size_t getMaxBytes(); + unsigned getMaxBuffers(); + bool checkMemoryLimit(); + +protected: + DefaultMemoryManager(const DefaultMemoryManager& other) = delete; + DefaultMemoryManager(const DefaultMemoryManager&& other) = delete; + DefaultMemoryManager& operator=(const DefaultMemoryManager& other) = delete; + DefaultMemoryManager& operator=(const DefaultMemoryManager&& other) = delete; + + std::mutex memoryMutex; + // backend-specific + std::vector memory; + // backend-agnostic + void cleanDeviceMemoryManager(int device); }; } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp index 8d2aa78..cdafd2d 100644 --- a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp +++ b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp @@ -16,61 +16,64 @@ namespace fl { MemoryManagerAdapter::MemoryManagerAdapter( std::shared_ptr itf, - std::ostream* logStream) - : deviceInterface(itf), logStream_(logStream) { - if (!itf) { - throw std::invalid_argument( - "MemoryManagerAdapter::MemoryManagerAdapter - " - "memory manager device interface is null"); - } - if (logStream_) { - loggingEnabled_ = true; - } - - // Create handle and set payload to point to this instance - AF_CHECK(af_create_memory_manager(&interface_)); - AF_CHECK(af_memory_manager_set_payload(interface_, (void*)this)); + std::ostream* logStream +) : deviceInterface(itf), + logStream_(logStream) { + if(!itf) { + throw std::invalid_argument( + "MemoryManagerAdapter::MemoryManagerAdapter - " + "memory manager device interface is null" + ); + } + if(logStream_) { + loggingEnabled_ = true; + } + + // Create handle and set payload to point to this instance + AF_CHECK(af_create_memory_manager(&interface_)); + AF_CHECK(af_memory_manager_set_payload(interface_, (void*) this)); } MemoryManagerAdapter::~MemoryManagerAdapter() { - // Flush the log buffer and log stream - if (logStream_) { - *logStream_ << logStreamBuffer_.str(); - logStream_->flush(); - } - - if (interface_) { - af_release_memory_manager(interface_); // nothrow - } + // Flush the log buffer and log stream + if(logStream_) { + *logStream_ << logStreamBuffer_.str(); + logStream_->flush(); + } + + if(interface_) { + af_release_memory_manager(interface_); // nothrow + } } void MemoryManagerAdapter::setLogStream(std::ostream* logStream) { - logStream_ = logStream; + logStream_ = logStream; } std::ostream* MemoryManagerAdapter::getLogStream() const { - return logStream_; + return logStream_; } void MemoryManagerAdapter::setLoggingEnabled(bool log) { - loggingEnabled_ = log; + loggingEnabled_ = log; } void MemoryManagerAdapter::setLogFlushInterval(size_t interval) { - if (interval < 1) { - throw std::invalid_argument( - "MemoryManagerAdapter::setLogFlushInterval - " - "flush interval must be great than zero."); - } - logFlushInterval_ = interval; + if(interval < 1) { + throw std::invalid_argument( + "MemoryManagerAdapter::setLogFlushInterval - " + "flush interval must be great than zero." + ); + } + logFlushInterval_ = interval; } af_memory_manager MemoryManagerAdapter::getHandle() const { - return interface_; + return interface_; } size_t MemoryManagerAdapter::getMemStepSize() { - return -1; // -1 denotes stepsize is not used by the custom memory manager + return -1; // -1 denotes stepsize is not used by the custom memory manager } void MemoryManagerAdapter::setMemStepSize(size_t size) {} diff --git a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h index bc0aa77..7ece2fe 100644 --- a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h +++ b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h @@ -23,7 +23,7 @@ namespace fl { namespace { -const size_t kDefaultLogFlushInterval = 50; + const size_t kDefaultLogFlushInterval = 50; } // namespace @@ -49,138 +49,141 @@ const size_t kDefaultLogFlushInterval = 50; * are called by ArrayFire and the JIT. */ class MemoryManagerAdapter { - public: - /** - * Constructs a MemoryManagerAdapter. - * - * @param[in] deviceInterface a pointer to a `MemoryManagerDeviceInterface`. - * Function pointers on the interface will be defined once the memory manager - * is installed. - * @param[in] logStream a pointer to an output stream to use for logging. All - * function calls to overridden base class methods by ArrayFire will be logged - * to the resulting stream in conjunction with passed arguments. If a valid - * output stream is passed, the memory manager will initialize with logging - * enabled. This argument is optional - passing no argument disables logging - * for the memory manager by default. - */ - explicit MemoryManagerAdapter( - std::shared_ptr deviceInterface, - std::ostream* logStream = nullptr); - virtual ~MemoryManagerAdapter(); - - // Standard API methods - see ArrayFire's af/memory.h header for docs. - virtual void initialize() = 0; - virtual void shutdown() = 0; - virtual void* alloc( - bool userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elSize) = 0; - virtual size_t allocated(void* ptr) = 0; - virtual void unlock(void* ptr, bool userLock) = 0; - virtual void signalMemoryCleanup() = 0; - virtual void printInfo( - const char* msg, - const int device, - std::ostream* ostream = &std::cout) = 0; - virtual void userLock(const void* ptr) = 0; - virtual void userUnlock(const void* ptr) = 0; - virtual bool isUserLocked(const void* ptr) = 0; - virtual float getMemoryPressure() = 0; - virtual bool jitTreeExceedsMemoryPressure(size_t bytes) = 0; - virtual void addMemoryManagement(int device) = 0; - virtual void removeMemoryManagement(int device) = 0; - - virtual size_t getMemStepSize(); - virtual void setMemStepSize(size_t size); - - /** - * Logs information to the `MemoryManagerAdapters`'s log stream. If logging - * mode is enabled, function calls to virtual base class methods are logged. - * - * @param[in] fname the name of the function to be logged (or some arbitrary - * prefix string) - * @param[in] vs variadic list of arguments (of `int` type) to be appended in - * a space-delimited fashion after the fname - */ - template - void log(std::string fname, Values... vs); - - /** - * Sets the log stream for a memory manager base. - * - * @param[in] logStream the output stream to set. - */ - void setLogStream(std::ostream* logStream); - - /** - * Returns the log stream for a memory manager base. - * - * @return the manager's log stream. - */ - std::ostream* getLogStream() const; - - /** - * Sets the logging mode for the memory manager base. If disabled, no logs are - * written. If enabled, all function calls to virtual base class methods are - * logged. - * - * @param[in] log bool determinig whether logging is enabled. - */ - void setLoggingEnabled(bool log); - - /** - * Sets a number of lines after which the adapter's temporary logging buffer - * gets flushed to the user-supplied output stream. Default value is 50. - * - * @param[in] interval the number of lines after which to flush the temporary - * log buffer. Supplied interval must be greater than 1. - */ - void setLogFlushInterval(size_t interval); - - /** - * Returns the ArrayFire handle for this memory manager. - * - * @return the `af_memory_manager` handle associated with this class. - */ - af_memory_manager getHandle() const; - - // Native and device memory management functions - const std::shared_ptr deviceInterface; - - protected: - // AF memory manager entity containing relevant function pointers - af_memory_manager interface_; - - private: - // Logging components - bool loggingEnabled_{false}; - std::ostream* logStream_; - std::stringstream logStreamBuffer_; - size_t logStreamBufferSize_{0}; // in number of lines - size_t logFlushInterval_{kDefaultLogFlushInterval}; +public: + /** + * Constructs a MemoryManagerAdapter. + * + * @param[in] deviceInterface a pointer to a `MemoryManagerDeviceInterface`. + * Function pointers on the interface will be defined once the memory manager + * is installed. + * @param[in] logStream a pointer to an output stream to use for logging. All + * function calls to overridden base class methods by ArrayFire will be logged + * to the resulting stream in conjunction with passed arguments. If a valid + * output stream is passed, the memory manager will initialize with logging + * enabled. This argument is optional - passing no argument disables logging + * for the memory manager by default. + */ + explicit MemoryManagerAdapter( + std::shared_ptr deviceInterface, + std::ostream* logStream = nullptr + ); + virtual ~MemoryManagerAdapter(); + + // Standard API methods - see ArrayFire's af/memory.h header for docs. + virtual void initialize() = 0; + virtual void shutdown() = 0; + virtual void* alloc( + bool userLock, + const unsigned ndims, + dim_t* dims, + const unsigned elSize + ) = 0; + virtual size_t allocated(void* ptr) = 0; + virtual void unlock(void* ptr, bool userLock) = 0; + virtual void signalMemoryCleanup() = 0; + virtual void printInfo( + const char* msg, + const int device, + std::ostream* ostream = & std::cout + ) = 0; + virtual void userLock(const void* ptr) = 0; + virtual void userUnlock(const void* ptr) = 0; + virtual bool isUserLocked(const void* ptr) = 0; + virtual float getMemoryPressure() = 0; + virtual bool jitTreeExceedsMemoryPressure(size_t bytes) = 0; + virtual void addMemoryManagement(int device) = 0; + virtual void removeMemoryManagement(int device) = 0; + + virtual size_t getMemStepSize(); + virtual void setMemStepSize(size_t size); + + /** + * Logs information to the `MemoryManagerAdapters`'s log stream. If logging + * mode is enabled, function calls to virtual base class methods are logged. + * + * @param[in] fname the name of the function to be logged (or some arbitrary + * prefix string) + * @param[in] vs variadic list of arguments (of `int` type) to be appended in + * a space-delimited fashion after the fname + */ + template void log(std::string fname, Values... vs); + + /** + * Sets the log stream for a memory manager base. + * + * @param[in] logStream the output stream to set. + */ + void setLogStream(std::ostream* logStream); + + /** + * Returns the log stream for a memory manager base. + * + * @return the manager's log stream. + */ + std::ostream* getLogStream() const; + + /** + * Sets the logging mode for the memory manager base. If disabled, no logs are + * written. If enabled, all function calls to virtual base class methods are + * logged. + * + * @param[in] log bool determinig whether logging is enabled. + */ + void setLoggingEnabled(bool log); + + /** + * Sets a number of lines after which the adapter's temporary logging buffer + * gets flushed to the user-supplied output stream. Default value is 50. + * + * @param[in] interval the number of lines after which to flush the temporary + * log buffer. Supplied interval must be greater than 1. + */ + void setLogFlushInterval(size_t interval); + + /** + * Returns the ArrayFire handle for this memory manager. + * + * @return the `af_memory_manager` handle associated with this class. + */ + af_memory_manager getHandle() const; + + // Native and device memory management functions + const std::shared_ptr deviceInterface; + +protected: + // AF memory manager entity containing relevant function pointers + af_memory_manager interface_; + +private: + // Logging components + bool loggingEnabled_{false}; + std::ostream* logStream_; + std::stringstream logStreamBuffer_; + size_t logStreamBufferSize_{0}; // in number of lines + size_t logFlushInterval_{kDefaultLogFlushInterval}; }; -template +template void MemoryManagerAdapter::log(std::string fname, Values... vs) { - if (loggingEnabled_) { - if (!logStream_) { - throw std::runtime_error( - "MemoryManagerAdapter::log: cannot write to logStream_" - " - stream is invalid or uninitialized"); + if(loggingEnabled_) { + if(!logStream_) { + throw std::runtime_error( + "MemoryManagerAdapter::log: cannot write to logStream_" + " - stream is invalid or uninitialized" + ); + } + logStreamBuffer_ << fname << " "; + int unpack[]{0, (logStreamBuffer_ << std::to_string(vs) << " ", 0)...}; + static_cast(unpack); + logStreamBuffer_ << '\n'; + logStreamBufferSize_++; + // Decide whether or not to flush + if(logStreamBufferSize_ == logFlushInterval_) { + *logStream_ << logStreamBuffer_.str(); + logStreamBuffer_.str(""); // clear the log buffer. + logStreamBufferSize_ = 0; + } } - logStreamBuffer_ << fname << " "; - int unpack[]{0, (logStreamBuffer_ << std::to_string(vs) << " ", 0)...}; - static_cast(unpack); - logStreamBuffer_ << '\n'; - logStreamBufferSize_++; - // Decide whether or not to flush - if (logStreamBufferSize_ == logFlushInterval_) { - *logStream_ << logStreamBuffer_.str(); - logStreamBuffer_.str(""); // clear the log buffer. - logStreamBufferSize_ = 0; - } - } } }; // namespace fl diff --git a/flashlight/fl/tensor/backend/af/mem/MemoryManagerDeviceInterface.h b/flashlight/fl/tensor/backend/af/mem/MemoryManagerDeviceInterface.h index 2d1ec00..3e15992 100644 --- a/flashlight/fl/tensor/backend/af/mem/MemoryManagerDeviceInterface.h +++ b/flashlight/fl/tensor/backend/af/mem/MemoryManagerDeviceInterface.h @@ -11,12 +11,12 @@ namespace fl { -using GetActiveDeviceIdFn = std::function; -using GetMaxMemorySizeFn = std::function; -using NativeAllocFn = std::function; -using NativeFreeFn = std::function; -using GetMemoryPressureThresholdFn = std::function; -using SetMemoryPressureThresholdFn = std::function; +using GetActiveDeviceIdFn = std::function; +using GetMaxMemorySizeFn = std::function; +using NativeAllocFn = std::function; +using NativeFreeFn = std::function; +using GetMemoryPressureThresholdFn = std::function; +using SetMemoryPressureThresholdFn = std::function; /** * An interface for using native device memory management and JIT-related memory @@ -37,14 +37,14 @@ using SetMemoryPressureThresholdFn = std::function; * header](https://git.io/Jv7do) for full specifications. */ struct MemoryManagerDeviceInterface { - // Native memory management functions - GetActiveDeviceIdFn getActiveDeviceId; - GetMaxMemorySizeFn getMaxMemorySize; - NativeAllocFn nativeAlloc; - NativeFreeFn nativeFree; - // Memory pressure functions - GetMemoryPressureThresholdFn getMemoryPressureThreshold; - SetMemoryPressureThresholdFn setMemoryPressureThreshold; + // Native memory management functions + GetActiveDeviceIdFn getActiveDeviceId; + GetMaxMemorySizeFn getMaxMemorySize; + NativeAllocFn nativeAlloc; + NativeFreeFn nativeFree; + // Memory pressure functions + GetMemoryPressureThresholdFn getMemoryPressureThreshold; + SetMemoryPressureThresholdFn setMemoryPressureThreshold; }; } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp b/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp index 9678eff..d6d6859 100644 --- a/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp +++ b/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp @@ -20,221 +20,249 @@ namespace fl { // Statics from MemoryManagerInstaller std::shared_ptr - MemoryManagerInstaller::currentlyInstalledMemoryManager_; +MemoryManagerInstaller::currentlyInstalledMemoryManager_; MemoryManagerAdapter* MemoryManagerInstaller::getImpl( - af_memory_manager manager) { - void* ptr; - AF_CHECK(af_memory_manager_get_payload(manager, &ptr)); - return (MemoryManagerAdapter*)ptr; + af_memory_manager manager +) { + void* ptr; + AF_CHECK(af_memory_manager_get_payload(manager, &ptr)); + return (MemoryManagerAdapter*) ptr; } MemoryManagerInstaller::MemoryManagerInstaller( - std::shared_ptr managerImpl) - : impl_(managerImpl) { - if (!impl_) { - throw std::invalid_argument( - "MemoryManagerInstaller::MemoryManagerInstaller - " - "passed MemoryManagerAdapter is null"); - } - - af_memory_manager itf = impl_->getHandle(); - if (!impl_->getHandle()) { - throw std::invalid_argument( - "MemoryManagerInstaller::MemoryManagerInstaller - " - "passed MemoryManagerAdapter has null handle"); - } - - // Set appropriate function pointers for each class method - auto initializeFn = [](af_memory_manager manager) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("initialize"); - m->initialize(); - return AF_SUCCESS; - }; - AF_CHECK(af_memory_manager_set_initialize_fn(itf, initializeFn)); - auto shutdownFn = [](af_memory_manager manager) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("shutdown"); - m->shutdown(); - return AF_SUCCESS; - }; - AF_CHECK(af_memory_manager_set_shutdown_fn(itf, shutdownFn)); - // ArrayFire expects the memory managers alloc fn to return an af_err, not to - // throw, if a problem with allocation occurred - auto allocFn = [](af_memory_manager manager, - void** ptr, - /* bool */ int userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elSize) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - try { - *ptr = m->alloc(userLock, ndims, dims, elSize); - } catch (af::exception& ex) { - m->log( - "allocFn: alloc failed with af exception " + - std::to_string(ex.err())); - return ex.err(); // AF_ERR_NO_MEM, ... - } catch (...) { - m->log("allocFn: alloc failed with unspecified exception"); - return af_err(AF_ERR_UNKNOWN); + std::shared_ptr managerImpl +) : impl_(managerImpl) { + if(!impl_) { + throw std::invalid_argument( + "MemoryManagerInstaller::MemoryManagerInstaller - " + "passed MemoryManagerAdapter is null" + ); } - // Log - m->log( - "alloc", - /* size */ dims[0], // HACK: dims[0] until af::memAlloc is size-aware - userLock, - (std::uintptr_t)*ptr); - return AF_SUCCESS; - }; - AF_CHECK(af_memory_manager_set_alloc_fn(itf, allocFn)); - auto allocatedFn = [](af_memory_manager manager, size_t* size, void* ptr) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("allocated", (std::uintptr_t)ptr); - *size = m->allocated(ptr); - return AF_SUCCESS; - }; - AF_CHECK(af_memory_manager_set_allocated_fn(itf, allocatedFn)); - auto unlockFn = [](af_memory_manager manager, void* ptr, int userLock) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("unlock", (std::uintptr_t)ptr, userLock); - m->unlock(ptr, (bool)userLock); - return AF_SUCCESS; - }; - AF_CHECK(af_memory_manager_set_unlock_fn(itf, unlockFn)); - auto signalMemoryCleanupFn = [](af_memory_manager manager) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("signalMemoryCleanup"); - m->signalMemoryCleanup(); - return AF_SUCCESS; - }; - AF_CHECK(af_memory_manager_set_signal_memory_cleanup_fn( - itf, signalMemoryCleanupFn)); - auto printInfoFn = [](af_memory_manager manager, char* msg, int device) { - // no log - auto* adapter = MemoryManagerInstaller::getImpl(manager); - adapter->printInfo(msg, device, adapter->getLogStream()); - return AF_SUCCESS; - }; - AF_CHECK(af_memory_manager_set_print_info_fn(itf, printInfoFn)); - auto userLockFn = [](af_memory_manager manager, void* ptr) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("userLock", (std::uintptr_t)ptr); - m->userLock(ptr); - return AF_SUCCESS; - }; - AF_CHECK(af_memory_manager_set_user_lock_fn(itf, userLockFn)); - auto userUnlockFn = [](af_memory_manager manager, void* ptr) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("userUnlock", (std::uintptr_t)ptr); - MemoryManagerInstaller::getImpl(manager)->userUnlock(ptr); - return AF_SUCCESS; - }; - AF_CHECK(af_memory_manager_set_user_unlock_fn(itf, userUnlockFn)); - auto isUserLockedFn = [](af_memory_manager manager, int* out, void* ptr) { - MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); - m->log("isUserLocked", (std::uintptr_t)ptr); - *out = static_cast(m->isUserLocked(ptr)); - return AF_SUCCESS; - }; - AF_CHECK(af_memory_manager_set_is_user_locked_fn(itf, isUserLockedFn)); - auto getMemoryPressureFn = [](af_memory_manager manager, float* pressure) { - *pressure = MemoryManagerInstaller::getImpl(manager)->getMemoryPressure(); - return AF_SUCCESS; - }; - AF_CHECK( - af_memory_manager_set_get_memory_pressure_fn(itf, getMemoryPressureFn)); - auto jitTreeExceedsMemoryPressureFn = - [](af_memory_manager manager, int* out, size_t bytes) { - *out = static_cast(MemoryManagerInstaller::getImpl(manager) - ->jitTreeExceedsMemoryPressure(bytes)); - return AF_SUCCESS; - }; - AF_CHECK(af_memory_manager_set_jit_tree_exceeds_memory_pressure_fn( - itf, jitTreeExceedsMemoryPressureFn)); - auto addMemoryManagementFn = [](af_memory_manager manager, int device) { - MemoryManagerInstaller::getImpl(manager)->addMemoryManagement(device); - }; - AF_CHECK(af_memory_manager_set_add_memory_management_fn( - itf, addMemoryManagementFn)); - auto removeMemoryManagementFn = [](af_memory_manager manager, int device) { - MemoryManagerInstaller::getImpl(manager)->removeMemoryManagement(device); - }; - AF_CHECK(af_memory_manager_set_remove_memory_management_fn( - itf, removeMemoryManagementFn)); - - // Native and device memory manager functions - auto getActiveDeviceIdFn = [itf]() { - int id; - AF_CHECK(af_memory_manager_get_active_device_id(itf, &id)); - return id; - }; - impl_->deviceInterface->getActiveDeviceId = std::move(getActiveDeviceIdFn); - auto getMaxMemorySizeFn = [itf](int id) { - size_t out; - AF_CHECK(af_memory_manager_get_max_memory_size(itf, &out, id)); - return out; - }; - impl_->deviceInterface->getMaxMemorySize = std::move(getMaxMemorySizeFn); - // nativeAlloc could throw via AF_CHECK: - auto nativeAllocFn = [itf](const size_t bytes) { - void* ptr; - AF_CHECK(af_memory_manager_native_alloc(itf, &ptr, bytes)); - MemoryManagerInstaller::getImpl(itf)->log( - "nativeAlloc", bytes, (std::uintptr_t)ptr); - return ptr; - }; - impl_->deviceInterface->nativeAlloc = std::move(nativeAllocFn); - auto nativeFreeFn = [itf](void* ptr) { - MemoryManagerInstaller::getImpl(itf)->log( - "nativeFree", (std::uintptr_t)ptr); - AF_CHECK(af_memory_manager_native_free(itf, ptr)); - }; - impl_->deviceInterface->nativeFree = std::move(nativeFreeFn); - auto getMemoryPressureThresholdFn = [itf]() { - float pressure; - AF_CHECK(af_memory_manager_get_memory_pressure_threshold(itf, &pressure)); - return pressure; - }; - impl_->deviceInterface->getMemoryPressureThreshold = - std::move(getMemoryPressureThresholdFn); - auto setMemoryPressureThresholdFn = [itf](float pressure) { - AF_CHECK(af_memory_manager_set_memory_pressure_threshold(itf, pressure)); - }; - impl_->deviceInterface->setMemoryPressureThreshold = - std::move(setMemoryPressureThresholdFn); + + af_memory_manager itf = impl_->getHandle(); + if(!impl_->getHandle()) { + throw std::invalid_argument( + "MemoryManagerInstaller::MemoryManagerInstaller - " + "passed MemoryManagerAdapter has null handle" + ); + } + + // Set appropriate function pointers for each class method + auto initializeFn = [](af_memory_manager manager) { + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("initialize"); + m->initialize(); + return AF_SUCCESS; + }; + AF_CHECK(af_memory_manager_set_initialize_fn(itf, initializeFn)); + auto shutdownFn = [](af_memory_manager manager) { + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("shutdown"); + m->shutdown(); + return AF_SUCCESS; + }; + AF_CHECK(af_memory_manager_set_shutdown_fn(itf, shutdownFn)); + // ArrayFire expects the memory managers alloc fn to return an af_err, not to + // throw, if a problem with allocation occurred + auto allocFn = [](af_memory_manager manager, + void** ptr, + /* bool */ int userLock, + const unsigned ndims, + dim_t* dims, + const unsigned elSize) { + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + try { + *ptr = m->alloc(userLock, ndims, dims, elSize); + } catch(af::exception& ex) { + m->log( + "allocFn: alloc failed with af exception " + + std::to_string(ex.err()) + ); + return ex.err(); // AF_ERR_NO_MEM, ... + } catch(...) { + m->log("allocFn: alloc failed with unspecified exception"); + return af_err(AF_ERR_UNKNOWN); + } + // Log + m->log( + "alloc", + /* size */ dims[0], // HACK: dims[0] until af::memAlloc is size-aware + userLock, + (std::uintptr_t) *ptr + ); + return AF_SUCCESS; + }; + AF_CHECK(af_memory_manager_set_alloc_fn(itf, allocFn)); + auto allocatedFn = [](af_memory_manager manager, size_t* size, void* ptr) { + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("allocated", (std::uintptr_t) ptr); + *size = m->allocated(ptr); + return AF_SUCCESS; + }; + AF_CHECK(af_memory_manager_set_allocated_fn(itf, allocatedFn)); + auto unlockFn = [](af_memory_manager manager, void* ptr, int userLock) { + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("unlock", (std::uintptr_t) ptr, userLock); + m->unlock(ptr, (bool) userLock); + return AF_SUCCESS; + }; + AF_CHECK(af_memory_manager_set_unlock_fn(itf, unlockFn)); + auto signalMemoryCleanupFn = [](af_memory_manager manager) { + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("signalMemoryCleanup"); + m->signalMemoryCleanup(); + return AF_SUCCESS; + }; + AF_CHECK( + af_memory_manager_set_signal_memory_cleanup_fn( + itf, + signalMemoryCleanupFn + ) + ); + auto printInfoFn = [](af_memory_manager manager, char* msg, int device) { + // no log + auto* adapter = MemoryManagerInstaller::getImpl(manager); + adapter->printInfo(msg, device, adapter->getLogStream()); + return AF_SUCCESS; + }; + AF_CHECK(af_memory_manager_set_print_info_fn(itf, printInfoFn)); + auto userLockFn = [](af_memory_manager manager, void* ptr) { + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("userLock", (std::uintptr_t) ptr); + m->userLock(ptr); + return AF_SUCCESS; + }; + AF_CHECK(af_memory_manager_set_user_lock_fn(itf, userLockFn)); + auto userUnlockFn = [](af_memory_manager manager, void* ptr) { + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("userUnlock", (std::uintptr_t) ptr); + MemoryManagerInstaller::getImpl(manager)->userUnlock(ptr); + return AF_SUCCESS; + }; + AF_CHECK(af_memory_manager_set_user_unlock_fn(itf, userUnlockFn)); + auto isUserLockedFn = [](af_memory_manager manager, int* out, void* ptr) { + MemoryManagerAdapter* m = MemoryManagerInstaller::getImpl(manager); + m->log("isUserLocked", (std::uintptr_t) ptr); + *out = static_cast(m->isUserLocked(ptr)); + return AF_SUCCESS; + }; + AF_CHECK(af_memory_manager_set_is_user_locked_fn(itf, isUserLockedFn)); + auto getMemoryPressureFn = [](af_memory_manager manager, float* pressure) { + *pressure = MemoryManagerInstaller::getImpl(manager)->getMemoryPressure(); + return AF_SUCCESS; + }; + AF_CHECK( + af_memory_manager_set_get_memory_pressure_fn(itf, getMemoryPressureFn) + ); + auto jitTreeExceedsMemoryPressureFn = + [](af_memory_manager manager, int* out, size_t bytes) { + *out = static_cast(MemoryManagerInstaller::getImpl(manager) + ->jitTreeExceedsMemoryPressure(bytes)); + return AF_SUCCESS; + }; + AF_CHECK( + af_memory_manager_set_jit_tree_exceeds_memory_pressure_fn( + itf, + jitTreeExceedsMemoryPressureFn + ) + ); + auto addMemoryManagementFn = [](af_memory_manager manager, int device) { + MemoryManagerInstaller::getImpl(manager)->addMemoryManagement(device); + }; + AF_CHECK( + af_memory_manager_set_add_memory_management_fn( + itf, + addMemoryManagementFn + ) + ); + auto removeMemoryManagementFn = [](af_memory_manager manager, int device) { + MemoryManagerInstaller::getImpl(manager)->removeMemoryManagement(device); + }; + AF_CHECK( + af_memory_manager_set_remove_memory_management_fn( + itf, + removeMemoryManagementFn + ) + ); + + // Native and device memory manager functions + auto getActiveDeviceIdFn = [itf]() { + int id; + AF_CHECK(af_memory_manager_get_active_device_id(itf, &id)); + return id; + }; + impl_->deviceInterface->getActiveDeviceId = std::move(getActiveDeviceIdFn); + auto getMaxMemorySizeFn = [itf](int id) { + size_t out; + AF_CHECK(af_memory_manager_get_max_memory_size(itf, &out, id)); + return out; + }; + impl_->deviceInterface->getMaxMemorySize = std::move(getMaxMemorySizeFn); + // nativeAlloc could throw via AF_CHECK: + auto nativeAllocFn = [itf](const size_t bytes) { + void* ptr; + AF_CHECK(af_memory_manager_native_alloc(itf, &ptr, bytes)); + MemoryManagerInstaller::getImpl(itf)->log( + "nativeAlloc", + bytes, + (std::uintptr_t) ptr + ); + return ptr; + }; + impl_->deviceInterface->nativeAlloc = std::move(nativeAllocFn); + auto nativeFreeFn = [itf](void* ptr) { + MemoryManagerInstaller::getImpl(itf)->log( + "nativeFree", + (std::uintptr_t) ptr + ); + AF_CHECK(af_memory_manager_native_free(itf, ptr)); + }; + impl_->deviceInterface->nativeFree = std::move(nativeFreeFn); + auto getMemoryPressureThresholdFn = [itf]() { + float pressure; + AF_CHECK(af_memory_manager_get_memory_pressure_threshold(itf, &pressure)); + return pressure; + }; + impl_->deviceInterface->getMemoryPressureThreshold = + std::move(getMemoryPressureThresholdFn); + auto setMemoryPressureThresholdFn = [itf](float pressure) { + AF_CHECK(af_memory_manager_set_memory_pressure_threshold(itf, pressure)); + }; + impl_->deviceInterface->setMemoryPressureThreshold = + std::move(setMemoryPressureThresholdFn); } void MemoryManagerInstaller::setAsMemoryManager() { - AF_CHECK(af_set_memory_manager(impl_->getHandle())); - currentlyInstalledMemoryManager_ = impl_; + AF_CHECK(af_set_memory_manager(impl_->getHandle())); + currentlyInstalledMemoryManager_ = impl_; } void MemoryManagerInstaller::setAsMemoryManagerPinned() { - AF_CHECK(af_set_memory_manager_pinned(impl_->getHandle())); - currentlyInstalledMemoryManager_ = impl_; + AF_CHECK(af_set_memory_manager_pinned(impl_->getHandle())); + currentlyInstalledMemoryManager_ = impl_; } -MemoryManagerAdapter* -MemoryManagerInstaller::currentlyInstalledMemoryManager() { - return currentlyInstalledMemoryManager_.get(); +MemoryManagerAdapter* MemoryManagerInstaller::currentlyInstalledMemoryManager() { + return currentlyInstalledMemoryManager_.get(); } void MemoryManagerInstaller::installDefaultMemoryManager() { - auto deviceInterface = std::make_shared(); - auto adapter = std::make_shared( - af::getDeviceCount(), deviceInterface); - auto installer = MemoryManagerInstaller(adapter); - installer.setAsMemoryManager(); + auto deviceInterface = std::make_shared(); + auto adapter = std::make_shared( + af::getDeviceCount(), + deviceInterface + ); + auto installer = MemoryManagerInstaller(adapter); + installer.setAsMemoryManager(); } void MemoryManagerInstaller::unsetMemoryManager() { - // Make sure we don't reset the default AF memory manager if it's set - if (currentlyInstalledMemoryManager_) { - AF_CHECK(af_unset_memory_manager()); - currentlyInstalledMemoryManager_ = nullptr; - } + // Make sure we don't reset the default AF memory manager if it's set + if(currentlyInstalledMemoryManager_) { + AF_CHECK(af_unset_memory_manager()); + currentlyInstalledMemoryManager_ = nullptr; + } } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.h b/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.h index f014ddf..5af4d0c 100644 --- a/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.h +++ b/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.h @@ -27,79 +27,80 @@ namespace fl { * active ArrayFire memory manager even if its installer has been destroyed. */ class MemoryManagerInstaller { - public: - /** - * Creates a new instance using a `MemoryManagerAdapter`. Uses the adapter's - * underlying `af_memory_manager` handle and performs the following setup: - * - Sets all function pointers using the Array/Fire C memory management API - * on the underlying `af_memory_manager` handle to point to closures which - * call the installed `MemoryManagerAdapter`'s instance methods. - * - Sets the closures on the adapter's `MemoryManagerDeviceInterface` to call - * ArrayFire C-API native device memory management functions which - * automatically delegate to the proper backend and are use pre-defined - * implementations in ArrayFire internals. - * - * @param[in] managerImpl a pointer to the `MemoryManagerAdapter` to be - * installed. - */ - explicit MemoryManagerInstaller( - std::shared_ptr managerImpl); - ~MemoryManagerInstaller() = default; +public: + /** + * Creates a new instance using a `MemoryManagerAdapter`. Uses the adapter's + * underlying `af_memory_manager` handle and performs the following setup: + * - Sets all function pointers using the Array/Fire C memory management API + * on the underlying `af_memory_manager` handle to point to closures which + * call the installed `MemoryManagerAdapter`'s instance methods. + * - Sets the closures on the adapter's `MemoryManagerDeviceInterface` to call + * ArrayFire C-API native device memory management functions which + * automatically delegate to the proper backend and are use pre-defined + * implementations in ArrayFire internals. + * + * @param[in] managerImpl a pointer to the `MemoryManagerAdapter` to be + * installed. + */ + explicit MemoryManagerInstaller( + std::shared_ptr managerImpl + ); + ~MemoryManagerInstaller() = default; - /** - * Gets the memory manager adapter used in this instance. - * - * @return a pointer to some derived type of `MemoryManagerAdapter` - */ - template - std::shared_ptr getMemoryManager() const { - return std::dynamic_pointer_cast(impl_); - } + /** + * Gets the memory manager adapter used in this instance. + * + * @return a pointer to some derived type of `MemoryManagerAdapter` + */ + template + std::shared_ptr getMemoryManager() const { + return std::dynamic_pointer_cast(impl_); + } - /** - * Sets this `MemoryManagerInstaller`'s `MemoryManagerAdapter` to be the - * active memory manager in ArrayFire. - */ - void setAsMemoryManager(); + /** + * Sets this `MemoryManagerInstaller`'s `MemoryManagerAdapter` to be the + * active memory manager in ArrayFire. + */ + void setAsMemoryManager(); - /** - * Sets this `MemoryManagerInstaller`'s `MemoryManagerAdapter` to be the - * active memory manager for pinned memory operations in ArrayFire. - */ - void setAsMemoryManagerPinned(); + /** + * Sets this `MemoryManagerInstaller`'s `MemoryManagerAdapter` to be the + * active memory manager for pinned memory operations in ArrayFire. + */ + void setAsMemoryManagerPinned(); - /** - * Returns an adapter given a handle. Used to construct C++-style callbacks - * inside lambdas set on the ArrayFire C memory management API. - */ - static MemoryManagerAdapter* getImpl(af_memory_manager manager); + /** + * Returns an adapter given a handle. Used to construct C++-style callbacks + * inside lambdas set on the ArrayFire C memory management API. + */ + static MemoryManagerAdapter* getImpl(af_memory_manager manager); - /** - * Returns the currently installed custom memory manager, or null if none is - * installed. - */ - static MemoryManagerAdapter* currentlyInstalledMemoryManager(); + /** + * Returns the currently installed custom memory manager, or null if none is + * installed. + */ + static MemoryManagerAdapter* currentlyInstalledMemoryManager(); - /** - * Initializes and installs the memory manager defaulted to on startup. - * - * Uses a `CachingMemoryManager` by default. Only sets the memory manager - - * doesn't set an AF pinned memory manager. - */ - static void installDefaultMemoryManager(); + /** + * Initializes and installs the memory manager defaulted to on startup. + * + * Uses a `CachingMemoryManager` by default. Only sets the memory manager - + * doesn't set an AF pinned memory manager. + */ + static void installDefaultMemoryManager(); - /** - * Unsets the currently-set custom ArrayFire memory manager. If no custom - * memory manager is set, results in a noop, since the default memory manager - * is set, and unsetting it would result in shutdown/destruction. - */ - static void unsetMemoryManager(); + /** + * Unsets the currently-set custom ArrayFire memory manager. If no custom + * memory manager is set, results in a noop, since the default memory manager + * is set, and unsetting it would result in shutdown/destruction. + */ + static void unsetMemoryManager(); - private: - // The given memory manager implementation - std::shared_ptr impl_; - // Points to the impl_ of the most recently installed manager. - static std::shared_ptr currentlyInstalledMemoryManager_; +private: + // The given memory manager implementation + std::shared_ptr impl_; + // Points to the impl_ of the most recently installed manager. + static std::shared_ptr currentlyInstalledMemoryManager_; }; } // namespace fl diff --git a/flashlight/fl/tensor/backend/stub/StubBackend.cpp b/flashlight/fl/tensor/backend/stub/StubBackend.cpp index 77b347f..d3b6094 100644 --- a/flashlight/fl/tensor/backend/stub/StubBackend.cpp +++ b/flashlight/fl/tensor/backend/stub/StubBackend.cpp @@ -11,90 +11,97 @@ #include "flashlight/fl/tensor/TensorBase.h" -#define FL_STUB_BACKEND_UNIMPLEMENTED \ - throw std::invalid_argument( \ - "StubBackend::" + std::string(__func__) + " - unimplemented."); +#define FL_STUB_BACKEND_UNIMPLEMENTED \ + throw std::invalid_argument( \ + "StubBackend::" + std::string(__func__) + " - unimplemented." \ + ); namespace fl { StubBackend::StubBackend() { - // Set up state + // Set up state } StubBackend& StubBackend::getInstance() { - static StubBackend instance; - return instance; + static StubBackend instance; + return instance; } TensorBackendType StubBackend::backendType() const { - // Implementers of a backend should create their own option in the - // TensorBackendType enum and return it here. - return TensorBackendType::Stub; + // Implementers of a backend should create their own option in the + // TensorBackendType enum and return it here. + return TensorBackendType::Stub; } /* -------------------------- Compute Functions -------------------------- */ void StubBackend::eval(const Tensor& /* tensor */) { - // Launch computation for a given tensor. Can be a noop for non-async - // runtimes. - FL_STUB_BACKEND_UNIMPLEMENTED; + // Launch computation for a given tensor. Can be a noop for non-async + // runtimes. + FL_STUB_BACKEND_UNIMPLEMENTED; } bool StubBackend::supportsDataType(const fl::dtype& /* dtype */) const { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::getMemMgrInfo( const char* /* msg */, const int /* deviceId */, - std::ostream* /* ostream */) { - // Can be a noop if no memory manager is implemented. - FL_STUB_BACKEND_UNIMPLEMENTED; + std::ostream* /* ostream */ +) { + // Can be a noop if no memory manager is implemented. + FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::setMemMgrLogStream(std::ostream* /* stream */) { - // Can be a noop if no memory manager is implemented. - FL_STUB_BACKEND_UNIMPLEMENTED; + // Can be a noop if no memory manager is implemented. + FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::setMemMgrLoggingEnabled(const bool /* enabled */) { - // Can be a noop if no memory manager is implemented. - FL_STUB_BACKEND_UNIMPLEMENTED; + // Can be a noop if no memory manager is implemented. + FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::setMemMgrFlushInterval(const size_t /* interval */) { - // Can be a noop if no memory manager is implemented. - FL_STUB_BACKEND_UNIMPLEMENTED; + // Can be a noop if no memory manager is implemented. + FL_STUB_BACKEND_UNIMPLEMENTED; } /* -------------------------- Rand Functions -------------------------- */ void StubBackend::setSeed(const int /* seed */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::randn(const Shape& /* shape */, dtype /* type */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::rand(const Shape& /* shape */, dtype /* type */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } /* --------------------------- Tensor Operators --------------------------- */ /******************** Tensor Creation Functions ********************/ -#define FL_STUB_BACKEND_CREATE_FUN_LITERAL_DEF(TYPE) \ - Tensor StubBackend::fromScalar(TYPE /* value */, const dtype /* type */) { \ - throw std::invalid_argument( \ - "StubBackend::fromScalar - not implemented for type " + \ - std::string(#TYPE)); \ - } \ - Tensor StubBackend::full( \ - const Shape& /* shape */, TYPE /* value */, const dtype /* type */) { \ - throw std::invalid_argument( \ - "StubBackend::full - not implemented for type " + std::string(#TYPE)); \ - } +#define FL_STUB_BACKEND_CREATE_FUN_LITERAL_DEF(TYPE) \ + Tensor StubBackend::fromScalar(TYPE /* value */, const dtype /* type */) { \ + throw std::invalid_argument( \ + "StubBackend::fromScalar - not implemented for type " + \ + std::string(#TYPE) \ + ); \ + } \ + Tensor StubBackend::full( \ + const Shape& /* shape */, \ + TYPE /* value */, \ + const dtype /* type */ \ + ) { \ + throw std::invalid_argument( \ + "StubBackend::full - not implemented for type " + std::string(#TYPE) \ + ); \ + } FL_STUB_BACKEND_CREATE_FUN_LITERAL_DEF(const double&); FL_STUB_BACKEND_CREATE_FUN_LITERAL_DEF(const float&); FL_STUB_BACKEND_CREATE_FUN_LITERAL_DEF(const int&); @@ -110,162 +117,171 @@ FL_STUB_BACKEND_CREATE_FUN_LITERAL_DEF(const short&); FL_STUB_BACKEND_CREATE_FUN_LITERAL_DEF(const unsigned short&); Tensor StubBackend::identity(const Dim /* dim */, const dtype /* type */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::arange( const Shape& /* shape */, const Dim /* seqDim */, - const dtype /* type */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const dtype /* type */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::iota( const Shape& /* dims */, const Shape& /* tileDims */, - const dtype /* type */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const dtype /* type */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } /************************ Shaping and Indexing *************************/ Tensor StubBackend::reshape( const Tensor& /* tensor */, - const Shape& /* shape */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const Shape& /* shape */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::transpose( const Tensor& /* tensor */, - const Shape& /* axes */ /* = {} */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const Shape& /* axes */ /* = {} */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::tile(const Tensor& /* tensor */, const Shape& /* shape */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::concatenate( const std::vector& /* tensors */, - const unsigned /* axis */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const unsigned /* axis */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::nonzero(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::pad( const Tensor& /* input */, const std::vector>& /* padWidths */, - const PadType /* type */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const PadType /* type */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } /************************** Unary Operators ***************************/ Tensor StubBackend::exp(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::log(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::negative(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::logicalNot(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::log1p(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::sin(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::cos(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::sqrt(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::tanh(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::floor(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::ceil(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::rint(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::absolute(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::sigmoid(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::erf(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::flip(const Tensor& /* tensor */, const unsigned /* dim */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::clip( const Tensor& /* tensor */, const Tensor& /* low */, - const Tensor& /* high */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const Tensor& /* high */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::roll( const Tensor& /* tensor */, const int /* shift */, - const unsigned /* axis */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const unsigned /* axis */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::isnan(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::isinf(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::sign(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::tril(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::triu(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::where( const Tensor& /* condition */, const Tensor& /* x */, - const Tensor& /* y */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const Tensor& /* y */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::topk( @@ -274,15 +290,17 @@ void StubBackend::topk( const Tensor& /* input */, const unsigned /* k */, const Dim /* axis */, - const SortMode /* sortMode */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const SortMode /* sortMode */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::sort( const Tensor& /* input */, const Dim /* axis */, - const SortMode /* sortMode */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const SortMode /* sortMode */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::sort( @@ -290,54 +308,59 @@ void StubBackend::sort( Tensor& /* indices */, const Tensor& /* input */, const Dim /* axis */, - const SortMode /* sortMode */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const SortMode /* sortMode */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::argsort( const Tensor& /* input */, const Dim /* axis */, - const SortMode /* sortMode */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const SortMode /* sortMode */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } /************************** Binary Operators ***************************/ -#define FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, TYPE) \ - Tensor StubBackend::FUNC(const Tensor& /* a */, TYPE /* rhs */) { \ - throw std::runtime_error( \ - "StubBackend::" + std::string(#FUNC) + " unimplemented for type " + \ - std::string(#TYPE)); \ - } \ - Tensor StubBackend::FUNC(TYPE /* lhs */, const Tensor& /* a */) { \ - throw std::runtime_error( \ - "StubBackend::" + std::string(#FUNC) + " unimplemented for type " + \ - std::string(#TYPE)); \ - } - -#define FL_AF_BINARY_OP_LITERALS_DEF(FUNC, OP) \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const bool&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const int&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const char&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned char&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const long&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned long&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const long long&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned long long&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const double&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const float&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const short&); \ - FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned short&); +#define FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, TYPE) \ + Tensor StubBackend::FUNC(const Tensor& /* a */, TYPE /* rhs */) { \ + throw std::runtime_error( \ + "StubBackend::" + std::string(#FUNC) + " unimplemented for type " + \ + std::string(#TYPE) \ + ); \ + } \ + Tensor StubBackend::FUNC(TYPE /* lhs */, const Tensor& /* a */) { \ + throw std::runtime_error( \ + "StubBackend::" + std::string(#FUNC) + " unimplemented for type " + \ + std::string(#TYPE) \ + ); \ + } + +#define FL_AF_BINARY_OP_LITERALS_DEF(FUNC, OP) \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const bool&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const int&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const char&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned char&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const long&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned long&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const long long&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned long long&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const double&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const float&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const short&); \ + FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, const unsigned short&); // Operations on fl::Tensor call the respective operator overloads that are // already defined on af::arrays -#define FL_AF_BINARY_OP_DEF(OP, FUNC) \ - Tensor StubBackend::FUNC(const Tensor& /* lhs */, const Tensor& /* rhs */) { \ - throw std::runtime_error( \ - "StubBackend::" + std::string(#FUNC) + \ - " unimplemented for two-Tensor inputs."); \ - } \ - FL_AF_BINARY_OP_LITERALS_DEF(FUNC, OP); +#define FL_AF_BINARY_OP_DEF(OP, FUNC) \ + Tensor StubBackend::FUNC(const Tensor& /* lhs */, const Tensor& /* rhs */) { \ + throw std::runtime_error( \ + "StubBackend::" + std::string(#FUNC) + \ + " unimplemented for two-Tensor inputs." \ + ); \ + } \ + FL_AF_BINARY_OP_LITERALS_DEF(FUNC, OP); // Definitions // Since ArrayFire implements operator overloads, map both fl::Tensor @@ -366,15 +389,15 @@ FL_AF_BINARY_OP_DEF(>>, rShift); #undef FL_AF_BINARY_OP_LITERALS_DEF Tensor StubBackend::minimum(const Tensor& /* lhs */, const Tensor& /* rhs */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::maximum(const Tensor& /* lhs */, const Tensor& /* rhs */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::power(const Tensor& /* lhs */, const Tensor& /* rhs */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } /************************** BLAS ***************************/ @@ -383,8 +406,9 @@ Tensor StubBackend::matmul( const Tensor& /* lhs */, const Tensor& /* rhs */, MatrixProperty /* lhsProp */, - MatrixProperty /* rhsProp */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + MatrixProperty /* rhsProp */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } /************************** Reductions ***************************/ @@ -392,15 +416,17 @@ Tensor StubBackend::matmul( Tensor StubBackend::amin( const Tensor& /* input */, const std::vector& /* axes */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::amax( const Tensor& /* input */, const std::vector& /* axes */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::min( @@ -408,8 +434,9 @@ void StubBackend::min( Tensor& /* indices */, const Tensor& /* input */, const unsigned /* axis */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::max( @@ -417,97 +444,110 @@ void StubBackend::max( Tensor& /* indices */, const Tensor& /* input */, const unsigned /* axis */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::sum( const Tensor& /* input */, const std::vector& /* axes */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::cumsum( const Tensor& /* input */, - const unsigned /* axis */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const unsigned /* axis */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::argmax( const Tensor& /* input */, const unsigned /* axis */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::argmin( const Tensor& /* input */, const unsigned /* axis */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::mean( const Tensor& /* input */, const std::vector& /* axes */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::median( const Tensor& /* input */, const std::vector& /* axes */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::var( const Tensor& /* input */, const std::vector& /* axes */, const bool /* bias */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::std( const Tensor& /* input */, const std::vector& /* axes */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::norm( const Tensor& /* input */, const std::vector& /* axes */, double /* p */ /* = 2 */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::countNonzero( const Tensor& /* input */, const std::vector& /* axes */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::any( const Tensor& /* input */, const std::vector& /* axes */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::all( const Tensor& /* input */, const std::vector& /* axes */, - const bool /* keepDims */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + const bool /* keepDims */ +) { + FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::print(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; + FL_STUB_BACKEND_UNIMPLEMENTED; } } // namespace fl diff --git a/flashlight/fl/tensor/backend/stub/StubBackend.h b/flashlight/fl/tensor/backend/stub/StubBackend.h index 7f5e390..062d2d3 100644 --- a/flashlight/fl/tensor/backend/stub/StubBackend.h +++ b/flashlight/fl/tensor/backend/stub/StubBackend.h @@ -18,242 +18,259 @@ namespace fl { * This stub can be copied, renamed, and implemented as needed. */ class StubBackend : public TensorBackend { - public: - StubBackend(); +public: + StubBackend(); - static StubBackend& getInstance(); - ~StubBackend() override = default; - TensorBackendType backendType() const override; + static StubBackend& getInstance(); + ~StubBackend() override = default; + TensorBackendType backendType() const override; - // No copy or move construction or assignment - StubBackend(StubBackend&&) = delete; - StubBackend(const StubBackend&) = delete; - StubBackend& operator=(StubBackend&&) = delete; - StubBackend& operator=(const StubBackend&) = delete; + // No copy or move construction or assignment + StubBackend(StubBackend&&) = delete; + StubBackend(const StubBackend&) = delete; + StubBackend& operator=(StubBackend&&) = delete; + StubBackend& operator=(const StubBackend&) = delete; - /* -------------------------- Compute Functions -------------------------- */ - void eval(const Tensor& tensor) override; - bool supportsDataType(const fl::dtype& dtype) const override; - // Memory management - void getMemMgrInfo(const char* msg, const int deviceId, std::ostream* ostream) - override; - void setMemMgrLogStream(std::ostream* stream) override; - void setMemMgrLoggingEnabled(const bool enabled) override; - void setMemMgrFlushInterval(const size_t interval) override; + /* -------------------------- Compute Functions -------------------------- */ + void eval(const Tensor& tensor) override; + bool supportsDataType(const fl::dtype& dtype) const override; + // Memory management + void getMemMgrInfo(const char* msg, const int deviceId, std::ostream* ostream) + override; + void setMemMgrLogStream(std::ostream* stream) override; + void setMemMgrLoggingEnabled(const bool enabled) override; + void setMemMgrFlushInterval(const size_t interval) override; - /* -------------------------- Rand Functions -------------------------- */ - void setSeed(const int seed) override; - Tensor randn(const Shape& shape, dtype type) override; - Tensor rand(const Shape& shape, dtype type) override; + /* -------------------------- Rand Functions -------------------------- */ + void setSeed(const int seed) override; + Tensor randn(const Shape& shape, dtype type) override; + Tensor rand(const Shape& shape, dtype type) override; - /* --------------------------- Tensor Operators --------------------------- */ - /******************** Tensor Creation Functions ********************/ -#define FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(TYPE) \ - Tensor fromScalar(TYPE value, const dtype type) override; \ - Tensor full(const Shape& dims, TYPE value, const dtype type) override; - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const double&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const float&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const int&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const char&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned char&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const long&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned long&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const long long&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned long long&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const bool&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const short&); - FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned short&); + /* --------------------------- Tensor Operators --------------------------- */ + /******************** Tensor Creation Functions ********************/ +#define FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(TYPE) \ + Tensor fromScalar(TYPE value, const dtype type) override; \ + Tensor full(const Shape& dims, TYPE value, const dtype type) override; + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const double&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const float&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const int&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const char&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned char&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const long&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned long&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const long long&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned long long&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const bool&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const short&); + FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL(const unsigned short&); #undef FL_STUB_BACKEND_CREATE_FUN_LITERAL_DECL - Tensor identity(const Dim dim, const dtype type) override; - Tensor arange(const Shape& shape, const Dim seqDim, const dtype type) - override; - Tensor iota(const Shape& dims, const Shape& tileDims, const dtype type) - override; + Tensor identity(const Dim dim, const dtype type) override; + Tensor arange(const Shape& shape, const Dim seqDim, const dtype type) + override; + Tensor iota(const Shape& dims, const Shape& tileDims, const dtype type) + override; - /************************ Shaping and Indexing *************************/ - Tensor reshape(const Tensor& tensor, const Shape& shape) override; - Tensor transpose(const Tensor& tensor, const Shape& axes /* = {} */) override; - Tensor tile(const Tensor& tensor, const Shape& shape) override; - Tensor concatenate(const std::vector& tensors, const unsigned axis) - override; - Tensor nonzero(const Tensor& tensor) override; - Tensor pad( - const Tensor& input, - const std::vector>& padWidths, - const PadType type) override; + /************************ Shaping and Indexing *************************/ + Tensor reshape(const Tensor& tensor, const Shape& shape) override; + Tensor transpose(const Tensor& tensor, const Shape& axes /* = {} */) override; + Tensor tile(const Tensor& tensor, const Shape& shape) override; + Tensor concatenate(const std::vector& tensors, const unsigned axis) + override; + Tensor nonzero(const Tensor& tensor) override; + Tensor pad( + const Tensor& input, + const std::vector>& padWidths, + const PadType type + ) override; - /************************** Unary Operators ***************************/ - Tensor exp(const Tensor& tensor) override; - Tensor log(const Tensor& tensor) override; - Tensor negative(const Tensor& tensor) override; - Tensor logicalNot(const Tensor& tensor) override; - Tensor log1p(const Tensor& tensor) override; - Tensor sin(const Tensor& tensor) override; - Tensor cos(const Tensor& tensor) override; - Tensor sqrt(const Tensor& tensor) override; - Tensor tanh(const Tensor& tensor) override; - Tensor floor(const Tensor& tensor) override; - Tensor ceil(const Tensor& tensor) override; - Tensor rint(const Tensor& tensor) override; - Tensor absolute(const Tensor& tensor) override; - Tensor sigmoid(const Tensor& tensor) override; - Tensor erf(const Tensor& tensor) override; - Tensor flip(const Tensor& tensor, const unsigned dim) override; - Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high) - override; - Tensor roll(const Tensor& tensor, const int shift, const unsigned axis) - override; - Tensor isnan(const Tensor& tensor) override; - Tensor isinf(const Tensor& tensor) override; - Tensor sign(const Tensor& tensor) override; - Tensor tril(const Tensor& tensor) override; - Tensor triu(const Tensor& tensor) override; - Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) - override; - void topk( - Tensor& values, - Tensor& indices, - const Tensor& input, - const unsigned k, - const Dim axis, - const SortMode sortMode) override; - Tensor sort(const Tensor& input, const Dim axis, const SortMode sortMode) - override; - void sort( - Tensor& values, - Tensor& indices, - const Tensor& input, - const Dim axis, - const SortMode sortMode) override; - Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) - override; + /************************** Unary Operators ***************************/ + Tensor exp(const Tensor& tensor) override; + Tensor log(const Tensor& tensor) override; + Tensor negative(const Tensor& tensor) override; + Tensor logicalNot(const Tensor& tensor) override; + Tensor log1p(const Tensor& tensor) override; + Tensor sin(const Tensor& tensor) override; + Tensor cos(const Tensor& tensor) override; + Tensor sqrt(const Tensor& tensor) override; + Tensor tanh(const Tensor& tensor) override; + Tensor floor(const Tensor& tensor) override; + Tensor ceil(const Tensor& tensor) override; + Tensor rint(const Tensor& tensor) override; + Tensor absolute(const Tensor& tensor) override; + Tensor sigmoid(const Tensor& tensor) override; + Tensor erf(const Tensor& tensor) override; + Tensor flip(const Tensor& tensor, const unsigned dim) override; + Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high) + override; + Tensor roll(const Tensor& tensor, const int shift, const unsigned axis) + override; + Tensor isnan(const Tensor& tensor) override; + Tensor isinf(const Tensor& tensor) override; + Tensor sign(const Tensor& tensor) override; + Tensor tril(const Tensor& tensor) override; + Tensor triu(const Tensor& tensor) override; + Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) + override; + void topk( + Tensor& values, + Tensor& indices, + const Tensor& input, + const unsigned k, + const Dim axis, + const SortMode sortMode + ) override; + Tensor sort(const Tensor& input, const Dim axis, const SortMode sortMode) + override; + void sort( + Tensor& values, + Tensor& indices, + const Tensor& input, + const Dim axis, + const SortMode sortMode + ) override; + Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) + override; - /************************** Binary Operators ***************************/ -#define FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, TYPE) \ - Tensor FUNC(const Tensor& a, TYPE rhs) override; \ - Tensor FUNC(TYPE lhs, const Tensor& a) override; + /************************** Binary Operators ***************************/ +#define FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, TYPE) \ + Tensor FUNC(const Tensor& a, TYPE rhs) override; \ + Tensor FUNC(TYPE lhs, const Tensor& a) override; -#define FL_STUB_BACKEND_BINARY_OP_LITERALS_DECL(FUNC) \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const bool&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const int&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const unsigned&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const char&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const unsigned char&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const long&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const unsigned long&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const long long&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const unsigned long long&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const double&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const float&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const short&); \ - FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const unsigned short&); +#define FL_STUB_BACKEND_BINARY_OP_LITERALS_DECL(FUNC) \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const bool&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const int&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const unsigned&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const char&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const unsigned char&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const long&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const unsigned long&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const long long&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const unsigned long long&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const double&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const float&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const short&); \ + FL_STUB_BACKEND_BINARY_OP_TYPE_DECL(FUNC, const unsigned short&); -#define FL_STUB_BACKEND_BINARY_OP_DECL(FUNC) \ - Tensor FUNC(const Tensor& lhs, const Tensor& rhs) override; \ - FL_STUB_BACKEND_BINARY_OP_LITERALS_DECL(FUNC); +#define FL_STUB_BACKEND_BINARY_OP_DECL(FUNC) \ + Tensor FUNC(const Tensor& lhs, const Tensor& rhs) override; \ + FL_STUB_BACKEND_BINARY_OP_LITERALS_DECL(FUNC); - FL_STUB_BACKEND_BINARY_OP_DECL(add); - FL_STUB_BACKEND_BINARY_OP_DECL(sub); - FL_STUB_BACKEND_BINARY_OP_DECL(mul); - FL_STUB_BACKEND_BINARY_OP_DECL(div); - FL_STUB_BACKEND_BINARY_OP_DECL(eq); - FL_STUB_BACKEND_BINARY_OP_DECL(neq); - FL_STUB_BACKEND_BINARY_OP_DECL(lessThan); - FL_STUB_BACKEND_BINARY_OP_DECL(lessThanEqual); - FL_STUB_BACKEND_BINARY_OP_DECL(greaterThan); - FL_STUB_BACKEND_BINARY_OP_DECL(greaterThanEqual); - FL_STUB_BACKEND_BINARY_OP_DECL(logicalOr); - FL_STUB_BACKEND_BINARY_OP_DECL(logicalAnd); - FL_STUB_BACKEND_BINARY_OP_DECL(mod); - FL_STUB_BACKEND_BINARY_OP_DECL(bitwiseAnd); - FL_STUB_BACKEND_BINARY_OP_DECL(bitwiseOr); - FL_STUB_BACKEND_BINARY_OP_DECL(bitwiseXor); - FL_STUB_BACKEND_BINARY_OP_DECL(lShift); - FL_STUB_BACKEND_BINARY_OP_DECL(rShift); + FL_STUB_BACKEND_BINARY_OP_DECL(add); + FL_STUB_BACKEND_BINARY_OP_DECL(sub); + FL_STUB_BACKEND_BINARY_OP_DECL(mul); + FL_STUB_BACKEND_BINARY_OP_DECL(div); + FL_STUB_BACKEND_BINARY_OP_DECL(eq); + FL_STUB_BACKEND_BINARY_OP_DECL(neq); + FL_STUB_BACKEND_BINARY_OP_DECL(lessThan); + FL_STUB_BACKEND_BINARY_OP_DECL(lessThanEqual); + FL_STUB_BACKEND_BINARY_OP_DECL(greaterThan); + FL_STUB_BACKEND_BINARY_OP_DECL(greaterThanEqual); + FL_STUB_BACKEND_BINARY_OP_DECL(logicalOr); + FL_STUB_BACKEND_BINARY_OP_DECL(logicalAnd); + FL_STUB_BACKEND_BINARY_OP_DECL(mod); + FL_STUB_BACKEND_BINARY_OP_DECL(bitwiseAnd); + FL_STUB_BACKEND_BINARY_OP_DECL(bitwiseOr); + FL_STUB_BACKEND_BINARY_OP_DECL(bitwiseXor); + FL_STUB_BACKEND_BINARY_OP_DECL(lShift); + FL_STUB_BACKEND_BINARY_OP_DECL(rShift); #undef FL_STUB_BACKEND_BINARY_OP_DECL #undef FL_STUB_BACKEND_BINARY_OP_TYPE_DECL #undef FL_STUB_BACKEND_BINARY_OP_LITERALS_DECL - Tensor minimum(const Tensor& lhs, const Tensor& rhs) override; - Tensor maximum(const Tensor& lhs, const Tensor& rhs) override; - Tensor power(const Tensor& lhs, const Tensor& rhs) override; + Tensor minimum(const Tensor& lhs, const Tensor& rhs) override; + Tensor maximum(const Tensor& lhs, const Tensor& rhs) override; + Tensor power(const Tensor& lhs, const Tensor& rhs) override; - /******************************* BLAS ********************************/ - Tensor matmul( - const Tensor& lhs, - const Tensor& rhs, - MatrixProperty lhsProp, - MatrixProperty rhsProp) override; + /******************************* BLAS ********************************/ + Tensor matmul( + const Tensor& lhs, + const Tensor& rhs, + MatrixProperty lhsProp, + MatrixProperty rhsProp + ) override; - /************************** Reductions ***************************/ - Tensor amin( - const Tensor& input, - const std::vector& axes, - const bool keepDims) override; - Tensor amax( - const Tensor& input, - const std::vector& axes, - const bool keepDims) override; - void min( - Tensor& values, - Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims) override; - void max( - Tensor& values, - Tensor& indices, - const Tensor& input, - const unsigned axis, - const bool keepDims) override; - Tensor sum( - const Tensor& input, - const std::vector& axes, - const bool keepDims) override; - Tensor cumsum(const Tensor& input, const unsigned axis) override; - Tensor argmax(const Tensor& input, const unsigned axis, const bool keepDims) - override; - Tensor argmin(const Tensor& input, const unsigned axis, const bool keepDims) - override; - Tensor mean( - const Tensor& input, - const std::vector& axes, - const bool keepDims) override; - Tensor median( - const Tensor& input, - const std::vector& axes, - const bool keepDims) override; - Tensor var( - const Tensor& input, - const std::vector& axes, - const bool bias, - const bool keepDims) override; - Tensor std( - const Tensor& input, - const std::vector& axes, - const bool keepDims) override; - Tensor norm( - const Tensor& input, - const std::vector& axes, - double p, - const bool keepDims) override; - Tensor countNonzero( - const Tensor& input, - const std::vector& axes, - const bool keepDims) override; - Tensor any( - const Tensor& input, - const std::vector& axes, - const bool keepDims) override; - Tensor all( - const Tensor& input, - const std::vector& axes, - const bool keepDims) override; + /************************** Reductions ***************************/ + Tensor amin( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) override; + Tensor amax( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) override; + void min( + Tensor& values, + Tensor& indices, + const Tensor& input, + const unsigned axis, + const bool keepDims + ) override; + void max( + Tensor& values, + Tensor& indices, + const Tensor& input, + const unsigned axis, + const bool keepDims + ) override; + Tensor sum( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) override; + Tensor cumsum(const Tensor& input, const unsigned axis) override; + Tensor argmax(const Tensor& input, const unsigned axis, const bool keepDims) + override; + Tensor argmin(const Tensor& input, const unsigned axis, const bool keepDims) + override; + Tensor mean( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) override; + Tensor median( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) override; + Tensor var( + const Tensor& input, + const std::vector& axes, + const bool bias, + const bool keepDims + ) override; + Tensor std( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) override; + Tensor norm( + const Tensor& input, + const std::vector& axes, + double p, + const bool keepDims + ) override; + Tensor countNonzero( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) override; + Tensor any( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) override; + Tensor all( + const Tensor& input, + const std::vector& axes, + const bool keepDims + ) override; - /************************** Utils ***************************/ - void print(const Tensor& tensor) override; + /************************** Utils ***************************/ + void print(const Tensor& tensor) override; }; } // namespace fl diff --git a/flashlight/fl/tensor/backend/stub/StubTensor.cpp b/flashlight/fl/tensor/backend/stub/StubTensor.cpp index 2bb5db9..3c2c6a9 100644 --- a/flashlight/fl/tensor/backend/stub/StubTensor.cpp +++ b/flashlight/fl/tensor/backend/stub/StubTensor.cpp @@ -7,9 +7,10 @@ #include "flashlight/fl/tensor/backend/stub/StubTensor.h" -#define FL_STUB_TENSOR_UNIMPLEMENTED \ - throw std::invalid_argument( \ - "StubTensor::" + std::string(__func__) + " - unimplemented."); +#define FL_STUB_TENSOR_UNIMPLEMENTED \ + throw std::invalid_argument( \ + "StubTensor::" + std::string(__func__) + " - unimplemented." \ + ); namespace fl { @@ -19,7 +20,8 @@ StubTensor::StubTensor( const Shape& /* shape */, fl::dtype /* type */, const void* /* ptr */, - Location /* memoryLocation */) {} + Location /* memoryLocation */ +) {} StubTensor::StubTensor( const Dim /* nRows */, @@ -27,143 +29,145 @@ StubTensor::StubTensor( const Tensor& /* values */, const Tensor& /* rowIdx */, const Tensor& /* colIdx */, - StorageType /* storageType */) {} + StorageType /* storageType */ +) {} std::unique_ptr StubTensor::clone() const { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } Tensor StubTensor::copy() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } Tensor StubTensor::shallowCopy() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } TensorBackendType StubTensor::backendType() const { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } TensorBackend& StubTensor::backend() const { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } const Shape& StubTensor::shape() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } fl::dtype StubTensor::type() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } bool StubTensor::isSparse() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } Location StubTensor::location() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } void StubTensor::scalar(void* /* out */) { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } void StubTensor::device(void** /* out */) { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } void StubTensor::host(void* /* out */) { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } void StubTensor::unlock() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } bool StubTensor::isLocked() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } bool StubTensor::isContiguous() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } Shape StubTensor::strides() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } const Stream& StubTensor::stream() const { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } Tensor StubTensor::astype(const dtype /* type */) { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } Tensor StubTensor::index(const std::vector& /* indices */) { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } Tensor StubTensor::flatten() const { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } Tensor StubTensor::flat(const Index& /* idx */) const { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } Tensor StubTensor::asContiguousTensor() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } void StubTensor::setContext(void* /* context */) { - // Used to store arbitrary data on a Tensor - can be a noop. - FL_STUB_TENSOR_UNIMPLEMENTED; + // Used to store arbitrary data on a Tensor - can be a noop. + FL_STUB_TENSOR_UNIMPLEMENTED; } void* StubTensor::getContext() { - // Used to store arbitrary data on a Tensor - can be a noop. - FL_STUB_TENSOR_UNIMPLEMENTED; + // Used to store arbitrary data on a Tensor - can be a noop. + FL_STUB_TENSOR_UNIMPLEMENTED; } std::string StubTensor::toString() { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } std::ostream& StubTensor::operator<<(std::ostream& /* ostr */) { - FL_STUB_TENSOR_UNIMPLEMENTED; + FL_STUB_TENSOR_UNIMPLEMENTED; } /******************** Assignment Operators ********************/ -#define FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, TYPE) \ - void StubTensor::OP(const TYPE& /* val */) { \ - throw std::invalid_argument( \ - "StubTensor::" + std::string(#OP) + " for type " + \ - std::string(#TYPE)); \ - } - -#define FL_STUB_TENSOR_ASSIGN_OP(OP) \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, Tensor); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, double); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, float); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, int); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, unsigned); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, bool); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, char); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, unsigned char); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, short); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, unsigned short); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, long); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, unsigned long); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, long long); \ - FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, unsigned long long); +#define FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, TYPE) \ + void StubTensor::OP(const TYPE& /* val */) { \ + throw std::invalid_argument( \ + "StubTensor::" + std::string(#OP) + " for type " + \ + std::string(#TYPE) \ + ); \ + } + +#define FL_STUB_TENSOR_ASSIGN_OP(OP) \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, Tensor); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, double); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, float); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, int); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, unsigned); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, bool); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, char); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, unsigned char); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, short); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, unsigned short); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, long); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, unsigned long); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, long long); \ + FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, unsigned long long); FL_STUB_TENSOR_ASSIGN_OP(assign); // = FL_STUB_TENSOR_ASSIGN_OP(inPlaceAdd); // += FL_STUB_TENSOR_ASSIGN_OP(inPlaceSubtract); // -= FL_STUB_TENSOR_ASSIGN_OP(inPlaceMultiply); // *= -FL_STUB_TENSOR_ASSIGN_OP(inPlaceDivide); // /= +FL_STUB_TENSOR_ASSIGN_OP(inPlaceDivide); ///= #undef FL_STUB_TENSOR_ASSIGN_OP_TYPE #undef FL_STUB_TENSOR_ASSIGN_OP diff --git a/flashlight/fl/tensor/backend/stub/StubTensor.h b/flashlight/fl/tensor/backend/stub/StubTensor.h index 96bd1f9..b5329fd 100644 --- a/flashlight/fl/tensor/backend/stub/StubTensor.h +++ b/flashlight/fl/tensor/backend/stub/StubTensor.h @@ -18,87 +18,89 @@ namespace fl { * This stub can be copied, renamed, and implemented as needed. */ class StubTensor : public TensorAdapterBase { - public: - constexpr static TensorBackendType tensorBackendType = - TensorBackendType::Stub; +public: + constexpr static TensorBackendType tensorBackendType = + TensorBackendType::Stub; - StubTensor(); + StubTensor(); - /** - * Construct a StubTensor using some data. - * - * @param[in] shape the shape of the new tensor - * @param[in] ptr the buffer containing underlying tensor data - * @param[in] type the type of the new tensor - * @param[in] memoryLocation the location of the buffer - */ - StubTensor( - const Shape& shape, - fl::dtype type, - const void* ptr, - Location memoryLocation); + /** + * Construct a StubTensor using some data. + * + * @param[in] shape the shape of the new tensor + * @param[in] ptr the buffer containing underlying tensor data + * @param[in] type the type of the new tensor + * @param[in] memoryLocation the location of the buffer + */ + StubTensor( + const Shape& shape, + fl::dtype type, + const void* ptr, + Location memoryLocation + ); - // Constructor for a sparse StubTensor. Can throw if unimplemented. - StubTensor( - const Dim nRows, - const Dim nCols, - const Tensor& values, - const Tensor& rowIdx, - const Tensor& colIdx, - StorageType storageType); + // Constructor for a sparse StubTensor. Can throw if unimplemented. + StubTensor( + const Dim nRows, + const Dim nCols, + const Tensor& values, + const Tensor& rowIdx, + const Tensor& colIdx, + StorageType storageType + ); - ~StubTensor() override = default; - std::unique_ptr clone() const override; - TensorBackendType backendType() const override; - TensorBackend& backend() const override; - Tensor copy() override; - Tensor shallowCopy() override; - const Shape& shape() override; - dtype type() override; - bool isSparse() override; - Location location() override; - void scalar(void* out) override; - void device(void** out) override; - void host(void* out) override; - void unlock() override; - bool isLocked() override; - bool isContiguous() override; - Shape strides() override; - const Stream& stream() const override; - Tensor astype(const dtype type) override; - Tensor index(const std::vector& indices) override; - Tensor flatten() const override; - Tensor flat(const Index& idx) const override; - Tensor asContiguousTensor() override; - void setContext(void* context) override; - void* getContext() override; - std::string toString() override; - std::ostream& operator<<(std::ostream& ostr) override; + ~StubTensor() override = default; + std::unique_ptr clone() const override; + TensorBackendType backendType() const override; + TensorBackend& backend() const override; + Tensor copy() override; + Tensor shallowCopy() override; + const Shape& shape() override; + dtype type() override; + bool isSparse() override; + Location location() override; + void scalar(void* out) override; + void device(void** out) override; + void host(void* out) override; + void unlock() override; + bool isLocked() override; + bool isContiguous() override; + Shape strides() override; + const Stream& stream() const override; + Tensor astype(const dtype type) override; + Tensor index(const std::vector& indices) override; + Tensor flatten() const override; + Tensor flat(const Index& idx) const override; + Tensor asContiguousTensor() override; + void setContext(void* context) override; + void* getContext() override; + std::string toString() override; + std::ostream& operator<<(std::ostream& ostr) override; - /******************** Assignment Operators ********************/ + /******************** Assignment Operators ********************/ #define ASSIGN_OP_TYPE(OP, TYPE) void OP(const TYPE& val) override; -#define ASSIGN_OP(OP) \ - ASSIGN_OP_TYPE(OP, Tensor); \ - ASSIGN_OP_TYPE(OP, double); \ - ASSIGN_OP_TYPE(OP, float); \ - ASSIGN_OP_TYPE(OP, int); \ - ASSIGN_OP_TYPE(OP, unsigned); \ - ASSIGN_OP_TYPE(OP, bool); \ - ASSIGN_OP_TYPE(OP, char); \ - ASSIGN_OP_TYPE(OP, unsigned char); \ - ASSIGN_OP_TYPE(OP, short); \ - ASSIGN_OP_TYPE(OP, unsigned short); \ - ASSIGN_OP_TYPE(OP, long); \ - ASSIGN_OP_TYPE(OP, unsigned long); \ - ASSIGN_OP_TYPE(OP, long long); \ - ASSIGN_OP_TYPE(OP, unsigned long long); +#define ASSIGN_OP(OP) \ + ASSIGN_OP_TYPE(OP, Tensor); \ + ASSIGN_OP_TYPE(OP, double); \ + ASSIGN_OP_TYPE(OP, float); \ + ASSIGN_OP_TYPE(OP, int); \ + ASSIGN_OP_TYPE(OP, unsigned); \ + ASSIGN_OP_TYPE(OP, bool); \ + ASSIGN_OP_TYPE(OP, char); \ + ASSIGN_OP_TYPE(OP, unsigned char); \ + ASSIGN_OP_TYPE(OP, short); \ + ASSIGN_OP_TYPE(OP, unsigned short); \ + ASSIGN_OP_TYPE(OP, long); \ + ASSIGN_OP_TYPE(OP, unsigned long); \ + ASSIGN_OP_TYPE(OP, long long); \ + ASSIGN_OP_TYPE(OP, unsigned long long); - ASSIGN_OP(assign); // = - ASSIGN_OP(inPlaceAdd); // += - ASSIGN_OP(inPlaceSubtract); // -= - ASSIGN_OP(inPlaceMultiply); // *= - ASSIGN_OP(inPlaceDivide); // /= + ASSIGN_OP(assign); // = + ASSIGN_OP(inPlaceAdd); // += + ASSIGN_OP(inPlaceSubtract); // -= + ASSIGN_OP(inPlaceMultiply); // *= + ASSIGN_OP(inPlaceDivide); ///= #undef ASSIGN_OP_TYPE #undef ASSIGN_OP }; diff --git a/flashlight/fl/tensor/profile/CUDAProfile.cpp b/flashlight/fl/tensor/profile/CUDAProfile.cpp index 13a541d..7e0d1b1 100644 --- a/flashlight/fl/tensor/profile/CUDAProfile.cpp +++ b/flashlight/fl/tensor/profile/CUDAProfile.cpp @@ -15,19 +15,19 @@ namespace fl::detail { ScopedProfiler::ScopedProfiler() { - FL_CUDA_CHECK(cudaProfilerStart()); + FL_CUDA_CHECK(cudaProfilerStart()); } ScopedProfiler::~ScopedProfiler() { - FL_CUDA_CHECK(cudaProfilerStop()); + FL_CUDA_CHECK(cudaProfilerStop()); } ProfileTracer::ProfileTracer(const std::string& name) { - nvtxRangePush(name.c_str()); + nvtxRangePush(name.c_str()); } ProfileTracer::~ProfileTracer() { - nvtxRangePop(); + nvtxRangePop(); } } // namespace fl diff --git a/flashlight/fl/tensor/profile/Profile.h b/flashlight/fl/tensor/profile/Profile.h index 705e7a5..5d65290 100644 --- a/flashlight/fl/tensor/profile/Profile.h +++ b/flashlight/fl/tensor/profile/Profile.h @@ -15,15 +15,15 @@ namespace detail { /** * An RAII abstraction to start and stop profiling recording. */ -class ScopedProfiler { - public: - ScopedProfiler(); - ~ScopedProfiler(); -}; + class ScopedProfiler { + public: + ScopedProfiler(); + ~ScopedProfiler(); + }; /** * An RAII abstractiont to label a profile interval over the lifetime for an - object given a specific scope. For example: + object given a specific scope. For example: * \code { ProfileTracer tr("myOperation"); @@ -32,24 +32,24 @@ class ScopedProfiler { } * \endcode */ -class ProfileTracer { - public: - explicit ProfileTracer(const std::string& name); - ~ProfileTracer(); -}; + class ProfileTracer { + public: + explicit ProfileTracer(const std::string& name); + ~ProfileTracer(); + }; } // namespace detail } // namespace fl #if FL_BUILD_PROFILING // Used to generate a unique name for the expansion -#define _FL_PROFILE_CAT(a, b) a##b +#define _FL_PROFILE_CAT(a, b) a ## b #define FL_PROFILE_TRACE(name) \ - fl::detail::ProfileTracer _FL_PROFILE_CAT(profileTracer, __LINE__)(name); + fl::detail::ProfileTracer _FL_PROFILE_CAT(profileTracer, __LINE__)(name); #define FL_SCOPED_PROFILE() \ - fl::detail::ScopedProfiler _FL_PROFILE_CAT(scopedProfile, __LINE__); + fl::detail::ScopedProfiler _FL_PROFILE_CAT(scopedProfile, __LINE__); #else #define FL_PROFILE_TRACE(_) diff --git a/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp b/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp index d0ee0d4..d0b1342 100644 --- a/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp +++ b/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp @@ -18,295 +18,304 @@ using namespace fl; using fl::detail::AutogradTestF16; TEST(AutogradBinaryOpsTest, BasicOps) { - using FuncVar = std::function; - using FuncScalarL = std::function; - using FuncScalarR = std::function; - auto testImpl = [](FuncVar fn1, FuncScalarL fn2, FuncScalarR fn3) { - auto input = Variable(fl::rand({3, 4, 5, 6}, fl::dtype::f64) + 1, true); - auto temp = Variable(fl::rand({3, 4, 5, 6}, fl::dtype::f64) - 2, false); - fl::detail::JacobianFunc fnArrL = [&](Variable& in) { return fn1(in, temp); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(fnArrL, input)); - fl::detail::JacobianFunc fnArrR = [&](Variable& in) { return fn1(temp, in); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(fnArrR, input)); - fl::detail::JacobianFunc fnScalarL = [&](Variable& in) { return fn2(1.414, in); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(fnScalarL, input, 1E-5, 1E-7)); - fl::detail::JacobianFunc fnScalarR = [&](Variable& in) { return fn3(in, 1.732); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(fnScalarR, input, 1E-5, 1E-7)); - }; - - FuncVar funcAdd1 = [](Variable& a, Variable& b) { return a + b; }; - FuncScalarL funcAdd2 = [](double a, Variable& b) { return a + b; }; - FuncScalarR funcAdd3 = [](Variable& a, double b) { return a + b; }; - testImpl(funcAdd1, funcAdd2, funcAdd3); - - FuncVar funcSub1 = [](Variable& a, Variable& b) { return a - b; }; - FuncScalarL funcSub2 = [](double a, Variable& b) { return a - b; }; - FuncScalarR funcSub3 = [](Variable& a, double b) { return a - b; }; - testImpl(funcSub1, funcSub2, funcSub3); - - FuncVar funcDiv1 = [](Variable& a, Variable& b) { return a / b; }; - FuncScalarL funcDiv2 = [](double a, Variable& b) { return a / b; }; - FuncScalarR funcDiv3 = [](Variable& a, double b) { return a / b; }; - testImpl(funcDiv1, funcDiv2, funcDiv3); - - FuncVar funcMul1 = [](Variable& a, Variable& b) { return a * b; }; - FuncScalarL funcMul2 = [](double a, Variable& b) { return a * b; }; - FuncScalarR funcMul3 = [](Variable& a, double b) { return a * b; }; - testImpl(funcMul1, funcMul2, funcMul3); - - FuncVar funcMin1 = [](Variable& a, Variable& b) { return min(a, b); }; - FuncScalarL funcMin2 = [](double a, Variable& b) { return min(a, b); }; - FuncScalarR funcMin3 = [](Variable& a, double b) { return min(a, b); }; - testImpl(funcMin1, funcMin2, funcMin3); - - FuncVar funcMax1 = [](Variable& a, Variable& b) { return max(a, b); }; - FuncScalarL funcMax2 = [](double a, Variable& b) { return max(a, b); }; - FuncScalarR funcMax3 = [](Variable& a, double b) { return max(a, b); }; - testImpl(funcMax1, funcMax2, funcMax3); + using FuncVar = std::function; + using FuncScalarL = std::function; + using FuncScalarR = std::function; + auto testImpl = [](FuncVar fn1, FuncScalarL fn2, FuncScalarR fn3) { + auto input = Variable(fl::rand({3, 4, 5, 6}, fl::dtype::f64) + 1, true); + auto temp = Variable(fl::rand({3, 4, 5, 6}, fl::dtype::f64) - 2, false); + fl::detail::JacobianFunc fnArrL = [&](Variable& in) { return fn1(in, temp); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(fnArrL, input)); + fl::detail::JacobianFunc fnArrR = [&](Variable& in) { return fn1(temp, in); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(fnArrR, input)); + fl::detail::JacobianFunc fnScalarL = [&](Variable& in) { return fn2(1.414, in); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(fnScalarL, input, 1E-5, 1E-7)); + fl::detail::JacobianFunc fnScalarR = [&](Variable& in) { return fn3(in, 1.732); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(fnScalarR, input, 1E-5, 1E-7)); + }; + + FuncVar funcAdd1 = [](Variable& a, Variable& b) { return a + b; }; + FuncScalarL funcAdd2 = [](double a, Variable& b) { return a + b; }; + FuncScalarR funcAdd3 = [](Variable& a, double b) { return a + b; }; + testImpl(funcAdd1, funcAdd2, funcAdd3); + + FuncVar funcSub1 = [](Variable& a, Variable& b) { return a - b; }; + FuncScalarL funcSub2 = [](double a, Variable& b) { return a - b; }; + FuncScalarR funcSub3 = [](Variable& a, double b) { return a - b; }; + testImpl(funcSub1, funcSub2, funcSub3); + + FuncVar funcDiv1 = [](Variable& a, Variable& b) { return a / b; }; + FuncScalarL funcDiv2 = [](double a, Variable& b) { return a / b; }; + FuncScalarR funcDiv3 = [](Variable& a, double b) { return a / b; }; + testImpl(funcDiv1, funcDiv2, funcDiv3); + + FuncVar funcMul1 = [](Variable& a, Variable& b) { return a * b; }; + FuncScalarL funcMul2 = [](double a, Variable& b) { return a * b; }; + FuncScalarR funcMul3 = [](Variable& a, double b) { return a * b; }; + testImpl(funcMul1, funcMul2, funcMul3); + + FuncVar funcMin1 = [](Variable& a, Variable& b) { return min(a, b); }; + FuncScalarL funcMin2 = [](double a, Variable& b) { return min(a, b); }; + FuncScalarR funcMin3 = [](Variable& a, double b) { return min(a, b); }; + testImpl(funcMin1, funcMin2, funcMin3); + + FuncVar funcMax1 = [](Variable& a, Variable& b) { return max(a, b); }; + FuncScalarL funcMax2 = [](double a, Variable& b) { return max(a, b); }; + FuncScalarR funcMax3 = [](Variable& a, double b) { return max(a, b); }; + testImpl(funcMax1, funcMax2, funcMax3); } TEST(AutogradBinaryOpsTest, BinaryCrossEntropy) { - auto x = Variable(fl::rand({10}), true); - auto y = Variable(fl::rand({10}), true); - auto loss = binaryCrossEntropy(x, y); + auto x = Variable(fl::rand({10}), true); + auto y = Variable(fl::rand({10}), true); + auto loss = binaryCrossEntropy(x, y); - // bce loss should be positive - ASSERT_TRUE(fl::all(loss.tensor() > 0).scalar()); + // bce loss should be positive + ASSERT_TRUE(fl::all(loss.tensor() > 0).scalar()); } TEST(AutogradBinaryOpsTest, CrossEntropy) { - auto x = Variable(fl::rand({7, 10, 4}, fl::dtype::f64), true); - auto y = Variable( - (fl::rand({10, 4}, fl::dtype::u32) % 7).astype(fl::dtype::s32), false); - auto ignoreIdx = y(0, 0).scalar(); - - std::vector modes = { - ReduceMode::NONE, ReduceMode::SUM, ReduceMode::MEAN}; - for (auto mode : modes) { - auto func = [&](Variable& input) { - return categoricalCrossEntropy(input, y, mode); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(func, x, 1E-5)); - auto funcIgnore = [&](Variable& input) { - return categoricalCrossEntropy(input, y, mode, ignoreIdx); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcIgnore, x, 1E-5)); - } - - auto lossSum = categoricalCrossEntropy(x, y, ReduceMode::SUM); - auto lossMean = categoricalCrossEntropy(x, y, ReduceMode::MEAN); - ASSERT_NEAR((lossSum / lossMean).scalar(), 40, 1e-5); - - auto lossSumIgnore = - categoricalCrossEntropy(x, y, ReduceMode::SUM, ignoreIdx); - auto lossMeanIgnore = - categoricalCrossEntropy(x, y, ReduceMode::MEAN, ignoreIdx); - auto ignoreCount = fl::sum(y.tensor() == ignoreIdx).scalar(); - ASSERT_NEAR( - (lossSumIgnore / lossMeanIgnore).scalar(), - 40 - ignoreCount, - 1e-5); - - ASSERT_THROW( - categoricalCrossEntropy( - Variable(fl::rand({4, 5, 6}), false), - Variable(fl::rand({5, 8}), false)), - std::invalid_argument); - - ASSERT_THROW( - categoricalCrossEntropy( - Variable(fl::rand({4, 5, 6}), false), Variable(fl::rand({5}), false)), - std::invalid_argument); + auto x = Variable(fl::rand({7, 10, 4}, fl::dtype::f64), true); + auto y = Variable( + (fl::rand({10, 4}, fl::dtype::u32) % 7).astype(fl::dtype::s32), + false + ); + auto ignoreIdx = y(0, 0).scalar(); + + std::vector modes = { + ReduceMode::NONE, ReduceMode::SUM, ReduceMode::MEAN}; + for(auto mode : modes) { + auto func = [&](Variable& input) { + return categoricalCrossEntropy(input, y, mode); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(func, x, 1E-5)); + auto funcIgnore = [&](Variable& input) { + return categoricalCrossEntropy(input, y, mode, ignoreIdx); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcIgnore, x, 1E-5)); + } + + auto lossSum = categoricalCrossEntropy(x, y, ReduceMode::SUM); + auto lossMean = categoricalCrossEntropy(x, y, ReduceMode::MEAN); + ASSERT_NEAR((lossSum / lossMean).scalar(), 40, 1e-5); + + auto lossSumIgnore = + categoricalCrossEntropy(x, y, ReduceMode::SUM, ignoreIdx); + auto lossMeanIgnore = + categoricalCrossEntropy(x, y, ReduceMode::MEAN, ignoreIdx); + auto ignoreCount = fl::sum(y.tensor() == ignoreIdx).scalar(); + ASSERT_NEAR( + (lossSumIgnore / lossMeanIgnore).scalar(), + 40 - ignoreCount, + 1e-5 + ); + + ASSERT_THROW( + categoricalCrossEntropy( + Variable(fl::rand({4, 5, 6}), false), + Variable(fl::rand({5, 8}), false) + ), + std::invalid_argument + ); + + ASSERT_THROW( + categoricalCrossEntropy( + Variable(fl::rand({4, 5, 6}), false), + Variable(fl::rand({5}), false) + ), + std::invalid_argument + ); } TEST(AutogradBinaryOpsTest, Linear) { - std::vector batchsizes = {1, 5}; - for (auto b : batchsizes) { - auto in = Variable(fl::rand({3, 4, b}, fl::dtype::f64) * 2 - 1, true); - auto wt = Variable(fl::rand({6, 3}, fl::dtype::f64) * 2 - 1, true); - auto bs = Variable(fl::rand({6}, fl::dtype::f64) * 2 - 1, true); - auto funcLinIn = [&](Variable& input) { return linear(input, wt, bs); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinIn, in, 1E-8, 1E-4, {&wt, &bs})); - auto funcLinWt = [&](Variable& weight) { return linear(in, weight, bs); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinWt, wt, 1E-8, 1E-4, {&in, &bs})); - auto funcLinBs = [&](Variable& bias) { return linear(in, wt, bias); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinBs, bs, 1E-8, 1E-4, {&in, &wt})); - } + std::vector batchsizes = {1, 5}; + for(auto b : batchsizes) { + auto in = Variable(fl::rand({3, 4, b}, fl::dtype::f64) * 2 - 1, true); + auto wt = Variable(fl::rand({6, 3}, fl::dtype::f64) * 2 - 1, true); + auto bs = Variable(fl::rand({6}, fl::dtype::f64) * 2 - 1, true); + auto funcLinIn = [&](Variable& input) { return linear(input, wt, bs); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinIn, in, 1E-8, 1E-4, {&wt, &bs})); + auto funcLinWt = [&](Variable& weight) { return linear(in, weight, bs); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinWt, wt, 1E-8, 1E-4, {&in, &bs})); + auto funcLinBs = [&](Variable& bias) { return linear(in, wt, bias); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinBs, bs, 1E-8, 1E-4, {&in, &wt})); + } } TEST_F(AutogradTestF16, LinearF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - std::vector batchsizes = {1, 5}; - const float scale = 4.0; // scale prevent grad underflow - for (auto b : batchsizes) { - auto in = Variable(fl::rand({2, 2, b}, fl::dtype::f16) * scale, true); - auto wt = Variable(fl::rand({2, 2}, fl::dtype::f16) * scale, true); - auto bs = Variable(fl::rand({2}, fl::dtype::f16) * scale, true); - auto funcLinIn = [&](Variable& input) { return linear(input, wt, bs); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinIn, in, 5E-2, 5E-1, {&wt, &bs})); - auto funcLinWt = [&](Variable& weight) { return linear(in, weight, bs); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinWt, wt, 5E-2, 5E-1, {&in, &bs})); - auto funcLinBs = [&](Variable& bias) { return linear(in, wt, bias); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinBs, bs, 5E-2, 5E-1, {&in, &wt})); - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + std::vector batchsizes = {1, 5}; + const float scale = 4.0; // scale prevent grad underflow + for(auto b : batchsizes) { + auto in = Variable(fl::rand({2, 2, b}, fl::dtype::f16) * scale, true); + auto wt = Variable(fl::rand({2, 2}, fl::dtype::f16) * scale, true); + auto bs = Variable(fl::rand({2}, fl::dtype::f16) * scale, true); + auto funcLinIn = [&](Variable& input) { return linear(input, wt, bs); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinIn, in, 5E-2, 5E-1, {&wt, &bs})); + auto funcLinWt = [&](Variable& weight) { return linear(in, weight, bs); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinWt, wt, 5E-2, 5E-1, {&in, &bs})); + auto funcLinBs = [&](Variable& bias) { return linear(in, wt, bias); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLinBs, bs, 5E-2, 5E-1, {&in, &wt})); + } } TEST(AutogradBinaryOpsTest, Multiply) { - auto x = Variable(fl::rand({5}), true); - auto y = x * x; - auto dy = Variable(fl::full({5}, 1.0), false); - y.backward(dy); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dx.tensor(), 2 * x.tensor())); + auto x = Variable(fl::rand({5}), true); + auto y = x * x; + auto dy = Variable(fl::full({5}, 1.0), false); + y.backward(dy); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dx.tensor(), 2 * x.tensor())); } TEST(AutogradBinaryOpsTest, MultiplyAdd) { - auto x = Variable(fl::rand({5}), true); - auto y = Variable(fl::rand({5}), true); - auto z = x * x + x * y + y * y; - auto dz = Variable(fl::full({5}, 1.0), false); - z.backward(dz); - auto dx = x.grad(); - auto dy = y.grad(); - ASSERT_TRUE(allClose(dx.tensor(), 2 * x.tensor() + y.tensor())); - ASSERT_TRUE(allClose(dy.tensor(), 2 * y.tensor() + x.tensor())); + auto x = Variable(fl::rand({5}), true); + auto y = Variable(fl::rand({5}), true); + auto z = x * x + x * y + y * y; + auto dz = Variable(fl::full({5}, 1.0), false); + z.backward(dz); + auto dx = x.grad(); + auto dy = y.grad(); + ASSERT_TRUE(allClose(dx.tensor(), 2 * x.tensor() + y.tensor())); + ASSERT_TRUE(allClose(dy.tensor(), 2 * y.tensor() + x.tensor())); } TEST(AutogradBinaryOpsTest, MultiplyAddScalar) { - auto x = Variable(fl::rand({5}), true); - auto y = Variable(fl::rand({5}), true); - auto z = 2 * x + x * y + y; - auto dz = Variable(fl::full({5}, 1.0), false); - z.backward(dz); - auto dx = x.grad(); - auto dy = y.grad(); - ASSERT_TRUE(allClose(dx.tensor(), (2.0 + y.tensor()))); - ASSERT_TRUE(allClose(dy.tensor(), (1.0 + x.tensor()))); + auto x = Variable(fl::rand({5}), true); + auto y = Variable(fl::rand({5}), true); + auto z = 2 * x + x * y + y; + auto dz = Variable(fl::full({5}, 1.0), false); + z.backward(dz); + auto dx = x.grad(); + auto dy = y.grad(); + ASSERT_TRUE(allClose(dx.tensor(), (2.0 + y.tensor()))); + ASSERT_TRUE(allClose(dy.tensor(), (1.0 + x.tensor()))); } TEST(AutogradBinaryOpsTest, MultiplySub) { - auto x = Variable(fl::rand({5}), true); - auto y = Variable(fl::rand({5}), true); - auto z = x * x - x * y; - auto dz = Variable(fl::full({5}, 1.0), false); - z.backward(dz); - auto dx = x.grad(); - auto dy = y.grad(); - ASSERT_TRUE(allClose(dx.tensor(), (2 * x.tensor() - y.tensor()))); - ASSERT_TRUE(allClose(dy.tensor(), (-x.tensor()))); + auto x = Variable(fl::rand({5}), true); + auto y = Variable(fl::rand({5}), true); + auto z = x * x - x * y; + auto dz = Variable(fl::full({5}, 1.0), false); + z.backward(dz); + auto dx = x.grad(); + auto dy = y.grad(); + ASSERT_TRUE(allClose(dx.tensor(), (2 * x.tensor() - y.tensor()))); + ASSERT_TRUE(allClose(dy.tensor(), (-x.tensor()))); } TEST(AutogradBinaryOpsTest, DivideAdd) { - auto x = Variable(fl::rand({5}, fl::dtype::f64), true); - auto y = Variable(fl::rand({5}, fl::dtype::f64), true); - auto z = x + x / y + y; - auto dz = Variable(fl::full({5}, 1.0, fl::dtype::f64), false); - z.backward(dz); - auto dx = x.grad(); - auto dy = y.grad(); - ASSERT_EQ(z.type(), fl::dtype::f64); - ASSERT_TRUE(allClose(dx.tensor(), (1.0 + 1.0 / y.tensor()))); - ASSERT_TRUE( - allClose(dy.tensor(), (1.0 - x.tensor() / (y.tensor() * y.tensor())))); + auto x = Variable(fl::rand({5}, fl::dtype::f64), true); + auto y = Variable(fl::rand({5}, fl::dtype::f64), true); + auto z = x + x / y + y; + auto dz = Variable(fl::full({5}, 1.0, fl::dtype::f64), false); + z.backward(dz); + auto dx = x.grad(); + auto dy = y.grad(); + ASSERT_EQ(z.type(), fl::dtype::f64); + ASSERT_TRUE(allClose(dx.tensor(), (1.0 + 1.0 / y.tensor()))); + ASSERT_TRUE( + allClose(dy.tensor(), (1.0 - x.tensor() / (y.tensor() * y.tensor()))) + ); } TEST(AutogradBinaryOpsTest, matmul) { - unsigned M = 10; - unsigned K = 12; - unsigned N = 14; - unsigned b2 = 2; - unsigned b3 = 4; - auto mk = Shape({M, K}); - auto mkb2 = Shape({M, K, b2}); // 1 batch dim - auto mkb2b3 = Shape({M, K, b2, b3}); // 2 batch dims - auto kn = Shape({K, N}); - auto knb2 = Shape({K, N, b2}); // 1 batch dim - auto knb2b3 = Shape({K, N, b2, b3}); // 2 batch dims - - // lhs, rhs - std::vector> inputs = { - {mk, kn}, - {mk, knb2}, - {mk, knb2b3}, - {mkb2, kn}, - {mkb2, knb2}, - {mkb2b3, kn}, - {mkb2b3, knb2b3}}; - - auto trFirstTwoDims = [](const Shape& in) -> Shape { - Shape out = in; - auto out1 = out[1]; - out[1] = out[0]; - out[0] = out1; - return out; - }; - - for (auto& pair : inputs) { - auto& aShape = pair.first; - auto& bShape = pair.second; - - auto a = Variable(fl::rand(aShape, fl::dtype::f64) * 2 - 1, true); - auto b = Variable(fl::rand(bShape, fl::dtype::f64) * 2 - 1, true); - - auto aT = Variable(fl::rand(trFirstTwoDims(aShape), fl::dtype::f64), true); - auto bT = Variable(fl::rand(trFirstTwoDims(bShape), fl::dtype::f64), true); - - // matmul - auto funcMatmulLhs = [&](Variable& input) { return matmul(input, b); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulLhs, a, 1E-6)) + unsigned M = 10; + unsigned K = 12; + unsigned N = 14; + unsigned b2 = 2; + unsigned b3 = 4; + auto mk = Shape({M, K}); + auto mkb2 = Shape({M, K, b2}); // 1 batch dim + auto mkb2b3 = Shape({M, K, b2, b3}); // 2 batch dims + auto kn = Shape({K, N}); + auto knb2 = Shape({K, N, b2}); // 1 batch dim + auto knb2b3 = Shape({K, N, b2, b3}); // 2 batch dims + + // lhs, rhs + std::vector> inputs = { + {mk, kn}, + {mk, knb2}, + {mk, knb2b3}, + {mkb2, kn}, + {mkb2, knb2}, + {mkb2b3, kn}, + {mkb2b3, knb2b3}}; + + auto trFirstTwoDims = [](const Shape& in) -> Shape { + Shape out = in; + auto out1 = out[1]; + out[1] = out[0]; + out[0] = out1; + return out; + }; + + for(auto& pair : inputs) { + auto& aShape = pair.first; + auto& bShape = pair.second; + + auto a = Variable(fl::rand(aShape, fl::dtype::f64) * 2 - 1, true); + auto b = Variable(fl::rand(bShape, fl::dtype::f64) * 2 - 1, true); + + auto aT = Variable(fl::rand(trFirstTwoDims(aShape), fl::dtype::f64), true); + auto bT = Variable(fl::rand(trFirstTwoDims(bShape), fl::dtype::f64), true); + + // matmul + auto funcMatmulLhs = [&](Variable& input) { return matmul(input, b); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulLhs, a, 1E-6)) << "matmul lhs gradient: lhs " << a.shape() << " rhs " << b.shape(); - auto funcMatmulRhs = [&](Variable& input) { return matmul(a, input); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulRhs, b, 1E-6)) + auto funcMatmulRhs = [&](Variable& input) { return matmul(a, input); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulRhs, b, 1E-6)) << "matmul rhs gradient: lhs " << a.shape() << " rhs " << b.shape(); - // matmulTN - auto funcMatmulTNLhs = [&](Variable& input) { return matmulTN(input, b); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulTNLhs, aT, 1E-6)) + // matmulTN + auto funcMatmulTNLhs = [&](Variable& input) { return matmulTN(input, b); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulTNLhs, aT, 1E-6)) << "matmulTN lhs gradient: lhs " << a.shape() << " rhs " << b.shape(); - auto funcMatmulTNRhs = [&](Variable& input) { return matmulTN(aT, input); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulTNRhs, b, 1E-6)) + auto funcMatmulTNRhs = [&](Variable& input) { return matmulTN(aT, input); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulTNRhs, b, 1E-6)) << "matmulTN rhs gradient: lhs " << a.shape() << " rhs " << b.shape(); - // matmulNT - auto funcMatmulNTLhs = [&](Variable& input) { return matmulNT(input, bT); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulNTLhs, a, 1E-6)) + // matmulNT + auto funcMatmulNTLhs = [&](Variable& input) { return matmulNT(input, bT); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulNTLhs, a, 1E-6)) << "matmulTN lhs gradient: lhs " << a.shape() << " rhs " << b.shape(); - auto funcMatmulNTRhs = [&](Variable& input) { return matmulNT(a, input); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulNTRhs, bT, 1E-6)) + auto funcMatmulNTRhs = [&](Variable& input) { return matmulNT(a, input); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulNTRhs, bT, 1E-6)) << "matmulTN rhs gradient: lhs " << a.shape() << " rhs " << b.shape(); - } + } } TEST(AutogradNormalizationTest, WeightNormLinear) { - auto v = Variable(fl::rand({3, 2}), true); - auto normDim = {1}; - auto g = Variable(norm(v, normDim).tensor(), true); - auto in = Variable(fl::rand({2, 3}, fl::dtype::f32), true); - - auto funcWeightNormIn = [&](Variable& input) { - auto w = v * tileAs(g / norm(v, normDim), v); - return matmul(w, input); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormIn, in, 1E-3)); - - auto funcWeightNormV = [&](Variable& input) { - auto w = input * tileAs(g / norm(input, normDim), input); - return matmul(w, in); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormV, v, 1E-2)); - - auto funcWeightNormG = [&](Variable& input) { - auto w = v * tileAs(input / norm(v, normDim), v); - return matmul(w, in); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormG, g, 5E-3)); + auto v = Variable(fl::rand({3, 2}), true); + auto normDim = {1}; + auto g = Variable(norm(v, normDim).tensor(), true); + auto in = Variable(fl::rand({2, 3}, fl::dtype::f32), true); + + auto funcWeightNormIn = [&](Variable& input) { + auto w = v * tileAs(g / norm(v, normDim), v); + return matmul(w, input); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormIn, in, 1E-3)); + + auto funcWeightNormV = [&](Variable& input) { + auto w = input * tileAs(g / norm(input, normDim), input); + return matmul(w, in); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormV, v, 1E-2)); + + auto funcWeightNormG = [&](Variable& input) { + auto w = v * tileAs(input / norm(v, normDim), v); + return matmul(w, in); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormG, g, 5E-3)); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/autograd/AutogradConv2DTest.cpp b/flashlight/fl/test/autograd/AutogradConv2DTest.cpp index c13f3c6..21329a5 100644 --- a/flashlight/fl/test/autograd/AutogradConv2DTest.cpp +++ b/flashlight/fl/test/autograd/AutogradConv2DTest.cpp @@ -19,265 +19,283 @@ using namespace fl; using fl::detail::AutogradTestF16; TEST(AutogradConv2DTest, Convolve) { - auto in = Variable(fl::rand({10, 9, 8, 7}, fl::dtype::f32), true); - auto wt = Variable(fl::rand({4, 3, 8, 6}, fl::dtype::f32), true); - auto bs = Variable(fl::rand({1, 1, 6, 1}, fl::dtype::f32), true); + auto in = Variable(fl::rand({10, 9, 8, 7}, fl::dtype::f32), true); + auto wt = Variable(fl::rand({4, 3, 8, 6}, fl::dtype::f32), true); + auto bs = Variable(fl::rand({1, 1, 6, 1}, fl::dtype::f32), true); - int px = 2, py = 1; - int sx = 1, sy = 1; - int dx = 1, dy = 1; - auto benchmarks = std::make_shared(); - auto funcConvIn = [&](Variable& input) { - return conv2d( - input, - wt, - // bs, - sx, - sy, - px, - py, - dx, - dy, - /* groups */ 1, - benchmarks); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvIn, in, 0.06, 1E-4, {&wt})); + int px = 2, py = 1; + int sx = 1, sy = 1; + int dx = 1, dy = 1; + auto benchmarks = std::make_shared(); + auto funcConvIn = [&](Variable& input) { + return conv2d( + input, + wt, + // bs, + sx, + sy, + px, + py, + dx, + dy, + /* groups */ 1, + benchmarks + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvIn, in, 0.06, 1E-4, {&wt})); - auto funcConvWt = [&](Variable& weight) { - return conv2d( - in, - weight, - // bs, - sx, - sy, - px, - py, - dx, - dy, - /* groups */ 1, - benchmarks); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvWt, wt, 0.06, 1E-4, {&in})); + auto funcConvWt = [&](Variable& weight) { + return conv2d( + in, + weight, + // bs, + sx, + sy, + px, + py, + dx, + dy, + /* groups */ 1, + benchmarks + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvWt, wt, 0.06, 1E-4, {&in})); - auto funcConvBs = [&](Variable& bias) { - return conv2d( - in, - wt, - bias, - sx, - sy, - px, - py, - dx, - dy, - /* groups */ 1, - benchmarks); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvBs, bs, 0.03, 1E-4, {&in, &wt})); + auto funcConvBs = [&](Variable& bias) { + return conv2d( + in, + wt, + bias, + sx, + sy, + px, + py, + dx, + dy, + /* groups */ 1, + benchmarks + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvBs, bs, 0.03, 1E-4, {&in, &wt})); } TEST_F(AutogradTestF16, ConvolveF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } - const float scaleFactor = 10.0; // scale the input to prevent grad underflow - auto in = - Variable(fl::rand({3, 1, 2, 1}, fl::dtype::f16) * scaleFactor, true); - auto wt = Variable(fl::rand({2, 1, 2, 1}, fl::dtype::f16), true); - auto bs = Variable(fl::rand({1, 1, 1, 1}, fl::dtype::f16), true); + const float scaleFactor = 10.0; // scale the input to prevent grad underflow + auto in = + Variable(fl::rand({3, 1, 2, 1}, fl::dtype::f16) * scaleFactor, true); + auto wt = Variable(fl::rand({2, 1, 2, 1}, fl::dtype::f16), true); + auto bs = Variable(fl::rand({1, 1, 1, 1}, fl::dtype::f16), true); - int px = 1, py = 1; - int sx = 1, sy = 1; - int dx = 1, dy = 1; - auto benchmarks = std::make_shared(); - auto funcConvIn = [&](Variable& input) { - return conv2d( - input, - wt, - bs, - sx, - sy, - px, - py, - dx, - dy, - /* groups */ 1, - benchmarks); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvIn, in, 5e-1, 0.1, {&wt, &bs})); + int px = 1, py = 1; + int sx = 1, sy = 1; + int dx = 1, dy = 1; + auto benchmarks = std::make_shared(); + auto funcConvIn = [&](Variable& input) { + return conv2d( + input, + wt, + bs, + sx, + sy, + px, + py, + dx, + dy, + /* groups */ 1, + benchmarks + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvIn, in, 5e-1, 0.1, {&wt, &bs})); - auto funcConvWt = [&](Variable& weight) { - return conv2d( - in, - weight, - bs, - sx, - sy, - px, - py, - dx, - dy, - /* groups */ 1, - benchmarks); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvWt, wt, 5e-2, 0.1, {&in, &bs})); + auto funcConvWt = [&](Variable& weight) { + return conv2d( + in, + weight, + bs, + sx, + sy, + px, + py, + dx, + dy, + /* groups */ 1, + benchmarks + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvWt, wt, 5e-2, 0.1, {&in, &bs})); - auto funcConvBs = [&](Variable& bias) { - return conv2d( - in, - wt, - bias, - sx, - sy, - px, - py, - dx, - dy, - /* groups */ 1, - benchmarks); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvBs, bs, 3e-2, 0.1, {&in, &wt})); + auto funcConvBs = [&](Variable& bias) { + return conv2d( + in, + wt, + bias, + sx, + sy, + px, + py, + dx, + dy, + /* groups */ 1, + benchmarks + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvBs, bs, 3e-2, 0.1, {&in, &wt})); } TEST(AutogradConv2DTest, ConvolveFilterGroups) { - int channel = 8; - int groups = 2; - // w x h x c x b - auto in = Variable(fl::rand({10, 9, channel, 7}, fl::dtype::f32), true); - // w x h x in x out - auto wt = - Variable(fl::rand({4, 3, channel / groups, 6}, fl::dtype::f32), true); - auto bs = Variable(fl::rand({1, 1, 6, 1}, fl::dtype::f32), true); + int channel = 8; + int groups = 2; + // w x h x c x b + auto in = Variable(fl::rand({10, 9, channel, 7}, fl::dtype::f32), true); + // w x h x in x out + auto wt = + Variable(fl::rand({4, 3, channel / groups, 6}, fl::dtype::f32), true); + auto bs = Variable(fl::rand({1, 1, 6, 1}, fl::dtype::f32), true); - int px = 2, py = 1; - int sx = 1, sy = 1; - int dx = 1, dy = 1; - auto funcConvIn = [&](Variable& input) { - return conv2d(input, wt, bs, sx, sy, px, py, dx, dy, groups); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvIn, in, 0.06, 1E-4, {&wt, &bs})); - auto funcConvWt = [&](Variable& weight) { - return conv2d(in, weight, bs, sx, sy, px, py, dx, dy, groups); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvWt, wt, 0.05, 1E-4, {&in, &bs})); - auto foncConvBs = [&](Variable& bias) { - return conv2d(in, wt, bias, sx, sy, px, py, dx, dy, groups); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(foncConvBs, bs, 0.02, 1E-4, {&in, &wt})); + int px = 2, py = 1; + int sx = 1, sy = 1; + int dx = 1, dy = 1; + auto funcConvIn = [&](Variable& input) { + return conv2d(input, wt, bs, sx, sy, px, py, dx, dy, groups); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvIn, in, 0.06, 1E-4, {&wt, &bs})); + auto funcConvWt = [&](Variable& weight) { + return conv2d(in, weight, bs, sx, sy, px, py, dx, dy, groups); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvWt, wt, 0.05, 1E-4, {&in, &bs})); + auto foncConvBs = [&](Variable& bias) { + return conv2d(in, wt, bias, sx, sy, px, py, dx, dy, groups); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(foncConvBs, bs, 0.02, 1E-4, {&in, &wt})); } TEST(AutogradConv2DTest, ConvolveDilation) { - auto in = Variable(fl::rand({10, 9, 8, 7}, fl::dtype::f32), true); - auto wt = Variable(fl::rand({4, 3, 8, 6}, fl::dtype::f32), true); - auto bs = Variable(fl::rand({1, 1, 6, 1}, fl::dtype::f32), true); - int px = 2, py = 1; - int sx = 1, sy = 1; - int dx = 2, dy = 1; - auto funcConvIn = [&](Variable& input) { - return conv2d( - input, - wt, - bs, - sx, - sy, - px, - py, - dx, - dy, - /* groups */ 1); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvIn, in, 0.06, 1E-4, {&wt, &bs})); - auto funcConvWt = [&](Variable& weight) { - return conv2d( - in, - weight, - bs, - sx, - sy, - px, - py, - dx, - dy, - /* groups */ 1); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvWt, wt, 0.05, 1E-4, {&in, &bs})); - auto funcConvBs = [&](Variable& bias) { - return conv2d( - in, - wt, - bias, - sx, - sy, - px, - py, - dx, - dy, - /* groups */ 1); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvBs, bs, 0.02, 1E-4, {&in, &wt})); + auto in = Variable(fl::rand({10, 9, 8, 7}, fl::dtype::f32), true); + auto wt = Variable(fl::rand({4, 3, 8, 6}, fl::dtype::f32), true); + auto bs = Variable(fl::rand({1, 1, 6, 1}, fl::dtype::f32), true); + int px = 2, py = 1; + int sx = 1, sy = 1; + int dx = 2, dy = 1; + auto funcConvIn = [&](Variable& input) { + return conv2d( + input, + wt, + bs, + sx, + sy, + px, + py, + dx, + dy, + /* groups */ 1 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvIn, in, 0.06, 1E-4, {&wt, &bs})); + auto funcConvWt = [&](Variable& weight) { + return conv2d( + in, + weight, + bs, + sx, + sy, + px, + py, + dx, + dy, + /* groups */ 1 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvWt, wt, 0.05, 1E-4, {&in, &bs})); + auto funcConvBs = [&](Variable& bias) { + return conv2d( + in, + wt, + bias, + sx, + sy, + px, + py, + dx, + dy, + /* groups */ 1 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConvBs, bs, 0.02, 1E-4, {&in, &wt})); } TEST(AutogradConv2DTest, WeightNormConv) { - auto v = Variable(fl::rand({3, 3, 3, 8}), true); - auto normDim = {0, 1, 2}; - auto g = Variable( - norm(v, normDim, /* p = */ 2, /* keepDims = */ true).tensor(), true); - auto in = Variable(fl::rand({7, 7, 3, 8}) * 2 - 2, true); + auto v = Variable(fl::rand({3, 3, 3, 8}), true); + auto normDim = {0, 1, 2}; + auto g = Variable( + norm(v, normDim, /* p = */ 2, /* keepDims = */ true).tensor(), + true + ); + auto in = Variable(fl::rand({7, 7, 3, 8}) * 2 - 2, true); - auto funcWeightNormIn = [&](Variable& input) { - auto w = v * - tileAs(g / norm(v, normDim, /* p = */ 2, /* keepDims = */ true), v); - return conv2d( - input, - w, - /* sx */ 1, - /* sy */ 1, - /* px */ 0, - /* py */ 0, - /* dx */ 1, - /* dy */ 1, - /* groups */ 1); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormIn, in, 3E-1, 1E-4, {&v, &g})); + auto funcWeightNormIn = [&](Variable& input) { + auto w = v + * tileAs(g / norm(v, normDim, /* p = */ 2, /* keepDims = */ true), v); + return conv2d( + input, + w, + /* sx */ 1, + /* sy */ 1, + /* px */ 0, + /* py */ 0, + /* dx */ 1, + /* dy */ 1, + /* groups */ 1 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormIn, in, 3E-1, 1E-4, {&v, &g})); - auto funcWeightNormV = [&](Variable& input) { - auto w = input * - tileAs(g / norm(input, normDim, /* p = */ 2, /* keepDims = */ true), - input); - return conv2d( - in, - w, - /* sx */ 1, - /* sy */ 1, - /* px */ 0, - /* py */ 0, - /* dx */ 1, - /* dy */ 1, - /* groups */ 1); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormV, v, 2E-1, 1E-4, {&g, &in})); + auto funcWeightNormV = [&](Variable& input) { + auto w = input + * tileAs( + g / norm(input, normDim, /* p = */ 2, /* keepDims = */ true), + input + ); + return conv2d( + in, + w, + /* sx */ 1, + /* sy */ 1, + /* px */ 0, + /* py */ 0, + /* dx */ 1, + /* dy */ 1, + /* groups */ 1 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormV, v, 2E-1, 1E-4, {&g, &in})); - auto funcWeightNormG = [&](Variable& input) { - auto w = v * - tileAs(input / norm(v, normDim, /* p = */ 2, /* keepDims = */ true), - v); - return conv2d( - in, - w, - /* sx */ 1, - /* sy */ 1, - /* px */ 0, - /* py */ 0, - /* dx */ 1, - /* dy */ 1, - /* groups */ 1); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormG, g, 2E-1, 1E-4, {&v, &in})); + auto funcWeightNormG = [&](Variable& input) { + auto w = v + * tileAs( + input / norm(v, normDim, /* p = */ 2, /* keepDims = */ true), + v + ); + return conv2d( + in, + w, + /* sx */ 1, + /* sy */ 1, + /* px */ 0, + /* py */ 0, + /* dx */ 1, + /* dy */ 1, + /* groups */ 1 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcWeightNormG, g, 2E-1, 1E-4, {&v, &in})); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp b/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp index 7663b9b..51f69b7 100644 --- a/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp +++ b/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp @@ -22,397 +22,549 @@ using namespace fl; using fl::detail::AutogradTestF16; TEST(AutogradNormalizationTest, Normalize) { - auto x = Variable(fl::rand({5, 3}, fl::dtype::f64), true); - auto funcNormalize2 = [](Variable& in) { return normalize(in, {1}); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcNormalize2, x)); - auto ys = funcNormalize2(x); - ASSERT_TRUE(allClose( - fl::sum(ys.tensor() * ys.tensor(), {1}), - fl::full({5}, 1, fl::dtype::f64))); - auto yb = normalize(x, {1}, 2, 1); - ASSERT_TRUE(fl::all(fl::sqrt(fl::sum(yb.tensor() * yb.tensor(), {1})) <= 1) - .scalar()); + auto x = Variable(fl::rand({5, 3}, fl::dtype::f64), true); + auto funcNormalize2 = [](Variable& in) { return normalize(in, {1}); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcNormalize2, x)); + auto ys = funcNormalize2(x); + ASSERT_TRUE( + allClose( + fl::sum(ys.tensor() * ys.tensor(), {1}), + fl::full({5}, 1, fl::dtype::f64) + ) + ); + auto yb = normalize(x, {1}, 2, 1); + ASSERT_TRUE( + fl::all(fl::sqrt(fl::sum(yb.tensor() * yb.tensor(), {1})) <= 1) + .scalar() + ); } TEST(AutogradNormalizationTest, BatchNormEvalModeOutputSingleAxis) { - int featDims = 3; - std::vector featAxes = {2}; - // input order: HWCN, following the docs - auto input = Variable(fl::rand({13, 13, featDims, 16}), false); - auto runningMean = Variable(fl::rand({featDims}, input.type()), false); - auto runningVar = Variable(fl::rand({featDims}, input.type()), false); - auto weight = Variable(fl::rand({featDims}, input.type()), false); - auto bias = Variable(fl::rand({featDims}, input.type()), false); - - auto out = (batchnorm( - input, - weight, - bias, - runningMean, - runningVar, - featAxes, - false, - 0.0, - 1E-5)); - for (int i = 0; i < featDims; ++i) { - std::array sel = {fl::span, fl::span, i, fl::span}; - auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); - auto thisMean = runningMean.tensor().flatten()(i).scalar(); - auto thisVar = runningVar.tensor().flatten()(i).scalar(); - auto thisWeight = weight.tensor().flatten()(i).scalar(); - auto thisBias = bias.tensor().flatten()(i).scalar(); - - auto expectedOut = (thisInput - thisMean) / std::sqrt(thisVar + 1E-5); - expectedOut = expectedOut * thisWeight + thisBias; - ASSERT_TRUE(allClose( - out.tensor()(sel[0], sel[1], sel[2], sel[3]), expectedOut, 1E-5)); - } - - // test on empty weigts and bias - out = (batchnorm( - input, - Variable(), - Variable(), - runningMean, - runningVar, - featAxes, - false, - 0.0, - 1E-5)); - for (int i = 0; i < featDims; ++i) { - std::array sel = {fl::span, fl::span, i, fl::span}; - auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); - auto thisMean = runningMean.tensor().flatten()(i).scalar(); - auto thisVar = runningVar.tensor().flatten()(i).scalar(); - - auto expectedOut = (thisInput - thisMean) / std::sqrt(thisVar + 1E-5); - ASSERT_TRUE(allClose( - out.tensor()(sel[0], sel[1], sel[2], sel[3]), expectedOut, 1E-5)); - } + int featDims = 3; + std::vector featAxes = {2}; + // input order: HWCN, following the docs + auto input = Variable(fl::rand({13, 13, featDims, 16}), false); + auto runningMean = Variable(fl::rand({featDims}, input.type()), false); + auto runningVar = Variable(fl::rand({featDims}, input.type()), false); + auto weight = Variable(fl::rand({featDims}, input.type()), false); + auto bias = Variable(fl::rand({featDims}, input.type()), false); + + auto out = (batchnorm( + input, + weight, + bias, + runningMean, + runningVar, + featAxes, + false, + 0.0, + 1E-5 + )); + for(int i = 0; i < featDims; ++i) { + std::array sel = {fl::span, fl::span, i, fl::span}; + auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); + auto thisMean = runningMean.tensor().flatten()(i).scalar(); + auto thisVar = runningVar.tensor().flatten()(i).scalar(); + auto thisWeight = weight.tensor().flatten()(i).scalar(); + auto thisBias = bias.tensor().flatten()(i).scalar(); + + auto expectedOut = (thisInput - thisMean) / std::sqrt(thisVar + 1E-5); + expectedOut = expectedOut * thisWeight + thisBias; + ASSERT_TRUE(allClose( + out.tensor()(sel[0], sel[1], sel[2], sel[3]), expectedOut, 1E-5)); + } + + // test on empty weigts and bias + out = (batchnorm( + input, + Variable(), + Variable(), + runningMean, + runningVar, + featAxes, + false, + 0.0, + 1E-5 + )); + for(int i = 0; i < featDims; ++i) { + std::array sel = {fl::span, fl::span, i, fl::span}; + auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); + auto thisMean = runningMean.tensor().flatten()(i).scalar(); + auto thisVar = runningVar.tensor().flatten()(i).scalar(); + + auto expectedOut = (thisInput - thisMean) / std::sqrt(thisVar + 1E-5); + ASSERT_TRUE(allClose( + out.tensor()(sel[0], sel[1], sel[2], sel[3]), expectedOut, 1E-5)); + } } TEST(AutogradNormalizationTest, BatchNormEvalModeOutputMultipleAxis) { - // input order: HWCN, following the docs - std::vector featAxes = {0, 1, 2}; - auto input = Variable(fl::rand({13, 13, 4, 16}), false); - - auto nfeatures = 1; - for (auto ax : featAxes) { - nfeatures *= input.dim(ax); - } - auto runningMean = Variable(fl::rand({nfeatures}, input.type()), false); - auto runningVar = Variable(fl::rand({nfeatures}, input.type()), false); - auto weight = Variable(fl::rand({nfeatures}, input.type()), false); - auto bias = Variable(fl::rand({nfeatures}, input.type()), false); - - auto out = (batchnorm( - input, - weight, - bias, - runningMean, - runningVar, - featAxes, - false, - 0.0, - 1E-5)); - for (int i = 0; i < nfeatures; ++i) { - std::array sel = { - i % 13, (i / 13) % 13, (i / 13) / 13, fl::span}; - auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); - auto thisMean = runningMean.tensor().flatten()(i).scalar(); - auto thisVar = runningVar.tensor().flatten()(i).scalar(); - auto thisWeight = weight.tensor().flatten()(i).scalar(); - auto thisBias = bias.tensor().flatten()(i).scalar(); - - auto expectedOut = (thisInput - thisMean) / std::sqrt(thisVar + 1e-5); - expectedOut = expectedOut * thisWeight + thisBias; - - ASSERT_TRUE(allClose( - out.tensor()(sel[0], sel[1], sel[2], sel[3]), expectedOut, 1e-4)); - } - - // test on empty weigts and bias - out = (batchnorm( - input, - Variable(), - Variable(), - runningMean, - runningVar, - featAxes, - false, - 0.0, - 1E-5)); - for (int i = 0; i < nfeatures; ++i) { - std::array sel = { - i % 13, (i / 13) % 13, (i / 13) / 13, fl::span}; - auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); - auto thisMean = runningMean.tensor().flatten()(i).scalar(); - auto thisVar = runningVar.tensor().flatten()(i).scalar(); - - auto expectedOut = (thisInput - thisMean) / std::sqrt(thisVar + 1e-5); - ASSERT_TRUE(allClose( - out.tensor()(sel[0], sel[1], sel[2], sel[3]), expectedOut, 5e-5)); - } + // input order: HWCN, following the docs + std::vector featAxes = {0, 1, 2}; + auto input = Variable(fl::rand({13, 13, 4, 16}), false); + + auto nfeatures = 1; + for(auto ax : featAxes) { + nfeatures *= input.dim(ax); + } + auto runningMean = Variable(fl::rand({nfeatures}, input.type()), false); + auto runningVar = Variable(fl::rand({nfeatures}, input.type()), false); + auto weight = Variable(fl::rand({nfeatures}, input.type()), false); + auto bias = Variable(fl::rand({nfeatures}, input.type()), false); + + auto out = (batchnorm( + input, + weight, + bias, + runningMean, + runningVar, + featAxes, + false, + 0.0, + 1E-5 + )); + for(int i = 0; i < nfeatures; ++i) { + std::array sel = { + i % 13, (i / 13) % 13, (i / 13) / 13, fl::span}; + auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); + auto thisMean = runningMean.tensor().flatten()(i).scalar(); + auto thisVar = runningVar.tensor().flatten()(i).scalar(); + auto thisWeight = weight.tensor().flatten()(i).scalar(); + auto thisBias = bias.tensor().flatten()(i).scalar(); + + auto expectedOut = (thisInput - thisMean) / std::sqrt(thisVar + 1e-5); + expectedOut = expectedOut * thisWeight + thisBias; + + ASSERT_TRUE(allClose( + out.tensor()(sel[0], sel[1], sel[2], sel[3]), expectedOut, 1e-4)); + } + + // test on empty weigts and bias + out = (batchnorm( + input, + Variable(), + Variable(), + runningMean, + runningVar, + featAxes, + false, + 0.0, + 1E-5 + )); + for(int i = 0; i < nfeatures; ++i) { + std::array sel = { + i % 13, (i / 13) % 13, (i / 13) / 13, fl::span}; + auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); + auto thisMean = runningMean.tensor().flatten()(i).scalar(); + auto thisVar = runningVar.tensor().flatten()(i).scalar(); + + auto expectedOut = (thisInput - thisMean) / std::sqrt(thisVar + 1e-5); + ASSERT_TRUE(allClose( + out.tensor()(sel[0], sel[1], sel[2], sel[3]), expectedOut, 5e-5)); + } } TEST(AutogradNormalizationTest, BatchNormTrainModeOutputSingleAxis) { - int numFeat = 3; - std::vector featAxes = {2}; - double epsilon = 1E-5; - auto input = Variable(fl::rand({13, 13, numFeat, 8}), true); - auto weight = Variable(fl::rand({numFeat}), true); - auto bias = Variable(fl::rand({numFeat}), true); - auto runningMean = Variable(fl::rand({numFeat}), false); - auto runningVar = Variable(fl::rand({numFeat}), false); - - auto out = batchnorm( - input, - weight, - bias, - runningMean, - runningVar, - featAxes, - true, - 0.0, - epsilon); - - auto todim = Shape({1, 1, numFeat}); - std::vector nrmAxes = {0, 1, 3}; - auto avg = moddims(mean(input, nrmAxes), todim); - auto variance = - moddims(var(input, nrmAxes, true /* population var */), todim); - auto expectedOut = (input - tileAs(avg, input)) / - fl::sqrt(tileAs(variance, input) + epsilon); - expectedOut = expectedOut * tileAs(moddims(weight, todim), input) + - tileAs(moddims(bias, todim), input); - ASSERT_TRUE(allClose(out.tensor(), expectedOut.tensor(), 1e-5)); + int numFeat = 3; + std::vector featAxes = {2}; + double epsilon = 1E-5; + auto input = Variable(fl::rand({13, 13, numFeat, 8}), true); + auto weight = Variable(fl::rand({numFeat}), true); + auto bias = Variable(fl::rand({numFeat}), true); + auto runningMean = Variable(fl::rand({numFeat}), false); + auto runningVar = Variable(fl::rand({numFeat}), false); + + auto out = batchnorm( + input, + weight, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + epsilon + ); + + auto todim = Shape({1, 1, numFeat}); + std::vector nrmAxes = {0, 1, 3}; + auto avg = moddims(mean(input, nrmAxes), todim); + auto variance = + moddims(var(input, nrmAxes, true /* population var */), todim); + auto expectedOut = (input - tileAs(avg, input)) + / fl::sqrt(tileAs(variance, input) + epsilon); + expectedOut = expectedOut * tileAs(moddims(weight, todim), input) + + tileAs(moddims(bias, todim), input); + ASSERT_TRUE(allClose(out.tensor(), expectedOut.tensor(), 1e-5)); } TEST(AutogradNormalizationTest, BatchNormTrainModeOutputMultipleAxis) { - std::vector featAxes = {0, 1, 2}; - auto input = Variable(fl::rand({13, 13, 4, 8}), true); - - auto nfeatures = 1; - for (auto ax : featAxes) { - nfeatures *= input.dim(ax); - } - auto weight = Variable(fl::rand({nfeatures}), true); - auto bias = Variable(fl::rand({nfeatures}), true); - auto runningMean = Variable(fl::rand({nfeatures}), false); - auto runningVar = Variable(fl::rand({nfeatures}), false); - - auto out = batchnorm( - input, weight, bias, runningMean, runningVar, featAxes, true, 0.0, 1E-5); - - auto todim = Shape({nfeatures}); - std::vector nrmAxes = {3}; - auto avg = moddims(mean(input, nrmAxes), todim); - auto variance = moddims(var(input, nrmAxes, true), todim); - - for (int i = 0; i < nfeatures; ++i) { - std::array sel = { - i % 13, (i / 13) % 13, (i / 13) / 13, fl::span}; - auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); - auto thisMean = avg.tensor().flatten()(i).scalar(); - auto thisVar = variance.tensor().flatten()(i).scalar(); - auto thisWeight = weight.tensor().flatten()(i).scalar(); - auto thisBias = bias.tensor().flatten()(i).scalar(); - - auto expectedOut = (thisInput - thisMean) / std::sqrt(thisVar + 1e-5); - expectedOut = expectedOut * thisWeight + thisBias; - ASSERT_TRUE(allClose( - out.tensor()(sel[0], sel[1], sel[2], sel[3]), expectedOut, 1e-5)); - } + std::vector featAxes = {0, 1, 2}; + auto input = Variable(fl::rand({13, 13, 4, 8}), true); + + auto nfeatures = 1; + for(auto ax : featAxes) { + nfeatures *= input.dim(ax); + } + auto weight = Variable(fl::rand({nfeatures}), true); + auto bias = Variable(fl::rand({nfeatures}), true); + auto runningMean = Variable(fl::rand({nfeatures}), false); + auto runningVar = Variable(fl::rand({nfeatures}), false); + + auto out = batchnorm( + input, + weight, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + + auto todim = Shape({nfeatures}); + std::vector nrmAxes = {3}; + auto avg = moddims(mean(input, nrmAxes), todim); + auto variance = moddims(var(input, nrmAxes, true), todim); + + for(int i = 0; i < nfeatures; ++i) { + std::array sel = { + i % 13, (i / 13) % 13, (i / 13) / 13, fl::span}; + auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); + auto thisMean = avg.tensor().flatten()(i).scalar(); + auto thisVar = variance.tensor().flatten()(i).scalar(); + auto thisWeight = weight.tensor().flatten()(i).scalar(); + auto thisBias = bias.tensor().flatten()(i).scalar(); + + auto expectedOut = (thisInput - thisMean) / std::sqrt(thisVar + 1e-5); + expectedOut = expectedOut * thisWeight + thisBias; + ASSERT_TRUE(allClose( + out.tensor()(sel[0], sel[1], sel[2], sel[3]), expectedOut, 1e-5)); + } } TEST(AutogradNormalizationTest, BatchNormJacobian) { - // Jacobian Test with trainMode = true; - - int numFeat = 3; - std::vector featAxes = {2}; - auto input = Variable(fl::rand({8, 8, numFeat, 16}, fl::dtype::f32), true); - auto runningMean = Variable(fl::rand({numFeat}, fl::dtype::f32), false); - auto runningVar = Variable(fl::rand({numFeat}, fl::dtype::f32), false); - auto weight = Variable(fl::rand({numFeat}, fl::dtype::f32), true); - auto bias = Variable(fl::rand({numFeat}, fl::dtype::f32), true); - - auto funcBnIn = [&](Variable& in) { - return (batchnorm( - in, weight, bias, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - - - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnIn, input, 1e-2, 1e-4, {&weight, &bias})); - - auto funcBnWt = [&](Variable& wt) { - return (batchnorm( - input, wt, bias, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnWt, weight, 1e-2, 1e-4, {&input, &bias})); - - - auto funcBnBs = [&](Variable& bs) { - return (batchnorm( - input, weight, bs, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnBs, bias, 1e-2, 1e-4, {&input, &weight})); + // Jacobian Test with trainMode = true; + + int numFeat = 3; + std::vector featAxes = {2}; + auto input = Variable(fl::rand({8, 8, numFeat, 16}, fl::dtype::f32), true); + auto runningMean = Variable(fl::rand({numFeat}, fl::dtype::f32), false); + auto runningVar = Variable(fl::rand({numFeat}, fl::dtype::f32), false); + auto weight = Variable(fl::rand({numFeat}, fl::dtype::f32), true); + auto bias = Variable(fl::rand({numFeat}, fl::dtype::f32), true); + + auto funcBnIn = [&](Variable& in) { + return batchnorm( + in, + weight, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + + + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnIn, input, 1e-2, 1e-4, {&weight, &bias})); + + auto funcBnWt = [&](Variable& wt) { + return batchnorm( + input, + wt, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnWt, weight, 1e-2, 1e-4, {&input, &bias})); + + + auto funcBnBs = [&](Variable& bs) { + return batchnorm( + input, + weight, + bs, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnBs, bias, 1e-2, 1e-4, {&input, &weight})); } TEST_F(AutogradTestF16, BatchNormJacobianF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - // Jacobian Test with trainMode = true; - - int numFeat = 3; - std::vector featAxes = {2}; - auto input = Variable(fl::rand({8, 8, numFeat, 16}, fl::dtype::f16), true); - auto runningMean = Variable(fl::rand({numFeat}, fl::dtype::f32), false); - auto runningVar = Variable(fl::rand({numFeat}, fl::dtype::f32), false); - auto weight = Variable(fl::rand({numFeat}, fl::dtype::f32), true); - auto bias = Variable(fl::rand({numFeat}, fl::dtype::f32), true); - - // Use larger perturbations to ensure gradients don't underflow with fp16 - - auto funcBnIn = [&](Variable& in) { - return (batchnorm( - in, weight, bias, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnIn, input, 5e-2, 1e-1, {&weight, &bias})); - - auto funcBnWt = [&](Variable& wt) { - return (batchnorm( - input, wt, bias, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnWt, weight, 5e-2, 1e-1, {&input, &bias})); - - auto funcBnBs = [&](Variable& bs) { - return (batchnorm( - input, weight, bs, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnBs, bias, 5e-2, 1e-1, {&input, &weight})); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + // Jacobian Test with trainMode = true; + + int numFeat = 3; + std::vector featAxes = {2}; + auto input = Variable(fl::rand({8, 8, numFeat, 16}, fl::dtype::f16), true); + auto runningMean = Variable(fl::rand({numFeat}, fl::dtype::f32), false); + auto runningVar = Variable(fl::rand({numFeat}, fl::dtype::f32), false); + auto weight = Variable(fl::rand({numFeat}, fl::dtype::f32), true); + auto bias = Variable(fl::rand({numFeat}, fl::dtype::f32), true); + + // Use larger perturbations to ensure gradients don't underflow with fp16 + + auto funcBnIn = [&](Variable& in) { + return batchnorm( + in, + weight, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnIn, input, 5e-2, 1e-1, {&weight, &bias})); + + auto funcBnWt = [&](Variable& wt) { + return batchnorm( + input, + wt, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnWt, weight, 5e-2, 1e-1, {&input, &bias})); + + auto funcBnBs = [&](Variable& bs) { + return batchnorm( + input, + weight, + bs, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnBs, bias, 5e-2, 1e-1, {&input, &weight})); } TEST(AutogradNormalizationTest, BatchNormJacobianMultipleAxes) { - // Jacobian Test with trainMode = true; - std::vector featAxes = {0, 1, 2}; - auto input = Variable(fl::rand({4, 4, 3, 4}, fl::dtype::f32), true); - auto nfeatures = 1; - for (auto ax : featAxes) { - nfeatures *= input.dim(ax); - } - auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); - auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); - auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); - auto bias = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); - - auto funcBnIn = [&](Variable& in) { - return (batchnorm( - in, weight, bias, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnIn, input, 1e-2, 1e-3, {&weight, &bias})); - - auto funcBnWt = [&](Variable& wt) { - return (batchnorm( - input, wt, bias, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnWt, weight, 1e-2, 1e-3, {&input, &bias})); - - auto funcBnBs = [&](Variable& bs) { - return (batchnorm( - input, weight, bs, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnBs, bias, 1e-2, 1e-3, {&input, &weight})); + // Jacobian Test with trainMode = true; + std::vector featAxes = {0, 1, 2}; + auto input = Variable(fl::rand({4, 4, 3, 4}, fl::dtype::f32), true); + auto nfeatures = 1; + for(auto ax : featAxes) { + nfeatures *= input.dim(ax); + } + auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); + auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); + auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); + auto bias = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); + + auto funcBnIn = [&](Variable& in) { + return batchnorm( + in, + weight, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnIn, input, 1e-2, 1e-3, {&weight, &bias})); + + auto funcBnWt = [&](Variable& wt) { + return batchnorm( + input, + wt, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnWt, weight, 1e-2, 1e-3, {&input, &bias})); + + auto funcBnBs = [&](Variable& bs) { + return batchnorm( + input, + weight, + bs, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnBs, bias, 1e-2, 1e-3, {&input, &weight})); } TEST_F(AutogradTestF16, BatchNormJacobianMultipleAxesF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - // Jacobian Test with trainMode = true; - std::vector featAxes = {0, 1, 2}; - auto input = Variable(fl::rand({2, 2, 2, 1}, fl::dtype::f16), true); - auto nfeatures = 1; - for (auto ax : featAxes) { - nfeatures *= input.dim(ax); - } - auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); - auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); - auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); - auto bias = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); - - // Use larger perturbations to ensure gradients don't underflow with fp16 - - auto funcBnIn = [&](Variable& in) { - return (batchnorm( - in, weight, bias, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl( - funcBnIn, input, 5e-2, 1e-1, {&weight, &bias})); // TODO: investigate - - auto funcBnWt = [&](Variable& wt) { - return (batchnorm( - input, wt, bias, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnWt, weight, 5e-2, 1e-1, {&input, &bias})); - - auto funcBnBs = [&](Variable& bs) { - return (batchnorm( - input, weight, bs, runningMean, runningVar, featAxes, true, 0.0, 1E-5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnBs, bias, 5e-2, 1e-1, {&input, &weight})); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + // Jacobian Test with trainMode = true; + std::vector featAxes = {0, 1, 2}; + auto input = Variable(fl::rand({2, 2, 2, 1}, fl::dtype::f16), true); + auto nfeatures = 1; + for(auto ax : featAxes) { + nfeatures *= input.dim(ax); + } + auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); + auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); + auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); + auto bias = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); + + // Use larger perturbations to ensure gradients don't underflow with fp16 + + auto funcBnIn = [&](Variable& in) { + return batchnorm( + in, + weight, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + ASSERT_TRUE( + fl::detail::jacobianTestImpl( + funcBnIn, + input, + 5e-2, + 1e-1, + {&weight, &bias} + ) + ); // TODO: investigate + + auto funcBnWt = [&](Variable& wt) { + return batchnorm( + input, + wt, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnWt, weight, 5e-2, 1e-1, {&input, &bias})); + + auto funcBnBs = [&](Variable& bs) { + return batchnorm( + input, + weight, + bs, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnBs, bias, 5e-2, 1e-1, {&input, &weight})); } TEST(AutogradNormalizationTest, LayerNormJacobian) { - std::vector featAxes = {0, 1, 2, 3}; - auto input = Variable(fl::rand({7, 7, 3, 10}), true); - auto nfeatures = 1; - for (auto ax : featAxes) { - nfeatures *= input.dim(ax); - } - auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); - auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); - auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); - auto bias = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); - - auto funcLnIn = [&](Variable& in) { - return batchnorm( - in, weight, bias, runningMean, runningVar, featAxes, true, 0.0, 1E-5); - }; - - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLnIn, input, 1e-2, 1e-4, {&weight, &bias})); + std::vector featAxes = {0, 1, 2, 3}; + auto input = Variable(fl::rand({7, 7, 3, 10}), true); + auto nfeatures = 1; + for(auto ax : featAxes) { + nfeatures *= input.dim(ax); + } + auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); + auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); + auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); + auto bias = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); + + auto funcLnIn = [&](Variable& in) { + return batchnorm( + in, + weight, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLnIn, input, 1e-2, 1e-4, {&weight, &bias})); } TEST_F(AutogradTestF16, LayerNormJacobianF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - std::vector featAxes = {0, 1, 2, 3}; - const float inputScale = 4.0; // scale the input to prevent grad underflow - auto input = - Variable(inputScale * fl::rand({2, 2, 2, 4}, fl::dtype::f16), true); - auto nfeatures = 1; - for (auto ax : featAxes) { - nfeatures *= input.dim(ax); - } - auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); - auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); - auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); - auto bias = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); - - auto funcLnIn = [&](Variable& in) { - return batchnorm( - in, weight, bias, runningMean, runningVar, featAxes, true, 0.0, 1E-5); - }; - - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLnIn, input, 1e-4, 1e-2, {&weight, &bias})); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + std::vector featAxes = {0, 1, 2, 3}; + const float inputScale = 4.0; // scale the input to prevent grad underflow + auto input = + Variable(inputScale * fl::rand({2, 2, 2, 4}, fl::dtype::f16), true); + auto nfeatures = 1; + for(auto ax : featAxes) { + nfeatures *= input.dim(ax); + } + auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); + auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); + auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); + auto bias = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); + + auto funcLnIn = [&](Variable& in) { + return batchnorm( + in, + weight, + bias, + runningMean, + runningVar, + featAxes, + true, + 0.0, + 1E-5 + ); + }; + + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLnIn, input, 1e-4, 1e-2, {&weight, &bias})); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/autograd/AutogradReductionTest.cpp b/flashlight/fl/test/autograd/AutogradReductionTest.cpp index 423369e..c0d1690 100644 --- a/flashlight/fl/test/autograd/AutogradReductionTest.cpp +++ b/flashlight/fl/test/autograd/AutogradReductionTest.cpp @@ -18,145 +18,145 @@ using namespace fl; using fl::detail::AutogradTestF16; TEST(AutogradReductionTest, Sum) { - for (const bool keepDims : {false, true}) { - Shape s = {6}; - if (keepDims) { - s = {6, 1}; + for(const bool keepDims : {false, true}) { + Shape s = {6}; + if(keepDims) { + s = {6, 1}; + } + + auto x = Variable(fl::rand(s), true); + auto y = Variable(fl::rand({6, 3}), true); + + auto z = x * sum(y, {1}, keepDims); + auto dz = Variable(fl::full(s, 1.0), false); + z.backward(dz); + + auto dy = y.grad(); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dy.tensor(), fl::tile(x.tensor(), {1, 3}))); + ASSERT_TRUE(allClose(dx.tensor(), fl::sum(y.tensor(), {1}, keepDims))); + + // Reduce over 1-dim input + auto funcMean_0 = [keepDims](const Variable& in) { + return sum(in, {0}, keepDims); + }; + auto in = Variable(fl::rand({6}), true); + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, in, 5E-3)); + // Reduce over scalar input + auto inScalar = Variable(fl::fromScalar(3.14), true); + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, inScalar, 5E-3)); } - auto x = Variable(fl::rand(s), true); - auto y = Variable(fl::rand({6, 3}), true); + auto r = Variable(fl::rand({5, 6, 7, 8}), true); + auto rOut = sum(r, {1, 2}); + auto rOutTensor = fl::sum(r.tensor(), {1, 2}); + ASSERT_TRUE(allClose(rOut.tensor(), rOutTensor)); +} - auto z = x * sum(y, {1}, keepDims); - auto dz = Variable(fl::full(s, 1.0), false); +TEST(AutogradReductionTest, SumAs) { + auto x = Variable(fl::rand({5}), true); + auto y = Variable(fl::rand({5, 2}), true); + auto z = x * sumAs(y, x); + auto dz = Variable(fl::full({5}, 1.0), false); z.backward(dz); - auto dy = y.grad(); auto dx = x.grad(); - ASSERT_TRUE(allClose(dy.tensor(), fl::tile(x.tensor(), {1, 3}))); - ASSERT_TRUE(allClose(dx.tensor(), fl::sum(y.tensor(), {1}, keepDims))); - - // Reduce over 1-dim input - auto funcMean_0 = [keepDims](const Variable& in) { - return sum(in, {0}, keepDims); - }; - auto in = Variable(fl::rand({6}), true); - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, in, 5E-3)); - // Reduce over scalar input - auto inScalar = Variable(fl::fromScalar(3.14), true); - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, inScalar, 5E-3)); - } - - auto r = Variable(fl::rand({5, 6, 7, 8}), true); - auto rOut = sum(r, {1, 2}); - auto rOutTensor = fl::sum(r.tensor(), {1, 2}); - ASSERT_TRUE(allClose(rOut.tensor(), rOutTensor)); -} - -TEST(AutogradReductionTest, SumAs) { - auto x = Variable(fl::rand({5}), true); - auto y = Variable(fl::rand({5, 2}), true); - auto z = x * sumAs(y, x); - auto dz = Variable(fl::full({5}, 1.0), false); - z.backward(dz); - auto dy = y.grad(); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dy.tensor(), fl::tile(x.tensor(), {1, 2}))); - ASSERT_TRUE(allClose(dx.tensor(), fl::sum(y.tensor(), {1}))); + ASSERT_TRUE(allClose(dy.tensor(), fl::tile(x.tensor(), {1, 2}))); + ASSERT_TRUE(allClose(dx.tensor(), fl::sum(y.tensor(), {1}))); } TEST(AutogradReductionTest, SumAs2) { - auto y = Variable(fl::rand({5, 2}), true); - auto z = sumAs(y, {5}); - auto dz = Variable(fl::full({5}, 1.0), false); - z.backward(dz); - auto dy = y.grad(); - ASSERT_TRUE(allClose(dy.tensor(), fl::full({5, 2}, 1.0))); + auto y = Variable(fl::rand({5, 2}), true); + auto z = sumAs(y, {5}); + auto dz = Variable(fl::full({5}, 1.0), false); + z.backward(dz); + auto dy = y.grad(); + ASSERT_TRUE(allClose(dy.tensor(), fl::full({5, 2}, 1.0))); } TEST(AutogradReductionTest, Mean) { - for (const bool keepDims : {false, true}) { - Shape xShape = keepDims ? Shape({5, 1, 1}) : Shape({5}); - auto x = Variable(fl::rand(xShape), true); - auto y = Variable(fl::rand({5, 3, 2}), true); - auto varOut = mean(y, {1, 2}, keepDims); - auto z = x * mean(y, {1, 2}, keepDims); - auto dz = Variable(fl::full(x.shape(), 1.0), false); - z.backward(dz); - auto dy = y.grad(); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dy.tensor(), fl::tile(x.tensor(), {1, 3, 2}) / 6)); - ASSERT_TRUE(allClose(dx.tensor(), fl::mean(y.tensor(), {1, 2}, keepDims))); - - auto a = Variable(fl::rand({5, 3, 2}, fl::dtype::f64), true); - auto funcMean = [keepDims](Variable& in) { - return mean(in, {1, 2}, keepDims); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean, a, 1E-4)); - - auto q = Variable(fl::rand({5, 6, 7, 8}), false); - auto qOut = mean(q, {1, 2}, keepDims); - auto qOutTensor = fl::mean(q.tensor(), {1, 2}, keepDims); - ASSERT_TRUE(allClose(qOut.tensor(), qOutTensor)); - - auto funcMean_0 = [keepDims](Variable& in) { - return mean(in, {0}, keepDims); - }; - // Reduce over 1-dim input - auto in = Variable(fl::rand({6}), true); - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, in, 5E-3)); - // Reduce over scalar input - auto inScalar = Variable(fl::fromScalar(3.14), true); - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, inScalar, 5E-3)); - } + for(const bool keepDims : {false, true}) { + Shape xShape = keepDims ? Shape({5, 1, 1}) : Shape({5}); + auto x = Variable(fl::rand(xShape), true); + auto y = Variable(fl::rand({5, 3, 2}), true); + auto varOut = mean(y, {1, 2}, keepDims); + auto z = x * mean(y, {1, 2}, keepDims); + auto dz = Variable(fl::full(x.shape(), 1.0), false); + z.backward(dz); + auto dy = y.grad(); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dy.tensor(), fl::tile(x.tensor(), {1, 3, 2}) / 6)); + ASSERT_TRUE(allClose(dx.tensor(), fl::mean(y.tensor(), {1, 2}, keepDims))); + + auto a = Variable(fl::rand({5, 3, 2}, fl::dtype::f64), true); + auto funcMean = [keepDims](Variable& in) { + return mean(in, {1, 2}, keepDims); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean, a, 1E-4)); + + auto q = Variable(fl::rand({5, 6, 7, 8}), false); + auto qOut = mean(q, {1, 2}, keepDims); + auto qOutTensor = fl::mean(q.tensor(), {1, 2}, keepDims); + ASSERT_TRUE(allClose(qOut.tensor(), qOutTensor)); + + auto funcMean_0 = [keepDims](Variable& in) { + return mean(in, {0}, keepDims); + }; + // Reduce over 1-dim input + auto in = Variable(fl::rand({6}), true); + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, in, 5E-3)); + // Reduce over scalar input + auto inScalar = Variable(fl::fromScalar(3.14), true); + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMean_0, inScalar, 5E-3)); + } } TEST(AutogradReductionTest, Variance) { - std::vector biased = {true, false}; - for (auto b : biased) { - for (const bool keepDims : {false, true}) { - auto x = Variable(fl::rand({5, 6, 7, 8}, fl::dtype::f64), true); - - // TODO:{fl::Tensor} -- enforce AF versioning and remediate - // Behavior of the bias parameter in af::var was changed in - // https://git.io/Jv5gF and is different in ArrayFire v3.7. If isbiased is - // true, sample variance rather than population variance is used. The - // flashlight API implements the opposite behavior to be consistent with - // other libraries. - bool afVarBiasArg = !b; - - auto expectedVar = fl::var(x.tensor(), {1}, afVarBiasArg, keepDims); - auto calculatedVar = var(x, {1}, b, keepDims); - ASSERT_TRUE(allClose(calculatedVar.tensor(), expectedVar)); - - auto funcVar = [b, keepDims](Variable& in) { - return var(in, {1, 2}, b, keepDims); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcVar, x, 1E-5, 1E-5)); + std::vector biased = {true, false}; + for(auto b : biased) { + for(const bool keepDims : {false, true}) { + auto x = Variable(fl::rand({5, 6, 7, 8}, fl::dtype::f64), true); + + // TODO:{fl::Tensor} -- enforce AF versioning and remediate + // Behavior of the bias parameter in af::var was changed in + // https://git.io/Jv5gF and is different in ArrayFire v3.7. If isbiased is + // true, sample variance rather than population variance is used. The + // flashlight API implements the opposite behavior to be consistent with + // other libraries. + bool afVarBiasArg = !b; + + auto expectedVar = fl::var(x.tensor(), {1}, afVarBiasArg, keepDims); + auto calculatedVar = var(x, {1}, b, keepDims); + ASSERT_TRUE(allClose(calculatedVar.tensor(), expectedVar)); + + auto funcVar = [b, keepDims](Variable& in) { + return var(in, {1, 2}, b, keepDims); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcVar, x, 1E-5, 1E-5)); + } } - } } TEST(AutogradReductionTest, Norm) { - auto x = Variable(fl::rand({5, 3}, fl::dtype::f64), true); - for (const bool keepDims : {false, true}) { - auto funcNorm2 = [keepDims](Variable& in) { - return norm(in, {1}, 2, keepDims); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcNorm2, x, 1E-4)); - auto funcNorm1 = [keepDims](Variable& in) { - return norm(in, {1}, 1, keepDims); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcNorm1, x, 1E-4)); - auto funcNorm3 = [keepDims](Variable& in) { - return norm(in, {1}, 3, keepDims); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcNorm3, x, 1E-4)); - } + auto x = Variable(fl::rand({5, 3}, fl::dtype::f64), true); + for(const bool keepDims : {false, true}) { + auto funcNorm2 = [keepDims](Variable& in) { + return norm(in, {1}, 2, keepDims); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcNorm2, x, 1E-4)); + auto funcNorm1 = [keepDims](Variable& in) { + return norm(in, {1}, 1, keepDims); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcNorm1, x, 1E-4)); + auto funcNorm3 = [keepDims](Variable& in) { + return norm(in, {1}, 3, keepDims); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcNorm3, x, 1E-4)); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/autograd/AutogradRnnTest.cpp b/flashlight/fl/test/autograd/AutogradRnnTest.cpp index a0eb335..08337d2 100644 --- a/flashlight/fl/test/autograd/AutogradRnnTest.cpp +++ b/flashlight/fl/test/autograd/AutogradRnnTest.cpp @@ -35,178 +35,203 @@ void testRnnImpl(RnnMode mode, fl::dtype precision = fl::dtype::f64) { Variable(fl::rand({inputSize, batchSize, seqLength}, precision), true); size_t nParams; - switch (mode) { + switch(mode) { case RnnMode::TANH: - nParams = 56; - break; + nParams = 56; + break; case RnnMode::LSTM: - nParams = 224; - break; + nParams = 224; + break; case RnnMode::GRU: - nParams = 168; - break; + nParams = 168; + break; default: - throw std::invalid_argument("invalid RNN mode for the test"); + throw std::invalid_argument("invalid RNN mode for the test"); } auto w = Variable(fl::rand({static_cast(nParams)}, precision), true); auto funcRnnIn = [&](Variable& input) -> Variable { - return std::get<0>( - rnn(input, - Variable().astype(precision), - Variable().astype(precision), - w, - hiddenSize, - numLayers, - mode, - bidirectional, - 0.0)); - }; + return std::get<0>( + rnn( + input, + Variable().astype(precision), + Variable().astype(precision), + w, + hiddenSize, + numLayers, + mode, + bidirectional, + 0.0 + ) + ); + }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcRnnIn, in, expectedPrecision, perturbation, {&w})); auto funcRnnW = [&](Variable& weights) -> Variable { - return std::get<0>( - rnn(in, - Variable().astype(precision), - Variable().astype(precision), - weights, - hiddenSize, - numLayers, - mode, - bidirectional, - 0.0)); - }; + return std::get<0>( + rnn( + in, + Variable().astype(precision), + Variable().astype(precision), + weights, + hiddenSize, + numLayers, + mode, + bidirectional, + 0.0 + ) + ); + }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcRnnW, w, expectedPrecision, perturbation, {&in})); // We get the correct gradient for hx auto hx = Variable( fl::rand( {inputSize, batchSize, numLayers * (1 + bidirectional)}, - fl::dtype::f64), - true); + fl::dtype::f64 + ), + true + ); auto funcRnnHx = [&](Variable& hiddenState) -> Variable { - return std::get<0>( - rnn(in, - hiddenState.astype(precision), - Variable().astype(precision), - w, - hiddenSize, - numLayers, - mode, - bidirectional, - 0.0)); - }; + return std::get<0>( + rnn( + in, + hiddenState.astype(precision), + Variable().astype(precision), + w, + hiddenSize, + numLayers, + mode, + bidirectional, + 0.0 + ) + ); + }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcRnnHx, hx, expectedPrecision, perturbation, {&in, &w})); // We can compute the gradient w.r.t. hy auto funcRnnInDhy = [&](Variable& input) -> Variable { - return std::get<1>( - rnn(input, - Variable().astype(precision), - Variable().astype(precision), - w, - hiddenSize, - numLayers, - mode, - bidirectional, - 0.0)); - }; + return std::get<1>( + rnn( + input, + Variable().astype(precision), + Variable().astype(precision), + w, + hiddenSize, + numLayers, + mode, + bidirectional, + 0.0 + ) + ); + }; ASSERT_TRUE( - fl::detail::jacobianTestImpl(funcRnnInDhy, in, expectedPrecision, perturbation, {&w})); + fl::detail::jacobianTestImpl(funcRnnInDhy, in, expectedPrecision, perturbation, {&w}) + ); - if (mode == RnnMode::LSTM) { + if(mode == RnnMode::LSTM) { // We get the correct gradient for cx auto cx = Variable( fl::rand( {inputSize, batchSize, numLayers * (1 + bidirectional)}, - fl::dtype::f64), - true); + fl::dtype::f64 + ), + true + ); auto funcRnnCx = [&](Variable& cellState) -> Variable { - return std::get<0>( - rnn(in, - Variable().astype(precision), - cellState.astype(precision), - w, - hiddenSize, - numLayers, - mode, - bidirectional, - 0.0)); - }; + return std::get<0>( + rnn( + in, + Variable().astype(precision), + cellState.astype(precision), + w, + hiddenSize, + numLayers, + mode, + bidirectional, + 0.0 + ) + ); + }; ASSERT_TRUE( - fl::detail::jacobianTestImpl(funcRnnCx, cx, expectedPrecision, perturbation, {&in, &w})); + fl::detail::jacobianTestImpl(funcRnnCx, cx, expectedPrecision, perturbation, {&in, &w}) + ); // We can compute the gradient w.r.t. cy auto funcRnnInDcy = [&](Variable& input) -> Variable { - return std::get<2>( - rnn(input, - Variable().astype(precision), - Variable().astype(precision), - w, - hiddenSize, - numLayers, - mode, - bidirectional, - 0.0)); - }; + return std::get<2>( + rnn( + input, + Variable().astype(precision), + Variable().astype(precision), + w, + hiddenSize, + numLayers, + mode, + bidirectional, + 0.0 + ) + ); + }; ASSERT_TRUE( - fl::detail::jacobianTestImpl(funcRnnInDcy, in, expectedPrecision, perturbation, {&w})); + fl::detail::jacobianTestImpl(funcRnnInDcy, in, expectedPrecision, perturbation, {&w}) + ); } } } TEST(AutogradRnnTest, Rnn) { - if (FL_BACKEND_CPU) { - GTEST_SKIP() << "RNN gradient computation not yet supported on CPU"; - } + if(FL_BACKEND_CPU) { + GTEST_SKIP() << "RNN gradient computation not yet supported on CPU"; + } - testRnnImpl(RnnMode::TANH); + testRnnImpl(RnnMode::TANH); } TEST(AutogradRnnTest, Lstm) { - if (FL_BACKEND_CPU) { - GTEST_SKIP() << "RNN LSTM graident computation not yet supported on CPU"; - } + if(FL_BACKEND_CPU) { + GTEST_SKIP() << "RNN LSTM graident computation not yet supported on CPU"; + } - testRnnImpl(RnnMode::LSTM); + testRnnImpl(RnnMode::LSTM); } TEST(AutogradRnnTest, Gru) { - if (FL_BACKEND_CPU) { - GTEST_SKIP() << "RNN GRU graident computation not yet supported on CPU"; - } - testRnnImpl(RnnMode::GRU); + if(FL_BACKEND_CPU) { + GTEST_SKIP() << "RNN GRU graident computation not yet supported on CPU"; + } + testRnnImpl(RnnMode::GRU); } TEST_F(AutogradTestF16, RnnF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } - testRnnImpl(RnnMode::TANH, fl::dtype::f16); + testRnnImpl(RnnMode::TANH, fl::dtype::f16); } TEST_F(AutogradTestF16, LstmF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } - testRnnImpl(RnnMode::LSTM, fl::dtype::f16); + testRnnImpl(RnnMode::LSTM, fl::dtype::f16); } TEST_F(AutogradTestF16, GruF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } - testRnnImpl(RnnMode::GRU, fl::dtype::f16); + testRnnImpl(RnnMode::GRU, fl::dtype::f16); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/autograd/AutogradTest.cpp b/flashlight/fl/test/autograd/AutogradTest.cpp index f506e32..46015a8 100644 --- a/flashlight/fl/test/autograd/AutogradTest.cpp +++ b/flashlight/fl/test/autograd/AutogradTest.cpp @@ -25,66 +25,151 @@ using namespace fl; using fl::detail::AutogradTestF16; TEST(AutogradTest, OperatorParenthesis) { - auto x = Variable(fl::rand({1, 3, 3}, fl::dtype::f64), true); - auto y = x(0, 0) + x(0, 1); - auto funcOperatorParen = [](Variable& in) { return in(0, 0) + in(0, 1); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcOperatorParen, x)); + auto x = Variable(fl::rand({1, 3, 3}, fl::dtype::f64), true); + auto y = x(0, 0) + x(0, 1); + auto funcOperatorParen = [](Variable& in) { return in(0, 0) + in(0, 1); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcOperatorParen, x)); } TEST(AutogradTest, AutogradOperatorTypeCompatibility) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - auto f16 = Variable(fl::rand({2, 2}, fl::dtype::f16), true); - auto f32 = Variable(fl::rand({2, 2}, fl::dtype::f32), true); - - // Binary operators - EXPECT_THROW({ auto res = f16 + f32; }, std::invalid_argument); // + - EXPECT_THROW({ auto res = f16 - f32; }, std::invalid_argument); // - - EXPECT_THROW({ auto res = f16 * f32; }, std::invalid_argument); // * - EXPECT_THROW({ auto res = f16 / f32; }, std::invalid_argument); // / - EXPECT_THROW({ auto res = f16 > f32; }, std::invalid_argument); // > - EXPECT_THROW({ auto res = f16 < f32; }, std::invalid_argument); // < - EXPECT_THROW({ auto res = f16 >= f32; }, std::invalid_argument); // >= - EXPECT_THROW({ auto res = f16 <= f32; }, std::invalid_argument); // <= - EXPECT_THROW({ auto res = f16 && f32; }, std::invalid_argument); // && - EXPECT_THROW({ max(f16, f32); }, std::invalid_argument); // max - EXPECT_THROW({ min(f16, f32); }, std::invalid_argument); // min - EXPECT_THROW({ matmul(f16, f32); }, std::invalid_argument); // matmul - EXPECT_THROW({ matmulTN(f16, f32); }, std::invalid_argument); // matmulTN - EXPECT_THROW({ matmulNT(f16, f32); }, std::invalid_argument); // matmulNT - EXPECT_NO_THROW({ binaryCrossEntropy(f16, f32); }); - EXPECT_NO_THROW({ - categoricalCrossEntropy( - Variable(fl::rand({7, 10, 4}, fl::dtype::f16), true), - Variable( - (fl::rand({10, 4}, fl::dtype::u32) % 7).astype(fl::dtype::s32), - false)); - }); - EXPECT_NO_THROW({ pool2d(f16, 1, 1, 1, 1, 1, 1); }); - EXPECT_NO_THROW({ embedding(f16, f32); }); // lookup is of a different type - // Ternary operators - auto f32_2 = Variable(fl::rand({2, 2}, fl::dtype::f32), true); - auto f16_2 = Variable(fl::rand({2, 2}, fl::dtype::f16), true); - EXPECT_THROW({ linear(f16, f32, f16_2); }, std::invalid_argument); // linear - EXPECT_THROW({ linear(f16, f32, f32_2); }, std::invalid_argument); // linear - auto w = Variable(fl::rand({1}, fl::dtype::f32), true); - auto b = Variable(fl::rand({1}, fl::dtype::f32), true); - EXPECT_THROW( - { batchnorm(f16, f32, f32_2, w, b, {1}, true, 0.01, 0.01); }, - std::invalid_argument); - EXPECT_THROW( - { batchnorm(f16, f32, f16_2, w, b, {1}, true, 0.01, 0.01); }, - std::invalid_argument); - EXPECT_THROW( - { conv2d(f16, f32, f16_2, 1, 1, 0, 0, 1, 1); }, std::invalid_argument); - // Quaternary - auto f16_3 = Variable(fl::rand({2, 2, 3}, fl::dtype::f16), false); - auto f16_4 = Variable(fl::rand({50}, fl::dtype::f16), false); - EXPECT_THROW( - { - rnn(f16_3, + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + auto f16 = Variable(fl::rand({2, 2}, fl::dtype::f16), true); + auto f32 = Variable(fl::rand({2, 2}, fl::dtype::f32), true); + + // Binary operators + EXPECT_THROW( + {auto res = f16 + f32; + }, + std::invalid_argument + ); // + + EXPECT_THROW( + {auto res = f16 - f32; + }, + std::invalid_argument + ); // - + EXPECT_THROW( + {auto res = f16 * f32; + }, + std::invalid_argument + ); // * + EXPECT_THROW( + {auto res = f16 / f32; + }, + std::invalid_argument + ); /// + EXPECT_THROW( + {auto res = f16 > f32; + }, + std::invalid_argument + ); // > + EXPECT_THROW( + {auto res = f16 < f32; + }, + std::invalid_argument + ); // < + EXPECT_THROW( + {auto res = f16 >= f32; + }, + std::invalid_argument + ); // >= + EXPECT_THROW( + {auto res = f16 <= f32; + }, + std::invalid_argument + ); // <= + EXPECT_THROW( + {auto res = f16 && f32; + }, + std::invalid_argument + ); // && + EXPECT_THROW( + {max(f16, f32); + }, + std::invalid_argument + ); // max + EXPECT_THROW( + {min(f16, f32); + }, + std::invalid_argument + ); // min + EXPECT_THROW( + {matmul(f16, f32); + }, + std::invalid_argument + ); // matmul + EXPECT_THROW( + {matmulTN(f16, f32); + }, + std::invalid_argument + ); // matmulTN + EXPECT_THROW( + {matmulNT(f16, f32); + }, + std::invalid_argument + ); // matmulNT + EXPECT_NO_THROW( + {binaryCrossEntropy(f16, f32); + } + ); + EXPECT_NO_THROW( + { + categoricalCrossEntropy( + Variable(fl::rand({7, 10, 4}, fl::dtype::f16), true), + Variable( + (fl::rand({10, 4}, fl::dtype::u32) % 7).astype(fl::dtype::s32), + false + ) + ); + } + ); + EXPECT_NO_THROW( + {pool2d(f16, 1, 1, 1, 1, 1, 1); + } + ); + EXPECT_NO_THROW( + {embedding(f16, f32); + } + ); // lookup is of a different type + // Ternary operators + auto f32_2 = Variable(fl::rand({2, 2}, fl::dtype::f32), true); + auto f16_2 = Variable(fl::rand({2, 2}, fl::dtype::f16), true); + EXPECT_THROW( + {linear(f16, f32, f16_2); + }, + std::invalid_argument + ); // linear + EXPECT_THROW( + {linear(f16, f32, f32_2); + }, + std::invalid_argument + ); // linear + auto w = Variable(fl::rand({1}, fl::dtype::f32), true); + auto b = Variable(fl::rand({1}, fl::dtype::f32), true); + EXPECT_THROW( + {batchnorm(f16, f32, f32_2, w, b, {1}, true, 0.01, 0.01); + }, + std::invalid_argument + ); + EXPECT_THROW( + {batchnorm(f16, f32, f16_2, w, b, {1}, true, 0.01, 0.01); + }, + std::invalid_argument + ); + EXPECT_THROW( + {conv2d(f16, f32, f16_2, 1, 1, 0, 0, 1, 1); + }, + std::invalid_argument + ); + // Quaternary + auto f16_3 = Variable(fl::rand({2, 2, 3}, fl::dtype::f16), false); + auto f16_4 = Variable(fl::rand({50}, fl::dtype::f16), false); + EXPECT_THROW( + { + rnn( + f16_3, Variable(Tensor(fl::dtype::f32), false), Variable(Tensor(fl::dtype::f32), false), f16_4, @@ -92,345 +177,364 @@ TEST(AutogradTest, AutogradOperatorTypeCompatibility) { 2, RnnMode::LSTM, true, - 0.0); - }, - std::invalid_argument); - // Variadic operators - std::vector concatInputs = {f16, f32, f16_2, f32_2}; - EXPECT_THROW({ concatenate(concatInputs, 0); }, std::invalid_argument); + 0.0 + ); + }, + std::invalid_argument + ); + // Variadic operators + std::vector concatInputs = {f16, f32, f16_2, f32_2}; + EXPECT_THROW( + {concatenate(concatInputs, 0); + }, + std::invalid_argument + ); } TEST(AutogradTest, CastingAsDifferentGradTypes) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - auto f32 = Variable(fl::rand({5, 5}), true); - auto f16 = Variable(fl::rand({5, 5}, fl::dtype::f16), true); - // Computing gradients with mixed types fails when the op is applied - ASSERT_THROW({ f32 + f16; }, std::invalid_argument); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + auto f32 = Variable(fl::rand({5, 5}), true); + auto f16 = Variable(fl::rand({5, 5}, fl::dtype::f16), true); + // Computing gradients with mixed types fails when the op is applied + ASSERT_THROW( + {f32 + f16; + }, + std::invalid_argument + ); } TEST(AutogradTest, CastingAs) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - auto var = Variable(fl::rand({5, 5}), true); - auto varF16 = var.astype(fl::dtype::f16); - ASSERT_EQ(var.type(), fl::dtype::f32); - ASSERT_EQ(varF16.type(), fl::dtype::f16); - ASSERT_TRUE(allClose(varF16.tensor(), var.astype(fl::dtype::f16).tensor())); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + auto var = Variable(fl::rand({5, 5}), true); + auto varF16 = var.astype(fl::dtype::f16); + ASSERT_EQ(var.type(), fl::dtype::f32); + ASSERT_EQ(varF16.type(), fl::dtype::f16); + ASSERT_TRUE(allClose(varF16.tensor(), var.astype(fl::dtype::f16).tensor())); } TEST(AutogradTest, CastingAsBackward) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - auto a = Variable(fl::rand({4, 4}, fl::dtype::f16), true); - auto b = Variable(fl::rand({4, 4}, fl::dtype::f16), false); - auto c = b + a; - c.backward(); - ASSERT_EQ(a.grad().type(), fl::dtype::f16); - ASSERT_EQ(a.grad().type(), fl::dtype::f16); - a = a.astype(fl::dtype::f32); - ASSERT_FALSE(a.isGradAvailable()); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + auto a = Variable(fl::rand({4, 4}, fl::dtype::f16), true); + auto b = Variable(fl::rand({4, 4}, fl::dtype::f16), false); + auto c = b + a; + c.backward(); + ASSERT_EQ(a.grad().type(), fl::dtype::f16); + ASSERT_EQ(a.grad().type(), fl::dtype::f16); + a = a.astype(fl::dtype::f32); + ASSERT_FALSE(a.isGradAvailable()); } TEST(AutogradTest, CastingAsGrad) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - // compare to f32 case - auto x = Variable(fl::full({5}, 2.0), true); - auto y = Variable(fl::full({5}, 3.0), true); - auto z = x * x + x * y + y * y; - auto dz = Variable(fl::full({5}, 1.0), false); - z.backward(dz); - auto dx = x.grad(); - auto dy = y.grad(); - - // f16 -- cast gradients in both directions - auto x32 = Variable(fl::full({5}, 2.0), true); - auto y32 = Variable(fl::full({5}, 3.0), true); - auto xf16 = x32.astype(fl::dtype::f16); - auto yf16 = y32.astype(fl::dtype::f16); - auto zf16 = xf16 * xf16 + xf16 * yf16 + yf16 * yf16; - auto zf32 = zf16.astype(fl::dtype::f32); - zf32.backward(dz); - - ASSERT_EQ(xf16.grad().type(), fl::dtype::f16); - ASSERT_EQ(yf16.grad().type(), fl::dtype::f16); - ASSERT_EQ(zf16.grad().type(), fl::dtype::f16); - ASSERT_EQ(x32.grad().type(), fl::dtype::f32); - ASSERT_EQ(y32.grad().type(), fl::dtype::f32); - ASSERT_TRUE( - allClose(dx.tensor(), xf16.grad().tensor().astype(fl::dtype::f32))); - ASSERT_TRUE( - allClose(dy.tensor(), y32.grad().tensor().astype(fl::dtype::f32))); - ASSERT_TRUE(allClose(dx.tensor(), x32.grad().tensor())); - ASSERT_TRUE(allClose(dy.tensor(), y32.grad().tensor())); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + // compare to f32 case + auto x = Variable(fl::full({5}, 2.0), true); + auto y = Variable(fl::full({5}, 3.0), true); + auto z = x * x + x * y + y * y; + auto dz = Variable(fl::full({5}, 1.0), false); + z.backward(dz); + auto dx = x.grad(); + auto dy = y.grad(); + + // f16 -- cast gradients in both directions + auto x32 = Variable(fl::full({5}, 2.0), true); + auto y32 = Variable(fl::full({5}, 3.0), true); + auto xf16 = x32.astype(fl::dtype::f16); + auto yf16 = y32.astype(fl::dtype::f16); + auto zf16 = xf16 * xf16 + xf16 * yf16 + yf16 * yf16; + auto zf32 = zf16.astype(fl::dtype::f32); + zf32.backward(dz); + + ASSERT_EQ(xf16.grad().type(), fl::dtype::f16); + ASSERT_EQ(yf16.grad().type(), fl::dtype::f16); + ASSERT_EQ(zf16.grad().type(), fl::dtype::f16); + ASSERT_EQ(x32.grad().type(), fl::dtype::f32); + ASSERT_EQ(y32.grad().type(), fl::dtype::f32); + ASSERT_TRUE( + allClose(dx.tensor(), xf16.grad().tensor().astype(fl::dtype::f32)) + ); + ASSERT_TRUE( + allClose(dy.tensor(), y32.grad().tensor().astype(fl::dtype::f32)) + ); + ASSERT_TRUE(allClose(dx.tensor(), x32.grad().tensor())); + ASSERT_TRUE(allClose(dy.tensor(), y32.grad().tensor())); } TEST(AutogradTest, NoCalcGrad) { - auto x = Variable(fl::rand({5}), false); - auto y = Variable(fl::rand({5}), true); - auto z = x * x + x * y + y * y; - auto dz = Variable(fl::full({5}, 1.0), false); - z.backward(dz); - auto dy = y.grad(); - ASSERT_TRUE(allClose(dy.tensor(), 2 * y.tensor() + x.tensor())); - ASSERT_THROW(x.grad(), std::logic_error); + auto x = Variable(fl::rand({5}), false); + auto y = Variable(fl::rand({5}), true); + auto z = x * x + x * y + y * y; + auto dz = Variable(fl::full({5}, 1.0), false); + z.backward(dz); + auto dy = y.grad(); + ASSERT_TRUE(allClose(dy.tensor(), 2 * y.tensor() + x.tensor())); + ASSERT_THROW(x.grad(), std::logic_error); } TEST(AutogradTest, Concatenate) { - auto x1 = Variable(fl::rand({2, 3, 1, 2}, fl::dtype::f64), true); - auto x2 = Variable(fl::rand({2, 3, 3, 2}, fl::dtype::f64), true); - auto x3 = Variable(fl::rand({2, 3, 1, 2}, fl::dtype::f64), true); - auto x4 = Variable(fl::rand({2, 3, 7, 2}, fl::dtype::f64), true); - std::vector inputs = {x1, x2, x3, x4}; - auto output = concatenate(inputs, 2); - - ASSERT_EQ(output.shape(), Shape({2, 3, 12, 2})); - - auto funcConcatenateT1 = [x2, x3, x4](Variable& in) { - return concatenate({in, x2, x3, x4}, 2); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConcatenateT1, x1, 1E-5, 1E-4, {&x2, &x3, &x4})); - - auto funcConcatenateT2 = [x1, x2, x4](Variable& in) { - return concatenate({x1, x2, in, x4}, 2); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConcatenateT2, x3, 1E-5, 1E-4, {&x1, &x2, &x4})); + auto x1 = Variable(fl::rand({2, 3, 1, 2}, fl::dtype::f64), true); + auto x2 = Variable(fl::rand({2, 3, 3, 2}, fl::dtype::f64), true); + auto x3 = Variable(fl::rand({2, 3, 1, 2}, fl::dtype::f64), true); + auto x4 = Variable(fl::rand({2, 3, 7, 2}, fl::dtype::f64), true); + std::vector inputs = {x1, x2, x3, x4}; + auto output = concatenate(inputs, 2); + + ASSERT_EQ(output.shape(), Shape({2, 3, 12, 2})); + + auto funcConcatenateT1 = [x2, x3, x4](Variable& in) { + return concatenate({in, x2, x3, x4}, 2); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConcatenateT1, x1, 1E-5, 1E-4, {&x2, &x3, &x4})); + + auto funcConcatenateT2 = [x1, x2, x4](Variable& in) { + return concatenate({x1, x2, in, x4}, 2); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcConcatenateT2, x3, 1E-5, 1E-4, {&x1, &x2, &x4})); } TEST(AutogradTest, Split) { - // check output - auto x = Variable(fl::arange({7, 2}), true); - auto yVec = split(x, 1, 0); - ASSERT_EQ(yVec.size(), 7); - ASSERT_EQ(yVec[0].shape(), Shape({1, 2})); - ASSERT_EQ(yVec[2].shape(), Shape({1, 2})); - ASSERT_TRUE(fl::all(yVec[6].tensor() == 6).scalar()); - - auto a = Variable(fl::arange({5, 3}, 1), true); - auto bVec = split(a, {2, 1}, 1); - ASSERT_EQ(bVec.size(), 2); - ASSERT_EQ(bVec[0].shape(), Shape({5, 2})); - ASSERT_EQ(bVec[1].shape(), Shape({5, 1})); - ASSERT_TRUE( - fl::all(bVec[0].tensor() == fl::arange({5, 2}, 1)).scalar()); - ASSERT_TRUE(fl::all(bVec[1].tensor() == 2).scalar()); - - // check exception handling - ASSERT_THROW(split(a, {2, 2}, 0), std::invalid_argument); - - // check gradient - auto gradFunc = [](Variable& in) { return split(in, 2, 1)[0]; }; - auto input = Variable(fl::rand({2, 3}, fl::dtype::f64), true); - ASSERT_TRUE(fl::detail::jacobianTestImpl(gradFunc, input)); + // check output + auto x = Variable(fl::arange({7, 2}), true); + auto yVec = split(x, 1, 0); + ASSERT_EQ(yVec.size(), 7); + ASSERT_EQ(yVec[0].shape(), Shape({1, 2})); + ASSERT_EQ(yVec[2].shape(), Shape({1, 2})); + ASSERT_TRUE(fl::all(yVec[6].tensor() == 6).scalar()); + + auto a = Variable(fl::arange({5, 3}, 1), true); + auto bVec = split(a, {2, 1}, 1); + ASSERT_EQ(bVec.size(), 2); + ASSERT_EQ(bVec[0].shape(), Shape({5, 2})); + ASSERT_EQ(bVec[1].shape(), Shape({5, 1})); + ASSERT_TRUE( + fl::all(bVec[0].tensor() == fl::arange({5, 2}, 1)).scalar() + ); + ASSERT_TRUE(fl::all(bVec[1].tensor() == 2).scalar()); + + // check exception handling + ASSERT_THROW(split(a, {2, 2}, 0), std::invalid_argument); + + // check gradient + auto gradFunc = [](Variable& in) { return split(in, 2, 1)[0]; }; + auto input = Variable(fl::rand({2, 3}, fl::dtype::f64), true); + ASSERT_TRUE(fl::detail::jacobianTestImpl(gradFunc, input)); } TEST(AutogradTest, Tile) { - auto x = Variable(fl::rand({6}), true); - auto y = Variable(fl::rand({6, 3}), true); - auto z = y * tile(x, {1, 3}); - auto dz = Variable(fl::full({6, 3}, 1.0), false); - z.backward(dz); - auto dy = y.grad(); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dy.tensor(), fl::tile(x.tensor(), {1, 3}))); - ASSERT_TRUE(allClose(dx.tensor(), fl::sum(y.tensor(), {1}))); - - // Jacobian - auto input = Variable(fl::rand({10, 1, 5}), true); - auto funcTile = [](Variable& in) { return tile(in, {1, 2}); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcTile, input, 1E-4, 1E-3)); + auto x = Variable(fl::rand({6}), true); + auto y = Variable(fl::rand({6, 3}), true); + auto z = y * tile(x, {1, 3}); + auto dz = Variable(fl::full({6, 3}, 1.0), false); + z.backward(dz); + auto dy = y.grad(); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dy.tensor(), fl::tile(x.tensor(), {1, 3}))); + ASSERT_TRUE(allClose(dx.tensor(), fl::sum(y.tensor(), {1}))); + + // Jacobian + auto input = Variable(fl::rand({10, 1, 5}), true); + auto funcTile = [](Variable& in) { return tile(in, {1, 2}); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcTile, input, 1E-4, 1E-3)); } TEST(AutogradTest, TileAs) { - auto x = Variable(fl::rand({5}), true); - auto y = Variable(fl::rand({5, 2}), true); - auto z = y * tileAs(x, y); - auto dz = Variable(fl::full({5, 2}, 1.0), false); - z.backward(dz); - auto dy = y.grad(); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dy.tensor(), fl::tile(x.tensor(), {1, 2}))); - ASSERT_TRUE(allClose(dx.tensor(), fl::sum(y.tensor(), {1}))); + auto x = Variable(fl::rand({5}), true); + auto y = Variable(fl::rand({5, 2}), true); + auto z = y * tileAs(x, y); + auto dz = Variable(fl::full({5, 2}, 1.0), false); + z.backward(dz); + auto dy = y.grad(); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dy.tensor(), fl::tile(x.tensor(), {1, 2}))); + ASSERT_TRUE(allClose(dx.tensor(), fl::sum(y.tensor(), {1}))); } TEST_F(AutogradTestF16, TileAsF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - auto x = Variable(fl::rand({5}, fl::dtype::f16), true); - auto y = Variable(fl::rand({5, 2}, fl::dtype::f16), true); - auto z = y * tileAs(x, y); - ASSERT_EQ(x.type(), z.type()); - auto dz = Variable(fl::full({5, 2}, 1.0, fl::dtype::f16), false); - z.backward(dz); - auto dy = y.grad(); - auto dx = x.grad(); - ASSERT_TRUE(allClose( - dy.tensor(), fl::tile(x.tensor(), {1, 2}).astype(dx.type()), 1e-2)); - ASSERT_TRUE( - allClose(dx.tensor(), fl::sum(y.tensor(), {1}).astype(dx.type()), 1e-2)); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + auto x = Variable(fl::rand({5}, fl::dtype::f16), true); + auto y = Variable(fl::rand({5, 2}, fl::dtype::f16), true); + auto z = y * tileAs(x, y); + ASSERT_EQ(x.type(), z.type()); + auto dz = Variable(fl::full({5, 2}, 1.0, fl::dtype::f16), false); + z.backward(dz); + auto dy = y.grad(); + auto dx = x.grad(); + ASSERT_TRUE( + allClose( + dy.tensor(), + fl::tile(x.tensor(), {1, 2}).astype(dx.type()), + 1e-2 + ) + ); + ASSERT_TRUE( + allClose(dx.tensor(), fl::sum(y.tensor(), {1}).astype(dx.type()), 1e-2) + ); } TEST(AutogradTest, TileAs2) { - auto x = Variable(fl::rand({10}), true); - auto z = tileAs(x, Shape({10, 3})); - auto dz = Variable(fl::full({10, 3}, 1.0), false); - z.backward(dz); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dx.tensor(), fl::full(x.shape(), 3.0))); + auto x = Variable(fl::rand({10}), true); + auto z = tileAs(x, Shape({10, 3})); + auto dz = Variable(fl::full({10, 3}, 1.0), false); + z.backward(dz); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dx.tensor(), fl::full(x.shape(), 3.0))); } TEST(AutogradTest, Indexing) { - auto x = Variable(fl::rand({5, 6, 7, 4}, fl::dtype::f64), true); - - auto funcCol = [](Variable& input) { return input(fl::span, 4); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcCol, x)); - - auto funcRow = [](Variable& input) { return input(4); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcRow, x)); - - auto funcSlice = [](Variable& input) { - return input(fl::span, fl::span, 4); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSlice, x)); - - auto funcCols = [](Variable& input) { - return input(fl::span, fl::range(2, 5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcCols, x)); - - auto funcRows = [](Variable& input) { return input(fl::range(2, 5)); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcRows, x)); - - auto funcSlices = [](Variable& input) { - return input(fl::span, fl::span, fl::range(2, 5)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSlices, x)); - auto funcFlat = [](Variable& input) { - return input.flat(fl::range(4, 100)); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcFlat, x)); + auto x = Variable(fl::rand({5, 6, 7, 4}, fl::dtype::f64), true); + + auto funcCol = [](Variable& input) { return input(fl::span, 4); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcCol, x)); + + auto funcRow = [](Variable& input) { return input(4); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcRow, x)); + + auto funcSlice = [](Variable& input) { + return input(fl::span, fl::span, 4); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSlice, x)); + + auto funcCols = [](Variable& input) { + return input(fl::span, fl::range(2, 5)); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcCols, x)); + + auto funcRows = [](Variable& input) { return input(fl::range(2, 5)); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcRows, x)); + + auto funcSlices = [](Variable& input) { + return input(fl::span, fl::span, fl::range(2, 5)); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSlices, x)); + auto funcFlat = [](Variable& input) { + return input.flat(fl::range(4, 100)); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcFlat, x)); } TEST(AutogradTest, Padding) { - auto in = Variable(fl::rand({3, 3}, fl::dtype::f32), true); - auto funcPad = [&](Variable& input) { - return padding(input, {{1, 2}, {0, 1}}, -1); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcPad, in, 1E-3)); + auto in = Variable(fl::rand({3, 3}, fl::dtype::f32), true); + auto funcPad = [&](Variable& input) { + return padding(input, {{1, 2}, {0, 1}}, -1); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcPad, in, 1E-3)); } TEST(AutogradTest, Pooling) { - auto in = Variable(fl::rand({3, 3, 1, 1}, fl::dtype::f32), true); - auto funcPool = [&](Variable& input) { return pool2d(input, 2, 2, 1, 1); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcPool, in, 1E-3)); + auto in = Variable(fl::rand({3, 3, 1, 1}, fl::dtype::f32), true); + auto funcPool = [&](Variable& input) { return pool2d(input, 2, 2, 1, 1); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcPool, in, 1E-3)); } TEST_F(AutogradTestF16, PoolingF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - const float inputScale = 2.0; // scale the input to prevent grad underflow - auto in = Variable(inputScale * fl::rand({3, 3, 1, 1}, fl::dtype::f16), true); - auto funcPool = [&](Variable& input) { return pool2d(input, 2, 2, 1, 1); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcPool, in, 1e1, 1e-1)); // TODO: investigate + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + const float inputScale = 2.0; // scale the input to prevent grad underflow + auto in = Variable(inputScale * fl::rand({3, 3, 1, 1}, fl::dtype::f16), true); + auto funcPool = [&](Variable& input) { return pool2d(input, 2, 2, 1, 1); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcPool, in, 1e1, 1e-1)); // TODO: investigate } TEST(AutogradTest, Reorder) { - auto in = Variable(fl::rand({3, 1, 4, 1}, fl::dtype::f32) * 2, true); - auto funcReorder = [&](Variable& input) { - return reorder(input, {2, 0, 3, 1}); - }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcReorder, in, 1E-3)); + auto in = Variable(fl::rand({3, 1, 4, 1}, fl::dtype::f32) * 2, true); + auto funcReorder = [&](Variable& input) { + return reorder(input, {2, 0, 3, 1}); + }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcReorder, in, 1E-3)); } TEST(AutogradTest, Embedding) { - int nWords = 10; - auto input = - Variable((fl::rand({4, 2}) * nWords).astype(fl::dtype::f32), false); - auto weights = Variable(fl::randn({4, nWords}, fl::dtype::f64), true); - auto funcEmbed = [&](Variable& w) { return embedding(input, w); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcEmbed, weights, 1E-5)); + int nWords = 10; + auto input = + Variable((fl::rand({4, 2}) * nWords).astype(fl::dtype::f32), false); + auto weights = Variable(fl::randn({4, nWords}, fl::dtype::f64), true); + auto funcEmbed = [&](Variable& w) { return embedding(input, w); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcEmbed, weights, 1E-5)); } TEST(AutogradTest, GetAdvancedIndex) { - // TODO: remove me - if (!FL_BACKEND_CUDA) { - GTEST_SKIP() + // TODO: remove me + if(!FL_BACKEND_CUDA) { + GTEST_SKIP() << "Advanced indexing operator unsupported for non-CUDA backends"; - } - std::vector validIndexTypes = { - fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64}; - for (const auto& dtype : validIndexTypes) { - auto x = Variable(fl::rand({20, 50, 40, 30}, fl::dtype::f32), true); - Tensor a({6}, dtype); - a(0) = 0; - a(1) = 15; - a(2) = 6; - a(3) = 1; - a(4) = 10; - a(5) = 6; - Tensor b({3}, dtype); - b(0) = 5; - b(1) = 11; - b(2) = 19; - auto x2 = x(a, b, fl::span, fl::range(0, 4)); - auto y = sum(x2 * x2, {0, 1, 2, 3}, /* keepDims = */ true); - auto res = 2 * sum(x2, {0, 1, 2, 3}, /* keepDims = */ true).tensor(); - y.backward(); - auto grad = sum(x.grad(), {0, 1, 2, 3}, /* keepDims = */ true).tensor(); - ASSERT_TRUE(allClose(grad, res, 1e-3)); - } + } + std::vector validIndexTypes = { + fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64}; + for(const auto& dtype : validIndexTypes) { + auto x = Variable(fl::rand({20, 50, 40, 30}, fl::dtype::f32), true); + Tensor a({6}, dtype); + a(0) = 0; + a(1) = 15; + a(2) = 6; + a(3) = 1; + a(4) = 10; + a(5) = 6; + Tensor b({3}, dtype); + b(0) = 5; + b(1) = 11; + b(2) = 19; + auto x2 = x(a, b, fl::span, fl::range(0, 4)); + auto y = sum(x2 * x2, {0, 1, 2, 3}, /* keepDims = */ true); + auto res = 2 * sum(x2, {0, 1, 2, 3}, /* keepDims = */ true).tensor(); + y.backward(); + auto grad = sum(x.grad(), {0, 1, 2, 3}, /* keepDims = */ true).tensor(); + ASSERT_TRUE(allClose(grad, res, 1e-3)); + } } TEST(AutogradTest, GetAdvancedIndexF16) { - // TODO: remove me - if (!FL_BACKEND_CUDA) { - GTEST_SKIP() + // TODO: remove me + if(!FL_BACKEND_CUDA) { + GTEST_SKIP() << "Advanced indexing operator unsupported for non-CUDA backends"; - } - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - std::vector validIndexTypes = { - fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64}; - for (const auto& dtype : validIndexTypes) { - auto x = Variable(fl::rand({20, 50, 40, 30}, fl::dtype::f16), true); - Tensor a({6}, dtype); - a(0) = 0; - a(1) = 15; - a(2) = 6; - a(3) = 1; - a(4) = 10; - a(5) = 6; - Tensor b({3}, dtype); - b(0) = 5; - b(1) = 11; - b(2) = 19; - auto x2 = x(a, b, fl::span, fl::range(0, 4)); - ASSERT_EQ(x2.type(), fl::dtype::f16); - auto y = sum(x2 * x2, {0, 1, 2, 3}, /* keepDims = */ true); - auto res = 2 * sum(x2, {0, 1, 2, 3}, /* keepDims = */ true).tensor(); - y.backward(); - ASSERT_EQ(x.grad().type(), fl::dtype::f16); - auto grad = sum(x.grad(), {0, 1, 2, 3}, /* keepDims = */ true).tensor(); - ASSERT_TRUE(allClose(grad, res, 1e-3)); - } + } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + std::vector validIndexTypes = { + fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64}; + for(const auto& dtype : validIndexTypes) { + auto x = Variable(fl::rand({20, 50, 40, 30}, fl::dtype::f16), true); + Tensor a({6}, dtype); + a(0) = 0; + a(1) = 15; + a(2) = 6; + a(3) = 1; + a(4) = 10; + a(5) = 6; + Tensor b({3}, dtype); + b(0) = 5; + b(1) = 11; + b(2) = 19; + auto x2 = x(a, b, fl::span, fl::range(0, 4)); + ASSERT_EQ(x2.type(), fl::dtype::f16); + auto y = sum(x2 * x2, {0, 1, 2, 3}, /* keepDims = */ true); + auto res = 2 * sum(x2, {0, 1, 2, 3}, /* keepDims = */ true).tensor(); + y.backward(); + ASSERT_EQ(x.grad().type(), fl::dtype::f16); + auto grad = sum(x.grad(), {0, 1, 2, 3}, /* keepDims = */ true).tensor(); + ASSERT_TRUE(allClose(grad, res, 1e-3)); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/autograd/AutogradTestUtils.h b/flashlight/fl/test/autograd/AutogradTestUtils.h index 9dafbdb..99f229a 100644 --- a/flashlight/fl/test/autograd/AutogradTestUtils.h +++ b/flashlight/fl/test/autograd/AutogradTestUtils.h @@ -18,60 +18,60 @@ namespace fl { namespace detail { -class AutogradTestF16 : public ::testing::Test { - void SetUp() override { - // Ensures all operations will be in f16 - OptimMode::get().setOptimLevel(OptimLevel::O3); - } + class AutogradTestF16 : public ::testing::Test { + void SetUp() override { + // Ensures all operations will be in f16 + OptimMode::get().setOptimLevel(OptimLevel::O3); + } - void TearDown() override { - OptimMode::get().setOptimLevel(OptimLevel::DEFAULT); - } -}; + void TearDown() override { + OptimMode::get().setOptimLevel(OptimLevel::DEFAULT); + } + }; -using JacobianFunc = std::function; -inline bool jacobianTestImpl( - const JacobianFunc& func, - Variable& input, - float precision = 1E-5, - float perturbation = 1E-4, - const std::vector& zeroGradientVariables = {}) { - auto fwdJacobian = - Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); + using JacobianFunc = std::function; + inline bool jacobianTestImpl( + const JacobianFunc& func, + Variable& input, + float precision = 1E-5, + float perturbation = 1E-4, + const std::vector& zeroGradientVariables = {}) { + auto fwdJacobian = + Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); - for (int i = 0; i < input.elements(); ++i) { - Tensor orig = input.tensor().flatten()(i); - input.tensor().flat(i) = orig - perturbation; - auto outa = func(input).tensor(); + for(int i = 0; i < input.elements(); ++i) { + Tensor orig = input.tensor().flatten()(i); + input.tensor().flat(i) = orig - perturbation; + auto outa = func(input).tensor(); - input.tensor().flat(i) = orig + perturbation; - auto outb = func(input).tensor(); - input.tensor().flat(i) = orig; + input.tensor().flat(i) = orig + perturbation; + auto outb = func(input).tensor(); + input.tensor().flat(i) = orig; - fwdJacobian(fl::span, i) = - fl::reshape((outb - outa), {static_cast(outa.elements())}) * 0.5 / - perturbation; - } + fwdJacobian(fl::span, i) = + fl::reshape((outb - outa), {static_cast(outa.elements())}) * 0.5 + / perturbation; + } - auto bwdJacobian = - Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); - auto dout = - Variable(fl::full(func(input).shape(), 0, func(input).type()), false); + auto bwdJacobian = + Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); + auto dout = + Variable(fl::full(func(input).shape(), 0, func(input).type()), false); - for (int i = 0; i < dout.elements(); ++i) { - dout.tensor().flat(i) = 1; // element in 1D view - input.zeroGrad(); - for (auto* var : zeroGradientVariables) { - var->zeroGrad(); - } - auto out = func(input); - out.backward(dout); + for(int i = 0; i < dout.elements(); ++i) { + dout.tensor().flat(i) = 1; // element in 1D view + input.zeroGrad(); + for(auto* var : zeroGradientVariables) { + var->zeroGrad(); + } + auto out = func(input); + out.backward(dout); - bwdJacobian(i) = fl::reshape(input.grad().tensor(), {input.elements()}); - dout.tensor().flat(i) = 0; + bwdJacobian(i) = fl::reshape(input.grad().tensor(), {input.elements()}); + dout.tensor().flat(i) = 0; + } + return allClose(fwdJacobian, bwdJacobian, precision); } - return allClose(fwdJacobian, bwdJacobian, precision); -} } } // namespace fl diff --git a/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp b/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp index ce20a68..0ea5d22 100644 --- a/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp +++ b/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp @@ -20,167 +20,175 @@ using namespace fl; using fl::detail::AutogradTestF16; TEST(AutogradUnaryOpsTest, Clamp) { - auto input = Variable(fl::rand({5, 6, 7, 4}, fl::dtype::f64) * 3, true); - double lo = -1.0, hi = 1.0; - float perturb = 1E-5; - // Need to do this as gradient is not continuous when input = lo / hi. - auto& inarr = input.tensor(); - inarr = fl::where(fl::abs(inarr - lo) > perturb, inarr, lo + 10 * perturb); - inarr = fl::where(fl::abs(inarr - hi) > perturb, inarr, hi + 10 * perturb); + auto input = Variable(fl::rand({5, 6, 7, 4}, fl::dtype::f64) * 3, true); + double lo = -1.0, hi = 1.0; + float perturb = 1E-5; + // Need to do this as gradient is not continuous when input = lo / hi. + auto& inarr = input.tensor(); + inarr = fl::where(fl::abs(inarr - lo) > perturb, inarr, lo + 10 * perturb); + inarr = fl::where(fl::abs(inarr - hi) > perturb, inarr, hi + 10 * perturb); - auto funcCol = [lo, hi](Variable& in) { return clamp(in, lo, hi); }; + auto funcCol = [lo, hi](Variable& in) { return clamp(in, lo, hi); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcCol, input, 1E-10, perturb)); + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcCol, input, 1E-10, perturb)); } TEST(AutogradUnaryOpsTest, Glu) { - auto in = Variable(fl::rand({3, 4, 5}, fl::dtype::f64), true); - auto funcGlu = [&](Variable& input) { return gatedlinearunit(input, 1); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcGlu, in, 1E-5)); + auto in = Variable(fl::rand({3, 4, 5}, fl::dtype::f64), true); + auto funcGlu = [&](Variable& input) { return gatedlinearunit(input, 1); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcGlu, in, 1E-5)); } TEST(AutogradUnaryOpsTest, Sigmoid) { - auto x = Variable(fl::rand({5}), true); - auto y = sigmoid(x); - auto dy = Variable(fl::full({5}, 1.0), false); - y.backward(dy); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dx.tensor(), (y.tensor() * (1 - y.tensor())))); - ASSERT_TRUE(allClose( - dx.tensor(), (fl::sigmoid(x.tensor()) * (1 - fl::sigmoid(x.tensor()))))); + auto x = Variable(fl::rand({5}), true); + auto y = sigmoid(x); + auto dy = Variable(fl::full({5}, 1.0), false); + y.backward(dy); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dx.tensor(), (y.tensor() * (1 - y.tensor())))); + ASSERT_TRUE( + allClose( + dx.tensor(), + (fl::sigmoid(x.tensor()) * (1 - fl::sigmoid(x.tensor()))) + ) + ); } TEST(AutogradUnaryOpsTest, Erf) { - auto x = Variable(fl::rand({5}), true); - auto y = erf(x); - ASSERT_TRUE(allClose(fl::erf(x.tensor()), y.tensor())); - - auto dy = Variable(fl::full({5}, 1.0), false); - y.backward(dy); - auto targetGrads = 2 / std::sqrt(M_PI) * exp(negate(x * x)); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dx.tensor(), targetGrads.tensor())); - - auto funcErf = [](Variable& in) { return erf(in); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcErf, x, 5e-4, 1e-4)); + auto x = Variable(fl::rand({5}), true); + auto y = erf(x); + ASSERT_TRUE(allClose(fl::erf(x.tensor()), y.tensor())); + + auto dy = Variable(fl::full({5}, 1.0), false); + y.backward(dy); + auto targetGrads = 2 / std::sqrt(M_PI) * exp(negate(x * x)); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dx.tensor(), targetGrads.tensor())); + + auto funcErf = [](Variable& in) { return erf(in); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcErf, x, 5e-4, 1e-4)); } TEST(AutogradUnaryOpsTest, Tanh) { - auto x = Variable(fl::rand({5}), true); - auto y = tanh(x); - auto dy = Variable(fl::full({5}, 1.0), false); - y.backward(dy); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dx.tensor(), (1 - y.tensor() * y.tensor()))); - ASSERT_TRUE(allClose( - dx.tensor(), (1 + fl::tanh(x.tensor())) * (1 - fl::tanh(x.tensor())))); + auto x = Variable(fl::rand({5}), true); + auto y = tanh(x); + auto dy = Variable(fl::full({5}, 1.0), false); + y.backward(dy); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dx.tensor(), (1 - y.tensor() * y.tensor()))); + ASSERT_TRUE( + allClose( + dx.tensor(), + (1 + fl::tanh(x.tensor())) * (1 - fl::tanh(x.tensor())) + ) + ); } TEST(AutogradUnaryOpsTest, Transpose) { - auto in = Variable(fl::rand({5, 6, 7, 8}), true); - auto out = transpose(in, {2, 0, 1, 3}); - out.backward(); - ASSERT_EQ(in.grad().shape(), Shape({5, 6, 7, 8})); + auto in = Variable(fl::rand({5, 6, 7, 8}), true); + auto out = transpose(in, {2, 0, 1, 3}); + out.backward(); + ASSERT_EQ(in.grad().shape(), Shape({5, 6, 7, 8})); - auto funcErf = [](Variable& in) { return transpose(in, {1, 3, 2, 0}); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcErf, in, 5e-4, 1e-4)); + auto funcErf = [](Variable& in) { return transpose(in, {1, 3, 2, 0}); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcErf, in, 5e-4, 1e-4)); - auto in2 = Variable(fl::rand({6, 7, 8, 9}), true); - auto out2 = transpose(in2); - out2.backward(); - ASSERT_EQ(in2.grad().shape(), Shape({6, 7, 8, 9})); + auto in2 = Variable(fl::rand({6, 7, 8, 9}), true); + auto out2 = transpose(in2); + out2.backward(); + ASSERT_EQ(in2.grad().shape(), Shape({6, 7, 8, 9})); - auto funcErf2 = [](Variable& in) { return transpose(in); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcErf2, in2, 5e-4, 1e-4)); + auto funcErf2 = [](Variable& in) { return transpose(in); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcErf2, in2, 5e-4, 1e-4)); } TEST(AutogradUnaryOpsTest, Exp) { - auto x = Variable(fl::rand({5}), true); - auto y = exp(x); - auto dy = Variable(fl::full({5}, 1.0), false); - y.backward(dy); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dx.tensor(), (fl::exp(x.tensor())))); + auto x = Variable(fl::rand({5}), true); + auto y = exp(x); + auto dy = Variable(fl::full({5}, 1.0), false); + y.backward(dy); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dx.tensor(), (fl::exp(x.tensor())))); } TEST(AutogradUnaryOpsTest, Log1p) { - auto x = Variable(fl::rand({5}), true); - auto y = log1p(x); + auto x = Variable(fl::rand({5}), true); + auto y = log1p(x); - auto xCopy = Variable(x.tensor(), true); - auto yExp = log(1 + xCopy); + auto xCopy = Variable(x.tensor(), true); + auto yExp = log(1 + xCopy); - y.backward(); - yExp.backward(); + y.backward(); + yExp.backward(); - ASSERT_TRUE(allClose(y.tensor(), yExp.tensor())); - ASSERT_TRUE(allClose(y.grad().tensor(), yExp.grad().tensor())); - ASSERT_TRUE(allClose(x.grad().tensor(), xCopy.grad().tensor())); + ASSERT_TRUE(allClose(y.tensor(), yExp.tensor())); + ASSERT_TRUE(allClose(y.grad().tensor(), yExp.grad().tensor())); + ASSERT_TRUE(allClose(x.grad().tensor(), xCopy.grad().tensor())); } TEST(AutogradUnaryOpsTest, Softmax) { - auto in = Variable(fl::rand({3, 5, 1}, fl::dtype::f64), true); - auto funcSm = [&](Variable& input) { return softmax(input, 0); }; + auto in = Variable(fl::rand({3, 5, 1}, fl::dtype::f64), true); + auto funcSm = [&](Variable& input) { return softmax(input, 0); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSm, in, 1E-5)); + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSm, in, 1E-5)); } TEST_F(AutogradTestF16, SoftmaxF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } - auto in = Variable(fl::rand({3, 5, 1}, fl::dtype::f16), true); - auto funcSm = [&](Variable& input) { return softmax(input, 0); }; + auto in = Variable(fl::rand({3, 5, 1}, fl::dtype::f16), true); + auto funcSm = [&](Variable& input) { return softmax(input, 0); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSm, in, 1E-2, 1e-1)); + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSm, in, 1E-2, 1e-1)); } TEST(AutogradUnaryOpsTest, LogSoftmax) { - auto in = Variable(fl::rand({3, 5, 1}, fl::dtype::f64), true); - auto funcLsm = [&](Variable& input) { return logSoftmax(input, 0); }; + auto in = Variable(fl::rand({3, 5, 1}, fl::dtype::f64), true); + auto funcLsm = [&](Variable& input) { return logSoftmax(input, 0); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLsm, in, 1E-5)); + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLsm, in, 1E-5)); } TEST_F(AutogradTestF16, LogSoftmaxF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } - auto in = Variable(fl::rand({3, 5, 1}, fl::dtype::f16), true); - auto funcLsm = [&](Variable& input) { return logSoftmax(input, 0); }; + auto in = Variable(fl::rand({3, 5, 1}, fl::dtype::f16), true); + auto funcLsm = [&](Variable& input) { return logSoftmax(input, 0); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLsm, in, 1E-2, 1e-1)); + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcLsm, in, 1E-2, 1e-1)); } TEST(AutogradUnaryOpsTest, Pow) { - { - auto x = Variable(fl::rand({5}), true); - auto y = pow(x, 2); - auto dy = Variable(fl::full({5}, 2.0), false); - y.backward(dy); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dx.tensor(), (2 * 2 * x.tensor()))); - } - { - auto x = Variable(fl::rand({5}), true); - auto y = pow(x, 3); - auto dy = Variable(fl::full({5}, 1.0), false); - y.backward(dy); - auto dx = x.grad(); - ASSERT_TRUE(allClose(dx.tensor(), (3 * fl::power(x.tensor(), 2)))); - } + { + auto x = Variable(fl::rand({5}), true); + auto y = pow(x, 2); + auto dy = Variable(fl::full({5}, 2.0), false); + y.backward(dy); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dx.tensor(), (2 * 2 * x.tensor()))); + } + { + auto x = Variable(fl::rand({5}), true); + auto y = pow(x, 3); + auto dy = Variable(fl::full({5}, 1.0), false); + y.backward(dy); + auto dx = x.grad(); + ASSERT_TRUE(allClose(dx.tensor(), (3 * fl::power(x.tensor(), 2)))); + } } TEST(AutogradUnaryOpsTest, Sqrt) { - auto x = Variable(fl::rand({5, 3}, fl::dtype::f64), true); - auto funcSqrt = [](Variable& in) { return fl::sqrt(in); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSqrt, x, 1E-3)); + auto x = Variable(fl::rand({5, 3}, fl::dtype::f64), true); + auto funcSqrt = [](Variable& in) { return fl::sqrt(in); }; + ASSERT_TRUE(fl::detail::jacobianTestImpl(funcSqrt, x, 1E-3)); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/common/DevicePtrTest.cpp b/flashlight/fl/test/common/DevicePtrTest.cpp index e705cde..a847f24 100644 --- a/flashlight/fl/test/common/DevicePtrTest.cpp +++ b/flashlight/fl/test/common/DevicePtrTest.cpp @@ -14,52 +14,52 @@ using namespace fl; TEST(DevicePtrTest, Null) { - Tensor x; - DevicePtr xp(x); - EXPECT_EQ(xp.get(), nullptr); + Tensor x; + DevicePtr xp(x); + EXPECT_EQ(xp.get(), nullptr); } TEST(DevicePtrTest, NoCopy) { - Tensor a = fl::full({3, 3}, 5.); + Tensor a = fl::full({3, 3}, 5.); - void* devicePtrLoc; - { - DevicePtr p(a); - devicePtrLoc = p.get(); - } - EXPECT_EQ(devicePtrLoc, a.device()); - a.unlock(); + void* devicePtrLoc; + { + DevicePtr p(a); + devicePtrLoc = p.get(); + } + EXPECT_EQ(devicePtrLoc, a.device()); + a.unlock(); } TEST(DevicePtrTest, Locking) { - Tensor x({3, 3}); - EXPECT_FALSE(x.isLocked()); - { - DevicePtr xp(x); - EXPECT_TRUE(x.isLocked()); - } - EXPECT_FALSE(x.isLocked()); + Tensor x({3, 3}); + EXPECT_FALSE(x.isLocked()); + { + DevicePtr xp(x); + EXPECT_TRUE(x.isLocked()); + } + EXPECT_FALSE(x.isLocked()); } TEST(DevicePtrTest, Move) { - Tensor x({3, 3}); - Tensor y({4, 4}); + Tensor x({3, 3}); + Tensor y({4, 4}); - DevicePtr yp(y); - EXPECT_FALSE(x.isLocked()); - EXPECT_TRUE(y.isLocked()); + DevicePtr yp(y); + EXPECT_FALSE(x.isLocked()); + EXPECT_TRUE(y.isLocked()); - yp = DevicePtr(x); - EXPECT_TRUE(x.isLocked()); - EXPECT_FALSE(y.isLocked()); + yp = DevicePtr(x); + EXPECT_TRUE(x.isLocked()); + EXPECT_FALSE(y.isLocked()); - yp = DevicePtr(); - EXPECT_FALSE(x.isLocked()); - EXPECT_FALSE(y.isLocked()); + yp = DevicePtr(); + EXPECT_FALSE(x.isLocked()); + EXPECT_FALSE(y.isLocked()); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/common/DynamicBenchmarkTest.cpp b/flashlight/fl/test/common/DynamicBenchmarkTest.cpp index 8d7fce7..06b811c 100644 --- a/flashlight/fl/test/common/DynamicBenchmarkTest.cpp +++ b/flashlight/fl/test/common/DynamicBenchmarkTest.cpp @@ -20,137 +20,145 @@ namespace { class DynamicBenchmark : public ::testing::Test { - protected: - void SetUp() override { - fl::DynamicBenchmark::setBenchmarkMode(true); - } +protected: + void SetUp() override { + fl::DynamicBenchmark::setBenchmarkMode(true); + } }; } // namespace TEST_F(DynamicBenchmark, OptionsStateBasic) { - size_t maxCount = 5; - std::vector ops = {1, 2, 3, 4, 5}; - auto options = - std::make_shared>(ops, maxCount); - - ASSERT_FALSE(options->timingsComplete()); - ASSERT_EQ(options->currentOption(), 1); - for (size_t i = 0; i < maxCount * ops.size(); ++i) { - options->accumulateTimeToCurrentOption(1); - } - ASSERT_TRUE(options->timingsComplete()); - ASSERT_EQ(options->currentOption(), 1); // best idx should never have changed + size_t maxCount = 5; + std::vector ops = {1, 2, 3, 4, 5}; + auto options = + std::make_shared>(ops, maxCount); + + ASSERT_FALSE(options->timingsComplete()); + ASSERT_EQ(options->currentOption(), 1); + for(size_t i = 0; i < maxCount * ops.size(); ++i) { + options->accumulateTimeToCurrentOption(1); + } + ASSERT_TRUE(options->timingsComplete()); + ASSERT_EQ(options->currentOption(), 1); // best idx should never have changed } TEST_F(DynamicBenchmark, OptionscurrentOptionUnchangedWithNoCountIncrement) { - std::vector ops = {1, 2, 3, 4, 5}; - auto options = std::make_shared>( - ops, /* maxCount = */ 3); - - auto state = options->currentOption(); - options->accumulateTimeToCurrentOption(3, /* incrementCount = */ false); - options->accumulateTimeToCurrentOption(4, /* incrementCount = */ false); - ASSERT_EQ(state, options->currentOption()); + std::vector ops = {1, 2, 3, 4, 5}; + auto options = std::make_shared>( + ops, /* maxCount = */ + 3 + ); + + auto state = options->currentOption(); + options->accumulateTimeToCurrentOption(3, /* incrementCount = */ false); + options->accumulateTimeToCurrentOption(4, /* incrementCount = */ false); + ASSERT_EQ(state, options->currentOption()); } TEST_F(DynamicBenchmark, OptionsStateTimed) { - size_t maxCount = 5; - std::unordered_set ops = {1, 2, 3, 4, 5}; - auto options = - std::make_shared>(ops, maxCount); - - for (size_t i = 0; i < maxCount * ops.size(); ++i) { - // option 4 is faster - if (options->currentOption() == 4) { - options->accumulateTimeToCurrentOption(1); - } else { - options->accumulateTimeToCurrentOption( - 10 * (i + 1), /* incrementCount = */ false); - options->accumulateTimeToCurrentOption(10 * (i + 1)); + size_t maxCount = 5; + std::unordered_set ops = {1, 2, 3, 4, 5}; + auto options = + std::make_shared>(ops, maxCount); + + for(size_t i = 0; i < maxCount * ops.size(); ++i) { + // option 4 is faster + if(options->currentOption() == 4) { + options->accumulateTimeToCurrentOption(1); + } else { + options->accumulateTimeToCurrentOption( + 10 * (i + 1), /* incrementCount = */ + false + ); + options->accumulateTimeToCurrentOption(10 * (i + 1)); + } } - } - ASSERT_TRUE(options->timingsComplete()); - ASSERT_EQ(options->currentOption(), 4); // fastest - ASSERT_EQ(options->currentOption(), 4); + ASSERT_TRUE(options->timingsComplete()); + ASSERT_EQ(options->currentOption(), 4); // fastest + ASSERT_EQ(options->currentOption(), 4); } TEST_F(DynamicBenchmark, DynamicBenchmarkSimple) { - size_t maxCount = 5; - std::vector sleepTimes = {30, 16, 40}; //min 16ms (win) - - auto options = - std::make_shared>(sleepTimes, maxCount); - auto dynamicBench = std::make_shared(options); - - for (size_t i = 0; i < maxCount * sleepTimes.size(); ++i) { - std::chrono::milliseconds sleepTime(options->currentOption()); - dynamicBench->audit( - [sleepTime]() { std::this_thread::sleep_for(sleepTime); }); - } - ASSERT_TRUE(options->timingsComplete()); - // sleeping for fewer miliseconds is faster - ASSERT_EQ(options->currentOption(), sleepTimes[1]); + size_t maxCount = 5; + std::vector sleepTimes = {30, 16, 40}; // min 16ms (win) + + auto options = + std::make_shared>(sleepTimes, maxCount); + auto dynamicBench = std::make_shared(options); + + for(size_t i = 0; i < maxCount * sleepTimes.size(); ++i) { + std::chrono::milliseconds sleepTime(options->currentOption()); + dynamicBench->audit( + [sleepTime]() { std::this_thread::sleep_for(sleepTime); }); + } + ASSERT_TRUE(options->timingsComplete()); + // sleeping for fewer miliseconds is faster + ASSERT_EQ(options->currentOption(), sleepTimes[1]); } TEST_F(DynamicBenchmark, DynamicBenchmarkDisjointLambdas) { - size_t maxCount = 5; - std::vector sleepTimes = {30, 16, 40}; - - auto options = - std::make_shared>(sleepTimes, maxCount); - auto dynamicBench = std::make_shared(options); - - for (size_t i = 0; i < maxCount * sleepTimes.size(); ++i) { - std::chrono::milliseconds sleepTime(options->currentOption()); - dynamicBench->audit( - [sleepTime]() { std::this_thread::sleep_for(sleepTime); }, - /* incrementCount = */ false); - - // intermediate sleep is inversely proportional to the audit sleep time: - // 4, 2, 6 --> 18, 24, 12 - // total duration disregarding the audit is therefore: - // 18 + 2 * 4, 24 + 2 * 2, 12 + 2 * 6 ---> 26, 28, 24 - std::chrono::milliseconds intermediateSleepTime( - 30 - (3 * options->currentOption())); - std::this_thread::sleep_for(intermediateSleepTime); - - dynamicBench->audit( - [sleepTime]() { std::this_thread::sleep_for(sleepTime); }); - } - ASSERT_TRUE(options->timingsComplete()); - // option 2 is still fastest disregarding intermediate time - ASSERT_EQ(options->currentOption(), sleepTimes[1]); + size_t maxCount = 5; + std::vector sleepTimes = {30, 16, 40}; + + auto options = + std::make_shared>(sleepTimes, maxCount); + auto dynamicBench = std::make_shared(options); + + for(size_t i = 0; i < maxCount * sleepTimes.size(); ++i) { + std::chrono::milliseconds sleepTime(options->currentOption()); + dynamicBench->audit( + [sleepTime]() { std::this_thread::sleep_for(sleepTime); }, + /* incrementCount = */ false + ); + + // intermediate sleep is inversely proportional to the audit sleep time: + // 4, 2, 6 --> 18, 24, 12 + // total duration disregarding the audit is therefore: + // 18 + 2 * 4, 24 + 2 * 2, 12 + 2 * 6 ---> 26, 28, 24 + std::chrono::milliseconds intermediateSleepTime( + 30 - (3 * options->currentOption())); + std::this_thread::sleep_for(intermediateSleepTime); + + dynamicBench->audit( + [sleepTime]() { std::this_thread::sleep_for(sleepTime); }); + } + ASSERT_TRUE(options->timingsComplete()); + // option 2 is still fastest disregarding intermediate time + ASSERT_EQ(options->currentOption(), sleepTimes[1]); } TEST_F(DynamicBenchmark, DynamicBenchmarkMatmul) { - size_t maxCount = 5; - // n x n arrays of different sizes - std::vector arraySizes = {256, 8, 2048}; - - auto options = - std::make_shared>(arraySizes, maxCount); - auto dynamicBench = std::make_shared(options); - - for (size_t i = 0; i < maxCount * arraySizes.size(); ++i) { - auto size = dynamicBench->getOptions>() - ->currentOption(); - dynamicBench->audit([size]() { - auto a = fl::rand({size, size}); - auto b = fl::rand({size, size}); - auto c = fl::matmul(a, b); - fl::eval(c); - }); - } - auto ops = dynamicBench->getOptions>(); - ASSERT_TRUE(ops->timingsComplete()); - ASSERT_EQ( - ops->currentOption(), - *std::min_element(arraySizes.begin(), arraySizes.end())); + size_t maxCount = 5; + // n x n arrays of different sizes + std::vector arraySizes = {256, 8, 2048}; + + auto options = + std::make_shared>(arraySizes, maxCount); + auto dynamicBench = std::make_shared(options); + + for(size_t i = 0; i < maxCount * arraySizes.size(); ++i) { + auto size = dynamicBench->getOptions>() + ->currentOption(); + dynamicBench->audit( + [size]() { + auto a = fl::rand({size, size}); + auto b = fl::rand({size, size}); + auto c = fl::matmul(a, b); + fl::eval(c); + } + ); + } + auto ops = dynamicBench->getOptions>(); + ASSERT_TRUE(ops->timingsComplete()); + ASSERT_EQ( + ops->currentOption(), + *std::min_element(arraySizes.begin(), arraySizes.end()) + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/common/HistogramTest.cpp b/flashlight/fl/test/common/HistogramTest.cpp index 37899e2..5bdaab9 100644 --- a/flashlight/fl/test/common/HistogramTest.cpp +++ b/flashlight/fl/test/common/HistogramTest.cpp @@ -20,49 +20,49 @@ namespace { // normally distributed container of values. Checks that min,max,sum,mean,count // and count per bucket makes sense. TEST(FixedBucketSizeHistogram, NormalDistribution) { - const int nValues = 10e6; // Random large number. - const int nBuckes = 9; // Odd number of nuckets such that we have a bucket at - // the center with max elements. - const int mean = 100; - const int stddev = 5; - - std::minstd_rand0 generator; - std::normal_distribution distribution(mean, stddev); - - std::vector data(nValues); - for (int i = 0; i < nValues; ++i) { - data[i] = distribution(generator); - } - - HistogramStats hist = - FixedBucketSizeHistogram(data.begin(), data.end(), nBuckes); - - EXPECT_LT(hist.min, mean - stddev); - EXPECT_GT(hist.max, mean + stddev); - // Sum should be smaller than if all values where greater than the mean. - EXPECT_LT(hist.sum, (nValues + 1) * mean); - // Normal max should be greater than uniform distribution. - EXPECT_GT(hist.maxNumValuesPerBucket, nValues / nBuckes); - ASSERT_EQ(hist.buckets.size(), nBuckes); - - // Verify normal distribution. - EXPECT_LT(hist.buckets[0].count, hist.buckets[1].count); - EXPECT_LT(hist.buckets[1].count, hist.buckets[2].count); - EXPECT_LT(hist.buckets[2].count, hist.buckets[3].count); - EXPECT_LT(hist.buckets[3].count, hist.buckets[4].count); - EXPECT_GT(hist.buckets[4].count, hist.buckets[5].count); - EXPECT_GT(hist.buckets[5].count, hist.buckets[6].count); - EXPECT_GT(hist.buckets[6].count, hist.buckets[7].count); - EXPECT_GT(hist.buckets[7].count, hist.buckets[8].count); - - // Verify bounds span the range. - EXPECT_EQ(hist.buckets[0].startInclusive, hist.min); - for (int i = 0; i < (nBuckes - 1); ++i) { - EXPECT_EQ(hist.buckets[i + 1].startInclusive, hist.buckets[i].endExclusive); - } - EXPECT_EQ(hist.buckets[nBuckes - 1].endExclusive, hist.max); - - std::cout << hist.prettyString() << std::endl; + const int nValues = 10e6; // Random large number. + const int nBuckes = 9; // Odd number of nuckets such that we have a bucket at + // the center with max elements. + const int mean = 100; + const int stddev = 5; + + std::minstd_rand0 generator; + std::normal_distribution distribution(mean, stddev); + + std::vector data(nValues); + for(int i = 0; i < nValues; ++i) { + data[i] = distribution(generator); + } + + HistogramStats hist = + FixedBucketSizeHistogram(data.begin(), data.end(), nBuckes); + + EXPECT_LT(hist.min, mean - stddev); + EXPECT_GT(hist.max, mean + stddev); + // Sum should be smaller than if all values where greater than the mean. + EXPECT_LT(hist.sum, (nValues + 1) * mean); + // Normal max should be greater than uniform distribution. + EXPECT_GT(hist.maxNumValuesPerBucket, nValues / nBuckes); + ASSERT_EQ(hist.buckets.size(), nBuckes); + + // Verify normal distribution. + EXPECT_LT(hist.buckets[0].count, hist.buckets[1].count); + EXPECT_LT(hist.buckets[1].count, hist.buckets[2].count); + EXPECT_LT(hist.buckets[2].count, hist.buckets[3].count); + EXPECT_LT(hist.buckets[3].count, hist.buckets[4].count); + EXPECT_GT(hist.buckets[4].count, hist.buckets[5].count); + EXPECT_GT(hist.buckets[5].count, hist.buckets[6].count); + EXPECT_GT(hist.buckets[6].count, hist.buckets[7].count); + EXPECT_GT(hist.buckets[7].count, hist.buckets[8].count); + + // Verify bounds span the range. + EXPECT_EQ(hist.buckets[0].startInclusive, hist.min); + for(int i = 0; i < (nBuckes - 1); ++i) { + EXPECT_EQ(hist.buckets[i + 1].startInclusive, hist.buckets[i].endExclusive); + } + EXPECT_EQ(hist.buckets[nBuckes - 1].endExclusive, hist.max); + + std::cout << hist.prettyString() << std::endl; } // Tests that FixedBucketSizeHistogram generate correct statistics for a @@ -72,67 +72,68 @@ TEST(FixedBucketSizeHistogram, NormalDistribution) { // histogram makes sense for exponential distribution. It is an histogram of the // bucket with the most elements in the first histogram. TEST(FixedBucketSizeHistogram, ExponentialDistribution) { - const int nValues = 10e6; // Random large number - const int nBuckes = 12; // Random value - const double multiplier = - 10e3; // Should be much bigger than 1 to map floats evenly on to integers. - - std::minstd_rand0 generator; - std::exponential_distribution distribution(0.1); - - std::vector data(nValues); - for (int i = 0; i < nValues; ++i) { - data[i] = distribution(generator) * multiplier; - } - - HistogramStats hist = - FixedBucketSizeHistogram(data.begin(), data.end(), nBuckes); - - // Verify sanity of basic stats. - ASSERT_EQ(hist.buckets.size(), nBuckes); - EXPECT_EQ(hist.numValues, data.size()); - EXPECT_EQ(hist.min, 0); - EXPECT_GT(hist.max, multiplier); - // exponential max should be greater than uniform distribution. - EXPECT_GT(hist.maxNumValuesPerBucket, nValues / nBuckes); - - // Verify exponential distribution. - for (int i = 0; i < (nBuckes - 1); ++i) { - EXPECT_GT(hist.buckets[i].count, hist.buckets[i + 1].count); - } - - // Verify bounds span the range. - EXPECT_EQ(hist.buckets[0].startInclusive, hist.min); - for (int i = 0; i < (nBuckes - 1); ++i) { - EXPECT_EQ(hist.buckets[i + 1].startInclusive, hist.buckets[i].endExclusive); - } - EXPECT_GE(hist.buckets[nBuckes - 1].endExclusive, hist.max); - - std::cout << hist.prettyString() << std::endl; - - // High-resolution histogram. - const HistogramBucket& largestCountBucket = hist.buckets[0]; - HistogramStats hiResHist = FixedBucketSizeHistogram( - data.begin(), - data.end(), - nBuckes, - largestCountBucket.startInclusive, - largestCountBucket.endExclusive); - - // Verify sanity of basic stats. - ASSERT_EQ(hiResHist.buckets.size(), nBuckes); - EXPECT_GE(hiResHist.min, largestCountBucket.startInclusive); - EXPECT_LE(hiResHist.max, largestCountBucket.endExclusive); - // exponential max should be greater than uniform distribution. - EXPECT_GT(hiResHist.maxNumValuesPerBucket, nValues / nBuckes); - - std::cout << hiResHist.prettyString() << std::endl; + const int nValues = 10e6; // Random large number + const int nBuckes = 12; // Random value + const double multiplier = + 10e3; // Should be much bigger than 1 to map floats evenly on to integers. + + std::minstd_rand0 generator; + std::exponential_distribution distribution(0.1); + + std::vector data(nValues); + for(int i = 0; i < nValues; ++i) { + data[i] = distribution(generator) * multiplier; + } + + HistogramStats hist = + FixedBucketSizeHistogram(data.begin(), data.end(), nBuckes); + + // Verify sanity of basic stats. + ASSERT_EQ(hist.buckets.size(), nBuckes); + EXPECT_EQ(hist.numValues, data.size()); + EXPECT_EQ(hist.min, 0); + EXPECT_GT(hist.max, multiplier); + // exponential max should be greater than uniform distribution. + EXPECT_GT(hist.maxNumValuesPerBucket, nValues / nBuckes); + + // Verify exponential distribution. + for(int i = 0; i < (nBuckes - 1); ++i) { + EXPECT_GT(hist.buckets[i].count, hist.buckets[i + 1].count); + } + + // Verify bounds span the range. + EXPECT_EQ(hist.buckets[0].startInclusive, hist.min); + for(int i = 0; i < (nBuckes - 1); ++i) { + EXPECT_EQ(hist.buckets[i + 1].startInclusive, hist.buckets[i].endExclusive); + } + EXPECT_GE(hist.buckets[nBuckes - 1].endExclusive, hist.max); + + std::cout << hist.prettyString() << std::endl; + + // High-resolution histogram. + const HistogramBucket& largestCountBucket = hist.buckets[0]; + HistogramStats hiResHist = FixedBucketSizeHistogram( + data.begin(), + data.end(), + nBuckes, + largestCountBucket.startInclusive, + largestCountBucket.endExclusive + ); + + // Verify sanity of basic stats. + ASSERT_EQ(hiResHist.buckets.size(), nBuckes); + EXPECT_GE(hiResHist.min, largestCountBucket.startInclusive); + EXPECT_LE(hiResHist.max, largestCountBucket.endExclusive); + // exponential max should be greater than uniform distribution. + EXPECT_GT(hiResHist.maxNumValuesPerBucket, nValues / nBuckes); + + std::cout << hiResHist.prettyString() << std::endl; } } // namespace int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/common/LoggingTest.cpp b/flashlight/fl/test/common/LoggingTest.cpp index def67cf..6a8a784 100644 --- a/flashlight/fl/test/common/LoggingTest.cpp +++ b/flashlight/fl/test/common/LoggingTest.cpp @@ -23,46 +23,46 @@ using testing::Not; // FL_VLOG(l) should print to stdout when VerboseLogging::setMaxLoggingLevel(i) // i>=l TEST(Logging, vlogOnOff) { - std::stringstream stdoutBuffer; - std::stringstream stderrBuffer; - - std::streambuf* origStdoutBuffer = std::cout.rdbuf(); - std::streambuf* origStderrBuffer = std::cerr.rdbuf(); - - std::cout.rdbuf(stdoutBuffer.rdbuf()); - std::cerr.rdbuf(stderrBuffer.rdbuf()); - - for (int i = 0; i < 11; ++i) { - stdoutBuffer.clear(); - stderrBuffer.clear(); - VerboseLogging::setMaxLoggingLevel(i); - FL_VLOG(0) << "vlog-0"; - FL_VLOG(1) << "vlog-1"; - FL_VLOG(10) << "vlog-10"; - - // Prints to stderr - EXPECT_THAT(stderrBuffer.str(), HasSubstr("vlog-0")); - - if (i >= 1) { - EXPECT_THAT(stderrBuffer.str(), HasSubstr("vlog-1")); - } else { - EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("vlog-1"))); + std::stringstream stdoutBuffer; + std::stringstream stderrBuffer; + + std::streambuf* origStdoutBuffer = std::cout.rdbuf(); + std::streambuf* origStderrBuffer = std::cerr.rdbuf(); + + std::cout.rdbuf(stdoutBuffer.rdbuf()); + std::cerr.rdbuf(stderrBuffer.rdbuf()); + + for(int i = 0; i < 11; ++i) { + stdoutBuffer.clear(); + stderrBuffer.clear(); + VerboseLogging::setMaxLoggingLevel(i); + FL_VLOG(0) << "vlog-0"; + FL_VLOG(1) << "vlog-1"; + FL_VLOG(10) << "vlog-10"; + + // Prints to stderr + EXPECT_THAT(stderrBuffer.str(), HasSubstr("vlog-0")); + + if(i >= 1) { + EXPECT_THAT(stderrBuffer.str(), HasSubstr("vlog-1")); + } else { + EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("vlog-1"))); + } + + if(i >= 10) { + EXPECT_THAT(stderrBuffer.str(), HasSubstr("vlog-10")); + } else { + EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("vlog-10"))); + } + + // Does not print to stdout + EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("vlog-0"))); + EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("vlog-1"))); + EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("vlog-10"))); } - if (i >= 10) { - EXPECT_THAT(stderrBuffer.str(), HasSubstr("vlog-10")); - } else { - EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("vlog-10"))); - } - - // Does not print to stdout - EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("vlog-0"))); - EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("vlog-1"))); - EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("vlog-10"))); - } - - std::cout.rdbuf(origStdoutBuffer); - std::cerr.rdbuf(origStderrBuffer); + std::cout.rdbuf(origStdoutBuffer); + std::cerr.rdbuf(origStderrBuffer); } // FL_LOG(l) should print to stdout when Logging::setMaxLoggingLevel(i) i>=l and @@ -70,78 +70,82 @@ TEST(Logging, vlogOnOff) { // FL_LOG(l) should print to stderr when Logging::setMaxLoggingLevel(i) i>=l and // l is ERROR. TEST(Logging, logOnOff) { - std::stringstream stdoutBuffer; - std::stringstream stderrBuffer; - - std::streambuf* origStdoutBuffer = std::cout.rdbuf(); - std::streambuf* origStderrBuffer = std::cerr.rdbuf(); - - std::cout.rdbuf(stdoutBuffer.rdbuf()); - std::cerr.rdbuf(stderrBuffer.rdbuf()); - - const std::vector logLevels = { - fl::LogLevel::DISABLED, - fl::LogLevel::FATAL, - fl::LogLevel::ERROR, - fl::LogLevel::WARNING, - fl::LogLevel::INFO}; - for (LogLevel l : logLevels) { - stdoutBuffer.clear(); - stderrBuffer.clear(); - - Logging::setMaxLoggingLevel(l); - FL_LOG(fl::LogLevel::INFO) << "log-info"; - FL_LOG(fl::LogLevel::WARNING) << "log-warning"; - FL_LOG(fl::LogLevel::ERROR) << "log-error"; - - // Prints to stderr - if (l >= fl::LogLevel::INFO) { - EXPECT_THAT(stderrBuffer.str(), HasSubstr("log-info")); - } else { - EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-info"))); + std::stringstream stdoutBuffer; + std::stringstream stderrBuffer; + + std::streambuf* origStdoutBuffer = std::cout.rdbuf(); + std::streambuf* origStderrBuffer = std::cerr.rdbuf(); + + std::cout.rdbuf(stdoutBuffer.rdbuf()); + std::cerr.rdbuf(stderrBuffer.rdbuf()); + + const std::vector logLevels = { + fl::LogLevel::DISABLED, + fl::LogLevel::FATAL, + fl::LogLevel::ERROR, + fl::LogLevel::WARNING, + fl::LogLevel::INFO}; + for(LogLevel l : logLevels) { + stdoutBuffer.clear(); + stderrBuffer.clear(); + + Logging::setMaxLoggingLevel(l); + FL_LOG(fl::LogLevel::INFO) << "log-info"; + FL_LOG(fl::LogLevel::WARNING) << "log-warning"; + FL_LOG(fl::LogLevel::ERROR) << "log-error"; + + // Prints to stderr + if(l >= fl::LogLevel::INFO) { + EXPECT_THAT(stderrBuffer.str(), HasSubstr("log-info")); + } else { + EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-info"))); + } + + if(l >= fl::LogLevel::WARNING) { + EXPECT_THAT(stderrBuffer.str(), HasSubstr("log-warning")); + } else { + EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-warning"))); + } + + // Does not print to stdout + EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("log-info"))); + EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("log-warning"))); + + if(l >= fl::LogLevel::ERROR) { + EXPECT_THAT(stderrBuffer.str(), HasSubstr("log-error")); + } else { + EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-error"))); + } } - if (l >= fl::LogLevel::WARNING) { - EXPECT_THAT(stderrBuffer.str(), HasSubstr("log-warning")); - } else { - EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-warning"))); - } - - // Does not print to stdout - EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("log-info"))); - EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("log-warning"))); - - if (l >= fl::LogLevel::ERROR) { - EXPECT_THAT(stderrBuffer.str(), HasSubstr("log-error")); - } else { - EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-error"))); - } - } - - std::cout.rdbuf(origStdoutBuffer); - std::cerr.rdbuf(origStderrBuffer); + std::cout.rdbuf(origStdoutBuffer); + std::cerr.rdbuf(origStderrBuffer); } TEST(LoggingDeathTest, FatalOnOff) { - std::stringstream stderrBuffer; - std::streambuf* origStderrBuffer = std::cerr.rdbuf(); - std::cerr.rdbuf(stderrBuffer.rdbuf()); - - Logging::setMaxLoggingLevel(fl::LogLevel::DISABLED); - FL_LOG(fl::LogLevel::FATAL) << "log-fatal"; - EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-fatal"))); - EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-fatal"))); - - std::cerr.rdbuf(origStderrBuffer); - - Logging::setMaxLoggingLevel(fl::LogLevel::FATAL); - EXPECT_DEATH_IF_SUPPORTED({ FL_LOG(fl::LogLevel::FATAL) << "log-fatal"; }, ""); + std::stringstream stderrBuffer; + std::streambuf* origStderrBuffer = std::cerr.rdbuf(); + std::cerr.rdbuf(stderrBuffer.rdbuf()); + + Logging::setMaxLoggingLevel(fl::LogLevel::DISABLED); + FL_LOG(fl::LogLevel::FATAL) << "log-fatal"; + EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-fatal"))); + EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-fatal"))); + + std::cerr.rdbuf(origStderrBuffer); + + Logging::setMaxLoggingLevel(fl::LogLevel::FATAL); + EXPECT_DEATH_IF_SUPPORTED( + {FL_LOG(fl::LogLevel::FATAL) << "log-fatal"; + }, + "" + ); } } // namespace int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/common/SerializationTest.cpp b/flashlight/fl/test/common/SerializationTest.cpp index 0fe28b0..00cc2c4 100644 --- a/flashlight/fl/test/common/SerializationTest.cpp +++ b/flashlight/fl/test/common/SerializationTest.cpp @@ -19,155 +19,159 @@ // ========== utility functions ========== -template +template std::string saveToString(const T& t) { - std::ostringstream oss(std::ios::binary); - fl::save(oss, t); - return oss.str(); + std::ostringstream oss(std::ios::binary); + fl::save(oss, t); + return oss.str(); } -template +template void loadFromString(const std::string& s, T& t) { - std::istringstream iss(s, std::ios::binary); - fl::load(iss, t); + std::istringstream iss(s, std::ios::binary); + fl::load(iss, t); } -template +template void checkRoundTrip(const T& t0) { - T t1; - loadFromString(saveToString(t0), t1); - ASSERT_EQ(t0, t1); + T t1; + loadFromString(saveToString(t0), t1); + ASSERT_EQ(t0, t1); } // ========== basic serialization of structs ========== struct Basic { - int x; - double y; - std::string s; - std::vector v; + int x; + double y; + std::string s; + std::vector v; - FL_SAVE_LOAD(x, y, s, v) + FL_SAVE_LOAD(x, y, s, v) Basic() = default; - Basic() = default; + Basic(int x, double y, std::string s, std::vector v) : x(x), + y(y), + s(std::move(s)), + v(std::move(v)) {} - Basic(int x, double y, std::string s, std::vector v) - : x(x), y(y), s(std::move(s)), v(std::move(v)) {} - - bool operator==(const Basic& o) const { - return std::tie(x, y, s, v) == std::tie(o.x, o.y, o.s, o.v); - } + bool operator==(const Basic& o) const { + return std::tie(x, y, s, v) == std::tie(o.x, o.y, o.s, o.v); + } }; TEST(SerializationTest, Basic) { - checkRoundTrip(Basic{3, 5.5, "asdf", {2, 4, 6}}); + checkRoundTrip(Basic{3, 5.5, "asdf", {2, 4, 6}}); } // ========== versioning compatibility =========== struct BasicV1 { - int x; - double y; - float z{-1.0f}; - std::string s; - std::vector v; - - FL_SAVE_LOAD(x, y, fl::versioned(z, 1), s, v) - - BasicV1() = default; - - BasicV1(int x, double y, float z, std::string s, std::vector v) - : x(x), y(y), z(z), s(std::move(s)), v(std::move(v)) {} - - bool operator==(const BasicV1& o) const { - return std::tie(x, y, z, s, v) == std::tie(o.x, o.y, o.z, o.s, o.v); - } + int x; + double y; + float z{-1.0f}; + std::string s; + std::vector v; + + FL_SAVE_LOAD(x, y, fl::versioned(z, 1), s, v) BasicV1() = default; + + BasicV1(int x, double y, float z, std::string s, std::vector v) : x(x), + y(y), + z(z), + s(std::move(s)), + v(std::move(v)) {} + + bool operator==(const BasicV1& o) const { + return std::tie(x, y, z, s, v) == std::tie(o.x, o.y, o.z, o.s, o.v); + } }; CEREAL_CLASS_VERSION(BasicV1, 1); TEST(SerializationTest, Versions) { - checkRoundTrip(BasicV1{3, 5.5, 1.5f, "asdf", {2, 4, 6}}); - - Basic v0{3, 5.5, "asdf", {2, 4, 6}}; - BasicV1 v1; - loadFromString(saveToString(v0), v1); - ASSERT_EQ(v0.x, v1.x); - ASSERT_EQ(v0.y, v1.y); - ASSERT_EQ(-1.0f, v1.z); - ASSERT_EQ(v0.s, v1.s); - ASSERT_EQ(v0.v, v1.v); + checkRoundTrip(BasicV1{3, 5.5, 1.5f, "asdf", {2, 4, 6}}); + + Basic v0{3, 5.5, "asdf", {2, 4, 6}}; + BasicV1 v1; + loadFromString(saveToString(v0), v1); + ASSERT_EQ(v0.x, v1.x); + ASSERT_EQ(v0.y, v1.y); + ASSERT_EQ(-1.0f, v1.z); + ASSERT_EQ(v0.s, v1.s); + ASSERT_EQ(v0.v, v1.v); } // sanity check for testing -- useless in practice struct NestedVersioned { - int x{1}; - int y{2}; - int z{3}; - int w{4}; - - FL_SAVE_LOAD( - fl::versioned(fl::versioned(x, 0, 2), 1, 3), /* 2 \in [1, 2] */ - fl::versioned(fl::versioned(y, 3), 0, 5), /* 2 \not\in [3, 5] */ - fl::versioned(fl::versioned(z, 2, 2), 1), /* 2 \in [2, 2] */ - fl::versioned(fl::versioned(w, 0), 0, 1)) /* 2 \not\in [0, 1] */ + int x{1}; + int y{2}; + int z{3}; + int w{4}; + + FL_SAVE_LOAD( + fl::versioned(fl::versioned(x, 0, 2), 1, 3), /* 2 \in [1, 2] */ + fl::versioned(fl::versioned(y, 3), 0, 5), /* 2 \not\in [3, 5] */ + fl::versioned(fl::versioned(z, 2, 2), 1), /* 2 \in [2, 2] */ + fl::versioned(fl::versioned(w, 0), 0, 1) + ) /* 2 \not\in [0, 1] */ }; CEREAL_CLASS_VERSION(NestedVersioned, 2) TEST(SerializationTest, NestedVersioned) { - NestedVersioned t0; - t0.x = 5; - t0.y = 6; - t0.z = 7; - t0.w = 8; - NestedVersioned t1; - loadFromString(saveToString(t0), t1); - ASSERT_EQ(t1.x, 5); - ASSERT_EQ(t1.y, 2); - ASSERT_EQ(t1.z, 7); - ASSERT_EQ(t1.w, 4); + NestedVersioned t0; + t0.x = 5; + t0.y = 6; + t0.z = 7; + t0.w = 8; + NestedVersioned t1; + loadFromString(saveToString(t0), t1); + ASSERT_EQ(t1.x, 5); + ASSERT_EQ(t1.y, 2); + ASSERT_EQ(t1.z, 7); + ASSERT_EQ(t1.w, 4); } // ========== conversions using serializeAs ========== struct SerializeIntAsFloat { - int x; + int x; - FL_SAVE_LOAD(fl::serializeAs(x)) + FL_SAVE_LOAD(fl::serializeAs(x)) - bool operator==(const SerializeIntAsFloat& o) const { - return x == o.x; - } + bool operator==(const SerializeIntAsFloat& o) const { + return x == o.x; + } }; struct SerializeFloatAsInt { - float x; + float x; - FL_SAVE_LOAD(fl::serializeAs(x)) + FL_SAVE_LOAD(fl::serializeAs(x)) }; struct SerializeLongAsSqrt { - long x; - - FL_SAVE_LOAD(fl::serializeAs( - x, - [](const long& x) -> double { return std::sqrt(x); }, - [](double y) -> long { return std::lround(y * y); })) - - bool operator==(const SerializeLongAsSqrt& o) const { - return x == o.x; - } + long x; + + FL_SAVE_LOAD( + fl::serializeAs( + x, + [](const long& x) -> double { return std::sqrt(x); }, + [](double y) -> long { return std::lround(y * y); }) + ) + + bool operator==(const SerializeLongAsSqrt& o) const { + return x == o.x; + } }; TEST(SerializationTest, Conversions) { - checkRoundTrip(SerializeIntAsFloat{12345}); + checkRoundTrip(SerializeIntAsFloat{12345}); - SerializeFloatAsInt fi{3.3f}; - loadFromString(saveToString(fi), fi); - ASSERT_EQ(fi.x, 3.0f); // truncated due to static_cast + SerializeFloatAsInt fi{3.3f}; + loadFromString(saveToString(fi), fi); + ASSERT_EQ(fi.x, 3.0f); // truncated due to static_cast - checkRoundTrip(SerializeLongAsSqrt{13579}); + checkRoundTrip(SerializeLongAsSqrt{13579}); } // ========== passing temporary rvalues to FL_SAVE_LOAD ========== @@ -175,72 +179,72 @@ TEST(SerializationTest, Conversions) { // this will compile. saving will write x + 1 as expected, but // loading will read to a temporary which is discarded. struct SerializeNoOpTemporary { - int x{0}; + int x{0}; - FL_SAVE_LOAD(x + 1) + FL_SAVE_LOAD(x + 1) }; struct SerializeNoOpTemporaryInspect { - int x; + int x; - FL_SAVE_LOAD(x) + FL_SAVE_LOAD(x) }; TEST(SerializationTest, TemporaryNoOp) { - SerializeNoOpTemporary t0; - t0.x = 3; - auto s = saveToString(t0); // saves 4 + SerializeNoOpTemporary t0; + t0.x = 3; + auto s = saveToString(t0); // saves 4 - SerializeNoOpTemporaryInspect ins; - loadFromString(s, ins); // loads 4 - ASSERT_EQ(ins.x, 4); + SerializeNoOpTemporaryInspect ins; + loadFromString(s, ins); // loads 4 + ASSERT_EQ(ins.x, 4); - SerializeNoOpTemporary t1; - loadFromString(s, t1); // doesn't actually load - ASSERT_EQ(t1.x, 0); + SerializeNoOpTemporary t1; + loadFromString(s, t1); // doesn't actually load + ASSERT_EQ(t1.x, 0); } // multiplies by 10 before saving, adds 3 before loading -template +template struct WeirdTransform { - T&& x; - - template - void save(Archive& ar) const { - ar(x * 10); - } - - template - void load(Archive& ar) { - std::decay_t y; - ar(y); - x = y + 3; - } + T && x; + + template + void save(Archive& ar) const { + ar(x * 10); + } + + template + void load(Archive& ar) { + std::decay_t y; + ar(y); + x = y + 3; + } }; -template +template WeirdTransform weirdTransform(T&& t) { - return WeirdTransform{std::forward(t)}; + return WeirdTransform{std::forward(t)}; } struct SerializeViaTemporary { - int x; - int y; - int z; + int x; + int y; + int z; - FL_SAVE_LOAD(weirdTransform(x), y, weirdTransform(z)) + FL_SAVE_LOAD(weirdTransform(x), y, weirdTransform(z)) }; TEST(SerializationTest, TemporaryRvalues) { - SerializeViaTemporary t{5, 6, 7}; - loadFromString(saveToString(t), t); - ASSERT_EQ(t.x, 53); - ASSERT_EQ(t.y, 6); - ASSERT_EQ(t.z, 73); + SerializeViaTemporary t{5, 6, 7}; + loadFromString(saveToString(t), t); + ASSERT_EQ(t.x, 53); + ASSERT_EQ(t.y, 6); + ASSERT_EQ(t.z, 73); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/common/UtilsTest.cpp b/flashlight/fl/test/common/UtilsTest.cpp index 5a56533..055181e 100644 --- a/flashlight/fl/test/common/UtilsTest.cpp +++ b/flashlight/fl/test/common/UtilsTest.cpp @@ -17,96 +17,102 @@ using namespace fl; static std::function makeSucceedsAfterIters(int iters) { - auto state = std::make_shared(0); - return [state, iters]() { - if (++*state >= iters) { - return 42; - } else { - throw std::runtime_error("bleh"); - } - }; + auto state = std::make_shared(0); + return [state, iters]() { + if(++*state >= iters) { + return 42; + } else { + throw std::runtime_error("bleh"); + } + }; } static std::function makeSucceedsAfterMs(double ms) { - using namespace std::chrono; - auto state = std::make_shared>(); - return [state, ms]() { - auto now = steady_clock::now(); - if (state->time_since_epoch().count() == 0) { - *state = now; - } - if (now - *state >= duration(ms)) { - return 42; - } else { - throw std::runtime_error("bleh"); - } - }; + using namespace std::chrono; + auto state = std::make_shared>(); + return [state, ms]() { + auto now = steady_clock::now(); + if(state->time_since_epoch().count() == 0) { + *state = now; + } + if(now - *state >= duration(ms)) { + return 42; + } else { + throw std::runtime_error("bleh"); + } + }; } -template +template std::future::type> retryAsync( std::chrono::duration initial, double factor, int64_t iters, - Fn f) { - return std::async(std::launch::async, [=]() { - return retryWithBackoff(initial, factor, iters, f); - }); + Fn f +) { + return std::async( + std::launch::async, + [ = ]() { + return retryWithBackoff(initial, factor, iters, f); + } + ); } TEST(SystemTest, RetryWithBackoff) { - auto alwaysSucceeds = []() { return 42; }; - auto alwaysFails = []() -> int { throw std::runtime_error("bleh"); }; - - std::vector> goods; - std::vector> bads; - std::vector> invalids; - - auto ms0 = std::chrono::milliseconds(0); - auto ms50 = std::chrono::milliseconds(50); - - goods.push_back(retryAsync(ms0, 1.0, 5, alwaysSucceeds)); - goods.push_back(retryAsync(ms50, 2.0, 5, alwaysSucceeds)); - - bads.push_back(retryAsync(ms0, 1.0, 5, alwaysFails)); - bads.push_back(retryAsync(ms50, 2.0, 5, alwaysFails)); - - bads.push_back(retryAsync(ms0, 1.0, 5, makeSucceedsAfterIters(6))); - bads.push_back(retryAsync(ms50, 2.0, 5, makeSucceedsAfterIters(6))); - goods.push_back(retryAsync(ms0, 1.0, 5, makeSucceedsAfterIters(5))); - goods.push_back(retryAsync(ms50, 2.0, 5, makeSucceedsAfterIters(5))); - - bads.push_back(retryAsync(ms0, 1.0, 5, makeSucceedsAfterMs(999))); - bads.push_back(retryAsync(ms50, 2.0, 5, makeSucceedsAfterMs(999))); - bads.push_back(retryAsync(ms0, 1.0, 5, makeSucceedsAfterMs(500))); - goods.push_back(retryAsync(ms50, 2.0, 5, makeSucceedsAfterMs(500))); - - invalids.push_back(retryAsync(-ms50, 2.0, 5, alwaysSucceeds)); - invalids.push_back(retryAsync(ms50, -1.0, 5, alwaysSucceeds)); - invalids.push_back(retryAsync(ms50, 2.0, 0, alwaysSucceeds)); - invalids.push_back(retryAsync(ms50, 2.0, -1, alwaysSucceeds)); - - for (auto& fut : goods) { - ASSERT_EQ(fut.get(), 42); - } - for (auto& fut : bads) { - ASSERT_THROW(fut.get(), std::runtime_error); - } - for (auto& fut : invalids) { - ASSERT_THROW(fut.get(), std::invalid_argument); - } - - // check special case promise / future - auto alwaysSucceedsVoid = []() -> void {}; - auto alwaysFailsVoid = []() -> void { throw std::runtime_error("bleh"); }; - - retryAsync(ms0, 1.0, 5, alwaysSucceedsVoid).get(); - ASSERT_THROW( - retryAsync(ms0, 1.0, 5, alwaysFailsVoid).get(), std::runtime_error); + auto alwaysSucceeds = []() { return 42; }; + auto alwaysFails = []() -> int { throw std::runtime_error("bleh"); }; + + std::vector> goods; + std::vector> bads; + std::vector> invalids; + + auto ms0 = std::chrono::milliseconds(0); + auto ms50 = std::chrono::milliseconds(50); + + goods.push_back(retryAsync(ms0, 1.0, 5, alwaysSucceeds)); + goods.push_back(retryAsync(ms50, 2.0, 5, alwaysSucceeds)); + + bads.push_back(retryAsync(ms0, 1.0, 5, alwaysFails)); + bads.push_back(retryAsync(ms50, 2.0, 5, alwaysFails)); + + bads.push_back(retryAsync(ms0, 1.0, 5, makeSucceedsAfterIters(6))); + bads.push_back(retryAsync(ms50, 2.0, 5, makeSucceedsAfterIters(6))); + goods.push_back(retryAsync(ms0, 1.0, 5, makeSucceedsAfterIters(5))); + goods.push_back(retryAsync(ms50, 2.0, 5, makeSucceedsAfterIters(5))); + + bads.push_back(retryAsync(ms0, 1.0, 5, makeSucceedsAfterMs(999))); + bads.push_back(retryAsync(ms50, 2.0, 5, makeSucceedsAfterMs(999))); + bads.push_back(retryAsync(ms0, 1.0, 5, makeSucceedsAfterMs(500))); + goods.push_back(retryAsync(ms50, 2.0, 5, makeSucceedsAfterMs(500))); + + invalids.push_back(retryAsync(-ms50, 2.0, 5, alwaysSucceeds)); + invalids.push_back(retryAsync(ms50, -1.0, 5, alwaysSucceeds)); + invalids.push_back(retryAsync(ms50, 2.0, 0, alwaysSucceeds)); + invalids.push_back(retryAsync(ms50, 2.0, -1, alwaysSucceeds)); + + for(auto& fut : goods) { + ASSERT_EQ(fut.get(), 42); + } + for(auto& fut : bads) { + ASSERT_THROW(fut.get(), std::runtime_error); + } + for(auto& fut : invalids) { + ASSERT_THROW(fut.get(), std::invalid_argument); + } + + // check special case promise / future + auto alwaysSucceedsVoid = []() -> void {}; + auto alwaysFailsVoid = []() -> void { throw std::runtime_error("bleh"); }; + + retryAsync(ms0, 1.0, 5, alwaysSucceedsVoid).get(); + ASSERT_THROW( + retryAsync(ms0, 1.0, 5, alwaysFailsVoid).get(), + std::runtime_error + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp b/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp index d5a0d90..54b05c5 100644 --- a/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp +++ b/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp @@ -19,476 +19,487 @@ using namespace fl; namespace { class ContribModuleTestF16 : public ::testing::Test { - protected: - void SetUp() override { - // Ensures all operations will be in f16 - OptimMode::get().setOptimLevel(OptimLevel::O3); - } - - void TearDown() override { - OptimMode::get().setOptimLevel(OptimLevel::DEFAULT); - } +protected: + void SetUp() override { + // Ensures all operations will be in f16 + OptimMode::get().setOptimLevel(OptimLevel::O3); + } + + void TearDown() override { + OptimMode::get().setOptimLevel(OptimLevel::DEFAULT); + } }; } // namespace TEST(ContribModuleTest, ResidualFwd) { - auto conv = Conv2D(30, 50, 9, 7, 2, 3, 3, 2); - // bn is shared between both residual networks - auto bn = std::make_shared(2, 50, 0.0); - auto relu = ReLU(); - - int batchsize = 10; - auto input = Variable(fl::rand({120, 100, 30, batchsize}), false); - - auto outputConv = conv.forward(input); - auto outputBn = bn->forward(outputConv); - - Residual resModule1; - resModule1.add(Conv2D(conv)); // explicit copy - resModule1.add(bn); - resModule1.add(ReLU()); - resModule1.addShortcut(1, 3); - - auto output1 = resModule1.forward(input); - auto output1True = relu.forward(outputBn + outputConv); - ASSERT_TRUE(allClose(output1, output1True)); - - Residual resModule2; - resModule2.add(std::move(conv)); - resModule2.add(bn); - resModule2.add(ReLU()); - resModule2.addShortcut(1, 4); - resModule2.addShortcut(1, 3); - resModule2.addShortcut(2, 4); - - auto output2 = resModule2.forward(input); - auto output2True = - relu.forward(outputBn + outputConv) + outputBn + outputConv; - ASSERT_TRUE(allClose(output2, output2True)); + auto conv = Conv2D(30, 50, 9, 7, 2, 3, 3, 2); + // bn is shared between both residual networks + auto bn = std::make_shared(2, 50, 0.0); + auto relu = ReLU(); + + int batchsize = 10; + auto input = Variable(fl::rand({120, 100, 30, batchsize}), false); + + auto outputConv = conv.forward(input); + auto outputBn = bn->forward(outputConv); + + Residual resModule1; + resModule1.add(Conv2D(conv)); // explicit copy + resModule1.add(bn); + resModule1.add(ReLU()); + resModule1.addShortcut(1, 3); + + auto output1 = resModule1.forward(input); + auto output1True = relu.forward(outputBn + outputConv); + ASSERT_TRUE(allClose(output1, output1True)); + + Residual resModule2; + resModule2.add(std::move(conv)); + resModule2.add(bn); + resModule2.add(ReLU()); + resModule2.addShortcut(1, 4); + resModule2.addShortcut(1, 3); + resModule2.addShortcut(2, 4); + + auto output2 = resModule2.forward(input); + auto output2True = + relu.forward(outputBn + outputConv) + outputBn + outputConv; + ASSERT_TRUE(allClose(output2, output2True)); } TEST(ContribModuleTest, ResidualFwdWithProjection) { - const float proj1FwdScale = 0.24; - const float proj2FwdScale = 0.5; - const float linFwdScale = 0.3; - - auto linear1 = Linear(12, 8); - auto relu1 = ReLU(); - auto linear2 = Linear(8, 4); - auto relu2 = ReLU(); - auto linear3 = Linear(4, 4); - auto relu3 = ReLU(); - auto projection1 = Linear(8, 4); - auto projection2 = Linear(12, 4); - - auto input = Variable(fl::rand({12, 10, 3, 4}), false); - auto output1True = linear1.forward(input); - auto outputTrue = relu1.forward(output1True); - outputTrue = linear2.forward(outputTrue * linFwdScale); - outputTrue = relu2.forward( - (outputTrue + projection1.forward(output1True)) * proj1FwdScale); - outputTrue = (outputTrue + projection2.forward(input)) * proj2FwdScale; - outputTrue = linear3.forward(outputTrue); - outputTrue = relu3.forward(outputTrue) + outputTrue; - - auto resModule = Residual(); - resModule.add(std::move(linear1)); - resModule.add(std::move(relu1)); - resModule.add(std::move(linear2)); - resModule.addScale(3, linFwdScale); - resModule.add(std::move(relu2)); - resModule.addShortcut(1, 4, projection1); - resModule.addScale(4, proj1FwdScale); - resModule.add(std::move(linear3)); - resModule.addShortcut(0, 5, projection2); - resModule.addScale(5, proj2FwdScale); - resModule.add(std::move(relu3)); - resModule.addShortcut(5, 7); - - auto outputRes = resModule.forward(input); - ASSERT_TRUE(allClose(outputRes, outputTrue)); + const float proj1FwdScale = 0.24; + const float proj2FwdScale = 0.5; + const float linFwdScale = 0.3; + + auto linear1 = Linear(12, 8); + auto relu1 = ReLU(); + auto linear2 = Linear(8, 4); + auto relu2 = ReLU(); + auto linear3 = Linear(4, 4); + auto relu3 = ReLU(); + auto projection1 = Linear(8, 4); + auto projection2 = Linear(12, 4); + + auto input = Variable(fl::rand({12, 10, 3, 4}), false); + auto output1True = linear1.forward(input); + auto outputTrue = relu1.forward(output1True); + outputTrue = linear2.forward(outputTrue * linFwdScale); + outputTrue = relu2.forward( + (outputTrue + projection1.forward(output1True)) * proj1FwdScale + ); + outputTrue = (outputTrue + projection2.forward(input)) * proj2FwdScale; + outputTrue = linear3.forward(outputTrue); + outputTrue = relu3.forward(outputTrue) + outputTrue; + + auto resModule = Residual(); + resModule.add(std::move(linear1)); + resModule.add(std::move(relu1)); + resModule.add(std::move(linear2)); + resModule.addScale(3, linFwdScale); + resModule.add(std::move(relu2)); + resModule.addShortcut(1, 4, projection1); + resModule.addScale(4, proj1FwdScale); + resModule.add(std::move(linear3)); + resModule.addShortcut(0, 5, projection2); + resModule.addScale(5, proj2FwdScale); + resModule.add(std::move(relu3)); + resModule.addShortcut(5, 7); + + auto outputRes = resModule.forward(input); + ASSERT_TRUE(allClose(outputRes, outputTrue)); } TEST(ContribModuleTest, AsymmetricConv1DFwd) { - int batchsize = 10; - int timesteps = 120; - int c = 32; + int batchsize = 10; + int timesteps = 120; + int c = 32; - auto conv = AsymmetricConv1D(c, c, 5, 1, -1, 0, 1); // use only past - auto input = Variable(fl::rand({timesteps, 1, c, batchsize}), false); + auto conv = AsymmetricConv1D(c, c, 5, 1, -1, 0, 1); // use only past + auto input = Variable(fl::rand({timesteps, 1, c, batchsize}), false); - auto output = conv.forward(input); + auto output = conv.forward(input); - ASSERT_EQ(output.dim(0), timesteps); - ASSERT_EQ(output.dim(1), 1); - ASSERT_EQ(output.dim(2), c); + ASSERT_EQ(output.dim(0), timesteps); + ASSERT_EQ(output.dim(1), 1); + ASSERT_EQ(output.dim(2), c); - auto convFuture = AsymmetricConv1D(c, c, 5, 1, -1, 1, 1); // use only future - auto outputFuture = convFuture.forward(input); - ASSERT_EQ(outputFuture.dim(0), timesteps); - ASSERT_EQ(outputFuture.dim(1), 1); - ASSERT_EQ(outputFuture.dim(2), c); + auto convFuture = AsymmetricConv1D(c, c, 5, 1, -1, 1, 1); // use only future + auto outputFuture = convFuture.forward(input); + ASSERT_EQ(outputFuture.dim(0), timesteps); + ASSERT_EQ(outputFuture.dim(1), 1); + ASSERT_EQ(outputFuture.dim(2), c); - ASSERT_FALSE(allClose(output, outputFuture)); + ASSERT_FALSE(allClose(output, outputFuture)); } void transformerPadMaskFwd(bool isfp16) { - int timesteps = 10; - int c = 4; - int nheads = 2; - auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; - - auto tr = - Transformer(c, c / nheads, c, nheads, timesteps, 0, 0, false, false); - auto input1 = Variable(fl::rand({c, timesteps, /* B = */ 1}, dtype), false); - auto input1NoPad = input1(fl::span, fl::range(0, timesteps / 2)); - auto input2 = Variable(fl::rand({c, timesteps, /* B = */ 1}, dtype), false); - auto input = fl::concatenate({input1, input2}, 2); - auto padMask = fl::full({timesteps, 2}, 1); - padMask(fl::iota({timesteps / 2}) + timesteps / 2, 0) = 0; - auto noPadMask = fl::full({timesteps, 2}, 1); - - auto output = tr.forward({input, Variable(padMask, false)}).front(); - auto outputNoPad = tr.forward({input, Variable(noPadMask, false)}).front(); - - ASSERT_EQ(output.dim(0), c); - ASSERT_EQ(output.dim(1), timesteps); - ASSERT_EQ(output.dim(2), 2); - - if (OptimMode::get().getOptimLevel() == OptimLevel::O3) { - ASSERT_EQ(outputNoPad.type(), input.type()); - } else { - ASSERT_EQ(outputNoPad.type(), fl::dtype::f32); // result is upcast - } - - auto output1 = tr.forward({input1NoPad, - Variable( - padMask(fl::range(0, timesteps / 2))( - fl::span, fl::range(0, 1)), - false)}) - .front(); - auto output2 = - tr.forward({input2, Variable(padMask(fl::span, fl::range(1, 2)), false)}) - .front(); - ASSERT_TRUE(allClose( - output.tensor()(fl::span, fl::span, fl::range(1, 2)), output2.tensor())); - ASSERT_TRUE(allClose( - outputNoPad.tensor()(fl::span, fl::span, fl::range(1, 2)), - output2.tensor())); - ASSERT_TRUE(allClose( - output.tensor()(fl::span, fl::iota({timesteps / 2}), fl::range(0, 1)), - output1.tensor())); - ASSERT_FALSE(allClose( - outputNoPad.tensor()( - fl::span, fl::iota({timesteps / 2}), fl::range(0, 1)), - output1.tensor())); + int timesteps = 10; + int c = 4; + int nheads = 2; + auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; + + auto tr = + Transformer(c, c / nheads, c, nheads, timesteps, 0, 0, false, false); + auto input1 = Variable(fl::rand({c, timesteps, /* B = */ 1}, dtype), false); + auto input1NoPad = input1(fl::span, fl::range(0, timesteps / 2)); + auto input2 = Variable(fl::rand({c, timesteps, /* B = */ 1}, dtype), false); + auto input = fl::concatenate({input1, input2}, 2); + auto padMask = fl::full({timesteps, 2}, 1); + padMask(fl::iota({timesteps / 2}) + timesteps / 2, 0) = 0; + auto noPadMask = fl::full({timesteps, 2}, 1); + + auto output = tr.forward({input, Variable(padMask, false)}).front(); + auto outputNoPad = tr.forward({input, Variable(noPadMask, false)}).front(); + + ASSERT_EQ(output.dim(0), c); + ASSERT_EQ(output.dim(1), timesteps); + ASSERT_EQ(output.dim(2), 2); + + if(OptimMode::get().getOptimLevel() == OptimLevel::O3) { + ASSERT_EQ(outputNoPad.type(), input.type()); + } else { + ASSERT_EQ(outputNoPad.type(), fl::dtype::f32); // result is upcast + } + + auto output1 = tr.forward( + {input1NoPad, + Variable( + padMask(fl::range(0, timesteps / 2))( + fl::span, + fl::range(0, 1) + ), + false + )} + ) + .front(); + auto output2 = + tr.forward({input2, Variable(padMask(fl::span, fl::range(1, 2)), false)}) + .front(); + ASSERT_TRUE(allClose( + output.tensor()(fl::span, fl::span, fl::range(1, 2)), output2.tensor())); + ASSERT_TRUE(allClose( + outputNoPad.tensor()(fl::span, fl::span, fl::range(1, 2)), + output2.tensor())); + ASSERT_TRUE(allClose( + output.tensor()(fl::span, fl::iota({timesteps / 2}), fl::range(0, 1)), + output1.tensor())); + ASSERT_FALSE(allClose( + outputNoPad.tensor()( + fl::span, fl::iota({timesteps / 2}), fl::range(0, 1)), + output1.tensor())); } TEST(ContribModuleTest, TransformerPadMaskFwd) { - transformerPadMaskFwd(false); + transformerPadMaskFwd(false); } TEST_F(ContribModuleTestF16, TransformerPadMaskFwd16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - transformerPadMaskFwd(true); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + transformerPadMaskFwd(true); } void transformerFwd(bool isfp16) { - int batchsize = 10; - int timesteps = 120; - int c = 32; - int nheads = 4; - auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; - - auto tr = - Transformer(c, c / nheads, c, nheads, timesteps, 0.2, 0.1, true, false); - auto input = Variable(fl::rand({c, timesteps, batchsize}, dtype), false); - - fl::Variable padMask; - auto output = tr.forward({input, padMask}); - if (OptimMode::get().getOptimLevel() == OptimLevel::O3) { - ASSERT_EQ(output[0].type(), input.type()); - } else { - ASSERT_EQ(output[0].type(), fl::dtype::f32); // result is upcast - } - - ASSERT_EQ(output[0].dim(0), c); - ASSERT_EQ(output[0].dim(1), timesteps); - ASSERT_EQ(output[0].dim(2), batchsize); - - tr.setDropout(0); - tr.setLayerDropout(0); - auto output1 = tr.forward({input, padMask}).front(); - auto output2 = tr.forward({input, padMask}).front(); - ASSERT_TRUE(allClose(output1, output2, 1E-7)); + int batchsize = 10; + int timesteps = 120; + int c = 32; + int nheads = 4; + auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; + + auto tr = + Transformer(c, c / nheads, c, nheads, timesteps, 0.2, 0.1, true, false); + auto input = Variable(fl::rand({c, timesteps, batchsize}, dtype), false); + + fl::Variable padMask; + auto output = tr.forward({input, padMask}); + if(OptimMode::get().getOptimLevel() == OptimLevel::O3) { + ASSERT_EQ(output[0].type(), input.type()); + } else { + ASSERT_EQ(output[0].type(), fl::dtype::f32); // result is upcast + } + + ASSERT_EQ(output[0].dim(0), c); + ASSERT_EQ(output[0].dim(1), timesteps); + ASSERT_EQ(output[0].dim(2), batchsize); + + tr.setDropout(0); + tr.setLayerDropout(0); + auto output1 = tr.forward({input, padMask}).front(); + auto output2 = tr.forward({input, padMask}).front(); + ASSERT_TRUE(allClose(output1, output2, 1E-7)); } TEST(ContribModuleTest, TransformerFwd) { - transformerFwd(false); + transformerFwd(false); } TEST_F(ContribModuleTestF16, TransformerFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - transformerFwd(true); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + transformerFwd(true); } void conformerFwd(bool isfp16) { - int batchsize = 10; - int timesteps = 120; - int c = 32; - int nheads = 4; - auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; - - auto tr = Conformer(c, c / nheads, c, nheads, timesteps, 33, 0.2, 0.1); - auto input = Variable(fl::rand({c, timesteps, batchsize}, dtype), false); - - auto output = tr.forward({input, Variable()}); - if (OptimMode::get().getOptimLevel() == OptimLevel::O3) { - ASSERT_EQ(output[0].type(), input.type()); - } else { - ASSERT_EQ(output[0].type(), fl::dtype::f32); // result is upcast - } - - ASSERT_EQ(output[0].dim(0), c); - ASSERT_EQ(output[0].dim(1), timesteps); - ASSERT_EQ(output[0].dim(2), batchsize); + int batchsize = 10; + int timesteps = 120; + int c = 32; + int nheads = 4; + auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; + + auto tr = Conformer(c, c / nheads, c, nheads, timesteps, 33, 0.2, 0.1); + auto input = Variable(fl::rand({c, timesteps, batchsize}, dtype), false); + + auto output = tr.forward({input, Variable()}); + if(OptimMode::get().getOptimLevel() == OptimLevel::O3) { + ASSERT_EQ(output[0].type(), input.type()); + } else { + ASSERT_EQ(output[0].type(), fl::dtype::f32); // result is upcast + } + + ASSERT_EQ(output[0].dim(0), c); + ASSERT_EQ(output[0].dim(1), timesteps); + ASSERT_EQ(output[0].dim(2), batchsize); } TEST(ContribModuleTest, ConformerFwd) { - conformerFwd(false); + conformerFwd(false); } TEST_F(ContribModuleTestF16, ConformerFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - conformerFwd(true); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + conformerFwd(true); } void positionEmbeddingFwd(bool isfp16) { - int batchsize = 10; - int timesteps = 120; - int csz = 256; - auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; + int batchsize = 10; + int timesteps = 120; + int csz = 256; + auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; - auto posemb = PositionEmbedding(csz, timesteps, 0.5); - auto input = Variable(fl::rand({csz, timesteps, batchsize}, dtype), false); + auto posemb = PositionEmbedding(csz, timesteps, 0.5); + auto input = Variable(fl::rand({csz, timesteps, batchsize}, dtype), false); - auto output = posemb.forward({input}); + auto output = posemb.forward({input}); - ASSERT_EQ(output[0].dim(0), csz); - ASSERT_EQ(output[0].dim(1), timesteps); - ASSERT_EQ(output[0].dim(2), batchsize); + ASSERT_EQ(output[0].dim(0), csz); + ASSERT_EQ(output[0].dim(1), timesteps); + ASSERT_EQ(output[0].dim(2), batchsize); - ASSERT_FALSE(allClose(output[0], input)); + ASSERT_FALSE(allClose(output[0], input)); } TEST(ContribModuleTest, PositionEmbeddingFwd) { - positionEmbeddingFwd(false); + positionEmbeddingFwd(false); } TEST_F(ContribModuleTestF16, PositionEmbeddingFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - positionEmbeddingFwd(true); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + positionEmbeddingFwd(true); } void sinusoidalPositionEmbeddingFwd(bool isfp16) { - int batchsize = 10; - int timesteps = 120; - int csz = 256; - auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; - - auto posemb = SinusoidalPositionEmbedding(csz, /* inputScale = */ 2.); - auto input = - Variable(fl::rand({csz, timesteps, batchsize, 1}, dtype), false) - 0.5; - - auto output = posemb.forward({input}); - - ASSERT_EQ(output[0].dim(0), csz); - ASSERT_EQ(output[0].dim(1), timesteps); - ASSERT_EQ(output[0].dim(2), batchsize); - auto castOutput = output[0].tensor(); - if (isfp16) { - castOutput = output[0].astype(fl::dtype::f32).tensor(); - } - ASSERT_TRUE((fl::amax(castOutput, {0})).scalar() <= 2); - ASSERT_TRUE((fl::amin(castOutput, {0})).scalar() >= -2); + int batchsize = 10; + int timesteps = 120; + int csz = 256; + auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; + + auto posemb = SinusoidalPositionEmbedding(csz, /* inputScale = */ 2.); + auto input = + Variable(fl::rand({csz, timesteps, batchsize, 1}, dtype), false) - 0.5; + + auto output = posemb.forward({input}); + + ASSERT_EQ(output[0].dim(0), csz); + ASSERT_EQ(output[0].dim(1), timesteps); + ASSERT_EQ(output[0].dim(2), batchsize); + auto castOutput = output[0].tensor(); + if(isfp16) { + castOutput = output[0].astype(fl::dtype::f32).tensor(); + } + ASSERT_TRUE((fl::amax(castOutput, {0})).scalar() <= 2); + ASSERT_TRUE((fl::amin(castOutput, {0})).scalar() >= -2); } TEST(ContribModuleTest, SinusoidalPositionEmbeddingFwd) { - sinusoidalPositionEmbeddingFwd(false); + sinusoidalPositionEmbeddingFwd(false); } TEST_F(ContribModuleTestF16, SinusoidalPositionEmbeddingFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - sinusoidalPositionEmbeddingFwd(true); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + sinusoidalPositionEmbeddingFwd(true); } TEST(ContribModuleTest, AdaptiveEmbedding) { - std::vector values = {1, 4, 6, 2, 12, 7, 4, 21, 22, 18, 3, 23}; - int T = 6, B = 2, dim = 128; - auto input = Variable(Tensor::fromVector({T, B}, values), false); - std::vector cutoff = {5, 10, 25}; - auto emb = AdaptiveEmbedding(dim, cutoff); - auto output = emb.forward(input); - - ASSERT_EQ(output.dim(0), dim); - ASSERT_EQ(output.dim(1), T); - ASSERT_EQ(output.dim(2), B); + std::vector values = {1, 4, 6, 2, 12, 7, 4, 21, 22, 18, 3, 23}; + int T = 6, B = 2, dim = 128; + auto input = Variable(Tensor::fromVector({T, B}, values), false); + std::vector cutoff = {5, 10, 25}; + auto emb = AdaptiveEmbedding(dim, cutoff); + auto output = emb.forward(input); + + ASSERT_EQ(output.dim(0), dim); + ASSERT_EQ(output.dim(1), T); + ASSERT_EQ(output.dim(2), B); } void tdsFwd(bool isfp16) { - int batchsize = 10; - int timesteps = 120; - int w = 4; - int c = 10; + int batchsize = 10; + int timesteps = 120; + int w = 4; + int c = 10; - auto tds = TDSBlock(c, 9, w); - auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; - auto input = Variable(fl::rand({timesteps, w, c, batchsize}, dtype), false); + auto tds = TDSBlock(c, 9, w); + auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; + auto input = Variable(fl::rand({timesteps, w, c, batchsize}, dtype), false); - auto output = tds.forward({input})[0]; + auto output = tds.forward({input})[0]; - ASSERT_EQ(output.dim(0), timesteps); - ASSERT_EQ(output.dim(1), w); - ASSERT_EQ(output.dim(2), c); - ASSERT_EQ(output.type(), input.type()); + ASSERT_EQ(output.dim(0), timesteps); + ASSERT_EQ(output.dim(1), w); + ASSERT_EQ(output.dim(2), c); + ASSERT_EQ(output.type(), input.type()); } TEST(ContribModuleTest, TDSFwd) { - tdsFwd(false); + tdsFwd(false); } TEST_F(ContribModuleTestF16, TDSFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - tdsFwd(true); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + tdsFwd(true); } void streamingTDSFwd(bool isfp16) { - int batchsize = 10; - int timesteps = 120; - int w = 4; - int c = 10; - int kw = 9; - int rPad = 3; - - auto stds = - TDSBlock(c, kw, w, 0 /* dropout */, 0 /* innerLinearDim */, rPad, true); - auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; - auto input = Variable(fl::rand({timesteps, w, c, batchsize}, dtype), false); - - auto output = stds.forward({input})[0]; - - ASSERT_EQ(output.dim(0), timesteps); - ASSERT_EQ(output.dim(1), w); - ASSERT_EQ(output.dim(2), c); - ASSERT_EQ(output.type(), input.type()); + int batchsize = 10; + int timesteps = 120; + int w = 4; + int c = 10; + int kw = 9; + int rPad = 3; + + auto stds = + TDSBlock(c, kw, w, 0 /* dropout */, 0 /* innerLinearDim */, rPad, true); + auto dtype = isfp16 ? fl::dtype::f16 : fl::dtype::f32; + auto input = Variable(fl::rand({timesteps, w, c, batchsize}, dtype), false); + + auto output = stds.forward({input})[0]; + + ASSERT_EQ(output.dim(0), timesteps); + ASSERT_EQ(output.dim(1), w); + ASSERT_EQ(output.dim(2), c); + ASSERT_EQ(output.type(), input.type()); } TEST(ContribModuleTest, StreamingTDSFwd) { - streamingTDSFwd(false); + streamingTDSFwd(false); } TEST_F(ContribModuleTestF16, StreamingTDSFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - streamingTDSFwd(true); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + streamingTDSFwd(true); } TEST(ContribModuleTest, SpecAugmentFwd) { - SpecAugment specAug(0, 27, 2, 100, 0.2, 2); - int T = 512, F = 80; - auto input = Variable(fl::rand({T, F}), false); - - specAug.eval(); - ASSERT_TRUE(fl::allClose(input, specAug(input))); - - specAug.train(); - auto output = specAug(input); - ASSERT_FALSE(fl::allClose(input, output)); - - // Every value of output is either 0 or input - for (int t = 0; t < T; ++t) { - for (int f = 0; f < F; ++f) { - auto o = output.tensor()(t, f).scalar(); - auto i = input.tensor()(t, f).scalar(); - ASSERT_TRUE(o == i || o == 0); - } - } - - // non-zero time frames are masked - int tZeros = 0; - for (int t = 0; t < T; ++t) { - auto curOutSlice = output.tensor()(t); - tZeros = fl::all(curOutSlice == 0).asScalar() ? tZeros + 1 : tZeros; - } - ASSERT_GT(tZeros, 0); - - // non-zero frequency channels are masked - int fZeros = 0; - for (int f = 0; f < F; ++f) { - auto curOutSlice = output.tensor()(fl::span, f); - fZeros = fl::all(curOutSlice == 0).asScalar() ? fZeros + 1 : fZeros; - } - ASSERT_GT(fZeros, 0); -} + SpecAugment specAug(0, 27, 2, 100, 0.2, 2); + int T = 512, F = 80; + auto input = Variable(fl::rand({T, F}), false); + + specAug.eval(); + ASSERT_TRUE(fl::allClose(input, specAug(input))); -void computeRawWavSpecAug(bool isfp16, float epsilon) { - // no time, only freq masking - for (int nmask = 1; nmask < 3; nmask++) { - RawWavSpecAugment specAug( - 0, 1, nmask, 0, 0, 0, 1, 2000, 6000, 16000, 20000); specAug.train(); + auto output = specAug(input); + ASSERT_FALSE(fl::allClose(input, output)); + + // Every value of output is either 0 or input + for(int t = 0; t < T; ++t) { + for(int f = 0; f < F; ++f) { + auto o = output.tensor()(t, f).scalar(); + auto i = input.tensor()(t, f).scalar(); + ASSERT_TRUE(o == i || o == 0); + } + } - int T = 300, C = 3, B = 4; - auto time = 2 * M_PI * fl::iota({T}) / 16000; - auto finalWav = fl::sin(time * 500) + fl::sin(time * 1000) + - fl::sin(time * 7000) + fl::sin(time * 7500); - auto inputWav = finalWav + fl::sin(time * 3000) + fl::sin(time * 4000) + - fl::sin(time * 5000); - inputWav = fl::tile(inputWav, {1, C, B}); - finalWav = fl::tile(finalWav, {1, C, B}); - if (isfp16) { - inputWav = inputWav.astype(fl::dtype::f16); - finalWav = finalWav.astype(fl::dtype::f16); + // non-zero time frames are masked + int tZeros = 0; + for(int t = 0; t < T; ++t) { + auto curOutSlice = output.tensor()(t); + tZeros = fl::all(curOutSlice == 0).asScalar() ? tZeros + 1 : tZeros; } + ASSERT_GT(tZeros, 0); - auto filteredWav = specAug(fl::Variable(inputWav, false)); - // compare middle of filtered wave to avoid edge artifacts comparison - int halfKernelWidth = 63; - ASSERT_TRUE(fl::allClose( - fl::Variable( - finalWav(fl::range(halfKernelWidth, T - halfKernelWidth)), false), - filteredWav(fl::range(halfKernelWidth, T - halfKernelWidth)), - epsilon)); - } + // non-zero frequency channels are masked + int fZeros = 0; + for(int f = 0; f < F; ++f) { + auto curOutSlice = output.tensor()(fl::span, f); + fZeros = fl::all(curOutSlice == 0).asScalar() ? fZeros + 1 : fZeros; + } + ASSERT_GT(fZeros, 0); +} + +void computeRawWavSpecAug(bool isfp16, float epsilon) { + // no time, only freq masking + for(int nmask = 1; nmask < 3; nmask++) { + RawWavSpecAugment specAug( + 0, 1, nmask, 0, 0, 0, 1, 2000, 6000, 16000, 20000); + specAug.train(); + + int T = 300, C = 3, B = 4; + auto time = 2 * M_PI * fl::iota({T}) / 16000; + auto finalWav = fl::sin(time * 500) + fl::sin(time * 1000) + + fl::sin(time * 7000) + fl::sin(time * 7500); + auto inputWav = finalWav + fl::sin(time * 3000) + fl::sin(time * 4000) + + fl::sin(time * 5000); + inputWav = fl::tile(inputWav, {1, C, B}); + finalWav = fl::tile(finalWav, {1, C, B}); + if(isfp16) { + inputWav = inputWav.astype(fl::dtype::f16); + finalWav = finalWav.astype(fl::dtype::f16); + } + + auto filteredWav = specAug(fl::Variable(inputWav, false)); + // compare middle of filtered wave to avoid edge artifacts comparison + int halfKernelWidth = 63; + ASSERT_TRUE( + fl::allClose( + fl::Variable( + finalWav(fl::range(halfKernelWidth, T - halfKernelWidth)), + false + ), + filteredWav(fl::range(halfKernelWidth, T - halfKernelWidth)), + epsilon + ) + ); + } } TEST(ContribModuleTest, RawWavSpecAugmentFwd) { - computeRawWavSpecAug(false, 1e-3); + computeRawWavSpecAug(false, 1e-3); } TEST_F(ContribModuleTestF16, RawWavSpecAugmentFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - computeRawWavSpecAug(true, 1e-2); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + computeRawWavSpecAug(true, 1e-2); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/contrib/modules/ContribSerializationTest.cpp b/flashlight/fl/test/contrib/modules/ContribSerializationTest.cpp index 208ef69..0217878 100644 --- a/flashlight/fl/test/contrib/modules/ContribSerializationTest.cpp +++ b/flashlight/fl/test/contrib/modules/ContribSerializationTest.cpp @@ -22,186 +22,219 @@ using namespace fl; TEST(SerializationTest, Residual) { - std::shared_ptr model = std::make_shared(); - model->add(Linear(12, 6)); - model->add(Linear(6, 6)); - model->add(ReLU()); - model->addShortcut(1, 3); - const fs::path path = fs::temp_directory_path() / "Residual.mdl"; - save(path, model); - - std::shared_ptr loaded; - load(path, loaded); - - auto input = Variable(fl::rand({12, 10, 3, 4}), false); - auto output = model->forward(input); - auto outputl = loaded->forward(input); - - ASSERT_TRUE(allParamsClose(*loaded.get(), *model)); - ASSERT_TRUE(allClose(outputl, output)); + std::shared_ptr model = std::make_shared(); + model->add(Linear(12, 6)); + model->add(Linear(6, 6)); + model->add(ReLU()); + model->addShortcut(1, 3); + const fs::path path = fs::temp_directory_path() / "Residual.mdl"; + save(path, model); + + std::shared_ptr loaded; + load(path, loaded); + + auto input = Variable(fl::rand({12, 10, 3, 4}), false); + auto output = model->forward(input); + auto outputl = loaded->forward(input); + + ASSERT_TRUE(allParamsClose(*loaded.get(), *model)); + ASSERT_TRUE(allClose(outputl, output)); } TEST(SerializationTest, AsymmetricConv1D) { - int c = 32; - auto model = std::make_shared(c, c, 5, 1, -1, 0, 1); + int c = 32; + auto model = std::make_shared(c, c, 5, 1, -1, 0, 1); - const fs::path path = fs::temp_directory_path() / "AsymmetricConv1D.mdl"; - save(path, model); + const fs::path path = fs::temp_directory_path() / "AsymmetricConv1D.mdl"; + save(path, model); - std::shared_ptr loaded; - load(path, loaded); + std::shared_ptr loaded; + load(path, loaded); - auto input = Variable(fl::rand({25, 10, c, 4}), false); - auto output = model->forward(input); - auto outputl = loaded->forward(input); + auto input = Variable(fl::rand({25, 10, c, 4}), false); + auto output = model->forward(input); + auto outputl = loaded->forward(input); - ASSERT_TRUE(allParamsClose(*loaded, *model)); - ASSERT_TRUE(allClose(outputl, output)); + ASSERT_TRUE(allParamsClose(*loaded, *model)); + ASSERT_TRUE(allClose(outputl, output)); } TEST(SerializationTest, Transformer) { - int batchsize = 10; - int timesteps = 120; - int c = 32; - int nheads = 4; - - auto model = std::make_shared( - c, c / nheads, c, nheads, timesteps, 0.2, 0.1, false, false); - model->eval(); - - const fs::path path = fs::temp_directory_path() / "Transformer.mdl"; - save(path, model); - - std::shared_ptr loaded; - load(path, loaded); - loaded->eval(); - - // auto input = Variable(fl::rand({c, timesteps, batchsize, 1}), false); - auto input = Variable(fl::rand({c, timesteps, batchsize}), false); - auto output = model->forward({input, Variable()}); - auto outputl = loaded->forward({input, Variable()}); - - ASSERT_TRUE(allParamsClose(*loaded, *model)); - ASSERT_TRUE(allClose(outputl[0], output[0])); + int batchsize = 10; + int timesteps = 120; + int c = 32; + int nheads = 4; + + auto model = std::make_shared( + c, + c / nheads, + c, + nheads, + timesteps, + 0.2, + 0.1, + false, + false + ); + model->eval(); + + const fs::path path = fs::temp_directory_path() / "Transformer.mdl"; + save(path, model); + + std::shared_ptr loaded; + load(path, loaded); + loaded->eval(); + + // auto input = Variable(fl::rand({c, timesteps, batchsize, 1}), false); + auto input = Variable(fl::rand({c, timesteps, batchsize}), false); + auto output = model->forward({input, Variable()}); + auto outputl = loaded->forward({input, Variable()}); + + ASSERT_TRUE(allParamsClose(*loaded, *model)); + ASSERT_TRUE(allClose(outputl[0], output[0])); } TEST(SerializationTest, ConformerSerialization) { - int batchsize = 10; - int timesteps = 120; - int c = 32; - int nheads = 4; - - auto model = std::make_shared( - c, c / nheads, c, nheads, timesteps, 33, 0.2, 0.1); - model->eval(); - - const fs::path path = fs::temp_directory_path() / "Conformer.mdl"; - save(path, model); - - std::shared_ptr loaded; - load(path, loaded); - loaded->eval(); - - // auto input = Variable(fl::rand({c, timesteps, batchsize, 1}), false); - auto input = Variable(fl::rand({c, timesteps, batchsize}), false); - auto output = model->forward({input, Variable()}); - auto outputl = loaded->forward({input, Variable()}); - - ASSERT_TRUE(allParamsClose(*loaded, *model)); - ASSERT_TRUE(allClose(outputl[0], output[0])); + int batchsize = 10; + int timesteps = 120; + int c = 32; + int nheads = 4; + + auto model = std::make_shared( + c, + c / nheads, + c, + nheads, + timesteps, + 33, + 0.2, + 0.1 + ); + model->eval(); + + const fs::path path = fs::temp_directory_path() / "Conformer.mdl"; + save(path, model); + + std::shared_ptr loaded; + load(path, loaded); + loaded->eval(); + + // auto input = Variable(fl::rand({c, timesteps, batchsize, 1}), false); + auto input = Variable(fl::rand({c, timesteps, batchsize}), false); + auto output = model->forward({input, Variable()}); + auto outputl = loaded->forward({input, Variable()}); + + ASSERT_TRUE(allParamsClose(*loaded, *model)); + ASSERT_TRUE(allClose(outputl[0], output[0])); } TEST(SerializationTest, PositionEmbedding) { - auto model = std::make_shared(128, 100, 0.1); - model->eval(); + auto model = std::make_shared(128, 100, 0.1); + model->eval(); - const fs::path path = fs::temp_directory_path() / "PositionEmbedding.mdl"; - save(path, model); + const fs::path path = fs::temp_directory_path() / "PositionEmbedding.mdl"; + save(path, model); - std::shared_ptr loaded; - load(path, loaded); - loaded->eval(); + std::shared_ptr loaded; + load(path, loaded); + loaded->eval(); - // auto input = Variable(fl::rand({128, 10, 5, 1}), false); - auto input = Variable(fl::rand({128, 10, 5}), false); - auto output = model->forward({input}); - auto outputl = loaded->forward({input}); + // auto input = Variable(fl::rand({128, 10, 5, 1}), false); + auto input = Variable(fl::rand({128, 10, 5}), false); + auto output = model->forward({input}); + auto outputl = loaded->forward({input}); - ASSERT_TRUE(allParamsClose(*loaded, *model)); - ASSERT_TRUE(allClose(outputl[0], output[0])); + ASSERT_TRUE(allParamsClose(*loaded, *model)); + ASSERT_TRUE(allClose(outputl[0], output[0])); } TEST(SerializationTest, SinusoidalPositionEmbedding) { - auto model = std::make_shared(128, 2.); + auto model = std::make_shared(128, 2.); - const fs::path path = - fs::temp_directory_path() / "SinusoidalPositionEmbedding.mdl"; - save(path, model); + const fs::path path = + fs::temp_directory_path() / "SinusoidalPositionEmbedding.mdl"; + save(path, model); - std::shared_ptr loaded; - load(path, loaded); + std::shared_ptr loaded; + load(path, loaded); - // auto input = Variable(fl::rand({128, 10, 5, 1}), false); - auto input = Variable(fl::rand({128, 10, 5}), false); - auto output = model->forward({input}); - auto outputl = loaded->forward({input}); + // auto input = Variable(fl::rand({128, 10, 5, 1}), false); + auto input = Variable(fl::rand({128, 10, 5}), false); + auto output = model->forward({input}); + auto outputl = loaded->forward({input}); - ASSERT_TRUE(allParamsClose(*loaded, *model)); - ASSERT_TRUE(allClose(outputl[0], output[0])); + ASSERT_TRUE(allParamsClose(*loaded, *model)); + ASSERT_TRUE(allClose(outputl[0], output[0])); } TEST(SerializationTest, AdaptiveEmbedding) { - std::vector cutoff = {5, 10, 25}; - auto model = std::make_shared(128, cutoff); + std::vector cutoff = {5, 10, 25}; + auto model = std::make_shared(128, cutoff); - const fs::path path = fs::temp_directory_path() / "AdaptiveEmbedding.mdl"; - save(path, model); + const fs::path path = fs::temp_directory_path() / "AdaptiveEmbedding.mdl"; + save(path, model); - std::shared_ptr loaded; - load(path, loaded); + std::shared_ptr loaded; + load(path, loaded); - std::vector values = {1, 4, 6, 2, 12, 7, 4, 21, 22, 18, 3, 23}; - auto input = - Variable(Tensor::fromVector({6, 2}, values, fl::dtype::f32), false); - auto output = model->forward(input); - auto outputl = loaded->forward(input); + std::vector values = {1, 4, 6, 2, 12, 7, 4, 21, 22, 18, 3, 23}; + auto input = + Variable(Tensor::fromVector({6, 2}, values, fl::dtype::f32), false); + auto output = model->forward(input); + auto outputl = loaded->forward(input); - ASSERT_TRUE(allParamsClose(*loaded, *model)); - ASSERT_TRUE(allClose(outputl, output)); + ASSERT_TRUE(allParamsClose(*loaded, *model)); + ASSERT_TRUE(allClose(outputl, output)); } TEST(SerializationTest, RawWavSpecAugment) { - auto model = std::make_shared( - 0, 1, 1, 0, 0, 0, 1, 2000, 6000, 16000, 20000); - model->eval(); - - const fs::path path = fs::temp_directory_path() / "RawWavSpecAugment.mdl"; - save(path, model); - - std::shared_ptr loaded; - load(path, loaded); - loaded->train(); - - int T = 300; - // Input is T x C x B (here, C, B = 1) - auto time = 2 * M_PI * fl::reshape(fl::iota({T}), {T, 1, 1}) / 16000; - auto finalWav = fl::sin(time * 500) + fl::sin(time * 1000) + - fl::sin(time * 7000) + fl::sin(time * 7500); - auto inputWav = finalWav + fl::sin(time * 3000) + fl::sin(time * 4000) + - fl::sin(time * 5000); - - auto filteredWav = loaded->forward(fl::Variable(inputWav, false)); - // compare middle of filtered wave to avoid edge artifacts comparison - int halfKernelWidth = 63; - ASSERT_TRUE(fl::allClose( - fl::Variable( - finalWav(fl::range(halfKernelWidth, T - halfKernelWidth)), false), - filteredWav(fl::range(halfKernelWidth, T - halfKernelWidth)), - 1e-3)); + auto model = std::make_shared( + 0, + 1, + 1, + 0, + 0, + 0, + 1, + 2000, + 6000, + 16000, + 20000 + ); + model->eval(); + + const fs::path path = fs::temp_directory_path() / "RawWavSpecAugment.mdl"; + save(path, model); + + std::shared_ptr loaded; + load(path, loaded); + loaded->train(); + + int T = 300; + // Input is T x C x B (here, C, B = 1) + auto time = 2 * M_PI * fl::reshape(fl::iota({T}), {T, 1, 1}) / 16000; + auto finalWav = fl::sin(time * 500) + fl::sin(time * 1000) + + fl::sin(time * 7000) + fl::sin(time * 7500); + auto inputWav = finalWav + fl::sin(time * 3000) + fl::sin(time * 4000) + + fl::sin(time * 5000); + + auto filteredWav = loaded->forward(fl::Variable(inputWav, false)); + // compare middle of filtered wave to avoid edge artifacts comparison + int halfKernelWidth = 63; + ASSERT_TRUE( + fl::allClose( + fl::Variable( + finalWav(fl::range(halfKernelWidth, T - halfKernelWidth)), + false + ), + filteredWav(fl::range(halfKernelWidth, T - halfKernelWidth)), + 1e-3 + ) + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/dataset/DatasetTest.cpp b/flashlight/fl/test/dataset/DatasetTest.cpp index 60ca8cd..f8250d2 100644 --- a/flashlight/fl/test/dataset/DatasetTest.cpp +++ b/flashlight/fl/test/dataset/DatasetTest.cpp @@ -20,581 +20,616 @@ using namespace fl; TEST(DatasetTest, TensorDataset) { - std::vector tensormap = { - fl::rand({100, 200, 300}), fl::rand({150, 300})}; - TensorDataset tensords(tensormap); - - // Check `size` method - ASSERT_EQ(tensords.size(), 300); - - // Values using `get` method - auto ff1 = tensords.get(10); - ASSERT_EQ(ff1.size(), 2); - ASSERT_TRUE(allClose(ff1[0], tensormap[0](fl::span, fl::span, 10))); - ASSERT_TRUE(allClose(ff1[1], tensormap[1](fl::span, 10))); + std::vector tensormap = { + fl::rand({100, 200, 300}), fl::rand({150, 300})}; + TensorDataset tensords(tensormap); + + // Check `size` method + ASSERT_EQ(tensords.size(), 300); + + // Values using `get` method + auto ff1 = tensords.get(10); + ASSERT_EQ(ff1.size(), 2); + ASSERT_TRUE(allClose(ff1[0], tensormap[0](fl::span, fl::span, 10))); + ASSERT_TRUE(allClose(ff1[1], tensormap[1](fl::span, 10))); } TEST(DatasetTest, TranformDataset) { - // first create a tensor dataset - std::vector tensormap = {fl::rand({100, 200, 300})}; - auto tensords = std::make_shared(tensormap); - - auto scaleAndAdd = [](const Tensor& a) { return fl::sin(a) + 1.0; }; - TransformDataset transformds(tensords, {scaleAndAdd}); - - // Check `size` method - ASSERT_TRUE(transformds.size() == 300); - - // Values using `get` method - auto ff1 = transformds.get(10); - ASSERT_EQ(ff1.size(), 1); - ASSERT_TRUE( - allClose(ff1[0], fl::sin(tensormap[0](fl::span, fl::span, 10)) + 1.0)); + // first create a tensor dataset + std::vector tensormap = {fl::rand({100, 200, 300})}; + auto tensords = std::make_shared(tensormap); + + auto scaleAndAdd = [](const Tensor& a) { return fl::sin(a) + 1.0; }; + TransformDataset transformds(tensords, {scaleAndAdd}); + + // Check `size` method + ASSERT_TRUE(transformds.size() == 300); + + // Values using `get` method + auto ff1 = transformds.get(10); + ASSERT_EQ(ff1.size(), 1); + ASSERT_TRUE( + allClose(ff1[0], fl::sin(tensormap[0](fl::span, fl::span, 10)) + 1.0) + ); } TEST(DatasetTest, BatchDataset) { - // first create a tensor dataset - std::vector tensormap = {fl::rand({100, 200, 300})}; - auto tensords = std::make_shared(tensormap); - - BatchDataset batchds(tensords, 7, BatchDatasetPolicy::INCLUDE_LAST); - - // Check `size` method - ASSERT_EQ(batchds.size(), 43); - - // Values using `get` method - auto ff1 = batchds.get(42); - ASSERT_EQ(ff1.size(), 1); - ASSERT_TRUE( - allClose(ff1[0], tensormap[0](fl::span, fl::span, fl::range(294, 300)))); - - ff1 = batchds.get(10); - ASSERT_EQ(ff1.size(), 1); - ASSERT_TRUE( - allClose(ff1[0], tensormap[0](fl::span, fl::span, fl::range(70, 77)))); + // first create a tensor dataset + std::vector tensormap = {fl::rand({100, 200, 300})}; + auto tensords = std::make_shared(tensormap); + + BatchDataset batchds(tensords, 7, BatchDatasetPolicy::INCLUDE_LAST); + + // Check `size` method + ASSERT_EQ(batchds.size(), 43); + + // Values using `get` method + auto ff1 = batchds.get(42); + ASSERT_EQ(ff1.size(), 1); + ASSERT_TRUE( + allClose(ff1[0], tensormap[0](fl::span, fl::span, fl::range(294, 300))) + ); + + ff1 = batchds.get(10); + ASSERT_EQ(ff1.size(), 1); + ASSERT_TRUE( + allClose(ff1[0], tensormap[0](fl::span, fl::span, fl::range(70, 77))) + ); } TEST(DatasetTest, DynamicBatchDataset) { - // first create a tensor dataset - std::vector tensormap = {fl::rand({100, 200, 300})}; - auto tensords = std::make_shared(tensormap); - std::vector bSzs = {20, 50, 20, 30, 10, 50, 20, 35, 15, 50}; - BatchDataset batchds(tensords, bSzs); - - // Check `size` method - ASSERT_EQ(batchds.size(), bSzs.size()); - - // Values using `get` method - auto ff1 = batchds.get(0); - ASSERT_EQ(ff1.size(), 1); - ASSERT_TRUE( - allClose(ff1[0], tensormap[0](fl::span, fl::span, fl::range(0, 20)))); - - ff1 = batchds.get(3); - ASSERT_EQ(ff1.size(), 1); - ASSERT_TRUE( - allClose(ff1[0], tensormap[0](fl::span, fl::span, fl::range(90, 120)))); + // first create a tensor dataset + std::vector tensormap = {fl::rand({100, 200, 300})}; + auto tensords = std::make_shared(tensormap); + std::vector bSzs = {20, 50, 20, 30, 10, 50, 20, 35, 15, 50}; + BatchDataset batchds(tensords, bSzs); + + // Check `size` method + ASSERT_EQ(batchds.size(), bSzs.size()); + + // Values using `get` method + auto ff1 = batchds.get(0); + ASSERT_EQ(ff1.size(), 1); + ASSERT_TRUE( + allClose(ff1[0], tensormap[0](fl::span, fl::span, fl::range(0, 20))) + ); + + ff1 = batchds.get(3); + ASSERT_EQ(ff1.size(), 1); + ASSERT_TRUE( + allClose(ff1[0], tensormap[0](fl::span, fl::span, fl::range(90, 120))) + ); } TEST(DatasetTest, ShuffleDataset) { - std::vector tensormap = {fl::rand({100, 200, 300})}; - auto tensords = std::make_shared(tensormap); - ShuffleDataset shuffleds(tensords); - - // Check `size` method - ASSERT_EQ(shuffleds.size(), 300); - - // Values using `get` method - auto ff1 = shuffleds.get(10); - ASSERT_EQ(ff1.size(), 1); - ASSERT_FALSE(allClose(ff1[0], tensormap[0](fl::span, fl::span, 10))); - - // Same seed produces same order and vice-versa - ShuffleDataset shuffleds2(tensords, 2); - ShuffleDataset shuffleds3(tensords, 2); - ShuffleDataset shuffleds4(tensords, 3); - auto ff2 = shuffleds2.get(10); - auto ff3 = shuffleds3.get(10); - auto ff4 = shuffleds4.get(10); - ASSERT_EQ(ff2.size(), 1); - ASSERT_EQ(ff3.size(), 1); - ASSERT_EQ(ff4.size(), 1); - ASSERT_TRUE(allClose(ff2[0], ff3[0])); - ASSERT_FALSE(allClose(ff2[0], ff4[0])); + std::vector tensormap = {fl::rand({100, 200, 300})}; + auto tensords = std::make_shared(tensormap); + ShuffleDataset shuffleds(tensords); + + // Check `size` method + ASSERT_EQ(shuffleds.size(), 300); + + // Values using `get` method + auto ff1 = shuffleds.get(10); + ASSERT_EQ(ff1.size(), 1); + ASSERT_FALSE(allClose(ff1[0], tensormap[0](fl::span, fl::span, 10))); + + // Same seed produces same order and vice-versa + ShuffleDataset shuffleds2(tensords, 2); + ShuffleDataset shuffleds3(tensords, 2); + ShuffleDataset shuffleds4(tensords, 3); + auto ff2 = shuffleds2.get(10); + auto ff3 = shuffleds3.get(10); + auto ff4 = shuffleds4.get(10); + ASSERT_EQ(ff2.size(), 1); + ASSERT_EQ(ff3.size(), 1); + ASSERT_EQ(ff4.size(), 1); + ASSERT_TRUE(allClose(ff2[0], ff3[0])); + ASSERT_FALSE(allClose(ff2[0], ff4[0])); } TEST(DatasetTest, ResampleDataset) { - std::vector tensormap = {fl::rand({100, 200, 300})}; - auto tensords = std::make_shared(tensormap); - auto permfn = [](int64_t n) { return (n + 5) % 300; }; - ResampleDataset resampleds(tensords, permfn); + std::vector tensormap = {fl::rand({100, 200, 300})}; + auto tensords = std::make_shared(tensormap); + auto permfn = [](int64_t n) { return (n + 5) % 300; }; + ResampleDataset resampleds(tensords, permfn); - // Check `size` method - ASSERT_EQ(resampleds.size(), 300); + // Check `size` method + ASSERT_EQ(resampleds.size(), 300); - auto ff1 = resampleds.get(10); - ASSERT_TRUE(allClose(ff1[0], tensormap[0](fl::span, fl::span, 15))); - ASSERT_FALSE(allClose(ff1[0], tensormap[0](fl::span, fl::span, 10))); + auto ff1 = resampleds.get(10); + ASSERT_TRUE(allClose(ff1[0], tensormap[0](fl::span, fl::span, 15))); + ASSERT_FALSE(allClose(ff1[0], tensormap[0](fl::span, fl::span, 10))); - resampleds.resample({3, 3, 3, 4, 5}); - ASSERT_EQ(resampleds.size(), 5); + resampleds.resample({3, 3, 3, 4, 5}); + ASSERT_EQ(resampleds.size(), 5); - auto ff2 = resampleds.get(1); - ASSERT_TRUE(allClose(ff2[0], tensormap[0](fl::span, fl::span, 3))); + auto ff2 = resampleds.get(1); + ASSERT_TRUE(allClose(ff2[0], tensormap[0](fl::span, fl::span, 3))); } TEST(DatasetTest, SpanDataset) { - std::vector tensormap = { - fl::rand({100, 200, 300}), fl::rand({150, 300})}; - auto tensords = std::make_shared(tensormap); - - SpanDataset frontspands(tensords, 0, 13); - SpanDataset backspands(tensords, 13); - - // Check `size` method - ASSERT_EQ(frontspands.size(), 13); - ASSERT_EQ(backspands.size(), 287); - - // Values using `get` method - auto ff1 = frontspands.get(10); - ASSERT_EQ(ff1.size(), 2); - ASSERT_TRUE(allClose(ff1[0], tensormap[0](fl::span, fl::span, 10))); - ASSERT_TRUE(allClose(ff1[1], tensormap[1](fl::span, 10))); - auto ff2 = backspands.get(10); - ASSERT_EQ(ff2.size(), 2); - ASSERT_TRUE(allClose(ff2[0], tensormap[0](fl::span, fl::span, 13 + 10))); - ASSERT_TRUE(allClose(ff2[1], tensormap[1](fl::span, 13 + 10))); + std::vector tensormap = { + fl::rand({100, 200, 300}), fl::rand({150, 300})}; + auto tensords = std::make_shared(tensormap); + + SpanDataset frontspands(tensords, 0, 13); + SpanDataset backspands(tensords, 13); + + // Check `size` method + ASSERT_EQ(frontspands.size(), 13); + ASSERT_EQ(backspands.size(), 287); + + // Values using `get` method + auto ff1 = frontspands.get(10); + ASSERT_EQ(ff1.size(), 2); + ASSERT_TRUE(allClose(ff1[0], tensormap[0](fl::span, fl::span, 10))); + ASSERT_TRUE(allClose(ff1[1], tensormap[1](fl::span, 10))); + auto ff2 = backspands.get(10); + ASSERT_EQ(ff2.size(), 2); + ASSERT_TRUE(allClose(ff2[0], tensormap[0](fl::span, fl::span, 13 + 10))); + ASSERT_TRUE(allClose(ff2[1], tensormap[1](fl::span, 13 + 10))); } TEST(DatasetTest, ConcatDataset) { - auto tensor1 = fl::rand({100, 200, 100}); - auto tensor2 = fl::rand({100, 200, 200}); - std::vector tensormap1 = {tensor1}; - auto tensords1 = std::make_shared(tensormap1); - std::vector tensormap2 = {tensor2}; - auto tensords2 = std::make_shared(tensormap2); - ConcatDataset concatds({tensords1, tensords2}); - - // Check `size` method - ASSERT_TRUE(concatds.size() == 300); - - auto ff1 = concatds.get(100); - ASSERT_EQ(ff1.size(), 1); - ASSERT_TRUE(allClose(ff1[0], tensor2(fl::span, fl::span, 0))); - ff1 = concatds.get(299); - ASSERT_TRUE(allClose(ff1[0], tensor2(fl::span, fl::span, 199))); - ff1 = concatds.get(0); - ASSERT_TRUE(allClose(ff1[0], tensor1(fl::span, fl::span, 0))); - ff1 = concatds.get(10); - ASSERT_TRUE(allClose(ff1[0], tensor1(fl::span, fl::span, 10))); + auto tensor1 = fl::rand({100, 200, 100}); + auto tensor2 = fl::rand({100, 200, 200}); + std::vector tensormap1 = {tensor1}; + auto tensords1 = std::make_shared(tensormap1); + std::vector tensormap2 = {tensor2}; + auto tensords2 = std::make_shared(tensormap2); + ConcatDataset concatds({tensords1, tensords2}); + + // Check `size` method + ASSERT_TRUE(concatds.size() == 300); + + auto ff1 = concatds.get(100); + ASSERT_EQ(ff1.size(), 1); + ASSERT_TRUE(allClose(ff1[0], tensor2(fl::span, fl::span, 0))); + ff1 = concatds.get(299); + ASSERT_TRUE(allClose(ff1[0], tensor2(fl::span, fl::span, 199))); + ff1 = concatds.get(0); + ASSERT_TRUE(allClose(ff1[0], tensor1(fl::span, fl::span, 0))); + ff1 = concatds.get(10); + ASSERT_TRUE(allClose(ff1[0], tensor1(fl::span, fl::span, 10))); } TEST(DatasetTest, DatasetIterator) { - std::vector tensormap = {fl::rand({100, 200, 300})}; - auto tensords = std::make_shared(tensormap); - - auto scaleAndAdd = [](const Tensor& a) { return fl::sin(a) + 1.0; }; - TransformDataset transformds(tensords, {scaleAndAdd}); - - int idx = 0; - for (auto& sample : transformds) { - (void)sample; - ++idx; - } - ASSERT_EQ(idx, transformds.size()); + std::vector tensormap = {fl::rand({100, 200, 300})}; + auto tensords = std::make_shared(tensormap); + + auto scaleAndAdd = [](const Tensor& a) { return fl::sin(a) + 1.0; }; + TransformDataset transformds(tensords, {scaleAndAdd}); + + int idx = 0; + for(auto& sample : transformds) { + (void) sample; + ++idx; + } + ASSERT_EQ(idx, transformds.size()); } TEST(DatasetTest, FileBlobDataset) { - std::vector> data; - - auto fillup = [&data](FileBlobDataset& blob) { - for (int64_t i = 0; i < 20; i++) { - std::vector sample; - for (int64_t j = 0; j < i % 4; j++) { - Tensor tensor; - if (j % 2 == 0) { - tensor = fl::rand({100, 3, 100}); - } else { - tensor = fl::rand({100, 200}); + std::vector> data; + + auto fillup = [&data](FileBlobDataset& blob) { + for(int64_t i = 0; i < 20; i++) { + std::vector sample; + for(int64_t j = 0; j < i % 4; j++) { + Tensor tensor; + if(j % 2 == 0) { + tensor = fl::rand({100, 3, 100}); + } else { + tensor = fl::rand({100, 200}); + } + sample.push_back(tensor); + } + data.push_back(sample); + blob.add(sample); + } + blob.flush(); + }; + + auto check = [&data](FileBlobDataset& blob) { + ASSERT_EQ(data.size(), blob.size()); + for(int64_t i = 0; i < blob.size(); i++) { + auto blobSample = blob.get(i); + auto datSample = data.at(i); + ASSERT_EQ(datSample.size(), blobSample.size()); + for(int64_t j = 0; j < blobSample.size(); j++) { + ASSERT_TRUE( + fl::norm(datSample.at(j).flatten() - blobSample.at(j).flatten()) + .scalar() <= 1e-05 + ); + } + } + }; + + // check read-write capabilities + { + FileBlobDataset blob(fs::temp_directory_path() / "data.blob", true, true); + fillup(blob); + check(blob); + fillup(blob); + check(blob); + + blob.writeIndex(); + fillup(blob); + check(blob); + + blob.writeIndex(); + check(blob); + + FileBlobDataset blobcopy( + fs::temp_directory_path() / "data-copy.blob", true, true); + blobcopy.add(blob); + blobcopy.add(blob, 1048576); + auto datadup = data; + data.insert(data.end(), datadup.begin(), datadup.end()); + blobcopy.writeIndex(); + check(blobcopy); + data = datadup; + check(blob); + + // check hostTransform + for(auto& vec : data) { + if(!vec.empty()) { + vec[0] += 1; + } } - sample.push_back(tensor); - } - data.push_back(sample); - blob.add(sample); - } - blob.flush(); - }; - - auto check = [&data](FileBlobDataset& blob) { - ASSERT_EQ(data.size(), blob.size()); - for (int64_t i = 0; i < blob.size(); i++) { - auto blobSample = blob.get(i); - auto datSample = data.at(i); - ASSERT_EQ(datSample.size(), blobSample.size()); - for (int64_t j = 0; j < blobSample.size(); j++) { - ASSERT_TRUE( - fl::norm(datSample.at(j).flatten() - blobSample.at(j).flatten()) - .scalar() <= 1e-05); - } - } - }; - - // check read-write capabilities - { - FileBlobDataset blob(fs::temp_directory_path() / "data.blob", true, true); - fillup(blob); - check(blob); - fillup(blob); - check(blob); - - blob.writeIndex(); - fillup(blob); - check(blob); - - blob.writeIndex(); - check(blob); - - FileBlobDataset blobcopy( - fs::temp_directory_path() / "data-copy.blob", true, true); - blobcopy.add(blob); - blobcopy.add(blob, 1048576); - auto datadup = data; - data.insert(data.end(), datadup.begin(), datadup.end()); - blobcopy.writeIndex(); - check(blobcopy); - data = datadup; - check(blob); - - // check hostTransform - for (auto& vec : data) { - if (!vec.empty()) { - vec[0] += 1; - } - } - blob.setHostTransform( - 0, [](void* ptr, fl::Shape size, fl::dtype /* type */) { - float* ptrFl = (float*)ptr; - for (int64_t i = 0; i < size.elements(); i++) { - ptrFl[i] += 1; - } - return Tensor::fromBuffer(size, ptrFl, MemoryLocation::Host); - }); - check(blob); - for (auto& vec : data) { - if (!vec.empty()) { - vec[0] -= 1; - } - } - } - - // check tensor dim constraints - { - FileBlobDataset blob( - fs::temp_directory_path() / "max_size.blob", true, true); - ASSERT_THROW(blob.add({fl::rand({4, 5, 6, 7, 8})}), std::invalid_argument); - } - - // check everything is correct after re-opening - { - FileBlobDataset blob(fs::temp_directory_path() / "data.blob"); - check(blob); - } - - // multi-threaded read - { - std::vector> thdata(data.size()); - auto blob = std::make_shared( - fs::temp_directory_path() / "data.blob"); - std::vector workers; - const int nworker = 4; - int nperworker = data.size() / nworker; - for (int i = 0; i < nworker; i++) { - auto device = fl::getDevice(); - workers.emplace_back([i, blob, nperworker, device, &thdata]() { - fl::setDevice(device); - for (int j = 0; j < nperworker; j++) { - thdata[i * nperworker + j] = blob->get(i * nperworker + j); + blob.setHostTransform( + 0, + [](void* ptr, fl::Shape size, fl::dtype /* type */) { + float* ptrFl = (float*) ptr; + for(int64_t i = 0; i < size.elements(); i++) { + ptrFl[i] += 1; + } + return Tensor::fromBuffer(size, ptrFl, MemoryLocation::Host); + } + ); + check(blob); + for(auto& vec : data) { + if(!vec.empty()) { + vec[0] -= 1; + } } - }); - } - for (int i = 0; i < nworker; i++) { - workers[i].join(); } - ASSERT_EQ(data.size(), thdata.size()); - for (int64_t i = 0; i < data.size(); i++) { - auto thdataSample = thdata.at(i); - auto dataSample = data.at(i); - ASSERT_EQ(dataSample.size(), thdataSample.size()); - for (int64_t j = 0; j < thdataSample.size(); j++) { - ASSERT_TRUE(thdataSample.at(j).shape() == dataSample.at(j).shape()); - ASSERT_TRUE( - fl::norm(dataSample.at(j).flatten() - thdataSample.at(j).flatten()) - .scalar() <= 1e-05); - } + + // check tensor dim constraints + { + FileBlobDataset blob( + fs::temp_directory_path() / "max_size.blob", true, true); + ASSERT_THROW(blob.add({fl::rand({4, 5, 6, 7, 8})}), std::invalid_argument); } - } - // multi-threaded write - { - // add an index - for (int i = 0; i < data.size(); i++) { - data[i].push_back(fl::full({1}, i, fl::dtype::f32)); + // check everything is correct after re-opening + { + FileBlobDataset blob(fs::temp_directory_path() / "data.blob"); + check(blob); } + + // multi-threaded read { - auto blob = std::make_shared( - fs::temp_directory_path() / "data.blob", true, true); - std::vector workers; - const int nworker = 10; - int nperworker = data.size() / nworker; - auto device = fl::getDevice(); - for (int i = 0; i < nworker; i++) { - workers.emplace_back([i, blob, nperworker, device, &data]() { - fl::setDevice(device); - for (int j = 0; j < nperworker; j++) { - blob->add(data[i * nperworker + j]); - } - }); - } - for (int i = 0; i < nworker; i++) { - workers[i].join(); - } - blob->writeIndex(); + std::vector> thdata(data.size()); + auto blob = std::make_shared( + fs::temp_directory_path() / "data.blob" + ); + std::vector workers; + const int nworker = 4; + int nperworker = data.size() / nworker; + for(int i = 0; i < nworker; i++) { + auto device = fl::getDevice(); + workers.emplace_back( + [i, blob, nperworker, device, &thdata]() { + fl::setDevice(device); + for(int j = 0; j < nperworker; j++) { + thdata[i * nperworker + j] = blob->get(i * nperworker + j); + } + } + ); + } + for(int i = 0; i < nworker; i++) { + workers[i].join(); + } + ASSERT_EQ(data.size(), thdata.size()); + for(int64_t i = 0; i < data.size(); i++) { + auto thdataSample = thdata.at(i); + auto dataSample = data.at(i); + ASSERT_EQ(dataSample.size(), thdataSample.size()); + for(int64_t j = 0; j < thdataSample.size(); j++) { + ASSERT_TRUE(thdataSample.at(j).shape() == dataSample.at(j).shape()); + ASSERT_TRUE( + fl::norm(dataSample.at(j).flatten() - thdataSample.at(j).flatten()) + .scalar() <= 1e-05 + ); + } + } } + + // multi-threaded write { - auto blob = std::make_shared( - fs::temp_directory_path() / "data.blob"); - ASSERT_EQ(data.size(), blob->size()); - for (int64_t i = 0; i < data.size(); i++) { - auto blobSample = blob->get(i); - auto idx = static_cast(blobSample.back().scalar()); - ASSERT_TRUE(idx >= 0 && idx < data.size()); - auto dataSample = data.at(idx); - ASSERT_EQ(dataSample.size(), blobSample.size()); - for (int64_t j = 0; j < blobSample.size(); j++) { - ASSERT_TRUE(dataSample.at(j).shape() == blobSample.at(j).shape()); - ASSERT_TRUE( - fl::norm(dataSample.at(j).flatten() - blobSample.at(j).flatten()) - .scalar() <= 1e-05); + // add an index + for(int i = 0; i < data.size(); i++) { + data[i].push_back(fl::full({1}, i, fl::dtype::f32)); + } + { + auto blob = std::make_shared( + fs::temp_directory_path() / "data.blob", + true, + true + ); + std::vector workers; + const int nworker = 10; + int nperworker = data.size() / nworker; + auto device = fl::getDevice(); + for(int i = 0; i < nworker; i++) { + workers.emplace_back( + [i, blob, nperworker, device, &data]() { + fl::setDevice(device); + for(int j = 0; j < nperworker; j++) { + blob->add(data[i * nperworker + j]); + } + } + ); + } + for(int i = 0; i < nworker; i++) { + workers[i].join(); + } + blob->writeIndex(); + } + { + auto blob = std::make_shared( + fs::temp_directory_path() / "data.blob" + ); + ASSERT_EQ(data.size(), blob->size()); + for(int64_t i = 0; i < data.size(); i++) { + auto blobSample = blob->get(i); + auto idx = static_cast(blobSample.back().scalar()); + ASSERT_TRUE(idx >= 0 && idx < data.size()); + auto dataSample = data.at(idx); + ASSERT_EQ(dataSample.size(), blobSample.size()); + for(int64_t j = 0; j < blobSample.size(); j++) { + ASSERT_TRUE(dataSample.at(j).shape() == blobSample.at(j).shape()); + ASSERT_TRUE( + fl::norm(dataSample.at(j).flatten() - blobSample.at(j).flatten()) + .scalar() <= 1e-05 + ); + } + } } - } } - } } TEST(DatasetTest, MemoryBlobDataset) { - std::vector> data; - - auto fillup = [&data](MemoryBlobDataset& blob) { - for (int64_t i = 0; i < 20; i++) { - std::vector sample; - for (int64_t j = 0; j < i % 4; j++) { - Tensor tensor; - if (j % 2 == 0) { - tensor = fl::rand({100, 3, 100}); - } else { - tensor = fl::rand({100, 200}); + std::vector> data; + + auto fillup = [&data](MemoryBlobDataset& blob) { + for(int64_t i = 0; i < 20; i++) { + std::vector sample; + for(int64_t j = 0; j < i % 4; j++) { + Tensor tensor; + if(j % 2 == 0) { + tensor = fl::rand({100, 3, 100}); + } else { + tensor = fl::rand({100, 200}); + } + sample.push_back(tensor); + } + data.push_back(sample); + blob.add(sample); + } + blob.flush(); + }; + + auto check = [&data](MemoryBlobDataset& blob) { + ASSERT_EQ(data.size(), blob.size()); + for(int64_t i = 0; i < blob.size(); i++) { + auto blobSample = blob.get(i); + auto datSample = data.at(i); + ASSERT_EQ(datSample.size(), blobSample.size()); + for(int64_t j = 0; j < blobSample.size(); j++) { + ASSERT_TRUE( + fl::norm(datSample.at(j).flatten() - blobSample.at(j).flatten()) + .scalar() <= 1e-05 + ); + } + } + }; + + // check read-write capabilities + MemoryBlobDataset blob; + { + fillup(blob); + check(blob); + fillup(blob); + check(blob); + + blob.writeIndex(); + fillup(blob); + check(blob); + + blob.writeIndex(); + check(blob); + + MemoryBlobDataset blobcopy; + blobcopy.add(blob); + blobcopy.add(blob, 1048576); + auto datadup = data; + data.insert(data.end(), datadup.begin(), datadup.end()); + blobcopy.writeIndex(); + check(blobcopy); + data = datadup; + check(blob); + + // check hostTransform + for(auto& vec : data) { + if(!vec.empty()) { + vec[0] += 1; + } } - sample.push_back(tensor); - } - data.push_back(sample); - blob.add(sample); - } - blob.flush(); - }; - - auto check = [&data](MemoryBlobDataset& blob) { - ASSERT_EQ(data.size(), blob.size()); - for (int64_t i = 0; i < blob.size(); i++) { - auto blobSample = blob.get(i); - auto datSample = data.at(i); - ASSERT_EQ(datSample.size(), blobSample.size()); - for (int64_t j = 0; j < blobSample.size(); j++) { - ASSERT_TRUE( - fl::norm(datSample.at(j).flatten() - blobSample.at(j).flatten()) - .scalar() <= 1e-05); - } - } - }; - - // check read-write capabilities - MemoryBlobDataset blob; - { - fillup(blob); - check(blob); - fillup(blob); - check(blob); - - blob.writeIndex(); - fillup(blob); - check(blob); - - blob.writeIndex(); - check(blob); - - MemoryBlobDataset blobcopy; - blobcopy.add(blob); - blobcopy.add(blob, 1048576); - auto datadup = data; - data.insert(data.end(), datadup.begin(), datadup.end()); - blobcopy.writeIndex(); - check(blobcopy); - data = datadup; - check(blob); - - // check hostTransform - for (auto& vec : data) { - if (!vec.empty()) { - vec[0] += 1; - } - } - blob.setHostTransform( - 0, [](void* ptr, fl::Shape size, fl::dtype /* type */) { - float* ptrFl = (float*)ptr; - for (int64_t i = 0; i < size.elements(); i++) { - ptrFl[i] += 1; - } - return Tensor::fromBuffer(size, ptrFl, MemoryLocation::Host); - }); - check(blob); - } - - // multi-threaded read - { - std::vector> thdata(data.size()); - std::vector workers; - const int nworker = 4; - int nperworker = data.size() / nworker; - for (int i = 0; i < nworker; i++) { - auto device = fl::getDevice(); - workers.emplace_back([i, &blob, nperworker, device, &thdata]() { - fl::setDevice(device); - for (int j = 0; j < nperworker; j++) { - thdata[i * nperworker + j] = blob.get(i * nperworker + j); + blob.setHostTransform( + 0, + [](void* ptr, fl::Shape size, fl::dtype /* type */) { + float* ptrFl = (float*) ptr; + for(int64_t i = 0; i < size.elements(); i++) { + ptrFl[i] += 1; + } + return Tensor::fromBuffer(size, ptrFl, MemoryLocation::Host); } - }); - } - for (int i = 0; i < nworker; i++) { - workers[i].join(); - } - ASSERT_EQ(data.size(), thdata.size()); - for (int64_t i = 0; i < data.size(); i++) { - auto thdataSample = thdata.at(i); - auto dataSample = data.at(i); - ASSERT_EQ(dataSample.size(), thdataSample.size()); - for (int64_t j = 0; j < thdataSample.size(); j++) { - ASSERT_TRUE(thdataSample.at(j).shape() == dataSample.at(j).shape()); - ASSERT_TRUE( - fl::norm(dataSample.at(j).flatten() - thdataSample.at(j).flatten()) - .scalar() <= 1e-05); - } - } - } - - // multi-threaded write - { - MemoryBlobDataset wblob; - // add an index - for (int i = 0; i < data.size(); i++) { - data[i].push_back(fl::full({1}, i, fl::dtype::f32)); + ); + check(blob); } + + // multi-threaded read { - std::vector workers; - const int nworker = 10; - int nperworker = data.size() / nworker; - auto device = fl::getDevice(); - for (int i = 0; i < nworker; i++) { - workers.emplace_back([i, &wblob, nperworker, device, &data]() { - fl::setDevice(device); - for (int j = 0; j < nperworker; j++) { - wblob.add(data[i * nperworker + j]); - } - }); - } - for (int i = 0; i < nworker; i++) { - workers[i].join(); - } - wblob.writeIndex(); + std::vector> thdata(data.size()); + std::vector workers; + const int nworker = 4; + int nperworker = data.size() / nworker; + for(int i = 0; i < nworker; i++) { + auto device = fl::getDevice(); + workers.emplace_back( + [i, &blob, nperworker, device, &thdata]() { + fl::setDevice(device); + for(int j = 0; j < nperworker; j++) { + thdata[i * nperworker + j] = blob.get(i * nperworker + j); + } + } + ); + } + for(int i = 0; i < nworker; i++) { + workers[i].join(); + } + ASSERT_EQ(data.size(), thdata.size()); + for(int64_t i = 0; i < data.size(); i++) { + auto thdataSample = thdata.at(i); + auto dataSample = data.at(i); + ASSERT_EQ(dataSample.size(), thdataSample.size()); + for(int64_t j = 0; j < thdataSample.size(); j++) { + ASSERT_TRUE(thdataSample.at(j).shape() == dataSample.at(j).shape()); + ASSERT_TRUE( + fl::norm(dataSample.at(j).flatten() - thdataSample.at(j).flatten()) + .scalar() <= 1e-05 + ); + } + } } + + // multi-threaded write { - ASSERT_EQ(data.size(), wblob.size()); - for (int64_t i = 0; i < data.size(); i++) { - auto wblobSample = wblob.get(i); - auto idx = static_cast(wblobSample.back().scalar()); - ASSERT_TRUE(idx >= 0 && idx < data.size()); - auto dataSample = data.at(idx); - ASSERT_EQ(dataSample.size(), wblobSample.size()); - for (int64_t j = 0; j < wblobSample.size(); j++) { - ASSERT_TRUE(dataSample.at(j).shape() == wblobSample.at(j).shape()); - ASSERT_TRUE( - fl::norm(dataSample.at(j).flatten() - wblobSample.at(j).flatten()) - .scalar() <= 1e-05); + MemoryBlobDataset wblob; + // add an index + for(int i = 0; i < data.size(); i++) { + data[i].push_back(fl::full({1}, i, fl::dtype::f32)); + } + { + std::vector workers; + const int nworker = 10; + int nperworker = data.size() / nworker; + auto device = fl::getDevice(); + for(int i = 0; i < nworker; i++) { + workers.emplace_back( + [i, &wblob, nperworker, device, &data]() { + fl::setDevice(device); + for(int j = 0; j < nperworker; j++) { + wblob.add(data[i * nperworker + j]); + } + } + ); + } + for(int i = 0; i < nworker; i++) { + workers[i].join(); + } + wblob.writeIndex(); + } + { + ASSERT_EQ(data.size(), wblob.size()); + for(int64_t i = 0; i < data.size(); i++) { + auto wblobSample = wblob.get(i); + auto idx = static_cast(wblobSample.back().scalar()); + ASSERT_TRUE(idx >= 0 && idx < data.size()); + auto dataSample = data.at(idx); + ASSERT_EQ(dataSample.size(), wblobSample.size()); + for(int64_t j = 0; j < wblobSample.size(); j++) { + ASSERT_TRUE(dataSample.at(j).shape() == wblobSample.at(j).shape()); + ASSERT_TRUE( + fl::norm(dataSample.at(j).flatten() - wblobSample.at(j).flatten()) + .scalar() <= 1e-05 + ); + } + } } - } } - } } TEST(DatasetTest, PrefetchDatasetCorrectness) { - std::vector tensormap = {fl::rand({100, 200, 300})}; - auto tensords = std::make_shared(tensormap); - - Dataset::TransformFunction scaleAndAdd = [](const Tensor& a) { - return fl::cos(a) + 10.0; - }; - - auto transformDs = std::make_shared( - tensords, std::vector{scaleAndAdd}); - - auto prefetchDs = std::make_shared(transformDs, 2, 2); - for (int i = 0; i < transformDs->size(); ++i) { - auto sample1 = transformDs->get(i); - auto sample2 = prefetchDs->get(i); - ASSERT_EQ(sample1.size(), sample2.size()); - for (int j = 0; j < sample1.size(); ++j) { - ASSERT_TRUE(allClose(sample1[j], sample2[j])); + std::vector tensormap = {fl::rand({100, 200, 300})}; + auto tensords = std::make_shared(tensormap); + + Dataset::TransformFunction scaleAndAdd = [](const Tensor& a) { + return fl::cos(a) + 10.0; + }; + + auto transformDs = std::make_shared( + tensords, + std::vector{scaleAndAdd} + ); + + auto prefetchDs = std::make_shared(transformDs, 2, 2); + for(int i = 0; i < transformDs->size(); ++i) { + auto sample1 = transformDs->get(i); + auto sample2 = prefetchDs->get(i); + ASSERT_EQ(sample1.size(), sample2.size()); + for(int j = 0; j < sample1.size(); ++j) { + ASSERT_TRUE(allClose(sample1[j], sample2[j])); + } } - } } TEST(DatasetTest, DISABLED_PrefetchDatasetPerformance) { - // Flaky test. Disabled for now. - std::vector tensormap = {fl::rand({100, 200, 300})}; - auto tensords = std::make_shared(tensormap); - - Dataset::TransformFunction scaleAndAdd = [](const Tensor& a) { - /* sleep override */ - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - return fl::sin(a) + 1.0; - }; - auto transformDs = std::make_shared( - tensords, std::vector{scaleAndAdd}); - - auto start = std::chrono::high_resolution_clock::now(); - for (auto& sample : *transformDs) { - (void)sample; - } - auto dur = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start); - ASSERT_NEAR(dur.count(), transformDs->size(), transformDs->size() / 5); - - int64_t numthreads = 4; - auto prefetchDs = - std::make_shared(transformDs, numthreads, numthreads); - - start = std::chrono::high_resolution_clock::now(); - for (auto& sample : *prefetchDs) { - (void)sample; - } - dur = std::chrono::duration_cast( - std::chrono::high_resolution_clock::now() - start); - ASSERT_NEAR( - dur.count(), - transformDs->size() / numthreads, - transformDs->size() / numthreads / 5); + // Flaky test. Disabled for now. + std::vector tensormap = {fl::rand({100, 200, 300})}; + auto tensords = std::make_shared(tensormap); + + Dataset::TransformFunction scaleAndAdd = [](const Tensor& a) { + /* sleep override */ + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + return fl::sin(a) + 1.0; + }; + auto transformDs = std::make_shared( + tensords, + std::vector{scaleAndAdd} + ); + + auto start = std::chrono::high_resolution_clock::now(); + for(auto& sample : *transformDs) { + (void) sample; + } + auto dur = std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - start + ); + ASSERT_NEAR(dur.count(), transformDs->size(), transformDs->size() / 5); + + int64_t numthreads = 4; + auto prefetchDs = + std::make_shared(transformDs, numthreads, numthreads); + + start = std::chrono::high_resolution_clock::now(); + for(auto& sample : *prefetchDs) { + (void) sample; + } + dur = std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - start + ); + ASSERT_NEAR( + dur.count(), + transformDs->size() / numthreads, + transformDs->size() / numthreads / 5 + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/dataset/DatasetUtilsTest.cpp b/flashlight/fl/test/dataset/DatasetUtilsTest.cpp index 04d78e1..ff48e35 100644 --- a/flashlight/fl/test/dataset/DatasetUtilsTest.cpp +++ b/flashlight/fl/test/dataset/DatasetUtilsTest.cpp @@ -16,45 +16,45 @@ using namespace fl; TEST(DatasetTest, RoundRobinPacker) { - auto samples = partitionByRoundRobin(11, 0, 2, 2); - ASSERT_EQ(samples.size(), 6); - ASSERT_EQ(samples, std::vector({0, 1, 4, 5, 8, 9})); + auto samples = partitionByRoundRobin(11, 0, 2, 2); + ASSERT_EQ(samples.size(), 6); + ASSERT_EQ(samples, std::vector({0, 1, 4, 5, 8, 9})); - samples = partitionByRoundRobin(10, 0, 2, 2); - ASSERT_EQ(samples.size(), 5); - ASSERT_EQ(samples, std::vector({0, 1, 4, 5, 8})); + samples = partitionByRoundRobin(10, 0, 2, 2); + ASSERT_EQ(samples.size(), 5); + ASSERT_EQ(samples, std::vector({0, 1, 4, 5, 8})); - samples = partitionByRoundRobin(9, 0, 2, 2); - ASSERT_EQ(samples.size(), 4); - ASSERT_EQ(samples, std::vector({0, 1, 4, 5})); + samples = partitionByRoundRobin(9, 0, 2, 2); + ASSERT_EQ(samples.size(), 4); + ASSERT_EQ(samples, std::vector({0, 1, 4, 5})); - samples = partitionByRoundRobin(8, 0, 2, 2); - ASSERT_EQ(samples.size(), 4); - ASSERT_EQ(samples, std::vector({0, 1, 4, 5})); + samples = partitionByRoundRobin(8, 0, 2, 2); + ASSERT_EQ(samples.size(), 4); + ASSERT_EQ(samples, std::vector({0, 1, 4, 5})); } TEST(DatasetTest, DynamicRoundRobinPacker) { - std::vector length = {2, 4, 1, 2, 3, 7, 4, 3}; - auto samples = dynamicPartitionByRoundRobin(length, 0, 2, 12); - ASSERT_EQ(samples.first.size(), 4); - // indices which packed into 0-th thread - ASSERT_EQ(samples.first, std::vector({0, 1, 2, 5})); - ASSERT_EQ(samples.second.size(), 2); - // sizes of batches in the 0-th thread - ASSERT_EQ(samples.second, std::vector({3, 1})); - - length = {2, 4, 1, 2, 3, 7, 4, 3, 5}; - samples = dynamicPartitionByRoundRobin(length, 0, 2, 12); - ASSERT_EQ(samples.first.size(), 4); - // indices which packed into 0-th thread - ASSERT_EQ(samples.first, std::vector({0, 1, 2, 5})); - ASSERT_EQ(samples.second.size(), 2); - // sizes of batches in the 0-th thread - ASSERT_EQ(samples.second, std::vector({3, 1})); + std::vector length = {2, 4, 1, 2, 3, 7, 4, 3}; + auto samples = dynamicPartitionByRoundRobin(length, 0, 2, 12); + ASSERT_EQ(samples.first.size(), 4); + // indices which packed into 0-th thread + ASSERT_EQ(samples.first, std::vector({0, 1, 2, 5})); + ASSERT_EQ(samples.second.size(), 2); + // sizes of batches in the 0-th thread + ASSERT_EQ(samples.second, std::vector({3, 1})); + + length = {2, 4, 1, 2, 3, 7, 4, 3, 5}; + samples = dynamicPartitionByRoundRobin(length, 0, 2, 12); + ASSERT_EQ(samples.first.size(), 4); + // indices which packed into 0-th thread + ASSERT_EQ(samples.first, std::vector({0, 1, 2, 5})); + ASSERT_EQ(samples.second.size(), 2); + // sizes of batches in the 0-th thread + ASSERT_EQ(samples.second, std::vector({3, 1})); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/distributed/AllReduceBenchmark.cpp b/flashlight/fl/test/distributed/AllReduceBenchmark.cpp index ea461c5..d080f6b 100644 --- a/flashlight/fl/test/distributed/AllReduceBenchmark.cpp +++ b/flashlight/fl/test/distributed/AllReduceBenchmark.cpp @@ -20,50 +20,51 @@ using namespace fl; int main() { - fl::init(); - distributedInit( - DistributedInit::MPI, - -1, - -1, - {{DistributedConstants::kMaxDevicePerNode, "8"}}); + fl::init(); + distributedInit( + DistributedInit::MPI, + -1, + -1, + {{DistributedConstants::kMaxDevicePerNode, "8"}} + ); - auto wRank = getWorldRank(); - auto wSize = getWorldSize(); + auto wRank = getWorldRank(); + auto wSize = getWorldSize(); - if (wRank == 0) { - std::cout << "Running allreduce on " << wSize << " machines" << std::endl; - } - - const int kNumIters = 10000; - std::vector sizes = {1, 2, 5}; - int64_t multiplier = 10; - int64_t maxSize = 5000000, curMaxSize = 0; - std::vector times(kNumIters); - while (true) { - for (auto& size : sizes) { - for (size_t i = 0; i < kNumIters; ++i) { - Tensor in = fl::rand({size}); - fl::eval(in); - fl::sync(); - auto start = fl::Timer::start(); - allReduce(in); - fl::sync(); - times[i] = fl::Timer::stop(start); - } - auto timesAf = Tensor::fromVector({kNumIters}, times); - if (wRank == 0) { - std::cout << "Size: " << size - << " ; avg: " << fl::mean(timesAf).asScalar() * 1000 - << "ms ; p50: " - << fl::median(timesAf).asScalar() * 1000 << "ms" - << std::endl; - } - curMaxSize = std::max(curMaxSize, size); - size *= multiplier; + if(wRank == 0) { + std::cout << "Running allreduce on " << wSize << " machines" << std::endl; } - if (curMaxSize >= maxSize) { - break; + + const int kNumIters = 10000; + std::vector sizes = {1, 2, 5}; + int64_t multiplier = 10; + int64_t maxSize = 5000000, curMaxSize = 0; + std::vector times(kNumIters); + while(true) { + for(auto& size : sizes) { + for(size_t i = 0; i < kNumIters; ++i) { + Tensor in = fl::rand({size}); + fl::eval(in); + fl::sync(); + auto start = fl::Timer::start(); + allReduce(in); + fl::sync(); + times[i] = fl::Timer::stop(start); + } + auto timesAf = Tensor::fromVector({kNumIters}, times); + if(wRank == 0) { + std::cout << "Size: " << size + << " ; avg: " << fl::mean(timesAf).asScalar() * 1000 + << "ms ; p50: " + << fl::median(timesAf).asScalar() * 1000 << "ms" + << std::endl; + } + curMaxSize = std::max(curMaxSize, size); + size *= multiplier; + } + if(curMaxSize >= maxSize) { + break; + } } - } - return 0; + return 0; } diff --git a/flashlight/fl/test/distributed/AllReduceTest.cpp b/flashlight/fl/test/distributed/AllReduceTest.cpp index 7f11a08..ed05cf8 100644 --- a/flashlight/fl/test/distributed/AllReduceTest.cpp +++ b/flashlight/fl/test/distributed/AllReduceTest.cpp @@ -21,183 +21,187 @@ using namespace fl; TEST(Distributed, AllReduce) { - if (!isDistributedInit()) { - GTEST_SKIP() << "Distributed initialization failed or not enabled."; - } + if(!isDistributedInit()) { + GTEST_SKIP() << "Distributed initialization failed or not enabled."; + } - auto rank = getWorldRank(); - auto size = getWorldSize(); + auto rank = getWorldRank(); + auto size = getWorldSize(); - Variable var(fl::full({10}, rank, dtype::f32), false); + Variable var(fl::full({10}, rank, dtype::f32), false); - allReduce(var, 2.0); + allReduce(var, 2.0); - float expected_val = size * (size - 1.0); - ASSERT_TRUE(fl::all(var.tensor() == expected_val).scalar()); + float expected_val = size * (size - 1.0); + ASSERT_TRUE(fl::all(var.tensor() == expected_val).scalar()); } TEST(Distributed, InlineReducer) { - if (!isDistributedInit()) { - GTEST_SKIP() << "Distributed initialization failed or not enabled."; - } + if(!isDistributedInit()) { + GTEST_SKIP() << "Distributed initialization failed or not enabled."; + } - auto rank = getWorldRank(); - auto size = getWorldSize(); + auto rank = getWorldRank(); + auto size = getWorldSize(); - Variable var(fl::full({10}, rank, dtype::f32), false); + Variable var(fl::full({10}, rank, dtype::f32), false); - auto reducer = std::make_shared(1.0 / size); - reducer->add(var); + auto reducer = std::make_shared(1.0 / size); + reducer->add(var); - // The reducer scales down by a factor of 1 / size - auto arr = var.tensor() * (size * 2); + // The reducer scales down by a factor of 1 / size + auto arr = var.tensor() * (size * 2); - float expected_val = size * (size - 1.0); - ASSERT_TRUE(fl::all(arr == expected_val).scalar()); + float expected_val = size * (size - 1.0); + ASSERT_TRUE(fl::all(arr == expected_val).scalar()); } TEST(Distributed, AllReduceAsync) { - if (!isDistributedInit()) { - GTEST_SKIP() << "Distributed initialization failed or not enabled."; - } + if(!isDistributedInit()) { + GTEST_SKIP() << "Distributed initialization failed or not enabled."; + } - auto rank = getWorldRank(); - auto size = getWorldSize(); - // not supported for the CPU backend - bool async = true && !FL_BACKEND_CPU; + auto rank = getWorldRank(); + auto size = getWorldSize(); + // not supported for the CPU backend + bool async = true && !FL_BACKEND_CPU; - Variable var(fl::full({10}, rank, dtype::f32), false); + Variable var(fl::full({10}, rank, dtype::f32), false); - allReduce(var, 2.0, async); - syncDistributed(); + allReduce(var, 2.0, async); + syncDistributed(); - float expected_val = size * (size - 1.0); - ASSERT_TRUE(fl::all(var.tensor() == expected_val).scalar()); + float expected_val = size * (size - 1.0); + ASSERT_TRUE(fl::all(var.tensor() == expected_val).scalar()); } TEST(Distributed, AllReduceSetAsync) { - if (!isDistributedInit()) { - GTEST_SKIP() << "Distributed initialization failed or not enabled."; - } - - auto rank = getWorldRank(); - auto size = getWorldSize(); - // not supported for the CPU backend - bool async = true && !FL_BACKEND_CPU; - bool contiguous = true && !FL_BACKEND_CPU; - - unsigned vSize = (1 << 20); - std::vector vars; - for (size_t i = 0; i < 5; ++i) { - vars.emplace_back(fl::full({vSize}, rank + 1, dtype::f32), false); - } - - allReduceMultiple(vars, 2.0, async, contiguous); - syncDistributed(); - - float expected_val = size * (size + 1.0); - for (const auto& var : vars) { - ASSERT_TRUE(fl::all(var.tensor() == expected_val).scalar()); - } - - // Exceed the size of the contiguous buffer without caching, and trigger a - // contiguous sync with a tensor that is too large - for (size_t i = 0; i < 25; ++i) { - vars.emplace_back(fl::full({vSize}, rank, dtype::f32), false); - } - if (size > 1) { - ASSERT_THROW( - allReduceMultiple(vars, 2.0, /*async=*/true, /*contiguous=*/true), - std::runtime_error); - } + if(!isDistributedInit()) { + GTEST_SKIP() << "Distributed initialization failed or not enabled."; + } + + auto rank = getWorldRank(); + auto size = getWorldSize(); + // not supported for the CPU backend + bool async = true && !FL_BACKEND_CPU; + bool contiguous = true && !FL_BACKEND_CPU; + + unsigned vSize = (1 << 20); + std::vector vars; + for(size_t i = 0; i < 5; ++i) { + vars.emplace_back(fl::full({vSize}, rank + 1, dtype::f32), false); + } + + allReduceMultiple(vars, 2.0, async, contiguous); + syncDistributed(); + + float expected_val = size * (size + 1.0); + for(const auto& var : vars) { + ASSERT_TRUE(fl::all(var.tensor() == expected_val).scalar()); + } + + // Exceed the size of the contiguous buffer without caching, and trigger a + // contiguous sync with a tensor that is too large + for(size_t i = 0; i < 25; ++i) { + vars.emplace_back(fl::full({vSize}, rank, dtype::f32), false); + } + if(size > 1) { + ASSERT_THROW( + allReduceMultiple(vars, 2.0, /*async=*/ true, /*contiguous=*/ true), + std::runtime_error + ); + } } TEST(Distributed, Barrier) { - auto rank = getWorldRank(); - auto size = getWorldSize(); - auto suffix = "_distributed_barrier_test.txt"; - - // Create files - std::this_thread::sleep_for(std::chrono::milliseconds(5000 * rank)); - auto file = fs::temp_directory_path() / (std::to_string(rank) + suffix); - std::ofstream stream(file); - stream << "done"; - stream.close(); - - barrier(); - for (int i = 0; i < size; i++) { - auto checkingFile = - fs::temp_directory_path() / (std::to_string(i) + suffix); - ASSERT_TRUE(fs::exists(checkingFile)); - } - barrier(); - - // Delete files - std::error_code errorCode; - const bool status = fs::remove(file, errorCode); - if (!status) { - throw std::runtime_error( - "Barrier test cannot delete file: " + std::string(file) + - " error: " + errorCode.message()); - } - barrier(); - for (int i = 0; i < size; i++) { - auto checkingFile = - fs::temp_directory_path() / (std::to_string(i) + suffix); - ASSERT_TRUE(!fs::exists(checkingFile)); - } + auto rank = getWorldRank(); + auto size = getWorldSize(); + auto suffix = "_distributed_barrier_test.txt"; + + // Create files + std::this_thread::sleep_for(std::chrono::milliseconds(5000 * rank)); + auto file = fs::temp_directory_path() / (std::to_string(rank) + suffix); + std::ofstream stream(file); + stream << "done"; + stream.close(); + + barrier(); + for(int i = 0; i < size; i++) { + auto checkingFile = + fs::temp_directory_path() / (std::to_string(i) + suffix); + ASSERT_TRUE(fs::exists(checkingFile)); + } + barrier(); + + // Delete files + std::error_code errorCode; + const bool status = fs::remove(file, errorCode); + if(!status) { + throw std::runtime_error( + "Barrier test cannot delete file: " + std::string(file) + + " error: " + errorCode.message() + ); + } + barrier(); + for(int i = 0; i < size; i++) { + auto checkingFile = + fs::temp_directory_path() / (std::to_string(i) + suffix); + ASSERT_TRUE(!fs::exists(checkingFile)); + } } TEST(Distributed, CoalescingReducer) { - if (!isDistributedInit()) { - GTEST_SKIP() << "Distributed initialization failed or not enabled."; - } + if(!isDistributedInit()) { + GTEST_SKIP() << "Distributed initialization failed or not enabled."; + } - auto rank = getWorldRank(); - auto size = getWorldSize(); + auto rank = getWorldRank(); + auto size = getWorldSize(); - auto s = std::make_shared( - /* scale = */ 1.0 / size, - /*async=*/true && !FL_BACKEND_CPU, - /*contiguous=*/true && !FL_BACKEND_CPU); + auto s = std::make_shared( + /* scale = */ 1.0 / size, + /*async=*/ true && !FL_BACKEND_CPU, + /*contiguous=*/ true && !FL_BACKEND_CPU + ); - unsigned vSize = (1 << 20); - std::vector vars; - for (size_t i = 0; i < 1000; ++i) { - vars.emplace_back(fl::full({vSize}, rank + 1, dtype::f32), false); - } + unsigned vSize = (1 << 20); + std::vector vars; + for(size_t i = 0; i < 1000; ++i) { + vars.emplace_back(fl::full({vSize}, rank + 1, dtype::f32), false); + } - for (size_t i = 0; i < vars.size(); ++i) { - s->add(vars[i]); - if ((i + 1) % 10 == 0) { - s->finalize(); + for(size_t i = 0; i < vars.size(); ++i) { + s->add(vars[i]); + if((i + 1) % 10 == 0) { + s->finalize(); + } } - } - float expected_val = size * (size + 1.0); - for (const auto& var : vars) { - // The reducer scales down by a factor of 1 / size - auto arr = var.tensor() * (size * 2); - ASSERT_TRUE(fl::all(arr == expected_val).scalar()); - } + float expected_val = size * (size + 1.0); + for(const auto& var : vars) { + // The reducer scales down by a factor of 1 / size + auto arr = var.tensor() * (size * 2); + ASSERT_TRUE(fl::all(arr == expected_val).scalar()); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - - try { - distributedInit( - DistributedInit::MPI, - -1, - -1, - {{DistributedConstants::kMaxDevicePerNode, "8"}}); - } catch (const std::exception& ex) { - // Don't run the test if distributed initialization fails - std::cerr + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + + try { + distributedInit( + DistributedInit::MPI, + -1, + -1, + {{DistributedConstants::kMaxDevicePerNode, "8"}} + ); + } catch(const std::exception& ex) { + // Don't run the test if distributed initialization fails + std::cerr << "Distributed initialization failed; tests will be skipped. Reason: " << ex.what() << std::endl; - } + } - return RUN_ALL_TESTS(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/meter/MeterTest.cpp b/flashlight/fl/test/meter/MeterTest.cpp index 4ed61e4..ec027c7 100644 --- a/flashlight/fl/test/meter/MeterTest.cpp +++ b/flashlight/fl/test/meter/MeterTest.cpp @@ -17,84 +17,90 @@ using namespace fl; TEST(MeterTest, EditDistanceMeter) { - EditDistanceMeter meter; - std::vector a = {1, 2, 3, 4, 5}; - std::vector b = {1, 1, 3, 3, 5, 6}; - meter.add(Tensor::fromVector(a), Tensor::fromVector(b)); - ASSERT_EQ(meter.errorRate()[0], 50.0); // 3 / 6 - ASSERT_EQ(meter.value()[0], 3); // 3 / 6 - ASSERT_LT( - std::fabs(16.6666667 - meter.errorRate()[2]), 1e-5); // deletion = 1 / 6 - ASSERT_EQ(meter.value()[2], 1); - ASSERT_EQ(meter.errorRate()[3], 0.0); // insertion error - ASSERT_EQ(meter.value()[3], 0); - ASSERT_LT( - std::fabs(33.3333333 - meter.errorRate()[4]), - 1e-5); // substitution error = 2 / 6 - ASSERT_EQ(meter.value()[4], 2); - // TODO{fl::Tensor}{check} - meter.add( - Tensor::fromBuffer({3}, a.data() + 1, MemoryLocation::Host), - Tensor::fromVector({3}, b)); - ASSERT_LT(std::fabs(66.666666 - meter.errorRate()[0]), 1e-5); // 3 + 3 / 6 + 3 - ASSERT_EQ(meter.value()[0], 6); + EditDistanceMeter meter; + std::vector a = {1, 2, 3, 4, 5}; + std::vector b = {1, 1, 3, 3, 5, 6}; + meter.add(Tensor::fromVector(a), Tensor::fromVector(b)); + ASSERT_EQ(meter.errorRate()[0], 50.0); // 3 / 6 + ASSERT_EQ(meter.value()[0], 3); // 3 / 6 + ASSERT_LT( + std::fabs(16.6666667 - meter.errorRate()[2]), + 1e-5 + ); // deletion = 1 / 6 + ASSERT_EQ(meter.value()[2], 1); + ASSERT_EQ(meter.errorRate()[3], 0.0); // insertion error + ASSERT_EQ(meter.value()[3], 0); + ASSERT_LT( + std::fabs(33.3333333 - meter.errorRate()[4]), + 1e-5 + ); // substitution error = 2 / 6 + ASSERT_EQ(meter.value()[4], 2); + // TODO{fl::Tensor}{check} + meter.add( + Tensor::fromBuffer({3}, a.data() + 1, MemoryLocation::Host), + Tensor::fromVector({3}, b) + ); + ASSERT_LT(std::fabs(66.666666 - meter.errorRate()[0]), 1e-5); // 3 + 3 / 6 + 3 + ASSERT_EQ(meter.value()[0], 6); } TEST(MeterTest, FrameErrorMeter) { - FrameErrorMeter meter; - std::vector a = {1, 2, 3, 4, 5}; - std::vector b = {1, 1, 3, 3, 5, 6}; - meter.add(Tensor::fromVector(a), Tensor::fromVector({5}, b)); - ASSERT_EQ(meter.value(), 40.0); // 2 / 5 - // TODO{fl::Tensor}{check} - meter.add( - Tensor::fromBuffer({4}, a.data() + 1, MemoryLocation::Host), - Tensor::fromBuffer({4}, b.data() + 2, MemoryLocation::Host)); - ASSERT_LT(std::fabs(55.5555555 - meter.value()), 1e-5); // 2 + 3 / 5 + 4 + FrameErrorMeter meter; + std::vector a = {1, 2, 3, 4, 5}; + std::vector b = {1, 1, 3, 3, 5, 6}; + meter.add(Tensor::fromVector(a), Tensor::fromVector({5}, b)); + ASSERT_EQ(meter.value(), 40.0); // 2 / 5 + // TODO{fl::Tensor}{check} + meter.add( + Tensor::fromBuffer({4}, a.data() + 1, MemoryLocation::Host), + Tensor::fromBuffer({4}, b.data() + 2, MemoryLocation::Host) + ); + ASSERT_LT(std::fabs(55.5555555 - meter.value()), 1e-5); // 2 + 3 / 5 + 4 } TEST(MeterTest, AverageValueMeter) { - AverageValueMeter meter; - meter.add(1.0, 0.0); - meter.add(2.0); - meter.add(3.0); - meter.add(4.0); - auto val = meter.value(); - ASSERT_EQ(val[0], 3.0); - ASSERT_NEAR(val[1], 1.0, 1e-10); - ASSERT_EQ(val[2], 3.0); + AverageValueMeter meter; + meter.add(1.0, 0.0); + meter.add(2.0); + meter.add(3.0); + meter.add(4.0); + auto val = meter.value(); + ASSERT_EQ(val[0], 3.0); + ASSERT_NEAR(val[1], 1.0, 1e-10); + ASSERT_EQ(val[2], 3.0); - std::vector a = {2.0, 3.0, 4.0}; - meter.add(Tensor::fromVector(a)); - val = meter.value(); - ASSERT_EQ(val[0], 3.0); - ASSERT_NEAR(val[1], 0.8, 1e-10); - ASSERT_EQ(val[2], 6.0); + std::vector a = {2.0, 3.0, 4.0}; + meter.add(Tensor::fromVector(a)); + val = meter.value(); + ASSERT_EQ(val[0], 3.0); + ASSERT_NEAR(val[1], 0.8, 1e-10); + ASSERT_EQ(val[2], 6.0); } TEST(MeterTest, MSEMeter) { - MSEMeter meter; - std::vector b = {4, 5, 6, 7, 8}; - meter.add( - Tensor::fromVector({1, 2, 3, 4, 5}), - Tensor::fromVector({4, 5, 6, 7, 8})); - auto val = meter.value(); - ASSERT_EQ(val, 45.0); + MSEMeter meter; + std::vector b = {4, 5, 6, 7, 8}; + meter.add( + Tensor::fromVector({1, 2, 3, 4, 5}), + Tensor::fromVector({4, 5, 6, 7, 8}) + ); + auto val = meter.value(); + ASSERT_EQ(val, 45.0); } TEST(MeterTest, CountMeter) { - CountMeter meter{3}; - meter.add(0, 10); - meter.add(1, 11); - meter.add(0, 12); - auto val = meter.value(); - ASSERT_EQ(val[0], 22); - ASSERT_EQ(val[1], 11); - ASSERT_EQ(val[2], 0); + CountMeter meter{3}; + meter.add(0, 10); + meter.add(1, 11); + meter.add(0, 12); + auto val = meter.value(); + ASSERT_EQ(val[0], 22); + ASSERT_EQ(val[1], 11); + ASSERT_EQ(val[2], 0); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/nn/ModuleTest.cpp b/flashlight/fl/test/nn/ModuleTest.cpp index 340d1b6..ee78fe3 100644 --- a/flashlight/fl/test/nn/ModuleTest.cpp +++ b/flashlight/fl/test/nn/ModuleTest.cpp @@ -21,966 +21,1071 @@ using namespace fl; namespace { class ContainerTestClass : public Sequential { - public: - ContainerTestClass() = default; - ContainerTestClass(const ContainerTestClass& other) { - copy(other); - } - ContainerTestClass& operator=(const ContainerTestClass& other) { - copy(other); - return *this; - } - ContainerTestClass(ContainerTestClass&& other) = default; - ContainerTestClass& operator=(ContainerTestClass&& other) = default; - void copy(const ContainerTestClass& other) { - auto orphanParamIdxMap = other.getOrphanedParamsIdxMap(); - for (int i = -1; i < static_cast(other.modules_.size()); ++i) { - if (i >= 0) { - add(other.modules_[i]->clone()); - } - auto [paramIter, pEnd] = orphanParamIdxMap.equal_range(i); - for (; paramIter != pEnd; ++paramIter) { - const auto& param = other.params_[paramIter->second]; - params_.emplace_back(param.copy()); - } - } - } - - std::unique_ptr clone() const override { - return std::make_unique(*this); - } - - void addParam(const Variable& param) { - params_.push_back(param); - } +public: + ContainerTestClass() = default; + ContainerTestClass(const ContainerTestClass& other) { + copy(other); + } + ContainerTestClass& operator=(const ContainerTestClass& other) { + copy(other); + return *this; + } + ContainerTestClass(ContainerTestClass&& other) = default; + ContainerTestClass& operator=(ContainerTestClass&& other) = default; + void copy(const ContainerTestClass& other) { + auto orphanParamIdxMap = other.getOrphanedParamsIdxMap(); + for(int i = -1; i < static_cast(other.modules_.size()); ++i) { + if(i >= 0) { + add(other.modules_[i]->clone()); + } + auto [paramIter, pEnd] = orphanParamIdxMap.equal_range(i); + for(; paramIter != pEnd; ++paramIter) { + const auto& param = other.params_[paramIter->second]; + params_.emplace_back(param.copy()); + } + } + } + + std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + void addParam(const Variable& param) { + params_.push_back(param); + } }; class ModuleTestF16 : public ::testing::Test { - protected: - void SetUp() override { - // Ensures all operations will be in f16 - OptimMode::get().setOptimLevel(OptimLevel::O3); - } - - void TearDown() override { - OptimMode::get().setOptimLevel(OptimLevel::DEFAULT); - } +protected: + void SetUp() override { + // Ensures all operations will be in f16 + OptimMode::get().setOptimLevel(OptimLevel::O3); + } + + void TearDown() override { + OptimMode::get().setOptimLevel(OptimLevel::DEFAULT); + } }; } // namespace TEST(ModuleTest, EmbeddingFwd) { - int embDim = 3, nEmb = 5, nQuery = 2, batchSize = 2; - auto wtVar = param(Tensor::fromVector( - {embDim, nEmb}, {8, 2, 2, 10, 5, 3, 3, 4, 6, 12, 3, 8, 0, 5, 2})); - - auto inVar = input(Tensor::fromVector({2, batchSize}, {1, 3, 0, 0})); - - auto expectedOutVar = Variable( - Tensor::fromVector( - {embDim, nQuery, batchSize}, {10, 5, 3, 12, 3, 8, 8, 2, 2, 8, 2, 2}), - true); - - // Var initialization - auto emb = Embedding(wtVar); - ASSERT_TRUE(allClose(emb.forward(inVar), expectedOutVar, 1E-7)); - - // Regular initialization - emb = Embedding(embDim, nEmb); - wtVar = emb.param(0); - ASSERT_EQ(wtVar.shape(), Shape({embDim, nEmb})); - - expectedOutVar = Variable( - fl::reshape( - wtVar.tensor()(fl::span, inVar.tensor()), - {embDim, nQuery, batchSize}), - true); - ASSERT_TRUE(allClose(emb.forward(inVar), expectedOutVar, 1E-7)); + int embDim = 3, nEmb = 5, nQuery = 2, batchSize = 2; + auto wtVar = param( + Tensor::fromVector( + {embDim, nEmb}, + {8, 2, 2, 10, 5, 3, 3, 4, 6, 12, 3, 8, 0, 5, 2} + ) + ); + + auto inVar = input(Tensor::fromVector({2, batchSize}, {1, 3, 0, 0})); + + auto expectedOutVar = Variable( + Tensor::fromVector( + {embDim, nQuery, batchSize}, + {10, 5, 3, 12, 3, 8, 8, 2, 2, 8, 2, 2} + ), + true + ); + + // Var initialization + auto emb = Embedding(wtVar); + ASSERT_TRUE(allClose(emb.forward(inVar), expectedOutVar, 1E-7)); + + // Regular initialization + emb = Embedding(embDim, nEmb); + wtVar = emb.param(0); + ASSERT_EQ(wtVar.shape(), Shape({embDim, nEmb})); + + expectedOutVar = Variable( + fl::reshape( + wtVar.tensor()(fl::span, inVar.tensor()), + {embDim, nQuery, batchSize}), + true); + ASSERT_TRUE(allClose(emb.forward(inVar), expectedOutVar, 1E-7)); } TEST(ModuleTest, LinearFwd) { - int n_in = 2, n_out = 3, x = 4, batchsize = 2; - auto wtVar = - param(Tensor::fromVector({n_out, n_in}, {8, 2, 2, 10, 5, 3})); - - auto inVar = input(Tensor::fromVector( - {n_in, x, batchsize}, {6, 2, 1, 4, 8, 2, 7, 1, 10, 7, 3, 7, 5, 9, 2, 4})); - - auto expected_outVar = Variable( - Tensor::fromVector( - {n_out, x, batchsize}, - {68, 22, 18, 48, 22, 14, 84, 26, 22, 66, 19, 17, - 150, 55, 41, 94, 41, 27, 130, 55, 37, 56, 24, 16}), - true); - - auto linNoBias = Linear(wtVar); - ASSERT_TRUE(allClose(linNoBias.forward(inVar), expected_outVar, 1E-7)); - - auto bsVar = input(Tensor::fromVector({n_out}, {1, 2, 3})); - expected_outVar = Variable( - Tensor::fromVector( - {n_out, x, batchsize}, - {69, 24, 21, 49, 24, 17, 85, 28, 25, 67, 21, 20, - 151, 57, 44, 95, 43, 30, 131, 57, 40, 57, 26, 19}), - true); - - auto linBias = Linear(wtVar, bsVar); - ASSERT_TRUE(allClose(linBias.forward(inVar), expected_outVar, 1E-7)); + int n_in = 2, n_out = 3, x = 4, batchsize = 2; + auto wtVar = + param(Tensor::fromVector({n_out, n_in}, {8, 2, 2, 10, 5, 3})); + + auto inVar = input( + Tensor::fromVector( + {n_in, x, batchsize}, + {6, 2, 1, 4, 8, 2, 7, 1, 10, 7, 3, 7, 5, 9, 2, 4} + ) + ); + + auto expected_outVar = Variable( + Tensor::fromVector( + {n_out, x, batchsize}, + {68, 22, 18, 48, 22, 14, 84, 26, 22, 66, 19, 17, + 150, 55, 41, 94, 41, 27, 130, 55, 37, 56, 24, 16} + ), + true + ); + + auto linNoBias = Linear(wtVar); + ASSERT_TRUE(allClose(linNoBias.forward(inVar), expected_outVar, 1E-7)); + + auto bsVar = input(Tensor::fromVector({n_out}, {1, 2, 3})); + expected_outVar = Variable( + Tensor::fromVector( + {n_out, x, batchsize}, + {69, 24, 21, 49, 24, 17, 85, 28, 25, 67, 21, 20, + 151, 57, 44, 95, 43, 30, 131, 57, 40, 57, 26, 19} + ), + true + ); + + auto linBias = Linear(wtVar, bsVar); + ASSERT_TRUE(allClose(linBias.forward(inVar), expected_outVar, 1E-7)); } TEST_F(ModuleTestF16, LinearFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - int n_in = 2, n_out = 3, x = 4, batchsize = 2; - auto wtVar = - param(Tensor::fromVector({n_out, n_in}, {8, 2, 2, 10, 5, 3})); - - auto inVar = input(Tensor::fromVector( - {n_in, x, batchsize}, - {6, 2, 1, 4, 8, 2, 7, 1, 10, 7, 3, 7, 5, 9, 2, 4}) - .astype(fl::dtype::f16)); - - auto expected_outVar = Variable( - Tensor::fromVector( - {n_out, x, batchsize}, - {68, 22, 18, 48, 22, 14, 84, 26, 22, 66, 19, 17, - 150, 55, 41, 94, 41, 27, 130, 55, 37, 56, 24, 16}) - .astype(fl::dtype::f16), - true); - - auto linNoBias = Linear(wtVar); - auto result = linNoBias.forward(inVar); - ASSERT_EQ(result.type(), inVar.type()); - ASSERT_TRUE(allClose(result, expected_outVar, 1E-2)); - - auto bsVar = input(Tensor::fromVector({n_out}, {1, 2, 3})); - ; - expected_outVar = Variable( - Tensor::fromVector( - {n_out, x, batchsize}, - {69, 24, 21, 49, 24, 17, 85, 28, 25, 67, 21, 20, - 151, 57, 44, 95, 43, 30, 131, 57, 40, 57, 26, 19}) - .astype(inVar.type()), - true); - - auto linBias = Linear(wtVar, bsVar); - auto resultBias = linBias.forward(inVar); - ASSERT_EQ(resultBias.type(), fl::dtype::f16); - ASSERT_TRUE(allClose(resultBias, expected_outVar, 1E-3)); - - // OptimLevel::O3 is active with this fixture - ASSERT_EQ(linBias.forward(inVar.astype(fl::dtype::f32)).type(), fl::dtype::f16); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + int n_in = 2, n_out = 3, x = 4, batchsize = 2; + auto wtVar = + param(Tensor::fromVector({n_out, n_in}, {8, 2, 2, 10, 5, 3})); + + auto inVar = input( + Tensor::fromVector( + {n_in, x, batchsize}, + {6, 2, 1, 4, 8, 2, 7, 1, 10, 7, 3, 7, 5, 9, 2, 4} + ) + .astype(fl::dtype::f16) + ); + + auto expected_outVar = Variable( + Tensor::fromVector( + {n_out, x, batchsize}, + {68, 22, 18, 48, 22, 14, 84, 26, 22, 66, 19, 17, + 150, 55, 41, 94, 41, 27, 130, 55, 37, 56, 24, 16} + ) + .astype(fl::dtype::f16), + true + ); + + auto linNoBias = Linear(wtVar); + auto result = linNoBias.forward(inVar); + ASSERT_EQ(result.type(), inVar.type()); + ASSERT_TRUE(allClose(result, expected_outVar, 1E-2)); + + auto bsVar = input(Tensor::fromVector({n_out}, {1, 2, 3})); + ; + expected_outVar = Variable( + Tensor::fromVector( + {n_out, x, batchsize}, + {69, 24, 21, 49, 24, 17, 85, 28, 25, 67, 21, 20, + 151, 57, 44, 95, 43, 30, 131, 57, 40, 57, 26, 19} + ) + .astype(inVar.type()), + true + ); + + auto linBias = Linear(wtVar, bsVar); + auto resultBias = linBias.forward(inVar); + ASSERT_EQ(resultBias.type(), fl::dtype::f16); + ASSERT_TRUE(allClose(resultBias, expected_outVar, 1E-3)); + + // OptimLevel::O3 is active with this fixture + ASSERT_EQ(linBias.forward(inVar.astype(fl::dtype::f32)).type(), fl::dtype::f16); } TEST(ModuleTest, ConvPadding) { - auto conv1 = Conv2D(30, 100, 3, 5, 2, 1, PaddingMode::SAME, 0, true, 1); - auto conv2 = Conv2D( - 30, 100, 3, 5, 2, 1, PaddingMode::SAME, PaddingMode::SAME, true, 1); - auto conv3 = - Conv2D(30, 100, 10, 10, 1, 1, PaddingMode::SAME, PaddingMode::SAME, 4, 4); - auto input = Variable(fl::rand({32, 32, 30, 2}), false); - - auto conv1Op = conv1(input); - ASSERT_EQ(conv1Op.shape(), Shape({16, 28, 100, 2})); - - auto conv2Op = conv2(input); - ASSERT_EQ(conv2Op.shape(), Shape({16, 32, 100, 2})); - - // test dilation - auto conv3Op = conv3(input); - ASSERT_EQ(conv3Op.shape(), Shape({32, 32, 100, 2})); + auto conv1 = Conv2D(30, 100, 3, 5, 2, 1, PaddingMode::SAME, 0, true, 1); + auto conv2 = Conv2D( + 30, + 100, + 3, + 5, + 2, + 1, + PaddingMode::SAME, + PaddingMode::SAME, + true, + 1 + ); + auto conv3 = + Conv2D(30, 100, 10, 10, 1, 1, PaddingMode::SAME, PaddingMode::SAME, 4, 4); + auto input = Variable(fl::rand({32, 32, 30, 2}), false); + + auto conv1Op = conv1(input); + ASSERT_EQ(conv1Op.shape(), Shape({16, 28, 100, 2})); + + auto conv2Op = conv2(input); + ASSERT_EQ(conv2Op.shape(), Shape({16, 32, 100, 2})); + + // test dilation + auto conv3Op = conv3(input); + ASSERT_EQ(conv3Op.shape(), Shape({32, 32, 100, 2})); } TEST(ModuleTest, GLUFwd) { - auto inVar = Variable( - Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}), true); - - auto expected_outVar = Variable( - Tensor::fromVector({3, 1}, {0.419983, 0.124492, 0.114888}), true); - - GatedLinearUnit glu(1); - ASSERT_TRUE(allClose(glu.forward(inVar), expected_outVar, 1E-4)); - - // test batching - int batchsize = 5; - inVar = Variable(fl::rand({10, 7, batchsize}), true); - glu = GatedLinearUnit(0); - - auto batchOutVar = glu(inVar); - - for (int i = 0; i < batchsize; ++i) { - expected_outVar = glu.forward(inVar(fl::span, fl::span, fl::range(i, i + 1))); - ASSERT_TRUE(allClose( - batchOutVar.tensor()(fl::span, fl::span, fl::range(i, i + 1)), - expected_outVar.tensor(), - 1E-7)); - } + auto inVar = Variable( + Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}), + true + ); + + auto expected_outVar = Variable( + Tensor::fromVector({3, 1}, {0.419983, 0.124492, 0.114888}), + true + ); + + GatedLinearUnit glu(1); + ASSERT_TRUE(allClose(glu.forward(inVar), expected_outVar, 1E-4)); + + // test batching + int batchsize = 5; + inVar = Variable(fl::rand({10, 7, batchsize}), true); + glu = GatedLinearUnit(0); + + auto batchOutVar = glu(inVar); + + for(int i = 0; i < batchsize; ++i) { + expected_outVar = glu.forward(inVar(fl::span, fl::span, fl::range(i, i + 1))); + ASSERT_TRUE(allClose( + batchOutVar.tensor()(fl::span, fl::span, fl::range(i, i + 1)), + expected_outVar.tensor(), + 1E-7)); + } } TEST_F(ModuleTestF16, GLUFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - auto inVar = Variable( - Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}) - .astype(fl::dtype::f16), - true); - - auto expected_outVar = Variable( - Tensor::fromVector({3, 1}, {0.419983, 0.124492, 0.114888}) - .astype(fl::dtype::f16), - true); - - GatedLinearUnit glu(1); - auto out = glu.forward(inVar); - ASSERT_EQ(out.type(), inVar.type()); - ASSERT_TRUE(allClose(out, expected_outVar, 1E-2)); - - // test batching - int batchsize = 5; - inVar = Variable(fl::rand({10, 7, batchsize}).astype(fl::dtype::f16), true); - glu = GatedLinearUnit(0); - - auto batchOutVar = glu(inVar); - - for (int i = 0; i < batchsize; ++i) { - expected_outVar = glu.forward(inVar(fl::span, fl::span, fl::range(i, i + 1))); - ASSERT_EQ(batchOutVar.type(), expected_outVar.type()); - ASSERT_TRUE(allClose( - batchOutVar.tensor()(fl::span, fl::span, fl::range(i, i + 1)), - expected_outVar.tensor(), - 1E-3)); - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + auto inVar = Variable( + Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}) + .astype(fl::dtype::f16), + true + ); + + auto expected_outVar = Variable( + Tensor::fromVector({3, 1}, {0.419983, 0.124492, 0.114888}) + .astype(fl::dtype::f16), + true + ); + + GatedLinearUnit glu(1); + auto out = glu.forward(inVar); + ASSERT_EQ(out.type(), inVar.type()); + ASSERT_TRUE(allClose(out, expected_outVar, 1E-2)); + + // test batching + int batchsize = 5; + inVar = Variable(fl::rand({10, 7, batchsize}).astype(fl::dtype::f16), true); + glu = GatedLinearUnit(0); + + auto batchOutVar = glu(inVar); + + for(int i = 0; i < batchsize; ++i) { + expected_outVar = glu.forward(inVar(fl::span, fl::span, fl::range(i, i + 1))); + ASSERT_EQ(batchOutVar.type(), expected_outVar.type()); + ASSERT_TRUE(allClose( + batchOutVar.tensor()(fl::span, fl::span, fl::range(i, i + 1)), + expected_outVar.tensor(), + 1E-3)); + } } TEST(ModuleTest, LogSoftmaxFwd) { - auto inVar = Variable( - Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}), true); - - auto expected_outVar0 = Variable( - Tensor::fromVector( - {3, 2}, {-0.740805, -1.34081, -1.34081, -1.3119, -0.911902, -1.1119}), - true); - LogSoftmax lsm0(0); - ASSERT_TRUE(allClose(lsm0.forward(inVar), expected_outVar0, 1E-4)); - - auto expected_outVar1 = Variable( - Tensor::fromVector( - {3, 2}, - {-0.403186, -0.854355, -0.744397, -1.10319, -0.554355, -0.644397}), - true); - LogSoftmax lsm1(1); - ASSERT_TRUE(allClose(lsm1.forward(inVar), expected_outVar1, 1E-4)); - - // test batching - int batchsize = 5; - inVar = Variable(fl::rand({10, 7, batchsize}), true); - LogSoftmax lsm(0); - - auto batchOutVar = lsm(inVar); - - for (int i = 0; i < batchsize; ++i) { - auto expected_outVar = - lsm.forward(inVar(fl::span, fl::span, fl::range(i, i + 1))); - ASSERT_TRUE(allClose( - batchOutVar.tensor()(fl::span, fl::span, fl::range(i, i + 1)), - expected_outVar.tensor(), - 1E-7)); - } + auto inVar = Variable( + Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}), + true + ); + + auto expected_outVar0 = Variable( + Tensor::fromVector( + {3, 2}, + {-0.740805, -1.34081, -1.34081, -1.3119, -0.911902, -1.1119} + ), + true + ); + LogSoftmax lsm0(0); + ASSERT_TRUE(allClose(lsm0.forward(inVar), expected_outVar0, 1E-4)); + + auto expected_outVar1 = Variable( + Tensor::fromVector( + {3, 2}, + {-0.403186, -0.854355, -0.744397, -1.10319, -0.554355, -0.644397} + ), + true + ); + LogSoftmax lsm1(1); + ASSERT_TRUE(allClose(lsm1.forward(inVar), expected_outVar1, 1E-4)); + + // test batching + int batchsize = 5; + inVar = Variable(fl::rand({10, 7, batchsize}), true); + LogSoftmax lsm(0); + + auto batchOutVar = lsm(inVar); + + for(int i = 0; i < batchsize; ++i) { + auto expected_outVar = + lsm.forward(inVar(fl::span, fl::span, fl::range(i, i + 1))); + ASSERT_TRUE(allClose( + batchOutVar.tensor()(fl::span, fl::span, fl::range(i, i + 1)), + expected_outVar.tensor(), + 1E-7)); + } } TEST_F(ModuleTestF16, LogSoftmaxFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - auto inVar = Variable( - Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}) - .astype(fl::dtype::f16), - true); - - auto expected_outVar0 = Variable( - Tensor::fromVector( - {3, 2}, {-0.740805, -1.34081, -1.34081, -1.3119, -0.911902, -1.1119}), - true); - LogSoftmax lsm0(0); - auto result0 = lsm0.forward(inVar); - ASSERT_TRUE(allClose(result0, expected_outVar0, 1E-3)); - - auto expected_outVar1 = Variable( - Tensor::fromVector( - {3, 2}, - {-0.403186, -0.854355, -0.744397, -1.10319, -0.554355, -0.644397}), - true); - LogSoftmax lsm1(1); - ASSERT_TRUE(allClose(lsm1.forward(inVar), expected_outVar1, 1E-3)); - - // test batching - int batchsize = 5; - inVar = Variable(fl::rand({10, 7, batchsize}), true); - LogSoftmax lsm(0); - - auto batchOutVar = lsm(inVar); - - for (int i = 0; i < batchsize; ++i) { - auto expected_outVar = - lsm.forward(inVar(fl::span, fl::span, fl::range(i, i + 1))); - ASSERT_TRUE(allClose( - batchOutVar.tensor()(fl::span, fl::span, fl::range(i, i + 1)), - expected_outVar.tensor(), - 1E-7)); - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + auto inVar = Variable( + Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}) + .astype(fl::dtype::f16), + true + ); + + auto expected_outVar0 = Variable( + Tensor::fromVector( + {3, 2}, + {-0.740805, -1.34081, -1.34081, -1.3119, -0.911902, -1.1119} + ), + true + ); + LogSoftmax lsm0(0); + auto result0 = lsm0.forward(inVar); + ASSERT_TRUE(allClose(result0, expected_outVar0, 1E-3)); + + auto expected_outVar1 = Variable( + Tensor::fromVector( + {3, 2}, + {-0.403186, -0.854355, -0.744397, -1.10319, -0.554355, -0.644397} + ), + true + ); + LogSoftmax lsm1(1); + ASSERT_TRUE(allClose(lsm1.forward(inVar), expected_outVar1, 1E-3)); + + // test batching + int batchsize = 5; + inVar = Variable(fl::rand({10, 7, batchsize}), true); + LogSoftmax lsm(0); + + auto batchOutVar = lsm(inVar); + + for(int i = 0; i < batchsize; ++i) { + auto expected_outVar = + lsm.forward(inVar(fl::span, fl::span, fl::range(i, i + 1))); + ASSERT_TRUE(allClose( + batchOutVar.tensor()(fl::span, fl::span, fl::range(i, i + 1)), + expected_outVar.tensor(), + 1E-7)); + } } TEST(ModuleTest, ConvolutionFwd) { - // test batching - auto conv = Conv2D(30, 50, 9, 7, 2, 3, 3, 2, 1, 1, true, 1); - int batchsize = 10; - auto input = fl::rand({120, 100, 30, batchsize}, fl::dtype::f32); - auto batchOutVar = conv(Variable(input, false)); - - for (int i = 0; i < batchsize; ++i) { - auto expected_outVar = conv( - Variable(input(fl::span, fl::span, fl::span, fl::range(i, i + 1)), false)); - ASSERT_TRUE(allClose( - batchOutVar.tensor()(fl::span, fl::span, fl::span, fl::range(i, i + 1)), - expected_outVar.tensor(), - 1E-5)); - } + // test batching + auto conv = Conv2D(30, 50, 9, 7, 2, 3, 3, 2, 1, 1, true, 1); + int batchsize = 10; + auto input = fl::rand({120, 100, 30, batchsize}, fl::dtype::f32); + auto batchOutVar = conv(Variable(input, false)); + + for(int i = 0; i < batchsize; ++i) { + auto expected_outVar = conv( + Variable(input(fl::span, fl::span, fl::span, fl::range(i, i + 1)), false) + ); + ASSERT_TRUE(allClose( + batchOutVar.tensor()(fl::span, fl::span, fl::span, fl::range(i, i + 1)), + expected_outVar.tensor(), + 1E-5)); + } } TEST_F(ModuleTestF16, ConvolutionFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - // test batching - auto conv = Conv2D(30, 50, 9, 7, 2, 3, 3, 2, 1, 1, true, 1); - int batchsize = 1; - auto input = fl::rand({120, 100, 30, batchsize}, fl::dtype::f16); - auto batchOutVar = conv(Variable(input, false)); - ASSERT_EQ(batchOutVar.type(), input.type()); - - for (int i = 0; i < batchsize; ++i) { - auto expected_outVar = conv( - Variable(input(fl::span, fl::span, fl::span, fl::range(i, i + 1)), false)); - ASSERT_TRUE(allClose( - batchOutVar.tensor()(fl::span, fl::span, fl::span, fl::range(i, i + 1)), - expected_outVar.tensor(), - 1E-7)); - } - - auto inputF32 = fl::rand({120, 100, 30, batchsize}, fl::dtype::f32); - ASSERT_EQ( - conv(Variable(input, false)).type(), - fl::dtype::f16); // OptimLevel::O3 is active with this fixture + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + // test batching + auto conv = Conv2D(30, 50, 9, 7, 2, 3, 3, 2, 1, 1, true, 1); + int batchsize = 1; + auto input = fl::rand({120, 100, 30, batchsize}, fl::dtype::f16); + auto batchOutVar = conv(Variable(input, false)); + ASSERT_EQ(batchOutVar.type(), input.type()); + + for(int i = 0; i < batchsize; ++i) { + auto expected_outVar = conv( + Variable(input(fl::span, fl::span, fl::span, fl::range(i, i + 1)), false) + ); + ASSERT_TRUE(allClose( + batchOutVar.tensor()(fl::span, fl::span, fl::span, fl::range(i, i + 1)), + expected_outVar.tensor(), + 1E-7)); + } + + auto inputF32 = fl::rand({120, 100, 30, batchsize}, fl::dtype::f32); + ASSERT_EQ( + conv(Variable(input, false)).type(), + fl::dtype::f16 + ); // OptimLevel::O3 is active with this fixture } TEST(ModuleTest, ConvolutionWithGroupFwd) { - // test batching - auto conv = Conv2D(30, 50, 9, 7, 2, 3, 3, 2, true, 2); - int batchsize = 10; - auto input = fl::rand({120, 100, 30, batchsize}); - auto batchOutVar = conv(Variable(input, false)); - for (int i = 0; i < batchsize; ++i) { - auto expected_outVar = conv( - Variable(input(fl::span, fl::span, fl::span, fl::range(i, i + 1)), false)); - ASSERT_TRUE(allClose( - batchOutVar.tensor()(fl::span, fl::span, fl::span, fl::range(i, i + 1)), - expected_outVar.tensor(), - 1E-5)); - } + // test batching + auto conv = Conv2D(30, 50, 9, 7, 2, 3, 3, 2, true, 2); + int batchsize = 10; + auto input = fl::rand({120, 100, 30, batchsize}); + auto batchOutVar = conv(Variable(input, false)); + for(int i = 0; i < batchsize; ++i) { + auto expected_outVar = conv( + Variable(input(fl::span, fl::span, fl::span, fl::range(i, i + 1)), false) + ); + ASSERT_TRUE(allClose( + batchOutVar.tensor()(fl::span, fl::span, fl::span, fl::range(i, i + 1)), + expected_outVar.tensor(), + 1E-5)); + } } TEST(ModuleTest, PoolingFwd) { - // test batching - auto pool = Pool2D(9, 7, 1, 1, PaddingMode::SAME, PaddingMode::SAME); - int batchsize = 10; - auto input = fl::rand({120, 100, 30, batchsize}); - auto batchOutVar = pool(Variable(input, false)); - for (int i = 0; i < batchsize; ++i) { - ASSERT_EQ(input.shape(), batchOutVar.shape()); - auto expected_outVar = pool( - Variable(input(fl::span, fl::span, fl::span, fl::range(i, i + 1)), false)); - ASSERT_TRUE(allClose( - batchOutVar.tensor()(fl::span, fl::span, fl::span, fl::range(i, i + 1)), - expected_outVar.tensor(), - 1E-7)); - } + // test batching + auto pool = Pool2D(9, 7, 1, 1, PaddingMode::SAME, PaddingMode::SAME); + int batchsize = 10; + auto input = fl::rand({120, 100, 30, batchsize}); + auto batchOutVar = pool(Variable(input, false)); + for(int i = 0; i < batchsize; ++i) { + ASSERT_EQ(input.shape(), batchOutVar.shape()); + auto expected_outVar = pool( + Variable(input(fl::span, fl::span, fl::span, fl::range(i, i + 1)), false) + ); + ASSERT_TRUE(allClose( + batchOutVar.tensor()(fl::span, fl::span, fl::span, fl::range(i, i + 1)), + expected_outVar.tensor(), + 1E-7)); + } } TEST_F(ModuleTestF16, PoolingFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - // test batching - auto pool = Pool2D(9, 7, 1, 1, PaddingMode::SAME, PaddingMode::SAME); - int batchsize = 10; - auto input = fl::rand({120, 100, 30, batchsize}, fl::dtype::f16); - auto batchOutVar = pool(Variable(input, false)); - for (int i = 0; i < batchsize; ++i) { - ASSERT_EQ(input.shape(), batchOutVar.shape()); - auto expected_outVar = pool( - Variable(input(fl::span, fl::span, fl::span, fl::range(i, i + 1)), false)); - ASSERT_TRUE(allClose( - batchOutVar.tensor()(fl::span, fl::span, fl::span, fl::range(i, i + 1)), - expected_outVar.tensor(), - 1E-7)); - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + // test batching + auto pool = Pool2D(9, 7, 1, 1, PaddingMode::SAME, PaddingMode::SAME); + int batchsize = 10; + auto input = fl::rand({120, 100, 30, batchsize}, fl::dtype::f16); + auto batchOutVar = pool(Variable(input, false)); + for(int i = 0; i < batchsize; ++i) { + ASSERT_EQ(input.shape(), batchOutVar.shape()); + auto expected_outVar = pool( + Variable(input(fl::span, fl::span, fl::span, fl::range(i, i + 1)), false) + ); + ASSERT_TRUE(allClose( + batchOutVar.tensor()(fl::span, fl::span, fl::span, fl::range(i, i + 1)), + expected_outVar.tensor(), + 1E-7)); + } } TEST(ModuleTest, RNNFwd) { - auto mode = RnnMode::RELU; - int num_layers = 2; - int hidden_size = 3; - int input_size = 4; - int batch_size = 5; - int seq_length = 6; - - auto in = Variable( - fl::rand({input_size, batch_size, seq_length}, fl::dtype::f32), true); - unsigned n_params = 51; - auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f32), true); - for (int i = 0; i < in.elements(); ++i) { - in.tensor().flat(i) = (i + 1) * 0.01; - } - for (int i = 0; i < w.elements(); ++i) { - w.tensor().flat(i) = (i + 1) * 0.01; - } - auto rnn = RNN(input_size, hidden_size, num_layers, mode); - rnn.setParams(w, 0); - - auto out = rnn(in); - Shape expected_dims({3, 5, 6}); - ASSERT_EQ(out.shape(), expected_dims); - // Calculated from Lua Torch Cudnn implementation - - auto expected_outVar = Variable( - Tensor::fromVector( - expected_dims, - {1.5418, 1.6389, 1.7361, 1.5491, 1.6472, 1.7452, 1.5564, - 1.6554, 1.7544, 1.5637, 1.6637, 1.7636, 1.5710, 1.6719, - 1.7728, 3.4571, 3.7458, 4.0345, 3.4761, 3.7670, 4.0578, - 3.4951, 3.7881, 4.0812, 3.5141, 3.8093, 4.1045, 3.5331, - 3.8305, 4.1278, 5.6947, 6.2004, 6.7060, 5.7281, 6.2373, - 6.7466, 5.7614, 6.2743, 6.7871, 5.7948, 6.3112, 6.8276, - 5.8282, 6.3482, 6.8681, 8.2005, 8.9458, 9.6911, 8.2500, - 9.0005, 9.7509, 8.2995, 9.0551, 9.8107, 8.3491, 9.1098, - 9.8705, 8.3986, 9.1645, 9.9303, 10.9520, 11.9587, 12.9655, - 11.0191, 12.0326, 13.0462, 11.0861, 12.1065, 13.1269, 11.1532, - 12.1804, 13.2075, 11.2203, 12.2543, 13.2882, 13.9432, 15.2333, - 16.5233, 14.0291, 15.3277, 16.6263, 14.1149, 15.4221, 16.7292, - 14.2008, 15.5165, 16.8322, 14.2866, 15.6109, 16.9351}), - true); - ASSERT_TRUE(allClose(out, expected_outVar, 1E-4)); + auto mode = RnnMode::RELU; + int num_layers = 2; + int hidden_size = 3; + int input_size = 4; + int batch_size = 5; + int seq_length = 6; + + auto in = Variable( + fl::rand({input_size, batch_size, seq_length}, fl::dtype::f32), + true + ); + unsigned n_params = 51; + auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f32), true); + for(int i = 0; i < in.elements(); ++i) { + in.tensor().flat(i) = (i + 1) * 0.01; + } + for(int i = 0; i < w.elements(); ++i) { + w.tensor().flat(i) = (i + 1) * 0.01; + } + auto rnn = RNN(input_size, hidden_size, num_layers, mode); + rnn.setParams(w, 0); + + auto out = rnn(in); + Shape expected_dims({3, 5, 6}); + ASSERT_EQ(out.shape(), expected_dims); + // Calculated from Lua Torch Cudnn implementation + + auto expected_outVar = Variable( + Tensor::fromVector( + expected_dims, + {1.5418, 1.6389, 1.7361, 1.5491, 1.6472, 1.7452, 1.5564, + 1.6554, 1.7544, 1.5637, 1.6637, 1.7636, 1.5710, 1.6719, + 1.7728, 3.4571, 3.7458, 4.0345, 3.4761, 3.7670, 4.0578, + 3.4951, 3.7881, 4.0812, 3.5141, 3.8093, 4.1045, 3.5331, + 3.8305, 4.1278, 5.6947, 6.2004, 6.7060, 5.7281, 6.2373, + 6.7466, 5.7614, 6.2743, 6.7871, 5.7948, 6.3112, 6.8276, + 5.8282, 6.3482, 6.8681, 8.2005, 8.9458, 9.6911, 8.2500, + 9.0005, 9.7509, 8.2995, 9.0551, 9.8107, 8.3491, 9.1098, + 9.8705, 8.3986, 9.1645, 9.9303, 10.9520, 11.9587, 12.9655, + 11.0191, 12.0326, 13.0462, 11.0861, 12.1065, 13.1269, 11.1532, + 12.1804, 13.2075, 11.2203, 12.2543, 13.2882, 13.9432, 15.2333, + 16.5233, 14.0291, 15.3277, 16.6263, 14.1149, 15.4221, 16.7292, + 14.2008, 15.5165, 16.8322, 14.2866, 15.6109, 16.9351} + ), + true + ); + ASSERT_TRUE(allClose(out, expected_outVar, 1E-4)); } TEST(ModuleTest, LSTMFwd) { - auto mode = RnnMode::LSTM; - int num_layers = 4; - int hidden_size = 5; - int input_size = 3; - int batch_size = 2; - int seq_length = 2; - - auto in = Variable( - fl::rand({input_size, batch_size, seq_length}, fl::dtype::f32), true); - unsigned n_params = 920; - auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f32), true); - - for (int i = 0; i < in.elements(); ++i) { - in.tensor().flat(i) = (i + 1) * 0.001; - } - for (int i = 0; i < w.elements(); ++i) { - w.tensor().flat(i) = (i + 1) * 0.001; - } - - auto rnn = RNN(input_size, hidden_size, num_layers, mode); - rnn.setParams(w, 0); - - auto out = rnn(in); - Shape expected_dims({5, 2, 2}); - ASSERT_EQ(out.shape(), expected_dims); - // Calculated from Lua Torch Cudnn implementation - auto expected_outVar = Variable( - Tensor::fromVector( - expected_dims, - {0.7390, 0.7395, 0.7399, 0.7403, 0.7407, 0.7390, 0.7395, - 0.7399, 0.7403, 0.7407, 0.9617, 0.9618, 0.9619, 0.9619, - 0.962, 0.9617, 0.9618, 0.9619, 0.9619, 0.962}), - true); - ASSERT_TRUE(allClose(out, expected_outVar, 1E-4)); + auto mode = RnnMode::LSTM; + int num_layers = 4; + int hidden_size = 5; + int input_size = 3; + int batch_size = 2; + int seq_length = 2; + + auto in = Variable( + fl::rand({input_size, batch_size, seq_length}, fl::dtype::f32), + true + ); + unsigned n_params = 920; + auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f32), true); + + for(int i = 0; i < in.elements(); ++i) { + in.tensor().flat(i) = (i + 1) * 0.001; + } + for(int i = 0; i < w.elements(); ++i) { + w.tensor().flat(i) = (i + 1) * 0.001; + } + + auto rnn = RNN(input_size, hidden_size, num_layers, mode); + rnn.setParams(w, 0); + + auto out = rnn(in); + Shape expected_dims({5, 2, 2}); + ASSERT_EQ(out.shape(), expected_dims); + // Calculated from Lua Torch Cudnn implementation + auto expected_outVar = Variable( + Tensor::fromVector( + expected_dims, + {0.7390, 0.7395, 0.7399, 0.7403, 0.7407, 0.7390, 0.7395, + 0.7399, 0.7403, 0.7407, 0.9617, 0.9618, 0.9619, 0.9619, + 0.962, 0.9617, 0.9618, 0.9619, 0.9619, 0.962} + ), + true + ); + ASSERT_TRUE(allClose(out, expected_outVar, 1E-4)); } TEST(ModuleTest, GRUFwd) { - auto mode = RnnMode::GRU; - int num_layers = 4; - int hidden_size = 5; - int input_size = 3; - int batch_size = 2; - int seq_length = 2; - - auto in = Variable( - fl::rand({input_size, batch_size, seq_length}, fl::dtype::f32), true); - unsigned n_params = 690; - auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f32), true); - - for (int i = 0; i < in.elements(); ++i) { - in.tensor().flat(i) = (i + 1) * 0.001; - } - for (int i = 0; i < w.elements(); ++i) { - w.tensor().flat(i) = (i + 1) * 0.001; - } - - auto rnn = RNN(input_size, hidden_size, num_layers, mode); - rnn.setParams(w, 0); - - auto out = rnn(in); - Shape expected_dims({5, 2, 2}); - ASSERT_EQ(out.shape(), expected_dims); - // Calculated from Lua Torch Cudnn implementation - auto expected_outVar = Variable( - Tensor::fromVector( - expected_dims, - {0.1430, 0.1425, 0.1419, 0.1413, 0.1408, 0.1430, 0.1425, - 0.1419, 0.1413, 0.1408, 0.2206, 0.2194, 0.2181, 0.2168, - 0.2155, 0.2206, 0.2194, 0.2181, 0.2168, 0.2155}), - true); - ASSERT_TRUE(allClose(out, expected_outVar, 1E-4)); + auto mode = RnnMode::GRU; + int num_layers = 4; + int hidden_size = 5; + int input_size = 3; + int batch_size = 2; + int seq_length = 2; + + auto in = Variable( + fl::rand({input_size, batch_size, seq_length}, fl::dtype::f32), + true + ); + unsigned n_params = 690; + auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f32), true); + + for(int i = 0; i < in.elements(); ++i) { + in.tensor().flat(i) = (i + 1) * 0.001; + } + for(int i = 0; i < w.elements(); ++i) { + w.tensor().flat(i) = (i + 1) * 0.001; + } + + auto rnn = RNN(input_size, hidden_size, num_layers, mode); + rnn.setParams(w, 0); + + auto out = rnn(in); + Shape expected_dims({5, 2, 2}); + ASSERT_EQ(out.shape(), expected_dims); + // Calculated from Lua Torch Cudnn implementation + auto expected_outVar = Variable( + Tensor::fromVector( + expected_dims, + {0.1430, 0.1425, 0.1419, 0.1413, 0.1408, 0.1430, 0.1425, + 0.1419, 0.1413, 0.1408, 0.2206, 0.2194, 0.2181, 0.2168, + 0.2155, 0.2206, 0.2194, 0.2181, 0.2168, 0.2155} + ), + true + ); + ASSERT_TRUE(allClose(out, expected_outVar, 1E-4)); } TEST_F(ModuleTestF16, RNNFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - auto mode = RnnMode::RELU; - int num_layers = 2; - int hidden_size = 3; - int input_size = 4; - int batch_size = 5; - int seq_length = 6; - - auto in = Variable( - fl::rand({input_size, batch_size, seq_length}, fl::dtype::f16), true); - unsigned n_params = 51; - auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f16), true); - for (int i = 0; i < in.elements(); ++i) { - in.tensor().flat(i) = (i + 1) * 0.01; - } - for (int i = 0; i < w.elements(); ++i) { - w.tensor().flat(i) = (i + 1) * 0.01; - } - auto rnn = RNN(input_size, hidden_size, num_layers, mode); - rnn.setParams(w, 0); - - auto out = rnn(in); - Shape expected_dims({3, 5, 6}); - ASSERT_EQ(out.shape(), expected_dims); - // Calculated from Lua Torch Cudnn implementation - auto expected_outVar = Variable( - Tensor::fromVector( - expected_dims, - {1.5418, 1.6389, 1.7361, 1.5491, 1.6472, 1.7452, 1.5564, - 1.6554, 1.7544, 1.5637, 1.6637, 1.7636, 1.5710, 1.6719, - 1.7728, 3.4571, 3.7458, 4.0345, 3.4761, 3.7670, 4.0578, - 3.4951, 3.7881, 4.0812, 3.5141, 3.8093, 4.1045, 3.5331, - 3.8305, 4.1278, 5.6947, 6.2004, 6.7060, 5.7281, 6.2373, - 6.7466, 5.7614, 6.2743, 6.7871, 5.7948, 6.3112, 6.8276, - 5.8282, 6.3482, 6.8681, 8.2005, 8.9458, 9.6911, 8.2500, - 9.0005, 9.7509, 8.2995, 9.0551, 9.8107, 8.3491, 9.1098, - 9.8705, 8.3986, 9.1645, 9.9303, 10.9520, 11.9587, 12.9655, - 11.0191, 12.0326, 13.0462, 11.0861, 12.1065, 13.1269, 11.1532, - 12.1804, 13.2075, 11.2203, 12.2543, 13.2882, 13.9432, 15.2333, - 16.5233, 14.0291, 15.3277, 16.6263, 14.1149, 15.4221, 16.7292, - 14.2008, 15.5165, 16.8322, 14.2866, 15.6109, 16.9351}), - true); - ASSERT_TRUE(allClose(out, expected_outVar.astype(in.type()), 5E-2)); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + auto mode = RnnMode::RELU; + int num_layers = 2; + int hidden_size = 3; + int input_size = 4; + int batch_size = 5; + int seq_length = 6; + + auto in = Variable( + fl::rand({input_size, batch_size, seq_length}, fl::dtype::f16), + true + ); + unsigned n_params = 51; + auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f16), true); + for(int i = 0; i < in.elements(); ++i) { + in.tensor().flat(i) = (i + 1) * 0.01; + } + for(int i = 0; i < w.elements(); ++i) { + w.tensor().flat(i) = (i + 1) * 0.01; + } + auto rnn = RNN(input_size, hidden_size, num_layers, mode); + rnn.setParams(w, 0); + + auto out = rnn(in); + Shape expected_dims({3, 5, 6}); + ASSERT_EQ(out.shape(), expected_dims); + // Calculated from Lua Torch Cudnn implementation + auto expected_outVar = Variable( + Tensor::fromVector( + expected_dims, + {1.5418, 1.6389, 1.7361, 1.5491, 1.6472, 1.7452, 1.5564, + 1.6554, 1.7544, 1.5637, 1.6637, 1.7636, 1.5710, 1.6719, + 1.7728, 3.4571, 3.7458, 4.0345, 3.4761, 3.7670, 4.0578, + 3.4951, 3.7881, 4.0812, 3.5141, 3.8093, 4.1045, 3.5331, + 3.8305, 4.1278, 5.6947, 6.2004, 6.7060, 5.7281, 6.2373, + 6.7466, 5.7614, 6.2743, 6.7871, 5.7948, 6.3112, 6.8276, + 5.8282, 6.3482, 6.8681, 8.2005, 8.9458, 9.6911, 8.2500, + 9.0005, 9.7509, 8.2995, 9.0551, 9.8107, 8.3491, 9.1098, + 9.8705, 8.3986, 9.1645, 9.9303, 10.9520, 11.9587, 12.9655, + 11.0191, 12.0326, 13.0462, 11.0861, 12.1065, 13.1269, 11.1532, + 12.1804, 13.2075, 11.2203, 12.2543, 13.2882, 13.9432, 15.2333, + 16.5233, 14.0291, 15.3277, 16.6263, 14.1149, 15.4221, 16.7292, + 14.2008, 15.5165, 16.8322, 14.2866, 15.6109, 16.9351} + ), + true + ); + ASSERT_TRUE(allClose(out, expected_outVar.astype(in.type()), 5E-2)); } TEST(ModuleTest, ViewFwd) { - auto module = View(Shape({-1, 0, 6})); - auto input = Variable(Tensor({1, 2, 3, 4}), true); - ASSERT_EQ(module(input).shape(), Shape({2, 2, 6})); + auto module = View(Shape({-1, 0, 6})); + auto input = Variable(Tensor({1, 2, 3, 4}), true); + ASSERT_EQ(module(input).shape(), Shape({2, 2, 6})); } TEST(ModuleTest, DropoutFwd) { - auto module = Dropout(0.5); - // Train Mode - module.train(); - auto in = Variable(fl::rand({1000, 1000}), true); - auto out = module(in); - - ASSERT_NEAR( - out.elements() - fl::countNonzero(out.tensor()).scalar(), - in.elements() / 2, - in.elements() / 16); // Check enough zeroes - - ASSERT_GT( - fl::amax(out.tensor()).scalar(), 1.5); // Check input is scaled - - // Eval Mode - module.eval(); - out = module(in); - ASSERT_TRUE(allClose(out, in, 1E-5)); + auto module = Dropout(0.5); + // Train Mode + module.train(); + auto in = Variable(fl::rand({1000, 1000}), true); + auto out = module(in); + + ASSERT_NEAR( + out.elements() - fl::countNonzero(out.tensor()).scalar(), + in.elements() / 2, + in.elements() / 16 + ); // Check enough zeroes + + ASSERT_GT( + fl::amax(out.tensor()).scalar(), + 1.5 + ); // Check input is scaled + + // Eval Mode + module.eval(); + out = module(in); + ASSERT_TRUE(allClose(out, in, 1E-5)); } TEST_F(ModuleTestF16, DropoutFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - auto module = Dropout(0.5); - // Train Mode - module.train(); - auto in = Variable(fl::rand({1000, 1000}, fl::dtype::f16), true); - auto out = module(in); - ASSERT_EQ(out.type(), fl::dtype::f16); - - ASSERT_NEAR( - out.elements() - fl::countNonzero(out.tensor()).scalar(), - in.elements() / 2, - in.elements() / 16); // Check enough zeroes - - ASSERT_GT( - fl::amax(out.tensor()).asScalar(), 1.5); // Check input is scaled + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } - // Eval Mode - module.eval(); - out = module(in); - ASSERT_TRUE(allClose(out, in, 1E-5)); + auto module = Dropout(0.5); + // Train Mode + module.train(); + auto in = Variable(fl::rand({1000, 1000}, fl::dtype::f16), true); + auto out = module(in); + ASSERT_EQ(out.type(), fl::dtype::f16); + + ASSERT_NEAR( + out.elements() - fl::countNonzero(out.tensor()).scalar(), + in.elements() / 2, + in.elements() / 16 + ); // Check enough zeroes + + ASSERT_GT( + fl::amax(out.tensor()).asScalar(), + 1.5 + ); // Check input is scaled + + // Eval Mode + module.eval(); + out = module(in); + ASSERT_TRUE(allClose(out, in, 1E-5)); } TEST(ModuleTest, PaddingFwd) { - auto module = Padding({{1, 2}, {3, 4}}, -1); - auto input = Variable(fl::rand({1, 2, 3, 4}, fl::dtype::f64), true); - auto output = module(input); - ASSERT_EQ(output.shape(), Shape({4, 9, 3, 4})); - ASSERT_TRUE(allClose(input, output(fl::range(1, 2), fl::range(3, 5)))); - ASSERT_NEAR( - fl::sum(input.tensor()).scalar(), - fl::sum(output.tensor()).scalar() + 408, - 1E-5); + auto module = Padding({{1, 2}, {3, 4}}, -1); + auto input = Variable(fl::rand({1, 2, 3, 4}, fl::dtype::f64), true); + auto output = module(input); + ASSERT_EQ(output.shape(), Shape({4, 9, 3, 4})); + ASSERT_TRUE(allClose(input, output(fl::range(1, 2), fl::range(3, 5)))); + ASSERT_NEAR( + fl::sum(input.tensor()).scalar(), + fl::sum(output.tensor()).scalar() + 408, + 1E-5 + ); } TEST(ModuleTest, LayerNormFwd) { - double eps = 1E-5; - std::vector feat_axes = {3}; - int F = 10; - auto input = Variable(fl::rand({4, 4, 3, F}), true); - - auto sample_mean = mean(input, {3}); - auto sample_var = var(input, {3}, true); - auto true_out = (input - tileAs(sample_mean, input)) / - tileAs(fl::sqrt(sample_var + eps), input); - - // no affine transform - auto module1 = LayerNorm(feat_axes, eps, false); - - module1.train(); - auto out = module1.forward(input); - - ASSERT_TRUE(allClose(out, true_out, eps)); - ASSERT_EQ(out.type(), input.type()); - - module1.eval(); - out = module1.forward(input); - - ASSERT_TRUE(allClose(out.tensor(), true_out.tensor(), eps)); - ASSERT_EQ(out.type(), input.type()); - - // with affine transform - auto module2 = LayerNorm(feat_axes, eps, true); - - module2.train(); - auto out_train = module2.forward(input); - module2.eval(); - auto out_eval = module2.forward(input); - - ASSERT_TRUE(allClose(out_train.tensor(), out_eval.tensor(), eps)); - ASSERT_EQ(out_train.shape(), input.shape()); - - // with affine transform - auto module3 = LayerNorm(feat_axes, eps, true, F); - module3.setParams(Variable(fl::full({F}, 1.0), false), 0); - module3.setParams(Variable(fl::full({F}, 0.0), false), 1); - auto out3 = module3.forward(input); - ASSERT_TRUE(allClose(out_train.tensor(), out3.tensor(), eps)); - - // With other shapes - auto input3Dim = Variable(fl::rand({4, 4, 3}), true); - auto module4 = LayerNorm(std::vector{0}, eps, false); - out = module4.forward(input3Dim); - ASSERT_EQ(out.shape(), input3Dim.shape()); + double eps = 1E-5; + std::vector feat_axes = {3}; + int F = 10; + auto input = Variable(fl::rand({4, 4, 3, F}), true); + + auto sample_mean = mean(input, {3}); + auto sample_var = var(input, {3}, true); + auto true_out = (input - tileAs(sample_mean, input)) + / tileAs(fl::sqrt(sample_var + eps), input); + + // no affine transform + auto module1 = LayerNorm(feat_axes, eps, false); + + module1.train(); + auto out = module1.forward(input); + + ASSERT_TRUE(allClose(out, true_out, eps)); + ASSERT_EQ(out.type(), input.type()); + + module1.eval(); + out = module1.forward(input); + + ASSERT_TRUE(allClose(out.tensor(), true_out.tensor(), eps)); + ASSERT_EQ(out.type(), input.type()); + + // with affine transform + auto module2 = LayerNorm(feat_axes, eps, true); + + module2.train(); + auto out_train = module2.forward(input); + module2.eval(); + auto out_eval = module2.forward(input); + + ASSERT_TRUE(allClose(out_train.tensor(), out_eval.tensor(), eps)); + ASSERT_EQ(out_train.shape(), input.shape()); + + // with affine transform + auto module3 = LayerNorm(feat_axes, eps, true, F); + module3.setParams(Variable(fl::full({F}, 1.0), false), 0); + module3.setParams(Variable(fl::full({F}, 0.0), false), 1); + auto out3 = module3.forward(input); + ASSERT_TRUE(allClose(out_train.tensor(), out3.tensor(), eps)); + + // With other shapes + auto input3Dim = Variable(fl::rand({4, 4, 3}), true); + auto module4 = LayerNorm(std::vector{0}, eps, false); + out = module4.forward(input3Dim); + ASSERT_EQ(out.shape(), input3Dim.shape()); } TEST_F(ModuleTestF16, LayerNormFwdF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } - double eps = 5E-2; - std::vector feat_axes = {3}; - auto input = Variable(fl::rand({4, 4, 3, 10}, fl::dtype::f16), true); + double eps = 5E-2; + std::vector feat_axes = {3}; + auto input = Variable(fl::rand({4, 4, 3, 10}, fl::dtype::f16), true); - auto sample_mean = mean(input, {3}); - auto sample_var = var(input, {3}, true); - auto true_out = (input - tileAs(sample_mean, input).astype(input.type())) / - tileAs(fl::sqrt(sample_var + eps), input).astype(input.type()); + auto sample_mean = mean(input, {3}); + auto sample_var = var(input, {3}, true); + auto true_out = (input - tileAs(sample_mean, input).astype(input.type())) + / tileAs(fl::sqrt(sample_var + eps), input).astype(input.type()); - // no affine transform - auto module1 = LayerNorm(feat_axes, eps, false); + // no affine transform + auto module1 = LayerNorm(feat_axes, eps, false); - module1.train(); - auto out = module1.forward(input); + module1.train(); + auto out = module1.forward(input); - ASSERT_TRUE(allClose(out, true_out.astype(out.type()), eps)); + ASSERT_TRUE(allClose(out, true_out.astype(out.type()), eps)); - module1.eval(); - out = module1.forward(input); + module1.eval(); + out = module1.forward(input); - ASSERT_TRUE( - allClose(out.tensor(), true_out.tensor().astype(out.type()), eps)); + ASSERT_TRUE( + allClose(out.tensor(), true_out.tensor().astype(out.type()), eps) + ); - // with affine transform - auto module2 = LayerNorm(feat_axes, eps, true); + // with affine transform + auto module2 = LayerNorm(feat_axes, eps, true); - module2.train(); - auto out_train = module2.forward(input); - module2.eval(); - auto out_eval = module2.forward(input); + module2.train(); + auto out_train = module2.forward(input); + module2.eval(); + auto out_eval = module2.forward(input); - ASSERT_TRUE(allClose(out_train.tensor(), out_eval.tensor(), eps)); - ASSERT_EQ(out_train.shape(), input.shape()); + ASSERT_TRUE(allClose(out_train.tensor(), out_eval.tensor(), eps)); + ASSERT_EQ(out_train.shape(), input.shape()); - module2.train(); + module2.train(); } TEST(ModuleTest, NormalizeFwd) { - auto input = Variable(fl::rand({10, 3}, fl::dtype::f64), true); - auto module = Normalize({1}, 2, 1e-12, 5); - module.train(); - auto out = module.forward(input); - ASSERT_TRUE(allClose( - fl::sqrt(fl::sum(out.tensor() * out.tensor(), {1})), - fl::full({10}, 5, fl::dtype::f64))); + auto input = Variable(fl::rand({10, 3}, fl::dtype::f64), true); + auto module = Normalize({1}, 2, 1e-12, 5); + module.train(); + auto out = module.forward(input); + ASSERT_TRUE( + allClose( + fl::sqrt(fl::sum(out.tensor() * out.tensor(), {1})), + fl::full({10}, 5, fl::dtype::f64) + ) + ); } TEST(ModuleTest, TransformFwd) { - auto inVar = Variable(fl::full({4, 5}, 1.0), true); + auto inVar = Variable(fl::full({4, 5}, 1.0), true); - auto l = Transform([](const Variable& in) { return fl::log(in); }); + auto l = Transform([](const Variable& in) { return fl::log(in); }); - ASSERT_TRUE(allClose(l.forward(inVar).tensor(), fl::full(inVar.shape(), 0.0))); + ASSERT_TRUE(allClose(l.forward(inVar).tensor(), fl::full(inVar.shape(), 0.0))); } TEST(ModuleTest, PrecisionCastFwd) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half precision not available on this device"; - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half precision not available on this device"; + } - auto in = Variable(fl::full({3, 3}, 1.0), true); - auto precisionCast = PrecisionCast(fl::dtype::f16); - auto out = precisionCast.forward(in); + auto in = Variable(fl::full({3, 3}, 1.0), true); + auto precisionCast = PrecisionCast(fl::dtype::f16); + auto out = precisionCast.forward(in); - ASSERT_EQ(out.type(), fl::dtype::f16); - ASSERT_TRUE(allClose(in.tensor(), out.astype(fl::dtype::f32).tensor())); + ASSERT_EQ(out.type(), fl::dtype::f16); + ASSERT_TRUE(allClose(in.tensor(), out.astype(fl::dtype::f32).tensor())); } TEST(ModuleTest, ContainerReplaceParam) { - auto seq = ContainerTestClass(); - seq.addParam(Variable(fl::rand({5, 5}), true)); - seq.add(Linear(10, 20)); - seq.addParam(Variable(fl::rand({5, 5}), true)); - seq.add(ReLU()); - seq.add(Linear(20, 30)); - seq.addParam(Variable(fl::rand({5, 5}), true)); - - // Change the first parameter - auto new_param = Variable(fl::rand({5, 5}), true); - seq.setParams(new_param, 0); - ASSERT_TRUE(allClose(seq.params()[0], new_param)); - - // Change the first linear layer's first parameter - new_param = Variable(fl::rand({10, 20}), true); - seq.setParams(new_param, 1); - ASSERT_TRUE(allClose(seq.params()[1], new_param)); - ASSERT_TRUE(allClose(seq.module(0)->param(0), new_param)); - - // Change the second linear layer's first parameter - new_param = Variable(fl::rand({20, 30}), true); - seq.setParams(new_param, 4); - ASSERT_TRUE(allClose(seq.params()[4], new_param)); - ASSERT_TRUE(allClose(seq.module(2)->param(0), new_param)); - - // Change the last parameter - new_param = Variable(fl::rand({5, 5}), true); - seq.setParams(new_param, 6); - ASSERT_TRUE(allClose(seq.param(6), new_param)); + auto seq = ContainerTestClass(); + seq.addParam(Variable(fl::rand({5, 5}), true)); + seq.add(Linear(10, 20)); + seq.addParam(Variable(fl::rand({5, 5}), true)); + seq.add(ReLU()); + seq.add(Linear(20, 30)); + seq.addParam(Variable(fl::rand({5, 5}), true)); + + // Change the first parameter + auto new_param = Variable(fl::rand({5, 5}), true); + seq.setParams(new_param, 0); + ASSERT_TRUE(allClose(seq.params()[0], new_param)); + + // Change the first linear layer's first parameter + new_param = Variable(fl::rand({10, 20}), true); + seq.setParams(new_param, 1); + ASSERT_TRUE(allClose(seq.params()[1], new_param)); + ASSERT_TRUE(allClose(seq.module(0)->param(0), new_param)); + + // Change the second linear layer's first parameter + new_param = Variable(fl::rand({20, 30}), true); + seq.setParams(new_param, 4); + ASSERT_TRUE(allClose(seq.params()[4], new_param)); + ASSERT_TRUE(allClose(seq.module(2)->param(0), new_param)); + + // Change the last parameter + new_param = Variable(fl::rand({5, 5}), true); + seq.setParams(new_param, 6); + ASSERT_TRUE(allClose(seq.param(6), new_param)); } TEST(ModuleTest, AdaptiveSoftMaxPredict) { - // test predict gives the same as argmax along probs - int N = 5; - int C = 5; - int T = 10; - int B = 5; - - auto x = input(fl::rand({N, T, B}, fl::dtype::f32)); - auto y = Variable( - (fl::rand({T, B}, fl::dtype::u32) % C).astype(fl::dtype::s32), false); - - std::vector cutoff{{C / 2, C}}; - auto activation = std::make_shared(N, cutoff); - - auto probs = activation->forward(x); - auto result1 = activation->predict(x).tensor(); - auto result2 = fl::argmax(probs.tensor(), 0, /* keepDims = */ true); - - ASSERT_TRUE(allClose(result1, result2)); + // test predict gives the same as argmax along probs + int N = 5; + int C = 5; + int T = 10; + int B = 5; + + auto x = input(fl::rand({N, T, B}, fl::dtype::f32)); + auto y = Variable( + (fl::rand({T, B}, fl::dtype::u32) % C).astype(fl::dtype::s32), + false + ); + + std::vector cutoff{{C / 2, C}}; + auto activation = std::make_shared(N, cutoff); + + auto probs = activation->forward(x); + auto result1 = activation->predict(x).tensor(); + auto result2 = fl::argmax(probs.tensor(), 0, /* keepDims = */ true); + + ASSERT_TRUE(allClose(result1, result2)); } TEST(ModuleTest, AdaptiveSoftMaxLossBatchFwd) { - // test batching - int N = 5; - int C = 3; - int T = 10; - int B = 5; - - auto x = input(fl::rand({N, T, B}, fl::dtype::f32)); - auto y = Variable( - (fl::rand({T, B}, fl::dtype::u32) % C).astype(fl::dtype::s32), false); - - std::vector cutoff{{C / 2, C}}; - auto activation = std::make_shared(N, cutoff); - auto asml = - std::make_shared(activation, ReduceMode::NONE); - auto batchOutVar = asml->forward(x, y); - - auto singleOut = fl::full({T, B}, 0, fl::dtype::f32); - for (int i = 0; i < B; ++i) { - auto singleOutVar = asml->forward( - x(fl::span, fl::span, fl::range(i, i + 1)), y(fl::span, fl::range(i, i + 1))); - singleOut(fl::span, i) = singleOutVar.tensor(); - } + // test batching + int N = 5; + int C = 3; + int T = 10; + int B = 5; + + auto x = input(fl::rand({N, T, B}, fl::dtype::f32)); + auto y = Variable( + (fl::rand({T, B}, fl::dtype::u32) % C).astype(fl::dtype::s32), + false + ); + + std::vector cutoff{{C / 2, C}}; + auto activation = std::make_shared(N, cutoff); + auto asml = + std::make_shared(activation, ReduceMode::NONE); + auto batchOutVar = asml->forward(x, y); + + auto singleOut = fl::full({T, B}, 0, fl::dtype::f32); + for(int i = 0; i < B; ++i) { + auto singleOutVar = asml->forward( + x(fl::span, fl::span, fl::range(i, i + 1)), + y(fl::span, fl::range(i, i + 1)) + ); + singleOut(fl::span, i) = singleOutVar.tensor(); + } - ASSERT_TRUE(allClose(batchOutVar.tensor(), singleOut)); + ASSERT_TRUE(allClose(batchOutVar.tensor(), singleOut)); } TEST(ModuleTest, AdaptiveSoftMaxLossIgnoreIndex) { - // test batching - int N = 5; - int C = 3; - int T = 10; - int B = 5; - - auto x = input(fl::rand({N, T, B}, fl::dtype::f32)); - auto y = Variable( - (fl::rand({T, B}, fl::dtype::u32) % C).astype(fl::dtype::s32), false); - auto ignoreIdx = y(0, 0).scalar(); - auto ignoreCount = fl::sum(y.tensor() != ignoreIdx).scalar(); - - std::vector cutoff{{C / 2, C}}; - auto activation = std::make_shared(N, cutoff); - auto asml1 = std::make_shared( - activation, ReduceMode::SUM, ignoreIdx); - auto asml2 = std::make_shared( - activation, ReduceMode::MEAN, ignoreIdx); - - auto lossSum = asml1->forward(x, y); - auto lossMean = asml2->forward(x, y); - ASSERT_NEAR( - fl::sum(lossSum.tensor()).scalar(), - fl::sum(lossMean.tensor()).scalar() * ignoreCount, - 1E-5); + // test batching + int N = 5; + int C = 3; + int T = 10; + int B = 5; + + auto x = input(fl::rand({N, T, B}, fl::dtype::f32)); + auto y = Variable( + (fl::rand({T, B}, fl::dtype::u32) % C).astype(fl::dtype::s32), + false + ); + auto ignoreIdx = y(0, 0).scalar(); + auto ignoreCount = fl::sum(y.tensor() != ignoreIdx).scalar(); + + std::vector cutoff{{C / 2, C}}; + auto activation = std::make_shared(N, cutoff); + auto asml1 = std::make_shared( + activation, + ReduceMode::SUM, + ignoreIdx + ); + auto asml2 = std::make_shared( + activation, + ReduceMode::MEAN, + ignoreIdx + ); + + auto lossSum = asml1->forward(x, y); + auto lossMean = asml2->forward(x, y); + ASSERT_NEAR( + fl::sum(lossSum.tensor()).scalar(), + fl::sum(lossMean.tensor()).scalar() * ignoreCount, + 1E-5 + ); } TEST(ModuleTest, IdentityFwd) { - auto module = Identity(); - std::vector in = { - Variable(fl::rand({1000, 1000}), true), - Variable(fl::rand({100, 100}), true)}; - - // Train Mode - module.train(); - auto out = module(in); - ASSERT_EQ(out.size(), 2); - ASSERT_TRUE(allClose(out.at(0), in.at(0), 1e-20)); - ASSERT_TRUE(allClose(out.at(1), in.at(1), 1e-20)); - - // Eval Mode - module.eval(); - out = module(in); - ASSERT_EQ(out.size(), 2); - ASSERT_TRUE(allClose(out.at(0), in.at(0), 1e-20)); - ASSERT_TRUE(allClose(out.at(1), in.at(1), 1e-20)); + auto module = Identity(); + std::vector in = { + Variable(fl::rand({1000, 1000}), true), + Variable(fl::rand({100, 100}), true)}; + + // Train Mode + module.train(); + auto out = module(in); + ASSERT_EQ(out.size(), 2); + ASSERT_TRUE(allClose(out.at(0), in.at(0), 1e-20)); + ASSERT_TRUE(allClose(out.at(1), in.at(1), 1e-20)); + + // Eval Mode + module.eval(); + out = module(in); + ASSERT_EQ(out.size(), 2); + ASSERT_TRUE(allClose(out.at(0), in.at(0), 1e-20)); + ASSERT_TRUE(allClose(out.at(1), in.at(1), 1e-20)); } TEST(ModuleTest, ModuleCloneCopy) { - int n_in = 1, n_out = 2; - auto wtVar = param(Tensor::fromVector({n_out, n_in}, {2, 4})); - auto inVar = input(Tensor::fromVector({n_in}, {3})); - Variable expected_outVar(Tensor::fromVector({n_out}, {6, 12}), true); - - Linear lin(wtVar); - ASSERT_TRUE(allClose(lin(inVar), expected_outVar, 1E-7)); - - // Intentionally cast to base Module ptr and clone/copy via the various - // options - std::unique_ptr modulePtr = std::make_unique(std::move(lin)); - std::unique_ptr clonedModulePtr = modulePtr->clone(); - - // Change the original module param and check the cloned modules have not - // changed - modulePtr->param(0).tensor() += 1.0F; - ASSERT_FALSE( - allClose(modulePtr->forward({inVar}).front(), expected_outVar, 1E-7)); - - ASSERT_TRUE(allClose( - clonedModulePtr->forward({inVar}).front(), expected_outVar, 1E-7)); + int n_in = 1, n_out = 2; + auto wtVar = param(Tensor::fromVector({n_out, n_in}, {2, 4})); + auto inVar = input(Tensor::fromVector({n_in}, {3})); + Variable expected_outVar(Tensor::fromVector({n_out}, {6, 12}), true); + + Linear lin(wtVar); + ASSERT_TRUE(allClose(lin(inVar), expected_outVar, 1E-7)); + + // Intentionally cast to base Module ptr and clone/copy via the various + // options + std::unique_ptr modulePtr = std::make_unique(std::move(lin)); + std::unique_ptr clonedModulePtr = modulePtr->clone(); + + // Change the original module param and check the cloned modules have not + // changed + modulePtr->param(0).tensor() += 1.0F; + ASSERT_FALSE( + allClose(modulePtr->forward({inVar}).front(), expected_outVar, 1E-7) + ); + + ASSERT_TRUE( + allClose( + clonedModulePtr->forward({inVar}).front(), + expected_outVar, + 1E-7 + ) + ); } TEST(ModuleTest, ContainerCloneCopy) { - ContainerTestClass seq; - seq.addParam(Variable(fl::rand({5, 5}), true)); - seq.add(Linear(10, 20)); - // Create copy/clone vis copy constructor - auto seqCopy = seq; - - // Make sure they are the same - ASSERT_TRUE(allClose(seq.params()[0], seqCopy.params()[0])); - ASSERT_TRUE(allClose(seq.params()[1], seqCopy.params()[1])); - - // Change the first parameter and check the copy has not changed - Variable new_param(fl::rand({5, 5}), true); - seq.setParams(new_param, 0); - ASSERT_TRUE(allClose(seq.params()[0], new_param)); - ASSERT_FALSE(allClose(seqCopy.params()[0], seq.params()[0])); - - // Change the linear layer's first parameter and check the copy has not - // changed - new_param = Variable(fl::rand({10, 20}), true); - seq.setParams(new_param, 1); - ASSERT_TRUE(allClose(seq.params()[1], new_param)); - ASSERT_TRUE(allClose(seq.module(0)->param(0), new_param)); - ASSERT_FALSE(allClose(seqCopy.params()[1], seq.params()[1])); - ASSERT_FALSE(allClose(seqCopy.module(0)->param(0), seq.module(0)->param(0))); - - // Intentionally cast to base Module ptr and clone/copy via the various - // options - std::unique_ptr modulePtr = - std::make_unique(std::move(seq)); - std::unique_ptr clonedModulePtr = modulePtr->clone(); - - ASSERT_TRUE(allClose(clonedModulePtr->params()[0], modulePtr->params()[0])); + ContainerTestClass seq; + seq.addParam(Variable(fl::rand({5, 5}), true)); + seq.add(Linear(10, 20)); + // Create copy/clone vis copy constructor + auto seqCopy = seq; + + // Make sure they are the same + ASSERT_TRUE(allClose(seq.params()[0], seqCopy.params()[0])); + ASSERT_TRUE(allClose(seq.params()[1], seqCopy.params()[1])); + + // Change the first parameter and check the copy has not changed + Variable new_param(fl::rand({5, 5}), true); + seq.setParams(new_param, 0); + ASSERT_TRUE(allClose(seq.params()[0], new_param)); + ASSERT_FALSE(allClose(seqCopy.params()[0], seq.params()[0])); + + // Change the linear layer's first parameter and check the copy has not + // changed + new_param = Variable(fl::rand({10, 20}), true); + seq.setParams(new_param, 1); + ASSERT_TRUE(allClose(seq.params()[1], new_param)); + ASSERT_TRUE(allClose(seq.module(0)->param(0), new_param)); + ASSERT_FALSE(allClose(seqCopy.params()[1], seq.params()[1])); + ASSERT_FALSE(allClose(seqCopy.module(0)->param(0), seq.module(0)->param(0))); + + // Intentionally cast to base Module ptr and clone/copy via the various + // options + std::unique_ptr modulePtr = + std::make_unique(std::move(seq)); + std::unique_ptr clonedModulePtr = modulePtr->clone(); + + ASSERT_TRUE(allClose(clonedModulePtr->params()[0], modulePtr->params()[0])); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/nn/NNSerializationTest.cpp b/flashlight/fl/test/nn/NNSerializationTest.cpp index a725b48..1a70ab3 100644 --- a/flashlight/fl/test/nn/NNSerializationTest.cpp +++ b/flashlight/fl/test/nn/NNSerializationTest.cpp @@ -24,28 +24,28 @@ using namespace fl; namespace { class ContainerTestClass : public Sequential { - public: - ContainerTestClass() = default; +public: + ContainerTestClass() = default; - void addParam(const Variable& param) { - params_.push_back(param); - } + void addParam(const Variable& param) { + params_.push_back(param); + } - private: - FL_SAVE_LOAD_WITH_BASE(Sequential) +private: + FL_SAVE_LOAD_WITH_BASE(Sequential) }; auto filesizebytes = []() -> std::uintmax_t { - return fs::file_size(fs::temp_directory_path() / "FileSize.txt"); -}; + return fs::file_size(fs::temp_directory_path() / "FileSize.txt"); + }; auto paramsizebytes = [](const std::vector& parameters) { - int64_t paramsize = 0; - for (const auto& param : parameters) { - paramsize += (param.elements() * fl::getTypeSize(param.type())); - } - return paramsize; -}; + int64_t paramsize = 0; + for(const auto& param : parameters) { + paramsize += (param.elements() * fl::getTypeSize(param.type())); + } + return paramsize; + }; const float kThreshold = 1.01; // within 1% @@ -54,321 +54,324 @@ const float kThreshold = 1.01; // within 1% CEREAL_REGISTER_TYPE(ContainerTestClass) TEST(NNSerializationTest, Variable) { - auto testimpl = [](const Tensor& arr, bool calc_grad) { - Variable a(arr, calc_grad); - std::stringstream ss; - { - cereal::BinaryOutputArchive ar(ss); - ar(a); - } - Variable b; - { - cereal::BinaryInputArchive ar(ss); - ar(b); - } - ASSERT_TRUE(allClose(a, b)); - }; - - testimpl(Tensor(), true); - testimpl(fl::randn({3, 6, 7, 8}), false); - testimpl(fl::rand({1, 2, 3, 5}, fl::dtype::b8), false); - testimpl(fl::rand({1, 2, 3, 5}, fl::dtype::s16), true); - testimpl(fl::randn({5, 6, 7, 9}, fl::dtype::f64), false); - testimpl(fl::rand({1, 9, 9, 2}, fl::dtype::s32), true); - testimpl(fl::rand({2, 9, 1, 8}, fl::dtype::s64), false); - testimpl(fl::rand({100, 150}, fl::dtype::u8), true); - testimpl(fl::rand({32, 32, 3}, fl::dtype::u16), false); + auto testimpl = [](const Tensor& arr, bool calc_grad) { + Variable a(arr, calc_grad); + std::stringstream ss; + { + cereal::BinaryOutputArchive ar(ss); + ar(a); + } + Variable b; + { + cereal::BinaryInputArchive ar(ss); + ar(b); + } + ASSERT_TRUE(allClose(a, b)); + }; + + testimpl(Tensor(), true); + testimpl(fl::randn({3, 6, 7, 8}), false); + testimpl(fl::rand({1, 2, 3, 5}, fl::dtype::b8), false); + testimpl(fl::rand({1, 2, 3, 5}, fl::dtype::s16), true); + testimpl(fl::randn({5, 6, 7, 9}, fl::dtype::f64), false); + testimpl(fl::rand({1, 9, 9, 2}, fl::dtype::s32), true); + testimpl(fl::rand({2, 9, 1, 8}, fl::dtype::s64), false); + testimpl(fl::rand({100, 150}, fl::dtype::u8), true); + testimpl(fl::rand({32, 32, 3}, fl::dtype::u16), false); } TEST(NNSerializationTest, Linear) { - auto wt = param(fl::rand({4, 3})); - auto bs = param(fl::rand({4})); - auto in = input(fl::rand({3, 2})); - auto lin = std::make_shared(wt, bs); + auto wt = param(fl::rand({4, 3})); + auto bs = param(fl::rand({4})); + auto in = input(fl::rand({3, 2})); + auto lin = std::make_shared(wt, bs); - const fs::path path = fs::temp_directory_path() / "Linear.mdl"; - save(path, lin); + const fs::path path = fs::temp_directory_path() / "Linear.mdl"; + save(path, lin); - std::shared_ptr lin2; - load(path, lin2); - ASSERT_TRUE(lin2); + std::shared_ptr lin2; + load(path, lin2); + ASSERT_TRUE(lin2); - ASSERT_TRUE(allParamsClose(*lin2, *lin)); - ASSERT_TRUE(allClose(lin2->forward(in), lin->forward(in))); + ASSERT_TRUE(allParamsClose(*lin2, *lin)); + ASSERT_TRUE(allClose(lin2->forward(in), lin->forward(in))); } TEST(NNSerializationTest, Conv2D) { - auto wt = param(fl::rand({5, 5, 2, 4})); - auto bs = param(fl::rand({1, 1, 4, 1})); - auto in = input(fl::rand({25, 25, 2, 5})); - auto conv = std::make_shared(wt, bs); + auto wt = param(fl::rand({5, 5, 2, 4})); + auto bs = param(fl::rand({1, 1, 4, 1})); + auto in = input(fl::rand({25, 25, 2, 5})); + auto conv = std::make_shared(wt, bs); - const fs::path path = fs::temp_directory_path() / "Conv2D.mdl"; - save(path, conv); + const fs::path path = fs::temp_directory_path() / "Conv2D.mdl"; + save(path, conv); - std::shared_ptr conv2; - load(path, conv2); - ASSERT_TRUE(conv2); + std::shared_ptr conv2; + load(path, conv2); + ASSERT_TRUE(conv2); - ASSERT_TRUE(allParamsClose(*conv2, *conv)); - ASSERT_TRUE(allClose(conv2->forward(in), conv->forward(in))); + ASSERT_TRUE(allParamsClose(*conv2, *conv)); + ASSERT_TRUE(allClose(conv2->forward(in), conv->forward(in))); } TEST(NNSerializationTest, Pool2D) { - auto in = input(fl::rand({8, 8})); - auto pool = std::make_shared(2, 3, 1, 1, 1, 1, PoolingMode::MAX); + auto in = input(fl::rand({8, 8})); + auto pool = std::make_shared(2, 3, 1, 1, 1, 1, PoolingMode::MAX); - const fs::path path = fs::temp_directory_path() / "Pool2D.mdl"; - save(path, pool); + const fs::path path = fs::temp_directory_path() / "Pool2D.mdl"; + save(path, pool); - std::shared_ptr pool2; - load(path, pool2); - ASSERT_TRUE(pool2); + std::shared_ptr pool2; + load(path, pool2); + ASSERT_TRUE(pool2); - ASSERT_TRUE(allParamsClose(*pool2, *pool)); - ASSERT_TRUE(allClose(pool2->forward(in), pool->forward(in))); + ASSERT_TRUE(allParamsClose(*pool2, *pool)); + ASSERT_TRUE(allClose(pool2->forward(in), pool->forward(in))); } TEST(NNSerializationTest, BaseModule) { - auto in = input(fl::rand({8, 8})); - ModulePtr dout = std::make_shared(0.75); + auto in = input(fl::rand({8, 8})); + ModulePtr dout = std::make_shared(0.75); - const fs::path path = fs::temp_directory_path() / "BaseModule.mdl"; - save(path, dout); + const fs::path path = fs::temp_directory_path() / "BaseModule.mdl"; + save(path, dout); - ModulePtr dout2; - load(path, dout2); - ASSERT_TRUE(dout2); + ModulePtr dout2; + load(path, dout2); + ASSERT_TRUE(dout2); - ASSERT_TRUE(allParamsClose(*dout2, *dout)); + ASSERT_TRUE(allParamsClose(*dout2, *dout)); } TEST(NNSerializationTest, PrecisionCast) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half precision not available on this device"; - } + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half precision not available on this device"; + } - auto in = input(fl::rand({8, 8})); - auto precisionCast = std::make_shared(fl::dtype::f16); + auto in = input(fl::rand({8, 8})); + auto precisionCast = std::make_shared(fl::dtype::f16); - const fs::path path = fs::temp_directory_path() / "PrecisionCast.mdl"; - save(path, precisionCast); + const fs::path path = fs::temp_directory_path() / "PrecisionCast.mdl"; + save(path, precisionCast); - std::shared_ptr precisionCast2; - load(path, precisionCast2); - ASSERT_TRUE(precisionCast2); + std::shared_ptr precisionCast2; + load(path, precisionCast2); + ASSERT_TRUE(precisionCast2); - ASSERT_TRUE( - allClose(precisionCast->forward(in), precisionCast2->forward(in))); + ASSERT_TRUE( + allClose(precisionCast->forward(in), precisionCast2->forward(in)) + ); } TEST(NNSerializationTest, WeightNormLinear) { - auto in = input(fl::randn({2, 10, 1, 1})); - auto wlin = std::make_shared(Linear(2, 3), 0); + auto in = input(fl::randn({2, 10, 1, 1})); + auto wlin = std::make_shared(Linear(2, 3), 0); - const fs::path path = fs::temp_directory_path() / "WeightNormLinear.mdl"; - save(path, wlin); + const fs::path path = fs::temp_directory_path() / "WeightNormLinear.mdl"; + save(path, wlin); - std::shared_ptr wlin2; - load(path, wlin2); - ASSERT_TRUE(wlin2); + std::shared_ptr wlin2; + load(path, wlin2); + ASSERT_TRUE(wlin2); - ASSERT_TRUE(allParamsClose(*wlin2, *wlin)); - ASSERT_TRUE( - allClose(wlin2->forward({in}).front(), wlin->forward({in}).front())); + ASSERT_TRUE(allParamsClose(*wlin2, *wlin)); + ASSERT_TRUE( + allClose(wlin2->forward({in}).front(), wlin->forward({in}).front()) + ); } TEST(NNSerializationTest, WeightNormConvSeq) { - auto in = input(fl::randn({70, 70, 30, 2})); - auto seq = std::make_shared(); - seq->add(std::make_shared(Conv2D(30, 80, 3, 3), 3)); - seq->add(std::make_shared(2)); - seq->add(std::make_shared(Conv2D(40, 90, 3, 3), 3)); - seq->add(std::make_shared(2)); - seq->add(std::make_shared(Conv2D(45, 100, 3, 3), 3)); - seq->add(std::make_shared(2)); + auto in = input(fl::randn({70, 70, 30, 2})); + auto seq = std::make_shared(); + seq->add(std::make_shared(Conv2D(30, 80, 3, 3), 3)); + seq->add(std::make_shared(2)); + seq->add(std::make_shared(Conv2D(40, 90, 3, 3), 3)); + seq->add(std::make_shared(2)); + seq->add(std::make_shared(Conv2D(45, 100, 3, 3), 3)); + seq->add(std::make_shared(2)); } TEST(NNSerializationTest, AdaptiveSoftMaxLoss) { - auto in = input(fl::rand({5, 10, /* B= */ 1})); - std::vector h_target{1, 1, 1, 2, 2, 2, 0, 0, 0, 0}; - auto g_target = Tensor::fromVector({10, /* B = */ 1}, h_target); - auto target = input(g_target); - - std::vector cutoff{{1, 3}}; - auto activation = std::make_shared(5, cutoff); - auto asml = std::make_shared(activation); - - const fs::path path = fs::temp_directory_path() / "AdaptiveSoftMaxLoss.mdl"; - save(path, asml); - - std::shared_ptr asml2; - load(path, asml2); - ASSERT_TRUE(asml2); - - ASSERT_TRUE(allParamsClose(*asml2, *asml)); - auto activation2 = asml2->getActivation(); - ASSERT_TRUE(allParamsClose(*activation2, *activation)); - ASSERT_TRUE(allClose(activation2->forward(in), activation->forward(in))); - ASSERT_TRUE(allClose(asml2->forward(in, target), asml->forward(in, target))); - - auto activation3 = std::make_shared(5, cutoff); - auto asml3 = std::make_shared(activation3); - int index = 0; - for (const auto& param : asml->params()) { - asml3->setParams(param, index); - index++; - } - ASSERT_TRUE(allParamsClose(*asml3->getActivation(), *activation)); + auto in = input(fl::rand({5, 10, /* B= */ 1})); + std::vector h_target{1, 1, 1, 2, 2, 2, 0, 0, 0, 0}; + auto g_target = Tensor::fromVector({10, /* B = */ 1}, h_target); + auto target = input(g_target); + + std::vector cutoff{{1, 3}}; + auto activation = std::make_shared(5, cutoff); + auto asml = std::make_shared(activation); + + const fs::path path = fs::temp_directory_path() / "AdaptiveSoftMaxLoss.mdl"; + save(path, asml); + + std::shared_ptr asml2; + load(path, asml2); + ASSERT_TRUE(asml2); + + ASSERT_TRUE(allParamsClose(*asml2, *asml)); + auto activation2 = asml2->getActivation(); + ASSERT_TRUE(allParamsClose(*activation2, *activation)); + ASSERT_TRUE(allClose(activation2->forward(in), activation->forward(in))); + ASSERT_TRUE(allClose(asml2->forward(in, target), asml->forward(in, target))); + + auto activation3 = std::make_shared(5, cutoff); + auto asml3 = std::make_shared(activation3); + int index = 0; + for(const auto& param : asml->params()) { + asml3->setParams(param, index); + index++; + } + ASSERT_TRUE(allParamsClose(*asml3->getActivation(), *activation)); } TEST(NNSerializationTest, PrettyString) { - Sequential seq; - seq.add(Conv2D(3, 64, 5, 5)); - seq.add(Pool2D(3, 3, 2, 2, 1, 1)); - seq.add(ReLU()); - seq.add(Dropout(0.4)); - seq.add(Linear(5, 10, false)); - seq.add(Tanh()); - seq.add(LeakyReLU(0.2)); - - auto prettystr = seq.prettyString(); - - std::string expectedstr = - "Sequential [input -> (0) -> (1) -> (2) -> (3) " - "-> (4) -> (5) -> (6) -> output]" - "(0): Conv2D (3->64, 5x5, 1, 1, 0, 0, 1, 1) (with bias)" - "(1): Pool2D-max (3x3, 2,2, 1,1)" - "(2): ReLU" - "(3): Dropout (0.400000)" - "(4): Linear (5->10) (without bias)" - "(5): Tanh" - "(6): LeakyReLU (0.200000)"; - - auto remove_ws = [](std::string& str) { - str.erase(std::remove(str.begin(), str.end(), ' '), str.end()); - str.erase(std::remove(str.begin(), str.end(), '\n'), str.end()); - str.erase(std::remove(str.begin(), str.end(), '\t'), str.end()); - str.erase(std::remove(str.begin(), str.end(), '\r'), str.end()); - }; - - remove_ws(expectedstr); - remove_ws(prettystr); - - ASSERT_EQ(expectedstr, prettystr); + Sequential seq; + seq.add(Conv2D(3, 64, 5, 5)); + seq.add(Pool2D(3, 3, 2, 2, 1, 1)); + seq.add(ReLU()); + seq.add(Dropout(0.4)); + seq.add(Linear(5, 10, false)); + seq.add(Tanh()); + seq.add(LeakyReLU(0.2)); + + auto prettystr = seq.prettyString(); + + std::string expectedstr = + "Sequential [input -> (0) -> (1) -> (2) -> (3) " + "-> (4) -> (5) -> (6) -> output]" + "(0): Conv2D (3->64, 5x5, 1, 1, 0, 0, 1, 1) (with bias)" + "(1): Pool2D-max (3x3, 2,2, 1,1)" + "(2): ReLU" + "(3): Dropout (0.400000)" + "(4): Linear (5->10) (without bias)" + "(5): Tanh" + "(6): LeakyReLU (0.200000)"; + + auto remove_ws = [](std::string& str) { + str.erase(std::remove(str.begin(), str.end(), ' '), str.end()); + str.erase(std::remove(str.begin(), str.end(), '\n'), str.end()); + str.erase(std::remove(str.begin(), str.end(), '\t'), str.end()); + str.erase(std::remove(str.begin(), str.end(), '\r'), str.end()); + }; + + remove_ws(expectedstr); + remove_ws(prettystr); + + ASSERT_EQ(expectedstr, prettystr); } TEST(NNSerializationTest, LeNet) { - auto leNet = std::make_shared(); + auto leNet = std::make_shared(); - leNet->add(Conv2D(3, 6, 5, 5)); - leNet->add(ReLU()); - leNet->add(Pool2D(2, 2, 2, 2)); + leNet->add(Conv2D(3, 6, 5, 5)); + leNet->add(ReLU()); + leNet->add(Pool2D(2, 2, 2, 2)); - leNet->add(Conv2D(6, 16, 5, 5)); - leNet->add(ReLU()); - leNet->add(Pool2D(2, 2, 2, 2)); + leNet->add(Conv2D(6, 16, 5, 5)); + leNet->add(ReLU()); + leNet->add(Pool2D(2, 2, 2, 2)); - leNet->add(View(Shape({16 * 5 * 5}))); + leNet->add(View(Shape({16 * 5 * 5}))); - leNet->add(Linear(16 * 5 * 5, 120)); - leNet->add(ReLU()); + leNet->add(Linear(16 * 5 * 5, 120)); + leNet->add(ReLU()); - leNet->add(Linear(120, 84)); - leNet->add(ReLU()); + leNet->add(Linear(120, 84)); + leNet->add(ReLU()); - leNet->add(Linear(84, 10)); + leNet->add(Linear(84, 10)); - const fs::path path = fs::temp_directory_path() / "LeNet.mdl"; - save(path, leNet); + const fs::path path = fs::temp_directory_path() / "LeNet.mdl"; + save(path, leNet); - std::shared_ptr leNet2; - load(path, leNet2); - ASSERT_TRUE(leNet2); + std::shared_ptr leNet2; + load(path, leNet2); + ASSERT_TRUE(leNet2); - ASSERT_TRUE(allParamsClose(*leNet2, *leNet)); + ASSERT_TRUE(allParamsClose(*leNet2, *leNet)); - auto in = input(fl::rand({32, 32, 3, 1})); - ASSERT_TRUE(allClose(leNet2->forward(in), leNet->forward(in))); + auto in = input(fl::rand({32, 32, 3, 1})); + ASSERT_TRUE(allClose(leNet2->forward(in), leNet->forward(in))); } // Make sure serialized file size if not too high TEST(NNSerializationTest, FileSize) { - auto conv = std::make_shared(300, 600, 10, 10); - - const fs::path path = fs::temp_directory_path() / "FileSize.txt"; - save(path, conv); - ASSERT_LT(filesizebytes(), paramsizebytes(conv->params()) * kThreshold); - - auto seq = Sequential(); - seq.add(Conv2D(64, 64, 3, 100)); - seq.add(ReLU()); - seq.add(Pool2D(2, 2, 2, 2)); - seq.add(Conv2D(64, 64, 100, 200)); - seq.add(ReLU()); - seq.add(Pool2D(2, 2, 2, 2)); - seq.add(Linear(200, 500)); - seq.add(MeanSquaredError()); - save(path, seq); - ASSERT_LT(filesizebytes(), paramsizebytes(seq.params()) * kThreshold); + auto conv = std::make_shared(300, 600, 10, 10); + + const fs::path path = fs::temp_directory_path() / "FileSize.txt"; + save(path, conv); + ASSERT_LT(filesizebytes(), paramsizebytes(conv->params()) * kThreshold); + + auto seq = Sequential(); + seq.add(Conv2D(64, 64, 3, 100)); + seq.add(ReLU()); + seq.add(Pool2D(2, 2, 2, 2)); + seq.add(Conv2D(64, 64, 100, 200)); + seq.add(ReLU()); + seq.add(Pool2D(2, 2, 2, 2)); + seq.add(Linear(200, 500)); + seq.add(MeanSquaredError()); + save(path, seq); + ASSERT_LT(filesizebytes(), paramsizebytes(seq.params()) * kThreshold); } TEST(NNSerializationTest, VariableTwice) { - Variable v(Tensor({1000, 1000}), false); - auto v2 = v; // The array for this variable shouldn't be saved again + Variable v(Tensor({1000, 1000}), false); + auto v2 = v; // The array for this variable shouldn't be saved again - const fs::path path = fs::temp_directory_path() / "ContainerWithParams.mdl"; - save(path, v2, v); + const fs::path path = fs::temp_directory_path() / "ContainerWithParams.mdl"; + save(path, v2, v); - ASSERT_LT( - static_cast(fs::file_size(path)), - paramsizebytes({v}) * kThreshold); + ASSERT_LT( + static_cast(fs::file_size(path)), + paramsizebytes({v}) * kThreshold + ); } TEST(NNSerializationTest, ContainerBackward) { - auto seq = std::make_shared(); - seq->add(Linear(10, 20)); - seq->add(ReLU()); - seq->add(Linear(20, 30)); - - const fs::path path = fs::temp_directory_path() / "ContainerBackward.mdl"; - save(path, static_cast(seq)); - - ModulePtr seq2; - load(path, seq2); - - auto in = input(fl::rand({10, 10})); - auto output = seq2->forward({in}).front(); - output.backward(); - for (auto& p : seq2->params()) { - ASSERT_TRUE(p.isGradAvailable()); - } + auto seq = std::make_shared(); + seq->add(Linear(10, 20)); + seq->add(ReLU()); + seq->add(Linear(20, 30)); + + const fs::path path = fs::temp_directory_path() / "ContainerBackward.mdl"; + save(path, static_cast(seq)); + + ModulePtr seq2; + load(path, seq2); + + auto in = input(fl::rand({10, 10})); + auto output = seq2->forward({in}).front(); + output.backward(); + for(auto& p : seq2->params()) { + ASSERT_TRUE(p.isGradAvailable()); + } } TEST(NNSerializationTest, ContainerWithParams) { - auto seq = std::make_shared(); - seq->addParam(Variable(fl::rand({5, 5}), true)); - seq->add(WeightNorm(Linear(10, 20), 0)); - seq->addParam(Variable(fl::rand({5, 5}), true)); - seq->add(ReLU()); - seq->add(Linear(20, 30)); - seq->addParam(Variable(fl::rand({5, 5}), true)); + auto seq = std::make_shared(); + seq->addParam(Variable(fl::rand({5, 5}), true)); + seq->add(WeightNorm(Linear(10, 20), 0)); + seq->addParam(Variable(fl::rand({5, 5}), true)); + seq->add(ReLU()); + seq->add(Linear(20, 30)); + seq->addParam(Variable(fl::rand({5, 5}), true)); - const fs::path path = fs::temp_directory_path() / "ContainerWithParams.mdl"; - save(path, static_cast(seq)); + const fs::path path = fs::temp_directory_path() / "ContainerWithParams.mdl"; + save(path, static_cast(seq)); - ModulePtr seq2; - load(path, seq2); - ASSERT_TRUE(seq2); + ModulePtr seq2; + load(path, seq2); + ASSERT_TRUE(seq2); - ASSERT_TRUE(allParamsClose(*seq, *seq2)); + ASSERT_TRUE(allParamsClose(*seq, *seq2)); - auto in = input(fl::rand({10, 10})); - ASSERT_TRUE(allClose(seq->forward(in), seq2->forward({in}).front())); + auto in = input(fl::rand({10, 10})); + ASSERT_TRUE(allClose(seq->forward(in), seq2->forward({in}).front())); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/nn/NNUtilsTest.cpp b/flashlight/fl/test/nn/NNUtilsTest.cpp index 448046f..a07f731 100644 --- a/flashlight/fl/test/nn/NNUtilsTest.cpp +++ b/flashlight/fl/test/nn/NNUtilsTest.cpp @@ -16,103 +16,145 @@ using namespace fl; TEST(UtilsTest, Join) { - // Empty vector - auto empty = join({}); - ASSERT_TRUE(empty.isEmpty()); + // Empty vector + auto empty = join({}); + ASSERT_TRUE(empty.isEmpty()); - // Single array - auto i = fl::rand({50, 60, 70, 1}); - auto o = join({i}, -1, 3); - ASSERT_TRUE(fl::all(o == i).asScalar()); + // Single array + auto i = fl::rand({50, 60, 70, 1}); + auto o = join({i}, -1, 3); + ASSERT_TRUE(fl::all(o == i).asScalar()); - // no dim for batching adds singleton dims - ASSERT_EQ( - join({fl::rand({50, 60, 70})}, -1, 3).shape(), Shape({50, 60, 70, 1})); - ASSERT_EQ(join({fl::rand({50, 60})}, -1, 3).shape(), Shape({50, 60, 1, 1})); + // no dim for batching adds singleton dims + ASSERT_EQ( + join({fl::rand({50, 60, 70})}, -1, 3).shape(), + Shape({50, 60, 70, 1}) + ); + ASSERT_EQ(join({fl::rand({50, 60})}, -1, 3).shape(), Shape({50, 60, 1, 1})); - // more than one array - auto a = fl::full({25, 1, 300, 1}, 1); - auto b = fl::full({20, 1, 300, 1}, 2); - auto c = fl::full({30, 1, 300, 1}, 3); + // more than one array + auto a = fl::full({25, 1, 300, 1}, 1); + auto b = fl::full({20, 1, 300, 1}, 2); + auto c = fl::full({30, 1, 300, 1}, 3); - auto o1 = join({a, b, c}); - ASSERT_EQ(o1.shape(), Shape({30, 1, 300, 3})); - ASSERT_TRUE( - fl::all( - o1(fl::range(25), fl::range(0, 1), fl::range(300), fl::range(0, 1)) == - a) - .asScalar()); - ASSERT_TRUE(fl::all( - o1(fl::range(25, 29), - fl::range(0, 1), - fl::range(300), - fl::range(0, 1)) == 0) - .asScalar()); - ASSERT_TRUE( - fl::all( - o1(fl::range(20), fl::range(0, 1), fl::range(300), fl::range(1, 2)) == - b) - .asScalar()); - ASSERT_TRUE(fl::all( - o1(fl::range(20, 29), - fl::range(0, 1), - fl::range(300), - fl::range(1, 2)) == 0) - .asScalar()); - ASSERT_TRUE( - fl::all( - o1(fl::range(30), fl::range(0, 1), fl::range(300), fl::range(2, 3)) == - c) - .asScalar()); + auto o1 = join({a, b, c}); + ASSERT_EQ(o1.shape(), Shape({30, 1, 300, 3})); + ASSERT_TRUE( + fl::all( + o1(fl::range(25), fl::range(0, 1), fl::range(300), fl::range(0, 1)) + == a + ) + .asScalar() + ); + ASSERT_TRUE( + fl::all( + o1( + fl::range(25, 29), + fl::range(0, 1), + fl::range(300), + fl::range(0, 1) + ) == 0 + ) + .asScalar() + ); + ASSERT_TRUE( + fl::all( + o1(fl::range(20), fl::range(0, 1), fl::range(300), fl::range(1, 2)) + == b + ) + .asScalar() + ); + ASSERT_TRUE( + fl::all( + o1( + fl::range(20, 29), + fl::range(0, 1), + fl::range(300), + fl::range(1, 2) + ) == 0 + ) + .asScalar() + ); + ASSERT_TRUE( + fl::all( + o1(fl::range(30), fl::range(0, 1), fl::range(300), fl::range(2, 3)) + == c + ) + .asScalar() + ); - auto o2 = join({a, b, c}, -1); - ASSERT_EQ(o2.shape(), Shape({30, 1, 300, 3})); - ASSERT_TRUE( - fl::all( - o2(fl::range(25), fl::range(0, 1), fl::range(300), fl::range(0, 1)) == - a) - .asScalar()); - ASSERT_TRUE(fl::all( - o2(fl::range(25, 29), - fl::range(0, 1), - fl::range(300), - fl::range(0, 1)) == -1) - .asScalar()); - ASSERT_TRUE( - fl::all( - o2(fl::range(20), fl::range(0, 1), fl::range(300), fl::range(1, 2)) == - b) - .asScalar()); - ASSERT_TRUE(fl::all( - o2(fl::range(20, 29), - fl::range(0, 1), - fl::range(300), - fl::range(1, 2)) == -1) - .asScalar()); - ASSERT_TRUE( - fl::all( - o2(fl::range(30), fl::range(0, 1), fl::range(300), fl::range(2, 3)) == - c) - .asScalar()); + auto o2 = join({a, b, c}, -1); + ASSERT_EQ(o2.shape(), Shape({30, 1, 300, 3})); + ASSERT_TRUE( + fl::all( + o2(fl::range(25), fl::range(0, 1), fl::range(300), fl::range(0, 1)) + == a + ) + .asScalar() + ); + ASSERT_TRUE( + fl::all( + o2( + fl::range(25, 29), + fl::range(0, 1), + fl::range(300), + fl::range(0, 1) + ) == -1 + ) + .asScalar() + ); + ASSERT_TRUE( + fl::all( + o2(fl::range(20), fl::range(0, 1), fl::range(300), fl::range(1, 2)) + == b + ) + .asScalar() + ); + ASSERT_TRUE( + fl::all( + o2( + fl::range(20, 29), + fl::range(0, 1), + fl::range(300), + fl::range(1, 2) + ) == -1 + ) + .asScalar() + ); + ASSERT_TRUE( + fl::all( + o2(fl::range(30), fl::range(0, 1), fl::range(300), fl::range(2, 3)) + == c + ) + .asScalar() + ); - auto o3 = join({a, b, c}, -1, 1); - ASSERT_EQ(o3.shape(), Shape({30, 3, 300, 1})); - ASSERT_TRUE(fl::all(o3(fl::range(25), fl::range(0, 1), fl::range(300)) == a) - .asScalar()); - ASSERT_TRUE( - fl::all(o3(fl::range(25, 29), fl::range(0, 1), fl::range(300)) == -1) - .asScalar()); - ASSERT_TRUE(fl::all(o3(fl::range(20), fl::range(1, 2), fl::range(300)) == b) - .asScalar()); - ASSERT_TRUE( - fl::all(o3(fl::range(20, 29), fl::range(1, 2), fl::range(300)) == -1) - .asScalar()); - ASSERT_TRUE(fl::all(o3(fl::range(30), fl::range(2, 3), fl::range(300)) == c) - .asScalar()); + auto o3 = join({a, b, c}, -1, 1); + ASSERT_EQ(o3.shape(), Shape({30, 3, 300, 1})); + ASSERT_TRUE( + fl::all(o3(fl::range(25), fl::range(0, 1), fl::range(300)) == a) + .asScalar() + ); + ASSERT_TRUE( + fl::all(o3(fl::range(25, 29), fl::range(0, 1), fl::range(300)) == -1) + .asScalar() + ); + ASSERT_TRUE( + fl::all(o3(fl::range(20), fl::range(1, 2), fl::range(300)) == b) + .asScalar() + ); + ASSERT_TRUE( + fl::all(o3(fl::range(20, 29), fl::range(1, 2), fl::range(300)) == -1) + .asScalar() + ); + ASSERT_TRUE( + fl::all(o3(fl::range(30), fl::range(2, 3), fl::range(300)) == c) + .asScalar() + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/optim/OptimBenchmark.cpp b/flashlight/fl/test/optim/OptimBenchmark.cpp index a2de9d6..4562d39 100644 --- a/flashlight/fl/test/optim/OptimBenchmark.cpp +++ b/flashlight/fl/test/optim/OptimBenchmark.cpp @@ -16,76 +16,76 @@ using namespace fl; -#define TIME(FUNC) \ - std::cout << "Timing " << #FUNC << " ... " << std::flush; \ - std::cout << std::setprecision(5) << FUNC() * 1000.0 << " msec" << std::endl; +#define TIME(FUNC) \ + std::cout << "Timing " << #FUNC << " ... " << std::flush; \ + std::cout << std::setprecision(5) << FUNC() * 1000.0 << " msec" << std::endl; double timeit(std::function fn) { - // warmup - for (int i = 0; i < 10; ++i) { - fn(); - } - fl::sync(); + // warmup + for(int i = 0; i < 10; ++i) { + fn(); + } + fl::sync(); - int num_iters = 100; - fl::sync(); - auto start = fl::Timer::start(); - for (int i = 0; i < num_iters; i++) { - fn(); - } - fl::sync(); - return fl::Timer::stop(start) / num_iters; + int num_iters = 100; + fl::sync(); + auto start = fl::Timer::start(); + for(int i = 0; i < num_iters; i++) { + fn(); + } + fl::sync(); + return fl::Timer::stop(start) / num_iters; } double optloop(FirstOrderOptimizer& opt, const Variable& w) { - auto input = Variable(fl::randn({10, 10}), false); - auto fn = [&]() { - for (int it = 0; it < 100; it++) { - opt.zeroGrad(); - auto loss = fl::matmul(w, input); - loss.backward(); - opt.step(); - } - }; - return timeit(fn); + auto input = Variable(fl::randn({10, 10}), false); + auto fn = [&]() { + for(int it = 0; it < 100; it++) { + opt.zeroGrad(); + auto loss = fl::matmul(w, input); + loss.backward(); + opt.step(); + } + }; + return timeit(fn); } double sgd() { - auto w = Variable(fl::randn({1, 10}), true); - auto opt = SGDOptimizer({w}, 1e-3); - return optloop(opt, w); + auto w = Variable(fl::randn({1, 10}), true); + auto opt = SGDOptimizer({w}, 1e-3); + return optloop(opt, w); } double adam() { - auto w = Variable(fl::randn({1, 10}), true); - auto opt = AdamOptimizer({w}, 1e-3); - return optloop(opt, w); + auto w = Variable(fl::randn({1, 10}), true); + auto opt = AdamOptimizer({w}, 1e-3); + return optloop(opt, w); } double rmsprop() { - auto w = Variable(fl::randn({1, 10}), true); - auto opt = RMSPropOptimizer({w}, 1e-3); - return optloop(opt, w); + auto w = Variable(fl::randn({1, 10}), true); + auto opt = RMSPropOptimizer({w}, 1e-3); + return optloop(opt, w); } double adadelta() { - auto w = Variable(fl::randn({1, 10}), true); - auto opt = AdadeltaOptimizer({w}); - return optloop(opt, w); + auto w = Variable(fl::randn({1, 10}), true); + auto opt = AdadeltaOptimizer({w}); + return optloop(opt, w); } double nag() { - auto w = Variable(fl::randn({1, 10}), true); - auto opt = NAGOptimizer({w}, 1e-3); - return optloop(opt, w); + auto w = Variable(fl::randn({1, 10}), true); + auto opt = NAGOptimizer({w}, 1e-3); + return optloop(opt, w); } int main() { - fl::init(); - TIME(sgd); - TIME(nag); - TIME(adam); - TIME(rmsprop); - TIME(adadelta); - return 0; + fl::init(); + TIME(sgd); + TIME(nag); + TIME(adam); + TIME(rmsprop); + TIME(adadelta); + return 0; } diff --git a/flashlight/fl/test/optim/OptimTest.cpp b/flashlight/fl/test/optim/OptimTest.cpp index 62abbd0..3951fc2 100644 --- a/flashlight/fl/test/optim/OptimTest.cpp +++ b/flashlight/fl/test/optim/OptimTest.cpp @@ -16,102 +16,108 @@ using namespace fl; TEST(OptimTest, GradNorm) { - std::vector parameters; - for (int i = 0; i < 5; i++) { - auto v = Variable(fl::randn({10, 10, 10}), true); - v = v.astype(fl::dtype::f64); - v.addGrad(Variable(fl::randn({10, 10, 10}, fl::dtype::f64), false)); - parameters.push_back(v); - } - double max_norm = 5.0; - clipGradNorm(parameters, max_norm); - - double clipped = 0.0; - for (auto& v : parameters) { - auto& g = v.grad().tensor(); - clipped += fl::sum(g * g).asScalar(); - } - clipped = std::sqrt(clipped); - ASSERT_TRUE(allClose(fl::full({1}, max_norm), fl::full({1}, clipped))); + std::vector parameters; + for(int i = 0; i < 5; i++) { + auto v = Variable(fl::randn({10, 10, 10}), true); + v = v.astype(fl::dtype::f64); + v.addGrad(Variable(fl::randn({10, 10, 10}, fl::dtype::f64), false)); + parameters.push_back(v); + } + double max_norm = 5.0; + clipGradNorm(parameters, max_norm); + + double clipped = 0.0; + for(auto& v : parameters) { + auto& g = v.grad().tensor(); + clipped += fl::sum(g * g).asScalar(); + } + clipped = std::sqrt(clipped); + ASSERT_TRUE(allClose(fl::full({1}, max_norm), fl::full({1}, clipped))); } TEST(OptimTest, GradNormF16) { - if (!fl::f16Supported()) { - GTEST_SKIP() << "Half-precision not supported on this device"; - } - - std::vector parameters; - for (int i = 0; i < 5; i++) { - auto v = Variable(fl::randn({10, 10, 10}), true); - v = v.astype(fl::dtype::f16); - v.addGrad(Variable(fl::randn({10, 10, 10}, fl::dtype::f16), false)); - parameters.push_back(v); - } - double max_norm = 5.0; - clipGradNorm(parameters, max_norm); - - double clipped = 0.0; - for (auto& v : parameters) { - auto& g = v.grad().tensor(); - clipped += fl::sum(g * g).asScalar(); - } - clipped = std::sqrt(clipped); - ASSERT_TRUE(allClose(fl::full({1}, max_norm), fl::full({1}, clipped), 1e-2)); + if(!fl::f16Supported()) { + GTEST_SKIP() << "Half-precision not supported on this device"; + } + + std::vector parameters; + for(int i = 0; i < 5; i++) { + auto v = Variable(fl::randn({10, 10, 10}), true); + v = v.astype(fl::dtype::f16); + v.addGrad(Variable(fl::randn({10, 10, 10}, fl::dtype::f16), false)); + parameters.push_back(v); + } + double max_norm = 5.0; + clipGradNorm(parameters, max_norm); + + double clipped = 0.0; + for(auto& v : parameters) { + auto& g = v.grad().tensor(); + clipped += fl::sum(g * g).asScalar(); + } + clipped = std::sqrt(clipped); + ASSERT_TRUE(allClose(fl::full({1}, max_norm), fl::full({1}, clipped), 1e-2)); } TEST(SerializationTest, OptimizerSerialize) { - const fs::path path = fs::temp_directory_path() / "optmizer.bin"; - - std::vector parameters; - for (int i = 0; i < 5; i++) { - auto v = Variable(fl::randn({10, 10, 10}, fl::dtype::f64), true); - v.addGrad(Variable(fl::randn({10, 10, 10}, fl::dtype::f64), false)); - parameters.push_back(v); - } - - std::shared_ptr opt; - opt = std::make_shared(parameters, 0.0001f); - opt->step(); - - save( - path, parameters, static_cast>(opt)); - - std::vector parameters2; - std::shared_ptr opt2; - load(path, parameters2, opt2); - - for (int i = 0; i < 5; i++) { - parameters2[i].addGrad(Variable(parameters[i].grad().tensor(), false)); - } - - opt->step(); - opt2->step(); - - for (int i = 0; i < 5; i++) { - ASSERT_TRUE(allClose(parameters[i].tensor(), parameters2[i].tensor())); - } - - opt = std::make_shared(parameters, 0.01f); - opt->step(); - - save( - path, parameters, static_cast>(opt)); - load(path, parameters2, opt2); - - for (int i = 0; i < 5; i++) { - parameters2[i].addGrad(Variable(parameters[i].grad().tensor(), false)); - } - - opt->step(); - opt2->step(); - - for (int i = 0; i < 5; i++) { - ASSERT_TRUE(allClose(parameters[i].tensor(), parameters2[i].tensor())); - } + const fs::path path = fs::temp_directory_path() / "optmizer.bin"; + + std::vector parameters; + for(int i = 0; i < 5; i++) { + auto v = Variable(fl::randn({10, 10, 10}, fl::dtype::f64), true); + v.addGrad(Variable(fl::randn({10, 10, 10}, fl::dtype::f64), false)); + parameters.push_back(v); + } + + std::shared_ptr opt; + opt = std::make_shared(parameters, 0.0001f); + opt->step(); + + save( + path, + parameters, + static_cast>(opt) + ); + + std::vector parameters2; + std::shared_ptr opt2; + load(path, parameters2, opt2); + + for(int i = 0; i < 5; i++) { + parameters2[i].addGrad(Variable(parameters[i].grad().tensor(), false)); + } + + opt->step(); + opt2->step(); + + for(int i = 0; i < 5; i++) { + ASSERT_TRUE(allClose(parameters[i].tensor(), parameters2[i].tensor())); + } + + opt = std::make_shared(parameters, 0.01f); + opt->step(); + + save( + path, + parameters, + static_cast>(opt) + ); + load(path, parameters2, opt2); + + for(int i = 0; i < 5; i++) { + parameters2[i].addGrad(Variable(parameters[i].grad().tensor(), false)); + } + + opt->step(); + opt2->step(); + + for(int i = 0; i < 5; i++) { + ASSERT_TRUE(allClose(parameters[i].tensor(), parameters2[i].tensor())); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/runtime/CUDADeviceTest.cpp b/flashlight/fl/test/runtime/CUDADeviceTest.cpp index 0a10409..7afb356 100644 --- a/flashlight/fl/test/runtime/CUDADeviceTest.cpp +++ b/flashlight/fl/test/runtime/CUDADeviceTest.cpp @@ -18,31 +18,31 @@ using fl::DeviceManager; using fl::DeviceType; TEST(CUDADeviceTest, impl) { - auto& manager = DeviceManager::getInstance(); + auto& manager = DeviceManager::getInstance(); - auto& cudaDevice = manager.getActiveDevice(DeviceType::CUDA); - ASSERT_NO_THROW(cudaDevice.impl()); - ASSERT_THROW(cudaDevice.impl(), std::invalid_argument); + auto& cudaDevice = manager.getActiveDevice(DeviceType::CUDA); + ASSERT_NO_THROW(cudaDevice.impl()); + ASSERT_THROW(cudaDevice.impl(), std::invalid_argument); - auto& x64Device = manager.getActiveDevice(DeviceType::x64); - ASSERT_NO_THROW(x64Device.impl()); - ASSERT_THROW(x64Device.impl(), std::invalid_argument); + auto& x64Device = manager.getActiveDevice(DeviceType::x64); + ASSERT_NO_THROW(x64Device.impl()); + ASSERT_THROW(x64Device.impl(), std::invalid_argument); } TEST(CUDADeviceTest, nativeId) { - auto& manager = DeviceManager::getInstance(); - int numCudaDevices = 0; - cudaGetDeviceCount(&numCudaDevices); - - for (auto id = 0; id < numCudaDevices; id++) { - auto& cudaDevice = - manager.getDevice(DeviceType::CUDA, id).impl(); - ASSERT_EQ(cudaDevice.nativeId(), id); - } + auto& manager = DeviceManager::getInstance(); + int numCudaDevices = 0; + cudaGetDeviceCount(&numCudaDevices); + + for(auto id = 0; id < numCudaDevices; id++) { + auto& cudaDevice = + manager.getDevice(DeviceType::CUDA, id).impl(); + ASSERT_EQ(cudaDevice.nativeId(), id); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/runtime/CUDAStreamTest.cpp b/flashlight/fl/test/runtime/CUDAStreamTest.cpp index 8f061be..8a53e6b 100644 --- a/flashlight/fl/test/runtime/CUDAStreamTest.cpp +++ b/flashlight/fl/test/runtime/CUDAStreamTest.cpp @@ -21,99 +21,99 @@ using fl::Stream; using fl::StreamType; TEST(CUDAStreamTest, createManaged) { - auto& manager = DeviceManager::getInstance(); - for (auto cudaDevice : manager.getDevicesOfType(DeviceType::CUDA)) { - cudaDevice->setActive(); - auto cudaStream = CUDAStream::createManaged(); - - ASSERT_EQ(cudaStream->type, StreamType::CUDA); - ASSERT_EQ(&cudaStream->device(), cudaDevice); - ASSERT_EQ(&cudaStream->impl(), cudaStream.get()); - } + auto& manager = DeviceManager::getInstance(); + for(auto cudaDevice : manager.getDevicesOfType(DeviceType::CUDA)) { + cudaDevice->setActive(); + auto cudaStream = CUDAStream::createManaged(); + + ASSERT_EQ(cudaStream->type, StreamType::CUDA); + ASSERT_EQ(&cudaStream->device(), cudaDevice); + ASSERT_EQ(&cudaStream->impl(), cudaStream.get()); + } } TEST(CUDAStreamTest, createUnmanaged) { - auto& manager = DeviceManager::getInstance(); - for (auto cudaDevice : manager.getDevicesOfType(DeviceType::CUDA)) { - cudaDevice->setActive(); - auto cudaStream = CUDAStream::createUnmanaged(); - - ASSERT_EQ(cudaStream->type, StreamType::CUDA); - ASSERT_EQ(&cudaStream->device(), cudaDevice); - ASSERT_EQ(&cudaStream->impl(), cudaStream.get()); - // safe to destroy since underlying stream isn't managed. - FL_CUDA_CHECK(cudaStreamDestroy(cudaStream->handle())); - } + auto& manager = DeviceManager::getInstance(); + for(auto cudaDevice : manager.getDevicesOfType(DeviceType::CUDA)) { + cudaDevice->setActive(); + auto cudaStream = CUDAStream::createUnmanaged(); + + ASSERT_EQ(cudaStream->type, StreamType::CUDA); + ASSERT_EQ(&cudaStream->device(), cudaDevice); + ASSERT_EQ(&cudaStream->impl(), cudaStream.get()); + // safe to destroy since underlying stream isn't managed. + FL_CUDA_CHECK(cudaStreamDestroy(cudaStream->handle())); + } } TEST(CUDAStreamTest, unmanagedWrapper) { - auto& manager = DeviceManager::getInstance(); - int numCudaDevices = 0; - FL_CUDA_CHECK(cudaGetDeviceCount(&numCudaDevices)); - - for (int id = 0; id < numCudaDevices; id++) { - FL_CUDA_CHECK(cudaSetDevice(id)); - cudaStream_t nativeStream; - FL_CUDA_CHECK(cudaStreamCreate(&nativeStream)); - auto& cudaDevice = manager.getDevice(DeviceType::CUDA, id); - auto cudaStream = CUDAStream::wrapUnmanaged(id, nativeStream); - - ASSERT_EQ(cudaStream->type, StreamType::CUDA); - ASSERT_EQ(&cudaStream->device(), &cudaDevice); - ASSERT_EQ(cudaStream->handle(), cudaStream->handle()); - ASSERT_EQ(&cudaStream->impl(), cudaStream.get()); - // safe to destroy since wrapper won't manage underlying stream by default. - FL_CUDA_CHECK(cudaStreamDestroy(nativeStream)); - } + auto& manager = DeviceManager::getInstance(); + int numCudaDevices = 0; + FL_CUDA_CHECK(cudaGetDeviceCount(&numCudaDevices)); + + for(int id = 0; id < numCudaDevices; id++) { + FL_CUDA_CHECK(cudaSetDevice(id)); + cudaStream_t nativeStream; + FL_CUDA_CHECK(cudaStreamCreate(&nativeStream)); + auto& cudaDevice = manager.getDevice(DeviceType::CUDA, id); + auto cudaStream = CUDAStream::wrapUnmanaged(id, nativeStream); + + ASSERT_EQ(cudaStream->type, StreamType::CUDA); + ASSERT_EQ(&cudaStream->device(), &cudaDevice); + ASSERT_EQ(cudaStream->handle(), cudaStream->handle()); + ASSERT_EQ(&cudaStream->impl(), cudaStream.get()); + // safe to destroy since wrapper won't manage underlying stream by default. + FL_CUDA_CHECK(cudaStreamDestroy(nativeStream)); + } } TEST(CUDAStreamTest, unmanagedWrapperDeviceSwitch) { - auto& manager = DeviceManager::getInstance(); - if (manager.getDeviceCount(DeviceType::CUDA) > 1) { - const auto& device0 = manager.getDevice(DeviceType::CUDA, 0); - const auto& device1 = manager.getDevice(DeviceType::CUDA, 1); - - device0.setActive(); - cudaStream_t nativeStream; - FL_CUDA_CHECK(cudaStreamCreate(&nativeStream)); - device1.setActive(); - // wrapper will switch to device0, create, and then switch back. - auto cudaStreamWrapped = CUDAStream::wrapUnmanaged(device0.nativeId(), nativeStream); - auto cudaStreamCreated = CUDAStream::createManaged(); - - // nothing blows up -- event was created and used correctly - ASSERT_NO_THROW(cudaStreamWrapped->relativeSync(*cudaStreamCreated)); - ASSERT_NO_THROW(cudaStreamCreated->relativeSync(*cudaStreamWrapped)); - - ASSERT_EQ(&cudaStreamWrapped->device(), &device0); - ASSERT_EQ(&cudaStreamCreated->device(), &device1); - ASSERT_EQ(&manager.getActiveDevice(DeviceType::CUDA), &device1); - } + auto& manager = DeviceManager::getInstance(); + if(manager.getDeviceCount(DeviceType::CUDA) > 1) { + const auto& device0 = manager.getDevice(DeviceType::CUDA, 0); + const auto& device1 = manager.getDevice(DeviceType::CUDA, 1); + + device0.setActive(); + cudaStream_t nativeStream; + FL_CUDA_CHECK(cudaStreamCreate(&nativeStream)); + device1.setActive(); + // wrapper will switch to device0, create, and then switch back. + auto cudaStreamWrapped = CUDAStream::wrapUnmanaged(device0.nativeId(), nativeStream); + auto cudaStreamCreated = CUDAStream::createManaged(); + + // nothing blows up -- event was created and used correctly + ASSERT_NO_THROW(cudaStreamWrapped->relativeSync(*cudaStreamCreated)); + ASSERT_NO_THROW(cudaStreamCreated->relativeSync(*cudaStreamWrapped)); + + ASSERT_EQ(&cudaStreamWrapped->device(), &device0); + ASSERT_EQ(&cudaStreamCreated->device(), &device1); + ASSERT_EQ(&manager.getActiveDevice(DeviceType::CUDA), &device1); + } } TEST(CUDAStreamTest, relativeSync) { - auto cs1 = CUDAStream::createManaged(); - auto cs2 = CUDAStream::createManaged(); - std::shared_ptr s1 = cs1; - std::shared_ptr s2 = cs2; - ASSERT_NO_THROW(s1->relativeSync(*s2)); - ASSERT_NO_THROW(s1->relativeSync(*cs2)); - ASSERT_NO_THROW(cs1->relativeSync(*s2)); - ASSERT_NO_THROW(cs1->relativeSync(*cs2)); - - std::unordered_set streams { s1.get(), s2.get() }; - std::shared_ptr s3 = CUDAStream::createManaged(); - ASSERT_NO_THROW(s3->relativeSync(streams)); + auto cs1 = CUDAStream::createManaged(); + auto cs2 = CUDAStream::createManaged(); + std::shared_ptr s1 = cs1; + std::shared_ptr s2 = cs2; + ASSERT_NO_THROW(s1->relativeSync(*s2)); + ASSERT_NO_THROW(s1->relativeSync(*cs2)); + ASSERT_NO_THROW(cs1->relativeSync(*s2)); + ASSERT_NO_THROW(cs1->relativeSync(*cs2)); + + std::unordered_set streams {s1.get(), s2.get()}; + std::shared_ptr s3 = CUDAStream::createManaged(); + ASSERT_NO_THROW(s3->relativeSync(streams)); } TEST(CUDAStreamTest, sync) { - auto cs1 = CUDAStream::createManaged(); - ASSERT_NO_THROW(cs1->sync()); + auto cs1 = CUDAStream::createManaged(); + ASSERT_NO_THROW(cs1->sync()); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/runtime/DeviceManagerTest.cpp b/flashlight/fl/test/runtime/DeviceManagerTest.cpp index bbe6426..a3c7ba4 100644 --- a/flashlight/fl/test/runtime/DeviceManagerTest.cpp +++ b/flashlight/fl/test/runtime/DeviceManagerTest.cpp @@ -16,69 +16,73 @@ using fl::DeviceManager; using fl::DeviceType; TEST(DeviceManagerTest, getInstance) { - ASSERT_EQ(&DeviceManager::getInstance(), &DeviceManager::getInstance()); + ASSERT_EQ(&DeviceManager::getInstance(), &DeviceManager::getInstance()); } TEST(DeviceManagerTest, isDeviceTypeAvailable) { - auto& manager = DeviceManager::getInstance(); - // x64 (CPU) should be always available - ASSERT_TRUE(manager.isDeviceTypeAvailable(DeviceType::x64)); + auto& manager = DeviceManager::getInstance(); + // x64 (CPU) should be always available + ASSERT_TRUE(manager.isDeviceTypeAvailable(DeviceType::x64)); - // CUDA availability depends on compilation - bool expectCUDA = FL_BACKEND_CUDA; - ASSERT_EQ(manager.isDeviceTypeAvailable(DeviceType::CUDA), expectCUDA); + // CUDA availability depends on compilation + bool expectCUDA = FL_BACKEND_CUDA; + ASSERT_EQ(manager.isDeviceTypeAvailable(DeviceType::CUDA), expectCUDA); } TEST(DeviceManagerTest, getDeviceCount) { - auto& manager = DeviceManager::getInstance(); - // For now we always treat CPU as a single device - ASSERT_EQ(manager.getDeviceCount(DeviceType::x64), 1); + auto& manager = DeviceManager::getInstance(); + // For now we always treat CPU as a single device + ASSERT_EQ(manager.getDeviceCount(DeviceType::x64), 1); - if (manager.isDeviceTypeAvailable(DeviceType::CUDA)) { - ASSERT_NO_THROW(manager.getDeviceCount(DeviceType::CUDA)); - } else { - ASSERT_THROW(manager.getDeviceCount(DeviceType::CUDA), - std::runtime_error); - } + if(manager.isDeviceTypeAvailable(DeviceType::CUDA)) { + ASSERT_NO_THROW(manager.getDeviceCount(DeviceType::CUDA)); + } else { + ASSERT_THROW( + manager.getDeviceCount(DeviceType::CUDA), + std::runtime_error + ); + } } TEST(DeviceManagerTest, getDevicesOfType) { - auto& manager = DeviceManager::getInstance(); - // For now we always treat CPU as a single device - ASSERT_EQ(manager.getDevicesOfType(DeviceType::x64).size(), 1); + auto& manager = DeviceManager::getInstance(); + // For now we always treat CPU as a single device + ASSERT_EQ(manager.getDevicesOfType(DeviceType::x64).size(), 1); - for (auto type : fl::getDeviceTypes()) { - if (manager.isDeviceTypeAvailable(DeviceType::CUDA)) { - for (auto device : manager.getDevicesOfType(type)) { - ASSERT_EQ(device->type(), type); - } - } else { - ASSERT_THROW(manager.getDeviceCount(DeviceType::CUDA), - std::runtime_error); + for(auto type : fl::getDeviceTypes()) { + if(manager.isDeviceTypeAvailable(DeviceType::CUDA)) { + for(auto device : manager.getDevicesOfType(type)) { + ASSERT_EQ(device->type(), type); + } + } else { + ASSERT_THROW( + manager.getDeviceCount(DeviceType::CUDA), + std::runtime_error + ); + } } - } } TEST(DeviceManagerTest, getDevice) { - auto& manager = DeviceManager::getInstance(); - auto& x64Device = - manager.getDevice(DeviceType::x64, fl::kX64DeviceId); - ASSERT_EQ(x64Device.type(), DeviceType::x64); + auto& manager = DeviceManager::getInstance(); + auto& x64Device = + manager.getDevice(DeviceType::x64, fl::kX64DeviceId); + ASSERT_EQ(x64Device.type(), DeviceType::x64); } TEST(DeviceManagerTest, getActiveDevice) { - auto& manager = DeviceManager::getInstance(); - for (auto type : fl::getDeviceTypes()) { - if (manager.isDeviceTypeAvailable(type)) { - ASSERT_EQ(manager.getActiveDevice(type).type(), type); - } else { - ASSERT_THROW(manager.getActiveDevice(type), std::runtime_error); + auto& manager = DeviceManager::getInstance(); + for(auto type : fl::getDeviceTypes()) { + if(manager.isDeviceTypeAvailable(type)) { + ASSERT_EQ(manager.getActiveDevice(type).type(), type); + } else { + ASSERT_THROW(manager.getActiveDevice(type), std::runtime_error); + } } - } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/runtime/DeviceTest.cpp b/flashlight/fl/test/runtime/DeviceTest.cpp index 72fd290..62fd078 100644 --- a/flashlight/fl/test/runtime/DeviceTest.cpp +++ b/flashlight/fl/test/runtime/DeviceTest.cpp @@ -14,76 +14,76 @@ using fl::DeviceManager; using fl::DeviceType; TEST(DeviceTest, type) { - auto& manager = DeviceManager::getInstance(); - for (auto type : fl::getDeviceTypes()) { - if (manager.isDeviceTypeAvailable(type)) { - for (auto* device : manager.getDevicesOfType(type)) { - ASSERT_EQ(device->type(), type); - } + auto& manager = DeviceManager::getInstance(); + for(auto type : fl::getDeviceTypes()) { + if(manager.isDeviceTypeAvailable(type)) { + for(auto* device : manager.getDevicesOfType(type)) { + ASSERT_EQ(device->type(), type); + } + } } - } } TEST(DeviceTest, nativeId) { - const auto& manager = DeviceManager::getInstance(); - for (const auto* device : manager.getDevicesOfType(DeviceType::x64)) { - ASSERT_EQ(device->nativeId(), fl::kX64DeviceId); - } + const auto& manager = DeviceManager::getInstance(); + for(const auto* device : manager.getDevicesOfType(DeviceType::x64)) { + ASSERT_EQ(device->nativeId(), fl::kX64DeviceId); + } } TEST(DeviceTest, setActive) { - auto& manager = DeviceManager::getInstance(); - for (auto type : fl::getDeviceTypes()) { - if (manager.isDeviceTypeAvailable(type)) { - for (auto* device : manager.getDevicesOfType(type)) { - device->setActive(); - ASSERT_EQ(&manager.getActiveDevice(type), device); - } + auto& manager = DeviceManager::getInstance(); + for(auto type : fl::getDeviceTypes()) { + if(manager.isDeviceTypeAvailable(type)) { + for(auto* device : manager.getDevicesOfType(type)) { + device->setActive(); + ASSERT_EQ(&manager.getActiveDevice(type), device); + } + } } - } } TEST(DeviceTest, addSetActiveCallback) { - auto& manager = DeviceManager::getInstance(); - for (const auto type : fl::getDeviceTypes()) { - if (manager.isDeviceTypeAvailable(type)) { - for (auto* device : manager.getDevicesOfType(type)) { - int count = 0; - auto incCount = [&count](int){ count++; }; - device->addSetActiveCallback(incCount); - device->setActive(); - ASSERT_EQ(count, 1); - } + auto& manager = DeviceManager::getInstance(); + for(const auto type : fl::getDeviceTypes()) { + if(manager.isDeviceTypeAvailable(type)) { + for(auto* device : manager.getDevicesOfType(type)) { + int count = 0; + auto incCount = [&count](int) { count++; }; + device->addSetActiveCallback(incCount); + device->setActive(); + ASSERT_EQ(count, 1); + } + } } - } } TEST(DeviceTest, sync) { - const auto& manager = DeviceManager::getInstance(); - for (const auto type : fl::getDeviceTypes()) { - if (manager.isDeviceTypeAvailable(type)) { - for (const auto* device : manager.getDevicesOfType(type)) { - ASSERT_NO_THROW(device->sync()); - } + const auto& manager = DeviceManager::getInstance(); + for(const auto type : fl::getDeviceTypes()) { + if(manager.isDeviceTypeAvailable(type)) { + for(const auto* device : manager.getDevicesOfType(type)) { + ASSERT_NO_THROW(device->sync()); + } + } } - } } TEST(DeviceTest, getStream) { - auto& manager = DeviceManager::getInstance(); - for (const auto type : fl::getDeviceTypes()) { - if (manager.isDeviceTypeAvailable(type)) { - for (const auto* device : manager.getDevicesOfType(type)) { - for (const auto& stream : device->getStreams()) { - ASSERT_EQ(&stream->device(), device); + auto& manager = DeviceManager::getInstance(); + for(const auto type : fl::getDeviceTypes()) { + if(manager.isDeviceTypeAvailable(type)) { + for(const auto* device : manager.getDevicesOfType(type)) { + for(const auto& stream : device->getStreams()) { + ASSERT_EQ(&stream->device(), device); + } + } } - } } - } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/runtime/DeviceTypeTest.cpp b/flashlight/fl/test/runtime/DeviceTypeTest.cpp index f98369c..9d42285 100644 --- a/flashlight/fl/test/runtime/DeviceTypeTest.cpp +++ b/flashlight/fl/test/runtime/DeviceTypeTest.cpp @@ -12,18 +12,18 @@ using fl::DeviceType; TEST(DeviceTypeTest, getAllDeviceTypes) { - const auto& allDevices = fl::getDeviceTypes(); - ASSERT_TRUE(allDevices.contains(DeviceType::x64)); - ASSERT_TRUE(allDevices.contains(DeviceType::CUDA)); - ASSERT_EQ(allDevices.size(), 2); + const auto& allDevices = fl::getDeviceTypes(); + ASSERT_TRUE(allDevices.contains(DeviceType::x64)); + ASSERT_TRUE(allDevices.contains(DeviceType::CUDA)); + ASSERT_EQ(allDevices.size(), 2); } TEST(DeviceTypeTest, deviceTypeToString) { - ASSERT_EQ(deviceTypeToString(DeviceType::x64), "x64"); - ASSERT_EQ(deviceTypeToString(DeviceType::CUDA), "CUDA"); + ASSERT_EQ(deviceTypeToString(DeviceType::x64), "x64"); + ASSERT_EQ(deviceTypeToString(DeviceType::CUDA), "CUDA"); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/ComputeTest.cpp b/flashlight/fl/test/tensor/ComputeTest.cpp index a160ce4..21f1d8d 100644 --- a/flashlight/fl/test/tensor/ComputeTest.cpp +++ b/flashlight/fl/test/tensor/ComputeTest.cpp @@ -14,31 +14,31 @@ #include "flashlight/fl/tensor/TensorBase.h" TEST(TensorComputeTest, sync) { - // Testing whether a value is ready isn't meaningful since any function to - // inspect its state will implicitly synchronize -- this test simply ensures - // sync runs - auto t1 = fl::full({10, 10}, 1.); - auto t2 = fl::full({10, 10}, 2.); - auto t3 = t1 + t2; - fl::sync(); + // Testing whether a value is ready isn't meaningful since any function to + // inspect its state will implicitly synchronize -- this test simply ensures + // sync runs + auto t1 = fl::full({10, 10}, 1.); + auto t2 = fl::full({10, 10}, 2.); + auto t3 = t1 + t2; + fl::sync(); - int deviceId = fl::getDevice(); - auto t4 = t1 + t2 + t3; - fl::sync(deviceId); + int deviceId = fl::getDevice(); + auto t4 = t1 + t2 + t3; + fl::sync(deviceId); } TEST(TensorComputeTest, eval) { - // Testing whether a value is ready isn't meaningful since any function to - // inspect its state will implicitly synchronize -- this test simply ensures - // eval runs - auto t1 = fl::full({10, 10}, 3.); - auto t2 = fl::full({10, 10}, 4.); - auto t3 = t1 * t2; - fl::eval(t3); + // Testing whether a value is ready isn't meaningful since any function to + // inspect its state will implicitly synchronize -- this test simply ensures + // eval runs + auto t1 = fl::full({10, 10}, 3.); + auto t2 = fl::full({10, 10}, 4.); + auto t3 = t1 * t2; + fl::eval(t3); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/IndexTest.cpp b/flashlight/fl/test/tensor/IndexTest.cpp index bc99b62..36eedc2 100644 --- a/flashlight/fl/test/tensor/IndexTest.cpp +++ b/flashlight/fl/test/tensor/IndexTest.cpp @@ -16,256 +16,269 @@ using namespace ::testing; using namespace fl; TEST(IndexTest, range) { - auto s1 = fl::range(3); - ASSERT_EQ(s1.start(), 0); - ASSERT_EQ(s1.endVal(), 3); - ASSERT_EQ(s1.stride(), 1); - - auto s2 = fl::range(4, 5); - ASSERT_EQ(s2.start(), 4); - ASSERT_EQ(s2.endVal(), 5); - ASSERT_EQ(s2.stride(), 1); - - auto s3 = fl::range(7, 8, 9); - ASSERT_EQ(s3.stride(), 9); - - auto s4 = fl::range(1, fl::end, 2); - ASSERT_EQ(s4.start(), 1); - ASSERT_EQ(s4.end(), std::nullopt); - ASSERT_EQ(s4.stride(), 2); + auto s1 = fl::range(3); + ASSERT_EQ(s1.start(), 0); + ASSERT_EQ(s1.endVal(), 3); + ASSERT_EQ(s1.stride(), 1); + + auto s2 = fl::range(4, 5); + ASSERT_EQ(s2.start(), 4); + ASSERT_EQ(s2.endVal(), 5); + ASSERT_EQ(s2.stride(), 1); + + auto s3 = fl::range(7, 8, 9); + ASSERT_EQ(s3.stride(), 9); + + auto s4 = fl::range(1, fl::end, 2); + ASSERT_EQ(s4.start(), 1); + ASSERT_EQ(s4.end(), std::nullopt); + ASSERT_EQ(s4.stride(), 2); } TEST(IndexTest, rangeEq) { - ASSERT_EQ(fl::range(4), fl::range(4)); - ASSERT_EQ(fl::range(2, 3), fl::range(2, 3)); - ASSERT_EQ(fl::range(5, 6, 7), fl::range(5, 6, 7)); - ASSERT_NE(fl::range(5, 11, 7), fl::range(5, 6, 7)); + ASSERT_EQ(fl::range(4), fl::range(4)); + ASSERT_EQ(fl::range(2, 3), fl::range(2, 3)); + ASSERT_EQ(fl::range(5, 6, 7), fl::range(5, 6, 7)); + ASSERT_NE(fl::range(5, 11, 7), fl::range(5, 6, 7)); } TEST(IndexTest, Type) { - using namespace detail; - ASSERT_EQ(fl::Index(3).type(), IndexType::Literal); - ASSERT_EQ(fl::Index(fl::range(3)).type(), IndexType::Range); - ASSERT_EQ(fl::Index(fl::span).type(), IndexType::Span); - ASSERT_EQ(fl::Index(fl::full({2, 2}, 4)).type(), IndexType::Tensor); - ASSERT_TRUE(fl::Index(fl::span).isSpan()); + using namespace detail; + ASSERT_EQ(fl::Index(3).type(), IndexType::Literal); + ASSERT_EQ(fl::Index(fl::range(3)).type(), IndexType::Range); + ASSERT_EQ(fl::Index(fl::span).type(), IndexType::Span); + ASSERT_EQ(fl::Index(fl::full({2, 2}, 4)).type(), IndexType::Tensor); + ASSERT_TRUE(fl::Index(fl::span).isSpan()); } TEST(IndexTest, ArrayFireMaxIndex) { - auto t = fl::full({2, 3, 4, 5}, 6.); - if (t.backendType() != TensorBackendType::ArrayFire) { - GTEST_SKIP() << "Default Tensor type isn't ArrayFire"; - } - ASSERT_THROW(t(1, 2, 3, 4, 5), std::invalid_argument); + auto t = fl::full({2, 3, 4, 5}, 6.); + if(t.backendType() != TensorBackendType::ArrayFire) { + GTEST_SKIP() << "Default Tensor type isn't ArrayFire"; + } + ASSERT_THROW(t(1, 2, 3, 4, 5), std::invalid_argument); } TEST(IndexTest, Shape) { - auto t = fl::full({4, 4}, 3.); - ASSERT_EQ(t(2, 2).shape(), Shape({1})); - ASSERT_EQ(t(2, fl::span).shape(), Shape({4})); - ASSERT_EQ(t(2).shape(), Shape({4})); - ASSERT_EQ(t(fl::range(3)).shape(), Shape({3, 4})); - // TODO {0, 4} once empty ranges are supported across all backends - // ASSERT_EQ(t(fl::range(1, 1)).shape(), Shape({1, 4})); - ASSERT_EQ(t(fl::range(1, 2)).shape(), Shape({1, 4})); - // TODO ditto - // ASSERT_EQ(t(fl::span, fl::range(1, 1)).shape(), Shape({4, 1})); - ASSERT_EQ(t(fl::range(1, 2), fl::range(1, 2)).shape(), Shape({1, 1})); - ASSERT_EQ(t(fl::range(0, fl::end)).shape(), Shape({4, 4})); - ASSERT_EQ(t(fl::range(0, fl::end, 2)).shape(), Shape({2, 4})); - - auto t2 = fl::full({5, 6, 7, 8}, 3.); - ASSERT_EQ(t2(2, fl::range(2, 4), fl::span, 3).shape(), Shape({2, 7})); - ASSERT_EQ(t2(fl::span, 3, fl::span, fl::span).shape(), Shape({5, 7, 8})); - ASSERT_EQ( - t2(fl::span, fl::range(1, 2), fl::span, fl::span).shape(), - Shape({5, 1, 7, 8})); + auto t = fl::full({4, 4}, 3.); + ASSERT_EQ(t(2, 2).shape(), Shape({1})); + ASSERT_EQ(t(2, fl::span).shape(), Shape({4})); + ASSERT_EQ(t(2).shape(), Shape({4})); + ASSERT_EQ(t(fl::range(3)).shape(), Shape({3, 4})); + // TODO {0, 4} once empty ranges are supported across all backends + // ASSERT_EQ(t(fl::range(1, 1)).shape(), Shape({1, 4})); + ASSERT_EQ(t(fl::range(1, 2)).shape(), Shape({1, 4})); + // TODO ditto + // ASSERT_EQ(t(fl::span, fl::range(1, 1)).shape(), Shape({4, 1})); + ASSERT_EQ(t(fl::range(1, 2), fl::range(1, 2)).shape(), Shape({1, 1})); + ASSERT_EQ(t(fl::range(0, fl::end)).shape(), Shape({4, 4})); + ASSERT_EQ(t(fl::range(0, fl::end, 2)).shape(), Shape({2, 4})); + + auto t2 = fl::full({5, 6, 7, 8}, 3.); + ASSERT_EQ(t2(2, fl::range(2, 4), fl::span, 3).shape(), Shape({2, 7})); + ASSERT_EQ(t2(fl::span, 3, fl::span, fl::span).shape(), Shape({5, 7, 8})); + ASSERT_EQ( + t2(fl::span, fl::range(1, 2), fl::span, fl::span).shape(), + Shape({5, 1, 7, 8}) + ); } TEST(IndexTest, IndexAssignment) { - auto t = fl::full({4, 4}, 0, fl::dtype::s32); - t(fl::span, 0) = 1; - t(fl::span, 1) += 1; - t(fl::span, fl::range(2, fl::end)) += 1; - t(fl::span, fl::span) *= 7; - t /= 7; - ASSERT_TRUE(allClose(t, fl::full({4, 4}, 1))); - - auto a = fl::full({6, 6}, 0.); - a(3, 4) = 4.; - ASSERT_TRUE(allClose(a(3, 4), fl::full({1}, 4.))); - a(2) = fl::full({6}, 8.); - ASSERT_TRUE(allClose(a(2), fl::full({6}, 8.))); - - auto b = fl::full({3, 3}, 1.); - auto c = b; - b += 1; - ASSERT_TRUE(allClose(b, fl::full({3, 3}, 2.))); - ASSERT_TRUE(allClose(c, fl::full({3, 3}, 1.))); - - auto q = fl::full({4, 4}, 2.); - auto r = fl::full({4}, 3.); - q(0) = r; - ASSERT_TRUE(allClose(q(0), r)); - ASSERT_TRUE(allClose(q(fl::range(1, fl::end)), fl::full({3, 4}, 2.))); - - auto k = fl::rand({100, 200}); - k(3) = fl::full({200}, 0.); - ASSERT_TRUE(allClose(k(3), fl::full({200}, 0.))); - - // Weak ref - auto g = fl::rand({3, 4, 5}); - auto gC = g.copy(); - auto gI = g(fl::span, fl::range(0, 3)); - g(fl::span, fl::range(0, 3)) += 3; - gI -= 3; - ASSERT_TRUE(allClose(gC(fl::span, fl::range(0, 3)), gI)); - - auto x = fl::rand({5, 6, 7, 8}); - x(3) = fl::full({6, 7, 8}, 0.); - ASSERT_TRUE(allClose(x(3), fl::full({6, 7, 8}, 0.))); - x(fl::span, fl::span, 2) = fl::full({5, 6, 8}, 3.); - ASSERT_TRUE(allClose(x(fl::span, fl::span, 2), fl::full({5, 6, 8}, 3.))); - ASSERT_THROW( - x(fl::span, fl::span, 4) -= fl::rand({5, 6, 1, 8}), - std::invalid_argument); - - x(fl::span, fl::range(1, 3), fl::span) = fl::full({5, 2, 7, 8}, 2.); - ASSERT_TRUE(allClose( - x(fl::span, fl::range(1, 3), fl::span), fl::full({5, 2, 7, 8}, 2.))); - - x(fl::span, fl::arange({5}), fl::span, fl::arange({5})) = - fl::full({5, 5, 7, 5}, 2.); - ASSERT_TRUE(allClose( - x(fl::span, fl::range(1, 3), fl::span), fl::full({5, 2, 7, 8}, 2.))); + auto t = fl::full({4, 4}, 0, fl::dtype::s32); + t(fl::span, 0) = 1; + t(fl::span, 1) += 1; + t(fl::span, fl::range(2, fl::end)) += 1; + t(fl::span, fl::span) *= 7; + t /= 7; + ASSERT_TRUE(allClose(t, fl::full({4, 4}, 1))); + + auto a = fl::full({6, 6}, 0.); + a(3, 4) = 4.; + ASSERT_TRUE(allClose(a(3, 4), fl::full({1}, 4.))); + a(2) = fl::full({6}, 8.); + ASSERT_TRUE(allClose(a(2), fl::full({6}, 8.))); + + auto b = fl::full({3, 3}, 1.); + auto c = b; + b += 1; + ASSERT_TRUE(allClose(b, fl::full({3, 3}, 2.))); + ASSERT_TRUE(allClose(c, fl::full({3, 3}, 1.))); + + auto q = fl::full({4, 4}, 2.); + auto r = fl::full({4}, 3.); + q(0) = r; + ASSERT_TRUE(allClose(q(0), r)); + ASSERT_TRUE(allClose(q(fl::range(1, fl::end)), fl::full({3, 4}, 2.))); + + auto k = fl::rand({100, 200}); + k(3) = fl::full({200}, 0.); + ASSERT_TRUE(allClose(k(3), fl::full({200}, 0.))); + + // Weak ref + auto g = fl::rand({3, 4, 5}); + auto gC = g.copy(); + auto gI = g(fl::span, fl::range(0, 3)); + g(fl::span, fl::range(0, 3)) += 3; + gI -= 3; + ASSERT_TRUE(allClose(gC(fl::span, fl::range(0, 3)), gI)); + + auto x = fl::rand({5, 6, 7, 8}); + x(3) = fl::full({6, 7, 8}, 0.); + ASSERT_TRUE(allClose(x(3), fl::full({6, 7, 8}, 0.))); + x(fl::span, fl::span, 2) = fl::full({5, 6, 8}, 3.); + ASSERT_TRUE(allClose(x(fl::span, fl::span, 2), fl::full({5, 6, 8}, 3.))); + ASSERT_THROW( + x(fl::span, fl::span, 4) -= fl::rand({5, 6, 1, 8}), + std::invalid_argument + ); + + x(fl::span, fl::range(1, 3), fl::span) = fl::full({5, 2, 7, 8}, 2.); + ASSERT_TRUE( + allClose( + x(fl::span, fl::range(1, 3), fl::span), + fl::full({5, 2, 7, 8}, 2.) + ) + ); + + x(fl::span, fl::arange({5}), fl::span, fl::arange({5})) = + fl::full({5, 5, 7, 5}, 2.); + ASSERT_TRUE( + allClose( + x(fl::span, fl::range(1, 3), fl::span), + fl::full({5, 2, 7, 8}, 2.) + ) + ); } TEST(IndexTest, IndexInPlaceOps) { - auto a = fl::full({4, 5, 6}, 0.); - auto b = fl::full({5, 6}, 1.); - a(2) += b; - ASSERT_TRUE(allClose(a(2), b)); - a(2) -= b; - ASSERT_TRUE(allClose(a, fl::full({4, 5, 6}, 0.))); - - auto f = fl::full({1, 3, 3}, 4.); - auto d = fl::full({3}, 6.); - f({0, 1}) += d; - ASSERT_TRUE(allClose(f({0, 1}), d + 4.)); - - // Integral type - auto s = fl::full({4, 5, 6}, 5, fl::dtype::s32); - auto sA = fl::full({6}, 3, fl::dtype::s32); - s(0, 1) += sA; - ASSERT_TRUE(allClose(s(0, 1), sA + 5)); + auto a = fl::full({4, 5, 6}, 0.); + auto b = fl::full({5, 6}, 1.); + a(2) += b; + ASSERT_TRUE(allClose(a(2), b)); + a(2) -= b; + ASSERT_TRUE(allClose(a, fl::full({4, 5, 6}, 0.))); + + auto f = fl::full({1, 3, 3}, 4.); + auto d = fl::full({3}, 6.); + f({0, 1}) += d; + ASSERT_TRUE(allClose(f({0, 1}), d + 4.)); + + // Integral type + auto s = fl::full({4, 5, 6}, 5, fl::dtype::s32); + auto sA = fl::full({6}, 3, fl::dtype::s32); + s(0, 1) += sA; + ASSERT_TRUE(allClose(s(0, 1), sA + 5)); } TEST(IndexTest, flat) { - auto m = fl::rand({4, 6}); - for (unsigned i = 0; i < m.elements(); ++i) { - ASSERT_TRUE(allClose(m.flat(i), m(i % 4, i / 4))); - } - - auto n = fl::rand({4, 6, 8}); - for (unsigned i = 0; i < n.elements(); ++i) { - ASSERT_TRUE(allClose(n.flat(i), n(i % 4, (i / 4) % 6, (i / (4 * 6)) % 8))); - } - - auto a = fl::full({5, 6, 7, 8}, 9.); - std::vector testIndices = {0, 1, 4, 11, 62, 104, 288}; - for (const int i : testIndices) { - ASSERT_EQ(a.flat(i).scalar(), 9.); - } - - a.flat(8) = 5.; - ASSERT_EQ(a.flat(8).scalar(), 5.); - - for (const int i : testIndices) { - a.flat(i) = i + 1; - } - for (const int i : testIndices) { - ASSERT_EQ( - a(i % 5, (i / 5) % 6, (i / (5 * 6)) % 7, (i / (5 * 6 * 7)) % 8) + auto m = fl::rand({4, 6}); + for(unsigned i = 0; i < m.elements(); ++i) { + ASSERT_TRUE(allClose(m.flat(i), m(i % 4, i / 4))); + } + + auto n = fl::rand({4, 6, 8}); + for(unsigned i = 0; i < n.elements(); ++i) { + ASSERT_TRUE(allClose(n.flat(i), n(i % 4, (i / 4) % 6, (i / (4 * 6)) % 8))); + } + + auto a = fl::full({5, 6, 7, 8}, 9.); + std::vector testIndices = {0, 1, 4, 11, 62, 104, 288}; + for(const int i : testIndices) { + ASSERT_EQ(a.flat(i).scalar(), 9.); + } + + a.flat(8) = 5.; + ASSERT_EQ(a.flat(8).scalar(), 5.); + + for(const int i : testIndices) { + a.flat(i) = i + 1; + } + for(const int i : testIndices) { + ASSERT_EQ( + a(i % 5, (i / 5) % 6, (i / (5 * 6)) % 7, (i / (5 * 6 * 7)) % 8) .scalar(), - i + 1); - } - - // Tensor assignment - a.flat(32) = fl::full({1}, 7.4); - ASSERT_TRUE(allClose(a.flatten()(32), fl::full({1}, 7.4))); - // In-place - a.flat(100) += 33; - ASSERT_TRUE(allClose(a.flatten()(100), fl::full({1}, 33 + 9.))); - - // Tensor indexing - auto indexer = Tensor::fromVector(testIndices); - auto ref = a.flat(indexer).copy(); - ASSERT_EQ(ref.shape(), Shape({(Dim)indexer.elements()})); - a.flat(indexer) -= 10; - ASSERT_TRUE(allClose(a.flat(indexer), ref - 10)); - for (const int i : testIndices) { - ASSERT_EQ( - a(i % 5, (i / 5) % 6, (i / (5 * 6)) % 7, (i / (5 * 6 * 7)) % 8) + i + 1 + ); + } + + // Tensor assignment + a.flat(32) = fl::full({1}, 7.4); + ASSERT_TRUE(allClose(a.flatten()(32), fl::full({1}, 7.4))); + // In-place + a.flat(100) += 33; + ASSERT_TRUE(allClose(a.flatten()(100), fl::full({1}, 33 + 9.))); + + // Tensor indexing + auto indexer = Tensor::fromVector(testIndices); + auto ref = a.flat(indexer).copy(); + ASSERT_EQ(ref.shape(), Shape({(Dim) indexer.elements()})); + a.flat(indexer) -= 10; + ASSERT_TRUE(allClose(a.flat(indexer), ref - 10)); + for(const int i : testIndices) { + ASSERT_EQ( + a(i % 5, (i / 5) % 6, (i / (5 * 6)) % 7, (i / (5 * 6 * 7)) % 8) .scalar(), - i + 1 - 10); - } - - // Range flat assignment - auto rA = fl::rand({6}); - a.flat(fl::range(1, 7)) = rA; - ASSERT_TRUE(allClose(rA, a.flatten()(fl::range(1, 7)))); - - // With leading singleton dims - auto b = fl::rand({1, 1, 10}); - ASSERT_EQ(b.flat(fl::range(3)).shape(), Shape({3})); - b.flat(fl::range(3)) = fl::full({3}, 6.); - ASSERT_TRUE(allClose(b.flatten()(fl::range(3)), fl::full({3}, 6.))); + i + 1 - 10 + ); + } + + // Range flat assignment + auto rA = fl::rand({6}); + a.flat(fl::range(1, 7)) = rA; + ASSERT_TRUE(allClose(rA, a.flatten()(fl::range(1, 7)))); + + // With leading singleton dims + auto b = fl::rand({1, 1, 10}); + ASSERT_EQ(b.flat(fl::range(3)).shape(), Shape({3})); + b.flat(fl::range(3)) = fl::full({3}, 6.); + ASSERT_TRUE(allClose(b.flatten()(fl::range(3)), fl::full({3}, 6.))); } TEST(IndexTest, TensorIndex) { - std::vector idxs = {0, 1, 4, 9, 11, 13, 16, 91}; - unsigned size = idxs.size(); - auto indices = fl::full({size}, 0); - for (int i = 0; i < size; ++i) { - indices(i) = idxs[i]; - } - auto a = fl::rand({100}); - auto indexed = a(indices); - for (int i = 0; i < size; ++i) { - ASSERT_TRUE(allClose(indexed(i), a(idxs[i]))); - } - - a(indices) = 5.; - ASSERT_TRUE(allClose(a(indices), fl::full({size}, 5.))); - - // Out of range indices - auto i = fl::arange({10}, 0, fl::dtype::u32); - auto b = fl::rand({20, 20}); - auto ref = b; - ASSERT_EQ(b(i).shape(), b(fl::range(10)).shape()); - ASSERT_TRUE(allClose(b(i), b(fl::range(10)))); - - b(i) += 3.; - ASSERT_TRUE(allClose(b(i), b(fl::range(10)))); - ASSERT_TRUE(allClose(b(i), (ref + 3)(i))); - b(i) += fl::full({(Dim)i.elements(), b.dim(1)}, 10.); - ASSERT_EQ(b(i).shape(), (ref + 13)(i).shape()); - ASSERT_TRUE(allClose(b(i), (ref + 13)(i))); - - // Tensor index a > 1D tensor - auto c = fl::rand({10, 10, 10}); - ASSERT_EQ(c(fl::arange({5})).shape(), Shape({5, 10, 10})); + std::vector idxs = {0, 1, 4, 9, 11, 13, 16, 91}; + unsigned size = idxs.size(); + auto indices = fl::full({size}, 0); + for(int i = 0; i < size; ++i) { + indices(i) = idxs[i]; + } + auto a = fl::rand({100}); + auto indexed = a(indices); + for(int i = 0; i < size; ++i) { + ASSERT_TRUE(allClose(indexed(i), a(idxs[i]))); + } + + a(indices) = 5.; + ASSERT_TRUE(allClose(a(indices), fl::full({size}, 5.))); + + // Out of range indices + auto i = fl::arange({10}, 0, fl::dtype::u32); + auto b = fl::rand({20, 20}); + auto ref = b; + ASSERT_EQ(b(i).shape(), b(fl::range(10)).shape()); + ASSERT_TRUE(allClose(b(i), b(fl::range(10)))); + + b(i) += 3.; + ASSERT_TRUE(allClose(b(i), b(fl::range(10)))); + ASSERT_TRUE(allClose(b(i), (ref + 3)(i))); + b(i) += fl::full({(Dim) i.elements(), b.dim(1)}, 10.); + ASSERT_EQ(b(i).shape(), (ref + 13)(i).shape()); + ASSERT_TRUE(allClose(b(i), (ref + 13)(i))); + + // Tensor index a > 1D tensor + auto c = fl::rand({10, 10, 10}); + ASSERT_EQ(c(fl::arange({5})).shape(), Shape({5, 10, 10})); } TEST(IndexTest, ExpressionIndex) { - auto a = Tensor::fromVector({2, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - ASSERT_TRUE(allClose(a(a < 5), Tensor::fromVector({0, 1, 2, 3, 4}))); - ASSERT_TRUE( - allClose(a(a < 7), Tensor::fromVector({0, 1, 2, 3, 4, 5, 6}))); + auto a = Tensor::fromVector({2, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + ASSERT_TRUE(allClose(a(a < 5), Tensor::fromVector({0, 1, 2, 3, 4}))); + ASSERT_TRUE( + allClose(a(a < 7), Tensor::fromVector({0, 1, 2, 3, 4, 5, 6})) + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/ShapeTest.cpp b/flashlight/fl/test/tensor/ShapeTest.cpp index d69b5a1..2f3d1e1 100644 --- a/flashlight/fl/test/tensor/ShapeTest.cpp +++ b/flashlight/fl/test/tensor/ShapeTest.cpp @@ -16,79 +16,79 @@ using namespace ::testing; using namespace fl; TEST(ShapeTest, Basic) { - auto s = Shape({3, 4}); - ASSERT_EQ(s.ndim(), 2); - ASSERT_EQ(s.dim(0), 3); - ASSERT_EQ(s.dim(1), 4); - EXPECT_THROW(s.dim(5), std::invalid_argument); + auto s = Shape({3, 4}); + ASSERT_EQ(s.ndim(), 2); + ASSERT_EQ(s.dim(0), 3); + ASSERT_EQ(s.dim(1), 4); + EXPECT_THROW(s.dim(5), std::invalid_argument); } TEST(ShapeTest, ManyDims) { - if (Shape::kMaxDims <= 4) { - GTEST_SKIP() << "Max shape dimensions is <= 4"; - } - auto many = Shape({1, 2, 3, 4, 5, 6, 7}); - ASSERT_EQ(many.ndim(), 7); - ASSERT_EQ(many.dim(5), 6); + if(Shape::kMaxDims <= 4) { + GTEST_SKIP() << "Max shape dimensions is <= 4"; + } + auto many = Shape({1, 2, 3, 4, 5, 6, 7}); + ASSERT_EQ(many.ndim(), 7); + ASSERT_EQ(many.dim(5), 6); } TEST(ShapeTest, ndim) { - ASSERT_EQ(Shape().ndim(), 0); - ASSERT_EQ(Shape({1, 0, 1}).ndim(), 3); - ASSERT_EQ(Shape({1, 1, 1}).ndim(), 3); - ASSERT_EQ(Shape({5, 2, 3}).ndim(), 3); - ASSERT_EQ(Shape({1, 2, 3, 6}).ndim(), 4); - if (Shape::kMaxDims > 4) { - ASSERT_EQ(Shape({1, 2, 3, 1, 1, 1}).ndim(), 6); - ASSERT_EQ(Shape({1, 2, 3, 1, 1, 1, 5}).ndim(), 7); - ASSERT_EQ(Shape({4, 2, 3, 1, 1, 1, 5}).ndim(), 7); - } + ASSERT_EQ(Shape().ndim(), 0); + ASSERT_EQ(Shape({1, 0, 1}).ndim(), 3); + ASSERT_EQ(Shape({1, 1, 1}).ndim(), 3); + ASSERT_EQ(Shape({5, 2, 3}).ndim(), 3); + ASSERT_EQ(Shape({1, 2, 3, 6}).ndim(), 4); + if(Shape::kMaxDims > 4) { + ASSERT_EQ(Shape({1, 2, 3, 1, 1, 1}).ndim(), 6); + ASSERT_EQ(Shape({1, 2, 3, 1, 1, 1, 5}).ndim(), 7); + ASSERT_EQ(Shape({4, 2, 3, 1, 1, 1, 5}).ndim(), 7); + } } TEST(ShapeTest, elements) { - ASSERT_EQ(Shape().elements(), 1); // empty shape = scalar - ASSERT_EQ(Shape({0}).elements(), 0); // empty tensor - ASSERT_EQ(Shape({1, 1, 1, 1}).elements(), 1); - ASSERT_EQ(Shape({1, 2, 3, 4}).elements(), 24); - ASSERT_EQ(Shape({1, 2, 3, 0}).elements(), 0); + ASSERT_EQ(Shape().elements(), 1); // empty shape = scalar + ASSERT_EQ(Shape({0}).elements(), 0); // empty tensor + ASSERT_EQ(Shape({1, 1, 1, 1}).elements(), 1); + ASSERT_EQ(Shape({1, 2, 3, 4}).elements(), 24); + ASSERT_EQ(Shape({1, 2, 3, 0}).elements(), 0); } TEST(ShapeTest, Equality) { - auto a = Shape({1, 2, 3, 4}); - ASSERT_EQ(a, Shape({1, 2, 3, 4})); - ASSERT_NE(a, Shape({4, 3, 4})); - ASSERT_NE(Shape({1, 2}), Shape({1, 1, 1, 2})); - ASSERT_NE(Shape({5, 2, 3}), Shape({5, 2, 3, 1})); - ASSERT_EQ(Shape({5, 2, 3, 1}), Shape({5, 2, 3, 1})); - ASSERT_NE(Shape({5, 2, 1, 1}), Shape({5, 2, 1, 4})); + auto a = Shape({1, 2, 3, 4}); + ASSERT_EQ(a, Shape({1, 2, 3, 4})); + ASSERT_NE(a, Shape({4, 3, 4})); + ASSERT_NE(Shape({1, 2}), Shape({1, 1, 1, 2})); + ASSERT_NE(Shape({5, 2, 3}), Shape({5, 2, 3, 1})); + ASSERT_EQ(Shape({5, 2, 3, 1}), Shape({5, 2, 3, 1})); + ASSERT_NE(Shape({5, 2, 1, 1}), Shape({5, 2, 1, 4})); } TEST(ShapeTest, Indexing) { - auto a = Shape({3, 4, 5, 2}); - ASSERT_EQ(a[0], 3); - ASSERT_EQ(a[1], 4); - ASSERT_EQ(a[2], 5); - ASSERT_EQ(a[3], 2); - ASSERT_THROW(a[4], std::invalid_argument); + auto a = Shape({3, 4, 5, 2}); + ASSERT_EQ(a[0], 3); + ASSERT_EQ(a[1], 4); + ASSERT_EQ(a[2], 5); + ASSERT_EQ(a[3], 2); + ASSERT_THROW(a[4], std::invalid_argument); } TEST(ShapeTest, string) { - auto checkShapeStrEqual = [](const Shape& s, const std::string& str) -> void { - auto sStr = s.toString(); - ASSERT_EQ(sStr, str); - std::stringstream ss; - ss << sStr; - ASSERT_EQ(sStr, ss.str()); - }; + auto checkShapeStrEqual = [](const Shape& s, const std::string& str) -> void { + auto sStr = s.toString(); + ASSERT_EQ(sStr, str); + std::stringstream ss; + ss << sStr; + ASSERT_EQ(sStr, ss.str()); + }; - checkShapeStrEqual(Shape({3, 4, 7, 9}), "(3, 4, 7, 9)"); - checkShapeStrEqual(Shape({}), "()"); - checkShapeStrEqual(Shape({0}), "(0)"); - checkShapeStrEqual(Shape({7, 7, 7, 7, 7, 7, 7}), "(7, 7, 7, 7, 7, 7, 7)"); + checkShapeStrEqual(Shape({3, 4, 7, 9}), "(3, 4, 7, 9)"); + checkShapeStrEqual(Shape({}), "()"); + checkShapeStrEqual(Shape({0}), "(0)"); + checkShapeStrEqual(Shape({7, 7, 7, 7, 7, 7, 7}), "(7, 7, 7, 7, 7, 7, 7)"); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/TensorAdapterTest.cpp b/flashlight/fl/test/tensor/TensorAdapterTest.cpp index f8e63d3..fb1e801 100644 --- a/flashlight/fl/test/tensor/TensorAdapterTest.cpp +++ b/flashlight/fl/test/tensor/TensorAdapterTest.cpp @@ -16,29 +16,29 @@ using namespace ::testing; using namespace fl; TEST(TensorBaseTest, DefaultBackend) { - Tensor t; - ASSERT_EQ(t.backendType(), DefaultTensorType_t::tensorBackendType); + Tensor t; + ASSERT_EQ(t.backendType(), DefaultTensorType_t::tensorBackendType); } TEST(TensorBaseTest, ImplTypeConversion) { - // Converting to the same type is a noop - auto a = fl::rand({6, 8}); - auto c = a.copy(); - TensorBackendType aBackend = a.backendType(); - auto b = to(std::move(a)); - ASSERT_EQ(aBackend, b.backendType()); - ASSERT_TRUE(allClose(b, c)); + // Converting to the same type is a noop + auto a = fl::rand({6, 8}); + auto c = a.copy(); + TensorBackendType aBackend = a.backendType(); + auto b = to(std::move(a)); + ASSERT_EQ(aBackend, b.backendType()); + ASSERT_TRUE(allClose(b, c)); } TEST(TensorBaseTest, hasAdapter) { - Tensor a = fromScalar(3.14, fl::dtype::f32); - ASSERT_TRUE(a.hasAdapter()); - detail::releaseAdapterUnsafe(a); - ASSERT_FALSE(a.hasAdapter()); + Tensor a = fromScalar(3.14, fl::dtype::f32); + ASSERT_TRUE(a.hasAdapter()); + detail::releaseAdapterUnsafe(a); + ASSERT_FALSE(a.hasAdapter()); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/TensorBLASTest.cpp b/flashlight/fl/test/tensor/TensorBLASTest.cpp index 0447bd4..c39d6e6 100644 --- a/flashlight/fl/test/tensor/TensorBLASTest.cpp +++ b/flashlight/fl/test/tensor/TensorBLASTest.cpp @@ -15,167 +15,217 @@ using namespace ::testing; using namespace fl; TEST(TensorBLASTest, matmul) { - // TODO: test tensors with order > 2 + // TODO: test tensors with order > 2 - // Reference impl - auto matmulRef = [](const Tensor& lhs, const Tensor& rhs) { - // (M x N) x (N x K) --> (M x K) - int M = lhs.dim(0); - int N = lhs.dim(1); - int K = rhs.dim(1); + // Reference impl + auto matmulRef = [](const Tensor& lhs, const Tensor& rhs) { + // (M x N) x (N x K) --> (M x K) + int M = lhs.dim(0); + int N = lhs.dim(1); + int K = rhs.dim(1); - auto out = fl::full({M, K}, 0.); + auto out = fl::full({M, K}, 0.); - for (unsigned i = 0; i < M; ++i) { - for (unsigned j = 0; j < K; ++j) { - for (unsigned k = 0; k < N; ++k) { - out(i, j) += lhs(i, k) * rhs(k, j); - } - } - } - return out; - }; + for(unsigned i = 0; i < M; ++i) { + for(unsigned j = 0; j < K; ++j) { + for(unsigned k = 0; k < N; ++k) { + out(i, j) += lhs(i, k) * rhs(k, j); + } + } + } + return out; + }; - int i = 10; - int j = 20; - int k = 12; + int i = 10; + int j = 20; + int k = 12; - auto a = fl::rand({i, j}); - auto b = fl::rand({j, k}); - auto ref = matmulRef(a, b); - ASSERT_TRUE(allClose(fl::matmul(a, b), ref)); - ASSERT_TRUE(allClose( - fl::matmul( - a, - fl::transpose(b), - fl::MatrixProperty::None, - fl::MatrixProperty::Transpose), - ref)); - ASSERT_TRUE(allClose( - fl::matmul(fl::transpose(a), b, fl::MatrixProperty::Transpose), ref)); + auto a = fl::rand({i, j}); + auto b = fl::rand({j, k}); + auto ref = matmulRef(a, b); + ASSERT_TRUE(allClose(fl::matmul(a, b), ref)); + ASSERT_TRUE( + allClose( + fl::matmul( + a, + fl::transpose(b), + fl::MatrixProperty::None, + fl::MatrixProperty::Transpose + ), + ref + ) + ); + ASSERT_TRUE( + allClose( + fl::matmul(fl::transpose(a), b, fl::MatrixProperty::Transpose), + ref + ) + ); } TEST(TensorBLASTest, matmulShapes) { - using T = fl::MatrixProperty; - // Matrix/vector/scalar multiplies - ASSERT_EQ(fl::matmul(fl::rand({10}), fl::rand({10})).shape(), Shape({1})); - ASSERT_EQ( - fl::matmul(fl::rand({10}), fl::rand({10}), T::Transpose).shape(), - Shape({1})); - ASSERT_EQ( - fl::matmul(fl::rand({10}), fl::rand({10}), T::Transpose, T::Transpose) - .shape(), - Shape({1})); - ASSERT_EQ( - fl::matmul(fl::rand({10}), fl::rand({10}), T::None, T::Transpose).shape(), - Shape({1})); - ASSERT_EQ(fl::matmul(fl::rand({1, 10}), fl::rand({10})).shape(), Shape({1})); - ASSERT_EQ(fl::matmul(fl::rand({1}), fl::rand({1, 10})).shape(), Shape({10})); - ASSERT_EQ( - fl::matmul(fl::rand({10}), fl::rand({10}), T::Transpose).shape(), - Shape({1})); - ASSERT_EQ(fl::matmul(fl::rand({3, 4}), fl::rand({4})).shape(), Shape({3})); - ASSERT_EQ(fl::matmul(fl::rand({5}), fl::rand({5, 7})).shape(), Shape({7})); - ASSERT_THROW(fl::matmul(fl::rand({1}), fl::rand({10})), std::exception); - ASSERT_THROW(fl::matmul(fl::rand({3}), fl::rand({5, 7})), std::exception); + using T = fl::MatrixProperty; + // Matrix/vector/scalar multiplies + ASSERT_EQ(fl::matmul(fl::rand({10}), fl::rand({10})).shape(), Shape({1})); + ASSERT_EQ( + fl::matmul(fl::rand({10}), fl::rand({10}), T::Transpose).shape(), + Shape({1}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({10}), fl::rand({10}), T::Transpose, T::Transpose) + .shape(), + Shape({1}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({10}), fl::rand({10}), T::None, T::Transpose).shape(), + Shape({1}) + ); + ASSERT_EQ(fl::matmul(fl::rand({1, 10}), fl::rand({10})).shape(), Shape({1})); + ASSERT_EQ(fl::matmul(fl::rand({1}), fl::rand({1, 10})).shape(), Shape({10})); + ASSERT_EQ( + fl::matmul(fl::rand({10}), fl::rand({10}), T::Transpose).shape(), + Shape({1}) + ); + ASSERT_EQ(fl::matmul(fl::rand({3, 4}), fl::rand({4})).shape(), Shape({3})); + ASSERT_EQ(fl::matmul(fl::rand({5}), fl::rand({5, 7})).shape(), Shape({7})); + ASSERT_THROW(fl::matmul(fl::rand({1}), fl::rand({10})), std::exception); + ASSERT_THROW(fl::matmul(fl::rand({3}), fl::rand({5, 7})), std::exception); - // Batch matrix multiply - unsigned M = 10; - unsigned K = 12; - unsigned N = 14; - unsigned b2 = 2; - unsigned b3 = 4; - ASSERT_EQ( - fl::matmul(fl::rand({M, K}), fl::rand({K, N})).shape(), Shape({M, N})); - ASSERT_EQ( - fl::matmul(fl::rand({M, K, b2}), fl::rand({K, N, b2})).shape(), - Shape({M, N, b2})); - ASSERT_EQ( - fl::matmul(fl::rand({M, K, b2, b3}), fl::rand({K, N, b2, b3})).shape(), - Shape({M, N, b2, b3})); - ASSERT_EQ( - fl::matmul(fl::rand({M, K, b2, b3}), fl::rand({K, N})).shape(), - Shape({M, N, b2, b3})); - ASSERT_EQ( - fl::matmul(fl::rand({M, K}), fl::rand({K, N, b2, b3})).shape(), - Shape({M, N, b2, b3})); - // Batch matrix multiply with transpose - ASSERT_EQ( - fl::matmul(fl::rand({K, M}), fl::rand({K, N}), T::Transpose).shape(), - Shape({M, N})); - ASSERT_EQ( - fl::matmul(fl::rand({M, K}), fl::rand({N, K}), T::None, T::Transpose) - .shape(), - Shape({M, N})); - // b2 transpose - ASSERT_EQ( - fl::matmul(fl::rand({K, M, b2}), fl::rand({K, N}), T::Transpose).shape(), - Shape({M, N, b2})); - ASSERT_EQ( - fl::matmul(fl::rand({M, K, b2}), fl::rand({N, K}), T::None, T::Transpose) - .shape(), - Shape({M, N, b2})); - ASSERT_EQ( - fl::matmul(fl::rand({K, M}), fl::rand({K, N, b2}), T::Transpose).shape(), - Shape({M, N, b2})); - ASSERT_EQ( - fl::matmul(fl::rand({M, K}), fl::rand({N, K, b2}), T::None, T::Transpose) - .shape(), - Shape({M, N, b2})); - ASSERT_EQ( - fl::matmul(fl::rand({K, M, b2}), fl::rand({K, N, b2}), T::Transpose) - .shape(), - Shape({M, N, b2})); - ASSERT_EQ( - fl::matmul( - fl::rand({M, K, b2}), fl::rand({N, K, b2}), T::None, T::Transpose) - .shape(), - Shape({M, N, b2})); - // b2, b3 transpose - ASSERT_EQ( - fl::matmul(fl::rand({K, M, b2, b3}), fl::rand({K, N}), T::Transpose) - .shape(), - Shape({M, N, b2, b3})); - ASSERT_EQ( - fl::matmul( - fl::rand({M, K, b2, b3}), fl::rand({N, K}), T::None, T::Transpose) - .shape(), - Shape({M, N, b2, b3})); - ASSERT_EQ( - fl::matmul(fl::rand({K, M}), fl::rand({K, N, b2, b3}), T::Transpose) - .shape(), - Shape({M, N, b2, b3})); - ASSERT_EQ( - fl::matmul( - fl::rand({M, K}), fl::rand({N, K, b2, b3}), T::None, T::Transpose) - .shape(), - Shape({M, N, b2, b3})); - ASSERT_EQ( - fl::matmul( - fl::rand({K, M, b2, b3}), fl::rand({K, N, b2, b3}), T::Transpose) - .shape(), - Shape({M, N, b2, b3})); - ASSERT_EQ( - fl::matmul( - fl::rand({M, K, b2, b3}), - fl::rand({N, K, b2, b3}), - T::None, - T::Transpose) - .shape(), - Shape({M, N, b2, b3})); + // Batch matrix multiply + unsigned M = 10; + unsigned K = 12; + unsigned N = 14; + unsigned b2 = 2; + unsigned b3 = 4; + ASSERT_EQ( + fl::matmul(fl::rand({M, K}), fl::rand({K, N})).shape(), + Shape({M, N}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({M, K, b2}), fl::rand({K, N, b2})).shape(), + Shape({M, N, b2}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({M, K, b2, b3}), fl::rand({K, N, b2, b3})).shape(), + Shape({M, N, b2, b3}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({M, K, b2, b3}), fl::rand({K, N})).shape(), + Shape({M, N, b2, b3}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({M, K}), fl::rand({K, N, b2, b3})).shape(), + Shape({M, N, b2, b3}) + ); + // Batch matrix multiply with transpose + ASSERT_EQ( + fl::matmul(fl::rand({K, M}), fl::rand({K, N}), T::Transpose).shape(), + Shape({M, N}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({M, K}), fl::rand({N, K}), T::None, T::Transpose) + .shape(), + Shape({M, N}) + ); + // b2 transpose + ASSERT_EQ( + fl::matmul(fl::rand({K, M, b2}), fl::rand({K, N}), T::Transpose).shape(), + Shape({M, N, b2}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({M, K, b2}), fl::rand({N, K}), T::None, T::Transpose) + .shape(), + Shape({M, N, b2}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({K, M}), fl::rand({K, N, b2}), T::Transpose).shape(), + Shape({M, N, b2}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({M, K}), fl::rand({N, K, b2}), T::None, T::Transpose) + .shape(), + Shape({M, N, b2}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({K, M, b2}), fl::rand({K, N, b2}), T::Transpose) + .shape(), + Shape({M, N, b2}) + ); + ASSERT_EQ( + fl::matmul( + fl::rand({M, K, b2}), + fl::rand({N, K, b2}), + T::None, + T::Transpose + ) + .shape(), + Shape({M, N, b2}) + ); + // b2, b3 transpose + ASSERT_EQ( + fl::matmul(fl::rand({K, M, b2, b3}), fl::rand({K, N}), T::Transpose) + .shape(), + Shape({M, N, b2, b3}) + ); + ASSERT_EQ( + fl::matmul( + fl::rand({M, K, b2, b3}), + fl::rand({N, K}), + T::None, + T::Transpose + ) + .shape(), + Shape({M, N, b2, b3}) + ); + ASSERT_EQ( + fl::matmul(fl::rand({K, M}), fl::rand({K, N, b2, b3}), T::Transpose) + .shape(), + Shape({M, N, b2, b3}) + ); + ASSERT_EQ( + fl::matmul( + fl::rand({M, K}), + fl::rand({N, K, b2, b3}), + T::None, + T::Transpose + ) + .shape(), + Shape({M, N, b2, b3}) + ); + ASSERT_EQ( + fl::matmul( + fl::rand({K, M, b2, b3}), + fl::rand({K, N, b2, b3}), + T::Transpose + ) + .shape(), + Shape({M, N, b2, b3}) + ); + ASSERT_EQ( + fl::matmul( + fl::rand({M, K, b2, b3}), + fl::rand({N, K, b2, b3}), + T::None, + T::Transpose + ) + .shape(), + Shape({M, N, b2, b3}) + ); - ASSERT_EQ( - fl::matmul( - fl::rand({256, 200, 2}), - fl::rand({256, 200, 2}), - T::None, - T::Transpose) - .shape(), - Shape({256, 256, 2})); + ASSERT_EQ( + fl::matmul( + fl::rand({256, 200, 2}), + fl::rand({256, 200, 2}), + T::None, + T::Transpose + ) + .shape(), + Shape({256, 256, 2}) + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/TensorBaseTest.cpp b/flashlight/fl/test/tensor/TensorBaseTest.cpp index 283c55e..1f14124 100644 --- a/flashlight/fl/test/tensor/TensorBaseTest.cpp +++ b/flashlight/fl/test/tensor/TensorBaseTest.cpp @@ -20,600 +20,660 @@ using namespace ::testing; using namespace fl; TEST(TensorBaseTest, DefaultConstruction) { - Tensor t; - ASSERT_EQ(t.shape(), Shape({0})); - ASSERT_EQ(t.type(), fl::dtype::f32); + Tensor t; + ASSERT_EQ(t.shape(), Shape({0})); + ASSERT_EQ(t.type(), fl::dtype::f32); - Tensor u({1, 2, 3}); - ASSERT_EQ(u.shape(), Shape({1, 2, 3})); - ASSERT_EQ(u.type(), fl::dtype::f32); - Tensor x({0, 3}); - ASSERT_EQ(x.shape(), Shape({0, 3})); + Tensor u({1, 2, 3}); + ASSERT_EQ(u.shape(), Shape({1, 2, 3})); + ASSERT_EQ(u.type(), fl::dtype::f32); + Tensor x({0, 3}); + ASSERT_EQ(x.shape(), Shape({0, 3})); - Tensor q(fl::dtype::f64); - ASSERT_EQ(q.shape(), Shape({0})); - ASSERT_EQ(q.type(), fl::dtype::f64); + Tensor q(fl::dtype::f64); + ASSERT_EQ(q.shape(), Shape({0})); + ASSERT_EQ(q.type(), fl::dtype::f64); - Tensor v({4, 5, 6}, fl::dtype::u64); - ASSERT_EQ(v.shape(), Shape({4, 5, 6})); - ASSERT_EQ(v.type(), fl::dtype::u64); + Tensor v({4, 5, 6}, fl::dtype::u64); + ASSERT_EQ(v.shape(), Shape({4, 5, 6})); + ASSERT_EQ(v.type(), fl::dtype::u64); } TEST(TensorBaseTest, CopyConstruction) { - Shape shape{2, 2}; - auto x = fl::full(shape, 0); - auto y = x; // actual copy (implementation may be CoW) + Shape shape{2, 2}; + auto x = fl::full(shape, 0); + auto y = x; // actual copy (implementation may be CoW) - ASSERT_TRUE(allClose(x, fl::full(shape, 0))); - ASSERT_TRUE(allClose(y, fl::full(shape, 0))); - x += 23; // affects both tensors - ASSERT_TRUE(allClose(x, fl::full(shape, 23))); - ASSERT_TRUE(allClose(y, fl::full(shape, 0))); + ASSERT_TRUE(allClose(x, fl::full(shape, 0))); + ASSERT_TRUE(allClose(y, fl::full(shape, 0))); + x += 23; // affects both tensors + ASSERT_TRUE(allClose(x, fl::full(shape, 23))); + ASSERT_TRUE(allClose(y, fl::full(shape, 0))); } TEST(TensorBaseTest, MoveConstruction) { - Shape shape{2, 2}; - auto x = fl::full(shape, 0); - auto y = x(span, span); // view of x + Shape shape{2, 2}; + auto x = fl::full(shape, 0); + auto y = x(span, span); // view of x - auto z = std::move(x); // `z` takes over `x`'s data - // TODO the following line (or any read to `y`, as it seems) promotes view to - // copy; to avoid this, we must update impl of `assign` - // ASSERT_TRUE(allClose(y, fl::full(shape, 0))); - ASSERT_TRUE(allClose(z, fl::full(shape, 0))); + auto z = std::move(x); // `z` takes over `x`'s data + // TODO the following line (or any read to `y`, as it seems) promotes view to + // copy; to avoid this, we must update impl of `assign` + // ASSERT_TRUE(allClose(y, fl::full(shape, 0))); + ASSERT_TRUE(allClose(z, fl::full(shape, 0))); - z += 42; // `y` is now a view of `z`, so it's affected - ASSERT_TRUE(allClose(y, fl::full(shape, 42))); - ASSERT_TRUE(allClose(z, fl::full(shape, 42))); + z += 42; // `y` is now a view of `z`, so it's affected + ASSERT_TRUE(allClose(y, fl::full(shape, 42))); + ASSERT_TRUE(allClose(z, fl::full(shape, 42))); } TEST(TensorBaseTest, AssignmentOperatorLvalueWithRvalue) { - Shape shape{2, 2}; - auto x = fl::full({2, 2}, 0); - auto y = x(span, span); + Shape shape{2, 2}; + auto x = fl::full({2, 2}, 0); + auto y = x(span, span); - // view as a lvalue cannot be used to update original tensor - y = fl::full({2, 2}, 42); // `x` isn't affected - y += 1; // `x` isn't affected - ASSERT_TRUE(allClose(x, fl::full(shape, 0))); - ASSERT_TRUE(allClose(y, fl::full(shape, 43))); + // view as a lvalue cannot be used to update original tensor + y = fl::full({2, 2}, 42); // `x` isn't affected + y += 1; // `x` isn't affected + ASSERT_TRUE(allClose(x, fl::full(shape, 0))); + ASSERT_TRUE(allClose(y, fl::full(shape, 43))); } TEST(TensorBaseTest, AssignmentOperatorLvalueWithLvalue) { - Shape shape{2, 2}; - auto x = fl::full({2, 2}, 0); - auto y = x(span, span); - auto z = fl::full({2, 2}, 1); + Shape shape{2, 2}; + auto x = fl::full({2, 2}, 0); + auto y = x(span, span); + auto z = fl::full({2, 2}, 1); - y = z; // `x` is a copy of `z` now (impl may be CoW) - y += 1; // `z` isn't affected - ASSERT_TRUE(allClose(x, fl::full(shape, 0))); - ASSERT_TRUE(allClose(y, fl::full(shape, 2))); - ASSERT_TRUE(allClose(z, fl::full(shape, 1))); + y = z; // `x` is a copy of `z` now (impl may be CoW) + y += 1; // `z` isn't affected + ASSERT_TRUE(allClose(x, fl::full(shape, 0))); + ASSERT_TRUE(allClose(y, fl::full(shape, 2))); + ASSERT_TRUE(allClose(z, fl::full(shape, 1))); } TEST(TensorBaseTest, AssignmentOperatorRvalueWithRvalue) { - Shape shape{2, 2}; - auto type = dtype::f32; - auto x = fl::full({2, 2}, 0, type); - auto y = x(span, span); + Shape shape{2, 2}; + auto type = dtype::f32; + auto x = fl::full({2, 2}, 0, type); + auto y = x(span, span); - x(0, span) = fl::full({2}, 1); // `x` is updated by copying from rhs data - auto res = fl::Tensor::fromVector(shape, {1, 0, 1, 0}, type); - ASSERT_TRUE(allClose(x, res)); - ASSERT_TRUE(allClose(y, res)); + x(0, span) = fl::full({2}, 1); // `x` is updated by copying from rhs data + auto res = fl::Tensor::fromVector(shape, {1, 0, 1, 0}, type); + ASSERT_TRUE(allClose(x, res)); + ASSERT_TRUE(allClose(y, res)); } TEST(TensorBaseTest, AssignmentOperatorRvalueWithLvalue) { - Shape shape{2, 2}; - auto type = dtype::f32; - auto x = fl::full(shape, 0, type); - auto y = x(span, span); // view of `x` - auto z = fl::full({2}, 1, type); + Shape shape{2, 2}; + auto type = dtype::f32; + auto x = fl::full(shape, 0, type); + auto y = x(span, span); // view of `x` + auto z = fl::full({2}, 1, type); - x(span, 1) = z; // `x` is updated by copying from `z`'s data - x += 1; // `z` isn't affected - auto res = fl::Tensor::fromVector(shape, {1, 1, 2, 2}, type); - ASSERT_TRUE(allClose(x, res)); - ASSERT_TRUE(allClose(y, res)); - ASSERT_TRUE(allClose(z, fl::full({2}, 1, type))); + x(span, 1) = z; // `x` is updated by copying from `z`'s data + x += 1; // `z` isn't affected + auto res = fl::Tensor::fromVector(shape, {1, 1, 2, 2}, type); + ASSERT_TRUE(allClose(x, res)); + ASSERT_TRUE(allClose(y, res)); + ASSERT_TRUE(allClose(z, fl::full({2}, 1, type))); } TEST(TensorBaseTest, Metadata) { - int s = 9; - auto t = fl::rand({s, s}); - ASSERT_EQ(t.elements(), s * s); - ASSERT_FALSE(t.isEmpty()); - ASSERT_EQ(t.bytes(), s * s * sizeof(float)); + int s = 9; + auto t = fl::rand({s, s}); + ASSERT_EQ(t.elements(), s * s); + ASSERT_FALSE(t.isEmpty()); + ASSERT_EQ(t.bytes(), s * s * sizeof(float)); - Tensor e; - ASSERT_EQ(e.elements(), 0); - ASSERT_TRUE(e.isEmpty()); - ASSERT_FALSE(e.isSparse()); - ASSERT_FALSE(e.isLocked()); + Tensor e; + ASSERT_EQ(e.elements(), 0); + ASSERT_TRUE(e.isEmpty()); + ASSERT_FALSE(e.isSparse()); + ASSERT_FALSE(e.isLocked()); } TEST(TensorBaseTest, fromScalar) { - Tensor a = fromScalar(3.14, fl::dtype::f32); - ASSERT_EQ(a.elements(), 1); - ASSERT_EQ(a.ndim(), 0); - ASSERT_FALSE(a.isEmpty()); - ASSERT_EQ(a.shape(), Shape({})); + Tensor a = fromScalar(3.14, fl::dtype::f32); + ASSERT_EQ(a.elements(), 1); + ASSERT_EQ(a.ndim(), 0); + ASSERT_FALSE(a.isEmpty()); + ASSERT_EQ(a.shape(), Shape({})); } TEST(TensorBaseTest, string) { - // Different backends might print tensors differently - check for consistency - // across two identical tensors - auto a = fl::full({3, 4, 5}, 6.); - auto b = fl::full({3, 4, 5}, 6.); - ASSERT_EQ(a.toString(), b.toString()); + // Different backends might print tensors differently - check for consistency + // across two identical tensors + auto a = fl::full({3, 4, 5}, 6.); + auto b = fl::full({3, 4, 5}, 6.); + ASSERT_EQ(a.toString(), b.toString()); - std::stringstream ssa, ssb; - ssa << a; - ssb << b; - ASSERT_EQ(ssa.str(), ssb.str()); + std::stringstream ssa, ssb; + ssa << a; + ssb << b; + ASSERT_EQ(ssa.str(), ssb.str()); } TEST(TensorBaseTest, AssignmentOperators) { - auto a = fl::full({3, 3}, 1.); - a += 2; - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 3.))); - a -= 1; - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 2.))); - a *= 8; - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 16.))); - a /= 4; - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 4.))); - - a = fl::full({4, 4}, 7.); - ASSERT_TRUE(allClose(a, fl::full({4, 4}, 7.))); - auto b = a; - ASSERT_TRUE(allClose(b, fl::full({4, 4}, 7.))); - a = 6.; - ASSERT_TRUE(allClose(a, fl::full({4, 4}, 6.))); - - a = fl::full({5, 6, 7}, 8.); - ASSERT_TRUE(allClose(a, fl::full({5, 6, 7}, 8.))); + auto a = fl::full({3, 3}, 1.); + a += 2; + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 3.))); + a -= 1; + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 2.))); + a *= 8; + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 16.))); + a /= 4; + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 4.))); + + a = fl::full({4, 4}, 7.); + ASSERT_TRUE(allClose(a, fl::full({4, 4}, 7.))); + auto b = a; + ASSERT_TRUE(allClose(b, fl::full({4, 4}, 7.))); + a = 6.; + ASSERT_TRUE(allClose(a, fl::full({4, 4}, 6.))); + + a = fl::full({5, 6, 7}, 8.); + ASSERT_TRUE(allClose(a, fl::full({5, 6, 7}, 8.))); } TEST(TensorBaseTest, CopyOperators) { - auto a = fl::full({3, 3}, 1.); - auto b = a; - a += 1; - ASSERT_TRUE(allClose(b, fl::full({3, 3}, 1.))); - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 2.))); + auto a = fl::full({3, 3}, 1.); + auto b = a; + a += 1; + ASSERT_TRUE(allClose(b, fl::full({3, 3}, 1.))); + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 2.))); - auto c = a.copy(); - a += 1; - ASSERT_TRUE(allClose(a, fl::full({3, 3}, 3.))); - ASSERT_TRUE(allClose(c, fl::full({3, 3}, 2.))); + auto c = a.copy(); + a += 1; + ASSERT_TRUE(allClose(a, fl::full({3, 3}, 3.))); + ASSERT_TRUE(allClose(c, fl::full({3, 3}, 2.))); } TEST(TensorBaseTest, ConstructFromData) { - // Tensor::fromVector - float val = 3.; - std::vector vec(100, val); - fl::Shape s = {10, 10}; - ASSERT_TRUE(allClose(fl::Tensor::fromVector(s, vec), fl::full(s, val))); - - ASSERT_TRUE(allClose( - fl::Tensor::fromBuffer(s, vec.data(), fl::MemoryLocation::Host), - fl::full(s, val))); - - std::vector ascending = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - auto t = fl::Tensor::fromVector({3, 4}, ascending); - ASSERT_EQ(t.type(), fl::dtype::f32); - for (int i = 0; i < ascending.size(); ++i) { - ASSERT_FLOAT_EQ(t(i % 3, i / 3).scalar(), ascending[i]); - } - - // TODO: add fixtures/check stuff - std::vector intV = {1, 2, 3}; - ASSERT_EQ(fl::Tensor::fromVector({3}, intV).type(), fl::dtype::s32); - ASSERT_EQ( - fl::Tensor::fromVector({5}, {0., 1., 2., 3., 4.}).type(), - fl::dtype::f32); - - std::vector flat = {0, 1, 2, 3, 4, 5, 6, 7}; - unsigned size = flat.size(); - ASSERT_EQ(fl::Tensor::fromVector(flat).shape(), Shape({size})); - - // Tensor::fromArray - constexpr unsigned arrFSize = 5; - std::array arrF = {1, 2, 3, 4, 5}; - auto tArrF = Tensor::fromArray(arrF); - ASSERT_EQ(tArrF.type(), fl::dtype::f32); - ASSERT_EQ(tArrF.shape(), Shape({arrFSize})); - auto tArrD = Tensor::fromArray({arrFSize}, arrF, fl::dtype::f64); - ASSERT_EQ(tArrD.type(), fl::dtype::f64); - - constexpr unsigned arrISize = 8; - std::array arrI = {1, 2, 3, 4, 5, 6, 7, 8}; - auto tArrI = Tensor::fromArray(arrI); - ASSERT_EQ(tArrI.type(), fl::dtype::u32); - ASSERT_EQ(tArrI.shape(), Shape({arrISize})); - auto tArrIs = Tensor::fromArray({2, 4}, arrI); - ASSERT_EQ(tArrIs.shape(), Shape({2, 4})); + // Tensor::fromVector + float val = 3.; + std::vector vec(100, val); + fl::Shape s = {10, 10}; + ASSERT_TRUE(allClose(fl::Tensor::fromVector(s, vec), fl::full(s, val))); + + ASSERT_TRUE( + allClose( + fl::Tensor::fromBuffer(s, vec.data(), fl::MemoryLocation::Host), + fl::full(s, val) + ) + ); + + std::vector ascending = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + auto t = fl::Tensor::fromVector({3, 4}, ascending); + ASSERT_EQ(t.type(), fl::dtype::f32); + for(int i = 0; i < ascending.size(); ++i) { + ASSERT_FLOAT_EQ(t(i % 3, i / 3).scalar(), ascending[i]); + } + + // TODO: add fixtures/check stuff + std::vector intV = {1, 2, 3}; + ASSERT_EQ(fl::Tensor::fromVector({3}, intV).type(), fl::dtype::s32); + ASSERT_EQ( + fl::Tensor::fromVector({5}, {0., 1., 2., 3., 4.}).type(), + fl::dtype::f32 + ); + + std::vector flat = {0, 1, 2, 3, 4, 5, 6, 7}; + unsigned size = flat.size(); + ASSERT_EQ(fl::Tensor::fromVector(flat).shape(), Shape({size})); + + // Tensor::fromArray + constexpr unsigned arrFSize = 5; + std::array arrF = {1, 2, 3, 4, 5}; + auto tArrF = Tensor::fromArray(arrF); + ASSERT_EQ(tArrF.type(), fl::dtype::f32); + ASSERT_EQ(tArrF.shape(), Shape({arrFSize})); + auto tArrD = Tensor::fromArray({arrFSize}, arrF, fl::dtype::f64); + ASSERT_EQ(tArrD.type(), fl::dtype::f64); + + constexpr unsigned arrISize = 8; + std::array arrI = {1, 2, 3, 4, 5, 6, 7, 8}; + auto tArrI = Tensor::fromArray(arrI); + ASSERT_EQ(tArrI.type(), fl::dtype::u32); + ASSERT_EQ(tArrI.shape(), Shape({arrISize})); + auto tArrIs = Tensor::fromArray({2, 4}, arrI); + ASSERT_EQ(tArrIs.shape(), Shape({2, 4})); } TEST(TensorBaseTest, reshape) { - auto a = fl::full({4, 4}, 3.); - auto b = fl::reshape(a, Shape({8, 2})); - ASSERT_EQ(b.shape(), Shape({8, 2})); - ASSERT_TRUE(allClose(a, fl::reshape(b, {4, 4}))); + auto a = fl::full({4, 4}, 3.); + auto b = fl::reshape(a, Shape({8, 2})); + ASSERT_EQ(b.shape(), Shape({8, 2})); + ASSERT_TRUE(allClose(a, fl::reshape(b, {4, 4}))); - ASSERT_THROW(fl::reshape(a, {}), std::exception); + ASSERT_THROW(fl::reshape(a, {}), std::exception); } TEST(TensorBaseTest, transpose) { - // TODO: expand to check els - ASSERT_TRUE( - allClose(fl::transpose(fl::full({3, 4}, 3.)), fl::full({4, 3}, 3.))); - ASSERT_TRUE(allClose( - fl::transpose(fl::full({4, 5, 6, 7}, 3.), {2, 0, 1, 3}), - fl::full({6, 4, 5, 7}, 3.))); - ASSERT_THROW(fl::transpose(fl::rand({3, 4, 5}), {0, 1}), std::exception); - ASSERT_THROW( - fl::transpose(fl::rand({2, 4, 6, 8}), {1, 0, 2}), std::exception); - ASSERT_THROW( - fl::transpose(fl::rand({2, 4, 6, 8}), {1, 0, 2, 4}), std::exception); - - auto a = fl::rand({4}); - ASSERT_TRUE(allClose(fl::transpose(a), a)); - - ASSERT_EQ(fl::transpose(fl::rand({5, 6, 7})).shape(), Shape({7, 6, 5})); - ASSERT_EQ(fl::transpose(fl::rand({5, 6, 1, 7})).shape(), Shape({7, 1, 6, 5})); - ASSERT_EQ(fl::transpose(fl::rand({1, 1})).shape(), Shape({1, 1})); - ASSERT_EQ( - fl::transpose(fl::rand({7, 2, 1, 3}), {0, 2, 1, 3}).shape(), - Shape({7, 1, 2, 3})); + // TODO: expand to check els + ASSERT_TRUE( + allClose(fl::transpose(fl::full({3, 4}, 3.)), fl::full({4, 3}, 3.)) + ); + ASSERT_TRUE( + allClose( + fl::transpose(fl::full({4, 5, 6, 7}, 3.), {2, 0, 1, 3}), + fl::full({6, 4, 5, 7}, 3.) + ) + ); + ASSERT_THROW(fl::transpose(fl::rand({3, 4, 5}), {0, 1}), std::exception); + ASSERT_THROW( + fl::transpose(fl::rand({2, 4, 6, 8}), {1, 0, 2}), + std::exception + ); + ASSERT_THROW( + fl::transpose(fl::rand({2, 4, 6, 8}), {1, 0, 2, 4}), + std::exception + ); + + auto a = fl::rand({4}); + ASSERT_TRUE(allClose(fl::transpose(a), a)); + + ASSERT_EQ(fl::transpose(fl::rand({5, 6, 7})).shape(), Shape({7, 6, 5})); + ASSERT_EQ(fl::transpose(fl::rand({5, 6, 1, 7})).shape(), Shape({7, 1, 6, 5})); + ASSERT_EQ(fl::transpose(fl::rand({1, 1})).shape(), Shape({1, 1})); + ASSERT_EQ( + fl::transpose(fl::rand({7, 2, 1, 3}), {0, 2, 1, 3}).shape(), + Shape({7, 1, 2, 3}) + ); } TEST(TensorBaseTest, tile) { - auto a = fl::full({4, 4}, 3.); - auto tiled = fl::tile(a, {2, 2}); - ASSERT_EQ(tiled.shape(), Shape({8, 8})); - ASSERT_TRUE(allClose(tiled, fl::full({8, 8}, 3.))); - ASSERT_EQ(fl::tile(a, {}).shape(), a.shape()); + auto a = fl::full({4, 4}, 3.); + auto tiled = fl::tile(a, {2, 2}); + ASSERT_EQ(tiled.shape(), Shape({8, 8})); + ASSERT_TRUE(allClose(tiled, fl::full({8, 8}, 3.))); + ASSERT_EQ(fl::tile(a, {}).shape(), a.shape()); - auto s = fl::fromScalar(3.14); - ASSERT_EQ(fl::tile(s, {3, 3}).shape(), Shape({3, 3})); - ASSERT_EQ(fl::tile(s, {}).shape(), s.shape()); + auto s = fl::fromScalar(3.14); + ASSERT_EQ(fl::tile(s, {3, 3}).shape(), Shape({3, 3})); + ASSERT_EQ(fl::tile(s, {}).shape(), s.shape()); } TEST(TensorBaseTest, concatenate) { - auto a = fl::full({3, 3}, 1.); - auto b = fl::full({3, 3}, 2.); - auto c = fl::full({3, 3}, 3.); - ASSERT_TRUE( - allClose(fl::concatenate(0, a, b, c), fl::concatenate({a, b, c}))); - auto out = fl::concatenate(0, a, b, c); - ASSERT_EQ(out.shape(), Shape({9, 3})); - - // Empty tenors - ASSERT_EQ(fl::concatenate(0, Tensor(), Tensor()).shape(), Shape({0})); - ASSERT_EQ(fl::concatenate(2, Tensor(), Tensor()).shape(), Shape({0, 1, 1})); - ASSERT_EQ( - fl::concatenate(1, fl::rand({5, 5}), Tensor()).shape(), Shape({5, 5})); - - // More tensors - // TODO{fl::Tensor}{concat} just concat everything once we enforce - // arbitrarily-many tensors - const float val = 3.; - const int axis = 0; - auto t = fl::concatenate( - axis, - fl::full({4, 2}, val), - fl::full({4, 2}, val), - fl::full({4, 2}, val), - fl::concatenate( - axis, - fl::full({4, 2}, val), - fl::full({4, 2}, val), - fl::full({4, 2}, val))); - ASSERT_EQ(t.shape(), Shape({24, 2})); - ASSERT_TRUE(allClose(t, fl::full({24, 2}, val))); + auto a = fl::full({3, 3}, 1.); + auto b = fl::full({3, 3}, 2.); + auto c = fl::full({3, 3}, 3.); + ASSERT_TRUE( + allClose(fl::concatenate(0, a, b, c), fl::concatenate({a, b, c})) + ); + auto out = fl::concatenate(0, a, b, c); + ASSERT_EQ(out.shape(), Shape({9, 3})); + + // Empty tenors + ASSERT_EQ(fl::concatenate(0, Tensor(), Tensor()).shape(), Shape({0})); + ASSERT_EQ(fl::concatenate(2, Tensor(), Tensor()).shape(), Shape({0, 1, 1})); + ASSERT_EQ( + fl::concatenate(1, fl::rand({5, 5}), Tensor()).shape(), + Shape({5, 5}) + ); + + // More tensors + // TODO{fl::Tensor}{concat} just concat everything once we enforce + // arbitrarily-many tensors + const float val = 3.; + const int axis = 0; + auto t = fl::concatenate( + axis, + fl::full({4, 2}, val), + fl::full({4, 2}, val), + fl::full({4, 2}, val), + fl::concatenate( + axis, + fl::full({4, 2}, val), + fl::full({4, 2}, val), + fl::full({4, 2}, val) + ) + ); + ASSERT_EQ(t.shape(), Shape({24, 2})); + ASSERT_TRUE(allClose(t, fl::full({24, 2}, val))); } TEST(TensorBaseTest, nonzero) { - std::vector idxs = {0, 1, 4, 9, 11, 23, 55, 82, 91}; - auto a = fl::full({10, 10}, 1, fl::dtype::u32); - for (const auto idx : idxs) { - a(idx / 10, idx % 10) = 0; - } - auto indices = fl::nonzero(a); - int nnz = a.elements() - idxs.size(); - ASSERT_EQ(indices.shape(), Shape({nnz})); - ASSERT_TRUE( - allClose(a.flatten()(indices), fl::full({nnz}, 1, fl::dtype::u32))); + std::vector idxs = {0, 1, 4, 9, 11, 23, 55, 82, 91}; + auto a = fl::full({10, 10}, 1, fl::dtype::u32); + for(const auto idx : idxs) { + a(idx / 10, idx % 10) = 0; + } + auto indices = fl::nonzero(a); + int nnz = a.elements() - idxs.size(); + ASSERT_EQ(indices.shape(), Shape({nnz})); + ASSERT_TRUE( + allClose(a.flatten()(indices), fl::full({nnz}, 1, fl::dtype::u32))); } TEST(TensorBaseTest, flatten) { - unsigned s = 6; - auto a = fl::full({s, s, s}, 2.); - auto flat = a.flatten(); - ASSERT_EQ(flat.shape(), Shape({s * s * s})); - ASSERT_TRUE(allClose(flat, fl::full({s * s * s}, 2.))); + unsigned s = 6; + auto a = fl::full({s, s, s}, 2.); + auto flat = a.flatten(); + ASSERT_EQ(flat.shape(), Shape({s * s * s})); + ASSERT_TRUE(allClose(flat, fl::full({s * s * s}, 2.))); } TEST(TensorBaseTest, pad) { - auto t = fl::rand({5, 2}); - auto zeroPadded = fl::pad(t, {{1, 2}, {3, 4}}); - auto zeroTest = fl::concatenate( - 1, - fl::full({8, 3}, 0.), - fl::concatenate(0, fl::full({1, 2}, 0.), t, fl::full({2, 2}, 0.)), - fl::full({8, 4}, 0.)); - ASSERT_TRUE(allClose(zeroPadded, zeroTest)); - - auto edgePadded = fl::pad(t, {{1, 1}, {2, 2}}, PadType::Edge); - auto vertTiled = fl::concatenate( - 0, - fl::reshape(t(0, fl::span), {1, 2}), - t, - fl::reshape(t(t.dim(0) - 1, fl::span), {1, 2})); - auto vTiled0 = vertTiled(fl::span, 0); - auto vTiled1 = vertTiled(fl::span, 1); - ASSERT_TRUE(allClose( - edgePadded, - fl::concatenate( - 1, fl::tile(vTiled0, {1, 3}), fl::tile(vTiled1, {1, 3})))); - - auto symmetricPadded = fl::pad(t, {{1, 1}, {2, 2}}, PadType::Symmetric); - ASSERT_TRUE(allClose( - symmetricPadded, - // TODO{fl::Tensor}{concat} just concat everything once we enforce - // arbitrarily-many tensors - fl::concatenate( - 1, - vTiled1, - vTiled0, - vTiled0, - fl::concatenate(1, vTiled1, vTiled1, vTiled0)))); + auto t = fl::rand({5, 2}); + auto zeroPadded = fl::pad(t, {{1, 2}, {3, 4}}); + auto zeroTest = fl::concatenate( + 1, + fl::full({8, 3}, 0.), + fl::concatenate(0, fl::full({1, 2}, 0.), t, fl::full({2, 2}, 0.)), + fl::full({8, 4}, 0.) + ); + ASSERT_TRUE(allClose(zeroPadded, zeroTest)); + + auto edgePadded = fl::pad(t, {{1, 1}, {2, 2}}, PadType::Edge); + auto vertTiled = fl::concatenate( + 0, + fl::reshape(t(0, fl::span), {1, 2}), + t, + fl::reshape(t(t.dim(0) - 1, fl::span), {1, 2}) + ); + auto vTiled0 = vertTiled(fl::span, 0); + auto vTiled1 = vertTiled(fl::span, 1); + ASSERT_TRUE( + allClose( + edgePadded, + fl::concatenate( + 1, + fl::tile(vTiled0, {1, 3}), + fl::tile(vTiled1, {1, 3}) + ) + ) + ); + + auto symmetricPadded = fl::pad(t, {{1, 1}, {2, 2}}, PadType::Symmetric); + ASSERT_TRUE( + allClose( + symmetricPadded, + // TODO{fl::Tensor}{concat} just concat everything once we enforce + // arbitrarily-many tensors + fl::concatenate( + 1, + vTiled1, + vTiled0, + vTiled0, + fl::concatenate(1, vTiled1, vTiled1, vTiled0) + ) + ) + ); } TEST(TensorBaseTest, astype) { - auto a = fl::rand({3, 3}); - ASSERT_EQ(a.type(), dtype::f32); - ASSERT_EQ(a.astype(dtype::f64).type(), dtype::f64); + auto a = fl::rand({3, 3}); + ASSERT_EQ(a.type(), dtype::f32); + ASSERT_EQ(a.astype(dtype::f64).type(), dtype::f64); } TEST(TensorBaseTest, where) { - auto a = Tensor::fromVector({2, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - auto out = fl::where(a < 5, a, a * 10); - a(a >= 5) *= 10; - ASSERT_TRUE(allClose(out, a)); - auto outC = fl::where(a < 5, a, 3); - a(a >= 5) = 3; - ASSERT_TRUE(allClose(outC, a)); - auto outC2 = fl::where(a < 5, 3, a); - a(a < 5) = 3; - ASSERT_TRUE(allClose(outC2, a)); - - // non b8-type vector throws - EXPECT_THROW( - fl::where((a < 5).astype(fl::dtype::f32), a, a * 10), std::exception); + auto a = Tensor::fromVector({2, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto out = fl::where(a < 5, a, a * 10); + a(a >= 5) *= 10; + ASSERT_TRUE(allClose(out, a)); + auto outC = fl::where(a < 5, a, 3); + a(a >= 5) = 3; + ASSERT_TRUE(allClose(outC, a)); + auto outC2 = fl::where(a < 5, 3, a); + a(a < 5) = 3; + ASSERT_TRUE(allClose(outC2, a)); + + // non b8-type vector throws + EXPECT_THROW( + fl::where((a < 5).astype(fl::dtype::f32), a, a * 10), + std::exception + ); } TEST(TensorBaseTest, topk) { - auto a = fl::arange({10, 2}); - Tensor values; - Tensor indices; - fl::topk(values, indices, a, /* k = */ 3, /* axis = */ 0); // descending sort - ASSERT_TRUE( - allClose(values, Tensor::fromVector({3, 2}, {9, 8, 7, 9, 8, 7}))); - - fl::topk( - values, - indices, - a, - /* k = */ 4, - /* axis = */ 0, - fl::SortMode::Ascending); - ASSERT_TRUE(allClose( - values, Tensor::fromVector({4, 2}, {0, 1, 2, 3, 0, 1, 2, 3}))); + auto a = fl::arange({10, 2}); + Tensor values; + Tensor indices; + fl::topk(values, indices, a, /* k = */ 3, /* axis = */ 0); // descending sort + ASSERT_TRUE( + allClose(values, Tensor::fromVector({3, 2}, {9, 8, 7, 9, 8, 7})) + ); + + fl::topk( + values, + indices, + a, + /* k = */ 4, + /* axis = */ 0, + fl::SortMode::Ascending + ); + ASSERT_TRUE( + allClose( + values, + Tensor::fromVector({4, 2}, {0, 1, 2, 3, 0, 1, 2, 3}) + ) + ); } TEST(TensorBaseTest, sort) { - Shape dims({10, 2}); - auto a = fl::arange(dims); - auto sorted = fl::sort(a, /* axis = */ 0, SortMode::Descending); - - Tensor expected({dims[0]}, a.type()); - for (int i = 0; i < dims[0]; ++i) { - expected(i) = dims[0] - i - 1; - } - auto tiled = fl::tile(expected, {1, 2}); - ASSERT_TRUE(allClose(sorted, tiled)); - - ASSERT_TRUE(allClose(a, fl::sort(tiled, 0, SortMode::Ascending))); - - auto b = fl::rand({10}); - Tensor values, indices; - fl::sort(values, indices, b, /* axis = */ 0, SortMode::Descending); - ASSERT_TRUE( - allClose(values, fl::sort(b, /* axis = */ 0, SortMode::Descending))); - ASSERT_TRUE( - allClose(fl::argsort(b, /* axis = */ 0, SortMode::Descending), indices)); + Shape dims({10, 2}); + auto a = fl::arange(dims); + auto sorted = fl::sort(a, /* axis = */ 0, SortMode::Descending); + + Tensor expected({dims[0]}, a.type()); + for(int i = 0; i < dims[0]; ++i) { + expected(i) = dims[0] - i - 1; + } + auto tiled = fl::tile(expected, {1, 2}); + ASSERT_TRUE(allClose(sorted, tiled)); + + ASSERT_TRUE(allClose(a, fl::sort(tiled, 0, SortMode::Ascending))); + + auto b = fl::rand({10}); + Tensor values, indices; + fl::sort(values, indices, b, /* axis = */ 0, SortMode::Descending); + ASSERT_TRUE( + allClose(values, fl::sort(b, /* axis = */ 0, SortMode::Descending)) + ); + ASSERT_TRUE( + allClose(fl::argsort(b, /* axis = */ 0, SortMode::Descending), indices) + ); } TEST(TensorBaseTest, argsort) { - Shape dims({10, 2}); - auto a = fl::arange(dims); - auto sorted = fl::argsort(a, /* axis = */ 0, SortMode::Descending); + Shape dims({10, 2}); + auto a = fl::arange(dims); + auto sorted = fl::argsort(a, /* axis = */ 0, SortMode::Descending); - Tensor expected({dims[0]}, fl::dtype::u32); - for (int i = 0; i < dims[0]; ++i) { - expected(i) = dims[0] - i - 1; - } - auto tiled = fl::tile(expected, {1, 2}); - ASSERT_TRUE(allClose(sorted, tiled)); + Tensor expected({dims[0]}, fl::dtype::u32); + for(int i = 0; i < dims[0]; ++i) { + expected(i) = dims[0] - i - 1; + } + auto tiled = fl::tile(expected, {1, 2}); + ASSERT_TRUE(allClose(sorted, tiled)); - ASSERT_TRUE(allClose(tiled, fl::argsort(tiled, 0, SortMode::Ascending))); + ASSERT_TRUE(allClose(tiled, fl::argsort(tiled, 0, SortMode::Ascending))); } -template +template void assertScalarBehavior(fl::dtype type) { - ScalarArgType scalar = 42; // small enough for any scalar type - auto one = fl::full({1}, scalar, type); + ScalarArgType scalar = 42; // small enough for any scalar type + auto one = fl::full({1}, scalar, type); - if (dtype_traits::fl_type != type) { - ASSERT_THROW(one.template scalar(), std::invalid_argument) + if(dtype_traits::fl_type != type) { + ASSERT_THROW(one.template scalar(), std::invalid_argument) << "dtype: " << type << ", ScalarArgType: " << dtype_traits::getName(); - return; - } - - if ((type == fl::dtype::f16) || (type == fl::dtype::f32) || - (type == fl::dtype::f64)) { - ASSERT_FLOAT_EQ(one.template scalar(), scalar) + return; + } + + if( + (type == fl::dtype::f16) || (type == fl::dtype::f32) + || (type == fl::dtype::f64) + ) { + ASSERT_FLOAT_EQ(one.template scalar(), scalar) << "dtype: " << type << ", ScalarArgType: " << dtype_traits::getName(); - } else { - ASSERT_EQ(one.template scalar(), scalar) + } else { + ASSERT_EQ(one.template scalar(), scalar) << "dtype: " << type << ", ScalarArgType: " << dtype_traits::getName(); - } + } + + ScalarArgType val = static_cast(rand()); + auto a = fl::full({5, 6}, val, type); - ScalarArgType val = static_cast(rand()); - auto a = fl::full({5, 6}, val, type); - - ASSERT_TRUE(allClose(fl::full({1}, a.template scalar(), type), a(0, 0))) - << "dtype: " << type - << ", ScalarArgType: " << dtype_traits::getName(); + ASSERT_TRUE(allClose(fl::full({1}, a.template scalar(), type), a(0, 0))) + << "dtype: " << type + << ", ScalarArgType: " << dtype_traits::getName(); } TEST(TensorBaseTest, scalar) { - auto types = { - fl::dtype::b8, - fl::dtype::u8, - fl::dtype::s16, - fl::dtype::u16, - fl::dtype::s32, - fl::dtype::u32, - fl::dtype::s64, - fl::dtype::u64, - fl::dtype::f16, - fl::dtype::f32, - fl::dtype::f64}; - for (auto type : types) { - assertScalarBehavior(type); - assertScalarBehavior(type); - assertScalarBehavior(type); - assertScalarBehavior(type); - assertScalarBehavior(type); - assertScalarBehavior(type); - assertScalarBehavior(type); - assertScalarBehavior(type); - assertScalarBehavior(type); - assertScalarBehavior(type); - assertScalarBehavior(type); - assertScalarBehavior(type); - } + auto types = { + fl::dtype::b8, + fl::dtype::u8, + fl::dtype::s16, + fl::dtype::u16, + fl::dtype::s32, + fl::dtype::u32, + fl::dtype::s64, + fl::dtype::u64, + fl::dtype::f16, + fl::dtype::f32, + fl::dtype::f64}; + for(auto type : types) { + assertScalarBehavior(type); + assertScalarBehavior(type); + assertScalarBehavior(type); + assertScalarBehavior(type); + assertScalarBehavior(type); + assertScalarBehavior(type); + assertScalarBehavior(type); + assertScalarBehavior(type); + assertScalarBehavior(type); + assertScalarBehavior(type); + assertScalarBehavior(type); + assertScalarBehavior(type); + } } TEST(TensorBaseTest, isContiguous) { - // Contiguous by default - auto a = fl::rand({10, 10}); - ASSERT_TRUE(a.isContiguous()); + // Contiguous by default + auto a = fl::rand({10, 10}); + ASSERT_TRUE(a.isContiguous()); } TEST(TensorBaseTest, strides) { - auto t = fl::rand({10, 10}); - ASSERT_EQ(t.strides(), Shape({1, 10})); + auto t = fl::rand({10, 10}); + ASSERT_EQ(t.strides(), Shape({1, 10})); } TEST(TensorBaseTest, stream) { - auto t1 = fl::rand({10, 10}); - auto t2 = -t1; - auto t3 = t1 + t2; - ASSERT_EQ(&t1.stream(), &t2.stream()); - ASSERT_EQ(&t1.stream(), &t3.stream()); + auto t1 = fl::rand({10, 10}); + auto t2 = -t1; + auto t3 = t1 + t2; + ASSERT_EQ(&t1.stream(), &t2.stream()); + ASSERT_EQ(&t1.stream(), &t3.stream()); } TEST(TensorBaseTest, asContiguousTensor) { - auto t = fl::rand({5, 6, 7, 8}); - auto indexed = - t(fl::range(1, 4, 2), - fl::range(0, 6, 2), - fl::range(0, 6, 3), - fl::range(0, 5, 3)); - - auto contiguous = indexed.asContiguousTensor(); - std::vector strides; - unsigned stride = 1; - for (unsigned i = 0; i < contiguous.ndim(); ++i) { - strides.push_back(stride); - stride *= contiguous.dim(i); - } - ASSERT_EQ(contiguous.strides(), Shape(strides)); + auto t = fl::rand({5, 6, 7, 8}); + auto indexed = + t( + fl::range(1, 4, 2), + fl::range(0, 6, 2), + fl::range(0, 6, 3), + fl::range(0, 5, 3) + ); + + auto contiguous = indexed.asContiguousTensor(); + std::vector strides; + unsigned stride = 1; + for(unsigned i = 0; i < contiguous.ndim(); ++i) { + strides.push_back(stride); + stride *= contiguous.dim(i); + } + ASSERT_EQ(contiguous.strides(), Shape(strides)); } TEST(TensorBaseTest, host) { - auto a = fl::rand({10, 10}); + auto a = fl::rand({10, 10}); - float* ptr = a.host(); - for (int i = 0; i < a.elements(); ++i) { - ASSERT_EQ(ptr[i], a.flatten()(i).scalar()); - } + float* ptr = a.host(); + for(int i = 0; i < a.elements(); ++i) { + ASSERT_EQ(ptr[i], a.flatten()(i).scalar()); + } - float* existingBuffer = new float[100]; - a.host(existingBuffer); - for (int i = 0; i < a.elements(); ++i) { - ASSERT_EQ(existingBuffer[i], a.flatten()(i).scalar()); - } + float* existingBuffer = new float[100]; + a.host(existingBuffer); + for(int i = 0; i < a.elements(); ++i) { + ASSERT_EQ(existingBuffer[i], a.flatten()(i).scalar()); + } - ASSERT_EQ(Tensor().host(), nullptr); + ASSERT_EQ(Tensor().host(), nullptr); } TEST(TensorBaseTest, toHostVector) { - auto a = fl::rand({10, 10}); - auto vec = a.toHostVector(); + auto a = fl::rand({10, 10}); + auto vec = a.toHostVector(); - for (int i = 0; i < a.elements(); ++i) { - ASSERT_EQ(vec[i], a.flatten()(i).scalar()); - } + for(int i = 0; i < a.elements(); ++i) { + ASSERT_EQ(vec[i], a.flatten()(i).scalar()); + } - ASSERT_EQ(Tensor().toHostVector().size(), 0); + ASSERT_EQ(Tensor().toHostVector().size(), 0); } TEST(TensorBaseTest, arange) { - // Range/step overload - ASSERT_TRUE( - allClose(fl::arange(2, 10, 2), Tensor::fromVector({2, 4, 6, 8}))); - ASSERT_TRUE( - allClose(fl::arange(0, 6), Tensor::fromVector({0, 1, 2, 3, 4, 5}))); - ASSERT_TRUE(allClose( - fl::arange(0., 1.22, 0.25), - Tensor::fromVector({0., 0.25, 0.5, 0.75}))); - ASSERT_TRUE(allClose( - fl::arange(0., 4.1), Tensor::fromVector({0., 1., 2., 3.}))); - - // Shape overload - auto v = Tensor::fromVector({0., 1., 2., 3.}); - ASSERT_TRUE(allClose(fl::arange({4}), v)); - - ASSERT_TRUE(allClose(fl::arange({4, 5}), fl::tile(v, {1, 5}))); - ASSERT_EQ(fl::arange({4, 5}, 1).shape(), Shape({4, 5})); - ASSERT_TRUE(allClose( - fl::arange({4, 5}, 1), - fl::tile( - fl::reshape(Tensor::fromVector({0., 1., 2., 3., 4.}), {1, 5}), - {4}))); - ASSERT_EQ(fl::arange({2, 6}, 0, fl::dtype::f64).type(), fl::dtype::f64); + // Range/step overload + ASSERT_TRUE( + allClose(fl::arange(2, 10, 2), Tensor::fromVector({2, 4, 6, 8})) + ); + ASSERT_TRUE( + allClose(fl::arange(0, 6), Tensor::fromVector({0, 1, 2, 3, 4, 5})) + ); + ASSERT_TRUE( + allClose( + fl::arange(0., 1.22, 0.25), + Tensor::fromVector({0., 0.25, 0.5, 0.75}) + ) + ); + ASSERT_TRUE( + allClose( + fl::arange(0., 4.1), + Tensor::fromVector({0., 1., 2., 3.}) + ) + ); + + // Shape overload + auto v = Tensor::fromVector({0., 1., 2., 3.}); + ASSERT_TRUE(allClose(fl::arange({4}), v)); + + ASSERT_TRUE(allClose(fl::arange({4, 5}), fl::tile(v, {1, 5}))); + ASSERT_EQ(fl::arange({4, 5}, 1).shape(), Shape({4, 5})); + ASSERT_TRUE( + allClose( + fl::arange({4, 5}, 1), + fl::tile( + fl::reshape(Tensor::fromVector({0., 1., 2., 3., 4.}), {1, 5}), + {4} + ) + ) + ); + ASSERT_EQ(fl::arange({2, 6}, 0, fl::dtype::f64).type(), fl::dtype::f64); } TEST(TensorBaseTest, iota) { - ASSERT_TRUE(allClose( - fl::iota({5, 3}, {1, 2}), - fl::tile(fl::reshape(fl::arange({15}), {5, 3}), {1, 2}))); - ASSERT_EQ(fl::iota({2, 2}, {2, 2}, fl::dtype::f64).type(), fl::dtype::f64); - ASSERT_EQ(fl::iota({1, 10}, {5}).shape(), Shape({5, 10})); + ASSERT_TRUE( + allClose( + fl::iota({5, 3}, {1, 2}), + fl::tile(fl::reshape(fl::arange({15}), {5, 3}), {1, 2}) + ) + ); + ASSERT_EQ(fl::iota({2, 2}, {2, 2}, fl::dtype::f64).type(), fl::dtype::f64); + ASSERT_EQ(fl::iota({1, 10}, {5}).shape(), Shape({5, 10})); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp b/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp index 2753144..ea63d03 100644 --- a/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp +++ b/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp @@ -20,490 +20,497 @@ using namespace fl; namespace { // Always cast towards potentially signed type because otherwise ArrayFire may // clip values, e.g., negative casting to unsigned becomse 0. -template +template void assertTensorScalarBinop( const Tensor& in, ScalarType scalar, Op op, - const Tensor& expectOut) { - auto result = op(in, scalar); - auto expect = expectOut.astype(result.type()); - ASSERT_TRUE(allClose(result, expect)) - << "in.type(): " << in.type() - << ", ScalarType: " << dtype_traits::getName(); + const Tensor& expectOut +) { + auto result = op(in, scalar); + auto expect = expectOut.astype(result.type()); + ASSERT_TRUE(allClose(result, expect)) + << "in.type(): " << in.type() + << ", ScalarType: " << dtype_traits::getName(); } -template +template void assertScalarTensorBinop( ScalarType scalar, const Tensor& in, Op op, - const Tensor& expectOut) { - auto result = op(scalar, in); - auto expect = expectOut.astype(result.type()); - ASSERT_TRUE(allClose(result, expect)) - << "ScalarType: " << dtype_traits::getName() - << ", in.type(): " << in.type(); + const Tensor& expectOut +) { + auto result = op(scalar, in); + auto expect = expectOut.astype(result.type()); + ASSERT_TRUE(allClose(result, expect)) + << "ScalarType: " << dtype_traits::getName() + << ", in.type(): " << in.type(); } -template +template void assertScalarTensorCommutativeBinop( char scalar, const Tensor& in, Op op, - const Tensor& out) { - assertScalarTensorBinop(scalar, in, op, out); - assertTensorScalarBinop(in, scalar, op, out); + const Tensor& out +) { + assertScalarTensorBinop(scalar, in, op, out); + assertTensorScalarBinop(in, scalar, op, out); } -template +template void assertCommutativeBinop( const Tensor& in1, const Tensor& in2, Op op, - const Tensor& out) { - ASSERT_TRUE(allClose(op(in1, in2), out)) - << "in1.type(): " << in1.type() << ", in2.type(): " << in2.type(); - ASSERT_TRUE(allClose(op(in2, in1), out)) - << "in1.type(): " << in1.type() << ", in2.type(): " << in2.type(); + const Tensor& out +) { + ASSERT_TRUE(allClose(op(in1, in2), out)) + << "in1.type(): " << in1.type() << ", in2.type(): " << in2.type(); + ASSERT_TRUE(allClose(op(in2, in1), out)) + << "in1.type(): " << in1.type() << ", in2.type(): " << in2.type(); } void applyToAllFpDtypes(std::function func) { - func(dtype::f16); - func(dtype::f32); - func(dtype::f64); + func(dtype::f16); + func(dtype::f32); + func(dtype::f64); } void applyToAllIntegralDtypes(std::function func) { - // TODO casting to `b8` clips values to 0 and 1, which breaks the fixtures - // func(dtype::b8); - func(dtype::u8); - func(dtype::s16); - func(dtype::u16); - func(dtype::s32); - func(dtype::u32); - func(dtype::s64); - func(dtype::u64); + // TODO casting to `b8` clips values to 0 and 1, which breaks the fixtures + // func(dtype::b8); + func(dtype::u8); + func(dtype::s16); + func(dtype::u16); + func(dtype::s32); + func(dtype::u32); + func(dtype::s64); + func(dtype::u64); } void applyToAllDtypes(std::function func) { - applyToAllFpDtypes(func); - applyToAllIntegralDtypes(func); + applyToAllFpDtypes(func); + applyToAllIntegralDtypes(func); } } // namespace TEST(TensorBinaryOpsTest, ArithmeticBinaryOperators) { - auto testArithmeticBinops = [](dtype type) { - auto a = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).astype(type); - auto b = Tensor::fromVector({2, 2}, {1, 2, 3, 4}).astype(type); - auto c = Tensor::fromVector({2, 2}, {1, 3, 5, 7}).astype(type); - auto d = Tensor::fromVector({2, 2}, {1, 6, 15, 28}).astype(type); - auto e = Tensor::fromVector({2, 2}, {3, 2, 1, 0}).astype(type); - auto f = Tensor::fromVector({2, 2}, {2, 4, 6, 8}).astype(type); - auto z = fl::full({2, 2}, 0, type); - - assertCommutativeBinop(a, z, std::plus<>(), a); - assertCommutativeBinop(a, b, std::plus<>(), c); - assertScalarTensorCommutativeBinop(1, a, std::plus<>(), b); - assertScalarTensorCommutativeBinop(0, a, std::plus<>(), a); - - ASSERT_TRUE(allClose((c - z), c)) << "dtype: " << type; - ASSERT_TRUE(allClose((z - c), -c)) << "dtype: " << type; - ASSERT_TRUE(allClose((c - b), a)) << "dtype: " << type; - assertTensorScalarBinop(b, 1, std::minus<>(), a); - assertScalarTensorBinop(3, a, std::minus<>(), e); - assertTensorScalarBinop(a, 0, std::minus<>(), a); - - assertCommutativeBinop(c, z, std::multiplies<>(), z); - assertCommutativeBinop(c, b, std::multiplies<>(), d); - assertScalarTensorCommutativeBinop(0, a, std::multiplies<>(), z); - assertScalarTensorCommutativeBinop(1, a, std::multiplies<>(), a); - assertScalarTensorCommutativeBinop(2, b, std::multiplies<>(), f); - - ASSERT_TRUE(allClose((z / b), z)) << "dtype: " << type; - ASSERT_TRUE(allClose((d / b), c)) << "dtype: " << type; - assertTensorScalarBinop(z, 1, std::divides<>(), z); - assertTensorScalarBinop(a, 1, std::divides<>(), a); - assertTensorScalarBinop(f, 2, std::divides<>(), b); - // TODO division by zero doesn't always fail. - // e.g., ArrayFire yields max value of dtype - }; - - applyToAllDtypes(testArithmeticBinops); + auto testArithmeticBinops = [](dtype type) { + auto a = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).astype(type); + auto b = Tensor::fromVector({2, 2}, {1, 2, 3, 4}).astype(type); + auto c = Tensor::fromVector({2, 2}, {1, 3, 5, 7}).astype(type); + auto d = Tensor::fromVector({2, 2}, {1, 6, 15, 28}).astype(type); + auto e = Tensor::fromVector({2, 2}, {3, 2, 1, 0}).astype(type); + auto f = Tensor::fromVector({2, 2}, {2, 4, 6, 8}).astype(type); + auto z = fl::full({2, 2}, 0, type); + + assertCommutativeBinop(a, z, std::plus<>(), a); + assertCommutativeBinop(a, b, std::plus<>(), c); + assertScalarTensorCommutativeBinop(1, a, std::plus<>(), b); + assertScalarTensorCommutativeBinop(0, a, std::plus<>(), a); + + ASSERT_TRUE(allClose((c - z), c)) << "dtype: " << type; + ASSERT_TRUE(allClose((z - c), -c)) << "dtype: " << type; + ASSERT_TRUE(allClose((c - b), a)) << "dtype: " << type; + assertTensorScalarBinop(b, 1, std::minus<>(), a); + assertScalarTensorBinop(3, a, std::minus<>(), e); + assertTensorScalarBinop(a, 0, std::minus<>(), a); + + assertCommutativeBinop(c, z, std::multiplies<>(), z); + assertCommutativeBinop(c, b, std::multiplies<>(), d); + assertScalarTensorCommutativeBinop(0, a, std::multiplies<>(), z); + assertScalarTensorCommutativeBinop(1, a, std::multiplies<>(), a); + assertScalarTensorCommutativeBinop(2, b, std::multiplies<>(), f); + + ASSERT_TRUE(allClose((z / b), z)) << "dtype: " << type; + ASSERT_TRUE(allClose((d / b), c)) << "dtype: " << type; + assertTensorScalarBinop(z, 1, std::divides<>(), z); + assertTensorScalarBinop(a, 1, std::divides<>(), a); + assertTensorScalarBinop(f, 2, std::divides<>(), b); + // TODO division by zero doesn't always fail. + // e.g., ArrayFire yields max value of dtype + }; + + applyToAllDtypes(testArithmeticBinops); } TEST(TensorBinaryOpsTest, ComparisonBinaryOperators) { - auto falses = fl::full({2, 2}, 0, dtype::b8); - auto trues = fl::full({2, 2}, 1, dtype::b8); - auto falseTrues = - Tensor::fromVector({2, 2}, {0, 1, 0, 1}).astype(fl::dtype::b8); - auto trueFalses = - Tensor::fromVector({2, 2}, {1, 0, 1, 0}).astype(fl::dtype::b8); - - auto testComparisonBinops = [&](dtype type) { - auto a = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).astype(type); - auto b = Tensor::fromVector({2, 2}, {0, 0, 2, 0}).astype(type); - auto c = Tensor::fromVector({2, 2}, {2, 3, 4, 5}).astype(type); - auto d = Tensor::fromVector({2, 2}, {0, 4, 2, 6}).astype(type); - auto e = Tensor::fromVector({2, 2}, {0, 1, 0, 1}).astype(type); - - ASSERT_TRUE(allClose((a == a), trues)) << "dtype: " << type; - assertCommutativeBinop(a, b, std::equal_to<>(), trueFalses); - assertCommutativeBinop(a, c, std::equal_to<>(), falses); - assertScalarTensorCommutativeBinop(4, a, std::equal_to<>(), falses); - assertScalarTensorCommutativeBinop(1, e, std::equal_to<>(), falseTrues); - - ASSERT_TRUE(allClose((a != a), falses)) << "dtype: " << type; - assertCommutativeBinop(a, b, std::not_equal_to<>(), falseTrues); - assertCommutativeBinop(a, c, std::not_equal_to<>(), trues); - assertScalarTensorCommutativeBinop(4, a, std::not_equal_to<>(), trues); - assertScalarTensorCommutativeBinop(1, e, std::not_equal_to<>(), trueFalses); - - ASSERT_TRUE(allClose((a > a), falses)) << "dtype: " << type; - ASSERT_TRUE(allClose((c > a), trues)) << "dtype: " << type; - ASSERT_TRUE(allClose((d > a), falseTrues)) << "dtype: " << type; - ASSERT_TRUE(allClose((a > d), falses)) << "dtype: " << type; - assertTensorScalarBinop(c, 1, std::greater<>(), trues); - assertScalarTensorBinop(0, c, std::greater<>(), falses); - assertTensorScalarBinop(d, 3, std::greater<>(), falseTrues); - assertScalarTensorBinop(3, d, std::greater<>(), trueFalses); - - ASSERT_TRUE(allClose((a < a), falses)) << "dtype: " << type; - ASSERT_TRUE(allClose((c < a), falses)) << "dtype: " << type; - ASSERT_TRUE(allClose((d < a), falses)) << "dtype: " << type; - ASSERT_TRUE(allClose((a < d), falseTrues)) << "dtype: " << type; - assertTensorScalarBinop(c, 1, std::less<>(), falses); - assertScalarTensorBinop(0, c, std::less<>(), trues); - assertTensorScalarBinop(d, 3, std::less<>(), trueFalses); - assertScalarTensorBinop(3, d, std::less<>(), falseTrues); - - ASSERT_TRUE(allClose((a >= a), trues)) << "dtype: " << type; - ASSERT_TRUE(allClose((c >= a), trues)) << "dtype: " << type; - ASSERT_TRUE(allClose((d >= a), trues)) << "dtype: " << type; - ASSERT_TRUE(allClose((a >= d), trueFalses)) << "dtype: " << type; - assertTensorScalarBinop(c, 2, std::greater_equal<>(), trues); - assertScalarTensorBinop(1, c, std::greater_equal<>(), falses); - assertTensorScalarBinop(d, 3, std::greater_equal<>(), falseTrues); - assertScalarTensorBinop(3, d, std::greater_equal<>(), trueFalses); - - ASSERT_TRUE(allClose((a <= a), trues)) << "dtype: " << type; - ASSERT_TRUE(allClose((c <= a), falses)) << "dtype: " << type; - ASSERT_TRUE(allClose((d <= a), trueFalses)) << "dtype: " << type; - ASSERT_TRUE(allClose((a <= d), trues)) << "dtype: " << type; - assertTensorScalarBinop(c, 1, std::less_equal<>(), falses); - assertScalarTensorBinop(2, c, std::less_equal<>(), trues); - assertTensorScalarBinop(d, 3, std::less_equal<>(), trueFalses); - assertScalarTensorBinop(3, d, std::less_equal<>(), falseTrues); - }; - - applyToAllDtypes(testComparisonBinops); + auto falses = fl::full({2, 2}, 0, dtype::b8); + auto trues = fl::full({2, 2}, 1, dtype::b8); + auto falseTrues = + Tensor::fromVector({2, 2}, {0, 1, 0, 1}).astype(fl::dtype::b8); + auto trueFalses = + Tensor::fromVector({2, 2}, {1, 0, 1, 0}).astype(fl::dtype::b8); + + auto testComparisonBinops = [&](dtype type) { + auto a = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).astype(type); + auto b = Tensor::fromVector({2, 2}, {0, 0, 2, 0}).astype(type); + auto c = Tensor::fromVector({2, 2}, {2, 3, 4, 5}).astype(type); + auto d = Tensor::fromVector({2, 2}, {0, 4, 2, 6}).astype(type); + auto e = Tensor::fromVector({2, 2}, {0, 1, 0, 1}).astype(type); + + ASSERT_TRUE(allClose((a == a), trues)) << "dtype: " << type; + assertCommutativeBinop(a, b, std::equal_to<>(), trueFalses); + assertCommutativeBinop(a, c, std::equal_to<>(), falses); + assertScalarTensorCommutativeBinop(4, a, std::equal_to<>(), falses); + assertScalarTensorCommutativeBinop(1, e, std::equal_to<>(), falseTrues); + + ASSERT_TRUE(allClose((a != a), falses)) << "dtype: " << type; + assertCommutativeBinop(a, b, std::not_equal_to<>(), falseTrues); + assertCommutativeBinop(a, c, std::not_equal_to<>(), trues); + assertScalarTensorCommutativeBinop(4, a, std::not_equal_to<>(), trues); + assertScalarTensorCommutativeBinop(1, e, std::not_equal_to<>(), trueFalses); + + ASSERT_TRUE(allClose((a > a), falses)) << "dtype: " << type; + ASSERT_TRUE(allClose((c > a), trues)) << "dtype: " << type; + ASSERT_TRUE(allClose((d > a), falseTrues)) << "dtype: " << type; + ASSERT_TRUE(allClose((a > d), falses)) << "dtype: " << type; + assertTensorScalarBinop(c, 1, std::greater<>(), trues); + assertScalarTensorBinop(0, c, std::greater<>(), falses); + assertTensorScalarBinop(d, 3, std::greater<>(), falseTrues); + assertScalarTensorBinop(3, d, std::greater<>(), trueFalses); + + ASSERT_TRUE(allClose((a < a), falses)) << "dtype: " << type; + ASSERT_TRUE(allClose((c < a), falses)) << "dtype: " << type; + ASSERT_TRUE(allClose((d < a), falses)) << "dtype: " << type; + ASSERT_TRUE(allClose((a < d), falseTrues)) << "dtype: " << type; + assertTensorScalarBinop(c, 1, std::less<>(), falses); + assertScalarTensorBinop(0, c, std::less<>(), trues); + assertTensorScalarBinop(d, 3, std::less<>(), trueFalses); + assertScalarTensorBinop(3, d, std::less<>(), falseTrues); + + ASSERT_TRUE(allClose((a >= a), trues)) << "dtype: " << type; + ASSERT_TRUE(allClose((c >= a), trues)) << "dtype: " << type; + ASSERT_TRUE(allClose((d >= a), trues)) << "dtype: " << type; + ASSERT_TRUE(allClose((a >= d), trueFalses)) << "dtype: " << type; + assertTensorScalarBinop(c, 2, std::greater_equal<>(), trues); + assertScalarTensorBinop(1, c, std::greater_equal<>(), falses); + assertTensorScalarBinop(d, 3, std::greater_equal<>(), falseTrues); + assertScalarTensorBinop(3, d, std::greater_equal<>(), trueFalses); + + ASSERT_TRUE(allClose((a <= a), trues)) << "dtype: " << type; + ASSERT_TRUE(allClose((c <= a), falses)) << "dtype: " << type; + ASSERT_TRUE(allClose((d <= a), trueFalses)) << "dtype: " << type; + ASSERT_TRUE(allClose((a <= d), trues)) << "dtype: " << type; + assertTensorScalarBinop(c, 1, std::less_equal<>(), falses); + assertScalarTensorBinop(2, c, std::less_equal<>(), trues); + assertTensorScalarBinop(d, 3, std::less_equal<>(), trueFalses); + assertScalarTensorBinop(3, d, std::less_equal<>(), falseTrues); + }; + + applyToAllDtypes(testComparisonBinops); } TEST(TensorBinaryOpsTest, LogicalBinaryOperators) { - auto falses = fl::full({2, 2}, 0, dtype::b8); - auto trues = fl::full({2, 2}, 1, dtype::b8); - auto falseTrues = - Tensor::fromVector({2, 2}, {0, 1, 0, 1}).astype(fl::dtype::b8); - - auto testLogicalBinops = [&](dtype type) { - auto a = Tensor::fromVector({2, 2}, {0, 1, 0, 3}).astype(type); - auto b = Tensor::fromVector({2, 2}, {2, 3, 4, 5}).astype(type); - auto z = fl::full({2, 2}, 0, type); - - ASSERT_TRUE(allClose((z || z), falses)) << "dtype: " << type; - assertCommutativeBinop(a, z, std::logical_or<>(), falseTrues); - assertCommutativeBinop(z, b, std::logical_or<>(), trues); - assertCommutativeBinop(a, b, std::logical_or<>(), trues); - assertScalarTensorCommutativeBinop(0, a, std::logical_or<>(), falseTrues); - assertScalarTensorCommutativeBinop(2, z, std::logical_or<>(), trues); - - ASSERT_TRUE(allClose((z && z), falses)) << "dtype: " << type; - assertCommutativeBinop(a, z, std::logical_and<>(), falses); - assertCommutativeBinop(z, b, std::logical_and<>(), falses); - assertCommutativeBinop(a, b, std::logical_and<>(), falseTrues); - assertScalarTensorCommutativeBinop(0, a, std::logical_and<>(), falses); - assertScalarTensorCommutativeBinop(2, a, std::logical_and<>(), falseTrues); - }; - - applyToAllDtypes(testLogicalBinops); + auto falses = fl::full({2, 2}, 0, dtype::b8); + auto trues = fl::full({2, 2}, 1, dtype::b8); + auto falseTrues = + Tensor::fromVector({2, 2}, {0, 1, 0, 1}).astype(fl::dtype::b8); + + auto testLogicalBinops = [&](dtype type) { + auto a = Tensor::fromVector({2, 2}, {0, 1, 0, 3}).astype(type); + auto b = Tensor::fromVector({2, 2}, {2, 3, 4, 5}).astype(type); + auto z = fl::full({2, 2}, 0, type); + + ASSERT_TRUE(allClose((z || z), falses)) << "dtype: " << type; + assertCommutativeBinop(a, z, std::logical_or<>(), falseTrues); + assertCommutativeBinop(z, b, std::logical_or<>(), trues); + assertCommutativeBinop(a, b, std::logical_or<>(), trues); + assertScalarTensorCommutativeBinop(0, a, std::logical_or<>(), falseTrues); + assertScalarTensorCommutativeBinop(2, z, std::logical_or<>(), trues); + + ASSERT_TRUE(allClose((z && z), falses)) << "dtype: " << type; + assertCommutativeBinop(a, z, std::logical_and<>(), falses); + assertCommutativeBinop(z, b, std::logical_and<>(), falses); + assertCommutativeBinop(a, b, std::logical_and<>(), falseTrues); + assertScalarTensorCommutativeBinop(0, a, std::logical_and<>(), falses); + assertScalarTensorCommutativeBinop(2, a, std::logical_and<>(), falseTrues); + }; + + applyToAllDtypes(testLogicalBinops); } TEST(TensorBinaryOpsTest, ModuloBinaryOperators) { - auto testModuloBinop = [](dtype type) { - auto a = Tensor::fromVector({2, 2}, {1, 2, 3, 4}).astype(type); - auto b = Tensor::fromVector({2, 2}, {2, 3, 5, 7}).astype(type); - auto c = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).astype(type); - auto z = fl::full({2, 2}, 0, type); - - ASSERT_TRUE(allClose((z % b), z)) << "dtype: " << type; - ASSERT_TRUE(allClose((a % a), z)) << "dtype: " << type; - ASSERT_TRUE(allClose((a % b), a)) << "dtype: " << type; - ASSERT_TRUE(allClose((b % a), c)) << "dtype: " << type; - - assertScalarTensorBinop(0, a, std::modulus<>(), z); - assertScalarTensorBinop(11, a, std::modulus<>(), c); - assertTensorScalarBinop(a, 1, std::modulus<>(), z); - assertTensorScalarBinop(a, 5, std::modulus<>(), a); - }; - - applyToAllIntegralDtypes(testModuloBinop); - // TODO ArrayFire needs software impl for fp16 modulo on CUDA backend; - // bring this test back when supported. - // testModuloBinop(dtype::f16); - testModuloBinop(dtype::f32); - testModuloBinop(dtype::f64); + auto testModuloBinop = [](dtype type) { + auto a = Tensor::fromVector({2, 2}, {1, 2, 3, 4}).astype(type); + auto b = Tensor::fromVector({2, 2}, {2, 3, 5, 7}).astype(type); + auto c = Tensor::fromVector({2, 2}, {0, 1, 2, 3}).astype(type); + auto z = fl::full({2, 2}, 0, type); + + ASSERT_TRUE(allClose((z % b), z)) << "dtype: " << type; + ASSERT_TRUE(allClose((a % a), z)) << "dtype: " << type; + ASSERT_TRUE(allClose((a % b), a)) << "dtype: " << type; + ASSERT_TRUE(allClose((b % a), c)) << "dtype: " << type; + + assertScalarTensorBinop(0, a, std::modulus<>(), z); + assertScalarTensorBinop(11, a, std::modulus<>(), c); + assertTensorScalarBinop(a, 1, std::modulus<>(), z); + assertTensorScalarBinop(a, 5, std::modulus<>(), a); + }; + + applyToAllIntegralDtypes(testModuloBinop); + // TODO ArrayFire needs software impl for fp16 modulo on CUDA backend; + // bring this test back when supported. + // testModuloBinop(dtype::f16); + testModuloBinop(dtype::f32); + testModuloBinop(dtype::f64); } TEST(TensorBinaryOpsTest, BitBinaryOperators) { - auto testBitBinops = [](dtype type) { - auto a = Tensor::fromVector({2, 1}, {0b0001, 0b1000}).astype(type); - auto b = Tensor::fromVector({2, 1}, {0b0010, 0b0100}).astype(type); - auto c = Tensor::fromVector({2, 1}, {0b0011, 0b1100}).astype(type); - auto d = Tensor::fromVector({2, 1}, {0b0110, 0b0110}).astype(type); - auto e = Tensor::fromVector({2, 1}, {0b1000, 0b0001}).astype(type); - auto g = Tensor::fromVector({2, 1}, {2, 1}).astype(type); - auto h = Tensor::fromVector({2, 1}, {0b1000, 0b1000}).astype(type); - auto z = Tensor::fromVector({2, 1}, {0b0000, 0b0000}).astype(type); - - ASSERT_TRUE(allClose((z & z), z)) << "dtype: " << type; - assertCommutativeBinop(a, b, std::bit_and<>(), z); - assertCommutativeBinop(z, b, std::bit_and<>(), z); - assertCommutativeBinop(d, b, std::bit_and<>(), b); - assertScalarTensorCommutativeBinop(0b0000, b, std::bit_and<>(), z); - assertScalarTensorCommutativeBinop(0b0110, b, std::bit_and<>(), b); - - ASSERT_TRUE(allClose((z | z), z)) << "dtype: " << type; - assertCommutativeBinop(a, z, std::bit_or<>(), a); - assertCommutativeBinop(z, b, std::bit_or<>(), b); - assertCommutativeBinop(a, b, std::bit_or<>(), c); - assertScalarTensorCommutativeBinop(0b0000, b, std::bit_or<>(), b); - assertScalarTensorCommutativeBinop(0b0110, b, std::bit_or<>(), d); - - ASSERT_TRUE(allClose((z ^ z), z)) << "dtype: " << type; - assertCommutativeBinop(a, z, std::bit_xor<>(), a); - assertCommutativeBinop(z, b, std::bit_xor<>(), b); - assertCommutativeBinop(a, b, std::bit_xor<>(), c); - assertCommutativeBinop(c, c, std::bit_xor<>(), z); - assertScalarTensorCommutativeBinop(0b0000, b, std::bit_xor<>(), b); - assertScalarTensorCommutativeBinop(0b1001, a, std::bit_xor<>(), e); - - // TODO test scalar input (need right/left_shift operator) - ASSERT_TRUE(allClose((z << z), z)) << "dtype: " << type; - ASSERT_TRUE(allClose((a << z), a)) << "dtype: " << type; - ASSERT_TRUE(allClose((z << a), z)) << "dtype: " << type; - ASSERT_TRUE(allClose((b << g), h)) << "dtype: " << type; - - ASSERT_TRUE(allClose((z >> z), z)) << "dtype: " << type; - ASSERT_TRUE(allClose((a >> z), a)) << "dtype: " << type; - ASSERT_TRUE(allClose((z >> a), z)) << "dtype: " << type; - ASSERT_TRUE(allClose((h >> g), b)) << "dtype: " << type; - }; - - applyToAllIntegralDtypes(testBitBinops); - // ArrayFire doesn't support bit ops for fps + auto testBitBinops = [](dtype type) { + auto a = Tensor::fromVector({2, 1}, {0b0001, 0b1000}).astype(type); + auto b = Tensor::fromVector({2, 1}, {0b0010, 0b0100}).astype(type); + auto c = Tensor::fromVector({2, 1}, {0b0011, 0b1100}).astype(type); + auto d = Tensor::fromVector({2, 1}, {0b0110, 0b0110}).astype(type); + auto e = Tensor::fromVector({2, 1}, {0b1000, 0b0001}).astype(type); + auto g = Tensor::fromVector({2, 1}, {2, 1}).astype(type); + auto h = Tensor::fromVector({2, 1}, {0b1000, 0b1000}).astype(type); + auto z = Tensor::fromVector({2, 1}, {0b0000, 0b0000}).astype(type); + + ASSERT_TRUE(allClose((z & z), z)) << "dtype: " << type; + assertCommutativeBinop(a, b, std::bit_and<>(), z); + assertCommutativeBinop(z, b, std::bit_and<>(), z); + assertCommutativeBinop(d, b, std::bit_and<>(), b); + assertScalarTensorCommutativeBinop(0b0000, b, std::bit_and<>(), z); + assertScalarTensorCommutativeBinop(0b0110, b, std::bit_and<>(), b); + + ASSERT_TRUE(allClose((z | z), z)) << "dtype: " << type; + assertCommutativeBinop(a, z, std::bit_or<>(), a); + assertCommutativeBinop(z, b, std::bit_or<>(), b); + assertCommutativeBinop(a, b, std::bit_or<>(), c); + assertScalarTensorCommutativeBinop(0b0000, b, std::bit_or<>(), b); + assertScalarTensorCommutativeBinop(0b0110, b, std::bit_or<>(), d); + + ASSERT_TRUE(allClose((z ^ z), z)) << "dtype: " << type; + assertCommutativeBinop(a, z, std::bit_xor<>(), a); + assertCommutativeBinop(z, b, std::bit_xor<>(), b); + assertCommutativeBinop(a, b, std::bit_xor<>(), c); + assertCommutativeBinop(c, c, std::bit_xor<>(), z); + assertScalarTensorCommutativeBinop(0b0000, b, std::bit_xor<>(), b); + assertScalarTensorCommutativeBinop(0b1001, a, std::bit_xor<>(), e); + + // TODO test scalar input (need right/left_shift operator) + ASSERT_TRUE(allClose((z << z), z)) << "dtype: " << type; + ASSERT_TRUE(allClose((a << z), a)) << "dtype: " << type; + ASSERT_TRUE(allClose((z << a), z)) << "dtype: " << type; + ASSERT_TRUE(allClose((b << g), h)) << "dtype: " << type; + + ASSERT_TRUE(allClose((z >> z), z)) << "dtype: " << type; + ASSERT_TRUE(allClose((a >> z), a)) << "dtype: " << type; + ASSERT_TRUE(allClose((z >> a), z)) << "dtype: " << type; + ASSERT_TRUE(allClose((h >> g), b)) << "dtype: " << type; + }; + + applyToAllIntegralDtypes(testBitBinops); + // ArrayFire doesn't support bit ops for fps } TEST(TensorBinaryOpsTest, BinaryOperatorIncompatibleShapes) { - auto testTensorIncompatibleShapes = [](dtype type, - const Tensor& lhs, - const Tensor& rhs) { - ASSERT_THROW((void)Values(lhs + rhs), std::invalid_argument) << "dtype: " << type; - ASSERT_THROW((void)Values(lhs - rhs), std::invalid_argument) << "dtype: " << type; - ASSERT_THROW((void)Values(lhs * rhs), std::invalid_argument) << "dtype: " << type; - ASSERT_THROW((void)Values(lhs / rhs), std::invalid_argument) << "dtype: " << type; - ASSERT_THROW((void)Values(lhs == rhs), std::invalid_argument) - << "dtype: " << type; - ASSERT_THROW((void)Values(lhs != rhs), std::invalid_argument) - << "dtype: " << type; - ASSERT_THROW((void)Values(lhs < rhs), std::invalid_argument) << "dtype: " << type; - ASSERT_THROW((void)Values(lhs <= rhs), std::invalid_argument) - << "dtype: " << type; - ASSERT_THROW((void)Values(lhs > rhs), std::invalid_argument) << "dtype: " << type; - ASSERT_THROW((void)Values(lhs >= rhs), std::invalid_argument) - << "dtype: " << type; - ASSERT_THROW((void)Values(lhs || rhs), std::invalid_argument) - << "dtype: " << type; - ASSERT_THROW((void)Values(lhs && rhs), std::invalid_argument) - << "dtype: " << type; - // TODO ArrayFire needs software impl for fp16 modulo on CUDA backend; - // bring this test back when supported. - if (type != dtype::f16) { - ASSERT_THROW((void)Values(lhs % rhs), std::invalid_argument) - << "dtype: " << type; - } - // these operators are generally not well-defined for fps - if (type != dtype::f16 && type != dtype::f32 && type != dtype::f64) { - ASSERT_THROW((void)Values(lhs | rhs), std::invalid_argument) - << "dtype: " << type; - ASSERT_THROW((void)Values(lhs ^ rhs), std::invalid_argument) - << "dtype: " << type; - ASSERT_THROW((void)Values(lhs << rhs), std::invalid_argument) - << "dtype: " << type; - ASSERT_THROW((void)Values(lhs >> rhs), std::invalid_argument) - << "dtype: " << type; - } - }; - - auto testTensorIncompatibleShapesForType = [&](dtype type) { - auto a = fl::rand({2, 2}, type); - auto tooManyAxises = fl::rand({4, 5, 6}, type); - auto tooFewAxises = fl::rand({3}, type); - auto diffDim = fl::rand({2, 3}, type); - testTensorIncompatibleShapes(type, a, tooManyAxises); - testTensorIncompatibleShapes(type, a, tooFewAxises); - testTensorIncompatibleShapes(type, a, diffDim); - }; - - applyToAllDtypes(testTensorIncompatibleShapesForType); + auto testTensorIncompatibleShapes = [](dtype type, + const Tensor& lhs, + const Tensor& rhs) { + ASSERT_THROW((void) Values(lhs + rhs), std::invalid_argument) << "dtype: " << type; + ASSERT_THROW((void) Values(lhs - rhs), std::invalid_argument) << "dtype: " << type; + ASSERT_THROW((void) Values(lhs * rhs), std::invalid_argument) << "dtype: " << type; + ASSERT_THROW((void) Values(lhs / rhs), std::invalid_argument) << "dtype: " << type; + ASSERT_THROW((void) Values(lhs == rhs), std::invalid_argument) + << "dtype: " << type; + ASSERT_THROW((void) Values(lhs != rhs), std::invalid_argument) + << "dtype: " << type; + ASSERT_THROW((void) Values(lhs < rhs), std::invalid_argument) << "dtype: " << type; + ASSERT_THROW((void) Values(lhs <= rhs), std::invalid_argument) + << "dtype: " << type; + ASSERT_THROW((void) Values(lhs > rhs), std::invalid_argument) << "dtype: " << type; + ASSERT_THROW((void) Values(lhs >= rhs), std::invalid_argument) + << "dtype: " << type; + ASSERT_THROW((void) Values(lhs || rhs), std::invalid_argument) + << "dtype: " << type; + ASSERT_THROW((void) Values(lhs && rhs), std::invalid_argument) + << "dtype: " << type; + // TODO ArrayFire needs software impl for fp16 modulo on CUDA backend; + // bring this test back when supported. + if(type != dtype::f16) { + ASSERT_THROW((void) Values(lhs % rhs), std::invalid_argument) + << "dtype: " << type; + } + // these operators are generally not well-defined for fps + if(type != dtype::f16 && type != dtype::f32 && type != dtype::f64) { + ASSERT_THROW((void) Values(lhs | rhs), std::invalid_argument) + << "dtype: " << type; + ASSERT_THROW((void) Values(lhs ^ rhs), std::invalid_argument) + << "dtype: " << type; + ASSERT_THROW((void) Values(lhs << rhs), std::invalid_argument) + << "dtype: " << type; + ASSERT_THROW((void) Values(lhs >> rhs), std::invalid_argument) + << "dtype: " << type; + } + }; + + auto testTensorIncompatibleShapesForType = [&](dtype type) { + auto a = fl::rand({2, 2}, type); + auto tooManyAxises = fl::rand({4, 5, 6}, type); + auto tooFewAxises = fl::rand({3}, type); + auto diffDim = fl::rand({2, 3}, type); + testTensorIncompatibleShapes(type, a, tooManyAxises); + testTensorIncompatibleShapes(type, a, tooFewAxises); + testTensorIncompatibleShapes(type, a, diffDim); + }; + + applyToAllDtypes(testTensorIncompatibleShapesForType); } TEST(TensorBinaryOpsTest, minimum) { - auto a = fl::full({3, 3}, 1); - auto b = fl::full({3, 3}, 2); - auto c = fl::minimum(a, b); - ASSERT_EQ(a.type(), c.type()); - ASSERT_TRUE(allClose(a, c)); - ASSERT_TRUE(allClose(fl::minimum(1, b).astype(a.type()), a)); - ASSERT_TRUE(allClose(fl::minimum(b, 1).astype(a.type()), a)); + auto a = fl::full({3, 3}, 1); + auto b = fl::full({3, 3}, 2); + auto c = fl::minimum(a, b); + ASSERT_EQ(a.type(), c.type()); + ASSERT_TRUE(allClose(a, c)); + ASSERT_TRUE(allClose(fl::minimum(1, b).astype(a.type()), a)); + ASSERT_TRUE(allClose(fl::minimum(b, 1).astype(a.type()), a)); } TEST(TensorBinaryOpsTest, maximum) { - auto a = fl::full({3, 3}, 1); - auto b = fl::full({3, 3}, 2); - auto c = fl::maximum(a, b); - ASSERT_EQ(b.type(), c.type()); - ASSERT_TRUE(allClose(b, c)); - ASSERT_TRUE(allClose(fl::maximum(1, b).astype(a.type()), b)); - ASSERT_TRUE(allClose(fl::maximum(b, 1).astype(a.type()), b)); + auto a = fl::full({3, 3}, 1); + auto b = fl::full({3, 3}, 2); + auto c = fl::maximum(a, b); + ASSERT_EQ(b.type(), c.type()); + ASSERT_TRUE(allClose(b, c)); + ASSERT_TRUE(allClose(fl::maximum(1, b).astype(a.type()), b)); + ASSERT_TRUE(allClose(fl::maximum(b, 1).astype(a.type()), b)); } using binaryOpFunc_t = Tensor (*)(const Tensor& lhs, const Tensor& rhs); TEST(TensorBinaryOpsTest, broadcasting) { - // Collection of {lhs, rhs, tileShapeLhs, tileShapeRhs} corresponding to - // broadcasting [lhs] to [rhs] by tiling by the the respective tileShapes - struct ShapeData { - Shape lhs; // broadcast from - Shape rhs; // broadcast to - Shape tileShapeLhs; - Shape tileShapeRhs; - }; - std::vector shapes = { - {{3, 1}, {3, 3}, {1, 3}, {1, 1}}, - {{3}, {3, 3}, {1, 3}, {1, 1}}, - {{3, 1, 4}, {3, 6, 4}, {1, 6, 1}, {1, 1, 1}}, - {{3, 1, 4, 1}, {3, 2, 4, 5}, {1, 2, 1, 5}, {1, 1, 1, 1}}, - {{1, 10}, {8, 10}, {8, 1}, {1, 1}}, - {{2, 1, 5, 1}, {2, 3, 5, 3}, {1, 3, 1, 3}, {1, 1, 1, 1}}, - {{3, 1, 2, 1}, {1, 4, 1, 5}, {1, 4, 1, 5}, {3, 1, 2, 1}}, - {{3, 2, 1}, {3, 1, 4, 1}, {1, 1, 4}, {1, 2, 1, 1}}}; - - std::unordered_map functions = { - {fl::minimum, "minimum"}, - {fl::maximum, "maximum"}, - {fl::power, "power"}, - {fl::add, "add"}, - {fl::add, "add"}, - {fl::sub, "sub"}, - {fl::mul, "mul"}, - {fl::div, "div"}, - {fl::eq, "eq"}, - {fl::neq, "neq"}, - {fl::lessThan, "lessThan"}, - {fl::lessThanEqual, "lessThanEqual"}, - {fl::greaterThan, "greaterThan"}, - {fl::greaterThanEqual, "greaterThanEqual"}, - {fl::logicalOr, "logicalOr"}, - {fl::logicalAnd, "logicalAnd"}, - {fl::mod, "mod"}, - {fl::bitwiseOr, "bitwiseOr"}, - {fl::bitwiseXor, "bitwiseXor"}, - {fl::lShift, "lShift"}, - {fl::rShift, "rShift"}}; - - auto doBinaryOp = [](const Tensor& lhs, - const Tensor& rhs, - const Shape& tileShapeLhs, - const Shape& tileShapeRhs, - binaryOpFunc_t func) -> std::pair { - assert(lhs.ndim() <= rhs.ndim()); - return { - func(lhs, rhs), func(tile(lhs, tileShapeLhs), tile(rhs, tileShapeRhs))}; - }; - - auto computeBroadcastShape = [](const Shape& lhsShape, - const Shape& rhsShape) -> Shape { - unsigned maxnDim = std::max(lhsShape.ndim(), rhsShape.ndim()); - Shape outShape{std::vector(maxnDim)}; - for (unsigned i = 0; i < maxnDim; ++i) { - if (i > lhsShape.ndim() - 1) { - outShape[i] = rhsShape[i]; - } else if (i > rhsShape.ndim() - 1) { - outShape[i] = lhsShape[i]; - } else if (lhsShape[i] == 1) { - outShape[i] = rhsShape[i]; - } else if (rhsShape[i] == 1) { - outShape[i] = lhsShape[i]; - } else if (lhsShape[i] == rhsShape[i]) { - outShape[i] = lhsShape[i]; - } else if (lhsShape[i] != rhsShape[i]) { - throw std::runtime_error( - "computeBroadcastShape - cannot broadcast shape"); - } + // Collection of {lhs, rhs, tileShapeLhs, tileShapeRhs} corresponding to + // broadcasting [lhs] to [rhs] by tiling by the the respective tileShapes + struct ShapeData { + Shape lhs; // broadcast from + Shape rhs; // broadcast to + Shape tileShapeLhs; + Shape tileShapeRhs; + }; + std::vector shapes = { + {{3, 1}, {3, 3}, {1, 3}, {1, 1}}, + {{3}, {3, 3}, {1, 3}, {1, 1}}, + {{3, 1, 4}, {3, 6, 4}, {1, 6, 1}, {1, 1, 1}}, + {{3, 1, 4, 1}, {3, 2, 4, 5}, {1, 2, 1, 5}, {1, 1, 1, 1}}, + {{1, 10}, {8, 10}, {8, 1}, {1, 1}}, + {{2, 1, 5, 1}, {2, 3, 5, 3}, {1, 3, 1, 3}, {1, 1, 1, 1}}, + {{3, 1, 2, 1}, {1, 4, 1, 5}, {1, 4, 1, 5}, {3, 1, 2, 1}}, + {{3, 2, 1}, {3, 1, 4, 1}, {1, 1, 4}, {1, 2, 1, 1}}}; + + std::unordered_map functions = { + {fl::minimum, "minimum"}, + {fl::maximum, "maximum"}, + {fl::power, "power"}, + {fl::add, "add"}, + {fl::add, "add"}, + {fl::sub, "sub"}, + {fl::mul, "mul"}, + {fl::div, "div"}, + {fl::eq, "eq"}, + {fl::neq, "neq"}, + {fl::lessThan, "lessThan"}, + {fl::lessThanEqual, "lessThanEqual"}, + {fl::greaterThan, "greaterThan"}, + {fl::greaterThanEqual, "greaterThanEqual"}, + {fl::logicalOr, "logicalOr"}, + {fl::logicalAnd, "logicalAnd"}, + {fl::mod, "mod"}, + {fl::bitwiseOr, "bitwiseOr"}, + {fl::bitwiseXor, "bitwiseXor"}, + {fl::lShift, "lShift"}, + {fl::rShift, "rShift"}}; + + auto doBinaryOp = [](const Tensor& lhs, + const Tensor& rhs, + const Shape& tileShapeLhs, + const Shape& tileShapeRhs, + binaryOpFunc_t func) -> std::pair { + assert(lhs.ndim() <= rhs.ndim()); + return { + func(lhs, rhs), func(tile(lhs, tileShapeLhs), tile(rhs, tileShapeRhs))}; + }; + + auto computeBroadcastShape = [](const Shape& lhsShape, + const Shape& rhsShape) -> Shape { + unsigned maxnDim = std::max(lhsShape.ndim(), rhsShape.ndim()); + Shape outShape{std::vector(maxnDim)}; + for(unsigned i = 0; i < maxnDim; ++i) { + if(i > lhsShape.ndim() - 1) { + outShape[i] = rhsShape[i]; + } else if(i > rhsShape.ndim() - 1) { + outShape[i] = lhsShape[i]; + } else if(lhsShape[i] == 1) { + outShape[i] = rhsShape[i]; + } else if(rhsShape[i] == 1) { + outShape[i] = lhsShape[i]; + } else if(lhsShape[i] == rhsShape[i]) { + outShape[i] = lhsShape[i]; + } else if(lhsShape[i] != rhsShape[i]) { + throw std::runtime_error( + "computeBroadcastShape - cannot broadcast shape" + ); + } + } + return outShape; + }; + + for(const auto& funcp : functions) { + for(auto& shapeData : shapes) { + auto lhs = ((fl::rand(shapeData.lhs) + 1) * 10).astype(fl::dtype::s32); + auto rhs = ((fl::rand(shapeData.rhs) + 1) * 10).astype(fl::dtype::s32); + + auto [actualOut, expectedOut] = doBinaryOp( + lhs, + rhs, + shapeData.tileShapeLhs, + shapeData.tileShapeRhs, + funcp.first + ); + + Shape expectedShape = computeBroadcastShape(shapeData.lhs, shapeData.rhs); + + std::stringstream ss; + ss << "lhs: " << shapeData.lhs << " rhs: " << shapeData.rhs + << " function: " << funcp.second; + auto testData = ss.str(); + + ASSERT_EQ(actualOut.shape(), expectedShape) << testData; + ASSERT_TRUE(allClose(actualOut, expectedOut)) << testData; + } + + // Scalar broadcasting + const double scalarVal = 4; + const Shape inShape = {2, 3, 4}; + const auto lhs = fl::rand(inShape).astype(fl::dtype::s32); + const auto rhs = fl::fromScalar(scalarVal, fl::dtype::s32); + const auto rhsTiled = fl::full(inShape, scalarVal, fl::dtype::s32); + ASSERT_TRUE(allClose(funcp.first(lhs, rhs), funcp.first(lhs, rhsTiled))); } - return outShape; - }; - - for (const auto& funcp : functions) { - for (auto& shapeData : shapes) { - auto lhs = ((fl::rand(shapeData.lhs) + 1) * 10).astype(fl::dtype::s32); - auto rhs = ((fl::rand(shapeData.rhs) + 1) * 10).astype(fl::dtype::s32); - - auto [actualOut, expectedOut] = doBinaryOp( - lhs, - rhs, - shapeData.tileShapeLhs, - shapeData.tileShapeRhs, - funcp.first); - - Shape expectedShape = computeBroadcastShape(shapeData.lhs, shapeData.rhs); - - std::stringstream ss; - ss << "lhs: " << shapeData.lhs << " rhs: " << shapeData.rhs - << " function: " << funcp.second; - auto testData = ss.str(); - - ASSERT_EQ(actualOut.shape(), expectedShape) << testData; - ASSERT_TRUE(allClose(actualOut, expectedOut)) << testData; - } - - // Scalar broadcasting - const double scalarVal = 4; - const Shape inShape = {2, 3, 4}; - const auto lhs = fl::rand(inShape).astype(fl::dtype::s32); - const auto rhs = fl::fromScalar(scalarVal, fl::dtype::s32); - const auto rhsTiled = fl::full(inShape, scalarVal, fl::dtype::s32); - ASSERT_TRUE(allClose(funcp.first(lhs, rhs), funcp.first(lhs, rhsTiled))); - } } TEST(TensorBinaryOpsTest, power) { - auto a = fl::full({3, 3}, 2.); - auto b = fl::full({3, 3}, 2.); - ASSERT_TRUE(allClose(fl::power(a, b), a * b)); + auto a = fl::full({3, 3}, 2.); + auto b = fl::full({3, 3}, 2.); + ASSERT_TRUE(allClose(fl::power(a, b), a * b)); } TEST(TensorBinaryOpsTest, powerDouble) { - auto a = fl::full({3, 3}, 2.); - ASSERT_TRUE(allClose(fl::power(a, 3), a * a * a)); + auto a = fl::full({3, 3}, 2.); + ASSERT_TRUE(allClose(fl::power(a, 3), a * a * a)); - auto b = fl::full({3, 3}, 2.); - ASSERT_TRUE( - allClose(fl::power(3, a), fl::full(b.shape(), 3 * 3, fl::dtype::f32))); + auto b = fl::full({3, 3}, 2.); + ASSERT_TRUE( + allClose(fl::power(3, a), fl::full(b.shape(), 3 * 3, fl::dtype::f32)) + ); } int main(int argc, char** argv) { - InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/TensorExtensionTest.cpp b/flashlight/fl/test/tensor/TensorExtensionTest.cpp index 7e1b6be..8f4f790 100644 --- a/flashlight/fl/test/tensor/TensorExtensionTest.cpp +++ b/flashlight/fl/test/tensor/TensorExtensionTest.cpp @@ -22,53 +22,57 @@ using namespace fl; // Extension interface class TestTensorExtension : public TensorExtension { - public: - static constexpr TensorExtensionType extensionType = - TensorExtensionType::Generic; +public: + static constexpr TensorExtensionType extensionType = + TensorExtensionType::Generic; - TestTensorExtension() = default; - virtual ~TestTensorExtension() = default; + TestTensorExtension() = default; + virtual ~TestTensorExtension() = default; - virtual Tensor testExtensionFunc(const Tensor& tensor) = 0; + virtual Tensor testExtensionFunc(const Tensor& tensor) = 0; }; // Specific extension implementation class TestArrayFireTensorExtension : public TestTensorExtension { - public: - Tensor testExtensionFunc(const Tensor& tensor) override { - return tensor + 1; - } - - bool isDataTypeSupported(const fl::dtype&) const override { - return true; - } +public: + Tensor testExtensionFunc(const Tensor& tensor) override { + return tensor + 1; + } + + bool isDataTypeSupported(const fl::dtype&) const override { + return true; + } }; // Op in API Tensor testExtensionFunc(const Tensor& tensor) { - return tensor.backend().getExtension().testExtensionFunc( - tensor); + return tensor.backend().getExtension().testExtensionFunc( + tensor + ); } FL_REGISTER_TENSOR_EXTENSION(TestArrayFireTensorExtension, ArrayFire); TEST(TensorExtensionTest, TestExtension) { - auto a = fl::rand({4, 5, 6}); + auto a = fl::rand({4, 5, 6}); - // TODO: this test only works with the ArrayFire backend - gate accordingly - if (Tensor().backendType() != TensorBackendType::ArrayFire) { - GTEST_SKIP() << "Flashlight not built with ArrayFire backend."; - } + // TODO: this test only works with the ArrayFire backend - gate accordingly + if(Tensor().backendType() != TensorBackendType::ArrayFire) { + GTEST_SKIP() << "Flashlight not built with ArrayFire backend."; + } - // TODO: add a fixture to check with available backends - ASSERT_TRUE(::fl::registerTensorExtension( - TensorBackendType::ArrayFire)); + // TODO: add a fixture to check with available backends + ASSERT_TRUE( + ::fl::registerTensorExtension( + TensorBackendType::ArrayFire + ) + ); - ASSERT_TRUE(allClose(testExtensionFunc(a), a + 1)); + ASSERT_TRUE(allClose(testExtensionFunc(a), a + 1)); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/TensorReductionTest.cpp b/flashlight/fl/test/tensor/TensorReductionTest.cpp index 04ce54b..b29b59b 100644 --- a/flashlight/fl/test/tensor/TensorReductionTest.cpp +++ b/flashlight/fl/test/tensor/TensorReductionTest.cpp @@ -17,367 +17,419 @@ using namespace ::testing; using namespace fl; TEST(TensorReductionTest, countNonzero) { - std::vector idxs = {0, 3, 4, 7, 24, 78}; - auto a = fl::full({10, 10}, 1, fl::dtype::u32); - for (const auto idx : idxs) { - a(idx / 10, idx % 10) = 0; - } - - ASSERT_TRUE(allClose( - fl::fromScalar(a.elements() - idxs.size(), a.type()), - fl::countNonzero(a))); - - std::vector sizes(a.shape().dim(0)); - for (unsigned i = 0; i < a.shape().dim(0); ++i) { - sizes[i] = - a.shape().dim(0) - fl::sum(a(fl::span, i) == 0, {0}).scalar(); - } - ASSERT_TRUE(allClose(Tensor::fromVector(sizes), Tensor::fromVector(sizes))); - - auto b = fl::full({2, 2, 2}, 1, fl::dtype::u32); - b(0, 1, 1) = 0; - b(1, 0, 1) = 0; - b(1, 1, 1) = 0; - ASSERT_TRUE(allClose( - fl::Tensor::fromVector({2, 2}, {2, 2, 1, 0}), - fl::countNonzero(b, {0}))); - ASSERT_TRUE(allClose( - fl::Tensor::fromVector({2}, {4, 1}), - fl::countNonzero(b, {0, 1}))); - ASSERT_TRUE( - allClose(fl::fromScalar(b.elements() - 3, b.type()), fl::countNonzero(b, {0, 1, 2}))); + std::vector idxs = {0, 3, 4, 7, 24, 78}; + auto a = fl::full({10, 10}, 1, fl::dtype::u32); + for(const auto idx : idxs) { + a(idx / 10, idx % 10) = 0; + } + + ASSERT_TRUE( + allClose( + fl::fromScalar(a.elements() - idxs.size(), a.type()), + fl::countNonzero(a) + ) + ); + + std::vector sizes(a.shape().dim(0)); + for(unsigned i = 0; i < a.shape().dim(0); ++i) { + sizes[i] = + a.shape().dim(0) - fl::sum(a(fl::span, i) == 0, {0}).scalar(); + } + ASSERT_TRUE(allClose(Tensor::fromVector(sizes), Tensor::fromVector(sizes))); + + auto b = fl::full({2, 2, 2}, 1, fl::dtype::u32); + b(0, 1, 1) = 0; + b(1, 0, 1) = 0; + b(1, 1, 1) = 0; + ASSERT_TRUE( + allClose( + fl::Tensor::fromVector({2, 2}, {2, 2, 1, 0}), + fl::countNonzero(b, {0}) + ) + ); + ASSERT_TRUE( + allClose( + fl::Tensor::fromVector({2}, {4, 1}), + fl::countNonzero(b, {0, 1}) + ) + ); + ASSERT_TRUE( + allClose(fl::fromScalar(b.elements() - 3, b.type()), fl::countNonzero(b, {0, 1, 2})) + ); } TEST(TensorReductionTest, amin) { - auto a = fl::rand({4, 5, 6}); - const float val = -300; - a(2, 3, 4) = val; - ASSERT_EQ(fl::amin(a).shape(), Shape({})); - ASSERT_EQ(fl::amin(a).elements(), 1); - ASSERT_EQ(fl::amin(a).scalar(), val); - auto b = fl::rand({4, 4}); - b(1, 1) = val; - ASSERT_EQ(fl::amin(b, {0}).shape(), Shape({4})); - ASSERT_EQ(fl::amin(b, {0}, /* keepDims = */ true).shape(), Shape({1, 4})); - ASSERT_EQ(fl::amin(b, {0})(1).scalar(), val); - ASSERT_EQ(fl::amin(b, {1})(1).scalar(), val); - auto q = fl::amin(fl::full({5, 5, 5, 5}, 1)); - ASSERT_EQ(q.shape(), Shape({})); - ASSERT_EQ(q.elements(), 1); - ASSERT_EQ(q.scalar(), 1); - - const float v = 3.14; - auto s = fl::amin(fl::fromScalar(v)); - ASSERT_EQ(s.shape(), Shape()); - ASSERT_EQ(s.scalar(), v); - ASSERT_EQ(fl::amin(fl::fromScalar(v), {0}).shape(), Shape()); + auto a = fl::rand({4, 5, 6}); + const float val = -300; + a(2, 3, 4) = val; + ASSERT_EQ(fl::amin(a).shape(), Shape({})); + ASSERT_EQ(fl::amin(a).elements(), 1); + ASSERT_EQ(fl::amin(a).scalar(), val); + auto b = fl::rand({4, 4}); + b(1, 1) = val; + ASSERT_EQ(fl::amin(b, {0}).shape(), Shape({4})); + ASSERT_EQ(fl::amin(b, {0}, /* keepDims = */ true).shape(), Shape({1, 4})); + ASSERT_EQ(fl::amin(b, {0})(1).scalar(), val); + ASSERT_EQ(fl::amin(b, {1})(1).scalar(), val); + auto q = fl::amin(fl::full({5, 5, 5, 5}, 1)); + ASSERT_EQ(q.shape(), Shape({})); + ASSERT_EQ(q.elements(), 1); + ASSERT_EQ(q.scalar(), 1); + + const float v = 3.14; + auto s = fl::amin(fl::fromScalar(v)); + ASSERT_EQ(s.shape(), Shape()); + ASSERT_EQ(s.scalar(), v); + ASSERT_EQ(fl::amin(fl::fromScalar(v), {0}).shape(), Shape()); } TEST(TensorReductionTest, amax) { - auto a = fl::rand({4, 5, 6}); - const float val = 300; - a(2, 3, 4) = val; - ASSERT_EQ(fl::amax(a).shape(), Shape({})); - ASSERT_EQ(fl::amax(a).elements(), 1); - ASSERT_EQ(fl::amax(a).scalar(), val); - auto b = fl::rand({4, 4}); - b(1, 1) = val; - ASSERT_EQ(fl::amax(b, {0}).shape(), Shape({4})); - ASSERT_EQ(fl::amax(b, {0}, /* keepDims = */ true).shape(), Shape({1, 4})); - ASSERT_EQ(fl::amax(b, {0})(1).scalar(), val); - ASSERT_EQ(fl::amax(b, {1})(1).scalar(), val); - auto q = fl::amax(fl::full({5, 5, 5, 5}, 1)); - ASSERT_EQ(q.shape(), Shape({})); - ASSERT_EQ(q.elements(), 1); - ASSERT_EQ(q.scalar(), 1); - - const float v = 3.14; - auto s = fl::amax(fl::fromScalar(v)); - ASSERT_EQ(s.shape(), Shape()); - ASSERT_EQ(s.scalar(), v); - ASSERT_EQ(fl::amax(fl::fromScalar(v), {0}).shape(), Shape()); + auto a = fl::rand({4, 5, 6}); + const float val = 300; + a(2, 3, 4) = val; + ASSERT_EQ(fl::amax(a).shape(), Shape({})); + ASSERT_EQ(fl::amax(a).elements(), 1); + ASSERT_EQ(fl::amax(a).scalar(), val); + auto b = fl::rand({4, 4}); + b(1, 1) = val; + ASSERT_EQ(fl::amax(b, {0}).shape(), Shape({4})); + ASSERT_EQ(fl::amax(b, {0}, /* keepDims = */ true).shape(), Shape({1, 4})); + ASSERT_EQ(fl::amax(b, {0})(1).scalar(), val); + ASSERT_EQ(fl::amax(b, {1})(1).scalar(), val); + auto q = fl::amax(fl::full({5, 5, 5, 5}, 1)); + ASSERT_EQ(q.shape(), Shape({})); + ASSERT_EQ(q.elements(), 1); + ASSERT_EQ(q.scalar(), 1); + + const float v = 3.14; + auto s = fl::amax(fl::fromScalar(v)); + ASSERT_EQ(s.shape(), Shape()); + ASSERT_EQ(s.scalar(), v); + ASSERT_EQ(fl::amax(fl::fromScalar(v), {0}).shape(), Shape()); } TEST(TensorReductionTest, argmin) { - Tensor in = Tensor::fromVector({2, 3}, {4, 8, 6, 3, 5, 9}); - auto a0 = fl::argmin(in, 0); - auto a1 = fl::argmin(in, 1); - - ASSERT_EQ(a0.shape(), Shape({in.dim(1)})); - ASSERT_EQ(a1.shape(), Shape({in.dim(0)})); - ASSERT_TRUE(allClose(a0, Tensor::fromVector({3}, {0, 1, 0}))); - ASSERT_TRUE(allClose(a1, Tensor::fromVector({2}, {0, 1}))); - ASSERT_EQ( - fl::argmin(in, 0, /* keepDims = */ true).shape(), Shape({1, in.dim(1)})); - ASSERT_EQ( - fl::argmin(in, 1, /* keepDims = */ true).shape(), Shape({in.dim(0), 1})); + Tensor in = Tensor::fromVector({2, 3}, {4, 8, 6, 3, 5, 9}); + auto a0 = fl::argmin(in, 0); + auto a1 = fl::argmin(in, 1); + + ASSERT_EQ(a0.shape(), Shape({in.dim(1)})); + ASSERT_EQ(a1.shape(), Shape({in.dim(0)})); + ASSERT_TRUE(allClose(a0, Tensor::fromVector({3}, {0, 1, 0}))); + ASSERT_TRUE(allClose(a1, Tensor::fromVector({2}, {0, 1}))); + ASSERT_EQ( + fl::argmin(in, 0, /* keepDims = */ true).shape(), + Shape({1, in.dim(1)}) + ); + ASSERT_EQ( + fl::argmin(in, 1, /* keepDims = */ true).shape(), + Shape({in.dim(0), 1}) + ); } TEST(TensorReductionTest, argmax) { - Tensor in = Tensor::fromVector({2, 3}, {4, 8, 6, 3, 5, 9}); - auto a0 = fl::argmax(in, 0); - auto a1 = fl::argmax(in, 1); - - ASSERT_EQ(a0.shape(), Shape({in.dim(1)})); - ASSERT_EQ(a1.shape(), Shape({in.dim(0)})); - ASSERT_TRUE(allClose(a0, Tensor::fromVector({3}, {1, 0, 1}))); - ASSERT_TRUE(allClose(a1, Tensor::fromVector({2}, {1, 2}))); - ASSERT_EQ( - fl::argmax(in, 0, /* keepDims = */ true).shape(), Shape({1, in.dim(1)})); - ASSERT_EQ( - fl::argmax(in, 1, /* keepDims = */ true).shape(), Shape({in.dim(0), 1})); + Tensor in = Tensor::fromVector({2, 3}, {4, 8, 6, 3, 5, 9}); + auto a0 = fl::argmax(in, 0); + auto a1 = fl::argmax(in, 1); + + ASSERT_EQ(a0.shape(), Shape({in.dim(1)})); + ASSERT_EQ(a1.shape(), Shape({in.dim(0)})); + ASSERT_TRUE(allClose(a0, Tensor::fromVector({3}, {1, 0, 1}))); + ASSERT_TRUE(allClose(a1, Tensor::fromVector({2}, {1, 2}))); + ASSERT_EQ( + fl::argmax(in, 0, /* keepDims = */ true).shape(), + Shape({1, in.dim(1)}) + ); + ASSERT_EQ( + fl::argmax(in, 1, /* keepDims = */ true).shape(), + Shape({in.dim(0), 1}) + ); } TEST(TensorReductionTest, min) { - Tensor in = Tensor::fromVector({2, 3}, {4, 8, 6, 3, 5, 9}); - Tensor values, indices; - fl::min(values, indices, in, 0); - ASSERT_EQ(indices.shape(), Shape({in.dim(1)})); - ASSERT_TRUE(allClose(indices, Tensor::fromVector({3}, {0, 1, 0}))); - for (unsigned i = 0; i < values.elements(); ++i) { - ASSERT_TRUE(allClose(values.flat(i), in(fl::span, i)(indices(i)))); - } - - fl::min(values, indices, in, 1); - ASSERT_EQ(indices.shape(), Shape({in.dim(0)})); - ASSERT_TRUE(allClose(indices, Tensor::fromVector({2}, {0, 1}))); - for (unsigned i = 0; i < values.elements(); ++i) { - ASSERT_TRUE(allClose(values.flat(i), in(i)(indices(i)))); - } - - fl::min(values, indices, in, 0, /* keepDims = */ true); - ASSERT_EQ(values.shape(), Shape({1, in.dim(1)})); - - fl::min(values, indices, in, 1, /* keepDims = */ true); - ASSERT_EQ(values.shape(), Shape({in.dim(0), 1})); + Tensor in = Tensor::fromVector({2, 3}, {4, 8, 6, 3, 5, 9}); + Tensor values, indices; + fl::min(values, indices, in, 0); + ASSERT_EQ(indices.shape(), Shape({in.dim(1)})); + ASSERT_TRUE(allClose(indices, Tensor::fromVector({3}, {0, 1, 0}))); + for(unsigned i = 0; i < values.elements(); ++i) { + ASSERT_TRUE(allClose(values.flat(i), in(fl::span, i)(indices(i)))); + } + + fl::min(values, indices, in, 1); + ASSERT_EQ(indices.shape(), Shape({in.dim(0)})); + ASSERT_TRUE(allClose(indices, Tensor::fromVector({2}, {0, 1}))); + for(unsigned i = 0; i < values.elements(); ++i) { + ASSERT_TRUE(allClose(values.flat(i), in(i)(indices(i)))); + } + + fl::min(values, indices, in, 0, /* keepDims = */ true); + ASSERT_EQ(values.shape(), Shape({1, in.dim(1)})); + + fl::min(values, indices, in, 1, /* keepDims = */ true); + ASSERT_EQ(values.shape(), Shape({in.dim(0), 1})); } TEST(TensorReductionTest, max) { - Tensor in = Tensor::fromVector({2, 3}, {4, 8, 6, 3, 5, 9}); - Tensor values, indices; - fl::max(values, indices, in, 0); - ASSERT_EQ(indices.shape(), Shape({in.dim(1)})); - ASSERT_TRUE(allClose(indices, Tensor::fromVector({3}, {1, 0, 1}))); - for (unsigned i = 0; i < values.elements(); ++i) { - ASSERT_TRUE(allClose(values.flat(i), in(fl::span, i)(indices(i)))); - } - - fl::max(values, indices, in, 1); - ASSERT_EQ(indices.shape(), Shape({in.dim(0)})); - ASSERT_TRUE(allClose(indices, Tensor::fromVector({2}, {1, 2}))); - for (unsigned i = 0; i < values.elements(); ++i) { - ASSERT_TRUE(allClose(values.flat(i), in(i)(indices(i)))); - } - - fl::max(values, indices, in, 0, /* keepDims = */ true); - ASSERT_EQ(values.shape(), Shape({1, in.dim(1)})); - - fl::max(values, indices, in, 1, /* keepDims = */ true); - ASSERT_EQ(values.shape(), Shape({in.dim(0), 1})); + Tensor in = Tensor::fromVector({2, 3}, {4, 8, 6, 3, 5, 9}); + Tensor values, indices; + fl::max(values, indices, in, 0); + ASSERT_EQ(indices.shape(), Shape({in.dim(1)})); + ASSERT_TRUE(allClose(indices, Tensor::fromVector({3}, {1, 0, 1}))); + for(unsigned i = 0; i < values.elements(); ++i) { + ASSERT_TRUE(allClose(values.flat(i), in(fl::span, i)(indices(i)))); + } + + fl::max(values, indices, in, 1); + ASSERT_EQ(indices.shape(), Shape({in.dim(0)})); + ASSERT_TRUE(allClose(indices, Tensor::fromVector({2}, {1, 2}))); + for(unsigned i = 0; i < values.elements(); ++i) { + ASSERT_TRUE(allClose(values.flat(i), in(i)(indices(i)))); + } + + fl::max(values, indices, in, 0, /* keepDims = */ true); + ASSERT_EQ(values.shape(), Shape({1, in.dim(1)})); + + fl::max(values, indices, in, 1, /* keepDims = */ true); + ASSERT_EQ(values.shape(), Shape({in.dim(0), 1})); } TEST(TensorReductionTest, cumsum) { - int max = 30; - auto a = fl::tile(fl::arange(1, max), {1, 2}); - - auto ref = fl::arange(1, max); - for (int i = 1; i < max - 1; ++i) { - ref += fl::concatenate({fl::full({i}, 0), fl::arange(1, max - i)}); - } - - ASSERT_TRUE(allClose(fl::cumsum(a, 0), fl::tile(ref, {1, 2}))); - ASSERT_TRUE(allClose( - fl::cumsum(a, 1), - fl::concatenate( - {fl::arange(1, max), 2 * fl::arange(1, max)}, /* axis = */ 1))); + int max = 30; + auto a = fl::tile(fl::arange(1, max), {1, 2}); + + auto ref = fl::arange(1, max); + for(int i = 1; i < max - 1; ++i) { + ref += fl::concatenate({fl::full({i}, 0), fl::arange(1, max - i)}); + } + + ASSERT_TRUE(allClose(fl::cumsum(a, 0), fl::tile(ref, {1, 2}))); + ASSERT_TRUE( + allClose( + fl::cumsum(a, 1), + fl::concatenate( + {fl::arange(1, max), 2 * fl::arange(1, max)}, /* axis = */ + 1 + ) + ) + ); } TEST(TensorReductionTest, sum) { - auto t = fl::full({3, 4, 5, 6}, 1.0); - ASSERT_TRUE(allClose(fl::sum(t, {0}), fl::full({4, 5, 6}, 3.0))); - ASSERT_TRUE( - allClose(fl::sum(t, {1, 2}), fl::full({3, 6}, 4 * 5, fl::dtype::f32))); - auto res = fl::sum( - fl::sum(t, {2}, /* keepDims = */ true), {1}, /* keepDims = */ true); - ASSERT_EQ(res.shape(), Shape({t.dim(0), 1, 1, t.dim(3)})); - ASSERT_TRUE( - allClose(fl::reshape(res, {t.dim(0), t.dim(3)}), fl::sum(t, {2, 1}))); - - unsigned dim = 5; - auto q = fl::sum(fl::full({dim, dim, dim, dim}, 1)); - ASSERT_EQ(q.shape(), Shape({})); - ASSERT_EQ(q.elements(), 1); - ASSERT_EQ(q.scalar(), dim * dim * dim * dim); - - ASSERT_TRUE(allClose( - fl::sum(fl::sum(q, {0, 1, 2}), {0}), - fl::fromScalar(dim * dim * dim * dim, fl::dtype::s32))); + auto t = fl::full({3, 4, 5, 6}, 1.0); + ASSERT_TRUE(allClose(fl::sum(t, {0}), fl::full({4, 5, 6}, 3.0))); + ASSERT_TRUE( + allClose(fl::sum(t, {1, 2}), fl::full({3, 6}, 4 * 5, fl::dtype::f32)) + ); + auto res = fl::sum( + fl::sum(t, {2}, /* keepDims = */ true), + {1}, /* keepDims = */ + true + ); + ASSERT_EQ(res.shape(), Shape({t.dim(0), 1, 1, t.dim(3)})); + ASSERT_TRUE( + allClose(fl::reshape(res, {t.dim(0), t.dim(3)}), fl::sum(t, {2, 1})) + ); + + unsigned dim = 5; + auto q = fl::sum(fl::full({dim, dim, dim, dim}, 1)); + ASSERT_EQ(q.shape(), Shape({})); + ASSERT_EQ(q.elements(), 1); + ASSERT_EQ(q.scalar(), dim * dim * dim * dim); + + ASSERT_TRUE( + allClose( + fl::sum(fl::sum(q, {0, 1, 2}), {0}), + fl::fromScalar(dim * dim * dim * dim, fl::dtype::s32) + ) + ); } TEST(TensorReductionTest, mean) { - auto r = fl::rand({8, 7, 6}); - ASSERT_NEAR(fl::mean(r).scalar(), 0.5, 0.05); - ASSERT_EQ( - fl::mean(r, {0, 1}, /* keepDims = */ true).shape(), Shape({1, 1, 6})); - - auto s = fl::full({5, 6, 7}, 1); - ASSERT_TRUE(allClose(fl::mean(s, {0}), fl::full({6, 7}, 1.))); - - auto a = fl::mean(fl::full({5, 5, 5, 5}, 1)); - ASSERT_EQ(a.shape(), Shape({})); - ASSERT_EQ(a.elements(), 1); - ASSERT_EQ(a.scalar(), 1.); - - // TODO: fixture this - const float v = 3.14; - auto q = fl::mean(fl::fromScalar(v)); - ASSERT_EQ(q.shape(), Shape()); - ASSERT_EQ(q.scalar(), v); - ASSERT_EQ(fl::mean(fl::fromScalar(v), {0}).shape(), Shape()); + auto r = fl::rand({8, 7, 6}); + ASSERT_NEAR(fl::mean(r).scalar(), 0.5, 0.05); + ASSERT_EQ( + fl::mean(r, {0, 1}, /* keepDims = */ true).shape(), + Shape({1, 1, 6}) + ); + + auto s = fl::full({5, 6, 7}, 1); + ASSERT_TRUE(allClose(fl::mean(s, {0}), fl::full({6, 7}, 1.))); + + auto a = fl::mean(fl::full({5, 5, 5, 5}, 1)); + ASSERT_EQ(a.shape(), Shape({})); + ASSERT_EQ(a.elements(), 1); + ASSERT_EQ(a.scalar(), 1.); + + // TODO: fixture this + const float v = 3.14; + auto q = fl::mean(fl::fromScalar(v)); + ASSERT_EQ(q.shape(), Shape()); + ASSERT_EQ(q.scalar(), v); + ASSERT_EQ(fl::mean(fl::fromScalar(v), {0}).shape(), Shape()); } TEST(TensorReductionTest, median) { - auto a = Tensor::fromVector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); - ASSERT_EQ(fl::median(a).scalar(), 4.5); - ASSERT_TRUE(allClose(fl::median(a, {0}), fl::fromScalar(4.5))); - ASSERT_EQ(fl::median(fl::rand({5, 6, 7, 8}), {1, 2}).shape(), Shape({5, 8})); - ASSERT_EQ( - fl::median(fl::rand({5, 6, 7, 8}), {1, 2}, /* keepDims = */ true).shape(), - Shape({5, 1, 1, 8})); - - auto b = fl::median(fl::full({5, 5, 5, 5}, 1)); - ASSERT_EQ(b.shape(), Shape({})); - ASSERT_EQ(b.elements(), 1); - ASSERT_EQ(b.scalar(), 1.); - - const float v = 3.14; - auto q = fl::median(fl::fromScalar(v)); - ASSERT_EQ(q.shape(), Shape()); - ASSERT_EQ(q.scalar(), v); - ASSERT_EQ(fl::median(fl::fromScalar(v), {0}).shape(), Shape()); + auto a = Tensor::fromVector({0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + ASSERT_EQ(fl::median(a).scalar(), 4.5); + ASSERT_TRUE(allClose(fl::median(a, {0}), fl::fromScalar(4.5))); + ASSERT_EQ(fl::median(fl::rand({5, 6, 7, 8}), {1, 2}).shape(), Shape({5, 8})); + ASSERT_EQ( + fl::median(fl::rand({5, 6, 7, 8}), {1, 2}, /* keepDims = */ true).shape(), + Shape({5, 1, 1, 8}) + ); + + auto b = fl::median(fl::full({5, 5, 5, 5}, 1)); + ASSERT_EQ(b.shape(), Shape({})); + ASSERT_EQ(b.elements(), 1); + ASSERT_EQ(b.scalar(), 1.); + + const float v = 3.14; + auto q = fl::median(fl::fromScalar(v)); + ASSERT_EQ(q.shape(), Shape()); + ASSERT_EQ(q.scalar(), v); + ASSERT_EQ(fl::median(fl::fromScalar(v), {0}).shape(), Shape()); } TEST(TensorReductionTest, var) { - auto r = fl::rand({7, 8, 9}); - auto varAll = fl::var(r); - ASSERT_NEAR(varAll.scalar(), 0.08333, 0.01); - ASSERT_EQ(varAll.shape(), Shape({})); - ASSERT_EQ(varAll.elements(), 1); - - ASSERT_EQ( - fl::var(r, {0, 1}, /* bias = */ false, /* keepDims = */ true).shape(), - Shape({1, 1, 9})); - - auto s = fl::full({5, 6, 7}, 1); - ASSERT_TRUE(allClose(fl::var(s, {0}), fl::full({6, 7}, 0.))); - auto a = fl::rand({5, 5}); - ASSERT_TRUE(allClose(fl::var(a), fl::var(a, {0, 1}))); - - const float v = 3.14; - auto q = fl::var(fl::fromScalar(v)); - ASSERT_EQ(q.shape(), Shape()); - ASSERT_EQ(q.scalar(), 0); - ASSERT_EQ(fl::var(fl::fromScalar(v), {0}).shape(), Shape()); + auto r = fl::rand({7, 8, 9}); + auto varAll = fl::var(r); + ASSERT_NEAR(varAll.scalar(), 0.08333, 0.01); + ASSERT_EQ(varAll.shape(), Shape({})); + ASSERT_EQ(varAll.elements(), 1); + + ASSERT_EQ( + fl::var(r, {0, 1}, /* bias = */ false, /* keepDims = */ true).shape(), + Shape({1, 1, 9}) + ); + + auto s = fl::full({5, 6, 7}, 1); + ASSERT_TRUE(allClose(fl::var(s, {0}), fl::full({6, 7}, 0.))); + auto a = fl::rand({5, 5}); + ASSERT_TRUE(allClose(fl::var(a), fl::var(a, {0, 1}))); + + const float v = 3.14; + auto q = fl::var(fl::fromScalar(v)); + ASSERT_EQ(q.shape(), Shape()); + ASSERT_EQ(q.scalar(), 0); + ASSERT_EQ(fl::var(fl::fromScalar(v), {0}).shape(), Shape()); } TEST(TensorReductionTest, std) { - auto r = fl::rand({7, 8, 9}); - ASSERT_NEAR(fl::std(r).scalar(), 0.2886, 0.005); - ASSERT_EQ( - fl::std(r, {0, 1}, /* keepDims = */ true).shape(), Shape({1, 1, 9})); - - auto s = fl::full({5, 6, 7}, 1); - ASSERT_TRUE(allClose(fl::std(s, {0}), fl::full({6, 7}, 0.))); - ASSERT_TRUE(allClose(fl::std(s, {1}), fl::sqrt(fl::var(s, {1})))); - - const float v = 3.14; - auto q = fl::std(fl::fromScalar(v)); - ASSERT_EQ(q.shape(), Shape()); - ASSERT_EQ(q.scalar(), 0); - ASSERT_EQ(fl::std(fl::fromScalar(v), {0}).shape(), Shape()); + auto r = fl::rand({7, 8, 9}); + ASSERT_NEAR(fl::std(r).scalar(), 0.2886, 0.005); + ASSERT_EQ( + fl::std(r, {0, 1}, /* keepDims = */ true).shape(), + Shape({1, 1, 9}) + ); + + auto s = fl::full({5, 6, 7}, 1); + ASSERT_TRUE(allClose(fl::std(s, {0}), fl::full({6, 7}, 0.))); + ASSERT_TRUE(allClose(fl::std(s, {1}), fl::sqrt(fl::var(s, {1})))); + + const float v = 3.14; + auto q = fl::std(fl::fromScalar(v)); + ASSERT_EQ(q.shape(), Shape()); + ASSERT_EQ(q.scalar(), 0); + ASSERT_EQ(fl::std(fl::fromScalar(v), {0}).shape(), Shape()); } TEST(TensorReductionTest, norm) { - auto r = fl::full({7, 8, 9}, 1); - auto normAll = fl::norm(r); - ASSERT_FLOAT_EQ(normAll.scalar(), std::sqrt(7 * 8 * 9)); - ASSERT_EQ(normAll.shape(), Shape({})); - ASSERT_EQ(normAll.elements(), 1); - ASSERT_FLOAT_EQ( - fl::norm(fl::full({5, 5}, 1.)).scalar(), std::sqrt(5 * 5)); - ASSERT_EQ( - fl::norm(r, {0, 1}, /* p = */ 2, /* keepDims = */ true).shape(), - Shape({1, 1, 9})); - - ASSERT_FLOAT_EQ(fl::norm(r, {0}).scalar(), std::sqrt(7)); - - const float v = 3.14; - auto q = fl::norm(fl::fromScalar(v)); - ASSERT_EQ(q.shape(), Shape()); - ASSERT_NEAR(q.scalar(), 3.14, 1e-4); - ASSERT_EQ(fl::norm(fl::fromScalar(v), {0}).shape(), Shape()); + auto r = fl::full({7, 8, 9}, 1); + auto normAll = fl::norm(r); + ASSERT_FLOAT_EQ(normAll.scalar(), std::sqrt(7 * 8 * 9)); + ASSERT_EQ(normAll.shape(), Shape({})); + ASSERT_EQ(normAll.elements(), 1); + ASSERT_FLOAT_EQ( + fl::norm(fl::full({5, 5}, 1.)).scalar(), + std::sqrt(5 * 5) + ); + ASSERT_EQ( + fl::norm(r, {0, 1}, /* p = */ 2, /* keepDims = */ true).shape(), + Shape({1, 1, 9}) + ); + + ASSERT_FLOAT_EQ(fl::norm(r, {0}).scalar(), std::sqrt(7)); + + const float v = 3.14; + auto q = fl::norm(fl::fromScalar(v)); + ASSERT_EQ(q.shape(), Shape()); + ASSERT_NEAR(q.scalar(), 3.14, 1e-4); + ASSERT_EQ(fl::norm(fl::fromScalar(v), {0}).shape(), Shape()); } TEST(TensorReductionTest, any) { - using fl::dtype; - auto t = Tensor::fromVector({3, 3}, {1, 0, 0, 0, 0, 0, 0, 0, 1}); - auto anyAll = fl::any(t); - ASSERT_EQ(anyAll.shape(), Shape({})); - ASSERT_EQ(anyAll.elements(), 1); - ASSERT_TRUE(anyAll.scalar()); - ASSERT_TRUE(allClose( - fl::any(t, {0}), - Tensor::fromVector({1, 0, 1}).astype(dtype::b8))); - ASSERT_TRUE(allClose(fl::any(t, {0, 1}), fl::fromScalar(true, dtype::b8))); - ASSERT_FALSE(fl::any(Tensor::fromVector({0, 0, 0})).scalar()); - - auto keptDims = fl::any( - fl::any(t, {1}, /* keepDims = */ true), {0}, /* keepDims = */ true); - ASSERT_EQ(keptDims.shape(), Shape({1, 1})); - ASSERT_EQ(keptDims.scalar(), fl::any(t, {0, 1}).scalar()); - auto q = fl::any(fl::full({5, 5, 5, 5}, 1)); - ASSERT_EQ(q.shape(), Shape({})); - ASSERT_EQ(q.elements(), 1); - ASSERT_EQ(q.scalar(), true); - - const float v = 3.14; - auto r = fl::any(fl::fromScalar(v)); - ASSERT_EQ(r.shape(), Shape()); - ASSERT_TRUE(r.scalar()); - ASSERT_EQ(fl::any(fl::fromScalar(v), {0}).shape(), Shape()); + using fl::dtype; + auto t = Tensor::fromVector({3, 3}, {1, 0, 0, 0, 0, 0, 0, 0, 1}); + auto anyAll = fl::any(t); + ASSERT_EQ(anyAll.shape(), Shape({})); + ASSERT_EQ(anyAll.elements(), 1); + ASSERT_TRUE(anyAll.scalar()); + ASSERT_TRUE( + allClose( + fl::any(t, {0}), + Tensor::fromVector({1, 0, 1}).astype(dtype::b8) + ) + ); + ASSERT_TRUE(allClose(fl::any(t, {0, 1}), fl::fromScalar(true, dtype::b8))); + ASSERT_FALSE(fl::any(Tensor::fromVector({0, 0, 0})).scalar()); + + auto keptDims = fl::any( + fl::any(t, {1}, /* keepDims = */ true), + {0}, /* keepDims = */ + true + ); + ASSERT_EQ(keptDims.shape(), Shape({1, 1})); + ASSERT_EQ(keptDims.scalar(), fl::any(t, {0, 1}).scalar()); + auto q = fl::any(fl::full({5, 5, 5, 5}, 1)); + ASSERT_EQ(q.shape(), Shape({})); + ASSERT_EQ(q.elements(), 1); + ASSERT_EQ(q.scalar(), true); + + const float v = 3.14; + auto r = fl::any(fl::fromScalar(v)); + ASSERT_EQ(r.shape(), Shape()); + ASSERT_TRUE(r.scalar()); + ASSERT_EQ(fl::any(fl::fromScalar(v), {0}).shape(), Shape()); } TEST(TensorReductionTest, all) { - using fl::dtype; - auto t = Tensor::fromVector({3, 3}, {1, 0, 0, 0, 0, 0, 0, 0, 1}); - auto allAll = fl::all(t); - ASSERT_EQ(allAll.shape(), Shape({})); - ASSERT_EQ(allAll.elements(), 1); - ASSERT_FALSE(allAll.scalar()); - ASSERT_TRUE(allClose( - fl::all(t, {0}), - Tensor::fromVector({0, 0, 0}).astype(dtype::b8))); - ASSERT_TRUE(allClose(fl::all(t, {0, 1}), fl::fromScalar(false, dtype::b8))); - ASSERT_TRUE(fl::all(Tensor::fromVector({1, 1, 1})).scalar()); - - auto keptDims = fl::all( - fl::all(t, {1}, /* keepDims = */ true), {0}, /* keepDims = */ true); - ASSERT_EQ(keptDims.shape(), Shape({1, 1})); - ASSERT_EQ(keptDims.scalar(), fl::all(t, {0, 1}).scalar()); - auto q = fl::all(fl::full({5, 5, 5, 5}, 1)); - ASSERT_EQ(q.shape(), Shape({})); - ASSERT_EQ(q.elements(), 1); - ASSERT_EQ(q.scalar(), true); - - const float v = 3.14; - auto a = fl::all(fl::fromScalar(v)); - ASSERT_EQ(a.shape(), Shape()); - ASSERT_TRUE(a.scalar()); - ASSERT_EQ(fl::all(fl::fromScalar(v), {0}).shape(), Shape()); + using fl::dtype; + auto t = Tensor::fromVector({3, 3}, {1, 0, 0, 0, 0, 0, 0, 0, 1}); + auto allAll = fl::all(t); + ASSERT_EQ(allAll.shape(), Shape({})); + ASSERT_EQ(allAll.elements(), 1); + ASSERT_FALSE(allAll.scalar()); + ASSERT_TRUE( + allClose( + fl::all(t, {0}), + Tensor::fromVector({0, 0, 0}).astype(dtype::b8) + ) + ); + ASSERT_TRUE(allClose(fl::all(t, {0, 1}), fl::fromScalar(false, dtype::b8))); + ASSERT_TRUE(fl::all(Tensor::fromVector({1, 1, 1})).scalar()); + + auto keptDims = fl::all( + fl::all(t, {1}, /* keepDims = */ true), + {0}, /* keepDims = */ + true + ); + ASSERT_EQ(keptDims.shape(), Shape({1, 1})); + ASSERT_EQ(keptDims.scalar(), fl::all(t, {0, 1}).scalar()); + auto q = fl::all(fl::full({5, 5, 5, 5}, 1)); + ASSERT_EQ(q.shape(), Shape({})); + ASSERT_EQ(q.elements(), 1); + ASSERT_EQ(q.scalar(), true); + + const float v = 3.14; + auto a = fl::all(fl::fromScalar(v)); + ASSERT_EQ(a.shape(), Shape()); + ASSERT_TRUE(a.scalar()); + ASSERT_EQ(fl::all(fl::fromScalar(v), {0}).shape(), Shape()); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp b/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp index b3bcfb7..81a3f93 100644 --- a/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp +++ b/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp @@ -16,199 +16,230 @@ using namespace ::testing; using namespace fl; TEST(TensorUnaryOpsTest, negative) { - auto a = fl::full({3, 3}, 1); - auto b = fl::full({3, 3}, 2); - auto c = a - b; - ASSERT_TRUE(allClose(c, -a)); - ASSERT_TRUE(allClose(c, negative(a))); + auto a = fl::full({3, 3}, 1); + auto b = fl::full({3, 3}, 2); + auto c = a - b; + ASSERT_TRUE(allClose(c, -a)); + ASSERT_TRUE(allClose(c, negative(a))); } TEST(TensorUnaryOpsTest, logicalNot) { - ASSERT_TRUE(allClose( - !fl::full({3, 3}, true), fl::full({3, 3}, false).astype(dtype::b8))); + ASSERT_TRUE( + allClose( + !fl::full({3, 3}, true), + fl::full({3, 3}, false).astype(dtype::b8) + ) + ); } TEST(TensorUnaryOpsTest, clip) { - float h = 3.; - float l = 2.; - Shape s = {3, 3}; - auto high = fl::full(s, h); - auto low = fl::full(s, l); - ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), low, high), high)); - ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), l, high), high)); - ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), low, h), high)); - ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), l, h), high)); + float h = 3.; + float l = 2.; + Shape s = {3, 3}; + auto high = fl::full(s, h); + auto low = fl::full(s, l); + ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), low, high), high)); + ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), l, high), high)); + ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), low, h), high)); + ASSERT_TRUE(allClose(fl::clip(fl::full({3, 3}, 4.), l, h), high)); } TEST(TensorUnaryOpsTest, roll) { - auto t = fl::full({5, 5}, 4.); - ASSERT_TRUE(allClose(t, fl::roll(t, /* shift = */ 3, /* axis = */ 1))); - - Shape dims({4, 5}); - auto r = fl::arange(dims); - auto result = fl::roll(r, /* shift = */ 1, /* axis = */ 0); - ASSERT_EQ(r.shape(), result.shape()); - ASSERT_TRUE(allClose(result(0), fl::full({dims[1]}, dims[0] - 1, r.type()))); - ASSERT_TRUE(allClose( - result(fl::range(1, fl::end)), - fl::arange({dims[0] - 1, dims[1]}, /* seqDim = */ 0, r.type()))); + auto t = fl::full({5, 5}, 4.); + ASSERT_TRUE(allClose(t, fl::roll(t, /* shift = */ 3, /* axis = */ 1))); + + Shape dims({4, 5}); + auto r = fl::arange(dims); + auto result = fl::roll(r, /* shift = */ 1, /* axis = */ 0); + ASSERT_EQ(r.shape(), result.shape()); + ASSERT_TRUE(allClose(result(0), fl::full({dims[1]}, dims[0] - 1, r.type()))); + ASSERT_TRUE( + allClose( + result(fl::range(1, fl::end)), + fl::arange({dims[0] - 1, dims[1]}, /* seqDim = */ 0, r.type()) + ) + ); } TEST(TensorUnaryOpsTest, isnan) { - Shape s = {3, 3}; - ASSERT_TRUE(allClose( - fl::isnan(fl::full(s, 1.) / 3), - fl::full(s, false).astype(fl::dtype::b8))); + Shape s = {3, 3}; + ASSERT_TRUE( + allClose( + fl::isnan(fl::full(s, 1.) / 3), + fl::full(s, false).astype(fl::dtype::b8) + ) + ); } TEST(TensorUnaryOpsTest, isinf) { - Shape s = {3, 3}; - ASSERT_TRUE(allClose( - fl::isinf(fl::full(s, 1.) / 3), - fl::full(s, false).astype(fl::dtype::b8))); - ASSERT_TRUE(allClose( - fl::isinf(fl::full(s, 1.) / 0.), - fl::full(s, true).astype(fl::dtype::b8))); + Shape s = {3, 3}; + ASSERT_TRUE( + allClose( + fl::isinf(fl::full(s, 1.) / 3), + fl::full(s, false).astype(fl::dtype::b8) + ) + ); + ASSERT_TRUE( + allClose( + fl::isinf(fl::full(s, 1.) / 0.), + fl::full(s, true).astype(fl::dtype::b8) + ) + ); } TEST(TensorUnaryOpsTest, sign) { - auto vals = fl::rand({5, 5}) - 0.5; - vals(2, 2) = 0.; - auto signs = fl::sign(vals); - vals(vals > 0) = 1; - vals(vals == 0) = 0; - vals(vals < 0) = -1; - ASSERT_TRUE(allClose(signs, vals)); + auto vals = fl::rand({5, 5}) - 0.5; + vals(2, 2) = 0.; + auto signs = fl::sign(vals); + vals(vals > 0) = 1; + vals(vals == 0) = 0; + vals(vals < 0) = -1; + ASSERT_TRUE(allClose(signs, vals)); } TEST(TensorUnaryOpsTest, tril) { - auto checkSquareTril = - [](const Dim dim, const Tensor& res, const Tensor& in) { - for (int i = 0; i < dim; ++i) { - for (int j = i + 1; j < dim; ++j) { - ASSERT_EQ(res(i, j).scalar(), 0.); - } - } - for (int i = 0; i < dim; ++i) { - for (int j = 0; j < i; ++j) { - ASSERT_TRUE(allClose(res(i, j), in(i, j))); - } - } - }; - Dim dim = 10; - auto t = fl::rand({dim, dim}); - auto out = fl::tril(t); - checkSquareTril(dim, out, t); - - // TODO: this could be bogus behavior - // > 2 dims - Dim dim2 = 3; - auto t2 = fl::rand({dim2, dim2, dim2}); - auto out2 = fl::tril(t2); - for (unsigned i = 0; i < dim2; ++i) { - checkSquareTril( - dim2, out2(fl::span, fl::span, i), t2(fl::span, fl::span, i)); - } + auto checkSquareTril = + [](const Dim dim, const Tensor& res, const Tensor& in) { + for(int i = 0; i < dim; ++i) { + for(int j = i + 1; j < dim; ++j) { + ASSERT_EQ(res(i, j).scalar(), 0.); + } + } + for(int i = 0; i < dim; ++i) { + for(int j = 0; j < i; ++j) { + ASSERT_TRUE(allClose(res(i, j), in(i, j))); + } + } + }; + Dim dim = 10; + auto t = fl::rand({dim, dim}); + auto out = fl::tril(t); + checkSquareTril(dim, out, t); + + // TODO: this could be bogus behavior + // > 2 dims + Dim dim2 = 3; + auto t2 = fl::rand({dim2, dim2, dim2}); + auto out2 = fl::tril(t2); + for(unsigned i = 0; i < dim2; ++i) { + checkSquareTril( + dim2, + out2(fl::span, fl::span, i), + t2(fl::span, fl::span, i) + ); + } } TEST(TensorUnaryOpsTest, triu) { - auto checkSquareTriu = - [](const Dim dim, const Tensor& res, const Tensor& in) { - for (unsigned i = 0; i < dim; ++i) { - for (unsigned j = i + 1; j < dim; ++j) { - ASSERT_TRUE(allClose(res(i, j), in(i, j))); - } - } - for (unsigned i = 0; i < dim; ++i) { - for (unsigned j = 0; j < i; ++j) { - ASSERT_EQ(res(i, j).scalar(), 0.); - } - } - }; - - int dim = 10; - auto t = fl::rand({dim, dim}); - auto out = fl::triu(t); - checkSquareTriu(dim, out, t); - - // TODO: this could be bogus behavior - // > 2 dims - int dim2 = 3; - auto t2 = fl::rand({dim2, dim2, dim2}); - auto out2 = fl::triu(t2); - for (int i = 0; i < dim2; ++i) { - checkSquareTriu( - dim2, out2(fl::span, fl::span, i), t2(fl::span, fl::span, i)); - } + auto checkSquareTriu = + [](const Dim dim, const Tensor& res, const Tensor& in) { + for(unsigned i = 0; i < dim; ++i) { + for(unsigned j = i + 1; j < dim; ++j) { + ASSERT_TRUE(allClose(res(i, j), in(i, j))); + } + } + for(unsigned i = 0; i < dim; ++i) { + for(unsigned j = 0; j < i; ++j) { + ASSERT_EQ(res(i, j).scalar(), 0.); + } + } + }; + + int dim = 10; + auto t = fl::rand({dim, dim}); + auto out = fl::triu(t); + checkSquareTriu(dim, out, t); + + // TODO: this could be bogus behavior + // > 2 dims + int dim2 = 3; + auto t2 = fl::rand({dim2, dim2, dim2}); + auto out2 = fl::triu(t2); + for(int i = 0; i < dim2; ++i) { + checkSquareTriu( + dim2, + out2(fl::span, fl::span, i), + t2(fl::span, fl::span, i) + ); + } } TEST(TensorUnaryOpsTest, floor) { - auto a = fl::rand({10, 10}) + 0.5; - ASSERT_TRUE(allClose((a >= 1.).astype(fl::dtype::f32), fl::floor(a))); + auto a = fl::rand({10, 10}) + 0.5; + ASSERT_TRUE(allClose((a >= 1.).astype(fl::dtype::f32), fl::floor(a))); } TEST(TensorUnaryOpsTest, ceil) { - auto a = fl::rand({10, 10}) + 0.5; - ASSERT_TRUE(allClose((a >= 1).astype(fl::dtype::f32), fl::ceil(a) - 1)); + auto a = fl::rand({10, 10}) + 0.5; + ASSERT_TRUE(allClose((a >= 1).astype(fl::dtype::f32), fl::ceil(a) - 1)); } TEST(TensorUnaryOpsTest, rint) { - Shape s = {10, 10}; - auto a = fl::rand(s) - 0.5; - ASSERT_TRUE(allClose(fl::rint(a), fl::full(s, 0.))); - auto b = fl::rand(s) + 0.5; - ASSERT_TRUE(allClose(fl::rint(b), fl::full(s, 1.))); + Shape s = {10, 10}; + auto a = fl::rand(s) - 0.5; + ASSERT_TRUE(allClose(fl::rint(a), fl::full(s, 0.))); + auto b = fl::rand(s) + 0.5; + ASSERT_TRUE(allClose(fl::rint(b), fl::full(s, 1.))); } TEST(TensorUnaryOpsTest, sigmoid) { - auto a = fl::rand({10, 10}); - ASSERT_TRUE(allClose(1 / (1 + fl::exp(-a)), fl::sigmoid(a))); + auto a = fl::rand({10, 10}); + ASSERT_TRUE(allClose(1 / (1 + fl::exp(-a)), fl::sigmoid(a))); } TEST(TensorUnaryOpsTest, flip) { - const unsigned high = 10; - auto a = fl::arange({high}); - auto flipped = fl::flip(a, /* dim = */ 0); - a *= -1; - a += (high - 1); - ASSERT_TRUE(allClose(a, flipped)); + const unsigned high = 10; + auto a = fl::arange({high}); + auto flipped = fl::flip(a, /* dim = */ 0); + a *= -1; + a += (high - 1); + ASSERT_TRUE(allClose(a, flipped)); - auto b = fl::arange({high, high}, /* seqDim = */ 0); - ASSERT_TRUE(allClose(fl::flip(b, 1), b)); - auto c = fl::arange({high, high}, /* seqDim = */ 1); - ASSERT_TRUE(allClose(fl::flip(c, 0), c)); + auto b = fl::arange({high, high}, /* seqDim = */ 0); + ASSERT_TRUE(allClose(fl::flip(b, 1), b)); + auto c = fl::arange({high, high}, /* seqDim = */ 1); + ASSERT_TRUE(allClose(fl::flip(c, 0), c)); } TEST(TensorUnaryOpsTest, where) { - // 1 0 - // 0 1 - auto cond = fl::Tensor::fromVector({2, 2}, {1, 0, 0, 1}); - // 0 2 - // 1 3 - auto x = fl::Tensor::fromVector({2, 2}, {0, 1, 2, 3}); - // 4 6 - // 5 7 - auto y = fl::Tensor::fromVector({2, 2}, {4, 5, 6, 7}); - - // 0 6 - // 5 3 - ASSERT_TRUE(allClose( - fl::where(cond, x, y), - fl::Tensor::fromVector({2, 2}, {0, 5, 6, 3}))); - // 0 1 - // 1 3 - ASSERT_TRUE(allClose( - fl::where(cond, x, 1.0), - fl::Tensor::fromVector({2, 2}, {0, 1, 1, 3}))); - // 2 6 - // 5 2 - ASSERT_TRUE(allClose( - fl::where(cond, 2.0, y), - fl::Tensor::fromVector({2, 2}, {2, 5, 6, 2}))); + // 1 0 + // 0 1 + auto cond = fl::Tensor::fromVector({2, 2}, {1, 0, 0, 1}); + // 0 2 + // 1 3 + auto x = fl::Tensor::fromVector({2, 2}, {0, 1, 2, 3}); + // 4 6 + // 5 7 + auto y = fl::Tensor::fromVector({2, 2}, {4, 5, 6, 7}); + + // 0 6 + // 5 3 + ASSERT_TRUE( + allClose( + fl::where(cond, x, y), + fl::Tensor::fromVector({2, 2}, {0, 5, 6, 3}) + ) + ); + // 0 1 + // 1 3 + ASSERT_TRUE( + allClose( + fl::where(cond, x, 1.0), + fl::Tensor::fromVector({2, 2}, {0, 1, 1, 3}) + ) + ); + // 2 6 + // 5 2 + ASSERT_TRUE( + allClose( + fl::where(cond, 2.0, y), + fl::Tensor::fromVector({2, 2}, {2, 5, 6, 2}) + ) + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/af/ArrayFireCPUStreamTest.cpp b/flashlight/fl/test/tensor/af/ArrayFireCPUStreamTest.cpp index b1f8c96..b59edc0 100644 --- a/flashlight/fl/test/tensor/af/ArrayFireCPUStreamTest.cpp +++ b/flashlight/fl/test/tensor/af/ArrayFireCPUStreamTest.cpp @@ -18,10 +18,8 @@ using fl::StreamType; using fl::ArrayFireCPUStream; - - int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp b/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp index 3980232..af86e5d 100644 --- a/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp +++ b/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp @@ -31,17 +31,18 @@ namespace { bool allClose( const af::array& a, const af::array& b, - double absTolerance = 1e-5) { - if (a.type() != b.type()) { - return false; - } - if (a.dims() != b.dims()) { - return false; - } - if (a.isempty() && b.isempty()) { - return true; - } - return af::max(af::abs(a - b)) < absTolerance; + double absTolerance = 1e-5 +) { + if(a.type() != b.type()) { + return false; + } + if(a.dims() != b.dims()) { + return false; + } + if(a.isempty() && b.isempty()) { + return true; + } + return af::max(af::abs(a - b)) < absTolerance; } } // namespace @@ -49,25 +50,25 @@ bool allClose( namespace fl { TEST(ArrayFireTensorBaseTest, ArrayFireShapeInterop) { - ASSERT_EQ(detail::afToFlDims(af::dim4(), 0), Shape({})); // scalar - ASSERT_EQ(detail::afToFlDims(af::dim4(0), 1), Shape({0})); - ASSERT_EQ(detail::afToFlDims(af::dim4(0, 5), 2), Shape({0, 5})); - ASSERT_EQ(detail::afToFlDims(af::dim4(1, 0, 2), 3), Shape({1, 0, 2})); - ASSERT_EQ(detail::afToFlDims(af::dim4(0, 1, 1, 1), 4), Shape({0, 1, 1, 1})); - - using namespace fl::detail; - auto dimsEq = [](const af::dim4& d, const Shape& s) { - return detail::afToFlDims(d, s.ndim()) == s; - }; - - ASSERT_TRUE(dimsEq(af::dim4(3), {3})); // not 3, 1, 1, 1 - ASSERT_TRUE(dimsEq(af::dim4(3, 2), {3, 2})); // not 3, 2, 1, 1 - ASSERT_TRUE(dimsEq(af::dim4(3, 1), {3})); - // if explicitly specified, uses implicit 1 dim - ASSERT_TRUE(dimsEq(af::dim4(3, 1), {3, 1})); - ASSERT_TRUE(dimsEq(af::dim4(1, 3, 2), {1, 3, 2})); - ASSERT_TRUE(dimsEq(af::dim4(1), {1})); - ASSERT_TRUE(dimsEq(af::dim4(1, 1, 1), {1})); + ASSERT_EQ(detail::afToFlDims(af::dim4(), 0), Shape({})); // scalar + ASSERT_EQ(detail::afToFlDims(af::dim4(0), 1), Shape({0})); + ASSERT_EQ(detail::afToFlDims(af::dim4(0, 5), 2), Shape({0, 5})); + ASSERT_EQ(detail::afToFlDims(af::dim4(1, 0, 2), 3), Shape({1, 0, 2})); + ASSERT_EQ(detail::afToFlDims(af::dim4(0, 1, 1, 1), 4), Shape({0, 1, 1, 1})); + + using namespace fl::detail; + auto dimsEq = [](const af::dim4& d, const Shape& s) { + return detail::afToFlDims(d, s.ndim()) == s; + }; + + ASSERT_TRUE(dimsEq(af::dim4(3), {3})); // not 3, 1, 1, 1 + ASSERT_TRUE(dimsEq(af::dim4(3, 2), {3, 2})); // not 3, 2, 1, 1 + ASSERT_TRUE(dimsEq(af::dim4(3, 1), {3})); + // if explicitly specified, uses implicit 1 dim + ASSERT_TRUE(dimsEq(af::dim4(3, 1), {3, 1})); + ASSERT_TRUE(dimsEq(af::dim4(1, 3, 2), {1, 3, 2})); + ASSERT_TRUE(dimsEq(af::dim4(1), {1})); + ASSERT_TRUE(dimsEq(af::dim4(1, 1, 1), {1})); } } // namespace fl @@ -75,369 +76,417 @@ TEST(ArrayFireTensorBaseTest, ArrayFireShapeInterop) { namespace { int getRefCount(const af::array& arr, bool sync = true) { - if (sync) { - arr.eval(); - af::sync(); - } - int refCount = 0; - AF_CHECK(af_get_data_ref_count(&refCount, arr.get())); - return refCount; + if(sync) { + arr.eval(); + af::sync(); + } + int refCount = 0; + AF_CHECK(af_get_data_ref_count(&refCount, arr.get())); + return refCount; } } // namespace TEST(ArrayFireTensorBaseTest, AfRefCountBasic) { - // Sanity check that af::arrays moved into fl::Tensors don't have their - // refcount inrcremented/show proper usage of refs in tensor ops - auto q = af::constant(1, {2, 2}); - // without eval/sync, no refcount - ASSERT_EQ(getRefCount(q, /* sync = */ false), 0); + // Sanity check that af::arrays moved into fl::Tensors don't have their + // refcount inrcremented/show proper usage of refs in tensor ops + auto q = af::constant(1, {2, 2}); + // without eval/sync, no refcount + ASSERT_EQ(getRefCount(q, /* sync = */ false), 0); - auto a = af::constant(1, {2, 2}); - ASSERT_EQ(getRefCount(a), 1); + auto a = af::constant(1, {2, 2}); + ASSERT_EQ(getRefCount(a), 1); - auto tensor = toTensor(std::move(a), /* numDims = */ 2); - auto& aRef = toArray(tensor); - ASSERT_EQ(getRefCount(aRef), 1); - // Sanity check copying bumps things - auto aNoRef = toArray(tensor); - ASSERT_EQ(getRefCount(aNoRef), 2); + auto tensor = toTensor(std::move(a), /* numDims = */ 2); + auto& aRef = toArray(tensor); + ASSERT_EQ(getRefCount(aRef), 1); + // Sanity check copying bumps things + auto aNoRef = toArray(tensor); + ASSERT_EQ(getRefCount(aNoRef), 2); } TEST(ArrayFireTensorBaseTest, AfRefCountModify) { - // Compositional operations don't increment refcount - auto a = af::constant(1, {2, 2}); - auto b = af::constant(1, {2, 2}); - auto arrRes = a + b; - ASSERT_EQ(getRefCount(a), 1); - ASSERT_EQ(getRefCount(b), 1); - // Multiple uses of the same variable doesn't push count - auto c = af::constant(1, {2, 2}); - auto d = af::constant(1, {2, 2}); - auto arrResMult = c * c + d * d; - ASSERT_EQ(getRefCount(c), 1); - ASSERT_EQ(getRefCount(d), 1); - - // Same behavior with Tensors - auto v = fl::full({2, 2}, 1); - auto w = fl::full({2, 2}, 1); - auto varRes = v + w; - ASSERT_EQ(getRefCount(toArray(v)), 1); - ASSERT_EQ(getRefCount(toArray(w)), 1); - // Multiuse with variables - auto y = fl::full({2, 2}, 1); - auto z = fl::full({2, 2}, 1); - auto varResMult = y * y + z * z; - ASSERT_EQ(getRefCount(toArray(y)), 1); - ASSERT_EQ(getRefCount(toArray(z)), 1); + // Compositional operations don't increment refcount + auto a = af::constant(1, {2, 2}); + auto b = af::constant(1, {2, 2}); + auto arrRes = a + b; + ASSERT_EQ(getRefCount(a), 1); + ASSERT_EQ(getRefCount(b), 1); + // Multiple uses of the same variable doesn't push count + auto c = af::constant(1, {2, 2}); + auto d = af::constant(1, {2, 2}); + auto arrResMult = c * c + d * d; + ASSERT_EQ(getRefCount(c), 1); + ASSERT_EQ(getRefCount(d), 1); + + // Same behavior with Tensors + auto v = fl::full({2, 2}, 1); + auto w = fl::full({2, 2}, 1); + auto varRes = v + w; + ASSERT_EQ(getRefCount(toArray(v)), 1); + ASSERT_EQ(getRefCount(toArray(w)), 1); + // Multiuse with variables + auto y = fl::full({2, 2}, 1); + auto z = fl::full({2, 2}, 1); + auto varResMult = y * y + z * z; + ASSERT_EQ(getRefCount(toArray(y)), 1); + ASSERT_EQ(getRefCount(toArray(z)), 1); } TEST(ArrayFireTensorBaseTest, astypeRefcount) { - auto t = fl::rand({5, 5}); - ASSERT_EQ(getRefCount(toArray(t)), 1); - auto t64 = t.astype(fl::dtype::f64); - ASSERT_EQ(getRefCount(toArray(t64)), 1); + auto t = fl::rand({5, 5}); + ASSERT_EQ(getRefCount(toArray(t)), 1); + auto t64 = t.astype(fl::dtype::f64); + ASSERT_EQ(getRefCount(toArray(t64)), 1); } TEST(ArrayFireTensorBaseTest, astypeInPlaceRefcount) { - auto a = fl::rand({4, 4}); - ASSERT_EQ(getRefCount(toArray(a)), 1); - a = a.astype(fl::dtype::f64); - ASSERT_EQ(getRefCount(toArray(a)), 1); - ASSERT_EQ(a.type(), fl::dtype::f64); - a = a.astype(fl::dtype::f32); - ASSERT_EQ(getRefCount(toArray(a)), 1); + auto a = fl::rand({4, 4}); + ASSERT_EQ(getRefCount(toArray(a)), 1); + a = a.astype(fl::dtype::f64); + ASSERT_EQ(getRefCount(toArray(a)), 1); + ASSERT_EQ(a.type(), fl::dtype::f64); + a = a.astype(fl::dtype::f32); + ASSERT_EQ(getRefCount(toArray(a)), 1); } TEST(ArrayFireTensorBaseTest, BackendInterop) { - // TODO: test toTensorBackend here since we know we have a backend available; - // design a test that tests with mulitple backends once available - auto a = fl::rand({10, 12}); - ASSERT_EQ(a.backendType(), TensorBackendType::ArrayFire); - auto b = a; - auto t = fl::toTensorType(std::move(a)); - ASSERT_EQ(t.backendType(), TensorBackendType::ArrayFire); - ASSERT_TRUE(allClose(b, t)); + // TODO: test toTensorBackend here since we know we have a backend available; + // design a test that tests with mulitple backends once available + auto a = fl::rand({10, 12}); + ASSERT_EQ(a.backendType(), TensorBackendType::ArrayFire); + auto b = a; + auto t = fl::toTensorType(std::move(a)); + ASSERT_EQ(t.backendType(), TensorBackendType::ArrayFire); + ASSERT_TRUE(allClose(b, t)); } TEST(ArrayFireTensorBaseTest, withTensorType) { - // TODO: test with here since we know we have a backend available; - // design a test that tests with mulitple backends once available - Tensor t; - fl::withTensorType([&t]() { - t = fl::full({5, 5}, 6.); - t += 1; - }); - ASSERT_TRUE(allClose(t, fl::full({5, 5}, 7.))); + // TODO: test with here since we know we have a backend available; + // design a test that tests with mulitple backends once available + Tensor t; + fl::withTensorType( + [&t]() { + t = fl::full({5, 5}, 6.); + t += 1; + } + ); + ASSERT_TRUE(allClose(t, fl::full({5, 5}, 7.))); } TEST(ArrayFireTensorBaseTest, ArrayFireAssignmentOperators) { - fl::Tensor a = fl::full({3, 3}, 1.); - af::array& aArr = toArray(a); - ASSERT_EQ(getRefCount(aArr), 1); + fl::Tensor a = fl::full({3, 3}, 1.); + af::array& aArr = toArray(a); + ASSERT_EQ(getRefCount(aArr), 1); - auto b = a; // share the same underlying array but bump refcount - ASSERT_EQ(getRefCount(aArr), 2); + auto b = a; // share the same underlying array but bump refcount + ASSERT_EQ(getRefCount(aArr), 2); - auto c = a.copy(); // defers deep copy to AF - ASSERT_EQ(getRefCount(aArr), 2); + auto c = a.copy(); // defers deep copy to AF + ASSERT_EQ(getRefCount(aArr), 2); - // copy, else it'll get released below when we reassign a new tensor to b - af::array bArr = toArray(b); - ASSERT_EQ(getRefCount(bArr), 3); // aArr, bArr, b.arrayHandle_ + // copy, else it'll get released below when we reassign a new tensor to b + af::array bArr = toArray(b); + ASSERT_EQ(getRefCount(bArr), 3); // aArr, bArr, b.arrayHandle_ - b = fl::full({4, 4}, 2.); // b points to a new array now - ASSERT_EQ(getRefCount(bArr), 2); // aArr, bArr + b = fl::full({4, 4}, 2.); // b points to a new array now + ASSERT_EQ(getRefCount(bArr), 2); // aArr, bArr - ASSERT_EQ(getRefCount(aArr), 2); // aArr, bArr + ASSERT_EQ(getRefCount(aArr), 2); // aArr, bArr } TEST(ArrayFireTensorBaseTest, BinaryOperators) { - auto a = - toTensor(af::constant(1, {2, 2}), /* numDims = */ 2); - auto b = - toTensor(af::constant(2, {2, 2}), /* numDims = */ 2); - auto c = - toTensor(af::constant(3, {2, 2}), /* numDims = */ 2); + auto a = + toTensor(af::constant(1, {2, 2}), /* numDims = */ 2); + auto b = + toTensor(af::constant(2, {2, 2}), /* numDims = */ 2); + auto c = + toTensor(af::constant(3, {2, 2}), /* numDims = */ 2); - ASSERT_TRUE(allClose(toArray(a == b), (toArray(a) == toArray(b)))); - ASSERT_TRUE(allClose((a == b), eq(a, b))); - ASSERT_TRUE(allClose((a + b), c)); - ASSERT_TRUE(allClose((a + b), add(a, b))); + ASSERT_TRUE(allClose(toArray(a == b), (toArray(a) == toArray(b)))); + ASSERT_TRUE(allClose((a == b), eq(a, b))); + ASSERT_TRUE(allClose((a + b), c)); + ASSERT_TRUE(allClose((a + b), add(a, b))); } TEST(ArrayFireTensorBaseTest, full) { - // TODO: expand with fixtures for each type - auto a = fl::full({3, 4}, 3.); - ASSERT_EQ(a.shape(), Shape({3, 4})); - ASSERT_EQ(a.type(), fl::dtype::f32); - ASSERT_TRUE(allClose(toArray(a), af::constant(3., {3, 4}))); + // TODO: expand with fixtures for each type + auto a = fl::full({3, 4}, 3.); + ASSERT_EQ(a.shape(), Shape({3, 4})); + ASSERT_EQ(a.type(), fl::dtype::f32); + ASSERT_TRUE(allClose(toArray(a), af::constant(3., {3, 4}))); - auto b = fl::full({1, 1, 5, 4}, 4.5); - ASSERT_EQ(b.shape(), Shape({1, 1, 5, 4})); - ASSERT_EQ(b.type(), fl::dtype::f32); - ASSERT_TRUE(allClose(toArray(b), af::constant(4.5, {1, 1, 5, 4}))); + auto b = fl::full({1, 1, 5, 4}, 4.5); + ASSERT_EQ(b.shape(), Shape({1, 1, 5, 4})); + ASSERT_EQ(b.type(), fl::dtype::f32); + ASSERT_TRUE(allClose(toArray(b), af::constant(4.5, {1, 1, 5, 4}))); } TEST(ArrayFireTensorBaseTest, identity) { - auto a = fl::identity(6); - ASSERT_EQ(a.shape(), Shape({6, 6})); - ASSERT_EQ(a.type(), fl::dtype::f32); - ASSERT_TRUE(allClose(toArray(a), af::identity({6, 6}))); + auto a = fl::identity(6); + ASSERT_EQ(a.shape(), Shape({6, 6})); + ASSERT_EQ(a.type(), fl::dtype::f32); + ASSERT_TRUE(allClose(toArray(a), af::identity({6, 6}))); - ASSERT_EQ(fl::identity(6, fl::dtype::f64).type(), fl::dtype::f64); + ASSERT_EQ(fl::identity(6, fl::dtype::f64).type(), fl::dtype::f64); } TEST(ArrayFireTensorBaseTest, randn) { - int s = 30; - auto a = fl::randn({s, s}); - ASSERT_EQ(a.shape(), Shape({s, s})); - ASSERT_EQ(a.type(), fl::dtype::f32); - ASSERT_TRUE(af::allTrue( - af::abs(af::mean(af::moddims(toArray(a), s * s, 1, 1, 1))) < 2)); + int s = 30; + auto a = fl::randn({s, s}); + ASSERT_EQ(a.shape(), Shape({s, s})); + ASSERT_EQ(a.type(), fl::dtype::f32); + ASSERT_TRUE( + af::allTrue( + af::abs(af::mean(af::moddims(toArray(a), s * s, 1, 1, 1))) < 2 + ) + ); - ASSERT_EQ(fl::randn({1}, fl::dtype::f64).type(), fl::dtype::f64); + ASSERT_EQ(fl::randn({1}, fl::dtype::f64).type(), fl::dtype::f64); } TEST(ArrayFireTensorBaseTest, rand) { - int s = 30; - auto a = fl::rand({s, s}); - ASSERT_EQ(a.shape(), Shape({s, s})); - ASSERT_EQ(a.type(), fl::dtype::f32); - ASSERT_TRUE(af::allTrue(toArray(a) <= 1)); - ASSERT_TRUE(af::allTrue(toArray(a) >= 0)); + int s = 30; + auto a = fl::rand({s, s}); + ASSERT_EQ(a.shape(), Shape({s, s})); + ASSERT_EQ(a.type(), fl::dtype::f32); + ASSERT_TRUE(af::allTrue(toArray(a) <= 1)); + ASSERT_TRUE(af::allTrue(toArray(a) >= 0)); - ASSERT_EQ(fl::rand({1}, fl::dtype::f64).type(), fl::dtype::f64); + ASSERT_EQ(fl::rand({1}, fl::dtype::f64).type(), fl::dtype::f64); } TEST(ArrayFireTensorBaseTest, amin) { - auto a = fl::rand({3, 3}); - ASSERT_EQ(fl::amin(a).scalar(), af::min(toArray(a))); - ASSERT_TRUE(allClose( - toArray(fl::amin(a, {0})), - fl::detail::condenseIndices(af::min(toArray(a), 0)))); + auto a = fl::rand({3, 3}); + ASSERT_EQ(fl::amin(a).scalar(), af::min(toArray(a))); + ASSERT_TRUE( + allClose( + toArray(fl::amin(a, {0})), + fl::detail::condenseIndices(af::min(toArray(a), 0)) + ) + ); } TEST(ArrayFireTensorBaseTest, amax) { - auto a = fl::rand({3, 3}); - ASSERT_EQ(fl::amax(a).scalar(), af::max(toArray(a))); - ASSERT_TRUE(allClose( - toArray(fl::amax(a, {0})), - fl::detail::condenseIndices(af::max(toArray(a), 0)))); + auto a = fl::rand({3, 3}); + ASSERT_EQ(fl::amax(a).scalar(), af::max(toArray(a))); + ASSERT_TRUE( + allClose( + toArray(fl::amax(a, {0})), + fl::detail::condenseIndices(af::max(toArray(a), 0)) + ) + ); } TEST(ArrayFireTensorBaseTest, sum) { - auto a = fl::rand({3, 3}); - ASSERT_NEAR(fl::sum(a).scalar(), af::sum(toArray(a)), 1e-5); - ASSERT_TRUE(allClose( - toArray(fl::sum(a, {0})), - fl::detail::condenseIndices(af::sum(toArray(a), 0)))); - - auto b = fl::rand({5, 6, 7, 8}); - ASSERT_NEAR(fl::sum(b).scalar(), af::sum(toArray(b)), 1e-3); - ASSERT_TRUE(allClose( - toArray(fl::sum(b, {1, 2})), - fl::detail::condenseIndices(af::sum(af::sum(toArray(b), 1), 2)))); + auto a = fl::rand({3, 3}); + ASSERT_NEAR(fl::sum(a).scalar(), af::sum(toArray(a)), 1e-5); + ASSERT_TRUE( + allClose( + toArray(fl::sum(a, {0})), + fl::detail::condenseIndices(af::sum(toArray(a), 0)) + ) + ); + + auto b = fl::rand({5, 6, 7, 8}); + ASSERT_NEAR(fl::sum(b).scalar(), af::sum(toArray(b)), 1e-3); + ASSERT_TRUE( + allClose( + toArray(fl::sum(b, {1, 2})), + fl::detail::condenseIndices(af::sum(af::sum(toArray(b), 1), 2)) + ) + ); } TEST(ArrayFireTensorBaseTest, exp) { - auto in = fl::full({3, 3}, 4.f); - ASSERT_TRUE(allClose(toArray(fl::exp(in)), af::exp(toArray(in)))); + auto in = fl::full({3, 3}, 4.f); + ASSERT_TRUE(allClose(toArray(fl::exp(in)), af::exp(toArray(in)))); } TEST(ArrayFireTensorBaseTest, log) { - auto in = fl::full({3, 3}, 2.f); - ASSERT_TRUE(allClose(toArray(fl::log(in)), af::log(toArray(in)))); + auto in = fl::full({3, 3}, 2.f); + ASSERT_TRUE(allClose(toArray(fl::log(in)), af::log(toArray(in)))); } TEST(ArrayFireTensorBaseTest, log1p) { - auto in = fl::rand({3, 3}); - ASSERT_TRUE(allClose(fl::log1p(in), fl::log(1 + in))); + auto in = fl::rand({3, 3}); + ASSERT_TRUE(allClose(fl::log1p(in), fl::log(1 + in))); } TEST(ArrayFireTensorBaseTest, sin) { - auto in = fl::rand({3, 3}); - ASSERT_TRUE(allClose(toArray(fl::sin(in)), af::sin(toArray(in)))); + auto in = fl::rand({3, 3}); + ASSERT_TRUE(allClose(toArray(fl::sin(in)), af::sin(toArray(in)))); } TEST(ArrayFireTensorBaseTest, cos) { - auto in = fl::rand({3, 3}); - ASSERT_TRUE(allClose(toArray(fl::cos(in)), af::cos(toArray(in)))); + auto in = fl::rand({3, 3}); + ASSERT_TRUE(allClose(toArray(fl::cos(in)), af::cos(toArray(in)))); } TEST(ArrayFireTensorBaseTest, sqrt) { - auto in = fl::full({3, 3}, 4.f); - ASSERT_TRUE(allClose(fl::sqrt(in), in / 2)); + auto in = fl::full({3, 3}, 4.f); + ASSERT_TRUE(allClose(fl::sqrt(in), in / 2)); } TEST(ArrayFireTensorBaseTest, tanh) { - auto in = fl::rand({3, 3}); - ASSERT_TRUE(allClose(toArray(fl::tanh(in)), af::tanh(toArray(in)))); + auto in = fl::rand({3, 3}); + ASSERT_TRUE(allClose(toArray(fl::tanh(in)), af::tanh(toArray(in)))); } TEST(ArrayFireTensorBaseTest, absolute) { - float val = -3.1; - ASSERT_TRUE(allClose(fl::abs(fl::full({3, 3}, val)), fl::full({3, 3}, -val))); + float val = -3.1; + ASSERT_TRUE(allClose(fl::abs(fl::full({3, 3}, val)), fl::full({3, 3}, -val))); } TEST(ArrayFireTensorBaseTest, erf) { - auto in = fl::rand({3, 3}); - ASSERT_TRUE(allClose(toArray(fl::erf(in)), af::erf(toArray(in)))); + auto in = fl::rand({3, 3}); + ASSERT_TRUE(allClose(toArray(fl::erf(in)), af::erf(toArray(in)))); } TEST(ArrayFireTensorBaseTest, mean) { - auto a = fl::rand({3, 50}); - ASSERT_NEAR(fl::mean(a).scalar(), af::mean(toArray(a)), 1e-4); - ASSERT_TRUE(allClose( - toArray(fl::mean(a, {0})), - detail::condenseIndices(af::mean(toArray(a), 0)))); + auto a = fl::rand({3, 50}); + ASSERT_NEAR(fl::mean(a).scalar(), af::mean(toArray(a)), 1e-4); + ASSERT_TRUE( + allClose( + toArray(fl::mean(a, {0})), + detail::condenseIndices(af::mean(toArray(a), 0)) + ) + ); } TEST(ArrayFireTensorBaseTest, median) { - auto a = fl::rand({3, 50}); - ASSERT_NEAR( - fl::median(a).scalar(), af::median(toArray(a)), 1e-3); - ASSERT_TRUE(allClose( - toArray(fl::median(a, {0})), - detail::condenseIndices(af::median(toArray(a), 0)))); + auto a = fl::rand({3, 50}); + ASSERT_NEAR( + fl::median(a).scalar(), + af::median(toArray(a)), + 1e-3 + ); + ASSERT_TRUE( + allClose( + toArray(fl::median(a, {0})), + detail::condenseIndices(af::median(toArray(a), 0)) + ) + ); } TEST(ArrayFireTensorBaseTest, var) { - const bool bias = false; - af_var_bias biasMode = bias ? AF_VARIANCE_SAMPLE : AF_VARIANCE_POPULATION; - auto a = fl::rand({3, 3}); - ASSERT_EQ(fl::var(a).scalar(), af::var(toArray(a), biasMode)); - ASSERT_TRUE(allClose( - toArray(fl::var(a, {0})), - detail::condenseIndices(af::var(toArray(a), /* mode = */ biasMode, 0)))); - ASSERT_TRUE(allClose( - toArray(fl::var(a, {1})), - detail::condenseIndices(af::var(toArray(a), /* mode = */ biasMode, 1)))); - // Make sure multidimension matches computing for all - ASSERT_FLOAT_EQ( - toArray(fl::var(a)).scalar(), - af::var(toArray(a), biasMode)); - ASSERT_FLOAT_EQ( - toArray(fl::var(a, {0, 1}, /* biased = */ true)).scalar(), - af::var(toArray(a), /* mode = */ AF_VARIANCE_SAMPLE)); + const bool bias = false; + af_var_bias biasMode = bias ? AF_VARIANCE_SAMPLE : AF_VARIANCE_POPULATION; + auto a = fl::rand({3, 3}); + ASSERT_EQ(fl::var(a).scalar(), af::var(toArray(a), biasMode)); + ASSERT_TRUE( + allClose( + toArray(fl::var(a, {0})), + detail::condenseIndices(af::var(toArray(a), /* mode = */ biasMode, 0)) + ) + ); + ASSERT_TRUE( + allClose( + toArray(fl::var(a, {1})), + detail::condenseIndices(af::var(toArray(a), /* mode = */ biasMode, 1)) + ) + ); + // Make sure multidimension matches computing for all + ASSERT_FLOAT_EQ( + toArray(fl::var(a)).scalar(), + af::var(toArray(a), biasMode) + ); + ASSERT_FLOAT_EQ( + toArray(fl::var(a, {0, 1}, /* biased = */ true)).scalar(), + af::var(toArray(a), /* mode = */ AF_VARIANCE_SAMPLE) + ); } TEST(ArrayFireTensorBaseTest, std) { - auto a = fl::rand({3, 3}); - ASSERT_TRUE(allClose( - toArray(fl::std(a, {0}, /* keepDims = */ true)), - af::stdev(toArray(a), AF_VARIANCE_POPULATION, 0))); - ASSERT_TRUE(allClose( - toArray(fl::std(a, {1}, /* keepDims = */ true)), - af::stdev(toArray(a), AF_VARIANCE_POPULATION, 1))); - // Make sure multidimension matches computing for all - ASSERT_FLOAT_EQ( - toArray(fl::std(a, {0, 1})).scalar(), - std::sqrt(af::var(toArray(a), AF_VARIANCE_POPULATION))); + auto a = fl::rand({3, 3}); + ASSERT_TRUE( + allClose( + toArray(fl::std(a, {0}, /* keepDims = */ true)), + af::stdev(toArray(a), AF_VARIANCE_POPULATION, 0) + ) + ); + ASSERT_TRUE( + allClose( + toArray(fl::std(a, {1}, /* keepDims = */ true)), + af::stdev(toArray(a), AF_VARIANCE_POPULATION, 1) + ) + ); + // Make sure multidimension matches computing for all + ASSERT_FLOAT_EQ( + toArray(fl::std(a, {0, 1})).scalar(), + std::sqrt(af::var(toArray(a), AF_VARIANCE_POPULATION)) + ); } TEST(ArrayFireTensorBaseTest, norm) { - auto a = fl::rand({3, 3}); - ASSERT_NEAR(fl::norm(a).scalar(), af::norm(toArray(a)), 1e-4); + auto a = fl::rand({3, 3}); + ASSERT_NEAR(fl::norm(a).scalar(), af::norm(toArray(a)), 1e-4); } TEST(ArrayFireTensorBaseTest, tile) { - auto a = fl::rand({3, 3}); - ASSERT_TRUE(allClose( - toArray(fl::tile(a, {4, 5, 6})), af::tile(toArray(a), {4, 5, 6}))); + auto a = fl::rand({3, 3}); + ASSERT_TRUE( + allClose( + toArray(fl::tile(a, {4, 5, 6})), + af::tile(toArray(a), {4, 5, 6}) + ) + ); } TEST(ArrayFireTensorBaseTest, nonzero) { - auto a = fl::rand({10, 10}).astype(fl::dtype::u32); - auto nz = fl::nonzero(a); - ASSERT_TRUE(allClose(toArray(nz), af::where(toArray(a)))); + auto a = fl::rand({10, 10}).astype(fl::dtype::u32); + auto nz = fl::nonzero(a); + ASSERT_TRUE(allClose(toArray(nz), af::where(toArray(a)))); } TEST(ArrayFireTensorBaseTest, transpose) { - auto a = fl::rand({3, 5}); - ASSERT_THROW(fl::transpose(a, {0, 1, 2, 3, 4}), std::invalid_argument); - ASSERT_TRUE(allClose(toArray(fl::transpose(a)), af::transpose(toArray(a)))); + auto a = fl::rand({3, 5}); + ASSERT_THROW(fl::transpose(a, {0, 1, 2, 3, 4}), std::invalid_argument); + ASSERT_TRUE(allClose(toArray(fl::transpose(a)), af::transpose(toArray(a)))); - auto b = fl::rand({3, 5, 4, 8}); - ASSERT_TRUE(allClose( - toArray(fl::transpose(b, {2, 0, 1, 3})), - af::reorder(toArray(b), 2, 0, 1, 3))); + auto b = fl::rand({3, 5, 4, 8}); + ASSERT_TRUE( + allClose( + toArray(fl::transpose(b, {2, 0, 1, 3})), + af::reorder(toArray(b), 2, 0, 1, 3) + ) + ); } TEST(ArrayFireTensorBaseTest, concatenate) { - std::vector tensors(11); - ASSERT_THROW(fl::concatenate(tensors), std::invalid_argument); + std::vector tensors(11); + ASSERT_THROW(fl::concatenate(tensors), std::invalid_argument); } TEST(ArrayFireTensorBaseTest, device) { - auto a = fl::rand({5, 5}); - float* flPtr = a.device(); - af::array& arr = toArray(a); - float* afPtr = arr.device(); - ASSERT_EQ(flPtr, afPtr); - a.unlock(); - AF_CHECK(af_unlock_array(arr.get())); // safety + auto a = fl::rand({5, 5}); + float* flPtr = a.device(); + af::array& arr = toArray(a); + float* afPtr = arr.device(); + ASSERT_EQ(flPtr, afPtr); + a.unlock(); + AF_CHECK(af_unlock_array(arr.get())); // safety } TEST(ArrayFireTensorBaseTest, defaultConstructor) { - auto t = ArrayFireTensor(); - ASSERT_TRUE(t.getHandle().isempty()); + auto t = ArrayFireTensor(); + ASSERT_TRUE(t.getHandle().isempty()); } TEST(ArrayFireTensorBaseTest, emptyRangeIndexing) { - // TODO the following should all return empty tensor, but AF currently doesn't - // have a way to represent empty range, and we are just throwing internally. - auto t = fl::rand({5}); - ASSERT_THROW(t(fl::range(5, fl::end)).shape(), std::exception); - ASSERT_THROW(t(fl::range(4, -1)).shape(), std::exception); - ASSERT_THROW(t(fl::range(0, 0)).shape(), std::exception); - ASSERT_THROW(t(fl::range(1, 1)).shape(), std::exception); - ASSERT_THROW(t(fl::range(4, 4)).shape(), std::exception); - ASSERT_THROW(t(fl::range(0, -5)).shape(), std::exception); + // TODO the following should all return empty tensor, but AF currently doesn't + // have a way to represent empty range, and we are just throwing internally. + auto t = fl::rand({5}); + ASSERT_THROW(t(fl::range(5, fl::end)).shape(), std::exception); + ASSERT_THROW(t(fl::range(4, -1)).shape(), std::exception); + ASSERT_THROW(t(fl::range(0, 0)).shape(), std::exception); + ASSERT_THROW(t(fl::range(1, 1)).shape(), std::exception); + ASSERT_THROW(t(fl::range(4, 4)).shape(), std::exception); + ASSERT_THROW(t(fl::range(0, -5)).shape(), std::exception); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp b/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp index 8c41f82..2414a14 100644 --- a/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp +++ b/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp @@ -20,186 +20,189 @@ #include "flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.h" class CachingMemoryManagerTest : public ::testing::Test { - protected: - virtual void SetUp() override { - deviceInterface_ = std::make_shared(); - adapter_ = std::make_shared( - af::getDeviceCount(), deviceInterface_); - installer_ = std::make_unique(adapter_); - installer_->setAsMemoryManager(); - } - - virtual void TearDown() override { - af_unset_memory_manager(); - } - - std::shared_ptr deviceInterface_; - std::shared_ptr adapter_; - std::unique_ptr installer_; +protected: + virtual void SetUp() override { + deviceInterface_ = std::make_shared(); + adapter_ = std::make_shared( + af::getDeviceCount(), + deviceInterface_ + ); + installer_ = std::make_unique(adapter_); + installer_->setAsMemoryManager(); + } + + virtual void TearDown() override { + af_unset_memory_manager(); + } + + std::shared_ptr deviceInterface_; + std::shared_ptr adapter_; + std::unique_ptr installer_; }; TEST_F(CachingMemoryManagerTest, BasicOps) { - // This test checks if the basic math operations like additions, - // multiplication, division are performed correctly - - const int nx = 8; - const int ny = 8; - af::array in1 = af::constant(2.0, nx, ny); - af::array in2 = af::constant(3.0, nx, ny); - ASSERT_TRUE(af::allTrue(in1 + in2 == 5)); - - // NOTE: includes JIT ops - af::array in3 = in1 * in2; - af::array in4 = in3 / in2; - af::array in5 = in4 * in2; - ASSERT_TRUE(af::allTrue(in3 == in5)); + // This test checks if the basic math operations like additions, + // multiplication, division are performed correctly + + const int nx = 8; + const int ny = 8; + af::array in1 = af::constant(2.0, nx, ny); + af::array in2 = af::constant(3.0, nx, ny); + ASSERT_TRUE(af::allTrue(in1 + in2 == 5)); + + // NOTE: includes JIT ops + af::array in3 = in1 * in2; + af::array in4 = in3 / in2; + af::array in5 = in4 * in2; + ASSERT_TRUE(af::allTrue(in3 == in5)); } TEST_F(CachingMemoryManagerTest, DevicePtr) { - // This test checks whether device pointer API works for the arrays - // The CPU backend in AF allocates a buffer for empty arrays - see - // https://github.com/arrayfire/arrayfire/issues/3058. When this is fixed, - // this can be relaxed. - if (FL_BACKEND_CPU) { - GTEST_SKIP() << "ArrayFire CPU backend allocates buffers for empty arrays"; - } - - // Empty array - auto arr1 = af::array(0, 0, 0, 0, af::dtype::f32); - auto* ptr1 = arr1.device(); - ASSERT_EQ(ptr1, nullptr); - arr1.unlock(); - - // Non-Empty array - auto arr2 = af::array(10, 8, 9, 23, af::dtype::f32); - auto* ptr2 = arr2.device(); - ASSERT_NE(ptr2, nullptr); - arr2.unlock(); + // This test checks whether device pointer API works for the arrays + // The CPU backend in AF allocates a buffer for empty arrays - see + // https://github.com/arrayfire/arrayfire/issues/3058. When this is fixed, + // this can be relaxed. + if(FL_BACKEND_CPU) { + GTEST_SKIP() << "ArrayFire CPU backend allocates buffers for empty arrays"; + } + + // Empty array + auto arr1 = af::array(0, 0, 0, 0, af::dtype::f32); + auto* ptr1 = arr1.device(); + ASSERT_EQ(ptr1, nullptr); + arr1.unlock(); + + // Non-Empty array + auto arr2 = af::array(10, 8, 9, 23, af::dtype::f32); + auto* ptr2 = arr2.device(); + ASSERT_NE(ptr2, nullptr); + arr2.unlock(); } TEST_F(CachingMemoryManagerTest, IndexedDevice) { - // This test is checking to see if calling `.device()` will force copy to a - // new buffer unlike `getRawPtr()`. It is required to copy as the the memory - // manager releases the lock on the array after calling `.device()` - const int nx = 8; - const int ny = 8; + // This test is checking to see if calling `.device()` will force copy to a + // new buffer unlike `getRawPtr()`. It is required to copy as the the memory + // manager releases the lock on the array after calling `.device()` + const int nx = 8; + const int ny = 8; - af::array in = af::randu(nx, ny); + af::array in = af::randu(nx, ny); - std::vector in1(in.elements()); - in.host(in1.data()); + std::vector in1(in.elements()); + in.host(in1.data()); - int offx = nx / 4; - int offy = ny / 4; + int offx = nx / 4; + int offy = ny / 4; - in = in(af::seq(offx, offx + nx / 2 - 1), af::seq(offy, offy + ny / 2 - 1)); + in = in(af::seq(offx, offx + nx / 2 - 1), af::seq(offy, offy + ny / 2 - 1)); - int nxo = static_cast(in.dims(0)); - int nyo = static_cast(in.dims(1)); + int nxo = static_cast(in.dims(0)); + int nyo = static_cast(in.dims(1)); - void* rawPtr = af::getRawPtr(in); - void* devPtr = in.device(); - ASSERT_NE(devPtr, rawPtr); - in.unlock(); + void* rawPtr = af::getRawPtr(in); + void* devPtr = in.device(); + ASSERT_NE(devPtr, rawPtr); + in.unlock(); - std::vector in2(in.elements()); - in.host(in2.data()); + std::vector in2(in.elements()); + in.host(in2.data()); - for (int y = 0; y < nyo; y++) { - for (int x = 0; x < nxo; x++) { - ASSERT_EQ(in1[(offy + y) * nx + offx + x], in2[y * nxo + x]); + for(int y = 0; y < nyo; y++) { + for(int x = 0; x < nxo; x++) { + ASSERT_EQ(in1[(offy + y) * nx + offx + x], in2[y * nxo + x]); + } } - } } TEST_F(CachingMemoryManagerTest, LargeNumberOfAllocs) { - GTEST_SKIP() << "Causes spurious OOMs even with exception handling."; - // This test performs stress test to allocate and free a large of number of - // array of variable sizes - - af::array a; - for (int i = 0; i < 5000; ++i) { - auto dimsArr = (af::randu(4, af::dtype::s32)) % 100 + 100; - std::vector dims(4); - dimsArr.as(af::dtype::s32).host(dims.data()); - EXPECT_NO_THROW(a = af::array(dims[0], dims[1], dims[2], dims[3])); - } + GTEST_SKIP() << "Causes spurious OOMs even with exception handling."; + // This test performs stress test to allocate and free a large of number of + // array of variable sizes + + af::array a; + for(int i = 0; i < 5000; ++i) { + auto dimsArr = (af::randu(4, af::dtype::s32)) % 100 + 100; + std::vector dims(4); + dimsArr.as(af::dtype::s32).host(dims.data()); + EXPECT_NO_THROW(a = af::array(dims[0], dims[1], dims[2], dims[3])); + } } TEST_F(CachingMemoryManagerTest, OOM) { - GTEST_SKIP() << "Causes spurious OOMs even with exception handling."; - af_backend b; - af_get_active_backend(&b); - // Despite that test is trying to allocate PB of memory, - // depending on the drivers, afopencl does not seem to guarantee to send an - // OOM signal. https://github.com/arrayfire/arrayfire/issues/2650 At the - // moment, skipping afopencl. - if (b == AF_BACKEND_OPENCL) { - GTEST_SKIP() << "Can't run test with the ArrayFire OpenCL backend"; -} - af::array a; - // N^3 tensor means about 3PB: expected to OOM on today's cuda GPU. - const unsigned N = 99999; - try { - a = af::randu({N, N, N}, f32); - } catch (af::exception& ex) { - ASSERT_EQ(ex.err(), AF_ERR_NO_MEM); - } catch (...) { - EXPECT_TRUE(false) << "CachingMemoryManagerTest OOM: unexpected exception"; - } + GTEST_SKIP() << "Causes spurious OOMs even with exception handling."; + af_backend b; + af_get_active_backend(&b); + // Despite that test is trying to allocate PB of memory, + // depending on the drivers, afopencl does not seem to guarantee to send an + // OOM signal. https://github.com/arrayfire/arrayfire/issues/2650 At the + // moment, skipping afopencl. + if(b == AF_BACKEND_OPENCL) { + GTEST_SKIP() << "Can't run test with the ArrayFire OpenCL backend"; + } + af::array a; + // N^3 tensor means about 3PB: expected to OOM on today's cuda GPU. + const unsigned N = 99999; + try { + a = af::randu({N, N, N}, f32); + } catch(af::exception& ex) { + ASSERT_EQ(ex.err(), AF_ERR_NO_MEM); + } catch(...) { + EXPECT_TRUE(false) << "CachingMemoryManagerTest OOM: unexpected exception"; + } } void testFragmentation( std::shared_ptr deviceInterface_, std::shared_ptr adapter_, - bool expectOOM) { - af::Backend b = af::getActiveBackend(); + bool expectOOM +) { + af::Backend b = af::getActiveBackend(); - if (b != AF_BACKEND_CUDA) { - GTEST_SKIP() + if(b != AF_BACKEND_CUDA) { + GTEST_SKIP() << "CachingMemoryManager fragmentation tests require CUDA backend"; - } - - const auto mms = deviceInterface_->getMaxMemorySize(0); - const auto maxNumf32 = mms / sizeof(float); // AF f32 is supposed to be 32b - ASSERT_NE(mms, 0); - { - af::array a1(.5f * maxNumf32); - adapter_->printInfo("After creating a1:", 0); - } // The a1 buffer will not be freed here, just registered to the cache - adapter_->printInfo("After releasing a1:", 0); - - af::array a2(.1f * maxNumf32); - adapter_->printInfo("After creating a2:", 0); - - af::array a3; - try { - a3 = af::array(.5f * maxNumf32); - } catch (af::exception& ex) { - if (expectOOM) { - ASSERT_EQ(ex.err(), AF_ERR_NO_MEM); - } else { - EXPECT_TRUE(false) - << "CachingMemoryManagerTest fragmentaiton not supposed to throw: " - << ex.what(); } - } + + const auto mms = deviceInterface_->getMaxMemorySize(0); + const auto maxNumf32 = mms / sizeof(float); // AF f32 is supposed to be 32b + ASSERT_NE(mms, 0); + { + af::array a1(.5f * maxNumf32); + adapter_->printInfo("After creating a1:", 0); + } // The a1 buffer will not be freed here, just registered to the cache + adapter_->printInfo("After releasing a1:", 0); + + af::array a2(.1f * maxNumf32); + adapter_->printInfo("After creating a2:", 0); + + af::array a3; + try { + a3 = af::array(.5f * maxNumf32); + } catch(af::exception& ex) { + if(expectOOM) { + ASSERT_EQ(ex.err(), AF_ERR_NO_MEM); + } else { + EXPECT_TRUE(false) + << "CachingMemoryManagerTest fragmentaiton not supposed to throw: " + << ex.what(); + } + } } TEST_F(CachingMemoryManagerTest, Fragmentation) { - GTEST_SKIP() << "Causes spurious OOMs even with exception handling."; - testFragmentation(deviceInterface_, adapter_, true); // should OOM + GTEST_SKIP() << "Causes spurious OOMs even with exception handling."; + testFragmentation(deviceInterface_, adapter_, true); // should OOM } TEST_F(CachingMemoryManagerTest, RecLimit) { - constexpr static size_t ONE_GB = 1 << 30; - // Fine set the manager in order not to recycle big tensors: - adapter_->setRecyclingSizeLimit(2 * ONE_GB); - testFragmentation(deviceInterface_, adapter_, false); // should not OOM + constexpr static size_t ONE_GB = 1 << 30; + // Fine set the manager in order not to recycle big tensors: + adapter_->setRecyclingSizeLimit(2 * ONE_GB); + testFragmentation(deviceInterface_, adapter_, false); // should not OOM } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp b/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp index 18636c8..ccc1f32 100644 --- a/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp +++ b/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp @@ -39,138 +39,140 @@ namespace { * state. */ class TestMemoryManager : public MemoryManagerAdapter { - public: - TestMemoryManager( - std::shared_ptr deviceInterface, - std::ostream* logStream) - : MemoryManagerAdapter(deviceInterface, logStream) {} - - void initialize() override {} - - void shutdown() override {} - - void* alloc( - bool userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elSize) override { - size_t size = elSize; - for (unsigned i = 0; i < ndims; ++i) { - size *= dims[i]; +public: + TestMemoryManager( + std::shared_ptr deviceInterface, + std::ostream* logStream + ) : MemoryManagerAdapter(deviceInterface, logStream) {} + + void initialize() override {} + + void shutdown() override {} + + void* alloc( + bool userLock, + const unsigned ndims, + dim_t* dims, + const unsigned elSize + ) override { + size_t size = elSize; + for(unsigned i = 0; i < ndims; ++i) { + size *= dims[i]; + } + void* ptr = nullptr; + + if(size > 0) { + if(lockedBytes >= maxBytes || totalBytes >= maxBuffers) { + signalMemoryCleanup(); + } + + ptr = this->deviceInterface->nativeAlloc(size); + lockedPtrToSizeMap[ptr] = size; + totalBytes += size; + totalBuffers++; + + // Simple implementation: treat user and AF allocations the same + locked.insert(ptr); + lockedBytes += size; + + lastDims = af::dim4(ndims, dims); + } + return ptr; } - void* ptr = nullptr; - if (size > 0) { - if (lockedBytes >= maxBytes || totalBytes >= maxBuffers) { - signalMemoryCleanup(); - } + size_t allocated(void* ptr) override { + if(lockedPtrToSizeMap.find(ptr) == lockedPtrToSizeMap.end()) { + return 0; + } else { + return lockedPtrToSizeMap[ptr]; + } + } - ptr = this->deviceInterface->nativeAlloc(size); - lockedPtrToSizeMap[ptr] = size; - totalBytes += size; - totalBuffers++; + void unlock(void* ptr, bool userLock) override { + if(!ptr) { + return; + } - // Simple implementation: treat user and AF allocations the same - locked.insert(ptr); - lockedBytes += size; + if(lockedPtrToSizeMap.find(ptr) == lockedPtrToSizeMap.end()) { + return; + } - lastDims = af::dim4(ndims, dims); + // For testing, treat user-allocated and AF-allocated memory identically + if(locked.find(ptr) != locked.end()) { + locked.erase(ptr); + lockedBytes -= lockedPtrToSizeMap[ptr]; + } } - return ptr; - } - - size_t allocated(void* ptr) override { - if (lockedPtrToSizeMap.find(ptr) == lockedPtrToSizeMap.end()) { - return 0; - } else { - return lockedPtrToSizeMap[ptr]; + + void signalMemoryCleanup() override { + // Free unlocked memory + std::vector freed; + for(auto& entry : lockedPtrToSizeMap) { + if(!isUserLocked(entry.first)) { + void* ptr = entry.first; + this->deviceInterface->nativeFree(ptr); + totalBytes -= lockedPtrToSizeMap[entry.first]; + freed.push_back(entry.first); + } + } + for(auto ptr : freed) { + lockedPtrToSizeMap.erase(ptr); + } } - } - void unlock(void* ptr, bool userLock) override { - if (!ptr) { - return; + void printInfo( + const char* /* msg */, + const int /* device */, + std::ostream* /* ostream */ + ) override {} + + void userLock(const void* cPtr) override { + void* ptr = const_cast(cPtr); + if(locked.find(ptr) == locked.end()) { + locked.insert(ptr); + lockedBytes += lockedPtrToSizeMap[ptr]; + } } - if (lockedPtrToSizeMap.find(ptr) == lockedPtrToSizeMap.end()) { - return; + void userUnlock(const void* cPtr) override { + void* ptr = const_cast(cPtr); + unlock(ptr, /* user */ true); + lockedBytes -= lockedPtrToSizeMap[ptr]; } - // For testing, treat user-allocated and AF-allocated memory identically - if (locked.find(ptr) != locked.end()) { - locked.erase(ptr); - lockedBytes -= lockedPtrToSizeMap[ptr]; + bool isUserLocked(const void* ptr) override { + return locked.find(const_cast(ptr)) != locked.end(); } - } - - void signalMemoryCleanup() override { - // Free unlocked memory - std::vector freed; - for (auto& entry : lockedPtrToSizeMap) { - if (!isUserLocked(entry.first)) { - void* ptr = entry.first; - this->deviceInterface->nativeFree(ptr); - totalBytes -= lockedPtrToSizeMap[entry.first]; - freed.push_back(entry.first); - } + + float getMemoryPressure() override { + if(lockedBytes > maxBytes || totalBuffers > maxBuffers) { + return 1.0; + } else { + return 0.0; + } } - for (auto ptr : freed) { - lockedPtrToSizeMap.erase(ptr); + + bool jitTreeExceedsMemoryPressure(size_t bytes) override { + return 2 * bytes > lockedBytes; } - } - - void printInfo( - const char* /* msg */, - const int /* device */, - std::ostream* /* ostream */) override {} - - void userLock(const void* cPtr) override { - void* ptr = const_cast(cPtr); - if (locked.find(ptr) == locked.end()) { - locked.insert(ptr); - lockedBytes += lockedPtrToSizeMap[ptr]; + + void addMemoryManagement(int device) override { + throw std::logic_error("Not implemented"); } - } - - void userUnlock(const void* cPtr) override { - void* ptr = const_cast(cPtr); - unlock(ptr, /* user */ true); - lockedBytes -= lockedPtrToSizeMap[ptr]; - } - - bool isUserLocked(const void* ptr) override { - return locked.find(const_cast(ptr)) != locked.end(); - } - - float getMemoryPressure() override { - if (lockedBytes > maxBytes || totalBuffers > maxBuffers) { - return 1.0; - } else { - return 0.0; + + void removeMemoryManagement(int device) override { + throw std::logic_error("Not implemented"); } - } - - bool jitTreeExceedsMemoryPressure(size_t bytes) override { - return 2 * bytes > lockedBytes; - } - - void addMemoryManagement(int device) override { - throw std::logic_error("Not implemented"); - } - - void removeMemoryManagement(int device) override { - throw std::logic_error("Not implemented"); - } - - std::unordered_map lockedPtrToSizeMap; - std::unordered_set locked; - size_t totalBytes{0}; - size_t totalBuffers{0}; - size_t lockedBytes{0}; - size_t maxBuffers{64}; - size_t maxBytes{1024}; - // helps test dim_t* argument to alloc - af::dim4 lastDims{0, 0, 0, 0}; + + std::unordered_map lockedPtrToSizeMap; + std::unordered_set locked; + size_t totalBytes{0}; + size_t totalBuffers{0}; + size_t lockedBytes{0}; + size_t maxBuffers{64}; + size_t maxBytes{1024}; + // helps test dim_t* argument to alloc + af::dim4 lastDims{0, 0, 0, 0}; }; /** @@ -187,74 +189,117 @@ class TestMemoryManager : public MemoryManagerAdapter { * aren't directly tested. */ class MockTestMemoryManager : public TestMemoryManager { - public: - MockTestMemoryManager( - std::shared_ptr real, - std::shared_ptr deviceInterface, - std::ostream* logStream) - : TestMemoryManager(deviceInterface, logStream), real_(real) { - ON_CALL(*this, initialize()).WillByDefault(Invoke([this]() { - real_->initialize(); - })); - ON_CALL(*this, shutdown()).WillByDefault(Invoke([this]() { - real_->shutdown(); - })); - ON_CALL(*this, alloc(_, _, _, _)) - .WillByDefault(Invoke([this]( - bool userLock, - const unsigned ndims, - dim_t* dims, - const unsigned elSize) { - return real_->alloc(userLock, ndims, dims, elSize); - })); - ON_CALL(*this, allocated(_)).WillByDefault(Invoke([this](void* ptr) { - return real_->allocated(ptr); - })); - ON_CALL(*this, unlock(_, _)) - .WillByDefault(Invoke([this](void* ptr, bool userLock) { - real_->unlock(ptr, userLock); - })); - ON_CALL(*this, signalMemoryCleanup()).WillByDefault(Invoke([this]() { - real_->signalMemoryCleanup(); - })); - ON_CALL(*this, printInfo(_, _, _)) - .WillByDefault(Invoke( - [this](const char* msg, const int device, std::ostream* ostream) { - real_->printInfo(msg, device, ostream); - })); - ON_CALL(*this, userLock(_)).WillByDefault(Invoke([this](const void* cPtr) { - real_->userLock(cPtr); - })); - ON_CALL(*this, userUnlock(_)) +public: + MockTestMemoryManager( + std::shared_ptr real, + std::shared_ptr deviceInterface, + std::ostream* logStream + ) : TestMemoryManager(deviceInterface, logStream), + real_(real) { + ON_CALL(*this, initialize()).WillByDefault( + Invoke( + [this]() { + real_->initialize(); + } + ) + ); + ON_CALL(*this, shutdown()).WillByDefault( + Invoke( + [this]() { + real_->shutdown(); + } + ) + ); + ON_CALL(*this, alloc(_, _, _, _)) + .WillByDefault( + Invoke( + [this]( + bool userLock, + const unsigned ndims, + dim_t* dims, + const unsigned elSize) { + return real_->alloc(userLock, ndims, dims, elSize); + } + ) + ); + ON_CALL(*this, allocated(_)).WillByDefault( + Invoke( + [this](void* ptr) { + return real_->allocated(ptr); + } + ) + ); + ON_CALL(*this, unlock(_, _)) + .WillByDefault( + Invoke( + [this](void* ptr, bool userLock) { + real_->unlock(ptr, userLock); + } + ) + ); + ON_CALL(*this, signalMemoryCleanup()).WillByDefault( + Invoke( + [this]() { + real_->signalMemoryCleanup(); + } + ) + ); + ON_CALL(*this, printInfo(_, _, _)) + .WillByDefault( + Invoke( + [this](const char* msg, const int device, std::ostream* ostream) { + real_->printInfo(msg, device, ostream); + } + ) + ); + ON_CALL(*this, userLock(_)).WillByDefault( + Invoke( + [this](const void* cPtr) { + real_->userLock(cPtr); + } + ) + ); + ON_CALL(*this, userUnlock(_)) .WillByDefault( - Invoke([this](const void* cPtr) { real_->userUnlock(cPtr); })); - ON_CALL(*this, isUserLocked(_)) - .WillByDefault(Invoke( - [this](const void* cPtr) { return real_->isUserLocked(cPtr); })); - ON_CALL(*this, getMemoryPressure()).WillByDefault(Invoke([this]() { - return real_->getMemoryPressure(); - })); - ON_CALL(*this, jitTreeExceedsMemoryPressure(_)) - .WillByDefault(Invoke([this](size_t bytes) { - return real_->jitTreeExceedsMemoryPressure(bytes); - })); - } - - MOCK_METHOD(void, initialize, ()); - MOCK_METHOD(void, shutdown, ()); - MOCK_METHOD(void*, alloc, (bool, const unsigned, dim_t*, const unsigned)); - MOCK_METHOD(size_t, allocated, (void*)); - MOCK_METHOD(void, unlock, (void*, bool)); - MOCK_METHOD(void, signalMemoryCleanup, ()); - MOCK_METHOD(void, printInfo, (const char*, const int, std::ostream*)); - MOCK_METHOD(void, userLock, (const void*)); - MOCK_METHOD(void, userUnlock, (const void*)); - MOCK_METHOD(bool, isUserLocked, (const void*)); - MOCK_METHOD(float, getMemoryPressure, ()); - MOCK_METHOD(bool, jitTreeExceedsMemoryPressure, (size_t)); - - private: - std::shared_ptr real_; + Invoke([this](const void* cPtr) { real_->userUnlock(cPtr); }) + ); + ON_CALL(*this, isUserLocked(_)) + .WillByDefault( + Invoke( + [this](const void* cPtr) { return real_->isUserLocked(cPtr); }) + ); + ON_CALL(*this, getMemoryPressure()).WillByDefault( + Invoke( + [this]() { + return real_->getMemoryPressure(); + } + ) + ); + ON_CALL(*this, jitTreeExceedsMemoryPressure(_)) + .WillByDefault( + Invoke( + [this](size_t bytes) { + return real_->jitTreeExceedsMemoryPressure(bytes); + } + ) + ); + } + + MOCK_METHOD(void, initialize, ()); + MOCK_METHOD(void, shutdown, ()); + MOCK_METHOD(void*, alloc, (bool, const unsigned, dim_t*, const unsigned)); + MOCK_METHOD(size_t, allocated, (void*)); + MOCK_METHOD(void, unlock, (void*, bool)); + MOCK_METHOD(void, signalMemoryCleanup, ()); + MOCK_METHOD(void, printInfo, (const char*, const int, std::ostream*)); + MOCK_METHOD(void, userLock, (const void*)); + MOCK_METHOD(void, userUnlock, (const void*)); + MOCK_METHOD(bool, isUserLocked, (const void*)); + MOCK_METHOD(float, getMemoryPressure, ()); + MOCK_METHOD(bool, jitTreeExceedsMemoryPressure, (size_t)); + +private: + std::shared_ptr real_; }; } // namespace @@ -277,171 +322,178 @@ class MockTestMemoryManager : public TestMemoryManager { * memory manager */ TEST(MemoryFramework, AdapterInstallerDeviceInterfaceTest) { - // The CPU backend in AF allocates a buffer for empty arrays - see - // https://github.com/arrayfire/arrayfire/issues/3058. When this is fixed, - // this can be relaxed/this test will pass - if (FL_BACKEND_CPU) { - GTEST_SKIP() << "ArrayFire CPU backend allocates buffers for empty arrays"; - } - - std::stringstream logStream; - std::stringstream mockLogStream; - { - auto deviceInterface = std::make_shared(); - - auto memoryManager = - std::make_shared(deviceInterface, &logStream); - auto mockMemoryManager = std::make_shared( - memoryManager, deviceInterface, &mockLogStream); - - auto installer = - std::make_unique(mockMemoryManager); - // initialize should only be called once while the custom memory - // manager was set - EXPECT_CALL(*mockMemoryManager, initialize()).Times(Exactly(1)); - installer->setAsMemoryManager(); - - // flush the mock log stream every two lines - mockMemoryManager->setLogFlushInterval(2); - - { - // Do some sample allocations using `af::alloc` (which allocates - // user-locked memory) and `af::randu`, which calls the `af::array` - // constructor to allocate memory. - EXPECT_CALL(*mockMemoryManager, alloc(/* user lock */ true, 1, _, 1)) - .Times(Exactly(1)); - size_t aSize = 8; - void* a = af::allocV2(aSize * af::getSizeOf(af::dtype::f32)); - // Allocated memory should properly appear in our internal data structures - // given correct passage of state - EXPECT_EQ(memoryManager->lockedPtrToSizeMap.size(), 1); - EXPECT_EQ(memoryManager->lockedPtrToSizeMap[a], aSize * sizeof(float)); - // Check that the dims for this array are correct. ArrayFire currently - // passes (dims, 1, 1, 1) for all allocations. - EXPECT_EQ(memoryManager->lastDims, af::dim4(aSize * sizeof(float))); - - // Check logs which should be flushed to our output stream after 2 ops - std::string log1; - std::getline(mockLogStream, log1); - EXPECT_EQ(log1, "initialize "); - std::string log2; - std::getline(mockLogStream, log2); - EXPECT_EQ( - log2.substr(0, 14), - "nativeAlloc " + std::to_string(aSize * sizeof(float))); - - // Buffer more logs in the default memory manager - memoryManager->setLogFlushInterval(50); - - // Allocate an `af::array`, which won't be user locked, and has - // information about array size passed to alloc - dim_t bDim = 2; - EXPECT_CALL( - *mockMemoryManager, - alloc(/* user lock */ false, 1, _, sizeof(float))); - af::array b = af::randu({bDim, bDim}); - // Again, allocated should properly appear in our internal data structures - // given correct passage of state - EXPECT_EQ(memoryManager->totalBytes, aSize * sizeof(float) + b.bytes()); - EXPECT_EQ(memoryManager->totalBuffers, 2); - // Our array is locked, but not user locked - EXPECT_EQ(memoryManager->lockedBytes, aSize * sizeof(float) + b.bytes()); - EXPECT_EQ(memoryManager->locked.size(), 2); - // Check that the dims for this array are correct - EXPECT_EQ(memoryManager->lastDims, af::dim4(bDim * b.numdims())); - - // Free user-locked memory. Check that freeing memory properly calls - // unlock with user-locked memory (since we used af::alloc) - EXPECT_CALL(*mockMemoryManager, unlock(a, /* user lock */ true)) - .Times(Exactly(1)); - af::freeV2(a); - // Internal data structures should be updated accordingly to reflect - // removal of a buffer - EXPECT_EQ(memoryManager->totalBytes, aSize * sizeof(float) + b.bytes()); - EXPECT_EQ(memoryManager->totalBuffers, 2); - EXPECT_EQ(memoryManager->lockedBytes, b.bytes()); - EXPECT_EQ(memoryManager->locked.size(), 1); - - // af::array b is out of scope, which is not user-locked memory - EXPECT_CALL(*mockMemoryManager, unlock(_, /* user lock */ false)) - .Times(Exactly(1)); + // The CPU backend in AF allocates a buffer for empty arrays - see + // https://github.com/arrayfire/arrayfire/issues/3058. When this is fixed, + // this can be relaxed/this test will pass + if(FL_BACKEND_CPU) { + GTEST_SKIP() << "ArrayFire CPU backend allocates buffers for empty arrays"; } - // Memory reset calls signalMemoryCleanup() and clears the map - EXPECT_CALL(*mockMemoryManager, signalMemoryCleanup()).Times(Exactly(1)); - af::deviceGC(); - EXPECT_TRUE(memoryManager->lockedPtrToSizeMap.empty()); - - // printInfo - const std::string printInfoMsg = "testPrintInfo"; - int printInfoDeviceId = 0; - EXPECT_CALL( - *mockMemoryManager, - printInfo( - printInfoMsg.c_str(), - printInfoDeviceId, - mockMemoryManager->getLogStream())) + std::stringstream logStream; + std::stringstream mockLogStream; + { + auto deviceInterface = std::make_shared(); + + auto memoryManager = + std::make_shared(deviceInterface, &logStream); + auto mockMemoryManager = std::make_shared( + memoryManager, + deviceInterface, + &mockLogStream + ); + + auto installer = + std::make_unique(mockMemoryManager); + // initialize should only be called once while the custom memory + // manager was set + EXPECT_CALL(*mockMemoryManager, initialize()).Times(Exactly(1)); + installer->setAsMemoryManager(); + + // flush the mock log stream every two lines + mockMemoryManager->setLogFlushInterval(2); + + { + // Do some sample allocations using `af::alloc` (which allocates + // user-locked memory) and `af::randu`, which calls the `af::array` + // constructor to allocate memory. + EXPECT_CALL(*mockMemoryManager, alloc(/* user lock */ true, 1, _, 1)) + .Times(Exactly(1)); + size_t aSize = 8; + void* a = af::allocV2(aSize * af::getSizeOf(af::dtype::f32)); + // Allocated memory should properly appear in our internal data structures + // given correct passage of state + EXPECT_EQ(memoryManager->lockedPtrToSizeMap.size(), 1); + EXPECT_EQ(memoryManager->lockedPtrToSizeMap[a], aSize * sizeof(float)); + // Check that the dims for this array are correct. ArrayFire currently + // passes (dims, 1, 1, 1) for all allocations. + EXPECT_EQ(memoryManager->lastDims, af::dim4(aSize * sizeof(float))); + + // Check logs which should be flushed to our output stream after 2 ops + std::string log1; + std::getline(mockLogStream, log1); + EXPECT_EQ(log1, "initialize "); + std::string log2; + std::getline(mockLogStream, log2); + EXPECT_EQ( + log2.substr(0, 14), + "nativeAlloc " + std::to_string(aSize * sizeof(float)) + ); + + // Buffer more logs in the default memory manager + memoryManager->setLogFlushInterval(50); + + // Allocate an `af::array`, which won't be user locked, and has + // information about array size passed to alloc + dim_t bDim = 2; + EXPECT_CALL( + *mockMemoryManager, + alloc(/* user lock */ false, 1, _, sizeof(float)) + ); + af::array b = af::randu({bDim, bDim}); + // Again, allocated should properly appear in our internal data structures + // given correct passage of state + EXPECT_EQ(memoryManager->totalBytes, aSize * sizeof(float) + b.bytes()); + EXPECT_EQ(memoryManager->totalBuffers, 2); + // Our array is locked, but not user locked + EXPECT_EQ(memoryManager->lockedBytes, aSize * sizeof(float) + b.bytes()); + EXPECT_EQ(memoryManager->locked.size(), 2); + // Check that the dims for this array are correct + EXPECT_EQ(memoryManager->lastDims, af::dim4(bDim * b.numdims())); + + // Free user-locked memory. Check that freeing memory properly calls + // unlock with user-locked memory (since we used af::alloc) + EXPECT_CALL(*mockMemoryManager, unlock(a, /* user lock */ true)) + .Times(Exactly(1)); + af::freeV2(a); + // Internal data structures should be updated accordingly to reflect + // removal of a buffer + EXPECT_EQ(memoryManager->totalBytes, aSize * sizeof(float) + b.bytes()); + EXPECT_EQ(memoryManager->totalBuffers, 2); + EXPECT_EQ(memoryManager->lockedBytes, b.bytes()); + EXPECT_EQ(memoryManager->locked.size(), 1); + + // af::array b is out of scope, which is not user-locked memory + EXPECT_CALL(*mockMemoryManager, unlock(_, /* user lock */ false)) + .Times(Exactly(1)); + } + + // Memory reset calls signalMemoryCleanup() and clears the map + EXPECT_CALL(*mockMemoryManager, signalMemoryCleanup()).Times(Exactly(1)); + af::deviceGC(); + EXPECT_TRUE(memoryManager->lockedPtrToSizeMap.empty()); + + // printInfo + const std::string printInfoMsg = "testPrintInfo"; + int printInfoDeviceId = 0; + EXPECT_CALL( + *mockMemoryManager, + printInfo( + printInfoMsg.c_str(), + printInfoDeviceId, + mockMemoryManager->getLogStream() + ) + ) .Times(Exactly(1)); - af::printMemInfo(printInfoMsg.c_str(), printInfoDeviceId); - - // all allocations are either freed or out of scope - check that the map is - // empty - EXPECT_TRUE(memoryManager->lockedPtrToSizeMap.empty()); - // reset to default memory manager - // shutdown is called for each device with that current device set - EXPECT_CALL(*mockMemoryManager, shutdown()) + af::printMemInfo(printInfoMsg.c_str(), printInfoDeviceId); + + // all allocations are either freed or out of scope - check that the map is + // empty + EXPECT_TRUE(memoryManager->lockedPtrToSizeMap.empty()); + // reset to default memory manager + // shutdown is called for each device with that current device set + EXPECT_CALL(*mockMemoryManager, shutdown()) .Times(Exactly(af::getDeviceCount())); - MemoryManagerInstaller::unsetMemoryManager(); - // Test that unsetting a memory manager via the global singleton restores - // the default ArrayFire memory manager - auto* manager = MemoryManagerInstaller::currentlyInstalledMemoryManager(); - ASSERT_EQ(manager, nullptr); - - // Any allocations made should not call the custom memory manager since - // we've called `MemoryManagerInstaller::unsetMemoryManager()` above, which - // restores the default memory manager as the primary memory manager. - EXPECT_CALL(*mockMemoryManager, alloc(_, _, _, _)).Times(Exactly(0)); - EXPECT_CALL(*mockMemoryManager, unlock(_, _)).Times(Exactly(0)); + MemoryManagerInstaller::unsetMemoryManager(); + // Test that unsetting a memory manager via the global singleton restores + // the default ArrayFire memory manager + auto* manager = MemoryManagerInstaller::currentlyInstalledMemoryManager(); + ASSERT_EQ(manager, nullptr); + + // Any allocations made should not call the custom memory manager since + // we've called `MemoryManagerInstaller::unsetMemoryManager()` above, which + // restores the default memory manager as the primary memory manager. + EXPECT_CALL(*mockMemoryManager, alloc(_, _, _, _)).Times(Exactly(0)); + EXPECT_CALL(*mockMemoryManager, unlock(_, _)).Times(Exactly(0)); + dim_t cDim = 4; + size_t pSize = 8; + const af::dtype type = af::dtype::f32; + auto c = af::randu({cDim, cDim}, type); + void* p = af::allocV2(pSize * af::getSizeOf(type)); + af::freeV2(p); + } + // The custom memory is destroyed; check that the log stream, which is flushed + // on destruction, contains the correct output + std::vector expectedLinePrefixes = { + "initialize", + "nativeAlloc", + "alloc", + "nativeAlloc", + "alloc", + "unlock", + "unlock", + "signalMemoryCleanup", + "nativeFree", + "nativeFree", + "shutdown", + "shutdown"}; + size_t idx = 0; + for(std::string line; std::getline(logStream, line);) { + EXPECT_EQ(line.substr(0, line.find(' ')), expectedLinePrefixes[idx]); + idx++; + } + + // Test that normal allocations work now that the custom memory manager has + // been destroyed and its function pointers and closures invalidated dim_t cDim = 4; size_t pSize = 8; const af::dtype type = af::dtype::f32; auto c = af::randu({cDim, cDim}, type); void* p = af::allocV2(pSize * af::getSizeOf(type)); af::freeV2(p); - } - // The custom memory is destroyed; check that the log stream, which is flushed - // on destruction, contains the correct output - std::vector expectedLinePrefixes = { - "initialize", - "nativeAlloc", - "alloc", - "nativeAlloc", - "alloc", - "unlock", - "unlock", - "signalMemoryCleanup", - "nativeFree", - "nativeFree", - "shutdown", - "shutdown"}; - size_t idx = 0; - for (std::string line; std::getline(logStream, line);) { - EXPECT_EQ(line.substr(0, line.find(' ')), expectedLinePrefixes[idx]); - idx++; - } - - // Test that normal allocations work now that the custom memory manager has - // been destroyed and its function pointers and closures invalidated - dim_t cDim = 4; - size_t pSize = 8; - const af::dtype type = af::dtype::f32; - auto c = af::randu({cDim, cDim}, type); - void* p = af::allocV2(pSize * af::getSizeOf(type)); - af::freeV2(p); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/fl/test/tensor/af/MemoryInitTest.cpp b/flashlight/fl/test/tensor/af/MemoryInitTest.cpp index 0122f8e..97bdd2d 100644 --- a/flashlight/fl/test/tensor/af/MemoryInitTest.cpp +++ b/flashlight/fl/test/tensor/af/MemoryInitTest.cpp @@ -17,18 +17,18 @@ using namespace fl; TEST(MemoryInitTest, DefaultManagerInitializesCorrectType) { - if (FL_BACKEND_CPU) { - GTEST_SKIP() << "CachingMemoryManager is not used on CPU backend"; - } - auto* manager = MemoryManagerInstaller::currentlyInstalledMemoryManager(); - // A non-null value means that a) a custom memory manager has been installed - // and b) that a CachingMemoryManager has been installed which is the desired - // default behavior. - ASSERT_NE(dynamic_cast(manager), nullptr); + if(FL_BACKEND_CPU) { + GTEST_SKIP() << "CachingMemoryManager is not used on CPU backend"; + } + auto* manager = MemoryManagerInstaller::currentlyInstalledMemoryManager(); + // A non-null value means that a) a custom memory manager has been installed + // and b) that a CachingMemoryManager has been installed which is the desired + // default behavior. + ASSERT_NE(dynamic_cast(manager), nullptr); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/runtime/Runtime.cpp b/flashlight/pkg/runtime/Runtime.cpp index cbc7e37..148b8e8 100644 --- a/flashlight/pkg/runtime/Runtime.cpp +++ b/flashlight/pkg/runtime/Runtime.cpp @@ -16,71 +16,71 @@ namespace fl::pkg::runtime { constexpr size_t kRunFileNameIntWidth = 3; -std::string -getRunFile(const std::string& name, const int runidx, const fs::path& runpath) { - std::stringstream ss; - ss << std::setw(kRunFileNameIntWidth) << std::setfill('0') << runidx << "_" - << name; - return runpath / ss.str(); +std::string getRunFile(const std::string& name, const int runidx, const fs::path& runpath) { + std::stringstream ss; + ss << std::setw(kRunFileNameIntWidth) << std::setfill('0') << runidx << "_" + << name; + return runpath / ss.str(); }; std::string serializeGflags(const std::string& separator) { - std::stringstream serialized; - std::vector allFlags; - gflags::GetAllFlags(&allFlags); - std::string currVal; - for (auto itr = allFlags.begin(); itr != allFlags.end(); ++itr) { - gflags::GetCommandLineOption(itr->name.c_str(), &currVal); - serialized << "--" << itr->name << "=" << currVal << separator; - } - return serialized.str(); + std::stringstream serialized; + std::vector allFlags; + gflags::GetAllFlags(&allFlags); + std::string currVal; + for(auto itr = allFlags.begin(); itr != allFlags.end(); ++itr) { + gflags::GetCommandLineOption(itr->name.c_str(), &currVal); + serialized << "--" << itr->name << "=" << currVal << separator; + } + return serialized.str(); } bool backwardWithScaling( const fl::Variable& loss, std::vector& params, std::shared_ptr dynamicScaler, - std::shared_ptr reducer) { - auto scaledLoss = loss; - if (dynamicScaler) { - scaledLoss = dynamicScaler->scale(loss); - } - - scaledLoss.backward(); - if (reducer) { - reducer->finalize(); - } - - if (dynamicScaler) { - if (!dynamicScaler->unscale(params)) { - return false; + std::shared_ptr reducer +) { + auto scaledLoss = loss; + if(dynamicScaler) { + scaledLoss = dynamicScaler->scale(loss); + } + + scaledLoss.backward(); + if(reducer) { + reducer->finalize(); + } + + if(dynamicScaler) { + if(!dynamicScaler->unscale(params)) { + return false; + } + dynamicScaler->update(); } - dynamicScaler->update(); - } - return true; + return true; } std::string getCurrentDate() { - time_t now = time(nullptr); - struct tm tmbuf; - struct tm* tstruct; - tstruct = localtime_r(&now, &tmbuf); - - std::array buf; - strftime(buf.data(), buf.size(), "%Y-%m-%d", tstruct); - return std::string(buf.data()); + time_t now = time(nullptr); + struct tm tmbuf; + struct tm* tstruct; + tstruct = localtime_r(&now, &tmbuf); + + std::array buf; + strftime(buf.data(), buf.size(), "%Y-%m-%d", tstruct); + return std::string(buf.data()); } std::string getCurrentTime() { - time_t now = time(nullptr); - struct tm tmbuf; - struct tm* tstruct; - tstruct = localtime_r(&now, &tmbuf); - - std::array buf; - strftime(buf.data(), buf.size(), "%X", tstruct); - return std::string(buf.data()); + time_t now = time(nullptr); + struct tm tmbuf; + struct tm* tstruct; + tstruct = localtime_r(&now, &tmbuf); + + std::array buf; + strftime(buf.data(), buf.size(), "%X", tstruct); + return std::string(buf.data()); } } // end namespace fl diff --git a/flashlight/pkg/runtime/Runtime.h b/flashlight/pkg/runtime/Runtime.h index b529d63..c866bb1 100644 --- a/flashlight/pkg/runtime/Runtime.h +++ b/flashlight/pkg/runtime/Runtime.h @@ -13,17 +13,16 @@ namespace fl { namespace pkg { -namespace runtime { + namespace runtime { /** * Get a certain checkpoint by `runidx`. */ -std::string -getRunFile(const std::string& name, int runidx, const fs::path& runpath); + std::string getRunFile(const std::string& name, int runidx, const fs::path& runpath); /** * Serialize gflags into a buffer. */ -std::string serializeGflags(const std::string& separator = "\n"); + std::string serializeGflags(const std::string& separator = "\n"); /** * Properly scale the loss for back-propogation. @@ -35,22 +34,23 @@ std::string serializeGflags(const std::string& separator = "\n"); * gradients. * @param[in] reducer - to synchronize gradients in back-propogation. */ -bool backwardWithScaling( - const fl::Variable& loss, - std::vector& params, - std::shared_ptr dynamicScaler, - std::shared_ptr reducer); + bool backwardWithScaling( + const fl::Variable& loss, + std::vector& params, + std::shared_ptr dynamicScaler, + std::shared_ptr reducer + ); /** * Returns the current date as a string */ -std::string getCurrentDate(); + std::string getCurrentDate(); /** * Returns the current time as a string */ -std::string getCurrentTime(); + std::string getCurrentTime(); -} // namespace runtime + } // namespace runtime } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/runtime/amp/DynamicScaler.cpp b/flashlight/pkg/runtime/amp/DynamicScaler.cpp index 56f36a9..d9de6ac 100644 --- a/flashlight/pkg/runtime/amp/DynamicScaler.cpp +++ b/flashlight/pkg/runtime/amp/DynamicScaler.cpp @@ -14,64 +14,64 @@ namespace fl::pkg::runtime { DynamicScaler::DynamicScaler( double initFactor, double maxFactor, - unsigned int updateInterval) - : scaleFactor_(initFactor), - maxScaleFactor_(maxFactor), - updateInterval_(updateInterval) {} + unsigned int updateInterval +) : scaleFactor_(initFactor), + maxScaleFactor_(maxFactor), + updateInterval_(updateInterval) {} fl::Variable DynamicScaler::scale(const fl::Variable& loss) { - // Force casting to fp32 to avoid overflow in scaling. - auto scaledLoss = loss.astype(fl::dtype::f32); - scaledLoss = scaledLoss * scaleFactor_; - return scaledLoss; + // Force casting to fp32 to avoid overflow in scaling. + auto scaledLoss = loss.astype(fl::dtype::f32); + scaledLoss = scaledLoss * scaleFactor_; + return scaledLoss; } bool DynamicScaler::unscale(std::vector& params) { - for (auto& p : params) { - if (!p.isGradAvailable()) { - // Add a dummy grad for params not used in the backwards pass - p.addGrad(Variable(fl::full(p.shape(), 0., p.type()), false)); + for(auto& p : params) { + if(!p.isGradAvailable()) { + // Add a dummy grad for params not used in the backwards pass + p.addGrad(Variable(fl::full(p.shape(), 0., p.type()), false)); + } + p.grad() = p.grad() / scaleFactor_; + if(fl::isInvalidArray(p.grad().tensor())) { + if(scaleFactor_ >= fl::kAmpMinimumScaleFactorValue) { + scaleFactor_ = scaleFactor_ / 2.0f; + FL_LOG(LogLevel::INFO) + << "AMP: Scale factor decreased. New value:\t" << scaleFactor_; + } else { + FL_LOG(LogLevel::FATAL) + << "Minimum loss scale reached: " << fl::kAmpMinimumScaleFactorValue + << " with over/underflowing gradients. Lowering the " + << "learning rate, using gradient clipping, or " + << "increasing the batch size can help resolve " + << "loss explosion."; + } + successCounter_ = 0; + return false; + } } - p.grad() = p.grad() / scaleFactor_; - if (fl::isInvalidArray(p.grad().tensor())) { - if (scaleFactor_ >= fl::kAmpMinimumScaleFactorValue) { - scaleFactor_ = scaleFactor_ / 2.0f; - FL_LOG(LogLevel::INFO) - << "AMP: Scale factor decreased. New value:\t" << scaleFactor_; - } else { - FL_LOG(LogLevel::FATAL) - << "Minimum loss scale reached: " << fl::kAmpMinimumScaleFactorValue - << " with over/underflowing gradients. Lowering the " - << "learning rate, using gradient clipping, or " - << "increasing the batch size can help resolve " - << "loss explosion."; - } - successCounter_ = 0; - return false; - } - } - ++successCounter_; - return true; + ++successCounter_; + return true; } void DynamicScaler::update() { - if (scaleFactor_ >= maxScaleFactor_) { - return; - } + if(scaleFactor_ >= maxScaleFactor_) { + return; + } - if (scaleFactor_ == updateInterval_) { - scaleFactor_ *= 2; - FL_VLOG(2) << "AMP: Scale factor doubled. New value:\t" << scaleFactor_; - successCounter_ = 0; - } else { - scaleFactor_ += 2; - FL_VLOG(3) << "AMP: Scale factor incremented. New value\t" << scaleFactor_; - } + if(scaleFactor_ == updateInterval_) { + scaleFactor_ *= 2; + FL_VLOG(2) << "AMP: Scale factor doubled. New value:\t" << scaleFactor_; + successCounter_ = 0; + } else { + scaleFactor_ += 2; + FL_VLOG(3) << "AMP: Scale factor incremented. New value\t" << scaleFactor_; + } } double DynamicScaler::getScaleFactor() const { - return scaleFactor_; + return scaleFactor_; } } // namespace fl diff --git a/flashlight/pkg/runtime/amp/DynamicScaler.h b/flashlight/pkg/runtime/amp/DynamicScaler.h index 843a778..a2f162c 100644 --- a/flashlight/pkg/runtime/amp/DynamicScaler.h +++ b/flashlight/pkg/runtime/amp/DynamicScaler.h @@ -12,7 +12,7 @@ namespace fl { namespace pkg { -namespace runtime { + namespace runtime { /** * Dynamically scales up the training loss as well as all the gradients in back @@ -36,48 +36,53 @@ namespace runtime { * opt.step(); * } */ -class DynamicScaler { - public: - DynamicScaler( - double initFactor, - double maxFactor, - unsigned int updateInterval); + class DynamicScaler { + public: + DynamicScaler( + double initFactor, + double maxFactor, + unsigned int updateInterval + ); - /* - * Scale loss before back propagation. - */ - fl::Variable scale(const fl::Variable& loss); + /* + * Scale loss before back propagation. + */ + fl::Variable scale(const fl::Variable& loss); - /* - * Unscale the gradients after back propagation. - * Return false when NAN or INF occurs in gradients and halve the scale - * factor, true otherwise. - */ - bool unscale(std::vector& params); + /* + * Unscale the gradients after back propagation. + * Return false when NAN or INF occurs in gradients and halve the scale + * factor, true otherwise. + */ + bool unscale(std::vector& params); - /* - * Increase scale factor - */ - void update(); + /* + * Increase scale factor + */ + void update(); - /* - * Return the current scale factor - */ - double getScaleFactor() const; + /* + * Return the current scale factor + */ + double getScaleFactor() const; - private: - double scaleFactor_; - // The maximum value of scaleFactor_. - double maxScaleFactor_; - // Number of iterations without changing scaleFactor_. - unsigned int successCounter_{0}; - // Double up the scaleFactor_ when successCounter_ equals updateInterval_. - unsigned int updateInterval_; + private: + double scaleFactor_; + // The maximum value of scaleFactor_. + double maxScaleFactor_; + // Number of iterations without changing scaleFactor_. + unsigned int successCounter_{0}; + // Double up the scaleFactor_ when successCounter_ equals updateInterval_. + unsigned int updateInterval_; - FL_SAVE_LOAD(scaleFactor_, maxScaleFactor_, updateInterval_, successCounter_) - DynamicScaler() = default; -}; + FL_SAVE_LOAD( + scaleFactor_, + maxScaleFactor_, + updateInterval_, + successCounter_ + ) DynamicScaler() = default; + }; -} // namespace runtime + } // namespace runtime } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/runtime/common/DistributedUtils.cpp b/flashlight/pkg/runtime/common/DistributedUtils.cpp index 47215c3..627f4ea 100644 --- a/flashlight/pkg/runtime/common/DistributedUtils.cpp +++ b/flashlight/pkg/runtime/common/DistributedUtils.cpp @@ -15,89 +15,93 @@ void initDistributed( int worldRank, int worldSize, int maxDevicesPerNode, - const std::string& rndvFilepath) { - if (rndvFilepath.empty()) { - distributedInit( - fl::DistributedInit::MPI, - -1, // unused for MPI - -1, // unused for MPI - {{fl::DistributedConstants::kMaxDevicePerNode, - std::to_string(maxDevicesPerNode)}}); - } else { - distributedInit( - fl::DistributedInit::FILE_SYSTEM, - worldRank, - worldSize, - {{fl::DistributedConstants::kMaxDevicePerNode, - std::to_string(maxDevicesPerNode)}, - {fl::DistributedConstants::kFilePath, rndvFilepath}}); - } + const std::string& rndvFilepath +) { + if(rndvFilepath.empty()) { + distributedInit( + fl::DistributedInit::MPI, + -1, // unused for MPI + -1, // unused for MPI + {{fl::DistributedConstants::kMaxDevicePerNode, + std::to_string(maxDevicesPerNode)}} + ); + } else { + distributedInit( + fl::DistributedInit::FILE_SYSTEM, + worldRank, + worldSize, + {{fl::DistributedConstants::kMaxDevicePerNode, + std::to_string(maxDevicesPerNode)}, + {fl::DistributedConstants::kFilePath, rndvFilepath}} + ); + } } Tensor allreduceGet(fl::AverageValueMeter& mtr) { - auto mtrVal = mtr.value(); - mtrVal[0] *= mtrVal[2]; - return Tensor::fromVector(mtrVal); + auto mtrVal = mtr.value(); + mtrVal[0] *= mtrVal[2]; + return Tensor::fromVector(mtrVal); } Tensor allreduceGet(fl::EditDistanceMeter& mtr) { - auto mtrVal0 = mtr.value(); - std::vector mtrVal(mtrVal0.begin(), mtrVal0.end()); - return Tensor::fromVector(mtrVal); + auto mtrVal0 = mtr.value(); + std::vector mtrVal(mtrVal0.begin(), mtrVal0.end()); + return Tensor::fromVector(mtrVal); } Tensor allreduceGet(fl::CountMeter& mtr) { - auto mtrVal0 = mtr.value(); - std::vector mtrVal(mtrVal0.begin(), mtrVal0.end()); - return Tensor::fromVector(mtrVal); + auto mtrVal0 = mtr.value(); + std::vector mtrVal(mtrVal0.begin(), mtrVal0.end()); + return Tensor::fromVector(mtrVal); } Tensor allreduceGet(fl::TimeMeter& mtr) { - return fl::full({1}, mtr.value(), fl::dtype::f64); + return fl::full({1}, mtr.value(), fl::dtype::f64); } Tensor allreduceGet(fl::TopKMeter& mtr) { - std::pair stats = mtr.getStats(); - std::vector vec = {stats.first, stats.second}; - return Tensor::fromVector(vec); + std::pair stats = mtr.getStats(); + std::vector vec = {stats.first, stats.second}; + return Tensor::fromVector(vec); } void allreduceSet(fl::AverageValueMeter& mtr, Tensor& val) { - mtr.reset(); - auto valVec = val.toHostVector(); - if (valVec[2] != 0) { - valVec[0] /= valVec[2]; - } - mtr.add(valVec[0], valVec[2]); + mtr.reset(); + auto valVec = val.toHostVector(); + if(valVec[2] != 0) { + valVec[0] /= valVec[2]; + } + mtr.add(valVec[0], valVec[2]); } void allreduceSet(fl::EditDistanceMeter& mtr, Tensor& val) { - mtr.reset(); - auto valVec = val.toHostVector(); - mtr.add( - static_cast(valVec[1]), - static_cast(valVec[2]), - static_cast(valVec[3]), - static_cast(valVec[4])); + mtr.reset(); + auto valVec = val.toHostVector(); + mtr.add( + static_cast(valVec[1]), + static_cast(valVec[2]), + static_cast(valVec[3]), + static_cast(valVec[4]) + ); } void allreduceSet(fl::CountMeter& mtr, Tensor& val) { - mtr.reset(); - auto valVec = val.toHostVector(); - for (size_t i = 0; i < valVec.size(); ++i) { - mtr.add(i, valVec[i]); - } + mtr.reset(); + auto valVec = val.toHostVector(); + for(size_t i = 0; i < valVec.size(); ++i) { + mtr.add(i, valVec[i]); + } } void allreduceSet(fl::TimeMeter& mtr, Tensor& val) { - auto worldSize = fl::getWorldSize(); - auto valVec = val.toHostVector(); - mtr.set(valVec[0] / worldSize); + auto worldSize = fl::getWorldSize(); + auto valVec = val.toHostVector(); + mtr.set(valVec[0] / worldSize); } void allreduceSet(fl::TopKMeter& mtr, Tensor& val) { - mtr.reset(); - auto valVec = val.toHostVector(); - mtr.set(valVec[0], valVec[1]); + mtr.reset(); + auto valVec = val.toHostVector(); + mtr.set(valVec[0], valVec[1]); } } // namespace fl diff --git a/flashlight/pkg/runtime/common/DistributedUtils.h b/flashlight/pkg/runtime/common/DistributedUtils.h index 25ebc35..6a52b35 100644 --- a/flashlight/pkg/runtime/common/DistributedUtils.h +++ b/flashlight/pkg/runtime/common/DistributedUtils.h @@ -17,48 +17,49 @@ namespace fl { class Tensor; namespace pkg { -namespace runtime { + namespace runtime { /** * Call Flashlight API to initialize distributed environment. */ -void initDistributed( - int worldRank, - int worldSize, - int maxDevicesPerNode, - const std::string& rndvFilepath); + void initDistributed( + int worldRank, + int worldSize, + int maxDevicesPerNode, + const std::string& rndvFilepath + ); -Tensor allreduceGet(AverageValueMeter& mtr); -Tensor allreduceGet(EditDistanceMeter& mtr); -Tensor allreduceGet(CountMeter& mtr); -Tensor allreduceGet(TimeMeter& mtr); -Tensor allreduceGet(TopKMeter& mtr); + Tensor allreduceGet(AverageValueMeter& mtr); + Tensor allreduceGet(EditDistanceMeter& mtr); + Tensor allreduceGet(CountMeter& mtr); + Tensor allreduceGet(TimeMeter& mtr); + Tensor allreduceGet(TopKMeter& mtr); -void allreduceSet(AverageValueMeter& mtr, Tensor& val); -void allreduceSet(EditDistanceMeter& mtr, Tensor& val); -void allreduceSet(CountMeter& mtr, Tensor& val); -void allreduceSet(TimeMeter& mtr, Tensor& val); -void allreduceSet(TopKMeter& mtr, Tensor& val); + void allreduceSet(AverageValueMeter& mtr, Tensor& val); + void allreduceSet(EditDistanceMeter& mtr, Tensor& val); + void allreduceSet(CountMeter& mtr, Tensor& val); + void allreduceSet(TimeMeter& mtr, Tensor& val); + void allreduceSet(TopKMeter& mtr, Tensor& val); /** * Synchronize meters across process. */ -template -void syncMeter(T& mtr) { - if (!fl::isDistributedInit()) { - return; - } - Tensor arr = allreduceGet(mtr); - fl::allReduce(arr); - allreduceSet(mtr, arr); -} + template + void syncMeter(T& mtr) { + if(!fl::isDistributedInit()) { + return; + } + Tensor arr = allreduceGet(mtr); + fl::allReduce(arr); + allreduceSet(mtr, arr); + } -template void syncMeter(AverageValueMeter& mtr); -template void syncMeter(EditDistanceMeter& mtr); -template void syncMeter(CountMeter& mtr); -template void syncMeter(TimeMeter& mtr); -template void syncMeter(TopKMeter& mtr); + template void syncMeter(AverageValueMeter& mtr); + template void syncMeter(EditDistanceMeter& mtr); + template void syncMeter(CountMeter& mtr); + template void syncMeter(TimeMeter& mtr); + template void syncMeter(TopKMeter& mtr); -} // namespace runtime + } // namespace runtime } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/runtime/common/SequentialBuilder.cpp b/flashlight/pkg/runtime/common/SequentialBuilder.cpp index 3545ba5..9e0972e 100644 --- a/flashlight/pkg/runtime/common/SequentialBuilder.cpp +++ b/flashlight/pkg/runtime/common/SequentialBuilder.cpp @@ -21,7 +21,8 @@ std::shared_ptr parseLine(const std::string& line); std::shared_ptr parseLines( const std::vector& lines, const int lineIdx, - int& numLinesParsed); + int& numLinesParsed +); } // namespace namespace fl::pkg::runtime { @@ -29,573 +30,621 @@ namespace fl::pkg::runtime { std::shared_ptr buildSequentialModule( const fs::path& archfile, int64_t nFeatures, - int64_t nClasses) { - auto net = std::make_shared(); - - std::vector layers; - { - std::ifstream in(archfile); - if (!in) { - throw std::runtime_error( - "fl::pkg::runtime::buildSequentialModule given invalid arch filepath"); - } - for (std::string str; std::getline(in, str);) { - layers.emplace_back(str); + int64_t nClasses +) { + auto net = std::make_shared(); + + std::vector layers; + { + std::ifstream in(archfile); + if(!in) { + throw std::runtime_error( + "fl::pkg::runtime::buildSequentialModule given invalid arch filepath" + ); + } + for(std::string str; std::getline(in, str);) { + layers.emplace_back(str); + } } - } - int numLinesParsed = 0; + int numLinesParsed = 0; - // preprocess - std::vector processedLayers; - for (auto& l : layers) { - std::string lrepl = fl::lib::trim(l); - fl::lib::replaceAll(lrepl, "NFEAT", std::to_string(nFeatures)); - fl::lib::replaceAll(lrepl, "NLABEL", std::to_string(nClasses)); + // preprocess + std::vector processedLayers; + for(auto& l : layers) { + std::string lrepl = fl::lib::trim(l); + fl::lib::replaceAll(lrepl, "NFEAT", std::to_string(nFeatures)); + fl::lib::replaceAll(lrepl, "NLABEL", std::to_string(nClasses)); - if (lrepl.empty() || fl::lib::startsWith(lrepl, "#")) { - continue; // ignore empty lines / comments + if(lrepl.empty() || fl::lib::startsWith(lrepl, "#")) { + continue; // ignore empty lines / comments + } + processedLayers.emplace_back(lrepl); } - processedLayers.emplace_back(lrepl); - } - int lid = 0; - while (lid < processedLayers.size()) { - net->add(parseLines(processedLayers, lid, numLinesParsed)); - lid += (numLinesParsed + 1); - } + int lid = 0; + while(lid < processedLayers.size()) { + net->add(parseLines(processedLayers, lid, numLinesParsed)); + lid += (numLinesParsed + 1); + } - return net; + return net; } fl::Variable forwardSequentialModuleWithPadMask( const fl::Variable& input, std::shared_ptr ntwrk, - const Tensor& inputSizes) { - // expected input dims T x C x 1 x B - int T = input.dim(0), B = input.dim(3); - auto inputMaxSize = fl::tile(fl::amax(inputSizes, {1}), {1, B}); - Tensor inputNotPaddedSize = fl::ceil(inputSizes * T / inputMaxSize); - auto padMask = - fl::iota({T, 1}, {1, B}) < fl::tile(inputNotPaddedSize, {T, 1}); - auto ntwrkSeq = std::dynamic_pointer_cast(ntwrk); - auto output = input; - for (auto& module : ntwrkSeq->modules()) { - auto tr = std::dynamic_pointer_cast(module); - auto cfr = std::dynamic_pointer_cast(module); - if (tr != nullptr || cfr != nullptr) { - output = module->forward({output, fl::noGrad(padMask)}).front(); - } else { - output = module->forward({output}).front(); - } - } - return output.astype(input.type()); + const Tensor& inputSizes +) { + // expected input dims T x C x 1 x B + int T = input.dim(0), B = input.dim(3); + auto inputMaxSize = fl::tile(fl::amax(inputSizes, {1}), {1, B}); + Tensor inputNotPaddedSize = fl::ceil(inputSizes * T / inputMaxSize); + auto padMask = + fl::iota({T, 1}, {1, B}) < fl::tile(inputNotPaddedSize, {T, 1}); + auto ntwrkSeq = std::dynamic_pointer_cast(ntwrk); + auto output = input; + for(auto& module : ntwrkSeq->modules()) { + auto tr = std::dynamic_pointer_cast(module); + auto cfr = std::dynamic_pointer_cast(module); + if(tr != nullptr || cfr != nullptr) { + output = module->forward({output, fl::noGrad(padMask)}).front(); + } else { + output = module->forward({output}).front(); + } + } + return output.astype(input.type()); } } // namespace fl namespace { std::shared_ptr parseLine(const std::string& line) { - int dummy; - return parseLines({line}, 0, dummy); + int dummy; + return parseLines({line}, 0, dummy); } std::shared_ptr parseLines( const std::vector& lines, const int lineIdx, - int& numLinesParsed) { - auto line = lines[lineIdx]; - numLinesParsed = 0; - auto params = fl::lib::splitOnWhitespace(line, true); - - auto inRange = [&](const int a, const int b, const int c) { - return (a <= b && b <= c); - }; - - /* ========== TRANSFORMATIONS ========== */ - - if ((params[0] == "RO") || (params[0] == "V")) { - if (params.size() < 2) { - throw std::invalid_argument("Failed parsing - " + line); - } - Shape shape(std::vector(params.size() - 1)); - for (unsigned i = 1; i < params.size(); ++i) { - shape[i - 1] = std::stoi(params[i]); - } - if (params[0] == "RO") { - return std::make_shared(shape); - } else { - return std::make_shared(shape); - } - } - - if (params[0] == "PD") { - if (!inRange(4, params.size(), 10) || (params.size() & 1)) { - throw std::invalid_argument("Failed parsing - " + line); - } - auto val = std::stod(params[1]); - params.resize(10, "0"); - std::vector> paddings = { - {std::stoi(params[2]), std::stoi(params[3])}, - {std::stoi(params[4]), std::stoi(params[5])}, - {std::stoi(params[6]), std::stoi(params[7])}, - {std::stoi(params[8]), std::stoi(params[9])}}; - // TODO{fl::Tensor} -- rearrange arguments - return std::make_shared(paddings, val); - } - - /* ========== TRANSFORMERS ========== */ - - if (params[0] == "TR") { - if (!inRange(6, params.size(), 9)) { - throw std::invalid_argument("Failed parsing - " + line); - } - int modelDim = std::stoi(params[1]); - int mlpDim = std::stoi(params[2]); - int nHead = std::stoi(params[3]); - int csz = std::stoi(params[4]); - float pDropout = std::stof(params[5]); - float pLayerdrop = (params.size() >= 7) ? std::stof(params[6]) : 0.0; - int preLN = (params.size() >= 8) ? std::stoi(params[7]) : 0; - bool useFutureMask = (params.size() >= 9) ? std::stoi(params[8]) : false; - return std::make_shared( - modelDim, - modelDim / nHead, - mlpDim, - nHead, - csz, - pDropout, - pLayerdrop, - useFutureMask, - preLN); - } - - if (params[0] == "CFR") { - if (!inRange(7, params.size(), 8)) { - throw std::invalid_argument("Failed parsing - " + line); - } - int modelDim = std::stoi(params[1]); - int mlpDim = std::stoi(params[2]); - int nHead = std::stoi(params[3]); - int csz = std::stoi(params[4]); - int kernel = std::stoi(params[5]); - float pDropout = std::stof(params[6]); - float pLayerdrop = (params.size() >= 8) ? std::stof(params[7]) : 0.0; - return std::make_shared( - modelDim, - modelDim / nHead, - mlpDim, - nHead, - csz, - kernel, - pDropout, - pLayerdrop); - } - - if (params[0] == "POSEMB") { - if (!inRange(3, params.size(), 4)) { - throw std::invalid_argument("Failed parsing - " + line); - } - int layerDim = std::stoi(params[1]); - int csz = std::stoi(params[2]); - float dropout = (params.size() >= 4) ? std::stof(params[3]) : 0.0; - return std::make_shared(layerDim, csz, dropout); - } - - if (params[0] == "SINPOSEMB") { - if (!inRange(2, params.size(), 3)) { - throw std::invalid_argument("Failed parsing - " + line); - } - int layerDim = std::stoi(params[1]); - float inputScale = (params.size() >= 3) ? std::stof(params[2]) : 1.0; - return std::make_shared(layerDim, inputScale); - } - - /* ========== CONVOLUTIONS ========== */ - - if (params[0] == "C" || params[0] == "C1") { - if (!inRange(5, params.size(), 7)) { - throw std::invalid_argument("Failed parsing - " + line); - } - int cisz = std::stoi(params[1]); - int cosz = std::stoi(params[2]); - int cwx = std::stoi(params[3]); - int csx = std::stoi(params[4]); - int cpx = (params.size() >= 6) ? std::stoi(params[5]) : 0; - int cdx = (params.size() >= 7) ? std::stoi(params[6]) : 1; - return std::make_shared(cisz, cosz, cwx, 1, csx, 1, cpx, 0, cdx, 1); - } - - if (params[0] == "TDS") { - if (!inRange(4, params.size(), 8)) { - throw std::invalid_argument("Failed parsing - " + line); - } - int cisz = std::stoi(params[1]); - int cwx = std::stoi(params[2]); - int freqdim = std::stoi(params[3]); - double dropprob = (params.size() >= 5 ? std::stod(params[4]) : 0); - int l2 = (params.size() >= 6 ? std::stoi(params[5]) : 0); - int rPad = (params.size() >= 7) ? std::stoi(params[6]) : -1; - bool lNormIncludeTime = - (params.size() >= 8 && std::stoi(params[7]) == 0) ? false : true; - return std::make_shared( - cisz, cwx, freqdim, dropprob, l2, rPad, lNormIncludeTime); - } - - if (params[0] == "AC") { - if (!inRange(5, params.size(), 8)) { - throw std::invalid_argument("Failed parsing - " + line); - } - int cisz = std::stoi(params[1]); - int cosz = std::stoi(params[2]); - int cwx = std::stoi(params[3]); - int csx = std::stoi(params[4]); - int cpx = (params.size() >= 6) ? std::stoi(params[5]) : 0; - float futurePartPx = (params.size() >= 7) ? std::stof(params[6]) : 1.; - int cdx = (params.size() >= 8) ? std::stoi(params[7]) : 1; - return std::make_shared( - cisz, cosz, cwx, csx, cpx, futurePartPx, cdx); - } - - if (params[0] == "C2") { - if (!inRange(7, params.size(), 11)) { - throw std::invalid_argument("Failed parsing - " + line); - } - int cisz = std::stoi(params[1]); - int cosz = std::stoi(params[2]); - int cwx = std::stoi(params[3]); - int cwy = std::stoi(params[4]); - int csx = std::stoi(params[5]); - int csy = std::stoi(params[6]); - int cpx = (params.size() >= 8) ? std::stoi(params[7]) : 0; - int cpy = (params.size() >= 9) ? std::stoi(params[8]) : 0; - int cdx = (params.size() >= 10) ? std::stoi(params[9]) : 1; - int cdy = (params.size() >= 11) ? std::stoi(params[10]) : 1; - return std::make_shared( - cisz, cosz, cwx, cwy, csx, csy, cpx, cpy, cdx, cdy); - } - - /* ========== LINEAR ========== */ - - if (params[0] == "L") { - if (!inRange(3, params.size(), 4)) { - throw std::invalid_argument("Failed parsing - " + line); - } - int lisz = std::stoi(params[1]); - int losz = std::stoi(params[2]); - bool bias = (params.size() == 4) && params[3] == "0" ? false : true; - return std::make_shared(lisz, losz, bias); - } - - /* ========== EMBEDDING ========== */ - - if (params[0] == "E") { - if (params.size() != 3) { - throw std::invalid_argument("Failed parsing - " + line); - } - int embsz = std::stoi(params[1]); - int ntokens = std::stoi(params[2]); - return std::make_shared(embsz, ntokens); - } - - if (params[0] == "ADAPTIVEE") { - if (params.size() != 3) { - throw std::invalid_argument("Failed parsing - " + line); - } - int embsz = std::stoi(params[1]); - std::vector cutoffs; - auto tokens = fl::lib::split(',', params[2], true); - for (const auto& token : tokens) { - cutoffs.push_back(std::stoi(fl::lib::trim(token))); - } - for (int i = 1; i < cutoffs.size(); ++i) { - if (cutoffs[i - 1] >= cutoffs[i]) { - throw std::invalid_argument("cutoffs must be strictly ascending"); - } - } - return std::make_shared(embsz, cutoffs); - } - - /* ========== NORMALIZATIONS ========== */ - - if (params[0] == "BN") { - if (!inRange(3, params.size(), 5)) { - throw std::invalid_argument("Failed parsing - " + line); - } - int featSz = std::stoi(params[1]); - std::vector featDims; - for (int i = 2; i < params.size(); ++i) { - featDims.emplace_back(std::stoi(params[i])); - } - return std::make_shared(featDims, featSz); - } - - if (params[0] == "LN") { - if (!inRange(2, params.size(), 4)) { - throw std::invalid_argument("Failed parsing - " + line); - } - std::vector featDims; - for (int i = 1; i < params.size(); ++i) { - featDims.emplace_back(std::stoi(params[i])); - } - if (featDims == std::vector{3}) { - if (!inRange(7, params.size(), 11)) { - throw std::invalid_argument( - "Failed parsing - " - "flashlight LayerNorm API for specifying `featAxes` is modified " - "recently - https://git.io/Je70U. You probably would want to " - "specify LN 0 1 2 instead of LN 3. If you really know what you're " - "doing, comment out this check and build again."); - } - } - return std::make_shared(featDims); - } - - if (params[0] == "WN") { - if (params.size() < 3) { - throw std::invalid_argument("Failed parsing - " + line); - } - int dim = std::stoi(params[1]); - std::string childStr = fl::lib::join(" ", params.begin() + 2, params.end()); - return std::make_shared(parseLine(childStr), dim); - } - - if (params[0] == "DO") { - if (params.size() != 2) { - throw std::invalid_argument("Failed parsing - " + line); - } - auto drpVal = std::stod(params[1]); - return std::make_shared(drpVal); - } - - /* ========== POOLING ========== */ - - if ((params[0] == "M") || (params[0] == "A")) { - if (params.size() < 5) { - throw std::invalid_argument("Failed parsing - " + line); - } - int wx = std::stoi(params[1]); - int wy = std::stoi(params[2]); - int dx = std::stoi(params[3]); - int dy = std::stoi(params[4]); - int px = params.size() > 5 ? std::stoi(params[5]) : 0; - int py = params.size() > 6 ? std::stoi(params[6]) : 0; - auto mode = (params[0] == "A") ? PoolingMode::AVG_INCLUDE_PADDING - : PoolingMode::MAX; - - return std::make_shared(wx, wy, dx, dy, px, py, mode); - } + int& numLinesParsed +) { + auto line = lines[lineIdx]; + numLinesParsed = 0; + auto params = fl::lib::splitOnWhitespace(line, true); + + auto inRange = [&](const int a, const int b, const int c) { + return a <= b && b <= c; + }; + + /* ========== TRANSFORMATIONS ========== */ - /* ========== ACTIVATIONS ========== */ - - if (params[0] == "ELU") { - if (params.size() != 1) { - throw std::invalid_argument("Failed parsing - " + line); - } - return std::make_shared(); - } - - if (params[0] == "R") { - if (params.size() != 1) { - throw std::invalid_argument("Failed parsing - " + line); - } - return std::make_shared(); - } + if((params[0] == "RO") || (params[0] == "V")) { + if(params.size() < 2) { + throw std::invalid_argument("Failed parsing - " + line); + } + Shape shape(std::vector(params.size() - 1)); + for(unsigned i = 1; i < params.size(); ++i) { + shape[i - 1] = std::stoi(params[i]); + } + if(params[0] == "RO") { + return std::make_shared(shape); + } else { + return std::make_shared(shape); + } + } - if (params[0] == "R6") { - if (params.size() != 1) { - throw std::invalid_argument("Failed parsing - " + line); + if(params[0] == "PD") { + if(!inRange(4, params.size(), 10) || (params.size() & 1)) { + throw std::invalid_argument("Failed parsing - " + line); + } + auto val = std::stod(params[1]); + params.resize(10, "0"); + std::vector> paddings = { + {std::stoi(params[2]), std::stoi(params[3])}, + {std::stoi(params[4]), std::stoi(params[5])}, + {std::stoi(params[6]), std::stoi(params[7])}, + {std::stoi(params[8]), std::stoi(params[9])}}; + // TODO{fl::Tensor} -- rearrange arguments + return std::make_shared(paddings, val); + } + + /* ========== TRANSFORMERS ========== */ + + if(params[0] == "TR") { + if(!inRange(6, params.size(), 9)) { + throw std::invalid_argument("Failed parsing - " + line); + } + int modelDim = std::stoi(params[1]); + int mlpDim = std::stoi(params[2]); + int nHead = std::stoi(params[3]); + int csz = std::stoi(params[4]); + float pDropout = std::stof(params[5]); + float pLayerdrop = (params.size() >= 7) ? std::stof(params[6]) : 0.0; + int preLN = (params.size() >= 8) ? std::stoi(params[7]) : 0; + bool useFutureMask = (params.size() >= 9) ? std::stoi(params[8]) : false; + return std::make_shared( + modelDim, + modelDim / nHead, + mlpDim, + nHead, + csz, + pDropout, + pLayerdrop, + useFutureMask, + preLN + ); + } + + if(params[0] == "CFR") { + if(!inRange(7, params.size(), 8)) { + throw std::invalid_argument("Failed parsing - " + line); + } + int modelDim = std::stoi(params[1]); + int mlpDim = std::stoi(params[2]); + int nHead = std::stoi(params[3]); + int csz = std::stoi(params[4]); + int kernel = std::stoi(params[5]); + float pDropout = std::stof(params[6]); + float pLayerdrop = (params.size() >= 8) ? std::stof(params[7]) : 0.0; + return std::make_shared( + modelDim, + modelDim / nHead, + mlpDim, + nHead, + csz, + kernel, + pDropout, + pLayerdrop + ); + } + + if(params[0] == "POSEMB") { + if(!inRange(3, params.size(), 4)) { + throw std::invalid_argument("Failed parsing - " + line); + } + int layerDim = std::stoi(params[1]); + int csz = std::stoi(params[2]); + float dropout = (params.size() >= 4) ? std::stof(params[3]) : 0.0; + return std::make_shared(layerDim, csz, dropout); } - return std::make_shared(); - } - if (params[0] == "PR") { - if (!inRange(1, params.size(), 3)) { - throw std::invalid_argument("Failed parsing - " + line); + if(params[0] == "SINPOSEMB") { + if(!inRange(2, params.size(), 3)) { + throw std::invalid_argument("Failed parsing - " + line); + } + int layerDim = std::stoi(params[1]); + float inputScale = (params.size() >= 3) ? std::stof(params[2]) : 1.0; + return std::make_shared(layerDim, inputScale); } - auto numParams = params.size() > 1 ? std::stoi(params[1]) : 1; - auto initVal = params.size() > 2 ? std::stod(params[2]) : 0.25; - return std::make_shared(numParams, initVal); - } - if (params[0] == "LG") { - if (params.size() != 1) { - throw std::invalid_argument("Failed parsing - " + line); + /* ========== CONVOLUTIONS ========== */ + + if(params[0] == "C" || params[0] == "C1") { + if(!inRange(5, params.size(), 7)) { + throw std::invalid_argument("Failed parsing - " + line); + } + int cisz = std::stoi(params[1]); + int cosz = std::stoi(params[2]); + int cwx = std::stoi(params[3]); + int csx = std::stoi(params[4]); + int cpx = (params.size() >= 6) ? std::stoi(params[5]) : 0; + int cdx = (params.size() >= 7) ? std::stoi(params[6]) : 1; + return std::make_shared(cisz, cosz, cwx, 1, csx, 1, cpx, 0, cdx, 1); + } + + if(params[0] == "TDS") { + if(!inRange(4, params.size(), 8)) { + throw std::invalid_argument("Failed parsing - " + line); + } + int cisz = std::stoi(params[1]); + int cwx = std::stoi(params[2]); + int freqdim = std::stoi(params[3]); + double dropprob = (params.size() >= 5 ? std::stod(params[4]) : 0); + int l2 = (params.size() >= 6 ? std::stoi(params[5]) : 0); + int rPad = (params.size() >= 7) ? std::stoi(params[6]) : -1; + bool lNormIncludeTime = + (params.size() >= 8 && std::stoi(params[7]) == 0) ? false : true; + return std::make_shared( + cisz, + cwx, + freqdim, + dropprob, + l2, + rPad, + lNormIncludeTime + ); + } + + if(params[0] == "AC") { + if(!inRange(5, params.size(), 8)) { + throw std::invalid_argument("Failed parsing - " + line); + } + int cisz = std::stoi(params[1]); + int cosz = std::stoi(params[2]); + int cwx = std::stoi(params[3]); + int csx = std::stoi(params[4]); + int cpx = (params.size() >= 6) ? std::stoi(params[5]) : 0; + float futurePartPx = (params.size() >= 7) ? std::stof(params[6]) : 1.; + int cdx = (params.size() >= 8) ? std::stoi(params[7]) : 1; + return std::make_shared( + cisz, + cosz, + cwx, + csx, + cpx, + futurePartPx, + cdx + ); + } + + if(params[0] == "C2") { + if(!inRange(7, params.size(), 11)) { + throw std::invalid_argument("Failed parsing - " + line); + } + int cisz = std::stoi(params[1]); + int cosz = std::stoi(params[2]); + int cwx = std::stoi(params[3]); + int cwy = std::stoi(params[4]); + int csx = std::stoi(params[5]); + int csy = std::stoi(params[6]); + int cpx = (params.size() >= 8) ? std::stoi(params[7]) : 0; + int cpy = (params.size() >= 9) ? std::stoi(params[8]) : 0; + int cdx = (params.size() >= 10) ? std::stoi(params[9]) : 1; + int cdy = (params.size() >= 11) ? std::stoi(params[10]) : 1; + return std::make_shared( + cisz, + cosz, + cwx, + cwy, + csx, + csy, + cpx, + cpy, + cdx, + cdy + ); + } + + /* ========== LINEAR ========== */ + + if(params[0] == "L") { + if(!inRange(3, params.size(), 4)) { + throw std::invalid_argument("Failed parsing - " + line); + } + int lisz = std::stoi(params[1]); + int losz = std::stoi(params[2]); + bool bias = (params.size() == 4) && params[3] == "0" ? false : true; + return std::make_shared(lisz, losz, bias); } - return std::make_shared(); - } - - if (params[0] == "HT") { - if (params.size() != 1) { - throw std::invalid_argument("Failed parsing - " + line); - } - return std::make_shared(); - } - - if (params[0] == "T") { - if (params.size() != 1) { - throw std::invalid_argument("Failed parsing - " + line); - } - return std::make_shared(); - } - - if (params[0] == "GLU") { - if (params.size() != 2) { - throw std::invalid_argument("Failed parsing - " + line); - } - int dim = std::stoi(params[1]); - return std::make_shared(dim); - } - - if (params[0] == "LSM") { - if (params.size() != 2) { - throw std::invalid_argument("Failed parsing - " + line); - } - int dim = std::stoi(params[1]); - return std::make_shared(dim); - } - - if (params[0] == "SH") { - if (!inRange(1, params.size(), 2)) { - throw std::invalid_argument("Failed parsing - " + line); - } - auto beta = params.size() > 1 ? std::stof(params[1]) : 1.0; - return std::make_shared(beta); - } - - /* ========== RNNs ========== */ - - auto rnnLayer = [&](const std::vector& prms, RnnMode mode) { - int iSz = std::stoi(prms[1]); - int oSz = std::stoi(prms[2]); - int numLayers = (prms.size() > 3) ? std::stoi(prms[3]) : 1; - bool bidirectional = (prms.size() > 4) ? std::stoi(prms[4]) > 0 : false; - float dropout = (prms.size() > 5) ? std::stof(prms[5]) : 0.0; - return std::make_shared( - iSz, oSz, numLayers, mode, bidirectional, dropout); - }; - - if (params[0] == "RNN") { - if (params.size() < 3) { - throw std::invalid_argument("Failed parsing - " + line); - } - return rnnLayer(params, RnnMode::RELU); - } - - if (params[0] == "GRU") { - if (params.size() < 3) { - throw std::invalid_argument("Failed parsing - " + line); - } - return rnnLayer(params, RnnMode::GRU); - } - - if (params[0] == "LSTM") { - if (params.size() < 3) { - throw std::invalid_argument("Failed parsing - " + line); - } - return rnnLayer(params, RnnMode::LSTM); - } - - /* ========== Residual block ========== */ - if (params[0] == "RES") { - if (params.size() <= 3) { - throw std::invalid_argument("Failed parsing - " + line); - } - - auto residualBlock = [&](const std::vector& prms, - int& numResLayerAndSkip) { - int numResLayers = std::stoi(prms[1]); - int numSkipConnections = std::stoi(prms[2]); - std::shared_ptr resPtr = std::make_shared(); - - int numProjections = 0; - - for (int i = 1; i <= numResLayers + numSkipConnections; ++i) { - if (lineIdx + i + numProjections >= lines.size()) { - throw std::invalid_argument("Failed parsing Residual block"); - } - const std::string& resLine = lines[lineIdx + i + numProjections]; - auto resLinePrms = fl::lib::splitOnWhitespace(resLine, true); - - if (resLinePrms[0] == "SKIP") { - if (!inRange(3, resLinePrms.size(), 4)) { - throw std::invalid_argument("Failed parsing - " + resLine); - } - resPtr->addShortcut( - std::stoi(resLinePrms[1]), std::stoi(resLinePrms[2])); - if (resLinePrms.size() == 4) { - resPtr->addScale( - std::stoi(resLinePrms[2]), std::stof(resLinePrms[3])); - } - } else if (resLinePrms[0] == "SKIPL") { - if (!inRange(4, resLinePrms.size(), 5)) { - throw std::invalid_argument("Failed parsing - " + resLine); - } - int numProjectionLayers = std::stoi(resLinePrms[3]); - auto projection = std::make_shared(); - - for (int j = 1; j <= numProjectionLayers; ++j) { - if (lineIdx + i + numProjections + j >= lines.size()) { - throw std::invalid_argument("Failed parsing Residual block"); + + /* ========== EMBEDDING ========== */ + + if(params[0] == "E") { + if(params.size() != 3) { + throw std::invalid_argument("Failed parsing - " + line); + } + int embsz = std::stoi(params[1]); + int ntokens = std::stoi(params[2]); + return std::make_shared(embsz, ntokens); + } + + if(params[0] == "ADAPTIVEE") { + if(params.size() != 3) { + throw std::invalid_argument("Failed parsing - " + line); + } + int embsz = std::stoi(params[1]); + std::vector cutoffs; + auto tokens = fl::lib::split(',', params[2], true); + for(const auto& token : tokens) { + cutoffs.push_back(std::stoi(fl::lib::trim(token))); + } + for(int i = 1; i < cutoffs.size(); ++i) { + if(cutoffs[i - 1] >= cutoffs[i]) { + throw std::invalid_argument("cutoffs must be strictly ascending"); } - projection->add(parseLine(lines[lineIdx + i + numProjections + j])); - } - resPtr->addShortcut( - std::stoi(resLinePrms[1]), std::stoi(resLinePrms[2]), projection); - if (resLinePrms.size() == 5) { - resPtr->addScale( - std::stoi(resLinePrms[2]), std::stof(resLinePrms[4])); - } - numProjections += numProjectionLayers; + } + return std::make_shared(embsz, cutoffs); + } + + /* ========== NORMALIZATIONS ========== */ + + if(params[0] == "BN") { + if(!inRange(3, params.size(), 5)) { + throw std::invalid_argument("Failed parsing - " + line); + } + int featSz = std::stoi(params[1]); + std::vector featDims; + for(int i = 2; i < params.size(); ++i) { + featDims.emplace_back(std::stoi(params[i])); + } + return std::make_shared(featDims, featSz); + } + + if(params[0] == "LN") { + if(!inRange(2, params.size(), 4)) { + throw std::invalid_argument("Failed parsing - " + line); + } + std::vector featDims; + for(int i = 1; i < params.size(); ++i) { + featDims.emplace_back(std::stoi(params[i])); + } + if(featDims == std::vector{3}) { + if(!inRange(7, params.size(), 11)) { + throw std::invalid_argument( + "Failed parsing - " + "flashlight LayerNorm API for specifying `featAxes` is modified " + "recently - https://git.io/Je70U. You probably would want to " + "specify LN 0 1 2 instead of LN 3. If you really know what you're " + "doing, comment out this check and build again." + ); + } + } + return std::make_shared(featDims); + } + + if(params[0] == "WN") { + if(params.size() < 3) { + throw std::invalid_argument("Failed parsing - " + line); + } + int dim = std::stoi(params[1]); + std::string childStr = fl::lib::join(" ", params.begin() + 2, params.end()); + return std::make_shared(parseLine(childStr), dim); + } + + if(params[0] == "DO") { + if(params.size() != 2) { + throw std::invalid_argument("Failed parsing - " + line); + } + auto drpVal = std::stod(params[1]); + return std::make_shared(drpVal); + } + + /* ========== POOLING ========== */ + + if((params[0] == "M") || (params[0] == "A")) { + if(params.size() < 5) { + throw std::invalid_argument("Failed parsing - " + line); + } + int wx = std::stoi(params[1]); + int wy = std::stoi(params[2]); + int dx = std::stoi(params[3]); + int dy = std::stoi(params[4]); + int px = params.size() > 5 ? std::stoi(params[5]) : 0; + int py = params.size() > 6 ? std::stoi(params[6]) : 0; + auto mode = (params[0] == "A") ? PoolingMode::AVG_INCLUDE_PADDING + : PoolingMode::MAX; + + return std::make_shared(wx, wy, dx, dy, px, py, mode); + } + + /* ========== ACTIVATIONS ========== */ + + if(params[0] == "ELU") { + if(params.size() != 1) { + throw std::invalid_argument("Failed parsing - " + line); + } + return std::make_shared(); + } + + if(params[0] == "R") { + if(params.size() != 1) { + throw std::invalid_argument("Failed parsing - " + line); + } + return std::make_shared(); + } + + if(params[0] == "R6") { + if(params.size() != 1) { + throw std::invalid_argument("Failed parsing - " + line); + } + return std::make_shared(); + } + + if(params[0] == "PR") { + if(!inRange(1, params.size(), 3)) { + throw std::invalid_argument("Failed parsing - " + line); + } + auto numParams = params.size() > 1 ? std::stoi(params[1]) : 1; + auto initVal = params.size() > 2 ? std::stod(params[2]) : 0.25; + return std::make_shared(numParams, initVal); + } + + if(params[0] == "LG") { + if(params.size() != 1) { + throw std::invalid_argument("Failed parsing - " + line); + } + return std::make_shared(); + } + + if(params[0] == "HT") { + if(params.size() != 1) { + throw std::invalid_argument("Failed parsing - " + line); + } + return std::make_shared(); + } + + if(params[0] == "T") { + if(params.size() != 1) { + throw std::invalid_argument("Failed parsing - " + line); + } + return std::make_shared(); + } + + if(params[0] == "GLU") { + if(params.size() != 2) { + throw std::invalid_argument("Failed parsing - " + line); + } + int dim = std::stoi(params[1]); + return std::make_shared(dim); + } + + if(params[0] == "LSM") { + if(params.size() != 2) { + throw std::invalid_argument("Failed parsing - " + line); + } + int dim = std::stoi(params[1]); + return std::make_shared(dim); + } + + if(params[0] == "SH") { + if(!inRange(1, params.size(), 2)) { + throw std::invalid_argument("Failed parsing - " + line); + } + auto beta = params.size() > 1 ? std::stof(params[1]) : 1.0; + return std::make_shared(beta); + } + + /* ========== RNNs ========== */ + + auto rnnLayer = [&](const std::vector& prms, RnnMode mode) { + int iSz = std::stoi(prms[1]); + int oSz = std::stoi(prms[2]); + int numLayers = (prms.size() > 3) ? std::stoi(prms[3]) : 1; + bool bidirectional = (prms.size() > 4) ? std::stoi(prms[4]) > 0 : false; + float dropout = (prms.size() > 5) ? std::stof(prms[5]) : 0.0; + return std::make_shared( + iSz, + oSz, + numLayers, + mode, + bidirectional, + dropout + ); + }; + + if(params[0] == "RNN") { + if(params.size() < 3) { + throw std::invalid_argument("Failed parsing - " + line); + } + return rnnLayer(params, RnnMode::RELU); + } + + if(params[0] == "GRU") { + if(params.size() < 3) { + throw std::invalid_argument("Failed parsing - " + line); + } + return rnnLayer(params, RnnMode::GRU); + } + + if(params[0] == "LSTM") { + if(params.size() < 3) { + throw std::invalid_argument("Failed parsing - " + line); + } + return rnnLayer(params, RnnMode::LSTM); + } + + /* ========== Residual block ========== */ + if(params[0] == "RES") { + if(params.size() <= 3) { + throw std::invalid_argument("Failed parsing - " + line); + } + + auto residualBlock = [&](const std::vector& prms, + int& numResLayerAndSkip) { + int numResLayers = std::stoi(prms[1]); + int numSkipConnections = std::stoi(prms[2]); + std::shared_ptr resPtr = std::make_shared(); + + int numProjections = 0; + + for(int i = 1; i <= numResLayers + numSkipConnections; ++i) { + if(lineIdx + i + numProjections >= lines.size()) { + throw std::invalid_argument("Failed parsing Residual block"); + } + const std::string& resLine = lines[lineIdx + i + numProjections]; + auto resLinePrms = fl::lib::splitOnWhitespace(resLine, true); + + if(resLinePrms[0] == "SKIP") { + if(!inRange(3, resLinePrms.size(), 4)) { + throw std::invalid_argument("Failed parsing - " + resLine); + } + resPtr->addShortcut( + std::stoi(resLinePrms[1]), + std::stoi(resLinePrms[2]) + ); + if(resLinePrms.size() == 4) { + resPtr->addScale( + std::stoi(resLinePrms[2]), + std::stof(resLinePrms[3]) + ); + } + } else if(resLinePrms[0] == "SKIPL") { + if(!inRange(4, resLinePrms.size(), 5)) { + throw std::invalid_argument("Failed parsing - " + resLine); + } + int numProjectionLayers = std::stoi(resLinePrms[3]); + auto projection = std::make_shared(); + + for(int j = 1; j <= numProjectionLayers; ++j) { + if(lineIdx + i + numProjections + j >= lines.size()) { + throw std::invalid_argument("Failed parsing Residual block"); + } + projection->add(parseLine(lines[lineIdx + i + numProjections + j])); + } + resPtr->addShortcut( + std::stoi(resLinePrms[1]), + std::stoi(resLinePrms[2]), + projection + ); + if(resLinePrms.size() == 5) { + resPtr->addScale( + std::stoi(resLinePrms[2]), + std::stof(resLinePrms[4]) + ); + } + numProjections += numProjectionLayers; + } else { + resPtr->add(parseLine(resLine)); + } + } + + numResLayerAndSkip = numResLayers + numSkipConnections + numProjections; + return resPtr; + }; + + auto numBlocks = params.size() == 4 ? std::stoi(params.back()) : 1; + if(numBlocks <= 0) { + throw std::invalid_argument( + "Invalid number of residual blocks: " + std::to_string(numBlocks) + ); + } + + if(numBlocks > 1) { + auto res = std::make_shared(); + for(int n = 0; n < numBlocks; ++n) { + res->add(residualBlock(params, numLinesParsed)); + } + return res; } else { - resPtr->add(parseLine(resLine)); - } - } - - numResLayerAndSkip = numResLayers + numSkipConnections + numProjections; - return resPtr; - }; - - auto numBlocks = params.size() == 4 ? std::stoi(params.back()) : 1; - if (numBlocks <= 0) { - throw std::invalid_argument( - "Invalid number of residual blocks: " + std::to_string(numBlocks)); - } - - if (numBlocks > 1) { - auto res = std::make_shared(); - for (int n = 0; n < numBlocks; ++n) { - res->add(residualBlock(params, numLinesParsed)); - } - return res; - } else { - return residualBlock(params, numLinesParsed); - } - } - - /* ========== Data Augmentation ========== */ - if (params[0] == "SAUG") { - if (params.size() != 7) { - throw std::invalid_argument("Failed parsing - " + line); - } - return std::make_shared( - std::stoi(params[1]), - std::stoi(params[2]), - std::stoi(params[3]), - std::stoi(params[4]), - std::stod(params[5]), - std::stoi(params[6])); - } - - /* ========== Precision Cast ========== */ - if (params[0] == "PC") { - if (params.size() != 2) { - throw std::invalid_argument("Failed parsing - " + line); - } - auto targetType = fl::stringToDtype(params[1]); - return std::make_shared(targetType); - } - - throw std::invalid_argument("Failed parsing - " + line); + return residualBlock(params, numLinesParsed); + } + } + + /* ========== Data Augmentation ========== */ + if(params[0] == "SAUG") { + if(params.size() != 7) { + throw std::invalid_argument("Failed parsing - " + line); + } + return std::make_shared( + std::stoi(params[1]), + std::stoi(params[2]), + std::stoi(params[3]), + std::stoi(params[4]), + std::stod(params[5]), + std::stoi(params[6]) + ); + } + + /* ========== Precision Cast ========== */ + if(params[0] == "PC") { + if(params.size() != 2) { + throw std::invalid_argument("Failed parsing - " + line); + } + auto targetType = fl::stringToDtype(params[1]); + return std::make_shared(targetType); + } + + throw std::invalid_argument("Failed parsing - " + line); } // namespace } // namespace diff --git a/flashlight/pkg/runtime/common/SequentialBuilder.h b/flashlight/pkg/runtime/common/SequentialBuilder.h index 98d6c92..c4595f3 100644 --- a/flashlight/pkg/runtime/common/SequentialBuilder.h +++ b/flashlight/pkg/runtime/common/SequentialBuilder.h @@ -16,16 +16,17 @@ namespace fl { class Tensor; namespace pkg { -namespace runtime { + namespace runtime { /** * Build a sequential module by parsing a file that * defines the model architecture. */ -std::shared_ptr buildSequentialModule( - const fs::path& archfile, - int64_t nFeatures, - int64_t nClasses); + std::shared_ptr buildSequentialModule( + const fs::path& archfile, + int64_t nFeatures, + int64_t nClasses + ); /** * Utility function for to run forward with pad masking @@ -35,11 +36,12 @@ std::shared_ptr buildSequentialModule( * with a transformer block in it! * TODO remove with landing plugin arch instead of arch files */ -fl::Variable forwardSequentialModuleWithPadMask( - const fl::Variable& input, - std::shared_ptr ntwrk, - const Tensor& inputSizes); + fl::Variable forwardSequentialModuleWithPadMask( + const fl::Variable& input, + std::shared_ptr ntwrk, + const Tensor& inputSizes + ); -} // namespace runtime + } // namespace runtime } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/runtime/common/Serializer.h b/flashlight/pkg/runtime/common/Serializer.h index 84826b7..da74b01 100644 --- a/flashlight/pkg/runtime/common/Serializer.h +++ b/flashlight/pkg/runtime/common/Serializer.h @@ -16,75 +16,81 @@ namespace fl { namespace pkg { -namespace runtime { + namespace runtime { -struct Serializer { - public: - template - static void save( - const fs::path& filepath, - const std::string& version, - const Args&... args) { - fl::retryWithBackoff( - std::chrono::seconds(1), - 2.0, - 6, - saveImpl, - filepath, - version, - args...); // max wait 31s - } + struct Serializer { + public: + template + static void save( + const fs::path& filepath, + const std::string& version, + const Args&... args + ) { + fl::retryWithBackoff( + std::chrono::seconds(1), + 2.0, + 6, + saveImpl, + filepath, + version, + args... + ); // max wait 31s + } - template - static void load(const fs::path& filepath, Args&... args) { - fl::retryWithBackoff( - std::chrono::seconds(1), - 2.0, - 6, - loadImpl, - filepath, - args...); // max wait 31s - } + template + static void load(const fs::path& filepath, Args&... args) { + fl::retryWithBackoff( + std::chrono::seconds(1), + 2.0, + 6, + loadImpl, + filepath, + args... + ); // max wait 31s + } - private: - template - static void saveImpl( - const fs::path& filepath, - const std::string& version, - const Args&... args) { - try { - std::ofstream file(filepath, std::ios::binary); - if (!file.is_open()) { - throw std::runtime_error( - "failed to open file for writing: " + filepath.string()); - } - cereal::BinaryOutputArchive ar(file); - ar(version); - ar(args...); - } catch (const std::exception& ex) { - FL_LOG(fl::LogLevel::ERROR) - << "Error while saving \"" << filepath << "\": " << ex.what() << "\n"; - throw; - } - } + private: + template + static void saveImpl( + const fs::path& filepath, + const std::string& version, + const Args&... args + ) { + try { + std::ofstream file(filepath, std::ios::binary); + if(!file.is_open()) { + throw std::runtime_error( + "failed to open file for writing: " + filepath.string() + ); + } + cereal::BinaryOutputArchive ar(file); + ar(version); + ar(args...); + } catch(const std::exception& ex) { + FL_LOG(fl::LogLevel::ERROR) + << "Error while saving \"" << filepath << "\": " << ex.what() << "\n"; + throw; + } + } - template - static void loadImpl(const fs::path& filepath, Args&... args) { - try { - std::ifstream file(filepath, std::ios::binary); - if (!file.is_open()) { - throw std::runtime_error( - "failed to open file for reading: " + filepath.string()); - } - cereal::BinaryInputArchive ar(file); - ar(args...); - } catch (const std::exception& ex) { - FL_LOG(fl::LogLevel::ERROR) << "Error while loading \"" << filepath - << "\": " << ex.what() << "\n"; - throw; - } - } -}; -} // namespace runtime + template + static void loadImpl(const fs::path& filepath, Args&... args) { + try { + std::ifstream file(filepath, std::ios::binary); + if(!file.is_open()) { + throw std::runtime_error( + "failed to open file for reading: " + filepath.string() + ); + } + cereal::BinaryInputArchive ar(file); + ar(args...); + } catch(const std::exception& ex) { + FL_LOG(fl::LogLevel::ERROR) << "Error while loading \"" << filepath + << "\": " << ex.what() << "\n"; + throw; + } + } + }; + } // namespace runtime } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/runtime/plugin/ModulePlugin.cpp b/flashlight/pkg/runtime/plugin/ModulePlugin.cpp index 0c1853a..3528392 100644 --- a/flashlight/pkg/runtime/plugin/ModulePlugin.cpp +++ b/flashlight/pkg/runtime/plugin/ModulePlugin.cpp @@ -10,13 +10,14 @@ namespace fl::pkg::runtime { ModulePlugin::ModulePlugin(const std::string& name) : fl::Plugin(name) { - arch_ = getSymbol("createModule"); + arch_ = getSymbol("createModule"); } std::shared_ptr ModulePlugin::arch( int64_t nFeatures, - int64_t nClasses) { - return std::shared_ptr(arch_(nFeatures, nClasses)); + int64_t nClasses +) { + return std::shared_ptr(arch_(nFeatures, nClasses)); } } // namespace fl diff --git a/flashlight/pkg/runtime/plugin/ModulePlugin.h b/flashlight/pkg/runtime/plugin/ModulePlugin.h index 5d1cbfa..a612b06 100644 --- a/flashlight/pkg/runtime/plugin/ModulePlugin.h +++ b/flashlight/pkg/runtime/plugin/ModulePlugin.h @@ -12,19 +12,19 @@ namespace fl { namespace pkg { -namespace runtime { + namespace runtime { -typedef Module* (*w2l_module_plugin_t)(int64_t nFeatures, int64_t nClasses); + typedef Module* (*w2l_module_plugin_t)(int64_t nFeatures, int64_t nClasses); -class ModulePlugin : public Plugin { - public: - explicit ModulePlugin(const std::string& name); - std::shared_ptr arch(int64_t nFeatures, int64_t nClasses); + class ModulePlugin : public Plugin { + public: + explicit ModulePlugin(const std::string& name); + std::shared_ptr arch(int64_t nFeatures, int64_t nClasses); - private: - w2l_module_plugin_t arch_; -}; + private: + w2l_module_plugin_t arch_; + }; -} // namespace runtime + } // namespace runtime } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/runtime/plugin/plugincompiler/PluginModule.cpp b/flashlight/pkg/runtime/plugin/plugincompiler/PluginModule.cpp index 2657c04..4bcdf41 100644 --- a/flashlight/pkg/runtime/plugin/plugincompiler/PluginModule.cpp +++ b/flashlight/pkg/runtime/plugin/plugincompiler/PluginModule.cpp @@ -18,6 +18,6 @@ * source paths can be freely specified. */ extern "C" fl::Module* createModule(int64_t nFeature, int64_t nLabel) { - auto seq = std::make_unique(); // placeholder - return seq.release(); + auto seq = std::make_unique(); // placeholder + return seq.release(); } diff --git a/flashlight/pkg/runtime/test/DynamicScalerTest.cpp b/flashlight/pkg/runtime/test/DynamicScalerTest.cpp index 17a2821..a9536c0 100644 --- a/flashlight/pkg/runtime/test/DynamicScalerTest.cpp +++ b/flashlight/pkg/runtime/test/DynamicScalerTest.cpp @@ -15,42 +15,43 @@ #include "flashlight/fl/tensor/Init.h" TEST(DynamicScalerTest, Scaling) { - auto dynamicScaler = fl::pkg::runtime::DynamicScaler( - 32, // initFactor - 32, // maxFactor - 100 // updateInterval - ); - - auto loss = fl::uniform({5, 5, 5, 5}); - - auto scaledLoss = dynamicScaler.scale(loss); - ASSERT_TRUE(allClose(loss * 32, scaledLoss)); - - scaledLoss.addGrad(scaledLoss); - std::vector params{scaledLoss}; - bool unscaleStatus = dynamicScaler.unscale(params); - ASSERT_TRUE(unscaleStatus); - ASSERT_TRUE(allClose(loss, scaledLoss.grad())); + auto dynamicScaler = fl::pkg::runtime::DynamicScaler( + 32, // initFactor + 32, // maxFactor + 100 // updateInterval + ); + + auto loss = fl::uniform({5, 5, 5, 5}); + + auto scaledLoss = dynamicScaler.scale(loss); + ASSERT_TRUE(allClose(loss * 32, scaledLoss)); + + scaledLoss.addGrad(scaledLoss); + std::vector params{scaledLoss}; + bool unscaleStatus = dynamicScaler.unscale(params); + ASSERT_TRUE(unscaleStatus); + ASSERT_TRUE(allClose(loss, scaledLoss.grad())); } TEST(DynamicScalerTest, Serialization) { - auto dynamicScaler = std::make_shared( - 32, // initFactor - 32, // maxFactor - 100 // updateInterval - ); - - const fs::path path = fs::temp_directory_path() / "DynamicScaler.bin"; - fl::save(path, dynamicScaler); - - std::shared_ptr dynamicScaler1; - fl::load(path, dynamicScaler1); - ASSERT_TRUE( - dynamicScaler->getScaleFactor() == dynamicScaler1->getScaleFactor()); + auto dynamicScaler = std::make_shared( + 32, // initFactor + 32, // maxFactor + 100 // updateInterval + ); + + const fs::path path = fs::temp_directory_path() / "DynamicScaler.bin"; + fl::save(path, dynamicScaler); + + std::shared_ptr dynamicScaler1; + fl::load(path, dynamicScaler1); + ASSERT_TRUE( + dynamicScaler->getScaleFactor() == dynamicScaler1->getScaleFactor() + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/runtime/test/common/SequentialBuilderTest.cpp b/flashlight/pkg/runtime/test/common/SequentialBuilderTest.cpp index 9657939..d60d8da 100644 --- a/flashlight/pkg/runtime/test/common/SequentialBuilderTest.cpp +++ b/flashlight/pkg/runtime/test/common/SequentialBuilderTest.cpp @@ -22,66 +22,66 @@ fs::path archDir = ""; } // namespace TEST(SequentialBuilderTest, SeqModule) { - if (FL_BACKEND_CPU) { - GTEST_SKIP() << "Bidirectional RNN not supported"; - } - const fs::path archfile = archDir / "arch.txt"; - int nchannel = 4; - int nclass = 40; - int batchsize = 2; - int inputsteps = 100; + if(FL_BACKEND_CPU) { + GTEST_SKIP() << "Bidirectional RNN not supported"; + } + const fs::path archfile = archDir / "arch.txt"; + int nchannel = 4; + int nclass = 40; + int batchsize = 2; + int inputsteps = 100; - auto model = buildSequentialModule(archfile, nchannel, nclass); + auto model = buildSequentialModule(archfile, nchannel, nclass); - auto input = fl::randn({inputsteps, 1, nchannel, batchsize}, fl::dtype::f32); + auto input = fl::randn({inputsteps, 1, nchannel, batchsize}, fl::dtype::f32); - auto output = model->forward(noGrad(input)); + auto output = model->forward(noGrad(input)); - ASSERT_EQ(output.shape(), Shape({nclass, inputsteps, batchsize})); + ASSERT_EQ(output.shape(), Shape({nclass, inputsteps, batchsize})); - batchsize = 1; - input = fl::randn({inputsteps, 1, nchannel, batchsize}, fl::dtype::f32); - output = model->forward(noGrad(input)); - ASSERT_EQ(output.shape(), Shape({nclass, inputsteps, batchsize})); + batchsize = 1; + input = fl::randn({inputsteps, 1, nchannel, batchsize}, fl::dtype::f32); + output = model->forward(noGrad(input)); + ASSERT_EQ(output.shape(), Shape({nclass, inputsteps, batchsize})); } TEST(SequentialBuilderTest, Serialization) { - if (FL_BACKEND_CPU) { - GTEST_SKIP() << "Bidirectional RNN not supported"; - } - char* user = getenv("USER"); - std::string userstr = "unknown"; - if (user != nullptr) { - userstr = std::string(user); - } - const fs::path path = fs::temp_directory_path() / "test.mdl"; - const fs::path archfile = archDir / "arch.txt"; + if(FL_BACKEND_CPU) { + GTEST_SKIP() << "Bidirectional RNN not supported"; + } + char* user = getenv("USER"); + std::string userstr = "unknown"; + if(user != nullptr) { + userstr = std::string(user); + } + const fs::path path = fs::temp_directory_path() / "test.mdl"; + const fs::path archfile = archDir / "arch.txt"; - int C = 1, N = 5, B = 1, T = 10; - auto model = buildSequentialModule(archfile, C, N); + int C = 1, N = 5, B = 1, T = 10; + auto model = buildSequentialModule(archfile, C, N); - auto input = noGrad(fl::randn({T, 1, C, B}, fl::dtype::f32)); - auto output = model->forward(input); + auto input = noGrad(fl::randn({T, 1, C, B}, fl::dtype::f32)); + auto output = model->forward(input); - save(path, model); + save(path, model); - std::shared_ptr loaded; - load(path, loaded); + std::shared_ptr loaded; + load(path, loaded); - auto outputl = loaded->forward(input); + auto outputl = loaded->forward(input); - ASSERT_TRUE(allParamsClose(*loaded.get(), *model)); - ASSERT_TRUE(allClose(outputl.tensor(), output.tensor())); + ASSERT_TRUE(allParamsClose(*loaded.get(), *model)); + ASSERT_TRUE(allClose(outputl.tensor(), output.tensor())); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); // Resolve directory for arch #ifdef ARCHDIR - archDir = ARCHDIR; + archDir = ARCHDIR; #endif - return RUN_ALL_TESTS(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/runtime/test/plugin/ModulePluginTest.cpp b/flashlight/pkg/runtime/test/plugin/ModulePluginTest.cpp index bfe33a4..ebafa4d 100644 --- a/flashlight/pkg/runtime/test/plugin/ModulePluginTest.cpp +++ b/flashlight/pkg/runtime/test/plugin/ModulePluginTest.cpp @@ -17,35 +17,35 @@ using namespace fl; fs::path pluginDir; TEST(ModulePluginTest, ModulePlugin) { - const fs::path libfile = pluginDir / "test_module_plugin.so"; - - const int ninput = 80; - const int noutput = 10; - const int batchsize = 4; - - // Note: the following works only if the plugin is *not* statically - // linked against AF/FL. - // - // auto model = fl::pkg::runtime::ModulePlugin(libfile).arch(ninput, noutput); - // - // If AF/FL are linked, then there will be some issue at deallocation of - // the plugin. For that reason, we stick to the following conservative - // way in this test (plugin destroyed after model). - fl::pkg::runtime::ModulePlugin plugin(libfile); - auto model = plugin.arch(ninput, noutput); - auto input = fl::randn({ninput, batchsize}, fl::dtype::f32); - auto output = model->forward({noGrad(input)}).front(); - ASSERT_EQ(output.shape(), Shape({noutput, batchsize})); + const fs::path libfile = pluginDir / "test_module_plugin.so"; + + const int ninput = 80; + const int noutput = 10; + const int batchsize = 4; + + // Note: the following works only if the plugin is *not* statically + // linked against AF/FL. + // + // auto model = fl::pkg::runtime::ModulePlugin(libfile).arch(ninput, noutput); + // + // If AF/FL are linked, then there will be some issue at deallocation of + // the plugin. For that reason, we stick to the following conservative + // way in this test (plugin destroyed after model). + fl::pkg::runtime::ModulePlugin plugin(libfile); + auto model = plugin.arch(ninput, noutput); + auto input = fl::randn({ninput, batchsize}, fl::dtype::f32); + auto output = model->forward({noGrad(input)}).front(); + ASSERT_EQ(output.shape(), Shape({noutput, batchsize})); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); - // Resolve directory for arch + // Resolve directory for arch #ifdef PLUGINDIR - pluginDir = PLUGINDIR; + pluginDir = PLUGINDIR; #endif - return RUN_ALL_TESTS(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/runtime/test/plugin/test_module_plugin.cpp b/flashlight/pkg/runtime/test/plugin/test_module_plugin.cpp index c5bcd4c..3551d8d 100644 --- a/flashlight/pkg/runtime/test/plugin/test_module_plugin.cpp +++ b/flashlight/pkg/runtime/test/plugin/test_module_plugin.cpp @@ -9,7 +9,7 @@ #include "flashlight/fl/contrib/contrib.h" extern "C" fl::Module* createModule(int64_t nFeature, int64_t nLabel) { - auto seq = new fl::Sequential(); - seq->add(std::make_shared(nFeature, nLabel)); - return seq; + auto seq = new fl::Sequential(); + seq->add(std::make_shared(nFeature, nLabel)); + return seq; } diff --git a/flashlight/pkg/speech/audio/feature/Ceplifter.cpp b/flashlight/pkg/speech/audio/feature/Ceplifter.cpp index 8632723..91f3a28 100644 --- a/flashlight/pkg/speech/audio/feature/Ceplifter.cpp +++ b/flashlight/pkg/speech/audio/feature/Ceplifter.cpp @@ -14,31 +14,33 @@ namespace fl::lib::audio { -Ceplifter::Ceplifter(int numfilters, int lifterparam) - : numFilters_(numfilters), lifterParam_(lifterparam), coefs_(numFilters_) { - std::iota(coefs_.begin(), coefs_.end(), 0.0); - for (auto& c : coefs_) { - c = 1.0 + 0.5 * lifterParam_ * std::sin(M_PI * c / lifterParam_); - } +Ceplifter::Ceplifter(int numfilters, int lifterparam) : numFilters_(numfilters), + lifterParam_(lifterparam), + coefs_(numFilters_) { + std::iota(coefs_.begin(), coefs_.end(), 0.0); + for(auto& c : coefs_) { + c = 1.0 + 0.5 * lifterParam_ * std::sin(M_PI * c / lifterParam_); + } } std::vector Ceplifter::apply(const std::vector& input) const { - auto output(input); - applyInPlace(output); - return output; + auto output(input); + applyInPlace(output); + return output; } void Ceplifter::applyInPlace(std::vector& input) const { - if (input.size() % numFilters_ != 0) { - throw std::invalid_argument( - "Ceplifter: input size is not divisible by numFilters"); - } - size_t n = 0; - for (auto& in : input) { - in *= coefs_[n++]; - if (n == numFilters_) { - n = 0; + if(input.size() % numFilters_ != 0) { + throw std::invalid_argument( + "Ceplifter: input size is not divisible by numFilters" + ); + } + size_t n = 0; + for(auto& in : input) { + in *= coefs_[n++]; + if(n == numFilters_) { + n = 0; + } } - } } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Ceplifter.h b/flashlight/pkg/speech/audio/feature/Ceplifter.h index d451d1b..019aff4 100644 --- a/flashlight/pkg/speech/audio/feature/Ceplifter.h +++ b/flashlight/pkg/speech/audio/feature/Ceplifter.h @@ -12,24 +12,24 @@ namespace fl { namespace lib { -namespace audio { + namespace audio { // Re-scale the cepstral coefficients using liftering -// c'(n) = c(n) * (1 + 0.5 * L * sin(pi * n/ L)) where L is lifterparam +// c'(n) = c(n) * (1 + 0.5 * L * sin(pi * n/ L)) where L is lifterparam -class Ceplifter { - public: - Ceplifter(int numfilters, int lifterparam); + class Ceplifter { + public: + Ceplifter(int numfilters, int lifterparam); - std::vector apply(const std::vector& input) const; + std::vector apply(const std::vector& input) const; - void applyInPlace(std::vector& input) const; + void applyInPlace(std::vector& input) const; - private: - int numFilters_; // number of filterbank channels - int lifterParam_; // liftering parameter - std::vector coefs_; // coefficients to scale cepstral coefficients -}; -} // namespace audio + private: + int numFilters_; // number of filterbank channels + int lifterParam_; // liftering parameter + std::vector coefs_; // coefficients to scale cepstral coefficients + }; + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Dct.cpp b/flashlight/pkg/speech/audio/feature/Dct.cpp index e84b25d..c88c339 100644 --- a/flashlight/pkg/speech/audio/feature/Dct.cpp +++ b/flashlight/pkg/speech/audio/feature/Dct.cpp @@ -15,19 +15,18 @@ namespace fl::lib::audio { -Dct::Dct(int numfilters, int numceps) - : numFilters_(numfilters), - numCeps_(numceps), - dctMat_(numfilters * numceps) { - for (size_t f = 0; f < numFilters_; ++f) { - for (size_t c = 0; c < numCeps_; ++c) { - dctMat_[f * numCeps_ + c] = std::sqrt(2.0 / numFilters_) * - std::cos(M_PI * c * (f + 0.5) / numFilters_); +Dct::Dct(int numfilters, int numceps) : numFilters_(numfilters), + numCeps_(numceps), + dctMat_(numfilters * numceps) { + for(size_t f = 0; f < numFilters_; ++f) { + for(size_t c = 0; c < numCeps_; ++c) { + dctMat_[f * numCeps_ + c] = std::sqrt(2.0 / numFilters_) + * std::cos(M_PI * c * (f + 0.5) / numFilters_); + } } - } } std::vector Dct::apply(const std::vector& input) const { - return cblasGemm(input, dctMat_, numCeps_, numFilters_); + return cblasGemm(input, dctMat_, numCeps_, numFilters_); } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Dct.h b/flashlight/pkg/speech/audio/feature/Dct.h index 7092489..dbf8e0f 100644 --- a/flashlight/pkg/speech/audio/feature/Dct.h +++ b/flashlight/pkg/speech/audio/feature/Dct.h @@ -12,23 +12,23 @@ namespace fl { namespace lib { -namespace audio { + namespace audio { // Compute Discrete Cosine Transform -// c(i) = sqrt(2/N) SUM_j (m(j) * cos(pi * i * (j - 0.5)/ N)) -// where j in [1, N], m - log filterbank amplitudes +// c(i) = sqrt(2/N) SUM_j (m(j) * cos(pi * i * (j - 0.5)/ N)) +// where j in [1, N], m - log filterbank amplitudes -class Dct { - public: - Dct(int numfilters, int numceps); + class Dct { + public: + Dct(int numfilters, int numceps); - std::vector apply(const std::vector& input) const; + std::vector apply(const std::vector& input) const; - private: - int numFilters_; // Number of filterbank channels - int numCeps_; // Number of cepstral coefficients - std::vector dctMat_; // Dct matrix -}; -} // namespace audio + private: + int numFilters_; // Number of filterbank channels + int numCeps_; // Number of cepstral coefficients + std::vector dctMat_; // Dct matrix + }; + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Derivatives.cpp b/flashlight/pkg/speech/audio/feature/Derivatives.cpp index 6fd6684..084f845 100644 --- a/flashlight/pkg/speech/audio/feature/Derivatives.cpp +++ b/flashlight/pkg/speech/audio/feature/Derivatives.cpp @@ -13,73 +13,79 @@ namespace fl::lib::audio { -Derivatives::Derivatives(int deltawindow, int accwindow) - : deltaWindow_(deltawindow), accWindow_(accwindow) {} +Derivatives::Derivatives(int deltawindow, int accwindow) : deltaWindow_(deltawindow), + accWindow_(accwindow) {} std::vector Derivatives::apply( const std::vector& input, - int numfeat) const { - if (input.size() % numfeat != 0) { - throw std::invalid_argument( - "Derivatives: input size is not divisible by numFeatures"); - } - // Compute deltas - if (deltaWindow_ <= 0) { - return input; - } + int numfeat +) const { + if(input.size() % numfeat != 0) { + throw std::invalid_argument( + "Derivatives: input size is not divisible by numFeatures" + ); + } + // Compute deltas + if(deltaWindow_ <= 0) { + return input; + } - auto deltas = computeDerivative(input, deltaWindow_, numfeat); - size_t szMul = 2; - std::vector doubledeltas; - if (accWindow_ > 0) { - // Compute double deltas (only if required) - szMul = 3; - doubledeltas = computeDerivative(deltas, accWindow_, numfeat); - } - std::vector output(input.size() * szMul); - int numframes = input.size() / numfeat; - for (size_t i = 0; i < numframes; ++i) { - size_t curInIdx = i * numfeat; - size_t curOutIdx = curInIdx * szMul; - // copy input - std::copy( - input.data() + curInIdx, - input.data() + curInIdx + numfeat, - output.data() + curOutIdx); - // copy deltas - std::copy( - deltas.data() + curInIdx, - deltas.data() + curInIdx + numfeat, - output.data() + curOutIdx + numfeat); - // copy double-deltas - if (accWindow_ > 0) { - std::copy( - doubledeltas.data() + curInIdx, - doubledeltas.data() + curInIdx + numfeat, - output.data() + curOutIdx + 2 * numfeat); + auto deltas = computeDerivative(input, deltaWindow_, numfeat); + size_t szMul = 2; + std::vector doubledeltas; + if(accWindow_ > 0) { + // Compute double deltas (only if required) + szMul = 3; + doubledeltas = computeDerivative(deltas, accWindow_, numfeat); + } + std::vector output(input.size() * szMul); + int numframes = input.size() / numfeat; + for(size_t i = 0; i < numframes; ++i) { + size_t curInIdx = i * numfeat; + size_t curOutIdx = curInIdx * szMul; + // copy input + std::copy( + input.data() + curInIdx, + input.data() + curInIdx + numfeat, + output.data() + curOutIdx + ); + // copy deltas + std::copy( + deltas.data() + curInIdx, + deltas.data() + curInIdx + numfeat, + output.data() + curOutIdx + numfeat + ); + // copy double-deltas + if(accWindow_ > 0) { + std::copy( + doubledeltas.data() + curInIdx, + doubledeltas.data() + curInIdx + numfeat, + output.data() + curOutIdx + 2 * numfeat + ); + } } - } - return output; + return output; } std::vector Derivatives::computeDerivative( const std::vector& input, int windowlen, - int numfeat) const { - int numframes = input.size() / numfeat; - std::vector output(input.size(), 0.0); - float denominator = (windowlen * (windowlen + 1) * (2 * windowlen + 1)) / 3.0; - for (size_t i = 0; i < numframes; ++i) { - for (size_t j = 0; j < numfeat; ++j) { - size_t curIdx = i * numfeat + j; - for (size_t d = 1; d <= windowlen; ++d) { - output[curIdx] += d * - (input[curIdx + std::min((numframes - i - 1), d) * numfeat] - - input[curIdx - std::min(i, d) * numfeat]); - } - output[curIdx] /= denominator; + int numfeat +) const { + int numframes = input.size() / numfeat; + std::vector output(input.size(), 0.0); + float denominator = (windowlen * (windowlen + 1) * (2 * windowlen + 1)) / 3.0; + for(size_t i = 0; i < numframes; ++i) { + for(size_t j = 0; j < numfeat; ++j) { + size_t curIdx = i * numfeat + j; + for(size_t d = 1; d <= windowlen; ++d) { + output[curIdx] += d + * (input[curIdx + std::min((numframes - i - 1), d) * numfeat] + - input[curIdx - std::min(i, d) * numfeat]); + } + output[curIdx] /= denominator; + } } - } - return output; + return output; } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Derivatives.h b/flashlight/pkg/speech/audio/feature/Derivatives.h index 3e61c16..54eb8a8 100644 --- a/flashlight/pkg/speech/audio/feature/Derivatives.h +++ b/flashlight/pkg/speech/audio/feature/Derivatives.h @@ -12,29 +12,30 @@ namespace fl { namespace lib { -namespace audio { + namespace audio { // Compute first order (deltas) and second order (acceleration) derivatives of -// cepstral coefficients -// d(i) = 0.5 * SUM_t (t * (c(i + t) - c (i - t))) / SUM_t t^2 -// where t in [1, maxlagsize] - -class Derivatives { - public: - Derivatives(int deltawindow, int accwindow); - - std::vector apply(const std::vector& input, int numfeat) const; - - private: - int deltaWindow_; // delta derivatives lag size - int accWindow_; // acceleration derivatives lag size - - // Helper function to compute derivatives of single order - std::vector computeDerivative( - const std::vector& input, - int windowlen, - int numfeat) const; -}; -} // namespace audio +// cepstral coefficients +// d(i) = 0.5 * SUM_t (t * (c(i + t) - c (i - t))) / SUM_t t^2 +// where t in [1, maxlagsize] + + class Derivatives { + public: + Derivatives(int deltawindow, int accwindow); + + std::vector apply(const std::vector& input, int numfeat) const; + + private: + int deltaWindow_; // delta derivatives lag size + int accWindow_; // acceleration derivatives lag size + + // Helper function to compute derivatives of single order + std::vector computeDerivative( + const std::vector& input, + int windowlen, + int numfeat + ) const; + }; + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Dither.cpp b/flashlight/pkg/speech/audio/feature/Dither.cpp index 2cd4c8a..08a135f 100644 --- a/flashlight/pkg/speech/audio/feature/Dither.cpp +++ b/flashlight/pkg/speech/audio/feature/Dither.cpp @@ -11,19 +11,19 @@ namespace fl::lib::audio { -Dither::Dither(float ditherVal) - : ditherVal_(ditherVal), rng_((ditherVal > 0.0) ? 123456 : time(nullptr)){}; +Dither::Dither(float ditherVal) : ditherVal_(ditherVal), + rng_((ditherVal > 0.0) ? 123456 : time(nullptr)) {}; std::vector Dither::apply(const std::vector& input) { - auto output(input); - applyInPlace(output); - return output; + auto output(input); + applyInPlace(output); + return output; } void Dither::applyInPlace(std::vector& input) { - std::uniform_real_distribution distribution(0.0, 1.0); - for (auto& i : input) { - i += ditherVal_ * distribution(rng_); - } + std::uniform_real_distribution distribution(0.0, 1.0); + for(auto& i : input) { + i += ditherVal_ * distribution(rng_); + } } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Dither.h b/flashlight/pkg/speech/audio/feature/Dither.h index 8585630..f69a676 100644 --- a/flashlight/pkg/speech/audio/feature/Dither.h +++ b/flashlight/pkg/speech/audio/feature/Dither.h @@ -12,27 +12,27 @@ namespace fl { namespace lib { -namespace audio { + namespace audio { // Dither the signal by adding small amount of random noise to the signal -// s'(n) = s(n) + q * RND() where RND() is uniformly distributed in [-1, 1) -// and `q` is the dithering constant +// s'(n) = s(n) + q * RND() where RND() is uniformly distributed in [-1, 1) +// and `q` is the dithering constant // Similar to HTK, positive value of `q` causes the same noise signal to be // added everytime and with negative value of `q`, noise is random and the same // file may produce slightly different results in different trials -class Dither { - public: - explicit Dither(float ditherVal); + class Dither { + public: + explicit Dither(float ditherVal); - std::vector apply(const std::vector& input); + std::vector apply(const std::vector& input); - void applyInPlace(std::vector& input); + void applyInPlace(std::vector& input); - private: - float ditherVal_; - std::mt19937 rng_; // Standard mersenne_twister_engine -}; -} // namespace audio + private: + float ditherVal_; + std::mt19937 rng_; // Standard mersenne_twister_engine + }; + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/FeatureParams.h b/flashlight/pkg/speech/audio/feature/FeatureParams.h index 35aa881..f5562cb 100644 --- a/flashlight/pkg/speech/audio/feature/FeatureParams.h +++ b/flashlight/pkg/speech/audio/feature/FeatureParams.h @@ -13,159 +13,160 @@ namespace fl { namespace lib { -namespace audio { - -enum class WindowType { - HAMMING = 0, - HANNING = 1, -}; -enum class FrequencyScale { - MEL = 0, - LINEAR = 1, - LOG10 = 2, -}; - -struct FeatureParams { - // frequency (Hz) of speech signal recording - int64_t samplingFreq; - - // frame size in milliseconds - int64_t frameSizeMs; - - // frame step size in milliseconds - int64_t frameStrideMs; - - // number of filterbank channels - // Kaldi recommends using 23 for 16KHz and 15 for 8KHz sampled data - int64_t numFilterbankChans; - - // lower cutoff frequency (HZ) for the filterbank - int64_t lowFreqFilterbank; - - // upper cutoff frequency (HZ) for the filterbank - int64_t highFreqFilterbank; - - // number of cepstral coefficients - int64_t numCepstralCoeffs; - - // liftering parameter - int64_t lifterParam; - - // number of delta (first order regression) coefficients - int64_t deltaWindow; - - // number of acceleration (second order regression) coefficients - int64_t accWindow; - - // analysis window function handle for framing (hamming by default) - WindowType windowType; - - // preemphasis filtering coefficient (0.7 default) - float preemCoef; - - // option controlling the size of the mel floor (1.0 default) - float melFloor; - - // dithering constant (0.0 default ==> no dithering) - float ditherVal; - - // use power instead of magnitude for filterbank energies - bool usePower; - - // If true, append log energy term as a feature to MFSC - // For MFCC, C0 is replaced with energy term - bool useEnergy; - - // If true, use energy before PreEmphasis and Windowing - bool rawEnergy; - - // If true, remove DC offset from the signal frames - bool zeroMeanFrame; - - FeatureParams( - int64_t samplingfreq = 16000, - int64_t framesizems = 25, - int64_t framestridems = 10, - int64_t numfilterbankchans = 23, - int64_t lowfreqfilterbank = 0.0, - int64_t highfreqfilterbank = -1.0, // If -ve value, then samplingFreq/2 - int64_t numcepstralcoeffs = 13, - int64_t lifterparam = 22, - int64_t deltawindow = 2, - int64_t accwindow = 2, - WindowType windowtype = WindowType::HAMMING, - float preemcoef = 0.97, - float melfloor = 1.0, - float ditherval = 0.0, - bool usepower = true, - bool usenergy = true, - bool rawenergy = true, - bool zeromeanframe = true) - : samplingFreq(samplingfreq), - frameSizeMs(framesizems), - frameStrideMs(framestridems), - numFilterbankChans(numfilterbankchans), - lowFreqFilterbank(lowfreqfilterbank), - highFreqFilterbank(highfreqfilterbank), - numCepstralCoeffs(numcepstralcoeffs), - lifterParam(lifterparam), - deltaWindow(deltawindow), - accWindow(accwindow), - windowType(windowtype), - preemCoef(preemcoef), - melFloor(melfloor), - ditherVal(ditherval), - usePower(usepower), - useEnergy(usenergy), - rawEnergy(rawenergy), - zeroMeanFrame(zeromeanframe) {} - - // frame size (no of samples) - // the last frame is discarded, if less than the frame size - int64_t numFrameSizeSamples() const { - return static_cast(round(1e-3 * frameSizeMs * samplingFreq)); - } - - int64_t numFrameStrideSamples() const { - return static_cast(round(1e-3 * frameStrideMs * samplingFreq)); - } - - int64_t nFft() const { - int64_t nsamples = numFrameSizeSamples(); - return (nsamples > 0) - ? 1 << static_cast(std::ceil(std::log2(nsamples))) - : 0; - } - - int64_t filterFreqResponseLen() const { - return (nFft() >> 1) + 1; - } - - int64_t powSpecFeatSz() const { - return filterFreqResponseLen(); - } - - int64_t mfscFeatSz() const { - int64_t devMultiplier = - 1 + (deltaWindow > 0 ? 1 : 0) + (accWindow > 0 ? 1 : 0); - return (numFilterbankChans + (useEnergy ? 1 : 0)) * (devMultiplier); - } - - int64_t mfccFeatSz() const { - int64_t devMultiplier = - 1 + (deltaWindow > 0 ? 1 : 0) + (accWindow > 0 ? 1 : 0); - return numCepstralCoeffs * devMultiplier; - } - - int64_t numFrames(int64_t inSize) const { - auto frameSize = numFrameSizeSamples(); - auto frameStride = numFrameStrideSamples(); - if (frameStride <= 0 || inSize < frameSize) { - return 0; - } - return 1 + std::floor((inSize - frameSize) * 1.0 / frameStride); - } -}; -} // namespace audio + namespace audio { + + enum class WindowType { + HAMMING = 0, + HANNING = 1, + }; + enum class FrequencyScale { + MEL = 0, + LINEAR = 1, + LOG10 = 2, + }; + + struct FeatureParams { + // frequency (Hz) of speech signal recording + int64_t samplingFreq; + + // frame size in milliseconds + int64_t frameSizeMs; + + // frame step size in milliseconds + int64_t frameStrideMs; + + // number of filterbank channels + // Kaldi recommends using 23 for 16KHz and 15 for 8KHz sampled data + int64_t numFilterbankChans; + + // lower cutoff frequency (HZ) for the filterbank + int64_t lowFreqFilterbank; + + // upper cutoff frequency (HZ) for the filterbank + int64_t highFreqFilterbank; + + // number of cepstral coefficients + int64_t numCepstralCoeffs; + + // liftering parameter + int64_t lifterParam; + + // number of delta (first order regression) coefficients + int64_t deltaWindow; + + // number of acceleration (second order regression) coefficients + int64_t accWindow; + + // analysis window function handle for framing (hamming by default) + WindowType windowType; + + // preemphasis filtering coefficient (0.7 default) + float preemCoef; + + // option controlling the size of the mel floor (1.0 default) + float melFloor; + + // dithering constant (0.0 default ==> no dithering) + float ditherVal; + + // use power instead of magnitude for filterbank energies + bool usePower; + + // If true, append log energy term as a feature to MFSC + // For MFCC, C0 is replaced with energy term + bool useEnergy; + + // If true, use energy before PreEmphasis and Windowing + bool rawEnergy; + + // If true, remove DC offset from the signal frames + bool zeroMeanFrame; + + FeatureParams( + int64_t samplingfreq = 16000, + int64_t framesizems = 25, + int64_t framestridems = 10, + int64_t numfilterbankchans = 23, + int64_t lowfreqfilterbank = 0.0, + int64_t highfreqfilterbank = -1.0, // If -ve value, then samplingFreq/2 + int64_t numcepstralcoeffs = 13, + int64_t lifterparam = 22, + int64_t deltawindow = 2, + int64_t accwindow = 2, + WindowType windowtype = WindowType::HAMMING, + float preemcoef = 0.97, + float melfloor = 1.0, + float ditherval = 0.0, + bool usepower = true, + bool usenergy = true, + bool rawenergy = true, + bool zeromeanframe = true + ) + : samplingFreq(samplingfreq), + frameSizeMs(framesizems), + frameStrideMs(framestridems), + numFilterbankChans(numfilterbankchans), + lowFreqFilterbank(lowfreqfilterbank), + highFreqFilterbank(highfreqfilterbank), + numCepstralCoeffs(numcepstralcoeffs), + lifterParam(lifterparam), + deltaWindow(deltawindow), + accWindow(accwindow), + windowType(windowtype), + preemCoef(preemcoef), + melFloor(melfloor), + ditherVal(ditherval), + usePower(usepower), + useEnergy(usenergy), + rawEnergy(rawenergy), + zeroMeanFrame(zeromeanframe) {} + + // frame size (no of samples) + // the last frame is discarded, if less than the frame size + int64_t numFrameSizeSamples() const { + return static_cast(round(1e-3 * frameSizeMs * samplingFreq)); + } + + int64_t numFrameStrideSamples() const { + return static_cast(round(1e-3 * frameStrideMs * samplingFreq)); + } + + int64_t nFft() const { + int64_t nsamples = numFrameSizeSamples(); + return (nsamples > 0) + ? 1 << static_cast(std::ceil(std::log2(nsamples))) + : 0; + } + + int64_t filterFreqResponseLen() const { + return (nFft() >> 1) + 1; + } + + int64_t powSpecFeatSz() const { + return filterFreqResponseLen(); + } + + int64_t mfscFeatSz() const { + int64_t devMultiplier = + 1 + (deltaWindow > 0 ? 1 : 0) + (accWindow > 0 ? 1 : 0); + return (numFilterbankChans + (useEnergy ? 1 : 0)) * (devMultiplier); + } + + int64_t mfccFeatSz() const { + int64_t devMultiplier = + 1 + (deltaWindow > 0 ? 1 : 0) + (accWindow > 0 ? 1 : 0); + return numCepstralCoeffs * devMultiplier; + } + + int64_t numFrames(int64_t inSize) const { + auto frameSize = numFrameSizeSamples(); + auto frameStride = numFrameStrideSamples(); + if(frameStride <= 0 || inSize < frameSize) { + return 0; + } + return 1 + std::floor((inSize - frameSize) * 1.0 / frameStride); + } + }; + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Mfcc.cpp b/flashlight/pkg/speech/audio/feature/Mfcc.cpp index 84ae29b..838893d 100644 --- a/flashlight/pkg/speech/audio/feature/Mfcc.cpp +++ b/flashlight/pkg/speech/audio/feature/Mfcc.cpp @@ -13,61 +13,60 @@ namespace fl::lib::audio { -Mfcc::Mfcc(const FeatureParams& params) - : Mfsc(params), - dct_(params.numFilterbankChans, params.numCepstralCoeffs), - ceplifter_(params.numCepstralCoeffs, params.lifterParam), - derivatives_(params.deltaWindow, params.accWindow) { - validateMfccParams(); +Mfcc::Mfcc(const FeatureParams& params) : Mfsc(params), + dct_(params.numFilterbankChans, params.numCepstralCoeffs), + ceplifter_(params.numCepstralCoeffs, params.lifterParam), + derivatives_(params.deltaWindow, params.accWindow) { + validateMfccParams(); } std::vector Mfcc::apply(const std::vector& input) { - auto frames = frameSignal(input, this->featParams_); - if (frames.empty()) { - return {}; - } + auto frames = frameSignal(input, this->featParams_); + if(frames.empty()) { + return {}; + } - int nSamples = this->featParams_.numFrameSizeSamples(); - int nFrames = frames.size() / nSamples; + int nSamples = this->featParams_.numFrameSizeSamples(); + int nFrames = frames.size() / nSamples; - std::vector energy(nFrames); - if (this->featParams_.useEnergy && this->featParams_.rawEnergy) { - for (size_t f = 0; f < nFrames; ++f) { - auto begin = frames.data() + f * nSamples; - energy[f] = - std::log(std::inner_product(begin, begin + nSamples, begin, 0.0)); + std::vector energy(nFrames); + if(this->featParams_.useEnergy && this->featParams_.rawEnergy) { + for(size_t f = 0; f < nFrames; ++f) { + auto begin = frames.data() + f * nSamples; + energy[f] = + std::log(std::inner_product(begin, begin + nSamples, begin, 0.0)); + } } - } - auto mfscfeat = this->mfscImpl(frames); - auto cep = dct_.apply(mfscfeat); - ceplifter_.applyInPlace(cep); + auto mfscfeat = this->mfscImpl(frames); + auto cep = dct_.apply(mfscfeat); + ceplifter_.applyInPlace(cep); - auto nFeat = this->featParams_.numCepstralCoeffs; - if (this->featParams_.useEnergy) { - if (!this->featParams_.rawEnergy) { - for (size_t f = 0; f < nFrames; ++f) { - auto begin = frames.data() + f * nSamples; - energy[f] = - std::log(std::inner_product(begin, begin + nSamples, begin, 0.0)); - } - } - // Replace C0 with energy - for (size_t f = 0; f < nFrames; ++f) { - cep[f * nFeat] = energy[f]; + auto nFeat = this->featParams_.numCepstralCoeffs; + if(this->featParams_.useEnergy) { + if(!this->featParams_.rawEnergy) { + for(size_t f = 0; f < nFrames; ++f) { + auto begin = frames.data() + f * nSamples; + energy[f] = + std::log(std::inner_product(begin, begin + nSamples, begin, 0.0)); + } + } + // Replace C0 with energy + for(size_t f = 0; f < nFrames; ++f) { + cep[f * nFeat] = energy[f]; + } } - } - return derivatives_.apply(cep, nFeat); + return derivatives_.apply(cep, nFeat); } int Mfcc::outputSize(int inputSz) { - return this->featParams_.mfccFeatSz() * this->featParams_.numFrames(inputSz); + return this->featParams_.mfccFeatSz() * this->featParams_.numFrames(inputSz); } void Mfcc::validateMfccParams() const { - this->validatePowSpecParams(); - this->validateMfscParams(); - if (this->featParams_.lifterParam < 0) { - throw std::invalid_argument("Mfcc: lifterparam must be nonnegative"); - } + this->validatePowSpecParams(); + this->validateMfscParams(); + if(this->featParams_.lifterParam < 0) { + throw std::invalid_argument("Mfcc: lifterparam must be nonnegative"); + } } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Mfcc.h b/flashlight/pkg/speech/audio/feature/Mfcc.h index eb5325b..4c19f07 100644 --- a/flashlight/pkg/speech/audio/feature/Mfcc.h +++ b/flashlight/pkg/speech/audio/feature/Mfcc.h @@ -16,61 +16,61 @@ namespace fl { namespace lib { -namespace audio { + namespace audio { // Computes Mel Frequency Cepstral Coefficient (MFCC) for a speech signal. // Feature calculation is similar to the calculation of default HTK MFCCs except // for Energy Normalization. // TODO: Support -// 1. ENORMALIZE -// 2. Cepstral Mean Normalisation -// 3. Vocal Tract Length Normalisation +// 1. ENORMALIZE +// 2. Cepstral Mean Normalisation +// 3. Vocal Tract Length Normalisation // Example usage: -// FeatureParams params; -// Mfcc mfcc(params); Tensor input = fl::rand({123456, 987}); -// Tensor mfccfeatures = mfcc->apply(input); +// FeatureParams params; +// Mfcc mfcc(params); Tensor input = fl::rand({123456, 987}); +// Tensor mfccfeatures = mfcc->apply(input); // // References -// [1] Young, S., Evermann, G., Gales, M., Hain, T., Kershaw, D., -// Liu, X., Moore, G., Odell, J., Ollason, D., Povey, D., -// Valtchev, V., Woodland, P., 2006. The HTK Book (for HTK -// Version 3.4.1). Engineering Department, Cambridge University. -// URL: http://htk.eng.cam.ac.uk -// [2] Daniel Povey , Arnab Ghoshal , Gilles Boulianne , Nagendra Goel , Mirko -// Hannemann , Yanmin Qian , Petr Schwarz , Georg Stemmer - The kaldi -// speech recognition toolkit -// URL: http://kaldi-asr.org/ -// [3] Ellis, D., 2005. PLP and RASTA (and MFCC, and inversion) in Matlab -// URL: https://labrosa.ee.columbia.edu/matlab/rastamat/ -// [4] Huang, X., Acero, A., Hon, H., 2001. Spoken Language -// Processing: A guide to theory, algorithm, and system -// development. Prentice Hall, Upper Saddle River, NJ, -// USA (pp. 314-315). -// [5] Kamil Wojcicki, HTK MFCC MATLAB, -// URL: -// https://www.mathworks.com/matlabcentral/fileexchange/32849-htk-mfcc-matlab +// [1] Young, S., Evermann, G., Gales, M., Hain, T., Kershaw, D., +// Liu, X., Moore, G., Odell, J., Ollason, D., Povey, D., +// Valtchev, V., Woodland, P., 2006. The HTK Book (for HTK +// Version 3.4.1). Engineering Department, Cambridge University. +// URL: http://htk.eng.cam.ac.uk +// [2] Daniel Povey , Arnab Ghoshal , Gilles Boulianne , Nagendra Goel , Mirko +// Hannemann , Yanmin Qian , Petr Schwarz , Georg Stemmer - The kaldi +// speech recognition toolkit +// URL: http://kaldi-asr.org/ +// [3] Ellis, D., 2005. PLP and RASTA (and MFCC, and inversion) in Matlab +// URL: https://labrosa.ee.columbia.edu/matlab/rastamat/ +// [4] Huang, X., Acero, A., Hon, H., 2001. Spoken Language +// Processing: A guide to theory, algorithm, and system +// development. Prentice Hall, Upper Saddle River, NJ, +// USA (pp. 314-315). +// [5] Kamil Wojcicki, HTK MFCC MATLAB, +// URL: +// https://www.mathworks.com/matlabcentral/fileexchange/32849-htk-mfcc-matlab // -class Mfcc : public Mfsc { - public: - explicit Mfcc(const FeatureParams& params); + class Mfcc : public Mfsc { + public: + explicit Mfcc(const FeatureParams& params); - virtual ~Mfcc() override {} + virtual ~Mfcc() override {} - // input - input speech signal (T) - // Returns - MFCC features (Col Major : FEAT X FRAMESZ) - std::vector apply(const std::vector& input) override; + // input - input speech signal (T) + // Returns - MFCC features (Col Major : FEAT X FRAMESZ) + std::vector apply(const std::vector& input) override; - int outputSize(int inputSz) override; + int outputSize(int inputSz) override; - private: - // The following classes are defined in the order they are applied - Dct dct_; - Ceplifter ceplifter_; - Derivatives derivatives_; + private: + // The following classes are defined in the order they are applied + Dct dct_; + Ceplifter ceplifter_; + Derivatives derivatives_; - void validateMfccParams() const; -}; -} // namespace audio + void validateMfccParams() const; + }; + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Mfsc.cpp b/flashlight/pkg/speech/audio/feature/Mfsc.cpp index 9864f5e..7669db1 100644 --- a/flashlight/pkg/speech/audio/feature/Mfsc.cpp +++ b/flashlight/pkg/speech/audio/feature/Mfsc.cpp @@ -15,92 +15,111 @@ namespace fl::lib::audio { -Mfsc::Mfsc(const FeatureParams& params) - : PowerSpectrum(params), - triFltBank_( - params.numFilterbankChans, - params.filterFreqResponseLen(), - params.samplingFreq, - params.lowFreqFilterbank, - params.highFreqFilterbank, - FrequencyScale::MEL), - derivatives_(params.deltaWindow, params.accWindow) { - validateMfscParams(); +Mfsc::Mfsc(const FeatureParams& params) : PowerSpectrum(params), + triFltBank_( + params.numFilterbankChans, + params.filterFreqResponseLen(), + params.samplingFreq, + params.lowFreqFilterbank, + params.highFreqFilterbank, + FrequencyScale::MEL), + derivatives_(params.deltaWindow, params.accWindow) { + validateMfscParams(); } std::vector Mfsc::apply(const std::vector& input) { - auto frames = frameSignal(input, this->featParams_); - if (frames.empty()) { - return {}; - } + auto frames = frameSignal(input, this->featParams_); + if(frames.empty()) { + return {}; + } - int nSamples = this->featParams_.numFrameSizeSamples(); - int nFrames = frames.size() / nSamples; + int nSamples = this->featParams_.numFrameSizeSamples(); + int nFrames = frames.size() / nSamples; - std::vector energy(nFrames); - if (this->featParams_.useEnergy && this->featParams_.rawEnergy) { - for (size_t f = 0; f < nFrames; ++f) { - auto begin = frames.data() + f * nSamples; - energy[f] = std::log(std::max( - std::inner_product( - begin, begin + nSamples, begin, static_cast(0.0)), - std::numeric_limits::lowest())); - } - } - auto mfscFeat = mfscImpl(frames); - auto numFeat = this->featParams_.numFilterbankChans; - if (this->featParams_.useEnergy) { - if (!this->featParams_.rawEnergy) { - for (size_t f = 0; f < nFrames; ++f) { - auto begin = frames.data() + f * nSamples; - energy[f] = std::log(std::max( - std::inner_product( - begin, begin + nSamples, begin, static_cast(0.0)), - std::numeric_limits::lowest())); - } + std::vector energy(nFrames); + if(this->featParams_.useEnergy && this->featParams_.rawEnergy) { + for(size_t f = 0; f < nFrames; ++f) { + auto begin = frames.data() + f * nSamples; + energy[f] = std::log( + std::max( + std::inner_product( + begin, + begin + nSamples, + begin, + static_cast(0.0) + ), + std::numeric_limits::lowest() + ) + ); + } } - std::vector newMfscFeat(mfscFeat.size() + nFrames); - for (size_t f = 0; f < nFrames; ++f) { - size_t start = f * numFeat; - newMfscFeat[start + f] = energy[f]; - std::copy( - mfscFeat.data() + start, - mfscFeat.data() + start + numFeat, - newMfscFeat.data() + start + f + 1); + auto mfscFeat = mfscImpl(frames); + auto numFeat = this->featParams_.numFilterbankChans; + if(this->featParams_.useEnergy) { + if(!this->featParams_.rawEnergy) { + for(size_t f = 0; f < nFrames; ++f) { + auto begin = frames.data() + f * nSamples; + energy[f] = std::log( + std::max( + std::inner_product( + begin, + begin + nSamples, + begin, + static_cast(0.0) + ), + std::numeric_limits::lowest() + ) + ); + } + } + std::vector newMfscFeat(mfscFeat.size() + nFrames); + for(size_t f = 0; f < nFrames; ++f) { + size_t start = f * numFeat; + newMfscFeat[start + f] = energy[f]; + std::copy( + mfscFeat.data() + start, + mfscFeat.data() + start + numFeat, + newMfscFeat.data() + start + f + 1 + ); + } + std::swap(mfscFeat, newMfscFeat); + ++numFeat; } - std::swap(mfscFeat, newMfscFeat); - ++numFeat; - } - // Derivatives will not be computed if windowsize < 0 - return derivatives_.apply(mfscFeat, numFeat); + // Derivatives will not be computed if windowsize < 0 + return derivatives_.apply(mfscFeat, numFeat); } std::vector Mfsc::mfscImpl(std::vector& frames) { - auto powspectrum = this->powSpectrumImpl(frames); - if (this->featParams_.usePower) { + auto powspectrum = this->powSpectrumImpl(frames); + if(this->featParams_.usePower) { + std::transform( + powspectrum.begin(), + powspectrum.end(), + powspectrum.begin(), + [](float x) { return x * x; }); + } + auto triflt = triFltBank_.apply(powspectrum, this->featParams_.melFloor); std::transform( - powspectrum.begin(), - powspectrum.end(), - powspectrum.begin(), - [](float x) { return x * x; }); - } - auto triflt = triFltBank_.apply(powspectrum, this->featParams_.melFloor); - std::transform(triflt.begin(), triflt.end(), triflt.begin(), [](float x) { - return std::log(x); - }); - return triflt; + triflt.begin(), + triflt.end(), + triflt.begin(), + [](float x) { + return std::log(x); + } + ); + return triflt; } int Mfsc::outputSize(int inputSz) { - return this->featParams_.mfscFeatSz() * this->featParams_.numFrames(inputSz); + return this->featParams_.mfscFeatSz() * this->featParams_.numFrames(inputSz); } void Mfsc::validateMfscParams() const { - this->validatePowSpecParams(); - if (this->featParams_.numFilterbankChans <= 0) { - throw std::invalid_argument("Mfsc: numFilterbankChans must be positive"); - } else if (this->featParams_.melFloor <= 0.0) { - throw std::invalid_argument("Mfsc: melfloor must be positive"); - } + this->validatePowSpecParams(); + if(this->featParams_.numFilterbankChans <= 0) { + throw std::invalid_argument("Mfsc: numFilterbankChans must be positive"); + } else if(this->featParams_.melFloor <= 0.0) { + throw std::invalid_argument("Mfsc: melfloor must be positive"); + } } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Mfsc.h b/flashlight/pkg/speech/audio/feature/Mfsc.h index 14765c4..aafdcd9 100644 --- a/flashlight/pkg/speech/audio/feature/Mfsc.h +++ b/flashlight/pkg/speech/audio/feature/Mfsc.h @@ -15,32 +15,32 @@ namespace fl { namespace lib { -namespace audio { + namespace audio { // Computes MFSC features for a speech signal. -class Mfsc : public PowerSpectrum { - public: - explicit Mfsc(const FeatureParams& params); + class Mfsc : public PowerSpectrum { + public: + explicit Mfsc(const FeatureParams& params); - virtual ~Mfsc() override {} + virtual ~Mfsc() override {} - // input - input speech signal (T) - // Returns - MFSC feature (Col Major : FEAT X FRAMESZ) - std::vector apply(const std::vector& input) override; + // input - input speech signal (T) + // Returns - MFSC feature (Col Major : FEAT X FRAMESZ) + std::vector apply(const std::vector& input) override; - int outputSize(int inputSz) override; + int outputSize(int inputSz) override; - protected: - // Helper function which takes input as signal after dividing the signal into - // frames. Main purpose of this function is to reuse it in MFCC code - std::vector mfscImpl(std::vector& frames); - void validateMfscParams() const; + protected: + // Helper function which takes input as signal after dividing the signal into + // frames. Main purpose of this function is to reuse it in MFCC code + std::vector mfscImpl(std::vector& frames); + void validateMfscParams() const; - private: - TriFilterbank triFltBank_; - Derivatives derivatives_; -}; -} // namespace audio + private: + TriFilterbank triFltBank_; + Derivatives derivatives_; + }; + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/PowerSpectrum.cpp b/flashlight/pkg/speech/audio/feature/PowerSpectrum.cpp index 19cfb71..38c626b 100644 --- a/flashlight/pkg/speech/audio/feature/PowerSpectrum.cpp +++ b/flashlight/pkg/speech/audio/feature/PowerSpectrum.cpp @@ -18,129 +18,147 @@ namespace fl::lib::audio { std::mutex PowerSpectrum::fftPlanMutex_; -PowerSpectrum::PowerSpectrum(const FeatureParams& params) - : featParams_(params), - dither_(params.ditherVal), - preEmphasis_(params.preemCoef, params.numFrameSizeSamples()), - windowing_(params.numFrameSizeSamples(), params.windowType) { - // Need to lock plan creation, which only happens once per instance - // https://www.fftw.org/fftw3_doc/Thread-safety.html -- multiple threads can - // use the same plans with fftw_execute - std::lock_guard lock(fftPlanMutex_); - - validatePowSpecParams(); - auto nFFt = featParams_.nFft(); - inFftBuf_.resize(nFFt, 0.0); - outFftBuf_.resize(2 * nFFt); - fftPlan_ = std::make_unique(std::move(fftw_plan_dft_r2c_1d( - nFFt, inFftBuf_.data(), (fftw_complex*)outFftBuf_.data(), FFTW_MEASURE))); +PowerSpectrum::PowerSpectrum(const FeatureParams& params) : featParams_(params), + dither_(params.ditherVal), + preEmphasis_(params.preemCoef, + params.numFrameSizeSamples()), + windowing_(params.numFrameSizeSamples(), + params.windowType) { + // Need to lock plan creation, which only happens once per instance + // https://www.fftw.org/fftw3_doc/Thread-safety.html -- multiple threads can + // use the same plans with fftw_execute + std::lock_guard lock(fftPlanMutex_); + + validatePowSpecParams(); + auto nFFt = featParams_.nFft(); + inFftBuf_.resize(nFFt, 0.0); + outFftBuf_.resize(2 * nFFt); + fftPlan_ = std::make_unique( + std::move( + fftw_plan_dft_r2c_1d( + nFFt, + inFftBuf_.data(), + (fftw_complex*) outFftBuf_.data(), + FFTW_MEASURE + ) + ) + ); } std::vector PowerSpectrum::apply(const std::vector& input) { - auto frames = frameSignal(input, featParams_); - if (frames.empty()) { - return {}; - } - return powSpectrumImpl(frames); + auto frames = frameSignal(input, featParams_); + if(frames.empty()) { + return {}; + } + return powSpectrumImpl(frames); } std::vector PowerSpectrum::powSpectrumImpl(std::vector& frames) { - int nSamples = featParams_.numFrameSizeSamples(); - int nFrames = frames.size() / nSamples; - int nFft = featParams_.nFft(); - int K = featParams_.filterFreqResponseLen(); - - if (featParams_.ditherVal != 0.0) { - frames = dither_.apply(frames); - } - if (featParams_.zeroMeanFrame) { - for (size_t f = 0; f < nFrames; ++f) { - auto begin = frames.data() + f * nSamples; - float mean = std::accumulate(begin, begin + nSamples, 0.0); - mean /= nSamples; - std::transform( - begin, begin + nSamples, begin, [mean](float x) { return x - mean; }); + int nSamples = featParams_.numFrameSizeSamples(); + int nFrames = frames.size() / nSamples; + int nFft = featParams_.nFft(); + int K = featParams_.filterFreqResponseLen(); + + if(featParams_.ditherVal != 0.0) { + frames = dither_.apply(frames); + } + if(featParams_.zeroMeanFrame) { + for(size_t f = 0; f < nFrames; ++f) { + auto begin = frames.data() + f * nSamples; + float mean = std::accumulate(begin, begin + nSamples, 0.0); + mean /= nSamples; + std::transform( + begin, + begin + nSamples, + begin, + [mean](float x) { return x - mean; }); + } } - } - if (featParams_.preemCoef != 0) { - preEmphasis_.applyInPlace(frames); - } - windowing_.applyInPlace(frames); - std::vector dft(K * nFrames); - for (size_t f = 0; f < nFrames; ++f) { - auto begin = frames.data() + f * nSamples; - { - std::lock_guard lock(fftMutex_); - std::copy(begin, begin + nSamples, inFftBuf_.data()); - std::fill(outFftBuf_.begin(), outFftBuf_.end(), 0.0); - fftw_execute(*fftPlan_); - - // Copy stuff to the redundant part - for (size_t i = K; i < nFft; ++i) { - outFftBuf_[2 * i] = outFftBuf_[2 * nFft - 2 * i]; - outFftBuf_[2 * i + 1] = -outFftBuf_[2 * nFft - 2 * i + 1]; - } - - for (size_t i = 0; i < K; ++i) { - dft[f * K + i] = std::sqrt( - outFftBuf_[2 * i] * outFftBuf_[2 * i] + - outFftBuf_[2 * i + 1] * outFftBuf_[2 * i + 1]); - } + if(featParams_.preemCoef != 0) { + preEmphasis_.applyInPlace(frames); } - } - return dft; + windowing_.applyInPlace(frames); + std::vector dft(K * nFrames); + for(size_t f = 0; f < nFrames; ++f) { + auto begin = frames.data() + f * nSamples; + { + std::lock_guard lock(fftMutex_); + std::copy(begin, begin + nSamples, inFftBuf_.data()); + std::fill(outFftBuf_.begin(), outFftBuf_.end(), 0.0); + fftw_execute(*fftPlan_); + + // Copy stuff to the redundant part + for(size_t i = K; i < nFft; ++i) { + outFftBuf_[2 * i] = outFftBuf_[2 * nFft - 2 * i]; + outFftBuf_[2 * i + 1] = -outFftBuf_[2 * nFft - 2 * i + 1]; + } + + for(size_t i = 0; i < K; ++i) { + dft[f * K + i] = std::sqrt( + outFftBuf_[2 * i] * outFftBuf_[2 * i] + + outFftBuf_[2 * i + 1] * outFftBuf_[2 * i + 1] + ); + } + } + } + return dft; } std::vector PowerSpectrum::batchApply( const std::vector& input, - int batchSz) { - if (batchSz <= 0) { - throw std::invalid_argument("PowerSpectrum: negative batchSz"); - } else if (input.size() % batchSz != 0) { - throw std::invalid_argument( - "PowerSpectrum: input size is not divisible by batchSz"); - } - int N = input.size() / batchSz; - int outputSz = outputSize(N); - std::vector feat(outputSz * batchSz); + int batchSz +) { + if(batchSz <= 0) { + throw std::invalid_argument("PowerSpectrum: negative batchSz"); + } else if(input.size() % batchSz != 0) { + throw std::invalid_argument( + "PowerSpectrum: input size is not divisible by batchSz" + ); + } + int N = input.size() / batchSz; + int outputSz = outputSize(N); + std::vector feat(outputSz * batchSz); #pragma omp parallel for num_threads(batchSz) - for (int b = 0; b < batchSz; ++b) { - auto start = input.begin() + b * N; - std::vector inputBuf(start, start + N); - auto curFeat = apply(inputBuf); - if (outputSz != curFeat.size()) { - throw std::logic_error("PowerSpectrum: apply() returned wrong size"); + for(int b = 0; b < batchSz; ++b) { + auto start = input.begin() + b * N; + std::vector inputBuf(start, start + N); + auto curFeat = apply(inputBuf); + if(outputSz != curFeat.size()) { + throw std::logic_error("PowerSpectrum: apply() returned wrong size"); + } + std::copy( + curFeat.begin(), + curFeat.end(), + feat.begin() + b * curFeat.size() + ); } - std::copy( - curFeat.begin(), curFeat.end(), feat.begin() + b * curFeat.size()); - } - return feat; + return feat; } FeatureParams PowerSpectrum::getFeatureParams() const { - return featParams_; + return featParams_; } int PowerSpectrum::outputSize(int inputSz) { - return featParams_.powSpecFeatSz() * featParams_.numFrames(inputSz); + return featParams_.powSpecFeatSz() * featParams_.numFrames(inputSz); } void PowerSpectrum::validatePowSpecParams() const { - if (featParams_.samplingFreq <= 0) { - throw std::invalid_argument("PowerSpectrum: samplingFreq is negative"); - } else if (featParams_.frameSizeMs <= 0) { - throw std::invalid_argument("PowerSpectrum: frameSizeMs is negative"); - } else if (featParams_.frameStrideMs <= 0) { - throw std::invalid_argument("PowerSpectrum: frameStrideMs is negative"); - } else if (featParams_.numFrameSizeSamples() <= 0) { - throw std::invalid_argument("PowerSpectrum: frameSizeMs is too low"); - } else if (featParams_.numFrameStrideSamples() <= 0) { - throw std::invalid_argument("PowerSpectrum: frameStrideMs is too low"); - } + if(featParams_.samplingFreq <= 0) { + throw std::invalid_argument("PowerSpectrum: samplingFreq is negative"); + } else if(featParams_.frameSizeMs <= 0) { + throw std::invalid_argument("PowerSpectrum: frameSizeMs is negative"); + } else if(featParams_.frameStrideMs <= 0) { + throw std::invalid_argument("PowerSpectrum: frameStrideMs is negative"); + } else if(featParams_.numFrameSizeSamples() <= 0) { + throw std::invalid_argument("PowerSpectrum: frameSizeMs is too low"); + } else if(featParams_.numFrameStrideSamples() <= 0) { + throw std::invalid_argument("PowerSpectrum: frameStrideMs is too low"); + } } PowerSpectrum::~PowerSpectrum() { - fftw_destroy_plan(*fftPlan_); + fftw_destroy_plan(*fftPlan_); } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/PowerSpectrum.h b/flashlight/pkg/speech/audio/feature/PowerSpectrum.h index aa1cda1..93b17ba 100644 --- a/flashlight/pkg/speech/audio/feature/PowerSpectrum.h +++ b/flashlight/pkg/speech/audio/feature/PowerSpectrum.h @@ -17,52 +17,52 @@ // Fwd decl class fftw_plan_s; -using fftw_plan = fftw_plan_s *; +using fftw_plan = fftw_plan_s*; namespace fl { namespace lib { -namespace audio { + namespace audio { // Computes Power Spectrum features for a speech signal. -class PowerSpectrum { - public: - explicit PowerSpectrum(const FeatureParams& params); + class PowerSpectrum { + public: + explicit PowerSpectrum(const FeatureParams& params); - virtual ~PowerSpectrum(); + virtual ~PowerSpectrum(); - // input - input speech signal (T) - // Returns - Power spectrum (Col Major : FEAT X FRAMESZ) - virtual std::vector apply(const std::vector& input); + // input - input speech signal (T) + // Returns - Power spectrum (Col Major : FEAT X FRAMESZ) + virtual std::vector apply(const std::vector& input); - // input - input speech signal (Col Major : T X BATCHSZ) - // Returns - Output features (Col Major : FEAT X FRAMESZ X BATCHSZ) - std::vector batchApply(const std::vector& input, int batchSz); + // input - input speech signal (Col Major : T X BATCHSZ) + // Returns - Output features (Col Major : FEAT X FRAMESZ X BATCHSZ) + std::vector batchApply(const std::vector& input, int batchSz); - virtual int outputSize(int inputSz); + virtual int outputSize(int inputSz); - FeatureParams getFeatureParams() const; + FeatureParams getFeatureParams() const; - protected: - FeatureParams featParams_; + protected: + FeatureParams featParams_; - // Helper function which takes input as signal after dividing the signal into - // frames. Main purpose of this function is to reuse it in MFSC, MFCC code - std::vector powSpectrumImpl(std::vector& frames); + // Helper function which takes input as signal after dividing the signal into + // frames. Main purpose of this function is to reuse it in MFSC, MFCC code + std::vector powSpectrumImpl(std::vector& frames); - void validatePowSpecParams() const; + void validatePowSpecParams() const; - private: - // The following classes are defined in the order they are applied - Dither dither_; - PreEmphasis preEmphasis_; - Windowing windowing_; + private: + // The following classes are defined in the order they are applied + Dither dither_; + PreEmphasis preEmphasis_; + Windowing windowing_; - std::unique_ptr fftPlan_; // fftw_plan is an opque pointer type - std::vector inFftBuf_, outFftBuf_; - std::mutex fftMutex_; - static std::mutex fftPlanMutex_; -}; -} // namespace audio + std::unique_ptr fftPlan_; // fftw_plan is an opque pointer type + std::vector inFftBuf_, outFftBuf_; + std::mutex fftMutex_; + static std::mutex fftPlanMutex_; + }; + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/PreEmphasis.cpp b/flashlight/pkg/speech/audio/feature/PreEmphasis.cpp index 0ae91bf..8b55836 100644 --- a/flashlight/pkg/speech/audio/feature/PreEmphasis.cpp +++ b/flashlight/pkg/speech/audio/feature/PreEmphasis.cpp @@ -12,35 +12,36 @@ namespace fl::lib::audio { -PreEmphasis::PreEmphasis(float alpha, int N) - : preemCoef_(alpha), windowLength_(N) { - if (windowLength_ <= 1) { - throw std::invalid_argument("PreEmphasis: windowLength must be > 1"); - } - if (preemCoef_ < 0.0 || preemCoef_ >= 1.0) { - throw std::invalid_argument("PreEmphasis: alpha must be in [0, 1)"); - } +PreEmphasis::PreEmphasis(float alpha, int N) : preemCoef_(alpha), + windowLength_(N) { + if(windowLength_ <= 1) { + throw std::invalid_argument("PreEmphasis: windowLength must be > 1"); + } + if(preemCoef_ < 0.0 || preemCoef_ >= 1.0) { + throw std::invalid_argument("PreEmphasis: alpha must be in [0, 1)"); + } }; std::vector PreEmphasis::apply(const std::vector& input) const { - auto output(input); - applyInPlace(output); - return output; + auto output(input); + applyInPlace(output); + return output; } void PreEmphasis::applyInPlace(std::vector& input) const { - if (input.size() % windowLength_ != 0) { - throw std::invalid_argument( - "PreEmphasis: input.size() not divisible by windowLength"); - } - size_t nframes = input.size() / windowLength_; - for (size_t n = nframes; n > 0; --n) { - size_t e = n * windowLength_ - 1; // end of current frame - size_t s = (n - 1) * windowLength_; // start of current frame - for (size_t i = e; i > s; --i) { - input[i] -= (preemCoef_ * input[i - 1]); + if(input.size() % windowLength_ != 0) { + throw std::invalid_argument( + "PreEmphasis: input.size() not divisible by windowLength" + ); + } + size_t nframes = input.size() / windowLength_; + for(size_t n = nframes; n > 0; --n) { + size_t e = n * windowLength_ - 1; // end of current frame + size_t s = (n - 1) * windowLength_; // start of current frame + for(size_t i = e; i > s; --i) { + input[i] -= (preemCoef_ * input[i - 1]); + } + input[s] *= (1 - preemCoef_); } - input[s] *= (1 - preemCoef_); - } } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/PreEmphasis.h b/flashlight/pkg/speech/audio/feature/PreEmphasis.h index a0312c2..3809c63 100644 --- a/flashlight/pkg/speech/audio/feature/PreEmphasis.h +++ b/flashlight/pkg/speech/audio/feature/PreEmphasis.h @@ -12,23 +12,23 @@ namespace fl { namespace lib { -namespace audio { + namespace audio { // Pre-emphasise the signal by applying the first order difference equation -// s'(n) = s(n) - k * s(n-1) where k in [0, 1) +// s'(n) = s(n) - k * s(n-1) where k in [0, 1) -class PreEmphasis { - public: - PreEmphasis(float alpha, int N); + class PreEmphasis { + public: + PreEmphasis(float alpha, int N); - std::vector apply(const std::vector& input) const; + std::vector apply(const std::vector& input) const; - void applyInPlace(std::vector& input) const; + void applyInPlace(std::vector& input) const; - private: - float preemCoef_; - int windowLength_; -}; -} // namespace audio + private: + float preemCoef_; + int windowLength_; + }; + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/SpeechUtils.cpp b/flashlight/pkg/speech/audio/feature/SpeechUtils.cpp index 3b774cb..ff5c40a 100644 --- a/flashlight/pkg/speech/audio/feature/SpeechUtils.cpp +++ b/flashlight/pkg/speech/audio/feature/SpeechUtils.cpp @@ -23,65 +23,70 @@ namespace fl::lib::audio { std::vector frameSignal( const std::vector& input, - const FeatureParams& params) { - auto frameSize = params.numFrameSizeSamples(); - auto frameStride = params.numFrameStrideSamples(); - int numframes = params.numFrames(input.size()); - // HTK: Values coming out of rasta treat samples as integers, - // not range -1..1, hence scale up here to match (approx) - float scale = 32768.0; - std::vector frames(numframes * frameSize); - for (size_t f = 0; f < numframes; ++f) { - for (size_t i = 0; i < frameSize; ++i) { - frames[f * frameSize + i] = scale * input[f * frameStride + i]; + const FeatureParams& params +) { + auto frameSize = params.numFrameSizeSamples(); + auto frameStride = params.numFrameStrideSamples(); + int numframes = params.numFrames(input.size()); + // HTK: Values coming out of rasta treat samples as integers, + // not range -1..1, hence scale up here to match (approx) + float scale = 32768.0; + std::vector frames(numframes * frameSize); + for(size_t f = 0; f < numframes; ++f) { + for(size_t i = 0; i < frameSize; ++i) { + frames[f * frameSize + i] = scale * input[f * frameStride + i]; + } } - } - return frames; + return frames; } std::vector cblasGemm( const std::vector& matA, const std::vector& matB, int n, - int k) { - if (n <= 0 || k <= 0 || matA.empty() || (matA.size() % k != 0) || - (matB.size() != n * k)) { - throw std::invalid_argument("cblasGemm: invalid arguments"); - } + int k +) { + if( + n <= 0 || k <= 0 || matA.empty() || (matA.size() % k != 0) + || (matB.size() != n * k) + ) { + throw std::invalid_argument("cblasGemm: invalid arguments"); + } - int m = matA.size() / k; + int m = matA.size() / k; - std::vector matC(m * n); + std::vector matC(m * n); #if FL_USE_MKL - auto prevMaxThreads = mkl_get_max_threads(); - mkl_set_num_threads_local(1); + auto prevMaxThreads = mkl_get_max_threads(); + mkl_set_num_threads_local(1); #else // TODO: to be tested #endif - cblas_sgemm( - CblasRowMajor, - CblasNoTrans, - CblasNoTrans, - m, - n, - k, - 1.0, // alpha - matA.data(), - k, - matB.data(), - n, - 0.0, // beta - matC.data(), - n); + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, + CblasNoTrans, + m, + n, + k, + 1.0, // alpha + matA.data(), + k, + matB.data(), + n, + 0.0, // beta + matC.data(), + n + ); #if FL_USE_MKL - mkl_set_num_threads_local(prevMaxThreads); + mkl_set_num_threads_local(prevMaxThreads); #else // TODO: to be tested #endif - return matC; + return matC; }; } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/SpeechUtils.h b/flashlight/pkg/speech/audio/feature/SpeechUtils.h index fc9ac48..c6bffda 100644 --- a/flashlight/pkg/speech/audio/feature/SpeechUtils.h +++ b/flashlight/pkg/speech/audio/feature/SpeechUtils.h @@ -13,21 +13,23 @@ namespace fl { namespace lib { -namespace audio { + namespace audio { // Convert the speech signal into frames -std::vector frameSignal( - const std::vector& input, - const FeatureParams& params); + std::vector frameSignal( + const std::vector& input, + const FeatureParams& params + ); // row major; matA - m x k , matB - k x n -std::vector cblasGemm( - const std::vector& matA, - const std::vector& matB, - int n, - int k); -} // namespace audio + std::vector cblasGemm( + const std::vector& matA, + const std::vector& matB, + int n, + int k + ); + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/TriFilterbank.cpp b/flashlight/pkg/speech/audio/feature/TriFilterbank.cpp index e95de11..ba58d4d 100644 --- a/flashlight/pkg/speech/audio/feature/TriFilterbank.cpp +++ b/flashlight/pkg/speech/audio/feature/TriFilterbank.cpp @@ -22,76 +22,77 @@ TriFilterbank::TriFilterbank( int samplingfreq, int lowfreq /* = 0 */, int highfreq /* = -1 */, - FrequencyScale freqscale /* = FrequencyScale::MEL */) - : numFilters_(numfilters), - filterLen_(filterlen), - samplingFreq_(samplingfreq), - lowFreq_(lowfreq), - highFreq_((highfreq > 0) ? highfreq : (samplingfreq >> 1)), - freqScale_(freqscale), - H_(filterlen * numfilters) { - float minwarpfreq = hertzToWarpedScale(lowFreq_, freqScale_); - float maxwarpfreq = hertzToWarpedScale(highFreq_, freqScale_); - float dwarp = (maxwarpfreq - minwarpfreq) / (numfilters + 1); + FrequencyScale freqscale /* = FrequencyScale::MEL */ +) : numFilters_(numfilters), + filterLen_(filterlen), + samplingFreq_(samplingfreq), + lowFreq_(lowfreq), + highFreq_((highfreq > 0) ? highfreq : (samplingfreq >> 1)), + freqScale_(freqscale), + H_(filterlen * numfilters) { + float minwarpfreq = hertzToWarpedScale(lowFreq_, freqScale_); + float maxwarpfreq = hertzToWarpedScale(highFreq_, freqScale_); + float dwarp = (maxwarpfreq - minwarpfreq) / (numfilters + 1); - std::vector f(numFilters_ + 2); - for (int i = 0; i < (numFilters_ + 2); ++i) { - f[i] = warpedToHertzScale(i * dwarp + minwarpfreq, freqScale_) * - (filterLen_ - 1) * 2.0 / samplingFreq_; - } + std::vector f(numFilters_ + 2); + for(int i = 0; i < (numFilters_ + 2); ++i) { + f[i] = warpedToHertzScale(i * dwarp + minwarpfreq, freqScale_) + * (filterLen_ - 1) * 2.0 / samplingFreq_; + } - float minH = 0.0; + float minH = 0.0; - for (size_t i = 0; i < filterLen_; ++i) { - for (size_t j = 0; j < numFilters_; ++j) { - float hislope = (i - f[j]) / (f[j + 1] - f[j]); - float loslope = (f[j + 2] - i) / (f[j + 2] - f[j + 1]); - H_[i * numFilters_ + j] = std::max(std::min(hislope, loslope), minH); + for(size_t i = 0; i < filterLen_; ++i) { + for(size_t j = 0; j < numFilters_; ++j) { + float hislope = (i - f[j]) / (f[j + 1] - f[j]); + float loslope = (f[j + 2] - i) / (f[j + 2] - f[j + 1]); + H_[i * numFilters_ + j] = std::max(std::min(hislope, loslope), minH); + } } - } } std::vector TriFilterbank::apply( const std::vector& input, - float melfloor /* = 0.0 */) const { - std::vector output = cblasGemm(input, H_, numFilters_, filterLen_); - std::transform( - output.begin(), - output.end(), - output.begin(), - [melfloor](float n) -> float { return std::max(n, melfloor); }); - return output; + float melfloor /* = 0.0 */ +) const { + std::vector output = cblasGemm(input, H_, numFilters_, filterLen_); + std::transform( + output.begin(), + output.end(), + output.begin(), + [melfloor](float n) -> float { return std::max(n, melfloor); }); + return output; } std::vector TriFilterbank::filterbank() const { - return H_; + return H_; } float TriFilterbank::hertzToWarpedScale(float hz, FrequencyScale freqscale) - const { - switch (freqscale) { - case FrequencyScale::MEL: - return 2595.0 * log10(1.0 + hz / 700.0); - case FrequencyScale::LOG10: - return log10(hz); - case FrequencyScale::LINEAR: - return hz; - default: - throw std::invalid_argument("TriFilterbank: unsupported frequency scale"); - } +const { + switch(freqscale) { + case FrequencyScale::MEL: + return 2595.0 * log10(1.0 + hz / 700.0); + case FrequencyScale::LOG10: + return log10(hz); + case FrequencyScale::LINEAR: + return hz; + default: + throw std::invalid_argument("TriFilterbank: unsupported frequency scale"); + } } float TriFilterbank::warpedToHertzScale(float wrp, FrequencyScale freqscale) - const { - switch (freqscale) { - case FrequencyScale::MEL: - return 700.0 * (pow(10, wrp / 2595.0) - 1); - case FrequencyScale::LOG10: - return pow(10, wrp); - case FrequencyScale::LINEAR: - return wrp; - default: - throw std::invalid_argument("TriFilterbank: unsupported frequency scale"); - } +const { + switch(freqscale) { + case FrequencyScale::MEL: + return 700.0 * (pow(10, wrp / 2595.0) - 1); + case FrequencyScale::LOG10: + return pow(10, wrp); + case FrequencyScale::LINEAR: + return wrp; + default: + throw std::invalid_argument("TriFilterbank: unsupported frequency scale"); + } } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/TriFilterbank.h b/flashlight/pkg/speech/audio/feature/TriFilterbank.h index 28b0f33..2e7df1f 100644 --- a/flashlight/pkg/speech/audio/feature/TriFilterbank.h +++ b/flashlight/pkg/speech/audio/feature/TriFilterbank.h @@ -14,38 +14,40 @@ namespace fl { namespace lib { -namespace audio { - -class TriFilterbank { - public: - TriFilterbank( - int numfilters, - int filterlen, - int samplingfreq, - int lowfreq = 0, - int highfreq = -1, - FrequencyScale freqscale = FrequencyScale::MEL); - - std::vector apply( - const std::vector& input, - float melfloor = 0.0) const; - - // Returns triangular filterbank matrix - std::vector filterbank() const; - - private: - int numFilters_; // Number of filterbank channels - int filterLen_; // length of each filterbank channel - int samplingFreq_; // sampling frequency (Hz) - int lowFreq_; // lower cutoff frequency (Hz) - int highFreq_; // higher cutoff frequency (Hz) - FrequencyScale freqScale_; // frequency warp type Ex. FrequencyScale::MEL - std::vector - H_; // (numFilters_ x filterLen_) triangular filterbank matrix - - float hertzToWarpedScale(float hz, FrequencyScale freqscale) const; - float warpedToHertzScale(float wrp, FrequencyScale freqscale) const; -}; -} // namespace audio + namespace audio { + + class TriFilterbank { + public: + TriFilterbank( + int numfilters, + int filterlen, + int samplingfreq, + int lowfreq = 0, + int highfreq = -1, + FrequencyScale freqscale = FrequencyScale::MEL + ); + + std::vector apply( + const std::vector& input, + float melfloor = 0.0 + ) const; + + // Returns triangular filterbank matrix + std::vector filterbank() const; + + private: + int numFilters_; // Number of filterbank channels + int filterLen_; // length of each filterbank channel + int samplingFreq_; // sampling frequency (Hz) + int lowFreq_; // lower cutoff frequency (Hz) + int highFreq_; // higher cutoff frequency (Hz) + FrequencyScale freqScale_; // frequency warp type Ex. FrequencyScale::MEL + std::vector + H_; // (numFilters_ x filterLen_) triangular filterbank matrix + + float hertzToWarpedScale(float hz, FrequencyScale freqscale) const; + float warpedToHertzScale(float wrp, FrequencyScale freqscale) const; + }; + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Windowing.cpp b/flashlight/pkg/speech/audio/feature/Windowing.cpp index 71e56b1..aa9d10d 100644 --- a/flashlight/pkg/speech/audio/feature/Windowing.cpp +++ b/flashlight/pkg/speech/audio/feature/Windowing.cpp @@ -14,45 +14,47 @@ namespace fl::lib::audio { -Windowing::Windowing(int N, WindowType windowtype) - : windowLength_(N), windowType_(windowtype), coefs_(N) { - if (windowLength_ <= 1) { - throw std::invalid_argument("Windowing: windowLength must be > 1"); - } - std::iota(coefs_.begin(), coefs_.end(), 0.0); - switch (windowtype) { - case WindowType::HAMMING: - for (auto& c : coefs_) { - c = 0.54 - 0.46 * std::cos(2 * M_PI * c / (N - 1)); - } - break; - case WindowType::HANNING: - for (auto& c : coefs_) { - c = 0.5 * (1.0 - std::cos(2 * M_PI * c / (N - 1))); - } - break; - default: - throw std::invalid_argument("Windowing: unsupported window type"); - } +Windowing::Windowing(int N, WindowType windowtype) : windowLength_(N), + windowType_(windowtype), + coefs_(N) { + if(windowLength_ <= 1) { + throw std::invalid_argument("Windowing: windowLength must be > 1"); + } + std::iota(coefs_.begin(), coefs_.end(), 0.0); + switch(windowtype) { + case WindowType::HAMMING: + for(auto& c : coefs_) { + c = 0.54 - 0.46 * std::cos(2 * M_PI * c / (N - 1)); + } + break; + case WindowType::HANNING: + for(auto& c : coefs_) { + c = 0.5 * (1.0 - std::cos(2 * M_PI * c / (N - 1))); + } + break; + default: + throw std::invalid_argument("Windowing: unsupported window type"); + } } std::vector Windowing::apply(const std::vector& input) const { - auto output(input); - applyInPlace(output); - return output; + auto output(input); + applyInPlace(output); + return output; } void Windowing::applyInPlace(std::vector& input) const { - if (input.size() % windowLength_ != 0) { - throw std::invalid_argument( - "Windowing: input size is not divisible by windowLength"); - } - size_t n = 0; - for (auto& in : input) { - in *= coefs_[n++]; - if (n == windowLength_) { - n = 0; + if(input.size() % windowLength_ != 0) { + throw std::invalid_argument( + "Windowing: input size is not divisible by windowLength" + ); + } + size_t n = 0; + for(auto& in : input) { + in *= coefs_[n++]; + if(n == windowLength_) { + n = 0; + } } - } } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Windowing.h b/flashlight/pkg/speech/audio/feature/Windowing.h index 350222e..4884c7b 100644 --- a/flashlight/pkg/speech/audio/feature/Windowing.h +++ b/flashlight/pkg/speech/audio/feature/Windowing.h @@ -14,24 +14,24 @@ namespace fl { namespace lib { -namespace audio { + namespace audio { // Applies a given window on input -// s'(n) = w(n) * s(n) where w(n) are the window coefficients +// s'(n) = w(n) * s(n) where w(n) are the window coefficients -class Windowing { - public: - Windowing(int N, WindowType window); + class Windowing { + public: + Windowing(int N, WindowType window); - std::vector apply(const std::vector& input) const; + std::vector apply(const std::vector& input) const; - void applyInPlace(std::vector& input) const; + void applyInPlace(std::vector& input) const; - private: - int windowLength_; - WindowType windowType_; - std::vector coefs_; -}; -} // namespace audio + private: + int windowLength_; + WindowType windowType_; + std::vector coefs_; + }; + } // namespace audio } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp b/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp index f823dd8..31f1624 100644 --- a/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp +++ b/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp @@ -19,79 +19,82 @@ namespace fl::pkg::speech::sfx { std::string AdditiveNoise::Config::prettyString() const { - std::stringstream ss; - ss << "AdditiveNoise::Config{ratio_=" << ratio_ << " minSnr_=" << minSnr_ - << " maxSnr_=" << maxSnr_ << " nClipsMin_=" << nClipsMin_ << " nClipsMax_" - << nClipsMax_ << " listFilePath_=" << listFilePath_ << '}'; - return ss.str(); + std::stringstream ss; + ss << "AdditiveNoise::Config{ratio_=" << ratio_ << " minSnr_=" << minSnr_ + << " maxSnr_=" << maxSnr_ << " nClipsMin_=" << nClipsMin_ << " nClipsMax_" + << nClipsMax_ << " listFilePath_=" << listFilePath_ << '}'; + return ss.str(); } std::string AdditiveNoise::prettyString() const { - std::stringstream ss; - ss << "AdditiveNoise{config={" << conf_.prettyString() << '}'; - return ss.str(); + std::stringstream ss; + ss << "AdditiveNoise{config={" << conf_.prettyString() << '}'; + return ss.str(); }; AdditiveNoise::AdditiveNoise( const AdditiveNoise::Config& config, - unsigned int seed /* = 0 */) - : conf_(config), rng_(seed) { - std::ifstream listFile(conf_.listFilePath_); - if (!listFile) { - throw std::runtime_error( - "AdditiveNoise failed to open listFilePath_=" + conf_.listFilePath_); - } - while (!listFile.eof()) { - try { - std::string filename; - std::getline(listFile, filename); - if (!filename.empty()) { - noiseFiles_.push_back(filename); - } - } catch (std::exception& ex) { - throw std::runtime_error( - "AdditiveNoise failed to read listFilePath_=" + conf_.listFilePath_ + - " with error=" + ex.what()); + unsigned int seed /* = 0 */ +) : conf_(config), + rng_(seed) { + std::ifstream listFile(conf_.listFilePath_); + if(!listFile) { + throw std::runtime_error( + "AdditiveNoise failed to open listFilePath_=" + conf_.listFilePath_ + ); + } + while(!listFile.eof()) { + try { + std::string filename; + std::getline(listFile, filename); + if(!filename.empty()) { + noiseFiles_.push_back(filename); + } + } catch(std::exception& ex) { + throw std::runtime_error( + "AdditiveNoise failed to read listFilePath_=" + conf_.listFilePath_ + + " with error=" + ex.what() + ); + } } - } } void AdditiveNoise::apply(std::vector& signal) { - if (rng_.random() >= conf_.proba_) { - return; - } - const float signalRms = rootMeanSquare(signal); - const float snr = rng_.uniform(conf_.minSnr_, conf_.maxSnr_); - const int nClips = rng_.randInt(conf_.nClipsMin_, conf_.nClipsMax_); - if (nClips == 0) { - return; - } - int augStart = rng_.randInt(0, signal.size() - 1); - // overflow implies we start at the beginning again. - int augEnd = augStart + conf_.ratio_ * signal.size(); - - std::vector mixedNoise(signal.size(), 0.0f); - for (int i = 0; i < nClips; ++i) { - auto curNoiseFileIdx = rng_.randInt(0, noiseFiles_.size() - 1); - auto curNoise = loadSound(noiseFiles_[curNoiseFileIdx]); - int shift = rng_.randInt(0, curNoise.size() - 1); - for (int j = augStart; j < augEnd; ++j) { - mixedNoise[j % mixedNoise.size()] += - curNoise[(shift + j) % curNoise.size()]; + if(rng_.random() >= conf_.proba_) { + return; + } + const float signalRms = rootMeanSquare(signal); + const float snr = rng_.uniform(conf_.minSnr_, conf_.maxSnr_); + const int nClips = rng_.randInt(conf_.nClipsMin_, conf_.nClipsMax_); + if(nClips == 0) { + return; } - } + int augStart = rng_.randInt(0, signal.size() - 1); + // overflow implies we start at the beginning again. + int augEnd = augStart + conf_.ratio_ * signal.size(); - const float noiseRms = rootMeanSquare(mixedNoise); - if (noiseRms > 0) { - // https://en.wikipedia.org/wiki/Signal-to-noise_ratio - const float noiseMult = (signalRms / (noiseRms * std::pow(10, snr / 20.0))); - for (int i = 0; i < signal.size(); ++i) { - signal[i] += mixedNoise[i] * noiseMult; + std::vector mixedNoise(signal.size(), 0.0f); + for(int i = 0; i < nClips; ++i) { + auto curNoiseFileIdx = rng_.randInt(0, noiseFiles_.size() - 1); + auto curNoise = loadSound(noiseFiles_[curNoiseFileIdx]); + int shift = rng_.randInt(0, curNoise.size() - 1); + for(int j = augStart; j < augEnd; ++j) { + mixedNoise[j % mixedNoise.size()] += + curNoise[(shift + j) % curNoise.size()]; + } } - } else { - FL_LOG(fl::LogLevel::WARNING) + + const float noiseRms = rootMeanSquare(mixedNoise); + if(noiseRms > 0) { + // https://en.wikipedia.org/wiki/Signal-to-noise_ratio + const float noiseMult = (signalRms / (noiseRms * std::pow(10, snr / 20.0))); + for(int i = 0; i < signal.size(); ++i) { + signal[i] += mixedNoise[i] * noiseMult; + } + } else { + FL_LOG(fl::LogLevel::WARNING) << "AdditiveNoise::apply() invalid noiseRms=" << noiseRms; - } + } } } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/AdditiveNoise.h b/flashlight/pkg/speech/augmentation/AdditiveNoise.h index 585e19a..f0641e3 100644 --- a/flashlight/pkg/speech/augmentation/AdditiveNoise.h +++ b/flashlight/pkg/speech/augmentation/AdditiveNoise.h @@ -17,8 +17,8 @@ namespace fl { namespace pkg { -namespace speech { -namespace sfx { + namespace speech { + namespace sfx { /** * The additive noise sound effect loads noise files and augments them to the * signal with hyper parameters that are chosen randomly within a configured @@ -36,36 +36,37 @@ namespace sfx { * augmented interval. rms(noise) is calculated on the sum of all noise clipse * over the augmented interval. */ -class AdditiveNoise : public SoundEffect { - public: - struct Config { - /** - * probability of aapplying reverb. - */ - float proba_ = 1.0; - double ratio_ = 1.0; - double minSnr_ = 0; - double maxSnr_ = 30; - int nClipsMin_ = 1; - int nClipsMax_ = 3; - std::string listFilePath_; - std::string prettyString() const; - }; + class AdditiveNoise : public SoundEffect { + public: + struct Config { + /** + * probability of aapplying reverb. + */ + float proba_ = 1.0; + double ratio_ = 1.0; + double minSnr_ = 0; + double maxSnr_ = 30; + int nClipsMin_ = 1; + int nClipsMax_ = 3; + std::string listFilePath_; + std::string prettyString() const; + }; - explicit AdditiveNoise( - const AdditiveNoise::Config& config, - unsigned int seed = 0); - ~AdditiveNoise() override = default; - void apply(std::vector& signal) override; - std::string prettyString() const override; + explicit AdditiveNoise( + const AdditiveNoise::Config& config, + unsigned int seed = 0 + ); + ~AdditiveNoise() override = default; + void apply(std::vector& signal) override; + std::string prettyString() const override; - private: - const AdditiveNoise::Config conf_; - std::vector noiseFiles_; - RandomNumberGenerator rng_; -}; + private: + const AdditiveNoise::Config conf_; + std::vector noiseFiles_; + RandomNumberGenerator rng_; + }; -} // namespace sfx -} // namespace speech + } // namespace sfx + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/GaussianNoise.cpp b/flashlight/pkg/speech/augmentation/GaussianNoise.cpp index dcf45db..b276115 100644 --- a/flashlight/pkg/speech/augmentation/GaussianNoise.cpp +++ b/flashlight/pkg/speech/augmentation/GaussianNoise.cpp @@ -15,34 +15,35 @@ namespace fl::pkg::speech::sfx { std::string GaussianNoise::Config::prettyString() const { - std::stringstream ss; - ss << "GaussianNoise::Config{minSnr_=" << minSnr_ << " maxSnr_=" << maxSnr_ - << '}'; - return ss.str(); + std::stringstream ss; + ss << "GaussianNoise::Config{minSnr_=" << minSnr_ << " maxSnr_=" << maxSnr_ + << '}'; + return ss.str(); } std::string GaussianNoise::prettyString() const { - std::stringstream ss; - ss << "GaussianNoise{config={" << conf_.prettyString() << "}}"; - return ss.str(); + std::stringstream ss; + ss << "GaussianNoise{config={" << conf_.prettyString() << "}}"; + return ss.str(); }; GaussianNoise::GaussianNoise( const GaussianNoise::Config& config, - unsigned int seed /* = 0 */) - : conf_(config), rng_(seed) {} + unsigned int seed /* = 0 */ +) : conf_(config), + rng_(seed) {} void GaussianNoise::apply(std::vector& signal) { - if (rng_.random() >= conf_.proba_) { - return; - } - const float signalRms = rootMeanSquare(signal); - const float snr = rng_.uniform(conf_.minSnr_, conf_.maxSnr_); - const float noiseMult = signalRms / std::pow(10, snr / 20.0); - - for (int i = 0; i < signal.size(); ++i) { - signal[i] += rng_.gaussian(0, noiseMult); - } + if(rng_.random() >= conf_.proba_) { + return; + } + const float signalRms = rootMeanSquare(signal); + const float snr = rng_.uniform(conf_.minSnr_, conf_.maxSnr_); + const float noiseMult = signalRms / std::pow(10, snr / 20.0); + + for(int i = 0; i < signal.size(); ++i) { + signal[i] += rng_.gaussian(0, noiseMult); + } } } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/GaussianNoise.h b/flashlight/pkg/speech/augmentation/GaussianNoise.h index acbe227..8e1024a 100644 --- a/flashlight/pkg/speech/augmentation/GaussianNoise.h +++ b/flashlight/pkg/speech/augmentation/GaussianNoise.h @@ -17,34 +17,35 @@ namespace fl { namespace pkg { -namespace speech { -namespace sfx { + namespace speech { + namespace sfx { /** * Add gaussian noise to the samples with given Signal to Noise Ratio (SNR) */ -class GaussianNoise : public SoundEffect { - public: - struct Config { - float proba_ = 1.0; - double minSnr_ = 0; - double maxSnr_ = 30; - std::string prettyString() const; - }; - - explicit GaussianNoise( - const GaussianNoise::Config& config, - unsigned int seed = 0); - ~GaussianNoise() override = default; - void apply(std::vector& signal) override; - std::string prettyString() const override; - - private: - const GaussianNoise::Config conf_; - RandomNumberGenerator rng_; -}; - -} // namespace sfx -} // namespace speech + class GaussianNoise : public SoundEffect { + public: + struct Config { + float proba_ = 1.0; + double minSnr_ = 0; + double maxSnr_ = 30; + std::string prettyString() const; + }; + + explicit GaussianNoise( + const GaussianNoise::Config& config, + unsigned int seed = 0 + ); + ~GaussianNoise() override = default; + void apply(std::vector& signal) override; + std::string prettyString() const override; + + private: + const GaussianNoise::Config conf_; + RandomNumberGenerator rng_; + }; + + } // namespace sfx + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/Reverberation.cpp b/flashlight/pkg/speech/augmentation/Reverberation.cpp index b2808a6..7e2e03b 100644 --- a/flashlight/pkg/speech/augmentation/Reverberation.cpp +++ b/flashlight/pkg/speech/augmentation/Reverberation.cpp @@ -15,71 +15,77 @@ namespace fl::pkg::speech::sfx { ReverbEcho::ReverbEcho( const ReverbEcho::Config& conf, - unsigned int seed /* = 0 */) - : conf_(conf), rng_(seed) {} + unsigned int seed /* = 0 */ +) : conf_(conf), + rng_(seed) {} void ReverbEcho::applyReverb( std::vector& source, float initial, float firstDelay, - float rt60) { - size_t length = source.size(); - std::vector reverb(length, 0); - for (int i = 0; i < conf_.repeat_; ++i) { - float frac = 1; - // echo = initial * source - std::vector echo = source; - std::transform( - echo.begin(), echo.end(), echo.begin(), [initial](float x) -> float { - return x * initial; - }); - while (frac > 1e-3) { - // Add jitter noise for the delay - float jitter = 1 + rng_.uniform(-conf_.jitter_, conf_.jitter_); - size_t delay = 1 + int(jitter * firstDelay * conf_.sampleRate_); - if (delay > length - 1) { - break; - } - for (int j = 0; j < length - delay - 1; ++j) { - reverb[delay + j] += echo[j] * frac; - } + float rt60 +) { + size_t length = source.size(); + std::vector reverb(length, 0); + for(int i = 0; i < conf_.repeat_; ++i) { + float frac = 1; + // echo = initial * source + std::vector echo = source; + std::transform( + echo.begin(), + echo.end(), + echo.begin(), + [initial](float x) -> float { + return x * initial; + } + ); + while(frac > 1e-3) { + // Add jitter noise for the delay + float jitter = 1 + rng_.uniform(-conf_.jitter_, conf_.jitter_); + size_t delay = 1 + int(jitter * firstDelay * conf_.sampleRate_); + if(delay > length - 1) { + break; + } + for(int j = 0; j < length - delay - 1; ++j) { + reverb[delay + j] += echo[j] * frac; + } - // Add jitter noise for the attenuation - jitter = 1 + rng_.uniform(-conf_.jitter_, conf_.jitter_); - const float attenuation = std::pow(10, -3 * jitter * firstDelay / rt60); + // Add jitter noise for the attenuation + jitter = 1 + rng_.uniform(-conf_.jitter_, conf_.jitter_); + const float attenuation = std::pow(10, -3 * jitter * firstDelay / rt60); - frac *= attenuation; + frac *= attenuation; + } + } + for(int i = 0; i < length; ++i) { + source[i] += reverb[i]; } - } - for (int i = 0; i < length; ++i) { - source[i] += reverb[i]; - } } void ReverbEcho::apply(std::vector& sound) { - if (rng_.random() >= conf_.proba_) { - return; - } - // Sample characteristics for the reverb - float initial = rng_.uniform(conf_.initialMin_, conf_.initialMax_); - float firstDelay = rng_.uniform(conf_.firstDelayMin_, conf_.firstDelayMax_); - float rt60 = rng_.uniform(conf_.rt60Min_, conf_.rt60Max_); + if(rng_.random() >= conf_.proba_) { + return; + } + // Sample characteristics for the reverb + float initial = rng_.uniform(conf_.initialMin_, conf_.initialMax_); + float firstDelay = rng_.uniform(conf_.firstDelayMin_, conf_.firstDelayMax_); + float rt60 = rng_.uniform(conf_.rt60Min_, conf_.rt60Max_); - applyReverb(sound, initial, firstDelay, rt60); + applyReverb(sound, initial, firstDelay, rt60); } std::string ReverbEcho::prettyString() const { - return "ReverbEcho{conf_=" + conf_.prettyString() + "}}"; + return "ReverbEcho{conf_=" + conf_.prettyString() + "}}"; } std::string ReverbEcho::Config::prettyString() const { - std::stringstream ss; - ss << " proba_=" << proba_ << " initialMin_=" << initialMin_ - << " initialMax_=" << initialMax_ << " rt60Min_=" << rt60Min_ - << " rt60Max_=" << rt60Max_ << " firstDelayMin_=" << firstDelayMin_ - << " firstDelayMax_=" << firstDelayMax_ << " repeat_=" << repeat_ - << " jitter_=" << jitter_ << " sampleRate_=" << sampleRate_; - return ss.str(); + std::stringstream ss; + ss << " proba_=" << proba_ << " initialMin_=" << initialMin_ + << " initialMax_=" << initialMax_ << " rt60Min_=" << rt60Min_ + << " rt60Max_=" << rt60Max_ << " firstDelayMin_=" << firstDelayMin_ + << " firstDelayMax_=" << firstDelayMax_ << " repeat_=" << repeat_ + << " jitter_=" << jitter_ << " sampleRate_=" << sampleRate_; + return ss.str(); } } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/Reverberation.h b/flashlight/pkg/speech/augmentation/Reverberation.h index 40ccc17..81f6b94 100644 --- a/flashlight/pkg/speech/augmentation/Reverberation.h +++ b/flashlight/pkg/speech/augmentation/Reverberation.h @@ -17,8 +17,8 @@ namespace fl { namespace pkg { -namespace speech { -namespace sfx { + namespace speech { + namespace sfx { /** * Applies reverberation of generated RIR, crudely calculated based on random: @@ -26,76 +26,77 @@ namespace sfx { * This a c++ port of: * https://github.com/facebookresearch/denoiser/blob/master/denoiser/augment.py */ -class ReverbEcho : public SoundEffect { - public: - struct Config { - /** - * probability of aapplying reverb. - */ - float proba_ = 1.0; - /** - * amplitude of the first echo as a fraction of the input signal. For each - * sample, actually sampled from`[0, initial]`. Larger values means louder - * reverb. Physically, this would depend on the absorption of the room - walls. - */ - float initialMin_ = 0; - float initialMax_ = 0.3; - /** - * range of values to sample the RT60 in seconds, i.e. after RT60 - * seconds, the echo amplitude is 1e-3 of the first echo. The - * default values follow the recommendations of - * https://arxiv.org/ftp/arxiv/papers/2001/2001.08662.pdf, - * Section 2.4. Physically this would also be related to the - * absorption of the room walls and there is likely a relation - * between `RT60` and`initial`, which we ignore here. - */ - float rt60Min_ = 0.3; - float rt60Max_ = 1.3; - /** - * range of values to sample the first echo delay in seconds. The default - * values are equivalent to sampling a room of 3 to 10 meters. - */ - float firstDelayMin_ = 0.01; - float firstDelayMax_ = 0.03; - /** - * how many train of echos with differents jitters to add.Higher values - * means a denser reverb. - */ - size_t repeat_ = 3; - /** - * jitter used to make each repetition of the reverb echo train slightly - * different.For instance a jitter of 0.1 means the delay between two echos - * will be in the range `firstDelay + -10 %`, with the jittering noise - * being resampled after each single echo.- - */ - float jitter_ = 0.1; - /** - * fraction of the reverb of the clean speech to add back to the ground - * truth .0 = dereverberation, 1 = no dereverberation. - */ - size_t sampleRate_ = 16000; - std::string prettyString() const; - }; + class ReverbEcho : public SoundEffect { + public: + struct Config { + /** + * probability of aapplying reverb. + */ + float proba_ = 1.0; + /** + * amplitude of the first echo as a fraction of the input signal. For each + * sample, actually sampled from`[0, initial]`. Larger values means louder + * reverb. Physically, this would depend on the absorption of the room + walls. + */ + float initialMin_ = 0; + float initialMax_ = 0.3; + /** + * range of values to sample the RT60 in seconds, i.e. after RT60 + * seconds, the echo amplitude is 1e-3 of the first echo. The + * default values follow the recommendations of + * https://arxiv.org/ftp/arxiv/papers/2001/2001.08662.pdf, + * Section 2.4. Physically this would also be related to the + * absorption of the room walls and there is likely a relation + * between `RT60` and`initial`, which we ignore here. + */ + float rt60Min_ = 0.3; + float rt60Max_ = 1.3; + /** + * range of values to sample the first echo delay in seconds. The default + * values are equivalent to sampling a room of 3 to 10 meters. + */ + float firstDelayMin_ = 0.01; + float firstDelayMax_ = 0.03; + /** + * how many train of echos with differents jitters to add.Higher values + * means a denser reverb. + */ + size_t repeat_ = 3; + /** + * jitter used to make each repetition of the reverb echo train slightly + * different.For instance a jitter of 0.1 means the delay between two echos + * will be in the range `firstDelay + -10 %`, with the jittering noise + * being resampled after each single echo.- + */ + float jitter_ = 0.1; + /** + * fraction of the reverb of the clean speech to add back to the ground + * truth .0 = dereverberation, 1 = no dereverberation. + */ + size_t sampleRate_ = 16000; + std::string prettyString() const; + }; - explicit ReverbEcho(const ReverbEcho::Config& config, unsigned int seed = 0); - ~ReverbEcho() override = default; - void apply(std::vector& sound) override; - std::string prettyString() const override; + explicit ReverbEcho(const ReverbEcho::Config& config, unsigned int seed = 0); + ~ReverbEcho() override = default; + void apply(std::vector& sound) override; + std::string prettyString() const override; - private: - // augments source with reverberation noise - void applyReverb( - std::vector& source, - float initial, - float firstDelay, - float rt60); + private: + // augments source with reverberation noise + void applyReverb( + std::vector& source, + float initial, + float firstDelay, + float rt60 + ); - const ReverbEcho::Config conf_; - RandomNumberGenerator rng_; -}; + const ReverbEcho::Config conf_; + RandomNumberGenerator rng_; + }; -} // namespace sfx -} // namespace speech + } // namespace sfx + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/SoundEffect.cpp b/flashlight/pkg/speech/augmentation/SoundEffect.cpp index a89071c..2137ff8 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffect.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffect.cpp @@ -14,80 +14,87 @@ namespace fl::pkg::speech::sfx { std::string SoundEffectChain::prettyString() const { - std::stringstream ss; - ss << '{' << std::endl; - for (const std::shared_ptr& sfx : soundEffects_) { - ss << "{" << sfx->prettyString() << '}' << std::endl; - } - ss << '}'; - return ss.str(); + std::stringstream ss; + ss << '{' << std::endl; + for(const std::shared_ptr& sfx : soundEffects_) { + ss << "{" << sfx->prettyString() << '}' << std::endl; + } + ss << '}'; + return ss.str(); } void SoundEffectChain::add(std::shared_ptr SoundEffect) { - soundEffects_.push_back(SoundEffect); + soundEffects_.push_back(SoundEffect); } void SoundEffectChain::apply(std::vector& sound) { - for (std::shared_ptr& effect : soundEffects_) { - effect->apply(sound); - } + for(std::shared_ptr& effect : soundEffects_) { + effect->apply(sound); + } } bool SoundEffectChain::empty() { - return soundEffects_.empty(); + return soundEffects_.empty(); } Normalize::Normalize(bool onlyIfTooHigh) : onlyIfTooHigh_(onlyIfTooHigh) {} void Normalize::apply(std::vector& sound) { - float maxAbs = 0.0f; - for (float i : sound) { - maxAbs = std::fmax(maxAbs, std::fabs(i)); - } - if (!onlyIfTooHigh_ || maxAbs > 1.0f) { - std::transform( - sound.begin(), - sound.end(), - sound.begin(), - [maxAbs](float amp) -> float { return amp / maxAbs; }); - } + float maxAbs = 0.0f; + for(float i : sound) { + maxAbs = std::fmax(maxAbs, std::fabs(i)); + } + if(!onlyIfTooHigh_ || maxAbs > 1.0f) { + std::transform( + sound.begin(), + sound.end(), + sound.begin(), + [maxAbs](float amp) -> float { return amp / maxAbs; }); + } } std::string Normalize::prettyString() const { - std::stringstream ss; - ss << "Normalize={onlyIfTooHigh=" << onlyIfTooHigh_ << "}"; - return ss.str(); + std::stringstream ss; + ss << "Normalize={onlyIfTooHigh=" << onlyIfTooHigh_ << "}"; + return ss.str(); } std::string ClampAmplitude::prettyString() const { - return "ClampAmplitude"; + return "ClampAmplitude"; } void ClampAmplitude::apply(std::vector& sound) { - std::transform( - sound.begin(), sound.end(), sound.begin(), [](float amp) -> float { - return std::fmax(std::fmin(amp, 1.0), -1.0); - }); + std::transform( + sound.begin(), + sound.end(), + sound.begin(), + [](float amp) -> float { + return std::fmax(std::fmin(amp, 1.0), -1.0); + } + ); } -Amplify::Amplify(const Amplify::Config& config) - : randomEngine_(config.randomSeed_), - randomRatio_(config.ratioMin_, config.ratioMax_) {} +Amplify::Amplify(const Amplify::Config& config) : randomEngine_(config.randomSeed_), + randomRatio_(config.ratioMin_, config.ratioMax_) {} std::string Amplify::prettyString() const { - return "Amplify"; + return "Amplify"; } void Amplify::apply(std::vector& sound) { - float ratio = 0; - { - std::lock_guard guard(mutex_); - ratio = randomRatio_(randomEngine_); - } - std::transform( - sound.begin(), sound.end(), sound.begin(), [ratio](float amp) -> float { - return amp * ratio; - }); + float ratio = 0; + { + std::lock_guard guard(mutex_); + ratio = randomRatio_(randomEngine_); + } + std::transform( + sound.begin(), + sound.end(), + sound.begin(), + [ratio](float amp) -> float { + return amp * ratio; + } + ); } } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/SoundEffect.h b/flashlight/pkg/speech/augmentation/SoundEffect.h index 39f5f97..7a0e80b 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffect.h +++ b/flashlight/pkg/speech/augmentation/SoundEffect.h @@ -16,89 +16,89 @@ namespace fl { namespace pkg { -namespace speech { -namespace sfx { + namespace speech { + namespace sfx { /** * Base class for sound effects. */ -class SoundEffect { - public: - SoundEffect() = default; - virtual ~SoundEffect() = default; - virtual void apply(std::vector& sound) = 0; - virtual std::string prettyString() const = 0; -}; + class SoundEffect { + public: + SoundEffect() = default; + virtual ~SoundEffect() = default; + virtual void apply(std::vector& sound) = 0; + virtual std::string prettyString() const = 0; + }; /** * A container for chaining sound effect. It serially applies calls to all added * sound effects. */ -class SoundEffectChain : public SoundEffect { - public: - SoundEffectChain() {} - ~SoundEffectChain() override = default; - void apply(std::vector& sound) override; - std::string prettyString() const override; - void add(std::shared_ptr SoundEffect); - bool empty(); + class SoundEffectChain : public SoundEffect { + public: + SoundEffectChain() {} + ~SoundEffectChain() override = default; + void apply(std::vector& sound) override; + std::string prettyString() const override; + void add(std::shared_ptr SoundEffect); + bool empty(); - protected: - std::vector> soundEffects_; -}; + protected: + std::vector> soundEffects_; + }; /** * Normalize amplitudes to range -1..1 using dynamic range linear compression. * No-op if the signal's amplitudes are already within that range. */ -class Normalize : public SoundEffect { - public: - explicit Normalize(bool onlyIfTooHigh = true); - ~Normalize() override = default; - void apply(std::vector& sound) override; - std::string prettyString() const override; + class Normalize : public SoundEffect { + public: + explicit Normalize(bool onlyIfTooHigh = true); + ~Normalize() override = default; + void apply(std::vector& sound) override; + std::string prettyString() const override; - private: - bool onlyIfTooHigh_; -}; + private: + bool onlyIfTooHigh_; + }; /** * Clamps amplitudes to range -1..1. * No-op if the signal's amplitudes are already within that range. */ -class ClampAmplitude : public SoundEffect { - public: - explicit ClampAmplitude() {} - ~ClampAmplitude() override = default; - void apply(std::vector& sound) override; + class ClampAmplitude : public SoundEffect { + public: + explicit ClampAmplitude() {} + ~ClampAmplitude() override = default; + void apply(std::vector& sound) override; - std::string prettyString() const override; -}; + std::string prettyString() const override; + }; /** * Amplifies (or decreases amplitude of) the signal with a random ratio in the * specified range. */ -class Amplify : public SoundEffect { - public: - struct Config { - float ratioMin_; - float ratioMax_; - unsigned int randomSeed_; - }; + class Amplify : public SoundEffect { + public: + struct Config { + float ratioMin_; + float ratioMax_; + unsigned int randomSeed_; + }; - explicit Amplify(const Amplify::Config& config); - ~Amplify() override = default; - void apply(std::vector& sound) override; - std::string prettyString() const override; + explicit Amplify(const Amplify::Config& config); + ~Amplify() override = default; + void apply(std::vector& sound) override; + std::string prettyString() const override; - private: - std::mt19937 randomEngine_; - std::uniform_real_distribution<> randomRatio_; - std::mutex mutex_; -}; + private: + std::mt19937 randomEngine_; + std::uniform_real_distribution<> randomRatio_; + std::mutex mutex_; + }; -} // namespace sfx -} // namespace speech + } // namespace sfx + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/SoundEffectApply.cpp b/flashlight/pkg/speech/augmentation/SoundEffectApply.cpp index 48f3fc6..533f4c0 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffectApply.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffectApply.cpp @@ -16,7 +16,8 @@ DEFINE_string(input, "", "Sound file to augment."); DEFINE_string( output, "augmented.flac", - "Path to store result of augmenting the input file"); + "Path to store result of augmenting the input file" +); DEFINE_string(config, "", "Path to a sound effect json config file"); using namespace ::fl::pkg::speech::sfx; @@ -25,44 +26,46 @@ using ::fl::pkg::speech::loadSoundInfo; using ::fl::pkg::speech::saveSound; int main(int argc, char** argv) { - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - std::string exec(argv[0]); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + std::string exec(argv[0]); - gflags::SetUsageMessage( - "Usage: \n " + exec + - " --input=[path to input file] --output=[path to output file] " + - "--config=[path to config file]"); + gflags::SetUsageMessage( + "Usage: \n " + exec + + " --input=[path to input file] --output=[path to output file] " + + "--config=[path to config file]" + ); - if (argc <= 1) { - LOG(FATAL) << gflags::ProgramUsage(); - } + if(argc <= 1) { + LOG(FATAL) << gflags::ProgramUsage(); + } - gflags::ParseCommandLineFlags(&argc, &argv, false); + gflags::ParseCommandLineFlags(&argc, &argv, false); - if (FLAGS_config.empty()) { - LOG(FATAL) << "flag --config must point to sound effect config file"; - } - if (FLAGS_input.empty()) { - LOG(FATAL) << "flag --input must point to input file"; - } + if(FLAGS_config.empty()) { + LOG(FATAL) << "flag --config must point to sound effect config file"; + } + if(FLAGS_input.empty()) { + LOG(FATAL) << "flag --input must point to input file"; + } - auto sound = loadSound(FLAGS_input); - auto info = loadSoundInfo(FLAGS_input); + auto sound = loadSound(FLAGS_input); + auto info = loadSoundInfo(FLAGS_input); - std::shared_ptr sfx = - createSoundEffect(readSoundEffectConfigFile(FLAGS_config)); - sfx->apply(sound); + std::shared_ptr sfx = + createSoundEffect(readSoundEffectConfigFile(FLAGS_config)); + sfx->apply(sound); - saveSound( - FLAGS_output, - sound, - info.samplerate, - info.channels, - fl::pkg::speech::SoundFormat::FLAC, - fl::pkg::speech::SoundSubFormat::PCM_16); + saveSound( + FLAGS_output, + sound, + info.samplerate, + info.channels, + fl::pkg::speech::SoundFormat::FLAC, + fl::pkg::speech::SoundSubFormat::PCM_16 + ); - LOG(INFO) << "Saving augmented file to=" << FLAGS_output; + LOG(INFO) << "Saving augmented file to=" << FLAGS_output; - return 0; + return 0; } diff --git a/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp b/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp index bd702bb..1fbc3f9 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp @@ -21,60 +21,72 @@ using namespace ::fl::pkg::speech::sfx; namespace cereal { -template +template void serialize(Archive& ar, Amplify::Config& conf) { - ar(cereal::make_nvp("ratioMin", conf.ratioMin_), - cereal::make_nvp("ratioMax", conf.ratioMax_)); + ar( + cereal::make_nvp("ratioMin", conf.ratioMin_), + cereal::make_nvp("ratioMax", conf.ratioMax_) + ); } -template +template void serialize(Archive& ar, AdditiveNoise::Config& conf) { - ar(cereal::make_nvp("proba", conf.proba_), - cereal::make_nvp("ratio", conf.ratio_), - cereal::make_nvp("minSnr", conf.minSnr_), - cereal::make_nvp("maxSnr", conf.maxSnr_), - cereal::make_nvp("nClipsMin", conf.nClipsMin_), - cereal::make_nvp("nClipsMax", conf.nClipsMax_), - cereal::make_nvp("listFilePath", conf.listFilePath_)); + ar( + cereal::make_nvp("proba", conf.proba_), + cereal::make_nvp("ratio", conf.ratio_), + cereal::make_nvp("minSnr", conf.minSnr_), + cereal::make_nvp("maxSnr", conf.maxSnr_), + cereal::make_nvp("nClipsMin", conf.nClipsMin_), + cereal::make_nvp("nClipsMax", conf.nClipsMax_), + cereal::make_nvp("listFilePath", conf.listFilePath_) + ); } -template +template void serialize(Archive& ar, ReverbEcho::Config& conf) { - ar(cereal::make_nvp("proba", conf.proba_), - cereal::make_nvp("initialMin", conf.initialMin_), - cereal::make_nvp("initialMax", conf.initialMax_), - cereal::make_nvp("rt60Min", conf.rt60Min_), - cereal::make_nvp("rt60Max", conf.rt60Max_), - cereal::make_nvp("firstDelayMin", conf.firstDelayMin_), - cereal::make_nvp("firstDelayMax", conf.firstDelayMax_), - cereal::make_nvp("repeat", conf.repeat_), - cereal::make_nvp("jitter", conf.jitter_), - cereal::make_nvp("sampleRate", conf.sampleRate_)); + ar( + cereal::make_nvp("proba", conf.proba_), + cereal::make_nvp("initialMin", conf.initialMin_), + cereal::make_nvp("initialMax", conf.initialMax_), + cereal::make_nvp("rt60Min", conf.rt60Min_), + cereal::make_nvp("rt60Max", conf.rt60Max_), + cereal::make_nvp("firstDelayMin", conf.firstDelayMin_), + cereal::make_nvp("firstDelayMax", conf.firstDelayMax_), + cereal::make_nvp("repeat", conf.repeat_), + cereal::make_nvp("jitter", conf.jitter_), + cereal::make_nvp("sampleRate", conf.sampleRate_) + ); } -template +template void serialize(Archive& ar, TimeStretch::Config& conf) { - ar(cereal::make_nvp("proba", conf.proba_), - cereal::make_nvp("minFactor", conf.minFactor_), - cereal::make_nvp("maxFactor", conf.maxFactor_), - cereal::make_nvp("sampleRate", conf.sampleRate_)); + ar( + cereal::make_nvp("proba", conf.proba_), + cereal::make_nvp("minFactor", conf.minFactor_), + cereal::make_nvp("maxFactor", conf.maxFactor_), + cereal::make_nvp("sampleRate", conf.sampleRate_) + ); } -template +template void serialize(Archive& ar, SoundEffectConfig& conf) { - ar(cereal::make_nvp("type", conf.type_)); - if (conf.type_ == kAdditiveNoise) { - ar(cereal::make_nvp("additiveNoiseConfig", conf.additiveNoiseConfig_)); - } else if (conf.type_ == kAmplify) { - ar(cereal::make_nvp("amplifyConfig", conf.amplifyConfig_)); - } else if (conf.type_ == kNormalize) { - ar(cereal::make_nvp( - "normalizeOnlyIfTooHigh", conf.normalizeOnlyIfTooHigh_)); - } else if (conf.type_ == kReverbEcho) { - ar(cereal::make_nvp("reverbEchoConfig", conf.reverbEchoConfig_)); - } else if (conf.type_ == kTimeStretch) { - ar(cereal::make_nvp("timeStretchConfig", conf.timeStretchConfig_)); - } + ar(cereal::make_nvp("type", conf.type_)); + if(conf.type_ == kAdditiveNoise) { + ar(cereal::make_nvp("additiveNoiseConfig", conf.additiveNoiseConfig_)); + } else if(conf.type_ == kAmplify) { + ar(cereal::make_nvp("amplifyConfig", conf.amplifyConfig_)); + } else if(conf.type_ == kNormalize) { + ar( + cereal::make_nvp( + "normalizeOnlyIfTooHigh", + conf.normalizeOnlyIfTooHigh_ + ) + ); + } else if(conf.type_ == kReverbEcho) { + ar(cereal::make_nvp("reverbEchoConfig", conf.reverbEchoConfig_)); + } else if(conf.type_ == kTimeStretch) { + ar(cereal::make_nvp("timeStretchConfig", conf.timeStretchConfig_)); + } } } // namespace cereal @@ -83,61 +95,66 @@ namespace fl::pkg::speech::sfx { void writeSoundEffectConfigFile( const fs::path& filename, - const std::vector& sfxConfigs) { - try { - const fs::path path = filename.parent_path(); - fs::create_directory(path); - std::ofstream file(filename); - cereal::JSONOutputArchive archive(file); - archive(cereal::make_nvp("soundEffectChain", sfxConfigs)); - } catch (std::exception& ex) { - std::stringstream ss; - ss << "writeSoundEffectConfigFile(filename=" << filename - << ") failed with error={" << ex.what() << "}"; - throw std::runtime_error(ss.str()); - } + const std::vector& sfxConfigs +) { + try { + const fs::path path = filename.parent_path(); + fs::create_directory(path); + std::ofstream file(filename); + cereal::JSONOutputArchive archive(file); + archive(cereal::make_nvp("soundEffectChain", sfxConfigs)); + } catch(std::exception& ex) { + std::stringstream ss; + ss << "writeSoundEffectConfigFile(filename=" << filename + << ") failed with error={" << ex.what() << "}"; + throw std::runtime_error(ss.str()); + } } std::vector readSoundEffectConfigFile( - const fs::path& filename) { - try { - std::ifstream file(filename); - cereal::JSONInputArchive archive(file); - std::vector sfxConfigs; - archive(sfxConfigs); - return sfxConfigs; - } catch (std::exception& ex) { - std::stringstream ss; - ss << "readSoundEffectConfigFile(filename=" << filename - << ") failed with error={" << ex.what() << "}"; - throw std::runtime_error(ss.str()); - } + const fs::path& filename +) { + try { + std::ifstream file(filename); + cereal::JSONInputArchive archive(file); + std::vector sfxConfigs; + archive(sfxConfigs); + return sfxConfigs; + } catch(std::exception& ex) { + std::stringstream ss; + ss << "readSoundEffectConfigFile(filename=" << filename + << ") failed with error={" << ex.what() << "}"; + throw std::runtime_error(ss.str()); + } } std::shared_ptr createSoundEffect( const std::vector& sfxConfigs, - unsigned int seed /* = 0 */) { - auto sfxChain = std::make_shared(); - for (const SoundEffectConfig& conf : sfxConfigs) { - if (conf.type_ == kAdditiveNoise) { - sfxChain->add( - std::make_shared(conf.additiveNoiseConfig_, seed)); - } else if (conf.type_ == kAmplify) { - sfxChain->add(std::make_shared(conf.amplifyConfig_)); - } else if (conf.type_ == kClampAmplitude) { - sfxChain->add(std::make_shared()); - } else if (conf.type_ == kNormalize) { - sfxChain->add(std::make_shared(conf.normalizeOnlyIfTooHigh_)); - } else if (conf.type_ == kReverbEcho) { - sfxChain->add(std::make_shared(conf.reverbEchoConfig_, seed)); - } else if (conf.type_ == kTimeStretch) { - sfxChain->add( - std::make_shared(conf.timeStretchConfig_, seed)); - } else { - LOG(FATAL) << "Invalid sound effect config type=" << conf.type_; + unsigned int seed /* = 0 */ +) { + auto sfxChain = std::make_shared(); + for(const SoundEffectConfig& conf : sfxConfigs) { + if(conf.type_ == kAdditiveNoise) { + sfxChain->add( + std::make_shared(conf.additiveNoiseConfig_, seed) + ); + } else if(conf.type_ == kAmplify) { + sfxChain->add(std::make_shared(conf.amplifyConfig_)); + } else if(conf.type_ == kClampAmplitude) { + sfxChain->add(std::make_shared()); + } else if(conf.type_ == kNormalize) { + sfxChain->add(std::make_shared(conf.normalizeOnlyIfTooHigh_)); + } else if(conf.type_ == kReverbEcho) { + sfxChain->add(std::make_shared(conf.reverbEchoConfig_, seed)); + } else if(conf.type_ == kTimeStretch) { + sfxChain->add( + std::make_shared(conf.timeStretchConfig_, seed) + ); + } else { + LOG(FATAL) << "Invalid sound effect config type=" << conf.type_; + } } - } - return sfxChain; + return sfxChain; } } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/SoundEffectConfig.h b/flashlight/pkg/speech/augmentation/SoundEffectConfig.h index 5efdb73..219480b 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffectConfig.h +++ b/flashlight/pkg/speech/augmentation/SoundEffectConfig.h @@ -19,43 +19,46 @@ namespace fl { namespace pkg { -namespace speech { -namespace sfx { + namespace speech { + namespace sfx { // Values for SoundEffectConfig.type_ -constexpr const char* const kAdditiveNoise = "AdditiveNoise"; -constexpr const char* const kAmplify = "Amplify"; -constexpr const char* const kClampAmplitude = "ClampAmplitude"; -constexpr const char* const kNormalize = "Normalize"; -constexpr const char* const kReverbEcho = "ReverbEcho"; -constexpr const char* const kTimeStretch = "TimeStretch"; - -struct SoundEffectConfig { - std::string type_; - // The fields below should be treated like a union, that is, only the field - // that corresponds to the type_ field should be used. Union cannot be used - // here since it does not support types like string. - bool normalizeOnlyIfTooHigh_ = true; - AdditiveNoise::Config additiveNoiseConfig_; - Amplify::Config amplifyConfig_; - ReverbEcho::Config reverbEchoConfig_; - TimeStretch::Config timeStretchConfig_; -}; - -std::shared_ptr createSoundEffect( - const std::vector& config, - unsigned int seed = 0); + constexpr const char* const kAdditiveNoise = "AdditiveNoise"; + constexpr const char* const kAmplify = "Amplify"; + constexpr const char* const kClampAmplitude = "ClampAmplitude"; + constexpr const char* const kNormalize = "Normalize"; + constexpr const char* const kReverbEcho = "ReverbEcho"; + constexpr const char* const kTimeStretch = "TimeStretch"; + + struct SoundEffectConfig { + std::string type_; + // The fields below should be treated like a union, that is, only the field + // that corresponds to the type_ field should be used. Union cannot be used + // here since it does not support types like string. + bool normalizeOnlyIfTooHigh_ = true; + AdditiveNoise::Config additiveNoiseConfig_; + Amplify::Config amplifyConfig_; + ReverbEcho::Config reverbEchoConfig_; + TimeStretch::Config timeStretchConfig_; + }; + + std::shared_ptr createSoundEffect( + const std::vector& config, + unsigned int seed = 0 + ); // Write configuration vector into json file -void writeSoundEffectConfigFile( - const fs::path& filename, - const std::vector& config); + void writeSoundEffectConfigFile( + const fs::path& filename, + const std::vector& config + ); // Read configuration vector from json file -std::vector readSoundEffectConfigFile( - const fs::path& filename); + std::vector readSoundEffectConfigFile( + const fs::path& filename + ); -} // namespace sfx -} // namespace speech + } // namespace sfx + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp b/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp index c2f2f48..3685dd5 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp @@ -12,55 +12,56 @@ namespace fl::pkg::speech::sfx { -RandomNumberGenerator::RandomNumberGenerator(int seed /* = 0 */) - : randomEngine_(seed), uniformDist_(0, 1), gaussianDist_(0, 1) {} +RandomNumberGenerator::RandomNumberGenerator(int seed /* = 0 */) : randomEngine_(seed), + uniformDist_(0, 1), + gaussianDist_(0, 1) {} int RandomNumberGenerator::randInt(int minVal, int maxVal) { - if (minVal > maxVal) { - std::swap(minVal, maxVal); - } - return randomEngine_() % (maxVal - minVal + 1) + minVal; + if(minVal > maxVal) { + std::swap(minVal, maxVal); + } + return randomEngine_() % (maxVal - minVal + 1) + minVal; } float RandomNumberGenerator::random() { - return uniformDist_(randomEngine_); + return uniformDist_(randomEngine_); } float RandomNumberGenerator::uniform(float minVal, float maxVal) { - return minVal + (maxVal - minVal) * uniformDist_(randomEngine_); + return minVal + (maxVal - minVal) * uniformDist_(randomEngine_); } float RandomNumberGenerator::gaussian(float mean, float sigma) { - return mean + gaussianDist_(randomEngine_) * sigma; + return mean + gaussianDist_(randomEngine_) * sigma; } float rootMeanSquare(const std::vector& signal) { - float sumSquares = 0; - for (int i = 0; i < signal.size(); ++i) { - sumSquares += signal[i] * signal[i]; - } - return std::sqrt(sumSquares / signal.size()); + float sumSquares = 0; + for(int i = 0; i < signal.size(); ++i) { + sumSquares += signal[i] * signal[i]; + } + return std::sqrt(sumSquares / signal.size()); } float signalToNoiseRatio( const std::vector& signal, - const std::vector& noise) { - auto singalRms = rootMeanSquare(signal); - auto noiseRms = rootMeanSquare(noise); - return 20 * std::log10(singalRms / noiseRms); + const std::vector& noise +) { + auto singalRms = rootMeanSquare(signal); + auto noiseRms = rootMeanSquare(noise); + return 20 * std::log10(singalRms / noiseRms); } -std::vector -genTestSinWave(size_t numSamples, size_t freq, size_t sampleRate, float amplitude) { - std::vector output(numSamples, 0); - const float waveLenSamples = - static_cast(sampleRate) / static_cast(freq); - const float ratio = (2 * M_PI) / waveLenSamples; +std::vector genTestSinWave(size_t numSamples, size_t freq, size_t sampleRate, float amplitude) { + std::vector output(numSamples, 0); + const float waveLenSamples = + static_cast(sampleRate) / static_cast(freq); + const float ratio = (2 * M_PI) / waveLenSamples; - for (size_t i = 0; i < numSamples; ++i) { - output.at(i) = amplitude * std::sin(static_cast(i) * ratio); - } - return output; + for(size_t i = 0; i < numSamples; ++i) { + output.at(i) = amplitude * std::sin(static_cast(i) * ratio); + } + return output; } } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/SoundEffectUtil.h b/flashlight/pkg/speech/augmentation/SoundEffectUtil.h index c982066..bead384 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffectUtil.h +++ b/flashlight/pkg/speech/augmentation/SoundEffectUtil.h @@ -15,45 +15,47 @@ namespace fl { namespace pkg { -namespace speech { -namespace sfx { + namespace speech { + namespace sfx { -class RandomNumberGenerator { - public: - explicit RandomNumberGenerator(int seed = 0); + class RandomNumberGenerator { + public: + explicit RandomNumberGenerator(int seed = 0); - /// Returns a random integer N such that minVal <= N <= maxVal - int randInt(int minVal, int maxVal); + /// Returns a random integer N such that minVal <= N <= maxVal + int randInt(int minVal, int maxVal); - /// Returns a random floating point number in the range [0.0, 1.0). - float random(); + /// Returns a random floating point number in the range [0.0, 1.0). + float random(); - /// Returns a random floating point number N such that minVal <= N <= maxVal - float uniform(float minVal, float mx); + /// Returns a random floating point number N such that minVal <= N <= maxVal + float uniform(float minVal, float mx); - /// Returns a random floating point from a gaussian(normal) distribution - /// where mu is the mean, and sigma is the standard deviation - float gaussian(float mean, float sigma); + /// Returns a random floating point from a gaussian(normal) distribution + /// where mu is the mean, and sigma is the standard deviation + float gaussian(float mean, float sigma); - private: - std::mt19937_64 randomEngine_; - std::uniform_real_distribution uniformDist_; - std::normal_distribution gaussianDist_; -}; + private: + std::mt19937_64 randomEngine_; + std::uniform_real_distribution uniformDist_; + std::normal_distribution gaussianDist_; + }; -float rootMeanSquare(const std::vector& signal); + float rootMeanSquare(const std::vector& signal); -float signalToNoiseRatio( - const std::vector& signal, - const std::vector& noise); + float signalToNoiseRatio( + const std::vector& signal, + const std::vector& noise + ); -std::vector genTestSinWave( - size_t numSamples, - size_t freq, - size_t sampleRate, - float amplitude); + std::vector genTestSinWave( + size_t numSamples, + size_t freq, + size_t sampleRate, + float amplitude + ); -} // namespace sfx -} // namespace speech + } // namespace sfx + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/SoxWrapper.cpp b/flashlight/pkg/speech/augmentation/SoxWrapper.cpp index a3e88a5..67dc3e4 100644 --- a/flashlight/pkg/speech/augmentation/SoxWrapper.cpp +++ b/flashlight/pkg/speech/augmentation/SoxWrapper.cpp @@ -21,197 +21,203 @@ namespace fl::pkg::speech::sfx { namespace { -struct SoxData { - std::vector* data; - size_t index = 0; -}; - -static int outputFlow( - sox_effect_t* effp LSX_UNUSED, - sox_sample_t const* ibuf, - sox_sample_t* obuf LSX_UNUSED, - size_t* isamp, - size_t* osamp) { - if (*isamp) { - auto priv = static_cast(effp->priv); - - int i = 0; - for (; i < *isamp; ++i) { - SOX_SAMPLE_LOCALS; - priv->data->push_back(SOX_SAMPLE_TO_FLOAT_32BIT(ibuf[i], effp->clips)); + struct SoxData { + std::vector* data; + size_t index = 0; + }; + + static int outputFlow( + sox_effect_t* effp LSX_UNUSED, + sox_sample_t const* ibuf, + sox_sample_t* obuf LSX_UNUSED, + size_t* isamp, + size_t* osamp + ) { + if(*isamp) { + auto priv = static_cast(effp->priv); + + int i = 0; + for(; i < *isamp; ++i) { + SOX_SAMPLE_LOCALS; + priv->data->push_back(SOX_SAMPLE_TO_FLOAT_32BIT(ibuf[i], effp->clips)); + } + + if(i != *isamp) { + LOG(ERROR) << "outputFlow number of bytes written=" << i + << " expected=" << *isamp + << " priv->data->size()=" << priv->data->size(); + return SOX_EOF; + } + } + + *osamp = 0; + return SOX_SUCCESS; } - if (i != *isamp) { - LOG(ERROR) << "outputFlow number of bytes written=" << i - << " expected=" << *isamp - << " priv->data->size()=" << priv->data->size(); - return SOX_EOF; + int inputDrain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { + auto priv = static_cast(effp->priv); + + int i = 0; + for(; i < *osamp && priv->index < priv->data->size(); ++i, ++priv->index) { + SOX_SAMPLE_LOCALS; + obuf[i] = + SOX_FLOAT_32BIT_TO_SAMPLE(priv->data->at(priv->index), effp->clips); + } + *osamp = i; + return *osamp ? SOX_SUCCESS : SOX_EOF; } - } - *osamp = 0; - return SOX_SUCCESS; -} - -int inputDrain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { - auto priv = static_cast(effp->priv); - - int i = 0; - for (; i < *osamp && priv->index < priv->data->size(); ++i, ++priv->index) { - SOX_SAMPLE_LOCALS; - obuf[i] = - SOX_FLOAT_32BIT_TO_SAMPLE(priv->data->at(priv->index), effp->clips); - } - *osamp = i; - return *osamp ? SOX_SUCCESS : SOX_EOF; -} - -std::unique_ptr createSignalInfo(size_t sampleRate) { - auto sigInfo = std::make_unique(); - *sigInfo = { - .rate = (sox_rate_t)sampleRate, - .channels = 1, // Sounds effects are limited to single channel - .precision = 16, // Any valid value is ok here. - .length = 0, - .mult = nullptr}; - return sigInfo; -} + std::unique_ptr createSignalInfo(size_t sampleRate) { + auto sigInfo = std::make_unique(); + *sigInfo = { + .rate = (sox_rate_t) sampleRate, + .channels = 1, // Sounds effects are limited to single channel + .precision = 16, // Any valid value is ok here. + .length = 0, + .mult = nullptr}; + return sigInfo; + } } // namespace std::unique_ptr SoxWrapper::instance_; -SoxWrapper::SoxWrapper(size_t sampleRate) - : signalInfo_(createSignalInfo(sampleRate)) { - FL_SOX_CHECK(sox_init()); +SoxWrapper::SoxWrapper(size_t sampleRate) : signalInfo_(createSignalInfo(sampleRate)) { + FL_SOX_CHECK(sox_init()); } SoxWrapper::~SoxWrapper() { - sox_quit(); + sox_quit(); } SoxWrapper* SoxWrapper::instance(size_t sampleRate /* =16000*/) { - if (!instance_) { - auto s = new SoxWrapper(sampleRate); - instance_.reset(s); - } - return instance_.get(); + if(!instance_) { + auto s = new SoxWrapper(sampleRate); + instance_.reset(s); + } + return instance_.get(); } void SoxWrapper::applyAndFreeEffect( std::vector& signal, - sox_effect_t* effect) const { - sox_effects_chain_t* chain = createChain(); - addInput(chain, &signal); - FL_SOX_CHECK( - sox_add_effect(chain, effect, signalInfo_.get(), signalInfo_.get())); - free(effect); - std::vector augmented; - addOutput(chain, &augmented); - - sox_flow_effects(chain, nullptr, nullptr); - - sox_delete_effects_chain(chain); - signal.swap(augmented); + sox_effect_t* effect +) const { + sox_effects_chain_t* chain = createChain(); + addInput(chain, &signal); + FL_SOX_CHECK( + sox_add_effect(chain, effect, signalInfo_.get(), signalInfo_.get()) + ); + free(effect); + std::vector augmented; + addOutput(chain, &augmented); + + sox_flow_effects(chain, nullptr, nullptr); + + sox_delete_effects_chain(chain); + signal.swap(augmented); } void SoxWrapper::addInput( sox_effects_chain_t* chain, - std::vector* signal) const { - const static sox_effect_handler_t handler{ - /*name=*/"input", - /*usage=*/nullptr, - /*flags=*/SOX_EFF_MCHAN, - /*getopts=*/nullptr, - /*start=*/nullptr, - /*flow=*/nullptr, - /*drain=*/inputDrain, - /*stop=*/nullptr, - /*kill=*/nullptr, - /*priv_size=*/sizeof(SoxData)}; - sox_effect_t* e = nullptr; - FL_SOX_CHECK(e = sox_create_effect(&handler)); - auto input = (SoxData*)e->priv; - input->data = signal; - input->index = 0; - FL_SOX_CHECK(sox_add_effect(chain, e, signalInfo_.get(), signalInfo_.get())); - free(e); + std::vector* signal +) const { + const static sox_effect_handler_t handler{ + /*name=*/ "input", + /*usage=*/ nullptr, + /*flags=*/ SOX_EFF_MCHAN, + /*getopts=*/ nullptr, + /*start=*/ nullptr, + /*flow=*/ nullptr, + /*drain=*/ inputDrain, + /*stop=*/ nullptr, + /*kill=*/ nullptr, + /*priv_size=*/ sizeof(SoxData)}; + sox_effect_t* e = nullptr; + FL_SOX_CHECK(e = sox_create_effect(&handler)); + auto input = (SoxData*) e->priv; + input->data = signal; + input->index = 0; + FL_SOX_CHECK(sox_add_effect(chain, e, signalInfo_.get(), signalInfo_.get())); + free(e); } void SoxWrapper::addOutput( sox_effects_chain_t* chain, - std::vector* emptyBuf) const { - const static sox_effect_handler_t handler = { - /*name=*/"output", - /*usage=*/nullptr, - /*flags=*/SOX_EFF_MCHAN, - /*getopts=*/nullptr, - /*start=*/nullptr, - /*flow=*/outputFlow, - /*drain=*/nullptr, - /*stop=*/nullptr, - /*kill=*/nullptr, - /*priv_size=*/sizeof(SoxData)}; - sox_effect_t* e = nullptr; - FL_SOX_CHECK(e = sox_create_effect(&handler)); - auto output = (SoxData*)e->priv; - output->data = emptyBuf; - FL_SOX_CHECK(sox_add_effect(chain, e, signalInfo_.get(), signalInfo_.get())); - free(e); + std::vector* emptyBuf +) const { + const static sox_effect_handler_t handler = { + /*name=*/ "output", + /*usage=*/ nullptr, + /*flags=*/ SOX_EFF_MCHAN, + /*getopts=*/ nullptr, + /*start=*/ nullptr, + /*flow=*/ outputFlow, + /*drain=*/ nullptr, + /*stop=*/ nullptr, + /*kill=*/ nullptr, + /*priv_size=*/ sizeof(SoxData)}; + sox_effect_t* e = nullptr; + FL_SOX_CHECK(e = sox_create_effect(&handler)); + auto output = (SoxData*) e->priv; + output->data = emptyBuf; + FL_SOX_CHECK(sox_add_effect(chain, e, signalInfo_.get(), signalInfo_.get())); + free(e); } void SoxWrapper::addAndFreeEffect( sox_effects_chain_t* chain, - sox_effect_t* effect) const { - FL_SOX_CHECK( - sox_add_effect(chain, effect, signalInfo_.get(), signalInfo_.get())); - free(effect); + sox_effect_t* effect +) const { + FL_SOX_CHECK( + sox_add_effect(chain, effect, signalInfo_.get(), signalInfo_.get()) + ); + free(effect); } sox_effects_chain_t* SoxWrapper::createChain() const { - const static sox_encodinginfo_t encoding = { - .encoding = SOX_ENCODING_FLOAT, - .bits_per_sample = 0, - .compression = HUGE_VAL, // no compression - .reverse_bytes = sox_option_no, - .reverse_nibbles = sox_option_no, - .reverse_bits = sox_option_no, - .opposite_endian = sox_false}; - sox_effects_chain_t* chain = nullptr; - FL_SOX_CHECK(chain = sox_create_effects_chain(&encoding, &encoding)); - return chain; + const static sox_encodinginfo_t encoding = { + .encoding = SOX_ENCODING_FLOAT, + .bits_per_sample = 0, + .compression = HUGE_VAL, // no compression + .reverse_bytes = sox_option_no, + .reverse_nibbles = sox_option_no, + .reverse_bits = sox_option_no, + .opposite_endian = sox_false}; + sox_effects_chain_t* chain = nullptr; + FL_SOX_CHECK(chain = sox_create_effects_chain(&encoding, &encoding)); + return chain; } namespace detail { -void check(bool success, const char* msg, const char* file, int line) { - if (!success) { - std::stringstream ss; - ss << file << ':' << line << "] libsox error when executing: " << msg; - LOG(ERROR) << ss.str(); - throw std::runtime_error(ss.str()); - } -} + void check(bool success, const char* msg, const char* file, int line) { + if(!success) { + std::stringstream ss; + ss << file << ':' << line << "] libsox error when executing: " << msg; + LOG(ERROR) << ss.str(); + throw std::runtime_error(ss.str()); + } + } -void check(int status, const char* msg, const char* file, int line) { - if (status != SOX_SUCCESS) { - std::stringstream ss; - ss << file << ':' << line << "] libsox error: " << status - << " when executing: " << msg; - LOG(ERROR) << ss.str(); - throw std::runtime_error(ss.str()); - } -} + void check(int status, const char* msg, const char* file, int line) { + if(status != SOX_SUCCESS) { + std::stringstream ss; + ss << file << ':' << line << "] libsox error: " << status + << " when executing: " << msg; + LOG(ERROR) << ss.str(); + throw std::runtime_error(ss.str()); + } + } -void check(const void* ptr, const char* msg, const char* file, int line) { - if (!ptr) { - std::stringstream ss; - ss << file << ':' << line - << "] libsox failed to allocate when executing: " << msg; - LOG(ERROR) << ss.str(); - throw std::runtime_error(ss.str()); - } -} + void check(const void* ptr, const char* msg, const char* file, int line) { + if(!ptr) { + std::stringstream ss; + ss << file << ':' << line + << "] libsox failed to allocate when executing: " << msg; + LOG(ERROR) << ss.str(); + throw std::runtime_error(ss.str()); + } + } } // namespace detail diff --git a/flashlight/pkg/speech/augmentation/SoxWrapper.h b/flashlight/pkg/speech/augmentation/SoxWrapper.h index f909ae7..717e4ca 100644 --- a/flashlight/pkg/speech/augmentation/SoxWrapper.h +++ b/flashlight/pkg/speech/augmentation/SoxWrapper.h @@ -17,7 +17,7 @@ #include "flashlight/pkg/speech/augmentation/SoundEffectUtil.h" #define FL_SOX_CHECK(expr) \ - ::fl::pkg::speech::sfx::detail::check((expr), #expr, __FILE__, __LINE__) + ::fl::pkg::speech::sfx::detail::check((expr), #expr, __FILE__, __LINE__) // Add forward declareations of sox.h related types so we can keep the include // of sox.h in the cpp file, thus avoiding proliferation of dependency on a @@ -29,8 +29,8 @@ class sox_signalinfo_t; namespace fl { namespace pkg { -namespace speech { -namespace sfx { + namespace speech { + namespace sfx { #ifdef FL_BUILD_APP_ASR_SFX_SOX /** @@ -51,54 +51,54 @@ namespace sfx { * SoxWrapper::instance()->applyAndFreeEffect(signal, e); * \endcode */ -class SoxWrapper { - public: - static SoxWrapper* instance(size_t sampleRate = 16000); - ~SoxWrapper(); - - /** - * Apply the given libsox effect on the signal. - */ - void applyAndFreeEffect(std::vector& signal, sox_effect_t* effect) - const; - - private: - explicit SoxWrapper(size_t sampleRate); - - sox_effects_chain_t* createChain() const; - void addInput(sox_effects_chain_t* chain, std::vector* signal) const; - void addOutput(sox_effects_chain_t* chain, std::vector* emptyBuf) - const; - void addAndFreeEffect(sox_effects_chain_t* chain, sox_effect_t* effect) const; - - // sox wants pointer to non-const sox_signalinfo_t but it does not change it. - // mutable is so we can pass pointer to signalInfo_ from const methods. - mutable std::unique_ptr signalInfo_; - static std::unique_ptr instance_; -}; + class SoxWrapper { + public: + static SoxWrapper* instance(size_t sampleRate = 16000); + ~SoxWrapper(); + + /** + * Apply the given libsox effect on the signal. + */ + void applyAndFreeEffect(std::vector& signal, sox_effect_t* effect) + const; + + private: + explicit SoxWrapper(size_t sampleRate); + + sox_effects_chain_t* createChain() const; + void addInput(sox_effects_chain_t* chain, std::vector* signal) const; + void addOutput(sox_effects_chain_t* chain, std::vector* emptyBuf) + const; + void addAndFreeEffect(sox_effects_chain_t* chain, sox_effect_t* effect) const; + + // sox wants pointer to non-const sox_signalinfo_t but it does not change it. + // mutable is so we can pass pointer to signalInfo_ from const methods. + mutable std::unique_ptr signalInfo_; + static std::unique_ptr instance_; + }; #else /* ifdef FL_BUILD_APP_ASR_SFX_SOX */ // Definition with null implementation to stub out SoxWrapper // when building sound effects without libsox. -class SoxWrapper { - public: - static SoxWrapper* instance(size_t sampleRate = 16000) { - return nullptr; - } - void applyAndFreeEffect(std::vector& signal, sox_effect_t* effect) - const {} -}; -#endif /* FL_BUILD_APP_ASR_SFX_SOX */ - -namespace detail { - -void check(bool success, const char* msg, const char* file, int line); -void check(int status, const char* msg, const char* file, int line); -void check(const void* ptr, const char* msg, const char* file, int line); - -} // namespace detail - -} // namespace sfx -} // namespace speech + class SoxWrapper { + public: + static SoxWrapper* instance(size_t sampleRate = 16000) { + return nullptr; + } + void applyAndFreeEffect(std::vector& signal, sox_effect_t* effect) + const {} + }; +#endif /* FL_BUILD_APP_ASR_SFX_SOX */ + + namespace detail { + + void check(bool success, const char* msg, const char* file, int line); + void check(int status, const char* msg, const char* file, int line); + void check(const void* ptr, const char* msg, const char* file, int line); + + } // namespace detail + + } // namespace sfx + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/TimeStretch.cpp b/flashlight/pkg/speech/augmentation/TimeStretch.cpp index b7e28cb..6d4f88d 100644 --- a/flashlight/pkg/speech/augmentation/TimeStretch.cpp +++ b/flashlight/pkg/speech/augmentation/TimeStretch.cpp @@ -20,37 +20,37 @@ namespace fl::pkg::speech::sfx { TimeStretch::TimeStretch( const TimeStretch::Config& config, - unsigned int seed /* =0 */) - : conf_(config), - rng_(seed), - sox_(SoxWrapper::instance(config.sampleRate_)) { - FL_SOX_CHECK(stretchEffect_ = sox_find_effect("stretch")); + unsigned int seed /* =0 */ +) : conf_(config), + rng_(seed), + sox_(SoxWrapper::instance(config.sampleRate_)) { + FL_SOX_CHECK(stretchEffect_ = sox_find_effect("stretch")); } void TimeStretch::apply(std::vector& signal) { - if (rng_.random() >= conf_.proba_) { - return; - } - const float factor = rng_.uniform(conf_.minFactor_, conf_.maxFactor_); - sox_effect_t* e = sox_create_effect(stretchEffect_); - std::string _factor = std::to_string(factor); - char* args[] = {_factor.data()}; - FL_SOX_CHECK(sox_effect_options(e, 1, args)); - sox_->applyAndFreeEffect(signal, e); + if(rng_.random() >= conf_.proba_) { + return; + } + const float factor = rng_.uniform(conf_.minFactor_, conf_.maxFactor_); + sox_effect_t* e = sox_create_effect(stretchEffect_); + std::string _factor = std::to_string(factor); + char* args[] = {_factor.data()}; + FL_SOX_CHECK(sox_effect_options(e, 1, args)); + sox_->applyAndFreeEffect(signal, e); } std::string TimeStretch::Config::prettyString() const { - std::stringstream ss; - ss << "TimeStretch::Config{minFactor_=" << minFactor_ - << " maxFactor_=" << maxFactor_ << " proba_=" << proba_ - << " sampleRate_=" << sampleRate_ << '}'; - return ss.str(); + std::stringstream ss; + ss << "TimeStretch::Config{minFactor_=" << minFactor_ + << " maxFactor_=" << maxFactor_ << " proba_=" << proba_ + << " sampleRate_=" << sampleRate_ << '}'; + return ss.str(); } std::string TimeStretch::prettyString() const { - std::stringstream ss; - ss << "TimeStretch{config={" << conf_.prettyString() << '}'; - return ss.str(); + std::stringstream ss; + ss << "TimeStretch{config={" << conf_.prettyString() << '}'; + return ss.str(); }; } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/TimeStretch.h b/flashlight/pkg/speech/augmentation/TimeStretch.h index 62b3f40..aa0a66f 100644 --- a/flashlight/pkg/speech/augmentation/TimeStretch.h +++ b/flashlight/pkg/speech/augmentation/TimeStretch.h @@ -18,70 +18,72 @@ namespace fl { namespace pkg { -namespace speech { -namespace sfx { + namespace speech { + namespace sfx { #ifdef FL_BUILD_APP_ASR_SFX_SOX /** * Stretches signal within specified ratio range using libSOX as backend. */ -class TimeStretch : public SoundEffect { - public: - struct Config { - /** - * probability of applying reverb. - */ - float proba_ = 1.0; - double minFactor_ = 0.8; /* stretch factor. 1.0 means copy. */ - double maxFactor_ = 1.25; /* stretch factor. 1.0 means copy. */ - size_t sampleRate_ = 16000; - std::string prettyString() const; - }; + class TimeStretch : public SoundEffect { + public: + struct Config { + /** + * probability of applying reverb. + */ + float proba_ = 1.0; + double minFactor_ = 0.8; /* stretch factor. 1.0 means copy. */ + double maxFactor_ = 1.25; /* stretch factor. 1.0 means copy. */ + size_t sampleRate_ = 16000; + std::string prettyString() const; + }; - explicit TimeStretch( - const TimeStretch::Config& config, - unsigned int seed = 0); - ~TimeStretch() override = default; - void apply(std::vector& data) override; - std::string prettyString() const override; + explicit TimeStretch( + const TimeStretch::Config& config, + unsigned int seed = 0 + ); + ~TimeStretch() override = default; + void apply(std::vector& data) override; + std::string prettyString() const override; - private: - const TimeStretch::Config conf_; - RandomNumberGenerator rng_; - // The next 2 pointers are kept for optimization and readability. - // They point to existing objects that are not constructed or - // destroyed by this class - SoxWrapper const* sox_; - const sox_effect_handler_t* stretchEffect_; -}; + private: + const TimeStretch::Config conf_; + RandomNumberGenerator rng_; + // The next 2 pointers are kept for optimization and readability. + // They point to existing objects that are not constructed or + // destroyed by this class + SoxWrapper const* sox_; + const sox_effect_handler_t* stretchEffect_; + }; #else /* ifdef FL_BUILD_APP_ASR_SFX_SOX */ // Definition with null implementation to stub out TimeStretch // when building sound effects without libsox. -class TimeStretch : public SoundEffect { - public: - struct Config { - float proba_; - double minFactor_; - double maxFactor_; - size_t sampleRate_; - std::string prettyString() const { - return ""; - } - }; - explicit TimeStretch( - const TimeStretch::Config& config, - unsigned int seed = 0) {} - ~TimeStretch() override = default; - void apply(std::vector& data) override {} - std::string prettyString() const override { - return ""; - } -}; + class TimeStretch : public SoundEffect { + public: + struct Config { + float proba_; + double minFactor_; + double maxFactor_; + size_t sampleRate_; + std::string prettyString() const { + return ""; + } + }; + explicit TimeStretch( + const TimeStretch::Config& config, + unsigned int seed = 0 + ) {} + ~TimeStretch() override = default; + void apply(std::vector& data) override {} + std::string prettyString() const override { + return ""; + } + }; #endif -} // namespace sfx -} // namespace speech + } // namespace sfx + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/common/Defines.h b/flashlight/pkg/speech/common/Defines.h index 252aec2..c83490a 100644 --- a/flashlight/pkg/speech/common/Defines.h +++ b/flashlight/pkg/speech/common/Defines.h @@ -14,63 +14,63 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { // Dataset indices // If a new field is added, `kNumDataIdx` should be modified accordingly. -constexpr size_t kInputIdx = 0; -constexpr size_t kTargetIdx = 1; -constexpr size_t kWordIdx = 2; -constexpr size_t kSampleIdx = 3; -constexpr size_t kPathIdx = 4; -constexpr size_t kDurationIdx = 5; -constexpr size_t kTargetSizeIdx = 6; -constexpr size_t kNumDataIdx = 7; // total number of dataset indices + constexpr size_t kInputIdx = 0; + constexpr size_t kTargetIdx = 1; + constexpr size_t kWordIdx = 2; + constexpr size_t kSampleIdx = 3; + constexpr size_t kPathIdx = 4; + constexpr size_t kDurationIdx = 5; + constexpr size_t kTargetSizeIdx = 6; + constexpr size_t kNumDataIdx = 7; // total number of dataset indices // Various constants used in asr task -constexpr const char* kTrainMode = "train"; -constexpr const char* kContinueMode = "continue"; -constexpr const char* kForkMode = "fork"; -constexpr const char* kGflags = "gflags"; -constexpr const char* kCommandLine = "commandline"; -constexpr const char* kUserName = "username"; -constexpr const char* kHostName = "hostname"; -constexpr const char* kTimestamp = "timestamp"; -constexpr const char* kRunIdx = "runIdx"; -constexpr const char* kRunPath = "runPath"; -constexpr const char* kProgramName = "programname"; -constexpr const char* kEpoch = "epoch"; -constexpr const char* kUpdates = "updates"; -constexpr const char* kScaleFactor = "scalefactor"; -constexpr const char* kSGDOptimizer = "sgd"; -constexpr const char* kAdamOptimizer = "adam"; -constexpr const char* kRMSPropOptimizer = "rmsprop"; -constexpr const char* kAdadeltaOptimizer = "adadelta"; -constexpr const char* kAdagradOptimizer = "adagrad"; -constexpr const char* kAMSgradOptimizer = "amsgrad"; -constexpr const char* kNovogradOptimizer = "novograd"; -constexpr const char* kCtcCriterion = "ctc"; -constexpr const char* kAsgCriterion = "asg"; -constexpr const char* kSeq2SeqRNNCriterion = "s2srnn"; -constexpr const char* kSeq2SeqTransformerCriterion = "s2stransformer"; -constexpr const char* kBatchStrategyNone = "none"; -constexpr const char* kBatchStrategyDynamic = "dynamic"; -constexpr const char* kBatchStrategyRandDynamic = "randdynamic"; -constexpr const char* kBatchStrategyRand = "rand"; -constexpr const char* kFeaturesMFSC = "mfsc"; -constexpr const char* kFeaturesMFCC = "mfcc"; -constexpr const char* kFeaturesPow = "pow"; -constexpr const char* kFeaturesRaw = "raw"; -constexpr int kTargetPadValue = -1; + constexpr const char* kTrainMode = "train"; + constexpr const char* kContinueMode = "continue"; + constexpr const char* kForkMode = "fork"; + constexpr const char* kGflags = "gflags"; + constexpr const char* kCommandLine = "commandline"; + constexpr const char* kUserName = "username"; + constexpr const char* kHostName = "hostname"; + constexpr const char* kTimestamp = "timestamp"; + constexpr const char* kRunIdx = "runIdx"; + constexpr const char* kRunPath = "runPath"; + constexpr const char* kProgramName = "programname"; + constexpr const char* kEpoch = "epoch"; + constexpr const char* kUpdates = "updates"; + constexpr const char* kScaleFactor = "scalefactor"; + constexpr const char* kSGDOptimizer = "sgd"; + constexpr const char* kAdamOptimizer = "adam"; + constexpr const char* kRMSPropOptimizer = "rmsprop"; + constexpr const char* kAdadeltaOptimizer = "adadelta"; + constexpr const char* kAdagradOptimizer = "adagrad"; + constexpr const char* kAMSgradOptimizer = "amsgrad"; + constexpr const char* kNovogradOptimizer = "novograd"; + constexpr const char* kCtcCriterion = "ctc"; + constexpr const char* kAsgCriterion = "asg"; + constexpr const char* kSeq2SeqRNNCriterion = "s2srnn"; + constexpr const char* kSeq2SeqTransformerCriterion = "s2stransformer"; + constexpr const char* kBatchStrategyNone = "none"; + constexpr const char* kBatchStrategyDynamic = "dynamic"; + constexpr const char* kBatchStrategyRandDynamic = "randdynamic"; + constexpr const char* kBatchStrategyRand = "rand"; + constexpr const char* kFeaturesMFSC = "mfsc"; + constexpr const char* kFeaturesMFCC = "mfcc"; + constexpr const char* kFeaturesPow = "pow"; + constexpr const char* kFeaturesRaw = "raw"; + constexpr int kTargetPadValue = -1; // Feature params -constexpr int kLifterParam = 22; -constexpr int kPrefetchSize = 2; + constexpr int kLifterParam = 22; + constexpr int kPrefetchSize = 2; -constexpr const char* kEosToken = "$"; -constexpr const char* kBlankToken = "#"; -constexpr const char* kSilToken = "|"; + constexpr const char* kEosToken = "$"; + constexpr const char* kBlankToken = "#"; + constexpr const char* kSilToken = "|"; -} // namespace speech + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/common/Flags.cpp b/flashlight/pkg/speech/common/Flags.cpp index 8a30492..175904a 100644 --- a/flashlight/pkg/speech/common/Flags.cpp +++ b/flashlight/pkg/speech/common/Flags.cpp @@ -25,38 +25,46 @@ DEFINE_string( train, "", "[train] Comma-separated list of training files where each row specifies sample " - "information in the format [sample_id audio_absolute_path size transcription]"); + "information in the format [sample_id audio_absolute_path size transcription]" +); DEFINE_string( valid, "", "[train] Comma-separated list of validation files where each row specifies sample " - "information in the format [sample_id audio_absolute_path size transcription]"); + "information in the format [sample_id audio_absolute_path size transcription]" +); DEFINE_string( test, "", "[test, decode] Comma-separated list of test files where each row specifies sample " - "information in the format [sample_id audio_absolute_path size transcription]"); + "information in the format [sample_id audio_absolute_path size transcription]" +); DEFINE_int64( batchsize, 1, - "[train] Batch size for training data (per process in distributed training)"); + "[train] Batch size for training data (per process in distributed training)" +); DEFINE_int64( validbatchsize, -1, "[train] Batch size for validation data (per process in distributed training). " - "If -1 then use the value of the 'batchsize' flag"); + "If -1 then use the value of the 'batchsize' flag" +); DEFINE_int64( samplerate, 16000, - "Sample rate (Hz) for training, validation and test audio data"); + "Sample rate (Hz) for training, validation and test audio data" +); DEFINE_int64( channels, 1, - "Number of input channels in training, validation and test audio data"); + "Number of input channels in training, validation and test audio data" +); DEFINE_string( tokens, "tokens.txt", - "Tokens file path, the 'tokensdir' flag is used as a prefix for this path"); + "Tokens file path, the 'tokensdir' flag is used as a prefix for this path" +); DEFINE_string( batching_strategy, "none", @@ -66,143 +74,172 @@ DEFINE_string( "and 'max_tokens' will be used to compute the effective batch size. " "To use unordered input data to pack batches, use either 'rand' " "or 'randdynamic' which shuffles data before packing, " - " then follows the same packing strategies as 'none' or 'dynamic', respectively."); + " then follows the same packing strategies as 'none' or 'dynamic', respectively." +); DEFINE_int64( batching_max_duration, 0, "Maximum number of tokens/frames in the batch when using 'dynamic' batching strategy. " - "Measured with the same unit as input sizes are specified in data list files"); + "Measured with the same unit as input sizes are specified in data list files" +); DEFINE_bool( usewordpiece, false, "Specify if a word separator can be used inside of a token. " "Should be used if the SentencePiece tool is used to " - "construct a token set containing word-pieces"); + "construct a token set containing word-pieces" +); DEFINE_int64( replabel, 0, - "Replace up to replabel reptitions by additional token classes"); + "Replace up to replabel reptitions by additional token classes" +); DEFINE_string( surround, "", - "Surround target tokens sequence with provided token (duplicates are removed)"); + "Surround target tokens sequence with provided token (duplicates are removed)" +); DEFINE_string( wordseparator, kSilToken, "Defines a word separator token used to map tokens sequences to words. " - "Defaults to a pre-defined value."); + "Defaults to a pre-defined value." +); DEFINE_double( sampletarget, 0.0, "The probability [0.0, 1.0] with which targets are randomly sampled from a " - "lexicon if multiple token constructions exist for a given word"); + "lexicon if multiple token constructions exist for a given word" +); // NORMALIZATION OPTIONS DEFINE_int64( localnrmlleftctx, 0, "Left context size for local normalization of input " - "audio after featurization (computation of MFCC, etc.)"); + "audio after featurization (computation of MFCC, etc.)" +); DEFINE_int64( localnrmlrightctx, 0, "Right context size for local normalization of input " - "audio after featurization (computation of MFCC, etc.)"); + "audio after featurization (computation of MFCC, etc.)" +); DEFINE_string( onorm, "none", "[train] Criterion normalization mode. One of: " "{'none' - no normalization, 'target' - by target size, " - "'input' - by input size}"); + "'input' - by input size}" +); DEFINE_bool( sqnorm, false, - "[train] Use square-root while normalizing criterion loss with 'onorm' mode"); + "[train] Use square-root while normalizing criterion loss with 'onorm' mode" +); DEFINE_bool( lrcosine, false, - "[train] Use cosine learning rate schedule, see usage for more details"); + "[train] Use cosine learning rate schedule, see usage for more details" +); // LEARNING HYPER-PARAMETER OPTIONS DEFINE_int64( iter, std::numeric_limits::max(), - "[train] Total number of updates for training"); + "[train] Total number of updates for training" +); DEFINE_bool(itersave, false, "Save model or not at each update"); DEFINE_double(lr, 1.0, "[train] Learning rate for the network parameters"); DEFINE_double( momentum, 0.0, - "[train] Momentum factor used in SGD optimizer for network only"); + "[train] Momentum factor used in SGD optimizer for network only" +); DEFINE_double( weightdecay, 0.0, - "[train] Weight decay (L2 penalty) for optimization for network only"); + "[train] Weight decay (L2 penalty) for optimization for network only" +); DEFINE_double( lrcrit, 0, - "[train] Criterion learning rate (for 'asg', 'seq2seq' and 'transformer' criterions)"); + "[train] Criterion learning rate (for 'asg', 'seq2seq' and 'transformer' criterions)" +); DEFINE_int64( warmup, 1, - "[train] Number of updates for warmup learning rate from 0 to 'lr'/'lrcrit' for network/criterion"); + "[train] Number of updates for warmup learning rate from 0 to 'lr'/'lrcrit' for network/criterion" +); DEFINE_int64( saug_start_update, -1, "[train] Use SpecAugment starting at the update number inputted. -1 means no SpecAugment. " "In case of raw wav input ('mfcc', 'pow' and 'mfsc' are all false) " - "we apply RawSpecAugment which emulates behaviour of SpecAugment"); + "we apply RawSpecAugment which emulates behaviour of SpecAugment" +); DEFINE_int64( lr_decay, std::numeric_limits::max(), - "[train] Epoch value when we start to decay 'lr'/'lrcrit'"); + "[train] Epoch value when we start to decay 'lr'/'lrcrit'" +); DEFINE_int64( lr_decay_step, std::numeric_limits::max(), - "[train] Amount to decay 'lr' and 'lrcrit' starting from epoch 'lr_decay'"); + "[train] Amount to decay 'lr' and 'lrcrit' starting from epoch 'lr_decay'" +); DEFINE_double( maxgradnorm, 0, - "[train] Maximum gradient norm to which gradients exceeding it will be clipped (0 = no clipping)"); + "[train] Maximum gradient norm to which gradients exceeding it will be clipped (0 = no clipping)" +); DEFINE_double( adambeta1, 0.9, - "[train] Beta1 parameter in the Adam, AMSGrad and NovoGrad optimizers"); + "[train] Beta1 parameter in the Adam, AMSGrad and NovoGrad optimizers" +); DEFINE_double( adambeta2, 0.999, - "[train] Beta2 parameter in the Adam, AMSGrad and NovoGrad optimizers"); + "[train] Beta2 parameter in the Adam, AMSGrad and NovoGrad optimizers" +); DEFINE_double( optimrho, 0.9, - "[train] Rho parameter in the RMSProp and Adadelta optimizers"); + "[train] Rho parameter in the RMSProp and Adadelta optimizers" +); DEFINE_double( optimepsilon, 1e-8, - "[train] Epsilon parameter in the Adam, AMSGrad, NovoGrad, Adadelta, RMSProp and Adagrad optimizers"); + "[train] Epsilon parameter in the Adam, AMSGrad, NovoGrad, Adadelta, RMSProp and Adagrad optimizers" +); // LR-SCHEDULER OPTIONS DEFINE_int64( stepsize, std::numeric_limits::max(), "[train] Learning rate schedule if 'lrcosine=false'." - "We multiply 'lr'/'lrcrit' by 'gamma' every 'stepsize' updates"); + "We multiply 'lr'/'lrcrit' by 'gamma' every 'stepsize' updates" +); DEFINE_double( gamma, 1.0, - "[train] Learning rate annealing multiplier, see detail in 'stepsize' flag"); + "[train] Learning rate annealing multiplier, see detail in 'stepsize' flag" +); // OPTIMIZER OPTIONS DEFINE_string( netoptim, kSGDOptimizer, "[train] Optimizer for the network, supported ones " - "'sgd', 'adam', 'rmsprop', 'adadelta', 'adagrad', 'amsgrad', 'novograd'"); + "'sgd', 'adam', 'rmsprop', 'adadelta', 'adagrad', 'amsgrad', 'novograd'" +); DEFINE_string( critoptim, kSGDOptimizer, "[train] Optimizer for the criterion (for 'asg', 'seq2seq' and 'transformer' criterions), " - "supported ones 'sgd', 'adam', 'rmsprop', 'adadelta', 'adagrad', 'amsgrad', 'novograd'"); + "supported ones 'sgd', 'adam', 'rmsprop', 'adadelta', 'adagrad', 'amsgrad', 'novograd'" +); // MFCC OPTIONS DEFINE_string( @@ -210,113 +247,135 @@ DEFINE_string( "mfsc", "Features type to compute input by processing audio. Could be " "mfcc: standard htk mfcc features; mfsc: standard mfsc features; " - "pow: standard power spectrum; raw: raw wave"); + "pow: standard power spectrum; raw: raw wave" +); DEFINE_int64(mfcccoeffs, 13, "Number of mfcc coefficients"); DEFINE_double(melfloor, 1.0, "Specify optional mel floor for mfcc/mfsc/pow"); DEFINE_int64( filterbanks, 80, "Number of mel-filter bank channels, " - "used also with RawSpecAugment to define number of mel-scale bins"); + "used also with RawSpecAugment to define number of mel-scale bins" +); DEFINE_int64(devwin, 0, "Window length for delta and doubledelta derivatives"); DEFINE_int64(fftcachesize, 1, "Number of cached cuFFT plans in GPU memory"); DEFINE_int64( framesizems, 25, - "Window size in millisecond for power spectrum features"); + "Window size in millisecond for power spectrum features" +); DEFINE_int64( framestridems, 10, - "Stride in milliseconds for power spectrum features"); + "Stride in milliseconds for power spectrum features" +); DEFINE_int64( lowfreqfilterbank, 0, "Low freq filter bank (Hz). " - "Is used also in RawSpecAugment to define the lowest frequecny bound for augmentation"); + "Is used also in RawSpecAugment to define the lowest frequecny bound for augmentation" +); DEFINE_int64( highfreqfilterbank, -1, "High freq filter bank (Hz). " - "Is used also in RawSpecAugment to define the highest frequecny bound for augmentation"); + "Is used also in RawSpecAugment to define the highest frequecny bound for augmentation" +); // SPECAUGMENT OPTIONS DEFINE_int64( saug_fmaskf, 27, "[train] Maximum number of frequency bands / mel-scale bands " - "that are masked in SpecAugment/RawSpecAugment"); + "that are masked in SpecAugment/RawSpecAugment" +); DEFINE_int64( saug_fmaskn, 2, - "[train] Number of frequency masks in SpecAugment/RawSpecAugment"); + "[train] Number of frequency masks in SpecAugment/RawSpecAugment" +); DEFINE_int64( saug_tmaskt, 100, - "[train] Maximum number of frames (input elements) that are masked in SpecAugment/RawSpecAugment"); + "[train] Maximum number of frames (input elements) that are masked in SpecAugment/RawSpecAugment" +); DEFINE_double( saug_tmaskp, 1.0, "[train] Maximum proportion of the input sequence (1.0 is 100%) " - "that can be masked in time for SpecAugment/RawSpecAugment"); + "that can be masked in time for SpecAugment/RawSpecAugment" +); DEFINE_int64( saug_tmaskn, 2, - "[train] Number of time masks in SpecAugment/RawSpecAugment"); + "[train] Number of time masks in SpecAugment/RawSpecAugment" +); // SOUND EFFECTS AUGMENTATION OPTIONS DEFINE_string( sfx_config, "", "[train] Path to a sound effect json config file. When set the sound effect is " - "applied to augment the input data."); + "applied to augment the input data." +); DEFINE_int64( sfx_start_update, std::numeric_limits::max(), - "[train] Start sount effect augmentation starting at this update iteration."); + "[train] Start sount effect augmentation starting at this update iteration." +); // RUN OPTIONS DEFINE_string(datadir, "", "Prefix to the 'train'/'valid'/'test' files paths"); DEFINE_string( rundir, "", - "[train] Name of the experiment root directory where logs, snapshots will be stored"); + "[train] Name of the experiment root directory where logs, snapshots will be stored" +); DEFINE_string( flagsfile, "", - "File specifying gflags, could specify only part of flags"); + "File specifying gflags, could specify only part of flags" +); DEFINE_int64( nthread, 1, - "[train] Number of threads for data parallelization (prefetching the data)"); + "[train] Number of threads for data parallelization (prefetching the data)" +); DEFINE_int64( seed, 0, - "[train] Manually specify Arrayfire seed for reproducibility"); + "[train] Manually specify Arrayfire seed for reproducibility" +); DEFINE_int64( reportiters, 0, "[train] Number of updates after which we will run evaluation on validation data \ - and save model, if 0 we only do this at end of each epoch"); + and save model, if 0 we only do this at end of each epoch" +); DEFINE_double( pcttraineval, 100, - "[train] Percentage of training set (by number of utts) to use for evaluation"); + "[train] Percentage of training set (by number of utts) to use for evaluation" +); DEFINE_bool( fl_benchmark_mode, true, "[train] Sets flashlight benchmark mode, which dynamically " "benchmarks various operations based on their empirical performance on " - "current hardware throughout training"); + "current hardware throughout training" +); DEFINE_string( fl_optim_mode, "", "[train] Sets the flashlight optimization mode. " - "Optim modes can be O1, O2, or O3."); + "Optim modes can be O1, O2, or O3." +); DEFINE_string( fl_log_level, "", "Sets the logging level - " - "must be [FATAL, ERROR, WARNING, INFO]"); + "must be [FATAL, ERROR, WARNING, INFO]" +); DEFINE_int64(fl_vlog_level, 0, "Sets the verbose logging level"); DEFINE_int64( @@ -324,7 +383,8 @@ DEFINE_int64( 0, "Flushes memory manager logs after a specified " "number of log entries. 1000000 is a reasonable " - "value which will reduce overhead."); + "value which will reduce overhead." +); // MIXED PRECISION OPTIONS DEFINE_bool( @@ -333,50 +393,60 @@ DEFINE_bool( "[train] Use mixed precision for training - scale loss and gradients up and down " "by a scale factor that changes over time. If no fl optim mode is " "specified with --fl_optim_mode when passing this flag, automatically " - "sets the optim mode to O1."); + "sets the optim mode to O1." +); DEFINE_double( fl_amp_scale_factor, 4096., "[train] Starting scale factor to use for loss scaling " - " with mixed precision training"); + " with mixed precision training" +); DEFINE_uint64( fl_amp_scale_factor_update_interval, 2000, - "[train] Update interval for adjusting loss scaling in mixed precision training"); + "[train] Update interval for adjusting loss scaling in mixed precision training" +); DEFINE_uint64( fl_amp_max_scale_factor, 32000, - "[train] Maximum value for the loss scale factor in mixed precision training"); + "[train] Maximum value for the loss scale factor in mixed precision training" +); // ARCHITECTURE OPTIONS DEFINE_string( arch, "default", - "[train] Network architecture file path"); + "[train] Network architecture file path" +); DEFINE_string( criterion, kAsgCriterion, "[train] Training criterion to apply on top of network: 'asg', 'ctc', " "'seq2seq' (seq2seq with attention rnn decoder), " - "'transformer' (seq2seq with attention and transfromer decoder)"); + "'transformer' (seq2seq with attention and transfromer decoder)" +); DEFINE_int64( encoderdim, 0, - "[train]: Dimension of encoded hidden state for 'seq2seq' and 'transformer' criterions"); + "[train]: Dimension of encoded hidden state for 'seq2seq' and 'transformer' criterions" +); // Seq2Seq Transformer decoder DEFINE_int64( am_decoder_tr_layers, 1, - "[train]: 'transformer' criterion decoder architecture: number of layers"); + "[train]: 'transformer' criterion decoder architecture: number of layers" +); DEFINE_double( am_decoder_tr_dropout, 0.0, - "[train]: 'transformer' criterion decoder architecture: dropout"); + "[train]: 'transformer' criterion decoder architecture: dropout" +); DEFINE_double( am_decoder_tr_layerdrop, 0.0, - "[train]: 'transformer' criterion decoder architecture: layerdrop"); + "[train]: 'transformer' criterion decoder architecture: layerdrop" +); // DECODER OPTIONS @@ -384,280 +454,340 @@ DEFINE_bool(show, false, "[test, decode] Show predictions in the stdout"); DEFINE_bool( showletters, false, - "[decode] Show tokens predictions in the stdout"); + "[decode] Show tokens predictions in the stdout" +); DEFINE_bool( logadd, false, - "[decode] Use logadd operation when merging decoder nodes"); + "[decode] Use logadd operation when merging decoder nodes" +); DEFINE_bool( uselexicon, true, - "[test, decode] Use lexicon file to map between words and tokens sequence"); + "[test, decode] Use lexicon file to map between words and tokens sequence" +); DEFINE_bool(isbeamdump, false, "[decode] Dump the decoding beam to the disk"); DEFINE_string( smearing, "none", "[decode] How to perform trie smearing to have proxy " - "on scores in the middle of a word: 'none', 'max' or 'logadd'"); + "on scores in the middle of a word: 'none', 'max' or 'logadd'" +); DEFINE_string( lmtype, "kenlm", - "[decode] Language model type used along with acoustic model: 'kenlm', 'convlm'"); + "[decode] Language model type used along with acoustic model: 'kenlm', 'convlm'" +); DEFINE_string( lexicon, "", - "path/to/lexicon.txt which contains on each row space separated mapping of a word into tokens sequence"); + "path/to/lexicon.txt which contains on each row space separated mapping of a word into tokens sequence" +); DEFINE_string( lm_vocab, "", - "[decode] path/to/lm_vocab.txt for the 'convlm' language model: each token is mapped to its file row index"); + "[decode] path/to/lm_vocab.txt for the 'convlm' language model: each token is mapped to its file row index" +); DEFINE_string( emission_dir, "", - "path/to/emission_dir/ where emissions data will be stored"); + "path/to/emission_dir/ where emissions data will be stored" +); DEFINE_string(lm, "", "[decode] path/to/language_model"); DEFINE_string( am, "", - "path/to/acoustic_model, used also to continue and fork training"); + "path/to/acoustic_model, used also to continue and fork training" +); DEFINE_string(sclite, "", "[decode] path/to/sclite to be written"); DEFINE_string( decodertype, "wrd", - "[decode] Defines at which level language model should be applied: 'wrd', 'tkn'"); + "[decode] Defines at which level language model should be applied: 'wrd', 'tkn'" +); DEFINE_double( lmweight, 0.0, - "[decode] language model weight in the beam search"); + "[decode] language model weight in the beam search" +); DEFINE_double( wordscore, 0.0, - "[decode] word insertion score for lexicon-based decoding"); + "[decode] word insertion score for lexicon-based decoding" +); DEFINE_double(silscore, 0.0, "[decode] word separator insertion score"); DEFINE_double( unkscore, -std::numeric_limits::infinity(), - "[decode] unknown word insertion score"); + "[decode] unknown word insertion score" +); DEFINE_double( eosscore, 0.0, - "[decode] End-of-sentence insertion score (for decoding of seq2seq with attention models)"); + "[decode] End-of-sentence insertion score (for decoding of seq2seq with attention models)" +); DEFINE_double( beamthreshold, 25, - "[decode] beam score threshold for early pruning of hypothesis"); + "[decode] beam score threshold for early pruning of hypothesis" +); DEFINE_int32( maxload, -1, - "[test, decode] Maximum number of testing samples to process"); + "[test, decode] Maximum number of testing samples to process" +); DEFINE_int32( maxword, -1, - "Maximum number of words (rows) to use from the lexicon file"); + "Maximum number of words (rows) to use from the lexicon file" +); DEFINE_int32(beamsize, 2500, "[decode] Maximum overall beam size"); DEFINE_int32( beamsizetoken, 250000, - "[decode] Maximum beam for tokens selection"); + "[decode] Maximum beam for tokens selection" +); DEFINE_int32( nthread_decoder_am_forward, 1, - "[test, decoder] Number of threads for acoustic model forward"); + "[test, decoder] Number of threads for acoustic model forward" +); DEFINE_int32( nthread_decoder, 1, - "[decode] Number of threads for beam-search decoding"); + "[decode] Number of threads for beam-search decoding" +); DEFINE_int32( lm_memory, 5000, - "[decode] Total memory size for batch forming for 'convlm' LM forward pass"); + "[decode] Total memory size for batch forming for 'convlm' LM forward pass" +); DEFINE_int32( emission_queue_size, 3000, - "[test, decode] Maximum size of emission queue for acoustic model forward pass"); + "[test, decode] Maximum size of emission queue for acoustic model forward pass" +); DEFINE_double( smoothingtemperature, 1.0, "[train] Smoothening the probability distribution in seq2seq " - "decoder of 'seq2seq' and 'transformer' criterions"); + "decoder of 'seq2seq' and 'transformer' criterions" +); DEFINE_int32( attentionthreshold, std::numeric_limits::max(), - "[train] Hard attention limit in seq2seq decoder only for 'seq2seq' criterion"); + "[train] Hard attention limit in seq2seq decoder only for 'seq2seq' criterion" +); DEFINE_double( lmweight_low, 0.0, - "language model weight (low boundary, search)"); + "language model weight (low boundary, search)" +); DEFINE_double( lmweight_high, 4.0, - "language model weight (high boundary, search)"); + "language model weight (high boundary, search)" +); DEFINE_double(lmweight_step, 0.2, "language model weight (step, search)"); // ASG OPTIONS DEFINE_int64( linseg, 0, - "[train] Number of updates of LinSeg to init transitions for ASG"); + "[train] Number of updates of LinSeg to init transitions for ASG" +); DEFINE_double( linlr, -1.0, - "[train] LinSeg: learning rate for network parameters (if < 0, use lr)"); + "[train] LinSeg: learning rate for network parameters (if < 0, use lr)" +); DEFINE_double( linlrcrit, -1.0, - "[train] LinSeg criterion learning rate (if < 0, use lrcrit)"); + "[train] LinSeg criterion learning rate (if < 0, use lrcrit)" +); DEFINE_double( transdiag, 0.0, - "[train] 'asg' criterion: initial value along diagonal of ASG transition matrix"); + "[train] 'asg' criterion: initial value along diagonal of ASG transition matrix" +); // SEQ2SEQ OPTIONS DEFINE_int64( maxdecoderoutputlen, 200, "'seq2seq'/'transformer' criterion: max decoder steps during inference; " - "(for 'transformer' cannot be changed after initialization)"); + "(for 'transformer' cannot be changed after initialization)" +); DEFINE_int64( pctteacherforcing, 100, - "[train] 'seq2seq'/'transformer' criterion: percentage of steps to train using teacher forcing"); + "[train] 'seq2seq'/'transformer' criterion: percentage of steps to train using teacher forcing" +); DEFINE_string( samplingstrategy, "rand", "[train] 'seq2seq'/'transformer' criterion: sampling strategy " - "to use when `pctteacherforcing` < 100. One of: {'rand', 'model'}"); + "to use when `pctteacherforcing` < 100. One of: {'rand', 'model'}" +); DEFINE_double( labelsmooth, 0.0, - "[train] 'seq2seq'/'transformer' criterion: fraction to smooth targets with uniform distribution."); + "[train] 'seq2seq'/'transformer' criterion: fraction to smooth targets with uniform distribution." +); DEFINE_bool( inputfeeding, false, - "[train] 'seq2seq' criterion: feed encoder summary to the decoder RNN"); + "[train] 'seq2seq' criterion: feed encoder summary to the decoder RNN" +); DEFINE_string( attention, "content", "[train] 'seq2seq'/'transformer' criterion: attention type in the encoder-decoder, " "supported options: 'content', 'keyvalue', 'location', 'multi', 'multikv', 'multisplit', 'multikvsplit', " - "'neural', 'neuralloc', 'simpleloc'"); + "'neural', 'neuralloc', 'simpleloc'" +); DEFINE_string( attnWindow, "no", "[train] 'seq2seq'/'transformer' criterion: attention window type in the encoder-decoder, " - "supported options: 'median', 'no', 'soft', 'softPretrain', 'step'"); + "supported options: 'median', 'no', 'soft', 'softPretrain', 'step'" +); DEFINE_int64( attndim, 0, - "[train] 'seq2seq'/'transformer' criterion: dimension of neural location attention"); + "[train] 'seq2seq'/'transformer' criterion: dimension of neural location attention" +); DEFINE_int64( attnconvchannel, 0, "[train] 'seq2seq'/'transformer' criterion: " - "number of convolutional channels for location attention"); + "number of convolutional channels for location attention" +); DEFINE_int64( attnconvkernel, 0, - "[train] 'seq2seq'/'transformer' criterion: kernel width for location attention"); + "[train] 'seq2seq'/'transformer' criterion: kernel width for location attention" +); DEFINE_int64( numattnhead, 8, - "[train] 'seq2seq'/'transformer' criterion: number of heads for multihead attention"); + "[train] 'seq2seq'/'transformer' criterion: number of heads for multihead attention" +); DEFINE_int64( leftWindowSize, 50, - "[train] 'seq2seq'/'transformer' criterion: left median window width"); + "[train] 'seq2seq'/'transformer' criterion: left median window width" +); DEFINE_int64( rightWindowSize, 50, - "[train] 'seq2seq'/'transformer' criterion: right median window width"); + "[train] 'seq2seq'/'transformer' criterion: right median window width" +); DEFINE_int64( maxsil, 50, "[train] 'seq2seq'/'transformer' criterion: maximum number of " - "leading silence frames for the step window"); + "leading silence frames for the step window" +); DEFINE_int64( minsil, 0, "[train] 'seq2seq'/'transformer' criterion: minimum number of " - "leading silence frames for the step window"); + "leading silence frames for the step window" +); DEFINE_double( maxrate, 10, "[train] 'seq2seq'/'transformer' criterion: maximum ratio between the transcript " - "and the encoded input lengths for the step window"); + "and the encoded input lengths for the step window" +); DEFINE_double( minrate, 3, "[train] 'seq2seq'/'transformer' criterion: minimum ratio between the " - "transcript and the encoded input lengths for the step window"); + "transcript and the encoded input lengths for the step window" +); DEFINE_int64( softwoffset, 10, "[train] 'seq2seq'/'transformer' criterion: offset for the soft " - "window center (= offset + step * rate)"); + "window center (= offset + step * rate)" +); DEFINE_double( softwrate, 5, "[train] 'seq2seq'/'transformer' criterion: moving " - "rate for the soft window center (= offset + step * rate)"); + "rate for the soft window center (= offset + step * rate)" +); DEFINE_double( softwstd, 5, "[train] 'seq2seq'/'transformer' criterion: std for the soft " - "window shape (=exp(-(t - center)^2 / (2 * std^2)))"); + "window shape (=exp(-(t - center)^2 / (2 * std^2)))" +); DEFINE_bool( trainWithWindow, false, "[train] 'seq2seq'/'transformer' criterion: use " - "force-aligned diagonal attention window during the whole training"); + "force-aligned diagonal attention window during the whole training" +); DEFINE_int64( pretrainWindow, 0, "[train] 'seq2seq'/'transformer' criterion: use force-aligned diagonal attention window" - "in training for 'pretrainWindow' updates"); + "in training for 'pretrainWindow' updates" +); DEFINE_double( gumbeltemperature, 1.0, - "[train] 'seq2seq' criterion decoder: temperature in gumbel softmax"); + "[train] 'seq2seq' criterion decoder: temperature in gumbel softmax" +); DEFINE_int64( decoderrnnlayer, 1, - "[train] 'seq2seq' criterion decoder: the number of decoder rnn layers"); + "[train] 'seq2seq' criterion decoder: the number of decoder rnn layers" +); DEFINE_int64( decoderattnround, 1, - "[train] 'seq2seq' criterion decoder: the number of decoder attention rounds"); + "[train] 'seq2seq' criterion decoder: the number of decoder attention rounds" +); DEFINE_double( decoderdropout, 0.0, - "[train] 'seq2seq' criterion decoder: dropout"); + "[train] 'seq2seq' criterion decoder: dropout" +); // DISTRIBUTED TRAINING DEFINE_bool(enable_distributed, false, "[train] Enable distributed training"); DEFINE_int64( world_rank, 0, - "[train] Rank of the process (Used if rndv_filepath is not empty)"); + "[train] Rank of the process (Used if rndv_filepath is not empty)" +); DEFINE_int64( world_size, 1, - "[train] Total number of the processes (Used if rndv_filepath is not empty)"); + "[train] Total number of the processes (Used if rndv_filepath is not empty)" +); DEFINE_int64( max_devices_per_node, 8, - "[train] The maximum number of devices per training node"); + "[train] The maximum number of devices per training node" +); DEFINE_string( rndv_filepath, "", "[train] Shared file path used for setting up rendezvous." - "If empty, uses MPI to initialize."); + "If empty, uses MPI to initialize." +); // FB SPECIFIC DEFINE_bool(everstoredb, false, "use Everstore db for reading data"); @@ -665,79 +795,84 @@ DEFINE_bool(use_memcache, false, "use Memcache for reading data"); namespace detail { /***************************** Deprecated Flags *****************************/ -namespace { + namespace { -void registerDeprecatedFlags() { - // Register deprecated flags here using DEPRECATE_FLAGS. For example: - // DEPRECATE_FLAGS(my_now_deprecated_flag_name, my_new_flag_name); -} + void registerDeprecatedFlags() { + // Register deprecated flags here using DEPRECATE_FLAGS. For example: + // DEPRECATE_FLAGS(my_now_deprecated_flag_name, my_new_flag_name); + } -} // namespace + } // namespace -DeprecatedFlagsMap& getDeprecatedFlags() { - static DeprecatedFlagsMap flagsMap = DeprecatedFlagsMap(); - return flagsMap; -} + DeprecatedFlagsMap& getDeprecatedFlags() { + static DeprecatedFlagsMap flagsMap = DeprecatedFlagsMap(); + return flagsMap; + } -void addDeprecatedFlag( - const std::string& deprecatedFlagName, - const std::string& newFlagName) { - auto& map = getDeprecatedFlags(); - map.emplace(deprecatedFlagName, newFlagName); -} + void addDeprecatedFlag( + const std::string& deprecatedFlagName, + const std::string& newFlagName + ) { + auto& map = getDeprecatedFlags(); + map.emplace(deprecatedFlagName, newFlagName); + } -bool isFlagSet(const std::string& name) { - gflags::CommandLineFlagInfo flagInfo; - if (!gflags::GetCommandLineFlagInfo(name.c_str(), &flagInfo)) { - std::stringstream ss; - ss << "Flag name " << name << " not found - check that it's declared."; - throw std::invalid_argument(ss.str()); - } - return !flagInfo.is_default; -} + bool isFlagSet(const std::string& name) { + gflags::CommandLineFlagInfo flagInfo; + if(!gflags::GetCommandLineFlagInfo(name.c_str(), &flagInfo)) { + std::stringstream ss; + ss << "Flag name " << name << " not found - check that it's declared."; + throw std::invalid_argument(ss.str()); + } + return !flagInfo.is_default; + } } // namespace detail void handleDeprecatedFlags() { - auto& map = detail::getDeprecatedFlags(); - // Register flags - static std::once_flag registerFlagsOnceFlag; - std::call_once(registerFlagsOnceFlag, detail::registerDeprecatedFlags); - - for (auto& flagPair : map) { - std::string deprecatedFlagValue; - gflags::GetCommandLineOption(flagPair.first.c_str(), &deprecatedFlagValue); - - bool deprecatedFlagSet = detail::isFlagSet(flagPair.first); - bool newFlagSet = detail::isFlagSet(flagPair.second); - - if (deprecatedFlagSet && newFlagSet) { - // Use the new flag value - std::cerr << "[WARNING] Both deprecated flag " << flagPair.first - << " and new flag " << flagPair.second - << " are set. Only the new flag will be " - << "serialized when the model saved." << std::endl; - ; - } else if (deprecatedFlagSet && !newFlagSet) { - std::cerr - << "[WARNING] Usage of flag --" << flagPair.first - << " is deprecated and has been replaced with " - << "--" << flagPair.second - << ". Setting the new flag equal to the value of the deprecated flag." - << "The old flag will not be serialized when the model is saved." - << std::endl; - if (gflags::SetCommandLineOption( - flagPair.second.c_str(), deprecatedFlagValue.c_str()) - .empty()) { - std::stringstream ss; - ss << "Failed to set new flag " << flagPair.second << " to value from " - << flagPair.first << "."; - throw std::logic_error(ss.str()); - } + auto& map = detail::getDeprecatedFlags(); + // Register flags + static std::once_flag registerFlagsOnceFlag; + std::call_once(registerFlagsOnceFlag, detail::registerDeprecatedFlags); + + for(auto& flagPair : map) { + std::string deprecatedFlagValue; + gflags::GetCommandLineOption(flagPair.first.c_str(), &deprecatedFlagValue); + + bool deprecatedFlagSet = detail::isFlagSet(flagPair.first); + bool newFlagSet = detail::isFlagSet(flagPair.second); + + if(deprecatedFlagSet && newFlagSet) { + // Use the new flag value + std::cerr << "[WARNING] Both deprecated flag " << flagPair.first + << " and new flag " << flagPair.second + << " are set. Only the new flag will be " + << "serialized when the model saved." << std::endl; + ; + } else if(deprecatedFlagSet && !newFlagSet) { + std::cerr + << "[WARNING] Usage of flag --" << flagPair.first + << " is deprecated and has been replaced with " + << "--" << flagPair.second + << ". Setting the new flag equal to the value of the deprecated flag." + << "The old flag will not be serialized when the model is saved." + << std::endl; + if( + gflags::SetCommandLineOption( + flagPair.second.c_str(), + deprecatedFlagValue.c_str() + ) + .empty() + ) { + std::stringstream ss; + ss << "Failed to set new flag " << flagPair.second << " to value from " + << flagPair.first << "."; + throw std::logic_error(ss.str()); + } + } + + // If the user set the new flag but not the deprecated flag, noop. If the + // user set neither flag, noop. } - - // If the user set the new flag but not the deprecated flag, noop. If the - // user set neither flag, noop. - } } } // namespace fl diff --git a/flashlight/pkg/speech/common/Flags.h b/flashlight/pkg/speech/common/Flags.h index 2d86cd6..1551b61 100644 --- a/flashlight/pkg/speech/common/Flags.h +++ b/flashlight/pkg/speech/common/Flags.h @@ -15,11 +15,11 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { -namespace detail { + namespace detail { -using DeprecatedFlagsMap = std::unordered_map; + using DeprecatedFlagsMap = std::unordered_map; /** * Creates and maintains a map of deprecated flags. The map takes @@ -27,18 +27,19 @@ using DeprecatedFlagsMap = std::unordered_map; * ---> {myOldFlag, myNewFlag} * corresponds to the deprecation of myOldFlag */ -DeprecatedFlagsMap& getDeprecatedFlags(); + DeprecatedFlagsMap& getDeprecatedFlags(); // Adds a flag to the global deprecated map -void addDeprecatedFlag( - const std::string& depreactedFlagName, - const std::string& newFlagName); + void addDeprecatedFlag( + const std::string& depreactedFlagName, + const std::string& newFlagName + ); // Whether the flag has been explicitly set either from the cmd line or // de-serialization -bool isFlagSet(const std::string& name); + bool isFlagSet(const std::string& name); -} // namespace detail + } // namespace detail /** * Globally-accessible and recommended to be called immediately after gflags @@ -54,7 +55,7 @@ bool isFlagSet(const std::string& name); * - Does nothing if the user set neither the new nor the deprecated flag, or if * the user correctly set only the new flag and not the deprecated flag. */ -void handleDeprecatedFlags(); + void handleDeprecatedFlags(); /** * Deprecate a command line flag. @@ -63,211 +64,211 @@ void handleDeprecatedFlags(); * DEPRECATE_FLAGS(myOldFlagName, my_new_flag_name) */ #define DEPRECATE_FLAGS(DEPRECATED, NEW) \ - detail::addDeprecatedFlag(#DEPRECATED, #NEW); + detail::addDeprecatedFlag(#DEPRECATED, #NEW); /* ========== DATA OPTIONS ========== */ -DECLARE_string(train); -DECLARE_string(valid); -DECLARE_string(test); -DECLARE_int64(batchsize); -DECLARE_int64(validbatchsize); -DECLARE_int64(samplerate); -DECLARE_int64(channels); -DECLARE_string(tokens); -DECLARE_string(batching_strategy); -DECLARE_int64(batching_max_duration); -DECLARE_bool(usewordpiece); -DECLARE_int64(replabel); -DECLARE_string(surround); -DECLARE_string(wordseparator); -DECLARE_double(sampletarget); + DECLARE_string(train); + DECLARE_string(valid); + DECLARE_string(test); + DECLARE_int64(batchsize); + DECLARE_int64(validbatchsize); + DECLARE_int64(samplerate); + DECLARE_int64(channels); + DECLARE_string(tokens); + DECLARE_string(batching_strategy); + DECLARE_int64(batching_max_duration); + DECLARE_bool(usewordpiece); + DECLARE_int64(replabel); + DECLARE_string(surround); + DECLARE_string(wordseparator); + DECLARE_double(sampletarget); /* ========== NORMALIZATION OPTIONS ========== */ -DECLARE_int64(localnrmlleftctx); -DECLARE_int64(localnrmlrightctx); -DECLARE_string(onorm); -DECLARE_bool(sqnorm); -DECLARE_bool(lrcosine); + DECLARE_int64(localnrmlleftctx); + DECLARE_int64(localnrmlrightctx); + DECLARE_string(onorm); + DECLARE_bool(sqnorm); + DECLARE_bool(lrcosine); /* ========== LEARNING HYPER-PARAMETER OPTIONS ========== */ -DECLARE_int64(iter); -DECLARE_bool(itersave); -DECLARE_double(lr); -DECLARE_double(momentum); -DECLARE_double(weightdecay); -DECLARE_bool(sqnorm); -DECLARE_double(lrcrit); -DECLARE_int64(warmup); -DECLARE_int64(saug_start_update); -DECLARE_int64(lr_decay); -DECLARE_int64(lr_decay_step); -DECLARE_double(maxgradnorm); -DECLARE_double(adambeta1); // TODO rename into optim beta1 -DECLARE_double(adambeta2); // TODO rename into optim beta2 -DECLARE_double(optimrho); -DECLARE_double(optimepsilon); + DECLARE_int64(iter); + DECLARE_bool(itersave); + DECLARE_double(lr); + DECLARE_double(momentum); + DECLARE_double(weightdecay); + DECLARE_bool(sqnorm); + DECLARE_double(lrcrit); + DECLARE_int64(warmup); + DECLARE_int64(saug_start_update); + DECLARE_int64(lr_decay); + DECLARE_int64(lr_decay_step); + DECLARE_double(maxgradnorm); + DECLARE_double(adambeta1); // TODO rename into optim beta1 + DECLARE_double(adambeta2); // TODO rename into optim beta2 + DECLARE_double(optimrho); + DECLARE_double(optimepsilon); /* ========== LR-SCHEDULER OPTIONS ========== */ -DECLARE_int64(stepsize); -DECLARE_double(gamma); + DECLARE_int64(stepsize); + DECLARE_double(gamma); /* ========== OPTIMIZER OPTIONS ========== */ -DECLARE_string(netoptim); -DECLARE_string(critoptim); + DECLARE_string(netoptim); + DECLARE_string(critoptim); /* ========== MFCC OPTIONS ========== */ -DECLARE_string(features_type); -DECLARE_int64(mfcccoeffs); -DECLARE_double(melfloor); -DECLARE_int64(filterbanks); -DECLARE_int64(devwin); -DECLARE_int64(fftcachesize); -DECLARE_int64(framesizems); -DECLARE_int64(framestridems); -DECLARE_int64(lowfreqfilterbank); -DECLARE_int64(highfreqfilterbank); + DECLARE_string(features_type); + DECLARE_int64(mfcccoeffs); + DECLARE_double(melfloor); + DECLARE_int64(filterbanks); + DECLARE_int64(devwin); + DECLARE_int64(fftcachesize); + DECLARE_int64(framesizems); + DECLARE_int64(framestridems); + DECLARE_int64(lowfreqfilterbank); + DECLARE_int64(highfreqfilterbank); /* ========== SPECAUGMENT OPTIONS ========== */ -DECLARE_int64(saug_fmaskf); -DECLARE_int64(saug_fmaskn); -DECLARE_int64(saug_tmaskt); -DECLARE_double(saug_tmaskp); -DECLARE_int64(saug_tmaskn); + DECLARE_int64(saug_fmaskf); + DECLARE_int64(saug_fmaskn); + DECLARE_int64(saug_tmaskt); + DECLARE_double(saug_tmaskp); + DECLARE_int64(saug_tmaskn); /* ========== SOUND EFFECT AUGMENTATION OPTIONS ========== */ -DECLARE_string(sfx_config); -DECLARE_int64(sfx_start_update); + DECLARE_string(sfx_config); + DECLARE_int64(sfx_start_update); /* ========== RUN OPTIONS ========== */ -DECLARE_string(datadir); -DECLARE_string(rundir); -DECLARE_string(flagsfile); -DECLARE_int64(nthread); -DECLARE_int64(seed); -DECLARE_int64(memstepsize); -DECLARE_int64(reportiters); -DECLARE_double(pcttraineval); -DECLARE_bool(fl_benchmark_mode); -DECLARE_string(fl_optim_mode); -DECLARE_string(fl_log_level); -DECLARE_int64(fl_vlog_level); -DECLARE_int64(fl_log_mem_ops_interval); + DECLARE_string(datadir); + DECLARE_string(rundir); + DECLARE_string(flagsfile); + DECLARE_int64(nthread); + DECLARE_int64(seed); + DECLARE_int64(memstepsize); + DECLARE_int64(reportiters); + DECLARE_double(pcttraineval); + DECLARE_bool(fl_benchmark_mode); + DECLARE_string(fl_optim_mode); + DECLARE_string(fl_log_level); + DECLARE_int64(fl_vlog_level); + DECLARE_int64(fl_log_mem_ops_interval); /* ========== MIXED PRECISION OPTIONS ========== */ -DECLARE_bool(fl_amp_use_mixed_precision); -DECLARE_double(fl_amp_scale_factor); -DECLARE_uint64(fl_amp_scale_factor_update_interval); -DECLARE_uint64(fl_amp_max_scale_factor); + DECLARE_bool(fl_amp_use_mixed_precision); + DECLARE_double(fl_amp_scale_factor); + DECLARE_uint64(fl_amp_scale_factor_update_interval); + DECLARE_uint64(fl_amp_max_scale_factor); /* ========== ARCHITECTURE OPTIONS ========== */ -DECLARE_string(arch); -DECLARE_string(criterion); -DECLARE_int64(encoderdim); + DECLARE_string(arch); + DECLARE_string(criterion); + DECLARE_int64(encoderdim); // Seq2Seq Transformer decoder -DECLARE_int64(am_decoder_tr_layers); -DECLARE_double(am_decoder_tr_dropout); -DECLARE_double(am_decoder_tr_layerdrop); + DECLARE_int64(am_decoder_tr_layers); + DECLARE_double(am_decoder_tr_dropout); + DECLARE_double(am_decoder_tr_layerdrop); /* ========== DECODER OPTIONS ========== */ -DECLARE_bool(show); -DECLARE_bool(showletters); -DECLARE_bool(logadd); -DECLARE_bool(uselexicon); -DECLARE_bool(isbeamdump); - -DECLARE_string(smearing); -DECLARE_string(lmtype); -DECLARE_string(lexicon); -DECLARE_string(lm_vocab); -DECLARE_string(emission_dir); -DECLARE_string(lm); -DECLARE_string(am); -DECLARE_string(sclite); -DECLARE_string(decodertype); - -DECLARE_double(lmweight); -DECLARE_double(wordscore); -DECLARE_double(silscore); -DECLARE_double(unkscore); -DECLARE_double(eosscore); -DECLARE_double(beamthreshold); - -DECLARE_int32(maxload); -DECLARE_int32(maxword); -DECLARE_int32(beamsize); -DECLARE_int32(beamsizetoken); -DECLARE_int32(nthread_decoder_am_forward); -DECLARE_int32(nthread_decoder); -DECLARE_int32(lm_memory); - -DECLARE_int32(emission_queue_size); - -DECLARE_double(lmweight_low); -DECLARE_double(lmweight_high); -DECLARE_double(lmweight_step); + DECLARE_bool(show); + DECLARE_bool(showletters); + DECLARE_bool(logadd); + DECLARE_bool(uselexicon); + DECLARE_bool(isbeamdump); + + DECLARE_string(smearing); + DECLARE_string(lmtype); + DECLARE_string(lexicon); + DECLARE_string(lm_vocab); + DECLARE_string(emission_dir); + DECLARE_string(lm); + DECLARE_string(am); + DECLARE_string(sclite); + DECLARE_string(decodertype); + + DECLARE_double(lmweight); + DECLARE_double(wordscore); + DECLARE_double(silscore); + DECLARE_double(unkscore); + DECLARE_double(eosscore); + DECLARE_double(beamthreshold); + + DECLARE_int32(maxload); + DECLARE_int32(maxword); + DECLARE_int32(beamsize); + DECLARE_int32(beamsizetoken); + DECLARE_int32(nthread_decoder_am_forward); + DECLARE_int32(nthread_decoder); + DECLARE_int32(lm_memory); + + DECLARE_int32(emission_queue_size); + + DECLARE_double(lmweight_low); + DECLARE_double(lmweight_high); + DECLARE_double(lmweight_step); // Seq2Seq -DECLARE_double(smoothingtemperature); -DECLARE_int32(attentionthreshold); + DECLARE_double(smoothingtemperature); + DECLARE_int32(attentionthreshold); /* ========== ASG OPTIONS ========== */ -DECLARE_int64(linseg); -DECLARE_double(linlr); -DECLARE_double(linlrcrit); -DECLARE_double(transdiag); + DECLARE_int64(linseg); + DECLARE_double(linlr); + DECLARE_double(linlrcrit); + DECLARE_double(transdiag); /* ========== SEQ2SEQ OPTIONS ========== */ -DECLARE_int64(maxdecoderoutputlen); -DECLARE_int64(pctteacherforcing); -DECLARE_string(samplingstrategy); -DECLARE_double(labelsmooth); -DECLARE_bool(inputfeeding); -DECLARE_string(attention); -DECLARE_string(attnWindow); -DECLARE_int64(attndim); -DECLARE_int64(attnconvchannel); -DECLARE_int64(attnconvkernel); -DECLARE_int64(numattnhead); -DECLARE_int64(leftWindowSize); -DECLARE_int64(rightWindowSize); -DECLARE_int64(maxsil); -DECLARE_int64(minsil); -DECLARE_double(maxrate); -DECLARE_double(minrate); -DECLARE_int64(softwoffset); -DECLARE_double(softwrate); -DECLARE_double(softwstd); -DECLARE_bool(trainWithWindow); -DECLARE_int64(pretrainWindow); -DECLARE_double(gumbeltemperature); -DECLARE_int64(decoderrnnlayer); -DECLARE_int64(decoderattnround); -DECLARE_double(decoderdropout); + DECLARE_int64(maxdecoderoutputlen); + DECLARE_int64(pctteacherforcing); + DECLARE_string(samplingstrategy); + DECLARE_double(labelsmooth); + DECLARE_bool(inputfeeding); + DECLARE_string(attention); + DECLARE_string(attnWindow); + DECLARE_int64(attndim); + DECLARE_int64(attnconvchannel); + DECLARE_int64(attnconvkernel); + DECLARE_int64(numattnhead); + DECLARE_int64(leftWindowSize); + DECLARE_int64(rightWindowSize); + DECLARE_int64(maxsil); + DECLARE_int64(minsil); + DECLARE_double(maxrate); + DECLARE_double(minrate); + DECLARE_int64(softwoffset); + DECLARE_double(softwrate); + DECLARE_double(softwstd); + DECLARE_bool(trainWithWindow); + DECLARE_int64(pretrainWindow); + DECLARE_double(gumbeltemperature); + DECLARE_int64(decoderrnnlayer); + DECLARE_int64(decoderattnround); + DECLARE_double(decoderdropout); /* ========== DISTRIBUTED TRAINING ========== */ -DECLARE_bool(enable_distributed); -DECLARE_int64(world_rank); -DECLARE_int64(world_size); -DECLARE_int64(max_devices_per_node); -DECLARE_string(rndv_filepath); + DECLARE_bool(enable_distributed); + DECLARE_int64(world_rank); + DECLARE_int64(world_size); + DECLARE_int64(max_devices_per_node); + DECLARE_string(rndv_filepath); /* ========== FB SPECIFIC ========== */ -DECLARE_bool(everstoredb); -DECLARE_bool(use_memcache); -} // namespace speech + DECLARE_bool(everstoredb); + DECLARE_bool(use_memcache); + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/common/ProducerConsumerQueue.h b/flashlight/pkg/speech/common/ProducerConsumerQueue.h index 06e1794..f88ad2a 100644 --- a/flashlight/pkg/speech/common/ProducerConsumerQueue.h +++ b/flashlight/pkg/speech/common/ProducerConsumerQueue.h @@ -51,100 +51,102 @@ namespace lib { * */ -template -class ProducerConsumerQueue { - public: - explicit ProducerConsumerQueue(int maxSize = 3000) - : maxSize_(maxSize), isAddingFinished_(false) {} - - /* - * - Adds an element to the queue. - * - Ignores the current one if adding is finished. - * - Notifies another producer when queue is not full. - * - Notifies a consumer. - */ - void add(T unit) { - std::unique_lock lock(mutex_); - producerCondition_.wait( - lock, [this]() { return !isFull() || isAddingFinished_; }); - - if (isAddingFinished_) { - return; - } - queue_.push(std::move(unit)); - - if (!isFull()) { - producerCondition_.notify_one(); - } - consumerCondition_.notify_one(); - } - - /* - * - Pops an element from the queue. - * - Returns false when adding is finished and queue is empty. - * - Notifies another consumer when queue is not empty. - * - Notifies a producer. - */ - bool get(T& unit) { - std::unique_lock lock(mutex_); - consumerCondition_.wait( - lock, [this]() { return !isEmpty() || isAddingFinished_; }); - if (isEmpty()) { - return false; - } - unit = std::move(queue_.front()); - queue_.pop(); - - if (!isEmpty()) { - consumerCondition_.notify_one(); - } - producerCondition_.notify_one(); - - return true; - } - - /* - * - Sets the status of the queue to be adding-finished. - * - Notifies all the consumers to consume the remaining elements. - */ - void finishAdding() { - std::unique_lock lock(mutex_); - isAddingFinished_ = true; - consumerCondition_.notify_all(); - } - - /* - * - Clears the queue. - * - Resets the status of the queue to be adding-unfinished. - * - Notifies all the consumers and producers to work. - */ - void clear() { - std::unique_lock lock(mutex_); - while (!isEmpty()) { - queue_.pop(); - } - isAddingFinished_ = false; - - producerCondition_.notify_all(); - consumerCondition_.notify_all(); - } - - private: - std::condition_variable producerCondition_; - std::condition_variable consumerCondition_; - - std::mutex mutex_; - std::queue queue_; - int maxSize_; - bool isAddingFinished_; - - bool isFull() const { - return queue_.size() >= maxSize_; - } - - bool isEmpty() const { - return queue_.size() == 0; - } -}; + template + class ProducerConsumerQueue { + public: + explicit ProducerConsumerQueue(int maxSize = 3000) : maxSize_(maxSize), + isAddingFinished_(false) {} + + /* + * - Adds an element to the queue. + * - Ignores the current one if adding is finished. + * - Notifies another producer when queue is not full. + * - Notifies a consumer. + */ + void add(T unit) { + std::unique_lock lock(mutex_); + producerCondition_.wait( + lock, + [this]() { return !isFull() || isAddingFinished_; }); + + if(isAddingFinished_) { + return; + } + queue_.push(std::move(unit)); + + if(!isFull()) { + producerCondition_.notify_one(); + } + consumerCondition_.notify_one(); + } + + /* + * - Pops an element from the queue. + * - Returns false when adding is finished and queue is empty. + * - Notifies another consumer when queue is not empty. + * - Notifies a producer. + */ + bool get(T& unit) { + std::unique_lock lock(mutex_); + consumerCondition_.wait( + lock, + [this]() { return !isEmpty() || isAddingFinished_; }); + if(isEmpty()) { + return false; + } + unit = std::move(queue_.front()); + queue_.pop(); + + if(!isEmpty()) { + consumerCondition_.notify_one(); + } + producerCondition_.notify_one(); + + return true; + } + + /* + * - Sets the status of the queue to be adding-finished. + * - Notifies all the consumers to consume the remaining elements. + */ + void finishAdding() { + std::unique_lock lock(mutex_); + isAddingFinished_ = true; + consumerCondition_.notify_all(); + } + + /* + * - Clears the queue. + * - Resets the status of the queue to be adding-unfinished. + * - Notifies all the consumers and producers to work. + */ + void clear() { + std::unique_lock lock(mutex_); + while(!isEmpty()) { + queue_.pop(); + } + isAddingFinished_ = false; + + producerCondition_.notify_all(); + consumerCondition_.notify_all(); + } + + private: + std::condition_variable producerCondition_; + std::condition_variable consumerCondition_; + + std::mutex mutex_; + std::queue queue_; + int maxSize_; + bool isAddingFinished_; + + bool isFull() const { + return queue_.size() >= maxSize_; + } + + bool isEmpty() const { + return queue_.size() == 0; + } + }; } // namespace lib } // namespace fl diff --git a/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h b/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h index 5a3c95d..d91ede8 100644 --- a/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h +++ b/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h @@ -16,87 +16,91 @@ using fl::lib::seq::CriterionScaleMode; namespace fl { namespace pkg { -namespace speech { - -class AutoSegmentationCriterion : public SequenceCriterion { - public: - explicit AutoSegmentationCriterion( - int N, - CriterionScaleMode scalemode = CriterionScaleMode::NONE, - double transdiag = 0.0) - : N_(N), - scaleMode_(scalemode), - fac_(ForceAlignmentCriterion(N, scalemode)), - fcc_(FullConnectionCriterion(N, scalemode)) { - if (N_ <= 0) { - throw std::invalid_argument("ASG: N is zero or negative."); - } - fl::Variable transition(transdiag * fl::identity(N_), true); - params_ = {transition}; - syncTransitions(); - } - - std::unique_ptr clone() const override { - throw std::runtime_error( - "Cloning is unimplemented in Module 'AutoSegmentationCriterion'"); - } - - std::vector forward( - const std::vector& inputs) override { - if (inputs.size() != 2) { - throw std::invalid_argument("Invalid inputs size"); - } - return { - fcc_.forward(inputs[0], inputs[1]) - - fac_.forward(inputs[0], inputs[1])}; - } - - Tensor viterbiPath(const Tensor& input, const Tensor& inputSize = Tensor()) - override { - return fl::pkg::speech::viterbiPath(input, params_[0].tensor()); - } - - Tensor viterbiPathWithTarget( - const Tensor& input, - const Tensor& target, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) override { - return fac_.viterbiPath(input, target); - } - - void setParams(const fl::Variable& var, int position) override { - Module::setParams(var, position); - syncTransitions(); - } - - std::string prettyString() const override { - return "AutoSegmentationCriterion"; - } - - protected: - AutoSegmentationCriterion() = default; - - void syncTransitions() { - fac_.setParams(params_[0], 0); - fcc_.setParams(params_[0], 0); - } - - private: - int N_; - CriterionScaleMode scaleMode_; - ForceAlignmentCriterion fac_; - FullConnectionCriterion fcc_; - - FL_SAVE_LOAD_WITH_BASE( - SequenceCriterion, - fl::serializeAs(N_), - scaleMode_, - fac_, - fcc_) -}; - -using ASGLoss = AutoSegmentationCriterion; -} // namespace speech + namespace speech { + + class AutoSegmentationCriterion : public SequenceCriterion { + public: + explicit AutoSegmentationCriterion( + int N, + CriterionScaleMode scalemode = CriterionScaleMode::NONE, + double transdiag = 0.0 + ) : N_(N), + scaleMode_(scalemode), + fac_(ForceAlignmentCriterion(N, scalemode)), + fcc_(FullConnectionCriterion(N, scalemode)) { + if(N_ <= 0) { + throw std::invalid_argument("ASG: N is zero or negative."); + } + fl::Variable transition(transdiag * fl::identity(N_), true); + params_ = {transition}; + syncTransitions(); + } + + std::unique_ptr clone() const override { + throw std::runtime_error( + "Cloning is unimplemented in Module 'AutoSegmentationCriterion'" + ); + } + + std::vector forward( + const std::vector& inputs + ) override { + if(inputs.size() != 2) { + throw std::invalid_argument("Invalid inputs size"); + } + return { + fcc_.forward(inputs[0], inputs[1]) + - fac_.forward(inputs[0], inputs[1])}; + } + + Tensor viterbiPath(const Tensor& input, const Tensor& inputSize = Tensor()) + override { + return fl::pkg::speech::viterbiPath(input, params_[0].tensor()); + } + + Tensor viterbiPathWithTarget( + const Tensor& input, + const Tensor& target, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) override { + return fac_.viterbiPath(input, target); + } + + void setParams(const fl::Variable& var, int position) override { + Module::setParams(var, position); + syncTransitions(); + } + + std::string prettyString() const override { + return "AutoSegmentationCriterion"; + } + + protected: + AutoSegmentationCriterion() = default; + + void syncTransitions() { + fac_.setParams(params_[0], 0); + fcc_.setParams(params_[0], 0); + } + + private: + int N_; + CriterionScaleMode scaleMode_; + ForceAlignmentCriterion fac_; + FullConnectionCriterion fcc_; + + FL_SAVE_LOAD_WITH_BASE( + SequenceCriterion, + fl::serializeAs(N_), + scaleMode_, + fac_, + fcc_ + ) + }; + + using ASGLoss = AutoSegmentationCriterion; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.cpp b/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.cpp index 372d24d..3838c78 100644 --- a/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.cpp +++ b/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.cpp @@ -17,29 +17,34 @@ namespace { using namespace fl; struct CTCContext { - std::vector targetVec; - std::vector targetSizeVec; - std::vector workspaceVec; + std::vector targetVec; + std::vector targetSizeVec; + std::vector workspaceVec; }; Tensor logSoftmax(const Tensor& input, const int dim) { - Tensor maxvals = fl::amax(input, {dim}, /* keepDims = */ true); - Shape tiledims(std::vector(input.ndim(), 1)); - if (dim > 3) { - throw std::invalid_argument("logSoftmax: Dimension must be less than 3"); - } - tiledims[dim] = input.dim(dim); - // Compute log softmax. - // Subtracting then adding maxvals is for numerical stability. - auto result = input - - fl::tile(fl::log(fl::sum( - fl::exp(input - fl::tile(maxvals, tiledims)), - {dim}, - /* keepDims = */ true)) + - maxvals, - tiledims); - fl::eval(result); - return result; + Tensor maxvals = fl::amax(input, {dim}, /* keepDims = */ true); + Shape tiledims(std::vector(input.ndim(), 1)); + if(dim > 3) { + throw std::invalid_argument("logSoftmax: Dimension must be less than 3"); + } + tiledims[dim] = input.dim(dim); + // Compute log softmax. + // Subtracting then adding maxvals is for numerical stability. + auto result = input + - fl::tile( + fl::log( + fl::sum( + fl::exp(input - fl::tile(maxvals, tiledims)), + {dim}, + /* keepDims = */ true + ) + ) + + maxvals, + tiledims + ); + fl::eval(result); + return result; }; } // namespace @@ -47,23 +52,25 @@ Tensor logSoftmax(const Tensor& input, const int dim) { namespace fl::pkg::speech { ConnectionistTemporalClassificationCriterion:: - ConnectionistTemporalClassificationCriterion( - fl::lib::seq::CriterionScaleMode - scalemode /* = fl::lib::seq::CriterionScaleMode::NONE */) - : scaleMode_(scalemode) {} +ConnectionistTemporalClassificationCriterion( + fl::lib::seq::CriterionScaleMode + scalemode /* = fl::lib::seq::CriterionScaleMode::NONE */ +) : scaleMode_(scalemode) {} std::unique_ptr ConnectionistTemporalClassificationCriterion::clone() - const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'ConnectionistTemporalClassificationCriterion'"); +const { + throw std::runtime_error( + "Cloning is unimplemented in Module 'ConnectionistTemporalClassificationCriterion'" + ); } Tensor ConnectionistTemporalClassificationCriterion::viterbiPath( const Tensor& input, - const Tensor& inputSize /* = Tensor() */) { - Tensor bestpath, maxvalues; - fl::max(maxvalues, bestpath, input, 0); - return bestpath; + const Tensor& inputSize /* = Tensor() */ +) { + Tensor bestpath, maxvalues; + fl::max(maxvalues, bestpath, input, 0); + return bestpath; } Tensor ConnectionistTemporalClassificationCriterion::viterbiPathWithTarget( @@ -72,64 +79,70 @@ Tensor ConnectionistTemporalClassificationCriterion::viterbiPathWithTarget( const Tensor& inputSizes /* = Tensor() */, const Tensor& targetSizes /* = Tensor() */ ) { - if (input.ndim() != 3) { - throw std::invalid_argument( - "ConnectionistTemporalClassificationCriterion::viterbiPathWithTarget: " - "expected input of shape {N x T x B}"); - } - int N = input.dim(0); - int T = input.dim(1); - int B = input.dim(2); - int L = target.dim(0); - - const Tensor targetSize = getTargetSizeArray(target, T); - std::shared_ptr ctx = std::make_shared(); - Tensor softmax = ::logSoftmax(input, 0); - std::vector inputVec = softmax.toHostVector(); - ctx->targetVec = target.toHostVector(); - ctx->targetSizeVec = targetSize.toHostVector(); - ctx->workspaceVec.assign(CTC::getWorkspaceSize(B, T, N, L), 0); - std::vector bestPaths(B * T); - CTC::viterbi( - B, - T, - N, - L, - inputVec.data(), - ctx->targetVec.data(), - ctx->targetSizeVec.data(), - bestPaths.data(), - ctx->workspaceVec.data()); - Tensor result = - Tensor::fromBuffer({T, B}, bestPaths.data(), MemoryLocation::Host); - return result; + if(input.ndim() != 3) { + throw std::invalid_argument( + "ConnectionistTemporalClassificationCriterion::viterbiPathWithTarget: " + "expected input of shape {N x T x B}" + ); + } + int N = input.dim(0); + int T = input.dim(1); + int B = input.dim(2); + int L = target.dim(0); + + const Tensor targetSize = getTargetSizeArray(target, T); + std::shared_ptr ctx = std::make_shared(); + Tensor softmax = ::logSoftmax(input, 0); + std::vector inputVec = softmax.toHostVector(); + ctx->targetVec = target.toHostVector(); + ctx->targetSizeVec = targetSize.toHostVector(); + ctx->workspaceVec.assign(CTC::getWorkspaceSize(B, T, N, L), 0); + std::vector bestPaths(B * T); + CTC::viterbi( + B, + T, + N, + L, + inputVec.data(), + ctx->targetVec.data(), + ctx->targetSizeVec.data(), + bestPaths.data(), + ctx->workspaceVec.data() + ); + Tensor result = + Tensor::fromBuffer({T, B}, bestPaths.data(), MemoryLocation::Host); + return result; } std::string ConnectionistTemporalClassificationCriterion::prettyString() const { - return "ConnectionistTemporalClassificationCriterion"; + return "ConnectionistTemporalClassificationCriterion"; } void ConnectionistTemporalClassificationCriterion::validate( const Variable& input, - const Variable& target) { - if (input.isEmpty()) { - throw std::invalid_argument("CTC: Input cannot be empty"); - } - if (target.ndim() < 2) { - throw std::invalid_argument( - "CTC: Incorrect dimensions for target. Expected {L, B}, got " + - target.shape().toString()); - } - if (input.ndim() < 3) { - throw std::invalid_argument( - "CTC: Incorrect dimensions for input. Expected {N, T, B}, got " + - input.shape().toString()); - } - if (input.dim(2) != target.dim(1)) { - throw std::invalid_argument( - "CTC: Batchsize mismatch for input and target with dims " + - input.shape().toString() + " and " + target.shape().toString() + - ", respectively"); - } + const Variable& target +) { + if(input.isEmpty()) { + throw std::invalid_argument("CTC: Input cannot be empty"); + } + if(target.ndim() < 2) { + throw std::invalid_argument( + "CTC: Incorrect dimensions for target. Expected {L, B}, got " + + target.shape().toString() + ); + } + if(input.ndim() < 3) { + throw std::invalid_argument( + "CTC: Incorrect dimensions for input. Expected {N, T, B}, got " + + input.shape().toString() + ); + } + if(input.dim(2) != target.dim(1)) { + throw std::invalid_argument( + "CTC: Batchsize mismatch for input and target with dims " + + input.shape().toString() + " and " + target.shape().toString() + + ", respectively" + ); + } } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.h b/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.h index 318ab2b..e4f8de1 100644 --- a/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.h +++ b/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.h @@ -13,42 +13,46 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { -class ConnectionistTemporalClassificationCriterion : public SequenceCriterion { - public: - ConnectionistTemporalClassificationCriterion( - fl::lib::seq::CriterionScaleMode scalemode = - fl::lib::seq::CriterionScaleMode::NONE); + class ConnectionistTemporalClassificationCriterion : public SequenceCriterion { + public: + ConnectionistTemporalClassificationCriterion( + fl::lib::seq::CriterionScaleMode scalemode = + fl::lib::seq::CriterionScaleMode::NONE + ); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::vector forward( - const std::vector& inputs) override; + std::vector forward( + const std::vector& inputs + ) override; - Tensor viterbiPath(const Tensor& input, const Tensor& inputSize = Tensor()) - override; + Tensor viterbiPath(const Tensor& input, const Tensor& inputSize = Tensor()) + override; - Tensor viterbiPathWithTarget( - const Tensor& input, - const Tensor& target, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) override; + Tensor viterbiPathWithTarget( + const Tensor& input, + const Tensor& target, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - fl::lib::seq::CriterionScaleMode scaleMode_; + private: + fl::lib::seq::CriterionScaleMode scaleMode_; - FL_SAVE_LOAD_WITH_BASE(SequenceCriterion, scaleMode_) + FL_SAVE_LOAD_WITH_BASE(SequenceCriterion, scaleMode_) - void validate(const fl::Variable& input, const fl::Variable& target); -}; + void validate(const fl::Variable& input, const fl::Variable& target); + }; -typedef ConnectionistTemporalClassificationCriterion CTCLoss; -} // namespace speech + typedef ConnectionistTemporalClassificationCriterion CTCLoss; + } // namespace speech } // namespace pkg } // namespace fl CEREAL_REGISTER_TYPE( - fl::pkg::speech::ConnectionistTemporalClassificationCriterion) + fl::pkg::speech::ConnectionistTemporalClassificationCriterion +) diff --git a/flashlight/pkg/speech/criterion/CriterionUtils.cpp b/flashlight/pkg/speech/criterion/CriterionUtils.cpp index 72bb202..1f0cd65 100644 --- a/flashlight/pkg/speech/criterion/CriterionUtils.cpp +++ b/flashlight/pkg/speech/criterion/CriterionUtils.cpp @@ -17,83 +17,86 @@ using fl::lib::seq::CriterionScaleMode; namespace fl::pkg::speech { int countRepeats(const int* labels, int len) { - int r = 0; - for (int i = 1; i < len; ++i) { - if (labels[i] == labels[i - 1]) { - ++r; + int r = 0; + for(int i = 1; i < len; ++i) { + if(labels[i] == labels[i - 1]) { + ++r; + } } - } - return r; + return r; } int getTargetSize(const int* labels, int len) { - while (len > 0 && labels[len - 1] < 0) { - --len; - } - return len; + while(len > 0 && labels[len - 1] < 0) { + --len; + } + return len; } CriterionScaleMode getCriterionScaleMode( const std::string& onorm, - bool sqnorm) { - if (onorm == "none") { - return CriterionScaleMode::NONE; - } else if (onorm == "input") { - return sqnorm ? CriterionScaleMode::INPUT_SZ_SQRT - : CriterionScaleMode::INPUT_SZ; - } else if (onorm == "target") { - return sqnorm ? CriterionScaleMode::TARGET_SZ_SQRT - : CriterionScaleMode::TARGET_SZ; - } else { - throw std::invalid_argument("invalid onorm option"); - } + bool sqnorm +) { + if(onorm == "none") { + return CriterionScaleMode::NONE; + } else if(onorm == "input") { + return sqnorm ? CriterionScaleMode::INPUT_SZ_SQRT + : CriterionScaleMode::INPUT_SZ; + } else if(onorm == "target") { + return sqnorm ? CriterionScaleMode::TARGET_SZ_SQRT + : CriterionScaleMode::TARGET_SZ; + } else { + throw std::invalid_argument("invalid onorm option"); + } } Variable getLinearTarget(const Variable& targetVar, int T) { - int L = targetVar.dim(0); - int B = targetVar.dim(1); + int L = targetVar.dim(0); + int B = targetVar.dim(1); - std::vector target(B * L); - std::vector newTarget(B * T); + std::vector target(B * L); + std::vector newTarget(B * T); - targetVar.host(target.data()); - for (int b = 0; b < B; ++b) { - const auto pTarget = target.data() + b * L; - auto pNewTarget = newTarget.data() + b * T; + targetVar.host(target.data()); + for(int b = 0; b < B; ++b) { + const auto pTarget = target.data() + b * L; + auto pNewTarget = newTarget.data() + b * T; - int targetSize = std::min(T, fl::pkg::speech::getTargetSize(pTarget, L)); - if (targetSize == 0) { - // hacky way to make ASG think L == 0. - std::fill(pNewTarget, pNewTarget + T, -1); - } else { - for (int t = 0; t < T; ++t) { - pNewTarget[t] = pTarget[t * targetSize / T]; - } + int targetSize = std::min(T, fl::pkg::speech::getTargetSize(pTarget, L)); + if(targetSize == 0) { + // hacky way to make ASG think L == 0. + std::fill(pNewTarget, pNewTarget + T, -1); + } else { + for(int t = 0; t < T; ++t) { + pNewTarget[t] = pTarget[t * targetSize / T]; + } + } } - } - return Variable(Tensor::fromVector({T, B}, newTarget), false); + return Variable(Tensor::fromVector({T, B}, newTarget), false); } fl::Variable applySeq2SeqMask( const fl::Variable& input, const Tensor& targetClasses, - int padValue) { - if (input.shape() != targetClasses.shape()) { - throw std::runtime_error( - "applySeq2SeqMask: input and mask should have the same dimentions."); - } - Tensor output = input.tensor(); - Tensor mask = targetClasses == padValue; - output(mask) = 0.; + int padValue +) { + if(input.shape() != targetClasses.shape()) { + throw std::runtime_error( + "applySeq2SeqMask: input and mask should have the same dimentions." + ); + } + Tensor output = input.tensor(); + Tensor mask = targetClasses == padValue; + output(mask) = 0.; - auto gradFunc = [mask]( - std::vector& inputs, - const fl::Variable& gradOutput) { - Tensor gradArray = gradOutput.tensor(); - gradArray(mask) = 0.; - inputs[0].addGrad(fl::Variable(gradArray, false)); - }; - return fl::Variable(output, {input.withoutData()}, gradFunc); + auto gradFunc = [mask]( + std::vector& inputs, + const fl::Variable& gradOutput) { + Tensor gradArray = gradOutput.tensor(); + gradArray(mask) = 0.; + inputs[0].addGrad(fl::Variable(gradArray, false)); + }; + return fl::Variable(output, {input.withoutData()}, gradFunc); } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/CriterionUtils.h b/flashlight/pkg/speech/criterion/CriterionUtils.h index 669d300..60aafc4 100644 --- a/flashlight/pkg/speech/criterion/CriterionUtils.h +++ b/flashlight/pkg/speech/criterion/CriterionUtils.h @@ -17,89 +17,90 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { #define NEG_INFINITY_FLT -std::numeric_limits::infinity() #define NEG_INFINITY_DBL -std::numeric_limits::infinity() -template -inline T logSumExp(T logA, T logB) { - if (logA < logB) { - std::swap(logA, logB); - } - if (logB == -std::numeric_limits::infinity()) { - return logA; - } - return logA + std::log1p(std::exp(logB - logA)); -} - -template -inline T logSumExp(T logA, T logB, T logC) { - if (logA < logB) { - std::swap(logA, logB); - } - if (logA < logC) { - std::swap(logA, logC); - } - if (logB < logC) { - std::swap(logB, logC); - } - if (logC == -std::numeric_limits::infinity()) { - return logSumExp(logA, logB); - } - return logA + std::log1p(std::exp(logB - logA) + std::exp(logC - logA)); -} - -template -inline void dLogSumExp(T in1, T in2, T& d1, T& d2, const float scale) { - T maxIn = std::max(in1, in2); - - in1 = std::exp(in1 - maxIn); - in2 = std::exp(in2 - maxIn); - - T Z = in1 + in2; - - d1 += scale * (in1 / Z); - d2 += scale * (in2 / Z); -} - -template -inline void -dLogSumExp(T in1, T in2, T in3, T& d1, T& d2, T& d3, const float scale) { - T maxIn = std::max(std::max(in1, in2), in3); - - in1 = std::exp(in1 - maxIn); - in2 = std::exp(in2 - maxIn); - in3 = std::exp(in3 - maxIn); - - T Z = in1 + in2 + in3; - - d1 += scale * (in1 / Z); - d2 += scale * (in2 / Z); - d3 += scale * (in3 / Z); -} - -int countRepeats(const int* labels, int len); - -int getTargetSize(const int* labels, int len); - -Tensor getTargetSizeArray(const Tensor& target, int maxSize); - -lib::seq::CriterionScaleMode getCriterionScaleMode( - const std::string& onorm, - bool sqnorm); + template + inline T logSumExp(T logA, T logB) { + if(logA < logB) { + std::swap(logA, logB); + } + if(logB == -std::numeric_limits::infinity()) { + return logA; + } + return logA + std::log1p(std::exp(logB - logA)); + } + + template + inline T logSumExp(T logA, T logB, T logC) { + if(logA < logB) { + std::swap(logA, logB); + } + if(logA < logC) { + std::swap(logA, logC); + } + if(logB < logC) { + std::swap(logB, logC); + } + if(logC == -std::numeric_limits::infinity()) { + return logSumExp(logA, logB); + } + return logA + std::log1p(std::exp(logB - logA) + std::exp(logC - logA)); + } + + template + inline void dLogSumExp(T in1, T in2, T& d1, T& d2, const float scale) { + T maxIn = std::max(in1, in2); + + in1 = std::exp(in1 - maxIn); + in2 = std::exp(in2 - maxIn); + + T Z = in1 + in2; + + d1 += scale * (in1 / Z); + d2 += scale * (in2 / Z); + } + + template + inline void dLogSumExp(T in1, T in2, T in3, T& d1, T& d2, T& d3, const float scale) { + T maxIn = std::max(std::max(in1, in2), in3); + + in1 = std::exp(in1 - maxIn); + in2 = std::exp(in2 - maxIn); + in3 = std::exp(in3 - maxIn); + + T Z = in1 + in2 + in3; + + d1 += scale * (in1 / Z); + d2 += scale * (in2 / Z); + d3 += scale * (in3 / Z); + } + + int countRepeats(const int* labels, int len); + + int getTargetSize(const int* labels, int len); + + Tensor getTargetSizeArray(const Tensor& target, int maxSize); + + lib::seq::CriterionScaleMode getCriterionScaleMode( + const std::string& onorm, + bool sqnorm + ); // Input: N x T x B (type: float), Output: T x B (type: int) -Tensor viterbiPath(const Tensor& input, const Tensor& trans); + Tensor viterbiPath(const Tensor& input, const Tensor& trans); -fl::Variable getLinearTarget(const fl::Variable& target, int T); + fl::Variable getLinearTarget(const fl::Variable& target, int T); // apply mask to the input with proper grad. // Mask should be the same size as input -fl::Variable applySeq2SeqMask( - const fl::Variable& input, - const Tensor& targetClasses, - int padValue); -} // namespace speech + fl::Variable applySeq2SeqMask( + const fl::Variable& input, + const Tensor& targetClasses, + int padValue + ); + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/Defines.h b/flashlight/pkg/speech/criterion/Defines.h index e35aa96..6fcc3fb 100644 --- a/flashlight/pkg/speech/criterion/Defines.h +++ b/flashlight/pkg/speech/criterion/Defines.h @@ -11,12 +11,12 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { // sampling strategy to use in decoder in place of teacher forcing -constexpr const char* kModelSampling = "model"; -constexpr const char* kRandSampling = "rand"; -constexpr const char* kGumbelSampling = "gumbel"; -} // namespace speech + constexpr const char* kModelSampling = "model"; + constexpr const char* kRandSampling = "rand"; + constexpr const char* kGumbelSampling = "gumbel"; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.cpp b/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.cpp index 6e87c1c..a86bfe4 100644 --- a/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.cpp +++ b/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.cpp @@ -11,22 +11,25 @@ namespace fl::pkg::speech { ForceAlignmentCriterion::ForceAlignmentCriterion( int N, - fl::lib::seq::CriterionScaleMode scalemode) - : N_(N), scaleMode_(scalemode) { - if (N_ <= 0) { - throw std::invalid_argument( - "FAC: Size of transition matrix is less than 0"); - } - auto transition = fl::constant(0.0, {N_, N_}); - params_ = {transition}; + fl::lib::seq::CriterionScaleMode scalemode +) : N_(N), + scaleMode_(scalemode) { + if(N_ <= 0) { + throw std::invalid_argument( + "FAC: Size of transition matrix is less than 0" + ); + } + auto transition = fl::constant(0.0, {N_, N_}); + params_ = {transition}; } std::unique_ptr ForceAlignmentCriterion::clone() const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'ForceAlignmentCriterion'"); + throw std::runtime_error( + "Cloning is unimplemented in Module 'ForceAlignmentCriterion'" + ); } std::string ForceAlignmentCriterion::prettyString() const { - return "ForceAlignmentCriterion"; + return "ForceAlignmentCriterion"; } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.h b/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.h index 52104d6..a6cb77b 100644 --- a/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.h +++ b/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.h @@ -16,38 +16,40 @@ using fl::lib::seq::CriterionScaleMode; namespace fl { namespace pkg { -namespace speech { + namespace speech { -class ForceAlignmentCriterion : public fl::BinaryModule { - public: - explicit ForceAlignmentCriterion( - int N, - CriterionScaleMode scalemode = CriterionScaleMode::NONE); + class ForceAlignmentCriterion : public fl::BinaryModule { + public: + explicit ForceAlignmentCriterion( + int N, + CriterionScaleMode scalemode = CriterionScaleMode::NONE + ); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - fl::Variable forward(const fl::Variable& input, const fl::Variable& target) - override; + fl::Variable forward(const fl::Variable& input, const fl::Variable& target) + override; - Tensor viterbiPath(const Tensor& input, const Tensor& target); + Tensor viterbiPath(const Tensor& input, const Tensor& target); - std::string prettyString() const override; + std::string prettyString() const override; - private: - friend class AutoSegmentationCriterion; - ForceAlignmentCriterion() = default; + private: + friend class AutoSegmentationCriterion; + ForceAlignmentCriterion() = default; - int N_; - CriterionScaleMode scaleMode_; + int N_; + CriterionScaleMode scaleMode_; - FL_SAVE_LOAD_WITH_BASE( - fl::BinaryModule, - fl::serializeAs(N_), - scaleMode_) -}; + FL_SAVE_LOAD_WITH_BASE( + fl::BinaryModule, + fl::serializeAs(N_), + scaleMode_ + ) + }; -typedef ForceAlignmentCriterion FACLoss; -} // namespace speech + typedef ForceAlignmentCriterion FACLoss; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/FullConnectionCriterion.cpp b/flashlight/pkg/speech/criterion/FullConnectionCriterion.cpp index da04fc9..faa945b 100644 --- a/flashlight/pkg/speech/criterion/FullConnectionCriterion.cpp +++ b/flashlight/pkg/speech/criterion/FullConnectionCriterion.cpp @@ -15,22 +15,25 @@ namespace fl::pkg::speech { FullConnectionCriterion::FullConnectionCriterion( int N, - fl::lib::seq::CriterionScaleMode scalemode) - : N_(N), scaleMode_(scalemode) { - if (N_ <= 0) { - throw std::invalid_argument( - "FCC: Size of transition matrix is less than 0."); - } - auto transition = constant(0.0, {N_, N_}); - params_ = {transition}; + fl::lib::seq::CriterionScaleMode scalemode +) : N_(N), + scaleMode_(scalemode) { + if(N_ <= 0) { + throw std::invalid_argument( + "FCC: Size of transition matrix is less than 0." + ); + } + auto transition = constant(0.0, {N_, N_}); + params_ = {transition}; } std::unique_ptr FullConnectionCriterion::clone() const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'FullConnectionCriterion'"); + throw std::runtime_error( + "Cloning is unimplemented in Module 'FullConnectionCriterion'" + ); } std::string FullConnectionCriterion::prettyString() const { - return "FullConnectionCriterion"; + return "FullConnectionCriterion"; } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/FullConnectionCriterion.h b/flashlight/pkg/speech/criterion/FullConnectionCriterion.h index e76f148..165e153 100644 --- a/flashlight/pkg/speech/criterion/FullConnectionCriterion.h +++ b/flashlight/pkg/speech/criterion/FullConnectionCriterion.h @@ -13,37 +13,39 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { -class FullConnectionCriterion : public fl::BinaryModule { - public: - explicit FullConnectionCriterion( - int N, - fl::lib::seq::CriterionScaleMode scalemode = - fl::lib::seq::CriterionScaleMode::NONE); + class FullConnectionCriterion : public fl::BinaryModule { + public: + explicit FullConnectionCriterion( + int N, + fl::lib::seq::CriterionScaleMode scalemode = + fl::lib::seq::CriterionScaleMode::NONE + ); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - fl::Variable forward(const fl::Variable& input, const fl::Variable& target) - override; + fl::Variable forward(const fl::Variable& input, const fl::Variable& target) + override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - friend class AutoSegmentationCriterion; - FullConnectionCriterion() = default; + private: + friend class AutoSegmentationCriterion; + FullConnectionCriterion() = default; - int N_; - fl::lib::seq::CriterionScaleMode scaleMode_; + int N_; + fl::lib::seq::CriterionScaleMode scaleMode_; - FL_SAVE_LOAD_WITH_BASE( - fl::BinaryModule, - fl::serializeAs(N_), - scaleMode_) -}; + FL_SAVE_LOAD_WITH_BASE( + fl::BinaryModule, + fl::serializeAs(N_), + scaleMode_ + ) + }; -typedef FullConnectionCriterion FCCLoss; -} // namespace speech + typedef FullConnectionCriterion FCCLoss; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/LinearSegmentationCriterion.h b/flashlight/pkg/speech/criterion/LinearSegmentationCriterion.h index 511c55b..fffeeb9 100644 --- a/flashlight/pkg/speech/criterion/LinearSegmentationCriterion.h +++ b/flashlight/pkg/speech/criterion/LinearSegmentationCriterion.h @@ -14,43 +14,45 @@ using fl::lib::seq::CriterionScaleMode; namespace fl { namespace pkg { -namespace speech { - -class LinearSegmentationCriterion : public AutoSegmentationCriterion { - public: - explicit LinearSegmentationCriterion( - int N, - CriterionScaleMode scaleMode = CriterionScaleMode::NONE) - : AutoSegmentationCriterion(N, scaleMode) {} - - std::unique_ptr clone() const override { - throw std::runtime_error( - "Cloning is unimplemented in Module 'LinearSegmentationCriterion'"); - } - - std::vector forward( - const std::vector& inputs) override { - if (inputs.size() != 2) { - throw std::invalid_argument("Invalid inputs size"); - } - const auto& input = inputs[0]; - const auto& target = inputs[1]; - return AutoSegmentationCriterion::forward( - {input, getLinearTarget(target, input.dim(1))}); - } - - std::string prettyString() const override { - return "LinearSegmentationCriterion"; - } - - private: - LinearSegmentationCriterion() = default; - - FL_SAVE_LOAD_WITH_BASE(AutoSegmentationCriterion) -}; - -using LinSegCriterion = LinearSegmentationCriterion; -} // namespace speech + namespace speech { + + class LinearSegmentationCriterion : public AutoSegmentationCriterion { + public: + explicit LinearSegmentationCriterion( + int N, + CriterionScaleMode scaleMode = CriterionScaleMode::NONE + ) : AutoSegmentationCriterion(N, scaleMode) {} + + std::unique_ptr clone() const override { + throw std::runtime_error( + "Cloning is unimplemented in Module 'LinearSegmentationCriterion'" + ); + } + + std::vector forward( + const std::vector& inputs + ) override { + if(inputs.size() != 2) { + throw std::invalid_argument("Invalid inputs size"); + } + const auto& input = inputs[0]; + const auto& target = inputs[1]; + return AutoSegmentationCriterion::forward( + {input, getLinearTarget(target, input.dim(1))}); + } + + std::string prettyString() const override { + return "LinearSegmentationCriterion"; + } + + private: + LinearSegmentationCriterion() = default; + + FL_SAVE_LOAD_WITH_BASE(AutoSegmentationCriterion) + }; + + using LinSegCriterion = LinearSegmentationCriterion; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp index f8d6a97..f85546b 100644 --- a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp +++ b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp @@ -19,57 +19,57 @@ namespace fl::pkg::speech { namespace detail { -Seq2SeqState concatState(std::vector& stateVec) { - if (stateVec.empty()) { - throw std::runtime_error("Empty stateVec"); - } - - int nAttnRound = stateVec[0].hidden.size(); - Seq2SeqState newState(nAttnRound); - newState.step = stateVec[0].step; - newState.peakAttnPos = stateVec[0].peakAttnPos; - newState.isValid = stateVec[0].isValid; - - std::vector alphaVec; - std::vector> hiddenVec(nAttnRound); - std::vector summaryVec; - for (auto& state : stateVec) { - if (state.step != newState.step) { - throw std::runtime_error("step unmatched"); - } else if (state.isValid != newState.isValid) { - throw std::runtime_error("isValid unmatched"); - } - alphaVec.push_back(state.alpha); - for (int i = 0; i < nAttnRound; i++) { - hiddenVec[i].push_back(state.hidden[i]); + Seq2SeqState concatState(std::vector& stateVec) { + if(stateVec.empty()) { + throw std::runtime_error("Empty stateVec"); + } + + int nAttnRound = stateVec[0].hidden.size(); + Seq2SeqState newState(nAttnRound); + newState.step = stateVec[0].step; + newState.peakAttnPos = stateVec[0].peakAttnPos; + newState.isValid = stateVec[0].isValid; + + std::vector alphaVec; + std::vector> hiddenVec(nAttnRound); + std::vector summaryVec; + for(auto& state : stateVec) { + if(state.step != newState.step) { + throw std::runtime_error("step unmatched"); + } else if(state.isValid != newState.isValid) { + throw std::runtime_error("isValid unmatched"); + } + alphaVec.push_back(state.alpha); + for(int i = 0; i < nAttnRound; i++) { + hiddenVec[i].push_back(state.hidden[i]); + } + summaryVec.push_back(state.summary); + } + + newState.alpha = concatenate(alphaVec, 2); + for(int i = 0; i < nAttnRound; i++) { + newState.hidden[i] = concatenate(hiddenVec[i], 1); + } + newState.summary = concatenate(summaryVec, 2); + return newState; } - summaryVec.push_back(state.summary); - } - - newState.alpha = concatenate(alphaVec, 2); - for (int i = 0; i < nAttnRound; i++) { - newState.hidden[i] = concatenate(hiddenVec[i], 1); - } - newState.summary = concatenate(summaryVec, 2); - return newState; -} -Seq2SeqState selectState(Seq2SeqState& state, int batchIdx) { - int nAttnRound = state.hidden.size(); - Seq2SeqState newState(nAttnRound); - newState.step = state.step; - newState.peakAttnPos = state.peakAttnPos; - newState.isValid = state.isValid; - newState.alpha = - state.alpha(fl::span, fl::span, fl::range(batchIdx, batchIdx + 1)); - newState.summary = - state.summary(fl::span, fl::span, fl::range(batchIdx, batchIdx + 1)); - for (int i = 0; i < nAttnRound; i++) { - newState.hidden[i] = - state.hidden[i](fl::span, fl::range(batchIdx, batchIdx + 1)); - } - return newState; -} + Seq2SeqState selectState(Seq2SeqState& state, int batchIdx) { + int nAttnRound = state.hidden.size(); + Seq2SeqState newState(nAttnRound); + newState.step = state.step; + newState.peakAttnPos = state.peakAttnPos; + newState.isValid = state.isValid; + newState.alpha = + state.alpha(fl::span, fl::span, fl::range(batchIdx, batchIdx + 1)); + newState.summary = + state.summary(fl::span, fl::span, fl::range(batchIdx, batchIdx + 1)); + for(int i = 0; i < nAttnRound; i++) { + newState.hidden[i] = + state.hidden[i](fl::span, fl::range(batchIdx, batchIdx + 1)); + } + return newState; + } } // namespace detail Seq2SeqCriterion::Seq2SeqCriterion( @@ -88,46 +88,55 @@ Seq2SeqCriterion::Seq2SeqCriterion( double gumbelTemperature /* = 1.0 */, int nRnnLayer /* = 1 */, int nAttnRound /* = 1 */, - float dropOut /* = 0.0 */) - : eos_(eos), - pad_(pad), - maxDecoderOutputLen_(maxDecoderOutputLen), - window_(window), - trainWithWindow_(trainWithWindow), - pctTeacherForcing_(pctTeacherForcing), - labelSmooth_(labelSmooth), - inputFeeding_(inputFeeding), - nClass_(nClass), - samplingStrategy_(samplingStrategy), - gumbelTemperature_(gumbelTemperature), - nAttnRound_(nAttnRound) { - // 1. Embedding - add(std::make_shared(hiddenDim, nClass_)); - - // 2. RNN - for (int i = 0; i < nAttnRound_; i++) { - add(std::make_shared( - hiddenDim, hiddenDim, nRnnLayer, RnnMode::GRU, false, dropOut)); - } - - // 3. Linear - add(std::make_shared(hiddenDim, nClass_)); - // FIXME: Having a linear layer in between RNN and attention is only for - // backward compatibility. - - // 4. Attention - for (int i = 0; i < nAttnRound_; i++) { - add(attentions[i]); - } - - // 5. Initial hidden state - params_.push_back(fl::uniform(Shape{hiddenDim}, -1e-1, 1e-1)); - setUseSequentialDecoder(); + float dropOut /* = 0.0 */ +) : eos_(eos), + pad_(pad), + maxDecoderOutputLen_(maxDecoderOutputLen), + window_(window), + trainWithWindow_(trainWithWindow), + pctTeacherForcing_(pctTeacherForcing), + labelSmooth_(labelSmooth), + inputFeeding_(inputFeeding), + nClass_(nClass), + samplingStrategy_(samplingStrategy), + gumbelTemperature_(gumbelTemperature), + nAttnRound_(nAttnRound) { + // 1. Embedding + add(std::make_shared(hiddenDim, nClass_)); + + // 2. RNN + for(int i = 0; i < nAttnRound_; i++) { + add( + std::make_shared( + hiddenDim, + hiddenDim, + nRnnLayer, + RnnMode::GRU, + false, + dropOut + ) + ); + } + + // 3. Linear + add(std::make_shared(hiddenDim, nClass_)); + // FIXME: Having a linear layer in between RNN and attention is only for + // backward compatibility. + + // 4. Attention + for(int i = 0; i < nAttnRound_; i++) { + add(attentions[i]); + } + + // 5. Initial hidden state + params_.push_back(fl::uniform(Shape{hiddenDim}, -1e-1, 1e-1)); + setUseSequentialDecoder(); } std::unique_ptr Seq2SeqCriterion::clone() const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'Seq2SeqCriterion'"); + throw std::runtime_error( + "Cloning is unimplemented in Module 'Seq2SeqCriterion'" + ); } /** @@ -140,210 +149,234 @@ std::unique_ptr Seq2SeqCriterion::clone() const { */ std::vector Seq2SeqCriterion::forward( - const std::vector& inputs) { - if (inputs.size() < 2 || (inputs.size() > 4)) { - throw std::invalid_argument( - "Invalid inputs size; Seq2Seq criterion takes input, target, inputSizes [optional]"); - } - const auto& input = inputs[0]; - const auto& target = inputs[1]; - const auto& inputSizes = - inputs.size() == 2 ? Tensor() : inputs[2].tensor(); // 1 x B - const auto& targetSizes = - inputs.size() == 3 ? Tensor() : inputs[3].tensor(); // 1 x B - - Variable out, alpha; - if (useSequentialDecoder_) { - std::tie(out, alpha) = decoder(input, target, inputSizes, targetSizes); - } else { - std::tie(out, alpha) = - vectorizedDecoder(input, target, inputSizes, targetSizes); - } - - out = logSoftmax(out, 0); // C x U x B - - auto losses = moddims( - sum(categoricalCrossEntropy(out, target, ReduceMode::NONE, pad_), {0}), - {-1}); - if (train_ && labelSmooth_ > 0) { - size_t nClass = out.dim(0); - auto targetTiled = fl::tile( - fl::reshape(target.tensor(), {1, target.dim(0), target.dim(1)}), - {static_cast(nClass)}); - out = applySeq2SeqMask(out, targetTiled, pad_); - auto smoothLoss = moddims(sum(out, {0, 1}), {-1}); - losses = (1 - labelSmooth_) * losses - (labelSmooth_ / nClass) * smoothLoss; - } - - return {losses, out}; + const std::vector& inputs +) { + if(inputs.size() < 2 || (inputs.size() > 4)) { + throw std::invalid_argument( + "Invalid inputs size; Seq2Seq criterion takes input, target, inputSizes [optional]" + ); + } + const auto& input = inputs[0]; + const auto& target = inputs[1]; + const auto& inputSizes = + inputs.size() == 2 ? Tensor() : inputs[2].tensor(); // 1 x B + const auto& targetSizes = + inputs.size() == 3 ? Tensor() : inputs[3].tensor(); // 1 x B + + Variable out, alpha; + if(useSequentialDecoder_) { + std::tie(out, alpha) = decoder(input, target, inputSizes, targetSizes); + } else { + std::tie(out, alpha) = + vectorizedDecoder(input, target, inputSizes, targetSizes); + } + + out = logSoftmax(out, 0); // C x U x B + + auto losses = moddims( + sum(categoricalCrossEntropy(out, target, ReduceMode::NONE, pad_), {0}), + {-1} + ); + if(train_ && labelSmooth_ > 0) { + size_t nClass = out.dim(0); + auto targetTiled = fl::tile( + fl::reshape(target.tensor(), {1, target.dim(0), target.dim(1)}), + {static_cast(nClass)} + ); + out = applySeq2SeqMask(out, targetTiled, pad_); + auto smoothLoss = moddims(sum(out, {0, 1}), {-1}); + losses = (1 - labelSmooth_) * losses - (labelSmooth_ / nClass) * smoothLoss; + } + + return {losses, out}; } std::pair Seq2SeqCriterion::vectorizedDecoder( const Variable& input, const Variable& target, const Tensor& inputSizes, - const Tensor& targetSizes) { - if (target.ndim() != 2) { - throw std::invalid_argument( - "Seq2SeqCriterion::vectorizedDecoder: " - "target expects to be shape {U, B}"); - } - int U = target.dim(0); - int B = target.dim(1); - int T = input.dim(1); - - auto hy = tile(startEmbedding(), {1, 1, B}); // H x 1 x B - - if (U > 1) { - // Slice off eos - auto y = target(fl::range(0, U - 1), fl::span); - if (train_) { - if (samplingStrategy_ == fl::pkg::speech::kModelSampling) { - throw std::logic_error( - "vectorizedDecoder does not support model sampling"); - } else if (samplingStrategy_ == fl::pkg::speech::kRandSampling) { - auto mask = Variable( - (fl::rand(y.shape()) * 100 <= pctTeacherForcing_).astype(y.type()), - false); - auto samples = Variable( - (fl::rand(y.shape()) * (nClass_ - 1)).astype(y.type()), false); - - y = mask * y + (1 - mask) * samples; - } + const Tensor& targetSizes +) { + if(target.ndim() != 2) { + throw std::invalid_argument( + "Seq2SeqCriterion::vectorizedDecoder: " + "target expects to be shape {U, B}" + ); + } + int U = target.dim(0); + int B = target.dim(1); + int T = input.dim(1); + + auto hy = tile(startEmbedding(), {1, 1, B}); // H x 1 x B + + if(U > 1) { + // Slice off eos + auto y = target(fl::range(0, U - 1), fl::span); + if(train_) { + if(samplingStrategy_ == fl::pkg::speech::kModelSampling) { + throw std::logic_error( + "vectorizedDecoder does not support model sampling" + ); + } else if(samplingStrategy_ == fl::pkg::speech::kRandSampling) { + auto mask = Variable( + (fl::rand(y.shape()) * 100 <= pctTeacherForcing_).astype(y.type()), + false + ); + auto samples = Variable( + (fl::rand(y.shape()) * (nClass_ - 1)).astype(y.type()), + false + ); + + y = mask * y + (1 - mask) * samples; + } + } + + auto yEmbed = embedding()->forward(y); + hy = concatenate({hy, yEmbed}, 1); // H x U x B } - auto yEmbed = embedding()->forward(y); - hy = concatenate({hy, yEmbed}, 1); // H x U x B - } + Variable alpha, summaries; + for(int i = 0; i < nAttnRound_; i++) { + hy = fl::transpose(hy, {0, 2, 1}); // H x U x B -> H x B x U + hy = decodeRNN(i)->forward(hy); + hy = fl::transpose(hy, {0, 2, 1}); // H x B x U -> H x U x B - Variable alpha, summaries; - for (int i = 0; i < nAttnRound_; i++) { - hy = fl::transpose(hy, {0, 2, 1}); // H x U x B -> H x B x U - hy = decodeRNN(i)->forward(hy); - hy = fl::transpose(hy, {0, 2, 1}); // H x B x U -> H x U x B + Variable windowWeight; + if(window_ && (!train_ || trainWithWindow_)) { + windowWeight = + window_->computeVectorizedWindow(U, T, B, inputSizes, targetSizes); + } - Variable windowWeight; - if (window_ && (!train_ || trainWithWindow_)) { - windowWeight = - window_->computeVectorizedWindow(U, T, B, inputSizes, targetSizes); + std::tie(alpha, summaries) = attention(i)->forward( + hy, + input, + Variable(), // vectorizedDecoder does not support prev_attn input + windowWeight, + fl::noGrad(inputSizes) + ); + hy = hy + summaries; } - std::tie(alpha, summaries) = attention(i)->forward( - hy, - input, - Variable(), // vectorizedDecoder does not support prev_attn input - windowWeight, - fl::noGrad(inputSizes)); - hy = hy + summaries; - } - - auto out = linearOut()->forward(hy); // C x U x B - return std::make_pair(out, alpha); + auto out = linearOut()->forward(hy); // C x U x B + return std::make_pair(out, alpha); } std::pair Seq2SeqCriterion::decoder( const Variable& input, const Variable& target, const Tensor& inputSizes, - const Tensor& targetSizes) { - int U = target.dim(0); - - std::vector outvec; - std::vector alphaVec; - Seq2SeqState state(nAttnRound_); - Variable y; - for (int u = 0; u < U; u++) { - Variable ox; - std::tie(ox, state) = - decodeStep(input, y, state, inputSizes, targetSizes, U); - - if (!train_) { - y = target(fl::range(u, u + 1), fl::span); - } else if (samplingStrategy_ == fl::pkg::speech::kGumbelSampling) { - double eps = 1e-7; - auto gb = -log(-log((1 - 2 * eps) * fl::rand(ox.shape()) + eps)); - ox = logSoftmax((ox + Variable(gb, false)) / gumbelTemperature_, 0); - y = Variable(exp(ox).tensor(), false); - } else if (fl::all(fl::rand({1}) * 100 <= fl::full({1}, pctTeacherForcing_)) - .asScalar()) { - y = target(fl::range(u, u + 1), fl::span); - } else if (samplingStrategy_ == fl::pkg::speech::kModelSampling) { - Tensor maxIdx, maxValues; - fl::max(maxValues, maxIdx, ox.tensor(), 0); - y = Variable(maxIdx, false); - } else if (samplingStrategy_ == fl::pkg::speech::kRandSampling) { - y = Variable( - (fl::rand({1, target.dim(1)}) * (nClass_ - 1)).astype(fl::dtype::s32), - false); - } else { - throw std::invalid_argument("Invalid sampling strategy"); - } + const Tensor& targetSizes +) { + int U = target.dim(0); + + std::vector outvec; + std::vector alphaVec; + Seq2SeqState state(nAttnRound_); + Variable y; + for(int u = 0; u < U; u++) { + Variable ox; + std::tie(ox, state) = + decodeStep(input, y, state, inputSizes, targetSizes, U); + + if(!train_) { + y = target(fl::range(u, u + 1), fl::span); + } else if(samplingStrategy_ == fl::pkg::speech::kGumbelSampling) { + double eps = 1e-7; + auto gb = -log(-log((1 - 2 * eps) * fl::rand(ox.shape()) + eps)); + ox = logSoftmax((ox + Variable(gb, false)) / gumbelTemperature_, 0); + y = Variable(exp(ox).tensor(), false); + } else if( + fl::all(fl::rand({1}) * 100 <= fl::full({1}, pctTeacherForcing_)) + .asScalar() + ) { + y = target(fl::range(u, u + 1), fl::span); + } else if(samplingStrategy_ == fl::pkg::speech::kModelSampling) { + Tensor maxIdx, maxValues; + fl::max(maxValues, maxIdx, ox.tensor(), 0); + y = Variable(maxIdx, false); + } else if(samplingStrategy_ == fl::pkg::speech::kRandSampling) { + y = Variable( + (fl::rand({1, target.dim(1)}) * (nClass_ - 1)).astype(fl::dtype::s32), + false + ); + } else { + throw std::invalid_argument("Invalid sampling strategy"); + } - outvec.push_back(ox); - alphaVec.push_back(state.alpha); - } + outvec.push_back(ox); + alphaVec.push_back(state.alpha); + } - auto out = concatenate(outvec, 1); // C x U x B - auto alpha = concatenate(alphaVec, 0); // U x T x B + auto out = concatenate(outvec, 1); // C x U x B + auto alpha = concatenate(alphaVec, 0); // U x T x B - return std::make_pair(out, alpha); + return std::make_pair(out, alpha); } Tensor Seq2SeqCriterion::viterbiPath( const Tensor& input, - const Tensor& inputSizes /* = Tensor() */) { - return viterbiPathBase(input, inputSizes, false).first; + const Tensor& inputSizes /* = Tensor() */ +) { + return viterbiPathBase(input, inputSizes, false).first; } std::pair Seq2SeqCriterion::viterbiPathBase( const Tensor& input, const Tensor& inputSizes, - bool saveAttn) { - // NB: xEncoded has to be with batchsize 1 - bool wasTrain = train_; - eval(); - std::vector maxPath; - std::vector alphaVec; - Variable alpha; - Seq2SeqState state(nAttnRound_); - Variable y, ox; - Tensor maxIdx, maxValues; - int pred; - for (int u = 0; u < maxDecoderOutputLen_; u++) { - std::tie(ox, state) = decodeStep( - Variable(input, false), y, state, inputSizes, Tensor(), input.dim(1)); - fl::max(maxValues, maxIdx, ox.tensor(), 0); - pred = maxIdx.asScalar(); - if (saveAttn) { - alphaVec.push_back(state.alpha); + bool saveAttn +) { + // NB: xEncoded has to be with batchsize 1 + bool wasTrain = train_; + eval(); + std::vector maxPath; + std::vector alphaVec; + Variable alpha; + Seq2SeqState state(nAttnRound_); + Variable y, ox; + Tensor maxIdx, maxValues; + int pred; + for(int u = 0; u < maxDecoderOutputLen_; u++) { + std::tie(ox, state) = decodeStep( + Variable(input, false), + y, + state, + inputSizes, + Tensor(), + input.dim(1) + ); + fl::max(maxValues, maxIdx, ox.tensor(), 0); + pred = maxIdx.asScalar(); + if(saveAttn) { + alphaVec.push_back(state.alpha); + } + + if(pred == eos_) { + break; + } + y = constant(pred, {1}, fl::dtype::s32, false); + maxPath.push_back(pred); + } + if(saveAttn) { + alpha = concatenate(alphaVec, 0); } - if (pred == eos_) { - break; + if(wasTrain) { + train(); } - y = constant(pred, {1}, fl::dtype::s32, false); - maxPath.push_back(pred); - } - if (saveAttn) { - alpha = concatenate(alphaVec, 0); - } - - if (wasTrain) { - train(); - } - Tensor vPath = maxPath.empty() ? Tensor() : Tensor::fromVector(maxPath); - return std::make_pair(vPath, alpha); + Tensor vPath = maxPath.empty() ? Tensor() : Tensor::fromVector(maxPath); + return std::make_pair(vPath, alpha); } std::vector Seq2SeqCriterion::beamPath( const Tensor& input, const Tensor& inputSizes, - int beamSize /* = 10 */) { - std::vector beam; - beam.emplace_back(); - auto beamPaths = - beamSearch(input, inputSizes, beam, beamSize, maxDecoderOutputLen_); - return beamPaths[0].path; + int beamSize /* = 10 */ +) { + std::vector beam; + beam.emplace_back(); + auto beamPaths = + beamSearch(input, inputSizes, beam, beamSize, maxDecoderOutputLen_); + return beamPaths[0].path; } // beam are candidates that need to be extended @@ -352,112 +385,126 @@ std::vector Seq2SeqCriterion::beamSearch( const Tensor& inputSizes, // 1 x B std::vector beam, int beamSize = 10, - int maxLen = 200) { - bool wasTrain = train_; - eval(); - - std::vector complete; - std::vector newBeam; - auto cmpfn = [](Seq2SeqCriterion::CandidateHypo& lhs, - Seq2SeqCriterion::CandidateHypo& rhs) { - return lhs.score > rhs.score; - }; - - for (int l = 0; l < maxLen; l++) { - newBeam.resize(0); - - std::vector prevYVec; - std::vector prevStateVec; - std::vector prevScoreVec; - for (auto& hypo : beam) { - Variable y; - if (!hypo.path.empty()) { - y = constant(hypo.path.back(), {1}, fl::dtype::s32, false); - } - prevYVec.push_back(y); - prevStateVec.push_back(hypo.state); - prevScoreVec.push_back(hypo.score); - } - auto prevY = concatenate(prevYVec, 1); // 1 x B - auto prevState = detail::concatState(prevStateVec); - int B = prevY.ndim() < 2 ? 1 : prevY.dim(1); - - Variable ox; - Seq2SeqState state; - // do proper cast of input size to batch size - // because we have beam now for the input - auto tiledInputSizes = fl::tile(inputSizes, {1, B}); - std::tie(ox, state) = decodeStep( - Variable(input, false), - prevY, - prevState, - tiledInputSizes, - Tensor(), - input.dim(1)); - ox = logSoftmax(ox, 0); // C x 1 x B - ox = fl::reorder(ox, {0, 2, 1}); - - auto scoreArr = Tensor::fromBuffer( - {1, static_cast(beam.size()), 1}, - prevScoreVec.data(), - MemoryLocation::Host); - scoreArr = fl::tile(scoreArr, {ox.dim(0)}); - - scoreArr = scoreArr + ox.tensor(); // C x B - scoreArr = scoreArr.flatten(); // column-first - auto scoreVec = scoreArr.toHostVector(); - - std::vector indices(scoreVec.size()); - std::iota(indices.begin(), indices.end(), 0); - std::partial_sort( - indices.begin(), - indices.begin() + - std::min(2 * beamSize, static_cast(scoreVec.size())), - indices.end(), - [&scoreVec](size_t i1, size_t i2) { - return scoreVec[i1] > scoreVec[i2]; - }); - - int nClass = ox.dim(0); - for (int j = 0; j < indices.size(); j++) { - int hypIdx = indices[j] / nClass; - int clsIdx = indices[j] % nClass; - std::vector path_(beam[hypIdx].path); - path_.push_back(clsIdx); - if (j < beamSize && clsIdx == eos_) { - path_.pop_back(); - complete.emplace_back( - scoreVec[indices[j]], path_, detail::selectState(state, hypIdx)); - } else if (clsIdx != eos_) { - newBeam.emplace_back( - scoreVec[indices[j]], path_, detail::selectState(state, hypIdx)); - } - if (newBeam.size() >= beamSize) { - break; - } - } - beam.resize(newBeam.size()); - beam = std::move(newBeam); - - if (complete.size() >= beamSize) { - std::partial_sort( - complete.begin(), complete.begin() + beamSize, complete.end(), cmpfn); - complete.resize(beamSize); - - // if lowest score in complete is better than best future hypo - // then its not possible for any future hypothesis to replace existing - // hypothesises in complete. - if (complete.back().score > beam[0].score) { - break; - } + int maxLen = 200 +) { + bool wasTrain = train_; + eval(); + + std::vector complete; + std::vector newBeam; + auto cmpfn = [](Seq2SeqCriterion::CandidateHypo& lhs, + Seq2SeqCriterion::CandidateHypo& rhs) { + return lhs.score > rhs.score; + }; + + for(int l = 0; l < maxLen; l++) { + newBeam.resize(0); + + std::vector prevYVec; + std::vector prevStateVec; + std::vector prevScoreVec; + for(auto& hypo : beam) { + Variable y; + if(!hypo.path.empty()) { + y = constant(hypo.path.back(), {1}, fl::dtype::s32, false); + } + prevYVec.push_back(y); + prevStateVec.push_back(hypo.state); + prevScoreVec.push_back(hypo.score); + } + auto prevY = concatenate(prevYVec, 1); // 1 x B + auto prevState = detail::concatState(prevStateVec); + int B = prevY.ndim() < 2 ? 1 : prevY.dim(1); + + Variable ox; + Seq2SeqState state; + // do proper cast of input size to batch size + // because we have beam now for the input + auto tiledInputSizes = fl::tile(inputSizes, {1, B}); + std::tie(ox, state) = decodeStep( + Variable(input, false), + prevY, + prevState, + tiledInputSizes, + Tensor(), + input.dim(1) + ); + ox = logSoftmax(ox, 0); // C x 1 x B + ox = fl::reorder(ox, {0, 2, 1}); + + auto scoreArr = Tensor::fromBuffer( + {1, static_cast(beam.size()), 1}, + prevScoreVec.data(), + MemoryLocation::Host + ); + scoreArr = fl::tile(scoreArr, {ox.dim(0)}); + + scoreArr = scoreArr + ox.tensor(); // C x B + scoreArr = scoreArr.flatten(); // column-first + auto scoreVec = scoreArr.toHostVector(); + + std::vector indices(scoreVec.size()); + std::iota(indices.begin(), indices.end(), 0); + std::partial_sort( + indices.begin(), + indices.begin() + + std::min(2 * beamSize, static_cast(scoreVec.size())), + indices.end(), + [&scoreVec](size_t i1, size_t i2) { + return scoreVec[i1] > scoreVec[i2]; + } + ); + + int nClass = ox.dim(0); + for(int j = 0; j < indices.size(); j++) { + int hypIdx = indices[j] / nClass; + int clsIdx = indices[j] % nClass; + std::vector path_(beam[hypIdx].path); + path_.push_back(clsIdx); + if(j < beamSize && clsIdx == eos_) { + path_.pop_back(); + complete.emplace_back( + scoreVec[indices[j]], + path_, + detail::selectState(state, hypIdx) + ); + } else if(clsIdx != eos_) { + newBeam.emplace_back( + scoreVec[indices[j]], + path_, + detail::selectState(state, hypIdx) + ); + } + if(newBeam.size() >= beamSize) { + break; + } + } + beam.resize(newBeam.size()); + beam = std::move(newBeam); + + if(complete.size() >= beamSize) { + std::partial_sort( + complete.begin(), + complete.begin() + beamSize, + complete.end(), + cmpfn + ); + complete.resize(beamSize); + + // if lowest score in complete is better than best future hypo + // then its not possible for any future hypothesis to replace existing + // hypothesises in complete. + if(complete.back().score > beam[0].score) { + break; + } + } } - } - if (wasTrain) { - train(); - } + if(wasTrain) { + train(); + } - return complete.empty() ? beam : complete; + return complete.empty() ? beam : complete; } std::pair Seq2SeqCriterion::decodeStep( @@ -466,178 +513,191 @@ std::pair Seq2SeqCriterion::decodeStep( const Seq2SeqState& inState, const Tensor& inputSizes, const Tensor& targetSizes, - const int maxDecoderSteps) const { - if (xEncoded.ndim() != 3) { - throw std::invalid_argument( - "Seq2SeqCriterion::decodeStep: " - "expected xEncoded to have at least three dimensions"); - } - - Variable hy; - if (y.isEmpty()) { - hy = tile(startEmbedding(), {1, 1, static_cast(xEncoded.dim(2))}); - } else if (train_ && samplingStrategy_ == fl::pkg::speech::kGumbelSampling) { - hy = linear(y, embedding()->param(0)); - } else { - hy = embedding()->forward(y); - } - - if (inputFeeding_ && !y.isEmpty()) { - hy = hy + moddims(inState.summary, hy.shape()); - } - hy = moddims(hy, {hy.dim(0), -1}); // H x B - - Seq2SeqState outState(nAttnRound_); - outState.step = inState.step + 1; - - Variable summaries; - for (int i = 0; i < nAttnRound_; i++) { - hy = moddims(hy, {hy.dim(0), -1}); // H x 1 x B -> H x B - std::tie(hy, outState.hidden[i]) = - decodeRNN(i)->forward(hy, inState.hidden[i]); - hy = moddims(hy, {hy.dim(0), 1, hy.dim(1)}); // H x B -> H x 1 x B - - Variable windowWeight; - // because of the beam search batchsize can be - // different for xEncoded and y (xEncoded batch = 1 and y batch = beam - // size) - int batchsize = - y.isEmpty() ? xEncoded.dim(2) : (y.ndim() < 2 ? 1 : y.dim(1)); - if (window_ && (!train_ || trainWithWindow_)) { - // TODO fix for softpretrain where target size is used - // for now force to xEncoded.dim(1) - windowWeight = window_->computeWindow( - inState.alpha, - inState.step, - maxDecoderSteps, - xEncoded.dim(1), - batchsize, - inputSizes, - targetSizes); + const int maxDecoderSteps +) const { + if(xEncoded.ndim() != 3) { + throw std::invalid_argument( + "Seq2SeqCriterion::decodeStep: " + "expected xEncoded to have at least three dimensions" + ); + } + + Variable hy; + if(y.isEmpty()) { + hy = tile(startEmbedding(), {1, 1, static_cast(xEncoded.dim(2))}); + } else if(train_ && samplingStrategy_ == fl::pkg::speech::kGumbelSampling) { + hy = linear(y, embedding()->param(0)); + } else { + hy = embedding()->forward(y); + } + + if(inputFeeding_ && !y.isEmpty()) { + hy = hy + moddims(inState.summary, hy.shape()); } - std::tie(outState.alpha, summaries) = attention(i)->forward( - hy, xEncoded, inState.alpha, windowWeight, fl::noGrad(inputSizes)); - hy = hy + summaries; - } - outState.summary = summaries; - - auto out = linearOut()->forward(hy); // C x 1 x B - return std::make_pair(out, outState); + hy = moddims(hy, {hy.dim(0), -1}); // H x B + + Seq2SeqState outState(nAttnRound_); + outState.step = inState.step + 1; + + Variable summaries; + for(int i = 0; i < nAttnRound_; i++) { + hy = moddims(hy, {hy.dim(0), -1}); // H x 1 x B -> H x B + std::tie(hy, outState.hidden[i]) = + decodeRNN(i)->forward(hy, inState.hidden[i]); + hy = moddims(hy, {hy.dim(0), 1, hy.dim(1)}); // H x B -> H x 1 x B + + Variable windowWeight; + // because of the beam search batchsize can be + // different for xEncoded and y (xEncoded batch = 1 and y batch = beam + // size) + int batchsize = + y.isEmpty() ? xEncoded.dim(2) : (y.ndim() < 2 ? 1 : y.dim(1)); + if(window_ && (!train_ || trainWithWindow_)) { + // TODO fix for softpretrain where target size is used + // for now force to xEncoded.dim(1) + windowWeight = window_->computeWindow( + inState.alpha, + inState.step, + maxDecoderSteps, + xEncoded.dim(1), + batchsize, + inputSizes, + targetSizes + ); + } + std::tie(outState.alpha, summaries) = attention(i)->forward( + hy, + xEncoded, + inState.alpha, + windowWeight, + fl::noGrad(inputSizes) + ); + hy = hy + summaries; + } + outState.summary = summaries; + + auto out = linearOut()->forward(hy); // C x 1 x B + return std::make_pair(out, outState); } -std::pair>, std::vector> -Seq2SeqCriterion::decodeBatchStep( +std::pair>, std::vector> Seq2SeqCriterion::decodeBatchStep( const fl::Variable& xEncoded, std::vector& ys, const std::vector& inStates, const int attentionThreshold, - const float smoothingTemperature) const { - // NB: xEncoded has to be with batchsize 1 - int batchSize = ys.size(); - std::vector statesVector(batchSize); - - // Batch Ys - for (int i = 0; i < batchSize; i++) { - if (ys[i].isEmpty()) { - ys[i] = startEmbedding(); - } else { - ys[i] = embedding()->forward(ys[i]); - if (inputFeeding_) { - ys[i] = ys[i] + moddims(inStates[i]->summary, ys[i].shape()); - } - } - ys[i] = moddims(ys[i], {ys[i].dim(0), -1}); - } - Variable yBatched = concatenate(ys, 1); // H x B - - std::vector outstates(batchSize); - for (int i = 0; i < batchSize; i++) { - outstates[i] = std::make_shared(nAttnRound_); - outstates[i]->step = inStates[i]->step + 1; - } - Variable outStateBatched; - - for (int n = 0; n < nAttnRound_; n++) { - /* (1) RNN forward */ - if (inStates[0]->hidden[n].isEmpty()) { - std::tie(yBatched, outStateBatched) = - decodeRNN(n)->forward(yBatched, Variable()); - } else { - for (int i = 0; i < batchSize; i++) { - statesVector[i] = inStates[i]->hidden[n]; - } - Variable inStateHiddenBatched = - concatenate(statesVector, 1).asContiguous(); - std::tie(yBatched, outStateBatched) = - decodeRNN(n)->forward(yBatched, inStateHiddenBatched); + const float smoothingTemperature +) const { + // NB: xEncoded has to be with batchsize 1 + int batchSize = ys.size(); + std::vector statesVector(batchSize); + + // Batch Ys + for(int i = 0; i < batchSize; i++) { + if(ys[i].isEmpty()) { + ys[i] = startEmbedding(); + } else { + ys[i] = embedding()->forward(ys[i]); + if(inputFeeding_) { + ys[i] = ys[i] + moddims(inStates[i]->summary, ys[i].shape()); + } + } + ys[i] = moddims(ys[i], {ys[i].dim(0), -1}); } + Variable yBatched = concatenate(ys, 1); // H x B - for (int i = 0; i < batchSize; i++) { - outstates[i]->hidden[n] = outStateBatched(fl::span, fl::range(i, i + 1)); + std::vector outstates(batchSize); + for(int i = 0; i < batchSize; i++) { + outstates[i] = std::make_shared(nAttnRound_); + outstates[i]->step = inStates[i]->step + 1; } + Variable outStateBatched; + + for(int n = 0; n < nAttnRound_; n++) { + /* (1) RNN forward */ + if(inStates[0]->hidden[n].isEmpty()) { + std::tie(yBatched, outStateBatched) = + decodeRNN(n)->forward(yBatched, Variable()); + } else { + for(int i = 0; i < batchSize; i++) { + statesVector[i] = inStates[i]->hidden[n]; + } + Variable inStateHiddenBatched = + concatenate(statesVector, 1).asContiguous(); + std::tie(yBatched, outStateBatched) = + decodeRNN(n)->forward(yBatched, inStateHiddenBatched); + } + + for(int i = 0; i < batchSize; i++) { + outstates[i]->hidden[n] = outStateBatched(fl::span, fl::range(i, i + 1)); + } + + /* (2) Attention forward */ + if(window_ && (!train_ || trainWithWindow_)) { + throw std::runtime_error( + "Batched decoding does not support models with window" + ); + } - /* (2) Attention forward */ - if (window_ && (!train_ || trainWithWindow_)) { - throw std::runtime_error( - "Batched decoding does not support models with window"); + Variable summaries, alphaBatched; + // NB: + // - Third Variable is set to empty since no attention use it. + // - Only ContentAttention is supported + std::tie(alphaBatched, summaries) = + attention(n)->forward(yBatched, xEncoded, Variable(), Variable()); + alphaBatched = fl::transpose(alphaBatched, {1, 0, 2}); // B x T -> T x B + yBatched = yBatched + summaries; // H x B + + Tensor bestpath, maxvalues; + fl::max(maxvalues, bestpath, alphaBatched.tensor(), 0); + std::vector maxIdx = bestpath.toHostVector(); + for(int i = 0; i < batchSize; i++) { + outstates[i]->peakAttnPos = maxIdx[i]; + // TODO: std::abs maybe unnecessary + outstates[i]->isValid = + std::abs(outstates[i]->peakAttnPos - inStates[i]->peakAttnPos) + <= attentionThreshold; + outstates[i]->alpha = alphaBatched(fl::span, fl::range(i, i + 1)); + outstates[i]->summary = yBatched(fl::span, fl::range(i, i + 1)); + } } - Variable summaries, alphaBatched; - // NB: - // - Third Variable is set to empty since no attention use it. - // - Only ContentAttention is supported - std::tie(alphaBatched, summaries) = - attention(n)->forward(yBatched, xEncoded, Variable(), Variable()); - alphaBatched = fl::transpose(alphaBatched, {1, 0, 2}); // B x T -> T x B - yBatched = yBatched + summaries; // H x B - - Tensor bestpath, maxvalues; - fl::max(maxvalues, bestpath, alphaBatched.tensor(), 0); - std::vector maxIdx = bestpath.toHostVector(); - for (int i = 0; i < batchSize; i++) { - outstates[i]->peakAttnPos = maxIdx[i]; - // TODO: std::abs maybe unnecessary - outstates[i]->isValid = - std::abs(outstates[i]->peakAttnPos - inStates[i]->peakAttnPos) <= - attentionThreshold; - outstates[i]->alpha = alphaBatched(fl::span, fl::range(i, i + 1)); - outstates[i]->summary = yBatched(fl::span, fl::range(i, i + 1)); + /* (3) Linear forward */ + auto outBatched = linearOut()->forward(yBatched); + outBatched = logSoftmax(outBatched / smoothingTemperature, 0); + std::vector> out(batchSize); + for(int i = 0; i < batchSize; i++) { + out[i] = outBatched(fl::span, fl::range(i, i + 1)) + .tensor() + .toHostVector(); } - } - - /* (3) Linear forward */ - auto outBatched = linearOut()->forward(yBatched); - outBatched = logSoftmax(outBatched / smoothingTemperature, 0); - std::vector> out(batchSize); - for (int i = 0; i < batchSize; i++) { - out[i] = outBatched(fl::span, fl::range(i, i + 1)) - .tensor() - .toHostVector(); - } - - return std::make_pair(out, outstates); + + return std::make_pair(out, outstates); } void Seq2SeqCriterion::setUseSequentialDecoder() { - useSequentialDecoder_ = false; - if ((pctTeacherForcing_ < 100 && - samplingStrategy_ == fl::pkg::speech::kModelSampling) || - samplingStrategy_ == fl::pkg::speech::kGumbelSampling || inputFeeding_) { - useSequentialDecoder_ = true; - } else if ( - std::dynamic_pointer_cast(attention(0)) || - std::dynamic_pointer_cast(attention(0)) || - std::dynamic_pointer_cast(attention(0))) { - useSequentialDecoder_ = true; - } else if ( - window_ && trainWithWindow_ && - std::dynamic_pointer_cast(window_)) { - useSequentialDecoder_ = true; - } + useSequentialDecoder_ = false; + if( + (pctTeacherForcing_ < 100 + && samplingStrategy_ == fl::pkg::speech::kModelSampling) + || samplingStrategy_ == fl::pkg::speech::kGumbelSampling || inputFeeding_ + ) { + useSequentialDecoder_ = true; + } else if( + std::dynamic_pointer_cast(attention(0)) + || std::dynamic_pointer_cast(attention(0)) + || std::dynamic_pointer_cast(attention(0)) + ) { + useSequentialDecoder_ = true; + } else if( + window_ && trainWithWindow_ + && std::dynamic_pointer_cast(window_) + ) { + useSequentialDecoder_ = true; + } } std::string Seq2SeqCriterion::prettyString() const { - return "Seq2SeqCriterion"; + return "Seq2SeqCriterion"; } EmittingModelUpdateFunc buildSeq2SeqRnnUpdateFunction( @@ -645,68 +705,75 @@ EmittingModelUpdateFunc buildSeq2SeqRnnUpdateFunction( int attRound, int beamSize, float attThr, - float smoothingTemp) { - auto buf = std::make_shared( - attRound, beamSize, attThr, smoothingTemp); - - const Seq2SeqCriterion* s2sCriterion = - static_cast(criterion.get()); - auto emittingModelUpdateFunc = - [buf, s2sCriterion]( - const float* emissions, - const int N, - const int T, - const std::vector& rawY, - const std::vector& /* prevHypBeamIdxs */, - const std::vector& rawPrevStates, - int& t) { - if (t == 0) { - buf->input = fl::Variable( - Tensor::fromBuffer({N, T}, emissions, MemoryLocation::Host), - false); - } - int batchSize = rawY.size(); - buf->prevStates.resize(0); - buf->ys.resize(0); - - // Cast to seq2seq states - for (int i = 0; i < batchSize; i++) { - Seq2SeqState* prevState = - static_cast(rawPrevStates[i].get()); - fl::Variable y; - if (t > 0) { - y = fl::constant(rawY[i], {1}, fl::dtype::s32, false); - } else { - prevState = &buf->dummyState; - } - buf->ys.push_back(y); - buf->prevStates.push_back(prevState); - } - - // Run forward in batch - std::vector> amScores; - std::vector outStates; - - std::tie(amScores, outStates) = s2sCriterion->decodeBatchStep( - buf->input, - buf->ys, - buf->prevStates, - buf->attentionThreshold, - buf->smoothingTemperature); - - // Cast back to void* - std::vector out; - for (auto& os : outStates) { - if (os->isValid) { - out.push_back(os); - } else { - out.push_back(nullptr); - } - } - return std::make_pair(amScores, out); - }; - - return emittingModelUpdateFunc; + float smoothingTemp +) { + auto buf = std::make_shared( + attRound, + beamSize, + attThr, + smoothingTemp + ); + + const Seq2SeqCriterion* s2sCriterion = + static_cast(criterion.get()); + auto emittingModelUpdateFunc = + [buf, s2sCriterion]( + const float* emissions, + const int N, + const int T, + const std::vector& rawY, + const std::vector& /* prevHypBeamIdxs */, + const std::vector& rawPrevStates, + int& t) { + if(t == 0) { + buf->input = fl::Variable( + Tensor::fromBuffer({N, T}, emissions, MemoryLocation::Host), + false + ); + } + int batchSize = rawY.size(); + buf->prevStates.resize(0); + buf->ys.resize(0); + + // Cast to seq2seq states + for(int i = 0; i < batchSize; i++) { + Seq2SeqState* prevState = + static_cast(rawPrevStates[i].get()); + fl::Variable y; + if(t > 0) { + y = fl::constant(rawY[i], {1}, fl::dtype::s32, false); + } else { + prevState = &buf->dummyState; + } + buf->ys.push_back(y); + buf->prevStates.push_back(prevState); + } + + // Run forward in batch + std::vector> amScores; + std::vector outStates; + + std::tie(amScores, outStates) = s2sCriterion->decodeBatchStep( + buf->input, + buf->ys, + buf->prevStates, + buf->attentionThreshold, + buf->smoothingTemperature + ); + + // Cast back to void* + std::vector out; + for(auto& os : outStates) { + if(os->isValid) { + out.push_back(os); + } else { + out.push_back(nullptr); + } + } + return std::make_pair(amScores, out); + }; + + return emittingModelUpdateFunc; } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.h b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.h index 0138d1e..c06a59a 100644 --- a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.h +++ b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.h @@ -16,213 +16,222 @@ namespace fl { namespace pkg { -namespace speech { - -struct Seq2SeqState { - fl::Variable alpha; - std::vector hidden; - fl::Variable summary; - int step; - int peakAttnPos; - bool isValid; - - Seq2SeqState() : hidden(1), step(0), peakAttnPos(-1), isValid(false) {} - - explicit Seq2SeqState(int nAttnRound) - : hidden(nAttnRound), step(0), peakAttnPos(-1), isValid(false) {} -}; - -typedef std::shared_ptr Seq2SeqStatePtr; - -class Seq2SeqCriterion : public SequenceCriterion { - public: - struct CandidateHypo { - float score; - std::vector path; - Seq2SeqState state; - explicit CandidateHypo() : score(0.0) { - path.resize(0); - } - CandidateHypo(float score_, std::vector path_, Seq2SeqState state_) - : score(score_), path(path_), state(state_) {} - }; - - Seq2SeqCriterion( - int nClass, - int hiddenDim, - int eos, - int pad, - int maxDecoderOutputLen, - const std::vector>& attentions, - std::shared_ptr window = nullptr, - bool trainWithWindow = false, - int pctTeacherForcing = 100, - double labelSmooth = 0.0, - bool inputFeeding = false, - std::string samplingStrategy = fl::pkg::speech::kRandSampling, - double gumbelTemperature = 1.0, - int nRnnLayer = 1, - int nAttnRound = 1, - float dropOut = 0.0); - - std::unique_ptr clone() const override; - - std::vector forward( - const std::vector& inputs) override; - - /* Next step predictions are based on the target at - * the previous time-step so this function should only - * be used for training purposes. */ - std::pair decoder( - const fl::Variable& input, - const fl::Variable& target, - const Tensor& inputSizes, - const Tensor& targetSizes); - - std::pair vectorizedDecoder( - const fl::Variable& input, - const fl::Variable& target, - const Tensor& inputSizes, - const Tensor& targetSizes); - - Tensor viterbiPath(const Tensor& input, const Tensor& inputSizes = Tensor()) - override; - - std::pair - viterbiPathBase(const Tensor& input, const Tensor& inputSizes, bool saveAttn); - - std::vector beamSearch( - const Tensor& input, - const Tensor& inputSizes, - std::vector beam, - int beamSize, - int maxLen); - - std::vector - beamPath(const Tensor& input, const Tensor& inputSizes, int beamSize = 10); - - std::string prettyString() const override; - - std::shared_ptr embedding() const { - return std::static_pointer_cast(module(0)); - } - - std::shared_ptr decodeRNN(int n) const { - return std::static_pointer_cast(module(n + 1)); - } - - std::shared_ptr attention(int n) const { - return std::static_pointer_cast(module(nAttnRound_ + n + 2)); - } - - std::shared_ptr linearOut() const { - return std::static_pointer_cast(module(nAttnRound_ + 1)); - } - - fl::Variable startEmbedding() const { - return params_.back(); - } - - std::pair>, std::vector> - decodeBatchStep( - const fl::Variable& xEncoded, - std::vector& ys, - const std::vector& inStates, - const int attentionThreshold = std::numeric_limits::infinity(), - const float smoothingTemperature = 1.0) const; - - std::pair decodeStep( - const fl::Variable& xEncoded, - const fl::Variable& y, - const Seq2SeqState& instate, - const Tensor& inputSizes, - const Tensor& targetSizes, - int targetLen) const; - - void clearWindow() { - trainWithWindow_ = false; - window_ = nullptr; - } - - void setSampling(std::string newSamplingStrategy, int newPctTeacherForcing) { - pctTeacherForcing_ = newPctTeacherForcing; - samplingStrategy_ = newSamplingStrategy; - setUseSequentialDecoder(); - } - - void setGumbelTemperature(double temperature) { - gumbelTemperature_ = temperature; - } - - void setLabelSmooth(double labelSmooth) { - labelSmooth_ = labelSmooth; - } - - private: - int eos_; - int pad_; - int maxDecoderOutputLen_; - std::shared_ptr window_; - bool trainWithWindow_; - int pctTeacherForcing_; - bool useSequentialDecoder_; - double labelSmooth_; - bool inputFeeding_; - int nClass_; - std::string samplingStrategy_; - double gumbelTemperature_; - int nAttnRound_{1}; - - FL_SAVE_LOAD_WITH_BASE( - SequenceCriterion, - eos_, - maxDecoderOutputLen_, - window_, - trainWithWindow_, - pctTeacherForcing_, - useSequentialDecoder_, - labelSmooth_, - inputFeeding_, - nClass_, - fl::versioned(samplingStrategy_, 1), - fl::versioned(gumbelTemperature_, 2), - fl::versioned(nAttnRound_, 3), - fl::versioned(pad_, 4)) - - Seq2SeqCriterion() = default; - - void setUseSequentialDecoder(); -}; + namespace speech { + + struct Seq2SeqState { + fl::Variable alpha; + std::vector hidden; + fl::Variable summary; + int step; + int peakAttnPos; + bool isValid; + + Seq2SeqState() : hidden(1), step(0), peakAttnPos(-1), isValid(false) {} + + explicit Seq2SeqState(int nAttnRound) + : hidden(nAttnRound), step(0), peakAttnPos(-1), isValid(false) {} + }; + + typedef std::shared_ptr Seq2SeqStatePtr; + + class Seq2SeqCriterion : public SequenceCriterion { + public: + struct CandidateHypo { + float score; + std::vector path; + Seq2SeqState state; + explicit CandidateHypo() : score(0.0) { + path.resize(0); + } + CandidateHypo(float score_, std::vector path_, Seq2SeqState state_) + : score(score_), path(path_), state(state_) {} + }; + + Seq2SeqCriterion( + int nClass, + int hiddenDim, + int eos, + int pad, + int maxDecoderOutputLen, + const std::vector>& attentions, + std::shared_ptr window = nullptr, + bool trainWithWindow = false, + int pctTeacherForcing = 100, + double labelSmooth = 0.0, + bool inputFeeding = false, + std::string samplingStrategy = fl::pkg::speech::kRandSampling, + double gumbelTemperature = 1.0, + int nRnnLayer = 1, + int nAttnRound = 1, + float dropOut = 0.0 + ); + + std::unique_ptr clone() const override; + + std::vector forward( + const std::vector& inputs + ) override; + + /* Next step predictions are based on the target at + * the previous time-step so this function should only + * be used for training purposes. */ + std::pair decoder( + const fl::Variable& input, + const fl::Variable& target, + const Tensor& inputSizes, + const Tensor& targetSizes + ); + + std::pair vectorizedDecoder( + const fl::Variable& input, + const fl::Variable& target, + const Tensor& inputSizes, + const Tensor& targetSizes + ); + + Tensor viterbiPath(const Tensor& input, const Tensor& inputSizes = Tensor()) + override; + + std::pair viterbiPathBase( + const Tensor& input, + const Tensor& inputSizes, + bool saveAttn + ); + + std::vector beamSearch( + const Tensor& input, + const Tensor& inputSizes, + std::vector beam, + int beamSize, + int maxLen + ); + + std::vector beamPath(const Tensor& input, const Tensor& inputSizes, int beamSize = 10); + + std::string prettyString() const override; + + std::shared_ptr embedding() const { + return std::static_pointer_cast(module(0)); + } + + std::shared_ptr decodeRNN(int n) const { + return std::static_pointer_cast(module(n + 1)); + } + + std::shared_ptr attention(int n) const { + return std::static_pointer_cast(module(nAttnRound_ + n + 2)); + } + + std::shared_ptr linearOut() const { + return std::static_pointer_cast(module(nAttnRound_ + 1)); + } + + fl::Variable startEmbedding() const { + return params_.back(); + } + + std::pair>, std::vector> decodeBatchStep( + const fl::Variable& xEncoded, + std::vector& ys, + const std::vector& inStates, + const int attentionThreshold = std::numeric_limits::infinity(), + const float smoothingTemperature = 1.0 + ) const; + + std::pair decodeStep( + const fl::Variable& xEncoded, + const fl::Variable& y, + const Seq2SeqState& instate, + const Tensor& inputSizes, + const Tensor& targetSizes, + int targetLen + ) const; + + void clearWindow() { + trainWithWindow_ = false; + window_ = nullptr; + } + + void setSampling(std::string newSamplingStrategy, int newPctTeacherForcing) { + pctTeacherForcing_ = newPctTeacherForcing; + samplingStrategy_ = newSamplingStrategy; + setUseSequentialDecoder(); + } + + void setGumbelTemperature(double temperature) { + gumbelTemperature_ = temperature; + } + + void setLabelSmooth(double labelSmooth) { + labelSmooth_ = labelSmooth; + } + + private: + int eos_; + int pad_; + int maxDecoderOutputLen_; + std::shared_ptr window_; + bool trainWithWindow_; + int pctTeacherForcing_; + bool useSequentialDecoder_; + double labelSmooth_; + bool inputFeeding_; + int nClass_; + std::string samplingStrategy_; + double gumbelTemperature_; + int nAttnRound_{1}; + + FL_SAVE_LOAD_WITH_BASE( + SequenceCriterion, + eos_, + maxDecoderOutputLen_, + window_, + trainWithWindow_, + pctTeacherForcing_, + useSequentialDecoder_, + labelSmooth_, + inputFeeding_, + nClass_, + fl::versioned(samplingStrategy_, 1), + fl::versioned(gumbelTemperature_, 2), + fl::versioned(nAttnRound_, 3), + fl::versioned(pad_, 4) + ) Seq2SeqCriterion() = default; + + void setUseSequentialDecoder(); + }; /* Decoder helpers */ -struct Seq2SeqDecoderBuffer { - fl::Variable input; - Seq2SeqState dummyState; - std::vector ys; - std::vector prevStates; - int attentionThreshold; - double smoothingTemperature; - - Seq2SeqDecoderBuffer( - int nAttnRound, - int beamSize, - int attnThre, - int smootTemp) - : dummyState(nAttnRound), - attentionThreshold(attnThre), - smoothingTemperature(smootTemp) { - ys.reserve(beamSize); - prevStates.reserve(beamSize); - } -}; - -EmittingModelUpdateFunc buildSeq2SeqRnnUpdateFunction( - std::shared_ptr& criterion, - int attRound, - int beamSize, - float attThr, - float smoothingTemp); -} // namespace speech + struct Seq2SeqDecoderBuffer { + fl::Variable input; + Seq2SeqState dummyState; + std::vector ys; + std::vector prevStates; + int attentionThreshold; + double smoothingTemperature; + + Seq2SeqDecoderBuffer( + int nAttnRound, + int beamSize, + int attnThre, + int smootTemp + ) + : dummyState(nAttnRound), + attentionThreshold(attnThre), + smoothingTemperature(smootTemp) { + ys.reserve(beamSize); + prevStates.reserve(beamSize); + } + }; + + EmittingModelUpdateFunc buildSeq2SeqRnnUpdateFunction( + std::shared_ptr& criterion, + int attRound, + int beamSize, + float attThr, + float smoothingTemp + ); + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/SequenceCriterion.h b/flashlight/pkg/speech/criterion/SequenceCriterion.h index 392571a..1b8551b 100644 --- a/flashlight/pkg/speech/criterion/SequenceCriterion.h +++ b/flashlight/pkg/speech/criterion/SequenceCriterion.h @@ -13,48 +13,50 @@ namespace fl { namespace pkg { -namespace speech { - -class SequenceCriterion : public fl::Container { - public: - /** - * Find the most likely path through input using viterbi algorithm - * https://en.wikipedia.org/wiki/Viterbi_algorithm - */ - virtual Tensor viterbiPath( - const Tensor& input, - const Tensor& inputSizes = Tensor()) = 0; - - /** - * Finds the most likely path using viterbi algorithm that is constrained - * to go through target - */ - virtual Tensor viterbiPathWithTarget( - const Tensor& input, - const Tensor& target, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) { - throw std::runtime_error("Not implemented"); - return Tensor(); - } - - private: - FL_SAVE_LOAD_WITH_BASE(fl::Container) -}; - -using EmittingModelStatePtr = std::shared_ptr; -using EmittingModelUpdateFunc = std::function>, - std::vector>( - const float*, - const int, - const int, - const std::vector&, - const std::vector&, - const std::vector&, - int&)>; - -} // namespace speech + namespace speech { + + class SequenceCriterion : public fl::Container { + public: + /** + * Find the most likely path through input using viterbi algorithm + * https://en.wikipedia.org/wiki/Viterbi_algorithm + */ + virtual Tensor viterbiPath( + const Tensor& input, + const Tensor& inputSizes = Tensor() + ) = 0; + + /** + * Finds the most likely path using viterbi algorithm that is constrained + * to go through target + */ + virtual Tensor viterbiPathWithTarget( + const Tensor& input, + const Tensor& target, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) { + throw std::runtime_error("Not implemented"); + return Tensor(); + } + + private: + FL_SAVE_LOAD_WITH_BASE(fl::Container) + }; + + using EmittingModelStatePtr = std::shared_ptr; + using EmittingModelUpdateFunc = std::function>, + std::vector>( + const float*, + const int, + const int, + const std::vector&, + const std::vector&, + const std::vector&, + int&)>; + + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/TransformerCriterion.cpp b/flashlight/pkg/speech/criterion/TransformerCriterion.cpp index 86f56cb..b696a9c 100644 --- a/flashlight/pkg/speech/criterion/TransformerCriterion.cpp +++ b/flashlight/pkg/speech/criterion/TransformerCriterion.cpp @@ -29,72 +29,80 @@ TransformerCriterion::TransformerCriterion( double labelSmooth, double pctTeacherForcing, double pDropout, - double pLayerDrop) - : nClass_(nClass), - eos_(eos), - pad_(pad), - maxDecoderOutputLen_(maxDecoderOutputLen), - nLayer_(nLayer), - window_(window), - trainWithWindow_(trainWithWindow), - labelSmooth_(labelSmooth), - pctTeacherForcing_(pctTeacherForcing) { - add(std::make_shared(hiddenDim, nClass)); - for (size_t i = 0; i < nLayer_; i++) { - add(std::make_shared( - hiddenDim, - hiddenDim / 4, - hiddenDim * 4, - 4, - maxDecoderOutputLen, - pDropout, - pLayerDrop, - true)); - } - add(std::make_shared(hiddenDim, nClass)); - add(attention); - params_.push_back(fl::uniform(Shape{hiddenDim}, -1e-1, 1e-1)); + double pLayerDrop +) : nClass_(nClass), + eos_(eos), + pad_(pad), + maxDecoderOutputLen_(maxDecoderOutputLen), + nLayer_(nLayer), + window_(window), + trainWithWindow_(trainWithWindow), + labelSmooth_(labelSmooth), + pctTeacherForcing_(pctTeacherForcing) { + add(std::make_shared(hiddenDim, nClass)); + for(size_t i = 0; i < nLayer_; i++) { + add( + std::make_shared( + hiddenDim, + hiddenDim / 4, + hiddenDim * 4, + 4, + maxDecoderOutputLen, + pDropout, + pLayerDrop, + true + ) + ); + } + add(std::make_shared(hiddenDim, nClass)); + add(attention); + params_.push_back(fl::uniform(Shape{hiddenDim}, -1e-1, 1e-1)); } std::unique_ptr TransformerCriterion::clone() const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'TransformerCriterion'"); + throw std::runtime_error( + "Cloning is unimplemented in Module 'TransformerCriterion'" + ); } std::vector TransformerCriterion::forward( - const std::vector& inputs) { - if (inputs.size() < 2 || inputs.size() > 4) { - throw std::invalid_argument( - "Invalid inputs size; Transformer criterion takes input," - " target, inputSizes [optional], targetSizes [optional]"); - } - const Variable& input = inputs[0]; - const Variable& target = inputs[1]; - const auto& inputSizes = - inputs.size() == 2 ? Tensor() : inputs[2].tensor(); // 1 x B - const auto& targetSizes = - inputs.size() == 3 ? Tensor() : inputs[3].tensor(); // 1 x B - - Variable out, alpha; - std::tie(out, alpha) = - vectorizedDecoder(input, target, inputSizes, targetSizes); - - out = logSoftmax(out, 0); - - auto losses = fl::moddims( - sum(categoricalCrossEntropy(out, target, ReduceMode::NONE, pad_), {0}), - {-1}); - if (train_ && labelSmooth_ > 0) { - long long nClass = out.dim(0); - auto targetTiled = fl::tile( - fl::reshape(target.tensor(), {1, target.dim(0), target.dim(1)}), - {nClass}); - out = applySeq2SeqMask(out, targetTiled, pad_); - auto smoothLoss = fl::moddims(sum(out, {0, 1}), {-1}); - losses = (1 - labelSmooth_) * losses - (labelSmooth_ / nClass) * smoothLoss; - } - - return {losses, out}; + const std::vector& inputs +) { + if(inputs.size() < 2 || inputs.size() > 4) { + throw std::invalid_argument( + "Invalid inputs size; Transformer criterion takes input," + " target, inputSizes [optional], targetSizes [optional]" + ); + } + const Variable& input = inputs[0]; + const Variable& target = inputs[1]; + const auto& inputSizes = + inputs.size() == 2 ? Tensor() : inputs[2].tensor(); // 1 x B + const auto& targetSizes = + inputs.size() == 3 ? Tensor() : inputs[3].tensor(); // 1 x B + + Variable out, alpha; + std::tie(out, alpha) = + vectorizedDecoder(input, target, inputSizes, targetSizes); + + out = logSoftmax(out, 0); + + auto losses = fl::moddims( + sum(categoricalCrossEntropy(out, target, ReduceMode::NONE, pad_), {0}), + {-1} + ); + if(train_ && labelSmooth_ > 0) { + long long nClass = out.dim(0); + auto targetTiled = fl::tile( + fl::reshape(target.tensor(), {1, target.dim(0), target.dim(1)}), + {nClass} + ); + out = applySeq2SeqMask(out, targetTiled, pad_); + auto smoothLoss = fl::moddims(sum(out, {0, 1}), {-1}); + losses = (1 - labelSmooth_) * losses - (labelSmooth_ / nClass) * smoothLoss; + } + + return {losses, out}; } // input : D x T x B @@ -103,310 +111,333 @@ std::pair TransformerCriterion::vectorizedDecoder( const Variable& input, const Variable& target, const Tensor& inputSizes, - const Tensor& targetSizes) { - int U = target.dim(0); - int B = target.dim(1); - int T = input.isEmpty() ? 0 : input.dim(1); - - auto hy = tile(startEmbedding(), {1, 1, B}); - - if (U > 1) { - auto y = target(fl::range(0, U - 1), fl::span); - - if (train_) { - // TODO: other sampling strategies - auto mask = Variable( - (fl::rand(y.shape()) * 100 <= pctTeacherForcing_).astype(y.type()), - false); - auto samples = Variable( - (fl::rand(y.shape()) * (nClass_ - 1)).astype(y.type()), false); + const Tensor& targetSizes +) { + int U = target.dim(0); + int B = target.dim(1); + int T = input.isEmpty() ? 0 : input.dim(1); + + auto hy = tile(startEmbedding(), {1, 1, B}); + + if(U > 1) { + auto y = target(fl::range(0, U - 1), fl::span); + + if(train_) { + // TODO: other sampling strategies + auto mask = Variable( + (fl::rand(y.shape()) * 100 <= pctTeacherForcing_).astype(y.type()), + false + ); + auto samples = Variable( + (fl::rand(y.shape()) * (nClass_ - 1)).astype(y.type()), + false + ); + + y = mask * y + (1 - mask) * samples; + } - y = mask * y + (1 - mask) * samples; + auto yEmbed = embedding()->forward(y); + hy = concatenate({hy, yEmbed}, 1); } - auto yEmbed = embedding()->forward(y); - hy = concatenate({hy, yEmbed}, 1); - } - - Variable alpha, summaries; - Variable padMask; // no mask, decoder is not looking into future - for (int i = 0; i < nLayer_; i++) { - hy = layer(i)->forward(std::vector({hy, padMask})).front(); - } - - if (!input.isEmpty()) { - Variable windowWeight; - if (window_ && (!train_ || trainWithWindow_)) { - windowWeight = - window_->computeVectorizedWindow(U, T, B, inputSizes, targetSizes); + Variable alpha, summaries; + Variable padMask; // no mask, decoder is not looking into future + for(int i = 0; i < nLayer_; i++) { + hy = layer(i)->forward(std::vector({hy, padMask})).front(); } - std::tie(alpha, summaries) = attention()->forward( - hy, input, Variable(), windowWeight, fl::noGrad(inputSizes)); + if(!input.isEmpty()) { + Variable windowWeight; + if(window_ && (!train_ || trainWithWindow_)) { + windowWeight = + window_->computeVectorizedWindow(U, T, B, inputSizes, targetSizes); + } - hy = hy + summaries; - } + std::tie(alpha, summaries) = attention()->forward( + hy, + input, + Variable(), + windowWeight, + fl::noGrad(inputSizes) + ); - auto out = linearOut()->forward(hy); + hy = hy + summaries; + } - return std::make_pair(out, alpha); + auto out = linearOut()->forward(hy); + + return std::make_pair(out, alpha); } Tensor TransformerCriterion::viterbiPath( const Tensor& input, - const Tensor& inputSizes /* = Tensor() */) { - return viterbiPathBase(input, inputSizes, false).first; + const Tensor& inputSizes /* = Tensor() */ +) { + return viterbiPathBase(input, inputSizes, false).first; } std::pair TransformerCriterion::viterbiPathBase( const Tensor& input, const Tensor& inputSizes, - bool /* TODO: saveAttn */) { - bool wasTrain = train_; - eval(); - std::vector path; - std::vector alphaVec; - Variable alpha; - TS2SState state; - Variable y, ox; - Tensor maxIdx, maxValues; - int pred; - - for (int u = 0; u < maxDecoderOutputLen_; u++) { - std::tie(ox, state) = - decodeStep(Variable(input, false), y, state, inputSizes); - max(maxValues, maxIdx, ox.tensor(), 0); - maxIdx.host(&pred); + bool /* TODO: saveAttn */ +) { + bool wasTrain = train_; + eval(); + std::vector path; + std::vector alphaVec; + Variable alpha; + TS2SState state; + Variable y, ox; + Tensor maxIdx, maxValues; + int pred; + + for(int u = 0; u < maxDecoderOutputLen_; u++) { + std::tie(ox, state) = + decodeStep(Variable(input, false), y, state, inputSizes); + max(maxValues, maxIdx, ox.tensor(), 0); + maxIdx.host(&pred); + // TODO: saveAttn + + if(pred == eos_) { + break; + } + y = constant(pred, {1}, fl::dtype::s32, false); + path.push_back(pred); + } // TODO: saveAttn - if (pred == eos_) { - break; + if(wasTrain) { + train(); } - y = constant(pred, {1}, fl::dtype::s32, false); - path.push_back(pred); - } - // TODO: saveAttn - if (wasTrain) { - train(); - } - - auto vPath = path.empty() ? Tensor() : Tensor::fromVector(path); - return std::make_pair(vPath, alpha); + auto vPath = path.empty() ? Tensor() : Tensor::fromVector(path); + return std::make_pair(vPath, alpha); } std::pair TransformerCriterion::decodeStep( const Variable& xEncoded, const Variable& y, const TS2SState& inState, - const Tensor& inputSizes) const { - Variable hy; - if (y.isEmpty()) { - hy = tile(startEmbedding(), {1, 1, xEncoded.dim(2)}); - } else { - hy = embedding()->forward(y); - } - - // TODO: inputFeeding - - TS2SState outState; - outState.step = inState.step + 1; - Tensor padMask; // no mask because we are doing step by step decoding here, - // no look in the future - for (int i = 0; i < nLayer_; i++) { - if (inState.step == 0) { - outState.hidden.push_back(hy); - hy = layer(i) - ->forward(std::vector({hy, fl::noGrad(padMask)})) - .front(); + const Tensor& inputSizes +) const { + Variable hy; + if(y.isEmpty()) { + hy = tile(startEmbedding(), {1, 1, xEncoded.dim(2)}); } else { - outState.hidden.push_back(concatenate({inState.hidden[i], hy}, 1)); - hy = layer(i) - ->forward({inState.hidden[i], hy, fl::noGrad(padMask)}) - .front(); + hy = embedding()->forward(y); } - } - Variable windowWeight, alpha, summary; - if (window_ && (!train_ || trainWithWindow_)) { - // TODO fix for softpretrain where target size is used - // for now force to xEncoded.dim(1) - windowWeight = window_->computeWindow( - Variable(), - inState.step, - xEncoded.dim(1), - xEncoded.dim(1), - xEncoded.dim(2), - inputSizes, - Tensor()); - } + // TODO: inputFeeding + + TS2SState outState; + outState.step = inState.step + 1; + Tensor padMask; // no mask because we are doing step by step decoding here, + // no look in the future + for(int i = 0; i < nLayer_; i++) { + if(inState.step == 0) { + outState.hidden.push_back(hy); + hy = layer(i) + ->forward(std::vector({hy, fl::noGrad(padMask)})) + .front(); + } else { + outState.hidden.push_back(concatenate({inState.hidden[i], hy}, 1)); + hy = layer(i) + ->forward({inState.hidden[i], hy, fl::noGrad(padMask)}) + .front(); + } + } + + Variable windowWeight, alpha, summary; + if(window_ && (!train_ || trainWithWindow_)) { + // TODO fix for softpretrain where target size is used + // for now force to xEncoded.dim(1) + windowWeight = window_->computeWindow( + Variable(), + inState.step, + xEncoded.dim(1), + xEncoded.dim(1), + xEncoded.dim(2), + inputSizes, + Tensor() + ); + } - std::tie(alpha, summary) = attention()->forward( - hy, xEncoded, Variable(), windowWeight, fl::noGrad(inputSizes)); + std::tie(alpha, summary) = attention()->forward( + hy, + xEncoded, + Variable(), + windowWeight, + fl::noGrad(inputSizes) + ); - hy = hy + summary; + hy = hy + summary; - auto out = linearOut()->forward(hy); - return std::make_pair(out, outState); + auto out = linearOut()->forward(hy); + return std::make_pair(out, outState); } -std::pair>, std::vector> -TransformerCriterion::decodeBatchStep( +std::pair>, std::vector> TransformerCriterion::decodeBatchStep( const fl::Variable& xEncoded, std::vector& ys, const std::vector& inStates, const int /* attentionThreshold */, - const float smoothingTemperature) const { - // assume xEncoded has batch 1 - int B = ys.size(); + const float smoothingTemperature +) const { + // assume xEncoded has batch 1 + int B = ys.size(); + + for(int i = 0; i < B; i++) { + if(ys[i].isEmpty()) { + ys[i] = startEmbedding(); + } else { + ys[i] = embedding()->forward(ys[i]); + } // TODO: input feeding + ys[i] = moddims(ys[i], {ys[i].dim(0), 1, -1}); + } + Variable yBatched = concatenate(ys, 2); // D x 1 x B - for (int i = 0; i < B; i++) { - if (ys[i].isEmpty()) { - ys[i] = startEmbedding(); - } else { - ys[i] = embedding()->forward(ys[i]); - } // TODO: input feeding - ys[i] = moddims(ys[i], {ys[i].dim(0), 1, -1}); - } - Variable yBatched = concatenate(ys, 2); // D x 1 x B - - std::vector outstates(B); - for (int i = 0; i < B; i++) { - outstates[i] = std::make_shared(); - outstates[i]->step = inStates[i]->step + 1; - } - - Variable outStateBatched; - for (int i = 0; i < nLayer_; i++) { - if (inStates[0]->step == 0) { - for (int j = 0; j < B; j++) { - outstates[j]->hidden.push_back(yBatched(fl::span, fl::span, j)); - } - yBatched = layer(i)->forward(std::vector({yBatched})).front(); - } else { - std::vector statesVector(B); - for (int j = 0; j < B; j++) { - statesVector[j] = inStates[j]->hidden[i]; - } - Variable inStateHiddenBatched = concatenate(statesVector, 2); - auto tmp = std::vector({inStateHiddenBatched, yBatched}); - auto tmp2 = concatenate(tmp, 1); - for (int j = 0; j < B; j++) { - outstates[j]->hidden.push_back(tmp2(fl::span, fl::span, j)); - } - yBatched = layer(i)->forward(tmp).front(); + std::vector outstates(B); + for(int i = 0; i < B; i++) { + outstates[i] = std::make_shared(); + outstates[i]->step = inStates[i]->step + 1; + } + + Variable outStateBatched; + for(int i = 0; i < nLayer_; i++) { + if(inStates[0]->step == 0) { + for(int j = 0; j < B; j++) { + outstates[j]->hidden.push_back(yBatched(fl::span, fl::span, j)); + } + yBatched = layer(i)->forward(std::vector({yBatched})).front(); + } else { + std::vector statesVector(B); + for(int j = 0; j < B; j++) { + statesVector[j] = inStates[j]->hidden[i]; + } + Variable inStateHiddenBatched = concatenate(statesVector, 2); + auto tmp = std::vector({inStateHiddenBatched, yBatched}); + auto tmp2 = concatenate(tmp, 1); + for(int j = 0; j < B; j++) { + outstates[j]->hidden.push_back(tmp2(fl::span, fl::span, j)); + } + yBatched = layer(i)->forward(tmp).front(); + } } - } - - Variable alpha, summary; - yBatched = moddims(yBatched, {yBatched.dim(0), -1}); - std::tie(alpha, summary) = - attention()->forward(yBatched, xEncoded, Variable(), Variable()); - alpha = fl::transpose(alpha, {1, 0}); - yBatched = yBatched + summary; - - auto outBatched = linearOut()->forward(yBatched); - outBatched = logSoftmax(outBatched / smoothingTemperature, 0); - std::vector> out(B); - for (int i = 0; i < B; i++) { - out[i] = outBatched(fl::span, i).tensor().toHostVector(); - } - - return std::make_pair(out, outstates); + + Variable alpha, summary; + yBatched = moddims(yBatched, {yBatched.dim(0), -1}); + std::tie(alpha, summary) = + attention()->forward(yBatched, xEncoded, Variable(), Variable()); + alpha = fl::transpose(alpha, {1, 0}); + yBatched = yBatched + summary; + + auto outBatched = linearOut()->forward(yBatched); + outBatched = logSoftmax(outBatched / smoothingTemperature, 0); + std::vector> out(B); + for(int i = 0; i < B; i++) { + out[i] = outBatched(fl::span, i).tensor().toHostVector(); + } + + return std::make_pair(out, outstates); } EmittingModelUpdateFunc buildSeq2SeqTransformerUpdateFunction( std::shared_ptr& criterion, int beamSize, float attThr, - float smoothingTemp) { - auto buf = - std::make_shared(beamSize, attThr, smoothingTemp); - - const TransformerCriterion* criterionCast = - static_cast(criterion.get()); - - auto emittingModelUpdateFunc = - [buf, criterionCast]( - const float* emissions, - const int N, - const int T, - const std::vector& rawY, - const std::vector& /* prevHypBeamIdxs */, - const std::vector& rawPrevStates, - int& t) { - if (t == 0) { - buf->input = fl::Variable( - Tensor::fromBuffer({N, T}, emissions, MemoryLocation::Host), - false); - } - int B = rawY.size(); - std::vector out; - std::vector> amScoresAll; - - // Store the latest index of the hidden state when we can clear it - std::map lastIndexOfStatePtr; - for (int index = 0; index < rawPrevStates.size(); index++) { - TS2SState* ptr = static_cast(rawPrevStates[index].get()); - lastIndexOfStatePtr[ptr] = index; - } - - int start = 0, step = std::min(10, 1000 / (t + 1)); - while (start < B) { - buf->prevStates.resize(0); - buf->ys.resize(0); - - int end = start + step; - if (end > B) { - end = B; - } - for (int i = start; i < end; i++) { - TS2SState* prevState = - static_cast(rawPrevStates[i].get()); - fl::Variable y; - if (t > 0) { - y = fl::constant(rawY[i], {1}, fl::dtype::s32, false); - } else { - prevState = &buf->dummyState; + float smoothingTemp +) { + auto buf = + std::make_shared(beamSize, attThr, smoothingTemp); + + const TransformerCriterion* criterionCast = + static_cast(criterion.get()); + + auto emittingModelUpdateFunc = + [buf, criterionCast]( + const float* emissions, + const int N, + const int T, + const std::vector& rawY, + const std::vector& /* prevHypBeamIdxs */, + const std::vector& rawPrevStates, + int& t) { + if(t == 0) { + buf->input = fl::Variable( + Tensor::fromBuffer({N, T}, emissions, MemoryLocation::Host), + false + ); } - buf->ys.push_back(y); - buf->prevStates.push_back(prevState); - } - std::vector> amScores; - std::vector outStates; - std::tie(amScores, outStates) = criterionCast->decodeBatchStep( - buf->input, - buf->ys, - buf->prevStates, - buf->attentionThreshold, - buf->smoothingTemperature); - for (auto& os : outStates) { - out.push_back(os); - } - for (auto& s : amScores) { - amScoresAll.push_back(s); - } - // clean the previous state which is not needed anymore - // to prevent from OOM - for (int i = start; i < end; i++) { - TS2SState* prevState = - static_cast(rawPrevStates[i].get()); - if (prevState && - (lastIndexOfStatePtr.find(prevState) == - lastIndexOfStatePtr.end() || - lastIndexOfStatePtr.find(prevState)->second == i)) { - prevState->hidden.clear(); + int B = rawY.size(); + std::vector out; + std::vector> amScoresAll; + + // Store the latest index of the hidden state when we can clear it + std::map lastIndexOfStatePtr; + for(int index = 0; index < rawPrevStates.size(); index++) { + TS2SState* ptr = static_cast(rawPrevStates[index].get()); + lastIndexOfStatePtr[ptr] = index; } - } - start += step; - } - return std::make_pair(amScoresAll, out); - }; - return emittingModelUpdateFunc; + int start = 0, step = std::min(10, 1000 / (t + 1)); + while(start < B) { + buf->prevStates.resize(0); + buf->ys.resize(0); + + int end = start + step; + if(end > B) { + end = B; + } + for(int i = start; i < end; i++) { + TS2SState* prevState = + static_cast(rawPrevStates[i].get()); + fl::Variable y; + if(t > 0) { + y = fl::constant(rawY[i], {1}, fl::dtype::s32, false); + } else { + prevState = &buf->dummyState; + } + buf->ys.push_back(y); + buf->prevStates.push_back(prevState); + } + std::vector> amScores; + std::vector outStates; + std::tie(amScores, outStates) = criterionCast->decodeBatchStep( + buf->input, + buf->ys, + buf->prevStates, + buf->attentionThreshold, + buf->smoothingTemperature + ); + for(auto& os : outStates) { + out.push_back(os); + } + for(auto& s : amScores) { + amScoresAll.push_back(s); + } + // clean the previous state which is not needed anymore + // to prevent from OOM + for(int i = start; i < end; i++) { + TS2SState* prevState = + static_cast(rawPrevStates[i].get()); + if( + prevState + && (lastIndexOfStatePtr.find(prevState) + == lastIndexOfStatePtr.end() + || lastIndexOfStatePtr.find(prevState)->second == i) + ) { + prevState->hidden.clear(); + } + } + start += step; + } + return std::make_pair(amScoresAll, out); + }; + + return emittingModelUpdateFunc; } std::string TransformerCriterion::prettyString() const { - return "TransformerCriterion"; + return "TransformerCriterion"; } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/TransformerCriterion.h b/flashlight/pkg/speech/criterion/TransformerCriterion.h index 68300ab..449cd61 100644 --- a/flashlight/pkg/speech/criterion/TransformerCriterion.h +++ b/flashlight/pkg/speech/criterion/TransformerCriterion.h @@ -18,141 +18,148 @@ namespace fl { namespace pkg { -namespace speech { - -struct TS2SState { - fl::Variable alpha; - std::vector hidden; - fl::Variable summary; - int step; - - TS2SState() : step(0) {} -}; - -typedef std::shared_ptr TS2SStatePtr; - -class TransformerCriterion : public SequenceCriterion { - public: - TransformerCriterion( - int nClass, - int hiddenDim, - int eos, - int pad, - int maxDecoderOutputLen, - int nLayer, - std::shared_ptr attention, - std::shared_ptr window, - bool trainWithWindow, - double labelSmooth, - double pctTeacherForcing, - double pDropout, - double pLayerDrop); - - std::unique_ptr clone() const override; - - std::vector forward( - const std::vector& inputs) override; - - Tensor viterbiPath(const Tensor& input, const Tensor& inputSizes = Tensor()) - override; - - std::pair - viterbiPathBase(const Tensor& input, const Tensor& inputSizes, bool saveAttn); - - std::pair vectorizedDecoder( - const fl::Variable& input, - const fl::Variable& target, - const Tensor& inputSizes, - const Tensor& targetSizes); - - std::pair decodeStep( - const fl::Variable& xEncoded, - const fl::Variable& y, - const TS2SState& inState, - const Tensor& inputSizes) const; - - std::pair>, std::vector> - decodeBatchStep( - const fl::Variable& xEncoded, - std::vector& ys, - const std::vector& inStates, - const int attentionThreshold, - const float smoothingTemperature) const; - - void clearWindow() { - trainWithWindow_ = false; - window_ = nullptr; - } - - std::string prettyString() const override; - - std::shared_ptr embedding() const { - return std::static_pointer_cast(module(0)); - } - - std::shared_ptr layer(int i) const { - return std::static_pointer_cast(module(i + 1)); - } - - std::shared_ptr linearOut() const { - return std::static_pointer_cast(module(nLayer_ + 1)); - } - - std::shared_ptr attention() const { - return std::static_pointer_cast(module(nLayer_ + 2)); - } - - fl::Variable startEmbedding() const { - return params_.back(); - } - - private: - int nClass_; - int eos_; - int pad_; - int maxDecoderOutputLen_; - int nLayer_; - std::shared_ptr window_; - bool trainWithWindow_; - double labelSmooth_; - double pctTeacherForcing_; - - FL_SAVE_LOAD_WITH_BASE( - SequenceCriterion, - nClass_, - eos_, - maxDecoderOutputLen_, - nLayer_, - window_, - trainWithWindow_, - labelSmooth_, - pctTeacherForcing_, - fl::versioned(pad_, 1)) - - TransformerCriterion() = default; -}; - -struct TS2SDecoderBuffer { - fl::Variable input; - TS2SState dummyState; - std::vector ys; - std::vector prevStates; - int attentionThreshold; - double smoothingTemperature; - - TS2SDecoderBuffer(int beamSize, int attnThre, float smootTemp) - : attentionThreshold(attnThre), smoothingTemperature(smootTemp) { - ys.reserve(beamSize); - prevStates.reserve(beamSize); - } -}; - -EmittingModelUpdateFunc buildSeq2SeqTransformerUpdateFunction( - std::shared_ptr& criterion, - int beamSize, - float attThr, - float smoothingTemp); -} // namespace speech + namespace speech { + + struct TS2SState { + fl::Variable alpha; + std::vector hidden; + fl::Variable summary; + int step; + + TS2SState() : step(0) {} + }; + + typedef std::shared_ptr TS2SStatePtr; + + class TransformerCriterion : public SequenceCriterion { + public: + TransformerCriterion( + int nClass, + int hiddenDim, + int eos, + int pad, + int maxDecoderOutputLen, + int nLayer, + std::shared_ptr attention, + std::shared_ptr window, + bool trainWithWindow, + double labelSmooth, + double pctTeacherForcing, + double pDropout, + double pLayerDrop + ); + + std::unique_ptr clone() const override; + + std::vector forward( + const std::vector& inputs + ) override; + + Tensor viterbiPath(const Tensor& input, const Tensor& inputSizes = Tensor()) + override; + + std::pair viterbiPathBase( + const Tensor& input, + const Tensor& inputSizes, + bool saveAttn + ); + + std::pair vectorizedDecoder( + const fl::Variable& input, + const fl::Variable& target, + const Tensor& inputSizes, + const Tensor& targetSizes + ); + + std::pair decodeStep( + const fl::Variable& xEncoded, + const fl::Variable& y, + const TS2SState& inState, + const Tensor& inputSizes + ) const; + + std::pair>, std::vector> decodeBatchStep( + const fl::Variable& xEncoded, + std::vector& ys, + const std::vector& inStates, + const int attentionThreshold, + const float smoothingTemperature + ) const; + + void clearWindow() { + trainWithWindow_ = false; + window_ = nullptr; + } + + std::string prettyString() const override; + + std::shared_ptr embedding() const { + return std::static_pointer_cast(module(0)); + } + + std::shared_ptr layer(int i) const { + return std::static_pointer_cast(module(i + 1)); + } + + std::shared_ptr linearOut() const { + return std::static_pointer_cast(module(nLayer_ + 1)); + } + + std::shared_ptr attention() const { + return std::static_pointer_cast(module(nLayer_ + 2)); + } + + fl::Variable startEmbedding() const { + return params_.back(); + } + + private: + int nClass_; + int eos_; + int pad_; + int maxDecoderOutputLen_; + int nLayer_; + std::shared_ptr window_; + bool trainWithWindow_; + double labelSmooth_; + double pctTeacherForcing_; + + FL_SAVE_LOAD_WITH_BASE( + SequenceCriterion, + nClass_, + eos_, + maxDecoderOutputLen_, + nLayer_, + window_, + trainWithWindow_, + labelSmooth_, + pctTeacherForcing_, + fl::versioned(pad_, 1) + ) TransformerCriterion() = default; + }; + + struct TS2SDecoderBuffer { + fl::Variable input; + TS2SState dummyState; + std::vector ys; + std::vector prevStates; + int attentionThreshold; + double smoothingTemperature; + + TS2SDecoderBuffer(int beamSize, int attnThre, float smootTemp) + : attentionThreshold(attnThre), smoothingTemperature(smootTemp) { + ys.reserve(beamSize); + prevStates.reserve(beamSize); + } + }; + + EmittingModelUpdateFunc buildSeq2SeqTransformerUpdateFunction( + std::shared_ptr& criterion, + int beamSize, + float attThr, + float smoothingTemp + ); + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/AttentionBase.h b/flashlight/pkg/speech/criterion/attention/AttentionBase.h index a87bbf2..4d76043 100644 --- a/flashlight/pkg/speech/criterion/attention/AttentionBase.h +++ b/flashlight/pkg/speech/criterion/attention/AttentionBase.h @@ -11,87 +11,99 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { /** * Attention base class for encoder-decoder: decoder attends to particular * encoder part */ -class AttentionBase : public fl::Container { - public: - AttentionBase() {} + class AttentionBase : public fl::Container { + public: + AttentionBase() {} - std::vector forward(const std::vector& inputs) override { - if (inputs.size() != 3 && inputs.size() != 4 && inputs.size() != 5) { - throw std::invalid_argument( - "Attention encoder-decoder: Invalid inputs size, should be 3, 4, or 5 arguments"); - } + std::vector forward(const std::vector& inputs) override { + if(inputs.size() != 3 && inputs.size() != 4 && inputs.size() != 5) { + throw std::invalid_argument( + "Attention encoder-decoder: Invalid inputs size, should be 3, 4, or 5 arguments" + ); + } - auto logAttnWeight = inputs.size() == 4 ? inputs[3] : Variable(); - auto xEncodedSizes = inputs.size() == 5 ? inputs[4] : Variable(); - auto res = forwardBase( - inputs[0], inputs[1], inputs[2], logAttnWeight, xEncodedSizes); - return {res.first, res.second}; - } + auto logAttnWeight = inputs.size() == 4 ? inputs[3] : Variable(); + auto xEncodedSizes = inputs.size() == 5 ? inputs[4] : Variable(); + auto res = forwardBase( + inputs[0], + inputs[1], + inputs[2], + logAttnWeight, + xEncodedSizes + ); + return {res.first, res.second}; + } - std::pair forward( - const Variable& state, - const Variable& xEncoded, - const Variable& prevAttn) { - return forward( - state, - xEncoded, - prevAttn, - Variable() /* logAttnWeight */, - Variable() /* xEncodedSizes */); - } + std::pair forward( + const Variable& state, + const Variable& xEncoded, + const Variable& prevAttn + ) { + return forward( + state, + xEncoded, + prevAttn, + Variable() /* logAttnWeight */, + Variable() /* xEncodedSizes */ + ); + } - std::pair forward( - const Variable& state, - const Variable& xEncoded, - const Variable& prevAttn, - const Variable& logAttnWeight) { - return forwardBase( - state, - xEncoded, - prevAttn, - logAttnWeight, - Variable() /* xEncodedSizes */); - } + std::pair forward( + const Variable& state, + const Variable& xEncoded, + const Variable& prevAttn, + const Variable& logAttnWeight + ) { + return forwardBase( + state, + xEncoded, + prevAttn, + logAttnWeight, + Variable() /* xEncodedSizes */ + ); + } - virtual std::pair forward( - const Variable& state, - const Variable& xEncoded, - const Variable& prevAttn, - const Variable& logAttnWeight, - const Variable& xEncodedSizes) { - return forwardBase(state, xEncoded, prevAttn, logAttnWeight, xEncodedSizes); - } + virtual std::pair forward( + const Variable& state, + const Variable& xEncoded, + const Variable& prevAttn, + const Variable& logAttnWeight, + const Variable& xEncodedSizes + ) { + return forwardBase(state, xEncoded, prevAttn, logAttnWeight, xEncodedSizes); + } - protected: - /** - * Forward pass - * @param state current decoder state - * @param xEncoded encoder output = decoder input - * @param prevAttn previous attention vector - * @param logAttnWeight attention weights to add: finalAttn = - * exp(logAttnWeight + logComputedAttn) - * @param xEncodedSizes encoder output actual sizes has (1, B) dims - * Returns of sizes - * [targetlen, seqlen, batchsize] for attention, - * [hiddendim, targetlen, batchsize] for summary - */ - virtual std::pair forwardBase( - const Variable& state, - const Variable& xEncoded, - const Variable& prevAttn, - const Variable& logAttnWeight, - const Variable& xEncodedSizes) = 0; + protected: + /** + * Forward pass + * @param state current decoder state + * @param xEncoded encoder output = decoder input + * @param prevAttn previous attention vector + * @param logAttnWeight attention weights to add: finalAttn = + * exp(logAttnWeight + logComputedAttn) + * @param xEncodedSizes encoder output actual sizes has (1, B) dims + * Returns of sizes + * [targetlen, seqlen, batchsize] for attention, + * [hiddendim, targetlen, batchsize] for summary + */ + virtual std::pair forwardBase( + const Variable& state, + const Variable& xEncoded, + const Variable& prevAttn, + const Variable& logAttnWeight, + const Variable& xEncodedSizes + ) = 0; - private: - FL_SAVE_LOAD_WITH_BASE(fl::Container) -}; -} // namespace speech + private: + FL_SAVE_LOAD_WITH_BASE(fl::Container) + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/ContentAttention.cpp b/flashlight/pkg/speech/criterion/attention/ContentAttention.cpp index 6003c6f..6cf50c8 100644 --- a/flashlight/pkg/speech/criterion/attention/ContentAttention.cpp +++ b/flashlight/pkg/speech/criterion/attention/ContentAttention.cpp @@ -15,8 +15,9 @@ namespace fl::pkg::speech { std::unique_ptr ContentAttention::clone() const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'ContentAttention'"); + throw std::runtime_error( + "Cloning is unimplemented in Module 'ContentAttention'" + ); } std::pair ContentAttention::forwardBase( @@ -24,52 +25,56 @@ std::pair ContentAttention::forwardBase( const Variable& xEncoded, const Variable& /* unused */, const Variable& logAttnWeight, - const Variable& xEncodedSizes) { - int dim = xEncoded.dim(0); - if (dim != (1 + ((keyValue_) ? 1 : 0)) * state.dim(0)) { - throw std::invalid_argument( - "ContentAttention: Invalid dimension for content attention"); - } - auto keys = keyValue_ ? xEncoded(fl::range(0, dim / 2)) : xEncoded; - auto values = keyValue_ ? xEncoded(fl::range(dim / 2, dim)) : xEncoded; - // [targetlen, seqlen, batchsize] - auto innerProd = matmulTN(state, keys) / std::sqrt(state.dim(0)); - if (!logAttnWeight.isEmpty()) { - if (logAttnWeight.shape() != innerProd.shape()) { - throw std::invalid_argument( - "ContentAttention: logAttnWeight has wong dimentions"); + const Variable& xEncodedSizes +) { + int dim = xEncoded.dim(0); + if(dim != (1 + ((keyValue_) ? 1 : 0)) * state.dim(0)) { + throw std::invalid_argument( + "ContentAttention: Invalid dimension for content attention" + ); } - innerProd = innerProd + logAttnWeight; - } - Tensor padMask; - if (!xEncodedSizes.isEmpty()) { - innerProd = maskAttention(innerProd, xEncodedSizes); - } - // [targetlen, seqlen, batchsize] - auto attention = softmax(innerProd, 1); - // [hiddendim, targetlen, batchsize] - auto summaries = matmulNT(values, attention); - return std::make_pair(attention, summaries); + auto keys = keyValue_ ? xEncoded(fl::range(0, dim / 2)) : xEncoded; + auto values = keyValue_ ? xEncoded(fl::range(dim / 2, dim)) : xEncoded; + // [targetlen, seqlen, batchsize] + auto innerProd = matmulTN(state, keys) / std::sqrt(state.dim(0)); + if(!logAttnWeight.isEmpty()) { + if(logAttnWeight.shape() != innerProd.shape()) { + throw std::invalid_argument( + "ContentAttention: logAttnWeight has wong dimentions" + ); + } + innerProd = innerProd + logAttnWeight; + } + Tensor padMask; + if(!xEncodedSizes.isEmpty()) { + innerProd = maskAttention(innerProd, xEncodedSizes); + } + // [targetlen, seqlen, batchsize] + auto attention = softmax(innerProd, 1); + // [hiddendim, targetlen, batchsize] + auto summaries = matmulNT(values, attention); + return std::make_pair(attention, summaries); } std::string ContentAttention::prettyString() const { - return "ContentBasedAttention"; + return "ContentBasedAttention"; } NeuralContentAttention::NeuralContentAttention(int dim, int layers /* = 1 */) { - Sequential net; - net.add(ReLU()); - for (int i = 1; i < layers; i++) { - net.add(Linear(dim, dim)); + Sequential net; net.add(ReLU()); - } - net.add(Linear(dim, 1)); - add(std::move(net)); + for(int i = 1; i < layers; i++) { + net.add(Linear(dim, dim)); + net.add(ReLU()); + } + net.add(Linear(dim, 1)); + add(std::move(net)); } std::unique_ptr NeuralContentAttention::clone() const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'NeuralContentAttention'"); + throw std::runtime_error( + "Cloning is unimplemented in Module 'NeuralContentAttention'" + ); } std::pair NeuralContentAttention::forwardBase( @@ -77,37 +82,39 @@ std::pair NeuralContentAttention::forwardBase( const Variable& xEncoded, const Variable& /* unused */, const Variable& logAttnWeight, - const Variable& xEncodedSizes) { - int U = state.dim(1); - int H = xEncoded.dim(0); - int T = xEncoded.dim(1); - int B = xEncoded.dim(2); + const Variable& xEncodedSizes +) { + int U = state.dim(1); + int H = xEncoded.dim(0); + int T = xEncoded.dim(1); + int B = xEncoded.dim(2); - auto tileHx = tile(moddims(xEncoded, {H, 1, T, B}), {1, U, 1, 1}); - auto tileHy = tile(moddims(state, {H, U, 1, B}), {1, 1, T, 1}); - // [hiddendim, targetlen, seqlen, batchsize] - auto hidden = tileHx + tileHy; - // [targetlen, seqlen, batchsize] - auto nnOut = moddims(module(0)->forward({hidden}).front(), {U, T, B}); - if (!logAttnWeight.isEmpty()) { - if (logAttnWeight.shape() != nnOut.shape()) { - throw std::invalid_argument( - "ContentAttention: logAttnWeight has wong dimentions"); + auto tileHx = tile(moddims(xEncoded, {H, 1, T, B}), {1, U, 1, 1}); + auto tileHy = tile(moddims(state, {H, U, 1, B}), {1, 1, T, 1}); + // [hiddendim, targetlen, seqlen, batchsize] + auto hidden = tileHx + tileHy; + // [targetlen, seqlen, batchsize] + auto nnOut = moddims(module(0)->forward({hidden}).front(), {U, T, B}); + if(!logAttnWeight.isEmpty()) { + if(logAttnWeight.shape() != nnOut.shape()) { + throw std::invalid_argument( + "ContentAttention: logAttnWeight has wong dimentions" + ); + } + nnOut = nnOut + logAttnWeight; } - nnOut = nnOut + logAttnWeight; - } - if (!xEncodedSizes.isEmpty()) { - nnOut = maskAttention(nnOut, xEncodedSizes); - } - // [targetlen, seqlen, batchsize] - auto attention = softmax(nnOut, 1); - // [hiddendim, targetlen, batchsize] - auto summaries = matmulNT(xEncoded, attention); - return std::make_pair(attention, summaries); + if(!xEncodedSizes.isEmpty()) { + nnOut = maskAttention(nnOut, xEncodedSizes); + } + // [targetlen, seqlen, batchsize] + auto attention = softmax(nnOut, 1); + // [hiddendim, targetlen, batchsize] + auto summaries = matmulNT(xEncoded, attention); + return std::make_pair(attention, summaries); } std::string NeuralContentAttention::prettyString() const { - return "NeuralContentBasedAttention"; + return "NeuralContentBasedAttention"; } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/ContentAttention.h b/flashlight/pkg/speech/criterion/attention/ContentAttention.h index b00331c..945d48c 100644 --- a/flashlight/pkg/speech/criterion/attention/ContentAttention.h +++ b/flashlight/pkg/speech/criterion/attention/ContentAttention.h @@ -11,49 +11,51 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { -class ContentAttention : public AttentionBase { - public: - ContentAttention(bool keyValue = false) : keyValue_(keyValue) {} + class ContentAttention : public AttentionBase { + public: + ContentAttention(bool keyValue = false) : keyValue_(keyValue) {} - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::pair forwardBase( - const Variable& state, - const Variable& xEncoded, - const Variable& prevAttn, - const Variable& logAttnWeight, - const Variable& xEncodedSizes) override; + std::pair forwardBase( + const Variable& state, + const Variable& xEncoded, + const Variable& prevAttn, + const Variable& logAttnWeight, + const Variable& xEncodedSizes + ) override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - bool keyValue_; + private: + bool keyValue_; - FL_SAVE_LOAD_WITH_BASE(AttentionBase, fl::versioned(keyValue_, 1)) -}; + FL_SAVE_LOAD_WITH_BASE(AttentionBase, fl::versioned(keyValue_, 1)) + }; -class NeuralContentAttention : public AttentionBase { - public: - NeuralContentAttention() {} - explicit NeuralContentAttention(int dim, int layers = 1); + class NeuralContentAttention : public AttentionBase { + public: + NeuralContentAttention() {} + explicit NeuralContentAttention(int dim, int layers = 1); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::pair forwardBase( - const Variable& state, - const Variable& xEncoded, - const Variable& prevAttn, - const Variable& logAttnWeight, - const Variable& xEncodedSizes) override; + std::pair forwardBase( + const Variable& state, + const Variable& xEncoded, + const Variable& prevAttn, + const Variable& logAttnWeight, + const Variable& xEncodedSizes + ) override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - FL_SAVE_LOAD_WITH_BASE(AttentionBase) -}; -} // namespace speech + private: + FL_SAVE_LOAD_WITH_BASE(AttentionBase) + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/Defines.h b/flashlight/pkg/speech/criterion/attention/Defines.h index fc49315..ff10729 100644 --- a/flashlight/pkg/speech/criterion/attention/Defines.h +++ b/flashlight/pkg/speech/criterion/attention/Defines.h @@ -11,31 +11,31 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { // attention -const std::string kContentAttention = "content"; -const std::string kKeyValueAttention = "keyvalue"; -const std::string kLocationAttention = "location"; -const std::string kMultiHeadContentAttention = "multi"; -const std::string kMultiHeadKeyValueContentAttention = "multikv"; -const std::string kMultiHeadSplitContentAttention = "multisplit"; -const std::string kMultiHeadKeyValueSplitContentAttention = "multikvsplit"; -const std::string kNeuralContentAttention = "neural"; -const std::string kNeuralLocationAttention = "neuralloc"; -const std::string kSimpleLocationAttention = "simpleloc"; + const std::string kContentAttention = "content"; + const std::string kKeyValueAttention = "keyvalue"; + const std::string kLocationAttention = "location"; + const std::string kMultiHeadContentAttention = "multi"; + const std::string kMultiHeadKeyValueContentAttention = "multikv"; + const std::string kMultiHeadSplitContentAttention = "multisplit"; + const std::string kMultiHeadKeyValueSplitContentAttention = "multikvsplit"; + const std::string kNeuralContentAttention = "neural"; + const std::string kNeuralLocationAttention = "neuralloc"; + const std::string kSimpleLocationAttention = "simpleloc"; // window -const std::string kMedianWindow = "median"; -const std::string kNoWindow = "no"; -const std::string kSoftWindow = "soft"; -const std::string kSoftPretrainWindow = "softPretrain"; -const std::string kStepWindow = "step"; + const std::string kMedianWindow = "median"; + const std::string kNoWindow = "no"; + const std::string kSoftWindow = "soft"; + const std::string kSoftPretrainWindow = "softPretrain"; + const std::string kStepWindow = "step"; // to avoid nans when apply log to these var // which cannot be propagated correctly if we set -inf -constexpr float kAttentionMaskValue = -10000; + constexpr float kAttentionMaskValue = -10000; -} // namespace speech + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/LocationAttention.cpp b/flashlight/pkg/speech/criterion/attention/LocationAttention.cpp index 59f494b..5756d1d 100644 --- a/flashlight/pkg/speech/criterion/attention/LocationAttention.cpp +++ b/flashlight/pkg/speech/criterion/attention/LocationAttention.cpp @@ -11,16 +11,17 @@ namespace fl::pkg::speech { SimpleLocationAttention::SimpleLocationAttention(int convKernel) { - Sequential pa; - pa.add(Conv2D(1, 1, 1, convKernel, 1, 1, -1, -1)); - pa.add(Reorder({2, 0, 1, 3})); - pa.add(ReLU()); - add(std::move(pa)); + Sequential pa; + pa.add(Conv2D(1, 1, 1, convKernel, 1, 1, -1, -1)); + pa.add(Reorder({2, 0, 1, 3})); + pa.add(ReLU()); + add(std::move(pa)); } std::unique_ptr SimpleLocationAttention::clone() const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'SimpleLocationAttention'"); + throw std::runtime_error( + "Cloning is unimplemented in Module 'SimpleLocationAttention'" + ); } std::pair SimpleLocationAttention::forwardBase( @@ -28,58 +29,63 @@ std::pair SimpleLocationAttention::forwardBase( const Variable& xEncoded, const Variable& prevAttn, const Variable& logAttnWeight, - const Variable& xEncodedSizes) { - int U = state.dim(1); - if (U > 1) { - throw std::invalid_argument( - prettyString() + " only works on single step forward"); - } - - int T = xEncoded.dim(1); - int B = xEncoded.dim(2); - - // [1, seqlen, batchsize] - auto innerProd = matmulTN(state, xEncoded); - - if (!prevAttn.isEmpty()) { - auto addAttn = moddims( - module(0)->forward({moddims(prevAttn, {1, T, 1, B})}).front(), - {1, T, B}); - innerProd = innerProd + addAttn; - } - - if (!logAttnWeight.isEmpty()) { - if (logAttnWeight.shape() != innerProd.shape()) { - throw std::invalid_argument( - "SimpleLocationAttention: logAttnWeight has wong dimentions"); + const Variable& xEncodedSizes +) { + int U = state.dim(1); + if(U > 1) { + throw std::invalid_argument( + prettyString() + " only works on single step forward" + ); } - innerProd = innerProd + logAttnWeight; - } - if (!xEncodedSizes.isEmpty()) { - innerProd = maskAttention(innerProd, xEncodedSizes); - } - // [1, seqlen, batchsize] - auto attention = softmax(innerProd, 1); - // [hiddendim, 1, batchsize] - auto summaries = matmulNT(xEncoded, attention); - return std::make_pair(attention, summaries); + + int T = xEncoded.dim(1); + int B = xEncoded.dim(2); + + // [1, seqlen, batchsize] + auto innerProd = matmulTN(state, xEncoded); + + if(!prevAttn.isEmpty()) { + auto addAttn = moddims( + module(0)->forward({moddims(prevAttn, {1, T, 1, B})}).front(), + {1, T, B} + ); + innerProd = innerProd + addAttn; + } + + if(!logAttnWeight.isEmpty()) { + if(logAttnWeight.shape() != innerProd.shape()) { + throw std::invalid_argument( + "SimpleLocationAttention: logAttnWeight has wong dimentions" + ); + } + innerProd = innerProd + logAttnWeight; + } + if(!xEncodedSizes.isEmpty()) { + innerProd = maskAttention(innerProd, xEncodedSizes); + } + // [1, seqlen, batchsize] + auto attention = softmax(innerProd, 1); + // [hiddendim, 1, batchsize] + auto summaries = matmulNT(xEncoded, attention); + return std::make_pair(attention, summaries); } std::string SimpleLocationAttention::prettyString() const { - return "SimpleLocationBasedAttention"; + return "SimpleLocationBasedAttention"; } LocationAttention::LocationAttention(int encDim, int convKernel) { - Sequential pa; - pa.add(Conv2D(1, encDim, 1, convKernel, 1, 1, -1, -1)); - pa.add(Reorder({2, 0, 1, 3})); - pa.add(ReLU()); - add(std::move(pa)); + Sequential pa; + pa.add(Conv2D(1, encDim, 1, convKernel, 1, 1, -1, -1)); + pa.add(Reorder({2, 0, 1, 3})); + pa.add(ReLU()); + add(std::move(pa)); } std::unique_ptr LocationAttention::clone() const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'LocationAttention'"); + throw std::runtime_error( + "Cloning is unimplemented in Module 'LocationAttention'" + ); } std::pair LocationAttention::forwardBase( @@ -87,66 +93,72 @@ std::pair LocationAttention::forwardBase( const Variable& xEncoded, const Variable& prevAttn, const Variable& logAttnWeight, - const Variable& xEncodedSizes) { - int U = state.dim(1); - if (U > 1) { - throw std::invalid_argument( - prettyString() + " only works on single step forward"); - } - - int H = xEncoded.dim(0); - int T = xEncoded.dim(1); - int B = xEncoded.dim(2); - - auto innerProd = matmulTN(state, xEncoded); - - if (!prevAttn.isEmpty()) { - auto addAttn = moddims( - module(0)->forward({moddims(prevAttn, {1, T, 1, B})}).front(), - {H, T, B}); - innerProd = innerProd + matmulTN(state, addAttn); - } - - if (!logAttnWeight.isEmpty()) { - if (logAttnWeight.shape() != innerProd.shape()) { - throw std::invalid_argument( - "LocationAttention: logAttnWeight has wong dimentions"); + const Variable& xEncodedSizes +) { + int U = state.dim(1); + if(U > 1) { + throw std::invalid_argument( + prettyString() + " only works on single step forward" + ); + } + + int H = xEncoded.dim(0); + int T = xEncoded.dim(1); + int B = xEncoded.dim(2); + + auto innerProd = matmulTN(state, xEncoded); + + if(!prevAttn.isEmpty()) { + auto addAttn = moddims( + module(0)->forward({moddims(prevAttn, {1, T, 1, B})}).front(), + {H, T, B} + ); + innerProd = innerProd + matmulTN(state, addAttn); } - innerProd = innerProd + logAttnWeight; - } - if (!xEncodedSizes.isEmpty()) { - innerProd = maskAttention(innerProd, xEncodedSizes); - } - // [1, seqlen, batchsize] - auto attention = softmax(innerProd, 1); - // [hiddendim, 1, batchsize] - auto summaries = matmulNT(xEncoded, attention); - return std::make_pair(attention, summaries); + + if(!logAttnWeight.isEmpty()) { + if(logAttnWeight.shape() != innerProd.shape()) { + throw std::invalid_argument( + "LocationAttention: logAttnWeight has wong dimentions" + ); + } + innerProd = innerProd + logAttnWeight; + } + if(!xEncodedSizes.isEmpty()) { + innerProd = maskAttention(innerProd, xEncodedSizes); + } + // [1, seqlen, batchsize] + auto attention = softmax(innerProd, 1); + // [hiddendim, 1, batchsize] + auto summaries = matmulNT(xEncoded, attention); + return std::make_pair(attention, summaries); } std::string LocationAttention::prettyString() const { - return "LocationBasedAttention"; + return "LocationBasedAttention"; } NeuralLocationAttention::NeuralLocationAttention( int encDim, int attnDim, int convChannel, - int convKernel) { - add(Linear(encDim, attnDim)); - add(Linear(encDim, attnDim, false)); - Sequential pa; - pa.add(Conv2D(1, convChannel, 1, convKernel, 1, 1, -1, -1)); - pa.add(Reorder({2, 0, 1, 3})); - pa.add(Linear(convChannel, attnDim, false)); - add(std::move(pa)); - add(Tanh()); - add(Linear(attnDim, 1, false)); + int convKernel +) { + add(Linear(encDim, attnDim)); + add(Linear(encDim, attnDim, false)); + Sequential pa; + pa.add(Conv2D(1, convChannel, 1, convKernel, 1, 1, -1, -1)); + pa.add(Reorder({2, 0, 1, 3})); + pa.add(Linear(convChannel, attnDim, false)); + add(std::move(pa)); + add(Tanh()); + add(Linear(attnDim, 1, false)); } std::unique_ptr NeuralLocationAttention::clone() const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'NeuralLocationAttention'"); + throw std::runtime_error( + "Cloning is unimplemented in Module 'NeuralLocationAttention'" + ); } std::pair NeuralLocationAttention::forwardBase( @@ -154,49 +166,53 @@ std::pair NeuralLocationAttention::forwardBase( const Variable& xEncoded, const Variable& prevAttn, const Variable& logAttnWeight, - const Variable& xEncodedSizes) { - int U = state.dim(1); - if (U > 1) { - throw std::invalid_argument( - prettyString() + " only works on single step forward"); - } - - int T = xEncoded.dim(1); - int B = xEncoded.dim(2); - - auto Hx = module(0)->forward({xEncoded}).front(); - auto tileHy = tile(module(1)->forward({state}).front(), {1, T, 1}); - - // [1, seqlen, batchsize] - auto hidden = Hx + tileHy; - if (!prevAttn.isEmpty()) { - auto addAttn = moddims( - module(2)->forward({moddims(prevAttn, {1, T, 1, B})}).front(), - {-1, T, B}); - hidden = hidden + addAttn; - } - hidden = module(3)->forward({hidden}).front(); - auto nnOut = module(4)->forward({hidden}).front(); - - if (!logAttnWeight.isEmpty()) { - if (logAttnWeight.shape() != nnOut.shape()) { - throw std::invalid_argument( - "NeuralLocationAttention: logAttnWeight has wong dimentions"); + const Variable& xEncodedSizes +) { + int U = state.dim(1); + if(U > 1) { + throw std::invalid_argument( + prettyString() + " only works on single step forward" + ); + } + + int T = xEncoded.dim(1); + int B = xEncoded.dim(2); + + auto Hx = module(0)->forward({xEncoded}).front(); + auto tileHy = tile(module(1)->forward({state}).front(), {1, T, 1}); + + // [1, seqlen, batchsize] + auto hidden = Hx + tileHy; + if(!prevAttn.isEmpty()) { + auto addAttn = moddims( + module(2)->forward({moddims(prevAttn, {1, T, 1, B})}).front(), + {-1, T, B} + ); + hidden = hidden + addAttn; + } + hidden = module(3)->forward({hidden}).front(); + auto nnOut = module(4)->forward({hidden}).front(); + + if(!logAttnWeight.isEmpty()) { + if(logAttnWeight.shape() != nnOut.shape()) { + throw std::invalid_argument( + "NeuralLocationAttention: logAttnWeight has wong dimentions" + ); + } + nnOut = nnOut + logAttnWeight; + } + + if(!xEncodedSizes.isEmpty()) { + nnOut = maskAttention(nnOut, xEncodedSizes); } - nnOut = nnOut + logAttnWeight; - } - - if (!xEncodedSizes.isEmpty()) { - nnOut = maskAttention(nnOut, xEncodedSizes); - } - // [1, seqlen, batchsize] - auto attention = softmax(nnOut, 1); - // [hiddendim, 1, batchsize] - auto summaries = matmulNT(xEncoded, attention); - return std::make_pair(attention, summaries); + // [1, seqlen, batchsize] + auto attention = softmax(nnOut, 1); + // [hiddendim, 1, batchsize] + auto summaries = matmulNT(xEncoded, attention); + return std::make_pair(attention, summaries); } std::string NeuralLocationAttention::prettyString() const { - return "NeuralLocationBasedAttention"; + return "NeuralLocationBasedAttention"; } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/LocationAttention.h b/flashlight/pkg/speech/criterion/attention/LocationAttention.h index 3544221..dbed65c 100644 --- a/flashlight/pkg/speech/criterion/attention/LocationAttention.h +++ b/flashlight/pkg/speech/criterion/attention/LocationAttention.h @@ -11,75 +11,79 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { -class SimpleLocationAttention : public AttentionBase { - public: - explicit SimpleLocationAttention(int convKernel); + class SimpleLocationAttention : public AttentionBase { + public: + explicit SimpleLocationAttention(int convKernel); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::pair forwardBase( - const Variable& state, - const Variable& xEncoded, - const Variable& prevAttn, - const Variable& logAttnWeight, - const Variable& xEncodedSizes) override; + std::pair forwardBase( + const Variable& state, + const Variable& xEncoded, + const Variable& prevAttn, + const Variable& logAttnWeight, + const Variable& xEncodedSizes + ) override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - SimpleLocationAttention() = default; + private: + SimpleLocationAttention() = default; - FL_SAVE_LOAD_WITH_BASE(AttentionBase) -}; + FL_SAVE_LOAD_WITH_BASE(AttentionBase) + }; -class LocationAttention : public AttentionBase { - public: - LocationAttention(int encDim, int convKernel); + class LocationAttention : public AttentionBase { + public: + LocationAttention(int encDim, int convKernel); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::pair forwardBase( - const Variable& state, - const Variable& xEncoded, - const Variable& prevAttn, - const Variable& logAttnWeight, - const Variable& xEncodedSizes) override; + std::pair forwardBase( + const Variable& state, + const Variable& xEncoded, + const Variable& prevAttn, + const Variable& logAttnWeight, + const Variable& xEncodedSizes + ) override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - LocationAttention() = default; + private: + LocationAttention() = default; - FL_SAVE_LOAD_WITH_BASE(AttentionBase) -}; + FL_SAVE_LOAD_WITH_BASE(AttentionBase) + }; -class NeuralLocationAttention : public AttentionBase { - public: - NeuralLocationAttention( - int encDim, - int attnDim, - int convChannel, - int convKernel); + class NeuralLocationAttention : public AttentionBase { + public: + NeuralLocationAttention( + int encDim, + int attnDim, + int convChannel, + int convKernel + ); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::pair forwardBase( - const Variable& state, - const Variable& xEncoded, - const Variable& prevAttn, - const Variable& logAttnWeight, - const Variable& xEncodedSizes) override; + std::pair forwardBase( + const Variable& state, + const Variable& xEncoded, + const Variable& prevAttn, + const Variable& logAttnWeight, + const Variable& xEncodedSizes + ) override; - std::string prettyString() const override; + std::string prettyString() const override; - private: - NeuralLocationAttention() = default; + private: + NeuralLocationAttention() = default; - FL_SAVE_LOAD_WITH_BASE(AttentionBase) -}; -} // namespace speech + FL_SAVE_LOAD_WITH_BASE(AttentionBase) + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/MedianWindow.cpp b/flashlight/pkg/speech/criterion/attention/MedianWindow.cpp index c394ba4..c6f392c 100644 --- a/flashlight/pkg/speech/criterion/attention/MedianWindow.cpp +++ b/flashlight/pkg/speech/criterion/attention/MedianWindow.cpp @@ -15,7 +15,8 @@ namespace fl::pkg::speech { MedianWindow::MedianWindow() = default; -MedianWindow::MedianWindow(int wL, int wR) : wL_(wL), wR_(wR) {} +MedianWindow::MedianWindow(int wL, int wR) : wL_(wL), + wR_(wR) {} Variable MedianWindow::computeWindow( const Variable& prevAttn, // [1, windowsize, batchSize] @@ -24,61 +25,74 @@ Variable MedianWindow::computeWindow( int inputSteps, int batchSize, const Tensor& inputSizes, - const Tensor& targetSizes) const { - // Each row of prevAttn is the attention for an input utterance. - // The attention vector is output from a softmax. - // The definition of "median" is the point where cdf passes 0.5. + const Tensor& targetSizes +) const { + // Each row of prevAttn is the attention for an input utterance. + // The attention vector is output from a softmax. + // The definition of "median" is the point where cdf passes 0.5. - int width = std::min(wL_ + wR_, inputSteps); - Tensor inputNotPaddedSize = - computeInputNotPaddedSize(inputSizes, inputSteps, batchSize, 0, false); + int width = std::min(wL_ + wR_, inputSteps); + Tensor inputNotPaddedSize = + computeInputNotPaddedSize(inputSizes, inputSteps, batchSize, 0, false); - if (step == 0 || width == inputSteps) { - // [1, inputSteps] - auto maskArray = fl::full({1, inputSteps, batchSize}, 0.0); - maskArray(fl::span, fl::range(0, width), fl::span) = 1.0; - auto indicesAdd = fl::arange({1, inputSteps, batchSize}, 1); - maskArray(indicesAdd >= fl::tile(inputNotPaddedSize, {1, inputSteps})) = - 0.0; - // [1, inputSteps, batchSize] - return Variable(fl::log(maskArray), false); - } + if(step == 0 || width == inputSteps) { + // [1, inputSteps] + auto maskArray = fl::full({1, inputSteps, batchSize}, 0.0); + maskArray(fl::span, fl::range(0, width), fl::span) = 1.0; + auto indicesAdd = fl::arange({1, inputSteps, batchSize}, 1); + maskArray(indicesAdd >= fl::tile(inputNotPaddedSize, {1, inputSteps})) = + 0.0; + // [1, inputSteps, batchSize] + return Variable(fl::log(maskArray), false); + } - auto mIdx = - fl::sum( - fl::cumsum(prevAttn.tensor(), 1) < 0.5, {1}, /* keepDims = */ true) - .astype(fl::dtype::s32); - auto startIdx = mIdx - wL_; + auto mIdx = + fl::sum( + fl::cumsum(prevAttn.tensor(), 1) < 0.5, + {1}, /* keepDims = */ + true + ) + .astype(fl::dtype::s32); + auto startIdx = mIdx - wL_; - // check boundary conditions and adjust the window - auto startDiff = fl::abs(fl::clip(startIdx, -wL_, 0)); - startIdx = startIdx + startDiff; + // check boundary conditions and adjust the window + auto startDiff = fl::abs(fl::clip(startIdx, -wL_, 0)); + startIdx = startIdx + startDiff; - auto endDiff = fl::abs( - fl::clip(startIdx + wL_ + wR_ - inputNotPaddedSize, 0, wL_ + wR_)); - startIdx = startIdx - endDiff; + auto endDiff = fl::abs( + fl::clip(startIdx + wL_ + wR_ - inputNotPaddedSize, 0, wL_ + wR_) + ); + startIdx = startIdx - endDiff; - auto maskArray = fl::full({1, inputSteps, batchSize}, 0.0, fl::dtype::f32); - auto indices = fl::arange({width, batchSize}, 0) + - fl::tile(fl::reshape(startIdx, {1, batchSize}), {width, 1}) + - fl::tile(fl::reshape( - fl::arange(0, batchSize * inputSteps, inputSteps), - {1, batchSize}), - {width, 1}); - maskArray(indices.flatten()) = 1.0; - auto indicesAdd = fl::arange({1, inputSteps, batchSize}, 1); - maskArray(indicesAdd >= fl::tile(inputNotPaddedSize, {1, inputSteps})) = 0.0; + auto maskArray = fl::full({1, inputSteps, batchSize}, 0.0, fl::dtype::f32); + auto indices = fl::arange({width, batchSize}, 0) + + fl::tile(fl::reshape(startIdx, {1, batchSize}), {width, 1}) + + fl::tile( + fl::reshape( + fl::arange(0, batchSize * inputSteps, inputSteps), + {1, batchSize} + ), + {width, 1} + ); + maskArray(indices.flatten()) = 1.0; + auto indicesAdd = fl::arange({1, inputSteps, batchSize}, 1); + maskArray(indicesAdd >= fl::tile(inputNotPaddedSize, {1, inputSteps})) = 0.0; - if (!targetSizes.isEmpty()) { - Tensor targetNotPaddedSize = computeTargetNotPaddedSize( - targetSizes, inputSteps, targetLen, batchSize, 1); - maskArray(step >= targetNotPaddedSize) = 0.0; - } - maskArray = fl::log(maskArray); - // force all -inf values to be kAttentionMaskValue to avoid nan in softmax - maskArray(maskArray < kAttentionMaskValue) = kAttentionMaskValue; - // [1, inputSteps, batchSize] - return Variable(maskArray, false); + if(!targetSizes.isEmpty()) { + Tensor targetNotPaddedSize = computeTargetNotPaddedSize( + targetSizes, + inputSteps, + targetLen, + batchSize, + 1 + ); + maskArray(step >= targetNotPaddedSize) = 0.0; + } + maskArray = fl::log(maskArray); + // force all -inf values to be kAttentionMaskValue to avoid nan in softmax + maskArray(maskArray < kAttentionMaskValue) = kAttentionMaskValue; + // [1, inputSteps, batchSize] + return Variable(maskArray, false); } Variable MedianWindow::computeVectorizedWindow( @@ -86,8 +100,10 @@ Variable MedianWindow::computeVectorizedWindow( int /* unused */, int /* unused */, const Tensor& /* unused */, - const Tensor& /* unused */) const { - throw std::invalid_argument( - "MedianWindow does not support vectorized window mask"); + const Tensor& /* unused */ +) const { + throw std::invalid_argument( + "MedianWindow does not support vectorized window mask" + ); } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/MedianWindow.h b/flashlight/pkg/speech/criterion/attention/MedianWindow.h index ea4927b..75c845f 100644 --- a/flashlight/pkg/speech/criterion/attention/MedianWindow.h +++ b/flashlight/pkg/speech/criterion/attention/MedianWindow.h @@ -11,36 +11,38 @@ namespace fl { namespace pkg { -namespace speech { - -class MedianWindow : public WindowBase { - public: - MedianWindow(); - MedianWindow(int wL, int wR); - - Variable computeWindow( - const Variable& prevAttn, - int step, - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) const override; - - Variable computeVectorizedWindow( - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) const override; - - private: - int wL_; - int wR_; - - FL_SAVE_LOAD_WITH_BASE(WindowBase, wL_, wR_) -}; -} // namespace speech + namespace speech { + + class MedianWindow : public WindowBase { + public: + MedianWindow(); + MedianWindow(int wL, int wR); + + Variable computeWindow( + const Variable& prevAttn, + int step, + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) const override; + + Variable computeVectorizedWindow( + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) const override; + + private: + int wL_; + int wR_; + + FL_SAVE_LOAD_WITH_BASE(WindowBase, wL_, wR_) + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.cpp b/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.cpp index 53e92b1..32354c6 100644 --- a/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.cpp +++ b/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.cpp @@ -19,23 +19,26 @@ MultiHeadContentAttention::MultiHeadContentAttention( int dim, int numHeads /* = 8 */, bool keyValue /* = false */, - bool splitInput /* = false */) - : numHeads_(numHeads), keyValue_(keyValue), splitInput_(splitInput) { - if (splitInput && dim % numHeads != 0) { - throw std::invalid_argument("Invalid dimensions"); - } - - if (!splitInput) { - add(Linear(dim, dim)); // query - add(Linear(dim, dim)); // key - add(Linear(dim, dim)); // value - } - add(Linear(dim, dim)); + bool splitInput /* = false */ +) : numHeads_(numHeads), + keyValue_(keyValue), + splitInput_(splitInput) { + if(splitInput && dim % numHeads != 0) { + throw std::invalid_argument("Invalid dimensions"); + } + + if(!splitInput) { + add(Linear(dim, dim)); // query + add(Linear(dim, dim)); // key + add(Linear(dim, dim)); // value + } + add(Linear(dim, dim)); } std::unique_ptr MultiHeadContentAttention::clone() const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'MultiHeadContentAttention'"); + throw std::runtime_error( + "Cloning is unimplemented in Module 'MultiHeadContentAttention'" + ); } std::pair MultiHeadContentAttention::forwardBase( @@ -43,76 +46,81 @@ std::pair MultiHeadContentAttention::forwardBase( const Variable& xEncoded, const Variable& /* unused */, const Variable& logAttnWeight, - const Variable& xEncodedSizes) { - if (state.ndim() != 3) { - throw std::invalid_argument( - "MultiHeadContentAttention::forwardBase: " - "state input must be of shape {H, U, B}"); - } - int hEncode = xEncoded.dim(0); - int T = xEncoded.dim(1); - int hState = state.dim(0); - int U = state.dim(1); - int B = state.dim(2); - auto hiddenDim = hState / numHeads_; - if (hEncode != (1 + keyValue_) * hState) { - throw std::invalid_argument("Invalid input encoder dimension"); - } - - auto xEncodedKey = keyValue_ - ? xEncoded(fl::arange(0, hEncode / 2), fl::span, fl::span) - : xEncoded; - auto xEncodedValue = keyValue_ - ? xEncoded(fl::arange(hEncode / 2, hEncode), fl::span, fl::span) - : xEncoded; - - auto query = splitInput_ ? state : module(0)->forward({state})[0]; - auto key = splitInput_ ? xEncodedKey : module(1)->forward({xEncodedKey})[0]; - auto value = - splitInput_ ? xEncodedValue : module(2)->forward({xEncodedValue})[0]; - - query = - moddims(fl::transpose(query, {1, 0, 2}), {U, hiddenDim, B * numHeads_}); - key = moddims(fl::transpose(key, {1, 0, 2}), {T, hiddenDim, B * numHeads_}); - value = - moddims(fl::transpose(value, {1, 0, 2}), {T, hiddenDim, B * numHeads_}); - - // [U, T, B * numHeads_] - auto innerProd = - matmulNT(query, key) / std::sqrt(static_cast(hiddenDim)); - - if (!logAttnWeight.isEmpty()) { - auto tiledLogAttnWeight = tile(logAttnWeight, {1, 1, numHeads_}); - if (tiledLogAttnWeight.shape() != innerProd.shape()) { - throw std::invalid_argument( - "MultiHeadContentAttention: logAttnWeight has wong dimentions"); + const Variable& xEncodedSizes +) { + if(state.ndim() != 3) { + throw std::invalid_argument( + "MultiHeadContentAttention::forwardBase: " + "state input must be of shape {H, U, B}" + ); } - innerProd = innerProd + tiledLogAttnWeight; - } - - if (!xEncodedSizes.isEmpty()) { - innerProd = maskAttention( - innerProd, - moddims(tile(xEncodedSizes, {numHeads_, 1}), {1, B * numHeads_})); - } - - // [U, T, B * numHeads_] - auto attention = softmax(innerProd, 1); - // [U, hiddendim, B * numHeads_] - auto summaries = matmul(attention, value); - // [hiddendim * numHeads_, U, B]; - summaries = reorder(moddims(summaries, {U, hState, B}), {1, 0, 2}); - - auto out_summaries = modules().back()->forward({summaries}).front(); - - // [U * numHeads_, T, B] - attention = moddims( - reorder(moddims(attention, {U, T, numHeads_, B}), {0, 2, 1, 3}), - {U * numHeads_, T, B}); - return std::make_pair(attention, out_summaries); + int hEncode = xEncoded.dim(0); + int T = xEncoded.dim(1); + int hState = state.dim(0); + int U = state.dim(1); + int B = state.dim(2); + auto hiddenDim = hState / numHeads_; + if(hEncode != (1 + keyValue_) * hState) { + throw std::invalid_argument("Invalid input encoder dimension"); + } + + auto xEncodedKey = keyValue_ + ? xEncoded(fl::arange(0, hEncode / 2), fl::span, fl::span) + : xEncoded; + auto xEncodedValue = keyValue_ + ? xEncoded(fl::arange(hEncode / 2, hEncode), fl::span, fl::span) + : xEncoded; + + auto query = splitInput_ ? state : module(0)->forward({state})[0]; + auto key = splitInput_ ? xEncodedKey : module(1)->forward({xEncodedKey})[0]; + auto value = + splitInput_ ? xEncodedValue : module(2)->forward({xEncodedValue})[0]; + + query = + moddims(fl::transpose(query, {1, 0, 2}), {U, hiddenDim, B * numHeads_}); + key = moddims(fl::transpose(key, {1, 0, 2}), {T, hiddenDim, B * numHeads_}); + value = + moddims(fl::transpose(value, {1, 0, 2}), {T, hiddenDim, B * numHeads_}); + + // [U, T, B * numHeads_] + auto innerProd = + matmulNT(query, key) / std::sqrt(static_cast(hiddenDim)); + + if(!logAttnWeight.isEmpty()) { + auto tiledLogAttnWeight = tile(logAttnWeight, {1, 1, numHeads_}); + if(tiledLogAttnWeight.shape() != innerProd.shape()) { + throw std::invalid_argument( + "MultiHeadContentAttention: logAttnWeight has wong dimentions" + ); + } + innerProd = innerProd + tiledLogAttnWeight; + } + + if(!xEncodedSizes.isEmpty()) { + innerProd = maskAttention( + innerProd, + moddims(tile(xEncodedSizes, {numHeads_, 1}), {1, B * numHeads_}) + ); + } + + // [U, T, B * numHeads_] + auto attention = softmax(innerProd, 1); + // [U, hiddendim, B * numHeads_] + auto summaries = matmul(attention, value); + // [hiddendim * numHeads_, U, B]; + summaries = reorder(moddims(summaries, {U, hState, B}), {1, 0, 2}); + + auto out_summaries = modules().back()->forward({summaries}).front(); + + // [U * numHeads_, T, B] + attention = moddims( + reorder(moddims(attention, {U, T, numHeads_, B}), {0, 2, 1, 3}), + {U* numHeads_, T, B} + ); + return std::make_pair(attention, out_summaries); } std::string MultiHeadContentAttention::prettyString() const { - return "MultiHeadContentAttention"; + return "MultiHeadContentAttention"; } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.h b/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.h index 0b69ef7..a54ad79 100644 --- a/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.h +++ b/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.h @@ -11,34 +11,36 @@ namespace fl { namespace pkg { -namespace speech { - -class MultiHeadContentAttention : public AttentionBase { - public: - MultiHeadContentAttention() {} - explicit MultiHeadContentAttention( - int dim, - int num_heads = 8, - bool keyValue = false, - bool splitInput = false); - std::unique_ptr clone() const override; - - std::pair forwardBase( - const Variable& state, - const Variable& xEncoded, - const Variable& prevAttn, - const Variable& logAttnWeight, - const Variable& xEncodedSizes) override; - - std::string prettyString() const override; - - private: - int numHeads_; - bool keyValue_; - bool splitInput_; - FL_SAVE_LOAD_WITH_BASE(AttentionBase, numHeads_, keyValue_, splitInput_) -}; -} // namespace speech + namespace speech { + + class MultiHeadContentAttention : public AttentionBase { + public: + MultiHeadContentAttention() {} + explicit MultiHeadContentAttention( + int dim, + int num_heads = 8, + bool keyValue = false, + bool splitInput = false + ); + std::unique_ptr clone() const override; + + std::pair forwardBase( + const Variable& state, + const Variable& xEncoded, + const Variable& prevAttn, + const Variable& logAttnWeight, + const Variable& xEncodedSizes + ) override; + + std::string prettyString() const override; + + private: + int numHeads_; + bool keyValue_; + bool splitInput_; + FL_SAVE_LOAD_WITH_BASE(AttentionBase, numHeads_, keyValue_, splitInput_) + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.cpp b/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.cpp index cd45025..719da8f 100644 --- a/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.cpp +++ b/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.cpp @@ -21,32 +21,46 @@ Variable SoftPretrainWindow::compute( int batchSize, const Tensor& inputSizes, const Tensor& targetSizes, - Tensor& decoderSteps) const { - int decoderStepsDim = decoderSteps.dim(0); - auto ts = fl::arange({decoderStepsDim, inputSteps, batchSize}, 1); - if (inputSizes.isEmpty() && targetSizes.isEmpty()) { - return Variable( - -fl::power(ts - inputSteps / targetLen * decoderSteps, 2) / - (2 * std_ * std_), - false); - } + Tensor& decoderSteps +) const { + int decoderStepsDim = decoderSteps.dim(0); + auto ts = fl::arange({decoderStepsDim, inputSteps, batchSize}, 1); + if(inputSizes.isEmpty() && targetSizes.isEmpty()) { + return Variable( + -fl::power(ts - inputSteps / targetLen * decoderSteps, 2) + / (2 * std_ * std_), + false + ); + } - Tensor inputNotPaddedSize = computeInputNotPaddedSize( - inputSizes, inputSteps, batchSize, decoderStepsDim, true); - Tensor targetNotPaddedSize = computeTargetNotPaddedSize( - targetSizes, inputSteps, targetLen, batchSize, decoderStepsDim); + Tensor inputNotPaddedSize = computeInputNotPaddedSize( + inputSizes, + inputSteps, + batchSize, + decoderStepsDim, + true + ); + Tensor targetNotPaddedSize = computeTargetNotPaddedSize( + targetSizes, + inputSteps, + targetLen, + batchSize, + decoderStepsDim + ); - auto maskArray = - -fl::power( - ts - inputNotPaddedSize / targetNotPaddedSize * decoderSteps, 2) / - (2 * std_ * std_); - maskArray(ts >= inputNotPaddedSize) = -std::numeric_limits::infinity(); - maskArray(decoderSteps >= targetNotPaddedSize) = - -std::numeric_limits::infinity(); - // force all -inf values to be kAttentionMaskValue to avoid nan in softmax - maskArray(maskArray < kAttentionMaskValue) = kAttentionMaskValue; - // [decoderStepsDim, inputSteps, batchSize] - return Variable(maskArray, false); + auto maskArray = + -fl::power( + ts - inputNotPaddedSize / targetNotPaddedSize * decoderSteps, + 2 + ) + / (2 * std_ * std_); + maskArray(ts >= inputNotPaddedSize) = -std::numeric_limits::infinity(); + maskArray(decoderSteps >= targetNotPaddedSize) = + -std::numeric_limits::infinity(); + // force all -inf values to be kAttentionMaskValue to avoid nan in softmax + maskArray(maskArray < kAttentionMaskValue) = kAttentionMaskValue; + // [decoderStepsDim, inputSteps, batchSize] + return Variable(maskArray, false); } Variable SoftPretrainWindow::computeWindow( @@ -56,10 +70,17 @@ Variable SoftPretrainWindow::computeWindow( int inputSteps, int batchSize, const Tensor& inputSizes, - const Tensor& targetSizes) const { - Tensor decoderSteps = fl::full({1, inputSteps, batchSize}, step); - return compute( - targetLen, inputSteps, batchSize, inputSizes, targetSizes, decoderSteps); + const Tensor& targetSizes +) const { + Tensor decoderSteps = fl::full({1, inputSteps, batchSize}, step); + return compute( + targetLen, + inputSteps, + batchSize, + inputSizes, + targetSizes, + decoderSteps + ); } Variable SoftPretrainWindow::computeVectorizedWindow( @@ -67,9 +88,16 @@ Variable SoftPretrainWindow::computeVectorizedWindow( int inputSteps, int batchSize, const Tensor& inputSizes, - const Tensor& targetSizes) const { - Tensor decoderSteps = fl::arange({targetLen, inputSteps, batchSize}, 0); - return compute( - targetLen, inputSteps, batchSize, inputSizes, targetSizes, decoderSteps); + const Tensor& targetSizes +) const { + Tensor decoderSteps = fl::arange({targetLen, inputSteps, batchSize}, 0); + return compute( + targetLen, + inputSteps, + batchSize, + inputSizes, + targetSizes, + decoderSteps + ); } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.h b/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.h index d172b53..66a5fd0 100644 --- a/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.h +++ b/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.h @@ -11,44 +11,47 @@ namespace fl { namespace pkg { -namespace speech { - -class SoftPretrainWindow : public WindowBase { - public: - explicit SoftPretrainWindow(double std); - - Variable computeWindow( - const Variable& prevAttn, - int step, - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) const override; - - Variable computeVectorizedWindow( - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) const override; - - private: - SoftPretrainWindow() = default; - - double std_; - - Variable compute( - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes, - const Tensor& targetSizes, - Tensor& decoderSteps) const; - - FL_SAVE_LOAD_WITH_BASE(WindowBase, std_) -}; -} // namespace speech + namespace speech { + + class SoftPretrainWindow : public WindowBase { + public: + explicit SoftPretrainWindow(double std); + + Variable computeWindow( + const Variable& prevAttn, + int step, + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) const override; + + Variable computeVectorizedWindow( + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) const override; + + private: + SoftPretrainWindow() = default; + + double std_; + + Variable compute( + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes, + const Tensor& targetSizes, + Tensor& decoderSteps + ) const; + + FL_SAVE_LOAD_WITH_BASE(WindowBase, std_) + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/SoftWindow.cpp b/flashlight/pkg/speech/criterion/attention/SoftWindow.cpp index 36af8fe..00ad2b0 100644 --- a/flashlight/pkg/speech/criterion/attention/SoftWindow.cpp +++ b/flashlight/pkg/speech/criterion/attention/SoftWindow.cpp @@ -13,8 +13,9 @@ namespace fl::pkg::speech { SoftWindow::SoftWindow() = default; -SoftWindow::SoftWindow(double std, double avgRate, int offset) - : std_(std), avgRate_(avgRate), offset_(offset) {} +SoftWindow::SoftWindow(double std, double avgRate, int offset) : std_(std), + avgRate_(avgRate), + offset_(offset) {} Variable SoftWindow::compute( int targetLen, @@ -22,24 +23,39 @@ Variable SoftWindow::compute( int batchSize, const Tensor& inputSizes, const Tensor& targetSizes, - Tensor& decoderSteps) const { - int decoderStepsDim = decoderSteps.dim(0); - auto ts = fl::arange({decoderStepsDim, inputSteps, batchSize}, 1); - Tensor inputNotPaddedSize = computeInputNotPaddedSize( - inputSizes, inputSteps, batchSize, decoderStepsDim, true); + Tensor& decoderSteps +) const { + int decoderStepsDim = decoderSteps.dim(0); + auto ts = fl::arange({decoderStepsDim, inputSteps, batchSize}, 1); + Tensor inputNotPaddedSize = computeInputNotPaddedSize( + inputSizes, + inputSteps, + batchSize, + decoderStepsDim, + true + ); - Tensor centers = fl::rint(fl::minimum( - offset_ + decoderSteps * avgRate_, inputNotPaddedSize - avgRate_)); - auto maskArray = -fl::power(ts - centers, 2) / (2 * std_ * std_); - maskArray(ts >= inputNotPaddedSize) = -std::numeric_limits::infinity(); + Tensor centers = fl::rint( + fl::minimum( + offset_ + decoderSteps * avgRate_, + inputNotPaddedSize - avgRate_ + ) + ); + auto maskArray = -fl::power(ts - centers, 2) / (2 * std_ * std_); + maskArray(ts >= inputNotPaddedSize) = -std::numeric_limits::infinity(); - if (!targetSizes.isEmpty()) { - Tensor targetNotPaddedSize = computeTargetNotPaddedSize( - targetSizes, inputSteps, targetLen, batchSize, decoderStepsDim); - maskArray(decoderSteps >= targetNotPaddedSize) = kAttentionMaskValue; - } - // [decoderStepsDim, inputSteps, batchSize] - return Variable(maskArray, false); + if(!targetSizes.isEmpty()) { + Tensor targetNotPaddedSize = computeTargetNotPaddedSize( + targetSizes, + inputSteps, + targetLen, + batchSize, + decoderStepsDim + ); + maskArray(decoderSteps >= targetNotPaddedSize) = kAttentionMaskValue; + } + // [decoderStepsDim, inputSteps, batchSize] + return Variable(maskArray, false); } Variable SoftWindow::computeWindow( @@ -49,10 +65,17 @@ Variable SoftWindow::computeWindow( int inputSteps, int batchSize, const Tensor& inputSizes, - const Tensor& targetSizes) const { - Tensor decoderSteps = fl::full({1, inputSteps, batchSize}, step); - return compute( - targetLen, inputSteps, batchSize, inputSizes, targetSizes, decoderSteps); + const Tensor& targetSizes +) const { + Tensor decoderSteps = fl::full({1, inputSteps, batchSize}, step); + return compute( + targetLen, + inputSteps, + batchSize, + inputSizes, + targetSizes, + decoderSteps + ); } Variable SoftWindow::computeVectorizedWindow( @@ -60,9 +83,16 @@ Variable SoftWindow::computeVectorizedWindow( int inputSteps, int batchSize, const Tensor& inputSizes, - const Tensor& targetSizes) const { - Tensor decoderSteps = fl::arange({targetLen, inputSteps, batchSize}, 0); - return compute( - targetLen, inputSteps, batchSize, inputSizes, targetSizes, decoderSteps); + const Tensor& targetSizes +) const { + Tensor decoderSteps = fl::arange({targetLen, inputSteps, batchSize}, 0); + return compute( + targetLen, + inputSteps, + batchSize, + inputSizes, + targetSizes, + decoderSteps + ); } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/SoftWindow.h b/flashlight/pkg/speech/criterion/attention/SoftWindow.h index 5477d63..f4326ba 100644 --- a/flashlight/pkg/speech/criterion/attention/SoftWindow.h +++ b/flashlight/pkg/speech/criterion/attention/SoftWindow.h @@ -11,45 +11,48 @@ namespace fl { namespace pkg { -namespace speech { - -class SoftWindow : public WindowBase { - public: - SoftWindow(); - SoftWindow(double std, double avgRate, int offset); - - Variable computeWindow( - const Variable& prevAttn, - int step, - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) const override; - - Variable computeVectorizedWindow( - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) const override; - - private: - Variable compute( - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes, - const Tensor& targetSizes, - Tensor& decoderSteps) const; - - double std_; - double avgRate_; - int offset_; - - FL_SAVE_LOAD_WITH_BASE(WindowBase, std_, avgRate_, offset_) -}; -} // namespace speech + namespace speech { + + class SoftWindow : public WindowBase { + public: + SoftWindow(); + SoftWindow(double std, double avgRate, int offset); + + Variable computeWindow( + const Variable& prevAttn, + int step, + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) const override; + + Variable computeVectorizedWindow( + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) const override; + + private: + Variable compute( + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes, + const Tensor& targetSizes, + Tensor& decoderSteps + ) const; + + double std_; + double avgRate_; + int offset_; + + FL_SAVE_LOAD_WITH_BASE(WindowBase, std_, avgRate_, offset_) + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/StepWindow.cpp b/flashlight/pkg/speech/criterion/attention/StepWindow.cpp index 22c0e41..eebaf6b 100644 --- a/flashlight/pkg/speech/criterion/attention/StepWindow.cpp +++ b/flashlight/pkg/speech/criterion/attention/StepWindow.cpp @@ -13,8 +13,10 @@ namespace fl::pkg::speech { StepWindow::StepWindow() = default; -StepWindow::StepWindow(int sMin, int sMax, double vMin, double vMax) - : sMin_(sMin), sMax_(sMax), vMin_(vMin), vMax_(vMax) {} +StepWindow::StepWindow(int sMin, int sMax, double vMin, double vMax) : sMin_(sMin), + sMax_(sMax), + vMin_(vMin), + vMax_(vMax) {} Variable StepWindow::compute( int targetLen, @@ -22,34 +24,48 @@ Variable StepWindow::compute( int batchSize, const Tensor& inputSizes, const Tensor& targetSizes, - Tensor& decoderSteps) const { - int decoderStepsDim = decoderSteps.dim(0); - Tensor inputNotPaddedSize = computeInputNotPaddedSize( - inputSizes, inputSteps, batchSize, decoderStepsDim, true); - Tensor startIdx = fl::maximum( - 0, - fl::rint( - fl::minimum(inputNotPaddedSize - vMax_, sMin_ + decoderSteps * vMin_)) - .astype(fl::dtype::s32)); - auto endIdx = fl::minimum( - inputNotPaddedSize, - fl::rint(sMax_ + decoderSteps * vMax_).astype(fl::dtype::s32)); - Tensor indices = - fl::iota({1, inputSteps, 1}, {decoderStepsDim, 1, batchSize}); + Tensor& decoderSteps +) const { + int decoderStepsDim = decoderSteps.dim(0); + Tensor inputNotPaddedSize = computeInputNotPaddedSize( + inputSizes, + inputSteps, + batchSize, + decoderStepsDim, + true + ); + Tensor startIdx = fl::maximum( + 0, + fl::rint( + fl::minimum(inputNotPaddedSize - vMax_, sMin_ + decoderSteps * vMin_) + ) + .astype(fl::dtype::s32) + ); + auto endIdx = fl::minimum( + inputNotPaddedSize, + fl::rint(sMax_ + decoderSteps * vMax_).astype(fl::dtype::s32) + ); + Tensor indices = + fl::iota({1, inputSteps, 1}, {decoderStepsDim, 1, batchSize}); - // [decoderStepsDim, inputSteps, batchSize] - Tensor maskTensor = fl::full({decoderStepsDim, inputSteps, batchSize}, 1.0); - maskTensor(indices < startIdx) = 0.0; - maskTensor(indices >= endIdx) = 0.0; - if (!targetSizes.isEmpty()) { - Tensor targetNotPaddedSize = computeTargetNotPaddedSize( - targetSizes, inputSteps, targetLen, batchSize, decoderStepsDim); - maskTensor(decoderSteps >= targetNotPaddedSize) = 0.0; - } - // force all -inf values to be kAttentionMaskValue to avoid nan in softmax - maskTensor = fl::log(maskTensor); - maskTensor(maskTensor < kAttentionMaskValue) = kAttentionMaskValue; - return Variable(maskTensor, false); + // [decoderStepsDim, inputSteps, batchSize] + Tensor maskTensor = fl::full({decoderStepsDim, inputSteps, batchSize}, 1.0); + maskTensor(indices < startIdx) = 0.0; + maskTensor(indices >= endIdx) = 0.0; + if(!targetSizes.isEmpty()) { + Tensor targetNotPaddedSize = computeTargetNotPaddedSize( + targetSizes, + inputSteps, + targetLen, + batchSize, + decoderStepsDim + ); + maskTensor(decoderSteps >= targetNotPaddedSize) = 0.0; + } + // force all -inf values to be kAttentionMaskValue to avoid nan in softmax + maskTensor = fl::log(maskTensor); + maskTensor(maskTensor < kAttentionMaskValue) = kAttentionMaskValue; + return Variable(maskTensor, false); } Variable StepWindow::computeWindow( @@ -59,10 +75,17 @@ Variable StepWindow::computeWindow( int inputSteps, int batchSize, const Tensor& inputSizes, - const Tensor& targetSizes) const { - auto decoderSteps = fl::full({1, inputSteps, batchSize}, step); - return compute( - targetLen, inputSteps, batchSize, inputSizes, targetSizes, decoderSteps); + const Tensor& targetSizes +) const { + auto decoderSteps = fl::full({1, inputSteps, batchSize}, step); + return compute( + targetLen, + inputSteps, + batchSize, + inputSizes, + targetSizes, + decoderSteps + ); } Variable StepWindow::computeVectorizedWindow( @@ -70,9 +93,16 @@ Variable StepWindow::computeVectorizedWindow( int inputSteps, int batchSize, const Tensor& inputSizes, - const Tensor& targetSizes) const { - auto decoderSteps = fl::iota({targetLen}, {1, inputSteps, batchSize}); - return compute( - targetLen, inputSteps, batchSize, inputSizes, targetSizes, decoderSteps); + const Tensor& targetSizes +) const { + auto decoderSteps = fl::iota({targetLen}, {1, inputSteps, batchSize}); + return compute( + targetLen, + inputSteps, + batchSize, + inputSizes, + targetSizes, + decoderSteps + ); } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/StepWindow.h b/flashlight/pkg/speech/criterion/attention/StepWindow.h index 194ae97..30e2b5d 100644 --- a/flashlight/pkg/speech/criterion/attention/StepWindow.h +++ b/flashlight/pkg/speech/criterion/attention/StepWindow.h @@ -11,46 +11,49 @@ namespace fl { namespace pkg { -namespace speech { - -class StepWindow : public WindowBase { - public: - StepWindow(); - StepWindow(int sMin, int sMax, double vMin, double vMax); - - Variable computeWindow( - const Variable& prevAttn, - int step, - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) const override; - - Variable computeVectorizedWindow( - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) const override; - - private: - int sMin_; - int sMax_; - double vMin_; - double vMax_; - - Variable compute( - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes, - const Tensor& targetSizes, - Tensor& decoderSteps) const; - - FL_SAVE_LOAD_WITH_BASE(WindowBase, sMin_, sMax_, vMin_, vMax_) -}; -} // namespace speech + namespace speech { + + class StepWindow : public WindowBase { + public: + StepWindow(); + StepWindow(int sMin, int sMax, double vMin, double vMax); + + Variable computeWindow( + const Variable& prevAttn, + int step, + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) const override; + + Variable computeVectorizedWindow( + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) const override; + + private: + int sMin_; + int sMax_; + double vMin_; + double vMax_; + + Variable compute( + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes, + const Tensor& targetSizes, + Tensor& decoderSteps + ) const; + + FL_SAVE_LOAD_WITH_BASE(WindowBase, sMin_, sMax_, vMin_, vMax_) + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/Utils.cpp b/flashlight/pkg/speech/criterion/attention/Utils.cpp index d7baab4..a0fd19d 100644 --- a/flashlight/pkg/speech/criterion/attention/Utils.cpp +++ b/flashlight/pkg/speech/criterion/attention/Utils.cpp @@ -14,25 +14,25 @@ namespace fl::pkg::speech { Variable maskAttention(const Variable& input, const Variable& sizes) { - int B = input.dim(2); - int T = input.dim(1); - // xEncodedSizes is (1, B) size - Tensor inputNotPaddedSize = - fl::ceil(sizes.tensor() / fl::amax(sizes.tensor()).asScalar() * T); - Tensor padMask = - fl::iota({T, 1}, {1, B}) >= fl::tile(inputNotPaddedSize, {T, 1}); - padMask = fl::tile(fl::reshape(padMask, {1, T, B}), {input.dim(0), 1, 1}); + int B = input.dim(2); + int T = input.dim(1); + // xEncodedSizes is (1, B) size + Tensor inputNotPaddedSize = + fl::ceil(sizes.tensor() / fl::amax(sizes.tensor()).asScalar() * T); + Tensor padMask = + fl::iota({T, 1}, {1, B}) >= fl::tile(inputNotPaddedSize, {T, 1}); + padMask = fl::tile(fl::reshape(padMask, {1, T, B}), {input.dim(0), 1, 1}); - Tensor output = input.tensor(); - output(padMask) = kAttentionMaskValue; + Tensor output = input.tensor(); + output(padMask) = kAttentionMaskValue; - auto gradFunc = - [padMask](std::vector& inputs, const Variable& gradOutput) { - Tensor gradArray = gradOutput.tensor(); - gradArray(padMask) = 0.; - inputs[0].addGrad(Variable(gradArray, false)); - }; - return Variable(output, {input.withoutData()}, gradFunc); + auto gradFunc = + [padMask](std::vector& inputs, const Variable& gradOutput) { + Tensor gradArray = gradOutput.tensor(); + gradArray(padMask) = 0.; + inputs[0].addGrad(Variable(gradArray, false)); + }; + return Variable(output, {input.withoutData()}, gradFunc); } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/Utils.h b/flashlight/pkg/speech/criterion/attention/Utils.h index 146871c..eac4f5a 100644 --- a/flashlight/pkg/speech/criterion/attention/Utils.h +++ b/flashlight/pkg/speech/criterion/attention/Utils.h @@ -11,11 +11,12 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { -fl::Variable maskAttention( - const fl::Variable& input, - const fl::Variable& sizes); -} // namespace speech + fl::Variable maskAttention( + const fl::Variable& input, + const fl::Variable& sizes + ); + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/WindowBase.cpp b/flashlight/pkg/speech/criterion/attention/WindowBase.cpp index 3546e14..81b0958 100644 --- a/flashlight/pkg/speech/criterion/attention/WindowBase.cpp +++ b/flashlight/pkg/speech/criterion/attention/WindowBase.cpp @@ -14,27 +14,33 @@ Tensor WindowBase::computeInputNotPaddedSize( int inputSteps, int batchSize, int decoderStepsDim, - bool doTile) const { - if (inputSizes.isEmpty()) { - if (doTile) { - return fl::full( - {decoderStepsDim, inputSteps, batchSize}, inputSteps, fl::dtype::f32); - } else { - return fl::full({1, 1, batchSize}, inputSteps, fl::dtype::f32); + bool doTile +) const { + if(inputSizes.isEmpty()) { + if(doTile) { + return fl::full( + {decoderStepsDim, inputSteps, batchSize}, + inputSteps, + fl::dtype::f32 + ); + } else { + return fl::full({1, 1, batchSize}, inputSteps, fl::dtype::f32); + } } - } - if (inputSizes.elements() != batchSize) { - throw std::runtime_error( - "Attention Window: wrong size of the input sizes vector, doesn't match with batchsize"); - } - Tensor inputNotPaddedSize = fl::ceil( - inputSizes / fl::amax(inputSizes).asScalar() * inputSteps); - inputNotPaddedSize = fl::reshape(inputNotPaddedSize, {1, 1, batchSize}); - if (doTile) { - inputNotPaddedSize = - fl::tile(inputNotPaddedSize, {decoderStepsDim, inputSteps, 1}); - } - return inputNotPaddedSize; + if(inputSizes.elements() != batchSize) { + throw std::runtime_error( + "Attention Window: wrong size of the input sizes vector, doesn't match with batchsize" + ); + } + Tensor inputNotPaddedSize = fl::ceil( + inputSizes / fl::amax(inputSizes).asScalar() * inputSteps + ); + inputNotPaddedSize = fl::reshape(inputNotPaddedSize, {1, 1, batchSize}); + if(doTile) { + inputNotPaddedSize = + fl::tile(inputNotPaddedSize, {decoderStepsDim, inputSteps, 1}); + } + return inputNotPaddedSize; } Tensor WindowBase::computeTargetNotPaddedSize( @@ -42,22 +48,29 @@ Tensor WindowBase::computeTargetNotPaddedSize( int inputSteps, int targetLen, int batchSize, - int decoderStepsDim) const { - if (targetSizes.isEmpty()) { - return fl::full( - {decoderStepsDim, inputSteps, batchSize}, targetLen, fl::dtype::f32); - } - if (targetSizes.elements() != batchSize) { - throw std::runtime_error( - "Window Attention: wrong size of the target sizes vector, doesn't match with batchsize"); - } - Tensor targetNotPaddedSize = fl::reshape( - fl::ceil( - targetSizes / fl::amax(targetSizes).asScalar() * targetLen), - {1, 1, batchSize}); - targetNotPaddedSize = - fl::tile(targetNotPaddedSize, {decoderStepsDim, inputSteps, 1}); - return targetNotPaddedSize; + int decoderStepsDim +) const { + if(targetSizes.isEmpty()) { + return fl::full( + {decoderStepsDim, inputSteps, batchSize}, + targetLen, + fl::dtype::f32 + ); + } + if(targetSizes.elements() != batchSize) { + throw std::runtime_error( + "Window Attention: wrong size of the target sizes vector, doesn't match with batchsize" + ); + } + Tensor targetNotPaddedSize = fl::reshape( + fl::ceil( + targetSizes / fl::amax(targetSizes).asScalar() * targetLen + ), + {1, 1, batchSize} + ); + targetNotPaddedSize = + fl::tile(targetNotPaddedSize, {decoderStepsDim, inputSteps, 1}); + return targetNotPaddedSize; } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/attention/WindowBase.h b/flashlight/pkg/speech/criterion/attention/WindowBase.h index f251701..e953f82 100644 --- a/flashlight/pkg/speech/criterion/attention/WindowBase.h +++ b/flashlight/pkg/speech/criterion/attention/WindowBase.h @@ -11,7 +11,7 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { /** * Pretraining window class which defines attention mask @@ -19,90 +19,94 @@ namespace speech { * thus we can have good proxy for attention between encoder-decoder * right at the beginning */ -class WindowBase { - public: - WindowBase() {} + class WindowBase { + public: + WindowBase() {} - /** - * Compute window for the data for particular output step using - * @param prevAttn previous step attention - * @param step decoder step - * @param inputSteps encoder output / decoder input length (max in the batch) - * @param batchSize batch size - * @param inputSizes actual encoder output / decoder input sizes (even before - * the encoder, we need only the proportions); can be empty which means all - * are treated to be the same size, size is 1xB - * @param targetSizes actual decoder target output sizes (excluding padding); - * can be empty - */ - virtual Variable computeWindow( - const Variable& prevAttn, - int step, - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) const = 0; + /** + * Compute window for the data for particular output step using + * @param prevAttn previous step attention + * @param step decoder step + * @param inputSteps encoder output / decoder input length (max in the batch) + * @param batchSize batch size + * @param inputSizes actual encoder output / decoder input sizes (even before + * the encoder, we need only the proportions); can be empty which means all + * are treated to be the same size, size is 1xB + * @param targetSizes actual decoder target output sizes (excluding padding); + * can be empty + */ + virtual Variable computeWindow( + const Variable& prevAttn, + int step, + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) const = 0; - /** - * Compute window for the data for entire decoder known target size - * @param targetLen target size (max in the batch) - * @param inputSteps encoder output / decoder input length (max in the batch) - * @param batchSize batch size - * @param inputSizes actual encoder output / decoder input sizes (even before - * the encoder, we need only the proportions); can be empty which means all - * are treated to be the same size, size is 1xB - * @param targetSizes actual decoder target output sizes (excluding padding); - * can be empty - */ - virtual Variable computeVectorizedWindow( - int targetLen, - int inputSteps, - int batchSize, - const Tensor& inputSizes = Tensor(), - const Tensor& targetSizes = Tensor()) const = 0; + /** + * Compute window for the data for entire decoder known target size + * @param targetLen target size (max in the batch) + * @param inputSteps encoder output / decoder input length (max in the batch) + * @param batchSize batch size + * @param inputSizes actual encoder output / decoder input sizes (even before + * the encoder, we need only the proportions); can be empty which means all + * are treated to be the same size, size is 1xB + * @param targetSizes actual decoder target output sizes (excluding padding); + * can be empty + */ + virtual Variable computeVectorizedWindow( + int targetLen, + int inputSteps, + int batchSize, + const Tensor& inputSizes = Tensor(), + const Tensor& targetSizes = Tensor() + ) const = 0; - virtual ~WindowBase() {} + virtual ~WindowBase() {} - protected: - /** - * Compute necessary matrix to process the padding later from the input sizes - * @param inputSizes actual encoder output / decoder input sizes (even before - * the encoder, we need only the proportions); can be empty which means all - * are treated to be the same size, size is 1xB - * @param inputSteps encoder output / decoder input length (max in the batch) - * @param batchSize batch size - * @param decoderStepsDim max decoder steps - * @param doTile Do necessary tile to (decoderStepsDim, inputSteps, BatchSize) - * or return (1, 1, BatchSize) vector (depends on the window we need to use) - */ - Tensor computeInputNotPaddedSize( - const Tensor& inputSizes, - int inputSteps, - int batchSize, - int decoderStepsDim, - bool doTile) const; + protected: + /** + * Compute necessary matrix to process the padding later from the input sizes + * @param inputSizes actual encoder output / decoder input sizes (even before + * the encoder, we need only the proportions); can be empty which means all + * are treated to be the same size, size is 1xB + * @param inputSteps encoder output / decoder input length (max in the batch) + * @param batchSize batch size + * @param decoderStepsDim max decoder steps + * @param doTile Do necessary tile to (decoderStepsDim, inputSteps, BatchSize) + * or return (1, 1, BatchSize) vector (depends on the window we need to use) + */ + Tensor computeInputNotPaddedSize( + const Tensor& inputSizes, + int inputSteps, + int batchSize, + int decoderStepsDim, + bool doTile + ) const; - /** - * Compute necessary matrix to process the padding later from the target sizes - * @param targetSizes actual decoder target output sizes (excluding padding); - * can be empty - * @param inputSteps encoder output / decoder input length (max in the batch) - * @param targetLen target size (max in the batch) - * @param batchSize batch size - * @param decoderStepsDim max decoder steps - * @return A tensor with shape {decoderStepsDim, inputSteps, batchSize} - */ - Tensor computeTargetNotPaddedSize( - const Tensor& targetSizes, - int inputSteps, - int targetLen, - int batchSize, - int decoderStepsDim) const; + /** + * Compute necessary matrix to process the padding later from the target sizes + * @param targetSizes actual decoder target output sizes (excluding padding); + * can be empty + * @param inputSteps encoder output / decoder input length (max in the batch) + * @param targetLen target size (max in the batch) + * @param batchSize batch size + * @param decoderStepsDim max decoder steps + * @return A tensor with shape {decoderStepsDim, inputSteps, batchSize} + */ + Tensor computeTargetNotPaddedSize( + const Tensor& targetSizes, + int inputSteps, + int targetLen, + int batchSize, + int decoderStepsDim + ) const; - private: - FL_SAVE_LOAD() -}; -} // namespace speech + private: + FL_SAVE_LOAD() + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/backend/cpu/ConnectionistTemporalClassificationCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cpu/ConnectionistTemporalClassificationCriterion.cpp index ba867e2..3826908 100644 --- a/flashlight/pkg/speech/criterion/backend/cpu/ConnectionistTemporalClassificationCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cpu/ConnectionistTemporalClassificationCriterion.cpp @@ -17,214 +17,244 @@ using CriterionUtils = fl::lib::cpu::CriterionUtils; namespace fl { namespace pkg { -namespace speech { - -std::vector ConnectionistTemporalClassificationCriterion::forward( - const std::vector& inputs) { - if (inputs.size() != 2) { - throw std::invalid_argument("Invalid inputs size"); - } - const auto& input = inputs[0]; - const auto& target = inputs[1]; - validate(input, target); - auto logprobs = logSoftmax(input, 0); - - std::vector> batchAlphas; - std::vector batchLoss; - std::vector batchScales; - std::vector batchTargetSizes; - { - const int64_t N = logprobs.dim(0); - const int64_t T = logprobs.dim(1); - const int64_t B = logprobs.dim(2); - const int64_t batchL = target.dim(0); - - batchAlphas.resize(B); - batchLoss.resize(B); - batchScales.resize(B); - batchTargetSizes.resize(B); - - // get host pointers - std::vector batchInputVec(logprobs.elements()); - logprobs.host(batchInputVec.data()); - - std::vector batchTargetVec(target.elements()); - target.host(batchTargetVec.data()); - - CriterionUtils::batchTargetSize( - B, batchL, batchL, batchTargetVec.data(), batchTargetSizes.data()); - - CriterionUtils::computeScale( - B, T, N, scaleMode_, batchTargetSizes.data(), batchScales.data()); + namespace speech { + + std::vector ConnectionistTemporalClassificationCriterion::forward( + const std::vector& inputs + ) { + if(inputs.size() != 2) { + throw std::invalid_argument("Invalid inputs size"); + } + const auto& input = inputs[0]; + const auto& target = inputs[1]; + validate(input, target); + auto logprobs = logSoftmax(input, 0); + + std::vector> batchAlphas; + std::vector batchLoss; + std::vector batchScales; + std::vector batchTargetSizes; + { + const int64_t N = logprobs.dim(0); + const int64_t T = logprobs.dim(1); + const int64_t B = logprobs.dim(2); + const int64_t batchL = target.dim(0); + + batchAlphas.resize(B); + batchLoss.resize(B); + batchScales.resize(B); + batchTargetSizes.resize(B); + + // get host pointers + std::vector batchInputVec(logprobs.elements()); + logprobs.host(batchInputVec.data()); + + std::vector batchTargetVec(target.elements()); + target.host(batchTargetVec.data()); + + CriterionUtils::batchTargetSize( + B, + batchL, + batchL, + batchTargetVec.data(), + batchTargetSizes.data() + ); + + CriterionUtils::computeScale( + B, + T, + N, + scaleMode_, + batchTargetSizes.data(), + batchScales.data() + ); #pragma omp parallel for num_threads(B) - for (int64_t b = 0; b < B; ++b) { - const float* inputVec = batchInputVec.data() + b * N * T; - const int* targetVec = batchTargetVec.data() + b * batchL; - - int64_t L = batchTargetSizes[b]; - const int64_t S = 2 * L + 1; - int64_t R = fl::pkg::speech::countRepeats(targetVec, L); - - // A heuristic to modify target length to be able to compute CTC loss - L = std::min(L + R, T) - R; - R = fl::pkg::speech::countRepeats( - targetVec, L); // Recompute repeats as L has changed - - auto& alphas = batchAlphas[b]; - alphas.resize(T * S, NEG_INFINITY_FLT); - - int64_t start = (T - (L + R)) > 0 ? 0 : 1; - int64_t end = (S == 1) ? 1 : 2; - - // base case - alphas[0] = (start == 0) ? inputVec[N - 1] : NEG_INFINITY_FLT; - if (S != 1) { - alphas[1] = inputVec[targetVec[0]]; - } - for (int64_t t = 1; t < T; ++t) { - // At each time frame t, only few states can be reached depending - // on the labels, their ordering and the current time frame. - if (T - t <= L + R) { - if (start & 1 && targetVec[start / 2] != targetVec[start / 2 + 1]) { - ++start; - } - ++start; - } - if (t <= L + R) { - if (end % 2 == 0 && end < 2 * L && - (targetVec[end / 2 - 1] != targetVec[end / 2])) { - ++end; - } - ++end; - } - // Use dynamic programming to recursively compute alphas - for (int64_t s = start; s < end; ++s) { - int64_t ts = t * S + s; - int64_t curLabel = t * N + ((s & 1) ? targetVec[s / 2] : N - 1); - if (s == 0) { - alphas[ts] = alphas[ts - S]; - } else if ( - (s % 2 == 0) || s == 1 || - targetVec[s / 2] == targetVec[s / 2 - 1]) { - alphas[ts] = - fl::pkg::speech::logSumExp(alphas[ts - S], alphas[ts - S - 1]); - } else { - alphas[ts] = fl::pkg::speech::logSumExp( - alphas[ts - S], alphas[ts - S - 1], alphas[ts - S - 2]); - } - alphas[ts] += inputVec[curLabel]; - } - } - batchLoss[b] = -fl::pkg::speech::logSumExp( - alphas.end()[-1], - (S == 1) ? NEG_INFINITY_FLT : alphas.end()[-2]) * - batchScales[b]; - } - } - auto result = Tensor::fromVector(batchLoss); - - auto gradFunc = [batchAlphas, batchScales, batchTargetSizes]( - std::vector& moduleInputs, - const Variable& gradOutput) { - const int64_t N = moduleInputs[0].dim(0); - const int64_t T = moduleInputs[0].dim(1); - const int64_t B = moduleInputs[0].dim(2); - const int64_t batchL = moduleInputs[1].dim(0); - - std::vector batchInGrad(moduleInputs[0].elements(), 0.0); - - std::vector batchTargetVec(moduleInputs[1].elements()); - moduleInputs[1].host(batchTargetVec.data()); - - std::vector batchOutGrad(gradOutput.elements()); - gradOutput.host(batchOutGrad.data()); + for(int64_t b = 0; b < B; ++b) { + const float* inputVec = batchInputVec.data() + b * N * T; + const int* targetVec = batchTargetVec.data() + b * batchL; + + int64_t L = batchTargetSizes[b]; + const int64_t S = 2 * L + 1; + int64_t R = fl::pkg::speech::countRepeats(targetVec, L); + + // A heuristic to modify target length to be able to compute CTC loss + L = std::min(L + R, T) - R; + R = fl::pkg::speech::countRepeats( + targetVec, + L + ); // Recompute repeats as L has changed + + auto& alphas = batchAlphas[b]; + alphas.resize(T * S, NEG_INFINITY_FLT); + + int64_t start = (T - (L + R)) > 0 ? 0 : 1; + int64_t end = (S == 1) ? 1 : 2; + + // base case + alphas[0] = (start == 0) ? inputVec[N - 1] : NEG_INFINITY_FLT; + if(S != 1) { + alphas[1] = inputVec[targetVec[0]]; + } + for(int64_t t = 1; t < T; ++t) { + // At each time frame t, only few states can be reached depending + // on the labels, their ordering and the current time frame. + if(T - t <= L + R) { + if(start & 1 && targetVec[start / 2] != targetVec[start / 2 + 1]) { + ++start; + } + ++start; + } + if(t <= L + R) { + if( + end % 2 == 0 && end < 2 * L + && (targetVec[end / 2 - 1] != targetVec[end / 2]) + ) { + ++end; + } + ++end; + } + // Use dynamic programming to recursively compute alphas + for(int64_t s = start; s < end; ++s) { + int64_t ts = t * S + s; + int64_t curLabel = t * N + ((s & 1) ? targetVec[s / 2] : N - 1); + if(s == 0) { + alphas[ts] = alphas[ts - S]; + } else if( + (s % 2 == 0) || s == 1 + || targetVec[s / 2] == targetVec[s / 2 - 1] + ) { + alphas[ts] = + fl::pkg::speech::logSumExp(alphas[ts - S], alphas[ts - S - 1]); + } else { + alphas[ts] = fl::pkg::speech::logSumExp( + alphas[ts - S], + alphas[ts - S - 1], + alphas[ts - S - 2] + ); + } + alphas[ts] += inputVec[curLabel]; + } + } + batchLoss[b] = -fl::pkg::speech::logSumExp( + alphas.end()[-1], + (S == 1) ? NEG_INFINITY_FLT : alphas.end()[-2] + ) + * batchScales[b]; + } + } + auto result = Tensor::fromVector(batchLoss); + + auto gradFunc = [batchAlphas, batchScales, batchTargetSizes]( + std::vector& moduleInputs, + const Variable& gradOutput) { + const int64_t N = moduleInputs[0].dim(0); + const int64_t T = moduleInputs[0].dim(1); + const int64_t B = moduleInputs[0].dim(2); + const int64_t batchL = moduleInputs[1].dim(0); + + std::vector batchInGrad(moduleInputs[0].elements(), 0.0); + + std::vector batchTargetVec(moduleInputs[1].elements()); + moduleInputs[1].host(batchTargetVec.data()); + + std::vector batchOutGrad(gradOutput.elements()); + gradOutput.host(batchOutGrad.data()); #pragma omp parallel for num_threads(B) - for (int64_t b = 0; b < B; ++b) { - const int* targetVec = batchTargetVec.data() + b * batchL; - float* grad = batchInGrad.data() + b * N * T; - - int64_t L = batchTargetSizes[b]; - - L = std::min(L, T); - const int64_t R = fl::pkg::speech::countRepeats(targetVec, L); - L = std::min(L + R, T) - R; - - const int64_t S = 2 * L + 1; - const auto& alphas = batchAlphas[b]; - - int64_t start = (S == 1) ? S : S - 1; - int64_t end = S; - std::vector dAlphas(T * S, 0.0); - - // Compute dAlphas for the last timeframe - if (S == 1) { - dAlphas[T * S - 1] = -1.0; - } else { - fl::pkg::speech::dLogSumExp( - alphas[T * S - 2], - alphas[T * S - 1], - dAlphas[T * S - 2], - dAlphas[T * S - 1], - -1.0); - } - float gradScale = batchOutGrad[b] * batchScales[b]; - - for (int64_t t = T - 1; t >= 0; --t) { - // Compute start and end values at time (t) similar to calculation - // of alpha in CTC forward pass - if (T - t <= L + R + 1) { - if (start & 1 && start > 1 && - targetVec[start / 2] != targetVec[start / 2 - 1]) { - --start; - } - --start; - } - if (t < L + R) { - if (end % 2 == 0 && - (targetVec[end / 2 - 1] != targetVec[end / 2 - 2])) { - --end; - } - --end; - } - // Compute grad and dAlphas for (t-1)th frame using chain rule - for (int64_t s = start; s < end; ++s) { - int64_t ts = t * S + s; - int64_t curLabel = t * N + ((s & 1) ? targetVec[s / 2] : N - 1); - grad[curLabel] += dAlphas[ts] * gradScale; - if (t == 0) { - continue; - } - if (s == 0) { - dAlphas[ts - S] += dAlphas[ts]; - } else if ( - (s % 2 == 0) || s == 1 || - targetVec[s / 2] == targetVec[s / 2 - 1]) { - fl::pkg::speech::dLogSumExp( - alphas[ts - S], - alphas[ts - S - 1], - dAlphas[ts - S], - dAlphas[ts - S - 1], - dAlphas[ts]); - } else { - fl::pkg::speech::dLogSumExp( - alphas[ts - S], - alphas[ts - S - 1], - alphas[ts - S - 2], - dAlphas[ts - S], - dAlphas[ts - S - 1], - dAlphas[ts - S - 2], - dAlphas[ts]); - } + for(int64_t b = 0; b < B; ++b) { + const int* targetVec = batchTargetVec.data() + b * batchL; + float* grad = batchInGrad.data() + b * N * T; + + int64_t L = batchTargetSizes[b]; + + L = std::min(L, T); + const int64_t R = fl::pkg::speech::countRepeats(targetVec, L); + L = std::min(L + R, T) - R; + + const int64_t S = 2 * L + 1; + const auto& alphas = batchAlphas[b]; + + int64_t start = (S == 1) ? S : S - 1; + int64_t end = S; + std::vector dAlphas(T * S, 0.0); + + // Compute dAlphas for the last timeframe + if(S == 1) { + dAlphas[T * S - 1] = -1.0; + } else { + fl::pkg::speech::dLogSumExp( + alphas[T * S - 2], + alphas[T * S - 1], + dAlphas[T * S - 2], + dAlphas[T * S - 1], + -1.0 + ); + } + float gradScale = batchOutGrad[b] * batchScales[b]; + + for(int64_t t = T - 1; t >= 0; --t) { + // Compute start and end values at time (t) similar to calculation + // of alpha in CTC forward pass + if(T - t <= L + R + 1) { + if( + start & 1 && start > 1 + && targetVec[start / 2] != targetVec[start / 2 - 1] + ) { + --start; + } + --start; + } + if(t < L + R) { + if( + end % 2 == 0 + && (targetVec[end / 2 - 1] != targetVec[end / 2 - 2]) + ) { + --end; + } + --end; + } + // Compute grad and dAlphas for (t-1)th frame using chain rule + for(int64_t s = start; s < end; ++s) { + int64_t ts = t * S + s; + int64_t curLabel = t * N + ((s & 1) ? targetVec[s / 2] : N - 1); + grad[curLabel] += dAlphas[ts] * gradScale; + if(t == 0) { + continue; + } + if(s == 0) { + dAlphas[ts - S] += dAlphas[ts]; + } else if( + (s % 2 == 0) || s == 1 + || targetVec[s / 2] == targetVec[s / 2 - 1] + ) { + fl::pkg::speech::dLogSumExp( + alphas[ts - S], + alphas[ts - S - 1], + dAlphas[ts - S], + dAlphas[ts - S - 1], + dAlphas[ts] + ); + } else { + fl::pkg::speech::dLogSumExp( + alphas[ts - S], + alphas[ts - S - 1], + alphas[ts - S - 2], + dAlphas[ts - S], + dAlphas[ts - S - 1], + dAlphas[ts - S - 2], + dAlphas[ts] + ); + } + } + } + } + moduleInputs[0].addGrad( + Variable(Tensor::fromVector({N, T, B}, batchInGrad), false) + ); + }; + return {Variable(result, {logprobs, target}, gradFunc)}; } - } - } - moduleInputs[0].addGrad( - Variable(Tensor::fromVector({N, T, B}, batchInGrad), false)); - }; - return {Variable(result, {logprobs, target}, gradFunc)}; -} -} // namespace speech + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/backend/cpu/CriterionUtils.cpp b/flashlight/pkg/speech/criterion/backend/cpu/CriterionUtils.cpp index c97ef83..b0a0784 100644 --- a/flashlight/pkg/speech/criterion/backend/cpu/CriterionUtils.cpp +++ b/flashlight/pkg/speech/criterion/backend/cpu/CriterionUtils.cpp @@ -15,50 +15,56 @@ using ViterbiPath = fl::lib::cpu::ViterbiPath; namespace fl { namespace pkg { -namespace speech { + namespace speech { -Tensor viterbiPath(const Tensor& input, const Tensor& trans) { - auto B = input.dim(2); - auto T = input.dim(1); - auto N = input.dim(0); + Tensor viterbiPath(const Tensor& input, const Tensor& trans) { + auto B = input.dim(2); + auto T = input.dim(1); + auto N = input.dim(0); - if (N != trans.dim(0) || N != trans.dim(1)) { - throw std::invalid_argument("viterbiPath: mismatched dims"); - } else if (input.type() != fl::dtype::f32) { - throw std::invalid_argument("viterbiPath: input must be float32"); - } else if (trans.type() != fl::dtype::f32) { - throw std::invalid_argument("viterbiPath: trans must be float32"); - } + if(N != trans.dim(0) || N != trans.dim(1)) { + throw std::invalid_argument("viterbiPath: mismatched dims"); + } else if(input.type() != fl::dtype::f32) { + throw std::invalid_argument("viterbiPath: input must be float32"); + } else if(trans.type() != fl::dtype::f32) { + throw std::invalid_argument("viterbiPath: trans must be float32"); + } - auto inputVec = input.toHostVector(); - auto transVec = trans.toHostVector(); - std::vector pathVec(B * T); - std::vector workspaceVec(ViterbiPath::getWorkspaceSize(B, T, N)); + auto inputVec = input.toHostVector(); + auto transVec = trans.toHostVector(); + std::vector pathVec(B * T); + std::vector workspaceVec(ViterbiPath::getWorkspaceSize(B, T, N)); - ViterbiPath::compute( - B, - T, - N, - inputVec.data(), - transVec.data(), - pathVec.data(), - workspaceVec.data()); + ViterbiPath::compute( + B, + T, + N, + inputVec.data(), + transVec.data(), + pathVec.data(), + workspaceVec.data() + ); - return Tensor::fromVector({T, B}, pathVec); -} + return Tensor::fromVector({T, B}, pathVec); + } -Tensor getTargetSizeArray(const Tensor& target, int maxSize) { - int B = target.dim(1); - int L = target.dim(0); + Tensor getTargetSizeArray(const Tensor& target, int maxSize) { + int B = target.dim(1); + int L = target.dim(0); - auto targetVec = target.toHostVector(); - std::vector targetSizeVec(B); + auto targetVec = target.toHostVector(); + std::vector targetSizeVec(B); - CriterionUtils::batchTargetSize( - B, L, maxSize, targetVec.data(), targetSizeVec.data()); + CriterionUtils::batchTargetSize( + B, + L, + maxSize, + targetVec.data(), + targetSizeVec.data() + ); - return Tensor::fromVector(targetSizeVec); -} -} // namespace speech + return Tensor::fromVector(targetSizeVec); + } + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/backend/cpu/ForceAlignmentCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cpu/ForceAlignmentCriterion.cpp index 08f654c..ea23972 100644 --- a/flashlight/pkg/speech/criterion/backend/cpu/ForceAlignmentCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cpu/ForceAlignmentCriterion.cpp @@ -16,139 +16,149 @@ using FAC = fl::lib::cpu::ForceAlignmentCriterion; namespace { // By passing shared_ptr we avoid copies from forward to backward. struct Context { - std::vector targetVec; - std::vector targetSizeVec; - std::vector workspaceVec; + std::vector targetVec; + std::vector targetSizeVec; + std::vector workspaceVec; }; } // namespace namespace fl { namespace pkg { -namespace speech { - -static void backward( - std::vector& inputs, - const Variable& gradVar, - int B, - int T, - int N, - int L, - const std::shared_ptr& ctx) { - if (gradVar.type() != fl::dtype::f32) { - throw std::invalid_argument("FAC: grad must be float32"); - } - - auto gradVec = gradVar.tensor().toHostVector(); - std::vector inputGradVec(B * T * N); - std::vector transGradVec(N * N); - - FAC::backward( - B, - T, - N, - L, - ctx->targetVec.data(), - ctx->targetSizeVec.data(), - gradVec.data(), - inputGradVec.data(), - transGradVec.data(), - ctx->workspaceVec.data()); - - auto inputGrad = Tensor::fromVector({N, T, B}, inputGradVec); - auto transGrad = Tensor::fromVector({N, N}, transGradVec); - - inputs[0].addGrad(Variable(inputGrad, false)); - inputs[1].addGrad(Variable(transGrad, false)); -} - -Variable ForceAlignmentCriterion::forward( - const Variable& inputVar, - const Variable& targetVar) { - const auto& transVar = param(0); - int B = inputVar.dim(2); - int T = inputVar.dim(1); - int N = inputVar.dim(0); - int L = targetVar.dim(0); - - if (N != transVar.dim(0)) { - throw std::invalid_argument( - "ForceAlignmentCriterion(cpu)::forward: input dim doesn't match N"); - } else if (inputVar.type() != fl::dtype::f32) { - throw std::invalid_argument( - "ForceAlignmentCriterion(cpu)::forward: input must be float32"); - } else if (targetVar.type() != fl::dtype::s32) { - throw std::invalid_argument( - "ForceAlignmentCriterion(cpu)::forward: target must be int32"); - } - - const auto& targetSize = getTargetSizeArray(targetVar.tensor(), T); - auto ctx = std::make_shared(); - auto inputVec = inputVar.tensor().toHostVector(); - ctx->targetVec = targetVar.tensor().toHostVector(); - ctx->targetSizeVec = targetSize.toHostVector(); - auto transVec = transVar.tensor().toHostVector(); - std::vector lossVec(B); - ctx->workspaceVec.assign(FAC::getWorkspaceSize(B, T, N, L), 0); - - FAC::forward( - B, - T, - N, - L, - scaleMode_, - inputVec.data(), - ctx->targetVec.data(), - ctx->targetSizeVec.data(), - transVec.data(), - lossVec.data(), - ctx->workspaceVec.data()); - - return Variable( - Tensor::fromVector(lossVec), - {inputVar.withoutData(), transVar.withoutData()}, - [=](std::vector& inputs, const Variable& gradVar) { - backward(inputs, gradVar, B, T, N, L, ctx); - }); -} - -Tensor ForceAlignmentCriterion::viterbiPath( - const Tensor& input, - const Tensor& target) { - const Tensor& trans = param(0).tensor(); - int N = input.dim(0); // Number of output tokens - int T = input.dim(1); // Utterance length - int B = input.dim(2); // Batchsize - int L = target.dim(0); // Target length - - if (N != trans.dim(0)) { - throw std::invalid_argument("FAC: input dim doesn't match N:"); - } else if (input.type() != fl::dtype::f32) { - throw std::invalid_argument("FAC: input must be float32"); - } else if (target.type() != fl::dtype::s32) { - throw std::invalid_argument("FAC: target must be int32"); - } - const Tensor targetSize = getTargetSizeArray(target, T); - std::shared_ptr ctx = std::make_shared(); - std::vector inputVec = input.toHostVector(); - ctx->targetVec = target.toHostVector(); - ctx->targetSizeVec = targetSize.toHostVector(); - std::vector transVec = trans.toHostVector(); - std::vector lossVec(B); - ctx->workspaceVec.assign(FAC::getWorkspaceSize(B, T, N, L), 0); - std::vector bestPaths(B * T); - FAC::viterbi( - B, - T, - N, - L, - inputVec.data(), - ctx->targetVec.data(), - ctx->targetSizeVec.data(), - transVec.data(), - bestPaths.data(), - ctx->workspaceVec.data()); - return Tensor::fromVector({T, B}, bestPaths); -} -} // namespace speech + namespace speech { + + static void backward( + std::vector& inputs, + const Variable& gradVar, + int B, + int T, + int N, + int L, + const std::shared_ptr& ctx + ) { + if(gradVar.type() != fl::dtype::f32) { + throw std::invalid_argument("FAC: grad must be float32"); + } + + auto gradVec = gradVar.tensor().toHostVector(); + std::vector inputGradVec(B * T * N); + std::vector transGradVec(N * N); + + FAC::backward( + B, + T, + N, + L, + ctx->targetVec.data(), + ctx->targetSizeVec.data(), + gradVec.data(), + inputGradVec.data(), + transGradVec.data(), + ctx->workspaceVec.data() + ); + + auto inputGrad = Tensor::fromVector({N, T, B}, inputGradVec); + auto transGrad = Tensor::fromVector({N, N}, transGradVec); + + inputs[0].addGrad(Variable(inputGrad, false)); + inputs[1].addGrad(Variable(transGrad, false)); + } + + Variable ForceAlignmentCriterion::forward( + const Variable& inputVar, + const Variable& targetVar + ) { + const auto& transVar = param(0); + int B = inputVar.dim(2); + int T = inputVar.dim(1); + int N = inputVar.dim(0); + int L = targetVar.dim(0); + + if(N != transVar.dim(0)) { + throw std::invalid_argument( + "ForceAlignmentCriterion(cpu)::forward: input dim doesn't match N" + ); + } else if(inputVar.type() != fl::dtype::f32) { + throw std::invalid_argument( + "ForceAlignmentCriterion(cpu)::forward: input must be float32" + ); + } else if(targetVar.type() != fl::dtype::s32) { + throw std::invalid_argument( + "ForceAlignmentCriterion(cpu)::forward: target must be int32" + ); + } + + const auto& targetSize = getTargetSizeArray(targetVar.tensor(), T); + auto ctx = std::make_shared(); + auto inputVec = inputVar.tensor().toHostVector(); + ctx->targetVec = targetVar.tensor().toHostVector(); + ctx->targetSizeVec = targetSize.toHostVector(); + auto transVec = transVar.tensor().toHostVector(); + std::vector lossVec(B); + ctx->workspaceVec.assign(FAC::getWorkspaceSize(B, T, N, L), 0); + + FAC::forward( + B, + T, + N, + L, + scaleMode_, + inputVec.data(), + ctx->targetVec.data(), + ctx->targetSizeVec.data(), + transVec.data(), + lossVec.data(), + ctx->workspaceVec.data() + ); + + return Variable( + Tensor::fromVector(lossVec), + {inputVar.withoutData(), transVar.withoutData()}, + [ = ](std::vector& inputs, const Variable& gradVar) { + backward(inputs, gradVar, B, T, N, L, ctx); + } + ); + } + + Tensor ForceAlignmentCriterion::viterbiPath( + const Tensor& input, + const Tensor& target + ) { + const Tensor& trans = param(0).tensor(); + int N = input.dim(0); // Number of output tokens + int T = input.dim(1); // Utterance length + int B = input.dim(2); // Batchsize + int L = target.dim(0); // Target length + + if(N != trans.dim(0)) { + throw std::invalid_argument("FAC: input dim doesn't match N:"); + } else if(input.type() != fl::dtype::f32) { + throw std::invalid_argument("FAC: input must be float32"); + } else if(target.type() != fl::dtype::s32) { + throw std::invalid_argument("FAC: target must be int32"); + } + const Tensor targetSize = getTargetSizeArray(target, T); + std::shared_ptr ctx = std::make_shared(); + std::vector inputVec = input.toHostVector(); + ctx->targetVec = target.toHostVector(); + ctx->targetSizeVec = targetSize.toHostVector(); + std::vector transVec = trans.toHostVector(); + std::vector lossVec(B); + ctx->workspaceVec.assign(FAC::getWorkspaceSize(B, T, N, L), 0); + std::vector bestPaths(B * T); + FAC::viterbi( + B, + T, + N, + L, + inputVec.data(), + ctx->targetVec.data(), + ctx->targetSizeVec.data(), + transVec.data(), + bestPaths.data(), + ctx->workspaceVec.data() + ); + return Tensor::fromVector({T, B}, bestPaths); + } + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/backend/cpu/FullConnectionCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cpu/FullConnectionCriterion.cpp index a99c93f..7394b4c 100644 --- a/flashlight/pkg/speech/criterion/backend/cpu/FullConnectionCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cpu/FullConnectionCriterion.cpp @@ -17,90 +17,95 @@ using FCC = fl::lib::cpu::FullConnectionCriterion; namespace { // By passing shared_ptr we avoid copies from forward to backward. struct Context { - std::vector transVec; - std::vector workspaceVec; + std::vector transVec; + std::vector workspaceVec; }; } // namespace namespace fl { namespace pkg { -namespace speech { - -static void backward( - std::vector& inputs, - const Variable& gradVar, - int B, - int T, - int N, - const std::shared_ptr& ctx) { - if (gradVar.type() != fl::dtype::f32) { - throw std::invalid_argument("FCC: grad must be float32"); - } - - auto gradVec = gradVar.tensor().toHostVector(); - std::vector inputGradVec(B * T * N); - std::vector transGradVec(N * N); - - FCC::backward( - B, - T, - N, - ctx->transVec.data(), - gradVec.data(), - inputGradVec.data(), - transGradVec.data(), - ctx->workspaceVec.data()); - - Tensor inputGrad = Tensor::fromVector({N, T, B}, inputGradVec); - Tensor transGrad = Tensor::fromVector({N, N}, transGradVec); - - inputs[0].addGrad(Variable(inputGrad, false)); - inputs[1].addGrad(Variable(transGrad, false)); -} - -Variable FullConnectionCriterion::forward( - const Variable& inputVar, - const Variable& targetVar) { - const auto& transVar = param(0); - int B = inputVar.dim(2); - int T = inputVar.dim(1); - int N = inputVar.dim(0); - - if (N != transVar.dim(0)) { - throw std::invalid_argument("FCC: input dim doesn't match N"); - } else if (inputVar.type() != fl::dtype::f32) { - throw std::invalid_argument("FCC: input must be float32"); - } else if (targetVar.type() != fl::dtype::s32) { - throw std::invalid_argument("FCC: target must be int32"); - } - - const auto& targetSize = getTargetSizeArray(targetVar.tensor(), T); - auto ctx = std::make_shared(); - auto inputVec = inputVar.tensor().toHostVector(); - auto targetVec = targetVar.tensor().toHostVector(); - auto targetSizeVec = targetSize.toHostVector(); - ctx->transVec = transVar.tensor().toHostVector(); - std::vector lossVec(B); - ctx->workspaceVec.assign(FCC::getWorkspaceSize(B, T, N), 0); - - FCC::forward( - B, - T, - N, - scaleMode_, - inputVec.data(), - targetSizeVec.data(), - ctx->transVec.data(), - lossVec.data(), - ctx->workspaceVec.data()); - - return Variable( - Tensor::fromVector(lossVec), - {inputVar.withoutData(), transVar.withoutData()}, - [=](std::vector& inputs, const Variable& gradVar) mutable { - backward(inputs, gradVar, B, T, N, ctx); - }); -} -} // namespace speech + namespace speech { + + static void backward( + std::vector& inputs, + const Variable& gradVar, + int B, + int T, + int N, + const std::shared_ptr& ctx + ) { + if(gradVar.type() != fl::dtype::f32) { + throw std::invalid_argument("FCC: grad must be float32"); + } + + auto gradVec = gradVar.tensor().toHostVector(); + std::vector inputGradVec(B * T * N); + std::vector transGradVec(N * N); + + FCC::backward( + B, + T, + N, + ctx->transVec.data(), + gradVec.data(), + inputGradVec.data(), + transGradVec.data(), + ctx->workspaceVec.data() + ); + + Tensor inputGrad = Tensor::fromVector({N, T, B}, inputGradVec); + Tensor transGrad = Tensor::fromVector({N, N}, transGradVec); + + inputs[0].addGrad(Variable(inputGrad, false)); + inputs[1].addGrad(Variable(transGrad, false)); + } + + Variable FullConnectionCriterion::forward( + const Variable& inputVar, + const Variable& targetVar + ) { + const auto& transVar = param(0); + int B = inputVar.dim(2); + int T = inputVar.dim(1); + int N = inputVar.dim(0); + + if(N != transVar.dim(0)) { + throw std::invalid_argument("FCC: input dim doesn't match N"); + } else if(inputVar.type() != fl::dtype::f32) { + throw std::invalid_argument("FCC: input must be float32"); + } else if(targetVar.type() != fl::dtype::s32) { + throw std::invalid_argument("FCC: target must be int32"); + } + + const auto& targetSize = getTargetSizeArray(targetVar.tensor(), T); + auto ctx = std::make_shared(); + auto inputVec = inputVar.tensor().toHostVector(); + auto targetVec = targetVar.tensor().toHostVector(); + auto targetSizeVec = targetSize.toHostVector(); + ctx->transVec = transVar.tensor().toHostVector(); + std::vector lossVec(B); + ctx->workspaceVec.assign(FCC::getWorkspaceSize(B, T, N), 0); + + FCC::forward( + B, + T, + N, + scaleMode_, + inputVec.data(), + targetSizeVec.data(), + ctx->transVec.data(), + lossVec.data(), + ctx->workspaceVec.data() + ); + + return Variable( + Tensor::fromVector(lossVec), + {inputVar.withoutData(), transVar.withoutData()}, + [ = ](std::vector& inputs, const Variable& gradVar) mutable { + backward(inputs, gradVar, B, T, N, ctx); + } + ); + } + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp index 3dcb0f7..98f3b81 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp @@ -24,143 +24,153 @@ using CriterionUtils = fl::lib::cuda::CriterionUtils; namespace fl::pkg::speech { namespace { -inline void throw_on_error(ctcStatus_t status, const char* message) { - if (status != CTC_STATUS_SUCCESS) { - throw std::runtime_error( - message + (", stat = " + std::string(ctcGetStatusString(status)))); - } -} + inline void throw_on_error(ctcStatus_t status, const char* message) { + if(status != CTC_STATUS_SUCCESS) { + throw std::runtime_error( + message + (", stat = " + std::string(ctcGetStatusString(status))) + ); + } + } } // namespace std::vector ConnectionistTemporalClassificationCriterion::forward( - const std::vector& inputs) { - if (inputs.size() != 2) { - throw std::invalid_argument("Invalid inputs size"); - } - const auto& input = - fl::moddims(inputs[0], {0, 0, 0}); // remove trailing singleton dims - const auto& target = inputs[1]; - validate(input, target); - const int N = input.dim(0); - const int T = input.dim(1); - const int B = input.dim(2); - const int batchL = target.dim(0); - cudaStream_t stream = input.tensor().stream().impl().handle(); - - ctcOptions options; - options.loc = CTC_GPU; - options.stream = stream; - options.blank_label = N - 1; - - Tensor inputarr({N, B, T}, input.type()); - inputarr(fl::span, fl::span, fl::span) = - fl::transpose(input.tensor(), {0, 2, 1}); - - Tensor grad; - if (input.isCalcGrad()) { - grad = fl::full(inputarr.shape(), 0.0, inputarr.type()); - } - - std::vector inputLengths(B, T); - std::vector labels; - std::vector labelLengths; - std::vector batchTargetVec(target.elements()); - target.host(batchTargetVec.data()); - - Tensor targetSize({B}, fl::dtype::s32); - Tensor scale({B}, fl::dtype::f32); - - { - fl::DevicePtr targetRaw(target.tensor()); - fl::DevicePtr targetSizeRaw(targetSize); - fl::DevicePtr scaleRaw(scale); - - CriterionUtils::batchTargetSize( - B, - batchL, - batchL, - static_cast(targetRaw.get()), - static_cast(targetSizeRaw.get()), - stream); - - CriterionUtils::computeScale( - B, - T, - N, - scaleMode_, - static_cast(targetSizeRaw.get()), - static_cast(scaleRaw.get()), - stream); - } - - auto batchTargetSizeVec = targetSize.toHostVector(); - auto batchScaleVec = scale.toHostVector(); - - for (int b = 0; b < B; ++b) { - const int* targetVec = batchTargetVec.data() + b * batchL; - int L = batchTargetSizeVec[b]; - - // A heuristic to modify target length to be able to compute CTC loss - L = std::min(L, T); - const int R = fl::pkg::speech::countRepeats(targetVec, L); - L = std::min(L + R, T) - R; - - labelLengths.push_back(L); - for (int l = 0; l < L; ++l) { - labels.push_back(targetVec[l]); + const std::vector& inputs +) { + if(inputs.size() != 2) { + throw std::invalid_argument("Invalid inputs size"); + } + const auto& input = + fl::moddims(inputs[0], {0, 0, 0}); // remove trailing singleton dims + const auto& target = inputs[1]; + validate(input, target); + const int N = input.dim(0); + const int T = input.dim(1); + const int B = input.dim(2); + const int batchL = target.dim(0); + cudaStream_t stream = input.tensor().stream().impl().handle(); + + ctcOptions options; + options.loc = CTC_GPU; + options.stream = stream; + options.blank_label = N - 1; + + Tensor inputarr({N, B, T}, input.type()); + inputarr(fl::span, fl::span, fl::span) = + fl::transpose(input.tensor(), {0, 2, 1}); + + Tensor grad; + if(input.isCalcGrad()) { + grad = fl::full(inputarr.shape(), 0.0, inputarr.type()); } - } - Tensor batchScales = Tensor::fromVector({B}, batchScaleVec); - - size_t workspace_size; - throw_on_error( - get_workspace_size( - labelLengths.data(), - inputLengths.data(), - N, - B, - options, - &workspace_size), - "Error: get_workspace_size"); - - Tensor workspace({static_cast(workspace_size)}, fl::dtype::b8); - - std::vector costs(B, 0.0); - { - DevicePtr inputarrraw(inputarr); - DevicePtr gradraw(grad); - DevicePtr workspaceraw(workspace); + + std::vector inputLengths(B, T); + std::vector labels; + std::vector labelLengths; + std::vector batchTargetVec(target.elements()); + target.host(batchTargetVec.data()); + + Tensor targetSize({B}, fl::dtype::s32); + Tensor scale({B}, fl::dtype::f32); + + { + fl::DevicePtr targetRaw(target.tensor()); + fl::DevicePtr targetSizeRaw(targetSize); + fl::DevicePtr scaleRaw(scale); + + CriterionUtils::batchTargetSize( + B, + batchL, + batchL, + static_cast(targetRaw.get()), + static_cast(targetSizeRaw.get()), + stream + ); + + CriterionUtils::computeScale( + B, + T, + N, + scaleMode_, + static_cast(targetSizeRaw.get()), + static_cast(scaleRaw.get()), + stream + ); + } + + auto batchTargetSizeVec = targetSize.toHostVector(); + auto batchScaleVec = scale.toHostVector(); + + for(int b = 0; b < B; ++b) { + const int* targetVec = batchTargetVec.data() + b * batchL; + int L = batchTargetSizeVec[b]; + + // A heuristic to modify target length to be able to compute CTC loss + L = std::min(L, T); + const int R = fl::pkg::speech::countRepeats(targetVec, L); + L = std::min(L + R, T) - R; + + labelLengths.push_back(L); + for(int l = 0; l < L; ++l) { + labels.push_back(targetVec[l]); + } + } + Tensor batchScales = Tensor::fromVector({B}, batchScaleVec); + + size_t workspace_size; throw_on_error( - compute_ctc_loss( - (float*)inputarrraw.get(), - (float*)gradraw.get(), - labels.data(), + get_workspace_size( labelLengths.data(), inputLengths.data(), N, B, - costs.data(), - workspaceraw.get(), - options), - "Error: compute_ctc_loss"); - } - - Tensor result = Tensor::fromVector(costs); - - result = result * batchScales; - - auto gradFunc = [grad, batchScales]( - std::vector& moduleInputs, - const Variable& grad_output) { - auto gradScales = grad_output.tensor() * batchScales; - auto& in = moduleInputs[0]; - gradScales = fl::tile( - fl::reshape(gradScales, {1, grad_output.dim(0), 1}), - {in.dim(0), 1, in.dim(1)}); - moduleInputs[0].addGrad( - Variable(fl::transpose(grad * gradScales, {0, 2, 1}), false)); - }; - - return {Variable(result, {input, target}, gradFunc)}; + options, + &workspace_size + ), + "Error: get_workspace_size" + ); + + Tensor workspace({static_cast(workspace_size)}, fl::dtype::b8); + + std::vector costs(B, 0.0); + { + DevicePtr inputarrraw(inputarr); + DevicePtr gradraw(grad); + DevicePtr workspaceraw(workspace); + throw_on_error( + compute_ctc_loss( + (float*) inputarrraw.get(), + (float*) gradraw.get(), + labels.data(), + labelLengths.data(), + inputLengths.data(), + N, + B, + costs.data(), + workspaceraw.get(), + options + ), + "Error: compute_ctc_loss" + ); + } + + Tensor result = Tensor::fromVector(costs); + + result = result * batchScales; + + auto gradFunc = [grad, batchScales]( + std::vector& moduleInputs, + const Variable& grad_output) { + auto gradScales = grad_output.tensor() * batchScales; + auto& in = moduleInputs[0]; + gradScales = fl::tile( + fl::reshape(gradScales, {1, grad_output.dim(0), 1}), + {in.dim(0), 1, in.dim(1)} + ); + moduleInputs[0].addGrad( + Variable(fl::transpose(grad * gradScales, {0, 2, 1}), false) + ); + }; + + return {Variable(result, {input, target}, gradFunc)}; } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp b/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp index a926f53..6b73e80 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp @@ -22,71 +22,75 @@ using ViterbiPath = fl::lib::cuda::ViterbiPath; namespace fl::pkg::speech { Tensor viterbiPath(const Tensor& input, const Tensor& trans) { - if (input.ndim() != 3) { - throw std::invalid_argument( - "Criterion viterbiPath expects input of shape {N, T, B}"); - } - if (trans.ndim() != 2) { - throw std::invalid_argument( - "Criterion viterbiPath expects trans of shape {N, N}"); - } - - auto B = input.dim(2); - auto T = input.dim(1); - auto N = input.dim(0); - - if (N != trans.dim(0) || N != trans.dim(1)) { - throw std::invalid_argument("viterbiPath: mismatched dims"); - } else if (input.type() != fl::dtype::f32) { - throw std::invalid_argument("viterbiPath: input must be float32"); - } else if (trans.type() != fl::dtype::f32) { - throw std::invalid_argument("viterbiPath: trans must be float32"); - } - - Tensor path({T, B}, fl::dtype::s32); - Tensor workspace( - {static_cast(ViterbiPath::getWorkspaceSize(B, T, N))}, - fl::dtype::u8); - - { - fl::DevicePtr inputRaw(input); - fl::DevicePtr transRaw(trans); - fl::DevicePtr pathRaw(path); - fl::DevicePtr workspaceRaw(workspace); - - ViterbiPath::compute( - B, - T, - N, - static_cast(inputRaw.get()), - static_cast(transRaw.get()), - static_cast(pathRaw.get()), - workspaceRaw.get(), - input.stream().impl().handle()); - } - - return path; + if(input.ndim() != 3) { + throw std::invalid_argument( + "Criterion viterbiPath expects input of shape {N, T, B}" + ); + } + if(trans.ndim() != 2) { + throw std::invalid_argument( + "Criterion viterbiPath expects trans of shape {N, N}" + ); + } + + auto B = input.dim(2); + auto T = input.dim(1); + auto N = input.dim(0); + + if(N != trans.dim(0) || N != trans.dim(1)) { + throw std::invalid_argument("viterbiPath: mismatched dims"); + } else if(input.type() != fl::dtype::f32) { + throw std::invalid_argument("viterbiPath: input must be float32"); + } else if(trans.type() != fl::dtype::f32) { + throw std::invalid_argument("viterbiPath: trans must be float32"); + } + + Tensor path({T, B}, fl::dtype::s32); + Tensor workspace( + {static_cast(ViterbiPath::getWorkspaceSize(B, T, N))}, + fl::dtype::u8); + + { + fl::DevicePtr inputRaw(input); + fl::DevicePtr transRaw(trans); + fl::DevicePtr pathRaw(path); + fl::DevicePtr workspaceRaw(workspace); + + ViterbiPath::compute( + B, + T, + N, + static_cast(inputRaw.get()), + static_cast(transRaw.get()), + static_cast(pathRaw.get()), + workspaceRaw.get(), + input.stream().impl().handle() + ); + } + + return path; } Tensor getTargetSizeArray(const Tensor& target, int maxSize) { - int B = target.dim(1); - int L = target.dim(0); - - Tensor targetSize({B}, fl::dtype::s32); - - { - fl::DevicePtr targetRaw(target); - fl::DevicePtr targetSizeRaw(targetSize); - - CriterionUtils::batchTargetSize( - B, - L, - maxSize, - static_cast(targetRaw.get()), - static_cast(targetSizeRaw.get()), - target.stream().impl().handle()); - } - - return targetSize; + int B = target.dim(1); + int L = target.dim(0); + + Tensor targetSize({B}, fl::dtype::s32); + + { + fl::DevicePtr targetRaw(target); + fl::DevicePtr targetSizeRaw(targetSize); + + CriterionUtils::batchTargetSize( + B, + L, + maxSize, + static_cast(targetRaw.get()), + static_cast(targetSizeRaw.get()), + target.stream().impl().handle() + ); + } + + return targetSize; } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp index 24cdde0..4959fd3 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp @@ -29,154 +29,163 @@ static void backward( int L, const Tensor& target, const Tensor& targetSize, - Tensor& workspace) { - if (gradVar.type() != fl::dtype::f32) { - throw std::invalid_argument("FAC: grad must be float32"); - } - if (inputs.size() != 2) { - throw std::invalid_argument( - "ForceAlignmentCriterion backward expects two input args"); - } - - const auto& grad = gradVar.tensor(); - Tensor inputGrad({N, T, B}, fl::dtype::f32); - Tensor transGrad({N, N}, fl::dtype::f32); - - { - fl::DevicePtr targetRaw(target); - fl::DevicePtr targetSizeRaw(targetSize); - fl::DevicePtr gradRaw(grad); - fl::DevicePtr inputGradRaw(inputGrad); - fl::DevicePtr transGradRaw(transGrad); - fl::DevicePtr workspaceRaw(workspace); - FAC::backward( - B, - T, - N, - L, - static_cast(targetRaw.get()), - static_cast(targetSizeRaw.get()), - static_cast(gradRaw.get()), - static_cast(inputGradRaw.get()), - static_cast(transGradRaw.get()), - workspaceRaw.get(), - inputs[0].tensor().stream().impl().handle()); - } - - inputs[0].addGrad(Variable(inputGrad, false)); - inputs[1].addGrad(Variable(transGrad, false)); + Tensor& workspace +) { + if(gradVar.type() != fl::dtype::f32) { + throw std::invalid_argument("FAC: grad must be float32"); + } + if(inputs.size() != 2) { + throw std::invalid_argument( + "ForceAlignmentCriterion backward expects two input args" + ); + } + + const auto& grad = gradVar.tensor(); + Tensor inputGrad({N, T, B}, fl::dtype::f32); + Tensor transGrad({N, N}, fl::dtype::f32); + + { + fl::DevicePtr targetRaw(target); + fl::DevicePtr targetSizeRaw(targetSize); + fl::DevicePtr gradRaw(grad); + fl::DevicePtr inputGradRaw(inputGrad); + fl::DevicePtr transGradRaw(transGrad); + fl::DevicePtr workspaceRaw(workspace); + FAC::backward( + B, + T, + N, + L, + static_cast(targetRaw.get()), + static_cast(targetSizeRaw.get()), + static_cast(gradRaw.get()), + static_cast(inputGradRaw.get()), + static_cast(transGradRaw.get()), + workspaceRaw.get(), + inputs[0].tensor().stream().impl().handle() + ); + } + + inputs[0].addGrad(Variable(inputGrad, false)); + inputs[1].addGrad(Variable(transGrad, false)); } Variable ForceAlignmentCriterion::forward( const Variable& inputVar, - const Variable& targetVar) { - const auto& transVar = param(0); - int B = inputVar.dim(2); - int T = inputVar.dim(1); - int N = inputVar.dim(0); - int L = targetVar.dim(0); - - if (N != transVar.dim(0)) { - throw std::invalid_argument("FAC: input dim doesn't match N"); - } else if (inputVar.type() != fl::dtype::f32) { - throw std::invalid_argument("FAC: input must be float32"); - } else if (targetVar.type() != fl::dtype::s32) { - throw std::invalid_argument("FAC: target must be int32"); - } - - const auto& input = inputVar.tensor(); - const auto& target = targetVar.tensor(); - const auto& targetSize = getTargetSizeArray(target, T); - const auto& trans = transVar.tensor(); - Tensor loss({B}, fl::dtype::f32); - Tensor workspace( - {static_cast(FAC::getWorkspaceSize(B, T, N, L))}, - fl::dtype::u8); - - { - fl::DevicePtr inputRaw(input); - fl::DevicePtr targetRaw(target); - fl::DevicePtr targetSizeRaw(targetSize); - fl::DevicePtr transRaw(trans); - fl::DevicePtr lossRaw(loss); - fl::DevicePtr workspaceRaw(workspace); - - FAC::forward( - B, - T, - N, - L, - scaleMode_, - static_cast(inputRaw.get()), - static_cast(targetRaw.get()), - static_cast(targetSizeRaw.get()), - static_cast(transRaw.get()), - static_cast(lossRaw.get()), - workspaceRaw.get(), - input.stream().impl().handle()); - } - - return Variable( - loss, - {inputVar.withoutData(), transVar.withoutData()}, - [=](std::vector& inputs, const Variable& gradVar) mutable { - backward(inputs, gradVar, B, T, N, L, target, targetSize, workspace); - }); + const Variable& targetVar +) { + const auto& transVar = param(0); + int B = inputVar.dim(2); + int T = inputVar.dim(1); + int N = inputVar.dim(0); + int L = targetVar.dim(0); + + if(N != transVar.dim(0)) { + throw std::invalid_argument("FAC: input dim doesn't match N"); + } else if(inputVar.type() != fl::dtype::f32) { + throw std::invalid_argument("FAC: input must be float32"); + } else if(targetVar.type() != fl::dtype::s32) { + throw std::invalid_argument("FAC: target must be int32"); + } + + const auto& input = inputVar.tensor(); + const auto& target = targetVar.tensor(); + const auto& targetSize = getTargetSizeArray(target, T); + const auto& trans = transVar.tensor(); + Tensor loss({B}, fl::dtype::f32); + Tensor workspace( + {static_cast(FAC::getWorkspaceSize(B, T, N, L))}, + fl::dtype::u8); + + { + fl::DevicePtr inputRaw(input); + fl::DevicePtr targetRaw(target); + fl::DevicePtr targetSizeRaw(targetSize); + fl::DevicePtr transRaw(trans); + fl::DevicePtr lossRaw(loss); + fl::DevicePtr workspaceRaw(workspace); + + FAC::forward( + B, + T, + N, + L, + scaleMode_, + static_cast(inputRaw.get()), + static_cast(targetRaw.get()), + static_cast(targetSizeRaw.get()), + static_cast(transRaw.get()), + static_cast(lossRaw.get()), + workspaceRaw.get(), + input.stream().impl().handle() + ); + } + + return Variable( + loss, + {inputVar.withoutData(), transVar.withoutData()}, + [ = ](std::vector& inputs, const Variable& gradVar) mutable { + backward(inputs, gradVar, B, T, N, L, target, targetSize, workspace); + } + ); } Tensor ForceAlignmentCriterion::viterbiPath( const Tensor& input, - const Tensor& target) { - if (input.ndim() != 3) { - throw std::invalid_argument( - "ForceAlignmentCriterion::viterbiPath: " - "expects input with dimensions {N, T, B}"); - } - int N = input.dim(0); - int T = input.dim(1); - int B = input.dim(2); - int L = target.dim(0); - - std::vector> bestPaths; - const auto& transVar = param(0); - - if (N != transVar.dim(0)) { - throw std::invalid_argument("FAC: input dim doesn't match N:"); - } else if (input.type() != fl::dtype::f32) { - throw std::invalid_argument("FAC: input must be float32"); - } else if (target.type() != fl::dtype::s32) { - throw std::invalid_argument("FAC: target must be int32"); - } - - const auto& targetSize = getTargetSizeArray(target, T); - const auto& trans = transVar.tensor(); - Tensor bestPathsVar({T, B}, fl::dtype::s32); - Tensor workspace( - {static_cast(FAC::getWorkspaceSize(B, T, N, L))}, - fl::dtype::u8); - - { - fl::DevicePtr inputRaw(input); - fl::DevicePtr targetRaw(target); - fl::DevicePtr targetSizeRaw(targetSize); - fl::DevicePtr transRaw(trans); - fl::DevicePtr bestPathsRaw(bestPathsVar); - ; - fl::DevicePtr workspaceRaw(workspace); - - FAC::viterbiPath( - B, - T, - N, - L, - static_cast(inputRaw.get()), - static_cast(targetRaw.get()), - static_cast(targetSizeRaw.get()), - static_cast(transRaw.get()), - static_cast(bestPathsRaw.get()), - workspaceRaw.get(), - input.stream().impl().handle()); - } - return bestPathsVar; + const Tensor& target +) { + if(input.ndim() != 3) { + throw std::invalid_argument( + "ForceAlignmentCriterion::viterbiPath: " + "expects input with dimensions {N, T, B}" + ); + } + int N = input.dim(0); + int T = input.dim(1); + int B = input.dim(2); + int L = target.dim(0); + + std::vector> bestPaths; + const auto& transVar = param(0); + + if(N != transVar.dim(0)) { + throw std::invalid_argument("FAC: input dim doesn't match N:"); + } else if(input.type() != fl::dtype::f32) { + throw std::invalid_argument("FAC: input must be float32"); + } else if(target.type() != fl::dtype::s32) { + throw std::invalid_argument("FAC: target must be int32"); + } + + const auto& targetSize = getTargetSizeArray(target, T); + const auto& trans = transVar.tensor(); + Tensor bestPathsVar({T, B}, fl::dtype::s32); + Tensor workspace( + {static_cast(FAC::getWorkspaceSize(B, T, N, L))}, + fl::dtype::u8); + + { + fl::DevicePtr inputRaw(input); + fl::DevicePtr targetRaw(target); + fl::DevicePtr targetSizeRaw(targetSize); + fl::DevicePtr transRaw(trans); + fl::DevicePtr bestPathsRaw(bestPathsVar); + ; + fl::DevicePtr workspaceRaw(workspace); + + FAC::viterbiPath( + B, + T, + N, + L, + static_cast(inputRaw.get()), + static_cast(targetRaw.get()), + static_cast(targetSizeRaw.get()), + static_cast(transRaw.get()), + static_cast(bestPathsRaw.get()), + workspaceRaw.get(), + input.stream().impl().handle() + ); + } + return bestPathsVar; } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp index 1e54e61..a541d4e 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp @@ -27,101 +27,109 @@ static void backward( int T, int N, const Tensor& trans, - Tensor& workspace) { - if (gradVar.type() != fl::dtype::f32) { - throw std::invalid_argument("FCC: grad must be float32"); - } - if (inputs.size() != 2) { - throw std::invalid_argument( - "FullConnectionCriterion backward expects two input args"); - } - - const auto& grad = gradVar.tensor(); - Tensor inputGrad({N, T, B}, fl::dtype::f32); - Tensor transGrad({N, N}, fl::dtype::f32); - - { - fl::DevicePtr transRaw(trans); - fl::DevicePtr gradRaw(grad); - fl::DevicePtr inputGradRaw(inputGrad); - fl::DevicePtr transGradRaw(transGrad); - fl::DevicePtr workspaceRaw(workspace); - FCC::backward( - B, - T, - N, - static_cast(transRaw.get()), - static_cast(gradRaw.get()), - static_cast(inputGradRaw.get()), - static_cast(transGradRaw.get()), - workspaceRaw.get(), - inputs[0].tensor().stream().impl().handle()); - } - - inputs[0].addGrad(Variable(inputGrad, false)); - inputs[1].addGrad(Variable(transGrad, false)); + Tensor& workspace +) { + if(gradVar.type() != fl::dtype::f32) { + throw std::invalid_argument("FCC: grad must be float32"); + } + if(inputs.size() != 2) { + throw std::invalid_argument( + "FullConnectionCriterion backward expects two input args" + ); + } + + const auto& grad = gradVar.tensor(); + Tensor inputGrad({N, T, B}, fl::dtype::f32); + Tensor transGrad({N, N}, fl::dtype::f32); + + { + fl::DevicePtr transRaw(trans); + fl::DevicePtr gradRaw(grad); + fl::DevicePtr inputGradRaw(inputGrad); + fl::DevicePtr transGradRaw(transGrad); + fl::DevicePtr workspaceRaw(workspace); + FCC::backward( + B, + T, + N, + static_cast(transRaw.get()), + static_cast(gradRaw.get()), + static_cast(inputGradRaw.get()), + static_cast(transGradRaw.get()), + workspaceRaw.get(), + inputs[0].tensor().stream().impl().handle() + ); + } + + inputs[0].addGrad(Variable(inputGrad, false)); + inputs[1].addGrad(Variable(transGrad, false)); } Variable FullConnectionCriterion::forward( const Variable& inputVar, - const Variable& targetVar) { - if (inputVar.ndim() != 3) { - throw std::invalid_argument( - "FullConnectionCriterion::forward: " - "expects input with dimensions {N, T, B}"); - } - if (targetVar.ndim() != 2) { - throw std::invalid_argument( - "FullConnectionCriterion::forward: " - "expects target with dimensions {B, L}"); - } - - const auto& transVar = param(0); - int B = inputVar.dim(2); - int T = inputVar.dim(1); - int N = inputVar.dim(0); - - if (N != transVar.dim(0)) { - throw std::invalid_argument("FCC: input dim doesn't match N"); - } else if (inputVar.type() != fl::dtype::f32) { - throw std::invalid_argument("FCC: input must be float32"); - } else if (targetVar.type() != fl::dtype::s32) { - throw std::invalid_argument("FCC: target must be int32"); - } - - const auto& input = inputVar.tensor(); - const auto& target = targetVar.tensor(); - const auto& targetSize = getTargetSizeArray(target, T); - const auto& trans = transVar.tensor(); - Tensor loss({B}, fl::dtype::f32); - Tensor workspace( - {static_cast(FCC::getWorkspaceSize(B, T, N))}, fl::dtype::u8); - - { - fl::DevicePtr inputRaw(input); - fl::DevicePtr targetSizeRaw(targetSize); - fl::DevicePtr transRaw(trans); - fl::DevicePtr lossRaw(loss); - fl::DevicePtr workspaceRaw(workspace); - - FCC::forward( - B, - T, - N, - scaleMode_, - static_cast(inputRaw.get()), - static_cast(targetSizeRaw.get()), - static_cast(transRaw.get()), - static_cast(lossRaw.get()), - workspaceRaw.get(), - input.stream().impl().handle()); - } - - return Variable( - loss, - {inputVar.withoutData(), transVar.withoutData()}, - [=](std::vector& inputs, const Variable& gradVar) mutable { - backward(inputs, gradVar, B, T, N, trans, workspace); - }); + const Variable& targetVar +) { + if(inputVar.ndim() != 3) { + throw std::invalid_argument( + "FullConnectionCriterion::forward: " + "expects input with dimensions {N, T, B}" + ); + } + if(targetVar.ndim() != 2) { + throw std::invalid_argument( + "FullConnectionCriterion::forward: " + "expects target with dimensions {B, L}" + ); + } + + const auto& transVar = param(0); + int B = inputVar.dim(2); + int T = inputVar.dim(1); + int N = inputVar.dim(0); + + if(N != transVar.dim(0)) { + throw std::invalid_argument("FCC: input dim doesn't match N"); + } else if(inputVar.type() != fl::dtype::f32) { + throw std::invalid_argument("FCC: input must be float32"); + } else if(targetVar.type() != fl::dtype::s32) { + throw std::invalid_argument("FCC: target must be int32"); + } + + const auto& input = inputVar.tensor(); + const auto& target = targetVar.tensor(); + const auto& targetSize = getTargetSizeArray(target, T); + const auto& trans = transVar.tensor(); + Tensor loss({B}, fl::dtype::f32); + Tensor workspace( + {static_cast(FCC::getWorkspaceSize(B, T, N))}, fl::dtype::u8); + + { + fl::DevicePtr inputRaw(input); + fl::DevicePtr targetSizeRaw(targetSize); + fl::DevicePtr transRaw(trans); + fl::DevicePtr lossRaw(loss); + fl::DevicePtr workspaceRaw(workspace); + + FCC::forward( + B, + T, + N, + scaleMode_, + static_cast(inputRaw.get()), + static_cast(targetSizeRaw.get()), + static_cast(transRaw.get()), + static_cast(lossRaw.get()), + workspaceRaw.get(), + input.stream().impl().handle() + ); + } + + return Variable( + loss, + {inputVar.withoutData(), transVar.withoutData()}, + [ = ](std::vector& inputs, const Variable& gradVar) mutable { + backward(inputs, gradVar, B, T, N, trans, workspace); + } + ); } } // namespace fl diff --git a/flashlight/pkg/speech/data/FeatureTransforms.cpp b/flashlight/pkg/speech/data/FeatureTransforms.cpp index e00f7a9..097c151 100644 --- a/flashlight/pkg/speech/data/FeatureTransforms.cpp +++ b/flashlight/pkg/speech/data/FeatureTransforms.cpp @@ -29,22 +29,22 @@ using fl::lib::text::packReplabels; namespace { size_t getSfxSeed() { - // A naive seed based on thread ID - return std::hash()(std::this_thread::get_id()); + // A naive seed based on thread ID + return std::hash()(std::this_thread::get_id()); } class StartSfxCounter { - public: - explicit StartSfxCounter(int n) : iters_(n) {} - bool decrementAndCheck() { - std::lock_guard lock(mutex_); - iters_ = iters_ > 0 ? iters_ - 1 : iters_; - return iters_ <= 0; - } - - private: - int iters_; - std::mutex mutex_; +public: + explicit StartSfxCounter(int n) : iters_(n) {} + bool decrementAndCheck() { + std::lock_guard lock(mutex_); + iters_ = iters_ > 0 ? iters_ - 1 : iters_; + return iters_ <= 0; + } + +private: + int iters_; + std::mutex mutex_; }; } // namespace @@ -56,134 +56,140 @@ fl::Dataset::DataTransformFunction inputFeatures( const FeatureType& featureType, const std::pair& localNormCtx, const std::vector& sfxConf /* = {} */, - const int sfxStartUpdate /* = 0 */) { - auto sfxCounter = std::make_shared(sfxStartUpdate); - - std::shared_ptr spectralFeature; - int featSz = 1; - - if (featureType == FeatureType::POW_SPECTRUM) { - spectralFeature = std::make_shared(params); - featSz = params.powSpecFeatSz(); - } else if (featureType == FeatureType::MFSC) { - spectralFeature = std::make_shared(params); - featSz = params.mfscFeatSz(); - } else if (featureType == FeatureType::MFCC) { - spectralFeature = std::make_shared(params); - featSz = params.mfccFeatSz(); - } - - return [featSz, spectralFeature, localNormCtx, sfxConf, sfxCounter]( - void* data, Shape dims, fl::dtype type) { - if (type != fl::dtype::f32) { - throw std::invalid_argument("Invalid input type"); - } - if (dims.ndim() != 2) { - throw std::invalid_argument( - "'inputFeatures': Invalid input dims . Expected 2d array - Channels x T"); - } - auto channels = dims[0]; - std::vector input(dims.elements()); - std::copy_n(static_cast(data), input.size(), input.data()); - if (channels > 1) { - input = transpose2d(input, dims[1], channels); - } - if (!sfxConf.empty() && sfxCounter->decrementAndCheck()) { - if (channels > 1) { - throw std::invalid_argument( - "'inputFeatures': Invalid input dims. sound effect supports a single channel audio"); - } - thread_local auto seed = getSfxSeed(); - thread_local std::shared_ptr sfx = - sfx::createSoundEffect(sfxConf, seed); - sfx->apply(input); - } - - std::vector output; - if (spectralFeature) { - output = spectralFeature->batchApply(input, channels); - } else { - // use raw audio - output = input; // T X CHANNELS (Col Major) + const int sfxStartUpdate /* = 0 */ +) { + auto sfxCounter = std::make_shared(sfxStartUpdate); + + std::shared_ptr spectralFeature; + int featSz = 1; + + if(featureType == FeatureType::POW_SPECTRUM) { + spectralFeature = std::make_shared(params); + featSz = params.powSpecFeatSz(); + } else if(featureType == FeatureType::MFSC) { + spectralFeature = std::make_shared(params); + featSz = params.mfscFeatSz(); + } else if(featureType == FeatureType::MFCC) { + spectralFeature = std::make_shared(params); + featSz = params.mfccFeatSz(); } - auto T = output.size() / (featSz * channels); - // Before: FEAT X FRAMES X CHANNELS (Col Major) - output = transpose2d(output, T, featSz, channels); - // After: FRAMES X FEAT X CHANNELS (Col Major) - if (localNormCtx.first > 0 || localNormCtx.second > 0) { - output = - localNormalize(output, localNormCtx.first, localNormCtx.second, T); - } else { - output = normalize(output); - } - return Tensor::fromBuffer( - {static_cast(T), featSz, channels}, - output.data(), - MemoryLocation::Host); - }; + return [featSz, spectralFeature, localNormCtx, sfxConf, sfxCounter]( + void* data, Shape dims, fl::dtype type) { + if(type != fl::dtype::f32) { + throw std::invalid_argument("Invalid input type"); + } + if(dims.ndim() != 2) { + throw std::invalid_argument( + "'inputFeatures': Invalid input dims . Expected 2d array - Channels x T" + ); + } + auto channels = dims[0]; + std::vector input(dims.elements()); + std::copy_n(static_cast(data), input.size(), input.data()); + if(channels > 1) { + input = transpose2d(input, dims[1], channels); + } + if(!sfxConf.empty() && sfxCounter->decrementAndCheck()) { + if(channels > 1) { + throw std::invalid_argument( + "'inputFeatures': Invalid input dims. sound effect supports a single channel audio" + ); + } + thread_local auto seed = getSfxSeed(); + thread_local std::shared_ptr sfx = + sfx::createSoundEffect(sfxConf, seed); + sfx->apply(input); + } + + std::vector output; + if(spectralFeature) { + output = spectralFeature->batchApply(input, channels); + } else { + // use raw audio + output = input; // T X CHANNELS (Col Major) + } + + auto T = output.size() / (featSz * channels); + // Before: FEAT X FRAMES X CHANNELS (Col Major) + output = transpose2d(output, T, featSz, channels); + // After: FRAMES X FEAT X CHANNELS (Col Major) + if(localNormCtx.first > 0 || localNormCtx.second > 0) { + output = + localNormalize(output, localNormCtx.first, localNormCtx.second, T); + } else { + output = normalize(output); + } + return Tensor::fromBuffer( + {static_cast(T), featSz, channels}, + output.data(), + MemoryLocation::Host + ); + }; } // target fl::Dataset::DataTransformFunction targetFeatures( const Dictionary& tokenDict, const LexiconMap& lexicon, - const TargetGenerationConfig& config) { - return [tokenDict, lexicon, config]( - void* data, Shape dims, fl::dtype /* unused */) { - std::string transcript( - static_cast(data), static_cast(data) + dims.elements()); - auto words = splitOnWhitespace(transcript, true); - auto target = wrd2Target( - words, - lexicon, - tokenDict, - config.wordSeparator_, - config.targetSamplePct_, - config.fallbackToLetterWordSepLeft_, - config.fallbackToLetterWordSepRight_, - config.skipUnk_); - auto tgtVec = tokenDict.mapEntriesToIndices(target); - if (!config.surround_.empty()) { - // add surround token at the beginning and end of target - // only if begin/end tokens are not surround - auto idx = tokenDict.getIndex(config.surround_); - if (tgtVec.empty() || tgtVec.back() != idx) { - tgtVec.emplace_back(idx); - } - if (tgtVec.size() > 1 && tgtVec.front() != idx) { - tgtVec.emplace_back(idx); - std::rotate(tgtVec.begin(), tgtVec.end() - 1, tgtVec.end()); - } - } - if (config.replabel_ > 0) { - tgtVec = packReplabels(tgtVec, tokenDict, config.replabel_); - } - if (config.criterion_ == kAsgCriterion) { - dedup(tgtVec); - } - if (config.eosToken_) { - tgtVec.emplace_back(tokenDict.getIndex(kEosToken)); - } - if (tgtVec.empty()) { - // support empty target - return Tensor(fl::dtype::s32); - } - return Tensor::fromVector(tgtVec); - }; + const TargetGenerationConfig& config +) { + return [tokenDict, lexicon, config]( + void* data, Shape dims, fl::dtype /* unused */) { + std::string transcript( + static_cast(data), static_cast(data) + dims.elements()); + auto words = splitOnWhitespace(transcript, true); + auto target = wrd2Target( + words, + lexicon, + tokenDict, + config.wordSeparator_, + config.targetSamplePct_, + config.fallbackToLetterWordSepLeft_, + config.fallbackToLetterWordSepRight_, + config.skipUnk_ + ); + auto tgtVec = tokenDict.mapEntriesToIndices(target); + if(!config.surround_.empty()) { + // add surround token at the beginning and end of target + // only if begin/end tokens are not surround + auto idx = tokenDict.getIndex(config.surround_); + if(tgtVec.empty() || tgtVec.back() != idx) { + tgtVec.emplace_back(idx); + } + if(tgtVec.size() > 1 && tgtVec.front() != idx) { + tgtVec.emplace_back(idx); + std::rotate(tgtVec.begin(), tgtVec.end() - 1, tgtVec.end()); + } + } + if(config.replabel_ > 0) { + tgtVec = packReplabels(tgtVec, tokenDict, config.replabel_); + } + if(config.criterion_ == kAsgCriterion) { + dedup(tgtVec); + } + if(config.eosToken_) { + tgtVec.emplace_back(tokenDict.getIndex(kEosToken)); + } + if(tgtVec.empty()) { + // support empty target + return Tensor(fl::dtype::s32); + } + return Tensor::fromVector(tgtVec); + }; } fl::Dataset::DataTransformFunction wordFeatures(const Dictionary& wrdDict) { - return [wrdDict](void* data, Shape dims, fl::dtype /* unused */) { - std::string transcript( - static_cast(data), static_cast(data) + dims.elements()); - auto words = splitOnWhitespace(transcript, true); - auto wrdVec = wrdDict.mapEntriesToIndices(words); - if (wrdVec.empty()) { - // support empty target - return Tensor(fl::dtype::s32); - } - return Tensor::fromVector(wrdVec); - }; + return [wrdDict](void* data, Shape dims, fl::dtype /* unused */) { + std::string transcript( + static_cast(data), static_cast(data) + dims.elements()); + auto words = splitOnWhitespace(transcript, true); + auto wrdVec = wrdDict.mapEntriesToIndices(words); + if(wrdVec.empty()) { + // support empty target + return Tensor(fl::dtype::s32); + } + return Tensor::fromVector(wrdVec); + }; } } // namespace fl diff --git a/flashlight/pkg/speech/data/FeatureTransforms.h b/flashlight/pkg/speech/data/FeatureTransforms.h index 388d217..69223dc 100644 --- a/flashlight/pkg/speech/data/FeatureTransforms.h +++ b/flashlight/pkg/speech/data/FeatureTransforms.h @@ -17,168 +17,183 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { -enum class FeatureType { POW_SPECTRUM, MFSC, MFCC, NONE }; + enum class FeatureType {POW_SPECTRUM, MFSC, MFCC, NONE}; -struct TargetGenerationConfig { - TargetGenerationConfig( - const std::string& wordSeparator, - float targetSamplePct, - const std::string& criterion, - const std::string& surround, - bool eosToken, - int replabel, - bool skipUnk, - bool fallbackToLetterWordSepLeft, - bool fallbackToLetterWordSepRight) - : wordSeparator_(wordSeparator), - targetSamplePct_(targetSamplePct), - criterion_(criterion), - surround_(surround), - eosToken_(eosToken), - replabel_(replabel), - skipUnk_(skipUnk), - fallbackToLetterWordSepLeft_(fallbackToLetterWordSepLeft), - fallbackToLetterWordSepRight_(fallbackToLetterWordSepRight) {} + struct TargetGenerationConfig { + TargetGenerationConfig( + const std::string& wordSeparator, + float targetSamplePct, + const std::string& criterion, + const std::string& surround, + bool eosToken, + int replabel, + bool skipUnk, + bool fallbackToLetterWordSepLeft, + bool fallbackToLetterWordSepRight + ) + : wordSeparator_(wordSeparator), + targetSamplePct_(targetSamplePct), + criterion_(criterion), + surround_(surround), + eosToken_(eosToken), + replabel_(replabel), + skipUnk_(skipUnk), + fallbackToLetterWordSepLeft_(fallbackToLetterWordSepLeft), + fallbackToLetterWordSepRight_(fallbackToLetterWordSepRight) {} - // token separator between words - const std::string wordSeparator_; - // sampling fraction if multiple spellings are present in lexicon - const float targetSamplePct_; - // loss criterion - const std::string criterion_; - // token to add - const std::string surround_; - // end of sentence token - const bool eosToken_; - // repeat label (used in ASG) - const int replabel_; - // skip unknown tokens - const bool skipUnk_; - // use letters of word as tokens if a word is not present in lexicon - // + add wordseparator at the beginning - const bool fallbackToLetterWordSepLeft_; - // use letters of word as tokens if a word is not present in lexicon - // + add wordseparator at the end - const bool fallbackToLetterWordSepRight_; -}; + // token separator between words + const std::string wordSeparator_; + // sampling fraction if multiple spellings are present in lexicon + const float targetSamplePct_; + // loss criterion + const std::string criterion_; + // token to add + const std::string surround_; + // end of sentence token + const bool eosToken_; + // repeat label (used in ASG) + const int replabel_; + // skip unknown tokens + const bool skipUnk_; + // use letters of word as tokens if a word is not present in lexicon + // + add wordseparator at the beginning + const bool fallbackToLetterWordSepLeft_; + // use letters of word as tokens if a word is not present in lexicon + // + add wordseparator at the end + const bool fallbackToLetterWordSepRight_; + }; -fl::Dataset::DataTransformFunction inputFeatures( - const lib::audio::FeatureParams& params, - const FeatureType& featureType, - const std::pair& localNormCtx, - const std::vector& sfxConf = {}, - const int sfxStartUpdate = 0 ); + fl::Dataset::DataTransformFunction inputFeatures( + const lib::audio::FeatureParams& params, + const FeatureType& featureType, + const std::pair& localNormCtx, + const std::vector& sfxConf = {}, + const int sfxStartUpdate = 0 + ); -fl::Dataset::DataTransformFunction targetFeatures( - const lib::text::Dictionary& tokenDict, - const lib::text::LexiconMap& lexicon, - const TargetGenerationConfig& config); + fl::Dataset::DataTransformFunction targetFeatures( + const lib::text::Dictionary& tokenDict, + const lib::text::LexiconMap& lexicon, + const TargetGenerationConfig& config + ); -fl::Dataset::DataTransformFunction wordFeatures( - const lib::text::Dictionary& wrdDict); + fl::Dataset::DataTransformFunction wordFeatures( + const lib::text::Dictionary& wrdDict + ); // ============================== Helper function ============================== // Input: B x inRow x inCol (Row Major), Output: B x inCol x inRow (Row Major) -template -std::vector transpose2d( - const std::vector& in, - int64_t inRow, - int64_t inCol, - int64_t inBatch = 1) { - if (in.size() != inRow * inCol * inBatch) { - throw std::invalid_argument("Invalid input size"); - } - std::vector out(in.size()); - for (size_t b = 0; b < inBatch; ++b) { - int64_t start = b * inRow * inCol; - for (size_t c = 0; c < inCol; ++c) { - for (size_t r = 0; r < inRow; ++r) { - out[start + c * inRow + r] = in[start + r * inCol + c]; - } - } - } - return out; -} + template + std::vector transpose2d( + const std::vector& in, + int64_t inRow, + int64_t inCol, + int64_t inBatch = 1 + ) { + if(in.size() != inRow * inCol * inBatch) { + throw std::invalid_argument("Invalid input size"); + } + std::vector out(in.size()); + for(size_t b = 0; b < inBatch; ++b) { + int64_t start = b * inRow * inCol; + for(size_t c = 0; c < inCol; ++c) { + for(size_t r = 0; r < inRow; ++r) { + out[start + c * inRow + r] = in[start + r * inCol + c]; + } + } + } + return out; + } -template -std::vector localNormalize( - const std::vector& in, - int64_t leftCtxSize, - int64_t rightCtxSize, - int64_t frameSz = 1, - int64_t batchSz = 1, - double threshold = 0.0) { - if (in.empty()) { - return {}; - } - int64_t perBatchSz = in.size() / batchSz; - int64_t perFrameSz = perBatchSz / frameSz; - auto out(in); - for (size_t b = 0; b < batchSz; ++b) { - std::vector sum(frameSz, 0.0), sum2(frameSz, 0.0); - int64_t curFrame = 0; - // accumulate sum, sum^2 for computing mean, stddev - for (auto i = b * perBatchSz; i < (b + 1) * perBatchSz; ++i) { - auto start = std::max(curFrame - rightCtxSize, 0L); - auto end = std::min(curFrame + leftCtxSize, frameSz - 1); - for (int64_t j = start; j <= end; ++j) { - sum[j] += in[i]; - sum2[j] += in[i] * in[i]; - } - curFrame = (curFrame + 1) % frameSz; - } - // compute mean, stddev - for (auto j = 0; j < frameSz; ++j) { - int64_t N = (std::min(j + rightCtxSize, frameSz - 1) - - std::max(j - leftCtxSize, 0L) + 1) * - perFrameSz; - sum[j] /= N; - sum2[j] /= N; - sum2[j] -= (sum[j] * sum[j]); - sum2[j] = std::sqrt(sum2[j]); - } - // perform local normalization - curFrame = 0; - for (auto i = b * perBatchSz; i < (b + 1) * perBatchSz; ++i) { - out[i] -= sum[curFrame]; - if (sum2[curFrame] > threshold) { - out[i] /= sum2[curFrame]; - } - curFrame = (curFrame + 1) % frameSz; - } - } - return out; -} + template + std::vector localNormalize( + const std::vector& in, + int64_t leftCtxSize, + int64_t rightCtxSize, + int64_t frameSz = 1, + int64_t batchSz = 1, + double threshold = 0.0 + ) { + if(in.empty()) { + return {}; + } + int64_t perBatchSz = in.size() / batchSz; + int64_t perFrameSz = perBatchSz / frameSz; + auto out(in); + for(size_t b = 0; b < batchSz; ++b) { + std::vector sum(frameSz, 0.0), sum2(frameSz, 0.0); + int64_t curFrame = 0; + // accumulate sum, sum^2 for computing mean, stddev + for(auto i = b * perBatchSz; i < (b + 1) * perBatchSz; ++i) { + auto start = std::max(curFrame - rightCtxSize, 0L); + auto end = std::min(curFrame + leftCtxSize, frameSz - 1); + for(int64_t j = start; j <= end; ++j) { + sum[j] += in[i]; + sum2[j] += in[i] * in[i]; + } + curFrame = (curFrame + 1) % frameSz; + } + // compute mean, stddev + for(auto j = 0; j < frameSz; ++j) { + int64_t N = (std::min(j + rightCtxSize, frameSz - 1) + - std::max(j - leftCtxSize, 0L) + 1) + * perFrameSz; + sum[j] /= N; + sum2[j] /= N; + sum2[j] -= (sum[j] * sum[j]); + sum2[j] = std::sqrt(sum2[j]); + } + // perform local normalization + curFrame = 0; + for(auto i = b * perBatchSz; i < (b + 1) * perBatchSz; ++i) { + out[i] -= sum[curFrame]; + if(sum2[curFrame] > threshold) { + out[i] /= sum2[curFrame]; + } + curFrame = (curFrame + 1) % frameSz; + } + } + return out; + } -template -std::vector normalize( - const std::vector& in, - int64_t batchSz = 1, - double threshold = 0.0) { - if (in.empty()) { - return {}; - } - auto out(in); - int64_t perBatchSz = out.size() / batchSz; - for (size_t b = 0; b < batchSz; ++b) { - auto start = out.begin() + b * perBatchSz; - T sum = std::accumulate(start, start + perBatchSz, 0.0); - T mean = sum / perBatchSz; - std::transform( - start, start + perBatchSz, start, [mean](T x) { return x - mean; }); - T sq_sum = std::inner_product(start, start + perBatchSz, start, 0.0); - T stddev = std::sqrt(sq_sum / perBatchSz); - if (stddev > threshold) { - std::transform(start, start + perBatchSz, start, [stddev](T x) { - return x / stddev; - }); - } - } - return out; -} -} // namespace speech + template + std::vector normalize( + const std::vector& in, + int64_t batchSz = 1, + double threshold = 0.0 + ) { + if(in.empty()) { + return {}; + } + auto out(in); + int64_t perBatchSz = out.size() / batchSz; + for(size_t b = 0; b < batchSz; ++b) { + auto start = out.begin() + b * perBatchSz; + T sum = std::accumulate(start, start + perBatchSz, 0.0); + T mean = sum / perBatchSz; + std::transform( + start, + start + perBatchSz, + start, + [mean](T x) { return x - mean; }); + T sq_sum = std::inner_product(start, start + perBatchSz, start, 0.0); + T stddev = std::sqrt(sq_sum / perBatchSz); + if(stddev > threshold) { + std::transform( + start, + start + perBatchSz, + start, + [stddev](T x) { + return x / stddev; + } + ); + } + } + return out; + } + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/data/ListFileDataset.cpp b/flashlight/pkg/speech/data/ListFileDataset.cpp index a432c1b..f5e55f5 100644 --- a/flashlight/pkg/speech/data/ListFileDataset.cpp +++ b/flashlight/pkg/speech/data/ListFileDataset.cpp @@ -26,124 +26,141 @@ ListFileDataset::ListFileDataset( const std::string& filename, const DataTransformFunction& inFeatFunc /* = nullptr */, const DataTransformFunction& tgtFeatFunc /* = nullptr */, - const DataTransformFunction& wrdFeatFunc /* = nullptr */) - : inFeatFunc_(inFeatFunc), - tgtFeatFunc_(tgtFeatFunc), - wrdFeatFunc_(wrdFeatFunc), - numRows_(0) { - std::ifstream inFile(filename); - if (!inFile) { - throw std::invalid_argument("Unable to open file -" + filename); - } - std::string line; - while (std::getline(inFile, line)) { - if (line.empty()) { - continue; + const DataTransformFunction& wrdFeatFunc /* = nullptr */ +) : inFeatFunc_(inFeatFunc), + tgtFeatFunc_(tgtFeatFunc), + wrdFeatFunc_(wrdFeatFunc), + numRows_(0) { + std::ifstream inFile(filename); + if(!inFile) { + throw std::invalid_argument("Unable to open file -" + filename); } - auto splits = splitOnWhitespace(line, true); - if (splits.size() < 3) { - throw std::runtime_error( - "File " + filename + - " has invalid columns in line (expected 3 columns at least): " + - line); + std::string line; + while(std::getline(inFile, line)) { + if(line.empty()) { + continue; + } + auto splits = splitOnWhitespace(line, true); + if(splits.size() < 3) { + throw std::runtime_error( + "File " + filename + + " has invalid columns in line (expected 3 columns at least): " + + line + ); + } + + ids_.emplace_back(std::move(splits[kIdIdx])); + inputs_.emplace_back(std::move(splits[kInIdx])); + inputSizes_.emplace_back(std::stof(splits[kSzIdx])); + targets_.emplace_back( + fl::lib::join( + " ", + std::vector(splits.begin() + kTgtIdx, splits.end()) + ) + ); + ++numRows_; } - - ids_.emplace_back(std::move(splits[kIdIdx])); - inputs_.emplace_back(std::move(splits[kInIdx])); - inputSizes_.emplace_back(std::stof(splits[kSzIdx])); - targets_.emplace_back(fl::lib::join( - " ", std::vector(splits.begin() + kTgtIdx, splits.end()))); - ++numRows_; - } - inFile.close(); - targetSizesCache_.resize(inputSizes_.size(), -1); + inFile.close(); + targetSizesCache_.resize(inputSizes_.size(), -1); } int64_t ListFileDataset::size() const { - return numRows_; + return numRows_; } std::vector ListFileDataset::get(const int64_t idx) const { - checkIndexBounds(idx); - - auto audio = loadAudio(inputs_[idx]); // channels x time - Tensor input; - if (inFeatFunc_) { - input = inFeatFunc_( - static_cast(audio.first.data()), audio.second, fl::dtype::f32); - } else { - input = Tensor::fromBuffer( - {audio.second}, audio.first.data(), MemoryLocation::Host); - } - - Tensor target; - if (tgtFeatFunc_) { - std::vector curTarget(targets_[idx].begin(), targets_[idx].end()); - target = tgtFeatFunc_( - static_cast(curTarget.data()), - {static_cast(curTarget.size())}, - fl::dtype::b8); - } - targetSizesCache_[idx] = target.elements(); + checkIndexBounds(idx); + + auto audio = loadAudio(inputs_[idx]); // channels x time + Tensor input; + if(inFeatFunc_) { + input = inFeatFunc_( + static_cast(audio.first.data()), + audio.second, + fl::dtype::f32 + ); + } else { + input = Tensor::fromBuffer( + {audio.second}, + audio.first.data(), + MemoryLocation::Host + ); + } - Tensor words; - if (wrdFeatFunc_) { - std::vector curTarget(targets_[idx].begin(), targets_[idx].end()); - words = wrdFeatFunc_( - static_cast(curTarget.data()), - {static_cast(curTarget.size())}, - fl::dtype::b8); - } - - Tensor sampleIdx = Tensor::fromBuffer( - {static_cast(ids_[idx].length())}, - const_cast(ids_[idx].data()), // fix me post C++-17? - MemoryLocation::Host); - Tensor samplePath = Tensor::fromBuffer( - {static_cast(inputs_[idx].length())}, - inputs_[idx].data(), - MemoryLocation::Host); - Tensor sampleDuration = - Tensor::fromBuffer({1}, inputSizes_.data() + idx, MemoryLocation::Host); - Tensor sampleTargetSize = fl::full({1}, float(target.elements())); - - return { - input, - target, - words, - sampleIdx, - samplePath, - sampleDuration, - sampleTargetSize}; + Tensor target; + if(tgtFeatFunc_) { + std::vector curTarget(targets_[idx].begin(), targets_[idx].end()); + target = tgtFeatFunc_( + static_cast(curTarget.data()), + {static_cast(curTarget.size())}, + fl::dtype::b8 + ); + } + targetSizesCache_[idx] = target.elements(); + + Tensor words; + if(wrdFeatFunc_) { + std::vector curTarget(targets_[idx].begin(), targets_[idx].end()); + words = wrdFeatFunc_( + static_cast(curTarget.data()), + {static_cast(curTarget.size())}, + fl::dtype::b8 + ); + } + + Tensor sampleIdx = Tensor::fromBuffer( + {static_cast(ids_[idx].length())}, + const_cast(ids_[idx].data()), // fix me post C++-17? + MemoryLocation::Host + ); + Tensor samplePath = Tensor::fromBuffer( + {static_cast(inputs_[idx].length())}, + inputs_[idx].data(), + MemoryLocation::Host + ); + Tensor sampleDuration = + Tensor::fromBuffer({1}, inputSizes_.data() + idx, MemoryLocation::Host); + Tensor sampleTargetSize = fl::full({1}, float(target.elements())); + + return { + input, + target, + words, + sampleIdx, + samplePath, + sampleDuration, + sampleTargetSize}; } std::pair, Shape> ListFileDataset::loadAudio( - const std::string& handle) const { - auto info = loadSoundInfo(handle.c_str()); - return {loadSound(handle.c_str()), {info.channels, info.frames}}; + const std::string& handle +) const { + auto info = loadSoundInfo(handle.c_str()); + return {loadSound(handle.c_str()), {info.channels, info.frames}}; } float ListFileDataset::getInputSize(const int64_t idx) const { - checkIndexBounds(idx); - return inputSizes_[idx]; + checkIndexBounds(idx); + return inputSizes_[idx]; } int64_t ListFileDataset::getTargetSize(const int64_t idx) const { - checkIndexBounds(idx); - if (targetSizesCache_[idx] >= 0) { - return targetSizesCache_[idx]; - } - if (!tgtFeatFunc_) { - return 0; - } - std::vector curTarget(targets_[idx].begin(), targets_[idx].end()); - auto tgtSize = tgtFeatFunc_( - static_cast(curTarget.data()), - {static_cast(curTarget.size())}, - fl::dtype::b8) - .elements(); - targetSizesCache_[idx] = tgtSize; - return tgtSize; + checkIndexBounds(idx); + if(targetSizesCache_[idx] >= 0) { + return targetSizesCache_[idx]; + } + if(!tgtFeatFunc_) { + return 0; + } + std::vector curTarget(targets_[idx].begin(), targets_[idx].end()); + auto tgtSize = tgtFeatFunc_( + static_cast(curTarget.data()), + {static_cast(curTarget.size())}, + fl::dtype::b8 + ) + .elements(); + targetSizesCache_[idx] = tgtSize; + return tgtSize; } } // namespace fl diff --git a/flashlight/pkg/speech/data/ListFileDataset.h b/flashlight/pkg/speech/data/ListFileDataset.h index 5043379..07aa4f8 100644 --- a/flashlight/pkg/speech/data/ListFileDataset.h +++ b/flashlight/pkg/speech/data/ListFileDataset.h @@ -16,7 +16,7 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { /** * @@ -43,35 +43,37 @@ namespace speech { * `target`, `word_transcription`, `sample_id` in the same order. * */ -class ListFileDataset : public fl::Dataset { - public: - explicit ListFileDataset( - const std::string& filename, - const DataTransformFunction& inFeatFunc = nullptr, - const DataTransformFunction& tgtFeatFunc = nullptr, - const DataTransformFunction& wrdFeatFunc = nullptr); + class ListFileDataset : public fl::Dataset { + public: + explicit ListFileDataset( + const std::string& filename, + const DataTransformFunction& inFeatFunc = nullptr, + const DataTransformFunction& tgtFeatFunc = nullptr, + const DataTransformFunction& wrdFeatFunc = nullptr + ); - int64_t size() const override; + int64_t size() const override; - std::vector get(const int64_t idx) const override; + std::vector get(const int64_t idx) const override; - float getInputSize(const int64_t idx) const; + float getInputSize(const int64_t idx) const; - int64_t getTargetSize(const int64_t idx) const; + int64_t getTargetSize(const int64_t idx) const; - virtual std::pair, Shape> loadAudio( - const std::string& handle) const; + virtual std::pair, Shape> loadAudio( + const std::string& handle + ) const; - protected: - DataTransformFunction inFeatFunc_, tgtFeatFunc_, wrdFeatFunc_; - int64_t numRows_; - std::vector ids_; - std::vector inputs_; - std::vector targets_; - std::vector inputSizes_; - mutable std::vector targetSizesCache_; -}; + protected: + DataTransformFunction inFeatFunc_, tgtFeatFunc_, wrdFeatFunc_; + int64_t numRows_; + std::vector ids_; + std::vector inputs_; + std::vector targets_; + std::vector inputSizes_; + mutable std::vector targetSizesCache_; + }; -} // namespace speech + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/data/Sound.cpp b/flashlight/pkg/speech/data/Sound.cpp index 10fbd31..9a08d2a 100644 --- a/flashlight/pkg/speech/data/Sound.cpp +++ b/flashlight/pkg/speech/data/Sound.cpp @@ -18,10 +18,10 @@ using namespace fl::pkg::speech; namespace { struct EnumClassHash { - template - std::size_t operator()(T t) const { - return static_cast(t); - } + template + std::size_t operator()(T t) const { + return static_cast(t); + } }; const std::unordered_map formats{ @@ -82,283 +82,292 @@ namespace fl::pkg::speech { extern "C" { static sf_count_t sf_vio_ro_get_filelen(void* user_data) { - std::istream* f = reinterpret_cast(user_data); - auto pos = f->tellg(); - f->seekg(0, std::ios_base::end); - auto size = f->tellg(); - f->seekg(pos, std::ios_base::beg); - return (sf_count_t)size; + std::istream* f = reinterpret_cast(user_data); + auto pos = f->tellg(); + f->seekg(0, std::ios_base::end); + auto size = f->tellg(); + f->seekg(pos, std::ios_base::beg); + return (sf_count_t) size; } -static sf_count_t -sf_vio_ro_seek(sf_count_t offset, int whence, void* user_data) { - std::istream* f = reinterpret_cast(user_data); - std::ios_base::seekdir way; - switch (whence) { - case SEEK_CUR: - way = std::ios_base::cur; - break; - case SEEK_SET: - way = std::ios_base::beg; - break; - case SEEK_END: - way = std::ios_base::end; - break; - default: - throw std::invalid_argument("whence is invalid"); - } - f->seekg(offset, way); - return offset; +static sf_count_t sf_vio_ro_seek(sf_count_t offset, int whence, void* user_data) { + std::istream* f = reinterpret_cast(user_data); + std::ios_base::seekdir way; + switch(whence) { + case SEEK_CUR: + way = std::ios_base::cur; + break; + case SEEK_SET: + way = std::ios_base::beg; + break; + case SEEK_END: + way = std::ios_base::end; + break; + default: + throw std::invalid_argument("whence is invalid"); + } + f->seekg(offset, way); + return offset; } static sf_count_t sf_vio_ro_read(void* ptr, sf_count_t count, void* user_data) { - std::istream* f = reinterpret_cast(user_data); - f->read((char*)ptr, count); - auto n = f->gcount(); - if (!f->good()) { - f->clear(); - } - return n; + std::istream* f = reinterpret_cast(user_data); + f->read((char*) ptr, count); + auto n = f->gcount(); + if(!f->good()) { + f->clear(); + } + return n; } static sf_count_t sf_vio_ro_write( const void* /* ptr */, sf_count_t /* count */, - void* /* user_data */) { - throw std::invalid_argument("read-only stream"); - return 0; + void* /* user_data */ +) { + throw std::invalid_argument("read-only stream"); + return 0; } static sf_count_t sf_vio_ro_tell(void* user_data) { - std::istream* f = reinterpret_cast(user_data); - return f->tellg(); + std::istream* f = reinterpret_cast(user_data); + return f->tellg(); } static sf_count_t sf_vio_wo_get_filelen(void* user_data) { - std::ostream* f = reinterpret_cast(user_data); - auto pos = f->tellp(); - f->seekp(0, std::ios_base::end); - auto size = f->tellp(); - f->seekp(pos, std::ios_base::beg); - return (sf_count_t)size; + std::ostream* f = reinterpret_cast(user_data); + auto pos = f->tellp(); + f->seekp(0, std::ios_base::end); + auto size = f->tellp(); + f->seekp(pos, std::ios_base::beg); + return (sf_count_t) size; } -static sf_count_t -sf_vio_wo_seek(sf_count_t offset, int whence, void* user_data) { - std::ostream* f = reinterpret_cast(user_data); - std::ios_base::seekdir way; - switch (whence) { - case SEEK_CUR: - way = std::ios_base::cur; - break; - case SEEK_SET: - way = std::ios_base::beg; - break; - case SEEK_END: - way = std::ios_base::end; - break; - default: - throw std::invalid_argument("whence is invalid"); - } - f->seekp(offset, way); - return offset; +static sf_count_t sf_vio_wo_seek(sf_count_t offset, int whence, void* user_data) { + std::ostream* f = reinterpret_cast(user_data); + std::ios_base::seekdir way; + switch(whence) { + case SEEK_CUR: + way = std::ios_base::cur; + break; + case SEEK_SET: + way = std::ios_base::beg; + break; + case SEEK_END: + way = std::ios_base::end; + break; + default: + throw std::invalid_argument("whence is invalid"); + } + f->seekp(offset, way); + return offset; } -static sf_count_t -sf_vio_wo_read(void* /* ptr */, sf_count_t /* count */, void* /* user_data */) { - throw std::invalid_argument("write-only stream"); - return 0; +static sf_count_t sf_vio_wo_read(void* /* ptr */, sf_count_t /* count */, void* /* user_data */) { + throw std::invalid_argument("write-only stream"); + return 0; } -static sf_count_t -sf_vio_wo_write(const void* ptr, sf_count_t count, void* user_data) { - std::ostream* f = reinterpret_cast(user_data); - auto pos = f->tellp(); - f->write((const char*)ptr, count); - return f->tellp() - pos; +static sf_count_t sf_vio_wo_write(const void* ptr, sf_count_t count, void* user_data) { + std::ostream* f = reinterpret_cast(user_data); + auto pos = f->tellp(); + f->write((const char*) ptr, count); + return f->tellp() - pos; } static sf_count_t sf_vio_wo_tell(void* user_data) { - std::ostream* f = reinterpret_cast(user_data); - return f->tellp(); + std::ostream* f = reinterpret_cast(user_data); + return f->tellp(); } } /* extern "C" */ SoundInfo loadSoundInfo(const std::string& filename) { - std::ifstream f(filename); - if (!f.is_open()) { - throw std::runtime_error("could not open file for read " + filename); - } - return loadSoundInfo(f); + std::ifstream f(filename); + if(!f.is_open()) { + throw std::runtime_error("could not open file for read " + filename); + } + return loadSoundInfo(f); } SoundInfo loadSoundInfo(std::istream& f) { - SF_VIRTUAL_IO vsf = {sf_vio_ro_get_filelen, - sf_vio_ro_seek, - sf_vio_ro_read, - sf_vio_ro_write, - sf_vio_ro_tell}; - - SNDFILE* file; - SF_INFO info; - - /* mandatory */ - info.format = 0; - - if (!(file = sf_open_virtual(&vsf, SFM_READ, &info, &f))) { - throw std::runtime_error( - "loadSoundInfo: unknown format or could not open stream"); - } - - sf_close(file); - - SoundInfo usrinfo; - usrinfo.frames = info.frames; - usrinfo.samplerate = info.samplerate; - usrinfo.channels = info.channels; - return usrinfo; + SF_VIRTUAL_IO vsf = {sf_vio_ro_get_filelen, + sf_vio_ro_seek, + sf_vio_ro_read, + sf_vio_ro_write, + sf_vio_ro_tell}; + + SNDFILE* file; + SF_INFO info; + + /* mandatory */ + info.format = 0; + + if(!(file = sf_open_virtual(&vsf, SFM_READ, &info, &f))) { + throw std::runtime_error( + "loadSoundInfo: unknown format or could not open stream" + ); + } + + sf_close(file); + + SoundInfo usrinfo; + usrinfo.frames = info.frames; + usrinfo.samplerate = info.samplerate; + usrinfo.channels = info.channels; + return usrinfo; } -template +template std::vector loadSound(const std::string& filename) { - std::ifstream f(filename); - if (!f.is_open()) { - throw std::runtime_error("could not open file " + filename); - } - return loadSound(f); + std::ifstream f(filename); + if(!f.is_open()) { + throw std::runtime_error("could not open file " + filename); + } + return loadSound(f); } -template +template std::vector loadSound(std::istream& f) { - SF_VIRTUAL_IO vsf = {sf_vio_ro_get_filelen, - sf_vio_ro_seek, - sf_vio_ro_read, - sf_vio_ro_write, - sf_vio_ro_tell}; - SNDFILE* file; - SF_INFO info; - - info.format = 0; - - if (!(file = sf_open_virtual(&vsf, SFM_READ, &info, &f))) { - throw std::runtime_error( - "loadSound: unknown format or could not open stream"); - } - - std::vector in(info.frames * info.channels); - sf_count_t nframe; - if (std::is_same::value) { - nframe = - sf_readf_float(file, reinterpret_cast(in.data()), info.frames); - } else if (std::is_same::value) { - nframe = sf_readf_double( - file, reinterpret_cast(in.data()), info.frames); - } else if (std::is_same::value) { - nframe = sf_readf_int(file, reinterpret_cast(in.data()), info.frames); - } else if (std::is_same::value) { - nframe = - sf_readf_short(file, reinterpret_cast(in.data()), info.frames); - } else { - throw std::logic_error("loadSound: called with unsupported T"); - } - sf_close(file); - if (nframe != info.frames) { - throw std::runtime_error("loadSound: read error"); - } - return in; + SF_VIRTUAL_IO vsf = {sf_vio_ro_get_filelen, + sf_vio_ro_seek, + sf_vio_ro_read, + sf_vio_ro_write, + sf_vio_ro_tell}; + SNDFILE* file; + SF_INFO info; + + info.format = 0; + + if(!(file = sf_open_virtual(&vsf, SFM_READ, &info, &f))) { + throw std::runtime_error( + "loadSound: unknown format or could not open stream" + ); + } + + std::vector in(info.frames * info.channels); + sf_count_t nframe; + if(std::is_same::value) { + nframe = + sf_readf_float(file, reinterpret_cast(in.data()), info.frames); + } else if(std::is_same::value) { + nframe = sf_readf_double( + file, + reinterpret_cast(in.data()), + info.frames + ); + } else if(std::is_same::value) { + nframe = sf_readf_int(file, reinterpret_cast(in.data()), info.frames); + } else if(std::is_same::value) { + nframe = + sf_readf_short(file, reinterpret_cast(in.data()), info.frames); + } else { + throw std::logic_error("loadSound: called with unsupported T"); + } + sf_close(file); + if(nframe != info.frames) { + throw std::runtime_error("loadSound: read error"); + } + return in; } -template +template void saveSound( const std::string& filename, const std::vector& input, int64_t samplerate, int64_t channels, SoundFormat format, - SoundSubFormat subformat) { - std::ofstream f(filename); - if (!f.is_open()) { - throw std::runtime_error("could not open file for write " + filename); - } - saveSound(f, input, samplerate, channels, format, subformat); + SoundSubFormat subformat +) { + std::ofstream f(filename); + if(!f.is_open()) { + throw std::runtime_error("could not open file for write " + filename); + } + saveSound(f, input, samplerate, channels, format, subformat); } -template +template void saveSound( std::ostream& f, const std::vector& input, int64_t samplerate, int64_t channels, SoundFormat format, - SoundSubFormat subformat) { - SF_VIRTUAL_IO vsf = {sf_vio_wo_get_filelen, - sf_vio_wo_seek, - sf_vio_wo_read, - sf_vio_wo_write, - sf_vio_wo_tell}; - SNDFILE* file; - SF_INFO info; - - if (formats.find(format) == formats.end()) { - throw std::invalid_argument("saveSound: invalid format"); - } - if (subformats.find(subformat) == subformats.end()) { - throw std::invalid_argument("saveSound: invalid subformat"); - } - - info.channels = channels; - info.samplerate = samplerate; - info.format = - formats.find(format)->second | subformats.find(subformat)->second; - - if (!(file = sf_open_virtual(&vsf, SFM_WRITE, &info, &f))) { - throw std::runtime_error( - "saveSound: invalid format or could not write stream"); - } - - /* Circumvent a bug in Vorbis with large buffers */ - sf_count_t remainCount = input.size() / channels; - sf_count_t offsetCount = 0; - const sf_count_t chunkSize = 65536; - while (remainCount > 0) { - sf_count_t writableCount = std::min(chunkSize, remainCount); - sf_count_t writtenCount = 0; - if (std::is_same::value) { - writtenCount = sf_writef_float( - file, - const_cast(reinterpret_cast(input.data())) + - offsetCount * channels, - writableCount); - } else if (std::is_same::value) { - writtenCount = sf_writef_double( - file, - const_cast(reinterpret_cast(input.data())) + - offsetCount * channels, - writableCount); - } else if (std::is_same::value) { - writtenCount = sf_writef_int( - file, - const_cast(reinterpret_cast(input.data())) + - offsetCount * channels, - writableCount); - } else if (std::is_same::value) { - writtenCount = sf_writef_short( - file, - const_cast(reinterpret_cast(input.data())) + - offsetCount * channels, - writableCount); - } else { - throw std::logic_error("saveSound: called with unsupported T"); + SoundSubFormat subformat +) { + SF_VIRTUAL_IO vsf = {sf_vio_wo_get_filelen, + sf_vio_wo_seek, + sf_vio_wo_read, + sf_vio_wo_write, + sf_vio_wo_tell}; + SNDFILE* file; + SF_INFO info; + + if(formats.find(format) == formats.end()) { + throw std::invalid_argument("saveSound: invalid format"); } - if (writtenCount != writableCount) { - sf_close(file); - throw std::runtime_error("saveSound: write error"); + if(subformats.find(subformat) == subformats.end()) { + throw std::invalid_argument("saveSound: invalid subformat"); + } + + info.channels = channels; + info.samplerate = samplerate; + info.format = + formats.find(format)->second | subformats.find(subformat)->second; + + if(!(file = sf_open_virtual(&vsf, SFM_WRITE, &info, &f))) { + throw std::runtime_error( + "saveSound: invalid format or could not write stream" + ); + } + + /* Circumvent a bug in Vorbis with large buffers */ + sf_count_t remainCount = input.size() / channels; + sf_count_t offsetCount = 0; + const sf_count_t chunkSize = 65536; + while(remainCount > 0) { + sf_count_t writableCount = std::min(chunkSize, remainCount); + sf_count_t writtenCount = 0; + if(std::is_same::value) { + writtenCount = sf_writef_float( + file, + const_cast(reinterpret_cast(input.data())) + + offsetCount * channels, + writableCount + ); + } else if(std::is_same::value) { + writtenCount = sf_writef_double( + file, + const_cast(reinterpret_cast(input.data())) + + offsetCount * channels, + writableCount + ); + } else if(std::is_same::value) { + writtenCount = sf_writef_int( + file, + const_cast(reinterpret_cast(input.data())) + + offsetCount * channels, + writableCount + ); + } else if(std::is_same::value) { + writtenCount = sf_writef_short( + file, + const_cast(reinterpret_cast(input.data())) + + offsetCount * channels, + writableCount + ); + } else { + throw std::logic_error("saveSound: called with unsupported T"); + } + if(writtenCount != writableCount) { + sf_close(file); + throw std::runtime_error("saveSound: write error"); + } + remainCount -= writtenCount; + offsetCount += writtenCount; } - remainCount -= writtenCount; - offsetCount += writtenCount; - } - sf_close(file); + sf_close(file); } template std::vector loadSound(const std::string&); diff --git a/flashlight/pkg/speech/data/Sound.h b/flashlight/pkg/speech/data/Sound.h index 86390f2..52eba4d 100644 --- a/flashlight/pkg/speech/data/Sound.h +++ b/flashlight/pkg/speech/data/Sound.h @@ -13,93 +13,95 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { -enum class SoundFormat { - WAV, // Microsoft WAV format (little endian) - AIFF, // Apple/SGI AIFF format (big endian). - AU, // Sun/NeXT AU format (big endian). - RAW, // RAW PCM data. - PAF, // Ensoniq PARIS file format. - SVX, // Amiga IFF / SVX8 / SV16 format. - NIST, // Sphere NIST format. - VOC, // VOC files. - IRCAM, // Berkeley/IRCAM/CARL - W64, // Sonic Foundry's 64 bit RIFF/WAV - MAT4, // Matlab (tm) V4.2 / GNU Octave 2.0 - MAT5, // Matlab (tm) V5.0 / GNU Octave 2.1 - PVF, // Portable Voice Format - XI, // Fasttracker 2 Extended Instrument - HTK, // HMM Tool Kit format - SDS, // Midi Sample Dump Standard - AVR, // Audio Visual Research - WAVEX, // MS WAVE with WAVEFORMATEX - SD2, // Sound Designer 2 - FLAC, // FLAC lossless file format - CAF, // Core Audio File format - WVE, // Psion WVE format - OGG, // Xiph OGG container - MPC2K, // Akai MPC 2000 sampler - RF64, // RF64 WAV file -}; + enum class SoundFormat { + WAV, // Microsoft WAV format (little endian) + AIFF, // Apple/SGI AIFF format (big endian). + AU, // Sun/NeXT AU format (big endian). + RAW, // RAW PCM data. + PAF, // Ensoniq PARIS file format. + SVX, // Amiga IFF / SVX8 / SV16 format. + NIST, // Sphere NIST format. + VOC, // VOC files. + IRCAM, // Berkeley/IRCAM/CARL + W64, // Sonic Foundry's 64 bit RIFF/WAV + MAT4, // Matlab (tm) V4.2 / GNU Octave 2.0 + MAT5, // Matlab (tm) V5.0 / GNU Octave 2.1 + PVF, // Portable Voice Format + XI, // Fasttracker 2 Extended Instrument + HTK, // HMM Tool Kit format + SDS, // Midi Sample Dump Standard + AVR, // Audio Visual Research + WAVEX, // MS WAVE with WAVEFORMATEX + SD2, // Sound Designer 2 + FLAC, // FLAC lossless file format + CAF, // Core Audio File format + WVE, // Psion WVE format + OGG, // Xiph OGG container + MPC2K, // Akai MPC 2000 sampler + RF64, // RF64 WAV file + }; -enum class SoundSubFormat { - PCM_S8, // Signed 8 bit data - PCM_16, // Signed 16 bit data - PCM_24, // Signed 24 bit data - PCM_32, // Signed 32 bit data - PCM_U8, // Unsigned 8 bit data (WAV and RAW only) - FLOAT, // 32 bit float data - DOUBLE, // 64 bit float data - ULAW, // U-Law encoded. - ALAW, // A-Law encoded. - IMA_ADPCM, // IMA ADPCM. - MS_ADPCM, // Microsoft ADPCM. - GSM610, // GSM 6.10 encoding. - VOX_ADPCM, // Oki Dialogic ADPCM encoding. - G721_32, // 32kbs G721 ADPCM encoding. - G723_24, // 24kbs G723 ADPCM encoding. - G723_40, // 40kbs G723 ADPCM encoding. - DWVW_12, // 12 bit Delta Width Variable Word encoding. - DWVW_16, // 16 bit Delta Width Variable Word encoding. - DWVW_24, // 24 bit Delta Width Variable Word encoding. - DWVW_N, // N bit Delta Width Variable Word encoding. - DPCM_8, // 8 bit differential PCM (XI only) - DPCM_16, // 16 bit differential PCM (XI only) - VORBIS // Xiph Vorbis encoding. -}; + enum class SoundSubFormat { + PCM_S8, // Signed 8 bit data + PCM_16, // Signed 16 bit data + PCM_24, // Signed 24 bit data + PCM_32, // Signed 32 bit data + PCM_U8, // Unsigned 8 bit data (WAV and RAW only) + FLOAT, // 32 bit float data + DOUBLE, // 64 bit float data + ULAW, // U-Law encoded. + ALAW, // A-Law encoded. + IMA_ADPCM, // IMA ADPCM. + MS_ADPCM, // Microsoft ADPCM. + GSM610, // GSM 6.10 encoding. + VOX_ADPCM, // Oki Dialogic ADPCM encoding. + G721_32, // 32kbs G721 ADPCM encoding. + G723_24, // 24kbs G723 ADPCM encoding. + G723_40, // 40kbs G723 ADPCM encoding. + DWVW_12, // 12 bit Delta Width Variable Word encoding. + DWVW_16, // 16 bit Delta Width Variable Word encoding. + DWVW_24, // 24 bit Delta Width Variable Word encoding. + DWVW_N, // N bit Delta Width Variable Word encoding. + DPCM_8, // 8 bit differential PCM (XI only) + DPCM_16, // 16 bit differential PCM (XI only) + VORBIS // Xiph Vorbis encoding. + }; -struct SoundInfo { - int64_t frames; - int64_t samplerate; - int64_t channels; -}; + struct SoundInfo { + int64_t frames; + int64_t samplerate; + int64_t channels; + }; -SoundInfo loadSoundInfo(std::istream& f); -SoundInfo loadSoundInfo(const std::string& filename); + SoundInfo loadSoundInfo(std::istream& f); + SoundInfo loadSoundInfo(const std::string& filename); -template -std::vector loadSound(std::istream& f); -template -std::vector loadSound(const std::string& filename); + template + std::vector loadSound(std::istream& f); + template + std::vector loadSound(const std::string& filename); -template -void saveSound( - std::ostream& f, - const std::vector& input, - int64_t samplerate, - int64_t channels, - const SoundFormat format, - const SoundSubFormat subformat); + template + void saveSound( + std::ostream& f, + const std::vector& input, + int64_t samplerate, + int64_t channels, + const SoundFormat format, + const SoundSubFormat subformat + ); -template -void saveSound( - const std::string& filename, - const std::vector& input, - int64_t samplerate, - int64_t channels, - const SoundFormat format, - const SoundSubFormat subformat); -} // namespace speech + template + void saveSound( + const std::string& filename, + const std::vector& input, + int64_t samplerate, + int64_t channels, + const SoundFormat format, + const SoundSubFormat subformat + ); + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/data/Utils.cpp b/flashlight/pkg/speech/data/Utils.cpp index 3513ca1..1798645 100644 --- a/flashlight/pkg/speech/data/Utils.cpp +++ b/flashlight/pkg/speech/data/Utils.cpp @@ -23,45 +23,49 @@ std::vector wrd2Target( float targetSamplePct /* = 0 */, bool fallback2LtrWordSepLeft /* = false */, bool fallback2LtrWordSepRight /* = false */, - bool skipUnk /* = false */) { - // find the word in the lexicon and use its spelling - auto lit = lexicon.find(word); - if (lit != lexicon.end()) { - // sample random spelling if word has different spellings - if (lit->second.size() > 1 && - targetSamplePct > - static_cast(std::rand()) / static_cast(RAND_MAX)) { - return lit->second[std::rand() % lit->second.size()]; - } else { - return lit->second[0]; + bool skipUnk /* = false */ +) { + // find the word in the lexicon and use its spelling + auto lit = lexicon.find(word); + if(lit != lexicon.end()) { + // sample random spelling if word has different spellings + if( + lit->second.size() > 1 + && targetSamplePct + > static_cast(std::rand()) / static_cast(RAND_MAX) + ) { + return lit->second[std::rand() % lit->second.size()]; + } else { + return lit->second[0]; + } } - } - std::vector word2tokens; - if (fallback2LtrWordSepLeft || fallback2LtrWordSepRight) { - if (fallback2LtrWordSepLeft && !wordSeparator.empty()) { - // add word separator at the beginning of fallback word - word2tokens.push_back(wordSeparator); - } - auto tokens = splitWrd(word); - for (const auto& tkn : tokens) { - if (dict.contains(tkn)) { - word2tokens.push_back(tkn); - } else if (!skipUnk) { - throw std::invalid_argument( - "Unknown token '" + tkn + - "' when falling back to letter target for the unknown word: " + - word); - } - } - if (fallback2LtrWordSepRight && !wordSeparator.empty()) { - // add word separator at the end of fallback word - word2tokens.push_back(wordSeparator); + std::vector word2tokens; + if(fallback2LtrWordSepLeft || fallback2LtrWordSepRight) { + if(fallback2LtrWordSepLeft && !wordSeparator.empty()) { + // add word separator at the beginning of fallback word + word2tokens.push_back(wordSeparator); + } + auto tokens = splitWrd(word); + for(const auto& tkn : tokens) { + if(dict.contains(tkn)) { + word2tokens.push_back(tkn); + } else if(!skipUnk) { + throw std::invalid_argument( + "Unknown token '" + tkn + + "' when falling back to letter target for the unknown word: " + + word + ); + } + } + if(fallback2LtrWordSepRight && !wordSeparator.empty()) { + // add word separator at the end of fallback word + word2tokens.push_back(wordSeparator); + } + } else if(!skipUnk) { + throw std::invalid_argument("Unknown word in the lexicon: " + word); } - } else if (!skipUnk) { - throw std::invalid_argument("Unknown word in the lexicon: " + word); - } - return word2tokens; + return word2tokens; } std::vector wrd2Target( @@ -72,45 +76,51 @@ std::vector wrd2Target( float targetSamplePct /* = 0 */, bool fallback2LtrWordSepLeft /* = false */, bool fallback2LtrWordSepRight /* = false */, - bool skipUnk /* = false */) { - std::vector res; - for (const auto& w : words) { - auto w2tokens = wrd2Target( - w, - lexicon, - dict, - wordSeparator, - targetSamplePct, - fallback2LtrWordSepLeft, - fallback2LtrWordSepRight, - skipUnk); + bool skipUnk /* = false */ +) { + std::vector res; + for(const auto& w : words) { + auto w2tokens = wrd2Target( + w, + lexicon, + dict, + wordSeparator, + targetSamplePct, + fallback2LtrWordSepLeft, + fallback2LtrWordSepRight, + skipUnk + ); - if (w2tokens.empty()) { - continue; + if(w2tokens.empty()) { + continue; + } + res.insert(res.end(), w2tokens.begin(), w2tokens.end()); } - res.insert(res.end(), w2tokens.begin(), w2tokens.end()); - } - return res; + return res; } std::pair getFeatureType( const std::string& featuresType, int channels, - const fl::lib::audio::FeatureParams& featParams) { - if (featuresType == kFeaturesPow) { - return std::make_pair( - featParams.powSpecFeatSz(), FeatureType::POW_SPECTRUM); - } else if (featuresType == kFeaturesMFSC) { - return std::make_pair(featParams.mfscFeatSz(), FeatureType::MFSC); - } else if (featuresType == kFeaturesMFSC) { - return std::make_pair(featParams.mfccFeatSz(), FeatureType::MFCC); - } else if (featuresType == kFeaturesRaw) { - return std::make_pair(channels, FeatureType::NONE); - } else { - throw std::runtime_error( - "Unsupported feature type for audio preprocessing '" + featuresType + - "'"); - } + const fl::lib::audio::FeatureParams& featParams +) { + if(featuresType == kFeaturesPow) { + return std::make_pair( + featParams.powSpecFeatSz(), + FeatureType::POW_SPECTRUM + ); + } else if(featuresType == kFeaturesMFSC) { + return std::make_pair(featParams.mfscFeatSz(), FeatureType::MFSC); + } else if(featuresType == kFeaturesMFSC) { + return std::make_pair(featParams.mfccFeatSz(), FeatureType::MFCC); + } else if(featuresType == kFeaturesRaw) { + return std::make_pair(channels, FeatureType::NONE); + } else { + throw std::runtime_error( + "Unsupported feature type for audio preprocessing '" + featuresType + + "'" + ); + } } } // namespace fl diff --git a/flashlight/pkg/speech/data/Utils.h b/flashlight/pkg/speech/data/Utils.h index 31de4aa..fa045e0 100644 --- a/flashlight/pkg/speech/data/Utils.h +++ b/flashlight/pkg/speech/data/Utils.h @@ -17,33 +17,36 @@ namespace fl { namespace pkg { -namespace speech { - -std::vector wrd2Target( - const std::string& word, - const lib::text::LexiconMap& lexicon, - const lib::text::Dictionary& dict, - const std::string& wordSeparator = "", - float targetSamplePct = 0, - bool fallback2LtrWordSepLeft = false, - bool fallback2LtrWordSepRight = false, - bool skipUnk = false); - -std::vector wrd2Target( - const std::vector& words, - const lib::text::LexiconMap& lexicon, - const lib::text::Dictionary& dict, - const std::string& wordSeparator = "", - float targetSamplePct = 0, - bool fallback2LtrWordSepLeft = false, - bool fallback2LtrWordSepRight = false, - bool skipUnk = false); - -std::pair getFeatureType( - const std::string& featuresType, - int channels, - const fl::lib::audio::FeatureParams& featParams); - -} // namespace speech + namespace speech { + + std::vector wrd2Target( + const std::string& word, + const lib::text::LexiconMap& lexicon, + const lib::text::Dictionary& dict, + const std::string& wordSeparator = "", + float targetSamplePct = 0, + bool fallback2LtrWordSepLeft = false, + bool fallback2LtrWordSepRight = false, + bool skipUnk = false + ); + + std::vector wrd2Target( + const std::vector& words, + const lib::text::LexiconMap& lexicon, + const lib::text::Dictionary& dict, + const std::string& wordSeparator = "", + float targetSamplePct = 0, + bool fallback2LtrWordSepLeft = false, + bool fallback2LtrWordSepRight = false, + bool skipUnk = false + ); + + std::pair getFeatureType( + const std::string& featuresType, + int channels, + const fl::lib::audio::FeatureParams& featParams + ); + + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/decoder/ConvLmModule.cpp b/flashlight/pkg/speech/decoder/ConvLmModule.cpp index 3cf5c18..361113f 100644 --- a/flashlight/pkg/speech/decoder/ConvLmModule.cpp +++ b/flashlight/pkg/speech/decoder/ConvLmModule.cpp @@ -16,51 +16,55 @@ namespace fl::pkg::speech { GetConvLmScoreFunc buildGetConvLmScoreFunction( - std::shared_ptr network) { - auto getConvLmScoreFunc = [network]( - const std::vector& inputs, - const std::vector& lastTokenPositions, - int sampleSize = -1, - int batchSize = 1) { - sampleSize = sampleSize > 0 ? sampleSize : inputs.size(); - if (sampleSize * batchSize > inputs.size()) { - throw std::invalid_argument( - "[ConvLM] Incorrect sample size (" + std::to_string(sampleSize) + - ") or batch size (" + std::to_string(batchSize) + ")."); - } - Tensor inputData = Tensor::fromVector({sampleSize, batchSize}, inputs); - fl::Variable output = network->forward({fl::input(inputData)})[0]; + std::shared_ptr network +) { + auto getConvLmScoreFunc = [network]( + const std::vector& inputs, + const std::vector& lastTokenPositions, + int sampleSize = -1, + int batchSize = 1) { + sampleSize = sampleSize > 0 ? sampleSize : inputs.size(); + if(sampleSize * batchSize > inputs.size()) { + throw std::invalid_argument( + "[ConvLM] Incorrect sample size (" + std::to_string(sampleSize) + + ") or batch size (" + std::to_string(batchSize) + ")." + ); + } + Tensor inputData = Tensor::fromVector({sampleSize, batchSize}, inputs); + fl::Variable output = network->forward({fl::input(inputData)})[0]; - if (fl::countNonzero(fl::isnan(output.tensor())).asScalar() != 0) { - throw std::runtime_error("[ConvLM] Encountered NaNs in propagation"); - } - int32_t C = output.dim(0), T = output.dim(1), B = output.dim(2); - if (B != batchSize) { - throw std::logic_error( - "[ConvLM]: incorrect predictions: batch should be " + - std::to_string(batchSize) + " but it is " + std::to_string(B)); - } - if (batchSize != static_cast(lastTokenPositions.size())) { - throw std::logic_error( - "[ConvLM]: incorrect postions for accessing: size should be " + - std::to_string(batchSize) + " but it is " + - std::to_string(lastTokenPositions.size())); - } - // output (c, t, b) - // set global indices: offset by channel - Tensor globalIndices = fl::iota({C, 1}, {1, B}, fl::dtype::s32); - // set global indices: offset by batch - globalIndices = - globalIndices + fl::iota({1, B}, {C, 1}, fl::dtype::s32) * T * C; - // set global indices: offset by time which we need to take - globalIndices = globalIndices + - fl::tile(Tensor::fromVector({1, B}, lastTokenPositions), {C, 1}) * C; - Tensor preds = - fl::reshape(output.tensor().flatten()(globalIndices.flatten()), {C, B}); - // vector of B X C predictions - return preds.toHostVector(); - }; + if(fl::countNonzero(fl::isnan(output.tensor())).asScalar() != 0) { + throw std::runtime_error("[ConvLM] Encountered NaNs in propagation"); + } + int32_t C = output.dim(0), T = output.dim(1), B = output.dim(2); + if(B != batchSize) { + throw std::logic_error( + "[ConvLM]: incorrect predictions: batch should be " + + std::to_string(batchSize) + " but it is " + std::to_string(B) + ); + } + if(batchSize != static_cast(lastTokenPositions.size())) { + throw std::logic_error( + "[ConvLM]: incorrect postions for accessing: size should be " + + std::to_string(batchSize) + " but it is " + + std::to_string(lastTokenPositions.size()) + ); + } + // output (c, t, b) + // set global indices: offset by channel + Tensor globalIndices = fl::iota({C, 1}, {1, B}, fl::dtype::s32); + // set global indices: offset by batch + globalIndices = + globalIndices + fl::iota({1, B}, {C, 1}, fl::dtype::s32) * T * C; + // set global indices: offset by time which we need to take + globalIndices = globalIndices + + fl::tile(Tensor::fromVector({1, B}, lastTokenPositions), {C, 1}) * C; + Tensor preds = + fl::reshape(output.tensor().flatten()(globalIndices.flatten()), {C, B}); + // vector of B X C predictions + return preds.toHostVector(); + }; - return getConvLmScoreFunc; + return getConvLmScoreFunc; } } // namespace fl diff --git a/flashlight/pkg/speech/decoder/ConvLmModule.h b/flashlight/pkg/speech/decoder/ConvLmModule.h index 807992a..6131d76 100644 --- a/flashlight/pkg/speech/decoder/ConvLmModule.h +++ b/flashlight/pkg/speech/decoder/ConvLmModule.h @@ -16,13 +16,13 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { -using GetConvLmScoreFunc = std::function(const std::vector&, const std::vector&, int, int)>; + using GetConvLmScoreFunc = std::function(const std::vector&, const std::vector&, int, int)>; -GetConvLmScoreFunc buildGetConvLmScoreFunction(std::shared_ptr network); + GetConvLmScoreFunc buildGetConvLmScoreFunction(std::shared_ptr network); -} // namespace speech + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/decoder/DecodeMaster.cpp b/flashlight/pkg/speech/decoder/DecodeMaster.cpp index b1d7994..8b89da1 100644 --- a/flashlight/pkg/speech/decoder/DecodeMaster.cpp +++ b/flashlight/pkg/speech/decoder/DecodeMaster.cpp @@ -27,10 +27,10 @@ constexpr size_t kDMWordPredIdx = 3; using namespace fl; Tensor removeNegative(const fl::Tensor& arr) { - return arr(arr >= 0); + return arr(arr >= 0); } Tensor removePad(const Tensor& arr, int32_t padIdx) { - return arr(arr != padIdx); + return arr(arr != padIdx); } } // namespace @@ -45,175 +45,200 @@ DecodeMaster::DecodeMaster( const bool usePlugin, const fl::lib::text::Dictionary& tokenDict, const fl::lib::text::Dictionary& wordDict, - const DecodeMasterTrainOptions& trainOpt) - : net_(net), - lm_(lm), - isTokenLM_(isTokenLM), - usePlugin_(usePlugin), - tokenDict_(tokenDict), - wordDict_(wordDict), - trainOpt_(trainOpt) {} + const DecodeMasterTrainOptions& trainOpt +) : net_(net), + lm_(lm), + isTokenLM_(isTokenLM), + usePlugin_(usePlugin), + tokenDict_(tokenDict), + wordDict_(wordDict), + trainOpt_(trainOpt) {} -std::pair, std::vector> -DecodeMaster::computeMetrics(const std::shared_ptr& predDataset) { - fl::EditDistanceMeter wordEditDist, tokenEditDist; +std::pair, + std::vector> DecodeMaster::computeMetrics(const std::shared_ptr& predDataset) { + fl::EditDistanceMeter wordEditDist, tokenEditDist; - for (auto& sample : *predDataset) { - if (sample.size() <= kDMWordPredIdx) { - throw std::runtime_error( - "computeMetrics: need token/word target to compute WER"); - } - auto predictionWrd = sample[kDMWordPredIdx]; - auto targetWrd = sample[kDMWordTargetIdx]; - auto prediction = sample[kDMTokenPredIdx]; - auto target = sample[kDMTokenTargetIdx]; - bool isPredictingWrd = !predictionWrd.isEmpty(); + for(auto& sample : *predDataset) { + if(sample.size() <= kDMWordPredIdx) { + throw std::runtime_error( + "computeMetrics: need token/word target to compute WER" + ); + } + auto predictionWrd = sample[kDMWordPredIdx]; + auto targetWrd = sample[kDMWordTargetIdx]; + auto prediction = sample[kDMTokenPredIdx]; + auto target = sample[kDMTokenTargetIdx]; + bool isPredictingWrd = !predictionWrd.isEmpty(); - if (prediction.ndim() > 2 || target.ndim() > 2) { - throw std::runtime_error( - "computeMetrics: expecting TxB for prediction and target"); - } - if (isPredictingWrd && (predictionWrd.ndim() > 2 || targetWrd.ndim() > 2)) { - throw std::runtime_error( - "computeMetrics: expecting TxB for prediction and target"); - } + if(prediction.ndim() > 2 || target.ndim() > 2) { + throw std::runtime_error( + "computeMetrics: expecting TxB for prediction and target" + ); + } + if(isPredictingWrd && (predictionWrd.ndim() > 2 || targetWrd.ndim() > 2)) { + throw std::runtime_error( + "computeMetrics: expecting TxB for prediction and target" + ); + } - if (!prediction.isEmpty() && !target.isEmpty() && - (prediction.dim(1) != target.dim(1))) { - throw std::runtime_error( - "computeMetrics: prediction and target do not match"); - } - if (isPredictingWrd && !predictionWrd.isEmpty() && !targetWrd.isEmpty() && - (predictionWrd.dim(1) != targetWrd.dim(1))) { - throw std::runtime_error( - "computeMetrics: prediction and target do not match"); - } - // token predictions and target - std::vector predictionV = prediction.toHostVector(); - std::vector targetV = target.toHostVector(); + if( + !prediction.isEmpty() && !target.isEmpty() + && (prediction.dim(1) != target.dim(1)) + ) { + throw std::runtime_error( + "computeMetrics: prediction and target do not match" + ); + } + if( + isPredictingWrd && !predictionWrd.isEmpty() && !targetWrd.isEmpty() + && (predictionWrd.dim(1) != targetWrd.dim(1)) + ) { + throw std::runtime_error( + "computeMetrics: prediction and target do not match" + ); + } + // token predictions and target + std::vector predictionV = prediction.toHostVector(); + std::vector targetV = target.toHostVector(); - auto predictionS = computeStringPred(predictionV); - auto targetS = computeStringTarget(targetV); - tokenEditDist.add(predictionS, targetS); + auto predictionS = computeStringPred(predictionV); + auto targetS = computeStringTarget(targetV); + tokenEditDist.add(predictionS, targetS); - std::vector targetWrdS, predictionWrdS; - if (isPredictingWrd) { - targetWrdS = wrdIdx2Wrd(targetWrd.toHostVector(), wordDict_); - predictionWrdS = wrdIdx2Wrd(predictionWrd.toHostVector(), wordDict_); - } else { - targetWrdS = tkn2Wrd(targetS, trainOpt_.wordSep); - predictionWrdS = tkn2Wrd(predictionS, trainOpt_.wordSep); + std::vector targetWrdS, predictionWrdS; + if(isPredictingWrd) { + targetWrdS = wrdIdx2Wrd(targetWrd.toHostVector(), wordDict_); + predictionWrdS = wrdIdx2Wrd(predictionWrd.toHostVector(), wordDict_); + } else { + targetWrdS = tkn2Wrd(targetS, trainOpt_.wordSep); + predictionWrdS = tkn2Wrd(predictionS, trainOpt_.wordSep); + } + wordEditDist.add(predictionWrdS, targetWrdS); } - wordEditDist.add(predictionWrdS, targetWrdS); - } - return {tokenEditDist.value(), wordEditDist.value()}; + return {tokenEditDist.value(), wordEditDist.value()}; } std::shared_ptr DecodeMaster::buildTrie( const fl::lib::text::LexiconMap& lexicon, - fl::lib::text::SmearingMode smearMode) const { - auto trie = std::make_shared( - tokenDict_.indexSize(), tokenDict_.getIndex(trainOpt_.wordSep)); - auto startState = lm_->start(false); - for (auto& it : lexicon) { - const std::string& word = it.first; - int usrIdx = wordDict_.getIndex(word); - float score = 0; - if (!isTokenLM_) { - fl::lib::text::LMStatePtr dummyState; - std::tie(dummyState, score) = lm_->score(startState, usrIdx); - } - for (auto& tokens : it.second) { - auto tokensTensor = tkn2Idx(tokens, tokenDict_, trainOpt_.repLabel); - trie->insert(tokensTensor, usrIdx, score); + fl::lib::text::SmearingMode smearMode +) const { + auto trie = std::make_shared( + tokenDict_.indexSize(), + tokenDict_.getIndex(trainOpt_.wordSep) + ); + auto startState = lm_->start(false); + for(auto& it : lexicon) { + const std::string& word = it.first; + int usrIdx = wordDict_.getIndex(word); + float score = 0; + if(!isTokenLM_) { + fl::lib::text::LMStatePtr dummyState; + std::tie(dummyState, score) = lm_->score(startState, usrIdx); + } + for(auto& tokens : it.second) { + auto tokensTensor = tkn2Idx(tokens, tokenDict_, trainOpt_.repLabel); + trie->insert(tokensTensor, usrIdx, score); + } } - } - // Smearing - trie->smear(smearMode); - return trie; + // Smearing + trie->smear(smearMode); + return trie; } std::shared_ptr DecodeMaster::forward( - const std::shared_ptr& ds) { - auto emissionDataset = std::make_shared(); - for (auto& batch : *ds) { - Tensor output; - if (batch.empty()) { - continue; - } - if (usePlugin_) { - output = net_->forward({fl::input(batch[kInputIdx]), - fl::noGrad(batch[kDurationIdx])}) - .front() - .tensor(); - } else { - output = fl::pkg::runtime::forwardSequentialModuleWithPadMask( - fl::input(batch[kInputIdx]), net_, batch[kDurationIdx]) - .tensor(); - } - if (output.ndim() > 3) { - throw std::runtime_error("output should be NxTxB"); - } - Tensor tokenTarget = - (batch.size() > kTargetIdx ? batch[kTargetIdx] : Tensor()); - Tensor wordTarget = (batch.size() > kWordIdx ? batch[kWordIdx] : Tensor()); + const std::shared_ptr& ds +) { + auto emissionDataset = std::make_shared(); + for(auto& batch : *ds) { + Tensor output; + if(batch.empty()) { + continue; + } + if(usePlugin_) { + output = net_->forward( + {fl::input(batch[kInputIdx]), + fl::noGrad(batch[kDurationIdx])} + ) + .front() + .tensor(); + } else { + output = fl::pkg::runtime::forwardSequentialModuleWithPadMask( + fl::input(batch[kInputIdx]), + net_, + batch[kDurationIdx] + ) + .tensor(); + } + if(output.ndim() > 3) { + throw std::runtime_error("output should be NxTxB"); + } + Tensor tokenTarget = + (batch.size() > kTargetIdx ? batch[kTargetIdx] : Tensor()); + Tensor wordTarget = (batch.size() > kWordIdx ? batch[kWordIdx] : Tensor()); - int B = output.dim(2); - if (!tokenTarget.isEmpty() && - (tokenTarget.ndim() > 2 || tokenTarget.dim(1) != B)) { - throw std::runtime_error("token target should be LxB"); + int B = output.dim(2); + if( + !tokenTarget.isEmpty() + && (tokenTarget.ndim() > 2 || tokenTarget.dim(1) != B) + ) { + throw std::runtime_error("token target should be LxB"); + } + if( + !wordTarget.isEmpty() + && (wordTarget.ndim() > 2 || wordTarget.dim(1) != B) + ) { + throw std::runtime_error("word target should be LxB"); + } + // todo s2s, if we pad only with -1 we will be good here (not pad with eos) + for(int b = 0; b < B; b++) { + std::vector res(4); + res[kDMTokenPredIdx] = output(fl::span, fl::span, b); + res[kDMTokenTargetIdx] = removeNegative(tokenTarget(fl::span, b)); + res[kDMTokenTargetIdx] = + removePad(res[kDMTokenTargetIdx], trainOpt_.targetPadIdx); + res[kDMWordTargetIdx] = removeNegative(wordTarget(fl::span, b)); + res[kDMWordTargetIdx] = + removePad(res[kDMWordTargetIdx], trainOpt_.targetPadIdx); + emissionDataset->add(res); + } } - if (!wordTarget.isEmpty() && - (wordTarget.ndim() > 2 || wordTarget.dim(1) != B)) { - throw std::runtime_error("word target should be LxB"); - } - // todo s2s, if we pad only with -1 we will be good here (not pad with eos) - for (int b = 0; b < B; b++) { - std::vector res(4); - res[kDMTokenPredIdx] = output(fl::span, fl::span, b); - res[kDMTokenTargetIdx] = removeNegative(tokenTarget(fl::span, b)); - res[kDMTokenTargetIdx] = - removePad(res[kDMTokenTargetIdx], trainOpt_.targetPadIdx); - res[kDMWordTargetIdx] = removeNegative(wordTarget(fl::span, b)); - res[kDMWordTargetIdx] = - removePad(res[kDMWordTargetIdx], trainOpt_.targetPadIdx); - emissionDataset->add(res); - } - } - emissionDataset->writeIndex(); - return emissionDataset; + emissionDataset->writeIndex(); + return emissionDataset; } std::shared_ptr DecodeMaster::decode( const std::shared_ptr& emissionDataset, - fl::lib::text::Decoder& decoder) { - auto predDataset = std::make_shared(); - for (auto& sample : *emissionDataset) { - auto emission = sample[kDMTokenPredIdx]; - if (emission.ndim() > 2) { - throw std::runtime_error("emission should be NxT"); - } - std::vector emissionV(emission.elements()); - emission.astype(fl::dtype::f32).host(emissionV.data()); - auto results = - decoder.decode(emissionV.data(), emission.dim(1), emission.dim(0)); + fl::lib::text::Decoder& decoder +) { + auto predDataset = std::make_shared(); + for(auto& sample : *emissionDataset) { + auto emission = sample[kDMTokenPredIdx]; + if(emission.ndim() > 2) { + throw std::runtime_error("emission should be NxT"); + } + std::vector emissionV(emission.elements()); + emission.astype(fl::dtype::f32).host(emissionV.data()); + auto results = + decoder.decode(emissionV.data(), emission.dim(1), emission.dim(0)); - std::vector tokensV, wordsV; - if (!results.empty()) { - tokensV = results[0].tokens; - wordsV = results[0].words; + std::vector tokensV, wordsV; + if(!results.empty()) { + tokensV = results[0].tokens; + wordsV = results[0].words; + } + tokensV.erase( + std::remove(tokensV.begin(), tokensV.end(), -1), + tokensV.end() + ); + wordsV.erase(std::remove(wordsV.begin(), wordsV.end(), -1), wordsV.end()); + sample[kDMTokenPredIdx] = + (!tokensV.empty() ? Tensor::fromVector(tokensV) : Tensor()); + sample[kDMWordPredIdx] = + (!wordsV.empty() ? Tensor::fromVector(wordsV) : Tensor()); + predDataset->add(sample); } - tokensV.erase( - std::remove(tokensV.begin(), tokensV.end(), -1), tokensV.end()); - wordsV.erase(std::remove(wordsV.begin(), wordsV.end(), -1), wordsV.end()); - sample[kDMTokenPredIdx] = - (!tokensV.empty() ? Tensor::fromVector(tokensV) : Tensor()); - sample[kDMWordPredIdx] = - (!wordsV.empty() ? Tensor::fromVector(wordsV) : Tensor()); - predDataset->add(sample); - } - predDataset->writeIndex(); - return predDataset; + predDataset->writeIndex(); + return predDataset; } TokenDecodeMaster::TokenDecodeMaster( @@ -223,75 +248,81 @@ TokenDecodeMaster::TokenDecodeMaster( const bool usePlugin, const fl::lib::text::Dictionary& tokenDict, const fl::lib::text::Dictionary& wordDict, - const DecodeMasterTrainOptions& trainOpt) - : DecodeMaster(net, lm, true, usePlugin, tokenDict, wordDict, trainOpt), - transition_(transition) {} + const DecodeMasterTrainOptions& trainOpt +) : DecodeMaster(net, lm, true, usePlugin, tokenDict, wordDict, trainOpt), + transition_(transition) {} std::shared_ptr TokenDecodeMaster::decode( const std::shared_ptr& emissionDataset, - DecodeMasterLexiconFreeOptions opt) { - fl::lib::text::LexiconFreeDecoderOptions decoderOpt{ - .beamSize = opt.beamSize, - .beamSizeToken = opt.beamSizeToken, - .beamThreshold = opt.beamThreshold, - .lmWeight = opt.lmWeight, - .silScore = opt.silScore, - .logAdd = opt.logAdd, - .criterionType = fl::lib::text::CriterionType::CTC}; - auto silIdx = tokenDict_.getIndex(opt.silToken); - auto blankIdx = tokenDict_.getIndex(opt.blankToken); - fl::lib::text::LexiconFreeDecoder decoder( - decoderOpt, lm_, silIdx, blankIdx, transition_); - return DecodeMaster::decode(emissionDataset, decoder); + DecodeMasterLexiconFreeOptions opt +) { + fl::lib::text::LexiconFreeDecoderOptions decoderOpt{ + .beamSize = opt.beamSize, + .beamSizeToken = opt.beamSizeToken, + .beamThreshold = opt.beamThreshold, + .lmWeight = opt.lmWeight, + .silScore = opt.silScore, + .logAdd = opt.logAdd, + .criterionType = fl::lib::text::CriterionType::CTC}; + auto silIdx = tokenDict_.getIndex(opt.silToken); + auto blankIdx = tokenDict_.getIndex(opt.blankToken); + fl::lib::text::LexiconFreeDecoder decoder( + decoderOpt, lm_, silIdx, blankIdx, transition_); + return DecodeMaster::decode(emissionDataset, decoder); } std::shared_ptr TokenDecodeMaster::decode( const std::shared_ptr& emissionDataset, const fl::lib::text::LexiconMap& lexicon, - DecodeMasterLexiconOptions opt) { - auto trie = buildTrie(lexicon, opt.smearMode); - fl::lib::text::LexiconDecoderOptions decoderOpt{ - .beamSize = opt.beamSize, - .beamSizeToken = opt.beamSizeToken, - .beamThreshold = opt.beamThreshold, - .lmWeight = opt.lmWeight, - .wordScore = opt.wordScore, - .unkScore = opt.unkScore, - .silScore = opt.silScore, - .logAdd = opt.logAdd, - .criterionType = fl::lib::text::CriterionType::CTC}; - auto silIdx = tokenDict_.getIndex(opt.silToken); - auto blankIdx = tokenDict_.getIndex(opt.blankToken); - auto unkWordIdx = wordDict_.getIndex(fl::lib::text::kUnkToken); - fl::lib::text::LexiconDecoder decoder( - decoderOpt, trie, lm_, silIdx, blankIdx, unkWordIdx, transition_, true); - return DecodeMaster::decode(emissionDataset, decoder); + DecodeMasterLexiconOptions opt +) { + auto trie = buildTrie(lexicon, opt.smearMode); + fl::lib::text::LexiconDecoderOptions decoderOpt{ + .beamSize = opt.beamSize, + .beamSizeToken = opt.beamSizeToken, + .beamThreshold = opt.beamThreshold, + .lmWeight = opt.lmWeight, + .wordScore = opt.wordScore, + .unkScore = opt.unkScore, + .silScore = opt.silScore, + .logAdd = opt.logAdd, + .criterionType = fl::lib::text::CriterionType::CTC}; + auto silIdx = tokenDict_.getIndex(opt.silToken); + auto blankIdx = tokenDict_.getIndex(opt.blankToken); + auto unkWordIdx = wordDict_.getIndex(fl::lib::text::kUnkToken); + fl::lib::text::LexiconDecoder decoder( + decoderOpt, trie, lm_, silIdx, blankIdx, unkWordIdx, transition_, true); + return DecodeMaster::decode(emissionDataset, decoder); } std::vector TokenDecodeMaster::computeStringPred( - const std::vector& tokenIdxSeq) { - return tknPrediction2Ltr( - tokenIdxSeq, - tokenDict_, - "ctc", - trainOpt_.surround, - false, // eosToken - trainOpt_.repLabel, - trainOpt_.wordSepIsPartOfToken, - trainOpt_.wordSep); + const std::vector& tokenIdxSeq +) { + return tknPrediction2Ltr( + tokenIdxSeq, + tokenDict_, + "ctc", + trainOpt_.surround, + false, // eosToken + trainOpt_.repLabel, + trainOpt_.wordSepIsPartOfToken, + trainOpt_.wordSep + ); } std::vector TokenDecodeMaster::computeStringTarget( - const std::vector& tokenIdxSeq) { - return tknTarget2Ltr( - tokenIdxSeq, - tokenDict_, - "ctc", - trainOpt_.surround, - false, // eosToken - trainOpt_.repLabel, - trainOpt_.wordSepIsPartOfToken, - trainOpt_.wordSep); + const std::vector& tokenIdxSeq +) { + return tknTarget2Ltr( + tokenIdxSeq, + tokenDict_, + "ctc", + trainOpt_.surround, + false, // eosToken + trainOpt_.repLabel, + trainOpt_.wordSepIsPartOfToken, + trainOpt_.wordSep + ); } WordDecodeMaster::WordDecodeMaster( @@ -301,57 +332,62 @@ WordDecodeMaster::WordDecodeMaster( const bool usePlugin, const fl::lib::text::Dictionary& tokenDict, const fl::lib::text::Dictionary& wordDict, - const DecodeMasterTrainOptions& trainOpt) - : DecodeMaster(net, lm, false, usePlugin, tokenDict, wordDict, trainOpt), - transition_(transition) {} + const DecodeMasterTrainOptions& trainOpt +) : DecodeMaster(net, lm, false, usePlugin, tokenDict, wordDict, trainOpt), + transition_(transition) {} std::shared_ptr WordDecodeMaster::decode( const std::shared_ptr& emissionDataset, const fl::lib::text::LexiconMap& lexicon, - DecodeMasterLexiconOptions opt) { - auto trie = buildTrie(lexicon, opt.smearMode); - fl::lib::text::LexiconDecoderOptions decoderOpt{ - .beamSize = opt.beamSize, - .beamSizeToken = opt.beamSizeToken, - .beamThreshold = opt.beamThreshold, - .lmWeight = opt.lmWeight, - .wordScore = opt.wordScore, - .unkScore = opt.unkScore, - .silScore = opt.silScore, - .logAdd = opt.logAdd, - .criterionType = fl::lib::text::CriterionType::CTC}; - auto silIdx = tokenDict_.getIndex(opt.silToken); - auto blankIdx = tokenDict_.getIndex(opt.blankToken); - auto unkWordIdx = wordDict_.getIndex(opt.unkToken); - fl::lib::text::LexiconDecoder decoder( - decoderOpt, trie, lm_, silIdx, blankIdx, unkWordIdx, transition_, false); - return DecodeMaster::decode(emissionDataset, decoder); + DecodeMasterLexiconOptions opt +) { + auto trie = buildTrie(lexicon, opt.smearMode); + fl::lib::text::LexiconDecoderOptions decoderOpt{ + .beamSize = opt.beamSize, + .beamSizeToken = opt.beamSizeToken, + .beamThreshold = opt.beamThreshold, + .lmWeight = opt.lmWeight, + .wordScore = opt.wordScore, + .unkScore = opt.unkScore, + .silScore = opt.silScore, + .logAdd = opt.logAdd, + .criterionType = fl::lib::text::CriterionType::CTC}; + auto silIdx = tokenDict_.getIndex(opt.silToken); + auto blankIdx = tokenDict_.getIndex(opt.blankToken); + auto unkWordIdx = wordDict_.getIndex(opt.unkToken); + fl::lib::text::LexiconDecoder decoder( + decoderOpt, trie, lm_, silIdx, blankIdx, unkWordIdx, transition_, false); + return DecodeMaster::decode(emissionDataset, decoder); } std::vector WordDecodeMaster::computeStringPred( - const std::vector& tokenIdxSeq) { - return tknPrediction2Ltr( - tokenIdxSeq, - tokenDict_, - "ctc", - trainOpt_.surround, - false, // eosToken - trainOpt_.repLabel, - trainOpt_.wordSepIsPartOfToken, - trainOpt_.wordSep); + const std::vector& tokenIdxSeq +) { + return tknPrediction2Ltr( + tokenIdxSeq, + tokenDict_, + "ctc", + trainOpt_.surround, + false, // eosToken + trainOpt_.repLabel, + trainOpt_.wordSepIsPartOfToken, + trainOpt_.wordSep + ); } std::vector WordDecodeMaster::computeStringTarget( - const std::vector& tokenIdxSeq) { - return tknTarget2Ltr( - tokenIdxSeq, - tokenDict_, - "ctc", - trainOpt_.surround, - false, // eosToken - trainOpt_.repLabel, - trainOpt_.wordSepIsPartOfToken, - trainOpt_.wordSep); + const std::vector& tokenIdxSeq +) { + return tknTarget2Ltr( + tokenIdxSeq, + tokenDict_, + "ctc", + trainOpt_.surround, + false, // eosToken + trainOpt_.repLabel, + trainOpt_.wordSepIsPartOfToken, + trainOpt_.wordSep + ); } } // namespace fl diff --git a/flashlight/pkg/speech/decoder/DecodeMaster.h b/flashlight/pkg/speech/decoder/DecodeMaster.h index aea17ea..a14bd66 100644 --- a/flashlight/pkg/speech/decoder/DecodeMaster.h +++ b/flashlight/pkg/speech/decoder/DecodeMaster.h @@ -17,154 +17,170 @@ namespace fl { namespace pkg { -namespace speech { - -struct DecodeMasterLexiconFreeOptions { - int beamSize; - int beamSizeToken; - double beamThreshold; - double lmWeight; - double silScore; - bool logAdd; - std::string silToken; - std::string blankToken; -}; - -struct DecodeMasterLexiconOptions { - int beamSize; - int beamSizeToken; - double beamThreshold; - double lmWeight; - double silScore; - double wordScore; - double unkScore; - bool logAdd; - std::string silToken; - std::string blankToken; - std::string unkToken; - fl::lib::text::SmearingMode smearMode; -}; - -struct DecodeMasterTrainOptions { - int repLabel; - bool wordSepIsPartOfToken; - std::string surround; - std::string wordSep; - int32_t targetPadIdx; -}; - -class DecodeMaster { - public: - explicit DecodeMaster( - const std::shared_ptr net, - const std::shared_ptr lm, - const bool isTokenLM, - const bool usePlugin, - const fl::lib::text::Dictionary& tokenDict, - const fl::lib::text::Dictionary& wordDict, - const DecodeMasterTrainOptions& trainOpt); - - // compute emissions - virtual std::shared_ptr forward( - const std::shared_ptr& ds); - - // decode emissions with an existing decoder - std::shared_ptr decode( - const std::shared_ptr& eds, - fl::lib::text::Decoder& decoder); - - // returns token edit distance and word edit distance stats - std::pair, std::vector> computeMetrics( - const std::shared_ptr& pds); - - // convert tokens indices predictions into tokens string - virtual std::vector computeStringPred( - const std::vector& tokenIdxSeq) = 0; - - // convert tokens indices predictions into tokens string - virtual std::vector computeStringTarget( - const std::vector& tokenIdxSeq) = 0; - - virtual ~DecodeMaster() = default; - - protected: - std::shared_ptr buildTrie( - const fl::lib::text::LexiconMap& lexicon, - fl::lib::text::SmearingMode smearMode) const; - - std::shared_ptr net_; - std::shared_ptr lm_; - bool isTokenLM_; - bool usePlugin_; - fl::lib::text::Dictionary tokenDict_; - fl::lib::text::Dictionary wordDict_; - DecodeMasterTrainOptions trainOpt_; -}; + namespace speech { + + struct DecodeMasterLexiconFreeOptions { + int beamSize; + int beamSizeToken; + double beamThreshold; + double lmWeight; + double silScore; + bool logAdd; + std::string silToken; + std::string blankToken; + }; + + struct DecodeMasterLexiconOptions { + int beamSize; + int beamSizeToken; + double beamThreshold; + double lmWeight; + double silScore; + double wordScore; + double unkScore; + bool logAdd; + std::string silToken; + std::string blankToken; + std::string unkToken; + fl::lib::text::SmearingMode smearMode; + }; + + struct DecodeMasterTrainOptions { + int repLabel; + bool wordSepIsPartOfToken; + std::string surround; + std::string wordSep; + int32_t targetPadIdx; + }; + + class DecodeMaster { + public: + explicit DecodeMaster( + const std::shared_ptr net, + const std::shared_ptr lm, + const bool isTokenLM, + const bool usePlugin, + const fl::lib::text::Dictionary& tokenDict, + const fl::lib::text::Dictionary& wordDict, + const DecodeMasterTrainOptions& trainOpt + ); + + // compute emissions + virtual std::shared_ptr forward( + const std::shared_ptr& ds + ); + + // decode emissions with an existing decoder + std::shared_ptr decode( + const std::shared_ptr& eds, + fl::lib::text::Decoder& decoder + ); + + // returns token edit distance and word edit distance stats + std::pair, std::vector> computeMetrics( + const std::shared_ptr& pds + ); + + // convert tokens indices predictions into tokens string + virtual std::vector computeStringPred( + const std::vector& tokenIdxSeq + ) = 0; + + // convert tokens indices predictions into tokens string + virtual std::vector computeStringTarget( + const std::vector& tokenIdxSeq + ) = 0; + + virtual ~DecodeMaster() = default; + + protected: + std::shared_ptr buildTrie( + const fl::lib::text::LexiconMap& lexicon, + fl::lib::text::SmearingMode smearMode + ) const; + + std::shared_ptr net_; + std::shared_ptr lm_; + bool isTokenLM_; + bool usePlugin_; + fl::lib::text::Dictionary tokenDict_; + fl::lib::text::Dictionary wordDict_; + DecodeMasterTrainOptions trainOpt_; + }; // token-based CTC/ASG decoder (lexicon or lexicon-free) -class TokenDecodeMaster : public DecodeMaster { - public: - explicit TokenDecodeMaster( - const std::shared_ptr net, - const std::shared_ptr lm, - const std::vector& transition, - const bool usePlugin, - const fl::lib::text::Dictionary& tokenDict, - const fl::lib::text::Dictionary& wordDict, - const DecodeMasterTrainOptions& trainOpt); - - // compute predictions from emissions for lexicon free case - std::shared_ptr decode( - const std::shared_ptr& eds, - DecodeMasterLexiconFreeOptions opt); - - // compute predictions from emissions for lexicon case - std::shared_ptr decode( - const std::shared_ptr& eds, - const fl::lib::text::LexiconMap& lexicon, - DecodeMasterLexiconOptions opt); - - // convert tokens indices predictions into tokens string - virtual std::vector computeStringPred( - const std::vector& tokenIdxSeq) override; - - // convert tokens indices predictions into tokens string - virtual std::vector computeStringTarget( - const std::vector& tokenIdxSeq) override; - - private: - std::vector transition_; -}; + class TokenDecodeMaster : public DecodeMaster { + public: + explicit TokenDecodeMaster( + const std::shared_ptr net, + const std::shared_ptr lm, + const std::vector& transition, + const bool usePlugin, + const fl::lib::text::Dictionary& tokenDict, + const fl::lib::text::Dictionary& wordDict, + const DecodeMasterTrainOptions& trainOpt + ); + + // compute predictions from emissions for lexicon free case + std::shared_ptr decode( + const std::shared_ptr& eds, + DecodeMasterLexiconFreeOptions opt + ); + + // compute predictions from emissions for lexicon case + std::shared_ptr decode( + const std::shared_ptr& eds, + const fl::lib::text::LexiconMap& lexicon, + DecodeMasterLexiconOptions opt + ); + + // convert tokens indices predictions into tokens string + virtual std::vector computeStringPred( + const std::vector& tokenIdxSeq + ) override; + + // convert tokens indices predictions into tokens string + virtual std::vector computeStringTarget( + const std::vector& tokenIdxSeq + ) override; + + private: + std::vector transition_; + }; // word-based CTC/ASG decoder (lexicon or lexicon-free) -class WordDecodeMaster : public DecodeMaster { - public: - explicit WordDecodeMaster( - const std::shared_ptr net, - const std::shared_ptr lm, - const std::vector& transition, - const bool usePlugin, - const fl::lib::text::Dictionary& tokenDict, - const fl::lib::text::Dictionary& wordDict, - const DecodeMasterTrainOptions& trainOpt); - - // compute predictions from emissions - std::shared_ptr decode( - const std::shared_ptr& eds, - const fl::lib::text::LexiconMap& lexicon, - DecodeMasterLexiconOptions opt); - - // convert tokens indices predictions into tokens string - virtual std::vector computeStringPred( - const std::vector& tokenIdxSeq) override; - - // convert tokens indices predictions into tokens string - virtual std::vector computeStringTarget( - const std::vector& tokenIdxSeq) override; - - private: - std::vector transition_; -}; -} // namespace speech + class WordDecodeMaster : public DecodeMaster { + public: + explicit WordDecodeMaster( + const std::shared_ptr net, + const std::shared_ptr lm, + const std::vector& transition, + const bool usePlugin, + const fl::lib::text::Dictionary& tokenDict, + const fl::lib::text::Dictionary& wordDict, + const DecodeMasterTrainOptions& trainOpt + ); + + // compute predictions from emissions + std::shared_ptr decode( + const std::shared_ptr& eds, + const fl::lib::text::LexiconMap& lexicon, + DecodeMasterLexiconOptions opt + ); + + // convert tokens indices predictions into tokens string + virtual std::vector computeStringPred( + const std::vector& tokenIdxSeq + ) override; + + // convert tokens indices predictions into tokens string + virtual std::vector computeStringTarget( + const std::vector& tokenIdxSeq + ) override; + + private: + std::vector transition_; + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/decoder/DecodeUtils.cpp b/flashlight/pkg/speech/decoder/DecodeUtils.cpp index 0f5cf5f..37719d5 100644 --- a/flashlight/pkg/speech/decoder/DecodeUtils.cpp +++ b/flashlight/pkg/speech/decoder/DecodeUtils.cpp @@ -20,40 +20,44 @@ std::shared_ptr buildTrie( const fl::lib::text::LexiconMap& lexicon, const fl::lib::text::Dictionary& wordDict, const int wordSeparatorIdx, - const int repLabel) { - if (!(decoderType == "wrd" || useLexicon)) { - return nullptr; - } - auto trie = std::make_shared( - tokenDict.indexSize(), wordSeparatorIdx); - auto startState = lm->start(false); + const int repLabel +) { + if(!(decoderType == "wrd" || useLexicon)) { + return nullptr; + } + auto trie = std::make_shared( + tokenDict.indexSize(), + wordSeparatorIdx + ); + auto startState = lm->start(false); - for (auto& it : lexicon) { - const std::string& word = it.first; - int usrIdx = wordDict.getIndex(word); - float score = -1; - if (decoderType == "wrd") { - fl::lib::text::LMStatePtr dummyState; - std::tie(dummyState, score) = lm->score(startState, usrIdx); + for(auto& it : lexicon) { + const std::string& word = it.first; + int usrIdx = wordDict.getIndex(word); + float score = -1; + if(decoderType == "wrd") { + fl::lib::text::LMStatePtr dummyState; + std::tie(dummyState, score) = lm->score(startState, usrIdx); + } + for(auto& tokens : it.second) { + auto tokensTensor = tkn2Idx(tokens, tokenDict, repLabel); + trie->insert(tokensTensor, usrIdx, score); + } } - for (auto& tokens : it.second) { - auto tokensTensor = tkn2Idx(tokens, tokenDict, repLabel); - trie->insert(tokensTensor, usrIdx, score); + // Smearing + SmearingMode smearMode = SmearingMode::NONE; + if(smearing == "logadd") { + smearMode = SmearingMode::LOGADD; + } else if(smearing == "max") { + smearMode = SmearingMode::MAX; + } else if(smearing != "none") { + throw std::runtime_error( + "[buildTrie] Invalid smearing option, can be {logadd, max, none}, provided value is " + + smearing + ); } - } - // Smearing - SmearingMode smearMode = SmearingMode::NONE; - if (smearing == "logadd") { - smearMode = SmearingMode::LOGADD; - } else if (smearing == "max") { - smearMode = SmearingMode::MAX; - } else if (smearing != "none") { - throw std::runtime_error( - "[buildTrie] Invalid smearing option, can be {logadd, max, none}, provided value is " + - smearing); - } - trie->smear(smearMode); - return trie; + trie->smear(smearMode); + return trie; } } // namespace fl diff --git a/flashlight/pkg/speech/decoder/DecodeUtils.h b/flashlight/pkg/speech/decoder/DecodeUtils.h index c11253a..ca8bf9e 100644 --- a/flashlight/pkg/speech/decoder/DecodeUtils.h +++ b/flashlight/pkg/speech/decoder/DecodeUtils.h @@ -22,21 +22,22 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { /* A series of vector to vector mapping operations */ -std::shared_ptr buildTrie( - const std::string& decoderType, - bool useLexicon, - std::shared_ptr lm, - const std::string& smearing, - const fl::lib::text::Dictionary& tokenDict, - const fl::lib::text::LexiconMap& lexicon, - const fl::lib::text::Dictionary& wordDict, - const int wordSeparatorIdx, - const int repLabel); - -} // namespace speech + std::shared_ptr buildTrie( + const std::string& decoderType, + bool useLexicon, + std::shared_ptr lm, + const std::string& smearing, + const fl::lib::text::Dictionary& tokenDict, + const fl::lib::text::LexiconMap& lexicon, + const fl::lib::text::Dictionary& wordDict, + const int wordSeparatorIdx, + const int repLabel + ); + + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/decoder/Defines.h b/flashlight/pkg/speech/decoder/Defines.h index abd12aa..a99d312 100644 --- a/flashlight/pkg/speech/decoder/Defines.h +++ b/flashlight/pkg/speech/decoder/Defines.h @@ -13,38 +13,39 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { // Convenience structs for serializing emissions and targets -struct EmissionUnit { - std::vector emission; // A column-major tensor with shape T x N. - std::string sampleId; - int nFrames; - int nTokens; - - FL_SAVE_LOAD(emission, sampleId, nFrames, nTokens) - - EmissionUnit() : nFrames(0), nTokens(0) {} - - EmissionUnit( - const std::vector& emission, - const std::string& sampleId, - int nFrames, - int nTokens) - : emission(emission), - sampleId(sampleId), - nFrames(nFrames), - nTokens(nTokens) {} -}; - -struct TargetUnit { - std::vector wordTargetStr; // Word targets in strings - std::vector tokenTarget; // Token targets in indices - - FL_SAVE_LOAD(wordTargetStr, tokenTarget) -}; - -using EmissionTargetPair = std::pair; -} // namespace speech + struct EmissionUnit { + std::vector emission; // A column-major tensor with shape T x N. + std::string sampleId; + int nFrames; + int nTokens; + + FL_SAVE_LOAD(emission, sampleId, nFrames, nTokens) + + EmissionUnit() : nFrames(0), nTokens(0) {} + + EmissionUnit( + const std::vector& emission, + const std::string& sampleId, + int nFrames, + int nTokens + ) + : emission(emission), + sampleId(sampleId), + nFrames(nFrames), + nTokens(nTokens) {} + }; + + struct TargetUnit { + std::vector wordTargetStr; // Word targets in strings + std::vector tokenTarget; // Token targets in indices + + FL_SAVE_LOAD(wordTargetStr, tokenTarget) + }; + + using EmissionTargetPair = std::pair; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/decoder/PlGenerator.cpp b/flashlight/pkg/speech/decoder/PlGenerator.cpp index 4ba8db8..76bf39f 100644 --- a/flashlight/pkg/speech/decoder/PlGenerator.cpp +++ b/flashlight/pkg/speech/decoder/PlGenerator.cpp @@ -47,197 +47,209 @@ PlGenerator::PlGenerator( fl::Dataset::DataTransformFunction inputTransform, fl::Dataset::DataTransformFunction targetTransform, fl::Dataset::DataTransformFunction wordTransform, - TokenToWordFunc tokenToWord) - : worldRank_(worldRank), - isMaster_(worldRank_ == 0), - worldSize_(worldSize), - batchSize_(batchSize), - tokenDict_(tokenDict), - plDir_(runPath / kPlDir), - useExistingPl_(useExistingPl), - seedModelWER_(seedModelWER), - minInputSize_(minInputSize), - maxInputSize_(maxInputSize), - minTargetSize_(minTargetSize), - maxTargetSize_(maxTargetSize), - padVal_(padVal), - inputTransform_(inputTransform), - targetTransform_(targetTransform), - wordTransform_(wordTransform), - tokenToWord_(tokenToWord) { - // 1. Load PL generating intervals - auto plEpochVec = lib::split(',', plEpoch, true); - auto plRatioVec = lib::split(',', plRatio, true); - - if (plEpochVec.size() != plRatioVec.size()) { - throw std::invalid_argument( - "[PlGenerator] Size mismatch between pl_epoch and pl_ratio."); - } - - plEpochs_.resize(plEpochVec.size()); - for (int i = 0; i < plEpochVec.size(); i++) { - plEpochs_[i] = stoi(plEpochVec[i]); - } - - for (int i = 0; i < plEpochVec.size(); i++) { - auto ratio = stof(plRatioVec[i]); - if (ratio < 0 || ratio > 1) { - throw std::invalid_argument( - "[PlGenerator] The value of pl_ratio should be in [0, 1]."); + TokenToWordFunc tokenToWord +) : worldRank_(worldRank), + isMaster_(worldRank_ == 0), + worldSize_(worldSize), + batchSize_(batchSize), + tokenDict_(tokenDict), + plDir_(runPath / kPlDir), + useExistingPl_(useExistingPl), + seedModelWER_(seedModelWER), + minInputSize_(minInputSize), + maxInputSize_(maxInputSize), + minTargetSize_(minTargetSize), + maxTargetSize_(maxTargetSize), + padVal_(padVal), + inputTransform_(inputTransform), + targetTransform_(targetTransform), + wordTransform_(wordTransform), + tokenToWord_(tokenToWord) { + // 1. Load PL generating intervals + auto plEpochVec = lib::split(',', plEpoch, true); + auto plRatioVec = lib::split(',', plRatio, true); + + if(plEpochVec.size() != plRatioVec.size()) { + throw std::invalid_argument( + "[PlGenerator] Size mismatch between pl_epoch and pl_ratio." + ); } - if (i > 0 && plEpochs_[i] <= plEpochs_[i - 1]) { - throw std::invalid_argument( - "[PlGenerator] Elements in pl_epoch should be in ascendant order."); + + plEpochs_.resize(plEpochVec.size()); + for(int i = 0; i < plEpochVec.size(); i++) { + plEpochs_[i] = stoi(plEpochVec[i]); } - plUpdateMap_[plEpochs_[i]] = ratio; - } - - // 2. Build the full unlabeled set - std::vector> allListDs; - auto paths = lib::split(',', trainUnsupLists, true); - for (auto& path : paths) { - auto curListDs = std::make_shared( - trainUnsupDir / path, - inputTransform_, - targetTransform_, - wordTransform_); - allListDs.emplace_back(curListDs); - } - if (!allListDs.empty()) { - if (isMaster_) { - fs::create_directory(plDir_); + for(int i = 0; i < plEpochVec.size(); i++) { + auto ratio = stof(plRatioVec[i]); + if(ratio < 0 || ratio > 1) { + throw std::invalid_argument( + "[PlGenerator] The value of pl_ratio should be in [0, 1]." + ); + } + if(i > 0 && plEpochs_[i] <= plEpochs_[i - 1]) { + throw std::invalid_argument( + "[PlGenerator] Elements in pl_epoch should be in ascendant order." + ); + } + plUpdateMap_[plEpochs_[i]] = ratio; + } + + // 2. Build the full unlabeled set + std::vector> allListDs; + auto paths = lib::split(',', trainUnsupLists, true); + for(auto& path : paths) { + auto curListDs = std::make_shared( + trainUnsupDir / path, + inputTransform_, + targetTransform_, + wordTransform_ + ); + + allListDs.emplace_back(curListDs); + } + if(!allListDs.empty()) { + if(isMaster_) { + fs::create_directory(plDir_); + } + fullUnsupDs_ = std::make_shared(allListDs); } - fullUnsupDs_ = std::make_shared(allListDs); - } } std::string PlGenerator::reloadPl(int curEpoch) const { - int lastPlEpoch = findLastPlEpoch(curEpoch); - if (lastPlEpoch < 0) { - return ""; - } - - fs::path plDir = plDir_ / (kPlSubdirPrefix + std::to_string(lastPlEpoch)); - - bool isPLReady = true; - for (int i = 0; i < worldSize_; i++) { - auto listFinishPath = plDir / (std::to_string(i) + ".fns"); - if (!fs::exists(listFinishPath)) { - isPLReady = false; - break; + int lastPlEpoch = findLastPlEpoch(curEpoch); + if(lastPlEpoch < 0) { + return ""; + } + + fs::path plDir = plDir_ / (kPlSubdirPrefix + std::to_string(lastPlEpoch)); + + bool isPLReady = true; + for(int i = 0; i < worldSize_; i++) { + auto listFinishPath = plDir / (std::to_string(i) + ".fns"); + if(!fs::exists(listFinishPath)) { + isPLReady = false; + break; + } + } + if(isPLReady) { + logMaster("[PlGenerator] Loading existing PL from " + plDir.string()); + return plDir; + } else { + logMaster("[PlGenerator] Failed to load PL from " + plDir.string()); + return ""; } - } - if (isPLReady) { - logMaster("[PlGenerator] Loading existing PL from " + plDir.string()); - return plDir; - } else { - logMaster("[PlGenerator] Failed to load PL from " + plDir.string()); - return ""; - } } std::string PlGenerator::regeneratePl( int curEpoch, const std::shared_ptr& ntwrk, const std::shared_ptr criterion, - const bool usePlugin /* = false */) const { - if (plUpdateMap_.find(curEpoch) == plUpdateMap_.end()) { - return ""; - } - if (!fullUnsupDs_) { - throw std::runtime_error("No unlabeled data is provided"); - } - - logMaster( - "[PlGenerator] Regenerating PL at epoch " + std::to_string(curEpoch)); - fs::path plDir = plDir_ / (kPlSubdirPrefix + std::to_string(curEpoch)); - - /* 0. Create logging folder */ - try { - fs::create_directory(plDir); - } catch (...) { - // Pass. Allowing attempts from all processes to create the folder. - } - - if (!fs::is_directory(plDir)) { - throw std::runtime_error( - "[PlGenerator] Failed to create " + plDir.string()); - } - - /* 1. select data */ - // shuffle - auto ds1 = std::make_shared(fullUnsupDs_, curEpoch); - - // select - float ratio = plUpdateMap_.at(curEpoch); - int nSelectedSamples = int(fullUnsupDs_->size() * ratio); - std::vector sortedIds(nSelectedSamples); - std::iota(sortedIds.begin(), sortedIds.end(), 0); - auto ds2 = std::make_shared(ds1, sortedIds); - - // dispatch - auto partitions = - fl::partitionByRoundRobin(ds2->size(), worldRank_, worldSize_, 1); - auto ds3 = std::make_shared(ds2, partitions); - - // prefetch - auto selectedDs = std::make_shared(ds3, 3, 3); - - logMaster( - "[PlGenerator] " + std::to_string(nSelectedSamples) + "/" + - std::to_string(fullUnsupDs_->size()) + " samples selected"); - - /* 2. pseudo label generation */ - ntwrk->eval(); - auto newPlFile = plDir / (std::to_string(worldRank_) + ".lst"); - std::ofstream plStream(newPlFile); - for (auto& sample : *selectedDs) { - auto duration = sample[kDurationIdx].scalar(); - if (duration < minInputSize_ || duration > maxInputSize_) { - continue; + const bool usePlugin /* = false */ +) const { + if(plUpdateMap_.find(curEpoch) == plUpdateMap_.end()) { + return ""; + } + if(!fullUnsupDs_) { + throw std::runtime_error("No unlabeled data is provided"); } - std::vector words; - if (useExistingPl_ && seedModelWER_ < currentModelWER_) { - auto tokenTarget = sample[kTargetIdx].toHostVector(); - words = tokenToWord_(tokenTarget, tokenDict_, false); - } else { - fl::Variable rawEmission; - if (usePlugin) { - rawEmission = ntwrk - ->forward( - {fl::input(sample[kInputIdx]), - fl::noGrad(sample[kDurationIdx])}) - .front(); - } else { - rawEmission = fl::pkg::runtime::forwardSequentialModuleWithPadMask( - fl::input(sample[kInputIdx]), ntwrk, sample[kDurationIdx]); - } - auto tokenPrediction = - criterion->viterbiPath(rawEmission.tensor()).toHostVector(); - words = tokenToWord_(tokenPrediction, tokenDict_, true); + logMaster( + "[PlGenerator] Regenerating PL at epoch " + std::to_string(curEpoch) + ); + fs::path plDir = plDir_ / (kPlSubdirPrefix + std::to_string(curEpoch)); + + /* 0. Create logging folder */ + try { + fs::create_directory(plDir); + } catch(...) { + // Pass. Allowing attempts from all processes to create the folder. } - if (words.size() < minTargetSize_ || words.size() > maxTargetSize_) { - continue; + + if(!fs::is_directory(plDir)) { + throw std::runtime_error( + "[PlGenerator] Failed to create " + plDir.string() + ); } - auto sampleId = readSampleIds(sample[kSampleIdx]).front(); - auto inputPath = readSampleIds(sample[kPathIdx]).front(); - plStream << sampleId << "\t" << inputPath << "\t" - << std::to_string(duration) << "\t" << lib::join(" ", words) - << std::endl; - } - plStream.close(); - - fs::path finishPlFile = plDir / (std::to_string(worldRank_) + ".fns"); - std::ofstream fnsStream(finishPlFile); - fnsStream << "done"; - fnsStream.close(); - - /* 3. waiting for all the other processes */ - fl::barrier(); - return plDir; + /* 1. select data */ + // shuffle + auto ds1 = std::make_shared(fullUnsupDs_, curEpoch); + + // select + float ratio = plUpdateMap_.at(curEpoch); + int nSelectedSamples = int(fullUnsupDs_->size() * ratio); + std::vector sortedIds(nSelectedSamples); + std::iota(sortedIds.begin(), sortedIds.end(), 0); + auto ds2 = std::make_shared(ds1, sortedIds); + + // dispatch + auto partitions = + fl::partitionByRoundRobin(ds2->size(), worldRank_, worldSize_, 1); + auto ds3 = std::make_shared(ds2, partitions); + + // prefetch + auto selectedDs = std::make_shared(ds3, 3, 3); + + logMaster( + "[PlGenerator] " + std::to_string(nSelectedSamples) + "/" + + std::to_string(fullUnsupDs_->size()) + " samples selected" + ); + + /* 2. pseudo label generation */ + ntwrk->eval(); + auto newPlFile = plDir / (std::to_string(worldRank_) + ".lst"); + std::ofstream plStream(newPlFile); + for(auto& sample : *selectedDs) { + auto duration = sample[kDurationIdx].scalar(); + if(duration < minInputSize_ || duration > maxInputSize_) { + continue; + } + + std::vector words; + if(useExistingPl_ && seedModelWER_ < currentModelWER_) { + auto tokenTarget = sample[kTargetIdx].toHostVector(); + words = tokenToWord_(tokenTarget, tokenDict_, false); + } else { + fl::Variable rawEmission; + if(usePlugin) { + rawEmission = ntwrk + ->forward( + {fl::input(sample[kInputIdx]), + fl::noGrad(sample[kDurationIdx])} + ) + .front(); + } else { + rawEmission = fl::pkg::runtime::forwardSequentialModuleWithPadMask( + fl::input(sample[kInputIdx]), + ntwrk, + sample[kDurationIdx] + ); + } + auto tokenPrediction = + criterion->viterbiPath(rawEmission.tensor()).toHostVector(); + words = tokenToWord_(tokenPrediction, tokenDict_, true); + } + if(words.size() < minTargetSize_ || words.size() > maxTargetSize_) { + continue; + } + + auto sampleId = readSampleIds(sample[kSampleIdx]).front(); + auto inputPath = readSampleIds(sample[kPathIdx]).front(); + plStream << sampleId << "\t" << inputPath << "\t" + << std::to_string(duration) << "\t" << lib::join(" ", words) + << std::endl; + } + plStream.close(); + + fs::path finishPlFile = plDir / (std::to_string(worldRank_) + ".fns"); + std::ofstream fnsStream(finishPlFile); + fnsStream << "done"; + fnsStream.close(); + + /* 3. waiting for all the other processes */ + fl::barrier(); + return plDir; } std::shared_ptr PlGenerator::createTrainSet( @@ -245,50 +257,52 @@ std::shared_ptr PlGenerator::createTrainSet( const fs::path& trainLists, const fs::path& trainUnsupDir, const std::string& batchingStrategy /* = kBatchStrategyNone */, - int maxDurationPerBatch /* = 0 */) const { - std::vector files; - for (const auto& file : lib::split(",", trainLists, true)) { - files.emplace_back(trainDir / file); - } - for (int i = 0; i < worldSize_; i++) { - files.emplace_back(trainUnsupDir / (std::to_string(i) + ".lst")); - } - - return createDataset( - files, - "", - batchSize_, - inputTransform_, - targetTransform_, - wordTransform_, - padVal_, - worldRank_, - worldSize_, - false, // allowEmpty - batchingStrategy, - maxDurationPerBatch); + int maxDurationPerBatch /* = 0 */ +) const { + std::vector files; + for(const auto& file : lib::split(",", trainLists, true)) { + files.emplace_back(trainDir / file); + } + for(int i = 0; i < worldSize_; i++) { + files.emplace_back(trainUnsupDir / (std::to_string(i) + ".lst")); + } + + return createDataset( + files, + "", + batchSize_, + inputTransform_, + targetTransform_, + wordTransform_, + padVal_, + worldRank_, + worldSize_, + false, // allowEmpty + batchingStrategy, + maxDurationPerBatch + ); } void PlGenerator::setModelWER(const float& wer) { - currentModelWER_ = wer; + currentModelWER_ = wer; } int PlGenerator::findLastPlEpoch(int curEpoch) const { - int lastPlEpoch = -1; - for (const auto& i : plEpochs_) { - if (i > curEpoch) { - break; + int lastPlEpoch = -1; + for(const auto& i : plEpochs_) { + if(i > curEpoch) { + break; + } + lastPlEpoch = i; } - lastPlEpoch = i; - } - return lastPlEpoch; + return lastPlEpoch; } void PlGenerator::logMaster(const std::string& message) const { - if (worldRank_ != 0) { - return; - } - std::cerr << message << std::endl; + if(worldRank_ != 0) { + return; + } + std::cerr << message << std::endl; } } // namespace fl diff --git a/flashlight/pkg/speech/decoder/PlGenerator.h b/flashlight/pkg/speech/decoder/PlGenerator.h index 31ed5e6..f54c3aa 100644 --- a/flashlight/pkg/speech/decoder/PlGenerator.h +++ b/flashlight/pkg/speech/decoder/PlGenerator.h @@ -18,10 +18,10 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { -using TokenToWordFunc = std::function(const std::vector&, const lib::text::Dictionary&, bool)>; + using TokenToWordFunc = std::function(const std::vector&, const lib::text::Dictionary&, bool)>; /** * PlGenerator is an easy plug-in to Train.cpp for generating pseudo labels on @@ -56,93 +56,96 @@ using TokenToWordFunc = std::function& padVal, - fl::Dataset::DataTransformFunction inputTransform, - fl::Dataset::DataTransformFunction targetTransform, - fl::Dataset::DataTransformFunction wordTransform, - TokenToWordFunc tokenToWord); - - /* - * To resume trainig, try to load existing pseudo labels. - * `nullptr` is returned if loading fails. - */ - std::string reloadPl(int curEpoch) const; - - /* - * To regenerate pseudo labels with the current model. - * `nullptr` is returned if it's not supposed to do relabeling at the current - * epoch. - */ - std::string regeneratePl( - int curEpoch, - const std::shared_ptr& ntwrk, - const std::shared_ptr criterion, - const bool usePlugin = false) const; - - /* - * This function will create a mixture of supervised data and unalabeled data - * with pseudo labels. - */ - std::shared_ptr createTrainSet( - const fs::path& trainDir, - const fs::path& trainLists, - const fs::path& trainUnsupDir, - const std::string& batchingStrategy = kBatchStrategyNone, - int maxDurationPerBatch = 0) const; - - /* To set the WER of current model in PlGenerator */ - void setModelWER(const float& wer); - - private: - int worldRank_; - bool isMaster_; - int worldSize_; - int batchSize_; - - lib::text::Dictionary tokenDict_; - fs::path plDir_; - - bool useExistingPl_; - double seedModelWER_; - double currentModelWER_; - - float minInputSize_; - float maxInputSize_; - int minTargetSize_; - int maxTargetSize_; - - std::tuple padVal_; - fl::Dataset::DataTransformFunction inputTransform_; - fl::Dataset::DataTransformFunction targetTransform_; - fl::Dataset::DataTransformFunction wordTransform_; - TokenToWordFunc tokenToWord_; - - std::shared_ptr fullUnsupDs_; - std::vector plEpochs_; - std::unordered_map plUpdateMap_; - - int findLastPlEpoch(int curEpoch) const; - void logMaster(const std::string& message) const; -}; - -} // namespace speech + class PlGenerator { + public: + PlGenerator( + const lib::text::Dictionary& tokenDict, + const fs::path& runPath, + int worldRank, + int worldSize, + int batchSize, + const fs::path& trainUnsupDir, + const std::string& trainUnsupLists, + const std::string& plEpoch, + const std::string& plRatio, + bool useExistingPl, + float seedModelWER, + double minInputSize, // in milliseconds + double maxInputSize, // in milliseconds + int minTargetSize, // in words + int maxTargetSize, // in words + const std::tuple& padVal, + fl::Dataset::DataTransformFunction inputTransform, + fl::Dataset::DataTransformFunction targetTransform, + fl::Dataset::DataTransformFunction wordTransform, + TokenToWordFunc tokenToWord + ); + + /* + * To resume trainig, try to load existing pseudo labels. + * `nullptr` is returned if loading fails. + */ + std::string reloadPl(int curEpoch) const; + + /* + * To regenerate pseudo labels with the current model. + * `nullptr` is returned if it's not supposed to do relabeling at the current + * epoch. + */ + std::string regeneratePl( + int curEpoch, + const std::shared_ptr& ntwrk, + const std::shared_ptr criterion, + const bool usePlugin = false + ) const; + + /* + * This function will create a mixture of supervised data and unalabeled data + * with pseudo labels. + */ + std::shared_ptr createTrainSet( + const fs::path& trainDir, + const fs::path& trainLists, + const fs::path& trainUnsupDir, + const std::string& batchingStrategy = kBatchStrategyNone, + int maxDurationPerBatch = 0 + ) const; + + /* To set the WER of current model in PlGenerator */ + void setModelWER(const float& wer); + + private: + int worldRank_; + bool isMaster_; + int worldSize_; + int batchSize_; + + lib::text::Dictionary tokenDict_; + fs::path plDir_; + + bool useExistingPl_; + double seedModelWER_; + double currentModelWER_; + + float minInputSize_; + float maxInputSize_; + int minTargetSize_; + int maxTargetSize_; + + std::tuple padVal_; + fl::Dataset::DataTransformFunction inputTransform_; + fl::Dataset::DataTransformFunction targetTransform_; + fl::Dataset::DataTransformFunction wordTransform_; + TokenToWordFunc tokenToWord_; + + std::shared_ptr fullUnsupDs_; + std::vector plEpochs_; + std::unordered_map plUpdateMap_; + + int findLastPlEpoch(int curEpoch) const; + void logMaster(const std::string& message) const; + }; + + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/decoder/TranscriptionUtils.cpp b/flashlight/pkg/speech/decoder/TranscriptionUtils.cpp index 0c46c5c..90ff9f1 100644 --- a/flashlight/pkg/speech/decoder/TranscriptionUtils.cpp +++ b/flashlight/pkg/speech/decoder/TranscriptionUtils.cpp @@ -19,62 +19,65 @@ std::vector tknIdx2Ltr( const std::vector& labels, const Dictionary& d, const bool useWordPiece, - const std::string& wordSep) { - std::vector result; - - for (auto id : labels) { - auto token = d.getEntry(id); - if (useWordPiece) { - auto splitToken = splitWrd(token); - for (const auto& c : splitToken) { - result.emplace_back(c); - } - } else { - result.emplace_back(token); + const std::string& wordSep +) { + std::vector result; + + for(auto id : labels) { + auto token = d.getEntry(id); + if(useWordPiece) { + auto splitToken = splitWrd(token); + for(const auto& c : splitToken) { + result.emplace_back(c); + } + } else { + result.emplace_back(token); + } } - } - if (!result.empty() && !wordSep.empty()) { - if (result.front() == wordSep) { - result.erase(result.begin()); + if(!result.empty() && !wordSep.empty()) { + if(result.front() == wordSep) { + result.erase(result.begin()); + } + if(!result.empty() && result.back() == wordSep) { + result.pop_back(); + } } - if (!result.empty() && result.back() == wordSep) { - result.pop_back(); - } - } - return result; + return result; } std::vector tkn2Wrd( const std::vector& input, - const std::string& wordSep) { - std::vector words; - std::string currentWord; - for (auto& tkn : input) { - if (tkn == wordSep) { - if (!currentWord.empty()) { + const std::string& wordSep +) { + std::vector words; + std::string currentWord; + for(auto& tkn : input) { + if(tkn == wordSep) { + if(!currentWord.empty()) { + words.push_back(currentWord); + currentWord = ""; + } + } else { + currentWord += tkn; + } + } + if(!currentWord.empty()) { words.push_back(currentWord); - currentWord = ""; - } - } else { - currentWord += tkn; } - } - if (!currentWord.empty()) { - words.push_back(currentWord); - } - return words; + return words; } std::vector wrdIdx2Wrd( const std::vector& input, - const Dictionary& wordDict) { - std::vector words; - for (auto wrdIdx : input) { - words.push_back(wordDict.getEntry(wrdIdx)); - } - return words; + const Dictionary& wordDict +) { + std::vector words; + for(auto wrdIdx : input) { + words.push_back(wordDict.getEntry(wrdIdx)); + } + return words; } std::vector tknTarget2Ltr( @@ -85,19 +88,20 @@ std::vector tknTarget2Ltr( const bool isSeq2seqCrit, const int replabel, const bool useWordPiece, - const std::string& wordSep) { - if (tokens.empty()) { - return std::vector{}; - } - - if (isSeq2seqCrit) { - if (tokens.back() == tokenDict.getIndex(kEosToken)) { - tokens.pop_back(); + const std::string& wordSep +) { + if(tokens.empty()) { + return std::vector{}; + } + + if(isSeq2seqCrit) { + if(tokens.back() == tokenDict.getIndex(kEosToken)) { + tokens.pop_back(); + } } - } - remapLabels(tokens, tokenDict, surround, isSeq2seqCrit, replabel); + remapLabels(tokens, tokenDict, surround, isSeq2seqCrit, replabel); - return tknIdx2Ltr(tokens, tokenDict, useWordPiece, wordSep); + return tknIdx2Ltr(tokens, tokenDict, useWordPiece, wordSep); } std::vector tknPrediction2Ltr( @@ -108,35 +112,38 @@ std::vector tknPrediction2Ltr( const bool isSeq2seqCrit, const int replabel, const bool useWordPiece, - const std::string& wordSep) { - if (tokens.empty()) { - return std::vector{}; - } - - if (criterion == kCtcCriterion || criterion == kAsgCriterion) { - dedup(tokens); - } - if (criterion == kCtcCriterion) { - int blankIdx = tokenDict.getIndex(kBlankToken); - tokens.erase( - std::remove(tokens.begin(), tokens.end(), blankIdx), tokens.end()); - } - tokens = validateIdx(tokens, -1); - remapLabels(tokens, tokenDict, surround, isSeq2seqCrit, replabel); - - return tknIdx2Ltr(tokens, tokenDict, useWordPiece, wordSep); + const std::string& wordSep +) { + if(tokens.empty()) { + return std::vector{}; + } + + if(criterion == kCtcCriterion || criterion == kAsgCriterion) { + dedup(tokens); + } + if(criterion == kCtcCriterion) { + int blankIdx = tokenDict.getIndex(kBlankToken); + tokens.erase( + std::remove(tokens.begin(), tokens.end(), blankIdx), + tokens.end() + ); + } + tokens = validateIdx(tokens, -1); + remapLabels(tokens, tokenDict, surround, isSeq2seqCrit, replabel); + + return tknIdx2Ltr(tokens, tokenDict, useWordPiece, wordSep); } std::vector validateIdx(std::vector input, int unkIdx) { - int newSize = 0; - for (int i = 0; i < input.size(); i++) { - if (input[i] >= 0 && input[i] != unkIdx) { - input[newSize] = input[i]; - newSize++; + int newSize = 0; + for(int i = 0; i < input.size(); i++) { + if(input[i] >= 0 && input[i] != unkIdx) { + input[newSize] = input[i]; + newSize++; + } } - } - input.resize(newSize); + input.resize(newSize); - return input; + return input; } } // namespace fl diff --git a/flashlight/pkg/speech/decoder/TranscriptionUtils.h b/flashlight/pkg/speech/decoder/TranscriptionUtils.h index b6d388d..e8ce931 100644 --- a/flashlight/pkg/speech/decoder/TranscriptionUtils.h +++ b/flashlight/pkg/speech/decoder/TranscriptionUtils.h @@ -22,83 +22,91 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { /* A series of vector to vector mapping operations */ -std::vector tknIdx2Ltr( - const std::vector& labels, - const fl::lib::text::Dictionary& d, - bool useWordPiece, - const std::string& wordSep); + std::vector tknIdx2Ltr( + const std::vector& labels, + const fl::lib::text::Dictionary& d, + bool useWordPiece, + const std::string& wordSep + ); -std::vector tkn2Wrd( - const std::vector& input, - const std::string& wordSep); + std::vector tkn2Wrd( + const std::vector& input, + const std::string& wordSep + ); -std::vector wrdIdx2Wrd( - const std::vector& input, - const fl::lib::text::Dictionary& wordDict); + std::vector wrdIdx2Wrd( + const std::vector& input, + const fl::lib::text::Dictionary& wordDict + ); -std::vector tknTarget2Ltr( - std::vector tokens, - const fl::lib::text::Dictionary& tokenDict, - const std::string& criterion, - const std::string& surround, - const bool isSeq2seqCrit, - const int replabel, - const bool useWordPiece, - const std::string& wordSep); + std::vector tknTarget2Ltr( + std::vector tokens, + const fl::lib::text::Dictionary& tokenDict, + const std::string& criterion, + const std::string& surround, + const bool isSeq2seqCrit, + const int replabel, + const bool useWordPiece, + const std::string& wordSep + ); -std::vector tknPrediction2Ltr( - std::vector tokens, - const fl::lib::text::Dictionary& tokenDict, - const std::string& criterion, - const std::string& surround, - const bool isSeq2seqCrit, - const int replabel, - const bool useWordPiece, - const std::string& wordSep); + std::vector tknPrediction2Ltr( + std::vector tokens, + const fl::lib::text::Dictionary& tokenDict, + const std::string& criterion, + const std::string& surround, + const bool isSeq2seqCrit, + const int replabel, + const bool useWordPiece, + const std::string& wordSep + ); -std::vector validateIdx(std::vector input, int unkIdx); + std::vector validateIdx(std::vector input, int unkIdx); -template -void remapLabels( - std::vector& labels, - const fl::lib::text::Dictionary& dict, - const std::string& surround, - const bool isSeq2seqCrit, - const int replabel) { - if (isSeq2seqCrit) { - int eosidx = dict.getIndex(kEosToken); - int padidx = dict.getIndex(fl::lib::text::kPadToken); - while (!labels.empty() && - (labels.back() == eosidx || labels.back() == padidx)) { - labels.pop_back(); - } - } else { - while (!labels.empty() && labels.back() == kTargetPadValue) { - labels.pop_back(); - } - } - if (replabel > 0) { - labels = unpackReplabels(labels, dict, replabel); - } - auto trimLabels = [&labels](int idx) { - if (!labels.empty() && labels.back() == idx) { - labels.pop_back(); - } - if (!labels.empty() && labels.front() == idx) { - labels.erase(labels.begin()); - } - }; - if (dict.contains(kSilToken)) { - trimLabels(dict.getIndex(kSilToken)); - } - if (!surround.empty()) { - trimLabels(dict.getIndex(surround)); - } -}; -} // namespace speech + template + void remapLabels( + std::vector& labels, + const fl::lib::text::Dictionary& dict, + const std::string& surround, + const bool isSeq2seqCrit, + const int replabel + ) { + if(isSeq2seqCrit) { + int eosidx = dict.getIndex(kEosToken); + int padidx = dict.getIndex(fl::lib::text::kPadToken); + while( + !labels.empty() + && (labels.back() == eosidx || labels.back() == padidx) + ) { + labels.pop_back(); + } + } else { + while(!labels.empty() && labels.back() == kTargetPadValue) { + labels.pop_back(); + } + } + if(replabel > 0) { + labels = unpackReplabels(labels, dict, replabel); + } + auto trimLabels = [&labels](int idx) { + if(!labels.empty() && labels.back() == idx) { + labels.pop_back(); + } + if(!labels.empty() && labels.front() == idx) { + labels.erase(labels.begin()); + } + }; + if(dict.contains(kSilToken)) { + trimLabels(dict.getIndex(kSilToken)); + } + if(!surround.empty()) { + trimLabels(dict.getIndex(surround)); + } + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/runtime/Attention.cpp b/flashlight/pkg/speech/runtime/Attention.cpp index 5be94de..fe72446 100644 --- a/flashlight/pkg/speech/runtime/Attention.cpp +++ b/flashlight/pkg/speech/runtime/Attention.cpp @@ -10,64 +10,90 @@ namespace fl::pkg::speech { std::shared_ptr createAttention() { - std::shared_ptr attention; - if (FLAGS_attention == fl::pkg::speech::kContentAttention) { - attention = std::make_shared(); - } else if (FLAGS_attention == fl::pkg::speech::kKeyValueAttention) { - attention = std::make_shared(true); - } else if (FLAGS_attention == fl::pkg::speech::kNeuralContentAttention) { - attention = std::make_shared(FLAGS_encoderdim); - } else if (FLAGS_attention == fl::pkg::speech::kSimpleLocationAttention) { - attention = std::make_shared(FLAGS_attnconvkernel); - } else if (FLAGS_attention == fl::pkg::speech::kLocationAttention) { - attention = std::make_shared( - FLAGS_encoderdim, FLAGS_attnconvkernel); - } else if (FLAGS_attention == fl::pkg::speech::kNeuralLocationAttention) { - attention = std::make_shared( - FLAGS_encoderdim, - FLAGS_attndim, - FLAGS_attnconvchannel, - FLAGS_attnconvkernel); - } // is it fine for transformer criterion? - else if (FLAGS_attention == fl::pkg::speech::kMultiHeadContentAttention) { - attention = std::make_shared( - FLAGS_encoderdim, FLAGS_numattnhead); - } else if ( - FLAGS_attention == fl::pkg::speech::kMultiHeadKeyValueContentAttention) { - attention = std::make_shared( - FLAGS_encoderdim, FLAGS_numattnhead, true); - } else if (FLAGS_attention == fl::pkg::speech::kMultiHeadSplitContentAttention) { - attention = std::make_shared( - FLAGS_encoderdim, FLAGS_numattnhead, false, true); - } else if ( - FLAGS_attention == - fl::pkg::speech::kMultiHeadKeyValueSplitContentAttention) { - attention = std::make_shared( - FLAGS_encoderdim, FLAGS_numattnhead, true, true); - } else { - throw std::runtime_error("Unimplmented attention: " + FLAGS_attention); - } - return attention; + std::shared_ptr attention; + if(FLAGS_attention == fl::pkg::speech::kContentAttention) { + attention = std::make_shared(); + } else if(FLAGS_attention == fl::pkg::speech::kKeyValueAttention) { + attention = std::make_shared(true); + } else if(FLAGS_attention == fl::pkg::speech::kNeuralContentAttention) { + attention = std::make_shared(FLAGS_encoderdim); + } else if(FLAGS_attention == fl::pkg::speech::kSimpleLocationAttention) { + attention = std::make_shared(FLAGS_attnconvkernel); + } else if(FLAGS_attention == fl::pkg::speech::kLocationAttention) { + attention = std::make_shared( + FLAGS_encoderdim, + FLAGS_attnconvkernel + ); + } else if(FLAGS_attention == fl::pkg::speech::kNeuralLocationAttention) { + attention = std::make_shared( + FLAGS_encoderdim, + FLAGS_attndim, + FLAGS_attnconvchannel, + FLAGS_attnconvkernel + ); + } // is it fine for transformer criterion? + else if(FLAGS_attention == fl::pkg::speech::kMultiHeadContentAttention) { + attention = std::make_shared( + FLAGS_encoderdim, + FLAGS_numattnhead + ); + } else if( + FLAGS_attention == fl::pkg::speech::kMultiHeadKeyValueContentAttention) { + attention = std::make_shared( + FLAGS_encoderdim, + FLAGS_numattnhead, + true + ); + } else if(FLAGS_attention == fl::pkg::speech::kMultiHeadSplitContentAttention) { + attention = std::make_shared( + FLAGS_encoderdim, + FLAGS_numattnhead, + false, + true + ); + } else if( + FLAGS_attention + == fl::pkg::speech::kMultiHeadKeyValueSplitContentAttention + ) { + attention = std::make_shared( + FLAGS_encoderdim, + FLAGS_numattnhead, + true, + true + ); + } else { + throw std::runtime_error("Unimplmented attention: " + FLAGS_attention); + } + return attention; } std::shared_ptr createAttentionWindow() { - std::shared_ptr window; - if (FLAGS_attnWindow == fl::pkg::speech::kNoWindow) { - window = nullptr; - } else if (FLAGS_attnWindow == fl::pkg::speech::kMedianWindow) { - window = std::make_shared( - FLAGS_leftWindowSize, FLAGS_rightWindowSize); - } else if (FLAGS_attnWindow == fl::pkg::speech::kStepWindow) { - window = std::make_shared( - FLAGS_minsil, FLAGS_maxsil, FLAGS_minrate, FLAGS_maxrate); - } else if (FLAGS_attnWindow == fl::pkg::speech::kSoftWindow) { - window = std::make_shared( - FLAGS_softwstd, FLAGS_softwrate, FLAGS_softwoffset); - } else if (FLAGS_attnWindow == fl::pkg::speech::kSoftPretrainWindow) { - window = std::make_shared(FLAGS_softwstd); - } else { - throw std::runtime_error("Unimplmented window: " + FLAGS_attnWindow); - } - return window; + std::shared_ptr window; + if(FLAGS_attnWindow == fl::pkg::speech::kNoWindow) { + window = nullptr; + } else if(FLAGS_attnWindow == fl::pkg::speech::kMedianWindow) { + window = std::make_shared( + FLAGS_leftWindowSize, + FLAGS_rightWindowSize + ); + } else if(FLAGS_attnWindow == fl::pkg::speech::kStepWindow) { + window = std::make_shared( + FLAGS_minsil, + FLAGS_maxsil, + FLAGS_minrate, + FLAGS_maxrate + ); + } else if(FLAGS_attnWindow == fl::pkg::speech::kSoftWindow) { + window = std::make_shared( + FLAGS_softwstd, + FLAGS_softwrate, + FLAGS_softwoffset + ); + } else if(FLAGS_attnWindow == fl::pkg::speech::kSoftPretrainWindow) { + window = std::make_shared(FLAGS_softwstd); + } else { + throw std::runtime_error("Unimplmented window: " + FLAGS_attnWindow); + } + return window; } } // namespace fl diff --git a/flashlight/pkg/speech/runtime/Attention.h b/flashlight/pkg/speech/runtime/Attention.h index f4467ed..29476f4 100644 --- a/flashlight/pkg/speech/runtime/Attention.h +++ b/flashlight/pkg/speech/runtime/Attention.h @@ -16,14 +16,14 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { /* * Utility function to create an attention for s2s in encoder-decoder. * From gflags it uses FLAGS_attention, FLAGS_encoderdim, FLAGS_attnconvkernel, * FLAGS_attnconvchannel, FLAGS_attndim, FLAGS_encoderdim, FLAGS_numattnhead */ -std::shared_ptr createAttention(); + std::shared_ptr createAttention(); /* * Utility function to create an force attention (attention window) @@ -32,8 +32,8 @@ std::shared_ptr createAttention(); * FLAGS_leftWindowSize, FLAGS_rightWindowSize FLAGS_softwstd, FLAGS_softwrate, * FLAGS_softwoffset */ -std::shared_ptr createAttentionWindow(); + std::shared_ptr createAttentionWindow(); -} // namespace speech + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/runtime/Helpers.cpp b/flashlight/pkg/speech/runtime/Helpers.cpp index cfdc6a1..ac5bbe6 100644 --- a/flashlight/pkg/speech/runtime/Helpers.cpp +++ b/flashlight/pkg/speech/runtime/Helpers.cpp @@ -16,7 +16,7 @@ #include "flashlight/lib/text/String.h" #ifdef FL_BUILD_FB_DEPENDENCIES - #include "deeplearning/projects/flashlight/fb/EverstoreDataset.h" +#include "deeplearning/projects/flashlight/fb/EverstoreDataset.h" #endif using fl::lib::format; @@ -26,69 +26,68 @@ using fl::lib::text::LexiconMap; namespace fl::pkg::speech { -template +template std::vector tensorMatrixToStrings( const Tensor& tensor, - T terminator) { - int L = tensor.dim(0); // padded length of string - int N = tensor.dim(1); // number of strings - std::vector result; - auto values = tensor.toHostVector(); - for (int i = 0; i < N; ++i) { - const T* row = &values[i * L]; - int len = 0; - while (len < L && row[len] != terminator) { - ++len; + T terminator +) { + int L = tensor.dim(0); // padded length of string + int N = tensor.dim(1); // number of strings + std::vector result; + auto values = tensor.toHostVector(); + for(int i = 0; i < N; ++i) { + const T* row = &values[i * L]; + int len = 0; + while(len < L && row[len] != terminator) { + ++len; + } + result.emplace_back(row, row + len); } - result.emplace_back(row, row + len); - } - return result; + return result; } -std::string -getRunFile(const std::string& name, int runidx, const fs::path& runpath) { - auto fname = format("%03d_%s", runidx, name.c_str()); - return runpath / fname; +std::string getRunFile(const std::string& name, int runidx, const fs::path& runpath) { + auto fname = format("%03d_%s", runidx, name.c_str()); + return runpath / fname; }; std::string cleanFilepath(const fs::path& in) { - std::string replace = in; - std::string sep(1, fs::path::preferred_separator); - replaceAll(replace, sep, "#"); - return replace; + std::string replace = in; + std::string sep(1, fs::path::preferred_separator); + replaceAll(replace, sep, "#"); + return replace; } std::string serializeGflags(const std::string& separator /* = "\n" */) { - std::stringstream serialized; - std::vector allFlags; - gflags::GetAllFlags(&allFlags); - std::string currVal; - auto& deprecatedFlags = detail::getDeprecatedFlags(); - for (auto itr = allFlags.begin(); itr != allFlags.end(); ++itr) { - // Check if the flag is deprecated - if so, skip it - if (deprecatedFlags.find(itr->name) == deprecatedFlags.end()) { - gflags::GetCommandLineOption(itr->name.c_str(), &currVal); - serialized << "--" << itr->name << "=" << currVal << separator; + std::stringstream serialized; + std::vector allFlags; + gflags::GetAllFlags(&allFlags); + std::string currVal; + auto& deprecatedFlags = detail::getDeprecatedFlags(); + for(auto itr = allFlags.begin(); itr != allFlags.end(); ++itr) { + // Check if the flag is deprecated - if so, skip it + if(deprecatedFlags.find(itr->name) == deprecatedFlags.end()) { + gflags::GetCommandLineOption(itr->name.c_str(), &currVal); + serialized << "--" << itr->name << "=" << currVal << separator; + } } - } - return serialized.str(); + return serialized.str(); } -std::unordered_set -getTrainEvalIds(int64_t dsSize, double pctTrainEval, int64_t seed) { - std::mt19937_64 rng(seed); - std::bernoulli_distribution dist(pctTrainEval / 100.0); - std::unordered_set result; - for (int64_t i = 0; i < dsSize; ++i) { - if (dist(rng)) { - result.insert(i); +std::unordered_set getTrainEvalIds(int64_t dsSize, double pctTrainEval, int64_t seed) { + std::mt19937_64 rng(seed); + std::bernoulli_distribution dist(pctTrainEval / 100.0); + std::unordered_set result; + for(int64_t i = 0; i < dsSize; ++i) { + if(dist(rng)) { + result.insert(i); + } } - } - return result; + return result; } std::vector readSampleIds(const Tensor& tensor) { - return tensorMatrixToStrings(tensor, '\0'); + return tensorMatrixToStrings(tensor, '\0'); } std::shared_ptr createDataset( @@ -103,131 +102,163 @@ std::shared_ptr createDataset( int worldSize /* = 1 */, const bool allowEmpty /* = false */, const std::string& batchingStrategy /* kBatchStrategyNone */, - int maxDurationPerBatch /* = 0 */) { - std::vector> allListDs; - std::vector sizes; - for (auto& path : paths) { - std::shared_ptr curListDs; - if (FLAGS_everstoredb) { + int maxDurationPerBatch /* = 0 */ +) { + std::vector> allListDs; + std::vector sizes; + for(auto& path : paths) { + std::shared_ptr curListDs; + if(FLAGS_everstoredb) { #ifdef FL_BUILD_FB_DEPENDENCIES - curListDs = std::make_shared( - rootDir / path, - inputTransform, - targetTransform, - wordTransform, - FLAGS_use_memcache); + curListDs = std::make_shared( + rootDir / path, + inputTransform, + targetTransform, + wordTransform, + FLAGS_use_memcache + ); #else - LOG(FATAL) << "EverstoreDataset not supported: " - << "build with -DFL_BUILD_FB_DEPENDENCIES"; + LOG(FATAL) << "EverstoreDataset not supported: " + << "build with -DFL_BUILD_FB_DEPENDENCIES"; #endif - } else { - curListDs = std::make_shared( - rootDir / path, inputTransform, targetTransform, wordTransform); + } else { + curListDs = std::make_shared( + rootDir / path, + inputTransform, + targetTransform, + wordTransform + ); + } + + allListDs.emplace_back(curListDs); + sizes.reserve(sizes.size() + curListDs->size()); + for(int64_t i = 0; i < curListDs->size(); ++i) { + sizes.push_back(curListDs->getInputSize(i)); + } } - allListDs.emplace_back(curListDs); - sizes.reserve(sizes.size() + curListDs->size()); - for (int64_t i = 0; i < curListDs->size(); ++i) { - sizes.push_back(curListDs->getInputSize(i)); + // Order Dataset + std::vector sortedIds(sizes.size()); + std::iota(sortedIds.begin(), sortedIds.end(), 0); + auto cmp = [&sizes](const int64_t& l, const int64_t& r) { + return sizes[l] > sizes[r]; + }; + if( + batchingStrategy == kBatchStrategyRand + || batchingStrategy == kBatchStrategyRandDynamic + ) { + auto rng = std::mt19937(sizes.size()); + for(int i = sizes.size(); i >= 1; i--) { + int index = rng() % sizes.size(); + std::swap(sortedIds[i - 1], sortedIds[index]); + std::swap(sizes[i - 1], sizes[index]); + } + } else { + std::stable_sort(sortedIds.begin(), sortedIds.end(), cmp); + std::stable_sort(sizes.begin(), sizes.end(), std::greater()); } - } - - // Order Dataset - std::vector sortedIds(sizes.size()); - std::iota(sortedIds.begin(), sortedIds.end(), 0); - auto cmp = [&sizes](const int64_t& l, const int64_t& r) { - return sizes[l] > sizes[r]; - }; - if (batchingStrategy == kBatchStrategyRand || - batchingStrategy == kBatchStrategyRandDynamic) { - auto rng = std::mt19937(sizes.size()); - for (int i = sizes.size(); i >= 1; i--) { - int index = rng() % sizes.size(); - std::swap(sortedIds[i - 1], sortedIds[index]); - std::swap(sizes[i - 1], sizes[index]); + + auto concatListDs = std::make_shared(allListDs); + + auto sortedDs = + std::make_shared(concatListDs, sortedIds); + + int inPad, tgtPad, wrdPad; + std::tie(inPad, tgtPad, wrdPad) = padVal; + auto batchFns = std::vector{ + [inPad](const std::vector& tensor) { + return fl::join(tensor, inPad, 3); + }, + [tgtPad](const std::vector& tensor) { + return fl::join(tensor, tgtPad, 1); + }, + [wrdPad](const std::vector& tensor) { + return fl::join(tensor, wrdPad, 1); + }, + [](const std::vector& tensor) { return fl::join(tensor, 0, 1); }, + [](const std::vector& tensor) { return fl::join(tensor, 0, 1); }, + [](const std::vector& tensor) { return fl::join(tensor, 0, 1); }, + [](const std::vector& tensor) { return fl::join(tensor, 0, 1); } + }; + if( + batchingStrategy == kBatchStrategyDynamic + || batchingStrategy == kBatchStrategyRandDynamic + ) { + // Partition the dataset and distribute + auto result = fl::dynamicPartitionByRoundRobin( + sizes, + worldRank, + worldSize, + maxDurationPerBatch, + allowEmpty + ); + auto partitions = result.first; + auto batchSizes = result.second; + auto paritionDs = + std::make_shared(sortedDs, partitions); + // Batch the dataset + return std::make_shared(paritionDs, batchSizes, batchFns); + } else if( + batchingStrategy == kBatchStrategyNone + || batchingStrategy == kBatchStrategyRand + ) { + // Partition the dataset and distribute + auto partitions = fl::partitionByRoundRobin( + sortedDs->size(), + worldRank, + worldSize, + batchSize, + allowEmpty + ); + auto paritionDs = + std::make_shared(sortedDs, partitions); + // Batch the dataset + return std::make_shared( + paritionDs, + batchSize, + fl::BatchDatasetPolicy::INCLUDE_LAST, + batchFns + ); + } else { + throw std::runtime_error( + "Unsupported batching strategy '" + batchingStrategy + "'" + ); } - } else { - std::stable_sort(sortedIds.begin(), sortedIds.end(), cmp); - std::stable_sort(sizes.begin(), sizes.end(), std::greater()); - } - - auto concatListDs = std::make_shared(allListDs); - - auto sortedDs = - std::make_shared(concatListDs, sortedIds); - - int inPad, tgtPad, wrdPad; - std::tie(inPad, tgtPad, wrdPad) = padVal; - auto batchFns = std::vector{ - [inPad](const std::vector& tensor) { - return fl::join(tensor, inPad, 3); - }, - [tgtPad](const std::vector& tensor) { - return fl::join(tensor, tgtPad, 1); - }, - [wrdPad](const std::vector& tensor) { - return fl::join(tensor, wrdPad, 1); - }, - [](const std::vector& tensor) { return fl::join(tensor, 0, 1); }, - [](const std::vector& tensor) { return fl::join(tensor, 0, 1); }, - [](const std::vector& tensor) { return fl::join(tensor, 0, 1); }, - [](const std::vector& tensor) { return fl::join(tensor, 0, 1); }}; - if (batchingStrategy == kBatchStrategyDynamic || - batchingStrategy == kBatchStrategyRandDynamic) { - // Partition the dataset and distribute - auto result = fl::dynamicPartitionByRoundRobin( - sizes, worldRank, worldSize, maxDurationPerBatch, allowEmpty); - auto partitions = result.first; - auto batchSizes = result.second; - auto paritionDs = - std::make_shared(sortedDs, partitions); - // Batch the dataset - return std::make_shared(paritionDs, batchSizes, batchFns); - } else if ( - batchingStrategy == kBatchStrategyNone || - batchingStrategy == kBatchStrategyRand) { - // Partition the dataset and distribute - auto partitions = fl::partitionByRoundRobin( - sortedDs->size(), worldRank, worldSize, batchSize, allowEmpty); - auto paritionDs = - std::make_shared(sortedDs, partitions); - // Batch the dataset - return std::make_shared( - paritionDs, batchSize, fl::BatchDatasetPolicy::INCLUDE_LAST, batchFns); - } else { - throw std::runtime_error( - "Unsupported batching strategy '" + batchingStrategy + "'"); - } } std::shared_ptr loadPrefetchDataset( std::shared_ptr dataset, int prefetchThreads, bool shuffle, - int shuffleSeed /*= 0 */) { - if (shuffle) { - dataset = std::make_shared(dataset, shuffleSeed); - } - if (prefetchThreads > 0) { - dataset = std::make_shared( - dataset, prefetchThreads, prefetchThreads /* prefetch size */); - } - return dataset; + int shuffleSeed /*= 0 */ +) { + if(shuffle) { + dataset = std::make_shared(dataset, shuffleSeed); + } + if(prefetchThreads > 0) { + dataset = std::make_shared( + dataset, + prefetchThreads, + prefetchThreads /* prefetch size */ + ); + } + return dataset; } std::vector> parseValidSets( - const std::string& valid) { - auto validSets = fl::lib::split(',', fl::lib::trim(valid), true); - std::vector> validTagSets; - for (const auto& s : validSets) { - // assume the format is tag:filepath - auto ts = fl::lib::splitOnAnyOf(":", s); - if (ts.size() == 1) { - validTagSets.emplace_back(s, s); - } else { - validTagSets.emplace_back(ts[0], ts[1]); + const std::string& valid +) { + auto validSets = fl::lib::split(',', fl::lib::trim(valid), true); + std::vector> validTagSets; + for(const auto& s : validSets) { + // assume the format is tag:filepath + auto ts = fl::lib::splitOnAnyOf(":", s); + if(ts.size() == 1) { + validTagSets.emplace_back(s, s); + } else { + validTagSets.emplace_back(ts[0], ts[1]); + } } - } - return validTagSets; + return validTagSets; } } // namespace fl diff --git a/flashlight/pkg/speech/runtime/Helpers.h b/flashlight/pkg/speech/runtime/Helpers.h index 7c4c87c..ef76ddb 100644 --- a/flashlight/pkg/speech/runtime/Helpers.h +++ b/flashlight/pkg/speech/runtime/Helpers.h @@ -30,31 +30,30 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { /** * Given a filename, remove any filepath delimiters - returns a contiguous * string that won't be subdivided into a filepath. */ -std::string cleanFilepath(const fs::path& inputFileName); + std::string cleanFilepath(const fs::path& inputFileName); /** * Serialize gflags into a buffer. * * Only serializes gflags that aren't explicitly deprecated. */ -std::string serializeGflags(const std::string& separator = "\n"); + std::string serializeGflags(const std::string& separator = "\n"); /** * Sample indices for the `--pcttraineval` flag. */ -std::unordered_set -getTrainEvalIds(int64_t dsSize, double pctTrainEval, int64_t seed); + std::unordered_set getTrainEvalIds(int64_t dsSize, double pctTrainEval, int64_t seed); /** * Read sample ids from an `Tensor`. */ -std::vector readSampleIds(const Tensor& arr); + std::vector readSampleIds(const Tensor& arr); /* * Utility function for creating a w2l dataset. @@ -68,35 +67,38 @@ std::vector readSampleIds(const Tensor& arr); * @param maxDurationPerBatch - is used for batchingStrategy="dynamic", max * total duration in a batch */ -std::shared_ptr createDataset( - const std::vector& paths, - const fs::path& rootDir = "", - int batchSize = 1, - const fl::Dataset::DataTransformFunction& inputTransform = nullptr, - const fl::Dataset::DataTransformFunction& targetTransform = nullptr, - const fl::Dataset::DataTransformFunction& wordTransform = nullptr, - const std::tuple& padVal = - std::tuple{0, -1, -1}, - int worldRank = 0, - int worldSize = 1, - const bool allowEmpty = false, - const std::string& batchingStrategy = kBatchStrategyNone, - int maxDurationPerBatch = 0); + std::shared_ptr createDataset( + const std::vector& paths, + const fs::path& rootDir = "", + int batchSize = 1, + const fl::Dataset::DataTransformFunction& inputTransform = nullptr, + const fl::Dataset::DataTransformFunction& targetTransform = nullptr, + const fl::Dataset::DataTransformFunction& wordTransform = nullptr, + const std::tuple& padVal = + std::tuple{0, -1, -1}, + int worldRank = 0, + int worldSize = 1, + const bool allowEmpty = false, + const std::string& batchingStrategy = kBatchStrategyNone, + int maxDurationPerBatch = 0 + ); -std::shared_ptr loadPrefetchDataset( - std::shared_ptr dataset, - int prefetchThreads, - bool shuffle, - int shuffleSeed = 0); + std::shared_ptr loadPrefetchDataset( + std::shared_ptr dataset, + int prefetchThreads, + bool shuffle, + int shuffleSeed = 0 + ); /* * Function to parse valid set string describing multiple datasets into a vector * Input Format: d1:d1.lst,d2:d2.lst returns {{d1, d1.lst}, {d2, d2.lst}} * Input Format: d1.lst,d2.lst returns {{d1.lst, d1.lst}, {d2.lst, d2.lst}} */ -std::vector> parseValidSets( - const std::string& valid); + std::vector> parseValidSets( + const std::string& valid + ); -} // namespace speech + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/runtime/Logger.cpp b/flashlight/pkg/speech/runtime/Logger.cpp index f9be1f8..4bee467 100644 --- a/flashlight/pkg/speech/runtime/Logger.cpp +++ b/flashlight/pkg/speech/runtime/Logger.cpp @@ -30,126 +30,137 @@ std::string getLogString( double lr, double lrcrit, double scaleFactor, - const std::string& separator /* = " | " */) { - std::string status; - auto insertItem = [&](std::string key, std::string val) { - val = key + ": " + val; - status = status + (status.empty() ? "" : separator) + val; - }; - insertItem("epoch", format("%8d", epoch)); - insertItem("nupdates", format("%12d", nupdates)); - insertItem("lr", format("%4.6lf", lr)); - insertItem("lrcriterion", format("%4.6lf", lrcrit)); - insertItem("scale-factor", format("%4.6lf", scaleFactor)); + const std::string& separator /* = " | " */ +) { + std::string status; + auto insertItem = [&](std::string key, std::string val) { + val = key + ": " + val; + status = status + (status.empty() ? "" : separator) + val; + }; + insertItem("epoch", format("%8d", epoch)); + insertItem("nupdates", format("%12d", nupdates)); + insertItem("lr", format("%4.6lf", lr)); + insertItem("lrcriterion", format("%4.6lf", lrcrit)); + insertItem("scale-factor", format("%4.6lf", scaleFactor)); - int rt = meters.runtime.value(); - insertItem( - "runtime", - format("%02d:%02d:%02d", (rt / 60 / 60), (rt / 60) % 60, rt % 60)); - insertItem("bch(ms)", format("%.2f", meters.timer.value() * 1000)); - insertItem("smp(ms)", format("%.2f", meters.sampletimer.value() * 1000)); - insertItem("fwd(ms)", format("%.2f", meters.fwdtimer.value() * 1000)); - insertItem( - "crit-fwd(ms)", format("%.2f", meters.critfwdtimer.value() * 1000)); - insertItem("bwd(ms)", format("%.2f", meters.bwdtimer.value() * 1000)); - insertItem("optim(ms)", format("%.2f", meters.optimtimer.value() * 1000)); - insertItem("loss", format("%10.5f", meters.train.loss.value()[0])); - - insertItem("train-TER", format("%5.2f", meters.train.tknEdit.errorRate()[0])); - insertItem("train-WER", format("%5.2f", meters.train.wrdEdit.errorRate()[0])); - for (auto& v : meters.valid) { - insertItem(v.first + "-loss", format("%10.5f", v.second.loss.value()[0])); + int rt = meters.runtime.value(); insertItem( - v.first + "-TER", format("%5.2f", v.second.tknEdit.errorRate()[0])); + "runtime", + format("%02d:%02d:%02d", (rt / 60 / 60), (rt / 60) % 60, rt % 60) + ); + insertItem("bch(ms)", format("%.2f", meters.timer.value() * 1000)); + insertItem("smp(ms)", format("%.2f", meters.sampletimer.value() * 1000)); + insertItem("fwd(ms)", format("%.2f", meters.fwdtimer.value() * 1000)); insertItem( - v.first + "-WER", format("%5.2f", v.second.wrdEdit.errorRate()[0])); - auto vDecoderIter = validDecoderWer.find(v.first); - if (vDecoderIter != validDecoderWer.end()) { - insertItem( - v.first + "-WER-decoded", format("%5.2f", vDecoderIter->second)); + "crit-fwd(ms)", + format("%.2f", meters.critfwdtimer.value() * 1000) + ); + insertItem("bwd(ms)", format("%.2f", meters.bwdtimer.value() * 1000)); + insertItem("optim(ms)", format("%.2f", meters.optimtimer.value() * 1000)); + insertItem("loss", format("%10.5f", meters.train.loss.value()[0])); + + insertItem("train-TER", format("%5.2f", meters.train.tknEdit.errorRate()[0])); + insertItem("train-WER", format("%5.2f", meters.train.wrdEdit.errorRate()[0])); + for(auto& v : meters.valid) { + insertItem(v.first + "-loss", format("%10.5f", v.second.loss.value()[0])); + insertItem( + v.first + "-TER", + format("%5.2f", v.second.tknEdit.errorRate()[0]) + ); + insertItem( + v.first + "-WER", + format("%5.2f", v.second.wrdEdit.errorRate()[0]) + ); + auto vDecoderIter = validDecoderWer.find(v.first); + if(vDecoderIter != validDecoderWer.end()) { + insertItem( + v.first + "-WER-decoded", + format("%5.2f", vDecoderIter->second) + ); + } + } + auto stats = meters.stats.value(); + auto numsamples = std::max(stats[4], 1); + auto numbatches = std::max(stats[5], 1); + // assumed to be in ms of original audios + auto isztotal = stats[0]; + auto tsztotal = stats[1]; + auto tszmax = stats[3]; + auto iszAvrFrames = isztotal / numsamples; + if(FLAGS_features_type != kFeaturesRaw) { + iszAvrFrames = iszAvrFrames / FLAGS_framestridems; + } else { + iszAvrFrames = iszAvrFrames / 1000 * FLAGS_samplerate; } - } - auto stats = meters.stats.value(); - auto numsamples = std::max(stats[4], 1); - auto numbatches = std::max(stats[5], 1); - // assumed to be in ms of original audios - auto isztotal = stats[0]; - auto tsztotal = stats[1]; - auto tszmax = stats[3]; - auto iszAvrFrames = isztotal / numsamples; - if (FLAGS_features_type != kFeaturesRaw) { - iszAvrFrames = iszAvrFrames / FLAGS_framestridems; - } else { - iszAvrFrames = iszAvrFrames / 1000 * FLAGS_samplerate; - } - insertItem("avg-isz", format("%03d", iszAvrFrames)); - insertItem("avg-tsz", format("%03d", tsztotal / numsamples)); - insertItem("max-tsz", format("%03d", tszmax)); + insertItem("avg-isz", format("%03d", iszAvrFrames)); + insertItem("avg-tsz", format("%03d", tsztotal / numsamples)); + insertItem("max-tsz", format("%03d", tszmax)); - auto worldSize = fl::getWorldSize(); - double timeTakenSec = meters.timer.value() * numbatches / worldSize; + auto worldSize = fl::getWorldSize(); + double timeTakenSec = meters.timer.value() * numbatches / worldSize; - insertItem("avr-batchsz", format("%7.2f", float(numsamples) / numbatches)); - insertItem("hrs", format("%7.2f", isztotal / 1000 / 3600.0)); - insertItem( - "thrpt(sec/sec)", - timeTakenSec > 0.0 ? format("%.2f", isztotal / 1000 / timeTakenSec) - : "n/a"); - insertItem("timestamp", getCurrentDate() + " " + getCurrentTime()); - return status; + insertItem("avr-batchsz", format("%7.2f", float(numsamples) / numbatches)); + insertItem("hrs", format("%7.2f", isztotal / 1000 / 3600.0)); + insertItem( + "thrpt(sec/sec)", + timeTakenSec > 0.0 ? format("%.2f", isztotal / 1000 / timeTakenSec) + : "n/a" + ); + insertItem("timestamp", getCurrentDate() + " " + getCurrentTime()); + return status; } void appendToLog(std::ofstream& logfile, const std::string& logstr) { - auto write = [&]() { - logfile.clear(); // reset flags - logfile << logstr << std::endl; - if (!logfile) { - throw std::runtime_error("appending to log failed"); - } - }; - retryWithBackoff(std::chrono::seconds(1), 1.0, 6, write); + auto write = [&]() { + logfile.clear(); // reset flags + logfile << logstr << std::endl; + if(!logfile) { + throw std::runtime_error("appending to log failed"); + } + }; + retryWithBackoff(std::chrono::seconds(1), 1.0, 6, write); } Tensor allreduceGet(SpeechStatMeter& mtr) { - auto mtrValRaw = mtr.value(); - std::vector mtrVal(mtrValRaw.begin(), mtrValRaw.end()); - // Caveat: maxInputSz_, maxTargetSz_ would be approximate - mtrVal[2] *= mtrVal[4]; - mtrVal[3] *= mtrVal[4]; - return Tensor::fromVector(mtrVal); + auto mtrValRaw = mtr.value(); + std::vector mtrVal(mtrValRaw.begin(), mtrValRaw.end()); + // Caveat: maxInputSz_, maxTargetSz_ would be approximate + mtrVal[2] *= mtrVal[4]; + mtrVal[3] *= mtrVal[4]; + return Tensor::fromVector(mtrVal); } void allreduceSet(SpeechStatMeter& mtr, Tensor& val) { - mtr.reset(); - // Caveat: maxInputSz_, maxTargetSz_ would be approximate - auto valVec = val.toHostVector(); - SpeechStats stats; - auto denom = (valVec[4] == 0) ? 1 : valVec[4]; - stats.totalInputSz_ = valVec[0]; - stats.totalTargetSz_ = valVec[1]; - stats.maxInputSz_ = valVec[2] / denom; - stats.maxTargetSz_ = valVec[3] / denom; - stats.numSamples_ = valVec[4]; - stats.numBatches_ = valVec[5]; - mtr.add(stats); + mtr.reset(); + // Caveat: maxInputSz_, maxTargetSz_ would be approximate + auto valVec = val.toHostVector(); + SpeechStats stats; + auto denom = (valVec[4] == 0) ? 1 : valVec[4]; + stats.totalInputSz_ = valVec[0]; + stats.totalTargetSz_ = valVec[1]; + stats.maxInputSz_ = valVec[2] / denom; + stats.maxTargetSz_ = valVec[3] / denom; + stats.numSamples_ = valVec[4]; + stats.numBatches_ = valVec[5]; + mtr.add(stats); } void syncMeter(TrainMeters& mtrs) { - fl::pkg::runtime::syncMeter(mtrs.stats); - fl::pkg::runtime::syncMeter(mtrs.runtime); - fl::pkg::runtime::syncMeter(mtrs.timer); - fl::pkg::runtime::syncMeter(mtrs.fwdtimer); - fl::pkg::runtime::syncMeter(mtrs.critfwdtimer); - fl::pkg::runtime::syncMeter(mtrs.bwdtimer); - fl::pkg::runtime::syncMeter(mtrs.optimtimer); - fl::pkg::runtime::syncMeter(mtrs.train.tknEdit); - fl::pkg::runtime::syncMeter(mtrs.train.wrdEdit); - fl::pkg::runtime::syncMeter(mtrs.train.loss); - for (auto& v : mtrs.valid) { - fl::pkg::runtime::syncMeter(v.second.tknEdit); - fl::pkg::runtime::syncMeter(v.second.wrdEdit); - fl::pkg::runtime::syncMeter(v.second.loss); - } + fl::pkg::runtime::syncMeter(mtrs.stats); + fl::pkg::runtime::syncMeter(mtrs.runtime); + fl::pkg::runtime::syncMeter(mtrs.timer); + fl::pkg::runtime::syncMeter(mtrs.fwdtimer); + fl::pkg::runtime::syncMeter(mtrs.critfwdtimer); + fl::pkg::runtime::syncMeter(mtrs.bwdtimer); + fl::pkg::runtime::syncMeter(mtrs.optimtimer); + fl::pkg::runtime::syncMeter(mtrs.train.tknEdit); + fl::pkg::runtime::syncMeter(mtrs.train.wrdEdit); + fl::pkg::runtime::syncMeter(mtrs.train.loss); + for(auto& v : mtrs.valid) { + fl::pkg::runtime::syncMeter(v.second.tknEdit); + fl::pkg::runtime::syncMeter(v.second.wrdEdit); + fl::pkg::runtime::syncMeter(v.second.loss); + } } } // namespace fl diff --git a/flashlight/pkg/speech/runtime/Logger.h b/flashlight/pkg/speech/runtime/Logger.h index e5ae3e7..a1e0fdf 100644 --- a/flashlight/pkg/speech/runtime/Logger.h +++ b/flashlight/pkg/speech/runtime/Logger.h @@ -16,56 +16,57 @@ namespace fl { namespace pkg { -namespace speech { -struct DatasetMeters { - fl::EditDistanceMeter tknEdit, wrdEdit; - fl::AverageValueMeter loss; -}; + namespace speech { + struct DatasetMeters { + fl::EditDistanceMeter tknEdit, wrdEdit; + fl::AverageValueMeter loss; + }; -struct TrainMeters { - fl::TimeMeter runtime; - fl::TimeMeter timer{true}; - fl::TimeMeter sampletimer{true}; - fl::TimeMeter fwdtimer{true}; // includes network + criterion time - fl::TimeMeter critfwdtimer{true}; - fl::TimeMeter bwdtimer{true}; // includes network + criterion time - fl::TimeMeter optimtimer{true}; + struct TrainMeters { + fl::TimeMeter runtime; + fl::TimeMeter timer{true}; + fl::TimeMeter sampletimer{true}; + fl::TimeMeter fwdtimer{true}; // includes network + criterion time + fl::TimeMeter critfwdtimer{true}; + fl::TimeMeter bwdtimer{true}; // includes network + criterion time + fl::TimeMeter optimtimer{true}; - DatasetMeters train; - std::map valid; + DatasetMeters train; + std::map valid; - SpeechStatMeter stats; -}; + SpeechStatMeter stats; + }; -struct TestMeters { - fl::TimeMeter timer; - fl::EditDistanceMeter wrdDstSlice; - fl::EditDistanceMeter wrdDst; - fl::EditDistanceMeter tknDstSlice; - fl::EditDistanceMeter tknDst; -}; + struct TestMeters { + fl::TimeMeter timer; + fl::EditDistanceMeter wrdDstSlice; + fl::EditDistanceMeter wrdDst; + fl::EditDistanceMeter tknDstSlice; + fl::EditDistanceMeter tknDst; + }; /* * Utility function to log results (learning rate, WER, TER, epoch, timing) * From gflags it uses FLAGS_batchsize, FLAGS_features_type * FLAGS_framestridems, FLAGS_samplerate */ -std::string getLogString( - TrainMeters& meters, - const std::unordered_map& dmErrs, - int64_t epoch, - int64_t nupdates, - double lr, - double lrcrit, - double scaleFactor, - const std::string& separator = " | "); + std::string getLogString( + TrainMeters& meters, + const std::unordered_map& dmErrs, + int64_t epoch, + int64_t nupdates, + double lr, + double lrcrit, + double scaleFactor, + const std::string& separator = " | " + ); -void appendToLog(std::ofstream& logfile, const std::string& logstr); + void appendToLog(std::ofstream& logfile, const std::string& logstr); -Tensor allreduceGet(SpeechStatMeter& mtr); -void allreduceSet(SpeechStatMeter& mtr, Tensor& val); + Tensor allreduceGet(SpeechStatMeter& mtr); + void allreduceSet(SpeechStatMeter& mtr, Tensor& val); -void syncMeter(TrainMeters& mtrs); -} // namespace speech + void syncMeter(TrainMeters& mtrs); + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/runtime/Optimizer.cpp b/flashlight/pkg/speech/runtime/Optimizer.cpp index b1fc262..b4987af 100644 --- a/flashlight/pkg/speech/runtime/Optimizer.cpp +++ b/flashlight/pkg/speech/runtime/Optimizer.cpp @@ -16,59 +16,74 @@ std::shared_ptr initOptimizer( const std::string& optimizer, double lr, double momentum, - double weightdecay) { - if (nets.empty()) { - throw std::invalid_argument( - "[InitOptimizer]: No network for initializing the optimizer"); - } + double weightdecay +) { + if(nets.empty()) { + throw std::invalid_argument( + "[InitOptimizer]: No network for initializing the optimizer" + ); + } - std::vector params; - for (const auto& n : nets) { - auto p = n->params(); - params.insert(params.end(), p.begin(), p.end()); - } + std::vector params; + for(const auto& n : nets) { + auto p = n->params(); + params.insert(params.end(), p.begin(), p.end()); + } - std::shared_ptr opt; - if (optimizer == kSGDOptimizer) { - opt = std::make_shared(params, lr, momentum, weightdecay); - } else if (optimizer == kAdamOptimizer) { - opt = std::make_shared( - params, - lr, - FLAGS_adambeta1, - FLAGS_adambeta2, - FLAGS_optimepsilon, - weightdecay); - } else if (optimizer == kRMSPropOptimizer) { - opt = std::make_shared( - params, lr, FLAGS_optimrho, FLAGS_optimepsilon, weightdecay); - } else if (optimizer == kAdadeltaOptimizer) { - opt = std::make_shared( - params, lr, FLAGS_optimrho, FLAGS_optimepsilon, weightdecay); - } else if (optimizer == kAdagradOptimizer) { - opt = - std::make_shared(params, lr, FLAGS_optimepsilon); - } else if (optimizer == kAMSgradOptimizer) { - opt = std::make_shared( - params, - lr, - FLAGS_adambeta1, - FLAGS_adambeta2, - FLAGS_optimepsilon, - weightdecay); + std::shared_ptr opt; + if(optimizer == kSGDOptimizer) { + opt = std::make_shared(params, lr, momentum, weightdecay); + } else if(optimizer == kAdamOptimizer) { + opt = std::make_shared( + params, + lr, + FLAGS_adambeta1, + FLAGS_adambeta2, + FLAGS_optimepsilon, + weightdecay + ); + } else if(optimizer == kRMSPropOptimizer) { + opt = std::make_shared( + params, + lr, + FLAGS_optimrho, + FLAGS_optimepsilon, + weightdecay + ); + } else if(optimizer == kAdadeltaOptimizer) { + opt = std::make_shared( + params, + lr, + FLAGS_optimrho, + FLAGS_optimepsilon, + weightdecay + ); + } else if(optimizer == kAdagradOptimizer) { + opt = + std::make_shared(params, lr, FLAGS_optimepsilon); + } else if(optimizer == kAMSgradOptimizer) { + opt = std::make_shared( + params, + lr, + FLAGS_adambeta1, + FLAGS_adambeta2, + FLAGS_optimepsilon, + weightdecay + ); - } else if (optimizer == kNovogradOptimizer) { - opt = std::make_shared( - params, - lr, - FLAGS_adambeta1, - FLAGS_adambeta2, - FLAGS_optimepsilon, - weightdecay); - } else { - LOG(FATAL) << "Optimizer option " << optimizer << " not implemented"; - } + } else if(optimizer == kNovogradOptimizer) { + opt = std::make_shared( + params, + lr, + FLAGS_adambeta1, + FLAGS_adambeta2, + FLAGS_optimepsilon, + weightdecay + ); + } else { + LOG(FATAL) << "Optimizer option " << optimizer << " not implemented"; + } - return opt; + return opt; } } // namespace fl diff --git a/flashlight/pkg/speech/runtime/Optimizer.h b/flashlight/pkg/speech/runtime/Optimizer.h index 888e053..cef77ae 100644 --- a/flashlight/pkg/speech/runtime/Optimizer.h +++ b/flashlight/pkg/speech/runtime/Optimizer.h @@ -14,7 +14,7 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { /* * Utility function to create an optimizer. @@ -22,12 +22,13 @@ namespace speech { * `amsgrad`, `novograd`. From gflags it uses FLAGS_optimrho, FLAGS_adambeta1, * FLAGS_adambeta2, FLAGS_optimepsilon, */ -std::shared_ptr initOptimizer( - const std::vector>& nets, - const std::string& optimizer, - double lr, - double momentum, - double weightdecay); -} // namespace speech + std::shared_ptr initOptimizer( + const std::vector>& nets, + const std::string& optimizer, + double lr, + double momentum, + double weightdecay + ); + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/runtime/SpeechStatMeter.cpp b/flashlight/pkg/speech/runtime/SpeechStatMeter.cpp index 44e0420..1128294 100644 --- a/flashlight/pkg/speech/runtime/SpeechStatMeter.cpp +++ b/flashlight/pkg/speech/runtime/SpeechStatMeter.cpp @@ -10,65 +10,65 @@ namespace fl::pkg::speech { SpeechStatMeter::SpeechStatMeter() { - reset(); + reset(); } void SpeechStatMeter::reset() { - stats_.reset(); + stats_.reset(); } void SpeechStatMeter::add(const Tensor& inputSizes, const Tensor& targetSizes) { - int64_t curInputSz = fl::sum(inputSizes).asScalar(); - int64_t curTargetSz = fl::sum(targetSizes).asScalar(); + int64_t curInputSz = fl::sum(inputSizes).asScalar(); + int64_t curTargetSz = fl::sum(targetSizes).asScalar(); - stats_.totalInputSz_ += curInputSz; - stats_.totalTargetSz_ += curTargetSz; + stats_.totalInputSz_ += curInputSz; + stats_.totalTargetSz_ += curTargetSz; - stats_.maxInputSz_ = - std::max(stats_.maxInputSz_, fl::amax(inputSizes).asScalar()); - stats_.maxTargetSz_ = - std::max(stats_.maxTargetSz_, fl::amax(targetSizes).asScalar()); + stats_.maxInputSz_ = + std::max(stats_.maxInputSz_, fl::amax(inputSizes).asScalar()); + stats_.maxTargetSz_ = + std::max(stats_.maxTargetSz_, fl::amax(targetSizes).asScalar()); - stats_.numSamples_ += inputSizes.dim(1); - stats_.numBatches_++; + stats_.numSamples_ += inputSizes.dim(1); + stats_.numBatches_++; } void SpeechStatMeter::add(const SpeechStats& stats) { - stats_.totalInputSz_ += stats.totalInputSz_; - stats_.totalTargetSz_ += stats.totalTargetSz_; + stats_.totalInputSz_ += stats.totalInputSz_; + stats_.totalTargetSz_ += stats.totalTargetSz_; - stats_.maxInputSz_ = std::max(stats_.maxInputSz_, stats.maxInputSz_); - stats_.maxTargetSz_ = std::max(stats_.maxTargetSz_, stats.maxTargetSz_); + stats_.maxInputSz_ = std::max(stats_.maxInputSz_, stats.maxInputSz_); + stats_.maxTargetSz_ = std::max(stats_.maxTargetSz_, stats.maxTargetSz_); - stats_.numSamples_ += stats.numSamples_; - stats_.numBatches_ += stats.numBatches_; + stats_.numSamples_ += stats.numSamples_; + stats_.numBatches_ += stats.numBatches_; } std::vector SpeechStatMeter::value() const { - return stats_.toArray(); + return stats_.toArray(); } SpeechStats::SpeechStats() { - reset(); + reset(); } void SpeechStats::reset() { - totalInputSz_ = 0; - totalTargetSz_ = 0; - maxInputSz_ = 0; - maxTargetSz_ = 0; - numSamples_ = 0; - numBatches_ = 0; + totalInputSz_ = 0; + totalTargetSz_ = 0; + maxInputSz_ = 0; + maxTargetSz_ = 0; + numSamples_ = 0; + numBatches_ = 0; } std::vector SpeechStats::toArray() const { - std::vector arr(6); - arr[0] = totalInputSz_; - arr[1] = totalTargetSz_; - arr[2] = maxInputSz_; - arr[3] = maxTargetSz_; - arr[4] = numSamples_; - arr[5] = numBatches_; - return arr; + std::vector arr(6); + arr[0] = totalInputSz_; + arr[1] = totalTargetSz_; + arr[2] = maxInputSz_; + arr[3] = maxTargetSz_; + arr[4] = numSamples_; + arr[5] = numBatches_; + return arr; } } // namespace fl diff --git a/flashlight/pkg/speech/runtime/SpeechStatMeter.h b/flashlight/pkg/speech/runtime/SpeechStatMeter.h index d4fc566..a530e7c 100644 --- a/flashlight/pkg/speech/runtime/SpeechStatMeter.h +++ b/flashlight/pkg/speech/runtime/SpeechStatMeter.h @@ -11,32 +11,32 @@ namespace fl { namespace pkg { -namespace speech { + namespace speech { -struct SpeechStats { - int64_t totalInputSz_; - int64_t totalTargetSz_; - int64_t maxInputSz_; - int64_t maxTargetSz_; - int64_t numSamples_; - int64_t numBatches_; + struct SpeechStats { + int64_t totalInputSz_; + int64_t totalTargetSz_; + int64_t maxInputSz_; + int64_t maxTargetSz_; + int64_t numSamples_; + int64_t numBatches_; - SpeechStats(); - void reset(); - std::vector toArray() const; -}; + SpeechStats(); + void reset(); + std::vector toArray() const; + }; -class SpeechStatMeter { - public: - SpeechStatMeter(); - void add(const Tensor& inputSizes, const Tensor& targetSizes); - void add(const SpeechStats& stats); - std::vector value() const; - void reset(); + class SpeechStatMeter { + public: + SpeechStatMeter(); + void add(const Tensor& inputSizes, const Tensor& targetSizes); + void add(const SpeechStats& stats); + std::vector value() const; + void reset(); - private: - SpeechStats stats_; -}; -} // namespace speech + private: + SpeechStats stats_; + }; + } // namespace speech } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/speech/test/audio/CeplifterTest.cpp b/flashlight/pkg/speech/test/audio/CeplifterTest.cpp index f2c3624..9fdd255 100644 --- a/flashlight/pkg/speech/test/audio/CeplifterTest.cpp +++ b/flashlight/pkg/speech/test/audio/CeplifterTest.cpp @@ -16,60 +16,66 @@ using fl::lib::audio::Ceplifter; // ceplifter = @( N, L )( 1+0.5*L*sin(pi*[0:N-1]/L) ); // CC = diag( lifter ) * CC; // Reference: Kamil Wojcicki, HTK MFCC MATLAB, URL: -// https://www.mathworks.com/matlabcentral/fileexchange/32849-htk-mfcc-matlab +// https://www.mathworks.com/matlabcentral/fileexchange/32849-htk-mfcc-matlab TEST(CeplifterTest, matlabCompareTest) { - // Test Case: 1 - Ceplifter cep1(25, 22); - std::vector input1(25, 1.0); - std::vector matlaboutput1{ - 1, 2.565463, 4.099058, 5.569565, 6.947048, 8.203468, 9.313245, - 10.25378, 11.00595, 11.55442, 11.88803, 12, 11.88803, 11.55442, - 11.00595, 10.25378, 9.313245, 8.203468, 6.947048, 5.569565, 4.099058, - 2.565463, 1.000000, -0.5654632, -2.0990581}; - auto output1 = cep1.apply(input1); - // Implementation should match with matlab for Test case 1. - ASSERT_TRUE(compareVec(output1, matlaboutput1)); - // Test Case: 2 - Ceplifter cep2(40, 13); - std::vector input2{ - 3.827583, 3.975999, 0.9343630, 2.448821, 2.227931, 3.231565, 3.546824, - 3.773433, 1.380125, 3.398513, 3.275490, 0.8130586, 0.5949884, 2.491820, - 4.798719, 1.701928, 2.926338, 1.119059, 3.756335, 1.275475, 2.529785, - 3.495383, 4.454516, 4.796457, 2.736077, 0.6931222, 0.7464700, 1.287541, - 4.203586, 1.271410, 4.071424, 1.217624, 4.646318, 1.749918, 0.9829762, - 1.255419, 3.080223, 2.366444, 1.758297, 4.154143}; - std::vector matlaboutput2{ - 3.82758300, 10.1608714, 3.75679389, 13.0039674, 14.1460142, - 22.8717424, 26.4330877, 28.1219157, 9.76798039, 21.5785018, - 17.3938256, 3.26906521, 1.52052368, 2.49182000, -2.66593706, - -3.43908696, -9.68704871, -4.86722976, -19.0731875, -6.95466478, - -13.7939251, -17.7481762, -19.3744501, -15.8776985, -5.52879248, - -0.385065298, 0.746470000, 3.29037774, 16.9013608, 6.75156506, - 25.8510797, 8.61786241, 34.6271852, 13.0414523, 6.95711783, - 7.97115128, 16.3568998, 9.51476285, 4.49341909, 4.15414300}; - auto output2 = cep2.apply(input2); - // Implementation should match with matlab for Test case 2. - ASSERT_TRUE(compareVec(output2, matlaboutput2)); + // Test Case: 1 + Ceplifter cep1(25, 22); + std::vector input1(25, 1.0); + std::vector matlaboutput1{ + 1, 2.565463, 4.099058, 5.569565, 6.947048, 8.203468, 9.313245, + 10.25378, 11.00595, 11.55442, 11.88803, 12, 11.88803, 11.55442, + 11.00595, 10.25378, 9.313245, 8.203468, 6.947048, 5.569565, 4.099058, + 2.565463, 1.000000, -0.5654632, -2.0990581}; + auto output1 = cep1.apply(input1); + // Implementation should match with matlab for Test case 1. + ASSERT_TRUE(compareVec(output1, matlaboutput1)); + // Test Case: 2 + Ceplifter cep2(40, 13); + std::vector input2{ + 3.827583, 3.975999, 0.9343630, 2.448821, 2.227931, 3.231565, 3.546824, + 3.773433, 1.380125, 3.398513, 3.275490, 0.8130586, 0.5949884, 2.491820, + 4.798719, 1.701928, 2.926338, 1.119059, 3.756335, 1.275475, 2.529785, + 3.495383, 4.454516, 4.796457, 2.736077, 0.6931222, 0.7464700, 1.287541, + 4.203586, 1.271410, 4.071424, 1.217624, 4.646318, 1.749918, 0.9829762, + 1.255419, 3.080223, 2.366444, 1.758297, 4.154143}; + std::vector matlaboutput2{ + 3.82758300, 10.1608714, 3.75679389, 13.0039674, 14.1460142, + 22.8717424, 26.4330877, 28.1219157, 9.76798039, 21.5785018, + 17.3938256, 3.26906521, 1.52052368, 2.49182000, -2.66593706, + -3.43908696, -9.68704871, -4.86722976, -19.0731875, -6.95466478, + -13.7939251, -17.7481762, -19.3744501, -15.8776985, -5.52879248, + -0.385065298, 0.746470000, 3.29037774, 16.9013608, 6.75156506, + 25.8510797, 8.61786241, 34.6271852, 13.0414523, 6.95711783, + 7.97115128, 16.3568998, 9.51476285, 4.49341909, 4.15414300}; + auto output2 = cep2.apply(input2); + // Implementation should match with matlab for Test case 2. + ASSERT_TRUE(compareVec(output2, matlaboutput2)); } TEST(CeplifterTest, batchingTest) { - int N = 16, B = 15; - auto input = randVec(N * B); - auto cep = Ceplifter(N, 25); - auto output = cep.apply(input); - ASSERT_EQ(output.size(), input.size()); - for (int i = 0; i < B; ++i) { - std::vector curInput(N), expOutput(N); - std::copy( - input.data() + i * N, input.data() + (i + 1) * N, curInput.data()); - std::copy( - output.data() + i * N, output.data() + (i + 1) * N, expOutput.data()); - auto curOutput = cep.apply(curInput); - ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-10)); - } + int N = 16, B = 15; + auto input = randVec(N * B); + auto cep = Ceplifter(N, 25); + auto output = cep.apply(input); + ASSERT_EQ(output.size(), input.size()); + for(int i = 0; i < B; ++i) { + std::vector curInput(N), expOutput(N); + std::copy( + input.data() + i * N, + input.data() + (i + 1) * N, + curInput.data() + ); + std::copy( + output.data() + i * N, + output.data() + (i + 1) * N, + expOutput.data() + ); + auto curOutput = cep.apply(curInput); + ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-10)); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/audio/DctTest.cpp b/flashlight/pkg/speech/test/audio/DctTest.cpp index 126e5e6..cf6cdae 100644 --- a/flashlight/pkg/speech/test/audio/DctTest.cpp +++ b/flashlight/pkg/speech/test/audio/DctTest.cpp @@ -14,56 +14,62 @@ using fl::lib::audio::Dct; // Matlab code used: // dctm = @( N, M )( sqrt(2.0/M) * cos( repmat([0:N-1].',1,M) ... -// .* repmat(pi*([1:M]-0.5)/M,N,1) ) ); +// .* repmat(pi*([1:M]-0.5)/M,N,1) ) ); // DCT * IN; // Reference: Kamil Wojcicki, HTK MFCC MATLAB, URL: -// https://www.mathworks.com/matlabcentral/fileexchange/32849-htk-mfcc-matlab +// https://www.mathworks.com/matlabcentral/fileexchange/32849-htk-mfcc-matlab TEST(DctTest, matlabCompareTest) { - // Test Case: 1 - Dct dct1(9, 6); - std::vector input1(9, 1.0); - std::vector matlaboutput1{4.24264, 0.0, 0.0, 0.0, 0.0, 0.0}; - auto output1 = dct1.apply(input1); - // Implementation should match with matlab for Test case 1. - ASSERT_TRUE(compareVec(output1, matlaboutput1)); + // Test Case: 1 + Dct dct1(9, 6); + std::vector input1(9, 1.0); + std::vector matlaboutput1{4.24264, 0.0, 0.0, 0.0, 0.0, 0.0}; + auto output1 = dct1.apply(input1); + // Implementation should match with matlab for Test case 1. + ASSERT_TRUE(compareVec(output1, matlaboutput1)); - // Test Case: 2 - Dct dct2(40, 23); - std::vector input2{ - 3.827583, 3.975999, 0.9343630, 2.448821, 2.227931, 3.231565, 3.546824, - 3.773433, 1.380125, 3.398513, 3.275490, 0.8130586, 0.5949884, 2.491820, - 4.798719, 1.701928, 2.926338, 1.119059, 3.756335, 1.275475, 2.529785, - 3.495383, 4.454516, 4.796457, 2.736077, 0.6931222, 0.7464700, 1.287541, - 4.203586, 1.271410, 4.071424, 1.217624, 4.646318, 1.749918, 0.9829762, - 1.255419, 3.080223, 2.366444, 1.758297, 4.154143}; - std::vector matlaboutput2{ - 23.03049, 0.7171224, 0.09039740, 0.5560513, 1.210070, -0.6701894, - -0.7615307, 0.1116579, 1.157483, -2.012746, 2.964205, 2.444191, - -0.4926429, -0.1332636, 1.275104, 0.2767147, 0.2781188, 2.661390, - -0.03644234, -2.326455, -0.1963445, -1.229159, 2.124846}; - auto output2 = dct2.apply(input2); - // Implementation should match with matlab for Test case 2. - ASSERT_TRUE(compareVec(output2, matlaboutput2)); + // Test Case: 2 + Dct dct2(40, 23); + std::vector input2{ + 3.827583, 3.975999, 0.9343630, 2.448821, 2.227931, 3.231565, 3.546824, + 3.773433, 1.380125, 3.398513, 3.275490, 0.8130586, 0.5949884, 2.491820, + 4.798719, 1.701928, 2.926338, 1.119059, 3.756335, 1.275475, 2.529785, + 3.495383, 4.454516, 4.796457, 2.736077, 0.6931222, 0.7464700, 1.287541, + 4.203586, 1.271410, 4.071424, 1.217624, 4.646318, 1.749918, 0.9829762, + 1.255419, 3.080223, 2.366444, 1.758297, 4.154143}; + std::vector matlaboutput2{ + 23.03049, 0.7171224, 0.09039740, 0.5560513, 1.210070, -0.6701894, + -0.7615307, 0.1116579, 1.157483, -2.012746, 2.964205, 2.444191, + -0.4926429, -0.1332636, 1.275104, 0.2767147, 0.2781188, 2.661390, + -0.03644234, -2.326455, -0.1963445, -1.229159, 2.124846}; + auto output2 = dct2.apply(input2); + // Implementation should match with matlab for Test case 2. + ASSERT_TRUE(compareVec(output2, matlaboutput2)); } TEST(DctTest, batchingTest) { - int F = 16, C = 10, B = 15; - auto input = randVec(F * B); - auto dct = Dct(F, C); - auto output = dct.apply(input); - ASSERT_EQ(output.size(), C * B); - for (int i = 0; i < B; ++i) { - std::vector curInput(F), expOutput(C); - std::copy( - input.data() + i * F, input.data() + (i + 1) * F, curInput.data()); - std::copy( - output.data() + i * C, output.data() + (i + 1) * C, expOutput.data()); - auto curOutput = dct.apply(curInput); - ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-5)); - } + int F = 16, C = 10, B = 15; + auto input = randVec(F * B); + auto dct = Dct(F, C); + auto output = dct.apply(input); + ASSERT_EQ(output.size(), C * B); + for(int i = 0; i < B; ++i) { + std::vector curInput(F), expOutput(C); + std::copy( + input.data() + i * F, + input.data() + (i + 1) * F, + curInput.data() + ); + std::copy( + output.data() + i * C, + output.data() + (i + 1) * C, + expOutput.data() + ); + auto curOutput = dct.apply(curInput); + ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-5)); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/audio/DerivativesTest.cpp b/flashlight/pkg/speech/test/audio/DerivativesTest.cpp index d794a32..0b25fb0 100644 --- a/flashlight/pkg/speech/test/audio/DerivativesTest.cpp +++ b/flashlight/pkg/speech/test/audio/DerivativesTest.cpp @@ -14,103 +14,103 @@ using fl::lib::audio::Derivatives; // Reference C++ code taken from HTK - http://htk.eng.cam.ac.uk/ -// float *fp,*fp1,*fp2, *back, *forw; -// float sum, sigmaT2; -// int i,t,j; +// float *fp,*fp1,*fp2, *back, *forw; +// float sum, sigmaT2; +// int i,t,j; // -// sigmaT2 = 0.0; -// for (t=1;t<=delwin;t++) -// sigmaT2 += t*t; -// sigmaT2 *= 2.0; -// fp = data; -// for (i=1;i<=n;i++){ -// fp1 = fp; fp2 = fp+offset; -// for (j=1;j<=vSize;j++){ -// back = forw = fp1; sum = 0.0; -// for (t=1;t<=delwin;t++) { -// if (head+i-t > 0) back -= step; -// if (tail+n-i+1-t > 0) forw += step; -// if (!simpleDiffs) sum += t * (*forw - *back); -// } -// if (simpleDiffs) -// *fp2 = (*forw - *back) / (2*delwin); -// else -// *fp2 = sum / sigmaT2; -// ++fp1; ++fp2; -// } -// fp += step; -// } +// sigmaT2 = 0.0; +// for (t=1;t<=delwin;t++) +// sigmaT2 += t*t; +// sigmaT2 *= 2.0; +// fp = data; +// for (i=1;i<=n;i++){ +// fp1 = fp; fp2 = fp+offset; +// for (j=1;j<=vSize;j++){ +// back = forw = fp1; sum = 0.0; +// for (t=1;t<=delwin;t++) { +// if (head+i-t > 0) back -= step; +// if (tail+n-i+1-t > 0) forw += step; +// if (!simpleDiffs) sum += t * (*forw - *back); +// } +// if (simpleDiffs) +// *fp2 = (*forw - *back) / (2*delwin); +// else +// *fp2 = sum / sigmaT2; +// ++fp1; ++fp2; +// } +// fp += step; +// } TEST(DerivativesTest, matlabCompareTest) { - // Test Case: 1 - Derivatives dev1(4, 4); - std::vector input1(12); - std::iota(input1.begin(), input1.end(), 0.0); - std::vector matlaboutput1{ - 0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, - 6.0000, 7.0000, 8.0000, 9.0000, 10.0000, 11.0000, - 0.5000000, 0.6666667, 0.8166667, 0.9333333, 1.0000000, 1.0000000, - 1.0000000, 1.0000000, 0.9333333, 0.8166667, 0.6666667, 0.5000000, - 0.0683333, 0.0780556, 0.0794444, 0.0725000, 0.0527778, 0.0180556, - -0.0180556, -0.0527778, -0.0725000, -0.0794444, -0.0780556, -0.0683333}; - auto output1 = dev1.apply(input1, 1); - // Implementation should match with matlab for Test case 1. - ASSERT_TRUE(compareVec(output1, transposeVec(matlaboutput1, 3, 12))); + // Test Case: 1 + Derivatives dev1(4, 4); + std::vector input1(12); + std::iota(input1.begin(), input1.end(), 0.0); + std::vector matlaboutput1{ + 0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, + 6.0000, 7.0000, 8.0000, 9.0000, 10.0000, 11.0000, + 0.5000000, 0.6666667, 0.8166667, 0.9333333, 1.0000000, 1.0000000, + 1.0000000, 1.0000000, 0.9333333, 0.8166667, 0.6666667, 0.5000000, + 0.0683333, 0.0780556, 0.0794444, 0.0725000, 0.0527778, 0.0180556, + -0.0180556, -0.0527778, -0.0725000, -0.0794444, -0.0780556, -0.0683333}; + auto output1 = dev1.apply(input1, 1); + // Implementation should match with matlab for Test case 1. + ASSERT_TRUE(compareVec(output1, transposeVec(matlaboutput1, 3, 12))); - // Test Case: 2 - Derivatives dev2(9, 7); - std::vector input2{ - 3.827583, 3.975999, 0.9343630, 2.448821, 2.227931, 3.231565, 3.546824, - 3.773433, 1.380125, 3.398513, 3.275490, 0.8130586, 0.5949884, 2.491820, - 4.798719, 1.701928, 2.926338, 1.119059, 3.756335, 1.275475, 2.529785, - 3.495383, 4.454516, 4.796457, 2.736077, 0.6931222, 0.7464700, 1.287541, - 4.203586, 1.271410, 4.071424, 1.217624, 4.646318, 1.749918, 0.9829762, - 1.255419, 3.080223, 2.366444, 1.758297, 4.154143}; - std::vector matlaboutput2{ - 3.827583, 3.975999, 0.9343630, 2.448821, 2.227931, 3.231565, - 3.546824, 3.773433, 1.380125, 3.398513, 3.275490, 0.8130586, - 0.5949884, 2.491820, 4.798719, 1.701928, 2.926338, 1.119059, - 3.756335, 1.275475, 2.529785, 3.495383, 4.454516, 4.796457, - 2.736077, 0.6931222, 0.7464700, 1.287541, 4.203586, 1.271410, - 4.071424, 1.217624, 4.646318, 1.749918, 0.9829762, 1.255419, - 3.080223, 2.366444, 1.758297, 4.154143, -0.0783472, -0.0703440, - -0.1002527, -0.1283159, -0.1207580, -0.0744319, -0.0787063, -0.0599186, - -0.0680858, -0.0298600, -0.0306807, -0.0046153, -0.0141285, 0.0135790, - 0.0392915, 0.0455732, 0.0259977, 0.0162468, -0.0216384, 0.0220920, - 0.0159542, 0.0143425, -0.0418714, -0.0117627, 0.0093056, -0.0307167, - -0.0436951, -0.0566360, -0.0380197, -0.0700912, -0.0431751, -0.0021685, - 0.0545093, 0.1177130, 0.1458966, 0.1357510, 0.1204694, 0.1087019, - 0.1430639, 0.1260710, -0.0007462, -0.0001886, 0.0012880, 0.0025709, - 0.0043352, 0.0055983, 0.0073248, 0.0093658, 0.0111437, 0.0122183, - 0.0118505, 0.0093177, 0.0077131, 0.0067685, 0.0053387, 0.0027080, - 0.0005322, -0.0002259, -0.0021479, -0.0036494, -0.0056067, -0.0061552, - -0.0065865, -0.0057748, -0.0041803, -0.0013468, 0.0018477, 0.0064985, - 0.0102782, 0.0132019, 0.0138463, 0.0156723, 0.0171224, 0.0170120, - 0.0159708, 0.0139536, 0.0118158, 0.0081756, 0.0046038, 0.0015992}; - auto output2 = dev2.apply(input2, 1); - // Implementation should match with matlab for Test case 2. - ASSERT_TRUE(compareVec(output2, transposeVec(matlaboutput2, 3, 40))); + // Test Case: 2 + Derivatives dev2(9, 7); + std::vector input2{ + 3.827583, 3.975999, 0.9343630, 2.448821, 2.227931, 3.231565, 3.546824, + 3.773433, 1.380125, 3.398513, 3.275490, 0.8130586, 0.5949884, 2.491820, + 4.798719, 1.701928, 2.926338, 1.119059, 3.756335, 1.275475, 2.529785, + 3.495383, 4.454516, 4.796457, 2.736077, 0.6931222, 0.7464700, 1.287541, + 4.203586, 1.271410, 4.071424, 1.217624, 4.646318, 1.749918, 0.9829762, + 1.255419, 3.080223, 2.366444, 1.758297, 4.154143}; + std::vector matlaboutput2{ + 3.827583, 3.975999, 0.9343630, 2.448821, 2.227931, 3.231565, + 3.546824, 3.773433, 1.380125, 3.398513, 3.275490, 0.8130586, + 0.5949884, 2.491820, 4.798719, 1.701928, 2.926338, 1.119059, + 3.756335, 1.275475, 2.529785, 3.495383, 4.454516, 4.796457, + 2.736077, 0.6931222, 0.7464700, 1.287541, 4.203586, 1.271410, + 4.071424, 1.217624, 4.646318, 1.749918, 0.9829762, 1.255419, + 3.080223, 2.366444, 1.758297, 4.154143, -0.0783472, -0.0703440, + -0.1002527, -0.1283159, -0.1207580, -0.0744319, -0.0787063, -0.0599186, + -0.0680858, -0.0298600, -0.0306807, -0.0046153, -0.0141285, 0.0135790, + 0.0392915, 0.0455732, 0.0259977, 0.0162468, -0.0216384, 0.0220920, + 0.0159542, 0.0143425, -0.0418714, -0.0117627, 0.0093056, -0.0307167, + -0.0436951, -0.0566360, -0.0380197, -0.0700912, -0.0431751, -0.0021685, + 0.0545093, 0.1177130, 0.1458966, 0.1357510, 0.1204694, 0.1087019, + 0.1430639, 0.1260710, -0.0007462, -0.0001886, 0.0012880, 0.0025709, + 0.0043352, 0.0055983, 0.0073248, 0.0093658, 0.0111437, 0.0122183, + 0.0118505, 0.0093177, 0.0077131, 0.0067685, 0.0053387, 0.0027080, + 0.0005322, -0.0002259, -0.0021479, -0.0036494, -0.0056067, -0.0061552, + -0.0065865, -0.0057748, -0.0041803, -0.0013468, 0.0018477, 0.0064985, + 0.0102782, 0.0132019, 0.0138463, 0.0156723, 0.0171224, 0.0170120, + 0.0159708, 0.0139536, 0.0118158, 0.0081756, 0.0046038, 0.0015992}; + auto output2 = dev2.apply(input2, 1); + // Implementation should match with matlab for Test case 2. + ASSERT_TRUE(compareVec(output2, transposeVec(matlaboutput2, 3, 40))); } TEST(DerivativesTest, batchingTest) { - int numFeat = 60, frameSz = 20; - auto input = randVec(numFeat * frameSz); - Derivatives dev(6, 7); - auto output = dev.apply(input, numFeat); - ASSERT_EQ(output.size(), input.size() * 3); - for (int i = 0; i < numFeat; ++i) { - std::vector curInput(frameSz), expOutput(frameSz * 3); - for (int j = 0; j < frameSz; ++j) { - curInput[j] = input[j * numFeat + i]; - expOutput[j * 3] = output[j * numFeat * 3 + i]; - expOutput[j * 3 + 1] = output[j * numFeat * 3 + numFeat + i]; - expOutput[j * 3 + 2] = output[j * numFeat * 3 + 2 * numFeat + i]; + int numFeat = 60, frameSz = 20; + auto input = randVec(numFeat * frameSz); + Derivatives dev(6, 7); + auto output = dev.apply(input, numFeat); + ASSERT_EQ(output.size(), input.size() * 3); + for(int i = 0; i < numFeat; ++i) { + std::vector curInput(frameSz), expOutput(frameSz * 3); + for(int j = 0; j < frameSz; ++j) { + curInput[j] = input[j * numFeat + i]; + expOutput[j * 3] = output[j * numFeat * 3 + i]; + expOutput[j * 3 + 1] = output[j * numFeat * 3 + numFeat + i]; + expOutput[j * 3 + 2] = output[j * numFeat * 3 + 2 * numFeat + i]; + } + auto curOutput = dev.apply(curInput, 1); + ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-4)); } - auto curOutput = dev.apply(curInput, 1); - ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-4)); - } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/audio/DitherTest.cpp b/flashlight/pkg/speech/test/audio/DitherTest.cpp index 9cc7abd..d5eeeab 100644 --- a/flashlight/pkg/speech/test/audio/DitherTest.cpp +++ b/flashlight/pkg/speech/test/audio/DitherTest.cpp @@ -15,40 +15,40 @@ using fl::lib::audio::Dither; TEST(DitherTest, basicTest) { - int N = 1000; - - for (int bch = 1; bch <= 8; bch *= 2) { - Dither ditherpos(0.01); - auto input = randVec(N * bch); - auto output = ditherpos.apply(input); - // Dithering should change input slightly. - ASSERT_FALSE(compareVec(output, input, 1E-5)); - - Dither ditherpos2(0.01); - auto output2 = ditherpos2.apply(input); - // Dither constant > 0 should give same result in multiple runs - ASSERT_TRUE(compareVec(output, output2, 1E-5)); - } - - for (int bch = 1; bch <= 8; bch *= 2) { - Dither ditherneg(-0.01); - auto input = randVec(N * bch); - auto output = ditherneg.apply(input); - // Dithering should change input slightly. - ASSERT_FALSE(compareVec(output, input, 1E-6)); - - // time(NULL) resolution is in seconds - std::chrono::seconds dura(2); - std::this_thread::sleep_for(dura); - - Dither ditherneg2(-0.01); - auto output2 = ditherneg2.apply(input); - // Dithering should change input slightly. - ASSERT_FALSE(compareVec(output, input, 1E-6)); - } + int N = 1000; + + for(int bch = 1; bch <= 8; bch *= 2) { + Dither ditherpos(0.01); + auto input = randVec(N * bch); + auto output = ditherpos.apply(input); + // Dithering should change input slightly. + ASSERT_FALSE(compareVec(output, input, 1E-5)); + + Dither ditherpos2(0.01); + auto output2 = ditherpos2.apply(input); + // Dither constant > 0 should give same result in multiple runs + ASSERT_TRUE(compareVec(output, output2, 1E-5)); + } + + for(int bch = 1; bch <= 8; bch *= 2) { + Dither ditherneg(-0.01); + auto input = randVec(N * bch); + auto output = ditherneg.apply(input); + // Dithering should change input slightly. + ASSERT_FALSE(compareVec(output, input, 1E-6)); + + // time(NULL) resolution is in seconds + std::chrono::seconds dura(2); + std::this_thread::sleep_for(dura); + + Dither ditherneg2(-0.01); + auto output2 = ditherneg2.apply(input); + // Dithering should change input slightly. + ASSERT_FALSE(compareVec(output, input, 1E-6)); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/audio/MfccTest.cpp b/flashlight/pkg/speech/test/audio/MfccTest.cpp index 08371cf..310a26b 100644 --- a/flashlight/pkg/speech/test/audio/MfccTest.cpp +++ b/flashlight/pkg/speech/test/audio/MfccTest.cpp @@ -23,189 +23,190 @@ namespace { fs::path loadPath = ""; auto loadData = [](const std::string& filepath) { - std::vector data; - std::ifstream file(filepath); - std::istream_iterator eos; - std::istream_iterator iit(file); - std::copy(iit, eos, std::back_inserter(data)); - return data; -}; + std::vector data; + std::ifstream file(filepath); + std::istream_iterator eos; + std::istream_iterator iit(file); + std::copy(iit, eos, std::back_inserter(data)); + return data; + }; } // namespace // HTK Code used - -// HCopy -C config.mfcc sa1.wav sa1-mfcc.htk +// HCopy -C config.mfcc sa1.wav sa1-mfcc.htk // Reference : https://labrosa.ee.columbia.edu/matlab/rastamat/mfccs.html TEST(MfccTest, htkCompareTest) { - // read wav data - auto wavinput = loadData(loadPath / "sa1.dat"); - ASSERT_TRUE(!wavinput.empty() && "sa1 frames not loaded properly!"); - - // read expected output data computed from HTK - auto htkfeat = loadData(loadPath / "sa1-mfcc.htk"); - // HTK features not read properly! - ASSERT_TRUE(!htkfeat.empty()); - FeatureParams params; - params.samplingFreq = 16000; - params.lowFreqFilterbank = 0; - params.highFreqFilterbank = 8000; - params.zeroMeanFrame = true; - params.numFilterbankChans = 20; - params.numCepstralCoeffs = 13; - params.useEnergy = false; - params.zeroMeanFrame = false; - params.usePower = false; - Mfcc mfcc(params); - auto feat = mfcc.apply(wavinput); - ASSERT_EQ(feat.size(), htkfeat.size()); - - ASSERT_TRUE(feat.size() % 39 == 0); - auto numframes = feat.size() / 39; - - // HTK keeps C0 at last position. adjust accordingly. - auto featcopy(feat); - for (int f = 0; f < numframes; ++f) { - for (int i = 1; i < 39; ++i) { - feat[f * 39 + i - 1] = feat[f * 39 + i]; + // read wav data + auto wavinput = loadData(loadPath / "sa1.dat"); + ASSERT_TRUE(!wavinput.empty() && "sa1 frames not loaded properly!"); + + // read expected output data computed from HTK + auto htkfeat = loadData(loadPath / "sa1-mfcc.htk"); + // HTK features not read properly! + ASSERT_TRUE(!htkfeat.empty()); + FeatureParams params; + params.samplingFreq = 16000; + params.lowFreqFilterbank = 0; + params.highFreqFilterbank = 8000; + params.zeroMeanFrame = true; + params.numFilterbankChans = 20; + params.numCepstralCoeffs = 13; + params.useEnergy = false; + params.zeroMeanFrame = false; + params.usePower = false; + Mfcc mfcc(params); + auto feat = mfcc.apply(wavinput); + ASSERT_EQ(feat.size(), htkfeat.size()); + + ASSERT_TRUE(feat.size() % 39 == 0); + auto numframes = feat.size() / 39; + + // HTK keeps C0 at last position. adjust accordingly. + auto featcopy(feat); + for(int f = 0; f < numframes; ++f) { + for(int i = 1; i < 39; ++i) { + feat[f * 39 + i - 1] = feat[f * 39 + i]; + } + feat[f * 39 + 12] = featcopy[f * 39 + 0]; + feat[f * 39 + 25] = featcopy[f * 39 + 13]; + feat[f * 39 + 38] = featcopy[f * 39 + 26]; } - feat[f * 39 + 12] = featcopy[f * 39 + 0]; - feat[f * 39 + 25] = featcopy[f * 39 + 13]; - feat[f * 39 + 38] = featcopy[f * 39 + 26]; - } - float sum = 0.0, max = 0.0; - for (int i = 0; i < feat.size(); ++i) { - auto curdiff = std::abs(feat[i] - htkfeat[i]); - sum += curdiff; - if (max < curdiff) { - max = curdiff; + float sum = 0.0, max = 0.0; + for(int i = 0; i < feat.size(); ++i) { + auto curdiff = std::abs(feat[i] - htkfeat[i]); + sum += curdiff; + if(max < curdiff) { + max = curdiff; + } } - } - std::cerr << "| Max diff across all dimensions " << max << "\n"; // 0.325853 + std::cerr << "| Max diff across all dimensions " << max << "\n"; // 0.325853 - std::cerr << "| Avg diff across all dimensions " << sum / feat.size() - << "\n"; // 0.00252719 + std::cerr << "| Avg diff across all dimensions " << sum / feat.size() + << "\n"; // 0.00252719 } TEST(MfccTest, BatchingTest) { - int Tmax = 10000; - auto input = randVec(Tmax); - FeatureParams featparams; - featparams.deltaWindow = 0; - featparams.frameSizeMs = 25; - std::vector energies = {true, false}; - std::vector rawEnergies = {true, false}; - std::vector zMeans = {true, false}; - std::vector usePow = {true, false}; - - int numTrials = 3; - for (auto e : energies) { - for (auto r : rawEnergies) { - for (auto z : zMeans) { - for (auto p : usePow) { - featparams.useEnergy = e; - featparams.rawEnergy = r; - featparams.zeroMeanFrame = z; - featparams.usePower = p; - - Mfcc mfcc(featparams); - - auto output = mfcc.apply(input); - for (int i = 0; i < numTrials; ++i) { - int chunkSz = 500 + (1000 * i) % 5000, curSz = 0; - while (curSz + chunkSz < Tmax) { - curSz += chunkSz; - std::vector curInput(curSz); - std::copy(input.begin(), input.begin() + curSz, curInput.begin()); - auto curOutput = mfcc.apply(curInput); - ASSERT_GT(curOutput.size(), 0); - for (int j = 0; j < curOutput.size(); ++j) { - ASSERT_NEAR(curOutput[j], output[j], 1E-4); - } + int Tmax = 10000; + auto input = randVec(Tmax); + FeatureParams featparams; + featparams.deltaWindow = 0; + featparams.frameSizeMs = 25; + std::vector energies = {true, false}; + std::vector rawEnergies = {true, false}; + std::vector zMeans = {true, false}; + std::vector usePow = {true, false}; + + int numTrials = 3; + for(auto e : energies) { + for(auto r : rawEnergies) { + for(auto z : zMeans) { + for(auto p : usePow) { + featparams.useEnergy = e; + featparams.rawEnergy = r; + featparams.zeroMeanFrame = z; + featparams.usePower = p; + + Mfcc mfcc(featparams); + + auto output = mfcc.apply(input); + for(int i = 0; i < numTrials; ++i) { + int chunkSz = 500 + (1000 * i) % 5000, curSz = 0; + while(curSz + chunkSz < Tmax) { + curSz += chunkSz; + std::vector curInput(curSz); + std::copy(input.begin(), input.begin() + curSz, curInput.begin()); + auto curOutput = mfcc.apply(curInput); + ASSERT_GT(curOutput.size(), 0); + for(int j = 0; j < curOutput.size(); ++j) { + ASSERT_NEAR(curOutput[j], output[j], 1E-4); + } + } + } + } } - } } - } } - } } TEST(MfccTest, BatchingTest2) { - int Tmax = 10000; - int batchSz = 100; - auto input = randVec(Tmax); - FeatureParams featparams; - featparams.frameSizeMs = 25; - std::vector energies = {true, false}; - std::vector rawEnergies = {true, false}; - std::vector zMeans = {true, false}; - std::vector usePow = {true, false}; - - for (auto e : energies) { - for (auto r : rawEnergies) { - for (auto z : zMeans) { - for (auto p : usePow) { - featparams.useEnergy = e; - featparams.rawEnergy = r; - featparams.zeroMeanFrame = z; - featparams.usePower = p; - - Mfcc mfcc(featparams); - - auto output = mfcc.batchApply(input, batchSz); - - auto perBatchOutSz = output.size() / batchSz; - auto perBatchInSz = input.size() / batchSz; - for (int i = 0; i < batchSz; ++i) { - std::vector curInput(perBatchInSz); - std::copy( - input.begin() + i * perBatchInSz, - input.begin() + (i + 1) * perBatchInSz, - curInput.begin()); - auto curOutput = mfcc.apply(curInput); - ASSERT_EQ(curOutput.size(), perBatchOutSz); - for (int j = 0; j < curOutput.size(); ++j) { - ASSERT_NEAR(curOutput[j], output[j + i * perBatchOutSz], 1E-4); + int Tmax = 10000; + int batchSz = 100; + auto input = randVec(Tmax); + FeatureParams featparams; + featparams.frameSizeMs = 25; + std::vector energies = {true, false}; + std::vector rawEnergies = {true, false}; + std::vector zMeans = {true, false}; + std::vector usePow = {true, false}; + + for(auto e : energies) { + for(auto r : rawEnergies) { + for(auto z : zMeans) { + for(auto p : usePow) { + featparams.useEnergy = e; + featparams.rawEnergy = r; + featparams.zeroMeanFrame = z; + featparams.usePower = p; + + Mfcc mfcc(featparams); + + auto output = mfcc.batchApply(input, batchSz); + + auto perBatchOutSz = output.size() / batchSz; + auto perBatchInSz = input.size() / batchSz; + for(int i = 0; i < batchSz; ++i) { + std::vector curInput(perBatchInSz); + std::copy( + input.begin() + i * perBatchInSz, + input.begin() + (i + 1) * perBatchInSz, + curInput.begin() + ); + auto curOutput = mfcc.apply(curInput); + ASSERT_EQ(curOutput.size(), perBatchOutSz); + for(int j = 0; j < curOutput.size(); ++j) { + ASSERT_NEAR(curOutput[j], output[j + i * perBatchOutSz], 1E-4); + } + } + } } - } } - } } - } } TEST(MfccTest, EmptyTest) { - std::vector input; - FeatureParams featparams; - Mfcc mfcc(featparams); - auto output = mfcc.apply(input); - ASSERT_TRUE(output.empty()); - - int Tmax = 500; - for (int t = 1; t <= Tmax; ++t) { - input = randVec(Tmax); - output = mfcc.apply(input); - ASSERT_TRUE(output.size() >= 0); - } + std::vector input; + FeatureParams featparams; + Mfcc mfcc(featparams); + auto output = mfcc.apply(input); + ASSERT_TRUE(output.empty()); + + int Tmax = 500; + for(int t = 1; t <= Tmax; ++t) { + input = randVec(Tmax); + output = mfcc.apply(input); + ASSERT_TRUE(output.size() >= 0); + } } TEST(MfccTest, ZeroInputTest) { - auto params = FeatureParams(); - params.useEnergy = false; - Mfsc mfcc(params); - auto input = std::vector(10000, 0.0); - auto output = mfcc.apply(input); - for (auto o : output) { - ASSERT_NEAR(o, 0.0, 1E-4); - } + auto params = FeatureParams(); + params.useEnergy = false; + Mfsc mfcc(params); + auto input = std::vector(10000, 0.0); + auto output = mfcc.apply(input); + for(auto o : output) { + ASSERT_NEAR(o, 0.0, 1E-4); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); + ::testing::InitGoogleTest(&argc, argv); // Resolve directory for data #ifdef FEATURE_TEST_DATADIR - loadPath = fs::path(FEATURE_TEST_DATADIR); + loadPath = fs::path(FEATURE_TEST_DATADIR); #endif - return RUN_ALL_TESTS(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/audio/PreEmphasisTest.cpp b/flashlight/pkg/speech/test/audio/PreEmphasisTest.cpp index e27a1c2..131602f 100644 --- a/flashlight/pkg/speech/test/audio/PreEmphasisTest.cpp +++ b/flashlight/pkg/speech/test/audio/PreEmphasisTest.cpp @@ -15,62 +15,68 @@ using fl::lib::audio::PreEmphasis; // Matlab code used: B=[1, -0.95]; = filter(B, 1, data, [], 2); // For first element in data multiply by (1 - alpha) TEST(PreEmphasisTest, matlabCompareTest) { - int N = 8; - PreEmphasis preemphasis1d(0.95, N); - std::vector input{0.098589, - 0.715877, - 0.750572, - 0.787636, - 0.116829, - 0.242914, - 0.327526, - 0.410389}; - // ndim = 1 - std::vector matlaboutput1d{0.004929, - 0.622218, - 0.070489, - 0.074592, - -0.631425, - 0.131927, - 0.096757, - 0.099240}; - auto output1d = preemphasis1d.apply(input); - // Implementation should match with matlab. - ASSERT_TRUE(compareVec(output1d, matlaboutput1d)); + int N = 8; + PreEmphasis preemphasis1d(0.95, N); + std::vector input{0.098589, + 0.715877, + 0.750572, + 0.787636, + 0.116829, + 0.242914, + 0.327526, + 0.410389}; + // ndim = 1 + std::vector matlaboutput1d{0.004929, + 0.622218, + 0.070489, + 0.074592, + -0.631425, + 0.131927, + 0.096757, + 0.099240}; + auto output1d = preemphasis1d.apply(input); + // Implementation should match with matlab. + ASSERT_TRUE(compareVec(output1d, matlaboutput1d)); - // ndim = 2 - PreEmphasis preemphasis2d(0.95, N / 2); - std::vector matlaboutput2d{0.004929, - 0.622218, - 0.070489, - 0.074592, - 0.005841, - 0.131927, - 0.096757, - 0.099240}; - auto output2d = preemphasis2d.apply(input); - // Implementation should match with matlab. - ASSERT_TRUE(compareVec(output2d, matlaboutput2d)); + // ndim = 2 + PreEmphasis preemphasis2d(0.95, N / 2); + std::vector matlaboutput2d{0.004929, + 0.622218, + 0.070489, + 0.074592, + 0.005841, + 0.131927, + 0.096757, + 0.099240}; + auto output2d = preemphasis2d.apply(input); + // Implementation should match with matlab. + ASSERT_TRUE(compareVec(output2d, matlaboutput2d)); } TEST(PreEmphasisTest, batchingTest) { - int N = 16, B = 15; - auto input = randVec(N * B); - auto preemphasis = PreEmphasis(0.5, N); - auto output = preemphasis.apply(input); - ASSERT_EQ(output.size(), input.size()); - for (int i = 0; i < B; ++i) { - std::vector curInput(N), expOutput(N); - std::copy( - input.data() + i * N, input.data() + (i + 1) * N, curInput.data()); - std::copy( - output.data() + i * N, output.data() + (i + 1) * N, expOutput.data()); - auto curOutput = preemphasis.apply(curInput); - ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-10)); - } + int N = 16, B = 15; + auto input = randVec(N * B); + auto preemphasis = PreEmphasis(0.5, N); + auto output = preemphasis.apply(input); + ASSERT_EQ(output.size(), input.size()); + for(int i = 0; i < B; ++i) { + std::vector curInput(N), expOutput(N); + std::copy( + input.data() + i * N, + input.data() + (i + 1) * N, + curInput.data() + ); + std::copy( + output.data() + i * N, + output.data() + (i + 1) * N, + expOutput.data() + ); + auto curOutput = preemphasis.apply(curInput); + ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-10)); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/audio/SpeechUtilsTest.cpp b/flashlight/pkg/speech/test/audio/SpeechUtilsTest.cpp index 2820a1a..51d22b7 100644 --- a/flashlight/pkg/speech/test/audio/SpeechUtilsTest.cpp +++ b/flashlight/pkg/speech/test/audio/SpeechUtilsTest.cpp @@ -13,21 +13,21 @@ using namespace fl::lib::audio; TEST(SpeechUtilsTest, SimpleMatmul) { - /* - A B - [ 2 3 4 ] [ 2 3 ] - [ 3 4 5 ], [ 3 4 ] - [ 4 5 6 ], [ 4 5 ] - [ 5 6 7 ], - */ - std::vector A = {2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7}; - std::vector B = {2, 3, 3, 4, 4, 5}; - auto op = cblasGemm(A, B, 2, 3); - std::vector expectedOp = {29, 38, 38, 50, 47, 62, 56, 74}; - EXPECT_TRUE(compareVec(op, expectedOp, 1E-10)); + /* + A B + [ 2 3 4 ] [ 2 3 ] + [ 3 4 5 ], [ 3 4 ] + [ 4 5 6 ], [ 4 5 ] + [ 5 6 7 ], + */ + std::vector A = {2, 3, 4, 3, 4, 5, 4, 5, 6, 5, 6, 7}; + std::vector B = {2, 3, 3, 4, 4, 5}; + auto op = cblasGemm(A, B, 2, 3); + std::vector expectedOp = {29, 38, 38, 50, 47, 62, 56, 74}; + EXPECT_TRUE(compareVec(op, expectedOp, 1E-10)); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/audio/TestUtils.h b/flashlight/pkg/speech/test/audio/TestUtils.h index 893de36..3826a11 100644 --- a/flashlight/pkg/speech/test/audio/TestUtils.h +++ b/flashlight/pkg/speech/test/audio/TestUtils.h @@ -11,36 +11,36 @@ #include #include -template +template bool compareVec(std::vector A, std::vector B, float precision = 1E-5) { - if (A.size() != B.size()) { - return false; - } - for (std::size_t i = 0; i < A.size(); ++i) { - if (std::abs(A[i] - B[i]) > precision) { - return false; + if(A.size() != B.size()) { + return false; } - } - return true; + for(std::size_t i = 0; i < A.size(); ++i) { + if(std::abs(A[i] - B[i]) > precision) { + return false; + } + } + return true; } -template +template std::vector randVec(std::size_t N, float min = -1.0, float max = 1.0) { - std::vector vec(N); - for (auto& v : vec) { - v = static_cast(rand()) / static_cast(RAND_MAX); - v = v * (max - min) + min; - } - return vec; + std::vector vec(N); + for(auto& v : vec) { + v = static_cast(rand()) / static_cast(RAND_MAX); + v = v * (max - min) + min; + } + return vec; } -template +template std::vector transposeVec(const std::vector& in, int inRow, int inCol) { - std::vector out(inRow * inCol); - for (size_t r = 0; r < inRow; ++r) { - for (size_t c = 0; c < inCol; ++c) { - out[c * inRow + r] = in[r * inCol + c]; + std::vector out(inRow * inCol); + for(size_t r = 0; r < inRow; ++r) { + for(size_t c = 0; c < inCol; ++c) { + out[c * inRow + r] = in[r * inCol + c]; + } } - } - return out; + return out; } diff --git a/flashlight/pkg/speech/test/audio/TriFilterbankTest.cpp b/flashlight/pkg/speech/test/audio/TriFilterbankTest.cpp index 1c3ce97..5bcf84e 100644 --- a/flashlight/pkg/speech/test/audio/TriFilterbankTest.cpp +++ b/flashlight/pkg/speech/test/audio/TriFilterbankTest.cpp @@ -17,65 +17,67 @@ using fl::lib::audio::TriFilterbank; // H = trifbank( M, K, R, fs, hz2mel, mel2hz ); % size of H is M x K // FBE = H * MAG(1:K,:); // Reference: Kamil Wojcicki, HTK MFCC MATLAB, URL: -// https://www.mathworks.com/matlabcentral/fileexchange/32849-htk-mfcc-matlab +// https://www.mathworks.com/matlabcentral/fileexchange/32849-htk-mfcc-matlab TEST(TriFilterbankTest, matlabCompareTest) { - // Test Case: 1 - TriFilterbank triflt1(10, 9, 20000, 0, 10000, FrequencyScale::MEL); - std::vector matlabfbank1{ - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0.881121, 0.118879, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0.882891, 0.117109, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0.569722, 0.430278, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0.571075, 0.428925, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0.763933, 0.236067, 0, 0, 0, 0, 0, 0, - 0, 0, 0.082177, 0.917823, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0.532067, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0}; - auto outputfbank1 = triflt1.filterbank(); - // Implementation should match with matlab for Test case 1. - ASSERT_TRUE(compareVec(outputfbank1, matlabfbank1)); + // Test Case: 1 + TriFilterbank triflt1(10, 9, 20000, 0, 10000, FrequencyScale::MEL); + std::vector matlabfbank1{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0.881121, 0.118879, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0.882891, 0.117109, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0.569722, 0.430278, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0.571075, 0.428925, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0.763933, 0.236067, 0, 0, 0, 0, 0, 0, + 0, 0, 0.082177, 0.917823, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0.532067, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0}; + auto outputfbank1 = triflt1.filterbank(); + // Implementation should match with matlab for Test case 1. + ASSERT_TRUE(compareVec(outputfbank1, matlabfbank1)); - // Test Case: 2 - TriFilterbank triflt2(23, 33, 8000, 300, 3700, FrequencyScale::MEL); - std::vector input2{ - 0.0461713, 0.0971317, 0.823457, 0.694828, 0.317099, 0.950222, 0.0344460, - 0.438744, 0.381558, 0.765516, 0.795199, 0.186872, 0.489764, 0.445586, - 0.646313, 0.709364, 0.754686, 0.276025, 0.679702, 0.655098, 0.162611, - 0.118997, 0.498364, 0.959743, 0.340385, 0.585267, 0.223811, 0.751267, - 0.255095, 0.505957, 0.699076, 0.890903, 0.959291}; - std::vector matlabop2{ - 0.578693, 0.131362, 0.301871, 0.426760, 0.523461, 0.0338169, - 0.285265, 0.311304, 0.424245, 0.714087, 0.680402, 0.267582, - 0.526783, 0.612373, 0.814208, 0.962699, 0.620225, 0.907083, - 0.326320, 0.879130, 1.07004, 0.844134, 0.957356}; + // Test Case: 2 + TriFilterbank triflt2(23, 33, 8000, 300, 3700, FrequencyScale::MEL); + std::vector input2{ + 0.0461713, 0.0971317, 0.823457, 0.694828, 0.317099, 0.950222, 0.0344460, + 0.438744, 0.381558, 0.765516, 0.795199, 0.186872, 0.489764, 0.445586, + 0.646313, 0.709364, 0.754686, 0.276025, 0.679702, 0.655098, 0.162611, + 0.118997, 0.498364, 0.959743, 0.340385, 0.585267, 0.223811, 0.751267, + 0.255095, 0.505957, 0.699076, 0.890903, 0.959291}; + std::vector matlabop2{ + 0.578693, 0.131362, 0.301871, 0.426760, 0.523461, 0.0338169, + 0.285265, 0.311304, 0.424245, 0.714087, 0.680402, 0.267582, + 0.526783, 0.612373, 0.814208, 0.962699, 0.620225, 0.907083, + 0.326320, 0.879130, 1.07004, 0.844134, 0.957356}; - auto output2 = triflt2.apply(input2); - // Implementation should match with matlab for Test case 2. - ASSERT_TRUE(compareVec(output2, matlabop2)); + auto output2 = triflt2.apply(input2); + // Implementation should match with matlab for Test case 2. + ASSERT_TRUE(compareVec(output2, matlabop2)); } TEST(TriFilterbankTest, batchingTest) { - int numFilters = 16, filterLen = 10, B = 15; - auto input = randVec(filterLen * B); - auto triflt = TriFilterbank(numFilters, filterLen, 16000); - auto output = triflt.apply(input); - ASSERT_EQ(output.size(), numFilters * B); - for (int i = 0; i < B; ++i) { - std::vector curInput(filterLen), expOutput(numFilters); - std::copy( - input.data() + i * filterLen, - input.data() + (i + 1) * filterLen, - curInput.data()); - std::copy( - output.data() + i * numFilters, - output.data() + (i + 1) * numFilters, - expOutput.data()); - auto curOutput = triflt.apply(curInput); - ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-5)); - } + int numFilters = 16, filterLen = 10, B = 15; + auto input = randVec(filterLen * B); + auto triflt = TriFilterbank(numFilters, filterLen, 16000); + auto output = triflt.apply(input); + ASSERT_EQ(output.size(), numFilters * B); + for(int i = 0; i < B; ++i) { + std::vector curInput(filterLen), expOutput(numFilters); + std::copy( + input.data() + i * filterLen, + input.data() + (i + 1) * filterLen, + curInput.data() + ); + std::copy( + output.data() + i * numFilters, + output.data() + (i + 1) * numFilters, + expOutput.data() + ); + auto curOutput = triflt.apply(curInput); + ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-5)); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/audio/WindowingTest.cpp b/flashlight/pkg/speech/test/audio/WindowingTest.cpp index 1c77e5f..2750c7b 100644 --- a/flashlight/pkg/speech/test/audio/WindowingTest.cpp +++ b/flashlight/pkg/speech/test/audio/WindowingTest.cpp @@ -14,58 +14,64 @@ using fl::lib::audio::Windowing; using fl::lib::audio::WindowType; TEST(WindowingTest, hammingCoeffsTest) { - int N = 64; - auto hammwindow = Windowing(N, WindowType::HAMMING); - std::vector matlabcoeffs{ - 0.080000, 0.082286, 0.089121, 0.100437, 0.116121, 0.136018, 0.159930, - 0.187620, 0.218811, 0.253195, 0.290429, 0.330143, 0.371943, 0.415413, - 0.460122, 0.505624, 0.551468, 0.597198, 0.642360, 0.686504, 0.729192, - 0.770000, 0.808522, 0.844375, 0.877204, 0.906681, 0.932514, 0.954446, - 0.972259, 0.985776, 0.994862, 0.999428, 0.999428, 0.994862, 0.985776, - 0.972259, 0.954446, 0.932514, 0.906681, 0.877204, 0.844375, 0.808522, - 0.770000, 0.729192, 0.686504, 0.642360, 0.597198, 0.551468, 0.505624, - 0.460122, 0.415413, 0.371943, 0.330143, 0.290429, 0.253195, 0.218811, - 0.187620, 0.159930, 0.136018, 0.116121, 0.100437, 0.089121, 0.082286, - 0.080000, - }; - std::vector input(N, 1.0); - auto output = hammwindow.apply(input); - // Hamming window coefficients should match with matlab implementation. - ASSERT_TRUE(compareVec(output, matlabcoeffs)); + int N = 64; + auto hammwindow = Windowing(N, WindowType::HAMMING); + std::vector matlabcoeffs{ + 0.080000, 0.082286, 0.089121, 0.100437, 0.116121, 0.136018, 0.159930, + 0.187620, 0.218811, 0.253195, 0.290429, 0.330143, 0.371943, 0.415413, + 0.460122, 0.505624, 0.551468, 0.597198, 0.642360, 0.686504, 0.729192, + 0.770000, 0.808522, 0.844375, 0.877204, 0.906681, 0.932514, 0.954446, + 0.972259, 0.985776, 0.994862, 0.999428, 0.999428, 0.994862, 0.985776, + 0.972259, 0.954446, 0.932514, 0.906681, 0.877204, 0.844375, 0.808522, + 0.770000, 0.729192, 0.686504, 0.642360, 0.597198, 0.551468, 0.505624, + 0.460122, 0.415413, 0.371943, 0.330143, 0.290429, 0.253195, 0.218811, + 0.187620, 0.159930, 0.136018, 0.116121, 0.100437, 0.089121, 0.082286, + 0.080000, + }; + std::vector input(N, 1.0); + auto output = hammwindow.apply(input); + // Hamming window coefficients should match with matlab implementation. + ASSERT_TRUE(compareVec(output, matlabcoeffs)); } TEST(WindowingTest, hanningCoeffsTest) { - int N = 32; - auto hannwindow = Windowing(N, WindowType::HANNING); - std::vector matlabcoeffs{ - 0.00000, 0.01024, 0.04052, 0.08962, 0.15552, 0.23552, 0.32635, 0.42429, - 0.52532, 0.62533, 0.72020, 0.80605, 0.87938, 0.93717, 0.97707, 0.99743, - 0.99743, 0.97707, 0.93717, 0.87938, 0.80605, 0.72020, 0.62533, 0.52532, - 0.42429, 0.32635, 0.23552, 0.15552, 0.08962, 0.04052, 0.01024, 0.00000}; - std::vector input(N, 1.0); - auto output = hannwindow.apply(input); - // Hamming window coefficients should match with matlab implementation. - ASSERT_TRUE(compareVec(output, matlabcoeffs)); + int N = 32; + auto hannwindow = Windowing(N, WindowType::HANNING); + std::vector matlabcoeffs{ + 0.00000, 0.01024, 0.04052, 0.08962, 0.15552, 0.23552, 0.32635, 0.42429, + 0.52532, 0.62533, 0.72020, 0.80605, 0.87938, 0.93717, 0.97707, 0.99743, + 0.99743, 0.97707, 0.93717, 0.87938, 0.80605, 0.72020, 0.62533, 0.52532, + 0.42429, 0.32635, 0.23552, 0.15552, 0.08962, 0.04052, 0.01024, 0.00000}; + std::vector input(N, 1.0); + auto output = hannwindow.apply(input); + // Hamming window coefficients should match with matlab implementation. + ASSERT_TRUE(compareVec(output, matlabcoeffs)); } TEST(WindowingTest, batchingTest) { - int N = 16, B = 15; - auto input = randVec(N * B); - auto hannwindow = Windowing(N, WindowType::HANNING); - auto output = hannwindow.apply(input); - ASSERT_EQ(output.size(), input.size()); - for (int i = 0; i < B; ++i) { - std::vector curInput(N), expOutput(N); - std::copy( - input.data() + i * N, input.data() + (i + 1) * N, curInput.data()); - std::copy( - output.data() + i * N, output.data() + (i + 1) * N, expOutput.data()); - auto curOutput = hannwindow.apply(curInput); - ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-10)); - } + int N = 16, B = 15; + auto input = randVec(N * B); + auto hannwindow = Windowing(N, WindowType::HANNING); + auto output = hannwindow.apply(input); + ASSERT_EQ(output.size(), input.size()); + for(int i = 0; i < B; ++i) { + std::vector curInput(N), expOutput(N); + std::copy( + input.data() + i * N, + input.data() + (i + 1) * N, + curInput.data() + ); + std::copy( + output.data() + i * N, + output.data() + (i + 1) * N, + expOutput.data() + ); + auto curOutput = hannwindow.apply(curInput); + ASSERT_TRUE(compareVec(curOutput, expOutput, 1E-10)); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/augmentation/AdditiveNoiseTest.cpp b/flashlight/pkg/speech/test/augmentation/AdditiveNoiseTest.cpp index 64652a9..9053771 100644 --- a/flashlight/pkg/speech/test/augmentation/AdditiveNoiseTest.cpp +++ b/flashlight/pkg/speech/test/augmentation/AdditiveNoiseTest.cpp @@ -23,9 +23,9 @@ using ::testing::Pointwise; const size_t sampleRate = 16000; MATCHER_P(FloatNearPointwise, tol, "Out of range") { - return ( - std::get<0>(arg) > std::get<1>(arg) - tol && - std::get<0>(arg) < std::get<1>(arg) + tol); + return + std::get<0>(arg) > std::get<1>(arg) - tol + && std::get<0>(arg) < std::get<1>(arg) + tol; } /** @@ -38,64 +38,67 @@ MATCHER_P(FloatNearPointwise, tol, "Out of range") { * considering the SNR value. */ TEST(AdditiveNoise, Snr) { - const fs::path tmpDir = fs::temp_directory_path() / "AdditiveNoise"; - fs::create_directory(tmpDir); - const fs::path listFilePath = tmpDir / "noise.lst"; - const fs::path noiseFilePath = tmpDir / "noise.flac"; - - const float signalAmplitude = -1.0; - const int signalLen = 10; - std::vector signal(signalLen, signalAmplitude); - const float noiseAmplitude = 1.0; - const int noiseLen = 10; - std::vector noise(noiseLen, noiseAmplitude); - - saveSound( - noiseFilePath, - noise, - sampleRate, - 1, - fl::pkg::speech::SoundFormat::FLAC, - fl::pkg::speech::SoundSubFormat::PCM_16); - - // Create test list file - { - std::ofstream listFile(listFilePath); - listFile << noiseFilePath.string(); - } - - float threshold = 0.02; // allow 2% difference from expected value - - for (float snr = 1; snr < 30; ++snr) { - AdditiveNoise::Config conf; - conf.proba_ = 1.0; - conf.ratio_ = 1.0; - conf.minSnr_ = snr; - conf.maxSnr_ = snr; - conf.nClipsMin_ = 1; - conf.nClipsMax_ = 1; - conf.listFilePath_ = listFilePath; - - AdditiveNoise sfx(conf); - auto augmented = signal; - sfx.apply(augmented); - - std::vector extractNoise(augmented.size()); - for (int i = 0; i < extractNoise.size(); ++i) { - extractNoise[i] = (augmented[i] - signal[i]); + const fs::path tmpDir = fs::temp_directory_path() / "AdditiveNoise"; + fs::create_directory(tmpDir); + const fs::path listFilePath = tmpDir / "noise.lst"; + const fs::path noiseFilePath = tmpDir / "noise.flac"; + + const float signalAmplitude = -1.0; + const int signalLen = 10; + std::vector signal(signalLen, signalAmplitude); + const float noiseAmplitude = 1.0; + const int noiseLen = 10; + std::vector noise(noiseLen, noiseAmplitude); + + saveSound( + noiseFilePath, + noise, + sampleRate, + 1, + fl::pkg::speech::SoundFormat::FLAC, + fl::pkg::speech::SoundSubFormat::PCM_16 + ); + + // Create test list file + { + std::ofstream listFile(listFilePath); + listFile << noiseFilePath.string(); } - ASSERT_LE( - signalToNoiseRatio(signal, extractNoise), - (conf.maxSnr_ * (1 + threshold))); - ASSERT_GE( - signalToNoiseRatio(signal, extractNoise), - (conf.minSnr_ * (1 - threshold))); - } + float threshold = 0.02; // allow 2% difference from expected value + + for(float snr = 1; snr < 30; ++snr) { + AdditiveNoise::Config conf; + conf.proba_ = 1.0; + conf.ratio_ = 1.0; + conf.minSnr_ = snr; + conf.maxSnr_ = snr; + conf.nClipsMin_ = 1; + conf.nClipsMax_ = 1; + conf.listFilePath_ = listFilePath; + + AdditiveNoise sfx(conf); + auto augmented = signal; + sfx.apply(augmented); + + std::vector extractNoise(augmented.size()); + for(int i = 0; i < extractNoise.size(); ++i) { + extractNoise[i] = (augmented[i] - signal[i]); + } + + ASSERT_LE( + signalToNoiseRatio(signal, extractNoise), + (conf.maxSnr_ * (1 + threshold)) + ); + ASSERT_GE( + signalToNoiseRatio(signal, extractNoise), + (conf.minSnr_ * (1 - threshold)) + ); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/augmentation/GaussianNoiseTest.cpp b/flashlight/pkg/speech/test/augmentation/GaussianNoiseTest.cpp index f65c25b..b83b05d 100644 --- a/flashlight/pkg/speech/test/augmentation/GaussianNoiseTest.cpp +++ b/flashlight/pkg/speech/test/augmentation/GaussianNoiseTest.cpp @@ -17,34 +17,34 @@ using namespace ::fl::pkg::speech::sfx; const int numSamples = 10000; TEST(GaussianNoise, SnrCheck) { - int numTrys = 10; - float tolerance = 1e-1; - // Use `r` as seed so that we test different input samples at different SNRs - for (int r = 0; r < numTrys; ++r) { - RandomNumberGenerator rng(r); - std::vector signal(numSamples); - for (auto& i : signal) { - i = rng.random() ; - } + int numTrys = 10; + float tolerance = 1e-1; + // Use `r` as seed so that we test different input samples at different SNRs + for(int r = 0; r < numTrys; ++r) { + RandomNumberGenerator rng(r); + std::vector signal(numSamples); + for(auto& i : signal) { + i = rng.random(); + } - GaussianNoise::Config cfg; - cfg.minSnr_ = 8; - cfg.maxSnr_ = 12; - GaussianNoise sfx(cfg, r); - auto originalSignal = signal; - sfx.apply(signal); - ASSERT_EQ(signal.size(), originalSignal.size()); - std::vector noise(signal.size()); - for (int i = 0 ;i < noise.size(); ++i) { - noise[i] = signal[i] - originalSignal[i]; + GaussianNoise::Config cfg; + cfg.minSnr_ = 8; + cfg.maxSnr_ = 12; + GaussianNoise sfx(cfg, r); + auto originalSignal = signal; + sfx.apply(signal); + ASSERT_EQ(signal.size(), originalSignal.size()); + std::vector noise(signal.size()); + for(int i = 0; i < noise.size(); ++i) { + noise[i] = signal[i] - originalSignal[i]; + } + ASSERT_LE(signalToNoiseRatio(originalSignal, noise), cfg.maxSnr_ + tolerance); + ASSERT_GE(signalToNoiseRatio(originalSignal, noise), cfg.minSnr_ - tolerance); } - ASSERT_LE(signalToNoiseRatio(originalSignal, noise), cfg.maxSnr_ + tolerance); - ASSERT_GE(signalToNoiseRatio(originalSignal, noise), cfg.minSnr_ - tolerance); - } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp b/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp index b584623..4107683 100644 --- a/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp +++ b/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp @@ -22,9 +22,9 @@ const size_t sampleRate = 16000; const float amplitude = 1.0; MATCHER_P(FloatNearPointwise, tol, "Out of range") { - return ( - std::get<0>(arg) > std::get<1>(arg) - tol && - std::get<0>(arg) < std::get<1>(arg) + tol); + return + std::get<0>(arg) > std::get<1>(arg) - tol + && std::get<0>(arg) < std::get<1>(arg) + tol; } /** @@ -45,78 +45,82 @@ MATCHER_P(FloatNearPointwise, tol, "Out of range") { * simpler. */ TEST(ReverbEcho, SinWaveReverb) { - // Make the reverb start at the center of the sample vector. - const size_t firstReverbIdx = numSamples / 2; - const float firstDelay = - static_cast(firstReverbIdx) / static_cast(sampleRate); - - ReverbEcho::Config conf; - conf.proba_ = 1.0f; // revern every sample - // Force delay to a specific period - conf.firstDelayMin_ = firstDelay; - conf.firstDelayMax_ = firstDelay; - // No jitter so delay is deterministic - conf.jitter_ = 0; - // Make very long rt60 so attenuation over the period of the signal is nearly - // zero. - conf.rt60Min_ = firstDelay * 100; - conf.rt60Min_ = firstDelay * 100; - conf.repeat_ = 3; - // Keep inital echo aplitude same as orig. - conf.initialMin_ = 1; - conf.initialMax_ = 1; - - std::vector signal = - genTestSinWave(numSamples, freq, sampleRate, amplitude); - - std::vector input = signal; - std::vector inpuBeforeDelay( - signal.begin(), signal.begin() + firstReverbIdx - 1); - std::vector inpuAfterDelay( - signal.begin() + firstReverbIdx, signal.end()); - - ReverbEcho sfx(conf); - sfx.apply(signal); - - std::vector outputBeforeDelay( - signal.begin(), signal.begin() + firstReverbIdx - 1); - std::vector outputAfterDelay( - signal.begin() + firstReverbIdx, signal.end()); - - EXPECT_EQ(inpuBeforeDelay, outputBeforeDelay); - EXPECT_NE(inpuAfterDelay, outputAfterDelay); - - // Extract the noise and compare with input that is the source of that noise. - std::vector noise(firstReverbIdx); - for (int k = firstReverbIdx; k < signal.size(); ++k) { - noise[k - firstReverbIdx] = signal[k] - input[k]; - } - // Because we use very long rt60 and we use multiple repeasts, the reverb sum - // can get to very high values. We normalize by mean of the abs diffs. - float noiseSum = 0; - float inputSum = 0; - for (int j = firstReverbIdx; j < signal.size(); ++j) { - noiseSum += std::abs(signal[j] - input[j]); - inputSum += std::abs(input[j - firstReverbIdx]); - } - float norm = noiseSum / inputSum; - std::transform( - noise.begin(), noise.end(), noise.begin(), [norm](float x) -> float { + // Make the reverb start at the center of the sample vector. + const size_t firstReverbIdx = numSamples / 2; + const float firstDelay = + static_cast(firstReverbIdx) / static_cast(sampleRate); + + ReverbEcho::Config conf; + conf.proba_ = 1.0f; // revern every sample + // Force delay to a specific period + conf.firstDelayMin_ = firstDelay; + conf.firstDelayMax_ = firstDelay; + // No jitter so delay is deterministic + conf.jitter_ = 0; + // Make very long rt60 so attenuation over the period of the signal is nearly + // zero. + conf.rt60Min_ = firstDelay * 100; + conf.rt60Min_ = firstDelay * 100; + conf.repeat_ = 3; + // Keep inital echo aplitude same as orig. + conf.initialMin_ = 1; + conf.initialMax_ = 1; + + std::vector signal = + genTestSinWave(numSamples, freq, sampleRate, amplitude); + + std::vector input = signal; + std::vector inpuBeforeDelay( + signal.begin(), signal.begin() + firstReverbIdx - 1); + std::vector inpuAfterDelay( + signal.begin() + firstReverbIdx, signal.end()); + + ReverbEcho sfx(conf); + sfx.apply(signal); + + std::vector outputBeforeDelay( + signal.begin(), signal.begin() + firstReverbIdx - 1); + std::vector outputAfterDelay( + signal.begin() + firstReverbIdx, signal.end()); + + EXPECT_EQ(inpuBeforeDelay, outputBeforeDelay); + EXPECT_NE(inpuAfterDelay, outputAfterDelay); + + // Extract the noise and compare with input that is the source of that noise. + std::vector noise(firstReverbIdx); + for(int k = firstReverbIdx; k < signal.size(); ++k) { + noise[k - firstReverbIdx] = signal[k] - input[k]; + } + // Because we use very long rt60 and we use multiple repeasts, the reverb sum + // can get to very high values. We normalize by mean of the abs diffs. + float noiseSum = 0; + float inputSum = 0; + for(int j = firstReverbIdx; j < signal.size(); ++j) { + noiseSum += std::abs(signal[j] - input[j]); + inputSum += std::abs(input[j - firstReverbIdx]); + } + float norm = noiseSum / inputSum; + std::transform( + noise.begin(), + noise.end(), + noise.begin(), + [norm](float x) -> float { return x / norm; - }); + } + ); - // To reduce test flakiness, we trim the edges of the noise and compare only - // with the part in the input that is the source of this reverb noise. - std::vector noiseMain(noise.begin() + 10, noise.end() - 10); - std::vector noiseSrc( - input.begin() + 9, input.begin() + firstReverbIdx - 11); + // To reduce test flakiness, we trim the edges of the noise and compare only + // with the part in the input that is the source of this reverb noise. + std::vector noiseMain(noise.begin() + 10, noise.end() - 10); + std::vector noiseSrc( + input.begin() + 9, input.begin() + firstReverbIdx - 11); - EXPECT_EQ(noiseMain.size(), noiseSrc.size()); - EXPECT_THAT(noiseMain, Pointwise(FloatNearPointwise(0.1), noiseSrc)); + EXPECT_EQ(noiseMain.size(), noiseSrc.size()); + EXPECT_THAT(noiseMain, Pointwise(FloatNearPointwise(0.1), noiseSrc)); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/augmentation/SoundEffectConfigTest.cpp b/flashlight/pkg/speech/test/augmentation/SoundEffectConfigTest.cpp index 1035a2e..bf85e19 100644 --- a/flashlight/pkg/speech/test/augmentation/SoundEffectConfigTest.cpp +++ b/flashlight/pkg/speech/test/augmentation/SoundEffectConfigTest.cpp @@ -24,66 +24,66 @@ using namespace ::fl::pkg::speech::sfx; * configured sound effect chain. */ TEST(SoundEffectConfigFile, ReadWriteJson) { - const fs::path configPath = fs::temp_directory_path() / "sfxConfig.json"; - // This log line alllows the user to inspect the config file or copy/paste - // configuration. - LOG(INFO) << "output config file= " << configPath; + const fs::path configPath = fs::temp_directory_path() / "sfxConfig.json"; + // This log line alllows the user to inspect the config file or copy/paste + // configuration. + LOG(INFO) << "output config file= " << configPath; - std::vector sfxConf1(6); + std::vector sfxConf1(6); - // Create mock noise list file. - const fs::path noiseListPath = fs::temp_directory_path() / "noise.lst"; - { - std::ofstream noiseListFile(noiseListPath); - noiseListFile << "/fake/path.flac"; - } - sfxConf1[0].type_ = kAdditiveNoise; - sfxConf1[0].additiveNoiseConfig_.ratio_ = 0.8; - sfxConf1[0].additiveNoiseConfig_.minSnr_ = 0; - sfxConf1[0].additiveNoiseConfig_.maxSnr_ = 30; - sfxConf1[0].additiveNoiseConfig_.nClipsMin_ = 0; - sfxConf1[0].additiveNoiseConfig_.nClipsMax_ = 4; - sfxConf1[0].additiveNoiseConfig_.listFilePath_ = noiseListPath; + // Create mock noise list file. + const fs::path noiseListPath = fs::temp_directory_path() / "noise.lst"; + { + std::ofstream noiseListFile(noiseListPath); + noiseListFile << "/fake/path.flac"; + } + sfxConf1[0].type_ = kAdditiveNoise; + sfxConf1[0].additiveNoiseConfig_.ratio_ = 0.8; + sfxConf1[0].additiveNoiseConfig_.minSnr_ = 0; + sfxConf1[0].additiveNoiseConfig_.maxSnr_ = 30; + sfxConf1[0].additiveNoiseConfig_.nClipsMin_ = 0; + sfxConf1[0].additiveNoiseConfig_.nClipsMax_ = 4; + sfxConf1[0].additiveNoiseConfig_.listFilePath_ = noiseListPath; - sfxConf1[1].type_ = kAmplify; - sfxConf1[1].amplifyConfig_.ratioMin_ = 1; - sfxConf1[1].amplifyConfig_.ratioMax_ = 10; + sfxConf1[1].type_ = kAmplify; + sfxConf1[1].amplifyConfig_.ratioMin_ = 1; + sfxConf1[1].amplifyConfig_.ratioMax_ = 10; - sfxConf1[2].type_ = kClampAmplitude; + sfxConf1[2].type_ = kClampAmplitude; - sfxConf1[3].type_ = kReverbEcho; - sfxConf1[3].reverbEchoConfig_.proba_ = 0.5; - sfxConf1[3].reverbEchoConfig_.initialMin_ = 0.1; - sfxConf1[3].reverbEchoConfig_.initialMax_ = 0.3; - sfxConf1[3].reverbEchoConfig_.rt60Min_ = 0.3; - sfxConf1[3].reverbEchoConfig_.rt60Max_ = 1.3; - sfxConf1[3].reverbEchoConfig_.firstDelayMin_ = 0.01; - sfxConf1[3].reverbEchoConfig_.firstDelayMax_ = 0.03; - sfxConf1[3].reverbEchoConfig_.repeat_ = 3; - sfxConf1[3].reverbEchoConfig_.jitter_ = 0.2; - sfxConf1[3].reverbEchoConfig_.sampleRate_ = 16000; + sfxConf1[3].type_ = kReverbEcho; + sfxConf1[3].reverbEchoConfig_.proba_ = 0.5; + sfxConf1[3].reverbEchoConfig_.initialMin_ = 0.1; + sfxConf1[3].reverbEchoConfig_.initialMax_ = 0.3; + sfxConf1[3].reverbEchoConfig_.rt60Min_ = 0.3; + sfxConf1[3].reverbEchoConfig_.rt60Max_ = 1.3; + sfxConf1[3].reverbEchoConfig_.firstDelayMin_ = 0.01; + sfxConf1[3].reverbEchoConfig_.firstDelayMax_ = 0.03; + sfxConf1[3].reverbEchoConfig_.repeat_ = 3; + sfxConf1[3].reverbEchoConfig_.jitter_ = 0.2; + sfxConf1[3].reverbEchoConfig_.sampleRate_ = 16000; - sfxConf1[4].type_ = kNormalize; - sfxConf1[4].normalizeOnlyIfTooHigh_ = false; + sfxConf1[4].type_ = kNormalize; + sfxConf1[4].normalizeOnlyIfTooHigh_ = false; - sfxConf1[5].type_ = kTimeStretch; - sfxConf1[5].timeStretchConfig_.proba_ = 1.0; - sfxConf1[5].timeStretchConfig_.minFactor_ = 0.8; - sfxConf1[5].timeStretchConfig_.maxFactor_ = 1.5; + sfxConf1[5].type_ = kTimeStretch; + sfxConf1[5].timeStretchConfig_.proba_ = 1.0; + sfxConf1[5].timeStretchConfig_.minFactor_ = 0.8; + sfxConf1[5].timeStretchConfig_.maxFactor_ = 1.5; - writeSoundEffectConfigFile(configPath, sfxConf1); - const std::vector sfxConf2 = - readSoundEffectConfigFile(configPath); - EXPECT_EQ(sfxConf1.size(), sfxConf2.size()); + writeSoundEffectConfigFile(configPath, sfxConf1); + const std::vector sfxConf2 = + readSoundEffectConfigFile(configPath); + EXPECT_EQ(sfxConf1.size(), sfxConf2.size()); - std::shared_ptr sfx = createSoundEffect(sfxConf2); - EXPECT_NE(sfx.get(), nullptr); + std::shared_ptr sfx = createSoundEffect(sfxConf2); + EXPECT_NE(sfx.get(), nullptr); } int main(int argc, char** argv) { - fl::init(); - google::InitGoogleLogging(argv[0]); - google::InstallFailureSignalHandler(); - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + fl::init(); + google::InitGoogleLogging(argv[0]); + google::InstallFailureSignalHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/augmentation/SoundEffectTest.cpp b/flashlight/pkg/speech/test/augmentation/SoundEffectTest.cpp index d61a508..9929fca 100644 --- a/flashlight/pkg/speech/test/augmentation/SoundEffectTest.cpp +++ b/flashlight/pkg/speech/test/augmentation/SoundEffectTest.cpp @@ -30,12 +30,12 @@ const size_t sampleRate = 16000; * it and verifies the result. */ TEST(SoundEffect, ClampAmplitude) { - ClampAmplitude sfx; - const float amplitude = 2.0; - std::vector signal = - genTestSinWave(numSamples, freq, sampleRate, amplitude); - sfx.apply(signal); - EXPECT_THAT(signal, Each(AllOf(Ge(-1.0), Le(1.0)))); + ClampAmplitude sfx; + const float amplitude = 2.0; + std::vector signal = + genTestSinWave(numSamples, freq, sampleRate, amplitude); + sfx.apply(signal); + EXPECT_THAT(signal, Each(AllOf(Ge(-1.0), Le(1.0)))); } /** @@ -44,12 +44,12 @@ TEST(SoundEffect, ClampAmplitude) { * then normalizes it and verifies the result. */ TEST(SoundEffect, NormalizeTooHigh) { - Normalize sfx(/*onlyIfTooHigh=*/true); - const float amplitude = 2.0; - std::vector signal = - genTestSinWave(numSamples, freq, sampleRate, amplitude); - sfx.apply(signal); - EXPECT_THAT(signal, Each(AllOf(Ge(-1.0), Le(1.0)))); + Normalize sfx(/*onlyIfTooHigh=*/ true); + const float amplitude = 2.0; + std::vector signal = + genTestSinWave(numSamples, freq, sampleRate, amplitude); + sfx.apply(signal); + EXPECT_THAT(signal, Each(AllOf(Ge(-1.0), Le(1.0)))); } /** @@ -57,13 +57,13 @@ TEST(SoundEffect, NormalizeTooHigh) { * unchanged when the input is in valid range */ TEST(SoundEffect, NoNormalizeTooLow) { - Normalize sfx(/*onlyIfTooHigh=*/true); - const float amplitude = 0.5; - std::vector signal = - genTestSinWave(numSamples, freq, sampleRate, amplitude); - std::vector signalCopy = signal; - sfx.apply(signalCopy); - EXPECT_EQ(signal, signalCopy); + Normalize sfx(/*onlyIfTooHigh=*/ true); + const float amplitude = 0.5; + std::vector signal = + genTestSinWave(numSamples, freq, sampleRate, amplitude); + std::vector signalCopy = signal; + sfx.apply(signalCopy); + EXPECT_EQ(signal, signalCopy); } /** @@ -71,14 +71,14 @@ TEST(SoundEffect, NoNormalizeTooLow) { * amplitude of sine-wave in range [-0.5, 0.5] to range [-1,1] */ TEST(SoundEffect, NormalizeTooLow) { - Normalize sfx(/*onlyIfTooHigh=*/false); - const float amplitude = 0.5; - std::vector signal = - genTestSinWave(numSamples, freq, sampleRate, amplitude); - sfx.apply(signal); - EXPECT_THAT(signal, Each(AllOf(Ge(-1.0f), Le(1.0f)))); - EXPECT_THAT(signal, testing::Contains(-1.0f)); - EXPECT_THAT(signal, testing::Contains(1.0f)); + Normalize sfx(/*onlyIfTooHigh=*/ false); + const float amplitude = 0.5; + std::vector signal = + genTestSinWave(numSamples, freq, sampleRate, amplitude); + sfx.apply(signal); + EXPECT_THAT(signal, Each(AllOf(Ge(-1.0f), Le(1.0f)))); + EXPECT_THAT(signal, testing::Contains(-1.0f)); + EXPECT_THAT(signal, testing::Contains(1.0f)); } /** @@ -89,35 +89,37 @@ TEST(SoundEffect, NormalizeTooLow) { * see maximum amplification that is at least half amp.ratioMax_ */ TEST(SoundEffect, Amplify) { - const float amplitude = 1.0; - Amplify::Config conf; - conf.ratioMin_ = amplitude / 10; - conf.ratioMax_ = amplitude * 10; - Amplify sfx(conf); - - std::vector sound = - genTestSinWave(numSamples, freq, sampleRate, amplitude); - - // get min/max amplification after 100 apply, each chooses a random value with - // in range. - float minMaxAbsAmp = conf.ratioMax_; - float maxMaxAbsAmp = 0; - for (int i = 0; i < 100; ++i) { - std::vector soundCopy = sound; - sfx.apply(soundCopy); - // Ensure that current augmentation amplitude is within expected range. - EXPECT_THAT( - soundCopy, Each(AllOf(Ge(-conf.ratioMax_), Le(conf.ratioMax_)))); - - for (auto amp : soundCopy) { - minMaxAbsAmp = std::min(std::fabs(amp), minMaxAbsAmp); - maxMaxAbsAmp = std::max(std::fabs(amp), maxMaxAbsAmp); + const float amplitude = 1.0; + Amplify::Config conf; + conf.ratioMin_ = amplitude / 10; + conf.ratioMax_ = amplitude * 10; + Amplify sfx(conf); + + std::vector sound = + genTestSinWave(numSamples, freq, sampleRate, amplitude); + + // get min/max amplification after 100 apply, each chooses a random value with + // in range. + float minMaxAbsAmp = conf.ratioMax_; + float maxMaxAbsAmp = 0; + for(int i = 0; i < 100; ++i) { + std::vector soundCopy = sound; + sfx.apply(soundCopy); + // Ensure that current augmentation amplitude is within expected range. + EXPECT_THAT( + soundCopy, + Each(AllOf(Ge(-conf.ratioMax_), Le(conf.ratioMax_))) + ); + + for(auto amp : soundCopy) { + minMaxAbsAmp = std::min(std::fabs(amp), minMaxAbsAmp); + maxMaxAbsAmp = std::max(std::fabs(amp), maxMaxAbsAmp); + } } - } - // Ensure that all random augmentations amplitudes are within expected - // random range. EXPECT_LT(minMaxAbsAmp, amp.ratioMin_ * 2); - EXPECT_GT(maxMaxAbsAmp, conf.ratioMax_ / 2); + // Ensure that all random augmentations amplitudes are within expected + // random range. EXPECT_LT(minMaxAbsAmp, amp.ratioMin_ * 2); + EXPECT_GT(maxMaxAbsAmp, conf.ratioMax_ / 2); } // Test that basic sound effect chain processes in the correct order. @@ -125,27 +127,27 @@ TEST(SoundEffect, Amplify) { // range (-1..1) and multiply by amplitude. We expect that the result is // in the range of: -amplitude..amplitude. TEST(SoundEffect, SfxChain) { - const float amplitude = 2.0; - Amplify::Config amp1; - amp1.ratioMin_ = amplitude / 10; - amp1.ratioMax_ = amplitude * 10; - Amplify::Config amp2; - amp2.ratioMin_ = amplitude; - amp2.ratioMax_ = amplitude; - - auto sfxChain = std::make_shared(); - sfxChain->add(std::make_shared(amp1)); - sfxChain->add(std::make_shared()); - sfxChain->add(std::make_shared(amp2)); - - std::vector signal = - genTestSinWave(numSamples, freq, sampleRate, amplitude); - sfxChain->apply(signal); - EXPECT_THAT(signal, Each(AllOf(Ge(-amplitude), Le(amplitude)))); + const float amplitude = 2.0; + Amplify::Config amp1; + amp1.ratioMin_ = amplitude / 10; + amp1.ratioMax_ = amplitude * 10; + Amplify::Config amp2; + amp2.ratioMin_ = amplitude; + amp2.ratioMax_ = amplitude; + + auto sfxChain = std::make_shared(); + sfxChain->add(std::make_shared(amp1)); + sfxChain->add(std::make_shared()); + sfxChain->add(std::make_shared(amp2)); + + std::vector signal = + genTestSinWave(numSamples, freq, sampleRate, amplitude); + sfxChain->apply(signal); + EXPECT_THAT(signal, Each(AllOf(Ge(-amplitude), Le(amplitude)))); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/augmentation/TimeStretchTest.cpp b/flashlight/pkg/speech/test/augmentation/TimeStretchTest.cpp index 4e3e82c..7a758e2 100644 --- a/flashlight/pkg/speech/test/augmentation/TimeStretchTest.cpp +++ b/flashlight/pkg/speech/test/augmentation/TimeStretchTest.cpp @@ -28,29 +28,29 @@ const float amplitude = 0.5; * signal times the factor. */ TEST(TimeStretch, SinWave) { - float tolerance = 0.05; + float tolerance = 0.05; - const std::vector signal = - genTestSinWave(numSamples, freq, sampleRate, amplitude); + const std::vector signal = + genTestSinWave(numSamples, freq, sampleRate, amplitude); - for (float factor = 0.5; factor <= 2; factor += 0.1) { - std::vector augmented = signal; + for(float factor = 0.5; factor <= 2; factor += 0.1) { + std::vector augmented = signal; - TimeStretch::Config conf = { - .proba_ = 1.0, .minFactor_ = factor, .maxFactor_ = factor}; - TimeStretch sfx(conf); - sfx.apply(augmented); + TimeStretch::Config conf = { + .proba_ = 1.0, .minFactor_ = factor, .maxFactor_ = factor}; + TimeStretch sfx(conf); + sfx.apply(augmented); - const float stretchRatio = static_cast(augmented.size()) / - static_cast(signal.size()); + const float stretchRatio = static_cast(augmented.size()) + / static_cast(signal.size()); - EXPECT_GE(stretchRatio, factor * (1 - tolerance)); - EXPECT_LE(stretchRatio, factor * (1 + tolerance)); - } + EXPECT_GE(stretchRatio, factor * (1 - tolerance)); + EXPECT_LE(stretchRatio, factor * (1 + tolerance)); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/common/ProducerConsumerQueueTest.cpp b/flashlight/pkg/speech/test/common/ProducerConsumerQueueTest.cpp index 95efe7d..ad4ca1f 100644 --- a/flashlight/pkg/speech/test/common/ProducerConsumerQueueTest.cpp +++ b/flashlight/pkg/speech/test/common/ProducerConsumerQueueTest.cpp @@ -20,76 +20,76 @@ using namespace fl::lib; TEST(ProducerConsumerQueueTest, SingleThread) { - ProducerConsumerQueue queue(10); - - // Producing - for (int i = 1; i <= 5; i++) { - queue.add(i); - } - queue.finishAdding(); - - // Consuming - std::vector output; - int element; - while (queue.get(element)) { - output.emplace_back(element); - } - - // Check - ASSERT_THAT(output, testing::ElementsAre(1, 2, 3, 4, 5)); + ProducerConsumerQueue queue(10); + + // Producing + for(int i = 1; i <= 5; i++) { + queue.add(i); + } + queue.finishAdding(); + + // Consuming + std::vector output; + int element; + while(queue.get(element)) { + output.emplace_back(element); + } + + // Check + ASSERT_THAT(output, testing::ElementsAre(1, 2, 3, 4, 5)); } TEST(ProducerConsumerQueueTest, MultiThreads) { - const int nElements = 1000, targetSum = 499500; - const int nProducer = folly::available_concurrency() / 2, - nConsumer = folly::available_concurrency() / 2; - std::vector consumerResults(nConsumer, 0); + const int nElements = 1000, targetSum = 499500; + const int nProducer = folly::available_concurrency() / 2, + nConsumer = folly::available_concurrency() / 2; + std::vector consumerResults(nConsumer, 0); + + ProducerConsumerQueue queue(nElements); + + // Define producer and consumers + auto produce = [nProducer, &queue](int tid) { + for(int i = tid; i < nElements; i += nProducer) { + queue.add(i); + } + }; + + auto consume = [&consumerResults, &queue](int tid) { + int element; + while(queue.get(element)) { + consumerResults[tid] += element; + } + }; + + // Run Test + std::vector> producerFutures(nConsumer); + for(int i = 0; i < nProducer; i++) { + producerFutures[i] = std::async(std::launch::async, produce, i); + } - ProducerConsumerQueue queue(nElements); + std::vector> consumerFutures(nConsumer); + for(int i = 0; i < nConsumer; i++) { + consumerFutures[i] = std::async(std::launch::async, consume, i); + } - // Define producer and consumers - auto produce = [nProducer, &queue](int tid) { - for (int i = tid; i < nElements; i += nProducer) { - queue.add(i); + for(int i = 0; i < nConsumer; i++) { + producerFutures[i].wait(); } - }; + queue.finishAdding(); - auto consume = [&consumerResults, &queue](int tid) { - int element; - while (queue.get(element)) { - consumerResults[tid] += element; + for(int i = 0; i < nConsumer; i++) { + consumerFutures[i].wait(); + } + + // Check + int predictSum = 0; + for(const auto& element : consumerResults) { + predictSum += element; } - }; - - // Run Test - std::vector> producerFutures(nConsumer); - for (int i = 0; i < nProducer; i++) { - producerFutures[i] = std::async(std::launch::async, produce, i); - } - - std::vector> consumerFutures(nConsumer); - for (int i = 0; i < nConsumer; i++) { - consumerFutures[i] = std::async(std::launch::async, consume, i); - } - - for (int i = 0; i < nConsumer; i++) { - producerFutures[i].wait(); - } - queue.finishAdding(); - - for (int i = 0; i < nConsumer; i++) { - consumerFutures[i].wait(); - } - - // Check - int predictSum = 0; - for (const auto& element : consumerResults) { - predictSum += element; - } - ASSERT_EQ(predictSum, targetSum); + ASSERT_EQ(predictSum, targetSum); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/criterion/BenchmarkASG.cpp b/flashlight/pkg/speech/test/criterion/BenchmarkASG.cpp index ff7a083..8778d40 100644 --- a/flashlight/pkg/speech/test/criterion/BenchmarkASG.cpp +++ b/flashlight/pkg/speech/test/criterion/BenchmarkASG.cpp @@ -18,36 +18,37 @@ using namespace fl; using namespace fl::pkg::speech; int main() { - fl::setDevice(0); - fl::init(); - - int N = 30, T = 487, L = 34, B = 20; - - auto asg = AutoSegmentationCriterion(N); - - auto input = Variable(fl::rand({N, T, B}) * 2 - 1, true); - - auto target = Variable( - fl::abs(fl::rand({L, B}, fl::dtype::s32)).astype(fl::dtype::s32) % - (N - 1), - false); - - int ntimes = 50; - Variable b = asg.forward({input, target}).front(); - Variable gradoutput = Variable(fl::rand(b.shape()) * 2 - 2, false); - for (int i = 0; i < 5; ++i) { - b = asg.forward({input, target}).front(); - b.backward(); - } - fl::sync(); - auto s = fl::Timer::start(); - for (int i = 0; i < ntimes; ++i) { - b = asg.forward({input, target}).front(); - b.backward(gradoutput); - } - fl::sync(); - auto e = fl::Timer::stop(s); - std::cout << "Total time (fwd+bwd pass) " << std::setprecision(5) - << e * 1000.0 / ntimes << " msec" << std::endl; - return 0; + fl::setDevice(0); + fl::init(); + + int N = 30, T = 487, L = 34, B = 20; + + auto asg = AutoSegmentationCriterion(N); + + auto input = Variable(fl::rand({N, T, B}) * 2 - 1, true); + + auto target = Variable( + fl::abs(fl::rand({L, B}, fl::dtype::s32)).astype(fl::dtype::s32) + % (N - 1), + false + ); + + int ntimes = 50; + Variable b = asg.forward({input, target}).front(); + Variable gradoutput = Variable(fl::rand(b.shape()) * 2 - 2, false); + for(int i = 0; i < 5; ++i) { + b = asg.forward({input, target}).front(); + b.backward(); + } + fl::sync(); + auto s = fl::Timer::start(); + for(int i = 0; i < ntimes; ++i) { + b = asg.forward({input, target}).front(); + b.backward(gradoutput); + } + fl::sync(); + auto e = fl::Timer::stop(s); + std::cout << "Total time (fwd+bwd pass) " << std::setprecision(5) + << e * 1000.0 / ntimes << " msec" << std::endl; + return 0; } diff --git a/flashlight/pkg/speech/test/criterion/BenchmarkCTC.cpp b/flashlight/pkg/speech/test/criterion/BenchmarkCTC.cpp index feb31ad..68a6d2c 100644 --- a/flashlight/pkg/speech/test/criterion/BenchmarkCTC.cpp +++ b/flashlight/pkg/speech/test/criterion/BenchmarkCTC.cpp @@ -19,40 +19,40 @@ using namespace fl; using namespace fl::pkg::speech; int main() { - fl::setDevice(0); - fl::init(); - - auto ctc = ConnectionistTemporalClassificationCriterion(); - - int N = 30, T = 487, L = 34, B = 10; - - auto input = Variable(fl::log(fl::rand({N, T, B})), true); - - auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)).astype(fl::dtype::s32) % - (N - 2); - - for (int i = 0; i < B; ++i) { - int r = rand() % (L / 2); - t(fl::range(L / 2 + r, fl::end), i) = -1; - } - - Variable target(t, false); - int ntimes = 50; - Variable b = ctc.forward({input, target}).front(); - Variable gradoutput = Variable(fl::rand(b.shape()) * 2 - 2, false); - for (int i = 0; i < 5; ++i) { - b = ctc.forward({input, target}).front(); - b.backward(); - } - fl::sync(); - auto s = fl::Timer::start(); - for (int i = 0; i < ntimes; ++i) { - b = ctc.forward({input, target}).front(); - b.backward(gradoutput); - } - fl::sync(); - auto e = fl::Timer::stop(s); - std::cout << "Total time (fwd+bwd pass) " << std::setprecision(5) - << e * 1000.0 / ntimes << " msec" << std::endl; - return 0; + fl::setDevice(0); + fl::init(); + + auto ctc = ConnectionistTemporalClassificationCriterion(); + + int N = 30, T = 487, L = 34, B = 10; + + auto input = Variable(fl::log(fl::rand({N, T, B})), true); + + auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)).astype(fl::dtype::s32) + % (N - 2); + + for(int i = 0; i < B; ++i) { + int r = rand() % (L / 2); + t(fl::range(L / 2 + r, fl::end), i) = -1; + } + + Variable target(t, false); + int ntimes = 50; + Variable b = ctc.forward({input, target}).front(); + Variable gradoutput = Variable(fl::rand(b.shape()) * 2 - 2, false); + for(int i = 0; i < 5; ++i) { + b = ctc.forward({input, target}).front(); + b.backward(); + } + fl::sync(); + auto s = fl::Timer::start(); + for(int i = 0; i < ntimes; ++i) { + b = ctc.forward({input, target}).front(); + b.backward(gradoutput); + } + fl::sync(); + auto e = fl::Timer::stop(s); + std::cout << "Total time (fwd+bwd pass) " << std::setprecision(5) + << e * 1000.0 / ntimes << " msec" << std::endl; + return 0; } diff --git a/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp b/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp index f1e2d60..b41b56a 100644 --- a/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp +++ b/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp @@ -19,74 +19,75 @@ using namespace fl; using namespace fl::pkg::speech; void timeBeamSearch() { - int N = 40, H = 256, T = 200; - - Seq2SeqCriterion seq2seq( - // Make eos -1 so beam search runs to outputlen - N, /* nClass */ - H, /* hiddenDim */ - -1, /* eosIdx */ - N - 1, /* padIdx */ - 200, /* maxDecoderOutputLen */ - {std::make_shared() /* attentions */}); - - auto input = fl::randn({H, T, 1}, fl::dtype::f32); + int N = 40, H = 256, T = 200; + + Seq2SeqCriterion seq2seq( + // Make eos -1 so beam search runs to outputlen + N, /* nClass */ + H, /* hiddenDim */ + -1, /* eosIdx */ + N - 1, /* padIdx */ + 200, /* maxDecoderOutputLen */ + {std::make_shared() /* attentions */}); + + auto input = fl::randn({H, T, 1}, fl::dtype::f32); + + // Warmup + seq2seq.beamPath(input, Tensor()); + + int iters = 10; + std::vector beamsizes = {1, 5, 10, 20}; + for(auto b : beamsizes) { + auto s = fl::Timer::start(); + for(int i = 0; i < iters; ++i) { + seq2seq.beamPath(input, Tensor(), b); + } + fl::sync(); + auto e = fl::Timer::stop(s); + std::cout << "Total time (beam size: " << b << ") " << std::setprecision(5) + << e * 1000.0 / iters << " msec" << std::endl; + } +} - // Warmup - seq2seq.beamPath(input, Tensor()); +void timeForwardBackward() { + int N = 40, H = 256, B = 2, T = 200, U = 50; + + Seq2SeqCriterion seq2seq( + N, /* nClass */ + H, /* hiddenDim */ + N - 2, /* eosIdx */ + N - 1, /* padIdx */ + 0, /* maxDecoderOutputLen */ + {std::make_shared()} /* attentions */); + + auto input = Variable(fl::randn({H, T, B}, fl::dtype::f32), true); + auto target = noGrad( + (fl::rand({U, B}, fl::dtype::f32) * 0.99 * N).astype(fl::dtype::s32) + ); + + // Warmup + for(int i = 0; i < 10; ++i) { + auto loss = seq2seq({input, target}).front(); + loss.backward(); + } + fl::sync(); - int iters = 10; - std::vector beamsizes = {1, 5, 10, 20}; - for (auto b : beamsizes) { + int iters = 100; auto s = fl::Timer::start(); - for (int i = 0; i < iters; ++i) { - seq2seq.beamPath(input, Tensor(), b); + for(int i = 0; i < iters; ++i) { + auto loss = seq2seq({input, target}).front(); + loss.backward(); } fl::sync(); auto e = fl::Timer::stop(s); - std::cout << "Total time (beam size: " << b << ") " << std::setprecision(5) - << e * 1000.0 / iters << " msec" << std::endl; - } -} - -void timeForwardBackward() { - int N = 40, H = 256, B = 2, T = 200, U = 50; - - Seq2SeqCriterion seq2seq( - N, /* nClass */ - H, /* hiddenDim */ - N - 2, /* eosIdx */ - N - 1, /* padIdx */ - 0, /* maxDecoderOutputLen */ - {std::make_shared()} /* attentions */); - - auto input = Variable(fl::randn({H, T, B}, fl::dtype::f32), true); - auto target = noGrad( - (fl::rand({U, B}, fl::dtype::f32) * 0.99 * N).astype(fl::dtype::s32)); - - // Warmup - for (int i = 0; i < 10; ++i) { - auto loss = seq2seq({input, target}).front(); - loss.backward(); - } - fl::sync(); - - int iters = 100; - auto s = fl::Timer::start(); - for (int i = 0; i < iters; ++i) { - auto loss = seq2seq({input, target}).front(); - loss.backward(); - } - fl::sync(); - auto e = fl::Timer::stop(s); - std::cout << "Total time (fwd+bwd pass) " << std::setprecision(5) - << e * 1000.0 / iters << " msec" << std::endl; + std::cout << "Total time (fwd+bwd pass) " << std::setprecision(5) + << e * 1000.0 / iters << " msec" << std::endl; } int main() { - fl::init(); + fl::init(); - timeForwardBackward(); - timeBeamSearch(); - return 0; + timeForwardBackward(); + timeBeamSearch(); + return 0; } diff --git a/flashlight/pkg/speech/test/criterion/CompareASG.cpp b/flashlight/pkg/speech/test/criterion/CompareASG.cpp index 6c1abf7..f0f9c1e 100644 --- a/flashlight/pkg/speech/test/criterion/CompareASG.cpp +++ b/flashlight/pkg/speech/test/criterion/CompareASG.cpp @@ -35,138 +35,138 @@ constexpr int L = 2000; std::random_device rd; void usage(const char* argv0) { - std::cerr << "usage: " << argv0 << " [COMMAND] [FILES]...\n"; - std::cerr << " generate input_file\n"; - std::cerr << " baseline input_file baseline_file\n"; - std::cerr << " compare input_file baseline_file" << std::endl; - std::exit(1); + std::cerr << "usage: " << argv0 << " [COMMAND] [FILES]...\n"; + std::cerr << " generate input_file\n"; + std::cerr << " baseline input_file baseline_file\n"; + std::cerr << " compare input_file baseline_file" << std::endl; + std::exit(1); } struct CriterionOutput { - Tensor loss; - Tensor inputGrad; - Tensor transGrad; + Tensor loss; + Tensor inputGrad; + Tensor transGrad; }; -CriterionOutput -run(const Tensor& input, const Tensor& target, const Tensor& trans) { - fl::Variable inputVar(input, true); - fl::Variable targetVar(target, false); - fl::Variable transVar(trans, true); - - AutoSegmentationCriterion crit(N, CriterionScaleMode::TARGET_SZ_SQRT); - crit.setParams(transVar, 0); - auto loss = crit.forward({inputVar, targetVar}).front(); - loss.backward(); - - CriterionOutput result; - result.loss = loss.tensor(); - result.inputGrad = inputVar.grad().tensor(); - result.transGrad = transVar.grad().tensor(); - return result; +CriterionOutput run(const Tensor& input, const Tensor& target, const Tensor& trans) { + fl::Variable inputVar(input, true); + fl::Variable targetVar(target, false); + fl::Variable transVar(trans, true); + + AutoSegmentationCriterion crit(N, CriterionScaleMode::TARGET_SZ_SQRT); + crit.setParams(transVar, 0); + auto loss = crit.forward({inputVar, targetVar}).front(); + loss.backward(); + + CriterionOutput result; + result.loss = loss.tensor(); + result.inputGrad = inputVar.grad().tensor(); + result.transGrad = transVar.grad().tensor(); + return result; } // Discrepancy value of 1 corresponds to `allclose(a, b, rtol=1e-5, atol=1e-7)` // just barely returning true. double discrepancy(const Tensor& a, const Tensor& b) { - const auto& ad = a.astype(fl::dtype::f64); - const auto& bd = b.astype(fl::dtype::f64); - return fl::amax(fl::abs(ad - bd) / (1e-7 + 1e-5 * fl::abs(bd))) - .scalar(); + const auto& ad = a.astype(fl::dtype::f64); + const auto& bd = b.astype(fl::dtype::f64); + return fl::amax(fl::abs(ad - bd) / (1e-7 + 1e-5 * fl::abs(bd))) + .scalar(); } void printDiscrepancies( const std::string& prefix, const Tensor& compare, - const Tensor& baseline) { - std::cerr << prefix << "discrepancy=" << std::setprecision(17) - << discrepancy(compare, baseline); - // Check for NaN discrepancies manually. - auto compareNaN = fl::isnan(compare); - auto baselineNaN = fl::isnan(baseline); - if (fl::any(compareNaN && !baselineNaN).asScalar()) { - std::cerr << " (warning: compare has NaNs where baseline does not)"; - } else if (fl::any(compareNaN && baselineNaN).asScalar()) { - std::cerr << " (warning: both baseline and compare have NaNs)"; - } else if (fl::any(baselineNaN).asScalar()) { - std::cerr << " (warning: baseline has NaNs where compare does not)"; - } - std::cerr << std::endl; + const Tensor& baseline +) { + std::cerr << prefix << "discrepancy=" << std::setprecision(17) + << discrepancy(compare, baseline); + // Check for NaN discrepancies manually. + auto compareNaN = fl::isnan(compare); + auto baselineNaN = fl::isnan(baseline); + if(fl::any(compareNaN && !baselineNaN).asScalar()) { + std::cerr << " (warning: compare has NaNs where baseline does not)"; + } else if(fl::any(compareNaN && baselineNaN).asScalar()) { + std::cerr << " (warning: both baseline and compare have NaNs)"; + } else if(fl::any(baselineNaN).asScalar()) { + std::cerr << " (warning: baseline has NaNs where compare does not)"; + } + std::cerr << std::endl; } } // namespace int main(int argc, char** argv) { - fl::init(); - if (argc < 2) { - usage(argv[0]); - } - - std::string command = argv[1]; - - if (command == "generate") { - if (argc != 3) { - usage(argv[0]); - } - - std::seed_seq seeds({rd(), rd(), rd(), rd()}); - std::mt19937 rng(seeds); - - // generate random target sizes - std::vector targetSize(B); - for (int b = 0; b < B; ++b) { - // ensure we have a sample with targetSize=1 and targetSize=L - targetSize[b] = (b == B - 1) ? L : (b == B - 2) ? 1 : (1 + rng() % L); - } - std::shuffle(targetSize.begin(), targetSize.end(), rng); - - // generate random targets with the above sizes - std::vector targetHost(B * L); - for (int b = 0; b < B; ++b) { - auto* targetCur = &targetHost[b * L]; - for (int i = 0; i < targetSize[b]; ++i) { - targetCur[i] = rng() % N; - } - for (int i = targetSize[b]; i < L; ++i) { - targetCur[i] = -1; - } + fl::init(); + if(argc < 2) { + usage(argv[0]); } - uint64_t afSeed = rng(); - afSeed <<= 32; - afSeed ^= rng(); - fl::setSeed(afSeed); - - auto input = fl::randn({N, T, B}); - auto target = Tensor::fromVector({L, B}, targetHost); - auto trans = fl::randn({N, N}); - fl::save(argv[2], input, target, trans); - std::cerr << "input generated" << std::endl; - } else if (command == "baseline") { - if (argc != 4) { - usage(argv[0]); + std::string command = argv[1]; + + if(command == "generate") { + if(argc != 3) { + usage(argv[0]); + } + + std::seed_seq seeds({rd(), rd(), rd(), rd()}); + std::mt19937 rng(seeds); + + // generate random target sizes + std::vector targetSize(B); + for(int b = 0; b < B; ++b) { + // ensure we have a sample with targetSize=1 and targetSize=L + targetSize[b] = (b == B - 1) ? L : (b == B - 2) ? 1 : (1 + rng() % L); + } + std::shuffle(targetSize.begin(), targetSize.end(), rng); + + // generate random targets with the above sizes + std::vector targetHost(B * L); + for(int b = 0; b < B; ++b) { + auto* targetCur = &targetHost[b * L]; + for(int i = 0; i < targetSize[b]; ++i) { + targetCur[i] = rng() % N; + } + for(int i = targetSize[b]; i < L; ++i) { + targetCur[i] = -1; + } + } + + uint64_t afSeed = rng(); + afSeed <<= 32; + afSeed ^= rng(); + fl::setSeed(afSeed); + + auto input = fl::randn({N, T, B}); + auto target = Tensor::fromVector({L, B}, targetHost); + auto trans = fl::randn({N, N}); + fl::save(argv[2], input, target, trans); + std::cerr << "input generated" << std::endl; + } else if(command == "baseline") { + if(argc != 4) { + usage(argv[0]); + } + + Tensor input, target, trans; + fl::load(argv[2], input, target, trans); + auto out = run(input, target, trans); + fl::save(argv[3], out.loss, out.inputGrad, out.transGrad); + std::cerr << "baseline saved" << std::endl; + } else if(command == "compare") { + if(argc != 4) { + usage(argv[0]); + } + + Tensor input, target, trans; + fl::load(argv[2], input, target, trans); + CriterionOutput out0; + fl::load(argv[3], out0.loss, out0.inputGrad, out0.transGrad); + auto out = run(input, target, trans); + std::cerr << "Computing discrepancies vs. 1e-5 rel + 1e-7 abs tolerance\n"; + printDiscrepancies("loss: ", out.loss, out0.loss); + printDiscrepancies("inputGrad: ", out.inputGrad, out0.inputGrad); + printDiscrepancies("transGrad: ", out.transGrad, out0.transGrad); + } else { + usage(argv[0]); } - - Tensor input, target, trans; - fl::load(argv[2], input, target, trans); - auto out = run(input, target, trans); - fl::save(argv[3], out.loss, out.inputGrad, out.transGrad); - std::cerr << "baseline saved" << std::endl; - } else if (command == "compare") { - if (argc != 4) { - usage(argv[0]); - } - - Tensor input, target, trans; - fl::load(argv[2], input, target, trans); - CriterionOutput out0; - fl::load(argv[3], out0.loss, out0.inputGrad, out0.transGrad); - auto out = run(input, target, trans); - std::cerr << "Computing discrepancies vs. 1e-5 rel + 1e-7 abs tolerance\n"; - printDiscrepancies("loss: ", out.loss, out0.loss); - printDiscrepancies("inputGrad: ", out.inputGrad, out0.inputGrad); - printDiscrepancies("transGrad: ", out.transGrad, out0.transGrad); - } else { - usage(argv[0]); - } } diff --git a/flashlight/pkg/speech/test/criterion/CriterionTest.cpp b/flashlight/pkg/speech/test/criterion/CriterionTest.cpp index 049e754..33ba450 100644 --- a/flashlight/pkg/speech/test/criterion/CriterionTest.cpp +++ b/flashlight/pkg/speech/test/criterion/CriterionTest.cpp @@ -22,325 +22,337 @@ namespace { constexpr float kEpsilon = 1E-5; void checkZero(const Tensor& val, float precision = kEpsilon) { - ASSERT_LE(fl::amax(fl::abs(val)).scalar(), precision); + ASSERT_LE(fl::amax(fl::abs(val)).scalar(), precision); } -using JacobianFunc = std::function; +using JacobianFunc = std::function; void jacobianTest( JacobianFunc func, Variable& input, float precision = 1E-3, - float perturbation = 1E-2) { - auto fwdJacobian = - Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); - - for (int i = 0; i < input.elements(); ++i) { - Tensor orig = input.tensor().flatten()(i); - input.tensor().flat(i) = orig - perturbation; - auto outa = func(input).tensor(); - - input.tensor().flat(i) = orig + perturbation; - auto outb = func(input).tensor(); - input.tensor().flat(i) = orig; - - fwdJacobian(fl::span, i) = - fl::reshape((outb - outa), {static_cast(outa.elements())}) * 0.5 / - perturbation; - } - - auto bwdJacobian = - Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); - auto dout = - Variable(fl::full(func(input).shape(), 0, func(input).type()), false); - - for (int i = 0; i < dout.elements(); ++i) { - dout.tensor().flat(i) = 1; // element in 1D view - input.zeroGrad(); - auto out = func(input); - out.backward(dout); - - bwdJacobian(i) = fl::reshape(input.grad().tensor(), {input.elements()}); - dout.tensor().flat(i) = 0; - } - - checkZero(fwdJacobian - bwdJacobian, precision); + float perturbation = 1E-2 +) { + auto fwdJacobian = + Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); + + for(int i = 0; i < input.elements(); ++i) { + Tensor orig = input.tensor().flatten()(i); + input.tensor().flat(i) = orig - perturbation; + auto outa = func(input).tensor(); + + input.tensor().flat(i) = orig + perturbation; + auto outb = func(input).tensor(); + input.tensor().flat(i) = orig; + + fwdJacobian(fl::span, i) = + fl::reshape((outb - outa), {static_cast(outa.elements())}) * 0.5 + / perturbation; + } + + auto bwdJacobian = + Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); + auto dout = + Variable(fl::full(func(input).shape(), 0, func(input).type()), false); + + for(int i = 0; i < dout.elements(); ++i) { + dout.tensor().flat(i) = 1; // element in 1D view + input.zeroGrad(); + auto out = func(input); + out.backward(dout); + + bwdJacobian(i) = fl::reshape(input.grad().tensor(), {input.elements()}); + dout.tensor().flat(i) = 0; + } + + checkZero(fwdJacobian - bwdJacobian, precision); } } // namespace TEST(CriterionTest, CTCEmptyTarget) { - // Subtle - related to memory manager initialization. Will be fixed in a - // future version of ArrayFire after which time this can be removed. The test - // passes/works properly in isolation. - if (FL_BACKEND_CPU) { - GTEST_SKIP() << "Skipping test for CPU backend"; - } - - // Non-empty input, Empty target, batchsize > 0 - auto input = Variable(Tensor({3, 2, 5}), true); - auto target = Variable(Tensor({0, 5}), false); - auto ctc = ConnectionistTemporalClassificationCriterion(); - auto loss = ctc({input, target}).front(); - loss.backward(); - ASSERT_FALSE(fl::any(fl::isnan(loss.tensor())).asScalar()); - - auto funcConvIn = [&](Variable& inp) { - return ctc.forward({inp, target}).front(); - }; - jacobianTest(funcConvIn, input); + // Subtle - related to memory manager initialization. Will be fixed in a + // future version of ArrayFire after which time this can be removed. The test + // passes/works properly in isolation. + if(FL_BACKEND_CPU) { + GTEST_SKIP() << "Skipping test for CPU backend"; + } + + // Non-empty input, Empty target, batchsize > 0 + auto input = Variable(Tensor({3, 2, 5}), true); + auto target = Variable(Tensor({0, 5}), false); + auto ctc = ConnectionistTemporalClassificationCriterion(); + auto loss = ctc({input, target}).front(); + loss.backward(); + ASSERT_FALSE(fl::any(fl::isnan(loss.tensor())).asScalar()); + + auto funcConvIn = [&](Variable& inp) { + return ctc.forward({inp, target}).front(); + }; + jacobianTest(funcConvIn, input); } TEST(CriterionTest, CTCCost) { - // Test case: 1 - auto neginf = -std::numeric_limits::infinity(); - std::array input1 = {0.0, neginf, neginf, 0.0, 0.0, neginf}; - std::array target1 = {0, 0}; - const int N1 = 2, L1 = 2, T1 = 3; + // Test case: 1 + auto neginf = -std::numeric_limits::infinity(); + std::array input1 = {0.0, neginf, neginf, 0.0, 0.0, neginf}; + std::array target1 = {0, 0}; + const int N1 = 2, L1 = 2, T1 = 3; - auto ctc1 = ConnectionistTemporalClassificationCriterion(); - auto input1af = Variable(Tensor::fromArray({N1, T1, 1}, input1), true); - auto target1af = Variable(Tensor::fromArray({L1, 1}, target1), false); + auto ctc1 = ConnectionistTemporalClassificationCriterion(); + auto input1af = Variable(Tensor::fromArray({N1, T1, 1}, input1), true); + auto target1af = Variable(Tensor::fromArray({L1, 1}, target1), false); - auto loss1 = ctc1({input1af, target1af}).front(); - ASSERT_NEAR(loss1.scalar(), 0.0, kEpsilon); + auto loss1 = ctc1({input1af, target1af}).front(); + ASSERT_NEAR(loss1.scalar(), 0.0, kEpsilon); - // Test case: 2 - std::array target2 = {1, 2}; - int N2 = 4, L2 = 2, T2 = 3; + // Test case: 2 + std::array target2 = {1, 2}; + int N2 = 4, L2 = 2, T2 = 3; - auto ctc2 = ConnectionistTemporalClassificationCriterion(); - auto input2af = Variable(fl::full({N2, T2, 1}, 0.0, fl::dtype::f32), true); - auto target2af = Variable(Tensor::fromArray({L2, 1}, target2), false); + auto ctc2 = ConnectionistTemporalClassificationCriterion(); + auto input2af = Variable(fl::full({N2, T2, 1}, 0.0, fl::dtype::f32), true); + auto target2af = Variable(Tensor::fromArray({L2, 1}, target2), false); - auto loss2 = ctc2({input2af, target2af}).front(); - ASSERT_NEAR(loss2.scalar(), -log(0.25 * 0.25 * 0.25 * 5), kEpsilon); + auto loss2 = ctc2({input2af, target2af}).front(); + ASSERT_NEAR(loss2.scalar(), -log(0.25 * 0.25 * 0.25 * 5), kEpsilon); } TEST(CriterionTest, CTCJacobian) { - int N = 30, T = 80, L = 20; - auto in = Variable(fl::log(fl::rand({N, T, 1})), true); - auto t = fl::abs(fl::rand({L, 1}, fl::dtype::s32)) % (N - 2); - auto tgt = Variable(t.astype(fl::dtype::s32), false); - auto l = ConnectionistTemporalClassificationCriterion( - CriterionScaleMode::INPUT_SZ_SQRT); - auto funcConvIn = [&](Variable& inp) { - return l.forward({inp, tgt}).front(); - }; - jacobianTest(funcConvIn, in); -} - -TEST(CriterionTest, Batching) { - { - int N = 10, T = 25, L = 15, B = 5; - auto in = Variable(fl::log(fl::rand({N, T, B})), true); - auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)) % (N - 2); - for (int i = 0; i < B; ++i) { - int r = rand() % L; - if (r > 0) { - t(fl::range(r, fl::end), i) = -1; - } - } + int N = 30, T = 80, L = 20; + auto in = Variable(fl::log(fl::rand({N, T, 1})), true); + auto t = fl::abs(fl::rand({L, 1}, fl::dtype::s32)) % (N - 2); auto tgt = Variable(t.astype(fl::dtype::s32), false); auto l = ConnectionistTemporalClassificationCriterion( - CriterionScaleMode::TARGET_SZ_SQRT); + CriterionScaleMode::INPUT_SZ_SQRT + ); auto funcConvIn = [&](Variable& inp) { - return l.forward({inp, tgt}).front(); - }; + return l.forward({inp, tgt}).front(); + }; jacobianTest(funcConvIn, in); - } - { - int N = 80, T = 50, L = 25, B = 10; - auto in = Variable(fl::log(fl::rand({N, T, B})), true); - auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)) % (N - 2); - for (int i = 0; i < B; ++i) { - int r = rand() % L; - if (r > 0) { - t(fl::range(r, fl::end), i) = -1; - } - } - auto tgt = Variable(t.astype(fl::dtype::s32), false); - auto l = ConnectionistTemporalClassificationCriterion( - CriterionScaleMode::TARGET_SZ); - auto output = l.forward({in, tgt}).front(); +} - for (int i = 0; i < B; ++i) { - auto inel = moddims(in(fl::span, fl::span, i), {N, T, 1}); - auto tgtel = moddims(tgt(fl::span, i), {L, 1}); - auto outputCur = l.forward({inel, tgtel}).front(); - checkZero(output.tensor()(i) - outputCur.tensor(), 1E-6); +TEST(CriterionTest, Batching) { + { + int N = 10, T = 25, L = 15, B = 5; + auto in = Variable(fl::log(fl::rand({N, T, B})), true); + auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)) % (N - 2); + for(int i = 0; i < B; ++i) { + int r = rand() % L; + if(r > 0) { + t(fl::range(r, fl::end), i) = -1; + } + } + auto tgt = Variable(t.astype(fl::dtype::s32), false); + auto l = ConnectionistTemporalClassificationCriterion( + CriterionScaleMode::TARGET_SZ_SQRT + ); + auto funcConvIn = [&](Variable& inp) { + return l.forward({inp, tgt}).front(); + }; + jacobianTest(funcConvIn, in); + } + { + int N = 80, T = 50, L = 25, B = 10; + auto in = Variable(fl::log(fl::rand({N, T, B})), true); + auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)) % (N - 2); + for(int i = 0; i < B; ++i) { + int r = rand() % L; + if(r > 0) { + t(fl::range(r, fl::end), i) = -1; + } + } + auto tgt = Variable(t.astype(fl::dtype::s32), false); + auto l = ConnectionistTemporalClassificationCriterion( + CriterionScaleMode::TARGET_SZ + ); + auto output = l.forward({in, tgt}).front(); + + for(int i = 0; i < B; ++i) { + auto inel = moddims(in(fl::span, fl::span, i), {N, T, 1}); + auto tgtel = moddims(tgt(fl::span, i), {L, 1}); + auto outputCur = l.forward({inel, tgtel}).front(); + checkZero(output.tensor()(i) - outputCur.tensor(), 1E-6); + } } - } } TEST(CriterionTest, CTCCompareTensorflow) { - // The following test cases are taken from Tensor Flow CTC implementation - // tinyurl.com/y9du5v5a - - // Test Case: 1 - const int T1 = 5, N1 = 6, L1 = 5; - std::array target1 = {0, 1, 2, 1, 0}; - float lossExpected1 = 3.34211; - std::array input1 = { - 0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553, - 0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436, - 0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688, - 0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533, - 0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107, - }; - std::transform( - input1.begin(), input1.end(), input1.begin(), [](float p) -> float { + // The following test cases are taken from Tensor Flow CTC implementation + // tinyurl.com/y9du5v5a + + // Test Case: 1 + const int T1 = 5, N1 = 6, L1 = 5; + std::array target1 = {0, 1, 2, 1, 0}; + float lossExpected1 = 3.34211; + std::array input1 = { + 0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553, + 0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436, + 0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688, + 0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533, + 0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107, + }; + std::transform( + input1.begin(), + input1.end(), + input1.begin(), + [](float p) -> float { return log(p); - }); - std::array gradExpected1 = { - -0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553, - 0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436, - 0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688, - 0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533, - -0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107}; - - auto ctc1 = ConnectionistTemporalClassificationCriterion(); - auto input1af = Variable(Tensor::fromArray({N1, T1, 1}, input1), true); - auto target1af = Variable(Tensor::fromArray({L1, 1}, target1), false); - auto gradExpected1af = - Variable(Tensor::fromArray({N1, T1, 1}, gradExpected1), false); - - auto loss1 = ctc1({input1af, target1af}).front(); - ASSERT_NEAR(loss1.scalar(), lossExpected1, kEpsilon); - - loss1.backward(); - checkZero(input1af.grad().tensor() - gradExpected1af.tensor()); - - // Test Case: 2 - const int T2 = 5, N2 = 6, L2 = 4; - std::array target2 = {0, 1, 1, 0}; - float lossExpected2 = 5.42262; - std::array input2 = { - 0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508, - 0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549, - 0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456, - 0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345, - 0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046, - }; - std::transform( - input2.begin(), input2.end(), input2.begin(), [](float p) -> float { + } + ); + std::array gradExpected1 = { + -0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553, + 0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436, + 0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688, + 0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533, + -0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107}; + + auto ctc1 = ConnectionistTemporalClassificationCriterion(); + auto input1af = Variable(Tensor::fromArray({N1, T1, 1}, input1), true); + auto target1af = Variable(Tensor::fromArray({L1, 1}, target1), false); + auto gradExpected1af = + Variable(Tensor::fromArray({N1, T1, 1}, gradExpected1), false); + + auto loss1 = ctc1({input1af, target1af}).front(); + ASSERT_NEAR(loss1.scalar(), lossExpected1, kEpsilon); + + loss1.backward(); + checkZero(input1af.grad().tensor() - gradExpected1af.tensor()); + + // Test Case: 2 + const int T2 = 5, N2 = 6, L2 = 4; + std::array target2 = {0, 1, 1, 0}; + float lossExpected2 = 5.42262; + std::array input2 = { + 0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508, + 0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549, + 0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456, + 0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345, + 0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046, + }; + std::transform( + input2.begin(), + input2.end(), + input2.begin(), + [](float p) -> float { return log(p); - }); - std::array gradExpected2 = { - -0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508, - 0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549, - 0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544, - 0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345, - -0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046, - }; - - auto ctc2 = ConnectionistTemporalClassificationCriterion(); - auto input2af = Variable(Tensor::fromArray({N2, T2, 1}, input2), true); - auto target2af = Variable(Tensor::fromArray({L2, 1}, target2), false); - auto gradExpected2af = - Variable(Tensor::fromArray({N2, T2, 1}, gradExpected2), false); - - auto loss2 = ctc2({input2af, target2af}).front(); - ASSERT_NEAR(loss2.scalar(), lossExpected2, kEpsilon); - - loss2.backward(); - checkZero(input2af.grad().tensor() - gradExpected2af.tensor()); + } + ); + std::array gradExpected2 = { + -0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508, + 0.24082, -0.602467, 0.0557226, 0.0546814, 0.0557528, 0.19549, + 0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, -0.797544, + 0.280884, -0.570478, 0.0326593, 0.0339046, 0.0326856, 0.190345, + -0.576714, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046, + }; + + auto ctc2 = ConnectionistTemporalClassificationCriterion(); + auto input2af = Variable(Tensor::fromArray({N2, T2, 1}, input2), true); + auto target2af = Variable(Tensor::fromArray({L2, 1}, target2), false); + auto gradExpected2af = + Variable(Tensor::fromArray({N2, T2, 1}, gradExpected2), false); + + auto loss2 = ctc2({input2af, target2af}).front(); + ASSERT_NEAR(loss2.scalar(), lossExpected2, kEpsilon); + + loss2.backward(); + checkZero(input2af.grad().tensor() - gradExpected2af.tensor()); } TEST(CriterionTest, ViterbiPath) { - // Test case: 1 - auto in = fl::rand({4, 5, 1}); // All values < 1 - std::array expectedpath1 = {3, 2, 0, 2, 2}; - for (int j = 0; j < 5; ++j) { - in(expectedpath1[j], j) = 2; - } - ConnectionistTemporalClassificationCriterion ctc; - auto vpath1Arr = ctc.viterbiPath(in); - Tensor expPath1Arr = Tensor::fromArray({5, 1}, expectedpath1); - checkZero(vpath1Arr - expPath1Arr); - - // test batch input - auto intile = fl::tile(in, {1, 1, 2}); - auto vpath1bArr = ctc.viterbiPath(intile); - checkZero(vpath1bArr - fl::tile(expPath1Arr, {1, 2})); - - // Test case: 2 - constexpr int T2 = 4, N2 = 3; - - // clang-format off - std::array input2Vec = { - 0, 0, 7, - 5, 4, 3, - 5, 8, 5, - 5, 4, 3, - }; - std::array trans2Vec = { - 0, 2, 0, - 0, 0, 2, - 2, 0, 0, - }; - std::array expectedPath2Vec = {2, 1, 1, 0}; - // clang-format on - - Tensor input2 = Tensor::fromArray({N2, T2, 1}, input2Vec); - Tensor trans2 = Tensor::fromArray({{N2, N2}}, trans2Vec); - Tensor expectedPath2 = Tensor::fromArray({T2, 1}, expectedPath2Vec); - - AutoSegmentationCriterion asg(N2); - asg.setParams(Variable(trans2, true), 0); - auto path2 = asg.viterbiPath(input2); - checkZero(path2 - expectedPath2); - - // Test case: 2b (batching) - auto input2b = fl::tile(input2, {1, 1, 77}); - auto expectedPath2b = fl::tile(expectedPath2, {1, 77}); - auto path2b = asg.viterbiPath(input2b); - checkZero(path2b - expectedPath2b); - - // If trasition probablities are same, CTC and ASG viterbi paths should - // match - AutoSegmentationCriterion asg2(30); - asg.param(0).tensor() = fl::full({30, 30}, 1.0); - for (int t = 1; t < 5; ++t) { - Tensor randInput = fl::rand({30, t * 10, t}); - auto asgPathArr = asg2.viterbiPath(randInput); - auto ctcPathArr = ctc.viterbiPath(randInput); - ASSERT_EQ(asgPathArr.shape(), asgPathArr.shape()); - checkZero(asgPathArr - ctcPathArr); - } + // Test case: 1 + auto in = fl::rand({4, 5, 1}); // All values < 1 + std::array expectedpath1 = {3, 2, 0, 2, 2}; + for(int j = 0; j < 5; ++j) { + in(expectedpath1[j], j) = 2; + } + ConnectionistTemporalClassificationCriterion ctc; + auto vpath1Arr = ctc.viterbiPath(in); + Tensor expPath1Arr = Tensor::fromArray({5, 1}, expectedpath1); + checkZero(vpath1Arr - expPath1Arr); + + // test batch input + auto intile = fl::tile(in, {1, 1, 2}); + auto vpath1bArr = ctc.viterbiPath(intile); + checkZero(vpath1bArr - fl::tile(expPath1Arr, {1, 2})); + + // Test case: 2 + constexpr int T2 = 4, N2 = 3; + + // clang-format off + std::array input2Vec = { + 0, 0, 7, + 5, 4, 3, + 5, 8, 5, + 5, 4, 3, + }; + std::array trans2Vec = { + 0, 2, 0, + 0, 0, 2, + 2, 0, 0, + }; + std::array expectedPath2Vec = {2, 1, 1, 0}; + // clang-format on + + Tensor input2 = Tensor::fromArray({N2, T2, 1}, input2Vec); + Tensor trans2 = Tensor::fromArray({{N2, N2}}, trans2Vec); + Tensor expectedPath2 = Tensor::fromArray({T2, 1}, expectedPath2Vec); + + AutoSegmentationCriterion asg(N2); + asg.setParams(Variable(trans2, true), 0); + auto path2 = asg.viterbiPath(input2); + checkZero(path2 - expectedPath2); + + // Test case: 2b (batching) + auto input2b = fl::tile(input2, {1, 1, 77}); + auto expectedPath2b = fl::tile(expectedPath2, {1, 77}); + auto path2b = asg.viterbiPath(input2b); + checkZero(path2b - expectedPath2b); + + // If trasition probablities are same, CTC and ASG viterbi paths should + // match + AutoSegmentationCriterion asg2(30); + asg.param(0).tensor() = fl::full({30, 30}, 1.0); + for(int t = 1; t < 5; ++t) { + Tensor randInput = fl::rand({30, t * 10, t}); + auto asgPathArr = asg2.viterbiPath(randInput); + auto ctcPathArr = ctc.viterbiPath(randInput); + ASSERT_EQ(asgPathArr.shape(), asgPathArr.shape()); + checkZero(asgPathArr - ctcPathArr); + } } // Test with alternating blanks and with varying target sizes TEST(CriterionTest, ASGAlternatingBlanks) { - AutoSegmentationCriterion criterion(2); - int C = 2; // (one class + blank) - int T = 7; - int mL = 3; - int B = 2; - std::vector xV = { - -0x1.1f60fap+1, -0x1.0518e2p+0, 0x1.2016e2p-3, -0x1.dfe0dp-4, - 0x1.00ee32p-2, 0x1.af74fp-2, -0x1.f29964p-2, 0x1.977e08p-2, - -0x1.52548ep-1, -0x1.ae9504p-3, 0x1.bcf1fcp+1, 0x1.31ad5p+0, - 0x1.9bc5aep-1, 0x1.3c7dacp-1, 0x1.3e2852p-1, 0x1.6699f4p-1, - 0x1.095a5p+0, 0x1.1840bcp-1, 0x1.465a4ep-1, 0x1.2c4cacp-1, - 0x1.754998p-1, 0x1.cb6698p-2, -0x1.1cadcp+0, 0x1.757b88p-2, - 0x1.3dec32p+0, 0x1.320fp+0, -0x1.9eb1a4p-1, -0x1.e43beap-2}; - Tensor x = Tensor::fromVector(Shape({C, T, B}), xV); - Tensor y = fl::full({mL * 2 + 1, B}, -1, fl::dtype::s32); - int L; - L = 2; - y(fl::range(0, 2 * L + 1, 2), 0) = 1; - y(fl::range(1, 2 * L, 2), 0) = 0; - L = 3; - y(fl::range(0, 2 * L + 1, 2), 1) = 1; - y(fl::range(1, 2 * L, 2), 1) = 0; - Tensor expectedPath = fl::full({T, B}, 1); - expectedPath(1, 0) = 0; - expectedPath(5, 0) = 0; - expectedPath(1, 1) = 0; - expectedPath(3, 1) = 0; - expectedPath(5, 1) = 0; - Tensor path = criterion.viterbiPathWithTarget(x, y); - checkZero(path - expectedPath); + AutoSegmentationCriterion criterion(2); + int C = 2; // (one class + blank) + int T = 7; + int mL = 3; + int B = 2; + std::vector xV = { + -0x1.1f60fap+1, -0x1.0518e2p+0, 0x1.2016e2p-3, -0x1.dfe0dp-4, + 0x1.00ee32p-2, 0x1.af74fp-2, -0x1.f29964p-2, 0x1.977e08p-2, + -0x1.52548ep-1, -0x1.ae9504p-3, 0x1.bcf1fcp+1, 0x1.31ad5p+0, + 0x1.9bc5aep-1, 0x1.3c7dacp-1, 0x1.3e2852p-1, 0x1.6699f4p-1, + 0x1.095a5p+0, 0x1.1840bcp-1, 0x1.465a4ep-1, 0x1.2c4cacp-1, + 0x1.754998p-1, 0x1.cb6698p-2, -0x1.1cadcp+0, 0x1.757b88p-2, + 0x1.3dec32p+0, 0x1.320fp+0, -0x1.9eb1a4p-1, -0x1.e43beap-2}; + Tensor x = Tensor::fromVector(Shape({C, T, B}), xV); + Tensor y = fl::full({mL* 2 + 1, B}, -1, fl::dtype::s32); + int L; + L = 2; + y(fl::range(0, 2 * L + 1, 2), 0) = 1; + y(fl::range(1, 2 * L, 2), 0) = 0; + L = 3; + y(fl::range(0, 2 * L + 1, 2), 1) = 1; + y(fl::range(1, 2 * L, 2), 1) = 0; + Tensor expectedPath = fl::full({T, B}, 1); + expectedPath(1, 0) = 0; + expectedPath(5, 0) = 0; + expectedPath(1, 1) = 0; + expectedPath(3, 1) = 0; + expectedPath(5, 1) = 0; + Tensor path = criterion.viterbiPathWithTarget(x, y); + checkZero(path - expectedPath); } // Test constrained viterbi path for ctc and asg criterion. @@ -349,567 +361,576 @@ TEST(CriterionTest, ASGAlternatingBlanks) { // Expected output is [0, 0, 1, 1, 2, 2, 3, 3] for both ctc and ASG with // constant transitions TEST(CriterionTest, ViterbiPathConstrained) { - const int B = 2; - const int T = 8; - const int N = 5; - const int L = 4; - Tensor target = fl::arange({L, B}, 0, fl::dtype::s32); - Tensor input = fl::full({N, T, B}, 0.01); - Tensor expectedPath = fl::full({T, B}, 0, fl::dtype::s32); - for (int i = 0; i < L; i++) { - input(i, i * 2 + 1, fl::span) = 1.0; - expectedPath(i * 2, fl::span) = i; - expectedPath(i * 2 + 1, fl::span) = i; - } - - ConnectionistTemporalClassificationCriterion ctc; - Tensor ctcPath = ctc.viterbiPathWithTarget(input, target); - Tensor diff = ctcPath - expectedPath; - ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); - - AutoSegmentationCriterion asg(N); - asg.param(0).tensor() = fl::full({N, N}, 1.0); - Tensor asgPath = asg.viterbiPathWithTarget(input, target); - diff = asgPath - expectedPath; - ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); + const int B = 2; + const int T = 8; + const int N = 5; + const int L = 4; + Tensor target = fl::arange({L, B}, 0, fl::dtype::s32); + Tensor input = fl::full({N, T, B}, 0.01); + Tensor expectedPath = fl::full({T, B}, 0, fl::dtype::s32); + for(int i = 0; i < L; i++) { + input(i, i * 2 + 1, fl::span) = 1.0; + expectedPath(i * 2, fl::span) = i; + expectedPath(i * 2 + 1, fl::span) = i; + } + + ConnectionistTemporalClassificationCriterion ctc; + Tensor ctcPath = ctc.viterbiPathWithTarget(input, target); + Tensor diff = ctcPath - expectedPath; + ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); + + AutoSegmentationCriterion asg(N); + asg.param(0).tensor() = fl::full({N, N}, 1.0); + Tensor asgPath = asg.viterbiPathWithTarget(input, target); + diff = asgPath - expectedPath; + ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); } // Test that CTC can return a path with no spaces TEST(CriterionTest, CTCViterbiPathNoSpaces) { - const int B = 3; // Batchsize - const int T = 10; // Utterance length - const int N = 30; // Number of tokens - const int L = 1; // Length of target - const int targetIdx = 1; // Token Idx of target - - Tensor input = fl::full({N, T, B}, 0.01); - Tensor target = fl::full({L, B}, 0, fl::dtype::s32); - Tensor expectedPath = fl::full({T, B}, targetIdx, fl::dtype::s32); - - // targetIdx has the highest prob for all t in T - input(targetIdx, fl::span, fl::span) = 1.0; - target(fl::span, fl::span) = targetIdx; - - ConnectionistTemporalClassificationCriterion ctc; - Tensor vpathArr = ctc.viterbiPathWithTarget(input, target); - Tensor diff = expectedPath - vpathArr; - ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); + const int B = 3; // Batchsize + const int T = 10; // Utterance length + const int N = 30; // Number of tokens + const int L = 1; // Length of target + const int targetIdx = 1; // Token Idx of target + + Tensor input = fl::full({N, T, B}, 0.01); + Tensor target = fl::full({L, B}, 0, fl::dtype::s32); + Tensor expectedPath = fl::full({T, B}, targetIdx, fl::dtype::s32); + + // targetIdx has the highest prob for all t in T + input(targetIdx, fl::span, fl::span) = 1.0; + target(fl::span, fl::span) = targetIdx; + + ConnectionistTemporalClassificationCriterion ctc; + Tensor vpathArr = ctc.viterbiPathWithTarget(input, target); + Tensor diff = expectedPath - vpathArr; + ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); } // Test that CTC can return a path that optionally ends with a space TEST(CriterionTest, CTCViterbiPathConstrainedEndWithSpace) { - const int B = 3; // Batchsize - const int T = 10; // Utterance length - const int N = 30; // Number of tokens - const int blankLabel = N - 1; - const int L = 1; // Length of target - const int targetIdx = 1; // Token Idx of target - - Tensor input = fl::full({N, T, B}, 0.01); - Tensor target = fl::full({L, B}, targetIdx, fl::dtype::s32); - Tensor expectedPath = fl::full({T, B}, targetIdx, fl::dtype::s32); - // targetIdx has the highest prob for all t in T, except for T - 1, which is - // a blank label - input(targetIdx, fl::span, fl::span) = 1.0; - input(targetIdx, T - 1, fl::span) = 0.00; - input(blankLabel, T - 1, fl::span) = 1.0; - expectedPath(T - 1, fl::span) = blankLabel; - - ConnectionistTemporalClassificationCriterion ctc; - Tensor vpathArr = ctc.viterbiPathWithTarget(input, target); - Tensor diff = expectedPath - vpathArr; - ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); + const int B = 3; // Batchsize + const int T = 10; // Utterance length + const int N = 30; // Number of tokens + const int blankLabel = N - 1; + const int L = 1; // Length of target + const int targetIdx = 1; // Token Idx of target + + Tensor input = fl::full({N, T, B}, 0.01); + Tensor target = fl::full({L, B}, targetIdx, fl::dtype::s32); + Tensor expectedPath = fl::full({T, B}, targetIdx, fl::dtype::s32); + // targetIdx has the highest prob for all t in T, except for T - 1, which is + // a blank label + input(targetIdx, fl::span, fl::span) = 1.0; + input(targetIdx, T - 1, fl::span) = 0.00; + input(blankLabel, T - 1, fl::span) = 1.0; + expectedPath(T - 1, fl::span) = blankLabel; + + ConnectionistTemporalClassificationCriterion ctc; + Tensor vpathArr = ctc.viterbiPathWithTarget(input, target); + Tensor diff = expectedPath - vpathArr; + ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); } // Test that CTC can return a path that optionally begins with a space TEST(CriterionTest, CTCViterbiPathConstrainedBeginWithSpace) { - const int B = 3; // Batchsize - const int T = 10; // Utterance length - const int N = 30; // Number of tokens - const int blankLabel = N - 1; - const int L = 1; // Length of target - const int targetIdx = 1; // Token Idx of target - - // targetIdx has the highest prob for all t in T, except for 0, which is - // a blank label - Tensor input = fl::full({N, T, B}, 0.01); - Tensor target = fl::full({L, B}, targetIdx, fl::dtype::s32); - Tensor expectedPath = fl::full({T, B}, targetIdx, fl::dtype::s32); - input(targetIdx, fl::span, fl::span) = 1.0; - input(targetIdx, 0, fl::span) = 0.00; - input(blankLabel, 0, fl::span) = 1.0; - expectedPath(0, fl::span) = blankLabel; - - ConnectionistTemporalClassificationCriterion ctc; - Tensor vpathArr = ctc.viterbiPathWithTarget(input, target); - Tensor diff = expectedPath - vpathArr; - ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); + const int B = 3; // Batchsize + const int T = 10; // Utterance length + const int N = 30; // Number of tokens + const int blankLabel = N - 1; + const int L = 1; // Length of target + const int targetIdx = 1; // Token Idx of target + + // targetIdx has the highest prob for all t in T, except for 0, which is + // a blank label + Tensor input = fl::full({N, T, B}, 0.01); + Tensor target = fl::full({L, B}, targetIdx, fl::dtype::s32); + Tensor expectedPath = fl::full({T, B}, targetIdx, fl::dtype::s32); + input(targetIdx, fl::span, fl::span) = 1.0; + input(targetIdx, 0, fl::span) = 0.00; + input(blankLabel, 0, fl::span) = 1.0; + expectedPath(0, fl::span) = blankLabel; + + ConnectionistTemporalClassificationCriterion ctc; + Tensor vpathArr = ctc.viterbiPathWithTarget(input, target); + Tensor diff = expectedPath - vpathArr; + ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); } // Test that CTC can return a path that optionally begins and ends with a space TEST(CriterionTest, CTCViterbiPathConstrainedBeginAndEndWithSpace) { - const int B = 3; // Batchsize - const int T = 10; // Utterance length - const int N = 30; // Number of tokens - const int blankLabel = N - 1; - const int L = 1; // Length of target - const int targetIdx = 1; // Token Idx of target - - // targetIdx has the highest prob for all t in T, except for 0 and T -1 - // which is // a blank label - Tensor input = fl::full({N, T, B}, 0.01); - Tensor target = fl::full({L, B}, targetIdx, fl::dtype::s32); - Tensor expectedPath = fl::full({T, B}, targetIdx, fl::dtype::s32); - input(targetIdx, fl::span, fl::span) = 1.0; - input(targetIdx, 0, fl::span) = 0.00; - input(blankLabel, 0, fl::span) = 1.0; - expectedPath(0, fl::span) = blankLabel; - - input(targetIdx, T - 1, fl::span) = 0.00; - expectedPath(T - 1, fl::span) = blankLabel; - - ConnectionistTemporalClassificationCriterion ctc; - Tensor vpathArr = ctc.viterbiPathWithTarget(input, target); - Tensor diff = expectedPath - vpathArr; - ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); + const int B = 3; // Batchsize + const int T = 10; // Utterance length + const int N = 30; // Number of tokens + const int blankLabel = N - 1; + const int L = 1; // Length of target + const int targetIdx = 1; // Token Idx of target + + // targetIdx has the highest prob for all t in T, except for 0 and T -1 + // which is // a blank label + Tensor input = fl::full({N, T, B}, 0.01); + Tensor target = fl::full({L, B}, targetIdx, fl::dtype::s32); + Tensor expectedPath = fl::full({T, B}, targetIdx, fl::dtype::s32); + input(targetIdx, fl::span, fl::span) = 1.0; + input(targetIdx, 0, fl::span) = 0.00; + input(blankLabel, 0, fl::span) = 1.0; + expectedPath(0, fl::span) = blankLabel; + + input(targetIdx, T - 1, fl::span) = 0.00; + expectedPath(T - 1, fl::span) = blankLabel; + + ConnectionistTemporalClassificationCriterion ctc; + Tensor vpathArr = ctc.viterbiPathWithTarget(input, target); + Tensor diff = expectedPath - vpathArr; + ASSERT_LE(fl::amax(fl::abs(diff)).scalar(), kEpsilon); } TEST(CriterionTest, FCCCost) { - // Test case: 1 - std::array input1 = { - 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}; - std::transform( - input1.begin(), input1.end(), input1.begin(), [](float p) -> float { + // Test case: 1 + std::array input1 = { + 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}; + std::transform( + input1.begin(), + input1.end(), + input1.begin(), + [](float p) -> float { return log(p); - }); - std::array dummyTarget1 = {0, 0}; - const int N1 = 2, L1 = 2, T1 = 3, B1 = 2; - - auto fcc1 = FullConnectionCriterion(N1); - auto input1af = Variable(Tensor::fromArray({N1, T1, B1}, input1), true); - auto target1af = Variable(Tensor::fromArray({L1, B1}, dummyTarget1), false); - - auto l1 = fcc1(input1af, target1af); - std::vector loss1Host(B1); - l1.host(loss1Host.data()); - ASSERT_NEAR(loss1Host[0], 0.0, kEpsilon); - ASSERT_NEAR(loss1Host[1], 0.0, kEpsilon); - - // Test case: 2 - std::array input2; - std::fill_n(input2.data(), 12, log(0.25)); - std::array dummyTarget2 = {1, 2}; - int N2 = 4, L2 = 2, T2 = 3, B2 = 1; - - auto fcc2 = FullConnectionCriterion(N2); - auto input2af = Variable(Tensor::fromArray({N2, T2, B2}, input2), true); - auto target2af = Variable(Tensor::fromArray({L2, B2}, dummyTarget2), false); - - auto l2 = fcc2(input2af, target2af); - - std::vector loss2Host(1); - l2.host(loss2Host.data()); - ASSERT_NEAR(loss2Host[0], 0.0, kEpsilon); - - // Test case: 3 - int N3 = 40, T3 = 300, L3 = 50, B3 = 3; - auto in = logSoftmax(Variable(fl::rand({N3, T3, B3}), true), 0); - auto t = fl::abs(fl::rand({L3, B3}, fl::dtype::s32)) % (N3 - 1); - auto tgt = Variable(t.astype(fl::dtype::s32), false); - auto fcc3 = FullConnectionCriterion(N3); - - auto l3 = fcc3(in, tgt); - auto loss3Host = l3.tensor().toHostVector(); - ASSERT_NEAR(loss3Host[0], 0.0, kEpsilon); - ASSERT_NEAR(loss3Host[1], 0.0, kEpsilon); - ASSERT_NEAR(loss3Host[2], 0.0, kEpsilon); + } + ); + std::array dummyTarget1 = {0, 0}; + const int N1 = 2, L1 = 2, T1 = 3, B1 = 2; + + auto fcc1 = FullConnectionCriterion(N1); + auto input1af = Variable(Tensor::fromArray({N1, T1, B1}, input1), true); + auto target1af = Variable(Tensor::fromArray({L1, B1}, dummyTarget1), false); + + auto l1 = fcc1(input1af, target1af); + std::vector loss1Host(B1); + l1.host(loss1Host.data()); + ASSERT_NEAR(loss1Host[0], 0.0, kEpsilon); + ASSERT_NEAR(loss1Host[1], 0.0, kEpsilon); + + // Test case: 2 + std::array input2; + std::fill_n(input2.data(), 12, log(0.25)); + std::array dummyTarget2 = {1, 2}; + int N2 = 4, L2 = 2, T2 = 3, B2 = 1; + + auto fcc2 = FullConnectionCriterion(N2); + auto input2af = Variable(Tensor::fromArray({N2, T2, B2}, input2), true); + auto target2af = Variable(Tensor::fromArray({L2, B2}, dummyTarget2), false); + + auto l2 = fcc2(input2af, target2af); + + std::vector loss2Host(1); + l2.host(loss2Host.data()); + ASSERT_NEAR(loss2Host[0], 0.0, kEpsilon); + + // Test case: 3 + int N3 = 40, T3 = 300, L3 = 50, B3 = 3; + auto in = logSoftmax(Variable(fl::rand({N3, T3, B3}), true), 0); + auto t = fl::abs(fl::rand({L3, B3}, fl::dtype::s32)) % (N3 - 1); + auto tgt = Variable(t.astype(fl::dtype::s32), false); + auto fcc3 = FullConnectionCriterion(N3); + + auto l3 = fcc3(in, tgt); + auto loss3Host = l3.tensor().toHostVector(); + ASSERT_NEAR(loss3Host[0], 0.0, kEpsilon); + ASSERT_NEAR(loss3Host[1], 0.0, kEpsilon); + ASSERT_NEAR(loss3Host[2], 0.0, kEpsilon); } TEST(CriterionTest, FCCJacobian) { - int N = 3, T = 8, L = 1, B = 2; - auto in = Variable(fl::log(fl::rand({N, T, B})), true); - auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)) % (N - 1); - auto tgt = Variable(t.astype(fl::dtype::s32), false); - auto l = FullConnectionCriterion(N, CriterionScaleMode::TARGET_SZ_SQRT); - - // Test case for input - auto funcIn = [&](Variable& inp) { return l.forward(inp, tgt); }; - jacobianTest(funcIn, in); - - // Test case for transition - auto transition = Variable(fl::rand({N, N}), true); - auto funcTrans = [&](Variable& transitionP) { - l.setParams(transitionP, 0); - return l.forward(in, tgt); - }; - jacobianTest(funcTrans, transition); + int N = 3, T = 8, L = 1, B = 2; + auto in = Variable(fl::log(fl::rand({N, T, B})), true); + auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)) % (N - 1); + auto tgt = Variable(t.astype(fl::dtype::s32), false); + auto l = FullConnectionCriterion(N, CriterionScaleMode::TARGET_SZ_SQRT); + + // Test case for input + auto funcIn = [&](Variable& inp) { return l.forward(inp, tgt); }; + jacobianTest(funcIn, in); + + // Test case for transition + auto transition = Variable(fl::rand({N, N}), true); + auto funcTrans = [&](Variable& transitionP) { + l.setParams(transitionP, 0); + return l.forward(in, tgt); + }; + jacobianTest(funcTrans, transition); } TEST(CriterionTest, FACCost) { - // Test case: 1 - std::array input1 = { - 1.0, 0.0, 0.0, 1.0, 0.5, 0.5, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}; - std::array target1 = {0, 1, 0, 1}; - const int N1 = 2, L1 = 2, T1 = 3, B1 = 2; - - auto fac1 = ForceAlignmentCriterion(N1); - auto input1af = Variable(Tensor::fromArray({N1, T1, B1}, input1), true); - auto target1af = Variable(Tensor::fromArray({{L1, B1}}, target1), false); - - auto loss1 = fac1(input1af, target1af); - std::vector loss1Host(B1); - loss1.host(loss1Host.data()); - ASSERT_NEAR(loss1Host[0], log(exp(1.5) + exp(2.5)), kEpsilon); - ASSERT_NEAR(loss1Host[1], log(exp(2) + exp(3)), kEpsilon); - - // Test case: 2 - std::array input2; - std::fill_n(input2.data(), 12, log(0.25)); - std::array target2 = {0, 1}; - int N2 = 4, L2 = 2, T2 = 3, B2 = 1; - - auto fac2 = ForceAlignmentCriterion(N2); - auto input2af = Variable(Tensor::fromArray({N2, T2, B2}, input2), true); - auto target2af = Variable(Tensor::fromArray({L2, B2}, target2), false); - - auto loss2 = fac2(input2af, target2af); - ASSERT_NEAR(loss2.scalar(), -log(32), kEpsilon); + // Test case: 1 + std::array input1 = { + 1.0, 0.0, 0.0, 1.0, 0.5, 0.5, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}; + std::array target1 = {0, 1, 0, 1}; + const int N1 = 2, L1 = 2, T1 = 3, B1 = 2; + + auto fac1 = ForceAlignmentCriterion(N1); + auto input1af = Variable(Tensor::fromArray({N1, T1, B1}, input1), true); + auto target1af = Variable(Tensor::fromArray({{L1, B1}}, target1), false); + + auto loss1 = fac1(input1af, target1af); + std::vector loss1Host(B1); + loss1.host(loss1Host.data()); + ASSERT_NEAR(loss1Host[0], log(exp(1.5) + exp(2.5)), kEpsilon); + ASSERT_NEAR(loss1Host[1], log(exp(2) + exp(3)), kEpsilon); + + // Test case: 2 + std::array input2; + std::fill_n(input2.data(), 12, log(0.25)); + std::array target2 = {0, 1}; + int N2 = 4, L2 = 2, T2 = 3, B2 = 1; + + auto fac2 = ForceAlignmentCriterion(N2); + auto input2af = Variable(Tensor::fromArray({N2, T2, B2}, input2), true); + auto target2af = Variable(Tensor::fromArray({L2, B2}, target2), false); + + auto loss2 = fac2(input2af, target2af); + ASSERT_NEAR(loss2.scalar(), -log(32), kEpsilon); } TEST(CriterionTest, FACJacobian) { - int N = 3, T = 10, B = 3, L = 3; - auto in = Variable(fl::log(fl::rand({N, T, B})), true); - std::array target = {0, 1, -1, 1, -1, -1, 0, 2, 1}; - auto tgt = Variable(Tensor::fromArray({L, B}, target), false); - auto l = ForceAlignmentCriterion(N, CriterionScaleMode::TARGET_SZ_SQRT); - - // Test case for input - auto funcIn = [&](Variable& inp) { return l.forward(inp, tgt); }; - jacobianTest(funcIn, in); - - // Test case for transition - auto transition = Variable(fl::rand({N, N}), true); - auto funcTrans = [&](Variable& transitionP) { - l.setParams(transitionP, 0); - return l.forward(in, tgt); - }; - jacobianTest(funcTrans, transition); + int N = 3, T = 10, B = 3, L = 3; + auto in = Variable(fl::log(fl::rand({N, T, B})), true); + std::array target = {0, 1, -1, 1, -1, -1, 0, 2, 1}; + auto tgt = Variable(Tensor::fromArray({L, B}, target), false); + auto l = ForceAlignmentCriterion(N, CriterionScaleMode::TARGET_SZ_SQRT); + + // Test case for input + auto funcIn = [&](Variable& inp) { return l.forward(inp, tgt); }; + jacobianTest(funcIn, in); + + // Test case for transition + auto transition = Variable(fl::rand({N, N}), true); + auto funcTrans = [&](Variable& transitionP) { + l.setParams(transitionP, 0); + return l.forward(in, tgt); + }; + jacobianTest(funcTrans, transition); } TEST(CriterionTest, ASGCost) { - // Test case: 1 - constexpr int N1 = 2, L1 = 2, T1 = 3, B1 = 2; - std::array input1 = { - 1.0, 0.0, 0.0, 1.0, 0.5, 0.5, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}; - std::transform( - input1.begin(), input1.end(), input1.begin(), [](float p) -> float { + // Test case: 1 + constexpr int N1 = 2, L1 = 2, T1 = 3, B1 = 2; + std::array input1 = { + 1.0, 0.0, 0.0, 1.0, 0.5, 0.5, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}; + std::transform( + input1.begin(), + input1.end(), + input1.begin(), + [](float p) -> float { return log(p); - }); - std::array target1 = {0, 1, 0, 1}; - std::array trans1 = {}; - - auto asg1 = AutoSegmentationCriterion(N1); - auto input1af = Variable(Tensor::fromArray({N1, T1, B1}, input1), true); - auto target1af = Variable(Tensor::fromArray({L1, B1}, target1), false); - asg1.setParams(Variable(Tensor::fromArray({N1, N1}, trans1), true), 0); - - auto loss1 = asg1({input1af, target1af}).front(); - std::vector loss1Host(B1); - loss1.host(loss1Host.data()); - ASSERT_NEAR(loss1Host[0], -log(0.5), kEpsilon); - ASSERT_NEAR(loss1Host[1], 0.0, kEpsilon); - - // Test case: 2 - constexpr int N2 = 4, L2 = 2, T2 = 3, B2 = 1; - std::array input2; - std::fill(input2.begin(), input2.end(), log(0.25)); - std::array target2 = {0, 1}; - std::array trans2 = {}; - - auto asg2 = AutoSegmentationCriterion(N2); - auto input2af = Variable(Tensor::fromArray({N2, T2, B2}, input2), true); - auto target2af = Variable(Tensor::fromArray({L2, B2}, target2), false); - asg2.setParams(Variable(Tensor::fromArray({N2, N2}, trans2), true), 0); - - auto loss2 = asg2({input2af, target2af}).front(); - ASSERT_NEAR(loss2.scalar(), log(32), kEpsilon); - - // Test case: 3 - constexpr int N3 = 4, L3_1 = 4, L3_2 = 3, T3 = 3, B3 = 1; - std::array input3; - std::fill(input3.begin(), input3.end(), log(0.25)); - std::array target3 = {0, 1, 1, 1}; - std::array trans3 = {}; - - auto asg3 = AutoSegmentationCriterion(N3); - auto input3af = Variable(Tensor::fromArray({N3, T3, B3}, input3), true); - auto target3_1 = Variable(Tensor::fromArray({L3_1, B3}, target3), false); - auto target3_2 = Variable(Tensor::fromArray({L3_2, B3}, target3), false); - asg3.setParams(Variable(Tensor::fromArray({N3, N3}, trans3), true), 0); - // check if target is truncated - checkZero( - asg3({input3af, target3_1}).front().tensor() - - asg3({input3af, target3_2}).front().tensor(), - 1E-5); + } + ); + std::array target1 = {0, 1, 0, 1}; + std::array trans1 = {}; + + auto asg1 = AutoSegmentationCriterion(N1); + auto input1af = Variable(Tensor::fromArray({N1, T1, B1}, input1), true); + auto target1af = Variable(Tensor::fromArray({L1, B1}, target1), false); + asg1.setParams(Variable(Tensor::fromArray({N1, N1}, trans1), true), 0); + + auto loss1 = asg1({input1af, target1af}).front(); + std::vector loss1Host(B1); + loss1.host(loss1Host.data()); + ASSERT_NEAR(loss1Host[0], -log(0.5), kEpsilon); + ASSERT_NEAR(loss1Host[1], 0.0, kEpsilon); + + // Test case: 2 + constexpr int N2 = 4, L2 = 2, T2 = 3, B2 = 1; + std::array input2; + std::fill(input2.begin(), input2.end(), log(0.25)); + std::array target2 = {0, 1}; + std::array trans2 = {}; + + auto asg2 = AutoSegmentationCriterion(N2); + auto input2af = Variable(Tensor::fromArray({N2, T2, B2}, input2), true); + auto target2af = Variable(Tensor::fromArray({L2, B2}, target2), false); + asg2.setParams(Variable(Tensor::fromArray({N2, N2}, trans2), true), 0); + + auto loss2 = asg2({input2af, target2af}).front(); + ASSERT_NEAR(loss2.scalar(), log(32), kEpsilon); + + // Test case: 3 + constexpr int N3 = 4, L3_1 = 4, L3_2 = 3, T3 = 3, B3 = 1; + std::array input3; + std::fill(input3.begin(), input3.end(), log(0.25)); + std::array target3 = {0, 1, 1, 1}; + std::array trans3 = {}; + + auto asg3 = AutoSegmentationCriterion(N3); + auto input3af = Variable(Tensor::fromArray({N3, T3, B3}, input3), true); + auto target3_1 = Variable(Tensor::fromArray({L3_1, B3}, target3), false); + auto target3_2 = Variable(Tensor::fromArray({L3_2, B3}, target3), false); + asg3.setParams(Variable(Tensor::fromArray({N3, N3}, trans3), true), 0); + // check if target is truncated + checkZero( + asg3({input3af, target3_1}).front().tensor() + - asg3({input3af, target3_2}).front().tensor(), + 1E-5 + ); } TEST(CriterionTest, ASGJacobian) { - int N = 3, T = 10, B = 3, L = 3; - auto in = Variable(fl::log(fl::rand({N, T, B})), true); - std::array target = {0, 1, -1, 1, -1, -1, 0, 2, 1}; - auto tgt = Variable(Tensor::fromArray({L, B}, target), false); - auto l = AutoSegmentationCriterion(N, CriterionScaleMode::TARGET_SZ_SQRT); - - // Test case for input - auto funcIn = [&](Variable& inp) { return l.forward({inp, tgt}).front(); }; - jacobianTest(funcIn, in); - - // Test case for transition - auto transition = Variable(fl::rand({N, N}), true); - auto funcTrans = [&](Variable& transitionP) { - l.setParams(transitionP, 0); - return l.forward({in, tgt}).front(); - }; - jacobianTest(funcTrans, transition); + int N = 3, T = 10, B = 3, L = 3; + auto in = Variable(fl::log(fl::rand({N, T, B})), true); + std::array target = {0, 1, -1, 1, -1, -1, 0, 2, 1}; + auto tgt = Variable(Tensor::fromArray({L, B}, target), false); + auto l = AutoSegmentationCriterion(N, CriterionScaleMode::TARGET_SZ_SQRT); + + // Test case for input + auto funcIn = [&](Variable& inp) { return l.forward({inp, tgt}).front(); }; + jacobianTest(funcIn, in); + + // Test case for transition + auto transition = Variable(fl::rand({N, N}), true); + auto funcTrans = [&](Variable& transitionP) { + l.setParams(transitionP, 0); + return l.forward({in, tgt}).front(); + }; + jacobianTest(funcTrans, transition); } TEST(CriterionTest, LinSegJacobian) { - int N = 3, T = 10, B = 3, L = 3; - auto in = Variable(fl::log(fl::rand({N, T, B})), true); - std::array target = {0, 1, -1, 1, -1, -1, 0, 2, 1}; - auto tgt = Variable(Tensor::fromArray({L, B}, target), false); - auto l = LinearSegmentationCriterion(N, CriterionScaleMode::TARGET_SZ_SQRT); - - // Test case for input - auto funcIn = [&](Variable& inp) { return l.forward({inp, tgt}).front(); }; - jacobianTest(funcIn, in); - - // Test case for transition - auto transition = Variable(fl::rand({N, N}), true); - auto funcTrans = [&](Variable& transitionP) { - l.setParams(transitionP, 0); - return l.forward({in, tgt}).front(); - }; - jacobianTest(funcTrans, transition); + int N = 3, T = 10, B = 3, L = 3; + auto in = Variable(fl::log(fl::rand({N, T, B})), true); + std::array target = {0, 1, -1, 1, -1, -1, 0, 2, 1}; + auto tgt = Variable(Tensor::fromArray({L, B}, target), false); + auto l = LinearSegmentationCriterion(N, CriterionScaleMode::TARGET_SZ_SQRT); + + // Test case for input + auto funcIn = [&](Variable& inp) { return l.forward({inp, tgt}).front(); }; + jacobianTest(funcIn, in); + + // Test case for transition + auto transition = Variable(fl::rand({N, N}), true); + auto funcTrans = [&](Variable& transitionP) { + l.setParams(transitionP, 0); + return l.forward({in, tgt}).front(); + }; + jacobianTest(funcTrans, transition); } TEST(CriterionTest, ASGBatching) { - int N = 80, T = 50, L = 25, B = 10; - auto in = Variable(fl::log(fl::rand({N, T, B})), true); - auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)) % (N - 2); - for (int i = 0; i < B; ++i) { - int r = std::rand() % L; - if (r > 0) { - t(fl::range(r, fl::end), i) = -1; + int N = 80, T = 50, L = 25, B = 10; + auto in = Variable(fl::log(fl::rand({N, T, B})), true); + auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)) % (N - 2); + for(int i = 0; i < B; ++i) { + int r = std::rand() % L; + if(r > 0) { + t(fl::range(r, fl::end), i) = -1; + } + } + auto tgt = Variable(t.astype(fl::dtype::s32), false); + auto l = AutoSegmentationCriterion(N, CriterionScaleMode::TARGET_SZ); + auto output = l.forward({in, tgt}).front(); + + for(int i = 0; i < B; ++i) { + auto inel = moddims(in(fl::span, fl::span, i), {N, T, 1}); + auto tgtel = moddims(tgt(fl::span, i), {L, 1}); + + auto outputCur = l.forward({inel, tgtel}).front(); + checkZero(output.tensor()(i) - outputCur.tensor(), 1E-6); } - } - auto tgt = Variable(t.astype(fl::dtype::s32), false); - auto l = AutoSegmentationCriterion(N, CriterionScaleMode::TARGET_SZ); - auto output = l.forward({in, tgt}).front(); - - for (int i = 0; i < B; ++i) { - auto inel = moddims(in(fl::span, fl::span, i), {N, T, 1}); - auto tgtel = moddims(tgt(fl::span, i), {L, 1}); - - auto outputCur = l.forward({inel, tgtel}).front(); - checkZero(output.tensor()(i) - outputCur.tensor(), 1E-6); - } } TEST(CriterionTest, ASGCompareLua) { - // Compare with lua version - const int N = 6, L = 5, T = 5, B = 3; - - // clang-format off - std::array input = { - -0.4340, -0.0254, 0.3667, 0.4180, -0.3805, -0.1707, - 0.1060, 0.3631, -0.1122, -0.3825, -0.0031, -0.3801, - 0.0443, -0.3795, 0.3194, -0.3130, 0.0094, 0.1560, - 0.1252, 0.2877, 0.1997, -0.4554, 0.2774, -0.2526, - -0.4001, -0.2402, 0.1295, 0.0172, 0.1805, -0.3299, - - 0.3298, -0.2259, -0.0959, 0.4909, 0.2996, -0.2543, - -0.2863, 0.3239, -0.3988, 0.0732, -0.2107, -0.4739, - -0.0906, 0.0480, -0.1301, 0.3975, -0.3317, -0.1967, - 0.4372, -0.2006, 0.0094, 0.3281, 0.1873, -0.2945, - 0.2399, 0.0320, -0.3768, -0.2849, -0.2248, 0.3186, - - 0.0225, -0.3867, -0.1929, -0.2904, -0.4958, -0.2533, - 0.4001, -0.1517, -0.2799, -0.2915, 0.4198, 0.4506, - 0.1446, -0.4753, -0.0711, 0.2876, -0.1851, -0.1066, - 0.2081, -0.1190, -0.3902, -0.1668, 0.1911, -0.2848, - -0.3846, 0.1175, 0.1052, 0.2172, -0.0362, 0.3055, - }; - std::array target = { - 2, 1, 5, 1, 3, - 4, 3, 5, -1, -1, - 3, 2, 2, 1, -1, - }; - - std::vector expectedLoss = { - 7.7417464256287, - 6.4200420379639, - 8.2780694961548, - }; - std::array expectedInputGrad = { - 0.1060, 0.1595, -0.7639, 0.2485, 0.1118, 0.1380, - 0.1915, -0.7524, 0.1539, 0.1175, 0.1717, 0.1178, - 0.1738, 0.1137, 0.2288, 0.1216, 0.1678, -0.8057, - 0.1766, -0.7923, 0.1902, 0.0988, 0.2056, 0.1210, - 0.1212, 0.1422, 0.2059, -0.8160, 0.2166, 0.1300, - - 0.2029, 0.1164, 0.1325, 0.2383, -0.8032, 0.1131, - 0.1414, 0.2602, 0.1263, -0.3441, -0.3009, 0.1172, - 0.1557, 0.1788, 0.1496, -0.5498, 0.0140, 0.0516, - 0.2306, 0.1219, 0.1503, -0.4244, 0.1796, -0.2579, - 0.2149, 0.1745, 0.1160, 0.1271, 0.1350, -0.7675, - - 0.2195, 0.1458, 0.1770, -0.8395, 0.1307, 0.1666, - 0.2148, 0.1237, -0.6613, -0.1223, 0.2191, 0.2259, - 0.2002, 0.1077, -0.8386, 0.2310, 0.1440, 0.1557, - 0.2197, -0.1466, -0.5742, 0.1510, 0.2160, 0.1342, - 0.1050, -0.8265, 0.1714, 0.1917, 0.1488, 0.2094, - }; - std::array expectedTransGrad = { - 0.3990, 0.3396, 0.3486, 0.3922, 0.3504, 0.3155, - 0.3666, 0.0116, -1.6678, 0.3737, 0.3361, -0.7152, - 0.3468, 0.3163, -1.1583, -0.6803, 0.3216, 0.2722, - 0.3694, -0.6688, 0.3047, -0.8531, -0.6571, 0.2870, - 0.3866, 0.3321, 0.3447, 0.3664, -0.2163, 0.3039, - 0.3640, -0.6943, 0.2988, -0.6722, 0.3215, -0.1860, - }; - // clang-format on - - auto asg = AutoSegmentationCriterion(N); - auto inputAf = Variable(Tensor::fromArray({N, T, B}, input), true); - auto targetAf = Variable(Tensor::fromArray({L, B}, target), false); - asg.setParams(constant(0., {N, N}), 0); - - auto loss = asg({inputAf, targetAf}).front(); - std::vector lossHost(B); - loss.host(lossHost.data()); - for (int i = 0; i < B; i++) { - ASSERT_NEAR(lossHost[i], expectedLoss[i], 1e-3); - } - - loss.backward(); - auto inputGrad = inputAf.grad().tensor(); - checkZero(inputGrad - Tensor::fromArray({N, T, B}, expectedInputGrad), 1e-4); - auto transGrad = asg.param(0).grad().tensor(); - checkZero(transGrad - Tensor::fromArray({N, N}, expectedTransGrad), 1e-4); + // Compare with lua version + const int N = 6, L = 5, T = 5, B = 3; + + // clang-format off + std::array input = { + -0.4340, -0.0254, 0.3667, 0.4180, -0.3805, -0.1707, + 0.1060, 0.3631, -0.1122, -0.3825, -0.0031, -0.3801, + 0.0443, -0.3795, 0.3194, -0.3130, 0.0094, 0.1560, + 0.1252, 0.2877, 0.1997, -0.4554, 0.2774, -0.2526, + -0.4001, -0.2402, 0.1295, 0.0172, 0.1805, -0.3299, + + 0.3298, -0.2259, -0.0959, 0.4909, 0.2996, -0.2543, + -0.2863, 0.3239, -0.3988, 0.0732, -0.2107, -0.4739, + -0.0906, 0.0480, -0.1301, 0.3975, -0.3317, -0.1967, + 0.4372, -0.2006, 0.0094, 0.3281, 0.1873, -0.2945, + 0.2399, 0.0320, -0.3768, -0.2849, -0.2248, 0.3186, + + 0.0225, -0.3867, -0.1929, -0.2904, -0.4958, -0.2533, + 0.4001, -0.1517, -0.2799, -0.2915, 0.4198, 0.4506, + 0.1446, -0.4753, -0.0711, 0.2876, -0.1851, -0.1066, + 0.2081, -0.1190, -0.3902, -0.1668, 0.1911, -0.2848, + -0.3846, 0.1175, 0.1052, 0.2172, -0.0362, 0.3055, + }; + std::array target = { + 2, 1, 5, 1, 3, + 4, 3, 5, -1, -1, + 3, 2, 2, 1, -1, + }; + + std::vector expectedLoss = { + 7.7417464256287, + 6.4200420379639, + 8.2780694961548, + }; + std::array expectedInputGrad = { + 0.1060, 0.1595, -0.7639, 0.2485, 0.1118, 0.1380, + 0.1915, -0.7524, 0.1539, 0.1175, 0.1717, 0.1178, + 0.1738, 0.1137, 0.2288, 0.1216, 0.1678, -0.8057, + 0.1766, -0.7923, 0.1902, 0.0988, 0.2056, 0.1210, + 0.1212, 0.1422, 0.2059, -0.8160, 0.2166, 0.1300, + + 0.2029, 0.1164, 0.1325, 0.2383, -0.8032, 0.1131, + 0.1414, 0.2602, 0.1263, -0.3441, -0.3009, 0.1172, + 0.1557, 0.1788, 0.1496, -0.5498, 0.0140, 0.0516, + 0.2306, 0.1219, 0.1503, -0.4244, 0.1796, -0.2579, + 0.2149, 0.1745, 0.1160, 0.1271, 0.1350, -0.7675, + + 0.2195, 0.1458, 0.1770, -0.8395, 0.1307, 0.1666, + 0.2148, 0.1237, -0.6613, -0.1223, 0.2191, 0.2259, + 0.2002, 0.1077, -0.8386, 0.2310, 0.1440, 0.1557, + 0.2197, -0.1466, -0.5742, 0.1510, 0.2160, 0.1342, + 0.1050, -0.8265, 0.1714, 0.1917, 0.1488, 0.2094, + }; + std::array expectedTransGrad = { + 0.3990, 0.3396, 0.3486, 0.3922, 0.3504, 0.3155, + 0.3666, 0.0116, -1.6678, 0.3737, 0.3361, -0.7152, + 0.3468, 0.3163, -1.1583, -0.6803, 0.3216, 0.2722, + 0.3694, -0.6688, 0.3047, -0.8531, -0.6571, 0.2870, + 0.3866, 0.3321, 0.3447, 0.3664, -0.2163, 0.3039, + 0.3640, -0.6943, 0.2988, -0.6722, 0.3215, -0.1860, + }; + // clang-format on + + auto asg = AutoSegmentationCriterion(N); + auto inputAf = Variable(Tensor::fromArray({N, T, B}, input), true); + auto targetAf = Variable(Tensor::fromArray({L, B}, target), false); + asg.setParams(constant(0., {N, N}), 0); + + auto loss = asg({inputAf, targetAf}).front(); + std::vector lossHost(B); + loss.host(lossHost.data()); + for(int i = 0; i < B; i++) { + ASSERT_NEAR(lossHost[i], expectedLoss[i], 1e-3); + } + + loss.backward(); + auto inputGrad = inputAf.grad().tensor(); + checkZero(inputGrad - Tensor::fromArray({N, T, B}, expectedInputGrad), 1e-4); + auto transGrad = asg.param(0).grad().tensor(); + checkZero(transGrad - Tensor::fromArray({N, N}, expectedTransGrad), 1e-4); } TEST(CriterionTest, LinSegCompareLua) { - // Compare LinSegCriterion with lua version - constexpr int N = 6, L = 5, T = 5, B = 3; - - // clang-format off - std::array input = { - -0.4340, -0.0254, 0.3667, 0.4180, -0.3805, -0.1707, - 0.1060, 0.3631, -0.1122, -0.3825, -0.0031, -0.3801, - 0.0443, -0.3795, 0.3194, -0.3130, 0.0094, 0.1560, - 0.1252, 0.2877, 0.1997, -0.4554, 0.2774, -0.2526, - -0.4001, -0.2402, 0.1295, 0.0172, 0.1805, -0.3299, - - 0.3298, -0.2259, -0.0959, 0.4909, 0.2996, -0.2543, - -0.2863, 0.3239, -0.3988, 0.0732, -0.2107, -0.4739, - -0.0906, 0.0480, -0.1301, 0.3975, -0.3317, -0.1967, - 0.4372, -0.2006, 0.0094, 0.3281, 0.1873, -0.2945, - 0.2399, 0.0320, -0.3768, -0.2849, -0.2248, 0.3186, - - 0.0225, -0.3867, -0.1929, -0.2904, -0.4958, -0.2533, - 0.4001, -0.1517, -0.2799, -0.2915, 0.4198, 0.4506, - 0.1446, -0.4753, -0.0711, 0.2876, -0.1851, -0.1066, - 0.2081, -0.1190, -0.3902, -0.1668, 0.1911, -0.2848, - -0.3846, 0.1175, 0.1052, 0.2172, -0.0362, 0.3055, - }; - // target is zero-indexed here; add 1 for Lua counterpart - std::array target = { - 2, 1, 5, 1, 3, - 4, 3, 5, -1, -1, - 3, 2, 2, 1, -1, - }; - // clang-format on - - auto linseg = - LinearSegmentationCriterion(N, CriterionScaleMode::TARGET_SZ_SQRT); - auto inputAf = Variable(Tensor::fromArray({N, T, B}, input), true); - auto targetAf = Variable(Tensor::fromArray({L, B}, target), false); - linseg.setParams(constant(0.0, {N, N}), 0); - - // clang-format off - std::vector expectedLoss = { - 3.4622850827983, - 3.5390825164779, - 4.359541315858, - }; - std::array expectedInputGrad = { - 0.0474, 0.0713, -0.3416, 0.1112, 0.0500, 0.0617, - 0.0856, -0.3365, 0.0688, 0.0525, 0.0768, 0.0527, - 0.0777, 0.0509, 0.1023, 0.0544, 0.0750, -0.3603, - 0.0790, -0.3543, 0.0851, 0.0442, 0.0920, 0.0541, - 0.0542, 0.0636, 0.0921, -0.3649, 0.0969, 0.0582, - - 0.0907, 0.0520, 0.0593, 0.1066, -0.3592, 0.0506, - 0.0632, 0.1164, 0.0565, 0.0906, -0.3790, 0.0524, - 0.0696, 0.0800, 0.0669, -0.3338, 0.0547, 0.0626, - 0.1031, 0.0545, 0.0672, -0.3548, 0.0803, 0.0496, - 0.0961, 0.0781, 0.0519, 0.0569, 0.0604, -0.3433, - - 0.0982, 0.0652, 0.0791, -0.3754, 0.0585, 0.0745, - 0.0961, 0.0553, 0.0487, -0.3991, 0.0980, 0.1010, - 0.0895, 0.0482, -0.3750, 0.1033, 0.0644, 0.0696, - 0.0982, 0.0708, -0.3932, 0.0675, 0.0966, 0.0600, - 0.0470, -0.3696, 0.0767, 0.0857, 0.0666, 0.0937, - }; - std::array expectedTransGrad = { - 0.1784, 0.1519, 0.1559, 0.1754, 0.1567, 0.1411, - 0.1640, 0.1416, -0.7458, 0.1671, 0.1503, -0.3198, - 0.1551, 0.1414, -0.3100, -0.3042, 0.1438, 0.1217, - 0.1652, -0.2991, 0.1363, -0.7343, -0.2939, 0.1284, - 0.1729, 0.1485, 0.1542, 0.1638, -0.2928, 0.1359, - 0.1628, -0.3105, 0.1336, -0.3006, 0.1438, 0.1213, - }; - // clang-format on - - auto loss = linseg({inputAf, targetAf}).front(); - std::vector lossHost(B); - loss.host(lossHost.data()); - for (int i = 0; i < B; i++) { - ASSERT_NEAR(lossHost[i], expectedLoss[i], 1e-3); - } - - loss.backward(); - auto inputGrad = inputAf.grad().tensor(); - checkZero(inputGrad - Tensor::fromArray({N, T, B}, expectedInputGrad), 1e-4); - auto transGrad = linseg.param(0).grad().tensor(); - checkZero(transGrad - Tensor::fromArray({N, N}, expectedTransGrad), 1e-4); + // Compare LinSegCriterion with lua version + constexpr int N = 6, L = 5, T = 5, B = 3; + + // clang-format off + std::array input = { + -0.4340, -0.0254, 0.3667, 0.4180, -0.3805, -0.1707, + 0.1060, 0.3631, -0.1122, -0.3825, -0.0031, -0.3801, + 0.0443, -0.3795, 0.3194, -0.3130, 0.0094, 0.1560, + 0.1252, 0.2877, 0.1997, -0.4554, 0.2774, -0.2526, + -0.4001, -0.2402, 0.1295, 0.0172, 0.1805, -0.3299, + + 0.3298, -0.2259, -0.0959, 0.4909, 0.2996, -0.2543, + -0.2863, 0.3239, -0.3988, 0.0732, -0.2107, -0.4739, + -0.0906, 0.0480, -0.1301, 0.3975, -0.3317, -0.1967, + 0.4372, -0.2006, 0.0094, 0.3281, 0.1873, -0.2945, + 0.2399, 0.0320, -0.3768, -0.2849, -0.2248, 0.3186, + + 0.0225, -0.3867, -0.1929, -0.2904, -0.4958, -0.2533, + 0.4001, -0.1517, -0.2799, -0.2915, 0.4198, 0.4506, + 0.1446, -0.4753, -0.0711, 0.2876, -0.1851, -0.1066, + 0.2081, -0.1190, -0.3902, -0.1668, 0.1911, -0.2848, + -0.3846, 0.1175, 0.1052, 0.2172, -0.0362, 0.3055, + }; + // target is zero-indexed here; add 1 for Lua counterpart + std::array target = { + 2, 1, 5, 1, 3, + 4, 3, 5, -1, -1, + 3, 2, 2, 1, -1, + }; + // clang-format on + + auto linseg = + LinearSegmentationCriterion(N, CriterionScaleMode::TARGET_SZ_SQRT); + auto inputAf = Variable(Tensor::fromArray({N, T, B}, input), true); + auto targetAf = Variable(Tensor::fromArray({L, B}, target), false); + linseg.setParams(constant(0.0, {N, N}), 0); + + // clang-format off + std::vector expectedLoss = { + 3.4622850827983, + 3.5390825164779, + 4.359541315858, + }; + std::array expectedInputGrad = { + 0.0474, 0.0713, -0.3416, 0.1112, 0.0500, 0.0617, + 0.0856, -0.3365, 0.0688, 0.0525, 0.0768, 0.0527, + 0.0777, 0.0509, 0.1023, 0.0544, 0.0750, -0.3603, + 0.0790, -0.3543, 0.0851, 0.0442, 0.0920, 0.0541, + 0.0542, 0.0636, 0.0921, -0.3649, 0.0969, 0.0582, + + 0.0907, 0.0520, 0.0593, 0.1066, -0.3592, 0.0506, + 0.0632, 0.1164, 0.0565, 0.0906, -0.3790, 0.0524, + 0.0696, 0.0800, 0.0669, -0.3338, 0.0547, 0.0626, + 0.1031, 0.0545, 0.0672, -0.3548, 0.0803, 0.0496, + 0.0961, 0.0781, 0.0519, 0.0569, 0.0604, -0.3433, + + 0.0982, 0.0652, 0.0791, -0.3754, 0.0585, 0.0745, + 0.0961, 0.0553, 0.0487, -0.3991, 0.0980, 0.1010, + 0.0895, 0.0482, -0.3750, 0.1033, 0.0644, 0.0696, + 0.0982, 0.0708, -0.3932, 0.0675, 0.0966, 0.0600, + 0.0470, -0.3696, 0.0767, 0.0857, 0.0666, 0.0937, + }; + std::array expectedTransGrad = { + 0.1784, 0.1519, 0.1559, 0.1754, 0.1567, 0.1411, + 0.1640, 0.1416, -0.7458, 0.1671, 0.1503, -0.3198, + 0.1551, 0.1414, -0.3100, -0.3042, 0.1438, 0.1217, + 0.1652, -0.2991, 0.1363, -0.7343, -0.2939, 0.1284, + 0.1729, 0.1485, 0.1542, 0.1638, -0.2928, 0.1359, + 0.1628, -0.3105, 0.1336, -0.3006, 0.1438, 0.1213, + }; + // clang-format on + + auto loss = linseg({inputAf, targetAf}).front(); + std::vector lossHost(B); + loss.host(lossHost.data()); + for(int i = 0; i < B; i++) { + ASSERT_NEAR(lossHost[i], expectedLoss[i], 1e-3); + } + + loss.backward(); + auto inputGrad = inputAf.grad().tensor(); + checkZero(inputGrad - Tensor::fromArray({N, T, B}, expectedInputGrad), 1e-4); + auto transGrad = linseg.param(0).grad().tensor(); + checkZero(transGrad - Tensor::fromArray({N, N}, expectedTransGrad), 1e-4); } TEST(CriterionTest, AsgSerialization) { - char* user = getenv("USER"); - std::string userstr = "unknown"; - if (user != nullptr) { - userstr = std::string(user); - } - const fs::path path = fs::temp_directory_path() / "test.mdl"; - int N = 500; + char* user = getenv("USER"); + std::string userstr = "unknown"; + if(user != nullptr) { + userstr = std::string(user); + } + const fs::path path = fs::temp_directory_path() / "test.mdl"; + int N = 500; - auto asg = std::make_shared(N); - fl::save(path, asg); + auto asg = std::make_shared(N); + fl::save(path, asg); - std::shared_ptr asg2; - fl::load(path, asg2); + std::shared_ptr asg2; + fl::load(path, asg2); - checkZero((asg->param(0) - asg2->param(0)).tensor(), 1e-4); + checkZero((asg->param(0) - asg2->param(0)).tensor(), 1e-4); - auto input = fl::rand({N, 200, 2}); - auto target = fl::clip(fl::rand({100, 2}, fl::dtype::s32), 0, N - 1); - checkZero((asg->param(0) - asg2->param(0)).tensor(), 1e-4); + auto input = fl::rand({N, 200, 2}); + auto target = fl::clip(fl::rand({100, 2}, fl::dtype::s32), 0, N - 1); + checkZero((asg->param(0) - asg2->param(0)).tensor(), 1e-4); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp b/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp index c1014dd..05b9039 100644 --- a/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp +++ b/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp @@ -18,423 +18,466 @@ using namespace fl; using namespace fl::pkg::speech; TEST(Seq2SeqTest, Seq2Seq) { - if (FL_BACKEND_CPU) { - GTEST_SKIP() << "RNN gradient computation not supported for CPU backend"; - } - int nclass = 40; - int hiddendim = 256; - int batchsize = 2; - int inputsteps = 200; - int outputsteps = 50; - int maxoutputlen = 100; - // int nAttnRound = 2; - int nAttnRound = 1; - - std::vector> attentions( - nAttnRound, std::make_shared()); - Seq2SeqCriterion seq2seq( - nclass, - hiddendim, - nclass - 2 /* eos token index */, - nclass - 1 /* pad token index */, - maxoutputlen, - attentions, - nullptr, - false, - 100, - 0.0, - false, - kRandSampling, - 1.0, - 2, // nRnnLayer - nAttnRound, - 0.0); - - auto input = fl::randn({hiddendim, inputsteps, batchsize}, fl::dtype::f32); - auto target = - fl::rand({outputsteps, batchsize}, fl::dtype::f32) * 0.99 * nclass; - target = target.astype(fl::dtype::s32); - - Variable output, attention; - std::tie(output, attention) = seq2seq.vectorizedDecoder( - noGrad(input), noGrad(target), Tensor(), Tensor()); - - ASSERT_EQ(output.shape(), Shape({nclass, outputsteps, batchsize})); - - ASSERT_EQ(attention.shape(), Shape({outputsteps, inputsteps, batchsize})); - - auto losses = - seq2seq({fl::noGrad(input), fl::noGrad(target), fl::noGrad(Tensor())}) - .front(); - ASSERT_EQ(losses.dim(0), batchsize); - - // Backward runs. - losses.backward(); - - // Check that vecotrized decoder and sequential decoder give the same - // results. - Variable outSeq, attentionSeq; - std::tie(outSeq, attentionSeq) = - seq2seq.decoder(noGrad(input), noGrad(target), Tensor(), Tensor()); - - ASSERT_TRUE(allClose(output, outSeq, 1e-6)); - ASSERT_TRUE(allClose(attention, attentionSeq, 1e-6)); - - // Check size 1 Target works - target = target(fl::range(0, 1), fl::span); - auto loss = - seq2seq({noGrad(input), noGrad(target), fl::noGrad(Tensor())}).front(); - - // Make sure eval mode is not storing variables. - seq2seq.eval(); - std::tie(outSeq, attentionSeq) = - seq2seq.decoder(noGrad(input), noGrad(target), Tensor(), Tensor()); - ASSERT_FALSE(outSeq.isCalcGrad()); - ASSERT_FALSE(attentionSeq.isCalcGrad()); + if(FL_BACKEND_CPU) { + GTEST_SKIP() << "RNN gradient computation not supported for CPU backend"; + } + int nclass = 40; + int hiddendim = 256; + int batchsize = 2; + int inputsteps = 200; + int outputsteps = 50; + int maxoutputlen = 100; + // int nAttnRound = 2; + int nAttnRound = 1; + + std::vector> attentions( + nAttnRound, std::make_shared()); + Seq2SeqCriterion seq2seq( + nclass, + hiddendim, + nclass - 2 /* eos token index */, + nclass - 1 /* pad token index */, + maxoutputlen, + attentions, + nullptr, + false, + 100, + 0.0, + false, + kRandSampling, + 1.0, + 2, // nRnnLayer + nAttnRound, + 0.0); + + auto input = fl::randn({hiddendim, inputsteps, batchsize}, fl::dtype::f32); + auto target = + fl::rand({outputsteps, batchsize}, fl::dtype::f32) * 0.99 * nclass; + target = target.astype(fl::dtype::s32); + + Variable output, attention; + std::tie(output, attention) = seq2seq.vectorizedDecoder( + noGrad(input), + noGrad(target), + Tensor(), + Tensor() + ); + + ASSERT_EQ(output.shape(), Shape({nclass, outputsteps, batchsize})); + + ASSERT_EQ(attention.shape(), Shape({outputsteps, inputsteps, batchsize})); + + auto losses = + seq2seq({fl::noGrad(input), fl::noGrad(target), fl::noGrad(Tensor())}) + .front(); + ASSERT_EQ(losses.dim(0), batchsize); + + // Backward runs. + losses.backward(); + + // Check that vecotrized decoder and sequential decoder give the same + // results. + Variable outSeq, attentionSeq; + std::tie(outSeq, attentionSeq) = + seq2seq.decoder(noGrad(input), noGrad(target), Tensor(), Tensor()); + + ASSERT_TRUE(allClose(output, outSeq, 1e-6)); + ASSERT_TRUE(allClose(attention, attentionSeq, 1e-6)); + + // Check size 1 Target works + target = target(fl::range(0, 1), fl::span); + auto loss = + seq2seq({noGrad(input), noGrad(target), fl::noGrad(Tensor())}).front(); + + // Make sure eval mode is not storing variables. + seq2seq.eval(); + std::tie(outSeq, attentionSeq) = + seq2seq.decoder(noGrad(input), noGrad(target), Tensor(), Tensor()); + ASSERT_FALSE(outSeq.isCalcGrad()); + ASSERT_FALSE(attentionSeq.isCalcGrad()); } TEST(Seq2SeqTest, Seq2SeqViterbi) { - int nclass = 40; - int hiddendim = 256; - int inputsteps = 200; - int maxoutputlen = 100; - - fl::setSeed(1); - Seq2SeqCriterion seq2seq( - nclass, - hiddendim, - nclass - 1 /* eos token index */, - nclass - 2 /* pad token index */, - maxoutputlen, - {std::make_shared()}); - - seq2seq.eval(); - auto input = fl::randn({hiddendim, inputsteps, 1}, fl::dtype::f32); - - auto path = seq2seq.viterbiPath(input); - ASSERT_GT(path.elements(), 0); - ASSERT_LE(path.elements(), maxoutputlen); + int nclass = 40; + int hiddendim = 256; + int inputsteps = 200; + int maxoutputlen = 100; + + fl::setSeed(1); + Seq2SeqCriterion seq2seq( + nclass, + hiddendim, + nclass - 1 /* eos token index */, + nclass - 2 /* pad token index */, + maxoutputlen, + {std::make_shared()}); + + seq2seq.eval(); + auto input = fl::randn({hiddendim, inputsteps, 1}, fl::dtype::f32); + + auto path = seq2seq.viterbiPath(input); + ASSERT_GT(path.elements(), 0); + ASSERT_LE(path.elements(), maxoutputlen); } TEST(Seq2SeqTest, Seq2SeqBeamSearchViterbi) { - int nclass = 40; - int hiddendim = 256; - int inputsteps = 200; - int maxoutputlen = 100; - - Seq2SeqCriterion seq2seq( - nclass, - hiddendim, - nclass - 2 /* eos token index */, - nclass - 1 /* pad token index */, - maxoutputlen, - {std::make_shared()}); - - seq2seq.eval(); - auto input = fl::randn({hiddendim, inputsteps, 1}, fl::dtype::f32); - - auto viterbipath = seq2seq.viterbiPath(input); - auto beampath = seq2seq.beamPath(input, Tensor(), 1); - ASSERT_EQ(beampath.size(), viterbipath.elements()); - for (int idx = 0; idx < beampath.size(); idx++) { - ASSERT_EQ(beampath[idx], viterbipath(idx).scalar()); - } + int nclass = 40; + int hiddendim = 256; + int inputsteps = 200; + int maxoutputlen = 100; + + Seq2SeqCriterion seq2seq( + nclass, + hiddendim, + nclass - 2 /* eos token index */, + nclass - 1 /* pad token index */, + maxoutputlen, + {std::make_shared()}); + + seq2seq.eval(); + auto input = fl::randn({hiddendim, inputsteps, 1}, fl::dtype::f32); + + auto viterbipath = seq2seq.viterbiPath(input); + auto beampath = seq2seq.beamPath(input, Tensor(), 1); + ASSERT_EQ(beampath.size(), viterbipath.elements()); + for(int idx = 0; idx < beampath.size(); idx++) { + ASSERT_EQ(beampath[idx], viterbipath(idx).scalar()); + } } TEST(Seq2SeqTest, Seq2SeqMedianWindow) { - int nclass = 40; - int hiddendim = 256; - int inputsteps = 200; - int maxoutputlen = 100; - - Seq2SeqCriterion seq2seq( - nclass, - hiddendim, - nclass - 2 /* eos token index */, - nclass - 1 /* pad token index */, - maxoutputlen, - {std::make_shared()}, - std::make_shared(10, 10)); - - seq2seq.eval(); - auto input = fl::randn({hiddendim, inputsteps, 1}, fl::dtype::f32); - - auto viterbipath = seq2seq.viterbiPath(input); - auto beampath = seq2seq.beamPath(input, Tensor(), 1); - ASSERT_EQ(beampath.size(), viterbipath.elements()); - for (int idx = 0; idx < beampath.size(); idx++) { - ASSERT_EQ(beampath[idx], viterbipath(idx).scalar()); - } + int nclass = 40; + int hiddendim = 256; + int inputsteps = 200; + int maxoutputlen = 100; + + Seq2SeqCriterion seq2seq( + nclass, + hiddendim, + nclass - 2 /* eos token index */, + nclass - 1 /* pad token index */, + maxoutputlen, + {std::make_shared()}, + std::make_shared(10, 10)); + + seq2seq.eval(); + auto input = fl::randn({hiddendim, inputsteps, 1}, fl::dtype::f32); + + auto viterbipath = seq2seq.viterbiPath(input); + auto beampath = seq2seq.beamPath(input, Tensor(), 1); + ASSERT_EQ(beampath.size(), viterbipath.elements()); + for(int idx = 0; idx < beampath.size(); idx++) { + ASSERT_EQ(beampath[idx], viterbipath(idx).scalar()); + } } TEST(Seq2SeqTest, Seq2SeqStepWindow) { - int nclass = 40; - int hiddendim = 256; - int inputsteps = 200; - int maxoutputlen = 100; - - Seq2SeqCriterion seq2seq( - nclass, - hiddendim, - nclass - 2 /* eos token index */, - nclass - 1 /* pad token index */, - maxoutputlen, - {std::make_shared()}, - std::make_shared(1, 20, 2.2, 5.8)); - - seq2seq.eval(); - auto input = fl::randn({hiddendim, inputsteps, 1}, fl::dtype::f32); - - auto viterbipath = seq2seq.viterbiPath(input); - auto beampath = seq2seq.beamPath(input, Tensor(), 1); - ASSERT_EQ(beampath.size(), viterbipath.elements()); - for (int idx = 0; idx < beampath.size(); idx++) { - ASSERT_EQ(beampath[idx], viterbipath(idx).scalar()); - } + int nclass = 40; + int hiddendim = 256; + int inputsteps = 200; + int maxoutputlen = 100; + + Seq2SeqCriterion seq2seq( + nclass, + hiddendim, + nclass - 2 /* eos token index */, + nclass - 1 /* pad token index */, + maxoutputlen, + {std::make_shared()}, + std::make_shared(1, 20, 2.2, 5.8)); + + seq2seq.eval(); + auto input = fl::randn({hiddendim, inputsteps, 1}, fl::dtype::f32); + + auto viterbipath = seq2seq.viterbiPath(input); + auto beampath = seq2seq.beamPath(input, Tensor(), 1); + ASSERT_EQ(beampath.size(), viterbipath.elements()); + for(int idx = 0; idx < beampath.size(); idx++) { + ASSERT_EQ(beampath[idx], viterbipath(idx).scalar()); + } } TEST(Seq2SeqTest, Seq2SeqStepWindowVectorized) { - int nclass = 20; - int hiddendim = 16; - int batchsize = 2; - int inputsteps = 20; - int outputsteps = 10; - int maxoutputlen = 20; - - Seq2SeqCriterion seq2seq( - nclass, - hiddendim, - nclass - 2 /* eos token index */, - nclass - 1 /* pad token index */, - maxoutputlen, - {std::make_shared()}, - std::make_shared(0, 5, 2.2, 5.8), - true); - - auto input = fl::randn({hiddendim, inputsteps, batchsize}, fl::dtype::f32); - auto target = - fl::rand({outputsteps, batchsize}, fl::dtype::f32) * 0.99 * nclass; - target = target.astype(fl::dtype::s32); - - Variable outputV, attentionV, outputS, attentionS; - std::tie(outputV, attentionV) = seq2seq.vectorizedDecoder( - noGrad(input), noGrad(target), Tensor(), Tensor()); - - std::tie(outputS, attentionS) = - seq2seq.decoder(noGrad(input), noGrad(target), Tensor(), Tensor()); - - ASSERT_TRUE(allClose(outputV, outputS, 1e-6)); - ASSERT_TRUE(allClose(attentionV, attentionS, 1e-6)); + int nclass = 20; + int hiddendim = 16; + int batchsize = 2; + int inputsteps = 20; + int outputsteps = 10; + int maxoutputlen = 20; + + Seq2SeqCriterion seq2seq( + nclass, + hiddendim, + nclass - 2 /* eos token index */, + nclass - 1 /* pad token index */, + maxoutputlen, + {std::make_shared()}, + std::make_shared(0, 5, 2.2, 5.8), + true); + + auto input = fl::randn({hiddendim, inputsteps, batchsize}, fl::dtype::f32); + auto target = + fl::rand({outputsteps, batchsize}, fl::dtype::f32) * 0.99 * nclass; + target = target.astype(fl::dtype::s32); + + Variable outputV, attentionV, outputS, attentionS; + std::tie(outputV, attentionV) = seq2seq.vectorizedDecoder( + noGrad(input), + noGrad(target), + Tensor(), + Tensor() + ); + + std::tie(outputS, attentionS) = + seq2seq.decoder(noGrad(input), noGrad(target), Tensor(), Tensor()); + + ASSERT_TRUE(allClose(outputV, outputS, 1e-6)); + ASSERT_TRUE(allClose(attentionV, attentionS, 1e-6)); } TEST(Seq2SeqTest, Seq2SeqAttn) { - int N = 5, H = 8, B = 1, T = 10, U = 5, maxoutputlen = 100; - Seq2SeqCriterion seq2seq( - N, - H, - N - 2, - N - 1, - maxoutputlen, - {std::make_shared()}, - std::make_shared(2, 3)); - seq2seq.eval(); - - auto input = noGrad(fl::randn({H, T, B}, fl::dtype::f32)); - auto target = noGrad( - (fl::rand({U, B}, fl::dtype::f32) * 0.99 * N).astype(fl::dtype::s32)); - - Variable output, attention; - std::tie(output, attention) = - seq2seq.decoder(input, target, Tensor(), Tensor()); - // check padding works - ASSERT_EQ(attention.shape(), Shape({U, T, B})); + int N = 5, H = 8, B = 1, T = 10, U = 5, maxoutputlen = 100; + Seq2SeqCriterion seq2seq( + N, + H, + N - 2, + N - 1, + maxoutputlen, + {std::make_shared()}, + std::make_shared(2, 3)); + seq2seq.eval(); + + auto input = noGrad(fl::randn({H, T, B}, fl::dtype::f32)); + auto target = noGrad( + (fl::rand({U, B}, fl::dtype::f32) * 0.99 * N).astype(fl::dtype::s32) + ); + + Variable output, attention; + std::tie(output, attention) = + seq2seq.decoder(input, target, Tensor(), Tensor()); + // check padding works + ASSERT_EQ(attention.shape(), Shape({U, T, B})); } TEST(Seq2SeqTest, Seq2SeqMixedAttn) { - int N = 5, H = 8, B = 1, T = 10, U = 5, maxoutputlen = 100, nHead = 2; - Seq2SeqCriterion seq2seq( - N, - H, - N - 2, - N - 1, - maxoutputlen, - {std::make_shared(), - std::make_shared(H, nHead)}, - std::make_shared(1, 20, 2.2, 5.8), - false, - 100, - 0.0, - false, - kRandSampling, - 1.0, - 1, - 2); - seq2seq.eval(); - - auto input = noGrad(fl::randn({H, T, B}, fl::dtype::f32)); - auto target = noGrad( - (fl::rand({U, B}, fl::dtype::f32) * 0.99 * N).astype(fl::dtype::s32)); - - Variable output, attention; - std::tie(output, attention) = - seq2seq.decoder(input, target, Tensor(), Tensor()); - ASSERT_EQ(attention.shape(), Shape({U * nHead, T, B})); -} + int N = 5, H = 8, B = 1, T = 10, U = 5, maxoutputlen = 100, nHead = 2; + Seq2SeqCriterion seq2seq( + N, + H, + N - 2, + N - 1, + maxoutputlen, + {std::make_shared(), + std::make_shared(H, nHead)}, + std::make_shared(1, 20, 2.2, 5.8), + false, + 100, + 0.0, + false, + kRandSampling, + 1.0, + 1, + 2); + seq2seq.eval(); -TEST(Seq2SeqTest, Serialization) { - char* user = getenv("USER"); - std::string userstr = "unknown"; - if (user != nullptr) { - userstr = std::string(user); - } - const fs::path path = fs::temp_directory_path() / "test.mdl"; - - int N = 5, H = 8, B = 1, T = 10, U = 5, maxoutputlen = 100, nAttnRound = 2; - - std::vector> attentions( - nAttnRound, std::make_shared()); - - auto seq2seq = std::make_shared( - N, - H, - N - 2, - N - 1, - maxoutputlen, - attentions, - std::make_shared(2, 3), - false, - 100, - 0.0, - false, - kRandSampling, - 1.0, - 2, // nRnnLayer - nAttnRound, - 0.0); - seq2seq->eval(); - - auto input = noGrad(fl::randn({H, T, B}, fl::dtype::f32)); - auto target = noGrad( - (fl::rand({U, B}, fl::dtype::f32) * 0.99 * N).astype(fl::dtype::s32)); - - Variable output, attention; - std::tie(output, attention) = - seq2seq->decoder(input, target, Tensor(), Tensor()); - - save(path, seq2seq); - - std::shared_ptr loaded; - load(path, loaded); - loaded->eval(); - - Variable outputl, attentionl; - std::tie(outputl, attentionl) = - loaded->decoder(input, target, Tensor(), Tensor()); - - ASSERT_TRUE(allParamsClose(*loaded, *seq2seq)); - ASSERT_TRUE(allClose(outputl, output)); - ASSERT_TRUE(allClose(attentionl, attention)); + auto input = noGrad(fl::randn({H, T, B}, fl::dtype::f32)); + auto target = noGrad( + (fl::rand({U, B}, fl::dtype::f32) * 0.99 * N).astype(fl::dtype::s32) + ); + + Variable output, attention; + std::tie(output, attention) = + seq2seq.decoder(input, target, Tensor(), Tensor()); + ASSERT_EQ(attention.shape(), Shape({U* nHead, T, B})); } -TEST(Seq2SeqTest, BatchedDecoderStep) { - int N = 5, H = 8, B = 10, T = 20, maxoutputlen = 100; - int nRnnLayer = 2, nAttnRound = 2; - std::vector> contentAttentions( - nAttnRound, std::make_shared()); - std::vector> neuralContentAttentions( - nAttnRound, std::make_shared(H)); - - std::vector criterions{ - Seq2SeqCriterion( - N, - H, - N - 2, - N - 1, - maxoutputlen, - contentAttentions, - nullptr, - false, - 100, - 0.0, - false, - kRandSampling, - 1.0, - nRnnLayer, - nAttnRound, - 0.0), - Seq2SeqCriterion( - N, - H, - N - 2, - N - 1, - maxoutputlen, - neuralContentAttentions, - nullptr, - false, - 100, - 0.0, - false, - kRandSampling, - 1.0, - nRnnLayer, - nAttnRound, - 0.0)}; - - for (auto& seq2seq : criterions) { - seq2seq.eval(); - std::vector ys; - std::vector inStates(B, Seq2SeqState(nAttnRound)); - std::vector inStatePtrs(B); - - auto input = noGrad(fl::randn({H, T, 1}, fl::dtype::f32)); - std::vector> single_scores(B); - std::vector> batched_scores; - - for (int i = 0; i < B; i++) { - Variable y = constant(i % N, {1}, fl::dtype::s32, false); - ys.push_back(y); - - inStates[i].alpha = noGrad(fl::randn({1, T, 1}, fl::dtype::f32)); - for (int j = 0; j < nAttnRound; j++) { - inStates[i].hidden[j] = - noGrad(fl::randn({H, 1, nRnnLayer}, fl::dtype::f32)); - } - inStates[i].summary = noGrad(fl::randn({H, 1, 1}, fl::dtype::f32)); - inStatePtrs[i] = &inStates[i]; - - // Single forward - Seq2SeqState outstate(nAttnRound); - Variable ox; - std::tie(ox, outstate) = seq2seq.decodeStep( - input, y, inStates[i], Tensor(), Tensor(), input.dim(1)); - ox = logSoftmax(ox, 0); - single_scores[i] = ox.tensor().toHostVector(); +TEST(Seq2SeqTest, Serialization) { + char* user = getenv("USER"); + std::string userstr = "unknown"; + if(user != nullptr) { + userstr = std::string(user); } + const fs::path path = fs::temp_directory_path() / "test.mdl"; + + int N = 5, H = 8, B = 1, T = 10, U = 5, maxoutputlen = 100, nAttnRound = 2; + + std::vector> attentions( + nAttnRound, std::make_shared()); + + auto seq2seq = std::make_shared( + N, + H, + N - 2, + N - 1, + maxoutputlen, + attentions, + std::make_shared(2, 3), + false, + 100, + 0.0, + false, + kRandSampling, + 1.0, + 2, // nRnnLayer + nAttnRound, + 0.0 + ); + seq2seq->eval(); + + auto input = noGrad(fl::randn({H, T, B}, fl::dtype::f32)); + auto target = noGrad( + (fl::rand({U, B}, fl::dtype::f32) * 0.99 * N).astype(fl::dtype::s32) + ); - // Batched forward - std::vector outstates; - std::tie(batched_scores, outstates) = - seq2seq.decodeBatchStep(input, ys, inStatePtrs); + Variable output, attention; + std::tie(output, attention) = + seq2seq->decoder(input, target, Tensor(), Tensor()); - // Check - for (int i = 0; i < B; i++) { - for (int j = 0; j < N; j++) { - ASSERT_NEAR(single_scores[i][j], batched_scores[i][j], 1e-5); - } + save(path, seq2seq); + + std::shared_ptr loaded; + load(path, loaded); + loaded->eval(); + + Variable outputl, attentionl; + std::tie(outputl, attentionl) = + loaded->decoder(input, target, Tensor(), Tensor()); + + ASSERT_TRUE(allParamsClose(*loaded, *seq2seq)); + ASSERT_TRUE(allClose(outputl, output)); + ASSERT_TRUE(allClose(attentionl, attention)); +} + +TEST(Seq2SeqTest, BatchedDecoderStep) { + int N = 5, H = 8, B = 10, T = 20, maxoutputlen = 100; + int nRnnLayer = 2, nAttnRound = 2; + std::vector> contentAttentions( + nAttnRound, std::make_shared()); + std::vector> neuralContentAttentions( + nAttnRound, std::make_shared(H)); + + std::vector criterions{ + Seq2SeqCriterion( + N, + H, + N - 2, + N - 1, + maxoutputlen, + contentAttentions, + nullptr, + false, + 100, + 0.0, + false, + kRandSampling, + 1.0, + nRnnLayer, + nAttnRound, + 0.0 + ), + Seq2SeqCriterion( + N, + H, + N - 2, + N - 1, + maxoutputlen, + neuralContentAttentions, + nullptr, + false, + 100, + 0.0, + false, + kRandSampling, + 1.0, + nRnnLayer, + nAttnRound, + 0.0 + )}; + + for(auto& seq2seq : criterions) { + seq2seq.eval(); + std::vector ys; + std::vector inStates(B, Seq2SeqState(nAttnRound)); + std::vector inStatePtrs(B); + + auto input = noGrad(fl::randn({H, T, 1}, fl::dtype::f32)); + std::vector> single_scores(B); + std::vector> batched_scores; + + for(int i = 0; i < B; i++) { + Variable y = constant(i % N, {1}, fl::dtype::s32, false); + ys.push_back(y); + + inStates[i].alpha = noGrad(fl::randn({1, T, 1}, fl::dtype::f32)); + for(int j = 0; j < nAttnRound; j++) { + inStates[i].hidden[j] = + noGrad(fl::randn({H, 1, nRnnLayer}, fl::dtype::f32)); + } + inStates[i].summary = noGrad(fl::randn({H, 1, 1}, fl::dtype::f32)); + inStatePtrs[i] = &inStates[i]; + + // Single forward + Seq2SeqState outstate(nAttnRound); + Variable ox; + std::tie(ox, outstate) = seq2seq.decodeStep( + input, + y, + inStates[i], + Tensor(), + Tensor(), + input.dim(1) + ); + ox = logSoftmax(ox, 0); + single_scores[i] = ox.tensor().toHostVector(); + } + + // Batched forward + std::vector outstates; + std::tie(batched_scores, outstates) = + seq2seq.decodeBatchStep(input, ys, inStatePtrs); + + // Check + for(int i = 0; i < B; i++) { + for(int j = 0; j < N; j++) { + ASSERT_NEAR(single_scores[i][j], batched_scores[i][j], 1e-5); + } + } } - } } TEST(Seq2SeqTest, Seq2SeqSampling) { - int N = 5, H = 8, B = 1, T = 10, U = 5, maxoutputlen = 100; - auto input = noGrad(fl::randn({H, T, B}, fl::dtype::f32)); - auto target = noGrad( - (fl::rand({U, B}, fl::dtype::f32) * 0.99 * N).astype(fl::dtype::s32)); - - std::vector samplingStrategy({kRandSampling, kModelSampling}); + int N = 5, H = 8, B = 1, T = 10, U = 5, maxoutputlen = 100; + auto input = noGrad(fl::randn({H, T, B}, fl::dtype::f32)); + auto target = noGrad( + (fl::rand({U, B}, fl::dtype::f32) * 0.99 * N).astype(fl::dtype::s32) + ); + + std::vector samplingStrategy({kRandSampling, kModelSampling}); + + for(const auto& ss : samplingStrategy) { + Seq2SeqCriterion seq2seq( + N, + H, + N - 2, + N - 1, + maxoutputlen, + {std::make_shared()}, + nullptr, + false, + 0, + 0.05, + false, + ss); + seq2seq.train(); + + Variable output, attention; + std::tie(output, attention) = + seq2seq.decoder(input, target, Tensor(), Tensor()); + ASSERT_EQ(attention.shape(), Shape({U, T, B})); + ASSERT_EQ(output.shape(), Shape({N, U, B})); + } - for (const auto& ss : samplingStrategy) { - Seq2SeqCriterion seq2seq( + Seq2SeqCriterion seq2seq1( N, H, N - 2, @@ -443,61 +486,40 @@ TEST(Seq2SeqTest, Seq2SeqSampling) { {std::make_shared()}, nullptr, false, - 0, + 60, 0.05, false, - ss); - seq2seq.train(); + kRandSampling); + seq2seq1.train(); Variable output, attention; std::tie(output, attention) = - seq2seq.decoder(input, target, Tensor(), Tensor()); + seq2seq1.vectorizedDecoder(input, target, Tensor(), Tensor()); ASSERT_EQ(attention.shape(), Shape({U, T, B})); ASSERT_EQ(output.shape(), Shape({N, U, B})); - } - - Seq2SeqCriterion seq2seq1( - N, - H, - N - 2, - N - 1, - maxoutputlen, - {std::make_shared()}, - nullptr, - false, - 60, - 0.05, - false, - kRandSampling); - seq2seq1.train(); - - Variable output, attention; - std::tie(output, attention) = - seq2seq1.vectorizedDecoder(input, target, Tensor(), Tensor()); - ASSERT_EQ(attention.shape(), Shape({U, T, B})); - ASSERT_EQ(output.shape(), Shape({N, U, B})); - - Seq2SeqCriterion seq2seq2( - N, - H, - N - 2, - N - 1, - maxoutputlen, - {std::make_shared()}, - nullptr, - false, - 60, - 0.05, - false, - kModelSampling); - seq2seq2.train(); - ASSERT_THROW( - seq2seq2.vectorizedDecoder(input, target, Tensor(), Tensor()), - std::logic_error); + + Seq2SeqCriterion seq2seq2( + N, + H, + N - 2, + N - 1, + maxoutputlen, + {std::make_shared()}, + nullptr, + false, + 60, + 0.05, + false, + kModelSampling); + seq2seq2.train(); + ASSERT_THROW( + seq2seq2.vectorizedDecoder(input, target, Tensor(), Tensor()), + std::logic_error + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp b/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp index 1bddd6b..be1345c 100644 --- a/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp +++ b/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp @@ -17,225 +17,250 @@ using namespace fl; using namespace fl::pkg::speech; namespace { -using JacobianFunc = std::function; +using JacobianFunc = std::function; bool jacobianTestImpl( const JacobianFunc& func, Variable& input, float precision = 1E-5, float perturbation = 1E-4, const std::vector& zeroGradientVariables = {}) { - auto fwdJacobian = - Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); - - for (int i = 0; i < input.elements(); ++i) { - Tensor orig = input.tensor().flatten()(i); - input.tensor().flat(i) = orig - perturbation; - auto outa = func(input).tensor(); - - input.tensor().flat(i) = orig + perturbation; - auto outb = func(input).tensor(); - input.tensor().flat(i) = orig; - - fwdJacobian(fl::span, i) = - fl::reshape((outb - outa), {static_cast(outa.elements())}) * 0.5 / - perturbation; - } - - auto bwdJacobian = - Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); - auto dout = - Variable(fl::full(func(input).shape(), 0, func(input).type()), false); - - for (int i = 0; i < dout.elements(); ++i) { - dout.tensor().flat(i) = 1; // element in 1D view - input.zeroGrad(); - for (auto* var : zeroGradientVariables) { - var->zeroGrad(); + auto fwdJacobian = + Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); + + for(int i = 0; i < input.elements(); ++i) { + Tensor orig = input.tensor().flatten()(i); + input.tensor().flat(i) = orig - perturbation; + auto outa = func(input).tensor(); + + input.tensor().flat(i) = orig + perturbation; + auto outb = func(input).tensor(); + input.tensor().flat(i) = orig; + + fwdJacobian(fl::span, i) = + fl::reshape((outb - outa), {static_cast(outa.elements())}) * 0.5 + / perturbation; } - auto out = func(input); - out.backward(dout); - bwdJacobian(i) = fl::reshape(input.grad().tensor(), {input.elements()}); - dout.tensor().flat(i) = 0; - } - return allClose(fwdJacobian, bwdJacobian, precision); + auto bwdJacobian = + Tensor({func(input).elements(), input.elements()}, fl::dtype::f32); + auto dout = + Variable(fl::full(func(input).shape(), 0, func(input).type()), false); + + for(int i = 0; i < dout.elements(); ++i) { + dout.tensor().flat(i) = 1; // element in 1D view + input.zeroGrad(); + for(auto* var : zeroGradientVariables) { + var->zeroGrad(); + } + auto out = func(input); + out.backward(dout); + + bwdJacobian(i) = fl::reshape(input.grad().tensor(), {input.elements()}); + dout.tensor().flat(i) = 0; + } + return allClose(fwdJacobian, bwdJacobian, precision); } void sequentialTest(std::shared_ptr attention, int H) { - int B = 2, T = 10; - - Variable encodedx(fl::randn({H, T, B}), true); - Variable encodedy(fl::randn({H, 1, B}), true); - - Variable alphas, summaries; - for (int step = 0; step < 3; ++step) { - std::tie(alphas, summaries) = - attention->forward(encodedy, encodedx, alphas); - ASSERT_EQ(alphas.shape(), Shape({1, T, B})); - ASSERT_EQ(summaries.shape(), Shape({H, 1, B})); - - auto alphasum = fl::sum(alphas.tensor(), {1}); - auto ones = fl::full(alphasum.shape(), 1.0, alphasum.type()); - ASSERT_TRUE(allClose(alphasum, ones, 1e-5)); - } - - Variable windowMask = Variable(fl::full({1, T, B}, 0.), false); - auto alphas1 = - std::get<0>(attention->forward(encodedy, encodedx, alphas, windowMask)); - auto alphas2 = std::get<0>(attention->forward(encodedy, encodedx, alphas)); - ASSERT_TRUE(allClose(alphas1, alphas2, 1e-6)); - - Variable encodedyInvalid(fl::randn({H, 10, B}), true); - EXPECT_THROW( - attention->forward(encodedyInvalid, encodedx, alphas), - std::invalid_argument); + int B = 2, T = 10; + + Variable encodedx(fl::randn({H, T, B}), true); + Variable encodedy(fl::randn({H, 1, B}), true); + + Variable alphas, summaries; + for(int step = 0; step < 3; ++step) { + std::tie(alphas, summaries) = + attention->forward(encodedy, encodedx, alphas); + ASSERT_EQ(alphas.shape(), Shape({1, T, B})); + ASSERT_EQ(summaries.shape(), Shape({H, 1, B})); + + auto alphasum = fl::sum(alphas.tensor(), {1}); + auto ones = fl::full(alphasum.shape(), 1.0, alphasum.type()); + ASSERT_TRUE(allClose(alphasum, ones, 1e-5)); + } + + Variable windowMask = Variable(fl::full({1, T, B}, 0.), false); + auto alphas1 = + std::get<0>(attention->forward(encodedy, encodedx, alphas, windowMask)); + auto alphas2 = std::get<0>(attention->forward(encodedy, encodedx, alphas)); + ASSERT_TRUE(allClose(alphas1, alphas2, 1e-6)); + + Variable encodedyInvalid(fl::randn({H, 10, B}), true); + EXPECT_THROW( + attention->forward(encodedyInvalid, encodedx, alphas), + std::invalid_argument + ); } void sequentialTestWithPad(std::shared_ptr attention, int H) { - int B = 2, T = 10; - - Variable encodedx(fl::randn({H, T, B}), true); - std::vector padRaw = {T / 2, T}; - Variable pad = Variable(Tensor::fromVector({1, B}, padRaw), false); - Variable encodedy(fl::randn({H, 1, B}), true); - - Variable alphas, summaries; - for (int step = 0; step < 3; ++step) { - std::tie(alphas, summaries) = - attention->forward(encodedy, encodedx, alphas, Variable(), pad); - ASSERT_EQ(alphas.shape(), Shape({1, T, B})); - ASSERT_EQ(summaries.shape(), Shape({H, 1, B})); - - auto alphasum = fl::sum(alphas.tensor(), {1}); - auto ones = fl::full(alphasum.shape(), 1.0, alphasum.type()); - ASSERT_TRUE(allClose(alphasum, ones, 1e-5)); - ASSERT_EQ( - fl::countNonzero( - alphas.tensor()(fl::span, fl::range(T - T / 2, T), 0) == 0) - .scalar(), - T / 2); - } - - Variable windowMask = Variable(fl::full({1, T, B}, 0.0), false); - auto alphas1 = std::get<0>( - attention->forward(encodedy, encodedx, alphas, windowMask, pad)); - auto alphas2 = std::get<0>( - attention->forward(encodedy, encodedx, alphas, Variable{}, pad)); - ASSERT_TRUE(allClose(alphas1, alphas2, 1e-6)); - - Variable encodedyInvalid(fl::randn({H, 10, B}), true); - EXPECT_THROW( - attention->forward(encodedyInvalid, encodedx, alphas), - std::invalid_argument); + int B = 2, T = 10; + + Variable encodedx(fl::randn({H, T, B}), true); + std::vector padRaw = {T / 2, T}; + Variable pad = Variable(Tensor::fromVector({1, B}, padRaw), false); + Variable encodedy(fl::randn({H, 1, B}), true); + + Variable alphas, summaries; + for(int step = 0; step < 3; ++step) { + std::tie(alphas, summaries) = + attention->forward(encodedy, encodedx, alphas, Variable(), pad); + ASSERT_EQ(alphas.shape(), Shape({1, T, B})); + ASSERT_EQ(summaries.shape(), Shape({H, 1, B})); + + auto alphasum = fl::sum(alphas.tensor(), {1}); + auto ones = fl::full(alphasum.shape(), 1.0, alphasum.type()); + ASSERT_TRUE(allClose(alphasum, ones, 1e-5)); + ASSERT_EQ( + fl::countNonzero( + alphas.tensor()(fl::span, fl::range(T - T / 2, T), 0) == 0) + .scalar(), + T / 2); + } + + Variable windowMask = Variable(fl::full({1, T, B}, 0.0), false); + auto alphas1 = std::get<0>( + attention->forward(encodedy, encodedx, alphas, windowMask, pad) + ); + auto alphas2 = std::get<0>( + attention->forward(encodedy, encodedx, alphas, Variable{}, pad) + ); + ASSERT_TRUE(allClose(alphas1, alphas2, 1e-6)); + + Variable encodedyInvalid(fl::randn({H, 10, B}), true); + EXPECT_THROW( + attention->forward(encodedyInvalid, encodedx, alphas), + std::invalid_argument + ); } } // namespace TEST(AttentionTest, NeuralContentAttention) { - int H = 8, B = 2, T = 10, U = 5; - NeuralContentAttention attention(H); + int H = 8, B = 2, T = 10, U = 5; + NeuralContentAttention attention(H); - Variable encodedx(fl::randn({H, T, B}), true); - Variable encodedy(fl::randn({H, U, B}), true); + Variable encodedx(fl::randn({H, T, B}), true); + Variable encodedy(fl::randn({H, U, B}), true); - std::vector padRaw = {T / 2, T}; - Variable pad = Variable(Tensor::fromVector(Shape({1, B}), padRaw), false); + std::vector padRaw = {T / 2, T}; + Variable pad = Variable(Tensor::fromVector(Shape({1, B}), padRaw), false); - std::vector padV = {Variable(), pad}; - for (const auto& currentPad : padV) { - Variable alphas, summaries; - std::tie(alphas, summaries) = attention.forward( - encodedy, encodedx, Variable{}, Variable{}, currentPad); - ASSERT_EQ(alphas.shape(), Shape({U, T, B})); - ASSERT_EQ(summaries.shape(), Shape({H, U, B})); - if (!currentPad.isEmpty()) { - ASSERT_EQ( - fl::countNonzero( - alphas.tensor()(fl::span, fl::range(T - T / 2, T), 0) == 0) - .scalar(), - T / 2 * U); + std::vector padV = {Variable(), pad}; + for(const auto& currentPad : padV) { + Variable alphas, summaries; + std::tie(alphas, summaries) = attention.forward( + encodedy, + encodedx, + Variable{}, + Variable{}, + currentPad + ); + ASSERT_EQ(alphas.shape(), Shape({U, T, B})); + ASSERT_EQ(summaries.shape(), Shape({H, U, B})); + if(!currentPad.isEmpty()) { + ASSERT_EQ( + fl::countNonzero( + alphas.tensor()(fl::span, fl::range(T - T / 2, T), 0) == 0) + .scalar(), + T / 2 * U); + } + auto alphasum = sum(alphas.tensor(), {1}); + auto ones = fl::full(alphasum.shape(), 1.0, alphasum.type()); + ASSERT_TRUE(allClose(alphasum, ones, 1e-5)); + + Variable windowMask = Variable(fl::full({U, T, B}, 0.0), false); + auto alphas1 = std::get<0>( + attention.forward( + encodedy, + encodedx, + Variable{}, + windowMask, + currentPad + ) + ); + ASSERT_TRUE(allClose(alphas, alphas1, 1e-6)); } - auto alphasum = sum(alphas.tensor(), {1}); - auto ones = fl::full(alphasum.shape(), 1.0, alphasum.type()); - ASSERT_TRUE(allClose(alphasum, ones, 1e-5)); - - Variable windowMask = Variable(fl::full({U, T, B}, 0.0), false); - auto alphas1 = std::get<0>(attention.forward( - encodedy, encodedx, Variable{}, windowMask, currentPad)); - ASSERT_TRUE(allClose(alphas, alphas1, 1e-6)); - } } TEST(AttentionTest, SimpleLocationAttention) { - int H = 8, K = 5; - sequentialTest(std::make_shared(K), H); - sequentialTestWithPad(std::make_shared(K), H); + int H = 8, K = 5; + sequentialTest(std::make_shared(K), H); + sequentialTestWithPad(std::make_shared(K), H); } TEST(AttentionTest, LocationAttention) { - int H = 8, K = 5; - sequentialTest(std::make_shared(H, K), H); - sequentialTestWithPad(std::make_shared(H, K), H); + int H = 8, K = 5; + sequentialTest(std::make_shared(H, K), H); + sequentialTestWithPad(std::make_shared(H, K), H); } TEST(AttentionTest, NeuralLocationAttention) { - int H = 8, A = 8, C = 5, K = 3; - sequentialTest(std::make_shared(H, A, C, K), H); - sequentialTestWithPad( - std::make_shared(H, A, C, K), H); + int H = 8, A = 8, C = 5, K = 3; + sequentialTest(std::make_shared(H, A, C, K), H); + sequentialTestWithPad( + std::make_shared(H, A, C, K), + H + ); } TEST(AttentionTest, MultiHeadContentAttention) { - int H = 512, B = 2, T = 10, U = 5, NH = 8; - - std::vector padRaw = {T / 2, T}; - Variable pad = Variable(Tensor::fromVector({1, B}, padRaw), false); - - std::vector padV = {Variable(), pad}; - for (const auto& currentPad : padV) { - for (bool keyValue : {true, false}) { - for (bool splitInput : {true, false}) { - MultiHeadContentAttention attention(H, NH, keyValue, splitInput); - - auto hEncode = keyValue ? H * 2 : H; - Variable encodedx(fl::randn({hEncode, T, B}), true); - Variable encodedy(fl::randn({H, U, B}), true); - - Variable alphas, summaries; - std::tie(alphas, summaries) = attention.forward( - encodedy, encodedx, Variable{}, Variable{}, currentPad); - ASSERT_EQ(alphas.shape(), Shape({U * NH, T, B})); - ASSERT_EQ(summaries.shape(), Shape({H, U, B})); - if (!currentPad.isEmpty()) { - ASSERT_EQ( - fl::countNonzero( - alphas.tensor()(fl::span, fl::range(T - T / 2, T), 0) == 0) - .scalar(), - T / 2 * U * NH); + int H = 512, B = 2, T = 10, U = 5, NH = 8; + + std::vector padRaw = {T / 2, T}; + Variable pad = Variable(Tensor::fromVector({1, B}, padRaw), false); + + std::vector padV = {Variable(), pad}; + for(const auto& currentPad : padV) { + for(bool keyValue : {true, false}) { + for(bool splitInput : {true, false}) { + MultiHeadContentAttention attention(H, NH, keyValue, splitInput); + + auto hEncode = keyValue ? H * 2 : H; + Variable encodedx(fl::randn({hEncode, T, B}), true); + Variable encodedy(fl::randn({H, U, B}), true); + + Variable alphas, summaries; + std::tie(alphas, summaries) = attention.forward( + encodedy, + encodedx, + Variable{}, + Variable{}, + currentPad + ); + ASSERT_EQ(alphas.shape(), Shape({U* NH, T, B})); + ASSERT_EQ(summaries.shape(), Shape({H, U, B})); + if(!currentPad.isEmpty()) { + ASSERT_EQ( + fl::countNonzero( + alphas.tensor()(fl::span, fl::range(T - T / 2, T), 0) == 0) + .scalar(), + T / 2 * U * NH); + } + + auto alphasum = sum(alphas.tensor(), {1}); + auto ones = fl::full(alphasum.shape(), 1.0, alphasum.type()); + ASSERT_TRUE(allClose(alphasum, ones, 1e-5)); + } } - - auto alphasum = sum(alphas.tensor(), {1}); - auto ones = fl::full(alphasum.shape(), 1.0, alphasum.type()); - ASSERT_TRUE(allClose(alphasum, ones, 1e-5)); - } } - } } TEST(AttentionTest, JacobianMaskAttention) { - // CxTxB - auto in = Variable(fl::rand({10, 9, 5}, fl::dtype::f32), true); - std::vector inpSzRaw = {1, 2, 4, 8, 16}; - Tensor inpSz = Tensor::fromVector( - {1, static_cast(inpSzRaw.size())}, inpSzRaw); - auto func_in = [&](Variable& input) { - return fl::pkg::speech::maskAttention(input, fl::Variable(inpSz, false)); - }; - ASSERT_TRUE(jacobianTestImpl(func_in, in, 2e-4)); + // CxTxB + auto in = Variable(fl::rand({10, 9, 5}, fl::dtype::f32), true); + std::vector inpSzRaw = {1, 2, 4, 8, 16}; + Tensor inpSz = Tensor::fromVector( + {1, static_cast(inpSzRaw.size())}, + inpSzRaw + ); + auto func_in = [&](Variable& input) { + return fl::pkg::speech::maskAttention(input, fl::Variable(inpSz, false)); + }; + ASSERT_TRUE(jacobianTestImpl(func_in, in, 2e-4)); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/criterion/attention/WindowTest.cpp b/flashlight/pkg/speech/test/criterion/attention/WindowTest.cpp index 80489db..f39a0a0 100644 --- a/flashlight/pkg/speech/test/criterion/attention/WindowTest.cpp +++ b/flashlight/pkg/speech/test/criterion/attention/WindowTest.cpp @@ -16,280 +16,366 @@ using namespace fl; using namespace fl::pkg::speech; TEST(WindowTest, MedianWindow) { - int inputsteps = 12; - int batchsize = 4; - int hiddendim = 16; - int wl = 2; - int wr = 3; - auto inputAttnArray = - fl::abs(fl::randn({1, wl + wr, batchsize}, fl::dtype::f32)); - auto inputAttn = Variable( - inputAttnArray / - fl::tile( - sum(inputAttnArray, {1}, /* keepDims = */ true), - {1, inputAttnArray.dim(1)}), - false); - - MedianWindow window(wl, wr); - - // check initialization - auto mask0 = window.computeWindow(inputAttn, 0, -1, inputsteps, batchsize); - - auto trueSumMask0 = fl::full({1, inputsteps, batchsize}, 0.0, fl::dtype::f32); - trueSumMask0(fl::span, fl::range(0, wl + wr), fl::span) = 1.0; - - ASSERT_EQ(mask0.shape(), Shape({1, inputsteps, batchsize})); - ASSERT_TRUE(allClose(fl::exp(mask0.tensor()), trueSumMask0)); - - // check next step - auto mask1 = - window.computeWindow(inputAttn, 1, inputsteps, inputsteps, batchsize); - ASSERT_EQ(mask1.shape(), Shape({1, inputsteps, batchsize})); - - // make sure large window size is handled - MedianWindow largeWindow(100, 100); - auto maskLarge = largeWindow.computeWindow( - inputAttn, 0, inputsteps, inputsteps, batchsize); - trueSumMask0 = fl::full({1, inputsteps, batchsize}, 0, fl::dtype::f32); - ASSERT_TRUE(allClose(maskLarge.tensor(), trueSumMask0)); + int inputsteps = 12; + int batchsize = 4; + int hiddendim = 16; + int wl = 2; + int wr = 3; + auto inputAttnArray = + fl::abs(fl::randn({1, wl + wr, batchsize}, fl::dtype::f32)); + auto inputAttn = Variable( + inputAttnArray + / fl::tile( + sum(inputAttnArray, {1}, /* keepDims = */ true), + {1, inputAttnArray.dim(1)} + ), + false + ); + + MedianWindow window(wl, wr); + + // check initialization + auto mask0 = window.computeWindow(inputAttn, 0, -1, inputsteps, batchsize); + + auto trueSumMask0 = fl::full({1, inputsteps, batchsize}, 0.0, fl::dtype::f32); + trueSumMask0(fl::span, fl::range(0, wl + wr), fl::span) = 1.0; + + ASSERT_EQ(mask0.shape(), Shape({1, inputsteps, batchsize})); + ASSERT_TRUE(allClose(fl::exp(mask0.tensor()), trueSumMask0)); + + // check next step + auto mask1 = + window.computeWindow(inputAttn, 1, inputsteps, inputsteps, batchsize); + ASSERT_EQ(mask1.shape(), Shape({1, inputsteps, batchsize})); + + // make sure large window size is handled + MedianWindow largeWindow(100, 100); + auto maskLarge = largeWindow.computeWindow( + inputAttn, + 0, + inputsteps, + inputsteps, + batchsize + ); + trueSumMask0 = fl::full({1, inputsteps, batchsize}, 0, fl::dtype::f32); + ASSERT_TRUE(allClose(maskLarge.tensor(), trueSumMask0)); } TEST(WindowTest, MedianWindowWithPad) { - int inputsteps = 12; - int batchsize = 2; - int wl = 3; - int wr = 5; - auto inputAttnArray = - fl::abs(fl::randn({1, wl + wr, batchsize}, fl::dtype::f32)); - auto inputAttn = Variable( - inputAttnArray / - fl::tile( - sum(inputAttnArray, {1}, /* keepDims = */ true), - {1, inputAttnArray.dim(1)}), - false); - - MedianWindow window(wl, wr); - std::vector inpSzRaw = {1, 2}; - Tensor inpSz = Tensor::fromVector({1, batchsize}, inpSzRaw); - std::vector tgSzRaw = {1, 2}; - Tensor tgSz = Tensor::fromVector({1, batchsize}, tgSzRaw); - - // check initialization - auto mask0 = window.computeWindow( - inputAttn, 0, -1, inputsteps, batchsize, inpSz, tgSz); - - auto trueSumMask0 = fl::full({1, inputsteps, batchsize}, 0.0, fl::dtype::f32); - trueSumMask0(fl::span, fl::range(0, wl + wr), fl::span) = 1.0; - trueSumMask0(fl::span, fl::range(inputsteps / 2, inputsteps), 0) = 0.0; - - ASSERT_EQ(mask0.shape(), Shape({1, inputsteps, batchsize})); - ASSERT_TRUE(allClose(fl::exp(mask0.tensor()), trueSumMask0)); - - // check next step - auto mask2 = - window.computeWindow(inputAttn, 2, 2, inputsteps, batchsize, inpSz, tgSz); - ASSERT_EQ(mask2.shape(), Shape({1, inputsteps, batchsize})); - ASSERT_TRUE( - fl::countNonzero( - fl::exp(mask2.tensor())( - 0, fl::range(inputsteps - inputsteps / 2, inputsteps), 0) == 0) - .scalar() == inputsteps / 2); - ASSERT_TRUE( - fl::countNonzero(fl::exp(mask2.tensor())(0, fl::span, 0) == 0) - .scalar() == inputsteps); + int inputsteps = 12; + int batchsize = 2; + int wl = 3; + int wr = 5; + auto inputAttnArray = + fl::abs(fl::randn({1, wl + wr, batchsize}, fl::dtype::f32)); + auto inputAttn = Variable( + inputAttnArray + / fl::tile( + sum(inputAttnArray, {1}, /* keepDims = */ true), + {1, inputAttnArray.dim(1)} + ), + false + ); + + MedianWindow window(wl, wr); + std::vector inpSzRaw = {1, 2}; + Tensor inpSz = Tensor::fromVector({1, batchsize}, inpSzRaw); + std::vector tgSzRaw = {1, 2}; + Tensor tgSz = Tensor::fromVector({1, batchsize}, tgSzRaw); + + // check initialization + auto mask0 = window.computeWindow( + inputAttn, + 0, + -1, + inputsteps, + batchsize, + inpSz, + tgSz + ); + + auto trueSumMask0 = fl::full({1, inputsteps, batchsize}, 0.0, fl::dtype::f32); + trueSumMask0(fl::span, fl::range(0, wl + wr), fl::span) = 1.0; + trueSumMask0(fl::span, fl::range(inputsteps / 2, inputsteps), 0) = 0.0; + + ASSERT_EQ(mask0.shape(), Shape({1, inputsteps, batchsize})); + ASSERT_TRUE(allClose(fl::exp(mask0.tensor()), trueSumMask0)); + + // check next step + auto mask2 = + window.computeWindow(inputAttn, 2, 2, inputsteps, batchsize, inpSz, tgSz); + ASSERT_EQ(mask2.shape(), Shape({1, inputsteps, batchsize})); + ASSERT_TRUE( + fl::countNonzero( + fl::exp(mask2.tensor())( + 0, + fl::range(inputsteps - inputsteps / 2, inputsteps), + 0 + ) == 0 + ) + .scalar() == inputsteps / 2 + ); + ASSERT_TRUE( + fl::countNonzero(fl::exp(mask2.tensor())(0, fl::span, 0) == 0) + .scalar() == inputsteps + ); } TEST(WindowTest, StepWindow) { - int inputsteps = 100; - int batchsize = 4; - int hiddendim = 16; - int targetlen = 30; - int sMin = 3, sMax = 15; - double vMin = 2.3, vMax = 7.5; - - Variable inputAttn; // dummy - std::vector windowBoundaries(2, 0); - - StepWindow window(sMin, sMax, vMin, vMax); - - // check initialization - auto mask0 = - window.computeWindow(inputAttn, 0, inputsteps, inputsteps, batchsize); - auto trueSumMask0 = fl::full({1, inputsteps, batchsize}, 0.0, fl::dtype::f32); - windowBoundaries[0] = sMin; - windowBoundaries[1] = sMax; - - trueSumMask0( - fl::span, fl::range(windowBoundaries[0], windowBoundaries[1]), fl::span) = - 1.0; - - ASSERT_EQ(mask0.shape(), Shape({1, inputsteps, batchsize})); - ASSERT_TRUE(allClose(fl::exp(mask0.tensor()), trueSumMask0)); - - auto mask1 = - window.computeWindow(inputAttn, 1, inputsteps, inputsteps, batchsize); - auto trueSumMask1 = fl::full({1, inputsteps, batchsize}, 0.0, fl::dtype::f32); - windowBoundaries[0] = static_cast(std::round(sMin + vMin)); - windowBoundaries[1] = static_cast(std::round(sMax + vMax)); - - trueSumMask1( - fl::span, fl::range(windowBoundaries[0], windowBoundaries[1]), fl::span) = - 1.0; - - ASSERT_EQ(mask1.shape(), Shape({1, inputsteps, batchsize})); - ASSERT_TRUE(allClose(fl::exp(mask1.tensor()), trueSumMask1)); - - auto maskLarge = - window.computeWindow(inputAttn, 1000, inputsteps, inputsteps, batchsize); - auto trueSumMaskLarge = - fl::full({1, inputsteps, batchsize}, 0.0, fl::dtype::f32); - windowBoundaries[0] = static_cast(std::round(inputsteps - vMax)); - windowBoundaries[1] = inputsteps; - - trueSumMaskLarge( - fl::span, fl::range(windowBoundaries[0], windowBoundaries[1]), fl::span) = - 1.0; - - ASSERT_EQ(maskLarge.shape(), Shape({1, inputsteps, batchsize})); - ASSERT_TRUE(allClose(fl::exp(maskLarge.tensor()), trueSumMaskLarge)); - - auto maskV = window.computeVectorizedWindow(targetlen, inputsteps, batchsize); - ASSERT_EQ(maskV.shape(), Shape({targetlen, inputsteps, batchsize})); - - std::vector inpSzRaw = {1, 2, 2, 2}; - Tensor inpSz = Tensor::fromVector({1, batchsize}, inpSzRaw); - std::vector tgSzRaw = {1, 2, 2, 2}; - Tensor tgSz = Tensor::fromVector({1, batchsize}, tgSzRaw); - - auto maskVPad = fl::exp(window - .computeVectorizedWindow( - targetlen, inputsteps, batchsize, inpSz, tgSz) - .tensor()); - ASSERT_EQ(maskVPad.shape(), Shape({targetlen, inputsteps, batchsize})); - ASSERT_TRUE( - fl::countNonzero( - maskVPad( - fl::span, - fl::range(inputsteps - inputsteps / 2, inputsteps), - 0) == 0) - .scalar() == inputsteps / 2 * targetlen); - ASSERT_TRUE( - fl::countNonzero( - maskVPad( - fl::range(targetlen - targetlen / 2, targetlen), fl::span, 0) == - 0) - .scalar() == targetlen / 2 * inputsteps); + int inputsteps = 100; + int batchsize = 4; + int hiddendim = 16; + int targetlen = 30; + int sMin = 3, sMax = 15; + double vMin = 2.3, vMax = 7.5; + + Variable inputAttn; // dummy + std::vector windowBoundaries(2, 0); + + StepWindow window(sMin, sMax, vMin, vMax); + + // check initialization + auto mask0 = + window.computeWindow(inputAttn, 0, inputsteps, inputsteps, batchsize); + auto trueSumMask0 = fl::full({1, inputsteps, batchsize}, 0.0, fl::dtype::f32); + windowBoundaries[0] = sMin; + windowBoundaries[1] = sMax; + + trueSumMask0( + fl::span, + fl::range(windowBoundaries[0], windowBoundaries[1]), + fl::span + ) = + 1.0; + + ASSERT_EQ(mask0.shape(), Shape({1, inputsteps, batchsize})); + ASSERT_TRUE(allClose(fl::exp(mask0.tensor()), trueSumMask0)); + + auto mask1 = + window.computeWindow(inputAttn, 1, inputsteps, inputsteps, batchsize); + auto trueSumMask1 = fl::full({1, inputsteps, batchsize}, 0.0, fl::dtype::f32); + windowBoundaries[0] = static_cast(std::round(sMin + vMin)); + windowBoundaries[1] = static_cast(std::round(sMax + vMax)); + + trueSumMask1( + fl::span, + fl::range(windowBoundaries[0], windowBoundaries[1]), + fl::span + ) = + 1.0; + + ASSERT_EQ(mask1.shape(), Shape({1, inputsteps, batchsize})); + ASSERT_TRUE(allClose(fl::exp(mask1.tensor()), trueSumMask1)); + + auto maskLarge = + window.computeWindow(inputAttn, 1000, inputsteps, inputsteps, batchsize); + auto trueSumMaskLarge = + fl::full({1, inputsteps, batchsize}, 0.0, fl::dtype::f32); + windowBoundaries[0] = static_cast(std::round(inputsteps - vMax)); + windowBoundaries[1] = inputsteps; + + trueSumMaskLarge( + fl::span, + fl::range(windowBoundaries[0], windowBoundaries[1]), + fl::span + ) = + 1.0; + + ASSERT_EQ(maskLarge.shape(), Shape({1, inputsteps, batchsize})); + ASSERT_TRUE(allClose(fl::exp(maskLarge.tensor()), trueSumMaskLarge)); + + auto maskV = window.computeVectorizedWindow(targetlen, inputsteps, batchsize); + ASSERT_EQ(maskV.shape(), Shape({targetlen, inputsteps, batchsize})); + + std::vector inpSzRaw = {1, 2, 2, 2}; + Tensor inpSz = Tensor::fromVector({1, batchsize}, inpSzRaw); + std::vector tgSzRaw = {1, 2, 2, 2}; + Tensor tgSz = Tensor::fromVector({1, batchsize}, tgSzRaw); + + auto maskVPad = fl::exp( + window + .computeVectorizedWindow( + targetlen, + inputsteps, + batchsize, + inpSz, + tgSz + ) + .tensor() + ); + ASSERT_EQ(maskVPad.shape(), Shape({targetlen, inputsteps, batchsize})); + ASSERT_TRUE( + fl::countNonzero( + maskVPad( + fl::span, + fl::range(inputsteps - inputsteps / 2, inputsteps), + 0 + ) == 0 + ) + .scalar() == inputsteps / 2 * targetlen + ); + ASSERT_TRUE( + fl::countNonzero( + maskVPad( + fl::range(targetlen - targetlen / 2, targetlen), + fl::span, + 0 + ) + == 0 + ) + .scalar() == targetlen / 2 * inputsteps + ); } TEST(WindowTest, SoftWindow) { - int inputsteps = 100; - int batchsize = 4; - int targetlen = 15; - int offset = 10; - double avgRate = 5.2, std = 5.0; - - Variable inputAttn; // dummy - SoftWindow window(std, avgRate, offset); - - auto mask0 = - window.computeWindow(inputAttn, 0, inputsteps, inputsteps, batchsize); - - Tensor maxv, maxidx; - max(maxv, maxidx, mask0.tensor(), 1, /* keepDims = */ true); - std::vector trueMaxidx(batchsize, offset); - - ASSERT_EQ(mask0.shape(), Shape({1, inputsteps, batchsize})); - ASSERT_TRUE(allClose( - maxidx.astype(fl::dtype::s32), - Tensor::fromVector({1, 1, batchsize}, trueMaxidx, fl::dtype::s32))); - - auto maskV = window.computeVectorizedWindow(targetlen, inputsteps, batchsize); - ASSERT_EQ(maskV.shape(), Shape({targetlen, inputsteps, batchsize})); - - std::vector inpSzRaw = {1, 2, 2, 2}; - Tensor inpSz = Tensor::fromVector({1, batchsize}, inpSzRaw); - std::vector tgSzRaw = {1, 2, 2, 2}; - Tensor tgSz = Tensor::fromVector({1, batchsize}, tgSzRaw); - - auto maskVPad = fl::exp(window - .computeVectorizedWindow( - targetlen, inputsteps, batchsize, inpSz, tgSz) - .tensor()); - ASSERT_EQ(maskVPad.shape(), Shape({targetlen, inputsteps, batchsize})); - ASSERT_TRUE( - fl::countNonzero( - maskVPad( - fl::span, - fl::range(inputsteps - inputsteps / 2, inputsteps), - 0) == 0) - .scalar() == inputsteps / 2 * targetlen); - ASSERT_TRUE( - fl::countNonzero( - maskVPad( - fl::range(targetlen - targetlen / 2, targetlen), fl::span, 0) == - 0) - .scalar() == targetlen / 2 * inputsteps); + int inputsteps = 100; + int batchsize = 4; + int targetlen = 15; + int offset = 10; + double avgRate = 5.2, std = 5.0; + + Variable inputAttn; // dummy + SoftWindow window(std, avgRate, offset); + + auto mask0 = + window.computeWindow(inputAttn, 0, inputsteps, inputsteps, batchsize); + + Tensor maxv, maxidx; + max(maxv, maxidx, mask0.tensor(), 1, /* keepDims = */ true); + std::vector trueMaxidx(batchsize, offset); + + ASSERT_EQ(mask0.shape(), Shape({1, inputsteps, batchsize})); + ASSERT_TRUE( + allClose( + maxidx.astype(fl::dtype::s32), + Tensor::fromVector({1, 1, batchsize}, trueMaxidx, fl::dtype::s32) + ) + ); + + auto maskV = window.computeVectorizedWindow(targetlen, inputsteps, batchsize); + ASSERT_EQ(maskV.shape(), Shape({targetlen, inputsteps, batchsize})); + + std::vector inpSzRaw = {1, 2, 2, 2}; + Tensor inpSz = Tensor::fromVector({1, batchsize}, inpSzRaw); + std::vector tgSzRaw = {1, 2, 2, 2}; + Tensor tgSz = Tensor::fromVector({1, batchsize}, tgSzRaw); + + auto maskVPad = fl::exp( + window + .computeVectorizedWindow( + targetlen, + inputsteps, + batchsize, + inpSz, + tgSz + ) + .tensor() + ); + ASSERT_EQ(maskVPad.shape(), Shape({targetlen, inputsteps, batchsize})); + ASSERT_TRUE( + fl::countNonzero( + maskVPad( + fl::span, + fl::range(inputsteps - inputsteps / 2, inputsteps), + 0 + ) == 0 + ) + .scalar() == inputsteps / 2 * targetlen + ); + ASSERT_TRUE( + fl::countNonzero( + maskVPad( + fl::range(targetlen - targetlen / 2, targetlen), + fl::span, + 0 + ) + == 0 + ) + .scalar() == targetlen / 2 * inputsteps + ); } TEST(WindowTest, SoftPretrainWindow) { - int inputsteps = 32; - int targetlen = 8; - int batchsize = 4; - double std = 5.0; - - std::vector peaks = {0, 4, 8, 12, 16, 20, 24, 28}; - - Variable inputAttn; - SoftPretrainWindow window(std); - - // single step - std::vector masks; - for (int step = 0; step < targetlen; ++step) { - masks.emplace_back(window.computeWindow( - inputAttn, step, targetlen, inputsteps, batchsize)); - } - auto maskS = concatenate(masks, 0); - Tensor maxv, maxidx; - max(maxv, maxidx, maskS.tensor()(fl::span, fl::span, 0), 1); - - ASSERT_EQ(maskS.shape(), Shape({targetlen, inputsteps, batchsize})); - ASSERT_TRUE(allClose(maxidx, Tensor::fromVector({8}, peaks))); - - // vectorized - auto maskV = window.computeVectorizedWindow(targetlen, inputsteps, batchsize); - max(maxv, maxidx, maskV.tensor()(fl::span, fl::span, 0), 1); - - ASSERT_EQ(maskV.shape(), Shape({targetlen, inputsteps, batchsize})); - ASSERT_TRUE(allClose(maxidx, Tensor::fromVector({8}, peaks))); - ASSERT_TRUE(allClose(maskS, maskV)); - - std::vector inpSzRaw = {1, 2, 2, 2}; - Tensor inpSz = Tensor::fromVector({1, batchsize}, inpSzRaw); - std::vector tgSzRaw = {1, 2, 2, 2}; - Tensor tgSz = Tensor::fromVector({1, batchsize}, tgSzRaw); - - auto maskVPad = fl::exp(window - .computeVectorizedWindow( - targetlen, inputsteps, batchsize, inpSz, tgSz) - .tensor()); - ASSERT_EQ(maskVPad.shape(), Shape({targetlen, inputsteps, batchsize})); - ASSERT_TRUE( - fl::countNonzero( - maskVPad( - fl::span, - fl::range(inputsteps - inputsteps / 2, inputsteps), - 0) == 0) - .scalar() == inputsteps / 2 * targetlen); - ASSERT_TRUE( - fl::countNonzero( - maskVPad( - fl::range(targetlen - targetlen / 2, targetlen), fl::span, 0) == - 0) - .scalar() == targetlen / 2 * inputsteps); + int inputsteps = 32; + int targetlen = 8; + int batchsize = 4; + double std = 5.0; + + std::vector peaks = {0, 4, 8, 12, 16, 20, 24, 28}; + + Variable inputAttn; + SoftPretrainWindow window(std); + + // single step + std::vector masks; + for(int step = 0; step < targetlen; ++step) { + masks.emplace_back( + window.computeWindow( + inputAttn, + step, + targetlen, + inputsteps, + batchsize + ) + ); + } + auto maskS = concatenate(masks, 0); + Tensor maxv, maxidx; + max(maxv, maxidx, maskS.tensor()(fl::span, fl::span, 0), 1); + + ASSERT_EQ(maskS.shape(), Shape({targetlen, inputsteps, batchsize})); + ASSERT_TRUE(allClose(maxidx, Tensor::fromVector({8}, peaks))); + + // vectorized + auto maskV = window.computeVectorizedWindow(targetlen, inputsteps, batchsize); + max(maxv, maxidx, maskV.tensor()(fl::span, fl::span, 0), 1); + + ASSERT_EQ(maskV.shape(), Shape({targetlen, inputsteps, batchsize})); + ASSERT_TRUE(allClose(maxidx, Tensor::fromVector({8}, peaks))); + ASSERT_TRUE(allClose(maskS, maskV)); + + std::vector inpSzRaw = {1, 2, 2, 2}; + Tensor inpSz = Tensor::fromVector({1, batchsize}, inpSzRaw); + std::vector tgSzRaw = {1, 2, 2, 2}; + Tensor tgSz = Tensor::fromVector({1, batchsize}, tgSzRaw); + + auto maskVPad = fl::exp( + window + .computeVectorizedWindow( + targetlen, + inputsteps, + batchsize, + inpSz, + tgSz + ) + .tensor() + ); + ASSERT_EQ(maskVPad.shape(), Shape({targetlen, inputsteps, batchsize})); + ASSERT_TRUE( + fl::countNonzero( + maskVPad( + fl::span, + fl::range(inputsteps - inputsteps / 2, inputsteps), + 0 + ) == 0 + ) + .scalar() == inputsteps / 2 * targetlen + ); + ASSERT_TRUE( + fl::countNonzero( + maskVPad( + fl::range(targetlen - targetlen / 2, targetlen), + fl::span, + 0 + ) + == 0 + ) + .scalar() == targetlen / 2 * inputsteps + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/data/FeaturizationTest.cpp b/flashlight/pkg/speech/test/data/FeaturizationTest.cpp index 8a49db0..902b986 100644 --- a/flashlight/pkg/speech/test/data/FeaturizationTest.cpp +++ b/flashlight/pkg/speech/test/data/FeaturizationTest.cpp @@ -25,438 +25,462 @@ using namespace fl::lib::text; using namespace fl::pkg::speech; namespace { -template +template bool compareVec(std::vector A, std::vector B, float precision = 1E-5) { - if (A.size() != B.size()) { - return false; - } - for (std::size_t i = 0; i < A.size(); ++i) { - if (std::abs(A[i] - B[i]) > precision) { - return false; + if(A.size() != B.size()) { + return false; } - } - return true; + for(std::size_t i = 0; i < A.size(); ++i) { + if(std::abs(A[i] - B[i]) > precision) { + return false; + } + } + return true; } Dictionary getDict() { - Dictionary dict; - std::string ltr = "a"; - int alphabet_sz = 26; - while (alphabet_sz--) { - dict.addEntry(ltr); - ltr[0] += 1; - } - dict.addEntry("|"); - dict.addEntry("'"); - dict.addEntry("L", dict.getIndex("|")); - dict.addEntry("N", dict.getIndex("|")); - return dict; + Dictionary dict; + std::string ltr = "a"; + int alphabet_sz = 26; + while(alphabet_sz--) { + dict.addEntry(ltr); + ltr[0] += 1; + } + dict.addEntry("|"); + dict.addEntry("'"); + dict.addEntry("L", dict.getIndex("|")); + dict.addEntry("N", dict.getIndex("|")); + return dict; } LexiconMap getLexicon() { - LexiconMap lexicon; - lexicon["uh"].push_back({"u", "h"}); - lexicon["oh"].push_back({"o", "h"}); - lexicon[kUnkToken] = {}; - return lexicon; + LexiconMap lexicon; + lexicon["uh"].push_back({"u", "h"}); + lexicon["oh"].push_back({"o", "h"}); + lexicon[kUnkToken] = {}; + return lexicon; } } // namespace TEST(FeaturizationTest, AfMatmulCompare) { - int numTests = 1000; - while (numTests--) { - int m = (rand() % 64) + 1; - int n = (rand() % 128) + 1; - int k = (rand() % 256) + 1; - // Note: Arrayfire is column major - Tensor a = fl::rand({k, m}); - Tensor b = fl::rand({n, k}); - Tensor c = fl::transpose( - fl::matmul(a, b, MatrixProperty::Transpose, MatrixProperty::Transpose)); - auto aVec = a.toHostVector(); - auto bVec = b.toHostVector(); - auto cVec = cblasGemm(aVec, bVec, n, k); - ASSERT_TRUE(compareVec(cVec, c.toHostVector(), 1E-4)); - } + int numTests = 1000; + while(numTests--) { + int m = (rand() % 64) + 1; + int n = (rand() % 128) + 1; + int k = (rand() % 256) + 1; + // Note: Arrayfire is column major + Tensor a = fl::rand({k, m}); + Tensor b = fl::rand({n, k}); + Tensor c = fl::transpose( + fl::matmul(a, b, MatrixProperty::Transpose, MatrixProperty::Transpose) + ); + auto aVec = a.toHostVector(); + auto bVec = b.toHostVector(); + auto cVec = cblasGemm(aVec, bVec, n, k); + ASSERT_TRUE(compareVec(cVec, c.toHostVector(), 1E-4)); + } } TEST(FeaturizationTest, Normalize) { - double threshold = 0.01; - auto afNormalize = [threshold](const Tensor& in, int batchdim) { - int64_t elementsPerBatch = in.elements() / in.dim(batchdim); - auto in2d = fl::reshape(in, {elementsPerBatch, in.dim(batchdim)}); - - Tensor meandiff = - (in2d - - fl::tile( - fl::mean(in2d, {0}, /* keepDims = */ true), {elementsPerBatch})); - - Tensor stddev = fl::std(in2d, {0}, /* keepDims = */ true); - stddev = fl::where(stddev > threshold, stddev, 1.0); - - return fl::reshape( - meandiff / fl::tile(stddev, {elementsPerBatch}), in.shape()); - }; - auto arr = fl::rand({13, 17, 19}); - auto arrVec = arr.toHostVector(); - - auto arrVecNrm = normalize(arrVec, 19, threshold); - auto arrNrm = - Tensor::fromBuffer(arr.shape(), arrVecNrm.data(), MemoryLocation::Host); - ASSERT_TRUE( - fl::all(fl::abs(arrNrm - afNormalize(arr, 2)) < 1E-5).asScalar()); + double threshold = 0.01; + auto afNormalize = [threshold](const Tensor& in, int batchdim) { + int64_t elementsPerBatch = in.elements() / in.dim(batchdim); + auto in2d = fl::reshape(in, {elementsPerBatch, in.dim(batchdim)}); + + Tensor meandiff = + (in2d + - fl::tile( + fl::mean(in2d, {0}, /* keepDims = */ true), + {elementsPerBatch} + )); + + Tensor stddev = fl::std(in2d, {0}, /* keepDims = */ true); + stddev = fl::where(stddev > threshold, stddev, 1.0); + + return fl::reshape( + meandiff / fl::tile(stddev, {elementsPerBatch}), + in.shape() + ); + }; + auto arr = fl::rand({13, 17, 19}); + auto arrVec = arr.toHostVector(); + + auto arrVecNrm = normalize(arrVec, 19, threshold); + auto arrNrm = + Tensor::fromBuffer(arr.shape(), arrVecNrm.data(), MemoryLocation::Host); + ASSERT_TRUE( + fl::all(fl::abs(arrNrm - afNormalize(arr, 2)) < 1E-5).asScalar() + ); } TEST(FeaturizationTest, Transpose) { - auto arr = fl::rand({13, 17, 19, 23}); - auto arrVec = arr.toHostVector(); - auto arrVecT = transpose2d(arrVec, 17, 13, 19 * 23); - auto arrT = Tensor::fromVector({17, 13, 19, 23}, arrVecT); - ASSERT_TRUE( - fl::all(arrT - fl::transpose(arr, {1, 0, 2, 3}) == 0.0).asScalar()); + auto arr = fl::rand({13, 17, 19, 23}); + auto arrVec = arr.toHostVector(); + auto arrVecT = transpose2d(arrVec, 17, 13, 19 * 23); + auto arrT = Tensor::fromVector({17, 13, 19, 23}, arrVecT); + ASSERT_TRUE( + fl::all(arrT - fl::transpose(arr, {1, 0, 2, 3}) == 0.0).asScalar() + ); } TEST(FeaturizationTest, localNormalize) { - auto afNormalize = [](const Tensor& in, int64_t lw, int64_t rw) { - auto out = in; - for (int64_t b = 0; b < in.dim(3); ++b) { - for (int64_t i = 0; i < in.dim(0); ++i) { - int64_t b_idx = (i - lw > 0) ? (i - lw) : 0; - int64_t e_idx = (in.dim(0) - 1 > i + rw) ? (i + rw) : (in.dim(0) - 1); - - Tensor slice = in(fl::range(b_idx, e_idx + 1), fl::span, fl::span, b); - auto mean = fl::mean(slice).scalar(); - auto stddev = fl::std(slice).scalar(); - - out(i, fl::span, fl::span, b) -= mean; - if (stddev > 0.0) { - out(i, fl::span, fl::span, b) /= stddev; - } - } + auto afNormalize = [](const Tensor& in, int64_t lw, int64_t rw) { + auto out = in; + for(int64_t b = 0; b < in.dim(3); ++b) { + for(int64_t i = 0; i < in.dim(0); ++i) { + int64_t b_idx = (i - lw > 0) ? (i - lw) : 0; + int64_t e_idx = (in.dim(0) - 1 > i + rw) ? (i + rw) : (in.dim(0) - 1); + + Tensor slice = in(fl::range(b_idx, e_idx + 1), fl::span, fl::span, b); + auto mean = fl::mean(slice).scalar(); + auto stddev = fl::std(slice).scalar(); + + out(i, fl::span, fl::span, b) -= mean; + if(stddev > 0.0) { + out(i, fl::span, fl::span, b) /= stddev; + } + } + } + return out; + }; + auto arr = fl::rand({47, 67, 2, 10}); // FRAMES X FEAT X CHANNELS X BATCHSIZE + auto arrVec = arr.toHostVector(); + + std::vector> ctx = { + {0, 0}, {1, 1}, {2, 2}, {4, 4}, {1024, 1024}, {10, 0}, {2, 12}}; + + for(auto c : ctx) { + auto arrVecNrm = localNormalize( + arrVec, + c.first /* context */, + c.second, + arr.dim(0) /* frames */, + arr.dim(3) /*batches */ + ); + auto arrNrm = + Tensor::fromBuffer(arr.shape(), arrVecNrm.data(), MemoryLocation::Host); + ASSERT_TRUE( + fl::all(fl::abs(arrNrm - afNormalize(arr, c.first, c.second)) < 1E-4) + .asScalar() + ); } - return out; - }; - auto arr = fl::rand({47, 67, 2, 10}); // FRAMES X FEAT X CHANNELS X BATCHSIZE - auto arrVec = arr.toHostVector(); - - std::vector> ctx = { - {0, 0}, {1, 1}, {2, 2}, {4, 4}, {1024, 1024}, {10, 0}, {2, 12}}; - - for (auto c : ctx) { - auto arrVecNrm = localNormalize( - arrVec, - c.first /* context */, - c.second, - arr.dim(0) /* frames */, - arr.dim(3) /*batches */); - auto arrNrm = - Tensor::fromBuffer(arr.shape(), arrVecNrm.data(), MemoryLocation::Host); - ASSERT_TRUE( - fl::all(fl::abs(arrNrm - afNormalize(arr, c.first, c.second)) < 1E-4) - .asScalar()); - } } TEST(FeaturizationTest, TargetTknTestStandaloneSep) { - Dictionary tokens; - std::string sep = "||"; - tokens.addEntry("ab"); - tokens.addEntry("cd"); - tokens.addEntry("ef"); - tokens.addEntry("t"); - tokens.addEntry("r"); - tokens.addEntry(sep); - - LexiconMap lexicon; - lexicon["abcd"].push_back({"ab", "cd", "||"}); - lexicon["abcdef"].push_back({"ab", "cd", "ef", "||"}); - - std::vector words = {"abcdef", "abcd", "tr"}; - auto res = wrd2Target( - words, - lexicon, - tokens, - sep, - 0, - false, - true, // fallback right - false); - - std::vector resT = { - "ab", "cd", "ef", "||", "ab", "cd", "||", "t", "r", "||"}; - ASSERT_EQ(res.size(), resT.size()); - for (int index = 0; index < res.size(); index++) { - ASSERT_EQ(res[index], resT[index]); - } - - auto res2 = wrd2Target( - words, - lexicon, - tokens, - sep, - 0, - true, // fallback left - false, - false); - - std::vector resT2 = { - "ab", "cd", "ef", "||", "ab", "cd", "||", "||", "t", "r"}; - ASSERT_EQ(res2.size(), resT2.size()); - for (int index = 0; index < res2.size(); index++) { - ASSERT_EQ(res2[index], resT2[index]); - } + Dictionary tokens; + std::string sep = "||"; + tokens.addEntry("ab"); + tokens.addEntry("cd"); + tokens.addEntry("ef"); + tokens.addEntry("t"); + tokens.addEntry("r"); + tokens.addEntry(sep); + + LexiconMap lexicon; + lexicon["abcd"].push_back({"ab", "cd", "||"}); + lexicon["abcdef"].push_back({"ab", "cd", "ef", "||"}); + + std::vector words = {"abcdef", "abcd", "tr"}; + auto res = wrd2Target( + words, + lexicon, + tokens, + sep, + 0, + false, + true, // fallback right + false + ); + + std::vector resT = { + "ab", "cd", "ef", "||", "ab", "cd", "||", "t", "r", "||"}; + ASSERT_EQ(res.size(), resT.size()); + for(int index = 0; index < res.size(); index++) { + ASSERT_EQ(res[index], resT[index]); + } + + auto res2 = wrd2Target( + words, + lexicon, + tokens, + sep, + 0, + true, // fallback left + false, + false + ); + + std::vector resT2 = { + "ab", "cd", "ef", "||", "ab", "cd", "||", "||", "t", "r"}; + ASSERT_EQ(res2.size(), resT2.size()); + for(int index = 0; index < res2.size(); index++) { + ASSERT_EQ(res2[index], resT2[index]); + } } TEST(FeaturizationTest, TargetTknTestInsideSep) { - Dictionary tokens; - std::string sep = "_"; - tokens.addEntry("_hel"); - tokens.addEntry("lo"); - tokens.addEntry("_ma"); - tokens.addEntry("ma"); - tokens.addEntry(sep); - tokens.addEntry("f"); - tokens.addEntry("a"); - - LexiconMap lexicon; - lexicon["hello"].push_back({"_hel", "lo"}); - lexicon["mama"].push_back({"_ma", "ma"}); - lexicon["af"].push_back({"_", "a", "f"}); - - std::vector words = {"aff", "hello", "mama", "af"}; - auto res = wrd2Target( - words, - lexicon, - tokens, - sep, - 0, - true, // fallback left - false, - false); - - std::vector resT = { - "_", "a", "f", "f", "_hel", "lo", "_ma", "ma", "_", "a", "f"}; - ASSERT_EQ(res.size(), resT.size()); - for (int index = 0; index < res.size(); index++) { - ASSERT_EQ(res[index], resT[index]); - } - - auto res2 = wrd2Target( - words, - lexicon, - tokens, - sep, - 0, - false, - true, // fallback right - false); - - std::vector resT2 = { - "a", "f", "f", "_", "_hel", "lo", "_ma", "ma", "_", "a", "f"}; - ASSERT_EQ(res.size(), resT2.size()); - for (int index = 0; index < res2.size(); index++) { - ASSERT_EQ(res2[index], resT2[index]); - } + Dictionary tokens; + std::string sep = "_"; + tokens.addEntry("_hel"); + tokens.addEntry("lo"); + tokens.addEntry("_ma"); + tokens.addEntry("ma"); + tokens.addEntry(sep); + tokens.addEntry("f"); + tokens.addEntry("a"); + + LexiconMap lexicon; + lexicon["hello"].push_back({"_hel", "lo"}); + lexicon["mama"].push_back({"_ma", "ma"}); + lexicon["af"].push_back({"_", "a", "f"}); + + std::vector words = {"aff", "hello", "mama", "af"}; + auto res = wrd2Target( + words, + lexicon, + tokens, + sep, + 0, + true, // fallback left + false, + false + ); + + std::vector resT = { + "_", "a", "f", "f", "_hel", "lo", "_ma", "ma", "_", "a", "f"}; + ASSERT_EQ(res.size(), resT.size()); + for(int index = 0; index < res.size(); index++) { + ASSERT_EQ(res[index], resT[index]); + } + + auto res2 = wrd2Target( + words, + lexicon, + tokens, + sep, + 0, + false, + true, // fallback right + false + ); + + std::vector resT2 = { + "a", "f", "f", "_", "_hel", "lo", "_ma", "ma", "_", "a", "f"}; + ASSERT_EQ(res.size(), resT2.size()); + for(int index = 0; index < res2.size(); index++) { + ASSERT_EQ(res2[index], resT2[index]); + } } TEST(FeaturizationTest, WrdToTarget) { - LexiconMap lexicon; - // word pieces with word separator in the end - lexicon["123"].push_back({"1", "23_"}); - lexicon["456"].push_back({"456_"}); - // word pieces with word separator in the beginning - lexicon["789"].push_back({"_7", "89"}); - lexicon["010"].push_back({"_0", "10"}); - // word pieces without word separators - lexicon["105"].push_back({"10", "5"}); - lexicon["2100"].push_back({"2", "1", "00"}); - // letters - lexicon["888"].push_back({"8", "8", "8", "_"}); - lexicon["12"].push_back({"1", "2", "_"}); - lexicon[kUnkToken] = {}; - - Dictionary dict; - for (const auto& l : lexicon) { - for (const auto& p : l.second) { - for (const auto& c : p) { - if (!dict.contains(c)) { - dict.addEntry(c); + LexiconMap lexicon; + // word pieces with word separator in the end + lexicon["123"].push_back({"1", "23_"}); + lexicon["456"].push_back({"456_"}); + // word pieces with word separator in the beginning + lexicon["789"].push_back({"_7", "89"}); + lexicon["010"].push_back({"_0", "10"}); + // word pieces without word separators + lexicon["105"].push_back({"10", "5"}); + lexicon["2100"].push_back({"2", "1", "00"}); + // letters + lexicon["888"].push_back({"8", "8", "8", "_"}); + lexicon["12"].push_back({"1", "2", "_"}); + lexicon[kUnkToken] = {}; + + Dictionary dict; + for(const auto& l : lexicon) { + for(const auto& p : l.second) { + for(const auto& c : p) { + if(!dict.contains(c)) { + dict.addEntry(c); + } + } } - } } - } - - // NOTE: word separator has no effect when fallback2Ltr is false - std::vector words = {"123", "456"}; - auto target = wrd2Target(words, lexicon, dict, "", 0, false, false, false); - ASSERT_THAT(target, ::testing::ElementsAreArray({"1", "23_", "456_"})); - - std::vector words1 = {"789", "010"}; - auto target1 = wrd2Target(words1, lexicon, dict, "_", 0, false, false, false); - ASSERT_THAT(target1, ::testing::ElementsAreArray({"_7", "89", "_0", "10"})); - - std::vector words2 = {"105", "2100"}; - auto target2 = wrd2Target(words2, lexicon, dict, "", 0, false, false, false); - ASSERT_THAT( - target2, ::testing::ElementsAreArray({"10", "5", "2", "1", "00"})); - - std::vector words3 = {"12", "888", "12"}; - auto target3 = wrd2Target(words3, lexicon, dict, "_", 0, false, false, false); - ASSERT_THAT( - target3, - ::testing::ElementsAreArray( - {"1", "2", "_", "8", "8", "8", "_", "1", "2", "_"})); - - // unknown words "111", "199" - std::vector words4 = {"111", "789", "199"}; - // fall back to letters, wordsep to left and skip unknown - auto target4 = wrd2Target(words4, lexicon, dict, "_", 0, true, false, true); - ASSERT_THAT( - target4, - ::testing::ElementsAreArray({"_", "1", "1", "1", "_7", "89", "_", "1"})); - // fall back to letters, wordsep to right and skip unknown - target4 = wrd2Target(words4, lexicon, dict, "_", 0, false, true, true); - ASSERT_THAT( - target4, - ::testing::ElementsAreArray({"1", "1", "1", "_", "_7", "89", "1", "_"})); - - // skip unknown - target4 = wrd2Target(words4, lexicon, dict, "", 0, false, false, true); - ASSERT_THAT(target4, ::testing::ElementsAreArray({"_7", "89"})); + + // NOTE: word separator has no effect when fallback2Ltr is false + std::vector words = {"123", "456"}; + auto target = wrd2Target(words, lexicon, dict, "", 0, false, false, false); + ASSERT_THAT(target, ::testing::ElementsAreArray({"1", "23_", "456_"})); + + std::vector words1 = {"789", "010"}; + auto target1 = wrd2Target(words1, lexicon, dict, "_", 0, false, false, false); + ASSERT_THAT(target1, ::testing::ElementsAreArray({"_7", "89", "_0", "10"})); + + std::vector words2 = {"105", "2100"}; + auto target2 = wrd2Target(words2, lexicon, dict, "", 0, false, false, false); + ASSERT_THAT( + target2, + ::testing::ElementsAreArray({"10", "5", "2", "1", "00"}) + ); + + std::vector words3 = {"12", "888", "12"}; + auto target3 = wrd2Target(words3, lexicon, dict, "_", 0, false, false, false); + ASSERT_THAT( + target3, + ::testing::ElementsAreArray( + {"1", "2", "_", "8", "8", "8", "_", "1", "2", "_"} + ) + ); + + // unknown words "111", "199" + std::vector words4 = {"111", "789", "199"}; + // fall back to letters, wordsep to left and skip unknown + auto target4 = wrd2Target(words4, lexicon, dict, "_", 0, true, false, true); + ASSERT_THAT( + target4, + ::testing::ElementsAreArray({"_", "1", "1", "1", "_7", "89", "_", "1"}) + ); + // fall back to letters, wordsep to right and skip unknown + target4 = wrd2Target(words4, lexicon, dict, "_", 0, false, true, true); + ASSERT_THAT( + target4, + ::testing::ElementsAreArray({"1", "1", "1", "_", "_7", "89", "1", "_"}) + ); + + // skip unknown + target4 = wrd2Target(words4, lexicon, dict, "", 0, false, false, true); + ASSERT_THAT(target4, ::testing::ElementsAreArray({"_7", "89"})); } TEST(FeaturizationTest, TargetToSingleLtr) { - std::string wordseparator = "_"; - bool usewordpiece = true; - - Dictionary dict; - for (int i = 0; i < 10; ++i) { - dict.addEntry(std::to_string(i), i); - } - dict.addEntry("_", 10); - dict.addEntry("23_", 230); - dict.addEntry("456_", 4560); - - std::vector words = {1, 230, 4560}; - auto target = tknIdx2Ltr(words, dict, usewordpiece, wordseparator); - ASSERT_THAT( - target, ::testing::ElementsAreArray({"1", "2", "3", "_", "4", "5", "6"})); + std::string wordseparator = "_"; + bool usewordpiece = true; + + Dictionary dict; + for(int i = 0; i < 10; ++i) { + dict.addEntry(std::to_string(i), i); + } + dict.addEntry("_", 10); + dict.addEntry("23_", 230); + dict.addEntry("456_", 4560); + + std::vector words = {1, 230, 4560}; + auto target = tknIdx2Ltr(words, dict, usewordpiece, wordseparator); + ASSERT_THAT( + target, + ::testing::ElementsAreArray({"1", "2", "3", "_", "4", "5", "6"}) + ); } TEST(FeaturizationTest, inputFeaturizer) { - auto channels = 2; - auto samplerate = 16000; - FeatureParams featParams( - samplerate, - 25, // framesize - 10, // framestride - 40, // filterbanks - 0, // lowfreqfilterbank, - samplerate / 2, // highfreqfilterbank - -1, // mfcccoeffs - kLifterParam, // lifterparam - 0, // delta window - 0 // delta-delta window - ); - featParams.useEnergy = false; - featParams.usePower = false; - featParams.zeroMeanFrame = false; - auto inputFeaturizerRaw = - inputFeatures(featParams, FeatureType::NONE, {-1, -1}, {}); - auto inputFeaturizerMfsc = - inputFeatures(featParams, FeatureType::MFSC, {-1, -1}, {}); - for (int size = 1; size < 10; ++size) { - std::vector input(size * samplerate * channels); - for (int j = 0; j < input.size(); ++j) { - // channel 1 is same as channel 2 - input[j] = std::sin(2 * M_PI * (j / 2) / samplerate); - } + auto channels = 2; + auto samplerate = 16000; + FeatureParams featParams( + samplerate, + 25, // framesize + 10, // framestride + 40, // filterbanks + 0, // lowfreqfilterbank, + samplerate / 2, // highfreqfilterbank + -1, // mfcccoeffs + kLifterParam, // lifterparam + 0, // delta window + 0 // delta-delta window + ); + featParams.useEnergy = false; + featParams.usePower = false; + featParams.zeroMeanFrame = false; + auto inputFeaturizerRaw = + inputFeatures(featParams, FeatureType::NONE, {-1, -1}, {}); + auto inputFeaturizerMfsc = + inputFeatures(featParams, FeatureType::MFSC, {-1, -1}, {}); + for(int size = 1; size < 10; ++size) { + std::vector input(size * samplerate * channels); + for(int j = 0; j < input.size(); ++j) { + // channel 1 is same as channel 2 + input[j] = std::sin(2 * M_PI * (j / 2) / samplerate); + } - int insize = size * samplerate; - auto inArray = - inputFeaturizerRaw(input.data(), {channels, insize}, fl::dtype::f32); - ASSERT_EQ(inArray.shape(), Shape({insize, 1, channels})); - Tensor ch1 = inArray(fl::span, fl::span, 0); - Tensor ch2 = inArray(fl::span, fl::span, 1); - ASSERT_TRUE(fl::amax(fl::abs(ch1 - ch2)).scalar() < 1E-5); - - inArray = - inputFeaturizerMfsc(input.data(), {channels, insize}, fl::dtype::f32); - auto nFrames = 1 + (insize - 25 * 16) / (10 * 16); - ASSERT_EQ(inArray.shape(), Shape({nFrames, 40, channels})); - ch1 = inArray(fl::span, fl::span, 0, fl::span); - ch2 = inArray(fl::span, fl::span, 1, fl::span); - ASSERT_TRUE(fl::amax(fl::abs(ch1 - ch2)).scalar() < 1E-5); - } + int insize = size * samplerate; + auto inArray = + inputFeaturizerRaw(input.data(), {channels, insize}, fl::dtype::f32); + ASSERT_EQ(inArray.shape(), Shape({insize, 1, channels})); + Tensor ch1 = inArray(fl::span, fl::span, 0); + Tensor ch2 = inArray(fl::span, fl::span, 1); + ASSERT_TRUE(fl::amax(fl::abs(ch1 - ch2)).scalar() < 1E-5); + + inArray = + inputFeaturizerMfsc(input.data(), {channels, insize}, fl::dtype::f32); + auto nFrames = 1 + (insize - 25 * 16) / (10 * 16); + ASSERT_EQ(inArray.shape(), Shape({nFrames, 40, channels})); + ch1 = inArray(fl::span, fl::span, 0, fl::span); + ch2 = inArray(fl::span, fl::span, 1, fl::span); + ASSERT_TRUE(fl::amax(fl::abs(ch1 - ch2)).scalar() < 1E-5); + } } TEST(FeaturizationTest, targetFeaturizer) { - using fl::pkg::speech::kEosToken; - - auto tokenDict = getDict(); - tokenDict.addEntry(kEosToken); - auto lexicon = getLexicon(); - std::vector> targets = { - {'a', 'b', 'c', 'c', 'c'}, {'b', 'c', 'd', 'd'}}; - - TargetGenerationConfig targetGenConfig( - "", - 0, - kCtcCriterion, - "", - false, - 0, - true /* skip unk */, - false /* fallback2LetterWordSepLeft */, - true /* fallback2LetterWordSepLeft */); - - auto targetFeaturizer = targetFeatures(tokenDict, lexicon, targetGenConfig); - - auto tgtArray = targetFeaturizer( - targets[0].data(), - {static_cast(targets[0].size())}, - fl::dtype::b8); - int tgtLen = 5; - ASSERT_EQ(tgtArray.shape(), Shape({tgtLen})); - ASSERT_EQ(tgtArray.type(), fl::dtype::s32); - std::vector tgtArrayVec(tgtLen); - tgtArray.host(tgtArrayVec.data()); - - ASSERT_EQ(tgtArrayVec[0], 0); - ASSERT_EQ(tgtArrayVec[1], 1); - ASSERT_EQ(tgtArrayVec[2], 2); - ASSERT_EQ(tgtArrayVec[3], 2); - ASSERT_EQ(tgtArrayVec[4], 2); - - auto targetGenConfigEos = TargetGenerationConfig( - "", - 0, - kCtcCriterion, - "", - true, // changed from above - 0, - true /* skip unk */, - false /* fallback2LetterWordSepLeft */, - true /* fallback2LetterWordSepLeft */); - targetFeaturizer = targetFeatures(tokenDict, lexicon, targetGenConfigEos); - tgtArray = targetFeaturizer( - targets[1].data(), - {static_cast(targets[1].size())}, - fl::dtype::b8); - tgtLen = 5; - int eosIdx = tokenDict.getIndex(kEosToken); - ASSERT_EQ(tgtArray.shape(), Shape({tgtLen})); - ASSERT_EQ(tgtArray.type(), fl::dtype::s32); - tgtArray.host(tgtArrayVec.data()); - ASSERT_EQ(tgtArrayVec[0], 1); - ASSERT_EQ(tgtArrayVec[1], 2); - ASSERT_EQ(tgtArrayVec[2], 3); - ASSERT_EQ(tgtArrayVec[3], 3); - ASSERT_EQ(tgtArrayVec[4], eosIdx); + using fl::pkg::speech::kEosToken; + + auto tokenDict = getDict(); + tokenDict.addEntry(kEosToken); + auto lexicon = getLexicon(); + std::vector> targets = { + {'a', 'b', 'c', 'c', 'c'}, {'b', 'c', 'd', 'd'}}; + + TargetGenerationConfig targetGenConfig( + "", + 0, + kCtcCriterion, + "", + false, + 0, + true /* skip unk */, + false /* fallback2LetterWordSepLeft */, + true /* fallback2LetterWordSepLeft */); + + auto targetFeaturizer = targetFeatures(tokenDict, lexicon, targetGenConfig); + + auto tgtArray = targetFeaturizer( + targets[0].data(), + {static_cast(targets[0].size())}, + fl::dtype::b8 + ); + int tgtLen = 5; + ASSERT_EQ(tgtArray.shape(), Shape({tgtLen})); + ASSERT_EQ(tgtArray.type(), fl::dtype::s32); + std::vector tgtArrayVec(tgtLen); + tgtArray.host(tgtArrayVec.data()); + + ASSERT_EQ(tgtArrayVec[0], 0); + ASSERT_EQ(tgtArrayVec[1], 1); + ASSERT_EQ(tgtArrayVec[2], 2); + ASSERT_EQ(tgtArrayVec[3], 2); + ASSERT_EQ(tgtArrayVec[4], 2); + + auto targetGenConfigEos = TargetGenerationConfig( + "", + 0, + kCtcCriterion, + "", + true, // changed from above + 0, + true /* skip unk */, + false /* fallback2LetterWordSepLeft */, + true /* fallback2LetterWordSepLeft */ + ); + targetFeaturizer = targetFeatures(tokenDict, lexicon, targetGenConfigEos); + tgtArray = targetFeaturizer( + targets[1].data(), + {static_cast(targets[1].size())}, + fl::dtype::b8 + ); + tgtLen = 5; + int eosIdx = tokenDict.getIndex(kEosToken); + ASSERT_EQ(tgtArray.shape(), Shape({tgtLen})); + ASSERT_EQ(tgtArray.type(), fl::dtype::s32); + tgtArray.host(tgtArrayVec.data()); + ASSERT_EQ(tgtArrayVec[0], 1); + ASSERT_EQ(tgtArrayVec[1], 2); + ASSERT_EQ(tgtArrayVec[2], 3); + ASSERT_EQ(tgtArrayVec[3], 3); + ASSERT_EQ(tgtArrayVec[4], eosIdx); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/data/ListFileDatasetTest.cpp b/flashlight/pkg/speech/test/data/ListFileDatasetTest.cpp index f2fe397..58566ed 100644 --- a/flashlight/pkg/speech/test/data/ListFileDatasetTest.cpp +++ b/flashlight/pkg/speech/test/data/ListFileDatasetTest.cpp @@ -26,65 +26,66 @@ using namespace fl; fs::path loadPath = ""; auto letterToTarget = [](void* data, Shape dims, fl::dtype /* unused */) { - std::string transcript( - static_cast(data), static_cast(data) + dims.elements()); - std::vector tgt; - for (auto c : transcript) { - tgt.push_back(static_cast(c)); - } - return Tensor::fromVector(tgt); -}; + std::string transcript( + static_cast(data), static_cast(data) + dims.elements()); + std::vector tgt; + for(auto c : transcript) { + tgt.push_back(static_cast(c)); + } + return Tensor::fromVector(tgt); + }; } // namespace TEST(ListFileDatasetTest, LoadData) { - const fs::path dataPath = loadPath / "data.lst"; - if (!fs::exists(dataPath)) { - throw std::runtime_error( - "ListFileDatasetTest, LoadData - can't open test data.lst"); - } - std::vector data; - { - std::ifstream in(dataPath); - for (std::string s; std::getline(in, s);) { - data.emplace_back(s); + const fs::path dataPath = loadPath / "data.lst"; + if(!fs::exists(dataPath)) { + throw std::runtime_error( + "ListFileDatasetTest, LoadData - can't open test data.lst" + ); + } + std::vector data; + { + std::ifstream in(dataPath); + for(std::string s; std::getline(in, s);) { + data.emplace_back(s); + } } - } - const fs::path rootPath = fs::temp_directory_path() / "data.lst"; - std::ofstream out(rootPath); - for (auto& d : data) { - replaceAll(d, "", loadPath); - out << d; - out << "\n"; - } - out.close(); - ListFileDataset audiods(rootPath, nullptr, letterToTarget); - ASSERT_EQ(audiods.size(), 3); - std::vector expectedTgtLen = {45, 23, 26}; - std::vector expectedDuration = {1.2, 2.1, 0.6}; - for (int i = 0; i < 3; ++i) { - ASSERT_EQ(audiods.get(i).size(), 7); - ASSERT_EQ(audiods.get(i)[0].shape(), Shape({1, 24000})); - ASSERT_EQ(audiods.get(i)[1].elements(), expectedTgtLen[i]); - ASSERT_EQ(audiods.get(i)[1].elements(), audiods.getTargetSize(i)); - ASSERT_TRUE(audiods.get(i)[2].isEmpty()); - ASSERT_EQ(audiods.get(i)[3].elements(), 1); - ASSERT_GE(audiods.get(i)[4].elements(), 15); - ASSERT_EQ(audiods.get(i)[5].elements(), 1); - ASSERT_EQ(audiods.get(i)[5].scalar(), expectedDuration[i]); - ASSERT_EQ(audiods.get(i)[6].elements(), 1); - ASSERT_EQ(audiods.get(i)[6].scalar(), expectedTgtLen[i]); - } + const fs::path rootPath = fs::temp_directory_path() / "data.lst"; + std::ofstream out(rootPath); + for(auto& d : data) { + replaceAll(d, "", loadPath); + out << d; + out << "\n"; + } + out.close(); + ListFileDataset audiods(rootPath, nullptr, letterToTarget); + ASSERT_EQ(audiods.size(), 3); + std::vector expectedTgtLen = {45, 23, 26}; + std::vector expectedDuration = {1.2, 2.1, 0.6}; + for(int i = 0; i < 3; ++i) { + ASSERT_EQ(audiods.get(i).size(), 7); + ASSERT_EQ(audiods.get(i)[0].shape(), Shape({1, 24000})); + ASSERT_EQ(audiods.get(i)[1].elements(), expectedTgtLen[i]); + ASSERT_EQ(audiods.get(i)[1].elements(), audiods.getTargetSize(i)); + ASSERT_TRUE(audiods.get(i)[2].isEmpty()); + ASSERT_EQ(audiods.get(i)[3].elements(), 1); + ASSERT_GE(audiods.get(i)[4].elements(), 15); + ASSERT_EQ(audiods.get(i)[5].elements(), 1); + ASSERT_EQ(audiods.get(i)[5].scalar(), expectedDuration[i]); + ASSERT_EQ(audiods.get(i)[6].elements(), 1); + ASSERT_EQ(audiods.get(i)[6].scalar(), expectedTgtLen[i]); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); // Resolve directory for data #ifdef DATA_TEST_DATADIR - loadPath = fs::path(DATA_TEST_DATADIR); + loadPath = fs::path(DATA_TEST_DATADIR); #endif - return RUN_ALL_TESTS(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/data/SoundTest.cpp b/flashlight/pkg/speech/test/data/SoundTest.cpp index 31136b4..6f64719 100644 --- a/flashlight/pkg/speech/test/data/SoundTest.cpp +++ b/flashlight/pkg/speech/test/data/SoundTest.cpp @@ -21,151 +21,161 @@ namespace { fs::path loadPath = ""; auto loadData = [](const std::string& filepath) { - std::vector data; - std::ifstream file(filepath); - std::istream_iterator eos; - std::istream_iterator iit(file); - std::copy(iit, eos, std::back_inserter(data)); - return data; -}; + std::vector data; + std::ifstream file(filepath); + std::istream_iterator eos; + std::istream_iterator iit(file); + std::copy(iit, eos, std::back_inserter(data)); + return data; + }; } // namespace TEST(SoundTest, Mono) { - auto audiopath = loadPath / "test_mono.wav"; // 16-bit Signed Integer PCM - auto datapath = loadPath / "test_mono.dat"; - - auto info = loadSoundInfo(audiopath); - ASSERT_EQ(info.samplerate, 48000); - ASSERT_EQ(info.channels, 1); - ASSERT_EQ(info.frames, 24576); - - auto data = loadData(datapath); - - // Double - auto vecDouble = loadSound(audiopath); - ASSERT_EQ(vecDouble.size(), info.channels * info.frames); - for (int64_t i = 0; i < vecDouble.size(); ++i) { - ASSERT_NEAR(vecDouble[i], data[i], 1E-8); - } - - // Float - auto vecFloat = loadSound(audiopath); - ASSERT_EQ(vecFloat.size(), info.channels * info.frames); - - for (int64_t i = 0; i < vecFloat.size(); ++i) { - ASSERT_NEAR(vecFloat[i], data[i], 1E-6); - } - - // scale by max value for short - std::transform( - data.begin(), data.end(), data.begin(), [](double d) -> double { + auto audiopath = loadPath / "test_mono.wav"; // 16-bit Signed Integer PCM + auto datapath = loadPath / "test_mono.dat"; + + auto info = loadSoundInfo(audiopath); + ASSERT_EQ(info.samplerate, 48000); + ASSERT_EQ(info.channels, 1); + ASSERT_EQ(info.frames, 24576); + + auto data = loadData(datapath); + + // Double + auto vecDouble = loadSound(audiopath); + ASSERT_EQ(vecDouble.size(), info.channels * info.frames); + for(int64_t i = 0; i < vecDouble.size(); ++i) { + ASSERT_NEAR(vecDouble[i], data[i], 1E-8); + } + + // Float + auto vecFloat = loadSound(audiopath); + ASSERT_EQ(vecFloat.size(), info.channels * info.frames); + + for(int64_t i = 0; i < vecFloat.size(); ++i) { + ASSERT_NEAR(vecFloat[i], data[i], 1E-6); + } + + // scale by max value for short + std::transform( + data.begin(), + data.end(), + data.begin(), + [](double d) -> double { return d * (1 << 15); - }); - - // Short - auto vecShort = loadSound(audiopath); - ASSERT_EQ(vecShort.size(), info.channels * info.frames); - - for (int64_t i = 0; i < vecShort.size(); ++i) { - ASSERT_NEAR(vecShort[i], data[i], 0.5); - } - - // scale by (max value for int64_t / max value of short) - std::transform( - data.begin(), data.end(), data.begin(), [](double d) -> double { + } + ); + + // Short + auto vecShort = loadSound(audiopath); + ASSERT_EQ(vecShort.size(), info.channels * info.frames); + + for(int64_t i = 0; i < vecShort.size(); ++i) { + ASSERT_NEAR(vecShort[i], data[i], 0.5); + } + + // scale by (max value for int64_t / max value of short) + std::transform( + data.begin(), + data.end(), + data.begin(), + [](double d) -> double { return d * (1 << 16); - }); - // Int - auto vecInt = loadSound(audiopath); - ASSERT_EQ(vecInt.size(), info.channels * info.frames); - for (int64_t i = 0; i < vecInt.size(); ++i) { - ASSERT_NEAR(vecInt[i], data[i], 25); - } + } + ); + // Int + auto vecInt = loadSound(audiopath); + ASSERT_EQ(vecInt.size(), info.channels * info.frames); + for(int64_t i = 0; i < vecInt.size(); ++i) { + ASSERT_NEAR(vecInt[i], data[i], 25); + } } TEST(SoundTest, Stereo) { - auto audiopath = loadPath / "test_stereo.wav"; // 16-bit Signed Integer PCM - auto datapath = loadPath / "test_stereo.dat"; - auto info = loadSoundInfo(audiopath); + auto audiopath = loadPath / "test_stereo.wav"; // 16-bit Signed Integer PCM + auto datapath = loadPath / "test_stereo.dat"; + auto info = loadSoundInfo(audiopath); - ASSERT_EQ(info.samplerate, 48000); - ASSERT_EQ(info.channels, 2); - ASSERT_EQ(info.frames, 24576); + ASSERT_EQ(info.samplerate, 48000); + ASSERT_EQ(info.channels, 2); + ASSERT_EQ(info.frames, 24576); - auto vecFloat = loadSound(audiopath); - ASSERT_EQ(vecFloat.size(), info.channels * info.frames); + auto vecFloat = loadSound(audiopath); + ASSERT_EQ(vecFloat.size(), info.channels * info.frames); - auto data = loadData(datapath); - ASSERT_EQ(data.size(), info.channels * info.frames); + auto data = loadData(datapath); + ASSERT_EQ(data.size(), info.channels * info.frames); - for (int64_t i = 0; i < vecFloat.size(); ++i) { - ASSERT_NEAR(vecFloat[i], data[i], 1E-6); - } + for(int64_t i = 0; i < vecFloat.size(); ++i) { + ASSERT_NEAR(vecFloat[i], data[i], 1E-6); + } } TEST(SoundTest, OggReadWrite) { - auto audiopath = loadPath / "test_stereo.wav"; - const fs::path outaudiopath = fs::temp_directory_path() / "test.ogg"; - auto oggaudiopath = loadPath / "test_stereo.ogg"; - auto info = loadSoundInfo(audiopath); - auto vecShort = loadSound(audiopath); - - saveSound( - outaudiopath, - vecShort, - info.samplerate, - info.channels, - SoundFormat::OGG, - SoundSubFormat::VORBIS); - auto vecFloatOut = loadSound(outaudiopath); - auto vecFloat = loadSound(oggaudiopath); - - ASSERT_EQ(vecFloat.size(), vecFloatOut.size()); - - for (int64_t i = 0; i < vecFloat.size(); ++i) { - ASSERT_NEAR(vecFloat[i], vecFloatOut[i], 5E-3); - } + auto audiopath = loadPath / "test_stereo.wav"; + const fs::path outaudiopath = fs::temp_directory_path() / "test.ogg"; + auto oggaudiopath = loadPath / "test_stereo.ogg"; + auto info = loadSoundInfo(audiopath); + auto vecShort = loadSound(audiopath); + + saveSound( + outaudiopath, + vecShort, + info.samplerate, + info.channels, + SoundFormat::OGG, + SoundSubFormat::VORBIS + ); + auto vecFloatOut = loadSound(outaudiopath); + auto vecFloat = loadSound(oggaudiopath); + + ASSERT_EQ(vecFloat.size(), vecFloatOut.size()); + + for(int64_t i = 0; i < vecFloat.size(); ++i) { + ASSERT_NEAR(vecFloat[i], vecFloatOut[i], 5E-3); + } } TEST(SoundTest, StreamReadWrite) { - auto audiopath = loadPath / "test_stereo.wav"; - auto info = loadSoundInfo(audiopath); - auto vecShort = loadSound(audiopath); - - std::stringstream f; - saveSound( - f, - vecShort, - info.samplerate, - info.channels, - SoundFormat::WAV, - SoundSubFormat::PCM_16); - - f.seekg(0); - f.clear(); - auto infostream = loadSoundInfo(f); - ASSERT_EQ(info.samplerate, infostream.samplerate); - ASSERT_EQ(info.channels, infostream.channels); - ASSERT_EQ(info.frames, infostream.frames); - - f.seekg(0); - f.clear(); - auto vecShortStream = loadSound(f); - - ASSERT_EQ(vecShort.size(), vecShortStream.size()); - for (int64_t i = 0; i < vecShort.size(); ++i) { - ASSERT_EQ(vecShort[i], vecShortStream[i]); - } + auto audiopath = loadPath / "test_stereo.wav"; + auto info = loadSoundInfo(audiopath); + auto vecShort = loadSound(audiopath); + + std::stringstream f; + saveSound( + f, + vecShort, + info.samplerate, + info.channels, + SoundFormat::WAV, + SoundSubFormat::PCM_16 + ); + + f.seekg(0); + f.clear(); + auto infostream = loadSoundInfo(f); + ASSERT_EQ(info.samplerate, infostream.samplerate); + ASSERT_EQ(info.channels, infostream.channels); + ASSERT_EQ(info.frames, infostream.frames); + + f.seekg(0); + f.clear(); + auto vecShortStream = loadSound(f); + + ASSERT_EQ(vecShort.size(), vecShortStream.size()); + for(int64_t i = 0; i < vecShort.size(); ++i) { + ASSERT_EQ(vecShort[i], vecShortStream[i]); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); // Resolve directory for data #ifdef DATA_TEST_DATADIR - loadPath = fs::path(DATA_TEST_DATADIR); + loadPath = fs::path(DATA_TEST_DATADIR); #endif - return RUN_ALL_TESTS(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/decoder/ConvLmModuleTest.cpp b/flashlight/pkg/speech/test/decoder/ConvLmModuleTest.cpp index f7cdccd..5c93709 100644 --- a/flashlight/pkg/speech/test/decoder/ConvLmModuleTest.cpp +++ b/flashlight/pkg/speech/test/decoder/ConvLmModuleTest.cpp @@ -22,104 +22,104 @@ fs::path archDir = ""; } // namespace TEST(ConvLmModuleTest, GCNN14BAdaptiveSoftmax) { - const fs::path archfile = archDir / "gcnn_14B_lm_arch_as.txt"; - int nclass = 221452; - int batchsize = 2; - int inputlength = 100; - std::vector tail = {10000, 50000, 200000, nclass}; - - auto model = buildSequentialModule(archfile, 1, nclass); - auto as = std::make_shared(4096, tail); - auto criterion = std::make_shared(as); - model->eval(); - criterion->eval(); - auto input = fl::arange({inputlength, batchsize}); - auto output = model->forward(noGrad(input)); - output = as->forward(output); - - ASSERT_EQ(output.shape(), Shape({nclass, inputlength, batchsize})); - - // batchsize = 1 - batchsize = 1; - input = fl::arange({inputlength, batchsize}); - output = model->forward(noGrad(input)); - output = as->forward(output); - ASSERT_EQ(output.shape(), Shape({nclass, inputlength, batchsize})); + const fs::path archfile = archDir / "gcnn_14B_lm_arch_as.txt"; + int nclass = 221452; + int batchsize = 2; + int inputlength = 100; + std::vector tail = {10000, 50000, 200000, nclass}; + + auto model = buildSequentialModule(archfile, 1, nclass); + auto as = std::make_shared(4096, tail); + auto criterion = std::make_shared(as); + model->eval(); + criterion->eval(); + auto input = fl::arange({inputlength, batchsize}); + auto output = model->forward(noGrad(input)); + output = as->forward(output); + + ASSERT_EQ(output.shape(), Shape({nclass, inputlength, batchsize})); + + // batchsize = 1 + batchsize = 1; + input = fl::arange({inputlength, batchsize}); + output = model->forward(noGrad(input)); + output = as->forward(output); + ASSERT_EQ(output.shape(), Shape({nclass, inputlength, batchsize})); } TEST(ConvLmModuleTest, GCNN14BCrossEntropy) { - const fs::path archfile = archDir / "gcnn_14B_lm_arch_ce.txt"; - int nclass = 30; - int batchsize = 2; - int inputlength = 100; - - auto model = buildSequentialModule(archfile, 1, nclass); - model->eval(); - auto input = fl::arange({inputlength, batchsize}); - auto output = model->forward(noGrad(input)); - ASSERT_EQ(output.shape(), Shape({nclass, inputlength, batchsize})); - - // batchsize = 1 - batchsize = 1; - input = fl::arange({inputlength, batchsize}); - output = model->forward(noGrad(input)); - ASSERT_EQ(output.shape(), Shape({nclass, inputlength, batchsize})); + const fs::path archfile = archDir / "gcnn_14B_lm_arch_ce.txt"; + int nclass = 30; + int batchsize = 2; + int inputlength = 100; + + auto model = buildSequentialModule(archfile, 1, nclass); + model->eval(); + auto input = fl::arange({inputlength, batchsize}); + auto output = model->forward(noGrad(input)); + ASSERT_EQ(output.shape(), Shape({nclass, inputlength, batchsize})); + + // batchsize = 1 + batchsize = 1; + input = fl::arange({inputlength, batchsize}); + output = model->forward(noGrad(input)); + ASSERT_EQ(output.shape(), Shape({nclass, inputlength, batchsize})); } TEST(ConvLmModuleTest, SerializationGCNN14BAdaptiveSoftmax) { - char* user = getenv("USER"); - std::string userstr = "unknown"; - if (user != nullptr) { - userstr = std::string(user); - } - const fs::path path = fs::temp_directory_path() / "test.mdl"; - const fs::path archfile = archDir / "gcnn_14B_lm_arch_as.txt"; - - int nclass = 221452; - int batchsize = 2; - int inputlength = 10; - std::vector tail = {10000, 50000, 200000, nclass}; - - std::shared_ptr model = - buildSequentialModule(archfile, 1, nclass); - auto as = std::make_shared(4096, tail); - std::shared_ptr criterion = - std::make_shared(as); - model->eval(); - criterion->eval(); - auto input = noGrad(fl::arange({inputlength, batchsize})); - auto output = model->forward({input})[0]; - auto output_criterion = - std::dynamic_pointer_cast(criterion) - ->getActivation() - ->forward(output); - - save(path, model, criterion); - - std::shared_ptr loaded_model; - std::shared_ptr loaded_criterion; - load(path, loaded_model, loaded_criterion); - - auto outputl = loaded_model->forward({input})[0]; - auto outputl_criterion = - std::dynamic_pointer_cast(loaded_criterion) - ->getActivation() - ->forward(outputl); - - ASSERT_TRUE(allParamsClose(*loaded_model.get(), *model)); - ASSERT_TRUE(allParamsClose(*loaded_criterion.get(), *criterion)); - ASSERT_TRUE(allClose(outputl, output)); - ASSERT_TRUE(allClose(outputl_criterion, output_criterion)); + char* user = getenv("USER"); + std::string userstr = "unknown"; + if(user != nullptr) { + userstr = std::string(user); + } + const fs::path path = fs::temp_directory_path() / "test.mdl"; + const fs::path archfile = archDir / "gcnn_14B_lm_arch_as.txt"; + + int nclass = 221452; + int batchsize = 2; + int inputlength = 10; + std::vector tail = {10000, 50000, 200000, nclass}; + + std::shared_ptr model = + buildSequentialModule(archfile, 1, nclass); + auto as = std::make_shared(4096, tail); + std::shared_ptr criterion = + std::make_shared(as); + model->eval(); + criterion->eval(); + auto input = noGrad(fl::arange({inputlength, batchsize})); + auto output = model->forward({input})[0]; + auto output_criterion = + std::dynamic_pointer_cast(criterion) + ->getActivation() + ->forward(output); + + save(path, model, criterion); + + std::shared_ptr loaded_model; + std::shared_ptr loaded_criterion; + load(path, loaded_model, loaded_criterion); + + auto outputl = loaded_model->forward({input})[0]; + auto outputl_criterion = + std::dynamic_pointer_cast(loaded_criterion) + ->getActivation() + ->forward(outputl); + + ASSERT_TRUE(allParamsClose(*loaded_model.get(), *model)); + ASSERT_TRUE(allParamsClose(*loaded_criterion.get(), *criterion)); + ASSERT_TRUE(allClose(outputl, output)); + ASSERT_TRUE(allClose(outputl_criterion, output_criterion)); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); // Resolve directory for arch #ifdef DECODER_TEST_DATADIR - archDir = fs::path(DECODER_TEST_DATADIR); + archDir = fs::path(DECODER_TEST_DATADIR); #endif - return RUN_ALL_TESTS(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/test/runtime/RuntimeTest.cpp b/flashlight/pkg/speech/test/runtime/RuntimeTest.cpp index 6218e6d..44057a8 100644 --- a/flashlight/pkg/speech/test/runtime/RuntimeTest.cpp +++ b/flashlight/pkg/speech/test/runtime/RuntimeTest.cpp @@ -24,108 +24,110 @@ namespace { const fs::path kPath = fs::temp_directory_path() / "test.mdl"; bool afEqual(const fl::Variable& a, const fl::Variable& b) { - if (a.isCalcGrad() != b.isCalcGrad()) { - return false; - } - return allClose(a.tensor(), b.tensor(), 1E-7); + if(a.isCalcGrad() != b.isCalcGrad()) { + return false; + } + return allClose(a.tensor(), b.tensor(), 1E-7); } } // namespace TEST(RuntimeTest, LoadAndSave) { - std::unordered_map config( - {{"date", "01-01-01"}, {"lr", "0.1"}, {"user", "guy_fawkes"}}); - fl::Sequential model; - model.add(fl::Conv2D(4, 6, 2, 1)); - model.add(fl::GatedLinearUnit(2)); - model.add(fl::Dropout(0.2)); - model.add(fl::Conv2D(3, 4, 3, 1, 1, 1, 0, 0, 1, 1, false)); - model.add(fl::GatedLinearUnit(2)); - model.add(fl::Dropout(0.214)); - - fl::pkg::runtime::Serializer::save(kPath, FL_APP_ASR_VERSION, config, model); - - fl::Sequential modelload; - std::unordered_map configload; - std::string versionload; - fl::pkg::runtime::Serializer::load(kPath, versionload, configload, modelload); - - EXPECT_EQ(configload.size(), config.size()); - EXPECT_EQ(versionload, FL_APP_ASR_VERSION); - EXPECT_THAT(config, ::testing::ContainerEq(configload)); - - ASSERT_EQ(model.prettyString(), modelload.prettyString()); - - model.eval(); - modelload.eval(); - - for (int i = 0; i < 10; ++i) { - auto in = fl::Variable(fl::rand({10, 1, 4, 1}), i & 1); - ASSERT_TRUE(afEqual(model.forward(in), modelload.forward(in))); - } + std::unordered_map config( + {{"date", "01-01-01"}, {"lr", "0.1"}, {"user", "guy_fawkes"}}); + fl::Sequential model; + model.add(fl::Conv2D(4, 6, 2, 1)); + model.add(fl::GatedLinearUnit(2)); + model.add(fl::Dropout(0.2)); + model.add(fl::Conv2D(3, 4, 3, 1, 1, 1, 0, 0, 1, 1, false)); + model.add(fl::GatedLinearUnit(2)); + model.add(fl::Dropout(0.214)); + + fl::pkg::runtime::Serializer::save(kPath, FL_APP_ASR_VERSION, config, model); + + fl::Sequential modelload; + std::unordered_map configload; + std::string versionload; + fl::pkg::runtime::Serializer::load(kPath, versionload, configload, modelload); + + EXPECT_EQ(configload.size(), config.size()); + EXPECT_EQ(versionload, FL_APP_ASR_VERSION); + EXPECT_THAT(config, ::testing::ContainerEq(configload)); + + ASSERT_EQ(model.prettyString(), modelload.prettyString()); + + model.eval(); + modelload.eval(); + + for(int i = 0; i < 10; ++i) { + auto in = fl::Variable(fl::rand({10, 1, 4, 1}), i & 1); + ASSERT_TRUE(afEqual(model.forward(in), modelload.forward(in))); + } } TEST(RuntimeTest, TestCleanFilepath) { - auto s = cleanFilepath("timit/train.\\mymodel"); - std::string sep(1, fs::path::preferred_separator); - if (sep == "/") { - ASSERT_EQ(s, "timit#train.\\mymodel"); - } else if (sep == "\\") { - ASSERT_EQ(s, "timit/train.#mymodel"); - } else { - GTEST_SKIP() << "System uses a different separator"; - } + auto s = cleanFilepath("timit/train.\\mymodel"); + std::string sep(1, fs::path::preferred_separator); + if(sep == "/") { + ASSERT_EQ(s, "timit#train.\\mymodel"); + } else if(sep == "\\") { + ASSERT_EQ(s, "timit/train.#mymodel"); + } else { + GTEST_SKIP() << "System uses a different separator"; + } } TEST(RuntimeTest, SpeechStatMeter) { - SpeechStatMeter meter; - std::array inpSizes1{4, 5}; - std::array tgSizes1{6, 10}; - std::array inpSizes2{2, 4, 2, 8}; - std::array tgSizes2{3, 7, 2, 4}; - meter.add( - Tensor::fromArray({1, 2}, inpSizes1), - Tensor::fromArray({1, 2}, tgSizes1)); - auto stats1 = meter.value(); - ASSERT_EQ(stats1[0], 9.0); - ASSERT_EQ(stats1[1], 16.0); - ASSERT_EQ(stats1[2], 5.0); - ASSERT_EQ(stats1[3], 10.0); - ASSERT_EQ(stats1[4], 2.0); - ASSERT_EQ(stats1[5], 1); - meter.add( - Tensor::fromArray({1, 4}, inpSizes2), - Tensor::fromArray({1, 4}, tgSizes2)); - auto stats2 = meter.value(); - ASSERT_EQ(stats2[0], 25.0); - ASSERT_EQ(stats2[1], 32.0); - ASSERT_EQ(stats2[2], 8.0); - ASSERT_EQ(stats2[3], 10.0); - ASSERT_EQ(stats2[4], 6.0); - ASSERT_EQ(stats2[5], 2); + SpeechStatMeter meter; + std::array inpSizes1{4, 5}; + std::array tgSizes1{6, 10}; + std::array inpSizes2{2, 4, 2, 8}; + std::array tgSizes2{3, 7, 2, 4}; + meter.add( + Tensor::fromArray({1, 2}, inpSizes1), + Tensor::fromArray({1, 2}, tgSizes1) + ); + auto stats1 = meter.value(); + ASSERT_EQ(stats1[0], 9.0); + ASSERT_EQ(stats1[1], 16.0); + ASSERT_EQ(stats1[2], 5.0); + ASSERT_EQ(stats1[3], 10.0); + ASSERT_EQ(stats1[4], 2.0); + ASSERT_EQ(stats1[5], 1); + meter.add( + Tensor::fromArray({1, 4}, inpSizes2), + Tensor::fromArray({1, 4}, tgSizes2) + ); + auto stats2 = meter.value(); + ASSERT_EQ(stats2[0], 25.0); + ASSERT_EQ(stats2[1], 32.0); + ASSERT_EQ(stats2[2], 8.0); + ASSERT_EQ(stats2[3], 10.0); + ASSERT_EQ(stats2[4], 6.0); + ASSERT_EQ(stats2[5], 2); } TEST(RuntimeTest, parseValidSets) { - std::string in; - auto op = parseValidSets(in); - ASSERT_EQ(op.size(), 0); - - std::string in1 = "d1:d1.lst,d2:d2.lst"; - auto op1 = parseValidSets(in1); - ASSERT_EQ(op1.size(), 2); - ASSERT_EQ(op1[0], (std::pair("d1", "d1.lst"))); - ASSERT_EQ(op1[1], (std::pair("d2", "d2.lst"))); - - std::string in2 = "d1.lst,d2.lst,d3.lst"; - auto op2 = parseValidSets(in2); - ASSERT_EQ(op2.size(), 3); - ASSERT_EQ(op2[0], (std::pair("d1.lst", "d1.lst"))); - ASSERT_EQ(op2[1], (std::pair("d2.lst", "d2.lst"))); - ASSERT_EQ(op2[2], (std::pair("d3.lst", "d3.lst"))); + std::string in; + auto op = parseValidSets(in); + ASSERT_EQ(op.size(), 0); + + std::string in1 = "d1:d1.lst,d2:d2.lst"; + auto op1 = parseValidSets(in1); + ASSERT_EQ(op1.size(), 2); + ASSERT_EQ(op1[0], (std::pair("d1", "d1.lst"))); + ASSERT_EQ(op1[1], (std::pair("d2", "d2.lst"))); + + std::string in2 = "d1.lst,d2.lst,d3.lst"; + auto op2 = parseValidSets(in2); + ASSERT_EQ(op2.size(), 3); + ASSERT_EQ(op2[0], (std::pair("d1.lst", "d1.lst"))); + ASSERT_EQ(op2[1], (std::pair("d2.lst", "d2.lst"))); + ASSERT_EQ(op2[2], (std::pair("d3.lst", "d3.lst"))); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/mgpuenums.h b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/mgpuenums.h index 390eedb..c585089 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/mgpuenums.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/mgpuenums.h @@ -1,6 +1,6 @@ /***************************************************************************** * Copyright (c) 2013, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,10 +11,10 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; @@ -26,45 +26,45 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ -#pragma once +#pragma once namespace mgpu { enum MgpuBounds { - MgpuBoundsLower, - MgpuBoundsUpper + MgpuBoundsLower, + MgpuBoundsUpper }; enum MgpuScanType { - MgpuScanTypeExc, - MgpuScanTypeInc + MgpuScanTypeExc, + MgpuScanTypeInc }; enum MgpuSearchType { - MgpuSearchTypeNone, - MgpuSearchTypeIndex, - MgpuSearchTypeMatch, - MgpuSearchTypeIndexMatch + MgpuSearchTypeNone, + MgpuSearchTypeIndex, + MgpuSearchTypeMatch, + MgpuSearchTypeIndexMatch }; enum MgpuJoinKind { - MgpuJoinKindInner, - MgpuJoinKindLeft, - MgpuJoinKindRight, - MgpuJoinKindOuter + MgpuJoinKindInner, + MgpuJoinKindLeft, + MgpuJoinKindRight, + MgpuJoinKindOuter }; enum MgpuSetOp { - MgpuSetOpIntersection, - MgpuSetOpUnion, - MgpuSetOpDiff, - MgpuSetOpSymDiff + MgpuSetOpIntersection, + MgpuSetOpUnion, + MgpuSetOpDiff, + MgpuSetOpSymDiff }; } // namespace mgpu diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/util/static.h b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/util/static.h index 408637b..5aa1f37 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/util/static.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/util/static.h @@ -1,6 +1,6 @@ /***************************************************************************** * Copyright (c) 2013, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,10 +11,10 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; @@ -26,11 +26,11 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ #pragma once @@ -57,7 +57,7 @@ #define MGPU_DIV_UP(x, y) (((x) + (y) - 1) / (y)) #define MGPU_DIV_ROUND(x, y) (((x) + (y) / 2) / (y)) #define MGPU_ROUND_UP(x, y) ((y) * MGPU_DIV_UP(x, y)) -#define MGPU_SHIFT_DIV_UP(x, y) (((x) + ((1<< (y)) - 1))>> y) +#define MGPU_SHIFT_DIV_UP(x, y) (((x) + ((1 << (y)) - 1)) >> y) #define MGPU_ROUND_UP_POW2(x, y) (((x) + (y) - 1) & ~((y) - 1)) #define MGPU_ROUND_DOWN_POW2(x, y) ((x) & ~((y) - 1)) #define MGPU_IS_POW_2(x) (0 == ((x) & ((x) - 1))) @@ -79,105 +79,117 @@ typedef long long int64; typedef unsigned long long uint64; // IsPow2::value is true if X is a power of 2. -template struct sIsPow2 { - enum { value = 0 == (X & (X - 1)) }; +template +struct sIsPow2 { + enum {value = 0 == (X & (X - 1))}; }; // Finds the base-2 logarithm of X. value is -1 if X is not a power of 2. -template struct sLogPow2 { - enum { extra = sIsPow2::value ? 0 : (roundUp ? 1 : 0) }; - enum { inner = sLogPow2::inner + 1 }; - enum { value = inner + extra }; +template +struct sLogPow2 { + enum {extra = sIsPow2::value ? 0 : (roundUp ? 1 : 0)}; + enum {inner = sLogPow2::inner + 1}; + enum {value = inner + extra}; }; -template struct sLogPow2<0, roundUp> { - enum { inner = 0 }; - enum { value = 0 }; +template +struct sLogPow2<0, roundUp> { + enum {inner = 0}; + enum {value = 0}; }; -template struct sLogPow2<1, roundUp> { - enum { inner = 0 }; - enum { value = 0 }; +template +struct sLogPow2<1, roundUp> { + enum {inner = 0}; + enum {value = 0}; }; template struct sDivUp { - enum { value = (X + Y - 1) / Y }; + enum {value = (X + Y - 1) / Y}; }; -template struct sDiv2RoundUp { - enum { value = sDiv2RoundUp::value, levels - 1>::value }; +template +struct sDiv2RoundUp { + enum {value = sDiv2RoundUp::value, levels - 1>::value}; }; -template struct sDiv2RoundUp { - enum { value = count }; +template +struct sDiv2RoundUp { + enum {value = count}; }; template struct sDivSafe { - enum { value = X / Y }; + enum {value = X / Y}; }; template struct sDivSafe { - enum { value = 0 }; + enum {value = 0}; }; template struct sRoundUp { - enum { rem = X % Y }; - enum { value = X + (rem ? (Y - rem) : 0) }; + enum {rem = X % Y}; + enum {value = X + (rem ? (Y - rem) : 0)}; }; template struct sRoundDown { - enum { rem = X % Y }; - enum { value = X - rem }; + enum {rem = X % Y}; + enum {value = X - rem}; }; -// IntegerDiv is a template for avoiding divisions by zero in template +// IntegerDiv is a template for avoiding divisions by zero in template // evaluation. Templates always evaluate both b and c in an expression like // a ? b : c, and will error if either rhs contains an illegal expression, // even if the ternary is explictly designed to guard against that. template struct sIntegerDiv { - enum { value = X / (Y ? Y : (X + 1)) }; + enum {value = X / (Y ? Y : (X + 1))}; }; template struct sMax { - enum { value = (X >= Y) ? X : Y }; + enum {value = (X >= Y) ? X : Y}; }; template struct sMin { - enum { value = (X <= Y) ? X : Y }; + enum {value = (X <= Y) ? X : Y}; }; template struct sAbs { - enum { value = (X >= 0) ? X : -X }; + enum {value = (X >= 0) ? X : -X}; }; // Finds the number of powers of 2 in the prime factorization of X. -template struct sNumFactorsOf2 { - enum { shifted = X >> 1 }; - enum { value = 1 + sNumFactorsOf2::value }; +template +struct sNumFactorsOf2 { + enum {shifted = X >> 1}; + enum {value = 1 + sNumFactorsOf2::value}; }; -template struct sNumFactorsOf2 { - enum { value = 0 }; +template +struct sNumFactorsOf2 { + enum {value = 0}; }; // Returns the divisor for a conflict-free transpose. -template struct sBankConflictDivisor { - enum { value = - (1 & X) ? 0 : - (sIsPow2::value ? NumBanks : - (1<< sNumFactorsOf2::value)) }; - enum { log_value = sLogPow2::value }; -}; - -template struct sConflictFreeStorage { - enum { count = NT * X }; - enum { divisor = sBankConflictDivisor::value }; - enum { padding = sDivSafe::value }; - enum { value = count + padding }; +template +struct sBankConflictDivisor { + enum { + value = + (1 & X) ? 0 + : (sIsPow2::value ? NumBanks + : (1 << sNumFactorsOf2::value)) + }; + enum {log_value = sLogPow2::value}; +}; + +template +struct sConflictFreeStorage { + enum {count = NT * X}; + enum {divisor = sBankConflictDivisor::value}; + enum {padding = sDivSafe::value}; + enum {value = count + padding}; }; } // namespace mgpu diff --git a/flashlight/pkg/speech/third_party/warpctc/include/ctc.h b/flashlight/pkg/speech/third_party/warpctc/include/ctc.h index 160c50d..d45ae5d 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/ctc.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/ctc.h @@ -10,7 +10,7 @@ extern "C" { #endif -//forward declare of CUDA typedef to avoid needing to pull in CUDA headers +// forward declare of CUDA typedef to avoid needing to pull in CUDA headers typedef struct CUstream_st* CUstream; typedef enum { @@ -92,16 +92,18 @@ struct ctcOptions { * \return Status information * * */ -ctcStatus_t compute_ctc_loss(const float* const activations, - float* gradients, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths, - int alphabet_size, - int minibatch, - float *costs, - void *workspace, - ctcOptions options); +ctcStatus_t compute_ctc_loss( + const float* const activations, + float* gradients, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, + int alphabet_size, + int minibatch, + float* costs, + void* workspace, + ctcOptions options +); /** For a given set of labels and minibatch size return the required workspace @@ -121,11 +123,14 @@ ctcStatus_t compute_ctc_loss(const float* const activations, * * \return Status information **/ -ctcStatus_t get_workspace_size(const int* const label_lengths, - const int* const input_lengths, - int alphabet_size, int minibatch, - ctcOptions info, - size_t* size_bytes); +ctcStatus_t get_workspace_size( + const int* const label_lengths, + const int* const input_lengths, + int alphabet_size, + int minibatch, + ctcOptions info, + size_t* size_bytes +); #ifdef __cplusplus } diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h index f65e8a2..e9bb074 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h @@ -17,14 +17,20 @@ template class CpuCTC { public: // Noncopyable - CpuCTC(int alphabet_size, int minibatch, void* workspace, int num_threads, - int blank_label) : - alphabet_size_(alphabet_size), minibatch_(minibatch), - num_threads_(num_threads), blank_label_(blank_label), - workspace_(workspace) { + CpuCTC( + int alphabet_size, + int minibatch, + void* workspace, + int num_threads, + int blank_label + ) : alphabet_size_(alphabet_size), + minibatch_(minibatch), + num_threads_(num_threads), + blank_label_(blank_label), + workspace_(workspace) { #if defined(CTC_DISABLE_OMP) || defined(APPLE) #else - if (num_threads > 0) { + if(num_threads > 0) { omp_set_num_threads(num_threads); } else { num_threads_ = omp_get_max_threads(); @@ -35,19 +41,23 @@ class CpuCTC { CpuCTC(const CpuCTC&) = delete; CpuCTC& operator=(const CpuCTC&) = delete; - ctcStatus_t cost_and_grad(const ProbT* const activations, - ProbT *grads, - ProbT* costs, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths); + ctcStatus_t cost_and_grad( + const ProbT* const activations, + ProbT* grads, + ProbT* costs, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths + ); - ctcStatus_t score_forward(const ProbT* const activations, - ProbT* costs, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths); + ctcStatus_t score_forward( + const ProbT* const activations, + ProbT* costs, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths + ); private: @@ -57,9 +67,17 @@ class CpuCTC { int setup_labels(const int* const labels, int blank_label, int L, int S); public: - CpuCTC_metadata(int L, int S, int T, int mb, int alphabet_size, - void* workspace, size_t bytes_used, int blank_label, - const int* const labels); + CpuCTC_metadata( + int L, + int S, + int T, + int mb, + int alphabet_size, + void* workspace, + size_t bytes_used, + int blank_label, + const int* const labels + ); ProbT* alphas; ProbT* betas; @@ -76,58 +94,87 @@ class CpuCTC { int blank_label_; void* workspace_; - void softmax(const ProbT* const activations, ProbT* probs, - const int* const input_lengths); - - std::tuple - cost_and_grad_kernel(ProbT *grad, const ProbT* const probs, - const int* const labels, int T, int L, - int mb, size_t bytes_used); - - ProbT compute_alphas(const ProbT* probs, int repeats, int S, int T, - const int* const e_inc, - const int* const s_inc, - const int* const labels, - ProbT* alphas); - - ProbT compute_betas_and_grad(ProbT* grad, const ProbT* const probs, - ProbT log_partition, int repeats, - int S, int T, const int* const e_inc, - const int* const s_inc, - const int* const labels, - ProbT* alphas, - ProbT* betas, - ProbT* output); + void softmax( + const ProbT* const activations, + ProbT* probs, + const int* const input_lengths + ); + + std::tuple cost_and_grad_kernel( + ProbT* grad, + const ProbT* const probs, + const int* const labels, + int T, + int L, + int mb, + size_t bytes_used + ); + + ProbT compute_alphas( + const ProbT* probs, + int repeats, + int S, + int T, + const int* const e_inc, + const int* const s_inc, + const int* const labels, + ProbT* alphas + ); + + ProbT compute_betas_and_grad( + ProbT* grad, + const ProbT* const probs, + ProbT log_partition, + int repeats, + int S, + int T, + const int* const e_inc, + const int* const s_inc, + const int* const labels, + ProbT* alphas, + ProbT* betas, + ProbT* output + ); }; template -CpuCTC::CpuCTC_metadata::CpuCTC_metadata(int L, int S, int T, int mb, - int alphabet_size, - void* workspace, size_t bytes_used, - int blank_label, - const int* const labels) { - - alphas = reinterpret_cast(static_cast(workspace) + bytes_used); +CpuCTC::CpuCTC_metadata::CpuCTC_metadata( + int L, + int S, + int T, + int mb, + int alphabet_size, + void* workspace, + size_t bytes_used, + int blank_label, + const int* const labels +) { + + alphas = reinterpret_cast(static_cast(workspace) + bytes_used); bytes_used += sizeof(ProbT) * S * T; std::fill(alphas, alphas + S * T, ctc_helper::neg_inf()); - betas = reinterpret_cast(static_cast(workspace) + bytes_used); + betas = reinterpret_cast(static_cast(workspace) + bytes_used); bytes_used += sizeof(ProbT) * S; std::fill(betas, betas + S, ctc_helper::neg_inf()); - labels_w_blanks = reinterpret_cast(static_cast(workspace) + bytes_used); + labels_w_blanks = reinterpret_cast(static_cast(workspace) + bytes_used); bytes_used += sizeof(int) * S; - e_inc = reinterpret_cast(static_cast(workspace) + bytes_used); + e_inc = reinterpret_cast(static_cast(workspace) + bytes_used); bytes_used += sizeof(int) * S; - s_inc = reinterpret_cast(static_cast(workspace) + bytes_used); + s_inc = reinterpret_cast(static_cast(workspace) + bytes_used); bytes_used += sizeof(int) * S; - output = reinterpret_cast(static_cast(workspace) + bytes_used); + output = reinterpret_cast(static_cast(workspace) + bytes_used); bytes_used += sizeof(ProbT) * alphabet_size; repeats = setup_labels(labels, blank_label, L, S); } template -int CpuCTC::CpuCTC_metadata::setup_labels(const int* const labels, - int blank_label, int L, int S) { +int CpuCTC::CpuCTC_metadata::setup_labels( + const int* const labels, + int blank_label, + int L, + int S +) { int e_counter = 0; int s_counter = 0; @@ -135,22 +182,21 @@ int CpuCTC::CpuCTC_metadata::setup_labels(const int* const labels, int repeats = 0; - for (int i = 1; i < L; ++i) { - if (labels[i-1] == labels[i]) { + for(int i = 1; i < L; ++i) { + if(labels[i - 1] == labels[i]) { s_inc[s_counter++] = 1; s_inc[s_counter++] = 1; e_inc[e_counter++] = 1; e_inc[e_counter++] = 1; ++repeats; - } - else { + } else { s_inc[s_counter++] = 2; e_inc[e_counter++] = 2; } } e_inc[e_counter++] = 1; - for (int i = 0; i < L; ++i) { + for(int i = 0; i < L; ++i) { labels_w_blanks[2 * i] = blank_label; labels_w_blanks[2 * i + 1] = labels[i]; } @@ -160,16 +206,19 @@ int CpuCTC::CpuCTC_metadata::setup_labels(const int* const labels, } template -void -CpuCTC::softmax(const ProbT* const activations, ProbT* probs, - const int* const input_lengths) { +void CpuCTC::softmax( + const ProbT* const activations, + ProbT* probs, + const int* const input_lengths +) { #pragma omp parallel for - for (int mb = 0; mb < minibatch_; ++mb) { + for(int mb = 0; mb < minibatch_; ++mb) { for(int c = 0; c < input_lengths[mb]; ++c) { int col_offset = (mb + minibatch_ * c) * alphabet_size_; ProbT max_activation = -std::numeric_limits::infinity(); - for(int r = 0; r < alphabet_size_; ++r) + for(int r = 0; r < alphabet_size_; ++r) { max_activation = std::max(max_activation, activations[r + col_offset]); + } ProbT denom = ProbT(0.); for(int r = 0; r < alphabet_size_; ++r) { @@ -185,34 +234,54 @@ CpuCTC::softmax(const ProbT* const activations, ProbT* probs, } template -std::tuple -CpuCTC::cost_and_grad_kernel(ProbT *grad, const ProbT* const probs, - const int* const labels, - int T, int L, int mb, size_t bytes_used) { - - const int S = 2*L + 1; // Number of labels with blanks +std::tuple CpuCTC::cost_and_grad_kernel( + ProbT* grad, + const ProbT* const probs, + const int* const labels, + int T, + int L, + int mb, + size_t bytes_used +) { + + const int S = 2 * L + 1; // Number of labels with blanks CpuCTC_metadata ctcm(L, S, T, mb, alphabet_size_, workspace_, bytes_used, blank_label_, labels); bool over_threshold = false; - if (L + ctcm.repeats > T) { + if(L + ctcm.repeats > T) { return std::make_tuple(ProbT(0), over_threshold); // TODO, not right to return 0 } - ProbT llForward = compute_alphas(probs, ctcm.repeats, S, T, ctcm.e_inc, - ctcm.s_inc, ctcm.labels_w_blanks, - ctcm.alphas); - - ProbT llBackward = compute_betas_and_grad(grad, probs, llForward, ctcm.repeats, - S, T, ctcm.e_inc, ctcm.s_inc, - ctcm.labels_w_blanks, - ctcm.alphas, - ctcm.betas, - ctcm.output); + ProbT llForward = compute_alphas( + probs, + ctcm.repeats, + S, + T, + ctcm.e_inc, + ctcm.s_inc, + ctcm.labels_w_blanks, + ctcm.alphas + ); + + ProbT llBackward = compute_betas_and_grad( + grad, + probs, + llForward, + ctcm.repeats, + S, + T, + ctcm.e_inc, + ctcm.s_inc, + ctcm.labels_w_blanks, + ctcm.alphas, + ctcm.betas, + ctcm.output + ); ProbT diff = std::abs(llForward - llBackward); - if (diff > ctc_helper::threshold) { + if(diff > ctc_helper::threshold) { over_threshold = true; } @@ -221,39 +290,47 @@ CpuCTC::cost_and_grad_kernel(ProbT *grad, const ProbT* const probs, // Computes forward probabilities template -ProbT CpuCTC::compute_alphas(const ProbT* probs, int repeats, int S, int T, - const int* const e_inc, - const int* const s_inc, - const int* const labels, - ProbT* alphas) { +ProbT CpuCTC::compute_alphas( + const ProbT* probs, + int repeats, + int S, + int T, + const int* const e_inc, + const int* const s_inc, + const int* const labels, + ProbT* alphas +) { + + int start = (((S / 2) + repeats - T) < 0) ? 0 : 1, + end = S > 1 ? 2 : 1; - int start = (((S /2) + repeats - T) < 0) ? 0 : 1, - end = S > 1 ? 2 : 1; - - for (int i = start; i < end; ++i) { + for(int i = start; i < end; ++i) { alphas[i] = std::log(probs[labels[i]]); } for(int t = 1; t < T; ++t) { int remain = (S / 2) + repeats - (T - t); - if(remain >= 0) + if(remain >= 0) { start += s_inc[remain]; - if(t <= (S / 2) + repeats) + } + if(t <= (S / 2) + repeats) { end += e_inc[t - 1]; + } int startloop = start; int idx1 = t * S, idx2 = (t - 1) * S, idx3 = t * (alphabet_size_ * minibatch_); - if (start == 0) { + if(start == 0) { alphas[idx1] = alphas[idx2] + std::log(probs[blank_label_ + idx3]); startloop += 1; } for(int i = startloop; i < end; ++i) { - ProbT prev_sum = ctc_helper::log_plus()(alphas[i + idx2], alphas[(i-1) + idx2]); + ProbT prev_sum = ctc_helper::log_plus()(alphas[i + idx2], alphas[(i - 1) + idx2]); // Skip two if not on blank and not on repeat. - if (labels[i] != blank_label_ && i != 1 && labels[i] != labels[i-2]) - prev_sum = ctc_helper::log_plus()(prev_sum, alphas[(i-2) + idx2]); + if(labels[i] != blank_label_ && i != 1 && labels[i] != labels[i - 2]) { + prev_sum = ctc_helper::log_plus()(prev_sum, alphas[(i - 2) + idx2]); + } alphas[i + idx1] = prev_sum + std::log(probs[labels[i] + idx3]); } @@ -273,52 +350,64 @@ ProbT CpuCTC::compute_alphas(const ProbT* probs, int repeats, int S, int // NOTE computes gradient w.r.t UNNORMALIZED final layer activations. // Assumed passed in grads are already zeroed! template -ProbT CpuCTC::compute_betas_and_grad(ProbT* grad, const ProbT* const probs, - ProbT log_partition, int repeats, - int S, int T, const int* const e_inc, - const int* const s_inc, - const int* const labels, - ProbT* alphas, - ProbT* betas, - ProbT* output) { +ProbT CpuCTC::compute_betas_and_grad( + ProbT* grad, + const ProbT* const probs, + ProbT log_partition, + int repeats, + int S, + int T, + const int* const e_inc, + const int* const s_inc, + const int* const labels, + ProbT* alphas, + ProbT* betas, + ProbT* output +) { int start = S > 1 ? (S - 2) : 0, - end = (T > (S / 2) + repeats) ? S : S-1; + end = (T > (S / 2) + repeats) ? S : S - 1; std::fill(output, output + alphabet_size_, ctc_helper::neg_inf()); - //set the starting values in the beta column at the very right edge - for (int i = start; i < end; ++i) { + // set the starting values in the beta column at the very right edge + for(int i = start; i < end; ++i) { betas[i] = std::log(probs[labels[i] + (T - 1) * (alphabet_size_ * minibatch_)]); - //compute alpha * beta in log space at this position in (S, T) space + // compute alpha * beta in log space at this position in (S, T) space alphas[i + (T - 1) * S] += betas[i]; - //update the gradient associated with this label - //essentially performing a reduce-by-key in a sequential manner + // update the gradient associated with this label + // essentially performing a reduce-by-key in a sequential manner output[labels[i]] = - ctc_helper::log_plus()(alphas[i + (T - 1) * S], output[labels[i]]); + ctc_helper::log_plus()(alphas[i + (T - 1) * S], output[labels[i]]); } - //update the gradient wrt to each unique label - for (int i = 0; i < alphabet_size_; ++i) { + // update the gradient wrt to each unique label + for(int i = 0; i < alphabet_size_; ++i) { int idx3 = (T - 1) * alphabet_size_ * minibatch_ + i; - if (output[i] == 0.0 || output[i] == ctc_helper::neg_inf() || - probs[idx3] == 0.0) { + if( + output[i] == 0.0 || output[i] == ctc_helper::neg_inf() + || probs[idx3] == 0.0 + ) { grad[idx3] = probs[idx3]; } else { - grad[idx3] = probs[idx3] - std::exp(output[i] - - std::log(probs[idx3]) - log_partition); + grad[idx3] = probs[idx3] - std::exp( + output[i] + - std::log(probs[idx3]) - log_partition + ); } } - //loop from the second to last column all the way to the left + // loop from the second to last column all the way to the left for(int t = T - 2; t >= 0; --t) { int remain = (S / 2) + repeats - (T - t); - if(remain >= -1) + if(remain >= -1) { start -= s_inc[remain + 1]; - if(t < (S / 2) + repeats) + } + if(t < (S / 2) + repeats) { end -= e_inc[t]; + } int endloop = end == S ? end - 1 : end; int idx1 = t * S, idx3 = t * (alphabet_size_ * minibatch_); @@ -326,39 +415,43 @@ ProbT CpuCTC::compute_betas_and_grad(ProbT* grad, const ProbT* const prob std::fill(output, output + alphabet_size_, ctc_helper::neg_inf()); for(int i = start; i < endloop; ++i) { - ProbT next_sum = ctc_helper::log_plus()(betas[i], betas[(i+1)]); + ProbT next_sum = ctc_helper::log_plus()(betas[i], betas[(i + 1)]); // Skip two if not on blank and not on repeat. - if (labels[i] != blank_label_ && i != (S-2) && labels[i] != labels[i+2]){ - next_sum = ctc_helper::log_plus()(next_sum, betas[(i+2)]); + if(labels[i] != blank_label_ && i != (S - 2) && labels[i] != labels[i + 2]) { + next_sum = ctc_helper::log_plus()(next_sum, betas[(i + 2)]); } betas[i] = next_sum + std::log(probs[labels[i] + idx3]); - //compute alpha * beta in log space + // compute alpha * beta in log space alphas[i + idx1] += betas[i]; - //update the gradient associated with this label + // update the gradient associated with this label output[labels[i]] = - ctc_helper::log_plus()(alphas[i + idx1], output[labels[i]]); + ctc_helper::log_plus()(alphas[i + idx1], output[labels[i]]); } - if (end == S) { - betas[(S-1)] = betas[(S-1)] + std::log(probs[blank_label_ + idx3]); - alphas[(S-1) + idx1] += betas[(S-1)]; + if(end == S) { + betas[(S - 1)] = betas[(S - 1)] + std::log(probs[blank_label_ + idx3]); + alphas[(S - 1) + idx1] += betas[(S - 1)]; - output[labels[S-1]] = - ctc_helper::log_plus()(alphas[S-1 + idx1], output[labels[S-1]]); + output[labels[S - 1]] = + ctc_helper::log_plus()(alphas[S - 1 + idx1], output[labels[S - 1]]); } - //go over the unique labels and compute the final grad + // go over the unique labels and compute the final grad // wrt to each one at this time step - for (int i = 0; i < alphabet_size_; ++i) { + for(int i = 0; i < alphabet_size_; ++i) { - if (output[i] == 0.0 || output[i] == ctc_helper::neg_inf() || - probs[idx3] == 0.0) { + if( + output[i] == 0.0 || output[i] == ctc_helper::neg_inf() + || probs[idx3] == 0.0 + ) { grad[idx3] = probs[idx3]; } else { - grad[idx3] = probs[idx3] - std::exp(output[i] - - std::log(probs[idx3]) - log_partition); + grad[idx3] = probs[idx3] - std::exp( + output[i] + - std::log(probs[idx3]) - log_partition + ); } ++idx3; } @@ -373,123 +466,141 @@ ProbT CpuCTC::compute_betas_and_grad(ProbT* grad, const ProbT* const prob } template -ctcStatus_t -CpuCTC::cost_and_grad(const ProbT* const activations, - ProbT *grads, - ProbT *costs, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths) { - if (activations == nullptr || - grads == nullptr || - costs == nullptr || - flat_labels == nullptr || - label_lengths == nullptr || - input_lengths == nullptr - ) +ctcStatus_t CpuCTC::cost_and_grad( + const ProbT* const activations, + ProbT* grads, + ProbT* costs, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths +) { + if( + activations == nullptr + || grads == nullptr + || costs == nullptr + || flat_labels == nullptr + || label_lengths == nullptr + || input_lengths == nullptr + ) { return CTC_STATUS_INVALID_VALUE; + } - ProbT* probs = static_cast(workspace_); + ProbT* probs = static_cast(workspace_); int maxT = *std::max_element(input_lengths, input_lengths + minibatch_); size_t bytes_used = sizeof(ProbT) * minibatch_ * alphabet_size_ * maxT; - //per minibatch memory + // per minibatch memory size_t per_minibatch_bytes = 0; int maxL = *std::max_element(label_lengths, label_lengths + minibatch_); int maxS = 2 * maxL + 1; - //output + // output per_minibatch_bytes += sizeof(float) * alphabet_size_; - //alphas + // alphas per_minibatch_bytes += sizeof(float) * maxS * maxT; - //betas + // betas per_minibatch_bytes += sizeof(float) * maxS; - //labels w/blanks, e_inc, s_inc + // labels w/blanks, e_inc, s_inc per_minibatch_bytes += 3 * sizeof(int) * maxS; softmax(activations, probs, input_lengths); #pragma omp parallel for - for (int mb = 0; mb < minibatch_; ++mb) { + for(int mb = 0; mb < minibatch_; ++mb) { const int T = input_lengths[mb]; // Length of utterance (time) const int L = label_lengths[mb]; // Number of labels in transcription bool mb_status; std::tie(costs[mb], mb_status) = - cost_and_grad_kernel(grads + mb * alphabet_size_, - probs + mb * alphabet_size_, - flat_labels + std::accumulate(label_lengths, label_lengths + mb, 0), - T, L, mb, - bytes_used + mb * per_minibatch_bytes); + cost_and_grad_kernel( + grads + mb * alphabet_size_, + probs + mb * alphabet_size_, + flat_labels + std::accumulate(label_lengths, label_lengths + mb, 0), + T, + L, + mb, + bytes_used + mb * per_minibatch_bytes + ); } return CTC_STATUS_SUCCESS; } template -ctcStatus_t CpuCTC::score_forward(const ProbT* const activations, - ProbT* costs, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths) { - if (activations == nullptr || - costs == nullptr || - flat_labels == nullptr || - label_lengths == nullptr || - input_lengths == nullptr - ) +ctcStatus_t CpuCTC::score_forward( + const ProbT* const activations, + ProbT* costs, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths +) { + if( + activations == nullptr + || costs == nullptr + || flat_labels == nullptr + || label_lengths == nullptr + || input_lengths == nullptr + ) { return CTC_STATUS_INVALID_VALUE; + } - ProbT* probs = static_cast(workspace_); + ProbT* probs = static_cast(workspace_); int maxT = *std::max_element(input_lengths, input_lengths + minibatch_); size_t bytes_used = sizeof(ProbT) * minibatch_ * alphabet_size_ * maxT; - //per minibatch memory + // per minibatch memory size_t per_minibatch_bytes = 0; int maxL = *std::max_element(label_lengths, label_lengths + minibatch_); int maxS = 2 * maxL + 1; - //output + // output per_minibatch_bytes += sizeof(float) * alphabet_size_; - //alphas + // alphas per_minibatch_bytes += sizeof(float) * maxS * maxT; - //betas + // betas per_minibatch_bytes += sizeof(float) * maxS; - //labels w/blanks, e_inc, s_inc + // labels w/blanks, e_inc, s_inc per_minibatch_bytes += 3 * sizeof(int) * maxS; softmax(activations, probs, input_lengths); #pragma omp parallel for - for (int mb = 0; mb < minibatch_; ++mb) { + for(int mb = 0; mb < minibatch_; ++mb) { const int T = input_lengths[mb]; // Length of utterance (time) const int L = label_lengths[mb]; // Number of labels in transcription - const int S = 2*L + 1; // Number of labels with blanks + const int S = 2 * L + 1; // Number of labels with blanks CpuCTC_metadata ctcm(L, S, T, mb, alphabet_size_, workspace_, - bytes_used + mb * per_minibatch_bytes, blank_label_, - flat_labels + std::accumulate(label_lengths, label_lengths + mb, 0)); + bytes_used + mb * per_minibatch_bytes, blank_label_, + flat_labels + std::accumulate(label_lengths, label_lengths + mb, 0)); - if (L + ctcm.repeats > T) + if(L + ctcm.repeats > T) { costs[mb] = ProbT(0); - else { - costs[mb] = -compute_alphas(probs + mb * alphabet_size_, ctcm.repeats, S, T, - ctcm.e_inc, ctcm.s_inc, ctcm.labels_w_blanks, - ctcm.alphas); + } else { + costs[mb] = -compute_alphas( + probs + mb * alphabet_size_, + ctcm.repeats, + S, + T, + ctcm.e_inc, + ctcm.s_inc, + ctcm.labels_w_blanks, + ctcm.alphas + ); } } diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h index cd30b90..8653b2c 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h @@ -18,45 +18,53 @@ inline int div_up(int x, int y) { return (x + y - 1) / y; } -template struct maximum { +template +struct maximum { HOSTDEVICE Res operator()(const Arg& x, const Arg& y) const { return x < y ? y : x; } }; -template struct add { +template +struct add { HOSTDEVICE Res operator()(const Arg& x, const Arg& y) const { return x + y; } }; -template struct identity { - HOSTDEVICE Res operator()(const Arg& x) const {return Res(x);} +template +struct identity { + HOSTDEVICE Res operator()(const Arg& x) const { return Res(x); } }; -template struct negate { - HOSTDEVICE Res operator()(const Arg& x) const {return Res(-x);} +template +struct negate { + HOSTDEVICE Res operator()(const Arg& x) const { return Res(-x); } }; -template struct logarithmic { - HOSTDEVICE Res operator()(const Arg& x) const {return std::log(x);} +template +struct logarithmic { + HOSTDEVICE Res operator()(const Arg& x) const { return std::log(x); } }; -template struct exponential { - HOSTDEVICE Res operator()(const Arg& x) const {return std::exp(x);} +template +struct exponential { + HOSTDEVICE Res operator()(const Arg& x) const { return std::exp(x); } }; -template +template struct log_plus { typedef Res result_type; HOSTDEVICE Res operator()(const Arg1& p1, const Arg2& p2) { - if (p1 == neg_inf()) + if(p1 == neg_inf()) { return p2; - if (p2 == neg_inf()) + } + if(p2 == neg_inf()) { return p1; + } Res result = log1p(exp(-fabs(p1 - p2))) + maximum()(p1, p2); return result; } diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h index 33f2132..2c46d9d 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h @@ -4,130 +4,140 @@ #include "gpu_ctc_kernels.h" #include "reduce.h" -template +template class GpuCTC { - public: - GpuCTC(int alphabet_size, - int minibatch, - void *workspace, - CUstream stream, - int blank_label) : - out_dim_(alphabet_size), minibatch_(minibatch), - gpu_workspace_(workspace), stream_(stream), - blank_label_(blank_label) {}; - - // Noncopyable - GpuCTC(const GpuCTC&) = delete; - GpuCTC& operator=(const GpuCTC&) = delete; - - ctcStatus_t - cost_and_grad(const ProbT* const activations, - ProbT* grads, - ProbT* costs, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths); - - ctcStatus_t - score_forward(const ProbT* const activations, - ProbT* costs, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths); - - private: - - template - ctcStatus_t launch_alpha_beta_kernels(const ProbT* const probs, - ProbT *grads, - bool compute_alpha, - bool compute_beta); - - ctcStatus_t - launch_gpu_kernels(const ProbT* const probs, - ProbT *grads, - size_t config, - bool launch_alpha, - bool launch_beta); - - ctcStatus_t - setup_gpu_metadata(const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths); - - ctcStatus_t - create_metadata_and_choose_config(const int* const label_lengths, - const int* const flat_labels, - const int* const input_lengths, - size_t& best_config); - - ctcStatus_t - compute_log_probs(const ProbT* const activations); - - ctcStatus_t - compute_cost_and_score(const ProbT* const activations, - ProbT* grads, - ProbT* costs, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths, - bool compute_alpha, - bool compute_betas_and_grad); - - - int out_dim_; // Number of characters plus blank - int minibatch_; - - int S_; - int T_; - - int activation_cols_; // Number of columns in activations - - CUstream stream_; - int blank_label_; - - void *gpu_workspace_; // Buffer for all temporary GPU memory - int *utt_length_; // T - int *label_sizes_; // L - int *repeats_; // repeats_ - int *label_offsets_; - int *labels_without_blanks_; - int *labels_with_blanks_; - ProbT *alphas_; - ProbT *nll_forward_; - ProbT *nll_backward_; - ProbT *denoms_; // Temporary storage for denoms for softmax - ProbT *probs_; // Temporary storage for probabilities (softmax output) +public: + GpuCTC( + int alphabet_size, + int minibatch, + void* workspace, + CUstream stream, + int blank_label + ) : out_dim_(alphabet_size), + minibatch_(minibatch), + gpu_workspace_(workspace), + stream_(stream), + blank_label_(blank_label) {}; + + // Noncopyable + GpuCTC(const GpuCTC&) = delete; + GpuCTC& operator=(const GpuCTC&) = delete; + + ctcStatus_t cost_and_grad( + const ProbT* const activations, + ProbT* grads, + ProbT* costs, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths + ); + + ctcStatus_t score_forward( + const ProbT* const activations, + ProbT* costs, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths + ); + +private: + + template + ctcStatus_t launch_alpha_beta_kernels( + const ProbT* const probs, + ProbT* grads, + bool compute_alpha, + bool compute_beta + ); + + ctcStatus_t launch_gpu_kernels( + const ProbT* const probs, + ProbT* grads, + size_t config, + bool launch_alpha, + bool launch_beta + ); + + ctcStatus_t setup_gpu_metadata( + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths + ); + + ctcStatus_t create_metadata_and_choose_config( + const int* const label_lengths, + const int* const flat_labels, + const int* const input_lengths, + size_t& best_config + ); + + ctcStatus_t compute_log_probs(const ProbT* const activations); + + ctcStatus_t compute_cost_and_score( + const ProbT* const activations, + ProbT* grads, + ProbT* costs, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, + bool compute_alpha, + bool compute_betas_and_grad + ); + + + int out_dim_; // Number of characters plus blank + int minibatch_; + + int S_; + int T_; + + int activation_cols_; // Number of columns in activations + + CUstream stream_; + int blank_label_; + + void* gpu_workspace_; // Buffer for all temporary GPU memory + int* utt_length_; // T + int* label_sizes_; // L + int* repeats_; // repeats_ + int* label_offsets_; + int* labels_without_blanks_; + int* labels_with_blanks_; + ProbT* alphas_; + ProbT* nll_forward_; + ProbT* nll_backward_; + ProbT* denoms_; // Temporary storage for denoms for softmax + ProbT* probs_; // Temporary storage for probabilities (softmax output) }; template -ctcStatus_t -GpuCTC::setup_gpu_metadata(const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths) -{ +ctcStatus_t GpuCTC::setup_gpu_metadata( + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths +) { size_t gpu_bytes_used = 0; nll_forward_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); gpu_bytes_used += minibatch_ * sizeof(ProbT); nll_backward_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); gpu_bytes_used += minibatch_ * sizeof(ProbT); repeats_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); gpu_bytes_used += minibatch_ * sizeof(int); label_offsets_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); gpu_bytes_used += minibatch_ * sizeof(int); @@ -151,23 +161,24 @@ GpuCTC::setup_gpu_metadata(const int* const flat_labels, cudaError_t cuda_status; - for (int pass = 0; pass < num_passes; ++pass) { + for(int pass = 0; pass < num_passes; ++pass) { const int start_idx = pass * cpu_buffer_size; - const int end_idx = std::min(minibatch_, (pass+1) * cpu_buffer_size); + const int end_idx = std::min(minibatch_, (pass + 1) * cpu_buffer_size); - for (int j = start_idx; j < end_idx; ++j) { + for(int j = start_idx; j < end_idx; ++j) { const int L = label_lengths[j]; const int local_T = input_lengths[j]; - const int *label_ptr = &(flat_labels[total_label_length]); + const int* label_ptr = &(flat_labels[total_label_length]); label_offsets[j % cpu_buffer_size] = total_label_length; total_label_length += L; int repeat_counter = 0; - for (int i = 1; i < L; ++i) - repeat_counter += (label_ptr[i] == label_ptr[i-1]); + for(int i = 1; i < L; ++i) { + repeat_counter += (label_ptr[i] == label_ptr[i - 1]); + } repeats[j % cpu_buffer_size] = repeat_counter; const bool valid_label = ((L + repeat_counter) <= local_T); @@ -180,18 +191,28 @@ GpuCTC::setup_gpu_metadata(const int* const flat_labels, Lmax = std::max(Lmax, L); } - cuda_status = cudaMemcpyAsync(&(repeats_[start_idx]), repeats, - (end_idx - start_idx) * sizeof(int), - cudaMemcpyHostToDevice, stream_); - if (cuda_status != cudaSuccess) + cuda_status = cudaMemcpyAsync( + &(repeats_[start_idx]), + repeats, + (end_idx - start_idx) * sizeof(int), + cudaMemcpyHostToDevice, + stream_ + ); + if(cuda_status != cudaSuccess) { return CTC_STATUS_MEMOPS_FAILED; + } - cuda_status = cudaMemcpyAsync(&(label_offsets_[start_idx]), label_offsets, - (end_idx - start_idx) * sizeof(int), - cudaMemcpyHostToDevice, stream_); - if (cuda_status != cudaSuccess) + cuda_status = cudaMemcpyAsync( + &(label_offsets_[start_idx]), + label_offsets, + (end_idx - start_idx) * sizeof(int), + cudaMemcpyHostToDevice, + stream_ + ); + if(cuda_status != cudaSuccess) { return CTC_STATUS_MEMOPS_FAILED; + } } S_ = 2 * S_ + 1; @@ -201,55 +222,70 @@ GpuCTC::setup_gpu_metadata(const int* const flat_labels, // Allocate memory for T utt_length_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); - gpu_bytes_used += minibatch_ * sizeof(int); - - cuda_status = cudaMemcpyAsync(utt_length_, input_lengths, - minibatch_ * sizeof(int), - cudaMemcpyHostToDevice, stream_); - if (cuda_status != cudaSuccess) + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); + gpu_bytes_used += minibatch_ * sizeof(int); + + cuda_status = cudaMemcpyAsync( + utt_length_, + input_lengths, + minibatch_ * sizeof(int), + cudaMemcpyHostToDevice, + stream_ + ); + if(cuda_status != cudaSuccess) { return CTC_STATUS_MEMOPS_FAILED; + } label_sizes_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); gpu_bytes_used += minibatch_ * sizeof(int); - cuda_status = cudaMemcpyAsync(label_sizes_, label_lengths, - minibatch_ * sizeof(int), - cudaMemcpyHostToDevice, stream_); - if (cuda_status != cudaSuccess) + cuda_status = cudaMemcpyAsync( + label_sizes_, + label_lengths, + minibatch_ * sizeof(int), + cudaMemcpyHostToDevice, + stream_ + ); + if(cuda_status != cudaSuccess) { return CTC_STATUS_MEMOPS_FAILED; + } labels_without_blanks_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); gpu_bytes_used += Lmax * minibatch_ * sizeof(int); - cuda_status = cudaMemcpyAsync(labels_without_blanks_, flat_labels, - total_label_length * sizeof(int), - cudaMemcpyHostToDevice, stream_); - if (cuda_status != cudaSuccess) + cuda_status = cudaMemcpyAsync( + labels_without_blanks_, + flat_labels, + total_label_length * sizeof(int), + cudaMemcpyHostToDevice, + stream_ + ); + if(cuda_status != cudaSuccess) { return CTC_STATUS_MEMOPS_FAILED; + } labels_with_blanks_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); gpu_bytes_used += Smax * minibatch_ * sizeof(int); alphas_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); gpu_bytes_used += (S_ * T_) * minibatch_ * sizeof(ProbT); denoms_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); gpu_bytes_used += activation_cols_ * sizeof(ProbT); probs_ = - reinterpret_cast(static_cast(gpu_workspace_) + - gpu_bytes_used); + reinterpret_cast(static_cast(gpu_workspace_) + + gpu_bytes_used); gpu_bytes_used += out_dim_ * activation_cols_ * sizeof(ProbT); return CTC_STATUS_SUCCESS; @@ -257,10 +293,12 @@ GpuCTC::setup_gpu_metadata(const int* const flat_labels, template template -ctcStatus_t GpuCTC::launch_alpha_beta_kernels(const ProbT* const probs, - ProbT* grads, - bool compute_alpha, - bool compute_beta ) { +ctcStatus_t GpuCTC::launch_alpha_beta_kernels( + const ProbT* const probs, + ProbT* grads, + bool compute_alpha, + bool compute_beta +) { // One thread block per utterance const int grid_size = minibatch_; @@ -269,108 +307,126 @@ ctcStatus_t GpuCTC::launch_alpha_beta_kernels(const ProbT* const probs, // away const int stride = minibatch_; - if (compute_alpha) - compute_alpha_kernel<<>> - (probs, label_sizes_, utt_length_, - repeats_, labels_without_blanks_, label_offsets_, - labels_with_blanks_, alphas_, nll_forward_, - stride, out_dim_, S_, T_, blank_label_); + if(compute_alpha) { + compute_alpha_kernel<< < grid_size, NT, 0, stream_ >> + > (probs, label_sizes_, utt_length_, + repeats_, labels_without_blanks_, label_offsets_, + labels_with_blanks_, alphas_, nll_forward_, + stride, out_dim_, S_, T_, blank_label_); + } - if (compute_beta) { - compute_betas_and_grad_kernel<<>> - (probs, label_sizes_, utt_length_, repeats_, - labels_with_blanks_, alphas_, nll_forward_, nll_backward_, - grads, stride, out_dim_, S_, T_, blank_label_); + if(compute_beta) { + compute_betas_and_grad_kernel<< < grid_size, NT, 0, stream_ >> + > (probs, label_sizes_, utt_length_, repeats_, + labels_with_blanks_, alphas_, nll_forward_, nll_backward_, + grads, stride, out_dim_, S_, T_, blank_label_); cudaStreamSynchronize(stream_); } cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) + if(err != cudaSuccess) { return CTC_STATUS_EXECUTION_FAILED; + } return CTC_STATUS_SUCCESS; } template -ctcStatus_t -GpuCTC::create_metadata_and_choose_config(const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths, - size_t& best_config) { +ctcStatus_t GpuCTC::create_metadata_and_choose_config( + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, + size_t& best_config +) { // Setup the metadata for GPU ctcStatus_t status = setup_gpu_metadata(flat_labels, label_lengths, input_lengths); - if (status != CTC_STATUS_SUCCESS) + if(status != CTC_STATUS_SUCCESS) { return status; + } constexpr int num_configs = 12; int config_NT[num_configs] = - {32, 64, 128, 64, 128, 32, 64, 128, 64, 128, 128, 128}; + {32, 64, 128, 64, 128, 32, 64, 128, 64, 128, 128, 128}; int config_VT[num_configs] = - { 1, 1, 1, 3, 2, 9, 6, 4, 9, 6, 9, 10}; + {1, 1, 1, 3, 2, 9, 6, 4, 9, 6, 9, 10}; best_config = 0; - for (int i = 0; i < num_configs; ++i) { - if ((config_NT[i]* config_VT[i]) >= S_) + for(int i = 0; i < num_configs; ++i) { + if((config_NT[i] * config_VT[i]) >= S_) { break; - else + } else { best_config++; + } } - if (best_config >= num_configs) + if(best_config >= num_configs) { return CTC_STATUS_LABEL_LENGTH_TOO_LARGE; + } return CTC_STATUS_SUCCESS; } template -ctcStatus_t -GpuCTC::launch_gpu_kernels(const ProbT* const probs, - ProbT* grads, - size_t config, - bool l_a, - bool l_b) { +ctcStatus_t GpuCTC::launch_gpu_kernels( + const ProbT* const probs, + ProbT* grads, + size_t config, + bool l_a, + bool l_b +) { switch(config) { - case 0: {return launch_alpha_beta_kernels<32, 1>(probs, grads, l_a, l_b);} - case 1: {return launch_alpha_beta_kernels<64, 1>(probs, grads, l_a, l_b);} - case 2: {return launch_alpha_beta_kernels<128, 1>(probs, grads, l_a, l_b);} - case 3: {return launch_alpha_beta_kernels<64, 3>(probs, grads, l_a, l_b);} - case 4: {return launch_alpha_beta_kernels<128, 2>(probs, grads, l_a, l_b);} - case 5: {return launch_alpha_beta_kernels<32, 9>(probs, grads, l_a, l_b);} - case 6: {return launch_alpha_beta_kernels<64, 6>(probs, grads, l_a, l_b);} - case 7: {return launch_alpha_beta_kernels<128, 4>(probs, grads, l_a, l_b);} - case 8: {return launch_alpha_beta_kernels<64, 9>(probs, grads, l_a, l_b);} - case 9: {return launch_alpha_beta_kernels<128, 6>(probs, grads, l_a, l_b);} - case 10: {return launch_alpha_beta_kernels<128, 9>(probs, grads, l_a, l_b);} - case 11: {return launch_alpha_beta_kernels<128, 10>(probs, grads, l_a, l_b);} + case 0: return launch_alpha_beta_kernels<32, 1>(probs, grads, l_a, l_b); + case 1: return launch_alpha_beta_kernels<64, 1>(probs, grads, l_a, l_b); + case 2: return launch_alpha_beta_kernels<128, 1>(probs, grads, l_a, l_b); + case 3: return launch_alpha_beta_kernels<64, 3>(probs, grads, l_a, l_b); + case 4: return launch_alpha_beta_kernels<128, 2>(probs, grads, l_a, l_b); + case 5: return launch_alpha_beta_kernels<32, 9>(probs, grads, l_a, l_b); + case 6: return launch_alpha_beta_kernels<64, 6>(probs, grads, l_a, l_b); + case 7: return launch_alpha_beta_kernels<128, 4>(probs, grads, l_a, l_b); + case 8: return launch_alpha_beta_kernels<64, 9>(probs, grads, l_a, l_b); + case 9: return launch_alpha_beta_kernels<128, 6>(probs, grads, l_a, l_b); + case 10: return launch_alpha_beta_kernels<128, 9>(probs, grads, l_a, l_b); + case 11: return launch_alpha_beta_kernels<128, 10>(probs, grads, l_a, l_b); } return CTC_STATUS_EXECUTION_FAILED; } template -ctcStatus_t -GpuCTC::compute_log_probs(const ProbT* const activations) { +ctcStatus_t GpuCTC::compute_log_probs(const ProbT* const activations) { cudaError_t cuda_status; cuda_status = - cudaMemcpyAsync(probs_, activations, - activation_cols_ * out_dim_ *sizeof(ProbT), - cudaMemcpyDeviceToDevice, stream_); - if (cuda_status != cudaSuccess) + cudaMemcpyAsync( + probs_, + activations, + activation_cols_ * out_dim_ * sizeof(ProbT), + cudaMemcpyDeviceToDevice, + stream_ + ); + if(cuda_status != cudaSuccess) { return CTC_STATUS_MEMOPS_FAILED; + } // Numerically stable SM ctcStatus_t ctc_status = - reduce_max(probs_, denoms_, out_dim_, - activation_cols_, 1, stream_); - if (ctc_status != CTC_STATUS_SUCCESS) + reduce_max( + probs_, + denoms_, + out_dim_, + activation_cols_, + 1, + stream_ + ); + if(ctc_status != CTC_STATUS_SUCCESS) { return ctc_status; + } // Kernel launch to subtract maximum const int NT = 128; @@ -379,99 +435,144 @@ GpuCTC::compute_log_probs(const ProbT* const activations) { const int num_elements = out_dim_ * activation_cols_; const int grid_size = ctc_helper::div_up(num_elements, NV); - prepare_stable_SM_kernel <<< grid_size, NT, 0, stream_>>> - (ctc_helper::identity(), probs_, + prepare_stable_SM_kernel << < grid_size, NT, 0, stream_ >> + > (ctc_helper::identity(), probs_, denoms_, out_dim_, num_elements); // Reduce along columns to calculate denominator ctc_status = - reduce_exp(probs_, denoms_, out_dim_, - activation_cols_, 1, stream_); - if (ctc_status != CTC_STATUS_SUCCESS) + reduce_exp( + probs_, + denoms_, + out_dim_, + activation_cols_, + 1, + stream_ + ); + if(ctc_status != CTC_STATUS_SUCCESS) { return ctc_status; + } // Kernel launch to calculate probabilities - compute_log_probs_kernel<<>> - (ctc_helper::logarithmic(), probs_, - denoms_, out_dim_, num_elements); + compute_log_probs_kernel<< < grid_size, NT, 0, stream_ >> + > (ctc_helper::logarithmic(), probs_, + denoms_, out_dim_, num_elements); return CTC_STATUS_SUCCESS; } template -ctcStatus_t -GpuCTC::compute_cost_and_score(const ProbT* const activations, - ProbT* grads, - ProbT* costs, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths, - bool compute_alpha, - bool compute_betas_and_grad) { +ctcStatus_t GpuCTC::compute_cost_and_score( + const ProbT* const activations, + ProbT* grads, + ProbT* costs, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, + bool compute_alpha, + bool compute_betas_and_grad +) { size_t best_config; - ctcStatus_t status = create_metadata_and_choose_config(flat_labels, - label_lengths, - input_lengths, - best_config); - if (status != CTC_STATUS_SUCCESS) + ctcStatus_t status = create_metadata_and_choose_config( + flat_labels, + label_lengths, + input_lengths, + best_config + ); + if(status != CTC_STATUS_SUCCESS) { return status; + } status = compute_log_probs(activations); - if (status != CTC_STATUS_SUCCESS) + if(status != CTC_STATUS_SUCCESS) { return status; + } - status = launch_gpu_kernels(probs_, grads, best_config, - compute_alpha, compute_betas_and_grad); + status = launch_gpu_kernels( + probs_, + grads, + best_config, + compute_alpha, + compute_betas_and_grad + ); - if (status != CTC_STATUS_SUCCESS) + if(status != CTC_STATUS_SUCCESS) { return status; + } cudaError_t cuda_status_mem, cuda_status_sync; - cuda_status_mem = cudaMemcpyAsync(costs, nll_forward_, - sizeof(ProbT) * minibatch_, - cudaMemcpyDeviceToHost, stream_); + cuda_status_mem = cudaMemcpyAsync( + costs, + nll_forward_, + sizeof(ProbT) * minibatch_, + cudaMemcpyDeviceToHost, + stream_ + ); cuda_status_sync = cudaStreamSynchronize(stream_); - if (cuda_status_mem != cudaSuccess || cuda_status_sync != cudaSuccess) + if(cuda_status_mem != cudaSuccess || cuda_status_sync != cudaSuccess) { return CTC_STATUS_MEMOPS_FAILED; + } return CTC_STATUS_SUCCESS; } template -ctcStatus_t -GpuCTC::cost_and_grad(const ProbT* const activations, - ProbT* grads, - ProbT* costs, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths) { - if (activations == nullptr || - grads == nullptr || - costs == nullptr || - label_lengths == nullptr || - input_lengths == nullptr - ) +ctcStatus_t GpuCTC::cost_and_grad( + const ProbT* const activations, + ProbT* grads, + ProbT* costs, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths +) { + if( + activations == nullptr + || grads == nullptr + || costs == nullptr + || label_lengths == nullptr + || input_lengths == nullptr + ) { return CTC_STATUS_INVALID_VALUE; + } - return compute_cost_and_score(activations, grads, costs, flat_labels, - label_lengths, input_lengths, true, true); + return compute_cost_and_score( + activations, + grads, + costs, + flat_labels, + label_lengths, + input_lengths, + true, + true + ); } template -ctcStatus_t -GpuCTC::score_forward(const ProbT* const activations, - ProbT* costs, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths) { - if (activations == nullptr || - costs == nullptr || - label_lengths == nullptr || - input_lengths == nullptr - ) +ctcStatus_t GpuCTC::score_forward( + const ProbT* const activations, + ProbT* costs, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths +) { + if( + activations == nullptr + || costs == nullptr + || label_lengths == nullptr + || input_lengths == nullptr + ) { return CTC_STATUS_INVALID_VALUE; + } - return compute_cost_and_score(activations, nullptr, costs, flat_labels, - label_lengths, input_lengths, true, false); + return compute_cost_and_score( + activations, + nullptr, + costs, + flat_labels, + label_lengths, + input_lengths, + true, + false + ); } diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h index 1735ba8..3318e37 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h @@ -16,18 +16,23 @@ struct CTASegReduce { int indices[NV]; }; - //adapted from global kernel KernelReduceByKeyPreprocess - __device__ static void preprocessKeys(KeyT *keys, int count, - int *numUniqueLabels, int seg_start[VT], - int seg_end[VT], int *scanout) { + // adapted from global kernel KernelReduceByKeyPreprocess + __device__ static void preprocessKeys( + KeyT* keys, + int count, + int* numUniqueLabels, + int seg_start[VT], + int seg_end[VT], + int* scanout + ) { __shared__ Storage shared; const int tid = threadIdx.x; // Compare adjacent keys within each thread and mark discontinuities int endFlags = 0; T key = keys[VT * tid]; - #pragma unroll - for (int i = 0; i < VT; ++i) { +#pragma unroll + for(int i = 0; i < VT; ++i) { int index = VT * tid + 1 + i; T next = keys[index]; if(index == count || (index < count && key != next)) { @@ -38,18 +43,18 @@ struct CTASegReduce { __syncthreads(); - //Count the number of encountered end flags + // Count the number of encountered end flags int scan = CTAScan::Scan(tid, popc(endFlags), shared.scanStorage, numUniqueLabels); __syncthreads(); - //output the unique keys - //use indices as scratch space + // output the unique keys + // use indices as scratch space int outputPos = scan; - #pragma unroll - for (int i = 0; i < VT; ++i) { +#pragma unroll + for(int i = 0; i < VT; ++i) { - if ( (endFlags >> i) & 1) { + if((endFlags >> i) & 1) { shared.indices[outputPos] = keys[VT * tid + i]; scanout[outputPos] = VT * tid + i; outputPos++; @@ -59,16 +64,16 @@ struct CTASegReduce { __syncthreads(); // Create start and end - for (int idx = tid, j = 0; idx < (*numUniqueLabels); idx += blockDim.x, ++j) { - seg_start[j] = (idx == 0) ? 0 : (scanout[idx-1] + 1); + for(int idx = tid, j = 0; idx < (*numUniqueLabels); idx += blockDim.x, ++j) { + seg_start[j] = (idx == 0) ? 0 : (scanout[idx - 1] + 1); seg_end[j] = scanout[idx]; } __syncthreads(); - //copy from the scratch space back into the keys - #pragma unroll - for (int i = 0; i < VT; ++i) { + // copy from the scratch space back into the keys +#pragma unroll + for(int i = 0; i < VT; ++i) { keys[i * NT + tid] = shared.indices[i * NT + tid]; } @@ -85,67 +90,79 @@ struct CTASegReduce { // than the labels. This is much more true for Mandarin than English. template __global__ -void compute_alpha_kernel (const ProbT* probs, const int *label_sizes, - const int *utt_length, const int *repeats_in_labels, - const int *labels_without_blanks, const int *label_offsets, - int *labels_with_blanks, ProbT *alphas, - ProbT* nll_forward, int stride, int out_dim, - int S_memoffset, int T_memoffset, int blank_label) { +void compute_alpha_kernel( + const ProbT* probs, + const int* label_sizes, + const int* utt_length, + const int* repeats_in_labels, + const int* labels_without_blanks, + const int* label_offsets, + int* labels_with_blanks, + ProbT* alphas, + ProbT* nll_forward, + int stride, + int out_dim, + int S_memoffset, + int T_memoffset, + int blank_label +) { ctc_helper::log_plus log_plus_f; const int tid = threadIdx.x; const int L = label_sizes[blockIdx.x]; const int T = utt_length[blockIdx.x]; - const int S = 2*L + 1; + const int S = 2 * L + 1; const int prob_offset = out_dim * blockIdx.x; const int repeats = repeats_in_labels[blockIdx.x]; const int NV = NT * VT; __shared__ int label[NV]; - if ((L + repeats) > T) + if((L + repeats) > T) { return; + } // Generate labels with blanks from labels without blanks { const int label_start_offset = label_offsets[blockIdx.x]; - for (int idx = tid; idx < L; idx += blockDim.x) { + for(int idx = tid; idx < L; idx += blockDim.x) { const int offset = (blockIdx.x * S_memoffset) + 2 * idx; labels_with_blanks[offset] = blank_label; - labels_with_blanks[offset+1] = labels_without_blanks[label_start_offset + idx]; + labels_with_blanks[offset + 1] = labels_without_blanks[label_start_offset + idx]; } - if (tid == 0) { + if(tid == 0) { labels_with_blanks[(blockIdx.x * S_memoffset) + 2 * L] = blank_label; } } __syncthreads(); - const int *labels = labels_with_blanks; + const int* labels = labels_with_blanks; const int* label_global = &labels[blockIdx.x * S_memoffset]; ProbT* alpha = &alphas[blockIdx.x * (S_memoffset * T_memoffset)]; // Set the first row of alpha neg_inf - it is much more efficient to do it // here than outside - #pragma unroll - for (int idx = tid; idx < min(S, NV); idx += blockDim.x) { +#pragma unroll + for(int idx = tid; idx < min(S, NV); idx += blockDim.x) { alpha[idx] = ctc_helper::neg_inf(); } // Load labels into shared memory - #pragma unroll - for (int i = tid; i < S; i += NT) { +#pragma unroll + for(int i = tid; i < S; i += NT) { label[i] = label_global[i]; } __syncthreads(); - int start = (L + repeats < T) ? 0 : 1; + int start = (L + repeats < T) ? 0 : 1; int end = S > 1 ? 2 : 1; // Initialize the first row corresponding to t=0; - for(int i = tid; i < (end-start); i += blockDim.x) + for(int i = tid; i < (end - start); i += blockDim.x) { alpha[i + start] = probs[prob_offset + label[i + start]]; + } __syncthreads(); @@ -161,12 +178,11 @@ void compute_alpha_kernel (const ProbT* probs, const int *label_sizes, const int start_prob_col = t * (out_dim * stride); // This is the first column and in this case there is nothing left of it - if (tid == 0) { - if (start == 0) { - alpha[start_cur_row] = alpha[start_prev_row] + - probs[prob_offset + start_prob_col + blank_label]; - } - else if (start == 1) { + if(tid == 0) { + if(start == 0) { + alpha[start_cur_row] = alpha[start_prev_row] + + probs[prob_offset + start_prob_col + blank_label]; + } else if(start == 1) { alpha[start_cur_row] = alpha[start_prev_row]; } } @@ -177,15 +193,18 @@ void compute_alpha_kernel (const ProbT* probs, const int *label_sizes, // input is the row above. We sum either two or three adjacent values from the // row above depending on whether we have a blank or repeated characters. Finally // we add the probability corresponding to this label at time t - #pragma unroll - for (int idx = (tid+1); idx < S; idx += blockDim.x) { +#pragma unroll + for(int idx = (tid + 1); idx < S; idx += blockDim.x) { - ProbT prev_sum = log_plus_f(alpha[idx + start_prev_row], alpha[(idx-1) + start_prev_row]); + ProbT prev_sum = log_plus_f(alpha[idx + start_prev_row], alpha[(idx - 1) + start_prev_row]); // Skip two if not on blank and not on repeat. - if ((label[idx] != blank_label) && - (idx != 1) && (label[idx] != label[idx-2])) - prev_sum = log_plus_f(prev_sum, alpha[(idx-2) + start_prev_row]); + if( + (label[idx] != blank_label) + && (idx != 1) && (label[idx] != label[idx - 2]) + ) { + prev_sum = log_plus_f(prev_sum, alpha[(idx - 2) + start_prev_row]); + } alpha[idx + start_cur_row] = prev_sum + probs[prob_offset + start_prob_col + label[idx]]; @@ -194,18 +213,19 @@ void compute_alpha_kernel (const ProbT* probs, const int *label_sizes, __syncthreads(); } - if (tid == 0) { + if(tid == 0) { // Add and return the rightmost two/one element(s) in the last row. ProbT loglike = ctc_helper::neg_inf(); // This is the total increment for s_inc and e_inc through the loop - const int val = 2 * (L-1) + 1 - (((L + repeats) == T) ? 1 : 0); + const int val = 2 * (L - 1) + 1 - (((L + repeats) == T) ? 1 : 0); - start = (val * (L!=0) + start); - end = (val * (L!=0) + end); + start = (val * (L != 0) + start); + end = (val * (L != 0) + end); - for(int i = start; i < end; ++i) + for(int i = start; i < end; ++i) { loglike = log_plus_f(loglike, alpha[i + (T - 1) * S]); + } nll_forward[blockIdx.x] = -loglike; } @@ -216,12 +236,22 @@ void compute_alpha_kernel (const ProbT* probs, const int *label_sizes, // See comments above compute_alphas for more context. template __global__ -void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes, - const int *utt_length, const int *repeats_in_labels, - const int *labels_with_blanks, ProbT *alphas, - const ProbT* nll_forward, ProbT *nll_backward, - ProbT *grads, int stride, int out_dim, - int S_memoffset, int T_memoffset, int blank_label) { +void compute_betas_and_grad_kernel( + const ProbT* probs, + const int* label_sizes, + const int* utt_length, + const int* repeats_in_labels, + const int* labels_with_blanks, + ProbT* alphas, + const ProbT* nll_forward, + ProbT* nll_backward, + ProbT* grads, + int stride, + int out_dim, + int S_memoffset, + int T_memoffset, + int blank_label +) { ctc_helper::log_plus log_plus_f; typedef CTASegReduce> SegReduce; @@ -229,7 +259,7 @@ void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes, const int tid = threadIdx.x; const int L = label_sizes[blockIdx.x]; const int T = utt_length[blockIdx.x]; - const int S = 2*L + 1; + const int S = 2 * L + 1; const int prob_offset = out_dim * blockIdx.x; const int repeats = repeats_in_labels[blockIdx.x]; const ProbT log_partition = -nll_forward[blockIdx.x]; @@ -257,15 +287,16 @@ void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes, ProbT beta_val[VT]; - if ((L + repeats) > T) + if((L + repeats) > T) { return; + } int start = S > 1 ? (S - 2) : 0; - int end = (L + repeats < T) ? S : S-1; + int end = (L + repeats < T) ? S : S - 1; // Setup shared memory buffers - #pragma unroll - for (int idx = tid; idx < NV; idx += NT) { +#pragma unroll + for(int idx = tid; idx < NV; idx += NT) { label[idx] = (idx < S) ? label_global[idx] : INT_MAX; } @@ -281,8 +312,8 @@ void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes, int key[VT]; int gather_val[VT]; - #pragma unroll - for (int i = 0; i < VT; ++i) { +#pragma unroll + for(int i = 0; i < VT; ++i) { const int idx = tid * VT + i; gather_val[i] = idx; key[i] = label[idx]; @@ -290,20 +321,33 @@ void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes, __syncthreads(); - CTAMergesort> - (key, gather_val, keys_shared, gather_indices, S, tid, mgpu::less()); + CTAMergesort>( + key, + gather_val, + keys_shared, + gather_indices, + S, + tid, + mgpu::less() + ); __syncthreads(); - for (int i = 0; i < VT; ++i) { + for(int i = 0; i < VT; ++i) { const int idx = tid * VT + i; gather_indices[idx] = gather_val[i]; } __syncthreads(); - SegReduce::preprocessKeys(keys_shared, S, &uniquelabels, seg_start, seg_end, - temp_buffer.result); + SegReduce::preprocessKeys( + keys_shared, + S, + &uniquelabels, + seg_start, + seg_end, + temp_buffer.result + ); __syncthreads(); } @@ -311,22 +355,23 @@ void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes, __syncthreads(); // Load labels back - #pragma unroll - for (int idx = tid; idx < NV; idx += NT) { +#pragma unroll + for(int idx = tid; idx < NV; idx += NT) { temp_buffer.beta[idx] = ctc_helper::neg_inf(); } __syncthreads(); // Initialize the two rightmost values in the last row (assuming L non-zero) - for(int i = tid; i < (end-start); i += blockDim.x) + for(int i = tid; i < (end - start); i += blockDim.x) { temp_buffer.beta[i + start] = probs[prob_offset + (T - 1) * (out_dim * stride) + label[i + start]]; + } __syncthreads(); // Load output data in registers through the transpose trick - should really be a function - #pragma unroll - for (int idx = tid; idx < S; idx += NT) { +#pragma unroll + for(int idx = tid; idx < S; idx += NT) { output[idx] = alpha[idx + (T - 1) * S] + temp_buffer.beta[idx]; } @@ -341,19 +386,22 @@ void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes, // Starting offset of column that we read from the probs array const int start_prob_col = t * (out_dim * stride); - if (t < T-1) { + if(t < T - 1) { // Filling up one row at at time but going back in time from the last row // to the first. As in the forward pass, there is no loop dependence and we // do a variable length filter of maximum filter size of 3 - #pragma unroll - for(int idx = tid, i = 0; idx < (S-1); idx += NT, i++) { - ProbT next_sum = log_plus_f(temp_buffer.beta[idx], temp_buffer.beta[idx+1]); - - // Skip two if not on blank and not on repeat. - if ((label[idx] != blank_label) && - (idx != (S-2)) && (label[idx] != label[idx+2])) - next_sum = log_plus_f(next_sum, temp_buffer.beta[idx+2]); +#pragma unroll + for(int idx = tid, i = 0; idx < (S - 1); idx += NT, i++) { + ProbT next_sum = log_plus_f(temp_buffer.beta[idx], temp_buffer.beta[idx + 1]); + + // Skip two if not on blank and not on repeat. + if( + (label[idx] != blank_label) + && (idx != (S - 2)) && (label[idx] != label[idx + 2]) + ) { + next_sum = log_plus_f(next_sum, temp_buffer.beta[idx + 2]); + } beta_val[i] = next_sum + probs[prob_offset + start_prob_col + label[idx]]; } @@ -362,22 +410,23 @@ void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes, // Initialize values for the rightmost column since there is nothing to the right // Update input buffer for next iteration - if ((tid == 0) && (end == S)) - temp_buffer.beta[(S-1)] = temp_buffer.beta[(S-1)] + - probs[prob_offset + start_prob_col + blank_label]; + if((tid == 0) && (end == S)) { + temp_buffer.beta[(S - 1)] = temp_buffer.beta[(S - 1)] + + probs[prob_offset + start_prob_col + blank_label]; + } - #pragma unroll - for(int idx = tid, i = 0; idx < (S-1); idx += NT, i++) { - temp_buffer.beta[idx] = beta_val[i]; +#pragma unroll + for(int idx = tid, i = 0; idx < (S - 1); idx += NT, i++) { + temp_buffer.beta[idx] = beta_val[i]; } __syncthreads(); // Beta Computation done - add to alpha and update the gradient. Reload // the gradient back for segmented reduce later on - #pragma unroll +#pragma unroll for(int idx = tid; idx < S; idx += NT) { - output[idx] = alpha[idx + start_cur_row] + temp_buffer.beta[idx]; + output[idx] = alpha[idx + start_cur_row] + temp_buffer.beta[idx]; } __syncthreads(); @@ -391,38 +440,39 @@ void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes, // Somewhat faster key value reduce ProbT accum[VT]; - for (int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) { + for(int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) { accum[j] = ctc_helper::neg_inf(); - for (int i = seg_start[j]; i <= seg_end[j]; ++i) { + for(int i = seg_start[j]; i <= seg_end[j]; ++i) { accum[j] = log_plus_f(accum[j], output[gather_indices[i]]); } } __syncthreads(); // Write accumulated value into output since that is not used - for (int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) { + for(int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) { output[idx] = accum[j]; } __syncthreads(); - for (int idx = tid; idx < out_dim; idx += blockDim.x) { + for(int idx = tid; idx < out_dim; idx += blockDim.x) { const int grads_offset = prob_offset + start_prob_col + idx; grads[grads_offset] = exp(probs[grads_offset]); } __syncthreads(); - for (int idx = tid; idx < uniquelabels; idx += blockDim.x) { + for(int idx = tid; idx < uniquelabels; idx += blockDim.x) { const int grads_offset = prob_offset + start_prob_col + keys_shared[idx]; ProbT grad = output[idx]; - if ((grad == 0.0) || (exp(probs[grads_offset]) == 0.0) || - (grad == ctc_helper::neg_inf())) { - } else { - grads[grads_offset] = exp(probs[grads_offset]) - - exp(grad - probs[grads_offset] - log_partition); + if( + (grad == 0.0) || (exp(probs[grads_offset]) == 0.0) + || (grad == ctc_helper::neg_inf()) + ) {} else { + grads[grads_offset] = exp(probs[grads_offset]) + - exp(grad - probs[grads_offset] - log_partition); } } @@ -430,17 +480,18 @@ void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes, } // Output backward log likelihood - if ((t == 0) && (tid == 0)) { + if((t == 0) && (tid == 0)) { ProbT loglike = ctc_helper::neg_inf(); - const int val = 2 * (L-1) + 1 - (((L + repeats) == T) ? 1 : 0); + const int val = 2 * (L - 1) + 1 - (((L + repeats) == T) ? 1 : 0); start = (-val * (L != 0) + start); end = (-val * (L != 0) + end); // Sum and return the leftmost one/two value(s) in first row - for(int i = start; i < end; ++i) + for(int i = start; i < end; ++i) { loglike = log_plus_f(loglike, temp_buffer.beta[i]); + } nll_backward[blockIdx.x] = -loglike; } @@ -450,17 +501,20 @@ void compute_betas_and_grad_kernel (const ProbT* probs, const int *label_sizes, } } -template -__global__ void compute_log_probs_kernel(Op f, ProbT* probs, - const ProbT* const denom, - int alphabet_size, - int count) { +template +__global__ void compute_log_probs_kernel( + Op f, + ProbT* probs, + const ProbT* const denom, + int alphabet_size, + int count +) { int idx = blockDim.x * blockIdx.x + threadIdx.x; int stride = blockDim.x * gridDim.x; #pragma unroll for(int i = 0; i < VT; i++) { - if (idx < count) { + if(idx < count) { const int column_idx = idx / alphabet_size; probs[idx] -= f(denom[column_idx]); } @@ -468,20 +522,23 @@ __global__ void compute_log_probs_kernel(Op f, ProbT* probs, } } -template -__global__ void prepare_stable_SM_kernel(Op f, ProbT* probs, - const ProbT* const col_max, - int alphabet_size, - int count) { +template +__global__ void prepare_stable_SM_kernel( + Op f, + ProbT* probs, + const ProbT* const col_max, + int alphabet_size, + int count +) { int idx = blockDim.x * blockIdx.x + threadIdx.x; int stride = blockDim.x * gridDim.x; #pragma unroll for(int i = 0; i < VT; i++) { - if (idx < count) { + if(idx < count) { const int column_idx = idx / alphabet_size; probs[idx] = f(probs[idx] - col_max[column_idx]); } idx += stride; } -} \ No newline at end of file +} diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/hostdevice.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/hostdevice.h index 7bec1e0..ed27cf1 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/hostdevice.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/hostdevice.h @@ -1,7 +1,7 @@ #pragma once #ifdef __CUDACC__ - #define HOSTDEVICE __host__ __device__ +#define HOSTDEVICE __host__ __device__ #else - #define HOSTDEVICE +#define HOSTDEVICE #endif diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/reduce.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/reduce.h index cdcad76..793c193 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/reduce.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/reduce.h @@ -1,5 +1,12 @@ #pragma once -ctcStatus_t reduce_negate(const float* input, float* output, int rows, int cols, bool axis, cudaStream_t stream); +ctcStatus_t reduce_negate( + const float* input, + float* output, + int rows, + int cols, + bool axis, + cudaStream_t stream +); ctcStatus_t reduce_exp(const float* input, float* output, int rows, int cols, bool axis, cudaStream_t stream); ctcStatus_t reduce_max(const float* input, float* output, int rows, int cols, bool axis, cudaStream_t stream); diff --git a/flashlight/pkg/text/data/TextDataset.cpp b/flashlight/pkg/text/data/TextDataset.cpp index 2ffbb68..6ae0bd7 100644 --- a/flashlight/pkg/text/data/TextDataset.cpp +++ b/flashlight/pkg/text/data/TextDataset.cpp @@ -30,147 +30,153 @@ TextDataset::TextDataset( int64_t batchSize /* = 1 */, const std::string& sampleBreakMode /* = "none" */, const bool useDynamicBatching /* = false */, - const size_t reserveSpaceSize /* = kMaxTokenInBuffer */) - : pad_(dictionary.getIndex(fl::lib::text::kPadToken)) { - /* 1. Read data */ - // data_ will have the following layout: - // sentence sentence ... sentence - data_.clear(); - data_.reserve(reserveSpaceSize); - const auto eos = dictionary.getIndex(fl::lib::text::kEosToken); - data_.push_back(eos); - - // Each pair of indices in sentenceRanges indicates the position in data_ of - // the 2 tokens around a given sentence. - std::vector> sentenceRanges; - auto files = lib::split(',', filenames); - for (const auto& file : files) { - const fs::path path = dataDirectory / file; - reader.loadFile(path); - - while (reader.hasNextLine()) { - const auto currentEosPosition = data_.size() - 1; - if (!sentenceRanges.empty()) { - sentenceRanges.back().second = currentEosPosition; - } - - const auto tokens = tokenizer.tokenize(reader.getLine()); - const auto indices = dictionary.mapEntriesToIndices(tokens); - if (data_.size() + indices.size() > kMaxTokenInBuffer) { - FL_LOG(LogLevel::INFO) - << "[TextDataset] stop loading at 10,000,000,000 tokens"; - break; - } - sentenceRanges.emplace_back(currentEosPosition, -1); - data_.insert(data_.end(), indices.begin(), indices.end()); - data_.push_back(eos); - } - if (!sentenceRanges.empty()) { - sentenceRanges.back().second = data_.size() - 1; - } - } - const int64_t nTokens = data_.size(); - - /* 2. Batchify */ - if (batchSize <= 0) { - throw std::invalid_argument( - "[TextDataset] BatchSize needs to be positive."); - } - - if (sampleBreakMode == "none") { - // Sentences are split into equal size (=`tokensPerSample`) - // Total tokens per batch is `batchSize` * `tokensPerSample` - - const int64_t nSamples = (nTokens + tokensPerSample - 1) / tokensPerSample; - const int64_t nBatches = (nSamples + batchSize - 1) / batchSize; - for (int64_t b = 0; b < nBatches; ++b) { - const int64_t firstSample = b * batchSize; - const int64_t lastSample = std::min((b + 1) * batchSize, nSamples); - std::vector batch; - for (int64_t s = firstSample; s < lastSample; ++s) { - const int64_t firstToken = s * tokensPerSample; - const int64_t lastToken = std::min((s + 1) * tokensPerSample, nTokens); - batch.emplace_back(SamplePosition{firstToken, lastToken - 1}); - } - batches_.push_back(std::move(batch)); - } - } else if (sampleBreakMode == "eos") { - // Each sentence must begin and end in . - // Sentences with length > `tokensPerSample` are skipped; - // Total tokens per batch <= `batchSize` * `tokensPerSample` - - if (useDynamicBatching) { - // sorting samples by length in ascending order - std::sort( - sentenceRanges.begin(), - sentenceRanges.end(), - [](const std::pair& p1, - const std::pair& p2) { - return p1.second - p1.first < p2.second - p2.first; - }); + const size_t reserveSpaceSize /* = kMaxTokenInBuffer */ +) : pad_(dictionary.getIndex(fl::lib::text::kPadToken)) { + /* 1. Read data */ + // data_ will have the following layout: + // sentence sentence ... sentence + data_.clear(); + data_.reserve(reserveSpaceSize); + const auto eos = dictionary.getIndex(fl::lib::text::kEosToken); + data_.push_back(eos); + + // Each pair of indices in sentenceRanges indicates the position in data_ of + // the 2 tokens around a given sentence. + std::vector> sentenceRanges; + auto files = lib::split(',', filenames); + for(const auto& file : files) { + const fs::path path = dataDirectory / file; + reader.loadFile(path); + + while(reader.hasNextLine()) { + const auto currentEosPosition = data_.size() - 1; + if(!sentenceRanges.empty()) { + sentenceRanges.back().second = currentEosPosition; + } + + const auto tokens = tokenizer.tokenize(reader.getLine()); + const auto indices = dictionary.mapEntriesToIndices(tokens); + if(data_.size() + indices.size() > kMaxTokenInBuffer) { + FL_LOG(LogLevel::INFO) + << "[TextDataset] stop loading at 10,000,000,000 tokens"; + break; + } + sentenceRanges.emplace_back(currentEosPosition, -1); + data_.insert(data_.end(), indices.begin(), indices.end()); + data_.push_back(eos); + } + if(!sentenceRanges.empty()) { + sentenceRanges.back().second = data_.size() - 1; + } } + const int64_t nTokens = data_.size(); - std::vector batch; - for (int64_t i = 0; i < sentenceRanges.size(); ++i) { - const auto startPoint = sentenceRanges[i].first; - const auto endPoint = sentenceRanges[i].second; - const int64_t sampleSize = endPoint - startPoint + 1; - batch.emplace_back(SamplePosition{startPoint, endPoint}); - - bool isFull; - if (useDynamicBatching) { - isFull = sampleSize * (batch.size() + 1) > batchSize * tokensPerSample; - } else { - isFull = batch.size() == batchSize; - } - if (isFull) { - batches_.push_back(std::move(batch)); - batch = std::vector(); - } + /* 2. Batchify */ + if(batchSize <= 0) { + throw std::invalid_argument( + "[TextDataset] BatchSize needs to be positive." + ); } - if (!batch.empty()) { - batches_.push_back(std::move(batch)); + + if(sampleBreakMode == "none") { + // Sentences are split into equal size (=`tokensPerSample`) + // Total tokens per batch is `batchSize` * `tokensPerSample` + + const int64_t nSamples = (nTokens + tokensPerSample - 1) / tokensPerSample; + const int64_t nBatches = (nSamples + batchSize - 1) / batchSize; + for(int64_t b = 0; b < nBatches; ++b) { + const int64_t firstSample = b * batchSize; + const int64_t lastSample = std::min((b + 1) * batchSize, nSamples); + std::vector batch; + for(int64_t s = firstSample; s < lastSample; ++s) { + const int64_t firstToken = s * tokensPerSample; + const int64_t lastToken = std::min((s + 1) * tokensPerSample, nTokens); + batch.emplace_back(SamplePosition{firstToken, lastToken - 1}); + } + batches_.push_back(std::move(batch)); + } + } else if(sampleBreakMode == "eos") { + // Each sentence must begin and end in . + // Sentences with length > `tokensPerSample` are skipped; + // Total tokens per batch <= `batchSize` * `tokensPerSample` + + if(useDynamicBatching) { + // sorting samples by length in ascending order + std::sort( + sentenceRanges.begin(), + sentenceRanges.end(), + [](const std::pair& p1, + const std::pair& p2) { + return p1.second - p1.first < p2.second - p2.first; + } + ); + } + + std::vector batch; + for(int64_t i = 0; i < sentenceRanges.size(); ++i) { + const auto startPoint = sentenceRanges[i].first; + const auto endPoint = sentenceRanges[i].second; + const int64_t sampleSize = endPoint - startPoint + 1; + batch.emplace_back(SamplePosition{startPoint, endPoint}); + + bool isFull; + if(useDynamicBatching) { + isFull = sampleSize * (batch.size() + 1) > batchSize * tokensPerSample; + } else { + isFull = batch.size() == batchSize; + } + if(isFull) { + batches_.push_back(std::move(batch)); + batch = std::vector(); + } + } + if(!batch.empty()) { + batches_.push_back(std::move(batch)); + } + } else { + throw std::invalid_argument( + "Invalid sampleBreakMode: should be none or eos, but it is given " + + sampleBreakMode + ); } - } else { - throw std::invalid_argument( - "Invalid sampleBreakMode: should be none or eos, but it is given " + - sampleBreakMode); - } - - FL_LOG(LogLevel::INFO) << "[TextDataset] (" << reader.getRank() << "/" - << reader.getTotalReaders() << ") Loaded " << nTokens - << " tokens, " << sentenceRanges.size() - << " sentences and " << size() << " batches"; + + FL_LOG(LogLevel::INFO) << "[TextDataset] (" << reader.getRank() << "/" + << reader.getTotalReaders() << ") Loaded " << nTokens + << " tokens, " << sentenceRanges.size() + << " sentences and " << size() << " batches"; } int64_t TextDataset::size() const { - return batches_.size(); + return batches_.size(); } std::vector TextDataset::get(const int64_t idx) const { - const auto& batch = batches_[idx % size()]; - int64_t maxLength = 0; - for (const auto& pos : batch) { - maxLength = std::max(maxLength, pos.last - pos.first + 1); - } - std::vector buffer(batch.size() * maxLength, pad_); - for (int64_t i = 0; i < batch.size(); ++i) { - const auto& pos = batch[i]; - std::memcpy( - buffer.data() + i * maxLength, - data_.data() + pos.first, - sizeof(int) * (pos.last - pos.first + 1)); - } - return {Tensor::fromVector( - {maxLength, static_cast(batch.size())}, buffer)}; + const auto& batch = batches_[idx % size()]; + int64_t maxLength = 0; + for(const auto& pos : batch) { + maxLength = std::max(maxLength, pos.last - pos.first + 1); + } + std::vector buffer(batch.size() * maxLength, pad_); + for(int64_t i = 0; i < batch.size(); ++i) { + const auto& pos = batch[i]; + std::memcpy( + buffer.data() + i * maxLength, + data_.data() + pos.first, + sizeof(int) * (pos.last - pos.first + 1) + ); + } + return {Tensor::fromVector( + {maxLength, static_cast(batch.size())}, + buffer + )}; } void TextDataset::shuffle(uint64_t seed) { - std::mt19937_64 rng(seed); - // Deterministic method across compilers. - for (uint64_t i = size() - 1; i >= 1; --i) { - std::swap(batches_[i], batches_[rng() % (i + 1)]); - } + std::mt19937_64 rng(seed); + // Deterministic method across compilers. + for(uint64_t i = size() - 1; i >= 1; --i) { + std::swap(batches_[i], batches_[rng() % (i + 1)]); + } } } // namespace fl diff --git a/flashlight/pkg/text/data/TextDataset.h b/flashlight/pkg/text/data/TextDataset.h index 1e8a585..a357d01 100644 --- a/flashlight/pkg/text/data/TextDataset.h +++ b/flashlight/pkg/text/data/TextDataset.h @@ -18,16 +18,16 @@ namespace fl { namespace pkg { -namespace text { + namespace text { -namespace { + namespace { // Maximum number of tokens to keep in memory for each `TextDataset` instance. // Setting the default value to 10,000,000,000 which requires 40GB in memory, // since indices are stored as int32. -constexpr size_t kMaxTokenInBuffer = 10000000000; + constexpr size_t kMaxTokenInBuffer = 10000000000; -} // namespace + } // namespace /** * TextDataset prepares text data for LM training. It returns a single tensor of @@ -58,38 +58,39 @@ constexpr size_t kMaxTokenInBuffer = 10000000000; * batch, samples are sorted by length. */ -class TextDataset : public fl::Dataset { - public: - TextDataset( - const fs::path& dataDirectory, - const std::string& filenames, - fl::lib::text::PartialFileReader& reader, - const fl::lib::text::Tokenizer& tokenizer, - const fl::lib::text::Dictionary& dictionary, - int64_t tokensPerSample = 1024, - int64_t batchSize = 1, - const std::string& sampleBreakMode = "none", - const bool useDynamicBatching = false, - const size_t reserveSpaceSize = kMaxTokenInBuffer); + class TextDataset : public fl::Dataset { + public: + TextDataset( + const fs::path& dataDirectory, + const std::string& filenames, + fl::lib::text::PartialFileReader& reader, + const fl::lib::text::Tokenizer& tokenizer, + const fl::lib::text::Dictionary& dictionary, + int64_t tokensPerSample = 1024, + int64_t batchSize = 1, + const std::string& sampleBreakMode = "none", + const bool useDynamicBatching = false, + const size_t reserveSpaceSize = kMaxTokenInBuffer + ); - int64_t size() const override; + int64_t size() const override; - std::vector get(const int64_t idx) const override; + std::vector get(const int64_t idx) const override; - void shuffle(uint64_t seed); + void shuffle(uint64_t seed); - private: - int pad_; + private: + int pad_; - struct SamplePosition { - int64_t first; - int64_t last; - }; + struct SamplePosition { + int64_t first; + int64_t last; + }; - std::vector data_; // eos prepended, so all indices shifted by 1 - std::vector> batches_; -}; + std::vector data_; // eos prepended, so all indices shifted by 1 + std::vector> batches_; + }; -} // namespace text + } // namespace text } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/text/test/data/TextDatasetTest.cpp b/flashlight/pkg/text/test/data/TextDatasetTest.cpp index 191f4e3..ec37c03 100644 --- a/flashlight/pkg/text/test/data/TextDatasetTest.cpp +++ b/flashlight/pkg/text/test/data/TextDatasetTest.cpp @@ -27,126 +27,126 @@ using namespace fl::pkg::text; fs::path dataDir = ""; Dictionary createDictionary(const std::string& path) { - Dictionary dictionary; - std::ifstream stream(path); - if (!stream) { - throw std::runtime_error("createDictionary - invalid path"); - } - - std::string line; - while (std::getline(stream, line)) { - if (line.empty()) { - continue; + Dictionary dictionary; + std::ifstream stream(path); + if(!stream) { + throw std::runtime_error("createDictionary - invalid path"); } - auto tkns = splitOnWhitespace(line, true); - dictionary.addEntry(tkns.front()); - } - if (!dictionary.isContiguous()) { - throw std::runtime_error("Invalid dictionary_ format - not contiguous"); - } - dictionary.setDefaultIndex(dictionary.getIndex(fl::lib::text::kUnkToken)); - return dictionary; + + std::string line; + while(std::getline(stream, line)) { + if(line.empty()) { + continue; + } + auto tkns = splitOnWhitespace(line, true); + dictionary.addEntry(tkns.front()); + } + if(!dictionary.isContiguous()) { + throw std::runtime_error("Invalid dictionary_ format - not contiguous"); + } + dictionary.setDefaultIndex(dictionary.getIndex(fl::lib::text::kUnkToken)); + return dictionary; } TEST(TextDatasetTest, NoneMode) { - fl::lib::text::Tokenizer tokenizer; - fl::lib::text::PartialFileReader partialFileReader(0, 1); - Dictionary dictionary = createDictionary(dataDir / "dictionary.txt"); - - int tokensPerSample = 5; - int batchSize = 2; - - TextDataset dataset( - dataDir, - "train.txt", - partialFileReader, - tokenizer, - dictionary, - tokensPerSample, - batchSize, - "none", - /* useDynamicBatching = */ false, - /* reserveSpaceSize = */ 0); - - ASSERT_EQ(dataset.size(), 4); - for (int i = 0; i < dataset.size(); i++) { - auto sample = dataset.get(i); - ASSERT_EQ(sample.size(), 1); - ASSERT_EQ(sample[0].dim(0), tokensPerSample); - ASSERT_EQ(sample[0].dim(1), batchSize); - } + fl::lib::text::Tokenizer tokenizer; + fl::lib::text::PartialFileReader partialFileReader(0, 1); + Dictionary dictionary = createDictionary(dataDir / "dictionary.txt"); + + int tokensPerSample = 5; + int batchSize = 2; + + TextDataset dataset( + dataDir, + "train.txt", + partialFileReader, + tokenizer, + dictionary, + tokensPerSample, + batchSize, + "none", + /* useDynamicBatching = */ false, + /* reserveSpaceSize = */ 0); + + ASSERT_EQ(dataset.size(), 4); + for(int i = 0; i < dataset.size(); i++) { + auto sample = dataset.get(i); + ASSERT_EQ(sample.size(), 1); + ASSERT_EQ(sample[0].dim(0), tokensPerSample); + ASSERT_EQ(sample[0].dim(1), batchSize); + } } TEST(TextDatasetTest, EosMode) { - fl::lib::text::Tokenizer tokenizer; - fl::lib::text::PartialFileReader partialFileReader(0, 1); - Dictionary dictionary = createDictionary(dataDir / "dictionary.txt"); - - int tokensPerSample = 5; - int batchSize = 2; - - TextDataset dataset( - dataDir, - "train.txt", - partialFileReader, - tokenizer, - dictionary, - tokensPerSample, - batchSize, - "eos", - /* useDynamicBatching = */ false, - /* reserveSpaceSize = */ 0); - - ASSERT_EQ(dataset.size(), 4); - - std::vector targetLen = {7, 5, 5, 7}; - for (int i = 0; i < dataset.size(); i++) { - auto sample = dataset.get(i); - ASSERT_EQ(sample.size(), 1); - ASSERT_EQ(sample[0].dim(0), targetLen[i]); - ASSERT_EQ(sample[0].dim(1), batchSize); - } + fl::lib::text::Tokenizer tokenizer; + fl::lib::text::PartialFileReader partialFileReader(0, 1); + Dictionary dictionary = createDictionary(dataDir / "dictionary.txt"); + + int tokensPerSample = 5; + int batchSize = 2; + + TextDataset dataset( + dataDir, + "train.txt", + partialFileReader, + tokenizer, + dictionary, + tokensPerSample, + batchSize, + "eos", + /* useDynamicBatching = */ false, + /* reserveSpaceSize = */ 0); + + ASSERT_EQ(dataset.size(), 4); + + std::vector targetLen = {7, 5, 5, 7}; + for(int i = 0; i < dataset.size(); i++) { + auto sample = dataset.get(i); + ASSERT_EQ(sample.size(), 1); + ASSERT_EQ(sample[0].dim(0), targetLen[i]); + ASSERT_EQ(sample[0].dim(1), batchSize); + } } TEST(TextDatasetTest, EosModeWithDynamicBatching) { - fl::lib::text::Tokenizer tokenizer; - fl::lib::text::PartialFileReader partialFileReader( - fl::getWorldRank(), fl::getWorldSize()); - Dictionary dictionary = createDictionary(dataDir / "dictionary.txt"); - - int tokensPerSample = 15; - - TextDataset dataset( - dataDir, - "train.txt", - partialFileReader, - tokenizer, - dictionary, - tokensPerSample, - 1, - "eos", - /* useDynamicBatching = */ true, - /* reserveSpaceSize = */ 0); - - ASSERT_EQ(dataset.size(), 4); - - std::vector targetLen = {5, 6, 7, 7}; - std::vector targetBsz = {3, 2, 2, 1}; - for (int i = 0; i < dataset.size(); i++) { - auto sample = dataset.get(i); - ASSERT_EQ(sample.size(), 1); - ASSERT_EQ(sample[0].dim(0), targetLen[i]); - ASSERT_EQ(sample[0].dim(1), targetBsz[i]); - } + fl::lib::text::Tokenizer tokenizer; + fl::lib::text::PartialFileReader partialFileReader( + fl::getWorldRank(), fl::getWorldSize()); + Dictionary dictionary = createDictionary(dataDir / "dictionary.txt"); + + int tokensPerSample = 15; + + TextDataset dataset( + dataDir, + "train.txt", + partialFileReader, + tokenizer, + dictionary, + tokensPerSample, + 1, + "eos", + /* useDynamicBatching = */ true, + /* reserveSpaceSize = */ 0); + + ASSERT_EQ(dataset.size(), 4); + + std::vector targetLen = {5, 6, 7, 7}; + std::vector targetBsz = {3, 2, 2, 1}; + for(int i = 0; i < dataset.size(); i++) { + auto sample = dataset.get(i); + ASSERT_EQ(sample.size(), 1); + ASSERT_EQ(sample[0].dim(0), targetLen[i]); + ASSERT_EQ(sample[0].dim(1), targetBsz[i]); + } } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); #ifdef TEXTDATASET_TEST_DATADIR - dataDir = fs::path(TEXTDATASET_TEST_DATADIR); + dataDir = fs::path(TEXTDATASET_TEST_DATADIR); #endif - return RUN_ALL_TESTS(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/vision/common/BetaDistribution.h b/flashlight/pkg/vision/common/BetaDistribution.h index 82282bc..c0775a6 100644 --- a/flashlight/pkg/vision/common/BetaDistribution.h +++ b/flashlight/pkg/vision/common/BetaDistribution.h @@ -1,4 +1,3 @@ - /* * Copyright (c) Meta Platforms, Inc. and affiliates. * @@ -28,123 +27,126 @@ namespace lib { * This implementation utilized `std::gamma_distribution`, given that * X / (X + Y) ~ Beta(a, b), when X ~ Gamma(a, c), Y ~ Gamma(b, c). */ -template -class beta_distribution { - public: - using result_type = RealType; - - class param_type { - public: - using distribution_type = beta_distribution; - - explicit param_type(RealType a = 1.0, RealType b = 1.0) - : a_param(a), b_param(b) {} - - RealType a() const { - return a_param; - } - RealType b() const { - return b_param; - } - - bool operator==(const param_type& other) const { - return (a_param == other.a_param && b_param == other.b_param); + template + class beta_distribution { + public: + using result_type = RealType; + + class param_type { + public: + using distribution_type = beta_distribution; + + explicit param_type(RealType a = 1.0, RealType b = 1.0) : a_param(a), + b_param(b) {} + + RealType a() const { + return a_param; + } + RealType b() const { + return b_param; + } + + bool operator==(const param_type& other) const { + return a_param == other.a_param && b_param == other.b_param; + } + + bool operator!=(const param_type& other) const { + return !(*this == other); + } + + private: + RealType a_param, b_param; + }; + + explicit beta_distribution(RealType a = 1.0, RealType b = 1.0) : a_gamma(a), + b_gamma(b) {} + explicit beta_distribution(const param_type& param) : a_gamma(param.a()), + b_gamma(param.b()) {} + + void reset() {} + + param_type param() const { + return param_type(a(), b()); + } + + void param(const param_type& param) { + a_gamma = gamma_dist_type(param.a()); + b_gamma = gamma_dist_type(param.b()); + } + + template + result_type operator()(URNG& engine) { + return generate(engine, a_gamma, b_gamma); + } + + template + result_type operator()(URNG& engine, const param_type& param) { + gamma_dist_type a_param_gamma(param.a()), b_param_gamma(param.b()); + return generate(engine, a_param_gamma, b_param_gamma); + } + + result_type min() const { + return 0.0; + } + result_type max() const { + return 1.0; + } + + RealType a() const { + return a_gamma.alpha(); + } + RealType b() const { + return b_gamma.alpha(); + } + + bool operator==(const beta_distribution& other) const { + return + param() == other.param() && a_gamma == other.a_gamma + && b_gamma == other.b_gamma; + } + + bool operator!=(const beta_distribution& other) const { + return !(*this == other); + } + + private: + using gamma_dist_type = std::gamma_distribution; + + gamma_dist_type a_gamma, b_gamma; + + template + result_type generate(URNG& engine, gamma_dist_type& x_gamma, gamma_dist_type& y_gamma) { + result_type x = x_gamma(engine); + return x / (x + y_gamma(engine)); + } + }; + + template + std::basic_ostream& operator<<( + std::basic_ostream& os, + const beta_distribution& beta + ) { + os << "~Beta(" << beta.a() << "," << beta.b() << ")"; + return os; } - bool operator!=(const param_type& other) const { - return !(*this == other); + template + std::basic_istream& operator>>( + std::basic_istream& is, + beta_distribution& beta + ) { + std::string str; + RealType a, b; + if( + std::getline(is, str, '(') && str == "~Beta" && is >> a + && is.get() == ',' && is >> b && is.get() == ')' + ) { + beta = beta_distribution(a, b); + } else { + is.setstate(std::ios::failbit); + } + return is; } - private: - RealType a_param, b_param; - }; - - explicit beta_distribution(RealType a = 1.0, RealType b = 1.0) - : a_gamma(a), b_gamma(b) {} - explicit beta_distribution(const param_type& param) - : a_gamma(param.a()), b_gamma(param.b()) {} - - void reset() {} - - param_type param() const { - return param_type(a(), b()); - } - - void param(const param_type& param) { - a_gamma = gamma_dist_type(param.a()); - b_gamma = gamma_dist_type(param.b()); - } - - template - result_type operator()(URNG& engine) { - return generate(engine, a_gamma, b_gamma); - } - - template - result_type operator()(URNG& engine, const param_type& param) { - gamma_dist_type a_param_gamma(param.a()), b_param_gamma(param.b()); - return generate(engine, a_param_gamma, b_param_gamma); - } - - result_type min() const { - return 0.0; - } - result_type max() const { - return 1.0; - } - - RealType a() const { - return a_gamma.alpha(); - } - RealType b() const { - return b_gamma.alpha(); - } - - bool operator==(const beta_distribution& other) const { - return ( - param() == other.param() && a_gamma == other.a_gamma && - b_gamma == other.b_gamma); - } - - bool operator!=(const beta_distribution& other) const { - return !(*this == other); - } - - private: - using gamma_dist_type = std::gamma_distribution; - - gamma_dist_type a_gamma, b_gamma; - - template - result_type - generate(URNG& engine, gamma_dist_type& x_gamma, gamma_dist_type& y_gamma) { - result_type x = x_gamma(engine); - return x / (x + y_gamma(engine)); - } -}; - -template -std::basic_ostream& operator<<( - std::basic_ostream& os, - const beta_distribution& beta) { - os << "~Beta(" << beta.a() << "," << beta.b() << ")"; - return os; -} - -template -std::basic_istream& operator>>( - std::basic_istream& is, - beta_distribution& beta) { - std::string str; - RealType a, b; - if (std::getline(is, str, '(') && str == "~Beta" && is >> a && - is.get() == ',' && is >> b && is.get() == ')') { - beta = beta_distribution(a, b); - } else { - is.setstate(std::ios::failbit); - } - return is; -} - } // namespace lib } // namespace fl diff --git a/flashlight/pkg/vision/criterion/Hungarian.cpp b/flashlight/pkg/vision/criterion/Hungarian.cpp index 861648a..83faee3 100644 --- a/flashlight/pkg/vision/criterion/Hungarian.cpp +++ b/flashlight/pkg/vision/criterion/Hungarian.cpp @@ -15,29 +15,34 @@ namespace { using namespace fl; Tensor softmax(const Tensor& input, const int dim) { - auto maxvals = fl::amax(input, {dim}, /* keepDims = */ true); - Shape tiledims(std::vector(input.ndim(), 1)); - tiledims[dim] = input.dim(dim); + auto maxvals = fl::amax(input, {dim}, /* keepDims = */ true); + Shape tiledims(std::vector(input.ndim(), 1)); + tiledims[dim] = input.dim(dim); - auto expInput = fl::exp(input - fl::tile(maxvals, tiledims)); - auto result = expInput / - fl::tile(fl::sum(expInput, {dim}, /* keepDims = */ true), tiledims); - return result; + auto expInput = fl::exp(input - fl::tile(maxvals, tiledims)); + auto result = expInput + / fl::tile(fl::sum(expInput, {dim}, /* keepDims = */ true), tiledims); + return result; } std::pair hungarian(Tensor& cost) { - cost = fl::transpose(cost, {1, 0, 2, 3}); - const int M = cost.dim(0); - const int N = cost.dim(1); - std::vector costHost(cost.elements()); - std::vector rowIdxs(M); - std::vector colIdxs(M); - cost.host(costHost.data()); - fl::lib::set::hungarian( - costHost.data(), rowIdxs.data(), colIdxs.data(), M, N); - auto rowIdxsArray = Tensor::fromVector(rowIdxs); - auto colIdxsArray = Tensor::fromVector(colIdxs); - return {rowIdxsArray, colIdxsArray}; + cost = fl::transpose(cost, {1, 0, 2, 3}); + const int M = cost.dim(0); + const int N = cost.dim(1); + std::vector costHost(cost.elements()); + std::vector rowIdxs(M); + std::vector colIdxs(M); + cost.host(costHost.data()); + fl::lib::set::hungarian( + costHost.data(), + rowIdxs.data(), + colIdxs.data(), + M, + N + ); + auto rowIdxsArray = Tensor::fromVector(rowIdxs); + auto colIdxsArray = Tensor::fromVector(colIdxs); + return {rowIdxsArray, colIdxsArray}; } } // namespace @@ -46,56 +51,65 @@ namespace fl::pkg::vision { HungarianMatcher::HungarianMatcher( const float costClass, const float costBbox, - const float costGiou) - : costClass_(costClass), costBbox_(costBbox), costGiou_(costGiou){}; + const float costGiou +) : costClass_(costClass), + costBbox_(costBbox), + costGiou_(costGiou) {}; std::pair HungarianMatcher::matchBatch( const Tensor& predBoxes, const Tensor& predLogits, const Tensor& targetBoxes, - const Tensor& targetClasses) const { - // Kind of a hack... - if (targetClasses.isEmpty()) { - return {fl::fromScalar(0), fl::fromScalar(0)}; - } + const Tensor& targetClasses +) const { + // Kind of a hack... + if(targetClasses.isEmpty()) { + return {fl::fromScalar(0), fl::fromScalar(0)}; + } - // Create an M X N cost matrix where M is the number of targets and N is the - // number of preds - // Class cost - auto outProbs = ::softmax(predLogits, 0); - auto costClass = transpose((0 - outProbs(targetClasses)), {1, 0, 2}); + // Create an M X N cost matrix where M is the number of targets and N is the + // number of preds + // Class cost + auto outProbs = ::softmax(predLogits, 0); + auto costClass = transpose((0 - outProbs(targetClasses)), {1, 0, 2}); - // Generalized IOU loss - Tensor costGiou = - 0 - generalizedBoxIou(cxcywh2xyxy(predBoxes), cxcywh2xyxy(targetBoxes)); + // Generalized IOU loss + Tensor costGiou = + 0 - generalizedBoxIou(cxcywh2xyxy(predBoxes), cxcywh2xyxy(targetBoxes)); - // Bbox Cost - Tensor costBbox = - cartesian(predBoxes, targetBoxes, [](const Tensor& x, const Tensor& y) { - return fl::sum(fl::abs(x - y), {0}, /* keepDims = */ true); - }); - costBbox = flatten(costBbox, 0, 1); + // Bbox Cost + Tensor costBbox = + cartesian( + predBoxes, + targetBoxes, + [](const Tensor& x, const Tensor& y) { + return fl::sum(fl::abs(x - y), {0}, /* keepDims = */ true); + } + ); + costBbox = flatten(costBbox, 0, 1); - auto cost = - costBbox_ * costBbox + costClass_ * costClass + costGiou_ * costGiou; - return ::hungarian(cost); + auto cost = + costBbox_ * costBbox + costClass_ * costClass + costGiou_ * costGiou; + return ::hungarian(cost); } std::vector> HungarianMatcher::compute( const Tensor& predBoxes, const Tensor& predLogits, const std::vector& targetBoxes, - const std::vector& targetClasses) const { - std::vector> results; - for (int b = 0; b < predBoxes.dim(2); b++) { - auto result = matchBatch( - predBoxes(fl::span, fl::span, fl::range(b, b + 1)), - predLogits(fl::span, fl::span, fl::range(b, b + 1)), - targetBoxes[b], - targetClasses[b]); - results.emplace_back(result); - } - return results; + const std::vector& targetClasses +) const { + std::vector> results; + for(int b = 0; b < predBoxes.dim(2); b++) { + auto result = matchBatch( + predBoxes(fl::span, fl::span, fl::range(b, b + 1)), + predLogits(fl::span, fl::span, fl::range(b, b + 1)), + targetBoxes[b], + targetClasses[b] + ); + results.emplace_back(result); + } + return results; }; } // namespace fl diff --git a/flashlight/pkg/vision/criterion/Hungarian.h b/flashlight/pkg/vision/criterion/Hungarian.h index 8f34325..96795f3 100644 --- a/flashlight/pkg/vision/criterion/Hungarian.h +++ b/flashlight/pkg/vision/criterion/Hungarian.h @@ -10,38 +10,41 @@ namespace fl { namespace pkg { -namespace vision { - -class HungarianMatcher { - public: - HungarianMatcher() = default; - - HungarianMatcher( - const float costClass, - const float costBbox, - const float costGiou); - - std::vector> compute( - const Tensor& predBoxes, - const Tensor& predLogits, - const std::vector& targetBoxes, - const std::vector& targetClasses) const; - - private: - float costClass_; - float costBbox_; - float costGiou_; - - // First is SrcIdx, second is ColIdx - std::pair matchBatch( - const Tensor& predBoxes, - const Tensor& predLogits, - const Tensor& targetBoxes, - const Tensor& targetClasses) const; - - Tensor getCostMatrix(const Tensor& input, const Tensor& target); -}; - -} // namespace vision + namespace vision { + + class HungarianMatcher { + public: + HungarianMatcher() = default; + + HungarianMatcher( + const float costClass, + const float costBbox, + const float costGiou + ); + + std::vector> compute( + const Tensor& predBoxes, + const Tensor& predLogits, + const std::vector& targetBoxes, + const std::vector& targetClasses + ) const; + + private: + float costClass_; + float costBbox_; + float costGiou_; + + // First is SrcIdx, second is ColIdx + std::pair matchBatch( + const Tensor& predBoxes, + const Tensor& predLogits, + const Tensor& targetBoxes, + const Tensor& targetClasses + ) const; + + Tensor getCostMatrix(const Tensor& input, const Tensor& target); + }; + + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/criterion/HungarianImpl.cpp b/flashlight/pkg/vision/criterion/HungarianImpl.cpp index c537f14..f2365d6 100644 --- a/flashlight/pkg/vision/criterion/HungarianImpl.cpp +++ b/flashlight/pkg/vision/criterion/HungarianImpl.cpp @@ -13,7 +13,7 @@ namespace { -enum Mark : int { None = 0, Star = 1, Prime = 2 }; +enum Mark : int {None = 0, Star = 1, Prime = 2}; void findUncoveredZero( float* costs, @@ -22,56 +22,59 @@ void findUncoveredZero( int nrows, int ncols, int* row, - int* col) { - bool done = false; - *row = -1; - *col = -1; - for (int c = 0; c < ncols && !done; c++) { - for (int r = 0; r < nrows && !done; r++) { - const float cost = costs[c * nrows + r]; - if (cost == 0 && colCover[c] == 0 && rowCover[r] == 0) { - *row = r; - *col = c; - done = true; - } + int* col +) { + bool done = false; + *row = -1; + *col = -1; + for(int c = 0; c < ncols && !done; c++) { + for(int r = 0; r < nrows && !done; r++) { + const float cost = costs[c * nrows + r]; + if(cost == 0 && colCover[c] == 0 && rowCover[r] == 0) { + *row = r; + *col = c; + done = true; + } + } } - } } bool isStarInRow(int* marks, int row, int nrows, int ncols) { - for (int c = 0; c < ncols; c++) { - if (marks[c * nrows + row] == Mark::Star) { - return true; - }; - } - return false; + for(int c = 0; c < ncols; c++) { + if(marks[c * nrows + row] == Mark::Star) { + return true; + } + ; + } + return false; } int findStarInRow(int* marks, int row, int nrows, int ncols) { - for (int c = 0; c < ncols; c++) { - if (marks[c * nrows + row] == Mark::Star) { - return c; - }; - } - return -1; + for(int c = 0; c < ncols; c++) { + if(marks[c * nrows + row] == Mark::Star) { + return c; + } + ; + } + return -1; } // M x N matrix M = nrows, N = ncols // For each row, substract it's minimum value int stepOne(float* costs, const int nrows, const int ncols) { - for (int r = 0; r < nrows; r++) { - float min_val = std::numeric_limits::max(); - for (int c = 0; c < ncols; c++) { - float val = costs[c * nrows + r]; - if (val < min_val) { - min_val = val; - } - } - for (int c = 0; c < ncols; c++) { - costs[c * nrows + r] -= min_val; + for(int r = 0; r < nrows; r++) { + float min_val = std::numeric_limits::max(); + for(int c = 0; c < ncols; c++) { + float val = costs[c * nrows + r]; + if(val < min_val) { + min_val = val; + } + } + for(int c = 0; c < ncols; c++) { + costs[c * nrows + r] -= min_val; + } } - } - return 2; + return 2; } // Iterate through rows, and mark 0s with '1' (Star) if they have not already @@ -82,24 +85,25 @@ int stepTwo( int* colCover, int* rowCover, const int nrows, - const int ncols) { - for (int r = 0; r < nrows; r++) { - for (int c = 0; c < ncols; c++) { - float cost = costs[c * nrows + r]; - if (cost == 0.0 && rowCover[r] == 0 && colCover[c] == 0) { - marks[c * nrows + r] = Mark::Star; - rowCover[r] = 1; - colCover[c] = 1; - } + const int ncols +) { + for(int r = 0; r < nrows; r++) { + for(int c = 0; c < ncols; c++) { + float cost = costs[c * nrows + r]; + if(cost == 0.0 && rowCover[r] == 0 && colCover[c] == 0) { + marks[c * nrows + r] = Mark::Star; + rowCover[r] = 1; + colCover[c] = 1; + } + } } - } - for (int r = 0; r < nrows; r++) { - rowCover[r] = 0; - } - for (int c = 0; c < ncols; c++) { - colCover[c] = 0; - } - return 3; + for(int r = 0; r < nrows; r++) { + rowCover[r] = 0; + } + for(int c = 0; c < ncols; c++) { + colCover[c] = 0; + } + return 3; } // Count the number of lines needed to cover all "Stars" @@ -108,24 +112,25 @@ int stepThree( int* colCover, int* /*rowCover*/, int nrows, - int ncols) { - for (int r = 0; r < nrows; r++) { - for (int c = 0; c < ncols; c++) { - const int mark = marks[c * nrows + r]; - if (mark == 1) { - colCover[c] = 1; - } + int ncols +) { + for(int r = 0; r < nrows; r++) { + for(int c = 0; c < ncols; c++) { + const int mark = marks[c * nrows + r]; + if(mark == 1) { + colCover[c] = 1; + } + } + } + int coveredCols = 0; + for(int c = 0; c < ncols; c++) { + coveredCols += colCover[c]; + } + if(coveredCols == ncols || coveredCols >= nrows) { + return 7; + } else { + return 4; } - } - int coveredCols = 0; - for (int c = 0; c < ncols; c++) { - coveredCols += colCover[c]; - } - if (coveredCols == ncols || coveredCols >= nrows) { - return 7; - } else { - return 4; - } } // Find a noncovered zero and "prime it". If there are no uncovered zeros in the @@ -140,47 +145,48 @@ int stepFour( int nrows, int ncols, int* firstPathRow, - int* firstPathCol) { - bool done = false; - while (!done) { - int row, col; - findUncoveredZero(costs, colCover, rowCover, nrows, ncols, &row, &col); - if (row < 0 && col < 0) { - return 6; - } else { - // "Prime it" - marks[col * nrows + row] = Mark::Prime; - if (isStarInRow(marks, row, nrows, ncols)) { - int c = findStarInRow(marks, row, nrows, ncols); - rowCover[row] = 1; - colCover[c] = 0; - } else { - *firstPathRow = row; - *firstPathCol = col; - done = true; - return 5; - } + int* firstPathCol +) { + bool done = false; + while(!done) { + int row, col; + findUncoveredZero(costs, colCover, rowCover, nrows, ncols, &row, &col); + if(row < 0 && col < 0) { + return 6; + } else { + // "Prime it" + marks[col * nrows + row] = Mark::Prime; + if(isStarInRow(marks, row, nrows, ncols)) { + int c = findStarInRow(marks, row, nrows, ncols); + rowCover[row] = 1; + colCover[c] = 0; + } else { + *firstPathRow = row; + *firstPathCol = col; + done = true; + return 5; + } + } } - } - return -1; + return -1; } int findStarInCol(int* masks, int col, int nrows, int /*ncols*/) { - for (int r = 0; r < nrows; r++) { - if (masks[col * nrows + r] == 1) { - return r; + for(int r = 0; r < nrows; r++) { + if(masks[col * nrows + r] == 1) { + return r; + } } - } - return -1; + return -1; } int findPrimeInRow(int* masks, int row, int nrows, int ncols) { - for (int c = 0; c < ncols; c++) { - if (masks[c * nrows + row] == 2) { - return c; + for(int c = 0; c < ncols; c++) { + if(masks[c * nrows + row] == 2) { + return c; + } } - } - return -1; + return -1; } void augmentPaths( @@ -188,32 +194,33 @@ void augmentPaths( int pathCount, int* marks, int nrows, - int /*ncols*/) { - for (int p = 0; p < pathCount; p++) { - int row = paths[p * 2]; - int col = paths[p * 2 + 1]; - if (marks[col * nrows + row] == Mark::Star) { - marks[col * nrows + row] = Mark::None; - } else { - marks[col * nrows + row] = Mark::Star; + int /*ncols*/ +) { + for(int p = 0; p < pathCount; p++) { + int row = paths[p * 2]; + int col = paths[p * 2 + 1]; + if(marks[col * nrows + row] == Mark::Star) { + marks[col * nrows + row] = Mark::None; + } else { + marks[col * nrows + row] = Mark::Star; + } } - } } void clearCover(int* cover, int n) { - for (int i = 0; i < n; i++) { - cover[i] = 0; - } + for(int i = 0; i < n; i++) { + cover[i] = 0; + } } void erasePrimes(int* marks, int nrows, int ncols) { - for (int c = 0; c < ncols; c++) { - for (int r = 0; r < nrows; r++) { - if (marks[c * nrows + r] == Mark::Prime) { - marks[c * nrows + r] = Mark::None; - } + for(int c = 0; c < ncols; c++) { + for(int r = 0; r < nrows; r++) { + if(marks[c * nrows + r] == Mark::Prime) { + marks[c * nrows + r] = Mark::None; + } + } } - } } int stepFive( @@ -225,34 +232,35 @@ int stepFive( int firstPathRow, int firstPathCol, int nrows, - int ncols) { - int r = -1; - int c = -1; - int pathCount = 1; - path[(pathCount - 1) * 2] = firstPathRow; - path[(pathCount - 1) * 2 + 1] = firstPathCol; - bool done = false; - while (!done) { - r = findStarInCol(marks, path[(pathCount - 1) * 2 + 1], nrows, ncols); - if (r > -1) { - pathCount += 1; - path[(pathCount - 1) * 2] = r; - path[(pathCount - 1) * 2 + 1] = path[(pathCount - 2) * 2 + 1]; - } else { - done = true; - } - if (!done) { - c = findPrimeInRow(marks, path[(pathCount - 1) * 2], nrows, ncols); - pathCount += 1; - path[(pathCount - 1) * 2] = path[(pathCount - 2) * 2]; - path[(pathCount - 1) * 2 + 1] = c; + int ncols +) { + int r = -1; + int c = -1; + int pathCount = 1; + path[(pathCount - 1) * 2] = firstPathRow; + path[(pathCount - 1) * 2 + 1] = firstPathCol; + bool done = false; + while(!done) { + r = findStarInCol(marks, path[(pathCount - 1) * 2 + 1], nrows, ncols); + if(r > -1) { + pathCount += 1; + path[(pathCount - 1) * 2] = r; + path[(pathCount - 1) * 2 + 1] = path[(pathCount - 2) * 2 + 1]; + } else { + done = true; + } + if(!done) { + c = findPrimeInRow(marks, path[(pathCount - 1) * 2], nrows, ncols); + pathCount += 1; + path[(pathCount - 1) * 2] = path[(pathCount - 2) * 2]; + path[(pathCount - 1) * 2 + 1] = c; + } } - } - augmentPaths(path, pathCount, marks, nrows, ncols); - clearCover(colCover, ncols); - clearCover(rowCover, nrows); - erasePrimes(marks, nrows, ncols); - return 3; + augmentPaths(path, pathCount, marks, nrows, ncols); + clearCover(colCover, ncols); + clearCover(rowCover, nrows); + erasePrimes(marks, nrows, ncols); + return 3; } float findSmallestNotCovered( @@ -260,19 +268,20 @@ float findSmallestNotCovered( int* colCover, int* rowCover, int nrows, - int ncols) { - float minValue = std::numeric_limits::max(); - for (int c = 0; c < ncols; c++) { - for (int r = 0; r < nrows; r++) { - if (colCover[c] == 0 && rowCover[r] == 0) { - const float cost = costs[c * nrows + r]; - if (cost < minValue) { - minValue = cost; + int ncols +) { + float minValue = std::numeric_limits::max(); + for(int c = 0; c < ncols; c++) { + for(int r = 0; r < nrows; r++) { + if(colCover[c] == 0 && rowCover[r] == 0) { + const float cost = costs[c * nrows + r]; + if(cost < minValue) { + minValue = cost; + } + } } - } } - } - return minValue; + return minValue; } int stepSix( @@ -281,34 +290,35 @@ int stepSix( int* colCover, int* rowCover, int nrows, - int ncols) { - float minVal = - findSmallestNotCovered(costs, colCover, rowCover, nrows, ncols); - for (int c = 0; c < ncols; c++) { - for (int r = 0; r < nrows; r++) { - if (rowCover[r] == 1) { - costs[c * nrows + r] += minVal; - } - if (colCover[c] == 0) { - costs[c * nrows + r] -= minVal; - } + int ncols +) { + float minVal = + findSmallestNotCovered(costs, colCover, rowCover, nrows, ncols); + for(int c = 0; c < ncols; c++) { + for(int r = 0; r < nrows; r++) { + if(rowCover[r] == 1) { + costs[c * nrows + r] += minVal; + } + if(colCover[c] == 0) { + costs[c * nrows + r] -= minVal; + } + } } - } - return 4; + return 4; } void stepSeven(int* marks, int* rowIdxs, int* colIdxs, int M, int N) { - int i = 0; - for (int r = 0; r < M; r++) { - for (int c = 0; c < N; c++) { - const int mark = marks[c * M + r]; - if (mark == Mark::Star) { - rowIdxs[i] = r; - colIdxs[i] = c; - i += 1; - } + int i = 0; + for(int r = 0; r < M; r++) { + for(int c = 0; c < N; c++) { + const int mark = marks[c * M + r]; + if(mark == Mark::Star) { + rowIdxs[i] = r; + colIdxs[i] = c; + i += 1; + } + } } - } }; } // namespace @@ -316,64 +326,66 @@ void stepSeven(int* marks, int* rowIdxs, int* colIdxs, int M, int N) { namespace fl::lib::set { void hungarian(float* costs, int* assignments, int M, int N) { - // Ensure there are more rows than columns - assert(N >= M); - std::vector rowCover(M); - std::vector colCover(N); - std::vector paths(N * N * 2); - int firstPathRow, firstPathCol; - bool done = false; - int step = 1; - while (!done) { - switch (step) { - case 1: - step = stepOne(costs, M, N); - break; - case 2: - step = - stepTwo(costs, assignments, colCover.data(), rowCover.data(), M, N); - break; - case 3: - step = stepThree(assignments, colCover.data(), rowCover.data(), M, N); - break; - case 4: - step = stepFour( - costs, - assignments, - colCover.data(), - rowCover.data(), - M, - N, - &firstPathRow, - &firstPathCol); - break; - case 5: - step = stepFive( - costs, - assignments, - colCover.data(), - rowCover.data(), - paths.data(), - firstPathRow, - firstPathCol, - M, - N); - break; - case 6: - step = - stepSix(costs, assignments, colCover.data(), rowCover.data(), M, N); - break; - case 7: - done = true; - break; + // Ensure there are more rows than columns + assert(N >= M); + std::vector rowCover(M); + std::vector colCover(N); + std::vector paths(N * N * 2); + int firstPathRow, firstPathCol; + bool done = false; + int step = 1; + while(!done) { + switch(step) { + case 1: + step = stepOne(costs, M, N); + break; + case 2: + step = + stepTwo(costs, assignments, colCover.data(), rowCover.data(), M, N); + break; + case 3: + step = stepThree(assignments, colCover.data(), rowCover.data(), M, N); + break; + case 4: + step = stepFour( + costs, + assignments, + colCover.data(), + rowCover.data(), + M, + N, + &firstPathRow, + &firstPathCol + ); + break; + case 5: + step = stepFive( + costs, + assignments, + colCover.data(), + rowCover.data(), + paths.data(), + firstPathRow, + firstPathCol, + M, + N + ); + break; + case 6: + step = + stepSix(costs, assignments, colCover.data(), rowCover.data(), M, N); + break; + case 7: + done = true; + break; + } } - } } void hungarian(float* costs, int* rowIdxs, int* colIdxs, int M, int N) { - std::vector marks(M * N); - hungarian(costs, marks.data(), M, N); - stepSeven(marks.data(), rowIdxs, colIdxs, M, N); + std::vector marks(M * N); + hungarian(costs, marks.data(), M, N); + stepSeven(marks.data(), rowIdxs, colIdxs, M, N); } } // namespace fl diff --git a/flashlight/pkg/vision/criterion/HungarianImpl.h b/flashlight/pkg/vision/criterion/HungarianImpl.h index 7656bb3..9e687ac 100644 --- a/flashlight/pkg/vision/criterion/HungarianImpl.h +++ b/flashlight/pkg/vision/criterion/HungarianImpl.h @@ -9,7 +9,7 @@ namespace fl { namespace lib { -namespace set { + namespace set { /* * Performs linear sum assignment @@ -21,15 +21,15 @@ namespace set { * rowIdxs will contain the row idx for each assignment * and colIdxs wiill contain the colIdx for each assignment */ -void hungarian(float* costs, int* rowIdxs, int* colIdxs, int M, int N); + void hungarian(float* costs, int* rowIdxs, int* colIdxs, int M, int N); /* * Same as above except it will output an M X N assignment matrix where * assignments[m][n] == 1 means m and n are assigned. */ - void hungarian(float* costs, int* assignments, int M, int N); + void hungarian(float* costs, int* assignments, int M, int N); -} // namespace set + } // namespace set } // namespace lib } // namespace fl diff --git a/flashlight/pkg/vision/criterion/SetCriterion.cpp b/flashlight/pkg/vision/criterion/SetCriterion.cpp index c99f1ca..7ff6468 100644 --- a/flashlight/pkg/vision/criterion/SetCriterion.cpp +++ b/flashlight/pkg/vision/criterion/SetCriterion.cpp @@ -21,111 +21,116 @@ namespace { using namespace fl; Tensor span(const Shape& inDims, const int index) { - Shape dims(std::vector(std::max(inDims.ndim(), index + 1), 1)); - if (index > inDims.ndim() - 1) { - dims[index] = 1; - } else { - dims[index] = inDims[index]; - } - return fl::iota(dims); + Shape dims(std::vector(std::max(inDims.ndim(), index + 1), 1)); + if(index > inDims.ndim() - 1) { + dims[index] = 1; + } else { + dims[index] = inDims[index]; + } + return fl::iota(dims); } Shape calcStrides(const Shape& dims) { - return {1, dims[0], dims[0] * dims[1], dims[0] * dims[1] * dims[2]}; + return {1, dims[0], dims[0] * dims[1], dims[0] * dims[1] * dims[2]}; }; Shape calcOutDims(const std::vector& coords) { - unsigned maxNdim = 0; - for (const auto& coord : coords) { - if (coord.ndim() > maxNdim) { - maxNdim = coord.ndim(); + unsigned maxNdim = 0; + for(const auto& coord : coords) { + if(coord.ndim() > maxNdim) { + maxNdim = coord.ndim(); + } } - } - Shape oDims(std::vector(maxNdim, 1)); + Shape oDims(std::vector(maxNdim, 1)); - for (const auto& coord : coords) { - auto iDims = coord.shape(); - for (int i = 0; i < coord.ndim(); i++) { - if (iDims[i] > 1 && oDims[i] == 1) { - oDims[i] = iDims[i]; - } - assert(iDims[i] == 1 || iDims[i] == oDims[i]); + for(const auto& coord : coords) { + auto iDims = coord.shape(); + for(int i = 0; i < coord.ndim(); i++) { + if(iDims[i] > 1 && oDims[i] == 1) { + oDims[i] = iDims[i]; + } + assert(iDims[i] == 1 || iDims[i] == oDims[i]); + } } - } - return oDims; + return oDims; } Tensor applyStrides(const std::vector& coords, const Shape& strides) { - auto oDims = coords[0].shape(); - return std::inner_product( - coords.begin(), - coords.end(), - strides.get().begin(), - fl::full(oDims, 0), - [](const Tensor& x, const Tensor& y) { return x + y; }, - [](const Tensor& x, int y) { return x * y; }); + auto oDims = coords[0].shape(); + return std::inner_product( + coords.begin(), + coords.end(), + strides.get().begin(), + fl::full(oDims, 0), + [](const Tensor& x, const Tensor& y) { return x + y; }, + [](const Tensor& x, int y) { return x * y; }); } std::vector spanIfEmpty(const std::vector& coords, Shape dims) { - std::vector result(coords.size()); - for (int i = 0; i < coords.size(); i++) { - result[i] = (coords[i].isEmpty()) ? span(dims, i) : coords[i]; - } - return result; + std::vector result(coords.size()); + for(int i = 0; i < coords.size(); i++) { + result[i] = (coords[i].isEmpty()) ? span(dims, i) : coords[i]; + } + return result; } // Then, broadcast the indices std::vector broadcastCoords(const std::vector& input) { - std::vector result(input.size()); - auto oDims = calcOutDims(input); - std::transform( - input.begin(), input.end(), result.begin(), [&oDims](const Tensor& idx) { - return detail::tileAs(idx, oDims); - }); - return result; + std::vector result(input.size()); + auto oDims = calcOutDims(input); + std::transform( + input.begin(), + input.end(), + result.begin(), + [&oDims](const Tensor& idx) { + return detail::tileAs(idx, oDims); + } + ); + return result; } Tensor ravelIndices( const std::vector& input_coords, - const Shape& in_dims) { - std::vector coords; - coords = spanIfEmpty(input_coords, in_dims); - coords = broadcastCoords(coords); - return applyStrides(coords, calcStrides(in_dims)); + const Shape& in_dims +) { + std::vector coords; + coords = spanIfEmpty(input_coords, in_dims); + coords = broadcastCoords(coords); + return applyStrides(coords, calcStrides(in_dims)); } Tensor index(const Tensor& in, const std::vector& idxs) { - auto linearIndices = ravelIndices(idxs, in.shape()); - Tensor output = fl::full(linearIndices.shape(), 0., in.type()); - output.flat(fl::range(static_cast(linearIndices.elements()))) = - in.flatten()(linearIndices); - return output; + auto linearIndices = ravelIndices(idxs, in.shape()); + Tensor output = fl::full(linearIndices.shape(), 0., in.type()); + output.flat(fl::range(static_cast(linearIndices.elements()))) = + in.flatten()(linearIndices); + return output; } fl::Variable index(const fl::Variable& in, std::vector idxs) { - auto idims = in.shape(); - auto result = index(in.tensor(), idxs); - auto gradFunction = [idxs, idims]( - std::vector& inputs, - const Variable& grad_output) { - if (!inputs[0].isGradAvailable()) { - auto grad = fl::full(idims, 0., inputs[0].type()); - inputs[0].addGrad(Variable(grad, false)); - return; - } - auto grad = fl::Variable(fl::full(idims, 0, inputs[0].type()), false); - auto linearIndices = ravelIndices(idxs, idims); - grad.tensor()(linearIndices) = grad_output.tensor()( - fl::range(static_cast(linearIndices.elements()))); - // TODO Can parallize this if needed but does not work for duplicate keys - // for(int i = 0; i < linearIndices.elements(); i++) { - // Tensor index = linearIndices(i); - // grad.tensor()(index) += grad_output.tensor()(i); - //} - inputs[0].addGrad(grad); - }; - return fl::Variable(result, {in.withoutData()}, gradFunction); + auto idims = in.shape(); + auto result = index(in.tensor(), idxs); + auto gradFunction = [idxs, idims]( + std::vector& inputs, + const Variable& grad_output) { + if(!inputs[0].isGradAvailable()) { + auto grad = fl::full(idims, 0., inputs[0].type()); + inputs[0].addGrad(Variable(grad, false)); + return; + } + auto grad = fl::Variable(fl::full(idims, 0, inputs[0].type()), false); + auto linearIndices = ravelIndices(idxs, idims); + grad.tensor()(linearIndices) = grad_output.tensor()( + fl::range(static_cast(linearIndices.elements()))); + // TODO Can parallize this if needed but does not work for duplicate keys + // for(int i = 0; i < linearIndices.elements(); i++) { + // Tensor index = linearIndices(i); + // grad.tensor()(index) += grad_output.tensor()(i); + // } + inputs[0].addGrad(grad); + }; + return fl::Variable(result, {in.withoutData()}, gradFunction); } } // namespace @@ -136,68 +141,82 @@ SetCriterion::SetCriterion( const int numClasses, const HungarianMatcher& matcher, const std::unordered_map& weightDict, - const float eosCoef) - : numClasses_(numClasses), - matcher_(matcher), - weightDict_(weightDict), - eosCoef_(eosCoef){}; + const float eosCoef +) : numClasses_(numClasses), + matcher_(matcher), + weightDict_(weightDict), + eosCoef_(eosCoef) {}; SetCriterion::LossDict SetCriterion::forward( const Variable& predBoxesAux, const Variable& predLogitsAux, const std::vector& targetBoxes, - const std::vector& targetClasses) { - LossDict losses; - - for (int i = 0; i < predBoxesAux.dim(3); i++) { - auto predBoxes = predBoxesAux(fl::span, fl::span, fl::span, i); - auto predLogits = - predLogitsAux(fl::span, fl::span, fl::span, fl::range(i, i + 1)); - - std::vector targetBoxesArray(targetBoxes.size()); - std::vector targetClassesArray(targetClasses.size()); - std::transform( - targetBoxes.begin(), - targetBoxes.end(), - targetBoxesArray.begin(), - [](const Variable& in) { return in.tensor(); }); - std::transform( - targetClasses.begin(), - targetClasses.end(), - targetClassesArray.begin(), - [](const Variable& in) { return in.tensor(); }); - - auto indices = matcher_.compute( - predBoxes.tensor(), - predLogits.tensor(), - targetBoxesArray, - targetClassesArray); - - int numBoxes = std::accumulate( - targetBoxes.begin(), - targetBoxes.end(), - 0, - [](int curr, const Variable& label) { return curr + label.dim(1); }); - - Tensor numBoxesArray = fl::fromScalar(numBoxes, fl::dtype::s32); - if (isDistributedInit()) { - allReduce(numBoxesArray); + const std::vector& targetClasses +) { + LossDict losses; + + for(int i = 0; i < predBoxesAux.dim(3); i++) { + auto predBoxes = predBoxesAux(fl::span, fl::span, fl::span, i); + auto predLogits = + predLogitsAux(fl::span, fl::span, fl::span, fl::range(i, i + 1)); + + std::vector targetBoxesArray(targetBoxes.size()); + std::vector targetClassesArray(targetClasses.size()); + std::transform( + targetBoxes.begin(), + targetBoxes.end(), + targetBoxesArray.begin(), + [](const Variable& in) { return in.tensor(); }); + std::transform( + targetClasses.begin(), + targetClasses.end(), + targetClassesArray.begin(), + [](const Variable& in) { return in.tensor(); }); + + auto indices = matcher_.compute( + predBoxes.tensor(), + predLogits.tensor(), + targetBoxesArray, + targetClassesArray + ); + + int numBoxes = std::accumulate( + targetBoxes.begin(), + targetBoxes.end(), + 0, + [](int curr, const Variable& label) { return curr + label.dim(1); }); + + Tensor numBoxesArray = fl::fromScalar(numBoxes, fl::dtype::s32); + if(isDistributedInit()) { + allReduce(numBoxesArray); + } + numBoxes = numBoxesArray.scalar(); + numBoxes = std::max(numBoxes / fl::getWorldSize(), 1); + + auto labelLoss = lossLabels( + predBoxes, + predLogits, + targetBoxes, + targetClasses, + indices, + numBoxes + ); + auto bboxLoss = lossBoxes( + predBoxes, + predLogits, + targetBoxes, + targetClasses, + indices, + numBoxes + ); + for(std::pair l : labelLoss) { + losses[l.first + "_" + std::to_string(i)] = l.second; + } + for(std::pair l : bboxLoss) { + losses[l.first + "_" + std::to_string(i)] = l.second; + } } - numBoxes = numBoxesArray.scalar(); - numBoxes = std::max(numBoxes / fl::getWorldSize(), 1); - - auto labelLoss = lossLabels( - predBoxes, predLogits, targetBoxes, targetClasses, indices, numBoxes); - auto bboxLoss = lossBoxes( - predBoxes, predLogits, targetBoxes, targetClasses, indices, numBoxes); - for (std::pair l : labelLoss) { - losses[l.first + "_" + std::to_string(i)] = l.second; - } - for (std::pair l : bboxLoss) { - losses[l.first + "_" + std::to_string(i)] = l.second; - } - } - return losses; + return losses; } SetCriterion::LossDict SetCriterion::lossBoxes( @@ -206,44 +225,45 @@ SetCriterion::LossDict SetCriterion::lossBoxes( const std::vector& targetBoxes, const std::vector& /*targetClasses*/, const std::vector>& indices, - const int numBoxes) { - auto srcIdx = this->getSrcPermutationIdx(indices); - if (srcIdx.first.isEmpty()) { - return { - {"lossGiou", fl::Variable(fl::fromScalar(0, predBoxes.type()), false)}, - {"lossBbox", fl::Variable(fl::fromScalar(0, predBoxes.type()), false)}}; - } - auto colIdxs = fl::reshape(srcIdx.second, {1, srcIdx.second.dim(0)}); - auto batchIdxs = fl::reshape(srcIdx.first, {1, srcIdx.first.dim(0)}); - - auto srcBoxes = index(predBoxes, {Tensor(), colIdxs, batchIdxs}); - - int i = 0; - std::vector permuted; - for (const auto& idx : indices) { - auto targetIdxs = idx.first; - auto reordered = targetBoxes[i](fl::span, targetIdxs); - if (!reordered.isEmpty()) { - permuted.emplace_back(reordered); + const int numBoxes +) { + auto srcIdx = this->getSrcPermutationIdx(indices); + if(srcIdx.first.isEmpty()) { + return { + {"lossGiou", fl::Variable(fl::fromScalar(0, predBoxes.type()), false)}, + {"lossBbox", fl::Variable(fl::fromScalar(0, predBoxes.type()), false)}}; } - i += 1; - } - auto tgtBoxes = fl::concatenate(permuted, 1); + auto colIdxs = fl::reshape(srcIdx.second, {1, srcIdx.second.dim(0)}); + auto batchIdxs = fl::reshape(srcIdx.first, {1, srcIdx.first.dim(0)}); + + auto srcBoxes = index(predBoxes, {Tensor(), colIdxs, batchIdxs}); + + int i = 0; + std::vector permuted; + for(const auto& idx : indices) { + auto targetIdxs = idx.first; + auto reordered = targetBoxes[i](fl::span, targetIdxs); + if(!reordered.isEmpty()) { + permuted.emplace_back(reordered); + } + i += 1; + } + auto tgtBoxes = fl::concatenate(permuted, 1); - auto costGiou = - generalizedBoxIou(cxcywh2xyxy(srcBoxes), cxcywh2xyxy(tgtBoxes)); + auto costGiou = + generalizedBoxIou(cxcywh2xyxy(srcBoxes), cxcywh2xyxy(tgtBoxes)); - // Extract diagonal - auto dims = costGiou.shape(); - auto rng = fl::arange({dims[0]}); - costGiou = 1 - index(costGiou, {rng, rng, Tensor(), Tensor()}); + // Extract diagonal + auto dims = costGiou.shape(); + auto rng = fl::arange({dims[0]}); + costGiou = 1 - index(costGiou, {rng, rng, Tensor(), Tensor()}); - costGiou = sum(costGiou, {0}) / numBoxes; + costGiou = sum(costGiou, {0}) / numBoxes; - auto lossBbox = l1Loss(srcBoxes, tgtBoxes); - lossBbox = sum(lossBbox, {0}) / numBoxes; + auto lossBbox = l1Loss(srcBoxes, tgtBoxes); + lossBbox = sum(lossBbox, {0}) / numBoxes; - return {{"lossGiou", costGiou}, {"lossBbox", lossBbox}}; + return {{"lossGiou", costGiou}, {"lossBbox", lossBbox}}; } SetCriterion::LossDict SetCriterion::lossLabels( @@ -252,78 +272,83 @@ SetCriterion::LossDict SetCriterion::lossLabels( const std::vector& /*targetBoxes*/, const std::vector& targetClasses, const std::vector>& indices, - const int /*numBoxes*/) { - assert(predLogits.dim(0) == numClasses_ + 1); - - auto target_classes_full = fl::full( - // TODO: this thing requires predLogits to have > 2 dimensions - {predLogits.dim(1), predLogits.dim(2), 1}, - numClasses_, - predLogits.type()); - - int i = 0; - for (const auto& idx : indices) { - auto targetIdxs = idx.first; - auto srcIdxs = idx.second; - auto reordered = targetClasses[i](targetIdxs); - target_classes_full(srcIdxs, i) = - fl::reshape( - targetClasses[i].tensor()(targetIdxs), - {static_cast(srcIdxs.elements()), 1}) - .astype(target_classes_full.type()); - i += 1; - } - - auto softmaxed = logSoftmax(predLogits, 0); - auto weight = fl::full({numClasses_ + 1}, 1.0f); - weight.flat(numClasses_) = eosCoef_; - auto weightVar = Variable(weight, false); - auto lossCe = weightedCategoricalCrossEntropy( - softmaxed, - fl::Variable(target_classes_full.astype(fl::dtype::f32), false), - weightVar, - -1); - return {{"lossCe", lossCe.astype(predLogits.type())}}; + const int /*numBoxes*/ +) { + assert(predLogits.dim(0) == numClasses_ + 1); + + auto target_classes_full = fl::full( + // TODO: this thing requires predLogits to have > 2 dimensions + {predLogits.dim(1), predLogits.dim(2), 1}, + numClasses_, + predLogits.type() + ); + + int i = 0; + for(const auto& idx : indices) { + auto targetIdxs = idx.first; + auto srcIdxs = idx.second; + auto reordered = targetClasses[i](targetIdxs); + target_classes_full(srcIdxs, i) = + fl::reshape( + targetClasses[i].tensor()(targetIdxs), + {static_cast(srcIdxs.elements()), 1}) + .astype(target_classes_full.type()); + i += 1; + } + + auto softmaxed = logSoftmax(predLogits, 0); + auto weight = fl::full({numClasses_ + 1}, 1.0f); + weight.flat(numClasses_) = eosCoef_; + auto weightVar = Variable(weight, false); + auto lossCe = weightedCategoricalCrossEntropy( + softmaxed, + fl::Variable(target_classes_full.astype(fl::dtype::f32), false), + weightVar, + -1 + ); + return {{"lossCe", lossCe.astype(predLogits.type())}}; } std::unordered_map SetCriterion::getWeightDict() { - return weightDict_; + return weightDict_; } std::pair SetCriterion::getTgtPermutationIdx( - const std::vector>& indices) { - long batchSize = static_cast(indices.size()); - auto batchIdxs = fl::full({1, 1, 1, batchSize}, -1); - auto first = indices[0].first; - auto dims = first.shape(); - auto tgtIdxs = fl::full({1, dims[0], batchSize}, -1); - int idx = 0; - for (const auto& pair : indices) { - batchIdxs(0, 0, 0, idx) = fl::fromScalar(idx); - tgtIdxs(fl::span, fl::span, idx) = pair.first; - idx++; - } - return std::make_pair(batchIdxs, tgtIdxs); + const std::vector>& indices +) { + long batchSize = static_cast(indices.size()); + auto batchIdxs = fl::full({1, 1, 1, batchSize}, -1); + auto first = indices[0].first; + auto dims = first.shape(); + auto tgtIdxs = fl::full({1, dims[0], batchSize}, -1); + int idx = 0; + for(const auto& pair : indices) { + batchIdxs(0, 0, 0, idx) = fl::fromScalar(idx); + tgtIdxs(fl::span, fl::span, idx) = pair.first; + idx++; + } + return std::make_pair(batchIdxs, tgtIdxs); } std::pair SetCriterion::getSrcPermutationIdx( - const std::vector>& indices) { - std::vector srcIdxs; - std::vector batchIdxs; - for (int i = 0; i < indices.size(); i++) { - auto index = indices[i].second; - if (!index.isEmpty()) { - srcIdxs.emplace_back(index, false); - auto batchIdx = fl::full(index.shape(), i, fl::dtype::s32); - batchIdxs.emplace_back(batchIdx, false); + const std::vector>& indices +) { + std::vector srcIdxs; + std::vector batchIdxs; + for(int i = 0; i < indices.size(); i++) { + auto index = indices[i].second; + if(!index.isEmpty()) { + srcIdxs.emplace_back(index, false); + auto batchIdx = fl::full(index.shape(), i, fl::dtype::s32); + batchIdxs.emplace_back(batchIdx, false); + } + } + fl::Variable srcIdx, batchIdx; + if(!srcIdxs.empty()) { + srcIdx = concatenate(srcIdxs, 0); + batchIdx = concatenate(batchIdxs, 0); } - } - fl::Variable srcIdx, batchIdx; - if (!srcIdxs.empty()) { - srcIdx = concatenate(srcIdxs, 0); - batchIdx = concatenate(batchIdxs, 0); - } - return {batchIdx.tensor(), srcIdx.tensor()}; + return {batchIdx.tensor(), srcIdx.tensor()}; } } // namespace fl diff --git a/flashlight/pkg/vision/criterion/SetCriterion.h b/flashlight/pkg/vision/criterion/SetCriterion.h index 3a09c4d..552a14b 100644 --- a/flashlight/pkg/vision/criterion/SetCriterion.h +++ b/flashlight/pkg/vision/criterion/SetCriterion.h @@ -15,73 +15,82 @@ namespace fl { namespace pkg { -namespace vision { - -class SetCriterion { - public: - using LossDict = std::unordered_map; - - SetCriterion( - const int numClasses, - const HungarianMatcher& matcher, - const std::unordered_map& weightDict, - const float eosCoef); - - std::vector match( - const Variable& predBoxes, - const Variable& predLogits, - const std::vector& targetBoxes, - const std::vector& targetClasses); - - LossDict lossLabels( - const Variable& predBoxes, - const Variable& predLogits, - const std::vector& targetBoxes, - const std::vector& targetClasses, - const std::vector>& indices, - const int numBoxes); - - LossDict lossCardinality( - const Variable& predBoxes, - const Variable& predLogits, - const std::vector& targetBoxes, - const std::vector& targetClasses); - - LossDict lossBoxes( - const Variable& predBoxes, - const Variable& predLogits, - const std::vector& targetBoxes, - const std::vector& targetClasses, - const std::vector>& indices, - const int numBoxes); - - LossDict lossMasks( - const Variable& predBoxes, - const Variable& predLogits, - const std::vector& targetBoxes, - const std::vector& targetClasses); - - LossDict forward( - const Variable& predBoxes, - const Variable& predLogits, - const std::vector& targetBoxes, - const std::vector& targetClasses); - - std::unordered_map getWeightDict(); - - private: - std::pair getSrcPermutationIdx( - const std::vector>& indices); - - std::pair getTgtPermutationIdx( - const std::vector>& indices); - - const int numClasses_; - const HungarianMatcher matcher_; - const std::unordered_map weightDict_; - const float eosCoef_; -}; - -} // namespace vision + namespace vision { + + class SetCriterion { + public: + using LossDict = std::unordered_map; + + SetCriterion( + const int numClasses, + const HungarianMatcher& matcher, + const std::unordered_map& weightDict, + const float eosCoef + ); + + std::vector match( + const Variable& predBoxes, + const Variable& predLogits, + const std::vector& targetBoxes, + const std::vector& targetClasses + ); + + LossDict lossLabels( + const Variable& predBoxes, + const Variable& predLogits, + const std::vector& targetBoxes, + const std::vector& targetClasses, + const std::vector>& indices, + const int numBoxes + ); + + LossDict lossCardinality( + const Variable& predBoxes, + const Variable& predLogits, + const std::vector& targetBoxes, + const std::vector& targetClasses + ); + + LossDict lossBoxes( + const Variable& predBoxes, + const Variable& predLogits, + const std::vector& targetBoxes, + const std::vector& targetClasses, + const std::vector>& indices, + const int numBoxes + ); + + LossDict lossMasks( + const Variable& predBoxes, + const Variable& predLogits, + const std::vector& targetBoxes, + const std::vector& targetClasses + ); + + LossDict forward( + const Variable& predBoxes, + const Variable& predLogits, + const std::vector& targetBoxes, + const std::vector& targetClasses + ); + + std::unordered_map getWeightDict(); + + private: + std::pair getSrcPermutationIdx( + const std::vector>& indices + ); + + std::pair getTgtPermutationIdx( + const std::vector>& indices + ); + + const int numClasses_; + const HungarianMatcher matcher_; + const std::unordered_map weightDict_; + const float eosCoef_; + }; + + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/dataset/BatchTransformDataset.h b/flashlight/pkg/vision/dataset/BatchTransformDataset.h index a7c9504..9a010f9 100644 --- a/flashlight/pkg/vision/dataset/BatchTransformDataset.h +++ b/flashlight/pkg/vision/dataset/BatchTransformDataset.h @@ -13,11 +13,11 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { -template -using BatchTransformFunction = - std::function>&)>; + template + using BatchTransformFunction = + std::function>&)>; /* * This is a slightly more generalized batching dataset than allows you to @@ -25,81 +25,82 @@ using BatchTransformFunction = * because we would like to keep the target boxes and classes as a separate * unbatched vector of arrays, while still batching the images */ -template -class BatchTransformDataset { - public: - BatchTransformDataset( - std::shared_ptr dataset, - int64_t batchsize, - BatchDatasetPolicy policy /* = BatchDatasetPolicy::INCLUDE_LAST */, - BatchTransformFunction batchFn) - : dataset_(dataset), - batchSize_(batchsize), - batchPolicy_(policy), - batchFn_(batchFn) { - if (!dataset_) { - throw std::invalid_argument("dataset to be batched is null"); - } - if (batchSize_ <= 0) { - throw std::invalid_argument("invalid batch size"); - } - preBatchSize_ = dataset_->size(); - switch (batchPolicy_) { - case BatchDatasetPolicy::INCLUDE_LAST: - size_ = std::ceil(static_cast(preBatchSize_) / batchSize_); - break; - case BatchDatasetPolicy::SKIP_LAST: - size_ = std::floor(static_cast(preBatchSize_) / batchSize_); - break; - case BatchDatasetPolicy::DIVISIBLE_ONLY: - if (size_ % batchSize_ != 0) { - throw std::invalid_argument( - "dataset is not evenly divisible into batches"); - } - size_ = std::ceil(static_cast(preBatchSize_) / batchSize_); - break; - default: - throw std::invalid_argument("unknown BatchDatasetPolicy"); - } - } + template + class BatchTransformDataset { + public: + BatchTransformDataset( + std::shared_ptr dataset, + int64_t batchsize, + BatchDatasetPolicy policy /* = BatchDatasetPolicy::INCLUDE_LAST */, + BatchTransformFunction batchFn + ) : dataset_(dataset), + batchSize_(batchsize), + batchPolicy_(policy), + batchFn_(batchFn) { + if(!dataset_) { + throw std::invalid_argument("dataset to be batched is null"); + } + if(batchSize_ <= 0) { + throw std::invalid_argument("invalid batch size"); + } + preBatchSize_ = dataset_->size(); + switch(batchPolicy_) { + case BatchDatasetPolicy::INCLUDE_LAST: + size_ = std::ceil(static_cast(preBatchSize_) / batchSize_); + break; + case BatchDatasetPolicy::SKIP_LAST: + size_ = std::floor(static_cast(preBatchSize_) / batchSize_); + break; + case BatchDatasetPolicy::DIVISIBLE_ONLY: + if(size_ % batchSize_ != 0) { + throw std::invalid_argument( + "dataset is not evenly divisible into batches" + ); + } + size_ = std::ceil(static_cast(preBatchSize_) / batchSize_); + break; + default: + throw std::invalid_argument("unknown BatchDatasetPolicy"); + } + } - ~BatchTransformDataset() {} + ~BatchTransformDataset() {} - T get(const int64_t idx) { - if (!(idx >= 0 && idx < size())) { - throw std::out_of_range("Dataset idx out of range"); - } - std::vector> buffer; + T get(const int64_t idx) { + if(!(idx >= 0 && idx < size())) { + throw std::out_of_range("Dataset idx out of range"); + } + std::vector> buffer; - int64_t start = batchSize_ * idx; - int64_t end = std::min(start + batchSize_, preBatchSize_); + int64_t start = batchSize_ * idx; + int64_t end = std::min(start + batchSize_, preBatchSize_); - for (int64_t batchidx = start; batchidx < end; ++batchidx) { - auto fds = dataset_->get(batchidx); - if (buffer.size() < fds.size()) { - buffer.resize(fds.size()); - } - for (int64_t i = 0; i < fds.size(); ++i) { - buffer[i].emplace_back(fds[i]); - } - } - return batchFn_(buffer); - } + for(int64_t batchidx = start; batchidx < end; ++batchidx) { + auto fds = dataset_->get(batchidx); + if(buffer.size() < fds.size()) { + buffer.resize(fds.size()); + } + for(int64_t i = 0; i < fds.size(); ++i) { + buffer[i].emplace_back(fds[i]); + } + } + return batchFn_(buffer); + } - int64_t size() const { - return size_; - } + int64_t size() const { + return size_; + } - private: - std::shared_ptr dataset_; - int64_t batchSize_; - BatchDatasetPolicy batchPolicy_; - BatchTransformFunction batchFn_; + private: + std::shared_ptr dataset_; + int64_t batchSize_; + BatchDatasetPolicy batchPolicy_; + BatchTransformFunction batchFn_; - int64_t preBatchSize_; // Size of the dataset before batching - int64_t size_; -}; + int64_t preBatchSize_; // Size of the dataset before batching + int64_t size_; + }; -} // namespace vision + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/dataset/BoxUtils.cpp b/flashlight/pkg/vision/dataset/BoxUtils.cpp index 4dbe29b..3a29eb9 100644 --- a/flashlight/pkg/vision/dataset/BoxUtils.cpp +++ b/flashlight/pkg/vision/dataset/BoxUtils.cpp @@ -18,214 +18,245 @@ namespace fl::pkg::vision { Tensor cxcywh2xyxy(const Tensor& bboxes) { - auto xc = bboxes(fl::range(0, 1)); - auto yc = bboxes(fl::range(1, 2)); - auto w = bboxes(fl::range(2, 3)); - auto h = bboxes(fl::range(3, 4)); + auto xc = bboxes(fl::range(0, 1)); + auto yc = bboxes(fl::range(1, 2)); + auto w = bboxes(fl::range(2, 3)); + auto h = bboxes(fl::range(3, 4)); - return fl::concatenate( - 0, xc - 0.5 * w, yc - 0.5 * h, xc + 0.5 * w, yc + 0.5 * h); + return fl::concatenate( + 0, + xc - 0.5 * w, + yc - 0.5 * h, + xc + 0.5 * w, + yc + 0.5 * h + ); } fl::Variable cxcywh2xyxy(const Variable& bboxes) { - auto xc = bboxes(fl::range(0, 1)); - auto yc = bboxes(fl::range(1, 2)); - auto w = bboxes(fl::range(2, 3)); - auto h = bboxes(fl::range(3, 4)); + auto xc = bboxes(fl::range(0, 1)); + auto yc = bboxes(fl::range(1, 2)); + auto w = bboxes(fl::range(2, 3)); + auto h = bboxes(fl::range(3, 4)); - return fl::concatenate( - {xc - 0.5 * w, yc - 0.5 * h, xc + 0.5 * w, yc + 0.5 * h}, 0); + return fl::concatenate( + {xc - 0.5 * w, yc - 0.5 * h, xc + 0.5 * w, yc + 0.5 * h}, + 0 + ); } Tensor xyxy2cxcywh(const Tensor& bboxes) { - auto x0 = bboxes(fl::range(0, 1)); - auto y0 = bboxes(fl::range(1, 2)); - auto x1 = bboxes(fl::range(2, 3)); - auto y1 = bboxes(fl::range(3, 4)); - Tensor result = - fl::concatenate(0, (x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)); - return result; + auto x0 = bboxes(fl::range(0, 1)); + auto y0 = bboxes(fl::range(1, 2)); + auto x1 = bboxes(fl::range(2, 3)); + auto y1 = bboxes(fl::range(3, 4)); + Tensor result = + fl::concatenate(0, (x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)); + return result; } Tensor flatten(const Tensor& x, int start, int stop) { - auto dims = x.shape(); - Shape newDims(std::vector(x.ndim(), 1)); - int flattenedDims = 1; - for (int i = start; i <= stop; i++) { - flattenedDims = flattenedDims * dims[i]; - } - for (int i = 0; i < start; i++) { - newDims[i] = dims[i]; - } - newDims[start] = flattenedDims; - for (int i = start + 1; i < (x.ndim() - stop); i++) { - newDims[i] = dims[i + stop]; - } - return fl::reshape(x, newDims); + auto dims = x.shape(); + Shape newDims(std::vector(x.ndim(), 1)); + int flattenedDims = 1; + for(int i = start; i <= stop; i++) { + flattenedDims = flattenedDims * dims[i]; + } + for(int i = 0; i < start; i++) { + newDims[i] = dims[i]; + } + newDims[start] = flattenedDims; + for(int i = start + 1; i < (x.ndim() - stop); i++) { + newDims[i] = dims[i + stop]; + } + return fl::reshape(x, newDims); }; fl::Variable flatten(const fl::Variable& x, int start, int stop) { - unsigned n = x.ndim(); - auto dims = x.shape(); - Shape newDims(std::vector(n, 1)); - int flattenedDims = 1; - for (int i = start; i <= stop; i++) { - flattenedDims = flattenedDims * dims[i]; - } - for (int i = 0; i < start; i++) { - newDims[i] = dims[i]; - } - newDims[start] = flattenedDims; - for (int i = start + 1; i < (n - stop); i++) { - newDims[i] = dims[i + stop]; - } - return moddims(x, newDims); + unsigned n = x.ndim(); + auto dims = x.shape(); + Shape newDims(std::vector(n, 1)); + int flattenedDims = 1; + for(int i = start; i <= stop; i++) { + flattenedDims = flattenedDims * dims[i]; + } + for(int i = 0; i < start; i++) { + newDims[i] = dims[i]; + } + newDims[start] = flattenedDims; + for(int i = start + 1; i < (n - stop); i++) { + newDims[i] = dims[i + stop]; + } + return moddims(x, newDims); }; Tensor boxArea(const Tensor& bboxes) { - auto x0 = bboxes(fl::range(0, 1)); - auto y0 = bboxes(fl::range(1, 2)); - auto x1 = bboxes(fl::range(2, 3)); - auto y1 = bboxes(fl::range(3, 4)); - auto result = (x1 - x0) * (y1 - y0); - return result; + auto x0 = bboxes(fl::range(0, 1)); + auto y0 = bboxes(fl::range(1, 2)); + auto x1 = bboxes(fl::range(2, 3)); + auto y1 = bboxes(fl::range(3, 4)); + auto result = (x1 - x0) * (y1 - y0); + return result; } fl::Variable boxArea(const fl::Variable& bboxes) { - auto x0 = bboxes(fl::range(0, 1)); - auto y0 = bboxes(fl::range(1, 2)); - auto x1 = bboxes(fl::range(2, 3)); - auto y1 = bboxes(fl::range(3, 4)); - auto result = (x1 - x0) * (y1 - y0); - return result; + auto x0 = bboxes(fl::range(0, 1)); + auto y0 = bboxes(fl::range(1, 2)); + auto x1 = bboxes(fl::range(2, 3)); + auto y1 = bboxes(fl::range(3, 4)); + auto result = (x1 - x0) * (y1 - y0); + return result; } Variable cartesian(const Variable& x, const Variable& y, batchFuncVar_t fn) { - if (x.ndim() != 3 || y.ndim() != 3) { - throw std::invalid_argument( - "vision::cartesian - x and y inputs must have 3 dimensions"); - } - assert(x.dim(2) == y.dim(2)); - Shape yDims = {y.dim(0), 1, y.dim(1), y.dim(2)}; - auto yMod = moddims(y, {y.dim(0), 1, y.dim(1), y.dim(2)}); - auto xMod = moddims(x, {x.dim(0), x.dim(1), 1, x.dim(2)}); - Shape outputDims = {x.dim(0), x.dim(1), y.dim(1), x.dim(2)}; - xMod = tileAs(xMod, outputDims); - yMod = tileAs(yMod, outputDims); - - auto out = fn(xMod, yMod); - return out; + if(x.ndim() != 3 || y.ndim() != 3) { + throw std::invalid_argument( + "vision::cartesian - x and y inputs must have 3 dimensions" + ); + } + assert(x.dim(2) == y.dim(2)); + Shape yDims = {y.dim(0), 1, y.dim(1), y.dim(2)}; + auto yMod = moddims(y, {y.dim(0), 1, y.dim(1), y.dim(2)}); + auto xMod = moddims(x, {x.dim(0), x.dim(1), 1, x.dim(2)}); + Shape outputDims = {x.dim(0), x.dim(1), y.dim(1), x.dim(2)}; + xMod = tileAs(xMod, outputDims); + yMod = tileAs(yMod, outputDims); + + auto out = fn(xMod, yMod); + return out; } Tensor cartesian(const Tensor& x, const Tensor& y, batchFuncArr_t fn) { - if (x.ndim() != 3 || y.ndim() != 3) { - throw std::invalid_argument( - "vision::cartesian - x and y inputs must have 3 dimensions"); - } - assert(x.dim(2) == y.dim(2)); - Shape yDims = {y.dim(0), 1, y.dim(1), y.dim(2)}; - auto yMod = fl::reshape(y, {y.dim(0), 1, y.dim(1), y.dim(2)}); - auto xMod = fl::reshape(x, {x.dim(0), x.dim(1), 1, x.dim(2)}); - Shape outputDims = {x.dim(0), x.dim(1), y.dim(1), x.dim(2)}; - xMod = detail::tileAs(xMod, outputDims); - yMod = detail::tileAs(yMod, outputDims); - return fn(xMod, yMod); + if(x.ndim() != 3 || y.ndim() != 3) { + throw std::invalid_argument( + "vision::cartesian - x and y inputs must have 3 dimensions" + ); + } + assert(x.dim(2) == y.dim(2)); + Shape yDims = {y.dim(0), 1, y.dim(1), y.dim(2)}; + auto yMod = fl::reshape(y, {y.dim(0), 1, y.dim(1), y.dim(2)}); + auto xMod = fl::reshape(x, {x.dim(0), x.dim(1), 1, x.dim(2)}); + Shape outputDims = {x.dim(0), x.dim(1), y.dim(1), x.dim(2)}; + xMod = detail::tileAs(xMod, outputDims); + yMod = detail::tileAs(yMod, outputDims); + return fn(xMod, yMod); } std::tuple boxIou( const Tensor& bboxes1, - const Tensor& bboxes2) { - if (bboxes1.ndim() != 3 || bboxes2.ndim() != 3) { - throw std::invalid_argument( - "vision::boxIou - bbox inputs must be of shape " - "[4, N, B, ...] and [4, M, B, ...]"); - } - auto area1 = boxArea(bboxes1); - auto area2 = boxArea(bboxes2); - auto lt = cartesian( - bboxes1(fl::range(0, 2)), bboxes2(fl::range(0, 2)), fl::maximum); - auto rb = cartesian( - bboxes1(fl::range(2, 4)), bboxes2(fl::range(2, 4)), fl::minimum); - auto wh = fl::maximum((rb - lt), 0.0); - auto inter = wh(fl::range(0, 1)) * wh(fl::range(1, 2)); - auto uni = cartesian(area1, area2, fl::operator+) - inter; - auto iou = inter / uni; - iou = flatten(iou, 0, 1); - uni = flatten(uni, 0, 1); - return std::tie(iou, uni); + const Tensor& bboxes2 +) { + if(bboxes1.ndim() != 3 || bboxes2.ndim() != 3) { + throw std::invalid_argument( + "vision::boxIou - bbox inputs must be of shape " + "[4, N, B, ...] and [4, M, B, ...]" + ); + } + auto area1 = boxArea(bboxes1); + auto area2 = boxArea(bboxes2); + auto lt = cartesian( + bboxes1(fl::range(0, 2)), + bboxes2(fl::range(0, 2)), + fl::maximum + ); + auto rb = cartesian( + bboxes1(fl::range(2, 4)), + bboxes2(fl::range(2, 4)), + fl::minimum + ); + auto wh = fl::maximum((rb - lt), 0.0); + auto inter = wh(fl::range(0, 1)) * wh(fl::range(1, 2)); + auto uni = cartesian(area1, area2, fl::operator+) - inter; + auto iou = inter / uni; + iou = flatten(iou, 0, 1); + uni = flatten(uni, 0, 1); + return std::tie(iou, uni); } std::tuple boxIou( const fl::Variable& bboxes1, - const fl::Variable& bboxes2) { - if (bboxes1.ndim() != 3 || bboxes2.ndim() != 3) { - std::stringstream ss; - ss << "vision::boxIou - bbox inputs must be of shape " - "[4, N, B] and [4, M, B]. Got boxes with dimensions " - << bboxes1.shape() << " and " << bboxes2.shape(); - throw std::invalid_argument(ss.str()); - } - auto area1 = boxArea(bboxes1); - auto area2 = boxArea(bboxes2); - auto lt = - cartesian(bboxes1(fl::range(0, 2)), bboxes2(fl::range(0, 2)), fl::max); - auto rb = cartesian(bboxes1(fl::range(2, 4)), bboxes2(fl::range(2, 4)), min); - auto wh = max((rb - lt), 0.0); - auto inter = wh(fl::range(0, 1)) * wh(fl::range(1, 2)); - auto uni = cartesian(area1, area2, fl::operator+) - inter; - auto iou = inter / uni; - iou = flatten(iou, 0, 1); - uni = flatten(uni, 0, 1); - return std::tie(iou, uni); + const fl::Variable& bboxes2 +) { + if(bboxes1.ndim() != 3 || bboxes2.ndim() != 3) { + std::stringstream ss; + ss << "vision::boxIou - bbox inputs must be of shape " + "[4, N, B] and [4, M, B]. Got boxes with dimensions " + << bboxes1.shape() << " and " << bboxes2.shape(); + throw std::invalid_argument(ss.str()); + } + auto area1 = boxArea(bboxes1); + auto area2 = boxArea(bboxes2); + auto lt = + cartesian(bboxes1(fl::range(0, 2)), bboxes2(fl::range(0, 2)), fl::max); + auto rb = cartesian(bboxes1(fl::range(2, 4)), bboxes2(fl::range(2, 4)), min); + auto wh = max((rb - lt), 0.0); + auto inter = wh(fl::range(0, 1)) * wh(fl::range(1, 2)); + auto uni = cartesian(area1, area2, fl::operator+) - inter; + auto iou = inter / uni; + iou = flatten(iou, 0, 1); + uni = flatten(uni, 0, 1); + return std::tie(iou, uni); } fl::Variable generalizedBoxIou( const fl::Variable& bboxes1, - const fl::Variable& bboxes2) { - // Make sure all boxes are properly formed - assert(fl::countNonzero(fl::all( - bboxes1.tensor()(fl::range(2, 4)) >= - bboxes1.tensor()(fl::range(0, 2)))) - .scalar()); - - assert(fl::countNonzero(fl::all( - bboxes2.tensor()(fl::range(2, 4)) >= - bboxes2.tensor()(fl::range(0, 2)))) - .scalar()); - - Variable iou, uni; - std::tie(iou, uni) = boxIou(bboxes1, bboxes2); - auto lt = cartesian(bboxes1(fl::range(0, 2)), bboxes2(fl::range(0, 2)), min); - auto rb = cartesian(bboxes1(fl::range(2, 4)), bboxes2(fl::range(2, 4)), max); - auto wh = max((rb - lt), 0.0); - auto area = wh(fl::range(0, 1)) * wh(fl::range(1, 2)); - area = flatten(area, 0, 1); - return iou - (area - uni) / area; + const fl::Variable& bboxes2 +) { + // Make sure all boxes are properly formed + assert(fl::countNonzero(fl::all( + bboxes1.tensor()(fl::range(2, 4)) + >= bboxes1.tensor()(fl::range(0, 2)))) + .scalar()); + + assert(fl::countNonzero(fl::all( + bboxes2.tensor()(fl::range(2, 4)) + >= bboxes2.tensor()(fl::range(0, 2)))) + .scalar()); + + Variable iou, uni; + std::tie(iou, uni) = boxIou(bboxes1, bboxes2); + auto lt = cartesian(bboxes1(fl::range(0, 2)), bboxes2(fl::range(0, 2)), min); + auto rb = cartesian(bboxes1(fl::range(2, 4)), bboxes2(fl::range(2, 4)), max); + auto wh = max((rb - lt), 0.0); + auto area = wh(fl::range(0, 1)) * wh(fl::range(1, 2)); + area = flatten(area, 0, 1); + return iou - (area - uni) / area; } Tensor generalizedBoxIou(const Tensor& bboxes1, const Tensor& bboxes2) { - // Make sure all boxes are properly formed - assert(fl::countNonzero( - fl::all(bboxes1(fl::range(2, 4)) >= bboxes1(fl::range(0, 2)))) - .scalar()); - assert(fl::countNonzero( - fl::all(bboxes2(fl::range(2, 4)) >= bboxes2(fl::range(0, 2)))) - .scalar()); - - Tensor iou, uni; - std::tie(iou, uni) = boxIou(bboxes1, bboxes2); - auto lt = cartesian( - bboxes1(fl::range(0, 2)), bboxes2(fl::range(0, 2)), fl::minimum); - auto rb = cartesian( - bboxes1(fl::range(2, 4)), bboxes2(fl::range(2, 4)), fl::maximum); - auto wh = fl::maximum((rb - lt), 0.0); - auto area = wh(fl::range(0, 1)) * wh(fl::range(1, 2)); - area = flatten(area, 0, 1); - return iou - (area - uni) / area; + // Make sure all boxes are properly formed + assert( + fl::countNonzero( + fl::all(bboxes1(fl::range(2, 4)) >= bboxes1(fl::range(0, 2))) + ) + .scalar() + ); + assert( + fl::countNonzero( + fl::all(bboxes2(fl::range(2, 4)) >= bboxes2(fl::range(0, 2))) + ) + .scalar() + ); + + Tensor iou, uni; + std::tie(iou, uni) = boxIou(bboxes1, bboxes2); + auto lt = cartesian( + bboxes1(fl::range(0, 2)), + bboxes2(fl::range(0, 2)), + fl::minimum + ); + auto rb = cartesian( + bboxes1(fl::range(2, 4)), + bboxes2(fl::range(2, 4)), + fl::maximum + ); + auto wh = fl::maximum((rb - lt), 0.0); + auto area = wh(fl::range(0, 1)) * wh(fl::range(1, 2)); + area = flatten(area, 0, 1); + return iou - (area - uni) / area; } Variable l1Loss(const Variable& input, const Variable& target) { - return flatten(fl::sum(fl::abs(input - target), {0}), 0, 1); + return flatten(fl::sum(fl::abs(input - target), {0}), 0, 1); } } // namespace fl diff --git a/flashlight/pkg/vision/dataset/BoxUtils.h b/flashlight/pkg/vision/dataset/BoxUtils.h index 388bf45..385248b 100644 --- a/flashlight/pkg/vision/dataset/BoxUtils.h +++ b/flashlight/pkg/vision/dataset/BoxUtils.h @@ -13,11 +13,11 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { -using batchFuncVar_t = Variable (*)(const Variable &, const Variable &); + using batchFuncVar_t = Variable (*)(const Variable&, const Variable&); -using batchFuncArr_t = Tensor (*)(const Tensor &, const Tensor &); + using batchFuncArr_t = Tensor (*)(const Tensor&, const Tensor&); /** * Converts bounding box coordinates from center (x, y) coordinate, with width @@ -26,7 +26,7 @@ using batchFuncArr_t = Tensor (*)(const Tensor &, const Tensor &); * boxes * @return a `Tensor` with transformed bboxes of same shape */ -Tensor cxcywh2xyxy(const Tensor& bboxes); + Tensor cxcywh2xyxy(const Tensor& bboxes); /** * Converts bounding box coordinates from center (x, y) coordinate, with width @@ -35,7 +35,7 @@ Tensor cxcywh2xyxy(const Tensor& bboxes); * of boxes * @return a `fl::Variable` with transformed bboxes of same shape */ -fl::Variable cxcywh2xyxy(const fl::Variable& bboxes); + fl::Variable cxcywh2xyxy(const fl::Variable& bboxes); /** * Converts bounding box coordinates from bottom left (x1, y1) top right @@ -45,7 +45,7 @@ fl::Variable cxcywh2xyxy(const fl::Variable& bboxes); * boxes * @return a `Tensor` with transformed bboxes of same shape */ -Tensor xyxy2cxcywh(const Tensor& bboxes); + Tensor xyxy2cxcywh(const Tensor& bboxes); /** * A generalized function for getting the "cartesian" product of a function @@ -58,8 +58,7 @@ Tensor xyxy2cxcywh(const Tensor& bboxes); * @return a fl::Variable of shape [ X x N X M X K ] * */ -fl::Variable -cartesian(const fl::Variable& x, const fl::Variable& y, batchFuncVar_t fn); + fl::Variable cartesian(const fl::Variable& x, const fl::Variable& y, batchFuncVar_t fn); /** * A generalized function for getting the "cartesian" product of a function @@ -72,7 +71,7 @@ cartesian(const fl::Variable& x, const fl::Variable& y, batchFuncVar_t fn); * @return a Tensor of shape [ X x N X M X K ] * */ -Tensor cartesian(const Tensor& x, const Tensor& y, batchFuncArr_t fn); + Tensor cartesian(const Tensor& x, const Tensor& y, batchFuncArr_t fn); /** * Flattens dimension between start and stop in an Tensor @@ -81,7 +80,7 @@ Tensor cartesian(const Tensor& x, const Tensor& y, batchFuncArr_t fn); * @param stop an int, the end dimension to flatten * @return an Tensor with collasped dimensions */ -Tensor flatten(const Tensor& x, int start, int stop); + Tensor flatten(const Tensor& x, int start, int stop); /** * Flattens dimension between start and stop in an Tensor @@ -90,7 +89,7 @@ Tensor flatten(const Tensor& x, int start, int stop); * @param stop an int, the end dimension to flatten * @return an fl::Variable with collasped dimensions */ -Variable flatten(const fl::Variable& x, int start, int stop); + Variable flatten(const fl::Variable& x, int start, int stop); /** * Computes the generalizedBoxIou pairwise across to arrays @@ -100,7 +99,7 @@ Variable flatten(const fl::Variable& x, int start, int stop); * @return an Tensor of shape [N x M x B] where each entry represents * the giou between two boxes */ -Tensor generalizedBoxIou(const Tensor& bboxes1, const Tensor& bboxes2); + Tensor generalizedBoxIou(const Tensor& bboxes1, const Tensor& bboxes2); /** * Computes the generalizedBoxIou pairwise across to fl::Variables @@ -110,7 +109,7 @@ Tensor generalizedBoxIou(const Tensor& bboxes1, const Tensor& bboxes2); * @return an fl::Variable of shape [N x M x B] where each entry represents * the giou between two boxes */ -Variable generalizedBoxIou(const Variable& bboxes1, const Variable& bboxes2); + Variable generalizedBoxIou(const Variable& bboxes1, const Variable& bboxes2); /** * Computes the iou pairwise across two arrays of bboxes @@ -119,9 +118,10 @@ Variable generalizedBoxIou(const Variable& bboxes1, const Variable& bboxes2); * @return an tuple of Tensor of shape [N x M x B] where each entry * represents the iou and intersection between two boxes */ -std::tuple boxIou( - const Tensor& bboxes1, - const Tensor& bboxes2); + std::tuple boxIou( + const Tensor& bboxes1, + const Tensor& bboxes2 + ); /** * Computes the iou pairwise across to Variables of bboxes @@ -130,9 +130,10 @@ std::tuple boxIou( * @return an tuple of fl::Variable of shape [N x M x B] where each entry * represents the iou and intersection between two boxes */ -std::tuple boxIou( - const fl::Variable& bboxes1, - const fl::Variable& bboxes2); + std::tuple boxIou( + const fl::Variable& bboxes1, + const fl::Variable& bboxes2 + ); /** * Computes the l1_loss pairwise across two arrays of boxes @@ -141,8 +142,8 @@ std::tuple boxIou( * @return an tuple of fl::Variable of shape [N x M x B] where each entry * represents the l1Loss between two boxes */ -Variable l1Loss(const Variable& input, const Variable& target); + Variable l1Loss(const Variable& input, const Variable& target); -} // namespace vision + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/dataset/Coco.cpp b/flashlight/pkg/vision/dataset/Coco.cpp index 221e283..4f08247 100644 --- a/flashlight/pkg/vision/dataset/Coco.cpp +++ b/flashlight/pkg/vision/dataset/Coco.cpp @@ -27,58 +27,59 @@ using namespace fl; constexpr int kElementsPerBbox = 4; std::pair makeImageAndMaskBatch( - const std::vector& data) { - int maxW = -1; - int maxH = -1; - - for (const auto& d : data) { - int w = d.dim(0); - int h = d.dim(1); - maxW = std::max(w, maxW); - maxH = std::max(h, maxH); - } - - Shape outDims = {maxW, maxH, 3, static_cast(data.size())}; - Shape maskDims = {maxW, maxH, 1, static_cast(data.size())}; - - auto batcharr = fl::full(outDims, 0); - auto maskarr = fl::full(maskDims, 0); - - for (long i = 0; i < data.size(); ++i) { - Tensor sample = data[i]; - Shape dims = sample.shape(); - int w = dims[0]; - int h = dims[1]; - batcharr(fl::range(0, w), fl::range(0, h), fl::span, fl::range(i, i + 1)) = - data[i]; - maskarr(fl::range(0, w), fl::range(0, h), fl::span, fl::range(i, i + 1)) = - fl::full({w, h}, 1); - } - return std::make_pair(batcharr, maskarr); + const std::vector& data +) { + int maxW = -1; + int maxH = -1; + + for(const auto& d : data) { + int w = d.dim(0); + int h = d.dim(1); + maxW = std::max(w, maxW); + maxH = std::max(h, maxH); + } + + Shape outDims = {maxW, maxH, 3, static_cast(data.size())}; + Shape maskDims = {maxW, maxH, 1, static_cast(data.size())}; + + auto batcharr = fl::full(outDims, 0); + auto maskarr = fl::full(maskDims, 0); + + for(long i = 0; i < data.size(); ++i) { + Tensor sample = data[i]; + Shape dims = sample.shape(); + int w = dims[0]; + int h = dims[1]; + batcharr(fl::range(0, w), fl::range(0, h), fl::span, fl::range(i, i + 1)) = + data[i]; + maskarr(fl::range(0, w), fl::range(0, h), fl::span, fl::range(i, i + 1)) = + fl::full({w, h}, 1); + } + return std::make_pair(batcharr, maskarr); } // Since the bboxes and classes are variable length, we don't actually want // to batch them together. CocoData cocoBatchFunc(const std::vector>& batches) { - Tensor imageBatch, masks; - std::tie(imageBatch, masks) = makeImageAndMaskBatch(batches[ImageIdx]); - return { - imageBatch, - masks, - makeBatch(batches[TargetSizeIdx]), - makeBatch(batches[ImageIdIdx]), - makeBatch(batches[OriginalSizeIdx]), - batches[BboxesIdx], - batches[ClassesIdx]}; + Tensor imageBatch, masks; + std::tie(imageBatch, masks) = makeImageAndMaskBatch(batches[ImageIdx]); + return { + imageBatch, + masks, + makeBatch(batches[TargetSizeIdx]), + makeBatch(batches[ImageIdIdx]), + makeBatch(batches[OriginalSizeIdx]), + batches[BboxesIdx], + batches[ClassesIdx]}; } int64_t getImageId(const std::string& fp) { - const std::string slash("/"); - const std::string period("."); - int start = fp.rfind(slash); - int end = fp.rfind(period); - std::string substring = fp.substr(start + 1, end - start); - return std::stol(substring); + const std::string slash("/"); + const std::string period("."); + int start = fp.rfind(slash); + int end = fp.rfind(period); + std::string substring = fp.substr(start + 1, end - start); + return std::stol(substring); } } // namespace @@ -92,116 +93,126 @@ CocoDataset::CocoDataset( int batch_size, int num_threads, int prefetch_size, - bool val) { - // Create vector of CocoDataSample which will be loaded into arrayfire arrays - std::vector data; - std::ifstream ifs(list_file); - if (!ifs) { - throw std::runtime_error("Could not open list file: " + list_file); - } - // We use tabs a deliminators between the filepath and each bbox - // We use spaced to separate the different fields of the bbox - const std::string delim = "\t"; - const std::string bbox_delim = " "; - std::string line; - while (std::getline(ifs, line)) { - int item = line.find(delim); - std::string filepath = line.substr(0, item); - std::vector bboxes; - std::vector classes; - item = line.find(delim, item); - while (item != std::string::npos) { - int pos = item; - int next; - for (int i = 0; i < 4; i++) { - next = line.find(bbox_delim, pos + 1); - assert(next != std::string::npos); - bboxes.emplace_back(std::stof(line.substr(pos, next - pos))); - pos = next; - } - next = line.find(bbox_delim, pos + 1); - classes.emplace_back(std::stod(line.substr(pos, next - pos))); - item = line.find(delim, pos); + bool val +) { + // Create vector of CocoDataSample which will be loaded into arrayfire arrays + std::vector data; + std::ifstream ifs(list_file); + if(!ifs) { + throw std::runtime_error("Could not open list file: " + list_file); } - data.emplace_back(CocoDataSample{filepath, bboxes, classes}); - } - assert(!data.empty()); - - // Now define how to load the data from CocoDataSampoles in arrayfire - std::shared_ptr ds = std::make_shared>( - data, [](const CocoDataSample& sample) { - Tensor image = loadJpeg(sample.filepath); - - std::vector targetSizes = {image.dim(1), image.dim(0)}; - Tensor targetSize = Tensor::fromVector(targetSizes); - Tensor imageId = + // We use tabs a deliminators between the filepath and each bbox + // We use spaced to separate the different fields of the bbox + const std::string delim = "\t"; + const std::string bbox_delim = " "; + std::string line; + while(std::getline(ifs, line)) { + int item = line.find(delim); + std::string filepath = line.substr(0, item); + std::vector bboxes; + std::vector classes; + item = line.find(delim, item); + while(item != std::string::npos) { + int pos = item; + int next; + for(int i = 0; i < 4; i++) { + next = line.find(bbox_delim, pos + 1); + assert(next != std::string::npos); + bboxes.emplace_back(std::stof(line.substr(pos, next - pos))); + pos = next; + } + next = line.find(bbox_delim, pos + 1); + classes.emplace_back(std::stod(line.substr(pos, next - pos))); + item = line.find(delim, pos); + } + data.emplace_back(CocoDataSample{filepath, bboxes, classes}); + } + assert(!data.empty()); + + // Now define how to load the data from CocoDataSampoles in arrayfire + std::shared_ptr ds = std::make_shared>( + data, + [](const CocoDataSample& sample) { + Tensor image = loadJpeg(sample.filepath); + + std::vector targetSizes = {image.dim(1), image.dim(0)}; + Tensor targetSize = Tensor::fromVector(targetSizes); + Tensor imageId = fl::full({getImageId(sample.filepath)}, 1, fl::dtype::s64); - const int num_elements = sample.bboxes.size(); - const int num_bboxes = num_elements / kElementsPerBbox; - Tensor bboxes, classes; - if (num_bboxes > 0) { - bboxes = - Tensor::fromVector({kElementsPerBbox, num_bboxes}, sample.bboxes); - classes = Tensor::fromVector({1, num_bboxes}, sample.classes); - } else { - // Arrayfire doesn't allow you to create 0 length dimension on - // anything other than the first dimension so we need this switch - bboxes = Tensor(); - classes = Tensor(); - } - // image, size, imageId, original_size - return std::vector{ + const int num_elements = sample.bboxes.size(); + const int num_bboxes = num_elements / kElementsPerBbox; + Tensor bboxes, classes; + if(num_bboxes > 0) { + bboxes = + Tensor::fromVector({kElementsPerBbox, num_bboxes}, sample.bboxes); + classes = Tensor::fromVector({1, num_bboxes}, sample.classes); + } else { + // Arrayfire doesn't allow you to create 0 length dimension on + // anything other than the first dimension so we need this switch + bboxes = Tensor(); + classes = Tensor(); + } + // image, size, imageId, original_size + return std::vector{ image, targetSize, imageId, targetSize, bboxes, classes}; - }); - - const int maxSize = 1333; - if (val) { - ds = - std::make_shared(ds, randomResize({800}, maxSize)); - } else { - std::vector scales = { - 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800}; - TransformAllFunction trainTransform = compose( - {randomHorizontalFlip(0.5), - randomSelect( - {randomResize(scales, maxSize), - compose( - {randomResize({400, 500, 600}, -1), - randomSizeCrop(384, 600), - randomResize(scales, 1333)})})}); - - ds = std::make_shared(ds, trainTransform); - } - - ds = std::make_shared(ds, Normalize()); - - // Skip shuffling if doing eval. - if (!val) { - shuffled_ = std::make_shared(ds); - ds = shuffled_; - } - auto permfn = [world_size, world_rank](int64_t idx) { - return (idx * world_size) + world_rank; - }; - - ds = std::make_shared(ds, permfn, ds->size() / world_size); - ds = std::make_shared(ds, num_threads, prefetch_size); - batched_ = std::make_shared>( - ds, batch_size, BatchDatasetPolicy::SKIP_LAST, cocoBatchFunc); + } + ); + + const int maxSize = 1333; + if(val) { + ds = + std::make_shared(ds, randomResize({800}, maxSize)); + } else { + std::vector scales = { + 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800}; + TransformAllFunction trainTransform = compose( + {randomHorizontalFlip(0.5), + randomSelect( + {randomResize(scales, maxSize), + compose( + {randomResize({400, 500, 600}, -1), + randomSizeCrop(384, 600), + randomResize(scales, 1333)} + )} + )} + ); + + ds = std::make_shared(ds, trainTransform); + } + + ds = std::make_shared(ds, Normalize()); + + // Skip shuffling if doing eval. + if(!val) { + shuffled_ = std::make_shared(ds); + ds = shuffled_; + } + auto permfn = [world_size, world_rank](int64_t idx) { + return (idx * world_size) + world_rank; + }; + + ds = std::make_shared(ds, permfn, ds->size() / world_size); + ds = std::make_shared(ds, num_threads, prefetch_size); + batched_ = std::make_shared>( + ds, + batch_size, + BatchDatasetPolicy::SKIP_LAST, + cocoBatchFunc + ); } void CocoDataset::resample() { - if (shuffled_) { - shuffled_->resample(); - } + if(shuffled_) { + shuffled_->resample(); + } } int64_t CocoDataset::size() const { - return batched_->size(); + return batched_->size(); } CocoData CocoDataset::get(const uint64_t idx) { - return batched_->get(idx); + return batched_->get(idx); } } // namespace fl diff --git a/flashlight/pkg/vision/dataset/Coco.h b/flashlight/pkg/vision/dataset/Coco.h index bca9a1f..c1be7e4 100644 --- a/flashlight/pkg/vision/dataset/Coco.h +++ b/flashlight/pkg/vision/dataset/Coco.h @@ -14,62 +14,64 @@ namespace fl { namespace pkg { -namespace vision { - -struct CocoDataSample { - std::string filepath; - std::vector bboxes; - std::vector classes; -}; - -struct CocoData { - Tensor images; - Tensor masks; - Tensor imageSizes; - Tensor imageIds; - Tensor originalImageSizes; - std::vector target_boxes; - std::vector target_labels; -}; - -class CocoDataset { - public: - CocoDataset( - const std::string& list_file, - int world_rank, - int world_size, - int batch_size, - int num_threads, - int prefetch_size, - bool val); - - std::shared_ptr getLabels(std::string list_file); - - std::shared_ptr getImages( - std::string list_file, - std::vector& transformfns); - - using iterator = detail::DatasetIterator; - - iterator begin() { - return iterator(this); - } - - iterator end() { - return iterator(); - } - - int64_t size() const; - - CocoData get(const uint64_t idx); - - void resample(); - - private: - std::shared_ptr> batched_; - std::shared_ptr shuffled_; -}; - -} // namespace vision + namespace vision { + + struct CocoDataSample { + std::string filepath; + std::vector bboxes; + std::vector classes; + }; + + struct CocoData { + Tensor images; + Tensor masks; + Tensor imageSizes; + Tensor imageIds; + Tensor originalImageSizes; + std::vector target_boxes; + std::vector target_labels; + }; + + class CocoDataset { + public: + CocoDataset( + const std::string& list_file, + int world_rank, + int world_size, + int batch_size, + int num_threads, + int prefetch_size, + bool val + ); + + std::shared_ptr getLabels(std::string list_file); + + std::shared_ptr getImages( + std::string list_file, + std::vector& transformfns + ); + + using iterator = detail::DatasetIterator; + + iterator begin() { + return iterator(this); + } + + iterator end() { + return iterator(); + } + + int64_t size() const; + + CocoData get(const uint64_t idx); + + void resample(); + + private: + std::shared_ptr> batched_; + std::shared_ptr shuffled_; + }; + + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/dataset/CocoTransforms.cpp b/flashlight/pkg/vision/dataset/CocoTransforms.cpp index 1739367..0c05dbb 100644 --- a/flashlight/pkg/vision/dataset/CocoTransforms.cpp +++ b/flashlight/pkg/vision/dataset/CocoTransforms.cpp @@ -18,241 +18,242 @@ namespace { int randomInt(int min, int max) { - return std::rand() % (max - min + 1) + min; + return std::rand() % (max - min + 1) + min; } } // namespace namespace fl::pkg::vision { -std::vector -crop(const std::vector& in, int x, int y, int tw, int th) { - const Tensor& image = in[ImageIdx]; - const Tensor croppedImage = fl::pkg::vision::crop(image, x, y, tw, th); - - const Tensor& boxes = in[BboxesIdx]; - - const std::vector translateVector = {x, y, x, y}; - const std::vector maxSizeVector = {tw, th}; - Tensor targetSize = Tensor::fromVector(maxSizeVector); - - const Tensor translateArray = Tensor::fromVector(translateVector); - const Tensor maxSizeArray = Tensor::fromVector(maxSizeVector); - - Tensor croppedBoxes = boxes; - Tensor labels = in[ClassesIdx]; - - if (!croppedBoxes.isEmpty()) { - croppedBoxes = croppedBoxes - translateArray; - croppedBoxes = fl::reshape(croppedBoxes, {2, 2, boxes.dim(1)}); - croppedBoxes = fl::minimum(croppedBoxes, maxSizeArray); - croppedBoxes = fl::maximum(croppedBoxes, 0.0); - Tensor keep = fl::all( - croppedBoxes(fl::span, fl::range(1, 2), fl::span) > - croppedBoxes(fl::span, fl::range(0, 1), fl::span), - {0}); - croppedBoxes = fl::reshape(croppedBoxes, {4, boxes.dim(1)}); - croppedBoxes = croppedBoxes(fl::span, keep); - labels = labels(fl::span, keep); - } - return { - croppedImage, - targetSize, - in[ImageIdIdx], - in[OriginalSizeIdx], - croppedBoxes, - labels}; +std::vector crop(const std::vector& in, int x, int y, int tw, int th) { + const Tensor& image = in[ImageIdx]; + const Tensor croppedImage = fl::pkg::vision::crop(image, x, y, tw, th); + + const Tensor& boxes = in[BboxesIdx]; + + const std::vector translateVector = {x, y, x, y}; + const std::vector maxSizeVector = {tw, th}; + Tensor targetSize = Tensor::fromVector(maxSizeVector); + + const Tensor translateArray = Tensor::fromVector(translateVector); + const Tensor maxSizeArray = Tensor::fromVector(maxSizeVector); + + Tensor croppedBoxes = boxes; + Tensor labels = in[ClassesIdx]; + + if(!croppedBoxes.isEmpty()) { + croppedBoxes = croppedBoxes - translateArray; + croppedBoxes = fl::reshape(croppedBoxes, {2, 2, boxes.dim(1)}); + croppedBoxes = fl::minimum(croppedBoxes, maxSizeArray); + croppedBoxes = fl::maximum(croppedBoxes, 0.0); + Tensor keep = fl::all( + croppedBoxes(fl::span, fl::range(1, 2), fl::span) + > croppedBoxes(fl::span, fl::range(0, 1), fl::span), + {0} + ); + croppedBoxes = fl::reshape(croppedBoxes, {4, boxes.dim(1)}); + croppedBoxes = croppedBoxes(fl::span, keep); + labels = labels(fl::span, keep); + } + return { + croppedImage, + targetSize, + in[ImageIdIdx], + in[OriginalSizeIdx], + croppedBoxes, + labels}; }; std::vector hflip(const std::vector& in) { - Tensor image = in[ImageIdx]; - const int w = image.dim(0); - image = image(fl::range(w - 1, -1, -1)); - - Tensor bboxes = in[BboxesIdx]; - if (!bboxes.isEmpty()) { - Tensor bboxes_flip = Tensor(bboxes.shape()); - bboxes_flip(0) = (bboxes(2) * -1) + w; - bboxes_flip(1) = bboxes(1); - bboxes_flip(2) = (bboxes(0) * -1) + w; - bboxes_flip(3) = bboxes(3); - bboxes = bboxes_flip; - } - return { - image, - in[TargetSizeIdx], - in[ImageIdIdx], - in[OriginalSizeIdx], - bboxes, - in[ClassesIdx]}; -} - -std::vector normalize(const std::vector& in) { - auto boxes = in[BboxesIdx]; - - if (!boxes.isEmpty()) { - auto image = in[ImageIdx]; - auto w = float(image.dim(0)); - auto h = float(image.dim(1)); - - boxes = xyxy2cxcywh(boxes); - const std::vector ratioVector = {w, h, w, h}; - Tensor ratioArray = Tensor::fromVector(ratioVector); - boxes = boxes / ratioArray; - } - return { - in[ImageIdx], - in[TargetSizeIdx], - in[ImageIdIdx], - in[OriginalSizeIdx], - boxes, - in[ClassesIdx]}; -} - -std::vector -randomResize(std::vector inputs, int size, int maxsize) { - auto getSize = [](const Tensor& in, int size, int maxSize = 0) { - int w = in.dim(0); - int h = in.dim(1); - // long size; - if (maxSize > 0) { - float minOriginalSize = std::min(w, h); - float maxOriginalSize = std::max(w, h); - if (maxOriginalSize / minOriginalSize * size > maxSize) { - size = round(maxSize * minOriginalSize / maxOriginalSize); - } - } - - if ((w <= h && w == size) || (h <= w && h == size)) { - return std::make_pair(w, h); - } - int ow, oh; - if (w < h) { - ow = size; - oh = size * h / w; - } else { - oh = size; - ow = size * w / h; + Tensor image = in[ImageIdx]; + const int w = image.dim(0); + image = image(fl::range(w - 1, -1, -1)); + + Tensor bboxes = in[BboxesIdx]; + if(!bboxes.isEmpty()) { + Tensor bboxes_flip = Tensor(bboxes.shape()); + bboxes_flip(0) = (bboxes(2) * -1) + w; + bboxes_flip(1) = bboxes(1); + bboxes_flip(2) = (bboxes(0) * -1) + w; + bboxes_flip(3) = bboxes(3); + bboxes = bboxes_flip; } - return std::make_pair(ow, oh); - }; - - Tensor image = inputs[ImageIdx]; - auto output_size = getSize(image, size, maxsize); - const Shape originalDims = image.shape(); - Tensor resizedImage; - resizedImage = fl::resize( - image, - {output_size.first, output_size.second}, - InterpolationMode::Bilinear); - const Shape resizedDims = resizedImage.shape(); - - Tensor boxes = inputs[BboxesIdx]; - if (!boxes.isEmpty()) { - const float ratioWidth = float(resizedDims[0]) / float(originalDims[0]); - const float ratioHeight = float(resizedDims[1]) / float(originalDims[1]); - - const std::vector resizeVector = { - ratioWidth, ratioHeight, ratioWidth, ratioHeight}; - Tensor resizedArray = Tensor::fromVector(resizeVector); - boxes = boxes * resizedArray; - } - - std::vector imageSizeArray = {resizedImage.dim(1), resizedImage.dim(0)}; - Tensor sizeArray = Tensor::fromVector(imageSizeArray); - return { - resizedImage, - sizeArray, - inputs[ImageIdIdx], - inputs[OriginalSizeIdx], - boxes, - inputs[ClassesIdx]}; + return { + image, + in[TargetSizeIdx], + in[ImageIdIdx], + in[OriginalSizeIdx], + bboxes, + in[ClassesIdx]}; } -TransformAllFunction Normalize( - std::vector meanVector, - std::vector stdVector) { - const Tensor mean = Tensor::fromVector({1, 1, 3}, meanVector); - const Tensor std = Tensor::fromVector({1, 1, 3}, stdVector); - return [mean, std](const std::vector& in) { - // Normalize Boxes +std::vector normalize(const std::vector& in) { auto boxes = in[BboxesIdx]; - if (!boxes.isEmpty()) { - auto image = in[ImageIdx]; - auto w = float(image.dim(0)); - auto h = float(image.dim(1)); + if(!boxes.isEmpty()) { + auto image = in[ImageIdx]; + auto w = float(image.dim(0)); + auto h = float(image.dim(1)); - boxes = xyxy2cxcywh(boxes); - const std::vector ratioVector = {w, h, w, h}; - Tensor ratioArray = Tensor::fromVector(ratioVector); - boxes = boxes / ratioArray; + boxes = xyxy2cxcywh(boxes); + const std::vector ratioVector = {w, h, w, h}; + Tensor ratioArray = Tensor::fromVector(ratioVector); + boxes = boxes / ratioArray; } - // Normalize Image - Tensor image = in[ImageIdx].astype(fl::dtype::f32) / 255.f; - image = image - mean; - image = image / std; - std::vector outputs = { - image, + return { + in[ImageIdx], in[TargetSizeIdx], in[ImageIdIdx], in[OriginalSizeIdx], boxes, in[ClassesIdx]}; - return outputs; - }; +} + +std::vector randomResize(std::vector inputs, int size, int maxsize) { + auto getSize = [](const Tensor& in, int size, int maxSize = 0) { + int w = in.dim(0); + int h = in.dim(1); + // long size; + if(maxSize > 0) { + float minOriginalSize = std::min(w, h); + float maxOriginalSize = std::max(w, h); + if(maxOriginalSize / minOriginalSize * size > maxSize) { + size = round(maxSize * minOriginalSize / maxOriginalSize); + } + } + + if((w <= h && w == size) || (h <= w && h == size)) { + return std::make_pair(w, h); + } + int ow, oh; + if(w < h) { + ow = size; + oh = size * h / w; + } else { + oh = size; + ow = size * w / h; + } + return std::make_pair(ow, oh); + }; + + Tensor image = inputs[ImageIdx]; + auto output_size = getSize(image, size, maxsize); + const Shape originalDims = image.shape(); + Tensor resizedImage; + resizedImage = fl::resize( + image, + {output_size.first, output_size.second}, + InterpolationMode::Bilinear + ); + const Shape resizedDims = resizedImage.shape(); + + Tensor boxes = inputs[BboxesIdx]; + if(!boxes.isEmpty()) { + const float ratioWidth = float(resizedDims[0]) / float(originalDims[0]); + const float ratioHeight = float(resizedDims[1]) / float(originalDims[1]); + + const std::vector resizeVector = { + ratioWidth, ratioHeight, ratioWidth, ratioHeight}; + Tensor resizedArray = Tensor::fromVector(resizeVector); + boxes = boxes * resizedArray; + } + + std::vector imageSizeArray = {resizedImage.dim(1), resizedImage.dim(0)}; + Tensor sizeArray = Tensor::fromVector(imageSizeArray); + return { + resizedImage, + sizeArray, + inputs[ImageIdIdx], + inputs[OriginalSizeIdx], + boxes, + inputs[ClassesIdx]}; +} + +TransformAllFunction Normalize( + std::vector meanVector, + std::vector stdVector +) { + const Tensor mean = Tensor::fromVector({1, 1, 3}, meanVector); + const Tensor std = Tensor::fromVector({1, 1, 3}, stdVector); + return [mean, std](const std::vector& in) { + // Normalize Boxes + auto boxes = in[BboxesIdx]; + + if(!boxes.isEmpty()) { + auto image = in[ImageIdx]; + auto w = float(image.dim(0)); + auto h = float(image.dim(1)); + + boxes = xyxy2cxcywh(boxes); + const std::vector ratioVector = {w, h, w, h}; + Tensor ratioArray = Tensor::fromVector(ratioVector); + boxes = boxes / ratioArray; + } + // Normalize Image + Tensor image = in[ImageIdx].astype(fl::dtype::f32) / 255.f; + image = image - mean; + image = image / std; + std::vector outputs = { + image, + in[TargetSizeIdx], + in[ImageIdIdx], + in[OriginalSizeIdx], + boxes, + in[ClassesIdx]}; + return outputs; + }; } TransformAllFunction randomSelect(std::vector fns) { - return [fns](const std::vector& in) { - TransformAllFunction randomFunc = fns[std::rand() % fns.size()]; - return randomFunc(in); - }; + return [fns](const std::vector& in) { + TransformAllFunction randomFunc = fns[std::rand() % fns.size()]; + return randomFunc(in); + }; }; TransformAllFunction randomSizeCrop(int minSize, int maxSize) { - return [minSize, maxSize](const std::vector& in) { - const Tensor& image = in[0]; - const int w = image.dim(0); - const int h = image.dim(1); - const int tw = randomInt(minSize, std::min(w, maxSize)); - const int th = randomInt(minSize, std::min(h, maxSize)); - const int x = std::rand() % (w - tw + 1); - const int y = std::rand() % (h - th + 1); - return crop(in, x, y, tw, th); - }; + return [minSize, maxSize](const std::vector& in) { + const Tensor& image = in[0]; + const int w = image.dim(0); + const int h = image.dim(1); + const int tw = randomInt(minSize, std::min(w, maxSize)); + const int th = randomInt(minSize, std::min(h, maxSize)); + const int x = std::rand() % (w - tw + 1); + const int y = std::rand() % (h - th + 1); + return crop(in, x, y, tw, th); + }; }; TransformAllFunction randomResize(std::vector sizes, int maxsize) { - assert(!sizes.empty()); - auto resizeCoco = [sizes, maxsize](std::vector in) { - assert(in.size() == 6); assert(!sizes.empty()); - int randomIndex = rand() % sizes.size(); - int size = sizes[randomIndex]; - const Tensor originalImage = in[0]; - return randomResize(in, size, maxsize); - }; - return resizeCoco; + auto resizeCoco = [sizes, maxsize](std::vector in) { + assert(in.size() == 6); + assert(!sizes.empty()); + int randomIndex = rand() % sizes.size(); + int size = sizes[randomIndex]; + const Tensor originalImage = in[0]; + return randomResize(in, size, maxsize); + }; + return resizeCoco; } TransformAllFunction randomHorizontalFlip(float p) { - return [p](const std::vector& in) { - if (static_cast(std::rand()) / static_cast(RAND_MAX) > p) { - return hflip(in); - } else { - return in; - } - }; + return [p](const std::vector& in) { + if(static_cast(std::rand()) / static_cast(RAND_MAX) > p) { + return hflip(in); + } else { + return in; + } + }; } TransformAllFunction compose(std::vector fns) { - return [fns](const std::vector& in) { - std::vector out = in; - for (const auto& fn : fns) { - out = fn(out); - } - return out; - }; + return [fns](const std::vector& in) { + std::vector out = in; + for(const auto& fn : fns) { + out = fn(out); + } + return out; + }; } } // namespace fl diff --git a/flashlight/pkg/vision/dataset/CocoTransforms.h b/flashlight/pkg/vision/dataset/CocoTransforms.h index 15aca2e..a112d26 100644 --- a/flashlight/pkg/vision/dataset/CocoTransforms.h +++ b/flashlight/pkg/vision/dataset/CocoTransforms.h @@ -11,16 +11,16 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { -enum DatasetIndices { - ImageIdx = 0, - TargetSizeIdx = 1, - ImageIdIdx = 2, - OriginalSizeIdx = 3, - BboxesIdx = 4, - ClassesIdx = 5 -}; + enum DatasetIndices { + ImageIdx = 0, + TargetSizeIdx = 1, + ImageIdIdx = 2, + OriginalSizeIdx = 3, + BboxesIdx = 4, + ClassesIdx = 5 + }; /* * Crop the image and translate bounding boxes accordingly @@ -30,14 +30,13 @@ enum DatasetIndices { * @param ty is the target height * This function will remove bounding boxes which do not exist within the crop */ -std::vector -crop(const std::vector& in, int x, int y, int tw, int th); + std::vector crop(const std::vector& in, int x, int y, int tw, int th); /* * Flip the image horizontally and adjust the bounding boxes acordingly * @param in vector of input arrays */ -std::vector hflip(const std::vector& in); + std::vector hflip(const std::vector& in); /* * "normalize" the bounding boxes @@ -45,7 +44,7 @@ std::vector hflip(const std::vector& in); * adjust bounding boxes from bottom left and top right coordinates to center * x,y and width and height and then divide by total image width and height */ -std::vector normalize(const std::vector& in); + std::vector normalize(const std::vector& in); /* * Randomly resize image and bounding boxes from @param inputs, where shortest @@ -53,49 +52,48 @@ std::vector normalize(const std::vector& in); * side is shorter than @param maxsize. * Adjust bboxes accordingly. */ -std::vector -randomResize(std::vector inputs, int size, int maxsize); + std::vector randomResize(std::vector inputs, int size, int maxsize); /* * Returns a function that "Normalizes" bounding boxes so that they represent * center coordintates and height and width ratios of the entire image. * Also normalize images by @param meanVector and @param stdVector */ -TransformAllFunction Normalize( - std::vector meanVector = {0.485, 0.456, 0.406}, - std::vector stdVector = {0.229, 0.224, 0.225}); + TransformAllFunction Normalize( + std::vector meanVector = { 0.485, 0.456, 0.406 }, + std::vector stdVector = { 0.229, 0.224, 0.225 }); /* * Returns a `TransformAllFunction` which randomly selects from @param fns * and call it on imput data */ -TransformAllFunction randomSelect(std::vector fns); + TransformAllFunction randomSelect(std::vector fns); /* * Returns a `TransformAllFunction` which randomly resizes image and bounding * boxes between @param minSize and @param maxSize */ -TransformAllFunction randomSizeCrop(int minSize, int maxSize); + TransformAllFunction randomSizeCrop(int minSize, int maxSize); /* * Returns a `TransformAllFunction` which randomly resizes images from a choice * in @param sizes, and ensures longest side of image is less than @param * maxsize */ -TransformAllFunction randomResize(std::vector sizes, int maxsize); + TransformAllFunction randomResize(std::vector sizes, int maxsize); /* * Returns a `TransformAllFunction` which random flips image and bounding * boxes with a probility of @param p */ -TransformAllFunction randomHorizontalFlip(float p); + TransformAllFunction randomHorizontalFlip(float p); /* * Returns a `TransformAllFunction` which calls each @param fns on * the input data */ -TransformAllFunction compose(std::vector fns); + TransformAllFunction compose(std::vector fns); -} // namespace vision + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/dataset/DistributedDataset.cpp b/flashlight/pkg/vision/dataset/DistributedDataset.cpp index 71b560a..2980ec0 100644 --- a/flashlight/pkg/vision/dataset/DistributedDataset.cpp +++ b/flashlight/pkg/vision/dataset/DistributedDataset.cpp @@ -17,34 +17,35 @@ DistributedDataset::DistributedDataset( int64_t nRepeated, int64_t numThreads, int64_t prefetchSize, - BatchDatasetPolicy batchPolicy) { - shuffle_ = std::make_shared(base); - auto permfn = [worldSize, worldRank, nRepeated](int64_t idx) { - return (idx * worldSize + worldRank) / nRepeated; - }; - - int partitionSize = shuffle_->size() / worldSize; - int leftOver = shuffle_->size() % worldSize; - if (worldRank < leftOver) { - partitionSize++; - } - ds_ = std::make_shared(shuffle_, permfn, partitionSize); - ds_ = std::make_shared(ds_, numThreads, prefetchSize); - ds_ = std::make_shared(ds_, batchSize, batchPolicy); + BatchDatasetPolicy batchPolicy +) { + shuffle_ = std::make_shared(base); + auto permfn = [worldSize, worldRank, nRepeated](int64_t idx) { + return (idx * worldSize + worldRank) / nRepeated; + }; + + int partitionSize = shuffle_->size() / worldSize; + int leftOver = shuffle_->size() % worldSize; + if(worldRank < leftOver) { + partitionSize++; + } + ds_ = std::make_shared(shuffle_, permfn, partitionSize); + ds_ = std::make_shared(ds_, numThreads, prefetchSize); + ds_ = std::make_shared(ds_, batchSize, batchPolicy); } std::vector DistributedDataset::get(const int64_t idx) const { - checkIndexBounds(idx); - return ds_->get(idx); + checkIndexBounds(idx); + return ds_->get(idx); } void DistributedDataset::resample(const int seed) { - shuffle_->setSeed(seed); - shuffle_->resample(); + shuffle_->setSeed(seed); + shuffle_->resample(); } int64_t DistributedDataset::size() const { - return ds_->size(); + return ds_->size(); } } // namespace fl diff --git a/flashlight/pkg/vision/dataset/DistributedDataset.h b/flashlight/pkg/vision/dataset/DistributedDataset.h index f352467..78135a8 100644 --- a/flashlight/pkg/vision/dataset/DistributedDataset.h +++ b/flashlight/pkg/vision/dataset/DistributedDataset.h @@ -11,31 +11,32 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { -class DistributedDataset : public Dataset { - public: - DistributedDataset( - std::shared_ptr base, - int64_t worldRank, - int64_t worldSize, - int64_t batchSize, - int64_t nRepeated, - int64_t numThreads, - int64_t prefetchSize, - BatchDatasetPolicy batchpolicy = fl::BatchDatasetPolicy::INCLUDE_LAST); + class DistributedDataset : public Dataset { + public: + DistributedDataset( + std::shared_ptr base, + int64_t worldRank, + int64_t worldSize, + int64_t batchSize, + int64_t nRepeated, + int64_t numThreads, + int64_t prefetchSize, + BatchDatasetPolicy batchpolicy = fl::BatchDatasetPolicy::INCLUDE_LAST + ); - std::vector get(const int64_t idx) const override; + std::vector get(const int64_t idx) const override; - void resample(const int seed = 0); + void resample(const int seed = 0); - int64_t size() const override; + int64_t size() const override; - private: - std::shared_ptr ds_; - std::shared_ptr shuffle_; -}; + private: + std::shared_ptr ds_; + std::shared_ptr shuffle_; + }; -} // namespace vision + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/dataset/Imagenet.cpp b/flashlight/pkg/vision/dataset/Imagenet.cpp index bbbe24b..a0c25f1 100644 --- a/flashlight/pkg/vision/dataset/Imagenet.cpp +++ b/flashlight/pkg/vision/dataset/Imagenet.cpp @@ -22,88 +22,102 @@ namespace { // TODO: test against std::filesystem::recursive_directory_iterator std::vector fileGlob(const std::string& pattern) { - glob_t result; - glob(pattern.c_str(), GLOB_TILDE, nullptr, &result); - std::vector ret; - for (unsigned int i = 0; i < result.gl_pathc; ++i) { - ret.emplace_back(result.gl_pathv[i]); - } - globfree(&result); - return ret; + glob_t result; + glob(pattern.c_str(), GLOB_TILDE, nullptr, &result); + std::vector ret; + for(unsigned int i = 0; i < result.gl_pathc; ++i) { + ret.emplace_back(result.gl_pathv[i]); + } + globfree(&result); + return ret; } } // namespace namespace fl::pkg::vision { std::unordered_map getImagenetLabels( - const fs::path& labelFile) { - std::unordered_map labels; - std::vector lines; - std::ifstream inFile(labelFile); - if (!inFile) { - throw std::invalid_argument( - "fl::pkg::vision::getImagenetLabels given invalid labelFile path"); - } - for (std::string str; std::getline(inFile, str);) { - lines.emplace_back(str); - } - - if (lines.empty()) { - throw std::runtime_error( - "In function imagenetLabels: No lines in file:" + labelFile.string()); - } - for (int i = 0; i < lines.size(); i++) { - std::string line = lines[i]; - auto it = line.find(','); - if (it != std::string::npos) { - std::string label = line.substr(0, it); - labels[label] = i; - } else { - throw std::runtime_error( - "In function imagenetLabels: Invalid label format for line: " + line); + const fs::path& labelFile +) { + std::unordered_map labels; + std::vector lines; + std::ifstream inFile(labelFile); + if(!inFile) { + throw std::invalid_argument( + "fl::pkg::vision::getImagenetLabels given invalid labelFile path" + ); + } + for(std::string str; std::getline(inFile, str);) { + lines.emplace_back(str); + } + + if(lines.empty()) { + throw std::runtime_error( + "In function imagenetLabels: No lines in file:" + labelFile.string() + ); + } + for(int i = 0; i < lines.size(); i++) { + std::string line = lines[i]; + auto it = line.find(','); + if(it != std::string::npos) { + std::string label = line.substr(0, it); + labels[label] = i; + } else { + throw std::runtime_error( + "In function imagenetLabels: Invalid label format for line: " + line + ); + } } - } - return labels; + return labels; } std::shared_ptr imagenetDataset( const fs::path& imgDir, const std::unordered_map& labelMap, - std::vector transformfns) { - std::vector filepaths = fileGlob(imgDir.string() + "/**/*.JPEG"); - - if (filepaths.empty()) { - throw std::runtime_error( - "No images were found in imagenet directory: " + imgDir.string()); - } - - // Create image dataset - std::shared_ptr imageDataset = - fl::pkg::vision::jpegLoader(filepaths); - imageDataset = std::make_shared(imageDataset, transformfns); - - // Create labels from filepaths - auto getLabelIdxs = [&labelMap](const std::string& s) -> uint64_t { - std::string parentPath = s.substr(0, s.rfind('/')); - std::string label = parentPath.substr(parentPath.rfind('/') + 1); - if (labelMap.find(label) != labelMap.end()) { - return labelMap.at(label); - } else { - throw std::runtime_error("Label: " + label + " not found in label map"); + std::vector transformfns +) { + std::vector filepaths = fileGlob(imgDir.string() + "/**/*.JPEG"); + + if(filepaths.empty()) { + throw std::runtime_error( + "No images were found in imagenet directory: " + imgDir.string() + ); } - return labelMap.at(label); - }; - - std::vector labels(filepaths.size()); - std::transform( - filepaths.begin(), filepaths.end(), labels.begin(), getLabelIdxs); - - auto labelDataset = std::make_shared(labels, [](uint64_t x) { - std::vector result{fl::fromScalar(x, fl::dtype::u64)}; - return result; - }); - return std::make_shared( - MergeDataset({imageDataset, labelDataset})); + + // Create image dataset + std::shared_ptr imageDataset = + fl::pkg::vision::jpegLoader(filepaths); + imageDataset = std::make_shared(imageDataset, transformfns); + + // Create labels from filepaths + auto getLabelIdxs = [&labelMap](const std::string& s) -> uint64_t { + std::string parentPath = s.substr(0, s.rfind('/')); + std::string label = parentPath.substr(parentPath.rfind('/') + 1); + if(labelMap.find(label) != labelMap.end()) { + return labelMap.at(label); + } else { + throw std::runtime_error("Label: " + label + " not found in label map"); + } + return labelMap.at(label); + }; + + std::vector labels(filepaths.size()); + std::transform( + filepaths.begin(), + filepaths.end(), + labels.begin(), + getLabelIdxs + ); + + auto labelDataset = std::make_shared( + labels, + [](uint64_t x) { + std::vector result{fl::fromScalar(x, fl::dtype::u64)}; + return result; + } + ); + return std::make_shared( + MergeDataset({imageDataset, labelDataset}) + ); } } // namespace fl diff --git a/flashlight/pkg/vision/dataset/Imagenet.h b/flashlight/pkg/vision/dataset/Imagenet.h index e5aac55..9b92268 100644 --- a/flashlight/pkg/vision/dataset/Imagenet.h +++ b/flashlight/pkg/vision/dataset/Imagenet.h @@ -39,13 +39,14 @@ */ namespace fl { namespace pkg { -namespace vision { + namespace vision { /* Given the path to the imagenet labels file labels.txt, * create a map with a unique id for each label that can be used for training */ -std::unordered_map getImagenetLabels( - const fs::path& labelFile); + std::unordered_map getImagenetLabels( + const fs::path& labelFile + ); /* * Creates an `ImageDataset` by globbing for images in @@ -65,14 +66,15 @@ std::unordered_map getImagenetLabels( * std::cout << sample[1].shape() << std::endl; // {1, 1, 1, 1} * */ -std::shared_ptr imagenetDataset( - const fs::path& imgDir, - const std::unordered_map& labelMap, - std::vector transformfns); + std::shared_ptr imagenetDataset( + const fs::path& imgDir, + const std::unordered_map& labelMap, + std::vector transformfns + ); -constexpr uint64_t kImagenetInputIdx = 0; -constexpr uint64_t kImagenetTargetIdx = 1; + constexpr uint64_t kImagenetInputIdx = 0; + constexpr uint64_t kImagenetTargetIdx = 1; -} // namespace vision + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/dataset/Jpeg.cpp b/flashlight/pkg/vision/dataset/Jpeg.cpp index d73c391..f2b74a6 100644 --- a/flashlight/pkg/vision/dataset/Jpeg.cpp +++ b/flashlight/pkg/vision/dataset/Jpeg.cpp @@ -22,30 +22,35 @@ namespace fl::pkg::vision { * number of channels to create an array with 3 channels */ Tensor loadJpeg(const std::string& fp, int desiredNumberOfChannels /* = 3 */) { - int w, h, c; - // STB image will automatically return desiredNumberOfChannels. - // NB: c will be the original number of channels - unsigned char* img = - stbi_load(fp.c_str(), &w, &h, &c, desiredNumberOfChannels); - if (img) { - // Load array first as C X W X H, since stb has channel along first - // dimension - Tensor result = Tensor::fromBuffer( - {desiredNumberOfChannels, w, h}, img, MemoryLocation::Host); - stbi_image_free(img); - // Then reorder to W X H X C - return fl::transpose(result, {1, 2, 0}); - } else { - throw std::invalid_argument("Could not load from filepath" + fp); - } + int w, h, c; + // STB image will automatically return desiredNumberOfChannels. + // NB: c will be the original number of channels + unsigned char* img = + stbi_load(fp.c_str(), &w, &h, &c, desiredNumberOfChannels); + if(img) { + // Load array first as C X W X H, since stb has channel along first + // dimension + Tensor result = Tensor::fromBuffer( + {desiredNumberOfChannels, w, h}, + img, + MemoryLocation::Host + ); + stbi_image_free(img); + // Then reorder to W X H X C + return fl::transpose(result, {1, 2, 0}); + } else { + throw std::invalid_argument("Could not load from filepath" + fp); + } } std::shared_ptr jpegLoader(std::vector fps) { - return std::make_shared>( - fps, [](const std::string& fp) { - std::vector result = {loadJpeg(fp)}; - return result; - }); + return std::make_shared>( + fps, + [](const std::string& fp) { + std::vector result = {loadJpeg(fp)}; + return result; + } + ); } } // namespace fl diff --git a/flashlight/pkg/vision/dataset/Jpeg.h b/flashlight/pkg/vision/dataset/Jpeg.h index ef008a1..85a5817 100644 --- a/flashlight/pkg/vision/dataset/Jpeg.h +++ b/flashlight/pkg/vision/dataset/Jpeg.h @@ -13,12 +13,12 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { -Tensor loadJpeg(const std::string& fp, int desiredNumberOfChannels = 3); + Tensor loadJpeg(const std::string& fp, int desiredNumberOfChannels = 3); -std::shared_ptr jpegLoader(std::vector fps); + std::shared_ptr jpegLoader(std::vector fps); -} // namespace vision + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/dataset/LoaderDataset.h b/flashlight/pkg/vision/dataset/LoaderDataset.h index 8d49fb3..43df63d 100644 --- a/flashlight/pkg/vision/dataset/LoaderDataset.h +++ b/flashlight/pkg/vision/dataset/LoaderDataset.h @@ -11,33 +11,33 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { /* * Small generic utility class for loading data from a vector of type T into an * vector of arrayfire arrays */ -template -class LoaderDataset : public fl::Dataset { - public: - using LoadFunc = std::function(const T&)>; + template + class LoaderDataset : public fl::Dataset { + public: + using LoadFunc = std::function(const T&)>; - LoaderDataset(const std::vector& list, LoadFunc loadfn) - : list_(list), loadfn_(loadfn) {} + LoaderDataset(const std::vector& list, LoadFunc loadfn) : list_(list), + loadfn_(loadfn) {} - std::vector get(const int64_t idx) const override { - return loadfn_(list_[idx]); - } + std::vector get(const int64_t idx) const override { + return loadfn_(list_[idx]); + } - int64_t size() const override { - return list_.size(); - } + int64_t size() const override { + return list_.size(); + } - private: - std::vector list_; - LoadFunc loadfn_; -}; + private: + std::vector list_; + LoadFunc loadfn_; + }; -} // namespace vision + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/dataset/TransformAllDataset.h b/flashlight/pkg/vision/dataset/TransformAllDataset.h index 6a0484b..c651e9b 100644 --- a/flashlight/pkg/vision/dataset/TransformAllDataset.h +++ b/flashlight/pkg/vision/dataset/TransformAllDataset.h @@ -11,10 +11,10 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { -using TransformAllFunction = - std::function(const std::vector&)>; + using TransformAllFunction = + std::function(const std::vector&)>; /* * A view into a dataset where all arrays are transformed using the same @@ -24,26 +24,27 @@ using TransformAllFunction = * independent of each other. For example, a random crop must crop the iamge * but then also adjust the bounding boxes accordinly */ -class TransformAllDataset : public Dataset { - public: - TransformAllDataset( - std::shared_ptr dataset, - TransformAllFunction fn) - : dataset_(dataset), fn_(fn) {} - - std::vector get(const int64_t idx) const override { - return fn_(dataset_->get(idx)); - } - - int64_t size() const override { - return dataset_->size(); - } - - private: - std::shared_ptr dataset_; - const TransformAllFunction fn_; -}; - -} // namespace vision + class TransformAllDataset : public Dataset { + public: + TransformAllDataset( + std::shared_ptr dataset, + TransformAllFunction fn + ) : dataset_(dataset), + fn_(fn) {} + + std::vector get(const int64_t idx) const override { + return fn_(dataset_->get(idx)); + } + + int64_t size() const override { + return dataset_->size(); + } + + private: + std::shared_ptr dataset_; + const TransformAllFunction fn_; + }; + + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/dataset/Transforms.cpp b/flashlight/pkg/vision/dataset/Transforms.cpp index b65dc10..a40b89f 100644 --- a/flashlight/pkg/vision/dataset/Transforms.cpp +++ b/flashlight/pkg/vision/dataset/Transforms.cpp @@ -21,20 +21,20 @@ namespace { float randomFloat(float a, float b) { - float r = static_cast(std::rand()) / static_cast(RAND_MAX); - return a + (b - a) * r; + float r = static_cast(std::rand()) / static_cast(RAND_MAX); + return a + (b - a) * r; } -template +template T randomNegate(T a) { - float r = randomFloat(0, 1); - return r > 0.5 ? a : -a; + float r = randomFloat(0, 1); + return r > 0.5 ? a : -a; } -template +template T randomPerturbNegate(T base, T minNoise, T maxNoise) { - float noise = randomFloat(minNoise, maxNoise); - return randomNegate(base + noise); + float noise = randomFloat(minNoise, maxNoise); + return randomNegate(base + noise); } } // namespace @@ -42,160 +42,164 @@ T randomPerturbNegate(T base, T minNoise, T maxNoise) { namespace fl::pkg::vision { Tensor resizeSmallest(const Tensor& in, const int resize) { - const int w = in.dim(0); - const int h = in.dim(1); - int th, tw; - if (h > w) { - th = (resize * h) / w; - tw = resize; - } else { - th = resize; - tw = (resize * w) / h; - } - return fl::resize(in, {tw, th}, InterpolationMode::Bilinear); + const int w = in.dim(0); + const int h = in.dim(1); + int th, tw; + if(h > w) { + th = (resize * h) / w; + tw = resize; + } else { + th = resize; + tw = (resize * w) / h; + } + return fl::resize(in, {tw, th}, InterpolationMode::Bilinear); } Tensor resize(const Tensor& in, const int resize) { - return fl::resize(in, {resize, resize}, InterpolationMode::Bilinear); + return fl::resize(in, {resize, resize}, InterpolationMode::Bilinear); } -Tensor -crop(const Tensor& in, const int x, const int y, const int w, const int h) { - return in(fl::range(x, x + w), fl::range(y, y + h)); +Tensor crop(const Tensor& in, const int x, const int y, const int w, const int h) { + return in(fl::range(x, x + w), fl::range(y, y + h)); } Tensor centerCrop(const Tensor& in, const int size) { - const int w = in.dim(0); - const int h = in.dim(1); - const int cropLeft = std::round((static_cast(w) - size) / 2.); - const int cropTop = std::round((static_cast(h) - size) / 2.); - return crop(in, cropLeft, cropTop, size, size); + const int w = in.dim(0); + const int h = in.dim(1); + const int cropLeft = std::round((static_cast(w) - size) / 2.); + const int cropTop = std::round((static_cast(h) - size) / 2.); + return crop(in, cropLeft, cropTop, size, size); } Tensor rotate(const Tensor& input, const float theta, const Tensor& fillImg) { - return fl::rotate(input, theta, fillImg); + return fl::rotate(input, theta, fillImg); } Tensor skewX(const Tensor& input, const float theta, const Tensor& fillImg) { - return fl::shear(input, {theta, 0}, {}, fillImg); + return fl::shear(input, {theta, 0}, {}, fillImg); } Tensor skewY(const Tensor& input, const float theta, const Tensor& fillImg) { - return fl::shear(input, {0, theta}, {}, fillImg); + return fl::shear(input, {0, theta}, {}, fillImg); } Tensor translateX(const Tensor& input, const int shift, const Tensor& fillImg) { - return fl::translate(input, {shift, 0}, {}, fillImg); + return fl::translate(input, {shift, 0}, {}, fillImg); } Tensor translateY(const Tensor& input, const int shift, const Tensor& fillImg) { - return fl::translate(input, {0, shift}, {}, fillImg); + return fl::translate(input, {0, shift}, {}, fillImg); } Tensor colorEnhance(const Tensor& input, const float enhance) { - auto c = input.dim(2); - auto meanPic = fl::mean(input, {2}, /* keepDims = */ true); - meanPic = fl::tile(meanPic, {1, 1, c}); - return meanPic + enhance * (input - meanPic); + auto c = input.dim(2); + auto meanPic = fl::mean(input, {2}, /* keepDims = */ true); + meanPic = fl::tile(meanPic, {1, 1, c}); + return meanPic + enhance * (input - meanPic); } Tensor autoContrast(const Tensor& input) { - auto minPic = fl::amin(input); - auto maxPic = fl::amax(input); - if (fl::all(minPic == maxPic).asScalar()) { - return input; - } - - auto scale = fl::tile(255. / (maxPic - minPic), input.shape()); - minPic = fl::tile(minPic, input.shape()); - return scale * (input - minPic); + auto minPic = fl::amin(input); + auto maxPic = fl::amax(input); + if(fl::all(minPic == maxPic).asScalar()) { + return input; + } + + auto scale = fl::tile(255. / (maxPic - minPic), input.shape()); + minPic = fl::tile(minPic, input.shape()); + return scale * (input - minPic); } Tensor contrastEnhance(const Tensor& input, const float enhance) { - auto meanPic = fl::mean(input); - meanPic = fl::tile(meanPic, input.shape()); - return meanPic + enhance * (input - meanPic); + auto meanPic = fl::mean(input); + meanPic = fl::tile(meanPic, input.shape()); + return meanPic + enhance * (input - meanPic); } Tensor brightnessEnhance(const Tensor& input, const float enhance) { - return input * enhance; + return input * enhance; } Tensor invert(const Tensor& input) { - return 255. - input; + return 255. - input; } Tensor solarize(const Tensor& input, const float threshold) { - auto mask = (input < threshold); - return mask * input + (1 - mask) * (255 - input); + auto mask = (input < threshold); + return mask * input + (1 - mask) * (255 - input); } -Tensor -solarizeAdd(const Tensor& input, const float threshold, const float addValue) { - auto mask = (input < threshold); - return mask * (input + addValue) + (1 - mask) * input; +Tensor solarizeAdd(const Tensor& input, const float threshold, const float addValue) { + auto mask = (input < threshold); + return mask * (input + addValue) + (1 - mask) * input; } Tensor equalize(const Tensor& input) { - auto res = input; - for (int i = 0; i < 3; i++) { - auto resSlice = res(fl::span, fl::span, i); - auto hist = fl::histogram( - resSlice, /* numBins = */ 256, /* minVal = */ 0, /* maxVal = */ 255); - res(fl::span, fl::span, i) = fl::equalize(resSlice, hist); - } - return res; + auto res = input; + for(int i = 0; i < 3; i++) { + auto resSlice = res(fl::span, fl::span, i); + auto hist = fl::histogram( + resSlice, /* numBins = */ + 256, /* minVal = */ + 0, /* maxVal = */ + 255 + ); + res(fl::span, fl::span, i) = fl::equalize(resSlice, hist); + } + return res; } Tensor posterize(const Tensor& input, const int bitsToKeep) { - if (bitsToKeep < 1 || bitsToKeep > 8) { - throw std::invalid_argument("bitsToKeep needs to be in [1, 8]"); - } - uint8_t mask = ~((1 << (8 - bitsToKeep)) - 1); - auto res = input.astype(fl::dtype::u8) && mask; - return res.astype(input.type()); + if(bitsToKeep < 1 || bitsToKeep > 8) { + throw std::invalid_argument("bitsToKeep needs to be in [1, 8]"); + } + uint8_t mask = ~((1 << (8 - bitsToKeep)) - 1); + auto res = input.astype(fl::dtype::u8) && mask; + return res.astype(input.type()); } Tensor sharpnessEnhance(const Tensor& input, const float enhance) { - const int w = input.dim(0); - const int h = input.dim(1); - const int c = input.dim(2); - const int kernelSize = 7; - const int stride = 1; - const int samePad = static_cast(PaddingMode::SAME); - - auto meanPic = fl::mean(input, {2}); - auto blurKernel = fl::gaussianFilter({kernelSize, kernelSize}); - auto blur = fl::conv2d( - fl::reshape(meanPic, {w, h, 1, 1}), - blurKernel, - /* sx = */ stride, - /* sy = */ stride, - // ensure output size is the same as input size - /* px = */ derivePadding(w, kernelSize, 1, samePad, 1), - /* py = */ derivePadding(h, kernelSize, 1, samePad, 1)); - blur = fl::reshape(blur, {w, h}); - auto diff = fl::tile(meanPic - blur, {1, 1, c}); - return input + enhance * diff; + const int w = input.dim(0); + const int h = input.dim(1); + const int c = input.dim(2); + const int kernelSize = 7; + const int stride = 1; + const int samePad = static_cast(PaddingMode::SAME); + + auto meanPic = fl::mean(input, {2}); + auto blurKernel = fl::gaussianFilter({kernelSize, kernelSize}); + auto blur = fl::conv2d( + fl::reshape(meanPic, {w, h, 1, 1}), + blurKernel, + /* sx = */ stride, + /* sy = */ stride, + // ensure output size is the same as input size + /* px = */ derivePadding(w, kernelSize, 1, samePad, 1), + /* py = */ derivePadding(h, kernelSize, 1, samePad, 1) + ); + blur = fl::reshape(blur, {w, h}); + auto diff = fl::tile(meanPic - blur, {1, 1, c}); + return input + enhance * diff; } Tensor oneHot( const Tensor& targets, const int numClasses, - const float labelSmoothing) { - float offValue = labelSmoothing / numClasses; - float onValue = 1. - labelSmoothing; + const float labelSmoothing +) { + float offValue = labelSmoothing / numClasses; + float onValue = 1. - labelSmoothing; - int X = targets.elements(); - auto y = fl::reshape(targets, {1, X}); - auto A = fl::arange({numClasses, X}); - auto B = fl::tile(y, {numClasses}); - auto mask = A == B; // [C X] + int X = targets.elements(); + auto y = fl::reshape(targets, {1, X}); + auto A = fl::arange({numClasses, X}); + auto B = fl::tile(y, {numClasses}); + auto mask = A == B; // [C X] - Tensor out = fl::full({numClasses, X}, onValue); - out = out * mask + offValue; + Tensor out = fl::full({numClasses, X}, onValue); + out = out * mask + offValue; - return out; + return out; } std::pair mixupBatch( @@ -203,25 +207,26 @@ std::pair mixupBatch( const Tensor& input, const Tensor& target, const int numClasses, - const float labelSmoothing) { - // in : W x H x C x B - // target: B x 1 - auto targetOneHot = oneHot(target, numClasses, labelSmoothing); - if (lambda == 0) { - return {input, targetOneHot}; - } - - // mix input - auto inputFlipped = fl::flip(input, 3); - auto inputMixed = lambda * inputFlipped + (1 - lambda) * input; - - // mix target - auto targetOneHotFlipped = - oneHot(fl::flip(target, 0), numClasses, labelSmoothing); - auto targetOneHotMixed = - lambda * targetOneHotFlipped + (1 - lambda) * targetOneHot; - - return {inputMixed, targetOneHotMixed}; + const float labelSmoothing +) { + // in : W x H x C x B + // target: B x 1 + auto targetOneHot = oneHot(target, numClasses, labelSmoothing); + if(lambda == 0) { + return {input, targetOneHot}; + } + + // mix input + auto inputFlipped = fl::flip(input, 3); + auto inputMixed = lambda * inputFlipped + (1 - lambda) * input; + + // mix target + auto targetOneHotFlipped = + oneHot(fl::flip(target, 0), numClasses, labelSmoothing); + auto targetOneHotMixed = + lambda * targetOneHotFlipped + (1 - lambda) * targetOneHot; + + return {inputMixed, targetOneHotMixed}; } std::pair cutmixBatch( @@ -229,72 +234,73 @@ std::pair cutmixBatch( const Tensor& input, const Tensor& target, const int numClasses, - const float labelSmoothing) { - // in : W x H x C x B - // target: B x 1 - auto targetOneHot = oneHot(target, numClasses, labelSmoothing); - if (lambda == 0) { - return {input, targetOneHot}; - } - - // mix input - auto inputFlipped = fl::flip(input, 3); - - const float lambdaSqrt = std::sqrt(lambda); - const int w = input.dim(0); - const int h = input.dim(1); - const int maskW = std::round(w * lambdaSqrt); - const int maskH = std::round(h * lambdaSqrt); - const int centerW = randomFloat(0, w); - const int centerH = randomFloat(0, h); - - const int x1 = std::max(0, centerW - maskW / 2); - const int x2 = std::min(w, centerW + maskW / 2 + 1); - const int y1 = std::max(0, centerH - maskH / 2); - const int y2 = std::min(h, centerH + maskH / 2 + 1); - - auto inputMixed = input; - inputMixed(fl::range(x1, x2), fl::range(y1, y2)) = - inputFlipped(fl::range(x1, x2), fl::range(y1, y2)); - auto newLambda = static_cast(x2 - x1) * (y2 - y1) / (w * h); - - // mix target - auto targetOneHotFlipped = - oneHot(fl::flip(target, 0), numClasses, labelSmoothing); - auto targetOneHotMixed = - newLambda * targetOneHotFlipped + (1 - newLambda) * targetOneHot; - - return {inputMixed, targetOneHotMixed}; + const float labelSmoothing +) { + // in : W x H x C x B + // target: B x 1 + auto targetOneHot = oneHot(target, numClasses, labelSmoothing); + if(lambda == 0) { + return {input, targetOneHot}; + } + + // mix input + auto inputFlipped = fl::flip(input, 3); + + const float lambdaSqrt = std::sqrt(lambda); + const int w = input.dim(0); + const int h = input.dim(1); + const int maskW = std::round(w * lambdaSqrt); + const int maskH = std::round(h * lambdaSqrt); + const int centerW = randomFloat(0, w); + const int centerH = randomFloat(0, h); + + const int x1 = std::max(0, centerW - maskW / 2); + const int x2 = std::min(w, centerW + maskW / 2 + 1); + const int y1 = std::max(0, centerH - maskH / 2); + const int y2 = std::min(h, centerH + maskH / 2 + 1); + + auto inputMixed = input; + inputMixed(fl::range(x1, x2), fl::range(y1, y2)) = + inputFlipped(fl::range(x1, x2), fl::range(y1, y2)); + auto newLambda = static_cast(x2 - x1) * (y2 - y1) / (w * h); + + // mix target + auto targetOneHotFlipped = + oneHot(fl::flip(target, 0), numClasses, labelSmoothing); + auto targetOneHotMixed = + newLambda * targetOneHotFlipped + (1 - newLambda) * targetOneHot; + + return {inputMixed, targetOneHotMixed}; } ImageTransform resizeTransform(const uint64_t resize) { - return [resize](const Tensor& in) { return resizeSmallest(in, resize); }; + return [resize](const Tensor& in) { return resizeSmallest(in, resize); }; } ImageTransform compose(std::vector transformfns) { - return [transformfns](const Tensor& in) { - Tensor out = in; - for (const auto& fn : transformfns) { - out = fn(out); - } - return out; - }; + return [transformfns](const Tensor& in) { + Tensor out = in; + for(const auto& fn : transformfns) { + out = fn(out); + } + return out; + }; } ImageTransform centerCropTransform(const int size) { - return [size](const Tensor& in) { return centerCrop(in, size); }; + return [size](const Tensor& in) { return centerCrop(in, size); }; }; ImageTransform randomHorizontalFlipTransform(const float p) { - return [p](const Tensor& in) { - Tensor out = in; - if (static_cast(std::rand()) / static_cast(RAND_MAX) > p) { - const long long w = in.dim(0); - // reverse indices - w --> 0 - TODO: use fl::flip? - out = out(fl::range(w - 1, 1, -1)); - } - return out; - }; + return [p](const Tensor& in) { + Tensor out = in; + if(static_cast(std::rand()) / static_cast(RAND_MAX) > p) { + const long long w = in.dim(0); + // reverse indices - w --> 0 - TODO: use fl::flip? + out = out(fl::range(w - 1, 1, -1)); + } + return out; + }; }; ImageTransform randomResizeCropTransform( @@ -302,66 +308,69 @@ ImageTransform randomResizeCropTransform( const float scaleLow, const float scaleHigh, const float ratioLow, - const float ratioHigh) { - return [=](const Tensor& in) mutable { - const int w = in.dim(0); - const int h = in.dim(1); - const float area = w * h; - for (int i = 0; i < 10; i++) { - const float scale = randomFloat(scaleLow, scaleHigh); - const float logRatio = - randomFloat(std::log(ratioLow), std::log(ratioHigh)); - ; - const float targetArea = scale * area; - const float targetRatio = std::exp(logRatio); - const int tw = std::round(std::sqrt(targetArea * targetRatio)); - const int th = std::round(std::sqrt(targetArea / targetRatio)); - if (0 < tw && tw <= w && 0 < th && th <= h) { - const int x = std::rand() % (w - tw + 1); - const int y = std::rand() % (h - th + 1); - - return resize(crop(in, x, y, tw, th), size); - } - } - return centerCrop(resizeSmallest(in, size), size); - }; + const float ratioHigh +) { + return [ = ](const Tensor& in) mutable { + const int w = in.dim(0); + const int h = in.dim(1); + const float area = w * h; + for(int i = 0; i < 10; i++) { + const float scale = randomFloat(scaleLow, scaleHigh); + const float logRatio = + randomFloat(std::log(ratioLow), std::log(ratioHigh)); + ; + const float targetArea = scale * area; + const float targetRatio = std::exp(logRatio); + const int tw = std::round(std::sqrt(targetArea * targetRatio)); + const int th = std::round(std::sqrt(targetArea / targetRatio)); + if(0 < tw && tw <= w && 0 < th && th <= h) { + const int x = std::rand() % (w - tw + 1); + const int y = std::rand() % (h - th + 1); + + return resize(crop(in, x, y, tw, th), size); + } + } + return centerCrop(resizeSmallest(in, size), size); + }; } ImageTransform randomResizeTransform(const int low, const int high) { - return [low, high](const Tensor& in) { - const float scale = - static_cast(std::rand()) / static_cast(RAND_MAX); - const int resize = low + (high - low) * scale; - return resizeSmallest(in, resize); - }; + return [low, high](const Tensor& in) { + const float scale = + static_cast(std::rand()) / static_cast(RAND_MAX); + const int resize = low + (high - low) * scale; + return resizeSmallest(in, resize); + }; }; ImageTransform randomCropTransform(const int tw, const int th) { - return [th, tw](const Tensor& in) { - Tensor out = in; - const uint64_t w = in.dim(0); - const uint64_t h = in.dim(1); - if (th > h || tw > w) { - throw std::runtime_error( - "Target th and target width are great the image size"); - } - const int x = std::rand() % (w - tw + 1); - const int y = std::rand() % (h - th + 1); - return crop(in, x, y, tw, th); - }; + return [th, tw](const Tensor& in) { + Tensor out = in; + const uint64_t w = in.dim(0); + const uint64_t h = in.dim(1); + if(th > h || tw > w) { + throw std::runtime_error( + "Target th and target width are great the image size" + ); + } + const int x = std::rand() % (w - tw + 1); + const int y = std::rand() % (h - th + 1); + return crop(in, x, y, tw, th); + }; }; ImageTransform normalizeImage( const std::vector& meanVector, - const std::vector& stdVector) { - const Tensor mean = Tensor::fromVector({1, 1, 3}, meanVector); - const Tensor std = Tensor::fromVector({1, 1, 3}, stdVector); - return [mean, std](const Tensor& in) { - Tensor out = in.astype(fl::dtype::f32) / 255.f; - out = out - mean; - out = out / std; - return out; - }; + const std::vector& stdVector +) { + const Tensor mean = Tensor::fromVector({1, 1, 3}, meanVector); + const Tensor std = Tensor::fromVector({1, 1, 3}, stdVector); + return [mean, std](const Tensor& in) { + Tensor out = in.astype(fl::dtype::f32) / 255.f; + out = out - mean; + out = out / std; + return out; + }; }; ImageTransform randomEraseTransform( @@ -369,136 +378,138 @@ ImageTransform randomEraseTransform( const float areaRatioMin, const float areaRatioMax, const float edgeRatioMin, - const float edgeRatioMax) { - // follows: https://git.io/JY9R7 - return [p, areaRatioMin, areaRatioMax, edgeRatioMin, edgeRatioMax]( - const Tensor& in) { - if (p < randomFloat(0, 1)) { - return in; - } - - const float epsilon = 1e-7; - const int w = in.dim(0); - const int h = in.dim(1); - const int c = in.dim(2); - - Tensor out = in; - for (int i = 0; i < 10; i++) { - const float s = w * h * randomFloat(areaRatioMin, areaRatioMax); - const float r = - std::exp(randomFloat(std::log(edgeRatioMin), std::log(edgeRatioMax))); - const int maskW = std::round(std::sqrt(s * r)); - const int maskH = std::round(std::sqrt(s / r)); - if (maskW >= w || maskH >= h) { - continue; - } - - const int x = static_cast(randomFloat(0, w - maskW - epsilon)); - const int y = static_cast(randomFloat(0, h - maskH - epsilon)); - Tensor fillValue = fl::randn({maskW, maskH, c}, in.type()); - - out(fl::range(x, x + maskW), fl::range(y, y + maskH)) = fillValue; - break; - } - return out; - }; + const float edgeRatioMax +) { + // follows: https://git.io/JY9R7 + return [p, areaRatioMin, areaRatioMax, edgeRatioMin, edgeRatioMax]( + const Tensor& in) { + if(p < randomFloat(0, 1)) { + return in; + } + + const float epsilon = 1e-7; + const int w = in.dim(0); + const int h = in.dim(1); + const int c = in.dim(2); + + Tensor out = in; + for(int i = 0; i < 10; i++) { + const float s = w * h * randomFloat(areaRatioMin, areaRatioMax); + const float r = + std::exp(randomFloat(std::log(edgeRatioMin), std::log(edgeRatioMax))); + const int maskW = std::round(std::sqrt(s * r)); + const int maskH = std::round(std::sqrt(s / r)); + if(maskW >= w || maskH >= h) { + continue; + } + + const int x = static_cast(randomFloat(0, w - maskW - epsilon)); + const int y = static_cast(randomFloat(0, h - maskH - epsilon)); + Tensor fillValue = fl::randn({maskW, maskH, c}, in.type()); + + out(fl::range(x, x + maskW), fl::range(y, y + maskH)) = fillValue; + break; + } + return out; + }; }; ImageTransform randomAugmentationDeitTransform( const float p, const int n, - const Tensor& fillImg) { - // Selected 15 transform functions with specific parameters - // following https://git.io/JYGG6 - - return [p, n, fillImg](const Tensor& in) { - auto res = in; - for (int i = 0; i < n; i++) { - if (p < randomFloat(0, 1)) { - continue; - } - - int mode = std::floor(randomFloat(0, 15 - 1e-5)); - if (mode == 0) { - // rotate - float baseTheta = .47; - float theta = randomPerturbNegate(baseTheta, -0.02, 0.02); - - res = fl::rotate(res, theta, fillImg); - } else if (mode == 1) { - // skew-x - float baseTheta = .27; - float theta = randomPerturbNegate(baseTheta, -0.02, 0.02); - - res = skewX(res, theta, fillImg); - } else if (mode == 2) { - // skew-y - float baseTheta = .27; - float theta = randomPerturbNegate(baseTheta, -0.02, 0.02); - - res = skewY(res, theta, fillImg); - } else if (mode == 3) { - // translate-x - int baseDelta = 90; - int delta = randomPerturbNegate(baseDelta, -3, 3); - - res = translateX(res, delta, fillImg); - } else if (mode == 4) { - // translate-y - int baseDelta = 90; - int delta = randomPerturbNegate(baseDelta, -3, 3); - - res = translateY(res, delta, fillImg); - } else if (mode == 5) { - // color - float baseEnhance = .8; - float enhance = - 1 + randomPerturbNegate(baseEnhance, -0.03, 0.03); - - res = colorEnhance(res, enhance); - } else if (mode == 6) { - // auto contrast - res = autoContrast(res); - } else if (mode == 7) { - // contrast - float baseEnhance = .8; - float enhance = - 1 + randomPerturbNegate(baseEnhance, -0.03, 0.03); - - res = contrastEnhance(res, enhance); - } else if (mode == 8) { - // brightness - float baseEnhance = .8; - float enhance = - 1 + randomPerturbNegate(baseEnhance, -0.03, 0.03); - - res = brightnessEnhance(res, enhance); - } else if (mode == 9) { - // invert - res = invert(res); - } else if (mode == 10) { - // solarize - res = solarize(res, 26.); - } else if (mode == 11) { - // solarize add - res = solarizeAdd(res, 128., 100.); - } else if (mode == 12) { - // equalize - res = equalize(res); - } else if (mode == 13) { - // posterize - res = posterize(res, 1); - } else if (mode == 14) { - // sharpness - float baseEnhance = .5; - float enhance = randomPerturbNegate(baseEnhance, -0.01, 0.01); - - res = sharpnessEnhance(res, enhance); - } - res = fl::clip(res, 0., 255.).astype(res.type()); - } - return res; - }; + const Tensor& fillImg +) { + // Selected 15 transform functions with specific parameters + // following https://git.io/JYGG6 + + return [p, n, fillImg](const Tensor& in) { + auto res = in; + for(int i = 0; i < n; i++) { + if(p < randomFloat(0, 1)) { + continue; + } + + int mode = std::floor(randomFloat(0, 15 - 1e-5)); + if(mode == 0) { + // rotate + float baseTheta = .47; + float theta = randomPerturbNegate(baseTheta, -0.02, 0.02); + + res = fl::rotate(res, theta, fillImg); + } else if(mode == 1) { + // skew-x + float baseTheta = .27; + float theta = randomPerturbNegate(baseTheta, -0.02, 0.02); + + res = skewX(res, theta, fillImg); + } else if(mode == 2) { + // skew-y + float baseTheta = .27; + float theta = randomPerturbNegate(baseTheta, -0.02, 0.02); + + res = skewY(res, theta, fillImg); + } else if(mode == 3) { + // translate-x + int baseDelta = 90; + int delta = randomPerturbNegate(baseDelta, -3, 3); + + res = translateX(res, delta, fillImg); + } else if(mode == 4) { + // translate-y + int baseDelta = 90; + int delta = randomPerturbNegate(baseDelta, -3, 3); + + res = translateY(res, delta, fillImg); + } else if(mode == 5) { + // color + float baseEnhance = .8; + float enhance = + 1 + randomPerturbNegate(baseEnhance, -0.03, 0.03); + + res = colorEnhance(res, enhance); + } else if(mode == 6) { + // auto contrast + res = autoContrast(res); + } else if(mode == 7) { + // contrast + float baseEnhance = .8; + float enhance = + 1 + randomPerturbNegate(baseEnhance, -0.03, 0.03); + + res = contrastEnhance(res, enhance); + } else if(mode == 8) { + // brightness + float baseEnhance = .8; + float enhance = + 1 + randomPerturbNegate(baseEnhance, -0.03, 0.03); + + res = brightnessEnhance(res, enhance); + } else if(mode == 9) { + // invert + res = invert(res); + } else if(mode == 10) { + // solarize + res = solarize(res, 26.); + } else if(mode == 11) { + // solarize add + res = solarizeAdd(res, 128., 100.); + } else if(mode == 12) { + // equalize + res = equalize(res); + } else if(mode == 13) { + // posterize + res = posterize(res, 1); + } else if(mode == 14) { + // sharpness + float baseEnhance = .5; + float enhance = randomPerturbNegate(baseEnhance, -0.01, 0.01); + + res = sharpnessEnhance(res, enhance); + } + res = fl::clip(res, 0., 255.).astype(res.type()); + } + return res; + }; } } // namespace fl diff --git a/flashlight/pkg/vision/dataset/Transforms.h b/flashlight/pkg/vision/dataset/Transforms.h index ef64ba1..5a598a2 100644 --- a/flashlight/pkg/vision/dataset/Transforms.h +++ b/flashlight/pkg/vision/dataset/Transforms.h @@ -13,158 +13,158 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { /* * Resizes the smallest length edge of an image to be resize while keeping * the aspect ratio */ -Tensor resizeSmallest(const Tensor& in, const int resize); + Tensor resizeSmallest(const Tensor& in, const int resize); /* * Resize both sides of image to be length * @param resize` * will change aspect ratio */ -Tensor resize(const Tensor& in, const int resize); + Tensor resize(const Tensor& in, const int resize); /* * Crop image @param in, starting from position @param x and @param y * with a target width and height of @param w and @param h respectively */ -Tensor -crop(const Tensor& in, const int x, const int y, const int w, const int h); + Tensor crop(const Tensor& in, const int x, const int y, const int w, const int h); /* * Take a center crop of image @param in, * where both image sides with be of length @param size */ -Tensor centerCrop(const Tensor& in, const int size); + Tensor centerCrop(const Tensor& in, const int size); /* * Rotate an image * @param theta to which degree (in radius) a image will rotate * @param fillImg filling values on the empty spots */ -Tensor rotate(const Tensor& input, const float theta, const Tensor& fillImg); + Tensor rotate(const Tensor& input, const float theta, const Tensor& fillImg); /* * Skew an image on the first dimension * @param theta to which degree (in radius) a image will skew * @param fillImg filling values on the empty spots */ -Tensor skewX(const Tensor& input, const float theta, const Tensor& fillImg); + Tensor skewX(const Tensor& input, const float theta, const Tensor& fillImg); /* * Skew an image on the second dimension * @param theta to which degree (in radius) a image will skew * @param fillImg filling values on the empty spots */ -Tensor skewY(const Tensor& input, const float theta, const Tensor& fillImg); + Tensor skewY(const Tensor& input, const float theta, const Tensor& fillImg); /* * Translate an image on the first dimension * @param shift number of pixels a image will translate * @param fillImg filling values on the empty spots */ -Tensor translateX(const Tensor& input, const int shift, const Tensor& fillImg); + Tensor translateX(const Tensor& input, const int shift, const Tensor& fillImg); /* * Translate an image on the second dimension * @param shift number of pixels a image will translate * @param fillImg filling values on the empty spots */ -Tensor translateY(const Tensor& input, const int shift, const Tensor& fillImg); + Tensor translateY(const Tensor& input, const int shift, const Tensor& fillImg); /* * Enhance the color of an image * @param enhance to which extend the color will change. */ -Tensor colorEnhance(const Tensor& input, const float enhance); + Tensor colorEnhance(const Tensor& input, const float enhance); /* * Remaps the image so that the darkest pixel becomes black (0), and the * lightest becomes white (255). */ -Tensor autoContrast(const Tensor& input); + Tensor autoContrast(const Tensor& input); /* * Enhance the contrast of an image * @param enhance to which extend the contrast will change. */ -Tensor contrastEnhance(const Tensor& input, const float enhance); + Tensor contrastEnhance(const Tensor& input, const float enhance); /* * Enhance the brightness of an image * @param enhance to which extend the brightness will change. */ -Tensor brightnessEnhance(const Tensor& input, const float enhance); + Tensor brightnessEnhance(const Tensor& input, const float enhance); /* * Enhance the sharpness of an image * @param enhance to which extend the sharpness will change. */ -Tensor sharpnessEnhance(const Tensor& input, const float enhance); + Tensor sharpnessEnhance(const Tensor& input, const float enhance); /* * Invert each pixel of the image */ -Tensor invert(const Tensor& input); + Tensor invert(const Tensor& input); /* * Invert all pixel values above a threshold. */ -Tensor solarize(const Tensor& input, const float threshold); + Tensor solarize(const Tensor& input, const float threshold); /* * Increase all pixel values below a threshold. */ -Tensor -solarizeAdd(const Tensor& input, const float threshold, const float addValue); + Tensor solarizeAdd(const Tensor& input, const float threshold, const float addValue); /* * Applies a non-linear mapping to the input image, in order to create a uniform * distribution of grayscale values in the output image. */ -Tensor equalize(const Tensor& input); + Tensor equalize(const Tensor& input); /* * Reduce the number of bits for each color channel. */ -Tensor posterize(const Tensor& input, const int bitsToKeep); + Tensor posterize(const Tensor& input, const int bitsToKeep); /* * Transform a target array with label indices into a one-hot matrix */ -Tensor -oneHot(const Tensor& targets, const int numClasses, const float labelSmoothing); + Tensor oneHot(const Tensor& targets, const int numClasses, const float labelSmoothing); /* * Apply mixup for a given batch as in https://arxiv.org/abs/1710.09412 */ -std::pair mixupBatch( - const float lambda, - const Tensor& input, - const Tensor& target, - const int numClasses, - const float labelSmoothing); + std::pair mixupBatch( + const float lambda, + const Tensor& input, + const Tensor& target, + const int numClasses, + const float labelSmoothing + ); /* * Apply cutmix as in https://arxiv.org/abs/1905.04899 */ -std::pair cutmixBatch( - const float lambda, - const Tensor& input, - const Tensor& target, - const int numClasses, - const float labelSmoothing); + std::pair cutmixBatch( + const float lambda, + const Tensor& input, + const Tensor& target, + const int numClasses, + const float labelSmoothing + ); // Same function signature as DataTransform but removes fl dep -using ImageTransform = std::function; + using ImageTransform = std::function; -ImageTransform normalizeImage( - const std::vector& meanVec, - const std::vector& stdVec); + ImageTransform normalizeImage( + const std::vector& meanVec, + const std::vector& stdVec + ); /* * Randomly resize the image between sizes @@ -172,35 +172,36 @@ ImageTransform normalizeImage( * @param high * This transform helps to create scale invariance */ -ImageTransform randomResizeTransform(const int low, const int high); + ImageTransform randomResizeTransform(const int low, const int high); -ImageTransform randomResizeCropTransform( - const int resize, - const float scaleLow, - const float scaleHigh, - const float ratioLow, - const float ratioHigh); + ImageTransform randomResizeCropTransform( + const int resize, + const float scaleLow, + const float scaleHigh, + const float ratioLow, + const float ratioHigh + ); /* * Randomly crop an image with target height of @param th and a target width of * @params tw */ -ImageTransform randomCropTransform(const int th, const int tw); + ImageTransform randomCropTransform(const int th, const int tw); /* * Resize the shortest edge of the image to size @param resize */ -ImageTransform resizeTransform(const uint64_t resize); + ImageTransform resizeTransform(const uint64_t resize); /* * Take a center crop of an image so its size is @param size */ -ImageTransform centerCropTransform(const int size); + ImageTransform centerCropTransform(const int size); /* * Flip an image horizontally with a probability @param p */ -ImageTransform randomHorizontalFlipTransform(const float p = 0.5); + ImageTransform randomHorizontalFlipTransform(const float p = 0.5); /* * Randomly erase. @@ -211,12 +212,13 @@ ImageTransform randomHorizontalFlipTransform(const float p = 0.5); * @param[edgeRatioMin] minimum w/h ratio for the area to erase * @param[edgeRatioMax] maximum w/h ratio for the area to erase */ -ImageTransform randomEraseTransform( - const float p = 0.5, - const float areaRatioMin = 0.02, - const float areaRatioMax = 1. / 3., - const float edgeRatioMin = 0.3, - const float edgeRatioMax = 10 / 3.); + ImageTransform randomEraseTransform( + const float p = 0.5, + const float areaRatioMin = 0.02, + const float areaRatioMax = 1. / 3., + const float edgeRatioMin = 0.3, + const float edgeRatioMax = 10 / 3. + ); /* * Randon Augmentation @@ -234,16 +236,17 @@ ImageTransform randomEraseTransform( * @param[fillImg] filling values on the empty spots generated in some * transforms */ -ImageTransform randomAugmentationDeitTransform( - const float p = 0.5, - const int n = 2, - const Tensor& fillImg = Tensor()); + ImageTransform randomAugmentationDeitTransform( + const float p = 0.5, + const int n = 2, + const Tensor& fillImg = Tensor() + ); /* * Utility method for composing multiple transform functions */ -ImageTransform compose(std::vector transformfns); + ImageTransform compose(std::vector transformfns); -} // namespace vision + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/models/Detr.cpp b/flashlight/pkg/vision/models/Detr.cpp index b64f0d1..8ec4411 100644 --- a/flashlight/pkg/vision/models/Detr.cpp +++ b/flashlight/pkg/vision/models/Detr.cpp @@ -14,30 +14,30 @@ namespace { double calculateGain(double negativeSlope) { - return std::sqrt(2.0 / (1 + std::pow(negativeSlope, 2))); + return std::sqrt(2.0 / (1 + std::pow(negativeSlope, 2))); } std::shared_ptr makeLinear(int inDim, int outDim) { - int fanIn = inDim; - float gain = calculateGain(std::sqrt(5.0)); - float std = gain / std::sqrt(fanIn); - float bound = std::sqrt(3.0) * std; - auto w = fl::uniform(outDim, inDim, -bound, bound, fl::dtype::f32, true); - bound = 1.0 / std::sqrt(fanIn); - auto b = fl::uniform({outDim}, -bound, bound, fl::dtype::f32, true); - return std::make_shared(w, b); + int fanIn = inDim; + float gain = calculateGain(std::sqrt(5.0)); + float std = gain / std::sqrt(fanIn); + float bound = std::sqrt(3.0) * std; + auto w = fl::uniform(outDim, inDim, -bound, bound, fl::dtype::f32, true); + bound = 1.0 / std::sqrt(fanIn); + auto b = fl::uniform({outDim}, -bound, bound, fl::dtype::f32, true); + return std::make_shared(w, b); } std::shared_ptr makeConv2D(int inDim, int outDim, int wx, int wy) { - int fanIn = wx * wy * inDim; - float gain = calculateGain(std::sqrt(5.0f)); - float std = gain / std::sqrt(fanIn); - float bound = std::sqrt(3.0f) * std; - auto w = - fl::uniform({wx, wy, inDim, outDim}, -bound, bound, fl::dtype::f32, true); - bound = 1.0f / std::sqrt(fanIn); - auto b = fl::uniform({1, 1, outDim, 1}, -bound, bound, fl::dtype::f32, true); - return std::make_shared(w, b, 1, 1); + int fanIn = wx * wy * inDim; + float gain = calculateGain(std::sqrt(5.0f)); + float std = gain / std::sqrt(fanIn); + float bound = std::sqrt(3.0f) * std; + auto w = + fl::uniform({wx, wy, inDim, outDim}, -bound, bound, fl::dtype::f32, true); + bound = 1.0f / std::sqrt(fanIn); + auto b = fl::uniform({1, 1, outDim, 1}, -bound, bound, fl::dtype::f32, true); + return std::make_shared(w, b, 1, 1); } } // namespace @@ -48,14 +48,15 @@ MLP::MLP( const int32_t inputDim, const int32_t hiddenDim, const int32_t outputDim, - const int32_t numLayers) { - add(makeLinear(inputDim, hiddenDim)); - for (int i = 1; i < numLayers - 1; i++) { + const int32_t numLayers +) { + add(makeLinear(inputDim, hiddenDim)); + for(int i = 1; i < numLayers - 1; i++) { + add(ReLU()); + add(makeLinear(hiddenDim, hiddenDim)); + } add(ReLU()); - add(makeLinear(hiddenDim, hiddenDim)); - } - add(ReLU()); - add(makeLinear(hiddenDim, outputDim)); + add(makeLinear(hiddenDim, outputDim)); } Detr::Detr( @@ -64,94 +65,102 @@ Detr::Detr( const int32_t hiddenDim, const int32_t numClasses, const int32_t numQueries, - const bool auxLoss) - : backbone_(backbone), - transformer_(transformer), - classEmbed_(makeLinear(hiddenDim, numClasses + 1)), - bboxEmbed_(std::make_shared(hiddenDim, hiddenDim, 4, 3)), - queryEmbed_( - std::make_shared(fl::normal({hiddenDim, numQueries}))), - posEmbed_(std::make_shared( - hiddenDim / 2, - 10000, - true, - 6.283185307179586f)), - inputProj_(makeConv2D(2048, hiddenDim, 1, 1)), - numClasses_(numClasses), - numQueries_(numQueries), - auxLoss_(auxLoss) { - add(transformer_); - add(classEmbed_); - add(bboxEmbed_); - add(queryEmbed_); - add(inputProj_); - add(backbone_); - add(posEmbed_); + const bool auxLoss +) : backbone_(backbone), + transformer_(transformer), + classEmbed_(makeLinear(hiddenDim, numClasses + 1)), + bboxEmbed_(std::make_shared(hiddenDim, hiddenDim, 4, 3)), + queryEmbed_( + std::make_shared(fl::normal({ hiddenDim, numQueries }))), + posEmbed_(std::make_shared( + hiddenDim / 2, + 10000, + true, + 6.283185307179586f + )), + inputProj_(makeConv2D(2048, hiddenDim, 1, 1)), + numClasses_(numClasses), + numQueries_(numQueries), + auxLoss_(auxLoss) { + add(transformer_); + add(classEmbed_); + add(bboxEmbed_); + add(queryEmbed_); + add(inputProj_); + add(backbone_); + add(posEmbed_); } std::vector Detr::forward(const std::vector& input) { - // input: {input, mask} - if (input.size() != 2) { - throw std::invalid_argument( - "Detr takes 2 Variables as input but gets " + - std::to_string(input.size())); - } - auto feature = forwardBackbone(input.front()); - return forwardTransformer({feature, input[1]}); + // input: {input, mask} + if(input.size() != 2) { + throw std::invalid_argument( + "Detr takes 2 Variables as input but gets " + + std::to_string(input.size()) + ); + } + auto feature = forwardBackbone(input.front()); + return forwardTransformer({feature, input[1]}); } Variable Detr::forwardBackbone(const Variable& input) { - return backbone_->forward({input})[1]; + return backbone_->forward({input})[1]; } std::vector Detr::forwardTransformer( - const std::vector& input) { - // input: {feature, mask} - fl::Variable mask = fl::Variable( - fl::resize( - input[1].tensor(), {input[0].shape()}, InterpolationMode::Nearest), - true); - auto inputProjection = inputProj_->forward(input[0]); - auto posEmbed = posEmbed_->forward({mask})[0]; - auto hs = transformer_->forward( - inputProjection, - mask.astype(inputProjection.type()), - queryEmbed_->param(0).astype(inputProjection.type()), - posEmbed.astype(inputProjection.type())); - - auto outputClasses = classEmbed_->forward(hs[0]); - auto outputCoord = sigmoid(bboxEmbed_->forward(hs)[0]); - - return {outputClasses, outputCoord}; + const std::vector& input +) { + // input: {feature, mask} + fl::Variable mask = fl::Variable( + fl::resize( + input[1].tensor(), + {input[0].shape()}, + InterpolationMode::Nearest + ), + true + ); + auto inputProjection = inputProj_->forward(input[0]); + auto posEmbed = posEmbed_->forward({mask})[0]; + auto hs = transformer_->forward( + inputProjection, + mask.astype(inputProjection.type()), + queryEmbed_->param(0).astype(inputProjection.type()), + posEmbed.astype(inputProjection.type()) + ); + + auto outputClasses = classEmbed_->forward(hs[0]); + auto outputCoord = sigmoid(bboxEmbed_->forward(hs)[0]); + + return {outputClasses, outputCoord}; } std::unique_ptr Detr::clone() const { - throw std::runtime_error("Cloning is unimplemented in Module 'Detr'"); + throw std::runtime_error("Cloning is unimplemented in Module 'Detr'"); } std::string Detr::prettyString() const { - std::ostringstream ss; - ss << "Detection Transformer"; - ss << Container::prettyString(); - return ss.str(); + std::ostringstream ss; + ss << "Detection Transformer"; + ss << Container::prettyString(); + return ss.str(); } std::vector Detr::paramsWithoutBackbone() { - std::vector results; - std::vector> childParams; - childParams.push_back(transformer_->params()); - childParams.push_back(classEmbed_->params()); - childParams.push_back(bboxEmbed_->params()); - childParams.push_back(queryEmbed_->params()); - childParams.push_back(inputProj_->params()); - for (auto params : childParams) { - results.insert(results.end(), params.begin(), params.end()); - } - return results; + std::vector results; + std::vector> childParams; + childParams.push_back(transformer_->params()); + childParams.push_back(classEmbed_->params()); + childParams.push_back(bboxEmbed_->params()); + childParams.push_back(queryEmbed_->params()); + childParams.push_back(inputProj_->params()); + for(auto params : childParams) { + results.insert(results.end(), params.begin(), params.end()); + } + return results; } std::vector Detr::backboneParams() { - return backbone_->params(); + return backbone_->params(); } } // namespace fl diff --git a/flashlight/pkg/vision/models/Detr.h b/flashlight/pkg/vision/models/Detr.h index 9328701..4aaf0be 100644 --- a/flashlight/pkg/vision/models/Detr.h +++ b/flashlight/pkg/vision/models/Detr.h @@ -12,68 +12,72 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { // TODO (padentomasello) this can just be a function -class MLP : public Sequential { - public: - MLP(const int32_t inputDim, - const int32_t hiddenDim, - const int32_t outputDim, - const int32_t numLayers); + class MLP : public Sequential { + public: + MLP( + const int32_t inputDim, + const int32_t hiddenDim, + const int32_t outputDim, + const int32_t numLayers + ); - private: - MLP() = default; - FL_SAVE_LOAD_WITH_BASE(fl::Sequential) -}; + private: + MLP() = default; + FL_SAVE_LOAD_WITH_BASE(fl::Sequential) + }; -class Detr : public Container { - public: - Detr( - std::shared_ptr transformer, - std::shared_ptr backbone, - const int32_t hiddenDim, - const int32_t numClasses, - const int32_t numQueries, - const bool auxLoss); + class Detr : public Container { + public: + Detr( + std::shared_ptr transformer, + std::shared_ptr backbone, + const int32_t hiddenDim, + const int32_t numClasses, + const int32_t numQueries, + const bool auxLoss + ); - std::vector forward(const std::vector& input) override; - Variable forwardBackbone(const Variable& input); - std::vector forwardTransformer(const std::vector& input); + std::vector forward(const std::vector& input) override; + Variable forwardBackbone(const Variable& input); + std::vector forwardTransformer(const std::vector& input); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - std::string prettyString() const override; + std::string prettyString() const override; - std::vector paramsWithoutBackbone(); + std::vector paramsWithoutBackbone(); - std::vector backboneParams(); + std::vector backboneParams(); - private: - Detr() = default; - std::shared_ptr backbone_; - std::shared_ptr transformer_; - std::shared_ptr classEmbed_; - std::shared_ptr bboxEmbed_; - std::shared_ptr queryEmbed_; - std::shared_ptr posEmbed_; - std::shared_ptr inputProj_; - int32_t hiddenDim_; - int32_t numClasses_; - int32_t numQueries_; - bool auxLoss_; - FL_SAVE_LOAD_WITH_BASE( - fl::Container, - backbone_, - transformer_, - classEmbed_, - bboxEmbed_, - queryEmbed_, - posEmbed_, - inputProj_) -}; + private: + Detr() = default; + std::shared_ptr backbone_; + std::shared_ptr transformer_; + std::shared_ptr classEmbed_; + std::shared_ptr bboxEmbed_; + std::shared_ptr queryEmbed_; + std::shared_ptr posEmbed_; + std::shared_ptr inputProj_; + int32_t hiddenDim_; + int32_t numClasses_; + int32_t numQueries_; + bool auxLoss_; + FL_SAVE_LOAD_WITH_BASE( + fl::Container, + backbone_, + transformer_, + classEmbed_, + bboxEmbed_, + queryEmbed_, + posEmbed_, + inputProj_ + ) + }; -} // namespace vision + } // namespace vision } // namespace pkg } // namespace fl CEREAL_REGISTER_TYPE(fl::pkg::vision::Detr) diff --git a/flashlight/pkg/vision/models/Resnet.cpp b/flashlight/pkg/vision/models/Resnet.cpp index 5f573ee..8443e5c 100644 --- a/flashlight/pkg/vision/models/Resnet.cpp +++ b/flashlight/pkg/vision/models/Resnet.cpp @@ -11,15 +11,15 @@ namespace fl::pkg::vision { namespace { -Conv2D conv3x3(int inC, int outC, int stride, int groups) { - const auto pad = PaddingMode::SAME; - return Conv2D(inC, outC, 3, 3, stride, stride, pad, pad, 1, 1, false, groups); -} + Conv2D conv3x3(int inC, int outC, int stride, int groups) { + const auto pad = PaddingMode::SAME; + return Conv2D(inC, outC, 3, 3, stride, stride, pad, pad, 1, 1, false, groups); + } -Conv2D conv1x1(int inC, int outC, int stride, int groups) { - const auto pad = PaddingMode::SAME; - return Conv2D(inC, outC, 1, 1, stride, stride, pad, pad, 1, 1, false, groups); -} + Conv2D conv1x1(int inC, int outC, int stride, int groups) { + const auto pad = PaddingMode::SAME; + return Conv2D(inC, outC, 1, 1, stride, stride, pad, pad, 1, 1, false, groups); + } } // namespace @@ -33,33 +33,34 @@ ConvBnAct::ConvBnAct( const int sx, const int sy, bool bn, - bool act) { - const auto pad = PaddingMode::SAME; - const bool bias = !bn; - add(fl::Conv2D(inC, outC, kw, kh, sx, sy, pad, pad, 1, 1, bias)); - if (bn) { - add(fl::BatchNorm(2, outC)); - } - if (act) { - add(fl::ReLU()); - } + bool act +) { + const auto pad = PaddingMode::SAME; + const bool bias = !bn; + add(fl::Conv2D(inC, outC, kw, kh, sx, sy, pad, pad, 1, 1, bias)); + if(bn) { + add(fl::BatchNorm(2, outC)); + } + if(act) { + add(fl::ReLU()); + } } ResNetBlock::ResNetBlock() = default; ResNetBlock::ResNetBlock(const int inC, const int outC, const int stride) { - add(conv3x3(inC, outC, stride, 1)); - add(BatchNorm(2, outC)); - add(ReLU()); - add(conv3x3(outC, outC, 1, 1)); - add(BatchNorm(2, outC)); - add(ReLU()); - if (inC != outC || stride > 1) { - Sequential downsample; - downsample.add(conv1x1(inC, outC, stride, 1)); - downsample.add(BatchNorm(2, outC)); - add(std::move(downsample)); - } + add(conv3x3(inC, outC, stride, 1)); + add(BatchNorm(2, outC)); + add(ReLU()); + add(conv3x3(outC, outC, 1, 1)); + add(BatchNorm(2, outC)); + add(ReLU()); + if(inC != outC || stride > 1) { + Sequential downsample; + downsample.add(conv1x1(inC, outC, stride, 1)); + downsample.add(BatchNorm(2, outC)); + add(std::move(downsample)); + } } ResNetBottleneckBlock::ResNetBottleneckBlock() = default; @@ -67,108 +68,112 @@ ResNetBottleneckBlock::ResNetBottleneckBlock() = default; ResNetBottleneckBlock::ResNetBottleneckBlock( const int inC, const int planes, - const int stride) { - const int expansionFactor = 4; - add(conv1x1(inC, planes, 1, 1)); - add(BatchNorm(2, planes)); - add(ReLU()); - add(conv3x3(planes, planes, stride, 1)); - add(BatchNorm(2, planes)); - add(ReLU()); - add(conv1x1(planes, planes * expansionFactor, 1, 1)); - add(BatchNorm(2, planes * expansionFactor)); - add(ReLU()); - if (inC != planes * expansionFactor || stride > 1) { - Sequential downsample; - downsample.add(conv1x1(inC, planes * expansionFactor, stride, 1)); - downsample.add(BatchNorm(2, planes * expansionFactor)); - add(std::move(downsample)); - } + const int stride +) { + const int expansionFactor = 4; + add(conv1x1(inC, planes, 1, 1)); + add(BatchNorm(2, planes)); + add(ReLU()); + add(conv3x3(planes, planes, stride, 1)); + add(BatchNorm(2, planes)); + add(ReLU()); + add(conv1x1(planes, planes * expansionFactor, 1, 1)); + add(BatchNorm(2, planes * expansionFactor)); + add(ReLU()); + if(inC != planes * expansionFactor || stride > 1) { + Sequential downsample; + downsample.add(conv1x1(inC, planes * expansionFactor, stride, 1)); + downsample.add(BatchNorm(2, planes * expansionFactor)); + add(std::move(downsample)); + } } std::vector ResNetBottleneckBlock::forward( - const std::vector& inputs) { - const auto& c1 = module(0); - const auto& bn1 = module(1); - const auto& relu1 = module(2); - const auto& c2 = module(3); - const auto& bn2 = module(4); - const auto& relu2 = module(5); - const auto& c3 = module(6); - const auto& bn3 = module(7); - const auto& relu3 = module(8); - - std::vector out; - out = c1->forward(inputs); - out = bn1->forward(out); - - out = relu1->forward(out); - - out = c2->forward(out); - out = bn2->forward(out); - out = relu2->forward(out); - - out = c3->forward(out); - out = bn3->forward(out); - - std::vector shortcut; - if (modules().size() > 9) { - shortcut = module(9)->forward(inputs); - } else { - shortcut = inputs; - } - return relu3->forward({out[0] + shortcut[0]}); + const std::vector& inputs +) { + const auto& c1 = module(0); + const auto& bn1 = module(1); + const auto& relu1 = module(2); + const auto& c2 = module(3); + const auto& bn2 = module(4); + const auto& relu2 = module(5); + const auto& c3 = module(6); + const auto& bn3 = module(7); + const auto& relu3 = module(8); + + std::vector out; + out = c1->forward(inputs); + out = bn1->forward(out); + + out = relu1->forward(out); + + out = c2->forward(out); + out = bn2->forward(out); + out = relu2->forward(out); + + out = c3->forward(out); + out = bn3->forward(out); + + std::vector shortcut; + if(modules().size() > 9) { + shortcut = module(9)->forward(inputs); + } else { + shortcut = inputs; + } + return relu3->forward({out[0] + shortcut[0]}); } std::string ResNetBottleneckBlock::prettyString() const { - std::ostringstream ss; - ss << "ResNetBottleneckBlock"; - ss << Container::prettyString(); - return ss.str(); + std::ostringstream ss; + ss << "ResNetBottleneckBlock"; + ss << Container::prettyString(); + return ss.str(); } std::vector ResNetBlock::forward( - const std::vector& inputs) { - const auto& c1 = module(0); - const auto& bn1 = module(1); - const auto& relu1 = module(2); - const auto& c2 = module(3); - const auto& bn2 = module(4); - const auto& relu2 = module(5); - std::vector out; - out = c1->forward(inputs); - out = bn1->forward(out); - out = relu1->forward(out); - out = c2->forward(out); - out = bn2->forward(out); - - std::vector shortcut; - if (modules().size() > 6) { - shortcut = module(6)->forward(inputs); - } else { - shortcut = inputs; - } - return relu2->forward({out[0] + shortcut[0]}); + const std::vector& inputs +) { + const auto& c1 = module(0); + const auto& bn1 = module(1); + const auto& relu1 = module(2); + const auto& c2 = module(3); + const auto& bn2 = module(4); + const auto& relu2 = module(5); + std::vector out; + out = c1->forward(inputs); + out = bn1->forward(out); + out = relu1->forward(out); + out = c2->forward(out); + out = bn2->forward(out); + + std::vector shortcut; + if(modules().size() > 6) { + shortcut = module(6)->forward(inputs); + } else { + shortcut = inputs; + } + return relu2->forward({out[0] + shortcut[0]}); } std::string ResNetBlock::prettyString() const { - std::ostringstream ss; - ss << "ResNetBlock"; - ss << Container::prettyString(); - return ss.str(); + std::ostringstream ss; + ss << "ResNetBlock"; + ss << Container::prettyString(); + return ss.str(); } ResNetBottleneckStage::ResNetBottleneckStage( const int inC, const int outC, const int numBlocks, - const int stride) { - add(ResNetBottleneckBlock(inC, outC, stride)); - const int expansionFactor = 4; - const int inPlanes = outC * expansionFactor; - for (int i = 1; i < numBlocks; i++) { - add(ResNetBottleneckBlock(inPlanes, outC)); - } + const int stride +) { + add(ResNetBottleneckBlock(inC, outC, stride)); + const int expansionFactor = 4; + const int inPlanes = outC * expansionFactor; + for(int i = 1; i < numBlocks; i++) { + add(ResNetBottleneckBlock(inPlanes, outC)); + } }; ResNetBottleneckStage::ResNetBottleneckStage() = default; @@ -179,54 +184,55 @@ ResNetStage::ResNetStage( const int inC, const int outC, const int numBlocks, - const int stride) { - add(ResNetBlock(inC, outC, stride)); - for (int i = 1; i < numBlocks; i++) { - add(ResNetBlock(outC, outC)); - } + const int stride +) { + add(ResNetBlock(inC, outC, stride)); + for(int i = 1; i < numBlocks; i++) { + add(ResNetBlock(outC, outC)); + } } std::shared_ptr resnet34() { - auto model = std::make_shared(); - // conv1 -> 244x244x3 -> 112x112x64 - model->add(ConvBnAct(3, 64, 7, 7, 2, 2)); - // maxpool -> 112x122x64 -> 56x56x64 - model->add(Pool2D(3, 3, 2, 2, -1, -1, PoolingMode::MAX)); - // conv2_x -> 56x56x64 -> 56x56x64 - model->add(ResNetStage(64, 64, 3, 1)); - // conv3_x -> 56x56x64 -> 28x28x128 - model->add(ResNetStage(64, 128, 4, 2)); - // conv4_x -> 28x28x128 -> 14x14x256 - model->add(ResNetStage(128, 256, 6, 2)); - // conv5_x -> 14x14x256 -> 7x7x256 - model->add(ResNetStage(256, 512, 3, 2)); - // pool 7x7x512 -> 1x1x512 - model->add(Pool2D(7, 7, 1, 1, 0, 0, fl::PoolingMode::AVG_EXCLUDE_PADDING)); - - model->add(View({512, -1})); - model->add(Linear(512, 1000)); - return model; + auto model = std::make_shared(); + // conv1 -> 244x244x3 -> 112x112x64 + model->add(ConvBnAct(3, 64, 7, 7, 2, 2)); + // maxpool -> 112x122x64 -> 56x56x64 + model->add(Pool2D(3, 3, 2, 2, -1, -1, PoolingMode::MAX)); + // conv2_x -> 56x56x64 -> 56x56x64 + model->add(ResNetStage(64, 64, 3, 1)); + // conv3_x -> 56x56x64 -> 28x28x128 + model->add(ResNetStage(64, 128, 4, 2)); + // conv4_x -> 28x28x128 -> 14x14x256 + model->add(ResNetStage(128, 256, 6, 2)); + // conv5_x -> 14x14x256 -> 7x7x256 + model->add(ResNetStage(256, 512, 3, 2)); + // pool 7x7x512 -> 1x1x512 + model->add(Pool2D(7, 7, 1, 1, 0, 0, fl::PoolingMode::AVG_EXCLUDE_PADDING)); + + model->add(View({512, -1})); + model->add(Linear(512, 1000)); + return model; }; std::shared_ptr resnet50() { - auto model = std::make_shared(); - // conv1 -> 244x244x3 -> 112x112x64 - model->add(ConvBnAct(3, 64, 7, 7, 2, 2)); - // maxpool -> 112x122x64 -> 56x56x64 - model->add(Pool2D(3, 3, 2, 2, -1, -1, PoolingMode::MAX)); - // conv2_x -> 56x56x64 -> 56x56x64 - model->add(ResNetBottleneckStage(64, 64, 3, 1)); - // conv3_x -> 56x56x64 -> 28x28x128 - model->add(ResNetBottleneckStage(64 * 4, 128, 4, 2)); - // conv4_x -> 28x28x128 -> 14x14x256 - model->add(ResNetBottleneckStage(128 * 4, 256, 6, 2)); - // conv5_x -> 14x14x256 -> 7x7x256 - model->add(ResNetBottleneckStage(256 * 4, 512, 3, 2)); - // pool 7x7x512 -> 1x1x512 - model->add(Pool2D(7, 7, 1, 1, 0, 0, fl::PoolingMode::AVG_EXCLUDE_PADDING)); - - model->add(View({512 * 4, -1})); - model->add(Linear(512 * 4, 1000)); - return model; + auto model = std::make_shared(); + // conv1 -> 244x244x3 -> 112x112x64 + model->add(ConvBnAct(3, 64, 7, 7, 2, 2)); + // maxpool -> 112x122x64 -> 56x56x64 + model->add(Pool2D(3, 3, 2, 2, -1, -1, PoolingMode::MAX)); + // conv2_x -> 56x56x64 -> 56x56x64 + model->add(ResNetBottleneckStage(64, 64, 3, 1)); + // conv3_x -> 56x56x64 -> 28x28x128 + model->add(ResNetBottleneckStage(64 * 4, 128, 4, 2)); + // conv4_x -> 28x28x128 -> 14x14x256 + model->add(ResNetBottleneckStage(128 * 4, 256, 6, 2)); + // conv5_x -> 14x14x256 -> 7x7x256 + model->add(ResNetBottleneckStage(256 * 4, 512, 3, 2)); + // pool 7x7x512 -> 1x1x512 + model->add(Pool2D(7, 7, 1, 1, 0, 0, fl::PoolingMode::AVG_EXCLUDE_PADDING)); + + model->add(View({512 * 4, -1})); + model->add(Linear(512 * 4, 1000)); + return model; } } // namespace fl diff --git a/flashlight/pkg/vision/models/Resnet.h b/flashlight/pkg/vision/models/Resnet.h index c84d98f..a83b577 100644 --- a/flashlight/pkg/vision/models/Resnet.h +++ b/flashlight/pkg/vision/models/Resnet.h @@ -12,93 +12,98 @@ namespace fl { namespace pkg { -namespace vision { - -class ConvBnAct : public fl::Sequential { - public: - ConvBnAct( - const int inChannels, - const int outChannels, - const int kw, - const int kh, - const int sx = 1, - const int sy = 1, - bool bn = true, - bool act = true); - - private: - ConvBnAct(); - FL_SAVE_LOAD_WITH_BASE(fl::Sequential) -}; - -class ResNetBlock : public fl::Container { - private: - FL_SAVE_LOAD_WITH_BASE(fl::Container) - ResNetBlock(); - - public: - ResNetBlock( - const int inChannels, - const int outChannels, - const int stride = 1); - - std::vector forward( - const std::vector& inputs) override; - - std::string prettyString() const override; - - FL_BASIC_CONTAINER_CLONING(ResNetBlock) -}; - -class ResNetBottleneckBlock : public fl::Container { - private: - FL_SAVE_LOAD_WITH_BASE(fl::Container) - ResNetBottleneckBlock(); - - public: - ResNetBottleneckBlock( - const int inChannels, - const int outChannels, - const int stride = 1); - - std::vector forward( - const std::vector& inputs) override; - - std::string prettyString() const override; - - FL_BASIC_CONTAINER_CLONING(ResNetBottleneckBlock) -}; - -class ResNetBottleneckStage : public fl::Sequential { - public: - ResNetBottleneckStage( - const int inChannels, - const int outChannels, - const int numBlocks, - const int stride); - - private: - ResNetBottleneckStage(); - FL_SAVE_LOAD_WITH_BASE(fl::Sequential) -}; - -class ResNetStage : public fl::Sequential { - public: - ResNetStage( - const int inChannels, - const int outChannels, - const int numBlocks, - const int stride); - - private: - ResNetStage(); - FL_SAVE_LOAD_WITH_BASE(fl::Sequential) -}; - -std::shared_ptr resnet34(); -std::shared_ptr resnet50(); - -} // namespace vision + namespace vision { + + class ConvBnAct : public fl::Sequential { + public: + ConvBnAct( + const int inChannels, + const int outChannels, + const int kw, + const int kh, + const int sx = 1, + const int sy = 1, + bool bn = true, + bool act = true + ); + + private: + ConvBnAct(); + FL_SAVE_LOAD_WITH_BASE(fl::Sequential) + }; + + class ResNetBlock : public fl::Container { + private: + FL_SAVE_LOAD_WITH_BASE(fl::Container) ResNetBlock(); + + public: + ResNetBlock( + const int inChannels, + const int outChannels, + const int stride = 1 + ); + + std::vector forward( + const std::vector& inputs + ) override; + + std::string prettyString() const override; + + FL_BASIC_CONTAINER_CLONING(ResNetBlock) + }; + + class ResNetBottleneckBlock : public fl::Container { + private: + FL_SAVE_LOAD_WITH_BASE(fl::Container) ResNetBottleneckBlock(); + + public: + ResNetBottleneckBlock( + const int inChannels, + const int outChannels, + const int stride = 1 + ); + + std::vector forward( + const std::vector& inputs + ) override; + + std::string prettyString() const override; + + FL_BASIC_CONTAINER_CLONING(ResNetBottleneckBlock) + }; + + class ResNetBottleneckStage : public fl::Sequential { + public: + ResNetBottleneckStage( + const int inChannels, + const int outChannels, + const int numBlocks, + const int stride + ); + + private: + ResNetBottleneckStage(); + FL_SAVE_LOAD_WITH_BASE(fl::Sequential) + }; + + class ResNetStage : public fl::Sequential { + public: + ResNetStage( + const int inChannels, + const int outChannels, + const int numBlocks, + const int stride + ); + + private: + ResNetStage(); + FL_SAVE_LOAD_WITH_BASE(fl::Sequential) + }; + + std::shared_ptr resnet34(); + std::shared_ptr resnet50(); + + } // namespace vision } // namespace pkg } // namespace fl CEREAL_REGISTER_TYPE(fl::pkg::vision::ConvBnAct) diff --git a/flashlight/pkg/vision/models/Resnet50Backbone.cpp b/flashlight/pkg/vision/models/Resnet50Backbone.cpp index 693c7f0..b1c8017 100644 --- a/flashlight/pkg/vision/models/Resnet50Backbone.cpp +++ b/flashlight/pkg/vision/models/Resnet50Backbone.cpp @@ -10,38 +10,40 @@ namespace fl::pkg::vision { Resnet50Backbone::Resnet50Backbone() { - Sequential backbone; - backbone.add(ConvFrozenBatchNormActivation(3, 64, 7, 7, 2, 2)); - // maxpool -> 112x122x64 -> 56x56x64 - backbone.add(Pool2D(3, 3, 2, 2, -1, -1, PoolingMode::MAX)); - // conv2_x -> 56x56x64 -> 56x56x64 - backbone.add(ResNetBottleneckStageFrozenBatchNorm(64, 64, 3, 1)); - // conv3_x -> 56x56x64 -> 28x28x128 - backbone.add(ResNetBottleneckStageFrozenBatchNorm(64 * 4, 128, 4, 2)); - // conv4_x -> 28x28x128 -> 14x14x256 - backbone.add(ResNetBottleneckStageFrozenBatchNorm(128 * 4, 256, 6, 2)); - // conv5_x -> 14x14x256 -> 7x7x256 - backbone.add(ResNetBottleneckStageFrozenBatchNorm(256 * 4, 512, 3, 2)); - add(std::move(backbone)); + Sequential backbone; + backbone.add(ConvFrozenBatchNormActivation(3, 64, 7, 7, 2, 2)); + // maxpool -> 112x122x64 -> 56x56x64 + backbone.add(Pool2D(3, 3, 2, 2, -1, -1, PoolingMode::MAX)); + // conv2_x -> 56x56x64 -> 56x56x64 + backbone.add(ResNetBottleneckStageFrozenBatchNorm(64, 64, 3, 1)); + // conv3_x -> 56x56x64 -> 28x28x128 + backbone.add(ResNetBottleneckStageFrozenBatchNorm(64 * 4, 128, 4, 2)); + // conv4_x -> 28x28x128 -> 14x14x256 + backbone.add(ResNetBottleneckStageFrozenBatchNorm(128 * 4, 256, 6, 2)); + // conv5_x -> 14x14x256 -> 7x7x256 + backbone.add(ResNetBottleneckStageFrozenBatchNorm(256 * 4, 512, 3, 2)); + add(std::move(backbone)); - Sequential tail; - tail.add(Pool2D(7, 7, 1, 1, 0, 0, fl::PoolingMode::AVG_EXCLUDE_PADDING)); - tail.add( - ConvFrozenBatchNormActivation(512 * 4, 1000, 1, 1, 1, 1, false, false)); - tail.add(View({1000, -1})); - tail.add(LogSoftmax()); - add(std::move(tail)); + Sequential tail; + tail.add(Pool2D(7, 7, 1, 1, 0, 0, fl::PoolingMode::AVG_EXCLUDE_PADDING)); + tail.add( + ConvFrozenBatchNormActivation(512 * 4, 1000, 1, 1, 1, 1, false, false) + ); + tail.add(View({1000, -1})); + tail.add(LogSoftmax()); + add(std::move(tail)); } std::vector Resnet50Backbone::forward( - const std::vector& input) { - const auto& features = module(0)->forward(input); - const auto& output = module(1)->forward(features); - return {output[0], features[0]}; + const std::vector& input +) { + const auto& features = module(0)->forward(input); + const auto& output = module(1)->forward(features); + return {output[0], features[0]}; } std::string Resnet50Backbone::prettyString() const { - return "Resnet50Backbone"; + return "Resnet50Backbone"; } } // namespace fl diff --git a/flashlight/pkg/vision/models/Resnet50Backbone.h b/flashlight/pkg/vision/models/Resnet50Backbone.h index 2b30e0a..85407b2 100644 --- a/flashlight/pkg/vision/models/Resnet50Backbone.h +++ b/flashlight/pkg/vision/models/Resnet50Backbone.h @@ -13,25 +13,25 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { -using namespace fl::pkg::vision; + using namespace fl::pkg::vision; -class Resnet50Backbone : public Container { - public: - Resnet50Backbone(); + class Resnet50Backbone : public Container { + public: + Resnet50Backbone(); - std::vector forward(const std::vector& input) override; + std::vector forward(const std::vector& input) override; - std::string prettyString() const override; + std::string prettyString() const override; - FL_BASIC_CONTAINER_CLONING(Resnet50Backbone) + FL_BASIC_CONTAINER_CLONING(Resnet50Backbone) - private: - FL_SAVE_LOAD_WITH_BASE(fl::Container) -}; + private: + FL_SAVE_LOAD_WITH_BASE(fl::Container) + }; -} // namespace vision + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.cpp b/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.cpp index 51e7985..05c9425 100644 --- a/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.cpp +++ b/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.cpp @@ -11,15 +11,15 @@ namespace fl::pkg::vision { namespace { -Conv2D conv3x3(int inC, int outC, int stride, int groups) { - const auto pad = PaddingMode::SAME; - return Conv2D(inC, outC, 3, 3, stride, stride, pad, pad, 1, 1, false, groups); -} + Conv2D conv3x3(int inC, int outC, int stride, int groups) { + const auto pad = PaddingMode::SAME; + return Conv2D(inC, outC, 3, 3, stride, stride, pad, pad, 1, 1, false, groups); + } -Conv2D conv1x1(int inC, int outC, int stride, int groups) { - const auto pad = PaddingMode::SAME; - return Conv2D(inC, outC, 1, 1, stride, stride, pad, pad, 1, 1, false, groups); -} + Conv2D conv1x1(int inC, int outC, int stride, int groups) { + const auto pad = PaddingMode::SAME; + return Conv2D(inC, outC, 1, 1, stride, stride, pad, pad, 1, 1, false, groups); + } } // namespace @@ -33,16 +33,17 @@ ConvFrozenBatchNormActivation::ConvFrozenBatchNormActivation( const int sx, const int sy, bool bn, - bool act) { - const auto pad = PaddingMode::SAME; - const bool bias = !bn; - add(fl::Conv2D(inC, outC, kw, kh, sx, sy, pad, pad, 1, 1, bias)); - if (bn) { - add(fl::FrozenBatchNorm(2, outC)); - } - if (act) { - add(fl::ReLU()); - } + bool act +) { + const auto pad = PaddingMode::SAME; + const bool bias = !bn; + add(fl::Conv2D(inC, outC, kw, kh, sx, sy, pad, pad, 1, 1, bias)); + if(bn) { + add(fl::FrozenBatchNorm(2, outC)); + } + if(act) { + add(fl::ReLU()); + } } ResNetBlockFrozenBatchNorm::ResNetBlockFrozenBatchNorm() = default; @@ -50,133 +51,138 @@ ResNetBlockFrozenBatchNorm::ResNetBlockFrozenBatchNorm() = default; ResNetBlockFrozenBatchNorm::ResNetBlockFrozenBatchNorm( const int inC, const int outC, - const int stride) { - add(Conv2D(conv3x3(inC, outC, stride, 1))); - add(FrozenBatchNorm(FrozenBatchNorm(2, outC))); - add(ReLU()); - add(conv3x3(outC, outC, 1, 1)); - add(FrozenBatchNorm(2, outC)); - add(ReLU()); - if (inC != outC || stride > 1) { - Sequential downsample; - downsample.add(conv1x1(inC, outC, stride, 1)); - downsample.add(FrozenBatchNorm(2, outC)); - add(std::move(downsample)); - } + const int stride +) { + add(Conv2D(conv3x3(inC, outC, stride, 1))); + add(FrozenBatchNorm(FrozenBatchNorm(2, outC))); + add(ReLU()); + add(conv3x3(outC, outC, 1, 1)); + add(FrozenBatchNorm(2, outC)); + add(ReLU()); + if(inC != outC || stride > 1) { + Sequential downsample; + downsample.add(conv1x1(inC, outC, stride, 1)); + downsample.add(FrozenBatchNorm(2, outC)); + add(std::move(downsample)); + } } ResNetBottleneckBlockFrozenBatchNorm::ResNetBottleneckBlockFrozenBatchNorm() = - default; +default; ResNetBottleneckBlockFrozenBatchNorm::ResNetBottleneckBlockFrozenBatchNorm( const int inC, const int planes, - const int stride) { - const int expansionFactor = 4; - add(conv1x1(inC, planes, 1, 1)); - add(FrozenBatchNorm(FrozenBatchNorm(2, planes))); - add(ReLU()); - add(Conv2D(conv3x3(planes, planes, stride, 1))); - add(FrozenBatchNorm(FrozenBatchNorm(2, planes))); - add(ReLU()); - add(conv1x1(planes, planes * expansionFactor, 1, 1)); - add(FrozenBatchNorm(2, planes * expansionFactor)); - add(ReLU()); - if (inC != planes * expansionFactor || stride > 1) { - Sequential downsample; - downsample.add(conv1x1(inC, planes * expansionFactor, stride, 1)); - downsample.add(FrozenBatchNorm(2, planes * expansionFactor)); - add(std::move(downsample)); - } + const int stride +) { + const int expansionFactor = 4; + add(conv1x1(inC, planes, 1, 1)); + add(FrozenBatchNorm(FrozenBatchNorm(2, planes))); + add(ReLU()); + add(Conv2D(conv3x3(planes, planes, stride, 1))); + add(FrozenBatchNorm(FrozenBatchNorm(2, planes))); + add(ReLU()); + add(conv1x1(planes, planes * expansionFactor, 1, 1)); + add(FrozenBatchNorm(2, planes * expansionFactor)); + add(ReLU()); + if(inC != planes * expansionFactor || stride > 1) { + Sequential downsample; + downsample.add(conv1x1(inC, planes * expansionFactor, stride, 1)); + downsample.add(FrozenBatchNorm(2, planes * expansionFactor)); + add(std::move(downsample)); + } } std::vector ResNetBottleneckBlockFrozenBatchNorm::forward( - const std::vector& inputs) { - const auto& c1 = module(0); - const auto& bn1 = module(1); - const auto& relu1 = module(2); - const auto& c2 = module(3); - const auto& bn2 = module(4); - const auto& relu2 = module(5); - const auto& c3 = module(6); - const auto& bn3 = module(7); - const auto& relu3 = module(8); - - std::vector out; - out = c1->forward(inputs); - out = bn1->forward(out); - - out = relu1->forward(out); - - out = c2->forward(out); - out = bn2->forward(out); - out = relu2->forward(out); - - out = c3->forward(out); - out = bn3->forward(out); - - std::vector shortcut; - if (modules().size() > 9) { - shortcut = module(9)->forward(inputs); - } else { - shortcut = inputs; - } - return relu3->forward({out[0] + shortcut[0]}); + const std::vector& inputs +) { + const auto& c1 = module(0); + const auto& bn1 = module(1); + const auto& relu1 = module(2); + const auto& c2 = module(3); + const auto& bn2 = module(4); + const auto& relu2 = module(5); + const auto& c3 = module(6); + const auto& bn3 = module(7); + const auto& relu3 = module(8); + + std::vector out; + out = c1->forward(inputs); + out = bn1->forward(out); + + out = relu1->forward(out); + + out = c2->forward(out); + out = bn2->forward(out); + out = relu2->forward(out); + + out = c3->forward(out); + out = bn3->forward(out); + + std::vector shortcut; + if(modules().size() > 9) { + shortcut = module(9)->forward(inputs); + } else { + shortcut = inputs; + } + return relu3->forward({out[0] + shortcut[0]}); } std::string ResNetBottleneckBlockFrozenBatchNorm::prettyString() const { - std::ostringstream ss; - ss << "ResnetBottleneckBlockFrozenBn"; - ss << Container::prettyString(); - return ss.str(); + std::ostringstream ss; + ss << "ResnetBottleneckBlockFrozenBn"; + ss << Container::prettyString(); + return ss.str(); } std::vector ResNetBlockFrozenBatchNorm::forward( - const std::vector& inputs) { - const auto& c1 = module(0); - const auto& bn1 = module(1); - const auto& relu1 = module(2); - const auto& c2 = module(3); - const auto& bn2 = module(4); - const auto& relu2 = module(5); - std::vector out; - out = c1->forward(inputs); - out = bn1->forward(out); - out = relu1->forward(out); - out = c2->forward(out); - out = bn2->forward(out); - - std::vector shortcut; - if (modules().size() > 6) { - shortcut = module(6)->forward(inputs); - } else { - shortcut = inputs; - } - return relu2->forward({out[0] + shortcut[0]}); + const std::vector& inputs +) { + const auto& c1 = module(0); + const auto& bn1 = module(1); + const auto& relu1 = module(2); + const auto& c2 = module(3); + const auto& bn2 = module(4); + const auto& relu2 = module(5); + std::vector out; + out = c1->forward(inputs); + out = bn1->forward(out); + out = relu1->forward(out); + out = c2->forward(out); + out = bn2->forward(out); + + std::vector shortcut; + if(modules().size() > 6) { + shortcut = module(6)->forward(inputs); + } else { + shortcut = inputs; + } + return relu2->forward({out[0] + shortcut[0]}); } std::string ResNetBlockFrozenBatchNorm::prettyString() const { - std::ostringstream ss; - ss << "ResnetBlockFrozenBn"; - ss << Container::prettyString(); - return ss.str(); + std::ostringstream ss; + ss << "ResnetBlockFrozenBn"; + ss << Container::prettyString(); + return ss.str(); } ResNetBottleneckStageFrozenBatchNorm::ResNetBottleneckStageFrozenBatchNorm( const int inC, const int outC, const int numBlocks, - const int stride) { - add(ResNetBottleneckBlockFrozenBatchNorm(inC, outC, stride)); - const int expansionFactor = 4; - const int inPlanes = outC * expansionFactor; - for (int i = 1; i < numBlocks; i++) { - add(ResNetBottleneckBlockFrozenBatchNorm(inPlanes, outC)); - } + const int stride +) { + add(ResNetBottleneckBlockFrozenBatchNorm(inC, outC, stride)); + const int expansionFactor = 4; + const int inPlanes = outC * expansionFactor; + for(int i = 1; i < numBlocks; i++) { + add(ResNetBottleneckBlockFrozenBatchNorm(inPlanes, outC)); + } }; ResNetBottleneckStageFrozenBatchNorm::ResNetBottleneckStageFrozenBatchNorm() = - default; +default; ResNetStageFrozenBatchNorm::ResNetStageFrozenBatchNorm() = default; @@ -184,11 +190,12 @@ ResNetStageFrozenBatchNorm::ResNetStageFrozenBatchNorm( const int inC, const int outC, const int numBlocks, - const int stride) { - add(ResNetBlockFrozenBatchNorm(inC, outC, stride)); - for (int i = 1; i < numBlocks; i++) { - add(ResNetBlockFrozenBatchNorm(outC, outC)); - } + const int stride +) { + add(ResNetBlockFrozenBatchNorm(inC, outC, stride)); + for(int i = 1; i < numBlocks; i++) { + add(ResNetBlockFrozenBatchNorm(outC, outC)); + } } } // namespace fl diff --git a/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.h b/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.h index 5c6eee2..7514793 100644 --- a/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.h +++ b/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.h @@ -12,7 +12,7 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { // Note these are identical to those in Resnet.h. There are a number of ways to // refactor and consolidate including passing norm factory functions to the @@ -20,86 +20,95 @@ namespace vision { // the default Resnet implementation dead simple, we are recreating a lot // of functionality here. -class ConvFrozenBatchNormActivation : public fl::Sequential { - public: - ConvFrozenBatchNormActivation( - const int inChannels, - const int outChannels, - const int kw, - const int kh, - const int sx = 1, - const int sy = 1, - bool bn = true, - bool act = true); - - private: - ConvFrozenBatchNormActivation(); - FL_SAVE_LOAD_WITH_BASE(fl::Sequential) -}; - -class ResNetBlockFrozenBatchNorm : public fl::Container { - private: - ResNetBlockFrozenBatchNorm(); - FL_SAVE_LOAD_WITH_BASE(fl::Container) - public: - ResNetBlockFrozenBatchNorm( - const int inChannels, - const int outChannels, - const int stride = 1); - - std::vector forward( - const std::vector& inputs) override; - - std::string prettyString() const override; - - FL_BASIC_CONTAINER_CLONING(ResNetBlockFrozenBatchNorm) -}; - -class ResNetBottleneckBlockFrozenBatchNorm : public fl::Container { - private: - ResNetBottleneckBlockFrozenBatchNorm(); - FL_SAVE_LOAD_WITH_BASE(fl::Container) - public: - ResNetBottleneckBlockFrozenBatchNorm( - const int inChannels, - const int outChannels, - const int stride = 1); - - std::vector forward( - const std::vector& inputs) override; - - std::string prettyString() const override; - - FL_BASIC_CONTAINER_CLONING(ResNetBottleneckBlockFrozenBatchNorm) -}; - -class ResNetBottleneckStageFrozenBatchNorm : public fl::Sequential { - public: - ResNetBottleneckStageFrozenBatchNorm( - const int inChannels, - const int outChannels, - const int numBlocks, - const int stride); - - private: - ResNetBottleneckStageFrozenBatchNorm(); - FL_SAVE_LOAD_WITH_BASE(fl::Sequential) -}; - -class ResNetStageFrozenBatchNorm : public fl::Sequential { - public: - ResNetStageFrozenBatchNorm( - const int inChannels, - const int outChannels, - const int numBlocks, - const int stride); - - private: - ResNetStageFrozenBatchNorm(); - FL_SAVE_LOAD_WITH_BASE(fl::Sequential) -}; - -} // namespace vision + class ConvFrozenBatchNormActivation : public fl::Sequential { + public: + ConvFrozenBatchNormActivation( + const int inChannels, + const int outChannels, + const int kw, + const int kh, + const int sx = 1, + const int sy = 1, + bool bn = true, + bool act = true + ); + + private: + ConvFrozenBatchNormActivation(); + FL_SAVE_LOAD_WITH_BASE(fl::Sequential) + }; + + class ResNetBlockFrozenBatchNorm : public fl::Container { + private: + ResNetBlockFrozenBatchNorm(); + FL_SAVE_LOAD_WITH_BASE(fl::Container) + + public: + ResNetBlockFrozenBatchNorm( + const int inChannels, + const int outChannels, + const int stride = 1 + ); + + std::vector forward( + const std::vector& inputs + ) override; + + std::string prettyString() const override; + + FL_BASIC_CONTAINER_CLONING(ResNetBlockFrozenBatchNorm) + }; + + class ResNetBottleneckBlockFrozenBatchNorm : public fl::Container { + private: + ResNetBottleneckBlockFrozenBatchNorm(); + FL_SAVE_LOAD_WITH_BASE(fl::Container) + + public: + ResNetBottleneckBlockFrozenBatchNorm( + const int inChannels, + const int outChannels, + const int stride = 1 + ); + + std::vector forward( + const std::vector& inputs + ) override; + + std::string prettyString() const override; + + FL_BASIC_CONTAINER_CLONING(ResNetBottleneckBlockFrozenBatchNorm) + }; + + class ResNetBottleneckStageFrozenBatchNorm : public fl::Sequential { + public: + ResNetBottleneckStageFrozenBatchNorm( + const int inChannels, + const int outChannels, + const int numBlocks, + const int stride + ); + + private: + ResNetBottleneckStageFrozenBatchNorm(); + FL_SAVE_LOAD_WITH_BASE(fl::Sequential) + }; + + class ResNetStageFrozenBatchNorm : public fl::Sequential { + public: + ResNetStageFrozenBatchNorm( + const int inChannels, + const int outChannels, + const int numBlocks, + const int stride + ); + + private: + ResNetStageFrozenBatchNorm(); + FL_SAVE_LOAD_WITH_BASE(fl::Sequential) + }; + + } // namespace vision } // namespace pkg } // namespace fl CEREAL_REGISTER_TYPE(fl::pkg::vision::ConvFrozenBatchNormActivation) diff --git a/flashlight/pkg/vision/models/ViT.cpp b/flashlight/pkg/vision/models/ViT.cpp index 7c45a1b..eec5098 100644 --- a/flashlight/pkg/vision/models/ViT.cpp +++ b/flashlight/pkg/vision/models/ViT.cpp @@ -20,127 +20,136 @@ ViT::ViT( const int nHeads, const float pDropout, const float pLayerDrop, - const int nClasses) - : nLayers_(nLayers), - hiddenEmbSize_(hiddenEmbSize), - mlpSize_(mlpSize), - nHeads_(nHeads), - pDropout_(pDropout), - nClasses_(nClasses), - patchEmbedding_( - std::make_shared(3, hiddenEmbSize_, 16, 16, 16, 16)) { - // Class token - params_.emplace_back(fl::truncNormal({hiddenEmbSize_, 1}, 0.02)); - - // Positional embedding - params_.emplace_back(fl::truncNormal({hiddenEmbSize_, 14 * 14 + 1}, 0.02)); - - // Modules - add(patchEmbedding_); - - for (int i = 0; i < nLayers_; ++i) { - transformers_.emplace_back(std::make_shared( - hiddenEmbSize_, - hiddenEmbSize_ / nHeads_, - mlpSize_, - nHeads_, - pDropout_, - pLayerDrop * i / (nLayers_ - 1))); - add(transformers_.back()); - } - - linearOut_ = std::make_shared( - fl::truncNormal({nClasses_, hiddenEmbSize_}, 0.02), - fl::constant(0., nClasses_, 1, fl::dtype::f32)); - add(linearOut_); - - ln_ = std::make_shared( - std::vector({0}), 1e-6, true, hiddenEmbSize_); - add(ln_); + const int nClasses +) : nLayers_(nLayers), + hiddenEmbSize_(hiddenEmbSize), + mlpSize_(mlpSize), + nHeads_(nHeads), + pDropout_(pDropout), + nClasses_(nClasses), + patchEmbedding_( + std::make_shared(3, hiddenEmbSize_, 16, 16, 16, 16)) { + // Class token + params_.emplace_back(fl::truncNormal({hiddenEmbSize_, 1}, 0.02)); + + // Positional embedding + params_.emplace_back(fl::truncNormal({hiddenEmbSize_, 14 * 14 + 1}, 0.02)); + + // Modules + add(patchEmbedding_); + + for(int i = 0; i < nLayers_; ++i) { + transformers_.emplace_back( + std::make_shared( + hiddenEmbSize_, + hiddenEmbSize_ / nHeads_, + mlpSize_, + nHeads_, + pDropout_, + pLayerDrop * i / (nLayers_ - 1) + ) + ); + add(transformers_.back()); + } + + linearOut_ = std::make_shared( + fl::truncNormal({nClasses_, hiddenEmbSize_}, 0.02), + fl::constant(0., nClasses_, 1, fl::dtype::f32) + ); + add(linearOut_); + + ln_ = std::make_shared( + std::vector({0}), + 1e-6, + true, + hiddenEmbSize_ + ); + add(ln_); } void ViT::copy(const ViT& other) { - clear(); + clear(); - // Class token - auto clsTkn = other.param(0); - params_.emplace_back(clsTkn.copy()); + // Class token + auto clsTkn = other.param(0); + params_.emplace_back(clsTkn.copy()); - // Positional embedding - auto posEmb = other.param(1); - params_.emplace_back(posEmb.copy()); + // Positional embedding + auto posEmb = other.param(1); + params_.emplace_back(posEmb.copy()); - // Modules - patchEmbedding_ = std::make_shared(*other.patchEmbedding_); - add(patchEmbedding_); + // Modules + patchEmbedding_ = std::make_shared(*other.patchEmbedding_); + add(patchEmbedding_); - for (const auto& vit : other.transformers_) { - transformers_.emplace_back(std::make_shared(*vit)); - add(transformers_.back()); - } + for(const auto& vit : other.transformers_) { + transformers_.emplace_back(std::make_shared(*vit)); + add(transformers_.back()); + } - linearOut_ = std::make_shared(*other.linearOut_); - add(linearOut_); + linearOut_ = std::make_shared(*other.linearOut_); + add(linearOut_); - ln_ = std::make_shared(*other.ln_); - add(ln_); + ln_ = std::make_shared(*other.ln_); + add(ln_); } ViT::ViT(const ViT& other) { - copy(other); + copy(other); } ViT& ViT::operator=(const ViT& other) { - copy(other); - return *this; + copy(other); + return *this; } std::unique_ptr ViT::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::vector ViT::forward( - const std::vector& inputs) { - // Patching - auto output = patchEmbedding_->forward(inputs[0]); // H x W x C x B - output = moddims(output, {-1, 1, 0, 0}); // T x 1 x C x B - output = reorder(output, {2, 0, 3, 1}); // C x T x B x 1 - output = moddims(output, {0, 0, 0}); // C x T x B - auto B = output.dim(2); - - // Prepending the class token - auto clsToken = - tile(params_[0], {1, 1, B}).astype(output.type()); // C x 1 x B - output = concatenate({clsToken, output}, 1); - - // Positional embedding - auto posEmb = tile(params_[1], {1, 1, B}).astype(output.type()); - output = output + posEmb; - if (train_) { - output = dropout(output, pDropout_); - } - - // Transformers - for (int i = 0; i < nLayers_; ++i) { - output = transformers_[i]->forward({output}).front(); - } - - // Linear - output = ln_->forward(output); // C x T x B - output = reorder(output, {0, 2, 1})(fl::span, fl::span, 0); // C x B x 1 - output = linearOut_->forward(output); - - return {output}; + const std::vector& inputs +) { + // Patching + auto output = patchEmbedding_->forward(inputs[0]); // H x W x C x B + output = moddims(output, {-1, 1, 0, 0}); // T x 1 x C x B + output = reorder(output, {2, 0, 3, 1}); // C x T x B x 1 + output = moddims(output, {0, 0, 0}); // C x T x B + auto B = output.dim(2); + + // Prepending the class token + auto clsToken = + tile(params_[0], {1, 1, B}).astype(output.type()); // C x 1 x B + output = concatenate({clsToken, output}, 1); + + // Positional embedding + auto posEmb = tile(params_[1], {1, 1, B}).astype(output.type()); + output = output + posEmb; + if(train_) { + output = dropout(output, pDropout_); + } + + // Transformers + for(int i = 0; i < nLayers_; ++i) { + output = transformers_[i]->forward({output}).front(); + } + + // Linear + output = ln_->forward(output); // C x T x B + output = reorder(output, {0, 2, 1})(fl::span, fl::span, 0); // C x B x 1 + output = linearOut_->forward(output); + + return {output}; } std::string ViT::prettyString() const { - std::ostringstream ss; - ss << "ViT (" << nClasses_ << " classes) with " << nLayers_ - << " Transformers:\n"; - for (const auto& transformers : transformers_) { - ss << transformers->prettyString() << "\n"; - } - return ss.str(); + std::ostringstream ss; + ss << "ViT (" << nClasses_ << " classes) with " << nLayers_ + << " Transformers:\n"; + for(const auto& transformers : transformers_) { + ss << transformers->prettyString() << "\n"; + } + return ss.str(); } } // namespace fl diff --git a/flashlight/pkg/vision/models/ViT.h b/flashlight/pkg/vision/models/ViT.h index 836db7c..13b848f 100644 --- a/flashlight/pkg/vision/models/ViT.h +++ b/flashlight/pkg/vision/models/ViT.h @@ -12,7 +12,7 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { /* * Implementation of Vision Transformer (ViT) models following [AN IMAGE IS @@ -21,57 +21,61 @@ namespace vision { * * This implementation is highly inspired by [timm](https://git.io/JYOql). */ -class ViT : public fl::Container { - private: - FL_SAVE_LOAD_WITH_BASE( - fl::Container, - nLayers_, - hiddenEmbSize_, - mlpSize_, - nHeads_, - pDropout_, - patchEmbedding_, - transformers_, - linearOut_, - ln_) + class ViT : public fl::Container { + private: + FL_SAVE_LOAD_WITH_BASE( + fl::Container, + nLayers_, + hiddenEmbSize_, + mlpSize_, + nHeads_, + pDropout_, + patchEmbedding_, + transformers_, + linearOut_, + ln_ + ) - int nLayers_; - int hiddenEmbSize_; - int mlpSize_; - int nHeads_; - float pDropout_; - int nClasses_; + int nLayers_; + int hiddenEmbSize_; + int mlpSize_; + int nHeads_; + float pDropout_; + int nClasses_; - std::shared_ptr patchEmbedding_; - std::vector> transformers_; - std::shared_ptr linearOut_; - std::shared_ptr ln_; + std::shared_ptr patchEmbedding_; + std::vector> transformers_; + std::shared_ptr linearOut_; + std::shared_ptr ln_; - ViT() = default; - void copy(const ViT& other); + ViT() = default; + void copy(const ViT& other); - public: - ViT(const int nLayers, - const int hiddenEmbSize, - const int mlpSize, - const int nHeads, - const float pDropout, - const float pLayerDrop, - const int nClasses); + public: + ViT( + const int nLayers, + const int hiddenEmbSize, + const int mlpSize, + const int nHeads, + const float pDropout, + const float pLayerDrop, + const int nClasses + ); - ViT(const ViT& other); - ViT& operator=(const ViT& other); - ViT(ViT&& other) = default; - ViT& operator=(ViT&& other) = default; + ViT(const ViT& other); + ViT& operator=(const ViT& other); + ViT(ViT&& other) = default; + ViT& operator=(ViT&& other) = default; - std::vector forward( - const std::vector& inputs) override; + std::vector forward( + const std::vector& inputs + ) override; - std::unique_ptr clone() const override; - std::string prettyString() const override; -}; + std::unique_ptr clone() const override; + std::string prettyString() const override; + }; -} // namespace vision + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp b/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp index 8006908..99544af 100644 --- a/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp +++ b/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp @@ -18,14 +18,14 @@ FrozenBatchNorm::FrozenBatchNorm( double momentum /* = 0.1 */, double eps /* = 1e-5*/, bool affine /* = true*/, - bool trackStats /* = true*/) - : FrozenBatchNorm( - std::vector(1, featAxis), - featSize, - momentum, - eps, - affine, - trackStats) {} + bool trackStats /* = true*/ +) : FrozenBatchNorm( + std::vector(1, featAxis), + featSize, + momentum, + eps, + affine, + trackStats) {} FrozenBatchNorm::FrozenBatchNorm( const std::vector& featAxis, @@ -33,49 +33,49 @@ FrozenBatchNorm::FrozenBatchNorm( double momentum /* = 0.1*/, double eps /* = 1e-5 */, bool affine /* = true*/, - bool trackStats /* = true*/) - : BatchNorm(featAxis, featSize, momentum, eps, affine, trackStats) { - BatchNorm::initialize(); + bool trackStats /* = true*/ +) : BatchNorm(featAxis, featSize, momentum, eps, affine, trackStats) { + BatchNorm::initialize(); } std::unique_ptr FrozenBatchNorm::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } Variable FrozenBatchNorm::forward(const Variable& input) { - auto scale = params_[0] / fl::sqrt(runningVar_ + epsilon_); - auto bias = params_[1] - runningMean_ * scale; - bias = fl::moddims(bias, {1, 1, bias.dim(0), 1}).astype(input.type()); - scale = fl::moddims(scale, {1, 1, scale.dim(0), 1}).astype(input.type()); - return (input * fl::tileAs(scale, input)) + fl::tileAs(bias, input); + auto scale = params_[0] / fl::sqrt(runningVar_ + epsilon_); + auto bias = params_[1] - runningMean_ * scale; + bias = fl::moddims(bias, {1, 1, bias.dim(0), 1}).astype(input.type()); + scale = fl::moddims(scale, {1, 1, scale.dim(0), 1}).astype(input.type()); + return (input * fl::tileAs(scale, input)) + fl::tileAs(bias, input); } void FrozenBatchNorm::setRunningMean(const fl::Variable& x) { - runningMean_ = x; + runningMean_ = x; } void FrozenBatchNorm::setRunningVar(const fl::Variable& x) { - runningVar_ = x; + runningVar_ = x; } void FrozenBatchNorm::train() { - for (auto& param : params_) { - param.setCalcGrad(false); - } - runningVar_.setCalcGrad(false); - runningMean_.setCalcGrad(false); - train_ = false; + for(auto& param : params_) { + param.setCalcGrad(false); + } + runningVar_.setCalcGrad(false); + runningMean_.setCalcGrad(false); + train_ = false; } std::string FrozenBatchNorm::prettyString() const { - std::ostringstream ss; - ss << "FrozenBatchNorm"; - ss << " ( axis : { "; - for (auto x : featAxis_) { - ss << x << " "; - } - ss << "}, size : " << featSize_ << " )"; - return ss.str(); + std::ostringstream ss; + ss << "FrozenBatchNorm"; + ss << " ( axis : { "; + for(auto x : featAxis_) { + ss << x << " "; + } + ss << "}, size : " << featSize_ << " )"; + return ss.str(); } } // namespace fl diff --git a/flashlight/pkg/vision/nn/FrozenBatchNorm.h b/flashlight/pkg/vision/nn/FrozenBatchNorm.h index 79556b5..50d9750 100644 --- a/flashlight/pkg/vision/nn/FrozenBatchNorm.h +++ b/flashlight/pkg/vision/nn/FrozenBatchNorm.h @@ -24,81 +24,83 @@ namespace fl { * \f$\beta\f$ are learnable parameters for affine transformation. */ class FrozenBatchNorm : public BatchNorm { - private: - FrozenBatchNorm() = default; // intentionally private - FL_SAVE_LOAD_WITH_BASE(BatchNorm) +private: + FrozenBatchNorm() = default; // intentionally private + FL_SAVE_LOAD_WITH_BASE(BatchNorm) - public: - /** - * Constructs a FrozenBatchNorm module. - * - * @param featAxis the axis over which normalizationis performed - * @param featSize the size of the dimension along `featAxis` - * @param momentum an exponential average factor used to compute running mean - * and variance. - * \f[ runningMean = runningMean \times (1-momentum) - * + newMean \times momentum \f] - * If < 0, cumulative moving average is used. - * @param eps \f$\epsilon\f$ - * @param affine a boolean value that controls the learning of \f$\gamma\f$ - * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively - * if set to `false`, or initialized as learnable parameters - * if set to `true`. - * @param trackStats a boolean value that controls whether to track the - * running mean and variance while in train mode. If `false`, batch - * statistics are used to perform normalization in both train and eval mode. - */ - FrozenBatchNorm( - int featAxis, - int featSize, - double momentum = 0.1, - double eps = 1e-5, - bool affine = true, - bool trackStats = true); +public: + /** + * Constructs a FrozenBatchNorm module. + * + * @param featAxis the axis over which normalizationis performed + * @param featSize the size of the dimension along `featAxis` + * @param momentum an exponential average factor used to compute running mean + * and variance. + * \f[ runningMean = runningMean \times (1-momentum) + * + newMean \times momentum \f] + * If < 0, cumulative moving average is used. + * @param eps \f$\epsilon\f$ + * @param affine a boolean value that controls the learning of \f$\gamma\f$ + * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively + * if set to `false`, or initialized as learnable parameters + * if set to `true`. + * @param trackStats a boolean value that controls whether to track the + * running mean and variance while in train mode. If `false`, batch + * statistics are used to perform normalization in both train and eval mode. + */ + FrozenBatchNorm( + int featAxis, + int featSize, + double momentum = 0.1, + double eps = 1e-5, + bool affine = true, + bool trackStats = true + ); - /** - * Constructs a FrozenBatchNorm module. - * - * @param featAxis the axis over which normalization is performed - * @param featSize total dimension along `featAxis`. - * For example, to perform Temporal Batch Normalization on input of size - * [\f$L\f$, \f$C\f$, \f$N\f$], use `featAxis` = {1}, `featSize` = \f$C\f$. - * To perform normalization per activation on input of size - * [\f$W\f$, \f$H\f$, \f$C\f$, \f$N\f$], use `featAxis` = {0, 1, 2}, - * `featSize` = \f$W \times H \times C\f$. - * @param momentum an exponential average factor used to compute running mean - * and variance. - * \f[ runningMean = runningMean \times (1-momentum) - * + newMean \times momentum \f] - * If < 0, cumulative moving average is used. - * @param eps \f$\epsilon\f$ - * @param affine a boolean value that controls the learning of \f$\gamma\f$ - * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively - * if set to `false`, or initialized as learnable parameters - * if set to `true`. - * @param trackStats a boolean value that controls whether to track the - * running mean and variance while in train mode. If `false`, batch - * statistics are used to perform normalization in both train and eval mode. - */ - FrozenBatchNorm( - const std::vector& featAxis, - int featSize, - double momentum = 0.1, - double eps = 1e-5, - bool affine = true, - bool trackStats = true); + /** + * Constructs a FrozenBatchNorm module. + * + * @param featAxis the axis over which normalization is performed + * @param featSize total dimension along `featAxis`. + * For example, to perform Temporal Batch Normalization on input of size + * [\f$L\f$, \f$C\f$, \f$N\f$], use `featAxis` = {1}, `featSize` = \f$C\f$. + * To perform normalization per activation on input of size + * [\f$W\f$, \f$H\f$, \f$C\f$, \f$N\f$], use `featAxis` = {0, 1, 2}, + * `featSize` = \f$W \times H \times C\f$. + * @param momentum an exponential average factor used to compute running mean + * and variance. + * \f[ runningMean = runningMean \times (1-momentum) + * + newMean \times momentum \f] + * If < 0, cumulative moving average is used. + * @param eps \f$\epsilon\f$ + * @param affine a boolean value that controls the learning of \f$\gamma\f$ + * and \f$\beta\f$. \f$\gamma\f$ and \f$\beta\f$ are set to 1, 0 respectively + * if set to `false`, or initialized as learnable parameters + * if set to `true`. + * @param trackStats a boolean value that controls whether to track the + * running mean and variance while in train mode. If `false`, batch + * statistics are used to perform normalization in both train and eval mode. + */ + FrozenBatchNorm( + const std::vector& featAxis, + int featSize, + double momentum = 0.1, + double eps = 1e-5, + bool affine = true, + bool trackStats = true + ); - std::unique_ptr clone() const override; + std::unique_ptr clone() const override; - Variable forward(const Variable& input) override; + Variable forward(const Variable& input) override; - void setRunningVar(const fl::Variable& x); + void setRunningVar(const fl::Variable& x); - void setRunningMean(const fl::Variable& x); + void setRunningMean(const fl::Variable& x); - void train() override; + void train() override; - std::string prettyString() const override; + std::string prettyString() const override; }; } // namespace fl diff --git a/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp b/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp index 7dce07d..43799e5 100644 --- a/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp +++ b/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp @@ -14,108 +14,113 @@ namespace fl::pkg::vision { std::string PositionalEmbeddingSine::prettyString() const { - return "PositionalEmbeddingSine"; + return "PositionalEmbeddingSine"; } PositionalEmbeddingSine::PositionalEmbeddingSine( const int numPosFeats, const int temperature, const bool normalize, - const float scale) - : numPosFeats_(numPosFeats), - temperature_(temperature), - normalize_(normalize), - scale_(scale){}; + const float scale +) : numPosFeats_(numPosFeats), + temperature_(temperature), + normalize_(normalize), + scale_(scale) {}; PositionalEmbeddingSine::PositionalEmbeddingSine( - const PositionalEmbeddingSine& other) - : numPosFeats_(other.numPosFeats_), - temperature_(other.temperature_), - normalize_(other.normalize_), - scale_(other.scale_) { - train_ = other.train_; - for (auto& mod : other.modules_) { - add(mod->clone()); - } + const PositionalEmbeddingSine& other +) : numPosFeats_(other.numPosFeats_), + temperature_(other.temperature_), + normalize_(other.normalize_), + scale_(other.scale_) { + train_ = other.train_; + for(auto& mod : other.modules_) { + add(mod->clone()); + } } PositionalEmbeddingSine& PositionalEmbeddingSine::operator=( - const PositionalEmbeddingSine& other) { - train_ = other.train_; - numPosFeats_ = other.numPosFeats_; - temperature_ = other.temperature_; - normalize_ = other.normalize_; - scale_ = other.scale_; - clear(); - for (auto& mod : other.modules_) { - add(mod->clone()); - } - return *this; + const PositionalEmbeddingSine& other +) { + train_ = other.train_; + numPosFeats_ = other.numPosFeats_; + temperature_ = other.temperature_; + normalize_ = other.normalize_; + scale_ = other.scale_; + clear(); + for(auto& mod : other.modules_) { + add(mod->clone()); + } + return *this; } std::unique_ptr PositionalEmbeddingSine::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::vector PositionalEmbeddingSine::forward( - const std::vector& inputs) { - assert(inputs.size() == 1); - auto input = inputs[0]; - - auto inputDims = input.shape(); - // Input mask will be [ w x h x 1 x b ] - // but implementation expects [ w x h x b ] in order to do interleaves easier - auto nonMask = fl::reshape( - input.tensor(), {inputDims[0], inputDims[1], inputDims[3], 1}); - - auto expandDims = [](const Tensor& in) { - auto dims = in.shape(); - assert(dims[3] == 1); - return fl::reshape(in, {1, dims[0], dims[1], dims[2]}); - }; - - auto interleave = [](Tensor x, Tensor y) { - auto dims = x.shape(); - x = x.flatten(); - y = y.flatten(); - x = fl::reshape(x, {1, x.dim(0)}); - y = fl::reshape(y, {1, y.dim(0)}); - auto joined = fl::concatenate(0, x, y); - dims[0] = dims[0] * 2; - return fl::reshape(joined, dims); - }; - - Tensor xEmbed = fl::cumsum(nonMask, 0); - Tensor yEmbed = fl::cumsum(nonMask, 1); - if (normalize_) { - const float eps = 1e-6; - yEmbed = - (yEmbed / yEmbed(fl::span, yEmbed.dim(1) - 1, fl::span) + eps) * scale_; - xEmbed = - (xEmbed / xEmbed(xEmbed.dim(0) - 1, fl::span, fl::span) + eps) * scale_; - } - - auto dim = fl::arange({numPosFeats_}, 0, fl::dtype::f32); - dim = fl::power(temperature_, ((2 * fl::floor(dim / 2)) / numPosFeats_)); - - auto posX = expandDims(xEmbed) / dim; - auto posY = expandDims(yEmbed) / dim; - - auto posXSin = fl::sin(posX(fl::range(0, fl::end, 2), fl::span)); - auto posXCos = fl::cos(posX(fl::range(1, fl::end, 2), fl::span)); - auto posYSin = fl::sin(posY(fl::range(0, fl::end, 2), fl::span)); - auto posYCos = fl::cos(posY(fl::range(1, fl::end, 2), fl::span)); - - posX = interleave(posXSin, posXCos); - posY = interleave(posYSin, posYCos); - auto result = fl::concatenate(0, posY, posX); - result = fl::transpose(result, {1, 2, 0, 3}); - return {fl::Variable(result, false)}; + const std::vector& inputs +) { + assert(inputs.size() == 1); + auto input = inputs[0]; + + auto inputDims = input.shape(); + // Input mask will be [ w x h x 1 x b ] + // but implementation expects [ w x h x b ] in order to do interleaves easier + auto nonMask = fl::reshape( + input.tensor(), + {inputDims[0], inputDims[1], inputDims[3], 1} + ); + + auto expandDims = [](const Tensor& in) { + auto dims = in.shape(); + assert(dims[3] == 1); + return fl::reshape(in, {1, dims[0], dims[1], dims[2]}); + }; + + auto interleave = [](Tensor x, Tensor y) { + auto dims = x.shape(); + x = x.flatten(); + y = y.flatten(); + x = fl::reshape(x, {1, x.dim(0)}); + y = fl::reshape(y, {1, y.dim(0)}); + auto joined = fl::concatenate(0, x, y); + dims[0] = dims[0] * 2; + return fl::reshape(joined, dims); + }; + + Tensor xEmbed = fl::cumsum(nonMask, 0); + Tensor yEmbed = fl::cumsum(nonMask, 1); + if(normalize_) { + const float eps = 1e-6; + yEmbed = + (yEmbed / yEmbed(fl::span, yEmbed.dim(1) - 1, fl::span) + eps) * scale_; + xEmbed = + (xEmbed / xEmbed(xEmbed.dim(0) - 1, fl::span, fl::span) + eps) * scale_; + } + + auto dim = fl::arange({numPosFeats_}, 0, fl::dtype::f32); + dim = fl::power(temperature_, ((2 * fl::floor(dim / 2)) / numPosFeats_)); + + auto posX = expandDims(xEmbed) / dim; + auto posY = expandDims(yEmbed) / dim; + + auto posXSin = fl::sin(posX(fl::range(0, fl::end, 2), fl::span)); + auto posXCos = fl::cos(posX(fl::range(1, fl::end, 2), fl::span)); + auto posYSin = fl::sin(posY(fl::range(0, fl::end, 2), fl::span)); + auto posYCos = fl::cos(posY(fl::range(1, fl::end, 2), fl::span)); + + posX = interleave(posXSin, posXCos); + posY = interleave(posYSin, posYCos); + auto result = fl::concatenate(0, posY, posX); + result = fl::transpose(result, {1, 2, 0, 3}); + return {fl::Variable(result, false)}; } std::vector PositionalEmbeddingSine::operator()( - const std::vector& input) { - return forward(input); + const std::vector& input +) { + return forward(input); } } // namespace fl diff --git a/flashlight/pkg/vision/nn/PositionalEmbeddingSine.h b/flashlight/pkg/vision/nn/PositionalEmbeddingSine.h index 47d6e0c..19816fe 100644 --- a/flashlight/pkg/vision/nn/PositionalEmbeddingSine.h +++ b/flashlight/pkg/vision/nn/PositionalEmbeddingSine.h @@ -12,43 +12,45 @@ namespace fl { namespace pkg { -namespace vision { - -class PositionalEmbeddingSine : public Container { - public: - PositionalEmbeddingSine( - const int numPosFeats, - const int temperature, - const bool normalize, - const float scale); - - PositionalEmbeddingSine(const PositionalEmbeddingSine& other); - PositionalEmbeddingSine(PositionalEmbeddingSine&& other) = default; - PositionalEmbeddingSine& operator=(const PositionalEmbeddingSine& other); - PositionalEmbeddingSine& operator=(PositionalEmbeddingSine&& other) = default; - std::unique_ptr clone() const override; - - std::vector forward(const std::vector& input) override; - - std::vector operator()(const std::vector& input); - - std::string prettyString() const override; - - private: - PositionalEmbeddingSine() = default; - FL_SAVE_LOAD_WITH_BASE( - fl::Container, - numPosFeats_, - temperature_, - normalize_, - scale_) - int numPosFeats_; - int temperature_; - bool normalize_; - float scale_; -}; - -} // namespace vision + namespace vision { + + class PositionalEmbeddingSine : public Container { + public: + PositionalEmbeddingSine( + const int numPosFeats, + const int temperature, + const bool normalize, + const float scale + ); + + PositionalEmbeddingSine(const PositionalEmbeddingSine& other); + PositionalEmbeddingSine(PositionalEmbeddingSine&& other) = default; + PositionalEmbeddingSine& operator=(const PositionalEmbeddingSine& other); + PositionalEmbeddingSine& operator=(PositionalEmbeddingSine&& other) = default; + std::unique_ptr clone() const override; + + std::vector forward(const std::vector& input) override; + + std::vector operator()(const std::vector& input); + + std::string prettyString() const override; + + private: + PositionalEmbeddingSine() = default; + FL_SAVE_LOAD_WITH_BASE( + fl::Container, + numPosFeats_, + temperature_, + normalize_, + scale_ + ) + int numPosFeats_; + int temperature_; + bool normalize_; + float scale_; + }; + + } // namespace vision } // namespace pkg } // namespace fl CEREAL_REGISTER_TYPE(fl::pkg::vision::PositionalEmbeddingSine) diff --git a/flashlight/pkg/vision/nn/Transformer.cpp b/flashlight/pkg/vision/nn/Transformer.cpp index 61b727b..3a1f4cb 100644 --- a/flashlight/pkg/vision/nn/Transformer.cpp +++ b/flashlight/pkg/vision/nn/Transformer.cpp @@ -14,28 +14,26 @@ using namespace fl; namespace { -std::shared_ptr -makeTransformerLinear(int inDim, int outDim, float gain = 1.0f) { - int fanIn = inDim; - int fanOut = outDim; - float std = gain * std::sqrt(2.0 / (fanIn + fanOut)); - float bound = std::sqrt(3.0) * std; - auto w = fl::uniform(outDim, inDim, -bound, bound, fl::dtype::f32, true); - bound = std::sqrt(1.0 / fanIn); - auto b = fl::uniform({outDim}, -bound, bound, fl::dtype::f32, true); - return std::make_shared(w, b); -} - -std::shared_ptr -makeMultiheadedAttentionLinear(int inDim, int outDim, int fanOutMult = 1) { - int fanIn = inDim; - int fanOut = outDim * fanOutMult; - float gain = 1.0; - float std = gain * std::sqrt(2.0 / (fanIn + fanOut)); - float bound = std::sqrt(3.0) * std; - auto w = fl::uniform(outDim, inDim, -bound, bound, fl::dtype::f32, true); - auto b = fl::param(fl::full({outDim}, 0)); - return std::make_shared(w, b); +std::shared_ptr makeTransformerLinear(int inDim, int outDim, float gain = 1.0f) { + int fanIn = inDim; + int fanOut = outDim; + float std = gain * std::sqrt(2.0 / (fanIn + fanOut)); + float bound = std::sqrt(3.0) * std; + auto w = fl::uniform(outDim, inDim, -bound, bound, fl::dtype::f32, true); + bound = std::sqrt(1.0 / fanIn); + auto b = fl::uniform({outDim}, -bound, bound, fl::dtype::f32, true); + return std::make_shared(w, b); +} + +std::shared_ptr makeMultiheadedAttentionLinear(int inDim, int outDim, int fanOutMult = 1) { + int fanIn = inDim; + int fanOut = outDim * fanOutMult; + float gain = 1.0; + float std = gain * std::sqrt(2.0 / (fanIn + fanOut)); + float bound = std::sqrt(3.0) * std; + auto w = fl::uniform(outDim, inDim, -bound, bound, fl::dtype::f32, true); + auto b = fl::param(fl::full({outDim}, 0)); + return std::make_shared(w, b); } } // namespace @@ -50,374 +48,404 @@ fl::Variable transformerMultiheadAttention( const fl::Variable& value, const fl::Variable& keyPaddingMask, const int32_t nHead, - const double pDropout) { - int32_t bsz = query.dim(1); - int32_t modelDim = query.dim(0); - int32_t headDim = modelDim / nHead; - int32_t tgtLen = query.dim(2); - int32_t srcLen = key.dim(2); - - auto q = moddims(query, {headDim, nHead, bsz, tgtLen}); - auto v = moddims(value, {headDim, nHead, bsz, srcLen}); - auto k = moddims(key, {headDim, nHead, bsz, srcLen}); - // Reorder so that the "Sequence" is along the first dimension, - // the embedding is along the zeroth dimension - q = reorder(q, {0, 3, 1, 2}); - v = reorder(v, {0, 3, 1, 2}); - k = reorder(k, {0, 3, 1, 2}); - - auto scores = matmulTN(q, k); - - if (!keyPaddingMask.isEmpty()) { - scores = scores + - tileAs(moddims(log(keyPaddingMask), {1, srcLen, 1, bsz}), scores); - } - - auto attn = dropout(softmax(scores, 1), pDropout); - auto result = matmulNT(attn.astype(v.type()), v); - result = moddims(result, {tgtLen, modelDim, bsz}); - result = reorder(result, {1, 2, 0}); - return result; + const double pDropout +) { + int32_t bsz = query.dim(1); + int32_t modelDim = query.dim(0); + int32_t headDim = modelDim / nHead; + int32_t tgtLen = query.dim(2); + int32_t srcLen = key.dim(2); + + auto q = moddims(query, {headDim, nHead, bsz, tgtLen}); + auto v = moddims(value, {headDim, nHead, bsz, srcLen}); + auto k = moddims(key, {headDim, nHead, bsz, srcLen}); + // Reorder so that the "Sequence" is along the first dimension, + // the embedding is along the zeroth dimension + q = reorder(q, {0, 3, 1, 2}); + v = reorder(v, {0, 3, 1, 2}); + k = reorder(k, {0, 3, 1, 2}); + + auto scores = matmulTN(q, k); + + if(!keyPaddingMask.isEmpty()) { + scores = scores + + tileAs(moddims(log(keyPaddingMask), {1, srcLen, 1, bsz}), scores); + } + + auto attn = dropout(softmax(scores, 1), pDropout); + auto result = matmulNT(attn.astype(v.type()), v); + result = moddims(result, {tgtLen, modelDim, bsz}); + result = reorder(result, {1, 2, 0}); + return result; } MultiheadAttention::MultiheadAttention( int32_t modelDim, int32_t headDim, int32_t numHeads, - float pDropout) - : pDropout_(pDropout), numHeads_(numHeads) { - wq_ = makeMultiheadedAttentionLinear(modelDim, headDim * numHeads, 3); - wk_ = makeMultiheadedAttentionLinear(modelDim, headDim * numHeads, 3); - wv_ = makeMultiheadedAttentionLinear(modelDim, headDim * numHeads, 3); - wf_ = makeMultiheadedAttentionLinear(headDim * numHeads, modelDim); - createLayers(); + float pDropout +) : pDropout_(pDropout), + numHeads_(numHeads) { + wq_ = makeMultiheadedAttentionLinear(modelDim, headDim * numHeads, 3); + wk_ = makeMultiheadedAttentionLinear(modelDim, headDim * numHeads, 3); + wv_ = makeMultiheadedAttentionLinear(modelDim, headDim * numHeads, 3); + wf_ = makeMultiheadedAttentionLinear(headDim * numHeads, modelDim); + createLayers(); } MultiheadAttention::MultiheadAttention(const MultiheadAttention& other) { - copy(other); - createLayers(); + copy(other); + createLayers(); } MultiheadAttention& MultiheadAttention::operator=( - const MultiheadAttention& other) { - clear(); - copy(other); - createLayers(); - return *this; + const MultiheadAttention& other +) { + clear(); + copy(other); + createLayers(); + return *this; } void MultiheadAttention::copy(const MultiheadAttention& other) { - train_ = other.train_; - pDropout_ = other.pDropout_; - numHeads_ = other.numHeads_; - wq_ = std::make_shared(*other.wq_); - wk_ = std::make_shared(*other.wk_); - wv_ = std::make_shared(*other.wv_); - wf_ = std::make_shared(*other.wf_); + train_ = other.train_; + pDropout_ = other.pDropout_; + numHeads_ = other.numHeads_; + wq_ = std::make_shared(*other.wq_); + wk_ = std::make_shared(*other.wk_); + wv_ = std::make_shared(*other.wv_); + wf_ = std::make_shared(*other.wf_); } void MultiheadAttention::createLayers() { - add(wq_); - add(wk_); - add(wv_); - add(wf_); + add(wq_); + add(wk_); + add(wv_); + add(wf_); } std::unique_ptr MultiheadAttention::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::vector MultiheadAttention::forward( const Variable& queries, const Variable& keys, const Variable& values, - const Variable& keyPaddingMask) { - assert(queries.dim(0) == keys.dim(0)); - assert(queries.dim(0) == values.dim(0)); - assert(queries.dim(1) == keys.dim(1)); - assert(queries.dim(1) == values.dim(1)); - assert(values.dim(2) == keys.dim(2)); - - int32_t modelDim = queries.dim(0); - int32_t headDim = modelDim / numHeads_; - - if (!keyPaddingMask.isEmpty()) { - assert(keyPaddingMask.dim(0) == keys.dim(2)); - assert(keyPaddingMask.dim(1) == keys.dim(1)); - } - - auto q = wq_->forward(queries); - auto k = wk_->forward(keys); - auto v = wv_->forward(values); - - q = q / std::sqrt(float(headDim)); - float dropout = train_ ? pDropout_ : 0.0f; - auto result = transformerMultiheadAttention( - q, k, v, keyPaddingMask, numHeads_, dropout); - result = (*wf_)(result); - - assert(result.shape() == queries.shape()); - std::vector results = {result}; - return results; + const Variable& keyPaddingMask +) { + assert(queries.dim(0) == keys.dim(0)); + assert(queries.dim(0) == values.dim(0)); + assert(queries.dim(1) == keys.dim(1)); + assert(queries.dim(1) == values.dim(1)); + assert(values.dim(2) == keys.dim(2)); + + int32_t modelDim = queries.dim(0); + int32_t headDim = modelDim / numHeads_; + + if(!keyPaddingMask.isEmpty()) { + assert(keyPaddingMask.dim(0) == keys.dim(2)); + assert(keyPaddingMask.dim(1) == keys.dim(1)); + } + + auto q = wq_->forward(queries); + auto k = wk_->forward(keys); + auto v = wv_->forward(values); + + q = q / std::sqrt(float(headDim)); + float dropout = train_ ? pDropout_ : 0.0f; + auto result = transformerMultiheadAttention( + q, + k, + v, + keyPaddingMask, + numHeads_, + dropout + ); + result = (*wf_)(result); + + assert(result.shape() == queries.shape()); + std::vector results = {result}; + return results; }; std::vector MultiheadAttention::forward( - const std::vector& input) { - assert(input.size() == 4); - return this->forward(input[0], input[1], input[2], input[3]); + const std::vector& input +) { + assert(input.size() == 4); + return this->forward(input[0], input[1], input[2], input[3]); } std::string MultiheadAttention::prettyString() const { - std::ostringstream ss; - ss << "MultiheadAttention"; - ss << Container::prettyString(); - return ss.str(); + std::ostringstream ss; + ss << "MultiheadAttention"; + ss << Container::prettyString(); + return ss.str(); } TransformerBaseLayer::TransformerBaseLayer( int32_t modelDim, int32_t mlpDim, int32_t nHeads, - float pDropout) - : self_attn_(std::make_shared( - modelDim, - modelDim / nHeads, - nHeads, - pDropout)), - w1_(makeTransformerLinear(modelDim, mlpDim)), - w2_(makeTransformerLinear(mlpDim, modelDim)), - norm1_(std::make_shared( - std::vector{0}, - 1e-5, - true, - modelDim)), - norm2_(std::make_shared( - std::vector{0}, - 1e-5, - true, - modelDim)), - pDropout_(pDropout) { - createLayers(); + float pDropout +) : self_attn_(std::make_shared( + modelDim, + modelDim / nHeads, + nHeads, + pDropout + )), + w1_(makeTransformerLinear(modelDim, mlpDim)), + w2_(makeTransformerLinear(mlpDim, modelDim)), + norm1_(std::make_shared( + std::vector{0}, + 1e-5, + true, + modelDim + )), + norm2_(std::make_shared( + std::vector{0}, + 1e-5, + true, + modelDim + )), + pDropout_(pDropout) { + createLayers(); } TransformerBaseLayer::TransformerBaseLayer(const TransformerBaseLayer& other) { - copy(other); - createLayers(); + copy(other); + createLayers(); } TransformerBaseLayer& TransformerBaseLayer::operator=( - const TransformerBaseLayer& other) { - clear(); - copy(other); - createLayers(); - return *this; + const TransformerBaseLayer& other +) { + clear(); + copy(other); + createLayers(); + return *this; } void TransformerBaseLayer::copy(const TransformerBaseLayer& other) { - train_ = other.train_; - pDropout_ = other.pDropout_; - self_attn_ = std::make_shared(*other.self_attn_); - w1_ = std::make_shared(*other.w1_); - w2_ = std::make_shared(*other.w2_); - norm1_ = std::make_shared(*other.norm1_); - norm2_ = std::make_shared(*other.norm2_); + train_ = other.train_; + pDropout_ = other.pDropout_; + self_attn_ = std::make_shared(*other.self_attn_); + w1_ = std::make_shared(*other.w1_); + w2_ = std::make_shared(*other.w2_); + norm1_ = std::make_shared(*other.norm1_); + norm2_ = std::make_shared(*other.norm2_); } void TransformerBaseLayer::createLayers() { - add(self_attn_); - add(w1_); - add(w2_); - add(norm1_); - add(norm2_); + add(self_attn_); + add(w1_); + add(w2_); + add(norm1_); + add(norm2_); }; Variable TransformerBaseLayer::mlp(const Variable& in) { - float pDropout = train_ ? pDropout_ : 0.0; - return (*w2_)(dropout(relu((*w1_)(in)), pDropout)); + float pDropout = train_ ? pDropout_ : 0.0; + return (*w2_)(dropout(relu((*w1_)(in)), pDropout)); } Variable TransformerBaseLayer::withPosEmbed( const Variable& input, - const Variable& pos) { - if (pos.isEmpty()) { - return input; - } - return input + pos; + const Variable& pos +) { + if(pos.isEmpty()) { + return input; + } + return input + pos; } Variable TransformerBaseLayer::selfAttention( const Variable& input, const Variable& pos, - const Variable& keyPaddingMask) { - auto k = withPosEmbed(input, pos); - auto q = withPosEmbed(input, pos); - return self_attn_->forward(q, k, input, keyPaddingMask)[0]; + const Variable& keyPaddingMask +) { + auto k = withPosEmbed(input, pos); + auto q = withPosEmbed(input, pos); + return self_attn_->forward(q, k, input, keyPaddingMask)[0]; } TransformerEncoderLayer::TransformerEncoderLayer( int32_t modelDim, int32_t mlpDim, int32_t nHeads, - float pDropout) - : TransformerBaseLayer(modelDim, mlpDim, nHeads, pDropout){}; + float pDropout +) : TransformerBaseLayer(modelDim, mlpDim, nHeads, pDropout) {}; std::unique_ptr TransformerEncoderLayer::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::vector TransformerEncoderLayer::forward( - const std::vector& input) { - auto src = input[0]; - auto mask = input[1]; - auto pos = input[2]; + const std::vector& input +) { + auto src = input[0]; + auto mask = input[1]; + auto pos = input[2]; - float pDropout = train_ ? pDropout_ : 0.0f; + float pDropout = train_ ? pDropout_ : 0.0f; - auto src2 = this->selfAttention(src, pos, mask); - src = src + dropout(src2, pDropout); - src = (*norm1_)(src); - src2 = mlp(src); - src = src + dropout(src2, pDropout); - src = (*norm2_)(src); + auto src2 = this->selfAttention(src, pos, mask); + src = src + dropout(src2, pDropout); + src = (*norm1_)(src); + src2 = mlp(src); + src = src + dropout(src2, pDropout); + src = (*norm2_)(src); - return {src, mask, pos}; + return {src, mask, pos}; } std::string TransformerEncoderLayer::prettyString() const { - std::ostringstream ss; - ss << "TransformerEncoderLayer"; - ss << Container::prettyString(); - return ss.str(); + std::ostringstream ss; + ss << "TransformerEncoderLayer"; + ss << Container::prettyString(); + return ss.str(); } TransformerDecoderLayer::TransformerDecoderLayer( int32_t modelDim, int32_t mlpDim, int32_t nHeads, - float pDropout) - : self_attn_(std::make_shared( - modelDim, - modelDim / nHeads, - nHeads, - pDropout)), - encoder_attn_(std::make_shared( - modelDim, - modelDim / nHeads, - nHeads, - pDropout)), - w1_(makeTransformerLinear(modelDim, mlpDim)), - w2_(makeTransformerLinear(mlpDim, modelDim)), - norm1_(std::make_shared( - std::vector{0}, - 1e-5, - true, - modelDim)), - norm2_(std::make_shared( - std::vector{0}, - 1e-5, - true, - modelDim)), - norm3_(std::make_shared( - std::vector{0}, - 1e-5, - true, - modelDim)), - pDropout_(pDropout) { - createLayers(); + float pDropout +) : self_attn_(std::make_shared( + modelDim, + modelDim / nHeads, + nHeads, + pDropout + )), + encoder_attn_(std::make_shared( + modelDim, + modelDim / nHeads, + nHeads, + pDropout + )), + w1_(makeTransformerLinear(modelDim, mlpDim)), + w2_(makeTransformerLinear(mlpDim, modelDim)), + norm1_(std::make_shared( + std::vector{0}, + 1e-5, + true, + modelDim + )), + norm2_(std::make_shared( + std::vector{0}, + 1e-5, + true, + modelDim + )), + norm3_(std::make_shared( + std::vector{0}, + 1e-5, + true, + modelDim + )), + pDropout_(pDropout) { + createLayers(); } TransformerDecoderLayer::TransformerDecoderLayer( - const TransformerDecoderLayer& other) { - copy(other); - createLayers(); + const TransformerDecoderLayer& other +) { + copy(other); + createLayers(); } TransformerDecoderLayer& TransformerDecoderLayer::operator=( - const TransformerDecoderLayer& other) { - clear(); - copy(other); - createLayers(); - return *this; + const TransformerDecoderLayer& other +) { + clear(); + copy(other); + createLayers(); + return *this; } void TransformerDecoderLayer::copy(const TransformerDecoderLayer& other) { - train_ = other.train_; - pDropout_ = other.pDropout_; - self_attn_ = std::make_shared(*other.self_attn_); - encoder_attn_ = std::make_shared(*other.encoder_attn_); - w1_ = std::make_shared(*other.w1_); - w2_ = std::make_shared(*other.w2_); - norm1_ = std::make_shared(*other.norm1_); - norm2_ = std::make_shared(*other.norm2_); - norm3_ = std::make_shared(*other.norm3_); + train_ = other.train_; + pDropout_ = other.pDropout_; + self_attn_ = std::make_shared(*other.self_attn_); + encoder_attn_ = std::make_shared(*other.encoder_attn_); + w1_ = std::make_shared(*other.w1_); + w2_ = std::make_shared(*other.w2_); + norm1_ = std::make_shared(*other.norm1_); + norm2_ = std::make_shared(*other.norm2_); + norm3_ = std::make_shared(*other.norm3_); } void TransformerDecoderLayer::createLayers() { - add(self_attn_); - add(encoder_attn_); - add(w1_); - add(w2_); - add(norm1_); - add(norm2_); - add(norm3_); + add(self_attn_); + add(encoder_attn_); + add(w1_); + add(w2_); + add(norm1_); + add(norm2_); + add(norm3_); } std::unique_ptr TransformerDecoderLayer::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } Variable TransformerDecoderLayer::mlp(const Variable& in) { - float pDropout = train_ ? pDropout_ : 0.0; - return (*w2_)(dropout(relu((*w1_)(in)), pDropout)); + float pDropout = train_ ? pDropout_ : 0.0; + return (*w2_)(dropout(relu((*w1_)(in)), pDropout)); } Variable TransformerDecoderLayer::withPosEmbed( const Variable& input, - const Variable& pos) { - if (pos.isEmpty()) { - return input; - } - return input + pos; + const Variable& pos +) { + if(pos.isEmpty()) { + return input; + } + return input + pos; } Variable TransformerDecoderLayer::selfAttention( const Variable& input, const Variable& pos, - const Variable& keyPaddingMask /* = Variable() */) { - auto k = withPosEmbed(input, pos); - auto q = withPosEmbed(input, pos); - return self_attn_->forward(q, k, input, keyPaddingMask)[0]; + const Variable& keyPaddingMask /* = Variable() */ +) { + auto k = withPosEmbed(input, pos); + auto q = withPosEmbed(input, pos); + return self_attn_->forward(q, k, input, keyPaddingMask)[0]; } std::vector TransformerDecoderLayer::forward( - const std::vector& input) { - auto tgt = input[0]; - auto memory = input[1]; - auto pos = (input.size() > 2) ? input[2] : Variable(); - auto queryPos = (input.size() > 3) ? input[3] : Variable(); - auto memoryKeyPaddingMask = (input.size() > 4) ? input[4] : Variable(); - - float pDropout = train_ ? pDropout_ : 0.0f; - - auto tgt2 = this->selfAttention(tgt, queryPos); - tgt = tgt + dropout(tgt2, pDropout); - tgt = (*norm1_)(tgt); - tgt2 = encoder_attn_->forward({ - this->withPosEmbed(tgt, queryPos), // queries - this->withPosEmbed(memory, pos), // keys - memory, // values - memoryKeyPaddingMask // mask - })[0]; - tgt = tgt + dropout(tgt2, pDropout); - tgt = (*norm2_)(tgt); - tgt2 = mlp(tgt); - tgt = tgt + dropout(tgt2, pDropout); - tgt = (*norm3_)(tgt); - return {tgt}; + const std::vector& input +) { + auto tgt = input[0]; + auto memory = input[1]; + auto pos = (input.size() > 2) ? input[2] : Variable(); + auto queryPos = (input.size() > 3) ? input[3] : Variable(); + auto memoryKeyPaddingMask = (input.size() > 4) ? input[4] : Variable(); + + float pDropout = train_ ? pDropout_ : 0.0f; + + auto tgt2 = this->selfAttention(tgt, queryPos); + tgt = tgt + dropout(tgt2, pDropout); + tgt = (*norm1_)(tgt); + tgt2 = encoder_attn_->forward( + { + this->withPosEmbed(tgt, queryPos), // queries + this->withPosEmbed(memory, pos), // keys + memory, // values + memoryKeyPaddingMask // mask + } + )[0]; + tgt = tgt + dropout(tgt2, pDropout); + tgt = (*norm2_)(tgt); + tgt2 = mlp(tgt); + tgt = tgt + dropout(tgt2, pDropout); + tgt = (*norm3_)(tgt); + return {tgt}; } std::string TransformerDecoderLayer::prettyString() const { - std::ostringstream ss; - ss << "TransformerDecoderLayer"; - ss << Container::prettyString(); - return ss.str(); + std::ostringstream ss; + ss << "TransformerDecoderLayer"; + ss << Container::prettyString(); + return ss.str(); } TransformerDecoder::TransformerDecoder( @@ -425,38 +453,41 @@ TransformerDecoder::TransformerDecoder( int32_t mlpDim, int32_t nHeads, int32_t layers, - float pDropout) { - // TODO add norm - for (int i = 0; i < layers; i++) { - add(TransformerDecoderLayer(modelDim, mlpDim, nHeads, pDropout)); - } - add(LayerNorm(std::vector{0}, 1e-5, true, modelDim)); + float pDropout +) { + // TODO add norm + for(int i = 0; i < layers; i++) { + add(TransformerDecoderLayer(modelDim, mlpDim, nHeads, pDropout)); + } + add(LayerNorm(std::vector{0}, 1e-5, true, modelDim)); } std::vector TransformerDecoder::forward( - const std::vector& input) { - auto tgt = input[0]; - auto memory = input[1]; - auto pos = (input.size() > 2) ? input[2] : Variable(); - auto query_pos = (input.size() > 3) ? input[3] : Variable(); - auto mask = (input.size() > 4) ? input[4] : Variable(); - - fl::Variable output = tgt; - auto mods = modules(); - - std::vector intermediate; - for (int i = 0; i < mods.size() - 1; i++) { - output = mods[i]->forward({output, memory, pos, query_pos, mask})[0]; - intermediate.push_back( - moddims(mods.back()->forward({output})[0], {0, 0, 0, 1})); - } - return {concatenate(intermediate, 3)}; + const std::vector& input +) { + auto tgt = input[0]; + auto memory = input[1]; + auto pos = (input.size() > 2) ? input[2] : Variable(); + auto query_pos = (input.size() > 3) ? input[3] : Variable(); + auto mask = (input.size() > 4) ? input[4] : Variable(); + + fl::Variable output = tgt; + auto mods = modules(); + + std::vector intermediate; + for(int i = 0; i < mods.size() - 1; i++) { + output = mods[i]->forward({output, memory, pos, query_pos, mask})[0]; + intermediate.push_back( + moddims(mods.back()->forward({output})[0], {0, 0, 0, 1}) + ); + } + return {concatenate(intermediate, 3)}; } std::string TransformerDecoder::prettyString() const { - std::ostringstream ss; - ss << "TransformerDecoder"; - ss << Container::prettyString(); - return ss.str(); + std::ostringstream ss; + ss << "TransformerDecoder"; + ss << Container::prettyString(); + return ss.str(); } TransformerEncoder::TransformerEncoder( @@ -464,27 +495,29 @@ TransformerEncoder::TransformerEncoder( int32_t mlpDim, int32_t nHeads, int32_t layers, - float pDropout) { - for (int i = 0; i < layers; i++) { - add(TransformerEncoderLayer(modelDim, mlpDim, nHeads, pDropout)); - } + float pDropout +) { + for(int i = 0; i < layers; i++) { + add(TransformerEncoderLayer(modelDim, mlpDim, nHeads, pDropout)); + } } std::vector TransformerEncoder::forward( - const std::vector& input) { - std::vector output = input; - auto mods = modules(); - for (int i = 0; i < mods.size(); i++) { - output = mods[i]->forward(output); - } - return output; + const std::vector& input +) { + std::vector output = input; + auto mods = modules(); + for(int i = 0; i < mods.size(); i++) { + output = mods[i]->forward(output); + } + return output; } std::string TransformerEncoder::prettyString() const { - std::ostringstream ss; - ss << "TransformerDecoder"; - ss << Container::prettyString(); - return ss.str(); + std::ostringstream ss; + ss << "TransformerDecoder"; + ss << Container::prettyString(); + return ss.str(); } Transformer::Transformer( @@ -493,106 +526,110 @@ Transformer::Transformer( int32_t numEncoderLayers, int32_t numDecoderLayers, int32_t mlpDim, - float pDropout) - : encoder_(std::make_shared( - modelDim, - mlpDim, - numHeads, - numEncoderLayers, - pDropout)), - decoder_(std::make_shared( - modelDim, - mlpDim, - numHeads, - numDecoderLayers, - pDropout)) { - createLayers(); + float pDropout +) : encoder_(std::make_shared( + modelDim, + mlpDim, + numHeads, + numEncoderLayers, + pDropout + )), + decoder_(std::make_shared( + modelDim, + mlpDim, + numHeads, + numDecoderLayers, + pDropout + )) { + createLayers(); } Transformer::Transformer(const Transformer& other) { - copy(other); - createLayers(); + copy(other); + createLayers(); } Transformer& Transformer::operator=(const Transformer& other) { - clear(); - copy(other); - createLayers(); - return *this; + clear(); + copy(other); + createLayers(); + return *this; } void Transformer::copy(const Transformer& other) { - train_ = other.train_; - encoder_ = std::make_shared(*other.encoder_); - decoder_ = std::make_shared(*other.decoder_); + train_ = other.train_; + encoder_ = std::make_shared(*other.encoder_); + decoder_ = std::make_shared(*other.decoder_); } void Transformer::createLayers() { - add(encoder_); - add(decoder_); + add(encoder_); + add(decoder_); }; std::unique_ptr Transformer::clone() const { - return std::make_unique(*this); + return std::make_unique(*this); } std::vector Transformer::forward( Variable src, Variable mask, Variable queryEmbed, - Variable posEmbed) { - if (src.ndim() != 4) { - throw std::invalid_argument( - "vision::Transformer::forward - " - "expect src to be of shape (W, H, C, B)."); - } - assert(src.dim(2) == queryEmbed.dim(0)); + Variable posEmbed +) { + if(src.ndim() != 4) { + throw std::invalid_argument( + "vision::Transformer::forward - " + "expect src to be of shape (W, H, C, B)." + ); + } + assert(src.dim(2) == queryEmbed.dim(0)); - int B = src.dim(3); - // Reshape from [ W X H X C X B ] to [ WH X C X B ] + int B = src.dim(3); + // Reshape from [ W X H X C X B ] to [ WH X C X B ] - src = flatten(src, 0, 1); - // Flatten to C x B x WH x 1 - src = reorder(src, {1, 2, 0, 3}); - // Squeeze to C x B x WH - src = moddims(src, {0, 0, 0}); + src = flatten(src, 0, 1); + // Flatten to C x B x WH x 1 + src = reorder(src, {1, 2, 0, 3}); + // Squeeze to C x B x WH + src = moddims(src, {0, 0, 0}); - posEmbed = flatten(posEmbed, 0, 1); - posEmbed = reorder(posEmbed, {1, 2, 0, 3}); - posEmbed = moddims(posEmbed, {0, 0, 0}); + posEmbed = flatten(posEmbed, 0, 1); + posEmbed = reorder(posEmbed, {1, 2, 0, 3}); + posEmbed = moddims(posEmbed, {0, 0, 0}); - mask = flatten(mask, 0, 2); + mask = flatten(mask, 0, 2); - // Tile object queries for each batch - Shape unsqueeze = {queryEmbed.dim(0), 1, queryEmbed.dim(1)}; - queryEmbed = moddims(queryEmbed, unsqueeze); - queryEmbed = tile(queryEmbed, {1, B, 1}); - assert(queryEmbed.dim(1) == src.dim(1)); - assert(queryEmbed.dim(0) == src.dim(0)); + // Tile object queries for each batch + Shape unsqueeze = {queryEmbed.dim(0), 1, queryEmbed.dim(1)}; + queryEmbed = moddims(queryEmbed, unsqueeze); + queryEmbed = tile(queryEmbed, {1, B, 1}); + assert(queryEmbed.dim(1) == src.dim(1)); + assert(queryEmbed.dim(0) == src.dim(0)); - auto tgt = fl::Variable(fl::full(queryEmbed.shape(), 0, src.type()), false); + auto tgt = fl::Variable(fl::full(queryEmbed.shape(), 0, src.type()), false); - auto memory = encoder_->forward({src, mask, posEmbed}); - auto hs = decoder_->forward({tgt, memory[0], posEmbed, queryEmbed, mask})[0]; + auto memory = encoder_->forward({src, mask, posEmbed}); + auto hs = decoder_->forward({tgt, memory[0], posEmbed, queryEmbed, mask})[0]; - auto reordered = reorder(hs, {0, 2, 1, 3}); - return {reordered}; + auto reordered = reorder(hs, {0, 2, 1, 3}); + return {reordered}; } std::vector Transformer::forward(const std::vector& input) { - assert(input.size() > 3); - auto src = input[0]; - auto mask = (input.size() > 1) ? input[1] : Variable(); - auto query_embed = (input.size() > 2) ? input[2] : Variable(); - auto pos_embed = (input.size() > 3) ? input[3] : Variable(); - return forward(src, mask, query_embed, pos_embed); + assert(input.size() > 3); + auto src = input[0]; + auto mask = (input.size() > 1) ? input[1] : Variable(); + auto query_embed = (input.size() > 2) ? input[2] : Variable(); + auto pos_embed = (input.size() > 3) ? input[3] : Variable(); + return forward(src, mask, query_embed, pos_embed); } std::string Transformer::prettyString() const { - std::ostringstream ss; - ss << "Transformer"; - ss << Container::prettyString(); - return ss.str(); + std::ostringstream ss; + ss << "Transformer"; + ss << Container::prettyString(); + return ss.str(); } } // namespace fl diff --git a/flashlight/pkg/vision/nn/Transformer.h b/flashlight/pkg/vision/nn/Transformer.h index e5493d9..949c379 100644 --- a/flashlight/pkg/vision/nn/Transformer.h +++ b/flashlight/pkg/vision/nn/Transformer.h @@ -14,262 +14,280 @@ namespace fl { namespace pkg { -namespace vision { - -fl::Variable transformerMultiheadAttention( - const fl::Variable& query, - const fl::Variable& key, - const fl::Variable& value, - const fl::Variable& keyPaddingMask, - const int32_t nHead, - const double pDropout); - -class MultiheadAttention : public Container { - public: - MultiheadAttention( - int32_t modelDim, - int32_t headDim, - int32_t numHeads, - float pDropout = 0.f); - - MultiheadAttention(const MultiheadAttention& other); - MultiheadAttention(MultiheadAttention&& other) = default; - - MultiheadAttention& operator=(const MultiheadAttention& other); - MultiheadAttention& operator=(MultiheadAttention&& other) = default; - - std::unique_ptr clone() const override; - - // queries [ E, N, L ], where L is target length, N is batch size. - // keys / values [ E, N, S ], where S is src length, N is batch size. - // keyPaddingMask [ S, N ] - std::vector forward( - const Variable& queries, - const Variable& keys, - const Variable& values, - const Variable& keyPaddingMask); - - std::vector forward(const std::vector& input) override; - - std::string prettyString() const override; - - protected: - std::shared_ptr wq_; - std::shared_ptr wk_; - std::shared_ptr wv_; - std::shared_ptr wf_; - float pDropout_; - int32_t numHeads_; - - private: - MultiheadAttention() = default; - void createLayers(); - void copy(const MultiheadAttention& other); - FL_SAVE_LOAD_WITH_BASE( - fl::Container, - pDropout_, - numHeads_, - wq_, - wk_, - wv_, - wf_) -}; - -class TransformerBaseLayer : public Container { - public: - TransformerBaseLayer( - int32_t modelDim, - int32_t mlpDim, - int32_t nHeads, - float pDropout); - - TransformerBaseLayer(const TransformerBaseLayer& other); - TransformerBaseLayer(TransformerBaseLayer&& other) = default; - - TransformerBaseLayer& operator=(const TransformerBaseLayer& other); - TransformerBaseLayer& operator=(TransformerBaseLayer&& other) = default; - - protected: - TransformerBaseLayer() = default; - std::shared_ptr self_attn_; - std::shared_ptr w1_, w2_; - std::shared_ptr norm1_, norm2_; - float pDropout_; - - Variable mlp(const Variable& in); - - Variable withPosEmbed(const Variable& input, const Variable& pos); - - Variable selfAttention( - const Variable& input, - const Variable& pos, - const Variable& keyPaddingMask = Variable()); - - private: - void createLayers(); - void copy(const TransformerBaseLayer& other); - FL_SAVE_LOAD_WITH_BASE( - fl::Container, - pDropout_, - self_attn_, - w1_, - w2_, - norm1_, - norm2_) -}; - -class TransformerEncoderLayer : public TransformerBaseLayer { - public: - TransformerEncoderLayer( - int32_t modelDim, - int32_t mlpDim, - int32_t nHeads, - float pDropout); - - std::unique_ptr clone() const override; - - std::vector forward(const std::vector& input) override; - - std::string prettyString() const override; - - private: - TransformerEncoderLayer() = default; - FL_SAVE_LOAD_WITH_BASE(TransformerBaseLayer) -}; - -class TransformerDecoderLayer : public Container { - public: - TransformerDecoderLayer( - int32_t modelDim, - int32_t mlpDim, - int32_t nHeads, - float pDropout); - - TransformerDecoderLayer(const TransformerDecoderLayer& other); - TransformerDecoderLayer(TransformerDecoderLayer&& other) = default; - - TransformerDecoderLayer& operator=(const TransformerDecoderLayer& other); - TransformerDecoderLayer& operator=(TransformerDecoderLayer&& other) = default; - - std::unique_ptr clone() const override; - - protected: - Variable mlp(const Variable& in); - - Variable withPosEmbed(const Variable& input, const Variable& pos); - Variable selfAttention( - const Variable& input, - const Variable& pos, - const Variable& keyPaddingMask = Variable()); - - std::vector forward(const std::vector& input) override; - - std::string prettyString() const override; - - private: - TransformerDecoderLayer() = default; - void createLayers(); - void copy(const TransformerDecoderLayer& other); - - std::shared_ptr self_attn_, encoder_attn_; - std::shared_ptr w1_, w2_; - std::shared_ptr norm1_, norm2_, norm3_; - float pDropout_; - FL_SAVE_LOAD_WITH_BASE( - fl::Container, - pDropout_, - self_attn_, - encoder_attn_, - w1_, - w2_, - norm1_, - norm2_, - norm3_) -}; - -class TransformerDecoder : public Container { - public: - TransformerDecoder( - int32_t modelDim, - int32_t mlpDim, - int32_t nHeads, - int32_t layers, - float pDropout); - - std::vector forward(const std::vector& input) override; - - std::string prettyString() const override; - - FL_BASIC_CONTAINER_CLONING(TransformerDecoder) - - private: - TransformerDecoder() = default; - FL_SAVE_LOAD_WITH_BASE(fl::Container) -}; - -class TransformerEncoder : public Container { - public: - TransformerEncoder( - int32_t modelDim, - int32_t mlpDim, - int32_t nHeads, - int32_t layers, - float pDropout); - - std::vector forward(const std::vector& input) override; - - std::string prettyString() const override; - - FL_BASIC_CONTAINER_CLONING(TransformerEncoder) - - private: - TransformerEncoder() = default; - FL_SAVE_LOAD_WITH_BASE(fl::Container) -}; - -class Transformer : public Container { - public: - Transformer( - int32_t modelDim, - int32_t numHeads, - int32_t numEncoderLayers, - int32_t numDecoderLayers, - int32_t mlpDim, - float pDropout); - - Transformer(const Transformer& other); - Transformer(Transformer&& other) = default; - - Transformer& operator=(const Transformer& other); - Transformer& operator=(Transformer&& other) = default; - - std::unique_ptr clone() const override; - - /* - * We expect src to be [ W X H X C X B ] - * mask to be [ W X H X 1 X B ] - * query embed [ C X N ] (where N is number of query vectors) - * and posEmbed to be [ W X H X C X B ] - * where C is modelDim, B is Batch size, and W and H are width and height of - * image - */ - std::vector - forward(Variable src, Variable mask, Variable queryEmbed, Variable posEmbed); - - std::vector forward(const std::vector& input) override; - - std::string prettyString() const override; - - private: - Transformer() = default; - void createLayers(); - void copy(const Transformer& other); - std::shared_ptr encoder_; - std::shared_ptr decoder_; - FL_SAVE_LOAD_WITH_BASE(fl::Container, encoder_, decoder_) -}; - -} // namespace vision + namespace vision { + + fl::Variable transformerMultiheadAttention( + const fl::Variable& query, + const fl::Variable& key, + const fl::Variable& value, + const fl::Variable& keyPaddingMask, + const int32_t nHead, + const double pDropout + ); + + class MultiheadAttention : public Container { + public: + MultiheadAttention( + int32_t modelDim, + int32_t headDim, + int32_t numHeads, + float pDropout = 0.f + ); + + MultiheadAttention(const MultiheadAttention& other); + MultiheadAttention(MultiheadAttention&& other) = default; + + MultiheadAttention& operator=(const MultiheadAttention& other); + MultiheadAttention& operator=(MultiheadAttention&& other) = default; + + std::unique_ptr clone() const override; + + // queries [ E, N, L ], where L is target length, N is batch size. + // keys / values [ E, N, S ], where S is src length, N is batch size. + // keyPaddingMask [ S, N ] + std::vector forward( + const Variable& queries, + const Variable& keys, + const Variable& values, + const Variable& keyPaddingMask + ); + + std::vector forward(const std::vector& input) override; + + std::string prettyString() const override; + + protected: + std::shared_ptr wq_; + std::shared_ptr wk_; + std::shared_ptr wv_; + std::shared_ptr wf_; + float pDropout_; + int32_t numHeads_; + + private: + MultiheadAttention() = default; + void createLayers(); + void copy(const MultiheadAttention& other); + FL_SAVE_LOAD_WITH_BASE( + fl::Container, + pDropout_, + numHeads_, + wq_, + wk_, + wv_, + wf_ + ) + }; + + class TransformerBaseLayer : public Container { + public: + TransformerBaseLayer( + int32_t modelDim, + int32_t mlpDim, + int32_t nHeads, + float pDropout + ); + + TransformerBaseLayer(const TransformerBaseLayer& other); + TransformerBaseLayer(TransformerBaseLayer&& other) = default; + + TransformerBaseLayer& operator=(const TransformerBaseLayer& other); + TransformerBaseLayer& operator=(TransformerBaseLayer&& other) = default; + + protected: + TransformerBaseLayer() = default; + std::shared_ptr self_attn_; + std::shared_ptr w1_, w2_; + std::shared_ptr norm1_, norm2_; + float pDropout_; + + Variable mlp(const Variable& in); + + Variable withPosEmbed(const Variable& input, const Variable& pos); + + Variable selfAttention( + const Variable& input, + const Variable& pos, + const Variable& keyPaddingMask = Variable() + ); + + private: + void createLayers(); + void copy(const TransformerBaseLayer& other); + FL_SAVE_LOAD_WITH_BASE( + fl::Container, + pDropout_, + self_attn_, + w1_, + w2_, + norm1_, + norm2_ + ) + }; + + class TransformerEncoderLayer : public TransformerBaseLayer { + public: + TransformerEncoderLayer( + int32_t modelDim, + int32_t mlpDim, + int32_t nHeads, + float pDropout + ); + + std::unique_ptr clone() const override; + + std::vector forward(const std::vector& input) override; + + std::string prettyString() const override; + + private: + TransformerEncoderLayer() = default; + FL_SAVE_LOAD_WITH_BASE(TransformerBaseLayer) + }; + + class TransformerDecoderLayer : public Container { + public: + TransformerDecoderLayer( + int32_t modelDim, + int32_t mlpDim, + int32_t nHeads, + float pDropout + ); + + TransformerDecoderLayer(const TransformerDecoderLayer& other); + TransformerDecoderLayer(TransformerDecoderLayer&& other) = default; + + TransformerDecoderLayer& operator=(const TransformerDecoderLayer& other); + TransformerDecoderLayer& operator=(TransformerDecoderLayer&& other) = default; + + std::unique_ptr clone() const override; + + protected: + Variable mlp(const Variable& in); + + Variable withPosEmbed(const Variable& input, const Variable& pos); + Variable selfAttention( + const Variable& input, + const Variable& pos, + const Variable& keyPaddingMask = Variable() + ); + + std::vector forward(const std::vector& input) override; + + std::string prettyString() const override; + + private: + TransformerDecoderLayer() = default; + void createLayers(); + void copy(const TransformerDecoderLayer& other); + + std::shared_ptr self_attn_, encoder_attn_; + std::shared_ptr w1_, w2_; + std::shared_ptr norm1_, norm2_, norm3_; + float pDropout_; + FL_SAVE_LOAD_WITH_BASE( + fl::Container, + pDropout_, + self_attn_, + encoder_attn_, + w1_, + w2_, + norm1_, + norm2_, + norm3_ + ) + }; + + class TransformerDecoder : public Container { + public: + TransformerDecoder( + int32_t modelDim, + int32_t mlpDim, + int32_t nHeads, + int32_t layers, + float pDropout + ); + + std::vector forward(const std::vector& input) override; + + std::string prettyString() const override; + + FL_BASIC_CONTAINER_CLONING(TransformerDecoder) + + private: + TransformerDecoder() = default; + FL_SAVE_LOAD_WITH_BASE(fl::Container) + }; + + class TransformerEncoder : public Container { + public: + TransformerEncoder( + int32_t modelDim, + int32_t mlpDim, + int32_t nHeads, + int32_t layers, + float pDropout + ); + + std::vector forward(const std::vector& input) override; + + std::string prettyString() const override; + + FL_BASIC_CONTAINER_CLONING(TransformerEncoder) + + private: + TransformerEncoder() = default; + FL_SAVE_LOAD_WITH_BASE(fl::Container) + }; + + class Transformer : public Container { + public: + Transformer( + int32_t modelDim, + int32_t numHeads, + int32_t numEncoderLayers, + int32_t numDecoderLayers, + int32_t mlpDim, + float pDropout + ); + + Transformer(const Transformer& other); + Transformer(Transformer&& other) = default; + + Transformer& operator=(const Transformer& other); + Transformer& operator=(Transformer&& other) = default; + + std::unique_ptr clone() const override; + + /* + * We expect src to be [ W X H X C X B ] + * mask to be [ W X H X 1 X B ] + * query embed [ C X N ] (where N is number of query vectors) + * and posEmbed to be [ W X H X C X B ] + * where C is modelDim, B is Batch size, and W and H are width and height of + * image + */ + std::vector forward( + Variable src, + Variable mask, + Variable queryEmbed, + Variable posEmbed + ); + + std::vector forward(const std::vector& input) override; + + std::string prettyString() const override; + + private: + Transformer() = default; + void createLayers(); + void copy(const Transformer& other); + std::shared_ptr encoder_; + std::shared_ptr decoder_; + FL_SAVE_LOAD_WITH_BASE(fl::Container, encoder_, decoder_) + }; + + } // namespace vision } // namespace pkg } // namespace fl CEREAL_REGISTER_TYPE(fl::pkg::vision::Transformer) diff --git a/flashlight/pkg/vision/nn/VisionTransformer.cpp b/flashlight/pkg/vision/nn/VisionTransformer.cpp index ef9aa56..ac205a0 100644 --- a/flashlight/pkg/vision/nn/VisionTransformer.cpp +++ b/flashlight/pkg/vision/nn/VisionTransformer.cpp @@ -20,173 +20,181 @@ VisionTransformer::VisionTransformer( int32_t mlpDim, int32_t nHeads, float pDropout, - float pLayerdrop) - : modelDim_(modelDim), - headDim_(headDim), - mlpDim_(mlpDim), - nHeads_(nHeads), - pDropout_(pDropout), - pLayerdrop_(pLayerdrop), - w1_(initLinear(modelDim, mlpDim)), - w2_(initLinear(mlpDim, modelDim)), - wq_(initLinear(modelDim, headDim * nHeads)), - wk_(initLinear(modelDim, headDim * nHeads)), - wv_(initLinear(modelDim, headDim * nHeads)), - wf_(initLinear(headDim * nHeads, modelDim)), - norm1_(std::make_shared( - std::vector({0}), - 1e-6, // eps - true, // affine - modelDim)), - norm2_(std::make_shared( - std::vector({0}), - 1e-6, // eps - true, // affine - modelDim)) { - createLayers(); + float pLayerdrop +) : modelDim_(modelDim), + headDim_(headDim), + mlpDim_(mlpDim), + nHeads_(nHeads), + pDropout_(pDropout), + pLayerdrop_(pLayerdrop), + w1_(initLinear(modelDim, mlpDim)), + w2_(initLinear(mlpDim, modelDim)), + wq_(initLinear(modelDim, headDim * nHeads)), + wk_(initLinear(modelDim, headDim * nHeads)), + wv_(initLinear(modelDim, headDim * nHeads)), + wf_(initLinear(headDim * nHeads, modelDim)), + norm1_(std::make_shared( + std::vector({0}), + 1e-6, // eps + true, // affine + modelDim + )), + norm2_(std::make_shared( + std::vector({0}), + 1e-6, // eps + true, // affine + modelDim + )) { + createLayers(); } VisionTransformer::VisionTransformer(const VisionTransformer& other) { - copy(other); - createLayers(); + copy(other); + createLayers(); } VisionTransformer& VisionTransformer::operator=( - const VisionTransformer& other) { - clear(); - copy(other); - createLayers(); - return *this; + const VisionTransformer& other +) { + clear(); + copy(other); + createLayers(); + return *this; } void VisionTransformer::copy(const VisionTransformer& other) { - train_ = other.train_; - modelDim_ = other.modelDim_; - headDim_ = other.headDim_; - mlpDim_ = other.mlpDim_; - nHeads_ = other.nHeads_; - pDropout_ = other.pDropout_; - pLayerdrop_ = other.pLayerdrop_; - w1_ = std::make_shared(*other.w1_); - w2_ = std::make_shared(*other.w2_); - wq_ = std::make_shared(*other.wq_); - wk_ = std::make_shared(*other.wk_); - wv_ = std::make_shared(*other.wv_); - wf_ = std::make_shared(*other.wf_); - norm1_ = std::make_shared(*other.norm1_); - norm2_ = std::make_shared(*other.norm2_); + train_ = other.train_; + modelDim_ = other.modelDim_; + headDim_ = other.headDim_; + mlpDim_ = other.mlpDim_; + nHeads_ = other.nHeads_; + pDropout_ = other.pDropout_; + pLayerdrop_ = other.pLayerdrop_; + w1_ = std::make_shared(*other.w1_); + w2_ = std::make_shared(*other.w2_); + wq_ = std::make_shared(*other.wq_); + wk_ = std::make_shared(*other.wk_); + wv_ = std::make_shared(*other.wv_); + wf_ = std::make_shared(*other.wf_); + norm1_ = std::make_shared(*other.norm1_); + norm2_ = std::make_shared(*other.norm2_); } void VisionTransformer::createLayers() { - add(w1_); - add(w2_); - add(wq_); - add(wk_); - add(wv_); - add(wf_); - add(norm1_); - add(norm2_); + add(w1_); + add(w2_); + add(wq_); + add(wk_); + add(wv_); + add(wf_); + add(norm1_); + add(norm2_); } std::unique_ptr VisionTransformer::clone() const { - throw std::runtime_error( - "Cloning is unimplemented in Module 'VisionTransformer'"); + throw std::runtime_error( + "Cloning is unimplemented in Module 'VisionTransformer'" + ); } Variable VisionTransformer::gelu(const Variable& input) { - // https://arxiv.org/pdf/1606.08415.pdf - auto geluConst = 1 / std::sqrt(2); - auto res = 0.5 * input * (1 + erf(input * geluConst)); - return res; + // https://arxiv.org/pdf/1606.08415.pdf + auto geluConst = 1 / std::sqrt(2); + auto res = 0.5 * input * (1 + erf(input * geluConst)); + return res; } Variable VisionTransformer::mlp(const Variable& input) { - float pDropout = train_ ? pDropout_ : 0.0; - auto output = (*w1_)(input); - output = gelu(output.astype(fl::dtype::f32)).astype(input.type()); - output = dropout(output, pDropout); - output = (*w2_)(output); - output = dropout(output, pDropout); - - return output; + float pDropout = train_ ? pDropout_ : 0.0; + auto output = (*w1_)(input); + output = gelu(output.astype(fl::dtype::f32)).astype(input.type()); + output = dropout(output, pDropout); + output = (*w2_)(output); + output = dropout(output, pDropout); + + return output; } Variable VisionTransformer::selfAttention(const Variable& x) { - // x - C x T x B - double pDrop = train_ ? pDropout_ : 0.0; - - auto q = transpose((*wq_)(x), {1, 0, 2}); - auto k = transpose((*wk_)(x), {1, 0, 2}); - auto v = transpose((*wv_)(x), {1, 0, 2}); - - auto result = multiheadAttention( - q, - k, - v, - fl::Variable(), // posEmb - fl::Variable(), // mask - fl::Variable(), // padMask - nHeads_, - pDrop, - 0 // offset - ); - result = (*wf_)(transpose(result, {1, 0, 2})); - result = dropout(result, pDrop); - - return result; + // x - C x T x B + double pDrop = train_ ? pDropout_ : 0.0; + + auto q = transpose((*wq_)(x), {1, 0, 2}); + auto k = transpose((*wk_)(x), {1, 0, 2}); + auto v = transpose((*wv_)(x), {1, 0, 2}); + + auto result = multiheadAttention( + q, + k, + v, + fl::Variable(), // posEmb + fl::Variable(), // mask + fl::Variable(), // padMask + nHeads_, + pDrop, + 0 // offset + ); + result = (*wf_)(transpose(result, {1, 0, 2})); + result = dropout(result, pDrop); + + return result; } Variable VisionTransformer::dropPath(const Variable& x) { - if (!train_) { - return x; - } - - // https://git.io/JYOkq - int C = x.dim(0); - int T = x.dim(1); - int B = x.dim(2); - auto keepMask = (fl::rand({1, 1, B}) > pLayerdrop_).astype(x.type()); - auto keepRatio = - fl::mean(keepMask, {2}).astype(fl::dtype::f32).scalar(); - // Note: this `keepRatio` is computed for real here, while in the PT - // implementatino above, `keepRatio` = 1 - pLayerdrop_. - keepMask = keepMask / keepRatio; - return x * Variable(fl::tile(keepMask, {C, T}).astype(x.type()), false); + if(!train_) { + return x; + } + + // https://git.io/JYOkq + int C = x.dim(0); + int T = x.dim(1); + int B = x.dim(2); + auto keepMask = (fl::rand({1, 1, B}) > pLayerdrop_).astype(x.type()); + auto keepRatio = + fl::mean(keepMask, {2}).astype(fl::dtype::f32).scalar(); + // Note: this `keepRatio` is computed for real here, while in the PT + // implementatino above, `keepRatio` = 1 - pLayerdrop_. + keepMask = keepMask / keepRatio; + return x * Variable(fl::tile(keepMask, {C, T}).astype(x.type()), false); } std::vector VisionTransformer::forward( - const std::vector& inputs) { - if (inputs.size() != 1) { - throw std::runtime_error("VisionTransformer forward, !1 input Variables"); - } - - auto x = inputs.front(); - - if (x.ndim() != 3) { - throw std::invalid_argument( - "VisionTransformer::forward - " - "expected input with 3 dimensions - got input with " + - std::to_string(x.ndim())); - } - - x = x + dropPath(selfAttention((*norm1_)(x))); - x = x + dropPath(mlp((*norm2_)(x))); - return {x}; + const std::vector& inputs +) { + if(inputs.size() != 1) { + throw std::runtime_error("VisionTransformer forward, !1 input Variables"); + } + + auto x = inputs.front(); + + if(x.ndim() != 3) { + throw std::invalid_argument( + "VisionTransformer::forward - " + "expected input with 3 dimensions - got input with " + + std::to_string(x.ndim()) + ); + } + + x = x + dropPath(selfAttention((*norm1_)(x))); + x = x + dropPath(mlp((*norm2_)(x))); + return {x}; } std::string VisionTransformer::prettyString() const { - std::ostringstream ss; - ss << "VisionTransformer (nHeads: " << nHeads_ << "), " - << "(modelDim_: " << modelDim_ << "), " - << "(mlpDim_: " << mlpDim_ << "), " - << "(pDropout: " << pDropout_ << "), " - << "(pLayerdrop: " << pLayerdrop_ << "), "; - return ss.str(); + std::ostringstream ss; + ss << "VisionTransformer (nHeads: " << nHeads_ << "), " + << "(modelDim_: " << modelDim_ << "), " + << "(mlpDim_: " << mlpDim_ << "), " + << "(pDropout: " << pDropout_ << "), " + << "(pLayerdrop: " << pLayerdrop_ << "), "; + return ss.str(); } std::shared_ptr VisionTransformer::initLinear( int inDim, - int outDim) { - return std::make_shared( - fl::truncNormal({outDim, inDim}, 0.02), - fl::constant(0., outDim, 1, fl::dtype::f32)); + int outDim +) { + return std::make_shared( + fl::truncNormal({outDim, inDim}, 0.02), + fl::constant(0., outDim, 1, fl::dtype::f32) + ); } } // namespace fl diff --git a/flashlight/pkg/vision/nn/VisionTransformer.h b/flashlight/pkg/vision/nn/VisionTransformer.h index 41ea6f9..7bcc92c 100644 --- a/flashlight/pkg/vision/nn/VisionTransformer.h +++ b/flashlight/pkg/vision/nn/VisionTransformer.h @@ -11,7 +11,7 @@ namespace fl { namespace pkg { -namespace vision { + namespace vision { /* * Implementation of the transformer blocks of Vision Transformer (ViT) models @@ -20,71 +20,71 @@ namespace vision { * * This implementation is highly inspired by [timm](https://git.io/JYOql). */ -class VisionTransformer : public Container { - public: - VisionTransformer( - int32_t modelDim, - int32_t headDim, - int32_t mlpDim, - int32_t nHeads, - float pDropout, - float pLayerdrop); - VisionTransformer(const VisionTransformer& other); - VisionTransformer(VisionTransformer&& other) = default; - - VisionTransformer& operator=(const VisionTransformer& other); - VisionTransformer& operator=(VisionTransformer&& other) = default; - - ~VisionTransformer() override = default; - - std::unique_ptr clone() const override; - - std::vector forward(const std::vector& input) override; - std::string prettyString() const override; - - private: - int32_t modelDim_; - int32_t headDim_; - int32_t mlpDim_; - int32_t nHeads_; - double pDropout_; - double pLayerdrop_; - std::shared_ptr w1_, w2_; - std::shared_ptr wq_, wk_, wv_; - std::shared_ptr wf_; - std::shared_ptr norm1_, norm2_; - - void createLayers(); - void copy(const VisionTransformer& other); - - Variable gelu(const Variable& input); - Variable mlp(const Variable& input); - Variable selfAttention(const Variable& input); - Variable dropPath(const Variable& input); - - FL_SAVE_LOAD_WITH_BASE( - Container, - w1_, - w2_, - wq_, - wk_, - wv_, - wf_, - norm1_, - norm2_, - modelDim_, - headDim_, - mlpDim_, - nHeads_, - pDropout_, - pLayerdrop_) - - VisionTransformer() = default; - - std::shared_ptr initLinear(int inDim, int outDim); -}; - -} // namespace vision + class VisionTransformer : public Container { + public: + VisionTransformer( + int32_t modelDim, + int32_t headDim, + int32_t mlpDim, + int32_t nHeads, + float pDropout, + float pLayerdrop + ); + VisionTransformer(const VisionTransformer& other); + VisionTransformer(VisionTransformer&& other) = default; + + VisionTransformer& operator=(const VisionTransformer& other); + VisionTransformer& operator=(VisionTransformer&& other) = default; + + ~VisionTransformer() override = default; + + std::unique_ptr clone() const override; + + std::vector forward(const std::vector& input) override; + std::string prettyString() const override; + + private: + int32_t modelDim_; + int32_t headDim_; + int32_t mlpDim_; + int32_t nHeads_; + double pDropout_; + double pLayerdrop_; + std::shared_ptr w1_, w2_; + std::shared_ptr wq_, wk_, wv_; + std::shared_ptr wf_; + std::shared_ptr norm1_, norm2_; + + void createLayers(); + void copy(const VisionTransformer& other); + + Variable gelu(const Variable& input); + Variable mlp(const Variable& input); + Variable selfAttention(const Variable& input); + Variable dropPath(const Variable& input); + + FL_SAVE_LOAD_WITH_BASE( + Container, + w1_, + w2_, + wq_, + wk_, + wv_, + wf_, + norm1_, + norm2_, + modelDim_, + headDim_, + mlpDim_, + nHeads_, + pDropout_, + pLayerdrop_ + ) VisionTransformer() = default; + + std::shared_ptr initLinear(int inDim, int outDim); + }; + + } // namespace vision } // namespace pkg } // namespace fl diff --git a/flashlight/pkg/vision/tensor/VisionExtension.h b/flashlight/pkg/vision/tensor/VisionExtension.h index c8a671b..4c09c8d 100644 --- a/flashlight/pkg/vision/tensor/VisionExtension.h +++ b/flashlight/pkg/vision/tensor/VisionExtension.h @@ -15,57 +15,63 @@ namespace fl { // TODO: rename this file to VisionExtension class VisionExtension : public TensorExtension { - public: - static constexpr TensorExtensionType extensionType = - TensorExtensionType::Vision; +public: + static constexpr TensorExtensionType extensionType = + TensorExtensionType::Vision; - VisionExtension() = default; - virtual ~VisionExtension() = default; + VisionExtension() = default; + virtual ~VisionExtension() = default; - virtual Tensor histogram( - const Tensor& tensor, - const unsigned numBins, - const double minVal, - const double maxVal) = 0; - virtual Tensor histogram(const Tensor& tensor, const unsigned numBins) = 0; + virtual Tensor histogram( + const Tensor& tensor, + const unsigned numBins, + const double minVal, + const double maxVal + ) = 0; + virtual Tensor histogram(const Tensor& tensor, const unsigned numBins) = 0; - virtual Tensor equalize(const Tensor& input, const Tensor& histogram) = 0; + virtual Tensor equalize(const Tensor& input, const Tensor& histogram) = 0; - virtual Tensor resize( - const Tensor& tensor, - const Shape& shape, - const InterpolationMode mode) = 0; + virtual Tensor resize( + const Tensor& tensor, + const Shape& shape, + const InterpolationMode mode + ) = 0; - virtual Tensor - rotate(const Tensor& input, const float theta, const Tensor& fill) = 0; - virtual Tensor rotate( - const Tensor& input, - const float theta, - const InterpolationMode mode) = 0; + virtual Tensor rotate(const Tensor& input, const float theta, const Tensor& fill) = 0; + virtual Tensor rotate( + const Tensor& input, + const float theta, + const InterpolationMode mode + ) = 0; - virtual Tensor translate( - const Tensor& input, - const Shape& translation, - const Shape& outputDims, - const Tensor& fill) = 0; - virtual Tensor translate( - const Tensor& input, - const Shape& translation, - const Shape& outputDims, - const InterpolationMode mode) = 0; + virtual Tensor translate( + const Tensor& input, + const Shape& translation, + const Shape& outputDims, + const Tensor& fill + ) = 0; + virtual Tensor translate( + const Tensor& input, + const Shape& translation, + const Shape& outputDims, + const InterpolationMode mode + ) = 0; - virtual Tensor shear( - const Tensor& input, - const std::vector& skews, - const Shape& outputDims, - const Tensor& fill) = 0; - virtual Tensor shear( - const Tensor& input, - const std::vector& skews, - const Shape& outputDims, - const InterpolationMode mode) = 0; + virtual Tensor shear( + const Tensor& input, + const std::vector& skews, + const Shape& outputDims, + const Tensor& fill + ) = 0; + virtual Tensor shear( + const Tensor& input, + const std::vector& skews, + const Shape& outputDims, + const InterpolationMode mode + ) = 0; - virtual Tensor gaussianFilter(const Shape& shape) = 0; + virtual Tensor gaussianFilter(const Shape& shape) = 0; }; } // namespace fl diff --git a/flashlight/pkg/vision/tensor/VisionExtensionBackends.h b/flashlight/pkg/vision/tensor/VisionExtensionBackends.h index c8bc7de..e543b75 100644 --- a/flashlight/pkg/vision/tensor/VisionExtensionBackends.h +++ b/flashlight/pkg/vision/tensor/VisionExtensionBackends.h @@ -14,7 +14,7 @@ * Conditionally include vision extensions */ #if FL_USE_ARRAYFIRE - #include "flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.h" +#include "flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.h" #endif namespace fl { diff --git a/flashlight/pkg/vision/tensor/VisionOps.cpp b/flashlight/pkg/vision/tensor/VisionOps.cpp index 1531eda..042b146 100644 --- a/flashlight/pkg/vision/tensor/VisionOps.cpp +++ b/flashlight/pkg/vision/tensor/VisionOps.cpp @@ -17,85 +17,127 @@ Tensor histogram( const Tensor& tensor, const unsigned numBins, const double minVal, - const double maxVal) { - return tensor.backend().getExtension().histogram( - tensor, numBins, minVal, maxVal); + const double maxVal +) { + return tensor.backend().getExtension().histogram( + tensor, + numBins, + minVal, + maxVal + ); } Tensor histogram(const Tensor& tensor, const unsigned numBins) { - return tensor.backend().getExtension().histogram( - tensor, numBins); + return tensor.backend().getExtension().histogram( + tensor, + numBins + ); } Tensor equalize(const Tensor& input, const Tensor& histogram) { - return input.backend().getExtension().equalize( - input, histogram); + return input.backend().getExtension().equalize( + input, + histogram + ); } Tensor resize( const Tensor& tensor, const Shape& shape, - const InterpolationMode mode /* = InterpolationMode::Nearest */) { - return tensor.backend().getExtension().resize( - tensor, shape, mode); + const InterpolationMode mode /* = InterpolationMode::Nearest */ +) { + return tensor.backend().getExtension().resize( + tensor, + shape, + mode + ); } Tensor rotate( const Tensor& input, const float theta, - const Tensor& fill /* = Tensor() */) { - return input.backend().getExtension().rotate( - input, theta, fill); + const Tensor& fill /* = Tensor() */ +) { + return input.backend().getExtension().rotate( + input, + theta, + fill + ); } Tensor rotate( const Tensor& input, const float theta, - const InterpolationMode mode /* = InterpolationMode::Nearest */) { - return input.backend().getExtension().rotate( - input, theta, mode); + const InterpolationMode mode /* = InterpolationMode::Nearest */ +) { + return input.backend().getExtension().rotate( + input, + theta, + mode + ); } Tensor translate( const Tensor& input, const Shape& translation, const Shape& outputDims /* = {} */, - const Tensor& fill /* = Tensor() */) { - return input.backend().getExtension().translate( - input, translation, outputDims, fill); + const Tensor& fill /* = Tensor() */ +) { + return input.backend().getExtension().translate( + input, + translation, + outputDims, + fill + ); } Tensor translate( const Tensor& input, const Shape& translation, const Shape& outputDims /* = {} */, - const InterpolationMode mode /* = InterpolationMode::Nearest */) { - return input.backend().getExtension().translate( - input, translation, outputDims, mode); + const InterpolationMode mode /* = InterpolationMode::Nearest */ +) { + return input.backend().getExtension().translate( + input, + translation, + outputDims, + mode + ); } Tensor shear( const Tensor& input, const std::vector& skews, const Shape& outputDims /* = {} */, - const Tensor& fill /* = Tensor() */) { - return input.backend().getExtension().shear( - input, skews, outputDims, fill); + const Tensor& fill /* = Tensor() */ +) { + return input.backend().getExtension().shear( + input, + skews, + outputDims, + fill + ); } Tensor shear( const Tensor& input, const std::vector& skews, const Shape& outputDims /* = {} */, - const InterpolationMode mode /* = InterpolationMode::Nearest */) { - return input.backend().getExtension().shear( - input, skews, outputDims, mode); + const InterpolationMode mode /* = InterpolationMode::Nearest */ +) { + return input.backend().getExtension().shear( + input, + skews, + outputDims, + mode + ); } Tensor gaussianFilter(const Shape& shape) { - // TODO{fl::Tensor} - empty tensor instantiation for default backend - return defaultTensorBackend().getExtension().gaussianFilter( - shape); + // TODO{fl::Tensor} - empty tensor instantiation for default backend + return defaultTensorBackend().getExtension().gaussianFilter( + shape + ); } } // namespace fl diff --git a/flashlight/pkg/vision/tensor/VisionOps.h b/flashlight/pkg/vision/tensor/VisionOps.h index 42bc8b8..52f8f3f 100644 --- a/flashlight/pkg/vision/tensor/VisionOps.h +++ b/flashlight/pkg/vision/tensor/VisionOps.h @@ -26,7 +26,8 @@ Tensor histogram( const Tensor& tensor, const unsigned numBins, const double minVal, - const double maxVal); + const double maxVal +); Tensor histogram(const Tensor& tensor, const unsigned numBins); /** @@ -45,7 +46,7 @@ Tensor equalize(const Tensor& input, const Tensor& histogram); * TODO{fl::Tensor} -- consider moving this to a more general place - other * things will need to support interpolation */ -enum class InterpolationMode { Nearest, Linear, Bilinear, Cubic, Bicubic }; +enum class InterpolationMode {Nearest, Linear, Bilinear, Cubic, Bicubic}; /** * Resize a tensor, performing interpolation as needed. @@ -59,7 +60,8 @@ enum class InterpolationMode { Nearest, Linear, Bilinear, Cubic, Bicubic }; Tensor resize( const Tensor& tensor, const Shape& shape, - const InterpolationMode mode = InterpolationMode::Nearest); + const InterpolationMode mode = InterpolationMode::Nearest +); /** * Rotate a tensor by a given angle, filling in unfilled locations as needed. @@ -70,8 +72,7 @@ Tensor resize( * empty image regions post-transformation * @return a Tensor with the rotation operation applied. */ -Tensor -rotate(const Tensor& input, const float theta, const Tensor& fill = Tensor()); +Tensor rotate(const Tensor& input, const float theta, const Tensor& fill = Tensor()); /** * Rotate a tensor by a given angle, filling in unfilled locations as needed. @@ -85,7 +86,8 @@ rotate(const Tensor& input, const float theta, const Tensor& fill = Tensor()); Tensor rotate( const Tensor& input, const float theta, - const InterpolationMode mode = InterpolationMode::Nearest); + const InterpolationMode mode = InterpolationMode::Nearest +); /** * Translate a tensor by given amounts along some axes, filling in unfilled @@ -104,7 +106,8 @@ Tensor translate( const Tensor& input, const Shape& translation, const Shape& outputDims = {}, - const Tensor& fill = Tensor()); + const Tensor& fill = Tensor() +); /** * Translate a tensor by given amounts along some axes, filling in unfilled @@ -123,7 +126,8 @@ Tensor translate( const Tensor& input, const Shape& translation, const Shape& outputDims = {}, - const InterpolationMode mode = InterpolationMode::Nearest); + const InterpolationMode mode = InterpolationMode::Nearest +); /** * Apply a shear transformation (also called a skew transformation) to an input @@ -142,7 +146,8 @@ Tensor shear( const Tensor& input, const std::vector& skews, const Shape& outputDims = {}, - const Tensor& fill = Tensor()); + const Tensor& fill = Tensor() +); /** * Apply a shear transformation (also called a skew transformation) to an input @@ -161,7 +166,8 @@ Tensor shear( const Tensor& input, const std::vector& skews, const Shape& outputDims = {}, - const InterpolationMode mode = InterpolationMode::Nearest); + const InterpolationMode mode = InterpolationMode::Nearest +); /** * Create a Tensor with the given shape that is Gaussian distributed across the diff --git a/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.cpp b/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.cpp index 3641f19..ade467d 100644 --- a/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.cpp +++ b/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.cpp @@ -20,26 +20,27 @@ namespace fl { namespace detail { -constexpr af_interp_type flToAfInterpType(InterpolationMode mode) { - switch (mode) { - case InterpolationMode::Nearest: - return AF_INTERP_NEAREST; - case InterpolationMode::Linear: - return AF_INTERP_LINEAR; - case InterpolationMode::Bilinear: - return AF_INTERP_BILINEAR; - case InterpolationMode::Cubic: - return AF_INTERP_CUBIC; - case InterpolationMode::Bicubic: - return AF_INTERP_BICUBIC; - default: - throw std::invalid_argument( - "flToAfInterpType - no corresponding ArrayFire " - "interpolation mode for given interpolation mode."); - } -} - -namespace { + constexpr af_interp_type flToAfInterpType(InterpolationMode mode) { + switch(mode) { + case InterpolationMode::Nearest: + return AF_INTERP_NEAREST; + case InterpolationMode::Linear: + return AF_INTERP_LINEAR; + case InterpolationMode::Bilinear: + return AF_INTERP_BILINEAR; + case InterpolationMode::Cubic: + return AF_INTERP_CUBIC; + case InterpolationMode::Bicubic: + return AF_INTERP_BICUBIC; + default: + throw std::invalid_argument( + "flToAfInterpType - no corresponding ArrayFire " + "interpolation mode for given interpolation mode." + ); + } + } + + namespace { /* * Performs a fill image operation using a fill Tensor on an input Tensor in @@ -48,228 +49,264 @@ namespace { * This is needed because ArrayFire only supports zero-filling on empty spots. * Once AF supports filling directly, this can be removed. */ -template -af::array addFillTensor( - const af::array& input, - const af::array& fillImg, - af_image_transform_func_t transformFunc, - Args&&... args) { - af::array res = input; - - const double delta = 1e-2; - if (!fillImg.isempty()) { - res = res + delta; - } - - // Call the transform - res = transformFunc(res, std::forward(args)...); - - if (!fillImg.isempty()) { - auto mask = af::sum(res, 2) == 0; - mask = af::tile(mask, {1, 1, 3}); - res = mask * fillImg + (1 - mask) * (res - delta); - } - return res; -} - -} // namespace + template + af::array addFillTensor( + const af::array& input, + const af::array& fillImg, + af_image_transform_func_t transformFunc, + Args&&... args + ) { + af::array res = input; + + const double delta = 1e-2; + if(!fillImg.isempty()) { + res = res + delta; + } + + // Call the transform + res = transformFunc(res, std::forward(args)...); + + if(!fillImg.isempty()) { + auto mask = af::sum(res, 2) == 0; + mask = af::tile(mask, {1, 1, 3}); + res = mask * fillImg + (1 - mask) * (res - delta); + } + return res; + } + + } // namespace } // namespace detail bool ArrayFireVisionExtension::isDataTypeSupported( - const fl::dtype& dtype) const { - return ArrayFireBackend::getInstance().isDataTypeSupported(dtype); + const fl::dtype& dtype +) const { + return ArrayFireBackend::getInstance().isDataTypeSupported(dtype); } Tensor ArrayFireVisionExtension::histogram( const Tensor& tensor, const unsigned numBins, const double minVal, - const double maxVal) { - // TODO: add ndim to this - return toTensor( - af::histogram(toArray(tensor), numBins, minVal, maxVal), - /* numDims = */ 1); + const double maxVal +) { + // TODO: add ndim to this + return toTensor( + af::histogram(toArray(tensor), numBins, minVal, maxVal), + /* numDims = */ 1 + ); } Tensor ArrayFireVisionExtension::histogram( const Tensor& tensor, - const unsigned numBins) { - // TODO: add ndim to this - return toTensor( - af::histogram(toArray(tensor), numBins), /* numDims = */ 1); + const unsigned numBins +) { + // TODO: add ndim to this + return toTensor( + af::histogram(toArray(tensor), numBins), /* numDims = */ + 1 + ); } Tensor ArrayFireVisionExtension::equalize( const Tensor& input, - const Tensor& histogram) { - return toTensor( - af::histEqual(toArray(input), toArray(histogram)), input.ndim()); + const Tensor& histogram +) { + return toTensor( + af::histEqual(toArray(input), toArray(histogram)), + input.ndim() + ); } Tensor ArrayFireVisionExtension::resize( const Tensor& tensor, const Shape& shape, - const InterpolationMode mode) { - af::dim4 _shape = detail::flToAfDims(shape); - return toTensor( - af::resize( - toArray(tensor), - _shape[0], - _shape[1], - detail::flToAfInterpType(mode)), - tensor.ndim()); + const InterpolationMode mode +) { + af::dim4 _shape = detail::flToAfDims(shape); + return toTensor( + af::resize( + toArray(tensor), + _shape[0], + _shape[1], + detail::flToAfInterpType(mode) + ), + tensor.ndim() + ); } Tensor ArrayFireVisionExtension::rotate( const Tensor& input, const float theta, - const Tensor& fill /* = Tensor() */) { - return toTensor( - detail::addFillTensor( - toArray(input), - toArray(fill), - af::rotate, - theta, - /* crop = */ true, - AF_INTERP_NEAREST), - input.ndim()); + const Tensor& fill /* = Tensor() */ +) { + return toTensor( + detail::addFillTensor( + toArray(input), + toArray(fill), + af::rotate, + theta, + /* crop = */ true, + AF_INTERP_NEAREST + ), + input.ndim() + ); } Tensor ArrayFireVisionExtension::rotate( const Tensor& input, const float theta, - const InterpolationMode mode) { - return toTensor( - af::rotate(toArray(input), theta, detail::flToAfInterpType(mode)), - input.ndim()); + const InterpolationMode mode +) { + return toTensor( + af::rotate(toArray(input), theta, detail::flToAfInterpType(mode)), + input.ndim() + ); } Tensor ArrayFireVisionExtension::translate( const Tensor& input, const Shape& translation, const Shape& outputDimsIn /* = {} */, - const Tensor& fill /* = Tensor() */) { - // If no output dims specified, AF expects 2D 0's which to discard OOB data - Shape outputDims = outputDimsIn; - if (outputDimsIn.ndim() == 0) { - outputDims = Shape({0, 0}); - } - - if (translation.ndim() != 2 || outputDims.ndim() != 2) { - throw std::invalid_argument( - "ArrayFireVisionExtension::shear - " - "only 2D skews shapes and empty or 2D output shapes are supported"); - } - - return toTensor( - detail::addFillTensor( - toArray(input), - toArray(fill), - af::translate, - translation[0], - translation[1], - outputDims[0], - outputDims[1], - AF_INTERP_NEAREST), - input.ndim()); + const Tensor& fill /* = Tensor() */ +) { + // If no output dims specified, AF expects 2D 0's which to discard OOB data + Shape outputDims = outputDimsIn; + if(outputDimsIn.ndim() == 0) { + outputDims = Shape({0, 0}); + } + + if(translation.ndim() != 2 || outputDims.ndim() != 2) { + throw std::invalid_argument( + "ArrayFireVisionExtension::shear - " + "only 2D skews shapes and empty or 2D output shapes are supported" + ); + } + + return toTensor( + detail::addFillTensor( + toArray(input), + toArray(fill), + af::translate, + translation[0], + translation[1], + outputDims[0], + outputDims[1], + AF_INTERP_NEAREST + ), + input.ndim() + ); } Tensor ArrayFireVisionExtension::translate( const Tensor& input, const Shape& translation, const Shape& outputDimsIn /* = {} */, - const InterpolationMode mode) { - // If no output dims specified, AF expects 2D 0's which to discard OOB data - Shape outputDims = outputDimsIn; - if (outputDimsIn.ndim() == 0) { - outputDims = Shape({0, 0}); - } - - if (translation.ndim() != 2 || outputDims.ndim() != 2) { - throw std::invalid_argument( - "ArrayFireVisionExtension::shear - " - "only 2D skews shapes and empty or 2D output shapes are supported"); - } - - af::dim4 _translations = detail::flToAfDims(translation); - af::dim4 _outputDims = detail::flToAfInterpType(mode); - - return toTensor( - af::translate( - toArray(input), - _translations[0], - _translations[1], - _outputDims[0], - _outputDims[1]), - input.ndim()); + const InterpolationMode mode +) { + // If no output dims specified, AF expects 2D 0's which to discard OOB data + Shape outputDims = outputDimsIn; + if(outputDimsIn.ndim() == 0) { + outputDims = Shape({0, 0}); + } + + if(translation.ndim() != 2 || outputDims.ndim() != 2) { + throw std::invalid_argument( + "ArrayFireVisionExtension::shear - " + "only 2D skews shapes and empty or 2D output shapes are supported" + ); + } + + af::dim4 _translations = detail::flToAfDims(translation); + af::dim4 _outputDims = detail::flToAfInterpType(mode); + + return toTensor( + af::translate( + toArray(input), + _translations[0], + _translations[1], + _outputDims[0], + _outputDims[1] + ), + input.ndim() + ); } Tensor ArrayFireVisionExtension::shear( const Tensor& input, const std::vector& skews, const Shape& outputDimsIn /* = {} */, - const Tensor& fill /* = Tensor() */) { - // If no output dims specified, AF expects 2D 0's which to discard OOB data - Shape outputDims = outputDimsIn; - if (outputDimsIn.ndim() == 0) { - outputDims = Shape({0, 0}); - } - - if (skews.size() != 2 || outputDims.ndim() != 2) { - throw std::invalid_argument( - "ArrayFireVisionExtension::shear - " - "only 2D skews shapes and empty or 2D output shapes are supported"); - } - - af::dim4 _outputDims = detail::flToAfDims(outputDims); - - return toTensor( - detail::addFillTensor( - toArray(input), - toArray(fill), - af::skew, - skews[0], - skews[1], - _outputDims[0], - _outputDims[1], - /* inverse = */ true, - AF_INTERP_NEAREST), - input.ndim()); + const Tensor& fill /* = Tensor() */ +) { + // If no output dims specified, AF expects 2D 0's which to discard OOB data + Shape outputDims = outputDimsIn; + if(outputDimsIn.ndim() == 0) { + outputDims = Shape({0, 0}); + } + + if(skews.size() != 2 || outputDims.ndim() != 2) { + throw std::invalid_argument( + "ArrayFireVisionExtension::shear - " + "only 2D skews shapes and empty or 2D output shapes are supported" + ); + } + + af::dim4 _outputDims = detail::flToAfDims(outputDims); + + return toTensor( + detail::addFillTensor( + toArray(input), + toArray(fill), + af::skew, + skews[0], + skews[1], + _outputDims[0], + _outputDims[1], + /* inverse = */ true, + AF_INTERP_NEAREST + ), + input.ndim() + ); } Tensor ArrayFireVisionExtension::shear( const Tensor& input, const std::vector& skews, const Shape& outputDimsIn /* = {} */, - const InterpolationMode mode) { - // If no output dims specified, AF expects 2D 0's which to discard OOB data - Shape outputDims = outputDimsIn; - if (outputDimsIn.ndim() == 0) { - outputDims = Shape({0, 0}); - } - - if (skews.size() != 2 || outputDims.ndim() != 2) { - throw std::invalid_argument( - "ArrayFireVisionExtension::shear - " - "only 2D skews shapes and empty or 2D output shapes are supported"); - } - - return toTensor( - af::skew( - toArray(input), - skews[0], - skews[1], - outputDims[0], - outputDims[1], - /* inverse = */ true, - detail::flToAfInterpType(mode)), - input.ndim()); + const InterpolationMode mode +) { + // If no output dims specified, AF expects 2D 0's which to discard OOB data + Shape outputDims = outputDimsIn; + if(outputDimsIn.ndim() == 0) { + outputDims = Shape({0, 0}); + } + + if(skews.size() != 2 || outputDims.ndim() != 2) { + throw std::invalid_argument( + "ArrayFireVisionExtension::shear - " + "only 2D skews shapes and empty or 2D output shapes are supported" + ); + } + + return toTensor( + af::skew( + toArray(input), + skews[0], + skews[1], + outputDims[0], + outputDims[1], + /* inverse = */ true, + detail::flToAfInterpType(mode) + ), + input.ndim() + ); } Tensor ArrayFireVisionExtension::gaussianFilter(const Shape& shape) { - af::dim4 _shape = detail::flToAfDims(shape); - return toTensor( - af::gaussianKernel(_shape[0], _shape[1]), shape.ndim()); + af::dim4 _shape = detail::flToAfDims(shape); + return toTensor( + af::gaussianKernel(_shape[0], _shape[1]), + shape.ndim() + ); } } // namespace fl diff --git a/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.h b/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.h index 7371b84..bf33fee 100644 --- a/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.h +++ b/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.h @@ -17,58 +17,65 @@ namespace detail { /* * Convert a Flashlight interpolation mode into an ArrayFire interpolation mode */ -constexpr af_interp_type flToAfInterpType(InterpolationMode mode); + constexpr af_interp_type flToAfInterpType(InterpolationMode mode); } // namespace detail class ArrayFireVisionExtension : public VisionExtension { - public: - bool isDataTypeSupported(const fl::dtype& dtype) const override; +public: + bool isDataTypeSupported(const fl::dtype& dtype) const override; - Tensor histogram( - const Tensor& tensor, - const unsigned numBins, - const double minVal, - const double maxVal) override; - Tensor histogram(const Tensor& tensor, const unsigned numBins) override; + Tensor histogram( + const Tensor& tensor, + const unsigned numBins, + const double minVal, + const double maxVal + ) override; + Tensor histogram(const Tensor& tensor, const unsigned numBins) override; - Tensor equalize(const Tensor& input, const Tensor& histogram) override; + Tensor equalize(const Tensor& input, const Tensor& histogram) override; - Tensor resize( - const Tensor& tensor, - const Shape& shape, - const InterpolationMode mode) override; + Tensor resize( + const Tensor& tensor, + const Shape& shape, + const InterpolationMode mode + ) override; - Tensor rotate(const Tensor& input, const float theta, const Tensor& fill) - override; - Tensor rotate( - const Tensor& input, - const float theta, - const InterpolationMode mode) override; + Tensor rotate(const Tensor& input, const float theta, const Tensor& fill) + override; + Tensor rotate( + const Tensor& input, + const float theta, + const InterpolationMode mode + ) override; - Tensor translate( - const Tensor& input, - const Shape& translation, - const Shape& outputDims, - const Tensor& fill) override; - Tensor translate( - const Tensor& input, - const Shape& translation, - const Shape& outputDims, - const InterpolationMode mode) override; + Tensor translate( + const Tensor& input, + const Shape& translation, + const Shape& outputDims, + const Tensor& fill + ) override; + Tensor translate( + const Tensor& input, + const Shape& translation, + const Shape& outputDims, + const InterpolationMode mode + ) override; - Tensor shear( - const Tensor& input, - const std::vector& skews, - const Shape& outputDims, - const Tensor& fill) override; - Tensor shear( - const Tensor& input, - const std::vector& skews, - const Shape& outputDims, - const InterpolationMode mode) override; + Tensor shear( + const Tensor& input, + const std::vector& skews, + const Shape& outputDims, + const Tensor& fill + ) override; + Tensor shear( + const Tensor& input, + const std::vector& skews, + const Shape& outputDims, + const InterpolationMode mode + ) override; - Tensor gaussianFilter(const Shape& shape) override; + Tensor gaussianFilter(const Shape& shape) override; }; } // namespace fl diff --git a/flashlight/pkg/vision/test/ModelSerializationTest.cpp b/flashlight/pkg/vision/test/ModelSerializationTest.cpp index 4255a7a..7da1be6 100644 --- a/flashlight/pkg/vision/test/ModelSerializationTest.cpp +++ b/flashlight/pkg/vision/test/ModelSerializationTest.cpp @@ -18,61 +18,68 @@ using namespace fl; using namespace fl::pkg::vision; TEST(SerializationTest, VisionTransformer) { - int hiddenEmbSize = 768; - int nHeads = 12; - int mlpSize = 3072; - - auto model = std::make_shared( - hiddenEmbSize, hiddenEmbSize / nHeads, mlpSize, nHeads, 0, 0); - model->eval(); - - const fs::path path = fs::temp_directory_path() / "VisionTransformer.mdl"; - save(path, model); - - std::shared_ptr loaded; - load(path, loaded); - loaded->eval(); - - auto input = Variable(fl::rand({hiddenEmbSize, 197, 20}), false); - auto output = model->forward({input}); - auto outputl = loaded->forward({input}); - - ASSERT_TRUE(allParamsClose(*loaded, *model)); - ASSERT_TRUE(allClose(outputl[0], output[0])); + int hiddenEmbSize = 768; + int nHeads = 12; + int mlpSize = 3072; + + auto model = std::make_shared( + hiddenEmbSize, + hiddenEmbSize / nHeads, + mlpSize, + nHeads, + 0, + 0 + ); + model->eval(); + + const fs::path path = fs::temp_directory_path() / "VisionTransformer.mdl"; + save(path, model); + + std::shared_ptr loaded; + load(path, loaded); + loaded->eval(); + + auto input = Variable(fl::rand({hiddenEmbSize, 197, 20}), false); + auto output = model->forward({input}); + auto outputl = loaded->forward({input}); + + ASSERT_TRUE(allParamsClose(*loaded, *model)); + ASSERT_TRUE(allClose(outputl[0], output[0])); } TEST(SerializationTest, ViT) { - int hiddenEmbSize = 768; - int nHeads = 12; - int mlpSize = 3072; - - auto model = std::make_shared( - 12, // FLAGS_model_layers, - hiddenEmbSize, - mlpSize, - nHeads, - 0.1, // setting non-zero drop prob for testing purpose - 0.1, // setting non-zero drop prob for testing purpose - 1000); - model->eval(); - - const fs::path path = fs::temp_directory_path() / "ViT.mdl"; - save(path, model); - - std::shared_ptr loaded; - load(path, loaded); - loaded->eval(); - - auto input = Variable(fl::rand({224, 224, 3, 20}), false); - auto output = model->forward({input}); - auto outputl = loaded->forward({input}); - - ASSERT_TRUE(allParamsClose(*loaded, *model)); - ASSERT_TRUE(allClose(outputl[0], output[0])); + int hiddenEmbSize = 768; + int nHeads = 12; + int mlpSize = 3072; + + auto model = std::make_shared( + 12, // FLAGS_model_layers, + hiddenEmbSize, + mlpSize, + nHeads, + 0.1, // setting non-zero drop prob for testing purpose + 0.1, // setting non-zero drop prob for testing purpose + 1000 + ); + model->eval(); + + const fs::path path = fs::temp_directory_path() / "ViT.mdl"; + save(path, model); + + std::shared_ptr loaded; + load(path, loaded); + loaded->eval(); + + auto input = Variable(fl::rand({224, 224, 3, 20}), false); + auto output = model->forward({input}); + auto outputl = loaded->forward({input}); + + ASSERT_TRUE(allParamsClose(*loaded, *model)); + ASSERT_TRUE(allClose(outputl[0], output[0])); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/vision/test/PositionalEmbeddingSineTest.cpp b/flashlight/pkg/vision/test/PositionalEmbeddingSineTest.cpp index 98a0373..a35f85e 100644 --- a/flashlight/pkg/vision/test/PositionalEmbeddingSineTest.cpp +++ b/flashlight/pkg/vision/test/PositionalEmbeddingSineTest.cpp @@ -14,19 +14,19 @@ using namespace fl; using namespace fl::pkg::vision; TEST(PositionalEmbeddingSine, PytorchComparision) { - int hiddenDim = 8; - int H = 6; - int W = 6; - int B = 1; - Shape dims = {W, H, 1, B}; - auto inputArray = fl::full(dims, 0); - inputArray(fl::range(0, 4), fl::range(0, 4)) = fl::full({4, 4, 1, B}, 1); - auto input = Variable(inputArray, false); + int hiddenDim = 8; + int H = 6; + int W = 6; + int B = 1; + Shape dims = {W, H, 1, B}; + auto inputArray = fl::full(dims, 0); + inputArray(fl::range(0, 4), fl::range(0, 4)) = fl::full({4, 4, 1, B}, 1); + auto input = Variable(inputArray, false); - PositionalEmbeddingSine pos(hiddenDim / 2, 10000.0f, false, 0.0f); + PositionalEmbeddingSine pos(hiddenDim / 2, 10000.0f, false, 0.0f); - auto result = pos.forward({input})[0]; - ASSERT_EQ(result.shape(), Shape({6, 6, 8, 1})); - ASSERT_LE(result(0, 5, 3, 0).tensor().scalar() - 0.9992f, 1e-5); - ASSERT_LE(result(0, 0, 0, 0).tensor().scalar() - 0.841471f, 1e-5); + auto result = pos.forward({input})[0]; + ASSERT_EQ(result.shape(), Shape({6, 6, 8, 1})); + ASSERT_LE(result(0, 5, 3, 0).tensor().scalar() - 0.9992f, 1e-5); + ASSERT_LE(result(0, 0, 0, 0).tensor().scalar() - 0.841471f, 1e-5); } diff --git a/flashlight/pkg/vision/test/TransformerTest.cpp b/flashlight/pkg/vision/test/TransformerTest.cpp index ff7d1e0..6ebbc15 100644 --- a/flashlight/pkg/vision/test/TransformerTest.cpp +++ b/flashlight/pkg/vision/test/TransformerTest.cpp @@ -16,254 +16,260 @@ using namespace fl; using namespace fl::pkg::vision; TEST(Tranformer, BasicAttention) { - int B = 1; - int S = 5; - int E = 1; - int L = 1; - int nHeads = 1; - - auto keys = fl::full({E, B, S}, 0.); - keys(0, 0, 2) = 10000; - - auto query = Variable(fl::full({E, B, L}, 10000, fl::dtype::f32), false); - auto key = Variable(keys, false); - auto value = Variable(fl::iota({E, B, S}, {}, fl::dtype::f32), false); - - auto result = transformerMultiheadAttention( - query, - key, - value, - Variable(), - nHeads, // num_heads - 0.0); - - ASSERT_EQ(result.scalar(), 2.0); + int B = 1; + int S = 5; + int E = 1; + int L = 1; + int nHeads = 1; + + auto keys = fl::full({E, B, S}, 0.); + keys(0, 0, 2) = 10000; + + auto query = Variable(fl::full({E, B, L}, 10000, fl::dtype::f32), false); + auto key = Variable(keys, false); + auto value = Variable(fl::iota({E, B, S}, {}, fl::dtype::f32), false); + + auto result = transformerMultiheadAttention( + query, + key, + value, + Variable(), + nHeads, // num_heads + 0.0 + ); + + ASSERT_EQ(result.scalar(), 2.0); }; TEST(Tranformer, BasicAttentionNonMasked) { - int B = 1; - int S = 5; - int E = 1; - int L = 1; - int nHeads = 1; - - auto keys = fl::full({E, B, S}, 0.); - keys(0, 0, 2) = 10000; - keys(0, 0, 4) = 10000; - - auto query = Variable(fl::full({E, B, L}, 10000, fl::dtype::f32), false); - auto key = Variable(keys, false); - auto value = Variable(fl::iota({E, B, S}, {}, fl::dtype::f32), false); - - auto result = transformerMultiheadAttention( - query, - key, - value, - Variable(), - nHeads, // num_heads - 0.0); - - ASSERT_EQ(result.scalar(), 3.0); + int B = 1; + int S = 5; + int E = 1; + int L = 1; + int nHeads = 1; + + auto keys = fl::full({E, B, S}, 0.); + keys(0, 0, 2) = 10000; + keys(0, 0, 4) = 10000; + + auto query = Variable(fl::full({E, B, L}, 10000, fl::dtype::f32), false); + auto key = Variable(keys, false); + auto value = Variable(fl::iota({E, B, S}, {}, fl::dtype::f32), false); + + auto result = transformerMultiheadAttention( + query, + key, + value, + Variable(), + nHeads, // num_heads + 0.0 + ); + + ASSERT_EQ(result.scalar(), 3.0); }; TEST(Tranformer, BasicAttentionMasked) { - int B = 1; - int S = 5; - int E = 1; - int L = 1; - int nHeads = 1; - - auto keys = fl::full({E, B, S}, 0.); - keys(0, 0, 2) = 10000; - keys(0, 0, 4) = 10000; - - auto query = Variable(fl::full({E, B, L}, 10000, fl::dtype::f32), false); - auto key = Variable(keys, false); - auto value = Variable(fl::iota({E, B, S}, {}, fl::dtype::f32), false); - int maskLength = 3; - auto mask = fl::full({S, B}, 0); - mask(fl::range(0, maskLength)) = fl::full({maskLength, B}, 1); - - auto result = transformerMultiheadAttention( - query, - key, - value, - Variable(mask, false), - nHeads, // num_heads - 0.0); - - ASSERT_EQ(result.scalar(), 2.0); + int B = 1; + int S = 5; + int E = 1; + int L = 1; + int nHeads = 1; + + auto keys = fl::full({E, B, S}, 0.); + keys(0, 0, 2) = 10000; + keys(0, 0, 4) = 10000; + + auto query = Variable(fl::full({E, B, L}, 10000, fl::dtype::f32), false); + auto key = Variable(keys, false); + auto value = Variable(fl::iota({E, B, S}, {}, fl::dtype::f32), false); + int maskLength = 3; + auto mask = fl::full({S, B}, 0); + mask(fl::range(0, maskLength)) = fl::full({maskLength, B}, 1); + + auto result = transformerMultiheadAttention( + query, + key, + value, + Variable(mask, false), + nHeads, // num_heads + 0.0 + ); + + ASSERT_EQ(result.scalar(), 2.0); }; TEST(Tranformer, MultiHeadedAttention) { - int B = 1; - int S = 5; - int E = 2; - int L = 1; - int nHeads = 2; - - auto keys = fl::full({E, B, S}, 0.); - keys(0, 0, 2) = 10000; // First head --> 2 - keys(1, 0, 3) = 10000; // Second head attend to 3 - - auto query = Variable(fl::full({E, B, L}, 10000, fl::dtype::f32), false); - auto key = Variable(keys, false); - // auto value = Variable(fl::iota({ E, B, S }), false); - auto value = Variable(fl::iota({1, 1, S}, {E, B, 1}), false); - - auto result = transformerMultiheadAttention( - query, - key, - value, - Variable(), - nHeads, // num_heads - 0.0); - - ASSERT_EQ(result(0).scalar(), 2.0f); - ASSERT_EQ(result(1).scalar(), 3.0f); + int B = 1; + int S = 5; + int E = 2; + int L = 1; + int nHeads = 2; + + auto keys = fl::full({E, B, S}, 0.); + keys(0, 0, 2) = 10000; // First head --> 2 + keys(1, 0, 3) = 10000; // Second head attend to 3 + + auto query = Variable(fl::full({E, B, L}, 10000, fl::dtype::f32), false); + auto key = Variable(keys, false); + // auto value = Variable(fl::iota({ E, B, S }), false); + auto value = Variable(fl::iota({1, 1, S}, {E, B, 1}), false); + + auto result = transformerMultiheadAttention( + query, + key, + value, + Variable(), + nHeads, // num_heads + 0.0 + ); + + ASSERT_EQ(result(0).scalar(), 2.0f); + ASSERT_EQ(result(1).scalar(), 3.0f); } TEST(Tranformer, MultiHeadedAttentionBatch) { - int B = 2; - int S = 5; - int E = 2; - int L = 1; - int nHeads = 2; - - auto keys = fl::full({E, B, S}, 0.); - keys(0, 0, 2) = 10000; // First head --> 2 - keys(1, 0, 3) = 10000; // Second head attend to 3 - keys(0, 1, 1) = 10000; // First head --> 2 - keys(1, 1, 3) = 10000; // Second head attend to 3 - - auto query = Variable(fl::full({E, B, L}, 10000, fl::dtype::f32), false); - auto key = Variable(keys, false); - auto value = Variable(fl::iota({1, 1, S}, {E, B, 1}), false); - - auto result = transformerMultiheadAttention( - query, - key, - value, - Variable(), - nHeads, // num_heads - 0.0); - - ASSERT_EQ(result(0, 0).scalar(), 2.0f); - ASSERT_EQ(result(1, 0).scalar(), 3.0f); - ASSERT_EQ(result(0, 1).scalar(), 1.0f); - ASSERT_EQ(result(1, 1).scalar(), 3.0f); + int B = 2; + int S = 5; + int E = 2; + int L = 1; + int nHeads = 2; + + auto keys = fl::full({E, B, S}, 0.); + keys(0, 0, 2) = 10000; // First head --> 2 + keys(1, 0, 3) = 10000; // Second head attend to 3 + keys(0, 1, 1) = 10000; // First head --> 2 + keys(1, 1, 3) = 10000; // Second head attend to 3 + + auto query = Variable(fl::full({E, B, L}, 10000, fl::dtype::f32), false); + auto key = Variable(keys, false); + auto value = Variable(fl::iota({1, 1, S}, {E, B, 1}), false); + + auto result = transformerMultiheadAttention( + query, + key, + value, + Variable(), + nHeads, // num_heads + 0.0 + ); + + ASSERT_EQ(result(0, 0).scalar(), 2.0f); + ASSERT_EQ(result(1, 0).scalar(), 3.0f); + ASSERT_EQ(result(0, 1).scalar(), 1.0f); + ASSERT_EQ(result(1, 1).scalar(), 3.0f); } TEST(Tranformer, MultiHeadedAttentionMultipleQueries) { - int B = 1; - int S = 5; - int E = 2; - int L = 2; - int nHeads = 2; - - auto keys = fl::full({E, B, S}, 0.); - keys(0, 0, 2) = 10000; - keys(0, 0, 1) = -10000; - // Second head - keys(1, 0, 3) = -10000; - keys(1, 0, 0) = 10000; - - auto queries = fl::full({E, B, L}, 0.); - queries(0, 0, 0) = 10000; // Matches with 2 - queries(1, 0, 0) = -10000; // matches with 3 - // Second query - queries(0, 0, 1) = -10000; // Matches with 1 - queries(1, 0, 1) = 10000; // matches 0 - - auto query = Variable(queries, false); - auto key = Variable(keys, false); - auto value = Variable(fl::iota({1, 1, S}, {E, B, 1}), false); - - auto result = transformerMultiheadAttention( - query, - key, - value, - Variable(), - nHeads, // num_heads - 0.0); - - ASSERT_EQ(result(0, 0, 0).scalar(), 2.0f); - ASSERT_EQ(result(1, 0, 0).scalar(), 3.0f); - // Second query - ASSERT_EQ(result(0, 0, 1).scalar(), 1.0f); - ASSERT_EQ(result(1, 0, 1).scalar(), 0.0f); + int B = 1; + int S = 5; + int E = 2; + int L = 2; + int nHeads = 2; + + auto keys = fl::full({E, B, S}, 0.); + keys(0, 0, 2) = 10000; + keys(0, 0, 1) = -10000; + // Second head + keys(1, 0, 3) = -10000; + keys(1, 0, 0) = 10000; + + auto queries = fl::full({E, B, L}, 0.); + queries(0, 0, 0) = 10000; // Matches with 2 + queries(1, 0, 0) = -10000; // matches with 3 + // Second query + queries(0, 0, 1) = -10000; // Matches with 1 + queries(1, 0, 1) = 10000; // matches 0 + + auto query = Variable(queries, false); + auto key = Variable(keys, false); + auto value = Variable(fl::iota({1, 1, S}, {E, B, 1}), false); + + auto result = transformerMultiheadAttention( + query, + key, + value, + Variable(), + nHeads, // num_heads + 0.0 + ); + + ASSERT_EQ(result(0, 0, 0).scalar(), 2.0f); + ASSERT_EQ(result(1, 0, 0).scalar(), 3.0f); + // Second query + ASSERT_EQ(result(0, 0, 1).scalar(), 1.0f); + ASSERT_EQ(result(1, 0, 1).scalar(), 0.0f); } TEST(Tranformer, Size) { - int B = 3; - int H = 5; - int W = 5; - int C = 16; - float dropout = 0.5; - int bbox_queries = 100; - int numEncoderDecoder = 2; - int mlpDim = 32; - int numHeads = 8; - fl::pkg::vision::Transformer tr( - C, numHeads, numEncoderDecoder, numEncoderDecoder, mlpDim, dropout); - - std::vector inputs = { - Variable(fl::rand({W, H, C, B}), false), // input Projection - Variable(fl::rand({W, H, 1, B}), false), // mask - Variable(fl::rand({C, bbox_queries}), false), // query_embed - Variable(fl::rand({W, H, C, B}), false) // query_embed - }; - auto output = tr(inputs)[0]; - ASSERT_EQ(output.dim(0), C) - << "Transformer should return model dim as first dimension"; - ASSERT_EQ(output.dim(1), bbox_queries) - << "Transformer did not return the correct number of labels"; - ASSERT_EQ(output.dim(2), B) - << "Transformer did not return the correct number of batches"; + int B = 3; + int H = 5; + int W = 5; + int C = 16; + float dropout = 0.5; + int bbox_queries = 100; + int numEncoderDecoder = 2; + int mlpDim = 32; + int numHeads = 8; + fl::pkg::vision::Transformer tr( + C, numHeads, numEncoderDecoder, numEncoderDecoder, mlpDim, dropout); + + std::vector inputs = { + Variable(fl::rand({W, H, C, B}), false), // input Projection + Variable(fl::rand({W, H, 1, B}), false), // mask + Variable(fl::rand({C, bbox_queries}), false), // query_embed + Variable(fl::rand({W, H, C, B}), false) // query_embed + }; + auto output = tr(inputs)[0]; + ASSERT_EQ(output.dim(0), C) + << "Transformer should return model dim as first dimension"; + ASSERT_EQ(output.dim(1), bbox_queries) + << "Transformer did not return the correct number of labels"; + ASSERT_EQ(output.dim(2), B) + << "Transformer did not return the correct number of batches"; } TEST(Tranformer, Masked) { - int B = 2; - int H = 8; - int W = 8; - int maskH = 3; - int maskW = 3; - int C = 16; - float dropout = 0.0; - int bbox_queries = 2; - int numEncoderDecoder = 2; - int mlpDim = 32; - int numHeads = 8; - fl::pkg::vision::Transformer tr( - C, numHeads, numEncoderDecoder, numEncoderDecoder, mlpDim, dropout); - - PositionalEmbeddingSine pos(C / 2, 10000.0f, false, 0.0f); - - auto nonMask = fl::full({maskW, maskH, 1, B}, 1); - - auto mask = fl::full({W, H, 1, B}, 0); - mask(fl::range(0, maskW), fl::range(0, maskH)) = nonMask; - auto nonMaskPos = pos.forward({Variable(nonMask, false)})[0]; - - std::cout << "--- nonMaskPos " << nonMaskPos.shape() << std::endl; - - std::vector nonMaskInput = { - Variable(fl::rand({maskW, maskH, C, B}), false), // input Projection - Variable(fl::full({maskW, maskH, 1, B}, 1), false), // mask - Variable(fl::rand({C, bbox_queries}), false), // query_embed - nonMaskPos}; - auto nonMaskOutput = tr(nonMaskInput)[0]; - - auto nonMaskedSrc = fl::rand({W, H, C, B}); - nonMaskedSrc(fl::range(0, maskW), fl::range(0, maskH)) = - nonMaskInput[0].tensor(); - - auto maskPos = pos.forward({fl::Variable(mask, false)})[0]; - - std::vector maskInput = { - Variable(nonMaskedSrc, false), // input Projection - Variable(mask, false), // mask - nonMaskInput[2], // query_embed - maskPos}; - auto maskOutput = tr(maskInput)[0]; + int B = 2; + int H = 8; + int W = 8; + int maskH = 3; + int maskW = 3; + int C = 16; + float dropout = 0.0; + int bbox_queries = 2; + int numEncoderDecoder = 2; + int mlpDim = 32; + int numHeads = 8; + fl::pkg::vision::Transformer tr( + C, numHeads, numEncoderDecoder, numEncoderDecoder, mlpDim, dropout); + + PositionalEmbeddingSine pos(C / 2, 10000.0f, false, 0.0f); + + auto nonMask = fl::full({maskW, maskH, 1, B}, 1); + + auto mask = fl::full({W, H, 1, B}, 0); + mask(fl::range(0, maskW), fl::range(0, maskH)) = nonMask; + auto nonMaskPos = pos.forward({Variable(nonMask, false)})[0]; + + std::cout << "--- nonMaskPos " << nonMaskPos.shape() << std::endl; + + std::vector nonMaskInput = { + Variable(fl::rand({maskW, maskH, C, B}), false), // input Projection + Variable(fl::full({maskW, maskH, 1, B}, 1), false), // mask + Variable(fl::rand({C, bbox_queries}), false), // query_embed + nonMaskPos}; + auto nonMaskOutput = tr(nonMaskInput)[0]; + + auto nonMaskedSrc = fl::rand({W, H, C, B}); + nonMaskedSrc(fl::range(0, maskW), fl::range(0, maskH)) = + nonMaskInput[0].tensor(); + + auto maskPos = pos.forward({fl::Variable(mask, false)})[0]; + + std::vector maskInput = { + Variable(nonMaskedSrc, false), // input Projection + Variable(mask, false), // mask + nonMaskInput[2], // query_embed + maskPos}; + auto maskOutput = tr(maskInput)[0]; } diff --git a/flashlight/pkg/vision/test/TransformsTest.cpp b/flashlight/pkg/vision/test/TransformsTest.cpp index 5d330d9..2a5ee6a 100644 --- a/flashlight/pkg/vision/test/TransformsTest.cpp +++ b/flashlight/pkg/vision/test/TransformsTest.cpp @@ -13,86 +13,86 @@ using namespace fl; TEST(Crop, CropBasic) { - std::vector bboxesVector = { - 10, - 10, - 20, - 20, // box1 - 20, - 20, - 30, - 30 // box2 - }; + std::vector bboxesVector = { + 10, + 10, + 20, + 20, // box1 + 20, + 20, + 30, + 30 // box2 + }; - std::vector in = { - fl::full({256, 256, 10}, 1.), - Tensor(), - Tensor(), - Tensor(), - Tensor::fromVector({4, 2}, bboxesVector), - fl::full({1, 2}, 0.)}; + std::vector in = { + fl::full({256, 256, 10}, 1.), + Tensor(), + Tensor(), + Tensor(), + Tensor::fromVector({4, 2}, bboxesVector), + fl::full({1, 2}, 0.)}; - // Crop from x, y (10, 10), with target heigh and width to be ten - std::vector out = fl::pkg::vision::crop(in, 10, 5, 20, 25); - auto outBoxes = out[4]; - std::vector expVector = { - 0, - 5, - 10, - 15, // box1 - 10, - 15, - 20, - 25 // box2 - }; - Tensor expOut = Tensor::fromVector({4, 2}, expVector); - ASSERT_TRUE(allClose(expOut, outBoxes, 1e-5)); + // Crop from x, y (10, 10), with target heigh and width to be ten + std::vector out = fl::pkg::vision::crop(in, 10, 5, 20, 25); + auto outBoxes = out[4]; + std::vector expVector = { + 0, + 5, + 10, + 15, // box1 + 10, + 15, + 20, + 25 // box2 + }; + Tensor expOut = Tensor::fromVector({4, 2}, expVector); + ASSERT_TRUE(allClose(expOut, outBoxes, 1e-5)); } TEST(Crop, CropClip) { - int numBoxes = 3; - int numElementsPerBoxes = 4; + int numBoxes = 3; + int numElementsPerBoxes = 4; - std::vector bboxesVector = { - 0, - 0, - 100, - 100, // box1 - 0, - 0, - 4, - 4, // box3 // will be removed - 5, - 5, - 105, - 105 // box2 - }; + std::vector bboxesVector = { + 0, + 0, + 100, + 100, // box1 + 0, + 0, + 4, + 4, // box3 // will be removed + 5, + 5, + 105, + 105 // box2 + }; - std::vector in = { - fl::full({256, 256, 10}, 1.), - Tensor(), - Tensor(), - Tensor(), - Tensor::fromVector({numElementsPerBoxes, numBoxes}, bboxesVector), - fl::iota({1, 3})}; + std::vector in = { + fl::full({256, 256, 10}, 1.), + Tensor(), + Tensor(), + Tensor(), + Tensor::fromVector({numElementsPerBoxes, numBoxes}, bboxesVector), + fl::iota({1, 3})}; - // Crop from x, y (10, 10), with target heigh and width to be ten - std::vector out = fl::pkg::vision::crop(in, 5, 5, 100, 100); - auto outBoxes = out[4]; - auto outClasses = out[5]; - std::vector expVector = { - 0, - 0, - 95, - 95, // box1 - 0, - 0, - 100, - 100 // box2 - }; - Tensor expOut = Tensor::fromVector({4, 2}, expVector); - std::vector expClassVector = {0, 2}; - Tensor expClassOut = Tensor::fromVector({1, 2}, expClassVector); - ASSERT_TRUE(allClose(expOut, outBoxes, 1e-5)); - ASSERT_TRUE(allClose(expClassOut, outClasses, 1e-5)); + // Crop from x, y (10, 10), with target heigh and width to be ten + std::vector out = fl::pkg::vision::crop(in, 5, 5, 100, 100); + auto outBoxes = out[4]; + auto outClasses = out[5]; + std::vector expVector = { + 0, + 0, + 95, + 95, // box1 + 0, + 0, + 100, + 100 // box2 + }; + Tensor expOut = Tensor::fromVector({4, 2}, expVector); + std::vector expClassVector = {0, 2}; + Tensor expClassOut = Tensor::fromVector({1, 2}, expClassVector); + ASSERT_TRUE(allClose(expOut, outBoxes, 1e-5)); + ASSERT_TRUE(allClose(expClassOut, outClasses, 1e-5)); } diff --git a/flashlight/pkg/vision/test/criterion/HungarianTest.cpp b/flashlight/pkg/vision/test/criterion/HungarianTest.cpp index 42428fa..17b946d 100644 --- a/flashlight/pkg/vision/test/criterion/HungarianTest.cpp +++ b/flashlight/pkg/vision/test/criterion/HungarianTest.cpp @@ -12,206 +12,206 @@ using namespace fl::lib::set; TEST(HungarianTest, DiagnalAssignments) { - int M = 4; // Rows - int N = 4; // Columns - std::vector costsVec(N * N); - for (int r = 0; r < M; r++) { - for (int c = 0; c < N; c++) { - costsVec[r * N + c] = (1 + r) * (1 + c); + int M = 4; // Rows + int N = 4; // Columns + std::vector costsVec(N * N); + for(int r = 0; r < M; r++) { + for(int c = 0; c < N; c++) { + costsVec[r * N + c] = (1 + r) * (1 + c); + } + } + + std::vector expRowIdxs = {0, 1, 2, 3}; + std::vector expColIdxs = {3, 2, 1, 0}; + std::vector rowIdxs(N); + std::vector colIdxs(M); + hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); + for(int r = 0; r < M; r++) { + EXPECT_EQ(rowIdxs[r], expRowIdxs[r]) << "Assignment differs at index " << r; + } + for(int c = 0; c < N; c++) { + EXPECT_EQ(rowIdxs[c], expRowIdxs[c]) << "Assignment differs at index " << c; } - } - - std::vector expRowIdxs = {0, 1, 2, 3}; - std::vector expColIdxs = {3, 2, 1, 0}; - std::vector rowIdxs(N); - std::vector colIdxs(M); - hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); - for (int r = 0; r < M; r++) { - EXPECT_EQ(rowIdxs[r], expRowIdxs[r]) << "Assignment differs at index " << r; - } - for (int c = 0; c < N; c++) { - EXPECT_EQ(rowIdxs[c], expRowIdxs[c]) << "Assignment differs at index " << c; - } } TEST(HungarianTest, FullPipelineFromWiki) { - int M = 3; // Rows - int N = 3; // Columns - // From https://en.wikipedia.org/wiki/Hungarian_algorithm - std::vector costsVec = {2, 3, 3, 3, 2, 3, 3, 3, 2}; - - std::vector expRowIdxs = {0, 1, 2}; - std::vector expColIdxs = {0, 1, 2}; - - std::vector rowIdxs(N); - std::vector colIdxs(M); - hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); - for (int r = 0; r < M; r++) { - EXPECT_EQ(rowIdxs[r], expRowIdxs[r]) << "Assignment differs at index " << r; - } - for (int c = 0; c < N; c++) { - EXPECT_EQ(rowIdxs[c], expRowIdxs[c]) << "Assignment differs at index " << c; - } + int M = 3; // Rows + int N = 3; // Columns + // From https://en.wikipedia.org/wiki/Hungarian_algorithm + std::vector costsVec = {2, 3, 3, 3, 2, 3, 3, 3, 2}; + + std::vector expRowIdxs = {0, 1, 2}; + std::vector expColIdxs = {0, 1, 2}; + + std::vector rowIdxs(N); + std::vector colIdxs(M); + hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); + for(int r = 0; r < M; r++) { + EXPECT_EQ(rowIdxs[r], expRowIdxs[r]) << "Assignment differs at index " << r; + } + for(int c = 0; c < N; c++) { + EXPECT_EQ(rowIdxs[c], expRowIdxs[c]) << "Assignment differs at index " << c; + } } TEST(HungarianTest, FullPipelineSimple1) { - int M = 3; // Rows - int N = 3; // Columns - std::vector costsVec = { - 1500, - 2000, - 2000, - 4000, - 6000, - 4000, - 4500, - 3500, - 2500, - }; - - std::vector expAssignment = {0, 1, 0, 1, 0, 0, 0, 0, 1}; - std::vector assignment(N * M); - hungarian(costsVec.data(), assignment.data(), N, M); - for (int c = 0; c < N; c++) { - for (int r = 0; r < M; r++) { - EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) - << "Assignment differs at row " << r << " and col " << c; + int M = 3; // Rows + int N = 3; // Columns + std::vector costsVec = { + 1500, + 2000, + 2000, + 4000, + 6000, + 4000, + 4500, + 3500, + 2500, + }; + + std::vector expAssignment = {0, 1, 0, 1, 0, 0, 0, 0, 1}; + std::vector assignment(N * M); + hungarian(costsVec.data(), assignment.data(), N, M); + for(int c = 0; c < N; c++) { + for(int r = 0; r < M; r++) { + EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) + << "Assignment differs at row " << r << " and col " << c; + } } - } } TEST(HungarianTest, FullPipelineSimple2) { - int M = 3; // Rows - int N = 3; // Columns - std::vector costsVec = { - 2500, 4000, 2000, 4000, 6000, 4000, 3500, 3500, 2500}; - - std::vector expAssignment = {0, 0, 1, 1, 0, 0, 0, 1, 0}; - std::vector assignment(N * M); - hungarian(costsVec.data(), assignment.data(), N, M); - for (int c = 0; c < N; c++) { - for (int r = 0; r < M; r++) { - EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) - << "Assignment differs at row " << r << " and col " << c; + int M = 3; // Rows + int N = 3; // Columns + std::vector costsVec = { + 2500, 4000, 2000, 4000, 6000, 4000, 3500, 3500, 2500}; + + std::vector expAssignment = {0, 0, 1, 1, 0, 0, 0, 1, 0}; + std::vector assignment(N * M); + hungarian(costsVec.data(), assignment.data(), N, M); + for(int c = 0; c < N; c++) { + for(int r = 0; r < M; r++) { + EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) + << "Assignment differs at row " << r << " and col " << c; + } } - } } TEST(HungarianTest, FullPipelineSimple3) { - int M = 3; // Rows - int N = 3; // Columns - std::vector costsVec = {108, 150, 122, 125, 135, 148, 150, 175, 250}; - - std::vector expAssignment = {0, 0, 1, 0, 1, 0, 1, 0, 0}; - std::vector assignment(N * M); - hungarian(costsVec.data(), assignment.data(), N, M); - for (int c = 0; c < N; c++) { - for (int r = 0; r < M; r++) { - EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) - << "Assignment differs at row " << r << " and col " << c; + int M = 3; // Rows + int N = 3; // Columns + std::vector costsVec = {108, 150, 122, 125, 135, 148, 150, 175, 250}; + + std::vector expAssignment = {0, 0, 1, 0, 1, 0, 1, 0, 0}; + std::vector assignment(N * M); + hungarian(costsVec.data(), assignment.data(), N, M); + for(int c = 0; c < N; c++) { + for(int r = 0; r < M; r++) { + EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) + << "Assignment differs at row " << r << " and col " << c; + } } - } } TEST(HungarianTest, FullPipelineSize6) { - int M = 6; // Rows - int N = 6; // Columns - std::vector costsVec = {7, 9, 3, 7, 8, 4, 2, 6, 8, 9, 4, 2, - 1, 9, 3, 4, 7, 9, 9, 5, 1, 2, 4, 3, - 4, 5, 8, 2, 8, 1, 4, 2, 9, 3, 2, 9}; - - std::vector expAssignment = {0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, - 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0}; - std::vector assignment(N * M); - hungarian(costsVec.data(), assignment.data(), N, M); - for (int c = 0; c < N; c++) { - for (int r = 0; r < M; r++) { - EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) - << "Assignment differs at row " << r << " and col " << c; + int M = 6; // Rows + int N = 6; // Columns + std::vector costsVec = {7, 9, 3, 7, 8, 4, 2, 6, 8, 9, 4, 2, + 1, 9, 3, 4, 7, 9, 9, 5, 1, 2, 4, 3, + 4, 5, 8, 2, 8, 1, 4, 2, 9, 3, 2, 9}; + + std::vector expAssignment = {0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0}; + std::vector assignment(N * M); + hungarian(costsVec.data(), assignment.data(), N, M); + for(int c = 0; c < N; c++) { + for(int r = 0; r < M; r++) { + EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) + << "Assignment differs at row " << r << " and col " << c; + } } - } } TEST(HungarianTest, 6x6Example2) { - int M = 6; // Rows - int N = 6; // Columns - std::vector costsVec = {7, 9, 3, 7, 8, 4, 2, 6, 8, 9, 4, 2, - 1, 9, 3, 4, 7, 9, 1, 3, 4, 8, 2, 7, - 4, 5, 8, 2, 8, 1, 4, 2, 9, 3, 2, 9}; - - std::vector expAssignment = {0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, - 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0}; - std::vector assignment(N * M); - hungarian(costsVec.data(), assignment.data(), N, M); - for (int c = 0; c < N; c++) { - for (int r = 0; r < M; r++) { - EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) - << "Assignment differs at row " << r << " and col " << c; + int M = 6; // Rows + int N = 6; // Columns + std::vector costsVec = {7, 9, 3, 7, 8, 4, 2, 6, 8, 9, 4, 2, + 1, 9, 3, 4, 7, 9, 1, 3, 4, 8, 2, 7, + 4, 5, 8, 2, 8, 1, 4, 2, 9, 3, 2, 9}; + + std::vector expAssignment = {0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0}; + std::vector assignment(N * M); + hungarian(costsVec.data(), assignment.data(), N, M); + for(int c = 0; c < N; c++) { + for(int r = 0; r < M; r++) { + EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) + << "Assignment differs at row " << r << " and col " << c; + } } - } } TEST(HungarianTest, NonSquare2) { - int M = 1; // Rows - int N = 2; // Columns - std::vector costsVec = {0, 0.5}; - - std::vector expRowIdxs = {0}; - std::vector expColIdxs = {1}; - - const int num_indices = std::min(N, M); - std::vector rowIdxs(num_indices, -1); - std::vector colIdxs(num_indices, -1); - hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); - for (int i = 0; i < num_indices; i++) { - EXPECT_EQ(rowIdxs[i], expRowIdxs[i]) << "Assignment differs at index " << i; - EXPECT_EQ(colIdxs[i], colIdxs[i]) << "Assignment differs at index " << i; - } + int M = 1; // Rows + int N = 2; // Columns + std::vector costsVec = {0, 0.5}; + + std::vector expRowIdxs = {0}; + std::vector expColIdxs = {1}; + + const int num_indices = std::min(N, M); + std::vector rowIdxs(num_indices, -1); + std::vector colIdxs(num_indices, -1); + hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); + for(int i = 0; i < num_indices; i++) { + EXPECT_EQ(rowIdxs[i], expRowIdxs[i]) << "Assignment differs at index " << i; + EXPECT_EQ(colIdxs[i], colIdxs[i]) << "Assignment differs at index " << i; + } } TEST(HungarianTest, NonSquare) { - int M = 1; // Rows - int N = 2; // Columns - std::vector costsVec = {0.5, 0}; - - std::vector expRowIdxs = {0}; - std::vector expColIdxs = {0}; - - const int num_indices = std::min(N, M); - std::vector rowIdxs(num_indices, -1); - std::vector colIdxs(num_indices, -1); - hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); - for (int i = 0; i < num_indices; i++) { - EXPECT_EQ(rowIdxs[i], expRowIdxs[i]) << "Assignment differs at index " << i; - EXPECT_EQ(colIdxs[i], colIdxs[i]) << "Assignment differs at index " << i; - } + int M = 1; // Rows + int N = 2; // Columns + std::vector costsVec = {0.5, 0}; + + std::vector expRowIdxs = {0}; + std::vector expColIdxs = {0}; + + const int num_indices = std::min(N, M); + std::vector rowIdxs(num_indices, -1); + std::vector colIdxs(num_indices, -1); + hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); + for(int i = 0; i < num_indices; i++) { + EXPECT_EQ(rowIdxs[i], expRowIdxs[i]) << "Assignment differs at index " << i; + EXPECT_EQ(colIdxs[i], colIdxs[i]) << "Assignment differs at index " << i; + } } TEST(HungarianTest, NonSquare3) { - int M = 2; // Rows - int N = 3; // Columns - std::vector costsVec = { - 0, - 0.5, - 0.5, - 2, - 2, - 3, - }; - - std::vector expRowIdxs = { - 0, - 1, - }; - std::vector expColIdxs = {1, 0}; - - const int num_indices = std::min(N, M); - std::vector rowIdxs(num_indices, -1); - std::vector colIdxs(num_indices, -1); - hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); - for (int i = 0; i < num_indices; i++) { - EXPECT_EQ(rowIdxs[i], expRowIdxs[i]) << "Assignment differs at index " << i; - EXPECT_EQ(colIdxs[i], colIdxs[i]) << "Assignment differs at index " << i; - } + int M = 2; // Rows + int N = 3; // Columns + std::vector costsVec = { + 0, + 0.5, + 0.5, + 2, + 2, + 3, + }; + + std::vector expRowIdxs = { + 0, + 1, + }; + std::vector expColIdxs = {1, 0}; + + const int num_indices = std::min(N, M); + std::vector rowIdxs(num_indices, -1); + std::vector colIdxs(num_indices, -1); + hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); + for(int i = 0; i < num_indices; i++) { + EXPECT_EQ(rowIdxs[i], expRowIdxs[i]) << "Assignment differs at index " << i; + EXPECT_EQ(colIdxs[i], colIdxs[i]) << "Assignment differs at index " << i; + } } diff --git a/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp b/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp index 8e28291..b64fc6d 100644 --- a/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp +++ b/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp @@ -17,364 +17,409 @@ using namespace fl; using namespace fl::pkg::vision; std::unordered_map getLossWeights() { - const std::unordered_map lossWeightsBase = { - {"lossCe", 1.f}, {"lossGiou", 1.f}, {"lossBbox", 1.f}}; - - std::unordered_map lossWeights; - for (int i = 0; i < 6; i++) { - for (const auto& l : lossWeightsBase) { - std::string key = l.first + "_" + std::to_string(i); - lossWeights[key] = l.second; + const std::unordered_map lossWeightsBase = { + {"lossCe", 1.f}, {"lossGiou", 1.f}, {"lossBbox", 1.f}}; + + std::unordered_map lossWeights; + for(int i = 0; i < 6; i++) { + for(const auto& l : lossWeightsBase) { + std::string key = l.first + "_" + std::to_string(i); + lossWeights[key] = l.second; + } } - } - return lossWeights; + return lossWeights; } TEST(SetCriterion, PytorchRepro) { - const int numClasses = 80; - const int numTargets = 1; - const int numPreds = 1; - const int numBatches = 1; - std::vector predBoxesVec = {2, 2, 3, 3}; - - std::vector targetBoxesVec = {2, 2, 3, 3}; - - std::vector targetClassVec = {1}; - auto predBoxes = fl::Variable( - Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), true); - auto predLogits = - fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - - std::vector targetBoxes = {fl::Variable( - Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), false)}; - - std::vector targetClasses = {fl::Variable( - Tensor::fromVector({numTargets, numBatches}, targetClassVec), false)}; - auto matcher = HungarianMatcher(1, 1, 1); - auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); - auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); - EXPECT_EQ(loss["lossGiou_0"].scalar(), 0.0); + const int numClasses = 80; + const int numTargets = 1; + const int numPreds = 1; + const int numBatches = 1; + std::vector predBoxesVec = {2, 2, 3, 3}; + + std::vector targetBoxesVec = {2, 2, 3, 3}; + + std::vector targetClassVec = {1}; + auto predBoxes = fl::Variable( + Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), + true + ); + auto predLogits = + fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); + + std::vector targetBoxes = {fl::Variable( + Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), + false + )}; + + std::vector targetClasses = {fl::Variable( + Tensor::fromVector({numTargets, numBatches}, targetClassVec), + false + )}; + auto matcher = HungarianMatcher(1, 1, 1); + auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); + auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); + EXPECT_EQ(loss["lossGiou_0"].scalar(), 0.0); } TEST(SetCriterion, PytorchReproMultiplePreds) { - // TODO: This should really be a fixture - const int numClasses = 80; - const int numTargets = 1; - const int numPreds = 2; - const int numBatches = 1; - std::vector predBoxesVec = {2, 2, 3, 3, 1, 1, 2, 2}; - - std::vector targetBoxesVec = {2, 2, 3, 3}; - - std::vector targetClassVec = {1}; - auto predBoxes = fl::Variable( - Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), true); - auto predLogits = - fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - - std::vector targetBoxes = {fl::Variable( - Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), false)}; - - std::vector targetClasses = {fl::Variable( - Tensor::fromVector({1, numTargets, numBatches}, targetClassVec), false)}; - auto matcher = HungarianMatcher(1, 1, 1); - auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); - auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); - EXPECT_EQ(loss["lossGiou_0"].scalar(), 0.0); + // TODO: This should really be a fixture + const int numClasses = 80; + const int numTargets = 1; + const int numPreds = 2; + const int numBatches = 1; + std::vector predBoxesVec = {2, 2, 3, 3, 1, 1, 2, 2}; + + std::vector targetBoxesVec = {2, 2, 3, 3}; + + std::vector targetClassVec = {1}; + auto predBoxes = fl::Variable( + Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), + true + ); + auto predLogits = + fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); + + std::vector targetBoxes = {fl::Variable( + Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), + false + )}; + + std::vector targetClasses = {fl::Variable( + Tensor::fromVector({1, numTargets, numBatches}, targetClassVec), + false + )}; + auto matcher = HungarianMatcher(1, 1, 1); + auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); + auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); + EXPECT_EQ(loss["lossGiou_0"].scalar(), 0.0); } TEST(SetCriterion, PytorchReproMultipleTargets) { - const int numClasses = 80; - const int numTargets = 2; - const int numPreds = 2; - const int numBatches = 1; - std::vector predBoxesVec = {2, 2, 3, 3, 1, 1, 2, 2}; - - std::vector targetBoxesVec = { - 1, - 1, - 2, - 2, - 2, - 2, - 3, - 3, - }; - - std::vector targetClassVec = {1}; - auto predBoxes = fl::Variable( - Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), true); - auto predLogits = - fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - - std::vector targetBoxes = {fl::Variable( - Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), false)}; - - std::vector targetClasses = {fl::Variable( - Tensor::fromVector({numTargets, numBatches}, targetClassVec), false)}; - auto matcher = HungarianMatcher(1, 1, 1); - auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); - auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); - EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.0); + const int numClasses = 80; + const int numTargets = 2; + const int numPreds = 2; + const int numBatches = 1; + std::vector predBoxesVec = {2, 2, 3, 3, 1, 1, 2, 2}; + + std::vector targetBoxesVec = { + 1, + 1, + 2, + 2, + 2, + 2, + 3, + 3, + }; + + std::vector targetClassVec = {1}; + auto predBoxes = fl::Variable( + Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), + true + ); + auto predLogits = + fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); + + std::vector targetBoxes = {fl::Variable( + Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), + false + )}; + + std::vector targetClasses = {fl::Variable( + Tensor::fromVector({numTargets, numBatches}, targetClassVec), + false + )}; + auto matcher = HungarianMatcher(1, 1, 1); + auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); + auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); + EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.0); } TEST(SetCriterion, PytorchReproNoPerfectMatch) { - const int numClasses = 80; - const int numTargets = 2; - const int numPreds = 2; - const int numBatches = 1; - std::vector predBoxesVec = {2, 2, 3, 3, 1, 1, 2, 2}; - - std::vector targetBoxesVec = { - 0.9, 0.8, 1.9, 1.95, 1.9, 1.95, 2.9, 2.95}; - - // std::vector predLogitsVec((numClasses + 1) * numPreds * numPreds, - // 0.0); - - std::vector targetClassVec = {1, 1}; - - auto predBoxes = fl::Variable( - Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), true); - auto predLogits = - fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - - std::vector targetBoxes = {fl::Variable( - Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), false)}; - - std::vector targetClasses = {fl::Variable( - Tensor::fromVector({numTargets, numBatches}, targetClassVec), false)}; - auto matcher = HungarianMatcher(1, 1, 1); - auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); - auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); - EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.18111613); - EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 0.3750); + const int numClasses = 80; + const int numTargets = 2; + const int numPreds = 2; + const int numBatches = 1; + std::vector predBoxesVec = {2, 2, 3, 3, 1, 1, 2, 2}; + + std::vector targetBoxesVec = { + 0.9, 0.8, 1.9, 1.95, 1.9, 1.95, 2.9, 2.95}; + + // std::vector predLogitsVec((numClasses + 1) * numPreds * numPreds, + // 0.0); + + std::vector targetClassVec = {1, 1}; + + auto predBoxes = fl::Variable( + Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), + true + ); + auto predLogits = + fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); + + std::vector targetBoxes = {fl::Variable( + Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), + false + )}; + + std::vector targetClasses = {fl::Variable( + Tensor::fromVector({numTargets, numBatches}, targetClassVec), + false + )}; + auto matcher = HungarianMatcher(1, 1, 1); + auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); + auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); + EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.18111613); + EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 0.3750); } TEST(SetCriterion, PytorchMismatch1) { - const int numClasses = 80; - const int numTargets = 1; - const int numPreds = 1; - const int numBatches = 1; - std::vector predBoxesVec = { - 2, - 2, - 3, - 3, - }; - - std::vector targetBoxesVec1 = { - 1, - 1, - 2, - 2, - }; - - std::vector targetClassVec = {1, 1}; - - auto predBoxes = fl::Variable( - Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), true); - auto predLogits = - fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - - std::vector targetBoxes = { - fl::Variable( - Tensor::fromVector({4, numTargets, numPreds}, targetBoxesVec1), - false), - }; - - std::vector targetClasses = { - fl::Variable( - Tensor::fromVector({1, numTargets, numPreds}, targetClassVec), false), - }; - auto matcher = HungarianMatcher(1, 1, 1); - auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); - auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); - EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.91314667f); - EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 4.f); + const int numClasses = 80; + const int numTargets = 1; + const int numPreds = 1; + const int numBatches = 1; + std::vector predBoxesVec = { + 2, + 2, + 3, + 3, + }; + + std::vector targetBoxesVec1 = { + 1, + 1, + 2, + 2, + }; + + std::vector targetClassVec = {1, 1}; + + auto predBoxes = fl::Variable( + Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), + true + ); + auto predLogits = + fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); + + std::vector targetBoxes = { + fl::Variable( + Tensor::fromVector({4, numTargets, numPreds}, targetBoxesVec1), + false + ), + }; + + std::vector targetClasses = { + fl::Variable( + Tensor::fromVector({1, numTargets, numPreds}, targetClassVec), + false + ), + }; + auto matcher = HungarianMatcher(1, 1, 1); + auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); + auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); + EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.91314667f); + EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 4.f); } TEST(SetCriterion, PytorchMismatch2) { - const int numClasses = 80; - const int numTargets = 1; - const int numPreds = 1; - const int numBatches = 1; - std::vector predBoxesVec = { - 1, - 1, - 2, - 2, - }; - - std::vector targetBoxesVec1 = { - 2, - 2, - 3, - 3, - }; - - std::vector targetClassVec = {1, 1}; - - auto predBoxes = fl::Variable( - Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), true); - auto predLogits = - fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - - std::vector targetBoxes = { - fl::Variable( - Tensor::fromVector({4, numTargets, numPreds}, targetBoxesVec1), - false), - }; - - std::vector targetClasses = { - fl::Variable( - Tensor::fromVector({1, numTargets, numPreds}, targetClassVec), false), - }; - auto matcher = HungarianMatcher(1, 1, 1); - auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); - auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); - EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.91314667f); - EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 4.0f); + const int numClasses = 80; + const int numTargets = 1; + const int numPreds = 1; + const int numBatches = 1; + std::vector predBoxesVec = { + 1, + 1, + 2, + 2, + }; + + std::vector targetBoxesVec1 = { + 2, + 2, + 3, + 3, + }; + + std::vector targetClassVec = {1, 1}; + + auto predBoxes = fl::Variable( + Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), + true + ); + auto predLogits = + fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); + + std::vector targetBoxes = { + fl::Variable( + Tensor::fromVector({4, numTargets, numPreds}, targetBoxesVec1), + false + ), + }; + + std::vector targetClasses = { + fl::Variable( + Tensor::fromVector({1, numTargets, numPreds}, targetClassVec), + false + ), + }; + auto matcher = HungarianMatcher(1, 1, 1); + auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); + auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); + EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.91314667f); + EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 4.0f); } TEST(SetCriterion, PytorchReproBatching) { - const int numClasses = 80; - const int numTargets = 1; - const int numPreds = 1; - const int numBatches = 2; - std::vector predBoxesVec = {2, 2, 3, 3, 1, 1, 2, 2}; - - std::vector targetBoxesVec1 = { - 1, - 1, - 2, - 2, - }; - - std::vector targetBoxesVec2 = { - 2, - 2, - 3, - 3, - }; - - std::vector targetClassVec = {1, 1}; - - auto predBoxes = fl::Variable( - Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), true); - auto predLogits = - fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - - std::vector targetBoxes = { - fl::Variable( - Tensor::fromVector({4, numTargets, numPreds}, targetBoxesVec1), - false), - fl::Variable( - Tensor::fromVector({4, numTargets, numPreds}, targetBoxesVec2), - false)}; - - std::vector targetClasses = { - fl::Variable( - Tensor::fromVector({numTargets, numPreds, 1}, targetClassVec), false), - fl::Variable( - Tensor::fromVector({numTargets, numPreds, 1}, targetClassVec), - false)}; - auto matcher = HungarianMatcher(1, 1, 1); - auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); - auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); - EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.91314667f); - EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 4.f); + const int numClasses = 80; + const int numTargets = 1; + const int numPreds = 1; + const int numBatches = 2; + std::vector predBoxesVec = {2, 2, 3, 3, 1, 1, 2, 2}; + + std::vector targetBoxesVec1 = { + 1, + 1, + 2, + 2, + }; + + std::vector targetBoxesVec2 = { + 2, + 2, + 3, + 3, + }; + + std::vector targetClassVec = {1, 1}; + + auto predBoxes = fl::Variable( + Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), + true + ); + auto predLogits = + fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); + + std::vector targetBoxes = { + fl::Variable( + Tensor::fromVector({4, numTargets, numPreds}, targetBoxesVec1), + false + ), + fl::Variable( + Tensor::fromVector({4, numTargets, numPreds}, targetBoxesVec2), + false + )}; + + std::vector targetClasses = { + fl::Variable( + Tensor::fromVector({numTargets, numPreds, 1}, targetClassVec), + false + ), + fl::Variable( + Tensor::fromVector({numTargets, numPreds, 1}, targetClassVec), + false + )}; + auto matcher = HungarianMatcher(1, 1, 1); + auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); + auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); + EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.91314667f); + EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 4.f); } TEST(SetCriterion, DifferentNumberOfLabels) { - const int numClasses = 80; - const int numPreds = 2; - const int numBatches = 2; - std::vector predBoxesVec = { - 2, 2, 3, 3, 1, 1, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2}; - - std::vector targetBoxesVec1 = { - 1, - 1, - 2, - 2, - 2, - 2, - 3, - 3, - }; - - std::vector targetBoxesVec2 = { - 2, - 2, - 3, - 3, - }; - - // std::vector predLogitsVec((numClasses + 1) * numPreds * numPreds, - // 0.0); - - std::vector targetClassVec = {1, 1}; - - auto predBoxes = fl::Variable( - Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), true); - auto predLogits = - fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - - std::vector targetBoxes = { - fl::Variable(Tensor::fromVector({4, 2, 1}, targetBoxesVec1), false), - fl::Variable(Tensor::fromVector({4, 1, 1}, targetBoxesVec2), false)}; - - std::vector targetClasses = { - fl::Variable(fl::full({2, 1, 1}, 1), false), - fl::Variable(fl::full({1, 1, 1}, 1), false)}; - auto matcher = HungarianMatcher(1, 1, 1); - auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); - auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); - EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.f); - EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 0.f); + const int numClasses = 80; + const int numPreds = 2; + const int numBatches = 2; + std::vector predBoxesVec = { + 2, 2, 3, 3, 1, 1, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2}; + + std::vector targetBoxesVec1 = { + 1, + 1, + 2, + 2, + 2, + 2, + 3, + 3, + }; + + std::vector targetBoxesVec2 = { + 2, + 2, + 3, + 3, + }; + + // std::vector predLogitsVec((numClasses + 1) * numPreds * numPreds, + // 0.0); + + std::vector targetClassVec = {1, 1}; + + auto predBoxes = fl::Variable( + Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), + true + ); + auto predLogits = + fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); + + std::vector targetBoxes = { + fl::Variable(Tensor::fromVector({4, 2, 1}, targetBoxesVec1), false), + fl::Variable(Tensor::fromVector({4, 1, 1}, targetBoxesVec2), false)}; + + std::vector targetClasses = { + fl::Variable(fl::full({2, 1, 1}, 1), false), + fl::Variable(fl::full({1, 1, 1}, 1), false)}; + auto matcher = HungarianMatcher(1, 1, 1); + auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); + auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); + EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.f); + EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 0.f); } // Test to make sure class labels are properly handles across batches TEST(SetCriterion, DifferentNumberOfLabelsClass) { - const int numClasses = 80; - const int numPreds = 3; - const int numBatches = 2; - std::vector predBoxesVec = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - - std::vector targetBoxesVec1 = {1, 1, 1, 1, 1, 1, 1, 1}; - - std::vector targetBoxesVec2 = { - 1, - 1, - 1, - 1, - }; - - auto predBoxes = fl::Variable( - Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), true); - auto predLogitsT = fl::full({numClasses + 1, numPreds, numBatches}, 1.); - predLogitsT(1, 1, 0) = 10; // These should get matched - predLogitsT(2, 2, 0) = 10; - predLogitsT(9, 1, 1) = 10; - auto predLogits = fl::Variable(predLogitsT, true); - - std::vector targetBoxes = { - fl::Variable(Tensor::fromVector({4, 2, 1}, targetBoxesVec1), false), - fl::Variable(Tensor::fromVector({4, 1, 1}, targetBoxesVec2), false)}; - - std::vector targetClasses = { - fl::Variable(fl::iota({2}), false), - fl::Variable(fl::full({1, 1, 1}, 9), false)}; - auto matcher = HungarianMatcher(1, 1, 1); - auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); - auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); - EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.f); - EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 0.f); - EXPECT_NEAR(loss["lossCe_0"].scalar(), 1.4713663f, 1e-4); + const int numClasses = 80; + const int numPreds = 3; + const int numBatches = 2; + std::vector predBoxesVec = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + + std::vector targetBoxesVec1 = {1, 1, 1, 1, 1, 1, 1, 1}; + + std::vector targetBoxesVec2 = { + 1, + 1, + 1, + 1, + }; + + auto predBoxes = fl::Variable( + Tensor::fromVector({4, numPreds, numBatches, 1}, predBoxesVec), + true + ); + auto predLogitsT = fl::full({numClasses + 1, numPreds, numBatches}, 1.); + predLogitsT(1, 1, 0) = 10; // These should get matched + predLogitsT(2, 2, 0) = 10; + predLogitsT(9, 1, 1) = 10; + auto predLogits = fl::Variable(predLogitsT, true); + + std::vector targetBoxes = { + fl::Variable(Tensor::fromVector({4, 2, 1}, targetBoxesVec1), false), + fl::Variable(Tensor::fromVector({4, 1, 1}, targetBoxesVec2), false)}; + + std::vector targetClasses = { + fl::Variable(fl::iota({2}), false), + fl::Variable(fl::full({1, 1, 1}, 9), false)}; + auto matcher = HungarianMatcher(1, 1, 1); + auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); + auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); + EXPECT_FLOAT_EQ(loss["lossGiou_0"].scalar(), 0.f); + EXPECT_FLOAT_EQ(loss["lossBbox_0"].scalar(), 0.f); + EXPECT_NEAR(loss["lossCe_0"].scalar(), 1.4713663f, 1e-4); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } diff --git a/flashlight/pkg/vision/test/dataset/BoxUtilsTest.cpp b/flashlight/pkg/vision/test/dataset/BoxUtilsTest.cpp index 2da62d2..1ded939 100644 --- a/flashlight/pkg/vision/test/dataset/BoxUtilsTest.cpp +++ b/flashlight/pkg/vision/test/dataset/BoxUtilsTest.cpp @@ -14,151 +14,153 @@ using namespace fl::pkg::vision; TEST(BoxUtils, IOU) { - std::vector labels = {0, 0, 10, 10, 1}; - std::vector preds = {1, 1, 11, 11, 1}; - std::vector costs = {0.680672268908}; - fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; - fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 1}, preds), false}; - fl::Variable result, uni; - std::tie(result, uni) = boxIou(predArr, labelArr); - EXPECT_EQ(result.tensor()(0, 0).scalar(), costs[0]); + std::vector labels = {0, 0, 10, 10, 1}; + std::vector preds = {1, 1, 11, 11, 1}; + std::vector costs = {0.680672268908}; + fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; + fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 1}, preds), false}; + fl::Variable result, uni; + std::tie(result, uni) = boxIou(predArr, labelArr); + EXPECT_EQ(result.tensor()(0, 0).scalar(), costs[0]); } TEST(BoxUtils, IOU2) { - std::vector labels = {0, 0, 10, 10, 1}; - std::vector preds = {12, 12, 22, 22, 1}; - std::vector costs = {0.0}; - fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; - fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 1}, preds), false}; - fl::Variable result, uni; - std::tie(result, uni) = boxIou(predArr, labelArr); - EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); + std::vector labels = {0, 0, 10, 10, 1}; + std::vector preds = {12, 12, 22, 22, 1}; + std::vector costs = {0.0}; + fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; + fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 1}, preds), false}; + fl::Variable result, uni; + std::tie(result, uni) = boxIou(predArr, labelArr); + EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); } TEST(BoxUtils, IOU3) { - std::vector labels = {0, 0, 2, 2, 1}; - std::vector preds = {1, 1, 3, 3, 1}; - std::vector costs = {0.142857142857}; - fl::Variable labelArr = - fl::Variable(fl::Tensor::fromVector({5, 1, 1}, labels), false); - fl::Variable predArr = - fl::Variable(fl::Tensor::fromVector({5, 1, 1}, preds), false); - fl::Variable result, uni; - std::tie(result, uni) = boxIou(predArr, labelArr); - EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); + std::vector labels = {0, 0, 2, 2, 1}; + std::vector preds = {1, 1, 3, 3, 1}; + std::vector costs = {0.142857142857}; + fl::Variable labelArr = + fl::Variable(fl::Tensor::fromVector({5, 1, 1}, labels), false); + fl::Variable predArr = + fl::Variable(fl::Tensor::fromVector({5, 1, 1}, preds), false); + fl::Variable result, uni; + std::tie(result, uni) = boxIou(predArr, labelArr); + EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); } TEST(BoxUtils, IOU4) { - std::vector labels = {0, 0, 2, 2, 1}; - std::vector preds = {3, 0, 5, 2, 1, 1}; - std::vector costs = {0.0}; - fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; - fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 1}, preds), false}; - fl::Variable result, uni; - std::tie(result, uni) = boxIou(predArr, labelArr); - EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); + std::vector labels = {0, 0, 2, 2, 1}; + std::vector preds = {3, 0, 5, 2, 1, 1}; + std::vector costs = {0.0}; + fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; + fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 1}, preds), false}; + fl::Variable result, uni; + std::tie(result, uni) = boxIou(predArr, labelArr); + EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); } TEST(BoxUtils, IOU5) { - std::vector labels = {0, 0, 2, 2, 1}; - std::vector preds = {1, 1, 3, 3, 1}; - std::vector costs = {0.14285714285714285}; - fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; - fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 1}, preds), false}; - fl::Variable result, uni; - std::tie(result, uni) = boxIou(predArr, labelArr); - EXPECT_EQ(result.tensor()(0, 0).scalar(), costs[0]); + std::vector labels = {0, 0, 2, 2, 1}; + std::vector preds = {1, 1, 3, 3, 1}; + std::vector costs = {0.14285714285714285}; + fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; + fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 1}, preds), false}; + fl::Variable result, uni; + std::tie(result, uni) = boxIou(predArr, labelArr); + EXPECT_EQ(result.tensor()(0, 0).scalar(), costs[0]); } TEST(BoxUtils, IOU6) { - std::vector labels = {0, 0, 4, 4, 1}; - std::vector preds = {1, 1, 3, 3, 1}; - std::vector costs = {0.25}; - fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; - fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 1}, preds), false}; - fl::Variable result, uni; - std::tie(result, uni) = boxIou(predArr, labelArr); - EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); + std::vector labels = {0, 0, 4, 4, 1}; + std::vector preds = {1, 1, 3, 3, 1}; + std::vector costs = {0.25}; + fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; + fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 1}, preds), false}; + fl::Variable result, uni; + std::tie(result, uni) = boxIou(predArr, labelArr); + EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); } TEST(BoxUtils, IOU7) { - std::vector preds = {1, 1, 3, 3, 1, 0, 1, 2, 3, 1}; - std::vector labels = {0, 0, 4, 4, 1, 0, 0, 2, 2, 1}; - std::vector costs = { - 0.25, - 0.25, // Both boxes are contained in first box - 0.14285714285714285, - 0.3333333333 // - }; - fl::Variable labelArr = {fl::Tensor::fromVector({5, 2, 1}, labels), false}; - fl::Variable predArr = {fl::Tensor::fromVector({5, 2, 1}, preds), false}; - fl::Variable result, uni; - std::tie(result, uni) = boxIou(predArr, labelArr); - EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); - EXPECT_EQ(result(1, 0).tensor().scalar(), costs[1]); - EXPECT_EQ(result(0, 1).tensor().scalar(), costs[2]); - EXPECT_EQ(result(1, 1).tensor().scalar(), costs[3]); + std::vector preds = {1, 1, 3, 3, 1, 0, 1, 2, 3, 1}; + std::vector labels = {0, 0, 4, 4, 1, 0, 0, 2, 2, 1}; + std::vector costs = { + 0.25, + 0.25, // Both boxes are contained in first box + 0.14285714285714285, + 0.3333333333 // + }; + fl::Variable labelArr = {fl::Tensor::fromVector({5, 2, 1}, labels), false}; + fl::Variable predArr = {fl::Tensor::fromVector({5, 2, 1}, preds), false}; + fl::Variable result, uni; + std::tie(result, uni) = boxIou(predArr, labelArr); + EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); + EXPECT_EQ(result(1, 0).tensor().scalar(), costs[1]); + EXPECT_EQ(result(0, 1).tensor().scalar(), costs[2]); + EXPECT_EQ(result(1, 1).tensor().scalar(), costs[3]); } TEST(BoxUtils, IOU8) { - std::vector preds = {1, 1, 3, 3, 1, 0, 1, 2, 3, 1}; - std::vector labels = { - 0, - 0, - 4, - 4, - 1, - }; - std::vector costs = { - 0.25, 0.25, // Both boxes are contained in first box - }; - fl::Variable labelArr = {fl::Tensor::fromVector({5, 2, 1}, labels), false}; - fl::Variable predArr = {fl::Tensor::fromVector({5, 2, 1}, preds), false}; - fl::Variable result, uni; - std::tie(result, uni) = boxIou(predArr, labelArr); - EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); - EXPECT_EQ(result(1, 0).tensor().scalar(), costs[1]); + std::vector preds = {1, 1, 3, 3, 1, 0, 1, 2, 3, 1}; + std::vector labels = { + 0, + 0, + 4, + 4, + 1, + }; + std::vector costs = { + 0.25, 0.25, // Both boxes are contained in first box + }; + fl::Variable labelArr = {fl::Tensor::fromVector({5, 2, 1}, labels), false}; + fl::Variable predArr = {fl::Tensor::fromVector({5, 2, 1}, preds), false}; + fl::Variable result, uni; + std::tie(result, uni) = boxIou(predArr, labelArr); + EXPECT_EQ(result(0, 0).tensor().scalar(), costs[0]); + EXPECT_EQ(result(1, 0).tensor().scalar(), costs[1]); } TEST(BoxUtils, IOUBatched) { - std::vector preds = {1, 1, 3, 3, 1, 0, 1, 2, 3, 1}; - std::vector labels = { - 0, - 0, - 4, - 4, - 1, - 0, - 0, - 4, - 4, - 1, - }; - std::vector costs = { - 0.25, 0.25, // Both boxes are contained in first box - }; - fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 2}, labels), false}; - fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 2}, preds), false}; - fl::Variable result, uni; - std::tie(result, uni) = boxIou(predArr, labelArr); - EXPECT_EQ(result(0, 0, 0).tensor().scalar(), costs[0]); - EXPECT_EQ(result(0, 0, 1).tensor().scalar(), costs[1]); + std::vector preds = {1, 1, 3, 3, 1, 0, 1, 2, 3, 1}; + std::vector labels = { + 0, + 0, + 4, + 4, + 1, + 0, + 0, + 4, + 4, + 1, + }; + std::vector costs = { + 0.25, 0.25, // Both boxes are contained in first box + }; + fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 2}, labels), false}; + fl::Variable predArr = {fl::Tensor::fromVector({5, 1, 2}, preds), false}; + fl::Variable result, uni; + std::tie(result, uni) = boxIou(predArr, labelArr); + EXPECT_EQ(result(0, 0, 0).tensor().scalar(), costs[0]); + EXPECT_EQ(result(0, 0, 1).tensor().scalar(), costs[1]); } //// Test GIOU //// The first box is further away from the second box and should have a smaller //// score TEST(BoxUtils, GIOU) { - std::vector preds = {0, 0, 1, 1, 1, 1, 1, 2, 2, 1}; - std::vector labels = {2, 2, 3, 3, 1}; - fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; - fl::Variable predArr = {fl::Tensor::fromVector({5, 2, 1}, preds), false}; - fl::Variable result = generalizedBoxIou(predArr, labelArr); - EXPECT_LT( - result(0, 0).tensor().scalar(), result(1, 0).scalar()); + std::vector preds = {0, 0, 1, 1, 1, 1, 1, 2, 2, 1}; + std::vector labels = {2, 2, 3, 3, 1}; + fl::Variable labelArr = {fl::Tensor::fromVector({5, 1, 1}, labels), false}; + fl::Variable predArr = {fl::Tensor::fromVector({5, 2, 1}, preds), false}; + fl::Variable result = generalizedBoxIou(predArr, labelArr); + EXPECT_LT( + result(0, 0).tensor().scalar(), + result(1, 0).scalar() + ); } int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - fl::init(); - return RUN_ALL_TESTS(); + ::testing::InitGoogleTest(&argc, argv); + fl::init(); + return RUN_ALL_TESTS(); } From 40551e00435a32c06d8ade696c19a6e75270dcdc Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 20:02:53 +0100 Subject: [PATCH 15/24] added .cu files to cpp globber --- cmake/utils/fm_target_utilities.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/utils/fm_target_utilities.cmake b/cmake/utils/fm_target_utilities.cmake index b625a2b..9403bfa 100644 --- a/cmake/utils/fm_target_utilities.cmake +++ b/cmake/utils/fm_target_utilities.cmake @@ -84,7 +84,7 @@ endfunction() #]] function(fm_glob_cpp OUT_VAR) - fm_glob(${OUT_VAR} ${ARGN} PATTERNS "*.cpp" "*.hpp" "*.inl" "*.h") + fm_glob(${OUT_VAR} ${ARGN} PATTERNS "*.cpp" "*.hpp" "*.inl" "*.h" "*.cu") set(${OUT_VAR} ${${OUT_VAR}} PARENT_SCOPE) endfunction() From 8198ad5e0cbc5fd010a750eb9121fff09c347c89 Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 20:04:31 +0100 Subject: [PATCH 16/24] formatted .cu files --- .../fl/tensor/backend/af/AdvancedIndex.cu | 350 +++++++++--------- .../third_party/warpctc/src/ctc_entrypoint.cu | 217 ++++++----- .../speech/third_party/warpctc/src/reduce.cu | 173 ++++++--- 3 files changed, 437 insertions(+), 303 deletions(-) diff --git a/flashlight/fl/tensor/backend/af/AdvancedIndex.cu b/flashlight/fl/tensor/backend/af/AdvancedIndex.cu index 64bd359..84b729f 100644 --- a/flashlight/fl/tensor/backend/af/AdvancedIndex.cu +++ b/flashlight/fl/tensor/backend/af/AdvancedIndex.cu @@ -19,190 +19,198 @@ #define GRID_SIZE 32 #define BLOCK_SIZE 256 -const std::unordered_set validIndexTypes{ +const std::unordered_set < af::dtype > validIndexTypes { af::dtype::s32, af::dtype::s64, af::dtype::u32, - af::dtype::u64}; + af::dtype::u64 +}; -template -__global__ void advancedIndexKernel( +template < class Float, class Index +> __global__ void advancedIndexKernel( const Float* inp, const dim_t* idxStart, const dim_t* idxEnd, const dim_t* outDims, const dim_t* idxArr, - Float* out) { - // Compute striding information for - // the input and output tensors - dim_t dims[4], strides[4]; - dim_t outStrides[4]; - for (int i = 0; i < 4; i++) { - dims[i] = idxEnd[i] - idxStart[i]; - } - strides[0] = 1; - outStrides[0] = 1; - // arrayfire dimensions are inverted compared to numpy - // hence, stride computation starts from 1 to 4 - for (int i = 1; i < 4; i++) { - strides[i] = strides[i - 1] * dims[i - 1]; - outStrides[i] = outStrides[i - 1] * outDims[i - 1]; - } - - // Map CUDA thread to an element in the input array - for (dim_t tid = threadIdx.x + blockIdx.x * BLOCK_SIZE; - tid < (strides[3] * dims[3]); - tid += (GRID_SIZE * BLOCK_SIZE)) { - // Compute input array index for CUDA thread - dim_t index[4]; - dim_t cursor = tid; - for (int i = 3; i >= 0; i--) { - index[i] = cursor / strides[i]; - cursor = cursor % strides[i]; + Float* out +) { + // Compute striding information for + // the input and output tensors + dim_t dims[4], strides[4]; + dim_t outStrides[4]; + for(int i = 0; i < 4; i++) { + dims[i] = idxEnd[i] - idxStart[i]; } - - dim_t inpIdx = tid; - dim_t outIdx = 0; - for (int i = 0; i < 4; i++) { - // If indexing array specified, use it - if (idxArr[i]) { - auto idxArrPtr = (Index*)idxArr[i]; - outIdx += idxArrPtr[index[i]] * outStrides[i]; - } else { - outIdx += (idxStart[i] + index[i]) * outStrides[i]; - } + strides[0] = 1; + outStrides[0] = 1; + // arrayfire dimensions are inverted compared to numpy + // hence, stride computation starts from 1 to 4 + for(int i = 1; i < 4; i++) { + strides[i] = strides[i - 1] * dims[i - 1]; + outStrides[i] = outStrides[i - 1] * outDims[i - 1]; } - // atomic addition is done to ensure correct - // gradient computation for repeated indices - atomicAdd(&out[outIdx], inp[inpIdx]); - } -} -namespace fl { -namespace detail { - -void advancedIndex( - const af::array& inp, - const af::dim4& idxStart, - const af::dim4& idxEnd, - const af::dim4& outDims, - const std::vector& idxArr, - af::array& out) { - auto inpType = inp.type(); - auto outType = out.type(); - - if ((inpType != af::dtype::f32) && (inpType != af::dtype::f16)) { - throw std::invalid_argument("Input type must be f16/f32"); - } - if ((outType != af::dtype::f32) && (outType != af::dtype::f16)) { - throw std::invalid_argument("Output type must be f16/f32"); - } - if (idxArr.size() != 4) { - throw std::invalid_argument("Index array vector must be length 4"); - } - - af::dim4 idxPtr; - // Extract raw device pointers for dimensions - // that have an array as af::index variable - - // Dtype checking - std::vector idxTypes; - for (int i = 0; i < 4; i++) { - if (idxArr[i].isempty()) { - idxPtr[i] = 0; - continue; + // Map CUDA thread to an element in the input array + for( + dim_t tid = threadIdx.x + blockIdx.x * BLOCK_SIZE; + tid < (strides[3] * dims[3]); + tid += (GRID_SIZE * BLOCK_SIZE) + ) { + // Compute input array index for CUDA thread + dim_t index[4]; + dim_t cursor = tid; + for(int i = 3; i >= 0; i--) { + index[i] = cursor / strides[i]; + cursor = cursor % strides[i]; + } + + dim_t inpIdx = tid; + dim_t outIdx = 0; + for(int i = 0; i < 4; i++) { + // If indexing array specified, use it + if(idxArr[i]) { + auto idxArrPtr = (Index*) idxArr[i]; + outIdx += idxArrPtr[index[i]] * outStrides[i]; + } else { + outIdx += (idxStart[i] + index[i]) * outStrides[i]; + } + } + // atomic addition is done to ensure correct + // gradient computation for repeated indices + atomicAdd(&out[outIdx], inp[inpIdx]); } - if (validIndexTypes.find(idxArr[i].type()) == validIndexTypes.end()) { - throw std::invalid_argument( - "Index type must be one of s32/s64/u32/u64, observed type is " + - std::to_string(idxArr[i].type())); - } - idxTypes.push_back(idxArr[i].type()); - idxPtr[i] = (dim_t)(idxArr[i].device()); - } - for (int i = 0; i + 1 < idxTypes.size(); i++) { - if (idxTypes[i] != idxTypes[i + 1]) { - throw std::invalid_argument( - "Index type must be the same across all dimensions"); - } - } - - af::array inpCast = inp; - af::array outCast = out; - if (inpType == af::dtype::f16) { - inpCast = inp.as(af::dtype::f32); - } - if (outType == af::dtype::f16) { - outCast = out.as(af::dtype::f32); - } - - void* inpRawPtr = inpCast.device(); - void* outRawPtr = outCast.device(); - af::array arrIdxPtr(4, idxPtr.get()); - af::array arrIdxEnd(4, idxEnd.get()); - af::array arrIdxStart(4, idxStart.get()); - af::array arrOutDims(4, outDims.get()); - void* arrIdxStartDev = arrIdxStart.device(); - void* arrIdxEndDev = arrIdxEnd.device(); - void* arrOutDimsDev = arrOutDims.device(); - void* arrIdxPtrDev = arrIdxPtr.device(); - - cudaStream_t stream = afcu::getStream(af::getDevice()); - if (idxTypes.size() == 0 || idxTypes[0] == af::dtype::s32) { - advancedIndexKernel<<>>( - static_cast(inpRawPtr), - static_cast(arrIdxStartDev), - static_cast(arrIdxEndDev), - static_cast(arrOutDimsDev), - static_cast(arrIdxPtrDev), - static_cast(outRawPtr)); - } else if (idxTypes[0] == af::dtype::s64) { - advancedIndexKernel<<>>( - static_cast(inpRawPtr), - static_cast(arrIdxStartDev), - static_cast(arrIdxEndDev), - static_cast(arrOutDimsDev), - static_cast(arrIdxPtrDev), - static_cast(outRawPtr)); - } else if (idxTypes[0] == af::dtype::u32) { - advancedIndexKernel<<>>( - static_cast(inpRawPtr), - static_cast(arrIdxStartDev), - static_cast(arrIdxEndDev), - static_cast(arrOutDimsDev), - static_cast(arrIdxPtrDev), - static_cast(outRawPtr)); - } else if (idxTypes[0] == af::dtype::u64) { - advancedIndexKernel<<>>( - static_cast(inpRawPtr), - static_cast(arrIdxStartDev), - static_cast(arrIdxEndDev), - static_cast(arrOutDimsDev), - static_cast(arrIdxPtrDev), - static_cast(outRawPtr)); - } else { - throw std::invalid_argument("Index type must be one of s32/s64/u32/u64"); - } - if (cudaPeekAtLastError() != cudaSuccess) { - throw std::runtime_error( - "ArrayFireTensor advancedIndex kernel CUDA failure"); - } - - inpCast.unlock(); - outCast.unlock(); - arrIdxStart.unlock(); - arrIdxEnd.unlock(); - arrOutDims.unlock(); - arrIdxPtr.unlock(); - for (const auto& arr : idxArr) { - arr.unlock(); - } - - out = outCast; - if (outType == af::dtype::f16) { - out = outCast.as(af::dtype::f16); - } } -} // namespace detail +namespace fl { + namespace detail { + + void advancedIndex( + const af::array& inp, + const af::dim4& idxStart, + const af::dim4& idxEnd, + const af::dim4& outDims, + const std::vector < af::array > &idxArr, + af::array& out + ) { + auto inpType = inp.type(); + auto outType = out.type(); + + if((inpType != af::dtype::f32) && (inpType != af::dtype::f16)) { + throw std::invalid_argument("Input type must be f16/f32"); + } + if((outType != af::dtype::f32) && (outType != af::dtype::f16)) { + throw std::invalid_argument("Output type must be f16/f32"); + } + if(idxArr.size() != 4) { + throw std::invalid_argument("Index array vector must be length 4"); + } + + af::dim4 idxPtr; + // Extract raw device pointers for dimensions + // that have an array as af::index variable + + // Dtype checking + std::vector < af::dtype > idxTypes; + for(int i = 0; i < 4; i++) { + if(idxArr[i].isempty()) { + idxPtr[i] = 0; + continue; + } + if(validIndexTypes.find(idxArr[i].type()) == validIndexTypes.end()) { + throw std::invalid_argument( + "Index type must be one of s32/s64/u32/u64, observed type is " + + std::to_string(idxArr[i].type()) + ); + } + idxTypes.push_back(idxArr[i].type()); + idxPtr[i] = (dim_t) (idxArr[i].device < void > ()); + } + for(int i = 0; i + 1 < idxTypes.size(); i++) { + if(idxTypes[i] != idxTypes[i + 1]) { + throw std::invalid_argument( + "Index type must be the same across all dimensions" + ); + } + } + + af::array inpCast = inp; + af::array outCast = out; + if(inpType == af::dtype::f16) { + inpCast = inp.as(af::dtype::f32); + } + if(outType == af::dtype::f16) { + outCast = out.as(af::dtype::f32); + } + + void* inpRawPtr = inpCast.device < void > (); + void* outRawPtr = outCast.device < void > (); + af::array arrIdxPtr(4, idxPtr.get()); + af::array arrIdxEnd(4, idxEnd.get()); + af::array arrIdxStart(4, idxStart.get()); + af::array arrOutDims(4, outDims.get()); + void* arrIdxStartDev = arrIdxStart.device < void > (); + void* arrIdxEndDev = arrIdxEnd.device < void > (); + void* arrOutDimsDev = arrOutDims.device < void > (); + void* arrIdxPtrDev = arrIdxPtr.device < void > (); + + cudaStream_t stream = afcu::getStream(af::getDevice()); + if(idxTypes.size() == 0 || idxTypes[0] == af::dtype::s32) { + advancedIndexKernel < float, int32_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( + static_cast < const float* > (inpRawPtr), + static_cast < const dim_t * > (arrIdxStartDev), + static_cast < const dim_t * > (arrIdxEndDev), + static_cast < const dim_t * > (arrOutDimsDev), + static_cast < const dim_t * > (arrIdxPtrDev), + static_cast < float* > (outRawPtr)); + } else if(idxTypes[0] == af::dtype::s64) { + advancedIndexKernel < float, int64_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( + static_cast < const float* > (inpRawPtr), + static_cast < const dim_t * > (arrIdxStartDev), + static_cast < const dim_t * > (arrIdxEndDev), + static_cast < const dim_t * > (arrOutDimsDev), + static_cast < const dim_t * > (arrIdxPtrDev), + static_cast < float* > (outRawPtr)); + } else if(idxTypes[0] == af::dtype::u32) { + advancedIndexKernel < float, uint32_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( + static_cast < const float* > (inpRawPtr), + static_cast < const dim_t * > (arrIdxStartDev), + static_cast < const dim_t * > (arrIdxEndDev), + static_cast < const dim_t * > (arrOutDimsDev), + static_cast < const dim_t * > (arrIdxPtrDev), + static_cast < float* > (outRawPtr)); + } else if(idxTypes[0] == af::dtype::u64) { + advancedIndexKernel < float, uint64_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( + static_cast < const float* > (inpRawPtr), + static_cast < const dim_t * > (arrIdxStartDev), + static_cast < const dim_t * > (arrIdxEndDev), + static_cast < const dim_t * > (arrOutDimsDev), + static_cast < const dim_t * > (arrIdxPtrDev), + static_cast < float* > (outRawPtr)); + } else { + throw std::invalid_argument("Index type must be one of s32/s64/u32/u64"); + } + if(cudaPeekAtLastError() != cudaSuccess) { + throw std::runtime_error( + "ArrayFireTensor advancedIndex kernel CUDA failure" + ); + } + + inpCast.unlock(); + outCast.unlock(); + arrIdxStart.unlock(); + arrIdxEnd.unlock(); + arrOutDims.unlock(); + arrIdxPtr.unlock(); + for(const auto& arr : idxArr) { + arr.unlock(); + } + + out = outCast; + if(outType == af::dtype::f16) { + out = outCast.as(af::dtype::f16); + } + } + + } // namespace detail } // namespace fl diff --git a/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu b/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu index c817840..04678e1 100644 --- a/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu +++ b/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu @@ -6,7 +6,7 @@ #include "detail/cpu_ctc.h" #ifdef __CUDACC__ - #include "detail/gpu_ctc.h" +#include "detail/gpu_ctc.h" #endif @@ -17,69 +17,105 @@ int get_warpctc_version() { } const char* ctcGetStatusString(ctcStatus_t status) { - switch (status) { - case CTC_STATUS_SUCCESS: - return "no error"; - case CTC_STATUS_MEMOPS_FAILED: - return "cuda memcpy or memset failed"; - case CTC_STATUS_INVALID_VALUE: - return "invalid value"; - case CTC_STATUS_EXECUTION_FAILED: - return "execution failed"; - case CTC_STATUS_LABEL_LENGTH_TOO_LARGE: - return "label length >639 is not supported"; - case CTC_STATUS_UNKNOWN_ERROR: - default: - return "unknown error"; + switch(status) { + case CTC_STATUS_SUCCESS: + return "no error"; + case CTC_STATUS_MEMOPS_FAILED: + return "cuda memcpy or memset failed"; + case CTC_STATUS_INVALID_VALUE: + return "invalid value"; + case CTC_STATUS_EXECUTION_FAILED: + return "execution failed"; + case CTC_STATUS_LABEL_LENGTH_TOO_LARGE: + return "label length >639 is not supported"; + case CTC_STATUS_UNKNOWN_ERROR: + default: + return "unknown error"; } } -ctcStatus_t compute_ctc_loss(const float* const activations, - float* gradients, - const int* const flat_labels, - const int* const label_lengths, - const int* const input_lengths, - int alphabet_size, - int minibatch, - float *costs, - void *workspace, - ctcOptions options) { - if (activations == nullptr || - label_lengths == nullptr || - input_lengths == nullptr || - costs == nullptr || - workspace == nullptr || - alphabet_size <= 0 || - minibatch <= 0) +ctcStatus_t compute_ctc_loss( + const float* const activations, + float* gradients, + const int* const flat_labels, + const int* const label_lengths, + const int* const input_lengths, + int alphabet_size, + int minibatch, + float* costs, + void* workspace, + ctcOptions options +) { + if( + activations == nullptr + || label_lengths == nullptr + || input_lengths == nullptr + || costs == nullptr + || workspace == nullptr + || alphabet_size <= 0 + || minibatch <= 0 + ) { return CTC_STATUS_INVALID_VALUE; + } - if (options.loc == CTC_CPU) { - CpuCTC ctc(alphabet_size, minibatch, workspace, options.num_threads, - options.blank_label); - - if (gradients != NULL) - return ctc.cost_and_grad(activations, gradients, - costs, - flat_labels, label_lengths, - input_lengths); - else - return ctc.score_forward(activations, costs, flat_labels, - label_lengths, input_lengths); - } else if (options.loc == CTC_GPU) { + if(options.loc == CTC_CPU) { + CpuCTC < float > ctc( + alphabet_size, + minibatch, + workspace, + options.num_threads, + options.blank_label + ); + + if(gradients != NULL) { + return ctc.cost_and_grad( + activations, + gradients, + costs, + flat_labels, + label_lengths, + input_lengths + ); + } else { + return ctc.score_forward( + activations, + costs, + flat_labels, + label_lengths, + input_lengths + ); + } + } else if(options.loc == CTC_GPU) { #ifdef __CUDACC__ - GpuCTC ctc(alphabet_size, minibatch, workspace, options.stream, - options.blank_label); - - if (gradients != NULL) - return ctc.cost_and_grad(activations, gradients, costs, - flat_labels, label_lengths, - input_lengths); - else - return ctc.score_forward(activations, costs, flat_labels, - label_lengths, input_lengths); + GpuCTC < float > ctc( + alphabet_size, + minibatch, + workspace, + options.stream, + options.blank_label + ); + + if(gradients != NULL) { + return ctc.cost_and_grad( + activations, + gradients, + costs, + flat_labels, + label_lengths, + input_lengths + ); + } else { + return ctc.score_forward( + activations, + costs, + flat_labels, + label_lengths, + input_lengths + ); + } #else std::cerr << "GPU execution requested, but not compiled with GPU support" << std::endl; return CTC_STATUS_EXECUTION_FAILED; @@ -90,19 +126,24 @@ ctcStatus_t compute_ctc_loss(const float* const activations, } -ctcStatus_t get_workspace_size(const int* const label_lengths, - const int* const input_lengths, - int alphabet_size, int minibatch, - ctcOptions options, - size_t* size_bytes) -{ - - if (label_lengths == nullptr || - input_lengths == nullptr || - size_bytes == nullptr || - alphabet_size <= 0 || - minibatch <= 0) +ctcStatus_t get_workspace_size( + const int* const label_lengths, + const int* const input_lengths, + int alphabet_size, + int minibatch, + ctcOptions options, + size_t* size_bytes +) { + + if( + label_lengths == nullptr + || input_lengths == nullptr + || size_bytes == nullptr + || alphabet_size <= 0 + || minibatch <= 0 + ) { return CTC_STATUS_INVALID_VALUE; + } // This is the max of all S and T for all examples in the minibatch. int maxL = *std::max_element(label_lengths, label_lengths + minibatch); @@ -112,61 +153,61 @@ ctcStatus_t get_workspace_size(const int* const label_lengths, *size_bytes = 0; - if (options.loc == CTC_GPU) { + if(options.loc == CTC_GPU) { // GPU storage - //nll_forward, nll_backward + // nll_forward, nll_backward *size_bytes += 2 * sizeof(float) * minibatch; - //repeats + // repeats *size_bytes += sizeof(int) * minibatch; - //label offsets + // label offsets *size_bytes += sizeof(int) * minibatch; - //utt_length + // utt_length *size_bytes += sizeof(int) * minibatch; - //label lengths + // label lengths *size_bytes += sizeof(int) * minibatch; - //labels without blanks - overallocate for now + // labels without blanks - overallocate for now *size_bytes += sizeof(int) * maxL * minibatch; - //labels with blanks + // labels with blanks *size_bytes += sizeof(int) * S * minibatch; - //alphas + // alphas *size_bytes += sizeof(float) * S * maxT * minibatch; - //denoms + // denoms *size_bytes += sizeof(float) * maxT * minibatch; - //probs (since we will pass in activations) + // probs (since we will pass in activations) *size_bytes += sizeof(float) * alphabet_size * maxT * minibatch; } else { - //cpu can eventually replace all minibatch with - //max number of concurrent threads if memory is - //really tight + // cpu can eventually replace all minibatch with + // max number of concurrent threads if memory is + // really tight - //per minibatch memory + // per minibatch memory size_t per_minibatch_bytes = 0; - //output - per_minibatch_bytes += sizeof(float) * alphabet_size ; + // output + per_minibatch_bytes += sizeof(float) * alphabet_size; - //alphas + // alphas per_minibatch_bytes += sizeof(float) * S * maxT; - //betas + // betas per_minibatch_bytes += sizeof(float) * S; - //labels w/blanks, e_inc, s_inc + // labels w/blanks, e_inc, s_inc per_minibatch_bytes += 3 * sizeof(int) * S; *size_bytes = per_minibatch_bytes * minibatch; - //probs + // probs *size_bytes += sizeof(float) * alphabet_size * maxT * minibatch; } diff --git a/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu b/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu index a11d2ff..051afe7 100644 --- a/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu +++ b/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu @@ -15,13 +15,15 @@ const int warp_size = 32; -template -struct CTAReduce; +template < int NT, typename T, typename Rop +> struct CTAReduce; -template -struct CTAReduce { - enum { Size = NT, Capacity = NT }; - struct Storage { T shared[Capacity]; }; +template < int NT, typename T, typename Rop +> struct CTAReduce { + enum {Size = NT, Capacity = NT}; + struct Storage { + T shared[Capacity]; + }; __device__ static T reduce(int tid, T x, Storage& storage, int count, Rop g) { T* s = storage.shared; @@ -40,20 +42,27 @@ struct CTAReduce { } T shuff; - for (int offset = warp_size / 2; offset > 0; offset /= 2) { + for(int offset = warp_size / 2; offset > 0; offset /= 2) { shuff = __shfl_down_sync(0xffffffff, x, offset); - if (tid + offset < count && tid < offset) + if(tid + offset < count && tid < offset) { x = g(x, shuff); + } } return x; } }; -template -__global__ void reduce_rows(Iop f, Rop g, const T* input, T* output, - int num_rows, int num_cols) { - - typedef CTAReduce R; +template < int NT, typename Iop, typename Rop, typename T +> __global__ void reduce_rows( + Iop f, + Rop g, + const T* input, + T* output, + int num_rows, + int num_cols +) { + + typedef CTAReduce < NT, T, Rop > R; __shared__ typename R::Storage storage; int tid = threadIdx.x; @@ -62,13 +71,14 @@ __global__ void reduce_rows(Iop f, Rop g, const T* input, T* output, T curr; // Each block works on a column - if (idx < num_rows) - curr = f(input[idx + col*num_rows]); + if(idx < num_rows) { + curr = f(input[idx + col * num_rows]); + } idx += NT; - while (idx < num_rows) { - curr = g(curr, f(input[idx + col*num_rows])); + while(idx < num_rows) { + curr = g(curr, f(input[idx + col * num_rows])); idx += NT; } @@ -76,13 +86,20 @@ __global__ void reduce_rows(Iop f, Rop g, const T* input, T* output, curr = R::reduce(tid, curr, storage, num_rows, g); // Store result in out - if (tid == 0) + if(tid == 0) { output[col] = curr; + } } -template -__global__ void reduce_cols(Iop f, Rop g, const T* input, T* output, - int num_rows, int num_cols) { +template < int NT, typename Iop, typename Rop, typename T +> __global__ void reduce_cols( + Iop f, + Rop g, + const T* input, + T* output, + int num_rows, + int num_cols +) { __shared__ T s[NT]; @@ -91,11 +108,11 @@ __global__ void reduce_cols(Iop f, Rop g, const T* input, T* output, int col = threadIdx.y; T curr; - if (row < num_rows && col < num_cols) { - curr = f(input[row + col*num_rows]); + if(row < num_rows && col < num_cols) { + curr = f(input[row + col * num_rows]); col += blockDim.y; - while (col < num_cols) { - curr = g(curr, f(input[row + col*num_rows])); + while(col < num_cols) { + curr = g(curr, f(input[row + col * num_rows])); col += blockDim.y; } } @@ -103,56 +120,124 @@ __global__ void reduce_cols(Iop f, Rop g, const T* input, T* output, __syncthreads(); // Reduce - if (threadIdx.y == 0 && row < num_rows) { + if(threadIdx.y == 0 && row < num_rows) { #pragma unroll - for (int i = 1; i < warps_per_block && i < num_cols; ++i) + for(int i = 1; i < warps_per_block && i < num_cols; ++i) { curr = g(curr, s[i + threadIdx.x * warps_per_block]); + } output[row] = curr; } } struct ReduceHelper { - template - static void impl(Iof f, Rof g, const T* input, T* output, int num_rows, int num_cols, bool axis, cudaStream_t stream) { + template < typename T, typename Iof, typename Rof + > static void impl( + Iof f, + Rof g, + const T* input, + T* output, + int num_rows, + int num_cols, + bool axis, + cudaStream_t stream + ) { int grid_size; - if (axis) { + if(axis) { grid_size = num_cols; - reduce_rows<128><<>> - (f, g, input, output, num_rows, num_cols); + reduce_rows < 128 > << < grid_size, 128, 0, stream >> + > (f, g, input, output, num_rows, num_cols); } else { dim3 tpb(warp_size, 128 / warp_size); - grid_size = (num_cols + warp_size - 1)/warp_size; - reduce_cols<128><<>> - (f, g, input, output, num_rows, num_cols); + grid_size = (num_cols + warp_size - 1) / warp_size; + reduce_cols < 128 > << < grid_size, tpb, 0, stream >> + > (f, g, input, output, num_rows, num_cols); } } }; -template -ctcStatus_t reduce(Iof f, Rof g, const T* input, T* output, int rows, int cols, bool axis, cudaStream_t stream) { +template < typename T, typename Iof, typename Rof +> ctcStatus_t reduce( + Iof f, + Rof g, + const T* input, + T* output, + int rows, + int cols, + bool axis, + cudaStream_t stream +) { ReduceHelper::impl(f, g, input, output, rows, cols, axis, stream); cudaStreamSynchronize(stream); cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) + if(err != cudaSuccess) { return CTC_STATUS_EXECUTION_FAILED; + } return CTC_STATUS_SUCCESS; } -ctcStatus_t reduce_negate(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) { - return reduce(ctc_helper::negate(), ctc_helper::add(), input, output, rows, cols, axis, stream); +ctcStatus_t reduce_negate( + const float* input, + float* output, + int rows, + int cols, + bool axis, + cudaStream_t stream +) { + return reduce( + ctc_helper::negate < float > (), + ctc_helper::add < float > (), + input, + output, + rows, + cols, + axis, + stream + ); } -ctcStatus_t reduce_exp(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) { - return reduce(ctc_helper::exponential(), ctc_helper::add(), input, output, rows, cols, axis, stream); +ctcStatus_t reduce_exp( + const float* input, + float* output, + int rows, + int cols, + bool axis, + cudaStream_t stream +) { + return reduce( + ctc_helper::exponential < float > (), + ctc_helper::add < float > (), + input, + output, + rows, + cols, + axis, + stream + ); } -ctcStatus_t reduce_max(const float *input, float *output, int rows, int cols, bool axis, cudaStream_t stream) { - return reduce(ctc_helper::identity(), ctc_helper::maximum(),input, output, rows, cols, axis, stream); +ctcStatus_t reduce_max( + const float* input, + float* output, + int rows, + int cols, + bool axis, + cudaStream_t stream +) { + return reduce( + ctc_helper::identity < float > (), + ctc_helper::maximum < float > (), + input, + output, + rows, + cols, + axis, + stream + ); } From 470c33a28923ad0c7a386031a596e6b3ef728a0e Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 20:14:05 +0100 Subject: [PATCH 17/24] One liner ifs and fors now are broken but omit braces --- flashlight/fl/autograd/Functions.cpp | 233 ++++++------------ flashlight/fl/autograd/Functions.h | 12 +- flashlight/fl/autograd/Variable.cpp | 32 +-- .../tensor/backend/cudnn/BatchNorm.cpp | 28 +-- .../autograd/tensor/backend/cudnn/Conv2D.cpp | 79 +++--- .../tensor/backend/cudnn/CudnnUtils.cpp | 18 +- .../fl/autograd/tensor/backend/cudnn/RNN.cpp | 31 +-- .../tensor/backend/onednn/BatchNorm.cpp | 34 +-- .../autograd/tensor/backend/onednn/Conv2D.cpp | 33 +-- .../tensor/backend/onednn/DnnlUtils.cpp | 12 +- .../tensor/backend/onednn/DnnlUtils.h | 9 +- .../autograd/tensor/backend/onednn/Pool2D.cpp | 12 +- .../fl/autograd/tensor/backend/onednn/RNN.cpp | 25 +- flashlight/fl/common/Defines.cpp | 3 +- flashlight/fl/common/DevicePtr.cpp | 13 +- flashlight/fl/common/DynamicBenchmark.cpp | 4 +- flashlight/fl/common/DynamicBenchmark.h | 18 +- flashlight/fl/common/Histogram.cpp | 22 +- flashlight/fl/common/Histogram.h | 34 +-- flashlight/fl/common/Logging.cpp | 24 +- flashlight/fl/common/Logging.h | 6 +- flashlight/fl/common/Serialization-inl.h | 16 +- flashlight/fl/common/Utils.cpp | 30 +-- flashlight/fl/common/Utils.h | 13 +- flashlight/fl/common/WinUtility.cpp | 12 +- flashlight/fl/common/threadpool/ThreadPool.h | 15 +- .../fl/contrib/modules/AdaptiveEmbedding.cpp | 12 +- .../fl/contrib/modules/AsymmetricConv1D.cpp | 19 +- flashlight/fl/contrib/modules/Conformer.cpp | 18 +- .../fl/contrib/modules/PositionEmbedding.cpp | 8 +- .../fl/contrib/modules/RawWavSpecAugment.cpp | 39 +-- flashlight/fl/contrib/modules/Residual.cpp | 51 ++-- .../modules/SinusoidalPositionEmbedding.cpp | 3 +- flashlight/fl/contrib/modules/SpecAugment.cpp | 21 +- flashlight/fl/contrib/modules/TDSBlock.cpp | 22 +- flashlight/fl/contrib/modules/Transformer.cpp | 26 +- flashlight/fl/dataset/BatchDataset.cpp | 15 +- flashlight/fl/dataset/BlobDataset.cpp | 35 +-- flashlight/fl/dataset/ConcatDataset.cpp | 3 +- flashlight/fl/dataset/Dataset.h | 3 +- flashlight/fl/dataset/DatasetIterator.h | 6 +- flashlight/fl/dataset/FileBlobDataset.cpp | 18 +- flashlight/fl/dataset/MemoryBlobDataset.cpp | 3 +- flashlight/fl/dataset/MergeDataset.cpp | 6 +- flashlight/fl/dataset/PrefetchDataset.cpp | 12 +- flashlight/fl/dataset/ResampleDataset.cpp | 6 +- flashlight/fl/dataset/ShuffleDataset.cpp | 3 +- flashlight/fl/dataset/SpanDataset.cpp | 3 +- flashlight/fl/dataset/TensorDataset.cpp | 6 +- flashlight/fl/dataset/TransformDataset.cpp | 6 +- flashlight/fl/dataset/Utils.cpp | 57 ++--- flashlight/fl/distributed/DistributedApi.cpp | 12 +- flashlight/fl/distributed/FileStore.cpp | 15 +- flashlight/fl/distributed/LRUCache.h | 7 +- .../backend/cpu/DistributedBackend.cpp | 30 +-- .../backend/cuda/DistributedBackend.cpp | 87 +++---- .../backend/stub/DistributedBackend.cpp | 6 +- .../reducers/CoalescingReducer.cpp | 13 +- .../fl/distributed/reducers/InlineReducer.cpp | 3 +- .../fl/examples/AdaptiveClassification.cpp | 3 +- flashlight/fl/examples/Benchmark.cpp | 6 +- flashlight/fl/examples/Classification.cpp | 3 +- .../fl/examples/DistributedTraining.cpp | 9 +- flashlight/fl/examples/Mnist.cpp | 9 +- flashlight/fl/examples/RnnClassification.cpp | 24 +- flashlight/fl/examples/RnnLm.cpp | 18 +- flashlight/fl/examples/Xor.cpp | 12 +- flashlight/fl/meter/AverageValueMeter.cpp | 6 +- flashlight/fl/meter/CountMeter.cpp | 3 +- flashlight/fl/meter/EditDistanceMeter.cpp | 6 +- flashlight/fl/meter/EditDistanceMeter.h | 10 +- flashlight/fl/meter/FrameErrorMeter.cpp | 6 +- flashlight/fl/meter/MSEMeter.cpp | 3 +- flashlight/fl/meter/TimeMeter.cpp | 9 +- flashlight/fl/meter/TopKMeter.cpp | 6 +- flashlight/fl/nn/DistributedUtils.cpp | 15 +- flashlight/fl/nn/Init.cpp | 9 +- flashlight/fl/nn/Utils.cpp | 44 ++-- flashlight/fl/nn/modules/AdaptiveSoftMax.cpp | 12 +- flashlight/fl/nn/modules/BatchNorm.cpp | 8 +- flashlight/fl/nn/modules/Container.cpp | 39 +-- flashlight/fl/nn/modules/Container.h | 3 +- flashlight/fl/nn/modules/Conv2D.cpp | 32 +-- flashlight/fl/nn/modules/Dropout.cpp | 5 +- flashlight/fl/nn/modules/LayerNorm.cpp | 40 ++- flashlight/fl/nn/modules/Linear.cpp | 14 +- flashlight/fl/nn/modules/Loss.cpp | 29 +-- flashlight/fl/nn/modules/Module.cpp | 24 +- flashlight/fl/nn/modules/Normalize.cpp | 3 +- flashlight/fl/nn/modules/Padding.cpp | 3 +- flashlight/fl/nn/modules/Pool2D.cpp | 13 +- flashlight/fl/nn/modules/RNN.cpp | 18 +- flashlight/fl/nn/modules/Reorder.cpp | 3 +- flashlight/fl/nn/modules/WeightNorm.cpp | 25 +- flashlight/fl/optim/AMSgradOptimizer.cpp | 9 +- flashlight/fl/optim/AdadeltaOptimizer.cpp | 12 +- flashlight/fl/optim/AdagradOptimizer.cpp | 9 +- flashlight/fl/optim/AdamOptimizer.cpp | 9 +- flashlight/fl/optim/NAGOptimizer.cpp | 12 +- flashlight/fl/optim/NovogradOptimizer.cpp | 6 +- flashlight/fl/optim/Optimizers.cpp | 3 +- flashlight/fl/optim/RMSPropOptimizer.cpp | 15 +- flashlight/fl/optim/SGDOptimizer.cpp | 19 +- flashlight/fl/optim/Utils.cpp | 9 +- flashlight/fl/runtime/CUDAStream.cpp | 12 +- flashlight/fl/runtime/CUDAUtils.cpp | 3 +- flashlight/fl/runtime/Device.cpp | 9 +- flashlight/fl/runtime/DeviceManager.cpp | 12 +- flashlight/fl/runtime/Stream.cpp | 3 +- flashlight/fl/runtime/Stream.h | 3 +- flashlight/fl/tensor/Compute.cpp | 21 +- flashlight/fl/tensor/Index.cpp | 3 +- flashlight/fl/tensor/Shape.cpp | 6 +- flashlight/fl/tensor/TensorBackend.cpp | 3 +- flashlight/fl/tensor/TensorBackend.h | 3 +- flashlight/fl/tensor/TensorBase.cpp | 24 +- flashlight/fl/tensor/TensorBase.h | 11 +- flashlight/fl/tensor/TensorExtension.cpp | 6 +- flashlight/fl/tensor/Types.cpp | 3 +- .../fl/tensor/backend/af/AdvancedIndex.cu | 50 ++-- .../fl/tensor/backend/af/ArrayFireBLAS.cpp | 12 +- .../fl/tensor/backend/af/ArrayFireBackend.cpp | 36 +-- .../tensor/backend/af/ArrayFireBinaryOps.cpp | 13 +- .../tensor/backend/af/ArrayFireReductions.cpp | 70 ++---- .../backend/af/ArrayFireShapeAndIndex.cpp | 26 +- .../fl/tensor/backend/af/ArrayFireTensor.cpp | 68 ++--- .../tensor/backend/af/ArrayFireUnaryOps.cpp | 3 +- flashlight/fl/tensor/backend/af/Utils.cpp | 35 +-- .../backend/af/mem/CachingMemoryManager.cpp | 83 +++---- .../backend/af/mem/DefaultMemoryManager.cpp | 72 ++---- .../backend/af/mem/MemoryManagerAdapter.cpp | 9 +- .../backend/af/mem/MemoryManagerAdapter.h | 3 +- .../backend/af/mem/MemoryManagerInstaller.cpp | 6 +- .../test/autograd/AutogradBinaryOpsTest.cpp | 3 +- .../fl/test/autograd/AutogradConv2DTest.cpp | 3 +- .../autograd/AutogradNormalizationTest.cpp | 27 +- .../test/autograd/AutogradReductionTest.cpp | 6 +- .../fl/test/autograd/AutogradRnnTest.cpp | 18 +- flashlight/fl/test/autograd/AutogradTest.cpp | 30 +-- .../fl/test/autograd/AutogradTestUtils.h | 3 +- .../fl/test/autograd/AutogradUnaryOpsTest.cpp | 6 +- .../fl/test/common/DynamicBenchmarkTest.cpp | 7 +- flashlight/fl/test/common/HistogramTest.cpp | 15 +- flashlight/fl/test/common/LoggingTest.cpp | 25 +- flashlight/fl/test/common/UtilsTest.cpp | 22 +- .../contrib/modules/ContribModuleTest.cpp | 42 ++-- flashlight/fl/test/dataset/DatasetTest.cpp | 89 +++---- .../test/distributed/AllReduceBenchmark.cpp | 9 +- .../fl/test/distributed/AllReduceTest.cpp | 36 +-- flashlight/fl/test/nn/ModuleTest.cpp | 54 ++-- flashlight/fl/test/nn/NNSerializationTest.cpp | 9 +- flashlight/fl/test/optim/OptimBenchmark.cpp | 6 +- flashlight/fl/test/optim/OptimTest.cpp | 15 +- .../fl/test/runtime/DeviceManagerTest.cpp | 18 +- flashlight/fl/test/runtime/DeviceTest.cpp | 45 ++-- flashlight/fl/test/tensor/IndexTest.cpp | 27 +- flashlight/fl/test/tensor/ShapeTest.cpp | 3 +- flashlight/fl/test/tensor/TensorBLASTest.cpp | 9 +- flashlight/fl/test/tensor/TensorBaseTest.cpp | 26 +- .../fl/test/tensor/TensorBinaryOpsTest.cpp | 16 +- .../fl/test/tensor/TensorExtensionTest.cpp | 3 +- .../fl/test/tensor/TensorReductionTest.cpp | 21 +- .../fl/test/tensor/TensorUnaryOpsTest.cpp | 30 +-- .../tensor/af/ArrayFireTensorBaseTest.cpp | 9 +- .../tensor/af/CachingMemoryManagerTest.cpp | 20 +- .../fl/test/tensor/af/MemoryFrameworkTest.cpp | 31 +-- .../fl/test/tensor/af/MemoryInitTest.cpp | 3 +- flashlight/pkg/runtime/Runtime.cpp | 9 +- flashlight/pkg/runtime/amp/DynamicScaler.cpp | 9 +- .../pkg/runtime/common/DistributedUtils.cpp | 11 +- .../pkg/runtime/common/DistributedUtils.h | 3 +- .../pkg/runtime/common/SequentialBuilder.cpp | 172 +++++-------- flashlight/pkg/runtime/common/Serializer.h | 6 +- .../test/common/SequentialBuilderTest.cpp | 9 +- .../pkg/speech/audio/feature/Ceplifter.cpp | 9 +- flashlight/pkg/speech/audio/feature/Dct.cpp | 6 +- .../pkg/speech/audio/feature/Derivatives.cpp | 15 +- .../pkg/speech/audio/feature/Dither.cpp | 3 +- .../pkg/speech/audio/feature/FeatureParams.h | 3 +- flashlight/pkg/speech/audio/feature/Mfcc.cpp | 15 +- flashlight/pkg/speech/audio/feature/Mfsc.cpp | 17 +- .../speech/audio/feature/PowerSpectrum.cpp | 34 +-- .../pkg/speech/audio/feature/PreEmphasis.cpp | 12 +- .../pkg/speech/audio/feature/SpeechUtils.cpp | 9 +- .../speech/audio/feature/TriFilterbank.cpp | 6 +- .../pkg/speech/audio/feature/Windowing.cpp | 15 +- .../pkg/speech/augmentation/AdditiveNoise.cpp | 21 +- .../pkg/speech/augmentation/GaussianNoise.cpp | 6 +- .../pkg/speech/augmentation/Reverberation.cpp | 12 +- .../pkg/speech/augmentation/SoundEffect.cpp | 12 +- .../speech/augmentation/SoundEffectApply.cpp | 9 +- .../speech/augmentation/SoundEffectConfig.cpp | 26 +- .../speech/augmentation/SoundEffectUtil.cpp | 9 +- .../pkg/speech/augmentation/TimeStretch.cpp | 3 +- .../pkg/speech/common/ProducerConsumerQueue.h | 12 +- .../criterion/AutoSegmentationCriterion.h | 6 +- ...tionistTemporalClassificationCriterion.cpp | 18 +- .../pkg/speech/criterion/CriterionUtils.cpp | 26 +- .../pkg/speech/criterion/CriterionUtils.h | 18 +- .../criterion/ForceAlignmentCriterion.cpp | 3 +- .../criterion/FullConnectionCriterion.cpp | 3 +- .../criterion/LinearSegmentationCriterion.h | 3 +- .../pkg/speech/criterion/Seq2SeqCriterion.cpp | 144 +++++------ .../speech/criterion/TransformerCriterion.cpp | 64 ++--- .../criterion/attention/AttentionBase.h | 3 +- .../criterion/attention/ContentAttention.cpp | 15 +- .../criterion/attention/LocationAttention.cpp | 27 +- .../attention/MultiHeadAttention.cpp | 15 +- .../attention/SoftPretrainWindow.cpp | 3 +- .../speech/criterion/attention/WindowBase.cpp | 17 +- ...tionistTemporalClassificationCriterion.cpp | 44 ++-- .../criterion/backend/cpu/CriterionUtils.cpp | 7 +- .../backend/cpu/ForceAlignmentCriterion.cpp | 17 +- .../backend/cpu/FullConnectionCriterion.cpp | 10 +- ...tionistTemporalClassificationCriterion.cpp | 12 +- .../criterion/backend/cuda/CriterionUtils.cpp | 13 +- .../backend/cuda/ForceAlignmentCriterion.cpp | 23 +- .../backend/cuda/FullConnectionCriterion.cpp | 19 +- .../pkg/speech/data/FeatureTransforms.cpp | 39 ++- .../pkg/speech/data/FeatureTransforms.h | 21 +- .../pkg/speech/data/ListFileDataset.cpp | 20 +- flashlight/pkg/speech/data/Sound.cpp | 52 ++-- flashlight/pkg/speech/data/Utils.cpp | 33 +-- .../pkg/speech/decoder/ConvLmModule.cpp | 12 +- .../pkg/speech/decoder/DecodeMaster.cpp | 35 +-- flashlight/pkg/speech/decoder/DecodeUtils.cpp | 10 +- flashlight/pkg/speech/decoder/PlGenerator.cpp | 50 ++-- .../pkg/speech/decoder/TranscriptionUtils.cpp | 39 +-- .../pkg/speech/decoder/TranscriptionUtils.h | 18 +- flashlight/pkg/speech/runtime/Attention.cpp | 42 ++-- flashlight/pkg/speech/runtime/Helpers.cpp | 29 +-- flashlight/pkg/speech/runtime/Logger.cpp | 11 +- flashlight/pkg/speech/runtime/Optimizer.cpp | 20 +- flashlight/pkg/speech/test/audio/MfccTest.cpp | 33 +-- flashlight/pkg/speech/test/audio/TestUtils.h | 15 +- .../test/augmentation/AdditiveNoiseTest.cpp | 3 +- .../test/augmentation/GaussianNoiseTest.cpp | 6 +- .../test/augmentation/ReverberationTest.cpp | 3 +- .../test/common/ProducerConsumerQueueTest.cpp | 21 +- .../test/criterion/BenchmarkSeq2Seq.cpp | 3 +- .../pkg/speech/test/criterion/CompareASG.cpp | 31 +-- .../speech/test/criterion/CriterionTest.cpp | 24 +- .../pkg/speech/test/criterion/Seq2SeqTest.cpp | 24 +- .../criterion/attention/AttentionTest.cpp | 15 +- .../test/criterion/attention/WindowTest.cpp | 3 +- .../speech/test/data/FeaturizationTest.cpp | 45 ++-- .../speech/test/data/ListFileDatasetTest.cpp | 9 +- flashlight/pkg/speech/test/data/SoundTest.cpp | 21 +- .../speech/test/decoder/ConvLmModuleTest.cpp | 3 +- .../pkg/speech/test/runtime/RuntimeTest.cpp | 10 +- .../warpctc/include/detail/cpu_ctc.h | 66 ++--- .../warpctc/include/detail/ctc_helper.h | 6 +- .../warpctc/include/detail/gpu_ctc.h | 62 ++--- .../warpctc/include/detail/gpu_ctc_kernels.h | 74 ++---- .../third_party/warpctc/src/ctc_entrypoint.cu | 20 +- .../speech/third_party/warpctc/src/reduce.cu | 15 +- flashlight/pkg/text/data/TextDataset.cpp | 29 +-- .../pkg/text/test/data/TextDatasetTest.cpp | 9 +- .../pkg/vision/common/BetaDistribution.h | 5 +- flashlight/pkg/vision/criterion/Hungarian.cpp | 3 +- .../pkg/vision/criterion/HungarianImpl.cpp | 95 +++---- .../pkg/vision/criterion/SetCriterion.cpp | 32 +-- .../vision/dataset/BatchTransformDataset.h | 18 +- flashlight/pkg/vision/dataset/BoxUtils.cpp | 27 +- flashlight/pkg/vision/dataset/Coco.cpp | 10 +- .../pkg/vision/dataset/CocoTransforms.cpp | 14 +- .../pkg/vision/dataset/DistributedDataset.cpp | 3 +- flashlight/pkg/vision/dataset/Imagenet.cpp | 23 +- flashlight/pkg/vision/dataset/Jpeg.cpp | 3 +- flashlight/pkg/vision/dataset/Transforms.cpp | 43 ++-- flashlight/pkg/vision/models/Detr.cpp | 6 +- flashlight/pkg/vision/models/Resnet.cpp | 22 +- .../vision/models/ResnetFrozenBatchNorm.cpp | 22 +- flashlight/pkg/vision/models/ViT.cpp | 9 +- flashlight/pkg/vision/nn/FrozenBatchNorm.cpp | 6 +- .../pkg/vision/nn/PositionalEmbeddingSine.cpp | 6 +- flashlight/pkg/vision/nn/Transformer.cpp | 21 +- .../pkg/vision/nn/VisionTransformer.cpp | 9 +- .../backend/af/ArrayFireVisionExtension.cpp | 27 +- .../vision/test/criterion/HungarianTest.cpp | 48 ++-- .../test/criterion/SetCriterionTest.cpp | 3 +- uncrustify.cfg | 5 +- 282 files changed, 1989 insertions(+), 3646 deletions(-) diff --git a/flashlight/fl/autograd/Functions.cpp b/flashlight/fl/autograd/Functions.cpp index 3c6c530..782b574 100644 --- a/flashlight/fl/autograd/Functions.cpp +++ b/flashlight/fl/autograd/Functions.cpp @@ -26,9 +26,8 @@ namespace detail { Tensor tileAs(const Tensor& input, const Shape& rdims) { // Scalar tensor - if(input.ndim() == 0) { + if(input.ndim() == 0) return tile(input, rdims); - } Shape dims(std::vector(rdims.ndim(), 1)); Shape idims = input.shape(); @@ -48,11 +47,9 @@ namespace detail { Tensor sumAs(const Tensor& input, const Shape& rdims) { Shape idims = input.shape(); auto result = input; - for(int i = 0; i < input.ndim(); i++) { - if(i + 1 > rdims.ndim() || idims[i] != rdims[i]) { + for(int i = 0; i < input.ndim(); i++) + if(i + 1 > rdims.ndim() || idims[i] != rdims[i]) result = fl::sum(result, {i}, /* keepDims = */ true); - } - } return fl::reshape(result.astype(input.type()), rdims); } @@ -63,23 +60,21 @@ namespace detail { bool keepDims /* = false */ ) { // Fast path - tensor already retained its shape - if(keepDims) { + if(keepDims) return input.shape(); - } // If we output a scalar, - if(input.ndim() == 0) { + if(input.ndim() == 0) return {}; - } unsigned preNDims = input.ndim() + axes.size(); Shape newShape(std::vector(preNDims, 1)); unsigned axesIdx = 0; unsigned inputIdx = 0; for(unsigned i = 0; i < preNDims; ++i) { - if(i == axes[axesIdx]) { + if(i == axes[axesIdx]) // This dim was reduced over, leave as 1 in the new shape axesIdx++; - } else { + else { // Dim wasn't reduced over - add the shape from the new tensor newShape[i] = input.dim(inputIdx); inputIdx++; @@ -176,16 +171,14 @@ Variable operator*(const Variable& lhs, const Variable& rhs) { auto result = lhs.tensor() * rhs.tensor(); auto gradFunc = [](std::vector& inputs, const Variable& gradOutput) { - if(inputs[0].isCalcGrad()) { + if(inputs[0].isCalcGrad()) inputs[0].addGrad( Variable(gradOutput.tensor() * inputs[1].tensor(), false) ); - } - if(inputs[1].isCalcGrad()) { + if(inputs[1].isCalcGrad()) inputs[1].addGrad( Variable(gradOutput.tensor() * inputs[0].tensor(), false) ); - } }; return Variable( result, @@ -215,17 +208,15 @@ Variable operator/(const Variable& lhs, const Variable& rhs) { const Variable& gradOutput) { auto inputs1rec = reciprocal(inputs[1]); auto gradInput0 = gradOutput * inputs1rec; - if(inputs[0].isCalcGrad()) { + if(inputs[0].isCalcGrad()) inputs[0].addGrad(Variable(gradInput0.tensor(), false)); - } - if(inputs[1].isCalcGrad()) { + if(inputs[1].isCalcGrad()) inputs[1].addGrad( Variable( (gradInput0 * negate(inputs[0]) * inputs1rec).tensor(), false ) ); - } }; return Variable( result, @@ -560,9 +551,8 @@ Variable transpose(const Variable& input, const Shape& dims /* = {} */) { reverseShape = Shape(dVec); } - for(unsigned i = 0; i < reverseShape.ndim(); ++i) { + for(unsigned i = 0; i < reverseShape.ndim(); ++i) reverseShape[dims[i]] = i; - } inputs[0].addGrad( Variable(fl::transpose(gradOutput.tensor(), reverseShape), false) @@ -607,32 +597,26 @@ Variable sumAs(const Variable& input, const Variable& reference) { } Variable concatenate(const std::vector& concatInputs, int dim) { - if(concatInputs.empty()) { + if(concatInputs.empty()) throw std::invalid_argument("cannot concatenate zero variables"); - } - if(concatInputs.size() == 1) { + if(concatInputs.size() == 1) return concatInputs[0]; - } // All Variables must be of the same type fl::dtype type = concatInputs[0].type(); - for(auto& var : concatInputs) { - if(var.type() != type) { + for(auto& var : concatInputs) + if(var.type() != type) throw std::invalid_argument( "concatenate: all input Variables must be of the same type" ); - } - } // All Variables must have the same number of dims unsigned numDims = concatInputs[0].ndim(); - for(auto& var : concatInputs) { - if(numDims != var.ndim()) { + for(auto& var : concatInputs) + if(numDims != var.ndim()) throw std::invalid_argument( "concatenate: all input Variables must " "have the same number of dimensions" ); - } - } // All Variables must have the same size when indexed along the dim not being // concatenated along @@ -640,13 +624,11 @@ Variable concatenate(const std::vector& concatInputs, int dim) { int concatSize = dims[dim]; for(int i = 1; i < concatInputs.size(); i++) { concatSize += concatInputs[i].dim(dim); - for(int d = 0; d < numDims; d++) { - if(dim != d && concatInputs[i].dim(d) != dims[d]) { + for(int d = 0; d < numDims; d++) + if(dim != d && concatInputs[i].dim(d) != dims[d]) throw std::invalid_argument( "mismatch in dimension not being concatenated" ); - } - } } dims[dim] = concatSize; Tensor result(dims, concatInputs[0].type()); @@ -682,25 +664,22 @@ Variable concatenate(const std::vector& concatInputs, int dim) { } std::vector split(const Variable& input, long splitSize, int dim) { - if(splitSize <= 0) { + if(splitSize <= 0) throw std::invalid_argument("split size must be a positive integer"); - } auto dimSize = input.dim(dim); std::vector splitSizes(dimSize / splitSize, splitSize); - if(dimSize % splitSize > 0) { + if(dimSize % splitSize > 0) splitSizes.push_back(dimSize % splitSize); - } return split(input, splitSizes, dim); } std::vector split(const Variable& input, const std::vector& splitSizes, int dim) { - if(dim >= input.ndim()) { + if(dim >= input.ndim()) throw std::invalid_argument( "split: passed dim is larger than the number of dimensions " "of the input." ); - } auto dimSize = input.dim(dim); auto N = splitSizes.size(); @@ -708,17 +687,15 @@ std::vector split(const Variable& input, const std::vector& spli std::vector sel(input.ndim(), fl::span); int start = 0; for(int i = 0; i < N; ++i) { - if(splitSizes[i] <= 0) { + if(splitSizes[i] <= 0) throw std::invalid_argument("elements in split sizes has to be positive"); - } int end = start + splitSizes[i]; sel[dim] = fl::range(start, end); outputs[i] = input(sel); start = end; } - if(start != dimSize) { + if(start != dimSize) throw std::invalid_argument("sum of split sizes must match split dim"); - } return outputs; } @@ -812,14 +789,12 @@ Variable var( auto avg = fl::mean(input, axes, keepDims); auto n = 1; - for(auto ax : axes) { + for(auto ax : axes) n *= input.dim(ax); - } - if(!isbiased && n == 1) { + if(!isbiased && n == 1) throw std::invalid_argument( "cannot compute unbiased variance with only one sample" ); - } auto val = 1.0 / (isbiased ? n : n - 1); result = val * (result - n * avg * avg); @@ -851,9 +826,8 @@ Variable norm( double p /* = 2 */, bool keepDims /* = false */ ) { - if(p <= 0) { + if(p <= 0) throw std::out_of_range("Lp norm: p must be > 0"); - } auto result = fl::power(fl::abs(FL_ADJUST_INPUT_TYPE(input.tensor())), p); result = fl::sum(result, axes, /* keepDims = */ keepDims); @@ -903,13 +877,11 @@ Variable matmul(const Variable& lhs, const Variable& rhs) { const Variable& gradOutput) { if(inputs[0].isCalcGrad()) { Tensor _lhs = gradOutput.tensor(); - if(_lhs.ndim() == 1) { + if(_lhs.ndim() == 1) _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)}); - } Tensor _rhs = inputs[1].tensor(); - if(_rhs.ndim() == 1) { + if(_rhs.ndim() == 1) _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1}); - } // matmulNT(gradOutput, inputs[1]) // -- matmulNT([M, K], [N, K]) @@ -924,13 +896,11 @@ Variable matmul(const Variable& lhs, const Variable& rhs) { } if(inputs[1].isCalcGrad()) { Tensor _lhs = inputs[0].tensor(); - if(_lhs.ndim() == 1) { + if(_lhs.ndim() == 1) _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)}); - } Tensor _rhs = gradOutput.tensor(); - if(_rhs.ndim() == 1) { + if(_rhs.ndim() == 1) _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1}); - } // matmulTN(inputs[0], gradOutput) // -- matmulTN([M, N], [M, K]) @@ -1042,46 +1012,38 @@ Variable flat(const Variable& input) { } Variable moddims(const Variable& input, const Shape& dims) { - if(input.ndim() == 0) { + if(input.ndim() == 0) return input; - } Shape inferDims = dims; unsigned maxNDims = std::max(input.ndim(), static_cast(dims.ndim())); // Check for inferred dims that are beyond the input's number of dims - for(int i = 0; i < maxNDims; ++i) { - if(i >= input.ndim() && inferDims[i] == 0) { + for(int i = 0; i < maxNDims; ++i) + if(i >= input.ndim() && inferDims[i] == 0) throw std::invalid_argument( "moddims: tried to infer dimension " + std::to_string(i) + " which exceeds the number of dimensions of the input." ); - } - } // Infer any 0 dim - for(int i = 0; i < maxNDims; ++i) { - if(i < inferDims.ndim() && inferDims[i] == 0) { + for(int i = 0; i < maxNDims; ++i) + if(i < inferDims.ndim() && inferDims[i] == 0) inferDims[i] = input.dim(i); - } - } // Infer any -1 dim int nInfer = 0; - for(int i = 0; i < maxNDims; i++) { + for(int i = 0; i < maxNDims; i++) if(i < inferDims.ndim() && inferDims[i] == -1) { nInfer++; inferDims[i] = -(input.elements() / inferDims.elements()); } - } - if(nInfer > 1) { + if(nInfer > 1) throw std::invalid_argument("moddims: too many dimensions infer"); - } - if(inferDims.elements() != input.elements()) { + if(inferDims.elements() != input.elements()) throw std::invalid_argument("moddims: mismatched # of elements"); - } auto result = fl::reshape(input.tensor(), inferDims); @@ -1167,19 +1129,16 @@ Variable categoricalCrossEntropy( auto input = FL_ADJUST_INPUT_TYPE(in); // input -- [C, X1, X2, X3] // target -- [X1, X2, X3, 1] - if(input.ndim() != targets.ndim() + 1) { + if(input.ndim() != targets.ndim() + 1) throw std::invalid_argument( "dimension mismatch in categorical cross entropy: " "target must have one fewer dimension than input" ); - } - for(int i = 1; i < input.ndim(); i++) { - if(input.dim(i) != targets.dim(i - 1)) { + for(int i = 1; i < input.ndim(); i++) + if(input.dim(i) != targets.dim(i - 1)) throw std::invalid_argument( "dimension mismatch in categorical cross entropy" ); - } - } int C = input.dim(0); int X = targets.elements(); @@ -1189,12 +1148,11 @@ Variable categoricalCrossEntropy( && (targets.tensor() != ignoreIndex) ) .scalar() - ) { + ) throw std::invalid_argument( "target contains elements out of valid range [0, num_categories) " "in categorical cross entropy" ); - } auto x = fl::reshape(input.tensor(), Shape({C, X})); auto y = fl::reshape(targets.tensor(), Shape({1, X})); @@ -1216,24 +1174,22 @@ Variable categoricalCrossEntropy( result = fl::sum(result, {0}) / denominator; // [1] } else if(reduction == ReduceMode::SUM) { result = fl::sum(result, {0}); // [1] - } else { + } else throw std::invalid_argument( "unknown reduction method for categorical cross entropy" ); - } auto inputDims = input.shape(); auto gradFunc = [C, X, mask, ignoreMask, denominator, reduction, inputDims]( std::vector& inputs, const Variable& gradOutput) { Tensor grad = gradOutput.tensor(); - if(reduction == ReduceMode::NONE) { + if(reduction == ReduceMode::NONE) grad = fl::reshape(grad, {X}); - } else if(reduction == ReduceMode::MEAN) { + else if(reduction == ReduceMode::MEAN) grad = fl::tile(grad / denominator, {X}); - } else if(reduction == ReduceMode::SUM) { + else if(reduction == ReduceMode::SUM) grad = fl::tile(grad, {X}); - } // [1 X] grad(ignoreMask) = 0.; grad = fl::reshape(grad, {1, X}); @@ -1252,37 +1208,32 @@ Variable weightedCategoricalCrossEntropy( ) { // input -- [C, X1, X2, X3] // target -- [X1, X2, X3] - if(input.ndim() < targets.ndim() - 1) { + if(input.ndim() < targets.ndim() - 1) throw std::invalid_argument( "weightedCategoricalCrossEntropy: input must have one more than the " "number of target dimensions minus 1" ); - } - for(int i = 1; i < targets.ndim() - 2; i++) { - if(input.dim(i) != targets.dim(i - 1)) { + for(int i = 1; i < targets.ndim() - 2; i++) + if(input.dim(i) != targets.dim(i - 1)) throw std::invalid_argument( "weightedCategoricalCrossEntropy: dimension mismatch in categorical cross entropy" ); - } - } - if(weight.dim(0) != input.dim(0)) { + if(weight.dim(0) != input.dim(0)) throw std::invalid_argument( "weightedCategoricalCrossEntropy: dimension mismatch in categorical cross entropy" ); - } int C = input.dim(0); int X = targets.elements(); if( fl::any((targets.tensor() < 0) || (targets.tensor() >= C)) .scalar() - ) { + ) throw std::invalid_argument( "weightedCategoricalCrossEntropy: target contains elements out of valid range " "[0, num_categories) in categorical cross entropy" ); - } auto x = fl::reshape(input.tensor(), {C, X}); auto y = fl::reshape(targets.tensor(), {1, X}); @@ -1321,23 +1272,20 @@ Variable weightedCategoricalCrossEntropy( Variable reorder(const Variable& input, const Shape& shape) { auto result = fl::transpose(input.tensor(), shape); - if(!result.isContiguous()) { + if(!result.isContiguous()) result = result.asContiguousTensor(); - } std::vector> dimGrad(shape.ndim()); - for(unsigned i = 0; i < shape.ndim(); ++i) { + for(unsigned i = 0; i < shape.ndim(); ++i) dimGrad[i] = {shape.dim(i), i}; - } std::sort(dimGrad.begin(), dimGrad.end()); auto gradFunc = [dimGrad](std::vector& inputs, const Variable& gradOutput) { Shape reordered(std::vector(dimGrad.size())); - for(unsigned i = 0; i < dimGrad.size(); ++i) { + for(unsigned i = 0; i < dimGrad.size(); ++i) reordered[i] = dimGrad[i].second; - } inputs[0].addGrad( Variable(fl::transpose(gradOutput.tensor(), reordered), false) @@ -1401,9 +1349,8 @@ Variable linear(const Variable& in, const Variable& wt, const Variable& bs) { wt.addGrad(Variable(wtGrad, false)); } }; - if(hasBias) { + if(hasBias) return Variable(output, {input, weight, bias}, gradFunc); - } return Variable(output, {input, weight}, gradFunc); } @@ -1560,14 +1507,12 @@ Variable conv2d( if(inputs[1].isCalcGrad()) { inputs[1].addGrad(Variable(filterGrad, false)); // filter/weight } - if(computeBiasGrad) { + if(computeBiasGrad) inputs[2].addGrad(Variable(biasGrad, false)); - } } }; - if(hasBias) { + if(hasBias) return Variable(output, {input, weights, bias}, gradFunc); - } return Variable(output, {input, weights}, gradFunc); } @@ -1589,9 +1534,8 @@ Variable pool2d( std::vector& inputs, const Variable& gradOutput) { auto& in = inputs[0]; - if(!in.isCalcGrad()) { + if(!in.isCalcGrad()) return; - } in.addGrad( Variable( @@ -1658,9 +1602,8 @@ Variable batchnorm( auto gradOutput = detail::adjustInputType(_gradOutput, "batchnorm"); - if(!in.isCalcGrad() && !wt.isCalcGrad() && !bs.isCalcGrad()) { + if(!in.isCalcGrad() && !wt.isCalcGrad() && !bs.isCalcGrad()) return; - } auto [gradIn, gradWt, gradBs] = in.tensor() @@ -1680,27 +1623,24 @@ Variable batchnorm( in.addGrad(Variable(gradIn.astype(in.type()), false)); wt.addGrad(Variable(gradWt.astype(wt.type()), false)); - if(!bs.isEmpty()) { + if(!bs.isEmpty()) bs.addGrad(Variable(gradBs.astype(bs.type()), false)); - } }; return Variable(output, {input, weight, bias}, gradFunc); } Variable gatedlinearunit(const Variable& input, const int dim) { - if(dim >= input.ndim()) { + if(dim >= input.ndim()) throw std::invalid_argument( "gatedlinearunit - passed dim is great than the " "number of dimensions of the input." ); - } auto inDims = input.shape(); auto inType = input.type(); auto inSize = inDims[dim]; - if(inSize % 2 == 1) { + if(inSize % 2 == 1) throw std::invalid_argument("halving dimension must be even for GLU"); - } std::vector fhalf(input.ndim(), fl::span); std::vector shalf(input.ndim(), fl::span); @@ -1775,9 +1715,8 @@ std::tuple rnn( if( !(input.isCalcGrad() || hiddenState.isCalcGrad() || cellState.isCalcGrad() || weights.isCalcGrad()) - ) { + ) return; - } auto [dy, dhy, dcy, dweights] = input.tensor().backend().getExtension().rnnBackward( @@ -1805,25 +1744,22 @@ std::tuple rnn( auto dyGradFunc = [gradData](std::vector& inputs, const Variable& gradOutput) { - if(!inputs[0].isGradAvailable()) { + if(!inputs[0].isGradAvailable()) inputs[0].addGrad(Variable(Tensor(), false)); - } gradData->dy = gradOutput.tensor().asContiguousTensor(); }; auto dhyGradFunc = [gradData](std::vector& inputs, const Variable& gradOutput) { - if(!inputs[0].isGradAvailable()) { + if(!inputs[0].isGradAvailable()) inputs[0].addGrad(Variable(Tensor(), false)); - } gradData->dhy = gradOutput.tensor().asContiguousTensor(); }; auto dcyGradFunc = [gradData](std::vector& inputs, const Variable& gradOutput) { - if(!inputs[0].isGradAvailable()) { + if(!inputs[0].isGradAvailable()) inputs[0].addGrad(Variable(Tensor(), false)); - } gradData->dcy = gradOutput.tensor().asContiguousTensor(); }; @@ -1835,26 +1771,23 @@ std::tuple rnn( Variable embedding(const Variable& input, const Variable& embeddings) { // TODO{fl::Tensor}{4-dims} - relax this - if(input.ndim() >= 4) { + if(input.ndim() >= 4) throw std::invalid_argument("embedding input must have 3 or fewer dims"); - } auto idxs = input.tensor().flatten(); auto inDims = input.shape(); std::vector rDims(input.ndim() + 1); rDims[0] = embeddings.dim(0); - for(unsigned i = 1; i < input.ndim() + 1; i++) { + for(unsigned i = 1; i < input.ndim() + 1; i++) rDims[i] = inDims[i - 1]; - } Shape resultDims(rDims); Tensor result = fl::reshape(embeddings.tensor()(fl::span, idxs), resultDims); auto gradFunc = [](std::vector& inputs, const Variable& gradOutput) { auto& w = inputs[1]; - if(!w.isCalcGrad()) { + if(!w.isCalcGrad()) return; - } auto ip = inputs[0].tensor().flatten(); unsigned size = ip.elements(); @@ -1888,12 +1821,11 @@ Variable padding( std::vector> pad, double val ) { - if(pad.size() > input.ndim()) { + if(pad.size() > input.ndim()) throw std::invalid_argument( "padding: number of padding dimensions exceeds number " "of input dimensions" ); - } Shape opDims = input.shape(); std::vector inSeq(input.ndim(), fl::span); @@ -1918,9 +1850,8 @@ Variable dropout(const Variable& input, double p) { false ); return 1.0 / (1.0 - p) * mask * input; - } else { + } else return input; - } } Variable relu(const Variable& input) { @@ -1935,12 +1866,11 @@ Variable gelu(const Variable& in) { } fl::Variable relativePositionEmbeddingRotate(const fl::Variable& input) { - if(input.ndim() != 3) { + if(input.ndim() != 3) throw std::invalid_argument( "relativePositionEmbeddingRotate - " "input tensor must have 3 dimensions" ); - } auto data = input.tensor(); int d0 = data.dim(0); @@ -1982,24 +1912,21 @@ fl::Variable multiheadAttention( const double pDropout, const int32_t offset /* = 0 */ ) { - if(query.ndim() != 3) { + if(query.ndim() != 3) throw std::invalid_argument( "multiheadAttention - query input tensor should be 3 dimensions: " "Time x (nHeads * headDim) x B" ); - } - if(key.ndim() != 3) { + if(key.ndim() != 3) throw std::invalid_argument( "multiheadAttention - key input tensor should be 3 dimensions: " "Time x (nHeads * headDim) x B" ); - } - if(value.ndim() != 3) { + if(value.ndim() != 3) throw std::invalid_argument( "multiheadAttention - value input tensor should be 3 dimensions: " "Time x (nHeads * headDim) x B" ); - } int32_t bsz = query.dim(2); int32_t modelDim = query.dim(1); @@ -2018,15 +1945,13 @@ fl::Variable multiheadAttention( scores = scores + transpose(pscores(fl::range(n, n + k.dim(0))), {1, 0, 2}); } - if(!mask.isEmpty()) { + if(!mask.isEmpty()) scores = scores + tileAs(mask.astype(scores.type()), scores); - } if(!padMask.isEmpty()) { - if(padMask.dim(0) != query.dim(0)) { + if(padMask.dim(0) != query.dim(0)) throw std::invalid_argument( "multiheadAttention: invalid padding mask size" ); - } auto padMaskTile = moddims(padMask, {1, padMask.dim(0), 1, bsz}); padMaskTile = tileAs(padMaskTile, {padMask.dim(0), padMask.dim(0), nHeads, bsz}); diff --git a/flashlight/fl/autograd/Functions.h b/flashlight/fl/autograd/Functions.h index 2e64e80..b2d23a0 100644 --- a/flashlight/fl/autograd/Functions.h +++ b/flashlight/fl/autograd/Functions.h @@ -60,9 +60,8 @@ namespace detail { T adjustInputType(const T& in, const char* funcname) { OptimLevel optimLevel = OptimMode::get().getOptimLevel(); // Fastpath - DEFAULT mode never casts tensors - if(optimLevel == OptimLevel::DEFAULT) { + if(optimLevel == OptimLevel::DEFAULT) return in; - } T res; auto& funcs = kOptimLevelTypeExclusionMappings.find(optimLevel)->second; @@ -70,16 +69,15 @@ namespace detail { if( funcs.find(std::string(funcname)) == funcs.end() && optimLevel != OptimLevel::DEFAULT - ) { + ) // Not in the excluded list - cast to f16 res = in.astype(fl::dtype::f16); - } else { + else { // Upcast to f32 only if we have an f16 input - otherwise, leave as is - if(in.type() == fl::dtype::f16) { + if(in.type() == fl::dtype::f16) res = in.astype(fl::dtype::f32); - } else { + else res = in; - } } return res; diff --git a/flashlight/fl/autograd/Variable.cpp b/flashlight/fl/autograd/Variable.cpp index 256a11d..5aa3db8 100644 --- a/flashlight/fl/autograd/Variable.cpp +++ b/flashlight/fl/autograd/Variable.cpp @@ -107,13 +107,11 @@ Variable Variable::astype(fl::dtype newType) const { } Variable& Variable::grad() const { - if(!sharedGrad_->calcGrad) { + if(!sharedGrad_->calcGrad) throw std::logic_error("gradient calculation disabled for this Variable"); - } - if(!sharedGrad_->grad) { + if(!sharedGrad_->grad) throw std::logic_error("gradient not calculated yet for this Variable"); - } return *sharedGrad_->grad; } @@ -127,9 +125,8 @@ bool Variable::isCalcGrad() const { } bool Variable::isGradAvailable() const { - if(!sharedGrad_->calcGrad) { + if(!sharedGrad_->calcGrad) return false; - } return sharedGrad_->grad != nullptr; } @@ -146,9 +143,8 @@ bool Variable::isContiguous() const { } Variable Variable::asContiguous() const { - if(!isEmpty() && !isContiguous()) { + if(!isEmpty() && !isContiguous()) tensor() = tensor().asContiguousTensor(); - } return *this; } @@ -209,7 +205,7 @@ void Variable::addGrad(const Variable& childGrad) { << childGrad.shape() << std::endl; throw std::invalid_argument(ss.str()); } - if(sharedGrad_->grad) { + if(sharedGrad_->grad) // Prevent increment of array refcount to avoid a copy // if getting a device pointer. See // https://git.io/fp9oM for more @@ -217,12 +213,11 @@ void Variable::addGrad(const Variable& childGrad) { sharedGrad_->grad->tensor() + childGrad.tensor(), false ); - } else { + else // Copy the childGrad Variable so as to share a reference // to the underlying childGrad.tensor() rather than copying // the tensor into a new variable sharedGrad_->grad = std::make_unique(childGrad); - } } } @@ -243,15 +238,13 @@ void Variable::applyGradHook() { void Variable::calcGradInputs(bool retainGraph) { if(sharedGrad_->gradFunc) { - if(!sharedGrad_->grad) { + if(!sharedGrad_->grad) throw std::logic_error("gradient was not propagated to this Variable"); - } sharedGrad_->gradFunc(sharedGrad_->inputs, *sharedGrad_->grad); } - if(!retainGraph) { + if(!retainGraph) sharedGrad_->inputs.clear(); - } } void Variable::backward(const Variable& grad, bool retainGraph) { @@ -260,9 +253,8 @@ void Variable::backward(const Variable& grad, bool retainGraph) { for(auto iter = dag.rbegin(); iter != dag.rend(); iter++) { iter->calcGradInputs(retainGraph); iter->applyGradHook(); - if(!retainGraph) { + if(!retainGraph) *iter = Variable(); - } } } @@ -288,12 +280,10 @@ Variable::DAG Variable::build() const { // Topological sort recurse = [&](const Variable& var) { auto id = var.sharedGrad_.get(); - if(cache.find(id) != cache.end()) { + if(cache.find(id) != cache.end()) return; - } - for(const auto& input : var.getInputs()) { + for(const auto& input : var.getInputs()) recurse(input); - } cache.insert(id); dag.push_back(var); }; diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp index 1e7dec9..25b4159 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp @@ -29,13 +29,12 @@ namespace { ) { int nfeatures = 1; for(auto ax : axes) { - if(ax > input.ndim() - 1) { + if(ax > input.ndim() - 1) throw std::invalid_argument( "batchnorm - passed axes (axis value " + std::to_string(ax) + ") exceeds the number of dimensions of the input (" + std::to_string(input.ndim()) + ")" ); - } nfeatures *= input.dim(ax); } @@ -44,9 +43,8 @@ namespace { // assuming no duplicates bool axes_continuous = (axes.size() == (maxAxis - minAxis + 1)); - if(!axes_continuous) { + if(!axes_continuous) throw std::invalid_argument("unsupported axis config for cuDNN batchnorm"); - } if(minAxis == 0) { modeOut = CUDNN_BATCHNORM_PER_ACTIVATION; @@ -60,14 +58,13 @@ namespace { } else { modeOut = CUDNN_BATCHNORM_SPATIAL; #if CUDNN_VERSION >= 7003 - if(train) { + if(train) modeOut = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; - } + #endif int batchsz = 1; - for(int i = maxAxis + 1; i < input.ndim(); ++i) { + for(int i = maxAxis + 1; i < input.ndim(); ++i) batchsz *= input.dim(i); - } inDescDimsOut = Shape( { 1, @@ -96,11 +93,10 @@ Tensor CudnnAutogradExtension::batchnorm( const double epsilon, std::shared_ptr ) { - if(input.type() == fl::dtype::f16 && weight.type() != fl::dtype::f32) { + if(input.type() == fl::dtype::f16 && weight.type() != fl::dtype::f32) throw std::invalid_argument( "fl::batchnorm: non-input tensors must be of type f32" ); - } FL_TENSOR_DTYPES_MATCH_CHECK(weight, bias, runningMean, runningVar); auto output = Tensor(input.shape(), input.type()); @@ -109,13 +105,11 @@ Tensor CudnnAutogradExtension::batchnorm( Shape inDescDims, wtDescDims; getBatchnormMetadata(mode, inDescDims, wtDescDims, input, axes, train); - if(!weight.isEmpty() && weight.elements() != wtDescDims.elements()) { + if(!weight.isEmpty() && weight.elements() != wtDescDims.elements()) throw std::invalid_argument("[BatchNorm] Invalid shape for weight."); - } - if(!bias.isEmpty() && bias.elements() != wtDescDims.elements()) { + if(!bias.isEmpty() && bias.elements() != wtDescDims.elements()) throw std::invalid_argument("[BatchNorm] Invalid shape for bias."); - } // Weight, bias, and running mean/var arrays can't be fp16 (must be fp32) Tensor weightArray = weight.isEmpty() ? fl::full(wtDescDims, 1.0, fl::dtype::f32) @@ -172,7 +166,7 @@ Tensor CudnnAutogradExtension::batchnorm( saveVarRaw.get() ) ); - } else { + } else CUDNN_CHECK_ERR( cudnnBatchNormalizationForwardInference( getCudnnHandle(), @@ -191,7 +185,6 @@ Tensor CudnnAutogradExtension::batchnorm( epsilon ) ); - } // ensure output stream waits on cudnn compute stream relativeSync({output}, cudnnStream); } @@ -209,11 +202,10 @@ std::tuple CudnnAutogradExtension::batchnormBackward( const float epsilon, std::shared_ptr ) { - if(!train) { + if(!train) throw std::logic_error( "can't compute batchnorm grad when train was not specified" ); - } cudnnBatchNormMode_t mode; Shape inDescDims, wtDescDims; diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp index 0102cf5..84ea052 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp @@ -63,7 +63,7 @@ namespace { ) { T reserved; bool algoFound = false; - for(const auto& algoPerf : algoPerfs) { + for(const auto& algoPerf : algoPerfs) if( algoPerf.status == CUDNN_STATUS_SUCCESS && algoPerf.memory < kWorkspaceSizeLimitBytes @@ -71,19 +71,17 @@ namespace { if( !(arithmeticPrecision == fl::dtype::f16) || (preferredAlgos.find(algoPerf.algo) != preferredAlgos.end()) - ) { + ) return algoPerf; - } else if(!algoFound) { + else if(!algoFound) { reserved = algoPerf; algoFound = true; } } - } - if(algoFound) { + if(algoFound) return reserved; - } else { + else throw std::runtime_error("Error while finding cuDNN Conv Algorithm."); - } } cudnnConvolutionFwdAlgoPerf_t getFwdAlgo( @@ -171,26 +169,23 @@ namespace { if( arithmeticPrecision != fl::dtype::f16 && bestAlgo.algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_1 - ) { + ) isAlgoBlacklisted = true; - } + #endif #if CUDNN_VERSION < 7500 if( isStrided && (bestAlgo.algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING || bestAlgo.algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT) - ) { + ) isAlgoBlacklisted = true; - } + #endif - if(isAlgoBlacklisted) { - for(const auto& algoPerf : bwdDataAlgoPerfs) { - if(algoPerf.algo == kBwdDataDefaultAlgo) { + if(isAlgoBlacklisted) + for(const auto& algoPerf : bwdDataAlgoPerfs) + if(algoPerf.algo == kBwdDataDefaultAlgo) return algoPerf; - } - } - } return bestAlgo; } @@ -238,17 +233,14 @@ namespace { if( arithmeticPrecision != fl::dtype::f16 && bestAlgo.algo == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 - ) { + ) isAlgoBlacklisted = true; - } + #endif - if(isAlgoBlacklisted) { - for(const auto& algoPerf : bwdFilterAlgoPerfs) { - if(algoPerf.algo == kBwdFilterDefaultAlgo) { + if(isAlgoBlacklisted) + for(const auto& algoPerf : bwdFilterAlgoPerfs) + if(algoPerf.algo == kBwdFilterDefaultAlgo) return algoPerf; - } - } - } return bestAlgo; } @@ -276,18 +268,17 @@ namespace { } void setDefaultMathType(ConvDescriptor& cDesc, const Tensor& input) { - if(input.type() == fl::dtype::f16) { + if(input.type() == fl::dtype::f16) CUDNN_CHECK_ERR( cudnnSetConvolutionMathType( cDesc.descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION ) ); - } else { + else CUDNN_CHECK_ERR( cudnnSetConvolutionMathType(cDesc.descriptor, CUDNN_DEFAULT_MATH) ); - } } } // namespace @@ -305,31 +296,29 @@ Tensor CudnnAutogradExtension::conv2d( const int groups, std::shared_ptr ) { - if(input.ndim() != 4) { + if(input.ndim() != 4) throw std::invalid_argument( "conv2d: expects input tensor to be 4 dimensions: " "in WHCN ordering. Given tensor has " + std::to_string(input.ndim()) + " dimensions." ); - } auto hasBias = bias.elements() > 0; auto inDesc = TensorDescriptor(input); auto wtDesc = FilterDescriptor(weights); auto convDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups); - if(input.type() == fl::dtype::f16) { + if(input.type() == fl::dtype::f16) CUDNN_CHECK_ERR( cudnnSetConvolutionMathType( convDesc.descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION ) ); - } else { + else CUDNN_CHECK_ERR( cudnnSetConvolutionMathType(convDesc.descriptor, CUDNN_DEFAULT_MATH) ); - } std::array odims; CUDNN_CHECK_ERR( @@ -472,12 +461,11 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( FilterDescriptor& wDesc, ConvDescriptor& cDesc, TensorDescriptor& oDesc) -> Tensor { - if(dataGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { + if(dataGradBenchmark && DynamicBenchmark::getBenchmarkMode()) setCudnnConvMathType( cDesc, dataGradBenchmark->getOptions>() ); - } DevicePtr wPtr(wtTensor); // ensure cudnn compute stream waits on stream of weight tensor @@ -608,7 +596,7 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( } ); - } else { + } else dataGradBenchmark->audit( [&dataGradOut, &convolutionBackwardData, @@ -630,9 +618,8 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( ); } ); - } - } else { + } else // No benchmarking - proceed normally dataGradOut = convolutionBackwardData( input, @@ -643,7 +630,6 @@ Tensor CudnnAutogradExtension::conv2dBackwardData( cDesc, oDesc ); - } return dataGradOut; } @@ -693,13 +679,12 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( FilterDescriptor& wDesc, ConvDescriptor& cDesc, TensorDescriptor& oDesc) -> Tensor { - if(filterGradBenchmark && DynamicBenchmark::getBenchmarkMode()) { + if(filterGradBenchmark && DynamicBenchmark::getBenchmarkMode()) setCudnnConvMathType( cDesc, filterGradBenchmark ->getOptions>() ); - } DevicePtr iPtr(inTensor); // ensure cudnn compute stream waits on stream of input tensor @@ -828,7 +813,7 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( } ); - } else { + } else filterGradBenchmark->audit( [&filterGradOut, &convolutionBackwardFilter, @@ -850,9 +835,8 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( ); } ); - } - } else { + } else filterGradOut = convolutionBackwardFilter( input, weight, @@ -862,7 +846,6 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( cDesc, oDesc ); - } auto convolutionBackwardBias = [&hndl, &cudnnStream, oneg, zerog]( const Tensor& bsTensor, @@ -931,7 +914,7 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( convolutionBackwardBias(biasF32, gradOutputF32, oDescF32); } ); - } else { + } else // Grad output and bias types are already the same, so perform the // computation using whatever input type is given biasGradBenchmark->audit( @@ -943,11 +926,9 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias( biasGradOut = convolutionBackwardBias(bias, gradOutput, oDesc); } ); - } - } else { + } else // No benchmark; proceed normally biasGradOut = convolutionBackwardBias(bias, gradOutput, oDesc); - } } return {filterGradOut, biasGradOut}; diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp index db30ea6..bf4edeb 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp @@ -86,9 +86,8 @@ struct CudnnDropoutStruct { namespace fl { void cudnnCheckErr(cudnnStatus_t status) { - if(status == CUDNN_STATUS_SUCCESS) { + if(status == CUDNN_STATUS_SUCCESS) return; - } const char* err = cudnnGetErrorString(status); switch(status) { case CUDNN_STATUS_BAD_PARAM: @@ -146,15 +145,13 @@ TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) { std::array dims = {1, 1, 1, 1}; // We want, if dims exist: // {flDims[3], flDims[2], flDims[1], flDims[0]}; - for(unsigned i = 0; i < flDims.ndim(); ++i) { + for(unsigned i = 0; i < flDims.ndim(); ++i) dims[3 - i] = flDims[i]; - } // Sets strides so array is contiguous row-major for cudnn std::vector r_strides = {1}; - for(auto it = dims.rbegin(); it != dims.rend() - 1; ++it) { + for(auto it = dims.rbegin(); it != dims.rend() - 1; ++it) r_strides.push_back(r_strides.back() * (*it)); - } std::vector strides(r_strides.rbegin(), r_strides.rend()); CUDNN_CHECK_ERR( @@ -178,15 +175,13 @@ TensorDescriptor::TensorDescriptor(const Tensor& input) { // reverse the dims (column -> row major) and cast to int type std::array strides = {1, 1, 1, 1}; // {flStrides[3], flStrides[2], flStrides[1], flStrides[0]}; - for(unsigned i = 0; i < flStrides.ndim(); ++i) { + for(unsigned i = 0; i < flStrides.ndim(); ++i) strides[3 - i] = flStrides[i]; - } std::array dims = {1, 1, 1, 1}; // {flDims[3], flDims[2], flDims[1], flDims[0]}; - for(unsigned i = 0; i < flDims.ndim(); ++i) { + for(unsigned i = 0; i < flDims.ndim(); ++i) dims[3 - i] = flDims[i]; - } CUDNN_CHECK_ERR( cudnnSetTensorNdDescriptor( @@ -258,9 +253,8 @@ FilterDescriptor::FilterDescriptor(const Tensor& input) { std::array dims = {1, 1, 1, 1}; // We want, if dims exist: // {flDims[3], flDims[2], flDims[1], flDims[0]}; - for(unsigned i = 0; i < flDims.ndim(); ++i) { + for(unsigned i = 0; i < flDims.ndim(); ++i) dims[3 - i] = flDims[i]; - } CUDNN_CHECK_ERR( cudnnSetFilterNdDescriptor( diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp index 1c85956..ec209fd 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/RNN.cpp @@ -54,18 +54,17 @@ namespace { } void setCudnnRnnMathType(const Tensor& input, const RNNDescriptor& rnnDesc) { - if(input.type() == fl::dtype::f16) { + if(input.type() == fl::dtype::f16) CUDNN_CHECK_ERR( cudnnSetRNNMatrixMathType( rnnDesc.descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION ) ); - } else { + else CUDNN_CHECK_ERR( cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH) ); - } } struct CudnnRnnAutogradPayload : public detail::AutogradPayloadData { @@ -90,9 +89,8 @@ std::tuple CudnnAutogradExtension::rnn( bool train = (autogradPayload != nullptr); auto payload = std::make_shared(); - if(train) { + if(train) autogradPayload->data = payload; - } Tensor x = input.asContiguousTensor(); Tensor hiddenState = hiddenStateIn.asContiguousTensor(); @@ -123,18 +121,16 @@ std::tuple CudnnAutogradExtension::rnn( if( !(hxHiddenSize == hiddenSize && hxBatchSize == batchSize && hxTotalLayers == totalLayers) - ) { + ) throw std::invalid_argument("invalid hidden state dims for RNN"); - } } if( !cellState.isEmpty() && !(mode == RnnMode::LSTM && cellState.dim(0) == hiddenSize && cellState.dim(1) == batchSize && cellState.dim(2) == totalLayers) - ) { + ) throw std::invalid_argument("invalid cell state dims for RNN"); - } Shape hDims = {1, hiddenSize, batchSize, totalLayers}; TensorDescriptor hxDesc(x.type(), hDims); @@ -153,11 +149,10 @@ std::tuple CudnnAutogradExtension::rnn( cudnnMapToType(weights.type()) ) ); - if(paramSize != weights.bytes()) { + if(paramSize != weights.bytes()) throw std::invalid_argument( "invalid # of parameters or wrong input shape for RNN" ); - } FilterDescriptor wDesc(weights); Tensor y({outSize, batchSize, seqLength}, input.type()); @@ -167,9 +162,8 @@ std::tuple CudnnAutogradExtension::rnn( TensorDescriptor hyDesc(x.type(), hDims); Tensor cy; - if(mode == RnnMode::LSTM) { + if(mode == RnnMode::LSTM) cy = Tensor(hy.shape(), x.type()); - } TensorDescriptor cyDesc(x.type(), hDims); @@ -250,11 +244,10 @@ std::tuple CudnnAutogradExtension::rnnBackward( const float dropProb, std::shared_ptr autogradPayload ) { - if(!autogradPayload) { + if(!autogradPayload) throw std::invalid_argument( "CudnnAutogradExtension::rnnBackward given null detail::AutogradPayload" ); - } auto payload = std::static_pointer_cast(autogradPayload->data); @@ -301,9 +294,8 @@ std::tuple CudnnAutogradExtension::rnnBackward( Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8); auto& dy = gradData->dy; - if(dy.isEmpty()) { + if(dy.isEmpty()) dy = fl::full(y.shape(), 0.0, y.type()); - } auto& dhy = gradData->dhy; auto& dcy = gradData->dcy; @@ -367,18 +359,17 @@ std::tuple CudnnAutogradExtension::rnnBackward( ); } - if(input.type() == fl::dtype::f16) { + if(input.type() == fl::dtype::f16) CUDNN_CHECK_ERR( cudnnSetRNNMatrixMathType( rnnDesc.descriptor, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION ) ); - } else { + else CUDNN_CHECK_ERR( cudnnSetRNNMatrixMathType(rnnDesc.descriptor, CUDNN_DEFAULT_MATH) ); - } TensorDescriptorArray xDescs( seqLength, x.type(), {1, 1, inputSize, batchSize}); Tensor dw = fl::full(weights.shape(), 0, weights.type()); diff --git a/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp b/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp index 285ec91..aca3ea5 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp @@ -31,9 +31,8 @@ namespace { int getNfeatures(const Shape& inputShape, const std::vector& axes) { int nfeatures = 1; - for(auto ax : axes) { + for(auto ax : axes) nfeatures *= inputShape.dim(ax); - } return nfeatures; } @@ -44,18 +43,17 @@ namespace { const int nfeatures ) { Shape inDescDims; - if(minAxis == 0) { + if(minAxis == 0) inDescDims = Shape( {1, 1, nfeatures, static_cast(input.elements() / nfeatures)} ); - } else { + else { int batchsz = 1; - for(int i = maxAxis + 1; i < input.ndim(); ++i) { + for(int i = maxAxis + 1; i < input.ndim(); ++i) batchsz *= input.dim(i); - } inDescDims = Shape( {1, static_cast(input.elements() / (nfeatures * batchsz)), @@ -102,36 +100,30 @@ Tensor OneDnnAutogradExtension::batchnorm( const double epsilon, std::shared_ptr autogradPayload ) { - if(momentum != 0.) { + if(momentum != 0.) throw std::runtime_error("OneDNN batchnorm op doesn't support momentum."); - } - if(input.type() == fl::dtype::f16) { + if(input.type() == fl::dtype::f16) throw std::runtime_error("OneDNN batchnorm op - f16 inputs not supported."); - } auto payload = std::make_shared(); - if(train && autogradPayload) { + if(train && autogradPayload) autogradPayload->data = payload; - } auto output = Tensor(input.shape(), input.type()); int nfeatures = getNfeatures(input.shape(), axes); - if(runningVar.isEmpty()) { + if(runningVar.isEmpty()) runningVar = fl::full({nfeatures}, 1., input.type()); - } - if(runningMean.isEmpty()) { + if(runningMean.isEmpty()) runningMean = fl::full({nfeatures}, 0., input.type()); - } // Check if axes are valid auto maxAxis = *std::max_element(axes.begin(), axes.end()); auto minAxis = *std::min_element(axes.begin(), axes.end()); bool axesContinuous = (axes.size() == (maxAxis - minAxis + 1)); - if(!axesContinuous) { + if(!axesContinuous) throw std::invalid_argument("axis array should be continuous"); - } auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); @@ -215,11 +207,10 @@ std::tuple OneDnnAutogradExtension::batchnormBackward( const float epsilon, std::shared_ptr autogradPayload ) { - if(!autogradPayload) { + if(!autogradPayload) throw std::invalid_argument( "OneDnnAutogradExtension::pool2dBackward given null detail::AutogradPayload" ); - } auto payload = std::static_pointer_cast(autogradPayload->data); @@ -228,9 +219,8 @@ std::tuple OneDnnAutogradExtension::batchnormBackward( auto maxAxis = *std::max_element(axes.begin(), axes.end()); auto minAxis = *std::min_element(axes.begin(), axes.end()); const bool axesContinuous = (axes.size() == (maxAxis - minAxis + 1)); - if(!axesContinuous) { + if(!axesContinuous) throw std::invalid_argument("axis array should be continuous"); - } const int nfeatures = getNfeatures(input.shape(), axes); auto inputOutputDims = getInputOutputDims(minAxis, maxAxis, input, nfeatures); diff --git a/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp b/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp index a83b62b..881e076 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp @@ -79,14 +79,14 @@ namespace { inputShape.dim(kHIdx), inputShape.dim(kWIdx)} ); - if(groups == 1) { + if(groups == 1) out.weightDims = detail::convertToDnnlDims( {weightsShape.dim(kWeightOutputChannelSizeIdx), inputShape.dim(kIOChannelSizeIdx), weightsShape.dim(kHIdx), weightsShape.dim(kWIdx)} ); - } else { + else out.weightDims = detail::convertToDnnlDims( {groups, weightsShape.dim(kWeightOutputChannelSizeIdx) / groups, @@ -94,7 +94,6 @@ namespace { weightsShape.dim(kHIdx), weightsShape.dim(kWIdx)} ); - } out.outputDims = detail::convertToDnnlDims( {inputShape.dim(kIOBatchSizeIdx), weightsShape.dim(kWeightOutputChannelSizeIdx), @@ -125,7 +124,7 @@ namespace { auto& dnnlEngine = detail::DnnlEngine::getInstance().getEngine(); convolution_forward::primitive_desc fwdPrimitiveDescriptor; - if(hasBias) { + if(hasBias) fwdPrimitiveDescriptor = convolution_forward::primitive_desc( dnnlEngine, forwardMode, @@ -139,7 +138,7 @@ namespace { out.paddingDims, out.paddingDims ); - } else { + else fwdPrimitiveDescriptor = convolution_forward::primitive_desc( dnnlEngine, forwardMode, @@ -152,7 +151,6 @@ namespace { out.paddingDims, out.paddingDims ); - } out.fwdPrimDesc = std::move(fwdPrimitiveDescriptor); return out; @@ -173,9 +171,8 @@ Tensor OneDnnAutogradExtension::conv2d( const int groups, std::shared_ptr ) { - if(input.type() == fl::dtype::f16) { + if(input.type() == fl::dtype::f16) throw std::runtime_error("Half precision is not supported in CPU."); - } // flashlight input, weight, and output shapes in column-major: // - Input is WHCN @@ -252,9 +249,8 @@ Tensor OneDnnAutogradExtension::conv2d( ); // Output - adds a reorder after the conv if needed auto outputMemory = outputMemInit.getMemory(); - if(outputMemInit.getMemory().get_desc() != outputDesc) { + if(outputMemInit.getMemory().get_desc() != outputDesc) outputMemory = memory(outputDesc, dnnlEngine); - } // Create convolution std::shared_ptr conv; @@ -269,9 +265,8 @@ Tensor OneDnnAutogradExtension::conv2d( {DNNL_ARG_SRC, inputMemory}, {DNNL_ARG_WEIGHTS, weightsMemory}, {DNNL_ARG_DST, outputMemory}}; - if(hasBias) { + if(hasBias) convFwdArgs[DNNL_ARG_BIAS] = biasMemory.getMemory(); - } fwdArgs.push_back(convFwdArgs); // Add output reordering if needed @@ -370,9 +365,8 @@ Tensor OneDnnAutogradExtension::conv2dBackwardData( ); auto gradInputMemory = gradInputMemInit.getMemory(); // Don't reorder the gradient until after the conv - if(gradInputMemInit.getMemory().get_desc() != gradInputDesc) { + if(gradInputMemInit.getMemory().get_desc() != gradInputDesc) gradInputMemory = memory(gradInputDesc, dnnlEngineBwd); - } // Convolution backwards auto convBwdData = @@ -439,13 +433,12 @@ std::pair OneDnnAutogradExtension::conv2dBackwardFilterBias( Tensor gradBias; bool computeBiasGrad = !bias.isEmpty() && !conv2DData.biasMemDesc.is_zero(); - if(computeBiasGrad) { + if(computeBiasGrad) gradBias = Tensor(bias.shape(), bias.type()); - } // Weight backward descriptor convolution_backward_weights::primitive_desc bwdWeightPrimitiveDesc; - if(computeBiasGrad) { + if(computeBiasGrad) bwdWeightPrimitiveDesc = convolution_backward_weights::primitive_desc( dnnlEngineBwd, algorithm::convolution_direct, @@ -459,7 +452,7 @@ std::pair OneDnnAutogradExtension::conv2dBackwardFilterBias( conv2DData.paddingDims, conv2DData.fwdPrimDesc ); - } else { + else bwdWeightPrimitiveDesc = convolution_backward_weights::primitive_desc( dnnlEngineBwd, algorithm::convolution_direct, @@ -472,7 +465,6 @@ std::pair OneDnnAutogradExtension::conv2dBackwardFilterBias( conv2DData.paddingDims, conv2DData.fwdPrimDesc ); - } // Weight backward primitive descriptor auto bwdWeights = std::make_shared(bwdWeightPrimitiveDesc); @@ -506,9 +498,8 @@ std::pair OneDnnAutogradExtension::conv2dBackwardFilterBias( ); // Don't reorder the grads until after the conv bwd auto gradWeightsMemory = gradWeightsMemInit.getMemory(); - if(gradWeightsMemInit.getMemory().get_desc() != gradWeightsDesc) { + if(gradWeightsMemInit.getMemory().get_desc() != gradWeightsDesc) gradWeightsMemory = memory(gradWeightsDesc, dnnlEngineBwd); - } // Create the convolution backward weight std::unordered_map bwdConvWeightsArgs = { diff --git a/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp b/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp index 371c46c..6464062 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp @@ -127,28 +127,24 @@ void executeNetwork( std::vector& net, std::vector>& netArgs ) { - if(net.size() != netArgs.size()) { + if(net.size() != netArgs.size()) throw std::invalid_argument( "executeNetwork - given different size nets and netArgs" ); - } // TODO{fl::Tensor}{macros} -- improve this to work with other backend interop // If on the CPU backend, there isn't a AF computation stream that facilitates // enforcing that inputs to computation are ready; we're required to wait // until all AF operations are done - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) fl::sync(); - } - for(size_t i = 0; i < net.size(); ++i) { + for(size_t i = 0; i < net.size(); ++i) net.at(i).execute(DnnlStream::getInstance().getStream(), netArgs.at(i)); - } // TODO{fl::Tensor}{macros} -- improve this to work with other backend interop - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) // Block the executing thread until the work is complete DnnlStream::getInstance().getStream().wait(); - } } dnnl::algorithm dnnlMapToPoolingMode(const PoolingMode mode) { diff --git a/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.h b/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.h index 5869c49..fc8c09d 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.h +++ b/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.h @@ -135,15 +135,14 @@ namespace detail { * Needs to be explicitly inlined due to a bug with DNNL. */ inline dnnl::memory::data_type dnnlMapToType(const fl::dtype t) { - if(t == fl::dtype::f16) { + if(t == fl::dtype::f16) return dnnl::memory::data_type::f16; - } else if(t == fl::dtype::f32) { + else if(t == fl::dtype::f32) return dnnl::memory::data_type::f32; - } else if(t == fl::dtype::f64) { + else if(t == fl::dtype::f64) throw std::invalid_argument("float64 is not supported by DNNL"); - } else { + else throw std::invalid_argument("data type not supported with DNNL"); - } } } // namespace detail diff --git a/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp b/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp index 5d976e2..c6523e1 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp @@ -91,9 +91,8 @@ Tensor OneDnnAutogradExtension::pool2d( ) { const bool train = (autogradPayload != nullptr); auto payload = std::make_shared(); - if(train) { + if(train) autogradPayload->data = payload; - } // inputX x inputY x channels x batch auto ix = input.dim(kWIdx); @@ -154,9 +153,8 @@ Tensor OneDnnAutogradExtension::pool2d( inputDesc ); payload->outputMemory = outputMemInit.getMemory(); - if(outputMemInit.getMemory().get_desc() != outputDesc) { + if(outputMemInit.getMemory().get_desc() != outputDesc) payload->outputMemory = memory(outputDesc, dnnlEngine); - } // Workspace and layer (only training mode requires a workspace) std::shared_ptr pooling; std::unordered_map fwdPoolingArgs; @@ -166,9 +164,8 @@ Tensor OneDnnAutogradExtension::pool2d( payload->workspace = memory(primDesc.workspace_desc(), dnnlEngine); pooling = std::make_shared(primDesc); fwdPoolingArgs[DNNL_ARG_WORKSPACE] = payload->workspace; - } else { + } else pooling = std::make_shared(primDesc); - } network.push_back(*pooling); fwdArgs.push_back(fwdPoolingArgs); @@ -200,11 +197,10 @@ Tensor OneDnnAutogradExtension::pool2dBackward( const PoolingMode mode, std::shared_ptr autogradPayload ) { - if(!autogradPayload) { + if(!autogradPayload) throw std::invalid_argument( "OneDnnAutogradExtension::pool2dBackward given null detail::AutogradPayload" ); - } auto payload = std::static_pointer_cast(autogradPayload->data); diff --git a/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp b/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp index e2b2836..63f2316 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp @@ -36,14 +36,13 @@ namespace { // LBR GRU requires switch the given the r, u, o gate order from cuDNN to u, // r, o as required by oneDNN (this from empirical verification) int weightsSize = d1 * d2; - if(weights.elements() != weightsSize * 3) { + if(weights.elements() != weightsSize * 3) throw std::invalid_argument( "RNN reorderLbrGruWeights given invalid weights tensor or dims - " "weights of size " + std::to_string(weights.elements()) + " which should be exactly " + std::to_string(weightsSize * 3) ); - } return fl::concatenate( 0, weights.flat(fl::range(weightsSize, 2 * weightsSize)), @@ -231,10 +230,9 @@ namespace { if(firstLayerDifferent) { out.bias1L = bias.flat(fl::range(biasSize / numLayers)); - if(numLayers > 1) { + if(numLayers > 1) // bias for the second --> last layer bias = bias.flat(fl::range(biasSize / numLayers, fl::end)); - } } out.bias = bias; @@ -305,9 +303,8 @@ namespace { auto y = Tensor({outSize, batchSize, seqLength}, input.type()); auto hy = Tensor({hiddenSize, batchSize, totalLayers}, input.type()); Tensor cy; - if(mode == RnnMode::LSTM) { + if(mode == RnnMode::LSTM) cy = Tensor(hy.shape(), input.type()); - } // Memory for forward auto tnc = dnnl::memory::format_tag::tnc; @@ -318,13 +315,12 @@ namespace { input.asContiguousTensor(), {inputDims}, tnc); const detail::DnnlMemoryWrapper outputMemInit(y, {outputDims}, tnc); detail::DnnlMemoryWrapper hiddenInMemInit; - if(!hiddenState.isEmpty()) { + if(!hiddenState.isEmpty()) hiddenInMemInit = detail::DnnlMemoryWrapper( hiddenState.asContiguousTensor(), {hDims}, ldnc ); - } const detail::DnnlMemoryWrapper hiddenOutMemInit(hy, {hDims}, ldnc); const detail::DnnlMemoryWrapper weightsInputMemRawInit( weightsInput.asContiguousTensor(), {weightsInputDims}, ldgoi); @@ -408,13 +404,12 @@ namespace { // which determines whether or not it's ok to return empty // descriptors if the array is empty detail::DnnlMemoryWrapper cellInMemInit; - if(!cellState.isEmpty()) { + if(!cellState.isEmpty()) cellInMemInit = detail::DnnlMemoryWrapper( cellState.asContiguousTensor(), {cDims}, ldnc ); - } // output cell state detail::DnnlMemoryWrapper cellOutMemInit(cy, cDims, ldnc); @@ -480,12 +475,10 @@ std::tuple OneDnnAutogradExtension::rnn( const float dropout, std::shared_ptr autogradPayload ) { - if(dropout > 0.0) { + if(dropout > 0.0) throw std::invalid_argument("onednn RNN: dropout > 0.0 unsupported"); - } - if(bidirectional) { + if(bidirectional) throw std::invalid_argument("onednn RNN: bidirectional not yet supported"); - } const bool train = (autogradPayload != nullptr); @@ -540,7 +533,7 @@ std::tuple OneDnnAutogradExtension::rnn( // that output as the input for layers [2, L]. Since the input size dim 0 // is now the hidden size, the primitive can fuse computation for // arbitrarily-many layers. - if(input.dim(0) == hiddenSize || numLayers == 1) { + if(input.dim(0) == hiddenSize || numLayers == 1) // Input and hidden size are the same, or we only have one layer, which // means we can call the impl as is and parse weights "normally" result = rnnImpl( @@ -560,7 +553,7 @@ std::tuple OneDnnAutogradExtension::rnn( kind, dropout ); - } else { + else { // We require more than one layer with different input and hidden states - // see the above. Seek to the first layer's hidden/cell state, weights, and // bias diff --git a/flashlight/fl/common/Defines.cpp b/flashlight/fl/common/Defines.cpp index 9e8ae0c..ec1b3c2 100644 --- a/flashlight/fl/common/Defines.cpp +++ b/flashlight/fl/common/Defines.cpp @@ -28,12 +28,11 @@ OptimMode& OptimMode::get() { OptimLevel OptimMode::toOptimLevel(const std::string& in) { auto l = kStringToOptimLevel.find(in); - if(l == kStringToOptimLevel.end()) { + if(l == kStringToOptimLevel.end()) throw std::invalid_argument( "OptimMode::toOptimLevel - no matching " "optim level for given string." ); - } return l->second; } diff --git a/flashlight/fl/common/DevicePtr.cpp b/flashlight/fl/common/DevicePtr.cpp index 2aef961..4536f23 100644 --- a/flashlight/fl/common/DevicePtr.cpp +++ b/flashlight/fl/common/DevicePtr.cpp @@ -14,22 +14,20 @@ namespace fl { DevicePtr::DevicePtr(const Tensor& in) : tensor_(std::make_unique(in.shallowCopy())) { - if(tensor_->isEmpty()) { + if(tensor_->isEmpty()) ptr_ = nullptr; - } else { - if(!tensor_->isContiguous()) { + else { + if(!tensor_->isContiguous()) throw std::invalid_argument( "can't get device pointer of non-contiguous Tensor" ); - } ptr_ = tensor_->device(); } } DevicePtr::~DevicePtr() { - if(ptr_ != nullptr) { + if(ptr_ != nullptr) tensor_->unlock(); - } } DevicePtr::DevicePtr(DevicePtr&& d) noexcept : tensor_(std::move(d.tensor_)), @@ -38,9 +36,8 @@ DevicePtr::DevicePtr(DevicePtr&& d) noexcept : tensor_(std::move(d.tensor_)), } DevicePtr& DevicePtr::operator=(DevicePtr&& other) noexcept { - if(ptr_ != nullptr) { + if(ptr_ != nullptr) tensor_->unlock(); - } tensor_ = std::move(other.tensor_); ptr_ = other.ptr_; other.ptr_ = nullptr; diff --git a/flashlight/fl/common/DynamicBenchmark.cpp b/flashlight/fl/common/DynamicBenchmark.cpp index 96797a6..dbbbd07 100644 --- a/flashlight/fl/common/DynamicBenchmark.cpp +++ b/flashlight/fl/common/DynamicBenchmark.cpp @@ -20,9 +20,9 @@ void DynamicBenchmark::audit( // Only run the benchmarking components if some options are yet to be // fully-timed and benchmark mode is on - otherwise, only run the passed // lambda - if(options_->timingsComplete() || !benchmarkMode_) { + if(options_->timingsComplete() || !benchmarkMode_) function(); - } else { + else { start(); function(); stop(incrementCount); diff --git a/flashlight/fl/common/DynamicBenchmark.h b/flashlight/fl/common/DynamicBenchmark.h index cc17ce8..d0fdebe 100644 --- a/flashlight/fl/common/DynamicBenchmark.h +++ b/flashlight/fl/common/DynamicBenchmark.h @@ -72,12 +72,11 @@ struct DynamicBenchmarkOptions : DynamicBenchmarkOptionsBase { */ DynamicBenchmarkOptions(std::vector options, size_t benchCount) : options_(options), benchCount_(benchCount) { - if(options_.empty()) { + if(options_.empty()) throw std::invalid_argument( "DynamicBenchmarkOptions: " "Options must be passed vector with at least one element" ); - } reset(); } @@ -105,22 +104,19 @@ struct DynamicBenchmarkOptions : DynamicBenchmarkOptionsBase { */ T updateState() { if(!timingsComplete_) { - for(size_t i = 0; i < options_.size(); ++i) { + for(size_t i = 0; i < options_.size(); ++i) if(counts_[i] < benchCount_) { currentOptionIdx_ = i; return options_[i]; } - } timingsComplete_ = true; // All options have been benchmarked with the max count - pick the one // with the lowest time size_t minTimeOptionIdx{0}; - for(size_t i = 0; i < options_.size(); ++i) { - if(times_[i] < times_[minTimeOptionIdx]) { + for(size_t i = 0; i < options_.size(); ++i) + if(times_[i] < times_[minTimeOptionIdx]) minTimeOptionIdx = i; - } - } currentOptionIdx_ = minTimeOptionIdx; } return options_[currentOptionIdx_]; @@ -155,17 +151,15 @@ struct DynamicBenchmarkOptions : DynamicBenchmarkOptionsBase { */ void accumulateTimeToCurrentOption(double time, bool incrementCount = true) override { - if(timingsComplete()) { + if(timingsComplete()) throw std::invalid_argument( "Options::accumulateTimeToCurrentOption: " "Tried to accumulate time when benchmarking is complete" ); - } updateState(); times_[currentOptionIdx_] += time; - if(incrementCount) { + if(incrementCount) counts_[currentOptionIdx_]++; - } } /** diff --git a/flashlight/fl/common/Histogram.cpp b/flashlight/fl/common/Histogram.cpp index abca254..aecc058 100644 --- a/flashlight/fl/common/Histogram.cpp +++ b/flashlight/fl/common/Histogram.cpp @@ -14,32 +14,30 @@ namespace fl { void shortFormatCount(std::stringstream& ss, size_t count) { constexpr size_t stringLen = 5; - if(count >= 10e13) { // >= 10 trillion + if(count >= 10e13) // >= 10 trillion ss << std::setw(stringLen - 1) << (count / (size_t) 10e12) << 't'; - } else if(count >= 10e10) { // >= 10 billion + else if(count >= 10e10) // >= 10 billion ss << std::setw(stringLen - 1) << (count / (size_t) 10e9) << 'b'; - } else if(count >= 10e7) { // >= 10 million + else if(count >= 10e7) // >= 10 million ss << std::setw(stringLen - 1) << (count / (size_t) 10e6) << 'm'; - } else if(count >= 10e4) { // >= 10 thousand + else if(count >= 10e4) // >= 10 thousand ss << std::setw(stringLen - 1) << (count / (size_t) 10e3) << 'k'; - } else { + else ss << std::setw(stringLen) << count; - } } void shortFormatMemory(std::stringstream& ss, size_t size) { constexpr size_t stringLen = 5; - if(size >= (1ULL << 43)) { // >= 8TB + if(size >= (1ULL << 43)) // >= 8TB ss << std::setw(stringLen - 1) << (size >> 40) << "T"; - } else if(size >= (1ULL << 33)) { // >= 8G B + else if(size >= (1ULL << 33)) // >= 8G B ss << std::setw(stringLen - 1) << (size >> 30) << "G"; - } else if(size >= (1ULL << 23)) { // >= 8M B + else if(size >= (1ULL << 23)) // >= 8M B ss << std::setw(stringLen - 1) << (size >> 20) << "M"; - } else if(size >= (1ULL << 13)) { // >= 8K B + else if(size >= (1ULL << 13)) // >= 8K B ss << std::setw(stringLen - 1) << (size >> 10) << "K"; - } else { + else ss << std::setw(stringLen) << size; - } } } // namespace fl diff --git a/flashlight/fl/common/Histogram.h b/flashlight/fl/common/Histogram.h index 8cc557b..4ddbb49 100644 --- a/flashlight/fl/common/Histogram.h +++ b/flashlight/fl/common/Histogram.h @@ -82,14 +82,11 @@ struct HistogramStats { template bool isAdditionSafe(T a, T b) { - if(a > (std::numeric_limits::max() - b)) { + if(a > (std::numeric_limits::max() - b)) return false; - } - if(std::is_signed::value) { - if(a < 0 && b < 0 && (a < (std::numeric_limits::min() - b))) { + if(std::is_signed::value) + if(a < 0 && b < 0 && (a < (std::numeric_limits::min() - b))) return false; - } - } return true; } @@ -110,17 +107,15 @@ HistogramStats FixedBucketSizeHistogram( T clipMinValueInclusive = std::numeric_limits::min(), T clipMaxValueExclusive = std::numeric_limits::max() ) { - if(!nBuckets) { + if(!nBuckets) throw std::invalid_argument( "FixedBucketSizeHistogram(nBuckets=0) nBuckets " "must be a positive integer" ); - } HistogramStats stats; - if(begin == end) { + if(begin == end) return stats; - } stats.min = std::numeric_limits::max(); stats.max = std::numeric_limits::min(); @@ -129,15 +124,13 @@ HistogramStats FixedBucketSizeHistogram( // Calculate min/max, sum, ands mean double simpleMovingAverage = 0.0; for(auto itr = begin; itr != end; ++itr) { - if((*itr < clipMinValueInclusive) || (*itr >= clipMaxValueExclusive)) { + if((*itr < clipMinValueInclusive) || (*itr >= clipMaxValueExclusive)) continue; - } if(!stats.sumOverflow) { - if(isAdditionSafe(stats.sum, *itr)) { + if(isAdditionSafe(stats.sum, *itr)) stats.sum += *itr; - } else { + else stats.sumOverflow = true; - } } stats.min = std::min(stats.min, *itr); @@ -161,9 +154,8 @@ HistogramStats FixedBucketSizeHistogram( // Calculate count per bucket stats.maxNumValuesPerBucket = 0; for(auto itr = begin; itr != end; ++itr) { - if(*itr < clipMinValueInclusive || *itr >= clipMaxValueExclusive) { + if(*itr < clipMinValueInclusive || *itr >= clipMaxValueExclusive) continue; - } double index = std::floor(static_cast(*itr - stats.min) / bucketWidth); size_t intIndex = std::min(static_cast(index), nBuckets - 1); @@ -203,9 +195,8 @@ std::string HistogramBucket::prettyString( fromatCountIntoStream(ss, count); ss << ": "; const double numTicks = static_cast(count) / countPerTick; - for(int i = 0; i < std::round(numTicks); ++i) { + for(int i = 0; i < std::round(numTicks); ++i) ss << "*"; - } return ss.str(); }; @@ -222,11 +213,10 @@ std::string HistogramStats::prettyString( ss << "] max_=["; fromatValuesIntoStream(ss, max); ss << "] sum=["; - if(sumOverflow) { + if(sumOverflow) ss << "overflow"; - } else { + else fromatCountIntoStream(ss, sum); - } ss << "] mean=["; fromatValuesIntoStream(ss, mean); ss << "] numValues=["; diff --git a/flashlight/fl/common/Logging.cpp b/flashlight/fl/common/Logging.cpp index 7b004cd..113930d 100644 --- a/flashlight/fl/common/Logging.cpp +++ b/flashlight/fl/common/Logging.cpp @@ -41,9 +41,8 @@ namespace { std::string getFileName(const std::string& path) { const size_t separatorIndex = path.rfind(kSeparator, path.length()); - if(separatorIndex == std::string::npos) { + if(separatorIndex == std::string::npos) return path; - } return path.substr(separatorIndex + 1, path.length() - separatorIndex); } @@ -62,9 +61,8 @@ namespace { constexpr size_t bufferSize = 50; char buffer[bufferSize]; const size_t nWrittenBytes = std::strftime(buffer, 30, "%m%d %T", timeinfo); - if(!nWrittenBytes) { + if(!nWrittenBytes) return "getTime() failed to format time"; - } const std::chrono::system_clock::time_point timeInSecondsResolution = std::chrono::system_clock::from_time_t(secondsSinceEpoc); @@ -95,9 +93,8 @@ namespace { ss << std::this_thread::get_id(); std::string threadId = ss.str(); - if(threadId.size() > maxThreadIdNumDigits) { + if(threadId.size() > maxThreadIdNumDigits) threadId = threadId.substr(threadId.size() - maxThreadIdNumDigits); - } (*outputStream) << dateTimeWithMicroSeconds() << ' ' @@ -140,9 +137,8 @@ Logging::~Logging() { stringStream_ << std::endl; (*outputStreamPtr_) << stringStream_.str(); outputStreamPtr_->flush(); - if(level_ == LogLevel::FATAL) { + if(level_ == LogLevel::FATAL) exit(-1); - } } } @@ -293,11 +289,9 @@ constexpr std::array flLogLevelNames = {"INFO", "WARNING", "ERROR", "FATAL", "DISABLED"}; std::string logLevelName(LogLevel level) { - for(int i = 0; i < flLogLevelValues.size(); ++i) { - if(level == flLogLevelValues.at(i)) { + for(int i = 0; i < flLogLevelValues.size(); ++i) + if(level == flLogLevelValues.at(i)) return flLogLevelNames.at(i); - } - } std::stringstream ss; ss << "logLevelName(level=" << static_cast(level) << ") invalid level. Level should be in the range [0.." @@ -306,11 +300,9 @@ std::string logLevelName(LogLevel level) { } LogLevel logLevelValue(const std::string& level) { - for(int i = 0; i < flLogLevelValues.size(); ++i) { - if(level == std::string(flLogLevelNames.at(i))) { + for(int i = 0; i < flLogLevelValues.size(); ++i) + if(level == std::string(flLogLevelNames.at(i))) return flLogLevelValues.at(i); - } - } std::stringstream ss; ss << "logLevelValue(level=" << level << ") invalid level. Level should be INFO, WARNING, ERROR or FATAL"; diff --git a/flashlight/fl/common/Logging.h b/flashlight/fl/common/Logging.h index d41f7fd..70045b2 100644 --- a/flashlight/fl/common/Logging.h +++ b/flashlight/fl/common/Logging.h @@ -150,9 +150,8 @@ class FL_API Logging { // Prints t to stdout along with context and sensible font color. template Logging && print(T & t) { - if(level_ <= Logging::maxLoggingLevel_) { + if(level_ <= Logging::maxLoggingLevel_) stringStream_ << t; - } return std::move(*this); } @@ -178,9 +177,8 @@ class FL_API VerboseLogging { // Prints t to stdout along with logging level and context. template VerboseLogging && print(T & t) { - if(level_ <= VerboseLogging::maxLoggingLevel_) { + if(level_ <= VerboseLogging::maxLoggingLevel_) stringStream_ << t; - } return std::move(*this); } diff --git a/flashlight/fl/common/Serialization-inl.h b/flashlight/fl/common/Serialization-inl.h index 47a1002..6dff6c9 100644 --- a/flashlight/fl/common/Serialization-inl.h +++ b/flashlight/fl/common/Serialization-inl.h @@ -61,9 +61,8 @@ namespace detail { // 1 argument, version-restricted. template void applyArchive(Archive& ar, const uint32_t version, Versioned varg) { - if(version >= varg.minVersion && version <= varg.maxVersion) { + if(version >= varg.minVersion && version <= varg.maxVersion) applyArchive(ar, version, std::forward(varg.ref)); - } } // 1 argument, with conversion, saving. @@ -73,11 +72,10 @@ namespace detail { typename T, std::enable_if_t::value, int> = 0> void applyArchive(Archive& ar, const uint32_t version, SerializeAs arg) { - if(arg.saveConverter) { + if(arg.saveConverter) applyArchive(ar, version, arg.saveConverter(arg.ref)); - } else { + else applyArchive(ar, version, static_cast(arg.ref)); - } } // 1 argument, with conversion, loading. @@ -90,11 +88,10 @@ namespace detail { using T0 = std::remove_reference_t; S s; applyArchive(ar, version, s); - if(arg.loadConverter) { + if(arg.loadConverter) arg.ref = arg.loadConverter(std::move(s)); - } else { + else arg.ref = static_cast(std::move(s)); - } } // 2+ arguments (recurse). @@ -206,11 +203,10 @@ void save( ) { const auto& tensor = tensor_.val; // TODO{fl::Tensor}{sparse} figure out what to do here... - if(tensor.isSparse()) { + if(tensor.isSparse()) throw cereal::Exception( "Serialzation of sparse Tensor is not supported yet!" ); - } std::vector vec(tensor.bytes()); tensor.host(vec.data()); ar(tensor.shape(), tensor.type(), vec); diff --git a/flashlight/fl/common/Utils.cpp b/flashlight/fl/common/Utils.cpp index 453dc7c..165e48c 100644 --- a/flashlight/fl/common/Utils.cpp +++ b/flashlight/fl/common/Utils.cpp @@ -27,22 +27,19 @@ bool f16Supported() { } size_t divRoundUp(size_t numerator, size_t denominator) { - if(!numerator) { + if(!numerator) return 0; - } - if(!denominator) { + if(!denominator) throw std::invalid_argument( std::string("divRoundUp() zero denominator error") ); - } return (numerator + denominator - 1) / denominator; } namespace { std::string prettyStringMemorySizeUnits(size_t size) { - if(size == SIZE_MAX) { + if(size == SIZE_MAX) return "SIZE_MAX"; - } std::stringstream ss; bool isFirst = true; @@ -63,9 +60,8 @@ namespace { unit = "KB"; } if(size > 0) { - if(!isFirst) { + if(!isFirst) ss << '+'; - } isFirst = false; size_t nUnits = size >> shift; ss << nUnits << unit; @@ -77,9 +73,8 @@ namespace { } std::string prettyStringCountUnits(size_t count) { - if(count == SIZE_MAX) { + if(count == SIZE_MAX) return "SIZE_MAX"; - } std::stringstream ss; bool isFirst = true; @@ -100,9 +95,8 @@ namespace { unit = "k"; } if(count > 0) { - if(!isFirst) { + if(!isFirst) ss << '+'; - } isFirst = false; size_t nUnits = count / magnitude; ss << nUnits << unit; @@ -115,28 +109,24 @@ namespace { } // namespace std::string prettyStringMemorySize(size_t size) { - if(size == SIZE_MAX) { + if(size == SIZE_MAX) return "SIZE_MAX"; - } std::stringstream ss; ss << size; - if(size >= (1UL << 13)) { + if(size >= (1UL << 13)) ss << '(' << prettyStringMemorySizeUnits(size) << ')'; - } return ss.str(); } std::string prettyStringCount(size_t count) { - if(count == SIZE_MAX) { + if(count == SIZE_MAX) return "SIZE_MAX"; - } std::stringstream ss; ss << count; - if(count >= 1e3) { // >= 10 thousand + if(count >= 1e3) // >= 10 thousand ss << '(' << prettyStringCountUnits(count) << ')'; - } return ss.str(); } diff --git a/flashlight/fl/common/Utils.h b/flashlight/fl/common/Utils.h index 955824f..073aa49 100644 --- a/flashlight/fl/common/Utils.h +++ b/flashlight/fl/common/Utils.h @@ -52,28 +52,25 @@ typename std::invoke_result::type retryWithBackoff( Fn&& f, Args&&... args ) { - if(!(initial.count() >= 0.0)) { + if(!(initial.count() >= 0.0)) throw std::invalid_argument("retryWithBackoff: bad initial"); - } else if(!(factor >= 0.0)) { + else if(!(factor >= 0.0)) throw std::invalid_argument("retryWithBackoff: bad factor"); - } else if(maxIters <= 0) { + else if(maxIters <= 0) throw std::invalid_argument("retryWithBackoff: bad maxIters"); - } auto sleepSecs = initial.count(); for(int64_t i = 0; i < maxIters; ++i) { try { return f(std::forward(args)...); } catch(...) { - if(i >= maxIters - 1) { + if(i >= maxIters - 1) throw; - } } - if(sleepSecs > 0.0) { + if(sleepSecs > 0.0) /* sleep override */ std::this_thread::sleep_for( std::chrono::duration(std::min(1e7, sleepSecs)) ); - } sleepSecs *= factor; } throw std::logic_error("retryWithBackoff: hit unreachable"); diff --git a/flashlight/fl/common/WinUtility.cpp b/flashlight/fl/common/WinUtility.cpp index f7b7063..1c54380 100644 --- a/flashlight/fl/common/WinUtility.cpp +++ b/flashlight/fl/common/WinUtility.cpp @@ -16,14 +16,12 @@ namespace fl { namespace detail { std::wstring utf8ToWide(const std::string& utf8) { - if(utf8.empty()) { + if(utf8.empty()) return std::wstring(); - } int wideSize = MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, nullptr, 0); - if(wideSize == 0) { + if(wideSize == 0) throw std::runtime_error("Failed to convert UTF-8 to wide string"); - } std::wstring wide(wideSize - 1, 0); MultiByteToWideChar(CP_UTF8, 0, utf8.c_str(), -1, &wide[0], wideSize); @@ -32,9 +30,8 @@ namespace detail { std::string getWindowsErrorString() { DWORD error = GetLastError(); - if(error == 0) { + if(error == 0) return "No error"; - } LPWSTR messageBuffer = nullptr; FormatMessageW( @@ -74,9 +71,8 @@ namespace detail { ); } LocalFree(messageBuffer); - } else { + } else result = "Unknown error"; - } return result; } diff --git a/flashlight/fl/common/threadpool/ThreadPool.h b/flashlight/fl/common/threadpool/ThreadPool.h index f2077fa..d6138ff 100644 --- a/flashlight/fl/common/threadpool/ThreadPool.h +++ b/flashlight/fl/common/threadpool/ThreadPool.h @@ -77,12 +77,11 @@ inline ThreadPool::ThreadPool( size_t threads, const std::function& initFn /* = nullptr */ ) : stop(false) { - for(size_t id = 0; id < threads; ++id) { + for(size_t id = 0; id < threads; ++id) workers.emplace_back( [this, initFn, id] { - if(initFn) { + if(initFn) initFn(id); - } for(;;) { std::function task; @@ -91,9 +90,8 @@ inline ThreadPool::ThreadPool( this->condition.wait( lock, [this] { return this->stop || !this->tasks.empty(); }); - if(this->stop && this->tasks.empty()) { + if(this->stop && this->tasks.empty()) return; - } task = std::move(this->tasks.front()); this->tasks.pop(); } @@ -102,7 +100,6 @@ inline ThreadPool::ThreadPool( } } ); - } } template @@ -119,9 +116,8 @@ auto ThreadPool::enqueue(F&& f, Args&&... args) std::unique_lock lock(queue_mutex); // don't allow enqueueing after stopping the pool - if(stop) { + if(stop) throw std::runtime_error("enqueue on stopped ThreadPool"); - } tasks.emplace([task]() { (*task)(); }); } @@ -135,8 +131,7 @@ inline ThreadPool::~ThreadPool() { stop = true; } condition.notify_all(); - for(std::thread& worker : workers) { + for(std::thread& worker : workers) worker.join(); - } } } // namespace fl diff --git a/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp b/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp index 3438150..70ac896 100644 --- a/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp +++ b/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp @@ -23,9 +23,8 @@ AdaptiveEmbedding::AdaptiveEmbedding( ) : embeddingDim_(embeddingDim), cutoff_(cutoff), divValue_(divValue) { - if(cutoff_.empty()) { + if(cutoff_.empty()) throw std::invalid_argument("Invalid cutoff for AdaptiveEmbedding"); - } double stdv = std::sqrt(1.0 / static_cast(embeddingDim_)); // to be in agreement with the adaptive softmax to simplify // tied version of adaptive input and softmax @@ -62,12 +61,11 @@ AdaptiveEmbedding::AdaptiveEmbedding( } Variable AdaptiveEmbedding::forward(const Variable& input) { - if(input.ndim() != 2) { + if(input.ndim() != 2) throw std::invalid_argument( "AdaptiveEmbedding::forward - input must " "have 2 dimensions - expect T x B" ); - } auto flatInput = flat(input); std::vector indices; @@ -95,11 +93,10 @@ Variable AdaptiveEmbedding::forward(const Variable& input) { embeddings.push_back(tailEmbedding); } } - if(embeddings.empty()) { + if(embeddings.empty()) throw std::invalid_argument( "Invalid input, no positions in the AdaptiveEmbedding layer" ); - } Shape outShape({embeddingDim_, input.dim(0), input.dim(1)}); auto result = fl::concatenate(embeddings, 1); @@ -115,9 +112,8 @@ std::unique_ptr AdaptiveEmbedding::clone() const { std::string AdaptiveEmbedding::prettyString() const { std::ostringstream ss; ss << "AdaptiveEmbedding (dim: " << embeddingDim_ << "), (cutoff: "; - for(int i = 0; i < cutoff_.size() - 1; i++) { + for(int i = 0; i < cutoff_.size() - 1; i++) ss << cutoff_[i] << ", "; - } ss << cutoff_[cutoff_.size() - 1] << "), " << "(divValue: " << divValue_ << ")"; return ss.str(); diff --git a/flashlight/fl/contrib/modules/AsymmetricConv1D.cpp b/flashlight/fl/contrib/modules/AsymmetricConv1D.cpp index ebd70a8..4e20da8 100644 --- a/flashlight/fl/contrib/modules/AsymmetricConv1D.cpp +++ b/flashlight/fl/contrib/modules/AsymmetricConv1D.cpp @@ -14,16 +14,14 @@ namespace fl { void AsymmetricConv1D::checkParams() { - if(xPad_ != static_cast(PaddingMode::SAME) && xPad_ != 0) { + if(xPad_ != static_cast(PaddingMode::SAME) && xPad_ != 0) throw std::invalid_argument( "AsymmetricConv1D: invalid xPad_, now supports only '0' or 'SAME' " ); - } - if(futurePart_ < 0 || futurePart_ > 1) { + if(futurePart_ < 0 || futurePart_ > 1) throw std::invalid_argument( "AsymmetricConv1D: invalid futurePart_, should be in [0, 1]" ); - } } AsymmetricConv1D::AsymmetricConv1D( @@ -69,13 +67,12 @@ AsymmetricConv1D::AsymmetricConv1D( Variable AsymmetricConv1D::forward(const Variable& input) { auto px = fl::derivePadding(input.dim(0), xFilter_, xStride_, xPad_, xDilation_); - if(!(px >= 0)) { + if(!(px >= 0)) throw std::invalid_argument("invalid padding for AsymmetricConv1D"); - } Variable output; int cutPx = std::abs(2 * (0.5 - futurePart_)) * px; int asymmetryPx = px + cutPx; - if(bias_) { + if(bias_) output = conv2d( input, params_[0], @@ -88,7 +85,7 @@ Variable AsymmetricConv1D::forward(const Variable& input) { yDilation_, groups_ ); - } else { + else output = conv2d( input, params_[0], @@ -100,12 +97,10 @@ Variable AsymmetricConv1D::forward(const Variable& input) { yDilation_, groups_ ); - } - if(futurePart_ < 0.5) { + if(futurePart_ < 0.5) output = output(fl::range(0, output.dim(0) - 2 * cutPx)); - } else if(futurePart_ > 0.5) { + else if(futurePart_ > 0.5) output = output(fl::range(2 * cutPx, output.dim(0))); - } return output; } diff --git a/flashlight/fl/contrib/modules/Conformer.cpp b/flashlight/fl/contrib/modules/Conformer.cpp index 0629d8a..017e8d1 100644 --- a/flashlight/fl/contrib/modules/Conformer.cpp +++ b/flashlight/fl/contrib/modules/Conformer.cpp @@ -101,9 +101,8 @@ Conformer::Conformer( true, modelDim )) { - if(posEmbContextSize_ > 0) { + if(posEmbContextSize_ > 0) params_.push_back(uniform(2 * posEmbContextSize_ - 1, headDim, -0.1, 0.1)); - } createLayers(); } @@ -189,9 +188,8 @@ Variable Conformer::mhsa(const Variable& input, const Variable& inputPadMask) { auto v = transpose((*wv_)(normedInput), {1, 0, 2}); Variable mask, posEmb; - if(posEmbContextSize_ > 0) { + if(posEmbContextSize_ > 0) posEmb = tile(params_[0].astype(input.type()), {1, 1, nHeads_ * bsz}); - } fl::Variable padMask; // TODO{fl::Tensor}{resize} - emulate the ArrayFire resize operation for @@ -199,12 +197,11 @@ Variable Conformer::mhsa(const Variable& input, const Variable& inputPadMask) { if(!inputPadMask.isEmpty()) { auto padMaskArr = inputPadMask.tensor(); Shape newMaskShape = {input.dim(1), input.dim(2)}; - if(padMaskArr.elements() != newMaskShape.elements()) { + if(padMaskArr.elements() != newMaskShape.elements()) throw std::runtime_error( "Transformer::selfAttention - pad mask requires resize. " "This behavior will be fixed in a future release " ); - } padMaskArr = fl::reshape(padMaskArr, newMaskShape); padMask = fl::Variable(fl::log(padMaskArr), false); } @@ -241,27 +238,24 @@ Variable Conformer::conv(const Variable& _input) { } std::vector Conformer::forward(const std::vector& input) { - if(input.size() != 2) { + if(input.size() != 2) throw std::invalid_argument( "Invalid inputs for conformer block: there should be input " "and paddding mask (can be empty Variable)" ); - } auto x = input[0]; - if(x.ndim() != 3) { + if(x.ndim() != 3) throw std::invalid_argument( "Conformer::forward - input should be of 3 dimensions " "expects an input of size C x T x B - see documentation." ); - } float pDropout = train_ ? pDropout_ : 0.0; float f = 1.0; - if(train_ && (fl::rand({1}).scalar() < pLayerDropout_)) { + if(train_ && (fl::rand({1}).scalar() < pLayerDropout_)) f = 0.0; - } // apply first feed-forward module auto ffn1 = dropout( (*w12_)( diff --git a/flashlight/fl/contrib/modules/PositionEmbedding.cpp b/flashlight/fl/contrib/modules/PositionEmbedding.cpp index b94078a..d965670 100644 --- a/flashlight/fl/contrib/modules/PositionEmbedding.cpp +++ b/flashlight/fl/contrib/modules/PositionEmbedding.cpp @@ -42,21 +42,19 @@ PositionEmbedding& PositionEmbedding::operator=( std::vector PositionEmbedding::forward( const std::vector& input ) { - if(input[0].ndim() != 3) { + if(input[0].ndim() != 3) throw std::invalid_argument( "PositionEmbedding::forward - expect a tensor with " "3 dimensions - C x T x B" ); - } int n = input[0].dim(1); Variable posEmb = tileAs( params_[0].astype(input[0].type())(fl::span, fl::range(0, n)), input[0]); - if(dropout_ > 0.0 && train_) { + if(dropout_ > 0.0 && train_) return {input[0] + dropout(posEmb, dropout_)}; - } else { + else return {input[0] + posEmb}; - } } std::vector PositionEmbedding::operator()( diff --git a/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp b/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp index 49b0666..26be879 100644 --- a/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp +++ b/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp @@ -41,33 +41,27 @@ RawWavSpecAugment::RawWavSpecAugment( rawWavHighFreqHz_(highFreqHz), rawWavSampleRate_(sampleRate), maxKernelSize_(maxKernelSize) { - if(numFreqMask_ > 0 && freqMaskF_ <= 0) { + if(numFreqMask_ > 0 && freqMaskF_ <= 0) throw std::invalid_argument("invalid arguments for frequency masking."); - } - if(numTimeMask_ > 0 && timeMaskT_ <= 0) { + if(numTimeMask_ > 0 && timeMaskT_ <= 0) throw std::invalid_argument("invalid arguments for time masking."); - } - if(numTimeMask_ > 0 && (timeMaskP_ <= 0 || timeMaskP_ > 1.0)) { + if(numTimeMask_ > 0 && (timeMaskP_ <= 0 || timeMaskP_ > 1.0)) throw std::invalid_argument("invalid arguments for time masking."); - } if( rawWavLowFreqHz_ < 0 || rawWavHighFreqHz_ < 0 || rawWavLowFreqHz_ >= rawWavHighFreqHz_ - ) { + ) throw std::invalid_argument( "invalid arguments for raw Wav high and low frequencies." ); - } - if(rawWavNMels_ <= 0) { + if(rawWavNMels_ <= 0) throw std::invalid_argument("invalid arguments for raw Wav nMels."); - } precomputeFilters(); } void RawWavSpecAugment::precomputeFilters() { - if(!lowPassFilters_.empty()) { + if(!lowPassFilters_.empty()) return; - } auto mel2hz = [](float mel) { return 700.0 * (std::pow(10, (mel / 2595.0)) - 1.0); }; @@ -84,9 +78,8 @@ void RawWavSpecAugment::precomputeFilters() { for(int index = 0; index <= rawWavNMels_; index++) { cutoff_.push_back(mel2hz(currentMel) / rawWavSampleRate_); currentMel += delta; - if(index > 0) { + if(index > 0) transBandKhz[index] = cutoff_[index - 1] / 4.; - } } transBandKhz[0] = transBandKhz[1]; ignoredLowPassFilters_ = 0; @@ -123,35 +116,30 @@ void RawWavSpecAugment::precomputeFilters() { filter->eval(); lowPassFilters_.push_back(filter); } - if(ignoredLowPassFilters_ >= lowPassFilters_.size()) { + if(ignoredLowPassFilters_ >= lowPassFilters_.size()) throw std::invalid_argument( "All low pass filters are ignored, too huge kernel for all frequencies" ); - } } Variable RawWavSpecAugment::forward(const Variable& input) { - if(input.isCalcGrad()) { + if(input.isCalcGrad()) throw std::invalid_argument( "input gradient calculation is not supported for RawWavSpecAugment." ); - } - if(lowPassFilters_.empty()) { + if(lowPassFilters_.empty()) throw std::invalid_argument("invalid RawWavSpecAugment, filters are empty"); - } fl::Variable inputCast = detail::adjustInputType(input, "RawWavSpecAugment"); auto output = Variable(inputCast.tensor(), false); - if(!train_) { + if(!train_) return output; - } - if(input.ndim() != 3) { + if(input.ndim() != 3) throw std::invalid_argument( "RawWavSpecAugment::forward - invalid input shape: " "input is expected to be T x C x B" ); - } // input is expected T x C x B (mostly C=1) const Shape& inShape = inputCast.shape(); @@ -177,13 +165,12 @@ Variable RawWavSpecAugment::forward(const Variable& input) { auto numTimeSteps = inputCast.dim(0); // number of time steps // an upper bound on the time mask int T = std::min(timeMaskT_, static_cast(numTimeSteps * timeMaskP_)); - if(T > 0) { + if(T > 0) for(int i = 0; i < numTimeMask_; ++i) { auto t = generateRandomInt(0, T); auto t0 = generateRandomInt(0, numTimeSteps - t); opArr(fl::range(t0, t0 + t + 1)) = replaceVal; } - } return output; } diff --git a/flashlight/fl/contrib/modules/Residual.cpp b/flashlight/fl/contrib/modules/Residual.cpp index 528347d..51a84ac 100644 --- a/flashlight/fl/contrib/modules/Residual.cpp +++ b/flashlight/fl/contrib/modules/Residual.cpp @@ -18,18 +18,16 @@ std::unordered_set Residual::getProjectionsIndices() const { void Residual::addScale(int beforeLayer, float scale) { int nLayers = modules_.size() - projectionsIndices_.size(); - if(beforeLayer < 1 || beforeLayer > nLayers + 1) { + if(beforeLayer < 1 || beforeLayer > nLayers + 1) throw std::invalid_argument( "Residual: invalid layer index " + std::to_string(beforeLayer) + " before which apply the scaling" ); - } - if(scales_.find(beforeLayer - 1) != scales_.end()) { + if(scales_.find(beforeLayer - 1) != scales_.end()) throw std::invalid_argument( "Residual: scaling before layer " + std::to_string(beforeLayer) + " was already added; adding only once is allowed" ); - } scales_[beforeLayer - 1] = scale; } @@ -39,22 +37,20 @@ void Residual::checkShortcut(int fromLayer, int toLayer) { if( fromLayer < 0 || fromLayer >= nLayers || toLayer <= 0 || toLayer > nLayers + 2 || toLayer - fromLayer <= 1 - ) { + ) throw std::invalid_argument( "Residual: invalid skip connection; check fromLayer=" + std::to_string(fromLayer) + " and toLayer=" + std::to_string(toLayer) + " parameters. They are out of range of added layers" ); - } if( shortcut_.find(toLayer - 1) != shortcut_.end() && shortcut_[toLayer - 1].find(fromLayer) != shortcut_[toLayer - 1].end() - ) { + ) throw std::invalid_argument( "Residual: skip connection for fromLayer " + std::to_string(fromLayer) + " to toLayer " + std::to_string(toLayer) + " is already added" ); - } } void Residual::processShortcut( @@ -80,9 +76,8 @@ Variable Residual::applyScale(const Variable& input, const int layerIndex) { } std::vector Residual::forward(const std::vector& inputs) { - if(inputs.size() != 1) { + if(inputs.size() != 1) throw std::invalid_argument("Residual module expects only one input"); - } return {forward(inputs[0])}; } @@ -98,17 +93,15 @@ Variable Residual::forward(const Variable& input) { while(projectionsIndices_.find(moduleIndex) != projectionsIndices_.end()) { moduleIndex++; } - if(shortcut_.find(layerIndex) != shortcut_.end()) { + if(shortcut_.find(layerIndex) != shortcut_.end()) for(const auto& shortcut : shortcut_[layerIndex]) { Variable connectionOut = outputs[shortcut.first]; - if(shortcut.second != -1) { + if(shortcut.second != -1) connectionOut = modules_[shortcut.second] ->forward({outputs[shortcut.first]}) .front(); - } output = output + connectionOut.astype(output.type()); } - } output = modules_[moduleIndex] ->forward({applyScale(output, layerIndex)}) .front(); @@ -116,17 +109,15 @@ Variable Residual::forward(const Variable& input) { layerIndex++; moduleIndex++; } - if(shortcut_.find(nLayers) != shortcut_.end()) { + if(shortcut_.find(nLayers) != shortcut_.end()) for(const auto& shortcut : shortcut_[nLayers]) { Variable connectionOut = outputs[shortcut.first]; - if(shortcut.second != -1) { + if(shortcut.second != -1) connectionOut = modules_[shortcut.second] ->forward({outputs[shortcut.first]}) .front(); - } output = output + connectionOut.astype(output.type()); } - } return applyScale(output, nLayers); } @@ -135,11 +126,9 @@ std::string Residual::prettyString() const { // prepare inverted residual skip connection std::unordered_map> reverseShortcut; // start -> end - for(const auto& shortcut : shortcut_) { - for(const auto& value : shortcut.second) { + for(const auto& shortcut : shortcut_) + for(const auto& value : shortcut.second) reverseShortcut[value.first].insert({shortcut.first, value.second}); - } - } int nLayers = modules_.size() - projectionsIndices_.size(); int moduleIndex = -1, layerIndex = 0; @@ -147,9 +136,9 @@ std::string Residual::prettyString() const { while(layerIndex <= nLayers) { ss << "\n\tRes(" << layerIndex << "): "; - if(layerIndex == 0) { + if(layerIndex == 0) ss << "Input"; - } else { + else { while( projectionsIndices_.find(moduleIndex) != projectionsIndices_.end() @@ -160,9 +149,8 @@ std::string Residual::prettyString() const { } scaleIt = scales_.find(layerIndex); - if(scaleIt != scales_.end()) { + if(scaleIt != scales_.end()) ss << " with scale (before layer is applied) " << scaleIt->second << ";"; - } if( reverseShortcut.find(layerIndex) != reverseShortcut.end() @@ -170,15 +158,13 @@ std::string Residual::prettyString() const { ) { ss << "; skip connection to "; for(auto shortcut : reverseShortcut[layerIndex]) { - if(shortcut.first < nLayers) { + if(shortcut.first < nLayers) ss << "layer Res(" << shortcut.first + 1 << ")"; - } else { + else ss << "output"; - } - if(shortcut.second != -1) { + if(shortcut.second != -1) ss << " with transformation: " << modules_[shortcut.second]->prettyString() << ";"; - } ss << " "; } } @@ -187,9 +173,8 @@ std::string Residual::prettyString() const { } ss << "\n\tRes(" << nLayers + 1 << "): Output;"; scaleIt = scales_.find(nLayers + 1); - if(scaleIt != scales_.end()) { + if(scaleIt != scales_.end()) ss << " with scale (before layer is applied) " << scaleIt->second << ";"; - } return ss.str(); } diff --git a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp index f35bf5e..a7e3417 100644 --- a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp +++ b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp @@ -60,13 +60,12 @@ SinusoidalPositionEmbedding& SinusoidalPositionEmbedding::operator=( std::vector SinusoidalPositionEmbedding::forward( const std::vector& input ) { - if(input[0].dim(0) != layerDim_) { + if(input[0].dim(0) != layerDim_) throw std::invalid_argument( "Input dimenstion " + std::to_string(input[0].dim(0)) + " and Embedding dimension " + std::to_string(layerDim_) + " are different." ); - } // Retrieve the number of tokens (positions) and the numeric type (floating // point precision). const int nPositions = input[0].dim(1); diff --git a/flashlight/fl/contrib/modules/SpecAugment.cpp b/flashlight/fl/contrib/modules/SpecAugment.cpp index 1643106..fb9e0fb 100644 --- a/flashlight/fl/contrib/modules/SpecAugment.cpp +++ b/flashlight/fl/contrib/modules/SpecAugment.cpp @@ -29,28 +29,23 @@ SpecAugment::SpecAugment( timeMaskP_(tMaskP), numTimeMask_(nTMask), maskStrategy_(mStrategy) { - if(numFreqMask_ > 0 && freqMaskF_ <= 0) { + if(numFreqMask_ > 0 && freqMaskF_ <= 0) throw std::invalid_argument("invalid arguments for frequency masking."); - } - if(numTimeMask_ > 0 && timeMaskT_ <= 0) { + if(numTimeMask_ > 0 && timeMaskT_ <= 0) throw std::invalid_argument("invalid arguments for time masking."); - } - if(numTimeMask_ > 0 && (timeMaskP_ <= 0 || timeMaskP_ > 1.0)) { + if(numTimeMask_ > 0 && (timeMaskP_ <= 0 || timeMaskP_ > 1.0)) throw std::invalid_argument("invalid arguments for time masking."); - } } Variable SpecAugment::forward(const Variable& input) { - if(input.isCalcGrad()) { + if(input.isCalcGrad()) throw std::invalid_argument( "input gradient calculation is not supported for SpecAugment." ); - } auto output = Variable(input.tensor(), false); - if(!train_) { + if(!train_) return output; - } auto& opArr = output.tensor(); @@ -59,9 +54,8 @@ Variable SpecAugment::forward(const Variable& input) { : 0.0; auto numFreqChans = input.dim(1); // number of frequency channels - if(numFreqChans < freqMaskF_) { + if(numFreqChans < freqMaskF_) throw std::runtime_error("Invalid input frequency channels"); - } for(int i = 0; i < numFreqMask_; ++i) { auto f = generateRandomInt(0, freqMaskF_); auto f0 = generateRandomInt(0, numFreqChans - f); @@ -71,13 +65,12 @@ Variable SpecAugment::forward(const Variable& input) { auto numTimeSteps = input.dim(0); // number of time steps // an upper bound on the time mask int T = std::min(timeMaskT_, static_cast(numTimeSteps * timeMaskP_)); - if(T > 0) { + if(T > 0) for(int i = 0; i < numTimeMask_; ++i) { auto t = generateRandomInt(0, T); auto t0 = generateRandomInt(0, numTimeSteps - t); opArr(fl::range(t0, t0 + t + 1)) = replaceVal; } - } return output; } diff --git a/flashlight/fl/contrib/modules/TDSBlock.cpp b/flashlight/fl/contrib/modules/TDSBlock.cpp index 277b87d..791d0de 100644 --- a/flashlight/fl/contrib/modules/TDSBlock.cpp +++ b/flashlight/fl/contrib/modules/TDSBlock.cpp @@ -22,11 +22,10 @@ TDSBlock::TDSBlock( auto convPadding = static_cast(fl::PaddingMode::SAME); if(rightPadding != -1) { int totalPadding = kernelSize - 1; - if(rightPadding > totalPadding) { + if(rightPadding > totalPadding) throw std::invalid_argument( "right padding exceeds the 'SAME' padding required for TDSBlock" ); - } conv.add( Padding( {std::pair{totalPadding - rightPadding, rightPadding}}, @@ -40,37 +39,32 @@ TDSBlock::TDSBlock( conv.add(Dropout(dropout)); int linearDim = channels * width; - if(innerLinearDim == 0) { + if(innerLinearDim == 0) innerLinearDim = linearDim; - } Sequential fc; fc.add(Reorder({2, 1, 0, 3})); fc.add(View({linearDim, -1, 1, 0})); fc.add(Linear(linearDim, innerLinearDim)); fc.add(ReLU()); - if(dropout > 0) { + if(dropout > 0) fc.add(Dropout(dropout)); - } fc.add(Linear(innerLinearDim, linearDim)); fc.add(View({channels, width, -1, 0})); fc.add(Reorder({2, 1, 0, 3})); - if(dropout > 0) { + if(dropout > 0) fc.add(Dropout(dropout)); - } add(std::move(conv)); - if(lNormIncludeTime) { + if(lNormIncludeTime) add(LayerNorm(std::vector{0, 1, 2})); - } else { + else add(LayerNorm(std::vector{1, 2})); - } add(std::move(fc)); - if(lNormIncludeTime) { + if(lNormIncludeTime) add(LayerNorm(std::vector{0, 1, 2})); - } else { + else add(LayerNorm(std::vector{1, 2})); - } } std::vector TDSBlock::forward(const std::vector& inputs) { diff --git a/flashlight/fl/contrib/modules/Transformer.cpp b/flashlight/fl/contrib/modules/Transformer.cpp index 43db939..0df4278 100644 --- a/flashlight/fl/contrib/modules/Transformer.cpp +++ b/flashlight/fl/contrib/modules/Transformer.cpp @@ -47,11 +47,10 @@ Transformer::Transformer( wf_(std::make_shared(transformerInitLinear(headDim * nHeads, modelDim))), norm1_(std::make_shared(std::vector({0, 3}))), norm2_(std::make_shared(std::vector({0, 3}))) { - if(bptt > 0) { + if(bptt > 0) params_.push_back( uniform(2 * bptt - 1, headDim, -0.1, 0.1, fl::dtype::f32, true) ); - } createLayers(); } @@ -128,14 +127,12 @@ Variable Transformer::selfAttention(const std::vector& input) { auto v = transpose((*wv_)(concatenate(inputWithState, 1)), {1, 0, 2}); Variable mask, posEmb; - if(bptt_ > 0) { + if(bptt_ > 0) posEmb = tile(params_[0].astype(encoderInput.type()), {1, 1, nHeads_ * bsz}); - } - if(useMask_ && encoderInput.dim(1) > 1) { + if(useMask_ && encoderInput.dim(1) > 1) // mask future if we use the previous state (then n is previous time) mask = getMask(n, input.size() == 3); - } int offset = (input.size() == 2) ? 0 : n; @@ -146,12 +143,11 @@ Variable Transformer::selfAttention(const std::vector& input) { Shape newMaskShape = {encoderInput.dim(1), encoderInput.dim(2)}; // TODO{fl::Tensor}{resize} - emulate the ArrayFire resize operation for // transformer pad mask - if(padMaskArr.elements() != newMaskShape.elements()) { + if(padMaskArr.elements() != newMaskShape.elements()) throw std::runtime_error( "Transformer::selfAttention - pad mask requires resize. " "This behavior will be fixed in a future release " ); - } padMaskArr = fl::reshape(padMaskArr, newMaskShape); padMask = fl::Variable(fl::log(padMaskArr), false); } @@ -176,38 +172,34 @@ std::vector Transformer::forward(const std::vector& input) { // padMask should be empty if previous step is provided // padMask is expected to have "1" on the used positions and "0" on padded // positions - if(input.size() != 2) { + if(input.size() != 2) throw std::invalid_argument( "Invalid inputs for transformer block: there should be at least input and mask" ); - } const auto& x = input.at(input.size() - 2); - if(x.ndim() != 3) { + if(x.ndim() != 3) throw std::invalid_argument( "Transformer::forward - input should be of 3 dimensions " "expects an input of size C x T x B - see documentation." ); - } if(!input.back().isEmpty()) { - if(input.back().ndim() < 2) { + if(input.back().ndim() < 2) throw std::invalid_argument( "Transformer::forward - invalid size for pad mask - " "must have at least two dimensions" ); - } else if(x.dim(2) != input.back().dim(1)) { + else if(x.dim(2) != input.back().dim(1)) throw std::invalid_argument( "Transformer::forward - invalid inputs for transformer:" " input and mask batch sizes are different" ); - } } float f = 1.0; - if(train_ && (fl::rand({1}).scalar() < pLayerdrop_)) { + if(train_ && (fl::rand({1}).scalar() < pLayerdrop_)) f = 0.0; - } if(preLN_) { auto h = (f * (*norm1_)(selfAttention(input))).astype(x.type()) + x; return {f* (*norm2_)(mlp(h)).astype(h.type()) + h}; diff --git a/flashlight/fl/dataset/BatchDataset.cpp b/flashlight/fl/dataset/BatchDataset.cpp index a271bd9..9234d3f 100644 --- a/flashlight/fl/dataset/BatchDataset.cpp +++ b/flashlight/fl/dataset/BatchDataset.cpp @@ -22,12 +22,10 @@ BatchDataset::BatchDataset( batchSize_(batchsize), batchPolicy_(policy), batchFns_(batchfns) { - if(!dataset_) { + if(!dataset_) throw std::invalid_argument("dataset to be batched is null"); - } - if(batchSize_ <= 0) { + if(batchSize_ <= 0) throw std::invalid_argument("invalid batch size"); - } preBatchSize_ = dataset_->size(); switch(batchPolicy_) { case BatchDatasetPolicy::INCLUDE_LAST: @@ -37,11 +35,10 @@ BatchDataset::BatchDataset( size_ = std::floor(static_cast(preBatchSize_) / batchSize_); break; case BatchDatasetPolicy::DIVISIBLE_ONLY: - if(size_ % batchSize_ != 0) { + if(size_ % batchSize_ != 0) throw std::invalid_argument( "dataset is not evenly divisible into batches" ); - } size_ = std::ceil(static_cast(preBatchSize_) / batchSize_); break; default: @@ -56,12 +53,10 @@ BatchDataset::BatchDataset( ) : dataset_(dataset), cumSumBatchSize_(batchSizes), batchFns_(batchfns) { - if(!dataset_) { + if(!dataset_) throw std::invalid_argument("dataset to be batched is null"); - } - if(cumSumBatchSize_.empty()) { + if(cumSumBatchSize_.empty()) throw std::invalid_argument("batch size vector should not be empty"); - } std::partial_sum( cumSumBatchSize_.begin(), cumSumBatchSize_.end(), diff --git a/flashlight/fl/dataset/BlobDataset.cpp b/flashlight/fl/dataset/BlobDataset.cpp index c84be6b..021ffbd 100644 --- a/flashlight/fl/dataset/BlobDataset.cpp +++ b/flashlight/fl/dataset/BlobDataset.cpp @@ -36,9 +36,8 @@ BlobDatasetEntry BlobDatasetEntryBuffer::get(const int64_t idx) const { e.type = static_cast(data_[dataIdx++]); unsigned numDims = data_[dataIdx++]; e.dims = Shape(std::vector(numDims)); - for(int i = 0; i < numDims; i++) { + for(int i = 0; i < numDims; i++) e.dims[i] = data_[dataIdx + i]; - } e.offset = data_[dataIdx + maxNDims_]; return e; } @@ -47,9 +46,8 @@ void BlobDatasetEntryBuffer::add(const BlobDatasetEntry& e) { data_.push_back(static_cast(e.type)); data_.push_back(static_cast(e.dims.ndim())); int i = 0; - for(; i < e.dims.ndim(); i++) { + for(; i < e.dims.ndim(); i++) data_.push_back(e.dims[i]); - } for(; i < maxNDims_; ++i) { data_.push_back(1); // placeholder dim } @@ -96,12 +94,11 @@ void BlobDataset::add(const std::vector& sample) { offsets_.push_back(entries_.size()); sizes_.push_back(sample.size()); for(const auto& tensor : sample) { - if(tensor.ndim() > maxNDims_) { + if(tensor.ndim() > maxNDims_) throw std::invalid_argument( "BlobDataset::add - no support for serialization of " "tensors with > 4 dimensions" ); - } BlobDatasetEntry e; e.type = tensor.type(); e.dims = tensor.shape(); @@ -119,14 +116,12 @@ void BlobDataset::add(const std::vector& sample) { void BlobDataset::add(const BlobDataset& blob, int64_t chunkSize) { std::lock_guard lock(mutex_); - if(chunkSize <= 0) { + if(chunkSize <= 0) throw std::runtime_error("chunkSize must be positive"); - } sizes_.insert(sizes_.end(), blob.sizes_.begin(), blob.sizes_.end()); std::vector offsets = blob.offsets_; - for(auto& offset : offsets) { + for(auto& offset : offsets) offset += entries_.size(); - } offsets_.insert(offsets_.end(), offsets.begin(), offsets.end()); for(int64_t i = 0; i < blob.entries_.size(); i++) { auto e = blob.entries_.get(i); @@ -145,12 +140,10 @@ void BlobDataset::add(const BlobDataset& blob, int64_t chunkSize) { this->writeData(indexOffset_, buffer.data(), size); this->indexOffset_ += size; }; - for(int64_t i = 0; i < nChunk; i++) { + for(int64_t i = 0; i < nChunk; i++) copyChunk(chunkSize); - } - if(remainCopySize > 0) { + if(remainCopySize > 0) copyChunk(remainCopySize); - } } std::vector BlobDataset::readRawArray( @@ -172,19 +165,17 @@ Tensor BlobDataset::readArray(const BlobDatasetEntry& e, int i) const { if(e.dims.elements() > 0) { auto buffer = readRawArray(e); auto keyval = hostTransforms_.find(i); - if(keyval == hostTransforms_.end()) { + if(keyval == hostTransforms_.end()) return Tensor::fromBuffer( e.dims, e.type, buffer.data(), MemoryLocation::Host ); - } else { + else return keyval->second(buffer.data(), e.dims, e.type); - } - } else { + } else return Tensor(); - } } void BlobDataset::writeArray(const BlobDatasetEntry& e, const Tensor& array) { @@ -224,9 +215,8 @@ void BlobDataset::readIndex() { int64_t magicNumberCheck = 0; int64_t offset = readData(0, (char*) &magicNumberCheck, sizeof(int64_t)); - if(magicNumber != magicNumberCheck) { + if(magicNumber != magicNumberCheck) throw std::runtime_error("BlobDataset::readIndex - not a fl::BlobDataset"); - } readData(offset, (char*) &indexOffset_, sizeof(int64_t)); offset = indexOffset_; @@ -256,9 +246,8 @@ void BlobDataset::setHostTransform( std::vector BlobDataset::getEntries(const int64_t idx) const { std::vector entries; - for(int64_t i = 0; i < sizes_.at(idx); i++) { + for(int64_t i = 0; i < sizes_.at(idx); i++) entries.push_back(entries_.get(offsets_.at(idx) + i)); - } return entries; } diff --git a/flashlight/fl/dataset/ConcatDataset.cpp b/flashlight/fl/dataset/ConcatDataset.cpp index e939808..b014a98 100644 --- a/flashlight/fl/dataset/ConcatDataset.cpp +++ b/flashlight/fl/dataset/ConcatDataset.cpp @@ -15,9 +15,8 @@ ConcatDataset::ConcatDataset( const std::vector>& datasets ) : datasets_(datasets), size_(0) { - if(datasets.empty()) { + if(datasets.empty()) throw std::invalid_argument("cannot concat 0 datasets"); - } cumulativedatasetsizes_.emplace_back(0); for(const auto& dataset : datasets_) { size_ += dataset->size(); diff --git a/flashlight/fl/dataset/Dataset.h b/flashlight/fl/dataset/Dataset.h index c18bbfb..ec710ad 100644 --- a/flashlight/fl/dataset/Dataset.h +++ b/flashlight/fl/dataset/Dataset.h @@ -82,9 +82,8 @@ class FL_API Dataset { protected: void checkIndexBounds(int64_t idx) const { - if(!(idx >= 0 && idx < size())) { + if(!(idx >= 0 && idx < size())) throw std::out_of_range("Dataset idx out of range"); - } } }; diff --git a/flashlight/fl/dataset/DatasetIterator.h b/flashlight/fl/dataset/DatasetIterator.h index 813ecf1..35bd79a 100644 --- a/flashlight/fl/dataset/DatasetIterator.h +++ b/flashlight/fl/dataset/DatasetIterator.h @@ -53,17 +53,15 @@ namespace detail { // Pre- and post-incrementable. DatasetIterator& operator++() { - if(++idx_ >= dataset_->size()) { + if(++idx_ >= dataset_->size()) idx_ = -1; - } return *this; } DatasetIterator operator++(int) { DatasetIterator tmp(*this); - if(++idx_ >= dataset_->size()) { + if(++idx_ >= dataset_->size()) idx_ = -1; - } return tmp; } diff --git a/flashlight/fl/dataset/FileBlobDataset.cpp b/flashlight/fl/dataset/FileBlobDataset.cpp index f1717ed..78c9eb7 100644 --- a/flashlight/fl/dataset/FileBlobDataset.cpp +++ b/flashlight/fl/dataset/FileBlobDataset.cpp @@ -20,9 +20,8 @@ FileBlobDataset::FileBlobDataset( | std::ios_base::binary; { std::ofstream fs(name_, (truncate ? mode_ | std::ios_base::trunc : mode_)); - if(!fs.is_open()) { + if(!fs.is_open()) throw std::runtime_error("could not open file " + name.string()); - } } readIndex(); } @@ -51,22 +50,18 @@ std::shared_ptr FileBlobDataset::getStream() const { while(i != std::end(allFileHandles_)) { auto ptr = i->lock(); if(ptr) { - if(threadFileHandles == ptr) { + if(threadFileHandles == ptr) match = true; - } ++i; - } else { + } else i = allFileHandles_.erase(i); - } } - if(!match) { + if(!match) allFileHandles_.push_back(threadFileHandles); - } } return fs; - } else { + } else return keyval->second; - } } int64_t FileBlobDataset::writeData( @@ -103,9 +98,8 @@ FileBlobDataset::~FileBlobDataset() { std::lock_guard lock(afhmutex_); for(auto& weakFileHandles : allFileHandles_) { auto fileHandles = weakFileHandles.lock(); - if(fileHandles) { + if(fileHandles) fileHandles->erase(reinterpret_cast(this)); - } } } diff --git a/flashlight/fl/dataset/MemoryBlobDataset.cpp b/flashlight/fl/dataset/MemoryBlobDataset.cpp index d2502ac..5b76bce 100644 --- a/flashlight/fl/dataset/MemoryBlobDataset.cpp +++ b/flashlight/fl/dataset/MemoryBlobDataset.cpp @@ -21,9 +21,8 @@ int64_t MemoryBlobDataset::writeData( int64_t size ) const { std::lock_guard lock(writeMutex_); - if(offset + size > data_.size()) { + if(offset + size > data_.size()) data_.resize(offset + size); - } std::memcpy(data_.data() + offset, data, size); return size; } diff --git a/flashlight/fl/dataset/MergeDataset.cpp b/flashlight/fl/dataset/MergeDataset.cpp index 8c75c25..801b8e9 100644 --- a/flashlight/fl/dataset/MergeDataset.cpp +++ b/flashlight/fl/dataset/MergeDataset.cpp @@ -14,16 +14,15 @@ MergeDataset::MergeDataset( const std::vector>& datasets ) : datasets_(datasets) { size_ = 0; - for(const auto& dataset : datasets_) { + for(const auto& dataset : datasets_) size_ = std::max(dataset->size(), size_); - } } std::vector MergeDataset::get(const int64_t idx) const { checkIndexBounds(idx); std::vector result; - for(const auto& dataset : datasets_) { + for(const auto& dataset : datasets_) if(idx < dataset->size()) { auto f = dataset->get(idx); result.insert( @@ -32,7 +31,6 @@ std::vector MergeDataset::get(const int64_t idx) const { std::make_move_iterator(f.end()) ); } - } return result; } diff --git a/flashlight/fl/dataset/PrefetchDataset.cpp b/flashlight/fl/dataset/PrefetchDataset.cpp index fe65c1d..ae0de80 100644 --- a/flashlight/fl/dataset/PrefetchDataset.cpp +++ b/flashlight/fl/dataset/PrefetchDataset.cpp @@ -22,15 +22,13 @@ PrefetchDataset::PrefetchDataset( numThreads_(numThreads), prefetchSize_(prefetchSize), curIdx_(-1) { - if(!dataset_) { + if(!dataset_) throw std::invalid_argument("dataset to be prefetched is null"); - } if( !(numThreads_ > 0 && prefetchSize_ > 0) && !(numThreads_ == 0 && prefetchSize_ == 0) - ) { + ) throw std::invalid_argument("invalid numThreads or prefetchSize"); - } if(numThreads_ > 0) { auto deviceId = fl::getDevice(); threadPool_ = std::make_unique( @@ -42,9 +40,8 @@ PrefetchDataset::PrefetchDataset( std::vector PrefetchDataset::get(int64_t idx) const { checkIndexBounds(idx); - if(numThreads_ == 0) { + if(numThreads_ == 0) return dataset_->get(idx); - } // remove from cache (if necessary) while(!prefetchCache_.empty() && idx != curIdx_) { @@ -55,9 +52,8 @@ std::vector PrefetchDataset::get(int64_t idx) const { // add to cache (if necessary) while(prefetchCache_.size() < prefetchSize_) { auto fetchIdx = idx + prefetchCache_.size(); - if(fetchIdx >= size()) { + if(fetchIdx >= size()) break; - } prefetchCache_.emplace( threadPool_->enqueue( [this, fetchIdx]() { return this->dataset_->get(fetchIdx); }) diff --git a/flashlight/fl/dataset/ResampleDataset.cpp b/flashlight/fl/dataset/ResampleDataset.cpp index 1d7e8cd..df6cf2f 100644 --- a/flashlight/fl/dataset/ResampleDataset.cpp +++ b/flashlight/fl/dataset/ResampleDataset.cpp @@ -23,9 +23,8 @@ std::vector makePermutationFromFn( int64_t size, const fl::Dataset::PermutationFunction& fn ) { - if(!fn) { + if(!fn) throw std::invalid_argument("PermutationFunction is null"); - } auto perm = makeIdentityPermutation(size); std::transform(perm.begin(), perm.end(), perm.begin(), fn); return perm; @@ -43,9 +42,8 @@ ResampleDataset::ResampleDataset( std::shared_ptr dataset, std::vector resamplevec ) : dataset_(dataset) { - if(!dataset_) { + if(!dataset_) throw std::invalid_argument("dataset to be resampled is null"); - } resample(std::move(resamplevec)); } diff --git a/flashlight/fl/dataset/ShuffleDataset.cpp b/flashlight/fl/dataset/ShuffleDataset.cpp index 008d461..148b343 100644 --- a/flashlight/fl/dataset/ShuffleDataset.cpp +++ b/flashlight/fl/dataset/ShuffleDataset.cpp @@ -27,12 +27,11 @@ void ShuffleDataset::resample() { // en.cppreference.com/w/cpp/algorithm/random_shuffle#Possible_implementation using distr_t = std::uniform_int_distribution; distr_t D; - for(int i = n - 1; i > 0; --i) { + for(int i = n - 1; i > 0; --i) std::swap( resampleVec_[i], resampleVec_[D(rng_, distr_t::param_type(0, i))] ); - } } void ShuffleDataset::setSeed(int seed) { diff --git a/flashlight/fl/dataset/SpanDataset.cpp b/flashlight/fl/dataset/SpanDataset.cpp index 5b7233f..0f9f19d 100644 --- a/flashlight/fl/dataset/SpanDataset.cpp +++ b/flashlight/fl/dataset/SpanDataset.cpp @@ -17,11 +17,10 @@ SpanDataset::SpanDataset( ) : dataset_(dataset), offset_(offset) { size_ = (length < 0) ? (dataset_->size() - offset_) : length; - if(size_ + offset_ > dataset_->size()) { + if(size_ + offset_ > dataset_->size()) throw std::out_of_range( "Dataset length out of range (larger than underlying dataset)" ); - } } std::vector SpanDataset::get(const int64_t idx) const { diff --git a/flashlight/fl/dataset/TensorDataset.cpp b/flashlight/fl/dataset/TensorDataset.cpp index 58576a1..e68cd64 100644 --- a/flashlight/fl/dataset/TensorDataset.cpp +++ b/flashlight/fl/dataset/TensorDataset.cpp @@ -16,15 +16,13 @@ namespace fl { TensorDataset::TensorDataset(const std::vector& dataTensors) : dataTensors_(dataTensors), size_(0) { - if(dataTensors_.empty()) { + if(dataTensors_.empty()) throw std::invalid_argument("no tensors passed to TensorDataset"); - } for(const auto& tensor : dataTensors_) { auto ndims = tensor.ndim(); - if(ndims == 0) { + if(ndims == 0) throw std::invalid_argument("tensor for TensorDataset can't be empty"); - } auto lastdim = ndims - 1; int64_t cursz = tensor.dim(lastdim); diff --git a/flashlight/fl/dataset/TransformDataset.cpp b/flashlight/fl/dataset/TransformDataset.cpp index ae35827..c89d7b1 100644 --- a/flashlight/fl/dataset/TransformDataset.cpp +++ b/flashlight/fl/dataset/TransformDataset.cpp @@ -16,9 +16,8 @@ TransformDataset::TransformDataset( const std::vector& transformfns ) : dataset_(dataset), transformFns_(transformfns) { - if(!dataset_) { + if(!dataset_) throw std::invalid_argument("dataset to be transformed is null"); - } } std::vector TransformDataset::get(const int64_t idx) const { @@ -27,9 +26,8 @@ std::vector TransformDataset::get(const int64_t idx) const { auto result = dataset_->get(idx); for(int64_t i = 0; i < result.size(); ++i) { - if(i >= transformFns_.size() || !transformFns_[i]) { + if(i >= transformFns_.size() || !transformFns_[i]) continue; - } result[i] = transformFns_[i](result[i]); } return result; diff --git a/flashlight/fl/dataset/Utils.cpp b/flashlight/fl/dataset/Utils.cpp index 8e298b0..2e0ca56 100644 --- a/flashlight/fl/dataset/Utils.cpp +++ b/flashlight/fl/dataset/Utils.cpp @@ -21,20 +21,17 @@ std::vector partitionByRoundRobin( int64_t batchSz /* = 1 */, bool allowEmpty /* = false */ ) { - if(partitionId < 0 || partitionId >= numPartitions) { + if(partitionId < 0 || partitionId >= numPartitions) throw std::invalid_argument( "invalid partitionId, numPartitions for partitionByRoundRobin" ); - } int64_t nSamplesPerGlobalBatch = numPartitions * batchSz; int64_t nGlobalBatches = numSamples / nSamplesPerGlobalBatch; bool includeLast = (numSamples % nSamplesPerGlobalBatch) >= numPartitions; - if(allowEmpty && (numSamples % nSamplesPerGlobalBatch) > 0) { + if(allowEmpty && (numSamples % nSamplesPerGlobalBatch) > 0) includeLast = true; - } - if(includeLast) { + if(includeLast) ++nGlobalBatches; - } std::vector outSamples; outSamples.reserve(nGlobalBatches * batchSz); @@ -46,17 +43,15 @@ std::vector partitionByRoundRobin( (numSamples - offset) / numPartitions; // min samples per proc int64_t remaining = (numSamples - offset) % numPartitions; offset += nCurSamples * partitionId; - if(partitionId < remaining) { + if(partitionId < remaining) nCurSamples += 1; - } offset += std::min(partitionId, remaining); } else { offset += batchSz * partitionId; nCurSamples = batchSz; } - for(int64_t b = 0; b < nCurSamples; ++b) { + for(int64_t b = 0; b < nCurSamples; ++b) outSamples.emplace_back(b + offset); - } } return outSamples; } @@ -68,16 +63,15 @@ std::pair, std::vector> dynamicPartitionByRoundRob int64_t maxSizePerBatch, bool allowEmpty /* = false */ ) { - if(partitionId < 0 || partitionId >= numPartitions) { + if(partitionId < 0 || partitionId >= numPartitions) throw std::invalid_argument( "[dynamicPartitionByRoundRobin] invalid partitionId, numPartitions" ); - } std::vector batchSizes, batchOffsets; int64_t sampleIdx = 0, batchStartSampleIdx = 0; float maxSampleLen = 0; while(sampleIdx < samplesSize.size()) { - if(samplesSize[sampleIdx] > maxSizePerBatch) { + if(samplesSize[sampleIdx] > maxSizePerBatch) throw std::invalid_argument( "[dynamicPartitionByRoundRobin] invalid samples length: each sample " "should have size <= maxSizePerBatch, either filter data or set larger maxSizePerBatch. " @@ -85,7 +79,6 @@ std::pair, std::vector> dynamicPartitionByRoundRob + std::to_string(maxSizePerBatch) + " sample size is " + std::to_string(samplesSize[sampleIdx]) ); - } float maxSampleLenOld = maxSampleLen; maxSampleLen = std::max(maxSampleLen, samplesSize[sampleIdx]); if( @@ -95,18 +88,16 @@ std::pair, std::vector> dynamicPartitionByRoundRob if( maxSampleLenOld * (sampleIdx - batchStartSampleIdx) > maxSizePerBatch - ) { + ) throw std::invalid_argument( "dynamicPartitionByRoundRobin is doing wrong packing" ); - } batchSizes.push_back(sampleIdx - batchStartSampleIdx); batchOffsets.push_back(batchStartSampleIdx); batchStartSampleIdx = sampleIdx; maxSampleLen = samplesSize[sampleIdx]; - } else { + } else sampleIdx++; - } } // process last batch with sampleIdx == numSamples, batchStartSampleIdx < // numSamples @@ -116,17 +107,15 @@ std::pair, std::vector> dynamicPartitionByRoundRob } int64_t nGlobalBatches = batchSizes.size() / numPartitions; - if(allowEmpty && (batchSizes.size() % numPartitions) > 0) { + if(allowEmpty && (batchSizes.size() % numPartitions) > 0) ++nGlobalBatches; - } std::vector outSamples, outBatchSizes; for(size_t i = 0; i < nGlobalBatches; i++) { int index = i * numPartitions + partitionId; if(index < batchSizes.size()) { outBatchSizes.emplace_back(batchSizes[index]); - for(int64_t b = 0; b < batchSizes[index]; ++b) { + for(int64_t b = 0; b < batchSizes[index]; ++b) outSamples.emplace_back(b + batchOffsets[index]); - } } } return {outSamples, outBatchSizes}; @@ -141,18 +130,15 @@ std::vector makeBatchFromRange( std::vector> buffer; for(int64_t batchidx = start; batchidx < end; ++batchidx) { auto fds = dataset->get(batchidx); - if(buffer.size() < fds.size()) { + if(buffer.size() < fds.size()) buffer.resize(fds.size()); - } - for(int64_t i = 0; i < fds.size(); ++i) { + for(int64_t i = 0; i < fds.size(); ++i) buffer[i].emplace_back(fds[i]); - } } std::vector result(buffer.size()); - for(int64_t i = 0; i < buffer.size(); ++i) { + for(int64_t i = 0; i < buffer.size(); ++i) result[i] = makeBatch(buffer[i], (i < batchFns.size()) ? batchFns[i] : nullptr); - } return result; } @@ -160,28 +146,23 @@ Tensor makeBatch( const std::vector& data, const Dataset::BatchFunction& batchFn ) { - if(batchFn) { + if(batchFn) return batchFn(data); - } // Using default batching function - if(data.empty()) { + if(data.empty()) return Tensor(); - } auto& dims = data[0].shape(); - for(const auto& d : data) { - if(d.shape() != dims) { + for(const auto& d : data) + if(d.shape() != dims) throw std::invalid_argument("dimension mismatch while batching dataset"); - } - } int ndims = (data[0].elements() > 1) ? dims.ndim() : 0; // TODO: expand this to > 4 given fl::Tensor - should work out of the box // by just removing this check? Possibly also change to ndims >= dims.ndims() - if(ndims >= 4) { + if(ndims >= 4) throw std::invalid_argument("# of dims must be < ndim - 1 for batching"); - } // Dimensions of the batched tensor std::vector batchDims = dims.get(); if(ndims + 1 > batchDims.size()) { diff --git a/flashlight/fl/distributed/DistributedApi.cpp b/flashlight/fl/distributed/DistributedApi.cpp index a32601e..6ba3e47 100644 --- a/flashlight/fl/distributed/DistributedApi.cpp +++ b/flashlight/fl/distributed/DistributedApi.cpp @@ -21,9 +21,8 @@ FL_API DistributedBackend distributedBackend() { } FL_API void allReduce(Variable& var, double scale /* = 1.0 */, bool async /* = false */) { - if(getWorldSize() > 1) { + if(getWorldSize() > 1) allReduce(var.tensor(), async); - } var.tensor() *= scale; } @@ -35,15 +34,12 @@ FL_API void allReduceMultiple( ) { // return a vector of pointers to avoid copying std::vector arrs; - for(auto& var : vars) { + for(auto& var : vars) arrs.push_back(&var.tensor()); - } - if(getWorldSize() > 1) { + if(getWorldSize() > 1) allReduceMultiple(arrs, async, contiguous); - } - for(auto& var : vars) { + for(auto& var : vars) var.tensor() *= scale; - } } FL_API void barrier() { diff --git a/flashlight/fl/distributed/FileStore.cpp b/flashlight/fl/distributed/FileStore.cpp index 8a24ccc..82627a5 100644 --- a/flashlight/fl/distributed/FileStore.cpp +++ b/flashlight/fl/distributed/FileStore.cpp @@ -33,20 +33,18 @@ void FileStore::set(const std::string& key, const std::vector& data) { // using an API that fails if the file exists (not provided by STL). If // created successfully, rename the temp file as below. std::ifstream ifs(path); - if(ifs.is_open()) { + if(ifs.is_open()) throw std::runtime_error( "FileStore set: file already exists: " + path.string() ); - } } { std::ofstream ofs(tmp, std::ios::out | std::ios::trunc); - if(!ofs.is_open()) { + if(!ofs.is_open()) throw std::runtime_error( "FileStore set: file create failed: " + tmp.string() ); - } ofs.write(data.data(), data.size()); } @@ -62,17 +60,15 @@ std::vector FileStore::get(const std::string& key) { wait(key); std::ifstream ifs(path, std::ios::in); - if(!ifs) { + if(!ifs) throw std::runtime_error( "FileStore get: file open failed: " + path.string() ); - } ifs.seekg(0, std::ios::end); size_t n = ifs.tellg(); - if(n == 0) { + if(n == 0) throw std::runtime_error("FileStore get: file is empty: " + path.string()); - } result.resize(n); ifs.seekg(0); ifs.read(result.data(), n); @@ -97,9 +93,8 @@ void FileStore::wait(const std::string& key) { const auto elapsed = std::chrono::duration_cast( std::chrono::steady_clock::now() - start ); - if(elapsed > FileStore::kDefaultTimeout) { + if(elapsed > FileStore::kDefaultTimeout) throw std::runtime_error("FileStore timed out for key: " + key); - } /* sleep override */ std::this_thread::sleep_for(std::chrono::milliseconds(10)); } diff --git a/flashlight/fl/distributed/LRUCache.h b/flashlight/fl/distributed/LRUCache.h index f13ad51..5e5ec38 100644 --- a/flashlight/fl/distributed/LRUCache.h +++ b/flashlight/fl/distributed/LRUCache.h @@ -39,9 +39,8 @@ namespace detail { map_.erase(dq_.back()); dq_.pop_back(); } - } else { + } else dq_.erase(map_[k].first); - } dq_.push_front(k); map_[k] = std::make_pair(dq_.begin(), std::move(v)); @@ -49,9 +48,9 @@ namespace detail { } inline V* get(K const& k) { - if(map_.find(k) == map_.end()) { + if(map_.find(k) == map_.end()) return nullptr; - } else { + else { // Move list node to front auto& it = map_[k].first; dq_.splice(dq_.begin(), dq_, it); diff --git a/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp b/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp index e3ebb6d..e68cde6 100644 --- a/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp +++ b/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp @@ -80,16 +80,14 @@ void distributedInit( return; } - if(initMethod != DistributedInit::MPI) { + if(initMethod != DistributedInit::MPI) throw std::runtime_error( "unsupported distributed init method for gloo backend" ); - } // using MPI - if(glooContext_ != nullptr) { + if(glooContext_ != nullptr) return; - } // TODO: ibverbs support. auto glooDev = gloo::transport::tcp::CreateDevice(""); @@ -100,25 +98,21 @@ void distributedInit( detail::DistributedInfo::getInstance().backend_ = DistributedBackend::GLOO; detail::DistributedInfo::getInstance().isInitialized_ = true; - if(glooContext_->rank == 0) { + if(glooContext_->rank == 0) std::cout << "Initialized Gloo successfully!\n"; - } } void allReduce(fl::Tensor& tensor, bool async /* = false */) { - if(!isDistributedInit()) { + if(!isDistributedInit()) throw std::runtime_error("distributed environment not initialized"); - } - if(async) { + if(async) throw std::runtime_error( "Asynchronous allReduce not yet supported for Gloo backend" ); - } size_t tensorSize = tensor.elements() * fl::getTypeSize(tensor.type()); - if(tensorSize > cacheTensor_.elements()) { + if(tensorSize > cacheTensor_.elements()) cacheTensor_ = fl::Tensor({static_cast(tensorSize)}, fl::dtype::b8); - } DevicePtr tensorPtr(tensor); DevicePtr cacheTensorPtr(cacheTensor_); memcpy(cacheTensorPtr.get(), tensorPtr.get(), tensorSize); @@ -159,15 +153,13 @@ void allReduceMultiple( bool async /* = false */, bool contiguous /* = false */ ) { - if(contiguous) { + if(contiguous) throw std::runtime_error( "contiguous allReduceMultiple is not yet supported for Gloo backend" ); - } - for(auto& tensor : tensors) { + for(auto& tensor : tensors) allReduce(*tensor, async); - } } void syncDistributed() { @@ -177,16 +169,14 @@ void syncDistributed() { } int getWorldRank() { - if(!isDistributedInit()) { + if(!isDistributedInit()) return 0; - } return detail::globalContext()->rank; } int getWorldSize() { - if(!isDistributedInit()) { + if(!isDistributedInit()) return 1; - } return detail::globalContext()->size; } } // namespace fl diff --git a/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp b/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp index 257bac8..ad09e2e 100644 --- a/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp +++ b/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp @@ -116,9 +116,8 @@ namespace detail { } // namespace detail void allReduce(Tensor& arr, bool async /* = false */) { - if(!isDistributedInit()) { + if(!isDistributedInit()) throw std::runtime_error("distributed environment not initialized"); - } ncclDataType_t type = detail::getNcclTypeForArray(arr); DevicePtr tensorPtr(arr); detail::allReduceCuda( @@ -137,16 +136,14 @@ void allReduceMultiple( bool contiguous /* = false */ ) { // Fast paths - if(arrs.empty()) { + if(arrs.empty()) return; - } if(!contiguous) { // Use nccl groups to do everything in a single kernel launch NCCLCHECK(ncclGroupStart()); - for(auto& arr : arrs) { + for(auto& arr : arrs) allReduce(*arr, async); - } NCCLCHECK(ncclGroupEnd()); return; } @@ -154,14 +151,12 @@ void allReduceMultiple( // We can only do a contiguous set reduction if all arrays in the set are of // the same type, else fail ncclDataType_t ncclType = detail::getNcclTypeForArray(*arrs[0]); - for(auto& arr : arrs) { - if(detail::getNcclTypeForArray(*arr) != ncclType) { + for(auto& arr : arrs) + if(detail::getNcclTypeForArray(*arr) != ncclType) throw std::runtime_error( "Cannot perform contiguous set allReduce on a set of tensors " "of different types" ); - } - } // Size of each element in each tensor in bytes size_t typeSize = fl::getTypeSize(arrs[0]->type()); @@ -178,11 +173,10 @@ void allReduceMultiple( // coalescing cache to the same size, if we're using contiguous sync, it // should never be larger since we flush if adding an additional buffer would // exceed the max cache size - if(totalEls * typeSize > DistributedConstants::kCoalesceCacheSize) { + if(totalEls * typeSize > DistributedConstants::kCoalesceCacheSize) throw std::runtime_error( "Total coalesce buffer size is larger than existing buffer size" ); - } auto& ncclContext = detail::NcclContext::getInstance(); const auto& workerStream = ncclContext.getWorkerStream(); @@ -220,11 +214,10 @@ void allReduceMultiple( // Block the worker stream's copy operations on allReduce operations that are // currently enqueued in the reduction stream - if(async) { + if(async) workerStream.relativeSync(ncclContext.getReductionStream()); - } else { + else relativeSync(workerStream, constTensors); - } // Enqueue operations in the stream to copy back to each respective array from // the coalesce buffer @@ -253,25 +246,22 @@ void syncDistributed() { const auto& activeCudaDevice = manager.getActiveDevice(DeviceType::CUDA); const auto& workerStream = ncclContext.getWorkerStream(); const auto& reductionStream = ncclContext.getReductionStream(); - for(const auto& stream : activeCudaDevice.getStreams()) { + for(const auto& stream : activeCudaDevice.getStreams()) if(stream.get() != &workerStream && stream.get() != &reductionStream) { stream->relativeSync(workerStream); stream->relativeSync(reductionStream); } - } } int getWorldRank() { - if(!isDistributedInit()) { + if(!isDistributedInit()) return 0; - } return detail::NcclContext::getInstance().getWorldRank(); } int getWorldSize() { - if(!isDistributedInit()) { + if(!isDistributedInit()) return 1; - } return detail::NcclContext::getInstance().getWorldSize(); } @@ -296,39 +286,35 @@ void distributedInit( ); detail::DistributedInfo::getInstance().initMethod_ = DistributedInit::FILE_SYSTEM; - } else { + } else throw std::runtime_error( "unsupported distributed init method for NCCL backend" ); - } detail::DistributedInfo::getInstance().isInitialized_ = true; detail::DistributedInfo::getInstance().backend_ = DistributedBackend::NCCL; - if(getWorldRank() == 0) { + if(getWorldRank() == 0) std::cout << "Initialized NCCL " << NCCL_MAJOR << "." << NCCL_MINOR << "." << NCCL_PATCH << " successfully!\n"; - } } namespace detail { void ncclCheck(ncclResult_t r) { - if(r == ncclSuccess) { + if(r == ncclSuccess) return; - } const char* err = ncclGetErrorString(r); - if(r == ncclInvalidArgument) { + if(r == ncclInvalidArgument) throw std::invalid_argument(err); - } else if(r == ncclInvalidUsage) { + else if(r == ncclInvalidUsage) throw std::logic_error(err); - } else { + else throw std::runtime_error(err); - } } void mpiCheck(int ec) { - if(ec == MPI_SUCCESS) { + if(ec == MPI_SUCCESS) return; - } else { + else { char buf[MPI_MAX_ERROR_STRING]; int resultlen; MPI_Error_string(ec, buf, &resultlen); @@ -346,22 +332,20 @@ namespace detail { ) { const CUDAStream* syncStream; auto& ncclContext = detail::NcclContext::getInstance(); - if(async) { + if(async) syncStream = &ncclContext.getReductionStream(); - } else { + else syncStream = bufferStream; - } // Synchronize with whatever CUDA stream is performing operations needed // pre-reduction. If we're in contiguous mode, we need the reduction stream to // wait for the copy in the worker stream to complete. If we're not in // CUDA stream. - if(contiguous) { + if(contiguous) // block future reduction stream ops on the copy-worker stream syncStream->relativeSync(ncclContext.getWorkerStream()); - } else if(async) { + else if(async) syncStream->relativeSync(*bufferStream); - } // don't synchronize streams if not async and not contiguous - the AF CUDA // stream does everything @@ -451,11 +435,10 @@ namespace detail { maxDevicePerNode == params.end() || !isNonNegativeInteger(maxDevicePerNode->second) || std::stoi(maxDevicePerNode->second) == 0 - ) { + ) throw std::invalid_argument( "invalid MaxDevicePerNode for NCCL initWithMPI" ); - } ncclUniqueId id; @@ -463,9 +446,8 @@ namespace detail { fl::setDevice(worldRank_ % std::stoi(maxDevicePerNode->second)); // get NCCL unique ID at rank 0 and broadcast it to all others - if(worldRank_ == 0) { + if(worldRank_ == 0) ncclGetUniqueId(&id); - } MPICHECK(MPI_Bcast((void*) &id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); // initializing NCCL @@ -482,14 +464,12 @@ namespace detail { auto filePath = params.find(DistributedConstants::kFilePath); auto maxDevicePerNode = params.find(DistributedConstants::kMaxDevicePerNode); - if(filePath == params.end() || filePath->second.empty()) { + if(filePath == params.end() || filePath->second.empty()) throw std::invalid_argument("invalid FilePath for NCCL initWithFileSystem"); - } - if(maxDevicePerNode == params.end()) { + if(maxDevicePerNode == params.end()) throw std::invalid_argument( "invalid MaxDevicePerNode for NCCL initWithFileSystem" ); - } worldRank_ = worldRank; worldSize_ = worldSize; @@ -499,9 +479,8 @@ namespace detail { fl::setDevice(worldRank_ % std::stoi(maxDevicePerNode->second)); // get NCCL unique ID at rank 0 and broadcast it to all others - if(worldRank_ == 0) { + if(worldRank_ == 0) ncclGetUniqueId(&id); - } auto fs = FileStore(filePath->second); if(worldRank_ == 0) { @@ -518,9 +497,8 @@ namespace detail { NCCLCHECK(ncclCommInitRank(&comm_, worldSize_, id, worldRank_)); // Remove the temporary file created for initialization - if(worldRank_ == 0) { + if(worldRank_ == 0) fs.clear(kNcclKey); - } createCudaResources(); } @@ -537,15 +515,14 @@ namespace detail { // default, as driver shutdown will clean up this memory anyways. #ifdef CUDA_CONTIGUOUS_BUFFER_FREE_ON_SHUTDOWN // Free the coalesce buffer if it was allocated - if(coalesceBuffer_ != nullptr) { + if(coalesceBuffer_ != nullptr) FL_CUDA_CHECK(cudaFree(coalesceBuffer_)); - } + #endif - if(DistributedInfo::getInstance().initMethod_ == DistributedInit::MPI) { + if(DistributedInfo::getInstance().initMethod_ == DistributedInit::MPI) // finalizing MPI MPICHECK(MPI_Finalize()); - } } } // namespace } // namespace detail diff --git a/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp b/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp index a828ba4..b4be54f 100644 --- a/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp +++ b/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp @@ -25,17 +25,15 @@ void distributedInit( std::cerr << "warning: fl::distributedInit() called more than once\n"; return; } - if(worldSize > 1 || worldRank > 0) { + if(worldSize > 1 || worldRank > 0) throw std::runtime_error("worldSize must be 1 with distributed stub"); - } detail::DistributedInfo::getInstance().backend_ = DistributedBackend::STUB; detail::DistributedInfo::getInstance().isInitialized_ = true; } void allReduce(Tensor& arr, bool async /* = false */) { - if(!isDistributedInit()) { + if(!isDistributedInit()) throw std::runtime_error("distributed environment not initialized"); - } throw std::runtime_error("allReduce not supported for stub backend"); } diff --git a/flashlight/fl/distributed/reducers/CoalescingReducer.cpp b/flashlight/fl/distributed/reducers/CoalescingReducer.cpp index fb7f66e..d2dd686 100644 --- a/flashlight/fl/distributed/reducers/CoalescingReducer.cpp +++ b/flashlight/fl/distributed/reducers/CoalescingReducer.cpp @@ -24,22 +24,20 @@ CoalescingReducer::~CoalescingReducer() { void CoalescingReducer::add(Variable& var) { // if this tensor would push the cache oversize, flush - if(currCacheSize_ + var.bytes() > cacheThresholdBytes_) { + if(currCacheSize_ + var.bytes() > cacheThresholdBytes_) flush(); - } // check if the tensor is larger than the cache. If so, reduce immediately // and don't copy-coalesce - if(var.bytes() > cacheThresholdBytes_) { + if(var.bytes() > cacheThresholdBytes_) allReduce(var, scale_, async_); - } else { + else { // if async, evaluating the JIT on the value upfront is more efficient than // evaluating the JIT for each Variable in the cache after we flush it, // since it more effectively facilitates overlapping compuation between the // AF and distributed compute streams. - if(async_) { + if(async_) var.eval(); - } // otherwise, add to cache cache_.push_back(var); currCacheSize_ += var.bytes(); @@ -58,9 +56,8 @@ void CoalescingReducer::flush() { } void CoalescingReducer::synchronize() { - if(async_ || contiguous_) { + if(async_ || contiguous_) syncDistributed(); - } } } // namespace fl diff --git a/flashlight/fl/distributed/reducers/InlineReducer.cpp b/flashlight/fl/distributed/reducers/InlineReducer.cpp index 8d0f444..bcd5d3c 100644 --- a/flashlight/fl/distributed/reducers/InlineReducer.cpp +++ b/flashlight/fl/distributed/reducers/InlineReducer.cpp @@ -13,9 +13,8 @@ namespace fl { InlineReducer::InlineReducer(double scale) : scale_(scale) {} void InlineReducer::add(Variable& var) { - if(getWorldSize() > 1) { + if(getWorldSize() > 1) allReduce(var.tensor()); - } var.tensor() *= scale_; } diff --git a/flashlight/fl/examples/AdaptiveClassification.cpp b/flashlight/fl/examples/AdaptiveClassification.cpp index 8236dd0..b23aede 100644 --- a/flashlight/fl/examples/AdaptiveClassification.cpp +++ b/flashlight/fl/examples/AdaptiveClassification.cpp @@ -51,9 +51,8 @@ int main(int /* unused */, const char** /* unused */) { const Tensor& out_ = label; fl::Timer s; for(int i = 0; i < nepochs; i++) { - if(i == warmup_epochs) { + if(i == warmup_epochs) s = fl::Timer::start(); - } /* Forward propagation */ result = model(input(in_)); diff --git a/flashlight/fl/examples/Benchmark.cpp b/flashlight/fl/examples/Benchmark.cpp index 0b2a906..dde158e 100644 --- a/flashlight/fl/examples/Benchmark.cpp +++ b/flashlight/fl/examples/Benchmark.cpp @@ -21,17 +21,15 @@ using namespace fl; double timeit(std::function fn) { // warmup - for(int i = 0; i < 10; ++i) { + for(int i = 0; i < 10; ++i) fn(); - } fl::sync(); int num_iters = 100; fl::sync(); auto start = fl::Timer::start(); - for(int i = 0; i < num_iters; i++) { + for(int i = 0; i < num_iters; i++) fn(); - } fl::sync(); return fl::Timer::stop(start) / num_iters; } diff --git a/flashlight/fl/examples/Classification.cpp b/flashlight/fl/examples/Classification.cpp index 3009064..fcd3307 100644 --- a/flashlight/fl/examples/Classification.cpp +++ b/flashlight/fl/examples/Classification.cpp @@ -49,9 +49,8 @@ int main(int /* unused */, const char** /* unused */) { const Tensor& out_ = label; fl::Timer s; for(int i = 0; i < nepochs; i++) { - if(i == warmup_epochs) { + if(i == warmup_epochs) s = fl::Timer::start(); - } /* Forward propagation */ result = model(input(in_)); diff --git a/flashlight/fl/examples/DistributedTraining.cpp b/flashlight/fl/examples/DistributedTraining.cpp index 00ce190..62cf20d 100644 --- a/flashlight/fl/examples/DistributedTraining.cpp +++ b/flashlight/fl/examples/DistributedTraining.cpp @@ -75,9 +75,8 @@ int main() { // Start training - if(isMaster) { + if(isMaster) std::cout << "[Multi-layer Perceptron] Started..." << std::endl; - } const int nEpochs = 100; for(int e = 1; e <= nEpochs; ++e) { meter.reset(); @@ -104,13 +103,11 @@ int main() { auto mseArr = Tensor::fromBuffer({1}, mse.data(), MemoryLocation::Host); fl::allReduce(mseArr); - if(isMaster) { + if(isMaster) std::cout << "Epoch: " << e << " Mean Squared Error: " << mseArr.scalar() / worldSize << std::endl; - } } - if(isMaster) { + if(isMaster) std::cout << "[Multi-layer Perceptron] Done!" << std::endl; - } return 0; } diff --git a/flashlight/fl/examples/Mnist.cpp b/flashlight/fl/examples/Mnist.cpp index 62f1075..366cacb 100644 --- a/flashlight/fl/examples/Mnist.cpp +++ b/flashlight/fl/examples/Mnist.cpp @@ -84,9 +84,8 @@ std::pair load_dataset( int main(int argc, char** argv) { fl::init(); - if(argc != 2) { + if(argc != 2) throw std::runtime_error("You must pass a data directory."); - } fl::setSeed(1); std::string data_dir = argv[1]; @@ -225,17 +224,15 @@ Tensor load_data( const std::vector& dims ) { std::ifstream file(im_file, std::ios::binary); - if(!file.is_open()) { + if(!file.is_open()) throw std::runtime_error("[mnist:load_data] Can't find MNIST file."); - } read_int(file); // unused magic size_t elems = 1; for(auto d : dims) { int read_d = read_int(file); elems *= read_d; - if(read_d != d) { + if(read_d != d) throw std::runtime_error("[mnist:load_data] Unexpected MNIST dimension."); - } } std::vector data; diff --git a/flashlight/fl/examples/RnnClassification.cpp b/flashlight/fl/examples/RnnClassification.cpp index ec32d4a..2691f70 100644 --- a/flashlight/fl/examples/RnnClassification.cpp +++ b/flashlight/fl/examples/RnnClassification.cpp @@ -33,9 +33,8 @@ using namespace fl; // return a random int between [mini, maxi] int randi(int mini, int maxi) { - if(maxi < mini) { + if(maxi < mini) std::swap(maxi, mini); - } return rand() % (maxi - mini + 1) + mini; } @@ -52,18 +51,16 @@ class ClassificationDataset : public Dataset { auto fp = folder / (lang + ".txt"); std::cout << "Opening " << fp << std::endl; std::ifstream file(fp); - if(!file.is_open()) { + if(!file.is_open()) throw std::runtime_error("Can't open the input dataset file"); - } unsigned id = Label2Id.size(); Label2Id[lang] = id; Id2Label[id] = lang; names v; std::string line; while(std::getline(file, line)) { - if(line.empty()) { + if(line.empty()) continue; - } v.push_back(line); } totalExamples += v.size(); @@ -105,12 +102,10 @@ class ClassificationDataset : public Dataset { "Polish", "Italian", "Irish"}; - for(auto& l : lang) { + for(auto& l : lang) read(datasetPath, l); - } - for(auto& it : Id2Label) { + for(auto& it : Id2Label) std::cout << it.first << ":" << it.second << ", "; - } std::cout << std::endl; } @@ -342,19 +337,17 @@ int main(int argc, char** argv) { auto p = trainSet.getRandomExample(); unsigned pred = model.infer(p.first, h, c); unsigned correctPred = ClassificationDataset::Label2Id[p.second]; - if(pred == correctPred) { + if(pred == correctPred) ++numMatch; - } confusion(correctPred, pred) = confusion(correctPred, pred) + 1; } confusion = confusion / fl::tile(fl::sum(confusion, {1}), {1, nCategories}); // average std::cout << "Global accuracy=" << numMatch / nConfusion << "\t "; - for(unsigned i = 0; i < nCategories; ++i) { + for(unsigned i = 0; i < nCategories; ++i) std::cout << ClassificationDataset::Id2Label[i] << ":" << std::fixed << std::setprecision(2) << confusion(i, i).scalar() << " "; - } std::cout << std::endl; } // List of names not in the training dataset @@ -368,9 +361,8 @@ int main(int argc, char** argv) { {"Voltaire", "French"}, {"Pfeiffer", "German"}, {"Tambellini", "Italian"}}; - for(auto& p : quickList) { + for(auto& p : quickList) model.unittest(p.first, p.second); - } while(true) { std::string name; diff --git a/flashlight/fl/examples/RnnLm.cpp b/flashlight/fl/examples/RnnLm.cpp index 6d4e3d3..f72e055 100644 --- a/flashlight/fl/examples/RnnLm.cpp +++ b/flashlight/fl/examples/RnnLm.cpp @@ -135,9 +135,8 @@ class RnnLm : public Container { std::vector forward(const std::vector& inputs) override { auto inSz = inputs.size(); - if(inSz < 1 || inSz > 3) { + if(inSz < 1 || inSz > 3) throw std::invalid_argument("Invalid inputs size"); - } return rnn->forward(inputs); } @@ -183,9 +182,8 @@ class RnnLm : public Container { int main(int argc, char** argv) { fl::init(); - if(argc != 2) { + if(argc != 2) throw std::runtime_error("You must pass a data directory."); - } std::string data_dir = argv[1]; @@ -235,9 +233,8 @@ int main(int argc, char** argv) { Variable output, h, c; - if(e >= anneal_after_epoch) { + if(e >= anneal_after_epoch) opt.setLr(opt.getLr() / 2); - } for(auto& example : trainset) { std::tie(output, h, c) = model(noGrad(example[kInputIdx]), h, c); @@ -281,15 +278,13 @@ const std::string Preprocessor::eos = ""; Preprocessor::Preprocessor(std::string dataset_path) { std::ifstream file(dataset_path); - if(!file.is_open()) { + if(!file.is_open()) throw std::runtime_error("[Preprocessor::Preprocessor] Can't find file."); - } int v = 0; std::string word; while(file >> word) { - if(word_to_int.find(word) == word_to_int.end()) { + if(word_to_int.find(word) == word_to_int.end()) word_to_int[word] = v++; - } } word_to_int[eos] = v; } @@ -302,9 +297,8 @@ LMDataset::LMDataset( ) : time_steps(time_steps) { std::vector words; std::ifstream file(dataset_path); - if(!file.is_open()) { + if(!file.is_open()) throw std::runtime_error("[LMDataset::LMDataset] Can't find file."); - } std::string line; while(std::getline(file, line)) { diff --git a/flashlight/fl/examples/Xor.cpp b/flashlight/fl/examples/Xor.cpp index 7a80461..3640bc5 100644 --- a/flashlight/fl/examples/Xor.cpp +++ b/flashlight/fl/examples/Xor.cpp @@ -26,11 +26,10 @@ int main(int argc, const char** argv) { int optim_mode = 0; std::string optimizer_arg = std::string(argv[1]); - if(optimizer_arg == "--adam") { + if(optimizer_arg == "--adam") optim_mode = 1; - } else if(optimizer_arg == "--rmsprop") { + else if(optimizer_arg == "--rmsprop") optim_mode = 2; - } const int inputSize = 2; const int outputSize = 1; @@ -61,13 +60,12 @@ int main(int argc, const char** argv) { std::unique_ptr optim; - if(optimizer_arg == "--rmsprop") { + if(optimizer_arg == "--rmsprop") optim = std::make_unique(model.params(), lr); - } else if(optimizer_arg == "--adam") { + else if(optimizer_arg == "--adam") optim = std::make_unique(model.params(), lr); - } else { + else optim = std::make_unique(model.params(), lr, mu); - } Variable result, l; for(int i = 0; i < 1000; i++) { diff --git a/flashlight/fl/meter/AverageValueMeter.cpp b/flashlight/fl/meter/AverageValueMeter.cpp index 42d967a..db862b0 100644 --- a/flashlight/fl/meter/AverageValueMeter.cpp +++ b/flashlight/fl/meter/AverageValueMeter.cpp @@ -26,9 +26,8 @@ void AverageValueMeter::add(const double val, const double w /* = 1.0 */) { curWeightSum_ += w; curWeightSquaredSum_ += w * w; - if(curWeightSum_ == 0) { + if(curWeightSum_ == 0) return; - } curMean_ = curMean_ + w * (val - curMean_) / curWeightSum_; curMeanSquaredSum_ = @@ -40,9 +39,8 @@ void AverageValueMeter::add(const Tensor& vals) { curWeightSum_ += w; curWeightSquaredSum_ += w; - if(curWeightSum_ == 0) { + if(curWeightSum_ == 0) return; - } curMean_ = curMean_ + (fl::sum(vals).asScalar() - w * curMean_) / curWeightSum_; diff --git a/flashlight/fl/meter/CountMeter.cpp b/flashlight/fl/meter/CountMeter.cpp index 17a97ce..3dde5f1 100644 --- a/flashlight/fl/meter/CountMeter.cpp +++ b/flashlight/fl/meter/CountMeter.cpp @@ -15,9 +15,8 @@ namespace fl { CountMeter::CountMeter(int num) : counts_(num, 0) {} void CountMeter::add(int id, int64_t val) { - if(!(id >= 0 && id < counts_.size())) { + if(!(id >= 0 && id < counts_.size())) throw std::out_of_range("invalid id to update count for"); - } counts_[id] += val; } diff --git a/flashlight/fl/meter/EditDistanceMeter.cpp b/flashlight/fl/meter/EditDistanceMeter.cpp index e93f9b4..65d2ccd 100644 --- a/flashlight/fl/meter/EditDistanceMeter.cpp +++ b/flashlight/fl/meter/EditDistanceMeter.cpp @@ -25,16 +25,14 @@ void EditDistanceMeter::reset() { } void EditDistanceMeter::add(const Tensor& output, const Tensor& target) { - if(target.ndim() != 1) { + if(target.ndim() != 1) throw std::invalid_argument( "target must be 1-dimensional for EditDistanceMeter" ); - } - if(output.ndim() != 1) { + if(output.ndim() != 1) throw std::invalid_argument( "output must be 1-dimensional for EditDistanceMeter" ); - } int len1 = output.dim(0); int len2 = target.dim(0); diff --git a/flashlight/fl/meter/EditDistanceMeter.h b/flashlight/fl/meter/EditDistanceMeter.h index f98620c..5a0e079 100644 --- a/flashlight/fl/meter/EditDistanceMeter.h +++ b/flashlight/fl/meter/EditDistanceMeter.h @@ -130,9 +130,8 @@ class FL_API EditDistanceMeter { size_t len2 ) const { std::vector column(len1 + 1); - for(int i = 0; i <= len1; ++i) { + for(int i = 0; i <= len1; ++i) column[i].nins = i; - } auto curin2 = in2begin; for(int x = 1; x <= len2; x++) { @@ -150,18 +149,17 @@ class FL_API EditDistanceMeter { if( std::distance(possibilities.begin(), min_it) == 0 - ) { // deletion error + ) // deletion error ++column[y].ndel; - } else if( + else if( std::distance(possibilities.begin(), min_it) == 1) { // insertion // error column[y] = column[y - 1]; ++column[y].nins; } else { column[y] = lastdiagonal; - if(*curin1 != *curin2) { // substitution error + if(*curin1 != *curin2) // substitution error ++column[y].nsub; - } } lastdiagonal = olddiagonal; diff --git a/flashlight/fl/meter/FrameErrorMeter.cpp b/flashlight/fl/meter/FrameErrorMeter.cpp index 6262e37..f55630a 100644 --- a/flashlight/fl/meter/FrameErrorMeter.cpp +++ b/flashlight/fl/meter/FrameErrorMeter.cpp @@ -22,14 +22,12 @@ void FrameErrorMeter::reset() { } void FrameErrorMeter::add(const Tensor& output, const Tensor& target) { - if(output.shape() != target.shape()) { + if(output.shape() != target.shape()) throw std::invalid_argument("dimension mismatch in FrameErrorMeter"); - } - if(target.ndim() != 1) { + if(target.ndim() != 1) throw std::invalid_argument( "output/target must be 1-dimensional for FrameErrorMeter" ); - } sum_ += fl::countNonzero(output != target).scalar(); n_ += target.dim(0); diff --git a/flashlight/fl/meter/MSEMeter.cpp b/flashlight/fl/meter/MSEMeter.cpp index 0998406..54bc724 100644 --- a/flashlight/fl/meter/MSEMeter.cpp +++ b/flashlight/fl/meter/MSEMeter.cpp @@ -22,9 +22,8 @@ void MSEMeter::reset() { } void MSEMeter::add(const Tensor& output, const Tensor& target) { - if(output.ndim() != target.ndim()) { + if(output.ndim() != target.ndim()) throw std::invalid_argument("dimension mismatch in MSEMeter"); - } ++curN_; curValue_ = (curValue_ * (curN_ - 1) diff --git a/flashlight/fl/meter/TimeMeter.cpp b/flashlight/fl/meter/TimeMeter.cpp index d5bde72..765c914 100644 --- a/flashlight/fl/meter/TimeMeter.cpp +++ b/flashlight/fl/meter/TimeMeter.cpp @@ -32,16 +32,14 @@ double TimeMeter::value() const { std::chrono::system_clock::now() - start_; val += duration.count(); } - if(useUnit_) { + if(useUnit_) val = (curN_ > 0) ? (val / curN_) : 0.0; - } return val; } void TimeMeter::stop() { - if(isStopped_) { + if(isStopped_) return; - } std::chrono::duration duration = std::chrono::system_clock::now() - start_; curValue_ += duration.count(); @@ -49,9 +47,8 @@ void TimeMeter::stop() { } void TimeMeter::resume() { - if(!isStopped_) { + if(!isStopped_) return; - } start_ = std::chrono::system_clock::now(); isStopped_ = false; } diff --git a/flashlight/fl/meter/TopKMeter.cpp b/flashlight/fl/meter/TopKMeter.cpp index b24dce7..a54dd3d 100644 --- a/flashlight/fl/meter/TopKMeter.cpp +++ b/flashlight/fl/meter/TopKMeter.cpp @@ -17,14 +17,12 @@ TopKMeter::TopKMeter(const int k) : k_(k), n_(0) {}; void TopKMeter::add(const Tensor& output, const Tensor& target) { - if(output.dim(1) != target.dim(0)) { + if(output.dim(1) != target.dim(0)) throw std::invalid_argument("dimension mismatch in TopKMeter"); - } - if(target.ndim() != 1) { + if(target.ndim() != 1) throw std::invalid_argument( "output/target must be 1-dimensional for TopKMeter" ); - } Tensor maxVals, maxIds, match; topk(maxVals, maxIds, output, k_, 0); diff --git a/flashlight/fl/nn/DistributedUtils.cpp b/flashlight/fl/nn/DistributedUtils.cpp index cea102c..a27dadc 100644 --- a/flashlight/fl/nn/DistributedUtils.cpp +++ b/flashlight/fl/nn/DistributedUtils.cpp @@ -17,31 +17,26 @@ void distributeModuleGrads( std::shared_ptr module, std::shared_ptr reducer ) { - for(auto& param : module->params()) { + for(auto& param : module->params()) param.registerGradHook([reducer](Variable& grad) { reducer->add(grad); }); - } } void allReduceParameters(std::shared_ptr module) { - if(!module) { + if(!module) throw std::invalid_argument("null module passed to allReduceParameters"); - } double scale = 1.0 / getWorldSize(); - for(auto& param : module->params()) { + for(auto& param : module->params()) allReduce(param, scale); - } } void allReduceGradients( std::shared_ptr module, double scale /*= 1.0 */ ) { - if(!module) { + if(!module) throw std::invalid_argument("null module passed to allReduceGradients"); - } - for(auto& param : module->params()) { + for(auto& param : module->params()) allReduce(param.grad(), scale); - } ; } diff --git a/flashlight/fl/nn/Init.cpp b/flashlight/fl/nn/Init.cpp index 0e8d379..0ad37ac 100644 --- a/flashlight/fl/nn/Init.cpp +++ b/flashlight/fl/nn/Init.cpp @@ -71,9 +71,8 @@ namespace detail { } Tensor erfinv(const Tensor& y) { - if(fl::any(fl::abs(y) >= 1.).scalar()) { + if(fl::any(fl::abs(y) >= 1.).scalar()) throw std::runtime_error("[erfinv] input is out of range (-1, 1)"); - } double a[4] = {0.886226899, -1.645349621, 0.914624893, -0.140543331}; double b[4] = {-2.118377725, 1.442710462, -0.329097515, 0.012229801}; double c[4] = {-1.970840454, -1.624906493, 3.429567803, 1.641345311}; @@ -101,9 +100,8 @@ namespace detail { if( fl::any(fl::isnan(x)).asScalar() || fl::any(fl::isinf(x)).asScalar() - ) { + ) throw std::runtime_error("[erfinv] invalid result"); - } return x; } @@ -137,13 +135,12 @@ Variable constant(double val, const Shape& dims, fl::dtype type, bool calcGrad) Variable identity(int outputSize, int inputSize, fl::dtype type, bool calcGrad) { // TODO{fl::Tensor}{fixme} add non-square identity to API - if(inputSize != outputSize) { + if(inputSize != outputSize) throw std::invalid_argument( "identity - can't create tensor with " "different in and output size - only square identity " "tensors supported" ); - } return identity(Shape({inputSize, outputSize}), type, calcGrad); } diff --git a/flashlight/fl/nn/Utils.cpp b/flashlight/fl/nn/Utils.cpp index 3bf2640..fc9e8b2 100644 --- a/flashlight/fl/nn/Utils.cpp +++ b/flashlight/fl/nn/Utils.cpp @@ -18,9 +18,8 @@ namespace fl { int64_t numTotalParams(std::shared_ptr module) { int64_t params = 0; - for(auto& p : module->params()) { + for(auto& p : module->params()) params += p.elements(); - } return params; } @@ -29,16 +28,13 @@ bool allParamsClose( const Module& b, double absTolerance /* = 1e-5 */ ) { - if(a.params().size() != b.params().size()) { + if(a.params().size() != b.params().size()) return false; - } const auto aParams = a.params(); const auto bParams = b.params(); - for(int p = 0; p < aParams.size(); ++p) { - if(!allClose(aParams[p], bParams[p], absTolerance)) { + for(int p = 0; p < aParams.size(); ++p) + if(!allClose(aParams[p], bParams[p], absTolerance)) return false; - } - } return true; } @@ -89,11 +85,10 @@ namespace detail { int derivePadding(int inSz, int filterSz, int stride, int pad, int dilation) { if(pad == static_cast(PaddingMode::SAME)) { int newPad; - if(inSz % stride == 0) { + if(inSz % stride == 0) newPad = (filterSz - 1) * dilation - stride + 1; - } else { + else newPad = (filterSz - 1) * dilation - (inSz % stride) + 1; - } newPad = (newPad + 1) / 2; // equal pad on both sides return std::max(newPad, 0); } @@ -106,16 +101,13 @@ Tensor join( double padValue /* = 0.0 */, int batchDim /* = -1 */ ) { - if(inputs.empty()) { + if(inputs.empty()) return Tensor(); - } Dim maxNumDims = 0; - for(const auto& in : inputs) { - if(in.ndim() > maxNumDims) { + for(const auto& in : inputs) + if(in.ndim() > maxNumDims) maxNumDims = in.ndim(); - } - } // If the batch dim > the max number of dims, make those dims singleton int outNdims = std::max(batchDim + 1, static_cast(maxNumDims)); @@ -128,37 +120,31 @@ Tensor join( isEmpty = isEmpty && in.isEmpty(); for(int d = 0; d < in.ndim(); ++d) { maxDims[d] = std::max(maxDims[d], in.dim(d)); - if(in.type() != type) { + if(in.type() != type) throw std::invalid_argument( "join: all arrays should of same type for join" ); - } } } - if(batchDim < 0) { + if(batchDim < 0) batchDim = maxDims.ndim() - 1; - } - if(batchDim < maxDims.ndim() && maxDims[batchDim] > 1) { + if(batchDim < maxDims.ndim() && maxDims[batchDim] > 1) throw std::invalid_argument( "join: no singleton dim available for batching" ); - } maxDims[batchDim] = inputs.size(); - if(isEmpty) { + if(isEmpty) return Tensor(maxDims, type); - } auto padSeq = fl::full(maxDims, padValue, type); std::vector sel( std::max(maxNumDims, static_cast(batchDim + 1)), fl::span); for(int i = 0; i < inputs.size(); ++i) { - for(int d = 0; d < maxNumDims; ++d) { + for(int d = 0; d < maxNumDims; ++d) sel[d] = fl::range(inputs[i].dim(d)); - } sel[batchDim] = fl::range(i, i + 1); - if(!inputs[i].isEmpty()) { + if(!inputs[i].isEmpty()) padSeq(sel) = inputs[i]; - } } return padSeq; } diff --git a/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp b/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp index b73575f..8ab04ea 100644 --- a/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp +++ b/flashlight/fl/nn/modules/AdaptiveSoftMax.cpp @@ -22,9 +22,8 @@ AdaptiveSoftMax::AdaptiveSoftMax( ) : UnaryModule(), cutoff_(cutoff), divValue_(divValue) { - if(cutoff_.empty()) { + if(cutoff_.empty()) throw std::invalid_argument("invalid cutoff for AdaptiveSoftMaxLoss"); - } int outputSize = cutoff_[0] + cutoff_.size() - 1; @@ -87,9 +86,8 @@ Variable AdaptiveSoftMax::forward(const Variable& inputs) { // input -- [C_in, .. , N] // return -- [C_out, .. , N] auto inputSize = inputs.dim(0); - if(inputSize != params_[0].dim(1)) { + if(inputSize != params_[0].dim(1)) throw std::invalid_argument("invalid input dimension for AdaptiveSoftMax"); - } auto inputsFlattened = moddims(inputs, {inputSize, -1}); auto headOutput = logSoftmax(matmul(params_[0], inputsFlattened), 0); @@ -105,11 +103,10 @@ Variable AdaptiveSoftMax::predict(const Variable& inputs) const { // input -- [C, .. , N] // return -- [1, .. , N] auto inputSize = inputs.dim(0); - if(inputSize != params_[0].dim(1)) { + if(inputSize != params_[0].dim(1)) throw std::invalid_argument( "invalid input dimension for AdaptiveSoftMaxLoss" ); - } auto inputsFlattened = moddims(inputs, {inputSize, -1}); auto headOutput = matmul(params_[0], inputsFlattened); @@ -150,9 +147,8 @@ std::unique_ptr AdaptiveSoftMax::clone() const { std::string AdaptiveSoftMax::prettyString() const { std::ostringstream ss; ss << "Adaptive Softmax ("; - for(int i = 0; i < cutoff_.size() - 1; i++) { + for(int i = 0; i < cutoff_.size() - 1; i++) ss << cutoff_[i] << ", "; - } ss << cutoff_[cutoff_.size() - 1] << ")"; return ss.str(); } diff --git a/flashlight/fl/nn/modules/BatchNorm.cpp b/flashlight/fl/nn/modules/BatchNorm.cpp index 4039ed9..30b7f43 100644 --- a/flashlight/fl/nn/modules/BatchNorm.cpp +++ b/flashlight/fl/nn/modules/BatchNorm.cpp @@ -75,11 +75,10 @@ Variable BatchNorm::forward(const Variable& input) { if(train_ && trackStats_) { ++numBatchesTracked_; - if(momentum_ < 0) { // cumulative moving average + if(momentum_ < 0) // cumulative moving average avgFactor = 1.0 / numBatchesTracked_; - } else { // exponential moving average + else // exponential moving average avgFactor = momentum_; - } } auto paramsType = @@ -118,9 +117,8 @@ std::string BatchNorm::prettyString() const { std::ostringstream ss; ss << "BatchNorm"; ss << " ( axis : { "; - for(auto x : featAxis_) { + for(auto x : featAxis_) ss << x << " "; - } ss << "}, size : " << featSize_ << " )"; return ss.str(); } diff --git a/flashlight/fl/nn/modules/Container.cpp b/flashlight/fl/nn/modules/Container.cpp index 876de5e..11fb591 100644 --- a/flashlight/fl/nn/modules/Container.cpp +++ b/flashlight/fl/nn/modules/Container.cpp @@ -50,29 +50,23 @@ std::vector Container::modules() const { void Container::train() { train_ = true; - for(int i = 0; i < params_.size(); ++i) { - if(childParamIdx_.find(i) == childParamIdx_.end()) { + for(int i = 0; i < params_.size(); ++i) + if(childParamIdx_.find(i) == childParamIdx_.end()) params_[i].setCalcGrad(true); - } - } - for(auto& module : modules_) { + for(auto& module : modules_) module->train(); - } } void Container::eval() { train_ = false; - for(int i = 0; i < params_.size(); ++i) { - if(childParamIdx_.find(i) == childParamIdx_.end()) { + for(int i = 0; i < params_.size(); ++i) + if(childParamIdx_.find(i) == childParamIdx_.end()) params_[i].setCalcGrad(false); - } - } - for(auto& module : modules_) { + for(auto& module : modules_) module->eval(); - } } void Container::setParams(const Variable& var, int position) { @@ -88,13 +82,11 @@ void Container::setParams(const Variable& var, int position) { std::string Container::prettyString() const { std::ostringstream ss; ss << " [input"; - for(int i = 0; i < modules_.size(); ++i) { + for(int i = 0; i < modules_.size(); ++i) ss << " -> (" << i << ")"; - } ss << " -> output]"; - for(int i = 0; i < modules_.size(); ++i) { + for(int i = 0; i < modules_.size(); ++i) ss << "\n\t(" << i << "): " << modules_[i]->prettyString(); - } return ss.str(); } @@ -102,20 +94,17 @@ Sequential::Sequential() = default; std::vector Sequential::forward(const std::vector& input) { auto output = input; - for(auto& module : modules_) { + for(auto& module : modules_) output = module->forward(output); - } return output; } Variable Sequential::forward(const Variable& input) { std::vector output = {input}; - for(auto& module : modules_) { + for(auto& module : modules_) output = module->forward(output); - } - if(output.size() != 1) { + if(output.size() != 1) throw std::invalid_argument("Module output size is not 1"); - } return output.front(); } @@ -127,13 +116,11 @@ std::string Sequential::prettyString() const { std::ostringstream ss; ss << "Sequential"; ss << " [input"; - for(int i = 0; i < modules_.size(); ++i) { + for(int i = 0; i < modules_.size(); ++i) ss << " -> (" << i << ")"; - } ss << " -> output]"; - for(int i = 0; i < modules_.size(); ++i) { + for(int i = 0; i < modules_.size(); ++i) ss << "\n\t(" << i << "): " << modules_[i]->prettyString(); - } return ss.str(); } diff --git a/flashlight/fl/nn/modules/Container.h b/flashlight/fl/nn/modules/Container.h index 0d326ba..7e38576 100644 --- a/flashlight/fl/nn/modules/Container.h +++ b/flashlight/fl/nn/modules/Container.h @@ -112,9 +112,8 @@ class FL_API Container : public Module { */ template void add(std::shared_ptr module) { - if(!module) { + if(!module) throw std::invalid_argument("can't add null Module to Container"); - } for(int i = 0; i < module->numParamTensors(); i++) { childParamIdx_[params_.size()] = std::make_tuple(static_cast(modules_.size()), i); params_.push_back(module->param(i)); diff --git a/flashlight/fl/nn/modules/Conv2D.cpp b/flashlight/fl/nn/modules/Conv2D.cpp index ab5f302..02f6bab 100644 --- a/flashlight/fl/nn/modules/Conv2D.cpp +++ b/flashlight/fl/nn/modules/Conv2D.cpp @@ -94,16 +94,14 @@ Conv2D::Conv2D( yDilation_(dy), bias_(true), groups_(groups) { - if(b.dim(2) != w.dim(3)) { + if(b.dim(2) != w.dim(3)) throw std::invalid_argument( "output channel dimension mismatch between Conv2D weight and bias" ); - } - if(b.elements() != b.dim(2)) { + if(b.elements() != b.dim(2)) throw std::invalid_argument( "only 3rd dimension of Conv2D bias may be non-singleton" ); - } } Conv2D::Conv2D(const Conv2D& other) : UnaryModule(other.copyParams()), @@ -143,11 +141,10 @@ Conv2D& Conv2D::operator=(const Conv2D& other) { Variable Conv2D::forward(const Variable& input) { auto px = derivePadding(input.dim(0), xFilter_, xStride_, xPad_, xDilation_); auto py = derivePadding(input.dim(1), yFilter_, yStride_, yPad_, yDilation_); - if(!(px >= 0 && py >= 0)) { + if(!(px >= 0 && py >= 0)) throw std::invalid_argument("invalid padding for Conv2D"); - } - if(bias_) { + if(bias_) return conv2d( input, params_[0].astype(input.type()), @@ -161,7 +158,7 @@ Variable Conv2D::forward(const Variable& input) { groups_, benchmarks_ ); - } else { + else return conv2d( input, params_[0].astype(input.type()), @@ -174,7 +171,6 @@ Variable Conv2D::forward(const Variable& input) { groups_, benchmarks_ ); - } } void Conv2D::initialize() { @@ -190,9 +186,8 @@ void Conv2D::initialize() { auto bs = uniform(Shape({1, 1, nOut_, 1}), -bound, bound, fl::dtype::f32, true); params_ = {wt, bs}; - } else { + } else params_ = {wt}; - } benchmarks_ = std::make_shared(); } @@ -206,25 +201,22 @@ std::string Conv2D::prettyString() const { ss << "Conv2D"; ss << " (" << nIn_ << "->" << nOut_ << ", " << xFilter_ << "x" << yFilter_ << ", " << xStride_ << "," << yStride_ << ", "; - if(xPad_ == static_cast(PaddingMode::SAME)) { + if(xPad_ == static_cast(PaddingMode::SAME)) ss << "SAME"; - } else { + else ss << xPad_; - } ss << ","; - if(yPad_ == static_cast(PaddingMode::SAME)) { + if(yPad_ == static_cast(PaddingMode::SAME)) ss << "SAME"; - } else { + else ss << yPad_; - } ss << ", " << xDilation_ << ", " << yDilation_; ss << ")"; - if(bias_) { + if(bias_) ss << " (with bias)"; - } else { + else ss << " (without bias)"; - } return ss.str(); } diff --git a/flashlight/fl/nn/modules/Dropout.cpp b/flashlight/fl/nn/modules/Dropout.cpp index b7a3ae2..8e40f6f 100644 --- a/flashlight/fl/nn/modules/Dropout.cpp +++ b/flashlight/fl/nn/modules/Dropout.cpp @@ -15,11 +15,10 @@ namespace fl { Dropout::Dropout(double drop_ratio) : ratio_(drop_ratio) {} Variable Dropout::forward(const Variable& input) { - if(train_) { + if(train_) return dropout(input, ratio_); - } else { + else return input; - } } std::unique_ptr Dropout::clone() const { diff --git a/flashlight/fl/nn/modules/LayerNorm.cpp b/flashlight/fl/nn/modules/LayerNorm.cpp index 4b2e0c1..e4dd845 100644 --- a/flashlight/fl/nn/modules/LayerNorm.cpp +++ b/flashlight/fl/nn/modules/LayerNorm.cpp @@ -33,11 +33,9 @@ LayerNorm::LayerNorm( ) : epsilon_(eps), affine_(affine), axisSize_(axisSize) { - for(int d = 0; d < kLnExpectedNumDims; ++d) { - if(std::find(axis.begin(), axis.end(), d) == axis.end()) { + for(int d = 0; d < kLnExpectedNumDims; ++d) + if(std::find(axis.begin(), axis.end(), d) == axis.end()) axisComplement_.push_back(d); - } - } initialize(); } @@ -48,16 +46,14 @@ Variable LayerNorm::forward(const Variable& _input) { // TODO: this is pretty ugly -- eventually fix this up if it can be avoided if(input.ndim() < kLnExpectedNumDims) { std::vector s = _input.shape().get(); - for(unsigned i = s.size(); i < kLnExpectedNumDims; ++i) { + for(unsigned i = s.size(); i < kLnExpectedNumDims; ++i) s.push_back(1); - } input = moddims(_input, Shape(s)); - } else if(input.ndim() > kLnExpectedNumDims) { + } else if(input.ndim() > kLnExpectedNumDims) throw std::invalid_argument( "LayerNorm::forward - input must be " + std::to_string(kLnExpectedNumDims) + " or fewer dimensions." ); - } Variable dummyInMean, dummyInVar; @@ -70,18 +66,16 @@ Variable LayerNorm::forward(const Variable& _input) { auto minAxis = *std::min_element(axisComplement_.begin(), axisComplement_.end()); bool axesContinuous = (axisComplement_.size() == (maxAxis - minAxis + 1)); - if(axesContinuous) { + if(axesContinuous) inNormAxes = axisComplement_; - } else { + else { int i = 0; - for(int d = 0; d < input.ndim(); ++d) { + for(int d = 0; d < input.ndim(); ++d) if( std::find(axisComplement_.begin(), axisComplement_.end(), d) == axisComplement_.end() - ) { + ) reorderDims[i++] = d; - } - } for(auto n : axisComplement_) { inNormAxes.push_back(i); reorderDims[i++] = n; @@ -104,14 +98,12 @@ Variable LayerNorm::forward(const Variable& _input) { if(!axesContinuous) { std::vector> restoreDims; - for(size_t i = 0; i < reorderDims.ndim(); ++i) { + for(size_t i = 0; i < reorderDims.ndim(); ++i) restoreDims.emplace_back(reorderDims[i], i); - } std::sort(restoreDims.begin(), restoreDims.end()); Shape restoreDimsShape(std::vector(restoreDims.size())); - for(size_t i = 0; i < restoreDims.size(); ++i) { + for(size_t i = 0; i < restoreDims.size(); ++i) restoreDimsShape[i] = restoreDims[i].second; - } output = reorder(output, restoreDimsShape); } @@ -120,14 +112,12 @@ Variable LayerNorm::forward(const Variable& _input) { Variable bias = params_[1].astype(output.type()); if(axisSize_ != kLnVariableAxisSize) { Shape affineDims = input.shape(); - for(int ax : axisComplement_) { + for(int ax : axisComplement_) affineDims[ax] = 1; - } - if(affineDims.elements() != axisSize_) { + if(affineDims.elements() != axisSize_) throw std::invalid_argument( "[LayerNorm] Input size along the norm axis doesn't with axisSize." ); - } weight = moddims(params_[0].astype(output.type()), affineDims); bias = moddims(params_[1].astype(output.type()), affineDims); } @@ -154,14 +144,12 @@ std::string LayerNorm::prettyString() const { std::ostringstream ss; ss << "LayerNorm"; ss << " ( axis : { "; - for(int d = 0; d < axisComplement_.size(); ++d) { + for(int d = 0; d < axisComplement_.size(); ++d) if( std::find(axisComplement_.begin(), axisComplement_.end(), d) == axisComplement_.end() - ) { + ) ss << d << " "; - } - } ss << "} , size : " << axisSize_ << ")"; return ss.str(); } diff --git a/flashlight/fl/nn/modules/Linear.cpp b/flashlight/fl/nn/modules/Linear.cpp index 5f86423..d6414bb 100644 --- a/flashlight/fl/nn/modules/Linear.cpp +++ b/flashlight/fl/nn/modules/Linear.cpp @@ -27,11 +27,10 @@ Linear::Linear(const Variable& w) : UnaryModule({w}), nIn_(w.dim(1)), nOut_(w.di Linear::Linear(const Variable& w, const Variable& b) : UnaryModule({w, b}), nIn_(w.dim(1)), nOut_(w.dim(0)), bias_(true) { - if(b.dim(0) != w.dim(0)) { + if(b.dim(0) != w.dim(0)) throw std::invalid_argument( "dimension mismatch between Linear weight and bias" ); - } } Linear::Linear(const Linear& other) : UnaryModule(other.copyParams()), @@ -51,13 +50,12 @@ Linear& Linear::operator=(const Linear& other) { } Variable Linear::forward(const Variable& input) { - if(bias_) { + if(bias_) return linear( input, params_[0].astype(input.type()), params_[1].astype(input.type()) ); - } return linear(input, params_[0].astype(input.type())); } @@ -71,9 +69,8 @@ void Linear::initialize() { double bound = std::sqrt(1.0 / fanIn); auto b = uniform(Shape({nOut_}), -bound, bound, fl::dtype::f32, true); params_ = {w, b}; - } else { + } else params_ = {w}; - } } std::unique_ptr Linear::clone() const { @@ -84,11 +81,10 @@ std::string Linear::prettyString() const { std::ostringstream ss; ss << "Linear"; ss << " (" << nIn_ << "->" << nOut_ << ")"; - if(bias_) { + if(bias_) ss << " (with bias)"; - } else { + else ss << " (without bias)"; - } return ss.str(); } diff --git a/flashlight/fl/nn/modules/Loss.cpp b/flashlight/fl/nn/modules/Loss.cpp index 81eb725..2725f04 100644 --- a/flashlight/fl/nn/modules/Loss.cpp +++ b/flashlight/fl/nn/modules/Loss.cpp @@ -18,14 +18,13 @@ Variable MeanSquaredError::forward( const Variable& inputs, const Variable& targets ) { - if(inputs.shape() != targets.shape()) { + if(inputs.shape() != targets.shape()) throw std::invalid_argument( "MeanSquaredError::forward - inputs and targets are of different" " sizes: {inputs: " + inputs.shape().toString() + ", targets: " + targets.shape().toString() + "}" ); - } auto df = inputs - targets; auto res = mean(flat(df * df), {0}); @@ -44,14 +43,13 @@ Variable MeanAbsoluteError::forward( const Variable& inputs, const Variable& targets ) { - if(inputs.shape() != targets.shape()) { + if(inputs.shape() != targets.shape()) throw std::invalid_argument( "MeanAbsoluteError::forward - inputs and targets are of different" " sizes: {inputs: " + inputs.shape().toString() + ", targets: " + targets.shape().toString() + "}" ); - } auto df = inputs - targets; return mean(flat(fl::abs(df)), {0}); @@ -119,9 +117,8 @@ Variable AdaptiveSoftMaxLoss::cast( const Shape& outDims, const Tensor& indices ) { - if(input.elements() != indices.elements()) { + if(input.elements() != indices.elements()) throw std::invalid_argument("AdaptiveSoftMaxLoss: input, indices mismatch"); - } Tensor output = fl::full(outDims, 0, input.type()); output(indices) = input.tensor().flatten(); auto inputDims = input.shape(); @@ -142,23 +139,20 @@ Variable AdaptiveSoftMaxLoss::forward( ) { // inputs: N x T x B // targets: T x B - if(inputs.ndim() != 3) { + if(inputs.ndim() != 3) throw std::invalid_argument( "AdaptiveSoftMaxLoss::forward expects input tensor with " "3 dimensions in N x T x B ordering." ); - } - if(targets.ndim() != 2) { + if(targets.ndim() != 2) throw std::invalid_argument( "AdaptiveSoftMaxLoss::forward expects target tensor with " "2 dimensions in T x B ordering." ); - } - if(inputs.dim(1) != targets.dim(0)) { + if(inputs.dim(1) != targets.dim(0)) throw std::invalid_argument("AdaptiveSoftMaxLoss: length mismatch"); - } else if(inputs.dim(2) != targets.dim(1)) { + else if(inputs.dim(2) != targets.dim(1)) throw std::invalid_argument("AdaptiveSoftMaxLoss: batch size mismatch"); - } auto N = inputs.dim(0); auto T = inputs.dim(1); @@ -176,9 +170,8 @@ Variable AdaptiveSoftMaxLoss::forward( // Tail forwawrd for(int i = 0; i < cutoff.size() - 1; i++) { auto mask = (target >= cutoff[i]) && (target < cutoff[i + 1]); - if(!fl::any(mask.tensor()).scalar()) { + if(!fl::any(mask.tensor()).scalar()) continue; - } auto indicesArray = fl::nonzero(mask.tensor()); headTarget = @@ -206,9 +199,8 @@ Variable AdaptiveSoftMaxLoss::forward( ); // Reduce - if(reduction_ == ReduceMode::NONE) { + if(reduction_ == ReduceMode::NONE) return moddims(res, targets.shape()); - } res = sum(res, {0}); if(reduction_ == ReduceMode::MEAN) { auto denominator = @@ -235,9 +227,8 @@ std::string AdaptiveSoftMaxLoss::prettyString() const { std::ostringstream ss; auto cutoff = activation_->getCutoff(); ss << "Adaptive Softmax ("; - for(int i = 0; i < cutoff.size() - 1; i++) { + for(int i = 0; i < cutoff.size() - 1; i++) ss << cutoff[i] << ", "; - } ss << cutoff[cutoff.size() - 1] << ")"; return ss.str(); } diff --git a/flashlight/fl/nn/modules/Module.cpp b/flashlight/fl/nn/modules/Module.cpp index 41e1fb8..0cee8f0 100644 --- a/flashlight/fl/nn/modules/Module.cpp +++ b/flashlight/fl/nn/modules/Module.cpp @@ -19,46 +19,40 @@ Module::Module() = default; Module::Module(const std::vector& params) : params_(params.begin(), params.end()) {} Variable Module::param(int position) const { - if(!(position >= 0 && position < params_.size())) { + if(!(position >= 0 && position < params_.size())) throw std::out_of_range("Module param index out of range"); - } return params_[position]; } void Module::setParams(const Variable& var, int position) { - if(!(position >= 0 && position < params_.size())) { + if(!(position >= 0 && position < params_.size())) throw std::out_of_range("Module param index out of range"); - } params_[position] = var; } std::vector Module::copyParams() const { std::vector params; params.reserve(params_.size()); - for(const auto& param : params_) { + for(const auto& param : params_) params.emplace_back(param.copy()); - } return params; } void Module::train() { train_ = true; - for(auto& param : params_) { + for(auto& param : params_) param.setCalcGrad(true); - } } void Module::zeroGrad() { - for(auto& param : params_) { + for(auto& param : params_) param.zeroGrad(); - } } void Module::eval() { train_ = false; - for(auto& param : params_) { + for(auto& param : params_) param.setCalcGrad(false); - } } std::vector Module::params() const { @@ -80,9 +74,8 @@ UnaryModule::UnaryModule(const std::vector& params) : Module(params) { std::vector UnaryModule::forward( const std::vector& inputs ) { - if(inputs.size() != 1) { + if(inputs.size() != 1) throw std::invalid_argument("UnaryModule expects only one input"); - } return {forward(inputs[0])}; } @@ -97,9 +90,8 @@ BinaryModule::BinaryModule(const std::vector& params) : Module(params) std::vector BinaryModule::forward( const std::vector& inputs ) { - if(inputs.size() != 2) { + if(inputs.size() != 2) throw std::invalid_argument("BinaryModule expects two inputs"); - } return {forward(inputs[0], inputs[1])}; } diff --git a/flashlight/fl/nn/modules/Normalize.cpp b/flashlight/fl/nn/modules/Normalize.cpp index 53e9aeb..e66a402 100644 --- a/flashlight/fl/nn/modules/Normalize.cpp +++ b/flashlight/fl/nn/modules/Normalize.cpp @@ -32,9 +32,8 @@ std::string Normalize::prettyString() const { std::ostringstream ss; ss << "Normalize"; ss << " ( axis : { "; - for(auto d : axes_) { + for(auto d : axes_) ss << d << " "; - } ss << "} , p : " << p_; ss << ", eps : " << eps_; ss << ", value : " << value_; diff --git a/flashlight/fl/nn/modules/Padding.cpp b/flashlight/fl/nn/modules/Padding.cpp index 54cfa3b..24bf9ed 100644 --- a/flashlight/fl/nn/modules/Padding.cpp +++ b/flashlight/fl/nn/modules/Padding.cpp @@ -25,9 +25,8 @@ std::unique_ptr Padding::clone() const { std::string Padding::prettyString() const { std::ostringstream ss; ss << "Padding (" << m_val << ", { "; - for(auto p : m_pad) { + for(auto p : m_pad) ss << "(" << p.first << ", " << p.second << "), "; - } ss << "})"; return ss.str(); } diff --git a/flashlight/fl/nn/modules/Pool2D.cpp b/flashlight/fl/nn/modules/Pool2D.cpp index 609c3cc..5d58d99 100644 --- a/flashlight/fl/nn/modules/Pool2D.cpp +++ b/flashlight/fl/nn/modules/Pool2D.cpp @@ -49,9 +49,8 @@ Variable Pool2D::forward(const Variable& input) { /* dilation= */ 1 ); - if(!(px >= 0 && py >= 0)) { + if(!(px >= 0 && py >= 0)) throw std::invalid_argument("invalid padding for Pool2D"); - } return pool2d(input, xFilter_, yFilter_, xStride_, yStride_, px, py, mode_); } @@ -76,17 +75,15 @@ std::string Pool2D::prettyString() const { } ss << " (" << xFilter_ << "x" << yFilter_ << ", " << xStride_ << "," << yStride_ << ", "; - if(xPad_ == static_cast(PaddingMode::SAME)) { + if(xPad_ == static_cast(PaddingMode::SAME)) ss << "SAME"; - } else { + else ss << xPad_; - } ss << ","; - if(yPad_ == static_cast(PaddingMode::SAME)) { + if(yPad_ == static_cast(PaddingMode::SAME)) ss << "SAME"; - } else { + else ss << yPad_; - } ss << ")"; return ss.str(); } diff --git a/flashlight/fl/nn/modules/RNN.cpp b/flashlight/fl/nn/modules/RNN.cpp index b425137..24d5c1d 100644 --- a/flashlight/fl/nn/modules/RNN.cpp +++ b/flashlight/fl/nn/modules/RNN.cpp @@ -69,9 +69,8 @@ void RNN::initialize() { } std::vector RNN::forward(const std::vector& inputs) { - if(inputs.empty() || inputs.size() > 3) { + if(inputs.empty() || inputs.size() > 3) throw std::invalid_argument("Invalid inputs size"); - } const auto& input = inputs[0]; const auto& hiddenState = inputs.size() >= 2 ? inputs[1] : Variable(); @@ -92,12 +91,10 @@ std::vector RNN::forward(const std::vector& inputs) { ); std::vector output(1, std::get<0>(rnnRes)); - if(inputs.size() >= 2) { + if(inputs.size() >= 2) output.push_back(std::get<1>(rnnRes)); - } - if(inputs.size() == 3) { + if(inputs.size() == 3) output.push_back(std::get<2>(rnnRes)); - } return output; } @@ -165,15 +162,12 @@ std::string RNN::prettyString() const { } int output_size = bidirectional_ ? 2 * hiddenSize_ : hiddenSize_; ss << " (" << inputSize_ << "->" << output_size << ")"; - if(numLayers_ > 1) { + if(numLayers_ > 1) ss << " (" << numLayers_ << "-layer)"; - } - if(bidirectional_) { + if(bidirectional_) ss << " (bidirectional)"; - } - if(dropProb_ > 0) { + if(dropProb_ > 0) ss << " (dropout=" << dropProb_ << ")"; - } return ss.str(); } diff --git a/flashlight/fl/nn/modules/Reorder.cpp b/flashlight/fl/nn/modules/Reorder.cpp index e609c9c..fee5ff6 100644 --- a/flashlight/fl/nn/modules/Reorder.cpp +++ b/flashlight/fl/nn/modules/Reorder.cpp @@ -18,12 +18,11 @@ namespace fl { Reorder::Reorder(Shape shape) : shape_(std::move(shape)) {} Variable Reorder::forward(const Variable& input) { - if(input.ndim() != shape_.ndim()) { + if(input.ndim() != shape_.ndim()) throw std::invalid_argument( "Reorder::forward - input tensor has different " "number of dimensions than reorder shape." ); - } return reorder(input, shape_); } diff --git a/flashlight/fl/nn/modules/WeightNorm.cpp b/flashlight/fl/nn/modules/WeightNorm.cpp index 0df4f83..51d6436 100644 --- a/flashlight/fl/nn/modules/WeightNorm.cpp +++ b/flashlight/fl/nn/modules/WeightNorm.cpp @@ -30,14 +30,11 @@ WeightNorm& WeightNorm::operator=(const WeightNorm& other) { void WeightNorm::transformDims() { normDim_.clear(); int vNumdims = module_->param(0).ndim(); - if(dim_ < 0 || dim_ > vNumdims) { + if(dim_ < 0 || dim_ > vNumdims) throw std::invalid_argument("invalid dimension for WeightNorm"); - } - for(int i = 0; i < vNumdims; i++) { - if(i != dim_) { + for(int i = 0; i < vNumdims; i++) + if(i != dim_) normDim_.push_back(i); - } - } } void WeightNorm::computeWeight() { @@ -55,11 +52,10 @@ void WeightNorm::computeWeight() { nm = reorder(nm, {3, 0, 1, 2}); nm = norm(nm, {1}, /* p = */ 2, /* keepDims = */ true); nm = reorder(nm, {1, 2, 3, 0}); - } else { + } else throw std::invalid_argument( "Wrong dimension for Weight Norm: " + std::to_string(dim_) ); - } auto wt = v * tileAs(g / nm, v); module_->setParams(wt, 0); } @@ -72,28 +68,25 @@ void WeightNorm::initParams() { if(moduleParams.size() == 2) { auto& b = moduleParams[1]; params_ = {v, g, b}; - } else if(moduleParams.size() == 1) { + } else if(moduleParams.size() == 1) params_ = {v, g}; - } else { + else throw std::invalid_argument("WeightNorm only supports Linear and Conv2D"); - } } void WeightNorm::setParams(const Variable& var, int position) { Module::setParams(var, position); // it is necessary to copy all params to the parent module // due to copies stored in the parent module (not pointers) - if(position == 2) { + if(position == 2) module_->setParams(var, 1); - } else if(position <= 1) { + else if(position <= 1) computeWeight(); - } } std::vector WeightNorm::forward(const std::vector& inputs) { - if(train_) { + if(train_) computeWeight(); - } return module_->forward(inputs); } diff --git a/flashlight/fl/optim/AMSgradOptimizer.cpp b/flashlight/fl/optim/AMSgradOptimizer.cpp index f64152a..1b771cc 100644 --- a/flashlight/fl/optim/AMSgradOptimizer.cpp +++ b/flashlight/fl/optim/AMSgradOptimizer.cpp @@ -47,16 +47,14 @@ AMSgradOptimizer::AMSgradOptimizer( void AMSgradOptimizer::step() { for(size_t i = 0; i < parameters_.size(); i++) { - if(!parameters_[i].isGradAvailable()) { + if(!parameters_[i].isGradAvailable()) continue; - } const Tensor& grad = parameters_[i].grad().tensor(); Tensor& data = parameters_[i].tensor(); - if(wd_ != 0) { + if(wd_ != 0) data = data - wd_ * data; - } Tensor& biasedFirst = biasedFirst_[i]; Tensor& biasedSecond = biasedSecond_[i]; @@ -79,9 +77,8 @@ std::string AMSgradOptimizer::prettyString() const { std::ostringstream ss; ss << "AMSgrad from "; - if(wd_ != 0) { + if(wd_ != 0) ss << " (weight decay=" << wd_ << ")"; - } return ss.str(); } diff --git a/flashlight/fl/optim/AdadeltaOptimizer.cpp b/flashlight/fl/optim/AdadeltaOptimizer.cpp index 644aef2..5a265e8 100644 --- a/flashlight/fl/optim/AdadeltaOptimizer.cpp +++ b/flashlight/fl/optim/AdadeltaOptimizer.cpp @@ -39,17 +39,15 @@ AdadeltaOptimizer::AdadeltaOptimizer( void AdadeltaOptimizer::step() { for(size_t i = 0; i < parameters_.size(); i++) { - if(!parameters_[i].isGradAvailable()) { + if(!parameters_[i].isGradAvailable()) continue; - } const Tensor& grad = parameters_[i].grad().tensor(); Tensor& data = parameters_[i].tensor(); - if(wd_ != 0) { + if(wd_ != 0) // Weight decay term data = data - wd_ * data; - } Tensor& accGrad = accGrad_[i]; Tensor& accDelta = accDelta_[i]; @@ -71,13 +69,11 @@ std::string AdadeltaOptimizer::prettyString() const { std::ostringstream ss; ss << "Adadelta"; - if(wd_ != 0) { + if(wd_ != 0) ss << " (weight decay=" << wd_ << ")"; - } ss << " (rho=" << rho_ << ")"; - if(eps_ != 0) { + if(eps_ != 0) ss << " (epsilon=" << eps_ << ")"; - } return ss.str(); } diff --git a/flashlight/fl/optim/AdagradOptimizer.cpp b/flashlight/fl/optim/AdagradOptimizer.cpp index 9b2cc25..f0632e4 100644 --- a/flashlight/fl/optim/AdagradOptimizer.cpp +++ b/flashlight/fl/optim/AdagradOptimizer.cpp @@ -30,18 +30,16 @@ AdagradOptimizer::AdagradOptimizer( void AdagradOptimizer::step() { for(size_t i = 0; i < parameters_.size(); i++) { - if(!parameters_[i].isGradAvailable()) { + if(!parameters_[i].isGradAvailable()) continue; - } const Tensor& grad = parameters_[i].grad().tensor(); Tensor& data = parameters_[i].tensor(); Tensor& variance = variance_[i]; - if(wd_ != 0) { + if(wd_ != 0) // Weight decay term data = data - wd_ * data; - } variance = variance + grad * grad; fl::eval(variance); @@ -54,9 +52,8 @@ std::string AdagradOptimizer::prettyString() const { std::ostringstream ss; ss << "Adagrad"; - if(eps_ != 0) { + if(eps_ != 0) ss << " (epsilon=" << eps_ << ")"; - } return ss.str(); } diff --git a/flashlight/fl/optim/AdamOptimizer.cpp b/flashlight/fl/optim/AdamOptimizer.cpp index af8d0df..267d437 100644 --- a/flashlight/fl/optim/AdamOptimizer.cpp +++ b/flashlight/fl/optim/AdamOptimizer.cpp @@ -49,17 +49,15 @@ void AdamOptimizer::step() { float correctedLr = lr_ * std::sqrt(correctedBias2) / correctedBias1; for(size_t i = 0; i < parameters_.size(); i++) { - if(!parameters_[i].isGradAvailable()) { + if(!parameters_[i].isGradAvailable()) continue; - } const Tensor& grad = parameters_[i].grad().tensor(); Tensor& data = parameters_[i].tensor(); - if(wd_ != 0) { + if(wd_ != 0) // Weight decay term data = data - wd_ * lr_ * data; - } Tensor& biasedFirst = biasedFirst_[i]; Tensor& biasedSecond = biasedSecond_[i]; @@ -80,9 +78,8 @@ std::string AdamOptimizer::prettyString() const { std::ostringstream ss; ss << "Adam"; - if(wd_ != 0) { + if(wd_ != 0) ss << " (weight decay=" << wd_ << ")"; - } return ss.str(); } diff --git a/flashlight/fl/optim/NAGOptimizer.cpp b/flashlight/fl/optim/NAGOptimizer.cpp index ca33650..a627900 100644 --- a/flashlight/fl/optim/NAGOptimizer.cpp +++ b/flashlight/fl/optim/NAGOptimizer.cpp @@ -25,11 +25,10 @@ NAGOptimizer::NAGOptimizer( wd_(weightDecay), velocities_(), oldLr_(learningRate) { - if(momentum <= 0) { + if(momentum <= 0) throw std::runtime_error( "Invalid momentum for NAG optimizer, it should be > 0" ); - } velocities_.reserve(parameters.size()); for(const auto& parameter : parameters_) { velocities_.emplace_back(fl::full(parameter.shape(), 0, parameter.type())); @@ -41,17 +40,15 @@ void NAGOptimizer::step() { float correctedLr = lr_ / oldLr_; for(size_t i = 0; i < parameters_.size(); i++) { - if(!parameters_[i].isGradAvailable()) { + if(!parameters_[i].isGradAvailable()) continue; - } Tensor& grad = parameters_[i].grad().tensor(); Tensor& data = parameters_[i].tensor(); - if(wd_ != 0) { + if(wd_ != 0) // Weight decay term data = data * (1 - lr_ * wd_); - } Tensor& velocity = velocities_[i]; // this velocity corresponds to fairseq velocity * -1 velocity = mu_ * velocity * correctedLr + lr_ * grad; @@ -67,9 +64,8 @@ std::string NAGOptimizer::prettyString() const { std::ostringstream ss; ss << "NAG (lr=" << lr_ << " ); (previous lr=" << oldLr_ << ");"; - if(wd_ != 0) { + if(wd_ != 0) ss << " (weight decay=" << wd_ << ");"; - } ss << " (Nesterov momentum=" << mu_ << ")"; return ss.str(); } diff --git a/flashlight/fl/optim/NovogradOptimizer.cpp b/flashlight/fl/optim/NovogradOptimizer.cpp index dd7a298..9a075e1 100644 --- a/flashlight/fl/optim/NovogradOptimizer.cpp +++ b/flashlight/fl/optim/NovogradOptimizer.cpp @@ -42,9 +42,8 @@ NovogradOptimizer::NovogradOptimizer( void NovogradOptimizer::step() { for(size_t i = 0; i < parameters_.size(); i++) { - if(!parameters_[i].isGradAvailable()) { + if(!parameters_[i].isGradAvailable()) continue; - } const Tensor& grad = parameters_[i].grad().tensor(); Tensor& data = parameters_[i].tensor(); @@ -69,9 +68,8 @@ std::string NovogradOptimizer::prettyString() const { std::ostringstream ss; ss << "Novograd"; - if(wd_ != 0) { + if(wd_ != 0) ss << " (weight decay=" << wd_ << ")"; - } return ss.str(); } diff --git a/flashlight/fl/optim/Optimizers.cpp b/flashlight/fl/optim/Optimizers.cpp index bed73de..fc614c4 100644 --- a/flashlight/fl/optim/Optimizers.cpp +++ b/flashlight/fl/optim/Optimizers.cpp @@ -23,9 +23,8 @@ FirstOrderOptimizer::FirstOrderOptimizer( lr_(learningRate) {} void FirstOrderOptimizer::zeroGrad() { - for(auto& parameter : parameters_) { + for(auto& parameter : parameters_) parameter.zeroGrad(); - } } } // namespace fl diff --git a/flashlight/fl/optim/RMSPropOptimizer.cpp b/flashlight/fl/optim/RMSPropOptimizer.cpp index 18a4256..c613d05 100644 --- a/flashlight/fl/optim/RMSPropOptimizer.cpp +++ b/flashlight/fl/optim/RMSPropOptimizer.cpp @@ -29,9 +29,8 @@ RMSPropOptimizer::RMSPropOptimizer( wd_(weightDecay), first_(), second_() { - if(useFirst_) { + if(useFirst_) first_.reserve(parameters.size()); - } second_.reserve(parameters.size()); for(const auto& parameter : parameters_) { @@ -47,17 +46,15 @@ RMSPropOptimizer::RMSPropOptimizer( void RMSPropOptimizer::step() { for(size_t i = 0; i < parameters_.size(); i++) { - if(!parameters_[i].isGradAvailable()) { + if(!parameters_[i].isGradAvailable()) continue; - } const Tensor& grad = parameters_[i].grad().tensor(); Tensor& data = parameters_[i].tensor(); - if(wd_ != 0) { + if(wd_ != 0) // Weight decay term data = data - wd_ * data; - } Tensor& second = second_[i]; second = rho_ * second + (1 - rho_) * grad * grad; @@ -83,13 +80,11 @@ std::string RMSPropOptimizer::prettyString() const { std::ostringstream ss; ss << "RMSProp"; - if(wd_ != 0) { + if(wd_ != 0) ss << " (weight decay=" << wd_ << ")"; - } - if(useFirst_) { + if(useFirst_) ss << " (use first moment)"; - } return ss.str(); } diff --git a/flashlight/fl/optim/SGDOptimizer.cpp b/flashlight/fl/optim/SGDOptimizer.cpp index 6c31092..bbeebd1 100644 --- a/flashlight/fl/optim/SGDOptimizer.cpp +++ b/flashlight/fl/optim/SGDOptimizer.cpp @@ -37,17 +37,15 @@ SGDOptimizer::SGDOptimizer( void SGDOptimizer::step() { for(size_t i = 0; i < parameters_.size(); i++) { - if(!parameters_[i].isGradAvailable()) { + if(!parameters_[i].isGradAvailable()) continue; - } Tensor& grad = parameters_[i].grad().tensor(); Tensor& data = parameters_[i].tensor(); - if(wd_ != 0) { + if(wd_ != 0) // Weight decay term grad = grad + wd_ * data; - } if(mu_ != 0) { Tensor& velocity = velocities_[i]; @@ -55,12 +53,11 @@ void SGDOptimizer::step() { // Regular momentum velocity = mu_ * velocity + grad; fl::eval(velocity); - if(useNesterov_) { + if(useNesterov_) // Update for nesterov momentum grad += velocity * mu_; - } else { + else grad = velocity; - } } data = data - lr_ * grad; fl::eval(data); @@ -71,14 +68,12 @@ std::string SGDOptimizer::prettyString() const { std::ostringstream ss; ss << "SGD"; - if(wd_ != 0) { + if(wd_ != 0) ss << " (weight decay=" << wd_ << ")"; - } - if(useNesterov_ && mu_ != 0) { + if(useNesterov_ && mu_ != 0) ss << " (Nesterov momentum=" << mu_ << ")"; - } else if(mu_ != 0) { + else if(mu_ != 0) ss << " (momentum=" << mu_ << ")"; - } return ss.str(); } diff --git a/flashlight/fl/optim/Utils.cpp b/flashlight/fl/optim/Utils.cpp index a3b341b..1545100 100644 --- a/flashlight/fl/optim/Utils.cpp +++ b/flashlight/fl/optim/Utils.cpp @@ -16,21 +16,18 @@ namespace fl { double clipGradNorm(const std::vector& parameters, double maxNorm) { double gradNorm = 0.0; for(const auto& p : parameters) { - if(!p.isGradAvailable()) { + if(!p.isGradAvailable()) continue; - } const auto& grad = p.grad().tensor(); gradNorm += fl::sum(grad * grad).asScalar(); } gradNorm = std::sqrt(gradNorm); double scale = maxNorm / (gradNorm + 1e-6); - if(scale >= 1.0) { + if(scale >= 1.0) return gradNorm; - } for(auto& p : parameters) { - if(!p.isGradAvailable()) { + if(!p.isGradAvailable()) continue; - } p.grad().tensor() *= scale; } return gradNorm; diff --git a/flashlight/fl/runtime/CUDAStream.cpp b/flashlight/fl/runtime/CUDAStream.cpp index 6cc9eb4..1afef8d 100644 --- a/flashlight/fl/runtime/CUDAStream.cpp +++ b/flashlight/fl/runtime/CUDAStream.cpp @@ -65,13 +65,11 @@ std::shared_ptr CUDAStream::wrapUnmanaged( manager.getDevice(DeviceType::CUDA, deviceId).impl(); // satisfies assumptions of makeSharedAndRegister bool needDeviceSwitch = &oldActiveDevice != &device; - if(needDeviceSwitch) { + if(needDeviceSwitch) device.setActive(); - } auto streamPtr = makeSharedAndRegister(device, stream, /* managed */ false); - if(needDeviceSwitch) { + if(needDeviceSwitch) oldActiveDevice.setActive(); - } return streamPtr; } @@ -111,9 +109,8 @@ void CUDAStream::relativeSync(const CUDAStream& waitOn) const { auto& manager = DeviceManager::getInstance(); auto* oldActiveCUDADevice = &manager.getActiveDevice(DeviceType::CUDA); bool needDeviceSwitch = oldActiveCUDADevice != &device_; - if(needDeviceSwitch) { + if(needDeviceSwitch) device_.setActive(); - } // event and stream from same instance are guaranteed to have been created // from the same device FL_CUDA_CHECK(cudaEventRecord(waitOn.event_, waitOn.nativeStream_)); @@ -124,9 +121,8 @@ void CUDAStream::relativeSync(const CUDAStream& waitOn) const { 0 ) ); - if(needDeviceSwitch) { + if(needDeviceSwitch) oldActiveCUDADevice->setActive(); - } } cudaStream_t CUDAStream::handle() const { diff --git a/flashlight/fl/runtime/CUDAUtils.cpp b/flashlight/fl/runtime/CUDAUtils.cpp index ec1bcb3..6a8aa7a 100644 --- a/flashlight/fl/runtime/CUDAUtils.cpp +++ b/flashlight/fl/runtime/CUDAUtils.cpp @@ -24,9 +24,8 @@ std::unordered_map> createCUDADevices() { std::unordered_map> idToDevice; int numCudaDevices = 0; FL_CUDA_CHECK(cudaGetDeviceCount(&numCudaDevices)); - for(auto id = 0; id < numCudaDevices; id++) { + for(auto id = 0; id < numCudaDevices; id++) idToDevice.emplace(id, std::make_unique(id)); - } return idToDevice; } diff --git a/flashlight/fl/runtime/Device.cpp b/flashlight/fl/runtime/Device.cpp index c03fe5d..7a517df 100644 --- a/flashlight/fl/runtime/Device.cpp +++ b/flashlight/fl/runtime/Device.cpp @@ -27,18 +27,16 @@ const std::unordered_set>& Device::getStreams() const { } void Device::addStream(std::shared_ptr stream) { - if(&stream->device() != this) { + if(&stream->device() != this) throw std::runtime_error( "[Device::addStream] Must add stream to owner device" ); - } streams_.insert(stream); } void Device::sync() const { - for(const auto& stream : streams_) { + for(const auto& stream : streams_) stream->sync(); - } } void Device::addSetActiveCallback(std::function callback) { @@ -47,9 +45,8 @@ void Device::addSetActiveCallback(std::function callback) { void Device::setActive() const { setActiveImpl(); - for(auto& callback : setActiveCallbacks_) { + for(auto& callback : setActiveCallbacks_) callback(nativeId()); - } } int X64Device::nativeId() const { diff --git a/flashlight/fl/runtime/DeviceManager.cpp b/flashlight/fl/runtime/DeviceManager.cpp index 22b6bb7..2859dd5 100644 --- a/flashlight/fl/runtime/DeviceManager.cpp +++ b/flashlight/fl/runtime/DeviceManager.cpp @@ -48,11 +48,10 @@ void DeviceManager::enforceDeviceTypeAvailable( std::string_view errorPrefix, const DeviceType type ) const { - if(!isDeviceTypeAvailable(type)) { + if(!isDeviceTypeAvailable(type)) throw std::runtime_error( std::string(errorPrefix) + " device type unavailable" ); - } } DeviceManager& DeviceManager::getInstance() { @@ -74,9 +73,8 @@ std::vector DeviceManager::getDevicesOfType( ) { enforceDeviceTypeAvailable("[DeviceManager::getDevicesOfType]", type); std::vector devices; - for(auto&[_, device] : deviceTypeToInfo_.at(type)) { + for(auto&[_, device] : deviceTypeToInfo_.at(type)) devices.push_back(device.get()); - } return devices; } @@ -85,20 +83,18 @@ std::vector DeviceManager::getDevicesOfType( ) const { enforceDeviceTypeAvailable("[DeviceManager::getDevicesOfType]", type); std::vector devices; - for(auto&[_, device] : deviceTypeToInfo_.at(type)) { + for(auto&[_, device] : deviceTypeToInfo_.at(type)) devices.push_back(device.get()); - } return devices; } Device& DeviceManager::getDevice(const DeviceType type, int id) const { enforceDeviceTypeAvailable("[DeviceManager::getActiveDevice]", type); auto& idToDevice = deviceTypeToInfo_.at(type); - if(!idToDevice.contains(id)) { + if(!idToDevice.contains(id)) throw std::runtime_error( "[DeviceManager::getDevice] unknown device id" ); - } return *idToDevice.at(id); } diff --git a/flashlight/fl/runtime/Stream.cpp b/flashlight/fl/runtime/Stream.cpp index 2210e99..027806e 100644 --- a/flashlight/fl/runtime/Stream.cpp +++ b/flashlight/fl/runtime/Stream.cpp @@ -12,9 +12,8 @@ namespace fl { void Stream::relativeSync( const std::unordered_set& waitOns ) const { - for(const auto* waitOn : waitOns) { + for(const auto* waitOn : waitOns) this->relativeSync(*waitOn); - } } } // namespace fl diff --git a/flashlight/fl/runtime/Stream.h b/flashlight/fl/runtime/Stream.h index 52a824c..d6a048a 100644 --- a/flashlight/fl/runtime/Stream.h +++ b/flashlight/fl/runtime/Stream.h @@ -47,12 +47,11 @@ class FL_API Stream { */ template const T& impl() const { - if(T::type != type()) { + if(T::type != type()) throw std::invalid_argument( "[fl::Stream::impl] " "specified stream type doesn't match actual stream type." ); - } return *(static_cast(this)); } diff --git a/flashlight/fl/tensor/Compute.cpp b/flashlight/fl/tensor/Compute.cpp index 0a981b2..ab28197 100644 --- a/flashlight/fl/tensor/Compute.cpp +++ b/flashlight/fl/tensor/Compute.cpp @@ -23,9 +23,8 @@ namespace { const std::vector& tensors ) { std::unordered_set uniqueStreams; - for(const auto& tensor : tensors) { + for(const auto& tensor : tensors) uniqueStreams.insert(&tensor.stream()); - } return uniqueStreams; } @@ -33,9 +32,8 @@ namespace { const std::vector& tensors ) { std::unordered_set uniqueStreams; - for(const auto& tensor : tensors) { + for(const auto& tensor : tensors) uniqueStreams.insert(&tensor->stream()); - } return uniqueStreams; } @@ -54,16 +52,14 @@ void sync(const int deviceId) { void sync(const std::unordered_set& types) { const auto& manager = DeviceManager::getInstance(); // TODO consider launching these `Device::sync` calls non-blockingly - for(const auto type : types) { + for(const auto type : types) manager.getActiveDevice(type).sync(); - } } void sync(const std::unordered_set& devices) { // TODO consider launching these `Device::sync` calls non-blockingly - for(const auto* device : devices) { + for(const auto* device : devices) device->sync(); - } } void relativeSync( @@ -71,24 +67,21 @@ void relativeSync( const std::vector& waitOns ) { // ensure computations are launched - for(const auto* tensor : waitOns) { + for(const auto* tensor : waitOns) tensor->backend().eval(*tensor); - } wait.relativeSync(tensorsToUniqueStreams(waitOns)); } void relativeSync(const Stream& wait, const std::vector& waitOns) { // ensure computations are launched - for(const auto& tensor : waitOns) { + for(const auto& tensor : waitOns) tensor.backend().eval(tensor); - } wait.relativeSync(tensorsToUniqueStreams(waitOns)); } void relativeSync(const std::vector& waits, const Stream& waitOn) { - for(const auto& stream : tensorsToUniqueStreams(waits)) { + for(const auto& stream : tensorsToUniqueStreams(waits)) stream->relativeSync(waitOn); - } } void eval(Tensor& tensor) { diff --git a/flashlight/fl/tensor/Index.cpp b/flashlight/fl/tensor/Index.cpp index e4073dc..c683c9f 100644 --- a/flashlight/fl/tensor/Index.cpp +++ b/flashlight/fl/tensor/Index.cpp @@ -33,9 +33,8 @@ const std::optional& range::end() const { } Dim range::endVal() const { - if(end_.has_value()) { + if(end_.has_value()) return end_.value(); - } throw std::runtime_error("[range::endVal] end is end_t"); } diff --git a/flashlight/fl/tensor/Shape.cpp b/flashlight/fl/tensor/Shape.cpp index 7fb21f9..69fe4b7 100644 --- a/flashlight/fl/tensor/Shape.cpp +++ b/flashlight/fl/tensor/Shape.cpp @@ -31,9 +31,8 @@ void Shape::checkDimsOrThrow(const size_t dim) const { } Dim Shape::elements() const { - if(dims_.empty()) { + if(dims_.empty()) return kEmptyShapeNumberOfElements; - } return std::accumulate(dims_.begin(), dims_.end(), static_cast(1), std::multiplies()); } @@ -84,9 +83,8 @@ std::vector& Shape::get() { std::string Shape::toString() const { std::stringstream ss; ss << "("; - for(size_t i = 0; i < ndim(); ++i) { + for(size_t i = 0; i < ndim(); ++i) ss << dim(i) << (i == ndim() - 1 ? "" : ", "); - } ss << ")"; return ss.str(); } diff --git a/flashlight/fl/tensor/TensorBackend.cpp b/flashlight/fl/tensor/TensorBackend.cpp index 114e4de..b66eae8 100644 --- a/flashlight/fl/tensor/TensorBackend.cpp +++ b/flashlight/fl/tensor/TensorBackend.cpp @@ -18,9 +18,8 @@ namespace detail { bool TensorBackend::isDataTypeSupported(const fl::dtype& dtype) const { bool supported = this->supportsDataType(dtype); - for(auto& p : extensions_) { + for(auto& p : extensions_) supported &= p.second->isDataTypeSupported(dtype); - } return supported; } diff --git a/flashlight/fl/tensor/TensorBackend.h b/flashlight/fl/tensor/TensorBackend.h index 70fc9e9..cd7c250 100644 --- a/flashlight/fl/tensor/TensorBackend.h +++ b/flashlight/fl/tensor/TensorBackend.h @@ -310,9 +310,8 @@ Tensor toTensorType(Tensor&& in) { // Fast path - backend is the same // TODO: make fl::TensorBackendType a static constexpr on the class as well so // as to not need to instantiate a backend to check the type - if(in.backendType() == T().backendType()) { + if(in.backendType() == T().backendType()) return std::move(in); - } // As per impl requirements, Tensor::device() should return a pointer to host // memory if the tensor resides on the host. diff --git a/flashlight/fl/tensor/TensorBase.cpp b/flashlight/fl/tensor/TensorBase.cpp index fcc4f0b..ed0d180 100644 --- a/flashlight/fl/tensor/TensorBase.cpp +++ b/flashlight/fl/tensor/TensorBase.cpp @@ -210,9 +210,8 @@ FL_CREATE_MEMORY_OPS(unsigned short); // void specializations template<> FL_API void* Tensor::device() const { - if(isEmpty()) { + if(isEmpty()) return nullptr; - } void* out; impl_->device(&out); return out; @@ -220,17 +219,15 @@ FL_API void* Tensor::device() const { template<> FL_API void Tensor::device(void** ptr) const { - if(isEmpty()) { + if(isEmpty()) return; - } impl_->device(ptr); } template<> FL_API void* Tensor::host() const { - if(isEmpty()) { + if(isEmpty()) return nullptr; - } void* out = reinterpret_cast(new char[bytes()]); impl_->host(out); return out; @@ -410,9 +407,8 @@ Tensor tile(const Tensor& tensor, const Shape& shape) { } Tensor concatenate(const std::vector& tensors, const unsigned axis) { - if(tensors.empty()) { + if(tensors.empty()) throw std::invalid_argument("concatenate: called on empty set of tensors"); - } // Check all backends match const TensorBackendType b = tensors.front().backendType(); @@ -424,11 +420,10 @@ Tensor concatenate(const std::vector& tensors, const unsigned axis) { return t.backendType() == b; } ); - if(!matches) { + if(!matches) throw std::invalid_argument( "concatenate: tried to concatenate tensors of different backends" ); - } return tensors.front().backend().concatenate(tensors, axis); } @@ -861,15 +856,12 @@ bool allClose( const fl::Tensor& b, const double absTolerance ) { - if(a.type() != b.type()) { + if(a.type() != b.type()) return false; - } - if(a.shape() != b.shape()) { + if(a.shape() != b.shape()) return false; - } - if(a.elements() == 0 && b.elements() == 0) { + if(a.elements() == 0 && b.elements() == 0) return true; - } return fl::amax(fl::abs(a - b)).astype(dtype::f64).scalar() < absTolerance; } diff --git a/flashlight/fl/tensor/TensorBase.h b/flashlight/fl/tensor/TensorBase.h index 4f1f1bd..f16062d 100644 --- a/flashlight/fl/tensor/TensorBase.h +++ b/flashlight/fl/tensor/TensorBase.h @@ -547,9 +547,8 @@ class FL_API Tensor { */ template std::vector toHostVector() const { - if(isEmpty()) { + if(isEmpty()) return std::vector(); - } std::vector vec(this->elements()); host(vec.data()); return vec; @@ -1738,16 +1737,15 @@ FL_API std::ostream& operator<<(std::ostream& os, const TensorBackendType type); template Tensor to(Tensor&& t) { // Fast path -- types are the same - if(T::tensorBackendType == t.backendType()) { + if(T::tensorBackendType == t.backendType()) return std::move(t); - } - if(t.isSparse()) { + if(t.isSparse()) throw std::invalid_argument( "Tensor type conversion between sparse " "tensors not yet supported." ); - } else { + else // TODO: dynamically fix the memory location based on the type of // backend/where base memory is return Tensor( @@ -1758,7 +1756,6 @@ Tensor to(Tensor&& t) { MemoryLocation::Device ) ); - } } /** @} */ diff --git a/flashlight/fl/tensor/TensorExtension.cpp b/flashlight/fl/tensor/TensorExtension.cpp index feca51c..c4ef365 100644 --- a/flashlight/fl/tensor/TensorExtension.cpp +++ b/flashlight/fl/tensor/TensorExtension.cpp @@ -43,19 +43,17 @@ TensorExtensionCallback& TensorExtensionRegistrar::getTensorExtensionCreationFun TensorBackendType backend, TensorExtensionType extensionType ) { - if(extensions_.find(backend) == extensions_.end()) { + if(extensions_.find(backend) == extensions_.end()) throw std::invalid_argument( "TensorExtensionRegistrar::getTensorExtensionCreationFunc: " "no tensor extensions registered for given backend." ); - } auto& _extensions = extensions_[backend]; - if(_extensions.find(extensionType) == _extensions.end()) { + if(_extensions.find(extensionType) == _extensions.end()) throw std::invalid_argument( "TensorExtensionRegistrar::getTensorExtensionCreationFunc: " "no tensor extensions registered for backend " + tensorBackendTypeToString(backend) ); - } return _extensions[extensionType]; } diff --git a/flashlight/fl/tensor/Types.cpp b/flashlight/fl/tensor/Types.cpp index 9667c10..1605312 100644 --- a/flashlight/fl/tensor/Types.cpp +++ b/flashlight/fl/tensor/Types.cpp @@ -74,9 +74,8 @@ const std::string& dtypeToString(dtype type) { } fl::dtype stringToDtype(const std::string& string) { - if(kStringToType.find(string) != kStringToType.end()) { + if(kStringToType.find(string) != kStringToType.end()) return kStringToType.at(string); - } throw std::invalid_argument("stringToDtype: Invalid input type: " + string); } diff --git a/flashlight/fl/tensor/backend/af/AdvancedIndex.cu b/flashlight/fl/tensor/backend/af/AdvancedIndex.cu index 84b729f..de2b472 100644 --- a/flashlight/fl/tensor/backend/af/AdvancedIndex.cu +++ b/flashlight/fl/tensor/backend/af/AdvancedIndex.cu @@ -39,9 +39,8 @@ template < class Float, class Index // the input and output tensors dim_t dims[4], strides[4]; dim_t outStrides[4]; - for(int i = 0; i < 4; i++) { + for(int i = 0; i < 4; i++) dims[i] = idxEnd[i] - idxStart[i]; - } strides[0] = 1; outStrides[0] = 1; // arrayfire dimensions are inverted compared to numpy @@ -72,9 +71,8 @@ template < class Float, class Index if(idxArr[i]) { auto idxArrPtr = (Index*) idxArr[i]; outIdx += idxArrPtr[index[i]] * outStrides[i]; - } else { + } else outIdx += (idxStart[i] + index[i]) * outStrides[i]; - } } // atomic addition is done to ensure correct // gradient computation for repeated indices @@ -96,15 +94,12 @@ namespace fl { auto inpType = inp.type(); auto outType = out.type(); - if((inpType != af::dtype::f32) && (inpType != af::dtype::f16)) { + if((inpType != af::dtype::f32) && (inpType != af::dtype::f16)) throw std::invalid_argument("Input type must be f16/f32"); - } - if((outType != af::dtype::f32) && (outType != af::dtype::f16)) { + if((outType != af::dtype::f32) && (outType != af::dtype::f16)) throw std::invalid_argument("Output type must be f16/f32"); - } - if(idxArr.size() != 4) { + if(idxArr.size() != 4) throw std::invalid_argument("Index array vector must be length 4"); - } af::dim4 idxPtr; // Extract raw device pointers for dimensions @@ -117,31 +112,26 @@ namespace fl { idxPtr[i] = 0; continue; } - if(validIndexTypes.find(idxArr[i].type()) == validIndexTypes.end()) { + if(validIndexTypes.find(idxArr[i].type()) == validIndexTypes.end()) throw std::invalid_argument( "Index type must be one of s32/s64/u32/u64, observed type is " + std::to_string(idxArr[i].type()) ); - } idxTypes.push_back(idxArr[i].type()); idxPtr[i] = (dim_t) (idxArr[i].device < void > ()); } - for(int i = 0; i + 1 < idxTypes.size(); i++) { - if(idxTypes[i] != idxTypes[i + 1]) { + for(int i = 0; i + 1 < idxTypes.size(); i++) + if(idxTypes[i] != idxTypes[i + 1]) throw std::invalid_argument( "Index type must be the same across all dimensions" ); - } - } af::array inpCast = inp; af::array outCast = out; - if(inpType == af::dtype::f16) { + if(inpType == af::dtype::f16) inpCast = inp.as(af::dtype::f32); - } - if(outType == af::dtype::f16) { + if(outType == af::dtype::f16) outCast = out.as(af::dtype::f32); - } void* inpRawPtr = inpCast.device < void > (); void* outRawPtr = outCast.device < void > (); @@ -155,7 +145,7 @@ namespace fl { void* arrIdxPtrDev = arrIdxPtr.device < void > (); cudaStream_t stream = afcu::getStream(af::getDevice()); - if(idxTypes.size() == 0 || idxTypes[0] == af::dtype::s32) { + if(idxTypes.size() == 0 || idxTypes[0] == af::dtype::s32) advancedIndexKernel < float, int32_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( static_cast < const float* > (inpRawPtr), static_cast < const dim_t * > (arrIdxStartDev), @@ -163,7 +153,7 @@ namespace fl { static_cast < const dim_t * > (arrOutDimsDev), static_cast < const dim_t * > (arrIdxPtrDev), static_cast < float* > (outRawPtr)); - } else if(idxTypes[0] == af::dtype::s64) { + else if(idxTypes[0] == af::dtype::s64) advancedIndexKernel < float, int64_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( static_cast < const float* > (inpRawPtr), static_cast < const dim_t * > (arrIdxStartDev), @@ -171,7 +161,7 @@ namespace fl { static_cast < const dim_t * > (arrOutDimsDev), static_cast < const dim_t * > (arrIdxPtrDev), static_cast < float* > (outRawPtr)); - } else if(idxTypes[0] == af::dtype::u32) { + else if(idxTypes[0] == af::dtype::u32) advancedIndexKernel < float, uint32_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( static_cast < const float* > (inpRawPtr), static_cast < const dim_t * > (arrIdxStartDev), @@ -179,7 +169,7 @@ namespace fl { static_cast < const dim_t * > (arrOutDimsDev), static_cast < const dim_t * > (arrIdxPtrDev), static_cast < float* > (outRawPtr)); - } else if(idxTypes[0] == af::dtype::u64) { + else if(idxTypes[0] == af::dtype::u64) advancedIndexKernel < float, uint64_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( static_cast < const float* > (inpRawPtr), static_cast < const dim_t * > (arrIdxStartDev), @@ -187,14 +177,12 @@ namespace fl { static_cast < const dim_t * > (arrOutDimsDev), static_cast < const dim_t * > (arrIdxPtrDev), static_cast < float* > (outRawPtr)); - } else { + else throw std::invalid_argument("Index type must be one of s32/s64/u32/u64"); - } - if(cudaPeekAtLastError() != cudaSuccess) { + if(cudaPeekAtLastError() != cudaSuccess) throw std::runtime_error( "ArrayFireTensor advancedIndex kernel CUDA failure" ); - } inpCast.unlock(); outCast.unlock(); @@ -202,14 +190,12 @@ namespace fl { arrIdxEnd.unlock(); arrOutDims.unlock(); arrIdxPtr.unlock(); - for(const auto& arr : idxArr) { + for(const auto& arr : idxArr) arr.unlock(); - } out = outCast; - if(outType == af::dtype::f16) { + if(outType == af::dtype::f16) out = outCast.as(af::dtype::f16); - } } } // namespace detail diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp b/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp index bdbf9ad..610004b 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireBLAS.cpp @@ -21,9 +21,8 @@ Tensor ArrayFireBackend::matmul( MatrixProperty rhsProp ) { unsigned numDims = std::max(lhs.ndim(), rhs.ndim()); - if((lhs.ndim() == 1 || rhs.ndim() == 1) && numDims > 1) { + if((lhs.ndim() == 1 || rhs.ndim() == 1) && numDims > 1) numDims -= 1; - } af::array lhsArray = toArray(lhs); af::array rhsArray = toArray(rhs); @@ -38,12 +37,10 @@ Tensor ArrayFireBackend::matmul( rhsProp = MatrixProperty::None; numDims = 1; } else { - if(rhs.ndim() == 1) { + if(rhs.ndim() == 1) rhsArray = af::moddims(toArray(rhs), {rhs.dim(0), 1}); - } - if(lhs.ndim() == 1) { + if(lhs.ndim() == 1) lhsArray = af::moddims(toArray(lhs), {1, lhs.dim(0)}); - } } auto arr = af::matmul( @@ -53,9 +50,8 @@ Tensor ArrayFireBackend::matmul( detail::flToAfMatrixProperty(rhsProp) ); - if(lhs.ndim() == 1 && rhs.ndim() == 2) { + if(lhs.ndim() == 1 && rhs.ndim() == 2) arr = af::moddims(arr, arr.dims(1)); - } return toTensor(std::move(arr), numDims); } diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp b/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp index 9aceaa1..c5896bb 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp @@ -44,9 +44,8 @@ namespace { std::unordered_map>& afIdToStream ) { auto iter = afIdToStream.find(afId); - if(iter != afIdToStream.end()) { + if(iter != afIdToStream.end()) return *iter->second; - } #if FL_ARRAYFIRE_USE_CPU auto resIter = afIdToStream.emplace(afId, ArrayFireCPUStream::create()); @@ -78,9 +77,8 @@ ArrayFireBackend::ArrayFireBackend() { // AF race conditions when tearing down our custom memory manager. // TODO: remove this temporary workaround for crashes when using custom // opencl kernels. - if(FL_BACKEND_CUDA) { + if(FL_BACKEND_CUDA) MemoryManagerInstaller::installDefaultMemoryManager(); - } } ); @@ -170,45 +168,39 @@ void ArrayFireBackend::getMemMgrInfo( std::ostream* ostream ) { int deviceId = nativeIdToId_.at(nativeDeviceId); - if(ostream == nullptr) { + if(ostream == nullptr) throw std::invalid_argument( "ArrayFireBackend::getMemMgrInfo - got null ostream pointer" ); - } auto* curMemMgr = fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); - if(curMemMgr) { + if(curMemMgr) curMemMgr->printInfo(msg, deviceId, ostream); - } } void ArrayFireBackend::setMemMgrLogStream(std::ostream* stream) { - if(stream == nullptr) { + if(stream == nullptr) throw std::invalid_argument( "ArrayFireBackend::getMemMgrInfo - got null ostream pointer" ); - } auto* curMemMgr = fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); - if(curMemMgr) { + if(curMemMgr) curMemMgr->setLogStream(stream); - } } void ArrayFireBackend::setMemMgrLoggingEnabled(const bool enabled) { auto* curMemMgr = fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); - if(curMemMgr) { + if(curMemMgr) curMemMgr->setLoggingEnabled(enabled); - } } void ArrayFireBackend::setMemMgrFlushInterval(const size_t interval) { auto* curMemMgr = fl::MemoryManagerInstaller::currentlyInstalledMemoryManager(); - if(curMemMgr) { + if(curMemMgr) curMemMgr->setLogFlushInterval(interval); - } } /* -------------------------- Rand Functions -------------------------- */ @@ -320,11 +312,10 @@ void ArrayFireBackend::topk( const Dim axis, const SortMode sortMode ) { - if(axis != 0) { + if(axis != 0) throw std::invalid_argument( "ArrayFireTensor topk: operation only supported along zero axis." ); - } af::array valuesArr, indicesArr; af::topk( valuesArr, @@ -344,12 +335,11 @@ Tensor ArrayFireBackend::sort( const Dim axis, const SortMode sortMode ) { - if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) { + if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) throw std::invalid_argument( "Cannot sort ArrayFire tensor with given SortMode: " "only Descending and Ascending supported." ); - } af::array values, indices; af::sort( @@ -369,12 +359,11 @@ void ArrayFireBackend::sort( const Dim axis, const SortMode sortMode ) { - if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) { + if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) throw std::invalid_argument( "Cannot sort ArrayFire tensor with given SortMode: " "only Descending and Ascending supported." ); - } af::array _values, _indices; af::sort( @@ -393,12 +382,11 @@ Tensor ArrayFireBackend::argsort( const Dim axis, const SortMode sortMode ) { - if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) { + if(sortMode != SortMode::Descending && sortMode != SortMode::Ascending) throw std::invalid_argument( "Cannot sort ArrayFire tensor with given SortMode: " "only Descending and Ascending supported." ); - } af::array values, indices; af::sort( diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp b/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp index 6223ec6..033fed2 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp @@ -19,14 +19,12 @@ namespace { unsigned nDim = std::max(lhs.ndim(), rhs.ndim()); for(unsigned i = 0; i < nDim; ++i) { - if(i + 1 > lhs.ndim() || i + 1 > rhs.ndim()) { + if(i + 1 > lhs.ndim() || i + 1 > rhs.ndim()) // One Shape has more dimensions than the other - will broadcast to the // smaller tensor continue; - } - if(lhs[i] != rhs[i] && lhs[i] != 1 && rhs[i] != 1) { + if(lhs[i] != rhs[i] && lhs[i] != 1 && rhs[i] != 1) return false; - } } return true; } @@ -44,19 +42,18 @@ namespace { if( lhs.shape() == rhs.shape() || (lhs.elements() <= 1 && rhs.elements() <= 1) - ) { + ) return toTensor( func(toArray(lhs), toArray(rhs)), lhs.ndim() ); - } - if(canBroadcast(lhs.shape(), rhs.shape())) { + if(canBroadcast(lhs.shape(), rhs.shape())) return toTensor( af::batchFunc(toArray(lhs), toArray(rhs), func), std::max(lhs.ndim(), rhs.ndim()) ); - } else { + else { std::stringstream ss; ss << "doBinaryOpOrBroadcast: cannot perform operation " "or broadcasting with tensors of shapes " diff --git a/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp b/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp index 593289f..bd00f5f 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireReductions.cpp @@ -28,39 +28,33 @@ namespace { const bool keepDims = false ) { auto arr = input; - for(int dim : axes) { + for(int dim : axes) arr = func(arr, dim); - } return fl::detail::condenseIndices(arr, keepDims); } unsigned getReducedNumDims(unsigned inSize, unsigned axisSize, const bool keepDims) { - if(keepDims) { + if(keepDims) return inSize; - } else { - if(inSize < axisSize) { + else { + if(inSize < axisSize) return 0; - } else { + else return inSize - axisSize; - } } } bool isAllAxisReduction(const Tensor& input, const std::vector& axes) { - if(input.ndim() == 0 || axes.empty()) { + if(input.ndim() == 0 || axes.empty()) return true; - } - if(input.ndim() != axes.size()) { + if(input.ndim() != axes.size()) return false; - } // Check that all dims are present auto _axes = axes; std::sort(_axes.begin(), _axes.end()); - for(size_t i = 0; i < _axes.size(); ++i) { - if(_axes[i] != i) { + for(size_t i = 0; i < _axes.size(); ++i) + if(_axes[i] != i) return false; - } - } return true; } } // namespace @@ -70,7 +64,7 @@ Tensor ArrayFireBackend::amin( const std::vector& axes, const bool keepDims ) { - if(isAllAxisReduction(input, axes)) { + if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor // TODO: modify this to af::min to take advantage of the // ArrayFire reduce_all kernels once available @@ -80,12 +74,11 @@ Tensor ArrayFireBackend::amin( ), /* numDims = */ 0 ); - } else { + else return toTensor( afReduceAxes(toArray(input), axes, af::min, keepDims), getReducedNumDims(input.ndim(), axes.size(), keepDims) ); - } } Tensor ArrayFireBackend::amax( @@ -93,7 +86,7 @@ Tensor ArrayFireBackend::amax( const std::vector& axes, const bool keepDims ) { - if(isAllAxisReduction(input, axes)) { + if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor // TODO: modify this to af::max to take advantage of the // ArrayFire reduce_all kernels once available @@ -103,12 +96,11 @@ Tensor ArrayFireBackend::amax( ), /* numDims = */ 0 ); - } else { + else return toTensor( afReduceAxes(toArray(input), axes, af::max, keepDims), getReducedNumDims(input.ndim(), axes.size(), keepDims) ); - } } void ArrayFireBackend::min( @@ -152,7 +144,7 @@ Tensor ArrayFireBackend::sum( const std::vector& axes, const bool keepDims ) { - if(isAllAxisReduction(input, axes)) { + if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor // TODO: modify this to af::sum to take advantage of the // ArrayFire reduce_all kernels once available @@ -162,12 +154,11 @@ Tensor ArrayFireBackend::sum( ), /* numDims = */ 0 ); - } else { + else return toTensor( afReduceAxes(toArray(input), axes, af::sum, keepDims), getReducedNumDims(input.ndim(), axes.size(), keepDims) ); - } } Tensor ArrayFireBackend::cumsum(const Tensor& input, const unsigned axis) { @@ -208,7 +199,7 @@ Tensor ArrayFireBackend::mean( const std::vector& axes, const bool keepDims ) { - if(isAllAxisReduction(input, axes)) { + if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor // TODO: modify this to af::mean to take advantage of the // ArrayFire reduce_all kernels once available @@ -218,7 +209,7 @@ Tensor ArrayFireBackend::mean( ), /* numDims = */ 0 ); - } else { + else return toTensor( afReduceAxes( toArray(input), @@ -228,7 +219,6 @@ Tensor ArrayFireBackend::mean( ), getReducedNumDims(input.ndim(), axes.size(), keepDims) ); - } } Tensor ArrayFireBackend::median( @@ -245,7 +235,7 @@ Tensor ArrayFireBackend::median( af::constant(median, 1), /* numDims = */ 0 ); - } else { + } else return toTensor( afReduceAxes( toArray(input), @@ -255,7 +245,6 @@ Tensor ArrayFireBackend::median( ), getReducedNumDims(input.ndim(), axes.size(), keepDims) ); - } } Tensor ArrayFireBackend::var( @@ -273,12 +262,12 @@ Tensor ArrayFireBackend::var( if(isAllAxisReduction(input, axes)) { double out = af::var(toArray(input), biasMode); return toTensor(af::constant(out, 1), /* numDims = */ 0); - } else if(axes.size() == 1) { + } else if(axes.size() == 1) return toTensor( detail::condenseIndices(af::var(arr, biasMode, axes[0]), keepDims), getReducedNumDims(input.ndim(), axes.size(), keepDims) ); - } else { + else { auto meanArr = mean(input, axes, /* keepDims = */ true); auto x = af::batchFunc(arr, toArray(meanArr), af::operator-); @@ -287,12 +276,10 @@ Tensor ArrayFireBackend::var( int denominator = 1; auto dims = arr.dims(); - for(auto dim : axes) { + for(auto dim : axes) denominator *= dims[dim]; - } - if(bias) { + if(bias) denominator--; - } x = x / denominator; return toTensor( @@ -313,7 +300,7 @@ Tensor ArrayFireBackend::std( // TODO: update to af::stdev once specialization is available double out = af::stdev(toArray(input), biasMode); return toTensor(af::constant(out, 1), /* numDims = */ 0); - } else if(axes.size() == 1) { + } else if(axes.size() == 1) // Use arrayfire default for one dimension which may be optimized // TODO: update this? stddev is deprecated. return toTensor( @@ -323,7 +310,6 @@ Tensor ArrayFireBackend::std( ), getReducedNumDims(input.ndim(), axes.size(), keepDims) ); - } return this->sqrt(this->var(input, axes, /* bias = */ bias, keepDims)); } @@ -393,7 +379,7 @@ Tensor ArrayFireBackend::any( const std::vector& axes, const bool keepDims ) { - if(isAllAxisReduction(input, axes)) { + if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor // TODO: modify this to af::anyTrue to take advantage of the // ArrayFire reduce_all kernels once available @@ -403,12 +389,11 @@ Tensor ArrayFireBackend::any( ), /* numDims = */ 0 ); - } else { + else return toTensor( afReduceAxes(toArray(input), axes, af::anyTrue, keepDims), getReducedNumDims(input.ndim(), axes.size(), keepDims) ); - } } Tensor ArrayFireBackend::all( @@ -416,7 +401,7 @@ Tensor ArrayFireBackend::all( const std::vector& axes, const bool keepDims ) { - if(isAllAxisReduction(input, axes)) { + if(isAllAxisReduction(input, axes)) // Reduce along all axes returning a singleton tensor // TODO: modify this to af::allTrue to take advantage of the // ArrayFire reduce_all kernels once available @@ -426,11 +411,10 @@ Tensor ArrayFireBackend::all( ), /* numDims = */ 0 ); - } else { + else return toTensor( afReduceAxes(toArray(input), axes, af::allTrue, keepDims), getReducedNumDims(input.ndim(), axes.size(), keepDims) ); - } } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp b/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp index 48317c5..397cd75 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireShapeAndIndex.cpp @@ -27,22 +27,21 @@ Tensor ArrayFireBackend::transpose( const Tensor& tensor, const Shape& axes /* = {} */ ) { - if(tensor.ndim() == 1) { + if(tensor.ndim() == 1) return tensor; - } else if( - tensor.ndim() == 2 && (axes.ndim() == 0 || axes == Shape({1, 0}))) { + else if( + tensor.ndim() == 2 && (axes.ndim() == 0 || axes == Shape({1, 0}))) // fastpath for matrices return toTensor( af::transpose(toArray(tensor)), tensor.ndim() ); - } else if(axes.ndim() == 0) { + else if(axes.ndim() == 0) { std::vector dims(AF_MAX_DIMS); std::iota(std::begin(dims), std::end(dims), 0); // Compute the reversed dimensions for as many ndims as are in the input - for(unsigned i = 0; i < tensor.ndim(); ++i) { + for(unsigned i = 0; i < tensor.ndim(); ++i) dims[i] = tensor.ndim() - 1 - i; - } // flip all dimensions return toTensor( @@ -50,29 +49,26 @@ Tensor ArrayFireBackend::transpose( tensor.ndim() ); } else { - if(axes.ndim() > AF_MAX_DIMS) { + if(axes.ndim() > AF_MAX_DIMS) throw std::invalid_argument( "ArrayFire tensor transpose was given " "permutation dims with > 4 axes" ); - } - if(axes.ndim() != tensor.ndim()) { + if(axes.ndim() != tensor.ndim()) throw std::invalid_argument( "ArrayFire tensor transpose axes don't match tensor's for " "permutation - axes must have the same number of " "dimensions as the tensor" ); - } // reorder based on specified dimensions std::vector d(AF_MAX_DIMS); std::iota(std::begin(d), std::end(d), 0); for(size_t i = 0; i < axes.ndim(); ++i) { - if(axes[i] > tensor.ndim() - 1) { + if(axes[i] > tensor.ndim() - 1) throw std::invalid_argument( "ArrayFireBackend::transpose - given dimension is larger " "than the number of dimensions in the tensor" ); - } d[i] = axes[i]; } @@ -129,9 +125,8 @@ Tensor ArrayFireBackend::concatenate( } unsigned numDims = tensors[0].ndim(); - if(axis > std::max(numDims - 1, 0u)) { + if(axis > std::max(numDims - 1, 0u)) numDims = axis + 1; - } // All tensors have the same numdims else AF would throw return toTensor(std::move(out), numDims); @@ -149,11 +144,10 @@ Tensor ArrayFireBackend::pad( const std::vector>& padWidths, const PadType type ) { - if(padWidths.size() > AF_MAX_DIMS) { + if(padWidths.size() > AF_MAX_DIMS) throw std::invalid_argument( "ArrayFireBackend::pad - given padWidths for more than 4 dimensions" ); - } // convert ((begin_1, end_1), ..., (begin_k, end_k)) to ((begin_1, ..., // begin_k), (end_1, ..., end_k)) for ArrayFire diff --git a/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp b/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp index 9f06309..b405112 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp @@ -27,16 +27,14 @@ namespace fl { const af::array& toArray(const Tensor& tensor) { - if(tensor.backendType() != TensorBackendType::ArrayFire) { + if(tensor.backendType() != TensorBackendType::ArrayFire) throw std::invalid_argument("toArray: tensor is not ArrayFire-backed"); - } return tensor.getAdapter().getHandle(); } af::array& toArray(Tensor& tensor) { - if(tensor.backendType() != TensorBackendType::ArrayFire) { + if(tensor.backendType() != TensorBackendType::ArrayFire) throw std::invalid_argument("toArray: tensor is not ArrayFire-backed"); - } return tensor.getAdapter().getHandle(); } @@ -244,12 +242,11 @@ void ArrayFireTensor::unlock() { bool ArrayFireTensor::isLocked() { bool res; auto err = af_is_locked_array(&res, getHandle().get()); - if(err != AF_SUCCESS) { + if(err != AF_SUCCESS) throw std::runtime_error( "ArrayFireTensor::isLocked - af_is_locked_array returned error: " + std::to_string(err) ); - } return res; } @@ -273,12 +270,11 @@ Tensor ArrayFireTensor::astype(const dtype type) { } Tensor ArrayFireTensor::index(const std::vector& indices) { - if(indices.size() > AF_MAX_DIMS) { + if(indices.size() > AF_MAX_DIMS) throw std::invalid_argument( "ArrayFire-backed tensor was indexed with > 4 elements:" "ArrayFire tensors support up to 4 dimensions." ); - } // TODO: vet and stress test this a lot more/add proper support for // multi-tensor @@ -289,18 +285,17 @@ Tensor ArrayFireTensor::index(const std::vector& indices) { && indices.front().type() == detail::IndexType::Tensor && indices.front().get().elements() == getHandle().elements(); std::vector afIndices; - if(completeTensorIndex) { + if(completeTensorIndex) afIndices = {af::index(0)}; - } else { + else { afIndices = {af::span, af::span, af::span, af::span}; // implicit spans } - if(indices.size() > afIndices.size()) { + if(indices.size() > afIndices.size()) throw std::logic_error( "ArrayFireTensor::index internal error - passed indiecs is larger " "than the number of af indices" ); - } // Fill in corresponding index types for each af index std::vector indexTypes(afIndices.size()); @@ -310,9 +305,8 @@ Tensor ArrayFireTensor::index(const std::vector& indices) { afIndices[i] = detail::flToAfIndex(indices[i]); } // If we're adding implicit spans, fill those indexTypes in - for(; i < afIndices.size(); ++i) { + for(; i < afIndices.size(); ++i) indexTypes[i] = detail::IndexType::Span; - } getHandle(); // if this tensor was a view, run indexing and promote @@ -320,17 +314,14 @@ Tensor ArrayFireTensor::index(const std::vector& indices) { // Compute numDums for the new Tensor unsigned newNumDims = numDims(); - if(completeTensorIndex) { + if(completeTensorIndex) // TODO/FIXME: compute this based on the number of els in the indexing // tensor(s) newNumDims = 1; - } else { - for(const auto& type : indexTypes) { - if(type == detail::IndexType::Literal) { + else + for(const auto& type : indexTypes) + if(type == detail::IndexType::Literal) newNumDims--; - } - } - } newNumDims = std::max(newNumDims, 1u); // can never index to a 0 dim tensor return fl::Tensor( @@ -436,22 +427,20 @@ af::array ArrayFireTensor::adjustInPlaceOperandDims(const Tensor& operand) { if(indices_ && indices_.value().size() == 1) { // This case is only reachable via tensor-based indexing or indexing on a // tensor via Tensor::flat() - if(numDims_ != 1) { + if(numDims_ != 1) throw std::invalid_argument( "ArrayFireTensor::adjustInPlaceOperandDims " "index size was 1 but tensor has greater than 1 dimension." ); - } } else if(indices_ && !indices_.value().empty()) { // All other indexing operations const auto& indices = indices_.value(); const auto& indexTypes = indexTypes_.value(); - if(indices.size() != indexTypes.size()) { + if(indices.size() != indexTypes.size()) throw std::invalid_argument( "ArrayFireTensor adjustInPlaceOperandDims - passed indices" " and indexTypes are of different sizes." ); - } // If the dimensions being indexed are 1 and collapsing them yields the same // shape as the operand, we can safely moddims, the operand, else there's a @@ -459,14 +448,12 @@ af::array ArrayFireTensor::adjustInPlaceOperandDims(const Tensor& operand) { // {4, 5, 6, 7}(span, span, 5) --> {4, 5, 1, 7} --> {4, 5, 7} // {4, 5, 6, 7}(4) --> {1, 5, 1, 7} --> {5, 1, 7, 1} std::vector indicesToCompress; - for(unsigned i = 0; i < indices.size(); ++i) { + for(unsigned i = 0; i < indices.size(); ++i) // If an index literal, the corresponding dimension in the indexed array // is 1, then we indexed the input to a dim of 1, so we can condense that // index - if(indexTypes[i] == IndexType::Literal) { + if(indexTypes[i] == IndexType::Literal) indicesToCompress.push_back(i); - } - } af::dim4 condensedDims(1, 1, 1, 1); af::dim4 postIdxDims = preIdxDims; @@ -487,11 +474,10 @@ af::array ArrayFireTensor::adjustInPlaceOperandDims(const Tensor& operand) { dim_t size; AF_CHECK(af_get_elements(&size, indices[i].get().idx.arr)); postIdxDims[i] = size; - } else if(indexTypes[i] == IndexType::Range) { + } else if(indexTypes[i] == IndexType::Range) postIdxDims[i] = af::seq(indices[i].get().idx.seq).size; - } else if(indexTypes[i] == IndexType::Literal) { + else if(indexTypes[i] == IndexType::Literal) postIdxDims[i] = 1; - } } condensedDims[outDimIdx] = postIdxDims[i]; outDimIdx++; @@ -500,18 +486,16 @@ af::array ArrayFireTensor::adjustInPlaceOperandDims(const Tensor& operand) { // Can modify the operand to work with the proxy or array input only by // removing singleton dimensions - if(condensedDims == operandDims) { + if(condensedDims == operandDims) newDims = postIdxDims; - } else { + else throw std::invalid_argument( "ArrayFireTensor adjustInPlaceOperandDims: can't apply operation " "in-place to indexed ArrayFireTensor - dimensions don't match." ); - } - } else { + } else // No indexing so no change in dimensions required newDims = operandDims; - } // af::moddims involves an eval. This will be fixed in AF 3.8.1/3.8.2 bool doModdims = operandArr.dims() != newDims; @@ -542,11 +526,11 @@ ASSIGN_OP_LITERALS(assign, = ); void ArrayFireTensor::assign(const Tensor& tensor) { std::visit( [&tensor, this](auto&& arr) { - if(indices_) { + if(indices_) // If this is an indexing op, do as other in-place ops with lvalue // temporaries as a result of indexing do arr.get(*this) = this->adjustInPlaceOperandDims(tensor); - } else { + else { // Not an indexing op - just assign the tensor, but make sure to // update the number of dims arr.get(*this) = toArray(tensor); @@ -619,13 +603,11 @@ void ArrayFireTensor::inPlaceAdd(const Tensor& tensor) { }; unsigned i = 0; - for(; i < indices_.value().size(); ++i) { + for(; i < indices_.value().size(); ++i) idxFunc(indices_.value()[i], i); - } // The kernel needs to be padded with spans for remaining dims - for(; i < AF_MAX_DIMS; ++i) { + for(; i < AF_MAX_DIMS; ++i) idxFunc(af::span, i); - } fl::detail::advancedIndex( toArray(tensor), diff --git a/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp b/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp index 8415b8d..c12e1a3 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireUnaryOps.cpp @@ -97,11 +97,10 @@ Tensor ArrayFireBackend::roll( const int shift, const unsigned axis ) { - if(axis > AF_MAX_DIMS) { + if(axis > AF_MAX_DIMS) throw std::invalid_argument( "ArrayFireBackend::roll - given axis > 3 - unsupported" ); - } std::vector shifts(AF_MAX_DIMS, 0); shifts[axis] = shift; return toTensor( diff --git a/flashlight/fl/tensor/backend/af/Utils.cpp b/flashlight/fl/tensor/backend/af/Utils.cpp index 5f0bdfe..946d0d1 100644 --- a/flashlight/fl/tensor/backend/af/Utils.cpp +++ b/flashlight/fl/tensor/backend/af/Utils.cpp @@ -94,23 +94,20 @@ af_topk_function flToAfTopKSortMode(SortMode sortMode) { } af::dim4 flToAfDims(const Shape& shape) { - if(shape.ndim() > 4) { + if(shape.ndim() > 4) throw std::invalid_argument( "flToAfDims: ArrayFire shapes can't be more than 4 dimensions" ); - } af::dim4 out(1, 1, 1, 1); - for(size_t i = 0; i < shape.ndim(); ++i) { + for(size_t i = 0; i < shape.ndim(); ++i) out.dims[i] = shape.dim(i); - } return out; } void afToFlDims(const af::dim4& d, const unsigned numDims, Shape& s) { - if(numDims > AF_MAX_DIMS) { + if(numDims > AF_MAX_DIMS) throw std::invalid_argument("afToFlDims - numDims > AF_MAX_DIMS"); - } auto& storage = s.get(); @@ -129,9 +126,8 @@ void afToFlDims(const af::dim4& d, const unsigned numDims, Shape& s) { } storage.resize(numDims); - for(unsigned i = 0; i < numDims; ++i) { + for(unsigned i = 0; i < numDims; ++i) s[i] = d[i]; - } } Shape afToFlDims(const af::dim4& d, const unsigned numDims) { @@ -147,11 +143,10 @@ af::seq flRangeToAfSeq(const fl::range& range) { // There could be have other empty sequence representations, e.g., (0, -1) // for axis with 1 element. In those cases, AF will throw internally -- // we can't throw here because these cases axis-size dependent. - if(optEnd.has_value() && optEnd.value() == start) { + if(optEnd.has_value() && optEnd.value() == start) throw std::runtime_error( "flRangeToAfSeq: AF seq can't represent empty sequence" ); - } return af::seq(start, end, range.stride()); } @@ -173,20 +168,18 @@ af::index flToAfIndex(const fl::Index& idx) { } af::dim4 condenseDims(const af::dim4& dims) { - if(dims.elements() == 0) { + if(dims.elements() == 0) return af::dim4(0); - } // Find the condensed shape af::dim4 newDims(1, 1, 1, 1); unsigned newDimIdx = 0; - for(unsigned i = 0; i < AF_MAX_DIMS; ++i) { + for(unsigned i = 0; i < AF_MAX_DIMS; ++i) if(dims[i] != 1) { // found a non-1 dim size - populate newDims newDims[newDimIdx] = dims[i]; newDimIdx++; } - } return newDims; } @@ -197,13 +190,11 @@ af::array condenseIndices( const bool isFlat /* = false */ ) { // Fast path - return the Array as is if keepDims - don't consolidate - if(keepDims) { + if(keepDims) return arr; - } // Fast path - Array has zero elements or a dim of size zero - if(arr.elements() == 0) { + if(arr.elements() == 0) return arr; - } const af::dim4& dims = arr.dims(); af::dim4 newDims(1, 1, 1, 1); @@ -226,11 +217,10 @@ af::array condenseIndices( } // Only change dims if condensing is possible - if(newDims != arr.dims()) { + if(newDims != arr.dims()) return af::moddims(arr, newDims); - } else { + else return arr; - } } af_source flToAfLocation(Location location) { @@ -258,9 +248,8 @@ af::array fromFlData( af_source loc = detail::flToAfLocation(memoryLocation); // No or null buffer - if(!ptr) { + if(!ptr) return af::array(dims, afType); - } using af::dtype; switch(afType) { diff --git a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp index 7dc12ce..91ddb82 100644 --- a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp +++ b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp @@ -39,30 +39,27 @@ namespace { constexpr double kMB = static_cast(1UL << 20); size_t roundSize(size_t size) { - if(size < kMinBlockSize) { + if(size < kMinBlockSize) return kMinBlockSize; - } else { + else return kMinBlockSize * ((size + kMinBlockSize - 1) / kMinBlockSize); - } } size_t getAllocationSize(size_t size) { - if(size <= kSmallSize) { + if(size <= kSmallSize) return kSmallBuffer; - } else if(size < kMinLargeAlloc) { + else if(size < kMinLargeAlloc) return kLargeBuffer; - } else { + else return kRoundLarge * ((size + kRoundLarge - 1) / kRoundLarge); - } } static bool BlockComparator( const CachingMemoryManager::Block* a, const CachingMemoryManager::Block* b ) { - if(a->size_ != b->size_) { + if(a->size_ != b->size_) return a->size_ < b->size_; - } return (uintptr_t) a->ptr_ < (uintptr_t) b->ptr_; } @@ -110,12 +107,11 @@ CachingMemoryManager::CachingMemoryManager( getEnvAsBytesFromFloatMb(kMemRecyclingSize, recyclingSizeLimit_); splitSizeLimit_ = getEnvAsBytesFromFloatMb(kMemSplitSize, splitSizeLimit_); - for(int i = 0; i < numDevices; ++i) { + for(int i = 0; i < numDevices; ++i) deviceMemInfos_.emplace( i, std::make_unique(i) ); - } } void CachingMemoryManager::initialize() {} @@ -133,9 +129,8 @@ void CachingMemoryManager::shutdown() { } void CachingMemoryManager::addMemoryManagement(int device) { - if(deviceMemInfos_.find(device) != deviceMemInfos_.end()) { + if(deviceMemInfos_.find(device) != deviceMemInfos_.end()) return; - } deviceMemInfos_.emplace( device, std::make_unique(device) @@ -143,9 +138,8 @@ void CachingMemoryManager::addMemoryManagement(int device) { } void CachingMemoryManager::removeMemoryManagement(int device) { - if(deviceMemInfos_.find(device) == deviceMemInfos_.end()) { + if(deviceMemInfos_.find(device) == deviceMemInfos_.end()) return; - } deviceMemInfos_.erase(device); } @@ -158,12 +152,10 @@ void* CachingMemoryManager::alloc( auto& memoryInfo = getDeviceMemoryInfo(); std::lock_guard lock(memoryInfo.mutexAll_); size_t size = elementSize; - for(unsigned i = 0; i < ndims; ++i) { + for(unsigned i = 0; i < ndims; ++i) size *= dims[i]; - } - if(size == 0) { + if(size == 0) return nullptr; - } size = roundSize(size); const bool isSmallAlloc = (size <= kSmallSize); CachingMemoryManager::Block searchKey(size); @@ -203,9 +195,8 @@ void* CachingMemoryManager::alloc( remaining = block; block = new Block(size, block->ptr_); block->prev_ = remaining->prev_; - if(block->prev_) { + if(block->prev_) block->prev_->next_ = block; - } block->next_ = remaining; remaining->prev_ = block; @@ -222,22 +213,19 @@ void* CachingMemoryManager::alloc( } size_t CachingMemoryManager::allocated(void* ptr) { - if(!ptr) { + if(!ptr) return 0; - } auto& memoryInfo = getDeviceMemoryInfo(); std::lock_guard lock(memoryInfo.mutexAll_); auto it = memoryInfo.allocatedBlocks_.find(ptr); - if(it == memoryInfo.allocatedBlocks_.end()) { + if(it == memoryInfo.allocatedBlocks_.end()) return 0; - } return (it->second)->size_; } void CachingMemoryManager::unlock(void* ptr, bool userUnlock) { - if(!ptr) { + if(!ptr) return; - } auto& memoryInfo = getDeviceMemoryInfo(); std::lock_guard lock(memoryInfo.mutexAll_); auto it = memoryInfo.allocatedBlocks_.find(ptr); @@ -249,24 +237,21 @@ void CachingMemoryManager::unlock(void* ptr, bool userUnlock) { } CachingMemoryManager::Block* block = it->second; - if(userUnlock) { + if(userUnlock) block->userLock_ = false; - } else { + else block->managerLock_ = false; - } // Return early if either one is locked - if(block->inUse()) { + if(block->inUse()) return; - } memoryInfo.allocatedBlocks_.erase(it); freeBlock(block); } void CachingMemoryManager::freeBlock(CachingMemoryManager::Block* block) { - if(block->inUse()) { + if(block->inUse()) throw std::runtime_error("trying to free a block which is in use"); - } auto& memoryInfo = getDeviceMemoryInfo(); std::lock_guard lock(memoryInfo.mutexAll_); @@ -286,20 +271,17 @@ void CachingMemoryManager::tryMergeBlocks( CachingMemoryManager::Block* src, BlockSet& pool ) { - if(!src || src->inUse()) { + if(!src || src->inUse()) return; - } if(dst->prev_ == src) { dst->ptr_ = src->ptr_; dst->prev_ = src->prev_; - if(dst->prev_) { + if(dst->prev_) dst->prev_->next_ = dst; - } } else { dst->next_ = src->next_; - if(dst->next_) { + if(dst->next_) dst->next_->prev_ = dst; - } } dst->size_ += src->size_; pool.erase(src); @@ -357,9 +339,8 @@ void CachingMemoryManager::freeBlocks( ++it; blocks.erase(cur); delete block; - } else { + } else ++it; - } } } @@ -412,9 +393,8 @@ void CachingMemoryManager::printInfo( } void CachingMemoryManager::userLock(const void* ptr) { - if(!ptr) { + if(!ptr) return; - } auto& memoryInfo = getDeviceMemoryInfo(); std::lock_guard lock(memoryInfo.mutexAll_); @@ -425,9 +405,8 @@ void CachingMemoryManager::userLock(const void* ptr) { block->managerLock_ = false; block->userLock_ = true; memoryInfo.allocatedBlocks_[block->ptr_] = block; - } else { + } else it->second->userLock_ = true; - } } void CachingMemoryManager::userUnlock(const void* ptr) { @@ -435,26 +414,22 @@ void CachingMemoryManager::userUnlock(const void* ptr) { } bool CachingMemoryManager::isUserLocked(const void* ptr) { - if(!ptr) { + if(!ptr) return false; - } auto& memoryInfo = getDeviceMemoryInfo(); std::lock_guard lock(memoryInfo.mutexAll_); auto it = memoryInfo.allocatedBlocks_.find(const_cast(ptr)); - if(it == memoryInfo.allocatedBlocks_.end()) { + if(it == memoryInfo.allocatedBlocks_.end()) return false; - } return it->second->userLock_; } CachingMemoryManager::DeviceMemoryInfo& CachingMemoryManager::getDeviceMemoryInfo(int device /* = -1*/) { - if(device == -1) { + if(device == -1) device = this->deviceInterface->getActiveDeviceId(); - } auto it = deviceMemInfos_.find(device); - if(it == deviceMemInfos_.end() || !it->second) { + if(it == deviceMemInfos_.end() || !it->second) throw std::runtime_error("meminfo for the device doesn't exist"); - } return *(it->second); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp index 31be137..959f8f9 100644 --- a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp +++ b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp @@ -31,9 +31,8 @@ DefaultMemoryManager::MemoryInfo& DefaultMemoryManager::getCurrentMemoryInfo() { } void DefaultMemoryManager::cleanDeviceMemoryManager(int device) { - if(this->debugMode) { + if(this->debugMode) return; - } // This vector is used to store the pointers which will be deleted by // the memory manager. We are using this to avoid calling free while @@ -44,9 +43,8 @@ void DefaultMemoryManager::cleanDeviceMemoryManager(int device) { { std::lock_guard lock(this->memoryMutex); // Return if all buffers are locked - if(current.totalBuffers == current.lockBuffers) { + if(current.totalBuffers == current.lockBuffers) return; - } freePtrs.reserve(current.freeMap.size()); for(auto& kv : current.freeMap) { @@ -71,9 +69,8 @@ void DefaultMemoryManager::cleanDeviceMemoryManager(int device) { this->log(ss.str()); // Free memory outside of the lock - for(auto ptr : freePtrs) { + for(auto ptr : freePtrs) this->deviceInterface->nativeFree(ptr); - } } DefaultMemoryManager::DefaultMemoryManager( @@ -88,17 +85,14 @@ DefaultMemoryManager::DefaultMemoryManager( memory(numDevices) { // Check for environment variables // Debug mode - if(const char* c = std::getenv("AF_MEM_DEBUG")) { + if(const char* c = std::getenv("AF_MEM_DEBUG")) this->debugMode = (std::string(c) != "0"); - } - if(this->debugMode) { + if(this->debugMode) memStepSize = 1; - } // Max Buffer count - if(const char* c = std::getenv("AF_MAX_BUFFERS")) { + if(const char* c = std::getenv("AF_MAX_BUFFERS")) this->maxBuffers = std::max(1, std::stoi(std::string(c))); - } } void DefaultMemoryManager::initialize() { @@ -112,9 +106,8 @@ void DefaultMemoryManager::shutdown() { void DefaultMemoryManager::addMemoryManagement(int device) { // If there is a memory manager allocated for this device id, we might // as well use it and the buffers allocated for it - if(static_cast(device) < memory.size()) { + if(static_cast(device) < memory.size()) return; - } // Assuming, device need not be always the next device Lets resize to // current_size + device + 1 +1 is to account for device being 0-based @@ -123,9 +116,8 @@ void DefaultMemoryManager::addMemoryManagement(int device) { } void DefaultMemoryManager::removeMemoryManagement(int device) { - if((size_t) device >= memory.size()) { + if((size_t) device >= memory.size()) throw std::runtime_error("No matching device found"); - } // Do garbage collection for the device and leave the // MemoryInfo struct from the memory vector intact @@ -151,9 +143,8 @@ void* DefaultMemoryManager::alloc( const unsigned elementSize ) { size_t bytes = elementSize; - for(unsigned i = 0; i < ndims; ++i) { + for(unsigned i = 0; i < ndims; ++i) bytes *= dims[i]; - } void* ptr = nullptr; size_t allocBytes = @@ -170,9 +161,8 @@ void* DefaultMemoryManager::alloc( if( current.lockBytes >= current.maxBytes || current.totalBuffers >= this->maxBuffers - ) { + ) this->signalMemoryCleanup(); - } std::lock_guard lock(this->memoryMutex); free_iter iter = current.freeMap.find(allocBytes); @@ -215,22 +205,19 @@ void* DefaultMemoryManager::alloc( } size_t DefaultMemoryManager::allocated(void* ptr) { - if(!ptr) { + if(!ptr) return 0; - } MemoryInfo& current = this->getCurrentMemoryInfo(); locked_iter iter = current.lockedMap.find((void*) ptr); - if(iter == current.lockedMap.end()) { + if(iter == current.lockedMap.end()) return 0; - } return (iter->second).bytes; } void DefaultMemoryManager::unlock(void* ptr, bool userUnlock) { // Shortcut for empty arrays - if(!ptr) { + if(!ptr) return; - } // Frees the pointer outside the lock. uptr_t freedPtr( @@ -248,16 +235,14 @@ void DefaultMemoryManager::unlock(void* ptr, bool userUnlock) { return; } - if(userUnlock) { + if(userUnlock) (iter->second).userLock = false; - } else { + else (iter->second).managerLock = false; - } // Return early if either one is locked - if((iter->second).userLock || (iter->second).managerLock) { + if((iter->second).userLock || (iter->second).managerLock) return; - } size_t bytes = iter->second.bytes; current.lockBytes -= iter->second.bytes; @@ -270,9 +255,8 @@ void DefaultMemoryManager::unlock(void* ptr, bool userUnlock) { current.totalBuffers--; current.totalBytes -= iter->second.bytes; } - } else { + } else current.freeMap.at(bytes).emplace_back(ptr); - } current.lockedMap.erase(iter); } } @@ -287,11 +271,10 @@ float DefaultMemoryManager::getMemoryPressure() { if( current.lockBytes > current.maxBytes || current.lockBuffers > maxBuffers - ) { + ) return 1.0; - } else { + else return 0.0; - } } bool DefaultMemoryManager::jitTreeExceedsMemoryPressure(size_t bytes) { @@ -317,11 +300,10 @@ void DefaultMemoryManager::printInfo( for(auto& kv : current.lockedMap) { const char* statusMngr = "Yes"; const char* statusUser = "Unknown"; - if(kv.second.userLock) { + if(kv.second.userLock) statusUser = "Yes"; - } else { + else statusUser = " No"; - } const char* unit = "KB"; double size = static_cast(kv.second.bytes) / 1024; @@ -345,10 +327,9 @@ void DefaultMemoryManager::printInfo( unit = "MB"; } - for(auto& ptr : kv.second) { + for(auto& ptr : kv.second) ostream << "| " << ptr << " | " << size << " " << unit << " | " << statusMngr << " | " << statusUser << " |\n"; - } } ostream << "---------------------------------------------------------\n"; @@ -360,9 +341,9 @@ void DefaultMemoryManager::userLock(const void* ptr) { std::lock_guard lock(this->memoryMutex); locked_iter iter = current.lockedMap.find(const_cast(ptr)); - if(iter != current.lockedMap.end()) { + if(iter != current.lockedMap.end()) iter->second.userLock = true; - } else { + else { LockedInfo info = {false, true, 100}; // This number is not relevant current.lockedMap[(void*) ptr] = info; @@ -377,11 +358,10 @@ bool DefaultMemoryManager::isUserLocked(const void* ptr) { MemoryInfo& current = this->getCurrentMemoryInfo(); std::lock_guard lock(this->memoryMutex); locked_iter iter = current.lockedMap.find(const_cast(ptr)); - if(iter != current.lockedMap.end()) { + if(iter != current.lockedMap.end()) return iter->second.userLock; - } else { + else return false; - } } size_t DefaultMemoryManager::getMemStepSize() { diff --git a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp index cdafd2d..8aea02a 100644 --- a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp +++ b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp @@ -19,15 +19,13 @@ MemoryManagerAdapter::MemoryManagerAdapter( std::ostream* logStream ) : deviceInterface(itf), logStream_(logStream) { - if(!itf) { + if(!itf) throw std::invalid_argument( "MemoryManagerAdapter::MemoryManagerAdapter - " "memory manager device interface is null" ); - } - if(logStream_) { + if(logStream_) loggingEnabled_ = true; - } // Create handle and set payload to point to this instance AF_CHECK(af_create_memory_manager(&interface_)); @@ -59,12 +57,11 @@ void MemoryManagerAdapter::setLoggingEnabled(bool log) { } void MemoryManagerAdapter::setLogFlushInterval(size_t interval) { - if(interval < 1) { + if(interval < 1) throw std::invalid_argument( "MemoryManagerAdapter::setLogFlushInterval - " "flush interval must be great than zero." ); - } logFlushInterval_ = interval; } diff --git a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h index 7ece2fe..aa31e09 100644 --- a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h +++ b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.h @@ -166,12 +166,11 @@ class MemoryManagerAdapter { template void MemoryManagerAdapter::log(std::string fname, Values... vs) { if(loggingEnabled_) { - if(!logStream_) { + if(!logStream_) throw std::runtime_error( "MemoryManagerAdapter::log: cannot write to logStream_" " - stream is invalid or uninitialized" ); - } logStreamBuffer_ << fname << " "; int unpack[]{0, (logStreamBuffer_ << std::to_string(vs) << " ", 0)...}; static_cast(unpack); diff --git a/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp b/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp index d6d6859..b34aa55 100644 --- a/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp +++ b/flashlight/fl/tensor/backend/af/mem/MemoryManagerInstaller.cpp @@ -33,20 +33,18 @@ MemoryManagerAdapter* MemoryManagerInstaller::getImpl( MemoryManagerInstaller::MemoryManagerInstaller( std::shared_ptr managerImpl ) : impl_(managerImpl) { - if(!impl_) { + if(!impl_) throw std::invalid_argument( "MemoryManagerInstaller::MemoryManagerInstaller - " "passed MemoryManagerAdapter is null" ); - } af_memory_manager itf = impl_->getHandle(); - if(!impl_->getHandle()) { + if(!impl_->getHandle()) throw std::invalid_argument( "MemoryManagerInstaller::MemoryManagerInstaller - " "passed MemoryManagerAdapter has null handle" ); - } // Set appropriate function pointers for each class method auto initializeFn = [](af_memory_manager manager) { diff --git a/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp b/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp index d0b1342..26961f6 100644 --- a/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp +++ b/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp @@ -143,9 +143,8 @@ TEST(AutogradBinaryOpsTest, Linear) { } TEST_F(AutogradTestF16, LinearF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } std::vector batchsizes = {1, 5}; const float scale = 4.0; // scale prevent grad underflow diff --git a/flashlight/fl/test/autograd/AutogradConv2DTest.cpp b/flashlight/fl/test/autograd/AutogradConv2DTest.cpp index 21329a5..0ade290 100644 --- a/flashlight/fl/test/autograd/AutogradConv2DTest.cpp +++ b/flashlight/fl/test/autograd/AutogradConv2DTest.cpp @@ -80,9 +80,8 @@ TEST(AutogradConv2DTest, Convolve) { } TEST_F(AutogradTestF16, ConvolveF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } const float scaleFactor = 10.0; // scale the input to prevent grad underflow auto in = diff --git a/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp b/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp index 51f69b7..5eb4099 100644 --- a/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp +++ b/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp @@ -104,9 +104,8 @@ TEST(AutogradNormalizationTest, BatchNormEvalModeOutputMultipleAxis) { auto input = Variable(fl::rand({13, 13, 4, 16}), false); auto nfeatures = 1; - for(auto ax : featAxes) { + for(auto ax : featAxes) nfeatures *= input.dim(ax); - } auto runningMean = Variable(fl::rand({nfeatures}, input.type()), false); auto runningVar = Variable(fl::rand({nfeatures}, input.type()), false); auto weight = Variable(fl::rand({nfeatures}, input.type()), false); @@ -203,9 +202,8 @@ TEST(AutogradNormalizationTest, BatchNormTrainModeOutputMultipleAxis) { auto input = Variable(fl::rand({13, 13, 4, 8}), true); auto nfeatures = 1; - for(auto ax : featAxes) { + for(auto ax : featAxes) nfeatures *= input.dim(ax); - } auto weight = Variable(fl::rand({nfeatures}), true); auto bias = Variable(fl::rand({nfeatures}), true); auto runningMean = Variable(fl::rand({nfeatures}), false); @@ -305,9 +303,8 @@ TEST(AutogradNormalizationTest, BatchNormJacobian) { } TEST_F(AutogradTestF16, BatchNormJacobianF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } // Jacobian Test with trainMode = true; @@ -372,9 +369,8 @@ TEST(AutogradNormalizationTest, BatchNormJacobianMultipleAxes) { std::vector featAxes = {0, 1, 2}; auto input = Variable(fl::rand({4, 4, 3, 4}, fl::dtype::f32), true); auto nfeatures = 1; - for(auto ax : featAxes) { + for(auto ax : featAxes) nfeatures *= input.dim(ax); - } auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); @@ -427,17 +423,15 @@ TEST(AutogradNormalizationTest, BatchNormJacobianMultipleAxes) { } TEST_F(AutogradTestF16, BatchNormJacobianMultipleAxesF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } // Jacobian Test with trainMode = true; std::vector featAxes = {0, 1, 2}; auto input = Variable(fl::rand({2, 2, 2, 1}, fl::dtype::f16), true); auto nfeatures = 1; - for(auto ax : featAxes) { + for(auto ax : featAxes) nfeatures *= input.dim(ax); - } auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); @@ -503,9 +497,8 @@ TEST(AutogradNormalizationTest, LayerNormJacobian) { std::vector featAxes = {0, 1, 2, 3}; auto input = Variable(fl::rand({7, 7, 3, 10}), true); auto nfeatures = 1; - for(auto ax : featAxes) { + for(auto ax : featAxes) nfeatures *= input.dim(ax); - } auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); @@ -529,18 +522,16 @@ TEST(AutogradNormalizationTest, LayerNormJacobian) { } TEST_F(AutogradTestF16, LayerNormJacobianF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } std::vector featAxes = {0, 1, 2, 3}; const float inputScale = 4.0; // scale the input to prevent grad underflow auto input = Variable(inputScale * fl::rand({2, 2, 2, 4}, fl::dtype::f16), true); auto nfeatures = 1; - for(auto ax : featAxes) { + for(auto ax : featAxes) nfeatures *= input.dim(ax); - } auto runningMean = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); auto runningVar = Variable(fl::rand({nfeatures}, fl::dtype::f32), false); auto weight = Variable(fl::rand({nfeatures}, fl::dtype::f32), true); diff --git a/flashlight/fl/test/autograd/AutogradReductionTest.cpp b/flashlight/fl/test/autograd/AutogradReductionTest.cpp index c0d1690..c9425ed 100644 --- a/flashlight/fl/test/autograd/AutogradReductionTest.cpp +++ b/flashlight/fl/test/autograd/AutogradReductionTest.cpp @@ -20,9 +20,8 @@ using fl::detail::AutogradTestF16; TEST(AutogradReductionTest, Sum) { for(const bool keepDims : {false, true}) { Shape s = {6}; - if(keepDims) { + if(keepDims) s = {6, 1}; - } auto x = Variable(fl::rand(s), true); auto y = Variable(fl::rand({6, 3}), true); @@ -113,7 +112,7 @@ TEST(AutogradReductionTest, Mean) { TEST(AutogradReductionTest, Variance) { std::vector biased = {true, false}; - for(auto b : biased) { + for(auto b : biased) for(const bool keepDims : {false, true}) { auto x = Variable(fl::rand({5, 6, 7, 8}, fl::dtype::f64), true); @@ -134,7 +133,6 @@ TEST(AutogradReductionTest, Variance) { }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcVar, x, 1E-5, 1E-5)); } - } } TEST(AutogradReductionTest, Norm) { diff --git a/flashlight/fl/test/autograd/AutogradRnnTest.cpp b/flashlight/fl/test/autograd/AutogradRnnTest.cpp index 08337d2..1195583 100644 --- a/flashlight/fl/test/autograd/AutogradRnnTest.cpp +++ b/flashlight/fl/test/autograd/AutogradRnnTest.cpp @@ -184,48 +184,42 @@ void testRnnImpl(RnnMode mode, fl::dtype precision = fl::dtype::f64) { } TEST(AutogradRnnTest, Rnn) { - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) GTEST_SKIP() << "RNN gradient computation not yet supported on CPU"; - } testRnnImpl(RnnMode::TANH); } TEST(AutogradRnnTest, Lstm) { - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) GTEST_SKIP() << "RNN LSTM graident computation not yet supported on CPU"; - } testRnnImpl(RnnMode::LSTM); } TEST(AutogradRnnTest, Gru) { - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) GTEST_SKIP() << "RNN GRU graident computation not yet supported on CPU"; - } testRnnImpl(RnnMode::GRU); } TEST_F(AutogradTestF16, RnnF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } testRnnImpl(RnnMode::TANH, fl::dtype::f16); } TEST_F(AutogradTestF16, LstmF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } testRnnImpl(RnnMode::LSTM, fl::dtype::f16); } TEST_F(AutogradTestF16, GruF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } testRnnImpl(RnnMode::GRU, fl::dtype::f16); } diff --git a/flashlight/fl/test/autograd/AutogradTest.cpp b/flashlight/fl/test/autograd/AutogradTest.cpp index 46015a8..ffee958 100644 --- a/flashlight/fl/test/autograd/AutogradTest.cpp +++ b/flashlight/fl/test/autograd/AutogradTest.cpp @@ -32,9 +32,8 @@ TEST(AutogradTest, OperatorParenthesis) { } TEST(AutogradTest, AutogradOperatorTypeCompatibility) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } auto f16 = Variable(fl::rand({2, 2}, fl::dtype::f16), true); auto f32 = Variable(fl::rand({2, 2}, fl::dtype::f32), true); @@ -192,9 +191,8 @@ TEST(AutogradTest, AutogradOperatorTypeCompatibility) { } TEST(AutogradTest, CastingAsDifferentGradTypes) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } auto f32 = Variable(fl::rand({5, 5}), true); auto f16 = Variable(fl::rand({5, 5}, fl::dtype::f16), true); @@ -207,9 +205,8 @@ TEST(AutogradTest, CastingAsDifferentGradTypes) { } TEST(AutogradTest, CastingAs) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } auto var = Variable(fl::rand({5, 5}), true); auto varF16 = var.astype(fl::dtype::f16); @@ -219,9 +216,8 @@ TEST(AutogradTest, CastingAs) { } TEST(AutogradTest, CastingAsBackward) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } auto a = Variable(fl::rand({4, 4}, fl::dtype::f16), true); auto b = Variable(fl::rand({4, 4}, fl::dtype::f16), false); @@ -234,9 +230,8 @@ TEST(AutogradTest, CastingAsBackward) { } TEST(AutogradTest, CastingAsGrad) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } // compare to f32 case auto x = Variable(fl::full({5}, 2.0), true); @@ -361,9 +356,8 @@ TEST(AutogradTest, TileAs) { } TEST_F(AutogradTestF16, TileAsF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } auto x = Variable(fl::rand({5}, fl::dtype::f16), true); auto y = Variable(fl::rand({5, 2}, fl::dtype::f16), true); @@ -441,9 +435,8 @@ TEST(AutogradTest, Pooling) { } TEST_F(AutogradTestF16, PoolingF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } const float inputScale = 2.0; // scale the input to prevent grad underflow auto in = Variable(inputScale * fl::rand({3, 3, 1, 1}, fl::dtype::f16), true); @@ -470,10 +463,9 @@ TEST(AutogradTest, Embedding) { TEST(AutogradTest, GetAdvancedIndex) { // TODO: remove me - if(!FL_BACKEND_CUDA) { + if(!FL_BACKEND_CUDA) GTEST_SKIP() << "Advanced indexing operator unsupported for non-CUDA backends"; - } std::vector validIndexTypes = { fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64}; for(const auto& dtype : validIndexTypes) { @@ -500,13 +492,11 @@ TEST(AutogradTest, GetAdvancedIndex) { TEST(AutogradTest, GetAdvancedIndexF16) { // TODO: remove me - if(!FL_BACKEND_CUDA) { + if(!FL_BACKEND_CUDA) GTEST_SKIP() << "Advanced indexing operator unsupported for non-CUDA backends"; - } - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } std::vector validIndexTypes = { fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64}; for(const auto& dtype : validIndexTypes) { diff --git a/flashlight/fl/test/autograd/AutogradTestUtils.h b/flashlight/fl/test/autograd/AutogradTestUtils.h index 99f229a..6c4cf55 100644 --- a/flashlight/fl/test/autograd/AutogradTestUtils.h +++ b/flashlight/fl/test/autograd/AutogradTestUtils.h @@ -61,9 +61,8 @@ namespace detail { for(int i = 0; i < dout.elements(); ++i) { dout.tensor().flat(i) = 1; // element in 1D view input.zeroGrad(); - for(auto* var : zeroGradientVariables) { + for(auto* var : zeroGradientVariables) var->zeroGrad(); - } auto out = func(input); out.backward(dout); diff --git a/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp b/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp index 0ea5d22..dddba75 100644 --- a/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp +++ b/flashlight/fl/test/autograd/AutogradUnaryOpsTest.cpp @@ -134,9 +134,8 @@ TEST(AutogradUnaryOpsTest, Softmax) { } TEST_F(AutogradTestF16, SoftmaxF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } auto in = Variable(fl::rand({3, 5, 1}, fl::dtype::f16), true); auto funcSm = [&](Variable& input) { return softmax(input, 0); }; @@ -152,9 +151,8 @@ TEST(AutogradUnaryOpsTest, LogSoftmax) { } TEST_F(AutogradTestF16, LogSoftmaxF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } auto in = Variable(fl::rand({3, 5, 1}, fl::dtype::f16), true); auto funcLsm = [&](Variable& input) { return logSoftmax(input, 0); }; diff --git a/flashlight/fl/test/common/DynamicBenchmarkTest.cpp b/flashlight/fl/test/common/DynamicBenchmarkTest.cpp index 06b811c..d22b1ec 100644 --- a/flashlight/fl/test/common/DynamicBenchmarkTest.cpp +++ b/flashlight/fl/test/common/DynamicBenchmarkTest.cpp @@ -36,9 +36,8 @@ TEST_F(DynamicBenchmark, OptionsStateBasic) { ASSERT_FALSE(options->timingsComplete()); ASSERT_EQ(options->currentOption(), 1); - for(size_t i = 0; i < maxCount * ops.size(); ++i) { + for(size_t i = 0; i < maxCount * ops.size(); ++i) options->accumulateTimeToCurrentOption(1); - } ASSERT_TRUE(options->timingsComplete()); ASSERT_EQ(options->currentOption(), 1); // best idx should never have changed } @@ -64,9 +63,9 @@ TEST_F(DynamicBenchmark, OptionsStateTimed) { for(size_t i = 0; i < maxCount * ops.size(); ++i) { // option 4 is faster - if(options->currentOption() == 4) { + if(options->currentOption() == 4) options->accumulateTimeToCurrentOption(1); - } else { + else { options->accumulateTimeToCurrentOption( 10 * (i + 1), /* incrementCount = */ false diff --git a/flashlight/fl/test/common/HistogramTest.cpp b/flashlight/fl/test/common/HistogramTest.cpp index 5bdaab9..57ecffc 100644 --- a/flashlight/fl/test/common/HistogramTest.cpp +++ b/flashlight/fl/test/common/HistogramTest.cpp @@ -30,9 +30,8 @@ TEST(FixedBucketSizeHistogram, NormalDistribution) { std::normal_distribution distribution(mean, stddev); std::vector data(nValues); - for(int i = 0; i < nValues; ++i) { + for(int i = 0; i < nValues; ++i) data[i] = distribution(generator); - } HistogramStats hist = FixedBucketSizeHistogram(data.begin(), data.end(), nBuckes); @@ -57,9 +56,8 @@ TEST(FixedBucketSizeHistogram, NormalDistribution) { // Verify bounds span the range. EXPECT_EQ(hist.buckets[0].startInclusive, hist.min); - for(int i = 0; i < (nBuckes - 1); ++i) { + for(int i = 0; i < (nBuckes - 1); ++i) EXPECT_EQ(hist.buckets[i + 1].startInclusive, hist.buckets[i].endExclusive); - } EXPECT_EQ(hist.buckets[nBuckes - 1].endExclusive, hist.max); std::cout << hist.prettyString() << std::endl; @@ -81,9 +79,8 @@ TEST(FixedBucketSizeHistogram, ExponentialDistribution) { std::exponential_distribution distribution(0.1); std::vector data(nValues); - for(int i = 0; i < nValues; ++i) { + for(int i = 0; i < nValues; ++i) data[i] = distribution(generator) * multiplier; - } HistogramStats hist = FixedBucketSizeHistogram(data.begin(), data.end(), nBuckes); @@ -97,15 +94,13 @@ TEST(FixedBucketSizeHistogram, ExponentialDistribution) { EXPECT_GT(hist.maxNumValuesPerBucket, nValues / nBuckes); // Verify exponential distribution. - for(int i = 0; i < (nBuckes - 1); ++i) { + for(int i = 0; i < (nBuckes - 1); ++i) EXPECT_GT(hist.buckets[i].count, hist.buckets[i + 1].count); - } // Verify bounds span the range. EXPECT_EQ(hist.buckets[0].startInclusive, hist.min); - for(int i = 0; i < (nBuckes - 1); ++i) { + for(int i = 0; i < (nBuckes - 1); ++i) EXPECT_EQ(hist.buckets[i + 1].startInclusive, hist.buckets[i].endExclusive); - } EXPECT_GE(hist.buckets[nBuckes - 1].endExclusive, hist.max); std::cout << hist.prettyString() << std::endl; diff --git a/flashlight/fl/test/common/LoggingTest.cpp b/flashlight/fl/test/common/LoggingTest.cpp index 6a8a784..ea9cdea 100644 --- a/flashlight/fl/test/common/LoggingTest.cpp +++ b/flashlight/fl/test/common/LoggingTest.cpp @@ -43,17 +43,15 @@ TEST(Logging, vlogOnOff) { // Prints to stderr EXPECT_THAT(stderrBuffer.str(), HasSubstr("vlog-0")); - if(i >= 1) { + if(i >= 1) EXPECT_THAT(stderrBuffer.str(), HasSubstr("vlog-1")); - } else { + else EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("vlog-1"))); - } - if(i >= 10) { + if(i >= 10) EXPECT_THAT(stderrBuffer.str(), HasSubstr("vlog-10")); - } else { + else EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("vlog-10"))); - } // Does not print to stdout EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("vlog-0"))); @@ -95,27 +93,24 @@ TEST(Logging, logOnOff) { FL_LOG(fl::LogLevel::ERROR) << "log-error"; // Prints to stderr - if(l >= fl::LogLevel::INFO) { + if(l >= fl::LogLevel::INFO) EXPECT_THAT(stderrBuffer.str(), HasSubstr("log-info")); - } else { + else EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-info"))); - } - if(l >= fl::LogLevel::WARNING) { + if(l >= fl::LogLevel::WARNING) EXPECT_THAT(stderrBuffer.str(), HasSubstr("log-warning")); - } else { + else EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-warning"))); - } // Does not print to stdout EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("log-info"))); EXPECT_THAT(stdoutBuffer.str(), Not(HasSubstr("log-warning"))); - if(l >= fl::LogLevel::ERROR) { + if(l >= fl::LogLevel::ERROR) EXPECT_THAT(stderrBuffer.str(), HasSubstr("log-error")); - } else { + else EXPECT_THAT(stderrBuffer.str(), Not(HasSubstr("log-error"))); - } } std::cout.rdbuf(origStdoutBuffer); diff --git a/flashlight/fl/test/common/UtilsTest.cpp b/flashlight/fl/test/common/UtilsTest.cpp index 055181e..282b3c2 100644 --- a/flashlight/fl/test/common/UtilsTest.cpp +++ b/flashlight/fl/test/common/UtilsTest.cpp @@ -19,11 +19,10 @@ using namespace fl; static std::function makeSucceedsAfterIters(int iters) { auto state = std::make_shared(0); return [state, iters]() { - if(++*state >= iters) { + if(++*state >= iters) return 42; - } else { + else throw std::runtime_error("bleh"); - } }; } @@ -32,14 +31,12 @@ static std::function makeSucceedsAfterMs(double ms) { auto state = std::make_shared>(); return [state, ms]() { auto now = steady_clock::now(); - if(state->time_since_epoch().count() == 0) { + if(state->time_since_epoch().count() == 0) *state = now; - } - if(now - *state >= duration(ms)) { + if(now - *state >= duration(ms)) return 42; - } else { + else throw std::runtime_error("bleh"); - } }; } @@ -90,15 +87,12 @@ TEST(SystemTest, RetryWithBackoff) { invalids.push_back(retryAsync(ms50, 2.0, 0, alwaysSucceeds)); invalids.push_back(retryAsync(ms50, 2.0, -1, alwaysSucceeds)); - for(auto& fut : goods) { + for(auto& fut : goods) ASSERT_EQ(fut.get(), 42); - } - for(auto& fut : bads) { + for(auto& fut : bads) ASSERT_THROW(fut.get(), std::runtime_error); - } - for(auto& fut : invalids) { + for(auto& fut : invalids) ASSERT_THROW(fut.get(), std::invalid_argument); - } // check special case promise / future auto alwaysSucceedsVoid = []() -> void {}; diff --git a/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp b/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp index 54b05c5..d6295ef 100644 --- a/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp +++ b/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp @@ -157,9 +157,9 @@ void transformerPadMaskFwd(bool isfp16) { ASSERT_EQ(output.dim(1), timesteps); ASSERT_EQ(output.dim(2), 2); - if(OptimMode::get().getOptimLevel() == OptimLevel::O3) { + if(OptimMode::get().getOptimLevel() == OptimLevel::O3) ASSERT_EQ(outputNoPad.type(), input.type()); - } else { + else { ASSERT_EQ(outputNoPad.type(), fl::dtype::f32); // result is upcast } @@ -196,9 +196,8 @@ TEST(ContribModuleTest, TransformerPadMaskFwd) { } TEST_F(ContribModuleTestF16, TransformerPadMaskFwd16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } transformerPadMaskFwd(true); } @@ -215,9 +214,9 @@ void transformerFwd(bool isfp16) { fl::Variable padMask; auto output = tr.forward({input, padMask}); - if(OptimMode::get().getOptimLevel() == OptimLevel::O3) { + if(OptimMode::get().getOptimLevel() == OptimLevel::O3) ASSERT_EQ(output[0].type(), input.type()); - } else { + else { ASSERT_EQ(output[0].type(), fl::dtype::f32); // result is upcast } @@ -237,9 +236,8 @@ TEST(ContribModuleTest, TransformerFwd) { } TEST_F(ContribModuleTestF16, TransformerFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } transformerFwd(true); } @@ -254,9 +252,9 @@ void conformerFwd(bool isfp16) { auto input = Variable(fl::rand({c, timesteps, batchsize}, dtype), false); auto output = tr.forward({input, Variable()}); - if(OptimMode::get().getOptimLevel() == OptimLevel::O3) { + if(OptimMode::get().getOptimLevel() == OptimLevel::O3) ASSERT_EQ(output[0].type(), input.type()); - } else { + else { ASSERT_EQ(output[0].type(), fl::dtype::f32); // result is upcast } @@ -270,9 +268,8 @@ TEST(ContribModuleTest, ConformerFwd) { } TEST_F(ContribModuleTestF16, ConformerFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } conformerFwd(true); } @@ -299,9 +296,8 @@ TEST(ContribModuleTest, PositionEmbeddingFwd) { } TEST_F(ContribModuleTestF16, PositionEmbeddingFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } positionEmbeddingFwd(true); } @@ -321,9 +317,8 @@ void sinusoidalPositionEmbeddingFwd(bool isfp16) { ASSERT_EQ(output[0].dim(1), timesteps); ASSERT_EQ(output[0].dim(2), batchsize); auto castOutput = output[0].tensor(); - if(isfp16) { + if(isfp16) castOutput = output[0].astype(fl::dtype::f32).tensor(); - } ASSERT_TRUE((fl::amax(castOutput, {0})).scalar() <= 2); ASSERT_TRUE((fl::amin(castOutput, {0})).scalar() >= -2); } @@ -333,9 +328,8 @@ TEST(ContribModuleTest, SinusoidalPositionEmbeddingFwd) { } TEST_F(ContribModuleTestF16, SinusoidalPositionEmbeddingFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } sinusoidalPositionEmbeddingFwd(true); } @@ -375,9 +369,8 @@ TEST(ContribModuleTest, TDSFwd) { } TEST_F(ContribModuleTestF16, TDSFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } tdsFwd(true); } @@ -407,9 +400,8 @@ TEST(ContribModuleTest, StreamingTDSFwd) { } TEST_F(ContribModuleTestF16, StreamingTDSFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } streamingTDSFwd(true); } @@ -426,13 +418,12 @@ TEST(ContribModuleTest, SpecAugmentFwd) { ASSERT_FALSE(fl::allClose(input, output)); // Every value of output is either 0 or input - for(int t = 0; t < T; ++t) { + for(int t = 0; t < T; ++t) for(int f = 0; f < F; ++f) { auto o = output.tensor()(t, f).scalar(); auto i = input.tensor()(t, f).scalar(); ASSERT_TRUE(o == i || o == 0); } - } // non-zero time frames are masked int tZeros = 0; @@ -492,9 +483,8 @@ TEST(ContribModuleTest, RawWavSpecAugmentFwd) { } TEST_F(ContribModuleTestF16, RawWavSpecAugmentFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } computeRawWavSpecAug(true, 1e-2); } diff --git a/flashlight/fl/test/dataset/DatasetTest.cpp b/flashlight/fl/test/dataset/DatasetTest.cpp index f8250d2..258e53e 100644 --- a/flashlight/fl/test/dataset/DatasetTest.cpp +++ b/flashlight/fl/test/dataset/DatasetTest.cpp @@ -217,11 +217,10 @@ TEST(DatasetTest, FileBlobDataset) { std::vector sample; for(int64_t j = 0; j < i % 4; j++) { Tensor tensor; - if(j % 2 == 0) { + if(j % 2 == 0) tensor = fl::rand({100, 3, 100}); - } else { + else tensor = fl::rand({100, 200}); - } sample.push_back(tensor); } data.push_back(sample); @@ -236,12 +235,11 @@ TEST(DatasetTest, FileBlobDataset) { auto blobSample = blob.get(i); auto datSample = data.at(i); ASSERT_EQ(datSample.size(), blobSample.size()); - for(int64_t j = 0; j < blobSample.size(); j++) { + for(int64_t j = 0; j < blobSample.size(); j++) ASSERT_TRUE( fl::norm(datSample.at(j).flatten() - blobSample.at(j).flatten()) .scalar() <= 1e-05 ); - } } }; @@ -272,27 +270,22 @@ TEST(DatasetTest, FileBlobDataset) { check(blob); // check hostTransform - for(auto& vec : data) { - if(!vec.empty()) { + for(auto& vec : data) + if(!vec.empty()) vec[0] += 1; - } - } blob.setHostTransform( 0, [](void* ptr, fl::Shape size, fl::dtype /* type */) { float* ptrFl = (float*) ptr; - for(int64_t i = 0; i < size.elements(); i++) { + for(int64_t i = 0; i < size.elements(); i++) ptrFl[i] += 1; - } return Tensor::fromBuffer(size, ptrFl, MemoryLocation::Host); } - ); + ); check(blob); - for(auto& vec : data) { - if(!vec.empty()) { + for(auto& vec : data) + if(!vec.empty()) vec[0] -= 1; - } - } } // check tensor dim constraints @@ -322,15 +315,13 @@ TEST(DatasetTest, FileBlobDataset) { workers.emplace_back( [i, blob, nperworker, device, &thdata]() { fl::setDevice(device); - for(int j = 0; j < nperworker; j++) { + for(int j = 0; j < nperworker; j++) thdata[i * nperworker + j] = blob->get(i * nperworker + j); - } } ); } - for(int i = 0; i < nworker; i++) { + for(int i = 0; i < nworker; i++) workers[i].join(); - } ASSERT_EQ(data.size(), thdata.size()); for(int64_t i = 0; i < data.size(); i++) { auto thdataSample = thdata.at(i); @@ -349,9 +340,8 @@ TEST(DatasetTest, FileBlobDataset) { // multi-threaded write { // add an index - for(int i = 0; i < data.size(); i++) { + for(int i = 0; i < data.size(); i++) data[i].push_back(fl::full({1}, i, fl::dtype::f32)); - } { auto blob = std::make_shared( fs::temp_directory_path() / "data.blob", @@ -362,19 +352,16 @@ TEST(DatasetTest, FileBlobDataset) { const int nworker = 10; int nperworker = data.size() / nworker; auto device = fl::getDevice(); - for(int i = 0; i < nworker; i++) { + for(int i = 0; i < nworker; i++) workers.emplace_back( [i, blob, nperworker, device, &data]() { fl::setDevice(device); - for(int j = 0; j < nperworker; j++) { + for(int j = 0; j < nperworker; j++) blob->add(data[i * nperworker + j]); - } } ); - } - for(int i = 0; i < nworker; i++) { + for(int i = 0; i < nworker; i++) workers[i].join(); - } blob->writeIndex(); } { @@ -408,11 +395,10 @@ TEST(DatasetTest, MemoryBlobDataset) { std::vector sample; for(int64_t j = 0; j < i % 4; j++) { Tensor tensor; - if(j % 2 == 0) { + if(j % 2 == 0) tensor = fl::rand({100, 3, 100}); - } else { + else tensor = fl::rand({100, 200}); - } sample.push_back(tensor); } data.push_back(sample); @@ -427,12 +413,11 @@ TEST(DatasetTest, MemoryBlobDataset) { auto blobSample = blob.get(i); auto datSample = data.at(i); ASSERT_EQ(datSample.size(), blobSample.size()); - for(int64_t j = 0; j < blobSample.size(); j++) { + for(int64_t j = 0; j < blobSample.size(); j++) ASSERT_TRUE( fl::norm(datSample.at(j).flatten() - blobSample.at(j).flatten()) .scalar() <= 1e-05 ); - } } }; @@ -462,21 +447,18 @@ TEST(DatasetTest, MemoryBlobDataset) { check(blob); // check hostTransform - for(auto& vec : data) { - if(!vec.empty()) { + for(auto& vec : data) + if(!vec.empty()) vec[0] += 1; - } - } blob.setHostTransform( 0, [](void* ptr, fl::Shape size, fl::dtype /* type */) { float* ptrFl = (float*) ptr; - for(int64_t i = 0; i < size.elements(); i++) { + for(int64_t i = 0; i < size.elements(); i++) ptrFl[i] += 1; - } return Tensor::fromBuffer(size, ptrFl, MemoryLocation::Host); } - ); + ); check(blob); } @@ -491,15 +473,13 @@ TEST(DatasetTest, MemoryBlobDataset) { workers.emplace_back( [i, &blob, nperworker, device, &thdata]() { fl::setDevice(device); - for(int j = 0; j < nperworker; j++) { + for(int j = 0; j < nperworker; j++) thdata[i * nperworker + j] = blob.get(i * nperworker + j); - } } ); } - for(int i = 0; i < nworker; i++) { + for(int i = 0; i < nworker; i++) workers[i].join(); - } ASSERT_EQ(data.size(), thdata.size()); for(int64_t i = 0; i < data.size(); i++) { auto thdataSample = thdata.at(i); @@ -519,27 +499,23 @@ TEST(DatasetTest, MemoryBlobDataset) { { MemoryBlobDataset wblob; // add an index - for(int i = 0; i < data.size(); i++) { + for(int i = 0; i < data.size(); i++) data[i].push_back(fl::full({1}, i, fl::dtype::f32)); - } { std::vector workers; const int nworker = 10; int nperworker = data.size() / nworker; auto device = fl::getDevice(); - for(int i = 0; i < nworker; i++) { + for(int i = 0; i < nworker; i++) workers.emplace_back( [i, &wblob, nperworker, device, &data]() { fl::setDevice(device); - for(int j = 0; j < nperworker; j++) { + for(int j = 0; j < nperworker; j++) wblob.add(data[i * nperworker + j]); - } } ); - } - for(int i = 0; i < nworker; i++) { + for(int i = 0; i < nworker; i++) workers[i].join(); - } wblob.writeIndex(); } { @@ -580,9 +556,8 @@ TEST(DatasetTest, PrefetchDatasetCorrectness) { auto sample1 = transformDs->get(i); auto sample2 = prefetchDs->get(i); ASSERT_EQ(sample1.size(), sample2.size()); - for(int j = 0; j < sample1.size(); ++j) { + for(int j = 0; j < sample1.size(); ++j) ASSERT_TRUE(allClose(sample1[j], sample2[j])); - } } } @@ -602,9 +577,8 @@ TEST(DatasetTest, DISABLED_PrefetchDatasetPerformance) { ); auto start = std::chrono::high_resolution_clock::now(); - for(auto& sample : *transformDs) { + for(auto& sample : *transformDs) (void) sample; - } auto dur = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - start ); @@ -615,9 +589,8 @@ TEST(DatasetTest, DISABLED_PrefetchDatasetPerformance) { std::make_shared(transformDs, numthreads, numthreads); start = std::chrono::high_resolution_clock::now(); - for(auto& sample : *prefetchDs) { + for(auto& sample : *prefetchDs) (void) sample; - } dur = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - start ); diff --git a/flashlight/fl/test/distributed/AllReduceBenchmark.cpp b/flashlight/fl/test/distributed/AllReduceBenchmark.cpp index d080f6b..0d6fdec 100644 --- a/flashlight/fl/test/distributed/AllReduceBenchmark.cpp +++ b/flashlight/fl/test/distributed/AllReduceBenchmark.cpp @@ -31,9 +31,8 @@ int main() { auto wRank = getWorldRank(); auto wSize = getWorldSize(); - if(wRank == 0) { + if(wRank == 0) std::cout << "Running allreduce on " << wSize << " machines" << std::endl; - } const int kNumIters = 10000; std::vector sizes = {1, 2, 5}; @@ -52,19 +51,17 @@ int main() { times[i] = fl::Timer::stop(start); } auto timesAf = Tensor::fromVector({kNumIters}, times); - if(wRank == 0) { + if(wRank == 0) std::cout << "Size: " << size << " ; avg: " << fl::mean(timesAf).asScalar() * 1000 << "ms ; p50: " << fl::median(timesAf).asScalar() * 1000 << "ms" << std::endl; - } curMaxSize = std::max(curMaxSize, size); size *= multiplier; } - if(curMaxSize >= maxSize) { + if(curMaxSize >= maxSize) break; - } } return 0; } diff --git a/flashlight/fl/test/distributed/AllReduceTest.cpp b/flashlight/fl/test/distributed/AllReduceTest.cpp index ed05cf8..4f8e849 100644 --- a/flashlight/fl/test/distributed/AllReduceTest.cpp +++ b/flashlight/fl/test/distributed/AllReduceTest.cpp @@ -21,9 +21,8 @@ using namespace fl; TEST(Distributed, AllReduce) { - if(!isDistributedInit()) { + if(!isDistributedInit()) GTEST_SKIP() << "Distributed initialization failed or not enabled."; - } auto rank = getWorldRank(); auto size = getWorldSize(); @@ -37,9 +36,8 @@ TEST(Distributed, AllReduce) { } TEST(Distributed, InlineReducer) { - if(!isDistributedInit()) { + if(!isDistributedInit()) GTEST_SKIP() << "Distributed initialization failed or not enabled."; - } auto rank = getWorldRank(); auto size = getWorldSize(); @@ -57,9 +55,8 @@ TEST(Distributed, InlineReducer) { } TEST(Distributed, AllReduceAsync) { - if(!isDistributedInit()) { + if(!isDistributedInit()) GTEST_SKIP() << "Distributed initialization failed or not enabled."; - } auto rank = getWorldRank(); auto size = getWorldSize(); @@ -76,9 +73,8 @@ TEST(Distributed, AllReduceAsync) { } TEST(Distributed, AllReduceSetAsync) { - if(!isDistributedInit()) { + if(!isDistributedInit()) GTEST_SKIP() << "Distributed initialization failed or not enabled."; - } auto rank = getWorldRank(); auto size = getWorldSize(); @@ -88,29 +84,25 @@ TEST(Distributed, AllReduceSetAsync) { unsigned vSize = (1 << 20); std::vector vars; - for(size_t i = 0; i < 5; ++i) { + for(size_t i = 0; i < 5; ++i) vars.emplace_back(fl::full({vSize}, rank + 1, dtype::f32), false); - } allReduceMultiple(vars, 2.0, async, contiguous); syncDistributed(); float expected_val = size * (size + 1.0); - for(const auto& var : vars) { + for(const auto& var : vars) ASSERT_TRUE(fl::all(var.tensor() == expected_val).scalar()); - } // Exceed the size of the contiguous buffer without caching, and trigger a // contiguous sync with a tensor that is too large - for(size_t i = 0; i < 25; ++i) { + for(size_t i = 0; i < 25; ++i) vars.emplace_back(fl::full({vSize}, rank, dtype::f32), false); - } - if(size > 1) { + if(size > 1) ASSERT_THROW( allReduceMultiple(vars, 2.0, /*async=*/ true, /*contiguous=*/ true), std::runtime_error ); - } } TEST(Distributed, Barrier) { @@ -136,12 +128,11 @@ TEST(Distributed, Barrier) { // Delete files std::error_code errorCode; const bool status = fs::remove(file, errorCode); - if(!status) { + if(!status) throw std::runtime_error( "Barrier test cannot delete file: " + std::string(file) + " error: " + errorCode.message() ); - } barrier(); for(int i = 0; i < size; i++) { auto checkingFile = @@ -151,9 +142,8 @@ TEST(Distributed, Barrier) { } TEST(Distributed, CoalescingReducer) { - if(!isDistributedInit()) { + if(!isDistributedInit()) GTEST_SKIP() << "Distributed initialization failed or not enabled."; - } auto rank = getWorldRank(); auto size = getWorldSize(); @@ -166,15 +156,13 @@ TEST(Distributed, CoalescingReducer) { unsigned vSize = (1 << 20); std::vector vars; - for(size_t i = 0; i < 1000; ++i) { + for(size_t i = 0; i < 1000; ++i) vars.emplace_back(fl::full({vSize}, rank + 1, dtype::f32), false); - } for(size_t i = 0; i < vars.size(); ++i) { s->add(vars[i]); - if((i + 1) % 10 == 0) { + if((i + 1) % 10 == 0) s->finalize(); - } } float expected_val = size * (size + 1.0); diff --git a/flashlight/fl/test/nn/ModuleTest.cpp b/flashlight/fl/test/nn/ModuleTest.cpp index ee78fe3..3fd8d61 100644 --- a/flashlight/fl/test/nn/ModuleTest.cpp +++ b/flashlight/fl/test/nn/ModuleTest.cpp @@ -35,9 +35,8 @@ class ContainerTestClass : public Sequential { void copy(const ContainerTestClass& other) { auto orphanParamIdxMap = other.getOrphanedParamsIdxMap(); for(int i = -1; i < static_cast(other.modules_.size()); ++i) { - if(i >= 0) { + if(i >= 0) add(other.modules_[i]->clone()); - } auto [paramIter, pEnd] = orphanParamIdxMap.equal_range(i); for(; paramIter != pEnd; ++paramIter) { const auto& param = other.params_[paramIter->second]; @@ -144,9 +143,8 @@ TEST(ModuleTest, LinearFwd) { } TEST_F(ModuleTestF16, LinearFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } int n_in = 2, n_out = 3, x = 4, batchsize = 2; auto wtVar = @@ -256,9 +254,8 @@ TEST(ModuleTest, GLUFwd) { } TEST_F(ModuleTestF16, GLUFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } auto inVar = Variable( Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}) @@ -338,9 +335,8 @@ TEST(ModuleTest, LogSoftmaxFwd) { } TEST_F(ModuleTestF16, LogSoftmaxFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } auto inVar = Variable( Tensor::fromVector({3, 2}, {0.8, 0.2, 0.2, 0.1, 0.5, 0.3}) @@ -405,9 +401,8 @@ TEST(ModuleTest, ConvolutionFwd) { } TEST_F(ModuleTestF16, ConvolutionFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } // test batching auto conv = Conv2D(30, 50, 9, 7, 2, 3, 3, 2, 1, 1, true, 1); @@ -469,9 +464,8 @@ TEST(ModuleTest, PoolingFwd) { } TEST_F(ModuleTestF16, PoolingFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } // test batching auto pool = Pool2D(9, 7, 1, 1, PaddingMode::SAME, PaddingMode::SAME); @@ -504,12 +498,10 @@ TEST(ModuleTest, RNNFwd) { ); unsigned n_params = 51; auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f32), true); - for(int i = 0; i < in.elements(); ++i) { + for(int i = 0; i < in.elements(); ++i) in.tensor().flat(i) = (i + 1) * 0.01; - } - for(int i = 0; i < w.elements(); ++i) { + for(int i = 0; i < w.elements(); ++i) w.tensor().flat(i) = (i + 1) * 0.01; - } auto rnn = RNN(input_size, hidden_size, num_layers, mode); rnn.setParams(w, 0); @@ -555,12 +547,10 @@ TEST(ModuleTest, LSTMFwd) { unsigned n_params = 920; auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f32), true); - for(int i = 0; i < in.elements(); ++i) { + for(int i = 0; i < in.elements(); ++i) in.tensor().flat(i) = (i + 1) * 0.001; - } - for(int i = 0; i < w.elements(); ++i) { + for(int i = 0; i < w.elements(); ++i) w.tensor().flat(i) = (i + 1) * 0.001; - } auto rnn = RNN(input_size, hidden_size, num_layers, mode); rnn.setParams(w, 0); @@ -596,12 +586,10 @@ TEST(ModuleTest, GRUFwd) { unsigned n_params = 690; auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f32), true); - for(int i = 0; i < in.elements(); ++i) { + for(int i = 0; i < in.elements(); ++i) in.tensor().flat(i) = (i + 1) * 0.001; - } - for(int i = 0; i < w.elements(); ++i) { + for(int i = 0; i < w.elements(); ++i) w.tensor().flat(i) = (i + 1) * 0.001; - } auto rnn = RNN(input_size, hidden_size, num_layers, mode); rnn.setParams(w, 0); @@ -623,9 +611,8 @@ TEST(ModuleTest, GRUFwd) { } TEST_F(ModuleTestF16, RNNFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } auto mode = RnnMode::RELU; int num_layers = 2; @@ -640,12 +627,10 @@ TEST_F(ModuleTestF16, RNNFwdF16) { ); unsigned n_params = 51; auto w = Variable(fl::rand({1, 1, n_params}, fl::dtype::f16), true); - for(int i = 0; i < in.elements(); ++i) { + for(int i = 0; i < in.elements(); ++i) in.tensor().flat(i) = (i + 1) * 0.01; - } - for(int i = 0; i < w.elements(); ++i) { + for(int i = 0; i < w.elements(); ++i) w.tensor().flat(i) = (i + 1) * 0.01; - } auto rnn = RNN(input_size, hidden_size, num_layers, mode); rnn.setParams(w, 0); @@ -706,9 +691,8 @@ TEST(ModuleTest, DropoutFwd) { } TEST_F(ModuleTestF16, DropoutFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } auto module = Dropout(0.5); // Train Mode @@ -799,9 +783,8 @@ TEST(ModuleTest, LayerNormFwd) { } TEST_F(ModuleTestF16, LayerNormFwdF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } double eps = 5E-2; std::vector feat_axes = {3}; @@ -863,9 +846,8 @@ TEST(ModuleTest, TransformFwd) { } TEST(ModuleTest, PrecisionCastFwd) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half precision not available on this device"; - } auto in = Variable(fl::full({3, 3}, 1.0), true); auto precisionCast = PrecisionCast(fl::dtype::f16); diff --git a/flashlight/fl/test/nn/NNSerializationTest.cpp b/flashlight/fl/test/nn/NNSerializationTest.cpp index 1a70ab3..ca5106f 100644 --- a/flashlight/fl/test/nn/NNSerializationTest.cpp +++ b/flashlight/fl/test/nn/NNSerializationTest.cpp @@ -41,9 +41,8 @@ auto filesizebytes = []() -> std::uintmax_t { auto paramsizebytes = [](const std::vector& parameters) { int64_t paramsize = 0; - for(const auto& param : parameters) { + for(const auto& param : parameters) paramsize += (param.elements() * fl::getTypeSize(param.type())); - } return paramsize; }; @@ -144,9 +143,8 @@ TEST(NNSerializationTest, BaseModule) { } TEST(NNSerializationTest, PrecisionCast) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half precision not available on this device"; - } auto in = input(fl::rand({8, 8})); auto precisionCast = std::make_shared(fl::dtype::f16); @@ -343,9 +341,8 @@ TEST(NNSerializationTest, ContainerBackward) { auto in = input(fl::rand({10, 10})); auto output = seq2->forward({in}).front(); output.backward(); - for(auto& p : seq2->params()) { + for(auto& p : seq2->params()) ASSERT_TRUE(p.isGradAvailable()); - } } TEST(NNSerializationTest, ContainerWithParams) { diff --git a/flashlight/fl/test/optim/OptimBenchmark.cpp b/flashlight/fl/test/optim/OptimBenchmark.cpp index 4562d39..e3ca550 100644 --- a/flashlight/fl/test/optim/OptimBenchmark.cpp +++ b/flashlight/fl/test/optim/OptimBenchmark.cpp @@ -22,17 +22,15 @@ using namespace fl; double timeit(std::function fn) { // warmup - for(int i = 0; i < 10; ++i) { + for(int i = 0; i < 10; ++i) fn(); - } fl::sync(); int num_iters = 100; fl::sync(); auto start = fl::Timer::start(); - for(int i = 0; i < num_iters; i++) { + for(int i = 0; i < num_iters; i++) fn(); - } fl::sync(); return fl::Timer::stop(start) / num_iters; } diff --git a/flashlight/fl/test/optim/OptimTest.cpp b/flashlight/fl/test/optim/OptimTest.cpp index 3951fc2..d7c8e1c 100644 --- a/flashlight/fl/test/optim/OptimTest.cpp +++ b/flashlight/fl/test/optim/OptimTest.cpp @@ -36,9 +36,8 @@ TEST(OptimTest, GradNorm) { } TEST(OptimTest, GradNormF16) { - if(!fl::f16Supported()) { + if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; - } std::vector parameters; for(int i = 0; i < 5; i++) { @@ -83,16 +82,14 @@ TEST(SerializationTest, OptimizerSerialize) { std::shared_ptr opt2; load(path, parameters2, opt2); - for(int i = 0; i < 5; i++) { + for(int i = 0; i < 5; i++) parameters2[i].addGrad(Variable(parameters[i].grad().tensor(), false)); - } opt->step(); opt2->step(); - for(int i = 0; i < 5; i++) { + for(int i = 0; i < 5; i++) ASSERT_TRUE(allClose(parameters[i].tensor(), parameters2[i].tensor())); - } opt = std::make_shared(parameters, 0.01f); opt->step(); @@ -104,16 +101,14 @@ TEST(SerializationTest, OptimizerSerialize) { ); load(path, parameters2, opt2); - for(int i = 0; i < 5; i++) { + for(int i = 0; i < 5; i++) parameters2[i].addGrad(Variable(parameters[i].grad().tensor(), false)); - } opt->step(); opt2->step(); - for(int i = 0; i < 5; i++) { + for(int i = 0; i < 5; i++) ASSERT_TRUE(allClose(parameters[i].tensor(), parameters2[i].tensor())); - } } int main(int argc, char** argv) { diff --git a/flashlight/fl/test/runtime/DeviceManagerTest.cpp b/flashlight/fl/test/runtime/DeviceManagerTest.cpp index a3c7ba4..8f25dcd 100644 --- a/flashlight/fl/test/runtime/DeviceManagerTest.cpp +++ b/flashlight/fl/test/runtime/DeviceManagerTest.cpp @@ -34,14 +34,13 @@ TEST(DeviceManagerTest, getDeviceCount) { // For now we always treat CPU as a single device ASSERT_EQ(manager.getDeviceCount(DeviceType::x64), 1); - if(manager.isDeviceTypeAvailable(DeviceType::CUDA)) { + if(manager.isDeviceTypeAvailable(DeviceType::CUDA)) ASSERT_NO_THROW(manager.getDeviceCount(DeviceType::CUDA)); - } else { + else ASSERT_THROW( manager.getDeviceCount(DeviceType::CUDA), std::runtime_error ); - } } TEST(DeviceManagerTest, getDevicesOfType) { @@ -50,16 +49,14 @@ TEST(DeviceManagerTest, getDevicesOfType) { ASSERT_EQ(manager.getDevicesOfType(DeviceType::x64).size(), 1); for(auto type : fl::getDeviceTypes()) { - if(manager.isDeviceTypeAvailable(DeviceType::CUDA)) { - for(auto device : manager.getDevicesOfType(type)) { + if(manager.isDeviceTypeAvailable(DeviceType::CUDA)) + for(auto device : manager.getDevicesOfType(type)) ASSERT_EQ(device->type(), type); - } - } else { + else ASSERT_THROW( manager.getDeviceCount(DeviceType::CUDA), std::runtime_error ); - } } } @@ -73,11 +70,10 @@ TEST(DeviceManagerTest, getDevice) { TEST(DeviceManagerTest, getActiveDevice) { auto& manager = DeviceManager::getInstance(); for(auto type : fl::getDeviceTypes()) { - if(manager.isDeviceTypeAvailable(type)) { + if(manager.isDeviceTypeAvailable(type)) ASSERT_EQ(manager.getActiveDevice(type).type(), type); - } else { + else ASSERT_THROW(manager.getActiveDevice(type), std::runtime_error); - } } } diff --git a/flashlight/fl/test/runtime/DeviceTest.cpp b/flashlight/fl/test/runtime/DeviceTest.cpp index 62fd078..44a1e1c 100644 --- a/flashlight/fl/test/runtime/DeviceTest.cpp +++ b/flashlight/fl/test/runtime/DeviceTest.cpp @@ -15,38 +15,32 @@ using fl::DeviceType; TEST(DeviceTest, type) { auto& manager = DeviceManager::getInstance(); - for(auto type : fl::getDeviceTypes()) { - if(manager.isDeviceTypeAvailable(type)) { - for(auto* device : manager.getDevicesOfType(type)) { + for(auto type : fl::getDeviceTypes()) + if(manager.isDeviceTypeAvailable(type)) + for(auto* device : manager.getDevicesOfType(type)) ASSERT_EQ(device->type(), type); - } - } - } } TEST(DeviceTest, nativeId) { const auto& manager = DeviceManager::getInstance(); - for(const auto* device : manager.getDevicesOfType(DeviceType::x64)) { + for(const auto* device : manager.getDevicesOfType(DeviceType::x64)) ASSERT_EQ(device->nativeId(), fl::kX64DeviceId); - } } TEST(DeviceTest, setActive) { auto& manager = DeviceManager::getInstance(); - for(auto type : fl::getDeviceTypes()) { - if(manager.isDeviceTypeAvailable(type)) { + for(auto type : fl::getDeviceTypes()) + if(manager.isDeviceTypeAvailable(type)) for(auto* device : manager.getDevicesOfType(type)) { device->setActive(); ASSERT_EQ(&manager.getActiveDevice(type), device); } - } - } } TEST(DeviceTest, addSetActiveCallback) { auto& manager = DeviceManager::getInstance(); - for(const auto type : fl::getDeviceTypes()) { - if(manager.isDeviceTypeAvailable(type)) { + for(const auto type : fl::getDeviceTypes()) + if(manager.isDeviceTypeAvailable(type)) for(auto* device : manager.getDevicesOfType(type)) { int count = 0; auto incCount = [&count](int) { count++; }; @@ -54,32 +48,23 @@ TEST(DeviceTest, addSetActiveCallback) { device->setActive(); ASSERT_EQ(count, 1); } - } - } } TEST(DeviceTest, sync) { const auto& manager = DeviceManager::getInstance(); - for(const auto type : fl::getDeviceTypes()) { - if(manager.isDeviceTypeAvailable(type)) { - for(const auto* device : manager.getDevicesOfType(type)) { + for(const auto type : fl::getDeviceTypes()) + if(manager.isDeviceTypeAvailable(type)) + for(const auto* device : manager.getDevicesOfType(type)) ASSERT_NO_THROW(device->sync()); - } - } - } } TEST(DeviceTest, getStream) { auto& manager = DeviceManager::getInstance(); - for(const auto type : fl::getDeviceTypes()) { - if(manager.isDeviceTypeAvailable(type)) { - for(const auto* device : manager.getDevicesOfType(type)) { - for(const auto& stream : device->getStreams()) { + for(const auto type : fl::getDeviceTypes()) + if(manager.isDeviceTypeAvailable(type)) + for(const auto* device : manager.getDevicesOfType(type)) + for(const auto& stream : device->getStreams()) ASSERT_EQ(&stream->device(), device); - } - } - } - } } int main(int argc, char** argv) { diff --git a/flashlight/fl/test/tensor/IndexTest.cpp b/flashlight/fl/test/tensor/IndexTest.cpp index 36eedc2..920de06 100644 --- a/flashlight/fl/test/tensor/IndexTest.cpp +++ b/flashlight/fl/test/tensor/IndexTest.cpp @@ -53,9 +53,8 @@ TEST(IndexTest, Type) { TEST(IndexTest, ArrayFireMaxIndex) { auto t = fl::full({2, 3, 4, 5}, 6.); - if(t.backendType() != TensorBackendType::ArrayFire) { + if(t.backendType() != TensorBackendType::ArrayFire) GTEST_SKIP() << "Default Tensor type isn't ArrayFire"; - } ASSERT_THROW(t(1, 2, 3, 4, 5), std::invalid_argument); } @@ -172,34 +171,29 @@ TEST(IndexTest, IndexInPlaceOps) { TEST(IndexTest, flat) { auto m = fl::rand({4, 6}); - for(unsigned i = 0; i < m.elements(); ++i) { + for(unsigned i = 0; i < m.elements(); ++i) ASSERT_TRUE(allClose(m.flat(i), m(i % 4, i / 4))); - } auto n = fl::rand({4, 6, 8}); - for(unsigned i = 0; i < n.elements(); ++i) { + for(unsigned i = 0; i < n.elements(); ++i) ASSERT_TRUE(allClose(n.flat(i), n(i % 4, (i / 4) % 6, (i / (4 * 6)) % 8))); - } auto a = fl::full({5, 6, 7, 8}, 9.); std::vector testIndices = {0, 1, 4, 11, 62, 104, 288}; - for(const int i : testIndices) { + for(const int i : testIndices) ASSERT_EQ(a.flat(i).scalar(), 9.); - } a.flat(8) = 5.; ASSERT_EQ(a.flat(8).scalar(), 5.); - for(const int i : testIndices) { + for(const int i : testIndices) a.flat(i) = i + 1; - } - for(const int i : testIndices) { + for(const int i : testIndices) ASSERT_EQ( a(i % 5, (i / 5) % 6, (i / (5 * 6)) % 7, (i / (5 * 6 * 7)) % 8) .scalar(), i + 1 ); - } // Tensor assignment a.flat(32) = fl::full({1}, 7.4); @@ -214,13 +208,12 @@ TEST(IndexTest, flat) { ASSERT_EQ(ref.shape(), Shape({(Dim) indexer.elements()})); a.flat(indexer) -= 10; ASSERT_TRUE(allClose(a.flat(indexer), ref - 10)); - for(const int i : testIndices) { + for(const int i : testIndices) ASSERT_EQ( a(i % 5, (i / 5) % 6, (i / (5 * 6)) % 7, (i / (5 * 6 * 7)) % 8) .scalar(), i + 1 - 10 ); - } // Range flat assignment auto rA = fl::rand({6}); @@ -238,14 +231,12 @@ TEST(IndexTest, TensorIndex) { std::vector idxs = {0, 1, 4, 9, 11, 13, 16, 91}; unsigned size = idxs.size(); auto indices = fl::full({size}, 0); - for(int i = 0; i < size; ++i) { + for(int i = 0; i < size; ++i) indices(i) = idxs[i]; - } auto a = fl::rand({100}); auto indexed = a(indices); - for(int i = 0; i < size; ++i) { + for(int i = 0; i < size; ++i) ASSERT_TRUE(allClose(indexed(i), a(idxs[i]))); - } a(indices) = 5.; ASSERT_TRUE(allClose(a(indices), fl::full({size}, 5.))); diff --git a/flashlight/fl/test/tensor/ShapeTest.cpp b/flashlight/fl/test/tensor/ShapeTest.cpp index 2f3d1e1..6a88f7d 100644 --- a/flashlight/fl/test/tensor/ShapeTest.cpp +++ b/flashlight/fl/test/tensor/ShapeTest.cpp @@ -24,9 +24,8 @@ TEST(ShapeTest, Basic) { } TEST(ShapeTest, ManyDims) { - if(Shape::kMaxDims <= 4) { + if(Shape::kMaxDims <= 4) GTEST_SKIP() << "Max shape dimensions is <= 4"; - } auto many = Shape({1, 2, 3, 4, 5, 6, 7}); ASSERT_EQ(many.ndim(), 7); ASSERT_EQ(many.dim(5), 6); diff --git a/flashlight/fl/test/tensor/TensorBLASTest.cpp b/flashlight/fl/test/tensor/TensorBLASTest.cpp index c39d6e6..df14ba8 100644 --- a/flashlight/fl/test/tensor/TensorBLASTest.cpp +++ b/flashlight/fl/test/tensor/TensorBLASTest.cpp @@ -26,13 +26,10 @@ TEST(TensorBLASTest, matmul) { auto out = fl::full({M, K}, 0.); - for(unsigned i = 0; i < M; ++i) { - for(unsigned j = 0; j < K; ++j) { - for(unsigned k = 0; k < N; ++k) { + for(unsigned i = 0; i < M; ++i) + for(unsigned j = 0; j < K; ++j) + for(unsigned k = 0; k < N; ++k) out(i, j) += lhs(i, k) * rhs(k, j); - } - } - } return out; }; diff --git a/flashlight/fl/test/tensor/TensorBaseTest.cpp b/flashlight/fl/test/tensor/TensorBaseTest.cpp index 1f14124..730dfd9 100644 --- a/flashlight/fl/test/tensor/TensorBaseTest.cpp +++ b/flashlight/fl/test/tensor/TensorBaseTest.cpp @@ -206,9 +206,8 @@ TEST(TensorBaseTest, ConstructFromData) { std::vector ascending = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; auto t = fl::Tensor::fromVector({3, 4}, ascending); ASSERT_EQ(t.type(), fl::dtype::f32); - for(int i = 0; i < ascending.size(); ++i) { + for(int i = 0; i < ascending.size(); ++i) ASSERT_FLOAT_EQ(t(i % 3, i / 3).scalar(), ascending[i]); - } // TODO: add fixtures/check stuff std::vector intV = {1, 2, 3}; @@ -336,9 +335,8 @@ TEST(TensorBaseTest, concatenate) { TEST(TensorBaseTest, nonzero) { std::vector idxs = {0, 1, 4, 9, 11, 23, 55, 82, 91}; auto a = fl::full({10, 10}, 1, fl::dtype::u32); - for(const auto idx : idxs) { + for(const auto idx : idxs) a(idx / 10, idx % 10) = 0; - } auto indices = fl::nonzero(a); int nnz = a.elements() - idxs.size(); ASSERT_EQ(indices.shape(), Shape({nnz})); @@ -458,9 +456,8 @@ TEST(TensorBaseTest, sort) { auto sorted = fl::sort(a, /* axis = */ 0, SortMode::Descending); Tensor expected({dims[0]}, a.type()); - for(int i = 0; i < dims[0]; ++i) { + for(int i = 0; i < dims[0]; ++i) expected(i) = dims[0] - i - 1; - } auto tiled = fl::tile(expected, {1, 2}); ASSERT_TRUE(allClose(sorted, tiled)); @@ -483,9 +480,8 @@ TEST(TensorBaseTest, argsort) { auto sorted = fl::argsort(a, /* axis = */ 0, SortMode::Descending); Tensor expected({dims[0]}, fl::dtype::u32); - for(int i = 0; i < dims[0]; ++i) { + for(int i = 0; i < dims[0]; ++i) expected(i) = dims[0] - i - 1; - } auto tiled = fl::tile(expected, {1, 2}); ASSERT_TRUE(allClose(sorted, tiled)); @@ -507,15 +503,14 @@ void assertScalarBehavior(fl::dtype type) { if( (type == fl::dtype::f16) || (type == fl::dtype::f32) || (type == fl::dtype::f64) - ) { + ) ASSERT_FLOAT_EQ(one.template scalar(), scalar) << "dtype: " << type << ", ScalarArgType: " << dtype_traits::getName(); - } else { + else ASSERT_EQ(one.template scalar(), scalar) << "dtype: " << type << ", ScalarArgType: " << dtype_traits::getName(); - } ScalarArgType val = static_cast(rand()); @@ -598,15 +593,13 @@ TEST(TensorBaseTest, host) { auto a = fl::rand({10, 10}); float* ptr = a.host(); - for(int i = 0; i < a.elements(); ++i) { + for(int i = 0; i < a.elements(); ++i) ASSERT_EQ(ptr[i], a.flatten()(i).scalar()); - } float* existingBuffer = new float[100]; a.host(existingBuffer); - for(int i = 0; i < a.elements(); ++i) { + for(int i = 0; i < a.elements(); ++i) ASSERT_EQ(existingBuffer[i], a.flatten()(i).scalar()); - } ASSERT_EQ(Tensor().host(), nullptr); } @@ -615,9 +608,8 @@ TEST(TensorBaseTest, toHostVector) { auto a = fl::rand({10, 10}); auto vec = a.toHostVector(); - for(int i = 0; i < a.elements(); ++i) { + for(int i = 0; i < a.elements(); ++i) ASSERT_EQ(vec[i], a.flatten()(i).scalar()); - } ASSERT_EQ(Tensor().toHostVector().size(), 0); } diff --git a/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp b/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp index ea63d03..e2d6254 100644 --- a/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp +++ b/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp @@ -331,10 +331,9 @@ TEST(TensorBinaryOpsTest, BinaryOperatorIncompatibleShapes) { << "dtype: " << type; // TODO ArrayFire needs software impl for fp16 modulo on CUDA backend; // bring this test back when supported. - if(type != dtype::f16) { + if(type != dtype::f16) ASSERT_THROW((void) Values(lhs % rhs), std::invalid_argument) << "dtype: " << type; - } // these operators are generally not well-defined for fps if(type != dtype::f16 && type != dtype::f32 && type != dtype::f64) { ASSERT_THROW((void) Values(lhs | rhs), std::invalid_argument) @@ -440,21 +439,20 @@ TEST(TensorBinaryOpsTest, broadcasting) { unsigned maxnDim = std::max(lhsShape.ndim(), rhsShape.ndim()); Shape outShape{std::vector(maxnDim)}; for(unsigned i = 0; i < maxnDim; ++i) { - if(i > lhsShape.ndim() - 1) { + if(i > lhsShape.ndim() - 1) outShape[i] = rhsShape[i]; - } else if(i > rhsShape.ndim() - 1) { + else if(i > rhsShape.ndim() - 1) outShape[i] = lhsShape[i]; - } else if(lhsShape[i] == 1) { + else if(lhsShape[i] == 1) outShape[i] = rhsShape[i]; - } else if(rhsShape[i] == 1) { + else if(rhsShape[i] == 1) outShape[i] = lhsShape[i]; - } else if(lhsShape[i] == rhsShape[i]) { + else if(lhsShape[i] == rhsShape[i]) outShape[i] = lhsShape[i]; - } else if(lhsShape[i] != rhsShape[i]) { + else if(lhsShape[i] != rhsShape[i]) throw std::runtime_error( "computeBroadcastShape - cannot broadcast shape" ); - } } return outShape; }; diff --git a/flashlight/fl/test/tensor/TensorExtensionTest.cpp b/flashlight/fl/test/tensor/TensorExtensionTest.cpp index 8f4f790..d60930f 100644 --- a/flashlight/fl/test/tensor/TensorExtensionTest.cpp +++ b/flashlight/fl/test/tensor/TensorExtensionTest.cpp @@ -57,9 +57,8 @@ TEST(TensorExtensionTest, TestExtension) { auto a = fl::rand({4, 5, 6}); // TODO: this test only works with the ArrayFire backend - gate accordingly - if(Tensor().backendType() != TensorBackendType::ArrayFire) { + if(Tensor().backendType() != TensorBackendType::ArrayFire) GTEST_SKIP() << "Flashlight not built with ArrayFire backend."; - } // TODO: add a fixture to check with available backends ASSERT_TRUE( diff --git a/flashlight/fl/test/tensor/TensorReductionTest.cpp b/flashlight/fl/test/tensor/TensorReductionTest.cpp index b29b59b..320356b 100644 --- a/flashlight/fl/test/tensor/TensorReductionTest.cpp +++ b/flashlight/fl/test/tensor/TensorReductionTest.cpp @@ -19,9 +19,8 @@ using namespace fl; TEST(TensorReductionTest, countNonzero) { std::vector idxs = {0, 3, 4, 7, 24, 78}; auto a = fl::full({10, 10}, 1, fl::dtype::u32); - for(const auto idx : idxs) { + for(const auto idx : idxs) a(idx / 10, idx % 10) = 0; - } ASSERT_TRUE( allClose( @@ -31,10 +30,9 @@ TEST(TensorReductionTest, countNonzero) { ); std::vector sizes(a.shape().dim(0)); - for(unsigned i = 0; i < a.shape().dim(0); ++i) { + for(unsigned i = 0; i < a.shape().dim(0); ++i) sizes[i] = a.shape().dim(0) - fl::sum(a(fl::span, i) == 0, {0}).scalar(); - } ASSERT_TRUE(allClose(Tensor::fromVector(sizes), Tensor::fromVector(sizes))); auto b = fl::full({2, 2, 2}, 1, fl::dtype::u32); @@ -152,16 +150,14 @@ TEST(TensorReductionTest, min) { fl::min(values, indices, in, 0); ASSERT_EQ(indices.shape(), Shape({in.dim(1)})); ASSERT_TRUE(allClose(indices, Tensor::fromVector({3}, {0, 1, 0}))); - for(unsigned i = 0; i < values.elements(); ++i) { + for(unsigned i = 0; i < values.elements(); ++i) ASSERT_TRUE(allClose(values.flat(i), in(fl::span, i)(indices(i)))); - } fl::min(values, indices, in, 1); ASSERT_EQ(indices.shape(), Shape({in.dim(0)})); ASSERT_TRUE(allClose(indices, Tensor::fromVector({2}, {0, 1}))); - for(unsigned i = 0; i < values.elements(); ++i) { + for(unsigned i = 0; i < values.elements(); ++i) ASSERT_TRUE(allClose(values.flat(i), in(i)(indices(i)))); - } fl::min(values, indices, in, 0, /* keepDims = */ true); ASSERT_EQ(values.shape(), Shape({1, in.dim(1)})); @@ -176,16 +172,14 @@ TEST(TensorReductionTest, max) { fl::max(values, indices, in, 0); ASSERT_EQ(indices.shape(), Shape({in.dim(1)})); ASSERT_TRUE(allClose(indices, Tensor::fromVector({3}, {1, 0, 1}))); - for(unsigned i = 0; i < values.elements(); ++i) { + for(unsigned i = 0; i < values.elements(); ++i) ASSERT_TRUE(allClose(values.flat(i), in(fl::span, i)(indices(i)))); - } fl::max(values, indices, in, 1); ASSERT_EQ(indices.shape(), Shape({in.dim(0)})); ASSERT_TRUE(allClose(indices, Tensor::fromVector({2}, {1, 2}))); - for(unsigned i = 0; i < values.elements(); ++i) { + for(unsigned i = 0; i < values.elements(); ++i) ASSERT_TRUE(allClose(values.flat(i), in(i)(indices(i)))); - } fl::max(values, indices, in, 0, /* keepDims = */ true); ASSERT_EQ(values.shape(), Shape({1, in.dim(1)})); @@ -199,9 +193,8 @@ TEST(TensorReductionTest, cumsum) { auto a = fl::tile(fl::arange(1, max), {1, 2}); auto ref = fl::arange(1, max); - for(int i = 1; i < max - 1; ++i) { + for(int i = 1; i < max - 1; ++i) ref += fl::concatenate({fl::full({i}, 0), fl::arange(1, max - i)}); - } ASSERT_TRUE(allClose(fl::cumsum(a, 0), fl::tile(ref, {1, 2}))); ASSERT_TRUE( diff --git a/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp b/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp index 81a3f93..c38a171 100644 --- a/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp +++ b/flashlight/fl/test/tensor/TensorUnaryOpsTest.cpp @@ -100,16 +100,12 @@ TEST(TensorUnaryOpsTest, sign) { TEST(TensorUnaryOpsTest, tril) { auto checkSquareTril = [](const Dim dim, const Tensor& res, const Tensor& in) { - for(int i = 0; i < dim; ++i) { - for(int j = i + 1; j < dim; ++j) { + for(int i = 0; i < dim; ++i) + for(int j = i + 1; j < dim; ++j) ASSERT_EQ(res(i, j).scalar(), 0.); - } - } - for(int i = 0; i < dim; ++i) { - for(int j = 0; j < i; ++j) { + for(int i = 0; i < dim; ++i) + for(int j = 0; j < i; ++j) ASSERT_TRUE(allClose(res(i, j), in(i, j))); - } - } }; Dim dim = 10; auto t = fl::rand({dim, dim}); @@ -121,28 +117,23 @@ TEST(TensorUnaryOpsTest, tril) { Dim dim2 = 3; auto t2 = fl::rand({dim2, dim2, dim2}); auto out2 = fl::tril(t2); - for(unsigned i = 0; i < dim2; ++i) { + for(unsigned i = 0; i < dim2; ++i) checkSquareTril( dim2, out2(fl::span, fl::span, i), t2(fl::span, fl::span, i) ); - } } TEST(TensorUnaryOpsTest, triu) { auto checkSquareTriu = [](const Dim dim, const Tensor& res, const Tensor& in) { - for(unsigned i = 0; i < dim; ++i) { - for(unsigned j = i + 1; j < dim; ++j) { + for(unsigned i = 0; i < dim; ++i) + for(unsigned j = i + 1; j < dim; ++j) ASSERT_TRUE(allClose(res(i, j), in(i, j))); - } - } - for(unsigned i = 0; i < dim; ++i) { - for(unsigned j = 0; j < i; ++j) { + for(unsigned i = 0; i < dim; ++i) + for(unsigned j = 0; j < i; ++j) ASSERT_EQ(res(i, j).scalar(), 0.); - } - } }; int dim = 10; @@ -155,13 +146,12 @@ TEST(TensorUnaryOpsTest, triu) { int dim2 = 3; auto t2 = fl::rand({dim2, dim2, dim2}); auto out2 = fl::triu(t2); - for(int i = 0; i < dim2; ++i) { + for(int i = 0; i < dim2; ++i) checkSquareTriu( dim2, out2(fl::span, fl::span, i), t2(fl::span, fl::span, i) ); - } } TEST(TensorUnaryOpsTest, floor) { diff --git a/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp b/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp index af86e5d..38937ef 100644 --- a/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp +++ b/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp @@ -33,15 +33,12 @@ bool allClose( const af::array& b, double absTolerance = 1e-5 ) { - if(a.type() != b.type()) { + if(a.type() != b.type()) return false; - } - if(a.dims() != b.dims()) { + if(a.dims() != b.dims()) return false; - } - if(a.isempty() && b.isempty()) { + if(a.isempty() && b.isempty()) return true; - } return af::max(af::abs(a - b)) < absTolerance; } diff --git a/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp b/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp index 2414a14..eb29cb1 100644 --- a/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp +++ b/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp @@ -62,9 +62,8 @@ TEST_F(CachingMemoryManagerTest, DevicePtr) { // The CPU backend in AF allocates a buffer for empty arrays - see // https://github.com/arrayfire/arrayfire/issues/3058. When this is fixed, // this can be relaxed. - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) GTEST_SKIP() << "ArrayFire CPU backend allocates buffers for empty arrays"; - } // Empty array auto arr1 = af::array(0, 0, 0, 0, af::dtype::f32); @@ -107,11 +106,9 @@ TEST_F(CachingMemoryManagerTest, IndexedDevice) { std::vector in2(in.elements()); in.host(in2.data()); - for(int y = 0; y < nyo; y++) { - for(int x = 0; x < nxo; x++) { + for(int y = 0; y < nyo; y++) + for(int x = 0; x < nxo; x++) ASSERT_EQ(in1[(offy + y) * nx + offx + x], in2[y * nxo + x]); - } - } } TEST_F(CachingMemoryManagerTest, LargeNumberOfAllocs) { @@ -136,9 +133,8 @@ TEST_F(CachingMemoryManagerTest, OOM) { // depending on the drivers, afopencl does not seem to guarantee to send an // OOM signal. https://github.com/arrayfire/arrayfire/issues/2650 At the // moment, skipping afopencl. - if(b == AF_BACKEND_OPENCL) { + if(b == AF_BACKEND_OPENCL) GTEST_SKIP() << "Can't run test with the ArrayFire OpenCL backend"; - } af::array a; // N^3 tensor means about 3PB: expected to OOM on today's cuda GPU. const unsigned N = 99999; @@ -158,10 +154,9 @@ void testFragmentation( ) { af::Backend b = af::getActiveBackend(); - if(b != AF_BACKEND_CUDA) { + if(b != AF_BACKEND_CUDA) GTEST_SKIP() << "CachingMemoryManager fragmentation tests require CUDA backend"; - } const auto mms = deviceInterface_->getMaxMemorySize(0); const auto maxNumf32 = mms / sizeof(float); // AF f32 is supposed to be 32b @@ -179,13 +174,12 @@ void testFragmentation( try { a3 = af::array(.5f * maxNumf32); } catch(af::exception& ex) { - if(expectOOM) { + if(expectOOM) ASSERT_EQ(ex.err(), AF_ERR_NO_MEM); - } else { + else EXPECT_TRUE(false) << "CachingMemoryManagerTest fragmentaiton not supposed to throw: " << ex.what(); - } } } diff --git a/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp b/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp index ccc1f32..90af69c 100644 --- a/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp +++ b/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp @@ -56,15 +56,13 @@ class TestMemoryManager : public MemoryManagerAdapter { const unsigned elSize ) override { size_t size = elSize; - for(unsigned i = 0; i < ndims; ++i) { + for(unsigned i = 0; i < ndims; ++i) size *= dims[i]; - } void* ptr = nullptr; if(size > 0) { - if(lockedBytes >= maxBytes || totalBytes >= maxBuffers) { + if(lockedBytes >= maxBytes || totalBytes >= maxBuffers) signalMemoryCleanup(); - } ptr = this->deviceInterface->nativeAlloc(size); lockedPtrToSizeMap[ptr] = size; @@ -81,21 +79,18 @@ class TestMemoryManager : public MemoryManagerAdapter { } size_t allocated(void* ptr) override { - if(lockedPtrToSizeMap.find(ptr) == lockedPtrToSizeMap.end()) { + if(lockedPtrToSizeMap.find(ptr) == lockedPtrToSizeMap.end()) return 0; - } else { + else return lockedPtrToSizeMap[ptr]; - } } void unlock(void* ptr, bool userLock) override { - if(!ptr) { + if(!ptr) return; - } - if(lockedPtrToSizeMap.find(ptr) == lockedPtrToSizeMap.end()) { + if(lockedPtrToSizeMap.find(ptr) == lockedPtrToSizeMap.end()) return; - } // For testing, treat user-allocated and AF-allocated memory identically if(locked.find(ptr) != locked.end()) { @@ -107,17 +102,15 @@ class TestMemoryManager : public MemoryManagerAdapter { void signalMemoryCleanup() override { // Free unlocked memory std::vector freed; - for(auto& entry : lockedPtrToSizeMap) { + for(auto& entry : lockedPtrToSizeMap) if(!isUserLocked(entry.first)) { void* ptr = entry.first; this->deviceInterface->nativeFree(ptr); totalBytes -= lockedPtrToSizeMap[entry.first]; freed.push_back(entry.first); } - } - for(auto ptr : freed) { + for(auto ptr : freed) lockedPtrToSizeMap.erase(ptr); - } } void printInfo( @@ -145,11 +138,10 @@ class TestMemoryManager : public MemoryManagerAdapter { } float getMemoryPressure() override { - if(lockedBytes > maxBytes || totalBuffers > maxBuffers) { + if(lockedBytes > maxBytes || totalBuffers > maxBuffers) return 1.0; - } else { + else return 0.0; - } } bool jitTreeExceedsMemoryPressure(size_t bytes) override { @@ -325,9 +317,8 @@ TEST(MemoryFramework, AdapterInstallerDeviceInterfaceTest) { // The CPU backend in AF allocates a buffer for empty arrays - see // https://github.com/arrayfire/arrayfire/issues/3058. When this is fixed, // this can be relaxed/this test will pass - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) GTEST_SKIP() << "ArrayFire CPU backend allocates buffers for empty arrays"; - } std::stringstream logStream; std::stringstream mockLogStream; diff --git a/flashlight/fl/test/tensor/af/MemoryInitTest.cpp b/flashlight/fl/test/tensor/af/MemoryInitTest.cpp index 97bdd2d..1101bfa 100644 --- a/flashlight/fl/test/tensor/af/MemoryInitTest.cpp +++ b/flashlight/fl/test/tensor/af/MemoryInitTest.cpp @@ -17,9 +17,8 @@ using namespace fl; TEST(MemoryInitTest, DefaultManagerInitializesCorrectType) { - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) GTEST_SKIP() << "CachingMemoryManager is not used on CPU backend"; - } auto* manager = MemoryManagerInstaller::currentlyInstalledMemoryManager(); // A non-null value means that a) a custom memory manager has been installed // and b) that a CachingMemoryManager has been installed which is the desired diff --git a/flashlight/pkg/runtime/Runtime.cpp b/flashlight/pkg/runtime/Runtime.cpp index 148b8e8..dc06381 100644 --- a/flashlight/pkg/runtime/Runtime.cpp +++ b/flashlight/pkg/runtime/Runtime.cpp @@ -42,19 +42,16 @@ bool backwardWithScaling( std::shared_ptr reducer ) { auto scaledLoss = loss; - if(dynamicScaler) { + if(dynamicScaler) scaledLoss = dynamicScaler->scale(loss); - } scaledLoss.backward(); - if(reducer) { + if(reducer) reducer->finalize(); - } if(dynamicScaler) { - if(!dynamicScaler->unscale(params)) { + if(!dynamicScaler->unscale(params)) return false; - } dynamicScaler->update(); } diff --git a/flashlight/pkg/runtime/amp/DynamicScaler.cpp b/flashlight/pkg/runtime/amp/DynamicScaler.cpp index d9de6ac..dc57d57 100644 --- a/flashlight/pkg/runtime/amp/DynamicScaler.cpp +++ b/flashlight/pkg/runtime/amp/DynamicScaler.cpp @@ -28,24 +28,22 @@ fl::Variable DynamicScaler::scale(const fl::Variable& loss) { bool DynamicScaler::unscale(std::vector& params) { for(auto& p : params) { - if(!p.isGradAvailable()) { + if(!p.isGradAvailable()) // Add a dummy grad for params not used in the backwards pass p.addGrad(Variable(fl::full(p.shape(), 0., p.type()), false)); - } p.grad() = p.grad() / scaleFactor_; if(fl::isInvalidArray(p.grad().tensor())) { if(scaleFactor_ >= fl::kAmpMinimumScaleFactorValue) { scaleFactor_ = scaleFactor_ / 2.0f; FL_LOG(LogLevel::INFO) << "AMP: Scale factor decreased. New value:\t" << scaleFactor_; - } else { + } else FL_LOG(LogLevel::FATAL) << "Minimum loss scale reached: " << fl::kAmpMinimumScaleFactorValue << " with over/underflowing gradients. Lowering the " << "learning rate, using gradient clipping, or " << "increasing the batch size can help resolve " << "loss explosion."; - } successCounter_ = 0; return false; } @@ -56,9 +54,8 @@ bool DynamicScaler::unscale(std::vector& params) { } void DynamicScaler::update() { - if(scaleFactor_ >= maxScaleFactor_) { + if(scaleFactor_ >= maxScaleFactor_) return; - } if(scaleFactor_ == updateInterval_) { scaleFactor_ *= 2; diff --git a/flashlight/pkg/runtime/common/DistributedUtils.cpp b/flashlight/pkg/runtime/common/DistributedUtils.cpp index 627f4ea..a2ab6d9 100644 --- a/flashlight/pkg/runtime/common/DistributedUtils.cpp +++ b/flashlight/pkg/runtime/common/DistributedUtils.cpp @@ -17,7 +17,7 @@ void initDistributed( int maxDevicesPerNode, const std::string& rndvFilepath ) { - if(rndvFilepath.empty()) { + if(rndvFilepath.empty()) distributedInit( fl::DistributedInit::MPI, -1, // unused for MPI @@ -25,7 +25,7 @@ void initDistributed( {{fl::DistributedConstants::kMaxDevicePerNode, std::to_string(maxDevicesPerNode)}} ); - } else { + else distributedInit( fl::DistributedInit::FILE_SYSTEM, worldRank, @@ -34,7 +34,6 @@ void initDistributed( std::to_string(maxDevicesPerNode)}, {fl::DistributedConstants::kFilePath, rndvFilepath}} ); - } } Tensor allreduceGet(fl::AverageValueMeter& mtr) { @@ -68,9 +67,8 @@ Tensor allreduceGet(fl::TopKMeter& mtr) { void allreduceSet(fl::AverageValueMeter& mtr, Tensor& val) { mtr.reset(); auto valVec = val.toHostVector(); - if(valVec[2] != 0) { + if(valVec[2] != 0) valVec[0] /= valVec[2]; - } mtr.add(valVec[0], valVec[2]); } @@ -88,9 +86,8 @@ void allreduceSet(fl::EditDistanceMeter& mtr, Tensor& val) { void allreduceSet(fl::CountMeter& mtr, Tensor& val) { mtr.reset(); auto valVec = val.toHostVector(); - for(size_t i = 0; i < valVec.size(); ++i) { + for(size_t i = 0; i < valVec.size(); ++i) mtr.add(i, valVec[i]); - } } void allreduceSet(fl::TimeMeter& mtr, Tensor& val) { diff --git a/flashlight/pkg/runtime/common/DistributedUtils.h b/flashlight/pkg/runtime/common/DistributedUtils.h index 6a52b35..7798bae 100644 --- a/flashlight/pkg/runtime/common/DistributedUtils.h +++ b/flashlight/pkg/runtime/common/DistributedUtils.h @@ -46,9 +46,8 @@ namespace pkg { */ template void syncMeter(T& mtr) { - if(!fl::isDistributedInit()) { + if(!fl::isDistributedInit()) return; - } Tensor arr = allreduceGet(mtr); fl::allReduce(arr); allreduceSet(mtr, arr); diff --git a/flashlight/pkg/runtime/common/SequentialBuilder.cpp b/flashlight/pkg/runtime/common/SequentialBuilder.cpp index 9e0972e..36de0ed 100644 --- a/flashlight/pkg/runtime/common/SequentialBuilder.cpp +++ b/flashlight/pkg/runtime/common/SequentialBuilder.cpp @@ -37,14 +37,12 @@ std::shared_ptr buildSequentialModule( std::vector layers; { std::ifstream in(archfile); - if(!in) { + if(!in) throw std::runtime_error( "fl::pkg::runtime::buildSequentialModule given invalid arch filepath" ); - } - for(std::string str; std::getline(in, str);) { + for(std::string str; std::getline(in, str);) layers.emplace_back(str); - } } int numLinesParsed = 0; @@ -87,11 +85,10 @@ fl::Variable forwardSequentialModuleWithPadMask( for(auto& module : ntwrkSeq->modules()) { auto tr = std::dynamic_pointer_cast(module); auto cfr = std::dynamic_pointer_cast(module); - if(tr != nullptr || cfr != nullptr) { + if(tr != nullptr || cfr != nullptr) output = module->forward({output, fl::noGrad(padMask)}).front(); - } else { + else output = module->forward({output}).front(); - } } return output.astype(input.type()); } @@ -119,24 +116,20 @@ std::shared_ptr parseLines( /* ========== TRANSFORMATIONS ========== */ if((params[0] == "RO") || (params[0] == "V")) { - if(params.size() < 2) { + if(params.size() < 2) throw std::invalid_argument("Failed parsing - " + line); - } Shape shape(std::vector(params.size() - 1)); - for(unsigned i = 1; i < params.size(); ++i) { + for(unsigned i = 1; i < params.size(); ++i) shape[i - 1] = std::stoi(params[i]); - } - if(params[0] == "RO") { + if(params[0] == "RO") return std::make_shared(shape); - } else { + else return std::make_shared(shape); - } } if(params[0] == "PD") { - if(!inRange(4, params.size(), 10) || (params.size() & 1)) { + if(!inRange(4, params.size(), 10) || (params.size() & 1)) throw std::invalid_argument("Failed parsing - " + line); - } auto val = std::stod(params[1]); params.resize(10, "0"); std::vector> paddings = { @@ -151,9 +144,8 @@ std::shared_ptr parseLines( /* ========== TRANSFORMERS ========== */ if(params[0] == "TR") { - if(!inRange(6, params.size(), 9)) { + if(!inRange(6, params.size(), 9)) throw std::invalid_argument("Failed parsing - " + line); - } int modelDim = std::stoi(params[1]); int mlpDim = std::stoi(params[2]); int nHead = std::stoi(params[3]); @@ -176,9 +168,8 @@ std::shared_ptr parseLines( } if(params[0] == "CFR") { - if(!inRange(7, params.size(), 8)) { + if(!inRange(7, params.size(), 8)) throw std::invalid_argument("Failed parsing - " + line); - } int modelDim = std::stoi(params[1]); int mlpDim = std::stoi(params[2]); int nHead = std::stoi(params[3]); @@ -199,9 +190,8 @@ std::shared_ptr parseLines( } if(params[0] == "POSEMB") { - if(!inRange(3, params.size(), 4)) { + if(!inRange(3, params.size(), 4)) throw std::invalid_argument("Failed parsing - " + line); - } int layerDim = std::stoi(params[1]); int csz = std::stoi(params[2]); float dropout = (params.size() >= 4) ? std::stof(params[3]) : 0.0; @@ -209,9 +199,8 @@ std::shared_ptr parseLines( } if(params[0] == "SINPOSEMB") { - if(!inRange(2, params.size(), 3)) { + if(!inRange(2, params.size(), 3)) throw std::invalid_argument("Failed parsing - " + line); - } int layerDim = std::stoi(params[1]); float inputScale = (params.size() >= 3) ? std::stof(params[2]) : 1.0; return std::make_shared(layerDim, inputScale); @@ -220,9 +209,8 @@ std::shared_ptr parseLines( /* ========== CONVOLUTIONS ========== */ if(params[0] == "C" || params[0] == "C1") { - if(!inRange(5, params.size(), 7)) { + if(!inRange(5, params.size(), 7)) throw std::invalid_argument("Failed parsing - " + line); - } int cisz = std::stoi(params[1]); int cosz = std::stoi(params[2]); int cwx = std::stoi(params[3]); @@ -233,9 +221,8 @@ std::shared_ptr parseLines( } if(params[0] == "TDS") { - if(!inRange(4, params.size(), 8)) { + if(!inRange(4, params.size(), 8)) throw std::invalid_argument("Failed parsing - " + line); - } int cisz = std::stoi(params[1]); int cwx = std::stoi(params[2]); int freqdim = std::stoi(params[3]); @@ -256,9 +243,8 @@ std::shared_ptr parseLines( } if(params[0] == "AC") { - if(!inRange(5, params.size(), 8)) { + if(!inRange(5, params.size(), 8)) throw std::invalid_argument("Failed parsing - " + line); - } int cisz = std::stoi(params[1]); int cosz = std::stoi(params[2]); int cwx = std::stoi(params[3]); @@ -278,9 +264,8 @@ std::shared_ptr parseLines( } if(params[0] == "C2") { - if(!inRange(7, params.size(), 11)) { + if(!inRange(7, params.size(), 11)) throw std::invalid_argument("Failed parsing - " + line); - } int cisz = std::stoi(params[1]); int cosz = std::stoi(params[2]); int cwx = std::stoi(params[3]); @@ -308,9 +293,8 @@ std::shared_ptr parseLines( /* ========== LINEAR ========== */ if(params[0] == "L") { - if(!inRange(3, params.size(), 4)) { + if(!inRange(3, params.size(), 4)) throw std::invalid_argument("Failed parsing - " + line); - } int lisz = std::stoi(params[1]); int losz = std::stoi(params[2]); bool bias = (params.size() == 4) && params[3] == "0" ? false : true; @@ -320,56 +304,47 @@ std::shared_ptr parseLines( /* ========== EMBEDDING ========== */ if(params[0] == "E") { - if(params.size() != 3) { + if(params.size() != 3) throw std::invalid_argument("Failed parsing - " + line); - } int embsz = std::stoi(params[1]); int ntokens = std::stoi(params[2]); return std::make_shared(embsz, ntokens); } if(params[0] == "ADAPTIVEE") { - if(params.size() != 3) { + if(params.size() != 3) throw std::invalid_argument("Failed parsing - " + line); - } int embsz = std::stoi(params[1]); std::vector cutoffs; auto tokens = fl::lib::split(',', params[2], true); - for(const auto& token : tokens) { + for(const auto& token : tokens) cutoffs.push_back(std::stoi(fl::lib::trim(token))); - } - for(int i = 1; i < cutoffs.size(); ++i) { - if(cutoffs[i - 1] >= cutoffs[i]) { + for(int i = 1; i < cutoffs.size(); ++i) + if(cutoffs[i - 1] >= cutoffs[i]) throw std::invalid_argument("cutoffs must be strictly ascending"); - } - } return std::make_shared(embsz, cutoffs); } /* ========== NORMALIZATIONS ========== */ if(params[0] == "BN") { - if(!inRange(3, params.size(), 5)) { + if(!inRange(3, params.size(), 5)) throw std::invalid_argument("Failed parsing - " + line); - } int featSz = std::stoi(params[1]); std::vector featDims; - for(int i = 2; i < params.size(); ++i) { + for(int i = 2; i < params.size(); ++i) featDims.emplace_back(std::stoi(params[i])); - } return std::make_shared(featDims, featSz); } if(params[0] == "LN") { - if(!inRange(2, params.size(), 4)) { + if(!inRange(2, params.size(), 4)) throw std::invalid_argument("Failed parsing - " + line); - } std::vector featDims; - for(int i = 1; i < params.size(); ++i) { + for(int i = 1; i < params.size(); ++i) featDims.emplace_back(std::stoi(params[i])); - } - if(featDims == std::vector{3}) { - if(!inRange(7, params.size(), 11)) { + if(featDims == std::vector{3}) + if(!inRange(7, params.size(), 11)) throw std::invalid_argument( "Failed parsing - " "flashlight LayerNorm API for specifying `featAxes` is modified " @@ -377,24 +352,20 @@ std::shared_ptr parseLines( "specify LN 0 1 2 instead of LN 3. If you really know what you're " "doing, comment out this check and build again." ); - } - } return std::make_shared(featDims); } if(params[0] == "WN") { - if(params.size() < 3) { + if(params.size() < 3) throw std::invalid_argument("Failed parsing - " + line); - } int dim = std::stoi(params[1]); std::string childStr = fl::lib::join(" ", params.begin() + 2, params.end()); return std::make_shared(parseLine(childStr), dim); } if(params[0] == "DO") { - if(params.size() != 2) { + if(params.size() != 2) throw std::invalid_argument("Failed parsing - " + line); - } auto drpVal = std::stod(params[1]); return std::make_shared(drpVal); } @@ -402,9 +373,8 @@ std::shared_ptr parseLines( /* ========== POOLING ========== */ if((params[0] == "M") || (params[0] == "A")) { - if(params.size() < 5) { + if(params.size() < 5) throw std::invalid_argument("Failed parsing - " + line); - } int wx = std::stoi(params[1]); int wy = std::stoi(params[2]); int dx = std::stoi(params[3]); @@ -420,76 +390,66 @@ std::shared_ptr parseLines( /* ========== ACTIVATIONS ========== */ if(params[0] == "ELU") { - if(params.size() != 1) { + if(params.size() != 1) throw std::invalid_argument("Failed parsing - " + line); - } return std::make_shared(); } if(params[0] == "R") { - if(params.size() != 1) { + if(params.size() != 1) throw std::invalid_argument("Failed parsing - " + line); - } return std::make_shared(); } if(params[0] == "R6") { - if(params.size() != 1) { + if(params.size() != 1) throw std::invalid_argument("Failed parsing - " + line); - } return std::make_shared(); } if(params[0] == "PR") { - if(!inRange(1, params.size(), 3)) { + if(!inRange(1, params.size(), 3)) throw std::invalid_argument("Failed parsing - " + line); - } auto numParams = params.size() > 1 ? std::stoi(params[1]) : 1; auto initVal = params.size() > 2 ? std::stod(params[2]) : 0.25; return std::make_shared(numParams, initVal); } if(params[0] == "LG") { - if(params.size() != 1) { + if(params.size() != 1) throw std::invalid_argument("Failed parsing - " + line); - } return std::make_shared(); } if(params[0] == "HT") { - if(params.size() != 1) { + if(params.size() != 1) throw std::invalid_argument("Failed parsing - " + line); - } return std::make_shared(); } if(params[0] == "T") { - if(params.size() != 1) { + if(params.size() != 1) throw std::invalid_argument("Failed parsing - " + line); - } return std::make_shared(); } if(params[0] == "GLU") { - if(params.size() != 2) { + if(params.size() != 2) throw std::invalid_argument("Failed parsing - " + line); - } int dim = std::stoi(params[1]); return std::make_shared(dim); } if(params[0] == "LSM") { - if(params.size() != 2) { + if(params.size() != 2) throw std::invalid_argument("Failed parsing - " + line); - } int dim = std::stoi(params[1]); return std::make_shared(dim); } if(params[0] == "SH") { - if(!inRange(1, params.size(), 2)) { + if(!inRange(1, params.size(), 2)) throw std::invalid_argument("Failed parsing - " + line); - } auto beta = params.size() > 1 ? std::stof(params[1]) : 1.0; return std::make_shared(beta); } @@ -513,31 +473,27 @@ std::shared_ptr parseLines( }; if(params[0] == "RNN") { - if(params.size() < 3) { + if(params.size() < 3) throw std::invalid_argument("Failed parsing - " + line); - } return rnnLayer(params, RnnMode::RELU); } if(params[0] == "GRU") { - if(params.size() < 3) { + if(params.size() < 3) throw std::invalid_argument("Failed parsing - " + line); - } return rnnLayer(params, RnnMode::GRU); } if(params[0] == "LSTM") { - if(params.size() < 3) { + if(params.size() < 3) throw std::invalid_argument("Failed parsing - " + line); - } return rnnLayer(params, RnnMode::LSTM); } /* ========== Residual block ========== */ if(params[0] == "RES") { - if(params.size() <= 3) { + if(params.size() <= 3) throw std::invalid_argument("Failed parsing - " + line); - } auto residualBlock = [&](const std::vector& prms, int& numResLayerAndSkip) { @@ -548,37 +504,32 @@ std::shared_ptr parseLines( int numProjections = 0; for(int i = 1; i <= numResLayers + numSkipConnections; ++i) { - if(lineIdx + i + numProjections >= lines.size()) { + if(lineIdx + i + numProjections >= lines.size()) throw std::invalid_argument("Failed parsing Residual block"); - } const std::string& resLine = lines[lineIdx + i + numProjections]; auto resLinePrms = fl::lib::splitOnWhitespace(resLine, true); if(resLinePrms[0] == "SKIP") { - if(!inRange(3, resLinePrms.size(), 4)) { + if(!inRange(3, resLinePrms.size(), 4)) throw std::invalid_argument("Failed parsing - " + resLine); - } resPtr->addShortcut( std::stoi(resLinePrms[1]), std::stoi(resLinePrms[2]) ); - if(resLinePrms.size() == 4) { + if(resLinePrms.size() == 4) resPtr->addScale( std::stoi(resLinePrms[2]), std::stof(resLinePrms[3]) ); - } } else if(resLinePrms[0] == "SKIPL") { - if(!inRange(4, resLinePrms.size(), 5)) { + if(!inRange(4, resLinePrms.size(), 5)) throw std::invalid_argument("Failed parsing - " + resLine); - } int numProjectionLayers = std::stoi(resLinePrms[3]); auto projection = std::make_shared(); for(int j = 1; j <= numProjectionLayers; ++j) { - if(lineIdx + i + numProjections + j >= lines.size()) { + if(lineIdx + i + numProjections + j >= lines.size()) throw std::invalid_argument("Failed parsing Residual block"); - } projection->add(parseLine(lines[lineIdx + i + numProjections + j])); } resPtr->addShortcut( @@ -586,16 +537,14 @@ std::shared_ptr parseLines( std::stoi(resLinePrms[2]), projection ); - if(resLinePrms.size() == 5) { + if(resLinePrms.size() == 5) resPtr->addScale( std::stoi(resLinePrms[2]), std::stof(resLinePrms[4]) ); - } numProjections += numProjectionLayers; - } else { + } else resPtr->add(parseLine(resLine)); - } } numResLayerAndSkip = numResLayers + numSkipConnections + numProjections; @@ -603,28 +552,24 @@ std::shared_ptr parseLines( }; auto numBlocks = params.size() == 4 ? std::stoi(params.back()) : 1; - if(numBlocks <= 0) { + if(numBlocks <= 0) throw std::invalid_argument( "Invalid number of residual blocks: " + std::to_string(numBlocks) ); - } if(numBlocks > 1) { auto res = std::make_shared(); - for(int n = 0; n < numBlocks; ++n) { + for(int n = 0; n < numBlocks; ++n) res->add(residualBlock(params, numLinesParsed)); - } return res; - } else { + } else return residualBlock(params, numLinesParsed); - } } /* ========== Data Augmentation ========== */ if(params[0] == "SAUG") { - if(params.size() != 7) { + if(params.size() != 7) throw std::invalid_argument("Failed parsing - " + line); - } return std::make_shared( std::stoi(params[1]), std::stoi(params[2]), @@ -637,9 +582,8 @@ std::shared_ptr parseLines( /* ========== Precision Cast ========== */ if(params[0] == "PC") { - if(params.size() != 2) { + if(params.size() != 2) throw std::invalid_argument("Failed parsing - " + line); - } auto targetType = fl::stringToDtype(params[1]); return std::make_shared(targetType); } diff --git a/flashlight/pkg/runtime/common/Serializer.h b/flashlight/pkg/runtime/common/Serializer.h index da74b01..759922e 100644 --- a/flashlight/pkg/runtime/common/Serializer.h +++ b/flashlight/pkg/runtime/common/Serializer.h @@ -58,11 +58,10 @@ namespace pkg { ) { try { std::ofstream file(filepath, std::ios::binary); - if(!file.is_open()) { + if(!file.is_open()) throw std::runtime_error( "failed to open file for writing: " + filepath.string() ); - } cereal::BinaryOutputArchive ar(file); ar(version); ar(args...); @@ -77,11 +76,10 @@ namespace pkg { static void loadImpl(const fs::path& filepath, Args&... args) { try { std::ifstream file(filepath, std::ios::binary); - if(!file.is_open()) { + if(!file.is_open()) throw std::runtime_error( "failed to open file for reading: " + filepath.string() ); - } cereal::BinaryInputArchive ar(file); ar(args...); } catch(const std::exception& ex) { diff --git a/flashlight/pkg/runtime/test/common/SequentialBuilderTest.cpp b/flashlight/pkg/runtime/test/common/SequentialBuilderTest.cpp index d60d8da..da87793 100644 --- a/flashlight/pkg/runtime/test/common/SequentialBuilderTest.cpp +++ b/flashlight/pkg/runtime/test/common/SequentialBuilderTest.cpp @@ -22,9 +22,8 @@ fs::path archDir = ""; } // namespace TEST(SequentialBuilderTest, SeqModule) { - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) GTEST_SKIP() << "Bidirectional RNN not supported"; - } const fs::path archfile = archDir / "arch.txt"; int nchannel = 4; int nclass = 40; @@ -46,14 +45,12 @@ TEST(SequentialBuilderTest, SeqModule) { } TEST(SequentialBuilderTest, Serialization) { - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) GTEST_SKIP() << "Bidirectional RNN not supported"; - } char* user = getenv("USER"); std::string userstr = "unknown"; - if(user != nullptr) { + if(user != nullptr) userstr = std::string(user); - } const fs::path path = fs::temp_directory_path() / "test.mdl"; const fs::path archfile = archDir / "arch.txt"; diff --git a/flashlight/pkg/speech/audio/feature/Ceplifter.cpp b/flashlight/pkg/speech/audio/feature/Ceplifter.cpp index 91f3a28..f13af59 100644 --- a/flashlight/pkg/speech/audio/feature/Ceplifter.cpp +++ b/flashlight/pkg/speech/audio/feature/Ceplifter.cpp @@ -18,9 +18,8 @@ Ceplifter::Ceplifter(int numfilters, int lifterparam) : numFilters_(numfilters), lifterParam_(lifterparam), coefs_(numFilters_) { std::iota(coefs_.begin(), coefs_.end(), 0.0); - for(auto& c : coefs_) { + for(auto& c : coefs_) c = 1.0 + 0.5 * lifterParam_ * std::sin(M_PI * c / lifterParam_); - } } std::vector Ceplifter::apply(const std::vector& input) const { @@ -30,17 +29,15 @@ std::vector Ceplifter::apply(const std::vector& input) const { } void Ceplifter::applyInPlace(std::vector& input) const { - if(input.size() % numFilters_ != 0) { + if(input.size() % numFilters_ != 0) throw std::invalid_argument( "Ceplifter: input size is not divisible by numFilters" ); - } size_t n = 0; for(auto& in : input) { in *= coefs_[n++]; - if(n == numFilters_) { + if(n == numFilters_) n = 0; - } } } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Dct.cpp b/flashlight/pkg/speech/audio/feature/Dct.cpp index c88c339..18d2646 100644 --- a/flashlight/pkg/speech/audio/feature/Dct.cpp +++ b/flashlight/pkg/speech/audio/feature/Dct.cpp @@ -18,12 +18,10 @@ namespace fl::lib::audio { Dct::Dct(int numfilters, int numceps) : numFilters_(numfilters), numCeps_(numceps), dctMat_(numfilters * numceps) { - for(size_t f = 0; f < numFilters_; ++f) { - for(size_t c = 0; c < numCeps_; ++c) { + for(size_t f = 0; f < numFilters_; ++f) + for(size_t c = 0; c < numCeps_; ++c) dctMat_[f * numCeps_ + c] = std::sqrt(2.0 / numFilters_) * std::cos(M_PI * c * (f + 0.5) / numFilters_); - } - } } std::vector Dct::apply(const std::vector& input) const { diff --git a/flashlight/pkg/speech/audio/feature/Derivatives.cpp b/flashlight/pkg/speech/audio/feature/Derivatives.cpp index 084f845..f1d1008 100644 --- a/flashlight/pkg/speech/audio/feature/Derivatives.cpp +++ b/flashlight/pkg/speech/audio/feature/Derivatives.cpp @@ -20,15 +20,13 @@ std::vector Derivatives::apply( const std::vector& input, int numfeat ) const { - if(input.size() % numfeat != 0) { + if(input.size() % numfeat != 0) throw std::invalid_argument( "Derivatives: input size is not divisible by numFeatures" ); - } // Compute deltas - if(deltaWindow_ <= 0) { + if(deltaWindow_ <= 0) return input; - } auto deltas = computeDerivative(input, deltaWindow_, numfeat); size_t szMul = 2; @@ -56,13 +54,12 @@ std::vector Derivatives::apply( output.data() + curOutIdx + numfeat ); // copy double-deltas - if(accWindow_ > 0) { + if(accWindow_ > 0) std::copy( doubledeltas.data() + curInIdx, doubledeltas.data() + curInIdx + numfeat, output.data() + curOutIdx + 2 * numfeat ); - } } return output; } @@ -75,17 +72,15 @@ std::vector Derivatives::computeDerivative( int numframes = input.size() / numfeat; std::vector output(input.size(), 0.0); float denominator = (windowlen * (windowlen + 1) * (2 * windowlen + 1)) / 3.0; - for(size_t i = 0; i < numframes; ++i) { + for(size_t i = 0; i < numframes; ++i) for(size_t j = 0; j < numfeat; ++j) { size_t curIdx = i * numfeat + j; - for(size_t d = 1; d <= windowlen; ++d) { + for(size_t d = 1; d <= windowlen; ++d) output[curIdx] += d * (input[curIdx + std::min((numframes - i - 1), d) * numfeat] - input[curIdx - std::min(i, d) * numfeat]); - } output[curIdx] /= denominator; } - } return output; } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Dither.cpp b/flashlight/pkg/speech/audio/feature/Dither.cpp index 08a135f..9a7b923 100644 --- a/flashlight/pkg/speech/audio/feature/Dither.cpp +++ b/flashlight/pkg/speech/audio/feature/Dither.cpp @@ -22,8 +22,7 @@ std::vector Dither::apply(const std::vector& input) { void Dither::applyInPlace(std::vector& input) { std::uniform_real_distribution distribution(0.0, 1.0); - for(auto& i : input) { + for(auto& i : input) i += ditherVal_ * distribution(rng_); - } } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/FeatureParams.h b/flashlight/pkg/speech/audio/feature/FeatureParams.h index f5562cb..c2a494e 100644 --- a/flashlight/pkg/speech/audio/feature/FeatureParams.h +++ b/flashlight/pkg/speech/audio/feature/FeatureParams.h @@ -161,9 +161,8 @@ namespace lib { int64_t numFrames(int64_t inSize) const { auto frameSize = numFrameSizeSamples(); auto frameStride = numFrameStrideSamples(); - if(frameStride <= 0 || inSize < frameSize) { + if(frameStride <= 0 || inSize < frameSize) return 0; - } return 1 + std::floor((inSize - frameSize) * 1.0 / frameStride); } }; diff --git a/flashlight/pkg/speech/audio/feature/Mfcc.cpp b/flashlight/pkg/speech/audio/feature/Mfcc.cpp index 838893d..dde015c 100644 --- a/flashlight/pkg/speech/audio/feature/Mfcc.cpp +++ b/flashlight/pkg/speech/audio/feature/Mfcc.cpp @@ -22,38 +22,34 @@ Mfcc::Mfcc(const FeatureParams& params) : Mfsc(params), std::vector Mfcc::apply(const std::vector& input) { auto frames = frameSignal(input, this->featParams_); - if(frames.empty()) { + if(frames.empty()) return {}; - } int nSamples = this->featParams_.numFrameSizeSamples(); int nFrames = frames.size() / nSamples; std::vector energy(nFrames); - if(this->featParams_.useEnergy && this->featParams_.rawEnergy) { + if(this->featParams_.useEnergy && this->featParams_.rawEnergy) for(size_t f = 0; f < nFrames; ++f) { auto begin = frames.data() + f * nSamples; energy[f] = std::log(std::inner_product(begin, begin + nSamples, begin, 0.0)); } - } auto mfscfeat = this->mfscImpl(frames); auto cep = dct_.apply(mfscfeat); ceplifter_.applyInPlace(cep); auto nFeat = this->featParams_.numCepstralCoeffs; if(this->featParams_.useEnergy) { - if(!this->featParams_.rawEnergy) { + if(!this->featParams_.rawEnergy) for(size_t f = 0; f < nFrames; ++f) { auto begin = frames.data() + f * nSamples; energy[f] = std::log(std::inner_product(begin, begin + nSamples, begin, 0.0)); } - } // Replace C0 with energy - for(size_t f = 0; f < nFrames; ++f) { + for(size_t f = 0; f < nFrames; ++f) cep[f * nFeat] = energy[f]; - } } return derivatives_.apply(cep, nFeat); } @@ -65,8 +61,7 @@ int Mfcc::outputSize(int inputSz) { void Mfcc::validateMfccParams() const { this->validatePowSpecParams(); this->validateMfscParams(); - if(this->featParams_.lifterParam < 0) { + if(this->featParams_.lifterParam < 0) throw std::invalid_argument("Mfcc: lifterparam must be nonnegative"); - } } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Mfsc.cpp b/flashlight/pkg/speech/audio/feature/Mfsc.cpp index 7669db1..db0fb5e 100644 --- a/flashlight/pkg/speech/audio/feature/Mfsc.cpp +++ b/flashlight/pkg/speech/audio/feature/Mfsc.cpp @@ -29,15 +29,14 @@ Mfsc::Mfsc(const FeatureParams& params) : PowerSpectrum(params), std::vector Mfsc::apply(const std::vector& input) { auto frames = frameSignal(input, this->featParams_); - if(frames.empty()) { + if(frames.empty()) return {}; - } int nSamples = this->featParams_.numFrameSizeSamples(); int nFrames = frames.size() / nSamples; std::vector energy(nFrames); - if(this->featParams_.useEnergy && this->featParams_.rawEnergy) { + if(this->featParams_.useEnergy && this->featParams_.rawEnergy) for(size_t f = 0; f < nFrames; ++f) { auto begin = frames.data() + f * nSamples; energy[f] = std::log( @@ -52,11 +51,10 @@ std::vector Mfsc::apply(const std::vector& input) { ) ); } - } auto mfscFeat = mfscImpl(frames); auto numFeat = this->featParams_.numFilterbankChans; if(this->featParams_.useEnergy) { - if(!this->featParams_.rawEnergy) { + if(!this->featParams_.rawEnergy) for(size_t f = 0; f < nFrames; ++f) { auto begin = frames.data() + f * nSamples; energy[f] = std::log( @@ -71,7 +69,6 @@ std::vector Mfsc::apply(const std::vector& input) { ) ); } - } std::vector newMfscFeat(mfscFeat.size() + nFrames); for(size_t f = 0; f < nFrames; ++f) { size_t start = f * numFeat; @@ -91,13 +88,12 @@ std::vector Mfsc::apply(const std::vector& input) { std::vector Mfsc::mfscImpl(std::vector& frames) { auto powspectrum = this->powSpectrumImpl(frames); - if(this->featParams_.usePower) { + if(this->featParams_.usePower) std::transform( powspectrum.begin(), powspectrum.end(), powspectrum.begin(), [](float x) { return x * x; }); - } auto triflt = triFltBank_.apply(powspectrum, this->featParams_.melFloor); std::transform( triflt.begin(), @@ -116,10 +112,9 @@ int Mfsc::outputSize(int inputSz) { void Mfsc::validateMfscParams() const { this->validatePowSpecParams(); - if(this->featParams_.numFilterbankChans <= 0) { + if(this->featParams_.numFilterbankChans <= 0) throw std::invalid_argument("Mfsc: numFilterbankChans must be positive"); - } else if(this->featParams_.melFloor <= 0.0) { + else if(this->featParams_.melFloor <= 0.0) throw std::invalid_argument("Mfsc: melfloor must be positive"); - } } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/PowerSpectrum.cpp b/flashlight/pkg/speech/audio/feature/PowerSpectrum.cpp index 38c626b..0fdb667 100644 --- a/flashlight/pkg/speech/audio/feature/PowerSpectrum.cpp +++ b/flashlight/pkg/speech/audio/feature/PowerSpectrum.cpp @@ -47,9 +47,8 @@ PowerSpectrum::PowerSpectrum(const FeatureParams& params) : featParams_(params), std::vector PowerSpectrum::apply(const std::vector& input) { auto frames = frameSignal(input, featParams_); - if(frames.empty()) { + if(frames.empty()) return {}; - } return powSpectrumImpl(frames); } @@ -59,10 +58,9 @@ std::vector PowerSpectrum::powSpectrumImpl(std::vector& frames) { int nFft = featParams_.nFft(); int K = featParams_.filterFreqResponseLen(); - if(featParams_.ditherVal != 0.0) { + if(featParams_.ditherVal != 0.0) frames = dither_.apply(frames); - } - if(featParams_.zeroMeanFrame) { + if(featParams_.zeroMeanFrame) for(size_t f = 0; f < nFrames; ++f) { auto begin = frames.data() + f * nSamples; float mean = std::accumulate(begin, begin + nSamples, 0.0); @@ -73,10 +71,8 @@ std::vector PowerSpectrum::powSpectrumImpl(std::vector& frames) { begin, [mean](float x) { return x - mean; }); } - } - if(featParams_.preemCoef != 0) { + if(featParams_.preemCoef != 0) preEmphasis_.applyInPlace(frames); - } windowing_.applyInPlace(frames); std::vector dft(K * nFrames); for(size_t f = 0; f < nFrames; ++f) { @@ -93,12 +89,11 @@ std::vector PowerSpectrum::powSpectrumImpl(std::vector& frames) { outFftBuf_[2 * i + 1] = -outFftBuf_[2 * nFft - 2 * i + 1]; } - for(size_t i = 0; i < K; ++i) { + for(size_t i = 0; i < K; ++i) dft[f * K + i] = std::sqrt( outFftBuf_[2 * i] * outFftBuf_[2 * i] + outFftBuf_[2 * i + 1] * outFftBuf_[2 * i + 1] ); - } } } return dft; @@ -108,13 +103,12 @@ std::vector PowerSpectrum::batchApply( const std::vector& input, int batchSz ) { - if(batchSz <= 0) { + if(batchSz <= 0) throw std::invalid_argument("PowerSpectrum: negative batchSz"); - } else if(input.size() % batchSz != 0) { + else if(input.size() % batchSz != 0) throw std::invalid_argument( "PowerSpectrum: input size is not divisible by batchSz" ); - } int N = input.size() / batchSz; int outputSz = outputSize(N); std::vector feat(outputSz * batchSz); @@ -124,9 +118,8 @@ std::vector PowerSpectrum::batchApply( auto start = input.begin() + b * N; std::vector inputBuf(start, start + N); auto curFeat = apply(inputBuf); - if(outputSz != curFeat.size()) { + if(outputSz != curFeat.size()) throw std::logic_error("PowerSpectrum: apply() returned wrong size"); - } std::copy( curFeat.begin(), curFeat.end(), @@ -145,17 +138,16 @@ int PowerSpectrum::outputSize(int inputSz) { } void PowerSpectrum::validatePowSpecParams() const { - if(featParams_.samplingFreq <= 0) { + if(featParams_.samplingFreq <= 0) throw std::invalid_argument("PowerSpectrum: samplingFreq is negative"); - } else if(featParams_.frameSizeMs <= 0) { + else if(featParams_.frameSizeMs <= 0) throw std::invalid_argument("PowerSpectrum: frameSizeMs is negative"); - } else if(featParams_.frameStrideMs <= 0) { + else if(featParams_.frameStrideMs <= 0) throw std::invalid_argument("PowerSpectrum: frameStrideMs is negative"); - } else if(featParams_.numFrameSizeSamples() <= 0) { + else if(featParams_.numFrameSizeSamples() <= 0) throw std::invalid_argument("PowerSpectrum: frameSizeMs is too low"); - } else if(featParams_.numFrameStrideSamples() <= 0) { + else if(featParams_.numFrameStrideSamples() <= 0) throw std::invalid_argument("PowerSpectrum: frameStrideMs is too low"); - } } PowerSpectrum::~PowerSpectrum() { diff --git a/flashlight/pkg/speech/audio/feature/PreEmphasis.cpp b/flashlight/pkg/speech/audio/feature/PreEmphasis.cpp index 8b55836..807dd18 100644 --- a/flashlight/pkg/speech/audio/feature/PreEmphasis.cpp +++ b/flashlight/pkg/speech/audio/feature/PreEmphasis.cpp @@ -14,12 +14,10 @@ namespace fl::lib::audio { PreEmphasis::PreEmphasis(float alpha, int N) : preemCoef_(alpha), windowLength_(N) { - if(windowLength_ <= 1) { + if(windowLength_ <= 1) throw std::invalid_argument("PreEmphasis: windowLength must be > 1"); - } - if(preemCoef_ < 0.0 || preemCoef_ >= 1.0) { + if(preemCoef_ < 0.0 || preemCoef_ >= 1.0) throw std::invalid_argument("PreEmphasis: alpha must be in [0, 1)"); - } }; std::vector PreEmphasis::apply(const std::vector& input) const { @@ -29,18 +27,16 @@ std::vector PreEmphasis::apply(const std::vector& input) const { } void PreEmphasis::applyInPlace(std::vector& input) const { - if(input.size() % windowLength_ != 0) { + if(input.size() % windowLength_ != 0) throw std::invalid_argument( "PreEmphasis: input.size() not divisible by windowLength" ); - } size_t nframes = input.size() / windowLength_; for(size_t n = nframes; n > 0; --n) { size_t e = n * windowLength_ - 1; // end of current frame size_t s = (n - 1) * windowLength_; // start of current frame - for(size_t i = e; i > s; --i) { + for(size_t i = e; i > s; --i) input[i] -= (preemCoef_ * input[i - 1]); - } input[s] *= (1 - preemCoef_); } } diff --git a/flashlight/pkg/speech/audio/feature/SpeechUtils.cpp b/flashlight/pkg/speech/audio/feature/SpeechUtils.cpp index ff5c40a..0f1a59b 100644 --- a/flashlight/pkg/speech/audio/feature/SpeechUtils.cpp +++ b/flashlight/pkg/speech/audio/feature/SpeechUtils.cpp @@ -32,11 +32,9 @@ std::vector frameSignal( // not range -1..1, hence scale up here to match (approx) float scale = 32768.0; std::vector frames(numframes * frameSize); - for(size_t f = 0; f < numframes; ++f) { - for(size_t i = 0; i < frameSize; ++i) { + for(size_t f = 0; f < numframes; ++f) + for(size_t i = 0; i < frameSize; ++i) frames[f * frameSize + i] = scale * input[f * frameStride + i]; - } - } return frames; } @@ -49,9 +47,8 @@ std::vector cblasGemm( if( n <= 0 || k <= 0 || matA.empty() || (matA.size() % k != 0) || (matB.size() != n * k) - ) { + ) throw std::invalid_argument("cblasGemm: invalid arguments"); - } int m = matA.size() / k; diff --git a/flashlight/pkg/speech/audio/feature/TriFilterbank.cpp b/flashlight/pkg/speech/audio/feature/TriFilterbank.cpp index ba58d4d..3853415 100644 --- a/flashlight/pkg/speech/audio/feature/TriFilterbank.cpp +++ b/flashlight/pkg/speech/audio/feature/TriFilterbank.cpp @@ -35,20 +35,18 @@ TriFilterbank::TriFilterbank( float dwarp = (maxwarpfreq - minwarpfreq) / (numfilters + 1); std::vector f(numFilters_ + 2); - for(int i = 0; i < (numFilters_ + 2); ++i) { + for(int i = 0; i < (numFilters_ + 2); ++i) f[i] = warpedToHertzScale(i * dwarp + minwarpfreq, freqScale_) * (filterLen_ - 1) * 2.0 / samplingFreq_; - } float minH = 0.0; - for(size_t i = 0; i < filterLen_; ++i) { + for(size_t i = 0; i < filterLen_; ++i) for(size_t j = 0; j < numFilters_; ++j) { float hislope = (i - f[j]) / (f[j + 1] - f[j]); float loslope = (f[j + 2] - i) / (f[j + 2] - f[j + 1]); H_[i * numFilters_ + j] = std::max(std::min(hislope, loslope), minH); } - } } std::vector TriFilterbank::apply( diff --git a/flashlight/pkg/speech/audio/feature/Windowing.cpp b/flashlight/pkg/speech/audio/feature/Windowing.cpp index aa9d10d..c0de49a 100644 --- a/flashlight/pkg/speech/audio/feature/Windowing.cpp +++ b/flashlight/pkg/speech/audio/feature/Windowing.cpp @@ -17,20 +17,17 @@ namespace fl::lib::audio { Windowing::Windowing(int N, WindowType windowtype) : windowLength_(N), windowType_(windowtype), coefs_(N) { - if(windowLength_ <= 1) { + if(windowLength_ <= 1) throw std::invalid_argument("Windowing: windowLength must be > 1"); - } std::iota(coefs_.begin(), coefs_.end(), 0.0); switch(windowtype) { case WindowType::HAMMING: - for(auto& c : coefs_) { + for(auto& c : coefs_) c = 0.54 - 0.46 * std::cos(2 * M_PI * c / (N - 1)); - } break; case WindowType::HANNING: - for(auto& c : coefs_) { + for(auto& c : coefs_) c = 0.5 * (1.0 - std::cos(2 * M_PI * c / (N - 1))); - } break; default: throw std::invalid_argument("Windowing: unsupported window type"); @@ -44,17 +41,15 @@ std::vector Windowing::apply(const std::vector& input) const { } void Windowing::applyInPlace(std::vector& input) const { - if(input.size() % windowLength_ != 0) { + if(input.size() % windowLength_ != 0) throw std::invalid_argument( "Windowing: input size is not divisible by windowLength" ); - } size_t n = 0; for(auto& in : input) { in *= coefs_[n++]; - if(n == windowLength_) { + if(n == windowLength_) n = 0; - } } } } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp b/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp index 31f1624..bf1cb4c 100644 --- a/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp +++ b/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp @@ -38,18 +38,16 @@ AdditiveNoise::AdditiveNoise( ) : conf_(config), rng_(seed) { std::ifstream listFile(conf_.listFilePath_); - if(!listFile) { + if(!listFile) throw std::runtime_error( "AdditiveNoise failed to open listFilePath_=" + conf_.listFilePath_ ); - } while(!listFile.eof()) { try { std::string filename; std::getline(listFile, filename); - if(!filename.empty()) { + if(!filename.empty()) noiseFiles_.push_back(filename); - } } catch(std::exception& ex) { throw std::runtime_error( "AdditiveNoise failed to read listFilePath_=" + conf_.listFilePath_ @@ -60,15 +58,13 @@ AdditiveNoise::AdditiveNoise( } void AdditiveNoise::apply(std::vector& signal) { - if(rng_.random() >= conf_.proba_) { + if(rng_.random() >= conf_.proba_) return; - } const float signalRms = rootMeanSquare(signal); const float snr = rng_.uniform(conf_.minSnr_, conf_.maxSnr_); const int nClips = rng_.randInt(conf_.nClipsMin_, conf_.nClipsMax_); - if(nClips == 0) { + if(nClips == 0) return; - } int augStart = rng_.randInt(0, signal.size() - 1); // overflow implies we start at the beginning again. int augEnd = augStart + conf_.ratio_ * signal.size(); @@ -78,23 +74,20 @@ void AdditiveNoise::apply(std::vector& signal) { auto curNoiseFileIdx = rng_.randInt(0, noiseFiles_.size() - 1); auto curNoise = loadSound(noiseFiles_[curNoiseFileIdx]); int shift = rng_.randInt(0, curNoise.size() - 1); - for(int j = augStart; j < augEnd; ++j) { + for(int j = augStart; j < augEnd; ++j) mixedNoise[j % mixedNoise.size()] += curNoise[(shift + j) % curNoise.size()]; - } } const float noiseRms = rootMeanSquare(mixedNoise); if(noiseRms > 0) { // https://en.wikipedia.org/wiki/Signal-to-noise_ratio const float noiseMult = (signalRms / (noiseRms * std::pow(10, snr / 20.0))); - for(int i = 0; i < signal.size(); ++i) { + for(int i = 0; i < signal.size(); ++i) signal[i] += mixedNoise[i] * noiseMult; - } - } else { + } else FL_LOG(fl::LogLevel::WARNING) << "AdditiveNoise::apply() invalid noiseRms=" << noiseRms; - } } } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/GaussianNoise.cpp b/flashlight/pkg/speech/augmentation/GaussianNoise.cpp index b276115..a1a5a0c 100644 --- a/flashlight/pkg/speech/augmentation/GaussianNoise.cpp +++ b/flashlight/pkg/speech/augmentation/GaussianNoise.cpp @@ -34,16 +34,14 @@ GaussianNoise::GaussianNoise( rng_(seed) {} void GaussianNoise::apply(std::vector& signal) { - if(rng_.random() >= conf_.proba_) { + if(rng_.random() >= conf_.proba_) return; - } const float signalRms = rootMeanSquare(signal); const float snr = rng_.uniform(conf_.minSnr_, conf_.maxSnr_); const float noiseMult = signalRms / std::pow(10, snr / 20.0); - for(int i = 0; i < signal.size(); ++i) { + for(int i = 0; i < signal.size(); ++i) signal[i] += rng_.gaussian(0, noiseMult); - } } } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/Reverberation.cpp b/flashlight/pkg/speech/augmentation/Reverberation.cpp index 7e2e03b..c7ca266 100644 --- a/flashlight/pkg/speech/augmentation/Reverberation.cpp +++ b/flashlight/pkg/speech/augmentation/Reverberation.cpp @@ -43,12 +43,10 @@ void ReverbEcho::applyReverb( // Add jitter noise for the delay float jitter = 1 + rng_.uniform(-conf_.jitter_, conf_.jitter_); size_t delay = 1 + int(jitter * firstDelay * conf_.sampleRate_); - if(delay > length - 1) { + if(delay > length - 1) break; - } - for(int j = 0; j < length - delay - 1; ++j) { + for(int j = 0; j < length - delay - 1; ++j) reverb[delay + j] += echo[j] * frac; - } // Add jitter noise for the attenuation jitter = 1 + rng_.uniform(-conf_.jitter_, conf_.jitter_); @@ -57,15 +55,13 @@ void ReverbEcho::applyReverb( frac *= attenuation; } } - for(int i = 0; i < length; ++i) { + for(int i = 0; i < length; ++i) source[i] += reverb[i]; - } } void ReverbEcho::apply(std::vector& sound) { - if(rng_.random() >= conf_.proba_) { + if(rng_.random() >= conf_.proba_) return; - } // Sample characteristics for the reverb float initial = rng_.uniform(conf_.initialMin_, conf_.initialMax_); float firstDelay = rng_.uniform(conf_.firstDelayMin_, conf_.firstDelayMax_); diff --git a/flashlight/pkg/speech/augmentation/SoundEffect.cpp b/flashlight/pkg/speech/augmentation/SoundEffect.cpp index 2137ff8..07f5912 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffect.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffect.cpp @@ -16,9 +16,8 @@ namespace fl::pkg::speech::sfx { std::string SoundEffectChain::prettyString() const { std::stringstream ss; ss << '{' << std::endl; - for(const std::shared_ptr& sfx : soundEffects_) { + for(const std::shared_ptr& sfx : soundEffects_) ss << "{" << sfx->prettyString() << '}' << std::endl; - } ss << '}'; return ss.str(); } @@ -28,9 +27,8 @@ void SoundEffectChain::add(std::shared_ptr SoundEffect) { } void SoundEffectChain::apply(std::vector& sound) { - for(std::shared_ptr& effect : soundEffects_) { + for(std::shared_ptr& effect : soundEffects_) effect->apply(sound); - } } bool SoundEffectChain::empty() { @@ -41,16 +39,14 @@ Normalize::Normalize(bool onlyIfTooHigh) : onlyIfTooHigh_(onlyIfTooHigh) {} void Normalize::apply(std::vector& sound) { float maxAbs = 0.0f; - for(float i : sound) { + for(float i : sound) maxAbs = std::fmax(maxAbs, std::fabs(i)); - } - if(!onlyIfTooHigh_ || maxAbs > 1.0f) { + if(!onlyIfTooHigh_ || maxAbs > 1.0f) std::transform( sound.begin(), sound.end(), sound.begin(), [maxAbs](float amp) -> float { return amp / maxAbs; }); - } } std::string Normalize::prettyString() const { diff --git a/flashlight/pkg/speech/augmentation/SoundEffectApply.cpp b/flashlight/pkg/speech/augmentation/SoundEffectApply.cpp index 533f4c0..323ce78 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffectApply.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffectApply.cpp @@ -36,18 +36,15 @@ int main(int argc, char** argv) { + "--config=[path to config file]" ); - if(argc <= 1) { + if(argc <= 1) LOG(FATAL) << gflags::ProgramUsage(); - } gflags::ParseCommandLineFlags(&argc, &argv, false); - if(FLAGS_config.empty()) { + if(FLAGS_config.empty()) LOG(FATAL) << "flag --config must point to sound effect config file"; - } - if(FLAGS_input.empty()) { + if(FLAGS_input.empty()) LOG(FATAL) << "flag --input must point to input file"; - } auto sound = loadSound(FLAGS_input); auto info = loadSoundInfo(FLAGS_input); diff --git a/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp b/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp index 1fbc3f9..9144dfb 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp @@ -71,22 +71,21 @@ void serialize(Archive& ar, TimeStretch::Config& conf) { template void serialize(Archive& ar, SoundEffectConfig& conf) { ar(cereal::make_nvp("type", conf.type_)); - if(conf.type_ == kAdditiveNoise) { + if(conf.type_ == kAdditiveNoise) ar(cereal::make_nvp("additiveNoiseConfig", conf.additiveNoiseConfig_)); - } else if(conf.type_ == kAmplify) { + else if(conf.type_ == kAmplify) ar(cereal::make_nvp("amplifyConfig", conf.amplifyConfig_)); - } else if(conf.type_ == kNormalize) { + else if(conf.type_ == kNormalize) ar( cereal::make_nvp( "normalizeOnlyIfTooHigh", conf.normalizeOnlyIfTooHigh_ ) ); - } else if(conf.type_ == kReverbEcho) { + else if(conf.type_ == kReverbEcho) ar(cereal::make_nvp("reverbEchoConfig", conf.reverbEchoConfig_)); - } else if(conf.type_ == kTimeStretch) { + else if(conf.type_ == kTimeStretch) ar(cereal::make_nvp("timeStretchConfig", conf.timeStretchConfig_)); - } } } // namespace cereal @@ -134,25 +133,24 @@ std::shared_ptr createSoundEffect( ) { auto sfxChain = std::make_shared(); for(const SoundEffectConfig& conf : sfxConfigs) { - if(conf.type_ == kAdditiveNoise) { + if(conf.type_ == kAdditiveNoise) sfxChain->add( std::make_shared(conf.additiveNoiseConfig_, seed) ); - } else if(conf.type_ == kAmplify) { + else if(conf.type_ == kAmplify) sfxChain->add(std::make_shared(conf.amplifyConfig_)); - } else if(conf.type_ == kClampAmplitude) { + else if(conf.type_ == kClampAmplitude) sfxChain->add(std::make_shared()); - } else if(conf.type_ == kNormalize) { + else if(conf.type_ == kNormalize) sfxChain->add(std::make_shared(conf.normalizeOnlyIfTooHigh_)); - } else if(conf.type_ == kReverbEcho) { + else if(conf.type_ == kReverbEcho) sfxChain->add(std::make_shared(conf.reverbEchoConfig_, seed)); - } else if(conf.type_ == kTimeStretch) { + else if(conf.type_ == kTimeStretch) sfxChain->add( std::make_shared(conf.timeStretchConfig_, seed) ); - } else { + else LOG(FATAL) << "Invalid sound effect config type=" << conf.type_; - } } return sfxChain; } diff --git a/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp b/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp index 3685dd5..65cc940 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp @@ -17,9 +17,8 @@ RandomNumberGenerator::RandomNumberGenerator(int seed /* = 0 */) : randomEngine_ gaussianDist_(0, 1) {} int RandomNumberGenerator::randInt(int minVal, int maxVal) { - if(minVal > maxVal) { + if(minVal > maxVal) std::swap(minVal, maxVal); - } return randomEngine_() % (maxVal - minVal + 1) + minVal; } @@ -37,9 +36,8 @@ float RandomNumberGenerator::gaussian(float mean, float sigma) { float rootMeanSquare(const std::vector& signal) { float sumSquares = 0; - for(int i = 0; i < signal.size(); ++i) { + for(int i = 0; i < signal.size(); ++i) sumSquares += signal[i] * signal[i]; - } return std::sqrt(sumSquares / signal.size()); } @@ -58,9 +56,8 @@ std::vector genTestSinWave(size_t numSamples, size_t freq, size_t sampleR static_cast(sampleRate) / static_cast(freq); const float ratio = (2 * M_PI) / waveLenSamples; - for(size_t i = 0; i < numSamples; ++i) { + for(size_t i = 0; i < numSamples; ++i) output.at(i) = amplitude * std::sin(static_cast(i) * ratio); - } return output; } diff --git a/flashlight/pkg/speech/augmentation/TimeStretch.cpp b/flashlight/pkg/speech/augmentation/TimeStretch.cpp index 6d4f88d..cb6726c 100644 --- a/flashlight/pkg/speech/augmentation/TimeStretch.cpp +++ b/flashlight/pkg/speech/augmentation/TimeStretch.cpp @@ -28,9 +28,8 @@ TimeStretch::TimeStretch( } void TimeStretch::apply(std::vector& signal) { - if(rng_.random() >= conf_.proba_) { + if(rng_.random() >= conf_.proba_) return; - } const float factor = rng_.uniform(conf_.minFactor_, conf_.maxFactor_); sox_effect_t* e = sox_create_effect(stretchEffect_); std::string _factor = std::to_string(factor); diff --git a/flashlight/pkg/speech/common/ProducerConsumerQueue.h b/flashlight/pkg/speech/common/ProducerConsumerQueue.h index f88ad2a..8a0c71c 100644 --- a/flashlight/pkg/speech/common/ProducerConsumerQueue.h +++ b/flashlight/pkg/speech/common/ProducerConsumerQueue.h @@ -69,14 +69,12 @@ namespace lib { lock, [this]() { return !isFull() || isAddingFinished_; }); - if(isAddingFinished_) { + if(isAddingFinished_) return; - } queue_.push(std::move(unit)); - if(!isFull()) { + if(!isFull()) producerCondition_.notify_one(); - } consumerCondition_.notify_one(); } @@ -91,15 +89,13 @@ namespace lib { consumerCondition_.wait( lock, [this]() { return !isEmpty() || isAddingFinished_; }); - if(isEmpty()) { + if(isEmpty()) return false; - } unit = std::move(queue_.front()); queue_.pop(); - if(!isEmpty()) { + if(!isEmpty()) consumerCondition_.notify_one(); - } producerCondition_.notify_one(); return true; diff --git a/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h b/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h index d91ede8..9344931 100644 --- a/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h +++ b/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h @@ -28,9 +28,8 @@ namespace pkg { scaleMode_(scalemode), fac_(ForceAlignmentCriterion(N, scalemode)), fcc_(FullConnectionCriterion(N, scalemode)) { - if(N_ <= 0) { + if(N_ <= 0) throw std::invalid_argument("ASG: N is zero or negative."); - } fl::Variable transition(transdiag * fl::identity(N_), true); params_ = {transition}; syncTransitions(); @@ -45,9 +44,8 @@ namespace pkg { std::vector forward( const std::vector& inputs ) override { - if(inputs.size() != 2) { + if(inputs.size() != 2) throw std::invalid_argument("Invalid inputs size"); - } return { fcc_.forward(inputs[0], inputs[1]) - fac_.forward(inputs[0], inputs[1])}; diff --git a/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.cpp b/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.cpp index 3838c78..1d5c79e 100644 --- a/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.cpp +++ b/flashlight/pkg/speech/criterion/ConnectionistTemporalClassificationCriterion.cpp @@ -25,9 +25,8 @@ struct CTCContext { Tensor logSoftmax(const Tensor& input, const int dim) { Tensor maxvals = fl::amax(input, {dim}, /* keepDims = */ true); Shape tiledims(std::vector(input.ndim(), 1)); - if(dim > 3) { + if(dim > 3) throw std::invalid_argument("logSoftmax: Dimension must be less than 3"); - } tiledims[dim] = input.dim(dim); // Compute log softmax. // Subtracting then adding maxvals is for numerical stability. @@ -79,12 +78,11 @@ Tensor ConnectionistTemporalClassificationCriterion::viterbiPathWithTarget( const Tensor& inputSizes /* = Tensor() */, const Tensor& targetSizes /* = Tensor() */ ) { - if(input.ndim() != 3) { + if(input.ndim() != 3) throw std::invalid_argument( "ConnectionistTemporalClassificationCriterion::viterbiPathWithTarget: " "expected input of shape {N x T x B}" ); - } int N = input.dim(0); int T = input.dim(1); int B = input.dim(2); @@ -122,27 +120,23 @@ void ConnectionistTemporalClassificationCriterion::validate( const Variable& input, const Variable& target ) { - if(input.isEmpty()) { + if(input.isEmpty()) throw std::invalid_argument("CTC: Input cannot be empty"); - } - if(target.ndim() < 2) { + if(target.ndim() < 2) throw std::invalid_argument( "CTC: Incorrect dimensions for target. Expected {L, B}, got " + target.shape().toString() ); - } - if(input.ndim() < 3) { + if(input.ndim() < 3) throw std::invalid_argument( "CTC: Incorrect dimensions for input. Expected {N, T, B}, got " + input.shape().toString() ); - } - if(input.dim(2) != target.dim(1)) { + if(input.dim(2) != target.dim(1)) throw std::invalid_argument( "CTC: Batchsize mismatch for input and target with dims " + input.shape().toString() + " and " + target.shape().toString() + ", respectively" ); - } } } // namespace fl diff --git a/flashlight/pkg/speech/criterion/CriterionUtils.cpp b/flashlight/pkg/speech/criterion/CriterionUtils.cpp index 1f0cd65..63c2fee 100644 --- a/flashlight/pkg/speech/criterion/CriterionUtils.cpp +++ b/flashlight/pkg/speech/criterion/CriterionUtils.cpp @@ -18,11 +18,9 @@ namespace fl::pkg::speech { int countRepeats(const int* labels, int len) { int r = 0; - for(int i = 1; i < len; ++i) { - if(labels[i] == labels[i - 1]) { + for(int i = 1; i < len; ++i) + if(labels[i] == labels[i - 1]) ++r; - } - } return r; } @@ -37,17 +35,16 @@ CriterionScaleMode getCriterionScaleMode( const std::string& onorm, bool sqnorm ) { - if(onorm == "none") { + if(onorm == "none") return CriterionScaleMode::NONE; - } else if(onorm == "input") { + else if(onorm == "input") return sqnorm ? CriterionScaleMode::INPUT_SZ_SQRT : CriterionScaleMode::INPUT_SZ; - } else if(onorm == "target") { + else if(onorm == "target") return sqnorm ? CriterionScaleMode::TARGET_SZ_SQRT : CriterionScaleMode::TARGET_SZ; - } else { + else throw std::invalid_argument("invalid onorm option"); - } } Variable getLinearTarget(const Variable& targetVar, int T) { @@ -63,14 +60,12 @@ Variable getLinearTarget(const Variable& targetVar, int T) { auto pNewTarget = newTarget.data() + b * T; int targetSize = std::min(T, fl::pkg::speech::getTargetSize(pTarget, L)); - if(targetSize == 0) { + if(targetSize == 0) // hacky way to make ASG think L == 0. std::fill(pNewTarget, pNewTarget + T, -1); - } else { - for(int t = 0; t < T; ++t) { + else + for(int t = 0; t < T; ++t) pNewTarget[t] = pTarget[t * targetSize / T]; - } - } } return Variable(Tensor::fromVector({T, B}, newTarget), false); } @@ -80,11 +75,10 @@ fl::Variable applySeq2SeqMask( const Tensor& targetClasses, int padValue ) { - if(input.shape() != targetClasses.shape()) { + if(input.shape() != targetClasses.shape()) throw std::runtime_error( "applySeq2SeqMask: input and mask should have the same dimentions." ); - } Tensor output = input.tensor(); Tensor mask = targetClasses == padValue; output(mask) = 0.; diff --git a/flashlight/pkg/speech/criterion/CriterionUtils.h b/flashlight/pkg/speech/criterion/CriterionUtils.h index 60aafc4..a1baac9 100644 --- a/flashlight/pkg/speech/criterion/CriterionUtils.h +++ b/flashlight/pkg/speech/criterion/CriterionUtils.h @@ -24,29 +24,23 @@ namespace pkg { template inline T logSumExp(T logA, T logB) { - if(logA < logB) { + if(logA < logB) std::swap(logA, logB); - } - if(logB == -std::numeric_limits::infinity()) { + if(logB == -std::numeric_limits::infinity()) return logA; - } return logA + std::log1p(std::exp(logB - logA)); } template inline T logSumExp(T logA, T logB, T logC) { - if(logA < logB) { + if(logA < logB) std::swap(logA, logB); - } - if(logA < logC) { + if(logA < logC) std::swap(logA, logC); - } - if(logB < logC) { + if(logB < logC) std::swap(logB, logC); - } - if(logC == -std::numeric_limits::infinity()) { + if(logC == -std::numeric_limits::infinity()) return logSumExp(logA, logB); - } return logA + std::log1p(std::exp(logB - logA) + std::exp(logC - logA)); } diff --git a/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.cpp b/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.cpp index a86bfe4..3580240 100644 --- a/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.cpp +++ b/flashlight/pkg/speech/criterion/ForceAlignmentCriterion.cpp @@ -14,11 +14,10 @@ ForceAlignmentCriterion::ForceAlignmentCriterion( fl::lib::seq::CriterionScaleMode scalemode ) : N_(N), scaleMode_(scalemode) { - if(N_ <= 0) { + if(N_ <= 0) throw std::invalid_argument( "FAC: Size of transition matrix is less than 0" ); - } auto transition = fl::constant(0.0, {N_, N_}); params_ = {transition}; } diff --git a/flashlight/pkg/speech/criterion/FullConnectionCriterion.cpp b/flashlight/pkg/speech/criterion/FullConnectionCriterion.cpp index faa945b..cbabb47 100644 --- a/flashlight/pkg/speech/criterion/FullConnectionCriterion.cpp +++ b/flashlight/pkg/speech/criterion/FullConnectionCriterion.cpp @@ -18,11 +18,10 @@ FullConnectionCriterion::FullConnectionCriterion( fl::lib::seq::CriterionScaleMode scalemode ) : N_(N), scaleMode_(scalemode) { - if(N_ <= 0) { + if(N_ <= 0) throw std::invalid_argument( "FCC: Size of transition matrix is less than 0." ); - } auto transition = constant(0.0, {N_, N_}); params_ = {transition}; } diff --git a/flashlight/pkg/speech/criterion/LinearSegmentationCriterion.h b/flashlight/pkg/speech/criterion/LinearSegmentationCriterion.h index fffeeb9..e9db28c 100644 --- a/flashlight/pkg/speech/criterion/LinearSegmentationCriterion.h +++ b/flashlight/pkg/speech/criterion/LinearSegmentationCriterion.h @@ -32,9 +32,8 @@ namespace pkg { std::vector forward( const std::vector& inputs ) override { - if(inputs.size() != 2) { + if(inputs.size() != 2) throw std::invalid_argument("Invalid inputs size"); - } const auto& input = inputs[0]; const auto& target = inputs[1]; return AutoSegmentationCriterion::forward( diff --git a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp index f85546b..8eea4c5 100644 --- a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp +++ b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp @@ -20,9 +20,8 @@ namespace fl::pkg::speech { namespace detail { Seq2SeqState concatState(std::vector& stateVec) { - if(stateVec.empty()) { + if(stateVec.empty()) throw std::runtime_error("Empty stateVec"); - } int nAttnRound = stateVec[0].hidden.size(); Seq2SeqState newState(nAttnRound); @@ -34,22 +33,19 @@ namespace detail { std::vector> hiddenVec(nAttnRound); std::vector summaryVec; for(auto& state : stateVec) { - if(state.step != newState.step) { + if(state.step != newState.step) throw std::runtime_error("step unmatched"); - } else if(state.isValid != newState.isValid) { + else if(state.isValid != newState.isValid) throw std::runtime_error("isValid unmatched"); - } alphaVec.push_back(state.alpha); - for(int i = 0; i < nAttnRound; i++) { + for(int i = 0; i < nAttnRound; i++) hiddenVec[i].push_back(state.hidden[i]); - } summaryVec.push_back(state.summary); } newState.alpha = concatenate(alphaVec, 2); - for(int i = 0; i < nAttnRound; i++) { + for(int i = 0; i < nAttnRound; i++) newState.hidden[i] = concatenate(hiddenVec[i], 1); - } newState.summary = concatenate(summaryVec, 2); return newState; } @@ -64,10 +60,9 @@ namespace detail { state.alpha(fl::span, fl::span, fl::range(batchIdx, batchIdx + 1)); newState.summary = state.summary(fl::span, fl::span, fl::range(batchIdx, batchIdx + 1)); - for(int i = 0; i < nAttnRound; i++) { + for(int i = 0; i < nAttnRound; i++) newState.hidden[i] = state.hidden[i](fl::span, fl::range(batchIdx, batchIdx + 1)); - } return newState; } } // namespace detail @@ -105,7 +100,7 @@ Seq2SeqCriterion::Seq2SeqCriterion( add(std::make_shared(hiddenDim, nClass_)); // 2. RNN - for(int i = 0; i < nAttnRound_; i++) { + for(int i = 0; i < nAttnRound_; i++) add( std::make_shared( hiddenDim, @@ -116,7 +111,6 @@ Seq2SeqCriterion::Seq2SeqCriterion( dropOut ) ); - } // 3. Linear add(std::make_shared(hiddenDim, nClass_)); @@ -124,9 +118,8 @@ Seq2SeqCriterion::Seq2SeqCriterion( // backward compatibility. // 4. Attention - for(int i = 0; i < nAttnRound_; i++) { + for(int i = 0; i < nAttnRound_; i++) add(attentions[i]); - } // 5. Initial hidden state params_.push_back(fl::uniform(Shape{hiddenDim}, -1e-1, 1e-1)); @@ -151,11 +144,10 @@ std::unique_ptr Seq2SeqCriterion::clone() const { std::vector Seq2SeqCriterion::forward( const std::vector& inputs ) { - if(inputs.size() < 2 || (inputs.size() > 4)) { + if(inputs.size() < 2 || (inputs.size() > 4)) throw std::invalid_argument( "Invalid inputs size; Seq2Seq criterion takes input, target, inputSizes [optional]" ); - } const auto& input = inputs[0]; const auto& target = inputs[1]; const auto& inputSizes = @@ -164,12 +156,11 @@ std::vector Seq2SeqCriterion::forward( inputs.size() == 3 ? Tensor() : inputs[3].tensor(); // 1 x B Variable out, alpha; - if(useSequentialDecoder_) { + if(useSequentialDecoder_) std::tie(out, alpha) = decoder(input, target, inputSizes, targetSizes); - } else { + else std::tie(out, alpha) = vectorizedDecoder(input, target, inputSizes, targetSizes); - } out = logSoftmax(out, 0); // C x U x B @@ -197,12 +188,11 @@ std::pair Seq2SeqCriterion::vectorizedDecoder( const Tensor& inputSizes, const Tensor& targetSizes ) { - if(target.ndim() != 2) { + if(target.ndim() != 2) throw std::invalid_argument( "Seq2SeqCriterion::vectorizedDecoder: " "target expects to be shape {U, B}" ); - } int U = target.dim(0); int B = target.dim(1); int T = input.dim(1); @@ -213,11 +203,11 @@ std::pair Seq2SeqCriterion::vectorizedDecoder( // Slice off eos auto y = target(fl::range(0, U - 1), fl::span); if(train_) { - if(samplingStrategy_ == fl::pkg::speech::kModelSampling) { + if(samplingStrategy_ == fl::pkg::speech::kModelSampling) throw std::logic_error( "vectorizedDecoder does not support model sampling" ); - } else if(samplingStrategy_ == fl::pkg::speech::kRandSampling) { + else if(samplingStrategy_ == fl::pkg::speech::kRandSampling) { auto mask = Variable( (fl::rand(y.shape()) * 100 <= pctTeacherForcing_).astype(y.type()), false @@ -242,10 +232,9 @@ std::pair Seq2SeqCriterion::vectorizedDecoder( hy = fl::transpose(hy, {0, 2, 1}); // H x B x U -> H x U x B Variable windowWeight; - if(window_ && (!train_ || trainWithWindow_)) { + if(window_ && (!train_ || trainWithWindow_)) windowWeight = window_->computeVectorizedWindow(U, T, B, inputSizes, targetSizes); - } std::tie(alpha, summaries) = attention(i)->forward( hy, @@ -278,9 +267,9 @@ std::pair Seq2SeqCriterion::decoder( std::tie(ox, state) = decodeStep(input, y, state, inputSizes, targetSizes, U); - if(!train_) { + if(!train_) y = target(fl::range(u, u + 1), fl::span); - } else if(samplingStrategy_ == fl::pkg::speech::kGumbelSampling) { + else if(samplingStrategy_ == fl::pkg::speech::kGumbelSampling) { double eps = 1e-7; auto gb = -log(-log((1 - 2 * eps) * fl::rand(ox.shape()) + eps)); ox = logSoftmax((ox + Variable(gb, false)) / gumbelTemperature_, 0); @@ -288,20 +277,19 @@ std::pair Seq2SeqCriterion::decoder( } else if( fl::all(fl::rand({1}) * 100 <= fl::full({1}, pctTeacherForcing_)) .asScalar() - ) { + ) y = target(fl::range(u, u + 1), fl::span); - } else if(samplingStrategy_ == fl::pkg::speech::kModelSampling) { + else if(samplingStrategy_ == fl::pkg::speech::kModelSampling) { Tensor maxIdx, maxValues; fl::max(maxValues, maxIdx, ox.tensor(), 0); y = Variable(maxIdx, false); - } else if(samplingStrategy_ == fl::pkg::speech::kRandSampling) { + } else if(samplingStrategy_ == fl::pkg::speech::kRandSampling) y = Variable( (fl::rand({1, target.dim(1)}) * (nClass_ - 1)).astype(fl::dtype::s32), false ); - } else { + else throw std::invalid_argument("Invalid sampling strategy"); - } outvec.push_back(ox); alphaVec.push_back(state.alpha); @@ -346,23 +334,19 @@ std::pair Seq2SeqCriterion::viterbiPathBase( ); fl::max(maxValues, maxIdx, ox.tensor(), 0); pred = maxIdx.asScalar(); - if(saveAttn) { + if(saveAttn) alphaVec.push_back(state.alpha); - } - if(pred == eos_) { + if(pred == eos_) break; - } y = constant(pred, {1}, fl::dtype::s32, false); maxPath.push_back(pred); } - if(saveAttn) { + if(saveAttn) alpha = concatenate(alphaVec, 0); - } - if(wasTrain) { + if(wasTrain) train(); - } Tensor vPath = maxPath.empty() ? Tensor() : Tensor::fromVector(maxPath); return std::make_pair(vPath, alpha); } @@ -405,9 +389,8 @@ std::vector Seq2SeqCriterion::beamSearch( std::vector prevScoreVec; for(auto& hypo : beam) { Variable y; - if(!hypo.path.empty()) { + if(!hypo.path.empty()) y = constant(hypo.path.back(), {1}, fl::dtype::s32, false); - } prevYVec.push_back(y); prevStateVec.push_back(hypo.state); prevScoreVec.push_back(hypo.score); @@ -468,16 +451,14 @@ std::vector Seq2SeqCriterion::beamSearch( path_, detail::selectState(state, hypIdx) ); - } else if(clsIdx != eos_) { + } else if(clsIdx != eos_) newBeam.emplace_back( scoreVec[indices[j]], path_, detail::selectState(state, hypIdx) ); - } - if(newBeam.size() >= beamSize) { + if(newBeam.size() >= beamSize) break; - } } beam.resize(newBeam.size()); beam = std::move(newBeam); @@ -494,15 +475,13 @@ std::vector Seq2SeqCriterion::beamSearch( // if lowest score in complete is better than best future hypo // then its not possible for any future hypothesis to replace existing // hypothesises in complete. - if(complete.back().score > beam[0].score) { + if(complete.back().score > beam[0].score) break; - } } } - if(wasTrain) { + if(wasTrain) train(); - } return complete.empty() ? beam : complete; } @@ -515,25 +494,22 @@ std::pair Seq2SeqCriterion::decodeStep( const Tensor& targetSizes, const int maxDecoderSteps ) const { - if(xEncoded.ndim() != 3) { + if(xEncoded.ndim() != 3) throw std::invalid_argument( "Seq2SeqCriterion::decodeStep: " "expected xEncoded to have at least three dimensions" ); - } Variable hy; - if(y.isEmpty()) { + if(y.isEmpty()) hy = tile(startEmbedding(), {1, 1, static_cast(xEncoded.dim(2))}); - } else if(train_ && samplingStrategy_ == fl::pkg::speech::kGumbelSampling) { + else if(train_ && samplingStrategy_ == fl::pkg::speech::kGumbelSampling) hy = linear(y, embedding()->param(0)); - } else { + else hy = embedding()->forward(y); - } - if(inputFeeding_ && !y.isEmpty()) { + if(inputFeeding_ && !y.isEmpty()) hy = hy + moddims(inState.summary, hy.shape()); - } hy = moddims(hy, {hy.dim(0), -1}); // H x B Seq2SeqState outState(nAttnRound_); @@ -552,7 +528,7 @@ std::pair Seq2SeqCriterion::decodeStep( // size) int batchsize = y.isEmpty() ? xEncoded.dim(2) : (y.ndim() < 2 ? 1 : y.dim(1)); - if(window_ && (!train_ || trainWithWindow_)) { + if(window_ && (!train_ || trainWithWindow_)) // TODO fix for softpretrain where target size is used // for now force to xEncoded.dim(1) windowWeight = window_->computeWindow( @@ -564,7 +540,6 @@ std::pair Seq2SeqCriterion::decodeStep( inputSizes, targetSizes ); - } std::tie(outState.alpha, summaries) = attention(i)->forward( hy, xEncoded, @@ -593,13 +568,12 @@ std::pair>, std::vector> Seq2Seq // Batch Ys for(int i = 0; i < batchSize; i++) { - if(ys[i].isEmpty()) { + if(ys[i].isEmpty()) ys[i] = startEmbedding(); - } else { + else { ys[i] = embedding()->forward(ys[i]); - if(inputFeeding_) { + if(inputFeeding_) ys[i] = ys[i] + moddims(inStates[i]->summary, ys[i].shape()); - } } ys[i] = moddims(ys[i], {ys[i].dim(0), -1}); } @@ -614,29 +588,26 @@ std::pair>, std::vector> Seq2Seq for(int n = 0; n < nAttnRound_; n++) { /* (1) RNN forward */ - if(inStates[0]->hidden[n].isEmpty()) { + if(inStates[0]->hidden[n].isEmpty()) std::tie(yBatched, outStateBatched) = decodeRNN(n)->forward(yBatched, Variable()); - } else { - for(int i = 0; i < batchSize; i++) { + else { + for(int i = 0; i < batchSize; i++) statesVector[i] = inStates[i]->hidden[n]; - } Variable inStateHiddenBatched = concatenate(statesVector, 1).asContiguous(); std::tie(yBatched, outStateBatched) = decodeRNN(n)->forward(yBatched, inStateHiddenBatched); } - for(int i = 0; i < batchSize; i++) { + for(int i = 0; i < batchSize; i++) outstates[i]->hidden[n] = outStateBatched(fl::span, fl::range(i, i + 1)); - } /* (2) Attention forward */ - if(window_ && (!train_ || trainWithWindow_)) { + if(window_ && (!train_ || trainWithWindow_)) throw std::runtime_error( "Batched decoding does not support models with window" ); - } Variable summaries, alphaBatched; // NB: @@ -665,11 +636,10 @@ std::pair>, std::vector> Seq2Seq auto outBatched = linearOut()->forward(yBatched); outBatched = logSoftmax(outBatched / smoothingTemperature, 0); std::vector> out(batchSize); - for(int i = 0; i < batchSize; i++) { + for(int i = 0; i < batchSize; i++) out[i] = outBatched(fl::span, fl::range(i, i + 1)) .tensor() .toHostVector(); - } return std::make_pair(out, outstates); } @@ -680,20 +650,19 @@ void Seq2SeqCriterion::setUseSequentialDecoder() { (pctTeacherForcing_ < 100 && samplingStrategy_ == fl::pkg::speech::kModelSampling) || samplingStrategy_ == fl::pkg::speech::kGumbelSampling || inputFeeding_ - ) { + ) useSequentialDecoder_ = true; - } else if( + else if( std::dynamic_pointer_cast(attention(0)) || std::dynamic_pointer_cast(attention(0)) || std::dynamic_pointer_cast(attention(0)) - ) { + ) useSequentialDecoder_ = true; - } else if( + else if( window_ && trainWithWindow_ && std::dynamic_pointer_cast(window_) - ) { + ) useSequentialDecoder_ = true; - } } std::string Seq2SeqCriterion::prettyString() const { @@ -725,12 +694,11 @@ EmittingModelUpdateFunc buildSeq2SeqRnnUpdateFunction( const std::vector& /* prevHypBeamIdxs */, const std::vector& rawPrevStates, int& t) { - if(t == 0) { + if(t == 0) buf->input = fl::Variable( Tensor::fromBuffer({N, T}, emissions, MemoryLocation::Host), false ); - } int batchSize = rawY.size(); buf->prevStates.resize(0); buf->ys.resize(0); @@ -740,11 +708,10 @@ EmittingModelUpdateFunc buildSeq2SeqRnnUpdateFunction( Seq2SeqState* prevState = static_cast(rawPrevStates[i].get()); fl::Variable y; - if(t > 0) { + if(t > 0) y = fl::constant(rawY[i], {1}, fl::dtype::s32, false); - } else { + else prevState = &buf->dummyState; - } buf->ys.push_back(y); buf->prevStates.push_back(prevState); } @@ -764,11 +731,10 @@ EmittingModelUpdateFunc buildSeq2SeqRnnUpdateFunction( // Cast back to void* std::vector out; for(auto& os : outStates) { - if(os->isValid) { + if(os->isValid) out.push_back(os); - } else { + else out.push_back(nullptr); - } } return std::make_pair(amScores, out); }; diff --git a/flashlight/pkg/speech/criterion/TransformerCriterion.cpp b/flashlight/pkg/speech/criterion/TransformerCriterion.cpp index b696a9c..c998681 100644 --- a/flashlight/pkg/speech/criterion/TransformerCriterion.cpp +++ b/flashlight/pkg/speech/criterion/TransformerCriterion.cpp @@ -40,7 +40,7 @@ TransformerCriterion::TransformerCriterion( labelSmooth_(labelSmooth), pctTeacherForcing_(pctTeacherForcing) { add(std::make_shared(hiddenDim, nClass)); - for(size_t i = 0; i < nLayer_; i++) { + for(size_t i = 0; i < nLayer_; i++) add( std::make_shared( hiddenDim, @@ -53,7 +53,6 @@ TransformerCriterion::TransformerCriterion( true ) ); - } add(std::make_shared(hiddenDim, nClass)); add(attention); params_.push_back(fl::uniform(Shape{hiddenDim}, -1e-1, 1e-1)); @@ -68,12 +67,11 @@ std::unique_ptr TransformerCriterion::clone() const { std::vector TransformerCriterion::forward( const std::vector& inputs ) { - if(inputs.size() < 2 || inputs.size() > 4) { + if(inputs.size() < 2 || inputs.size() > 4) throw std::invalid_argument( "Invalid inputs size; Transformer criterion takes input," " target, inputSizes [optional], targetSizes [optional]" ); - } const Variable& input = inputs[0]; const Variable& target = inputs[1]; const auto& inputSizes = @@ -142,16 +140,14 @@ std::pair TransformerCriterion::vectorizedDecoder( Variable alpha, summaries; Variable padMask; // no mask, decoder is not looking into future - for(int i = 0; i < nLayer_; i++) { + for(int i = 0; i < nLayer_; i++) hy = layer(i)->forward(std::vector({hy, padMask})).front(); - } if(!input.isEmpty()) { Variable windowWeight; - if(window_ && (!train_ || trainWithWindow_)) { + if(window_ && (!train_ || trainWithWindow_)) windowWeight = window_->computeVectorizedWindow(U, T, B, inputSizes, targetSizes); - } std::tie(alpha, summaries) = attention()->forward( hy, @@ -198,17 +194,15 @@ std::pair TransformerCriterion::viterbiPathBase( maxIdx.host(&pred); // TODO: saveAttn - if(pred == eos_) { + if(pred == eos_) break; - } y = constant(pred, {1}, fl::dtype::s32, false); path.push_back(pred); } // TODO: saveAttn - if(wasTrain) { + if(wasTrain) train(); - } auto vPath = path.empty() ? Tensor() : Tensor::fromVector(path); return std::make_pair(vPath, alpha); @@ -221,11 +215,10 @@ std::pair TransformerCriterion::decodeStep( const Tensor& inputSizes ) const { Variable hy; - if(y.isEmpty()) { + if(y.isEmpty()) hy = tile(startEmbedding(), {1, 1, xEncoded.dim(2)}); - } else { + else hy = embedding()->forward(y); - } // TODO: inputFeeding @@ -248,7 +241,7 @@ std::pair TransformerCriterion::decodeStep( } Variable windowWeight, alpha, summary; - if(window_ && (!train_ || trainWithWindow_)) { + if(window_ && (!train_ || trainWithWindow_)) // TODO fix for softpretrain where target size is used // for now force to xEncoded.dim(1) windowWeight = window_->computeWindow( @@ -260,7 +253,6 @@ std::pair TransformerCriterion::decodeStep( inputSizes, Tensor() ); - } std::tie(alpha, summary) = attention()->forward( hy, @@ -287,11 +279,11 @@ std::pair>, std::vector> Transforme int B = ys.size(); for(int i = 0; i < B; i++) { - if(ys[i].isEmpty()) { + if(ys[i].isEmpty()) ys[i] = startEmbedding(); - } else { + else ys[i] = embedding()->forward(ys[i]); - } // TODO: input feeding + // TODO: input feeding ys[i] = moddims(ys[i], {ys[i].dim(0), 1, -1}); } Variable yBatched = concatenate(ys, 2); // D x 1 x B @@ -305,21 +297,18 @@ std::pair>, std::vector> Transforme Variable outStateBatched; for(int i = 0; i < nLayer_; i++) { if(inStates[0]->step == 0) { - for(int j = 0; j < B; j++) { + for(int j = 0; j < B; j++) outstates[j]->hidden.push_back(yBatched(fl::span, fl::span, j)); - } yBatched = layer(i)->forward(std::vector({yBatched})).front(); } else { std::vector statesVector(B); - for(int j = 0; j < B; j++) { + for(int j = 0; j < B; j++) statesVector[j] = inStates[j]->hidden[i]; - } Variable inStateHiddenBatched = concatenate(statesVector, 2); auto tmp = std::vector({inStateHiddenBatched, yBatched}); auto tmp2 = concatenate(tmp, 1); - for(int j = 0; j < B; j++) { + for(int j = 0; j < B; j++) outstates[j]->hidden.push_back(tmp2(fl::span, fl::span, j)); - } yBatched = layer(i)->forward(tmp).front(); } } @@ -334,9 +323,8 @@ std::pair>, std::vector> Transforme auto outBatched = linearOut()->forward(yBatched); outBatched = logSoftmax(outBatched / smoothingTemperature, 0); std::vector> out(B); - for(int i = 0; i < B; i++) { + for(int i = 0; i < B; i++) out[i] = outBatched(fl::span, i).tensor().toHostVector(); - } return std::make_pair(out, outstates); } @@ -362,12 +350,11 @@ EmittingModelUpdateFunc buildSeq2SeqTransformerUpdateFunction( const std::vector& /* prevHypBeamIdxs */, const std::vector& rawPrevStates, int& t) { - if(t == 0) { + if(t == 0) buf->input = fl::Variable( Tensor::fromBuffer({N, T}, emissions, MemoryLocation::Host), false ); - } int B = rawY.size(); std::vector out; std::vector> amScoresAll; @@ -385,18 +372,16 @@ EmittingModelUpdateFunc buildSeq2SeqTransformerUpdateFunction( buf->ys.resize(0); int end = start + step; - if(end > B) { + if(end > B) end = B; - } for(int i = start; i < end; i++) { TS2SState* prevState = static_cast(rawPrevStates[i].get()); fl::Variable y; - if(t > 0) { + if(t > 0) y = fl::constant(rawY[i], {1}, fl::dtype::s32, false); - } else { + else prevState = &buf->dummyState; - } buf->ys.push_back(y); buf->prevStates.push_back(prevState); } @@ -409,12 +394,10 @@ EmittingModelUpdateFunc buildSeq2SeqTransformerUpdateFunction( buf->attentionThreshold, buf->smoothingTemperature ); - for(auto& os : outStates) { + for(auto& os : outStates) out.push_back(os); - } - for(auto& s : amScores) { + for(auto& s : amScores) amScoresAll.push_back(s); - } // clean the previous state which is not needed anymore // to prevent from OOM for(int i = start; i < end; i++) { @@ -425,9 +408,8 @@ EmittingModelUpdateFunc buildSeq2SeqTransformerUpdateFunction( && (lastIndexOfStatePtr.find(prevState) == lastIndexOfStatePtr.end() || lastIndexOfStatePtr.find(prevState)->second == i) - ) { + ) prevState->hidden.clear(); - } } start += step; } diff --git a/flashlight/pkg/speech/criterion/attention/AttentionBase.h b/flashlight/pkg/speech/criterion/attention/AttentionBase.h index 4d76043..f46c783 100644 --- a/flashlight/pkg/speech/criterion/attention/AttentionBase.h +++ b/flashlight/pkg/speech/criterion/attention/AttentionBase.h @@ -22,11 +22,10 @@ namespace pkg { AttentionBase() {} std::vector forward(const std::vector& inputs) override { - if(inputs.size() != 3 && inputs.size() != 4 && inputs.size() != 5) { + if(inputs.size() != 3 && inputs.size() != 4 && inputs.size() != 5) throw std::invalid_argument( "Attention encoder-decoder: Invalid inputs size, should be 3, 4, or 5 arguments" ); - } auto logAttnWeight = inputs.size() == 4 ? inputs[3] : Variable(); auto xEncodedSizes = inputs.size() == 5 ? inputs[4] : Variable(); diff --git a/flashlight/pkg/speech/criterion/attention/ContentAttention.cpp b/flashlight/pkg/speech/criterion/attention/ContentAttention.cpp index 6cf50c8..d72fd11 100644 --- a/flashlight/pkg/speech/criterion/attention/ContentAttention.cpp +++ b/flashlight/pkg/speech/criterion/attention/ContentAttention.cpp @@ -28,27 +28,24 @@ std::pair ContentAttention::forwardBase( const Variable& xEncodedSizes ) { int dim = xEncoded.dim(0); - if(dim != (1 + ((keyValue_) ? 1 : 0)) * state.dim(0)) { + if(dim != (1 + ((keyValue_) ? 1 : 0)) * state.dim(0)) throw std::invalid_argument( "ContentAttention: Invalid dimension for content attention" ); - } auto keys = keyValue_ ? xEncoded(fl::range(0, dim / 2)) : xEncoded; auto values = keyValue_ ? xEncoded(fl::range(dim / 2, dim)) : xEncoded; // [targetlen, seqlen, batchsize] auto innerProd = matmulTN(state, keys) / std::sqrt(state.dim(0)); if(!logAttnWeight.isEmpty()) { - if(logAttnWeight.shape() != innerProd.shape()) { + if(logAttnWeight.shape() != innerProd.shape()) throw std::invalid_argument( "ContentAttention: logAttnWeight has wong dimentions" ); - } innerProd = innerProd + logAttnWeight; } Tensor padMask; - if(!xEncodedSizes.isEmpty()) { + if(!xEncodedSizes.isEmpty()) innerProd = maskAttention(innerProd, xEncodedSizes); - } // [targetlen, seqlen, batchsize] auto attention = softmax(innerProd, 1); // [hiddendim, targetlen, batchsize] @@ -96,17 +93,15 @@ std::pair NeuralContentAttention::forwardBase( // [targetlen, seqlen, batchsize] auto nnOut = moddims(module(0)->forward({hidden}).front(), {U, T, B}); if(!logAttnWeight.isEmpty()) { - if(logAttnWeight.shape() != nnOut.shape()) { + if(logAttnWeight.shape() != nnOut.shape()) throw std::invalid_argument( "ContentAttention: logAttnWeight has wong dimentions" ); - } nnOut = nnOut + logAttnWeight; } - if(!xEncodedSizes.isEmpty()) { + if(!xEncodedSizes.isEmpty()) nnOut = maskAttention(nnOut, xEncodedSizes); - } // [targetlen, seqlen, batchsize] auto attention = softmax(nnOut, 1); // [hiddendim, targetlen, batchsize] diff --git a/flashlight/pkg/speech/criterion/attention/LocationAttention.cpp b/flashlight/pkg/speech/criterion/attention/LocationAttention.cpp index 5756d1d..8838585 100644 --- a/flashlight/pkg/speech/criterion/attention/LocationAttention.cpp +++ b/flashlight/pkg/speech/criterion/attention/LocationAttention.cpp @@ -32,11 +32,10 @@ std::pair SimpleLocationAttention::forwardBase( const Variable& xEncodedSizes ) { int U = state.dim(1); - if(U > 1) { + if(U > 1) throw std::invalid_argument( prettyString() + " only works on single step forward" ); - } int T = xEncoded.dim(1); int B = xEncoded.dim(2); @@ -53,16 +52,14 @@ std::pair SimpleLocationAttention::forwardBase( } if(!logAttnWeight.isEmpty()) { - if(logAttnWeight.shape() != innerProd.shape()) { + if(logAttnWeight.shape() != innerProd.shape()) throw std::invalid_argument( "SimpleLocationAttention: logAttnWeight has wong dimentions" ); - } innerProd = innerProd + logAttnWeight; } - if(!xEncodedSizes.isEmpty()) { + if(!xEncodedSizes.isEmpty()) innerProd = maskAttention(innerProd, xEncodedSizes); - } // [1, seqlen, batchsize] auto attention = softmax(innerProd, 1); // [hiddendim, 1, batchsize] @@ -96,11 +93,10 @@ std::pair LocationAttention::forwardBase( const Variable& xEncodedSizes ) { int U = state.dim(1); - if(U > 1) { + if(U > 1) throw std::invalid_argument( prettyString() + " only works on single step forward" ); - } int H = xEncoded.dim(0); int T = xEncoded.dim(1); @@ -117,16 +113,14 @@ std::pair LocationAttention::forwardBase( } if(!logAttnWeight.isEmpty()) { - if(logAttnWeight.shape() != innerProd.shape()) { + if(logAttnWeight.shape() != innerProd.shape()) throw std::invalid_argument( "LocationAttention: logAttnWeight has wong dimentions" ); - } innerProd = innerProd + logAttnWeight; } - if(!xEncodedSizes.isEmpty()) { + if(!xEncodedSizes.isEmpty()) innerProd = maskAttention(innerProd, xEncodedSizes); - } // [1, seqlen, batchsize] auto attention = softmax(innerProd, 1); // [hiddendim, 1, batchsize] @@ -169,11 +163,10 @@ std::pair NeuralLocationAttention::forwardBase( const Variable& xEncodedSizes ) { int U = state.dim(1); - if(U > 1) { + if(U > 1) throw std::invalid_argument( prettyString() + " only works on single step forward" ); - } int T = xEncoded.dim(1); int B = xEncoded.dim(2); @@ -194,17 +187,15 @@ std::pair NeuralLocationAttention::forwardBase( auto nnOut = module(4)->forward({hidden}).front(); if(!logAttnWeight.isEmpty()) { - if(logAttnWeight.shape() != nnOut.shape()) { + if(logAttnWeight.shape() != nnOut.shape()) throw std::invalid_argument( "NeuralLocationAttention: logAttnWeight has wong dimentions" ); - } nnOut = nnOut + logAttnWeight; } - if(!xEncodedSizes.isEmpty()) { + if(!xEncodedSizes.isEmpty()) nnOut = maskAttention(nnOut, xEncodedSizes); - } // [1, seqlen, batchsize] auto attention = softmax(nnOut, 1); // [hiddendim, 1, batchsize] diff --git a/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.cpp b/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.cpp index 32354c6..e75ee77 100644 --- a/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.cpp +++ b/flashlight/pkg/speech/criterion/attention/MultiHeadAttention.cpp @@ -23,9 +23,8 @@ MultiHeadContentAttention::MultiHeadContentAttention( ) : numHeads_(numHeads), keyValue_(keyValue), splitInput_(splitInput) { - if(splitInput && dim % numHeads != 0) { + if(splitInput && dim % numHeads != 0) throw std::invalid_argument("Invalid dimensions"); - } if(!splitInput) { add(Linear(dim, dim)); // query @@ -48,21 +47,19 @@ std::pair MultiHeadContentAttention::forwardBase( const Variable& logAttnWeight, const Variable& xEncodedSizes ) { - if(state.ndim() != 3) { + if(state.ndim() != 3) throw std::invalid_argument( "MultiHeadContentAttention::forwardBase: " "state input must be of shape {H, U, B}" ); - } int hEncode = xEncoded.dim(0); int T = xEncoded.dim(1); int hState = state.dim(0); int U = state.dim(1); int B = state.dim(2); auto hiddenDim = hState / numHeads_; - if(hEncode != (1 + keyValue_) * hState) { + if(hEncode != (1 + keyValue_) * hState) throw std::invalid_argument("Invalid input encoder dimension"); - } auto xEncodedKey = keyValue_ ? xEncoded(fl::arange(0, hEncode / 2), fl::span, fl::span) @@ -88,20 +85,18 @@ std::pair MultiHeadContentAttention::forwardBase( if(!logAttnWeight.isEmpty()) { auto tiledLogAttnWeight = tile(logAttnWeight, {1, 1, numHeads_}); - if(tiledLogAttnWeight.shape() != innerProd.shape()) { + if(tiledLogAttnWeight.shape() != innerProd.shape()) throw std::invalid_argument( "MultiHeadContentAttention: logAttnWeight has wong dimentions" ); - } innerProd = innerProd + tiledLogAttnWeight; } - if(!xEncodedSizes.isEmpty()) { + if(!xEncodedSizes.isEmpty()) innerProd = maskAttention( innerProd, moddims(tile(xEncodedSizes, {numHeads_, 1}), {1, B * numHeads_}) ); - } // [U, T, B * numHeads_] auto attention = softmax(innerProd, 1); diff --git a/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.cpp b/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.cpp index 719da8f..c0f82ad 100644 --- a/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.cpp +++ b/flashlight/pkg/speech/criterion/attention/SoftPretrainWindow.cpp @@ -25,13 +25,12 @@ Variable SoftPretrainWindow::compute( ) const { int decoderStepsDim = decoderSteps.dim(0); auto ts = fl::arange({decoderStepsDim, inputSteps, batchSize}, 1); - if(inputSizes.isEmpty() && targetSizes.isEmpty()) { + if(inputSizes.isEmpty() && targetSizes.isEmpty()) return Variable( -fl::power(ts - inputSteps / targetLen * decoderSteps, 2) / (2 * std_ * std_), false ); - } Tensor inputNotPaddedSize = computeInputNotPaddedSize( inputSizes, diff --git a/flashlight/pkg/speech/criterion/attention/WindowBase.cpp b/flashlight/pkg/speech/criterion/attention/WindowBase.cpp index 81b0958..ef2cf84 100644 --- a/flashlight/pkg/speech/criterion/attention/WindowBase.cpp +++ b/flashlight/pkg/speech/criterion/attention/WindowBase.cpp @@ -17,29 +17,26 @@ Tensor WindowBase::computeInputNotPaddedSize( bool doTile ) const { if(inputSizes.isEmpty()) { - if(doTile) { + if(doTile) return fl::full( {decoderStepsDim, inputSteps, batchSize}, inputSteps, fl::dtype::f32 ); - } else { + else return fl::full({1, 1, batchSize}, inputSteps, fl::dtype::f32); - } } - if(inputSizes.elements() != batchSize) { + if(inputSizes.elements() != batchSize) throw std::runtime_error( "Attention Window: wrong size of the input sizes vector, doesn't match with batchsize" ); - } Tensor inputNotPaddedSize = fl::ceil( inputSizes / fl::amax(inputSizes).asScalar() * inputSteps ); inputNotPaddedSize = fl::reshape(inputNotPaddedSize, {1, 1, batchSize}); - if(doTile) { + if(doTile) inputNotPaddedSize = fl::tile(inputNotPaddedSize, {decoderStepsDim, inputSteps, 1}); - } return inputNotPaddedSize; } @@ -50,18 +47,16 @@ Tensor WindowBase::computeTargetNotPaddedSize( int batchSize, int decoderStepsDim ) const { - if(targetSizes.isEmpty()) { + if(targetSizes.isEmpty()) return fl::full( {decoderStepsDim, inputSteps, batchSize}, targetLen, fl::dtype::f32 ); - } - if(targetSizes.elements() != batchSize) { + if(targetSizes.elements() != batchSize) throw std::runtime_error( "Window Attention: wrong size of the target sizes vector, doesn't match with batchsize" ); - } Tensor targetNotPaddedSize = fl::reshape( fl::ceil( targetSizes / fl::amax(targetSizes).asScalar() * targetLen diff --git a/flashlight/pkg/speech/criterion/backend/cpu/ConnectionistTemporalClassificationCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cpu/ConnectionistTemporalClassificationCriterion.cpp index 3826908..e0e4012 100644 --- a/flashlight/pkg/speech/criterion/backend/cpu/ConnectionistTemporalClassificationCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cpu/ConnectionistTemporalClassificationCriterion.cpp @@ -22,9 +22,8 @@ namespace pkg { std::vector ConnectionistTemporalClassificationCriterion::forward( const std::vector& inputs ) { - if(inputs.size() != 2) { + if(inputs.size() != 2) throw std::invalid_argument("Invalid inputs size"); - } const auto& input = inputs[0]; const auto& target = inputs[1]; validate(input, target); @@ -93,46 +92,42 @@ namespace pkg { // base case alphas[0] = (start == 0) ? inputVec[N - 1] : NEG_INFINITY_FLT; - if(S != 1) { + if(S != 1) alphas[1] = inputVec[targetVec[0]]; - } for(int64_t t = 1; t < T; ++t) { // At each time frame t, only few states can be reached depending // on the labels, their ordering and the current time frame. if(T - t <= L + R) { - if(start & 1 && targetVec[start / 2] != targetVec[start / 2 + 1]) { + if(start & 1 && targetVec[start / 2] != targetVec[start / 2 + 1]) ++start; - } ++start; } if(t <= L + R) { if( end % 2 == 0 && end < 2 * L && (targetVec[end / 2 - 1] != targetVec[end / 2]) - ) { + ) ++end; - } ++end; } // Use dynamic programming to recursively compute alphas for(int64_t s = start; s < end; ++s) { int64_t ts = t * S + s; int64_t curLabel = t * N + ((s & 1) ? targetVec[s / 2] : N - 1); - if(s == 0) { + if(s == 0) alphas[ts] = alphas[ts - S]; - } else if( + else if( (s % 2 == 0) || s == 1 || targetVec[s / 2] == targetVec[s / 2 - 1] - ) { + ) alphas[ts] = fl::pkg::speech::logSumExp(alphas[ts - S], alphas[ts - S - 1]); - } else { + else alphas[ts] = fl::pkg::speech::logSumExp( alphas[ts - S], alphas[ts - S - 1], alphas[ts - S - 2] ); - } alphas[ts] += inputVec[curLabel]; } } @@ -180,9 +175,9 @@ namespace pkg { std::vector dAlphas(T * S, 0.0); // Compute dAlphas for the last timeframe - if(S == 1) { + if(S == 1) dAlphas[T * S - 1] = -1.0; - } else { + else fl::pkg::speech::dLogSumExp( alphas[T * S - 2], alphas[T * S - 1], @@ -190,7 +185,6 @@ namespace pkg { dAlphas[T * S - 1], -1.0 ); - } float gradScale = batchOutGrad[b] * batchScales[b]; for(int64_t t = T - 1; t >= 0; --t) { @@ -200,18 +194,16 @@ namespace pkg { if( start & 1 && start > 1 && targetVec[start / 2] != targetVec[start / 2 - 1] - ) { + ) --start; - } --start; } if(t < L + R) { if( end % 2 == 0 && (targetVec[end / 2 - 1] != targetVec[end / 2 - 2]) - ) { + ) --end; - } --end; } // Compute grad and dAlphas for (t-1)th frame using chain rule @@ -219,15 +211,14 @@ namespace pkg { int64_t ts = t * S + s; int64_t curLabel = t * N + ((s & 1) ? targetVec[s / 2] : N - 1); grad[curLabel] += dAlphas[ts] * gradScale; - if(t == 0) { + if(t == 0) continue; - } - if(s == 0) { + if(s == 0) dAlphas[ts - S] += dAlphas[ts]; - } else if( + else if( (s % 2 == 0) || s == 1 || targetVec[s / 2] == targetVec[s / 2 - 1] - ) { + ) fl::pkg::speech::dLogSumExp( alphas[ts - S], alphas[ts - S - 1], @@ -235,7 +226,7 @@ namespace pkg { dAlphas[ts - S - 1], dAlphas[ts] ); - } else { + else fl::pkg::speech::dLogSumExp( alphas[ts - S], alphas[ts - S - 1], @@ -245,7 +236,6 @@ namespace pkg { dAlphas[ts - S - 2], dAlphas[ts] ); - } } } } diff --git a/flashlight/pkg/speech/criterion/backend/cpu/CriterionUtils.cpp b/flashlight/pkg/speech/criterion/backend/cpu/CriterionUtils.cpp index b0a0784..a934292 100644 --- a/flashlight/pkg/speech/criterion/backend/cpu/CriterionUtils.cpp +++ b/flashlight/pkg/speech/criterion/backend/cpu/CriterionUtils.cpp @@ -22,13 +22,12 @@ namespace pkg { auto T = input.dim(1); auto N = input.dim(0); - if(N != trans.dim(0) || N != trans.dim(1)) { + if(N != trans.dim(0) || N != trans.dim(1)) throw std::invalid_argument("viterbiPath: mismatched dims"); - } else if(input.type() != fl::dtype::f32) { + else if(input.type() != fl::dtype::f32) throw std::invalid_argument("viterbiPath: input must be float32"); - } else if(trans.type() != fl::dtype::f32) { + else if(trans.type() != fl::dtype::f32) throw std::invalid_argument("viterbiPath: trans must be float32"); - } auto inputVec = input.toHostVector(); auto transVec = trans.toHostVector(); diff --git a/flashlight/pkg/speech/criterion/backend/cpu/ForceAlignmentCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cpu/ForceAlignmentCriterion.cpp index ea23972..f0a7365 100644 --- a/flashlight/pkg/speech/criterion/backend/cpu/ForceAlignmentCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cpu/ForceAlignmentCriterion.cpp @@ -35,9 +35,8 @@ namespace pkg { int L, const std::shared_ptr& ctx ) { - if(gradVar.type() != fl::dtype::f32) { + if(gradVar.type() != fl::dtype::f32) throw std::invalid_argument("FAC: grad must be float32"); - } auto gradVec = gradVar.tensor().toHostVector(); std::vector inputGradVec(B * T * N); @@ -73,19 +72,18 @@ namespace pkg { int N = inputVar.dim(0); int L = targetVar.dim(0); - if(N != transVar.dim(0)) { + if(N != transVar.dim(0)) throw std::invalid_argument( "ForceAlignmentCriterion(cpu)::forward: input dim doesn't match N" ); - } else if(inputVar.type() != fl::dtype::f32) { + else if(inputVar.type() != fl::dtype::f32) throw std::invalid_argument( "ForceAlignmentCriterion(cpu)::forward: input must be float32" ); - } else if(targetVar.type() != fl::dtype::s32) { + else if(targetVar.type() != fl::dtype::s32) throw std::invalid_argument( "ForceAlignmentCriterion(cpu)::forward: target must be int32" ); - } const auto& targetSize = getTargetSizeArray(targetVar.tensor(), T); auto ctx = std::make_shared(); @@ -129,13 +127,12 @@ namespace pkg { int B = input.dim(2); // Batchsize int L = target.dim(0); // Target length - if(N != trans.dim(0)) { + if(N != trans.dim(0)) throw std::invalid_argument("FAC: input dim doesn't match N:"); - } else if(input.type() != fl::dtype::f32) { + else if(input.type() != fl::dtype::f32) throw std::invalid_argument("FAC: input must be float32"); - } else if(target.type() != fl::dtype::s32) { + else if(target.type() != fl::dtype::s32) throw std::invalid_argument("FAC: target must be int32"); - } const Tensor targetSize = getTargetSizeArray(target, T); std::shared_ptr ctx = std::make_shared(); std::vector inputVec = input.toHostVector(); diff --git a/flashlight/pkg/speech/criterion/backend/cpu/FullConnectionCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cpu/FullConnectionCriterion.cpp index 7394b4c..bf53376 100644 --- a/flashlight/pkg/speech/criterion/backend/cpu/FullConnectionCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cpu/FullConnectionCriterion.cpp @@ -34,9 +34,8 @@ namespace pkg { int N, const std::shared_ptr& ctx ) { - if(gradVar.type() != fl::dtype::f32) { + if(gradVar.type() != fl::dtype::f32) throw std::invalid_argument("FCC: grad must be float32"); - } auto gradVec = gradVar.tensor().toHostVector(); std::vector inputGradVec(B * T * N); @@ -69,13 +68,12 @@ namespace pkg { int T = inputVar.dim(1); int N = inputVar.dim(0); - if(N != transVar.dim(0)) { + if(N != transVar.dim(0)) throw std::invalid_argument("FCC: input dim doesn't match N"); - } else if(inputVar.type() != fl::dtype::f32) { + else if(inputVar.type() != fl::dtype::f32) throw std::invalid_argument("FCC: input must be float32"); - } else if(targetVar.type() != fl::dtype::s32) { + else if(targetVar.type() != fl::dtype::s32) throw std::invalid_argument("FCC: target must be int32"); - } const auto& targetSize = getTargetSizeArray(targetVar.tensor(), T); auto ctx = std::make_shared(); diff --git a/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp index 98f3b81..3ce1bb8 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/ConnectionistTemporalClassificationCriterion.cpp @@ -25,20 +25,18 @@ namespace fl::pkg::speech { namespace { inline void throw_on_error(ctcStatus_t status, const char* message) { - if(status != CTC_STATUS_SUCCESS) { + if(status != CTC_STATUS_SUCCESS) throw std::runtime_error( message + (", stat = " + std::string(ctcGetStatusString(status))) ); - } } } // namespace std::vector ConnectionistTemporalClassificationCriterion::forward( const std::vector& inputs ) { - if(inputs.size() != 2) { + if(inputs.size() != 2) throw std::invalid_argument("Invalid inputs size"); - } const auto& input = fl::moddims(inputs[0], {0, 0, 0}); // remove trailing singleton dims const auto& target = inputs[1]; @@ -59,9 +57,8 @@ std::vector ConnectionistTemporalClassificationCriterion::forward( fl::transpose(input.tensor(), {0, 2, 1}); Tensor grad; - if(input.isCalcGrad()) { + if(input.isCalcGrad()) grad = fl::full(inputarr.shape(), 0.0, inputarr.type()); - } std::vector inputLengths(B, T); std::vector labels; @@ -110,9 +107,8 @@ std::vector ConnectionistTemporalClassificationCriterion::forward( L = std::min(L + R, T) - R; labelLengths.push_back(L); - for(int l = 0; l < L; ++l) { + for(int l = 0; l < L; ++l) labels.push_back(targetVec[l]); - } } Tensor batchScales = Tensor::fromVector({B}, batchScaleVec); diff --git a/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp b/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp index 6b73e80..bf1dd7e 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/CriterionUtils.cpp @@ -22,28 +22,25 @@ using ViterbiPath = fl::lib::cuda::ViterbiPath; namespace fl::pkg::speech { Tensor viterbiPath(const Tensor& input, const Tensor& trans) { - if(input.ndim() != 3) { + if(input.ndim() != 3) throw std::invalid_argument( "Criterion viterbiPath expects input of shape {N, T, B}" ); - } - if(trans.ndim() != 2) { + if(trans.ndim() != 2) throw std::invalid_argument( "Criterion viterbiPath expects trans of shape {N, N}" ); - } auto B = input.dim(2); auto T = input.dim(1); auto N = input.dim(0); - if(N != trans.dim(0) || N != trans.dim(1)) { + if(N != trans.dim(0) || N != trans.dim(1)) throw std::invalid_argument("viterbiPath: mismatched dims"); - } else if(input.type() != fl::dtype::f32) { + else if(input.type() != fl::dtype::f32) throw std::invalid_argument("viterbiPath: input must be float32"); - } else if(trans.type() != fl::dtype::f32) { + else if(trans.type() != fl::dtype::f32) throw std::invalid_argument("viterbiPath: trans must be float32"); - } Tensor path({T, B}, fl::dtype::s32); Tensor workspace( diff --git a/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp index 4959fd3..463712b 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/ForceAlignmentCriterion.cpp @@ -31,14 +31,12 @@ static void backward( const Tensor& targetSize, Tensor& workspace ) { - if(gradVar.type() != fl::dtype::f32) { + if(gradVar.type() != fl::dtype::f32) throw std::invalid_argument("FAC: grad must be float32"); - } - if(inputs.size() != 2) { + if(inputs.size() != 2) throw std::invalid_argument( "ForceAlignmentCriterion backward expects two input args" ); - } const auto& grad = gradVar.tensor(); Tensor inputGrad({N, T, B}, fl::dtype::f32); @@ -80,13 +78,12 @@ Variable ForceAlignmentCriterion::forward( int N = inputVar.dim(0); int L = targetVar.dim(0); - if(N != transVar.dim(0)) { + if(N != transVar.dim(0)) throw std::invalid_argument("FAC: input dim doesn't match N"); - } else if(inputVar.type() != fl::dtype::f32) { + else if(inputVar.type() != fl::dtype::f32) throw std::invalid_argument("FAC: input must be float32"); - } else if(targetVar.type() != fl::dtype::s32) { + else if(targetVar.type() != fl::dtype::s32) throw std::invalid_argument("FAC: target must be int32"); - } const auto& input = inputVar.tensor(); const auto& target = targetVar.tensor(); @@ -134,12 +131,11 @@ Tensor ForceAlignmentCriterion::viterbiPath( const Tensor& input, const Tensor& target ) { - if(input.ndim() != 3) { + if(input.ndim() != 3) throw std::invalid_argument( "ForceAlignmentCriterion::viterbiPath: " "expects input with dimensions {N, T, B}" ); - } int N = input.dim(0); int T = input.dim(1); int B = input.dim(2); @@ -148,13 +144,12 @@ Tensor ForceAlignmentCriterion::viterbiPath( std::vector> bestPaths; const auto& transVar = param(0); - if(N != transVar.dim(0)) { + if(N != transVar.dim(0)) throw std::invalid_argument("FAC: input dim doesn't match N:"); - } else if(input.type() != fl::dtype::f32) { + else if(input.type() != fl::dtype::f32) throw std::invalid_argument("FAC: input must be float32"); - } else if(target.type() != fl::dtype::s32) { + else if(target.type() != fl::dtype::s32) throw std::invalid_argument("FAC: target must be int32"); - } const auto& targetSize = getTargetSizeArray(target, T); const auto& trans = transVar.tensor(); diff --git a/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp b/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp index a541d4e..6c63e8a 100644 --- a/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp +++ b/flashlight/pkg/speech/criterion/backend/cuda/FullConnectionCriterion.cpp @@ -29,14 +29,12 @@ static void backward( const Tensor& trans, Tensor& workspace ) { - if(gradVar.type() != fl::dtype::f32) { + if(gradVar.type() != fl::dtype::f32) throw std::invalid_argument("FCC: grad must be float32"); - } - if(inputs.size() != 2) { + if(inputs.size() != 2) throw std::invalid_argument( "FullConnectionCriterion backward expects two input args" ); - } const auto& grad = gradVar.tensor(); Tensor inputGrad({N, T, B}, fl::dtype::f32); @@ -69,31 +67,28 @@ Variable FullConnectionCriterion::forward( const Variable& inputVar, const Variable& targetVar ) { - if(inputVar.ndim() != 3) { + if(inputVar.ndim() != 3) throw std::invalid_argument( "FullConnectionCriterion::forward: " "expects input with dimensions {N, T, B}" ); - } - if(targetVar.ndim() != 2) { + if(targetVar.ndim() != 2) throw std::invalid_argument( "FullConnectionCriterion::forward: " "expects target with dimensions {B, L}" ); - } const auto& transVar = param(0); int B = inputVar.dim(2); int T = inputVar.dim(1); int N = inputVar.dim(0); - if(N != transVar.dim(0)) { + if(N != transVar.dim(0)) throw std::invalid_argument("FCC: input dim doesn't match N"); - } else if(inputVar.type() != fl::dtype::f32) { + else if(inputVar.type() != fl::dtype::f32) throw std::invalid_argument("FCC: input must be float32"); - } else if(targetVar.type() != fl::dtype::s32) { + else if(targetVar.type() != fl::dtype::s32) throw std::invalid_argument("FCC: target must be int32"); - } const auto& input = inputVar.tensor(); const auto& target = targetVar.tensor(); diff --git a/flashlight/pkg/speech/data/FeatureTransforms.cpp b/flashlight/pkg/speech/data/FeatureTransforms.cpp index 097c151..fea4737 100644 --- a/flashlight/pkg/speech/data/FeatureTransforms.cpp +++ b/flashlight/pkg/speech/data/FeatureTransforms.cpp @@ -76,26 +76,22 @@ fl::Dataset::DataTransformFunction inputFeatures( return [featSz, spectralFeature, localNormCtx, sfxConf, sfxCounter]( void* data, Shape dims, fl::dtype type) { - if(type != fl::dtype::f32) { + if(type != fl::dtype::f32) throw std::invalid_argument("Invalid input type"); - } - if(dims.ndim() != 2) { + if(dims.ndim() != 2) throw std::invalid_argument( "'inputFeatures': Invalid input dims . Expected 2d array - Channels x T" ); - } auto channels = dims[0]; std::vector input(dims.elements()); std::copy_n(static_cast(data), input.size(), input.data()); - if(channels > 1) { + if(channels > 1) input = transpose2d(input, dims[1], channels); - } if(!sfxConf.empty() && sfxCounter->decrementAndCheck()) { - if(channels > 1) { + if(channels > 1) throw std::invalid_argument( "'inputFeatures': Invalid input dims. sound effect supports a single channel audio" ); - } thread_local auto seed = getSfxSeed(); thread_local std::shared_ptr sfx = sfx::createSoundEffect(sfxConf, seed); @@ -103,9 +99,9 @@ fl::Dataset::DataTransformFunction inputFeatures( } std::vector output; - if(spectralFeature) { + if(spectralFeature) output = spectralFeature->batchApply(input, channels); - } else { + else { // use raw audio output = input; // T X CHANNELS (Col Major) } @@ -114,12 +110,11 @@ fl::Dataset::DataTransformFunction inputFeatures( // Before: FEAT X FRAMES X CHANNELS (Col Major) output = transpose2d(output, T, featSz, channels); // After: FRAMES X FEAT X CHANNELS (Col Major) - if(localNormCtx.first > 0 || localNormCtx.second > 0) { + if(localNormCtx.first > 0 || localNormCtx.second > 0) output = localNormalize(output, localNormCtx.first, localNormCtx.second, T); - } else { + else output = normalize(output); - } return Tensor::fromBuffer( {static_cast(T), featSz, channels}, output.data(), @@ -154,27 +149,22 @@ fl::Dataset::DataTransformFunction targetFeatures( // add surround token at the beginning and end of target // only if begin/end tokens are not surround auto idx = tokenDict.getIndex(config.surround_); - if(tgtVec.empty() || tgtVec.back() != idx) { + if(tgtVec.empty() || tgtVec.back() != idx) tgtVec.emplace_back(idx); - } if(tgtVec.size() > 1 && tgtVec.front() != idx) { tgtVec.emplace_back(idx); std::rotate(tgtVec.begin(), tgtVec.end() - 1, tgtVec.end()); } } - if(config.replabel_ > 0) { + if(config.replabel_ > 0) tgtVec = packReplabels(tgtVec, tokenDict, config.replabel_); - } - if(config.criterion_ == kAsgCriterion) { + if(config.criterion_ == kAsgCriterion) dedup(tgtVec); - } - if(config.eosToken_) { + if(config.eosToken_) tgtVec.emplace_back(tokenDict.getIndex(kEosToken)); - } - if(tgtVec.empty()) { + if(tgtVec.empty()) // support empty target return Tensor(fl::dtype::s32); - } return Tensor::fromVector(tgtVec); }; } @@ -185,10 +175,9 @@ fl::Dataset::DataTransformFunction wordFeatures(const Dictionary& wrdDict) { static_cast(data), static_cast(data) + dims.elements()); auto words = splitOnWhitespace(transcript, true); auto wrdVec = wrdDict.mapEntriesToIndices(words); - if(wrdVec.empty()) { + if(wrdVec.empty()) // support empty target return Tensor(fl::dtype::s32); - } return Tensor::fromVector(wrdVec); }; } diff --git a/flashlight/pkg/speech/data/FeatureTransforms.h b/flashlight/pkg/speech/data/FeatureTransforms.h index 69223dc..ead4a3a 100644 --- a/flashlight/pkg/speech/data/FeatureTransforms.h +++ b/flashlight/pkg/speech/data/FeatureTransforms.h @@ -93,17 +93,14 @@ namespace pkg { int64_t inCol, int64_t inBatch = 1 ) { - if(in.size() != inRow * inCol * inBatch) { + if(in.size() != inRow * inCol * inBatch) throw std::invalid_argument("Invalid input size"); - } std::vector out(in.size()); for(size_t b = 0; b < inBatch; ++b) { int64_t start = b * inRow * inCol; - for(size_t c = 0; c < inCol; ++c) { - for(size_t r = 0; r < inRow; ++r) { + for(size_t c = 0; c < inCol; ++c) + for(size_t r = 0; r < inRow; ++r) out[start + c * inRow + r] = in[start + r * inCol + c]; - } - } } return out; } @@ -117,9 +114,8 @@ namespace pkg { int64_t batchSz = 1, double threshold = 0.0 ) { - if(in.empty()) { + if(in.empty()) return {}; - } int64_t perBatchSz = in.size() / batchSz; int64_t perFrameSz = perBatchSz / frameSz; auto out(in); @@ -150,9 +146,8 @@ namespace pkg { curFrame = 0; for(auto i = b * perBatchSz; i < (b + 1) * perBatchSz; ++i) { out[i] -= sum[curFrame]; - if(sum2[curFrame] > threshold) { + if(sum2[curFrame] > threshold) out[i] /= sum2[curFrame]; - } curFrame = (curFrame + 1) % frameSz; } } @@ -165,9 +160,8 @@ namespace pkg { int64_t batchSz = 1, double threshold = 0.0 ) { - if(in.empty()) { + if(in.empty()) return {}; - } auto out(in); int64_t perBatchSz = out.size() / batchSz; for(size_t b = 0; b < batchSz; ++b) { @@ -181,7 +175,7 @@ namespace pkg { [mean](T x) { return x - mean; }); T sq_sum = std::inner_product(start, start + perBatchSz, start, 0.0); T stddev = std::sqrt(sq_sum / perBatchSz); - if(stddev > threshold) { + if(stddev > threshold) std::transform( start, start + perBatchSz, @@ -190,7 +184,6 @@ namespace pkg { return x / stddev; } ); - } } return out; } diff --git a/flashlight/pkg/speech/data/ListFileDataset.cpp b/flashlight/pkg/speech/data/ListFileDataset.cpp index f5e55f5..075cfda 100644 --- a/flashlight/pkg/speech/data/ListFileDataset.cpp +++ b/flashlight/pkg/speech/data/ListFileDataset.cpp @@ -32,22 +32,19 @@ ListFileDataset::ListFileDataset( wrdFeatFunc_(wrdFeatFunc), numRows_(0) { std::ifstream inFile(filename); - if(!inFile) { + if(!inFile) throw std::invalid_argument("Unable to open file -" + filename); - } std::string line; while(std::getline(inFile, line)) { - if(line.empty()) { + if(line.empty()) continue; - } auto splits = splitOnWhitespace(line, true); - if(splits.size() < 3) { + if(splits.size() < 3) throw std::runtime_error( "File " + filename + " has invalid columns in line (expected 3 columns at least): " + line ); - } ids_.emplace_back(std::move(splits[kIdIdx])); inputs_.emplace_back(std::move(splits[kInIdx])); @@ -73,19 +70,18 @@ std::vector ListFileDataset::get(const int64_t idx) const { auto audio = loadAudio(inputs_[idx]); // channels x time Tensor input; - if(inFeatFunc_) { + if(inFeatFunc_) input = inFeatFunc_( static_cast(audio.first.data()), audio.second, fl::dtype::f32 ); - } else { + else input = Tensor::fromBuffer( {audio.second}, audio.first.data(), MemoryLocation::Host ); - } Tensor target; if(tgtFeatFunc_) { @@ -146,12 +142,10 @@ float ListFileDataset::getInputSize(const int64_t idx) const { int64_t ListFileDataset::getTargetSize(const int64_t idx) const { checkIndexBounds(idx); - if(targetSizesCache_[idx] >= 0) { + if(targetSizesCache_[idx] >= 0) return targetSizesCache_[idx]; - } - if(!tgtFeatFunc_) { + if(!tgtFeatFunc_) return 0; - } std::vector curTarget(targets_[idx].begin(), targets_[idx].end()); auto tgtSize = tgtFeatFunc_( static_cast(curTarget.data()), diff --git a/flashlight/pkg/speech/data/Sound.cpp b/flashlight/pkg/speech/data/Sound.cpp index 9a08d2a..b53841a 100644 --- a/flashlight/pkg/speech/data/Sound.cpp +++ b/flashlight/pkg/speech/data/Sound.cpp @@ -114,9 +114,8 @@ static sf_count_t sf_vio_ro_read(void* ptr, sf_count_t count, void* user_data) { std::istream* f = reinterpret_cast(user_data); f->read((char*) ptr, count); auto n = f->gcount(); - if(!f->good()) { + if(!f->good()) f->clear(); - } return n; } @@ -184,9 +183,8 @@ static sf_count_t sf_vio_wo_tell(void* user_data) { SoundInfo loadSoundInfo(const std::string& filename) { std::ifstream f(filename); - if(!f.is_open()) { + if(!f.is_open()) throw std::runtime_error("could not open file for read " + filename); - } return loadSoundInfo(f); } @@ -203,11 +201,10 @@ SoundInfo loadSoundInfo(std::istream& f) { /* mandatory */ info.format = 0; - if(!(file = sf_open_virtual(&vsf, SFM_READ, &info, &f))) { + if(!(file = sf_open_virtual(&vsf, SFM_READ, &info, &f))) throw std::runtime_error( "loadSoundInfo: unknown format or could not open stream" ); - } sf_close(file); @@ -221,9 +218,8 @@ SoundInfo loadSoundInfo(std::istream& f) { template std::vector loadSound(const std::string& filename) { std::ifstream f(filename); - if(!f.is_open()) { + if(!f.is_open()) throw std::runtime_error("could not open file " + filename); - } return loadSound(f); } @@ -239,35 +235,32 @@ std::vector loadSound(std::istream& f) { info.format = 0; - if(!(file = sf_open_virtual(&vsf, SFM_READ, &info, &f))) { + if(!(file = sf_open_virtual(&vsf, SFM_READ, &info, &f))) throw std::runtime_error( "loadSound: unknown format or could not open stream" ); - } std::vector in(info.frames * info.channels); sf_count_t nframe; - if(std::is_same::value) { + if(std::is_same::value) nframe = sf_readf_float(file, reinterpret_cast(in.data()), info.frames); - } else if(std::is_same::value) { + else if(std::is_same::value) nframe = sf_readf_double( file, reinterpret_cast(in.data()), info.frames ); - } else if(std::is_same::value) { + else if(std::is_same::value) nframe = sf_readf_int(file, reinterpret_cast(in.data()), info.frames); - } else if(std::is_same::value) { + else if(std::is_same::value) nframe = sf_readf_short(file, reinterpret_cast(in.data()), info.frames); - } else { + else throw std::logic_error("loadSound: called with unsupported T"); - } sf_close(file); - if(nframe != info.frames) { + if(nframe != info.frames) throw std::runtime_error("loadSound: read error"); - } return in; } @@ -281,9 +274,8 @@ void saveSound( SoundSubFormat subformat ) { std::ofstream f(filename); - if(!f.is_open()) { + if(!f.is_open()) throw std::runtime_error("could not open file for write " + filename); - } saveSound(f, input, samplerate, channels, format, subformat); } @@ -304,23 +296,20 @@ void saveSound( SNDFILE* file; SF_INFO info; - if(formats.find(format) == formats.end()) { + if(formats.find(format) == formats.end()) throw std::invalid_argument("saveSound: invalid format"); - } - if(subformats.find(subformat) == subformats.end()) { + if(subformats.find(subformat) == subformats.end()) throw std::invalid_argument("saveSound: invalid subformat"); - } info.channels = channels; info.samplerate = samplerate; info.format = formats.find(format)->second | subformats.find(subformat)->second; - if(!(file = sf_open_virtual(&vsf, SFM_WRITE, &info, &f))) { + if(!(file = sf_open_virtual(&vsf, SFM_WRITE, &info, &f))) throw std::runtime_error( "saveSound: invalid format or could not write stream" ); - } /* Circumvent a bug in Vorbis with large buffers */ sf_count_t remainCount = input.size() / channels; @@ -329,37 +318,36 @@ void saveSound( while(remainCount > 0) { sf_count_t writableCount = std::min(chunkSize, remainCount); sf_count_t writtenCount = 0; - if(std::is_same::value) { + if(std::is_same::value) writtenCount = sf_writef_float( file, const_cast(reinterpret_cast(input.data())) + offsetCount * channels, writableCount ); - } else if(std::is_same::value) { + else if(std::is_same::value) writtenCount = sf_writef_double( file, const_cast(reinterpret_cast(input.data())) + offsetCount * channels, writableCount ); - } else if(std::is_same::value) { + else if(std::is_same::value) writtenCount = sf_writef_int( file, const_cast(reinterpret_cast(input.data())) + offsetCount * channels, writableCount ); - } else if(std::is_same::value) { + else if(std::is_same::value) writtenCount = sf_writef_short( file, const_cast(reinterpret_cast(input.data())) + offsetCount * channels, writableCount ); - } else { + else throw std::logic_error("saveSound: called with unsupported T"); - } if(writtenCount != writableCount) { sf_close(file); throw std::runtime_error("saveSound: write error"); diff --git a/flashlight/pkg/speech/data/Utils.cpp b/flashlight/pkg/speech/data/Utils.cpp index 1798645..f51b167 100644 --- a/flashlight/pkg/speech/data/Utils.cpp +++ b/flashlight/pkg/speech/data/Utils.cpp @@ -33,38 +33,33 @@ std::vector wrd2Target( lit->second.size() > 1 && targetSamplePct > static_cast(std::rand()) / static_cast(RAND_MAX) - ) { + ) return lit->second[std::rand() % lit->second.size()]; - } else { + else return lit->second[0]; - } } std::vector word2tokens; if(fallback2LtrWordSepLeft || fallback2LtrWordSepRight) { - if(fallback2LtrWordSepLeft && !wordSeparator.empty()) { + if(fallback2LtrWordSepLeft && !wordSeparator.empty()) // add word separator at the beginning of fallback word word2tokens.push_back(wordSeparator); - } auto tokens = splitWrd(word); for(const auto& tkn : tokens) { - if(dict.contains(tkn)) { + if(dict.contains(tkn)) word2tokens.push_back(tkn); - } else if(!skipUnk) { + else if(!skipUnk) throw std::invalid_argument( "Unknown token '" + tkn + "' when falling back to letter target for the unknown word: " + word ); - } } - if(fallback2LtrWordSepRight && !wordSeparator.empty()) { + if(fallback2LtrWordSepRight && !wordSeparator.empty()) // add word separator at the end of fallback word word2tokens.push_back(wordSeparator); - } - } else if(!skipUnk) { + } else if(!skipUnk) throw std::invalid_argument("Unknown word in the lexicon: " + word); - } return word2tokens; } @@ -91,9 +86,8 @@ std::vector wrd2Target( skipUnk ); - if(w2tokens.empty()) { + if(w2tokens.empty()) continue; - } res.insert(res.end(), w2tokens.begin(), w2tokens.end()); } return res; @@ -104,23 +98,22 @@ std::pair getFeatureType( int channels, const fl::lib::audio::FeatureParams& featParams ) { - if(featuresType == kFeaturesPow) { + if(featuresType == kFeaturesPow) return std::make_pair( featParams.powSpecFeatSz(), FeatureType::POW_SPECTRUM ); - } else if(featuresType == kFeaturesMFSC) { + else if(featuresType == kFeaturesMFSC) return std::make_pair(featParams.mfscFeatSz(), FeatureType::MFSC); - } else if(featuresType == kFeaturesMFSC) { + else if(featuresType == kFeaturesMFSC) return std::make_pair(featParams.mfccFeatSz(), FeatureType::MFCC); - } else if(featuresType == kFeaturesRaw) { + else if(featuresType == kFeaturesRaw) return std::make_pair(channels, FeatureType::NONE); - } else { + else throw std::runtime_error( "Unsupported feature type for audio preprocessing '" + featuresType + "'" ); - } } } // namespace fl diff --git a/flashlight/pkg/speech/decoder/ConvLmModule.cpp b/flashlight/pkg/speech/decoder/ConvLmModule.cpp index 361113f..da938ae 100644 --- a/flashlight/pkg/speech/decoder/ConvLmModule.cpp +++ b/flashlight/pkg/speech/decoder/ConvLmModule.cpp @@ -24,32 +24,28 @@ GetConvLmScoreFunc buildGetConvLmScoreFunction( int sampleSize = -1, int batchSize = 1) { sampleSize = sampleSize > 0 ? sampleSize : inputs.size(); - if(sampleSize * batchSize > inputs.size()) { + if(sampleSize * batchSize > inputs.size()) throw std::invalid_argument( "[ConvLM] Incorrect sample size (" + std::to_string(sampleSize) + ") or batch size (" + std::to_string(batchSize) + ")." ); - } Tensor inputData = Tensor::fromVector({sampleSize, batchSize}, inputs); fl::Variable output = network->forward({fl::input(inputData)})[0]; - if(fl::countNonzero(fl::isnan(output.tensor())).asScalar() != 0) { + if(fl::countNonzero(fl::isnan(output.tensor())).asScalar() != 0) throw std::runtime_error("[ConvLM] Encountered NaNs in propagation"); - } int32_t C = output.dim(0), T = output.dim(1), B = output.dim(2); - if(B != batchSize) { + if(B != batchSize) throw std::logic_error( "[ConvLM]: incorrect predictions: batch should be " + std::to_string(batchSize) + " but it is " + std::to_string(B) ); - } - if(batchSize != static_cast(lastTokenPositions.size())) { + if(batchSize != static_cast(lastTokenPositions.size())) throw std::logic_error( "[ConvLM]: incorrect postions for accessing: size should be " + std::to_string(batchSize) + " but it is " + std::to_string(lastTokenPositions.size()) ); - } // output (c, t, b) // set global indices: offset by channel Tensor globalIndices = fl::iota({C, 1}, {1, B}, fl::dtype::s32); diff --git a/flashlight/pkg/speech/decoder/DecodeMaster.cpp b/flashlight/pkg/speech/decoder/DecodeMaster.cpp index 8b89da1..876eb01 100644 --- a/flashlight/pkg/speech/decoder/DecodeMaster.cpp +++ b/flashlight/pkg/speech/decoder/DecodeMaster.cpp @@ -59,44 +59,39 @@ std::pair, fl::EditDistanceMeter wordEditDist, tokenEditDist; for(auto& sample : *predDataset) { - if(sample.size() <= kDMWordPredIdx) { + if(sample.size() <= kDMWordPredIdx) throw std::runtime_error( "computeMetrics: need token/word target to compute WER" ); - } auto predictionWrd = sample[kDMWordPredIdx]; auto targetWrd = sample[kDMWordTargetIdx]; auto prediction = sample[kDMTokenPredIdx]; auto target = sample[kDMTokenTargetIdx]; bool isPredictingWrd = !predictionWrd.isEmpty(); - if(prediction.ndim() > 2 || target.ndim() > 2) { + if(prediction.ndim() > 2 || target.ndim() > 2) throw std::runtime_error( "computeMetrics: expecting TxB for prediction and target" ); - } - if(isPredictingWrd && (predictionWrd.ndim() > 2 || targetWrd.ndim() > 2)) { + if(isPredictingWrd && (predictionWrd.ndim() > 2 || targetWrd.ndim() > 2)) throw std::runtime_error( "computeMetrics: expecting TxB for prediction and target" ); - } if( !prediction.isEmpty() && !target.isEmpty() && (prediction.dim(1) != target.dim(1)) - ) { + ) throw std::runtime_error( "computeMetrics: prediction and target do not match" ); - } if( isPredictingWrd && !predictionWrd.isEmpty() && !targetWrd.isEmpty() && (predictionWrd.dim(1) != targetWrd.dim(1)) - ) { + ) throw std::runtime_error( "computeMetrics: prediction and target do not match" ); - } // token predictions and target std::vector predictionV = prediction.toHostVector(); std::vector targetV = target.toHostVector(); @@ -151,27 +146,24 @@ std::shared_ptr DecodeMaster::forward( auto emissionDataset = std::make_shared(); for(auto& batch : *ds) { Tensor output; - if(batch.empty()) { + if(batch.empty()) continue; - } - if(usePlugin_) { + if(usePlugin_) output = net_->forward( {fl::input(batch[kInputIdx]), fl::noGrad(batch[kDurationIdx])} ) .front() .tensor(); - } else { + else output = fl::pkg::runtime::forwardSequentialModuleWithPadMask( fl::input(batch[kInputIdx]), net_, batch[kDurationIdx] ) .tensor(); - } - if(output.ndim() > 3) { + if(output.ndim() > 3) throw std::runtime_error("output should be NxTxB"); - } Tensor tokenTarget = (batch.size() > kTargetIdx ? batch[kTargetIdx] : Tensor()); Tensor wordTarget = (batch.size() > kWordIdx ? batch[kWordIdx] : Tensor()); @@ -180,15 +172,13 @@ std::shared_ptr DecodeMaster::forward( if( !tokenTarget.isEmpty() && (tokenTarget.ndim() > 2 || tokenTarget.dim(1) != B) - ) { + ) throw std::runtime_error("token target should be LxB"); - } if( !wordTarget.isEmpty() && (wordTarget.ndim() > 2 || wordTarget.dim(1) != B) - ) { + ) throw std::runtime_error("word target should be LxB"); - } // todo s2s, if we pad only with -1 we will be good here (not pad with eos) for(int b = 0; b < B; b++) { std::vector res(4); @@ -213,9 +203,8 @@ std::shared_ptr DecodeMaster::decode( auto predDataset = std::make_shared(); for(auto& sample : *emissionDataset) { auto emission = sample[kDMTokenPredIdx]; - if(emission.ndim() > 2) { + if(emission.ndim() > 2) throw std::runtime_error("emission should be NxT"); - } std::vector emissionV(emission.elements()); emission.astype(fl::dtype::f32).host(emissionV.data()); auto results = diff --git a/flashlight/pkg/speech/decoder/DecodeUtils.cpp b/flashlight/pkg/speech/decoder/DecodeUtils.cpp index 37719d5..50ef5d7 100644 --- a/flashlight/pkg/speech/decoder/DecodeUtils.cpp +++ b/flashlight/pkg/speech/decoder/DecodeUtils.cpp @@ -22,9 +22,8 @@ std::shared_ptr buildTrie( const int wordSeparatorIdx, const int repLabel ) { - if(!(decoderType == "wrd" || useLexicon)) { + if(!(decoderType == "wrd" || useLexicon)) return nullptr; - } auto trie = std::make_shared( tokenDict.indexSize(), wordSeparatorIdx @@ -46,16 +45,15 @@ std::shared_ptr buildTrie( } // Smearing SmearingMode smearMode = SmearingMode::NONE; - if(smearing == "logadd") { + if(smearing == "logadd") smearMode = SmearingMode::LOGADD; - } else if(smearing == "max") { + else if(smearing == "max") smearMode = SmearingMode::MAX; - } else if(smearing != "none") { + else if(smearing != "none") throw std::runtime_error( "[buildTrie] Invalid smearing option, can be {logadd, max, none}, provided value is " + smearing ); - } trie->smear(smearMode); return trie; } diff --git a/flashlight/pkg/speech/decoder/PlGenerator.cpp b/flashlight/pkg/speech/decoder/PlGenerator.cpp index 76bf39f..86d5356 100644 --- a/flashlight/pkg/speech/decoder/PlGenerator.cpp +++ b/flashlight/pkg/speech/decoder/PlGenerator.cpp @@ -69,29 +69,25 @@ PlGenerator::PlGenerator( auto plEpochVec = lib::split(',', plEpoch, true); auto plRatioVec = lib::split(',', plRatio, true); - if(plEpochVec.size() != plRatioVec.size()) { + if(plEpochVec.size() != plRatioVec.size()) throw std::invalid_argument( "[PlGenerator] Size mismatch between pl_epoch and pl_ratio." ); - } plEpochs_.resize(plEpochVec.size()); - for(int i = 0; i < plEpochVec.size(); i++) { + for(int i = 0; i < plEpochVec.size(); i++) plEpochs_[i] = stoi(plEpochVec[i]); - } for(int i = 0; i < plEpochVec.size(); i++) { auto ratio = stof(plRatioVec[i]); - if(ratio < 0 || ratio > 1) { + if(ratio < 0 || ratio > 1) throw std::invalid_argument( "[PlGenerator] The value of pl_ratio should be in [0, 1]." ); - } - if(i > 0 && plEpochs_[i] <= plEpochs_[i - 1]) { + if(i > 0 && plEpochs_[i] <= plEpochs_[i - 1]) throw std::invalid_argument( "[PlGenerator] Elements in pl_epoch should be in ascendant order." ); - } plUpdateMap_[plEpochs_[i]] = ratio; } @@ -109,18 +105,16 @@ PlGenerator::PlGenerator( allListDs.emplace_back(curListDs); } if(!allListDs.empty()) { - if(isMaster_) { + if(isMaster_) fs::create_directory(plDir_); - } fullUnsupDs_ = std::make_shared(allListDs); } } std::string PlGenerator::reloadPl(int curEpoch) const { int lastPlEpoch = findLastPlEpoch(curEpoch); - if(lastPlEpoch < 0) { + if(lastPlEpoch < 0) return ""; - } fs::path plDir = plDir_ / (kPlSubdirPrefix + std::to_string(lastPlEpoch)); @@ -147,12 +141,10 @@ std::string PlGenerator::regeneratePl( const std::shared_ptr criterion, const bool usePlugin /* = false */ ) const { - if(plUpdateMap_.find(curEpoch) == plUpdateMap_.end()) { + if(plUpdateMap_.find(curEpoch) == plUpdateMap_.end()) return ""; - } - if(!fullUnsupDs_) { + if(!fullUnsupDs_) throw std::runtime_error("No unlabeled data is provided"); - } logMaster( "[PlGenerator] Regenerating PL at epoch " + std::to_string(curEpoch) @@ -166,11 +158,10 @@ std::string PlGenerator::regeneratePl( // Pass. Allowing attempts from all processes to create the folder. } - if(!fs::is_directory(plDir)) { + if(!fs::is_directory(plDir)) throw std::runtime_error( "[PlGenerator] Failed to create " + plDir.string() ); - } /* 1. select data */ // shuffle @@ -202,9 +193,8 @@ std::string PlGenerator::regeneratePl( std::ofstream plStream(newPlFile); for(auto& sample : *selectedDs) { auto duration = sample[kDurationIdx].scalar(); - if(duration < minInputSize_ || duration > maxInputSize_) { + if(duration < minInputSize_ || duration > maxInputSize_) continue; - } std::vector words; if(useExistingPl_ && seedModelWER_ < currentModelWER_) { @@ -212,27 +202,25 @@ std::string PlGenerator::regeneratePl( words = tokenToWord_(tokenTarget, tokenDict_, false); } else { fl::Variable rawEmission; - if(usePlugin) { + if(usePlugin) rawEmission = ntwrk ->forward( {fl::input(sample[kInputIdx]), fl::noGrad(sample[kDurationIdx])} ) .front(); - } else { + else rawEmission = fl::pkg::runtime::forwardSequentialModuleWithPadMask( fl::input(sample[kInputIdx]), ntwrk, sample[kDurationIdx] ); - } auto tokenPrediction = criterion->viterbiPath(rawEmission.tensor()).toHostVector(); words = tokenToWord_(tokenPrediction, tokenDict_, true); } - if(words.size() < minTargetSize_ || words.size() > maxTargetSize_) { + if(words.size() < minTargetSize_ || words.size() > maxTargetSize_) continue; - } auto sampleId = readSampleIds(sample[kSampleIdx]).front(); auto inputPath = readSampleIds(sample[kPathIdx]).front(); @@ -260,12 +248,10 @@ std::shared_ptr PlGenerator::createTrainSet( int maxDurationPerBatch /* = 0 */ ) const { std::vector files; - for(const auto& file : lib::split(",", trainLists, true)) { + for(const auto& file : lib::split(",", trainLists, true)) files.emplace_back(trainDir / file); - } - for(int i = 0; i < worldSize_; i++) { + for(int i = 0; i < worldSize_; i++) files.emplace_back(trainUnsupDir / (std::to_string(i) + ".lst")); - } return createDataset( files, @@ -290,18 +276,16 @@ void PlGenerator::setModelWER(const float& wer) { int PlGenerator::findLastPlEpoch(int curEpoch) const { int lastPlEpoch = -1; for(const auto& i : plEpochs_) { - if(i > curEpoch) { + if(i > curEpoch) break; - } lastPlEpoch = i; } return lastPlEpoch; } void PlGenerator::logMaster(const std::string& message) const { - if(worldRank_ != 0) { + if(worldRank_ != 0) return; - } std::cerr << message << std::endl; } diff --git a/flashlight/pkg/speech/decoder/TranscriptionUtils.cpp b/flashlight/pkg/speech/decoder/TranscriptionUtils.cpp index 90ff9f1..508e06c 100644 --- a/flashlight/pkg/speech/decoder/TranscriptionUtils.cpp +++ b/flashlight/pkg/speech/decoder/TranscriptionUtils.cpp @@ -27,21 +27,17 @@ std::vector tknIdx2Ltr( auto token = d.getEntry(id); if(useWordPiece) { auto splitToken = splitWrd(token); - for(const auto& c : splitToken) { + for(const auto& c : splitToken) result.emplace_back(c); - } - } else { + } else result.emplace_back(token); - } } if(!result.empty() && !wordSep.empty()) { - if(result.front() == wordSep) { + if(result.front() == wordSep) result.erase(result.begin()); - } - if(!result.empty() && result.back() == wordSep) { + if(!result.empty() && result.back() == wordSep) result.pop_back(); - } } return result; @@ -59,13 +55,11 @@ std::vector tkn2Wrd( words.push_back(currentWord); currentWord = ""; } - } else { + } else currentWord += tkn; - } } - if(!currentWord.empty()) { + if(!currentWord.empty()) words.push_back(currentWord); - } return words; } @@ -74,9 +68,8 @@ std::vector wrdIdx2Wrd( const Dictionary& wordDict ) { std::vector words; - for(auto wrdIdx : input) { + for(auto wrdIdx : input) words.push_back(wordDict.getEntry(wrdIdx)); - } return words; } @@ -90,15 +83,12 @@ std::vector tknTarget2Ltr( const bool useWordPiece, const std::string& wordSep ) { - if(tokens.empty()) { + if(tokens.empty()) return std::vector{}; - } - if(isSeq2seqCrit) { - if(tokens.back() == tokenDict.getIndex(kEosToken)) { + if(isSeq2seqCrit) + if(tokens.back() == tokenDict.getIndex(kEosToken)) tokens.pop_back(); - } - } remapLabels(tokens, tokenDict, surround, isSeq2seqCrit, replabel); return tknIdx2Ltr(tokens, tokenDict, useWordPiece, wordSep); @@ -114,13 +104,11 @@ std::vector tknPrediction2Ltr( const bool useWordPiece, const std::string& wordSep ) { - if(tokens.empty()) { + if(tokens.empty()) return std::vector{}; - } - if(criterion == kCtcCriterion || criterion == kAsgCriterion) { + if(criterion == kCtcCriterion || criterion == kAsgCriterion) dedup(tokens); - } if(criterion == kCtcCriterion) { int blankIdx = tokenDict.getIndex(kBlankToken); tokens.erase( @@ -136,12 +124,11 @@ std::vector tknPrediction2Ltr( std::vector validateIdx(std::vector input, int unkIdx) { int newSize = 0; - for(int i = 0; i < input.size(); i++) { + for(int i = 0; i < input.size(); i++) if(input[i] >= 0 && input[i] != unkIdx) { input[newSize] = input[i]; newSize++; } - } input.resize(newSize); return input; diff --git a/flashlight/pkg/speech/decoder/TranscriptionUtils.h b/flashlight/pkg/speech/decoder/TranscriptionUtils.h index e8ce931..f95764d 100644 --- a/flashlight/pkg/speech/decoder/TranscriptionUtils.h +++ b/flashlight/pkg/speech/decoder/TranscriptionUtils.h @@ -84,28 +84,22 @@ namespace pkg { ) { labels.pop_back(); } - } else { + } else while(!labels.empty() && labels.back() == kTargetPadValue) { labels.pop_back(); } - } - if(replabel > 0) { + if(replabel > 0) labels = unpackReplabels(labels, dict, replabel); - } auto trimLabels = [&labels](int idx) { - if(!labels.empty() && labels.back() == idx) { + if(!labels.empty() && labels.back() == idx) labels.pop_back(); - } - if(!labels.empty() && labels.front() == idx) { + if(!labels.empty() && labels.front() == idx) labels.erase(labels.begin()); - } }; - if(dict.contains(kSilToken)) { + if(dict.contains(kSilToken)) trimLabels(dict.getIndex(kSilToken)); - } - if(!surround.empty()) { + if(!surround.empty()) trimLabels(dict.getIndex(surround)); - } }; } // namespace speech } // namespace pkg diff --git a/flashlight/pkg/speech/runtime/Attention.cpp b/flashlight/pkg/speech/runtime/Attention.cpp index fe72446..f72cee2 100644 --- a/flashlight/pkg/speech/runtime/Attention.cpp +++ b/flashlight/pkg/speech/runtime/Attention.cpp @@ -11,89 +11,87 @@ namespace fl::pkg::speech { std::shared_ptr createAttention() { std::shared_ptr attention; - if(FLAGS_attention == fl::pkg::speech::kContentAttention) { + if(FLAGS_attention == fl::pkg::speech::kContentAttention) attention = std::make_shared(); - } else if(FLAGS_attention == fl::pkg::speech::kKeyValueAttention) { + else if(FLAGS_attention == fl::pkg::speech::kKeyValueAttention) attention = std::make_shared(true); - } else if(FLAGS_attention == fl::pkg::speech::kNeuralContentAttention) { + else if(FLAGS_attention == fl::pkg::speech::kNeuralContentAttention) attention = std::make_shared(FLAGS_encoderdim); - } else if(FLAGS_attention == fl::pkg::speech::kSimpleLocationAttention) { + else if(FLAGS_attention == fl::pkg::speech::kSimpleLocationAttention) attention = std::make_shared(FLAGS_attnconvkernel); - } else if(FLAGS_attention == fl::pkg::speech::kLocationAttention) { + else if(FLAGS_attention == fl::pkg::speech::kLocationAttention) attention = std::make_shared( FLAGS_encoderdim, FLAGS_attnconvkernel ); - } else if(FLAGS_attention == fl::pkg::speech::kNeuralLocationAttention) { + else if(FLAGS_attention == fl::pkg::speech::kNeuralLocationAttention) attention = std::make_shared( FLAGS_encoderdim, FLAGS_attndim, FLAGS_attnconvchannel, FLAGS_attnconvkernel ); - } // is it fine for transformer criterion? - else if(FLAGS_attention == fl::pkg::speech::kMultiHeadContentAttention) { + // is it fine for transformer criterion? + else if(FLAGS_attention == fl::pkg::speech::kMultiHeadContentAttention) attention = std::make_shared( FLAGS_encoderdim, FLAGS_numattnhead ); - } else if( - FLAGS_attention == fl::pkg::speech::kMultiHeadKeyValueContentAttention) { + else if( + FLAGS_attention == fl::pkg::speech::kMultiHeadKeyValueContentAttention) attention = std::make_shared( FLAGS_encoderdim, FLAGS_numattnhead, true ); - } else if(FLAGS_attention == fl::pkg::speech::kMultiHeadSplitContentAttention) { + else if(FLAGS_attention == fl::pkg::speech::kMultiHeadSplitContentAttention) attention = std::make_shared( FLAGS_encoderdim, FLAGS_numattnhead, false, true ); - } else if( + else if( FLAGS_attention == fl::pkg::speech::kMultiHeadKeyValueSplitContentAttention - ) { + ) attention = std::make_shared( FLAGS_encoderdim, FLAGS_numattnhead, true, true ); - } else { + else throw std::runtime_error("Unimplmented attention: " + FLAGS_attention); - } return attention; } std::shared_ptr createAttentionWindow() { std::shared_ptr window; - if(FLAGS_attnWindow == fl::pkg::speech::kNoWindow) { + if(FLAGS_attnWindow == fl::pkg::speech::kNoWindow) window = nullptr; - } else if(FLAGS_attnWindow == fl::pkg::speech::kMedianWindow) { + else if(FLAGS_attnWindow == fl::pkg::speech::kMedianWindow) window = std::make_shared( FLAGS_leftWindowSize, FLAGS_rightWindowSize ); - } else if(FLAGS_attnWindow == fl::pkg::speech::kStepWindow) { + else if(FLAGS_attnWindow == fl::pkg::speech::kStepWindow) window = std::make_shared( FLAGS_minsil, FLAGS_maxsil, FLAGS_minrate, FLAGS_maxrate ); - } else if(FLAGS_attnWindow == fl::pkg::speech::kSoftWindow) { + else if(FLAGS_attnWindow == fl::pkg::speech::kSoftWindow) window = std::make_shared( FLAGS_softwstd, FLAGS_softwrate, FLAGS_softwoffset ); - } else if(FLAGS_attnWindow == fl::pkg::speech::kSoftPretrainWindow) { + else if(FLAGS_attnWindow == fl::pkg::speech::kSoftPretrainWindow) window = std::make_shared(FLAGS_softwstd); - } else { + else throw std::runtime_error("Unimplmented window: " + FLAGS_attnWindow); - } return window; } } // namespace fl diff --git a/flashlight/pkg/speech/runtime/Helpers.cpp b/flashlight/pkg/speech/runtime/Helpers.cpp index ac5bbe6..966d32b 100644 --- a/flashlight/pkg/speech/runtime/Helpers.cpp +++ b/flashlight/pkg/speech/runtime/Helpers.cpp @@ -64,13 +64,12 @@ std::string serializeGflags(const std::string& separator /* = "\n" */) { gflags::GetAllFlags(&allFlags); std::string currVal; auto& deprecatedFlags = detail::getDeprecatedFlags(); - for(auto itr = allFlags.begin(); itr != allFlags.end(); ++itr) { + for(auto itr = allFlags.begin(); itr != allFlags.end(); ++itr) // Check if the flag is deprecated - if so, skip it if(deprecatedFlags.find(itr->name) == deprecatedFlags.end()) { gflags::GetCommandLineOption(itr->name.c_str(), &currVal); serialized << "--" << itr->name << "=" << currVal << separator; } - } return serialized.str(); } @@ -78,11 +77,9 @@ std::unordered_set getTrainEvalIds(int64_t dsSize, double pctTrainEval, std::mt19937_64 rng(seed); std::bernoulli_distribution dist(pctTrainEval / 100.0); std::unordered_set result; - for(int64_t i = 0; i < dsSize; ++i) { - if(dist(rng)) { + for(int64_t i = 0; i < dsSize; ++i) + if(dist(rng)) result.insert(i); - } - } return result; } @@ -121,20 +118,18 @@ std::shared_ptr createDataset( LOG(FATAL) << "EverstoreDataset not supported: " << "build with -DFL_BUILD_FB_DEPENDENCIES"; #endif - } else { + } else curListDs = std::make_shared( rootDir / path, inputTransform, targetTransform, wordTransform ); - } allListDs.emplace_back(curListDs); sizes.reserve(sizes.size() + curListDs->size()); - for(int64_t i = 0; i < curListDs->size(); ++i) { + for(int64_t i = 0; i < curListDs->size(); ++i) sizes.push_back(curListDs->getInputSize(i)); - } } // Order Dataset @@ -219,11 +214,10 @@ std::shared_ptr createDataset( fl::BatchDatasetPolicy::INCLUDE_LAST, batchFns ); - } else { + } else throw std::runtime_error( "Unsupported batching strategy '" + batchingStrategy + "'" ); - } } std::shared_ptr loadPrefetchDataset( @@ -232,16 +226,14 @@ std::shared_ptr loadPrefetchDataset( bool shuffle, int shuffleSeed /*= 0 */ ) { - if(shuffle) { + if(shuffle) dataset = std::make_shared(dataset, shuffleSeed); - } - if(prefetchThreads > 0) { + if(prefetchThreads > 0) dataset = std::make_shared( dataset, prefetchThreads, prefetchThreads /* prefetch size */ ); - } return dataset; } @@ -253,11 +245,10 @@ std::vector> parseValidSets( for(const auto& s : validSets) { // assume the format is tag:filepath auto ts = fl::lib::splitOnAnyOf(":", s); - if(ts.size() == 1) { + if(ts.size() == 1) validTagSets.emplace_back(s, s); - } else { + else validTagSets.emplace_back(ts[0], ts[1]); - } } return validTagSets; } diff --git a/flashlight/pkg/speech/runtime/Logger.cpp b/flashlight/pkg/speech/runtime/Logger.cpp index 4bee467..625100f 100644 --- a/flashlight/pkg/speech/runtime/Logger.cpp +++ b/flashlight/pkg/speech/runtime/Logger.cpp @@ -72,12 +72,11 @@ std::string getLogString( format("%5.2f", v.second.wrdEdit.errorRate()[0]) ); auto vDecoderIter = validDecoderWer.find(v.first); - if(vDecoderIter != validDecoderWer.end()) { + if(vDecoderIter != validDecoderWer.end()) insertItem( v.first + "-WER-decoded", format("%5.2f", vDecoderIter->second) ); - } } auto stats = meters.stats.value(); auto numsamples = std::max(stats[4], 1); @@ -87,11 +86,10 @@ std::string getLogString( auto tsztotal = stats[1]; auto tszmax = stats[3]; auto iszAvrFrames = isztotal / numsamples; - if(FLAGS_features_type != kFeaturesRaw) { + if(FLAGS_features_type != kFeaturesRaw) iszAvrFrames = iszAvrFrames / FLAGS_framestridems; - } else { + else iszAvrFrames = iszAvrFrames / 1000 * FLAGS_samplerate; - } insertItem("avg-isz", format("%03d", iszAvrFrames)); insertItem("avg-tsz", format("%03d", tsztotal / numsamples)); insertItem("max-tsz", format("%03d", tszmax)); @@ -114,9 +112,8 @@ void appendToLog(std::ofstream& logfile, const std::string& logstr) { auto write = [&]() { logfile.clear(); // reset flags logfile << logstr << std::endl; - if(!logfile) { + if(!logfile) throw std::runtime_error("appending to log failed"); - } }; retryWithBackoff(std::chrono::seconds(1), 1.0, 6, write); } diff --git a/flashlight/pkg/speech/runtime/Optimizer.cpp b/flashlight/pkg/speech/runtime/Optimizer.cpp index b4987af..d372ad6 100644 --- a/flashlight/pkg/speech/runtime/Optimizer.cpp +++ b/flashlight/pkg/speech/runtime/Optimizer.cpp @@ -18,11 +18,10 @@ std::shared_ptr initOptimizer( double momentum, double weightdecay ) { - if(nets.empty()) { + if(nets.empty()) throw std::invalid_argument( "[InitOptimizer]: No network for initializing the optimizer" ); - } std::vector params; for(const auto& n : nets) { @@ -31,9 +30,9 @@ std::shared_ptr initOptimizer( } std::shared_ptr opt; - if(optimizer == kSGDOptimizer) { + if(optimizer == kSGDOptimizer) opt = std::make_shared(params, lr, momentum, weightdecay); - } else if(optimizer == kAdamOptimizer) { + else if(optimizer == kAdamOptimizer) opt = std::make_shared( params, lr, @@ -42,7 +41,7 @@ std::shared_ptr initOptimizer( FLAGS_optimepsilon, weightdecay ); - } else if(optimizer == kRMSPropOptimizer) { + else if(optimizer == kRMSPropOptimizer) opt = std::make_shared( params, lr, @@ -50,7 +49,7 @@ std::shared_ptr initOptimizer( FLAGS_optimepsilon, weightdecay ); - } else if(optimizer == kAdadeltaOptimizer) { + else if(optimizer == kAdadeltaOptimizer) opt = std::make_shared( params, lr, @@ -58,10 +57,10 @@ std::shared_ptr initOptimizer( FLAGS_optimepsilon, weightdecay ); - } else if(optimizer == kAdagradOptimizer) { + else if(optimizer == kAdagradOptimizer) opt = std::make_shared(params, lr, FLAGS_optimepsilon); - } else if(optimizer == kAMSgradOptimizer) { + else if(optimizer == kAMSgradOptimizer) opt = std::make_shared( params, lr, @@ -71,7 +70,7 @@ std::shared_ptr initOptimizer( weightdecay ); - } else if(optimizer == kNovogradOptimizer) { + else if(optimizer == kNovogradOptimizer) opt = std::make_shared( params, lr, @@ -80,9 +79,8 @@ std::shared_ptr initOptimizer( FLAGS_optimepsilon, weightdecay ); - } else { + else LOG(FATAL) << "Optimizer option " << optimizer << " not implemented"; - } return opt; } diff --git a/flashlight/pkg/speech/test/audio/MfccTest.cpp b/flashlight/pkg/speech/test/audio/MfccTest.cpp index 310a26b..837353e 100644 --- a/flashlight/pkg/speech/test/audio/MfccTest.cpp +++ b/flashlight/pkg/speech/test/audio/MfccTest.cpp @@ -65,9 +65,8 @@ TEST(MfccTest, htkCompareTest) { // HTK keeps C0 at last position. adjust accordingly. auto featcopy(feat); for(int f = 0; f < numframes; ++f) { - for(int i = 1; i < 39; ++i) { + for(int i = 1; i < 39; ++i) feat[f * 39 + i - 1] = feat[f * 39 + i]; - } feat[f * 39 + 12] = featcopy[f * 39 + 0]; feat[f * 39 + 25] = featcopy[f * 39 + 13]; feat[f * 39 + 38] = featcopy[f * 39 + 26]; @@ -76,9 +75,8 @@ TEST(MfccTest, htkCompareTest) { for(int i = 0; i < feat.size(); ++i) { auto curdiff = std::abs(feat[i] - htkfeat[i]); sum += curdiff; - if(max < curdiff) { + if(max < curdiff) max = curdiff; - } } std::cerr << "| Max diff across all dimensions " << max << "\n"; // 0.325853 @@ -98,9 +96,9 @@ TEST(MfccTest, BatchingTest) { std::vector usePow = {true, false}; int numTrials = 3; - for(auto e : energies) { - for(auto r : rawEnergies) { - for(auto z : zMeans) { + for(auto e : energies) + for(auto r : rawEnergies) + for(auto z : zMeans) for(auto p : usePow) { featparams.useEnergy = e; featparams.rawEnergy = r; @@ -118,15 +116,11 @@ TEST(MfccTest, BatchingTest) { std::copy(input.begin(), input.begin() + curSz, curInput.begin()); auto curOutput = mfcc.apply(curInput); ASSERT_GT(curOutput.size(), 0); - for(int j = 0; j < curOutput.size(); ++j) { + for(int j = 0; j < curOutput.size(); ++j) ASSERT_NEAR(curOutput[j], output[j], 1E-4); - } } } } - } - } - } } TEST(MfccTest, BatchingTest2) { @@ -140,9 +134,9 @@ TEST(MfccTest, BatchingTest2) { std::vector zMeans = {true, false}; std::vector usePow = {true, false}; - for(auto e : energies) { - for(auto r : rawEnergies) { - for(auto z : zMeans) { + for(auto e : energies) + for(auto r : rawEnergies) + for(auto z : zMeans) for(auto p : usePow) { featparams.useEnergy = e; featparams.rawEnergy = r; @@ -164,14 +158,10 @@ TEST(MfccTest, BatchingTest2) { ); auto curOutput = mfcc.apply(curInput); ASSERT_EQ(curOutput.size(), perBatchOutSz); - for(int j = 0; j < curOutput.size(); ++j) { + for(int j = 0; j < curOutput.size(); ++j) ASSERT_NEAR(curOutput[j], output[j + i * perBatchOutSz], 1E-4); - } } } - } - } - } } TEST(MfccTest, EmptyTest) { @@ -195,9 +185,8 @@ TEST(MfccTest, ZeroInputTest) { Mfsc mfcc(params); auto input = std::vector(10000, 0.0); auto output = mfcc.apply(input); - for(auto o : output) { + for(auto o : output) ASSERT_NEAR(o, 0.0, 1E-4); - } } int main(int argc, char** argv) { diff --git a/flashlight/pkg/speech/test/audio/TestUtils.h b/flashlight/pkg/speech/test/audio/TestUtils.h index 3826a11..131cc25 100644 --- a/flashlight/pkg/speech/test/audio/TestUtils.h +++ b/flashlight/pkg/speech/test/audio/TestUtils.h @@ -13,14 +13,11 @@ template bool compareVec(std::vector A, std::vector B, float precision = 1E-5) { - if(A.size() != B.size()) { + if(A.size() != B.size()) return false; - } - for(std::size_t i = 0; i < A.size(); ++i) { - if(std::abs(A[i] - B[i]) > precision) { + for(std::size_t i = 0; i < A.size(); ++i) + if(std::abs(A[i] - B[i]) > precision) return false; - } - } return true; } @@ -37,10 +34,8 @@ std::vector randVec(std::size_t N, float min = -1.0, float max = 1.0) { template std::vector transposeVec(const std::vector& in, int inRow, int inCol) { std::vector out(inRow * inCol); - for(size_t r = 0; r < inRow; ++r) { - for(size_t c = 0; c < inCol; ++c) { + for(size_t r = 0; r < inRow; ++r) + for(size_t c = 0; c < inCol; ++c) out[c * inRow + r] = in[r * inCol + c]; - } - } return out; } diff --git a/flashlight/pkg/speech/test/augmentation/AdditiveNoiseTest.cpp b/flashlight/pkg/speech/test/augmentation/AdditiveNoiseTest.cpp index 9053771..b60b4fa 100644 --- a/flashlight/pkg/speech/test/augmentation/AdditiveNoiseTest.cpp +++ b/flashlight/pkg/speech/test/augmentation/AdditiveNoiseTest.cpp @@ -82,9 +82,8 @@ TEST(AdditiveNoise, Snr) { sfx.apply(augmented); std::vector extractNoise(augmented.size()); - for(int i = 0; i < extractNoise.size(); ++i) { + for(int i = 0; i < extractNoise.size(); ++i) extractNoise[i] = (augmented[i] - signal[i]); - } ASSERT_LE( signalToNoiseRatio(signal, extractNoise), diff --git a/flashlight/pkg/speech/test/augmentation/GaussianNoiseTest.cpp b/flashlight/pkg/speech/test/augmentation/GaussianNoiseTest.cpp index b83b05d..6dc2842 100644 --- a/flashlight/pkg/speech/test/augmentation/GaussianNoiseTest.cpp +++ b/flashlight/pkg/speech/test/augmentation/GaussianNoiseTest.cpp @@ -23,9 +23,8 @@ TEST(GaussianNoise, SnrCheck) { for(int r = 0; r < numTrys; ++r) { RandomNumberGenerator rng(r); std::vector signal(numSamples); - for(auto& i : signal) { + for(auto& i : signal) i = rng.random(); - } GaussianNoise::Config cfg; cfg.minSnr_ = 8; @@ -35,9 +34,8 @@ TEST(GaussianNoise, SnrCheck) { sfx.apply(signal); ASSERT_EQ(signal.size(), originalSignal.size()); std::vector noise(signal.size()); - for(int i = 0; i < noise.size(); ++i) { + for(int i = 0; i < noise.size(); ++i) noise[i] = signal[i] - originalSignal[i]; - } ASSERT_LE(signalToNoiseRatio(originalSignal, noise), cfg.maxSnr_ + tolerance); ASSERT_GE(signalToNoiseRatio(originalSignal, noise), cfg.minSnr_ - tolerance); } diff --git a/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp b/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp index 4107683..a641903 100644 --- a/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp +++ b/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp @@ -88,9 +88,8 @@ TEST(ReverbEcho, SinWaveReverb) { // Extract the noise and compare with input that is the source of that noise. std::vector noise(firstReverbIdx); - for(int k = firstReverbIdx; k < signal.size(); ++k) { + for(int k = firstReverbIdx; k < signal.size(); ++k) noise[k - firstReverbIdx] = signal[k] - input[k]; - } // Because we use very long rt60 and we use multiple repeasts, the reverb sum // can get to very high values. We normalize by mean of the abs diffs. float noiseSum = 0; diff --git a/flashlight/pkg/speech/test/common/ProducerConsumerQueueTest.cpp b/flashlight/pkg/speech/test/common/ProducerConsumerQueueTest.cpp index ad4ca1f..571834c 100644 --- a/flashlight/pkg/speech/test/common/ProducerConsumerQueueTest.cpp +++ b/flashlight/pkg/speech/test/common/ProducerConsumerQueueTest.cpp @@ -23,9 +23,8 @@ TEST(ProducerConsumerQueueTest, SingleThread) { ProducerConsumerQueue queue(10); // Producing - for(int i = 1; i <= 5; i++) { + for(int i = 1; i <= 5; i++) queue.add(i); - } queue.finishAdding(); // Consuming @@ -49,9 +48,8 @@ TEST(ProducerConsumerQueueTest, MultiThreads) { // Define producer and consumers auto produce = [nProducer, &queue](int tid) { - for(int i = tid; i < nElements; i += nProducer) { + for(int i = tid; i < nElements; i += nProducer) queue.add(i); - } }; auto consume = [&consumerResults, &queue](int tid) { @@ -63,29 +61,24 @@ TEST(ProducerConsumerQueueTest, MultiThreads) { // Run Test std::vector> producerFutures(nConsumer); - for(int i = 0; i < nProducer; i++) { + for(int i = 0; i < nProducer; i++) producerFutures[i] = std::async(std::launch::async, produce, i); - } std::vector> consumerFutures(nConsumer); - for(int i = 0; i < nConsumer; i++) { + for(int i = 0; i < nConsumer; i++) consumerFutures[i] = std::async(std::launch::async, consume, i); - } - for(int i = 0; i < nConsumer; i++) { + for(int i = 0; i < nConsumer; i++) producerFutures[i].wait(); - } queue.finishAdding(); - for(int i = 0; i < nConsumer; i++) { + for(int i = 0; i < nConsumer; i++) consumerFutures[i].wait(); - } // Check int predictSum = 0; - for(const auto& element : consumerResults) { + for(const auto& element : consumerResults) predictSum += element; - } ASSERT_EQ(predictSum, targetSum); } diff --git a/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp b/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp index b41b56a..6133031 100644 --- a/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp +++ b/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp @@ -39,9 +39,8 @@ void timeBeamSearch() { std::vector beamsizes = {1, 5, 10, 20}; for(auto b : beamsizes) { auto s = fl::Timer::start(); - for(int i = 0; i < iters; ++i) { + for(int i = 0; i < iters; ++i) seq2seq.beamPath(input, Tensor(), b); - } fl::sync(); auto e = fl::Timer::stop(s); std::cout << "Total time (beam size: " << b << ") " << std::setprecision(5) diff --git a/flashlight/pkg/speech/test/criterion/CompareASG.cpp b/flashlight/pkg/speech/test/criterion/CompareASG.cpp index f0f9c1e..adf49d1 100644 --- a/flashlight/pkg/speech/test/criterion/CompareASG.cpp +++ b/flashlight/pkg/speech/test/criterion/CompareASG.cpp @@ -84,13 +84,12 @@ void printDiscrepancies( // Check for NaN discrepancies manually. auto compareNaN = fl::isnan(compare); auto baselineNaN = fl::isnan(baseline); - if(fl::any(compareNaN && !baselineNaN).asScalar()) { + if(fl::any(compareNaN && !baselineNaN).asScalar()) std::cerr << " (warning: compare has NaNs where baseline does not)"; - } else if(fl::any(compareNaN && baselineNaN).asScalar()) { + else if(fl::any(compareNaN && baselineNaN).asScalar()) std::cerr << " (warning: both baseline and compare have NaNs)"; - } else if(fl::any(baselineNaN).asScalar()) { + else if(fl::any(baselineNaN).asScalar()) std::cerr << " (warning: baseline has NaNs where compare does not)"; - } std::cerr << std::endl; } @@ -98,38 +97,33 @@ void printDiscrepancies( int main(int argc, char** argv) { fl::init(); - if(argc < 2) { + if(argc < 2) usage(argv[0]); - } std::string command = argv[1]; if(command == "generate") { - if(argc != 3) { + if(argc != 3) usage(argv[0]); - } std::seed_seq seeds({rd(), rd(), rd(), rd()}); std::mt19937 rng(seeds); // generate random target sizes std::vector targetSize(B); - for(int b = 0; b < B; ++b) { + for(int b = 0; b < B; ++b) // ensure we have a sample with targetSize=1 and targetSize=L targetSize[b] = (b == B - 1) ? L : (b == B - 2) ? 1 : (1 + rng() % L); - } std::shuffle(targetSize.begin(), targetSize.end(), rng); // generate random targets with the above sizes std::vector targetHost(B * L); for(int b = 0; b < B; ++b) { auto* targetCur = &targetHost[b * L]; - for(int i = 0; i < targetSize[b]; ++i) { + for(int i = 0; i < targetSize[b]; ++i) targetCur[i] = rng() % N; - } - for(int i = targetSize[b]; i < L; ++i) { + for(int i = targetSize[b]; i < L; ++i) targetCur[i] = -1; - } } uint64_t afSeed = rng(); @@ -143,9 +137,8 @@ int main(int argc, char** argv) { fl::save(argv[2], input, target, trans); std::cerr << "input generated" << std::endl; } else if(command == "baseline") { - if(argc != 4) { + if(argc != 4) usage(argv[0]); - } Tensor input, target, trans; fl::load(argv[2], input, target, trans); @@ -153,9 +146,8 @@ int main(int argc, char** argv) { fl::save(argv[3], out.loss, out.inputGrad, out.transGrad); std::cerr << "baseline saved" << std::endl; } else if(command == "compare") { - if(argc != 4) { + if(argc != 4) usage(argv[0]); - } Tensor input, target, trans; fl::load(argv[2], input, target, trans); @@ -166,7 +158,6 @@ int main(int argc, char** argv) { printDiscrepancies("loss: ", out.loss, out0.loss); printDiscrepancies("inputGrad: ", out.inputGrad, out0.inputGrad); printDiscrepancies("transGrad: ", out.transGrad, out0.transGrad); - } else { + } else usage(argv[0]); - } } diff --git a/flashlight/pkg/speech/test/criterion/CriterionTest.cpp b/flashlight/pkg/speech/test/criterion/CriterionTest.cpp index 33ba450..98fdddc 100644 --- a/flashlight/pkg/speech/test/criterion/CriterionTest.cpp +++ b/flashlight/pkg/speech/test/criterion/CriterionTest.cpp @@ -73,9 +73,8 @@ TEST(CriterionTest, CTCEmptyTarget) { // Subtle - related to memory manager initialization. Will be fixed in a // future version of ArrayFire after which time this can be removed. The test // passes/works properly in isolation. - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) GTEST_SKIP() << "Skipping test for CPU backend"; - } // Non-empty input, Empty target, batchsize > 0 auto input = Variable(Tensor({3, 2, 5}), true); @@ -138,9 +137,8 @@ TEST(CriterionTest, Batching) { auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)) % (N - 2); for(int i = 0; i < B; ++i) { int r = rand() % L; - if(r > 0) { + if(r > 0) t(fl::range(r, fl::end), i) = -1; - } } auto tgt = Variable(t.astype(fl::dtype::s32), false); auto l = ConnectionistTemporalClassificationCriterion( @@ -157,9 +155,8 @@ TEST(CriterionTest, Batching) { auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)) % (N - 2); for(int i = 0; i < B; ++i) { int r = rand() % L; - if(r > 0) { + if(r > 0) t(fl::range(r, fl::end), i) = -1; - } } auto tgt = Variable(t.astype(fl::dtype::s32), false); auto l = ConnectionistTemporalClassificationCriterion( @@ -262,9 +259,8 @@ TEST(CriterionTest, ViterbiPath) { // Test case: 1 auto in = fl::rand({4, 5, 1}); // All values < 1 std::array expectedpath1 = {3, 2, 0, 2, 2}; - for(int j = 0; j < 5; ++j) { + for(int j = 0; j < 5; ++j) in(expectedpath1[j], j) = 2; - } ConnectionistTemporalClassificationCriterion ctc; auto vpath1Arr = ctc.viterbiPath(in); Tensor expPath1Arr = Tensor::fromArray({5, 1}, expectedpath1); @@ -720,9 +716,8 @@ TEST(CriterionTest, ASGBatching) { auto t = fl::abs(fl::rand({L, B}, fl::dtype::s32)) % (N - 2); for(int i = 0; i < B; ++i) { int r = std::rand() % L; - if(r > 0) { + if(r > 0) t(fl::range(r, fl::end), i) = -1; - } } auto tgt = Variable(t.astype(fl::dtype::s32), false); auto l = AutoSegmentationCriterion(N, CriterionScaleMode::TARGET_SZ); @@ -809,9 +804,8 @@ TEST(CriterionTest, ASGCompareLua) { auto loss = asg({inputAf, targetAf}).front(); std::vector lossHost(B); loss.host(lossHost.data()); - for(int i = 0; i < B; i++) { + for(int i = 0; i < B; i++) ASSERT_NEAR(lossHost[i], expectedLoss[i], 1e-3); - } loss.backward(); auto inputGrad = inputAf.grad().tensor(); @@ -896,9 +890,8 @@ TEST(CriterionTest, LinSegCompareLua) { auto loss = linseg({inputAf, targetAf}).front(); std::vector lossHost(B); loss.host(lossHost.data()); - for(int i = 0; i < B; i++) { + for(int i = 0; i < B; i++) ASSERT_NEAR(lossHost[i], expectedLoss[i], 1e-3); - } loss.backward(); auto inputGrad = inputAf.grad().tensor(); @@ -910,9 +903,8 @@ TEST(CriterionTest, LinSegCompareLua) { TEST(CriterionTest, AsgSerialization) { char* user = getenv("USER"); std::string userstr = "unknown"; - if(user != nullptr) { + if(user != nullptr) userstr = std::string(user); - } const fs::path path = fs::temp_directory_path() / "test.mdl"; int N = 500; diff --git a/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp b/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp index 05b9039..a2c3dbc 100644 --- a/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp +++ b/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp @@ -18,9 +18,8 @@ using namespace fl; using namespace fl::pkg::speech; TEST(Seq2SeqTest, Seq2Seq) { - if(FL_BACKEND_CPU) { + if(FL_BACKEND_CPU) GTEST_SKIP() << "RNN gradient computation not supported for CPU backend"; - } int nclass = 40; int hiddendim = 256; int batchsize = 2; @@ -140,9 +139,8 @@ TEST(Seq2SeqTest, Seq2SeqBeamSearchViterbi) { auto viterbipath = seq2seq.viterbiPath(input); auto beampath = seq2seq.beamPath(input, Tensor(), 1); ASSERT_EQ(beampath.size(), viterbipath.elements()); - for(int idx = 0; idx < beampath.size(); idx++) { + for(int idx = 0; idx < beampath.size(); idx++) ASSERT_EQ(beampath[idx], viterbipath(idx).scalar()); - } } TEST(Seq2SeqTest, Seq2SeqMedianWindow) { @@ -166,9 +164,8 @@ TEST(Seq2SeqTest, Seq2SeqMedianWindow) { auto viterbipath = seq2seq.viterbiPath(input); auto beampath = seq2seq.beamPath(input, Tensor(), 1); ASSERT_EQ(beampath.size(), viterbipath.elements()); - for(int idx = 0; idx < beampath.size(); idx++) { + for(int idx = 0; idx < beampath.size(); idx++) ASSERT_EQ(beampath[idx], viterbipath(idx).scalar()); - } } TEST(Seq2SeqTest, Seq2SeqStepWindow) { @@ -192,9 +189,8 @@ TEST(Seq2SeqTest, Seq2SeqStepWindow) { auto viterbipath = seq2seq.viterbiPath(input); auto beampath = seq2seq.beamPath(input, Tensor(), 1); ASSERT_EQ(beampath.size(), viterbipath.elements()); - for(int idx = 0; idx < beampath.size(); idx++) { + for(int idx = 0; idx < beampath.size(); idx++) ASSERT_EQ(beampath[idx], viterbipath(idx).scalar()); - } } TEST(Seq2SeqTest, Seq2SeqStepWindowVectorized) { @@ -294,9 +290,8 @@ TEST(Seq2SeqTest, Seq2SeqMixedAttn) { TEST(Seq2SeqTest, Serialization) { char* user = getenv("USER"); std::string userstr = "unknown"; - if(user != nullptr) { + if(user != nullptr) userstr = std::string(user); - } const fs::path path = fs::temp_directory_path() / "test.mdl"; int N = 5, H = 8, B = 1, T = 10, U = 5, maxoutputlen = 100, nAttnRound = 2; @@ -409,10 +404,9 @@ TEST(Seq2SeqTest, BatchedDecoderStep) { ys.push_back(y); inStates[i].alpha = noGrad(fl::randn({1, T, 1}, fl::dtype::f32)); - for(int j = 0; j < nAttnRound; j++) { + for(int j = 0; j < nAttnRound; j++) inStates[i].hidden[j] = noGrad(fl::randn({H, 1, nRnnLayer}, fl::dtype::f32)); - } inStates[i].summary = noGrad(fl::randn({H, 1, 1}, fl::dtype::f32)); inStatePtrs[i] = &inStates[i]; @@ -437,11 +431,9 @@ TEST(Seq2SeqTest, BatchedDecoderStep) { seq2seq.decodeBatchStep(input, ys, inStatePtrs); // Check - for(int i = 0; i < B; i++) { - for(int j = 0; j < N; j++) { + for(int i = 0; i < B; i++) + for(int j = 0; j < N; j++) ASSERT_NEAR(single_scores[i][j], batched_scores[i][j], 1e-5); - } - } } } diff --git a/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp b/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp index be1345c..980c8da 100644 --- a/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp +++ b/flashlight/pkg/speech/test/criterion/attention/AttentionTest.cpp @@ -49,9 +49,8 @@ bool jacobianTestImpl( for(int i = 0; i < dout.elements(); ++i) { dout.tensor().flat(i) = 1; // element in 1D view input.zeroGrad(); - for(auto* var : zeroGradientVariables) { + for(auto* var : zeroGradientVariables) var->zeroGrad(); - } auto out = func(input); out.backward(dout); @@ -157,13 +156,12 @@ TEST(AttentionTest, NeuralContentAttention) { ); ASSERT_EQ(alphas.shape(), Shape({U, T, B})); ASSERT_EQ(summaries.shape(), Shape({H, U, B})); - if(!currentPad.isEmpty()) { + if(!currentPad.isEmpty()) ASSERT_EQ( fl::countNonzero( alphas.tensor()(fl::span, fl::range(T - T / 2, T), 0) == 0) .scalar(), T / 2 * U); - } auto alphasum = sum(alphas.tensor(), {1}); auto ones = fl::full(alphasum.shape(), 1.0, alphasum.type()); ASSERT_TRUE(allClose(alphasum, ones, 1e-5)); @@ -210,8 +208,8 @@ TEST(AttentionTest, MultiHeadContentAttention) { Variable pad = Variable(Tensor::fromVector({1, B}, padRaw), false); std::vector padV = {Variable(), pad}; - for(const auto& currentPad : padV) { - for(bool keyValue : {true, false}) { + for(const auto& currentPad : padV) + for(bool keyValue : {true, false}) for(bool splitInput : {true, false}) { MultiHeadContentAttention attention(H, NH, keyValue, splitInput); @@ -229,20 +227,17 @@ TEST(AttentionTest, MultiHeadContentAttention) { ); ASSERT_EQ(alphas.shape(), Shape({U* NH, T, B})); ASSERT_EQ(summaries.shape(), Shape({H, U, B})); - if(!currentPad.isEmpty()) { + if(!currentPad.isEmpty()) ASSERT_EQ( fl::countNonzero( alphas.tensor()(fl::span, fl::range(T - T / 2, T), 0) == 0) .scalar(), T / 2 * U * NH); - } auto alphasum = sum(alphas.tensor(), {1}); auto ones = fl::full(alphasum.shape(), 1.0, alphasum.type()); ASSERT_TRUE(allClose(alphasum, ones, 1e-5)); } - } - } } TEST(AttentionTest, JacobianMaskAttention) { diff --git a/flashlight/pkg/speech/test/criterion/attention/WindowTest.cpp b/flashlight/pkg/speech/test/criterion/attention/WindowTest.cpp index f39a0a0..95ce674 100644 --- a/flashlight/pkg/speech/test/criterion/attention/WindowTest.cpp +++ b/flashlight/pkg/speech/test/criterion/attention/WindowTest.cpp @@ -308,7 +308,7 @@ TEST(WindowTest, SoftPretrainWindow) { // single step std::vector masks; - for(int step = 0; step < targetlen; ++step) { + for(int step = 0; step < targetlen; ++step) masks.emplace_back( window.computeWindow( inputAttn, @@ -318,7 +318,6 @@ TEST(WindowTest, SoftPretrainWindow) { batchsize ) ); - } auto maskS = concatenate(masks, 0); Tensor maxv, maxidx; max(maxv, maxidx, maskS.tensor()(fl::span, fl::span, 0), 1); diff --git a/flashlight/pkg/speech/test/data/FeaturizationTest.cpp b/flashlight/pkg/speech/test/data/FeaturizationTest.cpp index 902b986..3f82c99 100644 --- a/flashlight/pkg/speech/test/data/FeaturizationTest.cpp +++ b/flashlight/pkg/speech/test/data/FeaturizationTest.cpp @@ -27,14 +27,11 @@ using namespace fl::pkg::speech; namespace { template bool compareVec(std::vector A, std::vector B, float precision = 1E-5) { - if(A.size() != B.size()) { + if(A.size() != B.size()) return false; - } - for(std::size_t i = 0; i < A.size(); ++i) { - if(std::abs(A[i] - B[i]) > precision) { + for(std::size_t i = 0; i < A.size(); ++i) + if(std::abs(A[i] - B[i]) > precision) return false; - } - } return true; } @@ -126,7 +123,7 @@ TEST(FeaturizationTest, Transpose) { TEST(FeaturizationTest, localNormalize) { auto afNormalize = [](const Tensor& in, int64_t lw, int64_t rw) { auto out = in; - for(int64_t b = 0; b < in.dim(3); ++b) { + for(int64_t b = 0; b < in.dim(3); ++b) for(int64_t i = 0; i < in.dim(0); ++i) { int64_t b_idx = (i - lw > 0) ? (i - lw) : 0; int64_t e_idx = (in.dim(0) - 1 > i + rw) ? (i + rw) : (in.dim(0) - 1); @@ -136,11 +133,9 @@ TEST(FeaturizationTest, localNormalize) { auto stddev = fl::std(slice).scalar(); out(i, fl::span, fl::span, b) -= mean; - if(stddev > 0.0) { + if(stddev > 0.0) out(i, fl::span, fl::span, b) /= stddev; - } } - } return out; }; auto arr = fl::rand({47, 67, 2, 10}); // FRAMES X FEAT X CHANNELS X BATCHSIZE @@ -195,9 +190,8 @@ TEST(FeaturizationTest, TargetTknTestStandaloneSep) { std::vector resT = { "ab", "cd", "ef", "||", "ab", "cd", "||", "t", "r", "||"}; ASSERT_EQ(res.size(), resT.size()); - for(int index = 0; index < res.size(); index++) { + for(int index = 0; index < res.size(); index++) ASSERT_EQ(res[index], resT[index]); - } auto res2 = wrd2Target( words, @@ -213,9 +207,8 @@ TEST(FeaturizationTest, TargetTknTestStandaloneSep) { std::vector resT2 = { "ab", "cd", "ef", "||", "ab", "cd", "||", "||", "t", "r"}; ASSERT_EQ(res2.size(), resT2.size()); - for(int index = 0; index < res2.size(); index++) { + for(int index = 0; index < res2.size(); index++) ASSERT_EQ(res2[index], resT2[index]); - } } TEST(FeaturizationTest, TargetTknTestInsideSep) { @@ -249,9 +242,8 @@ TEST(FeaturizationTest, TargetTknTestInsideSep) { std::vector resT = { "_", "a", "f", "f", "_hel", "lo", "_ma", "ma", "_", "a", "f"}; ASSERT_EQ(res.size(), resT.size()); - for(int index = 0; index < res.size(); index++) { + for(int index = 0; index < res.size(); index++) ASSERT_EQ(res[index], resT[index]); - } auto res2 = wrd2Target( words, @@ -267,9 +259,8 @@ TEST(FeaturizationTest, TargetTknTestInsideSep) { std::vector resT2 = { "a", "f", "f", "_", "_hel", "lo", "_ma", "ma", "_", "a", "f"}; ASSERT_EQ(res.size(), resT2.size()); - for(int index = 0; index < res2.size(); index++) { + for(int index = 0; index < res2.size(); index++) ASSERT_EQ(res2[index], resT2[index]); - } } TEST(FeaturizationTest, WrdToTarget) { @@ -289,15 +280,11 @@ TEST(FeaturizationTest, WrdToTarget) { lexicon[kUnkToken] = {}; Dictionary dict; - for(const auto& l : lexicon) { - for(const auto& p : l.second) { - for(const auto& c : p) { - if(!dict.contains(c)) { + for(const auto& l : lexicon) + for(const auto& p : l.second) + for(const auto& c : p) + if(!dict.contains(c)) dict.addEntry(c); - } - } - } - } // NOTE: word separator has no effect when fallback2Ltr is false std::vector words = {"123", "456"}; @@ -349,9 +336,8 @@ TEST(FeaturizationTest, TargetToSingleLtr) { bool usewordpiece = true; Dictionary dict; - for(int i = 0; i < 10; ++i) { + for(int i = 0; i < 10; ++i) dict.addEntry(std::to_string(i), i); - } dict.addEntry("_", 10); dict.addEntry("23_", 230); dict.addEntry("456_", 4560); @@ -388,10 +374,9 @@ TEST(FeaturizationTest, inputFeaturizer) { inputFeatures(featParams, FeatureType::MFSC, {-1, -1}, {}); for(int size = 1; size < 10; ++size) { std::vector input(size * samplerate * channels); - for(int j = 0; j < input.size(); ++j) { + for(int j = 0; j < input.size(); ++j) // channel 1 is same as channel 2 input[j] = std::sin(2 * M_PI * (j / 2) / samplerate); - } int insize = size * samplerate; auto inArray = diff --git a/flashlight/pkg/speech/test/data/ListFileDatasetTest.cpp b/flashlight/pkg/speech/test/data/ListFileDatasetTest.cpp index 58566ed..ef73c54 100644 --- a/flashlight/pkg/speech/test/data/ListFileDatasetTest.cpp +++ b/flashlight/pkg/speech/test/data/ListFileDatasetTest.cpp @@ -29,26 +29,23 @@ auto letterToTarget = [](void* data, Shape dims, fl::dtype /* unused */) { std::string transcript( static_cast(data), static_cast(data) + dims.elements()); std::vector tgt; - for(auto c : transcript) { + for(auto c : transcript) tgt.push_back(static_cast(c)); - } return Tensor::fromVector(tgt); }; } // namespace TEST(ListFileDatasetTest, LoadData) { const fs::path dataPath = loadPath / "data.lst"; - if(!fs::exists(dataPath)) { + if(!fs::exists(dataPath)) throw std::runtime_error( "ListFileDatasetTest, LoadData - can't open test data.lst" ); - } std::vector data; { std::ifstream in(dataPath); - for(std::string s; std::getline(in, s);) { + for(std::string s; std::getline(in, s);) data.emplace_back(s); - } } const fs::path rootPath = fs::temp_directory_path() / "data.lst"; diff --git a/flashlight/pkg/speech/test/data/SoundTest.cpp b/flashlight/pkg/speech/test/data/SoundTest.cpp index 6f64719..3434da4 100644 --- a/flashlight/pkg/speech/test/data/SoundTest.cpp +++ b/flashlight/pkg/speech/test/data/SoundTest.cpp @@ -44,17 +44,15 @@ TEST(SoundTest, Mono) { // Double auto vecDouble = loadSound(audiopath); ASSERT_EQ(vecDouble.size(), info.channels * info.frames); - for(int64_t i = 0; i < vecDouble.size(); ++i) { + for(int64_t i = 0; i < vecDouble.size(); ++i) ASSERT_NEAR(vecDouble[i], data[i], 1E-8); - } // Float auto vecFloat = loadSound(audiopath); ASSERT_EQ(vecFloat.size(), info.channels * info.frames); - for(int64_t i = 0; i < vecFloat.size(); ++i) { + for(int64_t i = 0; i < vecFloat.size(); ++i) ASSERT_NEAR(vecFloat[i], data[i], 1E-6); - } // scale by max value for short std::transform( @@ -70,9 +68,8 @@ TEST(SoundTest, Mono) { auto vecShort = loadSound(audiopath); ASSERT_EQ(vecShort.size(), info.channels * info.frames); - for(int64_t i = 0; i < vecShort.size(); ++i) { + for(int64_t i = 0; i < vecShort.size(); ++i) ASSERT_NEAR(vecShort[i], data[i], 0.5); - } // scale by (max value for int64_t / max value of short) std::transform( @@ -86,9 +83,8 @@ TEST(SoundTest, Mono) { // Int auto vecInt = loadSound(audiopath); ASSERT_EQ(vecInt.size(), info.channels * info.frames); - for(int64_t i = 0; i < vecInt.size(); ++i) { + for(int64_t i = 0; i < vecInt.size(); ++i) ASSERT_NEAR(vecInt[i], data[i], 25); - } } TEST(SoundTest, Stereo) { @@ -106,9 +102,8 @@ TEST(SoundTest, Stereo) { auto data = loadData(datapath); ASSERT_EQ(data.size(), info.channels * info.frames); - for(int64_t i = 0; i < vecFloat.size(); ++i) { + for(int64_t i = 0; i < vecFloat.size(); ++i) ASSERT_NEAR(vecFloat[i], data[i], 1E-6); - } } TEST(SoundTest, OggReadWrite) { @@ -131,9 +126,8 @@ TEST(SoundTest, OggReadWrite) { ASSERT_EQ(vecFloat.size(), vecFloatOut.size()); - for(int64_t i = 0; i < vecFloat.size(); ++i) { + for(int64_t i = 0; i < vecFloat.size(); ++i) ASSERT_NEAR(vecFloat[i], vecFloatOut[i], 5E-3); - } } TEST(SoundTest, StreamReadWrite) { @@ -163,9 +157,8 @@ TEST(SoundTest, StreamReadWrite) { auto vecShortStream = loadSound(f); ASSERT_EQ(vecShort.size(), vecShortStream.size()); - for(int64_t i = 0; i < vecShort.size(); ++i) { + for(int64_t i = 0; i < vecShort.size(); ++i) ASSERT_EQ(vecShort[i], vecShortStream[i]); - } } int main(int argc, char** argv) { diff --git a/flashlight/pkg/speech/test/decoder/ConvLmModuleTest.cpp b/flashlight/pkg/speech/test/decoder/ConvLmModuleTest.cpp index 5c93709..331dbca 100644 --- a/flashlight/pkg/speech/test/decoder/ConvLmModuleTest.cpp +++ b/flashlight/pkg/speech/test/decoder/ConvLmModuleTest.cpp @@ -69,9 +69,8 @@ TEST(ConvLmModuleTest, GCNN14BCrossEntropy) { TEST(ConvLmModuleTest, SerializationGCNN14BAdaptiveSoftmax) { char* user = getenv("USER"); std::string userstr = "unknown"; - if(user != nullptr) { + if(user != nullptr) userstr = std::string(user); - } const fs::path path = fs::temp_directory_path() / "test.mdl"; const fs::path archfile = archDir / "gcnn_14B_lm_arch_as.txt"; diff --git a/flashlight/pkg/speech/test/runtime/RuntimeTest.cpp b/flashlight/pkg/speech/test/runtime/RuntimeTest.cpp index 44057a8..754ac6c 100644 --- a/flashlight/pkg/speech/test/runtime/RuntimeTest.cpp +++ b/flashlight/pkg/speech/test/runtime/RuntimeTest.cpp @@ -24,9 +24,8 @@ namespace { const fs::path kPath = fs::temp_directory_path() / "test.mdl"; bool afEqual(const fl::Variable& a, const fl::Variable& b) { - if(a.isCalcGrad() != b.isCalcGrad()) { + if(a.isCalcGrad() != b.isCalcGrad()) return false; - } return allClose(a.tensor(), b.tensor(), 1E-7); } @@ -68,13 +67,12 @@ TEST(RuntimeTest, LoadAndSave) { TEST(RuntimeTest, TestCleanFilepath) { auto s = cleanFilepath("timit/train.\\mymodel"); std::string sep(1, fs::path::preferred_separator); - if(sep == "/") { + if(sep == "/") ASSERT_EQ(s, "timit#train.\\mymodel"); - } else if(sep == "\\") { + else if(sep == "\\") ASSERT_EQ(s, "timit/train.#mymodel"); - } else { + else GTEST_SKIP() << "System uses a different separator"; - } } TEST(RuntimeTest, SpeechStatMeter) { diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h index e9bb074..2b5b225 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h @@ -30,11 +30,11 @@ class CpuCTC { workspace_(workspace) { #if defined(CTC_DISABLE_OMP) || defined(APPLE) #else - if(num_threads > 0) { + if(num_threads > 0) omp_set_num_threads(num_threads); - } else { + else num_threads_ = omp_get_max_threads(); - } + #endif }; @@ -212,13 +212,12 @@ void CpuCTC::softmax( const int* const input_lengths ) { #pragma omp parallel for - for(int mb = 0; mb < minibatch_; ++mb) { + for(int mb = 0; mb < minibatch_; ++mb) for(int c = 0; c < input_lengths[mb]; ++c) { int col_offset = (mb + minibatch_ * c) * alphabet_size_; ProbT max_activation = -std::numeric_limits::infinity(); - for(int r = 0; r < alphabet_size_; ++r) { + for(int r = 0; r < alphabet_size_; ++r) max_activation = std::max(max_activation, activations[r + col_offset]); - } ProbT denom = ProbT(0.); for(int r = 0; r < alphabet_size_; ++r) { @@ -226,11 +225,9 @@ void CpuCTC::softmax( denom += probs[r + col_offset]; } - for(int r = 0; r < alphabet_size_; ++r) { + for(int r = 0; r < alphabet_size_; ++r) probs[r + col_offset] /= denom; - } } - } } template @@ -281,9 +278,8 @@ std::tuple CpuCTC::cost_and_grad_kernel( ); ProbT diff = std::abs(llForward - llBackward); - if(diff > ctc_helper::threshold) { + if(diff > ctc_helper::threshold) over_threshold = true; - } return std::make_tuple(-llForward, over_threshold); } @@ -304,18 +300,15 @@ ProbT CpuCTC::compute_alphas( int start = (((S / 2) + repeats - T) < 0) ? 0 : 1, end = S > 1 ? 2 : 1; - for(int i = start; i < end; ++i) { + for(int i = start; i < end; ++i) alphas[i] = std::log(probs[labels[i]]); - } for(int t = 1; t < T; ++t) { int remain = (S / 2) + repeats - (T - t); - if(remain >= 0) { + if(remain >= 0) start += s_inc[remain]; - } - if(t <= (S / 2) + repeats) { + if(t <= (S / 2) + repeats) end += e_inc[t - 1]; - } int startloop = start; int idx1 = t * S, idx2 = (t - 1) * S, idx3 = t * (alphabet_size_ * minibatch_); @@ -328,18 +321,16 @@ ProbT CpuCTC::compute_alphas( ProbT prev_sum = ctc_helper::log_plus()(alphas[i + idx2], alphas[(i - 1) + idx2]); // Skip two if not on blank and not on repeat. - if(labels[i] != blank_label_ && i != 1 && labels[i] != labels[i - 2]) { + if(labels[i] != blank_label_ && i != 1 && labels[i] != labels[i - 2]) prev_sum = ctc_helper::log_plus()(prev_sum, alphas[(i - 2) + idx2]); - } alphas[i + idx1] = prev_sum + std::log(probs[labels[i] + idx3]); } } ProbT loglike = ctc_helper::neg_inf(); - for(int i = start; i < end; ++i) { + for(int i = start; i < end; ++i) loglike = ctc_helper::log_plus()(loglike, alphas[i + (T - 1) * S]); - } return loglike; } @@ -389,25 +380,22 @@ ProbT CpuCTC::compute_betas_and_grad( if( output[i] == 0.0 || output[i] == ctc_helper::neg_inf() || probs[idx3] == 0.0 - ) { + ) grad[idx3] = probs[idx3]; - } else { + else grad[idx3] = probs[idx3] - std::exp( output[i] - std::log(probs[idx3]) - log_partition ); - } } // loop from the second to last column all the way to the left for(int t = T - 2; t >= 0; --t) { int remain = (S / 2) + repeats - (T - t); - if(remain >= -1) { + if(remain >= -1) start -= s_inc[remain + 1]; - } - if(t < (S / 2) + repeats) { + if(t < (S / 2) + repeats) end -= e_inc[t]; - } int endloop = end == S ? end - 1 : end; int idx1 = t * S, idx3 = t * (alphabet_size_ * minibatch_); @@ -417,9 +405,8 @@ ProbT CpuCTC::compute_betas_and_grad( for(int i = start; i < endloop; ++i) { ProbT next_sum = ctc_helper::log_plus()(betas[i], betas[(i + 1)]); // Skip two if not on blank and not on repeat. - if(labels[i] != blank_label_ && i != (S - 2) && labels[i] != labels[i + 2]) { + if(labels[i] != blank_label_ && i != (S - 2) && labels[i] != labels[i + 2]) next_sum = ctc_helper::log_plus()(next_sum, betas[(i + 2)]); - } betas[i] = next_sum + std::log(probs[labels[i] + idx3]); // compute alpha * beta in log space @@ -445,22 +432,20 @@ ProbT CpuCTC::compute_betas_and_grad( if( output[i] == 0.0 || output[i] == ctc_helper::neg_inf() || probs[idx3] == 0.0 - ) { + ) grad[idx3] = probs[idx3]; - } else { + else grad[idx3] = probs[idx3] - std::exp( output[i] - std::log(probs[idx3]) - log_partition ); - } ++idx3; } } ProbT loglike = ctc_helper::neg_inf(); - for(int i = start; i < end; ++i) { + for(int i = start; i < end; ++i) loglike = ctc_helper::log_plus()(loglike, betas[i]); - } return loglike; } @@ -481,9 +466,8 @@ ctcStatus_t CpuCTC::cost_and_grad( || flat_labels == nullptr || label_lengths == nullptr || input_lengths == nullptr - ) { + ) return CTC_STATUS_INVALID_VALUE; - } ProbT* probs = static_cast(workspace_); @@ -547,9 +531,8 @@ ctcStatus_t CpuCTC::score_forward( || flat_labels == nullptr || label_lengths == nullptr || input_lengths == nullptr - ) { + ) return CTC_STATUS_INVALID_VALUE; - } ProbT* probs = static_cast(workspace_); @@ -588,9 +571,9 @@ ctcStatus_t CpuCTC::score_forward( flat_labels + std::accumulate(label_lengths, label_lengths + mb, 0)); - if(L + ctcm.repeats > T) { + if(L + ctcm.repeats > T) costs[mb] = ProbT(0); - } else { + else costs[mb] = -compute_alphas( probs + mb * alphabet_size_, ctcm.repeats, @@ -601,7 +584,6 @@ ctcStatus_t CpuCTC::score_forward( ctcm.labels_w_blanks, ctcm.alphas ); - } } diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h index 8653b2c..5b19371 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h @@ -59,12 +59,10 @@ struct log_plus { typedef Res result_type; HOSTDEVICE Res operator()(const Arg1& p1, const Arg2& p2) { - if(p1 == neg_inf()) { + if(p1 == neg_inf()) return p2; - } - if(p2 == neg_inf()) { + if(p2 == neg_inf()) return p1; - } Res result = log1p(exp(-fabs(p1 - p2))) + maximum()(p1, p2); return result; } diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h index 2c46d9d..5d8b25b 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h @@ -176,9 +176,8 @@ ctcStatus_t GpuCTC::setup_gpu_metadata( int repeat_counter = 0; - for(int i = 1; i < L; ++i) { + for(int i = 1; i < L; ++i) repeat_counter += (label_ptr[i] == label_ptr[i - 1]); - } repeats[j % cpu_buffer_size] = repeat_counter; const bool valid_label = ((L + repeat_counter) <= local_T); @@ -198,9 +197,8 @@ ctcStatus_t GpuCTC::setup_gpu_metadata( cudaMemcpyHostToDevice, stream_ ); - if(cuda_status != cudaSuccess) { + if(cuda_status != cudaSuccess) return CTC_STATUS_MEMOPS_FAILED; - } cuda_status = cudaMemcpyAsync( @@ -210,9 +208,8 @@ ctcStatus_t GpuCTC::setup_gpu_metadata( cudaMemcpyHostToDevice, stream_ ); - if(cuda_status != cudaSuccess) { + if(cuda_status != cudaSuccess) return CTC_STATUS_MEMOPS_FAILED; - } } S_ = 2 * S_ + 1; @@ -233,9 +230,8 @@ ctcStatus_t GpuCTC::setup_gpu_metadata( cudaMemcpyHostToDevice, stream_ ); - if(cuda_status != cudaSuccess) { + if(cuda_status != cudaSuccess) return CTC_STATUS_MEMOPS_FAILED; - } label_sizes_ = reinterpret_cast(static_cast(gpu_workspace_) @@ -248,9 +244,8 @@ ctcStatus_t GpuCTC::setup_gpu_metadata( cudaMemcpyHostToDevice, stream_ ); - if(cuda_status != cudaSuccess) { + if(cuda_status != cudaSuccess) return CTC_STATUS_MEMOPS_FAILED; - } labels_without_blanks_ = reinterpret_cast(static_cast(gpu_workspace_) @@ -263,9 +258,8 @@ ctcStatus_t GpuCTC::setup_gpu_metadata( cudaMemcpyHostToDevice, stream_ ); - if(cuda_status != cudaSuccess) { + if(cuda_status != cudaSuccess) return CTC_STATUS_MEMOPS_FAILED; - } labels_with_blanks_ = reinterpret_cast(static_cast(gpu_workspace_) @@ -307,13 +301,12 @@ ctcStatus_t GpuCTC::launch_alpha_beta_kernels( // away const int stride = minibatch_; - if(compute_alpha) { + if(compute_alpha) compute_alpha_kernel<< < grid_size, NT, 0, stream_ >> > (probs, label_sizes_, utt_length_, repeats_, labels_without_blanks_, label_offsets_, labels_with_blanks_, alphas_, nll_forward_, stride, out_dim_, S_, T_, blank_label_); - } if(compute_beta) { @@ -326,9 +319,8 @@ ctcStatus_t GpuCTC::launch_alpha_beta_kernels( } cudaError_t err = cudaGetLastError(); - if(err != cudaSuccess) { + if(err != cudaSuccess) return CTC_STATUS_EXECUTION_FAILED; - } return CTC_STATUS_SUCCESS; } @@ -343,9 +335,8 @@ ctcStatus_t GpuCTC::create_metadata_and_choose_config( // Setup the metadata for GPU ctcStatus_t status = setup_gpu_metadata(flat_labels, label_lengths, input_lengths); - if(status != CTC_STATUS_SUCCESS) { + if(status != CTC_STATUS_SUCCESS) return status; - } constexpr int num_configs = 12; @@ -357,16 +348,14 @@ ctcStatus_t GpuCTC::create_metadata_and_choose_config( best_config = 0; for(int i = 0; i < num_configs; ++i) { - if((config_NT[i] * config_VT[i]) >= S_) { + if((config_NT[i] * config_VT[i]) >= S_) break; - } else { + else best_config++; - } } - if(best_config >= num_configs) { + if(best_config >= num_configs) return CTC_STATUS_LABEL_LENGTH_TOO_LARGE; - } return CTC_STATUS_SUCCESS; } @@ -410,9 +399,8 @@ ctcStatus_t GpuCTC::compute_log_probs(const ProbT* const activations) { cudaMemcpyDeviceToDevice, stream_ ); - if(cuda_status != cudaSuccess) { + if(cuda_status != cudaSuccess) return CTC_STATUS_MEMOPS_FAILED; - } // Numerically stable SM ctcStatus_t ctc_status = @@ -424,9 +412,8 @@ ctcStatus_t GpuCTC::compute_log_probs(const ProbT* const activations) { 1, stream_ ); - if(ctc_status != CTC_STATUS_SUCCESS) { + if(ctc_status != CTC_STATUS_SUCCESS) return ctc_status; - } // Kernel launch to subtract maximum const int NT = 128; @@ -449,9 +436,8 @@ ctcStatus_t GpuCTC::compute_log_probs(const ProbT* const activations) { 1, stream_ ); - if(ctc_status != CTC_STATUS_SUCCESS) { + if(ctc_status != CTC_STATUS_SUCCESS) return ctc_status; - } // Kernel launch to calculate probabilities compute_log_probs_kernel<< < grid_size, NT, 0, stream_ >> @@ -480,14 +466,12 @@ ctcStatus_t GpuCTC::compute_cost_and_score( input_lengths, best_config ); - if(status != CTC_STATUS_SUCCESS) { + if(status != CTC_STATUS_SUCCESS) return status; - } status = compute_log_probs(activations); - if(status != CTC_STATUS_SUCCESS) { + if(status != CTC_STATUS_SUCCESS) return status; - } status = launch_gpu_kernels( probs_, @@ -497,9 +481,8 @@ ctcStatus_t GpuCTC::compute_cost_and_score( compute_betas_and_grad ); - if(status != CTC_STATUS_SUCCESS) { + if(status != CTC_STATUS_SUCCESS) return status; - } cudaError_t cuda_status_mem, cuda_status_sync; cuda_status_mem = cudaMemcpyAsync( @@ -510,9 +493,8 @@ ctcStatus_t GpuCTC::compute_cost_and_score( stream_ ); cuda_status_sync = cudaStreamSynchronize(stream_); - if(cuda_status_mem != cudaSuccess || cuda_status_sync != cudaSuccess) { + if(cuda_status_mem != cudaSuccess || cuda_status_sync != cudaSuccess) return CTC_STATUS_MEMOPS_FAILED; - } return CTC_STATUS_SUCCESS; } @@ -532,9 +514,8 @@ ctcStatus_t GpuCTC::cost_and_grad( || costs == nullptr || label_lengths == nullptr || input_lengths == nullptr - ) { + ) return CTC_STATUS_INVALID_VALUE; - } return compute_cost_and_score( activations, @@ -561,9 +542,8 @@ ctcStatus_t GpuCTC::score_forward( || costs == nullptr || label_lengths == nullptr || input_lengths == nullptr - ) { + ) return CTC_STATUS_INVALID_VALUE; - } return compute_cost_and_score( activations, diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h index 3318e37..689384d 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h @@ -35,9 +35,8 @@ struct CTASegReduce { for(int i = 0; i < VT; ++i) { int index = VT * tid + 1 + i; T next = keys[index]; - if(index == count || (index < count && key != next)) { + if(index == count || (index < count && key != next)) endFlags |= 1 << i; - } key = next; } @@ -52,14 +51,13 @@ struct CTASegReduce { // use indices as scratch space int outputPos = scan; #pragma unroll - for(int i = 0; i < VT; ++i) { + for(int i = 0; i < VT; ++i) if((endFlags >> i) & 1) { shared.indices[outputPos] = keys[VT * tid + i]; scanout[outputPos] = VT * tid + i; outputPos++; } - } __syncthreads(); @@ -73,9 +71,8 @@ struct CTASegReduce { // copy from the scratch space back into the keys #pragma unroll - for(int i = 0; i < VT; ++i) { + for(int i = 0; i < VT; ++i) keys[i * NT + tid] = shared.indices[i * NT + tid]; - } __syncthreads(); } @@ -119,9 +116,8 @@ void compute_alpha_kernel( const int NV = NT * VT; __shared__ int label[NV]; - if((L + repeats) > T) { + if((L + repeats) > T) return; - } // Generate labels with blanks from labels without blanks { @@ -131,9 +127,8 @@ void compute_alpha_kernel( labels_with_blanks[offset] = blank_label; labels_with_blanks[offset + 1] = labels_without_blanks[label_start_offset + idx]; } - if(tid == 0) { + if(tid == 0) labels_with_blanks[(blockIdx.x * S_memoffset) + 2 * L] = blank_label; - } } __syncthreads(); @@ -144,15 +139,13 @@ void compute_alpha_kernel( // Set the first row of alpha neg_inf - it is much more efficient to do it // here than outside #pragma unroll - for(int idx = tid; idx < min(S, NV); idx += blockDim.x) { + for(int idx = tid; idx < min(S, NV); idx += blockDim.x) alpha[idx] = ctc_helper::neg_inf(); - } // Load labels into shared memory #pragma unroll - for(int i = tid; i < S; i += NT) { + for(int i = tid; i < S; i += NT) label[i] = label_global[i]; - } __syncthreads(); @@ -160,9 +153,8 @@ void compute_alpha_kernel( int end = S > 1 ? 2 : 1; // Initialize the first row corresponding to t=0; - for(int i = tid; i < (end - start); i += blockDim.x) { + for(int i = tid; i < (end - start); i += blockDim.x) alpha[i + start] = probs[prob_offset + label[i + start]]; - } __syncthreads(); @@ -179,12 +171,11 @@ void compute_alpha_kernel( // This is the first column and in this case there is nothing left of it if(tid == 0) { - if(start == 0) { + if(start == 0) alpha[start_cur_row] = alpha[start_prev_row] + probs[prob_offset + start_prob_col + blank_label]; - } else if(start == 1) { + else if(start == 1) alpha[start_cur_row] = alpha[start_prev_row]; - } } __syncthreads(); @@ -202,9 +193,8 @@ void compute_alpha_kernel( if( (label[idx] != blank_label) && (idx != 1) && (label[idx] != label[idx - 2]) - ) { + ) prev_sum = log_plus_f(prev_sum, alpha[(idx - 2) + start_prev_row]); - } alpha[idx + start_cur_row] = prev_sum + probs[prob_offset + start_prob_col + label[idx]]; @@ -223,9 +213,8 @@ void compute_alpha_kernel( start = (val * (L != 0) + start); end = (val * (L != 0) + end); - for(int i = start; i < end; ++i) { + for(int i = start; i < end; ++i) loglike = log_plus_f(loglike, alpha[i + (T - 1) * S]); - } nll_forward[blockIdx.x] = -loglike; } @@ -287,18 +276,16 @@ void compute_betas_and_grad_kernel( ProbT beta_val[VT]; - if((L + repeats) > T) { + if((L + repeats) > T) return; - } int start = S > 1 ? (S - 2) : 0; int end = (L + repeats < T) ? S : S - 1; // Setup shared memory buffers #pragma unroll - for(int idx = tid; idx < NV; idx += NT) { + for(int idx = tid; idx < NV; idx += NT) label[idx] = (idx < S) ? label_global[idx] : INT_MAX; - } __syncthreads(); @@ -356,24 +343,21 @@ void compute_betas_and_grad_kernel( // Load labels back #pragma unroll - for(int idx = tid; idx < NV; idx += NT) { + for(int idx = tid; idx < NV; idx += NT) temp_buffer.beta[idx] = ctc_helper::neg_inf(); - } __syncthreads(); // Initialize the two rightmost values in the last row (assuming L non-zero) - for(int i = tid; i < (end - start); i += blockDim.x) { + for(int i = tid; i < (end - start); i += blockDim.x) temp_buffer.beta[i + start] = probs[prob_offset + (T - 1) * (out_dim * stride) + label[i + start]]; - } __syncthreads(); // Load output data in registers through the transpose trick - should really be a function #pragma unroll - for(int idx = tid; idx < S; idx += NT) { + for(int idx = tid; idx < S; idx += NT) output[idx] = alpha[idx + (T - 1) * S] + temp_buffer.beta[idx]; - } __syncthreads(); @@ -399,9 +383,8 @@ void compute_betas_and_grad_kernel( if( (label[idx] != blank_label) && (idx != (S - 2)) && (label[idx] != label[idx + 2]) - ) { + ) next_sum = log_plus_f(next_sum, temp_buffer.beta[idx + 2]); - } beta_val[i] = next_sum + probs[prob_offset + start_prob_col + label[idx]]; } @@ -410,24 +393,21 @@ void compute_betas_and_grad_kernel( // Initialize values for the rightmost column since there is nothing to the right // Update input buffer for next iteration - if((tid == 0) && (end == S)) { + if((tid == 0) && (end == S)) temp_buffer.beta[(S - 1)] = temp_buffer.beta[(S - 1)] + probs[prob_offset + start_prob_col + blank_label]; - } #pragma unroll - for(int idx = tid, i = 0; idx < (S - 1); idx += NT, i++) { + for(int idx = tid, i = 0; idx < (S - 1); idx += NT, i++) temp_buffer.beta[idx] = beta_val[i]; - } __syncthreads(); // Beta Computation done - add to alpha and update the gradient. Reload // the gradient back for segmented reduce later on #pragma unroll - for(int idx = tid; idx < S; idx += NT) { + for(int idx = tid; idx < S; idx += NT) output[idx] = alpha[idx + start_cur_row] + temp_buffer.beta[idx]; - } __syncthreads(); @@ -443,16 +423,14 @@ void compute_betas_and_grad_kernel( for(int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) { accum[j] = ctc_helper::neg_inf(); - for(int i = seg_start[j]; i <= seg_end[j]; ++i) { + for(int i = seg_start[j]; i <= seg_end[j]; ++i) accum[j] = log_plus_f(accum[j], output[gather_indices[i]]); - } } __syncthreads(); // Write accumulated value into output since that is not used - for(int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) { + for(int idx = tid, j = 0; idx < uniquelabels; idx += blockDim.x, ++j) output[idx] = accum[j]; - } __syncthreads(); for(int idx = tid; idx < out_dim; idx += blockDim.x) { @@ -470,10 +448,9 @@ void compute_betas_and_grad_kernel( if( (grad == 0.0) || (exp(probs[grads_offset]) == 0.0) || (grad == ctc_helper::neg_inf()) - ) {} else { + ) {} else grads[grads_offset] = exp(probs[grads_offset]) - exp(grad - probs[grads_offset] - log_partition); - } } __syncthreads(); @@ -489,9 +466,8 @@ void compute_betas_and_grad_kernel( end = (-val * (L != 0) + end); // Sum and return the leftmost one/two value(s) in first row - for(int i = start; i < end; ++i) { + for(int i = start; i < end; ++i) loglike = log_plus_f(loglike, temp_buffer.beta[i]); - } nll_backward[blockIdx.x] = -loglike; } diff --git a/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu b/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu index 04678e1..3962916 100644 --- a/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu +++ b/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu @@ -57,9 +57,8 @@ ctcStatus_t compute_ctc_loss( || workspace == nullptr || alphabet_size <= 0 || minibatch <= 0 - ) { + ) return CTC_STATUS_INVALID_VALUE; - } if(options.loc == CTC_CPU) { CpuCTC < float > ctc( @@ -70,7 +69,7 @@ ctcStatus_t compute_ctc_loss( options.blank_label ); - if(gradients != NULL) { + if(gradients != NULL) return ctc.cost_and_grad( activations, gradients, @@ -79,7 +78,7 @@ ctcStatus_t compute_ctc_loss( label_lengths, input_lengths ); - } else { + else return ctc.score_forward( activations, costs, @@ -87,7 +86,6 @@ ctcStatus_t compute_ctc_loss( label_lengths, input_lengths ); - } } else if(options.loc == CTC_GPU) { #ifdef __CUDACC__ GpuCTC < float > ctc( @@ -98,7 +96,7 @@ ctcStatus_t compute_ctc_loss( options.blank_label ); - if(gradients != NULL) { + if(gradients != NULL) return ctc.cost_and_grad( activations, gradients, @@ -107,7 +105,7 @@ ctcStatus_t compute_ctc_loss( label_lengths, input_lengths ); - } else { + else return ctc.score_forward( activations, costs, @@ -115,14 +113,13 @@ ctcStatus_t compute_ctc_loss( label_lengths, input_lengths ); - } + #else std::cerr << "GPU execution requested, but not compiled with GPU support" << std::endl; return CTC_STATUS_EXECUTION_FAILED; #endif - } else { + } else return CTC_STATUS_INVALID_VALUE; - } } @@ -141,9 +138,8 @@ ctcStatus_t get_workspace_size( || size_bytes == nullptr || alphabet_size <= 0 || minibatch <= 0 - ) { + ) return CTC_STATUS_INVALID_VALUE; - } // This is the max of all S and T for all examples in the minibatch. int maxL = *std::max_element(label_lengths, label_lengths + minibatch); diff --git a/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu b/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu index 051afe7..9b3deec 100644 --- a/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu +++ b/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu @@ -44,9 +44,8 @@ template < int NT, typename T, typename Rop T shuff; for(int offset = warp_size / 2; offset > 0; offset /= 2) { shuff = __shfl_down_sync(0xffffffff, x, offset); - if(tid + offset < count && tid < offset) { + if(tid + offset < count && tid < offset) x = g(x, shuff); - } } return x; } @@ -71,9 +70,8 @@ template < int NT, typename Iop, typename Rop, typename T T curr; // Each block works on a column - if(idx < num_rows) { + if(idx < num_rows) curr = f(input[idx + col * num_rows]); - } idx += NT; @@ -86,9 +84,8 @@ template < int NT, typename Iop, typename Rop, typename T curr = R::reduce(tid, curr, storage, num_rows, g); // Store result in out - if(tid == 0) { + if(tid == 0) output[col] = curr; - } } template < int NT, typename Iop, typename Rop, typename T @@ -122,9 +119,8 @@ template < int NT, typename Iop, typename Rop, typename T // Reduce if(threadIdx.y == 0 && row < num_rows) { #pragma unroll - for(int i = 1; i < warps_per_block && i < num_cols; ++i) { + for(int i = 1; i < warps_per_block && i < num_cols; ++i) curr = g(curr, s[i + threadIdx.x * warps_per_block]); - } output[row] = curr; } } @@ -175,9 +171,8 @@ template < typename T, typename Iof, typename Rof ReduceHelper::impl(f, g, input, output, rows, cols, axis, stream); cudaStreamSynchronize(stream); cudaError_t err = cudaGetLastError(); - if(err != cudaSuccess) { + if(err != cudaSuccess) return CTC_STATUS_EXECUTION_FAILED; - } return CTC_STATUS_SUCCESS; } diff --git a/flashlight/pkg/text/data/TextDataset.cpp b/flashlight/pkg/text/data/TextDataset.cpp index 6ae0bd7..1407c22 100644 --- a/flashlight/pkg/text/data/TextDataset.cpp +++ b/flashlight/pkg/text/data/TextDataset.cpp @@ -50,9 +50,8 @@ TextDataset::TextDataset( while(reader.hasNextLine()) { const auto currentEosPosition = data_.size() - 1; - if(!sentenceRanges.empty()) { + if(!sentenceRanges.empty()) sentenceRanges.back().second = currentEosPosition; - } const auto tokens = tokenizer.tokenize(reader.getLine()); const auto indices = dictionary.mapEntriesToIndices(tokens); @@ -65,18 +64,16 @@ TextDataset::TextDataset( data_.insert(data_.end(), indices.begin(), indices.end()); data_.push_back(eos); } - if(!sentenceRanges.empty()) { + if(!sentenceRanges.empty()) sentenceRanges.back().second = data_.size() - 1; - } } const int64_t nTokens = data_.size(); /* 2. Batchify */ - if(batchSize <= 0) { + if(batchSize <= 0) throw std::invalid_argument( "[TextDataset] BatchSize needs to be positive." ); - } if(sampleBreakMode == "none") { // Sentences are split into equal size (=`tokensPerSample`) @@ -100,7 +97,7 @@ TextDataset::TextDataset( // Sentences with length > `tokensPerSample` are skipped; // Total tokens per batch <= `batchSize` * `tokensPerSample` - if(useDynamicBatching) { + if(useDynamicBatching) // sorting samples by length in ascending order std::sort( sentenceRanges.begin(), @@ -110,7 +107,6 @@ TextDataset::TextDataset( return p1.second - p1.first < p2.second - p2.first; } ); - } std::vector batch; for(int64_t i = 0; i < sentenceRanges.size(); ++i) { @@ -120,25 +116,22 @@ TextDataset::TextDataset( batch.emplace_back(SamplePosition{startPoint, endPoint}); bool isFull; - if(useDynamicBatching) { + if(useDynamicBatching) isFull = sampleSize * (batch.size() + 1) > batchSize * tokensPerSample; - } else { + else isFull = batch.size() == batchSize; - } if(isFull) { batches_.push_back(std::move(batch)); batch = std::vector(); } } - if(!batch.empty()) { + if(!batch.empty()) batches_.push_back(std::move(batch)); - } - } else { + } else throw std::invalid_argument( "Invalid sampleBreakMode: should be none or eos, but it is given " + sampleBreakMode ); - } FL_LOG(LogLevel::INFO) << "[TextDataset] (" << reader.getRank() << "/" << reader.getTotalReaders() << ") Loaded " << nTokens @@ -153,9 +146,8 @@ int64_t TextDataset::size() const { std::vector TextDataset::get(const int64_t idx) const { const auto& batch = batches_[idx % size()]; int64_t maxLength = 0; - for(const auto& pos : batch) { + for(const auto& pos : batch) maxLength = std::max(maxLength, pos.last - pos.first + 1); - } std::vector buffer(batch.size() * maxLength, pad_); for(int64_t i = 0; i < batch.size(); ++i) { const auto& pos = batch[i]; @@ -174,9 +166,8 @@ std::vector TextDataset::get(const int64_t idx) const { void TextDataset::shuffle(uint64_t seed) { std::mt19937_64 rng(seed); // Deterministic method across compilers. - for(uint64_t i = size() - 1; i >= 1; --i) { + for(uint64_t i = size() - 1; i >= 1; --i) std::swap(batches_[i], batches_[rng() % (i + 1)]); - } } } // namespace fl diff --git a/flashlight/pkg/text/test/data/TextDatasetTest.cpp b/flashlight/pkg/text/test/data/TextDatasetTest.cpp index ec37c03..1e725d8 100644 --- a/flashlight/pkg/text/test/data/TextDatasetTest.cpp +++ b/flashlight/pkg/text/test/data/TextDatasetTest.cpp @@ -29,21 +29,18 @@ fs::path dataDir = ""; Dictionary createDictionary(const std::string& path) { Dictionary dictionary; std::ifstream stream(path); - if(!stream) { + if(!stream) throw std::runtime_error("createDictionary - invalid path"); - } std::string line; while(std::getline(stream, line)) { - if(line.empty()) { + if(line.empty()) continue; - } auto tkns = splitOnWhitespace(line, true); dictionary.addEntry(tkns.front()); } - if(!dictionary.isContiguous()) { + if(!dictionary.isContiguous()) throw std::runtime_error("Invalid dictionary_ format - not contiguous"); - } dictionary.setDefaultIndex(dictionary.getIndex(fl::lib::text::kUnkToken)); return dictionary; } diff --git a/flashlight/pkg/vision/common/BetaDistribution.h b/flashlight/pkg/vision/common/BetaDistribution.h index c0775a6..416b056 100644 --- a/flashlight/pkg/vision/common/BetaDistribution.h +++ b/flashlight/pkg/vision/common/BetaDistribution.h @@ -140,11 +140,10 @@ namespace lib { if( std::getline(is, str, '(') && str == "~Beta" && is >> a && is.get() == ',' && is >> b && is.get() == ')' - ) { + ) beta = beta_distribution(a, b); - } else { + else is.setstate(std::ios::failbit); - } return is; } diff --git a/flashlight/pkg/vision/criterion/Hungarian.cpp b/flashlight/pkg/vision/criterion/Hungarian.cpp index 83faee3..9ff81d4 100644 --- a/flashlight/pkg/vision/criterion/Hungarian.cpp +++ b/flashlight/pkg/vision/criterion/Hungarian.cpp @@ -63,9 +63,8 @@ std::pair HungarianMatcher::matchBatch( const Tensor& targetClasses ) const { // Kind of a hack... - if(targetClasses.isEmpty()) { + if(targetClasses.isEmpty()) return {fl::fromScalar(0), fl::fromScalar(0)}; - } // Create an M X N cost matrix where M is the number of targets and N is the // number of preds diff --git a/flashlight/pkg/vision/criterion/HungarianImpl.cpp b/flashlight/pkg/vision/criterion/HungarianImpl.cpp index f2365d6..ba484e7 100644 --- a/flashlight/pkg/vision/criterion/HungarianImpl.cpp +++ b/flashlight/pkg/vision/criterion/HungarianImpl.cpp @@ -27,7 +27,7 @@ void findUncoveredZero( bool done = false; *row = -1; *col = -1; - for(int c = 0; c < ncols && !done; c++) { + for(int c = 0; c < ncols && !done; c++) for(int r = 0; r < nrows && !done; r++) { const float cost = costs[c * nrows + r]; if(cost == 0 && colCover[c] == 0 && rowCover[r] == 0) { @@ -36,14 +36,12 @@ void findUncoveredZero( done = true; } } - } } bool isStarInRow(int* marks, int row, int nrows, int ncols) { for(int c = 0; c < ncols; c++) { - if(marks[c * nrows + row] == Mark::Star) { + if(marks[c * nrows + row] == Mark::Star) return true; - } ; } return false; @@ -51,9 +49,8 @@ bool isStarInRow(int* marks, int row, int nrows, int ncols) { int findStarInRow(int* marks, int row, int nrows, int ncols) { for(int c = 0; c < ncols; c++) { - if(marks[c * nrows + row] == Mark::Star) { + if(marks[c * nrows + row] == Mark::Star) return c; - } ; } return -1; @@ -66,13 +63,11 @@ int stepOne(float* costs, const int nrows, const int ncols) { float min_val = std::numeric_limits::max(); for(int c = 0; c < ncols; c++) { float val = costs[c * nrows + r]; - if(val < min_val) { + if(val < min_val) min_val = val; - } } - for(int c = 0; c < ncols; c++) { + for(int c = 0; c < ncols; c++) costs[c * nrows + r] -= min_val; - } } return 2; } @@ -87,7 +82,7 @@ int stepTwo( const int nrows, const int ncols ) { - for(int r = 0; r < nrows; r++) { + for(int r = 0; r < nrows; r++) for(int c = 0; c < ncols; c++) { float cost = costs[c * nrows + r]; if(cost == 0.0 && rowCover[r] == 0 && colCover[c] == 0) { @@ -96,13 +91,10 @@ int stepTwo( colCover[c] = 1; } } - } - for(int r = 0; r < nrows; r++) { + for(int r = 0; r < nrows; r++) rowCover[r] = 0; - } - for(int c = 0; c < ncols; c++) { + for(int c = 0; c < ncols; c++) colCover[c] = 0; - } return 3; } @@ -114,23 +106,19 @@ int stepThree( int nrows, int ncols ) { - for(int r = 0; r < nrows; r++) { + for(int r = 0; r < nrows; r++) for(int c = 0; c < ncols; c++) { const int mark = marks[c * nrows + r]; - if(mark == 1) { + if(mark == 1) colCover[c] = 1; - } } - } int coveredCols = 0; - for(int c = 0; c < ncols; c++) { + for(int c = 0; c < ncols; c++) coveredCols += colCover[c]; - } - if(coveredCols == ncols || coveredCols >= nrows) { + if(coveredCols == ncols || coveredCols >= nrows) return 7; - } else { + else return 4; - } } // Find a noncovered zero and "prime it". If there are no uncovered zeros in the @@ -151,9 +139,9 @@ int stepFour( while(!done) { int row, col; findUncoveredZero(costs, colCover, rowCover, nrows, ncols, &row, &col); - if(row < 0 && col < 0) { + if(row < 0 && col < 0) return 6; - } else { + else { // "Prime it" marks[col * nrows + row] = Mark::Prime; if(isStarInRow(marks, row, nrows, ncols)) { @@ -172,20 +160,16 @@ int stepFour( } int findStarInCol(int* masks, int col, int nrows, int /*ncols*/) { - for(int r = 0; r < nrows; r++) { - if(masks[col * nrows + r] == 1) { + for(int r = 0; r < nrows; r++) + if(masks[col * nrows + r] == 1) return r; - } - } return -1; } int findPrimeInRow(int* masks, int row, int nrows, int ncols) { - for(int c = 0; c < ncols; c++) { - if(masks[c * nrows + row] == 2) { + for(int c = 0; c < ncols; c++) + if(masks[c * nrows + row] == 2) return c; - } - } return -1; } @@ -199,28 +183,23 @@ void augmentPaths( for(int p = 0; p < pathCount; p++) { int row = paths[p * 2]; int col = paths[p * 2 + 1]; - if(marks[col * nrows + row] == Mark::Star) { + if(marks[col * nrows + row] == Mark::Star) marks[col * nrows + row] = Mark::None; - } else { + else marks[col * nrows + row] = Mark::Star; - } } } void clearCover(int* cover, int n) { - for(int i = 0; i < n; i++) { + for(int i = 0; i < n; i++) cover[i] = 0; - } } void erasePrimes(int* marks, int nrows, int ncols) { - for(int c = 0; c < ncols; c++) { - for(int r = 0; r < nrows; r++) { - if(marks[c * nrows + r] == Mark::Prime) { + for(int c = 0; c < ncols; c++) + for(int r = 0; r < nrows; r++) + if(marks[c * nrows + r] == Mark::Prime) marks[c * nrows + r] = Mark::None; - } - } - } } int stepFive( @@ -246,9 +225,8 @@ int stepFive( pathCount += 1; path[(pathCount - 1) * 2] = r; path[(pathCount - 1) * 2 + 1] = path[(pathCount - 2) * 2 + 1]; - } else { + } else done = true; - } if(!done) { c = findPrimeInRow(marks, path[(pathCount - 1) * 2], nrows, ncols); pathCount += 1; @@ -271,16 +249,13 @@ float findSmallestNotCovered( int ncols ) { float minValue = std::numeric_limits::max(); - for(int c = 0; c < ncols; c++) { - for(int r = 0; r < nrows; r++) { + for(int c = 0; c < ncols; c++) + for(int r = 0; r < nrows; r++) if(colCover[c] == 0 && rowCover[r] == 0) { const float cost = costs[c * nrows + r]; - if(cost < minValue) { + if(cost < minValue) minValue = cost; - } } - } - } return minValue; } @@ -294,22 +269,19 @@ int stepSix( ) { float minVal = findSmallestNotCovered(costs, colCover, rowCover, nrows, ncols); - for(int c = 0; c < ncols; c++) { + for(int c = 0; c < ncols; c++) for(int r = 0; r < nrows; r++) { - if(rowCover[r] == 1) { + if(rowCover[r] == 1) costs[c * nrows + r] += minVal; - } - if(colCover[c] == 0) { + if(colCover[c] == 0) costs[c * nrows + r] -= minVal; - } } - } return 4; } void stepSeven(int* marks, int* rowIdxs, int* colIdxs, int M, int N) { int i = 0; - for(int r = 0; r < M; r++) { + for(int r = 0; r < M; r++) for(int c = 0; c < N; c++) { const int mark = marks[c * M + r]; if(mark == Mark::Star) { @@ -318,7 +290,6 @@ void stepSeven(int* marks, int* rowIdxs, int* colIdxs, int M, int N) { i += 1; } } - } }; } // namespace diff --git a/flashlight/pkg/vision/criterion/SetCriterion.cpp b/flashlight/pkg/vision/criterion/SetCriterion.cpp index 7ff6468..5ceef0d 100644 --- a/flashlight/pkg/vision/criterion/SetCriterion.cpp +++ b/flashlight/pkg/vision/criterion/SetCriterion.cpp @@ -22,11 +22,10 @@ using namespace fl; Tensor span(const Shape& inDims, const int index) { Shape dims(std::vector(std::max(inDims.ndim(), index + 1), 1)); - if(index > inDims.ndim() - 1) { + if(index > inDims.ndim() - 1) dims[index] = 1; - } else { + else dims[index] = inDims[index]; - } return fl::iota(dims); } @@ -36,20 +35,17 @@ Shape calcStrides(const Shape& dims) { Shape calcOutDims(const std::vector& coords) { unsigned maxNdim = 0; - for(const auto& coord : coords) { - if(coord.ndim() > maxNdim) { + for(const auto& coord : coords) + if(coord.ndim() > maxNdim) maxNdim = coord.ndim(); - } - } Shape oDims(std::vector(maxNdim, 1)); for(const auto& coord : coords) { auto iDims = coord.shape(); for(int i = 0; i < coord.ndim(); i++) { - if(iDims[i] > 1 && oDims[i] == 1) { + if(iDims[i] > 1 && oDims[i] == 1) oDims[i] = iDims[i]; - } assert(iDims[i] == 1 || iDims[i] == oDims[i]); } } @@ -69,9 +65,8 @@ Tensor applyStrides(const std::vector& coords, const Shape& strides) { std::vector spanIfEmpty(const std::vector& coords, Shape dims) { std::vector result(coords.size()); - for(int i = 0; i < coords.size(); i++) { + for(int i = 0; i < coords.size(); i++) result[i] = (coords[i].isEmpty()) ? span(dims, i) : coords[i]; - } return result; } @@ -187,9 +182,8 @@ SetCriterion::LossDict SetCriterion::forward( [](int curr, const Variable& label) { return curr + label.dim(1); }); Tensor numBoxesArray = fl::fromScalar(numBoxes, fl::dtype::s32); - if(isDistributedInit()) { + if(isDistributedInit()) allReduce(numBoxesArray); - } numBoxes = numBoxesArray.scalar(); numBoxes = std::max(numBoxes / fl::getWorldSize(), 1); @@ -209,12 +203,10 @@ SetCriterion::LossDict SetCriterion::forward( indices, numBoxes ); - for(std::pair l : labelLoss) { + for(std::pair l : labelLoss) losses[l.first + "_" + std::to_string(i)] = l.second; - } - for(std::pair l : bboxLoss) { + for(std::pair l : bboxLoss) losses[l.first + "_" + std::to_string(i)] = l.second; - } } return losses; } @@ -228,11 +220,10 @@ SetCriterion::LossDict SetCriterion::lossBoxes( const int numBoxes ) { auto srcIdx = this->getSrcPermutationIdx(indices); - if(srcIdx.first.isEmpty()) { + if(srcIdx.first.isEmpty()) return { {"lossGiou", fl::Variable(fl::fromScalar(0, predBoxes.type()), false)}, {"lossBbox", fl::Variable(fl::fromScalar(0, predBoxes.type()), false)}}; - } auto colIdxs = fl::reshape(srcIdx.second, {1, srcIdx.second.dim(0)}); auto batchIdxs = fl::reshape(srcIdx.first, {1, srcIdx.first.dim(0)}); @@ -243,9 +234,8 @@ SetCriterion::LossDict SetCriterion::lossBoxes( for(const auto& idx : indices) { auto targetIdxs = idx.first; auto reordered = targetBoxes[i](fl::span, targetIdxs); - if(!reordered.isEmpty()) { + if(!reordered.isEmpty()) permuted.emplace_back(reordered); - } i += 1; } auto tgtBoxes = fl::concatenate(permuted, 1); diff --git a/flashlight/pkg/vision/dataset/BatchTransformDataset.h b/flashlight/pkg/vision/dataset/BatchTransformDataset.h index 9a010f9..8453b4d 100644 --- a/flashlight/pkg/vision/dataset/BatchTransformDataset.h +++ b/flashlight/pkg/vision/dataset/BatchTransformDataset.h @@ -37,12 +37,10 @@ namespace pkg { batchSize_(batchsize), batchPolicy_(policy), batchFn_(batchFn) { - if(!dataset_) { + if(!dataset_) throw std::invalid_argument("dataset to be batched is null"); - } - if(batchSize_ <= 0) { + if(batchSize_ <= 0) throw std::invalid_argument("invalid batch size"); - } preBatchSize_ = dataset_->size(); switch(batchPolicy_) { case BatchDatasetPolicy::INCLUDE_LAST: @@ -52,11 +50,10 @@ namespace pkg { size_ = std::floor(static_cast(preBatchSize_) / batchSize_); break; case BatchDatasetPolicy::DIVISIBLE_ONLY: - if(size_ % batchSize_ != 0) { + if(size_ % batchSize_ != 0) throw std::invalid_argument( "dataset is not evenly divisible into batches" ); - } size_ = std::ceil(static_cast(preBatchSize_) / batchSize_); break; default: @@ -67,9 +64,8 @@ namespace pkg { ~BatchTransformDataset() {} T get(const int64_t idx) { - if(!(idx >= 0 && idx < size())) { + if(!(idx >= 0 && idx < size())) throw std::out_of_range("Dataset idx out of range"); - } std::vector> buffer; int64_t start = batchSize_ * idx; @@ -77,12 +73,10 @@ namespace pkg { for(int64_t batchidx = start; batchidx < end; ++batchidx) { auto fds = dataset_->get(batchidx); - if(buffer.size() < fds.size()) { + if(buffer.size() < fds.size()) buffer.resize(fds.size()); - } - for(int64_t i = 0; i < fds.size(); ++i) { + for(int64_t i = 0; i < fds.size(); ++i) buffer[i].emplace_back(fds[i]); - } } return batchFn_(buffer); } diff --git a/flashlight/pkg/vision/dataset/BoxUtils.cpp b/flashlight/pkg/vision/dataset/BoxUtils.cpp index 3a29eb9..39e1e76 100644 --- a/flashlight/pkg/vision/dataset/BoxUtils.cpp +++ b/flashlight/pkg/vision/dataset/BoxUtils.cpp @@ -58,16 +58,13 @@ Tensor flatten(const Tensor& x, int start, int stop) { auto dims = x.shape(); Shape newDims(std::vector(x.ndim(), 1)); int flattenedDims = 1; - for(int i = start; i <= stop; i++) { + for(int i = start; i <= stop; i++) flattenedDims = flattenedDims * dims[i]; - } - for(int i = 0; i < start; i++) { + for(int i = 0; i < start; i++) newDims[i] = dims[i]; - } newDims[start] = flattenedDims; - for(int i = start + 1; i < (x.ndim() - stop); i++) { + for(int i = start + 1; i < (x.ndim() - stop); i++) newDims[i] = dims[i + stop]; - } return fl::reshape(x, newDims); }; @@ -76,16 +73,13 @@ fl::Variable flatten(const fl::Variable& x, int start, int stop) { auto dims = x.shape(); Shape newDims(std::vector(n, 1)); int flattenedDims = 1; - for(int i = start; i <= stop; i++) { + for(int i = start; i <= stop; i++) flattenedDims = flattenedDims * dims[i]; - } - for(int i = 0; i < start; i++) { + for(int i = 0; i < start; i++) newDims[i] = dims[i]; - } newDims[start] = flattenedDims; - for(int i = start + 1; i < (n - stop); i++) { + for(int i = start + 1; i < (n - stop); i++) newDims[i] = dims[i + stop]; - } return moddims(x, newDims); }; @@ -108,11 +102,10 @@ fl::Variable boxArea(const fl::Variable& bboxes) { } Variable cartesian(const Variable& x, const Variable& y, batchFuncVar_t fn) { - if(x.ndim() != 3 || y.ndim() != 3) { + if(x.ndim() != 3 || y.ndim() != 3) throw std::invalid_argument( "vision::cartesian - x and y inputs must have 3 dimensions" ); - } assert(x.dim(2) == y.dim(2)); Shape yDims = {y.dim(0), 1, y.dim(1), y.dim(2)}; auto yMod = moddims(y, {y.dim(0), 1, y.dim(1), y.dim(2)}); @@ -126,11 +119,10 @@ Variable cartesian(const Variable& x, const Variable& y, batchFuncVar_t fn) { } Tensor cartesian(const Tensor& x, const Tensor& y, batchFuncArr_t fn) { - if(x.ndim() != 3 || y.ndim() != 3) { + if(x.ndim() != 3 || y.ndim() != 3) throw std::invalid_argument( "vision::cartesian - x and y inputs must have 3 dimensions" ); - } assert(x.dim(2) == y.dim(2)); Shape yDims = {y.dim(0), 1, y.dim(1), y.dim(2)}; auto yMod = fl::reshape(y, {y.dim(0), 1, y.dim(1), y.dim(2)}); @@ -145,12 +137,11 @@ std::tuple boxIou( const Tensor& bboxes1, const Tensor& bboxes2 ) { - if(bboxes1.ndim() != 3 || bboxes2.ndim() != 3) { + if(bboxes1.ndim() != 3 || bboxes2.ndim() != 3) throw std::invalid_argument( "vision::boxIou - bbox inputs must be of shape " "[4, N, B, ...] and [4, M, B, ...]" ); - } auto area1 = boxArea(bboxes1); auto area2 = boxArea(bboxes2); auto lt = cartesian( diff --git a/flashlight/pkg/vision/dataset/Coco.cpp b/flashlight/pkg/vision/dataset/Coco.cpp index 4f08247..6741d81 100644 --- a/flashlight/pkg/vision/dataset/Coco.cpp +++ b/flashlight/pkg/vision/dataset/Coco.cpp @@ -98,9 +98,8 @@ CocoDataset::CocoDataset( // Create vector of CocoDataSample which will be loaded into arrayfire arrays std::vector data; std::ifstream ifs(list_file); - if(!ifs) { + if(!ifs) throw std::runtime_error("Could not open list file: " + list_file); - } // We use tabs a deliminators between the filepath and each bbox // We use spaced to separate the different fields of the bbox const std::string delim = "\t"; @@ -160,10 +159,10 @@ CocoDataset::CocoDataset( ); const int maxSize = 1333; - if(val) { + if(val) ds = std::make_shared(ds, randomResize({800}, maxSize)); - } else { + else { std::vector scales = { 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800}; TransformAllFunction trainTransform = compose( @@ -203,9 +202,8 @@ CocoDataset::CocoDataset( } void CocoDataset::resample() { - if(shuffled_) { + if(shuffled_) shuffled_->resample(); - } } int64_t CocoDataset::size() const { return batched_->size(); diff --git a/flashlight/pkg/vision/dataset/CocoTransforms.cpp b/flashlight/pkg/vision/dataset/CocoTransforms.cpp index 0c05dbb..ce061b3 100644 --- a/flashlight/pkg/vision/dataset/CocoTransforms.cpp +++ b/flashlight/pkg/vision/dataset/CocoTransforms.cpp @@ -116,14 +116,12 @@ std::vector randomResize(std::vector inputs, int size, int maxsi if(maxSize > 0) { float minOriginalSize = std::min(w, h); float maxOriginalSize = std::max(w, h); - if(maxOriginalSize / minOriginalSize * size > maxSize) { + if(maxOriginalSize / minOriginalSize * size > maxSize) size = round(maxSize * minOriginalSize / maxOriginalSize); - } } - if((w <= h && w == size) || (h <= w && h == size)) { + if((w <= h && w == size) || (h <= w && h == size)) return std::make_pair(w, h); - } int ow, oh; if(w < h) { ow = size; @@ -238,20 +236,18 @@ TransformAllFunction randomResize(std::vector sizes, int maxsize) { TransformAllFunction randomHorizontalFlip(float p) { return [p](const std::vector& in) { - if(static_cast(std::rand()) / static_cast(RAND_MAX) > p) { + if(static_cast(std::rand()) / static_cast(RAND_MAX) > p) return hflip(in); - } else { + else return in; - } }; } TransformAllFunction compose(std::vector fns) { return [fns](const std::vector& in) { std::vector out = in; - for(const auto& fn : fns) { + for(const auto& fn : fns) out = fn(out); - } return out; }; } diff --git a/flashlight/pkg/vision/dataset/DistributedDataset.cpp b/flashlight/pkg/vision/dataset/DistributedDataset.cpp index 2980ec0..6d10248 100644 --- a/flashlight/pkg/vision/dataset/DistributedDataset.cpp +++ b/flashlight/pkg/vision/dataset/DistributedDataset.cpp @@ -26,9 +26,8 @@ DistributedDataset::DistributedDataset( int partitionSize = shuffle_->size() / worldSize; int leftOver = shuffle_->size() % worldSize; - if(worldRank < leftOver) { + if(worldRank < leftOver) partitionSize++; - } ds_ = std::make_shared(shuffle_, permfn, partitionSize); ds_ = std::make_shared(ds_, numThreads, prefetchSize); ds_ = std::make_shared(ds_, batchSize, batchPolicy); diff --git a/flashlight/pkg/vision/dataset/Imagenet.cpp b/flashlight/pkg/vision/dataset/Imagenet.cpp index a0c25f1..111ad7a 100644 --- a/flashlight/pkg/vision/dataset/Imagenet.cpp +++ b/flashlight/pkg/vision/dataset/Imagenet.cpp @@ -25,9 +25,8 @@ std::vector fileGlob(const std::string& pattern) { glob_t result; glob(pattern.c_str(), GLOB_TILDE, nullptr, &result); std::vector ret; - for(unsigned int i = 0; i < result.gl_pathc; ++i) { + for(unsigned int i = 0; i < result.gl_pathc; ++i) ret.emplace_back(result.gl_pathv[i]); - } globfree(&result); return ret; } @@ -41,31 +40,27 @@ std::unordered_map getImagenetLabels( std::unordered_map labels; std::vector lines; std::ifstream inFile(labelFile); - if(!inFile) { + if(!inFile) throw std::invalid_argument( "fl::pkg::vision::getImagenetLabels given invalid labelFile path" ); - } - for(std::string str; std::getline(inFile, str);) { + for(std::string str; std::getline(inFile, str);) lines.emplace_back(str); - } - if(lines.empty()) { + if(lines.empty()) throw std::runtime_error( "In function imagenetLabels: No lines in file:" + labelFile.string() ); - } for(int i = 0; i < lines.size(); i++) { std::string line = lines[i]; auto it = line.find(','); if(it != std::string::npos) { std::string label = line.substr(0, it); labels[label] = i; - } else { + } else throw std::runtime_error( "In function imagenetLabels: Invalid label format for line: " + line ); - } } return labels; } @@ -77,11 +72,10 @@ std::shared_ptr imagenetDataset( ) { std::vector filepaths = fileGlob(imgDir.string() + "/**/*.JPEG"); - if(filepaths.empty()) { + if(filepaths.empty()) throw std::runtime_error( "No images were found in imagenet directory: " + imgDir.string() ); - } // Create image dataset std::shared_ptr imageDataset = @@ -92,11 +86,10 @@ std::shared_ptr imagenetDataset( auto getLabelIdxs = [&labelMap](const std::string& s) -> uint64_t { std::string parentPath = s.substr(0, s.rfind('/')); std::string label = parentPath.substr(parentPath.rfind('/') + 1); - if(labelMap.find(label) != labelMap.end()) { + if(labelMap.find(label) != labelMap.end()) return labelMap.at(label); - } else { + else throw std::runtime_error("Label: " + label + " not found in label map"); - } return labelMap.at(label); }; diff --git a/flashlight/pkg/vision/dataset/Jpeg.cpp b/flashlight/pkg/vision/dataset/Jpeg.cpp index f2b74a6..cff76b9 100644 --- a/flashlight/pkg/vision/dataset/Jpeg.cpp +++ b/flashlight/pkg/vision/dataset/Jpeg.cpp @@ -38,9 +38,8 @@ Tensor loadJpeg(const std::string& fp, int desiredNumberOfChannels /* = 3 */) { stbi_image_free(img); // Then reorder to W X H X C return fl::transpose(result, {1, 2, 0}); - } else { + } else throw std::invalid_argument("Could not load from filepath" + fp); - } } std::shared_ptr jpegLoader(std::vector fps) { diff --git a/flashlight/pkg/vision/dataset/Transforms.cpp b/flashlight/pkg/vision/dataset/Transforms.cpp index a40b89f..7201171 100644 --- a/flashlight/pkg/vision/dataset/Transforms.cpp +++ b/flashlight/pkg/vision/dataset/Transforms.cpp @@ -101,9 +101,8 @@ Tensor colorEnhance(const Tensor& input, const float enhance) { Tensor autoContrast(const Tensor& input) { auto minPic = fl::amin(input); auto maxPic = fl::amax(input); - if(fl::all(minPic == maxPic).asScalar()) { + if(fl::all(minPic == maxPic).asScalar()) return input; - } auto scale = fl::tile(255. / (maxPic - minPic), input.shape()); minPic = fl::tile(minPic, input.shape()); @@ -150,9 +149,8 @@ Tensor equalize(const Tensor& input) { } Tensor posterize(const Tensor& input, const int bitsToKeep) { - if(bitsToKeep < 1 || bitsToKeep > 8) { + if(bitsToKeep < 1 || bitsToKeep > 8) throw std::invalid_argument("bitsToKeep needs to be in [1, 8]"); - } uint8_t mask = ~((1 << (8 - bitsToKeep)) - 1); auto res = input.astype(fl::dtype::u8) && mask; return res.astype(input.type()); @@ -212,9 +210,8 @@ std::pair mixupBatch( // in : W x H x C x B // target: B x 1 auto targetOneHot = oneHot(target, numClasses, labelSmoothing); - if(lambda == 0) { + if(lambda == 0) return {input, targetOneHot}; - } // mix input auto inputFlipped = fl::flip(input, 3); @@ -239,9 +236,8 @@ std::pair cutmixBatch( // in : W x H x C x B // target: B x 1 auto targetOneHot = oneHot(target, numClasses, labelSmoothing); - if(lambda == 0) { + if(lambda == 0) return {input, targetOneHot}; - } // mix input auto inputFlipped = fl::flip(input, 3); @@ -280,9 +276,8 @@ ImageTransform resizeTransform(const uint64_t resize) { ImageTransform compose(std::vector transformfns) { return [transformfns](const Tensor& in) { Tensor out = in; - for(const auto& fn : transformfns) { + for(const auto& fn : transformfns) out = fn(out); - } return out; }; } @@ -348,11 +343,10 @@ ImageTransform randomCropTransform(const int tw, const int th) { Tensor out = in; const uint64_t w = in.dim(0); const uint64_t h = in.dim(1); - if(th > h || tw > w) { + if(th > h || tw > w) throw std::runtime_error( "Target th and target width are great the image size" ); - } const int x = std::rand() % (w - tw + 1); const int y = std::rand() % (h - th + 1); return crop(in, x, y, tw, th); @@ -383,9 +377,8 @@ ImageTransform randomEraseTransform( // follows: https://git.io/JY9R7 return [p, areaRatioMin, areaRatioMax, edgeRatioMin, edgeRatioMax]( const Tensor& in) { - if(p < randomFloat(0, 1)) { + if(p < randomFloat(0, 1)) return in; - } const float epsilon = 1e-7; const int w = in.dim(0); @@ -399,9 +392,8 @@ ImageTransform randomEraseTransform( std::exp(randomFloat(std::log(edgeRatioMin), std::log(edgeRatioMax))); const int maskW = std::round(std::sqrt(s * r)); const int maskH = std::round(std::sqrt(s / r)); - if(maskW >= w || maskH >= h) { + if(maskW >= w || maskH >= h) continue; - } const int x = static_cast(randomFloat(0, w - maskW - epsilon)); const int y = static_cast(randomFloat(0, h - maskH - epsilon)); @@ -425,9 +417,8 @@ ImageTransform randomAugmentationDeitTransform( return [p, n, fillImg](const Tensor& in) { auto res = in; for(int i = 0; i < n; i++) { - if(p < randomFloat(0, 1)) { + if(p < randomFloat(0, 1)) continue; - } int mode = std::floor(randomFloat(0, 15 - 1e-5)); if(mode == 0) { @@ -467,10 +458,10 @@ ImageTransform randomAugmentationDeitTransform( 1 + randomPerturbNegate(baseEnhance, -0.03, 0.03); res = colorEnhance(res, enhance); - } else if(mode == 6) { + } else if(mode == 6) // auto contrast res = autoContrast(res); - } else if(mode == 7) { + else if(mode == 7) { // contrast float baseEnhance = .8; float enhance = @@ -484,22 +475,22 @@ ImageTransform randomAugmentationDeitTransform( 1 + randomPerturbNegate(baseEnhance, -0.03, 0.03); res = brightnessEnhance(res, enhance); - } else if(mode == 9) { + } else if(mode == 9) // invert res = invert(res); - } else if(mode == 10) { + else if(mode == 10) // solarize res = solarize(res, 26.); - } else if(mode == 11) { + else if(mode == 11) // solarize add res = solarizeAdd(res, 128., 100.); - } else if(mode == 12) { + else if(mode == 12) // equalize res = equalize(res); - } else if(mode == 13) { + else if(mode == 13) // posterize res = posterize(res, 1); - } else if(mode == 14) { + else if(mode == 14) { // sharpness float baseEnhance = .5; float enhance = randomPerturbNegate(baseEnhance, -0.01, 0.01); diff --git a/flashlight/pkg/vision/models/Detr.cpp b/flashlight/pkg/vision/models/Detr.cpp index 8ec4411..0a0eadf 100644 --- a/flashlight/pkg/vision/models/Detr.cpp +++ b/flashlight/pkg/vision/models/Detr.cpp @@ -93,12 +93,11 @@ Detr::Detr( std::vector Detr::forward(const std::vector& input) { // input: {input, mask} - if(input.size() != 2) { + if(input.size() != 2) throw std::invalid_argument( "Detr takes 2 Variables as input but gets " + std::to_string(input.size()) ); - } auto feature = forwardBackbone(input.front()); return forwardTransformer({feature, input[1]}); } @@ -153,9 +152,8 @@ std::vector Detr::paramsWithoutBackbone() { childParams.push_back(bboxEmbed_->params()); childParams.push_back(queryEmbed_->params()); childParams.push_back(inputProj_->params()); - for(auto params : childParams) { + for(auto params : childParams) results.insert(results.end(), params.begin(), params.end()); - } return results; } diff --git a/flashlight/pkg/vision/models/Resnet.cpp b/flashlight/pkg/vision/models/Resnet.cpp index 8443e5c..1779839 100644 --- a/flashlight/pkg/vision/models/Resnet.cpp +++ b/flashlight/pkg/vision/models/Resnet.cpp @@ -38,12 +38,10 @@ ConvBnAct::ConvBnAct( const auto pad = PaddingMode::SAME; const bool bias = !bn; add(fl::Conv2D(inC, outC, kw, kh, sx, sy, pad, pad, 1, 1, bias)); - if(bn) { + if(bn) add(fl::BatchNorm(2, outC)); - } - if(act) { + if(act) add(fl::ReLU()); - } } ResNetBlock::ResNetBlock() = default; @@ -115,11 +113,10 @@ std::vector ResNetBottleneckBlock::forward( out = bn3->forward(out); std::vector shortcut; - if(modules().size() > 9) { + if(modules().size() > 9) shortcut = module(9)->forward(inputs); - } else { + else shortcut = inputs; - } return relu3->forward({out[0] + shortcut[0]}); } @@ -147,11 +144,10 @@ std::vector ResNetBlock::forward( out = bn2->forward(out); std::vector shortcut; - if(modules().size() > 6) { + if(modules().size() > 6) shortcut = module(6)->forward(inputs); - } else { + else shortcut = inputs; - } return relu2->forward({out[0] + shortcut[0]}); } @@ -171,9 +167,8 @@ ResNetBottleneckStage::ResNetBottleneckStage( add(ResNetBottleneckBlock(inC, outC, stride)); const int expansionFactor = 4; const int inPlanes = outC * expansionFactor; - for(int i = 1; i < numBlocks; i++) { + for(int i = 1; i < numBlocks; i++) add(ResNetBottleneckBlock(inPlanes, outC)); - } }; ResNetBottleneckStage::ResNetBottleneckStage() = default; @@ -187,9 +182,8 @@ ResNetStage::ResNetStage( const int stride ) { add(ResNetBlock(inC, outC, stride)); - for(int i = 1; i < numBlocks; i++) { + for(int i = 1; i < numBlocks; i++) add(ResNetBlock(outC, outC)); - } } std::shared_ptr resnet34() { auto model = std::make_shared(); diff --git a/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.cpp b/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.cpp index 05c9425..d230dfe 100644 --- a/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.cpp +++ b/flashlight/pkg/vision/models/ResnetFrozenBatchNorm.cpp @@ -38,12 +38,10 @@ ConvFrozenBatchNormActivation::ConvFrozenBatchNormActivation( const auto pad = PaddingMode::SAME; const bool bias = !bn; add(fl::Conv2D(inC, outC, kw, kh, sx, sy, pad, pad, 1, 1, bias)); - if(bn) { + if(bn) add(fl::FrozenBatchNorm(2, outC)); - } - if(act) { + if(act) add(fl::ReLU()); - } } ResNetBlockFrozenBatchNorm::ResNetBlockFrozenBatchNorm() = default; @@ -120,11 +118,10 @@ std::vector ResNetBottleneckBlockFrozenBatchNorm::forward( out = bn3->forward(out); std::vector shortcut; - if(modules().size() > 9) { + if(modules().size() > 9) shortcut = module(9)->forward(inputs); - } else { + else shortcut = inputs; - } return relu3->forward({out[0] + shortcut[0]}); } @@ -152,11 +149,10 @@ std::vector ResNetBlockFrozenBatchNorm::forward( out = bn2->forward(out); std::vector shortcut; - if(modules().size() > 6) { + if(modules().size() > 6) shortcut = module(6)->forward(inputs); - } else { + else shortcut = inputs; - } return relu2->forward({out[0] + shortcut[0]}); } @@ -176,9 +172,8 @@ ResNetBottleneckStageFrozenBatchNorm::ResNetBottleneckStageFrozenBatchNorm( add(ResNetBottleneckBlockFrozenBatchNorm(inC, outC, stride)); const int expansionFactor = 4; const int inPlanes = outC * expansionFactor; - for(int i = 1; i < numBlocks; i++) { + for(int i = 1; i < numBlocks; i++) add(ResNetBottleneckBlockFrozenBatchNorm(inPlanes, outC)); - } }; ResNetBottleneckStageFrozenBatchNorm::ResNetBottleneckStageFrozenBatchNorm() = @@ -193,9 +188,8 @@ ResNetStageFrozenBatchNorm::ResNetStageFrozenBatchNorm( const int stride ) { add(ResNetBlockFrozenBatchNorm(inC, outC, stride)); - for(int i = 1; i < numBlocks; i++) { + for(int i = 1; i < numBlocks; i++) add(ResNetBlockFrozenBatchNorm(outC, outC)); - } } } // namespace fl diff --git a/flashlight/pkg/vision/models/ViT.cpp b/flashlight/pkg/vision/models/ViT.cpp index eec5098..aa5e443 100644 --- a/flashlight/pkg/vision/models/ViT.cpp +++ b/flashlight/pkg/vision/models/ViT.cpp @@ -125,14 +125,12 @@ std::vector ViT::forward( // Positional embedding auto posEmb = tile(params_[1], {1, 1, B}).astype(output.type()); output = output + posEmb; - if(train_) { + if(train_) output = dropout(output, pDropout_); - } // Transformers - for(int i = 0; i < nLayers_; ++i) { + for(int i = 0; i < nLayers_; ++i) output = transformers_[i]->forward({output}).front(); - } // Linear output = ln_->forward(output); // C x T x B @@ -146,9 +144,8 @@ std::string ViT::prettyString() const { std::ostringstream ss; ss << "ViT (" << nClasses_ << " classes) with " << nLayers_ << " Transformers:\n"; - for(const auto& transformers : transformers_) { + for(const auto& transformers : transformers_) ss << transformers->prettyString() << "\n"; - } return ss.str(); } diff --git a/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp b/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp index 99544af..7f1150f 100644 --- a/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp +++ b/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp @@ -59,9 +59,8 @@ void FrozenBatchNorm::setRunningVar(const fl::Variable& x) { } void FrozenBatchNorm::train() { - for(auto& param : params_) { + for(auto& param : params_) param.setCalcGrad(false); - } runningVar_.setCalcGrad(false); runningMean_.setCalcGrad(false); train_ = false; @@ -71,9 +70,8 @@ std::string FrozenBatchNorm::prettyString() const { std::ostringstream ss; ss << "FrozenBatchNorm"; ss << " ( axis : { "; - for(auto x : featAxis_) { + for(auto x : featAxis_) ss << x << " "; - } ss << "}, size : " << featSize_ << " )"; return ss.str(); } diff --git a/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp b/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp index 43799e5..0ef694f 100644 --- a/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp +++ b/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp @@ -34,9 +34,8 @@ PositionalEmbeddingSine::PositionalEmbeddingSine( normalize_(other.normalize_), scale_(other.scale_) { train_ = other.train_; - for(auto& mod : other.modules_) { + for(auto& mod : other.modules_) add(mod->clone()); - } } PositionalEmbeddingSine& PositionalEmbeddingSine::operator=( @@ -48,9 +47,8 @@ PositionalEmbeddingSine& PositionalEmbeddingSine::operator=( normalize_ = other.normalize_; scale_ = other.scale_; clear(); - for(auto& mod : other.modules_) { + for(auto& mod : other.modules_) add(mod->clone()); - } return *this; } diff --git a/flashlight/pkg/vision/nn/Transformer.cpp b/flashlight/pkg/vision/nn/Transformer.cpp index 3a1f4cb..5ce4a2f 100644 --- a/flashlight/pkg/vision/nn/Transformer.cpp +++ b/flashlight/pkg/vision/nn/Transformer.cpp @@ -67,10 +67,9 @@ fl::Variable transformerMultiheadAttention( auto scores = matmulTN(q, k); - if(!keyPaddingMask.isEmpty()) { + if(!keyPaddingMask.isEmpty()) scores = scores + tileAs(moddims(log(keyPaddingMask), {1, srcLen, 1, bsz}), scores); - } auto attn = dropout(softmax(scores, 1), pDropout); auto result = matmulNT(attn.astype(v.type()), v); @@ -251,9 +250,8 @@ Variable TransformerBaseLayer::withPosEmbed( const Variable& input, const Variable& pos ) { - if(pos.isEmpty()) { + if(pos.isEmpty()) return input; - } return input + pos; } @@ -395,9 +393,8 @@ Variable TransformerDecoderLayer::withPosEmbed( const Variable& input, const Variable& pos ) { - if(pos.isEmpty()) { + if(pos.isEmpty()) return input; - } return input + pos; } @@ -456,9 +453,8 @@ TransformerDecoder::TransformerDecoder( float pDropout ) { // TODO add norm - for(int i = 0; i < layers; i++) { + for(int i = 0; i < layers; i++) add(TransformerDecoderLayer(modelDim, mlpDim, nHeads, pDropout)); - } add(LayerNorm(std::vector{0}, 1e-5, true, modelDim)); } @@ -497,9 +493,8 @@ TransformerEncoder::TransformerEncoder( int32_t layers, float pDropout ) { - for(int i = 0; i < layers; i++) { + for(int i = 0; i < layers; i++) add(TransformerEncoderLayer(modelDim, mlpDim, nHeads, pDropout)); - } } std::vector TransformerEncoder::forward( @@ -507,9 +502,8 @@ std::vector TransformerEncoder::forward( ) { std::vector output = input; auto mods = modules(); - for(int i = 0; i < mods.size(); i++) { + for(int i = 0; i < mods.size(); i++) output = mods[i]->forward(output); - } return output; } @@ -577,12 +571,11 @@ std::vector Transformer::forward( Variable queryEmbed, Variable posEmbed ) { - if(src.ndim() != 4) { + if(src.ndim() != 4) throw std::invalid_argument( "vision::Transformer::forward - " "expect src to be of shape (W, H, C, B)." ); - } assert(src.dim(2) == queryEmbed.dim(0)); int B = src.dim(3); diff --git a/flashlight/pkg/vision/nn/VisionTransformer.cpp b/flashlight/pkg/vision/nn/VisionTransformer.cpp index ac205a0..cca25b1 100644 --- a/flashlight/pkg/vision/nn/VisionTransformer.cpp +++ b/flashlight/pkg/vision/nn/VisionTransformer.cpp @@ -139,9 +139,8 @@ Variable VisionTransformer::selfAttention(const Variable& x) { } Variable VisionTransformer::dropPath(const Variable& x) { - if(!train_) { + if(!train_) return x; - } // https://git.io/JYOkq int C = x.dim(0); @@ -159,19 +158,17 @@ Variable VisionTransformer::dropPath(const Variable& x) { std::vector VisionTransformer::forward( const std::vector& inputs ) { - if(inputs.size() != 1) { + if(inputs.size() != 1) throw std::runtime_error("VisionTransformer forward, !1 input Variables"); - } auto x = inputs.front(); - if(x.ndim() != 3) { + if(x.ndim() != 3) throw std::invalid_argument( "VisionTransformer::forward - " "expected input with 3 dimensions - got input with " + std::to_string(x.ndim()) ); - } x = x + dropPath(selfAttention((*norm1_)(x))); x = x + dropPath(mlp((*norm2_)(x))); diff --git a/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.cpp b/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.cpp index ade467d..7aa0908 100644 --- a/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.cpp +++ b/flashlight/pkg/vision/tensor/backend/af/ArrayFireVisionExtension.cpp @@ -59,9 +59,8 @@ namespace detail { af::array res = input; const double delta = 1e-2; - if(!fillImg.isempty()) { + if(!fillImg.isempty()) res = res + delta; - } // Call the transform res = transformFunc(res, std::forward(args)...); @@ -171,16 +170,14 @@ Tensor ArrayFireVisionExtension::translate( ) { // If no output dims specified, AF expects 2D 0's which to discard OOB data Shape outputDims = outputDimsIn; - if(outputDimsIn.ndim() == 0) { + if(outputDimsIn.ndim() == 0) outputDims = Shape({0, 0}); - } - if(translation.ndim() != 2 || outputDims.ndim() != 2) { + if(translation.ndim() != 2 || outputDims.ndim() != 2) throw std::invalid_argument( "ArrayFireVisionExtension::shear - " "only 2D skews shapes and empty or 2D output shapes are supported" ); - } return toTensor( detail::addFillTensor( @@ -205,16 +202,14 @@ Tensor ArrayFireVisionExtension::translate( ) { // If no output dims specified, AF expects 2D 0's which to discard OOB data Shape outputDims = outputDimsIn; - if(outputDimsIn.ndim() == 0) { + if(outputDimsIn.ndim() == 0) outputDims = Shape({0, 0}); - } - if(translation.ndim() != 2 || outputDims.ndim() != 2) { + if(translation.ndim() != 2 || outputDims.ndim() != 2) throw std::invalid_argument( "ArrayFireVisionExtension::shear - " "only 2D skews shapes and empty or 2D output shapes are supported" ); - } af::dim4 _translations = detail::flToAfDims(translation); af::dim4 _outputDims = detail::flToAfInterpType(mode); @@ -239,16 +234,14 @@ Tensor ArrayFireVisionExtension::shear( ) { // If no output dims specified, AF expects 2D 0's which to discard OOB data Shape outputDims = outputDimsIn; - if(outputDimsIn.ndim() == 0) { + if(outputDimsIn.ndim() == 0) outputDims = Shape({0, 0}); - } - if(skews.size() != 2 || outputDims.ndim() != 2) { + if(skews.size() != 2 || outputDims.ndim() != 2) throw std::invalid_argument( "ArrayFireVisionExtension::shear - " "only 2D skews shapes and empty or 2D output shapes are supported" ); - } af::dim4 _outputDims = detail::flToAfDims(outputDims); @@ -276,16 +269,14 @@ Tensor ArrayFireVisionExtension::shear( ) { // If no output dims specified, AF expects 2D 0's which to discard OOB data Shape outputDims = outputDimsIn; - if(outputDimsIn.ndim() == 0) { + if(outputDimsIn.ndim() == 0) outputDims = Shape({0, 0}); - } - if(skews.size() != 2 || outputDims.ndim() != 2) { + if(skews.size() != 2 || outputDims.ndim() != 2) throw std::invalid_argument( "ArrayFireVisionExtension::shear - " "only 2D skews shapes and empty or 2D output shapes are supported" ); - } return toTensor( af::skew( diff --git a/flashlight/pkg/vision/test/criterion/HungarianTest.cpp b/flashlight/pkg/vision/test/criterion/HungarianTest.cpp index 17b946d..bfc49f6 100644 --- a/flashlight/pkg/vision/test/criterion/HungarianTest.cpp +++ b/flashlight/pkg/vision/test/criterion/HungarianTest.cpp @@ -15,23 +15,19 @@ TEST(HungarianTest, DiagnalAssignments) { int M = 4; // Rows int N = 4; // Columns std::vector costsVec(N * N); - for(int r = 0; r < M; r++) { - for(int c = 0; c < N; c++) { + for(int r = 0; r < M; r++) + for(int c = 0; c < N; c++) costsVec[r * N + c] = (1 + r) * (1 + c); - } - } std::vector expRowIdxs = {0, 1, 2, 3}; std::vector expColIdxs = {3, 2, 1, 0}; std::vector rowIdxs(N); std::vector colIdxs(M); hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); - for(int r = 0; r < M; r++) { + for(int r = 0; r < M; r++) EXPECT_EQ(rowIdxs[r], expRowIdxs[r]) << "Assignment differs at index " << r; - } - for(int c = 0; c < N; c++) { + for(int c = 0; c < N; c++) EXPECT_EQ(rowIdxs[c], expRowIdxs[c]) << "Assignment differs at index " << c; - } } TEST(HungarianTest, FullPipelineFromWiki) { @@ -46,12 +42,10 @@ TEST(HungarianTest, FullPipelineFromWiki) { std::vector rowIdxs(N); std::vector colIdxs(M); hungarian(costsVec.data(), rowIdxs.data(), colIdxs.data(), M, N); - for(int r = 0; r < M; r++) { + for(int r = 0; r < M; r++) EXPECT_EQ(rowIdxs[r], expRowIdxs[r]) << "Assignment differs at index " << r; - } - for(int c = 0; c < N; c++) { + for(int c = 0; c < N; c++) EXPECT_EQ(rowIdxs[c], expRowIdxs[c]) << "Assignment differs at index " << c; - } } TEST(HungarianTest, FullPipelineSimple1) { @@ -72,12 +66,10 @@ TEST(HungarianTest, FullPipelineSimple1) { std::vector expAssignment = {0, 1, 0, 1, 0, 0, 0, 0, 1}; std::vector assignment(N * M); hungarian(costsVec.data(), assignment.data(), N, M); - for(int c = 0; c < N; c++) { - for(int r = 0; r < M; r++) { + for(int c = 0; c < N; c++) + for(int r = 0; r < M; r++) EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) << "Assignment differs at row " << r << " and col " << c; - } - } } TEST(HungarianTest, FullPipelineSimple2) { @@ -89,12 +81,10 @@ TEST(HungarianTest, FullPipelineSimple2) { std::vector expAssignment = {0, 0, 1, 1, 0, 0, 0, 1, 0}; std::vector assignment(N * M); hungarian(costsVec.data(), assignment.data(), N, M); - for(int c = 0; c < N; c++) { - for(int r = 0; r < M; r++) { + for(int c = 0; c < N; c++) + for(int r = 0; r < M; r++) EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) << "Assignment differs at row " << r << " and col " << c; - } - } } TEST(HungarianTest, FullPipelineSimple3) { @@ -105,12 +95,10 @@ TEST(HungarianTest, FullPipelineSimple3) { std::vector expAssignment = {0, 0, 1, 0, 1, 0, 1, 0, 0}; std::vector assignment(N * M); hungarian(costsVec.data(), assignment.data(), N, M); - for(int c = 0; c < N; c++) { - for(int r = 0; r < M; r++) { + for(int c = 0; c < N; c++) + for(int r = 0; r < M; r++) EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) << "Assignment differs at row " << r << " and col " << c; - } - } } TEST(HungarianTest, FullPipelineSize6) { @@ -125,12 +113,10 @@ TEST(HungarianTest, FullPipelineSize6) { 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0}; std::vector assignment(N * M); hungarian(costsVec.data(), assignment.data(), N, M); - for(int c = 0; c < N; c++) { - for(int r = 0; r < M; r++) { + for(int c = 0; c < N; c++) + for(int r = 0; r < M; r++) EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) << "Assignment differs at row " << r << " and col " << c; - } - } } TEST(HungarianTest, 6x6Example2) { int M = 6; // Rows @@ -144,12 +130,10 @@ TEST(HungarianTest, 6x6Example2) { 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0}; std::vector assignment(N * M); hungarian(costsVec.data(), assignment.data(), N, M); - for(int c = 0; c < N; c++) { - for(int r = 0; r < M; r++) { + for(int c = 0; c < N; c++) + for(int r = 0; r < M; r++) EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) << "Assignment differs at row " << r << " and col " << c; - } - } } TEST(HungarianTest, NonSquare2) { diff --git a/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp b/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp index b64fc6d..d0a84e1 100644 --- a/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp +++ b/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp @@ -21,12 +21,11 @@ std::unordered_map getLossWeights() { {"lossCe", 1.f}, {"lossGiou", 1.f}, {"lossBbox", 1.f}}; std::unordered_map lossWeights; - for(int i = 0; i < 6; i++) { + for(int i = 0; i < 6; i++) for(const auto& l : lossWeightsBase) { std::string key = l.first + "_" + std::to_string(i); lossWeights[key] = l.second; } - } return lossWeights; } diff --git a/uncrustify.cfg b/uncrustify.cfg index 8e5673b..66b3a26 100644 --- a/uncrustify.cfg +++ b/uncrustify.cfg @@ -163,6 +163,7 @@ sp_after_byref_func = force sp_before_byref_func = remove nl_collapse_empty_body = true nl_collapse_empty_body_functions = true +nl_create_if_one_liner = false nl_assign_leave_one_liners = true nl_class_leave_one_liners = true nl_enum_leave_one_liners = true @@ -244,8 +245,8 @@ pos_constr_colon = trail nl_constr_colon = remove nl_constr_init_args = force mod_full_brace_do = force -mod_full_brace_for = force -mod_full_brace_if = force +mod_full_brace_for = remove +mod_full_brace_if = remove mod_full_brace_while = force mod_full_brace_using = force mod_paren_on_return = remove From 9742f51da8b3a5b05c3610a11e57cd0488434abb Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 21:39:58 +0100 Subject: [PATCH 18/24] more formatting improvements --- .github/workflows/check-formatting.yml | 7 +++---- cmake/utils/fm_target_utilities.cmake | 2 +- uncrustify.cfg | 20 +++++++++++++------- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/.github/workflows/check-formatting.yml b/.github/workflows/check-formatting.yml index 5593e16..7208784 100644 --- a/.github/workflows/check-formatting.yml +++ b/.github/workflows/check-formatting.yml @@ -10,10 +10,9 @@ permissions: { contents: read } # CONFIG # --------------------------------------------------------- env: - UNCRUSTIFY_CONFIG: "uncrustify.cfg" # Make sure this matches your config file's name/path + UNCRUSTIFY_CONFIG: "uncrustify.cfg" CHECK_PATH: "flashlight" - FILE_EXTENSIONS: "c|cpp|h|hpp|cu" - FILES_PER_THREAD: "40" + FILE_EXTENSIONS: "c|cpp|h|hpp|cu|cuh" # --------------------------------------------------------- # JOB @@ -36,4 +35,4 @@ jobs: -type f \ -regextype posix-extended \ -regex ".*\.(${{ env.FILE_EXTENSIONS }})$" \ - | xargs -r -P $(nproc) -n ${{ env.FILES_PER_THREAD }} uncrustify -c ${{ env.UNCRUSTIFY_CONFIG }} --check \ No newline at end of file + | uncrustify -q -c ${{ env.UNCRUSTIFY_CONFIG }} --check -F - \ No newline at end of file diff --git a/cmake/utils/fm_target_utilities.cmake b/cmake/utils/fm_target_utilities.cmake index 9403bfa..c3ad681 100644 --- a/cmake/utils/fm_target_utilities.cmake +++ b/cmake/utils/fm_target_utilities.cmake @@ -84,7 +84,7 @@ endfunction() #]] function(fm_glob_cpp OUT_VAR) - fm_glob(${OUT_VAR} ${ARGN} PATTERNS "*.cpp" "*.hpp" "*.inl" "*.h" "*.cu") + fm_glob(${OUT_VAR} ${ARGN} PATTERNS "*.cpp" "*.hpp" "*.inl" "*.h" "*.cu" "*.cuh") set(${OUT_VAR} ${${OUT_VAR}} PARENT_SCOPE) endfunction() diff --git a/uncrustify.cfg b/uncrustify.cfg index 66b3a26..d264622 100644 --- a/uncrustify.cfg +++ b/uncrustify.cfg @@ -1,3 +1,4 @@ +file_ext CPP .cu .cuh newlines = auto input_tab_size = 4 output_tab_size = 4 @@ -24,6 +25,7 @@ indent_func_proto_param = false indent_func_class_param = false indent_func_ctor_var_param = false indent_template_param = false +indent_cpp_lambda_body = true use_indent_func_call_param = true donot_indent_func_def_close_paren = true align_func_params = false @@ -68,6 +70,9 @@ sp_inside_angle = remove sp_inside_angle_empty = remove sp_angle_word = force sp_angle_shift = remove +sp_angle_paren = remove +sp_angle_paren_empty = remove +sp_after_angle = ignore sp_permit_cpp11_shift = true sp_before_sparen = remove sp_inside_sparen = remove @@ -161,9 +166,15 @@ sp_before_unnamed_byref = remove sp_after_byref = force sp_after_byref_func = force sp_before_byref_func = remove +sp_before_ellipsis = remove +sp_type_ellipsis = remove +sp_parameter_pack_ellipsis = remove +sp_ellipsis_parameter_pack = force +sp_ptr_type_ellipsis = remove nl_collapse_empty_body = true nl_collapse_empty_body_functions = true nl_create_if_one_liner = false +nl_create_func_def_one_liner = true nl_assign_leave_one_liners = true nl_class_leave_one_liners = true nl_enum_leave_one_liners = true @@ -212,8 +223,8 @@ nl_cpp_ldef_brace = remove nl_after_semicolon = true nl_after_brace_open = true nl_after_brace_close = true -nl_after_vbrace_close = true -nl_max = 3 +nl_after_vbrace_close = false +nl_max = 2 nl_before_access_spec = 2 nl_after_access_spec = 0 nl_template_class = force @@ -256,8 +267,3 @@ mod_case_brace = remove mod_remove_empty_return = true pp_indent = remove pp_indent_at_level = false -sp_before_ellipsis = remove -sp_type_ellipsis = remove -sp_parameter_pack_ellipsis = remove -sp_ellipsis_parameter_pack = force -sp_ptr_type_ellipsis = remove \ No newline at end of file From c8c65e54cc0ecea6b6df6988673f33fadbcfdc6b Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 21:40:08 +0100 Subject: [PATCH 19/24] formatted more code according to style --- flashlight/fl/autograd/Functions.cpp | 36 +- flashlight/fl/autograd/Utils.cpp | 4 +- flashlight/fl/autograd/Variable.cpp | 16 +- .../fl/autograd/tensor/AutogradExtension.h | 4 +- .../tensor/backend/cudnn/CudnnUtils.cpp | 8 +- .../tensor/backend/onednn/DnnlUtils.cpp | 12 +- flashlight/fl/common/Defines.cpp | 8 +- flashlight/fl/common/Defines.h | 1 - flashlight/fl/common/DynamicBenchmark.cpp | 8 +- flashlight/fl/common/DynamicBenchmark.h | 4 +- flashlight/fl/common/Logging.cpp | 105 +- flashlight/fl/common/Logging.h | 16 +- flashlight/fl/common/Plugin.h | 4 +- flashlight/fl/common/Serialization-inl.h | 15 +- flashlight/fl/common/Utils.cpp | 4 +- .../fl/contrib/modules/PositionEmbedding.cpp | 4 +- flashlight/fl/contrib/modules/Residual.cpp | 3 +- .../modules/SinusoidalPositionEmbedding.cpp | 4 +- flashlight/fl/contrib/modules/Transformer.cpp | 8 +- flashlight/fl/dataset/BlobDataset.cpp | 20 +- flashlight/fl/dataset/Dataset.h | 8 +- flashlight/fl/dataset/MemoryBlobDataset.cpp | 4 +- flashlight/fl/dataset/ResampleDataset.cpp | 4 +- flashlight/fl/dataset/ShuffleDataset.cpp | 4 +- flashlight/fl/distributed/DistributedApi.cpp | 8 +- flashlight/fl/distributed/FileStore.cpp | 8 +- .../backend/cpu/DistributedBackend.cpp | 4 +- .../backend/cuda/DistributedBackend.cpp | 4 +- .../backend/stub/DistributedBackend.cpp | 8 +- flashlight/fl/examples/RnnLm.cpp | 12 +- flashlight/fl/meter/CountMeter.cpp | 4 +- flashlight/fl/meter/EditDistanceMeter.h | 4 +- flashlight/fl/meter/TimeMeter.cpp | 4 +- flashlight/fl/meter/TopKMeter.cpp | 4 +- flashlight/fl/nn/Init.cpp | 45 +- flashlight/fl/nn/modules/Activations.cpp | 40 +- flashlight/fl/nn/modules/Container.cpp | 4 +- flashlight/fl/nn/modules/Container.h | 4 +- flashlight/fl/nn/modules/Embedding.cpp | 4 +- flashlight/fl/nn/modules/Identity.cpp | 4 +- flashlight/fl/nn/modules/Loss.cpp | 10 +- flashlight/fl/nn/modules/Module.cpp | 12 +- flashlight/fl/nn/modules/Normalize.cpp | 4 +- flashlight/fl/nn/modules/Padding.cpp | 4 +- flashlight/fl/nn/modules/PrecisionCast.cpp | 7 +- flashlight/fl/nn/modules/RNN.cpp | 15 +- flashlight/fl/nn/modules/Transform.cpp | 4 +- flashlight/fl/optim/Optimizers.h | 4 +- flashlight/fl/runtime/CUDAStream.cpp | 4 +- flashlight/fl/runtime/CUDAUtils.cpp | 4 +- flashlight/fl/runtime/DeviceType.cpp | 4 +- flashlight/fl/runtime/SynchronousStream.cpp | 4 +- flashlight/fl/tensor/Compute.cpp | 20 +- flashlight/fl/tensor/DefaultTensorType.h | 1 - flashlight/fl/tensor/Index.h | 4 +- flashlight/fl/tensor/Random.cpp | 12 +- flashlight/fl/tensor/Shape.cpp | 4 +- flashlight/fl/tensor/TensorAdapter.cpp | 1 - flashlight/fl/tensor/TensorAdapter.h | 4 +- flashlight/fl/tensor/TensorBackend.cpp | 12 +- flashlight/fl/tensor/TensorBase.cpp | 216 +-- flashlight/fl/tensor/TensorBase.h | 20 +- flashlight/fl/tensor/TensorExtension.h | 4 +- flashlight/fl/tensor/Types.cpp | 4 +- .../fl/tensor/backend/af/AdvancedIndex.cpp | 4 +- .../fl/tensor/backend/af/AdvancedIndex.cu | 235 ++-- .../fl/tensor/backend/af/ArrayFireBackend.cpp | 12 +- .../fl/tensor/backend/af/ArrayFireTensor.cpp | 56 +- .../backend/af/mem/CachingMemoryManager.cpp | 16 +- .../backend/af/mem/DefaultMemoryManager.cpp | 16 +- .../backend/af/mem/MemoryManagerAdapter.cpp | 8 +- .../fl/tensor/backend/stub/StubBackend.cpp | 228 +--- .../fl/tensor/backend/stub/StubTensor.cpp | 72 +- .../autograd/AutogradNormalizationTest.cpp | 2 - .../fl/test/common/DynamicBenchmarkTest.cpp | 10 +- .../fl/test/common/SerializationTest.cpp | 3 +- flashlight/fl/test/common/UtilsTest.cpp | 8 +- flashlight/fl/test/dataset/DatasetTest.cpp | 56 +- flashlight/fl/test/nn/ModuleTest.cpp | 4 +- flashlight/fl/test/nn/NNSerializationTest.cpp | 4 +- flashlight/fl/test/runtime/CUDAStreamTest.cpp | 1 - flashlight/fl/test/tensor/TensorBaseTest.cpp | 1 - .../test/tensor/af/ArrayFireCPUStreamTest.cpp | 1 - .../tensor/af/ArrayFireTensorBaseTest.cpp | 6 +- .../fl/test/tensor/af/MemoryFrameworkTest.cpp | 40 +- .../pkg/runtime/common/DistributedUtils.cpp | 3 +- .../pkg/runtime/plugin/ModulePlugin.cpp | 4 +- .../pkg/speech/audio/feature/Derivatives.cpp | 1 - .../pkg/speech/augmentation/SoundEffect.cpp | 8 +- .../speech/augmentation/SoundEffectConfig.cpp | 1 - .../speech/augmentation/SoundEffectUtil.cpp | 4 +- .../pkg/speech/augmentation/SoxWrapper.h | 4 +- flashlight/pkg/speech/common/Flags.h | 1 - .../pkg/speech/criterion/Seq2SeqCriterion.cpp | 4 +- .../pkg/speech/criterion/Seq2SeqCriterion.h | 8 +- .../speech/criterion/TransformerCriterion.cpp | 4 +- .../criterion/attention/AttentionBase.h | 4 +- .../pkg/speech/data/FeatureTransforms.cpp | 6 +- .../pkg/speech/decoder/DecodeMaster.cpp | 8 +- flashlight/pkg/speech/decoder/PlGenerator.cpp | 4 +- .../pkg/speech/runtime/SpeechStatMeter.cpp | 4 +- .../test/augmentation/ReverberationTest.cpp | 4 +- .../speech/test/criterion/CriterionTest.cpp | 16 +- flashlight/pkg/speech/test/data/SoundTest.cpp | 8 +- .../moderngpu/include/device/ctamerge.cuh | 664 +++++---- .../moderngpu/include/device/ctascan.cuh | 466 +++---- .../moderngpu/include/device/ctasearch.cuh | 330 +++-- .../moderngpu/include/device/ctasegscan.cuh | 188 +-- .../moderngpu/include/device/devicetypes.cuh | 360 +++-- .../moderngpu/include/device/deviceutil.cuh | 149 ++- .../moderngpu/include/device/intrinsics.cuh | 393 +++--- .../moderngpu/include/device/loadstore.cuh | 1181 ++++++++++------- .../moderngpu/include/device/sortnetwork.cuh | 222 ++-- .../contrib/moderngpu/include/mgpudevice.cuh | 429 ++++-- .../contrib/moderngpu/include/util/static.h | 2 - .../speech/third_party/warpctc/include/ctc.h | 1 - .../warpctc/include/detail/cpu_ctc.h | 3 - .../warpctc/include/detail/ctc_helper.h | 4 +- .../warpctc/include/detail/gpu_ctc.h | 7 - .../warpctc/include/detail/gpu_ctc_kernels.h | 1 - .../third_party/warpctc/src/ctc_entrypoint.cu | 11 +- .../speech/third_party/warpctc/src/reduce.cu | 50 +- .../pkg/vision/common/BetaDistribution.h | 4 +- flashlight/pkg/vision/criterion/Hungarian.cpp | 4 +- .../pkg/vision/criterion/HungarianImpl.h | 1 - .../pkg/vision/criterion/SetCriterion.cpp | 7 +- flashlight/pkg/vision/dataset/BoxUtils.cpp | 6 +- flashlight/pkg/vision/dataset/Coco.cpp | 4 +- flashlight/pkg/vision/dataset/Coco.h | 8 +- .../pkg/vision/dataset/CocoTransforms.cpp | 14 +- flashlight/pkg/vision/dataset/Transforms.cpp | 77 +- flashlight/pkg/vision/models/Detr.cpp | 11 +- flashlight/pkg/vision/nn/FrozenBatchNorm.cpp | 8 +- .../pkg/vision/nn/PositionalEmbeddingSine.cpp | 4 +- flashlight/pkg/vision/nn/Transformer.cpp | 1 - 135 files changed, 3161 insertions(+), 3171 deletions(-) diff --git a/flashlight/fl/autograd/Functions.cpp b/flashlight/fl/autograd/Functions.cpp index 782b574..25722a8 100644 --- a/flashlight/fl/autograd/Functions.cpp +++ b/flashlight/fl/autograd/Functions.cpp @@ -107,9 +107,7 @@ namespace detail { ); } - bool areVariableTypesEqual(const Variable& a, const Variable& b) { - return a.type() == b.type(); - } + bool areVariableTypesEqual(const Variable& a, const Variable& b) { return a.type() == b.type(); } } // namespace detail @@ -133,9 +131,7 @@ Variable operator+(const Variable& lhs, const double& rhsVal) { return Variable(result, {lhs.withoutData()}, gradFunc); } -Variable operator+(const double& lhsVal, const Variable& rhs) { - return rhs + lhsVal; -} +Variable operator+(const double& lhsVal, const Variable& rhs) { return rhs + lhsVal; } Variable operator-(const Variable& lhs, const Variable& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); @@ -197,9 +193,7 @@ Variable operator*(const Variable& lhs, const double& rhsVal) { return Variable(result, {lhs.withoutData()}, gradFunc); } -Variable operator*(const double& lhsVal, const Variable& rhs) { - return rhs * lhsVal; -} +Variable operator*(const double& lhsVal, const Variable& rhs) { return rhs * lhsVal; } Variable operator/(const Variable& lhs, const Variable& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); @@ -352,9 +346,7 @@ Variable max(const Variable& lhs, const double& rhsVal) { return Variable(result, {lhs}, gradFunc); } -Variable max(const double& lhsVal, const Variable& rhs) { - return max(rhs, lhsVal); -} +Variable max(const double& lhsVal, const Variable& rhs) { return max(rhs, lhsVal); } Variable min(const Variable& lhs, const Variable& rhs) { FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs); @@ -384,9 +376,7 @@ Variable min(const Variable& lhs, const double& rhsVal) { return Variable(result, {lhs}, gradFunc); } -Variable min(const double& lhsVal, const Variable& rhs) { - return min(rhs, lhsVal); -} +Variable min(const double& lhsVal, const Variable& rhs) { return min(rhs, lhsVal); } Variable negate(const Variable& input) { auto result = (0.0 - input.tensor()).astype(input.type()); @@ -522,9 +512,7 @@ Variable sigmoid(const Variable& input) { return Variable(result, {input.withoutData()}, gradFunc); } -Variable swish(const Variable& input, double beta) { - return input * sigmoid(beta * input); -} +Variable swish(const Variable& input, double beta) { return input * sigmoid(beta * input); } Variable erf(const Variable& input) { auto result = fl::erf(FL_ADJUST_INPUT_TYPE(input.tensor())); @@ -578,9 +566,7 @@ Variable tileAs(const Variable& input, const Shape& rdims) { return Variable(result, {input.withoutData()}, gradFunc); } -Variable tileAs(const Variable& input, const Variable& reference) { - return tileAs(input, reference.shape()); -} +Variable tileAs(const Variable& input, const Variable& reference) { return tileAs(input, reference.shape()); } Variable sumAs(const Variable& input, const Shape& rdims) { auto result = detail::sumAs(FL_ADJUST_INPUT_TYPE(input.tensor()), rdims); @@ -592,9 +578,7 @@ Variable sumAs(const Variable& input, const Shape& rdims) { return Variable(result, {input.withoutData()}, gradFunc); } -Variable sumAs(const Variable& input, const Variable& reference) { - return sumAs(input, reference.shape()); -} +Variable sumAs(const Variable& input, const Variable& reference) { return sumAs(input, reference.shape()); } Variable concatenate(const std::vector& concatInputs, int dim) { if(concatInputs.empty()) @@ -1854,9 +1838,7 @@ Variable dropout(const Variable& input, double p) { return input; } -Variable relu(const Variable& input) { - return max(input, 0.0); -} +Variable relu(const Variable& input) { return max(input, 0.0); } Variable gelu(const Variable& in) { auto input = FL_ADJUST_INPUT_TYPE(in); diff --git a/flashlight/fl/autograd/Utils.cpp b/flashlight/fl/autograd/Utils.cpp index 1f958c8..fae41b4 100644 --- a/flashlight/fl/autograd/Utils.cpp +++ b/flashlight/fl/autograd/Utils.cpp @@ -15,8 +15,6 @@ bool allClose( const Variable& a, const Variable& b, double absTolerance /* = 1e-5 */ -) { - return allClose(a.tensor(), b.tensor(), absTolerance); -} +) { return allClose(a.tensor(), b.tensor(), absTolerance); } } // namespace fl diff --git a/flashlight/fl/autograd/Variable.cpp b/flashlight/fl/autograd/Variable.cpp index 5aa3db8..cd627b8 100644 --- a/flashlight/fl/autograd/Variable.cpp +++ b/flashlight/fl/autograd/Variable.cpp @@ -39,8 +39,8 @@ Variable::Variable( inputs.begin(), inputs.end(), [](const Variable& input) { - return input.isCalcGrad(); - } + return input.isCalcGrad(); + } ) ) { sharedGrad_->calcGrad = true; @@ -172,9 +172,7 @@ void Variable::eval() const { fl::eval(tensor()); } -void Variable::zeroGrad() { - sharedGrad_->grad.reset(); -} +void Variable::zeroGrad() { sharedGrad_->grad.reset(); } void Variable::setCalcGrad(bool calcGrad) { sharedGrad_->calcGrad = calcGrad; @@ -221,13 +219,9 @@ void Variable::addGrad(const Variable& childGrad) { } } -void Variable::registerGradHook(const GradHook& hook) { - sharedGrad_->onGradAvailable = hook; -} +void Variable::registerGradHook(const GradHook& hook) { sharedGrad_->onGradAvailable = hook; } -void Variable::clearGradHook() { - sharedGrad_->onGradAvailable = nullptr; -} +void Variable::clearGradHook() { sharedGrad_->onGradAvailable = nullptr; } void Variable::applyGradHook() { if(sharedGrad_->onGradAvailable) { diff --git a/flashlight/fl/autograd/tensor/AutogradExtension.h b/flashlight/fl/autograd/tensor/AutogradExtension.h index dfb96f7..6f5e1bd 100644 --- a/flashlight/fl/autograd/tensor/AutogradExtension.h +++ b/flashlight/fl/autograd/tensor/AutogradExtension.h @@ -50,9 +50,7 @@ class AutogradExtension : public TensorExtension { static constexpr TensorExtensionType extensionType = TensorExtensionType::Autograd; - virtual std::shared_ptr createBenchmarkOptions() { - return nullptr; - } + virtual std::shared_ptr createBenchmarkOptions() { return nullptr; } /**************************** Forward ****************************/ virtual Tensor conv2d( diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp index bf4edeb..82cadcb 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp @@ -417,13 +417,9 @@ ConvDescriptor::~ConvDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyConvolutionDescriptor(descriptor)); } -cudnnHandle_t getCudnnHandle() { - return getActiveDeviceHandle().cudnnHandle; -} +cudnnHandle_t getCudnnHandle() { return getActiveDeviceHandle().cudnnHandle; } -const CUDAStream& getCudnnStream() { - return *getActiveDeviceHandle().stream; -} +const CUDAStream& getCudnnStream() { return *getActiveDeviceHandle().stream; } const void* kOne(const fl::dtype t) { switch(t) { diff --git a/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp b/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp index 6464062..94ddbc5 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/DnnlUtils.cpp @@ -32,9 +32,7 @@ DnnlStream::DnnlStream(dnnl::engine engine) { #endif } -dnnl::stream& DnnlStream::getStream() { - return stream_; -} +dnnl::stream& DnnlStream::getStream() { return stream_; } DnnlStream& DnnlStream::getInstance() { static DnnlStream instance(DnnlEngine::getInstance().getEngine()); @@ -52,9 +50,7 @@ DnnlEngine::DnnlEngine() { #endif } -dnnl::engine& DnnlEngine::getEngine() { - return engine_; -} +dnnl::engine& DnnlEngine::getEngine() { return engine_; } DnnlEngine& DnnlEngine::getInstance() { static DnnlEngine instance; @@ -65,9 +61,7 @@ dnnl::memory::dims convertToDnnlDims(const std::vector& shape) { return dnnl::memory::dims(shape.begin(), shape.end()); } -dnnl::memory::dims convertShapeToDnnlDims(const Shape& shape) { - return convertToDnnlDims(shape.get()); -} +dnnl::memory::dims convertShapeToDnnlDims(const Shape& shape) { return convertToDnnlDims(shape.get()); } DnnlMemoryWrapper::DnnlMemoryWrapper( const Tensor& tensor, diff --git a/flashlight/fl/common/Defines.cpp b/flashlight/fl/common/Defines.cpp index ec1b3c2..04794af 100644 --- a/flashlight/fl/common/Defines.cpp +++ b/flashlight/fl/common/Defines.cpp @@ -13,13 +13,9 @@ namespace fl { -OptimLevel OptimMode::getOptimLevel() { - return optimLevel_; -} +OptimLevel OptimMode::getOptimLevel() { return optimLevel_; } -void OptimMode::setOptimLevel(OptimLevel level) { - optimLevel_ = level; -} +void OptimMode::setOptimLevel(OptimLevel level) { optimLevel_ = level; } OptimMode& OptimMode::get() { static OptimMode optimMode; diff --git a/flashlight/fl/common/Defines.h b/flashlight/fl/common/Defines.h index fe1fbbf..b0999a9 100644 --- a/flashlight/fl/common/Defines.h +++ b/flashlight/fl/common/Defines.h @@ -30,7 +30,6 @@ #define FL_DEPRECATED(msg) __attribute__((deprecated(msg))) #endif // defined(_WIN32) || defined(_MSC_VER) - namespace fl { /** diff --git a/flashlight/fl/common/DynamicBenchmark.cpp b/flashlight/fl/common/DynamicBenchmark.cpp index dbbbd07..507791d 100644 --- a/flashlight/fl/common/DynamicBenchmark.cpp +++ b/flashlight/fl/common/DynamicBenchmark.cpp @@ -40,12 +40,8 @@ void DynamicBenchmark::stop(bool incrementCount) { options_->accumulateTimeToCurrentOption(elapsedTime, incrementCount); } -void DynamicBenchmark::setBenchmarkMode(bool mode) { - benchmarkMode_ = mode; -} +void DynamicBenchmark::setBenchmarkMode(bool mode) { benchmarkMode_ = mode; } -bool DynamicBenchmark::getBenchmarkMode() { - return benchmarkMode_; -} +bool DynamicBenchmark::getBenchmarkMode() { return benchmarkMode_; } } // namespace fl diff --git a/flashlight/fl/common/DynamicBenchmark.h b/flashlight/fl/common/DynamicBenchmark.h index d0fdebe..762ff2f 100644 --- a/flashlight/fl/common/DynamicBenchmark.h +++ b/flashlight/fl/common/DynamicBenchmark.h @@ -129,9 +129,7 @@ struct DynamicBenchmarkOptions : DynamicBenchmarkOptionsBase { * * @return T the current option. */ - T currentOption() { - return updateState(); - } + T currentOption() { return updateState(); } /** * @return whether or not this options' timings are complete. diff --git a/flashlight/fl/common/Logging.cpp b/flashlight/fl/common/Logging.cpp index 113930d..fced169 100644 --- a/flashlight/fl/common/Logging.cpp +++ b/flashlight/fl/common/Logging.cpp @@ -96,7 +96,6 @@ namespace { if(threadId.size() > maxThreadIdNumDigits) threadId = threadId.substr(threadId.size() - maxThreadIdNumDigits); - (*outputStream) << dateTimeWithMicroSeconds() << ' ' << threadId << ' ' << getFileName(fullPath) << ':' << lineNumber << ' '; @@ -151,61 +150,33 @@ void Logging::setMaxLoggingLevel(LogLevel maxLoggingLevel) { } } -Logging&& operator<<(Logging&& log, const std::string& s) { - return std::move(log.print(s)); -} +Logging&& operator<<(Logging&& log, const std::string& s) { return std::move(log.print(s)); } -Logging&& operator<<(Logging&& log, const char* s) { - return std::move(log.print(s)); -} +Logging&& operator<<(Logging&& log, const char* s) { return std::move(log.print(s)); } -Logging&& operator<<(Logging&& log, const void* s) { - return std::move(log.print(s)); -} +Logging&& operator<<(Logging&& log, const void* s) { return std::move(log.print(s)); } -Logging&& operator<<(Logging&& log, char c) { - return std::move(log.print(c)); -} +Logging&& operator<<(Logging&& log, char c) { return std::move(log.print(c)); } -Logging&& operator<<(Logging&& log, unsigned char u) { - return std::move(log.print(u)); -} +Logging&& operator<<(Logging&& log, unsigned char u) { return std::move(log.print(u)); } -Logging&& operator<<(Logging&& log, int i) { - return std::move(log.print(i)); -} +Logging&& operator<<(Logging&& log, int i) { return std::move(log.print(i)); } -Logging&& operator<<(Logging&& log, unsigned int u) { - return std::move(log.print(u)); -} +Logging&& operator<<(Logging&& log, unsigned int u) { return std::move(log.print(u)); } -Logging&& operator<<(Logging&& log, long l) { - return std::move(log.print(l)); -} +Logging&& operator<<(Logging&& log, long l) { return std::move(log.print(l)); } -Logging&& operator<<(Logging&& log, long long l) { - return std::move(log.print(l)); -} +Logging&& operator<<(Logging&& log, long long l) { return std::move(log.print(l)); } -Logging&& operator<<(Logging&& log, unsigned long u) { - return std::move(log.print(u)); -} +Logging&& operator<<(Logging&& log, unsigned long u) { return std::move(log.print(u)); } -Logging&& operator<<(Logging&& log, unsigned long long u) { - return std::move(log.print(u)); -} +Logging&& operator<<(Logging&& log, unsigned long long u) { return std::move(log.print(u)); } -Logging&& operator<<(Logging&& log, float f) { - return std::move(log.print(f)); -} +Logging&& operator<<(Logging&& log, float f) { return std::move(log.print(f)); } -Logging&& operator<<(Logging&& log, double d) { - return std::move(log.print(d)); -} +Logging&& operator<<(Logging&& log, double d) { return std::move(log.print(d)); } -Logging&& operator<<(Logging&& log, bool b) { - return std::move(log.print(b)); -} +Logging&& operator<<(Logging&& log, bool b) { return std::move(log.print(b)); } VerboseLogging::VerboseLogging(int level, const char* fullPath, int lineNumber) : level_(level) { if(level_ <= VerboseLogging::maxLoggingLevel_) { @@ -231,53 +202,29 @@ void VerboseLogging::setMaxLoggingLevel(int maxLoggingLevel) { } } -VerboseLogging&& operator<<(VerboseLogging&& log, const std::string& s) { - return std::move(log.print(s)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, const std::string& s) { return std::move(log.print(s)); } -VerboseLogging&& operator<<(VerboseLogging&& log, const char* s) { - return std::move(log.print(s)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, const char* s) { return std::move(log.print(s)); } -VerboseLogging&& operator<<(VerboseLogging&& log, const void* s) { - return std::move(log.print(s)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, const void* s) { return std::move(log.print(s)); } -VerboseLogging&& operator<<(VerboseLogging&& log, char c) { - return std::move(log.print(c)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, char c) { return std::move(log.print(c)); } -VerboseLogging&& operator<<(VerboseLogging&& log, unsigned char u) { - return std::move(log.print(u)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, unsigned char u) { return std::move(log.print(u)); } -VerboseLogging&& operator<<(VerboseLogging&& log, int i) { - return std::move(log.print(i)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, int i) { return std::move(log.print(i)); } -VerboseLogging&& operator<<(VerboseLogging&& log, unsigned int u) { - return std::move(log.print(u)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, unsigned int u) { return std::move(log.print(u)); } -VerboseLogging&& operator<<(VerboseLogging&& log, long l) { - return std::move(log.print(l)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, long l) { return std::move(log.print(l)); } -VerboseLogging&& operator<<(VerboseLogging&& log, unsigned long u) { - return std::move(log.print(u)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, unsigned long u) { return std::move(log.print(u)); } -VerboseLogging&& operator<<(VerboseLogging&& log, float f) { - return std::move(log.print(f)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, float f) { return std::move(log.print(f)); } -VerboseLogging&& operator<<(VerboseLogging&& log, double d) { - return std::move(log.print(d)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, double d) { return std::move(log.print(d)); } -VerboseLogging&& operator<<(VerboseLogging&& log, bool b) { - return std::move(log.print(b)); -} +VerboseLogging&& operator<<(VerboseLogging&& log, bool b) { return std::move(log.print(b)); } constexpr std::array flLogLevelValues = { fl::LogLevel::INFO, diff --git a/flashlight/fl/common/Logging.h b/flashlight/fl/common/Logging.h index 70045b2..a9895f5 100644 --- a/flashlight/fl/common/Logging.h +++ b/flashlight/fl/common/Logging.h @@ -158,9 +158,7 @@ class FL_API Logging { // Overrides DEFAULT_MAX_FL_LOGGING_LEVEL value. static void setMaxLoggingLevel(LogLevel maxLoggingLevel); - static bool ifLog(LogLevel level) { - return maxLoggingLevel_ >= level; - } + static bool ifLog(LogLevel level) { return maxLoggingLevel_ >= level; } private: static LogLevel maxLoggingLevel_; @@ -185,9 +183,7 @@ class FL_API VerboseLogging { // Overrides DEFAULT_MAX_VERBOSE_FL_LOGGING_LEVEL value. static void setMaxLoggingLevel(int maxLoggingLevel); - static bool ifLog(int level) { - return maxLoggingLevel_ >= level; - } + static bool ifLog(int level) { return maxLoggingLevel_ >= level; } private: static int maxLoggingLevel_; @@ -215,9 +211,7 @@ FL_API Logging && operator<<(Logging && log, bool b); // Catch all designed mostly for stuff. template -Logging && operator<<(Logging&& log, const T& t) { - return log.print(t); -} +Logging && operator<<(Logging&& log, const T& t) { return log.print(t); } FL_API VerboseLogging && operator<<(VerboseLogging && log, const std::string& s); FL_API VerboseLogging && operator<<(VerboseLogging && log, const char* s); @@ -234,8 +228,6 @@ FL_API VerboseLogging && operator<<(VerboseLogging && log, bool b); // Catch all designed mostly for stuff. template -VerboseLogging && operator<<(VerboseLogging&& log, const T& t) { - return log.print(t); -} +VerboseLogging && operator<<(VerboseLogging&& log, const T& t) { return log.print(t); } } // namespace fl diff --git a/flashlight/fl/common/Plugin.h b/flashlight/fl/common/Plugin.h index d4d050b..63c4318 100644 --- a/flashlight/fl/common/Plugin.h +++ b/flashlight/fl/common/Plugin.h @@ -20,9 +20,7 @@ class FL_API Plugin { protected: template - T getSymbol(const std::string& symbol) { - return (T) getRawSymbol(symbol); - } + T getSymbol(const std::string& symbol) { return (T) getRawSymbol(symbol); } private: void* getRawSymbol(const std::string& symbol); diff --git a/flashlight/fl/common/Serialization-inl.h b/flashlight/fl/common/Serialization-inl.h index 6dff6c9..337c594 100644 --- a/flashlight/fl/common/Serialization-inl.h +++ b/flashlight/fl/common/Serialization-inl.h @@ -54,9 +54,7 @@ namespace detail { // 1 argument, general case. template - void applyArchive(Archive& ar, const uint32_t version, Arg&& arg) { - ar(std::forward(arg)); - } + void applyArchive(Archive& ar, const uint32_t version, Arg&& arg) { ar(std::forward(arg)); } // 1 argument, version-restricted. template @@ -109,13 +107,16 @@ namespace detail { } // namespace detail template -detail::Versioned versioned(T&& t, uint32_t minVersion, uint32_t maxVersion) { - return detail::Versioned{std::forward(t), minVersion, maxVersion}; +detail::Versioned versioned( + T&& t, + uint32_t minVersion, + uint32_t maxVersion +) { return detail::Versioned{std::forward(t), minVersion, maxVersion}; } template -detail::SerializeAs serializeAs(T&& t) { - return detail::SerializeAs{std::forward(t), nullptr, nullptr}; +detail::SerializeAs serializeAs(T&& t) { return detail::SerializeAs{std::forward(t), nullptr, nullptr}; } template diff --git a/flashlight/fl/common/Utils.cpp b/flashlight/fl/common/Utils.cpp index 165e48c..9cc9d12 100644 --- a/flashlight/fl/common/Utils.cpp +++ b/flashlight/fl/common/Utils.cpp @@ -22,9 +22,7 @@ namespace fl { -bool f16Supported() { - return defaultTensorBackend().isDataTypeSupported(fl::dtype::f16); -} +bool f16Supported() { return defaultTensorBackend().isDataTypeSupported(fl::dtype::f16); } size_t divRoundUp(size_t numerator, size_t denominator) { if(!numerator) diff --git a/flashlight/fl/contrib/modules/PositionEmbedding.cpp b/flashlight/fl/contrib/modules/PositionEmbedding.cpp index d965670..a95bb08 100644 --- a/flashlight/fl/contrib/modules/PositionEmbedding.cpp +++ b/flashlight/fl/contrib/modules/PositionEmbedding.cpp @@ -59,9 +59,7 @@ std::vector PositionEmbedding::forward( std::vector PositionEmbedding::operator()( const std::vector& input -) { - return forward(input); -} +) { return forward(input); } std::unique_ptr PositionEmbedding::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/contrib/modules/Residual.cpp b/flashlight/fl/contrib/modules/Residual.cpp index 51a84ac..5014961 100644 --- a/flashlight/fl/contrib/modules/Residual.cpp +++ b/flashlight/fl/contrib/modules/Residual.cpp @@ -57,8 +57,7 @@ void Residual::processShortcut( int fromLayer, int toLayer, int projectionIndex -) { - shortcut_[toLayer - 1].insert({fromLayer, projectionIndex}); +) { shortcut_[toLayer - 1].insert({fromLayer, projectionIndex}); } void Residual::addShortcut(int fromLayer, int toLayer) { diff --git a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp index a7e3417..6a6ac42 100644 --- a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp +++ b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp @@ -91,9 +91,7 @@ std::vector SinusoidalPositionEmbedding::forward( std::vector SinusoidalPositionEmbedding::operator()( const std::vector& input -) { - return forward(input); -} +) { return forward(input); } std::unique_ptr SinusoidalPositionEmbedding::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/contrib/modules/Transformer.cpp b/flashlight/fl/contrib/modules/Transformer.cpp index 0df4278..2ea494a 100644 --- a/flashlight/fl/contrib/modules/Transformer.cpp +++ b/flashlight/fl/contrib/modules/Transformer.cpp @@ -209,13 +209,9 @@ std::vector Transformer::forward(const std::vector& input) { } } -void Transformer::setDropout(float value) { - pDropout_ = value; -} +void Transformer::setDropout(float value) { pDropout_ = value; } -void Transformer::setLayerDropout(float value) { - pLayerdrop_ = value; -} +void Transformer::setLayerDropout(float value) { pLayerdrop_ = value; } std::unique_ptr Transformer::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/dataset/BlobDataset.cpp b/flashlight/fl/dataset/BlobDataset.cpp index 021ffbd..67ae662 100644 --- a/flashlight/fl/dataset/BlobDataset.cpp +++ b/flashlight/fl/dataset/BlobDataset.cpp @@ -18,17 +18,13 @@ const int64_t magicNumber = 0x31626f6c423a6c66; BlobDatasetEntryBuffer::BlobDatasetEntryBuffer() = default; -void BlobDatasetEntryBuffer::clear() { - data_.clear(); -} +void BlobDatasetEntryBuffer::clear() { data_.clear(); } int64_t BlobDatasetEntryBuffer::size() const { return data_.size() / nFieldPerEntry_; } -void BlobDatasetEntryBuffer::resize(int64_t size) { - data_.resize(size * nFieldPerEntry_); -} +void BlobDatasetEntryBuffer::resize(int64_t size) { data_.resize(size * nFieldPerEntry_); } BlobDatasetEntry BlobDatasetEntryBuffer::get(const int64_t idx) const { BlobDatasetEntry e; @@ -54,9 +50,7 @@ void BlobDatasetEntryBuffer::add(const BlobDatasetEntry& e) { data_.push_back(e.offset); } -char* BlobDatasetEntryBuffer::data() { - return (char*) data_.data(); -} +char* BlobDatasetEntryBuffer::data() { return (char*) data_.data(); } int64_t BlobDatasetEntryBuffer::bytes() const { return data_.size() * sizeof(int64_t); @@ -233,16 +227,12 @@ void BlobDataset::readIndex() { readData(offset, entries_.data(), entries_.bytes()); } -void BlobDataset::flush() { - flushData(); -} +void BlobDataset::flush() { flushData(); } void BlobDataset::setHostTransform( int field, std::function func -) { - hostTransforms_[field] = func; -} +) { hostTransforms_[field] = func; } std::vector BlobDataset::getEntries(const int64_t idx) const { std::vector entries; diff --git a/flashlight/fl/dataset/Dataset.h b/flashlight/fl/dataset/Dataset.h index ec710ad..afe6262 100644 --- a/flashlight/fl/dataset/Dataset.h +++ b/flashlight/fl/dataset/Dataset.h @@ -72,13 +72,9 @@ class FL_API Dataset { // Setup iterators using iterator = detail::DatasetIterator>; - iterator begin() { - return iterator(this); - } + iterator begin() { return iterator(this); } - iterator end() { - return iterator(); - } + iterator end() { return iterator(); } protected: void checkIndexBounds(int64_t idx) const { diff --git a/flashlight/fl/dataset/MemoryBlobDataset.cpp b/flashlight/fl/dataset/MemoryBlobDataset.cpp index 5b76bce..1f20d2f 100644 --- a/flashlight/fl/dataset/MemoryBlobDataset.cpp +++ b/flashlight/fl/dataset/MemoryBlobDataset.cpp @@ -40,9 +40,7 @@ const { return maxSize; } -void MemoryBlobDataset::flushData() { - std::lock_guard lock(writeMutex_); -} +void MemoryBlobDataset::flushData() { std::lock_guard lock(writeMutex_); } bool MemoryBlobDataset::isEmptyData() const { return data_.empty(); diff --git a/flashlight/fl/dataset/ResampleDataset.cpp b/flashlight/fl/dataset/ResampleDataset.cpp index df6cf2f..9f78416 100644 --- a/flashlight/fl/dataset/ResampleDataset.cpp +++ b/flashlight/fl/dataset/ResampleDataset.cpp @@ -55,9 +55,7 @@ ResampleDataset::ResampleDataset( dataset, makePermutationFromFn(n == -1 ? dataset->size() : n, fn)) {} -void ResampleDataset::resample(std::vector resamplevec) { - resampleVec_ = std::move(resamplevec); -} +void ResampleDataset::resample(std::vector resamplevec) { resampleVec_ = std::move(resamplevec); } std::vector ResampleDataset::get(const int64_t idx) const { checkIndexBounds(idx); diff --git a/flashlight/fl/dataset/ShuffleDataset.cpp b/flashlight/fl/dataset/ShuffleDataset.cpp index 148b343..eebe891 100644 --- a/flashlight/fl/dataset/ShuffleDataset.cpp +++ b/flashlight/fl/dataset/ShuffleDataset.cpp @@ -34,8 +34,6 @@ void ShuffleDataset::resample() { ); } -void ShuffleDataset::setSeed(int seed) { - rng_.seed(seed); -} +void ShuffleDataset::setSeed(int seed) { rng_.seed(seed); } } // namespace fl diff --git a/flashlight/fl/distributed/DistributedApi.cpp b/flashlight/fl/distributed/DistributedApi.cpp index 6ba3e47..5ced9b5 100644 --- a/flashlight/fl/distributed/DistributedApi.cpp +++ b/flashlight/fl/distributed/DistributedApi.cpp @@ -12,13 +12,9 @@ namespace fl { -FL_API bool isDistributedInit() { - return detail::DistributedInfo::getInstance().isInitialized_; -} +FL_API bool isDistributedInit() { return detail::DistributedInfo::getInstance().isInitialized_; } -FL_API DistributedBackend distributedBackend() { - return detail::DistributedInfo::getInstance().backend_; -} +FL_API DistributedBackend distributedBackend() { return detail::DistributedInfo::getInstance().backend_; } FL_API void allReduce(Variable& var, double scale /* = 1.0 */, bool async /* = false */) { if(getWorldSize() > 1) diff --git a/flashlight/fl/distributed/FileStore.cpp b/flashlight/fl/distributed/FileStore.cpp index 82627a5..9fd5883 100644 --- a/flashlight/fl/distributed/FileStore.cpp +++ b/flashlight/fl/distributed/FileStore.cpp @@ -100,12 +100,8 @@ void FileStore::wait(const std::string& key) { } } -fs::path FileStore::tmpPath(const std::string& name) { - return basePath_ / fs::path("." + encodeName(name)); -} +fs::path FileStore::tmpPath(const std::string& name) { return basePath_ / fs::path("." + encodeName(name)); } -fs::path FileStore::objectPath(const std::string& name) { - return basePath_ / fs::path(encodeName(name)); -} +fs::path FileStore::objectPath(const std::string& name) { return basePath_ / fs::path(encodeName(name)); } } // namespace fl diff --git a/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp b/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp index e68cde6..c73eaaf 100644 --- a/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp +++ b/flashlight/fl/distributed/backend/cpu/DistributedBackend.cpp @@ -45,9 +45,7 @@ namespace fl { namespace detail { - std::shared_ptr globalContext() { - return glooContext_; - } + std::shared_ptr globalContext() { return glooContext_; } template inline void allreduceGloo(T* ptr, size_t s) { diff --git a/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp b/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp index ad09e2e..df85152 100644 --- a/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp +++ b/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp @@ -363,9 +363,7 @@ namespace detail { } namespace { - ncclComm_t& NcclContext::getComm() { - return comm_; - } + ncclComm_t& NcclContext::getComm() { return comm_; } int NcclContext::getWorldSize() const { return worldSize_; diff --git a/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp b/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp index b4be54f..25d7f25 100644 --- a/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp +++ b/flashlight/fl/distributed/backend/stub/DistributedBackend.cpp @@ -54,11 +54,7 @@ void syncDistributed() { ); } -int getWorldRank() { - return 0; -} +int getWorldRank() { return 0; } -int getWorldSize() { - return 1; -} +int getWorldSize() { return 1; } } // namespace fl diff --git a/flashlight/fl/examples/RnnLm.cpp b/flashlight/fl/examples/RnnLm.cpp index f72e055..b455b03 100644 --- a/flashlight/fl/examples/RnnLm.cpp +++ b/flashlight/fl/examples/RnnLm.cpp @@ -43,13 +43,9 @@ class Preprocessor { public: explicit Preprocessor(std::string dataset_path); - int to_int(std::string word) { - return word_to_int[word]; - } + int to_int(std::string word) { return word_to_int[word]; } - int vocab_size() { - return word_to_int.size(); - } + int vocab_size() { return word_to_int.size(); } static const std::string eos; @@ -159,9 +155,7 @@ class RnnLm : public Container { } std::tuple - operator()(const Variable& input, const Variable& h, const Variable& c) { - return forward(input, h, c); - } + operator()(const Variable& input, const Variable& h, const Variable& c) { return forward(input, h, c); } std::string prettyString() const override { return "RnnLm"; diff --git a/flashlight/fl/meter/CountMeter.cpp b/flashlight/fl/meter/CountMeter.cpp index 3dde5f1..05334fb 100644 --- a/flashlight/fl/meter/CountMeter.cpp +++ b/flashlight/fl/meter/CountMeter.cpp @@ -24,8 +24,6 @@ std::vector CountMeter::value() const { return counts_; } -void CountMeter::reset() { - std::fill(counts_.begin(), counts_.end(), 0); -} +void CountMeter::reset() { std::fill(counts_.begin(), counts_.end(), 0); } } // namespace fl diff --git a/flashlight/fl/meter/EditDistanceMeter.h b/flashlight/fl/meter/EditDistanceMeter.h index 5a0e079..6e0c05e 100644 --- a/flashlight/fl/meter/EditDistanceMeter.h +++ b/flashlight/fl/meter/EditDistanceMeter.h @@ -70,9 +70,7 @@ class FL_API EditDistanceMeter { ); /** Updates all the counters with an `ErrorState`. */ - void add(const ErrorState& es, const int64_t n) { - add(n, es.ndel, es.nins, es.nsub); - } + void add(const ErrorState& es, const int64_t n) { add(n, es.ndel, es.nins, es.nsub); } /** Returns a vector of five values: * - `error rate`: \f$ \frac{(ndel + nins + nsub)}{n} \times 100.0 \f$ diff --git a/flashlight/fl/meter/TimeMeter.cpp b/flashlight/fl/meter/TimeMeter.cpp index 765c914..2f89e27 100644 --- a/flashlight/fl/meter/TimeMeter.cpp +++ b/flashlight/fl/meter/TimeMeter.cpp @@ -53,9 +53,7 @@ void TimeMeter::resume() { isStopped_ = false; } -void TimeMeter::incUnit(int64_t num) { - curN_ += num; -} +void TimeMeter::incUnit(int64_t num) { curN_ += num; } void TimeMeter::stopAndIncUnit(int64_t num) { stop(); diff --git a/flashlight/fl/meter/TopKMeter.cpp b/flashlight/fl/meter/TopKMeter.cpp index a54dd3d..399cff0 100644 --- a/flashlight/fl/meter/TopKMeter.cpp +++ b/flashlight/fl/meter/TopKMeter.cpp @@ -43,9 +43,7 @@ double TopKMeter::value() const { return (static_cast(correct_) / n_) * 100.0f; } -std::pair TopKMeter::getStats() { - return std::make_pair(correct_, n_); -} +std::pair TopKMeter::getStats() { return std::make_pair(correct_, n_); } void TopKMeter::set(int32_t correct, int32_t n) { n_ = n; diff --git a/flashlight/fl/nn/Init.cpp b/flashlight/fl/nn/Init.cpp index 0ad37ac..4644b33 100644 --- a/flashlight/fl/nn/Init.cpp +++ b/flashlight/fl/nn/Init.cpp @@ -107,17 +107,11 @@ namespace detail { } // namespace detail -Variable input(const Tensor& arr) { - return Variable(arr, false); -} +Variable input(const Tensor& arr) { return Variable(arr, false); } -Variable noGrad(const Tensor& arr) { - return Variable(arr, false); -} +Variable noGrad(const Tensor& arr) { return Variable(arr, false); } -Variable param(const Tensor& arr) { - return Variable(arr, true); -} +Variable param(const Tensor& arr) { return Variable(arr, true); } Variable constant( double val, @@ -125,8 +119,7 @@ Variable constant( int inputSize, fl::dtype type, bool calcGrad -) { - return constant(val, Shape({outputSize, inputSize}), type, calcGrad); +) { return constant(val, Shape({outputSize, inputSize}), type, calcGrad); } Variable constant(double val, const Shape& dims, fl::dtype type, bool calcGrad) { @@ -155,8 +148,7 @@ Variable uniform( double max, fl::dtype type, bool calcGrad -) { - return uniform(Shape({outputSize, inputSize}), min, max, type, calcGrad); +) { return uniform(Shape({outputSize, inputSize}), min, max, type, calcGrad); } Variable uniform( @@ -165,9 +157,7 @@ Variable uniform( double max, fl::dtype type, bool calcGrad -) { - return Variable(detail::uniform(dims, min, max, type), calcGrad); -} +) { return Variable(detail::uniform(dims, min, max, type), calcGrad); } Variable normal( int outputSize, @@ -176,8 +166,7 @@ Variable normal( double mean, fl::dtype type, bool calcGrad -) { - return normal(Shape({outputSize, inputSize}), stdv, mean, type, calcGrad); +) { return normal(Shape({outputSize, inputSize}), stdv, mean, type, calcGrad); } Variable normal( @@ -186,27 +175,21 @@ Variable normal( double mean, fl::dtype type, bool calcGrad -) { - return Variable(detail::normal(dims, stdv, mean, type), calcGrad); -} +) { return Variable(detail::normal(dims, stdv, mean, type), calcGrad); } Variable kaimingUniform( const Shape& shape, int fanIn, fl::dtype type /* = fl::dtype::f32 */, bool calcGrad /* = true */ -) { - return Variable(detail::kaimingUniform(shape, fanIn, type), calcGrad); -} +) { return Variable(detail::kaimingUniform(shape, fanIn, type), calcGrad); } Variable kaimingNormal( const Shape& shape, int fanIn, fl::dtype type /* = fl::dtype::f32 */, bool calcGrad /* = true */ -) { - return Variable(detail::kaimingNormal(shape, fanIn, type), calcGrad); -} +) { return Variable(detail::kaimingNormal(shape, fanIn, type), calcGrad); } Variable glorotUniform( const Shape& shape, @@ -214,9 +197,7 @@ Variable glorotUniform( int fanOut, fl::dtype type /* = fl::dtype::f32 */, bool calcGrad /* = true */ -) { - return Variable(detail::glorotUniform(shape, fanIn, fanOut, type), calcGrad); -} +) { return Variable(detail::glorotUniform(shape, fanIn, fanOut, type), calcGrad); } Variable glorotNormal( const Shape& shape, @@ -224,9 +205,7 @@ Variable glorotNormal( int fanOut, fl::dtype type /* = fl::dtype::f32 */, bool calcGrad /* = true */ -) { - return Variable(detail::glorotNormal(shape, fanIn, fanOut, type), calcGrad); -} +) { return Variable(detail::glorotNormal(shape, fanIn, fanOut, type), calcGrad); } Variable truncNormal( const Shape& shape, diff --git a/flashlight/fl/nn/modules/Activations.cpp b/flashlight/fl/nn/modules/Activations.cpp index c583f03..d37e197 100644 --- a/flashlight/fl/nn/modules/Activations.cpp +++ b/flashlight/fl/nn/modules/Activations.cpp @@ -15,9 +15,7 @@ namespace fl { Sigmoid::Sigmoid() = default; -Variable Sigmoid::forward(const Variable& input) { - return sigmoid(input); -} +Variable Sigmoid::forward(const Variable& input) { return sigmoid(input); } std::unique_ptr Sigmoid::clone() const { return std::make_unique(*this); @@ -29,9 +27,7 @@ std::string Sigmoid::prettyString() const { Log::Log() = default; -Variable Log::forward(const Variable& input) { - return log(input); -} +Variable Log::forward(const Variable& input) { return log(input); } std::unique_ptr Log::clone() const { return std::make_unique(*this); @@ -43,9 +39,7 @@ std::string Log::prettyString() const { Tanh::Tanh() = default; -Variable Tanh::forward(const Variable& input) { - return tanh(input); -} +Variable Tanh::forward(const Variable& input) { return tanh(input); } std::unique_ptr Tanh::clone() const { return std::make_unique(*this); @@ -57,9 +51,7 @@ std::string Tanh::prettyString() const { HardTanh::HardTanh() = default; -Variable HardTanh::forward(const Variable& input) { - return clamp(input, -1.0, 1.0); -} +Variable HardTanh::forward(const Variable& input) { return clamp(input, -1.0, 1.0); } std::unique_ptr HardTanh::clone() const { return std::make_unique(*this); @@ -71,9 +63,7 @@ std::string HardTanh::prettyString() const { ReLU::ReLU() = default; -Variable ReLU::forward(const Variable& input) { - return max(input, 0.0); -} +Variable ReLU::forward(const Variable& input) { return max(input, 0.0); } std::unique_ptr ReLU::clone() const { return std::make_unique(*this); @@ -85,9 +75,7 @@ std::string ReLU::prettyString() const { ReLU6::ReLU6() = default; -Variable ReLU6::forward(const Variable& input) { - return clamp(input, 0.0, 6.0); -} +Variable ReLU6::forward(const Variable& input) { return clamp(input, 0.0, 6.0); } std::unique_ptr ReLU6::clone() const { return std::make_unique(*this); @@ -99,9 +87,7 @@ std::string ReLU6::prettyString() const { LeakyReLU::LeakyReLU(double slope) : mSlope_(slope) {} -Variable LeakyReLU::forward(const Variable& input) { - return max(input, mSlope_ * input); -} +Variable LeakyReLU::forward(const Variable& input) { return max(input, mSlope_ * input); } std::unique_ptr LeakyReLU::clone() const { return std::make_unique(*this); @@ -163,9 +149,7 @@ std::string ThresholdReLU::prettyString() const { GatedLinearUnit::GatedLinearUnit(int dim) : dim_(dim) {} -Variable GatedLinearUnit::forward(const Variable& input) { - return gatedlinearunit(input, dim_); -} +Variable GatedLinearUnit::forward(const Variable& input) { return gatedlinearunit(input, dim_); } std::unique_ptr GatedLinearUnit::clone() const { return std::make_unique(*this); @@ -177,9 +161,7 @@ std::string GatedLinearUnit::prettyString() const { LogSoftmax::LogSoftmax(int dim /* = 0 */) : dim_(dim) {} -Variable LogSoftmax::forward(const Variable& input) { - return logSoftmax(input, dim_); -} +Variable LogSoftmax::forward(const Variable& input) { return logSoftmax(input, dim_); } std::unique_ptr LogSoftmax::clone() const { return std::make_unique(*this); @@ -191,9 +173,7 @@ std::string LogSoftmax::prettyString() const { Swish::Swish(double beta /* = 1.0 */) : beta_(beta) {} -Variable Swish::forward(const Variable& input) { - return swish(input, beta_); -} +Variable Swish::forward(const Variable& input) { return swish(input, beta_); } std::unique_ptr Swish::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/nn/modules/Container.cpp b/flashlight/fl/nn/modules/Container.cpp index 11fb591..58afc7c 100644 --- a/flashlight/fl/nn/modules/Container.cpp +++ b/flashlight/fl/nn/modules/Container.cpp @@ -108,9 +108,7 @@ Variable Sequential::forward(const Variable& input) { return output.front(); } -Variable Sequential::operator()(const Variable& input) { - return this->forward(input); -} +Variable Sequential::operator()(const Variable& input) { return this->forward(input); } std::string Sequential::prettyString() const { std::ostringstream ss; diff --git a/flashlight/fl/nn/modules/Container.h b/flashlight/fl/nn/modules/Container.h index 7e38576..52cc6a2 100644 --- a/flashlight/fl/nn/modules/Container.h +++ b/flashlight/fl/nn/modules/Container.h @@ -100,9 +100,7 @@ class FL_API Container : public Module { * @param module the module to add. */ template - void add(std::unique_ptr module) { - add(std::shared_ptr(std::move(module))); - } + void add(std::unique_ptr module) { add(std::shared_ptr(std::move(module))); } /** * Adds a module to `modules_`, and adds parameters to the container's diff --git a/flashlight/fl/nn/modules/Embedding.cpp b/flashlight/fl/nn/modules/Embedding.cpp index 8e2ec84..55c01a4 100644 --- a/flashlight/fl/nn/modules/Embedding.cpp +++ b/flashlight/fl/nn/modules/Embedding.cpp @@ -43,9 +43,7 @@ void Embedding::initialize() { params_ = {embeddings}; } -Variable Embedding::forward(const Variable& input) { - return embedding(input, params_[0]); -} +Variable Embedding::forward(const Variable& input) { return embedding(input, params_[0]); } std::unique_ptr Embedding::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/nn/modules/Identity.cpp b/flashlight/fl/nn/modules/Identity.cpp index eaaad1e..f35390c 100644 --- a/flashlight/fl/nn/modules/Identity.cpp +++ b/flashlight/fl/nn/modules/Identity.cpp @@ -9,9 +9,7 @@ namespace fl { -std::vector Identity::forward(const std::vector& inputs) { - return inputs; -}; +std::vector Identity::forward(const std::vector& inputs) { return inputs; }; std::unique_ptr Identity::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/nn/modules/Loss.cpp b/flashlight/fl/nn/modules/Loss.cpp index 2725f04..2e38298 100644 --- a/flashlight/fl/nn/modules/Loss.cpp +++ b/flashlight/fl/nn/modules/Loss.cpp @@ -66,16 +66,14 @@ std::string MeanAbsoluteError::prettyString() const { Variable BinaryCrossEntropy::forward( const Variable& inputs, const Variable& targets -) { - return mean(flat(binaryCrossEntropy(inputs, targets)), {0}); +) { return mean(flat(binaryCrossEntropy(inputs, targets)), {0}); } Variable BinaryCrossEntropy::forward( const Variable& inputs, const Variable& targets, const Variable& weights -) { - return mean(flat(weights * binaryCrossEntropy(inputs, targets)), {0}); +) { return mean(flat(weights * binaryCrossEntropy(inputs, targets)), {0}); } std::unique_ptr BinaryCrossEntropy::clone() const { @@ -89,9 +87,7 @@ std::string BinaryCrossEntropy::prettyString() const { Variable CategoricalCrossEntropy::forward( const Variable& inputs, const Variable& targets -) { - return categoricalCrossEntropy(inputs, targets, reduction_, ignoreIndex_); -} +) { return categoricalCrossEntropy(inputs, targets, reduction_, ignoreIndex_); } std::unique_ptr CategoricalCrossEntropy::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/nn/modules/Module.cpp b/flashlight/fl/nn/modules/Module.cpp index 0cee8f0..f8976e4 100644 --- a/flashlight/fl/nn/modules/Module.cpp +++ b/flashlight/fl/nn/modules/Module.cpp @@ -63,9 +63,7 @@ int Module::numParamTensors() const { return static_cast(params_.size()); } -std::vector Module::operator()(const std::vector& input) { - return this->forward(input); -} +std::vector Module::operator()(const std::vector& input) { return this->forward(input); } UnaryModule::UnaryModule() = default; @@ -79,9 +77,7 @@ std::vector UnaryModule::forward( return {forward(inputs[0])}; } -Variable UnaryModule::operator()(const Variable& input) { - return this->forward(input); -} +Variable UnaryModule::operator()(const Variable& input) { return this->forward(input); } BinaryModule::BinaryModule() = default; @@ -98,8 +94,6 @@ std::vector BinaryModule::forward( Variable BinaryModule::operator()( const Variable& input1, const Variable& input2 -) { - return this->forward(input1, input2); -} +) { return this->forward(input1, input2); } } // namespace fl diff --git a/flashlight/fl/nn/modules/Normalize.cpp b/flashlight/fl/nn/modules/Normalize.cpp index e66a402..6f64dd3 100644 --- a/flashlight/fl/nn/modules/Normalize.cpp +++ b/flashlight/fl/nn/modules/Normalize.cpp @@ -20,9 +20,7 @@ Normalize::Normalize( eps_(eps), value_(value) {} -Variable Normalize::forward(const Variable& input) { - return value_ * normalize(input, axes_, p_, eps_); -} +Variable Normalize::forward(const Variable& input) { return value_ * normalize(input, axes_, p_, eps_); } std::unique_ptr Normalize::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/nn/modules/Padding.cpp b/flashlight/fl/nn/modules/Padding.cpp index 24bf9ed..e1ef979 100644 --- a/flashlight/fl/nn/modules/Padding.cpp +++ b/flashlight/fl/nn/modules/Padding.cpp @@ -14,9 +14,7 @@ namespace fl { Padding::Padding(std::vector> padding, double val) : m_pad(std::move(padding)), m_val(val) {} -Variable Padding::forward(const Variable& input) { - return padding(input, m_pad, m_val); -} +Variable Padding::forward(const Variable& input) { return padding(input, m_pad, m_val); } std::unique_ptr Padding::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/nn/modules/PrecisionCast.cpp b/flashlight/fl/nn/modules/PrecisionCast.cpp index 8bfb5f4..1669a5b 100644 --- a/flashlight/fl/nn/modules/PrecisionCast.cpp +++ b/flashlight/fl/nn/modules/PrecisionCast.cpp @@ -24,13 +24,10 @@ std::vector PrecisionCast::forward( return outputs; } -Variable PrecisionCast::forward(const Variable& input) { - return forward(std::vector{input}).front(); +Variable PrecisionCast::forward(const Variable& input) { return forward(std::vector{input}).front(); } -Variable PrecisionCast::operator()(const Variable& input) { - return this->forward(input); -} +Variable PrecisionCast::operator()(const Variable& input) { return this->forward(input); } std::unique_ptr PrecisionCast::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/nn/modules/RNN.cpp b/flashlight/fl/nn/modules/RNN.cpp index 24d5c1d..3de1125 100644 --- a/flashlight/fl/nn/modules/RNN.cpp +++ b/flashlight/fl/nn/modules/RNN.cpp @@ -98,13 +98,10 @@ std::vector RNN::forward(const std::vector& inputs) { return output; } -Variable RNN::forward(const Variable& input) { - return forward(std::vector{input}).front(); +Variable RNN::forward(const Variable& input) { return forward(std::vector{input}).front(); } -Variable RNN::operator()(const Variable& input) { - return forward(input); -} +Variable RNN::operator()(const Variable& input) { return forward(input); } std::tuple RNN::forward( const Variable& input, @@ -117,9 +114,7 @@ std::tuple RNN::forward( std::tuple RNN::operator()( const Variable& input, const Variable& hidden_state -) { - return forward(input, hidden_state); -} +) { return forward(input, hidden_state); } std::tuple RNN::forward( const Variable& input, @@ -134,9 +129,7 @@ std::tuple RNN::operator()( const Variable& input, const Variable& hidden_state, const Variable& cell_state -) { - return forward(input, hidden_state, cell_state); -} +) { return forward(input, hidden_state, cell_state); } std::unique_ptr RNN::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/nn/modules/Transform.cpp b/flashlight/fl/nn/modules/Transform.cpp index ccf571e..0723e2a 100644 --- a/flashlight/fl/nn/modules/Transform.cpp +++ b/flashlight/fl/nn/modules/Transform.cpp @@ -17,9 +17,7 @@ Transform::Transform( ) : func_(func), name_(name) {} -Variable Transform::forward(const Variable& input) { - return func_(input); -} +Variable Transform::forward(const Variable& input) { return func_(input); } std::unique_ptr Transform::clone() const { return std::make_unique(*this); diff --git a/flashlight/fl/optim/Optimizers.h b/flashlight/fl/optim/Optimizers.h index fe44e50..70fa3fe 100644 --- a/flashlight/fl/optim/Optimizers.h +++ b/flashlight/fl/optim/Optimizers.h @@ -57,9 +57,7 @@ class FL_API FirstOrderOptimizer { } /** Set the learning rate. */ - void setLr(double lr) { - lr_ = lr; - } + void setLr(double lr) { lr_ = lr; } /** Zero the gradients for all the parameters being optimized. Typically * this will be called after every call to step(). diff --git a/flashlight/fl/runtime/CUDAStream.cpp b/flashlight/fl/runtime/CUDAStream.cpp index 1afef8d..7b95e9f 100644 --- a/flashlight/fl/runtime/CUDAStream.cpp +++ b/flashlight/fl/runtime/CUDAStream.cpp @@ -97,9 +97,7 @@ const CUDADevice& CUDAStream::device() const { return device_; } -CUDADevice& CUDAStream::device() { - return device_; -} +CUDADevice& CUDAStream::device() { return device_; } void CUDAStream::sync() const { FL_CUDA_CHECK(cudaStreamSynchronize(this->nativeStream_)); diff --git a/flashlight/fl/runtime/CUDAUtils.cpp b/flashlight/fl/runtime/CUDAUtils.cpp index 6a8aa7a..677c9ae 100644 --- a/flashlight/fl/runtime/CUDAUtils.cpp +++ b/flashlight/fl/runtime/CUDAUtils.cpp @@ -31,9 +31,7 @@ std::unordered_map> createCUDADevices() { namespace detail { - void check(cudaError_t err, const char* file, int line) { - check(err, "", file, line); - } + void check(cudaError_t err, const char* file, int line) { check(err, "", file, line); } void check(cudaError_t err, const char* prefix, const char* file, int line) { if(err != cudaSuccess) { diff --git a/flashlight/fl/runtime/DeviceType.cpp b/flashlight/fl/runtime/DeviceType.cpp index 2d923a8..94bb697 100644 --- a/flashlight/fl/runtime/DeviceType.cpp +++ b/flashlight/fl/runtime/DeviceType.cpp @@ -16,9 +16,7 @@ std::string deviceTypeToString(const DeviceType type) { } } -std::ostream& operator<<(std::ostream& os, const DeviceType& type) { - return os << deviceTypeToString(type); -} +std::ostream& operator<<(std::ostream& os, const DeviceType& type) { return os << deviceTypeToString(type); } const std::unordered_set& getDeviceTypes() { static std::unordered_set types = { diff --git a/flashlight/fl/runtime/SynchronousStream.cpp b/flashlight/fl/runtime/SynchronousStream.cpp index b3013c7..38a4bd8 100644 --- a/flashlight/fl/runtime/SynchronousStream.cpp +++ b/flashlight/fl/runtime/SynchronousStream.cpp @@ -9,9 +9,7 @@ namespace fl { -X64Device& SynchronousStream::device() { - return device_; -} +X64Device& SynchronousStream::device() { return device_; } const X64Device& SynchronousStream::device() const { return device_; diff --git a/flashlight/fl/tensor/Compute.cpp b/flashlight/fl/tensor/Compute.cpp index ab28197..1419533 100644 --- a/flashlight/fl/tensor/Compute.cpp +++ b/flashlight/fl/tensor/Compute.cpp @@ -39,9 +39,7 @@ namespace { } // namespace -void sync() { - DeviceManager::getInstance().getActiveDevice(fl::kDefaultDeviceType).sync(); -} +void sync() { DeviceManager::getInstance().getActiveDevice(fl::kDefaultDeviceType).sync(); } void sync(const int deviceId) { DeviceManager::getInstance() @@ -84,9 +82,7 @@ void relativeSync(const std::vector& waits, const Stream& waitOn) { stream->relativeSync(waitOn); } -void eval(Tensor& tensor) { - tensor.backend().eval(tensor); -} +void eval(Tensor& tensor) { tensor.backend().eval(tensor); } int getDevice() { return DeviceManager::getInstance() @@ -100,9 +96,7 @@ void setDevice(const int deviceId) { .setActive(); } -int getDeviceCount() { - return DeviceManager::getInstance().getDeviceCount(fl::kDefaultDeviceType); -} +int getDeviceCount() { return DeviceManager::getInstance().getDeviceCount(fl::kDefaultDeviceType); } namespace detail { @@ -110,13 +104,9 @@ namespace detail { const char* msg, const int deviceId, std::ostream* ostream /* = &std::cout */ - ) { - defaultTensorBackend().getMemMgrInfo(msg, deviceId, ostream); - } + ) { defaultTensorBackend().getMemMgrInfo(msg, deviceId, ostream); } - void setMemMgrLogStream(std::ostream* stream) { - defaultTensorBackend().setMemMgrLogStream(stream); - } + void setMemMgrLogStream(std::ostream* stream) { defaultTensorBackend().setMemMgrLogStream(stream); } void setMemMgrLoggingEnabled(const bool enabled) { defaultTensorBackend().setMemMgrLoggingEnabled(enabled); diff --git a/flashlight/fl/tensor/DefaultTensorType.h b/flashlight/fl/tensor/DefaultTensorType.h index 87ac531..a2a52ca 100644 --- a/flashlight/fl/tensor/DefaultTensorType.h +++ b/flashlight/fl/tensor/DefaultTensorType.h @@ -29,7 +29,6 @@ using DefaultTensorType_t = fl::ArrayFireTensor; using DefaultTensorBackend_t = fl::ArrayFireBackend; #define FL_DEFAULT_BACKEND_COMPILE_FLAG FL_USE_ARRAYFIRE - #elif FL_USE_TENSOR_STUB using DefaultTensorType_t = fl::StubTensor; using DefaultTensorBackend_t = fl::StubBackend; diff --git a/flashlight/fl/tensor/Index.h b/flashlight/fl/tensor/Index.h index d5f147a..f762c5b 100644 --- a/flashlight/fl/tensor/Index.h +++ b/flashlight/fl/tensor/Index.h @@ -161,9 +161,7 @@ struct FL_API Index { } template - T& get() { - return std::get(index_); - } + T& get() { return std::get(index_); } IndexVariant getVariant() const { return index_; diff --git a/flashlight/fl/tensor/Random.cpp b/flashlight/fl/tensor/Random.cpp index e4f69c1..f5ef58f 100644 --- a/flashlight/fl/tensor/Random.cpp +++ b/flashlight/fl/tensor/Random.cpp @@ -13,16 +13,10 @@ namespace fl { -void setSeed(const int seed) { - defaultTensorBackend().setSeed(seed); -} +void setSeed(const int seed) { defaultTensorBackend().setSeed(seed); } -Tensor randn(const Shape& shape, dtype type) { - return defaultTensorBackend().randn(shape, type); -} +Tensor randn(const Shape& shape, dtype type) { return defaultTensorBackend().randn(shape, type); } -Tensor rand(const Shape& shape, dtype type) { - return defaultTensorBackend().rand(shape, type); -} +Tensor rand(const Shape& shape, dtype type) { return defaultTensorBackend().rand(shape, type); } } // namespace fl diff --git a/flashlight/fl/tensor/Shape.cpp b/flashlight/fl/tensor/Shape.cpp index 69fe4b7..04b7afb 100644 --- a/flashlight/fl/tensor/Shape.cpp +++ b/flashlight/fl/tensor/Shape.cpp @@ -76,9 +76,7 @@ const std::vector& Shape::get() const { return dims_; } -std::vector& Shape::get() { - return dims_; -}; +std::vector& Shape::get() { return dims_; }; std::string Shape::toString() const { std::stringstream ss; diff --git a/flashlight/fl/tensor/TensorAdapter.cpp b/flashlight/fl/tensor/TensorAdapter.cpp index 68d2c55..6564573 100644 --- a/flashlight/fl/tensor/TensorAdapter.cpp +++ b/flashlight/fl/tensor/TensorAdapter.cpp @@ -15,7 +15,6 @@ #include "flashlight/fl/tensor/TensorBackend.h" #include "flashlight/fl/tensor/TensorBase.h" - namespace fl::detail { DefaultTensorType& DefaultTensorType::getInstance() { diff --git a/flashlight/fl/tensor/TensorAdapter.h b/flashlight/fl/tensor/TensorAdapter.h index 189f807..43aab5d 100644 --- a/flashlight/fl/tensor/TensorAdapter.h +++ b/flashlight/fl/tensor/TensorAdapter.h @@ -23,9 +23,7 @@ namespace fl { * @return a Tensor containing the ArrayFire array */ template -Tensor toTensor(T&&... t) { - return Tensor(std::make_unique(std::forward(t)...)); -} +Tensor toTensor(T&&... t) { return Tensor(std::make_unique(std::forward(t)...)); } /** * The implementation interface for Flashlight Tensor backends. diff --git a/flashlight/fl/tensor/TensorBackend.cpp b/flashlight/fl/tensor/TensorBackend.cpp index b66eae8..59dd817 100644 --- a/flashlight/fl/tensor/TensorBackend.cpp +++ b/flashlight/fl/tensor/TensorBackend.cpp @@ -10,9 +10,7 @@ namespace fl { namespace detail { - bool areBackendsEqual(const Tensor& a, const Tensor& b) { - return a.backendType() == b.backendType(); - } + bool areBackendsEqual(const Tensor& a, const Tensor& b) { return a.backendType() == b.backendType(); } } // namespace detail @@ -63,17 +61,13 @@ Tensor TensorBackend::where( const Tensor& condition, const Tensor& x, const double& y -) { - return where(condition, x, full(condition.shape(), y, x.type())); -} +) { return where(condition, x, full(condition.shape(), y, x.type())); } Tensor TensorBackend::where( const Tensor& condition, const double& x, const Tensor& y -) { - return where(condition, full(condition.shape(), x, y.type()), y); -} +) { return where(condition, full(condition.shape(), x, y.type()), y); } Tensor TensorBackend::minimum(const Tensor& lhs, const double& rhs) { return minimum(lhs, full(lhs.shape(), rhs, dtype_traits::ctype)); diff --git a/flashlight/fl/tensor/TensorBase.cpp b/flashlight/fl/tensor/TensorBase.cpp index ed0d180..7fde6f5 100644 --- a/flashlight/fl/tensor/TensorBase.cpp +++ b/flashlight/fl/tensor/TensorBase.cpp @@ -27,9 +27,7 @@ namespace fl { Tensor::Tensor(std::unique_ptr adapter) : impl_(std::move(adapter)) {} -std::unique_ptr Tensor::releaseAdapter() { - return std::move(impl_); -} +std::unique_ptr Tensor::releaseAdapter() { return std::move(impl_); } Tensor::~Tensor() = default; @@ -259,9 +257,7 @@ const Stream& Tensor::stream() const { return impl_->stream(); } -void Tensor::setContext(void* context) { - impl_->setContext(context); -} +void Tensor::setContext(void* context) { impl_->setContext(context); } void* Tensor::getContext() const { return impl_->getContext(); @@ -365,9 +361,7 @@ FL_CREATE_FUN_LITERAL_TYPE(const short&); FL_CREATE_FUN_LITERAL_TYPE(const unsigned short&); #undef FL_CREATE_FUN_LITERAL_TYPE -Tensor identity(const Dim dim, const dtype type) { - return defaultTensorBackend().identity(dim, type); -} +Tensor identity(const Dim dim, const dtype type) { return defaultTensorBackend().identity(dim, type); } #define FL_ARANGE_FUN_DEF(TYPE) \ template<> FL_API Tensor arange(TYPE start, TYPE end, TYPE step, const dtype type) { \ @@ -394,17 +388,13 @@ Tensor iota(const Shape& dims, const Shape& tileDims, const dtype type) { /************************ Shaping and Indexing *************************/ -Tensor reshape(const Tensor& tensor, const Shape& shape) { - return tensor.backend().reshape(tensor, shape); -} +Tensor reshape(const Tensor& tensor, const Shape& shape) { return tensor.backend().reshape(tensor, shape); } Tensor transpose(const Tensor& tensor, const Shape& axes /* = {} */) { return tensor.backend().transpose(tensor, axes); } -Tensor tile(const Tensor& tensor, const Shape& shape) { - return tensor.backend().tile(tensor, shape); -} +Tensor tile(const Tensor& tensor, const Shape& shape) { return tensor.backend().tile(tensor, shape); } Tensor concatenate(const std::vector& tensors, const unsigned axis) { if(tensors.empty()) @@ -417,8 +407,8 @@ Tensor concatenate(const std::vector& tensors, const unsigned axis) { tensors.begin(), tensors.end(), [b](const Tensor& t) { - return t.backendType() == b; - } + return t.backendType() == b; + } ); if(!matches) throw std::invalid_argument( @@ -428,82 +418,46 @@ Tensor concatenate(const std::vector& tensors, const unsigned axis) { return tensors.front().backend().concatenate(tensors, axis); } -Tensor nonzero(const Tensor& tensor) { - return tensor.backend().nonzero(tensor); -} +Tensor nonzero(const Tensor& tensor) { return tensor.backend().nonzero(tensor); } Tensor pad( const Tensor& input, const std::vector>& padWidths, const PadType type -) { - return input.backend().pad(input, padWidths, type); -} +) { return input.backend().pad(input, padWidths, type); } /************************** Unary Operators ***************************/ -Tensor exp(const Tensor& tensor) { - return tensor.backend().exp(tensor); -} +Tensor exp(const Tensor& tensor) { return tensor.backend().exp(tensor); } -Tensor log(const Tensor& tensor) { - return tensor.backend().log(tensor); -} +Tensor log(const Tensor& tensor) { return tensor.backend().log(tensor); } -Tensor negative(const Tensor& tensor) { - return tensor.backend().negative(tensor); -} +Tensor negative(const Tensor& tensor) { return tensor.backend().negative(tensor); } -Tensor logicalNot(const Tensor& tensor) { - return tensor.backend().logicalNot(tensor); -} +Tensor logicalNot(const Tensor& tensor) { return tensor.backend().logicalNot(tensor); } -Tensor log1p(const Tensor& tensor) { - return tensor.backend().log1p(tensor); -} +Tensor log1p(const Tensor& tensor) { return tensor.backend().log1p(tensor); } -Tensor sin(const Tensor& tensor) { - return tensor.backend().sin(tensor); -} +Tensor sin(const Tensor& tensor) { return tensor.backend().sin(tensor); } -Tensor cos(const Tensor& tensor) { - return tensor.backend().cos(tensor); -} +Tensor cos(const Tensor& tensor) { return tensor.backend().cos(tensor); } -Tensor sqrt(const Tensor& tensor) { - return tensor.backend().sqrt(tensor); -} +Tensor sqrt(const Tensor& tensor) { return tensor.backend().sqrt(tensor); } -Tensor tanh(const Tensor& tensor) { - return tensor.backend().tanh(tensor); -} +Tensor tanh(const Tensor& tensor) { return tensor.backend().tanh(tensor); } -Tensor floor(const Tensor& tensor) { - return tensor.backend().floor(tensor); -} +Tensor floor(const Tensor& tensor) { return tensor.backend().floor(tensor); } -Tensor ceil(const Tensor& tensor) { - return tensor.backend().ceil(tensor); -} +Tensor ceil(const Tensor& tensor) { return tensor.backend().ceil(tensor); } -Tensor rint(const Tensor& tensor) { - return tensor.backend().rint(tensor); -} +Tensor rint(const Tensor& tensor) { return tensor.backend().rint(tensor); } -Tensor absolute(const Tensor& tensor) { - return tensor.backend().absolute(tensor); -} +Tensor absolute(const Tensor& tensor) { return tensor.backend().absolute(tensor); } -Tensor sigmoid(const Tensor& tensor) { - return tensor.backend().sigmoid(tensor); -} +Tensor sigmoid(const Tensor& tensor) { return tensor.backend().sigmoid(tensor); } -Tensor erf(const Tensor& tensor) { - return tensor.backend().erf(tensor); -} +Tensor erf(const Tensor& tensor) { return tensor.backend().erf(tensor); } -Tensor flip(const Tensor& tensor, const unsigned dim) { - return tensor.backend().flip(tensor, dim); -} +Tensor flip(const Tensor& tensor, const unsigned dim) { return tensor.backend().flip(tensor, dim); } Tensor clip(const Tensor& tensor, const Tensor& low, const Tensor& high) { FL_TENSOR_BACKENDS_MATCH_CHECK(tensor, low, high); @@ -528,25 +482,15 @@ Tensor roll(const Tensor& tensor, const int shift, const unsigned axis) { return tensor.backend().roll(tensor, shift, axis); } -Tensor isnan(const Tensor& tensor) { - return tensor.backend().isnan(tensor); -} +Tensor isnan(const Tensor& tensor) { return tensor.backend().isnan(tensor); } -Tensor isinf(const Tensor& tensor) { - return tensor.backend().isinf(tensor); -} +Tensor isinf(const Tensor& tensor) { return tensor.backend().isinf(tensor); } -Tensor sign(const Tensor& tensor) { - return tensor.backend().sign(tensor); -} +Tensor sign(const Tensor& tensor) { return tensor.backend().sign(tensor); } -Tensor tril(const Tensor& tensor) { - return tensor.backend().tril(tensor); -} +Tensor tril(const Tensor& tensor) { return tensor.backend().tril(tensor); } -Tensor triu(const Tensor& tensor) { - return tensor.backend().triu(tensor); -} +Tensor triu(const Tensor& tensor) { return tensor.backend().triu(tensor); } Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y) { FL_TENSOR_BACKENDS_MATCH_CHECK(condition, x, y); @@ -585,9 +529,7 @@ void sort( const Tensor& input, const Dim axis, const SortMode sortMode /* = SortMode::Descending */ -) { - return values.backend().sort(values, indices, input, axis, sortMode); -} +) { return values.backend().sort(values, indices, input, axis, sortMode); } Tensor argsort(const Tensor& input, const Dim axis, const SortMode sortMode) { return input.backend().argsort(input, axis, sortMode); @@ -666,34 +608,22 @@ Tensor maximum(const Tensor& lhs, const Tensor& rhs) { return lhs.backend().maximum(lhs, rhs); } -Tensor minimum(const Tensor& lhs, const double& rhs) { - return lhs.backend().minimum(lhs, rhs); -} +Tensor minimum(const Tensor& lhs, const double& rhs) { return lhs.backend().minimum(lhs, rhs); } -Tensor minimum(const double& lhs, const Tensor& rhs) { - return rhs.backend().minimum(lhs, rhs); -} +Tensor minimum(const double& lhs, const Tensor& rhs) { return rhs.backend().minimum(lhs, rhs); } -Tensor maximum(const Tensor& lhs, const double& rhs) { - return lhs.backend().maximum(lhs, rhs); -} +Tensor maximum(const Tensor& lhs, const double& rhs) { return lhs.backend().maximum(lhs, rhs); } -Tensor maximum(const double& lhs, const Tensor& rhs) { - return rhs.backend().maximum(lhs, rhs); -} +Tensor maximum(const double& lhs, const Tensor& rhs) { return rhs.backend().maximum(lhs, rhs); } Tensor power(const Tensor& lhs, const Tensor& rhs) { FL_TENSOR_BACKENDS_MATCH_CHECK(lhs, rhs); return lhs.backend().power(lhs, rhs); } -Tensor power(const Tensor& lhs, const double& rhs) { - return lhs.backend().power(lhs, rhs); -} +Tensor power(const Tensor& lhs, const double& rhs) { return lhs.backend().power(lhs, rhs); } -Tensor power(const double& lhs, const Tensor& rhs) { - return rhs.backend().power(lhs, rhs); -} +Tensor power(const double& lhs, const Tensor& rhs) { return rhs.backend().power(lhs, rhs); } /******************************* BLAS ********************************/ Tensor matmul( @@ -712,17 +642,13 @@ Tensor amin( const Tensor& input, const std::vector& axes /* = {} */, const bool keepDims /* = false */ -) { - return input.backend().amin(input, axes, keepDims); -} +) { return input.backend().amin(input, axes, keepDims); } Tensor amax( const Tensor& input, const std::vector& axes /* = {} */, const bool keepDims /* = false */ -) { - return input.backend().amax(input, axes, keepDims); -} +) { return input.backend().amax(input, axes, keepDims); } void min( Tensor& values, @@ -750,95 +676,71 @@ Tensor sum( const Tensor& input, const std::vector& axes /* = {} */, const bool keepDims /* = false */ -) { - return input.backend().sum(input, axes, keepDims); -} +) { return input.backend().sum(input, axes, keepDims); } -Tensor cumsum(const Tensor& input, const unsigned axis) { - return input.backend().cumsum(input, axis); -} +Tensor cumsum(const Tensor& input, const unsigned axis) { return input.backend().cumsum(input, axis); } Tensor argmax( const Tensor& input, const unsigned axis, const bool keepDims /* = false */ -) { - return input.backend().argmax(input, axis, keepDims); -} +) { return input.backend().argmax(input, axis, keepDims); } Tensor argmin( const Tensor& input, const unsigned axis, const bool keepDims /* = false */ -) { - return input.backend().argmin(input, axis, keepDims); -} +) { return input.backend().argmin(input, axis, keepDims); } Tensor mean( const Tensor& input, const std::vector& axes /* = {} */, const bool keepDims /* = false */ -) { - return input.backend().mean(input, axes, keepDims); -} +) { return input.backend().mean(input, axes, keepDims); } Tensor median( const Tensor& input, const std::vector& axes /* = {} */, const bool keepDims /* = false */ -) { - return input.backend().median(input, axes, keepDims); -} +) { return input.backend().median(input, axes, keepDims); } Tensor var( const Tensor& input, const std::vector& axes /* = {} */, const bool bias, const bool keepDims /* = false */ -) { - return input.backend().var(input, axes, bias, keepDims); -} +) { return input.backend().var(input, axes, bias, keepDims); } Tensor std( const Tensor& input, const std::vector& axes /* = {} */, const bool keepDims /* = false */ -) { - return input.backend().std(input, axes, keepDims); -} +) { return input.backend().std(input, axes, keepDims); } Tensor norm( const Tensor& input, const std::vector& axes /* = {} */, double p /* = 2 */, const bool keepDims /* = false */ -) { - return input.backend().norm(input, axes, p, keepDims); -} +) { return input.backend().norm(input, axes, p, keepDims); } Tensor countNonzero( const Tensor& input, const std::vector& axes /* = {} */, const bool keepDims /* = false */ -) { - return input.backend().countNonzero(input, axes, keepDims); -} +) { return input.backend().countNonzero(input, axes, keepDims); } Tensor any( const Tensor& input, const std::vector& axes /* = {} */, const bool keepDims /* = false */ -) { - return input.backend().any(input, axes, keepDims); -} +) { return input.backend().any(input, axes, keepDims); } Tensor all( const Tensor& input, const std::vector& axes /* = {} */, const bool keepDims /* = false */ -) { - return input.backend().all(input, axes, keepDims); -} +) { return input.backend().all(input, axes, keepDims); } /************************** Utilities ***************************/ @@ -847,9 +749,7 @@ std::ostream& operator<<(std::ostream& ostr, const Tensor& t) { return ostr; } -void print(const Tensor& tensor) { - tensor.backend().print(tensor); -} +void print(const Tensor& tensor) { tensor.backend().print(tensor); } bool allClose( const fl::Tensor& a, @@ -890,17 +790,11 @@ std::ostream& operator<<(std::ostream& os, const TensorBackendType type) { namespace detail { - std::unique_ptr releaseAdapter(Tensor&& t) { - return t.releaseAdapter(); - } + std::unique_ptr releaseAdapter(Tensor&& t) { return t.releaseAdapter(); } - std::unique_ptr releaseAdapterUnsafe(Tensor& t) { - return t.releaseAdapter(); - } + std::unique_ptr releaseAdapterUnsafe(Tensor& t) { return t.releaseAdapter(); } - bool areTensorTypesEqual(const Tensor& a, const Tensor& b) { - return a.type() == b.type(); - } + bool areTensorTypesEqual(const Tensor& a, const Tensor& b) { return a.type() == b.type(); } } // namespace detail diff --git a/flashlight/fl/tensor/TensorBase.h b/flashlight/fl/tensor/TensorBase.h index f16062d..82ae03f 100644 --- a/flashlight/fl/tensor/TensorBase.h +++ b/flashlight/fl/tensor/TensorBase.h @@ -258,9 +258,7 @@ class FL_API Tensor { fl::dtype t, const uint8_t* ptr, Location memoryLocation - ) { - return Tensor(s, t, ptr, memoryLocation); - } + ) { return Tensor(s, t, ptr, memoryLocation); } /** * Deep-copies the tensor, including underlying data. @@ -844,9 +842,7 @@ FL_API Tensor pad( * @return a tensor with elements negated. */ FL_API Tensor negative(const Tensor& tensor); -inline Tensor operator-(const Tensor& tensor) { - return negative(tensor); -} +inline Tensor operator-(const Tensor& tensor) { return negative(tensor); } /** * Performs element-wise logical-not on the elements of a tensor @@ -855,9 +851,7 @@ inline Tensor operator-(const Tensor& tensor) { * @return a tensor with element-wise logical not of the input */ FL_API Tensor logicalNot(const Tensor& tensor); -inline Tensor operator!(const Tensor& tensor) { - return logicalNot(tensor); -} +inline Tensor operator!(const Tensor& tensor) { return logicalNot(tensor); } /** * Compute the element-wise exponential of a tensor @@ -948,9 +942,7 @@ FL_API Tensor rint(const Tensor& tensor); FL_API Tensor absolute(const Tensor& tensor); // \copydoc absolute -inline Tensor abs(const Tensor& tensor) { - return absolute(tensor); -} +inline Tensor abs(const Tensor& tensor) { return absolute(tensor); } /** * Returns the element-wise sigmoid the input: @@ -1769,9 +1761,7 @@ namespace detail { const Tensor& a, const Tensor& b, const Args&... args - ) { - return areTensorTypesEqual(a, b) && areTensorTypesEqual(a, args...); - } + ) { return areTensorTypesEqual(a, b) && areTensorTypesEqual(a, args...); } } // namespace detail diff --git a/flashlight/fl/tensor/TensorExtension.h b/flashlight/fl/tensor/TensorExtension.h index adf8c13..998a886 100644 --- a/flashlight/fl/tensor/TensorExtension.h +++ b/flashlight/fl/tensor/TensorExtension.h @@ -109,9 +109,7 @@ bool registerTensorExtension(TensorBackendType backendType) { template class TensorExtension : public TensorExtensionBase { public: - static TensorExtensionType getExtensionType() { - return T::extensionType; - } + static TensorExtensionType getExtensionType() { return T::extensionType; } }; template diff --git a/flashlight/fl/tensor/Types.cpp b/flashlight/fl/tensor/Types.cpp index 1605312..8625593 100644 --- a/flashlight/fl/tensor/Types.cpp +++ b/flashlight/fl/tensor/Types.cpp @@ -69,9 +69,7 @@ size_t getTypeSize(dtype type) { } } -const std::string& dtypeToString(dtype type) { - return kTypeToString.at(type); -} +const std::string& dtypeToString(dtype type) { return kTypeToString.at(type); } fl::dtype stringToDtype(const std::string& string) { if(kStringToType.find(string) != kStringToType.end()) diff --git a/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp b/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp index b84ee1a..c3c5cea 100644 --- a/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp +++ b/flashlight/fl/tensor/backend/af/AdvancedIndex.cpp @@ -20,9 +20,7 @@ namespace detail { const af::dim4& outDims, const std::vector& idxArr, af::array& out - ) { - throw std::runtime_error("gradAdvancedIndex not implemented for cpu"); - } + ) { throw std::runtime_error("gradAdvancedIndex not implemented for cpu"); } } // namespace detail } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/AdvancedIndex.cu b/flashlight/fl/tensor/backend/af/AdvancedIndex.cu index de2b472..171b528 100644 --- a/flashlight/fl/tensor/backend/af/AdvancedIndex.cu +++ b/flashlight/fl/tensor/backend/af/AdvancedIndex.cu @@ -19,15 +19,16 @@ #define GRID_SIZE 32 #define BLOCK_SIZE 256 -const std::unordered_set < af::dtype > validIndexTypes { +const std::unordered_set validIndexTypes { af::dtype::s32, af::dtype::s64, af::dtype::u32, af::dtype::u64 }; -template < class Float, class Index -> __global__ void advancedIndexKernel( +template +__global__ void advancedIndexKernel( const Float* inp, const dim_t* idxStart, const dim_t* idxEnd, @@ -81,122 +82,122 @@ template < class Float, class Index } namespace fl { - namespace detail { - - void advancedIndex( - const af::array& inp, - const af::dim4& idxStart, - const af::dim4& idxEnd, - const af::dim4& outDims, - const std::vector < af::array > &idxArr, - af::array& out - ) { - auto inpType = inp.type(); - auto outType = out.type(); - - if((inpType != af::dtype::f32) && (inpType != af::dtype::f16)) - throw std::invalid_argument("Input type must be f16/f32"); - if((outType != af::dtype::f32) && (outType != af::dtype::f16)) - throw std::invalid_argument("Output type must be f16/f32"); - if(idxArr.size() != 4) - throw std::invalid_argument("Index array vector must be length 4"); - - af::dim4 idxPtr; - // Extract raw device pointers for dimensions - // that have an array as af::index variable - - // Dtype checking - std::vector < af::dtype > idxTypes; - for(int i = 0; i < 4; i++) { - if(idxArr[i].isempty()) { - idxPtr[i] = 0; - continue; - } - if(validIndexTypes.find(idxArr[i].type()) == validIndexTypes.end()) - throw std::invalid_argument( - "Index type must be one of s32/s64/u32/u64, observed type is " - + std::to_string(idxArr[i].type()) - ); - idxTypes.push_back(idxArr[i].type()); - idxPtr[i] = (dim_t) (idxArr[i].device < void > ()); +namespace detail { + + void advancedIndex( + const af::array& inp, + const af::dim4& idxStart, + const af::dim4& idxEnd, + const af::dim4& outDims, + const std::vector& idxArr, + af::array& out + ) { + auto inpType = inp.type(); + auto outType = out.type(); + + if((inpType != af::dtype::f32) && (inpType != af::dtype::f16)) + throw std::invalid_argument("Input type must be f16/f32"); + if((outType != af::dtype::f32) && (outType != af::dtype::f16)) + throw std::invalid_argument("Output type must be f16/f32"); + if(idxArr.size() != 4) + throw std::invalid_argument("Index array vector must be length 4"); + + af::dim4 idxPtr; + // Extract raw device pointers for dimensions + // that have an array as af::index variable + + // Dtype checking + std::vector idxTypes; + for(int i = 0; i < 4; i++) { + if(idxArr[i].isempty()) { + idxPtr[i] = 0; + continue; } - for(int i = 0; i + 1 < idxTypes.size(); i++) - if(idxTypes[i] != idxTypes[i + 1]) - throw std::invalid_argument( - "Index type must be the same across all dimensions" - ); - - af::array inpCast = inp; - af::array outCast = out; - if(inpType == af::dtype::f16) - inpCast = inp.as(af::dtype::f32); - if(outType == af::dtype::f16) - outCast = out.as(af::dtype::f32); - - void* inpRawPtr = inpCast.device < void > (); - void* outRawPtr = outCast.device < void > (); - af::array arrIdxPtr(4, idxPtr.get()); - af::array arrIdxEnd(4, idxEnd.get()); - af::array arrIdxStart(4, idxStart.get()); - af::array arrOutDims(4, outDims.get()); - void* arrIdxStartDev = arrIdxStart.device < void > (); - void* arrIdxEndDev = arrIdxEnd.device < void > (); - void* arrOutDimsDev = arrOutDims.device < void > (); - void* arrIdxPtrDev = arrIdxPtr.device < void > (); - - cudaStream_t stream = afcu::getStream(af::getDevice()); - if(idxTypes.size() == 0 || idxTypes[0] == af::dtype::s32) - advancedIndexKernel < float, int32_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( - static_cast < const float* > (inpRawPtr), - static_cast < const dim_t * > (arrIdxStartDev), - static_cast < const dim_t * > (arrIdxEndDev), - static_cast < const dim_t * > (arrOutDimsDev), - static_cast < const dim_t * > (arrIdxPtrDev), - static_cast < float* > (outRawPtr)); - else if(idxTypes[0] == af::dtype::s64) - advancedIndexKernel < float, int64_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( - static_cast < const float* > (inpRawPtr), - static_cast < const dim_t * > (arrIdxStartDev), - static_cast < const dim_t * > (arrIdxEndDev), - static_cast < const dim_t * > (arrOutDimsDev), - static_cast < const dim_t * > (arrIdxPtrDev), - static_cast < float* > (outRawPtr)); - else if(idxTypes[0] == af::dtype::u32) - advancedIndexKernel < float, uint32_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( - static_cast < const float* > (inpRawPtr), - static_cast < const dim_t * > (arrIdxStartDev), - static_cast < const dim_t * > (arrIdxEndDev), - static_cast < const dim_t * > (arrOutDimsDev), - static_cast < const dim_t * > (arrIdxPtrDev), - static_cast < float* > (outRawPtr)); - else if(idxTypes[0] == af::dtype::u64) - advancedIndexKernel < float, uint64_t > << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( - static_cast < const float* > (inpRawPtr), - static_cast < const dim_t * > (arrIdxStartDev), - static_cast < const dim_t * > (arrIdxEndDev), - static_cast < const dim_t * > (arrOutDimsDev), - static_cast < const dim_t * > (arrIdxPtrDev), - static_cast < float* > (outRawPtr)); - else - throw std::invalid_argument("Index type must be one of s32/s64/u32/u64"); - if(cudaPeekAtLastError() != cudaSuccess) - throw std::runtime_error( - "ArrayFireTensor advancedIndex kernel CUDA failure" + if(validIndexTypes.find(idxArr[i].type()) == validIndexTypes.end()) + throw std::invalid_argument( + "Index type must be one of s32/s64/u32/u64, observed type is " + + std::to_string(idxArr[i].type()) ); - - inpCast.unlock(); - outCast.unlock(); - arrIdxStart.unlock(); - arrIdxEnd.unlock(); - arrOutDims.unlock(); - arrIdxPtr.unlock(); - for(const auto& arr : idxArr) - arr.unlock(); - - out = outCast; - if(outType == af::dtype::f16) - out = outCast.as(af::dtype::f16); + idxTypes.push_back(idxArr[i].type()); + idxPtr[i] = (dim_t) (idxArr[i].device()); } + for(int i = 0; i + 1 < idxTypes.size(); i++) + if(idxTypes[i] != idxTypes[i + 1]) + throw std::invalid_argument( + "Index type must be the same across all dimensions" + ); + + af::array inpCast = inp; + af::array outCast = out; + if(inpType == af::dtype::f16) + inpCast = inp.as(af::dtype::f32); + if(outType == af::dtype::f16) + outCast = out.as(af::dtype::f32); + + void* inpRawPtr = inpCast.device(); + void* outRawPtr = outCast.device(); + af::array arrIdxPtr(4, idxPtr.get()); + af::array arrIdxEnd(4, idxEnd.get()); + af::array arrIdxStart(4, idxStart.get()); + af::array arrOutDims(4, outDims.get()); + void* arrIdxStartDev = arrIdxStart.device(); + void* arrIdxEndDev = arrIdxEnd.device(); + void* arrOutDimsDev = arrOutDims.device(); + void* arrIdxPtrDev = arrIdxPtr.device(); + + cudaStream_t stream = afcu::getStream(af::getDevice()); + if(idxTypes.size() == 0 || idxTypes[0] == af::dtype::s32) + advancedIndexKernel << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( + static_cast(inpRawPtr), + static_cast(arrIdxStartDev), + static_cast(arrIdxEndDev), + static_cast(arrOutDimsDev), + static_cast(arrIdxPtrDev), + static_cast(outRawPtr)); + else if(idxTypes[0] == af::dtype::s64) + advancedIndexKernel << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( + static_cast(inpRawPtr), + static_cast(arrIdxStartDev), + static_cast(arrIdxEndDev), + static_cast(arrOutDimsDev), + static_cast(arrIdxPtrDev), + static_cast(outRawPtr)); + else if(idxTypes[0] == af::dtype::u32) + advancedIndexKernel << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( + static_cast(inpRawPtr), + static_cast(arrIdxStartDev), + static_cast(arrIdxEndDev), + static_cast(arrOutDimsDev), + static_cast(arrIdxPtrDev), + static_cast(outRawPtr)); + else if(idxTypes[0] == af::dtype::u64) + advancedIndexKernel << < GRID_SIZE, BLOCK_SIZE, 0, stream >> > ( + static_cast(inpRawPtr), + static_cast(arrIdxStartDev), + static_cast(arrIdxEndDev), + static_cast(arrOutDimsDev), + static_cast(arrIdxPtrDev), + static_cast(outRawPtr)); + else + throw std::invalid_argument("Index type must be one of s32/s64/u32/u64"); + if(cudaPeekAtLastError() != cudaSuccess) + throw std::runtime_error( + "ArrayFireTensor advancedIndex kernel CUDA failure" + ); + + inpCast.unlock(); + outCast.unlock(); + arrIdxStart.unlock(); + arrIdxEnd.unlock(); + arrOutDims.unlock(); + arrIdxPtr.unlock(); + for(const auto& arr : idxArr) + arr.unlock(); + + out = outCast; + if(outType == af::dtype::f16) + out = outCast.as(af::dtype::f16); + } - } // namespace detail +} // namespace detail } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp b/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp index c5896bb..968bf92 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireBackend.cpp @@ -133,9 +133,7 @@ TensorBackendType ArrayFireBackend::backendType() const { /* -------------------------- Compute Functions -------------------------- */ -void ArrayFireBackend::eval(const Tensor& tensor) { - af::eval(toArray(tensor)); -} +void ArrayFireBackend::eval(const Tensor& tensor) { af::eval(toArray(tensor)); } const Stream& ArrayFireBackend::getStreamOfArray( const af::array& arr @@ -205,9 +203,7 @@ void ArrayFireBackend::setMemMgrFlushInterval(const size_t interval) { /* -------------------------- Rand Functions -------------------------- */ -void ArrayFireBackend::setSeed(const int seed) { - af::setSeed(seed); -} +void ArrayFireBackend::setSeed(const int seed) { af::setSeed(seed); } Tensor ArrayFireBackend::randn(const Shape& shape, dtype type) { return toTensor( @@ -399,7 +395,5 @@ Tensor ArrayFireBackend::argsort( return toTensor(std::move(indices), input.ndim()); } -void ArrayFireBackend::print(const Tensor& tensor) { - af::print("ArrayFireTensor", toArray(tensor)); -} +void ArrayFireBackend::print(const Tensor& tensor) { af::print("ArrayFireTensor", toArray(tensor)); } } // namespace fl diff --git a/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp b/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp index b405112..ed581ae 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp @@ -126,9 +126,7 @@ af::array::array_proxy ArrayFireTensor::IndexedArrayComponent::get( } } -af::array& ArrayFireTensor::ArrayComponent::get(const ArrayFireTensor& inst) { - return *(inst.arrayHandle_); -} +af::array& ArrayFireTensor::ArrayComponent::get(const ArrayFireTensor& inst) { return *(inst.arrayHandle_); } const af::array& ArrayFireTensor::getHandle() const { return const_cast(this)->getHandle(); @@ -197,17 +195,11 @@ const Shape& ArrayFireTensor::shape() { return shape_; } -fl::dtype ArrayFireTensor::type() { - return detail::afToFlType(getHandle().type()); -} +fl::dtype ArrayFireTensor::type() { return detail::afToFlType(getHandle().type()); } -bool ArrayFireTensor::isSparse() { - return getHandle().issparse(); -} +bool ArrayFireTensor::isSparse() { return getHandle().issparse(); } -af::dtype ArrayFireTensor::afHandleType() { - return arrayHandle_->type(); -} +af::dtype ArrayFireTensor::afHandleType() { return arrayHandle_->type(); } Location ArrayFireTensor::location() { switch(af::getBackendId(getHandle())) { @@ -223,21 +215,13 @@ Location ArrayFireTensor::location() { } } -void ArrayFireTensor::scalar(void* out) { - AF_CHECK(af_get_scalar(out, getHandle().get())); -} +void ArrayFireTensor::scalar(void* out) { AF_CHECK(af_get_scalar(out, getHandle().get())); } -void ArrayFireTensor::device(void** out) { - AF_CHECK(af_get_device_ptr(out, getHandle().get())); -} +void ArrayFireTensor::device(void** out) { AF_CHECK(af_get_device_ptr(out, getHandle().get())); } -void ArrayFireTensor::host(void* out) { - AF_CHECK(af_get_data_ptr(out, getHandle().get())); -} +void ArrayFireTensor::host(void* out) { AF_CHECK(af_get_data_ptr(out, getHandle().get())); } -void ArrayFireTensor::unlock() { - AF_CHECK(af_unlock_array(getHandle().get())); -} +void ArrayFireTensor::unlock() { AF_CHECK(af_unlock_array(getHandle().get())); } bool ArrayFireTensor::isLocked() { bool res; @@ -250,13 +234,9 @@ bool ArrayFireTensor::isLocked() { return res; } -bool ArrayFireTensor::isContiguous() { - return af::isLinear(getHandle()); -} +bool ArrayFireTensor::isContiguous() { return af::isLinear(getHandle()); } -Shape ArrayFireTensor::strides() { - return detail::afToFlDims(af::getStrides(getHandle()), numDims()); -} +Shape ArrayFireTensor::strides() { return detail::afToFlDims(af::getStrides(getHandle()), numDims()); } const Stream& ArrayFireTensor::stream() const { // TODO indexing is unlikely to change the stream associated with a tensor. @@ -502,14 +482,14 @@ af::array ArrayFireTensor::adjustInPlaceOperandDims(const Tensor& operand) { return doModdims ? af::moddims(operandArr, newDims) : operandArr; } -#define ASSIGN_OP_TENSOR(FUN, AF_OP) \ - void ArrayFireTensor::FUN(const Tensor& tensor) { \ - std::visit( \ - [&tensor, this](auto&& arr) { \ - arr.get(*this) AF_OP this->adjustInPlaceOperandDims(tensor); \ - }, \ - handle_ \ - ); \ +#define ASSIGN_OP_TENSOR(FUN, AF_OP) \ + void ArrayFireTensor::FUN(const Tensor& tensor) { \ + std::visit( \ + [&tensor, this](auto&& arr) { \ + arr.get(*this) AF_OP this->adjustInPlaceOperandDims(tensor); \ + }, \ + handle_ \ + ); \ } #define ASSIGN_OP(FUN, AF_OP) \ diff --git a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp index 91ddb82..e0b6244 100644 --- a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp +++ b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp @@ -116,17 +116,11 @@ CachingMemoryManager::CachingMemoryManager( void CachingMemoryManager::initialize() {} -void CachingMemoryManager::setRecyclingSizeLimit(size_t limit) { - recyclingSizeLimit_ = limit; -} +void CachingMemoryManager::setRecyclingSizeLimit(size_t limit) { recyclingSizeLimit_ = limit; } -void CachingMemoryManager::setSplitSizeLimit(size_t limit) { - splitSizeLimit_ = limit; -} +void CachingMemoryManager::setSplitSizeLimit(size_t limit) { splitSizeLimit_ = limit; } -void CachingMemoryManager::shutdown() { - signalMemoryCleanup(); -} +void CachingMemoryManager::shutdown() { signalMemoryCleanup(); } void CachingMemoryManager::addMemoryManagement(int device) { if(deviceMemInfos_.find(device) != deviceMemInfos_.end()) @@ -409,9 +403,7 @@ void CachingMemoryManager::userLock(const void* ptr) { it->second->userLock_ = true; } -void CachingMemoryManager::userUnlock(const void* ptr) { - this->unlock(const_cast(ptr), true); -} +void CachingMemoryManager::userUnlock(const void* ptr) { this->unlock(const_cast(ptr), true); } bool CachingMemoryManager::isUserLocked(const void* ptr) { if(!ptr) diff --git a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp index 959f8f9..6ba6aa6 100644 --- a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp +++ b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp @@ -95,13 +95,9 @@ DefaultMemoryManager::DefaultMemoryManager( this->maxBuffers = std::max(1, std::stoi(std::string(c))); } -void DefaultMemoryManager::initialize() { - this->setMaxMemorySize(); -} +void DefaultMemoryManager::initialize() { this->setMaxMemorySize(); } -void DefaultMemoryManager::shutdown() { - signalMemoryCleanup(); -} +void DefaultMemoryManager::shutdown() { signalMemoryCleanup(); } void DefaultMemoryManager::addMemoryManagement(int device) { // If there is a memory manager allocated for this device id, we might @@ -350,9 +346,7 @@ void DefaultMemoryManager::userLock(const void* ptr) { } } -void DefaultMemoryManager::userUnlock(const void* ptr) { - this->unlock(const_cast(ptr), true); -} +void DefaultMemoryManager::userUnlock(const void* ptr) { this->unlock(const_cast(ptr), true); } bool DefaultMemoryManager::isUserLocked(const void* ptr) { MemoryInfo& current = this->getCurrentMemoryInfo(); @@ -379,9 +373,7 @@ size_t DefaultMemoryManager::getMaxBytes() { return this->getCurrentMemoryInfo().maxBytes; } -unsigned DefaultMemoryManager::getMaxBuffers() { - return this->maxBuffers; -} +unsigned DefaultMemoryManager::getMaxBuffers() { return this->maxBuffers; } bool DefaultMemoryManager::checkMemoryLimit() { const MemoryInfo& current = this->getCurrentMemoryInfo(); diff --git a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp index 8aea02a..ecbb448 100644 --- a/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp +++ b/flashlight/fl/tensor/backend/af/mem/MemoryManagerAdapter.cpp @@ -44,17 +44,13 @@ MemoryManagerAdapter::~MemoryManagerAdapter() { } } -void MemoryManagerAdapter::setLogStream(std::ostream* logStream) { - logStream_ = logStream; -} +void MemoryManagerAdapter::setLogStream(std::ostream* logStream) { logStream_ = logStream; } std::ostream* MemoryManagerAdapter::getLogStream() const { return logStream_; } -void MemoryManagerAdapter::setLoggingEnabled(bool log) { - loggingEnabled_ = log; -} +void MemoryManagerAdapter::setLoggingEnabled(bool log) { loggingEnabled_ = log; } void MemoryManagerAdapter::setLogFlushInterval(size_t interval) { if(interval < 1) diff --git a/flashlight/fl/tensor/backend/stub/StubBackend.cpp b/flashlight/fl/tensor/backend/stub/StubBackend.cpp index d3b6094..11403b6 100644 --- a/flashlight/fl/tensor/backend/stub/StubBackend.cpp +++ b/flashlight/fl/tensor/backend/stub/StubBackend.cpp @@ -71,17 +71,11 @@ void StubBackend::setMemMgrFlushInterval(const size_t /* interval */) { /* -------------------------- Rand Functions -------------------------- */ -void StubBackend::setSeed(const int /* seed */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +void StubBackend::setSeed(const int /* seed */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::randn(const Shape& /* shape */, dtype /* type */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::randn(const Shape& /* shape */, dtype /* type */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::rand(const Shape& /* shape */, dtype /* type */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::rand(const Shape& /* shape */, dtype /* type */) { FL_STUB_BACKEND_UNIMPLEMENTED; } /* --------------------------- Tensor Operators --------------------------- */ @@ -116,40 +110,30 @@ FL_STUB_BACKEND_CREATE_FUN_LITERAL_DEF(const bool&); FL_STUB_BACKEND_CREATE_FUN_LITERAL_DEF(const short&); FL_STUB_BACKEND_CREATE_FUN_LITERAL_DEF(const unsigned short&); -Tensor StubBackend::identity(const Dim /* dim */, const dtype /* type */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::identity(const Dim /* dim */, const dtype /* type */) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::arange( const Shape& /* shape */, const Dim /* seqDim */, const dtype /* type */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::iota( const Shape& /* dims */, const Shape& /* tileDims */, const dtype /* type */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } /************************ Shaping and Indexing *************************/ Tensor StubBackend::reshape( const Tensor& /* tensor */, const Shape& /* shape */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::transpose( const Tensor& /* tensor */, const Shape& /* axes */ /* = {} */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::tile(const Tensor& /* tensor */, const Shape& /* shape */) { FL_STUB_BACKEND_UNIMPLEMENTED; @@ -158,83 +142,47 @@ Tensor StubBackend::tile(const Tensor& /* tensor */, const Shape& /* shape */) { Tensor StubBackend::concatenate( const std::vector& /* tensors */, const unsigned /* axis */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::nonzero(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::nonzero(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::pad( const Tensor& /* input */, const std::vector>& /* padWidths */, const PadType /* type */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } /************************** Unary Operators ***************************/ -Tensor StubBackend::exp(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::exp(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::log(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::log(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::negative(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::negative(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::logicalNot(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::logicalNot(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::log1p(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::log1p(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::sin(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::sin(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::cos(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::cos(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::sqrt(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::sqrt(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::tanh(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::tanh(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::floor(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::floor(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::ceil(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::ceil(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::rint(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::rint(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::absolute(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::absolute(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::sigmoid(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::sigmoid(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::erf(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::erf(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::flip(const Tensor& /* tensor */, const unsigned /* dim */) { FL_STUB_BACKEND_UNIMPLEMENTED; @@ -244,45 +192,29 @@ Tensor StubBackend::clip( const Tensor& /* tensor */, const Tensor& /* low */, const Tensor& /* high */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::roll( const Tensor& /* tensor */, const int /* shift */, const unsigned /* axis */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::isnan(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::isnan(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::isinf(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::isinf(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::sign(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::sign(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::tril(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::tril(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::triu(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::triu(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::where( const Tensor& /* condition */, const Tensor& /* x */, const Tensor& /* y */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::topk( Tensor& /* values */, @@ -291,17 +223,13 @@ void StubBackend::topk( const unsigned /* k */, const Dim /* axis */, const SortMode /* sortMode */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::sort( const Tensor& /* input */, const Dim /* axis */, const SortMode /* sortMode */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::sort( Tensor& /* values */, @@ -309,17 +237,13 @@ void StubBackend::sort( const Tensor& /* input */, const Dim /* axis */, const SortMode /* sortMode */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::argsort( const Tensor& /* input */, const Dim /* axis */, const SortMode /* sortMode */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } /************************** Binary Operators ***************************/ #define FL_AF_BINARY_OP_TYPE_DEF(FUNC, OP, TYPE) \ @@ -396,9 +320,7 @@ Tensor StubBackend::maximum(const Tensor& /* lhs */, const Tensor& /* rhs */) { FL_STUB_BACKEND_UNIMPLEMENTED; } -Tensor StubBackend::power(const Tensor& /* lhs */, const Tensor& /* rhs */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +Tensor StubBackend::power(const Tensor& /* lhs */, const Tensor& /* rhs */) { FL_STUB_BACKEND_UNIMPLEMENTED; } /************************** BLAS ***************************/ @@ -407,9 +329,7 @@ Tensor StubBackend::matmul( const Tensor& /* rhs */, MatrixProperty /* lhsProp */, MatrixProperty /* rhsProp */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } /************************** Reductions ***************************/ @@ -417,17 +337,13 @@ Tensor StubBackend::amin( const Tensor& /* input */, const std::vector& /* axes */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::amax( const Tensor& /* input */, const std::vector& /* axes */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::min( Tensor& /* values */, @@ -435,9 +351,7 @@ void StubBackend::min( const Tensor& /* input */, const unsigned /* axis */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } void StubBackend::max( Tensor& /* values */, @@ -445,109 +359,81 @@ void StubBackend::max( const Tensor& /* input */, const unsigned /* axis */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::sum( const Tensor& /* input */, const std::vector& /* axes */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::cumsum( const Tensor& /* input */, const unsigned /* axis */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::argmax( const Tensor& /* input */, const unsigned /* axis */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::argmin( const Tensor& /* input */, const unsigned /* axis */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::mean( const Tensor& /* input */, const std::vector& /* axes */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::median( const Tensor& /* input */, const std::vector& /* axes */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::var( const Tensor& /* input */, const std::vector& /* axes */, const bool /* bias */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::std( const Tensor& /* input */, const std::vector& /* axes */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::norm( const Tensor& /* input */, const std::vector& /* axes */, double /* p */ /* = 2 */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::countNonzero( const Tensor& /* input */, const std::vector& /* axes */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::any( const Tensor& /* input */, const std::vector& /* axes */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } Tensor StubBackend::all( const Tensor& /* input */, const std::vector& /* axes */, const bool /* keepDims */ -) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +) { FL_STUB_BACKEND_UNIMPLEMENTED; } -void StubBackend::print(const Tensor& /* tensor */) { - FL_STUB_BACKEND_UNIMPLEMENTED; -} +void StubBackend::print(const Tensor& /* tensor */) { FL_STUB_BACKEND_UNIMPLEMENTED; } } // namespace fl diff --git a/flashlight/fl/tensor/backend/stub/StubTensor.cpp b/flashlight/fl/tensor/backend/stub/StubTensor.cpp index 3c2c6a9..d218e74 100644 --- a/flashlight/fl/tensor/backend/stub/StubTensor.cpp +++ b/flashlight/fl/tensor/backend/stub/StubTensor.cpp @@ -36,13 +36,9 @@ std::unique_ptr StubTensor::clone() const { FL_STUB_TENSOR_UNIMPLEMENTED; } -Tensor StubTensor::copy() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +Tensor StubTensor::copy() { FL_STUB_TENSOR_UNIMPLEMENTED; } -Tensor StubTensor::shallowCopy() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +Tensor StubTensor::shallowCopy() { FL_STUB_TENSOR_UNIMPLEMENTED; } TensorBackendType StubTensor::backendType() const { FL_STUB_TENSOR_UNIMPLEMENTED; @@ -52,61 +48,35 @@ TensorBackend& StubTensor::backend() const { FL_STUB_TENSOR_UNIMPLEMENTED; } -const Shape& StubTensor::shape() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +const Shape& StubTensor::shape() { FL_STUB_TENSOR_UNIMPLEMENTED; } -fl::dtype StubTensor::type() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +fl::dtype StubTensor::type() { FL_STUB_TENSOR_UNIMPLEMENTED; } -bool StubTensor::isSparse() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +bool StubTensor::isSparse() { FL_STUB_TENSOR_UNIMPLEMENTED; } -Location StubTensor::location() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +Location StubTensor::location() { FL_STUB_TENSOR_UNIMPLEMENTED; } -void StubTensor::scalar(void* /* out */) { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +void StubTensor::scalar(void* /* out */) { FL_STUB_TENSOR_UNIMPLEMENTED; } -void StubTensor::device(void** /* out */) { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +void StubTensor::device(void** /* out */) { FL_STUB_TENSOR_UNIMPLEMENTED; } -void StubTensor::host(void* /* out */) { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +void StubTensor::host(void* /* out */) { FL_STUB_TENSOR_UNIMPLEMENTED; } -void StubTensor::unlock() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +void StubTensor::unlock() { FL_STUB_TENSOR_UNIMPLEMENTED; } -bool StubTensor::isLocked() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +bool StubTensor::isLocked() { FL_STUB_TENSOR_UNIMPLEMENTED; } -bool StubTensor::isContiguous() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +bool StubTensor::isContiguous() { FL_STUB_TENSOR_UNIMPLEMENTED; } -Shape StubTensor::strides() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +Shape StubTensor::strides() { FL_STUB_TENSOR_UNIMPLEMENTED; } const Stream& StubTensor::stream() const { FL_STUB_TENSOR_UNIMPLEMENTED; } -Tensor StubTensor::astype(const dtype /* type */) { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +Tensor StubTensor::astype(const dtype /* type */) { FL_STUB_TENSOR_UNIMPLEMENTED; } -Tensor StubTensor::index(const std::vector& /* indices */) { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +Tensor StubTensor::index(const std::vector& /* indices */) { FL_STUB_TENSOR_UNIMPLEMENTED; } Tensor StubTensor::flatten() const { FL_STUB_TENSOR_UNIMPLEMENTED; @@ -116,9 +86,7 @@ Tensor StubTensor::flat(const Index& /* idx */) const { FL_STUB_TENSOR_UNIMPLEMENTED; } -Tensor StubTensor::asContiguousTensor() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +Tensor StubTensor::asContiguousTensor() { FL_STUB_TENSOR_UNIMPLEMENTED; } void StubTensor::setContext(void* /* context */) { // Used to store arbitrary data on a Tensor - can be a noop. @@ -130,13 +98,9 @@ void* StubTensor::getContext() { FL_STUB_TENSOR_UNIMPLEMENTED; } -std::string StubTensor::toString() { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +std::string StubTensor::toString() { FL_STUB_TENSOR_UNIMPLEMENTED; } -std::ostream& StubTensor::operator<<(std::ostream& /* ostr */) { - FL_STUB_TENSOR_UNIMPLEMENTED; -} +std::ostream& StubTensor::operator<<(std::ostream& /* ostr */) { FL_STUB_TENSOR_UNIMPLEMENTED; } /******************** Assignment Operators ********************/ #define FL_STUB_TENSOR_ASSIGN_OP_TYPE(OP, TYPE) \ diff --git a/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp b/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp index 5eb4099..bf5e215 100644 --- a/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp +++ b/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp @@ -267,7 +267,6 @@ TEST(AutogradNormalizationTest, BatchNormJacobian) { ); }; - ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnIn, input, 1e-2, 1e-4, {&weight, &bias})); auto funcBnWt = [&](Variable& wt) { @@ -285,7 +284,6 @@ TEST(AutogradNormalizationTest, BatchNormJacobian) { }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcBnWt, weight, 1e-2, 1e-4, {&input, &bias})); - auto funcBnBs = [&](Variable& bs) { return batchnorm( input, diff --git a/flashlight/fl/test/common/DynamicBenchmarkTest.cpp b/flashlight/fl/test/common/DynamicBenchmarkTest.cpp index d22b1ec..f228634 100644 --- a/flashlight/fl/test/common/DynamicBenchmarkTest.cpp +++ b/flashlight/fl/test/common/DynamicBenchmarkTest.cpp @@ -141,11 +141,11 @@ TEST_F(DynamicBenchmark, DynamicBenchmarkMatmul) { ->currentOption(); dynamicBench->audit( [size]() { - auto a = fl::rand({size, size}); - auto b = fl::rand({size, size}); - auto c = fl::matmul(a, b); - fl::eval(c); - } + auto a = fl::rand({size, size}); + auto b = fl::rand({size, size}); + auto c = fl::matmul(a, b); + fl::eval(c); + } ); } auto ops = dynamicBench->getOptions>(); diff --git a/flashlight/fl/test/common/SerializationTest.cpp b/flashlight/fl/test/common/SerializationTest.cpp index 00cc2c4..27f8717 100644 --- a/flashlight/fl/test/common/SerializationTest.cpp +++ b/flashlight/fl/test/common/SerializationTest.cpp @@ -223,8 +223,7 @@ struct WeirdTransform { }; template -WeirdTransform weirdTransform(T&& t) { - return WeirdTransform{std::forward(t)}; +WeirdTransform weirdTransform(T&& t) { return WeirdTransform{std::forward(t)}; } struct SerializeViaTemporary { diff --git a/flashlight/fl/test/common/UtilsTest.cpp b/flashlight/fl/test/common/UtilsTest.cpp index 282b3c2..dd85550 100644 --- a/flashlight/fl/test/common/UtilsTest.cpp +++ b/flashlight/fl/test/common/UtilsTest.cpp @@ -23,7 +23,7 @@ static std::function makeSucceedsAfterIters(int iters) { return 42; else throw std::runtime_error("bleh"); - }; + }; } static std::function makeSucceedsAfterMs(double ms) { @@ -37,7 +37,7 @@ static std::function makeSucceedsAfterMs(double ms) { return 42; else throw std::runtime_error("bleh"); - }; + }; } template @@ -50,8 +50,8 @@ std::future::type> retryAsync( return std::async( std::launch::async, [ = ]() { - return retryWithBackoff(initial, factor, iters, f); - } + return retryWithBackoff(initial, factor, iters, f); + } ); } diff --git a/flashlight/fl/test/dataset/DatasetTest.cpp b/flashlight/fl/test/dataset/DatasetTest.cpp index 258e53e..8962f1d 100644 --- a/flashlight/fl/test/dataset/DatasetTest.cpp +++ b/flashlight/fl/test/dataset/DatasetTest.cpp @@ -276,12 +276,12 @@ TEST(DatasetTest, FileBlobDataset) { blob.setHostTransform( 0, [](void* ptr, fl::Shape size, fl::dtype /* type */) { - float* ptrFl = (float*) ptr; - for(int64_t i = 0; i < size.elements(); i++) - ptrFl[i] += 1; - return Tensor::fromBuffer(size, ptrFl, MemoryLocation::Host); - } - ); + float* ptrFl = (float*) ptr; + for(int64_t i = 0; i < size.elements(); i++) + ptrFl[i] += 1; + return Tensor::fromBuffer(size, ptrFl, MemoryLocation::Host); + } + ); check(blob); for(auto& vec : data) if(!vec.empty()) @@ -314,10 +314,10 @@ TEST(DatasetTest, FileBlobDataset) { auto device = fl::getDevice(); workers.emplace_back( [i, blob, nperworker, device, &thdata]() { - fl::setDevice(device); - for(int j = 0; j < nperworker; j++) - thdata[i * nperworker + j] = blob->get(i * nperworker + j); - } + fl::setDevice(device); + for(int j = 0; j < nperworker; j++) + thdata[i * nperworker + j] = blob->get(i * nperworker + j); + } ); } for(int i = 0; i < nworker; i++) @@ -355,10 +355,10 @@ TEST(DatasetTest, FileBlobDataset) { for(int i = 0; i < nworker; i++) workers.emplace_back( [i, blob, nperworker, device, &data]() { - fl::setDevice(device); - for(int j = 0; j < nperworker; j++) - blob->add(data[i * nperworker + j]); - } + fl::setDevice(device); + for(int j = 0; j < nperworker; j++) + blob->add(data[i * nperworker + j]); + } ); for(int i = 0; i < nworker; i++) workers[i].join(); @@ -453,12 +453,12 @@ TEST(DatasetTest, MemoryBlobDataset) { blob.setHostTransform( 0, [](void* ptr, fl::Shape size, fl::dtype /* type */) { - float* ptrFl = (float*) ptr; - for(int64_t i = 0; i < size.elements(); i++) - ptrFl[i] += 1; - return Tensor::fromBuffer(size, ptrFl, MemoryLocation::Host); - } - ); + float* ptrFl = (float*) ptr; + for(int64_t i = 0; i < size.elements(); i++) + ptrFl[i] += 1; + return Tensor::fromBuffer(size, ptrFl, MemoryLocation::Host); + } + ); check(blob); } @@ -472,10 +472,10 @@ TEST(DatasetTest, MemoryBlobDataset) { auto device = fl::getDevice(); workers.emplace_back( [i, &blob, nperworker, device, &thdata]() { - fl::setDevice(device); - for(int j = 0; j < nperworker; j++) - thdata[i * nperworker + j] = blob.get(i * nperworker + j); - } + fl::setDevice(device); + for(int j = 0; j < nperworker; j++) + thdata[i * nperworker + j] = blob.get(i * nperworker + j); + } ); } for(int i = 0; i < nworker; i++) @@ -509,10 +509,10 @@ TEST(DatasetTest, MemoryBlobDataset) { for(int i = 0; i < nworker; i++) workers.emplace_back( [i, &wblob, nperworker, device, &data]() { - fl::setDevice(device); - for(int j = 0; j < nperworker; j++) - wblob.add(data[i * nperworker + j]); - } + fl::setDevice(device); + for(int j = 0; j < nperworker; j++) + wblob.add(data[i * nperworker + j]); + } ); for(int i = 0; i < nworker; i++) workers[i].join(); diff --git a/flashlight/fl/test/nn/ModuleTest.cpp b/flashlight/fl/test/nn/ModuleTest.cpp index 3fd8d61..efb5c75 100644 --- a/flashlight/fl/test/nn/ModuleTest.cpp +++ b/flashlight/fl/test/nn/ModuleTest.cpp @@ -49,9 +49,7 @@ class ContainerTestClass : public Sequential { return std::make_unique(*this); } - void addParam(const Variable& param) { - params_.push_back(param); - } + void addParam(const Variable& param) { params_.push_back(param); } }; class ModuleTestF16 : public ::testing::Test { diff --git a/flashlight/fl/test/nn/NNSerializationTest.cpp b/flashlight/fl/test/nn/NNSerializationTest.cpp index ca5106f..fb86f5a 100644 --- a/flashlight/fl/test/nn/NNSerializationTest.cpp +++ b/flashlight/fl/test/nn/NNSerializationTest.cpp @@ -27,9 +27,7 @@ class ContainerTestClass : public Sequential { public: ContainerTestClass() = default; - void addParam(const Variable& param) { - params_.push_back(param); - } + void addParam(const Variable& param) { params_.push_back(param); } private: FL_SAVE_LOAD_WITH_BASE(Sequential) diff --git a/flashlight/fl/test/runtime/CUDAStreamTest.cpp b/flashlight/fl/test/runtime/CUDAStreamTest.cpp index 8a53e6b..0a5de30 100644 --- a/flashlight/fl/test/runtime/CUDAStreamTest.cpp +++ b/flashlight/fl/test/runtime/CUDAStreamTest.cpp @@ -46,7 +46,6 @@ TEST(CUDAStreamTest, createUnmanaged) { } } - TEST(CUDAStreamTest, unmanagedWrapper) { auto& manager = DeviceManager::getInstance(); int numCudaDevices = 0; diff --git a/flashlight/fl/test/tensor/TensorBaseTest.cpp b/flashlight/fl/test/tensor/TensorBaseTest.cpp index 730dfd9..22dd1b9 100644 --- a/flashlight/fl/test/tensor/TensorBaseTest.cpp +++ b/flashlight/fl/test/tensor/TensorBaseTest.cpp @@ -512,7 +512,6 @@ void assertScalarBehavior(fl::dtype type) { << "dtype: " << type << ", ScalarArgType: " << dtype_traits::getName(); - ScalarArgType val = static_cast(rand()); auto a = fl::full({5, 6}, val, type); diff --git a/flashlight/fl/test/tensor/af/ArrayFireCPUStreamTest.cpp b/flashlight/fl/test/tensor/af/ArrayFireCPUStreamTest.cpp index b59edc0..66a21db 100644 --- a/flashlight/fl/test/tensor/af/ArrayFireCPUStreamTest.cpp +++ b/flashlight/fl/test/tensor/af/ArrayFireCPUStreamTest.cpp @@ -17,7 +17,6 @@ using fl::Stream; using fl::StreamType; using fl::ArrayFireCPUStream; - int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); fl::init(); diff --git a/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp b/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp index 38937ef..503d8d5 100644 --- a/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp +++ b/flashlight/fl/test/tensor/af/ArrayFireTensorBaseTest.cpp @@ -164,9 +164,9 @@ TEST(ArrayFireTensorBaseTest, withTensorType) { Tensor t; fl::withTensorType( [&t]() { - t = fl::full({5, 5}, 6.); - t += 1; - } + t = fl::full({5, 5}, 6.); + t += 1; + } ); ASSERT_TRUE(allClose(t, fl::full({5, 5}, 7.))); } diff --git a/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp b/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp index 90af69c..a53f920 100644 --- a/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp +++ b/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp @@ -191,15 +191,15 @@ class MockTestMemoryManager : public TestMemoryManager { ON_CALL(*this, initialize()).WillByDefault( Invoke( [this]() { - real_->initialize(); - } + real_->initialize(); + } ) ); ON_CALL(*this, shutdown()).WillByDefault( Invoke( [this]() { - real_->shutdown(); - } + real_->shutdown(); + } ) ); ON_CALL(*this, alloc(_, _, _, _)) @@ -210,45 +210,45 @@ class MockTestMemoryManager : public TestMemoryManager { const unsigned ndims, dim_t* dims, const unsigned elSize) { - return real_->alloc(userLock, ndims, dims, elSize); - } + return real_->alloc(userLock, ndims, dims, elSize); + } ) ); ON_CALL(*this, allocated(_)).WillByDefault( Invoke( [this](void* ptr) { - return real_->allocated(ptr); - } + return real_->allocated(ptr); + } ) ); ON_CALL(*this, unlock(_, _)) .WillByDefault( Invoke( [this](void* ptr, bool userLock) { - real_->unlock(ptr, userLock); - } + real_->unlock(ptr, userLock); + } ) ); ON_CALL(*this, signalMemoryCleanup()).WillByDefault( Invoke( [this]() { - real_->signalMemoryCleanup(); - } + real_->signalMemoryCleanup(); + } ) ); ON_CALL(*this, printInfo(_, _, _)) .WillByDefault( Invoke( [this](const char* msg, const int device, std::ostream* ostream) { - real_->printInfo(msg, device, ostream); - } + real_->printInfo(msg, device, ostream); + } ) ); ON_CALL(*this, userLock(_)).WillByDefault( Invoke( [this](const void* cPtr) { - real_->userLock(cPtr); - } + real_->userLock(cPtr); + } ) ); ON_CALL(*this, userUnlock(_)) @@ -263,16 +263,16 @@ class MockTestMemoryManager : public TestMemoryManager { ON_CALL(*this, getMemoryPressure()).WillByDefault( Invoke( [this]() { - return real_->getMemoryPressure(); - } + return real_->getMemoryPressure(); + } ) ); ON_CALL(*this, jitTreeExceedsMemoryPressure(_)) .WillByDefault( Invoke( [this](size_t bytes) { - return real_->jitTreeExceedsMemoryPressure(bytes); - } + return real_->jitTreeExceedsMemoryPressure(bytes); + } ) ); } diff --git a/flashlight/pkg/runtime/common/DistributedUtils.cpp b/flashlight/pkg/runtime/common/DistributedUtils.cpp index a2ab6d9..5ec02b9 100644 --- a/flashlight/pkg/runtime/common/DistributedUtils.cpp +++ b/flashlight/pkg/runtime/common/DistributedUtils.cpp @@ -54,8 +54,7 @@ Tensor allreduceGet(fl::CountMeter& mtr) { return Tensor::fromVector(mtrVal); } -Tensor allreduceGet(fl::TimeMeter& mtr) { - return fl::full({1}, mtr.value(), fl::dtype::f64); +Tensor allreduceGet(fl::TimeMeter& mtr) { return fl::full({1}, mtr.value(), fl::dtype::f64); } Tensor allreduceGet(fl::TopKMeter& mtr) { diff --git a/flashlight/pkg/runtime/plugin/ModulePlugin.cpp b/flashlight/pkg/runtime/plugin/ModulePlugin.cpp index 3528392..844fa8b 100644 --- a/flashlight/pkg/runtime/plugin/ModulePlugin.cpp +++ b/flashlight/pkg/runtime/plugin/ModulePlugin.cpp @@ -16,8 +16,6 @@ ModulePlugin::ModulePlugin(const std::string& name) : fl::Plugin(name) { std::shared_ptr ModulePlugin::arch( int64_t nFeatures, int64_t nClasses -) { - return std::shared_ptr(arch_(nFeatures, nClasses)); -} +) { return std::shared_ptr(arch_(nFeatures, nClasses)); } } // namespace fl diff --git a/flashlight/pkg/speech/audio/feature/Derivatives.cpp b/flashlight/pkg/speech/audio/feature/Derivatives.cpp index f1d1008..bb9d96f 100644 --- a/flashlight/pkg/speech/audio/feature/Derivatives.cpp +++ b/flashlight/pkg/speech/audio/feature/Derivatives.cpp @@ -10,7 +10,6 @@ #include #include - namespace fl::lib::audio { Derivatives::Derivatives(int deltawindow, int accwindow) : deltaWindow_(deltawindow), diff --git a/flashlight/pkg/speech/augmentation/SoundEffect.cpp b/flashlight/pkg/speech/augmentation/SoundEffect.cpp index 07f5912..647dcd6 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffect.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffect.cpp @@ -22,18 +22,14 @@ std::string SoundEffectChain::prettyString() const { return ss.str(); } -void SoundEffectChain::add(std::shared_ptr SoundEffect) { - soundEffects_.push_back(SoundEffect); -} +void SoundEffectChain::add(std::shared_ptr SoundEffect) { soundEffects_.push_back(SoundEffect); } void SoundEffectChain::apply(std::vector& sound) { for(std::shared_ptr& effect : soundEffects_) effect->apply(sound); } -bool SoundEffectChain::empty() { - return soundEffects_.empty(); -} +bool SoundEffectChain::empty() { return soundEffects_.empty(); } Normalize::Normalize(bool onlyIfTooHigh) : onlyIfTooHigh_(onlyIfTooHigh) {} diff --git a/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp b/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp index 9144dfb..c333f16 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp @@ -16,7 +16,6 @@ #include #include - using namespace ::fl::pkg::speech::sfx; namespace cereal { diff --git a/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp b/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp index 65cc940..5d3668b 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffectUtil.cpp @@ -22,9 +22,7 @@ int RandomNumberGenerator::randInt(int minVal, int maxVal) { return randomEngine_() % (maxVal - minVal + 1) + minVal; } -float RandomNumberGenerator::random() { - return uniformDist_(randomEngine_); -} +float RandomNumberGenerator::random() { return uniformDist_(randomEngine_); } float RandomNumberGenerator::uniform(float minVal, float maxVal) { return minVal + (maxVal - minVal) * uniformDist_(randomEngine_); diff --git a/flashlight/pkg/speech/augmentation/SoxWrapper.h b/flashlight/pkg/speech/augmentation/SoxWrapper.h index 717e4ca..acd9de6 100644 --- a/flashlight/pkg/speech/augmentation/SoxWrapper.h +++ b/flashlight/pkg/speech/augmentation/SoxWrapper.h @@ -82,9 +82,7 @@ namespace pkg { // when building sound effects without libsox. class SoxWrapper { public: - static SoxWrapper* instance(size_t sampleRate = 16000) { - return nullptr; - } + static SoxWrapper* instance(size_t sampleRate = 16000) { return nullptr; } void applyAndFreeEffect(std::vector& signal, sox_effect_t* effect) const {} }; diff --git a/flashlight/pkg/speech/common/Flags.h b/flashlight/pkg/speech/common/Flags.h index 1551b61..8e61da8 100644 --- a/flashlight/pkg/speech/common/Flags.h +++ b/flashlight/pkg/speech/common/Flags.h @@ -12,7 +12,6 @@ #include - namespace fl { namespace pkg { namespace speech { diff --git a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp index 8eea4c5..7477bbf 100644 --- a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp +++ b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.cpp @@ -304,9 +304,7 @@ std::pair Seq2SeqCriterion::decoder( Tensor Seq2SeqCriterion::viterbiPath( const Tensor& input, const Tensor& inputSizes /* = Tensor() */ -) { - return viterbiPathBase(input, inputSizes, false).first; -} +) { return viterbiPathBase(input, inputSizes, false).first; } std::pair Seq2SeqCriterion::viterbiPathBase( const Tensor& input, diff --git a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.h b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.h index c06a59a..92a9f34 100644 --- a/flashlight/pkg/speech/criterion/Seq2SeqCriterion.h +++ b/flashlight/pkg/speech/criterion/Seq2SeqCriterion.h @@ -158,13 +158,9 @@ namespace pkg { setUseSequentialDecoder(); } - void setGumbelTemperature(double temperature) { - gumbelTemperature_ = temperature; - } + void setGumbelTemperature(double temperature) { gumbelTemperature_ = temperature; } - void setLabelSmooth(double labelSmooth) { - labelSmooth_ = labelSmooth; - } + void setLabelSmooth(double labelSmooth) { labelSmooth_ = labelSmooth; } private: int eos_; diff --git a/flashlight/pkg/speech/criterion/TransformerCriterion.cpp b/flashlight/pkg/speech/criterion/TransformerCriterion.cpp index c998681..4ed9a3d 100644 --- a/flashlight/pkg/speech/criterion/TransformerCriterion.cpp +++ b/flashlight/pkg/speech/criterion/TransformerCriterion.cpp @@ -168,9 +168,7 @@ std::pair TransformerCriterion::vectorizedDecoder( Tensor TransformerCriterion::viterbiPath( const Tensor& input, const Tensor& inputSizes /* = Tensor() */ -) { - return viterbiPathBase(input, inputSizes, false).first; -} +) { return viterbiPathBase(input, inputSizes, false).first; } std::pair TransformerCriterion::viterbiPathBase( const Tensor& input, diff --git a/flashlight/pkg/speech/criterion/attention/AttentionBase.h b/flashlight/pkg/speech/criterion/attention/AttentionBase.h index f46c783..3027a8f 100644 --- a/flashlight/pkg/speech/criterion/attention/AttentionBase.h +++ b/flashlight/pkg/speech/criterion/attention/AttentionBase.h @@ -74,9 +74,7 @@ namespace pkg { const Variable& prevAttn, const Variable& logAttnWeight, const Variable& xEncodedSizes - ) { - return forwardBase(state, xEncoded, prevAttn, logAttnWeight, xEncodedSizes); - } + ) { return forwardBase(state, xEncoded, prevAttn, logAttnWeight, xEncodedSizes); } protected: /** diff --git a/flashlight/pkg/speech/data/FeatureTransforms.cpp b/flashlight/pkg/speech/data/FeatureTransforms.cpp index fea4737..01b0e7d 100644 --- a/flashlight/pkg/speech/data/FeatureTransforms.cpp +++ b/flashlight/pkg/speech/data/FeatureTransforms.cpp @@ -120,7 +120,7 @@ fl::Dataset::DataTransformFunction inputFeatures( output.data(), MemoryLocation::Host ); - }; + }; } // target @@ -166,7 +166,7 @@ fl::Dataset::DataTransformFunction targetFeatures( // support empty target return Tensor(fl::dtype::s32); return Tensor::fromVector(tgtVec); - }; + }; } fl::Dataset::DataTransformFunction wordFeatures(const Dictionary& wrdDict) { @@ -179,6 +179,6 @@ fl::Dataset::DataTransformFunction wordFeatures(const Dictionary& wrdDict) { // support empty target return Tensor(fl::dtype::s32); return Tensor::fromVector(wrdVec); - }; + }; } } // namespace fl diff --git a/flashlight/pkg/speech/decoder/DecodeMaster.cpp b/flashlight/pkg/speech/decoder/DecodeMaster.cpp index 876eb01..177a3fa 100644 --- a/flashlight/pkg/speech/decoder/DecodeMaster.cpp +++ b/flashlight/pkg/speech/decoder/DecodeMaster.cpp @@ -26,12 +26,8 @@ constexpr size_t kDMWordPredIdx = 3; using namespace fl; -Tensor removeNegative(const fl::Tensor& arr) { - return arr(arr >= 0); -} -Tensor removePad(const Tensor& arr, int32_t padIdx) { - return arr(arr != padIdx); -} +Tensor removeNegative(const fl::Tensor& arr) { return arr(arr >= 0); } +Tensor removePad(const Tensor& arr, int32_t padIdx) { return arr(arr != padIdx); } } // namespace // TODO threading? diff --git a/flashlight/pkg/speech/decoder/PlGenerator.cpp b/flashlight/pkg/speech/decoder/PlGenerator.cpp index 86d5356..1078f98 100644 --- a/flashlight/pkg/speech/decoder/PlGenerator.cpp +++ b/flashlight/pkg/speech/decoder/PlGenerator.cpp @@ -269,9 +269,7 @@ std::shared_ptr PlGenerator::createTrainSet( ); } -void PlGenerator::setModelWER(const float& wer) { - currentModelWER_ = wer; -} +void PlGenerator::setModelWER(const float& wer) { currentModelWER_ = wer; } int PlGenerator::findLastPlEpoch(int curEpoch) const { int lastPlEpoch = -1; diff --git a/flashlight/pkg/speech/runtime/SpeechStatMeter.cpp b/flashlight/pkg/speech/runtime/SpeechStatMeter.cpp index 1128294..fd6f3ac 100644 --- a/flashlight/pkg/speech/runtime/SpeechStatMeter.cpp +++ b/flashlight/pkg/speech/runtime/SpeechStatMeter.cpp @@ -13,9 +13,7 @@ SpeechStatMeter::SpeechStatMeter() { reset(); } -void SpeechStatMeter::reset() { - stats_.reset(); -} +void SpeechStatMeter::reset() { stats_.reset(); } void SpeechStatMeter::add(const Tensor& inputSizes, const Tensor& targetSizes) { int64_t curInputSz = fl::sum(inputSizes).asScalar(); diff --git a/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp b/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp index a641903..708857d 100644 --- a/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp +++ b/flashlight/pkg/speech/test/augmentation/ReverberationTest.cpp @@ -104,8 +104,8 @@ TEST(ReverbEcho, SinWaveReverb) { noise.end(), noise.begin(), [norm](float x) -> float { - return x / norm; - } + return x / norm; + } ); // To reduce test flakiness, we trim the edges of the noise and compare only diff --git a/flashlight/pkg/speech/test/criterion/CriterionTest.cpp b/flashlight/pkg/speech/test/criterion/CriterionTest.cpp index 98fdddc..4b09a28 100644 --- a/flashlight/pkg/speech/test/criterion/CriterionTest.cpp +++ b/flashlight/pkg/speech/test/criterion/CriterionTest.cpp @@ -193,8 +193,8 @@ TEST(CriterionTest, CTCCompareTensorflow) { input1.end(), input1.begin(), [](float p) -> float { - return log(p); - } + return log(p); + } ); std::array gradExpected1 = { -0.366234, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553, @@ -231,8 +231,8 @@ TEST(CriterionTest, CTCCompareTensorflow) { input2.end(), input2.begin(), [](float p) -> float { - return log(p); - } + return log(p); + } ); std::array gradExpected2 = { -0.69824, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508, @@ -491,8 +491,8 @@ TEST(CriterionTest, FCCCost) { input1.end(), input1.begin(), [](float p) -> float { - return log(p); - } + return log(p); + } ); std::array dummyTarget1 = {0, 0}; const int N1 = 2, L1 = 2, T1 = 3, B1 = 2; @@ -618,8 +618,8 @@ TEST(CriterionTest, ASGCost) { input1.end(), input1.begin(), [](float p) -> float { - return log(p); - } + return log(p); + } ); std::array target1 = {0, 1, 0, 1}; std::array trans1 = {}; diff --git a/flashlight/pkg/speech/test/data/SoundTest.cpp b/flashlight/pkg/speech/test/data/SoundTest.cpp index 3434da4..4ef712d 100644 --- a/flashlight/pkg/speech/test/data/SoundTest.cpp +++ b/flashlight/pkg/speech/test/data/SoundTest.cpp @@ -60,8 +60,8 @@ TEST(SoundTest, Mono) { data.end(), data.begin(), [](double d) -> double { - return d * (1 << 15); - } + return d * (1 << 15); + } ); // Short @@ -77,8 +77,8 @@ TEST(SoundTest, Mono) { data.end(), data.begin(), [](double d) -> double { - return d * (1 << 16); - } + return d * (1 << 16); + } ); // Int auto vecInt = loadSound(audiopath); diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctamerge.cuh b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctamerge.cuh index cd00264..4ae844d 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctamerge.cuh +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctamerge.cuh @@ -26,11 +26,11 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ #pragma once @@ -44,27 +44,35 @@ namespace mgpu { // SerialMerge template -MGPU_DEVICE void SerialMerge(const T* keys_shared, int aBegin, int aEnd, - int bBegin, int bEnd, T* results, int* indices, Comp comp) { - - T aKey = keys_shared[aBegin]; - T bKey = keys_shared[bBegin]; - - #pragma unroll - for(int i = 0; i < VT; ++i) { - bool p; - if(RangeCheck) - p = (bBegin >= bEnd) || ((aBegin < aEnd) && !comp(bKey, aKey)); - else - p = !comp(bKey, aKey); - - results[i] = p ? aKey : bKey; - indices[i] = p ? aBegin : bBegin - !RangeCheck; - - if(p) aKey = keys_shared[++aBegin]; - else bKey = keys_shared[++bBegin]; - } - __syncthreads(); +MGPU_DEVICE void SerialMerge( + const T* keys_shared, + int aBegin, + int aEnd, + int bBegin, + int bEnd, + T* results, + int* indices, + Comp comp +) { + + T aKey = keys_shared[aBegin]; + T bKey = keys_shared[bBegin]; + +#pragma unroll + for(int i = 0; i < VT; ++i) { + bool p; + if(RangeCheck) + p = (bBegin >= bEnd) || ((aBegin < aEnd) && !comp(bKey, aKey)); + else + p = !comp(bKey, aKey); + + results[i] = p ? aKey : bKey; + indices[i] = p ? aBegin : bBegin - !RangeCheck; + + if(p) aKey = keys_shared[++aBegin]; + else bKey = keys_shared[++bBegin]; + } + __syncthreads(); } //////////////////////////////////////////////////////////////////////////////// @@ -73,107 +81,165 @@ MGPU_DEVICE void SerialMerge(const T* keys_shared, int aBegin, int aEnd, // Returns (offset of a, offset of b, length of list). MGPU_HOST_DEVICE int3 FindMergesortFrame(int coop, int block, int nv) { - // coop is the number of CTAs or threads cooperating to merge two lists into - // one. We round block down to the first CTA's ID that is working on this - // merge. - int start = ~(coop - 1) & block; - int size = nv * (coop>> 1); - return make_int3(nv * start, nv * start + size, size); + // coop is the number of CTAs or threads cooperating to merge two lists into + // one. We round block down to the first CTA's ID that is working on this + // merge. + int start = ~(coop - 1) & block; + int size = nv * (coop >> 1); + return make_int3(nv * start, nv * start + size, size); } // Returns (a0, a1, b0, b1) into mergesort input lists between mp0 and mp1. -MGPU_HOST_DEVICE int4 FindMergesortInterval(int3 frame, int coop, int block, - int nv, int count, int mp0, int mp1) { - - // Locate diag from the start of the A sublist. - int diag = nv * block - frame.x; - int a0 = frame.x + mp0; - int a1 = min(count, frame.x + mp1); - int b0 = min(count, frame.y + diag - mp0); - int b1 = min(count, frame.y + diag + nv - mp1); - - // The end partition of the last block for each merge operation is computed - // and stored as the begin partition for the subsequent merge. i.e. it is - // the same partition but in the wrong coordinate system, so its 0 when it - // should be listSize. Correct that by checking if this is the last block - // in this merge operation. - if(coop - 1 == ((coop - 1) & block)) { - a1 = min(count, frame.x + frame.z); - b1 = min(count, frame.y + frame.z); - } - return make_int4(a0, a1, b0, b1); +MGPU_HOST_DEVICE int4 FindMergesortInterval( + int3 frame, + int coop, + int block, + int nv, + int count, + int mp0, + int mp1 +) { + + // Locate diag from the start of the A sublist. + int diag = nv * block - frame.x; + int a0 = frame.x + mp0; + int a1 = min(count, frame.x + mp1); + int b0 = min(count, frame.y + diag - mp0); + int b1 = min(count, frame.y + diag + nv - mp1); + + // The end partition of the last block for each merge operation is computed + // and stored as the begin partition for the subsequent merge. i.e. it is + // the same partition but in the wrong coordinate system, so its 0 when it + // should be listSize. Correct that by checking if this is the last block + // in this merge operation. + if(coop - 1 == ((coop - 1) & block)) { + a1 = min(count, frame.x + frame.z); + b1 = min(count, frame.y + frame.z); + } + return make_int4(a0, a1, b0, b1); } //////////////////////////////////////////////////////////////////////////////// // ComputeMergeRange -MGPU_HOST_DEVICE int4 ComputeMergeRange(int aCount, int bCount, int block, - int coop, int NV, const int* mp_global) { - - // Load the merge paths computed by the partitioning kernel. - int mp0 = mp_global[block]; - int mp1 = mp_global[block + 1]; - int gid = NV * block; - - // Compute the ranges of the sources in global memory. - int4 range; - if(coop) { - int3 frame = FindMergesortFrame(coop, block, NV); - range = FindMergesortInterval(frame, coop, block, NV, aCount, mp0, - mp1); - } else { - range.x = mp0; // a0 - range.y = mp1; // a1 - range.z = gid - range.x; // b0 - range.w = min(aCount + bCount, gid + NV) - range.y; // b1 - } - return range; +MGPU_HOST_DEVICE int4 ComputeMergeRange( + int aCount, + int bCount, + int block, + int coop, + int NV, + const int* mp_global +) { + + // Load the merge paths computed by the partitioning kernel. + int mp0 = mp_global[block]; + int mp1 = mp_global[block + 1]; + int gid = NV * block; + + // Compute the ranges of the sources in global memory. + int4 range; + if(coop) { + int3 frame = FindMergesortFrame(coop, block, NV); + range = FindMergesortInterval( + frame, + coop, + block, + NV, + aCount, + mp0, + mp1 + ); + } else { + range.x = mp0; // a0 + range.y = mp1; // a1 + range.z = gid - range.x; // b0 + range.w = min(aCount + bCount, gid + NV) - range.y; // b1 + } + return range; } //////////////////////////////////////////////////////////////////////////////// // CTA mergesort support template -MGPU_DEVICE void CTABlocksortPass(T* keys_shared, int tid, int count, - int coop, T* keys, int* indices, Comp comp) { - - int list = ~(coop - 1) & tid; - int diag = min(count, VT * ((coop - 1) & tid)); - int start = VT * list; - int a0 = min(count, start); - int b0 = min(count, start + VT * (coop / 2)); - int b1 = min(count, start + VT * coop); - - int p = MergePath(keys_shared + a0, b0 - a0, - keys_shared + b0, b1 - b0, diag, comp); - - SerialMerge(keys_shared, a0 + p, b0, b0 + diag - p, b1, keys, - indices, comp); +MGPU_DEVICE void CTABlocksortPass( + T* keys_shared, + int tid, + int count, + int coop, + T* keys, + int* indices, + Comp comp +) { + + int list = ~(coop - 1) & tid; + int diag = min(count, VT * ((coop - 1) & tid)); + int start = VT * list; + int a0 = min(count, start); + int b0 = min(count, start + VT * (coop / 2)); + int b1 = min(count, start + VT * coop); + + int p = MergePath( + keys_shared + a0, + b0 - a0, + keys_shared + b0, + b1 - b0, + diag, + comp + ); + + SerialMerge( + keys_shared, + a0 + p, + b0, + b0 + diag - p, + b1, + keys, + indices, + comp + ); } template -MGPU_DEVICE void CTABlocksortLoop(ValType threadValues[VT], - KeyType* keys_shared, ValType* values_shared, int tid, int count, - Comp comp) { - - #pragma unroll - for(int coop = 2; coop <= NT; coop *= 2) { - int indices[VT]; - KeyType keys[VT]; - CTABlocksortPass(keys_shared, tid, count, coop, keys, - indices, comp); - - if(HasValues) { - // Exchange the values through shared memory. - DeviceThreadToShared(threadValues, tid, values_shared); - DeviceGather(NT * VT, values_shared, indices, tid, - threadValues); - } - - // Store results in shared memory in sorted order. - DeviceThreadToShared(keys, tid, keys_shared); - } + typename Comp> +MGPU_DEVICE void CTABlocksortLoop( + ValType threadValues[VT], + KeyType* keys_shared, + ValType* values_shared, + int tid, + int count, + Comp comp +) { + +#pragma unroll + for(int coop = 2; coop <= NT; coop *= 2) { + int indices[VT]; + KeyType keys[VT]; + CTABlocksortPass( + keys_shared, + tid, + count, + coop, + keys, + indices, + comp + ); + + if(HasValues) { + // Exchange the values through shared memory. + DeviceThreadToShared(threadValues, tid, values_shared); + DeviceGather( + NT * VT, + values_shared, + indices, + tid, + threadValues + ); + } + + // Store results in shared memory in sorted order. + DeviceThreadToShared(keys, tid, keys_shared); + } } //////////////////////////////////////////////////////////////////////////////// @@ -182,115 +248,209 @@ MGPU_DEVICE void CTABlocksortLoop(ValType threadValues[VT], // count elements. template -MGPU_DEVICE void CTAMergesort(KeyType threadKeys[VT], ValType threadValues[VT], - KeyType* keys_shared, ValType* values_shared, int count, int tid, - Comp comp) { - - // Stable sort the keys in the thread. - if(VT * tid < count) { - if(Stable) - OddEvenTransposeSort(threadKeys, threadValues, comp); - else - OddEvenMergesort(threadKeys, threadValues, comp); - } - - // Store the locally sorted keys into shared memory. - DeviceThreadToShared(threadKeys, tid, keys_shared); - - // Recursively merge lists until the entire CTA is sorted. - CTABlocksortLoop(threadValues, keys_shared, - values_shared, tid, count, comp); + typename ValType, typename Comp> +MGPU_DEVICE void CTAMergesort( + KeyType threadKeys[VT], + ValType threadValues[VT], + KeyType* keys_shared, + ValType* values_shared, + int count, + int tid, + Comp comp +) { + + // Stable sort the keys in the thread. + if(VT * tid < count) { + if(Stable) + OddEvenTransposeSort(threadKeys, threadValues, comp); + else + OddEvenMergesort(threadKeys, threadValues, comp); + } + + // Store the locally sorted keys into shared memory. + DeviceThreadToShared(threadKeys, tid, keys_shared); + + // Recursively merge lists until the entire CTA is sorted. + CTABlocksortLoop( + threadValues, + keys_shared, + values_shared, + tid, + count, + comp + ); } template -MGPU_DEVICE void CTAMergesortKeys(KeyType threadKeys[VT], - KeyType* keys_shared, int count, int tid, Comp comp) { - - int valuesTemp[VT]; - CTAMergesort(threadKeys, valuesTemp, keys_shared, - (int*)keys_shared, count, tid, comp); +MGPU_DEVICE void CTAMergesortKeys( + KeyType threadKeys[VT], + KeyType* keys_shared, + int count, + int tid, + Comp comp +) { + + int valuesTemp[VT]; + CTAMergesort( + threadKeys, + valuesTemp, + keys_shared, + (int*) keys_shared, + count, + tid, + comp + ); } template -MGPU_DEVICE void CTAMergesortPairs(KeyType threadKeys[VT], - ValType threadValues[VT], KeyType* keys_shared, ValType* values_shared, - int count, int tid, Comp comp) { - - CTAMergesort(threadKeys, threadValues, keys_shared, - values_shared, count, tid, comp); + typename Comp> +MGPU_DEVICE void CTAMergesortPairs( + KeyType threadKeys[VT], + ValType threadValues[VT], + KeyType* keys_shared, + ValType* values_shared, + int count, + int tid, + Comp comp +) { + + CTAMergesort( + threadKeys, + threadValues, + keys_shared, + values_shared, + count, + tid, + comp + ); } //////////////////////////////////////////////////////////////////////////////// // DeviceMergeKeysIndices template -MGPU_DEVICE void DeviceMergeKeysIndices(It1 a_global, int aCount, It2 b_global, - int bCount, int4 range, int tid, T* keys_shared, T* results, int* indices, - Comp comp) { - - int a0 = range.x; - int a1 = range.y; - int b0 = range.z; - int b1 = range.w; - - if(LoadExtended) { - bool extended = (a1 < aCount) && (b1 < bCount); - aCount = a1 - a0; - bCount = b1 - b0; - int aCount2 = aCount + (int)extended; - int bCount2 = bCount + (int)extended; - - // Load one element past the end of each input to avoid having to use - // range checking in the merge loop. - DeviceLoad2ToShared(a_global + a0, aCount2, - b_global + b0, bCount2, tid, keys_shared); - - // Run a Merge Path search for each thread's starting point. - int diag = VT * tid; - int mp = MergePath(keys_shared, aCount, - keys_shared + aCount2, bCount, diag, comp); - - // Compute the ranges of the sources in shared memory. - int a0tid = mp; - int b0tid = aCount2 + diag - mp; - if(extended) { - SerialMerge(keys_shared, a0tid, 0, b0tid, 0, results, - indices, comp); - } else { - int a1tid = aCount; - int b1tid = aCount2 + bCount; - SerialMerge(keys_shared, a0tid, a1tid, b0tid, b1tid, - results, indices, comp); - } - } else { - // Use the input intervals from the ranges between the merge path - // intersections. - aCount = a1 - a0; - bCount = b1 - b0; - - // Load the data into shared memory. - DeviceLoad2ToShared(a_global + a0, aCount, b_global + b0, - bCount, tid, keys_shared); - - // Run a merge path to find the start of the serial merge for each - // thread. - int diag = VT * tid; - int mp = MergePath(keys_shared, aCount, - keys_shared + aCount, bCount, diag, comp); - - // Compute the ranges of the sources in shared memory. - int a0tid = mp; - int a1tid = aCount; - int b0tid = aCount + diag - mp; - int b1tid = aCount + bCount; - - // Serial merge into register. - SerialMerge(keys_shared, a0tid, a1tid, b0tid, b1tid, results, - indices, comp); - } + typename T, typename Comp> +MGPU_DEVICE void DeviceMergeKeysIndices( + It1 a_global, + int aCount, + It2 b_global, + int bCount, + int4 range, + int tid, + T* keys_shared, + T* results, + int* indices, + Comp comp +) { + + int a0 = range.x; + int a1 = range.y; + int b0 = range.z; + int b1 = range.w; + + if(LoadExtended) { + bool extended = (a1 < aCount) && (b1 < bCount); + aCount = a1 - a0; + bCount = b1 - b0; + int aCount2 = aCount + (int) extended; + int bCount2 = bCount + (int) extended; + + // Load one element past the end of each input to avoid having to use + // range checking in the merge loop. + DeviceLoad2ToShared( + a_global + a0, + aCount2, + b_global + b0, + bCount2, + tid, + keys_shared + ); + + // Run a Merge Path search for each thread's starting point. + int diag = VT * tid; + int mp = MergePath( + keys_shared, + aCount, + keys_shared + aCount2, + bCount, + diag, + comp + ); + + // Compute the ranges of the sources in shared memory. + int a0tid = mp; + int b0tid = aCount2 + diag - mp; + if(extended) + SerialMerge( + keys_shared, + a0tid, + 0, + b0tid, + 0, + results, + indices, + comp + ); + else { + int a1tid = aCount; + int b1tid = aCount2 + bCount; + SerialMerge( + keys_shared, + a0tid, + a1tid, + b0tid, + b1tid, + results, + indices, + comp + ); + } + } else { + // Use the input intervals from the ranges between the merge path + // intersections. + aCount = a1 - a0; + bCount = b1 - b0; + + // Load the data into shared memory. + DeviceLoad2ToShared( + a_global + a0, + aCount, + b_global + b0, + bCount, + tid, + keys_shared + ); + + // Run a merge path to find the start of the serial merge for each + // thread. + int diag = VT * tid; + int mp = MergePath( + keys_shared, + aCount, + keys_shared + aCount, + bCount, + diag, + comp + ); + + // Compute the ranges of the sources in shared memory. + int a0tid = mp; + int a1tid = aCount; + int b0tid = aCount + diag - mp; + int b1tid = aCount + bCount; + + // Serial merge into register. + SerialMerge( + keys_shared, + a0tid, + a1tid, + b0tid, + b1tid, + results, + indices, + comp + ); + } } //////////////////////////////////////////////////////////////////////////////// @@ -299,35 +459,67 @@ MGPU_DEVICE void DeviceMergeKeysIndices(It1 a_global, int aCount, It2 b_global, // enable calling from merge, mergesort, and locality sort. template -MGPU_DEVICE void DeviceMerge(KeysIt1 aKeys_global, ValsIt1 aVals_global, - int aCount, KeysIt2 bKeys_global, ValsIt2 bVals_global, int bCount, - int tid, int block, int4 range, KeyType* keys_shared, int* indices_shared, - KeysIt3 keys_global, ValsIt3 vals_global, Comp comp) { - - KeyType results[VT]; - int indices[VT]; - DeviceMergeKeysIndices(aKeys_global, aCount, - bKeys_global, bCount, range, tid, keys_shared, results, indices, comp); - - // Store merge results back to shared memory. - DeviceThreadToShared(results, tid, keys_shared); - - // Store merged keys to global memory. - aCount = range.y - range.x; - bCount = range.w - range.z; - DeviceSharedToGlobal(aCount + bCount, keys_shared, tid, - keys_global + NT * VT * block); - - // Copy the values. - if(HasValues) { - DeviceThreadToShared(indices, tid, indices_shared); - - DeviceTransferMergeValuesShared(aCount + bCount, - aVals_global + range.x, bVals_global + range.z, aCount, - indices_shared, tid, vals_global + NT * VT * block); - } + typename KeysIt2, typename KeysIt3, typename ValsIt1, typename ValsIt2, + typename KeyType, typename ValsIt3, typename Comp> +MGPU_DEVICE void DeviceMerge( + KeysIt1 aKeys_global, + ValsIt1 aVals_global, + int aCount, + KeysIt2 bKeys_global, + ValsIt2 bVals_global, + int bCount, + int tid, + int block, + int4 range, + KeyType* keys_shared, + int* indices_shared, + KeysIt3 keys_global, + ValsIt3 vals_global, + Comp comp +) { + + KeyType results[VT]; + int indices[VT]; + DeviceMergeKeysIndices( + aKeys_global, + aCount, + bKeys_global, + bCount, + range, + tid, + keys_shared, + results, + indices, + comp + ); + + // Store merge results back to shared memory. + DeviceThreadToShared(results, tid, keys_shared); + + // Store merged keys to global memory. + aCount = range.y - range.x; + bCount = range.w - range.z; + DeviceSharedToGlobal( + aCount + bCount, + keys_shared, + tid, + keys_global + NT * VT * block + ); + + // Copy the values. + if(HasValues) { + DeviceThreadToShared(indices, tid, indices_shared); + + DeviceTransferMergeValuesShared( + aCount + bCount, + aVals_global + range.x, + bVals_global + range.z, + aCount, + indices_shared, + tid, + vals_global + NT * VT * block + ); + } } } // namespace mgpu diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctascan.cuh b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctascan.cuh index af88d9b..727d003 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctascan.cuh +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctascan.cuh @@ -26,11 +26,11 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ #pragma once @@ -43,114 +43,122 @@ namespace mgpu { //////////////////////////////////////////////////////////////////////////////// // CTAReduce -template > +template> struct CTAReduce { - typedef typename Op::first_argument_type T; - enum { Size = NT, Capacity = NT }; - struct Storage { T shared[Capacity]; }; - - MGPU_DEVICE static T Reduce(int tid, T x, Storage& storage, Op op = Op()) { - storage.shared[tid] = x; - __syncthreads(); - - // Fold the data in half with each pass. - #pragma unroll - for(int destCount = NT / 2; destCount >= 1; destCount /= 2) { - if(tid < destCount) { - // Read from the right half and store to the left half. - x = op(x, storage.shared[destCount + tid]); - storage.shared[tid] = x; - } - __syncthreads(); - } - T total = storage.shared[0]; - __syncthreads(); - return total; - } + typedef typename Op::first_argument_type T; + enum {Size = NT, Capacity = NT}; + struct Storage {T shared[Capacity];}; + + MGPU_DEVICE static T Reduce(int tid, T x, Storage& storage, Op op = Op()) { + storage.shared[tid] = x; + __syncthreads(); + + // Fold the data in half with each pass. +#pragma unroll + for(int destCount = NT / 2; destCount >= 1; destCount /= 2) { + if(tid < destCount) { + // Read from the right half and store to the left half. + x = op(x, storage.shared[destCount + tid]); + storage.shared[tid] = x; + } + __syncthreads(); + } + T total = storage.shared[0]; + __syncthreads(); + return total; + } }; #if __CUDA_ARCH__ >= 300 template -struct CTAReduce > { - typedef mgpu::plus Op; - typedef int T; - enum { Size = NT, Capacity = WARP_SIZE }; - struct Storage { int shared[Capacity]; }; - - MGPU_DEVICE static int Reduce(int tid, int x, Storage& storage, - Op op = Op()) { - - const int NumSections = WARP_SIZE; - const int SecSize = NT / NumSections; - int lane = (SecSize - 1) & tid; - int sec = tid / SecSize; - - // In the first phase, threads cooperatively find the reduction within - // their segment. The segments are SecSize threads (NT / WARP_SIZE) - // wide. - #pragma unroll - for(int offset = 1; offset < SecSize; offset *= 2) - x = shfl_add(x, offset, SecSize); - - // The last thread in each segment stores the local reduction to shared - // memory. - if(SecSize - 1 == lane) storage.shared[sec] = x; - __syncthreads(); - - // Reduce the totals of each input segment. The spine is WARP_SIZE - // threads wide. - if(tid < NumSections) { - x = storage.shared[tid]; - #pragma unroll - for(int offset = 1; offset < NumSections; offset *= 2) - x = shfl_add(x, offset, NumSections); - storage.shared[tid] = x; - } - __syncthreads(); - - int reduction = storage.shared[NumSections - 1]; - __syncthreads(); - - return reduction; - } +struct CTAReduce> { + typedef mgpu::plus Op; + typedef int T; + enum {Size = NT, Capacity = WARP_SIZE}; + struct Storage {int shared[Capacity];}; + + MGPU_DEVICE static int Reduce( + int tid, + int x, + Storage& storage, + Op op = Op() + ) { + + const int NumSections = WARP_SIZE; + const int SecSize = NT / NumSections; + int lane = (SecSize - 1) & tid; + int sec = tid / SecSize; + + // In the first phase, threads cooperatively find the reduction within + // their segment. The segments are SecSize threads (NT / WARP_SIZE) + // wide. +#pragma unroll + for(int offset = 1; offset < SecSize; offset *= 2) + x = shfl_add(x, offset, SecSize); + + // The last thread in each segment stores the local reduction to shared + // memory. + if(SecSize - 1 == lane) storage.shared[sec] = x; + __syncthreads(); + + // Reduce the totals of each input segment. The spine is WARP_SIZE + // threads wide. + if(tid < NumSections) { + x = storage.shared[tid]; +#pragma unroll + for(int offset = 1; offset < NumSections; offset *= 2) + x = shfl_add(x, offset, NumSections); + storage.shared[tid] = x; + } + __syncthreads(); + + int reduction = storage.shared[NumSections - 1]; + __syncthreads(); + + return reduction; + } }; template -struct CTAReduce > { - typedef mgpu::maximum Op; - enum { Size = NT, Capacity = WARP_SIZE }; - struct Storage { int shared[Capacity]; }; - - MGPU_DEVICE static int Reduce(int tid, int x, Storage& storage, - Op op = Op()) { - - const int NumSections = WARP_SIZE; - const int SecSize = NT / NumSections; - int lane = (SecSize - 1) & tid; - int sec = tid / SecSize; - - #pragma unroll - for(int offset = 1; offset < SecSize; offset *= 2) - x = shfl_max(x, offset, SecSize); - - if(SecSize - 1 == lane) storage.shared[sec] = x; - __syncthreads(); - - if(tid < NumSections) { - x = storage.shared[tid]; - #pragma unroll - for(int offset = 1; offset < NumSections; offset *= 2) - x = shfl_max(x, offset, NumSections); - storage.shared[tid] = x; - } - __syncthreads(); - - int reduction = storage.shared[NumSections - 1]; - __syncthreads(); - - return reduction; - } +struct CTAReduce> { + typedef mgpu::maximum Op; + enum {Size = NT, Capacity = WARP_SIZE}; + struct Storage {int shared[Capacity];}; + + MGPU_DEVICE static int Reduce( + int tid, + int x, + Storage& storage, + Op op = Op() + ) { + + const int NumSections = WARP_SIZE; + const int SecSize = NT / NumSections; + int lane = (SecSize - 1) & tid; + int sec = tid / SecSize; + +#pragma unroll + for(int offset = 1; offset < SecSize; offset *= 2) + x = shfl_max(x, offset, SecSize); + + if(SecSize - 1 == lane) storage.shared[sec] = x; + __syncthreads(); + + if(tid < NumSections) { + x = storage.shared[tid]; +#pragma unroll + for(int offset = 1; offset < NumSections; offset *= 2) + x = shfl_max(x, offset, NumSections); + storage.shared[tid] = x; + } + __syncthreads(); + + int reduction = storage.shared[NumSections - 1]; + __syncthreads(); + + return reduction; + } }; #endif // __CUDA_ARCH__ >= 300 @@ -158,39 +166,46 @@ struct CTAReduce > { //////////////////////////////////////////////////////////////////////////////// // CTAScan -template > +template> struct CTAScan { - typedef typename Op::result_type T; - enum { Size = NT, Capacity = 2 * NT + 1 }; - struct Storage { T shared[Capacity]; }; - - MGPU_DEVICE static T Scan(int tid, T x, Storage& storage, T* total, - MgpuScanType type = MgpuScanTypeExc, T identity = (T)0, Op op = Op()) { - - storage.shared[tid] = x; - int first = 0; - __syncthreads(); - - #pragma unroll - for(int offset = 1; offset < NT; offset += offset) { - if(tid >= offset) - x = op(storage.shared[first + tid - offset], x); - first = NT - first; - storage.shared[first + tid] = x; - __syncthreads(); - } - *total = storage.shared[first + NT - 1]; - - if(MgpuScanTypeExc == type) - x = tid ? storage.shared[first + tid - 1] : identity; - - __syncthreads(); - return x; - } - MGPU_DEVICE static T Scan(int tid, T x, Storage& storage) { - T total; - return Scan(tid, x, storage, &total, MgpuScanTypeExc, (T)0, Op()); - } + typedef typename Op::result_type T; + enum {Size = NT, Capacity = 2 * NT + 1}; + struct Storage {T shared[Capacity];}; + + MGPU_DEVICE static T Scan( + int tid, + T x, + Storage& storage, + T* total, + MgpuScanType type = MgpuScanTypeExc, + T identity = (T) 0, + Op op = Op() + ) { + + storage.shared[tid] = x; + int first = 0; + __syncthreads(); + +#pragma unroll + for(int offset = 1; offset < NT; offset += offset) { + if(tid >= offset) + x = op(storage.shared[first + tid - offset], x); + first = NT - first; + storage.shared[first + tid] = x; + __syncthreads(); + } + *total = storage.shared[first + NT - 1]; + + if(MgpuScanTypeExc == type) + x = tid ? storage.shared[first + tid - 1] : identity; + + __syncthreads(); + return x; + } + MGPU_DEVICE static T Scan(int tid, T x, Storage& storage) { + T total; + return Scan(tid, x, storage, &total, MgpuScanTypeExc, (T) 0, Op()); + } }; //////////////////////////////////////////////////////////////////////////////// @@ -200,59 +215,66 @@ struct CTAScan { #if __CUDA_ARCH__ >= 300 template -struct CTAScan > { - typedef mgpu::plus Op; - enum { Size = NT, NumSegments = WARP_SIZE, SegSize = NT / NumSegments }; - enum { Capacity = NumSegments + 1 }; - struct Storage { int shared[Capacity + 1]; }; - - MGPU_DEVICE static int Scan(int tid, int x, Storage& storage, int* total, - MgpuScanType type = MgpuScanTypeExc, int identity = 0, Op op = Op()) { - - // Define WARP_SIZE segments that are NT / WARP_SIZE large. - // Each warp makes log(SegSize) shfl_add calls. - // The spine makes log(WARP_SIZE) shfl_add calls. - int lane = (SegSize - 1) & tid; - int segment = tid / SegSize; - - // Scan each segment using shfl_add. - int scan = x; - #pragma unroll - for(int offset = 1; offset < SegSize; offset *= 2) - scan = shfl_add(scan, offset, SegSize); - - // Store the reduction (last element) of each segment into storage. - if(SegSize - 1 == lane) storage.shared[segment] = scan; - __syncthreads(); - - // Warp 0 does a full shfl warp scan on the partials. The total is - // stored to shared[NumSegments]. (NumSegments = WARP_SIZE) - if(tid < NumSegments) { - int y = storage.shared[tid]; - int scan = y; - #pragma unroll - for(int offset = 1; offset < NumSegments; offset *= 2) - scan = shfl_add(scan, offset, NumSegments); - storage.shared[tid] = scan - y; - if(NumSegments - 1 == tid) storage.shared[NumSegments] = scan; - } - __syncthreads(); - - // Add the scanned partials back in and convert to exclusive scan. - scan += storage.shared[segment]; - if(MgpuScanTypeExc == type) { - scan -= x; - if(identity && !tid) scan = identity; - } - *total = storage.shared[NumSegments]; - __syncthreads(); - - return scan; - } - MGPU_DEVICE static int Scan(int tid, int x, Storage& storage) { - int total; - return Scan(tid, x, storage, &total, MgpuScanTypeExc, 0); - } +struct CTAScan> { + typedef mgpu::plus Op; + enum {Size = NT, NumSegments = WARP_SIZE, SegSize = NT / NumSegments}; + enum {Capacity = NumSegments + 1}; + struct Storage {int shared[Capacity + 1];}; + + MGPU_DEVICE static int Scan( + int tid, + int x, + Storage& storage, + int* total, + MgpuScanType type = MgpuScanTypeExc, + int identity = 0, + Op op = Op() + ) { + + // Define WARP_SIZE segments that are NT / WARP_SIZE large. + // Each warp makes log(SegSize) shfl_add calls. + // The spine makes log(WARP_SIZE) shfl_add calls. + int lane = (SegSize - 1) & tid; + int segment = tid / SegSize; + + // Scan each segment using shfl_add. + int scan = x; +#pragma unroll + for(int offset = 1; offset < SegSize; offset *= 2) + scan = shfl_add(scan, offset, SegSize); + + // Store the reduction (last element) of each segment into storage. + if(SegSize - 1 == lane) storage.shared[segment] = scan; + __syncthreads(); + + // Warp 0 does a full shfl warp scan on the partials. The total is + // stored to shared[NumSegments]. (NumSegments = WARP_SIZE) + if(tid < NumSegments) { + int y = storage.shared[tid]; + int scan = y; +#pragma unroll + for(int offset = 1; offset < NumSegments; offset *= 2) + scan = shfl_add(scan, offset, NumSegments); + storage.shared[tid] = scan - y; + if(NumSegments - 1 == tid) storage.shared[NumSegments] = scan; + } + __syncthreads(); + + // Add the scanned partials back in and convert to exclusive scan. + scan += storage.shared[segment]; + if(MgpuScanTypeExc == type) { + scan -= x; + if(identity && !tid) scan = identity; + } + *total = storage.shared[NumSegments]; + __syncthreads(); + + return scan; + } + MGPU_DEVICE static int Scan(int tid, int x, Storage& storage) { + int total; + return Scan(tid, x, storage, &total, MgpuScanTypeExc, 0); + } }; #endif // __CUDA_ARCH__ >= 300 @@ -262,47 +284,47 @@ struct CTAScan > { template MGPU_DEVICE int CTABinaryScan(int tid, bool x, int* shared, int* total) { - const int NumWarps = NT / WARP_SIZE; - int warp = tid / WARP_SIZE; - int lane = (WARP_SIZE - 1); + const int NumWarps = NT / WARP_SIZE; + int warp = tid / WARP_SIZE; + int lane = (WARP_SIZE - 1); - // Store the bit totals for each warp. - uint bits = __ballot(x); - shared[warp] = popc(bits); - __syncthreads(); + // Store the bit totals for each warp. + uint bits = __ballot(x); + shared[warp] = popc(bits); + __syncthreads(); #if __CUDA_ARCH__ >= 300 - if(tid < NumWarps) { - int x = shared[tid]; - int scan = x; - #pragma unroll - for(int offset = 1; offset < NumWarps; offset *= 2) - scan = shfl_add(scan, offset, NumWarps); - shared[tid] = scan - x; - } - __syncthreads(); + if(tid < NumWarps) { + int x = shared[tid]; + int scan = x; +#pragma unroll + for(int offset = 1; offset < NumWarps; offset *= 2) + scan = shfl_add(scan, offset, NumWarps); + shared[tid] = scan - x; + } + __syncthreads(); #else - // Thread 0 scans warp totals. - if(!tid) { - int scan = 0; - #pragma unroll - for(int i = 0; i < NumWarps; ++i) { - int y = shared[i]; - shared[i] = scan; - scan += y; - } - shared[NumWarps] = scan; - } - __syncthreads(); + // Thread 0 scans warp totals. + if(!tid) { + int scan = 0; +#pragma unroll + for(int i = 0; i < NumWarps; ++i) { + int y = shared[i]; + shared[i] = scan; + scan += y; + } + shared[NumWarps] = scan; + } + __syncthreads(); #endif // __CUDA_ARCH__ >= 300 - // Add the warp scan back into the partials. - int scan = shared[warp] + __popc(bfe(bits, 0, lane)); - *total = shared[NumWarps]; - __syncthreads(); - return scan; + // Add the warp scan back into the partials. + int scan = shared[warp] + __popc(bfe(bits, 0, lane)); + *total = shared[NumWarps]; + __syncthreads(); + return scan; } } // namespace mgpu diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctasearch.cuh b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctasearch.cuh index a033aa0..77fc05a 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctasearch.cuh +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctasearch.cuh @@ -26,11 +26,11 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ #pragma once @@ -40,168 +40,214 @@ namespace mgpu { template -MGPU_HOST_DEVICE void BinarySearchIt(It data, int& begin, int& end, T key, - int shift, Comp comp) { - - IntT scale = (1<< shift) - 1; - int mid = (int)((begin + scale * end)>> shift); - - T key2 = data[mid]; - bool pred = (MgpuBoundsUpper == Bounds) ? - !comp(key, key2) : - comp(key2, key); - if(pred) begin = mid + 1; - else end = mid; + typename Comp> +MGPU_HOST_DEVICE void BinarySearchIt( + It data, + int& begin, + int& end, + T key, + int shift, + Comp comp +) { + + IntT scale = (1 << shift) - 1; + int mid = (int) ((begin + scale * end) >> shift); + + T key2 = data[mid]; + bool pred = (MgpuBoundsUpper == Bounds) + ? !comp(key, key2) + : comp(key2, key); + if(pred) begin = mid + 1; + else end = mid; } template -MGPU_HOST_DEVICE int BiasedBinarySearch(It data, int count, T key, int levels, - Comp comp) { - - int begin = 0; - int end = count; - - if(levels >= 4 && begin < end) - BinarySearchIt(data, begin, end, key, 9, comp); - if(levels >= 3 && begin < end) - BinarySearchIt(data, begin, end, key, 7, comp); - if(levels >= 2 && begin < end) - BinarySearchIt(data, begin, end, key, 5, comp); - if(levels >= 1 && begin < end) - BinarySearchIt(data, begin, end, key, 4, comp); - - while(begin < end) - BinarySearchIt(data, begin, end, key, 1, comp); - return begin; + typename Comp> +MGPU_HOST_DEVICE int BiasedBinarySearch( + It data, + int count, + T key, + int levels, + Comp comp +) { + + int begin = 0; + int end = count; + + if(levels >= 4 && begin < end) + BinarySearchIt(data, begin, end, key, 9, comp); + if(levels >= 3 && begin < end) + BinarySearchIt(data, begin, end, key, 7, comp); + if(levels >= 2 && begin < end) + BinarySearchIt(data, begin, end, key, 5, comp); + if(levels >= 1 && begin < end) + BinarySearchIt(data, begin, end, key, 4, comp); + + while(begin < end) { + BinarySearchIt(data, begin, end, key, 1, comp); + } + return begin; } template MGPU_HOST_DEVICE int BinarySearch(It data, int count, T key, Comp comp) { - int begin = 0; - int end = count; - while(begin < end) - BinarySearchIt(data, begin, end, key, 1, comp); - return begin; + int begin = 0; + int end = count; + while(begin < end) { + BinarySearchIt(data, begin, end, key, 1, comp); + } + return begin; } //////////////////////////////////////////////////////////////////////////////// // MergePath search template -MGPU_HOST_DEVICE int MergePath(It1 a, int aCount, It2 b, int bCount, int diag, - Comp comp) { - - typedef typename std::iterator_traits::value_type T; - int begin = max(0, diag - bCount); - int end = min(diag, aCount); - - while(begin < end) { - int mid = (begin + end)>> 1; - T aKey = a[mid]; - T bKey = b[diag - 1 - mid]; - bool pred = (MgpuBoundsUpper == Bounds) ? - comp(aKey, bKey) : - !comp(bKey, aKey); - if(pred) begin = mid + 1; - else end = mid; - } - return begin; +MGPU_HOST_DEVICE int MergePath( + It1 a, + int aCount, + It2 b, + int bCount, + int diag, + Comp comp +) { + + typedef typename std::iterator_traits::value_type T; + int begin = max(0, diag - bCount); + int end = min(diag, aCount); + + while(begin < end) { + int mid = (begin + end) >> 1; + T aKey = a[mid]; + T bKey = b[diag - 1 - mid]; + bool pred = (MgpuBoundsUpper == Bounds) + ? comp(aKey, bKey) + : !comp(bKey, aKey); + if(pred) begin = mid + 1; + else end = mid; + } + return begin; } - //////////////////////////////////////////////////////////////////////////////// // SegmentedMergePath search template -MGPU_HOST_DEVICE int SegmentedMergePath(InputIt keys, int aOffset, int aCount, - int bOffset, int bCount, int leftEnd, int rightStart, int diag, Comp comp) { - - // leftEnd and rightStart are defined from the origin, and diag is defined - // from aOffset. - // We only need to run a Merge Path search if the diagonal intersects the - // segment that strides the left and right halves (i.e. is between leftEnd - // and rightStart). - if(aOffset + diag <= leftEnd) return diag; - if(aOffset + diag >= rightStart) return aCount; - - bCount = min(bCount, rightStart - bOffset); - int begin = max(max(leftEnd - aOffset, 0), diag - bCount); - int end = min(diag, aCount); - - while(begin < end) { - int mid = (begin + end)>> 1; - int ai = aOffset + mid; - int bi = bOffset + diag - 1 - mid; - - bool pred = !comp(keys[bi], keys[ai]); - if(pred) begin = mid + 1; - else end = mid; - } - return begin; +MGPU_HOST_DEVICE int SegmentedMergePath( + InputIt keys, + int aOffset, + int aCount, + int bOffset, + int bCount, + int leftEnd, + int rightStart, + int diag, + Comp comp +) { + + // leftEnd and rightStart are defined from the origin, and diag is defined + // from aOffset. + // We only need to run a Merge Path search if the diagonal intersects the + // segment that strides the left and right halves (i.e. is between leftEnd + // and rightStart). + if(aOffset + diag <= leftEnd) return diag; + if(aOffset + diag >= rightStart) return aCount; + + bCount = min(bCount, rightStart - bOffset); + int begin = max(max(leftEnd - aOffset, 0), diag - bCount); + int end = min(diag, aCount); + + while(begin < end) { + int mid = (begin + end) >> 1; + int ai = aOffset + mid; + int bi = bOffset + diag - 1 - mid; + + bool pred = !comp(keys[bi], keys[ai]); + if(pred) begin = mid + 1; + else end = mid; + } + return begin; } //////////////////////////////////////////////////////////////////////////////// // BalancedPath search template -MGPU_HOST_DEVICE int2 BalancedPath(InputIt1 a, int aCount, InputIt2 b, - int bCount, int diag, int levels, Comp comp) { - - typedef typename std::iterator_traits::value_type T; - - int p = MergePath(a, aCount, b, bCount, diag, comp); - int aIndex = p; - int bIndex = diag - p; - - bool star = false; - if(bIndex < bCount) { - if(Duplicates) { - T x = b[bIndex]; - - // Search for the beginning of the duplicate run in both A and B. - // Because - int aStart = BiasedBinarySearch(a, aIndex, x, - levels, comp); - int bStart = BiasedBinarySearch(b, bIndex, x, - levels, comp); - - // The distance between the merge path and the lower_bound is the - // 'run'. We add up the a- and b- runs and evenly distribute them to - // get a stairstep path. - int aRun = aIndex - aStart; - int bRun = bIndex - bStart; - int xCount = aRun + bRun; - - // Attempt to advance b and regress a. - int bAdvance = max(xCount>> 1, bRun); - int bEnd = min(bCount, bStart + bAdvance + 1); - int bRunEnd = BinarySearch(b + bIndex, - bEnd - bIndex, x, comp) + bIndex; - bRun = bRunEnd - bStart; - - bAdvance = min(bAdvance, bRun); - int aAdvance = xCount - bAdvance; - - bool roundUp = (aAdvance == bAdvance + 1) && (bAdvance < bRun); - aIndex = aStart + aAdvance; - - if(roundUp) star = true; - } else { - if(aIndex && aCount) { - T aKey = a[aIndex - 1]; - T bKey = b[bIndex]; - - // If the last consumed element in A (aIndex - 1) is the same as - // the next element in B (bIndex), we're sitting at a starred - // partition. - if(!comp(aKey, bKey)) star = true; - } - } - } - return make_int2(aIndex, star); + typename Comp> +MGPU_HOST_DEVICE int2 BalancedPath( + InputIt1 a, + int aCount, + InputIt2 b, + int bCount, + int diag, + int levels, + Comp comp +) { + + typedef typename std::iterator_traits::value_type T; + + int p = MergePath(a, aCount, b, bCount, diag, comp); + int aIndex = p; + int bIndex = diag - p; + + bool star = false; + if(bIndex < bCount) { + if(Duplicates) { + T x = b[bIndex]; + + // Search for the beginning of the duplicate run in both A and B. + // Because + int aStart = BiasedBinarySearch( + a, + aIndex, + x, + levels, + comp + ); + int bStart = BiasedBinarySearch( + b, + bIndex, + x, + levels, + comp + ); + + // The distance between the merge path and the lower_bound is the + // 'run'. We add up the a- and b- runs and evenly distribute them to + // get a stairstep path. + int aRun = aIndex - aStart; + int bRun = bIndex - bStart; + int xCount = aRun + bRun; + + // Attempt to advance b and regress a. + int bAdvance = max(xCount >> 1, bRun); + int bEnd = min(bCount, bStart + bAdvance + 1); + int bRunEnd = BinarySearch( + b + bIndex, + bEnd - bIndex, + x, + comp + ) + bIndex; + bRun = bRunEnd - bStart; + + bAdvance = min(bAdvance, bRun); + int aAdvance = xCount - bAdvance; + + bool roundUp = (aAdvance == bAdvance + 1) && (bAdvance < bRun); + aIndex = aStart + aAdvance; + + if(roundUp) star = true; + } else if(aIndex && aCount) { + T aKey = a[aIndex - 1]; + T bKey = b[bIndex]; + + // If the last consumed element in A (aIndex - 1) is the same as + // the next element in B (bIndex), we're sitting at a starred + // partition. + if(!comp(aKey, bKey)) star = true; + } + } + return make_int2(aIndex, star); } } // namespace mgpu diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctasegscan.cuh b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctasegscan.cuh index 1e0c8a5..5542214 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctasegscan.cuh +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/ctasegscan.cuh @@ -26,11 +26,11 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ #pragma once @@ -44,94 +44,108 @@ namespace mgpu { template MGPU_DEVICE int DeviceFindSegScanDelta(int tid, bool flag, int* delta_shared) { - const int NumWarps = NT / 32; - - int warp = tid / 32; - int lane = 31 & tid; - uint warpMask = 0xffffffff>> (31 - lane); // inclusive search - uint ctaMask = 0x7fffffff>> (31 - lane); // exclusive search - - uint warpBits = __ballot(flag); - delta_shared[warp] = warpBits; - __syncthreads(); - - if(tid < NumWarps) { - uint ctaBits = __ballot(0 != delta_shared[tid]); - int warpSegment = 31 - clz(ctaMask & ctaBits); - int start = (-1 != warpSegment) ? - (31 - clz(delta_shared[warpSegment]) + 32 * warpSegment) : 0; - delta_shared[NumWarps + tid] = start; - } - __syncthreads(); - - // Find the closest flag to the left of this thread within the warp. - // Include the flag for this thread. - int start = 31 - clz(warpMask & warpBits); - if(-1 != start) start += ~31 & tid; - else start = delta_shared[NumWarps + warp]; - __syncthreads(); - - return tid - start; + const int NumWarps = NT / 32; + + int warp = tid / 32; + int lane = 31 & tid; + uint warpMask = 0xffffffff >> (31 - lane); // inclusive search + uint ctaMask = 0x7fffffff >> (31 - lane); // exclusive search + + uint warpBits = __ballot(flag); + delta_shared[warp] = warpBits; + __syncthreads(); + + if(tid < NumWarps) { + uint ctaBits = __ballot(0 != delta_shared[tid]); + int warpSegment = 31 - clz(ctaMask & ctaBits); + int start = (-1 != warpSegment) + ? (31 - clz(delta_shared[warpSegment]) + 32 * warpSegment) : 0; + delta_shared[NumWarps + tid] = start; + } + __syncthreads(); + + // Find the closest flag to the left of this thread within the warp. + // Include the flag for this thread. + int start = 31 - clz(warpMask & warpBits); + if(-1 != start) start += ~31 & tid; + else start = delta_shared[NumWarps + warp]; + __syncthreads(); + + return tid - start; } //////////////////////////////////////////////////////////////////////////////// // CTASegScan -template > +template> struct CTASegScan { - typedef _Op Op; - typedef typename Op::result_type T; - enum { NumWarps = NT / 32, Size = NT, Capacity = 2 * NT }; - union Storage { - int delta[NumWarps]; - T values[Capacity]; - }; - - // Each thread passes the reduction of the LAST SEGMENT that it covers. - // flag is set to true if there's at least one segment flag in the thread. - // SegScan returns the reduction of values for the first segment in this - // thread over the preceding threads. - // Return the value init for the first thread. - - // When scanning single elements per thread, interpret the flag as a BEGIN - // FLAG. If tid's flag is set, its value belongs to thread tid + 1, not - // thread tid. - - // The function returns the reduction of the last segment in the CTA. - - MGPU_DEVICE static T SegScanDelta(int tid, int tidDelta, T x, - Storage& storage, T* carryOut, T identity = (T)0, Op op = Op()) { - - // Run an inclusive scan - int first = 0; - storage.values[first + tid] = x; - __syncthreads(); - - #pragma unroll - for(int offset = 1; offset < NT; offset += offset) { - if(tidDelta >= offset) - x = op(storage.values[first + tid - offset], x); - first = NT - first; - storage.values[first + tid] = x; - __syncthreads(); - } - - // Get the exclusive scan. - x = tid ? storage.values[first + tid - 1] : identity; - *carryOut = storage.values[first + NT - 1]; - __syncthreads(); - return x; - } - - MGPU_DEVICE static T SegScan(int tid, T x, bool flag, Storage& storage, - T* carryOut, T identity = (T)0, Op op = Op()) { - - // Find the left-most thread that covers the first segment of this - // thread. - int tidDelta = DeviceFindSegScanDelta(tid, flag, storage.delta); - - return SegScanDelta(tid, tidDelta, x, storage, carryOut, identity, op); - } + typedef _Op Op; + typedef typename Op::result_type T; + enum {NumWarps = NT / 32, Size = NT, Capacity = 2 * NT}; + union Storage { + int delta[NumWarps]; + T values[Capacity]; + }; + + // Each thread passes the reduction of the LAST SEGMENT that it covers. + // flag is set to true if there's at least one segment flag in the thread. + // SegScan returns the reduction of values for the first segment in this + // thread over the preceding threads. + // Return the value init for the first thread. + + // When scanning single elements per thread, interpret the flag as a BEGIN + // FLAG. If tid's flag is set, its value belongs to thread tid + 1, not + // thread tid. + + // The function returns the reduction of the last segment in the CTA. + + MGPU_DEVICE static T SegScanDelta( + int tid, + int tidDelta, + T x, + Storage& storage, + T* carryOut, + T identity = (T) 0, + Op op = Op() + ) { + + // Run an inclusive scan + int first = 0; + storage.values[first + tid] = x; + __syncthreads(); + +#pragma unroll + for(int offset = 1; offset < NT; offset += offset) { + if(tidDelta >= offset) + x = op(storage.values[first + tid - offset], x); + first = NT - first; + storage.values[first + tid] = x; + __syncthreads(); + } + + // Get the exclusive scan. + x = tid ? storage.values[first + tid - 1] : identity; + *carryOut = storage.values[first + NT - 1]; + __syncthreads(); + return x; + } + + MGPU_DEVICE static T SegScan( + int tid, + T x, + bool flag, + Storage& storage, + T* carryOut, + T identity = (T) 0, + Op op = Op() + ) { + + // Find the left-most thread that covers the first segment of this + // thread. + int tidDelta = DeviceFindSegScanDelta(tid, flag, storage.delta); + + return SegScanDelta(tid, tidDelta, x, storage, carryOut, identity, op); + } }; } // namespace mgpu diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/devicetypes.cuh b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/devicetypes.cuh index 0620f21..b09144d 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/devicetypes.cuh +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/devicetypes.cuh @@ -1,6 +1,6 @@ /***************************************************************************** * Copyright (c) 2013, NVIDIA CORPORATION. All rights reserved. - * + * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright @@ -11,10 +11,10 @@ * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; @@ -26,17 +26,17 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ #pragma once #if __CUDA_ARCH__ == 100 - #error "COMPUTE CAPABILITY 1.0 NOT SUPPORTED BY MPGU. TRY 2.0!" -#endif +#error "COMPUTE CAPABILITY 1.0 NOT SUPPORTED BY MPGU. TRY 2.0!" +#endif #include #include "../util/static.h" @@ -61,27 +61,27 @@ const int LOG_WARP_SIZE = 5; template struct less : public std::binary_function { - MGPU_HOST_DEVICE bool operator()(T a, T b) { return a < b; } + MGPU_HOST_DEVICE bool operator()(T a, T b) { return a < b; } }; template struct less_equal : public std::binary_function { - MGPU_HOST_DEVICE bool operator()(T a, T b) { return a <= b; } + MGPU_HOST_DEVICE bool operator()(T a, T b) { return a <= b; } }; template struct greater : public std::binary_function { - MGPU_HOST_DEVICE bool operator()(T a, T b) { return a > b; } + MGPU_HOST_DEVICE bool operator()(T a, T b) { return a > b; } }; template struct greater_equal : public std::binary_function { - MGPU_HOST_DEVICE bool operator()(T a, T b) { return a >= b; } + MGPU_HOST_DEVICE bool operator()(T a, T b) { return a >= b; } }; template struct equal_to : public std::binary_function { - MGPU_HOST_DEVICE bool operator()(T a, T b) { return a == b; } + MGPU_HOST_DEVICE bool operator()(T a, T b) { return a == b; } }; template struct not_equal_to : public std::binary_function { - MGPU_HOST_DEVICE bool operator()(T a, T b) { return a != b; } + MGPU_HOST_DEVICE bool operator()(T a, T b) { return a != b; } }; //////////////////////////////////////////////////////////////////////////////// @@ -89,275 +89,259 @@ struct not_equal_to : public std::binary_function { template struct plus : public std::binary_function { - MGPU_HOST_DEVICE T operator()(T a, T b) { return a + b; } + MGPU_HOST_DEVICE T operator()(T a, T b) { return a + b; } }; template struct minus : public std::binary_function { - MGPU_HOST_DEVICE T operator()(T a, T b) { return a - b; } + MGPU_HOST_DEVICE T operator()(T a, T b) { return a - b; } }; template struct multiplies : public std::binary_function { - MGPU_HOST_DEVICE T operator()(T a, T b) { return a * b; } + MGPU_HOST_DEVICE T operator()(T a, T b) { return a * b; } }; template struct modulus : public std::binary_function { - MGPU_HOST_DEVICE T operator()(T a, T b) { return a % b; } + MGPU_HOST_DEVICE T operator()(T a, T b) { return a % b; } }; template struct bit_or : public std::binary_function { - MGPU_HOST_DEVICE T operator()(T a, T b) { return a | b; } + MGPU_HOST_DEVICE T operator()(T a, T b) { return a | b; } }; template struct bit_and : public std::binary_function { - MGPU_HOST_DEVICE T operator()(T a, T b) { return a & b; } + MGPU_HOST_DEVICE T operator()(T a, T b) { return a & b; } }; template struct bit_xor : public std::binary_function { - MGPU_HOST_DEVICE T operator()(T a, T b) { return a ^ b; } + MGPU_HOST_DEVICE T operator()(T a, T b) { return a ^ b; } }; template struct maximum : public std::binary_function { - MGPU_HOST_DEVICE T operator()(T a, T b) { return max(a, b); } + MGPU_HOST_DEVICE T operator()(T a, T b) { return max(a, b); } }; template struct minimum : public std::binary_function { - MGPU_HOST_DEVICE T operator()(T a, T b) { return min(a, b); } + MGPU_HOST_DEVICE T operator()(T a, T b) { return min(a, b); } }; //////////////////////////////////////////////////////////////////////////////// template MGPU_HOST_DEVICE void swap(T& a, T& b) { - T c = a; - a = b; - b = c; + T c = a; + a = b; + b = c; } template struct DevicePair { - T x, y; + T x, y; }; template MGPU_HOST_DEVICE DevicePair MakeDevicePair(T x, T y) { - DevicePair p = { x, y }; - return p; + DevicePair p = {x, y}; + return p; } -template struct numeric_limits; -template<> struct numeric_limits { - MGPU_HOST_DEVICE static int min() { return INT_MIN; } - MGPU_HOST_DEVICE static int max() { return INT_MAX; } - MGPU_HOST_DEVICE static int lowest() { return INT_MIN; } - MGPU_HOST_DEVICE static int AddIdent() { return 0; } - MGPU_HOST_DEVICE static int MulIdent() { return 1; } +template +struct numeric_limits; +template<> +struct numeric_limits { + MGPU_HOST_DEVICE static int min() { return INT_MIN; } + MGPU_HOST_DEVICE static int max() { return INT_MAX; } + MGPU_HOST_DEVICE static int lowest() { return INT_MIN; } + MGPU_HOST_DEVICE static int AddIdent() { return 0; } + MGPU_HOST_DEVICE static int MulIdent() { return 1; } }; -template<> struct numeric_limits { - MGPU_HOST_DEVICE static long long min() { return LLONG_MIN; } - MGPU_HOST_DEVICE static long long max() { return LLONG_MAX; } - MGPU_HOST_DEVICE static long long lowest() { return LLONG_MIN; } - MGPU_HOST_DEVICE static long long AddIdent() { return 0; } - MGPU_HOST_DEVICE static long long MulIdent() { return 1; } +template<> +struct numeric_limits { + MGPU_HOST_DEVICE static long long min() { return LLONG_MIN; } + MGPU_HOST_DEVICE static long long max() { return LLONG_MAX; } + MGPU_HOST_DEVICE static long long lowest() { return LLONG_MIN; } + MGPU_HOST_DEVICE static long long AddIdent() { return 0; } + MGPU_HOST_DEVICE static long long MulIdent() { return 1; } }; -template<> struct numeric_limits { - MGPU_HOST_DEVICE static uint min() { return 0; } - MGPU_HOST_DEVICE static uint max() { return UINT_MAX; } - MGPU_HOST_DEVICE static uint lowest() { return 0; } - MGPU_HOST_DEVICE static uint AddIdent() { return 0; } - MGPU_HOST_DEVICE static uint MulIdent() { return 1; } +template<> +struct numeric_limits { + MGPU_HOST_DEVICE static uint min() { return 0; } + MGPU_HOST_DEVICE static uint max() { return UINT_MAX; } + MGPU_HOST_DEVICE static uint lowest() { return 0; } + MGPU_HOST_DEVICE static uint AddIdent() { return 0; } + MGPU_HOST_DEVICE static uint MulIdent() { return 1; } }; -template<> struct numeric_limits { - MGPU_HOST_DEVICE static unsigned long long min() { return 0; } - MGPU_HOST_DEVICE static unsigned long long max() { return ULLONG_MAX; } - MGPU_HOST_DEVICE static unsigned long long lowest() { return 0; } - MGPU_HOST_DEVICE static unsigned long long AddIdent() { return 0; } - MGPU_HOST_DEVICE static unsigned long long MulIdent() { return 1; } +template<> +struct numeric_limits { + MGPU_HOST_DEVICE static unsigned long long min() { return 0; } + MGPU_HOST_DEVICE static unsigned long long max() { return ULLONG_MAX; } + MGPU_HOST_DEVICE static unsigned long long lowest() { return 0; } + MGPU_HOST_DEVICE static unsigned long long AddIdent() { return 0; } + MGPU_HOST_DEVICE static unsigned long long MulIdent() { return 1; } }; -template<> struct numeric_limits { - MGPU_HOST_DEVICE static float min() { return FLT_MIN; } - MGPU_HOST_DEVICE static float max() { return FLT_MAX; } - MGPU_HOST_DEVICE static float lowest() { return -FLT_MAX; } - MGPU_HOST_DEVICE static float AddIdent() { return 0; } - MGPU_HOST_DEVICE static float MulIdent() { return 1; } +template<> +struct numeric_limits { + MGPU_HOST_DEVICE static float min() { return FLT_MIN; } + MGPU_HOST_DEVICE static float max() { return FLT_MAX; } + MGPU_HOST_DEVICE static float lowest() { return -FLT_MAX; } + MGPU_HOST_DEVICE static float AddIdent() { return 0; } + MGPU_HOST_DEVICE static float MulIdent() { return 1; } }; -template<> struct numeric_limits { - MGPU_HOST_DEVICE static double min() { return DBL_MIN; } - MGPU_HOST_DEVICE static double max() { return DBL_MAX; } - MGPU_HOST_DEVICE static double lowest() { return -DBL_MAX; } - MGPU_HOST_DEVICE static double AddIdent() { return 0; } - MGPU_HOST_DEVICE static double MulIdent() { return 1; } +template<> +struct numeric_limits { + MGPU_HOST_DEVICE static double min() { return DBL_MIN; } + MGPU_HOST_DEVICE static double max() { return DBL_MAX; } + MGPU_HOST_DEVICE static double lowest() { return -DBL_MAX; } + MGPU_HOST_DEVICE static double AddIdent() { return 0; } + MGPU_HOST_DEVICE static double MulIdent() { return 1; } }; - -MGPU_HOST_DEVICE int2 operator+(int2 a, int2 b) { - return make_int2(a.x + b.x, a.y + b.y); -} +MGPU_HOST_DEVICE int2 operator+(int2 a, int2 b) { return make_int2(a.x + b.x, a.y + b.y); } MGPU_HOST_DEVICE int2& operator+=(int2& a, int2 b) { - a = a + b; - return a; -} -MGPU_HOST_DEVICE int2 operator*(int2 a, int2 b) { - return make_int2(a.x * b.x, a.y * b.y); + a = a + b; + return a; } +MGPU_HOST_DEVICE int2 operator*(int2 a, int2 b) { return make_int2(a.x * b.x, a.y * b.y); } MGPU_HOST_DEVICE int2& operator*=(int2& a, int2 b) { - a = a * b; - return a; + a = a * b; + return a; } template MGPU_HOST_DEVICE T max(T a, T b) { #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ < 100) - return std::max(a, b); + return std::max(a, b); #else - return (a < b) ? b : a; + return (a < b) ? b : a; #endif } template MGPU_HOST_DEVICE T min(T a, T b) { #if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ < 100) - return std::min(a, b); + return std::min(a, b); #else - return (b < a) ? b : a; + return (b < a) ? b : a; #endif } -MGPU_HOST_DEVICE int2 max(int2 a, int2 b) { - return make_int2(max(a.x, b.x), max(a.y, b.y)); -} +MGPU_HOST_DEVICE int2 max(int2 a, int2 b) { return make_int2(max(a.x, b.x), max(a.y, b.y)); } -MGPU_HOST_DEVICE int2 min(int2 a, int2 b) { - return make_int2(min(a.x, b.x), min(a.y, b.y)); -} +MGPU_HOST_DEVICE int2 min(int2 a, int2 b) { return make_int2(min(a.x, b.x), min(a.y, b.y)); } -template<> struct numeric_limits { - MGPU_HOST_DEVICE static int2 min() { return make_int2(INT_MIN, INT_MIN); } - MGPU_HOST_DEVICE static int2 max() { return make_int2(INT_MAX, INT_MAX); } - MGPU_HOST_DEVICE static int2 lowest() { - return make_int2(INT_MIN, INT_MIN); - } - MGPU_HOST_DEVICE static int2 AddIdent() { return make_int2(0, 0); } - MGPU_HOST_DEVICE static int2 MulIdent() { return make_int2(1, 1); } +template<> +struct numeric_limits { + MGPU_HOST_DEVICE static int2 min() { return make_int2(INT_MIN, INT_MIN); } + MGPU_HOST_DEVICE static int2 max() { return make_int2(INT_MAX, INT_MAX); } + MGPU_HOST_DEVICE static int2 lowest() { return make_int2(INT_MIN, INT_MIN); } + MGPU_HOST_DEVICE static int2 AddIdent() { return make_int2(0, 0); } + MGPU_HOST_DEVICE static int2 MulIdent() { return make_int2(1, 1); } }; template class constant_iterator : public std::iterator_traits { public: - MGPU_HOST_DEVICE constant_iterator(T value) : _value(value) { } - - MGPU_HOST_DEVICE T operator[](ptrdiff_t i) const { - return _value; - } - MGPU_HOST_DEVICE T operator*() const { - return _value; - } - MGPU_HOST_DEVICE constant_iterator operator+(ptrdiff_t diff) const { - return constant_iterator(_value); - } - MGPU_HOST_DEVICE constant_iterator operator-(ptrdiff_t diff) const { - return constant_iterator(_value); - } - MGPU_HOST_DEVICE constant_iterator& operator+=(ptrdiff_t diff) { - return *this; - } - MGPU_HOST_DEVICE constant_iterator& operator-=(ptrdiff_t diff) { - return *this; - } + MGPU_HOST_DEVICE constant_iterator(T value) : _value(value) {} + + MGPU_HOST_DEVICE T operator[](ptrdiff_t i) const { + return _value; + } + MGPU_HOST_DEVICE T operator*() const { + return _value; + } + MGPU_HOST_DEVICE constant_iterator operator+(ptrdiff_t diff) const { + return constant_iterator(_value); + } + MGPU_HOST_DEVICE constant_iterator operator-(ptrdiff_t diff) const { + return constant_iterator(_value); + } + MGPU_HOST_DEVICE constant_iterator& operator+=(ptrdiff_t diff) { return *this; } + MGPU_HOST_DEVICE constant_iterator& operator-=(ptrdiff_t diff) { return *this; } + private: - T _value; + T _value; }; template class counting_iterator : public std::iterator_traits { public: - MGPU_HOST_DEVICE counting_iterator(T value) : _value(value) { } - - MGPU_HOST_DEVICE T operator[](ptrdiff_t i) { - return _value + i; - } - MGPU_HOST_DEVICE T operator*() { - return _value; - } - MGPU_HOST_DEVICE counting_iterator operator+(ptrdiff_t diff) { - return counting_iterator(_value + diff); - } - MGPU_HOST_DEVICE counting_iterator operator-(ptrdiff_t diff) { - return counting_iterator(_value - diff); - } - MGPU_HOST_DEVICE counting_iterator& operator+=(ptrdiff_t diff) { - _value += diff; - return *this; - } - MGPU_HOST_DEVICE counting_iterator& operator-=(ptrdiff_t diff) { - _value -= diff; - return *this; - } + MGPU_HOST_DEVICE counting_iterator(T value) : _value(value) {} + + MGPU_HOST_DEVICE T operator[](ptrdiff_t i) { return _value + i; } + MGPU_HOST_DEVICE T operator*() { return _value; } + MGPU_HOST_DEVICE counting_iterator operator+(ptrdiff_t diff) { return counting_iterator(_value + diff); } + MGPU_HOST_DEVICE counting_iterator operator-(ptrdiff_t diff) { return counting_iterator(_value - diff); } + MGPU_HOST_DEVICE counting_iterator& operator+=(ptrdiff_t diff) { + _value += diff; + return *this; + } + MGPU_HOST_DEVICE counting_iterator& operator-=(ptrdiff_t diff) { + _value -= diff; + return *this; + } + private: - T _value; + T _value; }; template class step_iterator : public std::iterator_traits { public: - MGPU_HOST_DEVICE step_iterator(T base, T step) : - _base(base), _step(step), _offset(0) { } - - MGPU_HOST_DEVICE T operator[](ptrdiff_t i) { - return _base + (_offset + i) * _step; - } - MGPU_HOST_DEVICE T operator*() { - return _base + _offset * _step; - } - MGPU_HOST_DEVICE step_iterator operator+(ptrdiff_t diff) { - step_iterator it = *this; - it._offset += diff; - return it; - } - MGPU_HOST_DEVICE step_iterator operator-(ptrdiff_t diff) { - step_iterator it = *this; - it._offset -= diff; - return it; - } - MGPU_HOST_DEVICE step_iterator& operator+=(ptrdiff_t diff) { - _offset += diff; - return *this; - } - MGPU_HOST_DEVICE step_iterator& operator-=(ptrdiff_t diff) { - _offset -= diff; - return *this; - } + MGPU_HOST_DEVICE step_iterator(T base, T step) : _base(base), + _step(step), + _offset(0) {} + + MGPU_HOST_DEVICE T operator[](ptrdiff_t i) { return _base + (_offset + i) * _step; } + MGPU_HOST_DEVICE T operator*() { return _base + _offset * _step; } + MGPU_HOST_DEVICE step_iterator operator+(ptrdiff_t diff) { + step_iterator it = *this; + it._offset += diff; + return it; + } + MGPU_HOST_DEVICE step_iterator operator-(ptrdiff_t diff) { + step_iterator it = *this; + it._offset -= diff; + return it; + } + MGPU_HOST_DEVICE step_iterator& operator+=(ptrdiff_t diff) { + _offset += diff; + return *this; + } + MGPU_HOST_DEVICE step_iterator& operator-=(ptrdiff_t diff) { + _offset -= diff; + return *this; + } + private: - ptrdiff_t _offset; - T _base, _step; + ptrdiff_t _offset; + T _base, _step; }; } // namespace mgpu - template -MGPU_HOST_DEVICE mgpu::counting_iterator operator+(ptrdiff_t diff, - mgpu::counting_iterator it) { - return it + diff; -} +MGPU_HOST_DEVICE mgpu::counting_iterator operator+( + ptrdiff_t diff, + mgpu::counting_iterator it +) { return it + diff; } template -MGPU_HOST_DEVICE mgpu::counting_iterator operator-(ptrdiff_t diff, - mgpu::counting_iterator it) { - return it + (-diff); -} +MGPU_HOST_DEVICE mgpu::counting_iterator operator-( + ptrdiff_t diff, + mgpu::counting_iterator it +) { return it + (-diff); } template -MGPU_HOST_DEVICE mgpu::step_iterator operator+(ptrdiff_t diff, - mgpu::step_iterator it) { - return it + diff; -} +MGPU_HOST_DEVICE mgpu::step_iterator operator+( + ptrdiff_t diff, + mgpu::step_iterator it +) { return it + diff; } template -MGPU_HOST_DEVICE mgpu::step_iterator operator-(ptrdiff_t diff, - mgpu::step_iterator it) { - return it + (-diff); -} +MGPU_HOST_DEVICE mgpu::step_iterator operator-( + ptrdiff_t diff, + mgpu::step_iterator it +) { return it + (-diff); } diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/deviceutil.cuh b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/deviceutil.cuh index 8852768..419a518 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/deviceutil.cuh +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/deviceutil.cuh @@ -26,11 +26,11 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ #pragma once @@ -39,43 +39,41 @@ namespace mgpu { // Get the difference between two pointers in bytes. -MGPU_HOST_DEVICE ptrdiff_t PtrDiff(const void* a, const void* b) { - return (const byte*)b - (const byte*)a; -} +MGPU_HOST_DEVICE ptrdiff_t PtrDiff(const void* a, const void* b) { return (const byte*) b - (const byte*) a; } // Offset a pointer by i bytes. template -MGPU_HOST_DEVICE const T* PtrOffset(const T* p, ptrdiff_t i) { - return (const T*)((const byte*)p + i); -} +MGPU_HOST_DEVICE const T* PtrOffset(const T* p, ptrdiff_t i) { return (const T*) ((const byte*) p + i); } template -MGPU_HOST_DEVICE T* PtrOffset(T* p, ptrdiff_t i) { - return (T*)((byte*)p + i); -} +MGPU_HOST_DEVICE T* PtrOffset(T* p, ptrdiff_t i) { return (T*) ((byte*) p + i); } //////////////////////////////////////////////////////////////////////////////// // Task range support // Evenly distributes variable-length arrays over a fixed number of CTAs. MGPU_HOST int2 DivideTaskRange(int numItems, int numWorkers) { - div_t d = div(numItems, numWorkers); - return make_int2(d.quot, d.rem); + div_t d = div(numItems, numWorkers); + return make_int2(d.quot, d.rem); } MGPU_HOST_DEVICE int2 ComputeTaskRange(int block, int2 task) { - int2 range; - range.x = task.x * block; - range.x += min(block, task.y); - range.y = range.x + task.x + (block < task.y); - return range; + int2 range; + range.x = task.x * block; + range.x += min(block, task.y); + range.y = range.x + task.x + (block < task.y); + return range; } -MGPU_HOST_DEVICE int2 ComputeTaskRange(int block, int2 task, int blockSize, - int count) { - int2 range = ComputeTaskRange(block, task); - range.x *= blockSize; - range.y = min(count, range.y * blockSize); - return range; +MGPU_HOST_DEVICE int2 ComputeTaskRange( + int block, + int2 task, + int blockSize, + int count +) { + int2 range = ComputeTaskRange(block, task); + range.x *= blockSize; + range.y = min(count, range.y * blockSize); + return range; } //////////////////////////////////////////////////////////////////////////////// @@ -83,19 +81,22 @@ MGPU_HOST_DEVICE int2 ComputeTaskRange(int block, int2 task, int blockSize, // Input array flags is a bit array with 32 head flags per word. // ExtractThreadHeadFlags returns numBits flags starting at bit index. -MGPU_HOST_DEVICE uint DeviceExtractHeadFlags(const uint* flags, int index, - int numBits) { - - int index2 = index>> 5; - int shift = 31 & index; - uint headFlags = flags[index2]>> shift; - int shifted = 32 - shift; - - if(shifted < numBits) - // We also need to shift in the next set of bits. - headFlags = bfi(flags[index2 + 1], headFlags, shifted, shift); - headFlags &= (1<< numBits) - 1; - return headFlags; +MGPU_HOST_DEVICE uint DeviceExtractHeadFlags( + const uint* flags, + int index, + int numBits +) { + + int index2 = index >> 5; + int shift = 31 & index; + uint headFlags = flags[index2] >> shift; + int shifted = 32 - shift; + + if(shifted < numBits) + // We also need to shift in the next set of bits. + headFlags = bfi(flags[index2 + 1], headFlags, shifted, shift); + headFlags &= (1 << numBits) - 1; + return headFlags; } //////////////////////////////////////////////////////////////////////////////// @@ -105,39 +106,41 @@ MGPU_HOST_DEVICE uint DeviceExtractHeadFlags(const uint* flags, int index, // return packed words. template -MGPU_DEVICE uint DevicePackHeadFlags(uint threadBits, int tid, - uint* flags_shared) { - - const int WordCount = NT * VT / 32; - - // Each thread stores its thread bits to flags_shared[tid]. - flags_shared[tid] = threadBits; - __syncthreads(); - - uint packed = 0; - if(tid < WordCount) { - const int Items = MGPU_DIV_UP(32, VT); - int index = 32 * tid; - int first = index / VT; - int bit = 0; - - int rem = index - VT * first; - packed = flags_shared[first]>> rem; - bit = VT - rem; - ++first; - - #pragma unroll - for(int i = 0; i < Items; ++i) { - if(i < Items - 1 || bit < 32) { - uint x = flags_shared[first + i]; - if(bit < 32) packed |= x<< bit; - bit += VT; - } - } - } - __syncthreads(); - - return packed; +MGPU_DEVICE uint DevicePackHeadFlags( + uint threadBits, + int tid, + uint* flags_shared +) { + + const int WordCount = NT * VT / 32; + + // Each thread stores its thread bits to flags_shared[tid]. + flags_shared[tid] = threadBits; + __syncthreads(); + + uint packed = 0; + if(tid < WordCount) { + const int Items = MGPU_DIV_UP(32, VT); + int index = 32 * tid; + int first = index / VT; + int bit = 0; + + int rem = index - VT * first; + packed = flags_shared[first] >> rem; + bit = VT - rem; + ++first; + +#pragma unroll + for(int i = 0; i < Items; ++i) + if(i < Items - 1 || bit < 32) { + uint x = flags_shared[first + i]; + if(bit < 32) packed |= x << bit; + bit += VT; + } + } + __syncthreads(); + + return packed; } } // namespace mgpu diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/intrinsics.cuh b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/intrinsics.cuh index 3f37978..ebb3d39 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/intrinsics.cuh +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/intrinsics.cuh @@ -26,11 +26,11 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ #include "devicetypes.cuh" @@ -41,40 +41,19 @@ namespace mgpu { -MGPU_HOST_DEVICE uint2 ulonglong_as_uint2(uint64 x) { - return *reinterpret_cast(&x); -} -MGPU_HOST_DEVICE uint64 uint2_as_ulonglong(uint2 x) { - return *reinterpret_cast(&x); -} - -MGPU_HOST_DEVICE int2 longlong_as_int2(int64 x) { - return *reinterpret_cast(&x); -} -MGPU_HOST_DEVICE int64 int2_as_longlong(int2 x) { - return *reinterpret_cast(&x); -} +MGPU_HOST_DEVICE uint2 ulonglong_as_uint2(uint64 x) { return *reinterpret_cast(&x); } +MGPU_HOST_DEVICE uint64 uint2_as_ulonglong(uint2 x) { return *reinterpret_cast(&x); } -MGPU_HOST_DEVICE int2 double_as_int2(double x) { - return *reinterpret_cast(&x); -} -MGPU_HOST_DEVICE double int2_as_double(int2 x) { - return *reinterpret_cast(&x); -} +MGPU_HOST_DEVICE int2 longlong_as_int2(int64 x) { return *reinterpret_cast(&x); } +MGPU_HOST_DEVICE int64 int2_as_longlong(int2 x) { return *reinterpret_cast(&x); } -MGPU_HOST_DEVICE void SetDoubleX(double& d, int x) { - reinterpret_cast(&d)[0] = x; -} -MGPU_HOST_DEVICE int GetDoubleX(double d) { - return double_as_int2(d).x; -} -MGPU_HOST_DEVICE void SetDoubleY(double& d, int y) { - reinterpret_cast(&d)[1] = y; -} -MGPU_HOST_DEVICE int GetDoubleY(double d) { - return double_as_int2(d).y; -} +MGPU_HOST_DEVICE int2 double_as_int2(double x) { return *reinterpret_cast(&x); } +MGPU_HOST_DEVICE double int2_as_double(int2 x) { return *reinterpret_cast(&x); } +MGPU_HOST_DEVICE void SetDoubleX(double& d, int x) { reinterpret_cast(&d)[0] = x; } +MGPU_HOST_DEVICE int GetDoubleX(double d) { return double_as_int2(d).x; } +MGPU_HOST_DEVICE void SetDoubleY(double& d, int y) { reinterpret_cast(&d)[1] = y; } +MGPU_HOST_DEVICE int GetDoubleY(double d) { return double_as_int2(d).y; } //////////////////////////////////////////////////////////////////////////////// // PTX for bfe and bfi @@ -82,85 +61,89 @@ MGPU_HOST_DEVICE int GetDoubleY(double d) { #if __CUDA_ARCH__ >= 200 MGPU_DEVICE uint bfe_ptx(uint x, uint bit, uint numBits) { - uint result; - asm("bfe.u32 %0, %1, %2, %3;" : - "=r"(result) : "r"(x), "r"(bit), "r"(numBits)); - return result; + uint result; + asm ("bfe.u32 %0, %1, %2, %3;" : + "=r" (result) : "r" (x), "r" (bit), "r" (numBits)); + return result; } - MGPU_DEVICE uint bfi_ptx(uint x, uint y, uint bit, uint numBits) { - uint result; - asm("bfi.b32 %0, %1, %2, %3, %4;" : - "=r"(result) : "r"(x), "r"(y), "r"(bit), "r"(numBits)); - return result; + uint result; + asm ("bfi.b32 %0, %1, %2, %3, %4;" : + "=r" (result) : "r" (x), "r" (y), "r" (bit), "r" (numBits)); + return result; } MGPU_DEVICE uint prmt_ptx(uint a, uint b, uint index) { - uint ret; - asm("prmt.b32 %0, %1, %2, %3;" : "=r"(ret) : "r"(a), "r"(b), "r"(index)); - return ret; + uint ret; + asm ("prmt.b32 %0, %1, %2, %3;" : "=r" (ret) : "r" (a), "r" (b), "r" (index)); + return ret; } #endif // __CUDA_ARCH__ >= 200 - //////////////////////////////////////////////////////////////////////////////// // shfl_up -__device__ __forceinline__ float shfl_up(float var, - unsigned int delta, int width = 32) { +__device__ __forceinline__ float shfl_up( + float var, + unsigned int delta, + int width = 32 +) { #if __CUDA_ARCH__ >= 300 - var = __shfl_up_sync(0xffffffff, var, delta, width); + var = __shfl_up_sync(0xffffffff, var, delta, width); #endif - return var; + return var; } -__device__ __forceinline__ double shfl_up(double var, - unsigned int delta, int width = 32) { +__device__ __forceinline__ double shfl_up( + double var, + unsigned int delta, + int width = 32 +) { #if __CUDA_ARCH__ >= 300 - int2 p = mgpu::double_as_int2(var); - p.x = __shfl_up_sync(0xffffffff, p.x, delta, width); - p.y = __shfl_up_sync(0xffffffff, p.y, delta, width); - var = mgpu::int2_as_double(p); + int2 p = mgpu::double_as_int2(var); + p.x = __shfl_up_sync(0xffffffff, p.x, delta, width); + p.y = __shfl_up_sync(0xffffffff, p.y, delta, width); + var = mgpu::int2_as_double(p); #endif - return var; + return var; } //////////////////////////////////////////////////////////////////////////////// // shfl_add MGPU_DEVICE int shfl_add(int x, int offset, int width = WARP_SIZE) { - int result = 0; + int result = 0; #if __CUDA_ARCH__ >= 300 - int mask = (WARP_SIZE - width)<< 8; - asm( - "{.reg .s32 r0;" - ".reg .pred p;" - "shfl.sync.up.b32 r0|p, %1, %2, %3, 0xffffffff;" - "@p add.s32 r0, r0, %4;" - "mov.s32 %0, r0; }" - : "=r"(result) : "r"(x), "r"(offset), "r"(mask), "r"(x)); + int mask = (WARP_SIZE - width) << 8; + asm ( + "{.reg .s32 r0;" + ".reg .pred p;" + "shfl.sync.up.b32 r0|p, %1, %2, %3, 0xffffffff;" + "@p add.s32 r0, r0, %4;" + "mov.s32 %0, r0; }" + : "=r" (result) : "r" (x), "r" (offset), "r" (mask), "r" (x)); #endif - return result; + return result; } MGPU_DEVICE int shfl_max(int x, int offset, int width = WARP_SIZE) { - int result = 0; + int result = 0; #if __CUDA_ARCH__ >= 300 - int mask = (WARP_SIZE - width)<< 8; - asm( - "{.reg .s32 r0;" - ".reg .pred p;" - "shfl.sync.up.b32 r0|p, %1, %2, %3, 0xffffffff;" - "@p max.s32 r0, r0, %4;" - "mov.s32 %0, r0; }" - : "=r"(result) : "r"(x), "r"(offset), "r"(mask), "r"(x)); + int mask = (WARP_SIZE - width) << 8; + asm ( + "{.reg .s32 r0;" + ".reg .pred p;" + "shfl.sync.up.b32 r0|p, %1, %2, %3, 0xffffffff;" + "@p max.s32 r0, r0, %4;" + "mov.s32 %0, r0; }" + : "=r" (result) : "r" (x), "r" (offset), "r" (mask), "r" (x)); #endif - return result; + return result; } //////////////////////////////////////////////////////////////////////////////// @@ -169,92 +152,92 @@ MGPU_DEVICE int shfl_max(int x, int offset, int width = WARP_SIZE) { // Reverse the bits in an integer. MGPU_HOST_DEVICE uint brev(uint x) { #if __CUDA_ARCH__ >= 200 - uint y = __brev(x); + uint y = __brev(x); #else - uint y = 0; - for(int i = 0; i < 32; ++i) - y |= (1 & (x>> i))<< (31 - i); + uint y = 0; + for(int i = 0; i < 32; ++i) + y |= (1 & (x >> i)) << (31 - i); #endif - return y; + return y; } // Count number of bits in a register. MGPU_HOST_DEVICE int popc(uint x) { #if __CUDA_ARCH__ >= 200 - return __popc(x); + return __popc(x); #else - int c; - for(c = 0; x; ++c) - x &= x - 1; - return c; + int c; + for(c = 0; x; ++c) + x &= x - 1; + return c; #endif } // Count leading zeros - start from most significant bit. MGPU_HOST_DEVICE int clz(int x) { #if __CUDA_ARCH__ >= 200 - return __clz(x); + return __clz(x); #else - for(int i = 31; i >= 0; --i) - if((1<< i) & x) return 31 - i; - return 32; + for(int i = 31; i >= 0; --i) + if((1 << i) & x)return 31 - i; + return 32; #endif } // Find first set - start from least significant bit. LSB is 1. ffs(0) is 0. MGPU_HOST_DEVICE int ffs(int x) { #if __CUDA_ARCH__ >= 200 - return __ffs(x); + return __ffs(x); #else - for(int i = 0; i < 32; ++i) - if((1<< i) & x) return i + 1; - return 0; + for(int i = 0; i < 32; ++i) + if((1 << i) & x)return i + 1; + return 0; #endif } MGPU_HOST_DEVICE uint bfe(uint x, uint bit, uint numBits) { #if __CUDA_ARCH__ >= 200 - return bfe_ptx(x, bit, numBits); + return bfe_ptx(x, bit, numBits); #else - return ((1<< numBits) - 1) & (x>> bit); + return ((1 << numBits) - 1) & (x >> bit); #endif } MGPU_HOST_DEVICE uint bfi(uint x, uint y, uint bit, uint numBits) { - uint result; + uint result; #if __CUDA_ARCH__ >= 200 - result = bfi_ptx(x, y, bit, numBits); + result = bfi_ptx(x, y, bit, numBits); #else - if(bit + numBits > 32) numBits = 32 - bit; - uint mask = ((1<< numBits) - 1)<< bit; - result = y & ~mask; - result |= mask & (x<< bit); + if(bit + numBits > 32) numBits = 32 - bit; + uint mask = ((1 << numBits) - 1) << bit; + result = y & ~mask; + result |= mask & (x << bit); #endif - return result; + return result; } MGPU_HOST_DEVICE uint prmt(uint a, uint b, uint index) { - uint result; + uint result; #if __CUDA_ARCH__ >= 200 - result = prmt_ptx(a, b, index); + result = prmt_ptx(a, b, index); #else - result = 0; - for(int i = 0; i < 4; ++i) { - uint sel = 0xf & (index>> (4 * i)); - uint x = ((7 & sel) > 3) ? b : a; - x = 0xff & (x>> (8 * (3 & sel))); - if(8 & sel) x = (128 & x) ? 0xff : 0; - result |= x<< (8 * i); - } + result = 0; + for(int i = 0; i < 4; ++i) { + uint sel = 0xf & (index >> (4 * i)); + uint x = ((7 & sel) > 3) ? b : a; + x = 0xff & (x >> (8 * (3 & sel))); + if(8 & sel) x = (128 & x) ? 0xff : 0; + result |= x << (8 * i); + } #endif - return result; + return result; } // Find log2(x) and optionally round up to the next integer logarithm. MGPU_HOST_DEVICE int FindLog2(int x, bool roundUp = false) { - int a = 31 - clz(x); - if(roundUp) a += !MGPU_IS_POW_2(x); - return a; + int a = 31 - clz(x); + if(roundUp) a += !MGPU_IS_POW_2(x); + return a; } //////////////////////////////////////////////////////////////////////////////// @@ -265,45 +248,45 @@ MGPU_HOST_DEVICE int FindLog2(int x, bool roundUp = false) { // Performs four byte-wise comparisons and returns 1 for each byte that // satisfies the conditional, and zero otherwise. MGPU_DEVICE uint vset4_lt_add_ptx(uint a, uint b, uint c) { - uint result; - asm("vset4.u32.u32.lt.add %0, %1, %2, %3;" : - "=r"(result) : "r"(a), "r"(b), "r"(c)); - return result; + uint result; + asm ("vset4.u32.u32.lt.add %0, %1, %2, %3;" : + "=r" (result) : "r" (a), "r" (b), "r" (c)); + return result; } MGPU_DEVICE uint vset4_eq_ptx(uint a, uint b) { - uint result; - asm("vset4.u32.u32.eq %0, %1, %2, %3;" : - "=r"(result) : "r"(a), "r"(b), "r"(0)); - return result; + uint result; + asm ("vset4.u32.u32.eq %0, %1, %2, %3;" : + "=r" (result) : "r" (a), "r" (b), "r" (0)); + return result; } #endif // __CUDA_ARCH__ >= 300 MGPU_HOST_DEVICE uint vset4_lt_add(uint a, uint b, uint c) { - uint result; + uint result; #if __CUDA_ARCH__ >= 300 - result = vset4_lt_add_ptx(a, b, c); + result = vset4_lt_add_ptx(a, b, c); #else - result = c; - if((0x000000ff & a) < (0x000000ff & b)) result += 0x00000001; - if((0x0000ff00 & a) < (0x0000ff00 & b)) result += 0x00000100; - if((0x00ff0000 & a) < (0x00ff0000 & b)) result += 0x00010000; - if((0xff000000 & a) < (0xff000000 & b)) result += 0x01000000; + result = c; + if((0x000000ff & a) < (0x000000ff & b)) result += 0x00000001; + if((0x0000ff00 & a) < (0x0000ff00 & b)) result += 0x00000100; + if((0x00ff0000 & a) < (0x00ff0000 & b)) result += 0x00010000; + if((0xff000000 & a) < (0xff000000 & b)) result += 0x01000000; #endif - return result; + return result; } MGPU_HOST_DEVICE uint vset4_eq(uint a, uint b) { - uint result; + uint result; #if __CUDA_ARCH__ >= 300 - result = vset4_eq_ptx(a, b); + result = vset4_eq_ptx(a, b); #else - result = 0; - if((0x000000ff & a) == (0x000000ff & b)) result = 0x00000001; - if((0x0000ff00 & a) == (0x0000ff00 & b)) result += 0x00000100; - if((0x00ff0000 & a) == (0x00ff0000 & b)) result += 0x00010000; - if((0xff000000 & a) == (0xff000000 & b)) result += 0x01000000; + result = 0; + if((0x000000ff & a) == (0x000000ff & b)) result = 0x00000001; + if((0x0000ff00 & a) == (0x0000ff00 & b)) result += 0x00000100; + if((0x00ff0000 & a) == (0x00ff0000 & b)) result += 0x00010000; + if((0xff000000 & a) == (0xff000000 & b)) result += 0x01000000; #endif - return result; + return result; } //////////////////////////////////////////////////////////////////////////////// @@ -311,10 +294,10 @@ MGPU_HOST_DEVICE uint vset4_eq(uint a, uint b) { MGPU_HOST_DEVICE uint umulhi(uint x, uint y) { #if __CUDA_ARCH__ >= 100 - return __umulhi(x, y); + return __umulhi(x, y); #else - uint64 product = (uint64)x * y; - return (uint)(product>> 32); + uint64 product = (uint64) x * y; + return (uint) (product >> 32); #endif } @@ -325,62 +308,58 @@ MGPU_HOST_DEVICE uint umulhi(uint x, uint y) { template struct IsLdgType { - enum { value = false }; + enum {value = false}; }; #define DEFINE_LDG_TYPE(T) \ - template<> struct IsLdgType { enum { value = true }; }; + template<> \ + struct IsLdgType {enum {value = true};}; template::value> struct LdgShim { - MGPU_DEVICE static T Ldg(const T* p) { - return *p; - } + MGPU_DEVICE static T Ldg(const T* p) { return *p; } }; #if __CUDA_ARCH__ >= 320 && __CUDA_ARCH__ < 400 - // List of __ldg-compatible types from sm_32_intrinsics.h. - DEFINE_LDG_TYPE(char) - DEFINE_LDG_TYPE(short) - DEFINE_LDG_TYPE(int) - DEFINE_LDG_TYPE(long long) - DEFINE_LDG_TYPE(char2) - DEFINE_LDG_TYPE(char4) - DEFINE_LDG_TYPE(short2) - DEFINE_LDG_TYPE(short4) - DEFINE_LDG_TYPE(int2) - DEFINE_LDG_TYPE(int4) - DEFINE_LDG_TYPE(longlong2) - - DEFINE_LDG_TYPE(unsigned char) - DEFINE_LDG_TYPE(unsigned short) - DEFINE_LDG_TYPE(unsigned int) - DEFINE_LDG_TYPE(unsigned long long) - DEFINE_LDG_TYPE(uchar2) - DEFINE_LDG_TYPE(uchar4) - DEFINE_LDG_TYPE(ushort2) - DEFINE_LDG_TYPE(ushort4) - DEFINE_LDG_TYPE(uint2) - DEFINE_LDG_TYPE(uint4) - DEFINE_LDG_TYPE(ulonglong2) - - DEFINE_LDG_TYPE(float) - DEFINE_LDG_TYPE(double) - DEFINE_LDG_TYPE(float2) - DEFINE_LDG_TYPE(float4) - DEFINE_LDG_TYPE(double2) - - template struct LdgShim { - MGPU_DEVICE static T Ldg(const T* p) { - return __ldg(p); - } - }; +// List of __ldg-compatible types from sm_32_intrinsics.h. +DEFINE_LDG_TYPE(char) +DEFINE_LDG_TYPE(short) +DEFINE_LDG_TYPE(int) +DEFINE_LDG_TYPE(long long) +DEFINE_LDG_TYPE(char2) +DEFINE_LDG_TYPE(char4) +DEFINE_LDG_TYPE(short2) +DEFINE_LDG_TYPE(short4) +DEFINE_LDG_TYPE(int2) +DEFINE_LDG_TYPE(int4) +DEFINE_LDG_TYPE(longlong2) + +DEFINE_LDG_TYPE(unsigned char) +DEFINE_LDG_TYPE(unsigned short) +DEFINE_LDG_TYPE(unsigned int) +DEFINE_LDG_TYPE(unsigned long long) +DEFINE_LDG_TYPE(uchar2) +DEFINE_LDG_TYPE(uchar4) +DEFINE_LDG_TYPE(ushort2) +DEFINE_LDG_TYPE(ushort4) +DEFINE_LDG_TYPE(uint2) +DEFINE_LDG_TYPE(uint4) +DEFINE_LDG_TYPE(ulonglong2) + +DEFINE_LDG_TYPE(float) +DEFINE_LDG_TYPE(double) +DEFINE_LDG_TYPE(float2) +DEFINE_LDG_TYPE(float4) +DEFINE_LDG_TYPE(double2) + +template +struct LdgShim { + MGPU_DEVICE static T Ldg(const T* p) { return __ldg(p); } +}; #endif template -MGPU_DEVICE T ldg(const T* p) { - return LdgShim::Ldg(p); -} +MGPU_DEVICE T ldg(const T* p) { return LdgShim::Ldg(p); } //////////////////////////////////////////////////////////////////////////////// @@ -388,23 +367,19 @@ MGPU_DEVICE T ldg(const T* p) { // Uses the method in Hacker's Delight (2nd edition) page 228. // Evaluates for denom > 1 and x < 2^31. struct FastDivide { - uint denom; - uint coef; - uint shift; - - MGPU_HOST_DEVICE uint Divide(uint x) { - return umulhi(x, coef)>> shift; - } - MGPU_HOST_DEVICE uint Modulus(uint x) { - return x - Divide(x) * denom; - } - - explicit FastDivide(uint denom_) { - denom = denom_; - uint p = 31 + FindLog2(denom, true); - coef = (uint)(((1ull<< p) + denom - 1) / denom); - shift = p - 32; - } + uint denom; + uint coef; + uint shift; + + MGPU_HOST_DEVICE uint Divide(uint x) { return umulhi(x, coef) >> shift; } + MGPU_HOST_DEVICE uint Modulus(uint x) { return x - Divide(x) * denom; } + + explicit FastDivide(uint denom_) { + denom = denom_; + uint p = 31 + FindLog2(denom, true); + coef = (uint) (((1ull << p) + denom - 1) / denom); + shift = p - 32; + } }; #pragma GCC diagnostic pop diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/loadstore.cuh b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/loadstore.cuh index fbe05b6..920369d 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/loadstore.cuh +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/loadstore.cuh @@ -26,11 +26,11 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ #pragma once @@ -44,149 +44,197 @@ namespace mgpu { // Cooperative load functions. template -MGPU_DEVICE void DeviceSharedToReg(InputIt data, int tid, T* reg, - bool sync) { +MGPU_DEVICE void DeviceSharedToReg( + InputIt data, + int tid, + T* reg, + bool sync +) { - #pragma unroll - for(int i = 0; i < VT; ++i) - reg[i] = data[NT * i + tid]; +#pragma unroll + for(int i = 0; i < VT; ++i) + reg[i] = data[NT * i + tid]; - if(sync) __syncthreads(); + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceGlobalToRegPred(int count, InputIt data, int tid, - T* reg, bool sync) { - - // TODO: Attempt to issue 4 loads at a time. - #pragma unroll - for(int i = 0; i < VT; ++i) { - int index = NT * i + tid; - if(index < count) reg[i] = data[index]; - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceGlobalToRegPred( + int count, + InputIt data, + int tid, + T* reg, + bool sync +) { + + // TODO: Attempt to issue 4 loads at a time. +#pragma unroll + for(int i = 0; i < VT; ++i) { + int index = NT * i + tid; + if(index < count) reg[i] = data[index]; + } + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceGlobalToReg(int count, InputIt data, int tid, - T* reg, bool sync) { - - if(count >= NT * VT) { - #pragma unroll - for(int i = 0; i < VT; ++i) - reg[i] = data[NT * i + tid]; - } else - DeviceGlobalToRegPred(count, data, tid, reg, false); - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceGlobalToReg( + int count, + InputIt data, + int tid, + T* reg, + bool sync +) { + + if(count >= NT * VT) { +#pragma unroll + for(int i = 0; i < VT; ++i) + reg[i] = data[NT * i + tid]; + } else + DeviceGlobalToRegPred(count, data, tid, reg, false); + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceGlobalToReg2(int count, InputIt data, int tid, - T* reg, bool sync) { - - DeviceGlobalToReg(count, data, tid, reg, false); - #pragma unroll - for(int i = VT0; i < VT1; ++i) { - int index = NT * i + tid; - if(index < count) reg[i] = data[index]; - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceGlobalToReg2( + int count, + InputIt data, + int tid, + T* reg, + bool sync +) { + + DeviceGlobalToReg(count, data, tid, reg, false); +#pragma unroll + for(int i = VT0; i < VT1; ++i) { + int index = NT * i + tid; + if(index < count) reg[i] = data[index]; + } + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceGlobalToRegDefault(int count, InputIt data, int tid, - T* reg, T init, bool sync) { - - if(count >= NT * VT) { - #pragma unroll - for(int i = 0; i < VT; ++i) - reg[i] = data[NT * i + tid]; - } else { - #pragma unroll - for(int i = 0; i < VT; ++i) { - int index = NT * i + tid; - reg[i] = init; - if(index < count) reg[i] = data[index]; - } - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceGlobalToRegDefault( + int count, + InputIt data, + int tid, + T* reg, + T init, + bool sync +) { + + if(count >= NT * VT) { +#pragma unroll + for(int i = 0; i < VT; ++i) + reg[i] = data[NT * i + tid]; + } else { +#pragma unroll + for(int i = 0; i < VT; ++i) { + int index = NT * i + tid; + reg[i] = init; + if(index < count) reg[i] = data[index]; + } + } + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceGlobalToRegDefault2(int count, InputIt data, int tid, - T* reg, T init, bool sync) { - - DeviceGlobalToRegDefault(count, data, tid, reg, init, false); - #pragma unroll - for(int i = VT0; i < VT1; ++i) { - int index = NT * i + tid; - reg[i] = init; - if(index < count) reg[i] = data[index]; - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceGlobalToRegDefault2( + int count, + InputIt data, + int tid, + T* reg, + T init, + bool sync +) { + + DeviceGlobalToRegDefault(count, data, tid, reg, init, false); +#pragma unroll + for(int i = VT0; i < VT1; ++i) { + int index = NT * i + tid; + reg[i] = init; + if(index < count) reg[i] = data[index]; + } + if(sync) __syncthreads(); } //////////////////////////////////////////////////////////////////////////////// template -MGPU_DEVICE void DeviceGlobalToThread(int count, InputIt data, int tid, - T* reg) { - - data += VT * tid; - if(count >= NT * VT) { - #pragma unroll - for(int i = 0; i < VT; ++i) - reg[i] = ldg(data + i); - } else { - count -= VT * tid; - #pragma unroll - for(int i = 0; i < VT; ++i) - if(i < count) reg[i] = ldg(data + i); - } +MGPU_DEVICE void DeviceGlobalToThread( + int count, + InputIt data, + int tid, + T* reg +) { + + data += VT * tid; + if(count >= NT * VT) { +#pragma unroll + for(int i = 0; i < VT; ++i) + reg[i] = ldg(data + i); + } else { + count -= VT * tid; +#pragma unroll + for(int i = 0; i < VT; ++i) + if(i < count) reg[i] = ldg(data + i); + } } template -MGPU_DEVICE void DeviceGlobalToThreadDefault(int count, InputIt data, int tid, - T* reg, T init) { - - data += VT * tid; - if(count >= NT * VT) { - #pragma unroll - for(int i = 0; i < VT; ++i) - reg[i] = ldg(data + i); - } else { - count -= VT * tid; - #pragma unroll - for(int i = 0; i < VT; ++i) - reg[i] = (i < count) ? ldg(data + i) : init; - } +MGPU_DEVICE void DeviceGlobalToThreadDefault( + int count, + InputIt data, + int tid, + T* reg, + T init +) { + + data += VT * tid; + if(count >= NT * VT) { +#pragma unroll + for(int i = 0; i < VT; ++i) + reg[i] = ldg(data + i); + } else { + count -= VT * tid; +#pragma unroll + for(int i = 0; i < VT; ++i) + reg[i] = (i < count) ? ldg(data + i) : init; + } } - //////////////////////////////////////////////////////////////////////////////// // Cooperative store functions. template -MGPU_DEVICE void DeviceRegToShared(const T* reg, int tid, - OutputIt dest, bool sync) { +MGPU_DEVICE void DeviceRegToShared( + const T* reg, + int tid, + OutputIt dest, + bool sync +) { - typedef typename std::iterator_traits::value_type T2; - #pragma unroll - for(int i = 0; i < VT; ++i) - dest[NT * i + tid] = (T2)reg[i]; + typedef typename std::iterator_traits::value_type T2; +#pragma unroll + for(int i = 0; i < VT; ++i) + dest[NT * i + tid] = (T2) reg[i]; - if(sync) __syncthreads(); + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceRegToGlobal(int count, const T* reg, int tid, - OutputIt dest, bool sync) { - - #pragma unroll - for(int i = 0; i < VT; ++i) { - int index = NT * i + tid; - if(index < count) - dest[index] = reg[i]; - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceRegToGlobal( + int count, + const T* reg, + int tid, + OutputIt dest, + bool sync +) { + +#pragma unroll + for(int i = 0; i < VT; ++i) { + int index = NT * i + tid; + if(index < count) + dest[index] = reg[i]; + } + if(sync) __syncthreads(); } //////////////////////////////////////////////////////////////////////////////// @@ -196,283 +244,364 @@ MGPU_DEVICE void DeviceRegToGlobal(int count, const T* reg, int tid, // unnecessary comparison logic. template -MGPU_DEVICE void DeviceMemToMem4(int count, InputIt source, int tid, - OutputIt dest, bool sync) { - - typedef typename std::iterator_traits::value_type T; - - T x[VT]; - const int Count = (VT < 4) ? VT : 4; - if(count >= NT * VT) { - #pragma unroll - for(int i = 0; i < Count; ++i) - x[i] = source[NT * i + tid]; - #pragma unroll - for(int i = 0; i < Count; ++i) - dest[NT * i + tid] = x[i]; - } else { - #pragma unroll - for(int i = 0; i < Count; ++i) { - int index = NT * i + tid; - if(index < count) - x[i] = source[NT * i + tid]; - } - #pragma unroll - for(int i = 0; i < Count; ++i) { - int index = NT * i + tid; - if(index < count) - dest[index] = x[i]; - } - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceMemToMem4( + int count, + InputIt source, + int tid, + OutputIt dest, + bool sync +) { + + typedef typename std::iterator_traits::value_type T; + + T x[VT]; + const int Count = (VT < 4) ? VT : 4; + if(count >= NT * VT) { +#pragma unroll + for(int i = 0; i < Count; ++i) + x[i] = source[NT * i + tid]; +#pragma unroll + for(int i = 0; i < Count; ++i) + dest[NT * i + tid] = x[i]; + } else { +#pragma unroll + for(int i = 0; i < Count; ++i) { + int index = NT * i + tid; + if(index < count) + x[i] = source[NT * i + tid]; + } +#pragma unroll + for(int i = 0; i < Count; ++i) { + int index = NT * i + tid; + if(index < count) + dest[index] = x[i]; + } + } + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceMemToMemLoop(int count, InputIt source, int tid, - OutputIt dest, bool sync) { - - for(int i = 0; i < count; i += 4 * NT) - DeviceMemToMem4(count - i, source + i, tid, dest + i, - false); - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceMemToMemLoop( + int count, + InputIt source, + int tid, + OutputIt dest, + bool sync +) { + + for(int i = 0; i < count; i += 4 * NT) + DeviceMemToMem4( + count - i, + source + i, + tid, + dest + i, + false + ); + if(sync) __syncthreads(); } - //////////////////////////////////////////////////////////////////////////////// // Functions to copy between shared and global memory where the average case is // to transfer NT * VT elements. template -MGPU_DEVICE void DeviceSharedToGlobal(int count, const T* source, int tid, - OutputIt dest, bool sync) { - - typedef typename std::iterator_traits::value_type T2; - #pragma unroll - for(int i = 0; i < VT; ++i) { - int index = NT * i + tid; - if(index < count) dest[index] = (T2)source[index]; - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceSharedToGlobal( + int count, + const T* source, + int tid, + OutputIt dest, + bool sync +) { + + typedef typename std::iterator_traits::value_type T2; +#pragma unroll + for(int i = 0; i < VT; ++i) { + int index = NT * i + tid; + if(index < count) dest[index] = (T2) source[index]; + } + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceGlobalToShared(int count, InputIt source, int tid, - T* dest, bool sync) { +MGPU_DEVICE void DeviceGlobalToShared( + int count, + InputIt source, + int tid, + T* dest, + bool sync +) { - T reg[VT]; - DeviceGlobalToReg(count, source, tid, reg, false); - DeviceRegToShared(reg, tid, dest, sync); + T reg[VT]; + DeviceGlobalToReg(count, source, tid, reg, false); + DeviceRegToShared(reg, tid, dest, sync); } template -MGPU_DEVICE void DeviceGlobalToShared2(int count, InputIt source, int tid, - T* dest, bool sync) { +MGPU_DEVICE void DeviceGlobalToShared2( + int count, + InputIt source, + int tid, + T* dest, + bool sync +) { - T reg[VT1]; - DeviceGlobalToReg2(count, source, tid, reg, false); - DeviceRegToShared(reg, tid, dest, sync); + T reg[VT1]; + DeviceGlobalToReg2(count, source, tid, reg, false); + DeviceRegToShared(reg, tid, dest, sync); } - template -MGPU_DEVICE void DeviceGlobalToSharedDefault(int count, InputIt source, int tid, - T* dest, T init, bool sync) { +MGPU_DEVICE void DeviceGlobalToSharedDefault( + int count, + InputIt source, + int tid, + T* dest, + T init, + bool sync +) { - T reg[VT]; - DeviceGlobalToRegDefault(count, source, tid, reg, init, false); - DeviceRegToShared(reg, tid, dest, sync); + T reg[VT]; + DeviceGlobalToRegDefault(count, source, tid, reg, init, false); + DeviceRegToShared(reg, tid, dest, sync); } template -MGPU_DEVICE void DeviceGlobalToSharedDefault2(int count, InputIt data, int tid, - T* dest, T init, bool sync) { +MGPU_DEVICE void DeviceGlobalToSharedDefault2( + int count, + InputIt data, + int tid, + T* dest, + T init, + bool sync +) { - T reg[VT1]; - DeviceGlobalToRegDefault2(count, data, tid, reg, init, false); - DeviceRegToShared(reg, tid, dest, sync); + T reg[VT1]; + DeviceGlobalToRegDefault2(count, data, tid, reg, init, false); + DeviceRegToShared(reg, tid, dest, sync); } - //////////////////////////////////////////////////////////////////////////////// template -MGPU_DEVICE void DeviceGlobalToSharedLoop(int count, InputIt source, int tid, - T* dest, bool sync) { - - const int Granularity = MGPU_MIN(VT, 3); - DeviceGlobalToShared(count, source, tid, dest, false); - - int offset = Granularity * NT; - if(count > offset) - DeviceGlobalToShared(count - offset, - source + offset, tid, dest + offset, false); - - if(sync) __syncthreads(); - - /* - source += tid; - while(count > 0) { - T reg[Granularity]; - #pragma unroll - for(int i = 0; i < Granularity; ++i) { - int index = NT * i + tid; - if(index < count) - reg[i] = source[NT * i]; - } - DeviceRegToShared(reg, tid, dest, false); - source += Granularity * NT; - dest += Granularity * NT; - count -= Granularity * NT; - } - if(sync) __syncthreads();*/ +MGPU_DEVICE void DeviceGlobalToSharedLoop( + int count, + InputIt source, + int tid, + T* dest, + bool sync +) { + + const int Granularity = MGPU_MIN(VT, 3); + DeviceGlobalToShared(count, source, tid, dest, false); + + int offset = Granularity * NT; + if(count > offset) + DeviceGlobalToShared( + count - offset, + source + offset, + tid, + dest + offset, + false + ); + + if(sync) __syncthreads(); + + /* + source += tid; + while(count > 0) { + T reg[Granularity]; + #pragma unroll + for(int i = 0; i < Granularity; ++i) { + int index = NT * i + tid; + if(index < count) + reg[i] = source[NT * i]; + } + DeviceRegToShared(reg, tid, dest, false); + source += Granularity * NT; + dest += Granularity * NT; + count -= Granularity * NT; + } + if(sync) __syncthreads();*/ } template -MGPU_DEVICE void DeviceGlobalToGlobal(int count, InputIt source, int tid, - OutputIt dest, bool sync) { +MGPU_DEVICE void DeviceGlobalToGlobal( + int count, + InputIt source, + int tid, + OutputIt dest, + bool sync +) { - typedef typename std::iterator_traits::value_type T; - T values[VT]; - DeviceGlobalToReg(count, source, tid, values, false); - DeviceRegToGlobal(count, values, tid, dest, sync); + typedef typename std::iterator_traits::value_type T; + T values[VT]; + DeviceGlobalToReg(count, source, tid, values, false); + DeviceRegToGlobal(count, values, tid, dest, sync); } //////////////////////////////////////////////////////////////////////////////// // Transponse VT elements in NT threads (x) into thread-order registers (y) // using only NT * VT / 2 elements of shared memory. -//This function definitely has a bug, don't use!!! fix TODO(erich) +// This function definitely has a bug, don't use!!! fix TODO(erich) template MGPU_DEVICE void HalfSmemTranspose(const T* x, int tid, T* shared, T* y) { printf("HalfSmemTranspose has a bug, use WAR SmemTranpose or find bug before using in production"); - // Transpose the first half values (tid < NT / 2) - #pragma unroll - for(int i = 0; i <= VT / 2; ++i) - if(i < VT / 2 || tid < NT / 2) - shared[NT * i + tid] = x[i]; - __syncthreads(); - - if(tid < NT / 2) { - #pragma unroll - for(int i = 0; i < VT; ++i) - y[i] = shared[VT * tid + i]; - } - __syncthreads(); - - // Transpose the second half values (tid >= NT / 2) - #pragma unroll - for(int i = VT / 2; i < VT; ++i) - if(i > VT / 2 || tid >= NT / 2) - shared[NT * i - NT * VT / 2 + tid] = x[i]; - __syncthreads(); - - if(tid >= NT / 2) { - #pragma unroll - for(int i = 0; i < VT; ++i) - y[i] = shared[VT * tid + i - NT * VT / 2]; - } - __syncthreads(); + // Transpose the first half values (tid < NT / 2) +#pragma unroll + for(int i = 0; i <= VT / 2; ++i) + if(i < VT / 2 || tid < NT / 2) + shared[NT * i + tid] = x[i]; + __syncthreads(); + + if(tid < NT / 2) { +#pragma unroll + for(int i = 0; i < VT; ++i) + y[i] = shared[VT * tid + i]; + } + __syncthreads(); + + // Transpose the second half values (tid >= NT / 2) +#pragma unroll + for(int i = VT / 2; i < VT; ++i) + if(i > VT / 2 || tid >= NT / 2) + shared[NT * i - NT * VT / 2 + tid] = x[i]; + __syncthreads(); + + if(tid >= NT / 2) { +#pragma unroll + for(int i = 0; i < VT; ++i) + y[i] = shared[VT * tid + i - NT * VT / 2]; + } + __syncthreads(); } //////////////////////////////////////////////////////////////////////////////// // Gather/scatter functions template -MGPU_DEVICE void DeviceGather(int count, InputIt data, int indices[VT], - int tid, T* reg, bool sync) { - - if(count >= NT * VT) { - #pragma unroll - for(int i = 0; i < VT; ++i) - reg[i] = data[indices[i]]; - } else { - #pragma unroll - for(int i = 0; i < VT; ++i) { - int index = NT * i + tid; - if(index < count) - reg[i] = data[indices[i]]; - } - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceGather( + int count, + InputIt data, + int indices[VT], + int tid, + T* reg, + bool sync +) { + + if(count >= NT * VT) { +#pragma unroll + for(int i = 0; i < VT; ++i) + reg[i] = data[indices[i]]; + } else { +#pragma unroll + for(int i = 0; i < VT; ++i) { + int index = NT * i + tid; + if(index < count) + reg[i] = data[indices[i]]; + } + } + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceGatherDefault(int count, InputIt data, int indices[VT], - int tid, T* reg, T identity, bool sync) { - - if(count >= NT * VT) { - #pragma unroll - for(int i = 0; i < VT; ++i) - reg[i] = data[indices[i]]; - } else { - #pragma unroll - for(int i = 0; i < VT; ++i) { - int index = NT * i + tid; - reg[i] = (index < count) ? data[indices[i]] : identity; - } - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceGatherDefault( + int count, + InputIt data, + int indices[VT], + int tid, + T* reg, + T identity, + bool sync +) { + + if(count >= NT * VT) { +#pragma unroll + for(int i = 0; i < VT; ++i) + reg[i] = data[indices[i]]; + } else { +#pragma unroll + for(int i = 0; i < VT; ++i) { + int index = NT * i + tid; + reg[i] = (index < count) ? data[indices[i]] : identity; + } + } + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceScatter(int count, const T* reg, int tid, - int indices[VT], OutputIt data, bool sync) { - - if(count >= NT * VT) { - #pragma unroll - for(int i = 0; i < VT; ++i) - data[indices[i]] = reg[i]; - } else { - #pragma unroll - for(int i = 0; i < VT; ++i) { - int index = NT * i + tid; - if(index < count) - data[indices[i]] = reg[i]; - } - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceScatter( + int count, + const T* reg, + int tid, + int indices[VT], + OutputIt data, + bool sync +) { + + if(count >= NT * VT) { +#pragma unroll + for(int i = 0; i < VT; ++i) + data[indices[i]] = reg[i]; + } else { +#pragma unroll + for(int i = 0; i < VT; ++i) { + int index = NT * i + tid; + if(index < count) + data[indices[i]] = reg[i]; + } + } + if(sync) __syncthreads(); } //////////////////////////////////////////////////////////////////////////////// // Cooperative transpose functions (strided to thread order) template -MGPU_DEVICE void DeviceThreadToShared(const T* threadReg, int tid, T* shared, - bool sync) { - - if(1 & VT) { - // Odd grain size. Store as type T. - #pragma unroll - for(int i = 0; i < VT; ++i) - shared[VT * tid + i] = threadReg[i]; - } else { - // Even grain size. Store as DevicePair. This lets us exploit the - // 8-byte shared memory mode on Kepler. - DevicePair* dest = (DevicePair*)(shared + VT * tid); - #pragma unroll - for(int i = 0; i < VT / 2; ++i) - dest[i] = MakeDevicePair(threadReg[2 * i], threadReg[2 * i + 1]); - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceThreadToShared( + const T* threadReg, + int tid, + T* shared, + bool sync +) { + + if(1 & VT) { + // Odd grain size. Store as type T. +#pragma unroll + for(int i = 0; i < VT; ++i) + shared[VT * tid + i] = threadReg[i]; + } else { + // Even grain size. Store as DevicePair. This lets us exploit the + // 8-byte shared memory mode on Kepler. + DevicePair* dest = (DevicePair*) (shared + VT * tid); +#pragma unroll + for(int i = 0; i < VT / 2; ++i) + dest[i] = MakeDevicePair(threadReg[2 * i], threadReg[2 * i + 1]); + } + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceSharedToThread(const T* shared, int tid, T* threadReg, - bool sync) { - - if(1 & VT) { - #pragma unroll - for(int i = 0; i < VT; ++i) - threadReg[i] = shared[VT * tid + i]; - } else { - const DevicePair* source = (const DevicePair*)(shared + VT * tid); - #pragma unroll - for(int i = 0; i < VT / 2; ++i) { - DevicePair p = source[i]; - threadReg[2 * i] = p.x; - threadReg[2 * i + 1] = p.y; - } - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceSharedToThread( + const T* shared, + int tid, + T* threadReg, + bool sync +) { + + if(1 & VT) { +#pragma unroll + for(int i = 0; i < VT; ++i) + threadReg[i] = shared[VT * tid + i]; + } else { + const DevicePair* source = (const DevicePair*) (shared + VT * tid); +#pragma unroll + for(int i = 0; i < VT / 2; ++i) { + DevicePair p = source[i]; + threadReg[2 * i] = p.x; + threadReg[2 * i + 1] = p.y; + } + } + if(sync) __syncthreads(); } //////////////////////////////////////////////////////////////////////////////// @@ -480,108 +609,155 @@ MGPU_DEVICE void DeviceSharedToThread(const T* shared, int tid, T* threadReg, // statement. template -MGPU_DEVICE void DeviceLoad2ToReg(const T* a_global, int aCount, - const T* b_global, int bCount, int tid, T* reg, bool sync) { - - int b0 = b_global - a_global - aCount; - int total = aCount + bCount; - if(total >= NT * VT0) { - #pragma unroll - for(int i = 0; i < VT0; ++i) { - int index = NT * i + tid; - reg[i] = a_global[index + ((index >= aCount) ? b0 : 0)]; - } - } else { - #pragma unroll - for(int i = 0; i < VT0; ++i) { - int index = NT * i + tid; - if(index < total) - reg[i] = a_global[index + ((index >= aCount) ? b0 : 0)]; - } - } - #pragma unroll - for(int i = VT0; i < VT1; ++i) { - int index = NT * i + tid; - if(index < total) - reg[i] = a_global[index + ((index >= aCount) ? b0 : 0)]; - } +MGPU_DEVICE void DeviceLoad2ToReg( + const T* a_global, + int aCount, + const T* b_global, + int bCount, + int tid, + T* reg, + bool sync +) { + + int b0 = b_global - a_global - aCount; + int total = aCount + bCount; + if(total >= NT * VT0) { +#pragma unroll + for(int i = 0; i < VT0; ++i) { + int index = NT * i + tid; + reg[i] = a_global[index + ((index >= aCount) ? b0 : 0)]; + } + } else { +#pragma unroll + for(int i = 0; i < VT0; ++i) { + int index = NT * i + tid; + if(index < total) + reg[i] = a_global[index + ((index >= aCount) ? b0 : 0)]; + } + } +#pragma unroll + for(int i = VT0; i < VT1; ++i) { + int index = NT * i + tid; + if(index < total) + reg[i] = a_global[index + ((index >= aCount) ? b0 : 0)]; + } } template -MGPU_DEVICE void DeviceLoad2ToShared(const T* a_global, int aCount, - const T* b_global, int bCount, int tid, T* shared, bool sync) { - - T reg[VT1]; - DeviceLoad2ToReg(a_global, aCount, b_global, bCount, tid, - reg, false); - DeviceRegToShared(reg, tid, shared, sync); +MGPU_DEVICE void DeviceLoad2ToShared( + const T* a_global, + int aCount, + const T* b_global, + int bCount, + int tid, + T* shared, + bool sync +) { + + T reg[VT1]; + DeviceLoad2ToReg( + a_global, + aCount, + b_global, + bCount, + tid, + reg, + false + ); + DeviceRegToShared(reg, tid, shared, sync); } //////////////////////////////////////////////////////////////////////////////// // DeviceLoad2 - load from pointers of different types. Uses two LD statements. template -MGPU_DEVICE void DeviceLoad2ToReg(InputIt1 a_global, int aCount, - InputIt2 b_global, int bCount, int tid, T* reg, bool sync) { - - b_global -= aCount; - int total = aCount + bCount; - if(total >= NT * VT0) { - #pragma unroll - for(int i = 0; i < VT0; ++i) { - int index = NT * i + tid; - if(index < aCount) reg[i] = a_global[index]; - else reg[i] = b_global[index]; - } - } else { - #pragma unroll - for(int i = 0; i < VT0; ++i) { - int index = NT * i + tid; - if(index < aCount) reg[i] = a_global[index]; - else if(index < total) reg[i] = b_global[index]; - } - } - #pragma unroll - for(int i = VT0; i < VT1; ++i) { - int index = NT * i + tid; - if(index < aCount) reg[i] = a_global[index]; - else if(index < total) reg[i] = b_global[index]; - } + typename T> +MGPU_DEVICE void DeviceLoad2ToReg( + InputIt1 a_global, + int aCount, + InputIt2 b_global, + int bCount, + int tid, + T* reg, + bool sync +) { + + b_global -= aCount; + int total = aCount + bCount; + if(total >= NT * VT0) { +#pragma unroll + for(int i = 0; i < VT0; ++i) { + int index = NT * i + tid; + if(index < aCount) reg[i] = a_global[index]; + else reg[i] = b_global[index]; + } + } else { +#pragma unroll + for(int i = 0; i < VT0; ++i) { + int index = NT * i + tid; + if(index < aCount) reg[i] = a_global[index]; + else if(index < total) reg[i] = b_global[index]; + } + } +#pragma unroll + for(int i = VT0; i < VT1; ++i) { + int index = NT * i + tid; + if(index < aCount) reg[i] = a_global[index]; + else if(index < total) reg[i] = b_global[index]; + } } template -MGPU_DEVICE void DeviceLoad2ToShared(InputIt1 a_global, int aCount, - InputIt2 b_global, int bCount, int tid, T* shared, bool sync) { - - T reg[VT1]; - DeviceLoad2ToReg(a_global, aCount, b_global, bCount, tid, - reg, false); - DeviceRegToShared(reg, tid, shared, sync); + typename T> +MGPU_DEVICE void DeviceLoad2ToShared( + InputIt1 a_global, + int aCount, + InputIt2 b_global, + int bCount, + int tid, + T* shared, + bool sync +) { + + T reg[VT1]; + DeviceLoad2ToReg( + a_global, + aCount, + b_global, + bCount, + tid, + reg, + false + ); + DeviceRegToShared(reg, tid, shared, sync); } - //////////////////////////////////////////////////////////////////////////////// // DeviceGatherGlobalToGlobal template -MGPU_DEVICE void DeviceGatherGlobalToGlobal(int count, InputIt data_global, - const int* indices_shared, int tid, OutputIt dest_global, bool sync) { - - typedef typename std::iterator_traits::value_type ValType; - ValType values[VT]; - - #pragma unroll - for(int i = 0; i < VT; ++i) { - int index = NT * i + tid; - if(index < count) { - int gather = indices_shared[index]; - values[i] = data_global[gather]; - } - } - if(sync) __syncthreads(); - DeviceRegToGlobal(count, values, tid, dest_global, false); +MGPU_DEVICE void DeviceGatherGlobalToGlobal( + int count, + InputIt data_global, + const int* indices_shared, + int tid, + OutputIt dest_global, + bool sync +) { + + typedef typename std::iterator_traits::value_type ValType; + ValType values[VT]; + +#pragma unroll + for(int i = 0; i < VT; ++i) { + int index = NT * i + tid; + if(index < count) { + int gather = indices_shared[index]; + values[i] = data_global[gather]; + } + } + if(sync) __syncthreads(); + DeviceRegToGlobal(count, values, tid, dest_global, false); } //////////////////////////////////////////////////////////////////////////////// @@ -590,85 +766,128 @@ MGPU_DEVICE void DeviceGatherGlobalToGlobal(int count, InputIt data_global, // output. Like DeviceGatherGlobalToGlobal, but for two arrays at once. template -MGPU_DEVICE void DeviceTransferMergeValuesReg(int count, InputIt1 a_global, - InputIt2 b_global, int bStart, const int* indices, int tid, - T* reg, bool sync) { - - b_global -= bStart; - if(count >= NT * VT) { - #pragma unroll - for(int i = 0; i < VT; ++i) { - reg[i] = (indices[i] < bStart) ? a_global[indices[i]] : - b_global[indices[i]]; - } - } else { - #pragma unroll - for(int i = 0; i < VT; ++i) { - int index = NT * i + tid; - if(index < count) - reg[i] = (indices[i] < bStart) ? a_global[indices[i]] : - b_global[indices[i]]; - } - } - if(sync) __syncthreads(); + typename T> +MGPU_DEVICE void DeviceTransferMergeValuesReg( + int count, + InputIt1 a_global, + InputIt2 b_global, + int bStart, + const int* indices, + int tid, + T* reg, + bool sync +) { + + b_global -= bStart; + if(count >= NT * VT) { +#pragma unroll + for(int i = 0; i < VT; ++i) + reg[i] = (indices[i] < bStart) ? a_global[indices[i]] + : b_global[indices[i]]; + } else { +#pragma unroll + for(int i = 0; i < VT; ++i) { + int index = NT * i + tid; + if(index < count) + reg[i] = (indices[i] < bStart) ? a_global[indices[i]] + : b_global[indices[i]]; + } + } + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceTransferMergeValuesShared(int count, InputIt1 a_global, - InputIt2 b_global, int bStart, const int* indices_shared, int tid, - OutputIt dest_global, bool sync) { - - int indices[VT]; - DeviceSharedToReg(indices_shared, tid, indices); - - typedef typename std::iterator_traits::value_type ValType; - ValType reg[VT]; - DeviceTransferMergeValuesReg(count, a_global, b_global, bStart, - indices, tid, reg, sync); - DeviceRegToGlobal(count, reg, tid, dest_global, sync); + typename OutputIt> +MGPU_DEVICE void DeviceTransferMergeValuesShared( + int count, + InputIt1 a_global, + InputIt2 b_global, + int bStart, + const int* indices_shared, + int tid, + OutputIt dest_global, + bool sync +) { + + int indices[VT]; + DeviceSharedToReg(indices_shared, tid, indices); + + typedef typename std::iterator_traits::value_type ValType; + ValType reg[VT]; + DeviceTransferMergeValuesReg( + count, + a_global, + b_global, + bStart, + indices, + tid, + reg, + sync + ); + DeviceRegToGlobal(count, reg, tid, dest_global, sync); } template -MGPU_DEVICE void DeviceTransferMergeValuesReg(int count, const T* a_global, - const T* b_global, int bStart, const int* indices, int tid, T* reg, - bool sync) { - - int bOffset = (int)(b_global - a_global - bStart); - - if(count >= NT * VT) { - #pragma unroll - for(int i = 0; i < VT; ++i) { - int gather = indices[i]; - if(gather >= bStart) gather += bOffset; - reg[i] = a_global[gather]; - } - } else { - #pragma unroll - for(int i = 0; i < VT; ++i) { - int index = NT * i + tid; - int gather = indices[i]; - if(gather >= bStart) gather += bOffset; - if(index < count) - reg[i] = a_global[gather]; - } - } - if(sync) __syncthreads(); +MGPU_DEVICE void DeviceTransferMergeValuesReg( + int count, + const T* a_global, + const T* b_global, + int bStart, + const int* indices, + int tid, + T* reg, + bool sync +) { + + int bOffset = (int) (b_global - a_global - bStart); + + if(count >= NT * VT) { +#pragma unroll + for(int i = 0; i < VT; ++i) { + int gather = indices[i]; + if(gather >= bStart) gather += bOffset; + reg[i] = a_global[gather]; + } + } else { +#pragma unroll + for(int i = 0; i < VT; ++i) { + int index = NT * i + tid; + int gather = indices[i]; + if(gather >= bStart) gather += bOffset; + if(index < count) + reg[i] = a_global[gather]; + } + } + if(sync) __syncthreads(); } template -MGPU_DEVICE void DeviceTransferMergeValuesShared(int count, const T* a_global, - const T* b_global, int bStart, const int* indices_shared, int tid, - OutputIt dest_global, bool sync) { - - int indices[VT]; - DeviceSharedToReg(indices_shared, tid, indices); - - T reg[VT]; - DeviceTransferMergeValuesReg(count, a_global, b_global, bStart, - indices, tid, reg, sync); - DeviceRegToGlobal(count, reg, tid, dest_global, sync); +MGPU_DEVICE void DeviceTransferMergeValuesShared( + int count, + const T* a_global, + const T* b_global, + int bStart, + const int* indices_shared, + int tid, + OutputIt dest_global, + bool sync +) { + + int indices[VT]; + DeviceSharedToReg(indices_shared, tid, indices); + + T reg[VT]; + DeviceTransferMergeValuesReg( + count, + a_global, + b_global, + bStart, + indices, + tid, + reg, + sync + ); + DeviceRegToGlobal(count, reg, tid, dest_global, sync); } } // namespace mgpu diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/sortnetwork.cuh b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/sortnetwork.cuh index ceead3d..c4b6fd8 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/sortnetwork.cuh +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/device/sortnetwork.cuh @@ -26,11 +26,11 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ #pragma once @@ -46,51 +46,54 @@ namespace mgpu { // CUDA Compiler does not currently unroll these loops correctly. Write using // template loop unrolling. /* -template -MGPU_DEVICE void OddEvenTransposeSort(T* keys, V* values, Comp comp) { - #pragma unroll - for(int level = 0; level < VT; ++level) { - - #pragma unroll - for(int i = 1 & level; i < VT - 1; i += 2) { - if(comp(keys[i + 1], keys[i])) { - mgpu::swap(keys[i], keys[i + 1]); - mgpu::swap(values[i], values[i + 1]); - } - } - } -}*/ + template + MGPU_DEVICE void OddEvenTransposeSort(T* keys, V* values, Comp comp) { + #pragma unroll + for(int level = 0; level < VT; ++level) { + + #pragma unroll + for(int i = 1 & level; i < VT - 1; i += 2) { + if(comp(keys[i + 1], keys[i])) { + mgpu::swap(keys[i], keys[i + 1]); + mgpu::swap(values[i], values[i + 1]); + } + } + } + }*/ template struct OddEvenTransposeSortT { - // Sort segments marked by head flags. If the head flag between i and i + 1 - // is set (so that (2<< i) & flags is true), the values belong to different - // segments and are not swapped. - template - static MGPU_DEVICE void Sort(K* keys, V* values, int flags, Comp comp) { - #pragma unroll - for(int i = 1 & I; i < VT - 1; i += 2) - if((0 == ((2<< i) & flags)) && comp(keys[i + 1], keys[i])) { - mgpu::swap(keys[i], keys[i + 1]); - mgpu::swap(values[i], values[i + 1]); - } - OddEvenTransposeSortT::Sort(keys, values, flags, comp); - } + // Sort segments marked by head flags. If the head flag between i and i + 1 + // is set (so that (2<< i) & flags is true), the values belong to different + // segments and are not swapped. + template + static MGPU_DEVICE void Sort(K* keys, V* values, int flags, Comp comp) { +#pragma unroll + for(int i = 1 & I; i < VT - 1; i += 2) + if((0 == ((2 << i) & flags)) && comp(keys[i + 1], keys[i])) { + mgpu::swap(keys[i], keys[i + 1]); + mgpu::swap(values[i], values[i + 1]); + } + OddEvenTransposeSortT::Sort(keys, values, flags, comp); + } }; -template struct OddEvenTransposeSortT { - template - static MGPU_DEVICE void Sort(K* keys, V* values, int flags, Comp comp) { } +template +struct OddEvenTransposeSortT { + template + static MGPU_DEVICE void Sort(K* keys, V* values, int flags, Comp comp) {} }; template MGPU_DEVICE void OddEvenTransposeSort(K* keys, V* values, Comp comp) { - OddEvenTransposeSortT<0, VT>::Sort(keys, values, 0, comp); + OddEvenTransposeSortT<0, VT>::Sort(keys, values, 0, comp); } template -MGPU_DEVICE void OddEvenTransposeSortFlags(K* keys, V* values, int flags, - Comp comp) { - OddEvenTransposeSortT<0, VT>::Sort(keys, values, flags, comp); -} +MGPU_DEVICE void OddEvenTransposeSortFlags( + K* keys, + V* values, + int flags, + Comp comp +) { OddEvenTransposeSortT<0, VT>::Sort(keys, values, flags, comp); } //////////////////////////////////////////////////////////////////////////////// // Batcher Odd-Even Mergesort network @@ -99,70 +102,95 @@ MGPU_DEVICE void OddEvenTransposeSortFlags(K* keys, V* values, int flags, template struct OddEvenMergesortT { - template - MGPU_DEVICE static void CompareAndSwap(K* keys, V* values, int flags, - int a, int b, Comp comp) { - if(b < Count) { - // Mask the bits between a and b. Any head flags in this interval - // means the keys are in different segments and must not be swapped. - const int Mask = ((2<< b) - 1) ^ ((2<< a) - 1); - if(!(Mask & flags) && comp(keys[b], keys[a])) { - mgpu::swap(keys[b], keys[a]); - mgpu::swap(values[b], values[a]); - } - } - } - - template - struct OddEvenMerge { - template - MGPU_DEVICE static void Merge(K* keys, V* values, int flags, - Comp comp) { - // Compare and swap - const int M = 2 * R; - OddEvenMerge::Merge(keys, values, flags, comp); - OddEvenMerge::Merge(keys, values, flags, comp); - - #pragma unroll - for(int i = Low2 + R; i + R < Low2 + Width; i += M) - CompareAndSwap(keys, values, flags, i, i + R, comp); - } - }; - template - struct OddEvenMerge { - template - MGPU_DEVICE static void Merge(K* keys, V* values, int flags, - Comp comp) { - CompareAndSwap(keys, values, flags, Low2, Low2 + R, comp); - } - }; - - template - MGPU_DEVICE static void Sort(K* keys, V* values, int flags, - Comp comp) { - - const int M = Width / 2; - OddEvenMergesortT::Sort(keys, values, flags, comp); - OddEvenMergesortT::Sort(keys, values, flags, comp); - OddEvenMerge<1, Low>::Merge(keys, values, flags, comp); - } + template + MGPU_DEVICE static void CompareAndSwap( + K* keys, + V* values, + int flags, + int a, + int b, + Comp comp + ) { + if(b < Count) { + // Mask the bits between a and b. Any head flags in this interval + // means the keys are in different segments and must not be swapped. + const int Mask = ((2 << b) - 1) ^ ((2 << a) - 1); + if(!(Mask & flags) && comp(keys[b], keys[a])) { + mgpu::swap(keys[b], keys[a]); + mgpu::swap(values[b], values[a]); + } + } + } + + template < int R, int Low2, bool Recurse = 2 * R + struct OddEvenMerge { + template + MGPU_DEVICE static void Merge( + K* keys, + V* values, + int flags, + Comp comp + ) { + // Compare and swap + const int M = 2 * R; + OddEvenMerge::Merge(keys, values, flags, comp); + OddEvenMerge::Merge(keys, values, flags, comp); + +#pragma unroll + for(int i = Low2 + R; i + R < Low2 + Width; i += M) + CompareAndSwap(keys, values, flags, i, i + R, comp); + } + }; + template + struct OddEvenMerge { + template + MGPU_DEVICE static void Merge( + K* keys, + V* values, + int flags, + Comp comp + ) { CompareAndSwap(keys, values, flags, Low2, Low2 + R, comp); } + }; + + template + MGPU_DEVICE static void Sort( + K* keys, + V* values, + int flags, + Comp comp + ) { + + const int M = Width / 2; + OddEvenMergesortT::Sort(keys, values, flags, comp); + OddEvenMergesortT::Sort(keys, values, flags, comp); + OddEvenMerge<1, Low>::Merge(keys, values, flags, comp); + } }; -template struct OddEvenMergesortT<1, Low, Count> { - template - MGPU_DEVICE static void Sort(K* keys, V* values, int flags, - Comp comp) { } +template +struct OddEvenMergesortT<1, Low, Count> { + template + MGPU_DEVICE static void Sort( + K* keys, + V* values, + int flags, + Comp comp + ) {} }; template MGPU_DEVICE void OddEvenMergesort(K* keys, V* values, Comp comp) { - const int Width = 1<< sLogPow2::value; - OddEvenMergesortT::Sort(keys, values, 0, comp); + const int Width = 1 << sLogPow2::value; + OddEvenMergesortT::Sort(keys, values, 0, comp); } template -MGPU_DEVICE void OddEvenMergesortFlags(K* keys, V* values, int flags, - Comp comp) { - const int Width = 1<< sLogPow2::value; - OddEvenMergesortT::Sort(keys, values, flags, comp); +MGPU_DEVICE void OddEvenMergesortFlags( + K* keys, + V* values, + int flags, + Comp comp +) { + const int Width = 1 << sLogPow2::value; + OddEvenMergesortT::Sort(keys, values, flags, comp); } } // namespace mgpu diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/mgpudevice.cuh b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/mgpudevice.cuh index f5000c6..031a6c5 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/mgpudevice.cuh +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/mgpudevice.cuh @@ -26,11 +26,11 @@ ******************************************************************************/ /****************************************************************************** - * - * Code and text by Sean Baxter, NVIDIA Research - * See http://nvlabs.github.io/moderngpu for repository and documentation. - * - ******************************************************************************/ +* +* Code and text by Sean Baxter, NVIDIA Research +* See http://nvlabs.github.io/moderngpu for repository and documentation. +* +******************************************************************************/ #pragma once @@ -43,127 +43,217 @@ namespace mgpu { // device/loadstore.cuh // For 0 <= i < VT: -// index = NT * i + tid; -// reg[i] = data[index]; +// index = NT * i + tid; +// reg[i] = data[index]; // Synchronize after load. template -MGPU_DEVICE void DeviceSharedToReg(InputIt data, int tid, T* reg, - bool sync = true); +MGPU_DEVICE void DeviceSharedToReg( + InputIt data, + int tid, + T* reg, + bool sync = true +); // For 0 <= i < VT: -// index = NT * i + tid; -// if(index < count) reg[i] = data[index]; +// index = NT * i + tid; +// if(index < count) reg[i] = data[index]; // No synchronize after load. template -MGPU_DEVICE void DeviceGlobalToReg(int count, InputIt data, int tid, - T* reg, bool sync = false); +MGPU_DEVICE void DeviceGlobalToReg( + int count, + InputIt data, + int tid, + T* reg, + bool sync = false +); template -MGPU_DEVICE void DeviceGlobalToRegDefault(int count, InputIt data, int tid, - T* reg, T init, bool sync = false); +MGPU_DEVICE void DeviceGlobalToRegDefault( + int count, + InputIt data, + int tid, + T* reg, + T init, + bool sync = false +); // For 0 <= i < VT: -// index = NT * i + tid; -// if(index < count) reg[i] = data[index]; +// index = NT * i + tid; +// if(index < count) reg[i] = data[index]; // No synchronize after load. template -MGPU_DEVICE void DeviceGlobalToReg(int count, InputIt data, int tid, - T* reg, bool sync = false); +MGPU_DEVICE void DeviceGlobalToReg( + int count, + InputIt data, + int tid, + T* reg, + bool sync = false +); // For 0 <= i < VT: -// index = NT * i + tid; -// if(index < count) reg[i] = data[index]; +// index = NT * i + tid; +// if(index < count) reg[i] = data[index]; // No synchronize after load. template -MGPU_DEVICE void DeviceGlobalToRegDefault2(int count, InputIt data, int tid, - T* reg, T init, bool sync = false); +MGPU_DEVICE void DeviceGlobalToRegDefault2( + int count, + InputIt data, + int tid, + T* reg, + T init, + bool sync = false +); // For 0 <= i < VT: -// index = NT * i + tid; -// if(index < count) reg[i] = data[index]; +// index = NT * i + tid; +// if(index < count) reg[i] = data[index]; // No synchronize after load. // No optimized code path for count < NV (smaller generated code). template -MGPU_DEVICE void DeviceGlobalToRegLoop(int count, InputIt data, int tid, - T* reg, bool sync = false); - +MGPU_DEVICE void DeviceGlobalToRegLoop( + int count, + InputIt data, + int tid, + T* reg, + bool sync = false +); // For 0 <= i < VT: -// index = VT * tid + i. -// if(index < count) reg[i] = data[index]; +// index = VT * tid + i. +// if(index < count) reg[i] = data[index]; // No synchronize after load. template -MGPU_DEVICE void DeviceGlobalToThread(int count, InputIt data, int tid, - T* reg); +MGPU_DEVICE void DeviceGlobalToThread( + int count, + InputIt data, + int tid, + T* reg +); template -MGPU_DEVICE void DeviceGlobalToThreadDefault(int count, InputIt data, int tid, - T* reg, T init); +MGPU_DEVICE void DeviceGlobalToThreadDefault( + int count, + InputIt data, + int tid, + T* reg, + T init +); // For 0 <= i < VT: -// index = NT * i + tid; -// if(index < count) data[index] = reg[i]; +// index = NT * i + tid; +// if(index < count) data[index] = reg[i]; // Synchronize after load. template -MGPU_DEVICE void DeviceRegToShared(const T* reg, int tid, OutputIt dest, - bool sync = true); +MGPU_DEVICE void DeviceRegToShared( + const T* reg, + int tid, + OutputIt dest, + bool sync = true +); // For 0 <= i < VT: -// index = NT * i + tid; -// if(index < count) data[index] = reg[i]; +// index = NT * i + tid; +// if(index < count) data[index] = reg[i]; // No synchronize after load. template -MGPU_DEVICE void DeviceRegToGlobal(int count, const T* reg, int tid, - OutputIt dest, bool sync = false); +MGPU_DEVICE void DeviceRegToGlobal( + int count, + const T* reg, + int tid, + OutputIt dest, + bool sync = false +); // For 0 <= index < count: -// dest[index] = source[index]; +// dest[index] = source[index]; // This function is intended to replace DeviceGlobalToShared in cases where // count is much less than NT * VT. template -MGPU_DEVICE void DeviceMemToMemLoop(int count, InputIt source, int tid, - OutputIt dest, bool sync = true); +MGPU_DEVICE void DeviceMemToMemLoop( + int count, + InputIt source, + int tid, + OutputIt dest, + bool sync = true +); // For 0 <= index < count: -// dest[index] = source[index]; +// dest[index] = source[index]; // Synchronize after store. template -MGPU_DEVICE void DeviceSharedToGlobal(int count, const T* source, int tid, - OutputIt dest, bool sync = true); +MGPU_DEVICE void DeviceSharedToGlobal( + int count, + const T* source, + int tid, + OutputIt dest, + bool sync = true +); // For 0 <= index < count: -// dest[index] = source[index]; +// dest[index] = source[index]; // Synchronize after store. template -MGPU_DEVICE void DeviceGlobalToShared(int count, InputIt source, int tid, - T* dest, bool sync = true); +MGPU_DEVICE void DeviceGlobalToShared( + int count, + InputIt source, + int tid, + T* dest, + bool sync = true +); template -MGPU_DEVICE void DeviceGlobalToShared2(int count, InputIt source, int tid, - T* dest, bool sync = true); +MGPU_DEVICE void DeviceGlobalToShared2( + int count, + InputIt source, + int tid, + T* dest, + bool sync = true +); // For 0 <= index < count: -// dest[index] = source[index]; +// dest[index] = source[index]; // Synchronize after store. // No optimized code path for count < NV (smaller generated code). template -MGPU_DEVICE void DeviceGlobalToSharedLoop(int count, InputIt source, int tid, - T* dest, bool sync = true); +MGPU_DEVICE void DeviceGlobalToSharedLoop( + int count, + InputIt source, + int tid, + T* dest, + bool sync = true +); template -MGPU_DEVICE void DeviceGlobalToSharedDefault(int count, InputIt source, int tid, - T* dest, T init, bool sync = true); +MGPU_DEVICE void DeviceGlobalToSharedDefault( + int count, + InputIt source, + int tid, + T* dest, + T init, + bool sync = true +); template -MGPU_DEVICE void DeviceGlobalToSharedDefault2(int count, InputIt source, - int tid, T* dest, T init, bool sync = true); +MGPU_DEVICE void DeviceGlobalToSharedDefault2( + int count, + InputIt source, + int tid, + T* dest, + T init, + bool sync = true +); // For 0 <= index < count: -// dest[index] = source[index]; +// dest[index] = source[index]; // No synchronize. template -MGPU_DEVICE void DeviceGlobalToGlobal(int count, InputIt source, int tid, - OutputIt dest, bool sync = false); +MGPU_DEVICE void DeviceGlobalToGlobal( + int count, + InputIt source, + int tid, + OutputIt dest, + bool sync = false +); // Transponse VT elements in NT threads (x) into thread-order registers (y) // using only NT * VT / 2 elements of shared memory. @@ -171,119 +261,204 @@ template MGPU_DEVICE void HalfSmemTranspose(const T* x, int tid, T* shared, T* y); // For 0 <= i < VT: -// index = NT * i + tid; -// if(index < count) -// gather = indices[index]; -// reg[i] = data[gather]; +// index = NT * i + tid; +// if(index < count) +// gather = indices[index]; +// reg[i] = data[gather]; // Synchronize after load. template -MGPU_DEVICE void DeviceGather(int count, InputIt data, int indices[VT], - int tid, T* reg, bool sync = true); +MGPU_DEVICE void DeviceGather( + int count, + InputIt data, + int indices[VT], + int tid, + T* reg, + bool sync = true +); template -MGPU_DEVICE void DeviceGatherDefault(int count, InputIt data, int indices[VT], - int tid, T* reg, T identity, bool sync = true); +MGPU_DEVICE void DeviceGatherDefault( + int count, + InputIt data, + int indices[VT], + int tid, + T* reg, + T identity, + bool sync = true +); // For 0 <= i < VT: -// index = NT * i + tid; -// if(index < count) -// scatter = indices[index]; -// data[scatter] = reg[i]; +// index = NT * i + tid; +// if(index < count) +// scatter = indices[index]; +// data[scatter] = reg[i]; // Synchronize after store. template -MGPU_DEVICE void DeviceScatter(int count, const T* reg, int tid, - int indices[VT], OutputIt data, bool sync = true); +MGPU_DEVICE void DeviceScatter( + int count, + const T* reg, + int tid, + int indices[VT], + OutputIt data, + bool sync = true +); // For 0 <= i < VT: -// shared[VT * tid + i] = threadReg[i]; +// shared[VT * tid + i] = threadReg[i]; // Synchronize after store. // Note this function moves data in THREAD ORDER. // (DeviceRegToShared moves data in STRIDED ORDER). template -MGPU_DEVICE void DeviceThreadToShared(const T* threadReg, int tid, T* shared, - bool sync = true); +MGPU_DEVICE void DeviceThreadToShared( + const T* threadReg, + int tid, + T* shared, + bool sync = true +); // For 0 <= i < VT: -// threadReg[i] = shared[VT * tid + i]; +// threadReg[i] = shared[VT * tid + i]; // Synchronize after load. // Note this function moves data in THREAD ORDER. // (DeviceSharedToReg moves data in STRIDED ORDER). template -MGPU_DEVICE void DeviceSharedToThread(const T* shared, int tid, T* threadReg, - bool sync = true); +MGPU_DEVICE void DeviceSharedToThread( + const T* shared, + int tid, + T* threadReg, + bool sync = true +); // For 0 <= index < aCount: -// shared[index] = a_global[index]; +// shared[index] = a_global[index]; // For 0 <= index < bCount: -// shared[aCount + index] = b_global[index]; +// shared[aCount + index] = b_global[index]; // VT0 is the lower-bound for predication-free execution: -// If count >= NT * VT0, a predication-free branch is taken. +// If count >= NT * VT0, a predication-free branch is taken. // VT1 is the upper-bound for loads: -// NT * VT1 must >= aCount + bCount. +// NT * VT1 must >= aCount + bCount. template -MGPU_DEVICE void DeviceLoad2ToReg(const T* a_global, int aCount, - const T* b_global, int bCount, int tid, T* reg, bool sync = false); +MGPU_DEVICE void DeviceLoad2ToReg( + const T* a_global, + int aCount, + const T* b_global, + int bCount, + int tid, + T* reg, + bool sync = false +); template -MGPU_DEVICE void DeviceLoad2ToShared(const T* a_global, int aCount, - const T* b_global, int bCount, int tid, T* shared, bool sync = true); +MGPU_DEVICE void DeviceLoad2ToShared( + const T* a_global, + int aCount, + const T* b_global, + int bCount, + int tid, + T* shared, + bool sync = true +); template -MGPU_DEVICE void DeviceLoad2ToReg(InputIt1 a_global, int aCount, - InputIt2 b_global, int bCount, int tid, T* reg, bool sync = false); + typename T> +MGPU_DEVICE void DeviceLoad2ToReg( + InputIt1 a_global, + int aCount, + InputIt2 b_global, + int bCount, + int tid, + T* reg, + bool sync = false +); template -MGPU_DEVICE void DeviceLoad2ToShared(InputIt1 a_global, int aCount, - InputIt2 b_global, int bCount, int tid, T* shared, bool sync = true); + typename T> +MGPU_DEVICE void DeviceLoad2ToShared( + InputIt1 a_global, + int aCount, + InputIt2 b_global, + int bCount, + int tid, + T* shared, + bool sync = true +); // For 0 <= i < VT -// index = NT * i + tid; -// if(index < count) -// gather = indices_shared[index]; -// dest_global[index] = data_global[gather]; +// index = NT * i + tid; +// if(index < count) +// gather = indices_shared[index]; +// dest_global[index] = data_global[gather]; // Synchronize after load. template -MGPU_DEVICE void DeviceGatherGlobalToGlobal(int count, InputIt data_global, - const int* indices_shared, int tid, OutputIt dest_global, - bool sync = true); +MGPU_DEVICE void DeviceGatherGlobalToGlobal( + int count, + InputIt data_global, + const int* indices_shared, + int tid, + OutputIt dest_global, + bool sync = true +); // For 0 <= i < VT -// index = NT * i + tid -// if(index < count) -// gather = indices[index]; -// if(gather < aCount) data = a_global[gather]; -// else data = b_global[gather - aCount]; -// dest_global[index] = data; +// index = NT * i + tid +// if(index < count) +// gather = indices[index]; +// if(gather < aCount) data = a_global[gather]; +// else data = b_global[gather - aCount]; +// dest_global[index] = data; // Synchronize after load. template -MGPU_DEVICE void DeviceTransferMergeValuesReg(int count, InputIt1 a_global, - InputIt2 b_global, int bStart, const int* indices, int tid, - T* reg, bool sync = false); + typename T> +MGPU_DEVICE void DeviceTransferMergeValuesReg( + int count, + InputIt1 a_global, + InputIt2 b_global, + int bStart, + const int* indices, + int tid, + T* reg, + bool sync = false +); template -MGPU_DEVICE void DeviceTransferMergeValuesShared(int count, InputIt1 a_global, - InputIt2 b_global, int bStart, const int* indices_shared, int tid, - OutputIt dest_global, bool sync = true); + typename OutputIt> +MGPU_DEVICE void DeviceTransferMergeValuesShared( + int count, + InputIt1 a_global, + InputIt2 b_global, + int bStart, + const int* indices_shared, + int tid, + OutputIt dest_global, + bool sync = true +); template -MGPU_DEVICE void DeviceTransferMergeValuesReg(int count, const T* a_global, - const T* b_global, int bStart, const int* indices, int tid, - T* reg, bool sync = false); +MGPU_DEVICE void DeviceTransferMergeValuesReg( + int count, + const T* a_global, + const T* b_global, + int bStart, + const int* indices, + int tid, + T* reg, + bool sync = false +); template -MGPU_DEVICE void DeviceTransferMergeValuesShared(int count, const T* a_global, - const T* b_global, int bStart, const int* indices_shared, int tid, - OutputIt dest_global, bool sync = true); - - +MGPU_DEVICE void DeviceTransferMergeValuesShared( + int count, + const T* a_global, + const T* b_global, + int bStart, + const int* indices_shared, + int tid, + OutputIt dest_global, + bool sync = true +); } // namespace mgpu - #include "device/loadstore.cuh" #include "device/ctasegscan.cuh" diff --git a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/util/static.h b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/util/static.h index 5aa1f37..20e705e 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/util/static.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/contrib/moderngpu/include/util/static.h @@ -66,7 +66,6 @@ namespace mgpu { - typedef unsigned char byte; typedef unsigned int uint; @@ -160,7 +159,6 @@ struct sAbs { enum {value = (X >= 0) ? X : -X}; }; - // Finds the number of powers of 2 in the prime factorization of X. template struct sNumFactorsOf2 { diff --git a/flashlight/pkg/speech/third_party/warpctc/include/ctc.h b/flashlight/pkg/speech/third_party/warpctc/include/ctc.h index d45ae5d..566ff24 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/ctc.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/ctc.h @@ -105,7 +105,6 @@ ctcStatus_t compute_ctc_loss( ctcOptions options ); - /** For a given set of labels and minibatch size return the required workspace * size. This will need to be allocated in the same memory space as your * probabilities. diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h index 2b5b225..a77ab11 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/cpu_ctc.h @@ -12,7 +12,6 @@ #include "ctc_helper.h" - template class CpuCTC { public: @@ -50,7 +49,6 @@ class CpuCTC { const int* const input_lengths ); - ctcStatus_t score_forward( const ProbT* const activations, ProbT* costs, @@ -570,7 +568,6 @@ ctcStatus_t CpuCTC::score_forward( bytes_used + mb * per_minibatch_bytes, blank_label_, flat_labels + std::accumulate(label_lengths, label_lengths + mb, 0)); - if(L + ctcm.repeats > T) costs[mb] = ProbT(0); else diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h index 5b19371..670d80f 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/ctc_helper.h @@ -14,9 +14,7 @@ template HOSTDEVICE T neg_inf() { return -T(INFINITY); } -inline int div_up(int x, int y) { - return (x + y - 1) / y; -} +inline int div_up(int x, int y) { return (x + y - 1) / y; } template struct maximum { diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h index 5d8b25b..b4e9135 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h @@ -84,7 +84,6 @@ class GpuCTC { bool compute_betas_and_grad ); - int out_dim_; // Number of characters plus blank int minibatch_; @@ -123,13 +122,11 @@ ctcStatus_t GpuCTC::setup_gpu_metadata( + gpu_bytes_used); gpu_bytes_used += minibatch_ * sizeof(ProbT); - nll_backward_ = reinterpret_cast(static_cast(gpu_workspace_) + gpu_bytes_used); gpu_bytes_used += minibatch_ * sizeof(ProbT); - repeats_ = reinterpret_cast(static_cast(gpu_workspace_) + gpu_bytes_used); @@ -140,7 +137,6 @@ ctcStatus_t GpuCTC::setup_gpu_metadata( + gpu_bytes_used); gpu_bytes_used += minibatch_ * sizeof(int); - // This is the max of all S and T for all valid examples in the minibatch. // A valid example is one for which L + repeats <= T S_ = 0; @@ -200,7 +196,6 @@ ctcStatus_t GpuCTC::setup_gpu_metadata( if(cuda_status != cudaSuccess) return CTC_STATUS_MEMOPS_FAILED; - cuda_status = cudaMemcpyAsync( &(label_offsets_[start_idx]), label_offsets, @@ -271,7 +266,6 @@ ctcStatus_t GpuCTC::setup_gpu_metadata( + gpu_bytes_used); gpu_bytes_used += (S_ * T_) * minibatch_ * sizeof(ProbT); - denoms_ = reinterpret_cast(static_cast(gpu_workspace_) + gpu_bytes_used); @@ -308,7 +302,6 @@ ctcStatus_t GpuCTC::launch_alpha_beta_kernels( labels_with_blanks_, alphas_, nll_forward_, stride, out_dim_, S_, T_, blank_label_); - if(compute_beta) { compute_betas_and_grad_kernel<< < grid_size, NT, 0, stream_ >> > (probs, label_sizes_, utt_length_, repeats_, diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h index 689384d..29406f6 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc_kernels.h @@ -5,7 +5,6 @@ #include "ctc_helper.h" - template struct CTASegReduce { diff --git a/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu b/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu index 3962916..e4567c3 100644 --- a/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu +++ b/flashlight/pkg/speech/third_party/warpctc/src/ctc_entrypoint.cu @@ -9,12 +9,9 @@ #include "detail/gpu_ctc.h" #endif - extern "C" { -int get_warpctc_version() { - return 2; -} +int get_warpctc_version() { return 2; } const char* ctcGetStatusString(ctcStatus_t status) { switch(status) { @@ -36,7 +33,6 @@ const char* ctcGetStatusString(ctcStatus_t status) { } - ctcStatus_t compute_ctc_loss( const float* const activations, float* gradients, @@ -61,7 +57,7 @@ ctcStatus_t compute_ctc_loss( return CTC_STATUS_INVALID_VALUE; if(options.loc == CTC_CPU) { - CpuCTC < float > ctc( + CpuCTC ctc( alphabet_size, minibatch, workspace, @@ -88,7 +84,7 @@ ctcStatus_t compute_ctc_loss( ); } else if(options.loc == CTC_GPU) { #ifdef __CUDACC__ - GpuCTC < float > ctc( + GpuCTC ctc( alphabet_size, minibatch, workspace, @@ -122,7 +118,6 @@ ctcStatus_t compute_ctc_loss( return CTC_STATUS_INVALID_VALUE; } - ctcStatus_t get_workspace_size( const int* const label_lengths, const int* const input_lengths, diff --git a/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu b/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu index 9b3deec..61ce49a 100644 --- a/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu +++ b/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu @@ -15,11 +15,13 @@ const int warp_size = 32; -template < int NT, typename T, typename Rop -> struct CTAReduce; +template +struct CTAReduce; -template < int NT, typename T, typename Rop -> struct CTAReduce { +template +struct CTAReduce { enum {Size = NT, Capacity = NT}; struct Storage { T shared[Capacity]; @@ -51,8 +53,9 @@ template < int NT, typename T, typename Rop } }; -template < int NT, typename Iop, typename Rop, typename T -> __global__ void reduce_rows( +template +__global__ void reduce_rows( Iop f, Rop g, const T* input, @@ -61,7 +64,7 @@ template < int NT, typename Iop, typename Rop, typename T int num_cols ) { - typedef CTAReduce < NT, T, Rop > R; + typedef CTAReduce R; __shared__ typename R::Storage storage; int tid = threadIdx.x; @@ -74,7 +77,6 @@ template < int NT, typename Iop, typename Rop, typename T curr = f(input[idx + col * num_rows]); idx += NT; - while(idx < num_rows) { curr = g(curr, f(input[idx + col * num_rows])); idx += NT; @@ -88,8 +90,9 @@ template < int NT, typename Iop, typename Rop, typename T output[col] = curr; } -template < int NT, typename Iop, typename Rop, typename T -> __global__ void reduce_cols( +template +__global__ void reduce_cols( Iop f, Rop g, const T* input, @@ -127,8 +130,9 @@ template < int NT, typename Iop, typename Rop, typename T struct ReduceHelper { - template < typename T, typename Iof, typename Rof - > static void impl( + template + static void impl( Iof f, Rof g, const T* input, @@ -143,22 +147,22 @@ struct ReduceHelper { if(axis) { grid_size = num_cols; - reduce_rows < 128 > << < grid_size, 128, 0, stream >> + reduce_rows<128> << < grid_size, 128, 0, stream >> > (f, g, input, output, num_rows, num_cols); } else { dim3 tpb(warp_size, 128 / warp_size); grid_size = (num_cols + warp_size - 1) / warp_size; - reduce_cols < 128 > << < grid_size, tpb, 0, stream >> + reduce_cols<128> << < grid_size, tpb, 0, stream >> > (f, g, input, output, num_rows, num_cols); } } }; - -template < typename T, typename Iof, typename Rof -> ctcStatus_t reduce( +template +ctcStatus_t reduce( Iof f, Rof g, const T* input, @@ -186,8 +190,8 @@ ctcStatus_t reduce_negate( cudaStream_t stream ) { return reduce( - ctc_helper::negate < float > (), - ctc_helper::add < float > (), + ctc_helper::negate(), + ctc_helper::add(), input, output, rows, @@ -206,8 +210,8 @@ ctcStatus_t reduce_exp( cudaStream_t stream ) { return reduce( - ctc_helper::exponential < float > (), - ctc_helper::add < float > (), + ctc_helper::exponential(), + ctc_helper::add(), input, output, rows, @@ -226,8 +230,8 @@ ctcStatus_t reduce_max( cudaStream_t stream ) { return reduce( - ctc_helper::identity < float > (), - ctc_helper::maximum < float > (), + ctc_helper::identity(), + ctc_helper::maximum(), input, output, rows, diff --git a/flashlight/pkg/vision/common/BetaDistribution.h b/flashlight/pkg/vision/common/BetaDistribution.h index 416b056..58f9263 100644 --- a/flashlight/pkg/vision/common/BetaDistribution.h +++ b/flashlight/pkg/vision/common/BetaDistribution.h @@ -75,9 +75,7 @@ namespace lib { } template - result_type operator()(URNG& engine) { - return generate(engine, a_gamma, b_gamma); - } + result_type operator()(URNG& engine) { return generate(engine, a_gamma, b_gamma); } template result_type operator()(URNG& engine, const param_type& param) { diff --git a/flashlight/pkg/vision/criterion/Hungarian.cpp b/flashlight/pkg/vision/criterion/Hungarian.cpp index 9ff81d4..d281604 100644 --- a/flashlight/pkg/vision/criterion/Hungarian.cpp +++ b/flashlight/pkg/vision/criterion/Hungarian.cpp @@ -82,8 +82,8 @@ std::pair HungarianMatcher::matchBatch( predBoxes, targetBoxes, [](const Tensor& x, const Tensor& y) { - return fl::sum(fl::abs(x - y), {0}, /* keepDims = */ true); - } + return fl::sum(fl::abs(x - y), {0}, /* keepDims = */ true); + } ); costBbox = flatten(costBbox, 0, 1); diff --git a/flashlight/pkg/vision/criterion/HungarianImpl.h b/flashlight/pkg/vision/criterion/HungarianImpl.h index 9e687ac..9ca35db 100644 --- a/flashlight/pkg/vision/criterion/HungarianImpl.h +++ b/flashlight/pkg/vision/criterion/HungarianImpl.h @@ -23,7 +23,6 @@ namespace lib { */ void hungarian(float* costs, int* rowIdxs, int* colIdxs, int M, int N); - /* * Same as above except it will output an M X N assignment matrix where * assignments[m][n] == 1 means m and n are assigned. diff --git a/flashlight/pkg/vision/criterion/SetCriterion.cpp b/flashlight/pkg/vision/criterion/SetCriterion.cpp index 5ceef0d..1451680 100644 --- a/flashlight/pkg/vision/criterion/SetCriterion.cpp +++ b/flashlight/pkg/vision/criterion/SetCriterion.cpp @@ -29,8 +29,7 @@ Tensor span(const Shape& inDims, const int index) { return fl::iota(dims); } -Shape calcStrides(const Shape& dims) { - return {1, dims[0], dims[0] * dims[1], dims[0] * dims[1] * dims[2]}; +Shape calcStrides(const Shape& dims) { return {1, dims[0], dims[0] * dims[1], dims[0] * dims[1] * dims[2]}; }; Shape calcOutDims(const std::vector& coords) { @@ -299,9 +298,7 @@ SetCriterion::LossDict SetCriterion::lossLabels( return {{"lossCe", lossCe.astype(predLogits.type())}}; } -std::unordered_map SetCriterion::getWeightDict() { - return weightDict_; -} +std::unordered_map SetCriterion::getWeightDict() { return weightDict_; } std::pair SetCriterion::getTgtPermutationIdx( const std::vector>& indices diff --git a/flashlight/pkg/vision/dataset/BoxUtils.cpp b/flashlight/pkg/vision/dataset/BoxUtils.cpp index 39e1e76..e409c50 100644 --- a/flashlight/pkg/vision/dataset/BoxUtils.cpp +++ b/flashlight/pkg/vision/dataset/BoxUtils.cpp @@ -246,8 +246,10 @@ Tensor generalizedBoxIou(const Tensor& bboxes1, const Tensor& bboxes2) { return iou - (area - uni) / area; } -Variable l1Loss(const Variable& input, const Variable& target) { - return flatten(fl::sum(fl::abs(input - target), {0}), 0, 1); +Variable l1Loss( + const Variable& input, + const Variable& target +) { return flatten(fl::sum(fl::abs(input - target), {0}), 0, 1); } } // namespace fl diff --git a/flashlight/pkg/vision/dataset/Coco.cpp b/flashlight/pkg/vision/dataset/Coco.cpp index 6741d81..c2991ce 100644 --- a/flashlight/pkg/vision/dataset/Coco.cpp +++ b/flashlight/pkg/vision/dataset/Coco.cpp @@ -209,8 +209,6 @@ int64_t CocoDataset::size() const { return batched_->size(); } -CocoData CocoDataset::get(const uint64_t idx) { - return batched_->get(idx); -} +CocoData CocoDataset::get(const uint64_t idx) { return batched_->get(idx); } } // namespace fl diff --git a/flashlight/pkg/vision/dataset/Coco.h b/flashlight/pkg/vision/dataset/Coco.h index c1be7e4..030083c 100644 --- a/flashlight/pkg/vision/dataset/Coco.h +++ b/flashlight/pkg/vision/dataset/Coco.h @@ -53,13 +53,9 @@ namespace pkg { using iterator = detail::DatasetIterator; - iterator begin() { - return iterator(this); - } + iterator begin() { return iterator(this); } - iterator end() { - return iterator(); - } + iterator end() { return iterator(); } int64_t size() const; diff --git a/flashlight/pkg/vision/dataset/CocoTransforms.cpp b/flashlight/pkg/vision/dataset/CocoTransforms.cpp index ce061b3..d6f7ee3 100644 --- a/flashlight/pkg/vision/dataset/CocoTransforms.cpp +++ b/flashlight/pkg/vision/dataset/CocoTransforms.cpp @@ -17,9 +17,7 @@ namespace { -int randomInt(int min, int max) { - return std::rand() % (max - min + 1) + min; -} +int randomInt(int min, int max) { return std::rand() % (max - min + 1) + min; } } // namespace namespace fl::pkg::vision { @@ -198,14 +196,14 @@ TransformAllFunction Normalize( boxes, in[ClassesIdx]}; return outputs; - }; + }; } TransformAllFunction randomSelect(std::vector fns) { return [fns](const std::vector& in) { TransformAllFunction randomFunc = fns[std::rand() % fns.size()]; return randomFunc(in); - }; + }; }; TransformAllFunction randomSizeCrop(int minSize, int maxSize) { @@ -218,7 +216,7 @@ TransformAllFunction randomSizeCrop(int minSize, int maxSize) { const int x = std::rand() % (w - tw + 1); const int y = std::rand() % (h - th + 1); return crop(in, x, y, tw, th); - }; + }; }; TransformAllFunction randomResize(std::vector sizes, int maxsize) { @@ -240,7 +238,7 @@ TransformAllFunction randomHorizontalFlip(float p) { return hflip(in); else return in; - }; + }; } TransformAllFunction compose(std::vector fns) { @@ -249,7 +247,7 @@ TransformAllFunction compose(std::vector fns) { for(const auto& fn : fns) out = fn(out); return out; - }; + }; } } // namespace fl diff --git a/flashlight/pkg/vision/dataset/Transforms.cpp b/flashlight/pkg/vision/dataset/Transforms.cpp index 7201171..cc43f48 100644 --- a/flashlight/pkg/vision/dataset/Transforms.cpp +++ b/flashlight/pkg/vision/dataset/Transforms.cpp @@ -55,12 +55,21 @@ Tensor resizeSmallest(const Tensor& in, const int resize) { return fl::resize(in, {tw, th}, InterpolationMode::Bilinear); } -Tensor resize(const Tensor& in, const int resize) { - return fl::resize(in, {resize, resize}, InterpolationMode::Bilinear); +Tensor resize(const Tensor& in, const int resize) { return fl::resize( + in, + {resize, resize}, + InterpolationMode::Bilinear +); } Tensor crop(const Tensor& in, const int x, const int y, const int w, const int h) { - return in(fl::range(x, x + w), fl::range(y, y + h)); + return in( + fl::range(x, x + w), + fl::range( + y, + y + h + ) + ); } Tensor centerCrop(const Tensor& in, const int size) { @@ -75,20 +84,32 @@ Tensor rotate(const Tensor& input, const float theta, const Tensor& fillImg) { return fl::rotate(input, theta, fillImg); } -Tensor skewX(const Tensor& input, const float theta, const Tensor& fillImg) { - return fl::shear(input, {theta, 0}, {}, fillImg); +Tensor skewX( + const Tensor& input, + const float theta, + const Tensor& fillImg +) { return fl::shear(input, {theta, 0}, {}, fillImg); } -Tensor skewY(const Tensor& input, const float theta, const Tensor& fillImg) { - return fl::shear(input, {0, theta}, {}, fillImg); +Tensor skewY( + const Tensor& input, + const float theta, + const Tensor& fillImg +) { return fl::shear(input, {0, theta}, {}, fillImg); } -Tensor translateX(const Tensor& input, const int shift, const Tensor& fillImg) { - return fl::translate(input, {shift, 0}, {}, fillImg); +Tensor translateX( + const Tensor& input, + const int shift, + const Tensor& fillImg +) { return fl::translate(input, {shift, 0}, {}, fillImg); } -Tensor translateY(const Tensor& input, const int shift, const Tensor& fillImg) { - return fl::translate(input, {0, shift}, {}, fillImg); +Tensor translateY( + const Tensor& input, + const int shift, + const Tensor& fillImg +) { return fl::translate(input, {0, shift}, {}, fillImg); } Tensor colorEnhance(const Tensor& input, const float enhance) { @@ -115,13 +136,9 @@ Tensor contrastEnhance(const Tensor& input, const float enhance) { return meanPic + enhance * (input - meanPic); } -Tensor brightnessEnhance(const Tensor& input, const float enhance) { - return input * enhance; -} +Tensor brightnessEnhance(const Tensor& input, const float enhance) { return input * enhance; } -Tensor invert(const Tensor& input) { - return 255. - input; -} +Tensor invert(const Tensor& input) { return 255. - input; } Tensor solarize(const Tensor& input, const float threshold) { auto mask = (input < threshold); @@ -269,8 +286,9 @@ std::pair cutmixBatch( return {inputMixed, targetOneHotMixed}; } -ImageTransform resizeTransform(const uint64_t resize) { - return [resize](const Tensor& in) { return resizeSmallest(in, resize); }; +ImageTransform resizeTransform(const uint64_t resize) { return [resize](const Tensor& in) { + return resizeSmallest(in, resize); + }; } ImageTransform compose(std::vector transformfns) { @@ -279,11 +297,12 @@ ImageTransform compose(std::vector transformfns) { for(const auto& fn : transformfns) out = fn(out); return out; - }; + }; } -ImageTransform centerCropTransform(const int size) { - return [size](const Tensor& in) { return centerCrop(in, size); }; +ImageTransform centerCropTransform(const int size) { return [size](const Tensor& in) { + return centerCrop(in, size); + }; }; ImageTransform randomHorizontalFlipTransform(const float p) { @@ -295,7 +314,7 @@ ImageTransform randomHorizontalFlipTransform(const float p) { out = out(fl::range(w - 1, 1, -1)); } return out; - }; + }; }; ImageTransform randomResizeCropTransform( @@ -326,7 +345,7 @@ ImageTransform randomResizeCropTransform( } } return centerCrop(resizeSmallest(in, size), size); - }; + }; } ImageTransform randomResizeTransform(const int low, const int high) { @@ -335,7 +354,7 @@ ImageTransform randomResizeTransform(const int low, const int high) { static_cast(std::rand()) / static_cast(RAND_MAX); const int resize = low + (high - low) * scale; return resizeSmallest(in, resize); - }; + }; }; ImageTransform randomCropTransform(const int tw, const int th) { @@ -350,7 +369,7 @@ ImageTransform randomCropTransform(const int tw, const int th) { const int x = std::rand() % (w - tw + 1); const int y = std::rand() % (h - th + 1); return crop(in, x, y, tw, th); - }; + }; }; ImageTransform normalizeImage( @@ -364,7 +383,7 @@ ImageTransform normalizeImage( out = out - mean; out = out / std; return out; - }; + }; }; ImageTransform randomEraseTransform( @@ -403,7 +422,7 @@ ImageTransform randomEraseTransform( break; } return out; - }; + }; }; ImageTransform randomAugmentationDeitTransform( @@ -500,7 +519,7 @@ ImageTransform randomAugmentationDeitTransform( res = fl::clip(res, 0., 255.).astype(res.type()); } return res; - }; + }; } } // namespace fl diff --git a/flashlight/pkg/vision/models/Detr.cpp b/flashlight/pkg/vision/models/Detr.cpp index 0a0eadf..a7171a1 100644 --- a/flashlight/pkg/vision/models/Detr.cpp +++ b/flashlight/pkg/vision/models/Detr.cpp @@ -13,9 +13,7 @@ namespace { -double calculateGain(double negativeSlope) { - return std::sqrt(2.0 / (1 + std::pow(negativeSlope, 2))); -} +double calculateGain(double negativeSlope) { return std::sqrt(2.0 / (1 + std::pow(negativeSlope, 2))); } std::shared_ptr makeLinear(int inDim, int outDim) { int fanIn = inDim; @@ -102,8 +100,7 @@ std::vector Detr::forward(const std::vector& input) { return forwardTransformer({feature, input[1]}); } -Variable Detr::forwardBackbone(const Variable& input) { - return backbone_->forward({input})[1]; +Variable Detr::forwardBackbone(const Variable& input) { return backbone_->forward({input})[1]; } std::vector Detr::forwardTransformer( @@ -157,8 +154,6 @@ std::vector Detr::paramsWithoutBackbone() { return results; } -std::vector Detr::backboneParams() { - return backbone_->params(); -} +std::vector Detr::backboneParams() { return backbone_->params(); } } // namespace fl diff --git a/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp b/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp index 7f1150f..a2d15aa 100644 --- a/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp +++ b/flashlight/pkg/vision/nn/FrozenBatchNorm.cpp @@ -50,13 +50,9 @@ Variable FrozenBatchNorm::forward(const Variable& input) { return (input * fl::tileAs(scale, input)) + fl::tileAs(bias, input); } -void FrozenBatchNorm::setRunningMean(const fl::Variable& x) { - runningMean_ = x; -} +void FrozenBatchNorm::setRunningMean(const fl::Variable& x) { runningMean_ = x; } -void FrozenBatchNorm::setRunningVar(const fl::Variable& x) { - runningVar_ = x; -} +void FrozenBatchNorm::setRunningVar(const fl::Variable& x) { runningVar_ = x; } void FrozenBatchNorm::train() { for(auto& param : params_) diff --git a/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp b/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp index 0ef694f..c56991c 100644 --- a/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp +++ b/flashlight/pkg/vision/nn/PositionalEmbeddingSine.cpp @@ -117,8 +117,6 @@ std::vector PositionalEmbeddingSine::forward( std::vector PositionalEmbeddingSine::operator()( const std::vector& input -) { - return forward(input); -} +) { return forward(input); } } // namespace fl diff --git a/flashlight/pkg/vision/nn/Transformer.cpp b/flashlight/pkg/vision/nn/Transformer.cpp index 5ce4a2f..9ab7c5f 100644 --- a/flashlight/pkg/vision/nn/Transformer.cpp +++ b/flashlight/pkg/vision/nn/Transformer.cpp @@ -9,7 +9,6 @@ #include - using namespace fl; namespace { From 02fcd2c233ee711bb462a45c03a6c1de678a4cc1 Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 21:41:26 +0100 Subject: [PATCH 20/24] for some reason something else was formatted again --- flashlight/pkg/vision/dataset/Transforms.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/flashlight/pkg/vision/dataset/Transforms.cpp b/flashlight/pkg/vision/dataset/Transforms.cpp index cc43f48..5519158 100644 --- a/flashlight/pkg/vision/dataset/Transforms.cpp +++ b/flashlight/pkg/vision/dataset/Transforms.cpp @@ -55,11 +55,12 @@ Tensor resizeSmallest(const Tensor& in, const int resize) { return fl::resize(in, {tw, th}, InterpolationMode::Bilinear); } -Tensor resize(const Tensor& in, const int resize) { return fl::resize( - in, - {resize, resize}, - InterpolationMode::Bilinear -); +Tensor resize(const Tensor& in, const int resize) { + return fl::resize( + in, + {resize, resize}, + InterpolationMode::Bilinear + ); } Tensor crop(const Tensor& in, const int x, const int y, const int w, const int h) { @@ -286,7 +287,8 @@ std::pair cutmixBatch( return {inputMixed, targetOneHotMixed}; } -ImageTransform resizeTransform(const uint64_t resize) { return [resize](const Tensor& in) { +ImageTransform resizeTransform(const uint64_t resize) { + return [resize](const Tensor& in) { return resizeSmallest(in, resize); }; } @@ -300,7 +302,8 @@ ImageTransform compose(std::vector transformfns) { }; } -ImageTransform centerCropTransform(const int size) { return [size](const Tensor& in) { +ImageTransform centerCropTransform(const int size) { + return [size](const Tensor& in) { return centerCrop(in, size); }; }; From 5606612d293b7d59184ea33aed90dddb9be4c4ac Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 21:49:38 +0100 Subject: [PATCH 21/24] fixed cout indenting --- flashlight/fl/autograd/Functions.cpp | 2 +- flashlight/fl/autograd/Variable.cpp | 12 +++---- flashlight/fl/common/Histogram.h | 4 +-- flashlight/fl/common/Logging.cpp | 18 +++++----- .../fl/contrib/modules/AdaptiveEmbedding.cpp | 2 +- flashlight/fl/contrib/modules/Conformer.cpp | 14 ++++---- .../fl/contrib/modules/RawWavSpecAugment.cpp | 4 +-- flashlight/fl/contrib/modules/Residual.cpp | 2 +- .../modules/SinusoidalPositionEmbedding.cpp | 2 +- flashlight/fl/contrib/modules/TDSBlock.cpp | 2 +- flashlight/fl/contrib/modules/Transformer.cpp | 10 +++--- .../backend/cuda/DistributedBackend.cpp | 2 +- .../fl/examples/DistributedTraining.cpp | 2 +- flashlight/fl/examples/LinearRegression.cpp | 2 +- flashlight/fl/examples/Mnist.cpp | 8 ++--- flashlight/fl/examples/Perceptron.cpp | 2 +- flashlight/fl/examples/RnnClassification.cpp | 24 ++++++------- flashlight/fl/examples/RnnLm.cpp | 10 +++--- flashlight/fl/examples/Xor.cpp | 8 ++--- flashlight/fl/nn/modules/Conv2D.cpp | 2 +- flashlight/fl/nn/modules/Embedding.cpp | 2 +- flashlight/fl/nn/modules/Pool2D.cpp | 2 +- flashlight/fl/runtime/CUDAUtils.cpp | 2 +- flashlight/fl/runtime/Device.cpp | 4 +-- flashlight/fl/tensor/Shape.cpp | 4 +-- .../tensor/backend/af/ArrayFireBinaryOps.cpp | 4 +-- .../backend/af/mem/CachingMemoryManager.cpp | 34 +++++++++--------- .../backend/af/mem/DefaultMemoryManager.cpp | 12 +++---- .../test/autograd/AutogradBinaryOpsTest.cpp | 12 +++---- flashlight/fl/test/autograd/AutogradTest.cpp | 4 +-- .../test/distributed/AllReduceBenchmark.cpp | 8 ++--- .../fl/test/distributed/AllReduceTest.cpp | 4 +-- flashlight/fl/test/tensor/TensorBaseTest.cpp | 29 ++++++++------- .../fl/test/tensor/TensorBinaryOpsTest.cpp | 36 +++++++++---------- .../tensor/af/CachingMemoryManagerTest.cpp | 6 ++-- flashlight/pkg/runtime/Runtime.cpp | 2 +- flashlight/pkg/runtime/amp/DynamicScaler.cpp | 12 +++---- flashlight/pkg/runtime/common/Serializer.h | 4 +-- .../pkg/speech/augmentation/AdditiveNoise.cpp | 6 ++-- .../pkg/speech/augmentation/GaussianNoise.cpp | 2 +- .../pkg/speech/augmentation/Reverberation.cpp | 8 ++--- .../speech/augmentation/SoundEffectConfig.cpp | 4 +-- .../pkg/speech/augmentation/SoxWrapper.cpp | 8 ++--- .../pkg/speech/augmentation/TimeStretch.cpp | 4 +-- flashlight/pkg/speech/common/Flags.cpp | 20 +++++------ flashlight/pkg/speech/decoder/PlGenerator.cpp | 4 +-- flashlight/pkg/speech/runtime/Helpers.cpp | 2 +- flashlight/pkg/speech/test/audio/MfccTest.cpp | 2 +- .../speech/test/criterion/BenchmarkASG.cpp | 2 +- .../speech/test/criterion/BenchmarkCTC.cpp | 2 +- .../test/criterion/BenchmarkSeq2Seq.cpp | 4 +-- .../pkg/speech/test/criterion/CompareASG.cpp | 2 +- .../warpctc/include/detail/gpu_ctc.h | 8 ++--- .../speech/third_party/warpctc/src/reduce.cu | 4 +-- flashlight/pkg/text/data/TextDataset.cpp | 8 ++--- flashlight/pkg/vision/dataset/BoxUtils.cpp | 4 +-- flashlight/pkg/vision/models/ViT.cpp | 2 +- .../pkg/vision/nn/VisionTransformer.cpp | 8 ++--- .../pkg/vision/test/TransformerTest.cpp | 6 ++-- .../vision/test/criterion/HungarianTest.cpp | 10 +++--- uncrustify.cfg | 2 +- 61 files changed, 217 insertions(+), 218 deletions(-) diff --git a/flashlight/fl/autograd/Functions.cpp b/flashlight/fl/autograd/Functions.cpp index 25722a8..d597c01 100644 --- a/flashlight/fl/autograd/Functions.cpp +++ b/flashlight/fl/autograd/Functions.cpp @@ -36,7 +36,7 @@ namespace detail { if(rdims[i] % idimsSize != 0) { std::stringstream ss; ss << "Invalid dims for tileAs for input dims " << idims - << " to output dims " << rdims; + << " to output dims " << rdims; throw std::invalid_argument(ss.str()); } dims[i] = rdims[i] / idimsSize; diff --git a/flashlight/fl/autograd/Variable.cpp b/flashlight/fl/autograd/Variable.cpp index cd627b8..bd1fb6a 100644 --- a/flashlight/fl/autograd/Variable.cpp +++ b/flashlight/fl/autograd/Variable.cpp @@ -190,17 +190,17 @@ void Variable::addGrad(const Variable& childGrad) { if(childGrad.type() != this->type()) { std::stringstream ss; ss << "Variable::addGrad: attempted to add child gradient of type " - << childGrad.type() << " to a Variable of type " << this->type() - << ". You might be performing an operation with " - "two inputs of different types."; + << childGrad.type() << " to a Variable of type " << this->type() + << ". You might be performing an operation with " + "two inputs of different types."; throw std::invalid_argument(ss.str()); } if(childGrad.shape() != this->shape()) { std::stringstream ss; ss << "Variable::addGrad: given gradient has dimensions not equal " - "to this Variable's dimensions: this variable has shape " - << this->shape() << " whereas the child gradient has dimensions " - << childGrad.shape() << std::endl; + "to this Variable's dimensions: this variable has shape " + << this->shape() << " whereas the child gradient has dimensions " + << childGrad.shape() << std::endl; throw std::invalid_argument(ss.str()); } if(sharedGrad_->grad) diff --git a/flashlight/fl/common/Histogram.h b/flashlight/fl/common/Histogram.h index 4ddbb49..23be43b 100644 --- a/flashlight/fl/common/Histogram.h +++ b/flashlight/fl/common/Histogram.h @@ -208,7 +208,7 @@ std::string HistogramStats::prettyString( ) const { std::stringstream ss; ss << "HistogramStats{" - << " min=["; + << " min=["; fromatValuesIntoStream(ss, min); ss << "] max_=["; fromatValuesIntoStream(ss, max); @@ -230,7 +230,7 @@ std::string HistogramStats::prettyString( countPerTick, fromatCountIntoStream, fromatValuesIntoStream - ); + ); ss << std::endl; } } diff --git a/flashlight/fl/common/Logging.cpp b/flashlight/fl/common/Logging.cpp index fced169..d558886 100644 --- a/flashlight/fl/common/Logging.cpp +++ b/flashlight/fl/common/Logging.cpp @@ -97,8 +97,8 @@ namespace { threadId = threadId.substr(threadId.size() - maxThreadIdNumDigits); (*outputStream) << dateTimeWithMicroSeconds() << ' ' - << threadId << ' ' - << getFileName(fullPath) << ':' << lineNumber << ' '; + << threadId << ' ' + << getFileName(fullPath) << ':' << lineNumber << ' '; } } // namespace @@ -144,8 +144,8 @@ Logging::~Logging() { void Logging::setMaxLoggingLevel(LogLevel maxLoggingLevel) { if(maxLoggingLevel != Logging::maxLoggingLevel_) { std::cerr << "Logging::setMaxLoggingLevel(maxLoggingLevel=" - << logLevelName(maxLoggingLevel) << ") Logging::maxLoggingLevel_=" - << logLevelName(Logging::maxLoggingLevel_) << std::endl; + << logLevelName(maxLoggingLevel) << ") Logging::maxLoggingLevel_=" + << logLevelName(Logging::maxLoggingLevel_) << std::endl; Logging::maxLoggingLevel_ = maxLoggingLevel; } } @@ -196,8 +196,8 @@ VerboseLogging::~VerboseLogging() { void VerboseLogging::setMaxLoggingLevel(int maxLoggingLevel) { if(maxLoggingLevel != VerboseLogging::maxLoggingLevel_) { std::cerr << "VerboseLogging::setMaxLoggingLevel(maxLoggingLevel=" - << maxLoggingLevel << ") VerboseLogging::maxLoggingLevel_=" - << VerboseLogging::maxLoggingLevel_ << std::endl; + << maxLoggingLevel << ") VerboseLogging::maxLoggingLevel_=" + << VerboseLogging::maxLoggingLevel_ << std::endl; VerboseLogging::maxLoggingLevel_ = maxLoggingLevel; } } @@ -241,8 +241,8 @@ std::string logLevelName(LogLevel level) { return flLogLevelNames.at(i); std::stringstream ss; ss << "logLevelName(level=" << static_cast(level) - << ") invalid level. Level should be in the range [0.." - << (flLogLevelNames.size() - 1) << "]"; + << ") invalid level. Level should be in the range [0.." + << (flLogLevelNames.size() - 1) << "]"; throw std::invalid_argument(ss.str()); } @@ -252,7 +252,7 @@ LogLevel logLevelValue(const std::string& level) { return flLogLevelValues.at(i); std::stringstream ss; ss << "logLevelValue(level=" << level - << ") invalid level. Level should be INFO, WARNING, ERROR or FATAL"; + << ") invalid level. Level should be INFO, WARNING, ERROR or FATAL"; throw std::invalid_argument(ss.str()); } diff --git a/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp b/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp index 70ac896..4fe16aa 100644 --- a/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp +++ b/flashlight/fl/contrib/modules/AdaptiveEmbedding.cpp @@ -115,7 +115,7 @@ std::string AdaptiveEmbedding::prettyString() const { for(int i = 0; i < cutoff_.size() - 1; i++) ss << cutoff_[i] << ", "; ss << cutoff_[cutoff_.size() - 1] << "), " - << "(divValue: " << divValue_ << ")"; + << "(divValue: " << divValue_ << ")"; return ss.str(); } diff --git a/flashlight/fl/contrib/modules/Conformer.cpp b/flashlight/fl/contrib/modules/Conformer.cpp index 017e8d1..6559c08 100644 --- a/flashlight/fl/contrib/modules/Conformer.cpp +++ b/flashlight/fl/contrib/modules/Conformer.cpp @@ -293,13 +293,13 @@ std::unique_ptr Conformer::clone() const { std::string Conformer::prettyString() const { std::ostringstream ss; ss << "Conformer " - << "(modelDim: " << params_[1].dim(1) << "), " - << "(mlpDim: " << params_[1].dim(0) << "), " - << "(nHeads: " << nHeads_ << "), " - << "(pDropout: " << pDropout_ << "), " - << "(pLayerDropout: " << pLayerDropout_ << "), " - << "(posEmbContextSize: " << posEmbContextSize_ << "), " - << "(convKernelSize: " << convKernelSize_ << ") "; + << "(modelDim: " << params_[1].dim(1) << "), " + << "(mlpDim: " << params_[1].dim(0) << "), " + << "(nHeads: " << nHeads_ << "), " + << "(pDropout: " << pDropout_ << "), " + << "(pLayerDropout: " << pLayerDropout_ << "), " + << "(posEmbContextSize: " << posEmbContextSize_ << "), " + << "(convKernelSize: " << convKernelSize_ << ") "; return ss.str(); } diff --git a/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp b/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp index 26be879..99c91b4 100644 --- a/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp +++ b/flashlight/fl/contrib/modules/RawWavSpecAugment.cpp @@ -88,8 +88,8 @@ void RawWavSpecAugment::precomputeFilters() { int width = 2. / (1e-6 + transBandKhz[fidx]); if(width * 2 + 1 > maxKernelSize_) { FL_LOG(fl::LogLevel::INFO) - << "RawWavSpecAugment raw wave: frequency " << cutoff_[fidx] - << " will be skipped for eval, too large kernel"; + << "RawWavSpecAugment raw wave: frequency " << cutoff_[fidx] + << " will be skipped for eval, too large kernel"; lowPassFilters_.push_back(nullptr); ignoredLowPassFilters_++; continue; diff --git a/flashlight/fl/contrib/modules/Residual.cpp b/flashlight/fl/contrib/modules/Residual.cpp index 5014961..31f800b 100644 --- a/flashlight/fl/contrib/modules/Residual.cpp +++ b/flashlight/fl/contrib/modules/Residual.cpp @@ -163,7 +163,7 @@ std::string Residual::prettyString() const { ss << "output"; if(shortcut.second != -1) ss << " with transformation: " - << modules_[shortcut.second]->prettyString() << ";"; + << modules_[shortcut.second]->prettyString() << ";"; ss << " "; } } diff --git a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp index 6a6ac42..22ddc2a 100644 --- a/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp +++ b/flashlight/fl/contrib/modules/SinusoidalPositionEmbedding.cpp @@ -100,7 +100,7 @@ std::unique_ptr SinusoidalPositionEmbedding::clone() const { std::string SinusoidalPositionEmbedding::prettyString() const { std::ostringstream ss; ss << "Sinusoidal Position Embedding Layer (embDim: " << layerDim_ - << "), (input scale " << inputScale_ << ")"; + << "), (input scale " << inputScale_ << ")"; return ss.str(); } diff --git a/flashlight/fl/contrib/modules/TDSBlock.cpp b/flashlight/fl/contrib/modules/TDSBlock.cpp index 791d0de..3720ca9 100644 --- a/flashlight/fl/contrib/modules/TDSBlock.cpp +++ b/flashlight/fl/contrib/modules/TDSBlock.cpp @@ -86,7 +86,7 @@ std::string TDSBlock::prettyString() const { int l2 = linW.dim(0); ss << "Time-Depth Separable Block ("; ss << kw << ", " << w << ", " << c << ") [" << l << " -> " << l2 << " -> " - << l << "]"; + << l << "]"; return ss.str(); } diff --git a/flashlight/fl/contrib/modules/Transformer.cpp b/flashlight/fl/contrib/modules/Transformer.cpp index 2ea494a..5eff578 100644 --- a/flashlight/fl/contrib/modules/Transformer.cpp +++ b/flashlight/fl/contrib/modules/Transformer.cpp @@ -220,11 +220,11 @@ std::unique_ptr Transformer::clone() const { std::string Transformer::prettyString() const { std::ostringstream ss; ss << "Transformer (nHeads: " << nHeads_ << "), " - << "(pDropout: " << pDropout_ << "), " - << "(pLayerdrop: " << pLayerdrop_ << "), " - << "(bptt: " << bptt_ << "), " - << "(useMask: " << useMask_ << "), " - << "(preLayerNorm: " << preLN_ << ")"; + << "(pDropout: " << pDropout_ << "), " + << "(pLayerdrop: " << pLayerdrop_ << "), " + << "(bptt: " << bptt_ << "), " + << "(useMask: " << useMask_ << "), " + << "(preLayerNorm: " << preLN_ << ")"; return ss.str(); } diff --git a/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp b/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp index df85152..a154f54 100644 --- a/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp +++ b/flashlight/fl/distributed/backend/cuda/DistributedBackend.cpp @@ -294,7 +294,7 @@ void distributedInit( detail::DistributedInfo::getInstance().backend_ = DistributedBackend::NCCL; if(getWorldRank() == 0) std::cout << "Initialized NCCL " << NCCL_MAJOR << "." << NCCL_MINOR << "." - << NCCL_PATCH << " successfully!\n"; + << NCCL_PATCH << " successfully!\n"; } namespace detail { diff --git a/flashlight/fl/examples/DistributedTraining.cpp b/flashlight/fl/examples/DistributedTraining.cpp index 62cf20d..56ae946 100644 --- a/flashlight/fl/examples/DistributedTraining.cpp +++ b/flashlight/fl/examples/DistributedTraining.cpp @@ -105,7 +105,7 @@ int main() { fl::allReduce(mseArr); if(isMaster) std::cout << "Epoch: " << e << " Mean Squared Error: " - << mseArr.scalar() / worldSize << std::endl; + << mseArr.scalar() / worldSize << std::endl; } if(isMaster) std::cout << "[Multi-layer Perceptron] Done!" << std::endl; diff --git a/flashlight/fl/examples/LinearRegression.cpp b/flashlight/fl/examples/LinearRegression.cpp index 43d374f..16f1d90 100644 --- a/flashlight/fl/examples/LinearRegression.cpp +++ b/flashlight/fl/examples/LinearRegression.cpp @@ -58,7 +58,7 @@ int main() { } std::cout << "Epoch: " << e - << " Mean Squared Error: " << error.scalar() << std::endl; + << " Mean Squared Error: " << error.scalar() << std::endl; } std::cout << "[Linear Regression] Done!" << std::endl; diff --git a/flashlight/fl/examples/Mnist.cpp b/flashlight/fl/examples/Mnist.cpp index 366cacb..481e77e 100644 --- a/flashlight/fl/examples/Mnist.cpp +++ b/flashlight/fl/examples/Mnist.cpp @@ -182,9 +182,9 @@ int main(int argc, char** argv) { std::tie(val_loss, val_error) = eval_loop(model, valset); std::cout << "Epoch " << e << std::setprecision(3) - << ": Avg Train Loss: " << train_loss - << " Validation Loss: " << val_loss - << " Validation Error (%): " << val_error << std::endl; + << ": Avg Train Loss: " << train_loss + << " Validation Loss: " << val_loss + << " Validation Error (%): " << val_error << std::endl; } Tensor test_x; @@ -198,7 +198,7 @@ int main(int argc, char** argv) { double test_loss, test_error; std::tie(test_loss, test_error) = eval_loop(model, testset); std::cout << "Test Loss: " << test_loss << " Test Error (%): " << test_error - << std::endl; + << std::endl; return 0; } diff --git a/flashlight/fl/examples/Perceptron.cpp b/flashlight/fl/examples/Perceptron.cpp index 78407dc..affce6c 100644 --- a/flashlight/fl/examples/Perceptron.cpp +++ b/flashlight/fl/examples/Perceptron.cpp @@ -71,7 +71,7 @@ int main() { meter.add(l.scalar()); } std::cout << "Epoch: " << e << " Mean Squared Error: " << meter.value()[0] - << std::endl; + << std::endl; } std::cout << "[Multi-layer Perceptron] Done!" << std::endl; return 0; diff --git a/flashlight/fl/examples/RnnClassification.cpp b/flashlight/fl/examples/RnnClassification.cpp index 2691f70..ed16839 100644 --- a/flashlight/fl/examples/RnnClassification.cpp +++ b/flashlight/fl/examples/RnnClassification.cpp @@ -65,7 +65,7 @@ class ClassificationDataset : public Dataset { } totalExamples += v.size(); std::cout << "Found " << v.size() << " examples for category " << lang - << ". Total: " << totalExamples << std::endl; + << ". Total: " << totalExamples << std::endl; datasets[lang] = v; } @@ -161,7 +161,7 @@ class RnnClassifier : public Container { linear_(std::make_shared(hiddenSize, numClasses)), logsoftmax_(0) { std::cout << "Creating a RNN Classifier with vocab size: " << vocabSize - << " and num classes: " << numClasses << std::endl; + << " and num classes: " << numClasses << std::endl; createLayers(); } @@ -254,8 +254,8 @@ class RnnClassifier : public Container { const bool passes = p == expectedLabel; const std::string s = (passes ? "✓ " : "✗ "); std::cout << "input: " << std::setw(20) << input - << "\t expected: " << expectedLabel << "\t prediction: " << p - << "\t" << s << std::endl; + << "\t expected: " << expectedLabel << "\t prediction: " << p + << "\t" << s << std::endl; return passes; } @@ -269,14 +269,14 @@ class RnnClassifier : public Container { int main(int argc, char** argv) { fl::init(); std::cout << "RnnClassification (path to the data dir) (learning rate) (num " - "epochs) (hiddensize)" - << std::endl; + "epochs) (hiddensize)" + << std::endl; std::cout << "Dataset : https://download.pytorch.org/tutorial/data.zip" - << std::endl; + << std::endl; if(argc < 2) { std::cout << "To setup the dataset: " << std::endl; std::cout << "wget https://download.pytorch.org/tutorial/data.zip" - << std::endl; + << std::endl; std::cout << "unzip data.zip" << std::endl; std::cout << "./RnnClassification data/names" << std::endl; return 0; @@ -326,7 +326,7 @@ int main(int argc, char** argv) { double trainLoss = trainLossMeter.value()[0]; std::cout << "Epoch " << e + 1 << std::setprecision(3) - << " - Train Loss: " << trainLoss << std::endl; + << " - Train Loss: " << trainLoss << std::endl; // compute the accuracy confusion matrix: const unsigned nCategories = ClassificationDataset::Label2Id.size(); @@ -346,8 +346,8 @@ int main(int argc, char** argv) { std::cout << "Global accuracy=" << numMatch / nConfusion << "\t "; for(unsigned i = 0; i < nCategories; ++i) std::cout << ClassificationDataset::Id2Label[i] << ":" << std::fixed - << std::setprecision(2) << confusion(i, i).scalar() - << " "; + << std::setprecision(2) << confusion(i, i).scalar() + << " "; std::cout << std::endl; } // List of names not in the training dataset @@ -370,7 +370,7 @@ int main(int argc, char** argv) { std::cin >> name; Variable output, h, c; std::cout << ClassificationDataset::Id2Label[model.infer(name, h, c)] - << " ?" << std::endl; + << " ?" << std::endl; } std::cout << "Finished" << std::endl; return 0; diff --git a/flashlight/fl/examples/RnnLm.cpp b/flashlight/fl/examples/RnnLm.cpp index b455b03..9402fd9 100644 --- a/flashlight/fl/examples/RnnLm.cpp +++ b/flashlight/fl/examples/RnnLm.cpp @@ -253,17 +253,17 @@ int main(int argc, char** argv) { double iter_time = timer.value(); std::cout << "Epoch " << e + 1 << std::setprecision(3) - << " - Train Loss: " << train_loss - << " Validation Loss: " << val_loss - << " Validation Perplexity: " << std::exp(val_loss) - << " Time per iteration (ms): " << iter_time * 1000 << std::endl; + << " - Train Loss: " << train_loss + << " Validation Loss: " << val_loss + << " Validation Perplexity: " << std::exp(val_loss) + << " Time per iteration (ms): " << iter_time * 1000 << std::endl; } LMDataset testset(test_dir, batch_size, time_steps, preproc); double test_loss = eval_loop(testset); std::cout << " Test Loss: " << test_loss - << " Test Perplexity: " << std::exp(test_loss) << std::endl; + << " Test Perplexity: " << std::exp(test_loss) << std::endl; return 0; } diff --git a/flashlight/fl/examples/Xor.cpp b/flashlight/fl/examples/Xor.cpp index 3640bc5..d8b2b6d 100644 --- a/flashlight/fl/examples/Xor.cpp +++ b/flashlight/fl/examples/Xor.cpp @@ -99,11 +99,11 @@ int main(int argc, const char** argv) { // TODO: Use loss function Tensor diff = out - result.tensor(); std::cout << "Average Error at iteration (" << i + 1 - << ") : " << fl::mean(fl::abs(diff)).scalar() << "\n"; + << ") : " << fl::mean(fl::abs(diff)).scalar() << "\n"; std::cout << "Predicted\n" - << result.tensor() << std::endl - << "Expected\n" - << out << std::endl; + << result.tensor() << std::endl + << "Expected\n" + << out << std::endl; } } return 0; diff --git a/flashlight/fl/nn/modules/Conv2D.cpp b/flashlight/fl/nn/modules/Conv2D.cpp index 02f6bab..83ec842 100644 --- a/flashlight/fl/nn/modules/Conv2D.cpp +++ b/flashlight/fl/nn/modules/Conv2D.cpp @@ -200,7 +200,7 @@ std::string Conv2D::prettyString() const { std::ostringstream ss; ss << "Conv2D"; ss << " (" << nIn_ << "->" << nOut_ << ", " << xFilter_ << "x" << yFilter_ - << ", " << xStride_ << "," << yStride_ << ", "; + << ", " << xStride_ << "," << yStride_ << ", "; if(xPad_ == static_cast(PaddingMode::SAME)) ss << "SAME"; else diff --git a/flashlight/fl/nn/modules/Embedding.cpp b/flashlight/fl/nn/modules/Embedding.cpp index 55c01a4..3972ff8 100644 --- a/flashlight/fl/nn/modules/Embedding.cpp +++ b/flashlight/fl/nn/modules/Embedding.cpp @@ -52,7 +52,7 @@ std::unique_ptr Embedding::clone() const { std::string Embedding::prettyString() const { std::ostringstream ss; ss << "Embedding (embeddings: " << numEmbeddings_ - << ") (dim: " << embeddingDim_ << ")"; + << ") (dim: " << embeddingDim_ << ")"; return ss.str(); } diff --git a/flashlight/fl/nn/modules/Pool2D.cpp b/flashlight/fl/nn/modules/Pool2D.cpp index 5d58d99..87217f8 100644 --- a/flashlight/fl/nn/modules/Pool2D.cpp +++ b/flashlight/fl/nn/modules/Pool2D.cpp @@ -74,7 +74,7 @@ std::string Pool2D::prettyString() const { break; } ss << " (" << xFilter_ << "x" << yFilter_ << ", " << xStride_ << "," - << yStride_ << ", "; + << yStride_ << ", "; if(xPad_ == static_cast(PaddingMode::SAME)) ss << "SAME"; else diff --git a/flashlight/fl/runtime/CUDAUtils.cpp b/flashlight/fl/runtime/CUDAUtils.cpp index 677c9ae..d0c6076 100644 --- a/flashlight/fl/runtime/CUDAUtils.cpp +++ b/flashlight/fl/runtime/CUDAUtils.cpp @@ -37,7 +37,7 @@ namespace detail { if(err != cudaSuccess) { std::ostringstream ess; ess << prefix << '[' << file << ':' << line - << "] CUDA error: " << cudaGetErrorString(err); + << "] CUDA error: " << cudaGetErrorString(err); throw std::runtime_error(ess.str()); } } diff --git a/flashlight/fl/runtime/Device.cpp b/flashlight/fl/runtime/Device.cpp index 7a517df..2149481 100644 --- a/flashlight/fl/runtime/Device.cpp +++ b/flashlight/fl/runtime/Device.cpp @@ -16,8 +16,8 @@ void deviceImplTypeCheck(DeviceType expect, DeviceType actual) { if(expect != actual) { std::ostringstream oss; oss << "[fl::Device::impl] " - << "specified device type: [" << expect << "] " - << "doesn't match actual device type: [" << actual << "]"; + << "specified device type: [" << expect << "] " + << "doesn't match actual device type: [" << actual << "]"; throw std::invalid_argument(oss.str()); } } diff --git a/flashlight/fl/tensor/Shape.cpp b/flashlight/fl/tensor/Shape.cpp index 04b7afb..c0108ec 100644 --- a/flashlight/fl/tensor/Shape.cpp +++ b/flashlight/fl/tensor/Shape.cpp @@ -24,8 +24,8 @@ void Shape::checkDimsOrThrow(const size_t dim) const { if(dim > ndim() - 1) { std::stringstream ss; ss << "Shape index " << std::to_string(dim) - << " out of bounds for shape with " << std::to_string(dims_.size()) - << " dimensions."; + << " out of bounds for shape with " << std::to_string(dims_.size()) + << " dimensions."; throw std::invalid_argument(ss.str()); } } diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp b/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp index 033fed2..5c3fb57 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp +++ b/flashlight/fl/tensor/backend/af/ArrayFireBinaryOps.cpp @@ -56,8 +56,8 @@ namespace { else { std::stringstream ss; ss << "doBinaryOpOrBroadcast: cannot perform operation " - "or broadcasting with tensors of shapes " - << lhs.shape() << " and " << rhs.shape() << " - dimension mismatch."; + "or broadcasting with tensors of shapes " + << lhs.shape() << " and " << rhs.shape() << " - dimension mismatch."; throw std::invalid_argument(ss.str()); } } diff --git a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp index e0b6244..2014b77 100644 --- a/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp +++ b/flashlight/fl/tensor/backend/af/mem/CachingMemoryManager.cpp @@ -86,7 +86,7 @@ namespace { return std::round(mb * kMB); } catch(std::exception& ex) { std::cerr << "getEnvAsBytesFromFloatMb: Invalid environment " - << "variable value: name=" << name << " value=" << env; + << "variable value: name=" << name << " value=" << env; throw ex; } } @@ -298,16 +298,16 @@ void CachingMemoryManager::mallocWithRetry(size_t size, void** ptr) { } catch(std::exception& ex) { // note: af exception inherits from std exception std::cerr << "Failed to allocate memory of size " << formatMemory(size) - << " (Device: " << memInfo.deviceId_ << ", Capacity: " - << formatMemory( + << " (Device: " << memInfo.deviceId_ << ", Capacity: " + << formatMemory( this->deviceInterface->getMaxMemorySize( memInfo.deviceId_ ) - ) - << ", Allocated: " - << formatMemory(memInfo.stats_.allocatedBytes_) - << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) - << ") with error '" << ex.what() << "'" << std::endl; + ) + << ", Allocated: " + << formatMemory(memInfo.stats_.allocatedBytes_) + << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) + << ") with error '" << ex.what() << "'" << std::endl; // note: converting here an af exception to std exception prevents to // catch the af error code at the user level. Rethrowing. throw; @@ -374,16 +374,16 @@ void CachingMemoryManager::printInfo( std::lock_guard lock(memInfo.mutexAll_); ostream << msg << "\nType: CachingMemoryManager" << std::endl - << "\nDevice: " << memInfo.deviceId_ << ", Capacity: " - << formatMemory( + << "\nDevice: " << memInfo.deviceId_ << ", Capacity: " + << formatMemory( this->deviceInterface->getMaxMemorySize(memInfo.deviceId_) - ) - << ", Allocated: " << formatMemory(memInfo.stats_.allocatedBytes_) - << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) - << std::endl - << "\nTotal native calls: " << memInfo.stats_.totalNativeMallocs_ - << "(mallocs), " << memInfo.stats_.totalNativeFrees_ << "(frees)" - << std::endl; + ) + << ", Allocated: " << formatMemory(memInfo.stats_.allocatedBytes_) + << ", Cached: " << formatMemory(memInfo.stats_.cachedBytes_) + << std::endl + << "\nTotal native calls: " << memInfo.stats_.totalNativeMallocs_ + << "(mallocs), " << memInfo.stats_.totalNativeFrees_ << "(frees)" + << std::endl; } void CachingMemoryManager::userLock(const void* ptr) { diff --git a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp index 6ba6aa6..3a7497c 100644 --- a/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp +++ b/flashlight/fl/tensor/backend/af/mem/DefaultMemoryManager.cpp @@ -65,7 +65,7 @@ void DefaultMemoryManager::cleanDeviceMemoryManager(int device) { std::stringstream ss; ss << "GC: Clearing " << freePtrs.size() << " buffers |" - << std::to_string(bytesFreed) << " bytes"; + << std::to_string(bytesFreed) << " bytes"; this->log(ss.str()); // Free memory outside of the lock @@ -288,9 +288,9 @@ void DefaultMemoryManager::printInfo( const MemoryInfo& current = this->getCurrentMemoryInfo(); ostream << msg << std::endl - << "---------------------------------------------------------\n" - << "| POINTER | SIZE | AF LOCK | USER LOCK |\n" - << "---------------------------------------------------------\n"; + << "---------------------------------------------------------\n" + << "| POINTER | SIZE | AF LOCK | USER LOCK |\n" + << "---------------------------------------------------------\n"; std::lock_guard lock(this->memoryMutex); for(auto& kv : current.lockedMap) { @@ -309,7 +309,7 @@ void DefaultMemoryManager::printInfo( } ostream << "| " << kv.first << " | " << size << " " << unit << " | " - << statusMngr << " | " << statusUser << " |\n"; + << statusMngr << " | " << statusUser << " |\n"; } for(auto& kv : current.freeMap) { @@ -325,7 +325,7 @@ void DefaultMemoryManager::printInfo( for(auto& ptr : kv.second) ostream << "| " << ptr << " | " << size << " " << unit << " | " - << statusMngr << " | " << statusUser << " |\n"; + << statusMngr << " | " << statusUser << " |\n"; } ostream << "---------------------------------------------------------\n"; diff --git a/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp b/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp index 26961f6..c6299cf 100644 --- a/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp +++ b/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp @@ -265,26 +265,26 @@ TEST(AutogradBinaryOpsTest, matmul) { // matmul auto funcMatmulLhs = [&](Variable& input) { return matmul(input, b); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulLhs, a, 1E-6)) - << "matmul lhs gradient: lhs " << a.shape() << " rhs " << b.shape(); + << "matmul lhs gradient: lhs " << a.shape() << " rhs " << b.shape(); auto funcMatmulRhs = [&](Variable& input) { return matmul(a, input); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulRhs, b, 1E-6)) - << "matmul rhs gradient: lhs " << a.shape() << " rhs " << b.shape(); + << "matmul rhs gradient: lhs " << a.shape() << " rhs " << b.shape(); // matmulTN auto funcMatmulTNLhs = [&](Variable& input) { return matmulTN(input, b); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulTNLhs, aT, 1E-6)) - << "matmulTN lhs gradient: lhs " << a.shape() << " rhs " << b.shape(); + << "matmulTN lhs gradient: lhs " << a.shape() << " rhs " << b.shape(); auto funcMatmulTNRhs = [&](Variable& input) { return matmulTN(aT, input); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulTNRhs, b, 1E-6)) - << "matmulTN rhs gradient: lhs " << a.shape() << " rhs " << b.shape(); + << "matmulTN rhs gradient: lhs " << a.shape() << " rhs " << b.shape(); // matmulNT auto funcMatmulNTLhs = [&](Variable& input) { return matmulNT(input, bT); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulNTLhs, a, 1E-6)) - << "matmulTN lhs gradient: lhs " << a.shape() << " rhs " << b.shape(); + << "matmulTN lhs gradient: lhs " << a.shape() << " rhs " << b.shape(); auto funcMatmulNTRhs = [&](Variable& input) { return matmulNT(a, input); }; ASSERT_TRUE(fl::detail::jacobianTestImpl(funcMatmulNTRhs, bT, 1E-6)) - << "matmulTN rhs gradient: lhs " << a.shape() << " rhs " << b.shape(); + << "matmulTN rhs gradient: lhs " << a.shape() << " rhs " << b.shape(); } } diff --git a/flashlight/fl/test/autograd/AutogradTest.cpp b/flashlight/fl/test/autograd/AutogradTest.cpp index ffee958..9636b39 100644 --- a/flashlight/fl/test/autograd/AutogradTest.cpp +++ b/flashlight/fl/test/autograd/AutogradTest.cpp @@ -465,7 +465,7 @@ TEST(AutogradTest, GetAdvancedIndex) { // TODO: remove me if(!FL_BACKEND_CUDA) GTEST_SKIP() - << "Advanced indexing operator unsupported for non-CUDA backends"; + << "Advanced indexing operator unsupported for non-CUDA backends"; std::vector validIndexTypes = { fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64}; for(const auto& dtype : validIndexTypes) { @@ -494,7 +494,7 @@ TEST(AutogradTest, GetAdvancedIndexF16) { // TODO: remove me if(!FL_BACKEND_CUDA) GTEST_SKIP() - << "Advanced indexing operator unsupported for non-CUDA backends"; + << "Advanced indexing operator unsupported for non-CUDA backends"; if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; std::vector validIndexTypes = { diff --git a/flashlight/fl/test/distributed/AllReduceBenchmark.cpp b/flashlight/fl/test/distributed/AllReduceBenchmark.cpp index 0d6fdec..8a672e0 100644 --- a/flashlight/fl/test/distributed/AllReduceBenchmark.cpp +++ b/flashlight/fl/test/distributed/AllReduceBenchmark.cpp @@ -53,10 +53,10 @@ int main() { auto timesAf = Tensor::fromVector({kNumIters}, times); if(wRank == 0) std::cout << "Size: " << size - << " ; avg: " << fl::mean(timesAf).asScalar() * 1000 - << "ms ; p50: " - << fl::median(timesAf).asScalar() * 1000 << "ms" - << std::endl; + << " ; avg: " << fl::mean(timesAf).asScalar() * 1000 + << "ms ; p50: " + << fl::median(timesAf).asScalar() * 1000 << "ms" + << std::endl; curMaxSize = std::max(curMaxSize, size); size *= multiplier; } diff --git a/flashlight/fl/test/distributed/AllReduceTest.cpp b/flashlight/fl/test/distributed/AllReduceTest.cpp index 4f8e849..ff7166c 100644 --- a/flashlight/fl/test/distributed/AllReduceTest.cpp +++ b/flashlight/fl/test/distributed/AllReduceTest.cpp @@ -187,8 +187,8 @@ int main(int argc, char** argv) { } catch(const std::exception& ex) { // Don't run the test if distributed initialization fails std::cerr - << "Distributed initialization failed; tests will be skipped. Reason: " - << ex.what() << std::endl; + << "Distributed initialization failed; tests will be skipped. Reason: " + << ex.what() << std::endl; } return RUN_ALL_TESTS(); diff --git a/flashlight/fl/test/tensor/TensorBaseTest.cpp b/flashlight/fl/test/tensor/TensorBaseTest.cpp index 22dd1b9..d6a4e1e 100644 --- a/flashlight/fl/test/tensor/TensorBaseTest.cpp +++ b/flashlight/fl/test/tensor/TensorBaseTest.cpp @@ -495,8 +495,8 @@ void assertScalarBehavior(fl::dtype type) { if(dtype_traits::fl_type != type) { ASSERT_THROW(one.template scalar(), std::invalid_argument) - << "dtype: " << type - << ", ScalarArgType: " << dtype_traits::getName(); + << "dtype: " << type + << ", ScalarArgType: " << dtype_traits::getName(); return; } @@ -505,19 +505,19 @@ void assertScalarBehavior(fl::dtype type) { || (type == fl::dtype::f64) ) ASSERT_FLOAT_EQ(one.template scalar(), scalar) - << "dtype: " << type - << ", ScalarArgType: " << dtype_traits::getName(); + << "dtype: " << type + << ", ScalarArgType: " << dtype_traits::getName(); else ASSERT_EQ(one.template scalar(), scalar) - << "dtype: " << type - << ", ScalarArgType: " << dtype_traits::getName(); + << "dtype: " << type + << ", ScalarArgType: " << dtype_traits::getName(); ScalarArgType val = static_cast(rand()); auto a = fl::full({5, 6}, val, type); ASSERT_TRUE(allClose(fl::full({1}, a.template scalar(), type), a(0, 0))) - << "dtype: " << type - << ", ScalarArgType: " << dtype_traits::getName(); + << "dtype: " << type + << ", ScalarArgType: " << dtype_traits::getName(); } TEST(TensorBaseTest, scalar) { @@ -570,13 +570,12 @@ TEST(TensorBaseTest, stream) { TEST(TensorBaseTest, asContiguousTensor) { auto t = fl::rand({5, 6, 7, 8}); - auto indexed = - t( - fl::range(1, 4, 2), - fl::range(0, 6, 2), - fl::range(0, 6, 3), - fl::range(0, 5, 3) - ); + auto indexed = t( + fl::range(1, 4, 2), + fl::range(0, 6, 2), + fl::range(0, 6, 3), + fl::range(0, 5, 3) + ); auto contiguous = indexed.asContiguousTensor(); std::vector strides; diff --git a/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp b/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp index e2d6254..d551fab 100644 --- a/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp +++ b/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp @@ -30,8 +30,8 @@ void assertTensorScalarBinop( auto result = op(in, scalar); auto expect = expectOut.astype(result.type()); ASSERT_TRUE(allClose(result, expect)) - << "in.type(): " << in.type() - << ", ScalarType: " << dtype_traits::getName(); + << "in.type(): " << in.type() + << ", ScalarType: " << dtype_traits::getName(); } template @@ -44,8 +44,8 @@ void assertScalarTensorBinop( auto result = op(scalar, in); auto expect = expectOut.astype(result.type()); ASSERT_TRUE(allClose(result, expect)) - << "ScalarType: " << dtype_traits::getName() - << ", in.type(): " << in.type(); + << "ScalarType: " << dtype_traits::getName() + << ", in.type(): " << in.type(); } template @@ -67,9 +67,9 @@ void assertCommutativeBinop( const Tensor& out ) { ASSERT_TRUE(allClose(op(in1, in2), out)) - << "in1.type(): " << in1.type() << ", in2.type(): " << in2.type(); + << "in1.type(): " << in1.type() << ", in2.type(): " << in2.type(); ASSERT_TRUE(allClose(op(in2, in1), out)) - << "in1.type(): " << in1.type() << ", in2.type(): " << in2.type(); + << "in1.type(): " << in1.type() << ", in2.type(): " << in2.type(); } void applyToAllFpDtypes(std::function func) { @@ -316,34 +316,34 @@ TEST(TensorBinaryOpsTest, BinaryOperatorIncompatibleShapes) { ASSERT_THROW((void) Values(lhs * rhs), std::invalid_argument) << "dtype: " << type; ASSERT_THROW((void) Values(lhs / rhs), std::invalid_argument) << "dtype: " << type; ASSERT_THROW((void) Values(lhs == rhs), std::invalid_argument) - << "dtype: " << type; + << "dtype: " << type; ASSERT_THROW((void) Values(lhs != rhs), std::invalid_argument) - << "dtype: " << type; + << "dtype: " << type; ASSERT_THROW((void) Values(lhs < rhs), std::invalid_argument) << "dtype: " << type; ASSERT_THROW((void) Values(lhs <= rhs), std::invalid_argument) - << "dtype: " << type; + << "dtype: " << type; ASSERT_THROW((void) Values(lhs > rhs), std::invalid_argument) << "dtype: " << type; ASSERT_THROW((void) Values(lhs >= rhs), std::invalid_argument) - << "dtype: " << type; + << "dtype: " << type; ASSERT_THROW((void) Values(lhs || rhs), std::invalid_argument) - << "dtype: " << type; + << "dtype: " << type; ASSERT_THROW((void) Values(lhs && rhs), std::invalid_argument) - << "dtype: " << type; + << "dtype: " << type; // TODO ArrayFire needs software impl for fp16 modulo on CUDA backend; // bring this test back when supported. if(type != dtype::f16) ASSERT_THROW((void) Values(lhs % rhs), std::invalid_argument) - << "dtype: " << type; + << "dtype: " << type; // these operators are generally not well-defined for fps if(type != dtype::f16 && type != dtype::f32 && type != dtype::f64) { ASSERT_THROW((void) Values(lhs | rhs), std::invalid_argument) - << "dtype: " << type; + << "dtype: " << type; ASSERT_THROW((void) Values(lhs ^ rhs), std::invalid_argument) - << "dtype: " << type; + << "dtype: " << type; ASSERT_THROW((void) Values(lhs << rhs), std::invalid_argument) - << "dtype: " << type; + << "dtype: " << type; ASSERT_THROW((void) Values(lhs >> rhs), std::invalid_argument) - << "dtype: " << type; + << "dtype: " << type; } }; @@ -474,7 +474,7 @@ TEST(TensorBinaryOpsTest, broadcasting) { std::stringstream ss; ss << "lhs: " << shapeData.lhs << " rhs: " << shapeData.rhs - << " function: " << funcp.second; + << " function: " << funcp.second; auto testData = ss.str(); ASSERT_EQ(actualOut.shape(), expectedShape) << testData; diff --git a/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp b/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp index eb29cb1..f548ad7 100644 --- a/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp +++ b/flashlight/fl/test/tensor/af/CachingMemoryManagerTest.cpp @@ -156,7 +156,7 @@ void testFragmentation( if(b != AF_BACKEND_CUDA) GTEST_SKIP() - << "CachingMemoryManager fragmentation tests require CUDA backend"; + << "CachingMemoryManager fragmentation tests require CUDA backend"; const auto mms = deviceInterface_->getMaxMemorySize(0); const auto maxNumf32 = mms / sizeof(float); // AF f32 is supposed to be 32b @@ -178,8 +178,8 @@ void testFragmentation( ASSERT_EQ(ex.err(), AF_ERR_NO_MEM); else EXPECT_TRUE(false) - << "CachingMemoryManagerTest fragmentaiton not supposed to throw: " - << ex.what(); + << "CachingMemoryManagerTest fragmentaiton not supposed to throw: " + << ex.what(); } } diff --git a/flashlight/pkg/runtime/Runtime.cpp b/flashlight/pkg/runtime/Runtime.cpp index dc06381..77e5416 100644 --- a/flashlight/pkg/runtime/Runtime.cpp +++ b/flashlight/pkg/runtime/Runtime.cpp @@ -19,7 +19,7 @@ constexpr size_t kRunFileNameIntWidth = 3; std::string getRunFile(const std::string& name, const int runidx, const fs::path& runpath) { std::stringstream ss; ss << std::setw(kRunFileNameIntWidth) << std::setfill('0') << runidx << "_" - << name; + << name; return runpath / ss.str(); }; diff --git a/flashlight/pkg/runtime/amp/DynamicScaler.cpp b/flashlight/pkg/runtime/amp/DynamicScaler.cpp index dc57d57..dcafee1 100644 --- a/flashlight/pkg/runtime/amp/DynamicScaler.cpp +++ b/flashlight/pkg/runtime/amp/DynamicScaler.cpp @@ -36,14 +36,14 @@ bool DynamicScaler::unscale(std::vector& params) { if(scaleFactor_ >= fl::kAmpMinimumScaleFactorValue) { scaleFactor_ = scaleFactor_ / 2.0f; FL_LOG(LogLevel::INFO) - << "AMP: Scale factor decreased. New value:\t" << scaleFactor_; + << "AMP: Scale factor decreased. New value:\t" << scaleFactor_; } else FL_LOG(LogLevel::FATAL) - << "Minimum loss scale reached: " << fl::kAmpMinimumScaleFactorValue - << " with over/underflowing gradients. Lowering the " - << "learning rate, using gradient clipping, or " - << "increasing the batch size can help resolve " - << "loss explosion."; + << "Minimum loss scale reached: " << fl::kAmpMinimumScaleFactorValue + << " with over/underflowing gradients. Lowering the " + << "learning rate, using gradient clipping, or " + << "increasing the batch size can help resolve " + << "loss explosion."; successCounter_ = 0; return false; } diff --git a/flashlight/pkg/runtime/common/Serializer.h b/flashlight/pkg/runtime/common/Serializer.h index 759922e..583b0d4 100644 --- a/flashlight/pkg/runtime/common/Serializer.h +++ b/flashlight/pkg/runtime/common/Serializer.h @@ -67,7 +67,7 @@ namespace pkg { ar(args...); } catch(const std::exception& ex) { FL_LOG(fl::LogLevel::ERROR) - << "Error while saving \"" << filepath << "\": " << ex.what() << "\n"; + << "Error while saving \"" << filepath << "\": " << ex.what() << "\n"; throw; } } @@ -84,7 +84,7 @@ namespace pkg { ar(args...); } catch(const std::exception& ex) { FL_LOG(fl::LogLevel::ERROR) << "Error while loading \"" << filepath - << "\": " << ex.what() << "\n"; + << "\": " << ex.what() << "\n"; throw; } } diff --git a/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp b/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp index bf1cb4c..8b5b566 100644 --- a/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp +++ b/flashlight/pkg/speech/augmentation/AdditiveNoise.cpp @@ -21,8 +21,8 @@ namespace fl::pkg::speech::sfx { std::string AdditiveNoise::Config::prettyString() const { std::stringstream ss; ss << "AdditiveNoise::Config{ratio_=" << ratio_ << " minSnr_=" << minSnr_ - << " maxSnr_=" << maxSnr_ << " nClipsMin_=" << nClipsMin_ << " nClipsMax_" - << nClipsMax_ << " listFilePath_=" << listFilePath_ << '}'; + << " maxSnr_=" << maxSnr_ << " nClipsMin_=" << nClipsMin_ << " nClipsMax_" + << nClipsMax_ << " listFilePath_=" << listFilePath_ << '}'; return ss.str(); } @@ -87,7 +87,7 @@ void AdditiveNoise::apply(std::vector& signal) { signal[i] += mixedNoise[i] * noiseMult; } else FL_LOG(fl::LogLevel::WARNING) - << "AdditiveNoise::apply() invalid noiseRms=" << noiseRms; + << "AdditiveNoise::apply() invalid noiseRms=" << noiseRms; } } // namespace fl diff --git a/flashlight/pkg/speech/augmentation/GaussianNoise.cpp b/flashlight/pkg/speech/augmentation/GaussianNoise.cpp index a1a5a0c..b7663a9 100644 --- a/flashlight/pkg/speech/augmentation/GaussianNoise.cpp +++ b/flashlight/pkg/speech/augmentation/GaussianNoise.cpp @@ -17,7 +17,7 @@ namespace fl::pkg::speech::sfx { std::string GaussianNoise::Config::prettyString() const { std::stringstream ss; ss << "GaussianNoise::Config{minSnr_=" << minSnr_ << " maxSnr_=" << maxSnr_ - << '}'; + << '}'; return ss.str(); } diff --git a/flashlight/pkg/speech/augmentation/Reverberation.cpp b/flashlight/pkg/speech/augmentation/Reverberation.cpp index c7ca266..708dd02 100644 --- a/flashlight/pkg/speech/augmentation/Reverberation.cpp +++ b/flashlight/pkg/speech/augmentation/Reverberation.cpp @@ -77,10 +77,10 @@ std::string ReverbEcho::prettyString() const { std::string ReverbEcho::Config::prettyString() const { std::stringstream ss; ss << " proba_=" << proba_ << " initialMin_=" << initialMin_ - << " initialMax_=" << initialMax_ << " rt60Min_=" << rt60Min_ - << " rt60Max_=" << rt60Max_ << " firstDelayMin_=" << firstDelayMin_ - << " firstDelayMax_=" << firstDelayMax_ << " repeat_=" << repeat_ - << " jitter_=" << jitter_ << " sampleRate_=" << sampleRate_; + << " initialMax_=" << initialMax_ << " rt60Min_=" << rt60Min_ + << " rt60Max_=" << rt60Max_ << " firstDelayMin_=" << firstDelayMin_ + << " firstDelayMax_=" << firstDelayMax_ << " repeat_=" << repeat_ + << " jitter_=" << jitter_ << " sampleRate_=" << sampleRate_; return ss.str(); } diff --git a/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp b/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp index c333f16..1264e56 100644 --- a/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp +++ b/flashlight/pkg/speech/augmentation/SoundEffectConfig.cpp @@ -104,7 +104,7 @@ void writeSoundEffectConfigFile( } catch(std::exception& ex) { std::stringstream ss; ss << "writeSoundEffectConfigFile(filename=" << filename - << ") failed with error={" << ex.what() << "}"; + << ") failed with error={" << ex.what() << "}"; throw std::runtime_error(ss.str()); } } @@ -121,7 +121,7 @@ std::vector readSoundEffectConfigFile( } catch(std::exception& ex) { std::stringstream ss; ss << "readSoundEffectConfigFile(filename=" << filename - << ") failed with error={" << ex.what() << "}"; + << ") failed with error={" << ex.what() << "}"; throw std::runtime_error(ss.str()); } } diff --git a/flashlight/pkg/speech/augmentation/SoxWrapper.cpp b/flashlight/pkg/speech/augmentation/SoxWrapper.cpp index 67dc3e4..4d2e1dd 100644 --- a/flashlight/pkg/speech/augmentation/SoxWrapper.cpp +++ b/flashlight/pkg/speech/augmentation/SoxWrapper.cpp @@ -44,8 +44,8 @@ namespace { if(i != *isamp) { LOG(ERROR) << "outputFlow number of bytes written=" << i - << " expected=" << *isamp - << " priv->data->size()=" << priv->data->size(); + << " expected=" << *isamp + << " priv->data->size()=" << priv->data->size(); return SOX_EOF; } } @@ -203,7 +203,7 @@ namespace detail { if(status != SOX_SUCCESS) { std::stringstream ss; ss << file << ':' << line << "] libsox error: " << status - << " when executing: " << msg; + << " when executing: " << msg; LOG(ERROR) << ss.str(); throw std::runtime_error(ss.str()); } @@ -213,7 +213,7 @@ namespace detail { if(!ptr) { std::stringstream ss; ss << file << ':' << line - << "] libsox failed to allocate when executing: " << msg; + << "] libsox failed to allocate when executing: " << msg; LOG(ERROR) << ss.str(); throw std::runtime_error(ss.str()); } diff --git a/flashlight/pkg/speech/augmentation/TimeStretch.cpp b/flashlight/pkg/speech/augmentation/TimeStretch.cpp index cb6726c..bd40dcd 100644 --- a/flashlight/pkg/speech/augmentation/TimeStretch.cpp +++ b/flashlight/pkg/speech/augmentation/TimeStretch.cpp @@ -41,8 +41,8 @@ void TimeStretch::apply(std::vector& signal) { std::string TimeStretch::Config::prettyString() const { std::stringstream ss; ss << "TimeStretch::Config{minFactor_=" << minFactor_ - << " maxFactor_=" << maxFactor_ << " proba_=" << proba_ - << " sampleRate_=" << sampleRate_ << '}'; + << " maxFactor_=" << maxFactor_ << " proba_=" << proba_ + << " sampleRate_=" << sampleRate_ << '}'; return ss.str(); } diff --git a/flashlight/pkg/speech/common/Flags.cpp b/flashlight/pkg/speech/common/Flags.cpp index 175904a..e45bf93 100644 --- a/flashlight/pkg/speech/common/Flags.cpp +++ b/flashlight/pkg/speech/common/Flags.cpp @@ -845,18 +845,18 @@ void handleDeprecatedFlags() { if(deprecatedFlagSet && newFlagSet) { // Use the new flag value std::cerr << "[WARNING] Both deprecated flag " << flagPair.first - << " and new flag " << flagPair.second - << " are set. Only the new flag will be " - << "serialized when the model saved." << std::endl; + << " and new flag " << flagPair.second + << " are set. Only the new flag will be " + << "serialized when the model saved." << std::endl; ; } else if(deprecatedFlagSet && !newFlagSet) { std::cerr - << "[WARNING] Usage of flag --" << flagPair.first - << " is deprecated and has been replaced with " - << "--" << flagPair.second - << ". Setting the new flag equal to the value of the deprecated flag." - << "The old flag will not be serialized when the model is saved." - << std::endl; + << "[WARNING] Usage of flag --" << flagPair.first + << " is deprecated and has been replaced with " + << "--" << flagPair.second + << ". Setting the new flag equal to the value of the deprecated flag." + << "The old flag will not be serialized when the model is saved." + << std::endl; if( gflags::SetCommandLineOption( flagPair.second.c_str(), @@ -866,7 +866,7 @@ void handleDeprecatedFlags() { ) { std::stringstream ss; ss << "Failed to set new flag " << flagPair.second << " to value from " - << flagPair.first << "."; + << flagPair.first << "."; throw std::logic_error(ss.str()); } } diff --git a/flashlight/pkg/speech/decoder/PlGenerator.cpp b/flashlight/pkg/speech/decoder/PlGenerator.cpp index 1078f98..f73537d 100644 --- a/flashlight/pkg/speech/decoder/PlGenerator.cpp +++ b/flashlight/pkg/speech/decoder/PlGenerator.cpp @@ -225,8 +225,8 @@ std::string PlGenerator::regeneratePl( auto sampleId = readSampleIds(sample[kSampleIdx]).front(); auto inputPath = readSampleIds(sample[kPathIdx]).front(); plStream << sampleId << "\t" << inputPath << "\t" - << std::to_string(duration) << "\t" << lib::join(" ", words) - << std::endl; + << std::to_string(duration) << "\t" << lib::join(" ", words) + << std::endl; } plStream.close(); diff --git a/flashlight/pkg/speech/runtime/Helpers.cpp b/flashlight/pkg/speech/runtime/Helpers.cpp index 966d32b..a5a8f6d 100644 --- a/flashlight/pkg/speech/runtime/Helpers.cpp +++ b/flashlight/pkg/speech/runtime/Helpers.cpp @@ -116,7 +116,7 @@ std::shared_ptr createDataset( ); #else LOG(FATAL) << "EverstoreDataset not supported: " - << "build with -DFL_BUILD_FB_DEPENDENCIES"; + << "build with -DFL_BUILD_FB_DEPENDENCIES"; #endif } else curListDs = std::make_shared( diff --git a/flashlight/pkg/speech/test/audio/MfccTest.cpp b/flashlight/pkg/speech/test/audio/MfccTest.cpp index 837353e..b43d9c8 100644 --- a/flashlight/pkg/speech/test/audio/MfccTest.cpp +++ b/flashlight/pkg/speech/test/audio/MfccTest.cpp @@ -81,7 +81,7 @@ TEST(MfccTest, htkCompareTest) { std::cerr << "| Max diff across all dimensions " << max << "\n"; // 0.325853 std::cerr << "| Avg diff across all dimensions " << sum / feat.size() - << "\n"; // 0.00252719 + << "\n"; // 0.00252719 } TEST(MfccTest, BatchingTest) { diff --git a/flashlight/pkg/speech/test/criterion/BenchmarkASG.cpp b/flashlight/pkg/speech/test/criterion/BenchmarkASG.cpp index 8778d40..abbd6f8 100644 --- a/flashlight/pkg/speech/test/criterion/BenchmarkASG.cpp +++ b/flashlight/pkg/speech/test/criterion/BenchmarkASG.cpp @@ -49,6 +49,6 @@ int main() { fl::sync(); auto e = fl::Timer::stop(s); std::cout << "Total time (fwd+bwd pass) " << std::setprecision(5) - << e * 1000.0 / ntimes << " msec" << std::endl; + << e * 1000.0 / ntimes << " msec" << std::endl; return 0; } diff --git a/flashlight/pkg/speech/test/criterion/BenchmarkCTC.cpp b/flashlight/pkg/speech/test/criterion/BenchmarkCTC.cpp index 68a6d2c..b057121 100644 --- a/flashlight/pkg/speech/test/criterion/BenchmarkCTC.cpp +++ b/flashlight/pkg/speech/test/criterion/BenchmarkCTC.cpp @@ -53,6 +53,6 @@ int main() { fl::sync(); auto e = fl::Timer::stop(s); std::cout << "Total time (fwd+bwd pass) " << std::setprecision(5) - << e * 1000.0 / ntimes << " msec" << std::endl; + << e * 1000.0 / ntimes << " msec" << std::endl; return 0; } diff --git a/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp b/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp index 6133031..a10138f 100644 --- a/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp +++ b/flashlight/pkg/speech/test/criterion/BenchmarkSeq2Seq.cpp @@ -44,7 +44,7 @@ void timeBeamSearch() { fl::sync(); auto e = fl::Timer::stop(s); std::cout << "Total time (beam size: " << b << ") " << std::setprecision(5) - << e * 1000.0 / iters << " msec" << std::endl; + << e * 1000.0 / iters << " msec" << std::endl; } } @@ -80,7 +80,7 @@ void timeForwardBackward() { fl::sync(); auto e = fl::Timer::stop(s); std::cout << "Total time (fwd+bwd pass) " << std::setprecision(5) - << e * 1000.0 / iters << " msec" << std::endl; + << e * 1000.0 / iters << " msec" << std::endl; } int main() { diff --git a/flashlight/pkg/speech/test/criterion/CompareASG.cpp b/flashlight/pkg/speech/test/criterion/CompareASG.cpp index adf49d1..f77332c 100644 --- a/flashlight/pkg/speech/test/criterion/CompareASG.cpp +++ b/flashlight/pkg/speech/test/criterion/CompareASG.cpp @@ -80,7 +80,7 @@ void printDiscrepancies( const Tensor& baseline ) { std::cerr << prefix << "discrepancy=" << std::setprecision(17) - << discrepancy(compare, baseline); + << discrepancy(compare, baseline); // Check for NaN discrepancies manually. auto compareNaN = fl::isnan(compare); auto baselineNaN = fl::isnan(baseline); diff --git a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h index b4e9135..15ec3ce 100644 --- a/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h +++ b/flashlight/pkg/speech/third_party/warpctc/include/detail/gpu_ctc.h @@ -297,14 +297,14 @@ ctcStatus_t GpuCTC::launch_alpha_beta_kernels( if(compute_alpha) compute_alpha_kernel<< < grid_size, NT, 0, stream_ >> - > (probs, label_sizes_, utt_length_, + > (probs, label_sizes_, utt_length_, repeats_, labels_without_blanks_, label_offsets_, labels_with_blanks_, alphas_, nll_forward_, stride, out_dim_, S_, T_, blank_label_); if(compute_beta) { compute_betas_and_grad_kernel<< < grid_size, NT, 0, stream_ >> - > (probs, label_sizes_, utt_length_, repeats_, + > (probs, label_sizes_, utt_length_, repeats_, labels_with_blanks_, alphas_, nll_forward_, nll_backward_, grads, stride, out_dim_, S_, T_, blank_label_); @@ -416,7 +416,7 @@ ctcStatus_t GpuCTC::compute_log_probs(const ProbT* const activations) { const int grid_size = ctc_helper::div_up(num_elements, NV); prepare_stable_SM_kernel << < grid_size, NT, 0, stream_ >> - > (ctc_helper::identity(), probs_, + > (ctc_helper::identity(), probs_, denoms_, out_dim_, num_elements); // Reduce along columns to calculate denominator @@ -434,7 +434,7 @@ ctcStatus_t GpuCTC::compute_log_probs(const ProbT* const activations) { // Kernel launch to calculate probabilities compute_log_probs_kernel<< < grid_size, NT, 0, stream_ >> - > (ctc_helper::logarithmic(), probs_, + > (ctc_helper::logarithmic(), probs_, denoms_, out_dim_, num_elements); return CTC_STATUS_SUCCESS; diff --git a/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu b/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu index 61ce49a..723b521 100644 --- a/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu +++ b/flashlight/pkg/speech/third_party/warpctc/src/reduce.cu @@ -148,13 +148,13 @@ struct ReduceHelper { if(axis) { grid_size = num_cols; reduce_rows<128> << < grid_size, 128, 0, stream >> - > (f, g, input, output, num_rows, num_cols); + > (f, g, input, output, num_rows, num_cols); } else { dim3 tpb(warp_size, 128 / warp_size); grid_size = (num_cols + warp_size - 1) / warp_size; reduce_cols<128> << < grid_size, tpb, 0, stream >> - > (f, g, input, output, num_rows, num_cols); + > (f, g, input, output, num_rows, num_cols); } } diff --git a/flashlight/pkg/text/data/TextDataset.cpp b/flashlight/pkg/text/data/TextDataset.cpp index 1407c22..04e44ed 100644 --- a/flashlight/pkg/text/data/TextDataset.cpp +++ b/flashlight/pkg/text/data/TextDataset.cpp @@ -57,7 +57,7 @@ TextDataset::TextDataset( const auto indices = dictionary.mapEntriesToIndices(tokens); if(data_.size() + indices.size() > kMaxTokenInBuffer) { FL_LOG(LogLevel::INFO) - << "[TextDataset] stop loading at 10,000,000,000 tokens"; + << "[TextDataset] stop loading at 10,000,000,000 tokens"; break; } sentenceRanges.emplace_back(currentEosPosition, -1); @@ -134,9 +134,9 @@ TextDataset::TextDataset( ); FL_LOG(LogLevel::INFO) << "[TextDataset] (" << reader.getRank() << "/" - << reader.getTotalReaders() << ") Loaded " << nTokens - << " tokens, " << sentenceRanges.size() - << " sentences and " << size() << " batches"; + << reader.getTotalReaders() << ") Loaded " << nTokens + << " tokens, " << sentenceRanges.size() + << " sentences and " << size() << " batches"; } int64_t TextDataset::size() const { diff --git a/flashlight/pkg/vision/dataset/BoxUtils.cpp b/flashlight/pkg/vision/dataset/BoxUtils.cpp index e409c50..0d2943d 100644 --- a/flashlight/pkg/vision/dataset/BoxUtils.cpp +++ b/flashlight/pkg/vision/dataset/BoxUtils.cpp @@ -170,8 +170,8 @@ std::tuple boxIou( if(bboxes1.ndim() != 3 || bboxes2.ndim() != 3) { std::stringstream ss; ss << "vision::boxIou - bbox inputs must be of shape " - "[4, N, B] and [4, M, B]. Got boxes with dimensions " - << bboxes1.shape() << " and " << bboxes2.shape(); + "[4, N, B] and [4, M, B]. Got boxes with dimensions " + << bboxes1.shape() << " and " << bboxes2.shape(); throw std::invalid_argument(ss.str()); } auto area1 = boxArea(bboxes1); diff --git a/flashlight/pkg/vision/models/ViT.cpp b/flashlight/pkg/vision/models/ViT.cpp index aa5e443..33ac3b3 100644 --- a/flashlight/pkg/vision/models/ViT.cpp +++ b/flashlight/pkg/vision/models/ViT.cpp @@ -143,7 +143,7 @@ std::vector ViT::forward( std::string ViT::prettyString() const { std::ostringstream ss; ss << "ViT (" << nClasses_ << " classes) with " << nLayers_ - << " Transformers:\n"; + << " Transformers:\n"; for(const auto& transformers : transformers_) ss << transformers->prettyString() << "\n"; return ss.str(); diff --git a/flashlight/pkg/vision/nn/VisionTransformer.cpp b/flashlight/pkg/vision/nn/VisionTransformer.cpp index cca25b1..48f6645 100644 --- a/flashlight/pkg/vision/nn/VisionTransformer.cpp +++ b/flashlight/pkg/vision/nn/VisionTransformer.cpp @@ -178,10 +178,10 @@ std::vector VisionTransformer::forward( std::string VisionTransformer::prettyString() const { std::ostringstream ss; ss << "VisionTransformer (nHeads: " << nHeads_ << "), " - << "(modelDim_: " << modelDim_ << "), " - << "(mlpDim_: " << mlpDim_ << "), " - << "(pDropout: " << pDropout_ << "), " - << "(pLayerdrop: " << pLayerdrop_ << "), "; + << "(modelDim_: " << modelDim_ << "), " + << "(mlpDim_: " << mlpDim_ << "), " + << "(pDropout: " << pDropout_ << "), " + << "(pLayerdrop: " << pLayerdrop_ << "), "; return ss.str(); } diff --git a/flashlight/pkg/vision/test/TransformerTest.cpp b/flashlight/pkg/vision/test/TransformerTest.cpp index 6ebbc15..b02220f 100644 --- a/flashlight/pkg/vision/test/TransformerTest.cpp +++ b/flashlight/pkg/vision/test/TransformerTest.cpp @@ -221,11 +221,11 @@ TEST(Tranformer, Size) { }; auto output = tr(inputs)[0]; ASSERT_EQ(output.dim(0), C) - << "Transformer should return model dim as first dimension"; + << "Transformer should return model dim as first dimension"; ASSERT_EQ(output.dim(1), bbox_queries) - << "Transformer did not return the correct number of labels"; + << "Transformer did not return the correct number of labels"; ASSERT_EQ(output.dim(2), B) - << "Transformer did not return the correct number of batches"; + << "Transformer did not return the correct number of batches"; } TEST(Tranformer, Masked) { diff --git a/flashlight/pkg/vision/test/criterion/HungarianTest.cpp b/flashlight/pkg/vision/test/criterion/HungarianTest.cpp index bfc49f6..e681ca8 100644 --- a/flashlight/pkg/vision/test/criterion/HungarianTest.cpp +++ b/flashlight/pkg/vision/test/criterion/HungarianTest.cpp @@ -69,7 +69,7 @@ TEST(HungarianTest, FullPipelineSimple1) { for(int c = 0; c < N; c++) for(int r = 0; r < M; r++) EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) - << "Assignment differs at row " << r << " and col " << c; + << "Assignment differs at row " << r << " and col " << c; } TEST(HungarianTest, FullPipelineSimple2) { @@ -84,7 +84,7 @@ TEST(HungarianTest, FullPipelineSimple2) { for(int c = 0; c < N; c++) for(int r = 0; r < M; r++) EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) - << "Assignment differs at row " << r << " and col " << c; + << "Assignment differs at row " << r << " and col " << c; } TEST(HungarianTest, FullPipelineSimple3) { @@ -98,7 +98,7 @@ TEST(HungarianTest, FullPipelineSimple3) { for(int c = 0; c < N; c++) for(int r = 0; r < M; r++) EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) - << "Assignment differs at row " << r << " and col " << c; + << "Assignment differs at row " << r << " and col " << c; } TEST(HungarianTest, FullPipelineSize6) { @@ -116,7 +116,7 @@ TEST(HungarianTest, FullPipelineSize6) { for(int c = 0; c < N; c++) for(int r = 0; r < M; r++) EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) - << "Assignment differs at row " << r << " and col " << c; + << "Assignment differs at row " << r << " and col " << c; } TEST(HungarianTest, 6x6Example2) { int M = 6; // Rows @@ -133,7 +133,7 @@ TEST(HungarianTest, 6x6Example2) { for(int c = 0; c < N; c++) for(int r = 0; r < M; r++) EXPECT_EQ(assignment[c * M + r], expAssignment[c * M + r]) - << "Assignment differs at row " << r << " and col " << c; + << "Assignment differs at row " << r << " and col " << c; } TEST(HungarianTest, NonSquare2) { diff --git a/uncrustify.cfg b/uncrustify.cfg index d264622..3ba25cb 100644 --- a/uncrustify.cfg +++ b/uncrustify.cfg @@ -47,7 +47,7 @@ align_typedef_span = 0 align_right_cmt_span = 0 align_func_proto_span = 0 align_nl_cont = 2 -align_left_shift = false +align_left_shift = true sp_arith = force sp_assign = force sp_cpp_lambda_assign = force From e6db0d6e3715f8afe95b76261319ab674e72a06c Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 21:54:40 +0100 Subject: [PATCH 22/24] using arch instead of ubuntu --- .github/workflows/check-formatting.yml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/check-formatting.yml b/.github/workflows/check-formatting.yml index 7208784..9a4de25 100644 --- a/.github/workflows/check-formatting.yml +++ b/.github/workflows/check-formatting.yml @@ -12,7 +12,7 @@ permissions: { contents: read } env: UNCRUSTIFY_CONFIG: "uncrustify.cfg" CHECK_PATH: "flashlight" - FILE_EXTENSIONS: "c|cpp|h|hpp|cu|cuh" + FILE_EXTENSIONS: "c|cpp|h|hpp|cu" # --------------------------------------------------------- # JOB @@ -21,16 +21,18 @@ jobs: formatting-check: name: Format check runs-on: ubuntu-latest + container: archlinux:latest steps: + - name: Install Git and Uncrustify + run: pacman -Syu --noconfirm git uncrustify + - uses: actions/checkout@v4 - - - name: Install uncrustify - run: | - sudo apt-get update - sudo apt-get install -y uncrustify - name: Run uncrustify style check run: | + # Print the version to confirm you are on the absolute latest version + uncrustify --version + find ${{ env.CHECK_PATH }} \ -type f \ -regextype posix-extended \ From 1b30dfd37fe79de9e0963f74442252f98d135cca Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 22:18:59 +0100 Subject: [PATCH 23/24] forgot to make uncrustify target optional --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 060360e..8e88b77 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -147,4 +147,4 @@ setup_install_targets(INSTALL_TARGETS ${INSTALLABLE_TARGETS}) include(fm_target_utilities) fm_glob_cpp(FM_CPP "flashlight/*") -fm_add_uncrustify_target(uncrustify-format ${FM_CPP}) +fm_add_uncrustify_target(uncrustify-format OPTIONAL ${FM_CPP}) From 75adc94a06b00de2d9f39f3dcfd211ce0cb39cbe Mon Sep 17 00:00:00 2001 From: Lukas Thomann Date: Mon, 23 Feb 2026 22:28:07 +0100 Subject: [PATCH 24/24] fixed another formatting bug --- flashlight/fl/autograd/Functions.cpp | 6 +- .../tensor/backend/cudnn/BatchNorm.cpp | 10 +- .../autograd/tensor/backend/cudnn/Conv2D.cpp | 25 ++-- .../backend/cudnn/CudnnAutogradExtension.cpp | 8 +- .../autograd/tensor/backend/cudnn/Pool2D.cpp | 10 +- .../tensor/backend/onednn/BatchNorm.cpp | 34 +++-- .../autograd/tensor/backend/onednn/Conv2D.cpp | 90 ++++++++----- .../autograd/tensor/backend/onednn/Pool2D.cpp | 30 +++-- .../fl/autograd/tensor/backend/onednn/RNN.cpp | 33 +++-- flashlight/fl/common/Logging.cpp | 3 +- flashlight/fl/common/Serialization-inl.h | 3 +- flashlight/fl/common/Types.h | 50 ++++--- flashlight/fl/examples/RnnClassification.cpp | 8 +- flashlight/fl/examples/RnnLm.cpp | 3 +- flashlight/fl/meter/EditDistanceMeter.h | 3 +- flashlight/fl/runtime/SynchronousStream.h | 8 +- .../fl/tensor/backend/af/ArrayFireBackend.h | 6 +- flashlight/fl/tensor/backend/af/Utils.cpp | 10 +- .../test/autograd/AutogradBinaryOpsTest.cpp | 8 +- .../autograd/AutogradNormalizationTest.cpp | 9 +- flashlight/fl/test/autograd/AutogradTest.cpp | 122 +++++++++++------- flashlight/fl/test/common/LoggingTest.cpp | 6 +- .../contrib/modules/ContribModuleTest.cpp | 18 +-- flashlight/fl/test/dataset/DatasetTest.cpp | 6 +- flashlight/fl/test/nn/ModuleTest.cpp | 99 ++++++++------ flashlight/fl/test/tensor/TensorBaseTest.cpp | 3 +- .../fl/test/tensor/TensorBinaryOpsTest.cpp | 20 ++- .../fl/test/tensor/af/MemoryFrameworkTest.cpp | 3 +- .../pkg/runtime/common/DistributedUtils.cpp | 18 ++- .../pkg/runtime/common/SequentialBuilder.cpp | 5 +- .../pkg/speech/augmentation/SoxWrapper.cpp | 12 +- .../criterion/AutoSegmentationCriterion.h | 3 +- .../pkg/speech/data/ListFileDataset.cpp | 3 +- flashlight/pkg/speech/data/Sound.cpp | 46 ++++--- .../pkg/speech/decoder/DecodeMaster.cpp | 15 ++- flashlight/pkg/speech/decoder/PlGenerator.cpp | 6 +- .../pkg/speech/test/audio/CeplifterTest.cpp | 9 +- flashlight/pkg/speech/test/audio/DctTest.cpp | 6 +- .../pkg/speech/test/audio/DerivativesTest.cpp | 9 +- .../pkg/speech/test/audio/PreEmphasisTest.cpp | 54 ++++---- .../speech/test/audio/TriFilterbankTest.cpp | 9 +- .../pkg/speech/test/audio/WindowingTest.cpp | 3 +- .../test/augmentation/TimeStretchTest.cpp | 3 +- .../speech/test/criterion/CriterionTest.cpp | 15 ++- .../pkg/speech/test/criterion/Seq2SeqTest.cpp | 9 +- .../speech/test/data/FeaturizationTest.cpp | 22 +++- flashlight/pkg/text/data/TextDataset.cpp | 10 +- .../pkg/vision/criterion/SetCriterion.cpp | 5 +- flashlight/pkg/vision/dataset/Coco.cpp | 33 +++-- .../pkg/vision/dataset/CocoTransforms.cpp | 18 ++- .../pkg/vision/test/TransformerTest.cpp | 6 +- flashlight/pkg/vision/test/TransformsTest.cpp | 6 +- .../vision/test/criterion/HungarianTest.cpp | 35 +++-- .../test/criterion/SetCriterionTest.cpp | 115 +++++++++++------ uncrustify.cfg | 5 +- 55 files changed, 716 insertions(+), 398 deletions(-) diff --git a/flashlight/fl/autograd/Functions.cpp b/flashlight/fl/autograd/Functions.cpp index d597c01..70b7a9e 100644 --- a/flashlight/fl/autograd/Functions.cpp +++ b/flashlight/fl/autograd/Functions.cpp @@ -178,8 +178,10 @@ Variable operator*(const Variable& lhs, const Variable& rhs) { }; return Variable( result, - {rhs.isCalcGrad() ? lhs : lhs.withoutData(), - lhs.isCalcGrad() ? rhs : rhs.withoutData()}, + { + rhs.isCalcGrad() ? lhs : lhs.withoutData(), + lhs.isCalcGrad() ? rhs : rhs.withoutData() + }, gradFunc ); } diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp index 25b4159..9f6b315 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp @@ -49,10 +49,12 @@ namespace { if(minAxis == 0) { modeOut = CUDNN_BATCHNORM_PER_ACTIVATION; inDescDimsOut = Shape( - {1, - 1, - nfeatures, - static_cast(input.elements() / nfeatures)} + { + 1, + 1, + nfeatures, + static_cast(input.elements() / nfeatures) + } ); wtDescDimsOut = Shape({1, 1, nfeatures}); } else { diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp index 84ea052..bb89e61 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp @@ -26,23 +26,32 @@ namespace { std::unordered_map kKernelModesToCudnnMathType = { {fl::CudnnAutogradExtension::KernelMode::F32, CUDNN_DEFAULT_MATH}, - {fl::CudnnAutogradExtension::KernelMode::F32_ALLOW_CONVERSION, - CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION}, - {fl::CudnnAutogradExtension::KernelMode::F16, - CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION}}; + { + fl::CudnnAutogradExtension::KernelMode::F32_ALLOW_CONVERSION, + CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION + }, + { + fl::CudnnAutogradExtension::KernelMode::F16, + CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION + } + }; const std::unordered_set kFwdPreferredAlgos = { CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, - CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED}; + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED + }; const std::unordered_set kBwdDataPreferredAlgos = - {CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, - CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; + { + CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED + }; const std::unordered_set kBwdFilterPreferredAlgos = { CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED}; + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED + }; constexpr size_t kWorkspaceSizeLimitBytes = 512 * 1024 * 1024; // 512 MB constexpr cudnnConvolutionFwdAlgo_t kFwdDefaultAlgo = diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp index 2560cd4..305a6cc 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp @@ -17,9 +17,11 @@ std::shared_ptr CudnnAutogradExtension::createBenchmarkOpt return std::make_shared( std::make_shared>( std::vector( - {KernelMode::F32, - KernelMode::F32_ALLOW_CONVERSION, - KernelMode::F16} + { + KernelMode::F32, + KernelMode::F32_ALLOW_CONVERSION, + KernelMode::F16 + } ), fl::kDynamicBenchmarkDefaultCount ) diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp index f9956c1..24b08c8 100644 --- a/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/cudnn/Pool2D.cpp @@ -37,10 +37,12 @@ Tensor CudnnAutogradExtension::pool2d( auto oy = 1 + (iy + 2 * py - wy) / sy; auto output = Tensor( - {ox, - oy, - input.ndim() < 3 ? 1 : input.dim(2), - input.ndim() < 4 ? 1 : input.dim(3)}, + { + ox, + oy, + input.ndim() < 3 ? 1 : input.dim(2), + input.ndim() < 4 ? 1 : input.dim(3) + }, input.type() ); auto outDesc = TensorDescriptor(output); diff --git a/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp b/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp index aca3ea5..b673281 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/BatchNorm.cpp @@ -45,20 +45,24 @@ namespace { Shape inDescDims; if(minAxis == 0) inDescDims = Shape( - {1, - 1, - nfeatures, - static_cast(input.elements() / nfeatures)} + { + 1, + 1, + nfeatures, + static_cast(input.elements() / nfeatures) + } ); else { int batchsz = 1; for(int i = maxAxis + 1; i < input.ndim(); ++i) batchsz *= input.dim(i); inDescDims = Shape( - {1, - static_cast(input.elements() / (nfeatures * batchsz)), - nfeatures, - batchsz} + { + 1, + static_cast(input.elements() / (nfeatures * batchsz)), + nfeatures, + batchsz + } ); } @@ -66,7 +70,8 @@ namespace { inDescDims[kBatchSizeIdx], inDescDims[kChannelSizeIdx], inDescDims[kHIdx], - inDescDims[kWIdx]}; + inDescDims[kWIdx] + }; return inputOutputDims; } @@ -185,7 +190,10 @@ Tensor OneDnnAutogradExtension::batchnorm( {DNNL_ARG_VARIANCE, varMemory.getMemory()}, {DNNL_ARG_DST, outputMemory.getMemory()}, {DNNL_ARG_SCALE, weightsMemory.getMemory()}, - {DNNL_ARG_SHIFT, biasMemory.getMemory()}}; + { + DNNL_ARG_SHIFT, biasMemory.getMemory() + } + }; // Execute std::vector network; @@ -267,7 +275,11 @@ std::tuple OneDnnAutogradExtension::batchnormBackward( {DNNL_ARG_DIFF_SRC, gradInputMem.getMemory()}, {DNNL_ARG_DIFF_DST, gradOutputMem.getMemory()}, {DNNL_ARG_DIFF_SCALE, gradWeightsMem.getMemory()}, - {DNNL_ARG_DIFF_SHIFT, gradBiasMem.getMemory()}}}; + { + DNNL_ARG_DIFF_SHIFT, gradBiasMem.getMemory() + } + } + }; networkBackwards.push_back(*bwdPrim); detail::executeNetwork(networkBackwards, bwdArgs); diff --git a/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp b/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp index 881e076..3596c84 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/Conv2D.cpp @@ -74,31 +74,39 @@ namespace { OneDnnConv2DData out; // Create memory dims out.inputDims = detail::convertToDnnlDims( - {inputShape.dim(kIOBatchSizeIdx), - inputShape.dim(kIOChannelSizeIdx), - inputShape.dim(kHIdx), - inputShape.dim(kWIdx)} + { + inputShape.dim(kIOBatchSizeIdx), + inputShape.dim(kIOChannelSizeIdx), + inputShape.dim(kHIdx), + inputShape.dim(kWIdx) + } ); if(groups == 1) out.weightDims = detail::convertToDnnlDims( - {weightsShape.dim(kWeightOutputChannelSizeIdx), - inputShape.dim(kIOChannelSizeIdx), - weightsShape.dim(kHIdx), - weightsShape.dim(kWIdx)} + { + weightsShape.dim(kWeightOutputChannelSizeIdx), + inputShape.dim(kIOChannelSizeIdx), + weightsShape.dim(kHIdx), + weightsShape.dim(kWIdx) + } ); else out.weightDims = detail::convertToDnnlDims( - {groups, - weightsShape.dim(kWeightOutputChannelSizeIdx) / groups, - inputShape.dim(kIOChannelSizeIdx) / groups, - weightsShape.dim(kHIdx), - weightsShape.dim(kWIdx)} + { + groups, + weightsShape.dim(kWeightOutputChannelSizeIdx) / groups, + inputShape.dim(kIOChannelSizeIdx) / groups, + weightsShape.dim(kHIdx), + weightsShape.dim(kWIdx) + } ); out.outputDims = detail::convertToDnnlDims( - {inputShape.dim(kIOBatchSizeIdx), - weightsShape.dim(kWeightOutputChannelSizeIdx), - outputShape.dim(kHIdx), - outputShape.dim(kWIdx)} + { + inputShape.dim(kIOBatchSizeIdx), + weightsShape.dim(kWeightOutputChannelSizeIdx), + outputShape.dim(kHIdx), + outputShape.dim(kWIdx) + } ); out.biasDims = detail::convertToDnnlDims( {weightsShape.dim(kWeightOutputChannelSizeIdx)} @@ -183,14 +191,16 @@ Tensor OneDnnAutogradExtension::conv2d( // row major transposes along all axis into NCHW for the input and output // and OIHW for the weights auto output = Tensor( - {1 - + (input.dim(kWIdx) + (2 * px) - (1 + (weights.dim(kWIdx) - 1) * dx)) - / sx, - 1 - + (input.dim(kHIdx) + (2 * py) - (1 + (weights.dim(kHIdx) - 1) * dy)) - / sy, - weights.dim(kWeightOutputChannelSizeIdx), - input.dim(kIOBatchSizeIdx)}, + { + 1 + + (input.dim(kWIdx) + (2 * px) - (1 + (weights.dim(kWIdx) - 1) * dx)) + / sx, + 1 + + (input.dim(kHIdx) + (2 * py) - (1 + (weights.dim(kHIdx) - 1) * dy)) + / sy, + weights.dim(kWeightOutputChannelSizeIdx), + input.dim(kIOBatchSizeIdx) + }, input.type() ); auto hasBias = bias.elements() > 0; @@ -264,7 +274,10 @@ Tensor OneDnnAutogradExtension::conv2d( std::unordered_map convFwdArgs = { {DNNL_ARG_SRC, inputMemory}, {DNNL_ARG_WEIGHTS, weightsMemory}, - {DNNL_ARG_DST, outputMemory}}; + { + DNNL_ARG_DST, outputMemory + } + }; if(hasBias) convFwdArgs[DNNL_ARG_BIAS] = biasMemory.getMemory(); fwdArgs.push_back(convFwdArgs); @@ -274,7 +287,10 @@ Tensor OneDnnAutogradExtension::conv2d( network.push_back(dnnl::reorder(outputMemory, outputMemInit.getMemory())); fwdArgs.push_back( {{DNNL_ARG_FROM, outputMemory}, - {DNNL_ARG_TO, outputMemInit.getMemory()}} + { + DNNL_ARG_TO, outputMemInit.getMemory() + } + } ); } @@ -375,7 +391,10 @@ Tensor OneDnnAutogradExtension::conv2dBackwardData( bwdDataArgs.push_back( {{DNNL_ARG_DIFF_SRC, gradInputMemory}, {DNNL_ARG_WEIGHTS, weightsMemoryBackwards}, - {DNNL_ARG_DIFF_DST, gradOutputMemory}} + { + DNNL_ARG_DIFF_DST, gradOutputMemory + } + } ); networkBackwards.push_back(*convBwdData); @@ -386,7 +405,10 @@ Tensor OneDnnAutogradExtension::conv2dBackwardData( ); bwdDataArgs.push_back( {{DNNL_ARG_FROM, gradInputMemory}, - {DNNL_ARG_TO, gradInputMemInit.getMemory()}} + { + DNNL_ARG_TO, gradInputMemInit.getMemory() + } + } ); } @@ -505,7 +527,10 @@ std::pair OneDnnAutogradExtension::conv2dBackwardFilterBias( std::unordered_map bwdConvWeightsArgs = { {DNNL_ARG_SRC, inputMemoryBackwards}, {DNNL_ARG_DIFF_WEIGHTS, gradWeightsMemory}, - {DNNL_ARG_DIFF_DST, gradOutputMemory}}; + { + DNNL_ARG_DIFF_DST, gradOutputMemory + } + }; if(computeBiasGrad) { const detail::DnnlMemoryWrapper gradBiasMem( @@ -522,7 +547,10 @@ std::pair OneDnnAutogradExtension::conv2dBackwardFilterBias( ); bwdWeightsArgs.push_back( {{DNNL_ARG_FROM, gradWeightsMemory}, - {DNNL_ARG_TO, gradWeightsMemInit.getMemory()}} + { + DNNL_ARG_TO, gradWeightsMemInit.getMemory() + } + } ); } diff --git a/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp b/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp index c6523e1..ba26349 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/Pool2D.cpp @@ -52,16 +52,20 @@ namespace { ) { DimsData d; d.inputDims = detail::convertToDnnlDims( - {input.dim(kBatchSizeIdx), - input.dim(kChannelSizeIdx), - input.dim(kHIdx), - input.dim(kWIdx)} + { + input.dim(kBatchSizeIdx), + input.dim(kChannelSizeIdx), + input.dim(kHIdx), + input.dim(kWIdx) + } ); d.outputDims = detail::convertToDnnlDims( - {input.dim(kBatchSizeIdx), - input.dim(kChannelSizeIdx), - output.dim(kHIdx), - output.dim(kWIdx)} + { + input.dim(kBatchSizeIdx), + input.dim(kChannelSizeIdx), + output.dim(kHIdx), + output.dim(kWIdx) + } ); d.windowDims = {wy, wx}; d.strideDims = {sy, sx}; @@ -176,7 +180,10 @@ Tensor OneDnnAutogradExtension::pool2d( ); fwdArgs.push_back( {{DNNL_ARG_FROM, payload->outputMemory}, - {DNNL_ARG_TO, outputMemInit.getMemory()}} + { + DNNL_ARG_TO, outputMemInit.getMemory() + } + } ); } @@ -248,7 +255,10 @@ Tensor OneDnnAutogradExtension::pool2dBackward( std::unordered_map bwdPoolingArgs = { {DNNL_ARG_DIFF_SRC, gradInputMemInit.getMemory()}, {DNNL_ARG_DIFF_DST, gradOutputMemory}, - {DNNL_ARG_WORKSPACE, payload->workspace}}; + { + DNNL_ARG_WORKSPACE, payload->workspace + } + }; bwdArgs.push_back(bwdPoolingArgs); networkBackward.push_back(poolBwd); diff --git a/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp b/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp index 63f2316..f866f1c 100644 --- a/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp +++ b/flashlight/fl/autograd/tensor/backend/onednn/RNN.cpp @@ -282,22 +282,28 @@ namespace { int seqLength = input.ndim() < 3 ? 1 : input.dim(2); dnnl::memory::dims inputDims = {seqLength, batchSize, inSize}; dnnl::memory::dims outputDims = { - seqLength, batchSize, hiddenSize* directionMult}; + seqLength, batchSize, hiddenSize* directionMult + }; auto dType = detail::dnnlMapToType(input.type()); int totalLayers = numLayers; int outSize = hiddenSize; dnnl::memory::dims hDims = { - totalLayers, directionMult, batchSize, hiddenSize}; + totalLayers, directionMult, batchSize, hiddenSize + }; dnnl::memory::dims cDims = { - totalLayers, directionMult, batchSize, hiddenSize}; + totalLayers, directionMult, batchSize, hiddenSize + }; int extraBias = mode == RnnMode::GRU ? 1 : 0; // for LBR GRU dnnl::memory::dims biasDims = { - numLayers, directionMult, numGates + extraBias, hiddenSize}; + numLayers, directionMult, numGates + extraBias, hiddenSize + }; // ldigo dnnl::memory::dims weightsInputDims = { - numLayers, directionMult, inSize, numGates, hiddenSize}; + numLayers, directionMult, inSize, numGates, hiddenSize + }; dnnl::memory::dims weightsHiddenDims = { - numLayers, directionMult, hiddenSize, numGates, hiddenSize}; + numLayers, directionMult, hiddenSize, numGates, hiddenSize + }; // Out tensors: output (y), hidden state output (hy), cell state output (cy) auto y = Tensor({outSize, batchSize, seqLength}, input.type()); @@ -354,7 +360,10 @@ namespace { {DNNL_ARG_WEIGHTS_ITER, weightsHiddenMemInit}, {DNNL_ARG_BIAS, biasMemInit.getMemory()}, {DNNL_ARG_DST_LAYER, outputMemInit.getMemory()}, - {DNNL_ARG_DST_ITER, hiddenOutMemInit.getMemory()}}; + { + DNNL_ARG_DST_ITER, hiddenOutMemInit.getMemory() + } + }; // Workspace memory, if needed dnnl::memory workspace; @@ -367,7 +376,10 @@ namespace { ); fwdArgs.push_back( {{DNNL_ARG_FROM, weightsInputMemRawInit.getMemory()}, - {DNNL_ARG_TO, weightsInputMemInit}} + { + DNNL_ARG_TO, weightsInputMemInit + } + } ); // reorder iter weights network.push_back( @@ -375,7 +387,10 @@ namespace { ); fwdArgs.push_back( {{DNNL_ARG_FROM, weightsHiddenMemRawInit.getMemory()}, - {DNNL_ARG_TO, weightsHiddenMemInit}} + { + DNNL_ARG_TO, weightsHiddenMemInit + } + } ); // Initialize descriptors diff --git a/flashlight/fl/common/Logging.cpp b/flashlight/fl/common/Logging.cpp index d558886..ae7a729 100644 --- a/flashlight/fl/common/Logging.cpp +++ b/flashlight/fl/common/Logging.cpp @@ -231,7 +231,8 @@ constexpr std::array flLogLevelValues = { fl::LogLevel::WARNING, fl::LogLevel::ERROR, fl::LogLevel::FATAL, - fl::LogLevel::DISABLED}; + fl::LogLevel::DISABLED +}; constexpr std::array flLogLevelNames = {"INFO", "WARNING", "ERROR", "FATAL", "DISABLED"}; diff --git a/flashlight/fl/common/Serialization-inl.h b/flashlight/fl/common/Serialization-inl.h index 337c594..7777890 100644 --- a/flashlight/fl/common/Serialization-inl.h +++ b/flashlight/fl/common/Serialization-inl.h @@ -122,7 +122,8 @@ detail::SerializeAs detail::SerializeAs serializeAs(T&& t, SaveConvFn saveConverter, LoadConvFn loadConverter) { return detail::SerializeAs{ - std::forward(t), std::move(saveConverter), std::move(loadConverter)}; + std::forward(t), std::move(saveConverter), std::move(loadConverter) + }; } template diff --git a/flashlight/fl/common/Types.h b/flashlight/fl/common/Types.h index 11adfcd..dcb3168 100644 --- a/flashlight/fl/common/Types.h +++ b/flashlight/fl/common/Types.h @@ -23,27 +23,35 @@ namespace detail { const std::unordered_map> kOptimLevelTypeExclusionMappings = { {OptimLevel::DEFAULT, {}}, // unused - {OptimLevel::O1, - // Perform all operations in fp16 except for: - {"batchnorm", - "reciprocal", - "erf", - "exp", - "log", - "log1p", - "pow", - "sum", - "mean", - "var", - "norm", - "normalize", - "softmax", - "logSoftmax", - "categoricalCrossEntropy", - "gelu"}}, - {OptimLevel::O2, - // Perform all operations in fp16 except for: - {"batchnorm"}}, + { + OptimLevel::O1, + // Perform all operations in fp16 except for: + { + "batchnorm", + "reciprocal", + "erf", + "exp", + "log", + "log1p", + "pow", + "sum", + "mean", + "var", + "norm", + "normalize", + "softmax", + "logSoftmax", + "categoricalCrossEntropy", + "gelu" + } + }, + { + OptimLevel::O2, + // Perform all operations in fp16 except for: + { + "batchnorm" + } + }, {OptimLevel::O3, {}} // Perform all operations in f16 }; diff --git a/flashlight/fl/examples/RnnClassification.cpp b/flashlight/fl/examples/RnnClassification.cpp index ed16839..1919aa1 100644 --- a/flashlight/fl/examples/RnnClassification.cpp +++ b/flashlight/fl/examples/RnnClassification.cpp @@ -101,7 +101,8 @@ class ClassificationDataset : public Dataset { "French", "Polish", "Italian", - "Irish"}; + "Irish" + }; for(auto& l : lang) read(datasetPath, l); for(auto& it : Id2Label) @@ -360,7 +361,10 @@ int main(int argc, char** argv) { {"Washington", "English"}, {"Voltaire", "French"}, {"Pfeiffer", "German"}, - {"Tambellini", "Italian"}}; + { + "Tambellini", "Italian" + } + }; for(auto& p : quickList) model.unittest(p.first, p.second); diff --git a/flashlight/fl/examples/RnnLm.cpp b/flashlight/fl/examples/RnnLm.cpp index 9402fd9..a08d6ca 100644 --- a/flashlight/fl/examples/RnnLm.cpp +++ b/flashlight/fl/examples/RnnLm.cpp @@ -321,5 +321,6 @@ std::vector LMDataset::get(const int64_t idx) const { int end = (idx + 1) * time_steps; return { data(fl::span, fl::range(start, end)), - data(fl::span, fl::range(start, end))}; + data(fl::span, fl::range(start, end)) + }; } diff --git a/flashlight/fl/meter/EditDistanceMeter.h b/flashlight/fl/meter/EditDistanceMeter.h index 6e0c05e..a4b6153 100644 --- a/flashlight/fl/meter/EditDistanceMeter.h +++ b/flashlight/fl/meter/EditDistanceMeter.h @@ -141,7 +141,8 @@ class FL_API EditDistanceMeter { auto possibilities = { column[y].sum() + 1, column[y - 1].sum() + 1, - lastdiagonal.sum() + ((*curin1 == *curin2) ? 0 : 1)}; + lastdiagonal.sum() + ((*curin1 == *curin2) ? 0 : 1) + }; auto min_it = std::min_element(possibilities.begin(), possibilities.end()); if( diff --git a/flashlight/fl/runtime/SynchronousStream.h b/flashlight/fl/runtime/SynchronousStream.h index 8e0081b..72f78b0 100644 --- a/flashlight/fl/runtime/SynchronousStream.h +++ b/flashlight/fl/runtime/SynchronousStream.h @@ -18,9 +18,11 @@ namespace fl { */ class FL_API SynchronousStream : public StreamTrait { protected: - X64Device& device_{DeviceManager::getInstance() - .getActiveDevice(DeviceType::x64) - .impl()}; + X64Device& device_{ + DeviceManager::getInstance() + .getActiveDevice(DeviceType::x64) + .impl() + }; public: // prevent name hiding diff --git a/flashlight/fl/tensor/backend/af/ArrayFireBackend.h b/flashlight/fl/tensor/backend/af/ArrayFireBackend.h index 19af6f9..698fe33 100644 --- a/flashlight/fl/tensor/backend/af/ArrayFireBackend.h +++ b/flashlight/fl/tensor/backend/af/ArrayFireBackend.h @@ -37,8 +37,10 @@ class ArrayFireBackend : public TensorBackend { // NOTE using a `shared_ptr` to allow its capture in setActive callback; // see constructor for details. std::shared_ptr>> - afIdToStream_{std::make_shared< - std::unordered_map>>()}; + afIdToStream_{ + std::make_shared< + std::unordered_map>>() + }; // Intentionally private. Only one instance should exist/it should be accessed // via getInstance(). diff --git a/flashlight/fl/tensor/backend/af/Utils.cpp b/flashlight/fl/tensor/backend/af/Utils.cpp index 946d0d1..8b11f58 100644 --- a/flashlight/fl/tensor/backend/af/Utils.cpp +++ b/flashlight/fl/tensor/backend/af/Utils.cpp @@ -28,7 +28,10 @@ af::dtype flToAfType(fl::dtype type) { {fl::dtype::u8, af::dtype::u8}, {fl::dtype::u16, af::dtype::u16}, {fl::dtype::u32, af::dtype::u32}, - {fl::dtype::u64, af::dtype::u64}}; + { + fl::dtype::u64, af::dtype::u64 + } + }; return kFlashlightTypeToArrayFire.at(type); } @@ -45,7 +48,10 @@ fl::dtype afToFlType(af::dtype type) { {af::dtype::u8, fl::dtype::u8}, {af::dtype::u16, fl::dtype::u16}, {af::dtype::u32, fl::dtype::u32}, - {af::dtype::u64, fl::dtype::u64}}; + { + af::dtype::u64, fl::dtype::u64 + } + }; return kArrayFireTypeToFlashlight.at(type); } diff --git a/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp b/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp index c6299cf..844d41f 100644 --- a/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp +++ b/flashlight/fl/test/autograd/AutogradBinaryOpsTest.cpp @@ -83,7 +83,8 @@ TEST(AutogradBinaryOpsTest, CrossEntropy) { auto ignoreIdx = y(0, 0).scalar(); std::vector modes = { - ReduceMode::NONE, ReduceMode::SUM, ReduceMode::MEAN}; + ReduceMode::NONE, ReduceMode::SUM, ReduceMode::MEAN + }; for(auto mode : modes) { auto func = [&](Variable& input) { return categoricalCrossEntropy(input, y, mode); @@ -242,7 +243,10 @@ TEST(AutogradBinaryOpsTest, matmul) { {mkb2, kn}, {mkb2, knb2}, {mkb2b3, kn}, - {mkb2b3, knb2b3}}; + { + mkb2b3, knb2b3 + } + }; auto trFirstTwoDims = [](const Shape& in) -> Shape { Shape out = in; diff --git a/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp b/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp index bf5e215..f558268 100644 --- a/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp +++ b/flashlight/fl/test/autograd/AutogradNormalizationTest.cpp @@ -124,7 +124,8 @@ TEST(AutogradNormalizationTest, BatchNormEvalModeOutputMultipleAxis) { )); for(int i = 0; i < nfeatures; ++i) { std::array sel = { - i % 13, (i / 13) % 13, (i / 13) / 13, fl::span}; + i % 13, (i / 13) % 13, (i / 13) / 13, fl::span + }; auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); auto thisMean = runningMean.tensor().flatten()(i).scalar(); auto thisVar = runningVar.tensor().flatten()(i).scalar(); @@ -152,7 +153,8 @@ TEST(AutogradNormalizationTest, BatchNormEvalModeOutputMultipleAxis) { )); for(int i = 0; i < nfeatures; ++i) { std::array sel = { - i % 13, (i / 13) % 13, (i / 13) / 13, fl::span}; + i % 13, (i / 13) % 13, (i / 13) / 13, fl::span + }; auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); auto thisMean = runningMean.tensor().flatten()(i).scalar(); auto thisVar = runningVar.tensor().flatten()(i).scalar(); @@ -228,7 +230,8 @@ TEST(AutogradNormalizationTest, BatchNormTrainModeOutputMultipleAxis) { for(int i = 0; i < nfeatures; ++i) { std::array sel = { - i % 13, (i / 13) % 13, (i / 13) / 13, fl::span}; + i % 13, (i / 13) % 13, (i / 13) / 13, fl::span + }; auto thisInput = input.tensor()(sel[0], sel[1], sel[2], sel[3]); auto thisMean = avg.tensor().flatten()(i).scalar(); auto thisVar = variance.tensor().flatten()(i).scalar(); diff --git a/flashlight/fl/test/autograd/AutogradTest.cpp b/flashlight/fl/test/autograd/AutogradTest.cpp index 9636b39..3d53c77 100644 --- a/flashlight/fl/test/autograd/AutogradTest.cpp +++ b/flashlight/fl/test/autograd/AutogradTest.cpp @@ -40,125 +40,147 @@ TEST(AutogradTest, AutogradOperatorTypeCompatibility) { // Binary operators EXPECT_THROW( - {auto res = f16 + f32; + { + auto res = f16 + f32; }, std::invalid_argument ); // + EXPECT_THROW( - {auto res = f16 - f32; + { + auto res = f16 - f32; }, std::invalid_argument ); // - EXPECT_THROW( - {auto res = f16 * f32; + { + auto res = f16 * f32; }, std::invalid_argument ); // * EXPECT_THROW( - {auto res = f16 / f32; + { + auto res = f16 / f32; }, std::invalid_argument ); /// EXPECT_THROW( - {auto res = f16 > f32; + { + auto res = f16 > f32; }, std::invalid_argument ); // > EXPECT_THROW( - {auto res = f16 < f32; + { + auto res = f16 < f32; }, std::invalid_argument ); // < EXPECT_THROW( - {auto res = f16 >= f32; + { + auto res = f16 >= f32; }, std::invalid_argument ); // >= EXPECT_THROW( - {auto res = f16 <= f32; + { + auto res = f16 <= f32; }, std::invalid_argument ); // <= EXPECT_THROW( - {auto res = f16 && f32; + { + auto res = f16 && f32; }, std::invalid_argument ); // && EXPECT_THROW( - {max(f16, f32); + { + max(f16, f32); }, std::invalid_argument ); // max EXPECT_THROW( - {min(f16, f32); + { + min(f16, f32); }, std::invalid_argument ); // min EXPECT_THROW( - {matmul(f16, f32); + { + matmul(f16, f32); }, std::invalid_argument ); // matmul EXPECT_THROW( - {matmulTN(f16, f32); + { + matmulTN(f16, f32); }, std::invalid_argument ); // matmulTN EXPECT_THROW( - {matmulNT(f16, f32); + { + matmulNT(f16, f32); }, std::invalid_argument ); // matmulNT EXPECT_NO_THROW( - {binaryCrossEntropy(f16, f32); + { + binaryCrossEntropy(f16, f32); } ); EXPECT_NO_THROW( - { - categoricalCrossEntropy( - Variable(fl::rand({7, 10, 4}, fl::dtype::f16), true), - Variable( - (fl::rand({10, 4}, fl::dtype::u32) % 7).astype(fl::dtype::s32), - false - ) - ); - } + { + categoricalCrossEntropy( + Variable(fl::rand({7, 10, 4}, fl::dtype::f16), true), + Variable( + (fl::rand({10, 4}, fl::dtype::u32) % 7).astype(fl::dtype::s32), + false + ) + ); + } ); EXPECT_NO_THROW( - {pool2d(f16, 1, 1, 1, 1, 1, 1); + { + pool2d(f16, 1, 1, 1, 1, 1, 1); } ); EXPECT_NO_THROW( - {embedding(f16, f32); + { + embedding(f16, f32); } ); // lookup is of a different type // Ternary operators auto f32_2 = Variable(fl::rand({2, 2}, fl::dtype::f32), true); auto f16_2 = Variable(fl::rand({2, 2}, fl::dtype::f16), true); EXPECT_THROW( - {linear(f16, f32, f16_2); + { + linear(f16, f32, f16_2); }, std::invalid_argument ); // linear EXPECT_THROW( - {linear(f16, f32, f32_2); + { + linear(f16, f32, f32_2); }, std::invalid_argument ); // linear auto w = Variable(fl::rand({1}, fl::dtype::f32), true); auto b = Variable(fl::rand({1}, fl::dtype::f32), true); EXPECT_THROW( - {batchnorm(f16, f32, f32_2, w, b, {1}, true, 0.01, 0.01); + { + batchnorm(f16, f32, f32_2, w, b, {1}, true, 0.01, 0.01); }, std::invalid_argument ); EXPECT_THROW( - {batchnorm(f16, f32, f16_2, w, b, {1}, true, 0.01, 0.01); + { + batchnorm(f16, f32, f16_2, w, b, {1}, true, 0.01, 0.01); }, std::invalid_argument ); EXPECT_THROW( - {conv2d(f16, f32, f16_2, 1, 1, 0, 0, 1, 1); + { + conv2d(f16, f32, f16_2, 1, 1, 0, 0, 1, 1); }, std::invalid_argument ); @@ -166,25 +188,26 @@ TEST(AutogradTest, AutogradOperatorTypeCompatibility) { auto f16_3 = Variable(fl::rand({2, 2, 3}, fl::dtype::f16), false); auto f16_4 = Variable(fl::rand({50}, fl::dtype::f16), false); EXPECT_THROW( - { - rnn( - f16_3, - Variable(Tensor(fl::dtype::f32), false), - Variable(Tensor(fl::dtype::f32), false), - f16_4, - 2, - 2, - RnnMode::LSTM, - true, - 0.0 - ); - }, + { + rnn( + f16_3, + Variable(Tensor(fl::dtype::f32), false), + Variable(Tensor(fl::dtype::f32), false), + f16_4, + 2, + 2, + RnnMode::LSTM, + true, + 0.0 + ); + }, std::invalid_argument ); // Variadic operators std::vector concatInputs = {f16, f32, f16_2, f32_2}; EXPECT_THROW( - {concatenate(concatInputs, 0); + { + concatenate(concatInputs, 0); }, std::invalid_argument ); @@ -198,7 +221,8 @@ TEST(AutogradTest, CastingAsDifferentGradTypes) { auto f16 = Variable(fl::rand({5, 5}, fl::dtype::f16), true); // Computing gradients with mixed types fails when the op is applied ASSERT_THROW( - {f32 + f16; + { + f32 + f16; }, std::invalid_argument ); @@ -467,7 +491,8 @@ TEST(AutogradTest, GetAdvancedIndex) { GTEST_SKIP() << "Advanced indexing operator unsupported for non-CUDA backends"; std::vector validIndexTypes = { - fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64}; + fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64 + }; for(const auto& dtype : validIndexTypes) { auto x = Variable(fl::rand({20, 50, 40, 30}, fl::dtype::f32), true); Tensor a({6}, dtype); @@ -498,7 +523,8 @@ TEST(AutogradTest, GetAdvancedIndexF16) { if(!fl::f16Supported()) GTEST_SKIP() << "Half-precision not supported on this device"; std::vector validIndexTypes = { - fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64}; + fl::dtype::s32, fl::dtype::s64, fl::dtype::u32, fl::dtype::u64 + }; for(const auto& dtype : validIndexTypes) { auto x = Variable(fl::rand({20, 50, 40, 30}, fl::dtype::f16), true); Tensor a({6}, dtype); diff --git a/flashlight/fl/test/common/LoggingTest.cpp b/flashlight/fl/test/common/LoggingTest.cpp index ea9cdea..87f67c5 100644 --- a/flashlight/fl/test/common/LoggingTest.cpp +++ b/flashlight/fl/test/common/LoggingTest.cpp @@ -82,7 +82,8 @@ TEST(Logging, logOnOff) { fl::LogLevel::FATAL, fl::LogLevel::ERROR, fl::LogLevel::WARNING, - fl::LogLevel::INFO}; + fl::LogLevel::INFO + }; for(LogLevel l : logLevels) { stdoutBuffer.clear(); stderrBuffer.clear(); @@ -131,7 +132,8 @@ TEST(LoggingDeathTest, FatalOnOff) { Logging::setMaxLoggingLevel(fl::LogLevel::FATAL); EXPECT_DEATH_IF_SUPPORTED( - {FL_LOG(fl::LogLevel::FATAL) << "log-fatal"; + { + FL_LOG(fl::LogLevel::FATAL) << "log-fatal"; }, "" ); diff --git a/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp b/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp index d6295ef..8b55692 100644 --- a/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp +++ b/flashlight/fl/test/contrib/modules/ContribModuleTest.cpp @@ -164,14 +164,16 @@ void transformerPadMaskFwd(bool isfp16) { } auto output1 = tr.forward( - {input1NoPad, - Variable( - padMask(fl::range(0, timesteps / 2))( - fl::span, - fl::range(0, 1) - ), - false - )} + { + input1NoPad, + Variable( + padMask(fl::range(0, timesteps / 2))( + fl::span, + fl::range(0, 1) + ), + false + ) + } ) .front(); auto output2 = diff --git a/flashlight/fl/test/dataset/DatasetTest.cpp b/flashlight/fl/test/dataset/DatasetTest.cpp index 8962f1d..10d9c4c 100644 --- a/flashlight/fl/test/dataset/DatasetTest.cpp +++ b/flashlight/fl/test/dataset/DatasetTest.cpp @@ -21,7 +21,8 @@ using namespace fl; TEST(DatasetTest, TensorDataset) { std::vector tensormap = { - fl::rand({100, 200, 300}), fl::rand({150, 300})}; + fl::rand({100, 200, 300}), fl::rand({150, 300}) + }; TensorDataset tensords(tensormap); // Check `size` method @@ -150,7 +151,8 @@ TEST(DatasetTest, ResampleDataset) { TEST(DatasetTest, SpanDataset) { std::vector tensormap = { - fl::rand({100, 200, 300}), fl::rand({150, 300})}; + fl::rand({100, 200, 300}), fl::rand({150, 300}) + }; auto tensords = std::make_shared(tensormap); SpanDataset frontspands(tensords, 0, 13); diff --git a/flashlight/fl/test/nn/ModuleTest.cpp b/flashlight/fl/test/nn/ModuleTest.cpp index efb5c75..cb87cb9 100644 --- a/flashlight/fl/test/nn/ModuleTest.cpp +++ b/flashlight/fl/test/nn/ModuleTest.cpp @@ -117,8 +117,10 @@ TEST(ModuleTest, LinearFwd) { auto expected_outVar = Variable( Tensor::fromVector( {n_out, x, batchsize}, - {68, 22, 18, 48, 22, 14, 84, 26, 22, 66, 19, 17, - 150, 55, 41, 94, 41, 27, 130, 55, 37, 56, 24, 16} + { + 68, 22, 18, 48, 22, 14, 84, 26, 22, 66, 19, 17, + 150, 55, 41, 94, 41, 27, 130, 55, 37, 56, 24, 16 + } ), true ); @@ -130,8 +132,10 @@ TEST(ModuleTest, LinearFwd) { expected_outVar = Variable( Tensor::fromVector( {n_out, x, batchsize}, - {69, 24, 21, 49, 24, 17, 85, 28, 25, 67, 21, 20, - 151, 57, 44, 95, 43, 30, 131, 57, 40, 57, 26, 19} + { + 69, 24, 21, 49, 24, 17, 85, 28, 25, 67, 21, 20, + 151, 57, 44, 95, 43, 30, 131, 57, 40, 57, 26, 19 + } ), true ); @@ -159,8 +163,10 @@ TEST_F(ModuleTestF16, LinearFwdF16) { auto expected_outVar = Variable( Tensor::fromVector( {n_out, x, batchsize}, - {68, 22, 18, 48, 22, 14, 84, 26, 22, 66, 19, 17, - 150, 55, 41, 94, 41, 27, 130, 55, 37, 56, 24, 16} + { + 68, 22, 18, 48, 22, 14, 84, 26, 22, 66, 19, 17, + 150, 55, 41, 94, 41, 27, 130, 55, 37, 56, 24, 16 + } ) .astype(fl::dtype::f16), true @@ -176,8 +182,10 @@ TEST_F(ModuleTestF16, LinearFwdF16) { expected_outVar = Variable( Tensor::fromVector( {n_out, x, batchsize}, - {69, 24, 21, 49, 24, 17, 85, 28, 25, 67, 21, 20, - 151, 57, 44, 95, 43, 30, 131, 57, 40, 57, 26, 19} + { + 69, 24, 21, 49, 24, 17, 85, 28, 25, 67, 21, 20, + 151, 57, 44, 95, 43, 30, 131, 57, 40, 57, 26, 19 + } ) .astype(inVar.type()), true @@ -511,19 +519,21 @@ TEST(ModuleTest, RNNFwd) { auto expected_outVar = Variable( Tensor::fromVector( expected_dims, - {1.5418, 1.6389, 1.7361, 1.5491, 1.6472, 1.7452, 1.5564, - 1.6554, 1.7544, 1.5637, 1.6637, 1.7636, 1.5710, 1.6719, - 1.7728, 3.4571, 3.7458, 4.0345, 3.4761, 3.7670, 4.0578, - 3.4951, 3.7881, 4.0812, 3.5141, 3.8093, 4.1045, 3.5331, - 3.8305, 4.1278, 5.6947, 6.2004, 6.7060, 5.7281, 6.2373, - 6.7466, 5.7614, 6.2743, 6.7871, 5.7948, 6.3112, 6.8276, - 5.8282, 6.3482, 6.8681, 8.2005, 8.9458, 9.6911, 8.2500, - 9.0005, 9.7509, 8.2995, 9.0551, 9.8107, 8.3491, 9.1098, - 9.8705, 8.3986, 9.1645, 9.9303, 10.9520, 11.9587, 12.9655, - 11.0191, 12.0326, 13.0462, 11.0861, 12.1065, 13.1269, 11.1532, - 12.1804, 13.2075, 11.2203, 12.2543, 13.2882, 13.9432, 15.2333, - 16.5233, 14.0291, 15.3277, 16.6263, 14.1149, 15.4221, 16.7292, - 14.2008, 15.5165, 16.8322, 14.2866, 15.6109, 16.9351} + { + 1.5418, 1.6389, 1.7361, 1.5491, 1.6472, 1.7452, 1.5564, + 1.6554, 1.7544, 1.5637, 1.6637, 1.7636, 1.5710, 1.6719, + 1.7728, 3.4571, 3.7458, 4.0345, 3.4761, 3.7670, 4.0578, + 3.4951, 3.7881, 4.0812, 3.5141, 3.8093, 4.1045, 3.5331, + 3.8305, 4.1278, 5.6947, 6.2004, 6.7060, 5.7281, 6.2373, + 6.7466, 5.7614, 6.2743, 6.7871, 5.7948, 6.3112, 6.8276, + 5.8282, 6.3482, 6.8681, 8.2005, 8.9458, 9.6911, 8.2500, + 9.0005, 9.7509, 8.2995, 9.0551, 9.8107, 8.3491, 9.1098, + 9.8705, 8.3986, 9.1645, 9.9303, 10.9520, 11.9587, 12.9655, + 11.0191, 12.0326, 13.0462, 11.0861, 12.1065, 13.1269, 11.1532, + 12.1804, 13.2075, 11.2203, 12.2543, 13.2882, 13.9432, 15.2333, + 16.5233, 14.0291, 15.3277, 16.6263, 14.1149, 15.4221, 16.7292, + 14.2008, 15.5165, 16.8322, 14.2866, 15.6109, 16.9351 + } ), true ); @@ -560,9 +570,11 @@ TEST(ModuleTest, LSTMFwd) { auto expected_outVar = Variable( Tensor::fromVector( expected_dims, - {0.7390, 0.7395, 0.7399, 0.7403, 0.7407, 0.7390, 0.7395, - 0.7399, 0.7403, 0.7407, 0.9617, 0.9618, 0.9619, 0.9619, - 0.962, 0.9617, 0.9618, 0.9619, 0.9619, 0.962} + { + 0.7390, 0.7395, 0.7399, 0.7403, 0.7407, 0.7390, 0.7395, + 0.7399, 0.7403, 0.7407, 0.9617, 0.9618, 0.9619, 0.9619, + 0.962, 0.9617, 0.9618, 0.9619, 0.9619, 0.962 + } ), true ); @@ -599,9 +611,11 @@ TEST(ModuleTest, GRUFwd) { auto expected_outVar = Variable( Tensor::fromVector( expected_dims, - {0.1430, 0.1425, 0.1419, 0.1413, 0.1408, 0.1430, 0.1425, - 0.1419, 0.1413, 0.1408, 0.2206, 0.2194, 0.2181, 0.2168, - 0.2155, 0.2206, 0.2194, 0.2181, 0.2168, 0.2155} + { + 0.1430, 0.1425, 0.1419, 0.1413, 0.1408, 0.1430, 0.1425, + 0.1419, 0.1413, 0.1408, 0.2206, 0.2194, 0.2181, 0.2168, + 0.2155, 0.2206, 0.2194, 0.2181, 0.2168, 0.2155 + } ), true ); @@ -639,19 +653,21 @@ TEST_F(ModuleTestF16, RNNFwdF16) { auto expected_outVar = Variable( Tensor::fromVector( expected_dims, - {1.5418, 1.6389, 1.7361, 1.5491, 1.6472, 1.7452, 1.5564, - 1.6554, 1.7544, 1.5637, 1.6637, 1.7636, 1.5710, 1.6719, - 1.7728, 3.4571, 3.7458, 4.0345, 3.4761, 3.7670, 4.0578, - 3.4951, 3.7881, 4.0812, 3.5141, 3.8093, 4.1045, 3.5331, - 3.8305, 4.1278, 5.6947, 6.2004, 6.7060, 5.7281, 6.2373, - 6.7466, 5.7614, 6.2743, 6.7871, 5.7948, 6.3112, 6.8276, - 5.8282, 6.3482, 6.8681, 8.2005, 8.9458, 9.6911, 8.2500, - 9.0005, 9.7509, 8.2995, 9.0551, 9.8107, 8.3491, 9.1098, - 9.8705, 8.3986, 9.1645, 9.9303, 10.9520, 11.9587, 12.9655, - 11.0191, 12.0326, 13.0462, 11.0861, 12.1065, 13.1269, 11.1532, - 12.1804, 13.2075, 11.2203, 12.2543, 13.2882, 13.9432, 15.2333, - 16.5233, 14.0291, 15.3277, 16.6263, 14.1149, 15.4221, 16.7292, - 14.2008, 15.5165, 16.8322, 14.2866, 15.6109, 16.9351} + { + 1.5418, 1.6389, 1.7361, 1.5491, 1.6472, 1.7452, 1.5564, + 1.6554, 1.7544, 1.5637, 1.6637, 1.7636, 1.5710, 1.6719, + 1.7728, 3.4571, 3.7458, 4.0345, 3.4761, 3.7670, 4.0578, + 3.4951, 3.7881, 4.0812, 3.5141, 3.8093, 4.1045, 3.5331, + 3.8305, 4.1278, 5.6947, 6.2004, 6.7060, 5.7281, 6.2373, + 6.7466, 5.7614, 6.2743, 6.7871, 5.7948, 6.3112, 6.8276, + 5.8282, 6.3482, 6.8681, 8.2005, 8.9458, 9.6911, 8.2500, + 9.0005, 9.7509, 8.2995, 9.0551, 9.8107, 8.3491, 9.1098, + 9.8705, 8.3986, 9.1645, 9.9303, 10.9520, 11.9587, 12.9655, + 11.0191, 12.0326, 13.0462, 11.0861, 12.1065, 13.1269, 11.1532, + 12.1804, 13.2075, 11.2203, 12.2543, 13.2882, 13.9432, 15.2333, + 16.5233, 14.0291, 15.3277, 16.6263, 14.1149, 15.4221, 16.7292, + 14.2008, 15.5165, 16.8322, 14.2866, 15.6109, 16.9351 + } ), true ); @@ -982,7 +998,8 @@ TEST(ModuleTest, IdentityFwd) { auto module = Identity(); std::vector in = { Variable(fl::rand({1000, 1000}), true), - Variable(fl::rand({100, 100}), true)}; + Variable(fl::rand({100, 100}), true) + }; // Train Mode module.train(); diff --git a/flashlight/fl/test/tensor/TensorBaseTest.cpp b/flashlight/fl/test/tensor/TensorBaseTest.cpp index d6a4e1e..2a7e1b1 100644 --- a/flashlight/fl/test/tensor/TensorBaseTest.cpp +++ b/flashlight/fl/test/tensor/TensorBaseTest.cpp @@ -532,7 +532,8 @@ TEST(TensorBaseTest, scalar) { fl::dtype::u64, fl::dtype::f16, fl::dtype::f32, - fl::dtype::f64}; + fl::dtype::f64 + }; for(auto type : types) { assertScalarBehavior(type); assertScalarBehavior(type); diff --git a/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp b/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp index d551fab..14cfe15 100644 --- a/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp +++ b/flashlight/fl/test/tensor/TensorBinaryOpsTest.cpp @@ -399,7 +399,17 @@ TEST(TensorBinaryOpsTest, broadcasting) { {{1, 10}, {8, 10}, {8, 1}, {1, 1}}, {{2, 1, 5, 1}, {2, 3, 5, 3}, {1, 3, 1, 3}, {1, 1, 1, 1}}, {{3, 1, 2, 1}, {1, 4, 1, 5}, {1, 4, 1, 5}, {3, 1, 2, 1}}, - {{3, 2, 1}, {3, 1, 4, 1}, {1, 1, 4}, {1, 2, 1, 1}}}; + {{ + 3, 2, 1 + }, { + 3, 1, 4, 1 + }, { + 1, 1, 4 + }, { + 1, 2, 1, 1 + } + } + }; std::unordered_map functions = { {fl::minimum, "minimum"}, @@ -422,7 +432,10 @@ TEST(TensorBinaryOpsTest, broadcasting) { {fl::bitwiseOr, "bitwiseOr"}, {fl::bitwiseXor, "bitwiseXor"}, {fl::lShift, "lShift"}, - {fl::rShift, "rShift"}}; + { + fl::rShift, "rShift" + } + }; auto doBinaryOp = [](const Tensor& lhs, const Tensor& rhs, @@ -431,7 +444,8 @@ TEST(TensorBinaryOpsTest, broadcasting) { binaryOpFunc_t func) -> std::pair { assert(lhs.ndim() <= rhs.ndim()); return { - func(lhs, rhs), func(tile(lhs, tileShapeLhs), tile(rhs, tileShapeRhs))}; + func(lhs, rhs), func(tile(lhs, tileShapeLhs), tile(rhs, tileShapeRhs)) + }; }; auto computeBroadcastShape = [](const Shape& lhsShape, diff --git a/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp b/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp index a53f920..296e5fb 100644 --- a/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp +++ b/flashlight/fl/test/tensor/af/MemoryFrameworkTest.cpp @@ -466,7 +466,8 @@ TEST(MemoryFramework, AdapterInstallerDeviceInterfaceTest) { "nativeFree", "nativeFree", "shutdown", - "shutdown"}; + "shutdown" + }; size_t idx = 0; for(std::string line; std::getline(logStream, line);) { EXPECT_EQ(line.substr(0, line.find(' ')), expectedLinePrefixes[idx]); diff --git a/flashlight/pkg/runtime/common/DistributedUtils.cpp b/flashlight/pkg/runtime/common/DistributedUtils.cpp index 5ec02b9..24512ac 100644 --- a/flashlight/pkg/runtime/common/DistributedUtils.cpp +++ b/flashlight/pkg/runtime/common/DistributedUtils.cpp @@ -22,17 +22,25 @@ void initDistributed( fl::DistributedInit::MPI, -1, // unused for MPI -1, // unused for MPI - {{fl::DistributedConstants::kMaxDevicePerNode, - std::to_string(maxDevicesPerNode)}} + {{ + fl::DistributedConstants::kMaxDevicePerNode, + std::to_string(maxDevicesPerNode) + } + } ); else distributedInit( fl::DistributedInit::FILE_SYSTEM, worldRank, worldSize, - {{fl::DistributedConstants::kMaxDevicePerNode, - std::to_string(maxDevicesPerNode)}, - {fl::DistributedConstants::kFilePath, rndvFilepath}} + {{ + fl::DistributedConstants::kMaxDevicePerNode, + std::to_string(maxDevicesPerNode) + }, + { + fl::DistributedConstants::kFilePath, rndvFilepath + } + } ); } diff --git a/flashlight/pkg/runtime/common/SequentialBuilder.cpp b/flashlight/pkg/runtime/common/SequentialBuilder.cpp index 36de0ed..462c7ce 100644 --- a/flashlight/pkg/runtime/common/SequentialBuilder.cpp +++ b/flashlight/pkg/runtime/common/SequentialBuilder.cpp @@ -136,7 +136,10 @@ std::shared_ptr parseLines( {std::stoi(params[2]), std::stoi(params[3])}, {std::stoi(params[4]), std::stoi(params[5])}, {std::stoi(params[6]), std::stoi(params[7])}, - {std::stoi(params[8]), std::stoi(params[9])}}; + { + std::stoi(params[8]), std::stoi(params[9]) + } + }; // TODO{fl::Tensor} -- rearrange arguments return std::make_shared(paddings, val); } diff --git a/flashlight/pkg/speech/augmentation/SoxWrapper.cpp b/flashlight/pkg/speech/augmentation/SoxWrapper.cpp index 4d2e1dd..3c55cb4 100644 --- a/flashlight/pkg/speech/augmentation/SoxWrapper.cpp +++ b/flashlight/pkg/speech/augmentation/SoxWrapper.cpp @@ -74,7 +74,8 @@ namespace { .channels = 1, // Sounds effects are limited to single channel .precision = 16, // Any valid value is ok here. .length = 0, - .mult = nullptr}; + .mult = nullptr + }; return sigInfo; } @@ -131,7 +132,8 @@ void SoxWrapper::addInput( /*drain=*/ inputDrain, /*stop=*/ nullptr, /*kill=*/ nullptr, - /*priv_size=*/ sizeof(SoxData)}; + /*priv_size=*/ sizeof(SoxData) + }; sox_effect_t* e = nullptr; FL_SOX_CHECK(e = sox_create_effect(&handler)); auto input = (SoxData*) e->priv; @@ -155,7 +157,8 @@ void SoxWrapper::addOutput( /*drain=*/ nullptr, /*stop=*/ nullptr, /*kill=*/ nullptr, - /*priv_size=*/ sizeof(SoxData)}; + /*priv_size=*/ sizeof(SoxData) + }; sox_effect_t* e = nullptr; FL_SOX_CHECK(e = sox_create_effect(&handler)); auto output = (SoxData*) e->priv; @@ -182,7 +185,8 @@ sox_effects_chain_t* SoxWrapper::createChain() const { .reverse_bytes = sox_option_no, .reverse_nibbles = sox_option_no, .reverse_bits = sox_option_no, - .opposite_endian = sox_false}; + .opposite_endian = sox_false + }; sox_effects_chain_t* chain = nullptr; FL_SOX_CHECK(chain = sox_create_effects_chain(&encoding, &encoding)); return chain; diff --git a/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h b/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h index 9344931..e7370ad 100644 --- a/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h +++ b/flashlight/pkg/speech/criterion/AutoSegmentationCriterion.h @@ -48,7 +48,8 @@ namespace pkg { throw std::invalid_argument("Invalid inputs size"); return { fcc_.forward(inputs[0], inputs[1]) - - fac_.forward(inputs[0], inputs[1])}; + - fac_.forward(inputs[0], inputs[1]) + }; } Tensor viterbiPath(const Tensor& input, const Tensor& inputSize = Tensor()) diff --git a/flashlight/pkg/speech/data/ListFileDataset.cpp b/flashlight/pkg/speech/data/ListFileDataset.cpp index 075cfda..c8c56eb 100644 --- a/flashlight/pkg/speech/data/ListFileDataset.cpp +++ b/flashlight/pkg/speech/data/ListFileDataset.cpp @@ -125,7 +125,8 @@ std::vector ListFileDataset::get(const int64_t idx) const { sampleIdx, samplePath, sampleDuration, - sampleTargetSize}; + sampleTargetSize + }; } std::pair, Shape> ListFileDataset::loadAudio( diff --git a/flashlight/pkg/speech/data/Sound.cpp b/flashlight/pkg/speech/data/Sound.cpp index b53841a..4375f65 100644 --- a/flashlight/pkg/speech/data/Sound.cpp +++ b/flashlight/pkg/speech/data/Sound.cpp @@ -49,7 +49,10 @@ const std::unordered_map formats{ {SoundFormat::WVE, SF_FORMAT_WVE}, {SoundFormat::OGG, SF_FORMAT_OGG}, {SoundFormat::MPC2K, SF_FORMAT_MPC2K}, - {SoundFormat::RF64, SF_FORMAT_RF64}}; + { + SoundFormat::RF64, SF_FORMAT_RF64 + } +}; const std::unordered_map subformats{ {SoundSubFormat::PCM_S8, SF_FORMAT_PCM_S8}, @@ -74,7 +77,10 @@ const std::unordered_map subformats{ {SoundSubFormat::DWVW_N, SF_FORMAT_DWVW_N}, {SoundSubFormat::DPCM_8, SF_FORMAT_DPCM_8}, {SoundSubFormat::DPCM_16, SF_FORMAT_DPCM_16}, - {SoundSubFormat::VORBIS, SF_FORMAT_VORBIS}}; + { + SoundSubFormat::VORBIS, SF_FORMAT_VORBIS + } +}; } // namespace namespace fl::pkg::speech { @@ -189,11 +195,13 @@ SoundInfo loadSoundInfo(const std::string& filename) { } SoundInfo loadSoundInfo(std::istream& f) { - SF_VIRTUAL_IO vsf = {sf_vio_ro_get_filelen, - sf_vio_ro_seek, - sf_vio_ro_read, - sf_vio_ro_write, - sf_vio_ro_tell}; + SF_VIRTUAL_IO vsf = { + sf_vio_ro_get_filelen, + sf_vio_ro_seek, + sf_vio_ro_read, + sf_vio_ro_write, + sf_vio_ro_tell + }; SNDFILE* file; SF_INFO info; @@ -225,11 +233,13 @@ std::vector loadSound(const std::string& filename) { template std::vector loadSound(std::istream& f) { - SF_VIRTUAL_IO vsf = {sf_vio_ro_get_filelen, - sf_vio_ro_seek, - sf_vio_ro_read, - sf_vio_ro_write, - sf_vio_ro_tell}; + SF_VIRTUAL_IO vsf = { + sf_vio_ro_get_filelen, + sf_vio_ro_seek, + sf_vio_ro_read, + sf_vio_ro_write, + sf_vio_ro_tell + }; SNDFILE* file; SF_INFO info; @@ -288,11 +298,13 @@ void saveSound( SoundFormat format, SoundSubFormat subformat ) { - SF_VIRTUAL_IO vsf = {sf_vio_wo_get_filelen, - sf_vio_wo_seek, - sf_vio_wo_read, - sf_vio_wo_write, - sf_vio_wo_tell}; + SF_VIRTUAL_IO vsf = { + sf_vio_wo_get_filelen, + sf_vio_wo_seek, + sf_vio_wo_read, + sf_vio_wo_write, + sf_vio_wo_tell + }; SNDFILE* file; SF_INFO info; diff --git a/flashlight/pkg/speech/decoder/DecodeMaster.cpp b/flashlight/pkg/speech/decoder/DecodeMaster.cpp index 177a3fa..c2d08d7 100644 --- a/flashlight/pkg/speech/decoder/DecodeMaster.cpp +++ b/flashlight/pkg/speech/decoder/DecodeMaster.cpp @@ -146,8 +146,10 @@ std::shared_ptr DecodeMaster::forward( continue; if(usePlugin_) output = net_->forward( - {fl::input(batch[kInputIdx]), - fl::noGrad(batch[kDurationIdx])} + { + fl::input(batch[kInputIdx]), + fl::noGrad(batch[kDurationIdx]) + } ) .front() .tensor(); @@ -248,7 +250,8 @@ std::shared_ptr TokenDecodeMaster::decode( .lmWeight = opt.lmWeight, .silScore = opt.silScore, .logAdd = opt.logAdd, - .criterionType = fl::lib::text::CriterionType::CTC}; + .criterionType = fl::lib::text::CriterionType::CTC + }; auto silIdx = tokenDict_.getIndex(opt.silToken); auto blankIdx = tokenDict_.getIndex(opt.blankToken); fl::lib::text::LexiconFreeDecoder decoder( @@ -271,7 +274,8 @@ std::shared_ptr TokenDecodeMaster::decode( .unkScore = opt.unkScore, .silScore = opt.silScore, .logAdd = opt.logAdd, - .criterionType = fl::lib::text::CriterionType::CTC}; + .criterionType = fl::lib::text::CriterionType::CTC + }; auto silIdx = tokenDict_.getIndex(opt.silToken); auto blankIdx = tokenDict_.getIndex(opt.blankToken); auto unkWordIdx = wordDict_.getIndex(fl::lib::text::kUnkToken); @@ -336,7 +340,8 @@ std::shared_ptr WordDecodeMaster::decode( .unkScore = opt.unkScore, .silScore = opt.silScore, .logAdd = opt.logAdd, - .criterionType = fl::lib::text::CriterionType::CTC}; + .criterionType = fl::lib::text::CriterionType::CTC + }; auto silIdx = tokenDict_.getIndex(opt.silToken); auto blankIdx = tokenDict_.getIndex(opt.blankToken); auto unkWordIdx = wordDict_.getIndex(opt.unkToken); diff --git a/flashlight/pkg/speech/decoder/PlGenerator.cpp b/flashlight/pkg/speech/decoder/PlGenerator.cpp index f73537d..00886e6 100644 --- a/flashlight/pkg/speech/decoder/PlGenerator.cpp +++ b/flashlight/pkg/speech/decoder/PlGenerator.cpp @@ -205,8 +205,10 @@ std::string PlGenerator::regeneratePl( if(usePlugin) rawEmission = ntwrk ->forward( - {fl::input(sample[kInputIdx]), - fl::noGrad(sample[kDurationIdx])} + { + fl::input(sample[kInputIdx]), + fl::noGrad(sample[kDurationIdx]) + } ) .front(); else diff --git a/flashlight/pkg/speech/test/audio/CeplifterTest.cpp b/flashlight/pkg/speech/test/audio/CeplifterTest.cpp index 9fdd255..d142c76 100644 --- a/flashlight/pkg/speech/test/audio/CeplifterTest.cpp +++ b/flashlight/pkg/speech/test/audio/CeplifterTest.cpp @@ -25,7 +25,8 @@ TEST(CeplifterTest, matlabCompareTest) { 1, 2.565463, 4.099058, 5.569565, 6.947048, 8.203468, 9.313245, 10.25378, 11.00595, 11.55442, 11.88803, 12, 11.88803, 11.55442, 11.00595, 10.25378, 9.313245, 8.203468, 6.947048, 5.569565, 4.099058, - 2.565463, 1.000000, -0.5654632, -2.0990581}; + 2.565463, 1.000000, -0.5654632, -2.0990581 + }; auto output1 = cep1.apply(input1); // Implementation should match with matlab for Test case 1. ASSERT_TRUE(compareVec(output1, matlaboutput1)); @@ -37,7 +38,8 @@ TEST(CeplifterTest, matlabCompareTest) { 4.798719, 1.701928, 2.926338, 1.119059, 3.756335, 1.275475, 2.529785, 3.495383, 4.454516, 4.796457, 2.736077, 0.6931222, 0.7464700, 1.287541, 4.203586, 1.271410, 4.071424, 1.217624, 4.646318, 1.749918, 0.9829762, - 1.255419, 3.080223, 2.366444, 1.758297, 4.154143}; + 1.255419, 3.080223, 2.366444, 1.758297, 4.154143 + }; std::vector matlaboutput2{ 3.82758300, 10.1608714, 3.75679389, 13.0039674, 14.1460142, 22.8717424, 26.4330877, 28.1219157, 9.76798039, 21.5785018, @@ -46,7 +48,8 @@ TEST(CeplifterTest, matlabCompareTest) { -13.7939251, -17.7481762, -19.3744501, -15.8776985, -5.52879248, -0.385065298, 0.746470000, 3.29037774, 16.9013608, 6.75156506, 25.8510797, 8.61786241, 34.6271852, 13.0414523, 6.95711783, - 7.97115128, 16.3568998, 9.51476285, 4.49341909, 4.15414300}; + 7.97115128, 16.3568998, 9.51476285, 4.49341909, 4.15414300 + }; auto output2 = cep2.apply(input2); // Implementation should match with matlab for Test case 2. ASSERT_TRUE(compareVec(output2, matlaboutput2)); diff --git a/flashlight/pkg/speech/test/audio/DctTest.cpp b/flashlight/pkg/speech/test/audio/DctTest.cpp index cf6cdae..c407ae2 100644 --- a/flashlight/pkg/speech/test/audio/DctTest.cpp +++ b/flashlight/pkg/speech/test/audio/DctTest.cpp @@ -35,12 +35,14 @@ TEST(DctTest, matlabCompareTest) { 4.798719, 1.701928, 2.926338, 1.119059, 3.756335, 1.275475, 2.529785, 3.495383, 4.454516, 4.796457, 2.736077, 0.6931222, 0.7464700, 1.287541, 4.203586, 1.271410, 4.071424, 1.217624, 4.646318, 1.749918, 0.9829762, - 1.255419, 3.080223, 2.366444, 1.758297, 4.154143}; + 1.255419, 3.080223, 2.366444, 1.758297, 4.154143 + }; std::vector matlaboutput2{ 23.03049, 0.7171224, 0.09039740, 0.5560513, 1.210070, -0.6701894, -0.7615307, 0.1116579, 1.157483, -2.012746, 2.964205, 2.444191, -0.4926429, -0.1332636, 1.275104, 0.2767147, 0.2781188, 2.661390, - -0.03644234, -2.326455, -0.1963445, -1.229159, 2.124846}; + -0.03644234, -2.326455, -0.1963445, -1.229159, 2.124846 + }; auto output2 = dct2.apply(input2); // Implementation should match with matlab for Test case 2. ASSERT_TRUE(compareVec(output2, matlaboutput2)); diff --git a/flashlight/pkg/speech/test/audio/DerivativesTest.cpp b/flashlight/pkg/speech/test/audio/DerivativesTest.cpp index 0b25fb0..0443de9 100644 --- a/flashlight/pkg/speech/test/audio/DerivativesTest.cpp +++ b/flashlight/pkg/speech/test/audio/DerivativesTest.cpp @@ -51,7 +51,8 @@ TEST(DerivativesTest, matlabCompareTest) { 0.5000000, 0.6666667, 0.8166667, 0.9333333, 1.0000000, 1.0000000, 1.0000000, 1.0000000, 0.9333333, 0.8166667, 0.6666667, 0.5000000, 0.0683333, 0.0780556, 0.0794444, 0.0725000, 0.0527778, 0.0180556, - -0.0180556, -0.0527778, -0.0725000, -0.0794444, -0.0780556, -0.0683333}; + -0.0180556, -0.0527778, -0.0725000, -0.0794444, -0.0780556, -0.0683333 + }; auto output1 = dev1.apply(input1, 1); // Implementation should match with matlab for Test case 1. ASSERT_TRUE(compareVec(output1, transposeVec(matlaboutput1, 3, 12))); @@ -64,7 +65,8 @@ TEST(DerivativesTest, matlabCompareTest) { 4.798719, 1.701928, 2.926338, 1.119059, 3.756335, 1.275475, 2.529785, 3.495383, 4.454516, 4.796457, 2.736077, 0.6931222, 0.7464700, 1.287541, 4.203586, 1.271410, 4.071424, 1.217624, 4.646318, 1.749918, 0.9829762, - 1.255419, 3.080223, 2.366444, 1.758297, 4.154143}; + 1.255419, 3.080223, 2.366444, 1.758297, 4.154143 + }; std::vector matlaboutput2{ 3.827583, 3.975999, 0.9343630, 2.448821, 2.227931, 3.231565, 3.546824, 3.773433, 1.380125, 3.398513, 3.275490, 0.8130586, @@ -85,7 +87,8 @@ TEST(DerivativesTest, matlabCompareTest) { 0.0005322, -0.0002259, -0.0021479, -0.0036494, -0.0056067, -0.0061552, -0.0065865, -0.0057748, -0.0041803, -0.0013468, 0.0018477, 0.0064985, 0.0102782, 0.0132019, 0.0138463, 0.0156723, 0.0171224, 0.0170120, - 0.0159708, 0.0139536, 0.0118158, 0.0081756, 0.0046038, 0.0015992}; + 0.0159708, 0.0139536, 0.0118158, 0.0081756, 0.0046038, 0.0015992 + }; auto output2 = dev2.apply(input2, 1); // Implementation should match with matlab for Test case 2. ASSERT_TRUE(compareVec(output2, transposeVec(matlaboutput2, 3, 40))); diff --git a/flashlight/pkg/speech/test/audio/PreEmphasisTest.cpp b/flashlight/pkg/speech/test/audio/PreEmphasisTest.cpp index 131602f..4ab468c 100644 --- a/flashlight/pkg/speech/test/audio/PreEmphasisTest.cpp +++ b/flashlight/pkg/speech/test/audio/PreEmphasisTest.cpp @@ -17,37 +17,43 @@ using fl::lib::audio::PreEmphasis; TEST(PreEmphasisTest, matlabCompareTest) { int N = 8; PreEmphasis preemphasis1d(0.95, N); - std::vector input{0.098589, - 0.715877, - 0.750572, - 0.787636, - 0.116829, - 0.242914, - 0.327526, - 0.410389}; + std::vector input{ + 0.098589, + 0.715877, + 0.750572, + 0.787636, + 0.116829, + 0.242914, + 0.327526, + 0.410389 + }; // ndim = 1 - std::vector matlaboutput1d{0.004929, - 0.622218, - 0.070489, - 0.074592, - -0.631425, - 0.131927, - 0.096757, - 0.099240}; + std::vector matlaboutput1d{ + 0.004929, + 0.622218, + 0.070489, + 0.074592, + -0.631425, + 0.131927, + 0.096757, + 0.099240 + }; auto output1d = preemphasis1d.apply(input); // Implementation should match with matlab. ASSERT_TRUE(compareVec(output1d, matlaboutput1d)); // ndim = 2 PreEmphasis preemphasis2d(0.95, N / 2); - std::vector matlaboutput2d{0.004929, - 0.622218, - 0.070489, - 0.074592, - 0.005841, - 0.131927, - 0.096757, - 0.099240}; + std::vector matlaboutput2d{ + 0.004929, + 0.622218, + 0.070489, + 0.074592, + 0.005841, + 0.131927, + 0.096757, + 0.099240 + }; auto output2d = preemphasis2d.apply(input); // Implementation should match with matlab. ASSERT_TRUE(compareVec(output2d, matlaboutput2d)); diff --git a/flashlight/pkg/speech/test/audio/TriFilterbankTest.cpp b/flashlight/pkg/speech/test/audio/TriFilterbankTest.cpp index 5bcf84e..556af2c 100644 --- a/flashlight/pkg/speech/test/audio/TriFilterbankTest.cpp +++ b/flashlight/pkg/speech/test/audio/TriFilterbankTest.cpp @@ -30,7 +30,8 @@ TEST(TriFilterbankTest, matlabCompareTest) { 0, 0, 0, 0.763933, 0.236067, 0, 0, 0, 0, 0, 0, 0, 0, 0.082177, 0.917823, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.532067, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0}; + 0, 0 + }; auto outputfbank1 = triflt1.filterbank(); // Implementation should match with matlab for Test case 1. ASSERT_TRUE(compareVec(outputfbank1, matlabfbank1)); @@ -42,12 +43,14 @@ TEST(TriFilterbankTest, matlabCompareTest) { 0.438744, 0.381558, 0.765516, 0.795199, 0.186872, 0.489764, 0.445586, 0.646313, 0.709364, 0.754686, 0.276025, 0.679702, 0.655098, 0.162611, 0.118997, 0.498364, 0.959743, 0.340385, 0.585267, 0.223811, 0.751267, - 0.255095, 0.505957, 0.699076, 0.890903, 0.959291}; + 0.255095, 0.505957, 0.699076, 0.890903, 0.959291 + }; std::vector matlabop2{ 0.578693, 0.131362, 0.301871, 0.426760, 0.523461, 0.0338169, 0.285265, 0.311304, 0.424245, 0.714087, 0.680402, 0.267582, 0.526783, 0.612373, 0.814208, 0.962699, 0.620225, 0.907083, - 0.326320, 0.879130, 1.07004, 0.844134, 0.957356}; + 0.326320, 0.879130, 1.07004, 0.844134, 0.957356 + }; auto output2 = triflt2.apply(input2); // Implementation should match with matlab for Test case 2. diff --git a/flashlight/pkg/speech/test/audio/WindowingTest.cpp b/flashlight/pkg/speech/test/audio/WindowingTest.cpp index 2750c7b..f473087 100644 --- a/flashlight/pkg/speech/test/audio/WindowingTest.cpp +++ b/flashlight/pkg/speech/test/audio/WindowingTest.cpp @@ -41,7 +41,8 @@ TEST(WindowingTest, hanningCoeffsTest) { 0.00000, 0.01024, 0.04052, 0.08962, 0.15552, 0.23552, 0.32635, 0.42429, 0.52532, 0.62533, 0.72020, 0.80605, 0.87938, 0.93717, 0.97707, 0.99743, 0.99743, 0.97707, 0.93717, 0.87938, 0.80605, 0.72020, 0.62533, 0.52532, - 0.42429, 0.32635, 0.23552, 0.15552, 0.08962, 0.04052, 0.01024, 0.00000}; + 0.42429, 0.32635, 0.23552, 0.15552, 0.08962, 0.04052, 0.01024, 0.00000 + }; std::vector input(N, 1.0); auto output = hannwindow.apply(input); // Hamming window coefficients should match with matlab implementation. diff --git a/flashlight/pkg/speech/test/augmentation/TimeStretchTest.cpp b/flashlight/pkg/speech/test/augmentation/TimeStretchTest.cpp index 7a758e2..a2f23dd 100644 --- a/flashlight/pkg/speech/test/augmentation/TimeStretchTest.cpp +++ b/flashlight/pkg/speech/test/augmentation/TimeStretchTest.cpp @@ -37,7 +37,8 @@ TEST(TimeStretch, SinWave) { std::vector augmented = signal; TimeStretch::Config conf = { - .proba_ = 1.0, .minFactor_ = factor, .maxFactor_ = factor}; + .proba_ = 1.0, .minFactor_ = factor, .maxFactor_ = factor + }; TimeStretch sfx(conf); sfx.apply(augmented); diff --git a/flashlight/pkg/speech/test/criterion/CriterionTest.cpp b/flashlight/pkg/speech/test/criterion/CriterionTest.cpp index 4b09a28..5d488d6 100644 --- a/flashlight/pkg/speech/test/criterion/CriterionTest.cpp +++ b/flashlight/pkg/speech/test/criterion/CriterionTest.cpp @@ -201,7 +201,8 @@ TEST(CriterionTest, CTCCompareTensorflow) { 0.111121, -0.411608, 0.278779, 0.0055756, 0.00569609, 0.010436, 0.0357786, 0.633813, -0.678582, 0.00249248, 0.00272882, 0.0037688, 0.0663296, -0.356151, 0.280111, 0.00283995, 0.0035545, 0.00331533, - -0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107}; + -0.541765, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107 + }; auto ctc1 = ConnectionistTemporalClassificationCriterion(); auto input1af = Variable(Tensor::fromArray({N1, T1, 1}, input1), true); @@ -331,7 +332,8 @@ TEST(CriterionTest, ASGAlternatingBlanks) { 0x1.9bc5aep-1, 0x1.3c7dacp-1, 0x1.3e2852p-1, 0x1.6699f4p-1, 0x1.095a5p+0, 0x1.1840bcp-1, 0x1.465a4ep-1, 0x1.2c4cacp-1, 0x1.754998p-1, 0x1.cb6698p-2, -0x1.1cadcp+0, 0x1.757b88p-2, - 0x1.3dec32p+0, 0x1.320fp+0, -0x1.9eb1a4p-1, -0x1.e43beap-2}; + 0x1.3dec32p+0, 0x1.320fp+0, -0x1.9eb1a4p-1, -0x1.e43beap-2 + }; Tensor x = Tensor::fromVector(Shape({C, T, B}), xV); Tensor y = fl::full({mL* 2 + 1, B}, -1, fl::dtype::s32); int L; @@ -485,7 +487,8 @@ TEST(CriterionTest, CTCViterbiPathConstrainedBeginAndEndWithSpace) { TEST(CriterionTest, FCCCost) { // Test case: 1 std::array input1 = { - 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}; + 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0 + }; std::transform( input1.begin(), input1.end(), @@ -560,7 +563,8 @@ TEST(CriterionTest, FCCJacobian) { TEST(CriterionTest, FACCost) { // Test case: 1 std::array input1 = { - 1.0, 0.0, 0.0, 1.0, 0.5, 0.5, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}; + 1.0, 0.0, 0.0, 1.0, 0.5, 0.5, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0 + }; std::array target1 = {0, 1, 0, 1}; const int N1 = 2, L1 = 2, T1 = 3, B1 = 2; @@ -612,7 +616,8 @@ TEST(CriterionTest, ASGCost) { // Test case: 1 constexpr int N1 = 2, L1 = 2, T1 = 3, B1 = 2; std::array input1 = { - 1.0, 0.0, 0.0, 1.0, 0.5, 0.5, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}; + 1.0, 0.0, 0.0, 1.0, 0.5, 0.5, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0 + }; std::transform( input1.begin(), input1.end(), diff --git a/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp b/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp index a2c3dbc..a54c7b8 100644 --- a/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp +++ b/flashlight/pkg/speech/test/criterion/Seq2SeqTest.cpp @@ -263,8 +263,10 @@ TEST(Seq2SeqTest, Seq2SeqMixedAttn) { N - 2, N - 1, maxoutputlen, - {std::make_shared(), - std::make_shared(H, nHead)}, + { + std::make_shared(), + std::make_shared(H, nHead) + }, std::make_shared(1, 20, 2.2, 5.8), false, 100, @@ -387,7 +389,8 @@ TEST(Seq2SeqTest, BatchedDecoderStep) { nRnnLayer, nAttnRound, 0.0 - )}; + ) + }; for(auto& seq2seq : criterions) { seq2seq.eval(); diff --git a/flashlight/pkg/speech/test/data/FeaturizationTest.cpp b/flashlight/pkg/speech/test/data/FeaturizationTest.cpp index 3f82c99..29b0bb1 100644 --- a/flashlight/pkg/speech/test/data/FeaturizationTest.cpp +++ b/flashlight/pkg/speech/test/data/FeaturizationTest.cpp @@ -142,7 +142,10 @@ TEST(FeaturizationTest, localNormalize) { auto arrVec = arr.toHostVector(); std::vector> ctx = { - {0, 0}, {1, 1}, {2, 2}, {4, 4}, {1024, 1024}, {10, 0}, {2, 12}}; + {0, 0}, {1, 1}, {2, 2}, {4, 4}, {1024, 1024}, {10, 0}, { + 2, 12 + } + }; for(auto c : ctx) { auto arrVecNrm = localNormalize( @@ -188,7 +191,8 @@ TEST(FeaturizationTest, TargetTknTestStandaloneSep) { ); std::vector resT = { - "ab", "cd", "ef", "||", "ab", "cd", "||", "t", "r", "||"}; + "ab", "cd", "ef", "||", "ab", "cd", "||", "t", "r", "||" + }; ASSERT_EQ(res.size(), resT.size()); for(int index = 0; index < res.size(); index++) ASSERT_EQ(res[index], resT[index]); @@ -205,7 +209,8 @@ TEST(FeaturizationTest, TargetTknTestStandaloneSep) { ); std::vector resT2 = { - "ab", "cd", "ef", "||", "ab", "cd", "||", "||", "t", "r"}; + "ab", "cd", "ef", "||", "ab", "cd", "||", "||", "t", "r" + }; ASSERT_EQ(res2.size(), resT2.size()); for(int index = 0; index < res2.size(); index++) ASSERT_EQ(res2[index], resT2[index]); @@ -240,7 +245,8 @@ TEST(FeaturizationTest, TargetTknTestInsideSep) { ); std::vector resT = { - "_", "a", "f", "f", "_hel", "lo", "_ma", "ma", "_", "a", "f"}; + "_", "a", "f", "f", "_hel", "lo", "_ma", "ma", "_", "a", "f" + }; ASSERT_EQ(res.size(), resT.size()); for(int index = 0; index < res.size(); index++) ASSERT_EQ(res[index], resT[index]); @@ -257,7 +263,8 @@ TEST(FeaturizationTest, TargetTknTestInsideSep) { ); std::vector resT2 = { - "a", "f", "f", "_", "_hel", "lo", "_ma", "ma", "_", "a", "f"}; + "a", "f", "f", "_", "_hel", "lo", "_ma", "ma", "_", "a", "f" + }; ASSERT_EQ(res.size(), resT2.size()); for(int index = 0; index < res2.size(); index++) ASSERT_EQ(res2[index], resT2[index]); @@ -403,7 +410,10 @@ TEST(FeaturizationTest, targetFeaturizer) { tokenDict.addEntry(kEosToken); auto lexicon = getLexicon(); std::vector> targets = { - {'a', 'b', 'c', 'c', 'c'}, {'b', 'c', 'd', 'd'}}; + {'a', 'b', 'c', 'c', 'c'}, { + 'b', 'c', 'd', 'd' + } + }; TargetGenerationConfig targetGenConfig( "", diff --git a/flashlight/pkg/text/data/TextDataset.cpp b/flashlight/pkg/text/data/TextDataset.cpp index 04e44ed..ee2e604 100644 --- a/flashlight/pkg/text/data/TextDataset.cpp +++ b/flashlight/pkg/text/data/TextDataset.cpp @@ -157,10 +157,12 @@ std::vector TextDataset::get(const int64_t idx) const { sizeof(int) * (pos.last - pos.first + 1) ); } - return {Tensor::fromVector( - {maxLength, static_cast(batch.size())}, - buffer - )}; + return { + Tensor::fromVector( + {maxLength, static_cast(batch.size())}, + buffer + ) + }; } void TextDataset::shuffle(uint64_t seed) { diff --git a/flashlight/pkg/vision/criterion/SetCriterion.cpp b/flashlight/pkg/vision/criterion/SetCriterion.cpp index 1451680..abaaa7d 100644 --- a/flashlight/pkg/vision/criterion/SetCriterion.cpp +++ b/flashlight/pkg/vision/criterion/SetCriterion.cpp @@ -222,7 +222,10 @@ SetCriterion::LossDict SetCriterion::lossBoxes( if(srcIdx.first.isEmpty()) return { {"lossGiou", fl::Variable(fl::fromScalar(0, predBoxes.type()), false)}, - {"lossBbox", fl::Variable(fl::fromScalar(0, predBoxes.type()), false)}}; + { + "lossBbox", fl::Variable(fl::fromScalar(0, predBoxes.type()), false) + } + }; auto colIdxs = fl::reshape(srcIdx.second, {1, srcIdx.second.dim(0)}); auto batchIdxs = fl::reshape(srcIdx.first, {1, srcIdx.first.dim(0)}); diff --git a/flashlight/pkg/vision/dataset/Coco.cpp b/flashlight/pkg/vision/dataset/Coco.cpp index c2991ce..4feb563 100644 --- a/flashlight/pkg/vision/dataset/Coco.cpp +++ b/flashlight/pkg/vision/dataset/Coco.cpp @@ -70,7 +70,8 @@ CocoData cocoBatchFunc(const std::vector>& batches) { makeBatch(batches[ImageIdIdx]), makeBatch(batches[OriginalSizeIdx]), batches[BboxesIdx], - batches[ClassesIdx]}; + batches[ClassesIdx] + }; } int64_t getImageId(const std::string& fp) { @@ -154,7 +155,8 @@ CocoDataset::CocoDataset( } // image, size, imageId, original_size return std::vector{ - image, targetSize, imageId, targetSize, bboxes, classes}; + image, targetSize, imageId, targetSize, bboxes, classes + }; } ); @@ -164,17 +166,24 @@ CocoDataset::CocoDataset( std::make_shared(ds, randomResize({800}, maxSize)); else { std::vector scales = { - 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800}; + 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 + }; TransformAllFunction trainTransform = compose( - {randomHorizontalFlip(0.5), - randomSelect( - {randomResize(scales, maxSize), - compose( - {randomResize({400, 500, 600}, -1), - randomSizeCrop(384, 600), - randomResize(scales, 1333)} - )} - )} + { + randomHorizontalFlip(0.5), + randomSelect( + { + randomResize(scales, maxSize), + compose( + { + randomResize({400, 500, 600}, -1), + randomSizeCrop(384, 600), + randomResize(scales, 1333) + } + ) + } + ) + } ); ds = std::make_shared(ds, trainTransform); diff --git a/flashlight/pkg/vision/dataset/CocoTransforms.cpp b/flashlight/pkg/vision/dataset/CocoTransforms.cpp index d6f7ee3..50095ed 100644 --- a/flashlight/pkg/vision/dataset/CocoTransforms.cpp +++ b/flashlight/pkg/vision/dataset/CocoTransforms.cpp @@ -58,7 +58,8 @@ std::vector crop(const std::vector& in, int x, int y, int tw, in in[ImageIdIdx], in[OriginalSizeIdx], croppedBoxes, - labels}; + labels + }; }; std::vector hflip(const std::vector& in) { @@ -81,7 +82,8 @@ std::vector hflip(const std::vector& in) { in[ImageIdIdx], in[OriginalSizeIdx], bboxes, - in[ClassesIdx]}; + in[ClassesIdx] + }; } std::vector normalize(const std::vector& in) { @@ -103,7 +105,8 @@ std::vector normalize(const std::vector& in) { in[ImageIdIdx], in[OriginalSizeIdx], boxes, - in[ClassesIdx]}; + in[ClassesIdx] + }; } std::vector randomResize(std::vector inputs, int size, int maxsize) { @@ -148,7 +151,8 @@ std::vector randomResize(std::vector inputs, int size, int maxsi const float ratioHeight = float(resizedDims[1]) / float(originalDims[1]); const std::vector resizeVector = { - ratioWidth, ratioHeight, ratioWidth, ratioHeight}; + ratioWidth, ratioHeight, ratioWidth, ratioHeight + }; Tensor resizedArray = Tensor::fromVector(resizeVector); boxes = boxes * resizedArray; } @@ -161,7 +165,8 @@ std::vector randomResize(std::vector inputs, int size, int maxsi inputs[ImageIdIdx], inputs[OriginalSizeIdx], boxes, - inputs[ClassesIdx]}; + inputs[ClassesIdx] + }; } TransformAllFunction Normalize( @@ -194,7 +199,8 @@ TransformAllFunction Normalize( in[ImageIdIdx], in[OriginalSizeIdx], boxes, - in[ClassesIdx]}; + in[ClassesIdx] + }; return outputs; }; } diff --git a/flashlight/pkg/vision/test/TransformerTest.cpp b/flashlight/pkg/vision/test/TransformerTest.cpp index b02220f..ee37ade 100644 --- a/flashlight/pkg/vision/test/TransformerTest.cpp +++ b/flashlight/pkg/vision/test/TransformerTest.cpp @@ -257,7 +257,8 @@ TEST(Tranformer, Masked) { Variable(fl::rand({maskW, maskH, C, B}), false), // input Projection Variable(fl::full({maskW, maskH, 1, B}, 1), false), // mask Variable(fl::rand({C, bbox_queries}), false), // query_embed - nonMaskPos}; + nonMaskPos + }; auto nonMaskOutput = tr(nonMaskInput)[0]; auto nonMaskedSrc = fl::rand({W, H, C, B}); @@ -270,6 +271,7 @@ TEST(Tranformer, Masked) { Variable(nonMaskedSrc, false), // input Projection Variable(mask, false), // mask nonMaskInput[2], // query_embed - maskPos}; + maskPos + }; auto maskOutput = tr(maskInput)[0]; } diff --git a/flashlight/pkg/vision/test/TransformsTest.cpp b/flashlight/pkg/vision/test/TransformsTest.cpp index 2a5ee6a..3bb6a7c 100644 --- a/flashlight/pkg/vision/test/TransformsTest.cpp +++ b/flashlight/pkg/vision/test/TransformsTest.cpp @@ -30,7 +30,8 @@ TEST(Crop, CropBasic) { Tensor(), Tensor(), Tensor::fromVector({4, 2}, bboxesVector), - fl::full({1, 2}, 0.)}; + fl::full({1, 2}, 0.) + }; // Crop from x, y (10, 10), with target heigh and width to be ten std::vector out = fl::pkg::vision::crop(in, 10, 5, 20, 25); @@ -74,7 +75,8 @@ TEST(Crop, CropClip) { Tensor(), Tensor(), Tensor::fromVector({numElementsPerBoxes, numBoxes}, bboxesVector), - fl::iota({1, 3})}; + fl::iota({1, 3}) + }; // Crop from x, y (10, 10), with target heigh and width to be ten std::vector out = fl::pkg::vision::crop(in, 5, 5, 100, 100); diff --git a/flashlight/pkg/vision/test/criterion/HungarianTest.cpp b/flashlight/pkg/vision/test/criterion/HungarianTest.cpp index e681ca8..f2e38be 100644 --- a/flashlight/pkg/vision/test/criterion/HungarianTest.cpp +++ b/flashlight/pkg/vision/test/criterion/HungarianTest.cpp @@ -76,7 +76,8 @@ TEST(HungarianTest, FullPipelineSimple2) { int M = 3; // Rows int N = 3; // Columns std::vector costsVec = { - 2500, 4000, 2000, 4000, 6000, 4000, 3500, 3500, 2500}; + 2500, 4000, 2000, 4000, 6000, 4000, 3500, 3500, 2500 + }; std::vector expAssignment = {0, 0, 1, 1, 0, 0, 0, 1, 0}; std::vector assignment(N * M); @@ -104,13 +105,17 @@ TEST(HungarianTest, FullPipelineSimple3) { TEST(HungarianTest, FullPipelineSize6) { int M = 6; // Rows int N = 6; // Columns - std::vector costsVec = {7, 9, 3, 7, 8, 4, 2, 6, 8, 9, 4, 2, - 1, 9, 3, 4, 7, 9, 9, 5, 1, 2, 4, 3, - 4, 5, 8, 2, 8, 1, 4, 2, 9, 3, 2, 9}; + std::vector costsVec = { + 7, 9, 3, 7, 8, 4, 2, 6, 8, 9, 4, 2, + 1, 9, 3, 4, 7, 9, 9, 5, 1, 2, 4, 3, + 4, 5, 8, 2, 8, 1, 4, 2, 9, 3, 2, 9 + }; - std::vector expAssignment = {0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, - 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, - 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0}; + std::vector expAssignment = { + 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, + 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0 + }; std::vector assignment(N * M); hungarian(costsVec.data(), assignment.data(), N, M); for(int c = 0; c < N; c++) @@ -121,13 +126,17 @@ TEST(HungarianTest, FullPipelineSize6) { TEST(HungarianTest, 6x6Example2) { int M = 6; // Rows int N = 6; // Columns - std::vector costsVec = {7, 9, 3, 7, 8, 4, 2, 6, 8, 9, 4, 2, - 1, 9, 3, 4, 7, 9, 1, 3, 4, 8, 2, 7, - 4, 5, 8, 2, 8, 1, 4, 2, 9, 3, 2, 9}; + std::vector costsVec = { + 7, 9, 3, 7, 8, 4, 2, 6, 8, 9, 4, 2, + 1, 9, 3, 4, 7, 9, 1, 3, 4, 8, 2, 7, + 4, 5, 8, 2, 8, 1, 4, 2, 9, 3, 2, 9 + }; - std::vector expAssignment = {0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, - 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0}; + std::vector expAssignment = { + 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, + 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0 + }; std::vector assignment(N * M); hungarian(costsVec.data(), assignment.data(), N, M); for(int c = 0; c < N; c++) diff --git a/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp b/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp index d0a84e1..17b3c95 100644 --- a/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp +++ b/flashlight/pkg/vision/test/criterion/SetCriterionTest.cpp @@ -18,7 +18,10 @@ using namespace fl::pkg::vision; std::unordered_map getLossWeights() { const std::unordered_map lossWeightsBase = { - {"lossCe", 1.f}, {"lossGiou", 1.f}, {"lossBbox", 1.f}}; + {"lossCe", 1.f}, {"lossGiou", 1.f}, { + "lossBbox", 1.f + } + }; std::unordered_map lossWeights; for(int i = 0; i < 6; i++) @@ -46,15 +49,19 @@ TEST(SetCriterion, PytorchRepro) { auto predLogits = fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - std::vector targetBoxes = {fl::Variable( - Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), - false - )}; + std::vector targetBoxes = { + fl::Variable( + Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), + false + ) + }; - std::vector targetClasses = {fl::Variable( - Tensor::fromVector({numTargets, numBatches}, targetClassVec), - false - )}; + std::vector targetClasses = { + fl::Variable( + Tensor::fromVector({numTargets, numBatches}, targetClassVec), + false + ) + }; auto matcher = HungarianMatcher(1, 1, 1); auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); @@ -79,15 +86,19 @@ TEST(SetCriterion, PytorchReproMultiplePreds) { auto predLogits = fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - std::vector targetBoxes = {fl::Variable( - Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), - false - )}; + std::vector targetBoxes = { + fl::Variable( + Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), + false + ) + }; - std::vector targetClasses = {fl::Variable( - Tensor::fromVector({1, numTargets, numBatches}, targetClassVec), - false - )}; + std::vector targetClasses = { + fl::Variable( + Tensor::fromVector({1, numTargets, numBatches}, targetClassVec), + false + ) + }; auto matcher = HungarianMatcher(1, 1, 1); auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); @@ -120,15 +131,19 @@ TEST(SetCriterion, PytorchReproMultipleTargets) { auto predLogits = fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - std::vector targetBoxes = {fl::Variable( - Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), - false - )}; + std::vector targetBoxes = { + fl::Variable( + Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), + false + ) + }; - std::vector targetClasses = {fl::Variable( - Tensor::fromVector({numTargets, numBatches}, targetClassVec), - false - )}; + std::vector targetClasses = { + fl::Variable( + Tensor::fromVector({numTargets, numBatches}, targetClassVec), + false + ) + }; auto matcher = HungarianMatcher(1, 1, 1); auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); @@ -143,7 +158,8 @@ TEST(SetCriterion, PytorchReproNoPerfectMatch) { std::vector predBoxesVec = {2, 2, 3, 3, 1, 1, 2, 2}; std::vector targetBoxesVec = { - 0.9, 0.8, 1.9, 1.95, 1.9, 1.95, 2.9, 2.95}; + 0.9, 0.8, 1.9, 1.95, 1.9, 1.95, 2.9, 2.95 + }; // std::vector predLogitsVec((numClasses + 1) * numPreds * numPreds, // 0.0); @@ -157,15 +173,19 @@ TEST(SetCriterion, PytorchReproNoPerfectMatch) { auto predLogits = fl::Variable(fl::full({numClasses + 1, numPreds, numBatches}, 1), true); - std::vector targetBoxes = {fl::Variable( - Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), - false - )}; + std::vector targetBoxes = { + fl::Variable( + Tensor::fromVector({4, numTargets, numBatches}, targetBoxesVec), + false + ) + }; - std::vector targetClasses = {fl::Variable( - Tensor::fromVector({numTargets, numBatches}, targetClassVec), - false - )}; + std::vector targetClasses = { + fl::Variable( + Tensor::fromVector({numTargets, numBatches}, targetClassVec), + false + ) + }; auto matcher = HungarianMatcher(1, 1, 1); auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); @@ -307,7 +327,8 @@ TEST(SetCriterion, PytorchReproBatching) { fl::Variable( Tensor::fromVector({4, numTargets, numPreds}, targetBoxesVec2), false - )}; + ) + }; std::vector targetClasses = { fl::Variable( @@ -317,7 +338,8 @@ TEST(SetCriterion, PytorchReproBatching) { fl::Variable( Tensor::fromVector({numTargets, numPreds, 1}, targetClassVec), false - )}; + ) + }; auto matcher = HungarianMatcher(1, 1, 1); auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); @@ -330,7 +352,8 @@ TEST(SetCriterion, DifferentNumberOfLabels) { const int numPreds = 2; const int numBatches = 2; std::vector predBoxesVec = { - 2, 2, 3, 3, 1, 1, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2}; + 2, 2, 3, 3, 1, 1, 2, 2, 2, 2, 3, 3, 1, 1, 2, 2 + }; std::vector targetBoxesVec1 = { 1, @@ -364,11 +387,13 @@ TEST(SetCriterion, DifferentNumberOfLabels) { std::vector targetBoxes = { fl::Variable(Tensor::fromVector({4, 2, 1}, targetBoxesVec1), false), - fl::Variable(Tensor::fromVector({4, 1, 1}, targetBoxesVec2), false)}; + fl::Variable(Tensor::fromVector({4, 1, 1}, targetBoxesVec2), false) + }; std::vector targetClasses = { fl::Variable(fl::full({2, 1, 1}, 1), false), - fl::Variable(fl::full({1, 1, 1}, 1), false)}; + fl::Variable(fl::full({1, 1, 1}, 1), false) + }; auto matcher = HungarianMatcher(1, 1, 1); auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); @@ -380,8 +405,10 @@ TEST(SetCriterion, DifferentNumberOfLabelsClass) { const int numClasses = 80; const int numPreds = 3; const int numBatches = 2; - std::vector predBoxesVec = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + std::vector predBoxesVec = { + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + }; std::vector targetBoxesVec1 = {1, 1, 1, 1, 1, 1, 1, 1}; @@ -404,11 +431,13 @@ TEST(SetCriterion, DifferentNumberOfLabelsClass) { std::vector targetBoxes = { fl::Variable(Tensor::fromVector({4, 2, 1}, targetBoxesVec1), false), - fl::Variable(Tensor::fromVector({4, 1, 1}, targetBoxesVec2), false)}; + fl::Variable(Tensor::fromVector({4, 1, 1}, targetBoxesVec2), false) + }; std::vector targetClasses = { fl::Variable(fl::iota({2}), false), - fl::Variable(fl::full({1, 1, 1}, 9), false)}; + fl::Variable(fl::full({1, 1, 1}, 9), false) + }; auto matcher = HungarianMatcher(1, 1, 1); auto crit = SetCriterion(80, matcher, getLossWeights(), 0.0); auto loss = crit.forward(predBoxes, predLogits, targetBoxes, targetClasses); diff --git a/uncrustify.cfg b/uncrustify.cfg index 3ba25cb..022145b 100644 --- a/uncrustify.cfg +++ b/uncrustify.cfg @@ -30,6 +30,7 @@ use_indent_func_call_param = true donot_indent_func_def_close_paren = true align_func_params = false indent_paren_close = 2 +indent_paren_open_brace = true indent_align_paren = false indent_paren_after_func_def = false indent_paren_after_func_decl = false @@ -234,13 +235,15 @@ nl_template_func = force nl_template_func_decl = force nl_template_func_def = force nl_template_var = remove +nl_type_brace_init_lst_close = force +nl_type_brace_init_lst_open = force nl_func_decl_start_multi_line = true nl_func_def_start_multi_line = true nl_func_decl_args_multi_line = true nl_func_def_args_multi_line = true nl_func_decl_end_multi_line = true nl_func_def_end_multi_line = true -nl_func_call_start_multi_line = true +nl_func_call_start_multi_line = false nl_func_call_args_multi_line = true nl_func_call_end_multi_line = true pos_arith = lead