Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[extensions]
TensorOperationsBumperExt = "Bumper"
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
TensorOperationsMooncakeExt = "Mooncake"
TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"]
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]

[compat]
Expand All @@ -38,6 +40,8 @@ CUDA = "5"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
DynamicPolynomials = "0.5, 0.6"
Enzyme = "0.13.115"
EnzymeTestUtils = "0.2"
LRUCache = "1"
LinearAlgebra = "1.6"
Logging = "1.6"
Expand All @@ -59,13 +63,16 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake"]
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "ChainRulesCore", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils"]
187 changes: 187 additions & 0 deletions ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
module TensorOperationsEnzymeExt

using TensorOperations
using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator
using VectorInterface
using TupleTools
using Enzyme, ChainRulesCore
using Enzyme.EnzymeCore
using Enzyme.EnzymeCore: EnzymeRules

@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorfree!), ::Any) = true
Enzyme.@import_rrule(typeof(TensorOperations.tensoralloc), Any, Any, Any, Any)

@inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true
@inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true
@inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true
@inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true
@inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorOperations.tensorcontract!)},
::Type{RT},
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
pA_dpA::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
B_dB::Annotation{<:AbstractArray{TB}},
pB_dpB::Const{<:Index2Tuple},
conjB_dconjB::Const{Bool},
pAB_dpAB::Const{<:Index2Tuple},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number}
# form caches if needed
cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
cache_B = EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing
cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...)
primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_C))
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorOperations.tensorcontract!)},
::Type{RT},
cache,
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
pA_dpA::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
B_dB::Annotation{<:AbstractArray{TB}},
pB_dpB::Const{<:Index2Tuple},
conjB_dconjB::Const{Bool},
pAB_dpAB::Const{<:Index2Tuple},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number}
cache_A, cache_B, cache_C = cache
Aval = something(cache_A, A_dA.val)
Bval = something(cache_B, B_dB.val)
Cval = cache_C
# good way to check that we don't use it accidentally when we should not be needing it?
dC = C_dC.dval
dA = A_dA.dval
dB = B_dB.dval
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
α = α_dα.val
β = β_dβ.val
pA, pB, pAB, conjA, conjB = getfield.((pA_dpA, pB_dpB, pAB_dpAB, conjA_dconjA, conjB_dconjB), :val)
dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, β, ba...)
return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
::Annotation{typeof(tensoradd!)},
::Type{RT},
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
pA_dpA::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
# form caches if needed
cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
α = α_dα.val
β = β_dβ.val
conjA = conjA_dconjA.val
TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...)
primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C))
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
::Annotation{typeof(tensoradd!)},
::Type{RT},
cache,
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
pA_dpA::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
cache_A, cache_C = cache
Aval = something(cache_A, A_dA.val)
Cval = cache_C
pA = pA_dpA.val
conjA = conjA_dconjA.val
α = α_dα.val
β = β_dβ.val
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
dC = C_dC.dval
dA = A_dA.dval
dC, dA, dα, dβ = TensorOperations.tensoradd_pullback!(dC, dA, Cval, Aval, pA, conjA, α, β, ba...)
return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
::Annotation{typeof(tensortrace!)},
::Type{RT},
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
p_dp::Const{<:Index2Tuple},
q_dq::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
# form caches if needed
cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
α = α_dα.val
β = β_dβ.val
conjA = conjA_dconjA.val
TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...)
primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C))
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
::Annotation{typeof(tensortrace!)},
::Type{RT},
cache,
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
p_dp::Const{<:Index2Tuple},
q_dq::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
cache_A, cache_C = cache
Aval = something(cache_A, A_dA.val)
Cval = cache_C
p = p_dp.val
q = q_dq.val
conjA = conjA_dconjA.val
α = α_dα.val
β = β_dβ.val
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
dC = C_dC.dval
dA = A_dA.dval
dC, dA, dα, dβ = TensorOperations.tensortrace_pullback!(dC, dA, Cval, Aval, p, q, conjA, α, β, ba...)
return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
end

end
106 changes: 106 additions & 0 deletions test/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using TensorOperations, VectorInterface
using Enzyme, ChainRulesCore, EnzymeTestUtils

@testset "tensorcontract!" begin
pAB = ((3, 2, 4, 1), ())
pA = ((2, 4, 5), (1, 3))
pB = ((2, 1), (3,))
@testset "($T₁, $T₂)" for (T₁, T₂) in (
(Float64, Float64),
(Float32, Float64),
(ComplexF64, ComplexF64),
(Float64, ComplexF64),
(ComplexF64, Float64),
)
T = promote_type(T₁, T₂)
atol = max(precision(T₁), precision(T₂))
rtol = max(precision(T₁), precision(T₂))

A = rand(T₁, (2, 3, 4, 2, 5))
B = rand(T₂, (4, 2, 3))
C = rand(T, (5, 2, 3, 3))
zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)))
αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),)
# test zeros only once to avoid wasteful tests
@testset for (α, β) in αβs
Tα = α === Zero() ? Const : Active
Tβ = β === Zero() ? Const : Active
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol)
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol)
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol)

test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol)
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol)

end
end
end

@testset "tensoradd!" begin
pA = ((2, 1, 4, 3, 5), ())
@testset "($T₁, $T₂)" for (T₁, T₂) in (
(Float64, Float64),
(Float32, Float64),
(ComplexF64, ComplexF64),
(Float64, ComplexF64),
)
T = promote_type(T₁, T₂)
atol = max(precision(T₁), precision(T₂))
rtol = max(precision(T₁), precision(T₂))

A = rand(T₁, (2, 3, 4, 2, 1))
C = rand(T₂, size.(Ref(A), pA[1]))
zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)))
αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),)
# test zeros only once to avoid wasteful tests
@testset for (α, β) in αβs
Tα = α === Zero() ? Const : Active
Tβ = β === Zero() ? Const : Active
test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol)
test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol)

test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol)
test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol)
end
end
end

@testset "tensortrace!" begin
p = ((3, 5, 2), ())
q = ((1,), (4,))
@testset "($T₁, $T₂)" for (T₁, T₂) in
(
(Float64, Float64),
(Float32, Float64),
(ComplexF64, ComplexF64),
(Float64, ComplexF64),
)
T = promote_type(T₁, T₂)
atol = max(precision(T₁), precision(T₂))
rtol = max(precision(T₁), precision(T₂))

A = rand(T₁, (2, 3, 4, 2, 5))
C = rand(T₂, size.(Ref(A), p[1]))
zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)))
αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),)
# test zeros only once to avoid wasteful tests
@testset for (α, β) in αβs
Tα = α === Zero() ? Const : Active
Tβ = β === Zero() ? Const : Active
test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol)
test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol)

test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol)
test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol)
end
end
end

@testset "tensorscalar ($T)" for T in (Float32, Float64, ComplexF64)
atol = precision(T)
rtol = precision(T)

C = Array{T, 0}(undef, ())
fill!(C, rand(T))
test_reverse(tensorscalar, Active, (C, Duplicated); atol, rtol)
end
7 changes: 6 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-8
# specific ones
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
if !is_buildkite

@testset "tensoropt" verbose = true begin
include("tensoropt.jl")
end
Expand All @@ -37,6 +36,12 @@ if !is_buildkite
@testset "mooncake" verbose = false begin
include("mooncake.jl")
end
# mystery segfault on 1.10 for now
@static if VERSION >= v"1.11.0"
@testset "enzyme" verbose = false begin
include("enzyme.jl")
end
end
end

if is_buildkite
Expand Down
Loading