From b3c7cd93b9825f21d398c1ce13b285d4faa5624b Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Tue, 7 Jan 2025 11:19:02 +0000 Subject: [PATCH] modified: notebooks/experiments.ipynb modified: requirements.txt new file: tests/test_all.py modified: torchFastText/model/pytorch_model.py modified: torchFastText/torchFastText.py modified: torchFastText/utilities/utils.py --- notebooks/experiments.ipynb | 4 +- requirements.txt | 1 + tests/test_all.py | 115 +++++++++++++++++++++++++++ torchFastText/model/pytorch_model.py | 17 ++-- torchFastText/torchFastText.py | 34 +++++--- torchFastText/utilities/utils.py | 24 ++++-- 6 files changed, 167 insertions(+), 28 deletions(-) create mode 100644 tests/test_all.py diff --git a/notebooks/experiments.ipynb b/notebooks/experiments.ipynb index d17579e..4b96fd7 100644 --- a/notebooks/experiments.ipynb +++ b/notebooks/experiments.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -520,7 +520,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.6" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/requirements.txt b/requirements.txt index e4de6bc..798a9d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ ipywidgets seaborn ruff>=0.7.1 pre-commit +pytest diff --git a/tests/test_all.py b/tests/test_all.py new file mode 100644 index 0000000..9456504 --- /dev/null +++ b/tests/test_all.py @@ -0,0 +1,115 @@ +import pytest +from pathlib import Path + +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder + +from torchFastText import torchFastText +from torchFastText.preprocess import clean_text_feature + +source_path = Path(__file__).resolve() +source_dir = source_path.parent + + +@pytest.fixture(scope='session', autouse=True) +def data(): + data = { + 'Catégorie': ['Politique', 'Politique', 'Politique', 'Politique', 'Politique', 'Politique', 'Politique', 'Politique', + 'International', 'International', 'International', 'International', 'International', 'International', 'International', 'International', + 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', + 'Sport', 'Sport', 'Sport', 'Sport', 'Sport', 'Sport', 'Sport', 'Sport'], + 'Titre': [ + "Nouveau budget présenté par le gouvernement", + "Élections législatives : les principaux candidats en lice", + "Réforme de la santé : les réactions des syndicats", + "Nouvelle loi sur l'éducation : les points clés", + "Les impacts des élections municipales sur la politique nationale", + "Réforme des retraites : les enjeux et débats", + "Nouveau plan de relance économique annoncé", + "La gestion de la crise climatique par le gouvernement", + "Accord climatique mondial : les engagements renouvelés", + "Conflit au Moyen-Orient : nouvelles tensions", + "Économie mondiale : les prévisions pour 2025", + "Sommet international sur la paix : les résultats", + "Répercussions des nouvelles sanctions économiques", + "Les négociations commerciales entre les grandes puissances", + "Les défis de la diplomatie moderne", + "Les conséquences du Brexit sur l'Europe", + "La dernière interview de [Nom de la célébrité]", + "Les révélations de [Nom de la célébrité] sur sa vie privée", + "Le retour sur scène de [Nom de la célébrité]", + "La nouvelle romance de [Nom de la célébrité]", + "Les scandales récents dans l'industrie du divertissement", + "Les projets humanitaires de [Nom de la célébrité]", + "La carrière impressionnante de [Nom de la célébrité]", + "Les derniers succès cinématographiques de [Nom de la célébrité]", + "Le championnat du monde de football : les favoris", + "Record battu par [Nom de l'athlète] lors des Jeux Olympiques", + "La finale de la Coupe de France : qui remportera le trophée?", + "Les transferts les plus chers de la saison", + "Les performances des athlètes français aux championnats du monde", + "Les nouveaux talents à surveiller dans le monde du sport", + "L'impact de la technologie sur les sports traditionnels", + "Les grandes compétitions sportives de l'année à venir" + ] + } + df = pd.DataFrame(data) + labelEncoder = LabelEncoder() + y = labelEncoder.fit_transform(df['Catégorie']) + df['Titre_cleaned'] = clean_text_feature(df['Titre']) + X_train, X_test, y_train, y_test = train_test_split(df['Titre_cleaned'], y, test_size=0.1, stratify=y) + return X_train, X_test, y_train, y_test + +@pytest.fixture(scope='session', autouse=True) +def model(): + num_buckets = 4 + embedding_dim = 10 + min_count = 1 + min_n = 2 + max_n = 5 + len_word_ngrams = 2 + sparse = False + return torchFastText( + num_buckets=num_buckets, + embedding_dim=embedding_dim, + min_count=min_count, + min_n=min_n, + max_n=max_n, + len_word_ngrams=len_word_ngrams, + sparse=sparse, + ) + + + +def test_model_initialization(model, data): + assert isinstance(model, torchFastText) + assert model.num_buckets == 4 + assert model.embedding_dim == 10 + assert model.min_count == 1 + assert model.min_n == 2 + assert model.max_n == 5 + assert model.len_word_ngrams == 2 + assert not model.sparse + X_train, X_test, y_train, y_test = data + model.train( + np.asarray(X_train), + np.asarray(y_train), + np.asarray(X_test), + np.asarray(y_test), + num_epochs=1, + batch_size=32, + lr=0.001, + num_workers=4 + ) + assert True, "Training completed without errors" + tokenizer = model.tokenizer + tokenized_text_tokens, tokenized_text, id_to_token_dicts, token_to_id_dicts= tokenizer.tokenize(["Nouveau budget présenté par le gouvernement"]) + assert isinstance(tokenized_text, list) + assert len(tokenized_text) > 0 + #assert "gouvern " in tokenized_text_tokens[0] + predictions, confidence, all_scores, all_scores_letters = model.predict_and_explain(np.asarray(["Nouveau budget présenté par le gouvernement"]), 2) + assert predictions.shape == (1, 2) + # "predictions" contains the predicted class for each input text, in int format. Need to decode back to have the string format + \ No newline at end of file diff --git a/torchFastText/model/pytorch_model.py b/torchFastText/model/pytorch_model.py index 00c55ca..96df5a7 100644 --- a/torchFastText/model/pytorch_model.py +++ b/torchFastText/model/pytorch_model.py @@ -206,13 +206,17 @@ def predict( x = padded_batch - other_features = [] - for i, categorical_variable in enumerate(categorical_variables): - other_features.append( - torch.tensor(categorical_variable).reshape(batch_size, -1).to(torch.int64) - ) + + if not self.no_cat_var: + other_features = [] + for i, categorical_variable in enumerate(categorical_variables): + other_features.append( + torch.tensor(categorical_variable).reshape(batch_size, -1).to(torch.int64) + ) - other_features = torch.stack(other_features).reshape(batch_size, -1).long() + other_features = torch.stack(other_features).reshape(batch_size, -1).long() + else: + other_features = torch.empty(batch_size) pred = self( x, other_features @@ -301,6 +305,7 @@ def predict_and_explain(self, text, categorical_variables, top_k=1, n=5, cutoff= all_attr, id_to_token_dicts, token_to_id_dicts, + min_n=self.tokenizer.min_n, padding_index=2009603, end_of_string_index=0, ) diff --git a/torchFastText/torchFastText.py b/torchFastText/torchFastText.py index 8cfc147..7420ead 100644 --- a/torchFastText/torchFastText.py +++ b/torchFastText/torchFastText.py @@ -428,10 +428,13 @@ def validate(self, X, Y, batch_size=256, num_workers=12): text, categorical_variables, no_cat_var = check_X(X) y = check_Y(Y) - if categorical_variables.shape[1] != self.num_categorical_features: - raise Exception( - f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})." - ) + if categorical_variables is not None: + if categorical_variables.shape[1] != self.num_categorical_features: + raise Exception( + f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})." + ) + else: + assert self.pytorch_model.no_cat_var == True self.pytorch_model.to(X.device) @@ -462,10 +465,13 @@ def predict(self, X, top_k=1): # checking right format for inputs text, categorical_variables, no_cat_var = check_X(X) - if categorical_variables.shape[1] != self.num_categorical_features: - raise Exception( - f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})." - ) + if categorical_variables is not None: + if categorical_variables.shape[1] != self.num_categorical_features: + raise Exception( + f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})." + ) + else: + assert self.pytorch_model.no_cat_var == True return self.pytorch_model.predict(text, categorical_variables, top_k=top_k) @@ -475,11 +481,13 @@ def predict_and_explain(self, X, top_k=1): # checking right format for inputs text, categorical_variables, no_cat_var = check_X(X) - - if categorical_variables.shape[1] != self.num_categorical_features: - raise Exception( - f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})." - ) + if categorical_variables is not None: + if categorical_variables.shape[1] != self.num_categorical_features: + raise Exception( + f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})." + ) + else: + assert self.pytorch_model.no_cat_var == True return self.pytorch_model.predict_and_explain(text, categorical_variables, top_k=top_k) diff --git a/torchFastText/utilities/utils.py b/torchFastText/utilities/utils.py index abbc38f..447960a 100644 --- a/torchFastText/utilities/utils.py +++ b/torchFastText/utilities/utils.py @@ -1,7 +1,7 @@ """ Utility functions. """ - +import warnings import difflib from difflib import SequenceMatcher @@ -80,7 +80,7 @@ def map_processed_to_original(processed_words, original_words, n=1, cutoff=0.9): return word_mapping -def test_end_of_word(all_processed_words, word, target_token, next_token): +def test_end_of_word(all_processed_words, word, target_token, next_token, min_n): flag = False if target_token[-1] == ">": if next_token[0] == "<": @@ -90,7 +90,7 @@ def test_end_of_word(all_processed_words, word, target_token, next_token): flag = False if next_token[1] != word[0]: flag = True - if len(next_token) == 3: + if len(next_token) == min_n: flag = True if next_token in all_processed_words: flag = True @@ -98,7 +98,7 @@ def test_end_of_word(all_processed_words, word, target_token, next_token): return flag -def match_word_to_token_indexes(sentence, tokenized_sentence_tokens): +def match_word_to_token_indexes(sentence, tokenized_sentence_tokens, min_n): """ Match words to token indexes in a sentence. @@ -116,7 +116,7 @@ def match_word_to_token_indexes(sentence, tokenized_sentence_tokens): processed_sentence = clean_text_feature([sentence], remove_stop_words=False)[0] processed_words = processed_sentence.split() # we know the tokens are in the right order - for index_word, word in enumerate(processed_sentence.split()): + for index_word, word in enumerate(processed_words): if word not in res: res[word] = [] @@ -128,8 +128,16 @@ def match_word_to_token_indexes(sentence, tokenized_sentence_tokens): word, tokenized_sentence_tokens[pointer_token], tokenized_sentence_tokens[pointer_token + 1], + min_n=min_n ): pointer_token += 1 + if pointer_token == len(tokenized_sentence_tokens)-1: + warnings.warn("Error in the tokenization of the sentence") + # workaround to avoid error: each word is asociated to regular ranges + chunck = len(tokenized_sentence_tokens) // len(processed_words) + for idx, word in enumerate(processed_words): + res[word] = range(idx * chunck, min((idx + 1) * chunck, len(tokenized_sentence_tokens))) + return res pointer_token += 1 end = pointer_token @@ -168,6 +176,7 @@ def compute_preprocessed_word_score( scores, id_to_token_dicts, token_to_id_dicts, + min_n, padding_index=2009603, end_of_string_index=0, ): @@ -193,7 +202,7 @@ def compute_preprocessed_word_score( for idx, sentence in enumerate(preprocessed_text): tokenized_sentence_tokens = tokenized_text_tokens[idx] # sentence level, List[str] - word_to_token_idx = match_word_to_token_indexes(sentence, tokenized_sentence_tokens) + word_to_token_idx = match_word_to_token_indexes(sentence, tokenized_sentence_tokens, min_n) score_sentence_topk = scores[idx] # torch.Tensor, token scores, (top_k, seq_len) # Calculate the score for each token and map to words @@ -325,7 +334,8 @@ def explain_continuous( for pos, token in enumerate(original_to_token[original_word]): pos_token = original_to_token_idxs[original_word][pos] - tok = preprocess_token(token)[0] + #tok = preprocess_token(token)[0] + tok = preprocess_token(token) score_token = all_attr[idx, k, pos_token].item() # Embed the token at the right indexes of the word