Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 152 additions & 10 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
using Mooncake
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
using MatrixAlgebraKit
using MatrixAlgebraKit: inv_safe, diagview, copy_input
using MatrixAlgebraKit: inv_safe, diagview, copy_input, zero!
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
Expand Down Expand Up @@ -52,11 +52,11 @@ for (f!, f, pb, adj) in (
$f!(A, args, Mooncake.primal(alg_dalg))
function $adj(::NoRData)
copy!(A, Ac)
$pb(dA, A, (arg1, arg2), (darg1, darg2))
copy!(arg1, arg1c)
copy!(arg2, arg2c)
MatrixAlgebraKit.zero!(darg1)
MatrixAlgebraKit.zero!(darg2)
$pb(dA, A, (arg1, arg2), (darg1, darg2))
zero!(darg1)
zero!(darg2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like a lot more sensible order. I assume this was working before because arg were not modified in between the forward and the backward pass?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thus the value of testing this a little more directly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although wait, I am now confused. The copy of arg is made before the primal call. So this restores the state of arg to that before it got the actual output values assigned into it, and uses these values in the pullback call. How can this then yield the correct result?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to have the actual computed values of arg available at the time of calling the pullback (assuming that someone might have modified the return values after the primal call), while also being able to restore the state of arg to value before the primal call, it seems like we need to independent copies of arg. One before calling the primal (to restore after the pullback call), and one copy after calling the primal (as a cache to be used in the pullback, in case someone would be destroying the values) in between primal and pullback call. Is this correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new order is zeroing out the derivative of the arg, not its primal value. Is that the source of the issue?

Copy link
Member

@Jutho Jutho Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. That was always there, except that before it was MatrixAlgebraKit.zero! and now you've imported zero!. I agree this needs to be there, but this has not changed. Before the pattern was:

Ac = copy(A)
argc = copy(arg) # value before calling primal
primal_call!(...) # destroy A and store result in arg

pullback_closure
    copy!(A, Ac) # restore A from Ac
    pullback(A, arg, ...) # assume arg is still containing correct result
    copy!(arg, arc) # restore arg to whatever value it contained before it contained the result
    zero!(darg)
end

whereas now line 2 and 3 have of pullback_closure have been flipped, which I don't see how this can work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, sorry, it's the GitHub UI confusing me once again

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it again be that we are not using A in the pullback, so the ordering is irrelevant here? I can flip it back again and add a comment for clarity

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the question is about the value of arg inside the pullback, which you don't actually wan't to restore to it's (typically undef) value before you call the pullback function

return NoRData(), NoRData(), NoRData(), NoRData()
end
return args_dargs, $adj
Expand All @@ -76,8 +76,8 @@ for (f!, f, pb, adj) in (
arg1, darg1 = arrayify(arg1, darg1_)
arg2, darg2 = arrayify(arg2, darg2_)
$pb(dA, A, (arg1, arg2), (darg1, darg2))
MatrixAlgebraKit.zero!(darg1)
MatrixAlgebraKit.zero!(darg2)
zero!(darg1)
zero!(darg2)
return NoRData(), NoRData(), NoRData()
end
return output_codual, $adj
Expand All @@ -99,8 +99,8 @@ for (f!, f, pb, adj) in (
$f!(A, arg, Mooncake.primal(alg_dalg))
function $adj(::NoRData)
copy!(A, Ac)
$pb(dA, A, arg, darg)
copy!(arg, argc)
$pb(dA, A, arg, darg)
MatrixAlgebraKit.zero!(darg)
return NoRData(), NoRData(), NoRData(), NoRData()
end
Expand Down Expand Up @@ -137,6 +137,7 @@ for (f!, f, f_full, pb, adj) in (
copy!(D, diagview(DV[1]))
V = DV[2]
function $adj(::NoRData)
copy!(D, diagview(DV[1]))
$pb(dA, A, DV, dD)
MatrixAlgebraKit.zero!(dD)
return NoRData(), NoRData(), NoRData(), NoRData()
Expand All @@ -163,12 +164,43 @@ for (f!, f, f_full, pb, adj) in (
end
end

for (f, f_ne, pb, adj) in (
(:eig_trunc, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint),
(:eigh_trunc, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint),
for (f!, f, f_ne!, f_ne, pb, adj) in (
(:eig_trunc!, :eig_trunc, :eig_trunc_no_error!, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint),
(:eigh_trunc!, :eigh_trunc, :eigh_trunc_no_error!, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint),
)
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
DV = Mooncake.primal(DV_dDV)
dDV = Mooncake.tangent(DV_dDV)
Ac = copy(A)
DVc = copy.(DV)
alg = Mooncake.primal(alg_dalg)
output = $f!(A, DV, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real}
copy!(A, Ac)
copy!(DV[1], DVc[1])
copy!(DV[2], DVc[2])
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
$pb(dA, A, (D′, V′), (dD′, dV′))
MatrixAlgebraKit.zero!(dD)
MatrixAlgebraKit.zero!(dV)
return NoRData(), NoRData(), NoRData()
end
return output_codual, $adj
end
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand All @@ -192,7 +224,37 @@ for (f, f_ne, pb, adj) in (
end
return output_codual, $adj
end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f_ne!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
alg = Mooncake.primal(alg_dalg)
DV = Mooncake.primal(DV_dDV)
dDV = Mooncake.tangent(DV_dDV)
Ac = copy(A)
DVc = copy.(DV)
output = $f_ne(A, DV, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function $adj(::NoRData)
copy!(A, Ac)
copy!(DV[1], DVc[1])
copy!(DV[2], DVc[2])
Dtrunc, Vtrunc = Mooncake.primal(output_codual)
dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual)
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
$pb(dA, A, (D′, V′), (dD′, dV′))
MatrixAlgebraKit.zero!(dD)
MatrixAlgebraKit.zero!(dV)
return NoRData(), NoRData(), NoRData()
end
return output_codual, $adj
end
function Mooncake.rrule!!(::CoDual{typeof($f_ne)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand Down Expand Up @@ -232,9 +294,13 @@ for (f!, f) in (
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)
output = $f!(A, Mooncake.primal(alg_dalg))
function svd_adjoint(::NoRData)
copy!(A, Ac)
copy!(U, USVᴴc[1])
copy!(S, USVᴴc[2])
copy!(Vᴴ, USVᴴc[3])
if $(f! == svd_compact!)
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
else # full
Expand Down Expand Up @@ -301,6 +367,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
function svd_vals_adjoint(::NoRData)
svd_vals_pullback!(dA, A, USVᴴ, dS)
MatrixAlgebraKit.zero!(dS)
copy!(S, diagview(USVᴴ[2]))
return NoRData(), NoRData(), NoRData(), NoRData()
end
return S_dS, svd_vals_adjoint
Expand All @@ -326,6 +393,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
return S_codual, svd_vals_adjoint
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
alg = Mooncake.primal(alg_dalg)
Ac = copy(A)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)
output = svd_trunc!(A, USVᴴ, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real}
copy!(A, Ac)
copy!(U, USVᴴc[1])
copy!(S, USVᴴc[2])
copy!(Vᴴ, USVᴴc[3])
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error"
U′, dU′ = arrayify(Utrunc, dUtrunc_)
S′, dS′ = arrayify(Strunc, dStrunc_)
Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_)
svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
MatrixAlgebraKit.zero!(dU)
MatrixAlgebraKit.zero!(dS)
MatrixAlgebraKit.zero!(dVᴴ)
return NoRData(), NoRData(), NoRData()
end
return output_codual, svd_trunc_adjoint
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
Expand Down Expand Up @@ -355,6 +460,43 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
return output_codual, svd_trunc_adjoint
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
alg = Mooncake.primal(alg_dalg)
Ac = copy(A)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)
output = svd_trunc_no_error!(A, USVᴴ, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function svd_trunc_adjoint(::NoRData)
copy!(A, Ac)
copy!(U, USVᴴc[1])
copy!(S, USVᴴc[2])
copy!(Vᴴ, USVᴴc[3])
Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual)
dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual)
U′, dU′ = arrayify(Utrunc, dUtrunc_)
S′, dS′ = arrayify(Strunc, dStrunc_)
Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_)
svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
MatrixAlgebraKit.zero!(dU)
MatrixAlgebraKit.zero!(dS)
MatrixAlgebraKit.zero!(dVᴴ)
return NoRData(), NoRData(), NoRData()
end
return output_codual, svd_trunc_adjoint
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
Expand Down
Loading