Skip to content

Commit 7d8bbbb

Browse files
Refactor test folder mirroring the main package
1 parent a1ae2af commit 7d8bbbb

File tree

66 files changed

+4805
-8573
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+4805
-8573
lines changed

tests/__init__.py

Whitespace-only changes.
File renamed without changes.

tests/test_nn_modules_high_learner_comprehensive.py renamed to tests/nn/modules/high/base/test_learner.py

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""
2-
Comprehensive tests for torch_concepts.nn.modules.high.base.learner
2+
Comprehensive tests for torch_concepts.nn.modules.high
33
4-
Tests the BaseLearner class with metrics setup, optimizer configuration,
5-
and loss computation for binary and categorical concepts.
4+
Tests high-level model modules (CBM, CEM, CGM, etc.).
65
"""
76
import unittest
87
import torch
@@ -11,7 +10,6 @@
1110
from torch_concepts.annotations import Annotations, AxisAnnotation
1211
from torch_concepts.distributions import Delta
1312
from torch_concepts.nn.modules.high.base.learner import BaseLearner
14-
from torchmetrics import Accuracy, MeanSquaredError
1513

1614

1715
class MockLearner(BaseLearner):
@@ -621,5 +619,125 @@ def test_instantiate_metric_dict_with_non_dict(self):
621619
self.assertEqual(metrics, {})
622620

623621

622+
class TestHighLevelModels(unittest.TestCase):
623+
"""Test high-level model architectures."""
624+
625+
def setUp(self):
626+
"""Set up common test fixtures."""
627+
# Create simple annotations for testing
628+
concept_labels = ['color', 'shape', 'size']
629+
task_labels = ['class1', 'class2']
630+
self.annotations = Annotations({
631+
1: AxisAnnotation(labels=concept_labels + task_labels)
632+
})
633+
self.variable_distributions = {
634+
'color': Delta,
635+
'shape': Delta,
636+
'size': Delta,
637+
'class1': Delta,
638+
'class2': Delta
639+
}
640+
641+
def test_cbm_placeholder(self):
642+
"""Placeholder test for CBM model."""
643+
# CBM requires complex setup with inference strategies
644+
# This is a placeholder to ensure the test file runs
645+
self.assertTrue(True)
646+
647+
def test_cem_placeholder(self):
648+
"""Placeholder test for CEM model."""
649+
# CEM requires complex setup with embeddings
650+
# This is a placeholder to ensure the test file runs
651+
self.assertTrue(True)
652+
653+
654+
class TestBatchValidation(unittest.TestCase):
655+
"""Test batch structure validation in BaseLearner."""
656+
657+
def setUp(self):
658+
"""Create a mock learner instance for testing unpack_batch."""
659+
# Create a mock learner that implements both _check_batch and unpack_batch
660+
self.learner = type('MockLearner', (), {})()
661+
# Bind both methods from BaseLearner
662+
self.learner._check_batch = BaseLearner._check_batch.__get__(self.learner)
663+
self.learner.unpack_batch = BaseLearner.unpack_batch.__get__(self.learner)
664+
665+
def test_valid_batch_structure(self):
666+
"""Test that valid batch structure is accepted."""
667+
valid_batch = {
668+
'inputs': torch.randn(4, 10),
669+
'concepts': torch.randn(4, 2)
670+
}
671+
inputs, concepts, transforms = self.learner.unpack_batch(valid_batch)
672+
self.assertIsNotNone(inputs)
673+
self.assertIsNotNone(concepts)
674+
self.assertEqual(transforms, {})
675+
676+
def test_batch_with_transforms(self):
677+
"""Test that batch with transforms is handled correctly."""
678+
batch_with_transforms = {
679+
'inputs': torch.randn(4, 10),
680+
'concepts': torch.randn(4, 2),
681+
'transforms': {'scaler': 'some_transform'}
682+
}
683+
inputs, concepts, transforms = self.learner.unpack_batch(batch_with_transforms)
684+
self.assertIsNotNone(inputs)
685+
self.assertIsNotNone(concepts)
686+
self.assertEqual(transforms, {'scaler': 'some_transform'})
687+
688+
def test_missing_inputs_key(self):
689+
"""Test that missing 'inputs' key raises KeyError."""
690+
invalid_batch = {
691+
'concepts': torch.randn(4, 2)
692+
}
693+
with self.assertRaises(KeyError) as context:
694+
self.learner.unpack_batch(invalid_batch)
695+
self.assertIn('inputs', str(context.exception))
696+
self.assertIn("missing required keys", str(context.exception))
697+
698+
def test_missing_concepts_key(self):
699+
"""Test that missing 'concepts' key raises KeyError."""
700+
invalid_batch = {
701+
'inputs': torch.randn(4, 10)
702+
}
703+
with self.assertRaises(KeyError) as context:
704+
self.learner.unpack_batch(invalid_batch)
705+
self.assertIn('concepts', str(context.exception))
706+
self.assertIn("missing required keys", str(context.exception))
707+
708+
def test_missing_both_keys(self):
709+
"""Test that missing both required keys raises KeyError."""
710+
invalid_batch = {
711+
'data': torch.randn(4, 10)
712+
}
713+
with self.assertRaises(KeyError) as context:
714+
self.learner.unpack_batch(invalid_batch)
715+
self.assertIn("missing required keys", str(context.exception))
716+
717+
def test_non_dict_batch(self):
718+
"""Test that non-dict batch raises TypeError."""
719+
invalid_batch = torch.randn(4, 10)
720+
with self.assertRaises(TypeError) as context:
721+
self.learner.unpack_batch(invalid_batch)
722+
self.assertIn("Expected batch to be a dict", str(context.exception))
723+
724+
def test_tuple_batch(self):
725+
"""Test that tuple batch raises TypeError."""
726+
invalid_batch = (torch.randn(4, 10), torch.randn(4, 2))
727+
with self.assertRaises(TypeError) as context:
728+
self.learner.unpack_batch(invalid_batch)
729+
self.assertIn("Expected batch to be a dict", str(context.exception))
730+
731+
def test_empty_dict_batch(self):
732+
"""Test that empty dict raises KeyError with helpful message."""
733+
invalid_batch = {}
734+
with self.assertRaises(KeyError) as context:
735+
self.learner.unpack_batch(invalid_batch)
736+
self.assertIn("missing required keys", str(context.exception))
737+
self.assertIn("Found keys: []", str(context.exception))
738+
739+
624740
if __name__ == '__main__':
625741
unittest.main()
742+
743+
File renamed without changes.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Comprehensive tests for torch_concepts.nn.modules.low.encoders
3+
4+
Tests all encoder modules (linear, exogenous, selector, stochastic).
5+
"""
6+
import unittest
7+
import torch
8+
import torch.nn as nn
9+
from torch_concepts.nn.modules.low.encoders.exogenous import LinearZU
10+
11+
12+
class TestLinearZU(unittest.TestCase):
13+
"""Test LinearZU."""
14+
15+
def test_initialization(self):
16+
"""Test encoder initialization."""
17+
encoder = LinearZU(
18+
in_features=128,
19+
out_features=10,
20+
exogenous_size=16
21+
)
22+
self.assertEqual(encoder.in_features, 128)
23+
self.assertEqual(encoder.out_features, 10)
24+
self.assertEqual(encoder.exogenous_size, 16)
25+
26+
def test_forward_shape(self):
27+
"""Test forward pass output shape."""
28+
encoder = LinearZU(
29+
in_features=64,
30+
out_features=5,
31+
exogenous_size=8
32+
)
33+
embeddings = torch.randn(4, 64)
34+
output = encoder(embeddings)
35+
self.assertEqual(output.shape, (4, 5, 8))
36+
37+
def test_gradient_flow(self):
38+
"""Test gradient flow through encoder."""
39+
encoder = LinearZU(
40+
in_features=32,
41+
out_features=3,
42+
exogenous_size=4
43+
)
44+
embeddings = torch.randn(2, 32, requires_grad=True)
45+
output = encoder(embeddings)
46+
loss = output.sum()
47+
loss.backward()
48+
self.assertIsNotNone(embeddings.grad)
49+
50+
def test_different_embedding_sizes(self):
51+
"""Test various embedding sizes."""
52+
for emb_size in [4, 8, 16, 32]:
53+
encoder = LinearZU(
54+
in_features=64,
55+
out_features=5,
56+
exogenous_size=emb_size
57+
)
58+
embeddings = torch.randn(2, 64)
59+
output = encoder(embeddings)
60+
self.assertEqual(output.shape, (2, 5, emb_size))
61+
62+
def test_encoder_output_dimension(self):
63+
"""Test output dimension calculation."""
64+
encoder = LinearZU(
65+
in_features=128,
66+
out_features=10,
67+
exogenous_size=16
68+
)
69+
self.assertEqual(encoder.out_endogenous_dim, 10)
70+
self.assertEqual(encoder.out_encoder_dim, 10 * 16)
71+
72+
def test_leaky_relu_activation(self):
73+
"""Test that LeakyReLU is applied."""
74+
encoder = LinearZU(
75+
in_features=32,
76+
out_features=3,
77+
exogenous_size=4
78+
)
79+
embeddings = torch.randn(2, 32)
80+
output = encoder(embeddings)
81+
# Output should have passed through LeakyReLU
82+
self.assertIsNotNone(output)
83+
84+
85+
if __name__ == '__main__':
86+
unittest.main()
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
Comprehensive tests for torch_concepts.nn.modules.low.encoders
3+
4+
Tests all encoder modules (linear, exogenous, selector, stochastic).
5+
"""
6+
import unittest
7+
import torch
8+
import torch.nn as nn
9+
from torch_concepts.nn.modules.low.encoders.linear import LinearZC, LinearUC
10+
11+
12+
class TestLinearZC(unittest.TestCase):
13+
"""Test LinearZC."""
14+
15+
def test_initialization(self):
16+
"""Test encoder initialization."""
17+
encoder = LinearZC(
18+
in_features=128,
19+
out_features=10
20+
)
21+
self.assertEqual(encoder.in_features, 128)
22+
self.assertEqual(encoder.out_features, 10)
23+
self.assertIsInstance(encoder.encoder, nn.Sequential)
24+
25+
def test_forward_shape(self):
26+
"""Test forward pass output shape."""
27+
encoder = LinearZC(
28+
in_features=128,
29+
out_features=10
30+
)
31+
embeddings = torch.randn(4, 128)
32+
output = encoder(embeddings)
33+
self.assertEqual(output.shape, (4, 10))
34+
35+
def test_gradient_flow(self):
36+
"""Test gradient flow through encoder."""
37+
encoder = LinearZC(
38+
in_features=64,
39+
out_features=5
40+
)
41+
embeddings = torch.randn(2, 64, requires_grad=True)
42+
output = encoder(embeddings)
43+
loss = output.sum()
44+
loss.backward()
45+
self.assertIsNotNone(embeddings.grad)
46+
47+
def test_batch_processing(self):
48+
"""Test different batch sizes."""
49+
encoder = LinearZC(
50+
in_features=32,
51+
out_features=5
52+
)
53+
for batch_size in [1, 4, 8]:
54+
embeddings = torch.randn(batch_size, 32)
55+
output = encoder(embeddings)
56+
self.assertEqual(output.shape, (batch_size, 5))
57+
58+
def test_with_bias_false(self):
59+
"""Test encoder without bias."""
60+
encoder = LinearZC(
61+
in_features=32,
62+
out_features=5,
63+
bias=False
64+
)
65+
embeddings = torch.randn(2, 32)
66+
output = encoder(embeddings)
67+
self.assertEqual(output.shape, (2, 5))
68+
69+
70+
class TestLinearUC(unittest.TestCase):
71+
"""Test LinearUC."""
72+
73+
def test_initialization(self):
74+
"""Test encoder initialization."""
75+
encoder = LinearUC(
76+
in_features_exogenous=16,
77+
n_exogenous_per_concept=2
78+
)
79+
self.assertEqual(encoder.n_exogenous_per_concept, 2)
80+
81+
def test_forward_shape(self):
82+
"""Test forward pass output shape."""
83+
encoder = LinearUC(
84+
in_features_exogenous=8,
85+
n_exogenous_per_concept=2
86+
)
87+
# Input shape: (batch, concepts, in_features * n_exogenous_per_concept)
88+
exog = torch.randn(4, 5, 16) # 8 * 2 = 16
89+
output = encoder(exog)
90+
self.assertEqual(output.shape, (4, 5))
91+
92+
def test_single_exogenous_per_concept(self):
93+
"""Test with single exogenous per concept."""
94+
encoder = LinearUC(
95+
in_features_exogenous=10,
96+
n_exogenous_per_concept=1
97+
)
98+
exog = torch.randn(3, 4, 10)
99+
output = encoder(exog)
100+
self.assertEqual(output.shape, (3, 4))
101+
102+
def test_gradient_flow(self):
103+
"""Test gradient flow."""
104+
encoder = LinearUC(
105+
in_features_exogenous=8,
106+
n_exogenous_per_concept=2
107+
)
108+
exog = torch.randn(2, 3, 16, requires_grad=True)
109+
output = encoder(exog)
110+
loss = output.sum()
111+
loss.backward()
112+
self.assertIsNotNone(exog.grad)
113+
114+
115+
if __name__ == '__main__':
116+
unittest.main()

0 commit comments

Comments
 (0)