Skip to content
This repository was archived by the owner on Sep 27, 2022. It is now read-only.

Commit 0d7a99a

Browse files
committed
Fix for loading of single sample
1 parent 7166500 commit 0d7a99a

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

PrognosAIs/IO/LabelParser.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,18 @@ def get_labels(self) -> list:
5454
Returns:
5555
labels: List of labels
5656
"""
57+
labels = np.squeeze(self.label_data.values)
58+
59+
if isinstance(labels, np.ndarray) and labels.size > 1:
60+
labels = labels.tolist()
61+
elif isinstance(labels, np.ndarray):
62+
# Otherwise if it is 1 element it will remove the list,
63+
# and return only a string
64+
labels = [labels.tolist()]
65+
else:
66+
labels = [labels]
5767

58-
return np.squeeze(self.label_data.values).tolist()
68+
return labels
5969

6070
def get_samples(self) -> list:
6171
"""Get all labels of all samples
@@ -65,7 +75,14 @@ def get_samples(self) -> list:
6575
Returns:
6676
samples: List of samples
6777
"""
68-
return np.squeeze(self.label_data.index).tolist()
78+
79+
samples = np.squeeze(self.label_data.index)
80+
81+
if isinstance(samples, np.ndarray):
82+
samples = samples.tolist()
83+
else:
84+
samples = [samples]
85+
return samples
6986

7087
def get_data(self) -> dict:
7188
"""Get all data from the label file
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Image Label_1
2+
./PrognosAIs/tests/test_data/NPZ_Data/Samples/Test_sample_1.npz ./PrognosAIs/tests/test_data/NPZ_Data/Samples/Test_sample_1_label.npz

tests/test_labelparser_new.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,22 @@ def test_one_hot_encoding_missing_labels():
1515
result = label_parser.get_labels_from_category("Label_1")
1616

1717
assert result == pytest.approx(np.asarray([[0, 1], [-1, -1], [1, 0], [0, 1]]))
18+
19+
def test_sample_loading_single_sample():
20+
label_file = os.path.join(FIXTURE_DIR, "labels_file_single_class_single_sample.txt")
21+
label_parser = LabelParser.LabelLoader(label_file)
22+
23+
result = label_parser.get_samples()
24+
25+
assert isinstance(result, list)
26+
assert len(result) == 1
27+
28+
def test_label_loading_single_sample():
29+
label_file = os.path.join(FIXTURE_DIR, "labels_file_single_class_single_sample.txt")
30+
label_parser = LabelParser.LabelLoader(label_file)
31+
32+
result = label_parser.get_labels()
33+
34+
assert isinstance(result, list)
35+
assert len(result) == 1
36+

0 commit comments

Comments
 (0)