Skip to content

Commit 69f28e5

Browse files
committed
TN equalize_norms add check_zero option and turn on for simplifying
1 parent 1dc14d5 commit 69f28e5

File tree

1 file changed

+72
-17
lines changed

1 file changed

+72
-17
lines changed

quimb/tensor/tensor_core.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9011,7 +9011,8 @@ def insert_compressor_between_regions(
90119011

90129012
# then form the 'oblique' projectors
90139013
Pl, Pr = decomp.compute_oblique_projectors(
9014-
Rl, Rr,
9014+
Rl,
9015+
Rr,
90159016
max_bond=max_bond,
90169017
cutoff=cutoff,
90179018
**compress_opts,
@@ -9389,7 +9390,7 @@ def randomize(self, dtype=None, seed=None, inplace=False, **randn_opts):
93899390

93909391
randomize_ = functools.partialmethod(randomize, inplace=True)
93919392

9392-
def strip_exponent(self, tid_or_tensor, value=None):
9393+
def strip_exponent(self, tid_or_tensor, value=None, check_zero=False):
93939394
"""Scale the elements of tensor corresponding to ``tid`` so that the
93949395
norm of the array is some value, which defaults to ``1``. The log of
93959396
the scaling factor, base 10, is then accumulated in the ``exponent``
@@ -9401,6 +9402,11 @@ def strip_exponent(self, tid_or_tensor, value=None):
94019402
The tensor identifier or actual tensor.
94029403
value : None or float, optional
94039404
The value to scale the norm of the tensor to.
9405+
check_zero : bool, optional
9406+
Whether to check if the tensor has zero norm and in that case do
9407+
nothing, since the `exponent` would be -inf. Off by default to
9408+
avoid data dependent computational graphs when tracing and
9409+
computing gradients etc.
94049410
"""
94059411
if (value is None) or (value is True):
94069412
value = 1.0
@@ -9411,6 +9417,10 @@ def strip_exponent(self, tid_or_tensor, value=None):
94119417
t = self.tensor_map[tid_or_tensor]
94129418

94139419
stripped_factor = t.norm() / value
9420+
9421+
if check_zero and (stripped_factor == 0.0):
9422+
return
9423+
94149424
t.modify(apply=lambda data: data / stripped_factor)
94159425
self.exponent = self.exponent + do("log10", stripped_factor)
94169426

@@ -9425,7 +9435,7 @@ def distribute_exponent(self):
94259435
# reset the exponent to zero
94269436
self.exponent = 0.0
94279437

9428-
def equalize_norms(self, value=None, inplace=False):
9438+
def equalize_norms(self, value=None, check_zero=False, inplace=False):
94299439
"""Make the Frobenius norm of every tensor in this TN equal without
94309440
changing the overall value if ``value=None``, or set the norm of every
94319441
tensor to ``value`` by scalar multiplication only.
@@ -9436,6 +9446,11 @@ def equalize_norms(self, value=None, inplace=False):
94369446
Set the norm of each tensor to this value specifically. If supplied
94379447
the change in overall scaling will be accumulated in
94389448
``tn.exponent`` in the form of a base 10 power.
9449+
check_zero : bool, optional
9450+
Whether, if and when equalizing norms, to check if tensors have
9451+
zero norm and in that case do nothing, since the `exponent` would
9452+
be -inf. Off by default to avoid data dependent computational
9453+
graphs when tracing and computing gradients etc.
94399454
inplace : bool, optional
94409455
Whether to perform the norm equalization inplace or not.
94419456
@@ -9446,7 +9461,7 @@ def equalize_norms(self, value=None, inplace=False):
94469461
tn = self if inplace else self.copy()
94479462

94489463
for tid in tn.tensor_map:
9449-
tn.strip_exponent(tid, value=value)
9464+
tn.strip_exponent(tid, value=value, check_zero=check_zero)
94509465

94519466
if value is None:
94529467
tn.distribute_exponent()
@@ -9591,6 +9606,7 @@ def rank_simplify(
95919606
equalize_norms=False,
95929607
cache=None,
95939608
max_combinations=500,
9609+
check_zero=False,
95949610
inplace=False,
95959611
):
95969612
"""Simplify this tensor network by performing contractions that don't
@@ -9607,6 +9623,11 @@ def rank_simplify(
96079623
exponent in ``tn.exponent``.
96089624
cache : None or set
96099625
Persistent cache used to mark already checked tensors.
9626+
check_zero : bool, optional
9627+
Whether, if and when equalizing norms, to check if tensors have
9628+
zero norm and in that case do nothing, since the `exponent` would
9629+
be -inf. Off by default to avoid data dependent computational
9630+
graphs when tracing and computing gradients etc.
96109631
inplace : bool, optional
96119632
Whether to perform the rand reduction inplace.
96129633
@@ -9752,18 +9773,24 @@ def rank_weight(ind):
97529773
tn |= tab
97539774

97549775
if equalize_norms:
9755-
tn.strip_exponent(tab, equalize_norms)
9776+
tn.strip_exponent(tab, equalize_norms, check_zero=check_zero)
97569777

97579778
for ix in out_ab:
97589779
# now we need to check outputs indices again
97599780
queue.add(ix)
97609781

97619782
if scalars:
97629783
if equalize_norms:
9784+
# move overall scaling factor into exponent, absorb phase
97639785
signs = []
97649786
for s in scalars:
9765-
signs.append(s / do("abs", s))
9766-
tn.exponent += do("log10", do("abs", s))
9787+
sa = do("abs", s)
9788+
if check_zero and (sa == 0.0):
9789+
# whole contraction is zero
9790+
signs = [0.0]
9791+
break
9792+
signs.append(s / sa)
9793+
tn.exponent += do("log10", sa)
97679794
scalars = signs
97689795

97699796
if tn.num_tensors:
@@ -10023,6 +10050,7 @@ def split_simplify(
1002310050
atol=1e-12,
1002410051
equalize_norms=False,
1002510052
cache=None,
10053+
check_zero=False,
1002610054
inplace=False,
1002710055
**split_opts,
1002810056
):
@@ -10039,6 +10067,11 @@ def split_simplify(
1003910067
exponent in ``tn.exponent``.
1004010068
cache : None or set
1004110069
Persistent cache used to mark already checked tensors.
10070+
check_zero : bool, optional
10071+
Whether, if and when equalizing norms, to check if tensors have
10072+
zero norm and in that case do nothing, since the `exponent` would
10073+
be -inf. Off by default to avoid data dependent computational
10074+
graphs when tracing and computing gradients etc.
1004210075
inplace, bool, optional
1004310076
Whether to perform the split simplification inplace.
1004410077
"""
@@ -10075,8 +10108,12 @@ def split_simplify(
1007510108
tn |= tr
1007610109

1007710110
if equalize_norms:
10078-
tn.strip_exponent(tl, equalize_norms)
10079-
tn.strip_exponent(tr, equalize_norms)
10111+
tn.strip_exponent(
10112+
tl, equalize_norms, check_zero=check_zero
10113+
)
10114+
tn.strip_exponent(
10115+
tr, equalize_norms, check_zero=check_zero
10116+
)
1008010117

1008110118
else:
1008210119
cache.add(cache_key)
@@ -10093,6 +10130,7 @@ def pair_simplify(
1009310130
cache=None,
1009410131
equalize_norms=False,
1009510132
max_combinations=500,
10133+
check_zero=False,
1009610134
inplace=False,
1009710135
**split_opts,
1009810136
):
@@ -10180,8 +10218,8 @@ def gen_pairs():
1018010218

1018110219
tensor_fuse_squeeze(tl, tr)
1018210220
if equalize_norms:
10183-
tn.strip_exponent(tl, equalize_norms)
10184-
tn.strip_exponent(tr, equalize_norms)
10221+
tn.strip_exponent(tl, equalize_norms, check_zero=check_zero)
10222+
tn.strip_exponent(tr, equalize_norms, check_zero=check_zero)
1018510223

1018610224
queue.extend(tl.inds)
1018710225
queue.extend(tr.inds)
@@ -10199,6 +10237,7 @@ def loop_simplify(
1019910237
loops=None,
1020010238
cache=None,
1020110239
equalize_norms=False,
10240+
check_zero=False,
1020210241
inplace=False,
1020310242
**split_opts,
1020410243
):
@@ -10218,6 +10257,11 @@ def loop_simplify(
1021810257
cache : set, optional
1021910258
For performance reasons can supply a cache for already checked
1022010259
loops.
10260+
check_zero : bool, optional
10261+
Whether, if and when equalizing norms, to check if tensors have
10262+
zero norm and in that case do nothing, since the `exponent` would
10263+
be -inf. Off by default to avoid data dependent computational
10264+
graphs when tracing and computing gradients etc.
1022110265
inplace : bool, optional
1022210266
Whether to replace the loops inplace.
1022310267
split_opts
@@ -10298,8 +10342,8 @@ def loop_simplify(
1029810342

1029910343
tensor_fuse_squeeze(tl, tr)
1030010344
if equalize_norms:
10301-
tn.strip_exponent(tl, equalize_norms)
10302-
tn.strip_exponent(tr, equalize_norms)
10345+
tn.strip_exponent(tl, equalize_norms, check_zero=check_zero)
10346+
tn.strip_exponent(tr, equalize_norms, check_zero=check_zero)
1030310347

1030410348
return tn
1030510349

@@ -10312,13 +10356,14 @@ def full_simplify(
1031210356
atol=1e-12,
1031310357
equalize_norms=False,
1031410358
cache=None,
10315-
inplace=False,
10316-
progbar=False,
1031710359
rank_simplify_opts=None,
1031810360
loop_simplify_opts=None,
1031910361
split_simplify_opts=None,
1032010362
custom_methods=(),
1032110363
split_method="svd",
10364+
check_zero=True,
10365+
inplace=False,
10366+
progbar=False,
1032210367
):
1032310368
"""Perform a series of tensor network 'simplifications' in a loop until
1032410369
there is no more reduction in the number of tensors or indices. Note
@@ -10357,6 +10402,9 @@ def full_simplify(
1035710402
cache : None or set
1035810403
A persistent cache for each simplification process to mark
1035910404
already processed tensors.
10405+
check_zero : bool, optional
10406+
Whether to check if tensors have zero norm and in that case do
10407+
nothing if and when equalizing norms, rather than generating a NaN.
1036010408
progbar : bool, optional
1036110409
Show a live progress bar of the simplification process.
1036210410
inplace : bool, optional
@@ -10422,6 +10470,7 @@ def full_simplify(
1042210470
output_inds=ix_o,
1042310471
cache=cache,
1042410472
equalize_norms=equalize_norms,
10473+
check_zero=check_zero,
1042510474
**rank_simplify_opts,
1042610475
)
1042710476
elif meth == "A":
@@ -10435,6 +10484,7 @@ def full_simplify(
1043510484
atol=atol,
1043610485
cache=cache,
1043710486
equalize_norms=equalize_norms,
10487+
check_zero=check_zero,
1043810488
**split_simplify_opts,
1043910489
)
1044010490
elif meth == "L":
@@ -10443,6 +10493,7 @@ def full_simplify(
1044310493
cutoff=atol,
1044410494
cache=cache,
1044510495
equalize_norms=equalize_norms,
10496+
check_zero=check_zero,
1044610497
**loop_simplify_opts,
1044710498
)
1044810499
elif meth == "P":
@@ -10451,6 +10502,7 @@ def full_simplify(
1045110502
cutoff=atol,
1045210503
cache=cache,
1045310504
equalize_norms=equalize_norms,
10505+
check_zero=check_zero,
1045410506
**loop_simplify_opts,
1045510507
)
1045610508
else:
@@ -10462,9 +10514,10 @@ def full_simplify(
1046210514
if equalize_norms:
1046310515
if equalize_norms is True:
1046410516
# this also redistributes the collected exponents
10465-
tn.equalize_norms_()
10517+
value = None
1046610518
else:
10467-
tn.equalize_norms_(value=equalize_norms)
10519+
value = equalize_norms
10520+
tn.equalize_norms_(value=value, check_zero=check_zero)
1046810521

1046910522
if progbar:
1047010523
pbar.close()
@@ -10594,6 +10647,7 @@ def compress_simplify(
1059410647
max_simplification_iterations=100,
1059510648
converged_tol=0.01,
1059610649
equalize_norms=True,
10650+
check_zero=True,
1059710651
progbar=False,
1059810652
inplace=False,
1059910653
**full_simplify_opts,
@@ -10606,6 +10660,7 @@ def compress_simplify(
1060610660
simplify_opts = {
1060710661
"atol": atol,
1060810662
"equalize_norms": equalize_norms,
10663+
"check_zero": check_zero,
1060910664
"progbar": progbar,
1061010665
"output_inds": output_inds,
1061110666
"cache": set(),

0 commit comments

Comments
 (0)