diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index fd27d410c..605231883 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -80,7 +80,7 @@ function ChainRulesCore.rrule( end function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) - tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A)) + tr_pullback(Δtr) = NoTangent(), scale!!(id!(similar(A)), unthunk(Δtr)) return tr(A), tr_pullback end diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 6e6cce626..12281d758 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -306,14 +306,14 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one)) return Base.$fname(codomain ← domain) end function Base.$fname( - ::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain) - ) where {T, S <: IndexSpace} - return Base.$fname(T, codomain ← domain) + ::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain) + ) where {TorA, S <: IndexSpace} + return Base.$fname(TorA, codomain ← domain) end Base.$fname(V::TensorMapSpace) = Base.$fname(Float64, V) - function Base.$fname(::Type{T}, V::TensorMapSpace) where {T} - t = TensorMap{T}(undef, V) - fill!(t, $felt(T)) + function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA} + t = tensormaptype(spacetype(V), numout(V), numin(V), TorA)(undef, V) + fill!(t, $felt(scalartype(t))) return t end end