From ce044bf6168a51839007203d90e42f5ca773a0f1 Mon Sep 17 00:00:00 2001 From: Tianyi Zhu Date: Mon, 31 Mar 2025 19:42:14 -0400 Subject: [PATCH 1/2] Add files via upload --- .../cobweb_symbolic.cpython-39.pyc | Bin 0 -> 3706 bytes symbolic_evaluation/cobweb_symbolic.py | 108 ++++ .../evaluate_cobweb_symbolic.py | 537 ++++++++++++++++++ 3 files changed, 645 insertions(+) create mode 100644 symbolic_evaluation/__pycache__/cobweb_symbolic.cpython-39.pyc create mode 100644 symbolic_evaluation/cobweb_symbolic.py create mode 100644 symbolic_evaluation/evaluate_cobweb_symbolic.py diff --git a/symbolic_evaluation/__pycache__/cobweb_symbolic.cpython-39.pyc b/symbolic_evaluation/__pycache__/cobweb_symbolic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..459122f7e70f718217d6deb9ce1a893eb5e2f4f6 GIT binary patch literal 3706 zcmai1-ESOM6`wm_JG)*x8^;Mv8o-o5Sy1d0N+^{TqQ;~IMXuBYp{@#}*|~SUlg`ep z@4f59tFeUSNR)>Zgm?p_^$Sv8_#67n1AoAM?UVlngap5HXYDvn5Hp%{&z}2zK7QwH zvbflC@O=KgFSkEka-1*e$LYtzkK6dvA3?aoS?=`emi1ixc5`p!_xw@N3mD10JRCK8 z4a@s^G>UqWU;+IG#w-lA=nM(27R+ZX2NN6Wos%Qx~1qm|wY zbKZA&#G6kY-aK%77dd<2bX!L_m$TJ%O-sw9P=ivAl!+4a;NM+7X3SAp|fIAPSisw1munP_OwvH5TjrwJt zrAGl4PHCIthxicR9enEBAQh{e*r{Alra>RUHlU0S@yv(R;|`Ygw~kzlJ|Zi3edDWa zDtfK7+)ZRWkmbk(T#WUw*Xcuh=`bs{<3W<@Qetp3E5;KY^K2yH8jhs>p+HiUhm}M= zkmpEFou4~$8O^^wTmOX;Qms#i6P+c+Ubem|)F-+euYXX8O_}YA)$6P8i&Ac9$?CmA z=}b?wSS<&u4@EjG%Dmj(Ti?7Ne<+klWI9}*6lyXSayL_gUmwd-mwIn3)Vi9hDi*s* zK1p;|7VD|CX?)zZHr_K)9A`zQ0^uuV54mz=`hCM&BS996qC_-Pj;*% z8)On{_~^ z^SC*7)HORWB^Nt>=%f)o@>umD$Zws2RoBz#ch08{dFmr}!1i4zi*fJ3{ry!a(VO}3 zW$(~s&M$!@KHdY5IYs9!M{VA4n0-G2oqKgkN;`gND^d!bUHu89!mx=W0WJ^7TG&v# zkFb@zMm$-gK!`v}eZjo%Ht{EYf$araQkYOpMzPv4L0V1c*8Wm=ia5F3a{ zA+JGX++u7R1y7_*_DraO5lYK%p+^!Hn8q+siPqB2H(@d!i-J>-79|&^Nfcq6YHUb5 zeGdFf%VJkZ9os%#+V1BbAis}gzlu*?0&!TxeD5lw*Zp_cXt5Ssai_P=gpf*hMNIGz zW4%YJEY>#iGEH)IYwcCL8TPmv%w1QOr8a(vV;i4l8lS_lyMFD=(ICr3k&HyIX_vQG zVT1a(Z{d8>dhPj_4)8kNfSl2B+skdQ&8DmW@60dwGKD@(x1s+JFm>p|XRpA2+03on zL$>d|M?x3b=RKIb_ssICF_Vm-r`=U@i%)1kk}T1{Ctw z>)GLk!y8qw%hVT;H!eHsZ##1i?zjX9@@I`Iv~KwKop&n_nIW<<6m1zMC54f7u~Xg@ z0)&wtYG-W9AJUMUAjT)?Fy6Qvn*b>y9l{kwp-O2yOp(`VAVJX_gycIU_%Y3FYB3sH z-!SfY$N1?)N_m62DQi6UJF`q%)J_EWvYXKU>I#Sh5c-ZAdM&nW-z#p1xo(>+xzlf) zG3(2#7=IhFb2lZNIDW{y^jY)%2290f&Y?2{qLZO2UtXj5uY$@WynjZ3v~o%UZ0Er< z6h7LW`P?d52M5kRp+;DRl>NEKo6meZqHzfLs=^s$WDeb&weV}<5Qdz!dFzlcVqwep z$wZgP27Wq9#)xG#7OAf7bP-2HO#BRuya$3LFv<`syQAcZ@yfvfwkft@n{0ci5&j`n zLy0yE{eJlbc`zLcWm;n>xaOz9L=UdN(_NGlD5fEW8YW|bRdQ{@Y@~+eW1QznrX;Be zNeK}?$go42mr76#AU-i+AEnmKw@tHu^KDV2ggDW*(3F;i3ArH2u}9pdWdgg4y#`?u zj+j%*3u<)p`jdHNTi39W-KpJ-aHZbt=U7KkesYkMe75YaxGgt?S43`$O}~D|8_v%D z-!cAAtXL-u%21Vy$n`33=5hczV%T})*;w@n;t!F%$N_kWdJCoAhJ#phDaK)D;-V9- zdBVZW$YBeWkBI#YvlbAsdS}krh=sGA6baQ%2PuaA+2QTKx_r^*%1}qMCURx?%3S%y zY2>kSzO?l`$h+>*RoqbaK};hn20|7kN{`y5eofLK8zrbQPI8Ofpb2k+OoKIoX88l~ z-IgR=w9zha5tk7836a}GEEb#~hx`#_b06rnce81HjOda|Xh{JmH;GtGAjmii4CiQ8 zS{Hbf1bP^w=rRtHQ3$KAqP6b{DlD_<*UostnMJm`Ot!ED8Y=ssh`VU+cAHc;Qo@O2 z(~9F!$tOAS?Ks|RQ9@ceXh=Ky*-fOk4ka_8U z4m5ZCN5_`&7SH8JwAgtd#5gZ?p7l{M6KBG@l@F-j&B{9Zeflp{o1O0QstcNEqt?|e W$}8$MkO;Ly2f*D%OCM3Z?tcMzc9*pP literal 0 HcmV?d00001 diff --git a/symbolic_evaluation/cobweb_symbolic.py b/symbolic_evaluation/cobweb_symbolic.py new file mode 100644 index 0000000..304526c --- /dev/null +++ b/symbolic_evaluation/cobweb_symbolic.py @@ -0,0 +1,108 @@ +import torch +import numpy as np +from torchvision import transforms, datasets +from torch.utils.data import DataLoader +from tqdm import tqdm +import untils +from cobweb.cobweb_continuous import CobwebContinuousTree +import json +import base64 +import matplotlib.pyplot as plt +from io import BytesIO +from copy import deepcopy + +class CobwebSymbolic(): + def __init__(self, input_dim, depth=5): + self.input_dim = input_dim + self.depth = depth + self.tree = CobwebContinuousTree(size=self.input_dim, covar_from=2, depth=self.depth, branching_factor=10) + + def train(self, train_data, epochs=10): + train_loader = DataLoader(train_data, batch_size=1, shuffle=True) + for epoch in range(epochs): + for (x, y) in tqdm(train_loader): + x = x.view(-1).numpy() + self.tree.ifit(x) + + def save_tree_to_json(self, filename): + # Convert tree to serializable format + def convert_node_to_dict(node): + if node is None: + return None + + node_dict = { + "mean": node.mean.tolist() if isinstance(node.mean, np.ndarray) else node.mean, + "sum_sq": node.sum_sq.tolist() if isinstance(node.sum_sq, np.ndarray) else node.sum_sq, + "count": node.count.tolist() if isinstance(node.count, np.ndarray) else node.count, + "children": [] + } + + if hasattr(node, 'children'): + for child in node.children: + node_dict["children"].append(convert_node_to_dict(child)) + + return node_dict + + # Convert the tree to a dictionary + tree_dict = convert_node_to_dict(self.tree.root) + + # Save to file + with open(filename, 'w') as f: + json.dump(tree_dict, f) + + def load_tree_in_torch(self, filename): + with open(filename, 'r') as f: + temp_tree = json.load(f) + + pq = [temp_tree] + + while True: + curr = pq.pop(0) + curr["mean"] = torch.tensor(curr["mean"]) + curr["sum_sq"] = torch.tensor(curr["sum_sq"]) + curr["count"] = torch.tensor(curr["count"]) + curr["logvar"] = torch.log(curr["sum_sq"] / curr["count"]) + + if "children" not in curr or not curr["children"]: + break + + for child in curr["children"]: + pq.append(child) + + self.tree = temp_tree + + def tensor_to_base64(self, tensor, shape, cmap="gray", normalize=False): + array = tensor.numpy().reshape(shape) + if normalize: + plt.imshow(array, cmap=cmap, aspect="auto") + else: + plt.imshow(array, cmap=cmap, aspect="auto", vmin=0, vmax=1) + + plt.axis("off") + + buf = BytesIO() + plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) + plt.close() + + buf.seek(0) + return base64.b64encode(buf.getvalue()).decode("utf-8") + + def viz_cobweb_tree(self, viz_filename): + temp_tree = deepcopy(self.tree) + pq = [temp_tree] + while True: + curr = pq.pop(0) + curr["image"] = self.tensor_to_base64(torch.tensor(curr["mean"]), (28, 28), cmap="inferno", normalize=True) + curr.pop("mean") + curr.pop("sum_sq") + curr.pop("count") + curr.pop("logvar") + + if "children" not in curr or not curr["children"]: + break + + for child in curr["children"]: + pq.append(child) + + with open(f'{viz_filename}.json', 'w') as f: + json.dump(temp_tree, f) \ No newline at end of file diff --git a/symbolic_evaluation/evaluate_cobweb_symbolic.py b/symbolic_evaluation/evaluate_cobweb_symbolic.py new file mode 100644 index 0000000..4e054bc --- /dev/null +++ b/symbolic_evaluation/evaluate_cobweb_symbolic.py @@ -0,0 +1,537 @@ + import torch +import numpy as np +from torchvision import transforms, datasets +from torch.utils.data import DataLoader +import matplotlib.pyplot as plt +from sklearn.metrics import accuracy_score, confusion_matrix +import seaborn as sns +from tqdm import tqdm +import json +import os +from cobweb_symbolic import CobwebSymbolic +import untils + +# Helper functions for evaluation +def safe_categorize(model, sample): + """ + Safely categorize a sample using the model tree, with error handling. + + Args: + model: Trained CobwebSymbolic model + sample: Input sample with label + + Returns: + Node in the tree or None if categorization failed + """ + try: + # Method 1: Try using _cobweb_categorize directly + if hasattr(model.tree, '_cobweb_categorize'): + return model.tree._cobweb_categorize(sample) + # Method 2: Try using categorize method + elif hasattr(model.tree, 'categorize'): + return model.tree.categorize(sample) + # Method 3: Try calling predict (which might use cobweb internally) + elif hasattr(model, 'predict'): + prediction = model.predict(sample[:-1]) # Exclude label for predict + return {"prediction": prediction} + else: + print(" No categorization method found") + return None + except Exception as e: + print(f" Categorization error: {e}") + return None + +def visualize_confusion_matrix(test_labels, predictions, title): + """ + Visualize confusion matrix for model evaluation. + + Args: + test_labels: True labels + predictions: Predicted labels + title: Title for the confusion matrix plot + """ + plt.figure(figsize=(10, 8)) + cm = confusion_matrix(test_labels, predictions) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') + plt.xlabel('Predicted Label') + plt.ylabel('True Label') + plt.title(title) + plt.show() + +def preprocess_dataset(dataset_name, split_classes=[0, 1, 2, 3]): + """ + Load and preprocess dataset with symbolic feature encoding. + + Args: + dataset_name: 'mnist' or 'cifar10' + split_classes: List of classes to use + + Returns: + train_loader, test_loader, image_shape + """ + print(f"Loading and preprocessing {dataset_name} dataset...") + + if dataset_name.lower() == 'mnist': + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + train_dataset = datasets.MNIST('data/MNIST', train=True, download=True, transform=transform) + test_dataset = datasets.MNIST('data/MNIST', train=False, download=True, transform=transform) + image_shape = (28, 28) + + elif dataset_name.lower() == 'cifar10': + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ]) + + train_dataset = datasets.CIFAR10('data/CIFAR10', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR10('data/CIFAR10', train=False, download=True, transform=transform) + image_shape = (32, 32, 3) + + else: + raise ValueError(f"Unsupported dataset: {dataset_name}") + + # Filter for specified classes + train_dataset = untils.filter_by_label(train_dataset, split_classes, rename_labels=True) + test_dataset = untils.filter_by_label(test_dataset, split_classes, rename_labels=True) + + # Create data loaders + train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) + + # Print class distribution + class_counts = {} + for _, label in train_dataset: + label_value = label.item() if hasattr(label, 'item') else int(label) + class_counts[label_value] = class_counts.get(label_value, 0) + 1 + + print(f"Training set class distribution:") + for class_idx, count in sorted(class_counts.items()): + print(f" Class {class_idx}: {count} samples") + + return train_loader, test_loader, image_shape + +def train_symbolic_cobweb(train_loader, image_shape, depth=4, epochs=1): + """ + Train a symbolic CobWeb model incrementally. + + Args: + train_loader: DataLoader with training data + image_shape: Shape of input images + depth: Maximum depth of the CobWeb tree + epochs: Number of training epochs + + Returns: + trained model + """ + print("Training symbolic CobWeb model...") + + if len(image_shape) == 2: + input_dim = image_shape[0] * image_shape[1] + else: + input_dim = image_shape[0] * image_shape[1] * image_shape[2] + + # Initialize model + model = CobwebSymbolic(input_dim=input_dim, depth=depth) + + # Train incrementally + for epoch in range(epochs): + print(f"Epoch {epoch+1}/{epochs}") + for data, label in tqdm(train_loader, desc="Training"): + # Flatten the input data + x = data.view(-1).numpy() + # Get label + # y = label.item() if hasattr(label, 'item') else int(label) + + # Include label at the end of the input vector + # x_with_label = np.concatenate([x, np.array([y])]) + + # Incrementally fit the tree + model.tree.ifit(x_with_label) + + return model + +def save_model_and_results(model, results, filename): + """ + Save the trained model and evaluation results. + + Args: + model: Trained CobwebSymbolic model + results: Evaluation results dictionary + filename: Base filename for saving + """ + # Create directory if it doesn't exist + os.makedirs('results', exist_ok=True) + + # Save model + model_file = f"results/{filename}_model.json" + try: + model.save_tree_to_json(model_file) + print(f"Model saved to {model_file}") + except Exception as e: + print(f"Error saving model: {e}") + + # Save results + results_file = f"results/{filename}_results.json" + + # Make all results JSON serializable + def make_serializable(obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, (np.float64, np.float32, np.int64, np.int32)): + return float(obj) + elif isinstance(obj, dict): + return {str(k): make_serializable(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [make_serializable(item) for item in obj] + else: + return obj + + serializable_results = make_serializable(results) + + # Write to file + try: + with open(results_file, 'w') as f: + json.dump(serializable_results, f) + print(f"Results saved to {results_file}") + except Exception as e: + print(f"Error saving results: {e}") + print("Saving simplified results instead...") + + # Fallback to basic results + simple_results = {} + for method, result in results.items(): + if isinstance(result, dict) and "accuracy" in result: + simple_results[method] = {"accuracy": float(result["accuracy"])} + elif isinstance(result, dict) and "error" in result: + simple_results[method] = {"error": str(result["error"])} + else: + simple_results[method] = {"result": "unknown"} + + with open(results_file, 'w') as f: + json.dump(simple_results, f) + +# Evaluation methods +def evaluate_with_averaging(model, train_loader, test_loader, image_shape, num_classes, + samples_per_class, visualize=True): + """ + Evaluation using sample points averaging: Select samples from each class, + find nodes in the tree, average them, and classify test samples. + + Args: + model: Trained CobwebSymbolic model + train_loader: DataLoader with training data + test_loader: DataLoader with test data + image_shape: Shape of input images + num_classes: Number of classes in the dataset + samples_per_class: Number of sample points to take per class + visualize: Whether to generate visualizations + + Returns: + Dictionary of evaluation metrics + """ + print(f"Evaluating with averaging approach (samples_per_class={samples_per_class})...") + + # 1. Collect training samples by class + class_samples = {cls: [] for cls in range(num_classes)} + + for data, label in tqdm(train_loader, desc="Collecting samples by class"): + x = data.view(-1).numpy() + y = label.item() if hasattr(label, 'item') else int(label) + + if y < num_classes: # Ensure valid class + class_samples[y].append(x) + + # 2. Select samples for each class + selected_samples = {} + for cls in range(num_classes): + if class_samples[cls]: + # If samples_per_class is -1, use all samples + if samples_per_class == -1: + selected_samples[cls] = class_samples[cls] + print(f"Class {cls}: Using all {len(selected_samples[cls])} samples") + # Otherwise, select random samples if enough are available + elif len(class_samples[cls]) >= samples_per_class: + indices = np.random.choice(len(class_samples[cls]), samples_per_class, replace=False) + selected_samples[cls] = [class_samples[cls][i] for i in indices] + print(f"Class {cls}: Selected {len(selected_samples[cls])} samples") + else: + # Use all available samples if less than requested + selected_samples[cls] = class_samples[cls] + print(f"Class {cls}: Using all {len(selected_samples[cls])} samples (fewer than requested)") + else: + print(f"Warning: No samples found for class {cls}") + selected_samples[cls] = [] + + # 3. Calculate average for each class + class_averages = {} + for cls in range(num_classes): + if selected_samples[cls]: + # Calculate average sample + avg_sample = np.mean(selected_samples[cls], axis=0) + + # Store average representation + class_averages[cls] = avg_sample + print(f"Class {cls}: Calculated average from {len(selected_samples[cls])} samples") + else: + print(f"Warning: Cannot calculate average for class {cls}") + + # Exit if no valid classes found + if len(class_averages) == 0: + print("No valid class averages found, cannot proceed with evaluation") + return {"accuracy": 0.0, "error": "No valid class averages"} + + # 4. Visualize class averages + if visualize: + plt.figure(figsize=(15, 5)) + plt.suptitle(f"Class Averages (samples_per_class={samples_per_class})", fontsize=16) + + for cls in range(num_classes): + if cls in class_averages: + plt.subplot(1, num_classes, cls + 1) + + # Reshape average sample + if len(image_shape) == 2: + img = class_averages[cls].reshape(image_shape) + plt.imshow(img, cmap='gray') + else: + img = class_averages[cls].reshape(image_shape) + plt.imshow(img) + + plt.title(f"Class {cls}") + plt.axis('off') + + plt.tight_layout(rect=[0, 0, 1, 0.95]) + plt.show() + + # 5. Evaluate on test data + test_labels = [] + predictions = [] + + for data, label in tqdm(test_loader, desc="Evaluating test data"): + x = data.view(-1).numpy() + y = label.item() if hasattr(label, 'item') else int(label) + test_labels.append(y) + + # Find closest class average + best_class = None + best_distance = float('inf') + + for cls, avg_sample in class_averages.items(): + try: + dist = np.linalg.norm(x - avg_sample) + if dist < best_distance: + best_distance = dist + best_class = cls + except ValueError: + # Skip if shapes don't match + continue + + # Use most common class as fallback + if best_class is None: + best_class = np.argmax(np.bincount(test_labels)) + + predictions.append(best_class) + + # 6. Calculate accuracy and visualize results + test_labels = np.array(test_labels) + predictions = np.array(predictions) + accuracy = accuracy_score(test_labels, predictions) + + if visualize: + visualize_confusion_matrix( + test_labels, predictions, f'Confusion Matrix (Averaging, samples={samples_per_class})' + ) + + print(f"Averaging Approach (samples={samples_per_class}) Accuracy: {accuracy:.4f}") + + return { + "accuracy": accuracy, + "confusion_matrix": confusion_matrix(test_labels, predictions).tolist() + } + +def evaluate_class_centroids(model, train_loader, test_loader, image_shape, num_classes, visualize=True): + """ + Evaluate using centroids (average of all samples) from each class. + + Args: + model: Trained CobwebSymbolic model + train_loader: DataLoader with training data + test_loader: DataLoader with test data + image_shape: Shape of input images + num_classes: Number of classes in the dataset + visualize: Whether to generate visualizations + + Returns: + Dictionary of evaluation metrics + """ + print("Evaluating with class centroids approach...") + + # 1. Collect all training samples by class + class_samples = {cls: [] for cls in range(num_classes)} + + for data, label in tqdm(train_loader, desc="Collecting all class samples"): + x = data.view(-1).numpy() + y = label.item() if hasattr(label, 'item') else int(label) + + if y < num_classes: # Ensure valid class + class_samples[y].append(x) + + # Print stats about collected samples + for cls in range(num_classes): + print(f"Class {cls}: Collected {len(class_samples[cls])} samples") + + # 2. Calculate class centroids (average of all samples in each class) + class_centroids = {} + for cls in range(num_classes): + if len(class_samples[cls]) > 0: + class_centroids[cls] = np.mean(class_samples[cls], axis=0) + + # Visualize class centroids + if visualize: + plt.figure(figsize=(15, 5)) + plt.suptitle("Class Centroids (All Data)", fontsize=16) + + for cls in range(num_classes): + if cls in class_centroids: + plt.subplot(1, num_classes, cls + 1) + + # Reshape centroid + if len(image_shape) == 2: + img = class_centroids[cls].reshape(image_shape) + plt.imshow(img, cmap='gray') + else: + img = class_centroids[cls].reshape(image_shape) + plt.imshow(img) + + plt.title(f"Class {cls} Centroid") + plt.axis('off') + + plt.tight_layout(rect=[0, 0, 1, 0.95]) + plt.show() + + # 3. Evaluate on test data using centroids + test_labels = [] + predictions = [] + + for data, label in tqdm(test_loader, desc="Evaluating with centroids"): + x = data.view(-1).numpy() + y = label.item() if hasattr(label, 'item') else int(label) + test_labels.append(y) + + # Find closest centroid + best_class = None + best_distance = float('inf') + + for cls, centroid in class_centroids.items(): + try: + dist = np.linalg.norm(x - centroid) + if dist < best_distance: + best_distance = dist + best_class = cls + except ValueError: + continue + + # Use most common class as fallback + if best_class is None: + best_class = np.argmax(np.bincount(test_labels)) + + predictions.append(best_class) + + # 4. Calculate accuracy + test_labels = np.array(test_labels) + predictions = np.array(predictions) + + accuracy = accuracy_score(test_labels, predictions) + + # 5. Visualize results + if visualize: + visualize_confusion_matrix( + test_labels, predictions, 'Confusion Matrix (Class Centroids)' + ) + + print(f"Class Centroids Accuracy: {accuracy:.4f}") + + return { + "accuracy": accuracy, + "confusion_matrix": confusion_matrix(test_labels, predictions).tolist() + } + +def main(): + """Main function to run the symbolic CobWeb evaluation.""" + # Configuration parameters + datasets = ['mnist', 'cifar10'] + split_classes = [0, 1, 2, 3] # First 4 classes + + # Default parameters + default_depth = 5 + default_epochs = 3 + + # Sample sizes to test + sample_sizes = [30, 100, 500, 5000, -1] # -1 means use all available samples + + for dataset_name in datasets: + print(f"\n\n{'='*50}") + print(f"Evaluating symbolic CobWeb on {dataset_name.upper()}") + print(f"{'='*50}\n") + + # 1. Load and preprocess dataset + train_loader, test_loader, image_shape = preprocess_dataset(dataset_name, split_classes) + + # 2. Train model + model = train_symbolic_cobweb(train_loader, image_shape, depth=default_depth, epochs=default_epochs) + + # 3. Evaluate with different approaches + results = {} + + # Test different sample sizes for averaging approach + for samples in sample_sizes: + method_name = f"averaging_samples_{samples}" if samples != -1 else "averaging_all_samples" + print(f"\nEvaluating averaging approach with {samples if samples != -1 else 'all'} samples per class...") + + try: + results[method_name] = evaluate_with_averaging( + model, train_loader, test_loader, image_shape, len(split_classes), + samples_per_class=samples, visualize=True + ) + except Exception as e: + print(f"Error with averaging evaluation (samples={samples}): {e}") + results[method_name] = {"accuracy": 0.0, "error": str(e)} + + # Evaluate with class centroids approach + print("\nEvaluating with class centroids approach...") + try: + results["class_centroids"] = evaluate_class_centroids( + model, train_loader, test_loader, image_shape, len(split_classes), visualize=True + ) + except Exception as e: + print(f"Error with class centroids evaluation: {e}") + results["class_centroids"] = {"error": str(e)} + + # 4. Save model and results + save_model_and_results(model, results, f"{dataset_name}_symbolic_evaluation") + + # 5. Print summary + print("\n" + "="*50) + print(f"SUMMARY OF RESULTS FOR {dataset_name.upper()}:") + print("="*50) + + for samples in sample_sizes: + method_name = f"averaging_samples_{samples}" if samples != -1 else "averaging_all_samples" + if "error" not in results[method_name]: + print(f"Averaging ({samples if samples != -1 else 'all'} samples): {results[method_name]['accuracy']:.4f}") + else: + print(f"Averaging ({samples if samples != -1 else 'all'} samples): Failed") + + if "error" not in results["class_centroids"]: + print(f"Class Centroids: {results['class_centroids']['accuracy']:.4f}") + else: + print(f"Class Centroids: Failed") + + print("\n\n") + +if __name__ == '__main__': + main() \ No newline at end of file From bf87dabb40abc109188c110dd00b48598155044f Mon Sep 17 00:00:00 2001 From: Tianyi Zhu Date: Wed, 30 Apr 2025 01:54:28 -0400 Subject: [PATCH 2/2] Add files via upload --- treevae/README.md | 35 + treevae/configs/celeba.yml | 39 + treevae/configs/cifar10.yml | 39 + treevae/configs/cifar100.yml | 39 + treevae/configs/fmnist.yml | 39 + treevae/configs/mnist.yml | 39 + treevae/configs/news20.yml | 39 + treevae/configs/omniglot.yml | 41 + treevae/evaluate_treevae.ipynb | 152 +++ treevae/main.py | 62 + treevae/minimal_requirements.txt | 14 + .../models/__pycache__/losses.cpython-39.pyc | Bin 0 -> 1978 bytes .../models/__pycache__/model.cpython-39.pyc | Bin 0 -> 16955 bytes .../__pycache__/networks.cpython-39.pyc | Bin 0 -> 12118 bytes treevae/models/losses.py | 34 + treevae/models/model.py | 602 ++++++++++ treevae/models/model_smalltree.py | 179 +++ treevae/models/networks.py | 386 +++++++ treevae/train/train.py | 87 ++ treevae/train/train_tree.py | 247 ++++ treevae/train/validate_tree.py | 192 +++ treevae/tree_exploration.ipynb | 1025 +++++++++++++++++ treevae/treevae.png | Bin 0 -> 1109860 bytes treevae/treevae.yml | 187 +++ .../__pycache__/model_utils.cpython-39.pyc | Bin 0 -> 6584 bytes .../__pycache__/training_utils.cpython-39.pyc | Bin 0 -> 15518 bytes treevae/utils/data_utils.py | 406 +++++++ treevae/utils/model_utils.py | 238 ++++ treevae/utils/plotting_utils.py | 301 +++++ treevae/utils/training_utils.py | 512 ++++++++ treevae/utils/utils.py | 214 ++++ 31 files changed, 5148 insertions(+) create mode 100644 treevae/README.md create mode 100644 treevae/configs/celeba.yml create mode 100644 treevae/configs/cifar10.yml create mode 100644 treevae/configs/cifar100.yml create mode 100644 treevae/configs/fmnist.yml create mode 100644 treevae/configs/mnist.yml create mode 100644 treevae/configs/news20.yml create mode 100644 treevae/configs/omniglot.yml create mode 100644 treevae/evaluate_treevae.ipynb create mode 100644 treevae/main.py create mode 100644 treevae/minimal_requirements.txt create mode 100644 treevae/models/__pycache__/losses.cpython-39.pyc create mode 100644 treevae/models/__pycache__/model.cpython-39.pyc create mode 100644 treevae/models/__pycache__/networks.cpython-39.pyc create mode 100644 treevae/models/losses.py create mode 100644 treevae/models/model.py create mode 100644 treevae/models/model_smalltree.py create mode 100644 treevae/models/networks.py create mode 100644 treevae/train/train.py create mode 100644 treevae/train/train_tree.py create mode 100644 treevae/train/validate_tree.py create mode 100644 treevae/tree_exploration.ipynb create mode 100644 treevae/treevae.png create mode 100644 treevae/treevae.yml create mode 100644 treevae/utils/__pycache__/model_utils.cpython-39.pyc create mode 100644 treevae/utils/__pycache__/training_utils.cpython-39.pyc create mode 100644 treevae/utils/data_utils.py create mode 100644 treevae/utils/model_utils.py create mode 100644 treevae/utils/plotting_utils.py create mode 100644 treevae/utils/training_utils.py create mode 100644 treevae/utils/utils.py diff --git a/treevae/README.md b/treevae/README.md new file mode 100644 index 0000000..badd37d --- /dev/null +++ b/treevae/README.md @@ -0,0 +1,35 @@ +# Tree Variational Autoencoders +This is the PyTorch repository for the NeurIPS 2023 Spotlight Publication (https://neurips.cc/virtual/2023/poster/71188). + +TreeVAE is a new generative method that learns the optimal tree-based posterior distribution of latent variables to capture the hierarchical structures present in the data. It adapts the architecture to discover the optimal tree for encoding dependencies between latent variables. TreeVAE optimizes the balance between shared and specialized architecture, enhancing the learning and adaptation capabilities of generative models. +An example of a tree learned by TreeVAE is depicted in the figure below. Each edge and each split are encoded by neural networks, while the circles depict latent variables. Each sample is associated with a probability distribution over different paths of the discovered tree. The resulting tree thus organizes the data into an interpretable hierarchical structure in an unsupervised fashion, optimizing the amount of shared information between samples. In CIFAR-10, for example, the method divides the vehicles and animals into two different subtrees and similar groups (such as planes and ships) share common ancestors. + +![Alt text](https://github.com/lauramanduchi/treevae/blob/main/treevae.png?raw=true) +For running TreeVAE: + +1. Create a new environment with the ```treevae.yml``` or ```minimal_requirements.txt``` file. +2. Select the dataset you wish to use by changing the default config_name in the main.py parser. +3. Potentially adapt default configuration in the config of the selected dataset (config/data_name.yml), the full set of config parameters with their explanations can be found in ```config/mnist.yml```. +4. For Weights & Biases support, set project & entity in ```train/train.py``` and change the value of ```wandb_logging``` to ```online``` in the config file. +5. Run ```main.py```. + +For exploring TreeVAE results (including the discovered tree, the generation of new images, the clustering performances and much more) we created a jupyter notebook (```tree_exploration.ipynb```): +1. Run the steps above by setting ```save_model=True```. +2. Copy the experiment path where the model is saved (it will be printed out). +3. Open ```tree_exploration.ipynb```, replace the experiment path with yours, and have fun exploring the model! + +DISCLAIMER: This PyTorch repository was thoroughly debugged and tested, however, please note that the experiments of the submission were performed using the repository with the Tensorflow code (https://github.com/lauramanduchi/treevae-tensorflow). + +## Citing +To cite TreeVAE please use the following BibTEX entries: + +``` +@inproceedings{ +manduchi2023tree, +title={Tree Variational Autoencoders}, +author={Laura Manduchi and Moritz Vandenhirtz and Alain Ryser and Julia E Vogt}, +booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, +year={2023}, +url={https://openreview.net/forum?id=adq0oXb9KM} +} +``` diff --git a/treevae/configs/celeba.yml b/treevae/configs/celeba.yml new file mode 100644 index 0000000..3495f5e --- /dev/null +++ b/treevae/configs/celeba.yml @@ -0,0 +1,39 @@ +run_name: 'celeba' + +data: + data_name: 'celeba' + num_clusters_data: 1 + +training: + num_epochs: 150 + num_epochs_smalltree: 150 + num_epochs_intermediate_fulltrain: 0 + num_epochs_finetuning: 0 + batch_size: 256 + lr: 0.001 + weight_decay: 0.00001 + decay_lr: 0.1 + decay_stepsize: 100 + decay_kl: 0.01 + kl_start: 0.01 + + inp_shape: 12288 + latent_dim: [64,64,64,64,64,64] + mlp_layers: [4096, 512, 512, 512, 512, 512] + initial_depth: 1 + activation: 'mse' + encoder: 'cnn2' + grow: True + prune: True + num_clusters_tree: 10 + augment: True + augmentation_method: 'InfoNCE,instancewise_full' + aug_decisions_weight: 100 + compute_ll: False + +globals: + wandb_logging: 'disabled' + eager_mode: False + seed: 42 + save_model: True + config_name: 'celeba' \ No newline at end of file diff --git a/treevae/configs/cifar10.yml b/treevae/configs/cifar10.yml new file mode 100644 index 0000000..3bf8b90 --- /dev/null +++ b/treevae/configs/cifar10.yml @@ -0,0 +1,39 @@ +run_name: 'cifar10' + +data: + data_name: 'cifar10' + num_clusters_data: 10 + +training: + num_epochs: 150 + num_epochs_smalltree: 150 + num_epochs_intermediate_fulltrain: 0 + num_epochs_finetuning: 0 + batch_size: 256 + lr: 0.001 + weight_decay: 0.00001 + decay_lr: 0.1 + decay_stepsize: 100 + decay_kl: 0.01 + kl_start: 0.01 + + inp_shape: 3072 + latent_dim: [64,64,64,64,64,64] + mlp_layers: [4096, 512, 512, 512, 512, 512] + initial_depth: 1 + activation: 'mse' + encoder: 'cnn2' + grow: True + prune: True + num_clusters_tree: 10 + augment: True + augmentation_method: 'InfoNCE,instancewise_full' + aug_decisions_weight: 100 + compute_ll: False + +globals: + wandb_logging: 'disabled' + eager_mode: False + seed: 42 + save_model: False + config_name: 'cifar10' diff --git a/treevae/configs/cifar100.yml b/treevae/configs/cifar100.yml new file mode 100644 index 0000000..a44955c --- /dev/null +++ b/treevae/configs/cifar100.yml @@ -0,0 +1,39 @@ +run_name: 'cifar100' + +data: + data_name: 'cifar100' + num_clusters_data: 20 + +training: + num_epochs: 150 + num_epochs_smalltree: 150 + num_epochs_intermediate_fulltrain: 0 + num_epochs_finetuning: 0 + batch_size: 256 + lr: 0.001 + weight_decay: 0.00001 + decay_lr: 0.1 + decay_stepsize: 100 + decay_kl: 0.01 + kl_start: 0.01 + + inp_shape: 3072 + latent_dim: [64,64,64,64,64,64] + mlp_layers: [4096, 512, 512, 512, 512, 512] + initial_depth: 1 + activation: 'mse' + encoder: 'cnn2' + grow: True + prune: True + num_clusters_tree: 20 + augment: True + augmentation_method: 'InfoNCE,instancewise_full' + aug_decisions_weight: 100 + compute_ll: False + +globals: + wandb_logging: 'disabled' + eager_mode: False + seed: 42 + save_model: False + config_name: 'cifar100' diff --git a/treevae/configs/fmnist.yml b/treevae/configs/fmnist.yml new file mode 100644 index 0000000..56c2459 --- /dev/null +++ b/treevae/configs/fmnist.yml @@ -0,0 +1,39 @@ +run_name: 'fmnist' + +data: + data_name: 'fmnist' + num_clusters_data: 10 + +training: + num_epochs: 150 + num_epochs_smalltree: 150 + num_epochs_intermediate_fulltrain: 80 + num_epochs_finetuning: 200 + batch_size: 256 + lr: 0.001 + weight_decay: 0.00001 + decay_lr: 0.1 + decay_stepsize: 100 + decay_kl: 0.001 + kl_start: 0.0 + + inp_shape: 784 + latent_dim: [8, 8, 8, 8, 8, 8] + mlp_layers: [128, 128, 128, 128, 128, 128] + initial_depth: 1 + activation: "sigmoid" + encoder: 'cnn1' + grow: True + prune: True + num_clusters_tree: 10 + compute_ll: False + augment: False + augmentation_method: 'simple' + aug_decisions_weight: 1 + +globals: + wandb_logging: 'disabled' + eager_mode: True + seed: 42 + save_model: False + config_name: 'fmnist' \ No newline at end of file diff --git a/treevae/configs/mnist.yml b/treevae/configs/mnist.yml new file mode 100644 index 0000000..327ef89 --- /dev/null +++ b/treevae/configs/mnist.yml @@ -0,0 +1,39 @@ +run_name: 'mnist' # name of the run + +data: + data_name: 'mnist' # name of the dataset + num_clusters_data: 10 # number of true clusters in the data (if known), this is used only for evaluation purposes + +training: + num_epochs: 150 # number of epochs to train the initial tree + num_epochs_smalltree: 150 # number of epochs to train the sub-tree during growing + num_epochs_intermediate_fulltrain: 80 # number of epochs to train the full tree during growing + num_epochs_finetuning: 200 # number of epochs to train the final tree + batch_size: 256 # batch size + lr: 0.001 # learning rate + weight_decay: 0.00001 # optimizer weight decay + decay_lr: 0.1 # learning rate decay + decay_stepsize: 100 # number of epochs after which learning rate decays + decay_kl: 0.001 # KL-annealing weight increase per epoch (capped at 1) + kl_start: 0.0 # KL-annealing weight initialization + + inp_shape: 784 # The total dimensions of the input data (if rgb images of 32x32 then 32x32x3) + latent_dim: [8, 8, 8, 8, 8, 8] # A list of latent dimensions for each depth of the tree from the bottom to the root, last value is the dimensionality of the root node + mlp_layers: [128, 128, 128, 128, 128, 128] # A list of hidden units number for the MLP transformations for each depth of the tree from bottom to root + initial_depth: 1 # The initial depth of the tree (root has depth 0 and a root with two leaves has depth 1) + activation: "sigmoid" # The name of the activation function for the reconstruction loss [sigmoid, mse] + encoder: 'cnn1' # Type of encoder/decoder used + grow: True # Whether to grow the tree + prune: True # Whether to prune the tree of empty leaves + num_clusters_tree: 10 # The maximum number of leaves of the final tree + compute_ll: False # Whether to compute the log-likelihood estimation at the end of the training (it might take some time) + augment: False # Whether to use contrastive learning through augmentation + augmentation_method: 'simple' # The type of augmentation method used if augment is True + aug_decisions_weight: 1 # The weight of the contrastive losses + +globals: + wandb_logging: 'disabled' # Whether to log to wandb [online, offline, disabled] + eager_mode: True # Whether to run in eager or graph mode + seed: 42 # Random seed + save_model: False # Whether to save the model. Set to True for inspecting models in notebook + config_name: 'mnist' diff --git a/treevae/configs/news20.yml b/treevae/configs/news20.yml new file mode 100644 index 0000000..987a8fc --- /dev/null +++ b/treevae/configs/news20.yml @@ -0,0 +1,39 @@ +run_name: 'news20' + +data: + data_name: 'news20' + num_clusters_data: 20 + +training: + num_epochs: 150 + num_epochs_smalltree: 150 + num_epochs_intermediate_fulltrain: 80 + num_epochs_finetuning: 200 + batch_size: 256 + lr: 0.001 + weight_decay: 0.00001 + decay_lr: 0.1 + decay_stepsize: 100 + decay_kl: 0.001 + kl_start: 0.0 + + inp_shape: 2000 + latent_dim: [4, 4, 4, 4, 4, 4, 4] + mlp_layers: [128, 128, 128, 128, 128, 128, 128] + initial_depth: 1 + activation: "sigmoid" + encoder: 'mlp' + grow: True + prune: True + num_clusters_tree: 20 + compute_ll: False + augment: False + augmentation_method: 'simple' + aug_decisions_weight: 1 + +globals: + wandb_logging: 'disabled' + eager_mode: True + seed: 42 + save_model: False + config_name: 'news20' \ No newline at end of file diff --git a/treevae/configs/omniglot.yml b/treevae/configs/omniglot.yml new file mode 100644 index 0000000..65f55ba --- /dev/null +++ b/treevae/configs/omniglot.yml @@ -0,0 +1,41 @@ +run_name: 'omniglot' + +data: + data_name: 'omniglot' + num_clusters_data: 5 + path: 'datasets/omniglot' + +training: + num_epochs: 150 + num_epochs_smalltree: 150 + num_epochs_intermediate_fulltrain: 80 + num_epochs_finetuning: 200 + batch_size: 256 + lr: 0.001 + weight_decay: 0.00001 + decay_lr: 0.1 + decay_stepsize: 100 + decay_kl: 0.001 + kl_start: 0.001 + + inp_shape: 784 + latent_dim: [8, 8, 8, 8, 8, 8] + mlp_layers: [128, 128, 128, 128, 128, 128] + initial_depth: 1 + activation: "sigmoid" + encoder: 'cnn_omni' + grow: True + prune: True + num_clusters_tree: 5 + compute_ll: False + augment: True + augmentation_method: 'simple' + aug_decisions_weight: 1 + + +globals: + wandb_logging: 'online' + eager_mode: True + seed: 42 + save_model: False + config_name: 'omniglot' \ No newline at end of file diff --git a/treevae/evaluate_treevae.ipynb b/treevae/evaluate_treevae.ipynb new file mode 100644 index 0000000..9f0d4fe --- /dev/null +++ b/treevae/evaluate_treevae.ipynb @@ -0,0 +1,152 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import numpy as np\n", + "import yaml\n", + "from sklearn.metrics import accuracy_score, normalized_mutual_info_score\n", + "from sklearn.linear_model import LogisticRegression\n", + "\n", + "from models.model import TreeVAE\n", + "from utils.model_utils import construct_tree_fromnpy\n", + "from utils.data_utils import get_data, get_gen\n", + "from utils.training_utils import predict\n", + "from utils.utils import cluster_acc\n", + "\n", + "checkpoint_path = 'models/experiments/mnist/20231025-175819_d6be9' # ← 请替换为你的路径\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "with open(os.path.join(checkpoint_path, \"config.yaml\"), 'r') as f:\n", + " configs = yaml.load(f, Loader=yaml.Loader)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = TreeVAE(**configs['training'])\n", + "model.load_state_dict(torch.load(os.path.join(checkpoint_path, \"model_weights.pt\"), map_location=device), strict=True)\n", + "\n", + "tree_structure = np.load(os.path.join(checkpoint_path, \"data_tree.npy\"), allow_pickle=True)\n", + "model = construct_tree_fromnpy(model, tree_structure, configs)\n", + "\n", + "model.to(device)\n", + "model.eval()\n", + "print(\"model loaded\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainset, trainset_eval, testset = get_data(configs)\n", + "gen_train_eval = get_gen(trainset_eval, configs, validation=True, shuffle=False)\n", + "gen_test = get_gen(testset, configs, validation=True, shuffle=False)\n", + "\n", + "y_train = trainset_eval.dataset.targets[trainset_eval.indices].numpy()\n", + "y_test = testset.dataset.targets[testset.indices].numpy()\n", + "\n", + "print(f\"data loaded | Train: {len(y_train)}, Test: {len(y_test)}\")\n", + "\n", + "train_acc = cluster_acc(y_train, trainset_eval.dataset.targets[trainset_eval.indices].numpy())\n", + "test_acc = cluster_acc(y_test, testset.dataset.targets[testset.indices].numpy())\n", + "\n", + "print(f\"cluster accuracy | Train: {train_acc:.3f}, Test: {test_acc:.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_, y_pred = predict(gen_test, model, device)\n", + "\n", + "acc = cluster_acc(y_test, y_pred.numpy(), return_index=False)\n", + "nmi = normalized_mutual_info_score(y_test, y_pred.numpy())\n", + "\n", + "print(f\"Clustering ACC: {acc:.4f}\")\n", + "print(f\"NMI: {nmi:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "z_train = predict(gen_train_eval, model, device, 'bottom_up')[-1].cpu().numpy()\n", + "z_test = predict(gen_test, model, device, 'bottom_up')[-1].cpu().numpy()\n", + "\n", + "clf = LogisticRegression(max_iter=1000)\n", + "clf.fit(z_train, y_train)\n", + "y_lp_pred = clf.predict(z_test)\n", + "\n", + "lp_acc = accuracy_score(y_test, y_lp_pred)\n", + "print(f\"Linear Probe Accuracy: {lp_acc:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prob_train = predict(gen_train_eval, model, device, 'prob_leaves')\n", + "prob_test = predict(gen_test, model, device, 'prob_leaves')\n", + "\n", + "leaf_test = prob_test.argmax(axis=1)\n", + "\n", + "def compute_dp_score(labels, leaves):\n", + " count = 0\n", + " total = 0\n", + " for i in range(len(labels)):\n", + " for j in range(i+1, len(labels)):\n", + " if labels[i] == labels[j]:\n", + " total += 1\n", + " if leaves[i] == leaves[j]:\n", + " count += 1\n", + " return count / total if total > 0 else 0.0\n", + "\n", + "dp_score = compute_dp_score(y_test, leaf_test)\n", + "print(f\"Decision Path Agreement: {dp_score:.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"======== TreeVAE Final Evaluation ========\")\n", + "print(f\"Clustering ACC : {acc:.4f}\")\n", + "print(f\"NMI : {nmi:.4f}\")\n", + "print(f\"Linear Probe ACC : {lp_acc:.4f}\")\n", + "print(f\"Decision Path Agree : {dp_score:.4f}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "new_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.9.20" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/treevae/main.py b/treevae/main.py new file mode 100644 index 0000000..f66eb1b --- /dev/null +++ b/treevae/main.py @@ -0,0 +1,62 @@ +""" +Runs the treeVAE model. +""" +import argparse +from pathlib import Path +import distutils + +from train.train import run_experiment +from utils.utils import prepare_config + + +def main(): + project_dir = Path(__file__).absolute().parent + print("Project directory:", project_dir) + + parser = argparse.ArgumentParser() + + # Model parameters + parser.add_argument('--data_name', type=str, help='the dataset') + parser.add_argument('--num_epochs', type=int, help='the number of training epochs') + parser.add_argument('--num_epochs_finetuning', type=int, help='the number of finetuning epochs') + parser.add_argument('--num_epochs_intermediate_fulltrain', type=int, help='the number of finetuning epochs during training') + parser.add_argument('--num_epochs_smalltree', type=int, help='the number of sub-tree training epochs') + + parser.add_argument('--num_clusters_data', type=int, help='the number of clusters in the data') + parser.add_argument('--num_clusters_tree', type=int, help='the max number of leaves of the tree') + + parser.add_argument('--kl_start', type=float, nargs='?', const=0., + help='initial KL divergence from where annealing starts') + parser.add_argument('--decay_kl', type=float, help='KL divergence annealing') + parser.add_argument('--latent_dim', type=str, help='specifies the latent dimensions of the tree') + parser.add_argument('--mlp_layers', type=str, help='specifies how many layers should the MLPs have') + + parser.add_argument('--grow', type=lambda x: bool(distutils.util.strtobool(x)), help='whether to grow the tree') + parser.add_argument('--augment', type=lambda x: bool(distutils.util.strtobool(x)), help='augment images or not') + parser.add_argument('--augmentation_method', type=str, help='none vs simple augmentation vs contrastive approaches') + parser.add_argument('--aug_decisions_weight', type=float, + help='weight of similarity regularizer for augmented images') + parser.add_argument('--compute_ll', type=lambda x: bool(distutils.util.strtobool(x)), + help='whether to compute the log-likelihood') + + # Other parameters + parser.add_argument('--save_model', type=lambda x: bool(distutils.util.strtobool(x)), + help='specifies if the model should be saved') + parser.add_argument('--eager_mode', type=lambda x: bool(distutils.util.strtobool(x)), + help='specifies if the model should be run in graph or eager mode') + parser.add_argument('--num_workers', type=int, help='number of workers in dataloader') + parser.add_argument('--seed', type=int, help='random number generator seed') + parser.add_argument('--wandb_logging', type=str, help='online, disabled, offline enables logging in wandb') + + # Specify config name + parser.add_argument('--config_name', default='mnist', type=str, + choices=['mnist', 'fmnist', 'news20', 'omniglot', 'cifar10', 'cifar100', 'celeba'], + help='the override file name for config.yml') + + args = parser.parse_args() + configs = prepare_config(args, project_dir) + run_experiment(configs) + + +if __name__ == "__main__": + main() diff --git a/treevae/minimal_requirements.txt b/treevae/minimal_requirements.txt new file mode 100644 index 0000000..3427b94 --- /dev/null +++ b/treevae/minimal_requirements.txt @@ -0,0 +1,14 @@ +attrs==22.1.0 +cudatoolkit==11.7 # conda install cudatoolkit=11.7 -c pytorch -c nvidia +cudnn==8.9.2.26 +numpy==1.23.4 +Pillow==9.3.0 +python==3.9.15 # Do this first +PyYAML==6.0 +scikit_learn==1.1.3 +scipy==1.10.0 +torch==2.0.1 # conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia +torchmetrics==1.0.3 +torchvision==0.15.2 +tqdm==4.65.0 +wandb==0.13.5 \ No newline at end of file diff --git a/treevae/models/__pycache__/losses.cpython-39.pyc b/treevae/models/__pycache__/losses.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d562e0a671179efea54da37522d47eb677439cd GIT binary patch literal 1978 zcmc&#OK)366rP#;@U`0n@_>-=vH(@JknAqlP(=}fmQ@LYh>g|i>gLYaxygOFXU4^` zu2-p)bdd%BAX)k^&6ZWx?2u?V=elW!*dChdyJ! z)8z6eaPkc{eHVpfT_(8{Dev;C7nvg%XPMjeIFpWapR%qmJ?Ud#lQkJ&A4svqTJ`bQ zjR%F+&E7DNO;Y4IDxhh0R1;K;Qv<_ly=fF=&7xOr-GA`ykL|`8KFu~`hq{-ai#?jl zXifi&!mwR_R2g+Fp7EI}N8%nk;Lil^QF9L2AwLv72Mu(e2F5kskvQN7oV27bYsX?H zCr;_0$-(9s$D60z*0qVigfhuuS~>qHHe|#WHobNqI^H; z+*5kb6obzF(Lh0xsoZo7DD@;#oeYnsx|3p{m2MCAZDT!6w26ysut9nL0~;H+X+OLg z_HaANBiL`Rr$rp4dIJaN2jB2Hz)O~}={qPYpKnk25i9xXEDeH2m<;mY*%eYq=a?_E zmyvL-n-{riIo3<^!O&PAgWFfe);;6^Jx63p0!WhVObQuct&e6jR zBTJrHo=hEE`&A|V9iuHD+4)g;9?wigc@<5#iK;M%f(Du#@r8Kq^R*YSOk3}vnEFto zG3~-C3hQD1<2~y^l=qc&Qp}hOpV0QFBt9k`mx4Nb;$Ri1j|ig!JYC`K5=mFlwvlYxWsWKG8}Q-qq`FiiOJus?FOh-FL~*$ zR&;)fe1|0}CVuIbf>ulaq*3CH%WQ>B+{f(_DXisg0qb>Q)|E-rw{@&^9j3{ivW>CA;;qK;8Qt?a3Ogr}W+2%{a>G9XyOa%6lT=q4 z73KwGQEC@|xU{%ATl3;H$_6$_i+(tOV!Ks)y&ye}M5ekPY79p3AwJgikrL>+tB|zK zw`B+&3kem-t100@6AGBKjJ5hMZ}JZ~e1-E?-PXcT7I7F>Zg-a|W^dfDnd7oo1?1jZ*agiW7A}JCpfgp914_Z+q>#{6~mSskE)>2IdI|E?Bz0fnu z1DcH;L5^d`wCTj<#EvQj=D`)EQqGGj4=Kk_`2+IcN>WpKNI4IQUhUYr zx?SnaG-f*0MwQbU)SYe2ay;9vb>6no0ix99it#)7SVPO0a2EBVX9 zabAD=Iiu6Fo%UM(ee_bfRM&!Xv)6S!F=%?Gha$mDv)AblJjd+!T+b1$o~WmT!lvVy zPPd7=LMCiSCiGWcel^Iv)uq@Q)Vt^=Y++CA4b16~4(OZc6#%l0&W^Iug`N}~YARf25DH%qT5O5O)3|ZGlbI@j_Y(i!!n4A=hiKH1AedN<4VyCxP7PD+Gsg03cqfNmgRvv zR@-=b;PpZt8`oOiR!GZ`vojEm5l@L>gF~XVKA@^bZ^LL?9<_dz$gQ^>cP(!9MbBut zhV5*$y6DmR4NzV0d0ww`ZqPTHTULwf+YaWh)9Qk3tBE04U3a4=I@~qa@U|?^*bu!A zVvf=3Vtp9bjc}oGp6K9bqX&J_TesF*G|$vruYb<&UF#Z_ zXl}JUr-?=27}vI1%`L;Tu7FWgpZY_zB!}0G%Ue#|A}AP4ftWAH>o&n3Iby0Ex9HtC zz`T~{8eJ@HV@0-UJAJ2Xy9QW`sx4vkEdlPjb?!Z~;&wlXDz%DT(sD6HE|CTgF%y%W z!H70`?RM{4tGkJ@H67v6ob)X$0Q^#M8dQBWqt9933->qV>|F;f7wIJ$8c1QWwB0q{ z|E7$O*5=cmTuzP~kL_Ii7iU>buXR;U>P1?X@eA4(m(kO%)p4TvP88l4bWs!R#2oU7 z#wCM6JKOl8+uH2(TK1{|Ven>5n%`ZNV#{5lRHBbjyId%4g?A-ZdN>0lwB#m@#Sy>a zbsV>2wfkF^anaal_pH%UrUqVm`JC15I#!!z=$g~o+`>e5TV4w}yUdN4h}^!m>9%&U zYA+hJz~h%h>uJtP>{z?ifz0G3Ldeg77N^O3^H-s`pFKF};1pWkv&iB}yKBP3-PA{Dak zszd!3U0O-rmgo&OA>}0e zOtOX%)|ez|#uCd6SNh>J!Pi`nRj3)y<&=S*F!CT{m@04V4W;C+N&hQc-ga+ zfO%x00-GY*D_Ex2EMbR(n{3Z%ulGJgldje2x1DpLwr(0+{TJ(j+$PYwhwaY82Tl@% zhBwmLlgfFe)or#1HZ7^pdKj%ySx`b~ZfLO{7&u z^9Df`%?;nzA_tE>2g%WT14<12e z8~rl)$KJ_3I2e}cb=kO-x)4s!=1ncP65sfY#=io+4R|L;e}w%)} z;*Qgd(Kh4dHn|c>>SOF`vJwv^ zxiulPr8UPTdi~G~f+q?0V$>&UaHZY468ashiN#F5#9>mp4j}-I-E|z>u}7;ia%`mq zu`)iU*Q8FdV#!1c;T*Bfm1}=&r$!@f4g`4yY@r(~uG8LFn=sa+M%kp9X!99(wgVDo z%*io>XcMLalpT39hmy5C50;|ocC2=rO07^&R}D^^Y~qg(5B5D-mX_fT)+4f1H!K@T zR@rQF!6Zj`kd@#u|qA1}z~taeZqO_E*9vBQpaN+k(Cz_9jNI zCeVi|#OW#EaZ^BmJjGLeWl!7JZY!$dBW~)xvaQ?dl7e^&@szDy(Zr)(8tF`To^y2M zJnLmSC*|edR=uhHJl9JtDYuDAfm3OwQsj7s<0Y?*RB=z`-_o9XMHk=nrtR#W`ZnsJ z1rG1UrD-7eWP{5#`ocJ*dzsCSnSe{av;)Q7rvWdEqIhkaP~j@h+nlzZB< z%9|I4#LwmZeRV9%5x}pXC0hnKKCKJt$=;EjffTr%?uN#Bbwrdm8gl+xgBdZRg7gZatWYpGoxN zwWuFqAD$fMK>tC|KkLgraLe}5n@kz+&hR_VmA#7I-O_g!@#J`17-vZFctp`&q`pZC zGLF*Q%)Mor?NP7x`vX9+N@%Z(fA9)WU=Oms7 z@tnbP8qbP-{3h+^hwVjxh9`Ev6l>5|EdfXb*xE>;#jwh`xYQMREii8{vU-ejN%K8D zp)n!%8TNx#f?i-7ccTTN)@mmaq98v4UxI9it2A`rDvjKgw&{A7@EUn(Mw@Vl8+n-O zA#$)g{|Zz|oBYjAD}pNR2OemYLd*j$BE0ipDt5MmZ0K1Avq={-bV`F7G7q?w&%duj zW^Ura-#`6okPY1oLE^QZ3W`Z6FeoN*l^}avI{870V3xWr=21s1P_T&LeIn6JifIiQ z1kZmCk4r^1v0UD)3{@X!@l{2neD&MfP6|_}0=urZfLUAJrgN?3I_5@8xLz z-&0UjTe}dbEpZGaIry(noc}c!bA5hy3jmwdz0o@Vtm9sREq?yF>wSRf1Q?trg>uz$ z&a;|!<SMd_wV5Eed0r1Z4?u4+%ZPcoTTwEqfR2ij5;z@J;jDL3wK^_2!Q?D z`c92mrilkp@Dh;vG*IJBnvdhi5GN1>>h_2^!l3L25d4TDM82NlxeE$m=LxzAiqZm< zb1ceKEFDZwxL!dP;0bN3U@DB5J3&UeO+gNJJJzsp>#IT99f0wo1TsM`;$e`HJC3Of z(%?agY*9f+Ji+Xxp7$d8{xn;T{hSaz9Bp_#fYU(+FuO}hh6{sIa$&kbX$*r4^tRIt z@~>O%ffEwpts`)fpokd?suavpqqzta38urC^dbiRpb%yd01@cGwFE8ZK+!aca|) zYsDZ67Z7{_kNcYllx$heXhk)z=IGHeRaQJLkIi1oxZ8j0ol8$O?8;ZyCV7v8%e$y77pYTDaX!*?@JhE8P7DH3TozI z6=_%WTPcC~eu>%`&iFGn*`PBvS)VgD>`q}SvtPwC>o3Aqt@%s7wgBtywr1yU zYr`X`w}5g-hs)ikK!47k1AQIbzh_Nv9(5Pd*3w=IJpzTerK8-^@>om9{A2qIx0T^> zzlL&){xSc^ZLOvZPoT$3{t3H82ad3GPlV%w7vP^58r_^<@K1bH85aEN_7OWXP;cp2 z1=al{{_)%D{!vhT0I6l-3#U%{Cw&di1JwFaWkK0Lc3Y|8EZ9G}sKD0dUxW`(Dp_Jd z8J_ZM{weh7lwb5up-r+~kK2`jf;N{CKM}=kvV}5NsV~c6e|3@ZOP&D=WD$}53acJ8 zO7LqE`I}0VPbF)LUz0tm(cGUNK8R6bJliLcIx{?r)sptlpoQFm;-B`9(~6{Z&iQA< z{0B$*-EW}osgO%Q^vkGS)e}q*c+l4_7Z+Zie?w9;j~}4t=cvC95lxdaZVAt;9-s#h^?aJ zgZ8|=@Q&h_$s@svSiGqX&-u#!8F-J=e)+aOTtlhT{u&Om>D_t%yua!{?2{dO)-QpZ z=lqJj^lo-|0XZw!Z{Nl0DS~$fo;5rd>?2Q*eO$-8AHe#)U?1i5LrDK6TDf2^BYgqp z&oS#Um;kW>`g{nd-xb?Q9B$vAp-Ygz)Sz)A2L4{;6J>e+vTk|Jt#f4A0WF|>j@L(^ zl|*U|*wZmVIu?i%L_LW*jlFp5EZ_&eR&e$rj_{!(0rBA?1f5ipIj}lb`YbGG!VFNB z(1Dj;mdBPM=0$G~mXEdQ;g#ao2Wv9&ba*X(HBn>C#ty5=Ga7)>3Eh9#xcX+(+zsm_ z5415dMzSUFLJ-lqLs)lflWr0q$!`fKldw0<*+t`2h|}PLz$uBeO>~vIZ^}1lbR2Nw zV{Je@Oqp<5z;B~-{JTyh&2sYZt+O-^P?voj;)AM>R&$f7eD^~zyJO2lkC5UW!M7ZC<2%_FiHuoy7XQuNKrH56UYxz z)WKkS*OXTa!gj?Gs`yz7sFZk(0y4(MX#}l*#pbMQ;yJ`&zggEW2N}0ThH5?4C=kJr zT9}ZAafAFXvSj&yF?7mcHe*aCFO2szZYRh^qi_8j&H}k+hPkhOa?C*>!zXqS$n9bW zp-i6CK3s&`p&#|0${i--ZqRji29C4qpc(cHcucX!G;#_)1CNcJg56V*y%l21hj~08 zt;2DxXUTAudpO938#E{)5^pW`czMe-N@KeuNb|xE3W){TD3XyKL5<>*#OX9kq58w( zGf1&%Ej~`Mk5ND+#4{AIDgP|RK1IRf2!a%ru*-IFkluAf&kgi`uOFn@PYLpzh09J? zoM4Jd%BbeqLE*(NTm@L~n71HviEiSwVQTLPFDQ?DLIP_-fj1;qdQhUF19la#v71hi zhGpI92N}5(00qdo5UW&4Z}&EX9LT}3>~2bjq{!VStQ+neok;~*aFUj7km@*AH_)4w zCmuy9c3Z@g6#Eqlo*Qq$U2;$QS8fjjM zL7u!>Q%)33AjeFc*f%Q4>ExX1*hI1f4+>JGm~vx{OLJE=sawSm%x2oW&B)IIm;rHg z)oG8m8>YtW49p!9cXomj_62UHSYa~5ulm7Mlnq}B+c=Q%^NiIg5Br%8Lt#b?R?E+EQ4&CiNupKuNVuIP^g_+@*hwg6@wID2pjIqo-8-!$YfS zc|8wLOL}kk)lb69s>0*q5_(z7!}H3)8!PL1?YMeeD`qMPtNL*@r&rVp-ed8APEr&O!)x5T-&+FvTrPK;?%c^!KpDLoY4Dzb_B8B`ehf;YpuP-tUO_l#S`u)#L zK8HSnB0sy+6S8Ma`jMMQ(vSAL`%?rks^IR~+Pi>ahbkN&4e&p~$^^szEo7o*!|a43 zsM~;j5dv(5k(&Y>oC4689;m`cEQKeH{0v+|bsuIlT%gQ?LUh?^rM&ET-cB{7poqUg zEDKsPmZX*87XHl7z+}zBWh6RqFq5=FM?)@y*A#~ln6;jTD2^3fc-SY8Yvdk_PQM5O zxPQBh_IFH`wDssZ64_HBX+EAAJ8K)Ay@ZuWBTrsw9Nh)EXuNfnRy=e#>u!oewevnl z9{)xIhzkr4ZG@^$v_w6<=dHTP6p>)2mQd)iYQMy)zMhlHg7pb$T;Wo1je_eGux7b| zSbd5mtN1eFK{4FXd=pCSQ9&mB4#mDr!S7M<9SX)ZE#3B&w|gbKw=vW1Nt9EHw2f$! zR`GBgTWTI#gP@UlfG@lA(bgCzF}6cepcL;AQjlcg*f_70nKp)gy^p znry+<)8Y@og@o0>GkO_$z&=S=ia(%$P%fzwC+4dZF`N7Ob($`|h7PF(sRy~bJ2yEk zT(YTyV)RE!olrh%+x>F{1W(y2Bt{J}NUH&0@lPu_Fx0jI`QXS>gG^G*v<+!yW;}pO zJZF#7oNWW}VPHnvp5j;{o7M<(M-ztpweATuaEK_#~SR7dRq(<}J^(Ltjj8 z(9hBNcE}uKMS|_&?brC&yngqC?+t+b@d}K1%$F(>4bloBH-(oAoh6TMD6Sc}d2RP_ zP}9Y&MhW5Kro#=7oetdGrpwpB%UvVBHZYQ<_vcNse7f(uAZP|RtHCg?23AG9!DR_L z=Ei+Pc}alQKduR!^B~?=!~0a%FwmIc4xAPo&9`riq$cx3UM7g|>N3kmR}|Kg6B2Wl zWaZ^VgVL3}3s2#A;}SL9>)^X!uP1W0Mz3p+f75DBLVZ!$6;85rR2TIjnPsi_H3vgj^7UqntKrz{Rf z5|hsC#(F$5hhww!!y{q;56IW^;uVrvm}Ny=LU2jYr&kFhmk}ODtr*VwJl@uih^tiS zgCUd{I1vsNUZQ}l;*sDavqoH@piM!Cf-VJp3fOvP5$aKLWJi;{ts7D}`}Ar^0r!SY zuP;$Dh9ua^7QaP#Z&T2uV2^^4nLQfhSLnrWQ}A^PewTu;Qt&$ze1if)0mL^c_!b4X zC>Yu4zfZ}-jP?1*_#0n;KOyn|1}Z5tWT>;R&7L-VYd!R;;RT0YHT>El{AlSOjbfgUm~uOC^OrpOQlG zXK^zw9SVL`{G1dB3;vv+W5J(e!H=&TL%}Z%${Y*+q7StK;r@4?j`L3#p>)EysmYd* zN3D+KLGVIxK%`IM$>YpQhXGsm_3*T?2;n}B6Qm+Uc%}Ohh}e>k6C9|CnftHjPi>b& zA?<@gnH0_bG$>YK23MgJ3U<-2hDw3XLFqJeYI_EKgDS%Kt58LGn@)6Qanchh1K94gj-Cr$YmqDz}E;jrMBeQ^_qS>wkU zkn7<`Fu+RruBmIhWXS8LWcHF#jI;zk!f&KtHA}}ip;YcqOU1)G_+`>1`TZ4WM;(c- zHOsc+@7mn0-k@z8@_x)nlgI-GoJ!FRl{J|f(=5;7V^VyK)a(xoFk!Pl7&#jE)QBH| z+PoAV<+*&oBvE8O8YBHR`fes16o2B@Tc?j3VXlARjoh)$0|dR-?cDpAEiwt{?=A|S zkXHU46<>#99{Biaax%z6`b7`GK7{T(Olna-?Q=f|Fyhy#c7p;^6cZ>AU&8tV zNXhe-ha*WJjX1n>eb~8+coXF6b5fgslt@n~?T=F)EA0)6-Bney8^rp0LRCLQ)kt59 zw-CT^Ae5vr74ILZwMl>DZWn!KIjODhs;Eaw`D;X;l`@?`h?^9Aux4JTA`^bu^aQ__ zz)_d%0;R@!85)?haT$&o>EfKarXB|h4RjkRx^o1_?G7Q|W7?STZlK-?U92r?y9;-j zqk55Uwb5s{rWxc-^1yK)4Dq6A?hLGU_$F(bcCU#WR&=jQe3t^orM^e8Kce7W3jP>D zPzvv3H645&5Z=k+k-tPm3<^pJ=y9tEfJA+|=;XentmTYcC3iaaMSfpj6n{cxA4h;w zCYDXzu6PQ`#%w66U;;-UvE%Yw@|$*-S57c<;G1Q>ea44gjp-26={i7N#1)sX#mOsb zbmvW8Tgy@NrLZI&cwoVhslr$N*qaY(vXST@Onw_Jic~T43iFc(l$d^Ihd~8$Pe_hG VN#-!KBc;nrX>_5i0>8Yd{SOKd87cq( literal 0 HcmV?d00001 diff --git a/treevae/models/__pycache__/networks.cpython-39.pyc b/treevae/models/__pycache__/networks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70157cb089fa640dc83b13c9cd284c2e26b7779a GIT binary patch literal 12118 zcmcgyTXP)8b)MVK&RzimAOKO61S^uvwG>|EII$dw;svTOGAxml;vk9F%bfwR;9?f+ z8Bi3mbR3hGa)YX(xGJ$zl>+m|msB3|lBeW9_-&qYc|uiwKuWHX@_nagc4rq$(X4W1 zw&ra2>7JgM?(?1RboaDswX%WZrFV~P{NNjg@!!nk4+oj^IHM)gFnnXx@J-+9nyaS# zTU++3ZEBsfWm|@6Y!y~rXl&oPZ>$#IGW>$?-Zy+Vus6MX_G$^5qVGZDrJ6D{CBF#40)-06j{EUDg3%nGg0N2M?dmb$GE1GUsr?Tt&M?UnU&a{T;_hMCmEU>&!-9W3{T3NW?~r z{(l1SJN43vn(}5 zCv8@wrsVt9;LcUn>RVEy<{Nysl>-jv@q;jDZw>MT`cW3;>z7d{DOU^QLHqD%4 z=3JV2m6=x?cH%@qca4vf4(74bmo2t*BMGy`Fn zb!*<7Hjky}yt#WcKawZ2r*DhF@CqCpQl1xaM!OJwW8dhT56rt3QWzknVykc6(;8cA zP=@XvD=Q4jq}xZXD-O!MvFj`!j&9AqW;`@k8YMM{XHfGLG-$Ozu}JY0#W9MfDUL%t zwAC~COI+!>NYOTR_Yh;!M@bQ$;DeU(d3nUk4YQ1&H$m>sjr3P~;X{`LAGWpv41BrS z1Q>U^Ojn!DTRW|;E|IZNFQWORmJU@r=ytQ=lIK}MsUA3*Qx#`KbM(w=RSc|+%*erk z;qgT1y`=r1MlajX4yFg_{YPX*(c33l?bz8Y5W@o$xv{ff^kEPv^;n8BXHZ&_Qb$V5 zQtE(#GVE8Rv>>HbDJ|qmYq`?8E(Lns)Nf1)Vv8#WKg5$Rs|w@+d<+E4_j+SmO+%wz zrXUQGG8|&(wzzynIH*|)v8JO;iAjB_HKnNhdczT`QqS_T=O|85EKv;XN^aVKDq_RH zp6Qx1W_9G;J(?ejUjl!hqNV?L{2BYD44a!}-@Iw5PvXjcHPw;ktd*n~b=J3f9lud~ zB7myXcq8e)q!9POx9SWPzeXWlIqXkD=QnUhOAwF7D6v~McjrcWHbts>30LOYWlI`w zG?NoJGrZNraDF9to67r)i(z=kyV8vj$k9JQj6n`SJ7&O7k`m!#nIj=Hi1UkkuBK0r zs+2%Fhss@Et4kn9uLafa!`re|fglJ~@@ORmtmGwK5x99Pit*_I1D^vvw3?z|b{T$< zq}kqRg(0F5O`3yT^Zhb$O=po_hK)+ zeb!H$b{L)(c}C>f9HF8ZOnF*km}JmM9e*{D=+BqiuGYflR%<(Ig3UpllK^`K*G6<4 z1gvKsHo`+?c9STM#q~$qkq@DM(nW$eoFGBUI;EQA|s$YKpME&Y9MBFAC23 zN=%PdC+IzwhGv>XAH|pnS~W~2~VTwaP%mW z_}~IpsQ3I zbLmX?z)ainGFB$Wg_@TYgCy|@2Xx~IR>vHJaBWF*?KRdFbQ9P~`P(vgdAkFC{uUJ^ z-AJHKf{ikMxWLPnD2DMFTlo{5(en_bPsdyYdGfbtPJiZk4v?#Vp6PuyAk=JxAB|tW zhTxa`Y?GPf;!6?9D*5U&pVE{{c4+-uJQa;^6#% zZKd#Ly8;;#f~RD}ttD?w8`RIesF)+1-u_loP9XUiCO_*A*BhBDp$6u1*G!&;NG zUrlGFGLJ;QPPUeRZ|uL8=JB7luDN%5MV-X^YwlIAAgO4BI_Ujfqz=wA%aB-DR%4IR zuy*WkBQ;>sKcHH$CT@Gcq}t`b$I^Ew2s|;r#BFVF2ca(#ehaLOPCBpS|IO3phSNG@gy`FJ266_O?LoU zEF^3e`!+zDt{QE+!(2w9x4y|*zL~29(*l)#6qokP6R1>%hjuGu*r-gT((khT_bH?o z1?8hGn$7NCU?sX9wMrqqJnZThq5Ef?5rdr}t!8*9=5v%f=-kI*o<{OBwwq%9sG{Z& z=2^-kqIqV^wth(#Zoy^uOIWy-h5l6uL+@khRu*tw%+||E*S@)Oz0H@* z4yNhw1pEJj#2EIm;CEBjf%;tUL39LJ=-a-rgNQT6LqVAKd3AX;_V#P2qhs%Snkx#qW~)LvrXnjg zCP&7vp=b!#S-7wWIPg=CDr#S7(giw@M^IWHE&>Pc@nJu9Tn2m$+}er zuw6j7Zn-u&)fL=N!y$6=Oi}6FCLLEl<0eDn&2ypuL zoUvfsb^4Ah{;dD(r#$}kjfRV5=&ZYxkM=crzR6TTk$-zWOk&?@Mv3L$No;iMB#`|a zwIVuLf{cEwYVOYFUo?f(NPtM|<#`dO^ct)M%L#ZBJq5Qy#07{Puy~~yt{|>y;!mwH zog;X)LG2jy*P%;rM#PDM&aYYymeTXjH0T5@ja_dP)mKqbouJ@~peBJAnPOY&YZSwq zEHPW8pfUa%&WME=?pd!WI$F!fF>nVjK|SUU=mzKxWS4JsxQ|m#>iUtJ>7-q=8YkoG z*D0<*jQ9qQ0-uXL8~qQ&A>V*Gvv>qYuwUpqK)M^l2ke({mXSA~Mqs;Nc>syHtPodG zlNME&HRiBSg!;9(jyB*$*3G+aJk>8qoOk7J5p`L!$eMloo^{vDG`_W9478=b8&7W@ zLC8&is~ zEl7~)B?#n^xBsid&h?X_ZcL+XTUr)AickQd{Ir1a+P^>=!`F6PzZcRcYkovrnx zd@Z=OgT28{tDCq={Jfr;6NHVjc08OAs*s{bA&*iFW!E!G+#BSDNYh@7^283;&P;Mj z|AoScCWlzVqD#WhK5ow0-WQI;rS5stf@AVN_X>+v)!JPcHMq$>iq=KfAlQ8!37{6S zJ$w&6(Qw0wo(daGup*2kH=Fh`AxmR>XtD({9oxI=Hd<15C(41_4#Auc#Y6gh0P~8l zKH*K6%hr7k=G6hr(~y7s$^H-x>g59gJYP3Esz|W61n_md{$l}8YR>~60l^6131SU+ zfpG@-{VBSN&060Cv~fP5BbZ({0P!RJqBwx7cS{-E`=!Mp+{<~bZ{c&Ihk_fiK?hXc@C!wb+a!k}FIIHjSpcq*HB%P`X= zB;JiqS2lc?v0+<{+AZt>u3^s!Xm+?ipVV62&N|k|dx$Djqyv;Lu529BuA+~$riL+i zN_@J_QhEw?i$YOE6!Z~_Fdgn{CDvM7I}0*M;=Q*WwYIjqf%<^e1O(@qqD#bw=i4+%L!RqkRp{%;U`%G@$RFQ`8*tHN-rRCs|dnzy0<*?JO4I#Y`7) ziUWBYiBSh4TT_T8$<)JdK^SWSV%(f!9DxlEr_+|r}m92E5mz2}bDzN{dF5t4HqQm;;4!-4>WYw>-_ZSRuip@P;#NJiB z`eZ}p`|}h-JrQ1g5?`T+&@*%5Vf>jzE<+{c-g-FNsIgpOE?z#PE2hNxiVzR{B6L^ySaeH<>A;2xcT zb`X;kB__dSUt$tZ)*{SYR8mUZtC;ZN!zr@EprLqn8CcBnNUyL*cwWRAu~!COgtK(`P%#|%2*ktG*jj@W zyk@>EF2JXd`k)o#veb11Up5Vl^Z6k(ym1dBZ!hmG#z$h)Z=drpt4CCTZ|1@#zD0{6 z@Ev8kRPIkQ&sBB>yLRYp`03R&jMcMt>uHx{oM(!)G|c5kvm1x*-J#XiJw?q3DX|*~ z66wnCBf$?g!Tj?v@W(xwJ?pL$E<Tiu=D zvQj<8Z#nZHxoY|=Vo8Rd7^c(0^j2e|msHk+xS5u13`Y?G1;&qKIKzL-Gu7YXt|J4$ z&ZkR~0buZy8vr01Hf*C)zXvS?vafLfWWMi&SD@pde3y&FzVkG`W}Y?B;c)K=&M^gI z&3*xU?k-w5j@o6^?s69>ZNZBReq}H))fM&Exc_i(TB7oGsBrwp8{8Mq=8Q=7m{ML>G#_B6p;aI;>m|tCquoixjLI=Pp<8R_=K(( zVr8K`OiRomlj_9P+lSihUPfEm_T&X=d-~ltC(-sO@B3qBg@48E(pu!va<4X^w8Z