diff --git a/tests/concepts/test_craft_tf.py b/tests/concepts/test_craft_tf.py index 7d940c0a..d1389273 100644 --- a/tests/concepts/test_craft_tf.py +++ b/tests/concepts/test_craft_tf.py @@ -1,15 +1,12 @@ import numpy as np import tensorflow as tf -import random import pytest -import os -from tensorflow.keras.models import Sequential -from tensorflow.keras.layers import Dense, Conv2D, Activation, Flatten, Input -from tensorflow.keras.optimizers import Adam + +from tensorflow.keras.layers import Input from xplique.concepts import CraftTf as Craft -from ..utils import generate_data, generate_model, generate_txt_images_data -from ..utils import download_file + +from ..utils import generate_data, generate_model def test_shape(): @@ -100,172 +97,3 @@ def test_wrong_layers(): number_of_concepts = number_of_concepts, patch_size = patch_size, batch_size = 64) - -def test_classifier(): - """ Check the Craft results on a small fake dataset """ - - input_shape = (64, 64, 3) - nb_labels = 3 - nb_samples = 200 - - # Create a dataset of 'ABC', 'BCD', 'CDE' images - x, y, nb_samples, _ = generate_txt_images_data(input_shape, nb_labels, nb_samples) - - # train a small classifier on the dataset - def create_classifier_model(input_shape=(64, 64, 3), output_shape=10): - model = Sequential() - model.add(Input(shape=input_shape)) - model.add(Conv2D(6, kernel_size=(2, 2))) - model.add(Activation('relu')) - model.add(Conv2D(6, kernel_size=(2, 2))) - model.add(Activation('relu')) - model.add(Conv2D(6, kernel_size=(2, 2))) - model.add(Activation('relu', name='relu')) - model.add(Flatten()) - model.add(Dense(output_shape)) - model.add(Activation('softmax')) - opt = Adam(learning_rate=0.005) - model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy']) - - return model - - model = create_classifier_model(input_shape, nb_labels) - - tf.random.set_seed(0) - np.random.seed(0) - random.seed(0) - - # Retrieve checkpoints - checkpoint_path = "tests/concepts/checkpoints/classifier_test_craft_tf.ckpt" - if not os.path.exists(f"{checkpoint_path}.index"): - os.makedirs("tests/concepts/checkpoints/", exist_ok=True) - identifier = "1NLA7x2EpElzEEmyvFQhD6VS6bMwS_bCs" - download_file(identifier, f"{checkpoint_path}.index") - - identifier = "1wDi-y9b-3I_a-ZtqRlfuib-D7Ox4j8pX" - download_file(identifier, f"{checkpoint_path}.data-00000-of-00001") - - model.load_weights(checkpoint_path) - - acc = np.sum(np.argmax(model(x), axis=1) == np.argmax(y, axis=1)) / nb_samples - assert acc == 1.0 - - # cut the model in two parts (as explained in the paper) - # first part is g(.) our 'input_to_latent' model, second part is h(.) our 'latent_to_logit' model - cut_layer = model.get_layer('relu') - g = tf.keras.Model(model.inputs, cut_layer.output) - h = tf.keras.Model(Input(tensor=cut_layer.output), model.outputs) - - assert np.all(g(x) >= 0.0) - - # Init Craft on the full dataset - craft = Craft(input_to_latent_model = g, - latent_to_logit_model = h, - number_of_concepts = 3, - patch_size = 12, - batch_size = 32) - - # Expected best crop for class 0 (ABC) is AB - AB_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 1 1 1 1 1 1 1 1 - 0 0 0 0 0 0 1 0 0 0 1 1 - 1 0 0 0 0 0 1 0 0 0 0 1 - 1 0 0 0 0 0 1 0 0 0 1 1 - 1 0 0 0 0 0 1 1 1 1 1 1 - 1 1 0 0 0 0 1 0 0 0 0 1 - 0 1 0 0 0 0 1 0 0 0 0 0 - 0 1 1 0 0 0 1 0 0 0 0 1 - 1 1 1 1 1 1 1 1 1 1 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - AB = np.genfromtxt(AB_str.splitlines()) - - # Expected best crop for class 1 (BCD) is BC - BC_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - 1 1 1 1 1 1 0 0 0 0 1 1 - 1 0 0 0 1 1 0 0 0 1 1 0 - 1 0 0 0 0 1 0 0 0 1 0 0 - 1 0 0 0 1 1 0 0 0 1 0 0 - 1 1 1 1 1 1 0 0 0 1 0 0 - 1 0 0 0 0 1 1 0 0 1 0 0 - 1 0 0 0 0 0 1 0 0 1 0 0 - 1 0 0 0 0 1 1 0 0 1 1 0 - 1 1 1 1 1 1 0 0 0 0 1 1 - """ - BC = np.genfromtxt(BC_str.splitlines()) - - # Expected best crop for class 2 (CDE) is DE - DE_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 1 0 0 1 1 1 1 1 1 1 1 0 - 1 1 0 0 0 1 0 0 0 0 1 0 - 0 1 0 0 0 1 0 0 0 0 1 0 - 0 1 1 0 0 1 0 0 1 0 0 0 - 0 1 1 0 0 1 1 1 1 0 0 0 - 0 1 1 0 0 1 0 0 1 0 0 0 - 0 1 0 0 0 1 0 0 0 0 1 1 - 1 1 0 0 0 1 0 0 0 0 1 1 - 1 0 0 1 1 1 1 1 1 1 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - DE = np.genfromtxt(DE_str.splitlines()) - - DE2_str = """ - 0 0 0 0 0 0 0 0 0 0 0 Z - 1 1 1 1 0 0 1 1 1 1 1 1 - 0 0 1 1 1 0 0 0 1 0 0 0 - 0 0 0 0 1 0 0 0 1 0 0 0 - 0 0 0 0 1 1 0 0 1 0 0 1 - 0 0 0 0 1 1 0 0 1 1 1 1 - 0 0 0 0 1 1 0 0 1 0 0 1 - 0 0 0 0 1 0 0 0 1 0 0 0 - 0 0 1 1 1 0 0 0 1 0 0 0 - 1 1 1 1 0 0 1 1 1 1 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - DE2 = np.genfromtxt(DE2_str.splitlines()) - - expected_best_crops = [[AB], [BC], [DE, DE2]] - expected_best_crops_names = ['AB', 'BC', 'DE'] - - # Run 3 Craft studies on each class, and in each case check if the best crop is the expected one - class_check = [False, False, False] - for class_id in range(3): - # Focus on class class_id - # Selecting subset for class {class_id} : {labels_str[class_id]}' - x_subset = x[np.argmax(y, axis=1)==class_id,:,:,:] - - # fit craft on the selected class - crops, crops_u, w = craft.fit(x_subset, class_id) - - # compute importances - importances = craft.estimate_importance() - assert importances[0] > 0.8 - - # find the best crop and compare it to the expected best crop - most_important_concepts = np.argsort(importances)[::-1] - - # Find the best crop for the most important concept - c_id = most_important_concepts[0] - best_crops_ids = np.argsort(crops_u[:, c_id])[::-1] - best_crop = np.array(crops)[best_crops_ids[0]] - - # Compare this best crop to the expectation - predicted_best_crop = np.where(best_crop.sum(axis=2) > 0.25, 1, 0) - for expected_best_crop in expected_best_crops[class_id]: - expected_best_crop = expected_best_crop.astype(np.uint8) - - comparison = predicted_best_crop == expected_best_crop - acc = np.sum(comparison) / len(comparison.ravel()) - check = acc > 0.9 - if check: - class_check[class_id] = True - break - assert np.all(class_check) diff --git a/tests/concepts/test_craft_torch.py b/tests/concepts/test_craft_torch.py index 0ef214e7..ef46843c 100644 --- a/tests/concepts/test_craft_torch.py +++ b/tests/concepts/test_craft_torch.py @@ -1,14 +1,10 @@ import numpy as np -import os import torch import torch.nn as nn import torch.nn.functional as F import pytest -import random from xplique.concepts import CraftTorch as Craft -from ..utils import generate_txt_images_data -from ..utils import download_file def generate_torch_data(x_shape=(3, 32, 32), num_labels=10, samples=100): x = torch.tensor(np.random.rand(samples, *x_shape).astype(np.float32)) @@ -133,177 +129,3 @@ def test_wrong_layers(): number_of_concepts = number_of_concepts, patch_size = patch_size, batch_size = 64) - -def test_classifier(): - """ Check the Craft results on a small fake dataset """ - - input_shape = (64, 64, 3) - nb_labels = 3 - nb_samples = 200 - - torch.manual_seed(0) - torch.use_deterministic_algorithms(True) - random.seed(0) - np.random.seed(0) - - # Create a dataset of 'ABC', 'BCD', 'CDE' images - x, y, nb_samples, _ = generate_txt_images_data(input_shape, nb_labels, nb_samples) - x = np.moveaxis(x, -1, 1) # reorder the axis to match torch format - x, y = torch.Tensor(x), torch.Tensor(y) - - # train a small classifier on the dataset - def create_torch_classifier_model(input_shape=(3, 64, 64), output_shape=10): - flatten_size = 6*(input_shape[1]-3)*(input_shape[2]-3) - model = nn.Sequential( - nn.Conv2d(3, 6, kernel_size=(2, 2)), - nn.ReLU(), - nn.Conv2d(6, 6, kernel_size=(2, 2)), - nn.ReLU(), - nn.Conv2d(6, 6, kernel_size=(2, 2)), - nn.ReLU(), - nn.Flatten(1, -1), - # nn.Dropout(p=0.2), - nn.Linear(flatten_size, output_shape)) - for layer in model: - if isinstance(layer, nn.Conv2d): - nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu') - layer.bias.data.fill_(0.01) - elif isinstance(layer, nn.Linear): - nn.init.xavier_normal_(layer.weight) - layer.bias.data.fill_(0.01) - return model - - model = create_torch_classifier_model((input_shape[-1], *input_shape[0:2]), nb_labels) - - # Retrieve checkpoints - checkpoint_path = "tests/concepts/checkpoints/classifier_test_craft_torch.ckpt" - if not os.path.exists(checkpoint_path): - os.makedirs("tests/concepts/checkpoints/", exist_ok=True) - identifier = "1vz6hMibMEN6_t9yAY9SS4iaMY8G8aAPQ" - download_file(identifier, checkpoint_path) - model.load_state_dict(torch.load(checkpoint_path)) - - # check accuracy - model.eval() - acc = torch.sum(torch.argmax(model(x), axis=1) == torch.argmax(y, axis=1))/len(y) - assert acc > 0.9 - - # cut pytorch model - g = nn.Sequential(*(list(model.children())[:6])) # input to penultimate layer - h = nn.Sequential(*(list(model.children())[6:])) # penultimate layer to logits - assert torch.all(g(x) >= 0.0) - - # Init Craft on the full dataset - craft = Craft(input_to_latent_model = g, - latent_to_logit_model = h, - number_of_concepts = 3, - patch_size = 12, - batch_size = 32, - device='cpu') - - # Expected best crop for class 0 (ABC) is AB - AB_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 1 1 1 1 1 1 1 1 0 0 0 - 0 0 0 1 0 0 0 1 1 0 0 1 - 0 0 0 1 0 0 0 0 1 0 0 1 - 0 0 0 1 0 0 0 1 1 0 0 1 - 0 0 0 1 1 1 1 1 1 0 0 1 - 0 0 0 1 0 0 0 0 1 1 0 1 - 0 0 0 1 0 0 0 0 0 1 0 1 - 0 0 0 1 0 0 0 0 1 1 0 1 - 1 1 1 1 1 1 1 1 1 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - AB = np.genfromtxt(AB_str.splitlines()) - - # Expected best crop for class 1 (BCD) is BC - BC_str = """ - 1 1 1 1 1 1 0 0 0 0 1 1 - 1 0 0 0 1 1 0 0 0 1 1 0 - 1 0 0 0 0 1 0 0 0 1 0 0 - 1 0 0 0 1 1 0 0 0 1 0 0 - 1 1 1 1 1 1 0 0 0 1 0 0 - 1 0 0 0 0 1 1 0 0 1 0 0 - 1 0 0 0 0 0 1 0 0 1 0 0 - 1 0 0 0 0 1 1 0 0 1 1 0 - 1 1 1 1 1 1 0 0 0 0 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - BC = np.genfromtxt(BC_str.splitlines()) - - # Expected best crop for class 2 (CDE) is DE - DE_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 1 0 0 1 1 1 1 1 1 1 1 0 - 1 1 0 0 0 1 0 0 0 0 1 0 - 0 1 0 0 0 1 0 0 0 0 1 0 - 0 1 1 0 0 1 0 0 1 0 0 0 - 0 1 1 0 0 1 1 1 1 0 0 0 - 0 1 1 0 0 1 0 0 1 0 0 0 - 0 1 0 0 0 1 0 0 0 0 1 1 - 1 1 0 0 0 1 0 0 0 0 1 1 - 1 0 0 1 1 1 1 1 1 1 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - DE = np.genfromtxt(DE_str.splitlines()) - - DE2_str = """ - 0 0 0 0 0 0 0 0 0 0 0 0 - 1 1 1 1 0 0 1 1 1 1 1 1 - 0 0 1 1 1 0 0 0 1 0 0 0 - 0 0 0 0 1 0 0 0 1 0 0 0 - 0 0 0 0 1 1 0 0 1 0 0 1 - 0 0 0 0 1 1 0 0 1 1 1 1 - 0 0 0 0 1 1 0 0 1 0 0 1 - 0 0 0 0 1 0 0 0 1 0 0 0 - 0 0 1 1 1 0 0 0 1 0 0 0 - 1 1 1 1 0 0 1 1 1 1 1 1 - 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 0 0 0 0 0 0 0 0 0 - """ - DE2 = np.genfromtxt(DE2_str.splitlines()) - - expected_best_crops = [[AB], [BC], [DE, DE2]] - expected_best_crops_names = ['AB', 'BC', 'DE'] - - # Run 3 Craft studies on each class, and in each case check if the best crop is the expected one - class_check = [False, False, False] - for class_id in range(3): - # Focus on class class_id - # Selecting subset for class {class_id} : {labels_str[class_id]}' - x_subset = x[np.argmax(y, axis=1)==class_id,:,:,:] - - # fit craft on the selected class - crops, crops_u, w = craft.fit(x_subset, class_id) - - # compute importances - importances = craft.estimate_importance() - assert np.all(importances >= 0) - - # find the best crop and compare it to the expected best crop - most_important_concepts = np.argsort(importances)[::-1] - - # Find the best crop for the most important concept - c_id = most_important_concepts[0] - best_crops_ids = np.argsort(crops_u[:, c_id])[::-1] - best_crop = np.array(crops)[best_crops_ids[0]] - best_crop = np.moveaxis(best_crop, 0, -1) - - # Compare this best crop to the expectation - predicted_best_crop = np.where(best_crop.sum(axis=2) > 0.25, 1, 0) - - # Comparison between expected: - for expected_best_crop in expected_best_crops[class_id]: - expected_best_crop = expected_best_crop.astype(np.uint8) - comparison = predicted_best_crop == expected_best_crop - acc = np.sum(comparison) / len(comparison.ravel()) - check = acc > 0.9 - if check: - class_check[class_id] = True - break - assert np.all(class_check) diff --git a/tests/utils.py b/tests/utils.py index 1e32bc8a..483d716b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,18 +1,12 @@ import signal, time import numpy as np -import matplotlib.pyplot as plt from sklearn.linear_model import LinearRegression -from sklearn.datasets import load_svmlight_file -from pathlib import Path -from math import ceil import tensorflow as tf from tensorflow.keras.models import Sequential, Model from tensorflow.keras.layers import (Dense, Conv1D, Conv2D, Activation, GlobalAveragePooling1D, Dropout, Flatten, MaxPooling2D, Input, Reshape) from tensorflow.keras.utils import to_categorical -from PIL import Image, ImageDraw, ImageFont -import urllib.request import requests def generate_data(x_shape=(32, 32, 3), num_labels=10, samples=100): @@ -173,64 +167,6 @@ def model_with_random_nb_boxes(input): return model_with_random_nb_boxes return valid_model -def generate_txt_images_data(x_shape=(32, 32, 3), num_labels=10, samples=100): - """ - Generate an image dataset composed of white texts over black background. - The texts are words of 3 successive letters, the number of classes is set by the - parameter num_labels. The location of the text in the image is cycling over the - image dimensions. - Ex: with num_labels=3, the 3 classes will be 'ABC', 'BCD' and 'CDE'. - - """ - all_labels_str = "".join([chr(lab_idx) for lab_idx in range(65, 65+num_labels+2)]) # ABCDEF - labels_str = [all_labels_str[i:i+3] for i in range(len(all_labels_str) - 2)] # ['ABC', 'BCD', 'CDE', 'DEF'] - - def create_image_from_txt(image_shape, txt, offset_x, offset_y): - # Get a Pillow font (OS independant) - try: - fnt = ImageFont.truetype("FreeMono.ttf", 16) - except OSError: - # dl the font it is it not in the system - url = "https://github.com/python-pillow/Pillow/raw/main/Tests/fonts/FreeMono.ttf" - urllib.request.urlretrieve(url, "tests/FreeMono.ttf") - fnt = ImageFont.truetype("tests/FreeMono.ttf", 16) - - # Make a black image and draw the input text in white at the location offset_x, offset_y - rgb = (len(image_shape) == 3 and image_shape[2] > 1) - if rgb: - image = Image.new("RGB", (image_shape[0], image_shape[1]), (0, 0, 0)) - else: - # grayscale - image = Image.new("L", (image_shape[0], image_shape[1]), 0) - d = ImageDraw.Draw(image) - d.text((offset_x, offset_y), txt, font=fnt, fill='white') - return image - - x = np.empty((samples, *x_shape)).astype(np.float32) - y = np.empty(samples) - - # Iterate over the samples and generate images of labels shifted by increasing offsets - offset_x_max = x_shape[0] - 25 - offset_y_max = x_shape[1] - 10 - - current_label_id = 0 - offset_x = offset_y = 0 - for i in range(samples): - image = create_image_from_txt(x_shape, txt=labels_str[current_label_id], offset_x=offset_x, offset_y=offset_y) - image = np.reshape(image, x_shape) - x[i] = np.array(image).astype(np.float32)/255.0 - y[i] = current_label_id - - # cycle labels - current_label_id = (current_label_id + 1) % num_labels - offset_x = (offset_x + 1) % offset_x_max - offset_y = ((i+2) % offset_y_max) - if offset_y > offset_y_max: - break - x = x[0:i] - y = y[0:i] - return x, to_categorical(y, num_labels), i, labels_str - def download_file(identifier: str, destination: str): """