diff --git a/src/common/gauge.jl b/src/common/gauge.jl index 84b5f7688..016b60ad8 100644 --- a/src/common/gauge.jl +++ b/src/common/gauge.jl @@ -10,7 +10,8 @@ is real and positive. # Helper functions _argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x))) -_largest(x, y) = abs(x) < abs(y) ? y : x +_largest(x::Real, y::Real) = abs(x) < abs(y) ? y : x +_largest(x::Complex, y::Complex) = abs2(x) < abs2(y) ? y : x function gaugefix!(::typeof(qr_householder!), Q, R, Rd) ax = Base.OneTo(length(Rd)) @@ -67,12 +68,10 @@ end function gaugefix!(::Union{typeof(svd_compact!), typeof(svd_trunc!)}, U, Vᴴ) @assert axes(U, 2) == axes(Vᴴ, 1) - for j in axes(U, 2) - u = view(U, :, j) - v = view(Vᴴ, j, :) - s = sign(_argmaxabs(u)) - u .*= conj(s) - v .*= s - end + signs = reduce(_largest, U; dims = 1, init = zero(eltype(U))) + @. signs = sign(signs) + signs_t = transpose(signs) + @. U = U * conj(signs) + @. Vᴴ = signs_t * Vᴴ return (U, Vᴴ) end