Skip to content

Commit a1ae2af

Browse files
Add test for unrolling pgm
1 parent ef20c93 commit a1ae2af

File tree

1 file changed

+80
-36
lines changed

1 file changed

+80
-36
lines changed

tests/test_nn_modules_mid_inference.py

Lines changed: 80 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
Tests for ForwardInference engine.
55
"""
66
import unittest
7+
from copy import deepcopy
8+
79
import torch
810
import torch.nn as nn
9-
from torch.distributions import Bernoulli, Categorical
11+
from torch.distributions import Bernoulli, Categorical, RelaxedBernoulli, RelaxedOneHotCategorical
12+
from torch_concepts.data.datasets import ToyDataset
1013

11-
from torch_concepts import InputVariable, EndogenousVariable
14+
from torch_concepts import InputVariable, EndogenousVariable, Annotations, AxisAnnotation, ConceptGraph
15+
from torch_concepts.nn import AncestralSamplingInference, WANDAGraphLearner, GraphModel, LazyConstructor, LinearZU, \
16+
LinearUC, HyperLinearCUC
1217
from torch_concepts.nn.modules.mid.models.variable import Variable
1318
from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD
1419
from torch_concepts.nn.modules.mid.models.probabilistic_model import ProbabilisticModel
@@ -356,40 +361,79 @@ def test_missing_factor(self):
356361
with self.assertRaises(RuntimeError):
357362
inference.predict(external_inputs)
358363

359-
def test_complex_multi_level_hierarchy(self):
360-
"""Test complex multi-level hierarchy."""
361-
# Level 0: latent
362-
input_var = Variable('input', parents=[], distribution=Delta, size=10)
363-
364-
# Level 1: A, B
365-
var_a = Variable('A', parents=[input_var], distribution=Bernoulli, size=1)
366-
var_b = Variable('B', parents=[input_var], distribution=Categorical, size=3)
367-
368-
# Level 2: C (depends on A and B)
369-
var_c = Variable('C', parents=[var_a, var_b], distribution=Bernoulli, size=1)
370-
371-
# Level 3: D (depends on C)
372-
var_d = Variable('D', parents=[var_c], distribution=Bernoulli, size=1)
373-
374-
latent_factor = ParametricCPD('input', parametrization=nn.Identity())
375-
cpd_a = ParametricCPD('A', parametrization=nn.Linear(10, 1))
376-
cpd_b = ParametricCPD('B', parametrization=nn.Linear(10, 3))
377-
cpd_c = ParametricCPD('C', parametrization=nn.Linear(4, 1)) # 1 + 3 inputs
378-
cpd_d = ParametricCPD('D', parametrization=nn.Linear(1, 1))
379-
380-
pgm = ProbabilisticModel(
381-
variables=[input_var, var_a, var_b, var_c, var_d],
382-
parametric_cpds=[latent_factor, cpd_a, cpd_b, cpd_c, cpd_d]
383-
)
384-
385-
inference = SimpleForwardInference(pgm)
386-
387-
self.assertEqual(len(inference.levels), 4)
388-
389-
external_inputs = {'input': torch.randn(4, 10)}
390-
results = inference.predict(external_inputs)
391-
392-
self.assertEqual(len(results), 5)
364+
def test_unroll_pgm(self):
365+
latent_dims = 20
366+
n_epochs = 1000
367+
n_samples = 1000
368+
concept_reg = 0.5
369+
370+
dataset = ToyDataset(dataset='xor', seed=42, n_gen=n_samples)
371+
x_train = dataset.input_data
372+
concept_idx = list(dataset.graph.edge_index[0].unique().numpy())
373+
task_idx = list(dataset.graph.edge_index[1].unique().numpy())
374+
c_train = dataset.concepts[:, concept_idx]
375+
y_train = dataset.concepts[:, task_idx]
376+
377+
c_train = torch.cat([c_train, y_train], dim=1)
378+
y_train = deepcopy(c_train)
379+
cy_train = torch.cat([c_train, y_train], dim=1)
380+
c_train_one_hot = torch.cat(
381+
[cy_train[:, :2], torch.nn.functional.one_hot(cy_train[:, 2].long(), num_classes=2).float()], dim=1)
382+
cy_train_one_hot = torch.cat([c_train_one_hot, c_train_one_hot], dim=1)
383+
384+
concept_names = ['c1', 'c2', 'xor']
385+
task_names = ['c1_copy', 'c2_copy', 'xor_copy']
386+
cardinalities = [1, 1, 2, 1, 1, 2]
387+
metadata = {
388+
'c1': {'distribution': RelaxedBernoulli, 'type': 'binary', 'description': 'Concept 1'},
389+
'c2': {'distribution': RelaxedBernoulli, 'type': 'binary', 'description': 'Concept 2'},
390+
'xor': {'distribution': RelaxedOneHotCategorical, 'type': 'categorical', 'description': 'XOR Task'},
391+
'c1_copy': {'distribution': RelaxedBernoulli, 'type': 'binary', 'description': 'Concept 1 Copy'},
392+
'c2_copy': {'distribution': RelaxedBernoulli, 'type': 'binary', 'description': 'Concept 2 Copy'},
393+
'xor_copy': {'distribution': RelaxedOneHotCategorical, 'type': 'categorical',
394+
'description': 'XOR Task Copy'},
395+
}
396+
annotations = Annotations(
397+
{1: AxisAnnotation(concept_names + task_names, cardinalities=cardinalities, metadata=metadata)})
398+
399+
model_graph = ConceptGraph(torch.tensor([[0, 0, 0, 0, 1, 1],
400+
[0, 0, 0, 1, 0, 1],
401+
[0, 0, 0, 1, 1, 0],
402+
[0, 0, 0, 0, 0, 0],
403+
[0, 0, 0, 0, 0, 0],
404+
[0, 0, 0, 0, 0, 0]]), list(annotations.get_axis_annotation(1).labels))
405+
406+
# ProbabilisticModel Initialization
407+
encoder = torch.nn.Sequential(torch.nn.Linear(x_train.shape[1], latent_dims), torch.nn.LeakyReLU())
408+
concept_model = GraphModel(model_graph=model_graph,
409+
input_size=latent_dims,
410+
annotations=annotations,
411+
source_exogenous=LazyConstructor(LinearZU, exogenous_size=11),
412+
internal_exogenous=LazyConstructor(LinearZU, exogenous_size=7),
413+
encoder=LazyConstructor(LinearUC),
414+
predictor=LazyConstructor(HyperLinearCUC, embedding_size=20))
415+
416+
# graph learning init
417+
graph_learner = WANDAGraphLearner(concept_names, task_names)
418+
419+
inference_engine = AncestralSamplingInference(concept_model.probabilistic_model, graph_learner, temperature=0.1)
420+
query_concepts = ["c1", "c2", "xor", "c1_copy", "c2_copy", "xor_copy"]
421+
422+
emb = encoder(x_train)
423+
cy_pred_before_unrolling = inference_engine.query(query_concepts, evidence={'input': emb}, debug=True)
424+
425+
concept_model_new = inference_engine.unrolled_probabilistic_model()
426+
427+
# identify available query concepts in the unrolled model
428+
query_concepts = [c for c in query_concepts if c in inference_engine.available_query_vars]
429+
concept_idx = {v: i for i, v in enumerate(concept_names)}
430+
reverse_c2t_mapping = dict(zip(task_names, concept_names))
431+
query_concepts = sorted(query_concepts, key=lambda x: concept_idx[x] if x in concept_idx else concept_idx[reverse_c2t_mapping[x]])
432+
433+
inference_engine = AncestralSamplingInference(concept_model_new, temperature=0.1)
434+
cy_pred_after_unrolling = inference_engine.query(query_concepts, evidence={'input': emb}, debug=True)
435+
436+
self.assertTrue(cy_pred_after_unrolling.shape == c_train_one_hot.shape)
393437

394438

395439
if __name__ == '__main__':

0 commit comments

Comments
 (0)