Skip to content

Commit 6b7731b

Browse files
committed
update add_distribution_to_annotations to work with dicts and GroupConfig
1 parent 3c877db commit 6b7731b

File tree

9 files changed

+259
-383
lines changed

9 files changed

+259
-383
lines changed

conceptarium/conf/model/_commons.yaml

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,12 @@ latent_encoder_kwargs:
1717
# Concept distribution configs
1818
# =============================================================
1919
variable_distributions:
20-
discrete_card1:
21-
path: "torch.distributions.RelaxedBernoulli"
22-
kwargs:
23-
temperature: 0.1
24-
discrete_cardn:
25-
path: "torch.distributions.RelaxedOneHotCategorical"
26-
kwargs:
27-
temperature: 0.1
28-
# num_classes: to be set dynamically for each concept
29-
continuous_card1:
30-
path: "torch_concepts.distributions.Delta"
31-
continuous_cardn:
32-
path: "torch_concepts.distributions.Delta"
20+
_target_: "torch_concepts.GroupConfig"
21+
binary: "torch.distributions.RelaxedBernoulli"
22+
categorical: "torch.distributions.RelaxedOneHotCategorical"
23+
# TODO: handle kwargs
24+
# continuous:
25+
# ... not supported yet
3326

3427

3528
# =============================================================

examples/utilization/2_model/5_torch_training.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313

1414
import torch
1515
from torch import nn
16-
from torch_concepts import Annotations, AxisAnnotation
17-
from torch_concepts.nn import ConceptBottleneckModel, ConceptLoss
18-
from torch_concepts.data.datasets import ToyDataset
1916
from torch.distributions import Bernoulli
2017

18+
from torch_concepts.nn import ConceptBottleneckModel
19+
from torch_concepts.data.datasets import ToyDataset
20+
2121
from torchmetrics.classification import BinaryAccuracy
2222

2323

24+
2425
def main():
2526
# Set random seed for reproducibility
2627
torch.manual_seed(42)
@@ -33,12 +34,10 @@ def main():
3334
n_samples = 1000
3435
dataset = ToyDataset(dataset='xor', seed=42, n_gen=n_samples)
3536
x_train = dataset.input_data
36-
concept_idx = list(dataset.graph.edge_index[0].unique().numpy())
37-
task_idx = list(dataset.graph.edge_index[1].unique().numpy())
38-
c_train = dataset.concepts[:, concept_idx]
39-
y_train = dataset.concepts[:, task_idx]
40-
concept_names = [dataset.concept_names[i] for i in concept_idx]
41-
task_names = [dataset.concept_names[i] for i in task_idx]
37+
c_train = dataset.concepts[:, :2]
38+
y_train = dataset.concepts[:, 2:]
39+
concept_names = dataset.concept_names[:2]
40+
task_names = dataset.concept_names[2:]
4241

4342
n_features = x_train.shape[1]
4443
n_concepts = c_train.shape[1]
@@ -49,41 +48,25 @@ def main():
4948
print(f"Tasks: {n_tasks} - {task_names}")
5049
print(f"Training samples: {n_samples}")
5150

52-
# For binary concepts, we can use simple labels
53-
concept_annotations = Annotations({
54-
1: AxisAnnotation(
55-
labels=tuple(concept_names + task_names),
56-
metadata={
57-
concept_names[0]: {
58-
'type': 'discrete',
59-
'distribution': Bernoulli
60-
},
61-
concept_names[1]: {
62-
'type': 'discrete',
63-
'distribution': Bernoulli
64-
},
65-
task_names[0]: {
66-
'type': 'discrete',
67-
'distribution': Bernoulli
68-
}
69-
}
70-
)
71-
})
51+
concept_annotations = dataset.annotations
7252

7353
print(f"Concept axis labels: {concept_annotations[1].labels}")
7454
print(f"Concept types: {[concept_annotations[1].metadata[name]['type'] for name in concept_names]}")
7555
print(f"Concept cardinalities: {concept_annotations[1].cardinalities}")
76-
print(f"Concept distributions: {[concept_annotations[1].metadata[name]['distribution'] for name in concept_names]}")
7756

7857
# Init model
7958
print("\n" + "=" * 60)
8059
print("Step 2: Initialize ConceptBottleneckModel")
8160
print("=" * 60)
8261

