diff --git a/.gitignore b/.gitignore index fec8e55f..57036226 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,8 @@ coverage.xml .pytest_cache/ cover/ *test*.sh +tests/concepts/checkpoints/ +tests/FreeMono.ttf # Environments .env diff --git a/README.md b/README.md index 56fc5746..39582f56 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,9 @@ + + + PyLint @@ -41,8 +44,13 @@ Feature Visualization · Metrics + . + Example-based

+> [!IMPORTANT] +> With the release of Keras 3.X since TensorFlow 2.16, some methods may not function as expected. We are actively working on a fix. In the meantime, we recommend using TensorFlow 2.15 or earlier versions for optimal compatibility. + The library is composed of several modules, the _Attributions Methods_ module implements various methods (e.g Saliency, Grad-CAM, Integrated-Gradients...), with explanations, examples and links to official papers. The _Feature Visualization_ module allows to see how neural networks build their understanding of images by finding inputs that maximize neurons, channels, layers or compositions of these elements. The _Concepts_ module allows you to extract human concepts from a model and to test their usefulness with respect to a class. @@ -54,6 +62,9 @@ Finally, the _Metrics_ module covers the current metrics used in explainability.
+> [!NOTE] +> We are proud to announce the release of the _Example-based_ module! This module is dedicated to methods that explain a model by retrieving relevant examples from a dataset. It includes methods that belong to different families: similar examples, contrastive (counter-factuals and semi-factuals) examples, and prototypes (as concepts based methods have a dedicated sections). + ## 🔥 Tutorials
@@ -110,6 +121,8 @@ Finally, the _Metrics_ module covers the current metrics used in explainability.

+- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) + You can find a certain number of [**other practical tutorials just here**](https://github.com/deel-ai/xplique/blob/master/TUTORIALS.md). This section is actively developed and more contents will be included. We will try to cover all the possible usage of the library, feel free to contact us if you have any suggestions or recommendations towards tutorials you would like to see. @@ -361,6 +374,28 @@ TF : Tensorflow compatible
+Even though we are only at the early stages, we have also recently added an [Example-based methods](api/example_based/api_example_based/) module. Do not hesitate to give us feedback! Currently, the methods available are summarized in the following table: + +
+Table of example-based methods available + +| Method | Family | Documentation | Tutorial | +| --- | --- | --- | --- | +| `SimilarExamples` | Similar Examples | [SimilarExamples](../similar_examples/similar_examples/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | +| `Cole` | Similar Examples | [Cole](../similar_examples/cole/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | +| | | | +| `NaiveCounterFactuals` | Counter Factuals | [NaiveCounterFactuals](../counterfactuals/naive_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | +| `LabelAwareCounterFactuals` | Counter Factuals | [LabelAwareCounterFactuals](../counterfactuals/label_aware_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | +|||| +| `KLEORSimMiss` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | +| `KLEORGlobalSim` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | +|||| +| `ProtoGreedy` | Prototypes | [ProtoGreedy](../prototypes/proto_greedy/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | +| `ProtoDash` | Prototypes | [ProtoDash](../prototypes/proto_dash/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | +| `MMDCritic` | Prototypes | [MMDCritic](../prototypes/mmd_critic/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | + +
+ ## 👍 Contributing Feel free to propose your ideas or come and contribute with us on the Xplique toolbox! We have a specific document where we describe in a simple way how to make your first pull request: [just here](https://github.com/deel-ai/xplique/blob/master/CONTRIBUTING.md). diff --git a/TUTORIALS.md b/TUTORIALS.md index 759964ec..e72aaf95 100644 --- a/TUTORIALS.md +++ b/TUTORIALS.md @@ -20,6 +20,8 @@ Here is the lists of the available tutorial for now: | Metrics | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WEpVpFSq-oL1Ejugr8Ojb3tcbqXIOPBg) | | Concept Activation Vectors | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1iuEz46ZjgG97vTBH8p-vod3y14UETvVE) | | Feature Visualization | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1st43K9AH-UL4eZM1S4QdyrOi7Epa5K8v) | +| Example-Based Methods | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | +| Prototypes | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | ## Attributions @@ -74,3 +76,10 @@ Here is the lists of the available tutorial for now: | :------------------------------------- | :-----------------------------------------------------------------------------------------------------------------------------------------------------: | | Feature Visualization: Getting started | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1st43K9AH-UL4eZM1S4QdyrOi7Epa5K8v) | | Modern Feature Visualization: MaCo | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1l0kag1o-qMY4NCbWuAwnuzkzd9sf92ic) | + +## Example-Based Methods + +| **Tutorial Name** | Notebook | +| :------------------------------------- | :-----------------------------------------------------------------------------------------------------------------------------------------------------: | +| Example-Based Methods: Getting started | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | +| Example-based: Prototypes | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | \ No newline at end of file diff --git a/docs/api/example_based/api_example_based.md b/docs/api/example_based/api_example_based.md new file mode 100644 index 00000000..6fbb3b8a --- /dev/null +++ b/docs/api/example_based/api_example_based.md @@ -0,0 +1,129 @@ +# API: Example-based + +- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) +- [**Example-based: Prototypes**](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) + +## Context ## + +!!! quote + While saliency maps have stolen the show for the last few years in the XAI field, their ability to reflect models' internal processes has been questioned. Although less in the spotlight, example-based XAI methods have continued to improve. It encompasses methods that use samples as explanations for a machine learning model's predictions. This aligns with the psychological mechanisms of human reasoning and makes example-based explanations natural and intuitive for users to understand. Indeed, humans learn and reason by forming mental representations of concepts based on examples. + + -- [Natural Example-Based Explainability: a Survey (2023)](https://arxiv.org/abs/2309.03234)[^1] + +As mentioned by our team members in the quote above, example-based methods are an alternative to saliency maps and can be more aligned with some users' expectations. Thus, we have been working on implementing some of those methods in Xplique that have been put aside in the previous developments. + +While not being exhaustive we tried to cover a range of methods that are representative of the field and that belong to different families: similar examples, contrastive (counter-factuals and semi-factuals) examples, and prototypes (as concepts based methods have a dedicated sections). + +At present, we made the following choices: +- Focus on methods that are natural example methods (post-hoc and non-generative, see the paper above for more details). +- Try to unify the four families of approaches with a common API. + +!!! info + We are in the early stages of development and are looking for feedback on the API design and the methods we have chosen to implement. Also, we are counting on the community to furnish the collection of methods available. If you are willing to contribute reach us on the [GitHub](https://github.com/deel-ai/xplique) repository (with an issue, pull request, ...). + +## Common API ## + +```python +projection = ProjectionMethod(model) + +explainer = ExampleMethod( + cases_dataset=cases_dataset, + k=k, + projection=projection, + case_returns=case_returns, + distance=distance, +) + +outputs_dict = explainer.explain(inputs, targets) +``` + +We tried to keep the API as close as possible to the one of the attribution methods to keep a consistent experience for the users. + +The `BaseExampleMethod` is an abstract base class designed for example-based methods used to explain classification models. It provides examples from a dataset (usually the training dataset) to help understand a model's predictions. Examples are projected from the input space to a search space using a [projection function](#projections). The projection function defines the search space. Then, examples are selected using a [search method](#search-methods) within the search space. For all example-based methods, one can define the `distance` that will be used by the search method. + +We can broadly categorize example-based methods into four families: similar examples, counter-factuals, semi-factuals, and prototypes. + +- **Similar Examples**: This method involves finding instances in the dataset that are similar to a given instance. The similarity is often determined based on the feature space, and these examples can help in understanding the model's decision by showing what other data points resemble the instance in question. +- **Counter Factuals**: Counterfactual explanations identify the minimal changes needed to an instance's features to change the model's prediction to a different, specified outcome. They help answer "what-if" scenarios by showing how altering certain aspects of the input would lead to a different decision. +- **Semi Factuals**: Semifactual explanations describe hypothetical situations where most features of an instance remain the same except for one or a few features, without changing the overall outcome. They highlight which features could vary without altering the prediction. +- **Prototypes**: Prototypes are representative examples from the dataset that summarize typical cases within a certain category or cluster. They act as archetypal instances that the model uses to make predictions, providing a reference point for understanding model behavior. Additional documentation can be found in the [Prototypes API documentation](../prototypes/api_prototypes/). + +??? abstract "Table of example-based methods available" + + | Method | Family | Documentation | Tutorial | + | --- | --- | --- | --- | + | `SimilarExamples` | Similar Examples | [SimilarExamples](../similar_examples/similar_examples/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + | `Cole` | Similar Examples | [Cole](../similar_examples/cole/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + | | | | + | `NaiveCounterFactuals` | Counter Factuals | [NaiveCounterFactuals](../counterfactuals/naive_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + | `LabelAwareCounterFactuals` | Counter Factuals | [LabelAwareCounterFactuals](../counterfactuals/label_aware_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + |||| + | `KLEORSimMiss` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + | `KLEORGlobalSim` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + |||| + | `ProtoGreedy` | Prototypes | [ProtoGreedy](../prototypes/proto_greedy/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | + | `ProtoDash` | Prototypes | [ProtoDash](../prototypes/proto_dash/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | + | `MMDCritic` | Prototypes | [MMDCritic](../prototypes/mmd_critic/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | + +### Parameters ### + +`DatasetOrTensor = Union[tf.Tensor, np.ndarray, "torch.Tensor", tf.data.Dataset, "torch.utils.data.DataLoader"]` + +- **cases_dataset** (`DatasetOrTensor`): The dataset used to train the model, examples are extracted from this dataset. All datasets (cases, labels, and targets) should be of the same type. Supported types are: `tf.data.Dataset`, `torch.utils.data.DataLoader`, `tf.Tensor`, `np.ndarray`, `torch.Tensor`. For datasets with multiple columns, the first column is assumed to be the cases. While the second column is assumed to be the labels, and the third the targets. Warning: datasets tend to reshuffle at each iteration, ensure the datasets are not reshuffle as we use index in the dataset. +- **labels_dataset** (`Optional[DatasetOrTensor]`): Labels associated with the examples in the cases dataset. It should have the same type as `cases_dataset`. +- **targets_dataset** (`Optional[DatasetOrTensor]`): Targets associated with the `cases_dataset` for dataset projection, often the one-hot encoding of a model's predictions. See `projection` for detail. It should have the same type as `cases_dataset`. It is not be necessary for all projections. Furthermore, projections which requires it compute it internally by default. +- **k** (`int`): The number of examples to retrieve per input. +- **projection** (`Union[Projection, Callable]`): A projection or callable function that projects samples from the input space to the search space. The search space should be relevant for the model. (see [Projections](#projections)) +- **case_returns** (`Union[List[str], str]`): Elements to return in `self.explain()`. Default is `"examples"`. `"all"` indicates that every possible output should be returned. +- **batch_size** (`Optional[int]`): Number of samples processed simultaneously for projection and search. Ignored if `cases_dataset` is a batched `tf.data.Dataset` or a batched `torch.utils.data.DataLoader` is provided. + +!!!tips + If the elements of your dataset are tuples (cases, labels), you can pass this dataset directly to the `cases_dataset`. + +!!!tips + Apart from contrastive explanations, in the case of classification, the built-in [Projections](#projections) compute `targets` online and the `targets_dataset` is not necessary. + +### Properties ### + +- **search_method_class** (`Type[BaseSearchMethod]`): Abstract property to define the search method class to use. Must be implemented in subclasses. (see [Search Methods](#search-methods)) +- **k** (`int`): Getter and setter for the `k` parameter. +- **returns** (`Union[List[str], str]`): Getter and setter for the `returns` parameter. Defines the elements to return in `self.explain()`. + +### `explain(self, inputs, targets)` ### + +Returns the relevant examples to explain the (inputs, targets). Projects inputs using `self.projection` and finds examples using the `self.search_method`. + +- **inputs** (`Union[tf.Tensor, np.ndarray]`): Input samples to be explained. Shape: (n, ...) where n is the number of samples. +- **targets** (`Optional[Union[tf.Tensor, np.ndarray]]`): Targets associated with the `inputs` for projection. Shape: (n, nb_classes) where n is the number of samples and nb_classes is the number of classes. Not used in all projection. Used in contrastive methods to know the predicted classes of the provided samples. + +**Returns:** Dictionary with elements listed in `self.returns`. + +!!!info + The `__call__` method is an alias for the `explain` method. + +## Projections ## +Projections are functions that map input samples to a search space where examples are retrieved with a `search_method`. The search space should be relevant for the model (e.g. projecting the inputs into the latent space of the model). + +!!!info + If one decides to use the identity function as a projection, the search space will be the input space, thus rather explaining the dataset than the model. + +The `Projection` class is a base class for projections. It involves two parts: `space_projection` and `weights`. The samples are first projected to a new space and then weighted. + +!!!warning + If both parts are `None`, the projection acts as an identity function. In general, we advise that one part should involve the model to ensure meaningful distance calculations with respect to the model. + +To know more about projections and their importance, you can refer to the [Projections](../../projections/) section. + +## Search Methods ## + +!!!info + The search methods are hidden to the user and only used internally. However, they help to understand how the API works. + +Search methods are used to retrieve examples from the `cases_dataset` that are relevant to the input samples. + +!!!warning + In an search method, the `cases_dataset` is the dataset that has been projected with a `Projection` object (see the previous section). The search methods are used to find examples in this projected space. + +Each example-based method has its own search method. The search method is defined in the `search_method_class` property of the `ExampleMethod` class. + +[^1]: [Natural Example-Based Explainability: a Survey (2023)](https://arxiv.org/abs/2309.03234) \ No newline at end of file diff --git a/docs/api/example_based/counterfactuals/label_aware_counter_factuals.md b/docs/api/example_based/counterfactuals/label_aware_counter_factuals.md new file mode 100644 index 00000000..95989d53 --- /dev/null +++ b/docs/api/example_based/counterfactuals/label_aware_counter_factuals.md @@ -0,0 +1,68 @@ +# Label Aware Counterfactuals + + + + [View colab tutorial](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + + + [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/counterfactuals.py) | +📰 [Paper](https://www.semanticscholar.org/paper/Nearest-unlike-neighbor-(NUN)%3A-an-aid-to-decision-Dasarathy/48c1a310f655b827e5e7d712c859b25a4e3c0902) + +!!!note + The paper referenced here is not exactly the one we implemented. However, it is probably the closest in essence of what we implemented. + +In contrast to the [Naive Counterfactuals](../../counterfactuals/naive_counter_factuals/) approach, the Label Aware CounterFactuals leverage an *a priori* knowledge of the Counterfactuals' (CFs) targets to guide the search for the CFs (*e.g.* one is looking for a CF of the digit 8 in MNIST dataset within the digit 0 instances). + +!!!warning + Consequently, for this class, when a user call the `explain` method, the user is expected to provide both the `targets` corresponding to the input samples and `cf_expected_classes` a one-hot encoding of the label expected for the CFs. But in most cases, the `targets` can be set to `None` as they are computed internally by projections. + +!!!info + One can use the `Projection` object to compute the distances between the samples (e.g. search for the CF in the latent space of a model). + +## Example + +```python +from xplique.example_based import LabelAwareCounterFactuals +from xplique.example_based.projections import LatentSpaceProjection + +# load the training dataset and the model +cases_dataset = ... # load the training dataset +targets_dataset = ... # load the one-hot encoding of predicted labels of the training dataset +model = ... + +# load the test samples +test_samples = ... # load the test samples to search for +test_cf_expacted_classes = ... # WARNING: provide the one-hot encoding of the expected label of the CFs + +# parameters +k = 5 # number of example for each input +case_returns = "all" # elements returned by the explain function +distance = "euclidean" +latent_layer = "last_conv" # where to split your model for the projection + +# construct a projection with your model +projection = LatentSpaceProjection(model, latent_layer=latent_layer) + +# instantiate the LabelAwareCounterfactuals object +lacf = LabelAwareCounterFactuals( + cases_dataset=cases_dataset, + targets_dataset=targets_dataset, + k=k, + projection=projection, + case_returns=case_returns, + distance=distance, +) + +# search the CFs for the test samples +output_dict = lacf.explain( + inputs=test_samples, + targets=None, # not necessary for this projection + cf_expected_classes=test_cf_expacted_classes, +) +``` + +## Notebooks + +- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) + +{{xplique.example_based.counterfactuals.LabelAwareCounterFactuals}} \ No newline at end of file diff --git a/docs/api/example_based/counterfactuals/naive_counter_factuals.md b/docs/api/example_based/counterfactuals/naive_counter_factuals.md new file mode 100644 index 00000000..3b5dd600 --- /dev/null +++ b/docs/api/example_based/counterfactuals/naive_counter_factuals.md @@ -0,0 +1,67 @@ +# Naive CounterFactuals + + + + [View colab tutorial](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + + + [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/counterfactuals.py) | +📰 [Paper](https://www.semanticscholar.org/paper/Nearest-unlike-neighbor-(NUN)%3A-an-aid-to-decision-Dasarathy/48c1a310f655b827e5e7d712c859b25a4e3c0902) + +!!!note + The paper referenced here is not exactly the one we implemented as we use a "naive" version of it. However, it is probably the closest in essence of what we implemented. + +We define here a "naive" counterfactual method that is based on the Nearest Unlike Neighbor (NUN) concept introduced by Dasarathy in 1991[^1]. In essence, the NUN of a sample $(x, y)$ is the closest sample in the training dataset which has a different label than $y$. + +Thus, in this naive approach to counterfactuals, we yield the $k$ nearest training instances that have a different label than the target of the input sample in a greedy fashion. + +As it is mentioned in the [API documentation](../../api_example_based/), by setting a `Projection` object, one will map the inputs to a space where the distance function is meaningful. + +## Example + +```python +from xplique.example_based import NaiveCounterFactuals +from xplique.example_based.projections import LatentSpaceProjection + +# load the training dataset and the model +cases_dataset = ... # load the training dataset +targets_dataset = ... # load the one-hot encoding of predicted labels of the training dataset +model = ... + +# load the test samples +test_samples = ... # load the test samples to search for +test_targets = ... # compute a one hot encoding of the model's prediction on the samples + +# parameters +k = 5 # number of example for each input +case_returns = "all" # elements returned by the explain function +distance = "euclidean" +latent_layer = "last_conv" # where to split your model for the projection + +# construct a projection with your model +projection = LatentSpaceProjection(model, latent_layer=latent_layer) + +# instantiate the NaiveCounterFactuals object +ncf = NaiveCounterFactuals( + cases_dataset=cases_dataset, + targets_dataset=targets_dataset, + k=k, + projection=projection, + case_returns=case_returns, + distance=distance, +) + +# search the CFs for the test samples +output_dict = ncf.explain( + inputs=test_samples, + targets=test_targets, +) +``` + +## Notebooks + +- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) + +{{xplique.example_based.counterfactuals.NaiveCounterFactuals}} + +[^1] [Nearest unlike neighbor (NUN): an aid to decision making](https://www.semanticscholar.org/paper/Nearest-unlike-neighbor-(NUN)%3A-an-aid-to-decision-Dasarathy/48c1a310f655b827e5e7d712c859b25a4e3c0902) \ No newline at end of file diff --git a/docs/api/example_based/projections.md b/docs/api/example_based/projections.md new file mode 100644 index 00000000..ea34720d --- /dev/null +++ b/docs/api/example_based/projections.md @@ -0,0 +1,74 @@ +# Projections + +In example-based explainability, one often needs to define a notion of similarity (distance) between samples. However, the original feature space may not be the most suitable space to define this similarity. For instance, in the case of images, two images can be very similar in terms of their pixel values but very different in terms of their semantic content. In addition, computing distances in the original feature space does not take into account the model's whatsoever, questioning the explainability of the method. + +To address these issues, one can project the samples into a new space where the distances between samples are more meaningful with respect to the model's decision. Two approaches are commonly used to define this projection space: (1) use a latent space and (2) use a feature weighting scheme. + +Consequently, we defined the general `Projection` class that will be used as a base class for all projection methods. This class allows one to use one or both of the aforementioned approaches. Indeed, one can instantiate a `Projection` object with a `space_projection` method, that define a projection from the feature space to a space of interest, and a`get_weights` method, that defines the feature weighting scheme. The `Projection` class will then project a sample with the `space_projection` method and weight the projected sample's features with the `get_weights` method. + +In addition, we provide concrete implementations of the `Projection` class: `LatentSpaceProjection`, `AttributionProjection`, and `HadamardProjection`. + +{{xplique.example_based.projections.Projection}} + +!!!info + The `__call__` method is an alias for the `project` method. + +## Defining a custom projection + +To define a custom projection, one needs to implement the `space_projection` and/or `get_weights` methods. The `space_projection` method should return the projected sample, and the `get_weights` method should return the weights of the features of the projected sample. + +!!!info + The `get_weights` method should take as input the original sample once it has been projected using the `space_projection` method. + +For the sake of clarity, we provide an example of a custom projection that projects the samples into a latent space (the final convolution block of the ResNet50 model) and weights the features with the gradients of the model's output with respect to the inputs once they have gone through the layers until the final convolutional layer. + +```python +import tensorflow as tf +from xplique.attributions import Saliency +from xplique.example_based.projections import Projection + +# load the model +model = tf.keras.applications.ResNet50(weights="imagenet", include_top=True) + +latent_layer = model.get_layer("conv5_block3_out") # output of the final convolutional block +features_extractor = tf.keras.Model( + model.input, latent_layer.output, name="features_extractor" +) + +# reconstruct the second part of the InceptionV3 model +second_input = tf.keras.Input(shape=latent_layer.output.shape[1:]) + +x = second_input +layer_found = False +for layer in model.layers: + if layer_found: + x = layer(x) + if layer == latent_layer: + layer_found = True + +predictor = tf.keras.Model( + inputs=second_input, + outputs=x, + name="predictor" +) + +# build the custom projection +space_projection = features_extractor +get_weights = Saliency(predictor) + +custom_projection = Projection(space_projection=space_projection, get_weights=get_weights, mappable=False) + +# build random samples +rdm_imgs = tf.random.normal((5, 224, 224, 3)) +rdm_targets = tf.random.uniform(shape=[5], minval=0, maxval=1000, dtype=tf.int32) +rdm_targets = tf.one_hot(rdm_targets, depth=1000) + +# project the samples +projected_samples = custom_projection(rdm_imgs, rdm_targets) +``` + +{{xplique.example_based.projections.LatentSpaceProjection}} + +{{xplique.example_based.projections.AttributionProjection}} + +{{xplique.example_based.projections.HadamardProjection}} \ No newline at end of file diff --git a/docs/api/example_based/prototypes/api_prototypes.md b/docs/api/example_based/prototypes/api_prototypes.md new file mode 100644 index 00000000..b7fc8e04 --- /dev/null +++ b/docs/api/example_based/prototypes/api_prototypes.md @@ -0,0 +1,104 @@ +# Prototypes +A prototype in AI explainability is a representative example from the data that shows how the model makes decisions ([Poché et al., 2023](https://hal.science/hal-04117520/document)). It helps explain a prediction by pointing to a similar example the model learned from, making the decision more understandable. Imagine you're training a model to recognize dogs. After the model learns, you can ask it to show a "prototype" for the dog category, which would be an actual image from the training set that best represents what a typical dog looks like. + +!!!info + Using the identity projection, one is looking for the **dataset prototypes**. In contrast, using the latent space of a model as a projection, one is looking for **prototypes relevant for the model**. + +## Common API ## + +```python +# only for model explanations, define a projection based on the model +projection = ProjectionMethod(model) + +# construct the explainer (it computes the global prototypes) +explainer = PrototypesMethod( + cases_dataset=cases_dataset, + nb_global_prototypes=nb_global_prototypes, + nb_local_prototypes=nb_local_prototypes, + projection=projection, + case_returns=case_returns, + distance=distance, +) + +# compute global explanation +global_prototypes_dict = explainer.get_global_prototypes() + +# compute local explanation +local_prototypes_dict = explainer(inputs) + +``` + +??? abstract "Table of methods available" + + The following prototypes methods are implemented: + + | Method Name and Documentation link | **Tutorial** | Available with TF | Available with PyTorch* | + |:-------------------------------------- | :----------------------: | :---------------: | :---------------------: | + | [ProtoGreedy](../proto_greedy/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ | + | [ProtoDash](../proto_dash/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ | + | [MMDCritic](../mmd_critic/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | ✔ | ✔ | + +!!!info + Prototypes, share a common API with other example-based methods. Thus, to understand some parameters, we recommend reading the [dedicated documentation](../../api_example_based/). + +## Specificity of prototypes + +The search method class related to a `Prototypes` class includes the following additional parameters: + +- `nb_global_prototypes` which represents the total number of prototypes desired to represent the entire dataset. +- `nb_local_prototypes` which represents the number of prototypes closest to the input and allows for a local explanation. This attribute is equivalent to $k$ in the other exemple based methods. + +- `kernel_fn`, and `gamma` which are related to the [kernel](#how-to-choose-the-kernel) used to compute the [MMD distance](#what-is-mmd). + +The prototype class has a `get_global_prototypes()` method, which calculates all the prototypes in the base dataset; these are called the global prototypes. The `explain` method then provides a local explanation, i.e., finds the prototypes closest to the input given as a parameter. + +## Implemented methods + +The library implements three methods, `MMDCritic`, `ProtoGreedy` and `ProtoDash` from **Data summarization with knapsack constraint** [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf). This class of prototype methods involves finding a subset of prototypes $\mathcal{P}$ that maximizes the coverage set function $F(\mathcal{P})$ under the constraint that its selection cost $C(\mathcal{P})$ (e.g., the number of selected prototypes $|\mathcal{P}|= nb\_global\_prototypes$) should be less than a given budget. +Submodularity and monotonicity of $F(\mathcal{P})$ are necessary to guarantee that a greedy algorithm has a constant factor guarantee of optimality [(Lin et al., 2011)](https://aclanthology.org/P11-1052.pdf). + +### Method comparison + +- Compared to `MMDCritic`, both `ProtoGreedy` and `Protodash` additionally determine the weights for each of the selected prototypes. +- `ProtoGreedy` and `Protodash` works for any symmetric positive definite kernel which is not the case for `MMDCritic`. +- `MMDCritic` and `ProtoGreedy` select the next element that maximizes the increment of the scoring function while `Protodash` maximizes a tight lower bound on the increment of the scoring function (it maximizes the gradient of $F(\mathcal{P},w)$). +- `ProtoDash` is much faster than `ProtoGreedy` without compromising on the quality of the solution. The complexity of `ProtoGreedy` is $O(n(n+m^4))$ comparing to $O(n(n+m^2)+m^4)$ for `ProtoDash`. +- The approximation guarantee for `ProtoGreedy` is $(1-e^{-\gamma})$, where $\gamma$ is submodularity ratio of $F(\mathcal{P})$, comparing to $(1-e^{-1})$ for `MMDCritic`. + +### What is MMD? + +The commonality among these three methods is their utilization of the Maximum Mean Discrepancy (MMD) statistic as a measure of similarity between points and potential prototypes. MMD is a statistic for comparing two distributions (similar to KL-divergence). However, it is a non-parametric statistic, i.e., it does not assume a specific parametric form for the probability distributions being compared. It is defined as follows: + +$$ +\begin{align*} +\text{MMD}(P, Q) &= \left\| \mathbb{E}_{X \sim P}[\varphi(X)] - \mathbb{E}_{Y \sim Q}[\varphi(Y)] \right\|_\mathcal{H} +\end{align*} +$$ + +where $\varphi(\cdot)$ is a mapping function of the data points. If we want to consider all orders of moments of the distributions, the mapping vectors $\varphi(X)$ and $\varphi(Y)$ will be infinite-dimensional. Thus, we cannot calculate them directly. However, if we have a kernel that gives the same result as the inner product of these two mappings in Hilbert space ($k(x, y) = \langle \varphi(x), \varphi(y) \rangle_\mathcal{H}$), then the $MMD^2$ can be computed using only the kernel and without explicitly using $\varphi(X)$ and $\varphi(Y)$ (this is called the kernel trick): + +$$ +\begin{align*} +\text{MMD}^2(P, Q) &= \langle \mathbb{E}_{X \sim P}[\varphi(X)], \mathbb{E}_{X' \sim P}[\varphi(X')] \rangle_\mathcal{H} + \langle \mathbb{E}_{Y \sim Q}[\varphi(Y)], \mathbb{E}_{Y' \sim Q}[\varphi(Y')] \rangle_\mathcal{H} \\ +&\quad - 2\langle \mathbb{E}_{X \sim P}[\varphi(X)], \mathbb{E}_{Y \sim Q}[\varphi(Y)] \rangle_\mathcal{H} \\ +&= \mathbb{E}_{X, X' \sim P}[k(X, X')] + \mathbb{E}_{Y, Y' \sim Q}[k(Y, Y')] - 2\mathbb{E}_{X \sim P, Y \sim Q}[k(X, Y)] +\end{align*} +$$ + +### How to choose the kernel ? +The choice of the kernel for selecting prototypes depends on the specific problem and the characteristics of your data. Several kernels can be used, including: + +- Gaussian +- Laplace +- Polynomial +- Linear... + +If we consider any exponential kernel (Gaussian kernel, Laplace, ...), we automatically consider all the moments for the distribution, as the Taylor expansion of the exponential considers infinite-order moments. It is better to use a non-linear kernel to capture non-linear relationships in your data. If the problem is linear, it is better to choose a linear kernel such as the dot product kernel, since it is computationally efficient and often requires fewer hyperparameters to tune. + +!!!warning + For `MMDCritic`, the kernel must satisfy a condition ensuring the submodularity of the set function (the Gaussian kernel respects this constraint). In contrast, for `ProtoDash` and `ProtoGreedy`, any kernel can be used, as these methods rely on weak submodularity instead of full submodularity. + +!!!info + The default kernel used is Gaussian kernel. This kernel distance assigns higher similarity to points that are close in feature space and gradually decreases similarity as points move further apart. It is a good choice when your data has complexity. However, it can be sensitive to the choice of hyperparameters, such as the width $\sigma$ of the Gaussian kernel, which may need to be carefully fine-tuned. + + diff --git a/docs/api/example_based/prototypes/mmd_critic.md b/docs/api/example_based/prototypes/mmd_critic.md new file mode 100644 index 00000000..e4f9d33a --- /dev/null +++ b/docs/api/example_based/prototypes/mmd_critic.md @@ -0,0 +1,77 @@ +# MMDCritic + + + +[View colab tutorial](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | + + +[View source](https://github.com/deel-ai/xplique/blob/antonin/example-based-merge/xplique/example_based/search_methods/proto_greedy_search.py) | +📰 [Paper](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf) + +`MMDCritic` finds prototypes and criticisms by maximizing two separate objectives based on the Maximum Mean Discrepancy (MMD). + +!!! quote + MMD-critic uses the MMD statistic as a measure of similarity between points and potential prototypes, and + efficiently selects prototypes that maximize the statistic. In addition to prototypes, MMD-critic selects criticism samples i.e. samples that are not well-explained by the prototypes using a regularized witness function score. + + -- [Efficient Data Representation by Selecting Prototypes with Importance Weights (2019).](https://arxiv.org/abs/1707.01212) + +First, to find prototypes $\mathcal{P}$, a greedy algorithm is used to maximize $F(\mathcal{P})$ s.t. $|\mathcal{P}| \le m_p$ where $F(\mathcal{P})$ is defined as: +\begin{equation} + F(\mathcal{P})=\frac{2}{|\mathcal{P}|\cdot n}\sum_{i,j=1}^{|\mathcal{P}|,n}\kappa(p_i,x_j)-\frac{1}{|\mathcal{P}|^2}\sum_{i,j=1}^{|\mathcal{P}|}\kappa(p_i,p_j), +\end{equation} +where $m_p$ the number of prototypes to be found. They used diagonal dominance conditions on the kernel to ensure monotonocity and submodularity of $F(\mathcal{P})$. + +Second, to find criticisms $\mathcal{C}$, the same greedy algorithm is used to select points that maximize another objective function $J(\mathcal{C})$. + +!!!warning + For `MMDCritic`, the kernel must satisfy a condition that ensures the submodularity of the set function. The Gaussian kernel meets this requirement and it is recommended. If you wish to choose a different kernel, it must satisfy the condition described by [Kim et al., 2016](https://proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf). + + +## Example + +```python +from xplique.example_based import MMDCritic +from xplique.example_based.projections import LatentSpaceProjection + +# load the training dataset and the model +cases_dataset = ... # load the training dataset +model = ... + +# load the test samples +test_samples = ... # load the test samples to search for + +# parameters +case_returns = "all" # elements returned by the explain function +latent_layer = "last_conv" # where to split your model for the projection +nb_global_prototypes = 5 +nb_local_prototypes = 1 +kernel_fn = None # the default rbf kernel will be used, the distance will be based on this + +# construct a projection with your model +projection = LatentSpaceProjection(model, latent_layer=latent_layer) + +mmd = MMDCritic( + cases_dataset=cases_dataset, + nb_global_prototypes=nb_global_prototypes, + nb_local_prototypes=nb_local_prototypes, + projection=projection, + case_returns=case_returns, +) + +# compute global explanation +global_prototypes = mmd.get_global_prototypes() + +# compute local explanation +local_prototypes = mmd.explain(test_samples) +``` + +## Notebooks + +- [**Example-based: Prototypes**](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) + + +{{xplique.example_based.prototypes.MMDCritic}} + +[^1]: [Visual Explanations from Deep Networks via Gradient-based Localization (2016).](https://arxiv.org/abs/1610.02391) + diff --git a/docs/api/example_based/prototypes/proto_dash.md b/docs/api/example_based/prototypes/proto_dash.md new file mode 100644 index 00000000..83d78573 --- /dev/null +++ b/docs/api/example_based/prototypes/proto_dash.md @@ -0,0 +1,79 @@ +# ProtoDash + + + +[View colab tutorial](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | + + +[View source](https://github.com/deel-ai/xplique/blob/antonin/example-based-merge/xplique/example_based/search_methods/proto_greedy_search.py) | +📰 [Paper](https://arxiv.org/abs/1707.01212) + +`ProtoDash` associated non-negative weights to prototypes which are indicative of their importance. This approach allows for identifying both prototypes and criticisms (the least weighted examples among prototypes) by maximmizing the same weighted objective function. + +!!! quote + Our work notably generalizes the recent work + by [Kim et al. (2016)](../mmd_critic/) where in addition to selecting prototypes, we + also associate non-negative weights which are indicative of their + importance. This extension provides a single coherent framework + under which both prototypes and criticisms (i.e. outliers) can be + found. Furthermore, our framework works for any symmetric + positive definite kernel thus addressing one of the key open + questions laid out in [Kim et al. (2016)](../mmd_critic/). + + -- [Efficient Data Representation by Selecting Prototypes with Importance Weights (2019).](https://arxiv.org/abs/1707.01212) + +More precisely, the weighted objective $F(\mathcal{P},w)$ is defined as: +\begin{equation} +F(\mathcal{P},w)=\frac{2}{n}\sum_{i,j=1}^{|\mathcal{P}|,n}w_i\kappa(p_i,x_j)-\sum_{i,j=1}^{|\mathcal{P}|}w_iw_j\kappa(p_i,p_j), +\end{equation} +where $w$ are non-negative weights for each prototype. The problem then consist on finding a subset $\mathcal{P}$ with a corresponding $w$ that maximizes $J(\mathcal{P}) \equiv \max_{w:supp(w)\in \mathcal{P},w\ge 0} J(\mathcal{P},w)$ s.t. $|\mathcal{P}| \leq m=m_p+m_c$. + +!!!info + For ProtoDash, any kernel can be used, as these methods rely on weak submodularity instead of full submodularity. + +## Example + +```python +from xplique.example_based import ProtoDash +from xplique.example_based.projections import LatentSpaceProjection + +# load the training dataset and the model +cases_dataset = ... # load the training dataset +model = ... + +# load the test samples +test_samples = ... # load the test samples to search for + +# parameters +case_returns = "all" # elements returned by the explain function +latent_layer = "last_conv" # where to split your model for the projection +nb_global_prototypes = 5 +nb_local_prototypes = 1 +kernel_fn = None # the default rbf kernel will be used, the distance will be based on this + +# construct a projection with your model +projection = LatentSpaceProjection(model, latent_layer=latent_layer) + +protodash = ProtoDash( + cases_dataset=cases_dataset, + nb_global_prototypes=nb_global_prototypes, + nb_local_prototypes=nb_local_prototypes, + projection=projection, + case_returns=case_returns, +) + +# compute global explanation +global_prototypes = protodash.get_global_prototypes() + +# compute local explanation +local_prototypes = protodash.explain(test_samples) +``` + +## Notebooks + +- [**Example-based: Prototypes**](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) + + +{{xplique.example_based.prototypes.ProtoDash}} + +[^1]: [Visual Explanations from Deep Networks via Gradient-based Localization (2016).](https://arxiv.org/abs/1610.02391) diff --git a/docs/api/example_based/prototypes/proto_greedy.md b/docs/api/example_based/prototypes/proto_greedy.md new file mode 100644 index 00000000..0cb861d4 --- /dev/null +++ b/docs/api/example_based/prototypes/proto_greedy.md @@ -0,0 +1,80 @@ +# ProtoGreedy + + + +[View colab tutorial](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) | + + +[View source](https://github.com/deel-ai/xplique/blob/antonin/example-based-merge/xplique/example_based/search_methods/proto_greedy_search.py) | +📰 [Paper](https://arxiv.org/abs/1707.01212) + +`ProtoGreedy` associated non-negative weights to prototypes which are indicative of their importance. This approach allows for identifying both prototypes and criticisms (the least weighted examples among prototypes) by maximizing the same weighted objective function. + +!!! quote + Our work notably generalizes the recent work + by [Kim et al. (2016)](../mmd_critic/) where in addition to selecting prototypes, we + also associate non-negative weights which are indicative of their + importance. This extension provides a single coherent framework + under which both prototypes and criticisms (i.e. outliers) can be + found. Furthermore, our framework works for any symmetric + positive definite kernel thus addressing one of the key open + questions laid out in [Kim et al. (2016)](../mmd_critic/). + + -- [Efficient Data Representation by Selecting Prototypes with Importance Weights (2019).](https://arxiv.org/abs/1707.01212) + +More precisely, the weighted objective $F(\mathcal{P},w)$ is defined as: +\begin{equation} +F(\mathcal{P},w)=\frac{2}{n}\sum_{i,j=1}^{|\mathcal{P}|,n}w_i\kappa(p_i,x_j)-\sum_{i,j=1}^{|\mathcal{P}|}w_iw_j\kappa(p_i,p_j), +\end{equation} +where $w$ are non-negative weights for each prototype. The problem then consist on finding a subset $\mathcal{P}$ with a corresponding $w$ that maximizes $J(\mathcal{P}) \equiv \max_{w:supp(w)\in \mathcal{P},w\ge 0} J(\mathcal{P},w)$ s.t. $|\mathcal{P}| \leq m=m_p+m_c$. + +!!!info + For ProtoGreedy, any kernel can be used, as these methods rely on weak submodularity instead of full submodularity. + + +## Example + +```python +from xplique.example_based import ProtoGreedy +from xplique.example_based.projections import LatentSpaceProjection + +# load the training dataset and the model +cases_dataset = ... # load the training dataset +model = ... + +# load the test samples +test_samples = ... # load the test samples to search for + +# parameters +case_returns = "all" # elements returned by the explain function +latent_layer = "last_conv" # where to split your model for the projection +nb_global_prototypes = 5 +nb_local_prototypes = 1 +kernel_fn = None # the default rbf kernel will be used, the distance will be based on this + +# construct a projection with your model +projection = LatentSpaceProjection(model, latent_layer=latent_layer) + +protogreedy = ProtoGreedy( + cases_dataset=cases_dataset, + nb_global_prototypes=nb_global_prototypes, + nb_local_prototypes=nb_local_prototypes, + projection=projection, + case_returns=case_returns, +) + +# compute global explanation +global_prototypes = protogreedy.get_global_prototypes() + +# compute local explanation +local_prototypes = protogreedy.explain(test_samples) +``` + +## Notebooks + +- [**Example-based: Prototypes**](https://colab.research.google.com/drive/1-bUvXxzWrBqLLfS_4TvErcEfyzymTVGz) + + +{{xplique.example_based.prototypes.ProtoGreedy}} + +[^1]: [Visual Explanations from Deep Networks via Gradient-based Localization (2016).](https://arxiv.org/abs/1610.02391) diff --git a/docs/api/example_based/semifactuals/kleor.md b/docs/api/example_based/semifactuals/kleor.md new file mode 100644 index 00000000..9162f65d --- /dev/null +++ b/docs/api/example_based/semifactuals/kleor.md @@ -0,0 +1,88 @@ +# KLEOR + + + + [View colab tutorial](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + + + [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/semifactuals.py) | +📰 [Paper](https://www.researchgate.net/publication/220106308_KLEOR_A_Knowledge_Lite_Approach_to_Explanation_Oriented_Retrieval) + +KLEOR for Knowledge-Light Explanation-Oriented Retrieval was introduced by Cummins & Bridge in 2006. It is a method that use counterfactuals, Nearest Unlike Neighbor (NUN), to guide the selection of a semi-factual (SF) example. + +Given a distance function $dist$, the NUN of a sample $(x, y)$ is the closest sample in the training dataset which has a different label than $y$. + +The KLEOR method actually have three variants including: + +- The Sim-Miss approach +- The Global-Sim approach + +In the Sim-Miss approach, the SF of the sample $(x,y)$ is the closest training sample from the corresponding NUN which has the same label as $y$. + +Denoting the training dataset as $\mathcal{D}$: + +$$Sim-Miss(x, y, NUN(x,y), \mathcal{D}) = arg \\ min_{(x',y') \in \mathcal{D} \\ | \\ y'=y} dist(x', NUN(x,y))$$ + +In the Global-Sim approach, they add an additional constraint that the SF should lie between the sample $(x,y)$ and the NUN that is: $dist(x, SF) < dist(x, NUN(x,y))$. + +We extended to the $k$ nearest neighbors of the NUN for both approaches. + +!!!info + In our implementation, we rather consider the labels predicted by the model $\hat{y}$ (*i.e.* the targets) rather than $y$! + +!!!tips + As KLEOR methods use counterfactuals, they can also return them. Therefore, it is possible to obtain both semi-factuals and counterfactuals with an unique method. To do so "nuns" and "nuns_labels" should be added to the `cases_returns` list. + +## Examples + +```python +from xplique.example_based import KLEORGlobalSim # or KLEORSimMiss +from xplique.example_based.projections import LatentSpaceProjection + +# load the training dataset and the model +cases_dataset = ... # load the training dataset +targets_dataset = ... # load the one-hot encoding of predicted labels of the training dataset +model = ... + +# load the test samples +test_samples = ... # load the test samples to search for +test_targets = ... # compute a one hot encoding of the model's prediction on the samples + +# parameters +k = 1 # number of example for each input +case_returns = "all" # elements returned by the explain function +distance = "euclidean" +latent_layer = "last_conv" # where to split your model for the projection + +# construct a projection with your model +projection = LatentSpaceProjection(model, latent_layer=latent_layer) + +# instantiate the KLEORGlobalSim object (could be KLEORSimMiss, the code do not change) +sf_explainer = KLEORGlobalSim( + cases_dataset=cases_dataset, + targets_dataset=targets_dataset, + k=k, + projection=projection, + case_returns=case_returns, + distance=distance, +) + +# search the SFs for the test samples +sf_output_dict = sf_explainer.explain( + inputs=test_samples, + targets=test_targets, +) + +# get the semi-factuals +semifactuals = sf_output_dict["examples"] + +# get the counterfactuals +counterfactuals = sf_output_dict["nuns"] +``` + +## Notebooks + +- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) + +{{xplique.example_based.semifactuals.KLEORSimMiss}} +{{xplique.example_based.semifactuals.KLEORGlobalSim}} \ No newline at end of file diff --git a/docs/api/example_based/similar_examples/cole.md b/docs/api/example_based/similar_examples/cole.md new file mode 100644 index 00000000..63fac794 --- /dev/null +++ b/docs/api/example_based/similar_examples/cole.md @@ -0,0 +1,78 @@ +# COLE: Contributions Oriented Local Explanations + + + + [View colab tutorial](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + + + [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/similar_examples.py) | +📰 [Paper](https://researchrepository.ucd.ie/handle/10197/11064) + +COLE for Contributions Oriented Local Explanations was introduced by Kenny & Keane in 2019. + +!!! quote + Our method COLE is based on the premise that the contributions of features in a model’s classification represent the most sensible basis to inform case-based explanations. + + -- [COLE paper](https://researchrepository.ucd.ie/handle/10197/11064)[^1] + +The core idea of the COLE approach is to use [attribution maps](../../../attributions/api_attributions/) to define a relevant search space for the K-Nearest Neighbors (KNN) search. + +More specifically, the COLE approach is based on the following steps: + +- (1) Given an input sample $x$, compute the attribution map $A(x)$ + +- (2) Consider the projection space defined by: $p: x \rightarrow A(x) \odot x$ ($\odot$ denotes the element-wise product) + +- (3) Perform a KNN search in the projection space to find the most similar training samples + +!!! info + In the original paper, the authors focused on Multi-Layer Perceptrons (MLP) and three attribution methods (Hadamard, LPR, Integrated Gradient, and DeepLift). We decided to implement a COLE method that generalizes to a more broader range of Neural Networks and attribution methods (see [API Attributions documentation](../../../attributions/api_attributions/) to see the list of methods available). + +!!! tips + The original paper shown that the hadamard product between the latent space and the gradient was the best method. Hence we optimized the code for this method. Setting the `attribution_method` argument to `"gradient"` will run much faster. + +## Example + +```python +from xplique.example_based import Cole + +# load the training dataset and the model +cases_dataset = ... # load the training dataset +model = ... # load the model + +# load the test samples +test_samples = ... # load the test samples to search for + +# parameters +k = 3 +case_returns = "all" # elements returned by the explain function +distance = "euclidean" +attribution_method = "gradient", +latent_layer = "last_conv" # where to split your model for the projection + +# instantiate the Cole object +cole = Cole( + cases_dataset=cases_dataset, + model=model, + k=k, + attribution_method=attribution_method, + latent_layer=latent_layer, + case_returns=case_returns, + distance=distance, +) + +# search the most similar samples with the COLE method +similar_samples = cole.explain( + inputs=test_samples, + targets=None, # not necessary with default operator, they are computed internally +) +``` + +## Notebooks + +- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) + +{{xplique.example_based.similar_examples.Cole}} + +[^1]: [Twin-Systems to Explain Artificial Neural Networks using Case-Based Reasoning: +Comparative Tests of Feature-Weighting Methods in ANN-CBR Twins for XAI (2019)](https://researchrepository.ucd.ie/handle/10197/11064) \ No newline at end of file diff --git a/docs/api/example_based/similar_examples/similar_examples.md b/docs/api/example_based/similar_examples/similar_examples.md new file mode 100644 index 00000000..a36eadc4 --- /dev/null +++ b/docs/api/example_based/similar_examples/similar_examples.md @@ -0,0 +1,57 @@ +# Similar-Examples + + + + [View colab tutorial](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + + + [View source](https://github.com/deel-ai/xplique/blob/master/xplique/example_based/similar_examples.py) + +We designate here as *Similar Examples* all methods that given an input sample, search for the most similar **training** samples given a distance function `distance`. Furthermore, one can define the search space using a `projection` function (see [Projections](../../projections/)). This function should map an input sample to the search space where the distance function is defined and meaningful (**e.g.** the latent space of a Convolutional Neural Network). +Then, a K-Nearest Neighbors (KNN) search is performed to find the most similar samples in the search space. + +## Example + +```python +from xplique.example_based import SimilarExamples + +cases_dataset = ... # load the training dataset +targets = ... # load the one-hot encoding of predicted labels of the training dataset + +# parameters +k = 5 +distance = "euclidean" +case_returns = ["examples", "nuns"] + +# define the projection function +def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None): + ''' + Example of projection, + inputs are the elements to project. + targets are optional parameters to orientate the projection. + ''' + projected_inputs = # do some magic on inputs, it should use the model. + return projected_inputs + +# instantiate the SimilarExamples object +sim_ex = SimilarExamples( + cases_dataset=cases_dataset, + targets_dataset=targets, + k=k, + projection=custom_projection, + distance=distance, +) + +# load the test samples and targets +test_samples = ... # load the test samples to search for +test_targets = ... # load the one-hot encoding of the test samples' predictions + +# search the most similar samples with the SimilarExamples method +similar_samples = sim_ex.explain(test_samples, test_targets) +``` + +# Notebooks + +- [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) + +{{xplique.example_based.similar_examples.SimilarExamples}} \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 55320f23..0e78c08a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -7,6 +7,9 @@ + + + PyLint @@ -41,8 +44,13 @@ Feature Visualization · Metrics + . + Example-based +!!! warning + With the release of Keras 3.X since TensorFlow 2.16, some methods may not function as expected. We are actively working on a fix. In the meantime, we recommend using TensorFlow 2.15 or earlier versions for optimal compatibility. + The library is composed of several modules, the _Attributions Methods_ module implements various methods (e.g Saliency, Grad-CAM, Integrated-Gradients...), with explanations, examples and links to official papers. The _Feature Visualization_ module allows to see how neural networks build their understanding of images by finding inputs that maximize neurons, channels, layers or compositions of these elements. The _Concepts_ module allows you to extract human concepts from a model and to test their usefulness with respect to a class. @@ -54,6 +62,9 @@ Finally, the _Metrics_ module covers the current metrics used in explainability.
+!!! info "🔔 **New Module Available!**" + We are proud to announce the release of the _Example-based_ module! This module is dedicated to methods that explain a model by retrieving relevant examples from a dataset. It includes methods that belong to different families: similar examples, contrastive (counter-factuals and semi-factuals) examples, and prototypes (as concepts based methods have a dedicated sections). + ## 🔥 Tutorials ??? example "We propose some Hands-on tutorials to get familiar with the library and its api" @@ -109,6 +120,7 @@ Finally, the _Metrics_ module covers the current metrics used in explainability.

- [**Modern Feature Visualization with MaCo**: Getting started](https://colab.research.google.com/drive/1l0kag1o-qMY4NCbWuAwnuzkzd9sf92ic) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1l0kag1o-qMY4NCbWuAwnuzkzd9sf92ic) + - [**Example-based Methods**: Getting started](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) You can find a certain number of [**other practical tutorials just here**](tutorials/). This section is actively developed and more contents will be included. We will try to cover all the possible usage of the library, feel free to contact us if you have any suggestions or recommendations towards tutorials you would like to see. @@ -333,6 +345,24 @@ There are 4 modules in Xplique, [Attribution methods](api/attributions/api_attri TF : Tensorflow compatible +Even though we are only at the early stages, we have also recently added an [Example-based methods](api/example_based/api_example_based/) module. Do not hesitate to give us feedback! Currently, the methods available are summarized in the following table: + +??? abstract "Table of example-based methods available" + + | Method | Family | Documentation | Tutorial | + | --- | --- | --- | --- | + | `SimilarExamples` | Similar Examples | [SimilarExamples](../similar_examples/similar_examples/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + | `Cole` | Similar Examples | [Cole](../similar_examples/cole/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + | | | | + | `NaiveCounterFactuals` | Counter Factuals | [NaiveCounterFactuals](../counterfactuals/naive_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + | `LabelAwareCounterFactuals` | Counter Factuals | [LabelAwareCounterFactuals](../counterfactuals/label_aware_counter_factuals/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + |||| + | `KLEORSimMiss` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + | `KLEORGlobalSim` | Semi Factuals | [KLEOR](../semifactuals/kleor/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | + |||| + | `ProtoGreedy` | Prototypes | [ProtoGreedy](../prototypes/proto_greedy/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | + | `ProtoDash` | Prototypes | [ProtoDash](../prototypes/proto_dash/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | + | `MMDCritic` | Prototypes | [MMDCritic](../prototypes/mmd_critic/) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | ## 👍 Contributing diff --git a/docs/tutorials.md b/docs/tutorials.md index 38957e89..7dda10c1 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -20,6 +20,8 @@ Here is the lists of the availables tutorial for now: | Metrics | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1WEpVpFSq-oL1Ejugr8Ojb3tcbqXIOPBg) | | Concept Activation Vectors | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1iuEz46ZjgG97vTBH8p-vod3y14UETvVE) | | Feature Visualization | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1st43K9AH-UL4eZM1S4QdyrOi7Epa5K8v) | +| Example-Based Methods | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | +| Prototypes | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | ## Attributions @@ -79,3 +81,10 @@ Here is the lists of the availables tutorial for now: | :------------------------------------- | :-----------------------------------------------------------------------------------------------------------------------------------------------------: | | Feature Visualization: Getting started | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1st43K9AH-UL4eZM1S4QdyrOi7Epa5K8v) | | Modern Feature Visualization: MaCo | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1l0kag1o-qMY4NCbWuAwnuzkzd9sf92ic) | + +## Example-Based Methods + +| **Tutorial Name** | Notebook | +| :------------------------------------- | :-----------------------------------------------------------------------------------------------------------------------------------------------------: | +| Example-Based Methods: Getting started | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gA7mhWhWzdKholZWkTvAg4FzFnzS8NHF) | +| Example-based: Prototypes | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OI3oa884GwGbXlzn3Y9NH-1j4cSaQb0w) | \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 6b20f7f5..ee5082d3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -42,6 +42,22 @@ nav: - Cav: api/concepts/cav.md - Tcav: api/concepts/tcav.md - Craft: api/concepts/craft.md + - Example based: + - API Description: api/example_based/api_example_based.md + - Similar Examples: + - SimilarExamples: api/example_based/similar_examples/similar_examples.md + - Cole: api/example_based/similar_examples/cole.md + - Counterfactuals: + - LabelAwareCounterFactuals: api/example_based/counterfactuals/label_aware_counter_factuals.md + - NaiveCounterFactuals: api/example_based/counterfactuals/naive_counter_factuals.md + - Semifactuals: + - Kleor: api/example_based/semifactuals/kleor.md + - Prototypes: + - API Description: api/example_based/prototypes/api_prototypes.md + - ProtoGreedy: api/example_based/prototypes/proto_greedy.md + - ProtoDash: api/example_based/prototypes/proto_dash.md + - MMDCritic: api/example_based/prototypes/mmd_critic.md + - Projections: api/example_based/projections.md - Feature visualization: - Modern Feature Visualization (MaCo): api/feature_viz/maco.md - Feature visualization: api/feature_viz/feature_viz.md @@ -89,8 +105,8 @@ markdown_extensions: custom_checkbox: true clickable_checkbox: true - pymdownx.emoji: - emoji_index: !!python/name:materialx.emoji.twemoji - emoji_generator: !!python/name:materialx.emoji.to_svg + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg extra: version: diff --git a/setup.cfg b/setup.cfg index 8430629c..93dfc177 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.3.3 +current_version = 1.4.0 commit = True tag = False @@ -16,6 +16,7 @@ disable = E1120, # see pylint#3613 E1101, # pylint misses members set dynamically C3001, # lambda function as variable + R0917, # too-many-positional-arguments - TODO: fix this when breaking retrocompatibility [pylint.FORMAT] max-line-length = 100 @@ -26,6 +27,7 @@ min-similarity-lines = 6 ignore-comments = yes ignore-docstrings = yes ignore-imports = no +ignore-signatures = yes [tox:tox] envlist = py{37,38,39,310}-lint, py{37,38,39,310}-tf{22,25,28,211}, py{38,39,310}-tf{25,28,211}-torch{111,113,200} @@ -47,7 +49,7 @@ deps = tf211: tensorflow ~= 2.11.0,<2.16 -rrequirements.txt commands = - pytest --cov=xplique --ignore=xplique/wrappers/pytorch.py --ignore=tests/wrappers/test_pytorch_wrapper.py --ignore=tests/concepts/test_craft_torch.py {posargs} + pytest --cov=xplique --ignore=xplique/wrappers/pytorch.py --ignore=tests/wrappers/test_pytorch_wrapper.py --ignore=tests/concepts/test_craft_torch.py --ignore=tests/example_based/test_torch.py {posargs} [testenv:py{38,39,310}-tf{25,28,211}-torch{111,113,200}] deps = @@ -61,7 +63,7 @@ deps = torch200: torch -rrequirements.txt commands = - pytest --cov=xplique/wrappers/pytorch tests/wrappers/test_pytorch_wrapper.py tests/concepts/test_craft_torch.py + pytest --cov=xplique/wrappers/pytorch tests/wrappers/test_pytorch_wrapper.py tests/concepts/test_craft_torch.py tests/example_based/test_torch.py [mypy] check_untyped_defs = True diff --git a/setup.py b/setup.py index 08a132ef..dc28d072 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="Xplique", - version="1.3.3", + version="1.4.0", description="Explanations toolbox for Tensorflow 2", long_description=README, long_description_content_type="text/markdown", diff --git a/tests/concepts/test_craft_tf.py b/tests/concepts/test_craft_tf.py index 7d940c0a..d1389273 100644 --- a/tests/concepts/test_craft_tf.py +++ b/tests/concepts/test_craft_tf.py @@ -1,15 +1,12 @@ import numpy as np import tensorflow as tf -import random import pytest -import os -from tensorflow.keras.models import Sequential -from tensorflow.keras.layers import Dense, Conv2D, Activation, Flatten, Input -from tensorflow.keras.optimizers import Adam + +from tensorflow.keras.layers import Input from xplique.concepts import CraftTf as Craft -from ..utils import generate_data, generate_model, generate_txt_images_data -from ..utils import download_file + +from ..utils import generate_data, generate_model def test_shape(): @@ -100,172 +97,3 @@ def test_wrong_layers(): number_of_concepts = number_of_concepts, patch_size = patch_size, batch_size = 64) - -def test_classifier(): - """ Check the Craft results on a small fake dataset """ - - input_shape = (64, 64, 3) - nb_labels = 3 - nb_samples = 200 - - # Create a dataset of 'ABC', 'BCD', 'CDE' images - x, y, nb_samples, _ = generate_txt_images_data(input_shape, nb_labels, nb_samples) - - # train a small classifier on the dataset - def create_classifier_model(input_shape=(64, 64, 3), output_shape=10): - model = Sequential() - model.add(Input(shape=input_shape)) - model.add(Conv2D(6, kernel_size=(2, 2))) - model.add(Activation('relu')) - model.add(Conv2D(6, kernel_size=(2, 2))) - model.add(Activation('relu')) - model.add(Conv2D(6, kernel_size=(2, 2))) - model.add(Activation('relu', name='relu')) - model.add(Flatten()) - model.add(Dense(output_shape)) - model.add(Activation('softmax')) - opt = Adam(learning_rate=0.005) - model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy']) - - return model - - model = create_classifier_model(input_shape, nb_labels) - - tf.random.set_seed(0) - np.random.seed(0) - random.seed(0) - - # Retrieve checkpoints - checkpoint_path = "tests/concepts/checkpoints/classifier_test_craft_tf.ckpt" - if not os.path.exists(f"{checkpoint_path}.index"): - os.makedirs("tests/concepts/checkpoints/", exist_ok=True) - identifier = "1NLA7x2EpElzEEmyvFQhD6VS6bMwS_bCs" - download_file(identifier, f"{checkpoint_path}.index") - - identifier = "1wDi-y9b-3I_a-ZtqRlfuib-D7Ox4j8pX" - download_file(identifier, f"{checkpoint_path}.data-00000-of-00001") - - model.load_weights(checkpoint_path) - - acc = np.sum(np.argmax(model(x), axis=1) == np.argmax(y, axis=1)) / nb_samples - assert acc == 1.0 - - # cut the model in two parts (as explained in the paper) - # first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model - cut_layer = model.get_layer('relu') - g = tf.keras.Model(model.inputs, cut_layer.output) - h = tf.keras.Model(Input(tensor=cut_layer.output), model.outputs) - - assert np.all(g(x) >= 0.0) - - # Init Craft on the full dataset - craft = Craft(input_to_latent_model = g, - latent_to_logit_model = h, - number_of_concepts = 3, - patch_size = 12, - batch_size = 32) - - # Expected best crop for class 0 (ABC) is AB - AB_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 1 1 1 1 1 1 1 1 - 0 0 0 0 0 0 1 0 0 0 1 1 - 1 0 0 0 0 0 1 0 0 0 0 1 - 1 0 0 0 0 0 1 0 0 0 1 1 - 1 0 0 0 0 0 1 1 1 1 1 1 - 1 1 0 0 0 0 1 0 0 0 0 1 - 0 1 0 0 0 0 1 0 0 0 0 0 - 0 1 1 0 0 0 1 0 0 0 0 1 - 1 1 1 1 1 1 1 1 1 1 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - AB = np.genfromtxt(AB_str.splitlines()) - - # Expected best crop for class 1 (BCD) is BC - BC_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - 1 1 1 1 1 1 0 0 0 0 1 1 - 1 0 0 0 1 1 0 0 0 1 1 0 - 1 0 0 0 0 1 0 0 0 1 0 0 - 1 0 0 0 1 1 0 0 0 1 0 0 - 1 1 1 1 1 1 0 0 0 1 0 0 - 1 0 0 0 0 1 1 0 0 1 0 0 - 1 0 0 0 0 0 1 0 0 1 0 0 - 1 0 0 0 0 1 1 0 0 1 1 0 - 1 1 1 1 1 1 0 0 0 0 1 1 - """ - BC = np.genfromtxt(BC_str.splitlines()) - - # Expected best crop for class 2 (CDE) is DE - DE_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 1 0 0 1 1 1 1 1 1 1 1 0 - 1 1 0 0 0 1 0 0 0 0 1 0 - 0 1 0 0 0 1 0 0 0 0 1 0 - 0 1 1 0 0 1 0 0 1 0 0 0 - 0 1 1 0 0 1 1 1 1 0 0 0 - 0 1 1 0 0 1 0 0 1 0 0 0 - 0 1 0 0 0 1 0 0 0 0 1 1 - 1 1 0 0 0 1 0 0 0 0 1 1 - 1 0 0 1 1 1 1 1 1 1 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - DE = np.genfromtxt(DE_str.splitlines()) - - DE2_str = """ - 0 0 0 0 0 0 0 0 0 0 0 Z - 1 1 1 1 0 0 1 1 1 1 1 1 - 0 0 1 1 1 0 0 0 1 0 0 0 - 0 0 0 0 1 0 0 0 1 0 0 0 - 0 0 0 0 1 1 0 0 1 0 0 1 - 0 0 0 0 1 1 0 0 1 1 1 1 - 0 0 0 0 1 1 0 0 1 0 0 1 - 0 0 0 0 1 0 0 0 1 0 0 0 - 0 0 1 1 1 0 0 0 1 0 0 0 - 1 1 1 1 0 0 1 1 1 1 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - DE2 = np.genfromtxt(DE2_str.splitlines()) - - expected_best_crops = [[AB], [BC], [DE, DE2]] - expected_best_crops_names = ['AB', 'BC', 'DE'] - - # Run 3 Craft studies on each class, and in each case check if the best crop is the expected one - class_check = [False, False, False] - for class_id in range(3): - # Focus on class class_id - # Selecting subset for class {class_id} : {labels_str[class_id]}' - x_subset = x[np.argmax(y, axis=1)==class_id,:,:,:] - - # fit craft on the selected class - crops, crops_u, w = craft.fit(x_subset, class_id) - - # compute importances - importances = craft.estimate_importance() - assert importances[0] > 0.8 - - # find the best crop and compare it to the expected best crop - most_important_concepts = np.argsort(importances)[::-1] - - # Find the best crop for the most important concept - c_id = most_important_concepts[0] - best_crops_ids = np.argsort(crops_u[:, c_id])[::-1] - best_crop = np.array(crops)[best_crops_ids[0]] - - # Compare this best crop to the expectation - predicted_best_crop = np.where(best_crop.sum(axis=2) > 0.25, 1, 0) - for expected_best_crop in expected_best_crops[class_id]: - expected_best_crop = expected_best_crop.astype(np.uint8) - - comparison = predicted_best_crop == expected_best_crop - acc = np.sum(comparison) / len(comparison.ravel()) - check = acc > 0.9 - if check: - class_check[class_id] = True - break - assert np.all(class_check) diff --git a/tests/concepts/test_craft_torch.py b/tests/concepts/test_craft_torch.py index 0ef214e7..ef46843c 100644 --- a/tests/concepts/test_craft_torch.py +++ b/tests/concepts/test_craft_torch.py @@ -1,14 +1,10 @@ import numpy as np -import os import torch import torch.nn as nn import torch.nn.functional as F import pytest -import random from xplique.concepts import CraftTorch as Craft -from ..utils import generate_txt_images_data -from ..utils import download_file def generate_torch_data(x_shape=(3, 32, 32), num_labels=10, samples=100): x = torch.tensor(np.random.rand(samples, *x_shape).astype(np.float32)) @@ -133,177 +129,3 @@ def test_wrong_layers(): number_of_concepts = number_of_concepts, patch_size = patch_size, batch_size = 64) - -def test_classifier(): - """ Check the Craft results on a small fake dataset """ - - input_shape = (64, 64, 3) - nb_labels = 3 - nb_samples = 200 - - torch.manual_seed(0) - torch.use_deterministic_algorithms(True) - random.seed(0) - np.random.seed(0) - - # Create a dataset of 'ABC', 'BCD', 'CDE' images - x, y, nb_samples, _ = generate_txt_images_data(input_shape, nb_labels, nb_samples) - x = np.moveaxis(x, -1, 1) # reorder the axis to match torch format - x, y = torch.Tensor(x), torch.Tensor(y) - - # train a small classifier on the dataset - def create_torch_classifier_model(input_shape=(3, 64, 64), output_shape=10): - flatten_size = 6*(input_shape[1]-3)*(input_shape[2]-3) - model = nn.Sequential( - nn.Conv2d(3, 6, kernel_size=(2, 2)), - nn.ReLU(), - nn.Conv2d(6, 6, kernel_size=(2, 2)), - nn.ReLU(), - nn.Conv2d(6, 6, kernel_size=(2, 2)), - nn.ReLU(), - nn.Flatten(1, -1), - # nn.Dropout(p=0.2), - nn.Linear(flatten_size, output_shape)) - for layer in model: - if isinstance(layer, nn.Conv2d): - nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu') - layer.bias.data.fill_(0.01) - elif isinstance(layer, nn.Linear): - nn.init.xavier_normal_(layer.weight) - layer.bias.data.fill_(0.01) - return model - - model = create_torch_classifier_model((input_shape[-1], *input_shape[0:2]), nb_labels) - - # Retrieve checkpoints - checkpoint_path = "tests/concepts/checkpoints/classifier_test_craft_torch.ckpt" - if not os.path.exists(checkpoint_path): - os.makedirs("tests/concepts/checkpoints/", exist_ok=True) - identifier = "1vz6hMibMEN6_t9yAY9SS4iaMY8G8aAPQ" - download_file(identifier, checkpoint_path) - model.load_state_dict(torch.load(checkpoint_path)) - - # check accuracy - model.eval() - acc = torch.sum(torch.argmax(model(x), axis=1) == torch.argmax(y, axis=1))/len(y) - assert acc > 0.9 - - # cut pytorch model - g = nn.Sequential(*(list(model.children())[:6])) # input to penultimate layer - h = nn.Sequential(*(list(model.children())[6:])) # penultimate layer to logits - assert torch.all(g(x) >= 0.0) - - # Init Craft on the full dataset - craft = Craft(input_to_latent_model = g, - latent_to_logit_model = h, - number_of_concepts = 3, - patch_size = 12, - batch_size = 32, - device='cpu') - - # Expected best crop for class 0 (ABC) is AB - AB_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 1 1 1 1 1 1 1 1 0 0 0 - 0 0 0 1 0 0 0 1 1 0 0 1 - 0 0 0 1 0 0 0 0 1 0 0 1 - 0 0 0 1 0 0 0 1 1 0 0 1 - 0 0 0 1 1 1 1 1 1 0 0 1 - 0 0 0 1 0 0 0 0 1 1 0 1 - 0 0 0 1 0 0 0 0 0 1 0 1 - 0 0 0 1 0 0 0 0 1 1 0 1 - 1 1 1 1 1 1 1 1 1 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - AB = np.genfromtxt(AB_str.splitlines()) - - # Expected best crop for class 1 (BCD) is BC - BC_str = """ - 1 1 1 1 1 1 0 0 0 0 1 1 - 1 0 0 0 1 1 0 0 0 1 1 0 - 1 0 0 0 0 1 0 0 0 1 0 0 - 1 0 0 0 1 1 0 0 0 1 0 0 - 1 1 1 1 1 1 0 0 0 1 0 0 - 1 0 0 0 0 1 1 0 0 1 0 0 - 1 0 0 0 0 0 1 0 0 1 0 0 - 1 0 0 0 0 1 1 0 0 1 1 0 - 1 1 1 1 1 1 0 0 0 0 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - BC = np.genfromtxt(BC_str.splitlines()) - - # Expected best crop for class 2 (CDE) is DE - DE_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 1 0 0 1 1 1 1 1 1 1 1 0 - 1 1 0 0 0 1 0 0 0 0 1 0 - 0 1 0 0 0 1 0 0 0 0 1 0 - 0 1 1 0 0 1 0 0 1 0 0 0 - 0 1 1 0 0 1 1 1 1 0 0 0 - 0 1 1 0 0 1 0 0 1 0 0 0 - 0 1 0 0 0 1 0 0 0 0 1 1 - 1 1 0 0 0 1 0 0 0 0 1 1 - 1 0 0 1 1 1 1 1 1 1 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - DE = np.genfromtxt(DE_str.splitlines()) - - DE2_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 1 1 1 1 0 0 1 1 1 1 1 1 - 0 0 1 1 1 0 0 0 1 0 0 0 - 0 0 0 0 1 0 0 0 1 0 0 0 - 0 0 0 0 1 1 0 0 1 0 0 1 - 0 0 0 0 1 1 0 0 1 1 1 1 - 0 0 0 0 1 1 0 0 1 0 0 1 - 0 0 0 0 1 0 0 0 1 0 0 0 - 0 0 1 1 1 0 0 0 1 0 0 0 - 1 1 1 1 0 0 1 1 1 1 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - DE2 = np.genfromtxt(DE2_str.splitlines()) - - expected_best_crops = [[AB], [BC], [DE, DE2]] - expected_best_crops_names = ['AB', 'BC', 'DE'] - - # Run 3 Craft studies on each class, and in each case check if the best crop is the expected one - class_check = [False, False, False] - for class_id in range(3): - # Focus on class class_id - # Selecting subset for class {class_id} : {labels_str[class_id]}' - x_subset = x[np.argmax(y, axis=1)==class_id,:,:,:] - - # fit craft on the selected class - crops, crops_u, w = craft.fit(x_subset, class_id) - - # compute importances - importances = craft.estimate_importance() - assert np.all(importances >= 0) - - # find the best crop and compare it to the expected best crop - most_important_concepts = np.argsort(importances)[::-1] - - # Find the best crop for the most important concept - c_id = most_important_concepts[0] - best_crops_ids = np.argsort(crops_u[:, c_id])[::-1] - best_crop = np.array(crops)[best_crops_ids[0]] - best_crop = np.moveaxis(best_crop, 0, -1) - - # Compare this best crop to the expectation - predicted_best_crop = np.where(best_crop.sum(axis=2) > 0.25, 1, 0) - - # Comparison between expected: - for expected_best_crop in expected_best_crops[class_id]: - expected_best_crop = expected_best_crop.astype(np.uint8) - comparison = predicted_best_crop == expected_best_crop - acc = np.sum(comparison) / len(comparison.ravel()) - check = acc > 0.9 - if check: - class_check[class_id] = True - break - assert np.all(class_check) diff --git a/tests/example_based/__init__.py b/tests/example_based/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/example_based/test_cole.py b/tests/example_based/test_cole.py new file mode 100644 index 00000000..f94ea24a --- /dev/null +++ b/tests/example_based/test_cole.py @@ -0,0 +1,216 @@ +""" +Test Cole +""" +import os + +import sys + +sys.path.append(os.getcwd()) + +import numpy as np +import tensorflow as tf + +from xplique.commons.operators_operations import gradients_predictions +from xplique.attributions import Occlusion, Saliency +from xplique.example_based import Cole, SimilarExamples +from xplique.example_based.projections import Projection + +from tests.utils import ( + generate_data, + generate_model, + almost_equal, + generate_timeseries_model, +) + + +def get_setup(input_shape, nb_samples=10, nb_labels=10): + """ + Generate data and model for Cole + """ + # Data generation + x_train = tf.stack( + [i * tf.ones(input_shape, tf.float32) for i in range(nb_samples)] + ) + x_test = x_train[1:-1] + y_train = tf.one_hot(tf.range(len(x_train)) % nb_labels, depth=nb_labels) + y_test = y_train[1:-1] + + # Model generation + model = generate_model(input_shape, nb_labels) + + return model, x_train, x_test, y_train, y_test + + +def test_cole_attribution(): + """ + Test Cole attribution projection. + It should be the same as a manual projection. + Test that the distance has an impact. + """ + # Setup + nb_samples = 50 + input_shape = (5, 5) + nb_labels = 10 + k = 3 + x_train = tf.random.uniform( + (nb_samples,) + input_shape, minval=-1, maxval=1, seed=0 + ) + x_test = tf.random.uniform((nb_samples,) + input_shape, minval=-1, maxval=1, seed=2) + labels = tf.one_hot( + indices=tf.repeat(input=tf.range(nb_labels), repeats=[nb_samples // nb_labels]), + depth=nb_labels, + ) + y_train = labels + y_test = tf.random.shuffle(labels, seed=1) + + # Model generation + model = generate_timeseries_model(input_shape, nb_labels) + + # Cole with attribution method constructor + method_constructor = Cole( + cases_dataset=x_train, + targets_dataset=y_train, + k=k, + batch_size=7, + distance="euclidean", + model=model, + attribution_method=Saliency, + ) + + # Cole with attribution explain batch gradient is overwritten for test purpose, do not copy! + explainer = Saliency(model) + explainer.batch_gradient = \ + lambda model, inputs, targets, batch_size:\ + explainer.gradient(model, inputs, targets) + projection = Projection(get_weights=explainer) + + euclidean_dist = lambda x, z: tf.sqrt(tf.reduce_sum(tf.square(x - z), axis=-1)) + method_call = SimilarExamples( + cases_dataset=x_train, + targets_dataset=y_train, + k=k, + distance=euclidean_dist, + projection=projection, + ) + + method_different_distance = Cole( + cases_dataset=x_train, + targets_dataset=y_train, + k=k, + batch_size=2, + distance="cosine", # infinity norm based distance + model=model, + attribution_method=Saliency, + ) + + # Generate explanation + examples_constructor = method_constructor.explain(x_test, y_test)["examples"] + examples_call = method_call.explain(x_test, y_test)["examples"] + examples_different_distance = method_different_distance(x_test, y_test)["examples"] + + # Verifications + # Shape should be (n, k, h, w, c) + assert examples_constructor.shape == (len(x_test), k) + input_shape + assert examples_call.shape == (len(x_test), k) + input_shape + assert examples_different_distance.shape == (len(x_test), k) + input_shape + + # both methods should be the same + assert almost_equal(examples_constructor, examples_call) + + # a different distance should give different results + assert not almost_equal(examples_constructor, examples_different_distance) + + +def test_cole_hadamard(): + """ + Test Cole with Hadamard projection. + It should be the same as a manual projection. + """ + # Setup + input_shape = (7, 7, 3) + nb_samples = 10 + nb_labels = 2 + k = 3 + model, x_train, x_test, y_train, y_test =\ + get_setup(input_shape, nb_samples=nb_samples, nb_labels=nb_labels) + + # Cole with Hadamard projection constructor + method_constructor = Cole( + cases_dataset=x_train, + targets_dataset=y_train, + k=k, + batch_size=7, + distance="euclidean", + model=model, + projection_method="gradient", + ) + + # Cole with Hadamard projection explain batch gradient is overwritten for test purpose, do not copy! + weights_extraction = lambda inputs, targets: gradients_predictions(model, inputs, targets) + projection = Projection(get_weights=weights_extraction) + + euclidean_dist = lambda x, z: tf.sqrt(tf.reduce_sum(tf.square(x - z), axis=-1)) + method_call = SimilarExamples( + cases_dataset=x_train, + targets_dataset=y_train, + k=k, + distance=euclidean_dist, + projection=projection, + ) + + # Generate explanation + examples_constructor = method_constructor.explain(x_test, y_test)["examples"] + examples_call = method_call.explain(x_test, y_test)["examples"] + + # Verifications + # Shape should be (n, k, h, w, c) + assert examples_constructor.shape == (len(x_test), k) + input_shape + assert examples_call.shape == (len(x_test), k) + input_shape + + # both methods should be the same + assert almost_equal(examples_constructor, examples_call) + + +def test_cole_splitting(): + """ + Test Cole with a `latent_layer` provided. + It should split the model. + """ + # Setup + nb_samples = 10 + input_shape = (6, 6, 3) + nb_labels = 5 + k = 1 + x_train = tf.random.uniform((nb_samples,) + input_shape, minval=0, maxval=1) + x_test = tf.random.uniform((nb_samples,) + input_shape, minval=0, maxval=1) + labels = tf.one_hot( + indices=tf.repeat(input=tf.range(nb_labels), repeats=[nb_samples // nb_labels]), + depth=nb_labels, + ) + y_train = labels + y_test = tf.random.shuffle(labels) + + # Model generation + model = generate_model(input_shape, nb_labels) + + # Cole with attribution method constructor + method = Cole( + cases_dataset=x_train, + targets_dataset=y_train, + k=k, + case_returns=["examples", "include_inputs"], + model=model, + latent_layer="last_conv", + attribution_method=Occlusion, + patch_size=2, + patch_stride=1, + ) + + # Generate explanation + outputs = method.explain(x_test, y_test) + examples = outputs["examples"] + + # Verifications + # Shape should be (n, k, h, w, c) + nb_samples_test = x_test.shape[0] + assert examples.shape == (nb_samples_test, k + 1) + input_shape diff --git a/tests/example_based/test_contrastive.py b/tests/example_based/test_contrastive.py new file mode 100644 index 00000000..d136d2c4 --- /dev/null +++ b/tests/example_based/test_contrastive.py @@ -0,0 +1,318 @@ +""" +Tests for the contrastive methods. +""" +import tensorflow as tf +import numpy as np + +from xplique.example_based import NaiveCounterFactuals, LabelAwareCounterFactuals, KLEORGlobalSim, KLEORSimMiss +from xplique.example_based.projections import Projection, LatentSpaceProjection + +from ..utils import generate_data, generate_model + + +def test_naive_counter_factuals(): + """ + """ + # setup the tests + cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32) + cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32) + + cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2) + cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2) + + inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32) + targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32) + + projection = Projection(space_projection=lambda inputs: inputs) + + # build the NaiveCounterFactuals object + counter_factuals = NaiveCounterFactuals( + cases_dataset, + cases_targets_dataset, + k=2, + projection=projection, + case_returns=["examples", "indices", "distances", "include_inputs"], + batch_size=2 + ) + + mask = counter_factuals.filter_fn(inputs, cases, targets, cases_targets) + assert mask.shape == (inputs.shape[0], cases.shape[0]) + + expected_mask = tf.constant([ + [False, True, True, False, True], + [True, False, False, True, False], + [True, False, False, True, False]], dtype=tf.bool) + assert tf.reduce_all(tf.equal(mask, expected_mask)) + + return_dict = counter_factuals(inputs, targets) + assert set(return_dict.keys()) == set(["examples", "indices", "distances"]) + + examples = return_dict["examples"] + distances = return_dict["distances"] + indices = return_dict["indices"] + + assert examples.shape == (3, 3, 2) # (n, k+1, W) + assert distances.shape == (3, 2) # (n, k) + assert indices.shape == (3, 2, 2) # (n, k, 2) + + expected_examples = tf.constant([ + [[1.5, 2.5], [2., 3.], [3., 4.]], + [[2.5, 3.5], [1., 2.], [4., 5.]], + [[4.5, 5.5], [4., 5.], [1., 2.]]], dtype=tf.float32) + assert tf.reduce_all(tf.equal(examples, expected_examples)) + + expected_distances = tf.constant([[np.sqrt(2*0.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*1.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*0.5**2), np.sqrt(2*3.5**2)]], dtype=tf.float32) + assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5) + + expected_indices = tf.constant([[[0, 1], [1, 0]],[[0, 0], [1, 1]],[[1, 1], [0, 0]]], dtype=tf.int32) + assert tf.reduce_all(tf.equal(indices, expected_indices)) + + +def test_label_aware_cf(): + """ + Test suite for the LabelAwareCounterFactuals class + """ + # Same tests as the previous one but with the LabelAwareCounterFactuals class + # thus we only needs to use cf_targets = 1 - targets of the previous tests + cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32) + cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32) + + cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2) + cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2) + + inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32) + # cf_targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32) + cf_expected_classes = tf.constant([[1, 0], [0, 1], [0, 1]], dtype=tf.float32) + + projection = Projection(space_projection=lambda inputs: inputs) + + # build the LabelAwareCounterFactuals object + counter_factuals = LabelAwareCounterFactuals( + cases_dataset=cases_dataset, + targets_dataset=cases_targets_dataset, + k=1, + projection=projection, + case_returns=["examples", "indices", "distances", "include_inputs"], + batch_size=2 + ) + + mask = counter_factuals.filter_fn(inputs, cases, cf_expected_classes, cases_targets) + assert mask.shape == (inputs.shape[0], cases.shape[0]) + + expected_mask = tf.constant([ + [False, True, True, False, True], + [True, False, False, True, False], + [True, False, False, True, False]], dtype=tf.bool) + assert tf.reduce_all(tf.equal(mask, expected_mask)) + + return_dict = counter_factuals(inputs, targets=None, cf_expected_classes=cf_expected_classes) + assert set(return_dict.keys()) == set(["examples", "indices", "distances"]) + + examples = return_dict["examples"] + distances = return_dict["distances"] + indices = return_dict["indices"] + + assert examples.shape == (3, 2, 2) # (n, k+1, W) + assert distances.shape == (3, 1) # (n, k) + assert indices.shape == (3, 1, 2) # (n, k, 2) + + expected_examples = tf.constant([ + [[1.5, 2.5], [2., 3.]], + [[2.5, 3.5], [1., 2.]], + [[4.5, 5.5], [4., 5.]]], dtype=tf.float32) + assert tf.reduce_all(tf.equal(examples, expected_examples)) + + expected_distances = tf.constant([[np.sqrt(2*0.5**2)], [np.sqrt(2*1.5**2)], [np.sqrt(2*0.5**2)]], dtype=tf.float32) + assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5) + + expected_indices = tf.constant([[[0, 1]],[[0, 0]],[[1, 1]]], dtype=tf.int32) + assert tf.reduce_all(tf.equal(indices, expected_indices)) + + # Now let's dive when multiple classes are available in 1D + cases = tf.constant([[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.], [9.], [10.]], dtype=tf.float32) + cases_targets = tf.constant([[0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1], [0, 1, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=tf.float32) + + cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2) + cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2) + + counter_factuals = LabelAwareCounterFactuals( + cases_dataset=cases_dataset, + targets_dataset=cases_targets_dataset, + k=1, + projection=projection, + case_returns=["examples", "indices", "distances", "include_inputs"], + batch_size=2 + ) + + inputs = tf.constant([[1.5], [2.5], [4.5], [6.5], [8.5]], dtype=tf.float32) + cf_expected_classes = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0]], dtype=tf.float32) + + mask = counter_factuals.filter_fn(inputs, cases, cf_expected_classes, cases_targets) + assert mask.shape == (inputs.shape[0], cases.shape[0]) + + expected_mask = tf.constant([ + [False, True, False, True, True, False, False, False, False, True], + [True, False, False, False, False, False, True, False, True, False], + [False, False, True, False, False, True, False, True, False, False], + [False, False, True, False, False, True, False, True, False, False], + [True, False, False, False, False, False, True, False, True, False]], dtype=tf.bool) + assert tf.reduce_all(tf.equal(mask, expected_mask)) + + return_dict = counter_factuals(inputs, cf_expected_classes=cf_expected_classes) + assert set(return_dict.keys()) == set(["examples", "indices", "distances"]) + + examples = return_dict["examples"] + distances = return_dict["distances"] + indices = return_dict["indices"] + + assert examples.shape == (5, 2, 1) # (n, k+1, W) + assert distances.shape == (5, 1) # (n, k) + assert indices.shape == (5, 1, 2) # (n, k, 2) + + expected_examples = tf.constant([ + [[1.5], [2.]], + [[2.5], [1.]], + [[4.5], [3.]], + [[6.5], [6.]], + [[8.5], [9.]]], dtype=tf.float32) + assert tf.reduce_all(tf.equal(examples, expected_examples)) + + expected_distances = tf.constant([[np.sqrt(0.5**2)], [np.sqrt(1.5**2)], [np.sqrt(1.5**2)], [np.sqrt(0.5**2)], [np.sqrt(0.5**2)]], dtype=tf.float32) + assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5) + + expected_indices = tf.constant([[[0, 1]],[[0, 0]],[[1, 0]],[[2, 1]],[[4, 0]]], dtype=tf.int32) + assert tf.reduce_all(tf.equal(indices, expected_indices)) + + +def test_kleor(): + """ + Test suite for the Kleor class + """ + # setup the tests + cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32) + cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32) + + cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2) + cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2) + + inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32) + targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32) + + projection = Projection(space_projection=lambda inputs: inputs) + + # start when strategy is sim_miss + kleor_sim_miss = KLEORSimMiss( + cases_dataset=cases_dataset, + targets_dataset=cases_targets_dataset, + k=1, + projection=projection, + case_returns=["examples", "indices", "distances", "include_inputs", "nuns"], + batch_size=2, + ) + + return_dict = kleor_sim_miss(inputs, targets) + assert set(return_dict.keys()) == set(["examples", "indices", "distances", "nuns"]) + + examples = return_dict["examples"] + distances = return_dict["distances"] + indices = return_dict["indices"] + nuns = return_dict["nuns"] + + expected_nuns = tf.constant([ + [[2., 3.]], + [[1., 2.]], + [[4., 5.]]], dtype=tf.float32) + assert tf.reduce_all(tf.equal(nuns, expected_nuns)) + + assert examples.shape == (3, 2, 2) # (n, k+1, W) + assert distances.shape == (3, 1) # (n, k) + assert indices.shape == (3, 1, 2) # (n, k, 2) + + expected_examples = tf.constant([ + [[1.5, 2.5], [1., 2.]], + [[2.5, 3.5], [2., 3.]], + [[4.5, 5.5], [3., 4.]]], dtype=tf.float32) + assert tf.reduce_all(tf.equal(examples, expected_examples)) + + expected_distances = tf.constant([[np.sqrt(2*0.5**2)], [np.sqrt(2*0.5**2)], [np.sqrt(2*1.5**2)]], dtype=tf.float32) + assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5) + + expected_indices = tf.constant([[[0, 0]],[[0, 1]],[[1, 0]]], dtype=tf.int32) + assert tf.reduce_all(tf.equal(indices, expected_indices)) + + # now strategy is global_sim + kleor_global_sim = KLEORGlobalSim( + cases_dataset, + cases_targets_dataset, + k=1, + projection=projection, + case_returns=["examples", "indices", "distances", "include_inputs", "nuns"], + batch_size=2, + ) + + return_dict = kleor_global_sim(inputs, targets) + assert set(return_dict.keys()) == set(["examples", "indices", "distances", "nuns"]) + + nuns = return_dict["nuns"] + assert tf.reduce_all(tf.equal(nuns, expected_nuns)) + + examples = return_dict["examples"] + distances = return_dict["distances"] + indices = return_dict["indices"] + + assert examples.shape == (3, 2, 2) # (n, k+1, W) + assert distances.shape == (3, 1) # (n, k) + assert indices.shape == (3, 1, 2) # (n, k, 2) + + expected_indices = tf.constant([[[-1, -1]],[[0, 1]],[[-1, -1]]], dtype=tf.int32) + assert tf.reduce_all(tf.equal(indices, expected_indices)) + + expected_distances = tf.constant([[np.inf], [np.sqrt(2*0.5**2)], [np.inf]], dtype=tf.float32) + # create masks for inf values + inf_mask_dist = tf.math.is_inf(distances) + inf_mask_expected_distances = tf.math.is_inf(expected_distances) + assert tf.reduce_all(tf.equal(inf_mask_dist, inf_mask_expected_distances)) + assert tf.reduce_all( + tf.abs(tf.where(inf_mask_dist, 0.0, distances) - tf.where(inf_mask_expected_distances, 0.0, expected_distances) + ) < 1e-5) + + expected_examples = tf.constant([ + [[1.5, 2.5], [np.inf, np.inf]], + [[2.5, 3.5], [2., 3.]], + [[4.5, 5.5], [np.inf, np.inf]]], dtype=tf.float32) + # mask for inf values + inf_mask_examples = tf.math.is_inf(examples) + inf_mask_expected_examples = tf.math.is_inf(expected_examples) + assert tf.reduce_all(tf.equal(inf_mask_examples, inf_mask_expected_examples)) + assert tf.reduce_all( + tf.abs(tf.where(inf_mask_examples, 0.0, examples) - tf.where(inf_mask_expected_examples, 0.0, expected_examples) + ) < 1e-5) + + +def test_contrastive_with_projection(): + input_shapes = [(28, 28, 1), (32, 32, 3)] + nb_labels = 10 + nb_samples = 50 + + for input_shape in input_shapes: + features, labels = generate_data(input_shape, nb_labels, nb_samples) + model = generate_model(input_shape, nb_labels) + + projection = LatentSpaceProjection(model, latent_layer=-1) + + for contrastive_method_class in [NaiveCounterFactuals, LabelAwareCounterFactuals, + KLEORGlobalSim, KLEORSimMiss]: + contrastive_method = contrastive_method_class( + features, + labels, + k=1, + projection=projection, + case_returns=["examples", "indices", "distances", "include_inputs"], + batch_size=7 + ) + + if isinstance(contrastive_method, LabelAwareCounterFactuals): + cf_expected_classes = tf.one_hot(tf.argmax(labels, axis=-1) + 1 % nb_labels, nb_labels) + contrastive_method(features, targets=labels, cf_expected_classes=cf_expected_classes) + else: + contrastive_method(features, targets=labels) \ No newline at end of file diff --git a/tests/example_based/test_datasets_harmonization.py b/tests/example_based/test_datasets_harmonization.py new file mode 100644 index 00000000..70d5efab --- /dev/null +++ b/tests/example_based/test_datasets_harmonization.py @@ -0,0 +1,229 @@ +import pytest +import tensorflow as tf +import numpy as np + + +from xplique.example_based.datasets_operations.tf_dataset_operations import are_dataset_first_elems_equal, is_batched +from xplique.example_based.datasets_operations.harmonize import split_tf_dataset, harmonize_datasets + + +def generate_tf_dataset(n_samples=100, n_features=10, n_labels=1, n_targets=None, batch_size=None): + """ + Utility function to generate TensorFlow datasets for testing. + """ + cases = np.random.random((n_samples, n_features, n_features)).astype(np.float32) + labels = np.random.randint(0, n_labels, size=(n_samples,)).astype(np.int64) + + if n_targets is not None: + targets = np.random.random((n_samples, n_targets)).astype(np.float32) + dataset = tf.data.Dataset.from_tensor_slices((cases, labels, targets)) + else: + dataset = tf.data.Dataset.from_tensor_slices((cases, labels)) + + if batch_size is not None: + dataset = dataset.batch(batch_size) + + return dataset + + +def test_split_tf_dataset_two_columns(): + dataset = generate_tf_dataset(n_samples=100, n_features=5, n_labels=2, batch_size=8) + + cases, labels, targets = split_tf_dataset(dataset) + + assert labels is not None, "Labels dataset should not be None for a 2-column dataset." + assert targets is None, "Targets dataset should be None for a 2-column dataset." + + for case_h, label_h, (case, label) in zip(cases, labels, dataset): + assert len(case_h.shape) == 3 and case_h.shape[1:] == (5, 5) + assert len(label_h.shape) == 1 + assert np.allclose(case_h, case), "Cases should match the original dataset." + assert np.allclose(label_h, label), "Labels should match the original dataset." + + +def test_split_tf_dataset_three_columns(): + dataset = generate_tf_dataset(n_samples=100, n_features=5, n_labels=2, n_targets=2, batch_size=8) + + cases, labels, targets = split_tf_dataset(dataset) + + assert labels is not None, "Labels dataset should not be None for a 3-column dataset." + assert targets is not None, "Targets dataset should not be None for a 3-column dataset." + + for case_h, label_h, target_h, (case, label, target) in zip(cases, labels, targets, dataset): + assert len(case_h.shape) == 3 and case_h.shape[1:] == (5, 5) + assert len(label_h.shape) == 1 + assert len(target_h.shape) == 2 and target_h.shape[1] == 2 + assert np.allclose(case_h, case), "Cases should match the original dataset." + assert np.allclose(label_h, label), "Labels should match the original dataset." + assert np.allclose(target_h, target), "Targets should match the original dataset." + + +def test_harmonize_datasets_with_tf_dataset(): + nb_features = 5 + nb_labels = 3 + dataset = generate_tf_dataset(n_samples=100, n_features=nb_features, n_labels=nb_labels) + batch_size = 10 + + assert not is_batched(dataset), "Dataset should not be batched." + + cases, labels, targets, batch_size_out = harmonize_datasets(dataset, batch_size=batch_size) + batched_dataset = dataset.batch(10) + + assert is_batched(cases), "Cases dataset should be batched." + assert is_batched(labels), "Labels dataset should be batched." + + assert cases is not None, "Cases dataset should not be None." + assert labels is not None, "Labels dataset should not be None." + assert targets is None, "Targets dataset should be None for a 2-column input dataset." + assert batch_size_out == batch_size, "Output batch size should match the input batch size." + + for case_h, label_h, (case, label) in zip(cases, labels, batched_dataset): + assert len(case_h.shape) == 3 and case_h.shape[1:] == (nb_features, nb_features) + assert len(label_h.shape) == 1 + assert np.allclose(case_h, case), "Cases should match the original dataset." + assert np.allclose(label_h, label), "Labels should match the original dataset." + + +def test_harmonize_datasets_with_tf_dataset_three_columns(): + batch_size = 10 + dataset = generate_tf_dataset(n_samples=100, n_features=10, n_labels=1, n_targets=1, batch_size=batch_size) + + cases, labels, targets, batch_size_out = harmonize_datasets(dataset, batch_size=batch_size) + + assert cases is not None, "Cases dataset should not be None." + assert labels is not None, "Labels dataset should not be None." + assert targets is not None, "Targets dataset should not be None for a 3-column input dataset." + assert batch_size_out == batch_size, "Output batch size should match the input batch size." + + +def test_harmonize_datasets_with_numpy(): + cases = np.random.random((100, 10)).astype(np.float32) + labels = np.random.randint(0, 2, size=(100, 1)).astype(np.int64) + batch_size = 10 + + cases_out, labels_out, targets_out, batch_size_out = harmonize_datasets(cases, labels, batch_size=batch_size) + + assert targets_out is None, "Targets should be None when not provided." + assert batch_size_out == batch_size, "Output batch size should match the input batch size." + + for case, label in zip(cases_out, labels_out): + assert case.shape == (batch_size, cases.shape[1]), "Each case should have the same shape as the input cases." + assert label.shape == (batch_size, labels.shape[1]), "Each label should have the same shape as the input labels." + break + + +def test_inputs_combinations(): + """ + Test management of dataset init inputs + """ + + tf_tensor = tf.reshape(tf.range(90, dtype=tf.float32), (10, 3, 3)) + np_array = np.array(tf_tensor) + tf_dataset = tf.data.Dataset.from_tensor_slices(tf_tensor) + + tf_dataset_b3 = tf_dataset.batch(3) + tf_dataset_b5 = tf_dataset.batch(5) + + tf_one_shuffle = tf_dataset.shuffle(32, 0, reshuffle_each_iteration=False).batch(4) + + # Method initialization that should work + cases_dataset, labels_dataset, targets_dataset, batch_size = harmonize_datasets(tf_dataset_b3, None, tf_dataset_b3) + assert are_dataset_first_elems_equal(cases_dataset, tf_dataset_b3) + assert are_dataset_first_elems_equal(labels_dataset, None) + assert are_dataset_first_elems_equal(targets_dataset, tf_dataset_b3) + assert batch_size == 3 + + cases_dataset, labels_dataset, targets_dataset, batch_size = harmonize_datasets(tf_tensor, tf_tensor, None, batch_size=5) + assert are_dataset_first_elems_equal(cases_dataset, tf_dataset_b5) + assert are_dataset_first_elems_equal(labels_dataset, tf_dataset_b5) + assert are_dataset_first_elems_equal(targets_dataset, None) + assert batch_size == 5 + + cases_dataset, labels_dataset, targets_dataset, batch_size =\ + harmonize_datasets(tf.data.Dataset.zip((tf_dataset_b5, tf_dataset_b5)), None, tf_dataset_b5) + assert are_dataset_first_elems_equal(cases_dataset, tf_dataset_b5) + assert are_dataset_first_elems_equal(labels_dataset, tf_dataset_b5) + assert are_dataset_first_elems_equal(targets_dataset, tf_dataset_b5) + assert batch_size == 5 + + cases_dataset, labels_dataset, targets_dataset, batch_size =\ + harmonize_datasets(tf.data.Dataset.zip((tf_dataset_b5, tf_dataset_b5, tf_dataset_b5))) + assert are_dataset_first_elems_equal(cases_dataset, tf_dataset_b5) + assert are_dataset_first_elems_equal(labels_dataset, tf_dataset_b5) + assert are_dataset_first_elems_equal(targets_dataset, tf_dataset_b5) + assert batch_size == 5 + + cases_dataset, labels_dataset, targets_dataset, batch_size =\ + harmonize_datasets(tf.data.Dataset.zip((tf_one_shuffle, tf_one_shuffle))) + assert are_dataset_first_elems_equal(cases_dataset, tf_one_shuffle) + assert are_dataset_first_elems_equal(labels_dataset, tf_one_shuffle) + assert are_dataset_first_elems_equal(targets_dataset, None) + assert batch_size == 4 + + cases_dataset, labels_dataset, targets_dataset, batch_size = harmonize_datasets(tf_one_shuffle) + assert are_dataset_first_elems_equal(cases_dataset, tf_one_shuffle) + assert are_dataset_first_elems_equal(labels_dataset, None) + assert are_dataset_first_elems_equal(targets_dataset, None) + assert batch_size == 4 + + +def test_error_raising(): + """ + Test management of dataset init inputs + """ + + tf_tensor = tf.reshape(tf.range(90, dtype=tf.float32), (10, 3, 3)) + np_array = np.array(tf_tensor) + tf_dataset = tf.data.Dataset.from_tensor_slices(tf_tensor) + too_short_np_array = np_array[:3] + too_long_tf_dataset = tf_dataset.concatenate(tf_dataset) + + tf_dataset_b3 = tf_dataset.batch(3) + tf_dataset_b5 = tf_dataset.batch(5) + too_long_tf_dataset_b5 = too_long_tf_dataset.batch(5) + too_long_tf_dataset_b10 = too_long_tf_dataset.batch(10) + + tf_shuffled = tf_dataset.shuffle(32, 0).batch(4) + + # Method initialization that should not work + # not input + with pytest.raises(TypeError): + harmonize_datasets() + + # shuffled + with pytest.raises(AssertionError): + harmonize_datasets(tf_shuffled) + + # mismatching types + with pytest.raises(AssertionError): + harmonize_datasets(tf_dataset, tf_tensor) + with pytest.raises(AssertionError): + harmonize_datasets(tf.data.Dataset.zip((tf_dataset_b5, tf_dataset_b5)), np_array) + with pytest.raises(AssertionError): + harmonize_datasets(tf_dataset_b3, too_short_np_array) + with pytest.raises(AssertionError): + harmonize_datasets(tf_dataset, None, too_long_tf_dataset) + + # not batched and no batch size provided + with pytest.raises(AssertionError): + harmonize_datasets(tf.data.Dataset.from_tensor_slices((tf_tensor, tf_tensor)), tf_dataset,) + + # not matching batch sizes + with pytest.raises(AssertionError): + harmonize_datasets(tf_dataset_b3, tf_dataset_b5,) + with pytest.raises(AssertionError): + harmonize_datasets(too_long_tf_dataset_b10, tf_dataset_b5,) + + # mismatching cardinality + with pytest.raises(AssertionError): + harmonize_datasets(tf_dataset_b5, too_long_tf_dataset_b5,) + + # multiple datasets for labels or targets + with pytest.raises(AssertionError): + harmonize_datasets(tf.data.Dataset.zip((tf_dataset_b5, tf_dataset_b5)), tf_dataset_b5,) + with pytest.raises(AssertionError): + harmonize_datasets( + tf.data.Dataset.zip((tf_dataset_b5, tf_dataset_b5, tf_dataset_b5)), + None, + tf_dataset_b5, + ) diff --git a/tests/example_based/test_kleor.py b/tests/example_based/test_kleor.py new file mode 100644 index 00000000..cd2cd333 --- /dev/null +++ b/tests/example_based/test_kleor.py @@ -0,0 +1,212 @@ +""" +Tests for the contrastive methods. +""" +import tensorflow as tf +import numpy as np + +from xplique.example_based.search_methods import KLEORSimMissSearch, KLEORGlobalSimSearch + +def test_kleor_base_and_sim_miss(): + """ + Test suite for both the BaseKLEOR and KLEORSimMiss class. Indeed, the KLEORSimMiss class is a subclass of the + BaseKLEOR class with a very basic implementation of the only abstract method (identity function). + """ + # setup the tests + cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32) + cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32) + + cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2) + cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2) + + inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32) + targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32) + + # build the kleor object + kleor = KLEORSimMissSearch(cases_dataset, cases_targets_dataset, k=1, search_returns=["examples", "indices", "distances", "include_inputs", "nuns"], batch_size=2) + + # test the _filter_fn method + fake_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32) + fake_cases_targets = tf.constant([[0, 1], [1, 0], [0, 1], [1, 0], [1, 0]], dtype=tf.float32) + # the mask should be True when the targets are the same i.e we keep those cases + expected_mask = tf.constant([[True, False, True, False, False], + [False, True, False, True, True], + [False, True, False, True, True], + [True, False, True, False, False], + [False, True, False, True, True]], dtype=tf.bool) + mask = kleor._filter_fn(inputs, cases, fake_targets, fake_cases_targets) + assert tf.reduce_all(tf.equal(mask, expected_mask)) + + # test the _filter_fn_nun method, this time the mask should be True when the targets are different + expected_mask = tf.constant([[False, True, False, True, True], + [True, False, True, False, False], + [True, False, True, False, False], + [False, True, False, True, True], + [True, False, True, False, False]], dtype=tf.bool) + mask = kleor._filter_fn_nun(inputs, cases, fake_targets, fake_cases_targets) + assert tf.reduce_all(tf.equal(mask, expected_mask)) + + # test the _get_nuns method + nuns, _, nuns_distances = kleor._get_nuns(inputs, targets) + expected_nuns = tf.constant([ + [[2., 3.]], + [[1., 2.]], + [[4., 5.]]], dtype=tf.float32) + expected_nuns_distances = tf.constant([ + [np.sqrt(2*0.5**2)], + [np.sqrt(2*1.5**2)], + [np.sqrt(2*0.5**2)]], dtype=tf.float32) + assert tf.reduce_all(tf.equal(nuns, expected_nuns)) + assert tf.reduce_all(tf.abs(nuns_distances - expected_nuns_distances) < 1e-5) + + # test the _initialize_search method + sf_indices, input_sf_distances, nun_sf_distances, batch_indices = kleor._initialize_search(inputs) + assert sf_indices.shape == (3, 1, 2) # (n, k, 2) + assert input_sf_distances.shape == (3, 1) # (n, k) + assert nun_sf_distances.shape == (3, 1) # (n, k) + assert batch_indices.shape == (3, 2) # (n, bs) + expected_sf_indices = tf.constant([[[-1, -1]],[[-1, -1]],[[-1, -1]]], dtype=tf.int32) + assert tf.reduce_all(tf.equal(sf_indices, expected_sf_indices)) + assert tf.reduce_all(tf.math.is_inf(input_sf_distances)) + assert tf.reduce_all(tf.math.is_inf(nun_sf_distances)) + expected_batch_indices = tf.constant([[0, 1], [0, 1], [0, 1]], dtype=tf.int32) + assert tf.reduce_all(tf.equal(batch_indices, expected_batch_indices)) + + # test the kneighbors method + input_sf_distances, sf_indices, nuns, _, __ = kleor.kneighbors(inputs, targets) + + assert input_sf_distances.shape == (3, 1) # (n, k) + assert sf_indices.shape == (3, 1, 2) # (n, k, 2) + assert nuns.shape == (3, 1, 2) # (n, k, 2) + + assert tf.reduce_all(tf.equal(nuns, expected_nuns)) + + expected_distances = tf.constant([[np.sqrt(2*0.5**2)], [np.sqrt(2*0.5**2)], [np.sqrt(2*1.5**2)]], dtype=tf.float32) + assert tf.reduce_all(tf.abs(input_sf_distances - expected_distances) < 1e-5) + + expected_indices = tf.constant([[[0, 0]],[[0, 1]],[[1, 0]]], dtype=tf.int32) + assert tf.reduce_all(tf.equal(sf_indices, expected_indices)) + + # test the find_examples method + return_dict = kleor.find_examples(inputs, targets) + assert set(return_dict.keys()) == set(["examples", "indices", "distances", "nuns"]) + + examples = return_dict["examples"] + distances = return_dict["distances"] + indices = return_dict["indices"] + nuns = return_dict["nuns"] + + assert tf.reduce_all(tf.equal(nuns, expected_nuns)) + assert tf.reduce_all(tf.equal(expected_indices, indices)) + assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5) + + expected_examples = tf.constant([ + [[1.5, 2.5], [1., 2.]], + [[2.5, 3.5], [2., 3.]], + [[4.5, 5.5], [3., 4.]]], dtype=tf.float32) + assert tf.reduce_all(tf.equal(examples, expected_examples)) + +def test_kleor_global_sim(): + """ + Test suite for the KleorGlobalSim class. As only the kneighbors, format_output are impacted by the + _additionnal_filtering method we test those 3 methods. + """ + # setup the tests + cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32) + cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32) + + cases_dataset = tf.data.Dataset.from_tensor_slices(cases).batch(2) + cases_targets_dataset = tf.data.Dataset.from_tensor_slices(cases_targets).batch(2) + + inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32) + targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32) + + # build the kleor object + kleor = KLEORGlobalSimSearch(cases_dataset, cases_targets_dataset, k=1, search_returns=["examples", "indices", "distances", "include_inputs", "nuns"], batch_size=2) + + # test the _additionnal_filtering method + # (n, bs) + fake_nun_sf_distances = tf.constant([[1., 2.], [2., 3.], [3., 4.]]) + # (n, bs) + fake_input_sf_distances = tf.constant([[2., 1.], [3., 2.], [2., 5.]]) + # (n,1) + fake_nuns_input_distances = tf.constant([[3.], [1.], [4.]]) + # the expected filtering should be such that we keep the distance of a sf candidates + # when the input is closer to the sf than the nun, otherwise we set it to infinity + expected_nun_sf_distances = tf.constant([[1., 2.], [np.inf, np.inf], [3., np.inf]], dtype=tf.float32) + expected_input_sf_distances = tf.constant([[2., 1.], [np.inf, np.inf], [2., np.inf]], dtype=tf.float32) + + nun_sf_distances, input_sf_distances = kleor._additional_filtering(fake_nun_sf_distances, fake_input_sf_distances, fake_nuns_input_distances) + assert nun_sf_distances.shape == (3, 2) + assert input_sf_distances.shape == (3, 2) + + inf_mask_expected_nun_sf = tf.math.is_inf(expected_nun_sf_distances) + inf_mask_nun_sf = tf.math.is_inf(nun_sf_distances) + assert tf.reduce_all(tf.equal(inf_mask_expected_nun_sf, inf_mask_nun_sf)) + assert tf.reduce_all( + tf.abs(tf.where(inf_mask_nun_sf, 0.0, nun_sf_distances) - tf.where(inf_mask_expected_nun_sf, 0.0, expected_nun_sf_distances) + ) < 1e-5) + + inf_mask_expected_input_sf = tf.math.is_inf(expected_input_sf_distances) + inf_mask_input_sf = tf.math.is_inf(input_sf_distances) + assert tf.reduce_all(tf.equal(inf_mask_expected_input_sf, inf_mask_input_sf)) + assert tf.reduce_all( + tf.abs(tf.where(inf_mask_input_sf, 0.0, input_sf_distances) - tf.where(inf_mask_expected_input_sf, 0.0, expected_input_sf_distances) + ) < 1e-5) + + # test the kneighbors method + input_sf_distances, sf_indices, nuns, _, __ = kleor.kneighbors(inputs, targets) + + expected_nuns = tf.constant([ + [[2., 3.]], + [[1., 2.]], + [[4., 5.]]], dtype=tf.float32) + assert tf.reduce_all(tf.equal(nuns, expected_nuns)) + + assert input_sf_distances.shape == (3, 1) # (n, k) + assert sf_indices.shape == (3, 1, 2) # (n, k, 2) + + expected_indices = tf.constant([[[-1, -1]],[[0, 1]],[[-1, -1]]], dtype=tf.int32) + assert tf.reduce_all(tf.equal(sf_indices, expected_indices)) + + expected_distances = tf.constant([[kleor.fill_value], [np.sqrt(2*0.5**2)], [kleor.fill_value]], dtype=tf.float32) + + # create masks for inf values + inf_mask_input = tf.math.is_inf(input_sf_distances) + inf_mask_expected = tf.math.is_inf(expected_distances) + assert tf.reduce_all(tf.equal(inf_mask_input, inf_mask_expected)) + + # compare finite values + assert tf.reduce_all( + tf.abs(tf.where(inf_mask_input, 0.0, input_sf_distances) - tf.where(inf_mask_expected, 0.0, expected_distances) + ) < 1e-5) + + # test the find_examples + return_dict = kleor.find_examples(inputs, targets) + + indices = return_dict["indices"] + nuns = return_dict["nuns"] + distances = return_dict["distances"] + examples = return_dict["examples"] + + assert tf.reduce_all(tf.equal(nuns, expected_nuns)) + assert tf.reduce_all(tf.equal(expected_indices, indices)) + + # create masks for inf values + inf_mask_dist = tf.math.is_inf(distances) + assert tf.reduce_all(tf.equal(inf_mask_dist, inf_mask_expected)) + assert tf.reduce_all( + tf.abs(tf.where(inf_mask_dist, 0.0, distances) - tf.where(inf_mask_expected, 0.0, expected_distances) + ) < 1e-5) + + expected_examples = tf.constant([ + [[1.5, 2.5], [np.inf, np.inf]], + [[2.5, 3.5], [2., 3.]], + [[4.5, 5.5], [np.inf, np.inf]]], dtype=tf.float32) + + # mask for inf values + inf_mask_examples = tf.math.is_inf(examples) + inf_mask_expected_examples = tf.math.is_inf(expected_examples) + assert tf.reduce_all(tf.equal(inf_mask_examples, inf_mask_expected_examples)) + assert tf.reduce_all( + tf.abs(tf.where(inf_mask_examples, 0.0, examples) - tf.where(inf_mask_expected_examples, 0.0, expected_examples) + ) < 1e-5) diff --git a/tests/example_based/test_knn.py b/tests/example_based/test_knn.py new file mode 100644 index 00000000..4a9df427 --- /dev/null +++ b/tests/example_based/test_knn.py @@ -0,0 +1,509 @@ +""" +Test the different search methods. +""" +import pytest +import numpy as np +import tensorflow as tf + +from ..utils import almost_equal + +from xplique.example_based.search_methods import BaseKNN, KNN, FilterKNN, ORDER + +def get_setup(input_shape, nb_samples=10, nb_labels=10): + """ + Generate data and model for SimilarExamples + """ + # Data generation + x_train = tf.stack( + [i * tf.ones(input_shape, tf.float32) for i in range(nb_samples)] + ) + x_test = x_train[1:-1] + y_train = tf.range(len(x_train), dtype=tf.float32) % nb_labels + + return x_train, x_test, y_train + +class MockKNN(BaseKNN): + """ + Mock KNN class for testing the find_examples method + """ + def kneighbors(self, inputs, targets): + """ + Define a mock kneighbors method for testing the find_examples method of + the base class. + """ + best_distances = tf.random.normal((inputs.shape[0], self.k), dtype=tf.float32) + best_indices= tf.random.uniform((inputs.shape[0], self.k, 2), maxval=self.k, dtype=tf.int32) + return best_distances, best_indices + +def same_target_filter(inputs, cases, targets, cases_targets): + """ + Filter function that returns a boolean mask with true when point-wise inputs and cases + have the same target. + """ + # get the labels predicted by the model + # (n, ) + predicted_labels = tf.argmax(targets, axis=-1) + + # for each input, if the target label is the same as the predicted label + # the mask as a True value and False otherwise + label_targets = tf.argmax(cases_targets, axis=-1) # (bs,) + mask = tf.equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs) + return mask + +def test_base_init(): + """ + Test the initialization of the base KNN class (not the super). + Check if it raises the relevant errors when the input is invalid. + """ + base_knn = MockKNN( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + k=2, + search_returns='distances', + ) + assert base_knn.order == ORDER.ASCENDING + assert base_knn.fill_value == np.inf + + # Test with reverse order + order = ORDER.DESCENDING + base_knn = MockKNN( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + k=2, + search_returns='distances', + order=order + ) + assert base_knn.order == order + assert base_knn.fill_value == -np.inf + + # Test with invalid order + with pytest.raises(AssertionError): + base_knn = MockKNN( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + k=2, + search_returns='distances', + order='invalid' + ) + +def test_base_find_examples(): + """ + Test the find_examples method of the base KNN class. + """ + returns = ["examples", "indices", "distances"] + mock_knn = MockKNN( + tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]], dtype=tf.float32), + k = 2, + search_returns = returns, + ) + + inputs = tf.random.normal((5, 3), dtype=tf.float32) + return_dict = mock_knn.find_examples(inputs) + assert set(return_dict.keys()) == set(returns) + assert return_dict["examples"].shape == (5, 2, 3) + assert return_dict["indices"].shape == (5, 2, 2) + assert return_dict["distances"].shape == (5, 2) + + returns = ["examples", "include_inputs"] + mock_knn = MockKNN( + tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]], dtype=tf.float32), + k = 2, + search_returns = returns, + ) + return_dict = mock_knn.find_examples(inputs) + assert return_dict["examples"].shape == (5, 3, 3) + + mock_knn = MockKNN( + tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]], dtype=tf.float32), + k = 2, + ) + return_dict = mock_knn.find_examples(inputs) + assert return_dict["examples"].shape == (5, 2, 3) + +def test_knn_init(): + """ + Test the initialization of the KNN class which are not linked to the super class. + """ + cases_dataset = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.], [10., 11., 12.]], dtype=tf.float32) + x1 = tf.random.normal((1, 3), dtype=tf.float32) + x2 = tf.random.normal((3, 3), dtype=tf.float32) + + # Test with distances that are compatible with tf.norm + distances = ["euclidean", 1, 2, np.inf, 5] + for distance in distances: + knn = KNN( + cases_dataset, + k=2, + search_returns='distances', + distance=distance, + ) + assert tf.reduce_all(tf.equal(knn.distance_fn(x1, x2), tf.norm(x1 - x2, ord=distance, axis=-1))) + + # Test with a custom distance function + def custom_distance(x1, x2): + return tf.reduce_sum(tf.abs(x1 - x2), axis=-1) + knn = KNN( + cases_dataset, + k=2, + search_returns='distances', + distance=custom_distance, + ) + assert tf.reduce_all(tf.equal(knn.distance_fn(x1, x2), custom_distance(x1, x2))) + + # Test with invalid distance + invalid_distances = [None, "invalid", 0.5] + for distance in invalid_distances: + with pytest.raises(AttributeError): + knn = KNN( + cases_dataset, + k=2, + search_returns='distances', + distance=distance, + ) + +def test_knn_compute_distances(): + """ + Test the private method _compute_distances_fn of the KNN class. + """ + # Test with input and cases being 1D + knn = KNN( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + k=2, + distance='euclidean', + order=ORDER.ASCENDING + ) + x1 = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=tf.float32) + x2 = tf.constant([[7.0, 8.0], [9.0, 10.0]], dtype=tf.float32) + + expected_distance = tf.constant( + [ + [np.sqrt(72), np.sqrt(128)], + [np.sqrt(32), np.sqrt(72)], + [np.sqrt(8), np.sqrt(32)] + ], dtype=tf.float32 + ) + + distances = knn._crossed_distances_fn(x1, x2) + assert distances.shape == (x1.shape[0], x2.shape[0]) + assert almost_equal(distances, expected_distance, epsilon=1e-5) + + # Test with higher dimensions + data = np.array([ + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]] + ]) + + knn = KNN( + data, + k=2, + distance="euclidean", + order=ORDER.ASCENDING + ) + + x1 = tf.constant( + [ + [[1, 2, 3],[4, 5, 6],[7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]] + ], dtype=tf.float32 + ) + + x2 = tf.constant( + [ + [[28, 29, 30], [31, 32, 33], [34, 35, 36]], + [[37, 38, 39], [40, 41, 42], [43, 44, 45]], + ], dtype=tf.float32 + ) + + expected_distance = tf.constant( + [[np.sqrt(9)*27, np.sqrt(9)*36], + [np.sqrt(9)*18, np.sqrt(9)*27], + [np.sqrt(9)*9, np.sqrt(9)*18]], dtype=tf.float32) + + distances = knn._crossed_distances_fn(x1, x2) + assert distances.shape == (x1.shape[0], x2.shape[0]) + assert almost_equal(distances, expected_distance) + + +def test_knn_kneighbors(): + """ + Test the kneighbors method of the KNN class. + """ + # Test with input and cases being 1D + cases = tf.constant([[1.], [2.], [3.], [4.], [5.]], dtype=tf.float32) + inputs = tf.constant([[1.5], [2.5], [4.5]], dtype=tf.float32) + knn = KNN( + cases, + k=2, + batch_size=2, + distance="euclidean", + ) + + distances, indices = knn.kneighbors(inputs) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + assert almost_equal(distances, tf.constant([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], dtype=tf.float32)) + assert almost_equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)) + + # Test with reverse order + knn = KNN( + cases, + k=2, + batch_size=2, + distance="euclidean", + order=ORDER.DESCENDING + ) + + distances, indices = knn.kneighbors(inputs) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + assert almost_equal(distances, tf.constant([[3.5, 2.5], [2.5, 1.5], [3.5, 2.5]], dtype=tf.float32)) + assert almost_equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)) + + # Test with input and cases being 2D + cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32) + inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32) + knn = KNN( + cases, + k=2, + batch_size=2, + distance="euclidean", + ) + + distances, indices = knn.kneighbors(inputs) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + assert almost_equal(distances, tf.constant([[np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)]], dtype=tf.float32)) + assert almost_equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)) + + # Test with reverse order + knn = KNN( + cases, + k=2, + batch_size=2, + distance="euclidean", + order=ORDER.DESCENDING + ) + + distances, indices = knn.kneighbors(inputs) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + expected_distances = tf.constant([[np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)]], dtype=tf.float32) + assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5) + assert almost_equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)) + +def test_filter_knn_compute_distances(): + """ + Test the private method _compute_distances_fn of the FilterKNN class. + """ + # Test in Low dimension + knn = FilterKNN( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + k=2, + distance='euclidean', + order=ORDER.ASCENDING + ) + x1 = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=tf.float32) + x2 = tf.constant([[7.0, 8.0], [9.0, 10.0]], dtype=tf.float32) + expected_distance = tf.constant( + [ + [np.sqrt(72), np.sqrt(128)], + [np.sqrt(32), np.sqrt(72)], + [np.sqrt(8), np.sqrt(32)] + ], dtype=tf.float32 + ) + mask = tf.ones((x1.shape[0], x2.shape[0]), dtype=tf.bool) + distances = knn._crossed_distances_fn(x1, x2, mask) + assert distances.shape == (x1.shape[0], x2.shape[0]) + assert almost_equal(distances, expected_distance, epsilon=1e-5) + + mask = tf.constant([[True, False], [False, True], [True, True]], dtype=tf.bool) + expected_distance = tf.constant([[np.sqrt(72), np.inf], [np.inf, np.sqrt(72)], [np.sqrt(8), np.sqrt(32)]], dtype=tf.float32) + distances = knn._crossed_distances_fn(x1, x2, mask) + assert np.allclose(distances, expected_distance, equal_nan=True) + assert np.array_equal(distances == np.inf, expected_distance == np.inf) + assert np.array_equal(distances == -np.inf, expected_distance == -np.inf) + + # Test with higher dimensions + data = np.array([ + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]] + ]) + + knn = FilterKNN( + data, + k=2, + distance="euclidean", + order=ORDER.ASCENDING + ) + + x1 = tf.constant( + [ + [[1, 2, 3],[4, 5, 6],[7, 8, 9]], + [[10, 11, 12], [13, 14, 15], [16, 17, 18]], + [[19, 20, 21], [22, 23, 24], [25, 26, 27]] + ], dtype=tf.float32 + ) + + x2 = tf.constant( + [ + [[28, 29, 30], [31, 32, 33], [34, 35, 36]], + [[37, 38, 39], [40, 41, 42], [43, 44, 45]], + ], dtype=tf.float32 + ) + + expected_distance = tf.constant( + [[np.sqrt(9)*27, np.sqrt(9)*36], + [np.sqrt(9)*18, np.sqrt(9)*27], + [np.sqrt(9)*9, np.sqrt(9)*18]], dtype=tf.float32) + + mask = tf.ones((x1.shape[0], x2.shape[0]), dtype=tf.bool) + distances = knn._crossed_distances_fn(x1, x2, mask) + assert distances.shape == (x1.shape[0], x2.shape[0]) + assert almost_equal(distances, expected_distance) + + mask = tf.constant([[True, False], [False, True], [True, True]], dtype=tf.bool) + expected_distance = tf.constant([[np.sqrt(9)*27, np.inf], [np.inf, np.sqrt(9)*27], [np.sqrt(9)*9, np.sqrt(9)*18]], dtype=tf.float32) + distances = knn._crossed_distances_fn(x1, x2, mask) + assert distances.shape == (x1.shape[0], x2.shape[0]) + assert np.allclose(distances, expected_distance, equal_nan=True) + assert np.array_equal(distances == np.inf, expected_distance == np.inf) + assert np.array_equal(distances == -np.inf, expected_distance == -np.inf) + +def test_filter_knn_kneighbors(): + """ + """ + # Test with input and cases being 1D + cases = tf.constant([[1.], [2.], [3.], [4.], [5.]], dtype=tf.float32) + inputs = tf.constant([[1.5], [2.5], [4.5]], dtype=tf.float32) + ## default filter and default order + knn = FilterKNN( + cases, + k=2, + batch_size=2, + distance="euclidean", + ) + + distances, indices = knn.kneighbors(inputs) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + assert almost_equal(distances, tf.constant([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], dtype=tf.float32)) + assert almost_equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)) + + cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32) + targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32) + + ## add a filter that is not the default + knn = FilterKNN( + cases, + targets_dataset=cases_targets, + k=2, + batch_size=2, + distance="euclidean", + filter_fn=same_target_filter + ) + distances, indices = knn.kneighbors(inputs, targets) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + assert almost_equal(distances, tf.constant([[0.5, 2.5], [0.5, 0.5], [0.5, 1.5]], dtype=tf.float32)) + assert almost_equal(indices, tf.constant([[[0, 0], [1, 1]],[[0, 1], [1, 0]],[[2, 0], [1, 0]]], dtype=tf.int32)) + + ## test with reverse order + knn = FilterKNN( + cases, + k=2, + batch_size=2, + distance="euclidean", + order=ORDER.DESCENDING + ) + + distances, indices = knn.kneighbors(inputs, targets) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + expected_distances = tf.constant([[3.5, 2.5], [2.5, 1.5], [3.5, 2.5]], dtype=tf.float32) + assert almost_equal(distances, expected_distances) + assert almost_equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)) + + ## add a filter that is not the default one and reverse order + knn = FilterKNN( + cases, + targets_dataset=cases_targets, + k=2, + batch_size=2, + distance="euclidean", + order=ORDER.DESCENDING, + filter_fn=same_target_filter + ) + + distances, indices = knn.kneighbors(inputs, targets) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + assert almost_equal(distances, tf.constant([[2.5, 0.5], [2.5, 0.5], [2.5, 1.5]], dtype=tf.float32)) + assert almost_equal(indices, tf.constant([[[1, 1], [0, 0]],[[2, 0], [0, 1]],[[0, 1], [1, 0]]], dtype=tf.int32)) + + # Test with input and cases being 2D + cases = tf.constant([[1., 2.], [2., 3.], [3., 4.], [4., 5.], [5., 6.]], dtype=tf.float32) + inputs = tf.constant([[1.5, 2.5], [2.5, 3.5], [4.5, 5.5]], dtype=tf.float32) + ## default filter and default order + knn = FilterKNN( + cases, + k=2, + batch_size=2, + distance="euclidean", + ) + + distances, indices = knn.kneighbors(inputs) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + assert almost_equal(distances, tf.constant([[np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(0.5)]], dtype=tf.float32)) + assert almost_equal(indices, tf.constant([[[0, 0], [0, 1]],[[0, 1], [1, 0]],[[1, 1], [2, 0]]], dtype=tf.int32)) + + cases_targets = tf.constant([[0, 1], [1, 0], [1, 0], [0, 1], [1, 0]], dtype=tf.float32) + targets = tf.constant([[0, 1], [1, 0], [1, 0]], dtype=tf.float32) + ## add a filter that is not the default + knn = FilterKNN( + cases, + targets_dataset=cases_targets, + k=2, + batch_size=2, + distance="euclidean", + filter_fn=same_target_filter + ) + + distances, indices = knn.kneighbors(inputs, targets) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + expected_distances = tf.constant([[np.sqrt(0.5), np.sqrt(2*2.5**2)], [np.sqrt(0.5), np.sqrt(0.5)], [np.sqrt(0.5), np.sqrt(2*1.5**2)],], dtype=tf.float32) + assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5) + assert almost_equal(indices, tf.constant([[[0, 0], [1, 1]],[[0, 1], [1, 0]],[[2, 0], [1, 0]]], dtype=tf.int32)) + + ## test with reverse order and default filter + knn = FilterKNN( + cases, + k=2, + batch_size=2, + distance="euclidean", + order=ORDER.DESCENDING + ) + + distances, indices = knn.kneighbors(inputs) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + expected_distances = tf.constant([[np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)], [np.sqrt(2*3.5**2), np.sqrt(2*2.5**2)]], dtype=tf.float32) + assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5) + assert almost_equal(indices, tf.constant([[[2, 0], [1, 1]],[[2, 0], [0, 0]],[[0, 0], [0, 1]]], dtype=tf.int32)) + + ## add a filter that is not the default one and reverse order + knn = FilterKNN( + cases, + targets_dataset=cases_targets, + k=2, + batch_size=2, + distance="euclidean", + order=ORDER.DESCENDING, + filter_fn=same_target_filter + ) + + distances, indices = knn.kneighbors(inputs, targets) + assert distances.shape == (3, 2) + assert indices.shape == (3, 2, 2) + expected_distances = tf.constant([[np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(0.5)], [np.sqrt(2*2.5**2), np.sqrt(2*1.5**2)]], dtype=tf.float32) + assert tf.reduce_all(tf.abs(distances - expected_distances) < 1e-5) + assert almost_equal(indices, tf.constant([[[1, 1], [0, 0]],[[2, 0], [0, 1]],[[0, 1], [1, 0]]], dtype=tf.int32)) diff --git a/tests/example_based/test_projections.py b/tests/example_based/test_projections.py new file mode 100644 index 00000000..9da17c68 --- /dev/null +++ b/tests/example_based/test_projections.py @@ -0,0 +1,226 @@ +import numpy as np +import tensorflow as tf +from tensorflow.keras.layers import ( + Dense, + Conv2D, + Conv1D, + Activation, + Dropout, + Flatten, + MaxPooling2D, + Input, +) + +from xplique.commons.operators import predictions_operator +from xplique.attributions import Saliency +from xplique.example_based.projections import Projection, AttributionProjection, LatentSpaceProjection, HadamardProjection +from xplique.example_based.projections.commons import model_splitting, target_free_classification_operator + +from ..utils import almost_equal + + +def get_setup(input_shape, nb_samples=10, nb_labels=2): + """ + Generate data and model for SimilarExamples + """ + # Data generation + x_train = tf.stack( + [i * tf.ones(input_shape, tf.float32) for i in range(nb_samples)] + ) + x_test = x_train[1:-1] + y_train = tf.one_hot(tf.range(len(x_train)) % nb_labels, nb_labels) + + return x_train, x_test, y_train + + +def _generate_model(input_shape=(32, 32, 3), output_shape=2): + model = tf.keras.Sequential() + model.add(Input(shape=input_shape)) + model.add(Conv2D(4, kernel_size=(2, 2), activation="relu", name="conv2d_1")) + model.add(Conv2D(4, kernel_size=(2, 2), activation="relu", name="conv2d_2")) + model.add(MaxPooling2D(pool_size=(2, 2))) + model.add(Dropout(0.25)) + model.add(Flatten()) + model.add(Dense(output_shape, name="dense")) + model.add(Activation("softmax", name="softmax")) + model.compile(loss="categorical_crossentropy", optimizer="sgd") + + return model + + +def test_simple_projection_mapping(): + """ + Test if a simple projection can be mapped. + """ + # Setup + input_shape = (7, 7, 3) + nb_samples = 10 + nb_labels = 2 + x_train, _, y_train = get_setup(input_shape, nb_samples=nb_samples, nb_labels=nb_labels) + + weights = tf.random.uniform((input_shape[0], input_shape[1], 1), minval=0, maxval=1) + + space_projection = lambda x, y=None: tf.nn.max_pool2d(x, ksize=3, strides=1, padding="SAME") + + projection = Projection(get_weights=weights, space_projection=space_projection, mappable=True) + + # Generate tf.data.Dataset from numpy + train_dataset = tf.data.Dataset.from_tensor_slices(x_train).batch(3) + targets_dataset = tf.data.Dataset.from_tensor_slices(y_train).batch(3) + + # Apply the projection by mapping the dataset + projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset) + + # Apply the projection by iterating over the dataset + projection.mappable = False + projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset) + + +def test_model_splitting(): + """ + Test if projected samples have the expected values + """ + x_train = np.reshape(np.arange(0, 100), (10, 10)) + + model = tf.keras.Sequential() + model.add(Input(shape=(10,))) + model.add(Dense(10, name="dense1")) + model.add(Dense(1, name="dense2")) + model.compile(loss="categorical_crossentropy", optimizer="sgd") + + model.get_layer("dense1").set_weights([np.eye(10) * np.sign(np.arange(-4.5, 5.5)), np.zeros(10)]) + model.get_layer("dense2").set_weights([np.ones((10, 1)), np.zeros(1)]) + + # Split the model + _, _ = model_splitting(model, latent_layer=-1) + _, _ = model_splitting(model, latent_layer="dense2") + features_extractor, predictor = model_splitting(model, latent_layer="dense1") + + assert almost_equal(predictor(features_extractor(x_train)).numpy(), model(x_train)) + + +def test_latent_space_projection_mapping(): + """ + Test if the latent space projection can be mapped. + """ + # Setup + input_shape = (7, 7, 3) + nb_samples = 10 + nb_labels = 2 + x_train, _, y_train = get_setup(input_shape, nb_samples=nb_samples, nb_labels=nb_labels) + + model = _generate_model(input_shape=input_shape, output_shape=nb_labels) + + projection = LatentSpaceProjection(model, "last_conv") + + # Generate tf.data.Dataset from numpy + train_dataset = tf.data.Dataset.from_tensor_slices(x_train).batch(3) + targets_dataset = tf.data.Dataset.from_tensor_slices(y_train).batch(3) + + # Apply the projection by mapping the dataset + projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset) + projected_train_dataset = projection._map_project_dataset(train_dataset, targets_dataset) + projected_train_dataset = projection._loop_project_dataset(train_dataset, targets_dataset) + + +def test_hadamard_projection_mapping(): + """ + Test if the hadamard projection can be mapped. + """ + # Setup + input_shape = (7, 7, 3) + nb_samples = 10 + nb_labels = 2 + x_train, _, y_train = get_setup(input_shape, nb_samples=nb_samples, nb_labels=nb_labels) + + model = _generate_model(input_shape=input_shape, output_shape=nb_labels) + + projection = HadamardProjection(model, "last_conv") + + # Generate tf.data.Dataset from numpy + train_dataset = tf.data.Dataset.from_tensor_slices(x_train).batch(3) + targets_dataset = tf.data.Dataset.from_tensor_slices(y_train).batch(3) + + # Apply the projection by mapping the dataset + projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset) + projected_train_dataset = projection._map_project_dataset(train_dataset, targets_dataset) + projected_train_dataset = projection._loop_project_dataset(train_dataset, targets_dataset) + + +def test_attribution_projection_mapping(): + """ + Test if the attribution projection can be mapped. + """ + # Setup + input_shape = (7, 7, 3) + nb_samples = 10 + nb_labels = 2 + x_train, _, y_train = get_setup(input_shape, nb_samples=nb_samples, nb_labels=nb_labels) + + model = _generate_model(input_shape=input_shape, output_shape=nb_labels) + + projection = AttributionProjection(model, attribution_method=Saliency, latent_layer="last_conv") + + # Generate tf.data.Dataset from numpy + train_dataset = tf.data.Dataset.from_tensor_slices(x_train).batch(3) + targets_dataset = tf.data.Dataset.from_tensor_slices(y_train).batch(3) + + # Apply the projection by mapping the dataset + projected_train_dataset = projection.project_dataset(train_dataset, targets_dataset) + + +def test_from_splitted_model(): + """ + Test the other way of constructing the projection. + """ + latent_width = 8 + nb_samples = 15 + input_features = 10 + output_features = 3 + x_train = np.reshape(np.arange(0, nb_samples * input_features), (nb_samples, input_features)) + tf_x_train = tf.convert_to_tensor(x_train, dtype=tf.float32) + + train_dataset = tf.data.Dataset.from_tensor_slices(x_train).batch(3) + + model1 = tf.keras.Sequential() + model1.add(Input(shape=(input_features,))) + model1.add(Dense(latent_width, name="dense1")) + model1.compile(loss="mean_absolute_error", optimizer="sgd") + + model2 = tf.keras.Sequential() + model2.add(Input(shape=(latent_width,))) + model2.add(Dense(output_features, name="dense2")) + model2.compile(loss="categorical_crossentropy", optimizer="sgd") + + assert model1(x_train).shape == (nb_samples, latent_width) + assert model2(model1(x_train)).shape == (nb_samples, output_features) + + # test LatentSpaceProjection from splitted model + projection = LatentSpaceProjection(model=model1, latent_layer=None, mappable=True) + projected_train_dataset = projection.project_dataset(train_dataset) + + # test HadamardProjection from splitted model + projection = HadamardProjection(features_extractor=model1, predictor=model2, mappable=True) + projected_train_dataset = projection.project_dataset(train_dataset) + + +def test_target_free_classification_operator(): + """ + Test if the target free classification operator works as expected. + """ + nb_classes = 5 + x_train = np.reshape(np.arange(0, 100), (10, 10)) + + model = tf.keras.Sequential() + model.add(Input(shape=(10,))) + model.add(Dense(10, name="dense1")) + model.add(Dense(nb_classes, name="dense2")) + model.compile(loss="categorical_crossentropy", optimizer="sgd") + + preds = model(x_train) + targets = tf.one_hot(tf.argmax(preds, axis=1), nb_classes) + + scores1 = target_free_classification_operator(model, x_train) + scores2 = predictions_operator(model, x_train, targets) + + assert almost_equal(scores1, scores2) diff --git a/tests/example_based/test_prototypes.py b/tests/example_based/test_prototypes.py new file mode 100644 index 00000000..b163c97b --- /dev/null +++ b/tests/example_based/test_prototypes.py @@ -0,0 +1,242 @@ +""" +Test Prototypes +""" +import os +import sys + +sys.path.append(os.getcwd()) + +import tensorflow as tf + +from xplique.example_based import Prototypes, ProtoGreedy, ProtoDash, MMDCritic +from xplique.example_based.projections import Projection, LatentSpaceProjection + +from tests.utils import almost_equal, get_gaussian_data, generate_model + + +def test_prototypes_global_explanations_basic(): + """ + Test prototypes shapes and uniqueness. + """ + # Setup + k = 2 + nb_prototypes = 5 + nb_classes = 3 + gamma = 0.026 + batch_size = 8 + + x_train, y_train = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20, n_dims=3) + x_test, y_test = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=6, n_dims=3) + + for method_class in [ProtoGreedy, ProtoDash, MMDCritic]: + # compute general prototypes + method = method_class( + cases_dataset=x_train, + labels_dataset=y_train, + nb_local_prototypes=k, + batch_size=batch_size, + case_returns=["examples", "distances", "labels", "indices"], + distance="euclidean", + nb_global_prototypes=nb_prototypes, + gamma=gamma, + ) + + # ====================== + # Test global prototypes + + # extract prototypes + prototypes = method.prototypes + prototypes_indices = method.prototypes_indices + prototypes_labels = method.prototypes_labels + prototypes_weights = method.prototypes_weights + + # check shapes + assert prototypes.shape == (nb_prototypes,) + x_train.shape[1:] + assert prototypes_indices.shape == (nb_prototypes, 2) + assert prototypes_labels.shape == (nb_prototypes,) + assert prototypes_weights.shape == (nb_prototypes,) + + # check uniqueness + flatten_indices = prototypes_indices[:, 0] * batch_size + prototypes_indices[:, 1] + assert len(tf.unique(flatten_indices)[0]) == nb_prototypes + + # for each prototype + for i in range(nb_prototypes): + # check prototypes are in the dataset and correspond to the index + assert tf.reduce_all(tf.equal(prototypes[i], x_train[flatten_indices[i]])) + + # same for labels + assert tf.reduce_all(tf.equal(prototypes_labels[i], y_train[flatten_indices[i]])) + + # check indices are in the dataset + assert flatten_indices[i] >= 0 and flatten_indices[i] < x_train.shape[0] + + # ===================== + # Test local prototypes + + # compute local explanations + outputs = method.explain(x_test) + examples = outputs["examples"] + distances = outputs["distances"] + labels = outputs["labels"] + indices = outputs["indices"] + + # check shapes + assert examples.shape == (x_test.shape[0], k) + x_train.shape[1:] + assert distances.shape == (x_test.shape[0], k) + assert labels.shape == (x_test.shape[0], k) + assert indices.shape == (x_test.shape[0], k, 2) + + assert tf.reduce_all(indices[:, :, 0] >= 0) + assert tf.reduce_all(indices[:, :, 0] < (1 + x_train.shape[0] // batch_size)) + assert tf.reduce_all(indices[:, :, 1] >= 0) + assert tf.reduce_all(indices[:, :, 1] < batch_size) + flatten_indices = indices[:, :, 0] * batch_size + indices[:, :, 1] + + # for each sample + for i in range(x_test.shape[0]): + # check first closest prototype label is the same as the sample label + assert tf.reduce_all(tf.equal(labels[i, 0], y_test[i])) + + for j in range(k): + # check prototypes are in the dataset and correspond to the index + assert tf.reduce_all(tf.equal(examples[i, j], x_train[flatten_indices[i, j]])) + + # same for labels + assert tf.reduce_all(tf.equal(labels[i, j], y_train[flatten_indices[i, j]])) + + +def test_prototypes_global_sanity_check(): + """ + Test prototypes global explanations sanity checks. + + Check: For n separated gaussians, + for n requested prototypes, + there should be 1 prototype per gaussian. + """ + # Setup + k = 2 + nb_prototypes = 3 + gamma = 0.026 + + x_train, y_train = get_gaussian_data(nb_classes=nb_prototypes, nb_samples_class=5, n_dims=3) + + for method_class in [MMDCritic, ProtoDash, ProtoGreedy]: + # compute general prototypes + method = method_class( + cases_dataset=x_train, + labels_dataset=y_train, + nb_local_prototypes=k, + batch_size=8, + nb_global_prototypes=nb_prototypes, + gamma=gamma, + ) + # extract prototypes + prototypes_labels = method.get_global_prototypes()["prototypes_labels"] + + # check 1 + assert len(tf.unique(prototypes_labels)[0]) == nb_prototypes + + +def test_prototypes_with_projection(): + """ + Test prototypes shapes and uniqueness. + """ + # Setup + k = 2 + nb_prototypes = 10 + nb_classes = 2 + gamma = 0.026 + batch_size = 8 + + x_train, y_train = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=20, n_dims=3) + x_test, y_test = get_gaussian_data(nb_classes=nb_classes, nb_samples_class=6, n_dims=3) + + # [10, 10, 10] -> [15, 15] + # [20, 20, 20] -> [30, 30] + # [30, 30, 30] -> [45, 45] + weights = tf.constant([[1.0, 0.0], + [0.5, 0.5], + [0.0, 1.0],], + dtype=tf.float32) + + weighted_projection = Projection( + space_projection=lambda inputs, targets=None: inputs @ weights + ) + + for method_class in [ProtoGreedy, ProtoDash, MMDCritic]: + # compute general prototypes + method = method_class( + cases_dataset=x_train, + labels_dataset=y_train, + nb_local_prototypes=k, + projection=weighted_projection, + batch_size=batch_size, + case_returns=["examples", "distances", "labels", "indices"], + nb_global_prototypes=nb_prototypes, + gamma=gamma, + ) + + # ====================== + # Test global prototypes + + # extract prototypes + prototypes = method.prototypes + prototypes_indices = method.prototypes_indices + prototypes_labels = method.prototypes_labels + prototypes_weights = method.prototypes_weights + + # check shapes + assert prototypes.shape == (nb_prototypes,) + x_train.shape[1:] + assert prototypes_indices.shape == (nb_prototypes, 2) + assert prototypes_labels.shape == (nb_prototypes,) + assert prototypes_weights.shape == (nb_prototypes,) + + # check uniqueness + flatten_indices = prototypes_indices[:, 0] * batch_size + prototypes_indices[:, 1] + assert len(tf.unique(flatten_indices)[0]) == nb_prototypes + + # for each prototype + for i in range(nb_prototypes): + # check prototypes are in the dataset and correspond to the index + assert tf.reduce_all(tf.equal(prototypes[i], x_train[flatten_indices[i]])) + + # same for labels + assert tf.reduce_all(tf.equal(prototypes_labels[i], y_train[flatten_indices[i]])) + + # check indices are in the dataset + assert flatten_indices[i] >= 0 and flatten_indices[i] < x_train.shape[0] + + # ===================== + # Test local prototypes + + # compute local explanations + outputs = method.explain(x_test) + examples = outputs["examples"] + distances = outputs["distances"] + labels = outputs["labels"] + indices = outputs["indices"] + + # check shapes + assert examples.shape == (x_test.shape[0], k) + x_train.shape[1:] + assert distances.shape == (x_test.shape[0], k) + assert labels.shape == (x_test.shape[0], k) + assert indices.shape == (x_test.shape[0], k, 2) + + assert tf.reduce_all(indices[:, :, 0] >= 0) + assert tf.reduce_all(indices[:, :, 0] < (1 + x_train.shape[0] // batch_size)) + assert tf.reduce_all(indices[:, :, 1] >= 0) + assert tf.reduce_all(indices[:, :, 1] < batch_size) + flatten_indices = indices[:, :, 0] * batch_size + indices[:, :, 1] + + # for each sample + for i in range(x_test.shape[0]): + # check first closest prototype label is the same as the sample label + assert tf.reduce_all(tf.equal(labels[i], y_test[i])) + + for j in range(k): + # check prototypes are in the dataset and correspond to the index + assert tf.reduce_all(tf.equal(examples[i, j], x_train[flatten_indices[i, j]])) + + # same for labels + assert tf.reduce_all(tf.equal(labels[i, j], y_train[flatten_indices[i, j]])) diff --git a/tests/example_based/test_similar_examples.py b/tests/example_based/test_similar_examples.py new file mode 100644 index 00000000..5d990fad --- /dev/null +++ b/tests/example_based/test_similar_examples.py @@ -0,0 +1,189 @@ +""" +Test Cole +""" +import os +import sys + +sys.path.append(os.getcwd()) + +import numpy as np +import tensorflow as tf + +from xplique.example_based import SimilarExamples +from xplique.example_based.projections import Projection + +from tests.utils import almost_equal + + +def get_setup(input_shape, nb_samples=10, nb_labels=10): + """ + Generate data and model for SimilarExamples + """ + # Data generation + x_train = tf.stack( + [i * tf.ones(input_shape, tf.float32) for i in range(nb_samples)] + ) + x_test = x_train[1:-1] + y_train = tf.range(len(x_train), dtype=tf.float32) % nb_labels + + return x_train, x_test, y_train + + + +def test_similar_examples_basic(): + """ + Test the SimilarExamples with an identity projection. + """ + # Setup + input_shape = (4, 4, 1) + k = 3 + x_train, x_test, _ = get_setup(input_shape) + + identity_projection = Projection( + space_projection=lambda inputs, targets=None: inputs + ) + + # Method initialization + method = SimilarExamples( + cases_dataset=x_train, + projection=identity_projection, + k=k, + batch_size=3, + distance="euclidean", + ) + + # Generate explanation + examples = method.explain(x_test)["examples"] + + # Verifications + # Shape should be (n, k, h, w, c) + assert examples.shape == (len(x_test), k) + input_shape + + for i in range(len(x_test)): + # test examples: + assert almost_equal(examples[i, 0], x_train[i + 1]) + assert almost_equal(examples[i, 1], x_train[i + 2]) or almost_equal( + examples[i, 1], x_train[i] + ) + assert almost_equal(examples[i, 2], x_train[i]) or almost_equal( + examples[i, 2], x_train[i + 2] + ) + + +def test_similar_examples_return_multiple_elements(): + """ + Test the returns attribute. + Test modifying k. + """ + # Setup + input_shape = (5, 5, 1) + k = 3 + x_train, x_test, y_train = get_setup(input_shape) + + nb_samples_test = len(x_test) + assert nb_samples_test + 2 == len(y_train) + + identity_projection = Projection( + space_projection=lambda inputs, targets=None: inputs + ) + + # Method initialization + method = SimilarExamples( + cases_dataset=x_train, + labels_dataset=y_train, + projection=identity_projection, + k=1, + batch_size=3, + distance="euclidean", + ) + + method.returns = "all" + method.k = k + + # Generate explanation + method_output = method.explain(x_test) + + assert isinstance(method_output, dict) + + examples = method_output["examples"] + distances = method_output["distances"] + labels = method_output["labels"] + + # test every outputs shape (with the include inputs) + assert examples.shape == (nb_samples_test, k + 1) + input_shape + # the inputs distance ae zero and indices do not exist + assert distances.shape == (nb_samples_test, k) + assert labels.shape == (nb_samples_test, k) + + for i in range(nb_samples_test): + # test examples: + assert almost_equal(examples[i, 0], x_test[i]) + assert almost_equal(examples[i, 1], x_train[i + 1]) + assert almost_equal(examples[i, 2], x_train[i + 2]) or almost_equal( + examples[i, 2], x_train[i] + ) + assert almost_equal(examples[i, 3], x_train[i]) or almost_equal( + examples[i, 3], x_train[i + 2] + ) + + # test distances + assert almost_equal(distances[i, 0], 0) + assert almost_equal(distances[i, 1], np.sqrt(np.prod(input_shape))) + assert almost_equal(distances[i, 2], np.sqrt(np.prod(input_shape))) + + # test labels + assert almost_equal(labels[i, 0], y_train[i + 1]) + assert almost_equal(labels[i, 1], y_train[i]) or almost_equal( + labels[i, 1], y_train[i + 2] + ) + assert almost_equal(labels[i, 2], y_train[i]) or almost_equal( + labels[i, 2], y_train[i + 2] + ) + + +def test_similar_examples_weighting(): + """ + Test the application of the projection weighting. + """ + # Setup + input_shape = (4, 4, 1) + nb_samples = 10 + k = 3 + x_train, x_test, y_train = get_setup(input_shape, nb_samples) + + # Define the weighing function + weights = np.zeros(x_train[0].shape) + weights[1] = np.ones(weights[1].shape) + + # create huge noise on non interesting features + noise = np.random.uniform(size=x_train.shape, low=-100, high=100) + x_train = np.float32(weights * np.array(x_train) + (1 - weights) * noise) + + weighting_function = Projection(get_weights=weights) + + method = SimilarExamples( + cases_dataset=x_train, + labels_dataset=np.array(y_train), + projection=weighting_function, + k=k, + batch_size=5, + distance="euclidean", + ) + + # Generate explanation + examples = method.explain(x_test)["examples"] + + # Verifications + # Shape should be (n, k, h, w, c) + nb_samples_test = x_test.shape[0] + assert examples.shape == (nb_samples_test, k) + input_shape + + for i in range(nb_samples_test): + # test examples: + assert almost_equal(examples[i, 0], x_train[i + 1]) + assert almost_equal(examples[i, 1], x_train[i + 2]) or almost_equal( + examples[i, 1], x_train[i] + ) + assert almost_equal(examples[i, 2], x_train[i]) or almost_equal( + examples[i, 2], x_train[i + 2] + ) diff --git a/tests/example_based/test_tf_dataset_operation.py b/tests/example_based/test_tf_dataset_operation.py new file mode 100644 index 00000000..a8c92ee5 --- /dev/null +++ b/tests/example_based/test_tf_dataset_operation.py @@ -0,0 +1,194 @@ +""" +Test operations on tf datasets +""" +import os +import sys + +sys.path.append(os.getcwd()) + +import pytest + +import numpy as np +import tensorflow as tf + +from xplique.example_based.datasets_operations.tf_dataset_operations import * +from xplique.example_based.datasets_operations.tf_dataset_operations import _almost_equal + + +def datasets_are_equal(dataset_1, dataset_2): + """ + Iterate over the datasets and compare the elements + """ + for elem_1, elem_2 in zip(dataset_1, dataset_2): + if not _almost_equal(elem_1, elem_2): + return False + return True + + +def test_are_dataset_first_elems_equal(): + """ + Verify that the function is able to compare the first element of datasets + """ + tf_dataset_up = tf.data.Dataset.from_tensor_slices( + tf.reshape(tf.range(90), (10, 3, 3)) + ) + tf_dataset_up_small = tf.data.Dataset.from_tensor_slices( + tf.reshape(tf.range(45), (5, 3, 3)) + ) + tf_dataset_down = tf.data.Dataset.from_tensor_slices( + tf.reshape(tf.range(90, 0, -1), (10, 3, 3)) + ) + + zipped = tf.data.Dataset.zip((tf_dataset_up, tf_dataset_up)) + zipped_batched_in = tf.data.Dataset.zip( + (tf_dataset_up.batch(3), tf_dataset_up.batch(3)) + ) + + assert are_dataset_first_elems_equal(tf_dataset_up, tf_dataset_up) + assert are_dataset_first_elems_equal(tf_dataset_up.batch(3), tf_dataset_up.batch(3)) + assert are_dataset_first_elems_equal(tf_dataset_up, tf_dataset_up_small) + assert are_dataset_first_elems_equal( + tf_dataset_up.batch(3), tf_dataset_up_small.batch(3) + ) + assert are_dataset_first_elems_equal(zipped, zipped) + assert are_dataset_first_elems_equal(zipped.batch(3), zipped.batch(3)) + assert are_dataset_first_elems_equal(zipped_batched_in, zipped_batched_in) + assert not are_dataset_first_elems_equal(tf_dataset_up, zipped) + assert not are_dataset_first_elems_equal(tf_dataset_up.batch(3), zipped.batch(3)) + assert not are_dataset_first_elems_equal(tf_dataset_up.batch(3), zipped_batched_in) + assert not are_dataset_first_elems_equal(tf_dataset_up, tf_dataset_down) + assert not are_dataset_first_elems_equal( + tf_dataset_up.batch(3), tf_dataset_down.batch(3) + ) + + +def test_is_shuffled(): + """ + Verify the function is able to detect dataset that do not provide stable order of elements + """ + # test with non-shuffled datasets + tf_dataset = tf.data.Dataset.from_tensor_slices( + tf.reshape(tf.range(900), (100, 3, 3)) + ) + zipped = tf.data.Dataset.zip((tf_dataset, tf_dataset)) + tf_mapped = tf_dataset.map(lambda x: x) + + assert not is_shuffled(tf_dataset) + assert not is_shuffled(tf_dataset.batch(3)) + assert not is_shuffled(zipped) + assert not is_shuffled(zipped.batch(3)) + assert not is_shuffled(tf_mapped) + + # test with shuffled datasets + tf_shuffled_once = tf_dataset.shuffle(3, reshuffle_each_iteration=False) + tf_shuffled_once_zipped = tf.data.Dataset.zip((tf_shuffled_once, tf_shuffled_once)) + tf_shuffled_once_mapped = tf_shuffled_once.map(lambda x: x) + + assert not is_shuffled(tf_shuffled_once) + assert not is_shuffled(tf_shuffled_once.batch(3)) + assert not is_shuffled(tf_shuffled_once_zipped) + assert not is_shuffled(tf_shuffled_once_zipped.batch(3)) + assert not is_shuffled(tf_shuffled_once_mapped) + + # test with reshuffled datasets + tf_reshuffled = tf_dataset.shuffle(3, reshuffle_each_iteration=True) + tf_reshuffled_zipped = tf.data.Dataset.zip((tf_reshuffled, tf_reshuffled)) + tf_reshuffled_mapped = tf_reshuffled.map(lambda x: x) + + assert is_shuffled(tf_reshuffled) + assert is_shuffled(tf_reshuffled.batch(3)) + assert is_shuffled(tf_reshuffled_zipped) + assert is_shuffled(tf_reshuffled_zipped.batch(3)) + assert is_shuffled(tf_reshuffled_mapped) + + +def test_batch_size_matches(): + """ + Test that the function is able to detect incoherence between dataset and batch_size + """ + tf_dataset = tf.data.Dataset.from_tensor_slices( + tf.reshape(tf.range(90), (10, 3, 3)) + ) + tf_b1 = tf_dataset.batch(1) + tf_b2 = tf_dataset.batch(2) + tf_b5 = tf_dataset.batch(5) + tf_b25 = tf_b5.batch(2) + tf_b52 = tf_b2.batch(5) + tf_b32 = tf_dataset.batch(32) + + tf_b5_shuffled = tf_b5.shuffle(3) + tf_b5_zipped = tf.data.Dataset.zip((tf_b5, tf_b5)) + tf_b5_mapped = tf_b5.map(lambda x: x) + + assert batch_size_matches(tf_b1, 1) + assert batch_size_matches(tf_b2, 2) + assert batch_size_matches(tf_b5, 5) + assert batch_size_matches(tf_b25, 2) + assert batch_size_matches(tf_b52, 5) + assert batch_size_matches(tf_b32, 10) + assert batch_size_matches(tf_b5_shuffled, 5) + assert batch_size_matches(tf_b5_zipped, 5) + assert batch_size_matches(tf_b5_mapped, 5) + + assert not batch_size_matches(tf_b1, 2) + assert not batch_size_matches(tf_b2, 1) + assert not batch_size_matches(tf_b5, 2) + assert not batch_size_matches(tf_b25, 5) + assert not batch_size_matches(tf_b52, 2) + assert not batch_size_matches(tf_b32, 5) + assert not batch_size_matches(tf_b5_shuffled, 2) + assert not batch_size_matches(tf_b5_zipped, 2) + assert not batch_size_matches(tf_b5_mapped, 2) + + +def test_sanitize_dataset(): + """ + Test that verifies that the function harmonize inputs into datasets + """ + tf_tensor = tf.reshape(tf.range(90), (10, 3, 3)) + np_array = np.array(tf_tensor) + tf_dataset = tf.data.Dataset.from_tensor_slices(tf_tensor) + tf_dataset_b4 = tf_dataset.batch(4) + tf_dataset_b4_mapped = tf_dataset_b4.map(lambda x: x).prefetch(2) + + # test sanitize_dataset do not destroy the dataset + assert sanitize_dataset(None, 1) is None + assert datasets_are_equal(sanitize_dataset(tf_dataset_b4, 4), tf_dataset_b4) + assert datasets_are_equal(sanitize_dataset(tf_dataset_b4_mapped, 4), tf_dataset_b4) + + # test convertion to tf dataset + assert datasets_are_equal(sanitize_dataset(np_array, 4), tf_dataset_b4) + assert datasets_are_equal(sanitize_dataset(tf_tensor, 4), tf_dataset_b4) + assert datasets_are_equal(sanitize_dataset(tf_dataset, 4), tf_dataset_b4) + + # test catch assertion errors + with pytest.raises(AssertionError): + sanitize_dataset(tf_dataset.shuffle(2).batch(4), 4) + with pytest.raises(AssertionError): + sanitize_dataset(tf_dataset_b4, 3) + with pytest.raises(AssertionError): + sanitize_dataset(tf_dataset_b4, 4, 4) + with pytest.raises(AssertionError): + sanitize_dataset(np_array[:6], 4, 4) + + +def test_dataset_gather(): + """ + Test dataset gather function + """ + # (5, 2, 3, 3) + tf_dataset = tf.data.Dataset.from_tensor_slices( + tf.reshape(tf.range(90), (10, 3, 3)) + ).batch(2) + + indices_1 = np.array([[[0, 0], [1, 1]], [[2, 1], [0, 0]]]) + # (2, 2, 3, 3) + results_1 = dataset_gather(tf_dataset, indices_1) + assert np.all(tf.shape(results_1).numpy() == np.array([2, 2, 3, 3])) + assert _almost_equal(results_1[0, 0], results_1[1, 1]) + + indices_2 = tf.constant([[[1, 1]]]) + # (1, 1, 3, 3) + results_2 = dataset_gather(tf_dataset, indices_2) + assert np.all(tf.shape(results_2).numpy() == np.array([1, 1, 3, 3])) + assert _almost_equal(results_1[0, 1], results_2[0, 0]) diff --git a/tests/example_based/test_torch.py b/tests/example_based/test_torch.py new file mode 100644 index 00000000..2737988e --- /dev/null +++ b/tests/example_based/test_torch.py @@ -0,0 +1,448 @@ +""" +Test example-based methods with PyTorch models and datasets. +""" + +import pytest + +import numpy as np +import tensorflow as tf +import torch +from torch import nn +import torch.nn.functional as F +from torch.utils.data import TensorDataset, DataLoader + +from xplique.example_based import ( + SimilarExamples, Cole, MMDCritic, ProtoDash, ProtoGreedy, + NaiveCounterFactuals, LabelAwareCounterFactuals, KLEORGlobalSim, KLEORSimMiss, +) +from xplique.example_based.projections import Projection, LatentSpaceProjection, HadamardProjection +from xplique.example_based.projections.commons import model_splitting + +from xplique.example_based.datasets_operations.tf_dataset_operations import are_dataset_first_elems_equal +from xplique.example_based.datasets_operations.harmonize import harmonize_datasets + +from tests.utils import almost_equal + + +def get_setup(input_shape, nb_samples=10, nb_labels=10): + """ + Generate data and model for SimilarExamples + """ + # Data generation + x_train = torch.stack( + [i * torch.ones(input_shape, dtype=torch.float32) for i in range(nb_samples)] + ) + y_train = torch.arange(len(x_train), dtype=torch.int64) % nb_labels + train_targets = F.one_hot(y_train, num_classes=nb_labels).to(torch.float32) + + x_test = x_train[1:-1] # Exclude the first and last elements + test_targets = train_targets[1:-1] # Exclude the first and last elements + + return x_train, x_test, y_train, train_targets, test_targets + + +def create_cnn_model(input_shape, output_shape): + in_channels, height, width = input_shape + + kernel_size = 3 + padding = 1 + stride = 1 + + # Calculate the flattened size after the convolutional layers and pooling + def conv_output_size(in_size): + return (in_size - kernel_size + 2 * padding) // stride + 1 + + height_after_conv1 = conv_output_size(height) // 2 # After first conv and pooling + height_after_conv2 = conv_output_size(height_after_conv1) // 2 # After second conv and pooling + + width_after_conv1 = conv_output_size(width) // 2 # After first conv and pooling + width_after_conv2 = conv_output_size(width_after_conv1) // 2 # After second conv and pooling + + flat_size = 8 * height_after_conv2 * width_after_conv2 # 8 is the number of filters in the last conv layer + + model = nn.Sequential( + # Convolutional layer 1 + nn.Conv2d(in_channels=in_channels, out_channels=4, kernel_size=kernel_size, padding=padding), # 4 filters + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), # Pooling layer (2x2) + + # Convolutional layer 2 + nn.Conv2d(in_channels=4, out_channels=8, kernel_size=kernel_size, padding=padding), # 8 filters + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2), # Pooling layer (2x2) + + # Flatten layer + nn.Flatten(), + + # Fully connected layer 1 + nn.Linear(flat_size, 16), + nn.ReLU(), + + # Output layer + nn.Linear(16, output_shape) + ) + + # Initialize all weights to ones + for layer in model: + if isinstance(layer, (nn.Conv2d, nn.Linear)): + nn.init.constant_(layer.weight, 1.0) # Set all weights to ones + if layer.bias is not None: + nn.init.constant_(layer.bias, 0.0) # Optionally set all biases to zero + + return model + + +def test_harmonize_datasets_with_torch(): + import torch + + cases = torch.rand(100, 10) + labels = torch.randint(0, 2, (100, 1)) + batch_size = 10 + + cases_out, labels_out, targets_out, batch_size_out = harmonize_datasets(cases, labels, batch_size=batch_size) + + assert targets_out is None, "Targets should be None when not provided." + assert batch_size_out == batch_size, "Output batch size should match the input batch size." + + for case, label in zip(cases_out, labels_out): + assert case.shape == (batch_size, cases.shape[1]), "Each case should have the same shape as the input cases." + assert label.shape == (batch_size, labels.shape[1]), "Each label should have the same shape as the input labels." + break + + +def test_inputs_combinations(): + """ + Test management of dataset init inputs + """ + + tf_tensor = tf.reshape(tf.range(90, dtype=tf.float32), (10, 3, 3)) + np_array = np.array(tf_tensor) + tf_dataset = tf.data.Dataset.from_tensor_slices(tf_tensor) + + tf_dataset_b3 = tf_dataset.batch(3) + tf_dataset_b5 = tf_dataset.batch(5) + + torch_tensor = torch.tensor(np_array) + torch_dataset = TensorDataset(torch_tensor) + zipped2 = TensorDataset(torch_tensor, torch_tensor) + zipped3 = TensorDataset(torch_tensor, torch_tensor, torch_tensor) + torch_dataloader_b3 = DataLoader(torch_dataset, batch_size=3, shuffle=False) + torch_dataloader_b5 = DataLoader(torch_dataset, batch_size=5, shuffle=False) + torch_zipped2_dataloader_b5 = DataLoader(zipped2, batch_size=5, shuffle=False) + torch_zipped3_dataloader_b3 = DataLoader(zipped3, batch_size=3, shuffle=False) + + # Method initialization that should work + cases_dataset, labels_dataset, targets_dataset, batch_size =\ + harmonize_datasets(torch_dataloader_b3, None, torch_dataloader_b3) + assert are_dataset_first_elems_equal(cases_dataset, tf_dataset_b3) + assert are_dataset_first_elems_equal(labels_dataset, None) + assert are_dataset_first_elems_equal(targets_dataset, tf_dataset_b3) + assert batch_size == 3 + + cases_dataset, labels_dataset, targets_dataset, batch_size =\ + harmonize_datasets(torch_tensor, torch_tensor, None, batch_size=5) + assert are_dataset_first_elems_equal(cases_dataset, tf_dataset_b5) + assert are_dataset_first_elems_equal(labels_dataset, tf_dataset_b5) + assert are_dataset_first_elems_equal(targets_dataset, None) + assert batch_size == 5 + + cases_dataset, labels_dataset, targets_dataset, batch_size =\ + harmonize_datasets(torch_zipped2_dataloader_b5, None, torch_dataloader_b5) + assert are_dataset_first_elems_equal(cases_dataset, tf_dataset_b5) + assert are_dataset_first_elems_equal(labels_dataset, tf_dataset_b5) + assert are_dataset_first_elems_equal(targets_dataset, tf_dataset_b5) + assert batch_size == 5 + + cases_dataset, labels_dataset, targets_dataset, batch_size =\ + harmonize_datasets(torch_zipped3_dataloader_b3, batch_size=3) + assert are_dataset_first_elems_equal(cases_dataset, tf_dataset_b3) + assert are_dataset_first_elems_equal(labels_dataset, tf_dataset_b3) + assert are_dataset_first_elems_equal(targets_dataset, tf_dataset_b3) + assert batch_size == 3 + + + +def test_error_raising(): + """ + Test management of dataset init inputs + """ + + tf_tensor = tf.reshape(tf.range(90, dtype=tf.float32), (10, 3, 3)) + tf_dataset = tf.data.Dataset.from_tensor_slices(tf_tensor) + + torch_tensor = torch.reshape(torch.arange(90, dtype=torch.float32), (10, 3, 3)) + np_array = np.array(torch_tensor) + + torch_dataset = TensorDataset(torch_tensor) + torch_dataloader = DataLoader(torch_dataset, batch_size=None, shuffle=False) + torch_shuffled = DataLoader(torch_dataset, batch_size=4, shuffle=True) + torch_dataloader_b3 = DataLoader(torch_dataset, batch_size=3, shuffle=False) + torch_dataloader_b5 = DataLoader(torch_dataset, batch_size=5, shuffle=False) + + zipped2 = TensorDataset(torch_tensor, torch_tensor) + zipped3 = TensorDataset(torch_tensor, torch_tensor, torch_tensor) + torch_zipped2_dataloader_b5 = DataLoader(zipped2, batch_size=5, shuffle=False) + torch_zipped3_dataloader_b3 = DataLoader(zipped3, batch_size=3, shuffle=False) + + too_long_torch_tensor = torch.cat([torch_tensor, torch_tensor], dim=0) + too_long_torch_dataset = TensorDataset(too_long_torch_tensor) + too_long_torch_dataloader_b10 = DataLoader(too_long_torch_dataset, batch_size=10, shuffle=False) + + + # Method initialization that should not work + + # not input + with pytest.raises(TypeError): + harmonize_datasets() + + # shuffled + with pytest.raises(AssertionError): + harmonize_datasets(torch_shuffled,) + + # mismatching types + with pytest.raises(AssertionError): + harmonize_datasets(torch_dataloader_b3, torch_tensor,) + with pytest.raises(AssertionError): + harmonize_datasets(torch_tensor, tf_tensor,) + with pytest.raises(AssertionError): + harmonize_datasets(np_array, torch_tensor,) + with pytest.raises(AssertionError): + harmonize_datasets(np_array, torch_dataloader_b3,) + with pytest.raises(AssertionError): + harmonize_datasets(tf_dataset, torch_dataloader_b3,) + with pytest.raises(AssertionError): + harmonize_datasets(torch_zipped2_dataloader_b5, tf_tensor,) + + # labels or targets zipped + with pytest.raises(AssertionError): + harmonize_datasets(torch_dataloader_b5, torch_zipped2_dataloader_b5,) + with pytest.raises(AssertionError): + harmonize_datasets(torch_dataloader_b3, None, torch_zipped3_dataloader_b3,) + + # not batched and no batch size provided + with pytest.raises(AssertionError): + harmonize_datasets(torch_dataloader,) + + # not matching batch sizes + with pytest.raises(AssertionError): + harmonize_datasets(torch_dataloader_b3, torch_dataloader_b5,) + with pytest.raises(AssertionError): + harmonize_datasets(torch_zipped2_dataloader_b5, None, torch_dataloader_b3,) + + with pytest.raises(AssertionError): + harmonize_datasets( + too_long_torch_dataloader_b10, + too_long_torch_dataloader_b10, + torch_dataloader_b5, + ) + + # multiple datasets for labels or targets + with pytest.raises(AssertionError): + harmonize_datasets(torch_zipped2_dataloader_b5, torch_dataloader_b5,) + with pytest.raises(AssertionError): + harmonize_datasets(torch_zipped3_dataloader_b3, None, torch_dataloader_b3,) + + +def test_torch_model_splitting(): + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + n_sample = 10 + torch_input_shape = (3, 32, 32) + input_shape = (32, 32, 3) + nb_labels = 10 + + model = create_cnn_model(input_shape=torch_input_shape, output_shape=nb_labels) + + # generate data + np_data = np.random.rand(n_sample, *input_shape).astype(np.float32) + + # inference with the initial model + model.eval() + model.to(device) + torch_data = torch.tensor(np_data, device=device) + with torch.no_grad(): + torch_channel_first_data = torch_data.permute(0, 3, 1, 2) + np_predictions_1 = model(torch_channel_first_data).cpu().numpy() + + assert np_predictions_1.shape == (n_sample, nb_labels) + + # test splitting support different types + _, _ = model_splitting(model, "flatten1") + _, _ = model_splitting(model, -2) + features_extractor, predictor = model_splitting(model, "last_conv") + + assert isinstance(features_extractor, tf.keras.Model) + assert isinstance(predictor, tf.keras.Model) + + + # inference with the splitted model + tf_data = tf.convert_to_tensor(np_data) + features = features_extractor(tf_data) + tf_predictions = predictor(features) + np_predictions_2 = tf_predictions.numpy() + + assert tf_predictions.shape == (n_sample, nb_labels) + assert np.allclose(np_predictions_1, np_predictions_2, atol=1e-5) + + +def test_similar_examples_basic(): + """ + Test the SimilarExamples with an identity projection. + """ + input_shape = (4, 4, 1) + k = 3 + batch_size = 4 + + x_train, x_test, y_train, _, _ = get_setup(input_shape) + + torch_dataset = TensorDataset(x_train, y_train) + torch_dataloader = DataLoader(torch_dataset, batch_size=batch_size, shuffle=False) + + identity_projection = Projection( + space_projection=lambda inputs, targets=None: inputs + ) + + # Method initialization + method = SimilarExamples( + cases_dataset=torch_dataloader, + projection=identity_projection, + k=k, + batch_size=batch_size, + distance="euclidean", + case_returns=["examples", "labels"], + ) + + # Generate explanation + outputs = method.explain(x_test) + examples = outputs["examples"] + labels = outputs["labels"] + + # Verifications + # Shape should be (n, k, h, w, c) + assert examples.shape == (len(x_test), k) + input_shape + + for i in range(len(x_test)): + # test examples: + assert almost_equal(np.array(examples[i, 0]), np.array(x_train[i + 1])) + assert almost_equal(np.array(examples[i, 1]), np.array(x_train[i + 2]))\ + or almost_equal(np.array(examples[i, 1]), np.array(x_train[i])) + assert almost_equal(np.array(examples[i, 2]), np.array(x_train[i]))\ + or almost_equal(np.array(examples[i, 2]), np.array(x_train[i + 2])) + + # test labels: + assert almost_equal(np.array(labels[i, 0]), np.array(y_train[i + 1])) + assert almost_equal(np.array(labels[i, 1]), np.array(y_train[i + 2]))\ + or almost_equal(np.array(labels[i, 1]), np.array(y_train[i])) + assert almost_equal(np.array(labels[i, 2]), np.array(y_train[i]))\ + or almost_equal(np.array(labels[i, 2]), np.array(y_train[i + 2])) + + +def test_similar_examples_with_splitting(): + """ + Test the SimilarExamples with an identity projection. + """ + # Setup + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + nb_samples = 10 + torch_input_shape = (3, 32, 32) + input_shape = (32, 32, 3) + nb_labels = 10 + k = 3 + batch_size = 4 + + x_train, x_test, y_train, _, _ = get_setup(input_shape, nb_samples, nb_labels) + torch_dataset = TensorDataset(x_train, y_train) + torch_dataloader = DataLoader(torch_dataset, batch_size=batch_size, shuffle=False) + + model = create_cnn_model(input_shape=torch_input_shape, output_shape=nb_labels) + projection = LatentSpaceProjection(model, "last_conv", device=device) + + # Method initialization + method = SimilarExamples( + cases_dataset=torch_dataloader, + projection=projection, + k=k, + batch_size=batch_size, + distance="euclidean", + case_returns=["examples", "labels"], + ) + + # Generate explanation + outputs = method.explain(x_test) + examples = outputs["examples"] + labels = outputs["labels"] + + # Verifications + # Shape should be (n, k, h, w, c) + assert examples.shape == (len(x_test), k) + input_shape + + for i in range(len(x_test)): + # test examples: + assert almost_equal(np.array(examples[i, 0]), np.array(x_train[i + 1])) + assert almost_equal(np.array(examples[i, 1]), np.array(x_train[i + 2]))\ + or almost_equal(np.array(examples[i, 1]), np.array(x_train[i])) + assert almost_equal(np.array(examples[i, 2]), np.array(x_train[i]))\ + or almost_equal(np.array(examples[i, 2]), np.array(x_train[i + 2])) + + # test labels: + assert almost_equal(np.array(labels[i, 0]), np.array(y_train[i + 1])) + assert almost_equal(np.array(labels[i, 1]), np.array(y_train[i + 2]))\ + or almost_equal(np.array(labels[i, 1]), np.array(y_train[i])) + assert almost_equal(np.array(labels[i, 2]), np.array(y_train[i]))\ + or almost_equal(np.array(labels[i, 2]), np.array(y_train[i + 2])) + + +def test_all_methods_with_torch(): + # Setup + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + nb_samples = 13 + torch_input_shape = (3, 32, 32) + input_shape = (32, 32, 3) + nb_labels = 5 + batch_size = 4 + + x_train, x_test, y_train, train_targets, test_targets = get_setup(input_shape, nb_samples, nb_labels) + torch_dataset = TensorDataset(x_train, y_train) + torch_dataloader = DataLoader(torch_dataset, batch_size=batch_size, shuffle=False) + targets_dataloader = DataLoader(TensorDataset(train_targets), batch_size=batch_size, shuffle=False) + + model = create_cnn_model(input_shape=torch_input_shape, output_shape=nb_labels) + projection = HadamardProjection(model, "last_conv", device=device) + + methods = [SimilarExamples, Cole, MMDCritic, ProtoDash, ProtoGreedy, + NaiveCounterFactuals, LabelAwareCounterFactuals, KLEORGlobalSim, KLEORSimMiss,] + + for method_class in methods: + if method_class == Cole: + method = method_class( + cases_dataset=torch_dataloader, + targets_dataset=targets_dataloader, + case_returns="all", + model=model, + latent_layer="last_conv", + device=device, + ) + else: + method = method_class( + cases_dataset=torch_dataloader, + targets_dataset=targets_dataloader, + projection=projection, + case_returns="all", + ) + + # Generate explanation + if method_class == LabelAwareCounterFactuals: + outputs = method.explain(x_test, cf_expected_classes=test_targets) + elif method_class in [NaiveCounterFactuals, KLEORGlobalSim, KLEORSimMiss]: + outputs = method.explain(x_test, targets=test_targets) + else: + outputs = method.explain(x_test, targets=None) + + examples = outputs["examples"] + labels = outputs["labels"] + + assert examples.shape == (len(x_test), 2) + input_shape + assert labels.shape == (len(x_test), 1) diff --git a/tests/plots/test_image_example_based_plot.py b/tests/plots/test_image_example_based_plot.py new file mode 100644 index 00000000..3d2ee11d --- /dev/null +++ b/tests/plots/test_image_example_based_plot.py @@ -0,0 +1,88 @@ +""" +Test Cole +""" +import os +import sys + +sys.path.append(os.getcwd()) + +import tensorflow as tf + +from xplique.attributions import Occlusion + +from xplique.example_based import Cole +from xplique.plots.image import plot_examples + +from tests.utils import ( + generate_model, +) + + +def get_setup(input_shape, nb_samples=10, nb_labels=10): + """ + Generate data and model for Cole + """ + # Data generation + x_train = tf.stack( + [i * tf.ones(input_shape, tf.float32) for i in range(nb_samples)] + ) + x_test = x_train[1:-1] + y_train = tf.one_hot(tf.range(len(x_train)) % nb_labels, depth=nb_labels) + + # Model generation + model = generate_model(input_shape, nb_labels) + + return model, x_train, x_test, y_train + + +def test_plot_cole_spliting(): + """ + Test examples plot function. + """ + # Setup + nb_samples = 10 + input_shape = (6, 6, 3) + nb_labels = 5 + k = 1 + x_train = tf.random.uniform((nb_samples,) + input_shape, minval=0, maxval=1) + x_test = tf.random.uniform((nb_samples,) + input_shape, minval=0, maxval=1) + labels = tf.one_hot( + indices=tf.repeat(input=tf.range(nb_labels), repeats=[nb_samples // nb_labels]), + depth=nb_labels, + ) + y_train = labels + y_test = tf.random.shuffle(labels) + + # Model generation + model = generate_model(input_shape, nb_labels) + + # Cole with attribution method constructor + method = Cole( + cases_dataset=x_train, + labels_dataset=tf.argmax(y_train, axis=1), + targets_dataset=y_train, + k=k, + case_returns="all", + model=model, + latent_layer="last_conv", + attribution_method=Occlusion, + patch_size=2, + patch_stride=1, + ) + + # Generate explanation + outputs = method.explain(x_test, y_test) + + # get predictions on examples + predicted_labels = tf.map_fn( + fn=lambda x: tf.cast(tf.argmax(model(x), axis=1), tf.int32), + elems=outputs["examples"], + fn_output_signature=tf.int32, + ) + + # test plot + plot_examples( + test_labels=tf.argmax(y_test, axis=1), + predicted_labels=predicted_labels, + **outputs + ) diff --git a/tests/utils.py b/tests/utils.py index 67cf0e36..483d716b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,8 +7,6 @@ from tensorflow.keras.layers import (Dense, Conv1D, Conv2D, Activation, GlobalAveragePooling1D, Dropout, Flatten, MaxPooling2D, Input, Reshape) from tensorflow.keras.utils import to_categorical -from PIL import Image, ImageDraw, ImageFont -import urllib.request import requests def generate_data(x_shape=(32, 32, 3), num_labels=10, samples=100): @@ -31,6 +29,14 @@ def generate_model(input_shape=(32, 32, 3), output_shape=10): return model +def generate_agnostic_model(input_shape=(3,), nb_labels=3): + model = Sequential() + model.add(Input(input_shape)) + model.add(Flatten()) + model.add(Dense(nb_labels)) + + return model + def generate_timeseries_model(input_shape=(20, 10), output_shape=10): model = Sequential() model.add(Input(shape=input_shape)) @@ -161,64 +167,6 @@ def model_with_random_nb_boxes(input): return model_with_random_nb_boxes return valid_model -def generate_txt_images_data(x_shape=(32, 32, 3), num_labels=10, samples=100): - """ - Generate an image dataset composed of white texts over black background. - The texts are words of 3 successive letters, the number of classes is set by the - parameter num_labels. The location of the text in the image is cycling over the - image dimensions. - Ex: with num_labels=3, the 3 classes will be 'ABC', 'BCD' and 'CDE'. - - """ - all_labels_str = "".join([chr(lab_idx) for lab_idx in range(65, 65+num_labels+2)]) # ABCDEF - labels_str = [all_labels_str[i:i+3] for i in range(len(all_labels_str) - 2)] # ['ABC', 'BCD', 'CDE', 'DEF'] - - def create_image_from_txt(image_shape, txt, offset_x, offset_y): - # Get a Pillow font (OS independant) - try: - fnt = ImageFont.truetype("FreeMono.ttf", 16) - except OSError: - # dl the font it is it not in the system - url = "https://github.com/python-pillow/Pillow/raw/main/Tests/fonts/FreeMono.ttf" - urllib.request.urlretrieve(url, "tests/FreeMono.ttf") - fnt = ImageFont.truetype("tests/FreeMono.ttf", 16) - - # Make a black image and draw the input text in white at the location offset_x, offset_y - rgb = (len(image_shape) == 3 and image_shape[2] > 1) - if rgb: - image = Image.new("RGB", (image_shape[0], image_shape[1]), (0, 0, 0)) - else: - # grayscale - image = Image.new("L", (image_shape[0], image_shape[1]), 0) - d = ImageDraw.Draw(image) - d.text((offset_x, offset_y), txt, font=fnt, fill='white') - return image - - x = np.empty((samples, *x_shape)).astype(np.float32) - y = np.empty(samples) - - # Iterate over the samples and generate images of labels shifted by increasing offsets - offset_x_max = x_shape[0] - 25 - offset_y_max = x_shape[1] - 10 - - current_label_id = 0 - offset_x = offset_y = 0 - for i in range(samples): - image = create_image_from_txt(x_shape, txt=labels_str[current_label_id], offset_x=offset_x, offset_y=offset_y) - image = np.reshape(image, x_shape) - x[i] = np.array(image).astype(np.float32)/255.0 - y[i] = current_label_id - - # cycle labels - current_label_id = (current_label_id + 1) % num_labels - offset_x = (offset_x + 1) % offset_x_max - offset_y = ((i+2) % offset_y_max) - if offset_y > offset_y_max: - break - x = x[0:i] - y = y[0:i] - return x, to_categorical(y, num_labels), i, labels_str - def download_file(identifier: str, destination: str): """ @@ -242,3 +190,21 @@ def download_file(identifier: str, for chunk in response.iter_content(chunk_size): if chunk: file.write(chunk) + +def get_gaussian_data(nb_classes=3, nb_samples_class=20, n_dims=1): + tf.random.set_seed(42) + + sigma = 1 + mu = [10 * (id + 1) for id in range(nb_classes)] + + X = tf.concat([ + tf.random.normal(shape=(nb_samples_class, n_dims), mean=mu[i], stddev=sigma, dtype=tf.float32) + for i in range(nb_classes) + ], axis=0) + + y = tf.concat([ + tf.ones(shape=(nb_samples_class), dtype=tf.int32) * i + for i in range(nb_classes) + ], axis=0) + + return(X, y) diff --git a/xplique/__init__.py b/xplique/__init__.py index 32ee5166..9438ef46 100644 --- a/xplique/__init__.py +++ b/xplique/__init__.py @@ -6,12 +6,13 @@ techniques """ -__version__ = '1.3.3' +__version__ = '1.4.0' from . import attributions +from . import commons from . import concepts +from . import example_based from . import features_visualizations -from . import commons from . import plots from .commons import Tasks diff --git a/xplique/commons/__init__.py b/xplique/commons/__init__.py index 94237f90..7439c846 100644 --- a/xplique/commons/__init__.py +++ b/xplique/commons/__init__.py @@ -2,10 +2,10 @@ Utility classes and functions """ -from .data_conversion import tensor_sanitize, numpy_sanitize +from .data_conversion import tensor_sanitize, numpy_sanitize, sanitize_inputs_targets from .model_override import guided_relu_policy, deconv_relu_policy, override_relu_gradient, \ find_layer, open_relu_policy -from .tf_operations import repeat_labels, batch_tensor +from .tf_operations import repeat_labels, batch_tensor, get_device from .callable_operations import predictions_one_hot_callable from .operators_operations import (Tasks, get_operator, check_operator, operator_batching, get_inference_function, get_gradient_functions) diff --git a/xplique/commons/data_conversion.py b/xplique/commons/data_conversion.py index 517f86ad..d5126db6 100644 --- a/xplique/commons/data_conversion.py +++ b/xplique/commons/data_conversion.py @@ -5,11 +5,11 @@ import tensorflow as tf import numpy as np -from ..types import Union, Optional, Tuple +from ..types import Union, Optional, Tuple, Callable def tensor_sanitize(inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray], - targets: Optional[Union[tf.Tensor, np.ndarray]]) -> Tuple[tf.Tensor, tf.Tensor]: + targets: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, tf.Tensor]: """ Ensure the output as tf.Tensor, accept various inputs format including: tf.Tensor, List, numpy array, tf.data.Dataset (when label = None). @@ -66,3 +66,35 @@ def numpy_sanitize(inputs: Union[tf.data.Dataset, tf.Tensor, np.ndarray], """ inputs, targets = tensor_sanitize(inputs, targets) return inputs.numpy(), targets.numpy() + + +def sanitize_inputs_targets(explanation_method: Callable): + """ + Wrap a method explanation function to ensure tf.Tensor as inputs and targets. + But targets may be None. + + explanation_method + Function to wrap, should return an tf.tensor. + """ + def sanitize(self, + inputs: Union[tf.Tensor, np.array], + targets: Optional[Union[tf.Tensor, np.array]] = None, + *args, + **kwargs + ): + # pylint: disable=keyword-arg-before-vararg + # ensure we have tf.tensor + inputs = tf.cast(inputs, tf.float32) + if targets is not None: + targets = tf.cast(targets, tf.float32) + + if args: + args = [tf.cast(arg, tf.float32) for arg in args] + + if kwargs: + kwargs = {key: tf.cast(value, tf.float32) for key, value in kwargs.items()} + + # then enter the explanation function + return explanation_method(self, inputs, targets, *args, **kwargs) + + return sanitize diff --git a/xplique/commons/tf_operations.py b/xplique/commons/tf_operations.py index 1d6e5fae..3831b41f 100644 --- a/xplique/commons/tf_operations.py +++ b/xplique/commons/tf_operations.py @@ -54,3 +54,28 @@ def batch_tensor(tensors: Union[Tuple, tf.Tensor], dataset = dataset.batch(batch_size) return dataset + + +def get_device(device: Optional[str] = None) -> str: + """ + Gets the name of the device to use. If there are any available GPUs, it will use the first one + in the system, otherwise, it will use the CPU. + + Parameters + ---------- + device + A string specifying the device on which to run the computations. If None, it will search + for available GPUs, and if none are found, it will return the first CPU. + + Returns + ------- + device + A string with the name of the device on which to run the computations. + """ + if device is not None: + return device + + physical_devices = tf.config.list_physical_devices('GPU') + if physical_devices is None or len(physical_devices) == 0: + return 'cpu:0' + return 'GPU:0' diff --git a/xplique/example_based/__init__.py b/xplique/example_based/__init__.py new file mode 100644 index 00000000..fa83c1ba --- /dev/null +++ b/xplique/example_based/__init__.py @@ -0,0 +1,8 @@ +""" +Example-based methods available +""" + +from .similar_examples import SimilarExamples, Cole +from .prototypes import Prototypes, ProtoGreedy, ProtoDash, MMDCritic +from .counterfactuals import NaiveCounterFactuals, LabelAwareCounterFactuals +from .semifactuals import KLEORGlobalSim, KLEORSimMiss diff --git a/xplique/example_based/base_example_method.py b/xplique/example_based/base_example_method.py new file mode 100644 index 00000000..700ab60f --- /dev/null +++ b/xplique/example_based/base_example_method.py @@ -0,0 +1,278 @@ +""" +Base model for example-based +""" + +from abc import ABC, abstractmethod +import warnings + +import tensorflow as tf +import numpy as np + +from ..types import Callable, Dict, List, Optional, Type, Union, DatasetOrTensor + +from ..commons import sanitize_inputs_targets +from .datasets_operations.harmonize import harmonize_datasets +from .datasets_operations.tf_dataset_operations import dataset_gather +from .search_methods import BaseSearchMethod +from .projections import Projection + +from .search_methods.base import _sanitize_returns + + +class BaseExampleMethod(ABC): + """ + Base class for natural example-based methods explaining classification models. + An example-based method is a method that explains a model's predictions by providing + examples from the cases_dataset (usually the training dataset). The examples are selected with + the help of a search method that performs a search in the search space. The search space is + defined with the help of a projection function that projects the cases_dataset + and the (inputs, targets) to explain into a space where the search method is relevant. + + Parameters + ---------- + cases_dataset + The dataset used to train the model, examples are extracted from this dataset. + All datasets (cases, labels, and targets) should be of the same type. + Supported types are: `tf.data.Dataset`, `torch.utils.data.DataLoader`, + `tf.Tensor`, `np.ndarray`, `torch.Tensor`. + For datasets with multiple columns, the first column is assumed to be the cases. + While the second column is assumed to be the labels, and the third the targets. + Warning: datasets tend to reshuffle at each iteration, ensure the datasets are + not reshuffle as we use index in the dataset. + labels_dataset + Labels associated with the examples in the `cases_dataset`. + It should have the same type as `cases_dataset`. + targets_dataset + Targets associated with the `cases_dataset` for dataset projection, + oftentimes the one-hot encoding of a model's predictions. See `projection` for detail. + It should have the same type as `cases_dataset`. + It is not be necessary for all projections. + Furthermore, projections which requires it compute it internally by default. + k + The number of examples to retrieve per input. + projection + Projection or Callable that project samples from the input space to the search space. + The search space should be a space where distances are relevant for the model. + It should not be `None`, otherwise, the model is not involved thus not explained. + + Example of Callable: + ``` + def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None): + ''' + Example of projection, + inputs are the elements to project. + targets are optional parameters to orientated the projection. + ''' + projected_inputs = # do some magic on inputs, it should use the model. + return projected_inputs + ``` + case_returns + String or list of string with the elements to return in `self.explain()`. + See the returns property for details. + batch_size + Number of samples treated simultaneously for projection and search. + Ignored if `cases_dataset` is a batched `tf.data.Dataset` or + a batched `torch.utils.data.DataLoader` is provided. + """ + # pylint: disable=too-many-instance-attributes + _returns_possibilities = ["examples", "distances", "labels", "include_inputs"] + + def __init__( + self, + cases_dataset: DatasetOrTensor, + labels_dataset: Optional[DatasetOrTensor] = None, + targets_dataset: Optional[DatasetOrTensor] = None, + k: int = 1, + projection: Union[Projection, Callable] = None, + case_returns: Union[List[str], str] = "examples", + batch_size: Optional[int] = None, + ): + # set attributes + self.cases_dataset, self.labels_dataset, self.targets_dataset, self.batch_size =\ + harmonize_datasets(cases_dataset, labels_dataset, targets_dataset, batch_size) + + self._search_returns = ["indices", "distances"] + + # check projection + if isinstance(projection, Projection): + self.projection = projection + elif hasattr(projection, "__call__"): + self.projection = Projection(get_weights=None, space_projection=projection) + elif projection is None: + warnings.warn( + "Example-based methods without projection will not explain the model."\ + + "To explain the model, consider using projections like the LatentSpaceProjection." + ) + self.projection = Projection(get_weights=None, space_projection=None) + else: + raise AttributeError( + f"projection should be a `Projection` or a `Callable`, not a {type(projection)}" + ) + + # project dataset + self.projected_cases_dataset = self.projection.project_dataset(self.cases_dataset, + self.targets_dataset) + + # set properties + self.k = k + if self.labels_dataset is None\ + and ("labels" in case_returns or case_returns in ["all", "labels"]): + raise AttributeError( + "The method cannot return labels without a label dataset." + ) + self.returns = case_returns + + # temporary value for the search method + self.search_method = None + + @property + @abstractmethod + def search_method_class(self) -> Type[BaseSearchMethod]: + """ + When inheriting from `BaseExampleMethod`, one should define the search method class to use. + """ + raise NotImplementedError + + @property + def k(self) -> int: + """Getter for the k parameter.""" + return self._k + + @k.setter + def k(self, k: int): + """Setter for the k parameter.""" + assert isinstance(k, int) and k >= 1, f"k should be an int >= 1 and not {k}" + self._k = k + + try: + self.search_method.k = k + except AttributeError: + pass + + @property + def returns(self) -> Union[List[str], str]: + """Getter for the returns parameter.""" + return self._returns + + @returns.setter + def returns(self, returns: Union[List[str], str]): + """ + Setter for the returns parameter used to define returned elements in `self.explain()`. + + Parameters + ---------- + returns + Most elements are useful in `xplique.plots.plot_examples()`. + `returns` can be set to 'all' for all possible elements to be returned. + - 'examples' correspond to the expected examples, + the inputs may be included in first position. (n, k(+1), ...) + - 'distances' the distances between the inputs and the corresponding examples. + They are associated to the examples. (n, k, ...) + - 'labels' if provided through `dataset_labels`, + they are the labels associated with the examples. (n, k, ...) + - 'include_inputs' specify if inputs should be included in the returned elements. + Note that it changes the number of returned elements from k to k+1. + """ + default = "examples" + self._returns = _sanitize_returns(returns, self._returns_possibilities, default) + + @sanitize_inputs_targets + def explain( + self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None, + ): + """ + Return the relevant examples to explain the (inputs, targets). + It projects inputs with `self.projection` in the search space + and find examples with the `self.search_method`. + + Parameters + ---------- + inputs + Tensor or Array. Input samples to be explained. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + More information in the documentation. + targets + Targets associated to the `inputs` for projection. + Shape: (n, nb_classes) where n is the number of samples and + nb_classes is the number of classes. + It is used in the `projection`. But `projection` can compute it internally. + + Returns + ------- + return_dict + Dictionary with listed elements in `self.returns`. + The elements that can be returned are defined with the `_returns_possibilities` + static attribute of the class. + """ + # project inputs into the search space + projected_inputs = self.projection(inputs, targets) + + # look for relevant elements in the search space + search_output = self.search_method.find_examples(projected_inputs, targets) + + # manage returned elements + return self.format_search_output(search_output, inputs) + + def __call__( + self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None, + ): + """explain() alias""" + return self.explain(inputs, targets) + + def format_search_output( + self, + search_output: Dict[str, tf.Tensor], + inputs: Union[tf.Tensor, np.ndarray], + ): + """ + Format the output of the `search_method` to match the expected returns in `self.returns`. + + Parameters + ---------- + search_output + Dictionary with the required outputs from the `search_method`. + inputs + Tensor or Array. Input samples to be explained. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + # targets + # Targets associated to the cases_dataset for dataset projection. + # See `projection` for details. + + Returns + ------- + return_dict + Dictionary with listed elements in `self.returns`. + The elements that can be returned are defined with the `_returns_possibilities` + static attribute of the class. + """ + # initialize return dictionary + return_dict = {} + + # gather examples, labels, and targets from the example's indices of the search output + examples = dataset_gather(self.cases_dataset, search_output["indices"]) + examples_labels = dataset_gather(self.labels_dataset, search_output["indices"]) + + # add examples and weights + if "examples" in self.returns: # or "weights" in self.returns: + if "include_inputs" in self.returns: + # include inputs + inputs = tf.expand_dims(inputs, axis=1) + examples = tf.concat([inputs, examples], axis=1) + return_dict["examples"] = examples + + # add indices, distances, and labels + if "indices" in self.returns: + return_dict["indices"] = search_output["indices"] + if "distances" in self.returns: + return_dict["distances"] = search_output["distances"] + if "labels" in self.returns: + assert (examples_labels is not None),\ + "The method cannot return labels without a label dataset. "\ + + "Either remove 'labels' from `case_returns` or provide a `labels_dataset`." + return_dict["labels"] = examples_labels + + return return_dict diff --git a/xplique/example_based/counterfactuals.py b/xplique/example_based/counterfactuals.py new file mode 100644 index 00000000..a85f24e4 --- /dev/null +++ b/xplique/example_based/counterfactuals.py @@ -0,0 +1,303 @@ +""" +Implementation of both counterfactuals and semi factuals methods for classification tasks. +""" +import numpy as np +import tensorflow as tf + +from ..commons import sanitize_inputs_targets +from ..types import Callable, List, Optional, Union, DatasetOrTensor + +from .base_example_method import BaseExampleMethod +from .search_methods import ORDER, FilterKNN +from .projections import Projection + + +class NaiveCounterFactuals(BaseExampleMethod): + """ + This class allows to search for counterfactuals by searching for the closest sample to + a query in a projection space that do not have the same model's prediction. + It is a naive approach as it follows a greedy approach. + + Parameters + ---------- + cases_dataset + The dataset used to train the model, examples are extracted from this dataset. + All datasets (cases, labels, and targets) should be of the same type. + Supported types are: `tf.data.Dataset`, `torch.utils.data.DataLoader`, + `tf.Tensor`, `np.ndarray`, `torch.Tensor`. + For datasets with multiple columns, the first column is assumed to be the cases. + While the second column is assumed to be the labels, and the third the targets. + Warning: datasets tend to reshuffle at each iteration, ensure the datasets are + not reshuffle as we use index in the dataset. + targets_dataset + Targets associated with the `cases_dataset` for dataset projection, + oftentimes the one-hot encoding of a model's predictions. See `projection` for detail. + They are also used to know the prediction of the model on the dataset. + It should have the same type as `cases_dataset`. + labels_dataset + Labels associated with the examples in the `cases_dataset`. + It should have the same type as `cases_dataset`. + k + The number of examples to retrieve per input. + projection + Projection or Callable that project samples from the input space to the search space. + The search space should be a space where distances are relevant for the model. + It should not be `None`, otherwise, the model is not involved thus not explained. + + Example of Callable: + ``` + def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None): + ''' + Example of projection, + inputs are the elements to project. + targets are optional parameters to orientated the projection. + ''' + projected_inputs = # do some magic on inputs, it should use the model. + return projected_inputs + ``` + case_returns + String or list of string with the elements to return in `self.explain()`. + See the base class returns property for more details. + batch_size + Number of samples treated simultaneously for projection and search. + Ignored if `cases_dataset` is a batched `tf.data.Dataset` or + a batched `torch.utils.data.DataLoader` is provided. + distance + Distance function for examples search. It can be an integer, a string in + {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable, + by default "euclidean". + """ + # pylint: disable=duplicate-code + + def __init__( + self, + cases_dataset: DatasetOrTensor, + targets_dataset: DatasetOrTensor, + labels_dataset: Optional[DatasetOrTensor] = None, + k: int = 1, + projection: Union[Projection, Callable] = None, + case_returns: Union[List[str], str] = "examples", + batch_size: Optional[int] = None, + distance: Union[int, str, Callable] = "euclidean", + ): + super().__init__( + cases_dataset=cases_dataset, + labels_dataset=labels_dataset, + targets_dataset=targets_dataset, + k=k, + projection=projection, + case_returns=case_returns, + batch_size=batch_size, + ) + + # initiate search_method + self.search_method = self.search_method_class( + cases_dataset=self.projected_cases_dataset, + targets_dataset=self.targets_dataset, + k=self.k, + search_returns=self._search_returns, + batch_size=self.batch_size, + distance=distance, + filter_fn=self.filter_fn, + order=ORDER.ASCENDING + ) + + @property + def search_method_class(self): + """ + This property defines the search method class to use for the search. + In this case, it is the FilterKNN that is an efficient KNN search method + ignoring non-acceptable cases, thus not considering them in the search. + """ + return FilterKNN + + + def filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor: + """ + Filter function to mask the cases for which the model's prediction + is different from the model's prediction on the inputs. + """ + # get the labels predicted by the model + # (n, ) + predicted_labels = tf.argmax(targets, axis=-1) + + # for each input, if the target label is the same as the predicted label + # the mask as a True value and False otherwise + label_targets = tf.argmax(cases_targets, axis=-1) # (bs,) + mask = tf.not_equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs) + return mask + + +class LabelAwareCounterFactuals(BaseExampleMethod): + """ + This method will search the counterfactuals of a query within an expected class. + This class should be provided with the query when calling the explain method. + + Parameters + ---------- + cases_dataset + The dataset used to train the model, examples are extracted from this dataset. + All datasets (cases, labels, and targets) should be of the same type. + Supported types are: `tf.data.Dataset`, `torch.utils.data.DataLoader`, + `tf.Tensor`, `np.ndarray`, `torch.Tensor`. + For datasets with multiple columns, the first column is assumed to be the cases. + While the second column is assumed to be the labels, and the third the targets. + Warning: datasets tend to reshuffle at each iteration, ensure the datasets are + not reshuffle as we use index in the dataset. + targets_dataset + Targets associated with the `cases_dataset` for dataset projection, + oftentimes the one-hot encoding of a model's predictions. See `projection` for detail. + They are also used to know the prediction of the model on the dataset. + It should have the same type as `cases_dataset`. + labels_dataset + Labels associated with the examples in the `cases_dataset`. + It should have the same type as `cases_dataset`. + k + The number of examples to retrieve per input. + projection + Projection or Callable that project samples from the input space to the search space. + The search space should be a space where distances are relevant for the model. + It should not be `None`, otherwise, the model is not involved thus not explained. + + Example of Callable: + ``` + def custom_projection(inputs: tf.Tensor, np.ndarray): + ''' + Example of projection, + inputs are the elements to project. + ''' + projected_inputs = # do some magic on inputs, it should use the model. + return projected_inputs + ``` + case_returns + String or list of string with the elements to return in `self.explain()`. + See the base class returns property for more details. + batch_size + Number of samples treated simultaneously for projection and search. + Ignored if `cases_dataset` is a batched `tf.data.Dataset` or + a batched `torch.utils.data.DataLoader` is provided. + distance + Distance for the FilterKNN search method. + Distance function for examples search. It can be an integer, a string in + {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable, + by default "euclidean". + """ + # pylint: disable=duplicate-code + + def __init__( + self, + cases_dataset: DatasetOrTensor, + targets_dataset: DatasetOrTensor, + labels_dataset: Optional[DatasetOrTensor] = None, + k: int = 1, + projection: Union[Projection, Callable] = None, + case_returns: Union[List[str], str] = "examples", + batch_size: Optional[int] = None, + distance: Union[int, str, Callable] = "euclidean", + ): + + super().__init__( + cases_dataset=cases_dataset, + labels_dataset=labels_dataset, + targets_dataset=targets_dataset, + k=k, + projection=projection, + case_returns=case_returns, + batch_size=batch_size, + ) + + # initiate search_method + self.search_method = self.search_method_class( + cases_dataset=self.projected_cases_dataset, + targets_dataset=self.targets_dataset, + k=self.k, + search_returns=self._search_returns, + batch_size=self.batch_size, + distance=distance, + filter_fn=self.filter_fn, + order=ORDER.ASCENDING + ) + + @property + def search_method_class(self): + """ + This property defines the search method class to use for the search. + In this case, it is the FilterKNN that is an efficient KNN search method ignoring + non-acceptable cases, thus not considering them in the search. + """ + return FilterKNN + + + def filter_fn(self, _, __, cf_expected_classes, cases_targets) -> tf.Tensor: + """ + Filter function to mask the cases for which the target is different from + the target(s) expected for the counterfactuals. + + Parameters + ---------- + cf_expected_classes + The one-hot encoding of the target class for the counterfactuals. + cases_targets + The one-hot encoding of the target class for the cases. + """ + cases_predicted_labels = tf.argmax(cases_targets, axis=-1) + cf_label_targets = tf.argmax(cf_expected_classes, axis=-1) + mask = tf.equal(tf.expand_dims(cf_label_targets, axis=1), cases_predicted_labels) + return mask + + @sanitize_inputs_targets + def explain( + self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None, + cf_expected_classes: Union[tf.Tensor, np.ndarray] = None, + ): + """ + Return the relevant CF examples to explain the inputs. + The CF examples are searched within cases + for which the target is the one provided in `cf_targets`. + It projects inputs with `self.projection` in the search space and + find examples with the `self.search_method`. + + Parameters + ---------- + inputs + Tensor or Array. Input samples to be explained. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + More information in the documentation. + targets + Tensor or Array. One-hot encoded labels or regression target (e.g {+1, -1}), + one for each sample. If not provided, the model's predictions are used. + Targets associated to the `inputs` for projection. + Shape: (n, nb_classes) where n is the number of samples and + nb_classes is the number of classes. + It is used in the `projection`. But `projection` can compute it internally. + cf_expected_classes + Tensor or Array. One-hot encoding of the target class for the counterfactuals. + + Returns + ------- + return_dict + Dictionary with listed elements in `self.returns`. + The elements that can be returned are defined with the `_returns_possibilities` + static attribute of the class. + """ + assert cf_expected_classes is not None, "cf_expected_classes should be provided." + + # project inputs into the search space + projected_inputs = self.projection(inputs, targets) + + # look for relevant elements in the search space + search_output = self.search_method(projected_inputs, cf_expected_classes) + + # manage returned elements + return self.format_search_output(search_output, inputs) + + def __call__( + self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None, + cf_expected_classes: Union[tf.Tensor, np.ndarray] = None, + ): + """explain() alias""" + return self.explain(inputs, targets, cf_expected_classes) diff --git a/xplique/example_based/datasets_operations/__init__.py b/xplique/example_based/datasets_operations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/xplique/example_based/datasets_operations/convert_torch_to_tf.py b/xplique/example_based/datasets_operations/convert_torch_to_tf.py new file mode 100644 index 00000000..9f07fbb6 --- /dev/null +++ b/xplique/example_based/datasets_operations/convert_torch_to_tf.py @@ -0,0 +1,187 @@ +""" +Set of functions to convert `torch.utils.data.DataLoader` and `torch.Tensor` to `tf.data.Dataset` +""" +from typing import Optional, Tuple + +import tensorflow as tf +import torch + + +def convert_column_dataloader_to_tf_dataset( + dataloader: torch.utils.data.DataLoader, + elements_shape: Tuple[int], + column_index: Optional[int] = None, + ) -> tf.data.Dataset: + """ + Converts a PyTorch torch.utils.data.DataLoader to a TensorFlow Dataset. + + Parameters + ---------- + dataloader + The DataLoader to convert. + elements_shape + The shape of the elements in the DataLoader. + column_index + The index of the column to convert. + If `None`, the entire DataLoader is converted. + + Returns + ------- + dataset + The converted dataset. + """ + + # make generator from dataloader + if column_index is None: + def generator(): + for elements in dataloader: + yield tf.cast(elements.numpy(), tf.float32) + else: + def generator(): + for elements in dataloader: + yield tf.cast(elements[column_index].numpy(), tf.float32) + + # create tf dataset from generator + dataset = tf.data.Dataset.from_generator( + generator, + output_signature=tf.TensorSpec(shape=elements_shape, dtype=tf.float32), + ) + + return dataset + + +def split_and_convert_column_dataloader( + cases_dataset: torch.utils.data.DataLoader, + labels_dataset: Optional[torch.utils.data.DataLoader] = None, + targets_dataset: Optional[torch.utils.data.DataLoader] = None, + ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: + """ + Splits a PyTorch DataLoader into cases, labels, and targets datasets. + The DataLoader is splitted only if it has multiple columns. + If the DataLoader has 2 columns, the second column is assumed to be the labels. + If the DataLoader has several columns but labels and targets are provided, + there is a conflict and an error is raised. + The splitted parts are then converted to TensorFlow datasets. + + Parameters + ---------- + cases_dataset + The dataset to split. + labels_dataset + Labels associated with the cases in the `cases_dataset`. + If this function is called, it should be `None`. + targets_dataset + Targets associated with the cases in the `cases_dataset`. + If this function is called and `cases_dataset` has 3 columns, it should be `None`. + + Returns + ------- + cases_dataset + The dataset used to train the model. + labels_dataset + Labels associated with the `cases_dataset`. + targets_dataset + Targets associated with the `cases_dataset`. + """ + # pylint: disable=too-many-branches + first_cases = next(iter(cases_dataset)) + + if not isinstance(first_cases, (tuple, list)): + # the cases dataset only has one column + + # manage cases dataset + cases_shape = (None,) + first_cases.shape[1:] + new_cases_dataset = convert_column_dataloader_to_tf_dataset(cases_dataset, cases_shape) + + else: + # manage cases dataset + cases_shape = (None,) + first_cases[0].shape[1:] + new_cases_dataset = convert_column_dataloader_to_tf_dataset( + cases_dataset, cases_shape, column_index=0) + + if len(first_cases) >= 2: + # the cases dataset has two columns + assert labels_dataset is None, ( + "The second column of `cases_dataset` is assumed to be the labels. "\ + + "Hence, `labels_dataset` should be empty." + ) + + # manage labels dataset (extract them from the second column of `cases_dataset`) + labels_shape = (None,) + first_cases[1].shape[1:] + labels_dataset = convert_column_dataloader_to_tf_dataset( + cases_dataset, labels_shape, column_index=1) + + if len(first_cases) == 3: + # the cases dataset has three columns + assert targets_dataset is None, ( + "The second and third columns of `cases_dataset` are assumed to be the labels "\ + "and targets. Hence, `labels_dataset` and `targets_dataset` should be empty." + ) + # manage targets dataset (extract them from the third column of `cases_dataset`) + targets_shape = (None,) + first_cases[2].shape[1:] + targets_dataset = convert_column_dataloader_to_tf_dataset( + cases_dataset, targets_shape, column_index=2) + + elif len(first_cases) > 3: + raise AttributeError( + "`cases_dataset` cannot have more than 3 columns, " + + f"{len(first_cases)} were detected." + ) + + # manage labels datasets + if labels_dataset is not None: + if isinstance(labels_dataset, tf.data.Dataset): + pass + elif isinstance(labels_dataset, torch.utils.data.DataLoader): + first_labels = next(iter(labels_dataset)) + if isinstance(first_labels, (tuple, list)): + assert len(first_labels) == 1, ( + "The `labels_dataset` should only have one column. " + + f"{len(first_labels)} were detected." + ) + labels_shape = (None,) + first_labels[0].shape[1:] + labels_dataset = convert_column_dataloader_to_tf_dataset( + labels_dataset, labels_shape, column_index=0 + ) + else: + labels_shape = (None,) + first_labels.shape[1:] + labels_dataset = convert_column_dataloader_to_tf_dataset( + labels_dataset, labels_shape + ) + else: + raise AttributeError( + "The `labels_dataset` should be a PyTorch DataLoader or a TensorFlow Dataset. " + + f"{type(labels_dataset)} was detected." + ) + else: + labels_dataset = None + + # manage targets datasets + if targets_dataset is not None: + if isinstance(targets_dataset, tf.data.Dataset): + pass + elif isinstance(targets_dataset, torch.utils.data.DataLoader): + first_targets = next(iter(targets_dataset)) + if isinstance(first_targets, (tuple, list)): + assert len(first_targets) == 1, ( + "The `targets_dataset` should only have one column. " + + f"{len(first_targets)} were detected." + ) + targets_shape = (None,) + first_targets[0].shape[1:] + targets_dataset = convert_column_dataloader_to_tf_dataset( + targets_dataset, targets_shape, column_index=0 + ) + else: + targets_shape = (None,) + first_targets.shape[1:] + targets_dataset = convert_column_dataloader_to_tf_dataset( + targets_dataset, targets_shape + ) + else: + raise AttributeError( + "The `labels_dataset` should be a PyTorch DataLoader or a TensorFlow Dataset. " + + f"{type(labels_dataset)} was detected." + ) + else: + targets_dataset = None + + return new_cases_dataset, labels_dataset, targets_dataset diff --git a/xplique/example_based/datasets_operations/harmonize.py b/xplique/example_based/datasets_operations/harmonize.py new file mode 100644 index 00000000..01581b17 --- /dev/null +++ b/xplique/example_based/datasets_operations/harmonize.py @@ -0,0 +1,235 @@ +""" +Allow Example-based methods to work with different types of datasets and tensors. +""" + + +import math + +import numpy as np +import tensorflow as tf + +from ...types import Optional, Tuple, DatasetOrTensor +from .tf_dataset_operations import sanitize_dataset, is_batched + + +def split_tf_dataset(cases_dataset: tf.data.Dataset, + labels_dataset: Optional[tf.data.Dataset] = None, + targets_dataset: Optional[tf.data.Dataset] = None + ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: + """ + Splits a TensorFlow dataset into cases, labels, and targets datasets. + The dataset is splitted only if it has multiple columns. + If the dataset has 2 columns, the second column is assumed to be the labels. + If the dataset has several columns but labels and targets are provided, + there is a conflict and an error is raised. + + Parameters + ---------- + cases_dataset + The dataset to split. + labels_dataset + Labels associated with the cases in the `cases_dataset`. + If this function is called, it should be `None`. + targets_dataset + Targets associated with the cases in the `cases_dataset`. + If this function is called and `cases_dataset` has 3 columns, it should be `None`. + + Returns + ------- + cases_dataset + The dataset used to train the model. + labels_dataset + Labels associated with the `cases_dataset`. + targets_dataset + Targets associated with the `cases_dataset`. + """ + + assert isinstance(cases_dataset, tf.data.Dataset), ( + f"The dataset should be a `tf.data.Dataset`, got {type(cases_dataset)}." + ) + + if isinstance(cases_dataset.element_spec, tuple): + if len(cases_dataset.element_spec) == 2: + assert labels_dataset is None, ( + "The second column of `cases_dataset` is assumed to be the labels. "\ + + "Hence, `labels_dataset` should be empty." + ) + labels_dataset = cases_dataset.map(lambda x, y: y) + cases_dataset = cases_dataset.map(lambda x, y: x) + elif len(cases_dataset.element_spec) == 3: + assert labels_dataset is None and targets_dataset is None, ( + "The second and third columns of `cases_dataset` are assumed to be the labels "\ + "and targets. Hence, `labels_dataset` and `targets_dataset` should be empty." + ) + targets_dataset = cases_dataset.map(lambda x, y, t: t) + labels_dataset = cases_dataset.map(lambda x, y, t: y) + cases_dataset = cases_dataset.map(lambda x, y, t: x) + else: + raise AttributeError( + "`cases_dataset` cannot have more than 3 columns, " + + f"{len(cases_dataset.element_spec)} were detected." + ) + + return cases_dataset, labels_dataset, targets_dataset + + +def harmonize_datasets( + cases_dataset: DatasetOrTensor, + labels_dataset: Optional[DatasetOrTensor] = None, + targets_dataset: Optional[DatasetOrTensor] = None, + batch_size: Optional[int] = None, + ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, int]: + """ + Harmonizes the provided datasets, transforming them to tf.data.Dataset if necessary. + Datasets are also checked in case they are shuffled or do not match in batch_size. + If the datasets have multiple columns, the function will split them into cases, + labels, and targets datasets based on the number of columns. + + This function supports both TensorFlow and PyTorch datasets. + + Parameters + ---------- + cases_dataset + The dataset used to train the model, examples are extracted from this dataset. + All datasets (cases, labels, and targets) should be of the same type. + Supported types are: `tf.data.Dataset`, `torch.utils.data.DataLoader`, + `tf.Tensor`, `np.ndarray`, `torch.Tensor`. + For datasets with multiple columns, the first column is assumed to be the cases. + While the second column is assumed to be the labels, and the third the targets. + Warning: datasets tend to reshuffle at each iteration, ensure the datasets are + not reshuffle as we use index in the dataset. + labels_dataset + Labels associated with the examples in the `cases_dataset`. + It should have the same type as `cases_dataset`. + targets_dataset + Targets associated with the `cases_dataset` for dataset projection, + oftentimes the one-hot encoding of a model's predictions. See `projection` for detail. + It should have the same type as `cases_dataset`. + It is not be necessary for all projections. + Furthermore, projections which requires it compute it internally by default. + batch_size : Optional[int] + Number of samples treated simultaneously when using the datasets. + It should match the batch size of the datasets if they are batched. + + Returns + ------- + cases_dataset + The harmonized dataset used to train the model. + labels_dataset + Harmonized labels associated with the `cases_dataset`. + targets_dataset + Harmonized targets associated with the `cases_dataset`. + batch_size : int + Number of samples treated simultaneously when using the datasets. + """ + # pylint: disable=too-many-statements + # pylint: disable=too-many-branches + # Ensure the datasets are of the same type + if labels_dataset is not None: + if isinstance(cases_dataset, tf.data.Dataset): + assert isinstance(labels_dataset, tf.data.Dataset), ( + "The labels_dataset should be a `tf.data.Dataset` if the cases_dataset is." + ) + assert not isinstance(labels_dataset.element_spec, tuple), ( + "The labels_dataset should only have one column." + ) + else: + assert isinstance(cases_dataset, type(labels_dataset)), ( + "The cases_dataset and labels_dataset should be of the same type."\ + + f"Got {type(cases_dataset)} and {type(labels_dataset)}." + ) + if targets_dataset is not None: + if isinstance(cases_dataset, tf.data.Dataset): + assert isinstance(targets_dataset, tf.data.Dataset), ( + "The targets_dataset should be a `tf.data.Dataset` if the cases_dataset is." + ) + assert not isinstance(targets_dataset.element_spec, tuple), ( + "The targets_dataset should only have one column." + ) + else: + assert isinstance(cases_dataset, type(targets_dataset)), ( + "The cases_dataset and targets_dataset should be of the same type."\ + + f"Got {type(cases_dataset)} and {type(targets_dataset)}." + ) + + # Determine batch size and cardinality based on the dataset type + # for torch elements, convert them to numpy arrays or tf datasets + if isinstance(cases_dataset, tf.data.Dataset): + # compute batch size and cardinality + if is_batched(cases_dataset): + if isinstance(cases_dataset.element_spec, tuple): + batch_size = tf.shape(next(iter(cases_dataset))[0])[0].numpy() + else: + batch_size = tf.shape(next(iter(cases_dataset)))[0].numpy() + else: + assert batch_size is not None, ( + "The dataset is not batched, hence a `batch_size` should be provided." + ) + cases_dataset = cases_dataset.batch(batch_size) + cardinality = cases_dataset.cardinality().numpy() + + # handle multi-column datasets + if isinstance(cases_dataset.element_spec, tuple): + # split dataset if `cases_dataset` has multiple columns + cases_dataset, labels_dataset, targets_dataset =\ + split_tf_dataset(cases_dataset, labels_dataset, targets_dataset) + elif isinstance(cases_dataset, (np.ndarray, tf.Tensor)): + # compute batch size and cardinality + if batch_size is None: + # no batching, one batch encompass all the dataset + batch_size = cases_dataset.shape[0] + else: + batch_size = min(batch_size, cases_dataset.shape[0]) + cardinality = math.ceil(cases_dataset.shape[0] / batch_size) + + # tensors will be converted to tf.data.Dataset via the snitize function + else: + error_message = "Unknown cases dataset type, should be in: [tf.data.Dataset, tf.Tensor, "\ + + "np.ndarray, torch.Tensor, torch.utils.data.DataLoader]. "\ + + f"But got {type(cases_dataset)} instead." + # try to import torch and torch.utils.data.DataLoader to treat possible input types + try: + # pylint: disable=import-outside-toplevel + import torch + from .convert_torch_to_tf import split_and_convert_column_dataloader + except ImportError as exc: + raise AttributeError(error_message) from exc + + if isinstance(cases_dataset, torch.Tensor): + # compute batch size and cardinality + if batch_size is None: + # no batching, one batch encompass all the dataset + batch_size = cases_dataset.shape[0] + else: + batch_size = min(batch_size, cases_dataset.shape[0]) + cardinality = math.ceil(cases_dataset.shape[0] / batch_size) + + # convert torch tensor to numpy array + cases_dataset = cases_dataset.cpu().numpy() + if labels_dataset is not None: + labels_dataset = labels_dataset.cpu().numpy() + if targets_dataset is not None: + targets_dataset = targets_dataset.cpu().numpy() + + # tensors will be converted to tf.data.Dataset via the snitize function + elif isinstance(cases_dataset, torch.utils.data.DataLoader): + batch_size = cases_dataset.batch_size + cardinality = len(cases_dataset) + cases_dataset, labels_dataset, targets_dataset =\ + split_and_convert_column_dataloader(cases_dataset, labels_dataset, targets_dataset) + else: + raise AttributeError(error_message) + + # Sanitize datasets to ensure they are in the correct format + cases_dataset = sanitize_dataset(cases_dataset, batch_size, cardinality) + labels_dataset = sanitize_dataset(labels_dataset, batch_size, cardinality) + targets_dataset = sanitize_dataset(targets_dataset, batch_size, cardinality) + + # Prefetch datasets + cases_dataset = cases_dataset.prefetch(tf.data.AUTOTUNE) + if labels_dataset is not None: + labels_dataset = labels_dataset.prefetch(tf.data.AUTOTUNE) + if targets_dataset is not None: + targets_dataset = targets_dataset.prefetch(tf.data.AUTOTUNE) + + return cases_dataset, labels_dataset, targets_dataset, batch_size diff --git a/xplique/example_based/datasets_operations/tf_dataset_operations.py b/xplique/example_based/datasets_operations/tf_dataset_operations.py new file mode 100644 index 00000000..b975dca7 --- /dev/null +++ b/xplique/example_based/datasets_operations/tf_dataset_operations.py @@ -0,0 +1,298 @@ +""" +Set of functions to manipulated `tf.data.Dataset` +""" +from itertools import product + +import numpy as np +import tensorflow as tf + +from ...types import Optional, Union + + +def _almost_equal(arr1, arr2, epsilon=1e-6): + """Ensure two array are almost equal at an epsilon""" + return np.shape(arr1) == np.shape(arr2) and np.sum(np.abs(arr1 - arr2)) < epsilon + + +def are_dataset_first_elems_equal(dataset1: Optional[tf.data.Dataset] = None, + dataset2: Optional[tf.data.Dataset] = None, + ) -> bool: + """ + Test if the first batch of elements of two datasets are the same. + It is used to verify equality between datasets in a lazy way. + + Parameters + ---------- + dataset1 + First `tf.data.Dataset` to compare. + dataset2 + Second `tf.data.Dataset` to compare. + + Returns + ------- + test_result + Boolean value of the equality. + """ + if dataset1 is None: + return dataset2 is None + + if dataset2 is None: + return False + + next1 = next(iter(dataset1)) + next2 = next(iter(dataset2)) + if isinstance(next1, tuple): + next1 = next1[0] + if isinstance(next2, tuple): + next2 = next2[0] + else: + return False + + return _almost_equal(next1, next2) + + +def is_batched(dataset: tf.data.Dataset) -> bool: + """ + Check if a TensorFlow dataset is batched. + + Parameters + ---------- + dataset : tf.data.Dataset + The dataset to check. + + Returns + ------- + bool + True if the dataset is batched, False otherwise. + """ + # Extract the element_spec + spec = dataset.element_spec + + # Handle datasets with tuple or dict structures + if isinstance(spec, (tuple, dict)): + # Check if any part of the element_spec is batched + if isinstance(spec, tuple): + return all(s.shape[0] is None for s in spec) + if isinstance(spec, dict): + return all(s.shape[0] is None for s in spec.values()) + else: + # Check if the first dimension is None (indicating batching) + return spec.shape[0] is None + + # If we reach here, it's not batched + return False + + +def is_shuffled(dataset: Optional[tf.data.Dataset]) -> bool: + """ + Test if the provided dataset reshuffle at each iteration. + Tensorflow do not provide clean way to verify it, + hence we draw two times the first element and compare it. + It may not always detect shuffled datasets, but this is enough of a safety net. + + Parameters + ---------- + dataset + Tensorflow dataset to test. + + Returns + ------- + test_result + Boolean value of the test. + """ + if are_dataset_first_elems_equal(dataset, dataset): + # test a second time to minimize the risk of false positive + return not are_dataset_first_elems_equal(dataset, dataset) + return True + + +def batch_size_matches(dataset: Optional[tf.data.Dataset], batch_size: int) -> bool: + """ + Test if batch size of a tensorflow dataset matches the expected one. + Tensorflow do not provide clean way to verify it, + hence we draw a batch and check its first dimension. + It may fail in some really precise cases, but this is enough of a safety net. + + Parameters + ---------- + dataset + Tensorflow dataset to test. + batch_size + The expected batch size of the dataset. + + Returns + ------- + test_result + Boolean value of the test. + """ + if dataset is None: + # ignored + return True + + if not is_batched(dataset): + return False + + first_item = next(iter(dataset)) + if isinstance(first_item, tuple): + return tf.reduce_all( + [tf.shape(item)[0].numpy() == batch_size for item in first_item] + ) + return tf.shape(first_item)[0].numpy() == batch_size + + +def sanitize_dataset( + dataset: Union[tf.data.Dataset, tf.Tensor, np.array], + batch_size: int, + cardinality: Optional[int] = None, +) -> Optional[tf.data.Dataset]: + """ + Function to ensure input dataset match expected format. + It also transforms tensors in `tf.data.Dataset` and also verify the properties. + This function verify that datasets do not reshuffle at each iteration and + that their batch and cardinality match the expected ones. + Note that, that Tensorflow do not provide easy way to make those tests, hence, + for cost constraints, our tests are not perfect. + + Parameters + ---------- + dataset + Tensorflow dataset to verify or tensor to transform in `tf.data.Dataset` and verify. + batch_size + The expected batch size used either to verify the input dataset + or batch the transformed tensor. + cardinality + Expected number of batch in the dataset or batched transformed tensor. + + Returns + ------- + dataset + Verified dataset or transformed tensor. In both case a `tf.data.Dataset`, + that does not reshuffle at each iteration and + with batch size and cardinality matching the expected ones. + """ + if dataset is not None: + if isinstance(dataset, tf.data.Dataset): + assert not is_shuffled(dataset), ( + "Datasets should not be shuffled, " + + "the order of the element should stay the same at each iteration." + ) + if not is_batched(dataset): + dataset = dataset.batch(batch_size) + else: + assert batch_size_matches( + dataset, batch_size + ), "The batch size should match between datasets." + elif isinstance(dataset, (tf.Tensor, np.ndarray)): + dataset = tf.data.Dataset.from_tensor_slices(dataset).batch(batch_size) + else: + raise ValueError( + "The input dataset should be a `tf.data.Dataset`, a `tf.Tensor` or a `np.array`. " + + f"Received {type(dataset)}." + ) + + if cardinality is not None and cardinality > 0: + dataset_cardinality = dataset.cardinality().numpy() + if dataset_cardinality > 0: + assert dataset_cardinality == cardinality, ( + "The number of batch should match between datasets. " + + f"Received {dataset.cardinality().numpy()} vs {cardinality}. " + + "You may have provided non-batched datasets "\ + + "or datasets with different lengths." + ) + else: + # negative cardinality means unknown cardinality, + # it will be the case for datasets created from generator, thus torch converted ones + pass + + return dataset + + +def dataset_gather(dataset: tf.data.Dataset, indices: tf.Tensor) -> tf.Tensor: + """ + Imitation of `tf.gather` for `tf.data.Dataset`, + it extracts elements from `dataset` at the given indices. + We could see it as returning the `indices` tensor + where each index was replaced by the corresponding element in `dataset`. + The aim is to use it in the `example_based` module to extract examples form the cases dataset. + Hence, `indices` expect dimensions of (n, k, 2), + where n represent the number of inputs and k the number of corresponding examples. + Here indices for each element are encoded by two values, + the batch index and the index of the element in the batch. + + Example of application + ``` + >>> dataset = tf.data.Dataset.from_tensor_slices( + ... tf.reshape(tf.range(20), (-1, 2, 2)) + ... ).batch(3) # shape=(None, 2, 2) + >>> indices = tf.constant([[[0, 0]], [[1, 0]]]) # shape=(2, 1, 2) + >>> dataset_gather(dataset, indices) + + ``` + + Parameters + ---------- + dataset + The dataset from which to extract elements. + indices + Tensor of indices of elements to extract from the `dataset`. + `indices` should be of dimensions (n, k, 2), + this is to match the format of indices in the `example_based` module. + Indeed, n represent the number of inputs and k the number of corresponding examples. + The index of each element is encoded by two values, + the batch index and the index of the element in the batch. + + Returns + ------- + results + A tensor with the extracted elements from the `dataset`. + The shape of the tensor is (n, k, ...), + where ... is the shape of the elements in the `dataset`. + """ + if dataset is None: + return None + + if len(indices.shape) != 3 or indices.shape[-1] != 2: + raise ValueError( + "Indices should have dimensions (n, k, 2), "\ + + "where n represent the number of inputs and k the number of corresponding examples. "\ + + "The index of each element is encoded by two values, "\ + + "the batch index and the index of the element in the batch. "\ + + f"Received {indices.shape}." + ) + + example = next(iter(dataset)) + + if dataset.element_spec.dtype in ['uint8', 'int8', 'int16', 'int32', 'int64']: + results = tf.fill(dims=indices.shape[:-1] + example[0].shape, + value=tf.constant(-1, dtype=dataset.element_spec.dtype)) + else: + results = tf.fill(dims=indices.shape[:-1] + example[0].shape, + value=tf.constant(np.inf, dtype=dataset.element_spec.dtype)) + + nb_results = product(indices.shape[:-1]) + current_nb_results = 0 + + for i, batch in enumerate(dataset): + # check if the batch is interesting + if not tf.reduce_any(indices[..., 0] == i): + continue + + # extract pertinent elements + pertinent_indices_location = tf.where(indices[..., 0] == i) + samples_index = tf.gather_nd(indices[..., 1], pertinent_indices_location) + samples = tf.gather(batch, samples_index) + + # put them at the right place in results + for location, sample in zip(pertinent_indices_location, samples): + results = tf.tensor_scatter_nd_update(results, [location], [sample]) + current_nb_results += 1 + + # test if results are filled to break the loop + if current_nb_results == nb_results: + break + + return results diff --git a/xplique/example_based/projections/__init__.py b/xplique/example_based/projections/__init__.py new file mode 100644 index 00000000..4b33a895 --- /dev/null +++ b/xplique/example_based/projections/__init__.py @@ -0,0 +1,8 @@ +""" +Projections +""" + +from .attributions import AttributionProjection +from .base import Projection +from .hadamard import HadamardProjection +from .latent_space import LatentSpaceProjection diff --git a/xplique/example_based/projections/attributions.py b/xplique/example_based/projections/attributions.py new file mode 100644 index 00000000..78207ede --- /dev/null +++ b/xplique/example_based/projections/attributions.py @@ -0,0 +1,87 @@ +""" +Attribution, a projection from example based module +""" +import warnings + +import tensorflow as tf + +from xplique.types import Optional + +from ...attributions.base import BlackBoxExplainer +from ...attributions import Saliency +from ...types import Union, Optional + +from .base import Projection +from .commons import model_splitting, target_free_classification_operator + + +class AttributionProjection(Projection): + """ + Projection build on an attribution function to provide local projections. + This class is used as the projection of the `Cole` similar examples method. + + Depending on the `latent_layer`, the model will be splitted between + the feature extractor and the predictor. + The feature extractor will become the `space_projection()` method, then + the predictor will be used to build the attribution method explain, and + its `explain()` method will become the `get_weights()` method. + + If no `latent_layer` is provided, the model is not splitted, + the `space_projection()` is the identity function, and + the attributions (`get_weights()`) are compute on the whole model. + + Parameters + ---------- + model + The model from which we want to obtain explanations. + latent_layer + Layer used to split the model, the first part will be used for projection and + the second to compute the attributions. By default, the model is not split. + For such split, the `model` should be a `tf.keras.Model`. + + If an `int` is provided it will be interpreted as a layer index. + If a `string` is provided it will look for the layer name. + + The method as described in the paper apply the separation on the last convolutional layer. + To do so, the `"last_conv"` parameter will extract it. + Otherwise, `-1` could be used for the last layer before softmax. + attribution_method + Class of the attribution method to use for projection. + It should inherit from `xplique.attributions.base.BlackBoxExplainer`. + Ignored if a projection is given. + attribution_kwargs + Parameters to be passed at the construction of the `attribution_method`. + """ + + def __init__( + self, + model: Union[tf.keras.Model, 'torch.nn.Module'], + attribution_method: BlackBoxExplainer = Saliency, + latent_layer: Optional[Union[str, int]] = None, + **attribution_kwargs + ): + self.attribution_method = attribution_method + + if latent_layer is None: + # no split + self.latent_layer = None + space_projection = None + self.predictor = model + else: + # split the model if a latent_layer is provided + space_projection, self.predictor = model_splitting(model, latent_layer) + + # change default operator + if "operator" not in attribution_kwargs or attribution_kwargs["operator"] is None: + warnings.warn("No operator provided, using standard classification operator. "\ + + "For non-classification tasks, please specify an operator.") + attribution_kwargs["operator"] = target_free_classification_operator + + # compute attributions + get_weights = self.attribution_method(self.predictor, **attribution_kwargs) + + # set methods + super().__init__(get_weights=get_weights, + space_projection=space_projection, + mappable=False, + requires_targets=True) diff --git a/xplique/example_based/projections/base.py b/xplique/example_based/projections/base.py new file mode 100644 index 00000000..1d3a1345 --- /dev/null +++ b/xplique/example_based/projections/base.py @@ -0,0 +1,257 @@ +""" +Base projection for similar examples in example based module +""" + +import warnings + +import tensorflow as tf +import numpy as np + +from ...commons import sanitize_inputs_targets, get_device +from ...types import Callable, Union, Optional + + +class Projection(): + """ + Base class used by `BaseExampleMethod` to project samples to a meaningful space + for the model to explain. + + Projection have two parts a `space_projection` and `weights`, to apply a projection, + the samples are first projected to a new space and then weighted. + Either the `space_projection` or the `weights` could be `None` but, + if both are, the projection is an identity function. + + At least one of the two part should include the model in the computation + for distance between projected elements to make sense for the model. + + Note that the cost of this projection should be limited + as it will be applied to all samples of the train dataset. + + Parameters + ---------- + get_weights + Either a Tensor or a Callable. + - In the case of a Tensor, weights are applied in the projected space. + - In the case of a callable, a function is expected. + It should take inputs and targets as parameters and return the weights (Tensor). + Weights should have the same shape as the input (possible difference on channels). + The inputs of `get_weights()` correspond to the projected inputs. + + Example of `get_weights()` function: + ``` + def get_weights_example(projected_inputs: Union(tf.Tensor, np.ndarray), + targets: Optional[Union[tf.Tensor, np.ndarray]] = None): + ''' + Example of function to get weights, + projected_inputs are the elements for which weights are computed. + targets are optional additional parameters for weights computation. + ''' + weights = ... # do some magic with inputs and targets, it should use the model. + return weights + ``` + space_projection + Callable that take samples and return a Tensor in the projected space. + An example of projected space is the latent space of a model. See `LatentSpaceProjection` + device + Device to use for the projection, if None, use the default device. + mappable + If True, the projection can be applied to a `tf.data.Dataset` through `Dataset.map`. + Otherwise, the dataset projection will be done through a loop. + It is not the case for wrapped PyTorch models. + If you encounter errors in the `project_dataset` method, you can set it to `False`. + """ + + def __init__(self, + get_weights: Optional[Union[Callable, tf.Tensor, np.ndarray]] = None, + space_projection: Optional[Callable] = None, + device: Optional[str] = None, + mappable: bool = False, + requires_targets: bool = False): + if get_weights is None and space_projection is None: + warnings.warn( + "At least one of `get_weights` and `space_projection`" + + "should not be `None`. Otherwise the projection is an identity function." + ) + + self.requires_targets = requires_targets + + # set get_weights + if get_weights is None: + # no weights + self.get_weights = lambda inputs, _: tf.ones(tf.shape(inputs)) + elif isinstance(get_weights, (tf.Tensor, np.ndarray)): + # weights is a tensor + if isinstance(get_weights, np.ndarray): + weights = tf.convert_to_tensor(get_weights, dtype=tf.float32) + mappable = False + else: + weights = get_weights + + # define a function that returns the weights + self.get_weights = lambda inputs, _: tf.repeat(tf.expand_dims(weights, axis=0), + tf.shape(inputs)[0], + axis=0) + elif hasattr(get_weights, "__call__"): + # weights is a function + self.get_weights = get_weights + else: + raise TypeError( + f"`get_weights` should be `Callable` or a Tensor, not a {type(get_weights)}" + ) + + # set space_projection + if space_projection is None: + self.space_projection = lambda inputs: inputs + elif hasattr(space_projection, "__call__"): + self.space_projection = space_projection + else: + raise TypeError( + f"`space_projection` should be a `Callable`, not a {type(space_projection)}" + ) + + self.mappable = mappable + + # set device + self.device = get_device(device) + + @sanitize_inputs_targets + def project( + self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None, + ): + """ + Project samples in a space meaningful for the model, + either by weights the inputs, projecting in a latent space or both. + This function should be called at the init and for each explanation. + + Parameters + ---------- + inputs + Tensor or Array. Input samples to be explained. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + More information in the documentation. + targets + Additional parameter for `self.get_weights` function. + + Returns + ------- + projected_samples + The samples projected in the new space. + """ + with tf.device(self.device): + projected_inputs = self.space_projection(inputs) + weights = self.get_weights(projected_inputs, targets) + weighted_projected_inputs = tf.multiply(weights, projected_inputs) + return weighted_projected_inputs + + def __call__( + self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None, + ): + """project alias""" + return self.project(inputs, targets) + + def project_dataset( + self, + cases_dataset: tf.data.Dataset, + targets_dataset: Optional[tf.data.Dataset] = None, + ) -> Optional[tf.data.Dataset]: + """ + Apply the projection to a dataset through `Dataset.map` + + Parameters + ---------- + cases_dataset + Dataset of samples to be projected. + targets_dataset + Dataset of targets for the samples. + + Returns + ------- + projected_dataset + The projected dataset. + """ + if self.requires_targets and targets_dataset is None: + warnings.warn( + "The projection requires `targets` but `targets_dataset` is not provided. "\ + +"`targets` will be computed online, assuming a classification setting. "\ + +"Hence, online `targets` will be the predicted class one-hot-encoding. "\ + +"If this is not the expected behavior, please provide a `targets_dataset`.") + + if self.mappable: + return self._map_project_dataset(cases_dataset, targets_dataset) + return self._loop_project_dataset(cases_dataset, targets_dataset) + + def _map_project_dataset( + self, + cases_dataset: tf.data.Dataset, + targets_dataset: Optional[tf.data.Dataset] = None, + ) -> Optional[tf.data.Dataset]: + """ + Apply the projection to a dataset through `Dataset.map` + + Parameters + ---------- + cases_dataset + Dataset of samples to be projected. + targets_dataset + Dataset of targets for the samples. + + Returns + ------- + projected_dataset + The projected dataset. + """ + # project dataset, note that projection is done at iteration time + if targets_dataset is None: + projected_cases_dataset = cases_dataset.map(self.project) + else: + # in case targets are provided, we zip the datasets and project them together + projected_cases_dataset = tf.data.Dataset.zip( + (cases_dataset, targets_dataset) + ).map(self.project) + + return projected_cases_dataset + + def _loop_project_dataset( + self, + cases_dataset: tf.data.Dataset, + targets_dataset: tf.data.Dataset, + ) -> tf.data.Dataset: + """ + Apply the projection to a dataset without `Dataset.map`. + Because some projections are not compatible with a `tf.data.Dataset.map`. + For example, the attribution methods, because they create a `tf.data.Dataset` for batching, + however doing so inside a `Dataset.map` is not recommended. + + Parameters + ---------- + cases_dataset + Dataset of samples to be projected. + targets_dataset + Dataset of targets for the samples. + + Returns + ------- + projected_dataset + The projected dataset. + """ + projected_cases_dataset = [] + batch_size = tf.shape(next(iter(cases_dataset)))[0].numpy() + + # iteratively project the dataset + if targets_dataset is None: + for inputs in cases_dataset: + projected_cases_dataset.append(self.project(inputs, None)) + else: + # in case targets are provided, we zip the datasets and project them together + for inputs, targets in tf.data.Dataset.zip((cases_dataset, targets_dataset)): + projected_cases_dataset.append(self.project(inputs, targets)) + + projected_cases_dataset = tf.concat(projected_cases_dataset, axis=0) + projected_cases_dataset = tf.data.Dataset.from_tensor_slices(projected_cases_dataset) + projected_cases_dataset = projected_cases_dataset.batch(batch_size) + + return projected_cases_dataset diff --git a/xplique/example_based/projections/commons.py b/xplique/example_based/projections/commons.py new file mode 100644 index 00000000..e5a93ac6 --- /dev/null +++ b/xplique/example_based/projections/commons.py @@ -0,0 +1,246 @@ +""" +Commons for projections +""" +import warnings + +import tensorflow as tf + +from ...commons import find_layer +from ...types import Callable, Union, Optional, Tuple + + +def model_splitting( + model: Union[tf.keras.Model, 'torch.nn.Module'], + latent_layer: Union[str, int], + device: Union["torch.device", str] = None, + ) -> Tuple[Union[tf.keras.Model, 'torch.nn.Module'], Union[tf.keras.Model, 'torch.nn.Module']]: + """ + Split the model into two parts, before and after the `latent_layer`. + The parts will respectively be called `features_extractor` and `predictor`. + + Parameters + ---------- + model + Model to split. + latent_layer + Layer used to split the `model`. + + If an `int` is provided it will be interpreted as a layer index. + If a `string` is provided it will look for the layer name. + + To separate after the last convolution, `"last_conv"` can be used. + Otherwise, `-1` could be used for the last layer before softmax. + device + Device to use for the projection, if None, use the default device. + Only used for PyTorch models. Ignored for TensorFlow models. + + Returns + ------- + features_extractor + Model used to project the inputs. + predictor + Model used to compute the attributions. + latent_layer + Layer used to split the `model`. + """ + if isinstance(model, tf.keras.Model): + return _tf_model_splitting(model, latent_layer) + try: + return _torch_model_splitting(model, latent_layer, device) + except ImportError as exc: + raise AttributeError( + "Unknown model type, should be either `tf.keras.Model` or `torch.nn.Module`. "\ + +f"But got {type(model)} instead.") from exc + + +def _tf_model_splitting(model: tf.keras.Model, + latent_layer: Union[str, int], + ) -> Tuple[tf.keras.Model, tf.keras.Model]: + """ + Split the model into two parts, before and after the `latent_layer`. + The parts will respectively be called `features_extractor` and `predictor`. + + Parameters + ---------- + model + Model to split. + latent_layer + Layer used to split the `model`. + + If an `int` is provided it will be interpreted as a layer index. + If a `string` is provided it will look for the layer name. + + To separate after the last convolution, `"last_conv"` can be used. + Otherwise, `-1` could be used for the last layer before softmax. + + Returns + ------- + features_extractor + Model used to project the inputs. + predictor + Model used to compute the attributions. + latent_layer + Layer used to split the `model`. + """ + + warnings.warn( + "Automatically splitting the provided TensorFlow model into two parts. "\ + +"This splitting is not robust to all models. "\ + +"It is recommended to split the model manually. "\ + +"Then the splitted parts can be provided at the method initialization.") + + if latent_layer == "last_conv": + latent_layer = next( + layer for layer in model.layers[::-1] if hasattr(layer, "filters") + ) + else: + latent_layer = find_layer(model, latent_layer) + + features_extractor = tf.keras.Model( + model.input, latent_layer.output, name="features_extractor" + ) + second_input = tf.keras.Input(shape=latent_layer.output_shape[1:]) + + # Reconstruct the second part of the model + new_input = second_input + layer_found = False + for layer in model.layers: + if layer_found: + new_input = layer(new_input) + if layer == latent_layer: + layer_found = True + + # Create the second part of the model (predictor) + predictor = tf.keras.Model( + inputs=second_input, + outputs=new_input, + name="predictor" + ) + + return features_extractor, predictor + + +def _torch_model_splitting( + model: 'torch.nn.Module', + latent_layer: Union[str, int], + device: Union["torch.device", str] = None, + ) -> Tuple['torch.nn.Module', 'torch.nn.Module']: + """ + Split the model into two parts, before and after the `latent_layer`. + The parts will respectively be called `features_extractor` and `predictor`. + + Parameters + ---------- + model + Model to split. + latent_layer + Layer used to split the `model`. + + If an `int` is provided it will be interpreted as a layer index. + If a `string` is provided it will look for the layer name. + + To separate after the last convolution, `"last_conv"` can be used. + Otherwise, `-1` could be used for the last layer before softmax. + Device to use for the projection, if None, use the default device. + + Returns + ------- + features_extractor + Model used to project the inputs. + predictor + Model used to compute the attributions. + latent_layer + Layer used to split the `model`. + """ + # pylint: disable=import-outside-toplevel + import torch + from torch import nn + from ...wrappers import TorchWrapper + + warnings.warn( + "Automatically splitting the provided PyTorch model into two parts. "\ + +"This splitting is based on `model.named_children()`. "\ + +"If the model cannot be reconstructed via sub-modules, errors are to be expected. "\ + +"It is recommended to split the model manually and wrap it with `TorchWrapper`. "\ + +"Then the wrapped parts can be provided at the method initialization.") + + if device is None: + warnings.warn( + "No device provided for the projection, using 'cuda' if available, else 'cpu'." + ) + device = "cuda" if torch.cuda.is_available() else "cpu" + + first_model = nn.Sequential() + second_model = nn.Sequential() + split_flag = False + + if isinstance(latent_layer, int) and latent_layer < 0: + latent_layer = len(list(model.children())) + latent_layer + + for layer_index, (name, module) in enumerate(model.named_children()): + if latent_layer in [layer_index, name]: + split_flag = True + + if not split_flag: + first_model.add_module(name, module) + else: + second_model.add_module(name, module) + + # Define forward function for the first model + def first_model_forward(new_input): + for module in first_model: + new_input = module(new_input) + return new_input + + # Define forward function for the second model + def second_model_forward(new_input): + for module in second_model: + new_input = module(new_input) + return new_input + + # Set the forward functions for the models + first_model.forward = first_model_forward + second_model.forward = second_model_forward + + # Wrap models to obtain tensorflow ones + first_model.eval() + wrapped_first_model = TorchWrapper(first_model, device=device) + second_model.eval() + wrapped_second_model = TorchWrapper(second_model, device=device) + + return wrapped_first_model, wrapped_second_model + + +@tf.function +def target_free_classification_operator(model: Callable, + inputs: tf.Tensor, + targets: Optional[tf.Tensor] = None) -> tf.Tensor: + """ + Compute predictions scores, only for the label class, for a batch of samples. + It has the same behavior as `Tasks.CLASSIFICATION` operator + but computes targets at the same time if not provided. + Targets are a mask with 1 on the predicted class and 0 elsewhere. + This operator should only be used for classification tasks. + + + Parameters + ---------- + model + Model used for computing predictions. + inputs + Input samples to be explained. + targets + One-hot encoded labels or regression target (e.g {+1, -1}), one for each sample. + + Returns + ------- + scores + Predictions scores computed, only for the label class. + """ + predictions = model(inputs) + + # the condition is always the same, hence this should not affect the graph + if targets is None: + targets = tf.one_hot(tf.argmax(predictions, axis=-1), predictions.shape[-1]) + + return tf.reduce_sum(predictions * targets, axis=-1) diff --git a/xplique/example_based/projections/hadamard.py b/xplique/example_based/projections/hadamard.py new file mode 100644 index 00000000..fe71f465 --- /dev/null +++ b/xplique/example_based/projections/hadamard.py @@ -0,0 +1,123 @@ +""" +Attribution, a projection from example based module +""" +import warnings + +import tensorflow as tf +from xplique.types import Optional + +from ...commons import get_gradient_functions +from ...types import Union, Optional, OperatorSignature + +from .base import Projection +from .commons import model_splitting, target_free_classification_operator + + +class HadamardProjection(Projection): + """ + Projection build on an the latent space and the gradient. + This class is used as the projection of the `Cole` similar examples method. + + Depending on the `latent_layer`, the model will be splitted between + the feature extractor and the predictor. + The feature extractor will become the `space_projection()` method, then + the predictor will be used to build the attribution method explain, and + its `explain()` method will become the `get_weights()` method. + + If no `latent_layer` is provided, the model is not splitted, + the `space_projection()` is the identity function, and + the attributions (`get_weights()`) are compute on the whole model. + + Parameters + ---------- + model + The model from which we want to obtain explanations. + It can be splitted manually outside of the projection and provided as two models: + the `feature_extractor` and the `predictor`. In this case, `model` should be `None`. + It is recommended to split it manually. + latent_layer + Layer used to split the model, the first part will be used for projection and + the second to compute the attributions. By default, the model is not split. + For such split, the `model` should be a `tf.keras.Model`. + Ignored if `model` is `None`, hence if a splitted model is provided through: + the `feature_extractor` and the `predictor`. + + If an `int` is provided it will be interpreted as a layer index. + If a `string` is provided it will look for the layer name. + + The method as described in the paper apply the separation on the last convolutional layer. + To do so, the `"last_conv"` parameter will extract it. + Otherwise, `-1` could be used for the last layer before softmax. + operator + Operator to use to compute the explanation, if None use standard predictions. + The default operator is the classification operator with online targets computations. + For more information, refer to the Attribution documentation. + device + Device to use for the projection, if None, use the default device. + Only used for PyTorch models. Ignored for TensorFlow models. + features_extractor + The feature extraction part of the model. Mapping inputs to the latent space. + Used to provided the first part of a splitted model. + It cannot be provided if a `model` is provided. It should be provided with a `predictor`. + predictor + The prediction part of the model. Mapping the latent space to the outputs. + Used to provided the second part of a splitted model. + It cannot be provided if a `model` is provided. + It should be provided with a `features_extractor`. + mappable + If the model parts can be placed in a `tf.data.Dataset` mapping function. + It is not the case for wrapped PyTorch models. + If you encounter errors in the `project_dataset` method, you can set it to `False`. + Used only for a splitted model. Thgus if `model` is `None`. + """ + def __init__( + self, + model: Optional[Union[tf.keras.Model, 'torch.nn.Module']] = None, + latent_layer: Optional[Union[str, int]] = None, + operator: Optional[OperatorSignature] = None, + device: Union["torch.device", str] = None, + features_extractor: Optional[tf.keras.Model] = None, + predictor: Optional[tf.keras.Model] = None, + mappable: bool = True, + ): + if model is None: + assert features_extractor is not None and predictor is not None,\ + "If no model is provided, the features_extractor and predictor should be provided." + + assert isinstance(features_extractor, tf.keras.Model)\ + and isinstance(predictor, tf.keras.Model),\ + "The features_extractor and predictor should be tf.keras.Model."\ + + "The xplique.wrappers.TorchWrapper can be used for PyTorch models." + else: + assert features_extractor is None and predictor is None,\ + "If a model is provided, the features_extractor and predictor cannot be provided." + + if latent_layer is None: + # no split + self.latent_layer = None + features_extractor = None + predictor = model + else: + # split the model if a latent_layer is provided + features_extractor, predictor = model_splitting(model, + latent_layer=latent_layer, + device=device) + + mappable = isinstance(model, tf.keras.Model) + + if operator is None: + warnings.warn("No operator provided, using standard classification operator. "\ + + "For non-classification tasks, please specify an operator.") + operator = target_free_classification_operator + + # the weights are given by the gradient of the operator based on the predictor + gradients, _ = get_gradient_functions(predictor, operator) + get_weights = lambda inputs, targets: gradients(predictor, inputs, targets) + + # set methods + super().__init__( + get_weights=get_weights, + space_projection=features_extractor, + mappable=mappable, + requires_targets=True + ) diff --git a/xplique/example_based/projections/latent_space.py b/xplique/example_based/projections/latent_space.py new file mode 100644 index 00000000..0d7a8db8 --- /dev/null +++ b/xplique/example_based/projections/latent_space.py @@ -0,0 +1,61 @@ +""" +Custom, a projection from example based module +""" + +import tensorflow as tf + +from ...types import Union + +from .base import Projection +from .commons import model_splitting + + +class LatentSpaceProjection(Projection): + """ + Projection that project inputs in the model latent space. + It does not have weighting. + + Parameters + ---------- + model + The model from which we want to obtain explanations. + It will be splitted if a `latent_layer` is provided. + Otherwise, it should be a `tf.keras.Model`. + It is recommended to split it manually and provide the first part of the model directly. + latent_layer + Layer used to split the `model`. + + If an `int` is provided it will be interpreted as a layer index. + If a `string` is provided it will look for the layer name. + + To separate after the last convolution, `"last_conv"` can be used. + Otherwise, `-1` could be used for the last layer before softmax. + device + Device to use for the projection, if None, use the default device. + Only used for PyTorch models. Ignored for TensorFlow models. + mappable + Used only if not `latent_layer` is provided. Thus if the model is already splitted. + If the model can be placed in a `tf.data.Dataset` mapping function. + It is not the case for wrapped PyTorch models. + If you encounter errors in the `project_dataset` method, you can set it to `False`. + """ + + def __init__(self, + model: Union[tf.keras.Model, 'torch.nn.Module'], + latent_layer: Union[str, int] = -1, + device: Union["torch.device", str] = None, + mappable: bool = True, + ): + if latent_layer is None: + assert isinstance(model, tf.keras.Model),\ + "If no latent_layer is provided, the model should be a tf.keras.Model." + features_extractor = model + else: + features_extractor, _ = model_splitting(model, latent_layer=latent_layer, device=device) + mappable = isinstance(model, tf.keras.Model) + + super().__init__( + space_projection=features_extractor, + mappable=mappable, + requires_targets=False + ) diff --git a/xplique/example_based/prototypes.py b/xplique/example_based/prototypes.py new file mode 100644 index 00000000..5e058c0d --- /dev/null +++ b/xplique/example_based/prototypes.py @@ -0,0 +1,273 @@ +""" +Base model for prototypes +""" + +from abc import ABC, abstractmethod + +import tensorflow as tf +import numpy as np + +from ..types import Callable, Dict, List, Optional, Type, Union, DatasetOrTensor + +from .datasets_operations.tf_dataset_operations import dataset_gather + +from .search_methods import ProtoGreedySearch, MMDCriticSearch, ProtoDashSearch +from .search_methods import KNN, ORDER +from .projections import Projection +from .base_example_method import BaseExampleMethod + + +class Prototypes(BaseExampleMethod, ABC): + """ + Base class for prototypes. + + Parameters + ---------- + cases_dataset + The dataset used to train the model, examples are extracted from this dataset. + All datasets (cases, labels, and targets) should be of the same type. + Supported types are: `tf.data.Dataset`, `torch.utils.data.DataLoader`, + `tf.Tensor`, `np.ndarray`, `torch.Tensor`. + For datasets with multiple columns, the first column is assumed to be the cases. + While the second column is assumed to be the labels, and the third the targets. + Warning: datasets tend to reshuffle at each iteration, ensure the datasets are + not reshuffle as we use index in the dataset. + labels_dataset + Labels associated with the examples in the `cases_dataset`. + It should have the same type as `cases_dataset`. + targets_dataset + Targets associated with the `cases_dataset` for dataset projection, + oftentimes the one-hot encoding of a model's predictions. See `projection` for detail. + It should have the same type as `cases_dataset`. + It is not be necessary for all projections. + Furthermore, projections which requires it compute it internally by default. + nb_global_prototypes + Number of prototypes to select to explain the dataset or the model. + They define the number of elements returned by the `get_global_prototypes` method. + They have a huge impact on the computation time of the method. + nb_local_prototypes + Number of prototypes to select to explain the decision of the model on given inputs. + They define the number of elements returned by the `explain` method. + (Calling this method do not make sens if `projection` is `None`.) + projection + Projection or Callable that project samples from the input space to the search space. + The search space should be a space where distance make sense for the model. + The output of the projection should be a two dimensional tensor. (nb_samples, nb_features). + If `projection` is `None`, the model is not explained and prototypes represent the dataset. + + Example of Callable: + ``` + def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None): + ''' + Example of projection, + inputs are the elements to project. + targets are optional parameters to orientated the projection. + ''' + projected_inputs = # do some magic on inputs, it should use the model. + return projected_inputs + ``` + case_returns + String or list of string with the elements to return in `self.explain()`. + See `self.set_returns()` for detail. + In the case of prototypes, the indices returned by local search are + the indices of the prototypes in the list of prototypes. + To obtain the indices of the prototypes in the dataset, use `get_global_prototypes`. + batch_size + Number of samples treated simultaneously for projection and search. + Ignored if `tf.data.Dataset` are provided (these are supposed to be batched). + distance + Distance function for examples search. It can be an integer, a string in + {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable. + By default a distance function based on the kernel_fn is used. + kernel_fn : Callable, optional + Kernel function, by default the rbf kernel. + This function must only use TensorFlow operations. + gamma : float, optional + Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features. + """ + # pylint: disable=duplicate-code + + def __init__( + self, + cases_dataset: DatasetOrTensor, + labels_dataset: Optional[DatasetOrTensor] = None, + targets_dataset: Optional[DatasetOrTensor] = None, + nb_global_prototypes: int = 1, + nb_local_prototypes: int = 1, + projection: Union[Projection, Callable] = None, + case_returns: Union[List[str], str] = "examples", + batch_size: Optional[int] = None, + distance: Optional[Union[int, str, Callable]] = None, + kernel_fn: callable = None, + gamma: float = None + ): + # set common example-based parameters + super().__init__( + cases_dataset=cases_dataset, + labels_dataset=labels_dataset, + targets_dataset=targets_dataset, + k=nb_local_prototypes, + projection=projection, + case_returns=case_returns, + batch_size=batch_size, + ) + + # initiate search_method and search global prototypes + self.global_prototypes_search_method = self.search_method_class( + cases_dataset=self.projected_cases_dataset, + batch_size=self.batch_size, + nb_prototypes=nb_global_prototypes, + kernel_fn=kernel_fn, + gamma=gamma + ) + + # get global prototypes through the indices found by the search method + self.get_global_prototypes() + + # set knn for local explanations + self.search_method = KNN( + cases_dataset=self.global_prototypes_search_method.prototypes, + search_returns=self._search_returns, + k=self.k, + batch_size=self.batch_size, + distance=self.global_prototypes_search_method._get_distance_fn(distance), + order=ORDER.ASCENDING, + ) + + @property + @abstractmethod + def search_method_class(self) -> Type[ProtoGreedySearch]: + raise NotImplementedError + + def get_global_prototypes(self) -> Dict[str, tf.Tensor]: + """ + Provide the global prototypes computed at the initialization. + Prototypes and their labels are extracted from the indices. + The weights of the prototypes and their indices are also returned. + + Returns + ------- + prototypes_dict : Dict[str, tf.Tensor] + A dictionary with the following + - 'prototypes': The prototypes found by the method. + - 'prototype_labels': The labels of the prototypes. + - 'prototype_weights': The weights of the prototypes. + - 'prototype_indices': The indices of the prototypes. + """ + # pylint: disable=access-member-before-definition + if not hasattr(self, "prototypes") or self.prototypes is None: + assert self.global_prototypes_search_method is not None, ( + "global_prototypes_search_method is not initialized" + ) + assert self.global_prototypes_search_method.prototypes_indices is not None, ( + "prototypes_indices are not initialized" + ) + + # (nb_prototypes, 2) + self.prototypes_indices = self.global_prototypes_search_method.prototypes_indices + indices = self.prototypes_indices[tf.newaxis, ...] + + # (nb_prototypes, ...) + self.prototypes = dataset_gather(self.cases_dataset, indices)[0] + + # (nb_prototypes,) + if self.labels_dataset is not None: + self.prototypes_labels = dataset_gather(self.labels_dataset, indices)[0] + else: + self.prototypes_labels = None + + # (nb_prototypes,) + self.prototypes_weights = self.global_prototypes_search_method.prototypes_weights + + return { + "prototypes": self.prototypes, + "prototypes_labels": self.prototypes_labels, + "prototypes_weights": self.prototypes_weights, + "prototypes_indices": self.prototypes_indices, + } + + def format_search_output( + self, + search_output: Dict[str, tf.Tensor], + inputs: Union[tf.Tensor, np.ndarray], + ): + """ + Format the output of the `search_method` to match the expected returns in `self.returns`. + + Parameters + ---------- + search_output + Dictionary with the required outputs from the `search_method`. + inputs + Tensor or Array. Input samples to be explained. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + + Returns + ------- + return_dict + Dictionary with listed elements in `self.returns`. + The elements that can be returned are defined with the `_returns_possibilities` + static attribute of the class. + """ + # initialize return dictionary + return_dict = {} + + # indices in the list of prototypes + # (n, k) + flatten_indices = search_output["indices"][:, :, 0] * self.batch_size\ + + search_output["indices"][:, :, 1] + flatten_indices = tf.reshape(flatten_indices, [-1]) + + # add examples and weights + if "examples" in self.returns: # or "weights" in self.returns: + # (n * k, ...) + examples = tf.gather(params=self.prototypes, indices=flatten_indices) + # (n, k, ...) + examples = tf.reshape(examples, (inputs.shape[0], self.k) + examples.shape[1:]) + if "include_inputs" in self.returns: + # include inputs + inputs = tf.expand_dims(inputs, axis=1) + examples = tf.concat([inputs, examples], axis=1) + return_dict["examples"] = examples + + # add indices, distances, and labels + if "indices" in self.returns: + # convert indices in the list of prototypes to indices in the dataset + # (n * k, 2) + indices = tf.gather(params=self.prototypes_indices, indices=flatten_indices) + # (n, k, 2) + return_dict["indices"] = tf.reshape(indices, (inputs.shape[0], self.k, 2)) + if "distances" in self.returns: + return_dict["distances"] = search_output["distances"] + if "labels" in self.returns: + assert ( + self.prototypes_labels is not None + ), "The method cannot return labels without a label dataset." + + # (n * k,) + labels = tf.gather(params=self.prototypes_labels, indices=flatten_indices) + # (n, k) + return_dict["labels"] = tf.reshape(labels, (inputs.shape[0], self.k)) + + return return_dict + + +class ProtoGreedy(Prototypes): + # pylint: disable=missing-class-docstring + @property + def search_method_class(self) -> Type[ProtoGreedySearch]: + return ProtoGreedySearch + + +class MMDCritic(Prototypes): + # pylint: disable=missing-class-docstring + @property + def search_method_class(self) -> Type[ProtoGreedySearch]: + return MMDCriticSearch + + +class ProtoDash(Prototypes): + # pylint: disable=missing-class-docstring + @property + def search_method_class(self) -> Type[ProtoGreedySearch]: + return ProtoDashSearch diff --git a/xplique/example_based/search_methods/__init__.py b/xplique/example_based/search_methods/__init__.py new file mode 100644 index 00000000..24a2e14c --- /dev/null +++ b/xplique/example_based/search_methods/__init__.py @@ -0,0 +1,12 @@ +""" +Search methods +""" + +from .base import BaseSearchMethod, ORDER + +# from .sklearn_knn import SklearnKNN +from .proto_greedy_search import ProtoGreedySearch +from .proto_dash_search import ProtoDashSearch +from .mmd_critic_search import MMDCriticSearch +from .knn import BaseKNN, KNN, FilterKNN +from .kleor import KLEORSimMissSearch, KLEORGlobalSimSearch diff --git a/xplique/example_based/search_methods/base.py b/xplique/example_based/search_methods/base.py new file mode 100644 index 00000000..4ad07867 --- /dev/null +++ b/xplique/example_based/search_methods/base.py @@ -0,0 +1,177 @@ +""" +Base search method for example-based module +""" +from enum import Enum +from abc import ABC, abstractmethod + +import tensorflow as tf +import numpy as np + +from ...types import Union, Optional, List +from ..datasets_operations.tf_dataset_operations import sanitize_dataset + +class ORDER(Enum): + """ + Enumeration for the two types of ordering for the sorting function. + ASCENDING puts the elements with the smallest value first. + DESCENDING puts the elements with the largest value first. + """ + ASCENDING = 1 + DESCENDING = 2 + +def _sanitize_returns( + returns: Optional[Union[List[str], str]] = None, + possibilities: List[str] = None, + default: Union[List[str], str] = None + ) -> List[str]: + """ + It cleans the `returns` parameter. + Results is either a sublist of possibilities or a value among possibilities. + + Parameters + ---------- + returns + The value to verify and put to the `instance.returns` attribute. + possibilities + List of possible unit values for `instance.returns`. + default + Value in case `returns` is None. + + Returns + ------- + returns + The cleaned `returns` value. + """ + if possibilities is None: + possibilities = ["examples"] + if default is None: + default = ["examples"] + + if returns is None: + returns = default + elif isinstance(returns, str): + if returns == "all": + returns = possibilities + elif returns in possibilities: + returns = [returns] + else: + raise ValueError(f"{returns} should belong to {possibilities}") + elif isinstance(returns, list): + pass # already in the right format. + else: + raise ValueError(f"{returns} should either be `str` or `List[str]`") + + return returns + + +class BaseSearchMethod(ABC): + """ + Base class for the example-based search methods. This class is abstract. + It should be inherited by the search methods that are used to find examples in a dataset. + It also defines the interface for the search methods. + + Parameters + ---------- + cases_dataset + The dataset containing the examples to search in. + `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it. + Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not + the case for your dataset, otherwise, examples will not make sense. + k + The number of examples to retrieve at each call. + search_returns + String or list of string with the elements to return in `self.find_examples()`. + It should be a subset of `self._returns_possibilities` or `"all"`. + See self.returns setter for more detail. + batch_size + Number of samples treated simultaneously. + It should match the batch size of the cases_dataset in the case of a `tf.data.Dataset`. + """ + _returns_possibilities = ["examples", "indices", "distances", "include_inputs"] + + def __init__( + self, + cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray], + k: int = 1, + search_returns: Optional[Union[List[str], str]] = None, + batch_size: Optional[int] = 32, + ): + + # set batch size + if isinstance(cases_dataset, tf.data.Dataset): + self.batch_size = tf.shape(next(iter(cases_dataset)))[0].numpy() + else: + self.batch_size = batch_size + + self.cases_dataset = sanitize_dataset(cases_dataset, self.batch_size) + + self.k = k + self.returns = search_returns + + @property + def k(self) -> int: + """Getter for the k parameter.""" + return self._k + + @k.setter + def k(self, k: int): + """Setter for the k parameter.""" + assert isinstance(k, int) and k >= 1, f"k should be an int >= 1 and not {k}" + self._k = k + + @property + def returns(self) -> Union[List[str], str]: + """Getter for the returns parameter.""" + return self._returns + + @returns.setter + def returns(self, returns: Union[List[str], str]): + """ + Setter for the returns parameter used to define returned elements in `self.explain()`. + + Parameters + ---------- + returns + Most elements are useful in `xplique.plots.plot_examples()`. + `returns` can be set to 'all' for all possible elements to be returned. + - 'examples' correspond to the expected examples, + the inputs may be included in first position. (n, k(+1), ...) + - 'distances' the distances between the inputs and the corresponding examples. + They are associated to the examples. (n, k, ...) + - 'labels' if provided through `dataset_labels`, + they are the labels associated with the examples. (n, k, ...) + - 'include_inputs' specify if inputs should be included in the returned elements. + Note that it changes the number of returned elements from k to k+1. + """ + default = "examples" + self._returns = _sanitize_returns(returns, self._returns_possibilities, default) + + @abstractmethod + def find_examples(self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> dict: + """ + Search the samples to return as examples. Called by the explain methods. + It may also return the indices corresponding to the samples, + based on `self.returns` value. + + Parameters + ---------- + inputs + Tensor or Array. Input samples to be explained. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + targets + Tensor or Array. Target of the samples to be explained. + + Returns + ------- + return_dict + Dictionary containing the elements to return which are specified in `self.returns`. + """ + raise NotImplementedError() + + def __call__(self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> dict: + """find_samples() alias""" + return self.find_examples(inputs, targets) diff --git a/xplique/example_based/search_methods/common.py b/xplique/example_based/search_methods/common.py new file mode 100644 index 00000000..5f2a23d4 --- /dev/null +++ b/xplique/example_based/search_methods/common.py @@ -0,0 +1,145 @@ +""" +Common functions for search methods. +""" +# pylint: disable=invalid-name + +import numpy as np +import tensorflow as tf + +from ...types import Callable, Union + + +def _manhattan_distance(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: + """ + Compute the Manhattan distance between two vectors. + + Parameters + ---------- + x1 : tf.Tensor + First vector. + x2 : tf.Tensor + Second vector. + + Returns + ------- + tf.Tensor + Manhattan distance between the two vectors. + """ + return tf.reduce_sum(tf.abs(x1 - x2), axis=-1) + + +def _euclidean_distance(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: + """ + Compute the Euclidean distance between two vectors. + + Parameters + ---------- + x1 : tf.Tensor + First vector. + x2 : tf.Tensor + Second vector. + + Returns + ------- + tf.Tensor + Euclidean distance between the two vectors. + """ + return tf.norm(x1 - x2, ord="euclidean", axis=-1) + + +def _cosine_distance(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: + """ + Compute the cosine distance between two vectors. + + Parameters + ---------- + x1 : tf.Tensor + First vector. + x2 : tf.Tensor + Second vector. + + Returns + ------- + tf.Tensor + Cosine distance between the two vectors. + """ + return 1 - tf.reduce_sum(x1 * x2, axis=-1) / ( + tf.norm(x1, axis=-1) * tf.norm(x2, axis=-1) + ) + + +def _chebyshev_distance(x1: tf.Tensor, x2: tf.Tensor) -> tf.Tensor: + """ + Compute the Chebyshev distance between two vectors. + + Parameters + ---------- + x1 : tf.Tensor + First vector. + x2 : tf.Tensor + Second vector. + + Returns + ------- + tf.Tensor + Chebyshev distance between the two vectors. + """ + return tf.reduce_max(tf.abs(x1 - x2), axis=-1) + + +def _minkowski_distance(x1: tf.Tensor, x2: tf.Tensor, p: int) -> tf.Tensor: + """ + Compute the Minkowski distance between two vectors. + + Parameters + ---------- + x1 : tf.Tensor + First vector. + x2 : tf.Tensor + Second vector. + p : int + Order of the Minkowski distance. + + Returns + ------- + tf.Tensor + Minkowski distance between the two vectors. + """ + return tf.norm(x1 - x2, ord=p, axis=-1) + + +_distances = { + "manhattan": _manhattan_distance, + "euclidean": _euclidean_distance, + "cosine": _cosine_distance, + "chebyshev": _chebyshev_distance, + "inf": _chebyshev_distance, +} + + +def get_distance_function(distance: Union[int, str, Callable] = "euclidean",) -> Callable: + """ + Function to obtain a distance function from different inputs. + + Parameters + ---------- + distance : Union[int, str, Callable], optional + Distance function to use. It can be an integer, a string in + {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable, + by default "euclidean". + """ + # set distance function + if hasattr(distance, "__call__"): + return distance + if isinstance(distance, str) and distance in _distances: + return _distances[distance] + if isinstance(distance, int): + return lambda x1, x2: _minkowski_distance(x1, x2, p=distance) + if distance == np.inf: + return _chebyshev_distance + + raise AttributeError( + "The distance parameter is expected to be either a Callable, "\ + + f" an integer, 'inf', or a string in {_distances.keys()}. "\ + + f"But a {type(distance)} was received, with value {distance}." + ) diff --git a/xplique/example_based/search_methods/kleor.py b/xplique/example_based/search_methods/kleor.py new file mode 100644 index 00000000..7698b561 --- /dev/null +++ b/xplique/example_based/search_methods/kleor.py @@ -0,0 +1,355 @@ +""" +Define the KLEOR search method. +""" +from abc import abstractmethod, ABC + +import numpy as np +import tensorflow as tf + +from ..datasets_operations.tf_dataset_operations import dataset_gather +from ...types import Callable, List, Union, Optional, Tuple + +from .base import ORDER +from .knn import FilterKNN + +class BaseKLEORSearch(FilterKNN, ABC): + """ + Base class for the KLEOR search methods. + In those methods, one should first retrieve the Nearest Unlike Neighbor (NUN) + which is the closest example to the query that has a different prediction than the query. + Then, the method search for the K-Nearest Neighbors (KNN) + of the NUN that have the same prediction as the query. + + Depending on the KLEOR method some additional condition for the search are added. + See the specific KLEOR method for more details. + + Parameters + ---------- + cases_dataset + The dataset used to search the examples. + `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it. + Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not + the case for your dataset, otherwise, examples will not make sense. + targets_dataset + Targets are expected to be the one-hot encoding of the model's predictions + for the samples in cases_dataset. + `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it. + Batch size and cardinality of other datasets should match `cases_dataset`. + Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not + the case for your dataset, otherwise, examples will not make sense. + k + The number of examples to retrieve per input. + search_returns + String or list of string with the elements to return in `self.find_examples()`. + It should be a subset of `self._returns_possibilities`. + batch_size + Number of samples treated simultaneously. + distance + Distance function for examples search. It can be an integer, a string in + {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable, + by default "euclidean". + """ + def __init__( + self, + cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray], + targets_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray], + k: int = 1, + search_returns: Optional[Union[List[str], str]] = None, + batch_size: Optional[int] = 32, + distance: Union[int, str, Callable] = "euclidean", + ): + super().__init__( + cases_dataset = cases_dataset, + targets_dataset=targets_dataset, + k=k, + search_returns=search_returns, + batch_size=batch_size, + distance=distance, + order=ORDER.ASCENDING, + filter_fn=self._filter_fn, + ) + + # search method for the Nearest Unlike Neighbors + self.search_nuns = FilterKNN( + cases_dataset=cases_dataset, + targets_dataset=targets_dataset, + k=1, + search_returns=["indices", "distances"], + batch_size=batch_size, + distance=distance, + order = ORDER.ASCENDING, + filter_fn=self._filter_fn_nun, + ) + + def find_examples(self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None) -> dict: + """ + Search the samples to return as examples. Called by the explain methods. + It may also return the indices corresponding to the samples, + based on `return_indices` value. + + Parameters + ---------- + inputs + Tensor or Array. Input samples to be explained. + Assumed to have been already projected. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + targets + Tensor or Array. Target of the samples to be explained. + + Returns + ------- + return_dict + Dictionary containing the elements to return which are specified in `self.returns`. + """ + # compute neighbors + examples_distances, examples_indices, nuns, nuns_indices, nuns_sf_distances =\ + self.kneighbors(inputs, targets) + + # build return dict + return_dict = self._build_return_dict(inputs, examples_distances, examples_indices) + + # add the nuns if needed + if "nuns" in self.returns: + return_dict["nuns"] = nuns + + if "dist_to_nuns" in self.returns: + return_dict["dist_to_nuns"] = nuns_sf_distances + + if "nuns_indices" in self.returns: + return_dict["nuns_indices"] = nuns_indices + + return return_dict + + def _filter_fn(self, _, __, targets, cases_targets) -> tf.Tensor: + """ + Filter function to mask the cases + for which the prediction is the same as the predicted label on the inputs. + """ + # get the labels predicted by the model + # (n, ) + predicted_labels = tf.argmax(targets, axis=-1) + label_targets = tf.argmax(cases_targets, axis=-1) + # for each input, if the target label is the same as the cases label + # the mask as a True value and False otherwise + mask = tf.equal(tf.expand_dims(predicted_labels, axis=1), label_targets) + return mask + + def _filter_fn_nun(self, _, __, targets, cases_targets) -> tf.Tensor: + """ + Filter function to mask the cases for which the label is different from the predicted + label on the inputs. + """ + # get the labels predicted by the model + # (n, ) + predicted_labels = tf.argmax(targets, axis=-1) + label_targets = tf.argmax(cases_targets, axis=-1) # (bs,) + # for each input, if the target label is the same as the predicted label + # the mask as a False value and True otherwise + mask = tf.not_equal(tf.expand_dims(predicted_labels, axis=1), label_targets) #(n, bs) + return mask + + def _get_nuns(self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Union[tf.Tensor, np.ndarray]) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Get the Nearest Unlike Neighbors and their distance to the related input. + """ + nuns_dict = self.search_nuns(inputs, targets) + nuns_indices, nuns_distances = nuns_dict["indices"], nuns_dict["distances"] + nuns = dataset_gather(self.cases_dataset, nuns_indices) + return nuns, nuns_indices, nuns_distances + + def kneighbors(self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Union[tf.Tensor, np.ndarray] + ) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Compute the k SF to each tensor of `inputs` in `self.cases_dataset`. + Here `self.cases_dataset` is a `tf.data.Dataset`, hence, computations are done by batches. + + Parameters + ---------- + inputs + Tensor or Array. Input samples on which knn are computed. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + More information in the documentation. + targets + Tensor or Array. Target of the samples to be explained. + + Returns + ------- + input_sf_distances + Tensor of distances between the SFs and the inputs with dimension (n, k). + The n inputs times their k-SF. + sf_indices + Tensor of indices of the SFs in `self.cases_dataset` with dimension (n, k, 2). + Where, n represent the number of inputs and k the number of corresponding SFs. + The index of each element is encoded by two values, + the batch index and the index of the element in the batch. + Those indices can be used through `xplique.commons.tf_dataset_operation.dataset_gather`. + nuns + Tensor of Nearest Unlike Neighbors with dimension (n, 1, ...). + The n inputs times their NUN. + nuns_indices + Tensor of indices of the NUN in `self.cases_dataset` with dimension (n, 1, 2). + Where, n represent the number of inputs. + The index of each element is encoded by two values, + the batch index and the index of the element in the batch. + Those indices can be used through `xplique.commons.tf_dataset_operation.dataset_gather`. + nun_sf_distances + Tensor of distances between the SFs and the NUN with dimension (n, k). + The n NUNs times the k-SF. + """ + # pylint: disable=signature-differs + # pylint: disable=duplicate-code + # get the Nearest Unlike Neighbors and their distance to the related input + nuns, nuns_indices, nuns_input_distances = self._get_nuns(inputs, targets) + + # initialize the search for the KLEOR semi-factual methods + sf_indices, input_sf_distances, nun_sf_distances, batch_indices =\ + self._initialize_search(inputs) + + # iterate on batches + for batch_index, (cases, cases_targets) in\ + enumerate(zip(self.cases_dataset, self.targets_dataset)): + # add new elements + # (n, current_bs, 2) + indices = batch_indices[:, : tf.shape(cases)[0]] + new_indices = tf.stack( + [tf.fill(indices.shape, tf.cast(batch_index, tf.int32)), indices], axis=-1 + ) + + # get filter masks + # (n, current_bs) + filter_mask = self.filter_fn(inputs, cases, targets, cases_targets) + + # compute distances + # (n, current_bs) + b_nun_sf_distances = self._crossed_distances_fn(nuns, cases, mask=filter_mask) + b_input_sf_distances = self._crossed_distances_fn(inputs, cases, mask=filter_mask) + + # additional filtering + b_nun_sf_distances, b_input_sf_distances = self._additional_filtering( + b_nun_sf_distances, b_input_sf_distances, nuns_input_distances + ) + # concatenate distances and indices + # (n, k+curent_bs, 2) + concatenated_indices = tf.concat([sf_indices, new_indices], axis=1) + # (n, k+curent_bs) + concatenated_nun_sf_distances = tf.concat( + [nun_sf_distances, b_nun_sf_distances], + axis=1 + ) + concatenated_input_sf_distances = tf.concat( + [input_sf_distances, b_input_sf_distances], + axis=1 + ) + + # sort according to the smallest distances between sf and nun + # (n, k) + sort_order = tf.argsort( + concatenated_nun_sf_distances, axis=1, direction=self.order.name.upper() + )[:, : self.k] + + sf_indices.assign( + tf.gather(concatenated_indices, sort_order, axis=1, batch_dims=1) + ) + nun_sf_distances.assign( + tf.gather(concatenated_nun_sf_distances, sort_order, axis=1, batch_dims=1) + ) + input_sf_distances.assign( + tf.gather(concatenated_input_sf_distances, sort_order, axis=1, batch_dims=1) + ) + + return input_sf_distances, sf_indices, nuns, nuns_indices, nun_sf_distances + + def _initialize_search(self, + inputs: Union[tf.Tensor, np.ndarray] + ) -> Tuple[tf.Variable, tf.Variable, tf.Variable, tf.Tensor]: + """ + Initialize the search for the KLEOR semi-factual methods. + """ + nb_inputs = tf.shape(inputs)[0] + + # sf_indices shape (n, k, 2) + sf_indices = tf.Variable(tf.fill((nb_inputs, self.k, 2), -1)) + # (n, k) + input_sf_distances = tf.Variable(tf.fill((nb_inputs, self.k), self.fill_value)) + nun_sf_distances = tf.Variable(tf.fill((nb_inputs, self.k), self.fill_value)) + # (n, bs) + batch_indices = tf.expand_dims(tf.range(self.batch_size, dtype=tf.int32), axis=0) + batch_indices = tf.tile(batch_indices, multiples=(nb_inputs, 1)) + return sf_indices, input_sf_distances, nun_sf_distances, batch_indices + + @abstractmethod + def _additional_filtering(self, + nun_sf_distances: tf.Tensor, + input_sf_distances: tf.Tensor, + nuns_input_distances: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Additional filtering to apply to the distances. + """ + raise NotImplementedError + +class KLEORSimMissSearch(BaseKLEORSearch): + """ + The KLEORSimMiss method search for Semi-Factuals examples + by searching for the Nearest Unlike Neighbor (NUN) of the query. + The NUN is the closest example to the query that has a different prediction than the query. + Then, the method search for the K-Nearest Neighbors (KNN) + of the NUN that have the same prediction as the query. + """ + def _additional_filtering(self, + nun_sf_distances: tf.Tensor, + input_sf_distances: tf.Tensor, + nuns_input_distances: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + """ + No additional filtering for the KLEORSimMiss method. + """ + return nun_sf_distances, input_sf_distances + +class KLEORGlobalSimSearch(BaseKLEORSearch): + """ + The KLEORGlobalSim method search for Semi-Factuals examples + by searching for the Nearest Unlike Neighbor (NUN) of the query. + The NUN is the closest example to the query that has a different prediction than the query. + Then, the method search for the K-Nearest Neighbors (KNN) + of the NUN that have the same prediction as the query. + + In addition, for a SF candidate to be considered, + the SF should be closer to the query than the NUN + (i.e. the SF should be 'between' the input and its NUN). + This condition is added to the search. + """ + def _additional_filtering(self, + nun_sf_distances: tf.Tensor, + input_sf_distances: tf.Tensor, + nuns_input_distances: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Filter the distances to keep only the SF that are 'between' the input and its NUN. + + Parameters + ---------- + nun_sf_distances + Distances between the SF and the NUN. + input_sf_distances + Distances between the SF and the input. + nuns_input_distances + Distances between the input and the NUN. + + Returns + ------- + nun_sf_distances + Filtered distances between the SF and the NUN. + input_sf_distances + Filtered distances between the SF and the input. + """ + # filter non acceptable cases, i.e. cases for which the distance to the input is greater + # than the distance between the input and its nun + # (n, current_bs) + mask = tf.less(input_sf_distances, nuns_input_distances) + nun_sf_distances = tf.where(mask, nun_sf_distances, self.fill_value) + input_sf_distances = tf.where(mask, input_sf_distances, self.fill_value) + return nun_sf_distances, input_sf_distances diff --git a/xplique/example_based/search_methods/knn.py b/xplique/example_based/search_methods/knn.py new file mode 100644 index 00000000..8f688217 --- /dev/null +++ b/xplique/example_based/search_methods/knn.py @@ -0,0 +1,543 @@ +""" +KNN online search method in example-based module +""" +from abc import abstractmethod +import inspect + +import numpy as np +import tensorflow as tf + +from ..datasets_operations.tf_dataset_operations import dataset_gather, sanitize_dataset +from ...types import Callable, List, Union, Optional, Tuple + +from .base import BaseSearchMethod, ORDER +from .common import get_distance_function + +class BaseKNN(BaseSearchMethod): + """ + Base class for the KNN search methods. + It is an abstract class that should be inherited by a specific KNN method. + + Parameters + ---------- + cases_dataset + The dataset containing the examples to search in. + `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it. + Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not + the case for your dataset, otherwise, examples will not make sense. + k + The number of examples to retrieve. + search_returns + String or list of string with the elements to return in `self.find_examples()`. + It should be a subset of `self._returns_possibilities`. + batch_size + Number of samples treated simultaneously. + It should match the batch size of the `cases_dataset` in the case of a `tf.data.Dataset`. + order + The order of the distances, either `ORDER.ASCENDING` or `ORDER.DESCENDING`. + Default is `ORDER.ASCENDING`. + ASCENDING means that the smallest distances are the best, + DESCENDING means that the biggest distances are the best. + """ + def __init__( + self, + cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray], + k: int = 1, + search_returns: Optional[Union[List[str], str]] = None, + batch_size: Optional[int] = 32, + order: ORDER = ORDER.ASCENDING, + ): + super().__init__( + cases_dataset=cases_dataset, + k=k, + search_returns=search_returns, + batch_size=batch_size, + ) + # set order + assert isinstance(order, ORDER),\ + f"order should be an instance of ORDER and not {type(order)}" + self.order = order + # fill value + self.fill_value = np.inf if self.order == ORDER.ASCENDING else -np.inf + + @abstractmethod + def kneighbors(self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None + ) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Compute the k-nearest neighbors to each tensor of `inputs` in `self.cases_dataset`. + Here `self.cases_dataset` is a `tf.data.Dataset`, hence, computations are done by batches. + + Parameters + ---------- + inputs + Tensor or Array. Input samples on which knn are computed. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + More information in the documentation. + targets + Tensor or Array. Target of the samples to be explained. + + Returns + ------- + best_distances + Tensor of distances between the knn and the inputs with dimension (n, k). + The n inputs times their k-nearest neighbors. + best_indices + Tensor of indices of the knn in `self.cases_dataset` with dimension (n, k, 2). + Where, n represent the number of inputs and k the number of corresponding examples. + The index of each element is encoded by two values, + the batch index and the index of the element in the batch. + Those indices can be used through: + `xplique.example_based.datasets_operations.tf_dataset_operation.dataset_gather`. + """ + raise NotImplementedError + + def find_examples(self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None + ) -> dict: + """ + Search the samples to return as examples. Called by the explain methods. + It may also return the indices corresponding to the samples, + based on `return_indices` value. + + Parameters + ---------- + inputs + Tensor or Array. Input samples to be explained. + Assumed to have been already projected. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + targets + Tensor or Array. Target of the samples to be explained. + + Returns + ------- + return_dict + Dictionary containing the elements to return which are specified in `self.returns`. + """ + # compute neighbors + examples_distances, examples_indices = self.kneighbors(inputs, targets) + + # build the return dict + return_dict = self._build_return_dict(inputs, examples_distances, examples_indices) + + return return_dict + + def _build_return_dict(self, + inputs: Union[tf.Tensor, np.ndarray], + examples_distances: tf.Tensor, + examples_indices: tf.Tensor + ) -> dict: + """ + Build the return dict based on the `self.returns` values. + It builds the return dict with the value in the subset of + ['examples', 'include_inputs', 'indices', 'distances'] which is commonly shared. + + Parameters + ---------- + inputs + Tensor or Array. Input samples to be explained. + Assumed to have been already projected. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + examples_distances + Tensor of distances between the knn and the inputs with dimension (n, k). + The n inputs times their k-nearest neighbors. + examples_indices + Tensor of indices of the knn in `self.cases_dataset` with dimension (n, k, 2). + Where, n represent the number of inputs and k the number of corresponding examples. + The index of each element is encoded by two values, + the batch index and the index of the element in the batch. + Those indices can be used through: + `xplique.example_based.datasets_operations.tf_dataset_operation.dataset_gather`. + + Returns + ------- + return_dict + Dictionary containing the elements to return which are specified in `self.returns`. + """ + # Set values in return dict + return_dict = {} + if "examples" in self.returns: + return_dict["examples"] = dataset_gather(self.cases_dataset, examples_indices) + if "include_inputs" in self.returns: + inputs = tf.expand_dims(inputs, axis=1) + return_dict["examples"] = tf.concat( + [inputs, return_dict["examples"]], axis=1 + ) + if "indices" in self.returns: + return_dict["indices"] = examples_indices + if "distances" in self.returns: + return_dict["distances"] = examples_distances + + return return_dict + +class KNN(BaseKNN): + """ + KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors`. + The kneighbors method is implemented in a batched way to handle large datasets. + + Parameters + ---------- + cases_dataset + The dataset containing the examples to search in. + `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it. + Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not + the case for your dataset, otherwise, examples will not make sense. + k + The number of examples to retrieve. + search_returns + String or list of string with the elements to return in `self.find_examples()`. + It should be a subset of `self._returns_possibilities`. + batch_size + Number of samples treated simultaneously. + It should match the batch size of the `cases_dataset` in the case of a `tf.data.Dataset`. + order + The order of the distances, either `ORDER.ASCENDING` or `ORDER.DESCENDING`. + Default is `ORDER.ASCENDING`. + ASCENDING means that the smallest distances are the best, + DESCENDING means that the biggest distances are the best. + distance + Distance function for examples search. It can be an integer, a string in + {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable, + by default "euclidean". + """ + def __init__( + self, + cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray], + k: int = 1, + search_returns: Optional[Union[List[str], str]] = None, + batch_size: Optional[int] = 32, + distance: Union[int, str, Callable] = "euclidean", + order: ORDER = ORDER.ASCENDING, + ): + super().__init__( + cases_dataset=cases_dataset, + k=k, + search_returns=search_returns, + batch_size=batch_size, + order=order, + ) + + # set distance function + self.distance_fn = get_distance_function(distance) + + @tf.function + def _crossed_distances_fn(self, x1, x2) -> tf.Tensor: + """ + Element-wise distance computation between two tensors. + It has been vectorized to handle batches of inputs and cases. + + Parameters + ---------- + x1 + Tensor. Input samples of shape (n, ...). + x2 + Tensor. Cases samples of shape (m, ...). + + Returns + ------- + distances + Tensor of distances between the inputs and the cases with dimension (n, m). + """ + # pylint: disable=invalid-name + n = x1.shape[0] + m = x2.shape[0] + x2 = tf.expand_dims(x2, axis=0) + x2 = tf.repeat(x2, n, axis=0) + # reshape for broadcasting + x1 = tf.reshape(x1, (n, 1, -1)) + x2 = tf.reshape(x2, (n, m, -1)) + def compute_distance(args): + a, b = args + return self.distance_fn(a, b) + args = (x1, x2) + # Use vectorized_map to apply compute_distance element-wise + distances = tf.vectorized_map(compute_distance, args) + return distances + + def kneighbors(self, + inputs: Union[tf.Tensor, np.ndarray], + _ = None, + ) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Compute the k-neareast neighbors to each tensor of `inputs` in `self.cases_dataset`. + Here `self.cases_dataset` is a `tf.data.Dataset`, hence, computations are done by batches. + + Parameters + ---------- + inputs + Tensor or Array. Input samples on which knn are computed. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + More information in the documentation. + targets + Tensor or Array. Target of the samples to be explained. + + Returns + ------- + best_distances + Tensor of distances between the knn and the inputs with dimension (n, k). + The n inputs times their k-nearest neighbors. + best_indices + Tensor of indices of the knn in `self.cases_dataset` with dimension (n, k, 2). + Where, n represent the number of inputs and k the number of corresponding examples. + The index of each element is encoded by two values, + the batch index and the index of the element in the batch. + Those indices can be used through: + `xplique.example_based.datasets_operations.tf_dataset_operation.dataset_gather`. + """ + nb_inputs = tf.shape(inputs)[0] + + # initialize + # (n, k, 2) + best_indices = tf.Variable(tf.fill((nb_inputs, self.k, 2), -1)) + # (n, k) + best_distances = tf.Variable(tf.fill((nb_inputs, self.k), self.fill_value)) + # (n, bs) + batch_indices = tf.expand_dims(tf.range(self.batch_size, dtype=tf.int32), axis=0) + batch_indices = tf.tile(batch_indices, multiples=(nb_inputs, 1)) + + # iterate on batches + for batch_index, cases in enumerate(self.cases_dataset): + # add new elements + # (n, current_bs, 2) + indices = batch_indices[:, : tf.shape(cases)[0]] + new_indices = tf.stack( + [tf.fill(indices.shape, tf.cast(batch_index, tf.int32)), indices], axis=-1 + ) + + # compute distances + # (n, current_bs) + distances = self._crossed_distances_fn(inputs, cases) + + # (n, k+curent_bs, 2) + concatenated_indices = tf.concat([best_indices, new_indices], axis=1) + # (n, k+curent_bs) + concatenated_distances = tf.concat([best_distances, distances], axis=1) + + # sort all + # (n, k) + sort_order = tf.argsort( + concatenated_distances, axis=1, direction=self.order.name.upper() + )[:, : self.k] + + best_indices.assign( + tf.gather(concatenated_indices, sort_order, axis=1, batch_dims=1) + ) + best_distances.assign( + tf.gather(concatenated_distances, sort_order, axis=1, batch_dims=1) + ) + + return best_distances, best_indices + +class FilterKNN(BaseKNN): + """ + KNN method to search examples. Based on `sklearn.neighbors.NearestNeighbors`. + The kneighbors method is implemented in a batched way to handle large datasets. + In addition, a filter function is used to select the elements to compute the distances, + thus reducing the computational cost of the distance computation + (worth if the computation of the filter is low and the matrix of distances is sparse). + + Parameters + ---------- + cases_dataset + The dataset containing the examples to search in. + `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it. + Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not + the case for your dataset, otherwise, examples will not make sense. + targets_dataset + Targets are expected to be the one-hot encoding of the model's predictions + for the samples in cases_dataset. + `tf.data.Dataset` are assumed to be batched as tensorflow provide no method to verify it. + Batch size and cardinality of other datasets should match `cases_dataset`. + Be careful, `tf.data.Dataset` are often reshuffled at each iteration, be sure that it is not + the case for your dataset, otherwise, examples will not make sense. + k + The number of examples to retrieve. + search_returns + String or list of string with the elements to return in `self.find_examples()`. + It should be a subset of `self._returns_possibilities`. + batch_size + Number of samples treated simultaneously. + It should match the batch size of the `cases_dataset` in the case of a `tf.data.Dataset`. + order + The order of the distances, either `ORDER.ASCENDING` or `ORDER.DESCENDING`. + Default is `ORDER.ASCENDING`. + ASCENDING means that the smallest distances are the best, + DESCENDING means that the biggest distances are the best. + distance + Distance function for examples search. It can be an integer, a string in + {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable, + by default "euclidean". + filter_fn + A Callable that takes as inputs the inputs, their targets, + the cases and their targets and returns a boolean mask of shape (n, m) + where n is the number of inputs and m the number of cases. + This boolean mask is used to choose between which inputs and cases to compute the distances. + """ + def __init__( + self, + cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray], + targets_dataset: Optional[Union[tf.data.Dataset, tf.Tensor, np.ndarray]] = None, + k: int = 1, + search_returns: Optional[Union[List[str], str]] = None, + batch_size: Optional[int] = 32, + distance: Union[int, str, Callable] = "euclidean", + order: ORDER = ORDER.ASCENDING, + filter_fn: Optional[Callable] = None, + ): + # pylint: disable=invalid-name + super().__init__( + cases_dataset=cases_dataset, + k=k, + search_returns=search_returns, + batch_size=batch_size, + order=order, + ) + + # set distance function + if hasattr(distance, "__call__"): + self.distance_fn = distance + else: + base_distance_fn = get_distance_function(distance) + self.distance_fn = lambda x1, x2, m:\ + tf.where(m, base_distance_fn(x1, x2), self.fill_value) + + if filter_fn is None: + filter_fn = lambda x, z, y, t: tf.ones((tf.shape(x)[0], tf.shape(z)[0]), dtype=tf.bool) + elif hasattr(filter_fn, "__call__"): + filter_fn_signature = inspect.signature(filter_fn) + assert len(filter_fn_signature.parameters) == 4,\ + f"filter_fn should take 4 parameters, not {len(filter_fn_signature.parameters)}" + else: + raise TypeError( + f"filter_fn should be Callable, not {type(filter_fn)}" + ) + self.filter_fn = filter_fn + + # set targets_dataset + if targets_dataset is not None: + self.targets_dataset = sanitize_dataset(targets_dataset, self.batch_size) + else: + # make an iterable of None + self.targets_dataset = [None]*len(cases_dataset) + + @tf.function + def _crossed_distances_fn(self, x1, x2, mask): + """ + Element-wise distance computation between two tensors with a mask. + It has been vectorized to handle batches of inputs and cases. + + Parameters + ---------- + x1 + Tensor. Input samples of shape (n, ...). + x2 + Tensor. Cases samples of shape (m, ...). + mask + Tensor. Boolean mask of shape (n, m). + It is used to filter the elements for which the distance is computed. + + Returns + ------- + distances + Tensor of distances between the inputs and the cases with dimension (n, m). + """ + # pylint: disable=invalid-name + n = x1.shape[0] + m = x2.shape[0] + x2 = tf.expand_dims(x2, axis=0) + x2 = tf.repeat(x2, n, axis=0) + # reshape for broadcasting + x1 = tf.reshape(x1, (n, 1, -1)) + x2 = tf.reshape(x2, (n, m, -1)) + def compute_distance(args): + a, b, mask = args + return self.distance_fn(a, b, mask) + args = (x1, x2, mask) + # Use vectorized_map to apply compute_distance element-wise + distances = tf.vectorized_map(compute_distance, args) + return distances + + def kneighbors(self, + inputs: Union[tf.Tensor, np.ndarray], + targets: Optional[Union[tf.Tensor, np.ndarray]] = None + ) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Compute the k-neareast neighbors to each tensor of `inputs` in `self.cases_dataset`. + Here `self.cases_dataset` is a `tf.data.Dataset`, hence, computations are done by batches. + In addition, a filter function is used to select the elements to compute the distances, + thus reducing the computational cost of the distance computation + (worth if the computation of the filter is low and the matrix of distances is sparse). + + Parameters + ---------- + inputs + Tensor or Array. Input samples on which knn are computed. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + More information in the documentation. + targets + Tensor or Array. Target of the samples to be explained. + + Returns + ------- + best_distances + Tensor of distances between the knn and the inputs with dimension (n, k). + The n inputs times their k-nearest neighbors. + best_indices + Tensor of indices of the knn in `self.cases_dataset` with dimension (n, k, 2). + Where, n represent the number of inputs and k the number of corresponding examples. + The index of each element is encoded by two values, + the batch index and the index of the element in the batch. + Those indices can be used through: + `xplique.example_based.datasets_operations.tf_dataset_operation.dataset_gather`. + """ + nb_inputs = tf.shape(inputs)[0] + + # initialiaze + # (n, k, 2) + best_indices = tf.Variable(tf.fill((nb_inputs, self.k, 2), -1)) + # (n, k) + best_distances = tf.Variable(tf.fill((nb_inputs, self.k), self.fill_value)) + # (n, bs) + batch_indices = tf.expand_dims(tf.range(self.batch_size, dtype=tf.int32), axis=0) + batch_indices = tf.tile(batch_indices, multiples=(nb_inputs, 1)) + + # iterate on batches + for batch_index, (cases, cases_targets) in\ + enumerate(zip(self.cases_dataset, self.targets_dataset)): + # add new elements + # (n, current_bs, 2) + indices = batch_indices[:, : tf.shape(cases)[0]] + new_indices = tf.stack( + [tf.fill(indices.shape, tf.cast(batch_index, tf.int32)), indices], axis=-1 + ) + + # get filter masks + # (n, current_bs) + filter_mask = self.filter_fn(inputs, cases, targets, cases_targets) + + # compute distances + # (n, current_bs) + distances = self._crossed_distances_fn(inputs, cases, mask=filter_mask) + + # (n, k+curent_bs, 2) + concatenated_indices = tf.concat([best_indices, new_indices], axis=1) + # (n, k+curent_bs) + concatenated_distances = tf.concat([best_distances, distances], axis=1) + + # sort all + # (n, k) + sort_order = tf.argsort( + concatenated_distances, axis=1, direction=self.order.name.upper() + )[:, : self.k] + + best_indices.assign( + tf.gather(concatenated_indices, sort_order, axis=1, batch_dims=1) + ) + best_distances.assign( + tf.gather(concatenated_distances, sort_order, axis=1, batch_dims=1) + ) + + return best_distances, best_indices + \ No newline at end of file diff --git a/xplique/example_based/search_methods/mmd_critic_search.py b/xplique/example_based/search_methods/mmd_critic_search.py new file mode 100644 index 00000000..778b79d9 --- /dev/null +++ b/xplique/example_based/search_methods/mmd_critic_search.py @@ -0,0 +1,108 @@ +""" +MMDCritic search method in example-based module +""" + +import tensorflow as tf + +from ...types import Tuple + +from .proto_greedy_search import ProtoGreedySearch + + +class MMDCriticSearch(ProtoGreedySearch): + """ + MMDCritic method to search prototypes. + + References: + .. [#] `Been Kim, Rajiv Khanna, Oluwasanmi Koyejo, + "Examples are not enough, learn to criticize! criticism for interpretability" + `_ + + Parameters + ---------- + cases_dataset + The dataset used to train the model, examples are extracted from the dataset. + For natural example-based methods it is the train dataset. + batch_size + Number of samples treated simultaneously. + It should match the batch size of the `cases_dataset` in the case of a `tf.data.Dataset`. + nb_prototypes : int + Number of prototypes to find. + kernel_fn : Callable, optional + Kernel function, by default the rbf kernel. + This function must only use TensorFlow operations. + gamma : float, optional + Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features. + """ + + def _compute_batch_objectives(self, + candidates_kernel_diag: tf.Tensor, + candidates_kernel_col_means: tf.Tensor, + selection_kernel_col_means: tf.Tensor, + candidates_selection_kernel: tf.Tensor, + selection_selection_kernel: tf.Tensor + ) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Compute the objective function and corresponding weights + for a given set of selected prototypes and a candidate. + + Here, we have a special case of protogreedy where we give equal weights to all prototypes, + the objective here is simplified to speed up processing + + Find argmax_{c} F(S ∪ c) - F(S) + ≡ + Find argmax_{c} F(S ∪ c) + ≡ + Find argmax_{c} (sum1 - sum2) + where: sum1 = (2 / n) * ∑[i=1 to n] κ(x_i, c) + sum2 = 1/(|S|+1) [κ(c, c) + 2 * ∑[j=1 to |S|] κ(x_j, c)] + + Parameters + ---------- + candidates_kernel_diag : Tensor + Diagonal values of the kernel matrix between the candidates and themselves. Shape (bc,). + candidates_kernel_col_means : Tensor + Column means of the kernel matrix, subset for the candidates. Shape (bc,). + selection_kernel_col_means : Tensor + Column means of the kernel matrix, subset for the selected prototypes. Shape (|S|,). + candidates_selection_kernel : Tensor + Kernel matrix between the candidates and the selected prototypes. Shape (bc, |S|). + selection_selection_kernel : Tensor + Kernel matrix between the selected prototypes. Shape (|S|, |S|). + + Returns + ------- + objectives + Tensor that contains the computed objective values for each candidate. Shape (bc,). + objectives_weights + Tensor that contains the computed objective weights for each candidate. + Shape (bc, |S|+1). + """ + + nb_candidates = tf.shape(candidates_kernel_diag)[0] + + # (bc,) - 2 * ∑[i=1 to n] κ(x_i, c) + sum1 = 2 * candidates_kernel_col_means + + if candidates_selection_kernel is None: + extended_nb_selected = 1 + + # (bc,) - κ(c, c) + sum2 = candidates_kernel_diag + else: + extended_nb_selected = tf.shape(selection_kernel_col_means)[0] + 1 + + # (bc,) - κ(c, c) + 2 * ∑[j=1 to |S|] κ(x_j, c) + # the second term is 0 when the selection is empty + sum2 = candidates_kernel_diag + 2 * tf.reduce_sum(candidates_selection_kernel, axis=1) + + # (bc,) - 1/(|S|+1) [κ(c, c) + 2 * ∑[j=1 to |S|] κ(x_j, c)] + sum2 /= tf.cast(extended_nb_selected, tf.float32) + + # (bc,) + objectives = sum1 - sum2 + + # (bc, |S|+1) - ones (the weights are normalized later) + objectives_weights = tf.ones((nb_candidates, extended_nb_selected), dtype=tf.float32) + + return objectives, objectives_weights diff --git a/xplique/example_based/search_methods/proto_dash_search.py b/xplique/example_based/search_methods/proto_dash_search.py new file mode 100644 index 00000000..ba26f838 --- /dev/null +++ b/xplique/example_based/search_methods/proto_dash_search.py @@ -0,0 +1,244 @@ +""" +ProtoDash search method in example-based module +""" + +import numpy as np +from scipy.optimize import minimize +import tensorflow as tf + +from ...types import Union, Optional, Tuple + +from .proto_greedy_search import ProtoGreedySearch + + +class Optimizer(): + """ + Class to solve the quadratic problem: + F(S) ≡ max_{w:supp(w)∈ S, w ≥ 0} l(w), + where l(w) = w^T * μ_p - 1/2 * w^T * K * w + + Parameters + ---------- + initial_weights : Tensor + Initial weight vector. + min_weight : float, optional + Lower bound on weight. Default is 0. + max_weight : float, optional + Upper bound on weight. Default is 10000. + """ + + def __init__( + self, + initial_weights: Union[tf.Tensor, np.ndarray], + min_weight: float = 0, + max_weight: float = 10000 + ): + self.initial_weights = initial_weights + self.min_weight = min_weight + self.max_weight = max_weight + self.bounds = [(min_weight, max_weight)] * initial_weights.shape[0] + self.objective_fn = lambda w, u, K: - (w @ u - 0.5 * w @ K @ w) + + def optimize(self, u, K): + """ + Perform optimization to find the optimal values of the weight vector (w) + and the corresponding objective function value. + + Parameters + ---------- + u : Tensor + Mean similarity of each prototype. + K : Tensor + The kernel matrix. + + Returns + ------- + best_weights : Tensor + The optimal value of the weight vector (w). + best_objective : Tensor + The value of the objective function corresponding to the best_weights. + """ + # pylint: disable=invalid-name + + u = u.numpy() + K = K.numpy() + + result = minimize(self.objective_fn, self.initial_weights, args=(u, K), + method='SLSQP', bounds=self.bounds, options={'disp': False}) + + # Get the best weights + best_weights = result.x + best_weights = tf.expand_dims(tf.convert_to_tensor(best_weights, dtype=tf.float32), axis=0) + + # Get the best objective + best_objective = -result.fun + best_objective = tf.expand_dims(tf.convert_to_tensor(best_objective, dtype=tf.float32), + axis=0) + + assert tf.reduce_all(best_weights >= 0) + + return best_weights, best_objective + + +class ProtoDashSearch(ProtoGreedySearch): + """ + Protodash method for searching prototypes. + + References: + .. [#] `Karthik S. Gurumoorthy, Amit Dhurandhar, Guillermo Cecchi, + "ProtoDash: Fast Interpretable Prototype Selection" + `_ + + Parameters + ---------- + cases_dataset + The dataset used to train the model, examples are extracted from the dataset. + For natural example-based methods it is the train dataset. + batch_size + Number of samples treated simultaneously. + It should match the batch size of the `cases_dataset` in the case of a `tf.data.Dataset`. + nb_prototypes : int + Number of prototypes to find. + kernel_fn : Callable, optional + Kernel function, by default the rbf kernel. + This function must only use TensorFlow operations. + gamma : float, optional + Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features. + exact_selection_weights_update : bool, optional + Wether to use an exact method to update selection weights, by default False. + Exact method is based on a scipy optimization, + while the other is based on a tensorflow inverse operation. + """ + + def __init__( + self, + cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray], + batch_size: Optional[int] = 32, + nb_prototypes: int = 1, + kernel_fn: callable = None, + gamma: float = None, + exact_selection_weights_update: bool = False, + ): + + self.exact_selection_weights_update = exact_selection_weights_update + + super().__init__( + cases_dataset=cases_dataset, + batch_size=batch_size, + nb_prototypes=nb_prototypes, + kernel_fn=kernel_fn, + gamma=gamma + ) + + def _update_selection_weights(self, + selection_kernel_col_means: tf.Tensor, + selection_selection_kernel: tf.Tensor, + best_diag: tf.Tensor, + best_objective: tf.Tensor + ) -> tf.Tensor: + """ + Update the selection weights based on the given parameters. + Pursuant to Lemma IV.4: + If best_gradient ≤ 0, then + ζ(S∪{best_sample_index}) = ζ(S) and specifically, w_{best_sample_index} = 0. + Otherwise, the stationarity and complementary slackness KKT conditions + entails that w_{best_sample_index} = best_gradient / κ(best_sample_index, best_sample_index) + + Parameters + ---------- + selection_kernel_col_means : Tensor + Column means of the kernel matrix computed from the selected prototypes. Shape (|S|,). + selection_selection_kernel : Tensor + Kernel matrix computed from the selected prototypes. Shape (|S|, |S|). + best_diag : tf.Tensor + The diagonal element of the kernel matrix corresponding to the lastly added prototype. + Shape (1,). + best_objective : tf.Tensor + The computed objective function value of the lastly added prototype. Shape (1,). + Used to initialize the weights for the exact weights update. + + """ + # pylint: disable=invalid-name + nb_selected = selection_kernel_col_means.shape[0] + + if best_objective <= 0: + self.prototypes_weights[nb_selected - 1].assign(0) + else: + # (|S|,) + u = selection_kernel_col_means + + # (|S|, |S|) + K = selection_selection_kernel + + if self.exact_selection_weights_update: + # initialize the weights + best_objective_diag = best_objective / best_diag + self.prototypes_weights[nb_selected - 1].assign(best_objective_diag) + + # optimize the weights + opt = Optimizer(self.prototypes_weights[:nb_selected]) + optimized_weights, _ = opt.optimize(u[:, tf.newaxis], K) + + # update the weights + self.prototypes_weights[:nb_selected].assign(tf.squeeze(optimized_weights, axis=0)) + else: + # We added epsilon to the diagonal of K to ensure that K is invertible + # (|S|, |S|) + K_inv = tf.linalg.inv(K + ProtoDashSearch.EPSILON * tf.eye(K.shape[-1])) + + # use w* = K^-1 * u as the optimal weights + # (|S|,) + selection_weights = tf.linalg.matvec(K_inv, u) + selection_weights = tf.maximum(selection_weights, 0) + + # update the weights + self.prototypes_weights[:nb_selected].assign(selection_weights) + + def _compute_batch_objectives(self, + candidates_kernel_diag: tf.Tensor, + candidates_kernel_col_means: tf.Tensor, + selection_kernel_col_means: tf.Tensor, + candidates_selection_kernel: tf.Tensor, + selection_selection_kernel: tf.Tensor + ) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Compute the objective function and corresponding weights + for a given set of selected prototypes and a candidate. + Calculate the gradient of l(w) = w^T * μ_p - 1/2 * w^T * K * w + w.r.t w, on the optimal weight point ζ^(S) + g = ∇l(ζ^(S)) = μ_p - K * ζ^(S) + g is computed for each candidate c + + Parameters + ---------- + candidates_kernel_diag : Tensor + Diagonal values of the kernel matrix between the candidates and themselves. Shape (bc,). + candidates_kernel_col_means : Tensor + Column means of the kernel matrix, subset for the candidates. Shape (bc,). + selection_kernel_col_means : Tensor + Column means of the kernel matrix, subset for the selected prototypes. Shape (|S|,). + candidates_selection_kernel : Tensor + Kernel matrix between the candidates and the selected prototypes. Shape (bc, |S|). + selection_selection_kernel : Tensor + Kernel matrix between the selected prototypes. Shape (|S|, |S|). + + Returns + ------- + objectives + Tensor that contains the computed objective values for each candidate. Shape (bc,). + objectives_weights + No weights are returned in this case. It is set to None. + The weights are computed and updated in the `_update_selection_weights` method. + """ + # pylint: disable=invalid-name + + if candidates_selection_kernel is None: + # (bc,) + # S = ∅ and ζ^(∅) = 0, g = ∇l(ζ^(∅)) = μ_p + objectives = candidates_kernel_col_means + else: + # (bc,) - g = μ_p - K * ζ^(S) + objectives = candidates_kernel_col_means - tf.linalg.matvec(candidates_selection_kernel, + selection_kernel_col_means) + + return objectives, None diff --git a/xplique/example_based/search_methods/proto_greedy_search.py b/xplique/example_based/search_methods/proto_greedy_search.py new file mode 100644 index 00000000..988f8656 --- /dev/null +++ b/xplique/example_based/search_methods/proto_greedy_search.py @@ -0,0 +1,539 @@ +""" +ProtoGreedy search method in example-based module +""" + +import numpy as np +import tensorflow as tf + +from ...types import Callable, Union, Optional, Tuple + +from ..datasets_operations.tf_dataset_operations import sanitize_dataset + +from .common import get_distance_function + + +class ProtoGreedySearch(): + """ + ProtoGreedy method for searching prototypes. + + References: + .. [#] `Karthik S. Gurumoorthy, Amit Dhurandhar, Guillermo Cecchi, + "ProtoDash: Fast Interpretable Prototype Selection" + `_ + + Parameters + ---------- + cases_dataset + The dataset used to train the model, examples are extracted from the dataset. + For natural example-based methods it is the train dataset. + batch_size + Number of samples treated simultaneously. + It should match the batch size of the `cases_dataset` in the case of a `tf.data.Dataset`. + nb_prototypes : int + Number of prototypes to find. + kernel_fn : Callable, optional + Kernel function, by default the rbf kernel. + The overall method will be much faster if the provided function is a `tf.function`. + gamma : float, optional + Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features. + """ + # pylint: disable=too-many-instance-attributes + + # Avoid zero division during procedure. (the value is not important, as if the denominator is + # zero, then the nominator will also be zero). + EPSILON = 1e-6 + + def __init__( + self, + cases_dataset: Union[tf.data.Dataset, tf.Tensor, np.ndarray], + batch_size: Optional[int] = 32, + nb_prototypes: int = 1, + kernel_fn: callable = None, + gamma: float = None + ): + # pylint: disable=duplicate-code + # set batch size + if isinstance(cases_dataset, tf.data.Dataset): + self.batch_size = tf.shape(next(iter(cases_dataset)))[0].numpy() + else: + self.batch_size = batch_size + + self.cases_dataset = sanitize_dataset(cases_dataset, self.batch_size) + + # set kernel function + if kernel_fn is None: + # define kernel fn to default rbf kernel + self.__set_default_kernel_fn(self.cases_dataset, gamma) + elif hasattr(kernel_fn, "__call__"): + # the kernel_fn is a callable the output is converted to a tensor for consistency + self.kernel_fn = lambda x1, x2: tf.convert_to_tensor(kernel_fn(x1, x2)) + else: + raise AttributeError( + "The kernel_fn parameter is expected to be None or a Callable"\ + +f"but {kernel_fn} was received."\ + ) + + # compute the sum of the columns and the diagonal values of the kernel matrix of the dataset + self.__set_kernel_matrix_column_means_and_diagonal() + + # compute the prototypes in the latent space + self.find_global_prototypes(nb_prototypes) + + def _get_distance_fn(self, distance: Optional[Union[int, str, Callable]]) -> Callable: + """ + Get the distance function for examples search. + Function called through the Prototypes class. + The distance function is used to search for the closest examples to the prototypes. + + Parameters + ---------- + distance + Distance function for examples search. It can be an integer, a string in + {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable. + + Returns + ------- + Callable + Distance function for examples search. + """ + # pylint: disable=invalid-name + if distance is None: + def kernel_induced_distance(x1, x2): + def dist(x): + x = tf.expand_dims(x, axis=0) + return tf.sqrt( + self.kernel_fn(x1, x1) - 2 * self.kernel_fn(x1, x) + self.kernel_fn(x, x) + ) + distance = tf.map_fn(dist, x2) + return tf.squeeze(distance, axis=[1, 2]) + return kernel_induced_distance + + return get_distance_function(distance) + + def __set_default_kernel_fn(self, + cases_dataset: tf.data.Dataset, + gamma: float = None, + ) -> None: + """ + Set the default kernel function. + + Parameters + ---------- + cases_dataset : tf.data.Dataset + The dataset used to train the model, examples are extracted from the dataset. + The shape are extracted from the dataset, it is necessary for optimal performance, + and to set the default gamma value. + gamma : float, optional + Parameter that determines the spread of the rbf kernel, defaults to 1.0 / n_features. + """ + cases_shape = cases_dataset.element_spec.shape + self.nb_features = cases_shape[-1] + + # elements should be batched tabular data + assert len(cases_shape) == 2,\ + "Prototypes' searches expects 2D data, (nb_samples, nb_features), but got "+\ + f"{cases_shape}. Please verify your projection "+\ + "if you provided a custom one. If you use a splitted model, "+\ + "make sure the output of the first part of the model is flattened." + + if gamma is None: + if cases_dataset is None: + raise ValueError( + "For the default kernel_fn, the default gamma value requires samples shape." + ) + gamma = 1.0 / self.nb_features + + gamma = tf.constant(gamma, dtype=tf.float32) + + # created inside a function for gamma to be a constant and prevent graph retracing + @tf.function(input_signature=[ + tf.TensorSpec(shape=cases_shape, dtype=tf.float32, name="tensor_1"), + tf.TensorSpec(shape=cases_shape, dtype=tf.float32, name="tensor_2") + ]) + def rbf_kernel(tensor_1: tf.Tensor, tensor_2: tf.Tensor,) -> tf.Tensor: + """ + Compute the rbf kernel matrix between two sets of samples. + + Parameters + ---------- + tensor_1 + The first set of samples of shape (n, d). + tensor_2 + The second set of samples of shape (m, d). + + Returns + ------- + Tensor + The rbf kernel matrix of shape (n, m). + """ + + # (n, m, d) + pairwise_diff = tensor_1[:, tf.newaxis, :] - tensor_2[tf.newaxis, :, :] + + # (n, m) + pairwise_sq_dist = tf.reduce_sum(tf.square(pairwise_diff), axis=-1) + kernel_matrix = tf.exp(-gamma * pairwise_sq_dist) + + return kernel_matrix + + self.kernel_fn = rbf_kernel + + def __set_kernel_matrix_column_means_and_diagonal(self) -> None: + """ + Compute the sum of the columns and the diagonal values of the kernel matrix of the dataset. + Results are stored in the object. + + Parameters + ---------- + cases_dataset : tf.data.Dataset + The kernel matrix is computed between the cases of this dataset. + kernel_fn : Callable + Kernel function to compute the kernel matrix between two sets of samples. + """ + # Compute the sum of the columns and the diagonal values of the kernel matrix of the dataset + # We take advantage of the symmetry of this matrix to traverse only its lower triangle + col_sums = [] + diag = [] + row_sums = [0] # first batch has no row sums and not computed, 0 is a placeholder + nb_samples = 0 + + for batch_col_index, batch_col_cases in enumerate(self.cases_dataset): + + batch_col_sums = tf.zeros((batch_col_cases.shape[0]), dtype=tf.float32) + + for batch_row_index, batch_row_cases in enumerate(self.cases_dataset): + # ignore batches that are above the diagonal + if batch_col_index > batch_row_index: + continue + + # Compute the kernel matrix between the two batches + # (n_b_row, n_b_col) + batch_kernel = self.kernel_fn(batch_row_cases, batch_col_cases) + + # increment the column sums + # (n_b_col,) + batch_col_sums = batch_col_sums + tf.reduce_sum(batch_kernel, axis=0) + + # current pair of batches is on the diagonal + if batch_col_index == batch_row_index: + # stock the diagonal values + diag.append(tf.linalg.diag_part(batch_kernel)) + + # complete the column sums with the row sums when the batch is on the diagonal + # (n_b_col,) + batch_col_sums = batch_col_sums + row_sums[batch_row_index] + continue + + # increment the row sums + # (n_b_row,) + current_batch_row_sums = tf.reduce_sum(batch_kernel, axis=1) + if batch_col_index == 0: + row_sums.append(current_batch_row_sums) + else: + row_sums[batch_row_index] += current_batch_row_sums + + col_sums.append(batch_col_sums) + nb_samples += batch_col_cases.shape[0] + + # pad the last batch to have the same size as the others + col_sums[-1] = tf.pad(col_sums[-1], [[0, self.batch_size - col_sums[-1].shape[0]]]) + + # (nb, b) + self.kernel_col_means = tf.stack(col_sums, axis=0) / tf.cast(nb_samples, dtype=tf.float32) + + # pad the last batch to have the same size as the others + diag[-1] = tf.pad(diag[-1], [[0, self.batch_size - diag[-1].shape[0]]]) + + # (nb, b) + self.kernel_diag = tf.stack(diag, axis=0) + + def _compute_batch_objectives(self, + candidates_kernel_diag: tf.Tensor, + candidates_kernel_col_means: tf.Tensor, + selection_kernel_col_means: tf.Tensor, + candidates_selection_kernel: tf.Tensor, + selection_selection_kernel: tf.Tensor + ) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Compute the objective function and corresponding weights + for a given set of selected prototypes and a batch of candidates. + + Here, we have a special case of protogreedy where we give equal weights to all prototypes, + the objective here is simplified to speed up processing. + + Find argmax_{c} F(S ∪ c) - F(S) + ≡ + Find argmax_{c} F(S ∪ c) + ≡ + Find argmax_{c} max_{w} (w^T mu_p) - (w^T K w) / 2 + + w*, the optimal objective weights, is computed as follows: w* = K^-1 mu_p + + where: + - mu_p is the column means of the kernel matrix + - K is the kernel matrix + + Parameters + ---------- + candidates_kernel_diag : Tensor + Diagonal values of the kernel matrix between the candidates and themselves. Shape (bc,). + candidates_kernel_col_means : Tensor + Column means of the kernel matrix, subset for the candidates. Shape (bc,). + selection_kernel_col_means : Tensor + Column means of the kernel matrix, subset for the selected prototypes. Shape (|S|,). + candidates_selection_kernel : Tensor + Kernel matrix between the candidates and the selected prototypes. Shape (bc, |S|). + selection_selection_kernel : Tensor + Kernel matrix between the selected prototypes. Shape (|S|, |S|). + + Returns + ------- + objectives + Tensor that contains the computed objective values for each candidate. Shape (bc,). + objectives_weights + Tensor that contains the computed objective weights for each candidate. + Shape (bc, |S|+1). + """ + # pylint: disable=invalid-name + # construct the kernel matrix for (S ∪ c) for each candidate (S is the selection) + # (bc, |S| + 1, |S| + 1) + if candidates_selection_kernel is None: + # no selected prototypes yet, S = {} + # (bc, 1, 1) + K = candidates_kernel_diag[:, tf.newaxis, tf.newaxis] + else: + # repeat the selection-selection kernel for each candidate + # (bc, |S|, |S|) + selection_selection_kernel = tf.tile( + tf.expand_dims(selection_selection_kernel, 0), + [candidates_selection_kernel.shape[0], 1, 1] + ) + + # add candidates-selection kernel row to the selection-selection kernel matrix + # (bc, |S| + 1, |S|) + extended_selection_selection_kernel = tf.concat( + [ + selection_selection_kernel, + candidates_selection_kernel[:, tf.newaxis, :] + ], + axis=1 + ) + + # create the extended column for the candidates with the diagonal values + # (bc, |S| + 1) + extended_candidates_selection_kernel = tf.concat( + [ + candidates_selection_kernel, + candidates_kernel_diag[:, tf.newaxis] + ], + axis=1 + ) + + # add the extended column for the candidates to the extended selection-selection kernel + # (bc, |S| + 1, |S| + 1) + K = tf.concat( + [ + extended_selection_selection_kernel, + extended_candidates_selection_kernel[:, :, tf.newaxis], + ], + axis=2 + ) + + # (bc, |S|) - extended selected kernel col means + selection_kernel_col_means = tf.tile( + selection_kernel_col_means[tf.newaxis, :], + multiples=[candidates_kernel_col_means.shape[0], 1] + ) + + # (bc, |S| + 1) - mu_p + candidates_selection_kernel_col_means = tf.concat( + [ + selection_kernel_col_means, + candidates_kernel_col_means[:, tf.newaxis]], + axis=1 + ) + + # compute the optimal objective weights for each candidate in the batch + # (bc, |S| + 1, |S| + 1) - K^-1 + K_inv = tf.linalg.inv(K + ProtoGreedySearch.EPSILON * tf.eye(K.shape[-1])) + + # (bc, |S| + 1) - w* = K^-1 mu_p + objectives_weights = tf.einsum("bsp,bp->bs", K_inv, candidates_selection_kernel_col_means) + objectives_weights = tf.maximum(objectives_weights, 0) + + # (bc,) - (w*^T mu_p) + weights_mu_p = tf.einsum("bp,bp->b", + objectives_weights, candidates_selection_kernel_col_means) + + # (bc,) - (w*^T K w*) + weights_K_weights = tf.einsum("bs,bsp,bp->b", + objectives_weights, K, objectives_weights) + + # (bc,) - (w*^T mu_p) - (w*^T K w*) / 2 + objectives = weights_mu_p - 0.5 * weights_K_weights + + return objectives, objectives_weights + + def find_global_prototypes(self, nb_prototypes: int): + """ + Search for global prototypes and their corresponding weights. + Iteratively select the best prototype candidate and add it to the selection. + The selected candidate is the one with the highest objective function value. + + The indices, weights, and cases of the selected prototypes are stored in the object. + + Parameters + ---------- + nb_prototypes : int + Number of global prototypes to find. + """ + # pylint: disable=too-many-statements + assert 0 < nb_prototypes, "`nb_prototypes` should be between at least 1." + + # initialize variables with placeholders + # final prototypes variables + # (np, 2) - final prototypes indices + self.prototypes_indices = tf.Variable(tf.fill((nb_prototypes, 2), -1)) + # (np,) - final prototypes weights + self.prototypes_weights = tf.Variable(tf.zeros((nb_prototypes,), dtype=tf.float32)) + # (np, d) - final prototypes cases + self.prototypes = tf.Variable(tf.zeros((nb_prototypes, self.nb_features), dtype=tf.float32)) + + # kernel matrix variables + # (np, np) - kernel matrix between selected prototypes + selection_selection_kernel = tf.Variable(tf.zeros((nb_prototypes, nb_prototypes), + dtype=tf.float32)) + # (nb, b, np) - kernel matrix between samples and selected prototypes + samples_selection_kernel = tf.Variable(tf.zeros((*self.kernel_diag.shape, nb_prototypes))) + + # (nb, b) - mask encoding the selected prototypes + mask_of_selected = tf.Variable(tf.fill(self.kernel_diag.shape, False)) + + # (np,) - selected column means + selection_kernel_col_means = tf.Variable(tf.zeros((nb_prototypes,), dtype=tf.float32)) + + # iterate till we find all the prototypes + for nb_selected in range(nb_prototypes): + # initialize + best_objective = tf.constant(-np.inf, dtype=tf.float32) + + # iterate over the batches + for batch_index, cases in enumerate(self.cases_dataset): + # (b,) + candidates_batch_mask = tf.math.logical_not(mask_of_selected[batch_index]) + + # last batch, pad with False + if cases.shape[0] < self.batch_size: + candidates_batch_mask = tf.math.logical_and( + candidates_batch_mask, tf.range(self.batch_size) < cases.shape[0] + ) + + # no candidates in the batch skipping + if not tf.reduce_any(candidates_batch_mask): + continue + + # compute the kernel matrix between the last selected prototypes and the candidates + if nb_selected > 0: + # (b,) + batch_samples_last_selection_kernel = self.kernel_fn( + cases, last_selected + )[:, 0] + samples_selection_kernel[batch_index, :cases.shape[0], nb_selected - 1].assign( + batch_samples_last_selection_kernel + ) + + # (b, |S|) + batch_candidates_selection_kernel =\ + samples_selection_kernel[batch_index, :cases.shape[0], :nb_selected] + # (bc, |S|) + batch_candidates_selection_kernel = tf.boolean_mask( + tensor=batch_candidates_selection_kernel, + mask=candidates_batch_mask[:cases.shape[0]], + axis=0, + ) + + else: + batch_candidates_selection_kernel = None + + # extract kernel values for the batch + # (bc,) + batch_candidates_kernel_diag = self.kernel_diag[batch_index][candidates_batch_mask] + # (bc,) + batch_candidates_kernel_col_means =\ + self.kernel_col_means[batch_index][candidates_batch_mask] + + # compute the objectives for the batch + # (bc,), (bc, |S| + 1) + objectives, objectives_weights = self._compute_batch_objectives( + batch_candidates_kernel_diag, + batch_candidates_kernel_col_means, + selection_kernel_col_means[:nb_selected], + batch_candidates_selection_kernel, + selection_selection_kernel[:nb_selected, :nb_selected], + ) + + # select the best candidate in the batch + objectives_argmax = tf.argmax(objectives) + batch_best_objective = tf.gather(objectives, objectives_argmax) + + if batch_best_objective > best_objective: + best_objective = batch_best_objective + best_batch_index = batch_index + best_index = tf.range(self.batch_size)[candidates_batch_mask][objectives_argmax] + best_case = cases[best_index] + if objectives_weights is not None: + best_weights = objectives_weights[objectives_argmax] + + # update the selected prototypes + # pylint: disable=unknown-option-value + # pylint: disable=possibly-used-before-assignment + last_selected = best_case[tf.newaxis, :] + mask_of_selected[best_batch_index, best_index].assign(True) + self.prototypes_indices[nb_selected].assign([best_batch_index, best_index]) + self.prototypes[nb_selected].assign(best_case) + + # update selected-selected kernel matrix (S = S ∪ c) + selection_selection_kernel[nb_selected, nb_selected].assign( + self.kernel_diag[best_batch_index, best_index] + ) + if nb_selected > 0: + # (|S|,) + new_selected = samples_selection_kernel[best_batch_index, best_index, :nb_selected] + + # add the new row and column to the selected-selected kernel matrix + selection_selection_kernel[nb_selected, :nb_selected].assign( + new_selected + ) + selection_selection_kernel[:nb_selected, nb_selected].assign( + new_selected + ) + + # update the selected column means + selection_kernel_col_means[nb_selected].assign( + self.kernel_col_means[best_batch_index, best_index] + ) + + # update the selected weights + if not hasattr(self, "_update_selection_weights"): + # pylint: disable=used-before-assignment + self.prototypes_weights[:nb_selected + 1].assign(best_weights) + else: + self._update_selection_weights( + selection_kernel_col_means[:nb_selected + 1], + selection_selection_kernel[:nb_selected + 1, :nb_selected + 1], + self.kernel_diag[best_batch_index, best_index], + best_objective, + ) + + # normalize the weights + self.prototypes_weights.assign( + self.prototypes_weights / tf.reduce_sum(self.prototypes_weights) + ) + + # convert variables to tensors + self.prototypes_indices = tf.convert_to_tensor(self.prototypes_indices) + self.prototypes = tf.convert_to_tensor(self.prototypes) + self.prototypes_weights = tf.convert_to_tensor(self.prototypes_weights) + + assert tf.reduce_sum(tf.cast(mask_of_selected, tf.int32)) == nb_prototypes,\ + "The number of prototypes found is not equal to the number of prototypes expected." diff --git a/xplique/example_based/semifactuals.py b/xplique/example_based/semifactuals.py new file mode 100644 index 00000000..1df6f0bb --- /dev/null +++ b/xplique/example_based/semifactuals.py @@ -0,0 +1,229 @@ +""" +Implementation of semi factuals methods for classification tasks. +""" +import numpy as np +import tensorflow as tf + +from ..types import Callable, List, Optional, Union, Dict, DatasetOrTensor + +from .datasets_operations.tf_dataset_operations import dataset_gather + +from .base_example_method import BaseExampleMethod +from .search_methods import KLEORSimMissSearch, KLEORGlobalSimSearch +from .projections import Projection + +from .search_methods.base import _sanitize_returns + + +class KLEORBase(BaseExampleMethod): + """ + Base class for KLEOR methods. KLEOR methods search Semi-Factuals examples. + In those methods, one should first retrieve the Nearest Unlike Neighbor (NUN) + which is the closest example to the query that has a different prediction than the query. + Then, the method search for the K-Nearest Neighbors (KNN) of the NUN + that have the same prediction as the query. + + All the searches are done in a projection space where distances are relevant for the model. + The projection space is defined by the `projection` method. + + Depending on the KLEOR method some additional condition for the search are added. + See the specific KLEOR method for more details. + + Parameters + ---------- + cases_dataset + The dataset used to train the model, examples are extracted from this dataset. + All datasets (cases, labels, and targets) should be of the same type. + Supported types are: `tf.data.Dataset`, `torch.utils.data.DataLoader`, + `tf.Tensor`, `np.ndarray`, `torch.Tensor`. + For datasets with multiple columns, the first column is assumed to be the cases. + While the second column is assumed to be the labels, and the third the targets. + Warning: datasets tend to reshuffle at each iteration, ensure the datasets are + not reshuffle as we use index in the dataset. + targets_dataset + Targets associated with the `cases_dataset` for dataset projection, + oftentimes the one-hot encoding of a model's predictions. See `projection` for detail. + They are also used to know the prediction of the model on the dataset. + It should have the same type as `cases_dataset`. + labels_dataset + Labels associated with the examples in the `cases_dataset`. + It should have the same type as `cases_dataset`. + k + The number of examples to retrieve per input. + projection + Projection or Callable that project samples from the input space to the search space. + The search space should be a space where distances are relevant for the model. + It should not be `None`, otherwise, the model is not involved thus not explained. + + Example of Callable: + ``` + def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None): + ''' + Example of projection, + inputs are the elements to project. + targets are optional parameters to orientated the projection. + ''' + projected_inputs = # do some magic on inputs, it should use the model. + return projected_inputs + ``` + case_returns + String or list of string with the elements to return in `self.explain()`. + See the base class returns property for more details. + batch_size + Number of samples treated simultaneously for projection and search. + Ignored if `cases_dataset` is a batched `tf.data.Dataset` or + a batched `torch.utils.data.DataLoader` is provided. + distance + Distance for the FilterKNN search method. + Distance function for examples search. It can be an integer, a string in + {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable, + by default "euclidean". + """ + # pylint: disable=duplicate-code + + _returns_possibilities = [ + "examples", "weights", "distances", "labels", "include_inputs", + "nuns", "nuns_indices", "dist_to_nuns", "nuns_labels" + ] + + def __init__( + self, + cases_dataset: DatasetOrTensor, + targets_dataset: DatasetOrTensor, + labels_dataset: Optional[DatasetOrTensor] = None, + k: int = 1, + projection: Union[Projection, Callable] = None, + case_returns: Union[List[str], str] = "examples", + batch_size: Optional[int] = None, + distance: Union[int, str, Callable] = "euclidean", + ): + + super().__init__( + cases_dataset=cases_dataset, + labels_dataset=labels_dataset, + targets_dataset=targets_dataset, + k=k, + projection=projection, + case_returns=case_returns, + batch_size=batch_size, + ) + + # initiate search_method + self.search_method = self.search_method_class( + cases_dataset=self.projected_cases_dataset, + targets_dataset=self.targets_dataset, + k=self.k, + search_returns=self._search_returns, + batch_size=self.batch_size, + distance=distance, + ) + + @property + def returns(self) -> Union[List[str], str]: + """Override the Base class returns' parameter.""" + return self._returns + + @returns.setter + def returns(self, returns: Union[List[str], str]): + """ + Set the returns parameter. The returns parameter is a string + or a list of string with the elements to return in `self.explain()`. + Possibly returned elements are defined with `_returns_possibilities` static attribute. + """ + default = "examples" + self._returns = _sanitize_returns(returns, self._returns_possibilities, default) + self._search_returns = ["indices", "distances"] + + if "nuns" in self._returns: + self._search_returns.append("nuns_indices") + elif "nuns_indices" in self._returns: + self._search_returns.append("nuns_indices") + elif "nuns_labels" in self._returns: + self._search_returns.append("nuns_indices") + + if "dist_to_nuns" in self._returns: + self._search_returns.append("dist_to_nuns") + + try: + self.search_method.returns = self._search_returns + except AttributeError: + pass + + def format_search_output( + self, + search_output: Dict[str, tf.Tensor], + inputs: Union[tf.Tensor, np.ndarray], + ): + """ + Format the output of the `search_method` to match the expected returns in `self.returns`. + + Parameters + ---------- + search_output + Dictionary with the required outputs from the `search_method`. + inputs + Tensor or Array. Input samples to be explained. + Expected shape among (N, W), (N, T, W), (N, W, H, C). + + Returns + ------- + return_dict + Dictionary with listed elements in `self.returns`. + The elements that can be returned are defined with the `_returns_possibilities` + static attribute of the class. + """ + return_dict = super().format_search_output(search_output, inputs) + if "nuns" in self.returns: + return_dict["nuns"] = dataset_gather(self.cases_dataset, search_output["nuns_indices"]) + if "nuns_labels" in self.returns: + return_dict["nuns_labels"] = dataset_gather(self.labels_dataset, + search_output["nuns_indices"]) + if "nuns_indices" in self.returns: + return_dict["nuns_indices"] = search_output["nuns_indices"] + if "dist_to_nuns" in self.returns: + return_dict["dist_to_nuns"] = search_output["dist_to_nuns"] + return return_dict + + +class KLEORSimMiss(KLEORBase): + """ + The KLEORSimMiss method search for Semi-Factuals examples + by searching for the Nearest Unlike Neighbor (NUN) of the query. + The NUN is the closest example to the query that has a different prediction than the query. + Then, the method search for the K-Nearest Neighbors (KNN) of the NUN + that have the same prediction as the query. + + The search is done in a projection space where distances are relevant for the model. + The projection space is defined by the `projection` method. + """ + @property + def search_method_class(self): + """ + This property defines the search method class to use for the search. + In this case, it is the KLEORSimMissSearch. + """ + return KLEORSimMissSearch + +class KLEORGlobalSim(KLEORBase): + """ + The KLEORGlobalSim method search for Semi-Factuals examples + by searching for the Nearest Unlike Neighbor (NUN) of the query. + The NUN is the closest example to the query that has a different prediction than the query. + Then, the method search for the K-Nearest Neighbors (KNN) of the NUN + that have the same prediction as the query. + + In addition, for a SF candidate to be considered, + the SF should be closer to the query than the NUN in the projection space + (i.e. the SF should be 'between' the input and its NUN). + This condition is added to the search. + + The search is done in a projection space where distances are relevant for the model. + The projection space is defined by the `projection` method. + """ + @property + def search_method_class(self): + """ + This property defines the search method class to use for the search. In this case, it is the + KLEORGlobalSimSearch. + """ + return KLEORGlobalSimSearch diff --git a/xplique/example_based/similar_examples.py b/xplique/example_based/similar_examples.py new file mode 100644 index 00000000..8d0fb756 --- /dev/null +++ b/xplique/example_based/similar_examples.py @@ -0,0 +1,220 @@ +""" +Base model for example-based +""" +import tensorflow as tf + +from ..attributions.base import BlackBoxExplainer +from ..types import Callable, List, Optional, Type, Union, DatasetOrTensor + +from .search_methods import KNN, BaseSearchMethod, ORDER +from .projections import Projection, AttributionProjection, HadamardProjection +from .base_example_method import BaseExampleMethod + + +class SimilarExamples(BaseExampleMethod): + """ + Class for similar example-based method. This class allows to search the k Nearest Neighbor + of an input in the projected space (defined by the projection method) + using the distance defined by the distance method provided. + + Parameters + ---------- + cases_dataset + The dataset used to train the model, examples are extracted from this dataset. + All datasets (cases, labels, and targets) should be of the same type. + Supported types are: `tf.data.Dataset`, `torch.utils.data.DataLoader`, + `tf.Tensor`, `np.ndarray`, `torch.Tensor`. + For datasets with multiple columns, the first column is assumed to be the cases. + While the second column is assumed to be the labels, and the third the targets. + Warning: datasets tend to reshuffle at each iteration, ensure the datasets are + not reshuffle as we use index in the dataset. + labels_dataset + Labels associated with the examples in the `cases_dataset`. + It should have the same type as `cases_dataset`. + targets_dataset + Targets associated with the `cases_dataset` for dataset projection, + oftentimes the one-hot encoding of a model's predictions. See `projection` for detail. + It should have the same type as `cases_dataset`. + It is not be necessary for all projections. + Furthermore, projections which requires it compute it internally by default. + k + The number of examples to retrieve per input. + projection + Projection or Callable that project samples from the input space to the search space. + The search space should be a space where distances are relevant for the model. + It should not be `None`, otherwise, the model is not involved thus not explained. + + Example of Callable: + ``` + def custom_projection(inputs: tf.Tensor, np.ndarray, targets: tf.Tensor, np.ndarray = None): + ''' + Example of projection, + inputs are the elements to project. + targets are optional parameters to orientated the projection. + ''' + projected_inputs = # do some magic on inputs, it should use the model. + return projected_inputs + ``` + case_returns + String or list of string with the elements to return in `self.explain()`. + See the base class returns property for more details. + batch_size + Number of samples treated simultaneously for projection and search. + Ignored if `cases_dataset` is a batched `tf.data.Dataset` or + a batched `torch.utils.data.DataLoader` is provided. + distance + Distance for the knn search method. It can be an integer, a string in + {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable, + by default "euclidean". + """ + def __init__( + self, + cases_dataset: DatasetOrTensor, + labels_dataset: Optional[DatasetOrTensor] = None, + targets_dataset: Optional[DatasetOrTensor] = None, + k: int = 1, + projection: Union[Projection, Callable] = None, + case_returns: Union[List[str], str] = "examples", + batch_size: Optional[int] = None, + distance: Union[int, str, Callable] = "euclidean", + ): + super().__init__( + cases_dataset=cases_dataset, + labels_dataset=labels_dataset, + targets_dataset=targets_dataset, + k=k, + projection=projection, + case_returns=case_returns, + batch_size=batch_size, + ) + + # initiate search_method + self.search_method = self.search_method_class( + cases_dataset=self.projected_cases_dataset, + search_returns=self._search_returns, + k=self.k, + batch_size=self.batch_size, + distance=distance, + order=ORDER.ASCENDING, + ) + + @property + def search_method_class(self) -> Type[BaseSearchMethod]: + return KNN + + +class Cole(SimilarExamples): + """ + Cole is a similar examples method that gives the most similar examples + to a query in some specific projection space. + Cole uses the model to build a search space so that distances are meaningful for the model. + It uses attribution methods to weight inputs. + Those attributions may be computed in the latent space for high-dimensional data like images. + + It is an implementation of a method proposed by Kenny et Keane in 2019, + Twin-Systems to Explain Artificial Neural Networks using Case-Based Reasoning: + https://researchrepository.ucd.ie/handle/10197/11064 + + Parameters + ---------- + cases_dataset + The dataset used to train the model, examples are extracted from this dataset. + All datasets (cases, labels, and targets) should be of the same type. + Supported types are: `tf.data.Dataset`, `torch.utils.data.DataLoader`, + `tf.Tensor`, `np.ndarray`, `torch.Tensor`. + For datasets with multiple columns, the first column is assumed to be the cases. + While the second column is assumed to be the labels, and the third the targets. + Warning: datasets tend to reshuffle at each iteration, ensure the datasets are + not reshuffle as we use index in the dataset. + labels_dataset + Labels associated with the examples in the `cases_dataset`. + It should have the same type as `cases_dataset`. + targets_dataset + Targets associated with the `cases_dataset` for dataset projection, + oftentimes the one-hot encoding of a model's predictions. See `projection` for detail. + It should have the same type as `cases_dataset`. + It is not be necessary for all projections. + Furthermore, projections which requires it compute it internally by default. + k + The number of examples to retrieve per input. + distance + Distance function for examples search. It can be an integer, a string in + {"manhattan", "euclidean", "cosine", "chebyshev", "inf"}, or a Callable, + by default "euclidean". + case_returns + String or list of string with the elements to return in `self.explain()`. + See the base class returns property for details. + batch_size + Number of samples treated simultaneously for projection and search. + Ignored if `cases_dataset` is a batched `tf.data.Dataset` or + a batched `torch.utils.data.DataLoader` is provided. + latent_layer + Layer used to split the model, the first part will be used for projection and + the second to compute the attributions. By default, the model is not split. + For such split, the `model` should be a `tf.keras.Model`. + + If an `int` is provided it will be interpreted as a layer index. + If a `string` is provided it will look for the layer name. + + The method as described in the paper apply the separation on the last convolutional layer. + To do so, the `"last_conv"` parameter will extract it. + Otherwise, `-1` could be used for the last layer before softmax. + attribution_method + Class of the attribution method to use for projection. + It should inherit from `xplique.attributions.base.BlackBoxExplainer`. + It can also be `"gradient"` to make the hadamard product between with the gradient. + It was deemed the best method in the original paper, and we optimized it for speed. + By default, it is set to `"gradient"`. + attribution_kwargs + Parameters to be passed for the construction of the `attribution_method`. + """ + def __init__( + self, + cases_dataset: DatasetOrTensor, + model: Union[tf.keras.Model, 'torch.nn.Module'], + labels_dataset: Optional[DatasetOrTensor] = None, + targets_dataset: Optional[DatasetOrTensor] = None, + k: int = 1, + distance: Union[str, Callable] = "euclidean", + case_returns: Optional[Union[List[str], str]] = "examples", + batch_size: Optional[int] = None, + latent_layer: Optional[Union[str, int]] = None, + attribution_method: Union[str, Type[BlackBoxExplainer]] = "gradient", + **attribution_kwargs, + ): + assert targets_dataset is not None + + # build the corresponding projection + if isinstance(attribution_method, str) and attribution_method.lower() == "gradient": + + operator = attribution_kwargs.get("operator", None) + + projection = HadamardProjection( + model=model, + latent_layer=latent_layer, + operator=operator, + ) + elif issubclass(attribution_method, BlackBoxExplainer): + # build attribution projection + projection = AttributionProjection( + model=model, + attribution_method=attribution_method, + latent_layer=latent_layer, + **attribution_kwargs, + ) + else: + raise ValueError( + "`attribution_method` should be 'gradient' or a subclass of BlackBoxExplainer, " +\ + f"not {attribution_method}" + ) + + super().__init__( + cases_dataset=cases_dataset, + targets_dataset=targets_dataset, + labels_dataset=labels_dataset, + projection=projection, + k=k, + case_returns=case_returns, + batch_size=batch_size, + distance=distance, + ) diff --git a/xplique/plots/__init__.py b/xplique/plots/__init__.py index 12e25eae..c7037f6a 100644 --- a/xplique/plots/__init__.py +++ b/xplique/plots/__init__.py @@ -1,6 +1,6 @@ """ Utility functions to visualize explanations """ -from .image import plot_attributions, plot_attribution, plot_maco +from .image import plot_attributions, plot_attribution, plot_maco, plot_examples from .tabular import plot_feature_impact, plot_mean_feature_impact, summary_plot_tabular from .timeseries import plot_timeseries_attributions diff --git a/xplique/plots/image.py b/xplique/plots/image.py index ca69b87d..aafccfc6 100644 --- a/xplique/plots/image.py +++ b/xplique/plots/image.py @@ -171,7 +171,7 @@ def plot_attributions( cols Number of columns. img_size - Size of each subplots (in inch), considering we keep aspect ratio + Size of each subplots (in inch), considering we keep aspect ratio. plot_kwargs Additional parameters passed to `plt.imshow()`. """ @@ -230,3 +230,102 @@ def plot_maco(image, alpha, percentile_image=1.0, percentile_alpha=80): plt.imshow(np.concatenate([image, alpha], -1)) plt.axis('off') + + +def plot_examples( + examples: np.ndarray, + distances: float = None, + labels: np.ndarray = None, + test_labels: np.ndarray = None, + predicted_labels: np.ndarray = None, + img_size: float = 2., +): + """ + This function is for image data, it show the returns of the explain function. + + Parameters + --------- + examples + Represente the k nearest neighbours of the input. (n, k+1, h, w, c) + distances + Distance between input data and examples. + labels + Labels of the examples. + labels_test + Corresponding to labels of the dataset test. + predicted_labels + Predicted labels of the examples. + img_size: + Size of each subplots (in inch), considering we keep aspect ratio + """ + # pylint: disable=too-many-arguments + if distances is not None: + assert examples.shape[0] == distances.shape[0],\ + "Number of samples treated should match between examples and distances." + assert examples.shape[1] == distances.shape[1] + 1,\ + "Number of distances for each input must correspond to the number of examples -1." + if labels is not None: + assert examples.shape[0] == labels.shape[0],\ + "Number of samples treated should match between examples and labels." + assert examples.shape[1] == labels.shape[1] + 1,\ + "Number of labels for each input must correspond to the number of examples -1." + + # number of rows depends if weights are provided + rows_by_input = 1 + rows = rows_by_input * examples.shape[0] + cols = examples.shape[1] + # get width and height of our images + l_width, l_height = examples.shape[2:4] + + # define the figure margin, width, height in inch + margin = 0.3 + spacing = 0.3 + figwidth = cols * img_size + (cols-1) * spacing + 2 * margin + figheight = rows * img_size * l_height/l_width + (rows-1) * spacing + 2 * margin + + left = margin/figwidth + bottom = margin/figheight + + fig = plt.figure() + fig.set_size_inches(figwidth, figheight) + + fig.subplots_adjust( + left = left, + bottom = bottom, + right = 1.-left, + top = 1.-bottom, + wspace = spacing/img_size, + hspace= spacing/img_size * l_width/l_height + ) + + # configure the grid to show all results + plt.rcParams["figure.autolayout"] = True + plt.rcParams["figure.figsize"] = [3 * examples.shape[1], 4] + + # loop to organize and show all results + for i in range(examples.shape[0]): + for k in range(examples.shape[1]): + plt.subplot(rows, cols, rows_by_input * i * cols + k + 1) + + # set title + if k == 0: + title = "Original image" + title += f"\nGround Truth: {test_labels[i]}" if test_labels is not None else "" + title += f"\nPrediction: {predicted_labels[i, k]}"\ + if predicted_labels is not None else "" + else: + title = f"Example {k}" + title += f"\nGround Truth: {labels[i, k-1]}" if labels is not None else "" + title += f"\nPrediction: {predicted_labels[i, k]}"\ + if predicted_labels is not None else "" + title += f"\nDistance: {distances[i, k-1]:.4f}" if distances is not None else "" + plt.title(title) + + # plot image + img = _normalize(examples[i, k]) + if img.shape[-1] == 1: + plt.imshow(img[:,:,0], cmap="gray") + else: + plt.imshow(img) + plt.axis("off") + fig.tight_layout() diff --git a/xplique/types/__init__.py b/xplique/types/__init__.py index 52cca202..1f04d319 100644 --- a/xplique/types/__init__.py +++ b/xplique/types/__init__.py @@ -2,5 +2,5 @@ Typing module """ -from typing import Union, Tuple, List, Callable, Dict, Optional, Any -from .custom_type import OperatorSignature +from typing import Union, Tuple, List, Callable, Dict, Optional, Any, Type +from .custom_type import OperatorSignature, DatasetOrTensor diff --git a/xplique/types/custom_type.py b/xplique/types/custom_type.py index 0562621a..4a27e8ce 100644 --- a/xplique/types/custom_type.py +++ b/xplique/types/custom_type.py @@ -1,7 +1,13 @@ """ Module for custom types or signature """ -from typing import Callable +from typing import Callable, TypeVar + +import numpy as np import tensorflow as tf OperatorSignature = Callable[[tf.keras.Model, tf.Tensor, tf.Tensor], float] + +DatasetOrTensor = TypeVar("DatasetOrTensor", + tf.Tensor, np.ndarray, "torch.Tensor", + tf.data.Dataset, "torch.utils.data.DataLoader")