From 483a438130b3687bc77fa3e537df4f184c59a24b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 21 Jan 2026 19:48:27 +0100 Subject: [PATCH 1/2] Add tests/fixes for inplace rules --- .../MatrixAlgebraKitMooncakeExt.jl | 162 ++++++++++++++++- test/mooncake.jl | 171 +++++++++++------- 2 files changed, 258 insertions(+), 75 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index f6feda8b..7fe18e07 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt using Mooncake using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive using MatrixAlgebraKit -using MatrixAlgebraKit: inv_safe, diagview, copy_input +using MatrixAlgebraKit: inv_safe, diagview, copy_input, zero! using MatrixAlgebraKit: qr_pullback!, lq_pullback! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! @@ -52,11 +52,11 @@ for (f!, f, pb, adj) in ( $f!(A, args, Mooncake.primal(alg_dalg)) function $adj(::NoRData) copy!(A, Ac) - $pb(dA, A, (arg1, arg2), (darg1, darg2)) copy!(arg1, arg1c) copy!(arg2, arg2c) - MatrixAlgebraKit.zero!(darg1) - MatrixAlgebraKit.zero!(darg2) + $pb(dA, A, (arg1, arg2), (darg1, darg2)) + zero!(darg1) + zero!(darg2) return NoRData(), NoRData(), NoRData(), NoRData() end return args_dargs, $adj @@ -76,8 +76,8 @@ for (f!, f, pb, adj) in ( arg1, darg1 = arrayify(arg1, darg1_) arg2, darg2 = arrayify(arg2, darg2_) $pb(dA, A, (arg1, arg2), (darg1, darg2)) - MatrixAlgebraKit.zero!(darg1) - MatrixAlgebraKit.zero!(darg2) + zero!(darg1) + zero!(darg2) return NoRData(), NoRData(), NoRData() end return output_codual, $adj @@ -99,8 +99,8 @@ for (f!, f, pb, adj) in ( $f!(A, arg, Mooncake.primal(alg_dalg)) function $adj(::NoRData) copy!(A, Ac) - $pb(dA, A, arg, darg) copy!(arg, argc) + $pb(dA, A, arg, darg) MatrixAlgebraKit.zero!(darg) return NoRData(), NoRData(), NoRData(), NoRData() end @@ -137,6 +137,7 @@ for (f!, f, f_full, pb, adj) in ( copy!(D, diagview(DV[1])) V = DV[2] function $adj(::NoRData) + copy!(D, diagview(DV[1])) $pb(dA, A, DV, dD) MatrixAlgebraKit.zero!(dD) return NoRData(), NoRData(), NoRData(), NoRData() @@ -163,12 +164,43 @@ for (f!, f, f_full, pb, adj) in ( end end -for (f, f_ne, pb, adj) in ( - (:eig_trunc, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint), - (:eigh_trunc, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint), +for (f!, f, f_ne!, f_ne, pb, adj) in ( + (:eig_trunc!, :eig_trunc, :eig_trunc_no_error!, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint), + (:eigh_trunc!, :eigh_trunc, :eigh_trunc_no_error!, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint), ) @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + DV = Mooncake.primal(DV_dDV) + dDV = Mooncake.tangent(DV_dDV) + Ac = copy(A) + DVc = copy.(DV) + alg = Mooncake.primal(alg_dalg) + output = $f!(A, DV, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} + copy!(A, Ac) + copy!(DV[1], DVc[1]) + copy!(DV[2], DVc[2]) + Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) + dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) + abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error" + D′, dD′ = arrayify(Dtrunc, dDtrunc_) + V′, dV′ = arrayify(Vtrunc, dVtrunc_) + $pb(dA, A, (D′, V′), (dD′, dV′)) + MatrixAlgebraKit.zero!(dD) + MatrixAlgebraKit.zero!(dV) + return NoRData(), NoRData(), NoRData() + end + return output_codual, $adj + end function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -192,7 +224,37 @@ for (f, f_ne, pb, adj) in ( end return output_codual, $adj end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f_ne!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + DV = Mooncake.primal(DV_dDV) + dDV = Mooncake.tangent(DV_dDV) + Ac = copy(A) + DVc = copy.(DV) + output = $f_ne(A, DV, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function $adj(::NoRData) + copy!(A, Ac) + copy!(DV[1], DVc[1]) + copy!(DV[2], DVc[2]) + Dtrunc, Vtrunc = Mooncake.primal(output_codual) + dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual) + D′, dD′ = arrayify(Dtrunc, dDtrunc_) + V′, dV′ = arrayify(Vtrunc, dVtrunc_) + $pb(dA, A, (D′, V′), (dD′, dV′)) + MatrixAlgebraKit.zero!(dD) + MatrixAlgebraKit.zero!(dV) + return NoRData(), NoRData(), NoRData() + end + return output_codual, $adj + end function Mooncake.rrule!!(::CoDual{typeof($f_ne)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -232,9 +294,13 @@ for (f!, f) in ( U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + USVᴴc = copy.(USVᴴ) output = $f!(A, Mooncake.primal(alg_dalg)) function svd_adjoint(::NoRData) copy!(A, Ac) + copy!(U, USVᴴc[1]) + copy!(S, USVᴴc[2]) + copy!(Vᴴ, USVᴴc[3]) if $(f! == svd_compact!) svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) else # full @@ -301,6 +367,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua function svd_vals_adjoint(::NoRData) svd_vals_pullback!(dA, A, USVᴴ, dS) MatrixAlgebraKit.zero!(dS) + copy!(S, diagview(USVᴴ[2])) return NoRData(), NoRData(), NoRData(), NoRData() end return S_dS, svd_vals_adjoint @@ -326,6 +393,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co return S_codual, svd_vals_adjoint end +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + Ac = copy(A) + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) + S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) + Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + USVᴴc = copy.(USVᴴ) + output = svd_trunc!(A, USVᴴ, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} + copy!(A, Ac) + copy!(U, USVᴴc[1]) + copy!(S, USVᴴc[2]) + copy!(Vᴴ, USVᴴc[3]) + Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) + dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) + abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error" + U′, dU′ = arrayify(Utrunc, dUtrunc_) + S′, dS′ = arrayify(Strunc, dStrunc_) + Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_) + svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′)) + MatrixAlgebraKit.zero!(dU) + MatrixAlgebraKit.zero!(dS) + MatrixAlgebraKit.zero!(dVᴴ) + return NoRData(), NoRData(), NoRData() + end + return output_codual, svd_trunc_adjoint +end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal @@ -355,6 +460,43 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C return output_codual, svd_trunc_adjoint end +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + Ac = copy(A) + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) + S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) + Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + USVᴴc = copy.(USVᴴ) + output = svd_trunc_no_error!(A, USVᴴ, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function svd_trunc_adjoint(::NoRData) + copy!(A, Ac) + copy!(U, USVᴴc[1]) + copy!(S, USVᴴc[2]) + copy!(Vᴴ, USVᴴc[3]) + Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual) + dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual) + U′, dU′ = arrayify(Utrunc, dUtrunc_) + S′, dS′ = arrayify(Strunc, dStrunc_) + Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_) + svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′)) + MatrixAlgebraKit.zero!(dU) + MatrixAlgebraKit.zero!(dS) + MatrixAlgebraKit.zero!(dVᴴ) + return NoRData(), NoRData(), NoRData() + end + return output_codual, svd_trunc_adjoint +end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal diff --git a/test/mooncake.jl b/test/mooncake.jl index 760102b1..13ccc048 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -22,26 +22,43 @@ make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), T...) make_mooncake_fdata(x) = make_mooncake_tangent(x) make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) -ETs = (Float32, ComplexF64) - +ETs = (Float64, ComplexF64) # no `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) +function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) dA_copy = make_mooncake_tangent(copy(ΔA)) A_copy = copy(A) dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy)) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy)) + if args isa Tuple + copyto!.(Mooncake.tangent(copy_out), dargs_copy) + else + copyto!(Mooncake.tangent(copy_out), dargs_copy) + end + @test Mooncake.primal(copy_out) ≈ args copy_pb!!(rdata) - return dA_copy + return dA_copy, Mooncake.tangent(copy_out) end # `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) +function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, alg, rdata) dA_copy = make_mooncake_tangent(copy(ΔA)) A_copy = copy(A) dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) + if args isa Tuple + copyto!.(Mooncake.tangent(copy_out), dargs_copy) + else + copyto!(Mooncake.tangent(copy_out), dargs_copy) + end + if args isa Tuple + for (arg, out) in zip(args, Mooncake.primal(copy_out)) + @test out ≈ arg + end + else + @test Mooncake.primal(copy_out) ≈ args + end copy_pb!!(rdata) - return dA_copy + return dA_copy, Mooncake.tangent(copy_out) end function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) @@ -60,13 +77,13 @@ function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) end inplace_pb!!(rdata) - return dA_inplace + return dA_inplace, dargs_inplace end function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) dA_inplace = make_mooncake_tangent(copy(ΔA)) A_inplace = copy(A) - dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(Δargs) : make_mooncake_fdata(Δargs) # not every f! has a handwritten rrule!! inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) @@ -79,7 +96,7 @@ function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) end inplace_pb!!(rdata) - return dA_inplace + return dA_inplace, dargs_inplace end """ @@ -100,37 +117,39 @@ The arguments to this function are: - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) """ function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) - sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} + sig = isnothing(alg) ? Tuple{typeof(f), typeof(A)} : Tuple{typeof(f), typeof(A), typeof(alg)} rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) rrule = Mooncake.build_rrule(rvs_interp, sig) ΔA = randn(rng, eltype(A), size(A)) - dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + copy_args = isa(args, Tuple) ? copy.(args) : copy(args) + inplace_args = isa(args, Tuple) ? copy.(args) : copy(args) + dA_copy, dargs_copy = _get_copying_derivative(f, rrule, A, ΔA, copy_args, Δargs, alg, rdata) + dA_inplace, dargs_inplace = _get_inplace_derivative(f!, A, ΔA, inplace_args, Δargs, alg, rdata) dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] @test dA_inplace_ ≈ dA_copy_ + @test copy_args == inplace_args + @test dargs_copy == dargs_inplace return end @timedtestset "QR AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) + @testset "size ($m, $n)" for n in (17,) # m, 23) atol = rtol = m * n * precision(T) A = randn(rng, T, m, n) minmn = min(m, n) @testset for alg in ( LAPACK_HouseholderQR(), - LAPACK_HouseholderQR(; positive = true), + #LAPACK_HouseholderQR(; positive = true), ) @testset "qr_compact" begin QR = qr_compact(A, alg) - Q = randn(rng, T, m, minmn) - R = randn(rng, T, minmn, n) - Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + Q, R = QR + Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(rng, qr_compact!, qr_compact, A, (Q, R), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) end @testset "qr_null" begin @@ -138,10 +157,10 @@ end ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) N = qr_null(A, alg) dN = make_mooncake_tangent(copy(ΔN)) - Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dN, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dN, atol, rtol) test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) end - @testset "qr_full" begin + #=@testset "qr_full" begin Q, R = qr_full(A, alg) Q1 = view(Q, 1:m, 1:minmn) ΔQ = randn(rng, T, m, m) @@ -151,7 +170,8 @@ end dQ = make_mooncake_tangent(copy(ΔQ)) dR = make_mooncake_tangent(copy(ΔR)) dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, qr_full!, A, (Q, R), alg; mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) + Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) end @testset "qr_compact - rank-deficient A" begin @@ -169,13 +189,14 @@ end dQ = make_mooncake_tangent(copy(ΔQ)) dR = make_mooncake_tangent(copy(ΔR)) dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, qr_compact!, Ard, QR, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) + Mooncake.TestUtils.test_rule(rng, qr_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) - end + end=# end end end - +#= @timedtestset "LQ AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) m = 19 @@ -189,7 +210,8 @@ end ) @testset "lq_compact" begin L, Q = lq_compact(A, alg) - Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, lq_compact!, A, (L, Q), alg; mode = Mooncake.ReverseMode, atol, rtol) + Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) end @testset "lq_null" begin @@ -197,7 +219,8 @@ end ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q Nᴴ = randn(rng, T, max(0, n - minmn), n) dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dNᴴ, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, lq_null!, A, Nᴴ, alg; mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol, rtol) test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) end @testset "lq_full" begin @@ -210,7 +233,8 @@ end dL = make_mooncake_tangent(ΔL) dQ = make_mooncake_tangent(ΔQ) dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, lq_full!, A, (L, Q), alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) end @testset "lq_compact - rank-deficient A" begin @@ -227,7 +251,8 @@ end dL = make_mooncake_tangent(ΔL) dQ = make_mooncake_tangent(ΔQ) dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, lq_compact!, Ard, (L, Q), alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) end end @@ -243,7 +268,7 @@ end D, V = DV Ddiag = diagview(D) ΔV = randn(rng, complex(T), m, m) - ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) + ΔV = remove_eiggauge_dependence!(ΔV, D, V) ΔD = randn(rng, complex(T), m, m) ΔD2 = Diagonal(randn(rng, complex(T), m)) @@ -256,11 +281,13 @@ end #LAPACK_Expert(), # expensive on CI ) @testset "eig_full" begin - Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, eig_full!, A, DV, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, atol, rtol) test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) end @testset "eig_vals" begin - Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, eig_vals!, A, D, alg; mode = Mooncake.ReverseMode, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) end @testset "eig_trunc" begin @@ -274,10 +301,12 @@ end dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, eig_trunc!, A, DV, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error!, A, DV, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) end truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) @@ -289,10 +318,12 @@ end dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, eig_trunc!, A, DV, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error!, A, DV, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) end end @@ -351,7 +382,7 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgeb D, V = eigh_full(A) Ddiag = diagview(D) ΔV = randn(rng, T, m, m) - ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) + ΔV = remove_eighgauge_dependence!(ΔV, D, V) ΔD = randn(rng, real(T), m, m) ΔD2 = Diagonal(randn(rng, real(T), m)) dD = make_mooncake_tangent(ΔD2) @@ -364,11 +395,11 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgeb #LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI ) @testset "eigh_full" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol, rtol) test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) end @testset "eigh_vals" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol, rtol) test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) end @testset "eigh_trunc" begin @@ -382,10 +413,10 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgeb dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) end truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) @@ -397,10 +428,10 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgeb dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) end end @@ -423,12 +454,13 @@ end ΔS2 = Diagonal(randn(rng, real(T), minmn)) ΔVᴴ = randn(rng, T, minmn, n) U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) dS = make_mooncake_tangent(ΔS2) dU = make_mooncake_tangent(ΔU) dVᴴ = make_mooncake_tangent(ΔVᴴ) dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_compact!, A, (U, S, Vᴴ), alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), alg) end @testset "svd_full" begin @@ -437,7 +469,7 @@ end ΔS2 = Diagonal(randn(rng, real(T), minmn)) ΔVᴴ = randn(rng, T, minmn, n) U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) ΔUfull = zeros(T, m, m) ΔSfull = zeros(real(T), m, n) ΔVᴴfull = zeros(T, n, n) @@ -449,12 +481,14 @@ end dU = make_mooncake_tangent(ΔUfull) dVᴴ = make_mooncake_tangent(ΔVᴴfull) dUSVᴴ = Mooncake.build_tangent(typeof((ΔUfull, ΔSfull, ΔVᴴfull)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_full!, A, (U, S, Vᴴ), alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) test_pullbacks_match(rng, svd_full!, svd_full, A, (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull), alg) end @testset "svd_vals" begin - Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) S = svd_vals(A, alg) + Mooncake.TestUtils.test_rule(rng, svd_vals!, A, S, alg; mode = Mooncake.ReverseMode, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, randn(rng, real(T), minmn), alg) end @testset "svd_trunc" begin @@ -464,7 +498,7 @@ end ΔS = randn(rng, real(T), minmn, minmn) ΔS2 = Diagonal(randn(rng, real(T), minmn)) ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) truncalg = TruncatedAlgorithm(alg, truncrank(r)) ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) Strunc = Diagonal(diagview(S)[ind]) @@ -478,10 +512,12 @@ end dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) ϵ = zero(real(T)) dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc!, A, USVᴴ, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error!, A, USVᴴ, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) end @testset "trunctol" begin @@ -490,7 +526,7 @@ end ΔS = randn(rng, real(T), minmn, minmn) ΔS2 = Diagonal(randn(rng, real(T), minmn)) ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) Strunc = Diagonal(diagview(S)[ind]) @@ -504,10 +540,12 @@ end dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) ϵ = zero(real(T)) dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc!, A, USVᴴ, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error!, A, USVᴴ, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) end end @@ -529,11 +567,13 @@ end ) if m >= n WP = left_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, left_polar!, A, WP, alg; mode = Mooncake.ReverseMode, atol, rtol) + Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(rng, left_polar!, left_polar, A, WP, (randn(rng, T, m, n), randn(rng, T, n, n)), alg) elseif m <= n PWᴴ = right_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, right_polar!, A, PWᴴ, alg; mode = Mooncake.ReverseMode, atol, rtol) + Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, (randn(rng, T, m, m), randn(rng, T, m, n)), alg) end end @@ -562,36 +602,37 @@ MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_ A = randn(rng, T, m, n) VC = left_orth(A) CVᴴ = right_orth(A) - Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(rng, left_orth!, left_orth, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(rng, right_orth!, right_orth, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) if m >= n - Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) end N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) + Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dN) test_pullbacks_match(rng, ((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) - Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) if m <= n - Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) end Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) + Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dNᴴ) test_pullbacks_match(rng, ((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) end end +=# From 8852ca97dccc2aef9bf77e65555382c397d2ab04 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 26 Jan 2026 20:56:11 +0100 Subject: [PATCH 2/2] Fix arg copying order --- .../MatrixAlgebraKitMooncakeExt.jl | 100 +++++++++--------- 1 file changed, 51 insertions(+), 49 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 7fe18e07..96a34523 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -52,9 +52,9 @@ for (f!, f, pb, adj) in ( $f!(A, args, Mooncake.primal(alg_dalg)) function $adj(::NoRData) copy!(A, Ac) + $pb(dA, A, (arg1, arg2), (darg1, darg2)) copy!(arg1, arg1c) copy!(arg2, arg2c) - $pb(dA, A, (arg1, arg2), (darg1, darg2)) zero!(darg1) zero!(darg2) return NoRData(), NoRData(), NoRData(), NoRData() @@ -99,9 +99,9 @@ for (f!, f, pb, adj) in ( $f!(A, arg, Mooncake.primal(alg_dalg)) function $adj(::NoRData) copy!(A, Ac) - copy!(arg, argc) $pb(dA, A, arg, darg) - MatrixAlgebraKit.zero!(darg) + copy!(arg, argc) + zero!(darg) return NoRData(), NoRData(), NoRData(), NoRData() end return arg_darg, $adj @@ -114,7 +114,7 @@ for (f!, f, pb, adj) in ( function $adj(::NoRData) arg, darg = arrayify(output_codual) $pb(dA, A, arg, darg) - MatrixAlgebraKit.zero!(darg) + zero!(darg) return NoRData(), NoRData(), NoRData() end return output_codual, $adj @@ -132,14 +132,15 @@ for (f!, f, f_full, pb, adj) in ( # compute primal A, dA = arrayify(A_dA) D, dD = arrayify(D_dD) + Dc = copy(D) # update primal DV = $f_full(A, Mooncake.primal(alg_dalg)) copy!(D, diagview(DV[1])) V = DV[2] function $adj(::NoRData) - copy!(D, diagview(DV[1])) $pb(dA, A, DV, dD) - MatrixAlgebraKit.zero!(dD) + copy!(D, Dc) + zero!(dD) return NoRData(), NoRData(), NoRData(), NoRData() end return D_dD, $adj @@ -156,7 +157,7 @@ for (f!, f, f_full, pb, adj) in ( function $adj(::NoRData) D, dD = arrayify(output_codual) $pb(dA, A, DV, dD) - MatrixAlgebraKit.zero!(dD) + zero!(dD) return NoRData(), NoRData(), NoRData() end return output_codual, $adj @@ -187,16 +188,16 @@ for (f!, f, f_ne!, f_ne, pb, adj) in ( output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} copy!(A, Ac) - copy!(DV[1], DVc[1]) - copy!(DV[2], DVc[2]) Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error" D′, dD′ = arrayify(Dtrunc, dDtrunc_) V′, dV′ = arrayify(Vtrunc, dVtrunc_) $pb(dA, A, (D′, V′), (dD′, dV′)) - MatrixAlgebraKit.zero!(dD) - MatrixAlgebraKit.zero!(dV) + copy!(DV[1], DVc[1]) + copy!(DV[2], DVc[2]) + zero!(dD) + zero!(dV) return NoRData(), NoRData(), NoRData() end return output_codual, $adj @@ -218,8 +219,8 @@ for (f!, f, f_ne!, f_ne, pb, adj) in ( D, dD = arrayify(Dtrunc, dDtrunc_) V, dV = arrayify(Vtrunc, dVtrunc_) $pb(dA, A, (D, V), (dD, dV)) - MatrixAlgebraKit.zero!(dD) - MatrixAlgebraKit.zero!(dV) + zero!(dD) + zero!(dV) return NoRData(), NoRData(), NoRData() end return output_codual, $adj @@ -242,15 +243,15 @@ for (f!, f, f_ne!, f_ne, pb, adj) in ( output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) function $adj(::NoRData) copy!(A, Ac) - copy!(DV[1], DVc[1]) - copy!(DV[2], DVc[2]) Dtrunc, Vtrunc = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual) D′, dD′ = arrayify(Dtrunc, dDtrunc_) V′, dV′ = arrayify(Vtrunc, dVtrunc_) $pb(dA, A, (D′, V′), (dD′, dV′)) - MatrixAlgebraKit.zero!(dD) - MatrixAlgebraKit.zero!(dV) + copy!(DV[1], DVc[1]) + copy!(DV[2], DVc[2]) + zero!(dD) + zero!(dV) return NoRData(), NoRData(), NoRData() end return output_codual, $adj @@ -271,8 +272,8 @@ for (f!, f, f_ne!, f_ne, pb, adj) in ( D, dD = arrayify(Dtrunc, dDtrunc_) V, dV = arrayify(Vtrunc, dVtrunc_) $pb(dA, A, (D, V), (dD, dV)) - MatrixAlgebraKit.zero!(dD) - MatrixAlgebraKit.zero!(dV) + zero!(dD) + zero!(dV) return NoRData(), NoRData(), NoRData() end return output_codual, $adj @@ -298,9 +299,6 @@ for (f!, f) in ( output = $f!(A, Mooncake.primal(alg_dalg)) function svd_adjoint(::NoRData) copy!(A, Ac) - copy!(U, USVᴴc[1]) - copy!(S, USVᴴc[2]) - copy!(Vᴴ, USVᴴc[3]) if $(f! == svd_compact!) svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) else # full @@ -313,9 +311,12 @@ for (f!, f) in ( vdVᴴ = view(dVᴴ, 1:minmn, :) svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) end - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + copy!(U, USVᴴc[1]) + copy!(S, USVᴴc[2]) + copy!(Vᴴ, USVᴴc[3]) + zero!(dU) + zero!(dS) + zero!(dVᴴ) return NoRData(), NoRData(), NoRData(), NoRData() end return CoDual(output, dUSVᴴ), svd_adjoint @@ -347,9 +348,9 @@ for (f!, f) in ( vdVᴴ = view(dVᴴ, 1:minmn, :) svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) end - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + zero!(dU) + zero!(dS) + zero!(dVᴴ) return NoRData(), NoRData(), NoRData() end return USVᴴ_codual, svd_adjoint @@ -362,12 +363,13 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua # compute primal A, dA = arrayify(A_dA) S, dS = arrayify(S_dS) + Sc = copy(S) USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) copy!(S, diagview(USVᴴ[2])) function svd_vals_adjoint(::NoRData) svd_vals_pullback!(dA, A, USVᴴ, dS) - MatrixAlgebraKit.zero!(dS) - copy!(S, diagview(USVᴴ[2])) + zero!(dS) + copy!(S, Sc) return NoRData(), NoRData(), NoRData(), NoRData() end return S_dS, svd_vals_adjoint @@ -387,7 +389,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co function svd_vals_adjoint(::NoRData) S, dS = arrayify(S_codual) svd_vals_pullback!(dA, A, USVᴴ, dS) - MatrixAlgebraKit.zero!(dS) + zero!(dS) return NoRData(), NoRData(), NoRData() end return S_codual, svd_vals_adjoint @@ -413,9 +415,6 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} copy!(A, Ac) - copy!(U, USVᴴc[1]) - copy!(S, USVᴴc[2]) - copy!(Vᴴ, USVᴴc[3]) Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error" @@ -423,9 +422,12 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS S′, dS′ = arrayify(Strunc, dStrunc_) Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_) svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′)) - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + copy!(U, USVᴴc[1]) + copy!(S, USVᴴc[2]) + copy!(Vᴴ, USVᴴc[3]) + zero!(dU) + zero!(dS) + zero!(dVᴴ) return NoRData(), NoRData(), NoRData() end return output_codual, svd_trunc_adjoint @@ -452,9 +454,9 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C S, dS = arrayify(Strunc, dStrunc_) Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + zero!(dU) + zero!(dS) + zero!(dVᴴ) return NoRData(), NoRData(), NoRData() end return output_codual, svd_trunc_adjoint @@ -480,18 +482,18 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, US output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) function svd_trunc_adjoint(::NoRData) copy!(A, Ac) - copy!(U, USVᴴc[1]) - copy!(S, USVᴴc[2]) - copy!(Vᴴ, USVᴴc[3]) Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual) U′, dU′ = arrayify(Utrunc, dUtrunc_) S′, dS′ = arrayify(Strunc, dStrunc_) Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_) svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′)) - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + copy!(U, USVᴴc[1]) + copy!(S, USVᴴc[2]) + copy!(Vᴴ, USVᴴc[3]) + zero!(dU) + zero!(dS) + zero!(dVᴴ) return NoRData(), NoRData(), NoRData() end return output_codual, svd_trunc_adjoint @@ -517,9 +519,9 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al S, dS = arrayify(Strunc, dStrunc_) Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + zero!(dU) + zero!(dS) + zero!(dVᴴ) return NoRData(), NoRData(), NoRData() end return output_codual, svd_trunc_adjoint