From d5a77bed3a05324ef7f6f6580457689334b411b7 Mon Sep 17 00:00:00 2001 From: Ryan Levy Date: Mon, 14 Oct 2024 16:30:19 -0400 Subject: [PATCH 1/4] Port functions from ITensorInfiniteMPS --- src/itensor.jl | 1 + src/lib/ITensorMPS/src/mps.jl | 2 ++ src/lib/SmallStrings/src/smallstring.jl | 27 +++++++++++++++++++ src/lib/TagSets/src/TagSets.jl | 1 + src/tensor_operations/matrix_decomposition.jl | 21 +++++++++++++++ 5 files changed, 52 insertions(+) diff --git a/src/itensor.jl b/src/itensor.jl index e092cf12da..f94f7808bf 100644 --- a/src/itensor.jl +++ b/src/itensor.jl @@ -840,6 +840,7 @@ dirs(A::ITensor, is) = dirs(inds(A), is) # TODO: add isdiag(::Tensor) to NDTensors isdiag(T::ITensor)::Bool = (storage(T) isa Diag || storage(T) isa DiagBlockSparse) +LinearAlgebra.isdiag(T::ITensor) = isdiag(T) diaglength(T::ITensor) = diaglength(tensor(T)) diff --git a/src/lib/ITensorMPS/src/mps.jl b/src/lib/ITensorMPS/src/mps.jl index ef462a8462..5b648c8493 100644 --- a/src/lib/ITensorMPS/src/mps.jl +++ b/src/lib/ITensorMPS/src/mps.jl @@ -1032,3 +1032,5 @@ end function expect(psi::MPS, op1::Matrix{<:Number}, ops::Matrix{<:Number}...; kwargs...) return expect(psi, (op1, ops...); kwargs...) end + +Base.getindex(ψ::MPS, r::UnitRange{Int}) = MPS([ψ[n] for n in r]) diff --git a/src/lib/SmallStrings/src/smallstring.jl b/src/lib/SmallStrings/src/smallstring.jl index ac660c5d72..d724348644 100644 --- a/src/lib/SmallStrings/src/smallstring.jl +++ b/src/lib/SmallStrings/src/smallstring.jl @@ -102,6 +102,33 @@ end Base.:(==)(s1::SmallString, s2::SmallString) = (s1.data == s2.data) Base.isless(s1::SmallString, s2::SmallString) = isless(s1.data, s2.data) +maxlength(s::SmallString) = length(s.data) + +function Base.length(s::SmallString) + n = 1 + while n <= maxlength(s) && s[n] != zero(eltype(s)) + n += 1 + end + return n - 1 +end + +Base.lastindex(s::SmallString) = length(s) +Base.getindex(s::SmallString, r::UnitRange) = SmallString([s[n] for n in r]) + + +# TODO: make this work directly on a Tag, without converting +# to String +function Base.parse(::Type{T}, s::SmallString) where {T<:Integer} + return parse(T, string(s)) +end + +function Base.startswith(s::SmallString, subtag::SmallString) + for n in 1:length(subtag) + s[n] ≠ subtag[n] && return false + end + return true +end + ######################################################## # Here are alternative SmallString comparison implementations # diff --git a/src/lib/TagSets/src/TagSets.jl b/src/lib/TagSets/src/TagSets.jl index 45d3152bf7..44331a7d29 100644 --- a/src/lib/TagSets/src/TagSets.jl +++ b/src/lib/TagSets/src/TagSets.jl @@ -215,6 +215,7 @@ data(T::TagSet) = T.data Base.length(T::TagSet) = T.length Base.@propagate_inbounds Base.getindex(T::TagSet, n::Integer) = SmallString(data(T)[n]) Base.copy(ts::TagSet) = TagSet(data(ts), length(ts)) +Base.keys(ts::TagSet) = Base.OneTo(length(ts)) function Base.:(==)(ts1::TagSet, ts2::TagSet) l1 = length(ts1) diff --git a/src/tensor_operations/matrix_decomposition.jl b/src/tensor_operations/matrix_decomposition.jl index 8983997d1a..ae31b4e98f 100644 --- a/src/tensor_operations/matrix_decomposition.jl +++ b/src/tensor_operations/matrix_decomposition.jl @@ -590,6 +590,27 @@ function sqrt_decomp(D::ITensor, u::Index, v::Index) return sqrtDL, prime(δᵤᵥ), sqrtDR end +# Take the square root of T assuming it is Hermitian +# TODO: add more general index structures +function Base.sqrt(T::ITensor; ishermitian=true, atol=1e-15) + @assert ishermitian + # TODO diagonal version + #if isdiag(T) && order(T) == 2 + # return itensor(sqrt(tensor(T))) + #end + D, U = eigen(T; ishermitian=ishermitian) + sqrtD = D + for n in 1:mindim(D) + Dnn = D[n, n] + if Dnn < 0 && abs(Dnn) < atol + sqrtD[n, n] = 0 + else + sqrtD[n, n] = sqrt(Dnn) + end + end + return U' * sqrtD * dag(U) +end + function factorize_svd( A::ITensor, Linds...; From d982ac444a490fbc7443b912040e599fc5060fb1 Mon Sep 17 00:00:00 2001 From: Ryan Levy Date: Mon, 14 Oct 2024 16:34:02 -0400 Subject: [PATCH 2/4] Formatting --- src/lib/SmallStrings/src/smallstring.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lib/SmallStrings/src/smallstring.jl b/src/lib/SmallStrings/src/smallstring.jl index d724348644..2492569ee0 100644 --- a/src/lib/SmallStrings/src/smallstring.jl +++ b/src/lib/SmallStrings/src/smallstring.jl @@ -115,7 +115,6 @@ end Base.lastindex(s::SmallString) = length(s) Base.getindex(s::SmallString, r::UnitRange) = SmallString([s[n] for n in r]) - # TODO: make this work directly on a Tag, without converting # to String function Base.parse(::Type{T}, s::SmallString) where {T<:Integer} From 20277cd77d09aeeb145a114da04f9a08bde63693 Mon Sep 17 00:00:00 2001 From: Ryan Levy Date: Mon, 14 Oct 2024 17:44:41 -0400 Subject: [PATCH 3/4] Remove ported function --- src/lib/ITensorMPS/src/mps.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lib/ITensorMPS/src/mps.jl b/src/lib/ITensorMPS/src/mps.jl index 5b648c8493..ef462a8462 100644 --- a/src/lib/ITensorMPS/src/mps.jl +++ b/src/lib/ITensorMPS/src/mps.jl @@ -1032,5 +1032,3 @@ end function expect(psi::MPS, op1::Matrix{<:Number}, ops::Matrix{<:Number}...; kwargs...) return expect(psi, (op1, ops...); kwargs...) end - -Base.getindex(ψ::MPS, r::UnitRange{Int}) = MPS([ψ[n] for n in r]) From a17b39b79fb86f2d105fb09596d5a17159dbb823 Mon Sep 17 00:00:00 2001 From: Ryan Levy Date: Mon, 21 Oct 2024 07:56:03 -0400 Subject: [PATCH 4/4] Better sqrt decomp --- src/tensor_operations/matrix_decomposition.jl | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/tensor_operations/matrix_decomposition.jl b/src/tensor_operations/matrix_decomposition.jl index ae31b4e98f..4637f3fe80 100644 --- a/src/tensor_operations/matrix_decomposition.jl +++ b/src/tensor_operations/matrix_decomposition.jl @@ -599,15 +599,7 @@ function Base.sqrt(T::ITensor; ishermitian=true, atol=1e-15) # return itensor(sqrt(tensor(T))) #end D, U = eigen(T; ishermitian=ishermitian) - sqrtD = D - for n in 1:mindim(D) - Dnn = D[n, n] - if Dnn < 0 && abs(Dnn) < atol - sqrtD[n, n] = 0 - else - sqrtD[n, n] = sqrt(Dnn) - end - end + sqrtD = map_diag(x -> x < 0 && abs(x) < atol ? 0 : sqrt(x), D) return U' * sqrtD * dag(U) end