diff --git a/BUILD.bazel b/BUILD.bazel index 885bb665..50db088e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -387,6 +387,7 @@ cc_library( "ops/sum-inl.h", "ops/fp_arith-inl.h", "ops/ops-inl.h", + "ops/fast_ops-inl.h", ], deps = [ ":allocator", diff --git a/ops/fast_ops-inl.h b/ops/fast_ops-inl.h new file mode 100644 index 00000000..1120d9a1 --- /dev/null +++ b/ops/fast_ops-inl.h @@ -0,0 +1,82 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard for non-SIMD code. +#ifndef THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_INL_H_ + +#include + +#include "ops/ops.h" +#include "util/threading_context.h" +#include "util/zones.h" +#include "hwy/base.h" + +#endif // THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_INL_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE) == \ + defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_OPS_FAST_OPS_TOGGLE +#endif + +#include "compression/compress-inl.h" +#include "hwy/contrib/math/fast_math-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// We use the tanh approximation for gelu (also used in training). +// gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) +// = 0.5 * x * (1 + tanh(x * (sqrt(2/π) + sqrt(2/π) * 0.044715 * x^2))) +// = 0.5 * x * (1 + tanh(x * (0.79788 + 0.035677 * x^2))) +// = x * (0.5 + 0.5 * tanh(x * (0.79788 + 0.035677 * x^2)))) +// +// This uses hn::FastTanh from +// third_party/highway/hwy/contrib/math/fast_math-inl.h +template +HWY_INLINE hn::Vec FastGelu(D d, hn::Vec v) { + const hn::Vec kMul = hn::Set(d, 0.03567740813636141f); + const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); + const hn::Vec kHalf = hn::Set(d, 0.5f); + + const hn::Vec v2 = hn::Mul(v, v); + const hn::Vec arg = hn::Mul(v, hn::MulAdd(kMul, v2, kSqrt2OverPi)); + const hn::Vec cdf = hn::MulAdd(kHalf, hn::FastTanh(d, arg), kHalf); + return hn::Mul(v, cdf); +} + +// Activation already has a profiler zone. +template +static HWY_NOINLINE HWY_MAYBE_UNUSED void FastGelu(T* HWY_RESTRICT x, + size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + DecompressAndCompressInplace( + DF(), x, size, [](DF d, VF v) HWY_ATTR -> VF { return FastGelu(d, v); }); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 0f83df15..3abb5b8e 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -48,6 +48,7 @@ // After highway.h #include "compression/test_util-inl.h" #include "ops/ops-inl.h" +#include "ops/fast_ops-inl.h" #include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); @@ -466,6 +467,32 @@ static HWY_NOINLINE void TestAllGelu() { ForeachActivationType1(hn::ScalableTag()); } +struct TestFastGelu { + template + void operator()(T, D) const { + std::vector values; + for (int i = -150; i <= 150; ++i) { + values.push_back(hwy::ConvertScalarTo(.1f * i)); + } + std::vector result = values; + gcpp::HWY_NAMESPACE::FastGelu(result.data(), result.size()); + + for (size_t i = 0; i < values.size(); i++) { + const float max_error = IsBF16() ? 0.02f : 0.002f; + const float x = hwy::ConvertScalarTo(values[i]); + const float actual = hwy::ConvertScalarTo(result[i]); + const float expected = + x * (0.5f + 0.5f * tanh(x * (0.79788f + 0.035677f * x * x))); + EXPECT_NEAR(expected, actual, max_error) + << (IsBF16() ? "bf16" : "float"); + } + } +}; + +static HWY_NOINLINE void TestAllFastGelu() { + ForeachActivationType1(hn::ScalableTag()); +} + static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( const float mul, float* HWY_RESTRICT x, const size_t dim_qkv, const float* HWY_RESTRICT inv_timescale, const int pos) { @@ -818,6 +845,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmaxState); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSigmoid); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllGelu); +HWY_EXPORT_AND_TEST_P(OpsTest, TestAllFastGelu); HWY_EXPORT_AND_TEST_P(OpsTest, TestRopeAndMulBy); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNorm); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllRMSNormInplace);