Skip to content

Commit ef20c93

Browse files
Add more tests to increase coverage
1 parent 6c1c5e2 commit ef20c93

25 files changed

+6220
-12
lines changed

tests/test_annotations_comprehensive.py

Lines changed: 425 additions & 0 deletions
Large diffs are not rendered by default.

tests/test_annotations_extended.py

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
"""Extended tests for torch_concepts.annotations module to improve coverage."""
2+
3+
import pytest
4+
import torch
5+
from torch_concepts.annotations import AxisAnnotation, Annotations
6+
7+
8+
class TestAxisAnnotationExtended:
9+
"""Extended tests for AxisAnnotation class to improve coverage."""
10+
11+
def test_cardinality_mismatch_with_states(self):
12+
"""Test that mismatched cardinalities and states raise error."""
13+
with pytest.raises(ValueError, match="don't match inferred cardinalities"):
14+
AxisAnnotation(
15+
labels=['a', 'b'],
16+
states=[['x', 'y'], ['p', 'q', 'r']],
17+
cardinalities=[2, 2] # Should be [2, 3] based on states
18+
)
19+
20+
def test_metadata_validation_non_dict(self):
21+
"""Test that non-dict metadata raises error."""
22+
with pytest.raises(ValueError, match="metadata must be a dictionary"):
23+
AxisAnnotation(
24+
labels=['a', 'b'],
25+
metadata="invalid" # Should be dict
26+
)
27+
28+
def test_metadata_validation_missing_label(self):
29+
"""Test that metadata missing a label raises error."""
30+
with pytest.raises(ValueError, match="Metadata missing for label"):
31+
AxisAnnotation(
32+
labels=['a', 'b', 'c'],
33+
metadata={'a': {}, 'b': {}} # Missing 'c'
34+
)
35+
36+
def test_has_metadata_with_key(self):
37+
"""Test has_metadata method with specific key."""
38+
axis = AxisAnnotation(
39+
labels=['a', 'b'],
40+
metadata={'a': {'type': 'binary'}, 'b': {'type': 'binary'}}
41+
)
42+
assert axis.has_metadata('type') is True
43+
assert axis.has_metadata('missing_key') is False
44+
45+
def test_has_metadata_none(self):
46+
"""Test has_metadata when metadata is None."""
47+
axis = AxisAnnotation(labels=['a', 'b'])
48+
assert axis.has_metadata('any_key') is False
49+
50+
def test_groupby_metadata_labels_layout(self):
51+
"""Test groupby_metadata with labels layout."""
52+
axis = AxisAnnotation(
53+
labels=['a', 'b', 'c', 'd'],
54+
metadata={
55+
'a': {'group': 'A'},
56+
'b': {'group': 'A'},
57+
'c': {'group': 'B'},
58+
'd': {'group': 'B'}
59+
}
60+
)
61+
result = axis.groupby_metadata('group', layout='labels')
62+
assert result == {'A': ['a', 'b'], 'B': ['c', 'd']}
63+
64+
def test_groupby_metadata_indices_layout(self):
65+
"""Test groupby_metadata with indices layout."""
66+
axis = AxisAnnotation(
67+
labels=['a', 'b', 'c'],
68+
metadata={
69+
'a': {'group': 'X'},
70+
'b': {'group': 'Y'},
71+
'c': {'group': 'X'}
72+
}
73+
)
74+
result = axis.groupby_metadata('group', layout='indices')
75+
assert result == {'X': [0, 2], 'Y': [1]}
76+
77+
def test_groupby_metadata_invalid_layout(self):
78+
"""Test groupby_metadata with invalid layout raises error."""
79+
axis = AxisAnnotation(
80+
labels=['a', 'b'],
81+
metadata={'a': {'g': '1'}, 'b': {'g': '2'}}
82+
)
83+
with pytest.raises(ValueError, match="Unknown layout"):
84+
axis.groupby_metadata('g', layout='invalid')
85+
86+
def test_groupby_metadata_none(self):
87+
"""Test groupby_metadata when metadata is None."""
88+
axis = AxisAnnotation(labels=['a', 'b'])
89+
result = axis.groupby_metadata('any_key')
90+
assert result == {}
91+
92+
def test_get_index_not_found(self):
93+
"""Test get_index with non-existent label."""
94+
axis = AxisAnnotation(labels=['a', 'b', 'c'])
95+
with pytest.raises(ValueError, match="Label 'z' not found"):
96+
axis.get_index('z')
97+
98+
def test_get_label_out_of_range(self):
99+
"""Test get_label with out-of-range index."""
100+
axis = AxisAnnotation(labels=['a', 'b'])
101+
with pytest.raises(IndexError, match="Index 5 out of range"):
102+
axis.get_label(5)
103+
104+
def test_getitem_out_of_range(self):
105+
"""Test __getitem__ with out-of-range index."""
106+
axis = AxisAnnotation(labels=['a', 'b'])
107+
with pytest.raises(IndexError, match="Index 10 out of range"):
108+
_ = axis[10]
109+
110+
def test_get_total_cardinality_nested(self):
111+
"""Test get_total_cardinality for nested axis."""
112+
axis = AxisAnnotation(
113+
labels=['a', 'b', 'c'],
114+
cardinalities=[2, 3, 4]
115+
)
116+
assert axis.get_total_cardinality() == 9
117+
118+
def test_get_total_cardinality_not_nested(self):
119+
"""Test get_total_cardinality for non-nested axis."""
120+
axis = AxisAnnotation(labels=['a', 'b', 'c'])
121+
assert axis.get_total_cardinality() == 3
122+
123+
def test_to_dict_with_all_fields(self):
124+
"""Test to_dict with all fields populated."""
125+
axis = AxisAnnotation(
126+
labels=['a', 'b'],
127+
states=[['0', '1'], ['x', 'y', 'z']],
128+
metadata={'a': {'type': 'binary'}, 'b': {'type': 'categorical'}}
129+
)
130+
result = axis.to_dict()
131+
132+
assert result['labels'] == ['a', 'b']
133+
assert result['states'] == [['0', '1'], ['x', 'y', 'z']]
134+
assert result['cardinalities'] == [2, 3]
135+
assert result['is_nested'] is True
136+
assert result['metadata'] == {'a': {'type': 'binary'}, 'b': {'type': 'categorical'}}
137+
138+
def test_from_dict_reconstruction(self):
139+
"""Test from_dict reconstructs AxisAnnotation correctly."""
140+
original = AxisAnnotation(
141+
labels=['x', 'y'],
142+
cardinalities=[2, 3],
143+
metadata={'x': {'info': 'test'}, 'y': {'info': 'test2'}}
144+
)
145+
146+
data = original.to_dict()
147+
reconstructed = AxisAnnotation.from_dict(data)
148+
149+
assert reconstructed.labels == original.labels
150+
assert reconstructed.cardinalities == original.cardinalities
151+
assert reconstructed.is_nested == original.is_nested
152+
assert reconstructed.metadata == original.metadata
153+
154+
def test_subset_basic(self):
155+
"""Test subset method with valid labels."""
156+
axis = AxisAnnotation(
157+
labels=['a', 'b', 'c', 'd'],
158+
cardinalities=[1, 2, 3, 1]
159+
)
160+
161+
subset = axis.subset(['b', 'd'])
162+
163+
assert subset.labels == ['b', 'd']
164+
assert subset.cardinalities == [2, 1]
165+
166+
def test_subset_with_metadata(self):
167+
"""Test subset preserves metadata."""
168+
axis = AxisAnnotation(
169+
labels=['a', 'b', 'c'],
170+
metadata={'a': {'x': 1}, 'b': {'x': 2}, 'c': {'x': 3}}
171+
)
172+
173+
subset = axis.subset(['a', 'c'])
174+
175+
assert subset.labels == ['a', 'c']
176+
assert subset.metadata == {'a': {'x': 1}, 'c': {'x': 3}}
177+
178+
def test_subset_missing_labels(self):
179+
"""Test subset with non-existent labels raises error."""
180+
axis = AxisAnnotation(labels=['a', 'b', 'c'])
181+
182+
with pytest.raises(ValueError, match="Unknown labels for subset"):
183+
axis.subset(['a', 'z'])
184+
185+
def test_subset_preserves_order(self):
186+
"""Test subset preserves the requested label order."""
187+
axis = AxisAnnotation(labels=['a', 'b', 'c', 'd'])
188+
189+
subset = axis.subset(['d', 'b', 'a'])
190+
191+
assert subset.labels == ['d', 'b', 'a']
192+
193+
def test_union_with_no_overlap(self):
194+
"""Test union_with with no overlapping labels."""
195+
axis1 = AxisAnnotation(labels=['a', 'b'])
196+
axis2 = AxisAnnotation(labels=['c', 'd'])
197+
198+
union = axis1.union_with(axis2)
199+
200+
assert union.labels == ['a', 'b', 'c', 'd']
201+
202+
def test_union_with_overlap(self):
203+
"""Test union_with with overlapping labels."""
204+
axis1 = AxisAnnotation(labels=['a', 'b', 'c'])
205+
axis2 = AxisAnnotation(labels=['b', 'c', 'd'])
206+
207+
union = axis1.union_with(axis2)
208+
209+
assert union.labels == ['a', 'b', 'c', 'd']
210+
211+
def test_union_with_metadata_merge(self):
212+
"""Test union_with merges metadata with left-win."""
213+
axis1 = AxisAnnotation(
214+
labels=['a', 'b'],
215+
metadata={'a': {'x': 1}, 'b': {'x': 2}}
216+
)
217+
axis2 = AxisAnnotation(
218+
labels=['b', 'c'],
219+
metadata={'b': {'x': 999}, 'c': {'x': 3}}
220+
)
221+
222+
union = axis1.union_with(axis2)
223+
224+
# Left-win: 'b' should keep metadata from axis1
225+
assert union.metadata['a'] == {'x': 1}
226+
assert union.metadata['b'] == {'x': 2}
227+
assert union.metadata['c'] == {'x': 3}
228+
229+
def test_write_once_labels_attribute(self):
230+
"""Test that labels attribute is write-once."""
231+
axis = AxisAnnotation(labels=['a', 'b'])
232+
233+
with pytest.raises(AttributeError, match="write-once and already set"):
234+
axis.labels = ['x', 'y']
235+
236+
def test_write_once_states_attribute(self):
237+
"""Test that states attribute is write-once."""
238+
axis = AxisAnnotation(labels=['a', 'b'], cardinalities=[2, 3])
239+
240+
with pytest.raises(AttributeError, match="write-once and already set"):
241+
axis.states = [['0', '1'], ['0', '1', '2']]
242+
243+
def test_metadata_can_be_modified(self):
244+
"""Test that metadata can be modified after creation."""
245+
axis = AxisAnnotation(labels=['a', 'b'])
246+
247+
# Metadata is not write-once, so this should work
248+
axis.metadata = {'a': {'test': 1}, 'b': {'test': 2}}
249+
assert axis.metadata is not None
250+
251+
252+
class TestAnnotationsExtended:
253+
"""Extended tests for Annotations class to improve coverage."""
254+
255+
def test_annotations_with_dict_input(self):
256+
"""Test Annotations with dict input."""
257+
axis0 = AxisAnnotation(labels=['batch'])
258+
axis1 = AxisAnnotation(labels=['a', 'b', 'c'])
259+
260+
annotations = Annotations({0: axis0, 1: axis1})
261+
262+
assert 0 in annotations._axis_annotations
263+
assert 1 in annotations._axis_annotations
264+
265+
def test_annotations_with_list_input(self):
266+
"""Test Annotations with list input."""
267+
axis0 = AxisAnnotation(labels=['a', 'b'])
268+
axis1 = AxisAnnotation(labels=['x', 'y', 'z'])
269+
270+
annotations = Annotations([axis0, axis1])
271+
272+
assert len(annotations._axis_annotations) == 2
273+
assert annotations._axis_annotations[0].labels == ['a', 'b']
274+
assert annotations._axis_annotations[1].labels == ['x', 'y', 'z']
275+
276+
def test_annotations_getitem(self):
277+
"""Test Annotations __getitem__ method."""
278+
axis = AxisAnnotation(labels=['a', 'b', 'c'])
279+
annotations = Annotations({1: axis})
280+
281+
retrieved = annotations[1]
282+
assert retrieved.labels == ['a', 'b', 'c']
283+
284+
def test_annotations_setitem(self):
285+
"""Test Annotations __setitem__ method."""
286+
annotations = Annotations({})
287+
axis = AxisAnnotation(labels=['x', 'y'])
288+
289+
annotations[2] = axis
290+
291+
assert annotations[2].labels == ['x', 'y']
292+
293+
def test_annotations_len(self):
294+
"""Test Annotations __len__ method."""
295+
axis0 = AxisAnnotation(labels=['a'])
296+
axis1 = AxisAnnotation(labels=['b'])
297+
axis2 = AxisAnnotation(labels=['c'])
298+
299+
annotations = Annotations({0: axis0, 1: axis1, 2: axis2})
300+
301+
assert len(annotations) == 3
302+
303+
def test_annotations_iter(self):
304+
"""Test Annotations __iter__ method."""
305+
axis0 = AxisAnnotation(labels=['a'])
306+
axis1 = AxisAnnotation(labels=['b'])
307+
308+
annotations = Annotations({0: axis0, 1: axis1})
309+
310+
axes = list(annotations)
311+
assert len(axes) == 2
312+
313+
def test_annotations_contains(self):
314+
"""Test Annotations __contains__ method."""
315+
axis = AxisAnnotation(labels=['a', 'b'])
316+
annotations = Annotations({1: axis})
317+
318+
assert 1 in annotations
319+
assert 0 not in annotations
320+
assert 5 not in annotations
321+

0 commit comments

Comments
 (0)