Skip to content

Commit 78516d8

Browse files
jjrenliwt31
authored andcommitted
optimize mpo construction
1 parent b80d923 commit 78516d8

File tree

3 files changed

+40
-61
lines changed

3 files changed

+40
-61
lines changed

renormalizer/mps/mpo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,12 +271,12 @@ def __init__(self, model: Model = None, terms: Union[Op, List[Op]] = None, offse
271271
if len(terms) == 0:
272272
raise ValueError("Terms all have factor 0.")
273273

274-
table, factor = _terms_to_table(model, terms, -self.offset)
274+
table, primary_ops, factor = _terms_to_table(model, terms, -self.offset)
275275

276276
self.dtype = factor.dtype
277277

278278
self.symbolic_mpo, self.qn, self.qntot, self.qnidx, self.symbolic_out_ops_list, self.primary_ops \
279-
= construct_symbolic_mpo(table, factor, algo=algo)
279+
= construct_symbolic_mpo(table, primary_ops, factor, algo=algo)
280280
# from renormalizer.mps.symbolic_mpo import _format_symbolic_mpo
281281
# print(_format_symbolic_mpo(mpo_symbol))
282282
self.model = model

renormalizer/mps/symbolic_mpo.py

Lines changed: 36 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
OpTuple = namedtuple("OpTuple", ["symbol", "qn", "factor"])
2020

2121

22-
def construct_symbolic_mpo(table, factor, algo="Hopcroft-Karp"):
22+
def construct_symbolic_mpo(table, primary_ops, factor, algo="Hopcroft-Karp"):
2323
r"""
2424
A General Compact (Symbolic) MPO Construction Routine
2525
@@ -102,17 +102,18 @@ def construct_symbolic_mpo(table, factor, algo="Hopcroft-Karp"):
102102
The local mpo is the transformation matrix between 0'',1'' to 0'''
103103
"""
104104

105-
qn_size = len(table[0][0].qn)
105+
qn_size = len(primary_ops[0].qn)
106+
106107
# Simplest case. Cut to the chase
107-
if len(table) == 1:
108+
if table.shape[0] == 1:
108109
# The first layer: number of sites. The middle array: in and out virtual bond
109110
# the 4th layer: operator sums
110111
mpo: List[np.ndarray[List[Op]]] = []
111112
mpoqn = [np.zeros((1, qn_size), dtype=int)]
112-
primary_ops = list(set(table[0]))
113113
op2idx = dict(zip(primary_ops, range(len(primary_ops))))
114114
out_ops_list: List[List[OpTuple]] = [[OpTuple([0], qn=0, factor=1)]]
115-
for op in table[0]:
115+
for idx in table[0]:
116+
op = primary_ops[idx]
116117
mo = np.full((1, 1), None)
117118
mo[0][0] = [op]
118119
mpo.append(mo)
@@ -130,14 +131,10 @@ def construct_symbolic_mpo(table, factor, algo="Hopcroft-Karp"):
130131
return mpo, mpoqn, qntot, qnidx, out_ops_list, primary_ops
131132

132133
logger.debug(f"symbolic mpo algorithm: {algo}")
133-
logger.debug(f"Input operator terms: {len(table)}")
134-
135-
table, factor, primary_ops = _transform_table(table, factor)
136134

137135
# add the first and last column for convenience
138136
ta = np.zeros((table.shape[0], 1), dtype=np.uint16)
139137
table = np.concatenate((ta, table, ta), axis=1)
140-
logger.debug(f"After combination of the same terms: {table.shape[0]}")
141138

142139
# 0 represents the identity symbol. Identity might not present
143140
# in `primary_ops` but the algorithm still works.
@@ -364,22 +361,38 @@ def _terms_to_table(model: Model, terms: List[Op], const: float):
364361

365362
table = []
366363
factor_list = []
367-
364+
365+
primary_ops_eachsite = []
366+
primary_ops = []
367+
368+
index = 0
369+
368370
dummy_table_entry = []
369371
for b in model.basis:
370372
if b.multi_dof:
371373
dof = b.dof[0]
372374
else:
373375
dof = b.dof
374376
op = Op.identity(dof, qn_size=model.qn_size)
375-
dummy_table_entry.append(op)
377+
378+
primary_ops_eachsite.append({op:index})
379+
primary_ops.append(op)
380+
dummy_table_entry.append(index)
381+
index += 1
382+
376383
for op in terms:
377384
elem_ops, factor = op.split_elementary(model.dof_to_siteidx)
378385
table_entry = dummy_table_entry.copy()
386+
379387
for elem_op in elem_ops:
380388
# it is ensured in `elem_op` every symbol is on the same site
381389
site_idx = model.dof_to_siteidx[elem_op.dofs[0]]
382-
table_entry[site_idx] = elem_op
390+
if elem_op not in primary_ops_eachsite[site_idx].keys():
391+
primary_ops_eachsite[site_idx][elem_op] = index
392+
primary_ops.append(elem_op)
393+
index += 1
394+
table_entry[site_idx] = primary_ops_eachsite[site_idx][elem_op]
395+
383396
table.append(table_entry)
384397
factor_list.append(factor)
385398

@@ -389,10 +402,19 @@ def _terms_to_table(model: Model, terms: List[Op], const: float):
389402
factor_list.append(const)
390403
table.append(table_entry)
391404

392-
factor_list = np.array(factor_list)
405+
factor = np.array(factor_list)
393406
logger.debug(f"# of operator terms: {len(table)}")
407+
408+
# use np.uint32, np.uint16 to save memory
409+
max_uint16 = np.iinfo(np.uint16).max
410+
assert len(primary_ops) < max_uint16
411+
412+
table = np.array(table, dtype=np.uint16)
413+
logger.debug(f"Input operator terms: {table.shape[0]}")
414+
table, factor = _deduplicate_table(table, factor)
415+
logger.debug(f"After combination of the same terms: {table.shape[0]}")
394416

395-
return table, factor_list
417+
return table, primary_ops, factor
396418

397419

398420
def _deduplicate_table(table, factor):
@@ -416,48 +438,6 @@ def _deduplicate_table(table, factor):
416438
return new_table, factor
417439

418440

419-
def _transform_table(table, factor):
420-
"""Transforms the table to integer table and combine duplicate terms."""
421-
422-
# use np.uint32, np.uint16 to save memory
423-
max_uint16 = np.iinfo(np.uint16).max
424-
425-
# translate the symbolic operator table to an easy to manipulate numpy array
426-
table = np.array(table)
427-
# unique operators with DoF names taken into consideration
428-
# The inclusion of DoF names is necessary for multi-dof basis.
429-
unique_op = OrderedDict.fromkeys(table.ravel())
430-
# Convert Set(table.ravel()) to List will change the Op order in list, OrderedDict made reproducible
431-
unique_op = list(unique_op.keys())
432-
433-
# check the index of different operators could be represented with np.uint16
434-
assert len(unique_op) < max_uint16
435-
436-
# Construct mapping from easy-to-manipulate integer to actual Op
437-
primary_ops = list(unique_op)
438-
439-
op2idx = dict(zip(unique_op, range(len(unique_op))))
440-
new_table = np.vectorize(op2idx.get)(table).astype(np.uint16)
441-
442-
del unique_op
443-
444-
if __debug__:
445-
qn_size = len(table[0][0].qn)
446-
qn_table = np.array([[x.qn for x in ta] for ta in table])
447-
factor_table = np.array([[x.factor for x in ta] for ta in table])
448-
for idx in range(len(primary_ops)):
449-
coord = np.nonzero(new_table == idx)
450-
# check that op with the same symbol has the same factor and qn
451-
for j in range(qn_size):
452-
assert np.unique(qn_table[:, :, j][coord]).size == 1
453-
assert np.all(factor_table[coord] == factor_table[coord][0])
454-
455-
del factor_table, qn_table
456-
457-
table, factor = _deduplicate_table(new_table, factor)
458-
459-
return table, factor, primary_ops
460-
461441

462442
# translate the numbers into symbolic Matrix Operator
463443
def compose_symbolic_mo(in_ops, out_ops, primary_ops):

renormalizer/tn/symbolic_ttno.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from renormalizer import Op, Model
77
from renormalizer.model.basis import BasisSet
88
from renormalizer.tn.treebase import BasisTree
9-
from renormalizer.mps.symbolic_mpo import _terms_to_table, _transform_table, _construct_symbolic_mpo_one_site, OpTuple
9+
from renormalizer.mps.symbolic_mpo import _terms_to_table, _construct_symbolic_mpo_one_site, OpTuple
1010

1111

1212
logger = logging.getLogger(__name__)
@@ -57,8 +57,7 @@ def construct_symbolic_ttno(tn: BasisTree, terms: List[Op], const: float = 0, al
5757
basis = list(chain(*[n.basis_sets for n in nodes]))
5858
model = Model(basis, [])
5959
qn_size = model.qn_size
60-
table, factor = _terms_to_table(model, terms, const)
61-
table, factor, primary_ops = _transform_table(table, factor)
60+
table, primary_ops, factor = _terms_to_table(model, terms, const)
6261

6362
dummy_in_ops = [[OpTuple([0], qn=np.zeros(qn_size, dtype=int), factor=1)]]
6463
out_ops: List[List[OpTuple]]

0 commit comments

Comments
 (0)