diff --git a/README.md b/README.md index 5397bab..ae30169 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,9 @@ [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/mateoespinosa/cem/blob/main/LICENSE) [![Python 3.7+](https://img.shields.io/badge/python-3.7+-green.svg)](https://www.python.org/downloads/release/python-370/) [![CEM Paper](https://img.shields.io/badge/-CEM%20Paper-red)](https://arxiv.org/abs/2209.09056) [![IntCEM Paper](https://img.shields.io/badge/-IntCEM%20Paper-red)](https://arxiv.org/abs/2309.16928) [![Poster](https://img.shields.io/badge/-Poster-yellow)](https://github.com/mateoespinosa/cem/blob/main/media/poster.pdf) [![Slides](https://img.shields.io/badge/-Slides-lightblue)](https://github.com/mateoespinosa/cem/blob/main/media/slides.pptx) -This repository contains the official Pytorch implementation of our papers -[*"Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off"*](https://arxiv.org/abs/2209.09056), -accepted and presented at **NeurIPS 2022**, and [*"Learning to Receive Help: Intervention-Aware Concept Embedding Models"*](https://arxiv.org/abs/2309.16928), -accepted and presented as a **spotlight paper** at **NeurIPS 2023**. - -This first paper was done by [Mateo Espinosa Zarlenga\*](https://mateoespinosa.github.io/), +This repository contains the official Pytorch implementation of our two papers: +- [*"Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off"*](https://arxiv.org/abs/2209.09056), +accepted and presented at **NeurIPS 2022**. This paper was the work of [Mateo Espinosa Zarlenga\*](https://mateoespinosa.github.io/), [Pietro Barbiero\*](https://www.pietrobarbiero.eu/), [Gabriele Ciravegna](https://sailab.diism.unisi.it/people/gabriele-ciravegna/), [Giuseppe Marra](https://www.giuseppemarra.com/), @@ -19,14 +16,17 @@ This first paper was done by [Mateo Espinosa Zarlenga\*](https://mate [Adrian Weller](http://mlg.eng.cam.ac.uk/adrian/), [Pietro Lio](https://www.cl.cam.ac.uk/~pl219/), and [Mateja Jamnik](https://www.cl.cam.ac.uk/~mj201/). - -The second paper was done by [Mateo Espinosa Zarlenga\*](https://mateoespinosa.github.io/), +- [*"Learning to Receive Help: Intervention-Aware Concept Embedding Models"*](https://arxiv.org/abs/2309.16928), +accepted and presented as a **spotlight paper** at **NeurIPS 2023**. The paper +was the work of [Mateo Espinosa Zarlenga\*](https://mateoespinosa.github.io/), [Katie Collins](https://collinskatie.github.io/), [Krishnamurthy (Dj) Dvijotham](https://dj-research.netlify.app/), [Adrian Weller](http://mlg.eng.cam.ac.uk/adrian/), [Zohreh Shams](https://zohrehshams.com/), and [Mateja Jamnik](https://www.cl.cam.ac.uk/~mj201/) + + ## TL;DR ### Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off @@ -161,12 +161,12 @@ import cem ## High-level Usage In this repository, we include a standalone Pytorch implementation of Concept Embedding Models (CEMs), Intervention-aware Concept Embedding Models (IntCEMs), -and seveeal variants of Concept Bottleneck Models (CBMs). +and several variants of Concept Bottleneck Models (CBMs). These models can be easily trained from scratch given a set of samples annotated with a downstream task and a set of binary concepts. In order to use our implementation, however, you first need to install all our code's requirements (listed in `requirements.txt`) by following the installation instructions -above, +above. After installation has been completed, you should be able to import `cem` as a package and use it to train a **CEM** as follows: @@ -273,7 +273,7 @@ You can import CBMs by including from cem.models.cbm import ConceptBottleneckModel ``` ---- +## Class Arguments Our **CEM module** takes the following initialization arguments: - `n_concepts` (int): The number of concepts given at training time. diff --git a/experiments/configs/celeba.yaml b/experiments/configs/celeba.yaml index 823b955..a64842b 100644 --- a/experiments/configs/celeba.yaml +++ b/experiments/configs/celeba.yaml @@ -8,7 +8,7 @@ shared_params: image_size: 64 num_classes: 1000 batch_size: 512 - root_dir: /homes/me466/data/ + root_dir: data/ use_imbalance: True use_binary_vector_class: True num_concepts: 6 diff --git a/experiments/configs/cub.yaml b/experiments/configs/cub.yaml index 6b8c318..aa524b0 100644 --- a/experiments/configs/cub.yaml +++ b/experiments/configs/cub.yaml @@ -9,7 +9,7 @@ shared_params: batch_size: 256 # DATASET VARIABLES - root_dir: /auto/homes/me466/data/CUB200/ + root_dir: data/CUB200/ sampling_percent: 1 test_subsampling: 1 weight_loss: True diff --git a/experiments/configs/cub_incomplete.yaml b/experiments/configs/cub_incomplete.yaml index 2e11193..8d63410 100644 --- a/experiments/configs/cub_incomplete.yaml +++ b/experiments/configs/cub_incomplete.yaml @@ -9,7 +9,7 @@ shared_params: batch_size: 256 # DATASET VARIABLES - root_dir: /auto/homes/me466/data/CUB200/ + root_dir: data/CUB200/ sampling_percent: 0.25 # [IMPORTANT] Select only a quarter of all concepts! sampling_groups: True test_subsampling: 1 diff --git a/experiments/configs/dot.yaml b/experiments/configs/dot.yaml index 15e7ddd..8433206 100644 --- a/experiments/configs/dot.yaml +++ b/experiments/configs/dot.yaml @@ -8,7 +8,7 @@ shared_params: dataset_size: 3000 batch_size: 256 num_workers: 8 - root_dir: data/dot + root_dir: data/ sampling_percent: 1 test_subsampling: 1 diff --git a/experiments/configs/mnist_add.yaml b/experiments/configs/mnist_add.yaml index 04dce5a..ea4f9e3 100644 --- a/experiments/configs/mnist_add.yaml +++ b/experiments/configs/mnist_add.yaml @@ -6,7 +6,7 @@ shared_params: # Dataset Configuration dataset_config: dataset: mnist_add - root_dir: /anfs/bigdisc/me466/ + root_dir: data/ num_workers: 8 batch_size: 2048 num_operands: 12 diff --git a/experiments/configs/mnist_add_incomplete.yaml b/experiments/configs/mnist_add_incomplete.yaml index 1314020..cb8c897 100644 --- a/experiments/configs/mnist_add_incomplete.yaml +++ b/experiments/configs/mnist_add_incomplete.yaml @@ -6,7 +6,7 @@ shared_params: # Dataset Configuration dataset_config: dataset: mnist_add - root_dir: /anfs/bigdisc/me466/ + root_dir: data/ num_workers: 8 batch_size: 2048 num_operands: 12 diff --git a/figures/intcem_white_background.jpg b/figures/intcem_white_background.jpg new file mode 100644 index 0000000..7484a18 Binary files /dev/null and b/figures/intcem_white_background.jpg differ