diff --git a/Project.toml b/Project.toml index bc8b170..28cb8eb 100644 --- a/Project.toml +++ b/Project.toml @@ -23,12 +23,14 @@ Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [extensions] TensorOperationsBumperExt = "Bumper" TensorOperationsChainRulesCoreExt = "ChainRulesCore" TensorOperationsMooncakeExt = "Mooncake" +TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"] TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"] [compat] @@ -38,6 +40,8 @@ CUDA = "5" ChainRulesCore = "1" ChainRulesTestUtils = "1" DynamicPolynomials = "0.5, 0.6" +Enzyme = "0.13.115" +EnzymeTestUtils = "0.2" LRUCache = "1" LinearAlgebra = "1.6" Logging = "1.6" @@ -59,8 +63,11 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -68,4 +75,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] -test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake"] +test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "ChainRulesCore", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils"] diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl new file mode 100644 index 0000000..8b5106f --- /dev/null +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -0,0 +1,187 @@ +module TensorOperationsEnzymeExt + +using TensorOperations +using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator +using VectorInterface +using TupleTools +using Enzyme, ChainRulesCore +using Enzyme.EnzymeCore +using Enzyme.EnzymeCore: EnzymeRules + +@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorfree!), ::Any) = true +Enzyme.@import_rrule(typeof(TensorOperations.tensoralloc), Any, Any, Any, Any) + +@inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true +@inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorOperations.tensorcontract!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + B_dB::Annotation{<:AbstractArray{TB}}, + pB_dpB::Const{<:Index2Tuple}, + conjB_dconjB::Const{Bool}, + pAB_dpAB::Const{<:Index2Tuple}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + # form caches if needed + cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_B = EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing + cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) + primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorOperations.tensorcontract!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + B_dB::Annotation{<:AbstractArray{TB}}, + pB_dpB::Const{<:Index2Tuple}, + conjB_dconjB::Const{Bool}, + pAB_dpAB::Const{<:Index2Tuple}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + cache_A, cache_B, cache_C = cache + Aval = something(cache_A, A_dA.val) + Bval = something(cache_B, B_dB.val) + Cval = cache_C + # good way to check that we don't use it accidentally when we should not be needing it? + dC = C_dC.dval + dA = A_dA.dval + dB = B_dB.dval + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + pA, pB, pAB, conjA, conjB = getfield.((pA_dpA, pB_dpB, pAB_dpAB, conjA_dconjA, conjB_dconjB), :val) + dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, β, ba...) + return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensoradd!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + # form caches if needed + cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + conjA = conjA_dconjA.val + TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...) + primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensoradd!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + cache_A, cache_C = cache + Aval = something(cache_A, A_dA.val) + Cval = cache_C + pA = pA_dpA.val + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + dC = C_dC.dval + dA = A_dA.dval + dC, dA, dα, dβ = TensorOperations.tensoradd_pullback!(dC, dA, Cval, Aval, pA, conjA, α, β, ba...) + return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensortrace!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + p_dp::Const{<:Index2Tuple}, + q_dq::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + # form caches if needed + cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + conjA = conjA_dconjA.val + TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...) + primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensortrace!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + p_dp::Const{<:Index2Tuple}, + q_dq::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + cache_A, cache_C = cache + Aval = something(cache_A, A_dA.val) + Cval = cache_C + p = p_dp.val + q = q_dq.val + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + dC = C_dC.dval + dA = A_dA.dval + dC, dA, dα, dβ = TensorOperations.tensortrace_pullback!(dC, dA, Cval, Aval, p, q, conjA, α, β, ba...) + return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +end diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 0000000..e31926f --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,106 @@ +using TensorOperations, VectorInterface +using Enzyme, ChainRulesCore, EnzymeTestUtils + +@testset "tensorcontract!" begin + pAB = ((3, 2, 4, 1), ()) + pA = ((2, 4, 5), (1, 3)) + pB = ((2, 1), (3,)) + @testset "($T₁, $T₂)" for (T₁, T₂) in ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + (ComplexF64, Float64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + A = rand(T₁, (2, 3, 4, 2, 5)) + B = rand(T₂, (4, 2, 3)) + C = rand(T, (5, 2, 3, 3)) + zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T))) + αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),) + # test zeros only once to avoid wasteful tests + @testset for (α, β) in αβs + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + + end + end +end + +@testset "tensoradd!" begin + pA = ((2, 1, 4, 3, 5), ()) + @testset "($T₁, $T₂)" for (T₁, T₂) in ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + A = rand(T₁, (2, 3, 4, 2, 1)) + C = rand(T₂, size.(Ref(A), pA[1])) + zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T))) + αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),) + # test zeros only once to avoid wasteful tests + @testset for (α, β) in αβs + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end + end +end + +@testset "tensortrace!" begin + p = ((3, 5, 2), ()) + q = ((1,), (4,)) + @testset "($T₁, $T₂)" for (T₁, T₂) in + ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + A = rand(T₁, (2, 3, 4, 2, 5)) + C = rand(T₂, size.(Ref(A), p[1])) + zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T))) + αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),) + # test zeros only once to avoid wasteful tests + @testset for (α, β) in αβs + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end + end +end + +@testset "tensorscalar ($T)" for T in (Float32, Float64, ComplexF64) + atol = precision(T) + rtol = precision(T) + + C = Array{T, 0}(undef, ()) + fill!(C, rand(T)) + test_reverse(tensorscalar, Active, (C, Duplicated); atol, rtol) +end diff --git a/test/runtests.jl b/test/runtests.jl index f67fbe6..6c74557 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,6 @@ precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-8 # specific ones is_buildkite = get(ENV, "BUILDKITE", "false") == "true" if !is_buildkite - @testset "tensoropt" verbose = true begin include("tensoropt.jl") end @@ -37,6 +36,12 @@ if !is_buildkite @testset "mooncake" verbose = false begin include("mooncake.jl") end + # mystery segfault on 1.10 for now + @static if VERSION >= v"1.11.0" + @testset "enzyme" verbose = false begin + include("enzyme.jl") + end + end end if is_buildkite