Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Enzyme = "0.13.118"
EnzymeTestUtils = "0.2.5"
JET = "0.9, 0.10"
LinearAlgebra = "1"
Mooncake = "0.4.183"
Mooncake = "0.4.195"
ParallelTestRunner = "2"
Random = "1"
SafeTestsets = "0.1"
Expand Down
9 changes: 8 additions & 1 deletion ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
using AMDGPU
using LinearAlgebra
using LinearAlgebra: BlasFloat
Expand Down Expand Up @@ -171,4 +171,11 @@ end
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) =
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))

function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix)
hX = sylvester(collect(A), collect(B), collect(C))
return ROCArray(hX)
end

svd_rank(S::AnyROCVector, rank_atol) = findlast(s -> s ≥ rank_atol, S)

end
11 changes: 10 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank
using CUDA, CUDA.CUBLAS
using CUDA: i32
using LinearAlgebra
Expand Down Expand Up @@ -195,4 +195,13 @@ end
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) =
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))

function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
# https://github.com/JuliaGPU/CUDA.jl/issues/3021
# to add native sylvester to CUDA
hX = sylvester(collect(A), collect(B), collect(C))
return CuArray(hX)
end

svd_rank(S::AnyCuVector, rank_atol) = findlast(s -> s ≥ rank_atol, S)

end
21 changes: 21 additions & 0 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ for eig in (:eig, :eigh)
eig_t! = Symbol(eig, "_trunc!")
eig_t_pb = Symbol(eig, "_trunc_pullback")
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
eig_t_ne! = Symbol(eig, "_trunc_no_error!")
eig_t_ne_pb = Symbol(eig, "_trunc_no_error_pullback")
_make_eig_t_ne_pb = Symbol("_make_", eig_t_ne_pb)
eig_v = Symbol(eig, "_vals")
eig_v! = Symbol(eig_v, "!")
eig_v_pb = Symbol(eig_v, "_pullback")
Expand Down Expand Up @@ -136,6 +139,24 @@ for eig in (:eig, :eigh)
end
return $eig_t_pb
end
function ChainRulesCore.rrule(::typeof($eig_t_ne!), A, DV, alg::TruncatedAlgorithm)
Ac = copy_input($eig_f, A)
DV = $(eig_f!)(Ac, DV, alg.alg)
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
return DV′, $(_make_eig_t_ne_pb)(A, DV, ind)
end
function $(_make_eig_t_ne_pb)(A, DV, ind)
function $eig_t_ne_pb(ΔDV)
ΔA = zero(A)
ΔD, ΔV = ΔDV
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function $eig_t_ne_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return $eig_t_ne_pb
end
function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg)
DV = $eig_f(A, alg)
function $eig_v_pb(ΔD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
using Mooncake
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
using MatrixAlgebraKit
using MatrixAlgebraKit: inv_safe, diagview, copy_input
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
Expand All @@ -18,14 +18,16 @@ Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.N
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
dAc = Mooncake.zero_tangent(Ac)
Ac_dAc = Mooncake.zero_fcodual(Ac)
dAc = Mooncake.tangent(Ac_dAc)
function copy_input_pb(::NoRData)
Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
return NoRData(), NoRData(), NoRData()
end
return CoDual(Ac, dAc), copy_input_pb
return Ac_dAc, copy_input_pb
end

Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any}
# two-argument in-place factorizations like LQ, QR, EIG
for (f!, f, pb, adj) in (
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),
Expand Down
1 change: 1 addition & 0 deletions src/common/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4)
Default tolerance for deciding what values should be considered equal to 0.
"""
default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4)
default_pullback_rank_atol(A::Diagonal) = default_pullback_rank_atol(diagview(A))

"""
default_hermitian_tol(A)
Expand Down
3 changes: 3 additions & 0 deletions src/common/pullbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ function iszerotangent end

iszerotangent(::Any) = false
iszerotangent(::Nothing) = true

# fallback
_sylvester(A, B, C) = LinearAlgebra.sylvester(A, B, C)
24 changes: 15 additions & 9 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
function check_eig_cotangents(
D, VᴴΔV;
degeneracy_atol::Real = default_pullback_rank_atol(D),
gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV)
)
mask = abs.(transpose(D) .- D) .< degeneracy_atol
Δgauge = norm(view(VᴴΔV, mask))
Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

