From bbda095da3e0476e109d29ab14fbb56e956a0242 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 14 Jan 2026 15:32:57 +0100 Subject: [PATCH 01/13] Add Enzyme rules --- Project.toml | 9 +- .../TensorOperationsEnzymeExt.jl | 197 ++++++++++++++++++ test/enzyme.jl | 93 +++++++++ test/mooncake.jl | 1 + test/runtests.jl | 7 +- 5 files changed, 305 insertions(+), 2 deletions(-) create mode 100644 ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl create mode 100644 test/enzyme.jl diff --git a/Project.toml b/Project.toml index 8ff7ca9..973924c 100644 --- a/Project.toml +++ b/Project.toml @@ -24,12 +24,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [extensions] TensorOperationsBumperExt = "Bumper" TensorOperationsChainRulesCoreExt = "ChainRulesCore" TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"] TensorOperationsMooncakeExt = "Mooncake" +TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"] [compat] Aqua = "0.6, 0.7, 0.8" @@ -38,6 +40,8 @@ CUDA = "5" ChainRulesCore = "1" ChainRulesTestUtils = "1" DynamicPolynomials = "0.5, 0.6" +Enzyme = "0.13.115" +EnzymeTestUtils = "0.2" LRUCache = "1" LinearAlgebra = "1.6" Logging = "1.6" @@ -59,8 +63,11 @@ julia = "1.10" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -68,4 +75,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] -test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake"] +test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "ChainRulesCore", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils"] diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl new file mode 100644 index 0000000..1b5cf3c --- /dev/null +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -0,0 +1,197 @@ +module TensorOperationsEnzymeExt + +using TensorOperations +using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator +using VectorInterface +using TupleTools +using Enzyme, ChainRulesCore +using Enzyme.EnzymeCore +using Enzyme.EnzymeCore: EnzymeRules + +@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorfree!), ::Any) = true +Enzyme.@import_rrule(typeof(TensorOperations.tensoralloc), Any, Any, Any, Any) + +@inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true +@inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true +@inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorOperations.tensorcontract!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + B_dB::Annotation{<:AbstractArray{TB}}, + pB_dpB::Const{<:Index2Tuple}, + conjB_dconjB::Const{Bool}, + pAB_dpAB::Const{<:Index2Tuple}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + # form caches if needed + cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_B = !isa(B_dB, Const) && EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing + cache_C = copy(C_dC.val) # do we need to do this, if we don't need the primal? + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) + primal = if EnzymeRules.needs_primal(config) + C_dC.val + else + nothing + end + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorOperations.tensorcontract!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + B_dB::Annotation{<:AbstractArray{TB}}, + pB_dpB::Const{<:Index2Tuple}, + conjB_dconjB::Const{Bool}, + pAB_dpAB::Const{<:Index2Tuple}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} + cache_A, cache_B, cache_C = cache + Aval = something(cache_A, A_dA.val) + Bval = something(cache_B, B_dB.val) + Cval = cache_C + dC = C_dC.dval + dA = A_dA.dval + dB = B_dB.dval + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, pA_dpA.val, conjA_dconjA.val, Bval, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α, β, ba...) + return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensoradd!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + # form caches if needed + cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_C = copy(C_dC.val) + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + conjA = conjA_dconjA.val + TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...) + primal = if EnzymeRules.needs_primal(config) + C_dC.val + else + nothing + end + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensoradd!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + pA_dpA::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + cache_A, cache_C = cache + Aval = something(cache_A, A_dA.val) + Cval = cache_C + pA = pA_dpA.val + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + dC = C_dC.dval + dA = A_dA.dval + dC, dA, dα, dβ = TensorOperations.tensoradd_pullback!(dC, dA, Cval, Aval, pA, conjA, α, β, ba...) + return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensortrace!)}, + ::Type{RT}, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + p_dp::Const{<:Index2Tuple}, + q_dq::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + # form caches if needed + cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_C = copy(C_dC.val) + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + α = α_dα.val + β = β_dβ.val + conjA = conjA_dconjA.val + TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...) + primal = if EnzymeRules.needs_primal(config) + C_dC.val + else + nothing + end + shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + ::Annotation{typeof(tensortrace!)}, + ::Type{RT}, + cache, + C_dC::Annotation{<:AbstractArray{TC}}, + A_dA::Annotation{<:AbstractArray{TA}}, + p_dp::Const{<:Index2Tuple}, + q_dq::Const{<:Index2Tuple}, + conjA_dconjA::Const{Bool}, + α_dα::Annotation{Tα}, + β_dβ::Annotation{Tβ}, + ba_dba::Const..., + ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} + cache_A, cache_C = cache + Aval = something(cache_A, A_dA.val) + Cval = cache_C + p = p_dp.val + q = q_dq.val + conjA = conjA_dconjA.val + α = α_dα.val + β = β_dβ.val + ba = map(ba_ -> getfield(ba_, :val), ba_dba) + dC = C_dC.dval + dA = A_dA.dval + dC, dA, dα, dβ = TensorOperations.tensortrace_pullback!(dC, dA, Cval, Aval, p, q, conjA, α, β, ba...) + return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... +end + +end diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 0000000..2852b5d --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,93 @@ +using TensorOperations, VectorInterface +using Enzyme, ChainRulesCore, EnzymeTestUtils + +@testset "tensorcontract! ($T₁, $T₂)" for (T₁, T₂) in + ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + (ComplexF64, Float64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + pAB = ((3, 2, 4, 1), ()) + pA = ((2, 4, 5), (1, 3)) + pB = ((2, 1), (3,)) + + A = rand(T₁, (2, 3, 4, 2, 5)) + B = rand(T₂, (4, 2, 3)) + C = rand(T, (5, 2, 3, 3)) + @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end +end + +@testset "tensoradd! ($T₁, $T₂)" for (T₁, T₂) in ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + pA = ((2, 1, 4, 3, 5), ()) + A = rand(T₁, (2, 3, 4, 2, 1)) + C = rand(T₂, size.(Ref(A), pA[1])) + @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end +end + +@testset "tensortrace! ($T₁, $T₂)" for (T₁, T₂) in + ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + p = ((3, 5, 2), ()) + q = ((1,), (4,)) + A = rand(T₁, (2, 3, 4, 2, 5)) + C = rand(T₂, size.(Ref(A), p[1])) + @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end +end + +@testset "tensorscalar ($T)" for T in (Float32, Float64, ComplexF64) + atol = precision(T) + rtol = precision(T) + + C = Array{T, 0}(undef, ()) + fill!(C, rand(T)) + test_reverse(tensorscalar, Active, (C, Duplicated); atol, rtol) +end diff --git a/test/mooncake.jl b/test/mooncake.jl index 1790ba4..9b53117 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -14,6 +14,7 @@ is_primitive = false (Float32, Float64), (ComplexF64, ComplexF64), (Float64, ComplexF64), + (ComplexF64, Float64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂)) diff --git a/test/runtests.jl b/test/runtests.jl index f67fbe6..6c74557 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,6 @@ precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-8 # specific ones is_buildkite = get(ENV, "BUILDKITE", "false") == "true" if !is_buildkite - @testset "tensoropt" verbose = true begin include("tensoropt.jl") end @@ -37,6 +36,12 @@ if !is_buildkite @testset "mooncake" verbose = false begin include("mooncake.jl") end + # mystery segfault on 1.10 for now + @static if VERSION >= v"1.11.0" + @testset "enzyme" verbose = false begin + include("enzyme.jl") + end + end end if is_buildkite From a3ebb5d33f6f115de843309969980cf7eec3d744 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sun, 25 Jan 2026 16:29:45 +0100 Subject: [PATCH 02/13] Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 1b5cf3c..6918770 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -149,8 +149,8 @@ function EnzymeRules.augmented_primal( ba_dba::Const..., ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} # form caches if needed - cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing - cache_C = copy(C_dC.val) + cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : nothing ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val From 03851f9a531045e93d0fb3e37f44b845322e6db5 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sun, 25 Jan 2026 16:29:54 +0100 Subject: [PATCH 03/13] Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 6918770..f9fec8c 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -92,8 +92,8 @@ function EnzymeRules.augmented_primal( ba_dba::Const..., ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} # form caches if needed - cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing - cache_C = copy(C_dC.val) + cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : nothing ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val From 5d7419fc18bce280c5f25c781c38d11ff26b71af Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sun, 25 Jan 2026 16:30:11 +0100 Subject: [PATCH 04/13] Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index f9fec8c..102fb4d 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -75,7 +75,8 @@ function EnzymeRules.reverse( ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val - dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, pA_dpA.val, conjA_dconjA.val, Bval, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α, β, ba...) + pA, pB, pAB, conjA, conjB = getfield.((pA_dpA, pB_dpB, pAB_dpAB, conjA_dconjA, conjB_dconjB), :val) + dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, β, ba...) return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... end From 42f4b5ce40055b240e75becc5f04587bfe48eea9 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sun, 25 Jan 2026 16:30:27 +0100 Subject: [PATCH 05/13] Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 102fb4d..1df464f 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -34,8 +34,8 @@ function EnzymeRules.augmented_primal( ba_dba::Const..., ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number} # form caches if needed - cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing - cache_B = !isa(B_dB, Const) && EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing + cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing + cache_B = EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing cache_C = copy(C_dC.val) # do we need to do this, if we don't need the primal? ba = map(ba_ -> getfield(ba_, :val), ba_dba) TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) From 48186e005c729bee81b1b4685d00fd1a97e14e4e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sun, 25 Jan 2026 16:30:42 +0100 Subject: [PATCH 06/13] Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 1df464f..e159c66 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -68,7 +68,8 @@ function EnzymeRules.reverse( cache_A, cache_B, cache_C = cache Aval = something(cache_A, A_dA.val) Bval = something(cache_B, B_dB.val) - Cval = cache_C + Cval = cache_C # might be nothing if iszero(β) + # good way to check that we don't use it accidentally when we should not be needing it? dC = C_dC.dval dA = A_dA.dval dB = B_dB.dval From 8fad856107efbdbebd47b9f38b71b7f29ac68aff Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sun, 25 Jan 2026 16:30:56 +0100 Subject: [PATCH 07/13] Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index e159c66..6e52379 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -36,7 +36,7 @@ function EnzymeRules.augmented_primal( # form caches if needed cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing cache_B = EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing - cache_C = copy(C_dC.val) # do we need to do this, if we don't need the primal? + cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : nothing ba = map(ba_ -> getfield(ba_, :val), ba_dba) TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) primal = if EnzymeRules.needs_primal(config) From db094a9633e9c9cd6c5dab118cb36ff338f79aa3 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 26 Jan 2026 14:12:05 +0100 Subject: [PATCH 08/13] Fix cache and simplify tests --- .../TensorOperationsEnzymeExt.jl | 11 +- test/enzyme.jl | 137 ++++++++++-------- test/runtests.jl | 4 +- 3 files changed, 84 insertions(+), 68 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 6e52379..107583f 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -36,7 +36,8 @@ function EnzymeRules.augmented_primal( # form caches if needed cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing cache_B = EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing - cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : nothing + cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val + #cache_C = copy(C_dC.val) ba = map(ba_ -> getfield(ba_, :val), ba_dba) TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) primal = if EnzymeRules.needs_primal(config) @@ -68,7 +69,7 @@ function EnzymeRules.reverse( cache_A, cache_B, cache_C = cache Aval = something(cache_A, A_dA.val) Bval = something(cache_B, B_dB.val) - Cval = cache_C # might be nothing if iszero(β) + Cval = cache_C # good way to check that we don't use it accidentally when we should not be needing it? dC = C_dC.dval dA = A_dA.dval @@ -95,7 +96,8 @@ function EnzymeRules.augmented_primal( ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} # form caches if needed cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing - cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : nothing + cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val + #cache_C = copy(C_dC.val) ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val @@ -152,7 +154,8 @@ function EnzymeRules.augmented_primal( ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} # form caches if needed cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing - cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : nothing + cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val + #cache_C = copy(C_dC.val) ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val diff --git a/test/enzyme.jl b/test/enzyme.jl index 2852b5d..e31926f 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -1,85 +1,98 @@ using TensorOperations, VectorInterface using Enzyme, ChainRulesCore, EnzymeTestUtils -@testset "tensorcontract! ($T₁, $T₂)" for (T₁, T₂) in - ( - (Float64, Float64), - (Float32, Float64), - (ComplexF64, ComplexF64), - (Float64, ComplexF64), - (ComplexF64, Float64), - ) - T = promote_type(T₁, T₂) - atol = max(precision(T₁), precision(T₂)) - rtol = max(precision(T₁), precision(T₂)) - +@testset "tensorcontract!" begin pAB = ((3, 2, 4, 1), ()) pA = ((2, 4, 5), (1, 3)) pB = ((2, 1), (3,)) + @testset "($T₁, $T₂)" for (T₁, T₂) in ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + (ComplexF64, Float64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) + + A = rand(T₁, (2, 3, 4, 2, 5)) + B = rand(T₂, (4, 2, 3)) + C = rand(T, (5, 2, 3, 3)) + zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T))) + αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),) + # test zeros only once to avoid wasteful tests + @testset for (α, β) in αβs + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) - A = rand(T₁, (2, 3, 4, 2, 5)) - B = rand(T₂, (4, 2, 3)) - C = rand(T, (5, 2, 3, 3)) - @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) - Tα = α === Zero() ? Const : Active - Tβ = β === Zero() ? Const : Active - test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) - test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) - test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) - test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) - test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end end end -@testset "tensoradd! ($T₁, $T₂)" for (T₁, T₂) in ( - (Float64, Float64), - (Float32, Float64), - (ComplexF64, ComplexF64), - (Float64, ComplexF64), - ) - T = promote_type(T₁, T₂) - atol = max(precision(T₁), precision(T₂)) - rtol = max(precision(T₁), precision(T₂)) - +@testset "tensoradd!" begin pA = ((2, 1, 4, 3, 5), ()) - A = rand(T₁, (2, 3, 4, 2, 1)) - C = rand(T₂, size.(Ref(A), pA[1])) - @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) - Tα = α === Zero() ? Const : Active - Tβ = β === Zero() ? Const : Active - test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) - test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + @testset "($T₁, $T₂)" for (T₁, T₂) in ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) - test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) - test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + A = rand(T₁, (2, 3, 4, 2, 1)) + C = rand(T₂, size.(Ref(A), pA[1])) + zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T))) + αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),) + # test zeros only once to avoid wasteful tests + @testset for (α, β) in αβs + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end end end -@testset "tensortrace! ($T₁, $T₂)" for (T₁, T₂) in - ( - (Float64, Float64), - (Float32, Float64), - (ComplexF64, ComplexF64), - (Float64, ComplexF64), - ) - T = promote_type(T₁, T₂) - atol = max(precision(T₁), precision(T₂)) - rtol = max(precision(T₁), precision(T₂)) - +@testset "tensortrace!" begin p = ((3, 5, 2), ()) q = ((1,), (4,)) - A = rand(T₁, (2, 3, 4, 2, 5)) - C = rand(T₂, size.(Ref(A), p[1])) - @testset for (α, β) in ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)), (randn(T), randn(T))) - Tα = α === Zero() ? Const : Active - Tβ = β === Zero() ? Const : Active + @testset "($T₁, $T₂)" for (T₁, T₂) in + ( + (Float64, Float64), + (Float32, Float64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), + ) + T = promote_type(T₁, T₂) + atol = max(precision(T₁), precision(T₂)) + rtol = max(precision(T₁), precision(T₂)) - test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) - test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + A = rand(T₁, (2, 3, 4, 2, 5)) + C = rand(T₂, size.(Ref(A), p[1])) + zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T))) + αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),) + # test zeros only once to avoid wasteful tests + @testset for (α, β) in αβs + Tα = α === Zero() ? Const : Active + Tβ = β === Zero() ? Const : Active + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) - test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) - test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 6c74557..17927e4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,7 @@ precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-8 # specific ones is_buildkite = get(ENV, "BUILDKITE", "false") == "true" if !is_buildkite - @testset "tensoropt" verbose = true begin + #=@testset "tensoropt" verbose = true begin include("tensoropt.jl") end @testset "auxiliary" verbose = true begin @@ -35,7 +35,7 @@ if !is_buildkite end @testset "mooncake" verbose = false begin include("mooncake.jl") - end + end=# # mystery segfault on 1.10 for now @static if VERSION >= v"1.11.0" @testset "enzyme" verbose = false begin From 83ac936df1b57e0e081e883a03acfb14dc83af9a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 26 Jan 2026 14:25:07 +0100 Subject: [PATCH 09/13] Re-enable all tests --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 3 --- test/runtests.jl | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 107583f..d8b70fa 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -37,7 +37,6 @@ function EnzymeRules.augmented_primal( cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing cache_B = EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val - #cache_C = copy(C_dC.val) ba = map(ba_ -> getfield(ba_, :val), ba_dba) TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) primal = if EnzymeRules.needs_primal(config) @@ -97,7 +96,6 @@ function EnzymeRules.augmented_primal( # form caches if needed cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val - #cache_C = copy(C_dC.val) ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val @@ -155,7 +153,6 @@ function EnzymeRules.augmented_primal( # form caches if needed cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val - #cache_C = copy(C_dC.val) ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val diff --git a/test/runtests.jl b/test/runtests.jl index 17927e4..6c74557 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,7 @@ precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-8 # specific ones is_buildkite = get(ENV, "BUILDKITE", "false") == "true" if !is_buildkite - #=@testset "tensoropt" verbose = true begin + @testset "tensoropt" verbose = true begin include("tensoropt.jl") end @testset "auxiliary" verbose = true begin @@ -35,7 +35,7 @@ if !is_buildkite end @testset "mooncake" verbose = false begin include("mooncake.jl") - end=# + end # mystery segfault on 1.10 for now @static if VERSION >= v"1.11.0" @testset "enzyme" verbose = false begin From 100bc258b8207ea18f75530935ae12dafb0acb59 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 11:59:50 +0100 Subject: [PATCH 10/13] Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index d8b70fa..ad80822 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -39,11 +39,7 @@ function EnzymeRules.augmented_primal( cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val ba = map(ba_ -> getfield(ba_, :val), ba_dba) TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) - primal = if EnzymeRules.needs_primal(config) - C_dC.val - else - nothing - end + primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_C)) end From eb3754b59655ce5fd972d6c7e0e5604d42d1c82a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 11:59:59 +0100 Subject: [PATCH 11/13] Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index ad80822..43f458f 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -154,11 +154,7 @@ function EnzymeRules.augmented_primal( β = β_dβ.val conjA = conjA_dconjA.val TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...) - primal = if EnzymeRules.needs_primal(config) - C_dC.val - else - nothing - end + primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) end From 3a6d6ae45af57d95b75d14e324cfd29b926dec8f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 12:00:12 +0100 Subject: [PATCH 12/13] Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho --- ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 43f458f..8b5106f 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -97,11 +97,7 @@ function EnzymeRules.augmented_primal( β = β_dβ.val conjA = conjA_dconjA.val TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...) - primal = if EnzymeRules.needs_primal(config) - C_dC.val - else - nothing - end + primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C)) end From 12cf6977befbd9b3d733716c0284c7ab9659443e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 12:22:29 +0100 Subject: [PATCH 13/13] Remove irrelevant Mooncake test --- test/mooncake.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/mooncake.jl b/test/mooncake.jl index 9b53117..1790ba4 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -14,7 +14,6 @@ is_primitive = false (Float32, Float64), (ComplexF64, ComplexF64), (Float64, ComplexF64), - (ComplexF64, Float64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂))