Skip to content

Commit

Permalink
Implement fit algorithm for sum of MPSs
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed Jul 22, 2024
1 parent 4eed68b commit 3bea7e4
Show file tree
Hide file tree
Showing 9 changed files with 329 additions and 24 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[targets]
test = ["Test", "Random", "Aqua", "JET"]
test = ["Test", "Random", "Aqua", "JET", "StableRNGs"]
32 changes: 17 additions & 15 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
using FastMPOContractions
using Documenter

DocMeta.setdocmeta!(FastMPOContractions, :DocTestSetup, :(using FastMPOContractions); recursive=true)
DocMeta.setdocmeta!(
FastMPOContractions,
:DocTestSetup,
:(using FastMPOContractions);
recursive = true,
)

makedocs(;
modules=[FastMPOContractions],
authors="Hiroshi Shinaoka <h.shinaoka@gmail.com> and contributors",
sitename="FastMPOContractions.jl",
format=Documenter.HTML(;
canonical="https://github.com/tensor4all/FastMPOContractions.jl",
edit_link="main",
assets=String[]),
pages=[
"Home" => "index.md",
])

deploydocs(;
repo="github.com/tensor4all/FastMPOContractions.jl.git",
devbranch="main",
modules = [FastMPOContractions],
authors = "Hiroshi Shinaoka <h.shinaoka@gmail.com> and contributors",
sitename = "FastMPOContractions.jl",
format = Documenter.HTML(;
canonical = "https://github.com/tensor4all/FastMPOContractions.jl",
edit_link = "main",
assets = String[],
),
pages = ["Home" => "index.md"],
)

deploydocs(; repo = "github.com/tensor4all/FastMPOContractions.jl.git", devbranch = "main")
4 changes: 3 additions & 1 deletion src/FastMPOContractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ module FastMPOContractions
using StaticArrays

using ITensors
import ITensors.ITensorMPS: AbstractMPS, sim!, setleftlim!, setrightlim!, check_hascommoninds
import ITensors.ITensorMPS:
AbstractMPS, sim!, setleftlim!, setrightlim!, check_hascommoninds

using ITensorTDVP

include("densitymatrix.jl")
include("fitalgorithm.jl")
include("util.jl")
include("contractMPO.jl")
include("fitalgorithm_sum.jl")