"""
eig_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
Expand Down Expand Up @@ -41,10 +53,7 @@ function eig_pullback!(
length(indV) == pV || throw(DimensionMismatch())
mul!(view(VᴴΔV, :, indV), V', ΔV)

mask = abs.(transpose(D) .- D) .< degeneracy_atol
Δgauge = norm(view(VᴴΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)

VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))

Expand Down Expand Up @@ -129,10 +138,7 @@ function eig_trunc_pullback!(
if !iszerotangent(ΔV)
(n, p) == size(ΔV) || throw(DimensionMismatch())
VᴴΔV = V' * ΔV
mask = abs.(transpose(D) .- D) .< degeneracy_atol
Δgauge = norm(view(VᴴΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)

ΔVperp = ΔV - V * inv(G) * VᴴΔV
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
Expand All @@ -150,7 +156,7 @@ function eig_trunc_pullback!(
# add contribution from orthogonal complement
PA = A - (A * V) / V
Y = mul!(ΔVperp, PA', Z, 1, 1)
X = sylvester(PA', -Dmat', Y)
X = _sylvester(PA', -Dmat', Y)
Z .+= X

if eltype(ΔA) <: Real
Expand Down
24 changes: 15 additions & 9 deletions src/pullbacks/eigh.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
function check_eigh_cotangents(
D, aVᴴΔV;
degeneracy_atol::Real = default_pullback_rank_atol(D),
gauge_atol::Real = default_pullback_gauge_atol(aVᴴΔV)
)
mask = abs.(D' .- D) .< degeneracy_atol
Δgauge = norm(view(aVᴴΔV, mask))
Δgauge ≤ gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

"""
eigh_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
Expand Down Expand Up @@ -42,10 +54,7 @@ function eigh_pullback!(
mul!(view(VᴴΔV, :, indV), V', ΔV)
aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work

mask = abs.(D' .- D) .< degeneracy_atol
Δgauge = norm(view(aVᴴΔV, mask))
Δgauge ≤ gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol)

aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)

Expand Down Expand Up @@ -120,10 +129,7 @@ function eigh_trunc_pullback!(
VᴴΔV = V' * ΔV
aVᴴΔV = project_antihermitian!(VᴴΔV)

mask = abs.(D' .- D) .< degeneracy_atol
Δgauge = norm(view(aVᴴΔV, mask))
Δgauge ≤ gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol)

aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)

Expand All @@ -138,7 +144,7 @@ function eigh_trunc_pullback!(
# add contribution from orthogonal complement
W = qr_null(V)
WᴴΔV = W' * ΔV
X = sylvester(W' * A * W, -Dmat, WᴴΔV)
X = _sylvester(W' * A * W, -Dmat, WᴴΔV)
Z = mul!(Z, W, X, 1, 1)

# put everything together: symmetrize for hermitian case
Expand Down
80 changes: 49 additions & 31 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,41 @@
function check_lq_cotangents(
L, Q, ΔL, ΔQ, minmn::Int, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge = max(Δgauge, norm(ΔQ2))
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn)
Δgauge = max(Δgauge, norm(ΔL22))
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
return
end

function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(Q1))
# in the case where A is full rank, but there are more columns in Q than in A
# (the case of `lq_full`), there is gauge-invariant information in the
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
# matrix. As the number of Householder reflections is in fixed in the full rank
# case, Q is expected to rotate smoothly (we might even be able to predict) also
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

"""
lq_pullback!(
ΔA, A, LQ, ΔLQ;
Expand Down Expand Up @@ -36,23 +74,7 @@ function lq_pullback!(
ΔA1 = view(ΔA, 1:p, :)
ΔA2 = view(ΔA, (p + 1):m, :)

if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
Δgauge = max(Δgauge, norm(ΔL22, Inf))
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
check_lq_cotangents(L, Q, ΔL, ΔQ, minmn, p; gauge_atol)

ΔQ̃ = zero!(similar(Q, (p, n)))
if !iszerotangent(ΔQ)
Expand All @@ -61,17 +83,8 @@ function lq_pullback!(
if p < size(Q, 1)
Q2 = view(Q, (p + 1):size(Q, 1), :)
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
# in the case where A is full rank, but there are more columns in Q than in A
# (the case of `qr_full`), there is gauge-invariant information in the
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
# matrix. As the number of Householder reflections is in fixed in the full rank
# case, Q is expected to rotate smoothly (we might even be able to predict) also
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
ΔQ2Q1ᴴ = ΔQ2 * Q1'
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol)
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
end
end
Expand Down Expand Up @@ -102,6 +115,14 @@ function lq_pullback!(
return ΔA
end

function check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ))
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
Δgauge = norm(aNᴴΔN)
Δgauge ≤ gauge_atol ||
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

"""
lq_null_pullback!(
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
Expand All @@ -118,10 +139,7 @@ function lq_null_pullback!(
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
)
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
Δgauge = norm(aNᴴΔN)
Δgauge ≤ gauge_atol ||
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol)
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')
ΔA = mul!(ΔA, X, Nᴴ, -1, 1)
Expand Down
4 changes: 2 additions & 2 deletions src/pullbacks/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
M = zero(P)
!iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1)
!iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1)
C = sylvester(P, P, M' - M)
C = _sylvester(P, P, M' - M)
C .+= ΔP
ΔA = mul!(ΔA, W, C, 1, 1)
if !iszerotangent(ΔW)
Expand Down Expand Up @@ -46,7 +46,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
M = zero(P)
!iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1)
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
C = sylvester(P, P, M' - M)
C = _sylvester(P, P, M' - M)
C .+= ΔP
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
if !iszerotangent(ΔWᴴ)
Expand Down
Loading