@@ -9011,7 +9011,8 @@ def insert_compressor_between_regions(
9011
9011
9012
9012
# then form the 'oblique' projectors
9013
9013
Pl , Pr = decomp .compute_oblique_projectors (
9014
- Rl , Rr ,
9014
+ Rl ,
9015
+ Rr ,
9015
9016
max_bond = max_bond ,
9016
9017
cutoff = cutoff ,
9017
9018
** compress_opts ,
@@ -9389,7 +9390,7 @@ def randomize(self, dtype=None, seed=None, inplace=False, **randn_opts):
9389
9390
9390
9391
randomize_ = functools .partialmethod (randomize , inplace = True )
9391
9392
9392
- def strip_exponent (self , tid_or_tensor , value = None ):
9393
+ def strip_exponent (self , tid_or_tensor , value = None , check_zero = False ):
9393
9394
"""Scale the elements of tensor corresponding to ``tid`` so that the
9394
9395
norm of the array is some value, which defaults to ``1``. The log of
9395
9396
the scaling factor, base 10, is then accumulated in the ``exponent``
@@ -9401,6 +9402,11 @@ def strip_exponent(self, tid_or_tensor, value=None):
9401
9402
The tensor identifier or actual tensor.
9402
9403
value : None or float, optional
9403
9404
The value to scale the norm of the tensor to.
9405
+ check_zero : bool, optional
9406
+ Whether to check if the tensor has zero norm and in that case do
9407
+ nothing, since the `exponent` would be -inf. Off by default to
9408
+ avoid data dependent computational graphs when tracing and
9409
+ computing gradients etc.
9404
9410
"""
9405
9411
if (value is None ) or (value is True ):
9406
9412
value = 1.0
@@ -9411,6 +9417,10 @@ def strip_exponent(self, tid_or_tensor, value=None):
9411
9417
t = self .tensor_map [tid_or_tensor ]
9412
9418
9413
9419
stripped_factor = t .norm () / value
9420
+
9421
+ if check_zero and (stripped_factor == 0.0 ):
9422
+ return
9423
+
9414
9424
t .modify (apply = lambda data : data / stripped_factor )
9415
9425
self .exponent = self .exponent + do ("log10" , stripped_factor )
9416
9426
@@ -9425,7 +9435,7 @@ def distribute_exponent(self):
9425
9435
# reset the exponent to zero
9426
9436
self .exponent = 0.0
9427
9437
9428
- def equalize_norms (self , value = None , inplace = False ):
9438
+ def equalize_norms (self , value = None , check_zero = False , inplace = False ):
9429
9439
"""Make the Frobenius norm of every tensor in this TN equal without
9430
9440
changing the overall value if ``value=None``, or set the norm of every
9431
9441
tensor to ``value`` by scalar multiplication only.
@@ -9436,6 +9446,11 @@ def equalize_norms(self, value=None, inplace=False):
9436
9446
Set the norm of each tensor to this value specifically. If supplied
9437
9447
the change in overall scaling will be accumulated in
9438
9448
``tn.exponent`` in the form of a base 10 power.
9449
+ check_zero : bool, optional
9450
+ Whether, if and when equalizing norms, to check if tensors have
9451
+ zero norm and in that case do nothing, since the `exponent` would
9452
+ be -inf. Off by default to avoid data dependent computational
9453
+ graphs when tracing and computing gradients etc.
9439
9454
inplace : bool, optional
9440
9455
Whether to perform the norm equalization inplace or not.
9441
9456
@@ -9446,7 +9461,7 @@ def equalize_norms(self, value=None, inplace=False):
9446
9461
tn = self if inplace else self .copy ()
9447
9462
9448
9463
for tid in tn .tensor_map :
9449
- tn .strip_exponent (tid , value = value )
9464
+ tn .strip_exponent (tid , value = value , check_zero = check_zero )
9450
9465
9451
9466
if value is None :
9452
9467
tn .distribute_exponent ()
@@ -9591,6 +9606,7 @@ def rank_simplify(
9591
9606
equalize_norms = False ,
9592
9607
cache = None ,
9593
9608
max_combinations = 500 ,
9609
+ check_zero = False ,
9594
9610
inplace = False ,
9595
9611
):
9596
9612
"""Simplify this tensor network by performing contractions that don't
@@ -9607,6 +9623,11 @@ def rank_simplify(
9607
9623
exponent in ``tn.exponent``.
9608
9624
cache : None or set
9609
9625
Persistent cache used to mark already checked tensors.
9626
+ check_zero : bool, optional
9627
+ Whether, if and when equalizing norms, to check if tensors have
9628
+ zero norm and in that case do nothing, since the `exponent` would
9629
+ be -inf. Off by default to avoid data dependent computational
9630
+ graphs when tracing and computing gradients etc.
9610
9631
inplace : bool, optional
9611
9632
Whether to perform the rand reduction inplace.
9612
9633
@@ -9752,18 +9773,24 @@ def rank_weight(ind):
9752
9773
tn |= tab
9753
9774
9754
9775
if equalize_norms :
9755
- tn .strip_exponent (tab , equalize_norms )
9776
+ tn .strip_exponent (tab , equalize_norms , check_zero = check_zero )
9756
9777
9757
9778
for ix in out_ab :
9758
9779
# now we need to check outputs indices again
9759
9780
queue .add (ix )
9760
9781
9761
9782
if scalars :
9762
9783
if equalize_norms :
9784
+ # move overall scaling factor into exponent, absorb phase
9763
9785
signs = []
9764
9786
for s in scalars :
9765
- signs .append (s / do ("abs" , s ))
9766
- tn .exponent += do ("log10" , do ("abs" , s ))
9787
+ sa = do ("abs" , s )
9788
+ if check_zero and (sa == 0.0 ):
9789
+ # whole contraction is zero
9790
+ signs = [0.0 ]
9791
+ break
9792
+ signs .append (s / sa )
9793
+ tn .exponent += do ("log10" , sa )
9767
9794
scalars = signs
9768
9795
9769
9796
if tn .num_tensors :
@@ -10023,6 +10050,7 @@ def split_simplify(
10023
10050
atol = 1e-12 ,
10024
10051
equalize_norms = False ,
10025
10052
cache = None ,
10053
+ check_zero = False ,
10026
10054
inplace = False ,
10027
10055
** split_opts ,
10028
10056
):
@@ -10039,6 +10067,11 @@ def split_simplify(
10039
10067
exponent in ``tn.exponent``.
10040
10068
cache : None or set
10041
10069
Persistent cache used to mark already checked tensors.
10070
+ check_zero : bool, optional
10071
+ Whether, if and when equalizing norms, to check if tensors have
10072
+ zero norm and in that case do nothing, since the `exponent` would
10073
+ be -inf. Off by default to avoid data dependent computational
10074
+ graphs when tracing and computing gradients etc.
10042
10075
inplace, bool, optional
10043
10076
Whether to perform the split simplification inplace.
10044
10077
"""
@@ -10075,8 +10108,12 @@ def split_simplify(
10075
10108
tn |= tr
10076
10109
10077
10110
if equalize_norms :
10078
- tn .strip_exponent (tl , equalize_norms )
10079
- tn .strip_exponent (tr , equalize_norms )
10111
+ tn .strip_exponent (
10112
+ tl , equalize_norms , check_zero = check_zero
10113
+ )
10114
+ tn .strip_exponent (
10115
+ tr , equalize_norms , check_zero = check_zero
10116
+ )
10080
10117
10081
10118
else :
10082
10119
cache .add (cache_key )
@@ -10093,6 +10130,7 @@ def pair_simplify(
10093
10130
cache = None ,
10094
10131
equalize_norms = False ,
10095
10132
max_combinations = 500 ,
10133
+ check_zero = False ,
10096
10134
inplace = False ,
10097
10135
** split_opts ,
10098
10136
):
@@ -10180,8 +10218,8 @@ def gen_pairs():
10180
10218
10181
10219
tensor_fuse_squeeze (tl , tr )
10182
10220
if equalize_norms :
10183
- tn .strip_exponent (tl , equalize_norms )
10184
- tn .strip_exponent (tr , equalize_norms )
10221
+ tn .strip_exponent (tl , equalize_norms , check_zero = check_zero )
10222
+ tn .strip_exponent (tr , equalize_norms , check_zero = check_zero )
10185
10223
10186
10224
queue .extend (tl .inds )
10187
10225
queue .extend (tr .inds )
@@ -10199,6 +10237,7 @@ def loop_simplify(
10199
10237
loops = None ,
10200
10238
cache = None ,
10201
10239
equalize_norms = False ,
10240
+ check_zero = False ,
10202
10241
inplace = False ,
10203
10242
** split_opts ,
10204
10243
):
@@ -10218,6 +10257,11 @@ def loop_simplify(
10218
10257
cache : set, optional
10219
10258
For performance reasons can supply a cache for already checked
10220
10259
loops.
10260
+ check_zero : bool, optional
10261
+ Whether, if and when equalizing norms, to check if tensors have
10262
+ zero norm and in that case do nothing, since the `exponent` would
10263
+ be -inf. Off by default to avoid data dependent computational
10264
+ graphs when tracing and computing gradients etc.
10221
10265
inplace : bool, optional
10222
10266
Whether to replace the loops inplace.
10223
10267
split_opts
@@ -10298,8 +10342,8 @@ def loop_simplify(
10298
10342
10299
10343
tensor_fuse_squeeze (tl , tr )
10300
10344
if equalize_norms :
10301
- tn .strip_exponent (tl , equalize_norms )
10302
- tn .strip_exponent (tr , equalize_norms )
10345
+ tn .strip_exponent (tl , equalize_norms , check_zero = check_zero )
10346
+ tn .strip_exponent (tr , equalize_norms , check_zero = check_zero )
10303
10347
10304
10348
return tn
10305
10349
@@ -10312,13 +10356,14 @@ def full_simplify(
10312
10356
atol = 1e-12 ,
10313
10357
equalize_norms = False ,
10314
10358
cache = None ,
10315
- inplace = False ,
10316
- progbar = False ,
10317
10359
rank_simplify_opts = None ,
10318
10360
loop_simplify_opts = None ,
10319
10361
split_simplify_opts = None ,
10320
10362
custom_methods = (),
10321
10363
split_method = "svd" ,
10364
+ check_zero = True ,
10365
+ inplace = False ,
10366
+ progbar = False ,
10322
10367
):
10323
10368
"""Perform a series of tensor network 'simplifications' in a loop until
10324
10369
there is no more reduction in the number of tensors or indices. Note
@@ -10357,6 +10402,9 @@ def full_simplify(
10357
10402
cache : None or set
10358
10403
A persistent cache for each simplification process to mark
10359
10404
already processed tensors.
10405
+ check_zero : bool, optional
10406
+ Whether to check if tensors have zero norm and in that case do
10407
+ nothing if and when equalizing norms, rather than generating a NaN.
10360
10408
progbar : bool, optional
10361
10409
Show a live progress bar of the simplification process.
10362
10410
inplace : bool, optional
@@ -10422,6 +10470,7 @@ def full_simplify(
10422
10470
output_inds = ix_o ,
10423
10471
cache = cache ,
10424
10472
equalize_norms = equalize_norms ,
10473
+ check_zero = check_zero ,
10425
10474
** rank_simplify_opts ,
10426
10475
)
10427
10476
elif meth == "A" :
@@ -10435,6 +10484,7 @@ def full_simplify(
10435
10484
atol = atol ,
10436
10485
cache = cache ,
10437
10486
equalize_norms = equalize_norms ,
10487
+ check_zero = check_zero ,
10438
10488
** split_simplify_opts ,
10439
10489
)
10440
10490
elif meth == "L" :
@@ -10443,6 +10493,7 @@ def full_simplify(
10443
10493
cutoff = atol ,
10444
10494
cache = cache ,
10445
10495
equalize_norms = equalize_norms ,
10496
+ check_zero = check_zero ,
10446
10497
** loop_simplify_opts ,
10447
10498
)
10448
10499
elif meth == "P" :
@@ -10451,6 +10502,7 @@ def full_simplify(
10451
10502
cutoff = atol ,
10452
10503
cache = cache ,
10453
10504
equalize_norms = equalize_norms ,
10505
+ check_zero = check_zero ,
10454
10506
** loop_simplify_opts ,
10455
10507
)
10456
10508
else :
@@ -10462,9 +10514,10 @@ def full_simplify(
10462
10514
if equalize_norms :
10463
10515
if equalize_norms is True :
10464
10516
# this also redistributes the collected exponents
10465
- tn . equalize_norms_ ()
10517
+ value = None
10466
10518
else :
10467
- tn .equalize_norms_ (value = equalize_norms )
10519
+ value = equalize_norms
10520
+ tn .equalize_norms_ (value = value , check_zero = check_zero )
10468
10521
10469
10522
if progbar :
10470
10523
pbar .close ()
@@ -10594,6 +10647,7 @@ def compress_simplify(
10594
10647
max_simplification_iterations = 100 ,
10595
10648
converged_tol = 0.01 ,
10596
10649
equalize_norms = True ,
10650
+ check_zero = True ,
10597
10651
progbar = False ,
10598
10652
inplace = False ,
10599
10653
** full_simplify_opts ,
@@ -10606,6 +10660,7 @@ def compress_simplify(
10606
10660
simplify_opts = {
10607
10661
"atol" : atol ,
10608
10662
"equalize_norms" : equalize_norms ,
10663
+ "check_zero" : check_zero ,
10609
10664
"progbar" : progbar ,
10610
10665
"output_inds" : output_inds ,
10611
10666
"cache" : set (),
0 commit comments