end
2 changes: 1 addition & 1 deletion src/contractMPO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ function contract_mpo_mpo(M1::MPO, M2::MPO; alg::String = "densitymatrix", kwarg
error("Unknown algorithm: $alg")
end

end
end
10 changes: 7 additions & 3 deletions src/fitalgorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Contract M1 and M2, and return the result as an MPO.
function contract_fit(M1::MPO, M2::MPO; init = nothing, kwargs...)::MPO
M2_ = MPS([M2[v] for v in eachindex(M2)])
if init === nothing
init_MPO::MPO = ITensors.contract(M1, M2; alg="zipup", kwargs...)
init_MPO::MPO = ITensors.contract(M1, M2; alg = "zipup", kwargs...)
init = MPS([init_MPO[v] for v in eachindex(init_MPO)])
else
init = MPS([init[v] for v in eachindex(M2)])
Expand Down Expand Up @@ -58,6 +58,10 @@ function contract_fit(A::MPO, psi0::MPS; init_mps = psi0, nsweeps = 1, kwargs...

reduced_operator = ITensorTDVP.ReducedContractProblem(psi0, A)
return ITensorTDVP.alternating_update(
reduced_operator, init_mps; updater=ITensorTDVP.contract_operator_state_updater, nsweeps=nsweeps, kwargs...
)
reduced_operator,
init_mps;
updater = ITensorTDVP.contract_operator_state_updater,
nsweeps = nsweeps,
kwargs...,
)
end
242 changes: 242 additions & 0 deletions src/fitalgorithm_sum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
using ITensors.ITensorMPS: ITensorMPS, AbstractProjMPO, MPO, MPS
using ITensors.ITensorMPS: linkinds, replaceinds
using ITensors: ITensors, OneITensor
import ITensorTDVP: alternating_update, rproj, lproj

"""
A ReducedFitProblem represents the projection
of an MPS `input_state` onto the basis of a different MPS `state`.
`state` may be an approximation of `input_state`.
```
*--*--*- -*--*--*--*--*--* <state|
| | | | | | | | | | |
o--o--o- -o--o--o--o--o--o |input_state>
```
"""
mutable struct ReducedFitProblem <: AbstractProjMPO
lpos::Int
rpos::Int
nsite::Int
input_state::MPS
environments::Vector{ITensor}
end

function ReducedFitProblem(input_state::MPS)
lpos = 0
rpos = length(input_state) + 1
nsite = 2
environments = Vector{ITensor}(undef, length(input_state))
return ReducedFitProblem(lpos, rpos, nsite, input_state, environments)
end

function lproj(P::ReducedFitProblem)::Union{ITensor,OneITensor}
(P.lpos <= 0) && return OneITensor()
return P.environments[P.lpos]
end

function rproj(P::ReducedFitProblem)::Union{ITensor,OneITensor}
(P.rpos >= length(P) + 1) && return OneITensor()
return P.environments[P.rpos]
end


function Base.copy(reduced_operator::ReducedFitProblem)
return ReducedFitProblem(
reduced_operator.lpos,
reduced_operator.rpos,
reduced_operator.nsite,
copy(reduced_operator.input_state),
copy(reduced_operator.environments),
)
end

Base.length(reduced_operator::ReducedFitProblem) = length(reduced_operator.input_state)

function ITensorMPS.set_nsite!(reduced_operator::ReducedFitProblem, nsite)
reduced_operator.nsite = nsite
return reduced_operator
end

function ITensorMPS.makeL!(reduced_operator::ReducedFitProblem, state::MPS, k::Int)
# Save the last `L` that is made to help with caching
# for DiskProjMPO
ll = reduced_operator.lpos
if ll k
# Special case when nothing has to be done.
# Still need to change the position if lproj is
# being moved backward.
reduced_operator.lpos = k
return nothing
end
# Make sure ll is at least 0 for the generic logic below
ll = max(ll, 0)
L = lproj(reduced_operator)
while ll < k
L = L * reduced_operator.input_state[ll+1] * dag(state[ll+1])
reduced_operator.environments[ll+1] = L
ll += 1
end
# Needed when moving lproj backward.
reduced_operator.lpos = k
return reduced_operator
end

function ITensorMPS.makeR!(reduced_operator::ReducedFitProblem, state::MPS, k::Int)
# Save the last `R` that is made to help with caching
# for DiskProjMPO
rl = reduced_operator.rpos
if rl k
# Special case when nothing has to be done.
# Still need to change the position if rproj is
# being moved backward.
reduced_operator.rpos = k
return nothing
end
N = length(state)
# Make sure rl is no bigger than `N + 1` for the generic logic below
rl = min(rl, N + 1)
R = rproj(reduced_operator)
while rl > k
R = R * reduced_operator.input_state[rl-1] * dag(state[rl-1])
reduced_operator.environments[rl-1] = R
rl -= 1
end
reduced_operator.rpos = k
return reduced_operator
end


struct ReducedFitMPSsProblem <: AbstractProjMPO
problems::Vector{ReducedFitProblem}
coeffs::Vector{<:Number}
end

function ReducedFitMPSsProblem(
input_states::AbstractVector{MPS},
coeffs::AbstractVector{<:Number},
)
ReducedFitMPSsProblem(ReducedFitProblem.(input_states), coeffs)
end

function Base.copy(reduced_operator::ReducedFitMPSsProblem)
return ReducedFitMPSsProblem(reduced_operator.problems, reduced_operator.coeffs)
end

function Base.getproperty(reduced_operator::ReducedFitMPSsProblem, sym::Symbol)
if sym === :nsite
return getfield(reduced_operator, :problems)[1].nsite
end
return getfield(reduced_operator, sym)
end


Base.length(reduced_operator::ReducedFitMPSsProblem) = length(reduced_operator.problems[1])

function ITensorMPS.set_nsite!(reduced_operator::ReducedFitMPSsProblem, nsite)
for p in reduced_operator.problems
ITensorMPS.set_nsite!(p, nsite)
end
return reduced_operator
end

function ITensorMPS.makeL!(reduced_operator::ReducedFitMPSsProblem, state::MPS, k::Int)
for p in reduced_operator.problems
ITensorMPS.makeL!(p, state, k)
end
return reduced_operator
end


function ITensorMPS.makeR!(reduced_operator::ReducedFitMPSsProblem, state::MPS, k::Int)
for p in reduced_operator.problems
ITensorMPS.makeR!(p, state, k)
end
return reduced_operator
end



function _contract(P::ReducedFitProblem, v::ITensor)::ITensor
itensor_map = Union{ITensor,OneITensor}[lproj(P)]
push!(itensor_map, rproj(P))

# Reverse the contraction order of the map if
# the first tensor is a scalar (for example we
# are at the left edge of the system)
if dim(first(itensor_map)) == 1
reverse!(itensor_map)
end

# Apply the map
Hv = v
for it in itensor_map
Hv *= it
end
return Hv
end

function contract_operator_state_updater(operator::ReducedFitProblem, init; internal_kwargs)
state = ITensor(true)
for j = (operator.lpos+1):(operator.rpos-1)
state *= operator.input_state[j]
end
state = _contract(operator, state)
return state, (;)
end

function contract_operator_state_updater(
operator::ReducedFitMPSsProblem,
init;
internal_kwargs,
)
states = ITensor[]
for (p, coeff) in zip(operator.problems, operator.coeffs)
res = contract_operator_state_updater(p, init; internal_kwargs)
push!(states, coeff * res[1])
end
return sum(states), (;)
end


function contract_fit(input_state::MPS, init::MPS; coeff::Number = 1, kwargs...)
links = ITensors.sim.(linkinds(init))
init = replaceinds(linkinds, init, links)
reduced_operator = ReducedFitProblem(input_state)
return alternating_update(
reduced_operator,
init;
updater = contract_operator_state_updater,
kwargs...,
)
end


function fit(
input_states::AbstractVector{MPS},
init::MPS;
coeffs::AbstractVector{<:Number} = ones(Int, length(input_states)),
kwargs...,
)
links = ITensors.sim.(linkinds(init))
init = replaceinds(linkinds, init, links)
reduced_operator = ReducedFitMPSsProblem(input_states, coeffs)
return alternating_update(
reduced_operator,
init;
updater = contract_operator_state_updater,
kwargs...,
)
end

function fit(
input_states::AbstractVector{MPO},
init::MPO;
coeffs::AbstractVector{<:Number} = ones(Int, length(input_states)),
kwargs...,
)
:MPO
to_mps::MPO) = MPS([x for x in Ψ])

res = fit(to_mps.(input_states), to_mps(init); coeffs = coeffs, kwargs...)
return MPO([x for x in res])
end
11 changes: 8 additions & 3 deletions test/_util.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
using ITensors
using Random

function _random_mpo(sites::Vector{Vector{Index{T}}}; linkdims = 1) where {T}
_random_mpo(Random.GLOBAL_RNG, sites; linkdims = linkdims)
end

function _random_mpo(rng, sites::Vector{Vector{Index{T}}}; linkdims = 1) where {T}
N = length(sites)
links = [Index(linkdims, "Link,n=$n") for n = 1:N-1]
M = MPO(N)
M[1] = random_itensor(sites[1]..., links[1])
M[N] = random_itensor(links[N-1], sites[N]...)
M[1] = random_itensor(rng, sites[1]..., links[1])
M[N] = random_itensor(rng, links[N-1], sites[N]...)
for n = 2:N-1
M[n] = random_itensor(links[n-1], sites[n]..., links[n])
M[n] = random_itensor(rng, links[n-1], sites[n]..., links[n])
end
return M
end
Loading

0 comments on commit 3bea7e4

Please sign in to comment.