@@ -117,7 +117,7 @@ def __init__(self, out_edges, in_edges, nodes=None, copy=True):
117
117
# dynamically searching for them when necessary. This is because
118
118
# searching all nodes in a large graph can be quite expensive while
119
119
# keeping track of them with network operations is straightforward.
120
- self .nodes = (
120
+ self ._nodes = (
121
121
set (nodes ) if nodes else tn .reachable (self .in_edges + self .out_edges )
122
122
)
123
123
@@ -129,7 +129,7 @@ def __init__(self, out_edges, in_edges, nodes=None, copy=True):
129
129
130
130
if copy :
131
131
node_dict , edge_dict = tn .copy (self .nodes )
132
- self .nodes = set (node_dict [n ] for n in self .nodes )
132
+ self ._nodes = set (node_dict [n ] for n in self .nodes )
133
133
self .in_edges = [edge_dict [e ] for e in self .in_edges ]
134
134
self .out_edges = [edge_dict [e ] for e in self .out_edges ]
135
135
@@ -203,7 +203,7 @@ def _fast_constructor(cls, out_edges, in_edges, nodes):
203
203
out = cls .__new__ (cls )
204
204
out .in_edges = in_edges
205
205
out .out_edges = out_edges
206
- out .nodes = nodes
206
+ out ._nodes = nodes
207
207
208
208
return out
209
209
@@ -267,7 +267,7 @@ def adjoint(self):
267
267
268
268
return Network ._fast_constructor (out_edges , in_edges , nodes )
269
269
270
- def contract (self , contractor = greedy , final_edge_order = None ):
270
+ def contract (self , contractor = greedy , copy = True ):
271
271
"""Return the contracted version of the tensor network.
272
272
273
273
Parameters
@@ -277,7 +277,8 @@ def contract(self, contractor=greedy, final_edge_order=None):
277
277
``tensornetwork.contractor.greedy``, which uses the greedy
278
278
algorithm from `opt_einsum` to determine a contraction order.
279
279
280
- final_edge_order: iterable of tensornetwork.Edges
280
+ copy: bool
281
+ Default True. If False, perform the operation in-place.
281
282
282
283
Returns
283
284
-------
@@ -289,19 +290,20 @@ def contract(self, contractor=greedy, final_edge_order=None):
289
290
tensornetwork.contractor: This module contains other functions that
290
291
can be used instead of ``greedy``.
291
292
"""
292
- nodes_dict , edges_dict = tn .copy (self .nodes )
293
-
294
- in_edges = [edges_dict [e ] for e in self .in_edges ]
295
- out_edges = [edges_dict [e ] for e in self .out_edges ]
296
- nodes = set (nodes_dict [n ] for n in self .nodes if n in nodes_dict )
293
+ if copy :
294
+ out = self .copy ()
295
+ return out .contract (contractor , copy = False )
297
296
298
- if final_edge_order is not None :
299
- final_edge_order = [edges_dict [e ] for e in final_edge_order ]
300
- nodes = set ([contractor (nodes , output_edge_order = final_edge_order )])
301
- else :
302
- nodes = set ([contractor (nodes , ignore_edge_order = True )])
297
+ nodes_dict , edges_dict = tn .copy (self .nodes )
298
+ nodes = set (
299
+ [contractor (self .nodes , output_edge_order = self .out_edges + self .in_edges )]
300
+ )
301
+ self ._nodes = nodes
302
+ return self
303
303
304
- return Network ._fast_constructor (out_edges , in_edges , nodes )
304
+ @property
305
+ def nodes (self ):
306
+ return self ._nodes
305
307
306
308
def to_array (self , contractor = greedy ):
307
309
"""Returns a 2D array that represents the contraction of the tensor
@@ -323,7 +325,7 @@ def to_array(self, contractor=greedy):
323
325
The final tensor representing the operator.
324
326
"""
325
327
final_edge_order = self .out_edges + self .in_edges
326
- network = self .contract (contractor , final_edge_order = final_edge_order )
328
+ network = self .contract (contractor )
327
329
nodes = network .nodes
328
330
if len (nodes ) != 1 :
329
331
raise ValueError (
@@ -390,30 +392,30 @@ def from_2d_array(cls, array):
390
392
391
393
if len (shape ) == 1 :
392
394
node = tn .Node (array )
393
- return Network (node [:], [])
395
+ return cls (node [:], [])
394
396
395
397
if len (shape ) == 0 :
396
398
node = tn .Node (array )
397
- return Network ([], [], [node ])
399
+ return cls ([], [], [node ])
398
400
399
401
if array .shape [0 ] == 1 and array .shape [1 ] != 1 :
400
402
array = array .reshape (array .shape [1 ])
401
403
node = tn .Node (array )
402
- return Network ([], node [:])
404
+ return cls ([], node [:])
403
405
404
406
elif array .shape [0 ] != 1 and array .shape [1 ] == 1 :
405
407
array = array .reshape (array .shape [0 ])
406
408
node = tn .Node (array )
407
- return Network (node [:], [])
409
+ return cls (node [:], [])
408
410
409
411
elif array .shape [0 ] == 1 and array .shape [1 ] == 1 :
410
412
array = array .reshape (())
411
413
node = tn .Node (array )
412
- return Network ([], [], nodes = [node ])
414
+ return cls ([], [], nodes = [node ])
413
415
414
416
else :
415
417
node = tn .Node (array )
416
- return Network (node [0 :1 ], node [1 :])
418
+ return cls (node [0 :1 ], node [1 :])
417
419
418
420
def partial_trace (self , subsystems_to_trace_out ):
419
421
"""NOT IMPLEMENTED YET.
0 commit comments