From 23bc031e45d1fac0bbc839e524f85c740ac48b66 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Mon, 16 May 2022 11:21:11 -0400 Subject: [PATCH] [ITensors] [ENHANCMENT] Introduce `removeqn` function for removing a specified quantum number (#915) --- src/exports.jl | 1 + src/index.jl | 15 +++++ src/indexset.jl | 7 ++- src/itensor.jl | 10 ++++ src/mps/abstractmps.jl | 5 ++ src/qn/qn.jl | 24 +++++++- src/qn/qnindex.jl | 32 ++++++++++ src/qn/qnindexset.jl | 2 - src/qn/qnitensor.jl | 17 ++++++ test/ITensorChainRules/test_optimization.jl | 2 +- test/qnitensor.jl | 66 +++++++++++++++++++++ 11 files changed, 175 insertions(+), 6 deletions(-) diff --git a/src/exports.jl b/src/exports.jl index 0acb928828..e854e08b0e 100644 --- a/src/exports.jl +++ b/src/exports.jl @@ -69,6 +69,7 @@ export plev, prime, removetags, + removeqn, removeqns, replacetags, replacetags!, diff --git a/src/index.jl b/src/index.jl index 3b9705d4c5..43eb36ba72 100644 --- a/src/index.jl +++ b/src/index.jl @@ -517,6 +517,21 @@ Removes the QNs from the Index, if it has any. """ removeqns(i::Index) = i +""" + removeqn(::Index, qn_name::String) + +Remove the specified QN from the Index, if it has any. +""" +removeqn(i::Index, qn_name::String) = i + +""" + mergeblocks(::Index) + +Merge the contiguous QN blocks if they have the same +quantum numbers. +""" +mergeblocks(i::Index) = i + # Keep partial backwards compatibility by defining IndexVal as follows: const IndexVal{IndexT} = Pair{IndexT,Int} diff --git a/src/indexset.jl b/src/indexset.jl index 99fc79bcbc..bddeeef2a4 100644 --- a/src/indexset.jl +++ b/src/indexset.jl @@ -1,4 +1,3 @@ - # Represents a static order of an ITensor @eval struct Order{N} (OrderT::Type{<:Order})() = $(Expr(:new, :OrderT)) @@ -596,7 +595,11 @@ end swapind(is::Indices, i1::Index, i2::Index) = swapinds(is, (i1,), (i2,)) -removeqns(is::Indices) = is +removeqns(is::Indices) = map(removeqns, is) +function removeqn(is::Indices, qn_name::String; mergeblocks=true) + return map(i -> removeqn(i, qn_name; mergeblocks), is) +end +mergeblocks(is::Indices) = map(mergeblocks, is) # Permute is1 to be in the order of is2 # This is helpful when is1 and is2 have different directions, and diff --git a/src/itensor.jl b/src/itensor.jl index b7e73b1aad..2d06a47a8e 100644 --- a/src/itensor.jl +++ b/src/itensor.jl @@ -674,6 +674,8 @@ function dense(A::ITensor) return setinds(itensor(dense(tensor(A))), removeqns(inds(A))) end +removeqns(T::ITensor) = dense(T) + denseblocks(D::ITensor) = itensor(denseblocks(tensor(D))) """ @@ -1072,6 +1074,14 @@ it may return a Cartesian range. """ eachindex(A::ITensor) = eachindex(tensor(A)) +""" + eachindval(A::ITensor) + +Create an iterable object for visiting each element of the ITensor `A` (including structually +zero elements for sparse tensors) in terms of pairs of indices and values. +""" +eachindval(T::ITensor) = eachindval(inds(T)) + """ iterate(A::ITensor, args...) diff --git a/src/mps/abstractmps.jl b/src/mps/abstractmps.jl index 5467a37326..c60fe1bbf4 100644 --- a/src/mps/abstractmps.jl +++ b/src/mps/abstractmps.jl @@ -2194,6 +2194,11 @@ function splitblocks(::typeof(linkinds), M::AbstractMPS; tol=0) return splitblocks!(linkinds, copy(M); tol=0) end +removeqns(M::AbstractMPS) = map(removeqns, M; set_limits=false) +function removeqn(M::AbstractMPS, qn_name::String) + return map(m -> removeqn(m, qn_name), M; set_limits=false) +end + # # Broadcasting # diff --git a/src/qn/qn.jl b/src/qn/qn.jl index 8c2d1e04ec..85d03447b8 100644 --- a/src/qn/qn.jl +++ b/src/qn/qn.jl @@ -1,4 +1,3 @@ - struct QNVal name::SmallString val::Int @@ -357,6 +356,29 @@ function have_same_mods(qn1::QN, qn2::QN) return true end +function removeqn(qn::QN, qn_name::String) + ss_qn_name = SmallString(qn_name) + + # Find the location of the QNVal to remove + n_qn = nothing + for n in 1:length(qn) + qnval = qn[n] + if name(qnval) == ss_qn_name + n_qn = n + end + end + if isnothing(n_qn) + return qn + end + + qn_data = data(qn) + for j in n_qn:(length(qn) - 1) + qn_data = setindex(qn_data, qn_data[j + 1], j) + end + qn_data = setindex(qn_data, QNVal(), length(qn)) + return QN(qn_data) +end + function show(io::IO, q::QN) print(io, "QN(") Na = nactive(q) diff --git a/src/qn/qnindex.jl b/src/qn/qnindex.jl index e38503990a..db139b33e1 100644 --- a/src/qn/qnindex.jl +++ b/src/qn/qnindex.jl @@ -37,6 +37,10 @@ function (qn1::QNBlock + qn2::QNBlock) return QNBlock(qn(qn1), blockdim(qn1) + blockdim(qn2)) end +function removeqn(qn_block::QNBlock, qn_name::String) + return removeqn(qn(qn_block), qn_name) => blockdim(qn_block) +end + function -(qns::QNBlocks) qns_new = copy(qns) for i in 1:length(qns_new) @@ -45,6 +49,30 @@ function -(qns::QNBlocks) return qns_new end +function mergeblocks(qns::QNBlocks) + qnsC = [qns[1]] + + # Which block this is, after combining + block_count = 1 + for i in 2:nblocks(qns) + if qn(qns[i]) == qn(qns[i - 1]) + qnsC[block_count] += qns[i] + else + push!(qnsC, qns[i]) + block_count += 1 + end + end + return qnsC +end + +function removeqn(space::QNBlocks, qn_name::String; mergeblocks=true) + space = QNBlocks([removeqn(qn_block, qn_name) for qn_block in space]) + if mergeblocks + space = ITensors.mergeblocks(space) + end + return space +end + """ A QN Index is an Index with QN block storage instead of just an integer dimension. The QN block storage is a @@ -396,6 +424,10 @@ function combineblocks(i::QNIndex) end removeqns(i::QNIndex) = setdir(setspace(i, dim(i)), Neither) +function removeqn(i::QNIndex, qn_name::String; mergeblocks=true) + return setspace(i, removeqn(space(i), qn_name; mergeblocks)) +end +mergeblocks(i::QNIndex) = setspace(i, mergeblocks(space(i))) function addqns(i::Index, qns::QNBlocks; dir::Arrow=Out) @assert dim(i) == dim(qns) diff --git a/src/qn/qnindexset.jl b/src/qn/qnindexset.jl index 4a7ebed730..22bd8dfdff 100644 --- a/src/qn/qnindexset.jl +++ b/src/qn/qnindexset.jl @@ -29,8 +29,6 @@ function nzdiagblocks(qn::QN, inds::Indices) return blocks end -removeqns(is::QNIndices) = map(i -> removeqns(i), is) - anyfermionic(is::Indices) = any(isfermionic, is) allfermionic(is::Indices) = all(isfermionic, is) diff --git a/src/qn/qnitensor.jl b/src/qn/qnitensor.jl index 7302147bb7..18a729a657 100644 --- a/src/qn/qnitensor.jl +++ b/src/qn/qnitensor.jl @@ -467,6 +467,9 @@ function δ_split(i1::Index, i2::Index) end function splitblocks(A::ITensor, is=inds(A); tol=0) + if !hasqns(A) + return A + end isA = filterinds(A; inds=is) for i in isA i_split = splitblocks(i) @@ -481,3 +484,17 @@ function splitblocks(A::ITensor, is=inds(A); tol=0) A = dropzeros(A; tol=tol) return A end + +function removeqn(T::ITensor, qn_name::String; mergeblocks=true) + if !hasqns(T) + return T + end + inds_R = removeqn(inds(T), qn_name; mergeblocks) + R = ITensor(inds_R) + for iv in eachindex(T) + if !iszero(T[iv]) + R[iv] = T[iv] + end + end + return R +end diff --git a/test/ITensorChainRules/test_optimization.jl b/test/ITensorChainRules/test_optimization.jl index 81681c41b8..a990e45531 100644 --- a/test/ITensorChainRules/test_optimization.jl +++ b/test/ITensorChainRules/test_optimization.jl @@ -124,7 +124,7 @@ include("utils/circuit.jl") end @testset "State preparation (MPS)" begin - for gate in ["Ry"]#="Rx", =# + for gate in ["Ry"] #="Rx", =# nsites = 4 # Number of sites nlayers = 2 # Layers of gates in the ansatz gradtol = 1e-3 # Tolerance for stopping gradient descent diff --git a/test/qnitensor.jl b/test/qnitensor.jl index a76732d73d..0cca94acab 100644 --- a/test/qnitensor.jl +++ b/test/qnitensor.jl @@ -1779,6 +1779,72 @@ Random.seed!(1234) # increase the number of blocks of A's storage @test length(ITensors.blockoffsets(ITensors.tensor(A))) == 1 end + + @testset "removeqns and removeqn" begin + s = siteind("Electron"; conserve_qns=true) + T = op("c†↑", s) + + @test hasqns(s) + @test hasqns(T) + @test qn(s, 1) == QN(("Nf", 0, -1), ("Sz", 0)) + @test qn(s, 2) == QN(("Nf", 1, -1), ("Sz", 1)) + @test qn(s, 3) == QN(("Nf", 1, -1), ("Sz", -1)) + @test qn(s, 4) == QN(("Nf", 2, -1), ("Sz", 0)) + @test blockdim(s, 1) == 1 + @test blockdim(s, 2) == 1 + @test blockdim(s, 3) == 1 + @test blockdim(s, 4) == 1 + @test nblocks(s) == 4 + @test dim(s) == 4 + + s1 = removeqns(s) + T1 = removeqns(T) + @test !hasqns(s1) + @test !hasqns(T1) + @test nblocks(s1) == 1 + @test dim(s1) == 4 + for I in eachindex(T1) + @test T1[I] == T[I] + end + + s2 = removeqn(s, "Sz") + T2 = removeqn(T, "Sz") + @test hasqns(s2) + @test hasqns(T2) + @test nnzblocks(T2) == 2 + @test nblocks(s2) == 3 + @test nblocks(T2) == (3, 3) + @test qn(s2, 1) == QN(("Nf", 0, -1)) + @test qn(s2, 2) == QN(("Nf", 1, -1)) + @test qn(s2, 3) == QN(("Nf", 2, -1)) + @test blockdim(s2, 1) == 1 + @test blockdim(s2, 2) == 2 + @test blockdim(s2, 3) == 1 + @test dim(s2) == 4 + for I in eachindex(T2) + @test T2[I] == T[I] + end + + s3 = removeqn(s, "Nf") + T3 = removeqn(T, "Nf") + @test hasqns(s3) + @test hasqns(T3) + @test nnzblocks(T3) == 2 + @test nblocks(s3) == 4 + @test nblocks(T3) == (4, 4) + @test qn(s3, 1) == QN(("Sz", 0)) + @test qn(s3, 2) == QN(("Sz", 1)) + @test qn(s3, 3) == QN(("Sz", -1)) + @test qn(s3, 4) == QN(("Sz", 0)) + @test blockdim(s3, 1) == 1 + @test blockdim(s3, 2) == 1 + @test blockdim(s3, 3) == 1 + @test blockdim(s3, 4) == 1 + @test dim(s3) == 4 + for I in eachindex(T3) + @test T3[I] == T[I] + end + end end nothing