1212"""
1313
1414import torch
15- from torch .utils .data import Dataset , DataLoader
16- from torch_concepts import Annotations , AxisAnnotation
1715from torch_concepts .nn import ConceptBottleneckModel
1816from torch_concepts .data .datasets import ToyDataset
17+ from torch_concepts .data .base .datamodule import ConceptDataModule
1918from torch .distributions import Bernoulli
2019
2120from torchmetrics .classification import BinaryAccuracy
2221
2322from pytorch_lightning import Trainer
2423
25- class ConceptDataset (Dataset ):
26- """Custom dataset that returns batches in the format expected by ConceptBottleneckModel."""
27-
28- def __init__ (self , x , c , y ):
29- self .x = x
30- self .concepts = torch .cat ([c , y ], dim = 1 )
31-
32- def __len__ (self ):
33- return len (self .x )
34-
35- def __getitem__ (self , idx ):
36- return {
37- 'inputs' : {'x' : self .x [idx ]},
38- 'concepts' : {'c' : self .concepts [idx ]},
39- }
40-
4124def main ():
4225 # Set random seed for reproducibility
4326 torch .manual_seed (42 )
@@ -47,63 +30,39 @@ def main():
4730 print ("Step 1: Generate toy XOR dataset" )
4831 print ("=" * 60 )
4932
50- n_samples = 1000
33+ n_samples = 10000
34+ batch_size = 2048
5135 dataset = ToyDataset (dataset = 'xor' , seed = 42 , n_gen = n_samples )
52- x_train = dataset .input_data
53- concept_idx = list (dataset .graph .edge_index [0 ].unique ().numpy ())
54- task_idx = list (dataset .graph .edge_index [1 ].unique ().numpy ())
55- c_train = dataset .concepts [:, concept_idx ]
56- y_train = dataset .concepts [:, task_idx ]
57- concept_names = [dataset .concept_names [i ] for i in concept_idx ]
58- task_names = [dataset .concept_names [i ] for i in task_idx ]
59-
60- n_features = x_train .shape [1 ]
61- n_concepts = c_train .shape [1 ]
62- n_tasks = y_train .shape [1 ]
63-
64- print (f"Input features: { n_features } " )
65- print (f"Concepts: { n_concepts } - { concept_names } " )
66- print (f"Tasks: { n_tasks } - { task_names } " )
67- print (f"Training samples: { n_samples } " )
36+ datamodule = ConceptDataModule (dataset = dataset ,
37+ batch_size = batch_size ,
38+ val_size = 0.1 ,
39+ test_size = 0.2 )
40+ annotations = dataset .annotations
41+ concept_names = annotations .get_axis_annotation (1 ).labels
6842
69- # For binary concepts, we can use simple labels
70- concept_annotations = Annotations ({
71- 1 : AxisAnnotation (
72- labels = tuple (concept_names + task_names ),
73- cardinalities = [1 ]* (n_concepts + n_tasks ),
74- metadata = {
75- concept_names [0 ]: {
76- 'type' : 'discrete' ,
77- 'distribution' : Bernoulli
78- },
79- concept_names [1 ]: {
80- 'type' : 'discrete' ,
81- 'distribution' : Bernoulli
82- },
83- task_names [0 ]: {
84- 'type' : 'discrete' ,
85- 'distribution' : Bernoulli
86- }
87- }
88- )
89- })
90-
91- print (f"Concept axis labels: { concept_annotations [1 ].labels } " )
92- print (f"Concept types: { [concept_annotations [1 ].metadata [name ]['type' ] for name in concept_names ]} " )
93- print (f"Concept cardinalities: { concept_annotations [1 ].cardinalities } " )
94- print (f"Concept distributions: { [concept_annotations [1 ].metadata [name ]['distribution' ] for name in concept_names ]} " )
43+ n_features = dataset .input_data .shape [1 ]
44+ n_concepts = 2
45+ n_tasks = 1
9546
47+ print (f"Input features: { n_features } " )
48+ print (f"Concepts: { n_concepts } - { concept_names [:2 ]} " )
49+ print (f"Tasks: { n_tasks } - { concept_names [2 ]} " )
50+ print (f"Training samples: { n_samples } " )
9651
9752 # Init model
9853 print ("\n " + "=" * 60 )
9954 print ("Step 2: Initialize ConceptBottleneckModel" )
10055 print ("=" * 60 )
10156
57+ # Define variable distributions as Bernoulli
58+ variable_distributions = {name : Bernoulli for name in concept_names }
59+
10260 # Initialize the CBM
10361 model = ConceptBottleneckModel (
10462 input_size = n_features ,
105- annotations = concept_annotations ,
106- task_names = task_names ,
63+ annotations = annotations ,
64+ variable_distributions = variable_distributions ,
65+ task_names = ['xor' ],
10766 latent_encoder_kwargs = {'hidden_size' : 16 , 'n_layers' : 1 },
10867 # Specify loss and optimizer to abilitate training with lightning
10968 loss = torch .nn .BCEWithLogitsLoss (),
@@ -121,11 +80,10 @@ def main():
12180 print ("Step 3: Test forward pass" )
12281 print ("=" * 60 )
12382
124- batch_size = 8
125- x_batch = x_train [:batch_size ]
83+ x_batch = dataset .input_data [:batch_size ]
12684
12785 # Forward pass
128- query = list ( concept_names ) + list ( task_names )
86+ query = concept_names
12987 print (f"Query variables: { query } " )
13088
13189 with torch .no_grad ():
@@ -136,43 +94,51 @@ def main():
13694 print (f"Expected output dim: { n_concepts + n_tasks } " )
13795
13896
139- # Test forward pass
97+ # Test lightning training
14098 print ("\n " + "=" * 60 )
14199 print ("Step 4: Training loop with lightning" )
142100 print ("=" * 60 )
143101
144- trainer = Trainer (
145- max_epochs = 500 ,
146- log_every_n_steps = 10
147- )
148-
149- # Create dataset and dataloader
150- train_dataset = ConceptDataset (x_train , c_train , y_train )
151- train_dataloader = DataLoader (train_dataset , batch_size = 1000 , shuffle = False )
102+ trainer = Trainer (max_epochs = 100 )
152103
153104 model .train ()
154- trainer .fit (model , train_dataloaders = train_dataloader )
105+ trainer .fit (model , datamodule = datamodule )
155106
156107 # Evaluate
157108 print ("\n " + "=" * 60 )
158- print ("Step 5: Evaluation" )
109+ print ("Step 5: Evaluation with standard torch metrics " )
159110 print ("=" * 60 )
160111
161112 concept_acc_fn = BinaryAccuracy ()
162113 task_acc_fn = BinaryAccuracy ()
163114
164115 model .eval ()
116+ concept_acc_sum = 0.0
117+ task_acc_sum = 0.0
118+ num_batches = 0
119+
165120 with torch .no_grad ():
166- endogenous = model (x_train , query = query )
167- c_pred = endogenous [:, :n_concepts ]
168- y_pred = endogenous [:, n_concepts :]
169-
170- # Compute accuracy using BinaryAccuracy
171- concept_acc = concept_acc_fn (c_pred , c_train .int ()).item ()
172- task_acc = task_acc_fn (y_pred , y_train .int ()).item ()
173-
174- print (f"Concept accuracy: { concept_acc :.4f} " )
175- print (f"Task accuracy: { task_acc :.4f} " )
121+ test_loader = datamodule .test_dataloader ()
122+ for batch in test_loader :
123+ endogenous = model (batch ['inputs' ]['x' ], query = query )
124+ c_pred = endogenous [:, :n_concepts ]
125+ y_pred = endogenous [:, n_concepts :]
126+
127+ c_true = batch ['concepts' ]['c' ][:, :n_concepts ]
128+ y_true = batch ['concepts' ]['c' ][:, n_concepts :]
129+
130+ concept_acc = concept_acc_fn (c_pred , c_true .int ()).item ()
131+ task_acc = task_acc_fn (y_pred , y_true .int ()).item ()
132+
133+ concept_acc_sum += concept_acc
134+ task_acc_sum += task_acc
135+ num_batches += 1
136+
137+ avg_concept_acc = concept_acc_sum / num_batches if num_batches > 0 else 0.0
138+ avg_task_acc = task_acc_sum / num_batches if num_batches > 0 else 0.0
139+
140+ print (f"Average concept accuracy: { avg_concept_acc :.4f} " )
141+ print (f"Average task accuracy: { avg_task_acc :.4f} " )
176142
177143if __name__ == "__main__" :
178144 main ()
0 commit comments