Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ TensorKitMooncakeExt = "Mooncake"
[workspace]
projects = ["test", "docs"]

[sources]
VectorInterface = {url = "https://github.com/QuantumKitHub/VectorInterface.jl", rev = "main"}

[compat]
Adapt = "4"
AMDGPU = "2"
Expand Down
12 changes: 11 additions & 1 deletion ext/TensorKitMooncakeExt/tangent.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Arrayify is needed to make MatrixAlgebraKit function properly -
# it turns coduals into argument types that MAK knows how to handle.
Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) = arrayify(primal(A_dA), tangent(A_dA))
Mooncake.arrayify(A_dA::Dual{<:TensorMap}) = arrayify(primal(A_dA), tangent(A_dA))
Mooncake.arrayify(A::TensorMap, dA::TensorMap) = (A, dA)

Mooncake.arrayify(A_dA::CoDual{<:DiagonalTensorMap}) = arrayify(primal(A_dA), tangent(A_dA))
Mooncake.arrayify(A_dA::Dual{<:DiagonalTensorMap}) = arrayify(primal(A_dA), tangent(A_dA))
Mooncake.arrayify(A::DiagonalTensorMap, dA::DiagonalTensorMap) = (A, dA)

function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TK.AdjointTensorMap})
Expand All @@ -14,6 +16,14 @@ function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TK.AdjointTensorMap})
return A', ΔA'
end

function Mooncake.arrayify(Aᴴ_ΔAᴴ::Dual{<:TK.AdjointTensorMap})
Aᴴ = Mooncake.primal(Aᴴ_ΔAᴴ)
ΔAᴴ = Mooncake.tangent(Aᴴ_ΔAᴴ)
A_ΔA = Dual(Aᴴ', ΔAᴴ.fields.parent)
A, ΔA = arrayify(A_ΔA)
return A', ΔA'
end

# Define the tangent type of a TensorMap to be TensorMap itself.
# This has a number of benefits, but also correctly alters the
# inner product when dealing with non-abelian symmetries.
Expand Down Expand Up @@ -181,7 +191,7 @@ _field_symbol(t, ::Val{F}) where {F} = _field_symbol(t, F)

# frules
_frule_getfield_common(t_dt::Dual{<:DiagOrTensorMap}, field_sym::Symbol) =
Dual(getfield(primal(t), field_sym), field_sym === :data ? tangent(t).data : NoFData())
Dual(getfield(primal(t_dt), field_sym), field_sym === :data ? tangent(t_dt).data : NoTangent())

Mooncake.frule!!(::Dual{typeof(Mooncake.lgetfield)}, t_dt::Dual{<:DiagOrTensorMap}, f_df::Dual) =
_frule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df)))
Expand Down
1 change: 1 addition & 0 deletions ext/TensorKitMooncakeExt/utility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent

@zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.adjoint), HomSpace}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.permute), HomSpace, Index2Tuple}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.braid), HomSpace, Index2Tuple, IndexTuple}
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.compose), HomSpace, HomSpace}
Expand Down
93 changes: 4 additions & 89 deletions ext/TensorKitMooncakeExt/vectorinterface.jl
Original file line number Diff line number Diff line change
@@ -1,89 +1,4 @@
@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, Number}

function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
α = primal(α_Δα)

# primal call
C_cache = copy(C)
scale!(C, α)

function scale_pullback(::NoRData)
copy!(C, C_cache)
Δαr = _needs_tangent(α) ? project_scalar(α, inner(C, ΔC)) : NoRData()
scale!(ΔC, conj(α))
return NoRData(), NoRData(), Δαr
end

return C_ΔC, scale_pullback
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number}

function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
α = primal(α_Δα)

# primal call
C_cache = copy(C)
scale!(C, A, α)

function scale_pullback(::NoRData)
copy!(C, C_cache)
add!(ΔA, ΔC, conj(α))
Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData()
zerovector!(ΔC)
return NoRData(), NoRData(), NoRData(), Δαr
end

return C_ΔC, scale_pullback
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number}

function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
α = primal(α_Δα)
β = primal(β_Δβ)

# primal call
C_cache = copy(C)
add!(C, A, α, β)

function add_pullback(::NoRData)
copy!(C, C_cache)

Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData()
Δβr = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData()
add!(ΔA, ΔC, conj(α))
scale!(ΔC, conj(β))

return NoRData(), NoRData(), NoRData(), Δαr, Δβr
end

return C_ΔC, add_pullback
end

@is_primitive DefaultCtx ReverseMode Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap}

function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap})
# prepare arguments
A, ΔA = arrayify(A_ΔA)
B, ΔB = arrayify(B_ΔB)

# primal call
s = inner(A, B)

function inner_pullback(Δs)
add!(ΔA, B, conj(Δs))
add!(ΔB, A, Δs)
return NoRData(), NoRData(), NoRData()
end

return CoDual(s, NoFData()), inner_pullback
end
@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, Number}
@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number}
@is_primitive DefaultCtx Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number}
@is_primitive DefaultCtx Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap}
35 changes: 16 additions & 19 deletions test/mooncake/vectorinterface.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using Mooncake
using VectorInterface, Mooncake
using Random


mode = Mooncake.ReverseMode
rng = Random.default_rng()

spacelist = ad_spacelist(fast_tests)
Expand All @@ -17,20 +15,19 @@ eltypes = (Float64, ComplexF64)

C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])')
α = randn(T)
β = randn(T)

Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, mode)

Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, mode, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, mode, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, mode)

Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, mode)
Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, mode)
for α in (randn(T), One(), Zero()), β in (randn(T), One(), Zero())
Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, is_primitive = false)

Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, is_primitive = false)

Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, is_primitive = false)
Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, is_primitive = false)
end
end
Loading