From fbfb889dbeb6de51c36f36950ff2f66dfe4cc24a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 02:55:12 -0500 Subject: [PATCH 01/18] Use Testsuite for AD tests --- Project.toml | 2 +- .../MatrixAlgebraKitAMDGPUExt.jl | 9 +- .../MatrixAlgebraKitCUDAExt.jl | 20 +- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 21 + .../MatrixAlgebraKitMooncakeExt.jl | 8 +- src/common/defaults.jl | 1 + src/common/pullbacks.jl | 3 + src/pullbacks/eig.jl | 24 +- src/pullbacks/eigh.jl | 24 +- src/pullbacks/lq.jl | 80 ++- src/pullbacks/polar.jl | 2 +- src/pullbacks/qr.jl | 78 ++- src/pullbacks/svd.jl | 24 +- test/ad_utils.jl | 62 -- test/chainrules.jl | 592 +---------------- test/enzyme.jl | 507 +------------- test/mooncake.jl | 618 +----------------- test/testsuite/TestSuite.jl | 4 + test/testsuite/ad_utils.jl | 423 ++++++++++++ test/testsuite/chainrules.jl | 612 +++++++++++++++++ test/testsuite/enzyme.jl | 459 +++++++++++++ test/testsuite/mooncake.jl | 481 ++++++++++++++ 22 files changed, 2231 insertions(+), 1823 deletions(-) delete mode 100644 test/ad_utils.jl create mode 100644 test/testsuite/ad_utils.jl create mode 100644 test/testsuite/chainrules.jl create mode 100644 test/testsuite/enzyme.jl create mode 100644 test/testsuite/mooncake.jl diff --git a/Project.toml b/Project.toml index dbf692a5..048e597f 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ Enzyme = "0.13.118" EnzymeTestUtils = "0.2.5" JET = "0.9, 0.10" LinearAlgebra = "1" -Mooncake = "0.4.183" +Mooncake = "0.4.195" ParallelTestRunner = "2" Random = "1" SafeTestsets = "0.1" diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 0ca43183..8a7c2ef2 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! -import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx! +import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank using AMDGPU using LinearAlgebra using LinearAlgebra: BlasFloat @@ -171,4 +171,11 @@ end MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) +function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix) + hX = sylvester(collect(A), collect(B), collect(C)) + return ROCArray(hX) +end + +svd_rank(S::AnyROCVector, rank_atol) = findlast(s -> s ≥ rank_atol, S) + end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 8bb09db1..ccc03a56 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -3,11 +3,11 @@ module MatrixAlgebraKitCUDAExt using MatrixAlgebraKit using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! -using MatrixAlgebraKit: diagview, sign_safe +using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! -import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd! +import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank using CUDA, CUDA.CUBLAS using CUDA: i32 using LinearAlgebra @@ -195,4 +195,20 @@ end MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) +MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4) +MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4) +function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...) + As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...)) + return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4) +end + +function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix) + # https://github.com/JuliaGPU/CUDA.jl/issues/3021 + # to add native sylvester to CUDA + hX = sylvester(collect(A), collect(B), collect(C)) + return CuArray(hX) +end + +svd_rank(S::AnyCuVector, rank_atol) = findlast(s -> s ≥ rank_atol, S) + end diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index c2de1758..400b2a79 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -95,6 +95,9 @@ for eig in (:eig, :eigh) eig_t! = Symbol(eig, "_trunc!") eig_t_pb = Symbol(eig, "_trunc_pullback") _make_eig_t_pb = Symbol("_make_", eig_t_pb) + eig_t_ne! = Symbol(eig, "_trunc_no_error!") + eig_t_ne_pb = Symbol(eig, "_trunc_no_error_pullback") + _make_eig_t_ne_pb = Symbol("_make_", eig_t_ne_pb) eig_v = Symbol(eig, "_vals") eig_v! = Symbol(eig_v, "!") eig_v_pb = Symbol(eig_v, "_pullback") @@ -136,6 +139,24 @@ for eig in (:eig, :eigh) end return $eig_t_pb end + function ChainRulesCore.rrule(::typeof($eig_t_ne!), A, DV, alg::TruncatedAlgorithm) + Ac = copy_input($eig_f, A) + DV = $(eig_f!)(Ac, DV, alg.alg) + DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc) + return DV′, $(_make_eig_t_ne_pb)(A, DV, ind) + end + function $(_make_eig_t_ne_pb)(A, DV, ind) + function $eig_t_ne_pb(ΔDV) + ΔA = zero(A) + ΔD, ΔV = ΔDV + MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function $eig_t_ne_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful? + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return $eig_t_ne_pb + end function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg) DV = $eig_f(A, alg) function $eig_v_pb(ΔD) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index f6feda8b..217a48c2 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt using Mooncake using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive using MatrixAlgebraKit -using MatrixAlgebraKit: inv_safe, diagview, copy_input +using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output using MatrixAlgebraKit: qr_pullback!, lq_pullback! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! @@ -18,14 +18,16 @@ Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.N @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any} function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual) Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA)) - dAc = Mooncake.zero_tangent(Ac) + Ac_dAc = Mooncake.zero_fcodual(Ac) + dAc = Mooncake.tangent(Ac_dAc) function copy_input_pb(::NoRData) Mooncake.increment!!(Mooncake.tangent(A_dA), dAc) return NoRData(), NoRData(), NoRData() end - return CoDual(Ac, dAc), copy_input_pb + return Ac_dAc, copy_input_pb end +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any} # two-argument in-place factorizations like LQ, QR, EIG for (f!, f, pb, adj) in ( (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint), diff --git a/src/common/defaults.jl b/src/common/defaults.jl index dad16376..bc4160a1 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -34,6 +34,7 @@ default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4) Default tolerance for deciding what values should be considered equal to 0. """ default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4) +default_pullback_rank_atol(A::Diagonal) = default_pullback_rank_atol(diagview(A)) """ default_hermitian_tol(A) diff --git a/src/common/pullbacks.jl b/src/common/pullbacks.jl index e45fbd94..4fe853cd 100644 --- a/src/common/pullbacks.jl +++ b/src/common/pullbacks.jl @@ -10,3 +10,6 @@ function iszerotangent end iszerotangent(::Any) = false iszerotangent(::Nothing) = true + +# fallback +_sylvester(A, B, C) = LinearAlgebra.sylvester(A, B, C) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 6b89b64f..a03eb3c4 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -1,3 +1,15 @@ +function check_eig_cotangents( + D, VᴴΔV; + degeneracy_atol::Real = default_pullback_rank_atol(D), + gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV) + ) + mask = abs.(transpose(D) .- D) .< degeneracy_atol + Δgauge = norm(view(VᴴΔV, mask)) + Δgauge ≤ gauge_atol || + @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ eig_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, [ind]; @@ -41,10 +53,7 @@ function eig_pullback!( length(indV) == pV || throw(DimensionMismatch()) mul!(view(VᴴΔV, :, indV), V', ΔV) - mask = abs.(transpose(D) .- D) .< degeneracy_atol - Δgauge = norm(view(VᴴΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol) VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) @@ -129,10 +138,7 @@ function eig_trunc_pullback!( if !iszerotangent(ΔV) (n, p) == size(ΔV) || throw(DimensionMismatch()) VᴴΔV = V' * ΔV - mask = abs.(transpose(D) .- D) .< degeneracy_atol - Δgauge = norm(view(VᴴΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol) ΔVperp = ΔV - V * inv(G) * VᴴΔV VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) @@ -150,7 +156,7 @@ function eig_trunc_pullback!( # add contribution from orthogonal complement PA = A - (A * V) / V Y = mul!(ΔVperp, PA', Z, 1, 1) - X = sylvester(PA', -Dmat', Y) + X = _sylvester(PA', -Dmat', Y) Z .+= X if eltype(ΔA) <: Real diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 11171685..db78bd6e 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -1,3 +1,15 @@ +function check_eigh_cotangents( + D, aVᴴΔV; + degeneracy_atol::Real = default_pullback_rank_atol(D), + gauge_atol::Real = default_pullback_gauge_atol(aVᴴΔV) + ) + mask = abs.(D' .- D) .< degeneracy_atol + Δgauge = norm(view(aVᴴΔV, mask)) + Δgauge ≤ gauge_atol || + @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ eigh_pullback!( ΔA::AbstractMatrix, A, DV, ΔDV, [ind]; @@ -42,10 +54,7 @@ function eigh_pullback!( mul!(view(VᴴΔV, :, indV), V', ΔV) aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work - mask = abs.(D' .- D) .< degeneracy_atol - Δgauge = norm(view(aVᴴΔV, mask)) - Δgauge ≤ gauge_atol || - @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol) aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) @@ -120,10 +129,7 @@ function eigh_trunc_pullback!( VᴴΔV = V' * ΔV aVᴴΔV = project_antihermitian!(VᴴΔV) - mask = abs.(D' .- D) .< degeneracy_atol - Δgauge = norm(view(aVᴴΔV, mask)) - Δgauge ≤ gauge_atol || - @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol) aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol) @@ -138,7 +144,7 @@ function eigh_trunc_pullback!( # add contribution from orthogonal complement W = qr_null(V) WᴴΔV = W' * ΔV - X = sylvester(W' * A * W, -Dmat, WᴴΔV) + X = _sylvester(W' * A * W, -Dmat, WᴴΔV) Z = mul!(Z, W, X, 1, 1) # put everything together: symmetrize for hermitian case diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index b30fe198..790ee744 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -1,3 +1,41 @@ +function check_lq_cotangents( + L, Q, ΔL, ΔQ, minmn::Int, p::Int; + gauge_atol::Real = default_pullback_gauge_atol(ΔQ) + ) + if minmn > p # case where A is rank-deficient + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + # in this case the number Householder reflections will + # change upon small variations, and all of the remaining + # columns of ΔQ should be zero for a gauge-invariant + # cost function + ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) + Δgauge = max(Δgauge, norm(ΔQ2)) + end + if !iszerotangent(ΔL) + ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn) + Δgauge = max(Δgauge, norm(ΔL22)) + end + Δgauge ≤ gauge_atol || + @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end + return +end + +function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(Q1)) + # in the case where A is full rank, but there are more columns in Q than in A + # (the case of `lq_full`), there is gauge-invariant information in the + # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary + # matrix. As the number of Householder reflections is in fixed in the full rank + # case, Q is expected to rotate smoothly (we might even be able to predict) also + # how the full Q2 will change, but this we omit for now, and we consider + # Q2' * ΔQ2 as a gauge dependent quantity. + Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) + Δgauge ≤ gauge_atol || + @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ lq_pullback!( ΔA, A, LQ, ΔLQ; @@ -36,23 +74,7 @@ function lq_pullback!( ΔA1 = view(ΔA, 1:p, :) ΔA2 = view(ΔA, (p + 1):m, :) - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # columns of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) - Δgauge = max(Δgauge, norm(ΔQ2, Inf)) - end - if !iszerotangent(ΔL) - ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn) - Δgauge = max(Δgauge, norm(ΔL22, Inf)) - end - Δgauge ≤ gauge_atol || - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - end + check_lq_cotangents(L, Q, ΔL, ΔQ, minmn, p; gauge_atol) ΔQ̃ = zero!(similar(Q, (p, n))) if !iszerotangent(ΔQ) @@ -61,17 +83,8 @@ function lq_pullback!( if p < size(Q, 1) Q2 = view(Q, (p + 1):size(Q, 1), :) ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :) - # in the case where A is full rank, but there are more columns in Q than in A - # (the case of `qr_full`), there is gauge-invariant information in the - # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary - # matrix. As the number of Householder reflections is in fixed in the full rank - # case, Q is expected to rotate smoothly (we might even be able to predict) also - # how the full Q2 will change, but this we omit for now, and we consider - # Q2' * ΔQ2 as a gauge dependent quantity. ΔQ2Q1ᴴ = ΔQ2 * Q1' - Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf) - Δgauge ≤ gauge_atol || - @warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol) ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1) end end @@ -102,6 +115,14 @@ function lq_pullback!( return ΔA end +function check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)) + aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ') + Δgauge = norm(aNᴴΔN) + Δgauge ≤ gauge_atol || + @warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ lq_null_pullback!( ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ; @@ -118,10 +139,7 @@ function lq_null_pullback!( gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ) ) if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0 - aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ') - Δgauge = norm(aNᴴΔN) - Δgauge ≤ gauge_atol || - @warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol) L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here? X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ') ΔA = mul!(ΔA, X, Nᴴ, -1, 1) diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index 1c6de509..8ada8575 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -16,7 +16,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...) M = zero(P) !iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1) !iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1) - C = sylvester(P, P, M' - M) + C = _sylvester(P, P, M' - M) C .+= ΔP ΔA = mul!(ΔA, W, C, 1, 1) if !iszerotangent(ΔW) diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index 888029be..d92878bd 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -1,3 +1,38 @@ +function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ)) + if minmn > p # case where A is rank-deficient + Δgauge = abs(zero(eltype(Q))) + if !iszerotangent(ΔQ) + # in this case the number Householder reflections will + # change upon small variations, and all of the remaining + # columns of ΔQ should be zero for a gauge-invariant + # cost function + ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) + Δgauge = max(Δgauge, norm(ΔQ2, Inf)) + end + if !iszerotangent(ΔR) + ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2)) + Δgauge = max(Δgauge, norm(ΔR22, Inf)) + end + Δgauge ≤ gauge_atol || + @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end + return +end + +function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_pullback_gauge_atol(ΔQ2)) + # in the case where A is full rank, but there are more columns in Q than in A + # (the case of `qr_full`), there is gauge-invariant information in the + # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary + # matrix. As the number of Householder reflections is in fixed in the full rank + # case, Q is expected to rotate smoothly (we might even be able to predict) also + # how the full Q2 will change, but this we omit for now, and we consider + # Q2' * ΔQ2 as a gauge dependent quantity. + Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) + Δgauge ≤ gauge_atol || + @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ qr_pullback!( ΔA, A, QR, ΔQR; @@ -37,23 +72,7 @@ function qr_pullback!( ΔA1 = view(ΔA, :, 1:p) ΔA2 = view(ΔA, :, (p + 1):n) - if minmn > p # case where A is rank-deficient - Δgauge = abs(zero(eltype(Q))) - if !iszerotangent(ΔQ) - # in this case the number Householder reflections will - # change upon small variations, and all of the remaining - # columns of ΔQ should be zero for a gauge-invariant - # cost function - ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) - Δgauge = max(Δgauge, norm(ΔQ2, Inf)) - end - if !iszerotangent(ΔR) - ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):n) - Δgauge = max(Δgauge, norm(ΔR22, Inf)) - end - Δgauge ≤ gauge_atol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - end + check_qr_cotangents(Q, R, ΔQ, ΔR, minmn, p; gauge_atol) ΔQ̃ = zero!(similar(Q, (m, p))) if !iszerotangent(ΔQ) @@ -61,17 +80,8 @@ function qr_pullback!( if p < size(Q, 2) Q2 = view(Q, :, (p + 1):size(Q, 2)) ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2)) - # in the case where A is full rank, but there are more columns in Q than in A - # (the case of `qr_full`), there is gauge-invariant information in the - # projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary - # matrix. As the number of Householder reflections is in fixed in the full rank - # case, Q is expected to rotate smoothly (we might even be able to predict) also - # how the full Q2 will change, but this we omit for now, and we consider - # Q2' * ΔQ2 as a gauge dependent quantity. Q1dΔQ2 = Q1' * ΔQ2 - Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf) - Δgauge ≤ gauge_atol || - @warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol) ΔQ̃ = mul!(ΔQ̃, Q2, Q1dΔQ2', -1, 1) end end @@ -102,6 +112,14 @@ function qr_pullback!( return ΔA end +function check_qr_null_cotangents(N, ΔN; gauge_atol::Real = default_pullback_gauge_atol(ΔN)) + aNᴴΔN = project_antihermitian!(N' * ΔN) + Δgauge = norm(aNᴴΔN) + Δgauge ≤ gauge_atol || + @warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ qr_null_pullback!( ΔA::AbstractMatrix, A, N, ΔN; @@ -118,11 +136,7 @@ function qr_null_pullback!( gauge_atol::Real = default_pullback_gauge_atol(ΔN) ) if !iszerotangent(ΔN) && size(N, 2) > 0 - aNᴴΔN = project_antihermitian!(N' * ΔN) - Δgauge = norm(aNᴴΔN) - Δgauge ≤ gauge_atol || - @warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" - + check_qr_null_cotangents(N, ΔN; gauge_atol) Q, R = qr_compact(A; positive = true) X = rdiv!(ΔN' * Q, UpperTriangular(R)') ΔA = mul!(ΔA, N, X, -1, 1) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 1608343e..9c131464 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -1,3 +1,13 @@ +svd_rank(S, rank_atol) = searchsortedlast(S, rank_atol; rev = true) + +function check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol = default_pullback_rank_atol(Sr), gauge_atol = default_pullback_gauge_atol(aUΔU, aVΔV)) + mask = abs.(Sr' .- Sr) .< degeneracy_atol + Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) + Δgauge ≤ gauge_atol || + @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + return +end + """ svd_pullback!( ΔA, A, USVᴴ, ΔUSVᴴ, [ind]; @@ -33,7 +43,7 @@ function svd_pullback!( minmn = min(m, n) S = diagview(Smat) length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)")) - r = searchsortedlast(S, rank_atol; rev = true) # rank + r = svd_rank(S, rank_atol) Ur = view(U, :, 1:r) Vᴴr = view(Vᴴ, 1:r, :) Sr = view(S, 1:r) @@ -70,10 +80,7 @@ function svd_pullback!( aVΔV = project_antihermitian!(VΔV) # check whether cotangents arise from gauge-invariance objective function - mask = abs.(Sr' .- Sr) .< degeneracy_atol - Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol, gauge_atol) UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(Sr' .- Sr, degeneracy_atol) .+ (aUΔU .- aVΔV) .* inv_safe.(Sr' .+ Sr, degeneracy_atol) @@ -169,10 +176,7 @@ function svd_trunc_pullback!( aVΔV = project_antihermitian!(VΔV) # check whether cotangents arise from gauge-invariance objective function - mask = abs.(S' .- S) .< degeneracy_atol - Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf) - Δgauge ≤ gauge_atol || - @warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + check_svd_cotangents(aUΔU, S, aVΔV; degeneracy_atol, gauge_atol) UdΔAV = (aUΔU .+ aVΔV) .* inv_safe.(S' .- S, degeneracy_atol) .+ (aUΔU .- aVΔV) .* inv_safe.(S' .+ S, degeneracy_atol) @@ -205,7 +209,7 @@ function svd_trunc_pullback!( else fill!(view(rhs, m̃ .+ (1:ñ), :), 0) end - XY = sylvester(ÃÃ, -Smat, rhs) + XY = _sylvester(ÃÃ, -Smat, rhs) X = view(XY, 1:m̃, :) Y = view(XY, m̃ .+ (1:ñ), :) ΔA = mul!(ΔA, Ũ, X * Vᴴ, 1, 1) diff --git a/test/ad_utils.jl b/test/ad_utils.jl deleted file mode 100644 index fccc6c00..00000000 --- a/test/ad_utils.jl +++ /dev/null @@ -1,62 +0,0 @@ -function remove_svdgauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true) - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 - mul!(ΔU, U, gaugepart, -1, 1) - return ΔU, ΔVᴴ -end -function remove_eiggauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = V' * ΔV - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V / (V' * V), gaugepart, -1, 1) - return ΔV -end -function remove_eighgauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = V' * ΔV - gaugepart = project_antihermitian!(gaugepart) - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V, gaugepart, -1, 1) - return ΔV -end -function stabilize_eigvals!(D::AbstractVector) - absD = abs.(D) - p = invperm(sortperm(absD)) # rank of abs(D) - # account for exact degeneracies in absolute value when having complex conjugate pairs - for i in 1:(length(D) - 1) - if absD[i] == absD[i + 1] # conjugate pairs will appear sequentially - p[p .>= p[i + 1]] .-= 1 # lower the rank of all higher ones - end - end - n = maximum(p) - # rescale eigenvalues so that they lie on distinct radii in the complex plane - # that are chosen randomly in non-overlapping intervals [10 * k/n, 10 * (k+0.5)/n)] for k=1,...,n - radii = 10 .* ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n - for i in 1:length(D) - D[i] = sign(D[i]) * radii[p[i]] - end - return D -end -function make_eig_matrix(rng, T, n) - A = randn(rng, T, n, n) - D, V = eig_full(A) - stabilize_eigvals!(diagview(D)) - Ac = V * D * inv(V) - return (T <: Real) ? real(Ac) : Ac -end -function make_eigh_matrix(rng, T, n) - A = project_hermitian!(randn(rng, T, n, n)) - D, V = eigh_full(A) - stabilize_eigvals!(diagview(D)) - return project_hermitian!(V * D * V') -end - -precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T))) diff --git a/test/chainrules.jl b/test/chainrules.jl index a8b2fd3b..c0ab618a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,590 +1,18 @@ using MatrixAlgebraKit using Test -using TestExtras -using StableRNGs -using ChainRulesCore, ChainRulesTestUtils, Zygote -using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD -using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! -include("ad_utils.jl") +#BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI -for f in - ( - :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, - :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, - :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, - :left_polar, :right_polar, - ) - copy_f = Symbol(:copy_, f) - f! = Symbol(f, '!') - _hermitian = startswith(string(f), "eigh") - @eval begin - function $copy_f(input, alg) - if $_hermitian - input = (input + input') / 2 - end - return $f(input, alg) - end - function ChainRulesCore.rrule(::typeof($copy_f), input, alg) - output = MatrixAlgebraKit.initialize_output($f!, input, alg) - if $_hermitian - input = (input + input') / 2 - else - input = copy(input) - end - output, pb = ChainRulesCore.rrule($f!, input, output, alg) - return output, x -> (NoTangent(), pb(x)[2], NoTangent()) - end - end -end - -@timedtestset "QR AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - # qr_compact - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - alg = LAPACK_HouseholderQR(; positive = true) - Q, R = copy_qr_compact(A, alg) - ΔQ = randn(rng, T, m, minmn) - ΔR = randn(rng, T, minmn, n) - ΔR2 = UpperTriangular(randn(rng, T, minmn, minmn)) - ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) - test_rrule( - copy_qr_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - test_rrule( - copy_qr_null, A, alg ⊢ NoTangent(); - output_tangent = ΔN, atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ qr_compact, A; - fkwargs = (; positive = true), output_tangent = ΔQ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ qr_compact, A; - fkwargs = (; positive = true), output_tangent = ΔR, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, qr_null, A; - fkwargs = (; positive = true), output_tangent = ΔN, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - # qr_full - Q, R = copy_qr_full(A, alg) - Q1 = view(Q, 1:m, 1:minmn) - ΔQ = randn(rng, T, m, m) - ΔQ2 = view(ΔQ, :, (minmn + 1):m) - mul!(ΔQ2, Q1, Q1' * ΔQ2) - ΔR = randn(rng, T, m, n) - test_rrule( - copy_qr_full, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_full, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - if m > n - _, null_pb = Zygote.pullback(qr_null, A, alg) - @test_logs (:warn,) null_pb(randn(rng, T, m, max(0, m - minmn))) - _, full_pb = Zygote.pullback(qr_full, A, alg) - @test_logs (:warn,) full_pb((randn(rng, T, m, m), randn(rng, T, m, n))) - end - # rank-deficient A - r = minmn - 5 - A = randn(rng, T, m, r) * randn(rng, T, r, n) - Q, R = qr_compact(A, alg) - ΔQ = randn(rng, T, m, minmn) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - ΔQ2 .= 0 - ΔR = randn(rng, T, minmn, n) - view(ΔR, (r + 1):minmn, :) .= 0 - test_rrule( - copy_qr_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔQ, ΔR), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, qr_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔQ, ΔR), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "LQ AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - # lq_compact - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - alg = LAPACK_HouseholderLQ(; positive = true) - L, Q = copy_lq_compact(A, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - test_rrule( - copy_lq_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - test_rrule( - copy_lq_null, A, alg ⊢ NoTangent(); - output_tangent = ΔNᴴ, atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ lq_compact, A; - fkwargs = (; positive = true), output_tangent = ΔL, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ lq_compact, A; - fkwargs = (; positive = true), output_tangent = ΔQ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, lq_null, A; - fkwargs = (; positive = true), output_tangent = ΔNᴴ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - # lq_full - L, Q = copy_lq_full(A, alg) - Q1 = view(Q, 1:minmn, 1:n) - ΔQ = randn(rng, T, n, n) - ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) - mul!(ΔQ2, ΔQ2 * Q1', Q1) - ΔL = randn(rng, T, m, n) - test_rrule( - copy_lq_full, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_full, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - if m < n - Nᴴ, null_pb = Zygote.pullback(lq_null, A, alg) - @test_logs (:warn,) null_pb(randn(rng, T, max(0, n - minmn), n)) - _, full_pb = Zygote.pullback(lq_full, A, alg) - @test_logs (:warn,) full_pb((randn(rng, T, m, n), randn(rng, T, n, n))) - end - # rank-deficient A - r = minmn - 5 - A = randn(rng, T, m, r) * randn(rng, T, r, n) - L, Q = lq_compact(A, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - test_rrule( - copy_lq_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔL, ΔQ), atol = atol, rtol = rtol - ) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, lq_compact, A; - fkwargs = (; positive = true), output_tangent = (ΔL, ΔQ), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "EIG AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eig_matrix(rng, T, m) - D, V = eig_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, complex(T), m, m) - ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, complex(T), m, m) - ΔD2 = Diagonal(randn(rng, complex(T), m)) - for alg in (LAPACK_Simple(), LAPACK_Expert()) - test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol - ) - test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol - ) - test_rrule( - copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol - ) - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(Ddiag[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, eig_full, A; - output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eig_full, A; - output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ eig_full, A; - output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ eig_full, A; - output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eig_vals, A; - output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) -end - -@timedtestset "EIGH AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eigh_matrix(rng, T, m) - D, V = eigh_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, T, m, m) - ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, real(T), m, m) - ΔD2 = Diagonal(randn(rng, real(T), m)) - for alg in ( - LAPACK_QRIteration(), LAPACK_DivideAndConquer(), LAPACK_Bisection(), - LAPACK_MultipleRelativelyRobustRepresentations(), - ) - # copy_eigh_full includes a projector onto the Hermitian part of the matrix - test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol - ) - test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol - ) - test_rrule( - copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol - ) - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - test_rrule( - copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) - dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - # eigh_full does not include a projector onto the Hermitian part of the matrix - test_rrule( - config, eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, first ∘ eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A; - output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, eigh_vals ∘ Matrix ∘ Hermitian, A; - output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) - for r in 1:4:m - trunc = truncrank(r; by = real) - ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) - test_rrule( - config, eigh_trunc2, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end - trunc = trunctol(; rtol = 1 / 2) - ind = MatrixAlgebraKit.findtruncated(Ddiag, trunc) - test_rrule( - config, eigh_trunc2, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) -end - -@timedtestset "SVD AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - for alg in (LAPACK_QRIteration(), LAPACK_DivideAndConquer()) - test_rrule( - copy_svd_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔU, ΔS, ΔVᴴ), atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_compact, A, alg ⊢ NoTangent(); - output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_vals, A, alg ⊢ NoTangent(); - output_tangent = diagview(ΔS), atol, rtol - ) - for r in 1:4:minmn - truncalg = TruncatedAlgorithm(alg, truncrank(r)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) - dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), - atol = atol, rtol = rtol - ) - test_rrule( - copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), - atol = atol, rtol = rtol - ) - dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) - dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - @test isapprox(dA1, dA2; atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, svd_compact, A; - output_tangent = (ΔU, ΔS, ΔVᴴ), atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_compact, A; - output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_vals, A; - output_tangent = diagview(ΔS), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - for r in 1:4:minmn - trunc = truncrank(r) - ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) - test_rrule( - config, svd_trunc, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_trunc_no_error, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end - trunc = trunctol(; atol = S[1, 1] / 2) - ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) - test_rrule( - config, svd_trunc, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, svd_trunc_no_error, A; - fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "Polar AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) - m >= n && - test_rrule(copy_left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) - m <= n && - test_rrule(copy_right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) - end - # Zygote part - config = Zygote.ZygoteRuleConfig() - m >= n && test_rrule( - config, left_polar, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m <= n && test_rrule( - config, right_polar, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end -end - -@timedtestset "Orth and null with eltype $T" for T in (Float64, ComplexF64, Float32) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - config = Zygote.ZygoteRuleConfig() - test_rrule( - config, left_orth, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, left_orth, A; - fkwargs = (; alg = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m >= n && - test_rrule( - config, left_orth, A; - fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - - ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - test_rrule( - config, left_null, A; - fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, - rrule_f = rrule_via_ad, check_inferred = false - ) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - test_rrule( - config, right_orth, A; - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - test_rrule( - config, right_orth, A; fkwargs = (; alg = :lq), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - m <= n && - test_rrule( - config, right_orth, A; fkwargs = (; alg = :polar), - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - test_rrule( - config, right_null, A; - fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) +m = 19 +for T in BLASFloats, n in (17, m, 23) + TestSuite.seed_rng!(123) + if !is_buildkite # doesn't work on GPU + TestSuite.test_chainrules(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) end end diff --git a/test/enzyme.jl b/test/enzyme.jl index 0330c5b0..28ff7454 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -1,497 +1,30 @@ using MatrixAlgebraKit using Test -using TestExtras -using StableRNGs -using ChainRulesCore -using Enzyme, EnzymeTestUtils -using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD -using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!, BlasFloat -using GenericLinearAlgebra, GenericSchur +using LinearAlgebra: Diagonal +using CUDA, AMDGPU -# https://github.com/EnzymeAD/Enzyme.jl/issues/2888, -# test_reverse doesn't work with BigFloat +BLASFloats = (ComplexF64,) # full suite is too expensive on CI +GenericFloats = (BigFloat,) +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite -ETs = @static if VERSION < v"1.12.0" - (ComplexF64, BigFloat) -else - (ComplexF64,) -end -include("ad_utils.jl") -function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated) - ΔA = randn(rng, eltype(A), size(A)...) - A_ΔA() = Duplicated(copy(A), copy(ΔA)) - function args_Δargs() - if isnothing(args) - return Const(args) - elseif args isa Tuple && all(isnothing, args) - return Const(args) - else - return Duplicated(copy.(args), copy.(Δargs)) - end - end - copy_activities = isnothing(alg) ? (Const(f), A_ΔA()) : (Const(f), A_ΔA(), Const(alg)) - inplace_activities = isnothing(alg) ? (Const(f!), A_ΔA(), args_Δargs()) : (Const(f!), A_ΔA(), args_Δargs(), Const(alg)) - - mode = EnzymeTestUtils.set_runtime_activity(ReverseSplitWithPrimal, false) - c_act = Const(EnzymeTestUtils.call_with_kwargs) - forward_copy, reverse_copy = autodiff_thunk( - mode, typeof(c_act), return_act, typeof(Const(())), map(typeof, copy_activities)... - ) - forward_inplace, reverse_inplace = autodiff_thunk( - mode, typeof(c_act), return_act, typeof(Const(())), map(typeof, inplace_activities)... - ) - copy_tape, copy_y_ad, copy_shadow_result = forward_copy(c_act, Const(()), copy_activities...) - inplace_tape, inplace_y_ad, inplace_shadow_result = forward_inplace(c_act, Const(()), inplace_activities...) - if !(copy_shadow_result === nothing) - flush(stdout) - EnzymeTestUtils.map_fields_recursive(copyto!, copy_shadow_result, copy.(ȳ)) - end - if !(inplace_shadow_result === nothing) - EnzymeTestUtils.map_fields_recursive(copyto!, inplace_shadow_result, copy.(ȳ)) - end - dx_copy_ad = only(reverse_copy(c_act, Const(()), copy_activities..., copy_tape)) - dx_inplace_ad = only(reverse_inplace(c_act, Const(()), inplace_activities..., inplace_tape)) - # check all returned derivatives between copy & inplace - for (i, (copy_act_i, inplace_act_i)) in enumerate(zip(copy_activities[2:end], inplace_activities[2:end])) - if copy_act_i isa Duplicated && inplace_act_i isa Duplicated - msg_deriv = "shadow derivative for argument $(i - 1) should match between copy and inplace" - EnzymeTestUtils.test_approx(copy_act_i.dval, inplace_act_i.dval, msg_deriv) - end - end - return -end +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" -@timedtestset "QR AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - A = randn(rng, T, m, n) - atol = rtol = m * n * precision(T) - minmn = min(m, n) - alg = MatrixAlgebraKit.default_qr_algorithm(A) - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - @testset "qr_compact" begin - ΔQR = (randn(rng, T, m, minmn), randn(rng, T, minmn, n)) - Q, R = qr_compact(A, alg) - QR = MatrixAlgebraKit.initialize_output(qr_compact!, A, alg) - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - T <: BlasFloat && test_reverse(qr_compact, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = ΔQR, fdm = fdm) - test_pullbacks_match(rng, qr_compact!, qr_compact, A, QR, ΔQR, alg) - end - @testset "qr_null" begin - Q, R = qr_compact(A, alg) - N = zeros(T, m, max(0, m - minmn)) - ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) - T <: BlasFloat && test_reverse(qr_null, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = ΔN) - test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) - end - @testset "qr_full" begin - Q, R = qr_full(A, alg) - Q1 = view(Q, 1:m, 1:minmn) - ΔQ = randn(rng, T, m, m) - ΔQ2 = view(ΔQ, :, (minmn + 1):m) - mul!(ΔQ2, Q1, Q1' * ΔQ2) - ΔR = randn(rng, T, m, n) - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - T <: BlasFloat && test_reverse(qr_full, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (ΔQ, ΔR), fdm = fdm) - test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) - end - @testset "qr_compact - rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) - Q, R = qr_compact(Ard, alg) - ΔQ = randn(rng, T, m, minmn) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - ΔQ2 .= 0 - ΔR = randn(rng, T, minmn, n) - view(ΔR, (r + 1):minmn, :) .= 0 - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - T <: BlasFloat && test_reverse(qr_compact, RT, (Ard, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (ΔQ, ΔR), fdm = fdm) - test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) - end +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if T <: BLASFloats + if CUDA.functional() + TestSuite.test_enzyme(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + #n == m && TestSuite.test_enzyme(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) end - end -end - -@timedtestset "LQ AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - minmn = min(m, n) - A = randn(rng, T, m, n) - alg = MatrixAlgebraKit.default_lq_algorithm(A) - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - @testset "lq_compact" begin - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - L, Q = lq_compact(A, alg) - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - T <: BlasFloat && test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (ΔL, ΔQ), fdm = fdm) - test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (ΔL, ΔQ), alg) - end - @testset "lq_null" begin - L, Q = lq_compact(A, alg) - ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - Nᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - T <: BlasFloat && test_reverse(lq_null, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = ΔNᴴ) - # runtime activity problems here with BigFloat - T <: BlasFloat && test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) - end - @testset "lq_full" begin - L, Q = lq_full(A, alg) - Q1 = view(Q, 1:minmn, 1:n) - ΔQ = randn(rng, T, n, n) - ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) - mul!(ΔQ2, ΔQ2 * Q1', Q1) - ΔL = randn(rng, T, m, n) - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - T <: BlasFloat && test_reverse(lq_full, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (ΔL, ΔQ), fdm = fdm) - test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) - end - @testset "lq_compact -- rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) - L, Q = lq_compact(Ard, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - T <: BlasFloat && test_reverse(lq_compact, RT, (Ard, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (ΔL, ΔQ), fdm = fdm) - test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) - end + if AMDGPU.functional() + TestSuite.test_enzyme(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + #TestSuite.test_enzyme(Diagonal{T, ROCVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) end end -end - -@timedtestset "EIG AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eig_matrix(rng, T, m) - D, V = eig_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, complex(T), m, m) - ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, complex(T), m, m) - ΔD2 = Diagonal(randn(rng, complex(T), m)) - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - alg = MatrixAlgebraKit.default_eig_algorithm(A) - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - if T <: BlasFloat - test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm) - test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) - test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag), fdm = fdm) - test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) - else - test_pullbacks_match(rng, eig_full!, eig_full, A, (nothing, nothing), (nothing, nothing), alg; ȳ = (ΔD2, ΔV)) - test_pullbacks_match(rng, eig_vals!, eig_vals, A, nothing, nothing, alg; ȳ = ΔD2.diag) - end - end - @testset "eig_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(diagview(D), truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - if T <: BlasFloat - test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) - else - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) - end - end - truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(Ddiag[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - if T <: BlasFloat - test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg; ȳ = (ΔDtrunc, ΔVtrunc)) - else - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg; ȳ = (ΔDtrunc, ΔVtrunc)) - end - end -end - - -function copy_eigh_full(A, alg) - A = (A + A') / 2 - return eigh_full(A, alg) -end - -function copy_eigh_full!(A, DV::Tuple, alg::MatrixAlgebraKit.AbstractAlgorithm) - A = (A + A') / 2 - return eigh_full!(A, DV, alg) -end - -function copy_eigh_vals(A; kwargs...) - A = (A + A') / 2 - return eigh_vals(A; kwargs...) -end - -function copy_eigh_vals!(A, D; kwargs...) - A = (A + A') / 2 - return eigh_vals!(A, D; kwargs...) -end - -function copy_eigh_vals(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals(A, alg; kwargs...) -end - -function copy_eigh_vals!(A, D, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals!(A, D, alg; kwargs...) -end - -function copy_eigh_trunc_no_error(A, alg) - A = (A + A') / 2 - return eigh_trunc_no_error(A, alg) -end - -function copy_eigh_trunc_no_error!(A, DV, alg) - A = (A + A') / 2 - return eigh_trunc_no_error!(A, DV, alg) -end - -@timedtestset "EIGH AD Rules with eltype $T" for T in filter(T -> <:(T, BlasFloat), ETs) - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eigh_matrix(rng, T, m) - D, V = eigh_full(A) - D2 = Diagonal(D) - ΔV = randn(rng, T, m, m) - ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, real(T), m, m) - ΔD2 = Diagonal(randn(rng, real(T), m)) - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - alg = MatrixAlgebraKit.default_eigh_algorithm(A) - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - T <: BlasFloat && test_reverse(copy_eigh_full, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm) - T <: BlasFloat && test_reverse(copy_eigh_full!, RT, (copy(A), TA), ((D, V), TA), (alg, Const); atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm) - test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) - T <: BlasFloat && test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag), fdm = fdm) - test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) - end - @testset "eigh_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - for r in 1:4:m - Ddiag = diagview(D) - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - T <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) - end - Ddiag = diagview(D) - truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - T <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) - end -end - -@timedtestset "SVD AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - alg = MatrixAlgebraKit.default_svd_algorithm(A) - minmn = min(m, n) - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - @testset "svd_compact" begin - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - if T <: BlasFloat - test_reverse(svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = (ΔU, ΔS, ΔVᴴ), fdm = fdm) - test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), alg) - else - USVᴴ = MatrixAlgebraKit.initialize_output(svd_compact!, A, alg) - test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = (ΔU, ΔS, ΔVᴴ)) - end - end - @testset "svd_full" begin - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - ΔUfull = zeros(T, m, m) - ΔSfull = zeros(real(T), m, n) - ΔVᴴfull = zeros(T, n, n) - U, S, Vᴴ = svd_full(A) - view(ΔUfull, :, 1:minmn) .= ΔU - view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ - diagview(ΔSfull)[1:minmn] .= diagview(ΔS) - if T <: BlasFloat - test_reverse(svd_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (ΔUfull, ΔSfull, ΔVᴴfull), fdm = fdm) - test_pullbacks_match(rng, svd_full!, svd_full, A, (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull), alg) - else - USVᴴ = MatrixAlgebraKit.initialize_output(svd_full!, A, alg) - test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = (ΔUfull, ΔSfull, ΔVᴴfull)) - end - end - @testset "svd_vals" begin - S = svd_vals(A) - ΔS = randn(rng, real(T), minmn) - if T <: BlasFloat - test_reverse(svd_vals, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), output_tangent = ΔS, fdm = fdm) - test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS, alg) - else - S = MatrixAlgebraKit.initialize_output(svd_vals!, A, alg) - test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, nothing, alg; ȳ = ΔS) - end - end - end - @testset "svd_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - for r in 1:4:minmn - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - truncalg = TruncatedAlgorithm(alg, truncrank(r)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - if T <: BlasFloat - test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - else - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - end - end - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - if T <: BlasFloat - test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - else - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) - end - end - end -end - -# GLA works with polar, but these tests -# segfault because of Sylvester + BigFloat -@timedtestset "Polar AD Rules with eltype $T" for T in filter(T -> <:(T, BlasFloat), ETs) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - alg = MatrixAlgebraKit.default_polar_algorithm(A) - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - if m >= n - WP = left_polar(A; alg = alg) - W, P = WP - ΔWP = randn(rng, T, size(W)...), randn(rng, T, size(P)...) - T <: BlasFloat && test_reverse(left_polar, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol) - test_pullbacks_match(rng, left_polar!, left_polar, A, WP, ΔWP, alg) - elseif m <= n - PWᴴ = right_polar(A; alg = alg) - P, Wᴴ = PWᴴ - ΔPWᴴ = randn(rng, T, size(P)...), randn(rng, T, size(Wᴴ)...) - T <: BlasFloat && test_reverse(right_polar, RT, (A, TA), (alg, Const); atol = atol, rtol = rtol) - test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, ΔPWᴴ, alg) - end - end - end -end - -# GLA not working with orthnull yet -@timedtestset "Orth and null with eltype $T" for T in filter(T -> <:(T, BlasFloat), ETs) - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - @testset "left_orth" begin - @testset for alg in (:polar, :qr) - n > m && alg == :polar && continue - VC = left_orth(A; alg = alg) - V, C = VC - ΔV = randn(rng, T, size(V)...) - ΔC = randn(rng, T, size(C)...) - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - T <: BlasFloat && test_reverse(left_orth, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), fdm = fdm) - left_orth_alg!(A, VC) = left_orth!(A, VC; alg = alg) - left_orth_alg(A) = left_orth(A; alg = alg) - test_pullbacks_match(rng, left_orth_alg!, left_orth_alg, A, (V, C), (ΔV, ΔC)) - end - end - @testset "right_orth" begin - @testset for alg in (:polar, :lq) - n < m && alg == :polar && continue - CVᴴ = right_orth(A; alg = alg) - C, Vᴴ = CVᴴ - ΔC = randn(rng, T, size(C)...) - ΔVᴴ = randn(rng, T, size(Vᴴ)...) - fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - T <: BlasFloat && test_reverse(right_orth, RT, (A, TA); atol = atol, rtol = rtol, fkwargs = (alg = alg,), fdm = fdm) - right_orth_alg!(A, CVᴴ) = right_orth!(A, CVᴴ; alg = alg) - right_orth_alg(A) = right_orth(A; alg = alg) - test_pullbacks_match(rng, right_orth_alg!, right_orth_alg, A, (C, Vᴴ), (ΔC, ΔVᴴ)) - end - end - @testset "left_null" begin - ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - N = similar(ΔN) - left_null_qr!(A, N) = left_null!(A, N; alg = :qr) - left_null_qr(A) = left_null(A; alg = :qr) - T <: BlasFloat && test_reverse(left_null_qr, RT, (A, TA); output_tangent = ΔN, atol = atol, rtol = rtol) - test_pullbacks_match(rng, left_null_qr!, left_null_qr, A, N, ΔN) - end - @testset "right_null" begin - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - Nᴴ = similar(ΔNᴴ) - right_null_lq!(A, Nᴴ) = right_null!(A, Nᴴ; alg = :lq) - right_null_lq(A) = right_null(A; alg = :lq) - T <: BlasFloat && test_reverse(right_null_lq, RT, (A, TA); output_tangent = ΔNᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, right_null_lq!, right_null_lq, A, Nᴴ, ΔNᴴ) - end - end + if !is_buildkite + TestSuite.test_enzyme(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + #n == m && TestSuite.test_enzyme(Diagonal{T, Vector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) end end diff --git a/test/mooncake.jl b/test/mooncake.jl index 760102b1..ea5bbf65 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -1,597 +1,29 @@ using MatrixAlgebraKit using Test -using TestExtras -using StableRNGs -using Mooncake, Mooncake.TestUtils -using Mooncake: rrule!! -using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc -using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! - -include("ad_utils.jl") - -make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) -make_mooncake_tangent(ΔA::Matrix{<:Real}) = ΔA -make_mooncake_tangent(ΔA::Vector{<:Real}) = ΔA -make_mooncake_tangent(ΔA::Matrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) -make_mooncake_tangent(ΔA::Vector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) -make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) -make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD))) - -make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), T...) - -make_mooncake_fdata(x) = make_mooncake_tangent(x) -make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) - -ETs = (Float32, ComplexF64) - -# no `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) - dA_copy = make_mooncake_tangent(copy(ΔA)) - A_copy = copy(A) - dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy)) - copy_pb!!(rdata) - return dA_copy -end - -# `alg` argument -function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_copy = make_mooncake_tangent(copy(ΔA)) - A_copy = copy(A) - dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) - copy_pb!!(rdata) - return dA_copy -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) - dA_inplace = make_mooncake_tangent(copy(ΔA)) - A_inplace = copy(A) - dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) - end - inplace_pb!!(rdata) - return dA_inplace -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) - dA_inplace = make_mooncake_tangent(copy(ΔA)) - A_inplace = copy(A) - dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) - end - inplace_pb!!(rdata) - return dA_inplace -end - -""" - test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - -Compare the result of running the *in-place, mutating* function `f!`'s reverse rule -with the result of running its *non-mutating* partner function `f`'s reverse rule. -We must compare directly because many of the mutating functions modify `A` as a -scratch workspace, making testing `f!` against finite differences infeasible. - -The arguments to this function are: - - `f!` the mutating, in-place version of the function (accepts `args` for the function result) - - `f` the non-mutating version of the function (does not accept `args` for the function result) - - `A` the input matrix to factorize - - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) - - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input - - `alg` optional algorithm keyword argument - - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) -""" -function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) - sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - rrule = Mooncake.build_rrule(rvs_interp, sig) - ΔA = randn(rng, eltype(A), size(A)) - - dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) - - dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] - dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] - @test dA_inplace_ ≈ dA_copy_ - return -end - -@timedtestset "QR AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_HouseholderQR(), - LAPACK_HouseholderQR(; positive = true), - ) - @testset "qr_compact" begin - QR = qr_compact(A, alg) - Q = randn(rng, T, m, minmn) - R = randn(rng, T, minmn, n) - Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_compact!, qr_compact, A, (Q, R), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) - end - @testset "qr_null" begin - Q, R = qr_compact(A, alg) - ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) - N = qr_null(A, alg) - dN = make_mooncake_tangent(copy(ΔN)) - Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dN, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) - end - @testset "qr_full" begin - Q, R = qr_full(A, alg) - Q1 = view(Q, 1:m, 1:minmn) - ΔQ = randn(rng, T, m, m) - ΔQ2 = view(ΔQ, :, (minmn + 1):m) - mul!(ΔQ2, Q1, Q1' * ΔQ2) - ΔR = randn(rng, T, m, n) - dQ = make_mooncake_tangent(copy(ΔQ)) - dR = make_mooncake_tangent(copy(ΔR)) - dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) - end - @testset "qr_compact - rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) - Q, R = qr_compact(Ard, alg) - QR = (Q, R) - ΔQ = randn(rng, T, m, minmn) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - ΔQ2 .= 0 - ΔR = randn(rng, T, minmn, n) - view(ΔR, (r + 1):minmn, :) .= 0 - dQ = make_mooncake_tangent(copy(ΔQ)) - dR = make_mooncake_tangent(copy(ΔR)) - dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) - end - end - end -end - -@timedtestset "LQ AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_HouseholderLQ(), - LAPACK_HouseholderLQ(; positive = true), - ) - @testset "lq_compact" begin - L, Q = lq_compact(A, alg) - Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) - end - @testset "lq_null" begin - L, Q = lq_compact(A, alg) - ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - Nᴴ = randn(rng, T, max(0, n - minmn), n) - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dNᴴ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) - end - @testset "lq_full" begin - L, Q = lq_full(A, alg) - Q1 = view(Q, 1:minmn, 1:n) - ΔQ = randn(rng, T, n, n) - ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) - mul!(ΔQ2, ΔQ2 * Q1', Q1) - ΔL = randn(rng, T, m, n) - dL = make_mooncake_tangent(ΔL) - dQ = make_mooncake_tangent(ΔQ) - dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) - end - @testset "lq_compact - rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) - L, Q = lq_compact(Ard, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - dL = make_mooncake_tangent(ΔL) - dQ = make_mooncake_tangent(ΔQ) - dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) - end - end - end -end - -@timedtestset "EIG AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eig_matrix(rng, T, m) - DV = eig_full(A) - D, V = DV - Ddiag = diagview(D) - ΔV = randn(rng, complex(T), m, m) - ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, complex(T), m, m) - ΔD2 = Diagonal(randn(rng, complex(T), m)) - - dD = make_mooncake_tangent(ΔD2) - dV = make_mooncake_tangent(ΔV) - dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) - # compute the dA corresponding to the above dD, dV - @testset for alg in ( - LAPACK_Simple(), - #LAPACK_Expert(), # expensive on CI - ) - @testset "eig_full" begin - Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) - end - @testset "eig_vals" begin - Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) - end - @testset "eig_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end - truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end - end -end - -function copy_eigh_full(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_full(A, alg; kwargs...) -end - -function copy_eigh_full!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_full!(A, DV, alg; kwargs...) -end - -function copy_eigh_vals(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals(A, alg; kwargs...) -end - -function copy_eigh_vals!(A, D, alg; kwargs...) - A = (A + A') / 2 - return eigh_vals!(A, D, alg; kwargs...) -end - -function copy_eigh_trunc(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc(A, alg; kwargs...) -end - -function copy_eigh_trunc!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc!(A, DV, alg; kwargs...) -end - -function copy_eigh_trunc_no_error(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc_no_error(A, alg; kwargs...) -end - -function copy_eigh_trunc_no_error!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc_no_error!(A, DV, alg; kwargs...) -end - -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) -MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) - -@timedtestset "EIGH AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - atol = rtol = m * m * precision(T) - A = make_eigh_matrix(rng, T, m) - D, V = eigh_full(A) - Ddiag = diagview(D) - ΔV = randn(rng, T, m, m) - ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) - ΔD = randn(rng, real(T), m, m) - ΔD2 = Diagonal(randn(rng, real(T), m)) - dD = make_mooncake_tangent(ΔD2) - dV = make_mooncake_tangent(ΔV) - dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) - @testset for alg in ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), - #LAPACK_Bisection(), - #LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI - ) - @testset "eigh_full" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) - end - @testset "eigh_vals" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) - end - @testset "eigh_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end - truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) - ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) - Dtrunc = Diagonal(diagview(D)[ind]) - Vtrunc = V[:, ind] - ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) - ΔVtrunc = ΔV[:, ind] - dDtrunc = make_mooncake_tangent(ΔDtrunc) - dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) - end - end -end - -@timedtestset "SVD AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - minmn = min(m, n) - @testset for alg in ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), # expensive on CI - ) - @testset "svd_compact" begin - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - dS = make_mooncake_tangent(ΔS2) - dU = make_mooncake_tangent(ΔU) - dVᴴ = make_mooncake_tangent(ΔVᴴ) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), alg) - end - @testset "svd_full" begin - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - ΔUfull = zeros(T, m, m) - ΔSfull = zeros(real(T), m, n) - ΔVᴴfull = zeros(T, n, n) - U, S, Vᴴ = svd_full(A) - view(ΔUfull, :, 1:minmn) .= ΔU - view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ - diagview(ΔSfull)[1:minmn] .= diagview(ΔS2) - dS = make_mooncake_tangent(ΔSfull) - dU = make_mooncake_tangent(ΔUfull) - dVᴴ = make_mooncake_tangent(ΔVᴴfull) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔUfull, ΔSfull, ΔVᴴfull)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_full!, svd_full, A, (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull), alg) - end - @testset "svd_vals" begin - Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) - S = svd_vals(A, alg) - test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, randn(rng, real(T), minmn), alg) - end - @testset "svd_trunc" begin - @testset for r in 1:4:minmn - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - truncalg = TruncatedAlgorithm(alg, truncrank(r)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - dStrunc = make_mooncake_tangent(ΔStrunc) - dUtrunc = make_mooncake_tangent(ΔUtrunc) - dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) - dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) - end - @testset "trunctol" begin - U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) - truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] - ΔVᴴtrunc = ΔVᴴ[ind, :] - dStrunc = make_mooncake_tangent(ΔStrunc) - dUtrunc = make_mooncake_tangent(ΔUtrunc) - dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) - dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) - dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) - end - end - end - end -end - -@timedtestset "Polar AD Rules with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - @testset for alg in PolarViaSVD.( - ( - LAPACK_QRIteration(), - #LAPACK_DivideAndConquer(), # expensive on CI - ) - ) - if m >= n - WP = left_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, left_polar!, left_polar, A, WP, (randn(rng, T, m, n), randn(rng, T, n, n)), alg) - elseif m <= n - PWᴴ = right_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) - test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, (randn(rng, T, m, m), randn(rng, T, m, n)), alg) - end - end - end -end - -left_orth_qr(X) = left_orth(X; alg = :qr) -left_orth_polar(X) = left_orth(X; alg = :polar) -left_null_qr(X) = left_null(X; alg = :qr) -right_orth_lq(X) = right_orth(X; alg = :lq) -right_orth_polar(X) = right_orth(X; alg = :polar) -right_null_lq(X) = right_null(X; alg = :lq) - -MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) - -@timedtestset "Orth and null with eltype $T" for T in ETs - rng = StableRNG(12345) - m = 19 - @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) - VC = left_orth(A) - CVᴴ = right_orth(A) - Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, left_orth!, left_orth, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, right_orth!, right_orth, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - - Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - if m >= n - Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - end - - N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) - dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) - test_pullbacks_match(rng, ((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) - - Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - - if m <= n - Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - end - - Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) - test_pullbacks_match(rng, ((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +#BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) + TestSuite.seed_rng!(123) + if CUDA.functional() + TestSuite.test_mooncake(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + #n == m && TestSuite.test_mooncake(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) + end + #=if AMDGPU.functional() + TestSuite.test_mooncake(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + TestSuite.test_mooncake(Diagonal{T, ROCVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) + end=# # not yet supported + if !is_buildkite + TestSuite.test_mooncake(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) + #n == m && TestSuite.test_mooncake(Diagonal{T, Vector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) end end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 2f3fde50..28833557 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -84,6 +84,8 @@ function instantiate_unitary(T, A::ROCMatrix{<:Complex}, sz) end instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A), eltype(A), sz), one(eltype(A)))) +include("ad_utils.jl") + include("qr.jl") include("lq.jl") include("polar.jl") @@ -93,5 +95,7 @@ include("eig.jl") include("eigh.jl") include("orthnull.jl") include("svd.jl") +include("mooncake.jl") +include("chainrules.jl") end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl new file mode 100644 index 00000000..31fb8ca1 --- /dev/null +++ b/test/testsuite/ad_utils.jl @@ -0,0 +1,423 @@ +function remove_svdgauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) + ) + gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true) + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end +function remove_eiggauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + gaugepart = V' * ΔV + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end +function remove_eighgauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D) + ) + gaugepart = V' * ΔV + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V, gaugepart, -1, 1) + return ΔV +end + +function stabilize_eigvals!(D::AbstractVector) + absD = collect(abs.(D)) + p = invperm(sortperm(collect(absD))) # rank of abs(D) + # account for exact degeneracies in absolute value when having complex conjugate pairs + for i in 1:(length(D) - 1) + if absD[i] == absD[i + 1] # conjugate pairs will appear sequentially + p[p .>= p[i + 1]] .-= 1 # lower the rank of all higher ones + end + end + n = maximum(p) + # rescale eigenvalues so that they lie on distinct radii in the complex plane + # that are chosen randomly in non-overlapping intervals [10 * k/n, 10 * (k+0.5)/n)] for k=1,...,n + radii = 10 .* ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n + hD = sign.(collect(D)) .* radii[p] + copyto!(D, hD) + return D +end +function make_eig_matrix(T, sz) + A = instantiate_matrix(T, sz) + D, V = eig_full(A) + stabilize_eigvals!(diagview(D)) + Ac = V * D * inv(V) + Af = (eltype(T) <: Real) ? real(Ac) : Ac + if T <: Diagonal + copyto!(diagview(A), diagview(Af)) + else + copyto!(A, Af) + end + return A +end +function make_eigh_matrix(T, sz) + A = project_hermitian!(instantiate_matrix(T, sz)) + D, V = eigh_full(A) + stabilize_eigvals!(diagview(D)) + return project_hermitian!(V * D * V') +end + +function ad_qr_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + QR = qr_compact(A) + T = eltype(A) + ΔQ = randn!(similar(A, T, m, minmn)) + ΔR = randn!(similar(A, T, minmn, n)) + return QR, (ΔQ, ΔR) +end + +function ad_qr_compact_setup(A::Diagonal) + m, n = size(A) + minmn = min(m, n) + QR = qr_compact(A) + T = eltype(A) + ΔQ = Diagonal(randn!(similar(A.diag, T, m))) + ΔR = Diagonal(randn!(similar(A.diag, T, m))) + return QR, (ΔQ, ΔR) +end + +function ad_qr_null_setup(A) + m, n = size(A) + minmn = min(m, n) + Q, R = qr_compact(A) + T = eltype(A) + ΔN = Q * randn!(similar(A, T, minmn, max(0, m - minmn))) + N = qr_null(A) + return N, ΔN +end + +function ad_qr_full_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + Q, R = qr_full(A) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn!(similar(A, T, m, m)) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn!(similar(A, T, m, n)) + return (Q, R), (ΔQ, ΔR) +end + +ad_qr_full_setup(A::Diagonal) = ad_qr_compact_setup(A) + +function ad_qr_rank_deficient_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + r = minmn - 5 + Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) + Q, R = qr_compact(Ard) + QR = (Q, R) + ΔQ = randn!(similar(A, T, m, minmn)) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + MatrixAlgebraKit.zero!(ΔQ2) + ΔR = randn!(similar(A, T, minmn, n)) + view(ΔR, (r + 1):minmn, :) .= 0 + return (Q, R), (ΔQ, ΔR) +end + +function ad_qr_rank_deficient_compact_setup(A::Diagonal) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + r = minmn - 5 + Ard_ = randn!(similar(A, T, m)) + MatrixAlgebraKit.zero!(view(Ard_, (r + 1):m)) + Ard = Diagonal(Ard_) + Q, R = qr_compact(Ard) + ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) + ΔR = Diagonal(randn!(similar(diagview(A), T, m))) + MatrixAlgebraKit.zero!(view(diagview(ΔQ), (r + 1):m)) + MatrixAlgebraKit.zero!(view(diagview(ΔR), (r + 1):m)) + return (Q, R), (ΔQ, ΔR) +end + +function ad_lq_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + LQ = lq_compact(A) + T = eltype(A) + ΔL = randn!(similar(A, T, m, minmn)) + ΔQ = randn!(similar(A, T, minmn, n)) + return LQ, (ΔL, ΔQ) +end +ad_lq_compact_setup(A::Diagonal) = ad_qr_compact_setup(A) + +function ad_lq_null_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + L, Q = lq_compact(A) + ΔNᴴ = randn!(similar(A, T, max(0, n - minmn), minmn)) * Q + Nᴴ = randn!(similar(A, T, max(0, n - minmn), n)) + return Nᴴ, ΔNᴴ +end + +function ad_lq_full_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + L, Q = lq_full(A) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn!(similar(A, T, n, n)) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + ΔQ2 .= (ΔQ2 * Q1') * Q1 + ΔL = randn!(similar(A, T, m, n)) + return (L, Q), (ΔL, ΔQ) +end +ad_lq_full_setup(A::Diagonal) = ad_qr_full_setup(A) + +function ad_lq_rank_deficient_compact_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + r = minmn - 5 + Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) + L, Q = lq_compact(Ard) + ΔL = randn!(similar(A, T, m, minmn)) + ΔQ = randn!(similar(A, T, minmn, n)) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + return (L, Q), (ΔL, ΔQ) +end +ad_lq_rank_deficient_compact_setup(A::Diagonal) = ad_qr_rank_deficient_compact_setup(A) + +function ad_eig_full_setup(A) + m, n = size(A) + T = eltype(A) + DV = eig_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn!(similar(A, complex(T), m, m)) + ΔV = remove_eiggauge_dependence!(ΔV, D, V) + ΔD = randn!(similar(A, complex(T), m, m)) + ΔD2 = Diagonal(randn!(similar(A, complex(T), m))) + return DV, (ΔD, ΔV), (ΔD2, ΔV) +end + +function ad_eig_full_setup(A::Diagonal) + m, n = size(A) + T = complex(eltype(A)) + DV = eig_full(A) + D, V = DV + ΔV = randn!(similar(A.diag, T, m, m)) + ΔV = remove_eiggauge_dependence!(ΔV, D, V) + ΔD = Diagonal(randn!(similar(A.diag, T, m))) + ΔD2 = Diagonal(randn!(similar(A.diag, T, m))) + return DV, (ΔD, ΔV), (ΔD2, ΔV) +end + +function ad_eig_vals_setup(A) + m, n = size(A) + T = complex(eltype(A)) + D = eig_vals(A) + ΔD = randn!(similar(A, complex(T), m)) + return D, ΔD +end + +function ad_eig_vals_setup(A::Diagonal) + m, n = size(A) + T = complex(eltype(A)) + D = eig_vals(A) + ΔD = randn!(similar(A.diag, T, m)) + return D, ΔD +end + +function ad_eig_trunc_setup(A, truncalg) + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + Dtrunc = Diagonal(diagview(DV[1])[ind]) + Vtrunc = DV[2][:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2V[1])[ind]) + ΔVtrunc = ΔDV[2][:, ind] + return DV, (Dtrunc, Vtrunc), ΔD2V, (ΔDtrunc, ΔVtrunc) +end + +function ad_eigh_full_setup(A) + m, n = size(A) + T = eltype(A) + DV = eigh_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn!(similar(A, T, m, m)) + ΔV = remove_eighgauge_dependence!(ΔV, D, V) + ΔD = randn!(similar(A, real(T), m, m)) + ΔD2 = Diagonal(randn!(similar(A, real(T), m))) + return DV, (ΔD, ΔV), (ΔD2, ΔV) +end + +function ad_eigh_vals_setup(A) + m, n = size(A) + T = eltype(A) + D = eigh_vals(A) + ΔD = randn!(similar(A, real(T), m)) + return D, ΔD +end + +function ad_eigh_trunc_setup(A, truncalg) + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + Dtrunc = Diagonal(diagview(DV[1])[ind]) + Vtrunc = DV[2][:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2V[1])[ind]) + ΔVtrunc = ΔDV[2][:, ind] + return DV, (Dtrunc, Vtrunc), ΔD2V, (ΔDtrunc, ΔVtrunc) +end + +function ad_svd_compact_setup(A) + m, n = size(A) + T = eltype(A) + minmn = min(m, n) + ΔU = randn!(similar(A, T, m, minmn)) + ΔS = randn!(similar(A, real(T), minmn, minmn)) + ΔS2 = Diagonal(randn!(similar(A, real(T), minmn))) + ΔVᴴ = randn!(similar(A, T, minmn, n)) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), (ΔU, ΔS2, ΔVᴴ) +end + +function ad_svd_compact_setup(A::Diagonal) + m, n = size(A) + T = eltype(A) + minmn = min(m, n) + ΔU = randn!(similar(A.diag, T, m, n)) + ΔS = Diagonal(randn!(similar(A.diag, real(T), minmn))) + ΔS2 = Diagonal(randn!(similar(A.diag, real(T), minmn))) + ΔVᴴ = randn!(similar(A.diag, T, m, n)) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), (ΔU, ΔS2, ΔVᴴ) +end + +function ad_svd_full_setup(A) + m, n = size(A) + T = eltype(A) + minmn = min(m, n) + ΔU = randn!(similar(A, T, m, minmn)) + ΔS = randn!(similar(A, real(T), minmn, minmn)) + ΔS2 = Diagonal(randn!(similar(A, real(T), minmn))) + ΔVᴴ = randn!(similar(A, T, minmn, n)) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + ΔUfull = similar(A, T, m, m) + ΔUfull .= zero(T) + ΔSfull = similar(A, real(T), m, n) + ΔSfull .= zero(real(T)) + ΔVᴴfull = similar(A, T, n, n) + ΔVᴴfull .= zero(T) + U, S, Vᴴ = svd_full(A) + view(ΔUfull, :, 1:minmn) .= ΔU + view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ + diagview(ΔSfull)[1:minmn] .= diagview(ΔS2) + return (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull) +end + +ad_svd_full_setup(A::Diagonal) = ad_svd_compact_setup(A) + +function ad_svd_vals_setup(A) + m, n = size(A) + minmn = min(m, n) + T = eltype(A) + S = svd_vals(A) + ΔS = randn!(similar(A, real(T), minmn)) + return S, ΔS +end + +function ad_svd_trunc_setup(A, truncalg) + USVᴴ, ΔUSVᴴ, ΔUS2Vᴴ = ad_svd_compact_setup(A) + ind = MatrixAlgebraKit.findtruncated(diagview(USVᴴ[2]), truncalg.trunc) + Strunc = Diagonal(diagview(USVᴴ[2])[ind]) + Utrunc = USVᴴ[1][:, ind] + Vᴴtrunc = USVᴴ[3][ind, :] + ΔStrunc = Diagonal(diagview(ΔUS2Vᴴ[2])[ind]) + ΔUtrunc = ΔUSVᴴ[1][:, ind] + ΔVᴴtrunc = ΔUSVᴴ[3][ind, :] + return USVᴴ, ΔUS2Vᴴ, (ΔUtrunc, ΔStrunc, ΔVᴴtrunc) +end + +function ad_left_polar_setup(A) + m, n = size(A) + T = eltype(A) + WP = left_polar(A) + ΔWP = (randn!(similar(A, T, m, n)), randn!(similar(A, T, n, n))) + return WP, ΔWP +end + +function ad_left_polar_setup(A::Diagonal) + m, n = size(A) + T = eltype(A) + WP = left_polar(A) + ΔWP = (Diagonal(randn!(similar(A.diag))), randn!(similar(WP[2]))) + return WP, ΔWP +end + +function ad_right_polar_setup(A) + m, n = size(A) + T = eltype(A) + PWᴴ = right_polar(A) + ΔPWᴴ = (randn!(similar(A, T, m, m)), randn!(similar(A, T, m, n))) + return PWᴴ, ΔPWᴴ +end +function ad_right_polar_setup(A::Diagonal) + m, n = size(A) + T = eltype(A) + PWᴴ = right_polar(A) + ΔPWᴴ = (randn!(similar(PWᴴ[1])), Diagonal(randn!(similar(A.diag)))) + return PWᴴ, ΔPWᴴ +end + +function ad_left_orth_setup(A) + m, n = size(A) + T = eltype(A) + VC = left_orth(A) + ΔVC = (randn!(similar(A, T, size(VC[1])...)), randn!(similar(A, T, size(VC[2])...))) + return VC, ΔVC +end +function ad_left_orth_setup(A::Diagonal) + m, n = size(A) + T = eltype(A) + VC = left_orth(A) + ΔVC = (Diagonal(randn!(similar(A.diag, T, m))), Diagonal(randn!(similar(A.diag, T, m)))) + return VC, ΔVC +end + +function ad_left_null_setup(A) + m, n = size(A) + T = eltype(A) + N = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) + ΔN = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n))) + return N, ΔN +end + +function ad_right_orth_setup(A) + m, n = size(A) + T = eltype(A) + CVᴴ = right_orth(A) + ΔCVᴴ = (randn!(similar(A, T, size(CVᴴ[1])...)), randn!(similar(A, T, size(CVᴴ[2])...))) + return CVᴴ, ΔCVᴴ +end +ad_right_orth_setup(A::Diagonal) = ad_left_orth_setup(A) + +function ad_right_null_setup(A) + m, n = size(A) + T = eltype(A) + Nᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] + ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] + return Nᴴ, ΔNᴴ +end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl new file mode 100644 index 00000000..b4126c59 --- /dev/null +++ b/test/testsuite/chainrules.jl @@ -0,0 +1,612 @@ +using MatrixAlgebraKit +using ChainRulesCore, ChainRulesTestUtils, Zygote +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +for f in + ( + :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, + :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, + :eig_trunc_no_error, :eigh_trunc_no_error, + :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, + :left_polar, :right_polar, + ) + copy_f = Symbol(:cr_copy_, f) + f! = Symbol(f, '!') + _hermitian = startswith(string(f), "eigh") + @eval begin + function $copy_f(input, alg) + if $_hermitian + input = (input + input') / 2 + end + return $f(input, alg) + end + function ChainRulesCore.rrule(::typeof($copy_f), input, alg) + output = MatrixAlgebraKit.initialize_output($f!, input, alg) + if $_hermitian + input = (input + input') / 2 + else + input = copy(input) + end + output, pb = ChainRulesCore.rrule($f!, input, output, alg) + return output, x -> (NoTangent(), pb(x)[2], NoTangent()) + end + end +end + +function test_chainrules(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Chainrules AD $summary_str" begin + test_chainrules_qr(T, sz; kwargs...) + test_chainrules_lq(T, sz; kwargs...) + if length(sz) == 1 || sz[1] == sz[2] + test_chainrules_eig(T, sz; kwargs...) + test_chainrules_eigh(T, sz; kwargs...) + end + test_chainrules_svd(T, sz; kwargs...) + test_chainrules_polar(T, sz; kwargs...) + test_chainrules_orthnull(T, sz; kwargs...) + end +end + +function test_chainrules_qr( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "QR ChainRules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_qr_algorithm(A) + @testset "qr_compact" begin + QR, ΔQR = ad_qr_compact_setup(A) + ΔQ, ΔR = ΔQR + test_rrule( + cr_copy_qr_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ qr_compact, A; + fkwargs = (; positive = true), output_tangent = ΔR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "qr_null" begin + N, ΔN = ad_qr_null_setup(A) + test_rrule( + cr_copy_qr_null, A, alg ⊢ NoTangent(); + output_tangent = ΔN, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_null, A; + fkwargs = (; positive = true), output_tangent = ΔN, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m, n = size(A) + end + @testset "qr_full" begin + QR, ΔQR = ad_qr_full_setup(A) + test_rrule( + cr_copy_qr_full, A, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_full, A; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m, n = size(A) + end + @testset "qr_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) + ΔQ, ΔR = ΔQR + test_rrule( + cr_copy_qr_compact, Ard, alg ⊢ NoTangent(); + output_tangent = ΔQR, atol = atol, rtol = rtol + ) + test_rrule( + config, qr_compact, Ard; + fkwargs = (; positive = true), output_tangent = ΔQR, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_lq( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "LQ Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_lq_algorithm(A) + @testset "lq_compact" begin + LQ, ΔLQ = ad_lq_compact_setup(A) + ΔL, ΔQ = ΔLQ + test_rrule( + cr_copy_lq_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔL, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ lq_compact, A; + fkwargs = (; positive = true), output_tangent = ΔQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_null" begin + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) + test_rrule( + cr_copy_lq_null, A, alg ⊢ NoTangent(); + output_tangent = ΔNᴴ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_null, A; + fkwargs = (; positive = true), output_tangent = ΔNᴴ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_full" begin + LQ, ΔLQ = ad_lq_full_setup(A) + test_rrule( + cr_copy_lq_full, A, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_full, A; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "lq_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) + test_rrule( + cr_copy_lq_compact, Ard, alg ⊢ NoTangent(); + output_tangent = ΔLQ, atol = atol, rtol = rtol + ) + test_rrule( + config, lq_compact, Ard; + fkwargs = (; positive = true), output_tangent = ΔLQ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_eig( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIG Chainrules AD rules $summary_str" begin + A = make_eig_matrix(T, sz) + m = size(A, 1) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_eig_algorithm(A) + @testset "eig_full" begin + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + ΔD, ΔV = ΔDV + test_rrule( + cr_copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol + ) + test_rrule( + cr_copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol + ) + test_rrule( + config, eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔDV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔD2V, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ eig_full, A, alg ⊢ NoTangent(); + output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eig_vals" begin + D, ΔD = ad_eig_vals_setup(A) + test_rrule( + cr_copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol + ) + test_rrule( + config, eig_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔD, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + test_rrule( + cr_copy_eig_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eig_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + test_rrule( + cr_copy_eig_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eig_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + end + end +end + +function test_chainrules_eigh( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIGH ChainRules AD rules $summary_str" begin + A = make_eigh_matrix(T, sz) + m = size(A, 1) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_eigh_algorithm(A) + # copy_eigh_xxxx includes a projector onto the Hermitian part of the matrix + @testset "eigh_full" begin + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) + ΔD, ΔV = ΔDV + test_rrule( + cr_copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔDV, atol, rtol + ) + test_rrule( + cr_copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = ΔD2V, atol, rtol + ) + # eigh_full does not include a projector onto the Hermitian part of the matrix + test_rrule( + config, eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔDV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD2V, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, first ∘ eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eigh_vals" begin + D, ΔD = ad_eigh_vals_setup(A) + test_rrule( + cr_copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = ΔD, atol, rtol + ) + test_rrule( + config, eigh_vals ∘ Matrix ∘ Hermitian, A; + output_tangent = ΔD, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "eigh_trunc" begin + eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) + eigh_trunc_no_error2(A; kwargs...) = eigh_trunc_no_error(Matrix(Hermitian(A)); kwargs...) + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + test_rrule( + cr_copy_eigh_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eigh_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = truncrank(r; by = real) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), trunc) + truncalg = TruncatedAlgorithm(alg, trunc) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + test_rrule( + config, eigh_trunc2, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eigh_trunc_no_error2, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔDVtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + D, ΔD = ad_eigh_vals_setup(A / 2) + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + test_rrule( + cr_copy_eigh_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_eigh_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔDVtrunc, atol = atol, rtol = rtol + ) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = trunctol(; rtol = 1 / 2) + truncalg = TruncatedAlgorithm(alg, trunc) + DV, DVtrunc, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + test_rrule( + config, eigh_trunc2, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔDVtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, eigh_trunc_no_error2, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔDVtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_svd( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "SVD Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + minmn = min(size(A)...) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_svd_algorithm(A) + @testset "svd_compact" begin + USV, ΔUSVᴴ, ΔUS2Vᴴ = ad_svd_compact_setup(A) + test_rrule( + cr_copy_svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUS2Vᴴ, atol = atol, rtol = rtol + ) + test_rrule( + config, svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUSVᴴ, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, svd_compact, A, alg ⊢ NoTangent(); + output_tangent = ΔUS2Vᴴ, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "svd_vals" begin + S, ΔS = ad_svd_vals_setup(A) + test_rrule( + cr_copy_svd_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔS, atol, rtol + ) + test_rrule( + config, svd_vals, A, alg ⊢ NoTangent(); + output_tangent = ΔS, atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "svd_trunc" begin + @testset for r in 1:4:minmn + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + test_rrule( + cr_copy_svd_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔUSVᴴtrunc, + atol = atol, rtol = rtol + ) + U, S, Vᴴ = USVᴴ + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), ΔUSVᴴtrunc) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = truncrank(r) + ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) + test_rrule( + config, svd_trunc, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, svd_trunc_no_error, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔUSVᴴtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + S, ΔS = ad_svd_vals_setup(A) + truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + test_rrule( + cr_copy_svd_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol + ) + test_rrule( + cr_copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = ΔUSVᴴtrunc, atol = atol, rtol = rtol + ) + U, S, Vᴴ = USVᴴ + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), ΔUSVᴴtrunc) + @test isapprox(dA1, dA2; atol = atol, rtol = rtol) + trunc = trunctol(; atol = S[1, 1] / 2) + ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) + test_rrule( + config, svd_trunc, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔUSVᴴtrunc..., zero(real(T))), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, svd_trunc_no_error, A; + fkwargs = (; trunc = trunc), + output_tangent = ΔUSVᴴtrunc, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end +end + +function test_chainrules_polar( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Polar Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + alg = MatrixAlgebraKit.default_polar_algorithm(A) + @testset "left_polar" begin + if m >= n + test_rrule(cr_copy_left_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + test_rrule( + config, left_polar, A, alg ⊢ NoTangent(); + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + @testset "right_polar" begin + if m <= n + test_rrule(cr_copy_right_polar, A, alg ⊢ NoTangent(); atol = atol, rtol = rtol) + test_rrule( + config, right_polar, A, alg ⊢ NoTangent(); + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + end +end + +function test_chainrules_orthnull( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Orthnull Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + N, ΔN = ad_left_null_setup(A) + Nᴴ, ΔNᴴ = ad_right_null_setup(A) + test_rrule( + config, left_orth, A; + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, left_orth, A; + fkwargs = (; alg = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m >= n && + test_rrule( + config, left_orth, A; + fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, left_null, A; + fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, + rrule_f = rrule_via_ad, check_inferred = false + ) + + test_rrule( + config, right_orth, A; + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, right_orth, A; fkwargs = (; alg = :lq), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + m <= n && + test_rrule( + config, right_orth, A; fkwargs = (; alg = :polar), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + test_rrule( + config, right_null, A; + fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end +end diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl new file mode 100644 index 00000000..d10bc0fc --- /dev/null +++ b/test/testsuite/enzyme.jl @@ -0,0 +1,459 @@ +using TestExtras +using MatrixAlgebraKit +using Enzyme, EnzymeTestUtils +using MatrixAlgebraKit: diagview, TruncatedAlgorithm +using LinearAlgebra: Diagonal, Hermitian, mul!, BlasFloat +using GenericLinearAlgebra, GenericSchur + +function enz_copy_eigh_full(A, alg) + A = (A + A') / 2 + return eigh_full(A, alg) +end + +function enz_copy_eigh_full!(A, DV::Tuple, alg::MatrixAlgebraKit.AbstractAlgorithm) + A = (A + A') / 2 + return eigh_full!(A, DV, alg) +end + +function enz_copy_eigh_vals(A; kwargs...) + A = (A + A') / 2 + return eigh_vals(A; kwargs...) +end + +function enz_copy_eigh_vals!(A, D; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D; kwargs...) +end + +function enz_copy_eigh_vals(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_vals(A, alg; kwargs...) +end + +function enz_copy_eigh_vals!(A, D, alg; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D, alg; kwargs...) +end + +function enz_copy_eigh_trunc_no_error(A, alg) + A = (A + A') / 2 + return eigh_trunc_no_error(A, alg) +end + +function enz_copy_eigh_trunc_no_error!(A, DV, alg) + A = (A + A') / 2 + return eigh_trunc_no_error!(A, DV, alg) +end + +function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated) + ΔA = randn(rng, eltype(A), size(A)...) + A_ΔA() = Duplicated(copy(A), copy(ΔA)) + function args_Δargs() + if isnothing(args) + return Const(args) + elseif args isa Tuple && all(isnothing, args) + return Const(args) + else + return Duplicated(copy.(args), copy.(Δargs)) + end + end + copy_activities = isnothing(alg) ? (Const(f), A_ΔA()) : (Const(f), A_ΔA(), Const(alg)) + inplace_activities = isnothing(alg) ? (Const(f!), A_ΔA(), args_Δargs()) : (Const(f!), A_ΔA(), args_Δargs(), Const(alg)) + + mode = EnzymeTestUtils.set_runtime_activity(ReverseSplitWithPrimal, false) + c_act = Const(EnzymeTestUtils.call_with_kwargs) + forward_copy, reverse_copy = autodiff_thunk( + mode, typeof(c_act), return_act, typeof(Const(())), map(typeof, copy_activities)... + ) + forward_inplace, reverse_inplace = autodiff_thunk( + mode, typeof(c_act), return_act, typeof(Const(())), map(typeof, inplace_activities)... + ) + copy_tape, copy_y_ad, copy_shadow_result = forward_copy(c_act, Const(()), copy_activities...) + inplace_tape, inplace_y_ad, inplace_shadow_result = forward_inplace(c_act, Const(()), inplace_activities...) + if !(copy_shadow_result === nothing) + flush(stdout) + EnzymeTestUtils.map_fields_recursive(copyto!, copy_shadow_result, copy.(ȳ)) + end + if !(inplace_shadow_result === nothing) + EnzymeTestUtils.map_fields_recursive(copyto!, inplace_shadow_result, copy.(ȳ)) + end + dx_copy_ad = only(reverse_copy(c_act, Const(()), copy_activities..., copy_tape)) + dx_inplace_ad = only(reverse_inplace(c_act, Const(()), inplace_activities..., inplace_tape)) + # check all returned derivatives between copy & inplace + for (i, (copy_act_i, inplace_act_i)) in enumerate(zip(copy_activities[2:end], inplace_activities[2:end])) + if copy_act_i isa Duplicated && inplace_act_i isa Duplicated + msg_deriv = "shadow derivative for argument $(i - 1) should match between copy and inplace" + EnzymeTestUtils.test_approx(copy_act_i.dval, inplace_act_i.dval, msg_deriv) + end + end + return +end + +function test_enzyme(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Enzyme AD $summary_str" begin + test_enzyme_qr(T, sz; kwargs...) + test_enzyme_lq(T, sz; kwargs...) + if length(sz) == 1 || sz[1] == sz[2] + test_enzyme_eig(T, sz; kwargs...) + test_enzyme_eigh(T, sz; kwargs...) + end + test_enzyme_svd(T, sz; kwargs...) + if eltype(T) <: BlasFloat + test_enzyme_polar(T, sz; kwargs...) + test_enzyme_orthnull(T, sz; kwargs...) + end + end +end + +function test_enzyme_qr( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "QR Enzyme AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + alg = MatrixAlgebraKit.default_qr_algorithm(A) + @testset "qr_compact" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + QR, ΔQR = ad_qr_compact_setup(A) + eltype(T) <: BlasFloat && test_reverse(qr_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) + test_pullbacks_match(rng, qr_compact!, qr_compact, A, QR, ΔQR, alg) + end + end + @testset "qr_null" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + N, ΔN = ad_qr_null_setup(A) + eltype(T) <: BlasFloat && test_reverse(qr_null, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔN) + test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) + end + end + @testset "qr_full" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + QR, ΔQR = ad_qr_full_setup(A) + eltype(T) <: BlasFloat && test_reverse(qr_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (ΔQ, ΔR), fdm) + test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) + end + end + @testset "qr_compact - rank-deficient A" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) + eltype(T) <: BlasFloat && test_reverse(qr_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = (ΔQ, ΔR), fdm) + test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) + end + end + end +end + +function test_enzyme_lq( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "LQ Enzyme AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.default_lq_algorithm(A) + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + @testset "lq_compact" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + LQ, ΔLQ = ad_lq_compact_setup(A) + eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ), fdm) + test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (ΔL, ΔQ), alg) + end + end + @testset "lq_null" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) + eltype(T) <: BlasFloat && test_reverse(lq_null, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔNᴴ) + test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) + end + end + @testset "lq_full" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + LQ, ΔLQ = ad_lq_full_setup(A) + eltype(T) <: BlasFloat && test_reverse(lq_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ), fdm) + test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) + end + end + @testset "lq_compact -- rank-deficient A" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) + eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ), fdm) + test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) + end + end + end +end + +function test_enzyme_eig( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIG Enzyme AD rules $summary_str" begin + A = make_eig_matrix(T, sz) + m = size(A, 1) + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + alg = MatrixAlgebraKit.default_eig_algorithm(A) + @testset "eig_full" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + if eltype(T) <: BlasFloat + test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm) + test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) + else + test_pullbacks_match(rng, eig_full!, eig_full, A, (nothing, nothing), (nothing, nothing), alg; ȳ = (ΔD2, ΔV)) + end + end + end + @testset "eig_vals" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + D, ΔD = ad_eig_vals_setup(A) + if eltype(T) <: BlasFloat + test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = copy(ΔD2.diag), fdm) + test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) + else + test_pullbacks_match(rng, eig_vals!, eig_vals, A, nothing, nothing, alg; ȳ = ΔD2.diag) + end + end + end + @testset "eig_trunc" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + for r in 1:4:m + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + if eltype(T) <: BlasFloat + test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + else + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + end + end + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + if eltype(T) <: BlasFloat + test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + else + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + end + end + end + end +end + +function test_enzyme_eigh( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIGH Enzyme AD rules $summary_str" begin + A = make_eigh_matrix(T, sz) + m = size(A, 1) + alg = MatrixAlgebraKit.default_eigh_algorithm(A) + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + @testset "eigh_full" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + if eltype(T) <: BlasFloat + test_reverse(copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm) + test_reverse(copy_eigh_full!, RT, (copy(A), TA), ((D, V), TA), (alg, Const); atol, rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm) + end + test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) + end + end + @testset "eigh_vals" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + eltype(T) <: BlasFloat && test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = copy(ΔD2.diag), fdm) + test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) + end + end + @testset "eigh_trunc" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + for r in 1:4:m + Ddiag = diagview(D) + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) + test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) + end + D = eigh_vals(A / 2) + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) + test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) + end + end + end +end + +function test_enzyme_svd( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "SVD Enzyme AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + minmn = min(size(A)...) + alg = MatrixAlgebraKit.default_svd_algorithm(A) + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + @testset "svd_compact" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) + if eltype(T) <: BlasFloat + test_reverse(svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = (ΔU, ΔS, ΔVᴴ), fdm) + test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), alg) + else + USVᴴ = MatrixAlgebraKit.initialize_output(svd_compact!, A, alg) + test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = (ΔU, ΔS, ΔVᴴ)) + end + end + end + @testset "svd_full" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) + if eltype(T) <: BlasFloat + test_reverse(svd_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔUSVᴴ, fdm) + test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ, alg) + else + USVᴴ = MatrixAlgebraKit.initialize_output(svd_full!, A, alg) + test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = ΔUSVᴴ) + end + end + end + @testset "svd_vals" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + S, ΔS = ad_svd_vals_setup(A) + if eltype(T) <: BlasFloat + test_reverse(svd_vals, RT, (A, TA); atol, rtol, fkwargs = (alg = alg,), output_tangent = ΔS, fdm) + test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS, alg) + else + S = MatrixAlgebraKit.initialize_output(svd_vals!, A, alg) + test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, nothing, alg; ȳ = ΔS) + end + end + end + @testset "svd_trunc" begin + S, ΔS = ad_svd_vals_setup(A) + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + for r in 1:4:minmn + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + if eltype(T) <: BlasFloat + test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = ΔUSVᴴtrunc) + else + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = ΔUSVᴴtrunc) + end + end + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + if eltype(T) <: BlasFloat + test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = ΔUSVᴴtrunc) + else + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = ΔUSVᴴtrunc) + end + end + end + end +end + +# GLA works with polar, but these tests +# segfault because of Sylvester + BigFloat +function test_enzyme_polar( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Polar Enzyme AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + alg = MatrixAlgebraKit.default_polar_algorithm(A) + @testset "left_polar" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + if m >= n + WP, ΔWP = ad_left_polar_setup(A) + eltype(T) <: BlasFloat && test_reverse(left_polar, RT, (A, TA), (alg, Const); atol, rtol) + test_pullbacks_match(rng, left_polar!, left_polar, A, WP, ΔWP, alg) + end + end + end + @testset "right_polar" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + if m <= n + PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) + eltype(T) <: BlasFloat && test_reverse(right_polar, RT, (A, TA), (alg, Const); atol, rtol) + test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, ΔPWᴴ, alg) + end + end + end + end +end + +function test_enzyme_orthnull( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Orthnull Enzyme AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + VC, ΔVC = ad_left_orth_setup(A) + CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + @testset "left_orth" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset for alg in (:polar, :qr) + n > m && alg == :polar && continue + eltype(T) <: BlasFloat && test_reverse(left_orth, RT, (A, TA); atol, rtol, fkwargs = (alg = alg,), fdm) + left_orth_alg!(A, VC) = left_orth!(A, VC; alg = alg) + left_orth_alg(A) = left_orth(A; alg = alg) + test_pullbacks_match(rng, left_orth_alg!, left_orth_alg, A, VC, ΔVC) + end + end + end + N, ΔN = ad_left_null_setup(A) + @testset "left_null" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + left_null_qr!(A, N) = left_null!(A, N; alg = :qr) + left_null_qr(A) = left_null(A; alg = :qr) + eltype(T) <: BlasFloat && test_reverse(left_null_qr, RT, (A, TA); output_tangent = ΔN, atol, rtol) + test_pullbacks_match(rng, left_null_qr!, left_null_qr, A, N, ΔN) + end + end + @testset "right_orth" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + @testset for alg in (:polar, :lq) + n < m && alg == :polar && continue + eltype(T) <: BlasFloat && test_reverse(right_orth, RT, (A, TA); atol, rtol, fkwargs = (alg = alg,), fdm) + right_orth_alg!(A, CVᴴ) = right_orth!(A, CVᴴ; alg = alg) + right_orth_alg(A) = right_orth(A; alg = alg) + test_pullbacks_match(rng, right_orth_alg!, right_orth_alg, A, CVᴴ, ΔCVᴴ) + end + end + end + Nᴴ, ΔNᴴ = ad_right_null_setup(A) + @testset "right_null" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + right_null_lq!(A, Nᴴ) = right_null!(A, Nᴴ; alg = :lq) + right_null_lq(A) = right_null(A; alg = :lq) + eltype(T) <: BlasFloat && test_reverse(right_null_lq, RT, (A, TA); output_tangent = ΔNᴴ, atol, rtol) + test_pullbacks_match(rng, right_null_lq!, right_null_lq, A, Nᴴ, ΔNᴴ) + end + end + end +end diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl new file mode 100644 index 00000000..0ea1b018 --- /dev/null +++ b/test/testsuite/mooncake.jl @@ -0,0 +1,481 @@ +using TestExtras +using MatrixAlgebraKit +using Mooncake, Mooncake.TestUtils +using Mooncake: rrule!! +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc +using LinearAlgebra: BlasFloat +using GenericLinearAlgebra + +function mc_copy_eigh_full(A; kwargs...) + A = (A + A') / 2 + return eigh_full(A; kwargs...) +end + +function mc_copy_eigh_full!(A, DV; kwargs...) + A = (A + A') / 2 + return eigh_full!(A, DV; kwargs...) +end + +function mc_copy_eigh_vals(A; kwargs...) + A = (A + A') / 2 + return eigh_vals(A; kwargs...) +end + +function mc_copy_eigh_vals!(A, D; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D; kwargs...) +end + +function mc_copy_eigh_trunc(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc(A, alg; kwargs...) +end + +function mc_copy_eigh_trunc!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc!(A, DV, alg; kwargs...) +end + +function mc_copy_eigh_trunc_no_error(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error(A, alg; kwargs...) +end + +function mc_copy_eigh_trunc_no_error!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error!(A, DV, alg; kwargs...) +end + +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) + +make_mooncake_tangent(ΔAelem::T) where {T <: Real} = ΔAelem +make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) +make_mooncake_tangent(ΔA::AbstractMatrix{<:Real}) = ΔA +make_mooncake_tangent(ΔA::AbstractVector{<:Real}) = ΔA +make_mooncake_tangent(ΔA::AbstractMatrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔA::AbstractVector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) +make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD))) + +make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), make_mooncake_tangent.(T)...) + +make_mooncake_fdata(x) = make_mooncake_tangent(x) +make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) +make_mooncake_fdata(x::Tuple) = map(make_mooncake_fdata, x) + +# no `alg` argument +function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) + dA_copy = make_mooncake_fdata(copy(ΔA)) + A_copy = copy(A) + dargs_copy = make_mooncake_fdata(deepcopy(Δargs)) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy)) + copy_pb!!(rdata) + return dA_copy +end + +# `alg` argument +function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) + dA_copy = make_mooncake_fdata(copy(ΔA)) + A_copy = copy(A) + dargs_copy = make_mooncake_fdata(deepcopy(Δargs)) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) + copy_pb!!(rdata) + return dA_copy +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) + dA_inplace = make_mooncake_fdata(copy(ΔA)) + A_inplace = copy(A) + dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + end + inplace_pb!!(rdata) + return dA_inplace +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + dA_inplace = make_mooncake_fdata(copy(ΔA)) + A_inplace = copy(A) + dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + end + inplace_pb!!(rdata) + return dA_inplace +end + +""" + test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + +Compare the result of running the *in-place, mutating* function `f!`'s reverse rule +with the result of running its *non-mutating* partner function `f`'s reverse rule. +We must compare directly because many of the mutating functions modify `A` as a +scratch workspace, making testing `f!` against finite differences infeasible. + +The arguments to this function are: + - `f!` the mutating, in-place version of the function (accepts `args` for the function result) + - `f` the non-mutating version of the function (does not accept `args` for the function result) + - `A` the input matrix to factorize + - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) + - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input + - `alg` optional algorithm keyword argument + - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) +""" +function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) + sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + rrule = Mooncake.build_rrule(rvs_interp, sig) + ΔA = isa(A, Diagonal) ? Diagonal(randn!(similar(A.diag))) : randn!(similar(A)) + + dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) + dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + + dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] + dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] + @test dA_inplace_ ≈ dA_copy_ + return +end + +function test_mooncake(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake AD $summary_str" begin + test_mooncake_qr(T, sz; kwargs...) + test_mooncake_lq(T, sz; kwargs...) + if length(sz) == 1 || sz[1] == sz[2] + test_mooncake_eig(T, sz; kwargs...) + test_mooncake_eigh(T, sz; kwargs...) + end + test_mooncake_svd(T, sz; kwargs...) + test_mooncake_polar(T, sz; kwargs...) + # doesn't work for Diagonals yet? + if T <: Number + test_mooncake_orthnull(T, sz; kwargs...) + end + end +end + +function test_mooncake_qr( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "QR Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + @testset "qr_compact" begin + QR, ΔQR = ad_qr_compact_setup(A) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + test_pullbacks_match(qr_compact!, qr_compact, A, QR, ΔQR) + end + @testset "qr_null" begin + N, ΔN = ad_qr_null_setup(A) + dN = make_mooncake_tangent(copy(ΔN)) + Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dN, atol = atol, rtol = rtol) + test_pullbacks_match(qr_null!, qr_null, A, N, ΔN) + end + @testset "qr_full" begin + QR, ΔQR = ad_qr_full_setup(A) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + test_pullbacks_match(qr_full!, qr_full, A, QR, ΔQR) + end + @testset "qr_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol = atol, rtol = rtol) + test_pullbacks_match(qr_compact!, qr_compact, Ard, QR, ΔQR) + end + end +end + +function test_mooncake_lq( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "LQ Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + @testset "lq_compact" begin + LQ, ΔLQ = ad_lq_compact_setup(A) + Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(lq_compact!, lq_compact, A, LQ, ΔLQ) + end + @testset "lq_null" begin + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(lq_null!, lq_null, A, Nᴴ, ΔNᴴ) + end + @testset "lq_full" begin + LQ, ΔLQ = ad_lq_full_setup(A) + dLQ = make_mooncake_tangent(ΔLQ) + Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) + test_pullbacks_match(lq_full!, lq_full, A, LQ, ΔLQ) + end + @testset "lq_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) + dLQ = make_mooncake_tangent(ΔLQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol = atol, rtol = rtol) + test_pullbacks_match(lq_compact!, lq_compact, Ard, LQ, ΔLQ) + end + end +end + +function test_mooncake_eig( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIG Mooncake AD rules $summary_str" begin + A = make_eig_matrix(T, sz) + m = size(A, 1) + @testset "eig_full" begin + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + dDV = make_mooncake_tangent(ΔD2V) + Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dDV, atol = atol, rtol = rtol) + test_pullbacks_match(eig_full!, eig_full, A, DV, ΔD2V) + end + @testset "eig_vals" begin + D, ΔD = ad_eig_vals_setup(A) + dD = make_mooncake_tangent(ΔD) + Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dD, atol = atol, rtol = rtol) + test_pullbacks_match(eig_vals!, eig_vals, A, D, ΔD) + end + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg) + end + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol) + test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg) + end + end +end + +function test_mooncake_eigh( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIGH Mooncake AD rules $summary_str" begin + A = make_eigh_matrix(T, sz) + m = size(A, 1) + @testset "eigh_full" begin + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) + dDV = make_mooncake_tangent(ΔD2V) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(mc_copy_eigh_full!, mc_copy_eigh_full, A, DV, ΔD2V) + end + @testset "eigh_vals" begin + D, ΔD = ad_eigh_vals_setup(A) + dD = make_mooncake_tangent(ΔD) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(mc_copy_eigh_vals!, mc_copy_eigh_vals, A, D, ΔD) + end + @testset "eigh_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(r; by = abs)) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg) + end + D = eigh_vals(A / 2) + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2)) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg) + end + end +end + +function test_mooncake_svd( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "SVD Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + minmn = min(size(A)...) + @testset "svd_compact" begin + USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ) + end + @testset "svd_full" begin + USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ) + end + @testset "svd_vals" begin + S, ΔS = ad_svd_vals_setup(A) + Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(svd_vals!, svd_vals, A, S, ΔS) + end + @testset "svd_trunc" begin + S, ΔS = ad_svd_vals_setup(A) + @testset for r in 1:4:minmn + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg) + end + @testset "trunctol" begin + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T))))) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg) + end + end + end +end + +function test_mooncake_polar( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Polar Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + @testset "left_polar" begin + if m >= n + WP, ΔWP = ad_left_polar_setup(A) + Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(left_polar!, left_polar, A, WP, ΔWP) + end + end + @testset "right_polar" begin + if m <= n + PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) + Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + test_pullbacks_match(right_polar!, right_polar, A, PWᴴ, ΔPWᴴ) + end + end + end +end + +left_orth_qr(X) = left_orth(X; alg = :qr) +left_orth_polar(X) = left_orth(X; alg = :polar) +left_null_qr(X) = left_null(X; alg = :qr) +right_orth_lq(X) = right_orth(X; alg = :lq) +right_orth_polar(X) = right_orth(X; alg = :polar) +right_null_lq(X) = right_null(X; alg = :lq) + +MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) +MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) +MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) +MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) + +function test_mooncake_orthnull( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Orthnull Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + VC, ΔVC = ad_left_orth_setup(A) + CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) + Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(left_orth!, left_orth, A, VC, ΔVC) + Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(right_orth!, right_orth, A, CVᴴ, ΔCVᴴ) + + Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, ΔVC) + if m >= n + Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, ΔVC) + end + + N, ΔN = ad_left_null_setup(A) + dN = make_mooncake_tangent(ΔN) + Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) + test_pullbacks_match(((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) + + Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, ΔCVᴴ) + + if m <= n + Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, ΔCVᴴ) + end + + Nᴴ, ΔNᴴ = ad_right_null_setup(A) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) + test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) + end +end From c5db5b5accd4e145ba9f7d71dde8109edff33698 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 23 Jan 2026 12:05:52 +0100 Subject: [PATCH 02/18] Get rid of custom GPU tolerances --- ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index ccc03a56..fb67149e 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -3,7 +3,7 @@ module MatrixAlgebraKitCUDAExt using MatrixAlgebraKit using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! -using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol +using MatrixAlgebraKit: diagview, sign_safe using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! @@ -195,13 +195,6 @@ end MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) -MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4) -MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4) -function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...) - As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...)) - return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4) -end - function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix) # https://github.com/JuliaGPU/CUDA.jl/issues/3021 # to add native sylvester to CUDA From c12f5581ddf6b969dc5b673197987b2c8701ea9e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 23 Jan 2026 14:52:15 -0500 Subject: [PATCH 03/18] Fixes for enzyme --- src/pullbacks/polar.jl | 2 +- test/enzyme.jl | 4 +-- test/runtests.jl | 8 ++--- test/testsuite/TestSuite.jl | 1 + test/testsuite/enzyme.jl | 64 +++++++++++++++++++------------------ 5 files changed, 41 insertions(+), 38 deletions(-) diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index 8ada8575..4d498da0 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -46,7 +46,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs... M = zero(P) !iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1) !iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1) - C = sylvester(P, P, M' - M) + C = _sylvester(P, P, M' - M) C .+= ΔP ΔA = mul!(ΔA, C, Wᴴ, 1, 1) if !iszerotangent(ΔWᴴ) diff --git a/test/enzyme.jl b/test/enzyme.jl index 28ff7454..bfef3577 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (ComplexF64,) # full suite is too expensive on CI +BLASFloats = (Float64,) # full suite is too expensive on CI GenericFloats = (BigFloat,) @isdefined(TestSuite) || include("testsuite/TestSuite.jl") using .TestSuite @@ -13,7 +13,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(123) - if T <: BLASFloats + if T ∈ BLASFloats if CUDA.functional() TestSuite.test_enzyme(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) #n == m && TestSuite.test_enzyme(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) diff --git a/test/runtests.jl b/test/runtests.jl index 28f220f6..7190882b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using ParallelTestRunner +#=using ParallelTestRunner using MatrixAlgebraKit # Start with autodiscovered tests @@ -22,8 +22,6 @@ if filter_tests!(testsuite, args) delete!(testsuite, "algorithms") delete!(testsuite, "truncate") delete!(testsuite, "gen_eig") - delete!(testsuite, "mooncake") - delete!(testsuite, "enzyme") delete!(testsuite, "chainrules") delete!(testsuite, "codequality") else @@ -32,4 +30,6 @@ if filter_tests!(testsuite, args) end end -runtests(MatrixAlgebraKit, args; testsuite) +runtests(MatrixAlgebraKit, args; testsuite)=# +include("enzyme.jl") +include("mooncake.jl") diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 28833557..12653096 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -96,6 +96,7 @@ include("eigh.jl") include("orthnull.jl") include("svd.jl") include("mooncake.jl") +include("enzyme.jl") include("chainrules.jl") end diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index d10bc0fc..dde02841 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -46,7 +46,7 @@ function enz_copy_eigh_trunc_no_error!(A, DV, alg) end function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated) - ΔA = randn(rng, eltype(A), size(A)...) + ΔA = randn!(similar(A)) A_ΔA() = Duplicated(copy(A), copy(ΔA)) function args_Δargs() if isnothing(args) @@ -143,8 +143,8 @@ function test_enzyme_qr( r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) - eltype(T) <: BlasFloat && test_reverse(qr_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = (ΔQ, ΔR), fdm) - test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) + eltype(T) <: BlasFloat && test_reverse(qr_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) + test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, QR, ΔQR, alg) end end end @@ -163,8 +163,8 @@ function test_enzyme_lq( @testset "lq_compact" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) LQ, ΔLQ = ad_lq_compact_setup(A) - eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ), fdm) - test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (ΔL, ΔQ), alg) + eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) + test_pullbacks_match(rng, lq_compact!, lq_compact, A, LQ, ΔLQ, alg) end end @testset "lq_null" begin @@ -177,8 +177,8 @@ function test_enzyme_lq( @testset "lq_full" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) LQ, ΔLQ = ad_lq_full_setup(A) - eltype(T) <: BlasFloat && test_reverse(lq_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ), fdm) - test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) + eltype(T) <: BlasFloat && test_reverse(lq_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) + test_pullbacks_match(rng, lq_full!, lq_full, A, LQ, ΔLQ, alg) end end @testset "lq_compact -- rank-deficient A" begin @@ -187,8 +187,8 @@ function test_enzyme_lq( r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) - eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = (ΔL, ΔQ), fdm) - test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) + eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) + test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, LQ, ΔLQ, alg) end end end @@ -209,8 +209,8 @@ function test_enzyme_eig( @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) DV, ΔDV, ΔD2V = ad_eig_full_setup(A) if eltype(T) <: BlasFloat - test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm) - test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) + test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD2V, fdm) + test_pullbacks_match(rng, eig_full!, eig_full, A, DV, ΔD2V, alg) else test_pullbacks_match(rng, eig_full!, eig_full, A, (nothing, nothing), (nothing, nothing), alg; ȳ = (ΔD2, ΔV)) end @@ -221,9 +221,9 @@ function test_enzyme_eig( D, ΔD = ad_eig_vals_setup(A) if eltype(T) <: BlasFloat test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = copy(ΔD2.diag), fdm) - test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) + test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD.diag, alg) else - test_pullbacks_match(rng, eig_vals!, eig_vals, A, nothing, nothing, alg; ȳ = ΔD2.diag) + test_pullbacks_match(rng, eig_vals!, eig_vals, A, nothing, nothing, alg; ȳ = ΔD.diag) end end end @@ -233,19 +233,19 @@ function test_enzyme_eig( truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) if eltype(T) <: BlasFloat - test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc) else - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = ΔDVtrunc) end end truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) if eltype(T) <: BlasFloat - test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc) else - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = (ΔDtrunc, ΔVtrunc)) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = ΔDVtrunc) end end end @@ -265,17 +265,19 @@ function test_enzyme_eigh( fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) @testset "eigh_full" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) if eltype(T) <: BlasFloat - test_reverse(copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm) - test_reverse(copy_eigh_full!, RT, (copy(A), TA), ((D, V), TA), (alg, Const); atol, rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm) + test_reverse(copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) + test_reverse(copy_eigh_full!, RT, (A, TA), ((D, V), TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) end - test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) + test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V, alg) end end @testset "eigh_vals" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - eltype(T) <: BlasFloat && test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = copy(ΔD2.diag), fdm) - test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) + D, ΔD = ad_eigh_vals_setup(A) + eltype(T) <: BlasFloat && test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD, fdm) + test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D, ΔD, alg) end end @testset "eigh_trunc" begin @@ -284,14 +286,14 @@ function test_enzyme_eigh( Ddiag = diagview(D) truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) - eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) + eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) end D = eigh_vals(A / 2) truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) - eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT) + eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) end end end @@ -312,11 +314,11 @@ function test_enzyme_svd( @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) if eltype(T) <: BlasFloat - test_reverse(svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = (ΔU, ΔS, ΔVᴴ), fdm) - test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ), alg) + test_reverse(svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔUSVᴴ, fdm) + test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ, alg) else USVᴴ = MatrixAlgebraKit.initialize_output(svd_compact!, A, alg) - test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = (ΔU, ΔS, ΔVᴴ)) + test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = ΔUSVᴴ) end end end From e18caf6a23a1b85a1811a9f681b2cf78a5b26d36 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 26 Jan 2026 09:17:16 -0500 Subject: [PATCH 04/18] Fix some remaining typos and turn off GPU --- test/enzyme.jl | 4 +- test/testsuite/enzyme.jl | 81 +++++++++++++++++++++------------------- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/test/enzyme.jl b/test/enzyme.jl index bfef3577..e7c5db37 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -13,7 +13,7 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(123) - if T ∈ BLASFloats + #=if T ∈ BLASFloats if CUDA.functional() TestSuite.test_enzyme(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) #n == m && TestSuite.test_enzyme(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) @@ -22,7 +22,7 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.test_enzyme(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) #TestSuite.test_enzyme(Diagonal{T, ROCVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) end - end + end=# if !is_buildkite TestSuite.test_enzyme(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) #n == m && TestSuite.test_enzyme(Diagonal{T, Vector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index dde02841..6f2d138f 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -45,7 +45,8 @@ function enz_copy_eigh_trunc_no_error!(A, DV, alg) return eigh_trunc_no_error!(A, DV, alg) end -function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated) +# necessary due to name conflict with Mooncake +function enz_test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; ȳ = copy.(Δargs), return_act = Duplicated) ΔA = randn!(similar(A)) A_ΔA() = Duplicated(copy(A), copy(ΔA)) function args_Δargs() @@ -106,6 +107,8 @@ function test_enzyme(T::Type, sz; kwargs...) end end +is_cpu(A) = typeof(parent(A)) <: Array + function test_enzyme_qr( T::Type, sz; atol::Real = 0, rtol::Real = precision(T), @@ -120,21 +123,21 @@ function test_enzyme_qr( @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) QR, ΔQR = ad_qr_compact_setup(A) eltype(T) <: BlasFloat && test_reverse(qr_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) - test_pullbacks_match(rng, qr_compact!, qr_compact, A, QR, ΔQR, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, qr_compact!, qr_compact, A, QR, ΔQR, alg) end end @testset "qr_null" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) N, ΔN = ad_qr_null_setup(A) eltype(T) <: BlasFloat && test_reverse(qr_null, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔN) - test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) end end @testset "qr_full" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) QR, ΔQR = ad_qr_full_setup(A) - eltype(T) <: BlasFloat && test_reverse(qr_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = (ΔQ, ΔR), fdm) - test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) + eltype(T) <: BlasFloat && test_reverse(qr_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) + is_cpu(A) && enz_test_pullbacks_match(rng, qr_full!, qr_full, A, QR, ΔQR, alg) end end @testset "qr_compact - rank-deficient A" begin @@ -144,7 +147,7 @@ function test_enzyme_qr( Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) eltype(T) <: BlasFloat && test_reverse(qr_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) - test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, QR, ΔQR, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, QR, ΔQR, alg) end end end @@ -164,21 +167,21 @@ function test_enzyme_lq( @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) LQ, ΔLQ = ad_lq_compact_setup(A) eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) - test_pullbacks_match(rng, lq_compact!, lq_compact, A, LQ, ΔLQ, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, lq_compact!, lq_compact, A, LQ, ΔLQ, alg) end end @testset "lq_null" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) Nᴴ, ΔNᴴ = ad_lq_null_setup(A) eltype(T) <: BlasFloat && test_reverse(lq_null, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔNᴴ) - test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) end end @testset "lq_full" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) LQ, ΔLQ = ad_lq_full_setup(A) eltype(T) <: BlasFloat && test_reverse(lq_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) - test_pullbacks_match(rng, lq_full!, lq_full, A, LQ, ΔLQ, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, lq_full!, lq_full, A, LQ, ΔLQ, alg) end end @testset "lq_compact -- rank-deficient A" begin @@ -188,7 +191,7 @@ function test_enzyme_lq( Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) - test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, LQ, ΔLQ, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, LQ, ΔLQ, alg) end end end @@ -210,9 +213,9 @@ function test_enzyme_eig( DV, ΔDV, ΔD2V = ad_eig_full_setup(A) if eltype(T) <: BlasFloat test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD2V, fdm) - test_pullbacks_match(rng, eig_full!, eig_full, A, DV, ΔD2V, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, eig_full!, eig_full, A, DV, ΔD2V, alg) else - test_pullbacks_match(rng, eig_full!, eig_full, A, (nothing, nothing), (nothing, nothing), alg; ȳ = (ΔD2, ΔV)) + is_cpu(A) && enz_test_pullbacks_match(rng, eig_full!, eig_full, A, (nothing, nothing), (nothing, nothing), alg; ȳ = (ΔD2, ΔV)) end end end @@ -220,10 +223,10 @@ function test_enzyme_eig( @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) D, ΔD = ad_eig_vals_setup(A) if eltype(T) <: BlasFloat - test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = copy(ΔD2.diag), fdm) - test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD.diag, alg) + test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD, fdm) + is_cpu(A) && enz_test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD, alg) else - test_pullbacks_match(rng, eig_vals!, eig_vals, A, nothing, nothing, alg; ȳ = ΔD.diag) + is_cpu(A) && enz_test_pullbacks_match(rng, eig_vals!, eig_vals, A, nothing, nothing, alg; ȳ = ΔD) end end end @@ -234,18 +237,18 @@ function test_enzyme_eig( DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) if eltype(T) <: BlasFloat test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc) + is_cpu(A) && enz_test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc) else - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = ΔDVtrunc) + is_cpu(A) && enz_test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = ΔDVtrunc) end end truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) if eltype(T) <: BlasFloat test_reverse(eig_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc) + is_cpu(A) && enz_test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc) else - test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = ΔDVtrunc) + is_cpu(A) && enz_test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (nothing, nothing), (nothing, nothing), truncalg, ȳ = ΔDVtrunc) end end end @@ -270,14 +273,14 @@ function test_enzyme_eigh( test_reverse(copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) test_reverse(copy_eigh_full!, RT, (A, TA), ((D, V), TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) end - test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V, alg) end end @testset "eigh_vals" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) D, ΔD = ad_eigh_vals_setup(A) eltype(T) <: BlasFloat && test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD, fdm) - test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D, ΔD, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D, ΔD, alg) end end @testset "eigh_trunc" begin @@ -287,13 +290,13 @@ function test_enzyme_eigh( truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) + is_cpu(A) && enz_test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) end D = eigh_vals(A / 2) truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) + is_cpu(A) && enz_test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) end end end @@ -315,10 +318,10 @@ function test_enzyme_svd( USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) if eltype(T) <: BlasFloat test_reverse(svd_compact, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔUSVᴴ, fdm) - test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ, alg) else USVᴴ = MatrixAlgebraKit.initialize_output(svd_compact!, A, alg) - test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = ΔUSVᴴ) + is_cpu(A) && enz_test_pullbacks_match(rng, svd_compact!, svd_compact, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = ΔUSVᴴ) end end end @@ -327,10 +330,10 @@ function test_enzyme_svd( USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) if eltype(T) <: BlasFloat test_reverse(svd_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔUSVᴴ, fdm) - test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ, alg) else USVᴴ = MatrixAlgebraKit.initialize_output(svd_full!, A, alg) - test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = ΔUSVᴴ) + is_cpu(A) && enz_test_pullbacks_match(rng, svd_full!, svd_full, A, USVᴴ, (nothing, nothing, nothing), alg; ȳ = ΔUSVᴴ) end end end @@ -339,10 +342,10 @@ function test_enzyme_svd( S, ΔS = ad_svd_vals_setup(A) if eltype(T) <: BlasFloat test_reverse(svd_vals, RT, (A, TA); atol, rtol, fkwargs = (alg = alg,), output_tangent = ΔS, fdm) - test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, ΔS, alg) else S = MatrixAlgebraKit.initialize_output(svd_vals!, A, alg) - test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, nothing, alg; ȳ = ΔS) + is_cpu(A) && enz_test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, nothing, alg; ȳ = ΔS) end end end @@ -354,18 +357,18 @@ function test_enzyme_svd( USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) if eltype(T) <: BlasFloat test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = ΔUSVᴴtrunc) + is_cpu(A) && enz_test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = ΔUSVᴴtrunc) else - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = ΔUSVᴴtrunc) + is_cpu(A) && enz_test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = ΔUSVᴴtrunc) end end truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) if eltype(T) <: BlasFloat test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = ΔUSVᴴtrunc) + is_cpu(A) && enz_test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg, ȳ = ΔUSVᴴtrunc) else - test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = ΔUSVᴴtrunc) + is_cpu(A) && enz_test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (nothing, nothing, nothing), (nothing, nothing, nothing), truncalg, ȳ = ΔUSVᴴtrunc) end end end @@ -389,7 +392,7 @@ function test_enzyme_polar( if m >= n WP, ΔWP = ad_left_polar_setup(A) eltype(T) <: BlasFloat && test_reverse(left_polar, RT, (A, TA), (alg, Const); atol, rtol) - test_pullbacks_match(rng, left_polar!, left_polar, A, WP, ΔWP, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, left_polar!, left_polar, A, WP, ΔWP, alg) end end end @@ -398,7 +401,7 @@ function test_enzyme_polar( if m <= n PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) eltype(T) <: BlasFloat && test_reverse(right_polar, RT, (A, TA), (alg, Const); atol, rtol) - test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, ΔPWᴴ, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, ΔPWᴴ, alg) end end end @@ -424,7 +427,7 @@ function test_enzyme_orthnull( eltype(T) <: BlasFloat && test_reverse(left_orth, RT, (A, TA); atol, rtol, fkwargs = (alg = alg,), fdm) left_orth_alg!(A, VC) = left_orth!(A, VC; alg = alg) left_orth_alg(A) = left_orth(A; alg = alg) - test_pullbacks_match(rng, left_orth_alg!, left_orth_alg, A, VC, ΔVC) + is_cpu(A) && enz_test_pullbacks_match(rng, left_orth_alg!, left_orth_alg, A, VC, ΔVC) end end end @@ -434,7 +437,7 @@ function test_enzyme_orthnull( left_null_qr!(A, N) = left_null!(A, N; alg = :qr) left_null_qr(A) = left_null(A; alg = :qr) eltype(T) <: BlasFloat && test_reverse(left_null_qr, RT, (A, TA); output_tangent = ΔN, atol, rtol) - test_pullbacks_match(rng, left_null_qr!, left_null_qr, A, N, ΔN) + is_cpu(A) && enz_test_pullbacks_match(rng, left_null_qr!, left_null_qr, A, N, ΔN) end end @testset "right_orth" begin @@ -444,7 +447,7 @@ function test_enzyme_orthnull( eltype(T) <: BlasFloat && test_reverse(right_orth, RT, (A, TA); atol, rtol, fkwargs = (alg = alg,), fdm) right_orth_alg!(A, CVᴴ) = right_orth!(A, CVᴴ; alg = alg) right_orth_alg(A) = right_orth(A; alg = alg) - test_pullbacks_match(rng, right_orth_alg!, right_orth_alg, A, CVᴴ, ΔCVᴴ) + is_cpu(A) && enz_test_pullbacks_match(rng, right_orth_alg!, right_orth_alg, A, CVᴴ, ΔCVᴴ) end end end @@ -454,7 +457,7 @@ function test_enzyme_orthnull( right_null_lq!(A, Nᴴ) = right_null!(A, Nᴴ; alg = :lq) right_null_lq(A) = right_null(A; alg = :lq) eltype(T) <: BlasFloat && test_reverse(right_null_lq, RT, (A, TA); output_tangent = ΔNᴴ, atol, rtol) - test_pullbacks_match(rng, right_null_lq!, right_null_lq, A, Nᴴ, ΔNᴴ) + is_cpu(A) && enz_test_pullbacks_match(rng, right_null_lq!, right_null_lq, A, Nᴴ, ΔNᴴ) end end end From 628efdb0db1a9649539f394a53cb66f9cd1f7842 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 26 Jan 2026 09:35:51 -0500 Subject: [PATCH 05/18] Fix runtests --- test/runtests.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 7190882b..cadaacab 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -#=using ParallelTestRunner +using ParallelTestRunner using MatrixAlgebraKit # Start with autodiscovered tests @@ -30,6 +30,4 @@ if filter_tests!(testsuite, args) end end -runtests(MatrixAlgebraKit, args; testsuite)=# -include("enzyme.jl") -include("mooncake.jl") +runtests(MatrixAlgebraKit, args; testsuite) From 90f5ce1bf3c7e207e4fae49dd54c462ccb0d574c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 26 Jan 2026 10:17:13 -0500 Subject: [PATCH 06/18] Turn off GPU for Mooncake too --- test/mooncake.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/mooncake.jl b/test/mooncake.jl index ea5bbf65..d2f54ece 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -14,11 +14,11 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) TestSuite.seed_rng!(123) - if CUDA.functional() + #=if CUDA.functional() TestSuite.test_mooncake(CuMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) #n == m && TestSuite.test_mooncake(Diagonal{T, CuVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) end - #=if AMDGPU.functional() + if AMDGPU.functional() TestSuite.test_mooncake(ROCMatrix{T}, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) TestSuite.test_mooncake(Diagonal{T, ROCVector{T}}, m; atol = m * TestSuite.precision(T), rtol = m * TestSuite.precision(T)) end=# # not yet supported From f5be465c117e6ad3862f936322e58951adaa0d26 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 26 Jan 2026 17:18:14 +0100 Subject: [PATCH 07/18] Remove deletion of nonexistent file --- test/runtests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index cadaacab..d517f5be 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,7 +9,6 @@ filter!(!(startswith("testsuite") ∘ first), testsuite) # remove utils delete!(testsuite, "utilities") -delete!(testsuite, "ad_utils") delete!(testsuite, "linearmap") # Parse arguments From 08e0cfbf76435ad79fbda389f8aa69de00cad941 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 26 Jan 2026 20:51:15 +0100 Subject: [PATCH 08/18] More Enzyme fix --- test/testsuite/enzyme.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index 6f2d138f..77b6b056 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -285,6 +285,7 @@ function test_enzyme_eigh( end @testset "eigh_trunc" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + D = eigh_vals(A / 2) for r in 1:4:m Ddiag = diagview(D) truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) @@ -292,7 +293,6 @@ function test_enzyme_eigh( eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) is_cpu(A) && enz_test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) end - D = eigh_vals(A / 2) truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) From a84f8c76839a79898cc68982450b5ff104518d5d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 07:20:28 +0100 Subject: [PATCH 09/18] More bad variables --- test/testsuite/enzyme.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index 77b6b056..7f39c249 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -287,7 +287,6 @@ function test_enzyme_eigh( @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) D = eigh_vals(A / 2) for r in 1:4:m - Ddiag = diagview(D) truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) From 1fcc8b2b64aa11894d893c7b1db222441ae28b77 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 08:06:12 +0100 Subject: [PATCH 10/18] More missing stuff --- test/testsuite/enzyme.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index 7f39c249..267f8b04 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -270,17 +270,17 @@ function test_enzyme_eigh( @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) if eltype(T) <: BlasFloat - test_reverse(copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) - test_reverse(copy_eigh_full!, RT, (A, TA), ((D, V), TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) + test_reverse(enz_copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) + test_reverse(enz_copy_eigh_full!, RT, (A, TA), ((D, V), TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) end - is_cpu(A) && enz_test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V, alg) end end @testset "eigh_vals" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) D, ΔD = ad_eigh_vals_setup(A) - eltype(T) <: BlasFloat && test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D, ΔD, alg) + eltype(T) <: BlasFloat && test_reverse(enz_copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD, fdm) + is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_vals!, copy_eigh_vals, A, D, ΔD, alg) end end @testset "eigh_trunc" begin @@ -289,13 +289,13 @@ function test_enzyme_eigh( for r in 1:4:m truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) - eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) + eltype(T) <: BlasFloat && test_reverse(enz_copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_trunc_no_error!, enz_copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) end truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) - eltype(T) <: BlasFloat && test_reverse(copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) + eltype(T) <: BlasFloat && test_reverse(enz_copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) + is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_trunc_no_error!, enz_copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) end end end From e1bc26624017a75523e7efeeeb8d1ab33032189b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 08:48:42 +0100 Subject: [PATCH 11/18] Yet another stupid typo --- test/testsuite/enzyme.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index 267f8b04..899a8280 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -290,12 +290,12 @@ function test_enzyme_eigh( truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) eltype(T) <: BlasFloat && test_reverse(enz_copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_trunc_no_error!, enz_copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) + is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_trunc_no_error!, enz_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc, return_act = RT) end truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, D) / 2)) DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) eltype(T) <: BlasFloat && test_reverse(enz_copy_eigh_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔDVtrunc, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_trunc_no_error!, enz_copy_eigh_trunc_no_error, A, DV, ΔD2V, truncalg, ȳ = ΔDVtrunc, return_act = RT) + is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_trunc_no_error!, enz_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg, ȳ = ΔDVtrunc, return_act = RT) end end end From 864b8b7f5da4736f0e9f9c7fb13bc97155a656bd Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 09:36:59 +0100 Subject: [PATCH 12/18] ANOTHER ONE --- test/testsuite/enzyme.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index 899a8280..235623fe 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -271,7 +271,7 @@ function test_enzyme_eigh( DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) if eltype(T) <: BlasFloat test_reverse(enz_copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) - test_reverse(enz_copy_eigh_full!, RT, (A, TA), ((D, V), TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) + test_reverse(enz_copy_eigh_full!, RT, (A, TA), (DV, TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) end is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V, alg) end From e4503b2aaa9ce15efd0ff7f60e75efc363d0d773 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 11:15:11 +0100 Subject: [PATCH 13/18] Fix other typos --- test/testsuite/enzyme.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index 235623fe..14723680 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -97,7 +97,8 @@ function test_enzyme(T::Type, sz; kwargs...) test_enzyme_lq(T, sz; kwargs...) if length(sz) == 1 || sz[1] == sz[2] test_enzyme_eig(T, sz; kwargs...) - test_enzyme_eigh(T, sz; kwargs...) + # missing Enzyme rule + eltype(T) <: BlasFloat && test_enzyme_eigh(T, sz; kwargs...) end test_enzyme_svd(T, sz; kwargs...) if eltype(T) <: BlasFloat @@ -215,7 +216,7 @@ function test_enzyme_eig( test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD2V, fdm) is_cpu(A) && enz_test_pullbacks_match(rng, eig_full!, eig_full, A, DV, ΔD2V, alg) else - is_cpu(A) && enz_test_pullbacks_match(rng, eig_full!, eig_full, A, (nothing, nothing), (nothing, nothing), alg; ȳ = (ΔD2, ΔV)) + is_cpu(A) && enz_test_pullbacks_match(rng, eig_full!, eig_full, A, (nothing, nothing), (nothing, nothing), alg; ȳ = ΔD2V) end end end @@ -224,7 +225,7 @@ function test_enzyme_eig( D, ΔD = ad_eig_vals_setup(A) if eltype(T) <: BlasFloat test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, eig_vals!, eig_vals, A, D, ΔD, alg) else is_cpu(A) && enz_test_pullbacks_match(rng, eig_vals!, eig_vals, A, nothing, nothing, alg; ȳ = ΔD) end @@ -273,14 +274,14 @@ function test_enzyme_eigh( test_reverse(enz_copy_eigh_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) test_reverse(enz_copy_eigh_full!, RT, (A, TA), (DV, TA), (alg, Const); atol, rtol, output_tangent = ΔD2V, fdm) end - is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_full!, copy_eigh_full, A, DV, ΔD2V, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_full!, enz_copy_eigh_full, A, DV, ΔD2V, alg) end end @testset "eigh_vals" begin @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) D, ΔD = ad_eigh_vals_setup(A) eltype(T) <: BlasFloat && test_reverse(enz_copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol, rtol, output_tangent = ΔD, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_vals!, copy_eigh_vals, A, D, ΔD, alg) + is_cpu(A) && enz_test_pullbacks_match(rng, enz_copy_eigh_vals!, enz_copy_eigh_vals, A, D, ΔD, alg) end end @testset "eigh_trunc" begin From e0705e1572a5cd6f25a80445bed21ecaa820dfc2 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 12:17:10 +0100 Subject: [PATCH 14/18] Turn off BigFloats for now --- test/enzyme.jl | 4 ++-- test/runtests.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/enzyme.jl b/test/enzyme.jl index e7c5db37..217be8c1 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -3,8 +3,8 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float64,) # full suite is too expensive on CI -GenericFloats = (BigFloat,) +BLASFloats = (Float64,ComplexF64) # full suite is too expensive on CI +GenericFloats = () #(BigFloat,) @isdefined(TestSuite) || include("testsuite/TestSuite.jl") using .TestSuite diff --git a/test/runtests.jl b/test/runtests.jl index d517f5be..6833e6ba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,7 +25,7 @@ if filter_tests!(testsuite, args) delete!(testsuite, "codequality") else is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true" - (Sys.iswindows() || is_apple_ci) && delete!(testsuite, "enzyme") + is_apple_ci && delete!(testsuite, "enzyme") end end From 9d1e0ffe541fe09298236cce2e7d744a96981ceb Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 12:24:14 +0100 Subject: [PATCH 15/18] Format --- test/enzyme.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/enzyme.jl b/test/enzyme.jl index 217be8c1..bfd83fc1 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float64,ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () #(BigFloat,) @isdefined(TestSuite) || include("testsuite/TestSuite.jl") using .TestSuite From a11f506106251a3ca0cd4cefcf9d14f2f353dff8 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 13:34:02 +0100 Subject: [PATCH 16/18] No windows --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6833e6ba..d517f5be 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,7 +25,7 @@ if filter_tests!(testsuite, args) delete!(testsuite, "codequality") else is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true" - is_apple_ci && delete!(testsuite, "enzyme") + (Sys.iswindows() || is_apple_ci) && delete!(testsuite, "enzyme") end end From 9f86477177706e58f8deaf7f8509074cce589c6d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 27 Jan 2026 10:28:09 -0500 Subject: [PATCH 17/18] Two types is too expensive --- test/enzyme.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/enzyme.jl b/test/enzyme.jl index bfd83fc1..9ffe1501 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI +BLASFloats = (ComplexF64,) # full suite is too expensive on CI GenericFloats = () #(BigFloat,) @isdefined(TestSuite) || include("testsuite/TestSuite.jl") using .TestSuite From 44764984622412ef14f1451ca9daa3231becb20c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 28 Jan 2026 09:35:09 +0100 Subject: [PATCH 18/18] Don't run AD tests on macOS CI at all --- test/runtests.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index d517f5be..9b180a90 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,7 +25,12 @@ if filter_tests!(testsuite, args) delete!(testsuite, "codequality") else is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true" - (Sys.iswindows() || is_apple_ci) && delete!(testsuite, "enzyme") + if is_apple_ci + delete!(testsuite, "enzyme") + delete!(testsuite, "mooncake") + delete!(testsuite, "chainrules") + end + Sys.iswindows() && delete!(testsuite, "enzyme") end end