From 313e580df37f921bb7e5739d3c935a64aacc585f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 19 May 2026 16:35:05 +0200 Subject: [PATCH 1/6] Forward rules for vector interface --- ext/TensorKitMooncakeExt/tangent.jl | 10 +++ ext/TensorKitMooncakeExt/vectorinterface.jl | 67 +++++++++++++++++++-- test/mooncake/vectorinterface.jl | 34 +++++------ 3 files changed, 89 insertions(+), 22 deletions(-) diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 7c1d7110a..151244b94 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -1,9 +1,11 @@ # Arrayify is needed to make MatrixAlgebraKit function properly - # it turns coduals into argument types that MAK knows how to handle. Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) = arrayify(primal(A_dA), tangent(A_dA)) +Mooncake.arrayify(A_dA::Dual{<:TensorMap}) = arrayify(primal(A_dA), tangent(A_dA)) Mooncake.arrayify(A::TensorMap, dA::TensorMap) = (A, dA) Mooncake.arrayify(A_dA::CoDual{<:DiagonalTensorMap}) = arrayify(primal(A_dA), tangent(A_dA)) +Mooncake.arrayify(A_dA::Dual{<:DiagonalTensorMap}) = arrayify(primal(A_dA), tangent(A_dA)) Mooncake.arrayify(A::DiagonalTensorMap, dA::DiagonalTensorMap) = (A, dA) function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TK.AdjointTensorMap}) @@ -14,6 +16,14 @@ function Mooncake.arrayify(Aᴴ_ΔAᴴ::CoDual{<:TK.AdjointTensorMap}) return A', ΔA' end +function Mooncake.arrayify(Aᴴ_ΔAᴴ::Dual{<:TK.AdjointTensorMap}) + Aᴴ = Mooncake.primal(Aᴴ_ΔAᴴ) + ΔAᴴ = Mooncake.tangent(Aᴴ_ΔAᴴ) + A_ΔA = Dual(Aᴴ', ΔAᴴ.fields.parent) + A, ΔA = arrayify(A_ΔA) + return A', ΔA' +end + # Define the tangent type of a TensorMap to be TensorMap itself. # This has a number of benefits, but also correctly alters the # inner product when dealing with non-abelian symmetries. diff --git a/ext/TensorKitMooncakeExt/vectorinterface.jl b/ext/TensorKitMooncakeExt/vectorinterface.jl index a6f2db85f..4bbf09d12 100644 --- a/ext/TensorKitMooncakeExt/vectorinterface.jl +++ b/ext/TensorKitMooncakeExt/vectorinterface.jl @@ -1,4 +1,4 @@ -@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, Number} +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, Number} function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) # prepare arguments @@ -19,7 +19,22 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTens return C_ΔC, scale_pullback end -@is_primitive DefaultCtx ReverseMode Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} +function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + α, Δα = Mooncake.extract(α_Δα) + + if !isa(Δα, Mooncake.NoTangent) + add!(ΔC, C, Δα, α) + else + scale!(ΔC, α) + end + scale!(C, α) + + return C_ΔC +end + +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) # prepare arguments @@ -42,7 +57,21 @@ function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTens return C_ΔC, scale_pullback end -@is_primitive DefaultCtx ReverseMode Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} +function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractTensorMap}, A_ΔA::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α, Δα = Mooncake.extract(α_Δα) + + scale!(ΔC, ΔA, α) + if !isa(Δα, Mooncake.NoTangent) + add!(ΔC, A, Δα, One()) + end + scale!(C, A, α) + return C_ΔC +end + +@is_primitive DefaultCtx Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) # prepare arguments @@ -69,7 +98,26 @@ function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensor return C_ΔC, add_pullback end -@is_primitive DefaultCtx ReverseMode Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} +function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractTensorMap}, A_ΔA::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}) + # prepare arguments + C, ΔC = arrayify(C_ΔC) + A, ΔA = arrayify(A_ΔA) + α, Δα = Mooncake.extract(α_Δα) + β, Δβ = Mooncake.extract(β_Δβ) + add!(ΔC, ΔA, α, β) + if isa(Δβ, Mooncake.NoTangent) && !isa(Δα, Mooncake.NoTangent) + add!(ΔC, A, Δα, One()) + elseif isa(Δα, Mooncake.NoTangent) && !isa(Δβ, Mooncake.NoTangent) + add!(ΔC, C, Δβ, One()) + elseif !isa(Δα, Mooncake.NoTangent) && !isa(Δβ, Mooncake.NoTangent) + add!(ΔC, A, Δα, One()) + add!(ΔC, C, Δβ, One()) + end + add!(C, A, α, β) + return C_ΔC +end + +@is_primitive DefaultCtx Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap}) # prepare arguments @@ -87,3 +135,14 @@ function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTenso return CoDual(s, NoFData()), inner_pullback end + +function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual{<:AbstractTensorMap}, B_ΔB::Dual{<:AbstractTensorMap}) + # prepare arguments + A, ΔA = arrayify(A_ΔA) + B, ΔB = arrayify(B_ΔB) + + s = inner(A, B) + Δs = inner(A, ΔB) + inner(ΔA, B) + + return Dual(s, Δs) +end diff --git a/test/mooncake/vectorinterface.jl b/test/mooncake/vectorinterface.jl index 5d10101da..7ffbf0735 100644 --- a/test/mooncake/vectorinterface.jl +++ b/test/mooncake/vectorinterface.jl @@ -3,9 +3,8 @@ using TensorKit using TensorOperations using Mooncake using Random +using VectorInterface - -mode = Mooncake.ReverseMode rng = Random.default_rng() spacelist = ad_spacelist(fast_tests) @@ -17,20 +16,19 @@ eltypes = (Float64, ComplexF64) C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') - α = randn(T) - β = randn(T) - - Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, mode) - - Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, mode, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, mode, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, mode) - - Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, mode) - Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, mode) + for α in (randn(T), One(), Zero()), β in (randn(T), One(), Zero()) + Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol) + + Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol) + + Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol) + Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol) + end end From 71775bd4f78c4e03e7e97190036c6059739cf03f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 07:10:58 +0200 Subject: [PATCH 2/6] Try falling back to new VI rules --- Project.toml | 3 + .../TensorKitMooncakeExt.jl | 1 - ext/TensorKitMooncakeExt/vectorinterface.jl | 148 ------------------ 3 files changed, 3 insertions(+), 149 deletions(-) delete mode 100644 ext/TensorKitMooncakeExt/vectorinterface.jl diff --git a/Project.toml b/Project.toml index af43fc925..87b612470 100644 --- a/Project.toml +++ b/Project.toml @@ -38,6 +38,9 @@ TensorKitMooncakeExt = "Mooncake" [workspace] projects = ["test", "docs"] +[sources] +VectorInterface = {url = "https://github.com/Jutho/VectorInterface.jl", rev = "main"} + [compat] Adapt = "4" AMDGPU = "2" diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index d436173e2..5ecba854c 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -17,7 +17,6 @@ include("utility.jl") include("tangent.jl") include("linalg.jl") include("indexmanipulations.jl") -include("vectorinterface.jl") include("tensoroperations.jl") include("planaroperations.jl") include("factorizations.jl") diff --git a/ext/TensorKitMooncakeExt/vectorinterface.jl b/ext/TensorKitMooncakeExt/vectorinterface.jl deleted file mode 100644 index 4bbf09d12..000000000 --- a/ext/TensorKitMooncakeExt/vectorinterface.jl +++ /dev/null @@ -1,148 +0,0 @@ -@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, Number} - -function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - α = primal(α_Δα) - - # primal call - C_cache = copy(C) - scale!(C, α) - - function scale_pullback(::NoRData) - copy!(C, C_cache) - Δαr = _needs_tangent(α) ? project_scalar(α, inner(C, ΔC)) : NoRData() - scale!(ΔC, conj(α)) - return NoRData(), NoRData(), Δαr - end - - return C_ΔC, scale_pullback -end - -function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - α, Δα = Mooncake.extract(α_Δα) - - if !isa(Δα, Mooncake.NoTangent) - add!(ΔC, C, Δα, α) - else - scale!(ΔC, α) - end - scale!(C, α) - - return C_ΔC -end - -@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} - -function Mooncake.rrule!!(::CoDual{typeof(scale!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - A, ΔA = arrayify(A_ΔA) - α = primal(α_Δα) - - # primal call - C_cache = copy(C) - scale!(C, A, α) - - function scale_pullback(::NoRData) - copy!(C, C_cache) - add!(ΔA, ΔC, conj(α)) - Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData() - zerovector!(ΔC) - return NoRData(), NoRData(), NoRData(), Δαr - end - - return C_ΔC, scale_pullback -end - -function Mooncake.frule!!(::Dual{typeof(scale!)}, C_ΔC::Dual{<:AbstractTensorMap}, A_ΔA::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - A, ΔA = arrayify(A_ΔA) - α, Δα = Mooncake.extract(α_Δα) - - scale!(ΔC, ΔA, α) - if !isa(Δα, Mooncake.NoTangent) - add!(ΔC, A, Δα, One()) - end - scale!(C, A, α) - return C_ΔC -end - -@is_primitive DefaultCtx Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} - -function Mooncake.rrule!!(::CoDual{typeof(add!)}, C_ΔC::CoDual{<:AbstractTensorMap}, A_ΔA::CoDual{<:AbstractTensorMap}, α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - A, ΔA = arrayify(A_ΔA) - α = primal(α_Δα) - β = primal(β_Δβ) - - # primal call - C_cache = copy(C) - add!(C, A, α, β) - - function add_pullback(::NoRData) - copy!(C, C_cache) - - Δαr = _needs_tangent(α) ? project_scalar(α, inner(A, ΔC)) : NoRData() - Δβr = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : NoRData() - add!(ΔA, ΔC, conj(α)) - scale!(ΔC, conj(β)) - - return NoRData(), NoRData(), NoRData(), Δαr, Δβr - end - - return C_ΔC, add_pullback -end - -function Mooncake.frule!!(::Dual{typeof(add!)}, C_ΔC::Dual{<:AbstractTensorMap}, A_ΔA::Dual{<:AbstractTensorMap}, α_Δα::Dual{<:Number}, β_Δβ::Dual{<:Number}) - # prepare arguments - C, ΔC = arrayify(C_ΔC) - A, ΔA = arrayify(A_ΔA) - α, Δα = Mooncake.extract(α_Δα) - β, Δβ = Mooncake.extract(β_Δβ) - add!(ΔC, ΔA, α, β) - if isa(Δβ, Mooncake.NoTangent) && !isa(Δα, Mooncake.NoTangent) - add!(ΔC, A, Δα, One()) - elseif isa(Δα, Mooncake.NoTangent) && !isa(Δβ, Mooncake.NoTangent) - add!(ΔC, C, Δβ, One()) - elseif !isa(Δα, Mooncake.NoTangent) && !isa(Δβ, Mooncake.NoTangent) - add!(ΔC, A, Δα, One()) - add!(ΔC, C, Δβ, One()) - end - add!(C, A, α, β) - return C_ΔC -end - -@is_primitive DefaultCtx Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} - -function Mooncake.rrule!!(::CoDual{typeof(inner)}, A_ΔA::CoDual{<:AbstractTensorMap}, B_ΔB::CoDual{<:AbstractTensorMap}) - # prepare arguments - A, ΔA = arrayify(A_ΔA) - B, ΔB = arrayify(B_ΔB) - - # primal call - s = inner(A, B) - - function inner_pullback(Δs) - add!(ΔA, B, conj(Δs)) - add!(ΔB, A, Δs) - return NoRData(), NoRData(), NoRData() - end - - return CoDual(s, NoFData()), inner_pullback -end - -function Mooncake.frule!!(::Dual{typeof(inner)}, A_ΔA::Dual{<:AbstractTensorMap}, B_ΔB::Dual{<:AbstractTensorMap}) - # prepare arguments - A, ΔA = arrayify(A_ΔA) - B, ΔB = arrayify(B_ΔB) - - s = inner(A, B) - Δs = inner(A, ΔB) + inner(ΔA, B) - - return Dual(s, Δs) -end From 10af68dd231ce8f8b2eb0cfd5eda72badedeba95 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 11:25:51 +0200 Subject: [PATCH 3/6] Fix tangent typo --- ext/TensorKitMooncakeExt/tangent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index 151244b94..ac63fe8c3 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -191,7 +191,7 @@ _field_symbol(t, ::Val{F}) where {F} = _field_symbol(t, F) # frules _frule_getfield_common(t_dt::Dual{<:DiagOrTensorMap}, field_sym::Symbol) = - Dual(getfield(primal(t), field_sym), field_sym === :data ? tangent(t).data : NoFData()) + Dual(getfield(primal(t_dt), field_sym), field_sym === :data ? tangent(t_dt).data : NoFData()) Mooncake.frule!!(::Dual{typeof(Mooncake.lgetfield)}, t_dt::Dual{<:DiagOrTensorMap}, f_df::Dual) = _frule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df))) From 6a021c4c80445bf968c21e33698804ab4a7753b3 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 14:49:25 +0200 Subject: [PATCH 4/6] Some small fixes --- Project.toml | 2 +- ext/TensorKitMooncakeExt/tangent.jl | 2 +- ext/TensorKitMooncakeExt/utility.jl | 1 + test/mooncake/vectorinterface.jl | 21 ++++++++++----------- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 87b612470..c1366192d 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ TensorKitMooncakeExt = "Mooncake" projects = ["test", "docs"] [sources] -VectorInterface = {url = "https://github.com/Jutho/VectorInterface.jl", rev = "main"} +VectorInterface = {url = "https://github.com/kshyatt/VectorInterface.jl", rev = "ksh/mooncake_loosen"} [compat] Adapt = "4" diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl index ac63fe8c3..776591248 100644 --- a/ext/TensorKitMooncakeExt/tangent.jl +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -191,7 +191,7 @@ _field_symbol(t, ::Val{F}) where {F} = _field_symbol(t, F) # frules _frule_getfield_common(t_dt::Dual{<:DiagOrTensorMap}, field_sym::Symbol) = - Dual(getfield(primal(t_dt), field_sym), field_sym === :data ? tangent(t_dt).data : NoFData()) + Dual(getfield(primal(t_dt), field_sym), field_sym === :data ? tangent(t_dt).data : NoTangent()) Mooncake.frule!!(::Dual{typeof(Mooncake.lgetfield)}, t_dt::Dual{<:DiagOrTensorMap}, f_df::Dual) = _frule_getfield_common(t_dt, _field_symbol(primal(t_dt), primal(f_df))) diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index ceb32d867..64ad6520d 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -67,6 +67,7 @@ Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent @zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any} +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.adjoint), HomSpace} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.permute), HomSpace, Index2Tuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.braid), HomSpace, Index2Tuple, IndexTuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.compose), HomSpace, HomSpace} diff --git a/test/mooncake/vectorinterface.jl b/test/mooncake/vectorinterface.jl index 7ffbf0735..900acbae6 100644 --- a/test/mooncake/vectorinterface.jl +++ b/test/mooncake/vectorinterface.jl @@ -1,9 +1,8 @@ using Test, TestExtras using TensorKit using TensorOperations -using Mooncake +using VectorInterface, Mooncake using Random -using VectorInterface rng = Random.default_rng() @@ -17,18 +16,18 @@ eltypes = (Float64, ComplexF64) C = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') for α in (randn(T), One(), Zero()), β in (randn(T), One(), Zero()) - Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol) - Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol) - Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol) - Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol) - Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol) - Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol) + Mooncake.TestUtils.test_rule(rng, scale!, C, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!, C', α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!, C, A, α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!, C', A', α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!, copy(C'), A', α; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, scale!, C', copy(A'), α; atol, rtol, is_primitive = false) Mooncake.TestUtils.test_rule(rng, add!, C, A; atol, rtol, is_primitive = false) Mooncake.TestUtils.test_rule(rng, add!, C, A, α; atol, rtol, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol) + Mooncake.TestUtils.test_rule(rng, add!, C, A, α, β; atol, rtol, is_primitive = false) - Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol) - Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol) + Mooncake.TestUtils.test_rule(rng, inner, C, A; atol, rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, inner, C', A'; atol, rtol, is_primitive = false) end end From ddaa985312aefc10842dfe132aca2c346ec9c8dd Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 15:42:39 +0200 Subject: [PATCH 5/6] Restore primitive markers --- ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl | 1 + ext/TensorKitMooncakeExt/vectorinterface.jl | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 ext/TensorKitMooncakeExt/vectorinterface.jl diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl index 5ecba854c..d436173e2 100644 --- a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -17,6 +17,7 @@ include("utility.jl") include("tangent.jl") include("linalg.jl") include("indexmanipulations.jl") +include("vectorinterface.jl") include("tensoroperations.jl") include("planaroperations.jl") include("factorizations.jl") diff --git a/ext/TensorKitMooncakeExt/vectorinterface.jl b/ext/TensorKitMooncakeExt/vectorinterface.jl new file mode 100644 index 000000000..260f32a01 --- /dev/null +++ b/ext/TensorKitMooncakeExt/vectorinterface.jl @@ -0,0 +1,4 @@ +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, Number} +@is_primitive DefaultCtx Tuple{typeof(scale!), AbstractTensorMap, AbstractTensorMap, Number} +@is_primitive DefaultCtx Tuple{typeof(add!), AbstractTensorMap, AbstractTensorMap, Number, Number} +@is_primitive DefaultCtx Tuple{typeof(inner), AbstractTensorMap, AbstractTensorMap} From 5b8b266de27f08b6d4664df345bf679c20742cd6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 21 May 2026 18:22:54 +0200 Subject: [PATCH 6/6] Update sources --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c1366192d..eeae411ae 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,7 @@ TensorKitMooncakeExt = "Mooncake" projects = ["test", "docs"] [sources] -VectorInterface = {url = "https://github.com/kshyatt/VectorInterface.jl", rev = "ksh/mooncake_loosen"} +VectorInterface = {url = "https://github.com/QuantumKitHub/VectorInterface.jl", rev = "main"} [compat] Adapt = "4"