diff --git a/validphys2/src/validphys/plotoptions/core.py b/validphys2/src/validphys/plotoptions/core.py index d96aef7290..88258b9961 100644 --- a/validphys2/src/validphys/plotoptions/core.py +++ b/validphys2/src/validphys/plotoptions/core.py @@ -57,10 +57,11 @@ def get_info(data, *, normalize=False, cuts=None, use_plotfiles=True): if cuts is None: if isinstance(data, DataSetSpec): cuts = data.cuts.load() if data.cuts else None - elif isinstance(cuts, (Cuts, InternalCutsWrapper)): + elif hasattr(cuts, 'load'): cuts = cuts.load() - elif not cuts: - cuts = None + + if cuts is not None and not len(cuts): + raise NotImplementedError("No point passes the cuts. Cannot retieve info") if isinstance(data, DataSetSpec): data = data.commondata @@ -175,6 +176,11 @@ def from_commondata(cls, commondata, cuts=None, normalize=False): kinlabels = commondata.plot_kinlabels kinlabels = plot_params['kinematics_override'].new_labels(*kinlabels) + if "extra_labels" in plot_params and cuts is not None: + cut_extra_labels ={ + k: [v[i] for i in cuts] for k, v in plot_params["extra_labels"].items() + } + plot_params["extra_labels"] = cut_extra_labels return cls(kinlabels=kinlabels, **plot_params) diff --git a/validphys2/src/validphys/tests/test_loader.py b/validphys2/src/validphys/tests/test_loader.py index decf0a020c..9d23248675 100644 --- a/validphys2/src/validphys/tests/test_loader.py +++ b/validphys2/src/validphys/tests/test_loader.py @@ -17,18 +17,24 @@ #The sorted is to appease hypothesis dss = sorted(l.available_datasets - {"PDFEVOLTEST"}) +class MockCuts(): + def __init__(self, arr): + self.arr = arr + def load(self): + return self.arr + @composite -def commodata_and_cuts(draw): +def commondata_and_cuts(draw): cd = l.check_commondata(draw(sampled_from(dss))) ndata = cd.metadata.ndata - #TODO: Maybe upgrade to this - #https://github.com/HypothesisWorks/hypothesis/issues/1115 - mask = sorted(draw(sets(sampled_from(range(ndata))))) + # Get a cut mask with at least one selected datapoint + masks = sets(sampled_from(range(ndata)), min_size=1) + mask = sorted(draw(masks)) return cd, mask -@given(arg=commodata_and_cuts()) +@given(arg=commondata_and_cuts()) @settings(deadline=None) def test_rebuild_commondata_without_cuts(tmp_path_factory, arg): # We need to create a new directory for each call of the test @@ -60,6 +66,13 @@ def test_rebuild_commondata_without_cuts(tmp_path_factory, arg): nocuts[cuts] = False assert (lncd.get_cv()[nocuts] == 0).all() +@given(inp=commondata_and_cuts()) +def test_kitable_with_cuts(inp): + cd, cuts = inp + info = get_info(cd, cuts=cuts) + tb = kitable(cd, info, cuts=MockCuts(cuts)) + assert len(tb) == len(cuts) + def test_load_fit(): assert l.check_fit(FIT) with pytest.raises(FitNotFound):