Skip to content

Commit

Permalink
Rename SimpleTensorNetwork => TensorNetwork
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed Nov 15, 2024
1 parent bf7a420 commit 2703cfb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 23 deletions.
36 changes: 17 additions & 19 deletions src/simpletensornetwork.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
"""
Generic tensor network data structure
"""
mutable struct SimpleTensorNetwork <: AbstractDataGraph{Int,IndexedArray,IndexedArray}
mutable struct TensorNetwork <: AbstractDataGraph{Int,IndexedArray,IndexedArray}
# data_graph: (undirected) graph of the tensor network
# An integer is assigned to each vertex (starting from 1 and increasing one by one).
# We can place an IndexedArray at each vertex of the graph, and an edge between two vertices.
# But, the latter is not supported by the current implementation of SimpleTensorNetworks.jl.
# This may be useful for supporting the Vidal notation.
data_graph::DataGraph{Int,IndexedArray,IndexedArray,NamedGraph{Int},NamedEdge{Int}}

function SimpleTensorNetwork(
function TensorNetwork(
dg::DataGraph{Int,IndexedArray,IndexedArray,NamedGraph{Int},NamedEdge{Int}},
)
is_connected(dg) ||
error("SimpleTensorNetwork is only supported for a connected graph.")
is_connected(dg) || error("TensorNetwork is only supported for a connected graph.")
new(dg)
end
end

function SimpleTensorNetwork(ts::AbstractVector{<:AbstractIndexedArray})
function TensorNetwork(ts::AbstractVector{<:AbstractIndexedArray})
g = NamedGraph(collect(eachindex(ts)))
dg = DataGraph{Int,IndexedArray,IndexedArray,NamedGraph{Int},NamedEdge{Int}}(g)

Expand All @@ -32,34 +31,33 @@ function SimpleTensorNetwork(ts::AbstractVector{<:AbstractIndexedArray})
end
end
end
tn = SimpleTensorNetwork(dg)
tn = TensorNetwork(dg)
return tn
end

data_graph(tn::SimpleTensorNetwork) = getfield(tn, :data_graph)
data_graph_type(TN::Type{<:SimpleTensorNetwork}) = fieldtype(TN, :data_graph)
DataGraphs.underlying_graph(tn::SimpleTensorNetwork) = underlying_graph(data_graph(tn))
DataGraphs.underlying_graph_type(TN::Type{<:SimpleTensorNetwork}) =
data_graph(tn::TensorNetwork) = getfield(tn, :data_graph)
data_graph_type(TN::Type{<:TensorNetwork}) = fieldtype(TN, :data_graph)
DataGraphs.underlying_graph(tn::TensorNetwork) = underlying_graph(data_graph(tn))
DataGraphs.underlying_graph_type(TN::Type{<:TensorNetwork}) =
fieldtype(data_graph_type(TN), :underlying_graph)
DataGraphs.vertex_data(graph::SimpleTensorNetwork, args...) =
DataGraphs.vertex_data(graph::TensorNetwork, args...) =
vertex_data(data_graph(graph), args...)
DataGraphs.edge_data(graph::SimpleTensorNetwork, args...) =
edge_data(data_graph(graph), args...)
DataGraphs.edge_data(graph::TensorNetwork, args...) = edge_data(data_graph(graph), args...)

function Base.setindex!(tn::SimpleTensorNetwork, t::AbstractIndexedArray, v::Int)
function Base.setindex!(tn::TensorNetwork, t::AbstractIndexedArray, v::Int)
tn.data_graph[v] = t
end

Base.getindex(tn::SimpleTensorNetwork, v::Int) = tn.data_graph[v]
Base.getindex(tn::TensorNetwork, v::Int) = tn.data_graph[v]

"""
Return if a tensor network `tn` has a cycle. If it has not a cycle, `tn` is a tree tensor network.
"""
Graphs.is_cyclic(tn::SimpleTensorNetwork) =
Graphs.is_cyclic(tn::TensorNetwork) =
Graphs.is_cyclic(tn.data_graph.underlying_graph.position_graph)


Graphs.has_edge(tn::SimpleTensorNetwork, e::NamedEdge) = Graphs.has_edge(tn.data_graph, e)
Graphs.has_edge(tn::TensorNetwork, e::NamedEdge) = Graphs.has_edge(tn.data_graph, e)

"""
Contract all the tensors in a tensor network `tn` and return the result.
Expand All @@ -68,7 +66,7 @@ This function works only for tree tensor networks, i.e., `is_cyclic(tn) == false
root_vertex: The vertex to start the contraction. The default is 1.
"""
function complete_contraction(tn::SimpleTensorNetwork; root_vertex::Int = 1)
function complete_contraction(tn::TensorNetwork; root_vertex::Int = 1)
!Graphs.is_cyclic(tn) ||
error("complete_contraction is not supported only for a tree tensor network.")
res = tn[root_vertex]
Expand All @@ -84,7 +82,7 @@ Contract all the tensors in a subtree of a tensor network `tn` and return the re
The subtree is defined by a vertex `v` and its parent vertex `parent_v`.
Note that `parent_v` is not included in the subtree.
"""
function _contract_subtree(tn::SimpleTensorNetwork, v::Int, parent_v::Union{Int,Nothing})
function _contract_subtree(tn::TensorNetwork, v::Int, parent_v::Union{Int,Nothing})
res = tn[v]
for nv in neighbors(tn.data_graph, v)
if nv != parent_v
Expand Down
7 changes: 3 additions & 4 deletions test/simpletensornetwork_test.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
@testitem "simpletensornetwork.jl" begin
import SimpleTensorNetworks:
Index, dim, IndexedArray, indices, permute, SimpleTensorNetwork
import SimpleTensorNetworks: Index, dim, IndexedArray, indices, permute, TensorNetwork
import Graphs: is_connected, has_edge

@testset "Construction from IndexedArray objects" begin
Expand All @@ -14,7 +13,7 @@
t3 = IndexedArray(rand(2, 2), [c, d])


tn = SimpleTensorNetwork([t1, t2, t3])
tn = TensorNetwork([t1, t2, t3])

@test has_edge(tn, 1 => 2)
@test has_edge(tn, 2 => 1)
Expand All @@ -34,7 +33,7 @@
t3 = IndexedArray(rand(2), [b])
t4 = IndexedArray(rand(2), [c])

tn = SimpleTensorNetwork([t1, t2, t3, t4])
tn = TensorNetwork([t1, t2, t3, t4])
@test only(SimpleTensorNetworks.complete_contraction(tn; root_vertex = 1))
only(t1 * t2 * t3 * t4)
end
Expand Down

0 comments on commit 2703cfb

Please sign in to comment.