diff --git a/symbolic_evaluation/__pycache__/cobweb_symbolic.cpython-39.pyc b/symbolic_evaluation/__pycache__/cobweb_symbolic.cpython-39.pyc
new file mode 100644
index 0000000..459122f
Binary files /dev/null and b/symbolic_evaluation/__pycache__/cobweb_symbolic.cpython-39.pyc differ
diff --git a/symbolic_evaluation/cobweb_symbolic.py b/symbolic_evaluation/cobweb_symbolic.py
new file mode 100644
index 0000000..304526c
--- /dev/null
+++ b/symbolic_evaluation/cobweb_symbolic.py
@@ -0,0 +1,108 @@
+import torch
+import numpy as np
+from torchvision import transforms, datasets
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+import untils
+from cobweb.cobweb_continuous import CobwebContinuousTree
+import json
+import base64
+import matplotlib.pyplot as plt
+from io import BytesIO
+from copy import deepcopy
+
+class CobwebSymbolic():
+ def __init__(self, input_dim, depth=5):
+ self.input_dim = input_dim
+ self.depth = depth
+ self.tree = CobwebContinuousTree(size=self.input_dim, covar_from=2, depth=self.depth, branching_factor=10)
+
+ def train(self, train_data, epochs=10):
+ train_loader = DataLoader(train_data, batch_size=1, shuffle=True)
+ for epoch in range(epochs):
+ for (x, y) in tqdm(train_loader):
+ x = x.view(-1).numpy()
+ self.tree.ifit(x)
+
+ def save_tree_to_json(self, filename):
+ # Convert tree to serializable format
+ def convert_node_to_dict(node):
+ if node is None:
+ return None
+
+ node_dict = {
+ "mean": node.mean.tolist() if isinstance(node.mean, np.ndarray) else node.mean,
+ "sum_sq": node.sum_sq.tolist() if isinstance(node.sum_sq, np.ndarray) else node.sum_sq,
+ "count": node.count.tolist() if isinstance(node.count, np.ndarray) else node.count,
+ "children": []
+ }
+
+ if hasattr(node, 'children'):
+ for child in node.children:
+ node_dict["children"].append(convert_node_to_dict(child))
+
+ return node_dict
+
+ # Convert the tree to a dictionary
+ tree_dict = convert_node_to_dict(self.tree.root)
+
+ # Save to file
+ with open(filename, 'w') as f:
+ json.dump(tree_dict, f)
+
+ def load_tree_in_torch(self, filename):
+ with open(filename, 'r') as f:
+ temp_tree = json.load(f)
+
+ pq = [temp_tree]
+
+ while True:
+ curr = pq.pop(0)
+ curr["mean"] = torch.tensor(curr["mean"])
+ curr["sum_sq"] = torch.tensor(curr["sum_sq"])
+ curr["count"] = torch.tensor(curr["count"])
+ curr["logvar"] = torch.log(curr["sum_sq"] / curr["count"])
+
+ if "children" not in curr or not curr["children"]:
+ break
+
+ for child in curr["children"]:
+ pq.append(child)
+
+ self.tree = temp_tree
+
+ def tensor_to_base64(self, tensor, shape, cmap="gray", normalize=False):
+ array = tensor.numpy().reshape(shape)
+ if normalize:
+ plt.imshow(array, cmap=cmap, aspect="auto")
+ else:
+ plt.imshow(array, cmap=cmap, aspect="auto", vmin=0, vmax=1)
+
+ plt.axis("off")
+
+ buf = BytesIO()
+ plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
+ plt.close()
+
+ buf.seek(0)
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
+
+ def viz_cobweb_tree(self, viz_filename):
+ temp_tree = deepcopy(self.tree)
+ pq = [temp_tree]
+ while True:
+ curr = pq.pop(0)
+ curr["image"] = self.tensor_to_base64(torch.tensor(curr["mean"]), (28, 28), cmap="inferno", normalize=True)
+ curr.pop("mean")
+ curr.pop("sum_sq")
+ curr.pop("count")
+ curr.pop("logvar")
+
+ if "children" not in curr or not curr["children"]:
+ break
+
+ for child in curr["children"]:
+ pq.append(child)
+
+ with open(f'{viz_filename}.json', 'w') as f:
+ json.dump(temp_tree, f)
\ No newline at end of file
diff --git a/symbolic_evaluation/evaluate_cobweb_symbolic.py b/symbolic_evaluation/evaluate_cobweb_symbolic.py
new file mode 100644
index 0000000..4e054bc
--- /dev/null
+++ b/symbolic_evaluation/evaluate_cobweb_symbolic.py
@@ -0,0 +1,537 @@
+ import torch
+import numpy as np
+from torchvision import transforms, datasets
+from torch.utils.data import DataLoader
+import matplotlib.pyplot as plt
+from sklearn.metrics import accuracy_score, confusion_matrix
+import seaborn as sns
+from tqdm import tqdm
+import json
+import os
+from cobweb_symbolic import CobwebSymbolic
+import untils
+
+# Helper functions for evaluation
+def safe_categorize(model, sample):
+ """
+ Safely categorize a sample using the model tree, with error handling.
+
+ Args:
+ model: Trained CobwebSymbolic model
+ sample: Input sample with label
+
+ Returns:
+ Node in the tree or None if categorization failed
+ """
+ try:
+ # Method 1: Try using _cobweb_categorize directly
+ if hasattr(model.tree, '_cobweb_categorize'):
+ return model.tree._cobweb_categorize(sample)
+ # Method 2: Try using categorize method
+ elif hasattr(model.tree, 'categorize'):
+ return model.tree.categorize(sample)
+ # Method 3: Try calling predict (which might use cobweb internally)
+ elif hasattr(model, 'predict'):
+ prediction = model.predict(sample[:-1]) # Exclude label for predict
+ return {"prediction": prediction}
+ else:
+ print(" No categorization method found")
+ return None
+ except Exception as e:
+ print(f" Categorization error: {e}")
+ return None
+
+def visualize_confusion_matrix(test_labels, predictions, title):
+ """
+ Visualize confusion matrix for model evaluation.
+
+ Args:
+ test_labels: True labels
+ predictions: Predicted labels
+ title: Title for the confusion matrix plot
+ """
+ plt.figure(figsize=(10, 8))
+ cm = confusion_matrix(test_labels, predictions)
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
+ plt.xlabel('Predicted Label')
+ plt.ylabel('True Label')
+ plt.title(title)
+ plt.show()
+
+def preprocess_dataset(dataset_name, split_classes=[0, 1, 2, 3]):
+ """
+ Load and preprocess dataset with symbolic feature encoding.
+
+ Args:
+ dataset_name: 'mnist' or 'cifar10'
+ split_classes: List of classes to use
+
+ Returns:
+ train_loader, test_loader, image_shape
+ """
+ print(f"Loading and preprocessing {dataset_name} dataset...")
+
+ if dataset_name.lower() == 'mnist':
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,))
+ ])
+
+ train_dataset = datasets.MNIST('data/MNIST', train=True, download=True, transform=transform)
+ test_dataset = datasets.MNIST('data/MNIST', train=False, download=True, transform=transform)
+ image_shape = (28, 28)
+
+ elif dataset_name.lower() == 'cifar10':
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
+ ])
+
+ train_dataset = datasets.CIFAR10('data/CIFAR10', train=True, download=True, transform=transform)
+ test_dataset = datasets.CIFAR10('data/CIFAR10', train=False, download=True, transform=transform)
+ image_shape = (32, 32, 3)
+
+ else:
+ raise ValueError(f"Unsupported dataset: {dataset_name}")
+
+ # Filter for specified classes
+ train_dataset = untils.filter_by_label(train_dataset, split_classes, rename_labels=True)
+ test_dataset = untils.filter_by_label(test_dataset, split_classes, rename_labels=True)
+
+ # Create data loaders
+ train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
+ test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
+
+ # Print class distribution
+ class_counts = {}
+ for _, label in train_dataset:
+ label_value = label.item() if hasattr(label, 'item') else int(label)
+ class_counts[label_value] = class_counts.get(label_value, 0) + 1
+
+ print(f"Training set class distribution:")
+ for class_idx, count in sorted(class_counts.items()):
+ print(f" Class {class_idx}: {count} samples")
+
+ return train_loader, test_loader, image_shape
+
+def train_symbolic_cobweb(train_loader, image_shape, depth=4, epochs=1):
+ """
+ Train a symbolic CobWeb model incrementally.
+
+ Args:
+ train_loader: DataLoader with training data
+ image_shape: Shape of input images
+ depth: Maximum depth of the CobWeb tree
+ epochs: Number of training epochs
+
+ Returns:
+ trained model
+ """
+ print("Training symbolic CobWeb model...")
+
+ if len(image_shape) == 2:
+ input_dim = image_shape[0] * image_shape[1]
+ else:
+ input_dim = image_shape[0] * image_shape[1] * image_shape[2]
+
+ # Initialize model
+ model = CobwebSymbolic(input_dim=input_dim, depth=depth)
+
+ # Train incrementally
+ for epoch in range(epochs):
+ print(f"Epoch {epoch+1}/{epochs}")
+ for data, label in tqdm(train_loader, desc="Training"):
+ # Flatten the input data
+ x = data.view(-1).numpy()
+ # Get label
+ # y = label.item() if hasattr(label, 'item') else int(label)
+
+ # Include label at the end of the input vector
+ # x_with_label = np.concatenate([x, np.array([y])])
+
+ # Incrementally fit the tree
+ model.tree.ifit(x_with_label)
+
+ return model
+
+def save_model_and_results(model, results, filename):
+ """
+ Save the trained model and evaluation results.
+
+ Args:
+ model: Trained CobwebSymbolic model
+ results: Evaluation results dictionary
+ filename: Base filename for saving
+ """
+ # Create directory if it doesn't exist
+ os.makedirs('results', exist_ok=True)
+
+ # Save model
+ model_file = f"results/{filename}_model.json"
+ try:
+ model.save_tree_to_json(model_file)
+ print(f"Model saved to {model_file}")
+ except Exception as e:
+ print(f"Error saving model: {e}")
+
+ # Save results
+ results_file = f"results/{filename}_results.json"
+
+ # Make all results JSON serializable
+ def make_serializable(obj):
+ if isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, (np.float64, np.float32, np.int64, np.int32)):
+ return float(obj)
+ elif isinstance(obj, dict):
+ return {str(k): make_serializable(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [make_serializable(item) for item in obj]
+ else:
+ return obj
+
+ serializable_results = make_serializable(results)
+
+ # Write to file
+ try:
+ with open(results_file, 'w') as f:
+ json.dump(serializable_results, f)
+ print(f"Results saved to {results_file}")
+ except Exception as e:
+ print(f"Error saving results: {e}")
+ print("Saving simplified results instead...")
+
+ # Fallback to basic results
+ simple_results = {}
+ for method, result in results.items():
+ if isinstance(result, dict) and "accuracy" in result:
+ simple_results[method] = {"accuracy": float(result["accuracy"])}
+ elif isinstance(result, dict) and "error" in result:
+ simple_results[method] = {"error": str(result["error"])}
+ else:
+ simple_results[method] = {"result": "unknown"}
+
+ with open(results_file, 'w') as f:
+ json.dump(simple_results, f)
+
+# Evaluation methods
+def evaluate_with_averaging(model, train_loader, test_loader, image_shape, num_classes,
+ samples_per_class, visualize=True):
+ """
+ Evaluation using sample points averaging: Select samples from each class,
+ find nodes in the tree, average them, and classify test samples.
+
+ Args:
+ model: Trained CobwebSymbolic model
+ train_loader: DataLoader with training data
+ test_loader: DataLoader with test data
+ image_shape: Shape of input images
+ num_classes: Number of classes in the dataset
+ samples_per_class: Number of sample points to take per class
+ visualize: Whether to generate visualizations
+
+ Returns:
+ Dictionary of evaluation metrics
+ """
+ print(f"Evaluating with averaging approach (samples_per_class={samples_per_class})...")
+
+ # 1. Collect training samples by class
+ class_samples = {cls: [] for cls in range(num_classes)}
+
+ for data, label in tqdm(train_loader, desc="Collecting samples by class"):
+ x = data.view(-1).numpy()
+ y = label.item() if hasattr(label, 'item') else int(label)
+
+ if y < num_classes: # Ensure valid class
+ class_samples[y].append(x)
+
+ # 2. Select samples for each class
+ selected_samples = {}
+ for cls in range(num_classes):
+ if class_samples[cls]:
+ # If samples_per_class is -1, use all samples
+ if samples_per_class == -1:
+ selected_samples[cls] = class_samples[cls]
+ print(f"Class {cls}: Using all {len(selected_samples[cls])} samples")
+ # Otherwise, select random samples if enough are available
+ elif len(class_samples[cls]) >= samples_per_class:
+ indices = np.random.choice(len(class_samples[cls]), samples_per_class, replace=False)
+ selected_samples[cls] = [class_samples[cls][i] for i in indices]
+ print(f"Class {cls}: Selected {len(selected_samples[cls])} samples")
+ else:
+ # Use all available samples if less than requested
+ selected_samples[cls] = class_samples[cls]
+ print(f"Class {cls}: Using all {len(selected_samples[cls])} samples (fewer than requested)")
+ else:
+ print(f"Warning: No samples found for class {cls}")
+ selected_samples[cls] = []
+
+ # 3. Calculate average for each class
+ class_averages = {}
+ for cls in range(num_classes):
+ if selected_samples[cls]:
+ # Calculate average sample
+ avg_sample = np.mean(selected_samples[cls], axis=0)
+
+ # Store average representation
+ class_averages[cls] = avg_sample
+ print(f"Class {cls}: Calculated average from {len(selected_samples[cls])} samples")
+ else:
+ print(f"Warning: Cannot calculate average for class {cls}")
+
+ # Exit if no valid classes found
+ if len(class_averages) == 0:
+ print("No valid class averages found, cannot proceed with evaluation")
+ return {"accuracy": 0.0, "error": "No valid class averages"}
+
+ # 4. Visualize class averages
+ if visualize:
+ plt.figure(figsize=(15, 5))
+ plt.suptitle(f"Class Averages (samples_per_class={samples_per_class})", fontsize=16)
+
+ for cls in range(num_classes):
+ if cls in class_averages:
+ plt.subplot(1, num_classes, cls + 1)
+
+ # Reshape average sample
+ if len(image_shape) == 2:
+ img = class_averages[cls].reshape(image_shape)
+ plt.imshow(img, cmap='gray')
+ else:
+ img = class_averages[cls].reshape(image_shape)
+ plt.imshow(img)
+
+ plt.title(f"Class {cls}")
+ plt.axis('off')
+
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
+ plt.show()
+
+ # 5. Evaluate on test data
+ test_labels = []
+ predictions = []
+
+ for data, label in tqdm(test_loader, desc="Evaluating test data"):
+ x = data.view(-1).numpy()
+ y = label.item() if hasattr(label, 'item') else int(label)
+ test_labels.append(y)
+
+ # Find closest class average
+ best_class = None
+ best_distance = float('inf')
+
+ for cls, avg_sample in class_averages.items():
+ try:
+ dist = np.linalg.norm(x - avg_sample)
+ if dist < best_distance:
+ best_distance = dist
+ best_class = cls
+ except ValueError:
+ # Skip if shapes don't match
+ continue
+
+ # Use most common class as fallback
+ if best_class is None:
+ best_class = np.argmax(np.bincount(test_labels))
+
+ predictions.append(best_class)
+
+ # 6. Calculate accuracy and visualize results
+ test_labels = np.array(test_labels)
+ predictions = np.array(predictions)
+ accuracy = accuracy_score(test_labels, predictions)
+
+ if visualize:
+ visualize_confusion_matrix(
+ test_labels, predictions, f'Confusion Matrix (Averaging, samples={samples_per_class})'
+ )
+
+ print(f"Averaging Approach (samples={samples_per_class}) Accuracy: {accuracy:.4f}")
+
+ return {
+ "accuracy": accuracy,
+ "confusion_matrix": confusion_matrix(test_labels, predictions).tolist()
+ }
+
+def evaluate_class_centroids(model, train_loader, test_loader, image_shape, num_classes, visualize=True):
+ """
+ Evaluate using centroids (average of all samples) from each class.
+
+ Args:
+ model: Trained CobwebSymbolic model
+ train_loader: DataLoader with training data
+ test_loader: DataLoader with test data
+ image_shape: Shape of input images
+ num_classes: Number of classes in the dataset
+ visualize: Whether to generate visualizations
+
+ Returns:
+ Dictionary of evaluation metrics
+ """
+ print("Evaluating with class centroids approach...")
+
+ # 1. Collect all training samples by class
+ class_samples = {cls: [] for cls in range(num_classes)}
+
+ for data, label in tqdm(train_loader, desc="Collecting all class samples"):
+ x = data.view(-1).numpy()
+ y = label.item() if hasattr(label, 'item') else int(label)
+
+ if y < num_classes: # Ensure valid class
+ class_samples[y].append(x)
+
+ # Print stats about collected samples
+ for cls in range(num_classes):
+ print(f"Class {cls}: Collected {len(class_samples[cls])} samples")
+
+ # 2. Calculate class centroids (average of all samples in each class)
+ class_centroids = {}
+ for cls in range(num_classes):
+ if len(class_samples[cls]) > 0:
+ class_centroids[cls] = np.mean(class_samples[cls], axis=0)
+
+ # Visualize class centroids
+ if visualize:
+ plt.figure(figsize=(15, 5))
+ plt.suptitle("Class Centroids (All Data)", fontsize=16)
+
+ for cls in range(num_classes):
+ if cls in class_centroids:
+ plt.subplot(1, num_classes, cls + 1)
+
+ # Reshape centroid
+ if len(image_shape) == 2:
+ img = class_centroids[cls].reshape(image_shape)
+ plt.imshow(img, cmap='gray')
+ else:
+ img = class_centroids[cls].reshape(image_shape)
+ plt.imshow(img)
+
+ plt.title(f"Class {cls} Centroid")
+ plt.axis('off')
+
+ plt.tight_layout(rect=[0, 0, 1, 0.95])
+ plt.show()
+
+ # 3. Evaluate on test data using centroids
+ test_labels = []
+ predictions = []
+
+ for data, label in tqdm(test_loader, desc="Evaluating with centroids"):
+ x = data.view(-1).numpy()
+ y = label.item() if hasattr(label, 'item') else int(label)
+ test_labels.append(y)
+
+ # Find closest centroid
+ best_class = None
+ best_distance = float('inf')
+
+ for cls, centroid in class_centroids.items():
+ try:
+ dist = np.linalg.norm(x - centroid)
+ if dist < best_distance:
+ best_distance = dist
+ best_class = cls
+ except ValueError:
+ continue
+
+ # Use most common class as fallback
+ if best_class is None:
+ best_class = np.argmax(np.bincount(test_labels))
+
+ predictions.append(best_class)
+
+ # 4. Calculate accuracy
+ test_labels = np.array(test_labels)
+ predictions = np.array(predictions)
+
+ accuracy = accuracy_score(test_labels, predictions)
+
+ # 5. Visualize results
+ if visualize:
+ visualize_confusion_matrix(
+ test_labels, predictions, 'Confusion Matrix (Class Centroids)'
+ )
+
+ print(f"Class Centroids Accuracy: {accuracy:.4f}")
+
+ return {
+ "accuracy": accuracy,
+ "confusion_matrix": confusion_matrix(test_labels, predictions).tolist()
+ }
+
+def main():
+ """Main function to run the symbolic CobWeb evaluation."""
+ # Configuration parameters
+ datasets = ['mnist', 'cifar10']
+ split_classes = [0, 1, 2, 3] # First 4 classes
+
+ # Default parameters
+ default_depth = 5
+ default_epochs = 3
+
+ # Sample sizes to test
+ sample_sizes = [30, 100, 500, 5000, -1] # -1 means use all available samples
+
+ for dataset_name in datasets:
+ print(f"\n\n{'='*50}")
+ print(f"Evaluating symbolic CobWeb on {dataset_name.upper()}")
+ print(f"{'='*50}\n")
+
+ # 1. Load and preprocess dataset
+ train_loader, test_loader, image_shape = preprocess_dataset(dataset_name, split_classes)
+
+ # 2. Train model
+ model = train_symbolic_cobweb(train_loader, image_shape, depth=default_depth, epochs=default_epochs)
+
+ # 3. Evaluate with different approaches
+ results = {}
+
+ # Test different sample sizes for averaging approach
+ for samples in sample_sizes:
+ method_name = f"averaging_samples_{samples}" if samples != -1 else "averaging_all_samples"
+ print(f"\nEvaluating averaging approach with {samples if samples != -1 else 'all'} samples per class...")
+
+ try:
+ results[method_name] = evaluate_with_averaging(
+ model, train_loader, test_loader, image_shape, len(split_classes),
+ samples_per_class=samples, visualize=True
+ )
+ except Exception as e:
+ print(f"Error with averaging evaluation (samples={samples}): {e}")
+ results[method_name] = {"accuracy": 0.0, "error": str(e)}
+
+ # Evaluate with class centroids approach
+ print("\nEvaluating with class centroids approach...")
+ try:
+ results["class_centroids"] = evaluate_class_centroids(
+ model, train_loader, test_loader, image_shape, len(split_classes), visualize=True
+ )
+ except Exception as e:
+ print(f"Error with class centroids evaluation: {e}")
+ results["class_centroids"] = {"error": str(e)}
+
+ # 4. Save model and results
+ save_model_and_results(model, results, f"{dataset_name}_symbolic_evaluation")
+
+ # 5. Print summary
+ print("\n" + "="*50)
+ print(f"SUMMARY OF RESULTS FOR {dataset_name.upper()}:")
+ print("="*50)
+
+ for samples in sample_sizes:
+ method_name = f"averaging_samples_{samples}" if samples != -1 else "averaging_all_samples"
+ if "error" not in results[method_name]:
+ print(f"Averaging ({samples if samples != -1 else 'all'} samples): {results[method_name]['accuracy']:.4f}")
+ else:
+ print(f"Averaging ({samples if samples != -1 else 'all'} samples): Failed")
+
+ if "error" not in results["class_centroids"]:
+ print(f"Class Centroids: {results['class_centroids']['accuracy']:.4f}")
+ else:
+ print(f"Class Centroids: Failed")
+
+ print("\n\n")
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/treevae/README.md b/treevae/README.md
new file mode 100644
index 0000000..badd37d
--- /dev/null
+++ b/treevae/README.md
@@ -0,0 +1,35 @@
+# Tree Variational Autoencoders
+This is the PyTorch repository for the NeurIPS 2023 Spotlight Publication (https://neurips.cc/virtual/2023/poster/71188).
+
+TreeVAE is a new generative method that learns the optimal tree-based posterior distribution of latent variables to capture the hierarchical structures present in the data. It adapts the architecture to discover the optimal tree for encoding dependencies between latent variables. TreeVAE optimizes the balance between shared and specialized architecture, enhancing the learning and adaptation capabilities of generative models.
+An example of a tree learned by TreeVAE is depicted in the figure below. Each edge and each split are encoded by neural networks, while the circles depict latent variables. Each sample is associated with a probability distribution over different paths of the discovered tree. The resulting tree thus organizes the data into an interpretable hierarchical structure in an unsupervised fashion, optimizing the amount of shared information between samples. In CIFAR-10, for example, the method divides the vehicles and animals into two different subtrees and similar groups (such as planes and ships) share common ancestors.
+
+
+For running TreeVAE:
+
+1. Create a new environment with the ```treevae.yml``` or ```minimal_requirements.txt``` file.
+2. Select the dataset you wish to use by changing the default config_name in the main.py parser.
+3. Potentially adapt default configuration in the config of the selected dataset (config/data_name.yml), the full set of config parameters with their explanations can be found in ```config/mnist.yml```.
+4. For Weights & Biases support, set project & entity in ```train/train.py``` and change the value of ```wandb_logging``` to ```online``` in the config file.
+5. Run ```main.py```.
+
+For exploring TreeVAE results (including the discovered tree, the generation of new images, the clustering performances and much more) we created a jupyter notebook (```tree_exploration.ipynb```):
+1. Run the steps above by setting ```save_model=True```.
+2. Copy the experiment path where the model is saved (it will be printed out).
+3. Open ```tree_exploration.ipynb```, replace the experiment path with yours, and have fun exploring the model!
+
+DISCLAIMER: This PyTorch repository was thoroughly debugged and tested, however, please note that the experiments of the submission were performed using the repository with the Tensorflow code (https://github.com/lauramanduchi/treevae-tensorflow).
+
+## Citing
+To cite TreeVAE please use the following BibTEX entries:
+
+```
+@inproceedings{
+manduchi2023tree,
+title={Tree Variational Autoencoders},
+author={Laura Manduchi and Moritz Vandenhirtz and Alain Ryser and Julia E Vogt},
+booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
+year={2023},
+url={https://openreview.net/forum?id=adq0oXb9KM}
+}
+```
diff --git a/treevae/configs/celeba.yml b/treevae/configs/celeba.yml
new file mode 100644
index 0000000..3495f5e
--- /dev/null
+++ b/treevae/configs/celeba.yml
@@ -0,0 +1,39 @@
+run_name: 'celeba'
+
+data:
+ data_name: 'celeba'
+ num_clusters_data: 1
+
+training:
+ num_epochs: 150
+ num_epochs_smalltree: 150
+ num_epochs_intermediate_fulltrain: 0
+ num_epochs_finetuning: 0
+ batch_size: 256
+ lr: 0.001
+ weight_decay: 0.00001
+ decay_lr: 0.1
+ decay_stepsize: 100
+ decay_kl: 0.01
+ kl_start: 0.01
+
+ inp_shape: 12288
+ latent_dim: [64,64,64,64,64,64]
+ mlp_layers: [4096, 512, 512, 512, 512, 512]
+ initial_depth: 1
+ activation: 'mse'
+ encoder: 'cnn2'
+ grow: True
+ prune: True
+ num_clusters_tree: 10
+ augment: True
+ augmentation_method: 'InfoNCE,instancewise_full'
+ aug_decisions_weight: 100
+ compute_ll: False
+
+globals:
+ wandb_logging: 'disabled'
+ eager_mode: False
+ seed: 42
+ save_model: True
+ config_name: 'celeba'
\ No newline at end of file
diff --git a/treevae/configs/cifar10.yml b/treevae/configs/cifar10.yml
new file mode 100644
index 0000000..3bf8b90
--- /dev/null
+++ b/treevae/configs/cifar10.yml
@@ -0,0 +1,39 @@
+run_name: 'cifar10'
+
+data:
+ data_name: 'cifar10'
+ num_clusters_data: 10
+
+training:
+ num_epochs: 150
+ num_epochs_smalltree: 150
+ num_epochs_intermediate_fulltrain: 0
+ num_epochs_finetuning: 0
+ batch_size: 256
+ lr: 0.001
+ weight_decay: 0.00001
+ decay_lr: 0.1
+ decay_stepsize: 100
+ decay_kl: 0.01
+ kl_start: 0.01
+
+ inp_shape: 3072
+ latent_dim: [64,64,64,64,64,64]
+ mlp_layers: [4096, 512, 512, 512, 512, 512]
+ initial_depth: 1
+ activation: 'mse'
+ encoder: 'cnn2'
+ grow: True
+ prune: True
+ num_clusters_tree: 10
+ augment: True
+ augmentation_method: 'InfoNCE,instancewise_full'
+ aug_decisions_weight: 100
+ compute_ll: False
+
+globals:
+ wandb_logging: 'disabled'
+ eager_mode: False
+ seed: 42
+ save_model: False
+ config_name: 'cifar10'
diff --git a/treevae/configs/cifar100.yml b/treevae/configs/cifar100.yml
new file mode 100644
index 0000000..a44955c
--- /dev/null
+++ b/treevae/configs/cifar100.yml
@@ -0,0 +1,39 @@
+run_name: 'cifar100'
+
+data:
+ data_name: 'cifar100'
+ num_clusters_data: 20
+
+training:
+ num_epochs: 150
+ num_epochs_smalltree: 150
+ num_epochs_intermediate_fulltrain: 0
+ num_epochs_finetuning: 0
+ batch_size: 256
+ lr: 0.001
+ weight_decay: 0.00001
+ decay_lr: 0.1
+ decay_stepsize: 100
+ decay_kl: 0.01
+ kl_start: 0.01
+
+ inp_shape: 3072
+ latent_dim: [64,64,64,64,64,64]
+ mlp_layers: [4096, 512, 512, 512, 512, 512]
+ initial_depth: 1
+ activation: 'mse'
+ encoder: 'cnn2'
+ grow: True
+ prune: True
+ num_clusters_tree: 20
+ augment: True
+ augmentation_method: 'InfoNCE,instancewise_full'
+ aug_decisions_weight: 100
+ compute_ll: False
+
+globals:
+ wandb_logging: 'disabled'
+ eager_mode: False
+ seed: 42
+ save_model: False
+ config_name: 'cifar100'
diff --git a/treevae/configs/fmnist.yml b/treevae/configs/fmnist.yml
new file mode 100644
index 0000000..56c2459
--- /dev/null
+++ b/treevae/configs/fmnist.yml
@@ -0,0 +1,39 @@
+run_name: 'fmnist'
+
+data:
+ data_name: 'fmnist'
+ num_clusters_data: 10
+
+training:
+ num_epochs: 150
+ num_epochs_smalltree: 150
+ num_epochs_intermediate_fulltrain: 80
+ num_epochs_finetuning: 200
+ batch_size: 256
+ lr: 0.001
+ weight_decay: 0.00001
+ decay_lr: 0.1
+ decay_stepsize: 100
+ decay_kl: 0.001
+ kl_start: 0.0
+
+ inp_shape: 784
+ latent_dim: [8, 8, 8, 8, 8, 8]
+ mlp_layers: [128, 128, 128, 128, 128, 128]
+ initial_depth: 1
+ activation: "sigmoid"
+ encoder: 'cnn1'
+ grow: True
+ prune: True
+ num_clusters_tree: 10
+ compute_ll: False
+ augment: False
+ augmentation_method: 'simple'
+ aug_decisions_weight: 1
+
+globals:
+ wandb_logging: 'disabled'
+ eager_mode: True
+ seed: 42
+ save_model: False
+ config_name: 'fmnist'
\ No newline at end of file
diff --git a/treevae/configs/mnist.yml b/treevae/configs/mnist.yml
new file mode 100644
index 0000000..327ef89
--- /dev/null
+++ b/treevae/configs/mnist.yml
@@ -0,0 +1,39 @@
+run_name: 'mnist' # name of the run
+
+data:
+ data_name: 'mnist' # name of the dataset
+ num_clusters_data: 10 # number of true clusters in the data (if known), this is used only for evaluation purposes
+
+training:
+ num_epochs: 150 # number of epochs to train the initial tree
+ num_epochs_smalltree: 150 # number of epochs to train the sub-tree during growing
+ num_epochs_intermediate_fulltrain: 80 # number of epochs to train the full tree during growing
+ num_epochs_finetuning: 200 # number of epochs to train the final tree
+ batch_size: 256 # batch size
+ lr: 0.001 # learning rate
+ weight_decay: 0.00001 # optimizer weight decay
+ decay_lr: 0.1 # learning rate decay
+ decay_stepsize: 100 # number of epochs after which learning rate decays
+ decay_kl: 0.001 # KL-annealing weight increase per epoch (capped at 1)
+ kl_start: 0.0 # KL-annealing weight initialization
+
+ inp_shape: 784 # The total dimensions of the input data (if rgb images of 32x32 then 32x32x3)
+ latent_dim: [8, 8, 8, 8, 8, 8] # A list of latent dimensions for each depth of the tree from the bottom to the root, last value is the dimensionality of the root node
+ mlp_layers: [128, 128, 128, 128, 128, 128] # A list of hidden units number for the MLP transformations for each depth of the tree from bottom to root
+ initial_depth: 1 # The initial depth of the tree (root has depth 0 and a root with two leaves has depth 1)
+ activation: "sigmoid" # The name of the activation function for the reconstruction loss [sigmoid, mse]
+ encoder: 'cnn1' # Type of encoder/decoder used
+ grow: True # Whether to grow the tree
+ prune: True # Whether to prune the tree of empty leaves
+ num_clusters_tree: 10 # The maximum number of leaves of the final tree
+ compute_ll: False # Whether to compute the log-likelihood estimation at the end of the training (it might take some time)
+ augment: False # Whether to use contrastive learning through augmentation
+ augmentation_method: 'simple' # The type of augmentation method used if augment is True
+ aug_decisions_weight: 1 # The weight of the contrastive losses
+
+globals:
+ wandb_logging: 'disabled' # Whether to log to wandb [online, offline, disabled]
+ eager_mode: True # Whether to run in eager or graph mode
+ seed: 42 # Random seed
+ save_model: False # Whether to save the model. Set to True for inspecting models in notebook
+ config_name: 'mnist'
diff --git a/treevae/configs/news20.yml b/treevae/configs/news20.yml
new file mode 100644
index 0000000..987a8fc
--- /dev/null
+++ b/treevae/configs/news20.yml
@@ -0,0 +1,39 @@
+run_name: 'news20'
+
+data:
+ data_name: 'news20'
+ num_clusters_data: 20
+
+training:
+ num_epochs: 150
+ num_epochs_smalltree: 150
+ num_epochs_intermediate_fulltrain: 80
+ num_epochs_finetuning: 200
+ batch_size: 256
+ lr: 0.001
+ weight_decay: 0.00001
+ decay_lr: 0.1
+ decay_stepsize: 100
+ decay_kl: 0.001
+ kl_start: 0.0
+
+ inp_shape: 2000
+ latent_dim: [4, 4, 4, 4, 4, 4, 4]
+ mlp_layers: [128, 128, 128, 128, 128, 128, 128]
+ initial_depth: 1
+ activation: "sigmoid"
+ encoder: 'mlp'
+ grow: True
+ prune: True
+ num_clusters_tree: 20
+ compute_ll: False
+ augment: False
+ augmentation_method: 'simple'
+ aug_decisions_weight: 1
+
+globals:
+ wandb_logging: 'disabled'
+ eager_mode: True
+ seed: 42
+ save_model: False
+ config_name: 'news20'
\ No newline at end of file
diff --git a/treevae/configs/omniglot.yml b/treevae/configs/omniglot.yml
new file mode 100644
index 0000000..65f55ba
--- /dev/null
+++ b/treevae/configs/omniglot.yml
@@ -0,0 +1,41 @@
+run_name: 'omniglot'
+
+data:
+ data_name: 'omniglot'
+ num_clusters_data: 5
+ path: 'datasets/omniglot'
+
+training:
+ num_epochs: 150
+ num_epochs_smalltree: 150
+ num_epochs_intermediate_fulltrain: 80
+ num_epochs_finetuning: 200
+ batch_size: 256
+ lr: 0.001
+ weight_decay: 0.00001
+ decay_lr: 0.1
+ decay_stepsize: 100
+ decay_kl: 0.001
+ kl_start: 0.001
+
+ inp_shape: 784
+ latent_dim: [8, 8, 8, 8, 8, 8]
+ mlp_layers: [128, 128, 128, 128, 128, 128]
+ initial_depth: 1
+ activation: "sigmoid"
+ encoder: 'cnn_omni'
+ grow: True
+ prune: True
+ num_clusters_tree: 5
+ compute_ll: False
+ augment: True
+ augmentation_method: 'simple'
+ aug_decisions_weight: 1
+
+
+globals:
+ wandb_logging: 'online'
+ eager_mode: True
+ seed: 42
+ save_model: False
+ config_name: 'omniglot'
\ No newline at end of file
diff --git a/treevae/evaluate_treevae.ipynb b/treevae/evaluate_treevae.ipynb
new file mode 100644
index 0000000..9f0d4fe
--- /dev/null
+++ b/treevae/evaluate_treevae.ipynb
@@ -0,0 +1,152 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "import yaml\n",
+ "from sklearn.metrics import accuracy_score, normalized_mutual_info_score\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "\n",
+ "from models.model import TreeVAE\n",
+ "from utils.model_utils import construct_tree_fromnpy\n",
+ "from utils.data_utils import get_data, get_gen\n",
+ "from utils.training_utils import predict\n",
+ "from utils.utils import cluster_acc\n",
+ "\n",
+ "checkpoint_path = 'models/experiments/mnist/20231025-175819_d6be9' # ← 请替换为你的路径\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "\n",
+ "with open(os.path.join(checkpoint_path, \"config.yaml\"), 'r') as f:\n",
+ " configs = yaml.load(f, Loader=yaml.Loader)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = TreeVAE(**configs['training'])\n",
+ "model.load_state_dict(torch.load(os.path.join(checkpoint_path, \"model_weights.pt\"), map_location=device), strict=True)\n",
+ "\n",
+ "tree_structure = np.load(os.path.join(checkpoint_path, \"data_tree.npy\"), allow_pickle=True)\n",
+ "model = construct_tree_fromnpy(model, tree_structure, configs)\n",
+ "\n",
+ "model.to(device)\n",
+ "model.eval()\n",
+ "print(\"model loaded\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trainset, trainset_eval, testset = get_data(configs)\n",
+ "gen_train_eval = get_gen(trainset_eval, configs, validation=True, shuffle=False)\n",
+ "gen_test = get_gen(testset, configs, validation=True, shuffle=False)\n",
+ "\n",
+ "y_train = trainset_eval.dataset.targets[trainset_eval.indices].numpy()\n",
+ "y_test = testset.dataset.targets[testset.indices].numpy()\n",
+ "\n",
+ "print(f\"data loaded | Train: {len(y_train)}, Test: {len(y_test)}\")\n",
+ "\n",
+ "train_acc = cluster_acc(y_train, trainset_eval.dataset.targets[trainset_eval.indices].numpy())\n",
+ "test_acc = cluster_acc(y_test, testset.dataset.targets[testset.indices].numpy())\n",
+ "\n",
+ "print(f\"cluster accuracy | Train: {train_acc:.3f}, Test: {test_acc:.3f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "_, y_pred = predict(gen_test, model, device)\n",
+ "\n",
+ "acc = cluster_acc(y_test, y_pred.numpy(), return_index=False)\n",
+ "nmi = normalized_mutual_info_score(y_test, y_pred.numpy())\n",
+ "\n",
+ "print(f\"Clustering ACC: {acc:.4f}\")\n",
+ "print(f\"NMI: {nmi:.4f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "z_train = predict(gen_train_eval, model, device, 'bottom_up')[-1].cpu().numpy()\n",
+ "z_test = predict(gen_test, model, device, 'bottom_up')[-1].cpu().numpy()\n",
+ "\n",
+ "clf = LogisticRegression(max_iter=1000)\n",
+ "clf.fit(z_train, y_train)\n",
+ "y_lp_pred = clf.predict(z_test)\n",
+ "\n",
+ "lp_acc = accuracy_score(y_test, y_lp_pred)\n",
+ "print(f\"Linear Probe Accuracy: {lp_acc:.4f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "prob_train = predict(gen_train_eval, model, device, 'prob_leaves')\n",
+ "prob_test = predict(gen_test, model, device, 'prob_leaves')\n",
+ "\n",
+ "leaf_test = prob_test.argmax(axis=1)\n",
+ "\n",
+ "def compute_dp_score(labels, leaves):\n",
+ " count = 0\n",
+ " total = 0\n",
+ " for i in range(len(labels)):\n",
+ " for j in range(i+1, len(labels)):\n",
+ " if labels[i] == labels[j]:\n",
+ " total += 1\n",
+ " if leaves[i] == leaves[j]:\n",
+ " count += 1\n",
+ " return count / total if total > 0 else 0.0\n",
+ "\n",
+ "dp_score = compute_dp_score(y_test, leaf_test)\n",
+ "print(f\"Decision Path Agreement: {dp_score:.4f}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"======== TreeVAE Final Evaluation ========\")\n",
+ "print(f\"Clustering ACC : {acc:.4f}\")\n",
+ "print(f\"NMI : {nmi:.4f}\")\n",
+ "print(f\"Linear Probe ACC : {lp_acc:.4f}\")\n",
+ "print(f\"Decision Path Agree : {dp_score:.4f}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "new_env",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.9.20"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/treevae/main.py b/treevae/main.py
new file mode 100644
index 0000000..f66eb1b
--- /dev/null
+++ b/treevae/main.py
@@ -0,0 +1,62 @@
+"""
+Runs the treeVAE model.
+"""
+import argparse
+from pathlib import Path
+import distutils
+
+from train.train import run_experiment
+from utils.utils import prepare_config
+
+
+def main():
+ project_dir = Path(__file__).absolute().parent
+ print("Project directory:", project_dir)
+
+ parser = argparse.ArgumentParser()
+
+ # Model parameters
+ parser.add_argument('--data_name', type=str, help='the dataset')
+ parser.add_argument('--num_epochs', type=int, help='the number of training epochs')
+ parser.add_argument('--num_epochs_finetuning', type=int, help='the number of finetuning epochs')
+ parser.add_argument('--num_epochs_intermediate_fulltrain', type=int, help='the number of finetuning epochs during training')
+ parser.add_argument('--num_epochs_smalltree', type=int, help='the number of sub-tree training epochs')
+
+ parser.add_argument('--num_clusters_data', type=int, help='the number of clusters in the data')
+ parser.add_argument('--num_clusters_tree', type=int, help='the max number of leaves of the tree')
+
+ parser.add_argument('--kl_start', type=float, nargs='?', const=0.,
+ help='initial KL divergence from where annealing starts')
+ parser.add_argument('--decay_kl', type=float, help='KL divergence annealing')
+ parser.add_argument('--latent_dim', type=str, help='specifies the latent dimensions of the tree')
+ parser.add_argument('--mlp_layers', type=str, help='specifies how many layers should the MLPs have')
+
+ parser.add_argument('--grow', type=lambda x: bool(distutils.util.strtobool(x)), help='whether to grow the tree')
+ parser.add_argument('--augment', type=lambda x: bool(distutils.util.strtobool(x)), help='augment images or not')
+ parser.add_argument('--augmentation_method', type=str, help='none vs simple augmentation vs contrastive approaches')
+ parser.add_argument('--aug_decisions_weight', type=float,
+ help='weight of similarity regularizer for augmented images')
+ parser.add_argument('--compute_ll', type=lambda x: bool(distutils.util.strtobool(x)),
+ help='whether to compute the log-likelihood')
+
+ # Other parameters
+ parser.add_argument('--save_model', type=lambda x: bool(distutils.util.strtobool(x)),
+ help='specifies if the model should be saved')
+ parser.add_argument('--eager_mode', type=lambda x: bool(distutils.util.strtobool(x)),
+ help='specifies if the model should be run in graph or eager mode')
+ parser.add_argument('--num_workers', type=int, help='number of workers in dataloader')
+ parser.add_argument('--seed', type=int, help='random number generator seed')
+ parser.add_argument('--wandb_logging', type=str, help='online, disabled, offline enables logging in wandb')
+
+ # Specify config name
+ parser.add_argument('--config_name', default='mnist', type=str,
+ choices=['mnist', 'fmnist', 'news20', 'omniglot', 'cifar10', 'cifar100', 'celeba'],
+ help='the override file name for config.yml')
+
+ args = parser.parse_args()
+ configs = prepare_config(args, project_dir)
+ run_experiment(configs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/treevae/minimal_requirements.txt b/treevae/minimal_requirements.txt
new file mode 100644
index 0000000..3427b94
--- /dev/null
+++ b/treevae/minimal_requirements.txt
@@ -0,0 +1,14 @@
+attrs==22.1.0
+cudatoolkit==11.7 # conda install cudatoolkit=11.7 -c pytorch -c nvidia
+cudnn==8.9.2.26
+numpy==1.23.4
+Pillow==9.3.0
+python==3.9.15 # Do this first
+PyYAML==6.0
+scikit_learn==1.1.3
+scipy==1.10.0
+torch==2.0.1 # conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
+torchmetrics==1.0.3
+torchvision==0.15.2
+tqdm==4.65.0
+wandb==0.13.5
\ No newline at end of file
diff --git a/treevae/models/__pycache__/losses.cpython-39.pyc b/treevae/models/__pycache__/losses.cpython-39.pyc
new file mode 100644
index 0000000..8d562e0
Binary files /dev/null and b/treevae/models/__pycache__/losses.cpython-39.pyc differ
diff --git a/treevae/models/__pycache__/model.cpython-39.pyc b/treevae/models/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000..7a45c98
Binary files /dev/null and b/treevae/models/__pycache__/model.cpython-39.pyc differ
diff --git a/treevae/models/__pycache__/networks.cpython-39.pyc b/treevae/models/__pycache__/networks.cpython-39.pyc
new file mode 100644
index 0000000..70157cb
Binary files /dev/null and b/treevae/models/__pycache__/networks.cpython-39.pyc differ
diff --git a/treevae/models/losses.py b/treevae/models/losses.py
new file mode 100644
index 0000000..29db55d
--- /dev/null
+++ b/treevae/models/losses.py
@@ -0,0 +1,34 @@
+"""
+Loss functions for the reconstruction term of the ELBO.
+"""
+import torch
+import torch.nn.functional as F
+
+
+def loss_reconstruction_binary(x, x_decoded_mean, weights):
+ x = torch.flatten(x, start_dim=1)
+ x_decoded_mean = [torch.flatten(decoded_leaf, start_dim=1) for decoded_leaf in x_decoded_mean]
+ loss = torch.sum(
+ torch.stack([weights[i] *
+ F.binary_cross_entropy(input = x_decoded_mean[i], target = x, reduction='none').sum(dim=-1)
+ for i in range(len(x_decoded_mean))], dim=-1), dim=-1)
+ return loss
+
+def loss_reconstruction_mse(x, x_decoded_mean, weights):
+ x = torch.flatten(x, start_dim=1)
+ x_decoded_mean = [torch.flatten(decoded_leaf, start_dim=1) for decoded_leaf in x_decoded_mean]
+ loss = torch.sum(
+ torch.stack([weights[i] *
+ F.mse_loss(input = x_decoded_mean[i], target = x, reduction='none').sum(dim=-1)
+ for i in range(len(x_decoded_mean))], dim=-1), dim=-1)
+ return loss
+
+def loss_reconstruction_cov_mse_eval(x, x_decoded_mean, weights):
+ # NOTE Only use for evaluation purposes, as the clamping stops gradient flow
+ # NOTE WE ASSUME IDENTITY MATRIX BECAUSE WE ASSUME THIS IMPLICITLY WHEN ONLY OPTIMIZING MSE
+ scale = torch.diag(torch.ones_like(x_decoded_mean[0]))
+ logpx = torch.zeros_like(weights[0])
+ for i in range(len(x_decoded_mean)):
+ x_dist = torch.distributions.multivariate_normal.MultivariateNormal(loc=torch.clamp(x_decoded_mean[i],0,1), covariance_matrix=scale)
+ logpx = logpx + weights[i] * x_dist.log_prob(x)
+ return logpx
diff --git a/treevae/models/model.py b/treevae/models/model.py
new file mode 100644
index 0000000..53d3a24
--- /dev/null
+++ b/treevae/models/model.py
@@ -0,0 +1,602 @@
+"""
+TreeVAE model.
+"""
+import torch
+import torch.nn as nn
+import torch.distributions as td
+from utils.model_utils import construct_tree, compute_posterior
+from models.networks import get_encoder, get_decoder, MLP, Router, Dense
+from models.losses import loss_reconstruction_binary, loss_reconstruction_mse
+from utils.model_utils import return_list_tree
+from utils.training_utils import calc_aug_loss
+
+class TreeVAE(nn.Module):
+ """
+ A class used to represent a tree-based VAE.
+
+ TreeVAE specifies a Variational Autoencoder with a tree structure posterior distribution of latent variables.
+ It is defined by a bottom-up chain of deterministic transformations that from the input x compute the root
+ representation of the data, and a probabilistic top-down architecture which takes the form of a tree. The
+ top down tree is described by the probability distribution of its node (which depends on their parents) and the
+ probability distribution of the decisions (what is the probability of following a certain path in the tree).
+ Each node of the tree is described by the class Node in utils.model_utils.
+
+ Attributes
+ ----------
+ activation : str
+ The name of the activation function for the reconstruction loss [sigmoid, mse]
+ loss : models.losses
+ The loss function used by the decoder to reconstruct the input
+ alpha : float
+ KL-annealing weight initialization
+ encoded_sizes : list
+ A list of latent dimensions for each depth of the tree from the bottom to the root
+ hidden_layers : float
+ A list of hidden units number for the MLP transformations for each depth of the tree from bottom to root
+ depth : int
+ The depth of the tree (root has depth 0 and a root with two leaves has depth 1)
+ inp_shape : int
+ The total dimensions of the input data (if images of 32x32 then 32x32x3)
+ augment : bool
+ Whether to use contrastive learning through augmentation, if False no augmentation is used
+ augmentation_method : str
+ The type of augmentation method used
+ aug_decisions_weight : str
+ The weight of the contrastive loss used in the decisions
+ return_x : float
+ Whether to return the input in the return dictionary of the forward method
+ return_elbo : float
+ Whether to return the sample-specific elbo in the return dictionary of the forward method
+ return_bottomup : float
+ Whether to return the list of bottom-up transformations (including encoder)
+ bottom_up : str
+ The list of bottom-up transformations [encoder, MLP, MLP, ...] up to the root
+ contrastive_mlp : list
+ The list of transformations from the bottom-up embeddings to the latent spaces
+ in which the contrastive losses are applied
+ transformations : list
+ List of transformations (MLPs) associated with each node of the tree from root to bottom (left to right)
+ denses : list
+ List of dense layers for the sharing of top-down and bottom-up (MLPs) associated with each node of the tree
+ from root to bottom (left to right).
+ decisions : list
+ List of decisions associated with each node of the tree from root to bottom (left to right)
+ decoders : list
+ List of decoders one for each leaf
+ decisions_q : str
+ List of decisions of the bottom-up associated with each node of the tree from root to bottom (left to right)
+ tree : utils.model_utils.Node
+ The root node of the tree
+
+ Methods
+ -------
+ forward(x)
+ Compute the forward pass of the treeVAE model and return a dictionary of losses and optional outputs
+ (like input, bottom-up and sample-specific elbo) when needed.
+ compute_leaves()
+ Return a list of leaf-nodes from left to right of the current tree (self.tree).
+ compute_depth()
+ Calculate the depth of the given tree (self.tree).
+ attach_smalltree(node, small_model)
+ Attach a sub tree (small_model) to the given node of the current tree.
+ compute_reconstruction(x)
+ Given the input x, it computes the reconstructions.
+ generate_images(n_samples, device)
+ Generate n_samples new images by sampling from the root and propagating through the entire tree.
+ """
+
+ def __init__(self, **kwargs):
+ """
+ Parameters
+ ----------
+ kwargs : dict
+ A dictionary of attributes (see config file).
+ """
+ super(TreeVAE, self).__init__()
+ self.kwargs = kwargs
+
+ self.activation = self.kwargs['activation']
+ if self.activation == "sigmoid":
+ self.loss = loss_reconstruction_binary
+ elif self.activation == "mse":
+ self.loss = loss_reconstruction_mse
+ else:
+ raise NotImplementedError
+ # KL-annealing weight initialization
+ self.alpha = torch.tensor(self.kwargs['kl_start'])
+
+ # saving important variables to initialize the tree
+ self.encoded_sizes = self.kwargs['latent_dim']
+ self.hidden_layers = self.kwargs['mlp_layers']
+ # check that the number of layers for bottom up is equal to top down
+ if len(self.encoded_sizes) != len(self.hidden_layers):
+ raise ValueError('Model is mispecified!!')
+ self.depth = self.kwargs['initial_depth']
+ self.inp_shape = self.kwargs['inp_shape']
+ self.augment = self.kwargs['augment']
+ self.augmentation_method = self.kwargs['augmentation_method']
+ self.aug_decisions_weight = self.kwargs['aug_decisions_weight']
+ self.return_x = torch.tensor([False])
+ self.return_bottomup = torch.tensor([False])
+ self.return_elbo = torch.tensor([False])
+
+ # bottom up: the inference chain that from input computes the d units till the root
+ if self.activation == "mse":
+ size = int((self.inp_shape / 3)**0.5)
+ encoder = get_encoder(architecture=self.kwargs['encoder'], encoded_size=self.hidden_layers[0],
+ x_shape=self.inp_shape, size=size)
+ else:
+ encoder = get_encoder(architecture=self.kwargs['encoder'], encoded_size=self.hidden_layers[0],
+ x_shape=self.inp_shape)
+
+ self.bottom_up = nn.ModuleList([encoder])
+ for i in range(1, len(self.hidden_layers)):
+ self.bottom_up.append(MLP(self.hidden_layers[i-1], self.encoded_sizes[i], self.hidden_layers[i]))
+
+ # MLP's if we use contrastive loss on d's
+ if len([i for i in self.augmentation_method if i in ['instancewise_first', 'instancewise_full']]) > 0:
+ self.contrastive_mlp = nn.ModuleList([])
+ for i in range(0, len(self.hidden_layers)):
+ self.contrastive_mlp.append(MLP(input_size=self.hidden_layers[i], encoded_size=self.encoded_sizes[i], hidden_unit=min(self.hidden_layers)))
+
+ # top down: the generative model that from x computes the prior prob of all nodes from root till leaves
+ # it has a tree structure which is constructed by passing a list of transformations and routers from root to
+ # leaves visiting nodes layer-wise from left to right
+ # N.B. root has None as transformation and leaves have None as routers
+ # the encoded sizes and layers are reversed from bottom up
+ # e.g. for bottom up [MLP(256, 32), MLP(128, 16), MLP(64, 8)] the list of top-down transformations are
+ # [None, MLP(16, 64), MLP(16, 64), MLP(32, 128), MLP(32, 128), MLP(32, 128), MLP(32, 128)]
+
+ # select the top down generative networks
+ encoded_size_gen = self.encoded_sizes[-(self.depth+1):] # e.g. encoded_sizes 32,16,8, depth 1
+ encoded_size_gen = encoded_size_gen[::-1] # encoded_size_gen = 16,8 => 8,16
+ layers_gen = self.hidden_layers[-(self.depth+1):] # e.g. encoded_sizes 256,128,64, depth 1
+ layers_gen = layers_gen[::-1] # encoded_size_gen = 128,64 => 64,128
+
+ # add root transformation and dense layer, the dense layer is layer that connects the bottom-up with the nodes
+ self.transformations = nn.ModuleList([None])
+ self.denses = nn.ModuleList([Dense(layers_gen[0], encoded_size_gen[0])])
+ # attach the rest of transformations and dense layers for each node
+ for i in range(self.depth):
+ for j in range(2 ** (i + 1)):
+ self.transformations.append(MLP(encoded_size_gen[i], encoded_size_gen[i+1], layers_gen[i])) # MLP from depth i to i+1
+ self.denses.append(Dense(layers_gen[i+1], encoded_size_gen[i+1])) # Dense at depth i+1 from bottom-up to top-down
+
+ # compute the list of decisions for both bottom-up (decisions_q) and top-down (decisions)
+ # for each node of the tree
+ self.decisions = nn.ModuleList([])
+ self.decisions_q = nn.ModuleList([])
+ for i in range(self.depth):
+ for _ in range(2 ** i):
+ self.decisions.append(Router(encoded_size_gen[i], hidden_units=layers_gen[i])) # Router at node of depth i
+ self.decisions_q.append(Router(layers_gen[i], hidden_units=layers_gen[i]))
+ # the leaves do not have decisions (we set it to None)
+ for _ in range(2 ** (self.depth)):
+ self.decisions.append(None)
+ self.decisions_q.append(None)
+
+ # compute the list of decoders to attach to each node, note that internal nodes do not have a decoder
+ # e.g. for a tree with depth 2: decoders = [None, None, None, Dec, Dec, Dec, Dec]
+ self.decoders = nn.ModuleList([None for i in range(self.depth) for j in range(2 ** i)])
+ for _ in range(2 ** (self.depth)):
+ self.decoders.append(get_decoder(architecture=self.kwargs['encoder'], input_shape=encoded_size_gen[-1],
+ output_shape=self.inp_shape, activation=self.activation))
+
+ # construct the tree
+ self.tree = construct_tree(transformations=self.transformations, routers=self.decisions,
+ routers_q=self.decisions_q, denses=self.denses, decoders=self.decoders)
+
+ def forward(self, x):
+ """
+ Forward pass of the treeVAE model.
+
+ Parameters
+ ----------
+ x : tensor
+ Input data (batch-size, input-size)
+
+ Returns
+ -------
+ dict
+ a dictionary
+ {'rec_loss': reconstruction loss,
+ 'kl_root': the KL loss of the root,
+ 'kl_decisions': the KL loss of the decisions,
+ 'kl_nodes': the KL loss of the nodes,
+ 'aug_decisions': the weighted contrastive loss,
+ 'p_c_z': the probability of each sample to be assigned to each leaf with size: #samples x #leaves,
+ 'node_leaves': a list of leaf nodes, each one described by a dictionary
+ {'prob': sample-wise probability of reaching the node, 'z_sample': sampled leaf embedding}
+ }
+ """
+ # Small constant to prevent numerical instability
+ epsilon = 1e-7
+ device = x.device
+
+ # compute deterministic bottom up
+ d = x
+ encoders = []
+ emb_contr = []
+
+ for i in range(0, len(self.hidden_layers)):
+ d, _, _ = self.bottom_up[i](d)
+ # store bottom-up embeddings for top-down
+ encoders.append(d)
+
+ # pass through contrastive MLP's if contrastive learning is selected
+ if 'instancewise_full' in self.augmentation_method:
+ _, emb_c, _ = self.contrastive_mlp[i](d)
+ emb_contr.append(emb_c)
+ elif 'instancewise_first' in self.augmentation_method:
+ if i == 0:
+ _, emb_c, _ = self.contrastive_mlp[i](d)
+ emb_contr.append(emb_c)
+
+ # create a list of nodes of the tree that need to be processed, self.tree is the root of the tree
+ list_nodes = [{'node': self.tree, 'depth': 0, 'prob': torch.ones(x.size(0), device=device), 'z_parent_sample': None}]
+ # initializate KL losses
+ kl_nodes_tot = torch.zeros(len(x), device=device)
+ kl_decisions_tot = torch.zeros(len(x), device=device)
+ aug_decisions_loss = torch.zeros(1, device=device)
+ leaves_prob = []
+ reconstructions = []
+ node_leaves = []
+
+ # iterates over all nodes in the tree
+ while len(list_nodes) != 0:
+ # store info regarding the current node
+ current_node = list_nodes.pop(0)
+ node, depth_level, prob = current_node['node'], current_node['depth'], current_node['prob']
+ z_parent_sample = current_node['z_parent_sample']
+ # access deterministic bottom up mu and sigma hat (computed above)
+ d = encoders[-(1+depth_level)]
+ z_mu_q_hat, z_sigma_q_hat = node.dense(d)
+
+ # here we are in the root
+ if depth_level == 0:
+ # the root has a standard gaussian prior
+ z_mu_p, z_sigma_p = torch.zeros_like(z_mu_q_hat, device=device), torch.ones_like(z_sigma_q_hat, device=device)
+ z_p = td.Independent(td.Normal(z_mu_p, torch.sqrt(z_sigma_p + epsilon)), 1)
+ # the samples z (from q(z|x)) is the top layer of deterministic bottom-up
+ z_mu_q, z_sigma_q = z_mu_q_hat, z_sigma_q_hat
+
+ # otherwise we are in the rest of the nodes of the tree
+ else:
+ # the generative probability distribution of internal nodes is a gaussian with mu and sigma that are
+ # the outputs of the top-down network conditioned on the sampled parent
+ _, z_mu_p, z_sigma_p = node.transformation(z_parent_sample)
+ z_p = td.Independent(td.Normal(z_mu_p, torch.sqrt(z_sigma_p + epsilon)), 1)
+ # to avoid posterior collapse there is a share of information between the bottom-up and top-down
+ z_mu_q, z_sigma_q = compute_posterior(z_mu_q_hat, z_mu_p, z_sigma_q_hat, z_sigma_p)
+
+ # compute sample z using mu_q and sigma_q
+ z = td.Independent(td.Normal(z_mu_q, torch.sqrt(z_sigma_q + epsilon)), 1)
+ z_sample = z.rsample()
+
+ # compute KL node
+ kl_node = prob * td.kl_divergence(z, z_p)
+ kl_node = torch.clamp(kl_node, min=-1, max=1000)
+
+ if depth_level == 0:
+ kl_root = kl_node
+ else:
+ kl_nodes_tot += kl_node
+
+ # if there is a router (i.e. decision probability) then we are in the internal nodes (not leaves)
+ if node.router is not None:
+ # compute the probability of the sample to go to the left child
+ prob_child_left = node.router(z_sample).squeeze()
+ prob_child_left_q = node.routers_q(d).squeeze()
+
+ # compute the KL of the decisions
+ kl_decisions = prob_child_left_q * (epsilon + prob_child_left_q / (prob_child_left + epsilon)).log() + \
+ (1 - prob_child_left_q) * (epsilon + (1 - prob_child_left_q) / (1 - prob_child_left + epsilon)).log()
+ kl_decisions = prob * kl_decisions
+ kl_decisions_tot += kl_decisions
+
+ # compute the contrastive loss of the embeddings and the decisions
+ if self.training is True and self.augment is True and 'simple' not in self.augmentation_method:
+ if depth_level == 0:
+ # compute the contrastive loss for all the bottom-up representations
+ aug_decisions_loss += calc_aug_loss(prob_parent=prob, prob_router=prob_child_left_q, augmentation_methods=self.augmentation_method, emb_contr=emb_contr)
+ else:
+ # compute the contrastive loss for the decisions
+ aug_decisions_loss += calc_aug_loss(prob_parent=prob, prob_router=prob_child_left_q, augmentation_methods=self.augmentation_method, emb_contr=[])
+
+ # we are not in a leaf, so we have to add the left and right child to the list
+ prob_node_left, prob_node_right = prob * prob_child_left_q, prob * (1 - prob_child_left_q)
+ node_left, node_right = node.left, node.right
+ list_nodes.append(
+ {'node': node_left, 'depth': depth_level + 1, 'prob': prob_node_left, 'z_parent_sample': z_sample})
+ list_nodes.append({'node': node_right, 'depth': depth_level + 1, 'prob': prob_node_right,
+ 'z_parent_sample': z_sample})
+
+ # if there is a decoder then we are in one of the leaf
+ elif node.decoder is not None:
+ # if we are in a leaf we need to store the prob of reaching that leaf and compute reconstructions
+ # as the nodes are explored left to right, these probabilities will be also ordered left to right
+ leaves_prob.append(prob)
+ dec = node.decoder
+ reconstructions.append(dec(z_sample))
+ node_leaves.append({'prob': prob, 'z_sample': z_sample})
+
+ # here we are in an internal node with pruned leaves and thus only have one child
+ elif node.router is None and node.decoder is None:
+ node_left, node_right = node.left, node.right
+ child = node_left if node_left is not None else node_right
+ list_nodes.append(
+ {'node': child, 'depth': depth_level + 1, 'prob': prob, 'z_parent_sample': z_sample})
+
+ kl_nodes_loss = torch.clamp(torch.mean(kl_nodes_tot), min=-10, max=1e10)
+ kl_decisions_loss = torch.mean(kl_decisions_tot)
+ kl_root_loss = torch.mean(kl_root)
+
+ # p_c_z is the probability of reaching a leaf and is of shape [batch_size, num_clusters]
+ p_c_z = torch.cat([prob.unsqueeze(-1) for prob in leaves_prob], dim=-1)
+
+ rec_losses = self.loss(x, reconstructions, leaves_prob)
+ rec_loss = torch.mean(rec_losses, dim=0)
+
+ return_dict = {
+ 'rec_loss': rec_loss,
+ 'kl_root': kl_root_loss,
+ 'kl_decisions': kl_decisions_loss,
+ 'kl_nodes': kl_nodes_loss,
+ 'aug_decisions': self.aug_decisions_weight * aug_decisions_loss,
+ 'p_c_z': p_c_z,
+ 'node_leaves': node_leaves,
+ }
+
+ if self.return_elbo:
+ return_dict['elbo_samples'] = kl_nodes_tot + kl_decisions_tot + kl_root + rec_losses
+
+ if self.return_bottomup:
+ return_dict['bottom_up'] = encoders
+
+ if self.return_x:
+ return_dict['input'] = x
+
+ return return_dict
+
+
+ def compute_leaves(self):
+ """
+ Computes the leaves of the tree
+
+ Returns
+ -------
+ list
+ A list of the leaves from left to right.
+ A leaf is defined by a dictionary: {'node': leaf node, 'depth': depth of the leaf node}.
+ A leaf node is defined by the class Node in utils.model_utils.
+ """
+ # iterate over all nodes in the tree to find the leaves
+ list_nodes = [{'node': self.tree, 'depth': 0}]
+ nodes_leaves = []
+ while len(list_nodes) != 0:
+ current_node = list_nodes.pop(0)
+ node, depth_level = current_node['node'], current_node['depth']
+ if node.router is not None:
+ node_left, node_right = node.left, node.right
+ list_nodes.append(
+ {'node': node_left, 'depth': depth_level + 1})
+ list_nodes.append({'node': node_right, 'depth': depth_level + 1})
+ elif node.router is None and node.decoder is None:
+ # we are in an internal node with pruned leaves and thus only have one child
+ node_left, node_right = node.left, node.right
+ child = node_left if node_left is not None else node_right
+ list_nodes.append({'node': child, 'depth': depth_level + 1})
+ else:
+ nodes_leaves.append(current_node)
+ return nodes_leaves
+
+
+ def compute_depth(self):
+ """
+ Computes the depth of the tree
+
+ Returns
+ -------
+ int
+ The depth of the tree (the root has depth 0 and a root with two leaves had depth 1).
+ """
+ # computes depth of the tree
+ nodes_leaves = self.compute_leaves()
+ d = []
+ for i in range(len(nodes_leaves)):
+ d.append(nodes_leaves[i]['depth'])
+ return max(d)
+
+ def attach_smalltree(self, node, small_model):
+ """
+ Attach a trained small tree of the class SmallTreeVAE (models.model_smalltree) to the given node of the full
+ TreeVAE. The small tree has one root and two leaves. It does not return anything but changes self.tree
+
+ Parameters
+ ----------
+ node : utils.model_utils.Node
+ The selected node of TreeVAE where to attach the sub-tree, which was trained separately.
+ small_model: models.model_smalltree.SmallTreeVAE
+ The sub-tree with one root and two leaves that needs to be attached to TreeVAE.
+ """
+ assert node.left is None and node.right is None
+ node.router = small_model.decision
+ node.routers_q = small_model.decision_q
+ node.decoder = None
+ for j in range(2):
+ dense = small_model.denses[j]
+ transformation = small_model.transformations[j]
+ decoder = small_model.decoders[j]
+ # insert each leaf of the small tree as child of the node of TreeVAE
+ node.insert(transformation, None, None, dense, decoder)
+
+ # once the small tree is attached we re-compute the list of transformations, routers etc
+ transformations, routers, denses, decoders, routers_q = return_list_tree(self.tree)
+
+ # we then need to re-initialize the parameters of TreeVAE
+ self.decisions_q = routers_q
+ self.transformations = transformations
+ self.decisions = routers
+ self.denses = denses
+ self.decoders = decoders
+ self.depth = self.compute_depth()
+ return
+
+
+ def compute_reconstruction(self, x):
+ """
+ Given the input x, it computes the reconstructions.
+
+ Parameters
+ ----------
+ x: Tensor
+ Input data.
+
+ Returns
+ -------
+ Tensor
+ The reconstructions of the input data by computing a forward pass of the model.
+ List
+ A list of leaf nodes, each one described by a dictionary
+ {'prob': sample-wise probability of reaching the node, 'z_sample': sampled leaf embedding}
+ """
+ assert self.training is False
+ epsilon = 1e-7
+ device = x.device
+
+ # compute deterministic bottom up
+ d = x
+ encoders = []
+
+ for i in range(0, len(self.hidden_layers)):
+ d, _, _ = self.bottom_up[i](d)
+ # store the bottom-up layers for the top down computation
+ encoders.append(d)
+
+ # create a list of nodes of the tree that need to be processed
+ list_nodes = [{'node': self.tree, 'depth': 0, 'prob': torch.ones(x.size(0), device=device), 'z_parent_sample': None}]
+
+ # initializate KL losses
+ leaves_prob = []
+ reconstructions = []
+ node_leaves = []
+
+ # iterate over the nodes
+ while len(list_nodes) != 0:
+
+ # store info regarding the current node
+ current_node = list_nodes.pop(0)
+ node, depth_level, prob = current_node['node'], current_node['depth'], current_node['prob']
+ z_parent_sample = current_node['z_parent_sample']
+ # access deterministic bottom up mu and sigma hat (computed above)
+ d = encoders[-(1+depth_level)]
+ z_mu_q_hat, z_sigma_q_hat = node.dense(d)
+
+ if depth_level == 0:
+ z_mu_q, z_sigma_q = z_mu_q_hat, z_sigma_q_hat
+ else:
+ # the generative mu and sigma is the output of the top-down network given the sampled parent
+ _, z_mu_p, z_sigma_p = node.transformation(z_parent_sample)
+ z_mu_q, z_sigma_q = compute_posterior(z_mu_q_hat, z_mu_p, z_sigma_q_hat, z_sigma_p)
+
+ # compute sample z using mu_q and sigma_q
+ z = td.Independent(td.Normal(z_mu_q, torch.sqrt(z_sigma_q + epsilon)), 1)
+ z_sample = z.rsample()
+
+ # if we are in the internal nodes (not leaves)
+ if node.router is not None:
+
+ prob_child_left_q = node.routers_q(d).squeeze()
+
+ # we are not in a leaf, so we have to add the left and right child to the list
+ prob_node_left, prob_node_right = prob * prob_child_left_q, prob * (1 - prob_child_left_q)
+
+ node_left, node_right = node.left, node.right
+ list_nodes.append(
+ {'node': node_left, 'depth': depth_level + 1, 'prob': prob_node_left, 'z_parent_sample': z_sample})
+ list_nodes.append({'node': node_right, 'depth': depth_level + 1, 'prob': prob_node_right,
+ 'z_parent_sample': z_sample})
+
+ elif node.decoder is not None:
+ # if we are in a leaf we need to store the prob of reaching that leaf and compute reconstructions
+ # as the nodes are explored left to right, these probabilities will be also ordered left to right
+ leaves_prob.append(prob)
+ dec = node.decoder
+ reconstructions.append(dec(z_sample))
+ node_leaves.append({'prob': prob, 'z_sample': z_sample})
+
+ elif node.router is None and node.decoder is None:
+ # We are in an internal node with pruned leaves and thus only have one child
+ node_left, node_right = node.left, node.right
+ child = node_left if node_left is not None else node_right
+ list_nodes.append(
+ {'node': child, 'depth': depth_level + 1, 'prob': prob, 'z_parent_sample': z_sample})
+
+ return reconstructions, node_leaves
+
+ def generate_images(self, n_samples, device):
+ """
+ Generate K x n_samples new images by sampling from the root and propagating through the entire tree.
+ For each sample the method generates K images, where K is the number of leaves.
+
+ Parameters
+ ----------
+ n_samples: int
+ Number of generated samples the function should output.
+ device: torch.device
+ Either cpu or gpu
+
+ Returns
+ -------
+ list
+ A list of K tensors containing the leaf-specific generations obtained by sampling from the root and
+ propagating through the entire tree, where K is the number of leaves.
+ Tensor
+ The probability of each generated sample to be assigned to each leaf with size: #samples x #leaves,
+ """
+ assert self.training is False
+ epsilon = 1e-7
+ sizes = self.encoded_sizes
+ list_nodes = [{'node': self.tree, 'depth': 0, 'prob': torch.ones(n_samples, device=device), 'z_parent_sample': None}]
+ leaves_prob = []
+ reconstructions = []
+ while len(list_nodes) != 0:
+ current_node = list_nodes.pop(0)
+ node, depth_level, prob = current_node['node'], current_node['depth'], current_node['prob']
+ z_parent_sample = current_node['z_parent_sample']
+
+ if depth_level == 0:
+ z_mu_p, z_sigma_p = torch.zeros([n_samples, sizes[-1]], device=device), torch.ones([n_samples, sizes[-1]], device=device)
+ z_p = td.Independent(td.Normal(z_mu_p, torch.sqrt(z_sigma_p+epsilon)), 1)
+ z_sample = z_p.rsample()
+
+ else:
+ _, z_mu_p, z_sigma_p = node.transformation(z_parent_sample)
+ z_p = td.Independent(td.Normal(z_mu_p, torch.sqrt(z_sigma_p+epsilon)), 1)
+ z_sample = z_p.rsample()
+
+ if node.router is not None:
+ prob_child_left = node.router(z_sample).squeeze()
+ prob_node_left, prob_node_right = prob * prob_child_left, prob * (
+ 1 - prob_child_left)
+ node_left, node_right = node.left, node.right
+ list_nodes.append(
+ {'node': node_left, 'depth': depth_level + 1, 'prob': prob_node_left, 'z_parent_sample': z_sample})
+ list_nodes.append({'node': node_right, 'depth': depth_level + 1, 'prob': prob_node_right,
+ 'z_parent_sample': z_sample})
+
+ elif node.decoder is not None:
+ # here we are in a leaf node and we attach the corresponding generations
+ leaves_prob.append(prob)
+ dec = node.decoder
+ reconstructions.append(dec(z_sample))
+
+ elif node.router is None and node.decoder is None:
+ # We are in an internal node with pruned leaves and thus only have one child
+ node_left, node_right = node.left, node.right
+ child = node_left if node_left is not None else node_right
+ list_nodes.append(
+ {'node': child, 'depth': depth_level + 1, 'prob': prob, 'z_parent_sample': z_sample})
+ p_c_z = torch.cat([prob.unsqueeze(-1) for prob in leaves_prob], dim=-1)
+
+ return reconstructions, p_c_z
diff --git a/treevae/models/model_smalltree.py b/treevae/models/model_smalltree.py
new file mode 100644
index 0000000..9e23d63
--- /dev/null
+++ b/treevae/models/model_smalltree.py
@@ -0,0 +1,179 @@
+"""
+SmallTreeVAE model (used for the growing procedure of TreeVAE).
+"""
+import torch
+import torch.nn as nn
+import torch.distributions as td
+from models.networks import get_decoder, MLP, Router, Dense
+from utils.model_utils import compute_posterior
+from models.losses import loss_reconstruction_binary, loss_reconstruction_mse
+from utils.training_utils import calc_aug_loss
+
+class SmallTreeVAE(nn.Module):
+ """
+ A class used to represent a sub-tree VAE with one root and two children.
+
+ SmallTreeVAE specifies a sub-tree of TreeVAE with one root and two children. It is used in the
+ growing procedure of TreeVAE. At each growing step a new SmallTreeVAE is attached to a leaf of TreeVAE and
+ trained separately to reduce computational time.
+
+ Attributes
+ ----------
+ activation : str
+ The name of the activation function for the reconstruction loss [sigmoid, mse]
+ loss : models.losses
+ The loss function used by the decoder to reconstruct the input
+ alpha : float
+ KL-annealing weight initialization
+ depth : int
+ The depth at which the sub-tree will be attached (root has depth 0 and a root with two leaves has depth 1)
+ inp_shape : int
+ The total dimensions of the input data (if images of 32x32 then 32x32x3)
+ augment : bool
+ Whether to use contrastive learning through augmentation, if False no augmentation is used
+ augmentation_method : str
+ The type of augmentation method used
+ aug_decisions_weight : str
+ The weight of the contrastive loss used in the decisions
+ denses : nn.ModuleList
+ List of dense layers for the sharing of top-down and bottom-up (MLPs) associated with each of the two leaf
+ node of the tree from left to right.
+ transformations : nn.ModuleList
+ List of transformations (MLPs) associated with each of the two leaf node of the sub-tree from left to right
+ decision : Router
+ The decision associated with the root of the sub-tree.
+ decoders : nn.ModuleList
+ List of two decoders one for each leaf of the sub-tree
+ decision_q : str
+ The decision of the bottom-up associated with the root of the sub-tree
+
+ Methods
+ -------
+ forward(x)
+ Compute the forward pass of the SmallTreeVAE model and return a dictionary of losses.
+ """
+ def __init__(self, depth, **kwargs):
+ """
+ Parameters
+ ----------
+ depth: int
+ The depth at which the sub-tree will be attached to TreeVAE
+ kwargs : dict
+ A dictionary of attributes (see config file).
+ """
+ super(SmallTreeVAE, self).__init__()
+ self.kwargs = kwargs
+
+ self.activation = self.kwargs['activation']
+ if self.activation == "sigmoid":
+ self.loss = loss_reconstruction_binary
+ elif self.activation == "mse":
+ self.loss = loss_reconstruction_mse
+ else:
+ raise NotImplementedError
+ # KL-annealing weight initialization
+ self.alpha=self.kwargs['kl_start']
+
+ encoded_sizes = self.kwargs['latent_dim']
+ hidden_layers = self.kwargs['mlp_layers']
+ self.depth = depth
+ encoded_size_gen = encoded_sizes[-(self.depth+1):-(self.depth-1)] # e.g. encoded_size_gen = 32,16, depth 2
+ self.encoded_size = encoded_size_gen[::-1] # self.encoded_size = 32,16 => 16,32
+ layers_gen = hidden_layers[-(self.depth+1):-(self.depth-1)] # e.g. encoded_sizes 256,128,64, depth 2
+ self.hidden_layer = layers_gen[::-1] # encoded_size_gen = 256,128 => 128,256
+
+ self.inp_shape = self.kwargs['inp_shape']
+ self.augment = self.kwargs['augment']
+ self.augmentation_method = self.kwargs['augmentation_method']
+ self.aug_decisions_weight = self.kwargs['aug_decisions_weight']
+
+ self.denses = nn.ModuleList([Dense(self.hidden_layer[1], self.encoded_size[1]) for _ in range(2)])
+ self.transformations = nn.ModuleList([MLP(self.encoded_size[0], self.encoded_size[1], self.hidden_layer[0]) for _ in range(2)])
+ self.decision = Router(self.encoded_size[0], hidden_units=self.hidden_layer[0])
+ self.decision_q = Router(self.hidden_layer[0], hidden_units=self.hidden_layer[0])
+ self.decoders = nn.ModuleList([get_decoder(architecture=self.kwargs['encoder'], input_shape=self.encoded_size[1],
+ output_shape=self.inp_shape, activation=self.activation) for _ in range(2)])
+
+ def forward(self, x, z_parent, p, bottom_up):
+ """
+ Forward pass of the SmallTreeVAE model.
+
+ Parameters
+ ----------
+ x : tensor
+ Input data (batch-size, input-size)
+ z_parent: tensor
+ The embeddings of the parent of the two children of SmallTreeVAE (which are the embeddings of the TreeVAE
+ leaf where the SmallTreeVAE will be attached)
+ p: list
+ Probabilities of falling into the selected TreeVAE leaf where the SmallTreeVAE will be attached
+ bottom_up: list
+ The list of bottom-up transformations [encoder, MLP, MLP, ...] up to the root
+
+ Returns
+ -------
+ dict
+ a dictionary
+ {'rec_loss': reconstruction loss,
+ 'kl_decisions': the KL loss of the decisions,
+ 'kl_nodes': the KL loss of the nodes,
+ 'aug_decisions': the weighted contrastive loss,
+ 'p_c_z': the probability of each sample to be assigned to each leaf with size: #samples x #leaves,
+ }
+ """
+ epsilon = 1e-7 # Small constant to prevent numerical instability
+ device = x.device
+
+ # Extract relevant bottom-up
+ d_q = bottom_up[-self.depth]
+ d = bottom_up[-self.depth - 1]
+
+ prob_child_left = self.decision(z_parent).squeeze()
+ prob_child_left_q = self.decision_q(d_q).squeeze()
+ leaves_prob = [p * prob_child_left_q, p * (1 - prob_child_left_q)]
+
+ kl_decisions = prob_child_left_q * torch.log(epsilon + prob_child_left_q / (prob_child_left + epsilon)) +\
+ (1 - prob_child_left_q) * torch.log(epsilon + (1 - prob_child_left_q) /
+ (1 - prob_child_left + epsilon))
+ kl_decisions = torch.mean(p * kl_decisions)
+
+ # Contrastive loss
+ aug_decisions_loss = torch.zeros(1, device=device)
+ if self.training is True and self.augment is True and 'simple' not in self.augmentation_method:
+ aug_decisions_loss += calc_aug_loss(prob_parent=p, prob_router=prob_child_left_q,
+ augmentation_methods=self.augmentation_method)
+
+ reconstructions = []
+ kl_nodes = torch.zeros(1, device=device)
+ for i in range(2):
+ # Compute posterior parameters
+ z_mu_q_hat, z_sigma_q_hat = self.denses[i](d)
+ _, z_mu_p, z_sigma_p = self.transformations[i](z_parent)
+ z_p = td.Independent(td.Normal(z_mu_p, torch.sqrt(z_sigma_p+epsilon)), 1)
+ z_mu_q, z_sigma_q = compute_posterior(z_mu_q_hat, z_mu_p, z_sigma_q_hat, z_sigma_p)
+
+ # Compute sample z using mu_q and sigma_q
+ z_q = td.Independent(td.Normal(z_mu_q, torch.sqrt(z_sigma_q + epsilon)), 1)
+ z_sample = z_q.rsample()
+
+ # Compute KL node
+ kl_node = torch.mean(leaves_prob[i] * td.kl_divergence(z_q, z_p))
+ kl_nodes += kl_node
+
+ reconstructions.append(self.decoders[i](z_sample))
+
+ kl_nodes_loss = torch.clamp(kl_nodes, min=-10, max=1e10)
+
+ # Probability of falling in each leaf
+ p_c_z = torch.cat([prob.unsqueeze(-1) for prob in leaves_prob], dim=-1)
+
+ rec_losses = self.loss(x, reconstructions, leaves_prob)
+ rec_loss = torch.mean(rec_losses, dim=0)
+
+ return {
+ 'rec_loss': rec_loss,
+ 'kl_decisions': kl_decisions,
+ 'kl_nodes': kl_nodes_loss,
+ 'aug_decisions': self.aug_decisions_weight * aug_decisions_loss,
+ 'p_c_z': p_c_z,
+ }
diff --git a/treevae/models/networks.py b/treevae/models/networks.py
new file mode 100644
index 0000000..17b238f
--- /dev/null
+++ b/treevae/models/networks.py
@@ -0,0 +1,386 @@
+"""
+Encoder, decoder, transformation, router, and dense layer architectures.
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def actvn(x):
+ return F.leaky_relu(x, negative_slope=0.3)
+
+class EncoderSmall(nn.Module):
+ def __init__(self, input_shape, output_shape):
+ super(EncoderSmall, self).__init__()
+
+ self.dense1 = nn.Linear(in_features=input_shape, out_features=4*output_shape, bias=False)
+ self.bn1 = nn.BatchNorm1d(4*output_shape)
+ self.dense2 = nn.Linear(in_features=4*output_shape, out_features=4*output_shape, bias=False)
+ self.bn2 = nn.BatchNorm1d(4*output_shape)
+ self.dense3 = nn.Linear(in_features=4*output_shape, out_features=2*output_shape, bias=False)
+ self.bn3 = nn.BatchNorm1d(2*output_shape)
+ self.dense4 = nn.Linear(in_features=2*output_shape, out_features=output_shape, bias=False)
+ self.bn4 = nn.BatchNorm1d(output_shape)
+
+ def forward(self, inputs):
+ x = self.dense1(inputs)
+ x = self.bn1(x)
+ x = actvn(x)
+ x = self.dense2(x)
+ x = self.bn2(x)
+ x = actvn(x)
+ x = self.dense3(x)
+ x = self.bn3(x)
+ x = actvn(x)
+ x = self.dense4(x)
+ x = self.bn4(x)
+ x = actvn(x)
+ return x, None, None
+
+class DecoderSmall(nn.Module):
+ def __init__(self, input_shape, output_shape, activation):
+ super(DecoderSmall, self).__init__()
+ self.activation = activation
+ self.dense1 = nn.Linear(in_features=input_shape, out_features=128, bias=False)
+ self.bn1 = nn.BatchNorm1d(128)
+ self.dense2 = nn.Linear(in_features=128, out_features=256, bias=False)
+ self.bn2 = nn.BatchNorm1d(256)
+ self.dense3 = nn.Linear(in_features=256, out_features=512, bias=False)
+ self.bn3 = nn.BatchNorm1d(512)
+ self.dense4 = nn.Linear(in_features=512, out_features=512, bias=False)
+ self.bn4 = nn.BatchNorm1d(512)
+ self.dense5 = nn.Linear(in_features=512, out_features=output_shape, bias=True)
+
+ def forward(self, inputs):
+ x = self.dense1(inputs)
+ x = self.bn1(x)
+ x = actvn(x)
+ x = self.dense2(x)
+ x = self.bn2(x)
+ x = actvn(x)
+ x = self.dense3(x)
+ x = self.bn3(x)
+ x = actvn(x)
+ x = self.dense4(x)
+ x = self.bn4(x)
+ x = actvn(x)
+ x = self.dense5(x)
+ if self.activation == "sigmoid":
+ x = torch.sigmoid(x)
+ return x
+
+
+class EncoderSmallCnn(nn.Module):
+ def __init__(self, encoded_size):
+ super(EncoderSmallCnn, self).__init__()
+ n_maps_output = encoded_size//4
+ self.cnn0 = nn.Conv2d(in_channels=1, out_channels=n_maps_output//4, kernel_size=3, stride=2, padding=0, bias=False)
+ self.cnn1 = nn.Conv2d(in_channels=n_maps_output//4, out_channels=n_maps_output//2, kernel_size=3, stride=2, padding=0, bias=False)
+ self.cnn2 = nn.Conv2d(in_channels=n_maps_output//2, out_channels=n_maps_output, kernel_size=3, stride=2, padding=0, bias=False)
+ self.bn0 = nn.BatchNorm2d(n_maps_output//4)
+ self.bn1 = nn.BatchNorm2d(n_maps_output//2)
+ self.bn2 = nn.BatchNorm2d(n_maps_output)
+
+ def forward(self, x):
+ x = self.cnn0(x)
+ x = self.bn0(x)
+ x = actvn(x)
+ x = self.cnn1(x)
+ x = self.bn1(x)
+ x = actvn(x)
+ x = self.cnn2(x)
+ x = self.bn2(x)
+ x = actvn(x)
+ x = x.view(x.size(0), -1)
+ return x, None, None
+
+class DecoderSmallCnn(nn.Module):
+ def __init__(self, input_shape, activation):
+ super(DecoderSmallCnn, self).__init__()
+ self.activation = activation
+ self.dense = nn.Linear(in_features=input_shape, out_features=3 * 3 * 32, bias=False)
+ self.bn = nn.BatchNorm1d(3 * 3 * 32)
+ self.bn1 = nn.BatchNorm2d(16)
+ self.bn2 = nn.BatchNorm2d(8)
+ self.cnn1 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, bias=False)
+ self.cnn2 = nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
+ self.cnn3 = nn.ConvTranspose2d(in_channels=8, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True)
+
+ def forward(self, inputs):
+ x = self.dense(inputs)
+ x = self.bn(x)
+ x = actvn(x)
+ x = x.view(-1, 32, 3, 3)
+ x = self.cnn1(x)
+ x = self.bn1(x)
+ x = actvn(x)
+ x = self.cnn2(x)
+ x = self.bn2(x)
+ x = actvn(x)
+ x = self.cnn3(x)
+ if self.activation == 'sigmoid':
+ x = torch.sigmoid(x)
+ return x
+
+
+class EncoderOmniglot(nn.Module):
+ def __init__(self, encoded_size):
+ super(EncoderOmniglot, self).__init__()
+ self.cnns = nn.ModuleList([
+ nn.Conv2d(in_channels=1, out_channels=encoded_size//4, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.Conv2d(in_channels=encoded_size//4, out_channels=encoded_size//4, kernel_size=4, stride=2, padding=1, bias=False),
+ nn.Conv2d(in_channels=encoded_size//4, out_channels=encoded_size//2, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.Conv2d(in_channels=encoded_size//2, out_channels=encoded_size//2, kernel_size=4, stride=2, padding=1, bias=False),
+ nn.Conv2d(in_channels=encoded_size//2, out_channels=encoded_size, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.Conv2d(in_channels=encoded_size, out_channels=encoded_size, kernel_size=5, bias=False)
+ ])
+ self.bns = nn.ModuleList([
+ nn.BatchNorm2d(encoded_size//4),
+ nn.BatchNorm2d(encoded_size//4),
+ nn.BatchNorm2d(encoded_size//2),
+ nn.BatchNorm2d(encoded_size//2),
+ nn.BatchNorm2d(encoded_size),
+ nn.BatchNorm2d(encoded_size)
+ ])
+
+ def forward(self, x):
+ for i in range(len(self.cnns)):
+ x = self.cnns[i](x)
+ x = self.bns[i](x)
+ x = actvn(x)
+ x = x.view(x.size(0), -1)
+ return x, None, None
+
+class DecoderOmniglot(nn.Module):
+ def __init__(self, input_shape, activation):
+ super(DecoderOmniglot, self).__init__()
+ self.activation = activation
+ self.dense = nn.Linear(in_features=input_shape, out_features=2 * 2 * 128, bias=False)
+ self.cnns = nn.ModuleList([
+ nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=5, stride=2, bias=False),
+ nn.Conv2d(in_channels=128, out_channels=64, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=0, bias=False),
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=0, output_padding=1, bias=False)
+ ])
+ self.cnns.append(nn.Conv2d(in_channels=32, out_channels=1, kernel_size=4, stride=1, padding=1, bias=True))
+ self.bn = nn.BatchNorm1d(2 * 2 * 128)
+ self.bns = nn.ModuleList([
+ nn.BatchNorm2d(128),
+ nn.BatchNorm2d(64),
+ nn.BatchNorm2d(64),
+ nn.BatchNorm2d(32),
+ nn.BatchNorm2d(32)
+ ])
+
+ def forward(self, inputs):
+ x = self.dense(inputs)
+ x = self.bn(x)
+ x = actvn(x)
+ x = x.view(-1, 128, 2, 2)
+ for i in range(len(self.bns)):
+ x = self.cnns[i](x)
+ x = self.bns[i](x)
+ x = actvn(x)
+ x = self.cnns[-1](x)
+ if self.activation == "sigmoid":
+ x = torch.sigmoid(x)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, fin, fout, fhidden=None, is_bias=True):
+ super(ResnetBlock, self).__init__()
+
+ self.learned_shortcut = (fin != fout)
+ self.fin = fin
+ self.fout = fout
+ if fhidden is None:
+ self.fhidden = min(fin, fout)
+ else:
+ self.fhidden = fhidden
+
+ # Submodules
+ self.conv_0 = nn.Conv2d(in_channels=fin, out_channels=self.fhidden, kernel_size=3, stride=1, padding=1)
+ self.conv_1 = nn.Conv2d(in_channels=self.fhidden, out_channels=self.fout, kernel_size=3, stride=1, padding=1, bias=is_bias)
+ if self.learned_shortcut:
+ self.conv_s = nn.Conv2d(in_channels=fin, out_channels=self.fout, kernel_size=1, stride=1, padding=0, bias=False)
+ self.bn0 = nn.BatchNorm2d(self.fin)
+ self.bn1 = nn.BatchNorm2d(self.fhidden)
+
+ def forward(self, x):
+ x_s = self._shortcut(x)
+ dx = self.conv_0(actvn(self.bn0(x)))
+ dx = self.conv_1(actvn(self.bn1(dx)))
+ out = x_s + 0.1 * dx
+ return out
+
+ def _shortcut(self, x):
+ if self.learned_shortcut:
+ x_s = self.conv_s(x)
+ else:
+ x_s = x
+ return x_s
+
+class Resnet_Encoder(nn.Module):
+ def __init__(self, s0=2, nf=8, nf_max=256, size=32):
+ super(Resnet_Encoder, self).__init__()
+
+ self.s0 = s0
+ self.nf = nf
+ self.nf_max = nf_max
+ self.size = size
+
+ # Submodules
+ nlayers = int(torch.log2(torch.tensor(size / s0).float()))
+ self.nf0 = min(nf_max, nf * 2 ** nlayers)
+
+ blocks = [
+ ResnetBlock(nf, nf)
+ ]
+
+ for i in range(nlayers):
+ nf0 = min(nf * 2 ** i, nf_max)
+ nf1 = min(nf * 2 ** (i + 1), nf_max)
+ blocks += [
+ nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
+ ResnetBlock(nf0, nf1),
+ ]
+
+ self.conv_img = nn.Conv2d(3, 1 * nf, kernel_size=3, padding=1)
+
+ self.resnet = nn.Sequential(*blocks)
+
+ self.bn0 = nn.BatchNorm2d(self.nf0)
+
+
+ def forward(self, x):
+ out = self.conv_img(x)
+ out = self.resnet(out)
+ out = actvn(self.bn0(out))
+ out = out.view(out.size(0), -1)
+ return out, None, None
+
+class Resnet_Decoder(nn.Module):
+ def __init__(self, s0=2, nf=8, nf_max=256, ndim=64, activation='sigmoid', size=32):
+ super(Resnet_Decoder, self).__init__()
+
+ self.s0 = s0
+ self.nf = nf
+ self.nf_max = nf_max
+ self.activation = activation
+
+ # Submodules
+ nlayers = int(torch.log2(torch.tensor(size / s0).float()))
+ self.nf0 = min(nf_max, nf * 2 ** nlayers)
+
+ self.fc = nn.Linear(ndim, self.nf0 * s0 * s0)
+
+ blocks = []
+ for i in range(nlayers):
+ nf0 = min(nf * 2 ** (nlayers - i), nf_max)
+ nf1 = min(nf * 2 ** (nlayers - i - 1), nf_max)
+ blocks += [
+ ResnetBlock(nf0, nf1),
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+ ]
+ blocks += [
+ ResnetBlock(nf, nf),
+ ]
+ self.resnet = nn.Sequential(*blocks)
+
+ self.bn0 = nn.BatchNorm2d(nf)
+ self.conv_img = nn.ConvTranspose2d(nf, 3, kernel_size=3, padding=1)
+
+
+ def forward(self, z):
+ out = self.fc(z)
+ out = out.view(-1, self.nf0, self.s0, self.s0)
+ out = self.resnet(out)
+ out = self.conv_img(actvn(self.bn0(out)))
+ if self.activation == 'sigmoid':
+ out = torch.sigmoid(out)
+ return out
+
+
+# Small branch transformation
+class MLP(nn.Module):
+ def __init__(self, input_size, encoded_size, hidden_unit):
+ super(MLP, self).__init__()
+ self.dense1 = nn.Linear(input_size, hidden_unit, bias=False)
+ self.bn1 = nn.BatchNorm1d(hidden_unit)
+ self.mu = nn.Linear(hidden_unit, encoded_size)
+ self.sigma = nn.Linear(hidden_unit, encoded_size)
+
+ def forward(self, inputs):
+ x = self.dense1(inputs)
+ x = self.bn1(x)
+ x = actvn(x)
+ mu = self.mu(x)
+ sigma = F.softplus(self.sigma(x))
+ return x, mu, sigma
+
+
+class Dense(nn.Module):
+ def __init__(self, input_size, encoded_size):
+ super(Dense, self).__init__()
+ self.mu = nn.Linear(input_size, encoded_size)
+ self.sigma = nn.Linear(input_size, encoded_size)
+
+ def forward(self, inputs):
+ x = inputs
+ mu = self.mu(x)
+ sigma = F.softplus(self.sigma(x))
+ return mu, sigma
+
+
+class Router(nn.Module):
+ def __init__(self, input_size, hidden_units=128):
+ super(Router, self).__init__()
+ self.dense1 = nn.Linear(input_size, hidden_units, bias=False)
+ self.dense2 = nn.Linear(hidden_units, hidden_units, bias=False)
+ self.bn1 = nn.BatchNorm1d(hidden_units)
+ self.bn2 = nn.BatchNorm1d(hidden_units)
+ self.dense3 = nn.Linear(hidden_units, 1)
+
+ def forward(self, inputs, return_last_layer=False):
+ x = self.dense1(inputs)
+ x = self.bn1(x)
+ x = actvn(x)
+ x = self.dense2(x)
+ x = self.bn2(x)
+ x = actvn(x)
+ d = F.sigmoid(self.dense3(x))
+ if return_last_layer:
+ return d, x
+ else:
+ return d
+
+
+def get_encoder(architecture, encoded_size, x_shape, size=None):
+ if architecture == 'mlp':
+ encoder = EncoderSmall(input_shape=x_shape, output_shape=encoded_size)
+ elif architecture == 'cnn1':
+ encoder = EncoderSmallCnn(encoded_size)
+ elif architecture == 'cnn2':
+ encoder = Resnet_Encoder(s0=4, nf=32, nf_max=256, size=size)
+ elif architecture == 'cnn_omni':
+ encoder = EncoderOmniglot(encoded_size)
+ else:
+ raise ValueError('The encoder architecture is mispecified.')
+ return encoder
+
+
+def get_decoder(architecture, input_shape, output_shape, activation):
+ if architecture == 'mlp':
+ decoder = DecoderSmall(input_shape, output_shape, activation)
+ elif architecture == 'cnn1':
+ decoder = DecoderSmallCnn(input_shape, activation)
+ elif architecture == 'cnn2':
+ size = int((output_shape/3)**0.5)
+ decoder = Resnet_Decoder(s0=4, nf=32, nf_max=256, ndim = input_shape, activation=activation, size=size)
+ elif architecture == 'cnn_omni':
+ decoder = DecoderOmniglot(input_shape, activation)
+ else:
+ raise ValueError('The decoder architecture is mispecified.')
+ return decoder
diff --git a/treevae/train/train.py b/treevae/train/train.py
new file mode 100644
index 0000000..b94ec00
--- /dev/null
+++ b/treevae/train/train.py
@@ -0,0 +1,87 @@
+"""
+Run training and validation functions of TreeVAE.
+"""
+import time
+from pathlib import Path
+import wandb
+import uuid
+import os
+import torch
+
+from utils.data_utils import get_data
+from utils.utils import reset_random_seeds
+from train.train_tree import run_tree
+from train.validate_tree import val_tree
+
+
+def run_experiment(configs):
+ """
+ Run the experiments for TreeVAE as defined in the config setting. This method will set up the device, the correct
+ experimental paths, initialize Wandb for tracking, generate the dataset, train and grow the TreeVAE model, and
+ finally it will validate the result. All final results and validations will be stored in Wandb, while the most
+ important ones will be also printed out in the terminal. If specified, the model will also be saved for further
+ exploration using the Jupyter Notebook: tree_exploration.ipynb.
+
+ Parameters
+ ----------
+ configs: dict
+ The config setting for training and validating TreeVAE defined in configs or in the command line.
+ """
+ # Setting device on GPU if available, else CPU
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+ # Additional info when using cuda
+ if device.type == 'cuda':
+ print("Using", torch.cuda.get_device_name(0))
+ else:
+ print("No GPU available")
+
+ # Set paths
+ project_dir = Path(__file__).absolute().parent
+ timestr = time.strftime("%Y%m%d-%H%M%S")
+ ex_name = "{}_{}".format(str(timestr), uuid.uuid4().hex[:5])
+ experiment_path = configs['globals']['results_dir'] / configs['data']['data_name'] / ex_name
+ experiment_path.mkdir(parents=True)
+ os.makedirs(os.path.join(project_dir, '../models/logs', ex_name))
+ print("Experiment path: ", experiment_path)
+
+ # Wandb
+ os.environ['WANDB_CACHE_DIR'] = os.path.join(project_dir, '../wandb', '.cache', 'wandb')
+ os.environ["WANDB_SILENT"] = "true"
+
+ # ADD YOUR WANDB ENTITY
+ wandb.init(
+ project="treevae",
+ entity="test",
+ config=configs,
+ mode=configs['globals']['wandb_logging']
+ )
+
+ if configs['globals']['wandb_logging'] in ['online', 'disabled']:
+ wandb.run.name = wandb.run.name.split("-")[-1] + "-"+ configs['run_name']
+ elif configs['globals']['wandb_logging'] == 'offline':
+ wandb.run.name = configs['run_name']
+ else:
+ raise ValueError('wandb needs to be set to online, offline or disabled.')
+
+ # Reproducibility
+ reset_random_seeds(configs['globals']['seed'])
+
+ # Generate a new dataset each run
+ trainset, trainset_eval, testset = get_data(configs)
+
+ # Run the full training of treeVAE model, including the growing of the tree
+ model = run_tree(trainset, trainset_eval, testset, device, configs)
+
+ # Save model
+ if configs['globals']['save_model']:
+ print("\nSaving weights at ", experiment_path)
+ torch.save(model.state_dict(), experiment_path / 'model_weights.pt')
+
+ # Evaluation of TreeVAE
+ print("\n" * 2)
+ print("Evaluation")
+ print("\n" * 2)
+ val_tree(trainset_eval, testset, model, device, experiment_path, configs)
+ wandb.finish(quiet=True)
+ return
diff --git a/treevae/train/train_tree.py b/treevae/train/train_tree.py
new file mode 100644
index 0000000..bbfbbad
--- /dev/null
+++ b/treevae/train/train_tree.py
@@ -0,0 +1,247 @@
+"""
+Training function of TreeVAE and SmallTreeVAE.
+"""
+import wandb
+import numpy as np
+import gc
+import torch
+import torch.optim as optim
+
+from utils.training_utils import train_one_epoch, validate_one_epoch, AnnealKLCallback, Custom_Metrics, \
+ get_ind_small_tree, compute_growing_leaf, compute_pruning_leaf, get_optimizer, predict
+from utils.data_utils import get_gen
+from utils.model_utils import return_list_tree, construct_data_tree
+from models.model import TreeVAE
+from models.model_smalltree import SmallTreeVAE
+
+
+def run_tree(trainset, trainset_eval, testset, device, configs):
+ """
+ Run the TreeVAE model as defined in the config setting. The method will first train a TreeVAE model with initial
+ depth defined in config (initial_depth). After training TreeVAE for epochs=num_epochs, if grow=True then it will
+ start the iterative growing schedule. At each step, a SmallTreeVAE will be trained for num_epochs_smalltree and
+ attached to the selected leaf of TreeVAE. The resulting TreeVAE will then grow at each step and will be finetuned
+ throughout the growing procedure for num_epochs_intermediate_fulltrain and at the end of the growing procedure for
+ num_epochs_finetuning.
+
+ Parameters
+ ----------
+ trainset: torch.utils.data.Dataset
+ The train dataset
+ trainset_eval: torch.utils.data.Dataset
+ The validation dataset
+ testset: torch.utils.data.Dataset
+ The test dataset
+ device: torch.device
+ The device in which to validate the model
+ configs: dict
+ The config setting for training and validating TreeVAE defined in configs or in the command line
+
+ Returns
+ ------
+ models.model.TreeVAE
+ The trained TreeVAE model
+ """
+
+ graph_mode = not configs['globals']['eager_mode']
+ gen_train = get_gen(trainset, configs, validation=False, shuffle=True)
+ gen_train_eval = get_gen(trainset_eval, configs, validation=True, shuffle=False)
+ gen_test = get_gen(testset, configs, validation=True, shuffle=False)
+ _ = gc.collect()
+
+ # Define model & optimizer
+ model = TreeVAE(**configs['training'])
+ model.to(device)
+
+ if graph_mode:
+ model = torch.compile(model)
+
+ optimizer = get_optimizer(model, configs)
+
+ # Initialize schedulers
+ lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=configs['training']['decay_stepsize'],
+ gamma=configs['training']['decay_lr'])
+ alpha_scheduler = AnnealKLCallback(model, decay=configs['training']['decay_kl'],
+ start=configs['training']['kl_start'])
+
+ # Initialize Metrics
+ metrics_calc_train = Custom_Metrics(device).to(device)
+ metrics_calc_val = Custom_Metrics(device).to(device)
+
+ ################################# TRAINING TREEVAE with depth defined in config #################################
+
+ # Training the initial tree
+ for epoch in range(configs['training']['num_epochs']): # loop over the dataset multiple times
+ train_one_epoch(gen_train, model, optimizer, metrics_calc_train, epoch, device)
+ validate_one_epoch(gen_test, model, metrics_calc_val, epoch, device)
+ lr_scheduler.step()
+ alpha_scheduler.on_epoch_end(epoch)
+ _ = gc.collect()
+
+ ################################# GROWING THE TREE #################################
+
+ # Start the growing loop of the tree
+ # Compute metrics and set node.expand False for the nodes that should not grow
+ # This loop goes layer-wise
+ grow = configs['training']['grow']
+ initial_depth = configs['training']['initial_depth']
+ max_depth = len(configs['training']['mlp_layers']) - 1
+ if initial_depth >= max_depth:
+ grow = False
+ growing_iterations = 0
+ while grow and growing_iterations < 150:
+
+ # full model finetuning during growing after every 3 splits
+ if configs['training']['num_epochs_intermediate_fulltrain']>0:
+ if growing_iterations != 0 and growing_iterations % 3 == 0:
+ # Initialize optimizer and schedulers
+ optimizer = get_optimizer(model, configs)
+ lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=configs['training']['decay_stepsize'],
+ gamma=configs['training']['decay_lr'])
+ alpha_scheduler = AnnealKLCallback(model, decay=configs['training']['decay_kl'],
+ start=configs['training']['kl_start'])
+
+ # Training the initial split
+ print('\nTree intermediate finetuning\n')
+ for epoch in range(configs['training']['num_epochs_intermediate_fulltrain']):
+ train_one_epoch(gen_train, model, optimizer, metrics_calc_train, epoch, device)
+ validate_one_epoch(gen_test, model, metrics_calc_val, epoch, device)
+ lr_scheduler.step()
+ alpha_scheduler.on_epoch_end(epoch)
+ _ = gc.collect()
+
+ # extract information of leaves
+ node_leaves_train = predict(gen_train_eval, model, device, 'node_leaves')
+ node_leaves_test = predict(gen_test, model, device, 'node_leaves')
+
+ # compute which leaf to grow and split
+ ind_leaf, leaf, n_effective_leaves = compute_growing_leaf(gen_train_eval, model, node_leaves_train, max_depth,
+ configs['training']['batch_size'],
+ max_leaves=configs['training']['num_clusters_tree'])
+ if ind_leaf == None:
+ break
+ else:
+ print('\nGrowing tree: Leaf %d at depth %d\n' % (ind_leaf, leaf['depth']))
+ depth, node = leaf['depth'], leaf['node']
+
+ # get subset of data that has high prob. of falling in subtree
+ ind_train = get_ind_small_tree(node_leaves_train[ind_leaf], n_effective_leaves)
+ ind_test = get_ind_small_tree(node_leaves_test[ind_leaf], n_effective_leaves)
+ gen_train_small = get_gen(trainset, configs, shuffle=True, smalltree=True, smalltree_ind=ind_train)
+ gen_test_small = get_gen(testset, configs, shuffle=False, validation=True, smalltree=True,
+ smalltree_ind=ind_test)
+
+ # preparation for the smalltree training
+ # initialize the smalltree
+ small_model = SmallTreeVAE(depth=depth+1, **configs['training'])
+ small_model.to(device)
+ if graph_mode:
+ small_model = torch.compile(small_model)
+
+ # Optimizer for smalltree
+ optimizer = get_optimizer(small_model, configs)
+
+ # Initialize schedulers
+ lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=configs['training']['decay_stepsize'],
+ gamma=configs['training']['decay_lr'])
+ alpha_scheduler = AnnealKLCallback(small_model, decay=configs['training']['decay_kl'],
+ start=configs['training']['kl_start'])
+
+ # Training the smalltree subsplit
+ for epoch in range(configs['training']['num_epochs_smalltree']):
+ train_one_epoch(gen_train_small, model, optimizer, metrics_calc_train, epoch, device, train_small_tree=True,
+ small_model=small_model, ind_leaf=ind_leaf)
+ validate_one_epoch(gen_test_small, model, metrics_calc_val, epoch, device, train_small_tree=True,
+ small_model=small_model, ind_leaf=ind_leaf)
+ lr_scheduler.step()
+ alpha_scheduler.on_epoch_end(epoch)
+ _ = gc.collect()
+
+ # attach smalltree to full tree by assigning decisions and adding new children nodes to full tree
+ model.attach_smalltree(node, small_model)
+
+ # Check if reached the max number of effective leaves before finetuning unnecessarily
+ if n_effective_leaves + 1 == configs['training']['num_clusters_tree']:
+ node_leaves_train = predict(gen_train_eval, model, device, 'node_leaves')
+ _, _, max_growth = compute_growing_leaf(gen_train_eval, model, node_leaves_train, max_depth,
+ configs['training']['batch_size'],
+ max_leaves=configs['training']['num_clusters_tree'], check_max=True)
+ if max_growth is True:
+ break
+
+ growing_iterations += 1
+
+ # The growing loop of the tree is concluded!
+ # check whether we need to prune the final tree and log pre-pruning dendrogram
+ prune = configs['training']['prune']
+ if prune:
+ node_leaves_test, prob_leaves_test = predict(gen_test, model, device, 'node_leaves', 'prob_leaves')
+ if len(node_leaves_test)<2:
+ prune = False
+ else:
+ print('\nStarting pruning!\n')
+ yy = np.squeeze(np.argmax(prob_leaves_test, axis=-1))
+ y_test = testset.dataset.targets[testset.indices]
+ data_tree = construct_data_tree(model, y_predicted=yy, y_true=y_test, n_leaves=len(node_leaves_test),
+ data_name=configs['data']['data_name'])
+
+ table = wandb.Table(columns=["node_id", "node_name", "parent", "size"], data=data_tree)
+ fields = {"node_name": "node_name", "node_id": "node_id", "parent": "parent", "size": "size"}
+ dendro = wandb.plot_table(vega_spec_name="stacey/flat_tree", data_table=table, fields=fields)
+ wandb.log({"dendogram_pre_pruned": dendro})
+
+ # prune the tree
+ while prune:
+ # check pruning conditions
+ node_leaves_train = predict(gen_train_eval, model, device, 'node_leaves')
+ ind_leaf, leaf = compute_pruning_leaf(model, node_leaves_train)
+
+ if ind_leaf == None:
+ print('\nPruning finished!\n')
+ break
+ else:
+ # prune leaves and internal nodes without children
+ print(f'\nPruning leaf {ind_leaf}!\n')
+ current_node = leaf['node']
+ while all(child is None for child in [current_node.left, current_node.right]):
+ if current_node.parent is not None:
+ parent = current_node.parent
+ # root does not get pruned
+ else:
+ break
+ parent.prune_child(current_node)
+ current_node = parent
+
+
+ # reinitialize model
+ transformations, routers, denses, decoders, routers_q = return_list_tree(model.tree)
+ model.decisions_q = routers_q
+ model.transformations = transformations
+ model.decisions = routers
+ model.denses = denses
+ model.decoders = decoders
+ model.depth = model.compute_depth()
+ _ = gc.collect()
+
+ ################################# FULL MODEL FINETUNING #################################
+
+
+ print('\n*****************model depth %d******************\n' % (model.depth))
+ print('\n*****************model finetuning******************\n')
+
+ # Initialize optimizer and schedulers
+ optimizer = get_optimizer(model, configs)
+ lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=configs['training']['decay_stepsize'], gamma=configs['training']['decay_lr'])
+ alpha_scheduler = AnnealKLCallback(model, decay=max(0.01,1/max(1,configs['training']['num_epochs_finetuning']-1)), start=configs['training']['kl_start'])
+ # finetune the full tree
+ print('\nTree final finetuning\n')
+ for epoch in range(configs['training']['num_epochs_finetuning']): # loop over the dataset multiple times
+ train_one_epoch(gen_train, model, optimizer, metrics_calc_train, epoch, device)
+ validate_one_epoch(gen_test, model, metrics_calc_val, epoch, device)
+ lr_scheduler.step()
+ alpha_scheduler.on_epoch_end(epoch)
+ _ = gc.collect()
+
+ return model
+
+
diff --git a/treevae/train/validate_tree.py b/treevae/train/validate_tree.py
new file mode 100644
index 0000000..efbbbc4
--- /dev/null
+++ b/treevae/train/validate_tree.py
@@ -0,0 +1,192 @@
+import wandb
+import numpy as np
+from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score
+import gc
+import yaml
+import torch
+import scipy
+from tqdm import tqdm
+
+from utils.data_utils import get_gen
+from utils.utils import cluster_acc, dendrogram_purity, leaf_purity
+from utils.training_utils import compute_leaves, validate_one_epoch, Custom_Metrics, predict
+from utils.model_utils import construct_data_tree
+from models.losses import loss_reconstruction_cov_mse_eval
+
+
+def val_tree(trainset, testset, model, device, experiment_path, configs):
+ """
+ Run the validation of a trained instance of TreeVAE on both the train and test datasets. All final results and
+ validations will be stored in Wandb, while the most important ones will be also printed out in the terminal.
+
+ Parameters
+ ----------
+ trainset: torch.utils.data.Dataset
+ The train dataset
+ testset: torch.utils.data.Dataset
+ The test dataset
+ model: models.model.TreeVAE
+ The trained TreeVAE model
+ device: torch.device
+ The device in which to validate the model
+ experiment_path: str
+ The experimental path where to store the tree
+ configs: dict
+ The config setting for training and validating TreeVAE defined in configs or in the command line
+ """
+
+ ############ Training set performance ############
+
+ # get the data loader
+ gen_train_eval = get_gen(trainset, configs, validation=True, shuffle=False)
+ y_train = trainset.dataset.targets[trainset.indices].numpy()
+ # compute the leaf probabilities
+ prob_leaves_train = predict(gen_train_eval, model, device, 'prob_leaves')
+ _ = gc.collect()
+ # compute the predicted cluster
+ y_train_pred = np.squeeze(np.argmax(prob_leaves_train, axis=-1)).numpy()
+ # compute clustering metrics
+ acc, idx = cluster_acc(y_train, y_train_pred, return_index=True)
+ nmi = normalized_mutual_info_score(y_train, y_train_pred)
+ ari = adjusted_rand_score(y_train, y_train_pred)
+ wandb.log({"Train Accuracy": acc, "Train Normalized Mutual Information": nmi, "Train Adjusted Rand Index": ari})
+ # compute confusion matrix
+ swap = dict(zip(range(len(idx)), idx))
+ y_wandb = np.array([swap[i] for i in y_train_pred], dtype=np.uint8)
+ wandb.log({"Train_confusion_matrix":
+ wandb.plot.confusion_matrix(probs=None, y_true=y_train, preds=y_wandb, class_names=range(len(idx)))})
+
+ ############ Test set performance ############
+
+ # get the data loader
+ gen_test = get_gen(testset, configs, validation=True, shuffle=False)
+ y_test = testset.dataset.targets[testset.indices].numpy()
+ # compute one validation pass through the test set to log losses
+ metrics_calc_test = Custom_Metrics(device)
+ validate_one_epoch(gen_test, model, metrics_calc_test, 0, device, test=True)
+ _ = gc.collect()
+ # predict the leaf probabilities and the leaves
+ node_leaves_test, prob_leaves_test = predict(gen_test, model, device, 'node_leaves', 'prob_leaves')
+ _ = gc.collect()
+ # compute the predicted cluster
+ y_test_pred = np.squeeze(np.argmax(prob_leaves_test, axis=-1)).numpy()
+ # Calculate clustering metrics
+ acc, idx = cluster_acc(y_test, y_test_pred, return_index=True)
+ nmi = normalized_mutual_info_score(y_test, y_test_pred)
+ ari = adjusted_rand_score(y_test, y_test_pred)
+ wandb.log({"Test Accuracy": acc, "Test Normalized Mutual Information": nmi, "Test Adjusted Rand Index": ari})
+ # Calculate confusion matrix
+ swap = dict(zip(range(len(idx)), idx))
+ y_wandb = np.array([swap[i] for i in y_test_pred], dtype=np.uint8)
+ wandb.log({"Test_confusion_matrix": wandb.plot.confusion_matrix(probs=None,
+ y_true=y_test, preds=y_wandb,
+ class_names=range(len(idx)))})
+
+ # Determine indices of samples that fall into each leaf for Dendogram Purity & Leaf Purity
+ leaves = compute_leaves(model.tree)
+ ind_samples_of_leaves = []
+ for i in range(len(leaves)):
+ ind_samples_of_leaves.append([leaves[i]['node'], np.where(y_test_pred == i)[0]])
+ # Calculate leaf and dedrogram purity
+ dp = dendrogram_purity(model.tree, y_test, ind_samples_of_leaves)
+ lp = leaf_purity(model.tree, y_test, ind_samples_of_leaves)
+ # Note: Only comparable DP & LP values wrt baselines if they have the same n_leaves for all methods
+ wandb.log({"Test Dendrogram Purity": dp, "Test Leaf Purity": lp})
+
+ # Save the tree structure of TreeVAE and log it
+ data_tree = construct_data_tree(model, y_predicted=y_test_pred, y_true=y_test, n_leaves=len(node_leaves_test),
+ data_name=configs['data']['data_name'])
+
+ if configs['globals']['save_model']:
+ with open(experiment_path / 'data_tree.npy', 'wb') as save_file:
+ np.save(save_file, data_tree)
+ with open(experiment_path / 'config.yaml', 'w', encoding='utf8') as outfile:
+ yaml.dump(configs, outfile, default_flow_style=False, allow_unicode=True)
+
+ table = wandb.Table(columns=["node_id", "node_name", "parent", "size"], data=data_tree)
+ fields = {"node_name": "node_name", "node_id": "node_id", "parent": "parent", "size": "size"}
+ dendro = wandb.plot_table(vega_spec_name="stacey/flat_tree", data_table=table, fields=fields)
+ wandb.log({"dendogram_final": dendro})
+
+ # Printing important results
+ print(np.unique(y_test_pred, return_counts=True))
+ print("Accuracy:", acc)
+ print("Normalized Mutual Information:", nmi)
+ print("Adjusted Rand Index:", ari)
+ print("Dendrogram Purity:", dp)
+ print("Leaf Purity:", lp)
+ print("Digits", np.unique(y_test))
+
+ # Compute the log-likehood of the test data
+ # ATTENTION it might take a while! If not interested disable the setting in configs
+ if configs['training']['compute_ll']:
+ compute_likelihood(testset, model, device, configs)
+ return
+
+
+def compute_likelihood(testset, model, device, configs):
+ """
+ Compute the approximated log-likelihood calculated using 1000 importance-weighted samples.
+
+ Parameters
+ ----------
+ testset: torch.utils.data.Dataset
+ The test dataset
+ model: models.model.TreeVAE
+ The trained TreeVAE model
+ device: torch.device
+ The device in which to validate the model
+ configs: dict
+ The config setting for training and validating TreeVAE defined in configs or in the command line
+ """
+ ESTIMATION_SAMPLES = 1000
+ gen_test = get_gen(testset, configs, validation=True, shuffle=False)
+ print('\nComputing the log likelihood.... it might take a while.')
+ if configs['training']['activation'] == 'sigmoid':
+ elbo = np.zeros((len(testset), ESTIMATION_SAMPLES))
+ for j in tqdm(range(ESTIMATION_SAMPLES)):
+ elbo[:, j] = predict(gen_test, model, device, 'elbo')
+ _ = gc.collect()
+ elbo_new = elbo[:, :ESTIMATION_SAMPLES]
+ log_likel = np.log(1 / ESTIMATION_SAMPLES) + scipy.special.logsumexp(-elbo_new, axis=1)
+ marginal_log_likelihood = np.sum(log_likel) / len(testset)
+ wandb.log({"test log-likelihood": marginal_log_likelihood})
+ print("Test log-likelihood", marginal_log_likelihood)
+ output_elbo, output_rec_loss = predict(gen_test, model, device, 'elbo', 'rec_loss')
+ print('Test ELBO:', -torch.mean(output_elbo))
+ print('Test Reconstruction Loss:', torch.mean(output_rec_loss))
+
+ elif configs['training']['activation'] == 'mse':
+ # Correct calculation of ELBO and Loglikelihood for 3channel images without assuming diagonal gaussian for
+ # reconstruction
+ old_loss = model.loss
+ model.loss = loss_reconstruction_cov_mse_eval
+ # Note that for comparability to other papers, one might want to add Uniform(0,1) noise to the input images
+ # (in 0,255), to go from the discrete to the assumed continuous inputs
+ # x_test_elbo = x_test * 255
+ # x_test_elbo = (x_test_elbo + tfd.Uniform().sample(x_test_elbo.shape)) / 256
+ output_elbo, output_rec_loss = predict(gen_test, model, device, 'elbo', 'rec_loss')
+ nelbo = torch.mean(output_elbo)
+ nelbo_bpd = nelbo / (torch.log(torch.tensor(2)) * configs['training']['inp_shape']) + 8 # Add 8 to account normalizing of inputs
+ model.loss = old_loss
+ elbo = np.zeros((len(testset), ESTIMATION_SAMPLES))
+ for j in range(ESTIMATION_SAMPLES):
+ # x_test_elbo = x_test * 255
+ # x_test_elbo = (x_test_elbo + tfd.Uniform().sample(x_test_elbo.shape)) / 256
+ output_elbo = predict(gen_test, model, device, 'elbo')
+ elbo[:, j] = output_elbo
+ # Change to bpd
+ elbo_new = elbo[:, :ESTIMATION_SAMPLES]
+ log_likel = np.log(1 / ESTIMATION_SAMPLES) + scipy.special.logsumexp(-elbo_new, axis=1)
+ marginal_log_likelihood = np.sum(log_likel) / len(testset)
+ marginal_log_likelihood = marginal_log_likelihood / (
+ torch.log(torch.tensor(2)) * configs['training']['inp_shape']) - 8
+ wandb.log({"test log-likelihood": marginal_log_likelihood})
+ print('Test Log-Likelihood Bound:', marginal_log_likelihood)
+ print('Test ELBO:', -nelbo_bpd)
+ print('Test Reconstruction Loss:',
+ torch.mean(output_rec_loss) / (torch.log(torch.tensor(2)) * configs['training']['inp_shape']) + 8)
+ model.loss = old_loss
+ else:
+ raise NotImplementedError
+ return
diff --git a/treevae/tree_exploration.ipynb b/treevae/tree_exploration.ipynb
new file mode 100644
index 0000000..7edef2c
--- /dev/null
+++ b/treevae/tree_exploration.ipynb
@@ -0,0 +1,1025 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# TreeVAE"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Table Of Contents\n",
+ "1. [Data Loading](#section_1)\n",
+ "2. [Generations](#section_2)\n",
+ "3. [Reconstructions](#section_3)\n",
+ "4. [Tree and Representation Analysis](#section_4)\n",
+ "5. [CelebA Attributes](#section_5)\n",
+ "\n",
+ "This is the notebook for analyzing and visualizing the trees learnt by TreeVAE. \n",
+ "\n",
+ "Trees can be learnt by running main.py and stored by setting the option save_model to True in the config file.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Data Loading"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Always execute this section first. This section loads the data and model and computes the NMI to ensure that the model was loaded correctly. Make sure to set the path in the second cell to the specific model that you want to analyze."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import torch\n",
+ "import torchvision\n",
+ "from matplotlib import pyplot as plt\n",
+ "from models.model import TreeVAE\n",
+ "import scipy\n",
+ "import os\n",
+ "import yaml\n",
+ "import gc\n",
+ "from tqdm import tqdm\n",
+ "from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score\n",
+ "from pathlib import Path\n",
+ "from utils.utils import reset_random_seeds, display_image\n",
+ "from utils.data_utils import get_data, get_gen\n",
+ "from utils.training_utils import compute_leaves, predict, move_to\n",
+ "from train.validate_tree import compute_likelihood\n",
+ "from models.model_smalltree import SmallTreeVAE\n",
+ "from models.losses import loss_reconstruction_binary, loss_reconstruction_mse, loss_reconstruction_cov_mse_eval\n",
+ "from utils.model_utils import Node, construct_tree_fromnpy, return_list_tree, construct_data_tree, construct_tree_fromnpy\n",
+ "from utils.plotting_utils import plot_tree_graph, get_node_embeddings, draw_tree_with_scatter_plots"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "path = 'models/experiments/'\n",
+ "ex_path = 'mnist/20231025-175819_d6be9' # INSERT YOUR PATH HERE\n",
+ "checkpoint_path = path+ex_path\n",
+ "with open(checkpoint_path + \"/config.yaml\", 'r') as stream:\n",
+ " configs = yaml.load(stream,Loader=yaml.Loader)\n",
+ "print(configs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load Data\n",
+ "trainset, trainset_eval, testset = get_data(configs)\n",
+ "gen_train = get_gen(trainset, configs, validation=False, shuffle=False)\n",
+ "gen_train_eval = get_gen(trainset_eval, configs, validation=True, shuffle=False)\n",
+ "gen_test = get_gen(testset, configs, validation=True, shuffle=False)\n",
+ "gen_train_eval_iter = iter(gen_train_eval)\n",
+ "gen_test_iter = iter(gen_test)\n",
+ "y_train = trainset_eval.dataset.targets[trainset_eval.indices].numpy()\n",
+ "y_test = testset.dataset.targets[testset.indices].numpy()\n",
+ "\n",
+ "# Load Model\n",
+ "n_d = configs['training']['num_clusters_tree']\n",
+ "model = TreeVAE(**configs['training'])\n",
+ "data_tree = np.load(checkpoint_path+'/data_tree.npy', allow_pickle=True)\n",
+ "model = construct_tree_fromnpy(model, data_tree, configs)\n",
+ "if not configs['globals']['eager_mode']:\n",
+ " model = torch.compile(model)\n",
+ "model.load_state_dict(torch.load(checkpoint_path+'/model_weights.pt'),strict=True)\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "model.to(device)\n",
+ "model.eval()\n",
+ "\n",
+ "plot_tree_graph(data_tree)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Compute Train NMI\n",
+ "prob_leaves = predict(gen_train_eval, model, device,'prob_leaves')\n",
+ "y = np.squeeze(np.argmax(prob_leaves, axis=-1))\n",
+ "print('Train NMI:',normalized_mutual_info_score(y, np.squeeze(y_train)))\n",
+ "\n",
+ "tot_counts = []\n",
+ "print(\" Leaf\", np.arange(10))\n",
+ "for i in np.unique(y_test):\n",
+ " list_y_hat, counts = np.unique(y[np.squeeze(y_train)==i], return_counts=True)\n",
+ " for j in range(n_d):\n",
+ " if j not in list_y_hat:\n",
+ " list_y_hat = np.insert(list_y_hat, j, j)\n",
+ " counts = np.insert(counts, j, 0)\n",
+ " tot_counts.append(counts)\n",
+ " print(f\"Class {i:<10}\", list_y_hat, counts)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# Compute Test NMI\n",
+ "prob_leaves = predict(gen_test, model, device,'prob_leaves')\n",
+ "y = np.squeeze(np.argmax(prob_leaves, axis=-1))\n",
+ "print('Test NMI:', normalized_mutual_info_score(y, np.squeeze(y_test)))\n",
+ "\n",
+ "tot_counts = []\n",
+ "print(\" Leaf\", np.arange(10))\n",
+ "for i in np.unique(y_test):\n",
+ " list_y_hat, counts = np.unique(y[np.squeeze(y_test)==i], return_counts=True)\n",
+ " for j in range(n_d):\n",
+ " if j not in list_y_hat:\n",
+ " list_y_hat = np.insert(list_y_hat, j, j)\n",
+ " counts = np.insert(counts, j, 0)\n",
+ " tot_counts.append(counts)\n",
+ " print(f\"Class {i:<10}\", list_y_hat, counts)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Generations"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This section is concerned with unconditionally generating new samples as opposed to reconstructing existing data points."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Clusterwise generations"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here, given one unconditional random sampling from the root, we visualize the generations for each leaf. That is, each row corresponds to one sample and each column corresponds to one leaf. Above each generation, we provide the probability of falling into the respective leaf for this sample. \n",
+ "\n",
+ "This way of visualization can provide insights on the characteristics each leaf is associated with. Observe that the generations differ across the leaves, as each leaf decodes the sample in the style of the cluster that it learnt. It is likely that cluster-differences are observed more strongly than in the reconstructions' section, as here, we have no guiding information from the bottom-up."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "n_imgs = 15\n",
+ "with torch.no_grad():\n",
+ " reconstructions, p_c_z = model.generate_images(n_imgs, device)\n",
+ "reconstructions = move_to(reconstructions, 'cpu')\n",
+ "for i in range(n_imgs):\n",
+ " fig, axs = plt.subplots(1, n_d, figsize=(15, 15))\n",
+ " for c in range(n_d):\n",
+ " axs[c].imshow(display_image(reconstructions[c][i]), cmap=plt.get_cmap('gray'))\n",
+ " axs[c].set_title(f\"L{c}: \" + f\"p=%.2f\" % torch.round(p_c_z[i][c],decimals=2))\n",
+ " axs[c].axis('off')\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Generate new images according to cluster assignment"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this subsection, given a leaf, we store the first 100 generations, for which this leaf is their most likely cluster assignment. This allows us to gain insights on the cluster and characteristics that each leaf learnt."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Here, we store generations for each leaf simultaneously until\n",
+ "# every leaf has n_imgs associated generations, or we iterated through max_iter batches.\n",
+ "n_imgs = configs['training']['batch_size']\n",
+ "max_iter = 200\n",
+ "\n",
+ "with torch.no_grad():\n",
+ " reconstructions, p_c_z = model.generate_images(n_imgs, device)\n",
+ "reconstructions = move_to(reconstructions, 'cpu')\n",
+ "clusterwise_reconst = [torch.zeros_like(reconstructions[0][0:2]) for i in range(len(reconstructions))]\n",
+ "n_iter=0\n",
+ "while min([clusterwise_reconst[leaf_ind].shape[0] for leaf_ind in range(len(reconstructions))]) < n_imgs+2 and n_iter < max_iter:\n",
+ " for i in range(n_imgs):\n",
+ " leaf_ind = torch.argmax(p_c_z[i])\n",
+ " if clusterwise_reconst[leaf_ind].shape[0] < n_imgs+2:\n",
+ " clusterwise_reconst[leaf_ind] = torch.vstack([clusterwise_reconst[leaf_ind], reconstructions[leaf_ind][i].unsqueeze(0)])\n",
+ " with torch.no_grad():\n",
+ " reconstructions, p_c_z = model.generate_images(n_imgs, device)\n",
+ " reconstructions = move_to(reconstructions, 'cpu')\n",
+ " n_iter += 1\n",
+ " if n_iter %10 == 0:\n",
+ " print(n_iter)\n",
+ "for i in range(len(reconstructions)):\n",
+ " clusterwise_reconst[i] = clusterwise_reconst[i][2:,:]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# For each leaf, we visualize n_grid x n_grid generations, \n",
+ "# which have highest probability of being assigned to this cluster\n",
+ "n_leaves = len(clusterwise_reconst)\n",
+ "n_grid = min(5,int((clusterwise_reconst[leaf_ind].shape[0])**.5))\n",
+ "\n",
+ "k=0\n",
+ "for l in range(n_leaves):\n",
+ " fig, axs = plt.subplots(n_grid, n_grid, figsize=(4,4))\n",
+ " i=0\n",
+ " for a in range(n_grid):\n",
+ " for b in range(n_grid):\n",
+ " axs[a,b].set_axis_off()\n",
+ " axs[a,b].imshow(display_image(clusterwise_reconst[k][i]), cmap=plt.get_cmap('gray'))\n",
+ " i+=1\n",
+ " fig.suptitle(f\"Leaf {k} samples\",fontsize=25)\n",
+ " fig.tight_layout()\n",
+ " fig.subplots_adjust(top=0.87)\n",
+ " k+=1\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# For closer inspection, one can select a specific leaf by leaf_ind and investigate more generations.\n",
+ "leaf_ind = 0\n",
+ " \n",
+ "n_grid = int((clusterwise_reconst[leaf_ind].shape[0])**.5)\n",
+ "fig, axs = plt.subplots(n_grid, n_grid, figsize=(15,15))\n",
+ "\n",
+ "i=0\n",
+ "for a in range(n_grid):\n",
+ " for b in range(n_grid):\n",
+ " axs[a,b].set_axis_off()\n",
+ " axs[a,b].imshow(display_image(clusterwise_reconst[leaf_ind][i]), cmap=plt.get_cmap('gray'))\n",
+ " i+=1\n",
+ "fig.suptitle(f\"Leaf {leaf_ind} samples\",fontsize=25)\n",
+ "fig.tight_layout()\n",
+ "fig.subplots_adjust(top=0.95)\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Reconstructions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This section is concerned with computing reconstructions of input samples."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Clusterwise reconstructions"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here, given one input image, we visualize the reconstructions for each leaf. That is, each row corresponds to one input image and each column corresponds to one leaf. Above each reconstruction, we provide the probability of falling into the respective leaf for this sample. \n",
+ "\n",
+ "This way of visualization can provide insights on the characteristics each leaf is associated with. Observe that the reconstructions differ across the leaves, as each leaf reconstructs the image in the style of the cluster that it learnt."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# Training Set\n",
+ "gen_train_eval_iter = iter(gen_train_eval)\n",
+ "inputs, labels = next(gen_train_eval_iter)\n",
+ "\n",
+ "\n",
+ "inputs_gpu, labels_gpu = inputs.to(device), labels.to(device)\n",
+ "with torch.no_grad():\n",
+ " reconstructions_gpu, node_leaves_gpu = model.compute_reconstruction(inputs_gpu)\n",
+ "reconstructions = move_to(reconstructions_gpu, 'cpu')\n",
+ "node_leaves = move_to(node_leaves_gpu, 'cpu')\n",
+ "\n",
+ "\n",
+ "for i in range(10):\n",
+ " print(\"Class:\", labels[i].item())\n",
+ " fig, axs = plt.subplots(1, n_d+1, figsize=(15, 15))\n",
+ " axs[n_d].imshow(display_image(inputs[i]), cmap=plt.get_cmap('gray'))\n",
+ " axs[n_d].set_title(\"Original\")\n",
+ " axs[n_d].axis('off')\n",
+ " for c in range(n_d):\n",
+ " axs[c].imshow(display_image(reconstructions[c][i]), cmap=plt.get_cmap('gray'))\n",
+ " axs[c].set_title(f\"L{c}: \" + f\"p=%.2f\" % torch.round(node_leaves[c]['prob'][i],decimals=2))\n",
+ " axs[c].axis('off')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# Test Set\n",
+ "gen_test_iter = iter(gen_test)\n",
+ "inputs, labels = next(gen_test_iter)\n",
+ "inputs_gpu, labels_gpu = inputs.to(device), labels.to(device)\n",
+ "with torch.no_grad():\n",
+ " reconstructions_gpu, node_leaves_gpu = model.compute_reconstruction(inputs_gpu)\n",
+ "reconstructions = move_to(reconstructions_gpu, 'cpu')\n",
+ "node_leaves = move_to(node_leaves_gpu, 'cpu')\n",
+ "\n",
+ "\n",
+ "for i in range(10):\n",
+ " print(\"Class:\", labels[i].item())\n",
+ " fig, axs = plt.subplots(1, n_d+1, figsize=(15, 15))\n",
+ " axs[n_d].imshow(display_image(inputs[i]), cmap=plt.get_cmap('gray'))\n",
+ " axs[n_d].set_title(\"Original\")\n",
+ " axs[n_d].axis('off')\n",
+ " for c in range(n_d):\n",
+ " axs[c].imshow(display_image(reconstructions[c][i]), cmap=plt.get_cmap('gray'))\n",
+ " axs[c].set_title(f\"L{c}: \" + f\"p=%.2f\" % torch.round(node_leaves[c]['prob'][i],decimals=2))\n",
+ " axs[c].axis('off')\n",
+ "\n",
+ " plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Group reconstructions according to cluster assignment"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this subsection, given a leaf, we store the reconstructions of the first 100 samples, for which this leaf is their most likely cluster assignment. This allows us to visualize for each leaf, which samples fall into it, in order to gain insights on the cluster that each leaf learnt."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Test Set\n",
+ "# Here, we store samples for each leaf simultaneously by iterating through the training set until\n",
+ "# every leaf has n_imgs associated samples, or we iterated through max_iter batches.\n",
+ "max_iter = 100\n",
+ "n_imgs = configs['training']['batch_size']\n",
+ "\n",
+ "n_iter=0\n",
+ "gen_test_iter = iter(gen_test)\n",
+ "inputs, labels = next(gen_test_iter)\n",
+ "inputs_gpu, labels_gpu = inputs.to(device), labels.to(device)\n",
+ "with torch.no_grad():\n",
+ " reconstructions_gpu, node_leaves_gpu = model.compute_reconstruction(inputs_gpu)\n",
+ "reconstructions = move_to(reconstructions_gpu, 'cpu')\n",
+ "node_leaves = move_to(node_leaves_gpu, 'cpu')\n",
+ "p_c_z = torch.stack([node_leaves[i]['prob'] for i in range(len(node_leaves))],1)\n",
+ "clusterwise_reconst = [torch.zeros_like(reconstructions[0][0:2]) for i in range(len(reconstructions))]\n",
+ "while min([clusterwise_reconst[leaf_ind].shape[0] for leaf_ind in range(len(reconstructions))]) < n_imgs+2 and n_iter < max_iter:\n",
+ " n_iter += 1\n",
+ " if n_iter %10 == 0:\n",
+ " print(n_iter)\n",
+ " for i in range(n_imgs):\n",
+ " leaf_ind = p_c_z[i].numpy().argmax()\n",
+ " if clusterwise_reconst[leaf_ind].shape[0] < n_imgs+2:\n",
+ " clusterwise_reconst[leaf_ind] = torch.vstack([clusterwise_reconst[leaf_ind], reconstructions[leaf_ind][i].unsqueeze(0)])\n",
+ " inputs, labels = next(gen_test_iter)\n",
+ " inputs_gpu, labels_gpu = inputs.to(device), labels.to(device)\n",
+ " with torch.no_grad():\n",
+ " reconstructions_gpu, node_leaves_gpu = model.compute_reconstruction(inputs_gpu)\n",
+ " reconstructions = move_to(reconstructions_gpu, 'cpu')\n",
+ " node_leaves = move_to(node_leaves_gpu, 'cpu')\n",
+ " p_c_z = torch.stack([node_leaves[i]['prob'] for i in range(len(node_leaves))],1)\n",
+ "for i in range(len(reconstructions)):\n",
+ " clusterwise_reconst[i] = clusterwise_reconst[i][2:,:]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# For each leaf, we visualize n_grid x n_grid reconstructions of samples, \n",
+ "# which have highest probability of being assigned to this cluster\n",
+ "n_leaves = len(clusterwise_reconst)\n",
+ "n_grid = min(5,int((clusterwise_reconst[leaf_ind].shape[0])**.5))\n",
+ "\n",
+ "k=0\n",
+ "for l in range(n_leaves):\n",
+ " fig, axs = plt.subplots(n_grid, n_grid, figsize=(4,4))\n",
+ " i=0\n",
+ " for a in range(n_grid):\n",
+ " for b in range(n_grid):\n",
+ " axs[a,b].set_axis_off()\n",
+ " axs[a,b].imshow(display_image(clusterwise_reconst[k][i]), cmap=plt.get_cmap('gray'))\n",
+ " i+=1\n",
+ " fig.suptitle(f\"Leaf {k} samples\",fontsize=25)\n",
+ " fig.tight_layout()\n",
+ " fig.subplots_adjust(top=0.87)\n",
+ " k+=1\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# For closer inspection, one can select a specific leaf by leaf_ind and investigate more reconstructions.\n",
+ "leaf_ind = 0\n",
+ " \n",
+ "n_grid = int((clusterwise_reconst[leaf_ind].shape[0])**.5)\n",
+ "fig, axs = plt.subplots(n_grid, n_grid, figsize=(15,15))\n",
+ "\n",
+ "i=0\n",
+ "for a in range(n_grid):\n",
+ " for b in range(n_grid):\n",
+ " axs[a,b].set_axis_off()\n",
+ " axs[a,b].imshow(display_image(clusterwise_reconst[leaf_ind][i]), cmap=plt.get_cmap('gray'))\n",
+ " i+=1\n",
+ "fig.suptitle(f\"Leaf {leaf_ind} samples\",fontsize=25)\n",
+ "fig.tight_layout()\n",
+ "fig.subplots_adjust(top=0.95)\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 4. Tree and Representation Analysis "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this section we explore the structure of the learnt tree as well as the representations"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Tree embeddings"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Below, we visualize the learnt embeddings by performing PCA on each node. Set use_pca to False if you want to directly see the first two dimensions without dimensionality reduction."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Do you want to look at pca embeddings or learnt representations\n",
+ "use_pca = True\n",
+ "\n",
+ "# pick data loader\n",
+ "data_loader = gen_test\n",
+ "\n",
+ "\n",
+ "# each entry in node_embeddings is a dictionary with keys 'prob' and 'z_sample' for each leaf\n",
+ "nb_nodes = len(data_tree)\n",
+ "node_embeddings = [{'prob': [], 'z_sample': []} for _ in range(nb_nodes)]\n",
+ "label_list = []\n",
+ "\n",
+ "# iterate over test data points\n",
+ "for inputs, labels in tqdm(data_loader):\n",
+ " inputs_gpu, labels_gpu = inputs.to(device), labels.to(device)\n",
+ "\n",
+ " label_list.append(labels)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " node_info = get_node_embeddings(model, inputs_gpu)\n",
+ " node_info = move_to(node_info, 'cpu')\n",
+ "\n",
+ " # for each node, append the probability and z_sample to the list\n",
+ "\n",
+ " k = 0 # need this variable to skip \"no digits\" nodes\n",
+ " for i in range(nb_nodes):\n",
+ " j = i - k\n",
+ " if data_tree[i][1] == 'no digits':\n",
+ " k += 1\n",
+ " continue\n",
+ "\n",
+ " node_embeddings[i]['prob'].append(node_info[j]['prob'].numpy())\n",
+ " node_embeddings[i]['z_sample'].append(node_info[j]['z_sample'].numpy())\n",
+ "\n",
+ "# flatten the lists\n",
+ "k = 0\n",
+ "for i in range(nb_nodes):\n",
+ " if data_tree[i][1] == 'no digits':\n",
+ " node_embeddings[i]['prob'] = []\n",
+ " node_embeddings[i]['z_sample'] = []\n",
+ " continue\n",
+ " \n",
+ " node_embeddings[i]['prob'] = np.concatenate(node_embeddings[i]['prob'])\n",
+ " node_embeddings[i]['z_sample'] = np.concatenate(node_embeddings[i]['z_sample'])\n",
+ "\n",
+ "label_list = np.concatenate(label_list)\n",
+ "\n",
+ "# Draw the tree graph with scatter plots as nodes and arrows for edges\n",
+ "draw_tree_with_scatter_plots(data_tree, node_embeddings, label_list, pca = use_pca)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Leaf embeddings"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Below, we visualize the learnt leaf embeddings after performing PCA. This allows for a closer inspection of the leaf embeddings, which are also visualized in the tree above."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# get leaf embeddings for each test data point\n",
+ "gen_test_iter = iter(gen_test)\n",
+ "inputs, labels = next(gen_test_iter)\n",
+ "inputs_gpu, labels_gpu = inputs.to(device), labels.to(device)\n",
+ "with torch.no_grad():\n",
+ " reconstructions_gpu, node_leaves_gpu = model.compute_reconstruction(inputs_gpu)\n",
+ "reconstructions = move_to(reconstructions_gpu, 'cpu')\n",
+ "node_leaves = move_to(node_leaves_gpu, 'cpu')\n",
+ "\n",
+ "# each entry in node_leaves is a dictionary with keys 'prob' and 'z_sample' for each leaf\n",
+ "node_leaves = [{'prob': [], 'z_sample': []} for _ in range(n_d)]\n",
+ "label_list = []\n",
+ "\n",
+ "for inputs, labels in tqdm(gen_test):\n",
+ " inputs_gpu, labels_gpu = inputs.to(device), labels.to(device)\n",
+ "\n",
+ " label_list.append(labels)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " _, node_leaves_gpu = model.compute_reconstruction(inputs_gpu)\n",
+ " node_leaves_cpu = move_to(node_leaves_gpu, 'cpu')\n",
+ " \n",
+ " # for each leaf, append the probability and z_sample to the list\n",
+ " for i in range(n_d):\n",
+ " node_leaves[i]['prob'].append(node_leaves_cpu[i]['prob'].numpy())\n",
+ " node_leaves[i]['z_sample'].append(node_leaves_cpu[i]['z_sample'].numpy())\n",
+ "\n",
+ "# flatten the lists\n",
+ "for i in range(n_d):\n",
+ " node_leaves[i]['prob'] = np.concatenate(node_leaves[i]['prob'])\n",
+ " node_leaves[i]['z_sample'] = np.concatenate(node_leaves[i]['z_sample'])\n",
+ "\n",
+ "label_list = np.concatenate(label_list)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# visualize z_sample for each leaf, do PCA and plot in 2D\n",
+ "from sklearn.decomposition import PCA\n",
+ "\n",
+ "# PCA on node_leaves['z_sample']\n",
+ "colors = label_list\n",
+ "plt.figure(figsize=(20, 10))\n",
+ "\n",
+ "for i in range(n_d):\n",
+ " z_sample = node_leaves[i]['z_sample']\n",
+ " weights = node_leaves[i]['prob']\n",
+ "\n",
+ " pca = PCA(n_components=2)\n",
+ " z_sample_pca = pca.fit_transform(z_sample)\n",
+ "\n",
+ " plt.subplot(2, -(-len(node_leaves)//2), i+1)\n",
+ " plt.scatter(z_sample_pca[:, 0], z_sample_pca[:, 1], c=colors, cmap='tab10', alpha=weights)\n",
+ " plt.title(f\"Leaf {i}\")\n",
+ " plt.colorbar()\n",
+ " plt.xlabel(\"PC1\")\n",
+ " plt.ylabel(\"PC2\")\n",
+ "\n",
+ "plt.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 5. CelebA attributes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This section is designated for analyzing the learnt splits and clusters of datasets without ground truth cluster labels, but various attributes. It is designed with a focus on CelebA."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "assert configs['data']['data_name'] == 'celeba'\n",
+ "import pandas as pd\n",
+ "data_dir = './data/celeba/'\n",
+ "attr = pd.read_csv(data_dir+'/list_attr_celeba.txt', sep=\"\\s+\", skiprows=1)\n",
+ "y_test = attr[182637:]\n",
+ "y_train = attr[:162770]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Calculate cluster-matching attributes"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Preprocessing step where we store for every node, the indeces of the test samples, whose most likely path went through said node"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Change to leafwise view of samples\n",
+ "prob_leaves = predict(gen_test, model, device,'prob_leaves')\n",
+ "y = np.squeeze(np.argmax(prob_leaves, axis=-1))\n",
+ "sample_ind = []\n",
+ "for i in range(len(np.unique(y))):\n",
+ " sample_ind.append([])\n",
+ "for i in np.unique(y):\n",
+ " sample_ind[i] = np.where(y==i)[0]\n",
+ " \n",
+ "# Fill all internal nodes and create datatree with corresponding samples\n",
+ "data_tree_ids = []\n",
+ "for i in range(len(data_tree)):\n",
+ " data_tree_ids.append([i,[]])\n",
+ "for listnode in reversed(data_tree_ids):\n",
+ " i = listnode[0]\n",
+ " if data_tree[i][3] == 1:\n",
+ " # If leaf, just copy samples from above\n",
+ " data_tree_ids[i][1] = sample_ind[i-(len(data_tree_ids)-len(sample_ind))]\n",
+ " else:\n",
+ " # If internal node, take samples from children\n",
+ " children = []\n",
+ " for j in range(len(data_tree)):\n",
+ " if data_tree[j][2] == i:\n",
+ " children.append(j)\n",
+ " assert len(children)==2\n",
+ " data_tree_ids[i][1] = np.sort(np.concatenate((data_tree_ids[children[0]][1],data_tree_ids[children[1]][1])))\n",
+ " \n",
+ " \n",
+ " \n",
+ "# Final ID-tree, where for each node, we store which test sample went through it\n",
+ "data_tree_ids"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For each split, we additionally store the five attribute that correlate most highly with the the split. This gives an intuition on what attributes the split is based on, i.e. which characteristics the split differentiates between.\n",
+ "\n",
+ "Note that for CelebA, the \"ground truth\" attributes are in our opinion not the most descriptive ones regarding overall image&cluster impression and focus sometimes on details, on which we don't pick up on."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# Highest correlated features per split\n",
+ "data_tree_new = data_tree.copy()\n",
+ "for i in range(len(data_tree_ids)):\n",
+ " in_leaf = False\n",
+ " node_ind = data_tree_ids[i][1]\n",
+ " # Samples in node before split\n",
+ " node_samples = y_test.iloc[node_ind]\n",
+ " # Split of samples\n",
+ " node_split = np.zeros(len(y_test))\n",
+ " children = []\n",
+ " for j in range(len(data_tree)):\n",
+ " if data_tree[j][2] == i:\n",
+ " children.append(j)\n",
+ " if children == []:\n",
+ " in_leaf = True\n",
+ " else:\n",
+ " in_leaf = False\n",
+ " if not in_leaf: \n",
+ " child_left = children[0]\n",
+ " node_split[data_tree_ids[child_left][1]] = 1\n",
+ " node_split = node_split[node_ind]\n",
+ " # Store corr coefficients\n",
+ " corr = np.corrcoef(np.concatenate((np.array(node_samples),np.expand_dims(node_split,1)),1).T)[len(y_test.columns),0:len(y_test.columns)]\n",
+ " data_tree_ids[i].append(corr)\n",
+ " \n",
+ " # Store 5 strongest correlations\n",
+ " ind = np.abs(corr).argsort()[-5:][::-1]\n",
+ " features = y_test.columns[ind].tolist()\n",
+ " for k in range(len(ind)):\n",
+ " if corr[ind[k]] < 0:\n",
+ " features[k] = 'not ' + features[k]\n",
+ " features[k] = features[k] + ' ({})'.format(round(corr[ind[k]], 2))\n",
+ " data_tree_ids[i].append(features)\n",
+ " \n",
+ "data_tree_ids"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "As a summary, for each attribute, we print the split that has the highest correlation with it. This gives an intuition on what internal node differentiates the most according to a given attribute."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Attributewise node with highest correlation (i.e. internal node that was splitting attribute the most)\n",
+ "attr_maxnode = y_test.columns.tolist()\n",
+ "for i in range(len(y_test.columns)):\n",
+ " attrcorr = []\n",
+ " for node in range(len(data_tree_ids)):\n",
+ " if len(data_tree_ids[node])==len(data_tree_ids[0]):\n",
+ " attrcorr.append(data_tree_ids[node][2][i])\n",
+ " attrcorr = np.array(attrcorr)\n",
+ " if len(np.argwhere(np.isnan(attrcorr)).squeeze(1))>0:\n",
+ " attrcorr[np.argwhere(np.isnan(attrcorr)).squeeze(1)] = 0\n",
+ " ind = np.argmax(np.abs(attrcorr))\n",
+ " attr_maxnode[i] = attr_maxnode[i] + \": \" + f'{ind}' ' ({})'.format(round(attrcorr[ind], 2))\n",
+ "attr_maxnode"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Evaluation of clusters"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We can analyze the clustering quality according to certain attributes. To do this, in the second cell, pick the indeces of the attributes, whose intersections you want to determine as ground truth clusterings. Then, the NMI is calculated for treating the selected attributes' intersections as ground-truth clusters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y_test.columns"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Pick labels here\n",
+ "label_ind = [2,20,39]\n",
+ "print([attr_maxnode[i] for i in label_ind])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "if len(label_ind)==2:\n",
+ " label_dict = {\n",
+ " (-1, -1): 0,\n",
+ " (-1, 1): 1,\n",
+ " (1, -1): 2,\n",
+ " (1, 1): 3\n",
+ " }\n",
+ "else:\n",
+ " label_dict = {\n",
+ " (-1, -1, -1): 0,\n",
+ " (-1, -1, 1): 1,\n",
+ " (-1, 1, -1): 2,\n",
+ " (-1, 1, 1): 3,\n",
+ " (1, -1, -1): 4,\n",
+ " (1, -1, 1): 5,\n",
+ " (1, 1, -1): 6,\n",
+ " (1, 1, 1): 7\n",
+ " }\n",
+ "selected_classes = np.array(y_test.iloc[:, label_ind])\n",
+ "selected_classes = [tuple([x for x in a]) for a in selected_classes]\n",
+ "label_true = [label_dict[sample_labels] for sample_labels in selected_classes]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print('NMI:')\n",
+ "normalized_mutual_info_score(y, label_true)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Create attribute-wise percentage table for leaves"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This subsection presents the frequency of the attributes for each leaf. The numbers indicate the percentage of samples assigned to a given leaf, that contain a certain attribute. For example: 67% of all people assigned to leaf 3 are blonde."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "leaf_attr = []\n",
+ "n_leaves = len(np.unique(y))\n",
+ "for i in range(1,1+n_leaves):\n",
+ " data_tree_ids[-i].append((y_test.iloc[data_tree_ids[-i][1]] == 1).mean())\n",
+ " leaf_attr.append((y_test.iloc[data_tree_ids[-i][1]] == 1).mean())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "leaf_attr_table = pd.DataFrame(np.stack(leaf_attr)[::-1])\n",
+ "leaf_attr_table.columns = y_test.columns\n",
+ "\n",
+ "pd.set_option('display.max_columns', None)\n",
+ "leaf_attr_table.round(3)*100"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here, one can create new attributes by combining previous attributes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_vars =[]\n",
+ "for i in range(1,1+n_leaves):\n",
+ " temp = y_test.iloc[data_tree_ids[-i][1]] == 1\n",
+ " temp['Hair_Loss'] = np.clip(temp['Bald'] + temp['Receding_Hairline'],0,1)\n",
+ " temp['Dark_Hair'] = np.clip(temp['Brown_Hair'] + temp['Black_Hair'],0,1)\n",
+ " temp['Happy'] = np.clip(temp['Smiling'] + temp['Mouth_Slightly_Open'],0,1)\n",
+ " temp['Light_Hair'] = np.clip(temp['Blond_Hair'] + temp['Gray_Hair'],0,1)\n",
+ " temp['Beard'] = np.clip(temp['5_o_Clock_Shadow'] + 1-temp['No_Beard'],0,1)\n",
+ "\n",
+ " new_vars.append([temp['Hair_Loss'].mean(),temp['Dark_Hair'].mean(),temp['Happy'].mean(),temp['Light_Hair'].mean(),temp['Beard'].mean()])\n",
+ " \n",
+ "new_vars_table = pd.DataFrame(np.stack(new_vars)[::-1])\n",
+ "new_vars_table.columns = ['Hair_Loss','Dark_Hair','Happy','Light_Hair','Beard']\n",
+ "new_vars_table.round(3)*100"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "new_env",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.20"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/treevae/treevae.png b/treevae/treevae.png
new file mode 100644
index 0000000..5c1fd07
Binary files /dev/null and b/treevae/treevae.png differ
diff --git a/treevae/treevae.yml b/treevae/treevae.yml
new file mode 100644
index 0000000..04b8a84
--- /dev/null
+++ b/treevae/treevae.yml
@@ -0,0 +1,187 @@
+name: treevae
+channels:
+ - anaconda
+ - defaults
+ - conda-forge
+ - bioconda
+dependencies:
+ - _libgcc_mutex=0.1=conda_forge
+ - _openmp_mutex=4.5=2_kmp_llvm
+ - anyio=3.5.0=py311h06a4308_0
+ - appdirs=1.4.4=pyh9f0ad1d_0
+ - argon2-cffi=21.3.0=pyhd3eb1b0_0
+ - argon2-cffi-bindings=21.2.0=py311h5eee18b_0
+ - asttokens=2.0.5=pyhd3eb1b0_0
+ - attrs=23.1.0=py311h06a4308_0
+ - backcall=0.2.0=pyhd3eb1b0_0
+ - bleach=4.1.0=pyhd3eb1b0_0
+ - brotlipy=0.7.0=py311h5eee18b_1002
+ - bzip2=1.0.8=h7f98852_4
+ - ca-certificates=2023.08.22=h06a4308_0
+ - certifi=2023.7.22=py311h06a4308_0
+ - cffi=1.15.1=py311h5eee18b_3
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
+ - click=8.1.7=unix_pyh707e725_0
+ - comm=0.1.2=py311h06a4308_0
+ - cryptography=41.0.3=py311hdda0065_0
+ - cuda-version=11.8=h70ddcb2_2
+ - cudatoolkit=11.8.0=h4ba93d1_12
+ - cudnn=8.8.0.121=h838ba91_3
+ - debugpy=1.6.7=py311h6a678d5_0
+ - decorator=5.1.1=pyhd3eb1b0_0
+ - defusedxml=0.7.1=pyhd3eb1b0_0
+ - docker-pycreds=0.4.0=py_0
+ - entrypoints=0.4=py311h06a4308_0
+ - executing=0.8.3=pyhd3eb1b0_0
+ - filelock=3.12.4=pyhd8ed1ab_0
+ - freetype=2.12.1=h4a9f257_0
+ - gitdb=4.0.10=pyhd8ed1ab_0
+ - gitpython=3.1.37=pyhd8ed1ab_0
+ - gmp=6.2.1=h58526e2_0
+ - gmpy2=2.1.2=py311h6a5fa03_1
+ - icu=73.2=h59595ed_0
+ - idna=3.4=py311h06a4308_0
+ - ipykernel=6.25.0=py311h92b7b1e_0
+ - ipython=8.15.0=py311h06a4308_0
+ - ipython_genutils=0.2.0=pyhd3eb1b0_1
+ - ipywidgets=8.0.4=py311h06a4308_0
+ - jedi=0.18.1=py311h06a4308_1
+ - jinja2=3.1.2=pyhd8ed1ab_1
+ - joblib=1.2.0=py311h06a4308_0
+ - jsonschema=4.17.3=py311h06a4308_0
+ - jupyter=1.0.0=pyhd8ed1ab_10
+ - jupyter_client=8.1.0=py311h06a4308_0
+ - jupyter_console=6.6.3=py311h06a4308_0
+ - jupyter_core=5.3.0=py311h06a4308_0
+ - jupyter_server=1.13.5=pyhd3eb1b0_0
+ - jupyterlab_widgets=3.0.5=py311h06a4308_0
+ - lcms2=2.15=hb7c19ff_3
+ - ld_impl_linux-64=2.40=h41732ed_0
+ - lerc=4.0.0=h27087fc_0
+ - libabseil=20230802.1=cxx17_h59595ed_0
+ - libblas=3.9.0=19_linux64_openblas
+ - libcblas=3.9.0=19_linux64_openblas
+ - libdeflate=1.19=hd590300_0
+ - libexpat=2.5.0=hcb278e6_1
+ - libffi=3.4.2=h7f98852_5
+ - libgcc-ng=13.2.0=h807b86a_2
+ - libgfortran-ng=13.2.0=h69a702a_2
+ - libgfortran5=13.2.0=ha4646dd_2
+ - libhwloc=2.9.3=default_h554bfaf_1009
+ - libiconv=1.17=h166bdaf_0
+ - libjpeg-turbo=3.0.0=hd590300_1
+ - liblapack=3.9.0=19_linux64_openblas
+ - libmagma=2.7.1=hc72dce7_6
+ - libmagma_sparse=2.7.1=h8354cda_6
+ - libnsl=2.0.1=hd590300_0
+ - libopenblas=0.3.24=pthreads_h413a1c8_0
+ - libpng=1.6.39=h5eee18b_0
+ - libprotobuf=4.24.3=hf27288f_1
+ - libsodium=1.0.18=h7b6447c_0
+ - libsqlite=3.43.2=h2797004_0
+ - libstdcxx-ng=13.2.0=h7e041cc_2
+ - libtiff=4.6.0=ha9c0a0a_2
+ - libuuid=2.38.1=h0b41bf4_0
+ - libuv=1.46.0=hd590300_0
+ - libwebp-base=1.3.2=h5eee18b_0
+ - libxcb=1.15=h7f8727e_0
+ - libxml2=2.11.5=h232c23b_1
+ - libzlib=1.2.13=hd590300_5
+ - llvm-openmp=17.0.2=h4dfa4b3_0
+ - magma=2.7.1=ha770c72_6
+ - markupsafe=2.1.3=py311h459d7ec_1
+ - matplotlib-inline=0.1.6=py311h06a4308_0
+ - mkl=2022.2.1=h84fe81f_16997
+ - mpc=1.3.1=hfe3b2da_0
+ - mpfr=4.2.0=hb012696_0
+ - mpmath=1.3.0=pyhd8ed1ab_0
+ - nbclassic=0.5.5=py311h06a4308_0
+ - nbformat=5.9.2=py311h06a4308_0
+ - nccl=2.19.3.1=h6103f9b_0
+ - ncurses=6.4=hcb278e6_0
+ - nest-asyncio=1.5.6=py311h06a4308_0
+ - networkx=3.1=pyhd8ed1ab_0
+ - notebook=6.5.4=py311h06a4308_0
+ - notebook-shim=0.2.2=py311h06a4308_0
+ - numpy=1.26.0=py311h64a7726_0
+ - openjpeg=2.5.0=h488ebb8_3
+ - openssl=3.1.3=hd590300_0
+ - packaging=23.1=py311h06a4308_0
+ - pandoc=2.12=h06a4308_3
+ - pandocfilters=1.5.0=pyhd3eb1b0_0
+ - parso=0.8.3=pyhd3eb1b0_0
+ - pathtools=0.1.2=py_1
+ - pexpect=4.8.0=pyhd3eb1b0_3
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
+ - pillow=10.1.0=py311ha6c5da5_0
+ - pip=23.3=pyhd8ed1ab_0
+ - platformdirs=3.10.0=py311h06a4308_0
+ - prometheus_client=0.14.1=py311h06a4308_0
+ - prompt-toolkit=3.0.36=py311h06a4308_0
+ - prompt_toolkit=3.0.36=hd3eb1b0_0
+ - protobuf=4.24.3=py311h46cbc50_1
+ - psutil=5.9.5=py311h459d7ec_1
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
+ - pure_eval=0.2.2=pyhd3eb1b0_0
+ - pycparser=2.21=pyhd3eb1b0_0
+ - pygments=2.15.1=py311h06a4308_1
+ - pyopenssl=23.2.0=py311h06a4308_0
+ - pyrsistent=0.18.0=py311h5eee18b_0
+ - pysocks=1.7.1=py311h06a4308_0
+ - python=3.11.6=hab00c5b_0_cpython
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
+ - python-fastjsonschema=2.16.2=py311h06a4308_0
+ - python_abi=3.11=4_cp311
+ - pytorch=2.0.0=cuda112py311hf4e4fe6_303
+ - pytorch-gpu=2.0.0=cuda112py311h398211c_303
+ - pyyaml=6.0.1=py311h459d7ec_1
+ - pyzmq=25.1.0=py311h6a678d5_0
+ - qtconsole-base=5.4.4=pyha770c72_0
+ - qtpy=2.4.0=pyhd8ed1ab_0
+ - readline=8.2=h8228510_1
+ - requests=2.31.0=py311h06a4308_0
+ - scikit-learn=1.2.2=py311h6a678d5_1
+ - scipy=1.11.3=py311h64a7726_1
+ - send2trash=1.8.0=pyhd3eb1b0_1
+ - sentry-sdk=1.32.0=pyhd8ed1ab_0
+ - setproctitle=1.3.3=py311h459d7ec_0
+ - setuptools=68.2.2=pyhd8ed1ab_0
+ - six=1.16.0=pyh6c4a22f_0
+ - sleef=3.5.1=h9b69904_2
+ - smmap=3.0.5=pyh44b312d_0
+ - sniffio=1.2.0=py311h06a4308_1
+ - stack_data=0.2.0=pyhd3eb1b0_0
+ - sympy=1.12=pypyh9d50eac_103
+ - tbb=2021.10.0=h00ab1b0_1
+ - terminado=0.17.1=py311h06a4308_0
+ - testpath=0.6.0=py311h06a4308_0
+ - threadpoolctl=2.2.0=pyh0d69192_0
+ - tk=8.6.13=h2797004_0
+ - torchmetrics=0.11.4=py311h92b7b1e_1
+ - torchvision=0.15.2=cuda112py311h3f38234_2
+ - tornado=6.3.3=py311h5eee18b_0
+ - tqdm=4.65.0=py311h92b7b1e_0
+ - traitlets=5.7.1=py311h06a4308_0
+ - typing-extensions=4.8.0=hd8ed1ab_0
+ - typing_extensions=4.8.0=pyha770c72_0
+ - tzdata=2023c=h71feb2d_0
+ - urllib3=1.26.16=py311h06a4308_0
+ - wandb=0.15.12=pyhd8ed1ab_0
+ - wcwidth=0.2.5=pyhd3eb1b0_0
+ - webencodings=0.5.1=py311h06a4308_1
+ - websocket-client=0.58.0=py311h06a4308_4
+ - wheel=0.41.2=pyhd8ed1ab_0
+ - widgetsnbextension=4.0.5=py311h06a4308_0
+ - xz=5.2.6=h166bdaf_0
+ - yaml=0.2.5=h7b6447c_0
+ - zeromq=4.3.4=h2531618_0
+ - zlib=1.2.13=hd590300_5
+ - zstd=1.5.5=hfc55251_0
+ - pip:
+ - beautifulsoup4==4.12.2
+ - jupyterlab-pygments==0.2.2
+ - mistune==3.0.2
+ - nbclient==0.8.0
+ - nbconvert==7.9.2
+ - soupsieve==2.5
+ - tinycss2==1.2.1
diff --git a/treevae/utils/__pycache__/model_utils.cpython-39.pyc b/treevae/utils/__pycache__/model_utils.cpython-39.pyc
new file mode 100644
index 0000000..b40f1e3
Binary files /dev/null and b/treevae/utils/__pycache__/model_utils.cpython-39.pyc differ
diff --git a/treevae/utils/__pycache__/training_utils.cpython-39.pyc b/treevae/utils/__pycache__/training_utils.cpython-39.pyc
new file mode 100644
index 0000000..56841bb
Binary files /dev/null and b/treevae/utils/__pycache__/training_utils.cpython-39.pyc differ
diff --git a/treevae/utils/data_utils.py b/treevae/utils/data_utils.py
new file mode 100644
index 0000000..a001f4f
--- /dev/null
+++ b/treevae/utils/data_utils.py
@@ -0,0 +1,406 @@
+"""
+Utility functions for data loading.
+"""
+import os
+import torch
+import torchvision
+import torchvision.transforms as T
+import numpy as np
+from torch.utils.data import TensorDataset, DataLoader, Subset, ConcatDataset
+from PIL import Image
+from sklearn.datasets import fetch_20newsgroups
+from sklearn.feature_extraction.text import TfidfVectorizer
+from sklearn.model_selection import train_test_split
+from utils.utils import reset_random_seeds
+
+def get_data(configs):
+ """Compute and process the data specified in the configs file.
+
+ Parameters
+ ----------
+ configs : dict
+ A dictionary of config settings, where the data_name, the number of clusters in the data and augmentation
+ details are specified.
+
+ Returns
+ ------
+ list
+ A list of three tensor datasets: trainset, trainset_eval, testset
+ """
+ data_name = configs['data']['data_name']
+ augment = configs['training']['augment']
+ augmentation_method = configs['training']['augmentation_method']
+ n_classes = configs['data']['num_clusters_data']
+
+ data_path = './data/'
+
+ if data_name == 'mnist':
+ reset_random_seeds(configs['globals']['seed'])
+ full_trainset = torchvision.datasets.MNIST(root=data_path, train=True, download=True, transform=T.ToTensor())
+ full_testset = torchvision.datasets.MNIST(root=data_path, train=False, download=True, transform=T.ToTensor())
+
+ # get only num_clusters digits
+ indx_train, indx_test = select_subset(full_trainset.targets, full_testset.targets, n_classes)
+ trainset = Subset(full_trainset, indx_train)
+ trainset_eval = Subset(full_trainset, indx_train)
+ testset = Subset(full_testset, indx_test)
+
+
+ elif data_name == 'fmnist':
+ reset_random_seeds(configs['globals']['seed'])
+ full_trainset = torchvision.datasets.FashionMNIST(root=data_path, train=True, download=True, transform=T.ToTensor())
+ full_testset = torchvision.datasets.FashionMNIST(root=data_path, train=False, download=True, transform=T.ToTensor())
+
+ # get only num_clusters digits
+ indx_train, indx_test = select_subset(full_trainset.targets, full_testset.targets, n_classes)
+ trainset = Subset(full_trainset, indx_train)
+ trainset_eval = Subset(full_trainset, indx_train)
+ testset = Subset(full_testset, indx_test)
+
+
+ elif data_name == 'news20':
+ reset_random_seeds(configs['globals']['seed'])
+ newsgroups_train = fetch_20newsgroups(subset='train')
+ newsgroups_test = fetch_20newsgroups(subset='test')
+ vectorizer = TfidfVectorizer(max_features=2000, dtype=np.float32)
+ x_train = torch.from_numpy(vectorizer.fit_transform(newsgroups_train.data).toarray())
+ x_test = torch.from_numpy(vectorizer.transform(newsgroups_test.data).toarray())
+ y_train = torch.from_numpy(newsgroups_train.target)
+ y_test = torch.from_numpy(newsgroups_test.target)
+
+ # get only num_clusters digits
+ indx_train, indx_test = select_subset(y_train, y_test, n_classes)
+ trainset = Subset(TensorDataset(x_train, y_train), indx_train)
+ trainset_eval = Subset(TensorDataset(x_train, y_train), indx_train)
+ testset = Subset(TensorDataset(x_test, y_test), indx_test)
+ trainset.dataset.targets = torch.tensor(trainset.dataset.tensors[1])
+ trainset_eval.dataset.targets = torch.tensor(trainset_eval.dataset.tensors[1])
+ testset.dataset.targets = torch.tensor(testset.dataset.tensors[1])
+
+
+ elif data_name == 'omniglot':
+ reset_random_seeds(configs['globals']['seed'])
+
+ transform_eval = T.Compose([
+ T.ToTensor(),
+ T.Resize([28,28], antialias=True),
+ ])
+
+ if augment and augmentation_method == ['simple']:
+ transform = T.Compose([
+ T.ToTensor(),
+ T.Resize([28,28], antialias=True),
+ T.RandomAffine(degrees=10, translate=(1/28, 1/28), scale=(0.9, 1.1), shear=0.01, fill=1),
+ ])
+ elif augment is False:
+ transform = transform_eval
+ else:
+ raise NotImplementedError
+
+ # Download the datasets and apply transformations
+ trainset_premerge = torchvision.datasets.Omniglot(root=data_path, background=True, download=True, transform=transform)
+ testset_premerge = torchvision.datasets.Omniglot(root=data_path, background=False, download=True, transform=transform)
+ trainset_premerge_eval = torchvision.datasets.Omniglot(root=data_path, background=True, download=True, transform=transform_eval)
+ testset_premerge_eval = torchvision.datasets.Omniglot(root=data_path, background=False, download=True, transform=transform_eval)
+
+ # Get the corresponding labels y_train and y_test
+ y_train_ind = torch.tensor([sample[1] for sample in trainset_premerge])
+ y_test_ind = torch.tensor([sample[1] for sample in testset_premerge])
+
+ # Create a list of all alphabet labels from both datasets
+ alphabets = trainset_premerge._alphabets + testset_premerge._alphabets
+
+ # Replace character labels by alphabet labels
+ y_train_pre = []
+ y_test_pre = []
+ for value in y_train_ind:
+ alphabet = trainset_premerge._characters[value].split("/")[0]
+ alphabet_ind = alphabets.index(alphabet)
+ y_train_pre.append(alphabet_ind)
+ for value in y_test_ind:
+ alphabet = testset_premerge._characters[value].split("/")[0]
+ alphabet_ind = alphabets.index(alphabet)
+ y_test_pre.append(alphabet_ind)
+
+ y = np.array(y_train_pre + y_test_pre)
+
+ # Select alphabets
+ num_clusters = n_classes
+ if num_clusters !=50:
+ alphabets_selected = get_selected_omniglot_alphabets()[:num_clusters]
+ alphabets_ind = []
+ for i in alphabets_selected:
+ alphabets_ind.append(alphabets.index(i))
+ else:
+ alphabets_ind = np.arange(50)
+
+ indx = np.array([], dtype=int)
+ for i in range(num_clusters):
+ indx = np.append(indx, np.where(y == alphabets_ind[i])[0])
+ indx = np.sort(indx)
+
+ # Split and stratify by digits
+ digits_label = torch.concatenate([y_train_ind, y_test_ind+len(torch.unique(y_train_ind))])
+ indx_train, indx_test = train_test_split(indx, test_size=0.2, random_state=configs['globals']['seed'], stratify=digits_label[indx])
+ indx_train = np.sort(indx_train)
+ indx_test = np.sort(indx_test)
+
+ # Define alphabets as labels
+ y = y+50
+ for idx, alphabet in enumerate(alphabets_ind):
+ y[y==alphabet+50] = idx
+
+ # Define mapping from digit to label
+ mapping_train = []
+ for value in torch.unique(y_train_ind):
+ alphabet = trainset_premerge._characters[value].split("/")[0]
+ alphabet_ind = alphabets.index(alphabet)
+ mapping_train.append(alphabet_ind)
+ mapping_test = []
+ for value in torch.unique(y_test_ind):
+ alphabet = testset_premerge._characters[value].split("/")[0]
+ alphabet_ind = alphabets.index(alphabet)
+ mapping_test.append(alphabet_ind)
+
+ custom_target_transform_train = T.Lambda(lambda y: mapping_train[y])
+ custom_target_transform_test = T.Lambda(lambda y: mapping_test[y])
+
+ trainset_premerge.target_transform = custom_target_transform_train
+ trainset_premerge_eval.target_transform = custom_target_transform_train
+ testset_premerge.target_transform = custom_target_transform_test
+ testset_premerge_eval.target_transform = custom_target_transform_test
+
+ # Define datasets
+ fullset = ConcatDataset([trainset_premerge, testset_premerge])
+ fullset_eval = ConcatDataset([trainset_premerge_eval, testset_premerge_eval])
+ fullset.targets = torch.from_numpy(y)
+ fullset_eval.targets = torch.from_numpy(y)
+ trainset = Subset(fullset, indx_train)
+ trainset_eval = Subset(fullset_eval, indx_train)
+ testset = Subset(fullset_eval, indx_test)
+
+
+
+ elif data_name in ['cifar10', 'cifar100', 'cifar10_vehicles', 'cifar10_animals']:
+ reset_random_seeds(configs['globals']['seed'])
+ aug_strength = 0.5
+
+
+ transform_eval = T.Compose([
+ T.ToTensor(),
+ ])
+
+ if augment is True:
+ aug_transforms = T.Compose([
+ T.RandomResizedCrop(32, interpolation=Image.BICUBIC, scale=(0.2, 1.0)),
+ T.RandomHorizontalFlip(),
+ T.RandomApply([T.ColorJitter(0.8 * aug_strength, 0.8 * aug_strength, 0.8 * aug_strength, 0.2 * aug_strength)], p=0.8),
+ T.RandomGrayscale(p=0.2),
+ T.ToTensor(),
+ ])
+ if augmentation_method == ['simple']:
+ transform = aug_transforms
+ else:
+ transform = ContrastiveTransformations(aug_transforms, n_views=2)
+ else:
+ transform = transform_eval
+
+ if data_name == 'cifar100':
+ if n_classes==20:
+ dataset = CIFAR100Coarse
+ else:
+ dataset = torchvision.datasets.CIFAR100
+ else:
+ dataset = torchvision.datasets.CIFAR10
+
+ full_trainset = dataset(root=data_path, train=True, download=True, transform=transform)
+ full_trainset_eval = dataset(root=data_path, train=True, download=True, transform=transform_eval)
+ full_testset = dataset(root=data_path, train=False, download=True, transform=transform_eval)
+
+ if data_name == 'cifar10_vehicles':
+ indx_train = [index for index, value in enumerate(full_trainset.targets) if value in (0, 1, 8, 9)]
+ indx_test = [index for index, value in enumerate(full_testset.targets) if value in (0, 1, 8, 9)]
+ elif data_name == 'cifar10_animals':
+ indx_train = [index for index, value in enumerate(full_trainset.targets) if value not in (0, 1, 8, 9)]
+ indx_test = [index for index, value in enumerate(full_testset.targets) if value not in (0, 1, 8, 9)]
+ else:
+ indx_train, indx_test = select_subset(full_trainset.targets, full_testset.targets, n_classes)
+
+ trainset = Subset(full_trainset, indx_train)
+ trainset_eval = Subset(full_trainset_eval, indx_train)
+ testset = Subset(full_testset, indx_test)
+
+ trainset.dataset.targets = torch.tensor(trainset.dataset.targets)
+ trainset_eval.dataset.targets = torch.tensor(trainset_eval.dataset.targets)
+ testset.dataset.targets = torch.tensor(testset.dataset.targets)
+
+ elif data_name == 'celeba':
+ reset_random_seeds(configs['globals']['seed'])
+ aug_strength = 0.25
+
+ # Slightly different reshaping from TF implementation to be inline with WAE
+ transform_eval = T.Compose([
+ T.Lambda(lambda x: T.functional.crop(x, left=15, top=40, width=148, height=148)),
+ T.Resize([64,64], antialias=True),
+ T.ToTensor(),
+ ])
+ if augment is True:
+ aug_transforms = T.Compose([
+ T.Lambda(lambda x: T.functional.crop(x, left=15, top=40, width=148, height=148)),
+ T.Resize([64,64], antialias=True),
+ T.RandomResizedCrop(64, interpolation=Image.BICUBIC, scale = (3/4,1), ratio=(4/5,5/4)),
+ T.RandomHorizontalFlip(),
+ T.RandomApply([T.ColorJitter(0.8 * aug_strength, 0.8 * aug_strength, 0.8 * aug_strength)], p=0.8),
+ T.ToTensor(),
+ ])
+ if augmentation_method == ['simple']:
+ transform = aug_transforms
+ else:
+ transform = ContrastiveTransformations(aug_transforms, n_views=2)
+ else:
+ transform = transform_eval
+
+
+ full_trainset = torchvision.datasets.CelebA(root=data_path, split='train', target_type='attr', target_transform=lambda y: 0, download=True, transform=transform)
+ full_trainset_eval = torchvision.datasets.CelebA(root=data_path, split='train', target_type='attr', target_transform=lambda y: 0, download=True, transform=transform_eval)
+ full_testset = torchvision.datasets.CelebA(root=data_path, split='test', target_type='attr', target_transform=lambda y: 0, download=True, transform=transform_eval)
+
+ indx_train = np.arange(len(full_trainset))
+ indx_test = np.arange(len(full_testset))
+
+ trainset = Subset(full_trainset, indx_train)
+ trainset_eval = Subset(full_trainset_eval, indx_train)
+ testset = Subset(full_testset, indx_test)
+ trainset.dataset.targets = torch.zeros(trainset.dataset.attr.shape[0], dtype=torch.int8)
+ trainset_eval.dataset.targets = torch.zeros(trainset.dataset.attr.shape[0], dtype=torch.int8)
+ testset.dataset.targets = torch.zeros(trainset.dataset.attr.shape[0], dtype=torch.int8)
+
+ else:
+ raise NotImplementedError('This dataset is not supported!')
+
+ assert trainset.__class__ == testset.__class__ == trainset_eval.__class__ == Subset
+ return trainset, trainset_eval, testset
+
+
+def get_gen(dataset, configs, validation=False, shuffle=True, smalltree=False, smalltree_ind=None):
+ """Given the dataset and a config file, it will output the DataLoader for training.
+
+ Parameters
+ ----------
+ dataset : torch.dataset
+ A tensor dataset.
+ configs : dict
+ A dictionary of config settings.
+ validation : bool, optional
+ If set to True it will not drop the last batch, during training it is preferrable to drop the last batch if it
+ has a different shape to avoid changing the batch normalization statistics.
+ shuffle : bool, optional
+ Whether to shuffle the dataset at every epoch.
+ smalltree : bool, optional
+ Whether the method should output the DataLoader for the small tree training, where a subset of training inputs
+ are used.
+ smalltree_ind : list
+ For training the small tree during the growing strategy of TreeVAE, only a subset of training inputs will be
+ used for efficiency.
+
+ Returns
+ ------
+ DataLoader
+ The dataloader of the provided dataset.
+ """
+ batch_size = configs['training']['batch_size']
+ drop_last = not validation
+ try:
+ num_workers = configs['parser']['num_workers']
+ except:
+ num_workers = 6
+
+ if smalltree:
+ dataset = Subset(dataset, smalltree_ind)
+
+ # Call the DataLoader when contrastive learning is used
+ if configs['training']['augment'] and configs['training']['augmentation_method'] != ['simple'] and not validation:
+ # As one datapoint leads to two samples, we have to half the batch size to retain same number of samples per batch
+ assert batch_size % 2 == 0
+ batch_size = batch_size // 2
+ data_gen = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True,
+ persistent_workers=True, collate_fn=custom_collate_fn, drop_last=drop_last)
+
+ # Call the DataLoader without contrastive learning
+ else:
+ data_gen = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True,
+ persistent_workers=True, drop_last=drop_last)
+ return data_gen
+
+
+def select_subset(y_train, y_test, num_classes):
+ # Select a random subset of labels where the number of different labels equal num_classes.
+ digits = np.random.choice([i for i in range(len(np.unique(y_train)))], size=num_classes, replace=False)
+ indx_train = np.array([], dtype=int)
+ indx_test = np.array([], dtype=int)
+ for i in range(num_classes):
+ indx_train = np.append(indx_train, np.where(y_train == digits[i])[0])
+ indx_test = np.append(indx_test, np.where(y_test == digits[i])[0])
+ return np.sort(indx_train), np.sort(indx_test)
+
+
+def custom_collate_fn(batch):
+ # Concatenate the augmented versions
+ batch = torch.utils.data.default_collate(batch)
+ batch[0] = batch[0].transpose(1, 0).reshape(-1,*batch[0].shape[2:])
+ batch[1] = batch[1].repeat(2)
+ return batch
+
+
+class ContrastiveTransformations(object):
+
+ def __init__(self, base_transforms, n_views=2):
+ self.base_transforms = base_transforms
+ self.n_views = n_views
+
+ def __call__(self, x):
+ return torch.stack([self.base_transforms(x) for i in range(self.n_views)],dim=0)
+
+
+def get_selected_omniglot_alphabets():
+ return ['Braille', 'Glagolitic', 'Old_Church_Slavonic_(Cyrillic)', 'Oriya', 'Bengali']
+
+
+class CIFAR100Coarse(torchvision.datasets.CIFAR100):
+ def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
+ super(CIFAR100Coarse, self).__init__(root, train, transform, target_transform, download)
+
+ # update labels
+ coarse_labels = np.array([ 4, 1, 14, 8, 0, 6, 7, 7, 18, 3,
+ 3, 14, 9, 18, 7, 11, 3, 9, 7, 11,
+ 6, 11, 5, 10, 7, 6, 13, 15, 3, 15,
+ 0, 11, 1, 10, 12, 14, 16, 9, 11, 5,
+ 5, 19, 8, 8, 15, 13, 14, 17, 18, 10,
+ 16, 4, 17, 4, 2, 0, 17, 4, 18, 17,
+ 10, 3, 2, 12, 12, 16, 12, 1, 9, 19,
+ 2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
+ 16, 19, 2, 4, 6, 19, 5, 5, 8, 19,
+ 18, 1, 2, 15, 6, 0, 17, 8, 14, 13])
+ self.targets = coarse_labels[self.targets]
+
+ # update classes
+ self.classes = [['beaver', 'dolphin', 'otter', 'seal', 'whale'],
+ ['aquarium_fish', 'flatfish', 'ray', 'shark', 'trout'],
+ ['orchid', 'poppy', 'rose', 'sunflower', 'tulip'],
+ ['bottle', 'bowl', 'can', 'cup', 'plate'],
+ ['apple', 'mushroom', 'orange', 'pear', 'sweet_pepper'],
+ ['clock', 'keyboard', 'lamp', 'telephone', 'television'],
+ ['bed', 'chair', 'couch', 'table', 'wardrobe'],
+ ['bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach'],
+ ['bear', 'leopard', 'lion', 'tiger', 'wolf'],
+ ['bridge', 'castle', 'house', 'road', 'skyscraper'],
+ ['cloud', 'forest', 'mountain', 'plain', 'sea'],
+ ['camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'],
+ ['fox', 'porcupine', 'possum', 'raccoon', 'skunk'],
+ ['crab', 'lobster', 'snail', 'spider', 'worm'],
+ ['baby', 'boy', 'girl', 'man', 'woman'],
+ ['crocodile', 'dinosaur', 'lizard', 'snake', 'turtle'],
+ ['hamster', 'mouse', 'rabbit', 'shrew', 'squirrel'],
+ ['maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree'],
+ ['bicycle', 'bus', 'motorcycle', 'pickup_truck', 'train'],
+ ['lawn_mower', 'rocket', 'streetcar', 'tank', 'tractor']]
\ No newline at end of file
diff --git a/treevae/utils/model_utils.py b/treevae/utils/model_utils.py
new file mode 100644
index 0000000..4581caf
--- /dev/null
+++ b/treevae/utils/model_utils.py
@@ -0,0 +1,238 @@
+"""
+Utility functions for model.
+"""
+import numpy as np
+import torch.nn as nn
+
+def compute_posterior(mu_q, mu_p, sigma_q, sigma_p):
+ epsilon = 1e-7
+ z_sigma_q = 1 / (1 / (sigma_q + epsilon) + 1 / (sigma_p + epsilon))
+ z_mu_q = (mu_q / (sigma_q + epsilon) +
+ mu_p / (sigma_p + epsilon)) * z_sigma_q
+ return z_mu_q, z_sigma_q
+
+
+def construct_tree(transformations, routers, routers_q, denses, decoders):
+ """
+ Construct the tree by passing a list of transformations and routers from root to leaves visiting nodes
+ layer-wise from left to right
+
+ :param transformations: list of transformations to attach to the nodes of the tree
+ :param routers: list of decisions to attach to the nodes of the tree
+ :param denses: list of dense network that from d of the bottom up compute node-specific q
+ :param decoders: list of decoders to attach to the nodes, they should be set to None except the leaves
+ :return: the root of the tree
+ """
+ if len(transformations) != len(routers) and len(transformations) != len(denses) \
+ and len(transformations) != len(decoders):
+ raise ValueError('Len transformation is different than len routers in constructing the tree.')
+ root = Node(transformation=transformations[0], router=routers[0], routers_q=routers_q[0], dense=denses[0], decoder=decoders[0])
+ for i in range(1, len(transformations)):
+ root.insert(transformation=transformations[i], router=routers[i], routers_q=routers_q[i], dense=denses[i], decoder=decoders[i])
+ return root
+
+
+class Node:
+ def __init__(self, transformation, router, routers_q, dense, decoder=None, expand=True):
+ self.left = None
+ self.right = None
+ self.parent = None
+ self.transformation = transformation
+ self.dense = dense
+ self.router = router
+ self.routers_q = routers_q
+ self.decoder = decoder
+ self.expand = expand
+
+ def insert(self, transformation=None, router=None, routers_q=None, dense=None, decoder=None):
+ queue = []
+ node = self
+ queue.append(node)
+ while len(queue) > 0:
+ node = queue.pop(0)
+ if node.expand:
+ if node.left is None:
+ node.left = Node(transformation, router, routers_q, dense, decoder)
+ node.left.parent = node
+ return
+ elif node.right is None:
+ node.right = Node(transformation, router, routers_q, dense, decoder)
+ node.right.parent = node
+ return
+ else:
+ queue.append(node.left)
+ queue.append(node.right)
+ print('\nAttention node has not been inserted!\n')
+ return
+
+ def prune_child(self, child):
+ if child is self.left:
+ self.left = None
+ self.router = None
+
+ elif child is self.right:
+ self.right = None
+ self.router = None
+
+ else:
+ raise ValueError("This is not my child! (Node is not a child of this parent.)")
+
+def return_list_tree(root):
+ list_nodes = [root]
+ denses = []
+ transformations = []
+ routers = []
+ routers_q = []
+ decoders = []
+ while len(list_nodes) != 0:
+ current_node = list_nodes.pop(0)
+ denses.append(current_node.dense)
+ transformations.append(current_node.transformation)
+ routers.append(current_node.router)
+ routers_q.append(current_node.routers_q)
+ decoders.append(current_node.decoder)
+ if current_node.router is not None:
+ node_left, node_right = current_node.left, current_node.right
+ list_nodes.append(node_left)
+ list_nodes.append(node_right)
+ elif current_node.router is None and current_node.decoder is None:
+ # We are in an internal node with pruned leaves and thus only have one child
+ node_left, node_right = current_node.left, current_node.right
+ child = node_left if node_left is not None else node_right
+ list_nodes.append(child)
+ return nn.ModuleList(transformations), nn.ModuleList(routers), nn.ModuleList(denses), nn.ModuleList(decoders), nn.ModuleList(routers_q)
+
+
+def construct_tree_fromnpy(model, data_tree, configs):
+ from models.model_smalltree import SmallTreeVAE
+ nodes = {0: {'node': model.tree, 'depth': 0}}
+
+ for i in range(1, len(data_tree)-1):
+ node_left = data_tree[i]
+ node_right = data_tree[i + 1]
+ id_node_left = node_left[0]
+ id_node_right = node_right[0]
+
+ if node_left[2] == node_right[2]:
+ id_parent = node_left[2]
+
+ parent = nodes[id_parent]
+ node = parent['node']
+ depth = parent['depth']
+
+ new_depth = depth + 1
+
+ small_model = SmallTreeVAE(new_depth+1, **configs['training'])
+
+ node.router = small_model.decision
+ node.routers_q = small_model.decision_q
+
+ node.decoder = None
+ n = []
+ for j in range(2):
+ dense = small_model.denses[j]
+ transformation = small_model.transformations[j]
+ decoder = small_model.decoders[j]
+ n.append(Node(transformation, None, None, dense, decoder))
+
+ node.left = n[0]
+ node.right = n[1]
+
+ nodes[id_node_left] = {'node': node.left, 'depth': new_depth}
+ nodes[id_node_right] = {'node': node.right, 'depth': new_depth}
+ elif data_tree[i][2] != data_tree[i - 1][2]: # Internal node w/ 1 child only
+ id_parent = node_left[2]
+
+ parent = nodes[id_parent]
+ node = parent['node']
+ depth = parent['depth']
+
+ new_depth = depth + 1
+
+ small_model = SmallTreeVAE(new_depth+1, **configs['training'])
+
+ node.router = None
+ node.routers_q = None
+
+ node.decoder = None
+ n = []
+ for j in range(1):
+ dense = small_model.denses[j]
+ transformation = small_model.transformations[j]
+ decoder = small_model.decoders[j]
+ n.append(Node(transformation, None, None, dense, decoder))
+
+ node.left = n[0]
+ nodes[id_node_left] = {'node': node.left, 'depth': new_depth}
+
+ transformations, routers, denses, decoders, routers_q = return_list_tree(model.tree)
+ model.decisions_q = routers_q
+ model.transformations = transformations
+ model.decisions = routers
+ model.denses = denses
+ model.decoders = decoders
+ model.depth = model.compute_depth()
+ return model
+
+
+def construct_data_tree(model, y_predicted, y_true, n_leaves, data_name):
+ list_nodes = [{'node':model.tree, 'id': 0, 'parent':None}]
+ data = []
+ i = 0
+ labels = [i for i in range(n_leaves)]
+ while len(list_nodes) != 0:
+ current_node = list_nodes.pop(0)
+ if current_node['node'].router is not None:
+ data.append([current_node['id'], str(current_node['id']), current_node['parent'], 10])
+ node_left, node_right = current_node['node'].left, current_node['node'].right
+ i += 1
+ list_nodes.append({'node':node_left, 'id': i, 'parent': current_node['id']})
+ i += 1
+ list_nodes.append({'node':node_right, 'id': i, 'parent': current_node['id']})
+ elif current_node['node'].router is None and current_node['node'].decoder is None:
+ # We are in an internal node with pruned leaves and will only add the non-pruned leaves
+ data.append([current_node['id'], str(current_node['id']), current_node['parent'], 10])
+ node_left, node_right = current_node['node'].left, current_node['node'].right
+ child = node_left if node_left is not None else node_right
+ i += 1
+ list_nodes.append({'node': child, 'id': i, 'parent': current_node['id']})
+ else:
+ y_leaf = labels.pop(0)
+ ind = np.where(y_predicted == y_leaf)[0]
+ digits, counts = np.unique(y_true[ind], return_counts=True)
+ tot = len(ind)
+ if tot == 0:
+ name = 'no digits'
+ else:
+ counts = np.round(counts / np.sum(counts), 2)
+ ind = np.where(counts > 0.1)[0]
+ name = ' '
+ for j in ind:
+ if data_name == 'fmnist':
+ items = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker',
+ 'Bag', 'Boot']
+ name = name + str(items[digits[j]]) + ': ' + str(counts[j]) + ' '
+ elif data_name == 'cifar10':
+ items = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship',
+ 'truck']
+ name = name + str(items[digits[j]]) + ': ' + str(counts[j]) + ' '
+ elif data_name == 'news20':
+ items = ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware',
+ 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale','rec.autos',
+ 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt',
+ 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian',
+ 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc',
+ 'talk.religion.misc']
+ name = name + str(items[digits[j]]) + ': ' + str(counts[j]) + ' '
+ elif data_name == 'omniglot':
+ from utils.data_utils import get_selected_omniglot_alphabets
+ items = get_selected_omniglot_alphabets()
+ if np.unique(y_true).shape[0]>len(items):
+ items=np.arange(50)
+
+ name = name + items[digits[j]] + ': ' + str(counts[j]) + ' '
+ else:
+ name = name + str(digits[j]) + ': ' + str(counts[j]) + ' '
+ name = name + 'tot ' + str(tot)
+ data.append([current_node['id'], name, current_node['parent'], 1])
+ return data
diff --git a/treevae/utils/plotting_utils.py b/treevae/utils/plotting_utils.py
new file mode 100644
index 0000000..622cc7b
--- /dev/null
+++ b/treevae/utils/plotting_utils.py
@@ -0,0 +1,301 @@
+import numpy as np
+import torch
+import torch.distributions as td
+from matplotlib import pyplot as plt
+from utils.model_utils import construct_tree, compute_posterior
+import re
+import networkx as nx
+from sklearn.decomposition import PCA
+
+
+
+def hierarchy_pos(G, root, levels=None, width=1., height=1.):
+ '''
+ Encodes the hierarchy for the tree layout in a graph.
+ From https://stackoverflow.com/questions/29586520/can-one-get-hierarchical-graphs-from-networkx-with-python-3
+ If there is a cycle that is reachable from root, then this will see infinite recursion.
+ G: the graph
+ root: the root node
+ levels: a dictionary
+ key: level number (starting from 0)
+ value: number of nodes in this level
+ width: horizontal space allocated for drawing
+ height: vertical space allocated for drawing'''
+ TOTAL = "total"
+ CURRENT = "current"
+ def make_levels(levels, node=root, currentLevel=0, parent=None):
+ """Compute the number of nodes for each level
+ """
+ if not currentLevel in levels:
+ levels[currentLevel] = {TOTAL : 0, CURRENT : 0}
+ levels[currentLevel][TOTAL] += 1
+ neighbors = G.neighbors(node)
+ for neighbor in neighbors:
+ if not neighbor == parent:
+ levels = make_levels(levels, neighbor, currentLevel + 1, node)
+ return levels
+
+ def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0):
+ dx = 1/levels[currentLevel][TOTAL]
+ left = dx/2
+ pos[node] = ((left + dx*levels[currentLevel][CURRENT])*width, vert_loc)
+ levels[currentLevel][CURRENT] += 1
+ neighbors = G.neighbors(node)
+ for neighbor in neighbors:
+ if not neighbor == parent:
+ pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc-vert_gap)
+ return pos
+ if levels is None:
+ levels = make_levels({})
+ else:
+ levels = {l:{TOTAL: levels[l], CURRENT:0} for l in levels}
+ vert_gap = height / (max([l for l in levels])+1)
+ return make_pos({})
+
+
+def plot_tree_graph(data):
+
+ # get a '/n' before every 'tot' in each second entry of data
+ data = data.copy()
+ for d in data:
+ if d[3] == 1:
+ #d[1] = d[1].replace('tot', '\ntot')
+ pattern = r'(\w+:\s\d+\.\d+|\d+:\s\d+\.\d+|\w+\s\d+|\d+\s\d+|\w+:\s\d+|\d+:\s\d+|\w+:\s\d+\s\w+|\d+:\s\d+\s\w+|\w+\s\d+\s\w+|\d+\s\d+\s\w+|\w+:\s\d+\.\d+\s\w+|\d+:\s\d+\.\d+\s\w+)'
+
+ # Split the string using the regular expression pattern
+ result = re.findall(pattern, d[1])
+
+ # Join the resulting list to format it as desired
+ d[1] = '\n'.join(result)
+
+ # Create a directed graph
+ G = nx.DiGraph()
+
+ # Add nodes and edges to the graph
+ for node in data:
+ node_id, label, parent_id, node_type = node
+ G.add_node(node_id, label=label, node_type=node_type)
+ if parent_id is not None:
+ G.add_edge(parent_id, node_id)
+
+ # Get positions of graph nodes
+ pos = hierarchy_pos(G, 0, levels=None, width=1, height=1)
+
+ # get the labels of the nodes
+ labels = nx.get_node_attributes(G, 'label')
+
+ # Initialize node color and size lists
+ node_colors = []
+ node_sizes = []
+
+ # Iterate through nodes to set colors and sizes
+ for node_id, node_data in G.nodes(data=True):
+ if G.out_degree(node_id) == 0: # Leaf nodes have out-degree 0
+ node_colors.append('lightgreen')
+ node_sizes.append(4000)
+
+ else:
+ node_colors.append('lightblue')
+ node_sizes.append(1000)
+
+ # Draw the graph with different node properties
+ tree = plt.figure(figsize=(10, 5))
+ nx.draw(G, pos=pos, labels=labels, with_labels=True, node_size=node_sizes, node_color=node_colors, font_size=7)
+
+ plt.show()
+
+
+
+def get_node_embeddings(model, x):
+ assert model.training == False
+ epsilon = 1e-7
+ device = x.device
+
+ # compute deterministic bottom up
+ d = x
+ encoders = []
+
+ for i in range(0, len(model.hidden_layers)):
+ d, _, _ = model.bottom_up[i](d)
+ # store the bottom-up layers for the top-down computation
+ encoders.append(d)
+
+ # Create a list to store node information
+ node_info_list = []
+
+ # Create a list of nodes of the tree that need to be processed
+ list_nodes = [{'node': model.tree, 'depth': 0, 'prob': torch.ones(x.size(0), device=device), 'z_parent_sample': None}]
+
+ while len(list_nodes) != 0:
+ # Store info regarding the current node
+ current_node = list_nodes.pop(0)
+ node, depth_level, prob = current_node['node'], current_node['depth'], current_node['prob']
+ z_parent_sample = current_node['z_parent_sample']
+
+ # Access deterministic bottom-up mu and sigma hat (computed above)
+ d = encoders[-(1 + depth_level)]
+ z_mu_q_hat, z_sigma_q_hat = node.dense(d)
+
+ if depth_level == 0:
+ z_mu_q, z_sigma_q = z_mu_q_hat, z_sigma_q_hat
+ else:
+ # The generative mu and sigma are the output of the top-down network given the sampled parent
+ _, z_mu_p, z_sigma_p = node.transformation(z_parent_sample)
+ z_mu_q, z_sigma_q = compute_posterior(z_mu_q_hat, z_mu_p, z_sigma_q_hat, z_sigma_p)
+
+ # Compute sample z using mu_q and sigma_q
+ z = td.Independent(td.Normal(z_mu_q, torch.sqrt(z_sigma_q + epsilon)), 1)
+ z_sample = z.rsample()
+
+ # Store information in the list
+ node_info = {'prob': prob, 'z_sample': z_sample}
+ node_info_list.append(node_info)
+
+ if node.router is not None:
+ prob_child_left_q = node.routers_q(d).squeeze()
+
+ # We are not in a leaf, so we have to add the left and right child to the list
+ prob_node_left, prob_node_right = prob * prob_child_left_q, prob * (1 - prob_child_left_q)
+
+ node_left, node_right = node.left, node.right
+ list_nodes.append(
+ {'node': node_left, 'depth': depth_level + 1, 'prob': prob_node_left, 'z_parent_sample': z_sample})
+ list_nodes.append({'node': node_right, 'depth': depth_level + 1, 'prob': prob_node_right,
+ 'z_parent_sample': z_sample})
+
+ elif node.decoder is None and (node.left is not None or node.right is not None):
+ # We are in an internal node with pruned leaves and thus only have one child
+ node_left, node_right = node.left, node.right
+ child = node_left if node_left is not None else node_right
+ list_nodes.append(
+ {'node': child, 'depth': depth_level + 1, 'prob': prob, 'z_parent_sample': z_sample})
+
+ return node_info_list
+
+
+
+# Create a function to draw scatter plots as nodes
+def draw_scatter_node(node_id, node_embeddings, colors, ax, pca = True):
+
+ # if list is empty --> node has been pruned
+ if node_embeddings[node_id]['z_sample'] == []:
+ # return empty plot
+ ax.set_title(f"Node {node_id}")
+ ax.set_xticks([])
+ ax.set_yticks([])
+ return
+
+ z_sample = node_embeddings[node_id]['z_sample']
+ weights = node_embeddings[node_id]['prob']
+
+ if pca:
+ pca_fit = PCA(n_components=2)
+ z_sample = pca_fit.fit_transform(z_sample)
+
+
+ ax.scatter(z_sample[:, 0], z_sample[:, 1], c=colors, cmap='tab10', alpha=weights, s = 0.25)
+ ax.set_title(f"Node {node_id}")
+ # no ticks
+ ax.set_xticks([])
+ ax.set_yticks([])
+
+
+def splits_to_right_and_left(node_id, data):
+ # Initialize splits to right and left to 0
+ splits_to_right = 0
+ splits_to_left = 0
+
+ # root node
+
+ while True:
+ # root node
+ if node_id == 0:
+ return splits_to_left, splits_to_right
+
+ # previous node has same parent
+ elif data[node_id-1][2] == data[node_id][2]:
+ splits_to_right += 1
+ node_id = data[node_id][2]
+
+ else:
+ splits_to_left += 1
+ node_id = data[node_id][2]
+
+
+def get_depth(node_id, data):
+ # Initialize the depth to 0
+ depth = 0
+
+ # Find the node in the data list
+ node = next(node for node in data if node[0] == node_id)
+
+ # Recursively calculate the depth
+ if node[2] is not None:
+ depth = 1 + get_depth(node[2], data)
+
+ return depth
+
+
+# Create the tree graph with scatter plots as nodes
+def draw_tree_with_scatter_plots(data, node_embeddings, label_list, pca = True):
+
+ # Create a directed graph
+ G = nx.DiGraph()
+
+ # Add nodes and edges to the graph
+ for node in data:
+ node_id, label, parent_id, node_type = node
+ G.add_node(node_id, label=label, node_type=node_type)
+ if parent_id is not None:
+ G.add_edge(parent_id, node_id)
+
+ # Get positions of graph nodes
+ pos = hierarchy_pos(G, 0, levels=None, width=1, height=1)
+
+ # get the labels of the nodes
+ labels = nx.get_node_attributes(G, 'label')
+
+
+ fig, ax = plt.subplots(figsize=(20, 10))
+
+ for node_id, node_data in G.nodes(data=True):
+ x, y = pos[node_id]
+
+ # Create a subplot for each node, centered on the node
+ sub_ax = fig.add_axes([x, y+0.9, 0.1, 0.1])
+ draw_scatter_node(node_id, node_embeddings, label_list, sub_ax, pca)
+
+ # Draw the lines between above nodes, need to consider the position of the subplots
+
+ # first need a list of edges in the order of the nodes and the positions of the nodes
+ # Calculate the positions of the connection lines
+ # offset by -0.05 for each left split and by +0.05 for each right split
+
+ node_positions = {}
+
+ for node in data:
+ node_id, label, parent_id, node_type = node
+ x, y = pos[node_id]
+ depth = get_depth(node_id, data)
+ splits_to_left, splits_to_right = splits_to_right_and_left(node_id, data)
+
+ # calculate the position of the node
+ x = x - splits_to_left * 0.05 + splits_to_right * 0.05 + 0.05
+ y = y + 1.1 - depth * 0.05
+
+ node_positions[node_id] = (x, y)
+
+ # draw the connection lines
+ if parent_id is not None:
+ x_parent, y_parent = node_positions[parent_id]
+ ax.plot([x_parent, x], [y_parent, y], color='black', alpha=0.5)
+
+
+ # Set the limits of the plot
+ ax.set_ylim(0, 1)
+ ax.set_xlim(0, 1)
+ ax.axis('off')
+
+ plt.show()
+
+
diff --git a/treevae/utils/training_utils.py b/treevae/utils/training_utils.py
new file mode 100644
index 0000000..2bc940b
--- /dev/null
+++ b/treevae/utils/training_utils.py
@@ -0,0 +1,512 @@
+"""
+Utility functions for training.
+"""
+import torch
+import math
+import numpy as np
+import gc
+import wandb
+from tqdm import tqdm
+import torch.optim as optim
+from torchmetrics import Metric
+from sklearn.metrics.cluster import normalized_mutual_info_score
+from utils.utils import cluster_acc
+from torch.utils.data import TensorDataset
+
+
+def train_one_epoch(train_loader, model, optimizer, metrics_calc, epoch_idx, device, train_small_tree=False,
+ small_model=None, ind_leaf=None):
+ """
+ Train TreeVAE or SmallTreeVAE model for one epoch.
+
+ Parameters
+ ----------
+ train_loader: DataLoader
+ The train data loader
+ model: models.model.TreeVAE
+ The TreeVAE model
+ optimizer: optim
+ The optimizer for training the model
+ metrics_calc: Metric
+ The metrics to keep track while training
+ epoch_idx: int
+ The current epoch
+ device: torch.device
+ The device in which to validate the model
+ train_small_tree: bool
+ If set to True, then the subtree (small_model) will be trained (and afterwords attached to model)
+ small_model: models.model.SmallTreeVAE
+ The SmallTreeVAE model (which is then attached to a selected leaf of TreeVAE)
+ ind_leaf: int
+ The index of the TreeVAE leaf where the small_model will be attached
+ """
+ if train_small_tree:
+ # if we train the small tree, then the full tree is frozen
+ model.eval()
+ small_model.train()
+ model.return_bottomup[0] = True
+ model.return_x[0] = True
+ alpha = small_model.alpha
+ else:
+ # otherwise we are training the full tree
+ model.train()
+ alpha = model.alpha
+
+ metrics_calc.reset()
+
+ for batch_idx, batch in enumerate(tqdm(train_loader)):
+ inputs, labels = batch
+ inputs, labels = inputs.to(device), labels.to(device)
+ # Zero your gradients for every batch
+ optimizer.zero_grad()
+
+ # Make predictions for this batch
+ if train_small_tree:
+ # Gradient-free pass of full tree
+ with torch.no_grad():
+ outputs_full = model(inputs)
+ x, node_leaves, bottom_up = outputs_full['input'], outputs_full['node_leaves'], outputs_full['bottom_up']
+ # Passing through subtree for updating its parameters
+ outputs = small_model(x, node_leaves[ind_leaf]['z_sample'], node_leaves[ind_leaf]['prob'], bottom_up)
+ outputs['kl_root'] = torch.tensor(0., device=device)
+ else:
+ outputs = model(inputs)
+
+ # Compute the loss and its gradients
+ rec_loss = outputs['rec_loss']
+ kl_losses = outputs['kl_root'] + outputs['kl_decisions'] + outputs['kl_nodes']
+ loss_value = rec_loss + alpha * kl_losses + outputs['aug_decisions']
+ loss_value.backward()
+
+ # Adjust learning weights
+ optimizer.step()
+
+ # Store metrics
+ # Note that y_pred is used for computing nmi.
+ # During subtree training, the nmi is calculated relative to only the subtree.
+ y_pred = outputs['p_c_z'].argmax(dim=-1)
+ metrics_calc.update(loss_value, outputs['rec_loss'], outputs['kl_decisions'], outputs['kl_nodes'],
+ outputs['kl_root'], outputs['aug_decisions'],
+ (1 - torch.mean(y_pred.float()) if outputs['p_c_z'].shape[1] <= 2 else torch.tensor(0.,
+ device=device)),
+ labels, y_pred)
+
+ if train_small_tree:
+ model.return_bottomup[0] = False
+ model.return_x[0] = False
+
+ # Calculate and log metrics
+ metrics = metrics_calc.compute()
+ metrics['alpha'] = alpha
+ wandb.log({'train': metrics})
+ prints = f"Epoch {epoch_idx}, Train : "
+ for key, value in metrics.items():
+ prints += f"{key}: {value:.3f} "
+ print(prints)
+ metrics_calc.reset()
+ _ = gc.collect()
+ return
+
+
+def validate_one_epoch(test_loader, model, metrics_calc, epoch_idx, device, test=False, train_small_tree=False,
+ small_model=None, ind_leaf=None):
+ model.eval()
+ if train_small_tree:
+ small_model.eval()
+ model.return_bottomup[0] = True
+ model.return_x[0] = True
+ alpha = small_model.alpha
+ else:
+ alpha = model.alpha
+
+ metrics_calc.reset()
+
+ with torch.no_grad():
+ for batch_idx, batch in enumerate(tqdm(test_loader)):
+ inputs, labels = batch
+ inputs, labels = inputs.to(device), labels.to(device)
+ # Make predictions for this batch
+ if train_small_tree:
+ # Sass of full tree
+ outputs_full = model(inputs)
+ x, node_leaves, bottom_up = outputs_full['input'], outputs_full['node_leaves'], outputs_full[
+ 'bottom_up']
+ # Passing through subtree
+ outputs = small_model(x, node_leaves[ind_leaf]['z_sample'], node_leaves[ind_leaf]['prob'], bottom_up)
+ outputs['kl_root'] = torch.tensor(0., device=device)
+ else:
+ outputs = model(inputs)
+
+ # Compute the loss and its gradients
+ rec_loss = outputs['rec_loss']
+ kl_losses = outputs['kl_root'] + outputs['kl_decisions'] + outputs['kl_nodes']
+ loss_value = rec_loss + alpha * kl_losses + outputs['aug_decisions']
+
+ # Store metrics
+ y_pred = outputs['p_c_z'].argmax(dim=-1)
+ metrics_calc.update(loss_value, outputs['rec_loss'], outputs['kl_decisions'], outputs['kl_nodes'],
+ outputs['kl_root'],
+ outputs['aug_decisions'], (
+ 1 - torch.mean(outputs['p_c_z'].argmax(dim=-1).float()) if outputs['p_c_z'].shape[
+ 1] <= 2 else torch.tensor(
+ 0., device=device)), labels, y_pred)
+
+ if train_small_tree:
+ model.return_bottomup[0] = False
+ model.return_x[0] = False
+
+ # Calculate and log metrics
+ metrics = metrics_calc.compute()
+ if not test:
+ wandb.log({'validation': metrics})
+ prints = f"Epoch {epoch_idx}, Validation: "
+ else:
+ wandb.log({'test': metrics})
+ prints = f"Test: "
+ for key, value in metrics.items():
+ prints += f"{key}: {value:.3f} "
+ print(prints)
+ metrics_calc.reset()
+ _ = gc.collect()
+ return
+
+
+def predict(loader, model, device, *return_flags):
+ model.eval()
+
+ if 'bottom_up' in return_flags:
+ model.return_bottomup[0] = True
+ if 'X_aug' in return_flags:
+ model.return_x[0] = True
+ if 'elbo' in return_flags:
+ model.return_elbo[0] = True
+
+ results = {name: [] for name in return_flags}
+ # Create a dictionary to map return flags to corresponding functions
+ return_functions = {
+ 'node_leaves': lambda: move_to(outputs['node_leaves'], 'cpu'),
+ 'bottom_up': lambda: move_to(outputs['bottom_up'], 'cpu'),
+ 'prob_leaves': lambda: move_to(outputs['p_c_z'], 'cpu'),
+ 'X_aug': lambda: move_to(outputs['input'], 'cpu'),
+ 'y': lambda: labels,
+ 'elbo': lambda: move_to(outputs['elbo_samples'], 'cpu'),
+ 'rec_loss': lambda: move_to(outputs['rec_loss'], 'cpu')
+ }
+
+ with torch.no_grad():
+ for batch_idx, (inputs, labels) in enumerate(tqdm(loader)):
+ inputs = inputs.to(device)
+ # Make predictions for this batch
+ outputs = model(inputs)
+
+ for return_flag in return_flags:
+ results[return_flag].append(return_functions[return_flag]())
+
+ for return_flag in return_flags:
+ if return_flag == 'bottom_up':
+ bottom_up = results[return_flag]
+ results[return_flag] = [torch.cat([sublist[i] for sublist in bottom_up], dim=0) for i in
+ range(len(bottom_up[0]))]
+ elif return_flag == 'node_leaves':
+ node_leaves_combined = []
+ node_leaves = results[return_flag]
+ for i in range(len(node_leaves[0])):
+ node_leaves_combined.append(dict())
+ for key in node_leaves[0][i].keys():
+ node_leaves_combined[i][key] = torch.cat([sublist[i][key] for sublist in node_leaves], dim=0)
+ results[return_flag] = node_leaves_combined
+ elif return_flag == 'rec_loss':
+ results[return_flag] = torch.stack(results[return_flag], dim=0)
+ else:
+ results[return_flag] = torch.cat(results[return_flag], dim=0)
+
+ if 'bottom_up' in return_flags:
+ model.return_bottomup[0] = False
+ if 'X_aug' in return_flags:
+ model.return_x[0] = False
+ if 'elbo' in return_flags:
+ model.return_elbo[0] = False
+
+ if len(return_flags) == 1:
+ return list(results.values())[0]
+ else:
+ return tuple(results.values())
+
+
+def move_to(obj, device):
+ if torch.is_tensor(obj):
+ return obj.to(device)
+ elif isinstance(obj, dict):
+ res = {}
+ for k, v in obj.items():
+ res[k] = move_to(v, device)
+ return res
+ elif isinstance(obj, list):
+ res = []
+ for v in obj:
+ res.append(move_to(v, device))
+ return res
+ elif isinstance(obj, tuple):
+ res = tuple(tensor.to(device) for tensor in obj)
+ return res
+ else:
+ raise TypeError("Invalid type for move_to")
+
+
+class AnnealKLCallback:
+ def __init__(self, model, decay=0.01, start=0.):
+ self.decay = decay
+ self.start = start
+ self.model = model
+ self.model.alpha = torch.tensor(min(1, start))
+
+ def on_epoch_end(self, epoch, logs=None):
+ value = self.start + (epoch + 1) * self.decay
+ self.model.alpha = torch.tensor(min(1, value))
+
+
+class Decay():
+ def __init__(self, lr=0.001, drop=0.1, epochs_drop=50):
+ self.lr = lr
+ self.drop = drop
+ self.epochs_drop = epochs_drop
+
+ def learning_rate_scheduler(self, epoch):
+ initial_lrate = self.lr
+ drop = self.drop
+ epochs_drop = self.epochs_drop
+ lrate = initial_lrate * math.pow(drop, math.floor((1 + epoch) / epochs_drop))
+ return lrate
+
+
+def calc_aug_loss(prob_parent, prob_router, augmentation_methods, emb_contr=[]):
+ aug_decisions_loss = torch.zeros(1, device=prob_parent.device)
+ prob_parent = prob_parent.detach()
+
+ # Get router probabilities of X' and X''
+ p1, p2 = prob_router[:len(prob_router) // 2], prob_router[len(prob_router) // 2:]
+ # Perform invariance regularization
+ for aug_method in augmentation_methods:
+ # Perform invariance regularization in the decisions
+ if aug_method == 'InfoNCE':
+ p1_normed = torch.nn.functional.normalize(torch.stack([p1, 1 - p1], 1), dim=1)
+ p2_normed = torch.nn.functional.normalize(torch.stack([p2, 1 - p2], 1), dim=1)
+ pair_sim = torch.exp(torch.sum(p1_normed * p2_normed, dim=1))
+ p_normed = torch.cat([p1_normed, p2_normed], dim=0)
+ matrix_sim = torch.exp(torch.matmul(p_normed, p_normed.t()))
+ norm_factor = torch.sum(matrix_sim, dim=1) - torch.diag(matrix_sim)
+ pair_sim = pair_sim.repeat(2) # storing sim for X' and X''
+ info_nce_sample = -torch.log(pair_sim / norm_factor)
+ info_nce = torch.sum(prob_parent * info_nce_sample) / torch.sum(prob_parent)
+ aug_decisions_loss += info_nce
+ # Perform invariance regularization in the bottom-up embeddings
+ elif aug_method == 'instancewise_full':
+ looplen = len(emb_contr)
+ for i in range(looplen):
+ temp_instance = 0.5
+ emb = emb_contr[i]
+ emb1, emb2 = emb[:len(emb) // 2], emb[len(emb) // 2:]
+ emb1_normed = torch.nn.functional.normalize(emb1, dim=1)
+ emb2_normed = torch.nn.functional.normalize(emb2, dim=1)
+ pair_sim = torch.exp(torch.sum(emb1_normed * emb2_normed, dim=1) / temp_instance)
+ emb_normed = torch.cat([emb1_normed, emb2_normed], dim=0)
+ matrix_sim = torch.exp(torch.matmul(emb_normed, emb_normed.t()) / temp_instance)
+ norm_factor = torch.sum(matrix_sim, dim=1) - torch.diag(matrix_sim)
+ pair_sim = pair_sim.repeat(2) # storing sim for X' and X''
+ info_nce_sample = -torch.log(pair_sim / norm_factor)
+ info_nce = torch.mean(info_nce_sample)
+ info_nce = info_nce / looplen
+ aug_decisions_loss += info_nce
+ else:
+ raise NotImplementedError
+
+ return aug_decisions_loss
+
+
+def get_ind_small_tree(node_leaves, n_effective_leaves):
+ prob = node_leaves['prob']
+ ind = np.where(prob >= min(1 / n_effective_leaves, 0.5))[0] # To circumvent problems with n_effective_leaves==1
+ return ind
+
+
+def compute_leaves(tree):
+ list_nodes = [{'node': tree, 'depth': 0}]
+ nodes_leaves = []
+ while len(list_nodes) != 0:
+ current_node = list_nodes.pop(0)
+ node, depth_level = current_node['node'], current_node['depth']
+ if node.router is not None:
+ node_left, node_right = node.left, node.right
+ list_nodes.append(
+ {'node': node_left, 'depth': depth_level + 1})
+ list_nodes.append({'node': node_right, 'depth': depth_level + 1})
+ elif node.router is None and node.decoder is None:
+ # We are in an internal node with pruned leaves and thus only have one child
+ node_left, node_right = node.left, node.right
+ child = node_left if node_left is not None else node_right
+ list_nodes.append(
+ {'node': child, 'depth': depth_level + 1})
+ else:
+ nodes_leaves.append(current_node)
+ return nodes_leaves
+
+
+def compute_growing_leaf(loader, model, node_leaves, max_depth, batch_size, max_leaves, check_max=False):
+ """
+ Compute the leaf of the TreeVAE model that should be further split.
+
+ Parameters
+ ----------
+ loader: DataLoader
+ The data loader used to compute the leaf
+ model: models.model.TreeVAE
+ The TreeVAE model
+ node_leaves: list
+ A list of leaf nodes, each one described by a dictionary
+ {'prob': sample-wise probability of reaching the node, 'z_sample': sampled leaf embedding}
+ max_depth: int
+ The maximum depth of the tree
+ batch_size: int
+ The batch size
+ max_leaves: int
+ The maximum number of leaves of the tree
+ check_max: bool
+ Whether to check that we reached the maximum number of leaves
+ Returns
+ ------
+ list: List containing:
+ ind_leaf: index of the selected leaf
+ leaf: the selected leaf
+ n_effective_leaves: number of leaves that are not empty
+ """
+
+ # count effective number of leaves (non empty leaves)
+ weights = [node_leaves[i]['prob'] for i in range(len(node_leaves))]
+ weights_summed = [weights[i].sum() for i in range(len(weights))]
+ n_effective_leaves = len(np.where(weights_summed / np.sum(weights_summed) >= 0.01)[0])
+ print("\nNumber of effective leaves: ", n_effective_leaves)
+
+ # grow until reaching required n_effective_leaves
+ if n_effective_leaves >= max_leaves:
+ print('\nReached maximum number of leaves\n')
+ return None, None, True
+
+ elif check_max:
+ return None, None, False
+
+ else:
+ leaves = compute_leaves(model.tree)
+ n_samples = []
+ if loader.dataset.dataset.__class__ is TensorDataset:
+ y_train = loader.dataset.dataset.tensors[1][loader.dataset.indices]
+ else:
+ y_train = loader.dataset.dataset.targets[loader.dataset.indices]
+ # Calculating ground-truth nodes-to-split for logging and model development
+ # NOTE: labels are used to evaluate leaf metrics, they are not used to select the leaf
+ for i in range(len(node_leaves)):
+ depth, node = leaves[i]['depth'], leaves[i]['node']
+ if not node.expand:
+ continue
+ ind = get_ind_small_tree(node_leaves[i], n_effective_leaves)
+ y_train_small = y_train[ind]
+ # printing distribution of ground-truth classes in leaves
+ print(f"Leaf {i}: ", np.unique(y_train_small, return_counts=True))
+ n_samples.append(len(y_train_small))
+
+ # Highest number of samples indicates splitting
+ split_values = n_samples
+ ind_leaves = np.argsort(np.array(split_values))
+ ind_leaves = ind_leaves[::-1]
+
+ print("Ranking of leaves to split: ", ind_leaves)
+ for i in ind_leaves:
+ if n_samples[i] < batch_size:
+ wandb.log({'Skipped Split': 1})
+ print("We don't split leaves with fewer samples than batch size")
+ continue
+ elif leaves[i]['depth'] == max_depth or not leaves[i]['node'].expand:
+ leaves[i]['node'].expand = False
+ print('\nReached maximum architecture\n')
+ print('\n!!ATTENTION!! architecture is not deep enough\n')
+ break
+ else:
+ ind_leaf = i
+ leaf = leaves[ind_leaf]
+ print(f'\nSplitting leaf {ind_leaf}\n')
+ return ind_leaf, leaf, n_effective_leaves
+
+ return None, None, n_effective_leaves
+
+
+def compute_pruning_leaf(model, node_leaves_train):
+ leaves = compute_leaves(model.tree)
+ n_leaves = len(node_leaves_train)
+ weights = [node_leaves_train[i]['prob'] for i in range(n_leaves)]
+
+ # Assign each sample to a leaf by argmax(weights)
+ max_indeces = np.array([np.argmax(col) for col in zip(*weights)])
+
+ n_samples = []
+ for i in range(n_leaves):
+ print(f"Leaf {i}: ", sum(max_indeces == i), "samples")
+ n_samples.append(sum(max_indeces == i))
+
+ # Prune leaves with less than 1% of all samples
+ ind_leaf = np.argmin(n_samples)
+ if n_samples[ind_leaf] < 0.01 * sum(n_samples):
+ leaf = leaves[ind_leaf]
+ return ind_leaf, leaf
+ else:
+ return None, None
+
+
+def get_optimizer(model, configs):
+ optimizer = optim.Adam(params=model.parameters(), lr=configs['training']['lr'],
+ weight_decay=configs['training']['weight_decay'])
+ return optimizer
+
+
+class Custom_Metrics(Metric):
+ def __init__(self, device):
+ super().__init__()
+ self.add_state("loss_value", default=torch.tensor(0., device=device))
+ self.add_state("rec_loss", default=torch.tensor(0., device=device))
+ self.add_state("kl_root", default=torch.tensor(0., device=device))
+ self.add_state("kl_decisions", default=torch.tensor(0., device=device))
+ self.add_state("kl_nodes", default=torch.tensor(0., device=device))
+ self.add_state("aug_decisions", default=torch.tensor(0., device=device))
+ self.add_state("perc_samples", default=torch.tensor(0., device=device))
+ self.add_state("y_true", default=[])
+ self.add_state("y_pred", default=[])
+ self.add_state("n_samples", default=torch.tensor(0, dtype=torch.int, device=device))
+
+ def update(self, loss_value: torch.Tensor, rec_loss: torch.Tensor, kl_decisions: torch.Tensor,
+ kl_nodes: torch.Tensor, kl_root: torch.Tensor, aug_decisions: torch.Tensor, perc_samples: torch.Tensor,
+ y_true: torch.Tensor, y_pred: torch.Tensor):
+ assert y_true.shape == y_pred.shape
+
+ n_samples = y_true.numel()
+ self.n_samples += n_samples
+ self.loss_value += loss_value.item() * n_samples
+ self.rec_loss += rec_loss.item() * n_samples
+ self.kl_root += kl_root.item() * n_samples
+ self.kl_decisions += kl_decisions.item() * n_samples
+ self.kl_nodes += kl_nodes.item() * n_samples
+ self.aug_decisions += aug_decisions.item() * n_samples
+ self.perc_samples += perc_samples.item() * n_samples
+ self.y_true.append(y_true)
+ self.y_pred.append(y_pred)
+
+ def compute(self):
+ self.y_true = torch.cat(self.y_true, dim=0)
+ self.y_pred = torch.cat(self.y_pred, dim=0)
+ nmi = normalized_mutual_info_score(self.y_true.cpu().numpy(), self.y_pred.cpu().numpy())
+ acc = cluster_acc(self.y_true.cpu().numpy(), self.y_pred.cpu().numpy(), return_index=False)
+
+ metrics = dict({'loss_value': self.loss_value / self.n_samples, 'rec_loss': self.rec_loss / self.n_samples,
+ 'kl_decisions': self.kl_decisions / self.n_samples, 'kl_root': self.kl_root / self.n_samples,
+ 'kl_nodes': self.kl_nodes / self.n_samples,
+ 'aug_decisions': self.aug_decisions / self.n_samples,
+ 'perc_samples': self.perc_samples / self.n_samples, 'nmi': nmi, 'accuracy': acc})
+
+ return metrics
diff --git a/treevae/utils/utils.py b/treevae/utils/utils.py
new file mode 100644
index 0000000..7bd7940
--- /dev/null
+++ b/treevae/utils/utils.py
@@ -0,0 +1,214 @@
+"""
+General utility functions.
+"""
+
+import numpy as np
+from scipy.optimize import linear_sum_assignment as linear_assignment
+from scipy.special import comb
+import torch
+import os
+import random
+from pathlib import Path
+import yaml
+
+
+def cluster_acc(y_true, y_pred, return_index=False):
+ """
+ Calculate clustering accuracy.
+ # Arguments
+ y: true labels, numpy.array with shape `(n_samples,)`
+ y_pred: predicted labels, numpy.array with shape `(n_samples,)`
+ # Return
+ accuracy, in [0,1]
+ """
+ y_true = y_true.astype(np.int64)
+ assert y_pred.size == y_true.size
+ D = max(y_pred.astype(int).max(), y_true.astype(int).max()) + 1
+ w = np.zeros((int(D), (D)), dtype=np.int64)
+ for i in range(y_pred.size):
+ w[int(y_pred[i]), int(y_true[i])] += 1
+ ind = np.array(linear_assignment(w.max() - w))
+ if return_index:
+ assert all(ind[0] == range(len(ind[0]))) # Assert rows don't change order
+ cluster_acc = sum(w[ind[0], ind[1]]) * 1.0 / y_pred.size
+ return cluster_acc, ind[1]
+ else:
+ return sum([w[ind[0,i], ind[1,i]] for i in range(len(ind[0]))]) * 1.0 / y_pred.size
+
+
+def reset_random_seeds(seed):
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+ # No determinism as nn.Upsample has no deterministic implementation
+ #torch.use_deterministic_algorithms(True)
+ torch.backends.cudnn.benchmark = False
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+
+
+def merge_yaml_args(configs, args):
+ arg_dict = args.__dict__
+ configs['parser'] = dict()
+ for key, value in arg_dict.items():
+ flag = True
+ # Replace/Create values in config if they are defined by arg in parser.
+ if arg_dict[key] is not None:
+ for key_config in configs.keys():
+ # If value of config is dict itself, then search key-value pairs inside this dict for matching the arg
+ if type(configs[key_config]) is dict:
+ for key2, value2 in configs[key_config].items():
+ if key == key2:
+ configs[key_config][key2] = value
+ flag = False
+ # If value of config is not a dict, check whether key matches to the arg
+ else:
+ if key == key_config:
+ configs[key_config] = value
+ flag = False
+ # Break out of loop if key got replaced
+ if flag == False:
+ break
+ # If arg does not match any keys of config, define a new key
+ else:
+ print("Could not find this key in config, therefore adding it:", key)
+ configs['parser'][key] = arg_dict[key]
+ return configs
+
+
+def prepare_config(args, project_dir):
+ # Load config
+ data_name = args.config_name +'.yml'
+ config_path = project_dir / 'configs' / data_name
+
+ with config_path.open(mode='r') as yamlfile:
+ configs = yaml.safe_load(yamlfile)
+
+ # Override config if args in parser
+ configs = merge_yaml_args(configs, args)
+ if isinstance(configs['training']['latent_dim'], str):
+ a = configs['training']['latent_dim'].split(",")
+ configs['training']['latent_dim'] = [int(i) for i in a]
+ if isinstance(configs['training']['mlp_layers'], str):
+ a = configs['training']['mlp_layers'].split(",")
+ configs['training']['mlp_layers'] = [int(i) for i in a]
+
+ a = configs['training']['augmentation_method'].split(",")
+ configs['training']['augmentation_method'] = [str(i) for i in a]
+
+
+
+ configs['globals']['results_dir'] = os.path.join(project_dir, 'models/experiments')
+ configs['globals']['results_dir'] = Path(configs['globals']['results_dir']).absolute()
+
+ # Prepare for passing x' and x'' through model by setting batch size to an even number
+ if configs['training']['augment'] is True and configs['training']['augmentation_method'] != ['simple'] and configs['training']['batch_size'] % 2 != 0:
+ configs['training']['batch_size'] += 1
+
+
+ return configs
+
+def count_values_in_sequence(seq):
+ from collections import defaultdict
+ res = defaultdict(lambda : 0)
+ for key in seq:
+ res[key] += 1
+ return dict(res)
+
+
+def dendrogram_purity(tree_root, ground_truth, ind_samples_of_leaves):
+ total_per_label_frequencies = count_values_in_sequence(ground_truth)
+ total_per_label_pairs_count = {k: comb(v, 2, True) for k, v in total_per_label_frequencies.items()}
+ total_n_of_pairs = sum(total_per_label_pairs_count.values())
+ one_div_total_n_of_pairs = 1. / total_n_of_pairs
+ purity = 0.
+
+ def calculate_purity(node, level):
+ nonlocal purity
+ if node.decoder:
+ # Match node to leaf samples
+ ind_leaf = np.where([node == ind_samples_of_leaves[ind_leaf][0] for ind_leaf in range(len(ind_samples_of_leaves))])[0].item()
+ ind_samples_of_leaf = ind_samples_of_leaves[ind_leaf][1]
+ node_total_dp_count = len(ind_samples_of_leaf)
+ # Count how many samples of given leaf fall into which ground-truth class (-> For treevae make use of ground_truth(to which class a sample belongs)&yy(into which leaf a sample falls))
+ node_per_label_frequencies = count_values_in_sequence(
+ [ground_truth[id] for id in ind_samples_of_leaf])
+ # From above, deduct how many pairs will fall into same leaf
+ node_per_label_pairs_count = {k: comb(v, 2, True) for k, v in node_per_label_frequencies.items()}
+
+ elif node.router is None and node.decoder is None:
+ # We are in an internal node with pruned leaves and thus only have one child. Therefore no prunity calculation here!
+ node_left, node_right = node.left, node.right
+ child = node_left if node_left is not None else node_right
+ node_per_label_frequencies, node_total_dp_count = calculate_purity(child, level + 1)
+ return node_per_label_frequencies, node_total_dp_count
+
+ else:
+ # it is an inner splitting node
+ left_child_per_label_freq, left_child_total_dp_count = calculate_purity(node.left, level + 1)
+ right_child_per_label_freq, right_child_total_dp_count = calculate_purity(node.right, level + 1)
+ node_total_dp_count = left_child_total_dp_count + right_child_total_dp_count
+ # Count how many samples of given internal node fall into which ground-truth class (=sum of their children's values)
+ node_per_label_frequencies = {k: left_child_per_label_freq.get(k, 0) + right_child_per_label_freq.get(k, 0) \
+ for k in set(left_child_per_label_freq) | set(right_child_per_label_freq)}
+
+ # Class-wisedly count how many pairs of samples of a class will have this node as least common ancestor (=mult. of their children's values, bcs this is all possible pairs coming from different sides)
+ node_per_label_pairs_count = {k: left_child_per_label_freq.get(k) * right_child_per_label_freq.get(k) \
+ for k in set(left_child_per_label_freq) & set(right_child_per_label_freq)}
+
+ # Given the class-wise number of pairs with given node as least common ancestor node, calculate their purity
+ for label, pair_count in node_per_label_pairs_count.items():
+ label_freq = node_per_label_frequencies[label]
+ label_pairs = node_per_label_pairs_count[label]
+ purity += one_div_total_n_of_pairs * label_freq / node_total_dp_count * label_pairs # (1/n_all_pairs) * purity(=n_samples_of_this_class_in_node/n_samples) * n_class_pairs_with_this_node_being_least_common_ancestor(this last term represents sum over pairs with this node being least common ancestor)
+ return node_per_label_frequencies, node_total_dp_count
+
+ calculate_purity(tree_root, 0)
+ return purity
+
+
+def leaf_purity(tree_root, ground_truth, ind_samples_of_leaves):
+ values = [] # purity rate per leaf
+ weights = [] # n_samples per leaf
+ # For each leaf calculate the maximum over classes for in-leaf purity (i.e. majority class / n_samples_in_leaf)
+ def get_leaf_purities(node):
+ nonlocal values
+ nonlocal weights
+ if node.decoder:
+ ind_leaf = np.where([node == ind_samples_of_leaves[ind_leaf][0] for ind_leaf in range(len(ind_samples_of_leaves))])[0].item()
+ ind_samples_of_leaf = ind_samples_of_leaves[ind_leaf][1]
+ node_total_dp_count = len(ind_samples_of_leaf)
+ node_per_label_counts = count_values_in_sequence(
+ [ground_truth[id] for id in ind_samples_of_leaf])
+ if node_total_dp_count > 0:
+ purity_rate = max(node_per_label_counts.values()) / node_total_dp_count
+ else:
+ purity_rate = 1.0
+ values.append(purity_rate)
+ weights.append(node_total_dp_count)
+ elif node.router is None and node.decoder is None:
+ # We are in an internal node with pruned leaves and thus only have one child.
+ node_left, node_right = node.left, node.right
+ child = node_left if node_left is not None else node_right
+ get_leaf_purities(child)
+ else:
+ get_leaf_purities(node.left)
+ get_leaf_purities(node.right)
+
+ get_leaf_purities(tree_root)
+ assert len(values) == len(ind_samples_of_leaves), "Didn't iterate through all leaves"
+ # Return mean leaf_purity
+ return np.average(values, weights=weights)
+
+def display_image(image):
+ assert image.dim() == 3
+ if image.size()[0] == 1:
+ return torch.clamp(image.squeeze(0),0,1)
+ elif image.size()[0] == 3:
+ return torch.clamp(image.permute(1, 2, 0),0,1)
+ elif image.size()[-1] == 3:
+ return torch.clamp(image,0,1)
+ else:
+ raise NotImplementedError
\ No newline at end of file