Skip to content

Commit e8223e6

Browse files
Merge pull request #10 from AGaliciaMartinez/tensor_train
Added the tensor_train class.
2 parents 408c9f4 + 0986bef commit e8223e6

File tree

6 files changed

+558
-26
lines changed

6 files changed

+558
-26
lines changed

src/qutip_tensornetwork/core/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .tensor import *
55
from .adjoint import *
66
from .mul import *
7+
from .tensor_train import *

src/qutip_tensornetwork/core/data/network.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(self, out_edges, in_edges, nodes=None, copy=True):
117117
# dynamically searching for them when necessary. This is because
118118
# searching all nodes in a large graph can be quite expensive while
119119
# keeping track of them with network operations is straightforward.
120-
self.nodes = (
120+
self._nodes = (
121121
set(nodes) if nodes else tn.reachable(self.in_edges + self.out_edges)
122122
)
123123

@@ -129,7 +129,7 @@ def __init__(self, out_edges, in_edges, nodes=None, copy=True):
129129

130130
if copy:
131131
node_dict, edge_dict = tn.copy(self.nodes)
132-
self.nodes = set(node_dict[n] for n in self.nodes)
132+
self._nodes = set(node_dict[n] for n in self.nodes)
133133
self.in_edges = [edge_dict[e] for e in self.in_edges]
134134
self.out_edges = [edge_dict[e] for e in self.out_edges]
135135

@@ -203,7 +203,7 @@ def _fast_constructor(cls, out_edges, in_edges, nodes):
203203
out = cls.__new__(cls)
204204
out.in_edges = in_edges
205205
out.out_edges = out_edges
206-
out.nodes = nodes
206+
out._nodes = nodes
207207

208208
return out
209209

@@ -267,7 +267,7 @@ def adjoint(self):
267267

268268
return Network._fast_constructor(out_edges, in_edges, nodes)
269269

270-
def contract(self, contractor=greedy, final_edge_order=None):
270+
def contract(self, contractor=greedy, copy=True):
271271
"""Return the contracted version of the tensor network.
272272
273273
Parameters
@@ -277,7 +277,8 @@ def contract(self, contractor=greedy, final_edge_order=None):
277277
``tensornetwork.contractor.greedy``, which uses the greedy
278278
algorithm from `opt_einsum` to determine a contraction order.
279279
280-
final_edge_order: iterable of tensornetwork.Edges
280+
copy: bool
281+
Default True. If False, perform the operation in-place.
281282
282283
Returns
283284
-------
@@ -289,19 +290,20 @@ def contract(self, contractor=greedy, final_edge_order=None):
289290
tensornetwork.contractor: This module contains other functions that
290291
can be used instead of ``greedy``.
291292
"""
292-
nodes_dict, edges_dict = tn.copy(self.nodes)
293-
294-
in_edges = [edges_dict[e] for e in self.in_edges]
295-
out_edges = [edges_dict[e] for e in self.out_edges]
296-
nodes = set(nodes_dict[n] for n in self.nodes if n in nodes_dict)
293+
if copy:
294+
out = self.copy()
295+
return out.contract(contractor, copy=False)
297296

298-
if final_edge_order is not None:
299-
final_edge_order = [edges_dict[e] for e in final_edge_order]
300-
nodes = set([contractor(nodes, output_edge_order=final_edge_order)])
301-
else:
302-
nodes = set([contractor(nodes, ignore_edge_order=True)])
297+
nodes_dict, edges_dict = tn.copy(self.nodes)
298+
nodes = set(
299+
[contractor(self.nodes, output_edge_order=self.out_edges + self.in_edges)]
300+
)
301+
self._nodes = nodes
302+
return self
303303

304-
return Network._fast_constructor(out_edges, in_edges, nodes)
304+
@property
305+
def nodes(self):
306+
return self._nodes
305307

306308
def to_array(self, contractor=greedy):
307309
"""Returns a 2D array that represents the contraction of the tensor
@@ -323,7 +325,7 @@ def to_array(self, contractor=greedy):
323325
The final tensor representing the operator.
324326
"""
325327
final_edge_order = self.out_edges + self.in_edges
326-
network = self.contract(contractor, final_edge_order=final_edge_order)
328+
network = self.contract(contractor)
327329
nodes = network.nodes
328330
if len(nodes) != 1:
329331
raise ValueError(
@@ -390,30 +392,30 @@ def from_2d_array(cls, array):
390392

391393
if len(shape) == 1:
392394
node = tn.Node(array)
393-
return Network(node[:], [])
395+
return cls(node[:], [])
394396

395397
if len(shape) == 0:
396398
node = tn.Node(array)
397-
return Network([], [], [node])
399+
return cls([], [], [node])
398400

399401
if array.shape[0] == 1 and array.shape[1] != 1:
400402
array = array.reshape(array.shape[1])
401403
node = tn.Node(array)
402-
return Network([], node[:])
404+
return cls([], node[:])
403405

404406
elif array.shape[0] != 1 and array.shape[1] == 1:
405407
array = array.reshape(array.shape[0])
406408
node = tn.Node(array)
407-
return Network(node[:], [])
409+
return cls(node[:], [])
408410

409411
elif array.shape[0] == 1 and array.shape[1] == 1:
410412
array = array.reshape(())
411413
node = tn.Node(array)
412-
return Network([], [], nodes=[node])
414+
return cls([], [], nodes=[node])
413415

414416
else:
415417
node = tn.Node(array)
416-
return Network(node[0:1], node[1:])
418+
return cls(node[0:1], node[1:])
417419

418420
def partial_trace(self, subsystems_to_trace_out):
419421
"""NOT IMPLEMENTED YET.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .tensor_train import *

0 commit comments

Comments
 (0)