62+
# Define variable distributions as Bernoulli
63+
variable_distributions = {name: Bernoulli for name in concept_names + task_names}
64+
8365
# Initialize the CBM
8466
model = ConceptBottleneckModel(
8567
input_size=n_features,
8668
annotations=concept_annotations,
69+
variable_distributions=variable_distributions,
8770
task_names=task_names,
8871
latent_encoder_kwargs={'hidden_size': 16, 'n_layers': 1}
8972
)

examples/utilization/2_model/6_lightning_training.py

Lines changed: 53 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,15 @@
1212
"""
1313

1414
import torch
15-
from torch.utils.data import Dataset, DataLoader
16-
from torch_concepts import Annotations, AxisAnnotation
1715
from torch_concepts.nn import ConceptBottleneckModel
1816
from torch_concepts.data.datasets import ToyDataset
17+
from torch_concepts.data.base.datamodule import ConceptDataModule
1918
from torch.distributions import Bernoulli
2019

2120
from torchmetrics.classification import BinaryAccuracy
2221

2322
from pytorch_lightning import Trainer
2423

25-
class ConceptDataset(Dataset):
26-
"""Custom dataset that returns batches in the format expected by ConceptBottleneckModel."""
27-
28-
def __init__(self, x, c, y):
29-
self.x = x
30-
self.concepts = torch.cat([c, y], dim=1)
31-
32-
def __len__(self):
33-
return len(self.x)
34-
35-
def __getitem__(self, idx):
36-
return {
37-
'inputs': {'x': self.x[idx]},
38-
'concepts': {'c': self.concepts[idx]},
39-
}
40-
4124
def main():
4225
# Set random seed for reproducibility
4326
torch.manual_seed(42)
@@ -47,63 +30,39 @@ def main():
4730
print("Step 1: Generate toy XOR dataset")
4831
print("=" * 60)
4932

50-
n_samples = 1000
33+
n_samples = 10000
34+
batch_size = 2048
5135
dataset = ToyDataset(dataset='xor', seed=42, n_gen=n_samples)
52-
x_train = dataset.input_data
53-
concept_idx = list(dataset.graph.edge_index[0].unique().numpy())
54-
task_idx = list(dataset.graph.edge_index[1].unique().numpy())
55-
c_train = dataset.concepts[:, concept_idx]
56-
y_train = dataset.concepts[:, task_idx]
57-
concept_names = [dataset.concept_names[i] for i in concept_idx]
58-
task_names = [dataset.concept_names[i] for i in task_idx]
59-
60-
n_features = x_train.shape[1]
61-
n_concepts = c_train.shape[1]
62-
n_tasks = y_train.shape[1]
63-
64-
print(f"Input features: {n_features}")
65-
print(f"Concepts: {n_concepts} - {concept_names}")
66-
print(f"Tasks: {n_tasks} - {task_names}")
67-
print(f"Training samples: {n_samples}")
36+
datamodule = ConceptDataModule(dataset=dataset,
37+
batch_size=batch_size,
38+
val_size=0.1,
39+
test_size=0.2)
40+
annotations = dataset.annotations
41+
concept_names = annotations.get_axis_annotation(1).labels
6842

69-
# For binary concepts, we can use simple labels
70-
concept_annotations = Annotations({
71-
1: AxisAnnotation(
72-
labels=tuple(concept_names + task_names),
73-
cardinalities=[1]*(n_concepts + n_tasks),
74-
metadata={
75-
concept_names[0]: {
76-
'type': 'discrete',
77-
'distribution': Bernoulli
78-
},
79-
concept_names[1]: {
80-
'type': 'discrete',
81-
'distribution': Bernoulli
82-
},
83-
task_names[0]: {
84-
'type': 'discrete',
85-
'distribution': Bernoulli
86-
}
87-
}
88-
)
89-
})
90-
91-
print(f"Concept axis labels: {concept_annotations[1].labels}")
92-
print(f"Concept types: {[concept_annotations[1].metadata[name]['type'] for name in concept_names]}")
93-
print(f"Concept cardinalities: {concept_annotations[1].cardinalities}")
94-
print(f"Concept distributions: {[concept_annotations[1].metadata[name]['distribution'] for name in concept_names]}")
43+
n_features = dataset.input_data.shape[1]
44+
n_concepts = 2
45+
n_tasks = 1
9546

47+
print(f"Input features: {n_features}")
48+
print(f"Concepts: {n_concepts} - {concept_names[:2]}")
49+
print(f"Tasks: {n_tasks} - {concept_names[2]}")
50+
print(f"Training samples: {n_samples}")
9651

9752
# Init model
9853
print("\n" + "=" * 60)
9954
print("Step 2: Initialize ConceptBottleneckModel")
10055
print("=" * 60)
10156

57+
# Define variable distributions as Bernoulli
58+
variable_distributions = {name: Bernoulli for name in concept_names}
59+
10260
# Initialize the CBM
10361
model = ConceptBottleneckModel(
10462
input_size=n_features,
105-
annotations=concept_annotations,
106-
task_names=task_names,
63+
annotations=annotations,
64+
variable_distributions=variable_distributions,
65+
task_names=['xor'],
10766
latent_encoder_kwargs={'hidden_size': 16, 'n_layers': 1},
10867
# Specify loss and optimizer to abilitate training with lightning
10968
loss=torch.nn.BCEWithLogitsLoss(),
@@ -121,11 +80,10 @@ def main():
12180
print("Step 3: Test forward pass")
12281
print("=" * 60)
12382

124-
batch_size = 8
125-
x_batch = x_train[:batch_size]
83+
x_batch = dataset.input_data[:batch_size]
12684

12785
# Forward pass
128-
query = list(concept_names) + list(task_names)
86+
query = concept_names
12987
print(f"Query variables: {query}")
13088

13189
with torch.no_grad():
@@ -136,43 +94,51 @@ def main():
13694
print(f"Expected output dim: {n_concepts + n_tasks}")
13795

13896

139-
# Test forward pass
97+
# Test lightning training
14098
print("\n" + "=" * 60)
14199
print("Step 4: Training loop with lightning")
142100
print("=" * 60)
143101

144-
trainer = Trainer(
145-
max_epochs=500,
146-
log_every_n_steps=10
147-
)
148-
149-
# Create dataset and dataloader
150-
train_dataset = ConceptDataset(x_train, c_train, y_train)
151-
train_dataloader = DataLoader(train_dataset, batch_size=1000, shuffle=False)
102+
trainer = Trainer(max_epochs=100)
152103

153104
model.train()
154-
trainer.fit(model, train_dataloaders=train_dataloader)
105+
trainer.fit(model, datamodule=datamodule)
155106

156107
# Evaluate
157108
print("\n" + "=" * 60)
158-
print("Step 5: Evaluation")
109+
print("Step 5: Evaluation with standard torch metrics")
159110
print("=" * 60)
160111

161112
concept_acc_fn = BinaryAccuracy()
162113
task_acc_fn = BinaryAccuracy()
163114

164115
model.eval()
116+
concept_acc_sum = 0.0
117+
task_acc_sum = 0.0
118+
num_batches = 0
119+
165120
with torch.no_grad():
166-
endogenous = model(x_train, query=query)
167-
c_pred = endogenous[:, :n_concepts]
168-
y_pred = endogenous[:, n_concepts:]
169-
170-
# Compute accuracy using BinaryAccuracy
171-
concept_acc = concept_acc_fn(c_pred, c_train.int()).item()
172-
task_acc = task_acc_fn(y_pred, y_train.int()).item()
173-
174-
print(f"Concept accuracy: {concept_acc:.4f}")
175-
print(f"Task accuracy: {task_acc:.4f}")
121+
test_loader = datamodule.test_dataloader()
122+
for batch in test_loader:
123+
endogenous = model(batch['inputs']['x'], query=query)
124+
c_pred = endogenous[:, :n_concepts]
125+
y_pred = endogenous[:, n_concepts:]
126+
127+
c_true = batch['concepts']['c'][:, :n_concepts]
128+
y_true = batch['concepts']['c'][:, n_concepts:]
129+
130+
concept_acc = concept_acc_fn(c_pred, c_true.int()).item()
131+
task_acc = task_acc_fn(y_pred, y_true.int()).item()
132+
133+
concept_acc_sum += concept_acc
134+
task_acc_sum += task_acc
135+
num_batches += 1
136+
137+
avg_concept_acc = concept_acc_sum / num_batches if num_batches > 0 else 0.0
138+
avg_task_acc = task_acc_sum / num_batches if num_batches > 0 else 0.0
139+
140+
print(f"Average concept accuracy: {avg_concept_acc:.4f}")
141+
print(f"Average task accuracy: {avg_task_acc:.4f}")
176142

177143
if __name__ == "__main__":
178144
main()

0 commit comments

Comments
 (0)