Skip to content

Commit

Permalink
Updating README.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mateo Espinosa Zarlenga committed Mar 24, 2024
1 parent 45709d8 commit 5fb5f93
Showing 1 changed file with 184 additions and 29 deletions.
213 changes: 184 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/mateoespinosa/cem/blob/main/LICENSE) [![Python 3.7+](https://img.shields.io/badge/python-3.7+-green.svg)](https://www.python.org/downloads/release/python-370/) [![Paper](https://img.shields.io/badge/-Paper-red)](https://arxiv.org/abs/2209.09056) [![Poster](https://img.shields.io/badge/-Poster-yellow)](https://github.com/mateoespinosa/cem/blob/main/media/poster.pdf) [![Slides](https://img.shields.io/badge/-Slides-lightblue)](https://github.com/mateoespinosa/cem/blob/main/media/slides.pptx)
# Concept Embedding Models
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/mateoespinosa/cem/blob/main/LICENSE) [![Python 3.7+](https://img.shields.io/badge/python-3.7+-green.svg)](https://www.python.org/downloads/release/python-370/) [![CEM Paper](https://img.shields.io/badge/-CEM%20Paper-red)](https://arxiv.org/abs/2209.09056) [![IntCEM Paper](https://img.shields.io/badge/-IntCEM%20Paper-red)](https://arxiv.org/abs/2309.16928) [![Poster](https://img.shields.io/badge/-Poster-yellow)](https://github.com/mateoespinosa/cem/blob/main/media/poster.pdf) [![Slides](https://img.shields.io/badge/-Slides-lightblue)](https://github.com/mateoespinosa/cem/blob/main/media/slides.pptx)


![CEM Architecture](figures/cem_white_background.png)

This repository contains the official Pytorch implementation of our papers
[*"Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off"*](https://arxiv.org/abs/2209.09056),
accepted and presented at **NeurIPS 2022**, and [*"Learning to Receive Help: Intervention-Aware Concept Embedding Models"*](https://arxiv.org/abs/2309.16928),
accepted and presented as a **spotlight paper** at **NeurIPS 2023**.

This repository contains the official Pytorch implementation of our paper
[*"Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off"*](https://arxiv.org/abs/2209.09056)
accepted and presented at **NeurIPS 2022**.

This work was done by [Mateo Espinosa Zarlenga<sup>\*</sup>](https://mateoespinosa.github.io/),
This first paper was done by [Mateo Espinosa Zarlenga<sup>\*</sup>](https://mateoespinosa.github.io/),
[Pietro Barbiero<sup>\*</sup>](https://www.pietrobarbiero.eu/),
[Gabriele Ciravegna](https://sailab.diism.unisi.it/people/gabriele-ciravegna/),
[Giuseppe Marra](https://www.giuseppemarra.com/),
Expand All @@ -19,11 +17,22 @@ This work was done by [Mateo Espinosa Zarlenga<sup>\*</sup>](https://mateoespino
[Frederic Precioso](https://www.i3s.unice.fr/~precioso/),
[Stefano Melacci](https://scholar.google.com/citations?user=_HHu1MQAAAAJ&hl=en),
[Adrian Weller](http://mlg.eng.cam.ac.uk/adrian/),
[Pietro Lio](https://www.cl.cam.ac.uk/~pl219/),
[Pietro Lio](https://www.cl.cam.ac.uk/~pl219/), and
[Mateja Jamnik](https://www.cl.cam.ac.uk/~mj201/).

The second paper was done by [Mateo Espinosa Zarlenga<sup>\*</sup>](https://mateoespinosa.github.io/),
[Katie Collins](https://collinskatie.github.io/),
[Krishnamurthy (Dj) Dvijotham](https://dj-research.netlify.app/),
[Adrian Weller](http://mlg.eng.cam.ac.uk/adrian/),
[Zohreh Shams](https://zohrehshams.com/), and
[Mateja Jamnik](https://www.cl.cam.ac.uk/~mj201/)

## TL;DR

### Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off

![CEM Architecture](figures/cem_white_background.png)

#### TL;DR

We propose **Concept Embedding Models (CEMs)**, a novel family of concept-based
interpretable neural architectures that can achieve task high performance while
Expand All @@ -33,7 +42,27 @@ annotations in the task of interest and allow effective test-time concept
interventions, enabling CEMs to drastically improve their task performance in a
human-in-the-loop setting.

#### Abstract
### Learning to Receive Help: Intervention-Aware Concept Embedding Models

![IntCEM Architecture](figures/intcem_white_background.jpg)

In this paper we argue that previous concept-based architectures, including
CEMs to some extent, do not include implicit training incentives to improve
their performance when they are intervened on. Yet, during testing we expect
these models to correctly intake expert feedback and improve their performance.
Here, we propose **Intervention-awqre Concept Embedding Models (IntCEMs)** as
an alternative training framework for CEMs that exposes CEMs to informative
trajectories of interventions during training so that a penalty is incurred by the model
when it mispredicts its task after several interventions have
been performed compared to when no interventions are performed.
Through this process, we concurrently learn a CEM that is more receptive to
test-time intervnetions and an intervention policy that suggests which concepts
one should intervene on next to lead to the largest decrease in uncertainty
in the mdoel.

## Abstracts

### Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off (NeurIPS 2022)

Deploying AI-powered systems requires trustworthy models supporting effective
human interactions, going beyond raw prediction accuracy. Concept bottleneck
Expand All @@ -54,7 +83,28 @@ truth labels, (3) support test-time concept interventions whose effect in test
accuracy surpasses that in standard concept bottleneck models, and (4) scale to
real-world conditions where complete concept supervisions are scarce.

# Model
### Learning to Receive Help: Intervention-Aware Concept Embedding Models (NeurIPS 2023)

Concept Bottleneck Models (CBMs) tackle the opacity of neural architectures by
constructing and explaining their predictions using a set of high-level
concepts. A special property of these models is that they permit concept
interventions, wherein users can correct mispredicted concepts and thus improve
the model's performance. Recent work, however, has shown that intervention
efficacy can be highly dependent on the order in which concepts are intervened
on and on the model's architecture and training hyperparameters. We argue that
this is rooted in a CBM's lack of train-time incentives for the model to be
appropriately receptive to concept interventions. To address this, we propose
Intervention-aware Concept Embedding models (IntCEMs), a novel CBM-based
architecture and training paradigm that improves a model's receptiveness to
test-time interventions. Our model learns a concept intervention policy in an
end-to-end fashion from where it can sample meaningful intervention trajectories
at train-time. This conditions IntCEMs to effectively select and receive concept
interventions when deployed at test-time. Our experiments show that IntCEMs
significantly outperform state-of-the-art concept-interpretable models when
provided with test-time concept interventions, demonstrating the effectiveness
of our approach.

# Models

[Concept Bottleneck Models (CBMs)](https://arxiv.org/abs/2007.04612) have recently gained attention as
high-performing and interpretable neural architectures that can explain their
Expand Down Expand Up @@ -87,6 +137,7 @@ between these two produced embeddings at test-time to then intervene in a
concept and improve downstream performance. Our entire architecture is
visualized in the figure above and formally described in our paper.


# Installation

You can locally install this package by first cloning this repository:
Expand All @@ -109,15 +160,16 @@ import cem

## High-level Usage
In this repository, we include a standalone Pytorch implementation of Concept
Embedding Models (CEMs) and Concept Bottleneck Models (CBMs)
which can be easily trained from scratch given a set of samples annotated with
a downstream task and a set of binary concepts. In order to use our
Embedding Models (CEMs), Intervention-aware Concept Embedding Models (IntCEMs),
and seveeal variants of Concept Bottleneck Models (CBMs).
These models can be easily trained from scratch given a set of samples annotated
with a downstream task and a set of binary concepts. In order to use our
implementation, however, you first need to install all our code's requirements
(listed in `requirements.txt`) by following the installation instructions
above,

After installation has been completed, you should be able to import
`cem` as a package and use it to train a CEM as follows:
`cem` as a package and use it to train a **CEM** as follows:

```python
import pytorch_lightning as pl
Expand Down Expand Up @@ -159,18 +211,67 @@ trainer = pl.Trainer(
trainer.fit(cem_model, train_dl, val_dl)
```

Similarly, you can train an **IntCEM** as follows:

```python
import pytorch_lightning as pl
from cem.models.intcbm import IntAwareConceptEmbeddingModel

#####
# Define your pytorch dataset objects
#####

train_dl = ...
val_dl = ...

#####
# Construct the model
#####

intcem_model = IntAwareConceptEmbeddingModel(
n_concepts=n_concepts, # Number of training-time concepts
n_tasks=n_tasks, # Number of output labels
emb_size=16,
concept_loss_weight=0.1,
learning_rate=1e-3,
optimizer="adam",
c_extractor_arch=latent_code_generator_model, # Replace this appropriately
training_intervention_prob=0.25, # RandInt probability
intervention_task_discount=1.1, # Penalty factor "gamma" for misprediction after intervntions
intervention_weight=1, # The weight lambda_roll of the intervention loss in IntCEM's training objective
)

#####
# Train it
#####

trainer = pl.Trainer(
accelerator="gpu", # or "cpu" if no GPU available
devices="auto",
max_epochs=100,
check_val_every_n_epoch=5,
)
# train_dl and val_dl are datasets previously built...
trainer.fit(intcem_model, train_dl, val_dl)
```

For a **full example** showing how to generate a dataset and configure a CEM
**for training on your own custom dataset**, please see our [Dot example notebook](https://github.com/mateoespinosa/cem/blob/main/examples/dot_cem_train_walkthrough.ipynb)
for a step-by-step walkthrough on how to set things up for your own work.
The same setup used in this notebook can be used for training IntCEMs or
CBMs as defined in this library.

## Included Models
Besides CEMs, this repository also includes a PyTorch implementation of
Concept Bottleneck Models (CBMs), which should be trainable out of the box.

Besides CEMs and IntCEMs, this repository also includes a PyTorch implementation of
Concept Bottleneck Models (CBMs) and some of its variants.
These models should be trainable out of the box if one follows the same steps
used for training an IntCEM/CEM as indicated above.

You can import CBMs by including
```python
from cem.models.cbm import ConceptBottleneckModel
```
in your python source file.

---

Expand Down Expand Up @@ -236,7 +337,7 @@ Our **CEM module** takes the following initialization arguments:
concept c_i to 0 would imply setting its corresponding
predicted concept to inactive_intervention_values[i]). If not given,
then we will assume that we use `0` for all concepts.
- `intervention_policy` (Callable[(np.ndarray, np.ndarray, np.ndarray), np.ndarray]):
- `intervention_policy` (an instance of InterventionPolicy as defined in cem.interventions.intervention_policy):
An optional intervention policy to be used when intervening on a test
batch sample x (first argument), with corresponding true concepts c
(second argument), and true labels y (third argument). The policy must
Expand All @@ -247,10 +348,43 @@ Our **CEM module** takes the following initialization arguments:
for during training/testing when the number of tasks is high.


Notice that our **[CBM module](https://github.com/mateoespinosa/cem/blob/main/cem/models/cbm.py) takes similar arguments**, albeit some extra ones
Notice that the **[CBM module](https://github.com/mateoespinosa/cem/blob/main/cem/models/cbm.py) takes similar arguments**, albeit some extra ones
and some with slightly different semantics (e.g., x2c_model goes directly
from the input to the bottleneck).

Similarly, the **[IntCEM module](https://github.com/mateoespinosa/cem/blob/main/cem/models/intcbm.py)**
takes the same arguments as its CEM counterpart with the following additional
arguments:
- `intervention_task_discount` (float): Penalty to be applied for mispredicting
the task label after some interventions have been performed vs when no
interventions have been performed. This is what we call `gamma` in the
paper. Defaults to 1.1.
- `intervention_weight` (float): Weight to be used for the intervention policy
loss during training. This is what we call `lambda_roll` in the paper.
Defaults to 1.
- `concept_map` (Dict[Any, List[int]]): A map between concept group names (e.g.
"wing_colour") and a list of concept indices that represent the group.
If concpet groups are known, and we want to learn an intervention policy that
acts on groups rather than on individual concepts, then this dictionary
should be provided before training and at intervention time. Defaults to
every concept being a single group.
- `use_concept_groups` (bool): Set this to true if `concept_map` is provided and
you want interventions to be done on entire groups of concepts at a time
rather than on individual concepts. Defaults to True.
- `num_rollouts` (int): The number of Monte Carlo rollouts we will perform when
sampling trajectories (i.e., the number of trajectories one will sample per
training step). Defaults to 1.
- `max_horizon` (int): The end maximum number of interventions to be made on a
single training trajectory. Defaults to 6.
- `initial_horizon` (int): The initial maximum number of interventions to be
made on a single training trajectory. Defaults to 2.
- `horizon_rate` (int): How much we increase `T_max` on every training step.
This value will start as `initial_horizon` and end in `max_horizon` increasing
by a factor of `horizon_rate` on every training step. Defaults to 1.005.
- `int_model_layers` (List[int]): The size of the layers to be used for the
MLP forming IntCEM's learnable polilcy `psi`. Defaults to
- `int_model_use_bn` (bool): Whether or not we use batch normalization between
each layer of the intervnetion policy model `psi`. Defaults to True.

# Experiment Reproducibility

Expand All @@ -267,22 +401,23 @@ or via the `DATASET_DIR` environment variable.

## Running Experiments

To reproduce the experiments discussed in our paper, please use the scripts
To reproduce the experiments discussed in our papers, please use the scripts
in the `experiments` directory after installing the `cem` package as
indicated above. For example, to run our experiments on the DOT dataset
(see our paper), you can execute the following command:

```bash
$ python experiments/run_experiments.py -c experiments/configs/dot_config.yaml
$ python experiments/run_experiments.py -c experiments/configs/dot.yaml
```
This should generate a summary of all the results after execution has
terminated in the form of a table and should dump all results/trained
models/logs into the given output directory (`dot_results/` in this case).

Similarly, you can recreate our `CUB` and `CelebA` experiments (or those on any other synthetic dataset) by running
Similarly, you can recreate our `CUB`, `CelebA`, and `MNIST-Add` experiments (or
those on any other synthetic dataset) by running

```bash
$ python experiments/run_experiments.py -c experiments/configs/{cub/celeba}_config.yaml
$ python experiments/run_experiments.py -c experiments/configs/{cub/celeba/mnist_add}.yaml
```

These scripts will also run the intervention experiments and generate the test
Expand All @@ -291,13 +426,19 @@ accuracies for all models as one intervenes on an increasing number of concepts.
Once an experiment is over, the script will dump all testing results/statistics
summarized over 5 random initializations in a single `results.joblib` dictionary
which you can then analyze.

This dictionary has at its top level the seed id (e.g. '0', '1', etc), then
this is mapped to a dictionary that maps run names (e.g., 'CEM') to dictonary
containing all test and validation metrics computed.


# Citation

If you would like to cite this repository, or the accompanying paper, please
use the following citation:
If you would like to cite this repository, or the accompanying paper, we would
appreciate it if you could please use the following citations to include both
the CEM and IntCEM papers (although if IntCEMs are not used, only the CEM
citation is necessary):

### CEM Paper
```
@article{EspinosaZarlenga2022cem,
title={Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off},
Expand All @@ -312,3 +453,17 @@ use the following citation:
year={2022}
}
```

### IntCEM Paper
```
@article{EspinosaZarlenga2023intcem,
title={Learning to Receive Help: Intervention-Aware Concept Embedding Models},
author={
Espinosa Zarlenga, Mateo and Collins, Katie and Dvijotham,
Krishnamurthy and Weller, Adrian and Shams, Zohreh and Jamnik, Mateja
},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2023}
}
```

0 comments on commit 5fb5f93

Please sign in to comment.