diff --git a/README.md b/README.md index 2e46f11..f55162c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,140 @@ -# cem -Concept Embedding Models Pytorch Implementation +# Concept Embedding Models + +This repository contains the official Pytorch implementation of our work +*"Concept Embedding Models"* accepted at **NeurIPS 2022**. For details on our +model and motivation, please refer to our official [paper](TODO). + +# Model + +![CEM Architecture](figures/cem.png) + +[Concept Bottleneck Models (CBMs)](https://arxiv.org/abs/2007.04612) have recently gained attention as +high-performing and interpretable neural architectures that can explain their +predictions using a set of human-understandable high-level concepts. +Nevertheless, the need for a strict activation bottleneck as part of the +architecture, as well as the fact that one requires the set of concept +annotations used during training to be fully descriptive of the downstream +task of interest, are constraints that force CBMs to trade downstream +performance for interpretability purposes. This severely limits their +applicability in real-world applications, where data rarely comes with +concept annotations that are fully descriptive of any task of interest. + + +In our work, we propose Concept Embedding Models (CEMs) to tackle these two big +challenges. Our neural architecture expands a CBM's bottleneck and allows the +information related to unseen concepts to be flow as part of the model's +bottleneck. We achieve this by learning a high-dimensional representation +(i.e., a *concept embedding*) for each concept provided during training. Naively +extending the bottleneck, however, may directly impede the use of test-time +*concept interventions* where one can correct a mispredicted concept in order +to improve the end model's downstream performance. This is a crucial element +motivating the creation of traditional CBMs and therefore is a highly desirable +feature. Therefore, in order to use concept embeddings in the bottleneck while +still permitting effective test-time interventions, CEM +construct each concept's representation as a linear combination of two +concept embeddings, where each embedding has fixed semantics. Specifically, +we learn an embedding to represent the "active" space of a concept and one +to represent the "inactive" state of a concept, allowing us to selecting +between these two produced embeddings at test-time to then intervene in a +concept and improve downstream performance. Our entire architecture is +visualized in the figure above and formally described in our paper. + +# Usage + +In this repository, we include a standalone Pytorch implementation of CEM +which can be easily trained from scratch given a set of samples annotated with +a downstream task and a set of binary concepts. In order to use our implementation, +however, you first need to install all our code's requirements (listed in +`requirements.txt`). We provide an automatic mechanism for this installation using +Python's setup process with our standalone `setup.py`. To install our package, +therefore, you only need to run: +```bash +$ python setup.py install +``` + +After this command has terminated successfully, you should be able to import +`cem` as a package and use it to train a CEM object as follows: +```python +import pytorch_lightning as pl +from cem.models.cem import ConceptEmbeddingModel + +##### +# Define your dataset +##### + +train_dl = ... +val_dl = ... + +##### +# Construct the model +##### + +cem_model = ConceptEmbeddingModel( + n_concepts=n_concepts, # Number of training-time concepts + n_tasks=n_tasks, # Number of output labels + emb_size=16, + concept_loss_weight=0.1, + learning_rate=1e-3, + optimizer="adam", + c_extractor_arch=latent_code_generator_model, # Replace this appropriately + training_intervention_prob=0.25, # RandInt probability +) + +##### +# Train it +##### + +trainer = pl.Trainer( + gpus=1, + max_epochs=100, + check_val_every_n_epoch=5, +) +# train_dl and val_dl are datasets previously built... +trainer.fit(cem_model, train_dl, val_dl) +``` + +# Experiment Reproducibility + +To reproduce the experiments discussed in our paper, please use the scripts +in the `experiments` directory after installing the `cem` package as indicated +above. For example, to run our experiments on the DOT dataset (see our paper), +you can execute the following command: + +```bash +$ python experiments/synthetic_datasets_experiments.py dot -o dot_results/ +``` +This should generate a summary of all the results after execution has +terminated and dump all results/trained models/logs into the given +output directory (`dot_results/` in this case). + + +# Citation +If you would like to cite this repository, or the accompanying paper, please +use the following citation: + +``` +@article{DBLP:journals/corr/abs-2111-12628, + author = {Mateo Espinosa Zarlenga and + Pietro Barbiero and + Gabriele Ciravegna and + Giuseppe Marra and + Francesco Giannini and + Michelangelo Diligenti and + Zohreh Shams and + Frederic Precioso and + Stefano Melacci and + Adrian Weller and + Pietro Lio and + Mateja Jamnik}, + title = {Concept Embedding Models}, + journal = {CoRR}, + volume = {abs/TODO}, + year = {2021}, + url = {https://arxiv.org/abs/TODO}, + eprinttype = {arXiv}, + eprint = {TODO}, + timestamp = {TODO}, + biburl = {https://dblp.org/rec/journals/corr/abs-TODO.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} +``` diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000..6c808df --- /dev/null +++ b/experiments/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +# @Author: Mateo Espinosa Zarlenga +# @Date: 2022-09-19 18:28:17 +# @Last Modified by: Mateo Espinosa Zarlenga +# @Last Modified time: 2022-09-19 18:28:17 diff --git a/experiments/celeba_emb_size_ablation.py b/experiments/celeba_emb_size_ablation.py new file mode 100644 index 0000000..dec1905 --- /dev/null +++ b/experiments/celeba_emb_size_ablation.py @@ -0,0 +1,606 @@ +import argparse +import copy +import joblib +import numpy as np +import os +import torch +import torchvision + +from pathlib import Path +from pytorch_lightning import seed_everything +from torchvision import transforms + +import cem.experiments.celeba_experiments as celeba_experiments +import cem.train.training as training +import cem.train.utils as utils + + +def main( + rerun=False, + result_dir='results/cub_emb_size_ablation/', + project_name='', + activation_freq=0, + num_workers=8, + single_frequency_epochs=0, + global_params=None, + data_root=celeba_experiments.CELEBA_ROOT, +): + seed_everything(42) + # parameters for data, model, and training + og_config = dict( + cv=5, + max_epochs=200, + patience=15, + batch_size=512, + num_workers=num_workers, + emb_size=16, + extra_dims=0, + concept_loss_weight=1, + normalize_loss=False, + learning_rate=0.005, + weight_decay=4e-05, + weight_loss=False, + pretrain_model=True, + c_extractor_arch="resnet34", + optimizer="sgd", + bool=False, + early_stopping_monitor="val_loss", + early_stopping_mode="min", + early_stopping_delta=0.0, + image_size=64, + num_classes=1000, + top_k_accuracy=[3, 5, 10], + save_model=True, + use_imbalance=True, + use_binary_vector_class=True, + num_concepts=6, + label_binary_width=1, + label_dataset_subsample=12, + num_hidden_concepts=2, + selected_concepts=False, + + momentum=0.9, + shared_prob_gen=False, + sigmoidal_prob=False, + sigmoidal_embedding=False, + training_intervention_prob=0.0, + embeding_activation=None, + concat_prob=False, + ) + + utils.extend_with_global_params(og_config, global_params or []) + use_binary_vector_class = og_config.get('use_binary_vector_class', False) + if use_binary_vector_class: + # Now reload by transform the labels accordingly + width = og_config.get('label_binary_width', 5) + def _binarize(concepts, selected, width): + result = [] + binary_repr = [] + concepts = concepts[selected] + for i in range(0, concepts.shape[-1], width): + binary_repr.append( + str(int(np.sum(concepts[i : i + width]) > 0)) + ) + return int("".join(binary_repr), 2) + + celeba_train_data = torchvision.datasets.CelebA( + root=data_root, + split='all', + download=True, + target_transform=lambda x: x[0].long() - 1, + target_type=['attr'], + ) + + concept_freq = np.sum( + celeba_train_data.attr.cpu().detach().numpy(), + axis=0 + ) / celeba_train_data.attr.shape[0] + print("Concept frequency is:", concept_freq) + sorted_concepts = list(map( + lambda x: x[0], + sorted(enumerate(np.abs(concept_freq - 0.5)), key=lambda x: x[1]), + )) + num_concepts = og_config.get( + 'num_concepts', + celeba_train_data.attr.shape[-1], + ) + concept_idxs = sorted_concepts[:num_concepts] + concept_idxs = sorted(concept_idxs) + if og_config.get('num_hidden_concepts', 0): + num_hidden = og_config.get('num_hidden_concepts', 0) + hidden_concepts = sorted( + sorted_concepts[ + num_concepts:min( + (num_concepts + num_hidden), + len(sorted_concepts) + ) + ] + ) + else: + hidden_concepts = [] + print("Selecting concepts:", concept_idxs) + print("\tAnd hidden concepts:", hidden_concepts) + celeba_train_data = torchvision.datasets.CelebA( + root=data_root, + split='all', + download=True, + transform=transforms.Compose([ + transforms.Resize(og_config['image_size']), + transforms.CenterCrop(og_config['image_size']), + transforms.ToTensor(), + transforms.ConvertImageDtype(torch.float32), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]), + target_transform=lambda x: [ + torch.tensor( + _binarize( + x[1].cpu().detach().numpy(), + selected=(concept_idxs + hidden_concepts), + width=width, + ), + dtype=torch.long, + ), + x[1][concept_idxs].float(), + ], + target_type=['identity', 'attr'], + ) + label_remap = {} + vals, counts = np.unique( + list(map( + lambda x: _binarize( + x.cpu().detach().numpy(), + selected=(concept_idxs + hidden_concepts), + width=width, + ), + celeba_train_data.attr + )), + return_counts=True, + ) + for i, label in enumerate(vals): + label_remap[label] = i + + celeba_train_data = torchvision.datasets.CelebA( + root=data_root, + split='all', + download=True, + transform=transforms.Compose([ + transforms.Resize(og_config['image_size']), + transforms.CenterCrop(og_config['image_size']), + transforms.ToTensor(), + transforms.ConvertImageDtype(torch.float32), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]), + target_transform=lambda x: [ + torch.tensor( + label_remap[_binarize( + x[1].cpu().detach().numpy(), + selected=(concept_idxs + hidden_concepts), + width=width, + )], + dtype=torch.long, + ), + x[1][concept_idxs].float(), + ], + target_type=['identity', 'attr'], + ) + num_classes = len(label_remap) + + # And subsample to reduce its massive size + factor = og_config.get('label_dataset_subsample', 1) + if factor != 1: + train_idxs = np.random.choice( + np.arange(0, len(celeba_train_data)), + replace=False, + size=len(celeba_train_data)//factor, + ) + print("Subsampling to", len(train_idxs), "elements.") + celeba_train_data = torch.utils.data.Subset( + celeba_train_data, + train_idxs, + ) + else: + concept_selection = list(range(0, len(CONCEPT_SEMANTICS))) + if og_config.get('selected_concepts', False): + concept_selection = SELECTED_CONCEPTS + celeba_train_data = torchvision.datasets.CelebA( + root=data_root, + split='all', + download=True, + target_transform=lambda x: x[0].long() - 1, + target_type=['identity'], + ) + vals, counts = np.unique( + celeba_train_data.identity, + return_counts=True, + ) + sorted_labels = list(map( + lambda x: x[0], + sorted(zip(vals, counts), key=lambda x: -x[1]) + )) + print( + "Selecting", + og_config['num_classes'], + "out of", + len(vals), + "classes", + ) + if result_dir: + Path(result_dir).mkdir(parents=True, exist_ok=True) + np.save( + os.path.join( + result_dir, + f"selected_top_{og_config['num_classes']}_labels.npy", + ), + sorted_labels[:og_config['num_classes']], + ) + label_remap = {} + for i, label in enumerate(sorted_labels[:og_config['num_classes']]): + label_remap[label] = i + print("len(label_remap) =", len(label_remap)) + + # Now reload by transform the labels accordingly + celeba_train_data = torchvision.datasets.CelebA( + root=data_root, + split='all', + download=True, + transform=transforms.Compose([ + transforms.Resize(og_config['image_size']), + transforms.CenterCrop(og_config['image_size']), + transforms.ToTensor(), + transforms.ConvertImageDtype(torch.float32), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]), + target_transform=lambda x: [ + torch.tensor( + # If it is not in our map, then we make it be the token label + # og_config['num_classes'] which will be removed afterwards + label_remap.get( + x[0].cpu().detach().item() - 1, + og_config['num_classes'] + ), + dtype=torch.long, + ), + x[1][concept_selection].float(), + ], + target_type=['identity', 'attr'], + ) + num_classes = og_config['num_classes'] + + train_idxs = np.where( + list(map( + lambda x: x.cpu().detach().item() - 1 in label_remap, + celeba_train_data.identity + )) + )[0] + celeba_train_data = torch.utils.data.Subset( + celeba_train_data, + train_idxs, + ) + total_samples = len(celeba_train_data) + train_samples = int(0.7 * total_samples) + test_samples = int(0.2 * total_samples) + val_samples = total_samples - test_samples - train_samples + print( + f"Data split is: {total_samples} = {train_samples} (train) + " + f"{test_samples} (test) + {val_samples} (validation)" + ) + celeba_train_data, celeba_test_data, celeba_val_data = \ + torch.utils.data.random_split( + celeba_train_data, + [train_samples, test_samples, val_samples], + ) + train_dl = torch.utils.data.DataLoader( + celeba_train_data, + batch_size=og_config['batch_size'], + shuffle=True, + num_workers=og_config['num_workers'], + ) + test_dl = torch.utils.data.DataLoader( + celeba_test_data, + batch_size=og_config['batch_size'], + shuffle=False, + num_workers=og_config['num_workers'], + ) + val_dl = torch.utils.data.DataLoader( + celeba_val_data, + batch_size=og_config['batch_size'], + shuffle=False, + num_workers=og_config['num_workers'], + ) + + if result_dir and activation_freq: + # Then let's save the testing data for further analysis later on + out_acts_save_dir = os.path.join(result_dir, "test_embedding_acts") + Path(out_acts_save_dir).mkdir(parents=True, exist_ok=True) + + for (ds, name) in [ + (test_dl, "test"), + (val_dl, "val"), + ]: + x_total = [] + y_total = [] + c_total = [] + for x, (y, c) in ds: + x_total.append(x.cpu().detach()) + y_total.append(y.cpu().detach()) + c_total.append(c.cpu().detach()) + x_inputs = np.concatenate(x_total, axis=0) + print(f"x_{name}.shape =", x_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"x_{name}.npy"), x_inputs) + + y_inputs = np.concatenate(y_total, axis=0) + print(f"y_{name}.shape =", y_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"y_{name}.npy"), y_inputs) + + c_inputs = np.concatenate(c_total, axis=0) + print(f"c_{name}.shape =", c_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"c_{name}.npy"), c_inputs) + + label_set = set() + sample = next(iter(train_dl)) + real_sample = [] + for derp in sample: + if isinstance(derp, list): + real_sample += derp + else: + real_sample.append(derp) + sample = real_sample + print("Sample has", len(sample), "elements.") + for i, derp in enumerate(sample): + print("Element", i, "has shape", derp.shape, "and type", derp.dtype) + + print("Training sample shape is:", sample[0].shape) + print("Training label shape is:", sample[1].shape) + print("Training concept shape is:", sample[2].shape) + + n_concepts, n_tasks = sample[2].shape[-1], num_classes + + attribute_count = np.zeros((n_concepts,)) + samples_seen = 0 + for i, (_, (y, c)) in enumerate(train_dl): + print("\rIn batch", i, "we have seen", len(label_set), "classes") + c = c.cpu().detach().numpy() + attribute_count += np.sum(c, axis=0) + samples_seen += c.shape[0] + for l in y.reshape(-1).cpu().detach(): + label_set.add(l.item()) + + print("Found a total of", len(label_set), "classes") + if og_config.get("use_imbalance", False): + imbalance = samples_seen / attribute_count - 1 + else: + imbalance = None + print("Imbalance:", imbalance) + + os.makedirs(result_dir, exist_ok=True) + results = {} + for split in range(og_config["cv"]): + for emb_size in [1, 2, 4, 6, 8, 16, 32, 64]: + if emb_size not in results: + results[emb_size] = {} + if f'{split}' not in results[emb_size]: + results[emb_size][f'{split}'] = {} + print( + f'Experiment {split+1}/{og_config["cv"]} with emb_size', + emb_size, + ) + + # Trial period for mixture embedding model + config = copy.deepcopy(og_config) + config["architecture"] = "MixtureEmbModel" + config["extra_name"] = ( + f"SharedProb_AdaptiveDropout_NoProbConcat_" + f"emb_size_{emb_size}" + ) + config["shared_prob_gen"] = True + config["sigmoidal_prob"] = False + config["sigmoidal_embedding"] = False + config['training_intervention_prob'] = 0.25 + config['concat_prob'] = False + config['emb_size'] = emb_size + config["embeding_activation"] = "leakyrelu" + mixed_emb_shared_prob_model, mixed_emb_shared_prob_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[emb_size][f'{split}'], + config, + mixed_emb_shared_prob_model, + mixed_emb_shared_prob_test_results, + ) + + # Train fuzzy CBM with extra capacity + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["bool"] = False + config["extra_dims"] = (emb_size - 1) * n_concepts + config["extra_name"] = ( + f"FuzzyExtraCapacity_Logit_emb_size_{emb_size}" + ) + config["bottleneck_nonlinear"] = "leakyrelu" + config["sigmoidal_extra_capacity"] = False + config["sigmoidal_prob"] = False + config['emb_size'] = emb_size + extra_fuzzy_logit_model, extra_fuzzy_logit_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[emb_size][f'{split}'], + config, + extra_fuzzy_logit_model, + extra_fuzzy_logit_test_results, + ) + + # train vanilla model with more capacity (i.e., no concept + # supervision) but with ReLU activation + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = ( + f"NoConceptSupervisionReLU_ExtraCapacity_emb_size_{emb_size}" + ) + config["bool"] = False + config["extra_dims"] = (emb_size - 1) * n_concepts + config["bottleneck_nonlinear"] = "leakyrelu" + config["concept_loss_weight"] = 0 + config['emb_size'] = emb_size + extra_vanilla_relu_model, extra_vanilla_relu_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[emb_size][f'{split}'], + config, + extra_vanilla_relu_model, + extra_vanilla_relu_test_results, + ) + + # save results + joblib.dump(results, os.path.join(result_dir, f'results.joblib')) + return results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=( + 'Runs embedding ablation study in CelebA dataset.' + ), + ) + parser.add_argument( + '--project_name', + default='', + help=( + "Project name used for Weights & Biases monitoring. If not " + "provided, we assume no W&B logging is done." + ), + metavar="name", + + ) + + parser.add_argument( + '--output_dir', + '-o', + default='results/celeba_emb_size_ablation/', + help=( + "directory where we will dump our experiment's results. If not " + "given, then we will use results/celeba_emb_size_ablation/." + ), + metavar="path", + + ) + parser.add_argument( + '--rerun', + '-r', + default=False, + action="store_true", + help=( + "If set, then we will force a rerun of the entire experiment even if " + "valid results are found in the provided output directory. Note that " + "this may overwrite and previous results, so use with care." + ), + + ) + parser.add_argument( + '--activation_freq', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the ' + 'embedding activations for our validation set. By default we will ' + 'not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--single_frequency_epochs', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the ' + 'embedding activations for our validation set. By default we ' + 'will not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--num_workers', + default=8, + help=( + 'number of workers used for data feeders. Do not use more workers ' + 'than cores in the machine.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + "-d", + "--debug", + action="store_true", + default=False, + help="starts debug mode in our program.", + ) + parser.add_argument( + '-p', + '--param', + action='append', + nargs=2, + metavar=('param_name=value'), + help=( + 'Allows the passing of a config param that will overwrite ' + 'anything passed as part of the config file itself.' + ), + default=[], + ) + args = parser.parse_args() + if args.project_name: + # Lazy import to avoid importing unless necessary + import wandb + main( + rerun=args.rerun, + result_dir=args.output_dir, + project_name=args.project_name, + activation_freq=args.activation_freq, + num_workers=args.num_workers, + single_frequency_epochs=args.single_frequency_epochs, + global_params=args.param, + ) diff --git a/experiments/celeba_experiments.py b/experiments/celeba_experiments.py new file mode 100644 index 0000000..a443383 --- /dev/null +++ b/experiments/celeba_experiments.py @@ -0,0 +1,826 @@ +import argparse +import copy +import joblib +import numpy as np +import os +import torch +import torchvision + +from pathlib import Path +from pytorch_lightning import seed_everything +from torchvision import transforms + +import cem.train.training as training +import cem.train.utils as utils + +############################################################################### +## GLOBAL VARIABLES +############################################################################### + +SELECTED_CONCEPTS = [ + 2, + 4, + 6, + 7, + 8, + 9, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 32, + 33, + 39, +] + +CONCEPT_SEMANTICS = [ + '5_o_Clock_Shadow', + 'Arched_Eyebrows', + 'Attractive', + 'Bags_Under_Eyes', + 'Bald', + 'Bangs', + 'Big_Lips', + 'Big_Nose', + 'Black_Hair', + 'Blond_Hair', + 'Blurry', + 'Brown_Hair', + 'Bushy_Eyebrows', + 'Chubby', + 'Double_Chin', + 'Eyeglasses', + 'Goatee', + 'Gray_Hair', + 'Heavy_Makeup', + 'High_Cheekbones', + 'Male', + 'Mouth_Slightly_Open', + 'Mustache', + 'Narrow_Eyes', + 'No_Beard', + 'Oval_Face', + 'Pale_Skin', + 'Pointy_Nose', + 'Receding_Hairline', + 'Rosy_Cheeks', + 'Sideburns', + 'Smiling', + 'Straight_Hair', + 'Wavy_Hair', + 'Wearing_Earrings', + 'Wearing_Hat', + 'Wearing_Lipstick', + 'Wearing_Necklace', + 'Wearing_Necktie', + 'Young', +] + +# IMPORANT NOTE: THIS DATASET NEEDS TO BE DOWNLOADED FIRST BEFORE BEING ABLE +# TO RUN ANY CUB EXPERIMENTS!! +# Instructions on how to download it can be found +# in https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html +CELEBA_ROOT = 'data/celeba' + + +############################################################################### +## MAIN EXPERIMENT LOOP +############################################################################### + + +def main( + rerun=False, + result_dir='results/celeba/', + project_name='', + activation_freq=0, + num_workers=8, + single_frequency_epochs=0, + global_params=None, + save_model=True, + data_root=CELEBA_ROOT, +): + seed_everything(42) + # parameters for data, model, and training + og_config = dict( + cv=5, + max_epochs=200, + patience=15, + batch_size=512, + num_workers=num_workers, + emb_size=16, + extra_dims=0, + concept_loss_weight=1, + normalize_loss=False, + learning_rate=0.005, + weight_decay=4e-05, + weight_loss=False, + pretrain_model=True, + c_extractor_arch="resnet34", + optimizer="sgd", + bool=False, + early_stopping_monitor="val_loss", + early_stopping_mode="min", + early_stopping_delta=0.0, + image_size=64, + num_classes=1000, + top_k_accuracy=[3, 5, 10], + save_model=True, + use_imbalance=True, + use_binary_vector_class=True, + num_concepts=6, + label_binary_width=1, + label_dataset_subsample=12, + num_hidden_concepts=2, + selected_concepts=False, + + momentum=0.9, + shared_prob_gen=False, + sigmoidal_prob=False, + sigmoidal_embedding=False, + training_intervention_prob=0.0, + embeding_activation=None, + concat_prob=False, + ) + + utils.extend_with_global_params(og_config, global_params or []) + use_binary_vector_class = og_config.get('use_binary_vector_class', False) + if use_binary_vector_class: + # Now reload by transform the labels accordingly + width = og_config.get('label_binary_width', 5) + def _binarize(concepts, selected, width): + result = [] + binary_repr = [] + concepts = concepts[selected] + for i in range(0, concepts.shape[-1], width): + binary_repr.append( + str(int(np.sum(concepts[i : i + width]) > 0)) + ) + return int("".join(binary_repr), 2) + + celeba_train_data = torchvision.datasets.CelebA( + root=data_root, + split='all', + download=True, + target_transform=lambda x: x[0].long() - 1, + target_type=['attr'], + ) + + concept_freq = np.sum( + celeba_train_data.attr.cpu().detach().numpy(), + axis=0 + ) / celeba_train_data.attr.shape[0] + print("Concept frequency is:", concept_freq) + sorted_concepts = list(map( + lambda x: x[0], + sorted(enumerate(np.abs(concept_freq - 0.5)), key=lambda x: x[1]), + )) + num_concepts = og_config.get( + 'num_concepts', + celeba_train_data.attr.shape[-1], + ) + concept_idxs = sorted_concepts[:num_concepts] + concept_idxs = sorted(concept_idxs) + if og_config.get('num_hidden_concepts', 0): + num_hidden = og_config.get('num_hidden_concepts', 0) + hidden_concepts = sorted( + sorted_concepts[ + num_concepts:min( + (num_concepts + num_hidden), + len(sorted_concepts) + ) + ] + ) + else: + hidden_concepts = [] + print("Selecting concepts:", concept_idxs) + print("\tAnd hidden concepts:", hidden_concepts) + celeba_train_data = torchvision.datasets.CelebA( + root=data_root, + split='all', + download=True, + transform=transforms.Compose([ + transforms.Resize(og_config['image_size']), + transforms.CenterCrop(og_config['image_size']), + transforms.ToTensor(), + transforms.ConvertImageDtype(torch.float32), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]), + target_transform=lambda x: [ + torch.tensor( + _binarize( + x[1].cpu().detach().numpy(), + selected=(concept_idxs + hidden_concepts), + width=width, + ), + dtype=torch.long, + ), + x[1][concept_idxs].float(), + ], + target_type=['identity', 'attr'], + ) + label_remap = {} + vals, counts = np.unique( + list(map( + lambda x: _binarize( + x.cpu().detach().numpy(), + selected=(concept_idxs + hidden_concepts), + width=width, + ), + celeba_train_data.attr + )), + return_counts=True, + ) + for i, label in enumerate(vals): + label_remap[label] = i + + celeba_train_data = torchvision.datasets.CelebA( + root=data_root, + split='all', + download=True, + transform=transforms.Compose([ + transforms.Resize(og_config['image_size']), + transforms.CenterCrop(og_config['image_size']), + transforms.ToTensor(), + transforms.ConvertImageDtype(torch.float32), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]), + target_transform=lambda x: [ + torch.tensor( + label_remap[_binarize( + x[1].cpu().detach().numpy(), + selected=(concept_idxs + hidden_concepts), + width=width, + )], + dtype=torch.long, + ), + x[1][concept_idxs].float(), + ], + target_type=['identity', 'attr'], + ) + num_classes = len(label_remap) + + # And subsample to reduce its massive size + factor = og_config.get('label_dataset_subsample', 1) + if factor != 1: + train_idxs = np.random.choice( + np.arange(0, len(celeba_train_data)), + replace=False, + size=len(celeba_train_data)//factor, + ) + print("Subsampling to", len(train_idxs), "elements.") + celeba_train_data = torch.utils.data.Subset( + celeba_train_data, + train_idxs, + ) + else: + concept_selection = list(range(0, len(CONCEPT_SEMANTICS))) + if og_config.get('selected_concepts', False): + concept_selection = SELECTED_CONCEPTS + celeba_train_data = torchvision.datasets.CelebA( + root=data_root, + split='all', + download=True, + target_transform=lambda x: x[0].long() - 1, + target_type=['identity'], + ) + vals, counts = np.unique( + celeba_train_data.identity, + return_counts=True, + ) + sorted_labels = list(map( + lambda x: x[0], + sorted(zip(vals, counts), key=lambda x: -x[1]) + )) + print( + "Selecting", + og_config['num_classes'], + "out of", + len(vals), + "classes", + ) + if result_dir: + Path(result_dir).mkdir(parents=True, exist_ok=True) + np.save( + os.path.join( + result_dir, + f"selected_top_{og_config['num_classes']}_labels.npy", + ), + sorted_labels[:og_config['num_classes']], + ) + label_remap = {} + for i, label in enumerate(sorted_labels[:og_config['num_classes']]): + label_remap[label] = i + print("len(label_remap) =", len(label_remap)) + + # Now reload by transform the labels accordingly + celeba_train_data = torchvision.datasets.CelebA( + root=data_root, + split='all', + download=True, + transform=transforms.Compose([ + transforms.Resize(og_config['image_size']), + transforms.CenterCrop(og_config['image_size']), + transforms.ToTensor(), + transforms.ConvertImageDtype(torch.float32), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]), + target_transform=lambda x: [ + torch.tensor( + # If it is not in our map, then we make it be the token label + # og_config['num_classes'] which will be removed afterwards + label_remap.get( + x[0].cpu().detach().item() - 1, + og_config['num_classes'] + ), + dtype=torch.long, + ), + x[1][concept_selection].float(), + ], + target_type=['identity', 'attr'], + ) + num_classes = og_config['num_classes'] + + train_idxs = np.where( + list(map( + lambda x: x.cpu().detach().item() - 1 in label_remap, + celeba_train_data.identity + )) + )[0] + celeba_train_data = torch.utils.data.Subset( + celeba_train_data, + train_idxs, + ) + total_samples = len(celeba_train_data) + train_samples = int(0.7 * total_samples) + test_samples = int(0.2 * total_samples) + val_samples = total_samples - test_samples - train_samples + print( + f"Data split is: {total_samples} = {train_samples} (train) + " + f"{test_samples} (test) + {val_samples} (validation)" + ) + celeba_train_data, celeba_test_data, celeba_val_data = \ + torch.utils.data.random_split( + celeba_train_data, + [train_samples, test_samples, val_samples], + ) + train_dl = torch.utils.data.DataLoader( + celeba_train_data, + batch_size=og_config['batch_size'], + shuffle=True, + num_workers=og_config['num_workers'], + ) + test_dl = torch.utils.data.DataLoader( + celeba_test_data, + batch_size=og_config['batch_size'], + shuffle=False, + num_workers=og_config['num_workers'], + ) + val_dl = torch.utils.data.DataLoader( + celeba_val_data, + batch_size=og_config['batch_size'], + shuffle=False, + num_workers=og_config['num_workers'], + ) + + if result_dir and activation_freq: + # Then let's save the testing data for further analysis later on + out_acts_save_dir = os.path.join(result_dir, "test_embedding_acts") + Path(out_acts_save_dir).mkdir(parents=True, exist_ok=True) + + for (ds, name) in [ + (test_dl, "test"), + (val_dl, "val"), + ]: + x_total = [] + y_total = [] + c_total = [] + for x, (y, c) in ds: + x_total.append(x.cpu().detach()) + y_total.append(y.cpu().detach()) + c_total.append(c.cpu().detach()) + x_inputs = np.concatenate(x_total, axis=0) + print(f"x_{name}.shape =", x_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"x_{name}.npy"), x_inputs) + + y_inputs = np.concatenate(y_total, axis=0) + print(f"y_{name}.shape =", y_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"y_{name}.npy"), y_inputs) + + c_inputs = np.concatenate(c_total, axis=0) + print(f"c_{name}.shape =", c_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"c_{name}.npy"), c_inputs) + + label_set = set() + sample = next(iter(train_dl)) + real_sample = [] + for derp in sample: + if isinstance(derp, list): + real_sample += derp + else: + real_sample.append(derp) + sample = real_sample + print("Sample has", len(sample), "elements.") + for i, derp in enumerate(sample): + print("Element", i, "has shape", derp.shape, "and type", derp.dtype) + + print("Training sample shape is:", sample[0].shape) + print("Training label shape is:", sample[1].shape) + print("Training concept shape is:", sample[2].shape) + + n_concepts, n_tasks = sample[2].shape[-1], num_classes + + attribute_count = np.zeros((n_concepts,)) + samples_seen = 0 + for i, (_, (y, c)) in enumerate(train_dl): + print("\rIn batch", i, "we have seen", len(label_set), "classes") + c = c.cpu().detach().numpy() + attribute_count += np.sum(c, axis=0) + samples_seen += c.shape[0] + for l in y.reshape(-1).cpu().detach(): + label_set.add(l.item()) + + print("Found a total of", len(label_set), "classes") + if og_config.get("use_imbalance", False): + imbalance = samples_seen / attribute_count - 1 + else: + imbalance = None + print("Imbalance:", imbalance) + + os.makedirs(result_dir, exist_ok=True) + + results = {} + for split in range(og_config["cv"]): + print(f'Experiment {split+1}/{og_config["cv"]}') + results[f'{split}'] = {} + + # Trial period for mixture embedding model + config = copy.deepcopy(og_config) + config["architecture"] = "MixtureEmbModel" + config["extra_name"] = f"SharedProb_AdaptiveDropout_NoProbConcat" + config["shared_prob_gen"] = True + config["sigmoidal_prob"] = True + config["sigmoidal_embedding"] = False + config['training_intervention_prob'] = 0.25 + config['concat_prob'] = False + config['emb_size'] = config['emb_size'] + config["embeding_activation"] = "leakyrelu" + mixed_emb_shared_prob_model, mixed_emb_shared_prob_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[f'{split}'], + config, + mixed_emb_shared_prob_model, + mixed_emb_shared_prob_test_results, + ) + + # Trial period for mixture embedding model + config = copy.deepcopy(og_config) + config["architecture"] = "MixtureEmbModel" + config["extra_name"] = f"SharedProb_Adaptive_NoProbConcat" + config["shared_prob_gen"] = True + config["sigmoidal_prob"] = True + config["sigmoidal_embedding"] = False + config['training_intervention_prob'] = 0.0 + config['concat_prob'] = False + config['emb_size'] = config['emb_size'] + config["embeding_activation"] = "leakyrelu" + mixed_emb_shared_prob_model, mixed_emb_shared_prob_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[f'{split}'], + config, + mixed_emb_shared_prob_model, + mixed_emb_shared_prob_test_results, + ) + + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["bool"] = False + config["extra_dims"] = (config['emb_size'] - 1) * n_concepts + config["extra_name"] = f"FuzzyExtraCapacity_Logit" + config["bottleneck_nonlinear"] = "leakyrelu" + config["sigmoidal_extra_capacity"] = False + config["sigmoidal_prob"] = False + extra_fuzzy_logit_model, extra_fuzzy_logit_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[f'{split}'], + config, + extra_fuzzy_logit_model, + extra_fuzzy_logit_test_results, + ) + + # fuzzy model + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = f"Fuzzy" + config["bool"] = False + config["extra_dims"] = 0 + config["sigmoidal_extra_capacity"] = False + config["sigmoidal_prob"] = True + extra_fuzzy_logit_model, extra_fuzzy_logit_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[f'{split}'], + config, + extra_fuzzy_logit_model, + extra_fuzzy_logit_test_results, + ) + + # sequential and independent models + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = f"" + config["sigmoidal_prob"] = True + ind_model, ind_test_results, seq_model, seq_test_results = \ + training.train_independent_and_sequential_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + config["architecture"] = "IndependentConceptBottleneckModel" + training.update_statistics( + results[f'{split}'], + config, + ind_model, + ind_test_results, + ) + + config["architecture"] = "SequentialConceptBottleneckModel" + training.update_statistics( + results[f'{split}'], + config, + seq_model, + seq_test_results, + ) + + # train model *without* embeddings (concepts are just *Boolean* scalars) + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = f"Bool" + config["bool"] = True + trainingl_model, bool_test_results = training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + imbalance=imbalance, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + ) + training.update_statistics( + results[f'{split}'], + config, + bool_model, + bool_test_results, + ) + + # train vanilla model with more capacity (i.e., no concept supervision) + # but with ReLU activation + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = f"NoConceptSupervisionReLU_ExtraCapacity" + config["bool"] = False + config["extra_dims"] = (config['emb_size'] - 1) * n_concepts + config["bottleneck_nonlinear"] = "leakyrelu" + config["concept_loss_weight"] = 0 + extra_vanilla_relu_model, extra_vanilla_relu_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[f'{split}'], + config, + extra_vanilla_relu_model, + extra_vanilla_relu_test_results, + ) + + # save results + joblib.dump(results, os.path.join(result_dir, f'results.joblib')) + + return results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=( + 'Runs concept embedding experiment in CelebA dataset.' + ), + ) + parser.add_argument( + '--project_name', + default='', + help=( + "Project name used for Weights & Biases monitoring. If not " + "provided, then we will assume we do not run a w&b project." + ), + metavar="name", + + ) + + parser.add_argument( + '--output_dir', + '-o', + default='results/celeba/', + help=( + "directory where we will dump our experiment's results. If not " + "given, then we will use ./results/celeba/." + ), + metavar="path", + + ) + parser.add_argument( + '--rerun', + '-r', + default=False, + action="store_true", + help=( + "If set, then we will force a rerun of the entire experiment even if " + "valid results are found in the provided output directory. Note that " + "this may overwrite and previous results, so use with care." + ), + + ) + parser.add_argument( + '--activation_freq', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the embedding activations for our ' + 'validation set. By default we will not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--single_frequency_epochs', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the embedding activations for our ' + 'validation set. By default we will not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--num_workers', + default=12, + help=( + 'number of workers used for data feeders. Do not use more workers ' + 'than cores in the machine.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--data_root', + default=CELEBA_ROOT, + help=( + 'directory containing the CelebA dataset.' + ), + metavar='path', + type=str, + ) + parser.add_argument( + "-d", + "--debug", + action="store_true", + default=False, + help="starts debug mode in our program.", + ) + parser.add_argument( + "--no_save_model", + action="store_true", + default=False, + help="whether or not we will save the fully trained models.", + ) + parser.add_argument( + '-p', + '--param', + action='append', + nargs=2, + metavar=('param_name=value'), + help=( + 'Allows the passing of a config param that will overwrite ' + 'anything passed as part of the config file itself.' + ), + default=[], + ) + args = parser.parse_args() + main( + data_root=args.data_root, + rerun=args.rerun, + result_dir=args.output_dir, + project_name=args.project_name, + activation_freq=args.activation_freq, + num_workers=args.num_workers, + single_frequency_epochs=args.single_frequency_epochs, + global_params=args.param, + save_model=(not args.no_save_model), + ) + diff --git a/experiments/cub_emb_size_ablation.py b/experiments/cub_emb_size_ablation.py new file mode 100644 index 0000000..8055aca --- /dev/null +++ b/experiments/cub_emb_size_ablation.py @@ -0,0 +1,441 @@ +import argparse +import copy +import joblib +import numpy as np +import os +import torch + +from CUB200.cub_loader import load_data, find_class_imbalance +from pathlib import Path +from pytorch_lightning import seed_everything + +import cem.experiments.cub_experiments as cub +import cem.train.training as training +import cem.train.utils as utils + +def main( + rerun=False, + result_dir='results/cub_emb_size_ablation/', + project_name='', + activation_freq=0, + num_workers=8, + single_frequency_epochs=0, + global_params=None, + data_root=celeba_final.CELEBA_ROOT, +): + seed_everything(42) + # parameters for data, model, and training + og_config = dict( + cv=5, + max_epochs=300, + patience=15, + batch_size=128, + num_workers=num_workers, + emb_size=16, + extra_dims=0, + concept_loss_weight=5, + normalize_loss=False, + learning_rate=0.01, + weight_decay=4e-05, + scheduler_step=20, + weight_loss=True, + pretrain_model=True, + c_extractor_arch="resnet34", + optimizer="sgd", + bool=False, + early_stopping_monitor="val_loss", + early_stopping_mode="min", + early_stopping_delta=0.0, + # By default we start with 25% of the concepts in the bottleneck + sampling_percent=0.25, + + momentum=0.9, + shared_prob_gen=False, + sigmoidal_prob=False, + sigmoidal_embedding=False, + training_intervention_prob=0.0, + embeding_activation=None, + concat_prob=False, + ) + + utils.extend_with_global_params(og_config, global_params or []) + train_data_path = os.path.join(cub.BASE_DIR, 'train.pkl') + if og_config['weight_loss']: + imbalance = find_class_imbalance(train_data_path, True) + else: + imbalance = None + + val_data_path = train_data_path.replace('train.pkl', 'val.pkl') + test_data_path = train_data_path.replace('train.pkl', 'test.pkl') + sampling_percent = og_config.get("sampling_percent", 1) + n_concepts, n_tasks = 112, 200 + + if sampling_percent != 1: + # Do the subsampling + new_n_concepts = int(np.ceil(n_concepts * sampling_percent)) + selected_concepts_file = os.path.join( + result_dir, + f"selected_concepts_sampling_{sampling_percent}.npy", + ) + if (not rerun) and os.path.exists(selected_concepts_file): + selected_concepts = np.load(selected_concepts_file) + else: + selected_concepts = sorted( + np.random.permutation(n_concepts)[:new_n_concepts] + ) + np.save(selected_concepts_file, selected_concepts) + print("\t\tSelected concepts:", selected_concepts) + def subsample_transform(sample): + if isinstance(sample, list): + sample = np.array(sample) + return sample[selected_concepts] + + if og_config['weight_loss']: + imbalance = np.array(imbalance)[selected_concepts] + + train_dl = load_data( + pkl_paths=[train_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + val_dl = load_data( + pkl_paths=[val_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + test_dl = load_data( + pkl_paths=[test_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + # And set the right number of concepts to be used + n_concepts = new_n_concepts + else: + train_dl = load_data( + pkl_paths=[train_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + ) + val_dl = load_data( + pkl_paths=[val_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + ) + test_dl = load_data( + pkl_paths=[test_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + ) + + if result_dir and activation_freq: + # Then let's save the testing data for furter analysis later on + out_acts_save_dir = os.path.join(result_dir, "test_embedding_acts") + Path(out_acts_save_dir).mkdir(parents=True, exist_ok=True) + for (ds, name) in [ + (test_dl, "test"), + (val_dl, "val"), + ]: + x_total = [] + y_total = [] + c_total = [] + for x, y, c in ds: + x_total.append(x.cpu().detach()) + y_total.append(y.cpu().detach()) + c_total.append(c.cpu().detach()) + x_inputs = np.concatenate(x_total, axis=0) + print(f"x_{name}.shape =", x_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"x_{name}.npy"), x_inputs) + + y_inputs = np.concatenate(y_total, axis=0) + print(f"y_{name}.shape =", y_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"y_{name}.npy"), y_inputs) + + c_inputs = np.concatenate(c_total, axis=0) + print(f"c_{name}.shape =", c_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"c_{name}.npy"), c_inputs) + + sample = next(iter(train_dl)) + n_concepts, n_tasks = sample[2].shape[-1], 200 + + print("Training sample shape is:", sample[0].shape) + print("Training label shape is:", sample[1].shape) + print("Training concept shape is:", sample[2].shape) + os.makedirs(result_dir, exist_ok=True) + results = {} + + for split in range(og_config["cv"]): + for emb_size in [1, 2, 4, 6, 8, 16, 32, 64]: + if emb_size not in results: + results[emb_size] = {} + if f'{split}' not in results[emb_size]: + results[emb_size][f'{split}'] = {} + print( + f'Experiment {split+1}/{og_config["cv"]} with emb_size', + emb_size, + ) + + # Trial period for mixture embedding model + config = copy.deepcopy(og_config) + config["architecture"] = "MixtureEmbModel" + config["extra_name"] = ( + f"SharedProb_AdaptiveDropout_NoProbConcat_emb_size_{emb_size}" + ) + config["shared_prob_gen"] = True + config["sigmoidal_prob"] = False + config["sigmoidal_embedding"] = False + config['training_intervention_prob'] = 0.25 + config['concat_prob'] = False + config['emb_size'] = emb_size + config["embeding_activation"] = "leakyrelu" + mixed_emb_shared_prob_model, mixed_emb_shared_prob_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[emb_size][f'{split}'], + config, + mixed_emb_shared_prob_model, + mixed_emb_shared_prob_test_results, + ) + + # Train fuzzy CBM with extra capacity + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["bool"] = False + config["extra_dims"] = (emb_size - 1) * n_concepts + config["extra_name"] = ( + f"FuzzyExtraCapacity_Logit_emb_size_{emb_size}" + ) + config["bottleneck_nonlinear"] = "leakyrelu" + config["sigmoidal_extra_capacity"] = False + config["sigmoidal_prob"] = False + config['emb_size'] = emb_size + extra_fuzzy_logit_model, extra_fuzzy_logit_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[emb_size][f'{split}'], + config, + extra_fuzzy_logit_model, + extra_fuzzy_logit_test_results, + ) + + # train vanilla model with more capacity (i.e., no concept + # supervision) but with ReLU activation + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = ( + f"NoConceptSupervisionReLU_ExtraCapacity_emb_size_{emb_size}" + ) + config["bool"] = False + config["extra_dims"] = (emb_size - 1) * n_concepts + config["bottleneck_nonlinear"] = "leakyrelu" + config["concept_loss_weight"] = 0 + config['emb_size'] = emb_size + extra_vanilla_relu_model, extra_vanilla_relu_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[emb_size][f'{split}'], + config, + extra_vanilla_relu_model, + extra_vanilla_relu_test_results, + ) + + # save results + joblib.dump(results, os.path.join(result_dir, f'results.joblib')) + return results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=( + 'Runs concept embedding experiment in CUB dataset.' + ), + ) + parser.add_argument( + '--project_name', + default='', + help=( + "Project name used for Weights & Biases monitoring. If not " + "provided, then we will assume no W&B is used for logging." + ), + metavar="name", + + ) + + parser.add_argument( + '--output_dir', + '-o', + default='results/cub_emb_size_ablation/', + help=( + "directory where we will dump our experiment's results. If not " + "given, then we will use ./results/cub_emb_size_ablation/." + ), + metavar="path", + + ) + parser.add_argument( + '--rerun', + '-r', + default=False, + action="store_true", + help=( + "If set, then we will force a rerun of the entire experiment " + "even if valid results are found in the provided output " + "directory. Note that this may overwrite and previous results, " + "so use with care." + ), + + ) + parser.add_argument( + '--activation_freq', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the ' + 'embedding activations for our validation set. By default we will ' + 'not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--single_frequency_epochs', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the ' + 'embedding activations for our validation set. By default we will ' + 'not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--num_workers', + default=8, + help=( + 'number of workers used for data feeders. Do not use more workers ' + 'than cores in the machine.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + "-d", + "--debug", + action="store_true", + default=False, + help="starts debug mode in our program.", + ) + parser.add_argument( + '-p', + '--param', + action='append', + nargs=2, + metavar=('param_name=value'), + help=( + 'Allows the passing of a config param that will overwrite ' + 'anything passed as part of the config file itself.' + ), + default=[], + ) + args = parser.parse_args() + if args.project_name: + # Lazy import to avoid importing unless necessary + import wandb + main( + rerun=args.rerun, + result_dir=args.output_dir, + project_name=args.project_name, + activation_freq=args.activation_freq, + num_workers=args.num_workers, + single_frequency_epochs=args.single_frequency_epochs, + global_params=args.param, + ) +# hyperparameter_sweep() diff --git a/experiments/cub_experiments.py b/experiments/cub_experiments.py new file mode 100644 index 0000000..cc08770 --- /dev/null +++ b/experiments/cub_experiments.py @@ -0,0 +1,574 @@ +import argparse +import copy +import joblib +import numpy as np +import os +import torch + +from CUB200.cub_loader import load_data, find_class_imbalance +from pathlib import Path +from pytorch_lightning import seed_everything + +import cem.train.training as training +import cem.train.utils as utils + +################################################################################ +## GLOBAL CUB VARIABLES +################################################################################ + +# IMPORANT NOTE: THIS DATASET NEEDS TO BE DOWNLOADED FIRST BEFORE BEING ABLE +# TO RUN ANY CUB EXPERIMENTS!! +# Instructions on how to download it can be found +# in the original CBM paper's repository +# found here: https://github.com/yewsiang/ConceptBottleneck +CUB_DIR = 'cem/data/CUB200/' +BASE_DIR = os.path.join(CUB_DIR, 'class_attr_data_10') + +################################################################################ +## MAIN FUNCTION +################################################################################ + + +def main( + rerun=False, + result_dir='results/cub/', + project_name='', + activation_freq=0, + num_workers=8, + single_frequency_epochs=0, + global_params=None, +): + seed_everything(42) + # parameters for data, model, and training + og_config = dict( + cv=5, + max_epochs=300, + patience=15, + batch_size=128, + num_workers=num_workers, + emb_size=16, + extra_dims=0, + concept_loss_weight=5, + normalize_loss=False, + learning_rate=0.01, + weight_decay=4e-05, + weight_loss=True, + pretrain_model=True, + c_extractor_arch="resnet34", + optimizer="sgd", + bool=False, + early_stopping_monitor="val_loss", + early_stopping_mode="min", + early_stopping_delta=0.0, + sampling_percent=1, + + momentum=0.9, + shared_prob_gen=False, + sigmoidal_prob=False, + sigmoidal_embedding=False, + training_intervention_prob=0.0, + embeding_activation=None, + concat_prob=False, + ) + utils.extend_with_global_params(og_config, global_params or []) + + train_data_path = os.path.join(BASE_DIR, 'train.pkl') + if og_config['weight_loss']: + imbalance = find_class_imbalance(train_data_path, True) + else: + imbalance = None + + val_data_path = train_data_path.replace('train.pkl', 'val.pkl') + test_data_path = train_data_path.replace('train.pkl', 'test.pkl') + sampling_percent = og_config.get("sampling_percent", 1) + n_concepts, n_tasks = 112, 200 + if sampling_percent != 1: + # Do the subsampling + new_n_concepts = int(np.ceil(n_concepts * sampling_percent)) + selected_concepts_file = os.path.join( + result_dir, + f"selected_concepts_sampling_{sampling_percent}.npy", + ) + if (not rerun) and os.path.exists(selected_concepts_file): + selected_concepts = np.load(selected_concepts_file) + else: + selected_concepts = sorted( + np.random.permutation(n_concepts)[:new_n_concepts] + ) + np.save(selected_concepts_file, selected_concepts) + print("\t\tSelected concepts:", selected_concepts) + def subsample_transform(sample): + if isinstance(sample, list): + sample = np.array(sample) + return sample[selected_concepts] + + if og_config['weight_loss']: + imbalance = np.array(imbalance)[selected_concepts] + + train_dl = load_data( + pkl_paths=[train_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + val_dl = load_data( + pkl_paths=[val_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + test_dl = load_data( + pkl_paths=[test_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + + # And set the right number of concepts to be used + n_concepts = new_n_concepts + else: + train_dl = load_data( + pkl_paths=[train_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=CUB_DIR, + num_workers=og_config['num_workers'], + ) + val_dl = load_data( + pkl_paths=[val_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=CUB_DIR, + num_workers=og_config['num_workers'], + ) + test_dl = load_data( + pkl_paths=[test_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=CUB_DIR, + num_workers=og_config['num_workers'], + ) + + if result_dir and activation_freq: + # Then let's save the testing data for further analysis later on + out_acts_save_dir = os.path.join(result_dir, "test_embedding_acts") + Path(out_acts_save_dir).mkdir(parents=True, exist_ok=True) + for (ds, name) in [ + (test_dl, "test"), + (val_dl, "val"), + ]: + x_total = [] + y_total = [] + c_total = [] + for x, y, c in ds: + x_total.append(x.cpu().detach()) + y_total.append(y.cpu().detach()) + c_total.append(c.cpu().detach()) + x_inputs = np.concatenate(x_total, axis=0) + print(f"x_{name}.shape =", x_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"x_{name}.npy"), x_inputs) + + y_inputs = np.concatenate(y_total, axis=0) + print(f"y_{name}.shape =", y_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"y_{name}.npy"), y_inputs) + + c_inputs = np.concatenate(c_total, axis=0) + print(f"c_{name}.shape =", c_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"c_{name}.npy"), c_inputs) + + sample = next(iter(train_dl)) + n_concepts, n_tasks = sample[2].shape[-1], 200 + + print("Training sample shape is:", sample[0].shape) + print("Training label shape is:", sample[1].shape) + print("Training concept shape is:", sample[2].shape) + + os.makedirs(result_dir, exist_ok=True) + + results = {} + for split in range(og_config["cv"]): + print(f'Experiment {split+1}/{og_config["cv"]}') + results[f'{split}'] = {} + + config = copy.deepcopy(og_config) + config["architecture"] = "MixtureEmbModel" + config["extra_name"] = f"SharedProb_AdaptiveDropout_NoProbConcat" + config["shared_prob_gen"] = True + config["sigmoidal_prob"] = True + config["sigmoidal_embedding"] = False + config['training_intervention_prob'] = 0.25 + config['concat_prob'] = False + config['emb_size'] = config['emb_size'] + config["embeding_activation"] = "leakyrelu" + mixed_emb_shared_prob_model, mixed_emb_shared_prob_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[f'{split}'], + config, + mixed_emb_shared_prob_model, + mixed_emb_shared_prob_test_results, + ) + + config = copy.deepcopy(og_config) + config["architecture"] = "MixtureEmbModel" + config["extra_name"] = f"SharedProb_Adaptive_NoProbConcat" + config["shared_prob_gen"] = True + config["sigmoidal_prob"] = True + config["sigmoidal_embedding"] = False + config['training_intervention_prob'] = 0.0 + config['concat_prob'] = False + config['emb_size'] = config['emb_size'] + config["embeding_activation"] = "leakyrelu" + mixed_emb_shared_prob_model, mixed_emb_shared_prob_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[f'{split}'], + config, + mixed_emb_shared_prob_model, + mixed_emb_shared_prob_test_results, + ) + + # train model *without* embeddings but with extra capacity (concepts + # are just *fuzzy* scalars and the model also has some extra capacity). + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["bool"] = False + config["extra_dims"] = (config['emb_size'] - 1) * n_concepts + config["extra_name"] = f"FuzzyExtraCapacity_Logit" + config["bottleneck_nonlinear"] = "leakyrelu" + config["sigmoidal_extra_capacity"] = False + config["sigmoidal_prob"] = False + extra_fuzzy_logit_model, extra_fuzzy_logit_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[f'{split}'], + config, + extra_fuzzy_logit_model, + extra_fuzzy_logit_test_results, + ) + + # fuzzy model + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = f"Fuzzy" + config["bool"] = False + config["extra_dims"] = 0 + config["sigmoidal_extra_capacity"] = False + config["sigmoidal_prob"] = True + extra_fuzzy_logit_model, extra_fuzzy_logit_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[f'{split}'], + config, + extra_fuzzy_logit_model, + extra_fuzzy_logit_test_results, + ) + + # Bool model + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = f"Bool" + config["bool"] = True + bool_model, bool_test_results = training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + imbalance=imbalance, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + ) + training.update_statistics( + results[f'{split}'], + config, + bool_model, + bool_test_results, + ) + + + # sequential and independent models + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = f"" + config["sigmoidal_prob"] = True + ind_model, ind_test_results, seq_model, seq_test_results = \ + training.train_independent_and_sequential_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + config["architecture"] = "IndependentConceptBottleneckModel" + training.update_statistics( + results[f'{split}'], + config, + ind_model, + ind_test_results, + ) + + config["architecture"] = "SequentialConceptBottleneckModel" + training.update_statistics( + results[f'{split}'], + config, + seq_model, + seq_test_results, + ) + + # train vanilla model with more capacity (i.e., no concept supervision) + # but with ReLU activation + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = f"NoConceptSupervisionReLU_ExtraCapacity" + config["bool"] = False + config["extra_dims"] = (config['emb_size'] - 1) * n_concepts + config["bottleneck_nonlinear"] = "leakyrelu" + config["concept_loss_weight"] = 0 + extra_vanilla_relu_model, extra_vanilla_relu_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[f'{split}'], + config, + extra_vanilla_relu_model, + extra_vanilla_relu_test_results, + ) + + # save results + joblib.dump(results, os.path.join(result_dir, f'results.joblib')) + + return results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=( + 'Runs concept embedding experiment in CUB dataset.' + ), + ) + parser.add_argument( + '--project_name', + default='cub_rerun_concept_training', + help=( + "Project name used for Weights & Biases monitoring. If not " + "provided, then we will assume it is 'cub_rerun_concept_training'." + ), + metavar="name", + + ) + + parser.add_argument( + '--output_dir', + '-o', + default='results/cub_rerun/', + help=( + "directory where we will dump our experiment's results. If not " + "given, then we will use ./results/cub/." + ), + metavar="path", + + ) + + parser.add_argument( + '--rerun', + '-r', + default=False, + action="store_true", + help=( + "If set, then we will force a rerun of the entire experiment even " + "if valid results are found in the provided output directory. " + "Note that this may overwrite and previous results, so use " + "with care." + ), + + ) + parser.add_argument( + '--activation_freq', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the ' + 'embedding activations for our validation set. By default we will ' + 'not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--single_frequency_epochs', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the ' + 'embedding activations for our validation set. By default we will ' + 'not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--num_workers', + default=8, + help=( + 'number of workers used for data feeders. Do not use more workers ' + 'than cores in the machine.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + "-d", + "--debug", + action="store_true", + default=False, + help="starts debug mode in our program.", + ) + parser.add_argument( + '-p', + '--param', + action='append', + nargs=2, + metavar=('param_name=value'), + help=( + 'Allows the passing of a config param that will overwrite ' + 'anything passed as part of the config file itself.' + ), + default=[], + ) + args = parser.parse_args() + if args.project_name: + # Lazy import to avoid importing unless necessary + import wandb + main( + rerun=args.rerun, + result_dir=args.output_dir, + project_name=args.project_name, + activation_freq=args.activation_freq, + num_workers=args.num_workers, + single_frequency_epochs=args.single_frequency_epochs, + global_params=args.param, + ) diff --git a/experiments/cub_randint_ablation.py b/experiments/cub_randint_ablation.py new file mode 100644 index 0000000..59d7887 --- /dev/null +++ b/experiments/cub_randint_ablation.py @@ -0,0 +1,361 @@ +import argparse +import copy +import joblib +import numpy as np +import os +import torch + +from CUB200.cub_loader import load_data, find_class_imbalance +from pathlib import Path +from pytorch_lightning import seed_everything + +import cem.experiments.cub_experiments as cub +import cem.train.training as training +import cem.train.utils as utils + +def main( + rerun=False, + result_dir='results/cub_randint_ablation/', + project_name='', + activation_freq=0, + num_workers=8, + single_frequency_epochs=0, + global_params=None, +): + seed_everything(42) + # parameters for data, model, and training + og_config = dict( + cv=5, + max_epochs=300, + patience=15, + batch_size=128, + num_workers=num_workers, + emb_size=16, + extra_dims=0, + concept_loss_weight=5, + normalize_loss=False, + learning_rate=0.01, + weight_decay=4e-05, + scheduler_step=20, + weight_loss=True, + pretrain_model=True, + c_extractor_arch="resnet34", + optimizer="sgd", + bool=False, + early_stopping_monitor="val_loss", + early_stopping_mode="min", + early_stopping_delta=0.0, + sampling_percent=1, + + momentum=0.9, + shared_prob_gen=False, + sigmoidal_prob=False, + sigmoidal_embedding=False, + training_intervention_prob=0.0, + embeding_activation=None, + concat_prob=False, + ) + + utils.extend_with_global_params(og_config, global_params or []) + train_data_path = os.path.join(cub.BASE_DIR, 'train.pkl') + if og_config['weight_loss']: + imbalance = find_class_imbalance(train_data_path, True) + else: + imbalance = None + + val_data_path = train_data_path.replace('train.pkl', 'val.pkl') + test_data_path = train_data_path.replace('train.pkl', 'test.pkl') + sampling_percent = og_config.get("sampling_percent", 1) + n_concepts, n_tasks = 112, 200 + + if sampling_percent != 1: + # Do the subsampling + new_n_concepts = int(np.ceil(n_concepts * sampling_percent)) + selected_concepts_file = os.path.join( + result_dir, + f"selected_concepts_sampling_{sampling_percent}.npy", + ) + if (not rerun) and os.path.exists(selected_concepts_file): + selected_concepts = np.load(selected_concepts_file) + else: + selected_concepts = sorted( + np.random.permutation(n_concepts)[:new_n_concepts] + ) + np.save(selected_concepts_file, selected_concepts) + print("\t\tSelected concepts:", selected_concepts) + def subsample_transform(sample): + if isinstance(sample, list): + sample = np.array(sample) + return sample[selected_concepts] + + if og_config['weight_loss']: + imbalance = np.array(imbalance)[selected_concepts] + + train_dl = load_data( + pkl_paths=[train_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + val_dl = load_data( + pkl_paths=[val_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + test_dl = load_data( + pkl_paths=[test_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + # And set the right number of concepts to be used + n_concepts = new_n_concepts + else: + train_dl = load_data( + pkl_paths=[train_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + ) + val_dl = load_data( + pkl_paths=[val_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + ) + test_dl = load_data( + pkl_paths=[test_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=cub.CUB_DIR, + num_workers=og_config['num_workers'], + ) + + if result_dir and activation_freq: + # Then let's save the testing data for furter analysis later on + out_acts_save_dir = os.path.join(result_dir, "test_embedding_acts") + Path(out_acts_save_dir).mkdir(parents=True, exist_ok=True) + for (ds, name) in [ + (test_dl, "test"), + (val_dl, "val"), + ]: + x_total = [] + y_total = [] + c_total = [] + for x, y, c in ds: + x_total.append(x.cpu().detach()) + y_total.append(y.cpu().detach()) + c_total.append(c.cpu().detach()) + x_inputs = np.concatenate(x_total, axis=0) + print(f"x_{name}.shape =", x_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"x_{name}.npy"), x_inputs) + + y_inputs = np.concatenate(y_total, axis=0) + print(f"y_{name}.shape =", y_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"y_{name}.npy"), y_inputs) + + c_inputs = np.concatenate(c_total, axis=0) + print(f"c_{name}.shape =", c_inputs.shape) + np.save(os.path.join(out_acts_save_dir, f"c_{name}.npy"), c_inputs) + + sample = next(iter(train_dl)) + n_concepts, n_tasks = sample[2].shape[-1], 200 + + print("Training sample shape is:", sample[0].shape) + print("Training label shape is:", sample[1].shape) + print("Training concept shape is:", sample[2].shape) + os.makedirs(result_dir, exist_ok=True) + results = {} + + for prob in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99, 1.0]: + results[prob] = {} + for split in range(og_config["cv"]): + print(f'Experiment {split+1}/{og_config["cv"]} with prob', prob) + results[prob][f'{split}'] = {} + + # Trial period for mixture embedding model + config = copy.deepcopy(og_config) + config["architecture"] = "MixtureEmbModel" + config["extra_name"] = ( + f"SharedProb_AdaptiveDropout_NoProbConcat_prob_{prob}" + ) + config["shared_prob_gen"] = True + config["sigmoidal_prob"] = False + config["sigmoidal_embedding"] = False + config['training_intervention_prob'] = prob + config['concat_prob'] = False + config['emb_size'] = config['emb_size'] + mixed_emb_shared_prob_model, mixed_emb_shared_prob_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[prob][f'{split}'], + config, + mixed_emb_shared_prob_model, + mixed_emb_shared_prob_test_results, + ) + + # save results + joblib.dump(results, os.path.join(result_dir, f'results.joblib')) + + return results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=( + 'Runs ablation study for RandInt in CUB dataset.' + ), + ) + parser.add_argument( + '--project_name', + default='', + help=( + "Project name used for Weights & Biases monitoring. If not " + "provided, then we will assume no W&B logging is used." + ), + metavar="name", + + ) + + parser.add_argument( + '--output_dir', + '-o', + default='results/cub_randint_ablation/', + help=( + "directory where we will dump our experiment's results. If not " + "given, then we will use ./results/cub_randint_ablation/." + ), + metavar="path", + + ) + parser.add_argument( + '--rerun', + '-r', + default=False, + action="store_true", + help=( + "If set, then we will force a rerun of the entire experiment even " + "if valid results are found in the provided output directory. " + "Note that this may overwrite and previous results, so use with " + "care." + ), + + ) + parser.add_argument( + '--activation_freq', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the ' + 'embedding activations for our validation set. By default we ' + 'will not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--single_frequency_epochs', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the ' + 'embedding activations for our validation set. By default we will ' + 'not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--num_workers', + default=8, + help=( + 'number of workers used for data feeders. Do not use more workers ' + 'than cores in the machine.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + "-d", + "--debug", + action="store_true", + default=False, + help="starts debug mode in our program.", + ) + parser.add_argument( + '-p', + '--param', + action='append', + nargs=2, + metavar=('param_name=value'), + help=( + 'Allows the passing of a config param that will overwrite ' + 'anything passed as part of the config file itself.' + ), + default=[], + ) + args = parser.parse_args() + if args.project_name: + # Lazy import to avoid importing unless necessary + import wandb + main( + rerun=args.rerun, + result_dir=args.output_dir, + project_name=args.project_name, + activation_freq=args.activation_freq, + num_workers=args.num_workers, + single_frequency_epochs=args.single_frequency_epochs, + global_params=args.param, + ) diff --git a/experiments/cub_subsample_experiment.py b/experiments/cub_subsample_experiment.py new file mode 100644 index 0000000..c834537 --- /dev/null +++ b/experiments/cub_subsample_experiment.py @@ -0,0 +1,460 @@ +import argparse +import copy +import joblib +import numpy as np +import os +import torch + +from data.CUB200.cub_loader import load_data, find_class_imbalance +from pathlib import Path +from pytorch_lightning import seed_everything + +import cem.experiments.cub_experiments as cub +import cem.train.training as training +import cem.train.utils as utils + + +def main( + rerun=False, + result_dir='results/cub_subsample/', + project_name='', + save_models=True, + activation_freq=0, + single_frequency_epochs=0, + global_params=None, + num_workers=8, +): + seed_everything(42) + # parameters for data, model, and training + og_config = dict( + cv=5, + max_epochs=300, + patience=15, + batch_size=128, + num_workers=num_workers, + emb_size=16, + extra_dims=0, + concept_loss_weight=5, + normalize_loss=False, + learning_rate=0.01, + weight_decay=4e-05, + scheduler_step=20, + weight_loss=True, + pretrain_model=True, + c_extractor_arch="resnet34", + optimizer="sgd", + bool=False, + early_stopping_monitor="val_loss", + early_stopping_mode="min", + early_stopping_delta=0.0, + corr_thresh=0.5, + dense_corr_thresh=0.25, + sampling_percent=1, + sampling_percents=[0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 1], + + momentum=0.9, + shared_prob_gen=False, + sigmoidal_prob=False, + sigmoidal_embedding=False, + training_intervention_prob=0.0, + embeding_activation=None, + concat_prob=False, + ) + + train_data_path = os.path.join(cub.BASE_DIR, 'train.pkl') + if og_config['weight_loss']: + og_imbalance = find_class_imbalance(train_data_path, True) + else: + og_imbalance = None + utils.extend_with_global_params(og_config, global_params or []) + + val_data_path = train_data_path.replace('train.pkl', 'val.pkl') + test_data_path = train_data_path.replace('train.pkl', 'test.pkl') + n_concepts, n_tasks = 112, 200 + + os.makedirs(result_dir, exist_ok=True) + joblib.dump( + og_config, + os.path.join(result_dir, f'experiment_config.joblib'), + ) + + if result_dir and activation_freq: + # Then let's save the testing data for further analysis later on + out_acts_save_dir = os.path.join(result_dir, "test_embedding_acts") + Path(out_acts_save_dir).mkdir(parents=True, exist_ok=True) + + results = {} + for sampling_percent in og_config['sampling_percents']: + print( + f"Training model by subsampling {sampling_percent *100}% of " + f"concepts" + ) + results[sampling_percent] = {} + new_n_concepts = int(np.ceil(n_concepts * sampling_percent)) + for split in range(og_config["cv"]): + print( + f'\tExperiment {split+1}/{og_config["cv"]} with sampling ' + f'rate {sampling_percent *100}% and {new_n_concepts} concepts' + ) + results[sampling_percent][f'{split}'] = {} + + # Do the subsampling + selected_concepts_file = os.path.join( + result_dir, + ( + f"selected_concepts_" + f"sampling_{sampling_percent}_fold_{split}.npy" + ), + ) + if (not rerun) and os.path.exists(selected_concepts_file): + selected_concepts = np.load(selected_concepts_file) + else: + if sampling_percent != 1: + selected_concepts = np.random.permutation( + n_concepts + )[:new_n_concepts] + else: + # Then simply select them all in their original order + selected_concepts = np.range(new_n_concepts) + np.save(selected_concepts_file, selected_concepts) + print("\t\tSelected concepts:", selected_concepts) + def subsample_transform(sample): + if isinstance(sample, list): + sample = np.array(sample) + return sample[selected_concepts] + + if og_config['weight_loss']: + imbalance = np.array(og_imbalance)[selected_concepts] + else: + imbalance = np.array(og_imbalance)[selected_concepts] + + train_dl = load_data( + pkl_paths=[train_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + val_dl = load_data( + pkl_paths=[val_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + test_dl = load_data( + pkl_paths=[test_data_path], + use_attr=True, + no_img=False, + batch_size=og_config['batch_size'], + uncertain_label=False, + n_class_attr=2, + image_dir='images', + resampling=False, + root_dir=CUB_DIR, + num_workers=og_config['num_workers'], + concept_transform=subsample_transform, + ) + + sample = next(iter(train_dl)) + print("Training sample shape is:", sample[0].shape) + print("Training label shape is:", sample[1].shape) + print("Training concept shape is:", sample[2].shape) + + + # train vanilla model with more capacity (i.e., no concept + # supervision) but with ReLU activation + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = ( + f"NoConceptSupervisionReLU_ExtraCapacity_" + f"subsample_{sampling_percent}" + ) + config["sampling_percent"] = sampling_percent + config["bool"] = False + config["extra_dims"] = config['emb_size'] * new_n_concepts + config["bottleneck_nonlinear"] = "relu" + config["concept_loss_weight"] = 0 + extra_vanilla_relu_model, extra_vanilla_relu_test_results = \ + training.train_model( + n_concepts=new_n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[sampling_percent][f'{split}'], + config, + extra_vanilla_relu_model, + extra_vanilla_relu_test_results, + ) + + # fuzzy model + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = f"Fuzzy_subsample_{sampling_percent}" + config["sampling_percent"] = sampling_percent + config["bool"] = False + config["extra_dims"] = 0 + config["sigmoidal_extra_capacity"] = False + config["sigmoidal_prob"] = True + extra_fuzzy_logit_model, extra_fuzzy_logit_test_results = \ + training.train_model( + n_concepts=new_n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[sampling_percent][f'{split}'], + config, + extra_fuzzy_logit_model, + extra_fuzzy_logit_test_results, + ) + + # train model *without* embeddings but with extra capacity. + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["bool"] = False + config["extra_dims"] = config['emb_size'] * new_n_concepts + config["sampling_percent"] = sampling_percent + config["extra_name"] = ( + f"FuzzyExtraCapacity_Logit_subsample_{sampling_percent}" + ) + config["sigmoidal_extra_capacity"] = False + config["sigmoidal_prob"] = False + extra_fuzzy_logit_model, extra_fuzzy_logit_test_results = \ + training.train_model( + n_concepts=new_n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[sampling_percent][f'{split}'], + config, + extra_fuzzy_logit_model, + extra_fuzzy_logit_test_results, + ) + + # train model *without* embeddings (concepts are just *Boolean* + # scalars) + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = f"Bool_subsample_{sampling_percent}" + config["bool"] = True + config["sampling_percent"] = sampling_percent + config["selected_concepts"] = selected_concepts + bool_model, bool_test_results = training.train_model( + n_concepts=new_n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + imbalance=imbalance, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + save_model=save_models, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + ) + training.update_statistics( + results[sampling_percent][f'{split}'], + config, + bool_model, + bool_test_results, + save_model=save_models, + ) + + config = copy.deepcopy(og_config) + config["architecture"] = "MixtureEmbModel" + config["extra_name"] = ( + f"SharedProb_AdaptiveDropout_NoProbConcat_" + f"subsample_{sampling_percent}" + ) + config["sampling_percent"] = sampling_percent + config["shared_prob_gen"] = True + config["sigmoidal_prob"] = True + config["sigmoidal_embedding"] = False + config['training_intervention_prob'] = 0.25 + config['concat_prob'] = False + config['emb_size'] = config['emb_size'] + config["embeding_activation"] = "leakyrelu" + mixed_emb_shared_prob_model, mixed_emb_shared_prob_test_results = \ + training.train_model( + n_concepts=new_n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + imbalance=imbalance, + ) + training.update_statistics( + results[sampling_percent][f'{split}'], + config, + mixed_emb_shared_prob_model, + mixed_emb_shared_prob_test_results, + ) + + # save results + joblib.dump(results, os.path.join(result_dir, f'results.joblib')) + + return results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=( + 'Runs concept subsampling experiment in CUB dataset.' + ), + ) + parser.add_argument( + '--project_name', + default='', + help=( + "Project name used for Weights & Biases monitoring. If not " + "provided, then we will assume no W&B logging is done." + ), + metavar="name", + + ) + + parser.add_argument( + '--output_dir', + '-o', + default='results/cub_subsample/', + help=( + "directory where we will dump our experiment's results. If not " + "given, then we will use ./results/cub_subsample/." + ), + metavar="path", + + ) + parser.add_argument( + '--rerun', + '-r', + default=False, + action="store_true", + help=( + "If set, then we will force a rerun of the entire experiment even " + "if valid results are found in the provided output directory. " + "Note that this may overwrite and previous results, so use with " + "care." + ), + + ) + parser.add_argument( + '--activation_freq', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the ' + 'embedding activations for our validation set. By default we will ' + 'not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--single_frequency_epochs', + default=0, + help=( + 'how frequently, in terms of epochs, should we store the ' + 'embedding activations for our validation set. By default we will ' + 'not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--num_workers', + default=12, + help=( + 'number of workers used for data feeders. Do not use more workers ' + 'than cores in the machine.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + "-d", + "--debug", + action="store_true", + default=False, + help="starts debug mode in our program.", + ) + parser.add_argument( + '-p', + '--param', + action='append', + nargs=2, + metavar=('param_name=value'), + help=( + 'Allows the passing of a config param that will overwrite ' + 'anything passed as part of the config file itself.' + ), + default=[], + ) + args = parser.parse_args() + main( + rerun=args.rerun, + result_dir=args.output_dir, + project_name=args.project_name, + activation_freq=args.activation_freq, + num_workers=args.num_workers, + single_frequency_epochs=args.single_frequency_epochs, + global_params=args.param, + ) diff --git a/experiments/synthetic_datasets_experiments.py b/experiments/synthetic_datasets_experiments.py new file mode 100644 index 0000000..58ae727 --- /dev/null +++ b/experiments/synthetic_datasets_experiments.py @@ -0,0 +1,501 @@ +import argparse +import copy +import joblib +import numpy as np +import os +import pytorch_lightning as pl +import torch + +from pathlib import Path +from pytorch_lightning import seed_everything + +import cem.train.training as training +import cem.train.utils as utils + +################################################################################ +## DATASET GENERATORS +################################################################################ + + +def generate_xor_data(size): + # sample from normal distribution + x = np.random.uniform(0, 1, (size, 2)) + c = np.stack([ + x[:, 0] > 0.5, + x[:, 1] > 0.5, + ]).T + y = np.logical_xor(c[:, 0], c[:, 1]) + + x = torch.FloatTensor(x) + c = torch.FloatTensor(c) + y = torch.FloatTensor(y) + return x, c, y + + +def generate_trig_data(size): + h = np.random.normal(0, 2, (size, 3)) + x, y, z = h[:, 0], h[:, 1], h[:, 2] + + # raw features + input_features = np.stack([ + np.sin(x) + x, + np.cos(x) + x, + np.sin(y) + y, + np.cos(y) + y, + np.sin(z) + z, + np.cos(z) + z, + x ** 2 + y ** 2 + z ** 2, + ]).T + + # concetps + concetps = np.stack([ + x > 0, + y > 0, + z > 0, + ]).T + + # task + downstream_task = (x + y + z) > 1 + + input_features = torch.FloatTensor(input_features) + concetps = torch.FloatTensor(concetps) + downstream_task = torch.FloatTensor(downstream_task) + return input_features, concetps, downstream_task + + +def generate_dot_data(size): + # sample from normal distribution + emb_size = 2 + v1 = np.random.randn(size, emb_size) * 2 + v2 = np.ones(emb_size) + v3 = np.random.randn(size, emb_size) * 2 + v4 = -np.ones(emb_size) + x = np.hstack([v1+v3, v1-v3]) + c = np.stack([ + np.dot(v1, v2).ravel() > 0, + np.dot(v3, v4).ravel() > 0, + ]).T + y = ((v1*v3).sum(axis=-1) > 0).astype(np.int64) + + x = torch.FloatTensor(x) + c = torch.FloatTensor(c) + y = torch.Tensor(y) + return x, c, y + + +################################################################################ +## MAIN PROGRAM +################################################################################ + +def main( + dataset, + result_dir, + rerun=False, + project_name='', + activation_freq=0, + single_frequency_epochs=0, + global_params=None, +): + seed_everything(42) + # parameters for data, model, and training + og_config = dict( + cv=5, + dataset_size=3000, + max_epochs=500, + patience=15, + batch_size=256, + num_workers=8, + emb_size=128, + extra_dims=0, + concept_loss_weight=1, + normalize_loss=False, + learning_rate=0.01, + weight_decay=0, + scheduler_step=20, + weight_loss=False, + optimizer="adam", + bool=False, + early_stopping_monitor="val_loss", + early_stopping_mode="min", + early_stopping_delta=0.0, + masked=False, + check_val_every_n_epoch=30, + linear_c2y=True, + embeding_activation="leakyrelu", + + momentum=0.9, + shared_prob_gen=False, + sigmoidal_prob=False, + sigmoidal_embedding=False, + training_intervention_prob=0.0, + concat_prob=False, + ) + + if dataset == "xor": + generate_data = generate_xor_data + elif dataset in ["trig", "trigonometry"]: + generate_data = generate_trig_data + elif dataset in ["vector", "dot"]: + generate_data = generate_dot_data + else: + raise ValueError(f"Unsupported dataset {dataset}") + + utils.extend_with_global_params(og_config, global_params or []) + dataset_size = og_config['dataset_size'] + batch_size = og_config["batch_size"] + x, c, y = generate_data(int(dataset_size * 0.7)) + train_data = torch.utils.data.TensorDataset(x, y, c) + train_dl = torch.utils.data.DataLoader(train_data, batch_size=batch_size) + dataset = dataset.lower() + + x_test, c_test, y_test = generate_data(int(dataset_size * 0.2)) + test_data = torch.utils.data.TensorDataset(x_test, y_test, c_test) + test_dl = torch.utils.data.DataLoader(test_data, batch_size=batch_size) + + x_val, c_val, y_val = generate_data(int(dataset_size * 0.1)) + val_data = torch.utils.data.TensorDataset(x_val, y_val, c_val) + val_dl = torch.utils.data.DataLoader(val_data, batch_size=batch_size) + + if result_dir and activation_freq: + # Then let's save the testing data for further analysis later on + out_acts_save_dir = os.path.join(result_dir, "test_embedding_acts") + Path(out_acts_save_dir).mkdir(parents=True, exist_ok=True) + np.save(os.path.join(out_acts_save_dir, "x_test.npy"), x_test) + np.save(os.path.join(out_acts_save_dir, "y_test.npy"), y_test) + np.save(os.path.join(out_acts_save_dir, "c_test.npy"), c_test) + np.save(os.path.join(out_acts_save_dir, "x_val.npy"), x_val) + np.save(os.path.join(out_acts_save_dir, "y_val.npy"), y_val) + np.save(os.path.join(out_acts_save_dir, "c_val.npy"), c_val) + + sample = next(iter(train_dl)) + n_features, n_concepts, n_tasks = ( + sample[0].shape[-1], + sample[2].shape[-1], + 1, + ) + + # And make the concept extractor architecture + def c_extractor_arch(output_dim): + return torch.nn.Sequential(*[ + torch.nn.Linear(n_features, 128), + torch.nn.LeakyReLU(), + torch.nn.Linear(128, 128), + torch.nn.LeakyReLU(), + torch.nn.Linear(128, output_dim), + ]) + og_config['c_extractor_arch'] = c_extractor_arch + + print("Training sample shape is:", sample[0].shape) + print("Training label shape is:", sample[2].shape) + print("Training concept shape is:", sample[1].shape) + + os.makedirs(result_dir, exist_ok=True) + + results = {} + for split in range(og_config["cv"]): + print(f'Experiment {split+1}/{og_config["cv"]}') + results[f'{split}'] = {} + + # train model *without* embeddings (concepts are just *fuzzy* scalars) + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["bool"] = False + config["extra_name"] = "Fuzzy" + config["concept_loss_weight"] = config.get( + "cbm_concept_loss_weight", + config["concept_loss_weight"], + ) + fuzzy_model, fuzzy_test_results = training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + ) + training.update_statistics( + results[f'{split}'], + config, + fuzzy_model, + fuzzy_test_results, + ) + + # Trial period for mixture embedding model + config = copy.deepcopy(og_config) + config["architecture"] = "MixtureEmbModel" + config["extra_name"] = f"SharedProb_AdaptiveDropout_NoProbConcat" + config["shared_prob_gen"] = True + config["sigmoidal_prob"] = True + config["sigmoidal_embedding"] = False + config['training_intervention_prob'] = 0.25 + config['concat_prob'] = False + config['emb_size'] = config['emb_size'] + mixed_emb_shared_prob_model, mixed_emb_shared_prob_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + ) + training.update_statistics( + results[f'{split}'], + config, + mixed_emb_shared_prob_model, + mixed_emb_shared_prob_test_results, + ) + + # Trial period for mixture embedding model + config = copy.deepcopy(og_config) + config["architecture"] = "MixtureEmbModel" + config["extra_name"] = f"SharedProb_Adaptive_NoProbConcat" + config["shared_prob_gen"] = True + config["sigmoidal_prob"] = True + config["sigmoidal_embedding"] = False + config['training_intervention_prob'] = 0.0 + config['concat_prob'] = False + config['emb_size'] = config['emb_size'] + mixed_emb_shared_prob_model, mixed_emb_shared_prob_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + ) + training.update_statistics( + results[f'{split}'], + config, + mixed_emb_shared_prob_model, + mixed_emb_shared_prob_test_results, + ) + + # train model *without* embeddings but with extra capacity + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["bool"] = False + config["extra_dims"] = (config['emb_size'] - 1) * n_concepts + config["extra_name"] = "FuzzyExtraCapacity_LogitOnlyProb" + config["bottleneck_nonlinear"] = "leakyrelu" + config["sigmoidal_extra_capacity"] = False + config["sigmoidal_prob"] = True + extra_fuzzy_logit_model, extra_fuzzy_logit_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + ) + training.update_statistics( + results[f'{split}'], + config, + extra_fuzzy_logit_model, + extra_fuzzy_logit_test_results, + ) + + # train vanilla model with more capacity (i.e., no concept supervision) + # but with ReLU activation + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["bool"] = False + config["extra_dims"] = (config['emb_size'] - 1) * n_concepts + config["bottleneck_nonlinear"] = "leakyrelu" + config["extra_name"] = "NoConceptSupervisionReLU_ExtraCapacity" + config["concept_loss_weight"] = 0 + config["sigmoidal_extra_capacity"] = False + config["sigmoidal_prob"] = False + extra_vanilla_relu_model, extra_vanilla_relu_test_results = \ + training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + ) + training.update_statistics( + results[f'{split}'], + config, + extra_vanilla_relu_model, + extra_vanilla_relu_test_results, + ) + + # train model *without* embeddings (concepts are just *Boolean* scalars) + config = copy.deepcopy(og_config) + config["architecture"] = "ConceptBottleneckModel" + config["extra_name"] = "Bool" + config["bool"] = True + if "cbm_bool_concept_loss_weight" in config: + config["concept_loss_weight"] = config[ + "cbm_bool_concept_loss_weight" + ] + else: + config["concept_loss_weight"] = config.get( + "cbm_concept_loss_weight", + config["concept_loss_weight"], + ) + bool_model, bool_test_results = training.train_model( + n_concepts=n_concepts, + n_tasks=n_tasks, + config=config, + train_dl=train_dl, + val_dl=val_dl, + test_dl=test_dl, + split=split, + result_dir=result_dir, + rerun=rerun, + project_name=project_name, + seed=split, + activation_freq=activation_freq, + single_frequency_epochs=single_frequency_epochs, + ) + training.update_statistics( + results[f'{split}'], + config, + bool_model, + bool_test_results, + ) + + # save results + joblib.dump(results, os.path.join(result_dir, f'results.joblib')) + return results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description=( + 'Runs concept embedding experiment in our synthetic datasets.' + ), + ) + parser.add_argument( + 'dataset', + help=( + "Dataset to be used. One of xor, trig, dot." + ), + metavar="name", + + ) + parser.add_argument( + '--project_name', + default='', + help=( + "Project name used for Weights & Biases monitoring. If not " + "provided, then we will assume we will not be using wandb " + "for logging'." + ), + metavar="name", + + ) + + parser.add_argument( + '--output_dir', + '-o', + default='results/synthetic/', + help=( + "directory where we will dump our experiment's results. If not " + "given, then we will use results/synthetic/." + ), + metavar="path", + + ) + parser.add_argument( + '--rerun', + '-r', + default=False, + action="store_true", + help=( + "If set, then we will force a rerun of the entire experiment even " + "if valid results are found in the provided output directory. " + "Note that this may overwrite and previous results, so use with " + "care." + ), + + ) + parser.add_argument( + '--activation_freq', + default=0, + help=( + 'How frequently, in terms of epochs, should we store the ' + 'embedding activations for our validation set. By default we will ' + 'not store any activations.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + '--single_frequency_epochs', + default=0, + help=( + 'how many epochs we will monitor using an equivalent frequency of 1.' + ), + metavar='N', + type=int, + ) + parser.add_argument( + "-d", + "--debug", + action="store_true", + default=False, + help="starts debug mode in our program.", + ) + parser.add_argument( + '-p', + '--param', + action='append', + nargs=2, + metavar=('param_name=value'), + help=( + 'Allows the passing of a config param that will overwrite ' + 'anything passed as part of the config file itself.' + ), + default=[], + ) + args = parser.parse_args() + main( + dataset=args.dataset, + rerun=args.rerun, + result_dir=args.output_dir, + project_name=args.project_name, + activation_freq=args.activation_freq, + single_frequency_epochs=args.single_frequency_epochs, + global_params=args.param + ) diff --git a/figures/cem.png b/figures/cem.png new file mode 100644 index 0000000..139a2ba Binary files /dev/null and b/figures/cem.png differ