|
4 | 4 | Tests for ForwardInference engine. |
5 | 5 | """ |
6 | 6 | import unittest |
| 7 | +from copy import deepcopy |
| 8 | + |
7 | 9 | import torch |
8 | 10 | import torch.nn as nn |
9 | | -from torch.distributions import Bernoulli, Categorical |
| 11 | +from torch.distributions import Bernoulli, Categorical, RelaxedBernoulli, RelaxedOneHotCategorical |
| 12 | +from torch_concepts.data.datasets import ToyDataset |
10 | 13 |
|
11 | | -from torch_concepts import InputVariable, EndogenousVariable |
| 14 | +from torch_concepts import InputVariable, EndogenousVariable, Annotations, AxisAnnotation, ConceptGraph |
| 15 | +from torch_concepts.nn import AncestralSamplingInference, WANDAGraphLearner, GraphModel, LazyConstructor, LinearZU, \ |
| 16 | + LinearUC, HyperLinearCUC |
12 | 17 | from torch_concepts.nn.modules.mid.models.variable import Variable |
13 | 18 | from torch_concepts.nn.modules.mid.models.cpd import ParametricCPD |
14 | 19 | from torch_concepts.nn.modules.mid.models.probabilistic_model import ProbabilisticModel |
@@ -356,40 +361,79 @@ def test_missing_factor(self): |
356 | 361 | with self.assertRaises(RuntimeError): |
357 | 362 | inference.predict(external_inputs) |
358 | 363 |
|
359 | | - def test_complex_multi_level_hierarchy(self): |
360 | | - """Test complex multi-level hierarchy.""" |
361 | | - # Level 0: latent |
362 | | - input_var = Variable('input', parents=[], distribution=Delta, size=10) |
363 | | - |
364 | | - # Level 1: A, B |
365 | | - var_a = Variable('A', parents=[input_var], distribution=Bernoulli, size=1) |
366 | | - var_b = Variable('B', parents=[input_var], distribution=Categorical, size=3) |
367 | | - |
368 | | - # Level 2: C (depends on A and B) |
369 | | - var_c = Variable('C', parents=[var_a, var_b], distribution=Bernoulli, size=1) |
370 | | - |
371 | | - # Level 3: D (depends on C) |
372 | | - var_d = Variable('D', parents=[var_c], distribution=Bernoulli, size=1) |
373 | | - |
374 | | - latent_factor = ParametricCPD('input', parametrization=nn.Identity()) |
375 | | - cpd_a = ParametricCPD('A', parametrization=nn.Linear(10, 1)) |
376 | | - cpd_b = ParametricCPD('B', parametrization=nn.Linear(10, 3)) |
377 | | - cpd_c = ParametricCPD('C', parametrization=nn.Linear(4, 1)) # 1 + 3 inputs |
378 | | - cpd_d = ParametricCPD('D', parametrization=nn.Linear(1, 1)) |
379 | | - |
380 | | - pgm = ProbabilisticModel( |
381 | | - variables=[input_var, var_a, var_b, var_c, var_d], |
382 | | - parametric_cpds=[latent_factor, cpd_a, cpd_b, cpd_c, cpd_d] |
383 | | - ) |
384 | | - |
385 | | - inference = SimpleForwardInference(pgm) |
386 | | - |
387 | | - self.assertEqual(len(inference.levels), 4) |
388 | | - |
389 | | - external_inputs = {'input': torch.randn(4, 10)} |
390 | | - results = inference.predict(external_inputs) |
391 | | - |
392 | | - self.assertEqual(len(results), 5) |
| 364 | + def test_unroll_pgm(self): |
| 365 | + latent_dims = 20 |
| 366 | + n_epochs = 1000 |
| 367 | + n_samples = 1000 |
| 368 | + concept_reg = 0.5 |
| 369 | + |
| 370 | + dataset = ToyDataset(dataset='xor', seed=42, n_gen=n_samples) |
| 371 | + x_train = dataset.input_data |
| 372 | + concept_idx = list(dataset.graph.edge_index[0].unique().numpy()) |
| 373 | + task_idx = list(dataset.graph.edge_index[1].unique().numpy()) |
| 374 | + c_train = dataset.concepts[:, concept_idx] |
| 375 | + y_train = dataset.concepts[:, task_idx] |
| 376 | + |
| 377 | + c_train = torch.cat([c_train, y_train], dim=1) |
| 378 | + y_train = deepcopy(c_train) |
| 379 | + cy_train = torch.cat([c_train, y_train], dim=1) |
| 380 | + c_train_one_hot = torch.cat( |
| 381 | + [cy_train[:, :2], torch.nn.functional.one_hot(cy_train[:, 2].long(), num_classes=2).float()], dim=1) |
| 382 | + cy_train_one_hot = torch.cat([c_train_one_hot, c_train_one_hot], dim=1) |
| 383 | + |
| 384 | + concept_names = ['c1', 'c2', 'xor'] |
| 385 | + task_names = ['c1_copy', 'c2_copy', 'xor_copy'] |
| 386 | + cardinalities = [1, 1, 2, 1, 1, 2] |
| 387 | + metadata = { |
| 388 | + 'c1': {'distribution': RelaxedBernoulli, 'type': 'binary', 'description': 'Concept 1'}, |
| 389 | + 'c2': {'distribution': RelaxedBernoulli, 'type': 'binary', 'description': 'Concept 2'}, |
| 390 | + 'xor': {'distribution': RelaxedOneHotCategorical, 'type': 'categorical', 'description': 'XOR Task'}, |
| 391 | + 'c1_copy': {'distribution': RelaxedBernoulli, 'type': 'binary', 'description': 'Concept 1 Copy'}, |
| 392 | + 'c2_copy': {'distribution': RelaxedBernoulli, 'type': 'binary', 'description': 'Concept 2 Copy'}, |
| 393 | + 'xor_copy': {'distribution': RelaxedOneHotCategorical, 'type': 'categorical', |
| 394 | + 'description': 'XOR Task Copy'}, |
| 395 | + } |
| 396 | + annotations = Annotations( |
| 397 | + {1: AxisAnnotation(concept_names + task_names, cardinalities=cardinalities, metadata=metadata)}) |
| 398 | + |
| 399 | + model_graph = ConceptGraph(torch.tensor([[0, 0, 0, 0, 1, 1], |
| 400 | + [0, 0, 0, 1, 0, 1], |
| 401 | + [0, 0, 0, 1, 1, 0], |
| 402 | + [0, 0, 0, 0, 0, 0], |
| 403 | + [0, 0, 0, 0, 0, 0], |
| 404 | + [0, 0, 0, 0, 0, 0]]), list(annotations.get_axis_annotation(1).labels)) |
| 405 | + |
| 406 | + # ProbabilisticModel Initialization |
| 407 | + encoder = torch.nn.Sequential(torch.nn.Linear(x_train.shape[1], latent_dims), torch.nn.LeakyReLU()) |
| 408 | + concept_model = GraphModel(model_graph=model_graph, |
| 409 | + input_size=latent_dims, |
| 410 | + annotations=annotations, |
| 411 | + source_exogenous=LazyConstructor(LinearZU, exogenous_size=11), |
| 412 | + internal_exogenous=LazyConstructor(LinearZU, exogenous_size=7), |
| 413 | + encoder=LazyConstructor(LinearUC), |
| 414 | + predictor=LazyConstructor(HyperLinearCUC, embedding_size=20)) |
| 415 | + |
| 416 | + # graph learning init |
| 417 | + graph_learner = WANDAGraphLearner(concept_names, task_names) |
| 418 | + |
| 419 | + inference_engine = AncestralSamplingInference(concept_model.probabilistic_model, graph_learner, temperature=0.1) |
| 420 | + query_concepts = ["c1", "c2", "xor", "c1_copy", "c2_copy", "xor_copy"] |
| 421 | + |
| 422 | + emb = encoder(x_train) |
| 423 | + cy_pred_before_unrolling = inference_engine.query(query_concepts, evidence={'input': emb}, debug=True) |
| 424 | + |
| 425 | + concept_model_new = inference_engine.unrolled_probabilistic_model() |
| 426 | + |
| 427 | + # identify available query concepts in the unrolled model |
| 428 | + query_concepts = [c for c in query_concepts if c in inference_engine.available_query_vars] |
| 429 | + concept_idx = {v: i for i, v in enumerate(concept_names)} |
| 430 | + reverse_c2t_mapping = dict(zip(task_names, concept_names)) |
| 431 | + query_concepts = sorted(query_concepts, key=lambda x: concept_idx[x] if x in concept_idx else concept_idx[reverse_c2t_mapping[x]]) |
| 432 | + |
| 433 | + inference_engine = AncestralSamplingInference(concept_model_new, temperature=0.1) |
| 434 | + cy_pred_after_unrolling = inference_engine.query(query_concepts, evidence={'input': emb}, debug=True) |
| 435 | + |
| 436 | + self.assertTrue(cy_pred_after_unrolling.shape == c_train_one_hot.shape) |
393 | 437 |
|
394 | 438 |
|
395 | 439 | if __name__ == '__main__': |
|
0 commit comments