From 4fd5436459a6408a51529609ea0aefbcff84c4d7 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 13:16:53 +0200 Subject: [PATCH 1/3] Loosen type rules for Mooncake --- ext/VectorInterfaceMooncakeExt.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ext/VectorInterfaceMooncakeExt.jl b/ext/VectorInterfaceMooncakeExt.jl index 57c26b8..896d7c5 100644 --- a/ext/VectorInterfaceMooncakeExt.jl +++ b/ext/VectorInterfaceMooncakeExt.jl @@ -23,8 +23,8 @@ _needs_tangent(::Type{T}) where {T <: Number} = # scale # ----- -@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, Number} -function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}) +@is_primitive DefaultCtx Tuple{typeof(scale!), Any, Number} +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, α_Δα::CoDual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) α = primal(α_Δα) @@ -43,7 +43,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArra return C_ΔC, scale_pullback end -function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}) +function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual, α_Δα::Dual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) α, Δα = extract(α_Δα) @@ -58,9 +58,9 @@ function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, return C_ΔC end -@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, AbstractArray, Number} +@is_primitive DefaultCtx Tuple{typeof(scale!), Any, Any, Number} -function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}) +function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) A, ΔA = arrayify(A_ΔA) @@ -81,7 +81,7 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractArra return C_ΔC, scale_pullback end -function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}) +function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual, A_ΔA::Dual, α_Δα::Dual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) A, ΔA = arrayify(A_ΔA) @@ -96,9 +96,9 @@ end # add # --- -@is_primitive DefaultCtx Tuple{typeof(add!), AbstractArray, AbstractArray, Number, Number} +@is_primitive DefaultCtx Tuple{typeof(add!), Any, Any, Number, Number} -function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray}, A_ΔA::CoDual{<:AbstractArray}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) +function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) A, ΔA = arrayify(A_ΔA) @@ -123,7 +123,7 @@ function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractArray} return C_ΔC, add_pullback end -function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractArray}, A_ΔA::Dual{<:AbstractArray}, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}) +function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual, A_ΔA::Dual, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) A, ΔA = arrayify(A_ΔA) @@ -140,9 +140,9 @@ end # inner # ----- -@is_primitive DefaultCtx Tuple{typeof(inner), AbstractArray, AbstractArray} +@is_primitive DefaultCtx Tuple{typeof(inner), Any, Any} -function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray}, B_ΔB::CoDual{<:AbstractArray}) +function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual, B_ΔB::CoDual) # prepare arguments A, ΔA = arrayify(A_ΔA) B, ΔB = arrayify(B_ΔB) @@ -159,7 +159,7 @@ function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractArray return CoDual(s, NoFData()), inner_pullback end -function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual{<:AbstractArray}, B_ΔB::Dual{<:AbstractArray}) +function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual, B_ΔB::Dual) # prepare arguments A, ΔA = arrayify(A_ΔA) B, ΔB = arrayify(B_ΔB) From 2b056ef6abda967cf6470f2242c73f31391f08dd Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 13:37:09 +0200 Subject: [PATCH 2/3] Arrayify methods --- test/mooncake.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/mooncake.jl b/test/mooncake.jl index d8466d1..3f44ca3 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -4,6 +4,7 @@ using VectorInterface using VectorInterface: MinimalMVec, MinimalSVec, MinimalVec using Test, TestExtras using Mooncake +import Mooncake: arrayify using Random rng = Random.default_rng() @@ -11,6 +12,13 @@ rng = Random.default_rng() precision(::Type{T}) where {T <: Union{Float32, ComplexF32}} = sqrt(eps(Float32)) precision(::Type{T}) where {T <: Union{Float64, ComplexF64}} = sqrt(eps(Float64)) +function Mooncake.arrayify(A_dA::Mooncake.CoDual{<:MinimalVec}) + return (Mooncake.primal(A_dA).vec, Mooncake.tangent(A_dA).data.vec) +end +function Mooncake.arrayify(A_dA::Mooncake.Dual{<:MinimalVec}) + return (Mooncake.primal(A_dA).vec, Mooncake.tangent(A_dA).fields.vec) +end + eltypes = (Float32, Float64, ComplexF64) @testset "scale ($T)" for T in eltypes From c6c813c261ca9c88f23989daca6014c229dfba62 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 15:37:01 +0200 Subject: [PATCH 3/3] Tighten isprimitive for AbstractArray --- ext/VectorInterfaceMooncakeExt.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/VectorInterfaceMooncakeExt.jl b/ext/VectorInterfaceMooncakeExt.jl index 896d7c5..124583f 100644 --- a/ext/VectorInterfaceMooncakeExt.jl +++ b/ext/VectorInterfaceMooncakeExt.jl @@ -23,7 +23,7 @@ _needs_tangent(::Type{T}) where {T <: Number} = # scale # ----- -@is_primitive DefaultCtx Tuple{typeof(scale!), Any, Number} +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, Number} function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, α_Δα::CoDual{<:Number}) # prepare arguments C, ΔC = arrayify(C_ΔC) @@ -58,7 +58,7 @@ function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual, α_Δα::Dual{<:N return C_ΔC end -@is_primitive DefaultCtx Tuple{typeof(scale!), Any, Any, Number} +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractArray, AbstractArray, Number} function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number}) # prepare arguments @@ -96,7 +96,7 @@ end # add # --- -@is_primitive DefaultCtx Tuple{typeof(add!), Any, Any, Number, Number} +@is_primitive DefaultCtx Tuple{typeof(add!), AbstractArray, AbstractArray, Number, Number} function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual, A_ΔA::CoDual, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) # prepare arguments @@ -140,7 +140,7 @@ end # inner # ----- -@is_primitive DefaultCtx Tuple{typeof(inner), Any, Any} +@is_primitive DefaultCtx Tuple{typeof(inner), AbstractArray, AbstractArray} function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual, B_ΔB::CoDual) # prepare arguments