Skip to content

Commit

Permalink
add test for sum_rules=TSR
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Apr 16, 2024
1 parent b29956f commit 47a077f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
8 changes: 7 additions & 1 deletion n3fit/src/n3fit/layers/msr_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def __init__(self, mode: str = "ALL", replica_seeds=None, **kwargs):
else:
raise ValueError(f"Mode {mode} not accepted for sum rules")

self._replicas = len(replica_seeds)
if replica_seeds is None:
self._replicas = 1
else:
self._replicas = len(replica_seeds)

indices = []
self.divisor_indices = []
Expand All @@ -119,6 +122,9 @@ def __init__(self, mode: str = "ALL", replica_seeds=None, **kwargs):
[np.repeat(VSR_CONSTANTS[c], self._replicas) for c in VSR_COMPONENTS]
)
if self._tsr_enabled:
if replica_seeds is None:
raise ValueError("To use sum_rules=TSR a list of seeds must be provided")

self.divisor_indices += [IDX[TSR_DENOMINATORS[c]] for c in TSR_COMPONENTS]
indices += [IDX[c] for c in TSR_COMPONENTS]

Expand Down
4 changes: 2 additions & 2 deletions n3fit/src/n3fit/tests/regressions/quickcard.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ theory:

############################################################
genrep: True # on = generate MC replicas, False = use real data
trvlseed: 5
trvlseed: 3
nnseed: 2
mcseed: 1
nnseed: 20

load: "weights.weights.h5"

Expand Down
16 changes: 16 additions & 0 deletions n3fit/src/n3fit/tests/test_msr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest

from n3fit.backends import operations as op
from n3fit.layers import MSR_Normalization
Expand Down Expand Up @@ -76,3 +77,18 @@ def test_vsr():
]
)
np.testing.assert_allclose(output, known_output, rtol=1e-5)


def test_tsr():
"""Test the sum rules used in polarized fits"""
with pytest.raises(ValueError):
# Check that seeds _are_ needed
layer = MSR_Normalization(mode='TSR')

layer = MSR_Normalization(mode="TSR", replica_seeds=[3])
output = apply_layer_to_fixed_input(layer)
# They should _all_ be 1.0 except for entries 9 and 10
known_output = np.ones((1, 1, 14))
known_output[0, 0, 9] = 1.1133982
known_output[0, 0, 10] = -0.9901034
np.testing.assert_allclose(output, known_output, rtol=1e-5)

0 comments on commit 47a077f

Please sign in to comment.