Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions ext/VectorInterfaceMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ _needs_tangent(::Type{T}) where {T <: Number} =
# scale
# -----
@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, Number}
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number})
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, α_Δα::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
α = primal(α_Δα)
Expand All @@ -43,7 +43,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArra
return C_ΔC, scale_pullback
end

function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, α_Δα::Dual{<:Number})
function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual, α_Δα::Dual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
α, Δα = extract(α_Δα)
Expand All @@ -60,7 +60,7 @@ end

@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, AbstractArray, Number}

function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number})
function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
Expand All @@ -81,7 +81,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArra
return C_ΔC, scale_pullback
end

function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number})
function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual, A_ΔA::Dual, α_Δα::Dual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
Expand All @@ -98,7 +98,7 @@ end

@is_primitive DefaultCtx Tuple{typeof(add!), AbstractArray, AbstractArray, Number, Number}

function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number})
function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
Expand All @@ -123,7 +123,7 @@ function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray}
return C_ΔC, add_pullback
end

function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number})
function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual, A_ΔA::Dual, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number})
# prepare arguments
C, ΔC = arrayify(C_ΔC)
A, ΔA = arrayify(A_ΔA)
Expand All @@ -142,7 +142,7 @@ end

@is_primitive DefaultCtx Tuple{typeof(inner), AbstractArray, AbstractArray}

function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray}, B_ΔB::CoDual{<:AbstractArray})
function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual, B_ΔB::CoDual)
# prepare arguments
A, ΔA = arrayify(A_ΔA)
B, ΔB = arrayify(B_ΔB)
Expand All @@ -159,7 +159,7 @@ function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray
return CoDual(s, NoFData()), inner_pullback
end

function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual{<:AbstractArray}, B_ΔB::Dual{<:AbstractArray})
function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual, B_ΔB::Dual)
# prepare arguments
A, ΔA = arrayify(A_ΔA)
B, ΔB = arrayify(B_ΔB)
Expand Down
8 changes: 8 additions & 0 deletions test/mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@ using VectorInterface
using VectorInterface: MinimalMVec, MinimalSVec, MinimalVec
using Test, TestExtras
using Mooncake
import Mooncake: arrayify
using Random

rng = Random.default_rng()

precision(::Type{T}) where {T <: Union{Float32, ComplexF32}} = sqrt(eps(Float32))
precision(::Type{T}) where {T <: Union{Float64, ComplexF64}} = sqrt(eps(Float64))

function Mooncake.arrayify(A_dA::Mooncake.CoDual{<:MinimalVec})
return (Mooncake.primal(A_dA).vec, Mooncake.tangent(A_dA).data.vec)
end
function Mooncake.arrayify(A_dA::Mooncake.Dual{<:MinimalVec})
return (Mooncake.primal(A_dA).vec, Mooncake.tangent(A_dA).fields.vec)
end

eltypes = (Float32, Float64, ComplexF64)

@testset "scale ($T)" for T in eltypes
Expand Down
Loading