Skip to content

Commit

Permalink
Merge pull request #1666 from NNPDF/cutsinfoextralabels
Browse files Browse the repository at this point in the history
Fix extra labels slicing when cuts are given
  • Loading branch information
scarlehoff authored Jan 20, 2023
2 parents 3dfa046 + 679471e commit f6c49ae
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
12 changes: 9 additions & 3 deletions validphys2/src/validphys/plotoptions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@ def get_info(data, *, normalize=False, cuts=None, use_plotfiles=True):
if cuts is None:
if isinstance(data, DataSetSpec):
cuts = data.cuts.load() if data.cuts else None
elif isinstance(cuts, (Cuts, InternalCutsWrapper)):
elif hasattr(cuts, 'load'):
cuts = cuts.load()
elif not cuts:
cuts = None

if cuts is not None and not len(cuts):
raise NotImplementedError("No point passes the cuts. Cannot retieve info")

if isinstance(data, DataSetSpec):
data = data.commondata
Expand Down Expand Up @@ -175,6 +176,11 @@ def from_commondata(cls, commondata, cuts=None, normalize=False):

kinlabels = commondata.plot_kinlabels
kinlabels = plot_params['kinematics_override'].new_labels(*kinlabels)
if "extra_labels" in plot_params and cuts is not None:
cut_extra_labels ={
k: [v[i] for i in cuts] for k, v in plot_params["extra_labels"].items()
}
plot_params["extra_labels"] = cut_extra_labels

return cls(kinlabels=kinlabels, **plot_params)

Expand Down
23 changes: 18 additions & 5 deletions validphys2/src/validphys/tests/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,24 @@
#The sorted is to appease hypothesis
dss = sorted(l.available_datasets - {"PDFEVOLTEST"})

class MockCuts():
def __init__(self, arr):
self.arr = arr
def load(self):
return self.arr


@composite
def commodata_and_cuts(draw):
def commondata_and_cuts(draw):
cd = l.check_commondata(draw(sampled_from(dss)))
ndata = cd.metadata.ndata
#TODO: Maybe upgrade to this
#https://github.com/HypothesisWorks/hypothesis/issues/1115
mask = sorted(draw(sets(sampled_from(range(ndata)))))
# Get a cut mask with at least one selected datapoint
masks = sets(sampled_from(range(ndata)), min_size=1)
mask = sorted(draw(masks))
return cd, mask


@given(arg=commodata_and_cuts())
@given(arg=commondata_and_cuts())
@settings(deadline=None)
def test_rebuild_commondata_without_cuts(tmp_path_factory, arg):
# We need to create a new directory for each call of the test
Expand Down Expand Up @@ -60,6 +66,13 @@ def test_rebuild_commondata_without_cuts(tmp_path_factory, arg):
nocuts[cuts] = False
assert (lncd.get_cv()[nocuts] == 0).all()

@given(inp=commondata_and_cuts())
def test_kitable_with_cuts(inp):
cd, cuts = inp
info = get_info(cd, cuts=cuts)
tb = kitable(cd, info, cuts=MockCuts(cuts))
assert len(tb) == len(cuts)

def test_load_fit():
assert l.check_fit(FIT)
with pytest.raises(FitNotFound):
Expand Down

0 comments on commit f6c49ae

Please sign in to comment.