From 2d2abc5650b8fbf982300f659fa0bb7c98510976 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 30 Jan 2026 06:39:48 -0500 Subject: [PATCH 1/6] Small fixes for upstream + CUDA --- ext/TensorKitCUDAExt/cutensormap.jl | 5 +++++ ext/TensorKitChainRulesCoreExt/linalg.jl | 2 +- src/tensors/tensor.jl | 17 +++++++++++------ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/ext/TensorKitCUDAExt/cutensormap.jl b/ext/TensorKitCUDAExt/cutensormap.jl index 3274e654a..8475c0a2e 100644 --- a/ext/TensorKitCUDAExt/cutensormap.jl +++ b/ext/TensorKitCUDAExt/cutensormap.jl @@ -37,6 +37,11 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one)) fill!(t, $felt(T)) return t end + function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA <: CuArray} + t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V) + fill!(t, $felt(eltype(TorA))) + return t + end end end diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index fd27d410c..28a85752a 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -80,7 +80,7 @@ function ChainRulesCore.rrule( end function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) - tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A)) + tr_pullback(Δtr) = NoTangent(), Δtr * id(storagetype(A), domain(A)) return tr(A), tr_pullback end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 6e6cce626..f39078b4c 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -306,14 +306,19 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one)) return Base.$fname(codomain ← domain) end function Base.$fname( - ::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain) - ) where {T, S <: IndexSpace} - return Base.$fname(T, codomain ← domain) + ::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain) + ) where {TorA, S <: IndexSpace} + return Base.$fname(TorA, codomain ← domain) + end + function Base.$fname( + ::Type{T}, ::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain) + ) where {T, TorA, S <: IndexSpace} + return Base.$fname(TorA, codomain ← domain) end Base.$fname(V::TensorMapSpace) = Base.$fname(Float64, V) - function Base.$fname(::Type{T}, V::TensorMapSpace) where {T} - t = TensorMap{T}(undef, V) - fill!(t, $felt(T)) + function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA} + t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V) + fill!(t, $felt(TorA)) return t end end From cf6f75f615e60840a9fab0d2d587b161cebeb858 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 6 Feb 2026 15:37:31 +0100 Subject: [PATCH 2/6] Update ext/TensorKitChainRulesCoreExt/linalg.jl Co-authored-by: Lukas Devos --- ext/TensorKitChainRulesCoreExt/linalg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index 28a85752a..670eb8fc3 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -80,7 +80,7 @@ function ChainRulesCore.rrule( end function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) - tr_pullback(Δtr) = NoTangent(), Δtr * id(storagetype(A), domain(A)) + tr_pullback(Δtr) = NoTangent(), scale!!(id!(similar(A)), Δtr) return tr(A), tr_pullback end From 65959790ec1e570c6f4171464475e9f66114c343 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 6 Feb 2026 15:37:37 +0100 Subject: [PATCH 3/6] Update src/tensors/tensor.jl Co-authored-by: Lukas Devos --- src/tensors/tensor.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index f39078b4c..432f50d6c 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -318,7 +318,7 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one)) Base.$fname(V::TensorMapSpace) = Base.$fname(Float64, V) function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA} t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V) - fill!(t, $felt(TorA)) + fill!(t, $felt(scalartype(t))) return t end end From de4eb2b497a08f1c3856790705d3ec4eeaa4130f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 6 Feb 2026 15:37:43 +0100 Subject: [PATCH 4/6] Update ext/TensorKitCUDAExt/cutensormap.jl Co-authored-by: Lukas Devos --- ext/TensorKitCUDAExt/cutensormap.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ext/TensorKitCUDAExt/cutensormap.jl b/ext/TensorKitCUDAExt/cutensormap.jl index 8475c0a2e..3274e654a 100644 --- a/ext/TensorKitCUDAExt/cutensormap.jl +++ b/ext/TensorKitCUDAExt/cutensormap.jl @@ -37,11 +37,6 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one)) fill!(t, $felt(T)) return t end - function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA <: CuArray} - t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V) - fill!(t, $felt(eltype(TorA))) - return t - end end end From cbe6e38a3f3ef57000014b3881047ca0d8060f59 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 6 Feb 2026 09:40:54 -0500 Subject: [PATCH 5/6] Remove unneeded method --- src/tensors/tensor.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 432f50d6c..12281d758 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -310,11 +310,6 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one)) ) where {TorA, S <: IndexSpace} return Base.$fname(TorA, codomain ← domain) end - function Base.$fname( - ::Type{T}, ::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain) - ) where {T, TorA, S <: IndexSpace} - return Base.$fname(TorA, codomain ← domain) - end Base.$fname(V::TensorMapSpace) = Base.$fname(Float64, V) function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA} t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V) From ad9a3241e832c00cacb302ec22a1bc18645eaa39 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 6 Feb 2026 11:04:04 -0500 Subject: [PATCH 6/6] Missing unthunk --- ext/TensorKitChainRulesCoreExt/linalg.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index 670eb8fc3..605231883 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -80,7 +80,7 @@ function ChainRulesCore.rrule( end function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) - tr_pullback(Δtr) = NoTangent(), scale!!(id!(similar(A)), Δtr) + tr_pullback(Δtr) = NoTangent(), scale!!(id!(similar(A)), unthunk(Δtr)) return tr(A), tr_pullback end