Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/package' into add-workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
micedre committed Jan 8, 2025
2 parents 6c39972 + b3c7cd9 commit 02d26f5
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 112 deletions.
4 changes: 2 additions & 2 deletions notebooks/experiments.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -520,7 +520,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
"version": "3.12.7"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ ipywidgets
seaborn
ruff>=0.7.1
pre-commit
pytest
182 changes: 98 additions & 84 deletions tests/test_all.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import unittest
import pytest
from pathlib import Path

import numpy as np
Expand All @@ -7,95 +7,109 @@
from sklearn.preprocessing import LabelEncoder

from torchFastText import torchFastText
from torchFastText.preprocess import clean_text_feature

source_path = Path(__file__).resolve()
source_dir = source_path.parent


class tftTest(unittest.TestCase):
def __init__(self, methodName="runTest"):
super().__init__(methodName)
data = {
'Catégorie': ['Politique', 'Politique', 'Politique', 'Politique', 'Politique', 'Politique', 'Politique', 'Politique',
'International', 'International', 'International', 'International', 'International', 'International', 'International', 'International',
'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités',
'Sport', 'Sport', 'Sport', 'Sport', 'Sport', 'Sport', 'Sport', 'Sport'],
'Titre': [
"Nouveau budget présenté par le gouvernement",
"Élections législatives : les principaux candidats en lice",
"Réforme de la santé : les réactions des syndicats",
"Nouvelle loi sur l'éducation : les points clés",
"Les impacts des élections municipales sur la politique nationale",
"Réforme des retraites : les enjeux et débats",
"Nouveau plan de relance économique annoncé",
"La gestion de la crise climatique par le gouvernement",
"Accord climatique mondial : les engagements renouvelés",
"Conflit au Moyen-Orient : nouvelles tensions",
"Économie mondiale : les prévisions pour 2025",
"Sommet international sur la paix : les résultats",
"Répercussions des nouvelles sanctions économiques",
"Les négociations commerciales entre les grandes puissances",
"Les défis de la diplomatie moderne",
"Les conséquences du Brexit sur l'Europe",
"La dernière interview de [Nom de la célébrité]",
"Les révélations de [Nom de la célébrité] sur sa vie privée",
"Le retour sur scène de [Nom de la célébrité]",
"La nouvelle romance de [Nom de la célébrité]",
"Les scandales récents dans l'industrie du divertissement",
"Les projets humanitaires de [Nom de la célébrité]",
"La carrière impressionnante de [Nom de la célébrité]",
"Les derniers succès cinématographiques de [Nom de la célébrité]",
"Le championnat du monde de football : les favoris",
"Record battu par [Nom de l'athlète] lors des Jeux Olympiques",
"La finale de la Coupe de France : qui remportera le trophée?",
"Les transferts les plus chers de la saison",
"Les performances des athlètes français aux championnats du monde",
"Les nouveaux talents à surveiller dans le monde du sport",
"L'impact de la technologie sur les sports traditionnels",
"Les grandes compétitions sportives de l'année à venir"
]
}

df = pd.DataFrame(data)
labelEncoder = LabelEncoder()
y = labelEncoder.fit_transform(df['Catégorie'])
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(df['Titre'], y, test_size=0.1, stratify=y)



@pytest.fixture(scope='session', autouse=True)
def data():
data = {
'Catégorie': ['Politique', 'Politique', 'Politique', 'Politique', 'Politique', 'Politique', 'Politique', 'Politique',
'International', 'International', 'International', 'International', 'International', 'International', 'International', 'International',
'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités', 'Célébrités',
'Sport', 'Sport', 'Sport', 'Sport', 'Sport', 'Sport', 'Sport', 'Sport'],
'Titre': [
"Nouveau budget présenté par le gouvernement",
"Élections législatives : les principaux candidats en lice",
"Réforme de la santé : les réactions des syndicats",
"Nouvelle loi sur l'éducation : les points clés",
"Les impacts des élections municipales sur la politique nationale",
"Réforme des retraites : les enjeux et débats",
"Nouveau plan de relance économique annoncé",
"La gestion de la crise climatique par le gouvernement",
"Accord climatique mondial : les engagements renouvelés",
"Conflit au Moyen-Orient : nouvelles tensions",
"Économie mondiale : les prévisions pour 2025",
"Sommet international sur la paix : les résultats",
"Répercussions des nouvelles sanctions économiques",
"Les négociations commerciales entre les grandes puissances",
"Les défis de la diplomatie moderne",
"Les conséquences du Brexit sur l'Europe",
"La dernière interview de [Nom de la célébrité]",
"Les révélations de [Nom de la célébrité] sur sa vie privée",
"Le retour sur scène de [Nom de la célébrité]",
"La nouvelle romance de [Nom de la célébrité]",
"Les scandales récents dans l'industrie du divertissement",
"Les projets humanitaires de [Nom de la célébrité]",
"La carrière impressionnante de [Nom de la célébrité]",
"Les derniers succès cinématographiques de [Nom de la célébrité]",
"Le championnat du monde de football : les favoris",
"Record battu par [Nom de l'athlète] lors des Jeux Olympiques",
"La finale de la Coupe de France : qui remportera le trophée?",
"Les transferts les plus chers de la saison",
"Les performances des athlètes français aux championnats du monde",
"Les nouveaux talents à surveiller dans le monde du sport",
"L'impact de la technologie sur les sports traditionnels",
"Les grandes compétitions sportives de l'année à venir"
]
}
df = pd.DataFrame(data)
labelEncoder = LabelEncoder()
y = labelEncoder.fit_transform(df['Catégorie'])
df['Titre_cleaned'] = clean_text_feature(df['Titre'])
X_train, X_test, y_train, y_test = train_test_split(df['Titre_cleaned'], y, test_size=0.1, stratify=y)
return X_train, X_test, y_train, y_test

def test_train_no_categorical_variables(self):
num_buckets = 4
embedding_dim = 10
min_count = 1
min_n = 3
max_n = 6
len_word_ngrams = 10
sparse = False
self.torchfasttext = torchFastText(
num_buckets=num_buckets,
embedding_dim=embedding_dim,
min_count=min_count,
min_n=min_n,
max_n=max_n,
len_word_ngrams=len_word_ngrams,
sparse=sparse,
)
self.torchfasttext.train(
np.asarray(self.X_train),
np.asarray(self.y_train),
np.asarray(self.X_test),
np.asarray(self.y_test),
num_epochs=2,
batch_size=32,
lr=0.001
)
self.assertTrue(True, msg="Training Validated")
#print(self.torchfasttext.predict(np.asarray(["Star John elected president"]), 3))
@pytest.fixture(scope='session', autouse=True)
def model():
num_buckets = 4
embedding_dim = 10
min_count = 1
min_n = 2
max_n = 5
len_word_ngrams = 2
sparse = False
return torchFastText(
num_buckets=num_buckets,
embedding_dim=embedding_dim,
min_count=min_count,
min_n=min_n,
max_n=max_n,
len_word_ngrams=len_word_ngrams,
sparse=sparse,
)

#self.assertTrue(True, msg="Predicted")



if __name__ == "__main__":
unittest.main()
def test_model_initialization(model, data):
assert isinstance(model, torchFastText)
assert model.num_buckets == 4
assert model.embedding_dim == 10
assert model.min_count == 1
assert model.min_n == 2
assert model.max_n == 5
assert model.len_word_ngrams == 2
assert not model.sparse
X_train, X_test, y_train, y_test = data
model.train(
np.asarray(X_train),
np.asarray(y_train),
np.asarray(X_test),
np.asarray(y_test),
num_epochs=1,
batch_size=32,
lr=0.001,
num_workers=4
)
assert True, "Training completed without errors"
tokenizer = model.tokenizer
tokenized_text_tokens, tokenized_text, id_to_token_dicts, token_to_id_dicts= tokenizer.tokenize(["Nouveau budget présenté par le gouvernement"])
assert isinstance(tokenized_text, list)
assert len(tokenized_text) > 0
#assert "gouvern </s>" in tokenized_text_tokens[0]
predictions, confidence, all_scores, all_scores_letters = model.predict_and_explain(np.asarray(["Nouveau budget présenté par le gouvernement"]), 2)
assert predictions.shape == (1, 2)
# "predictions" contains the predicted class for each input text, in int format. Need to decode back to have the string format

17 changes: 11 additions & 6 deletions torchFastText/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,17 @@ def predict(

x = padded_batch

other_features = []
for i, categorical_variable in enumerate(categorical_variables):
other_features.append(
torch.tensor(categorical_variable).reshape(batch_size, -1).to(torch.int64)
)

if not self.no_cat_var:
other_features = []
for i, categorical_variable in enumerate(categorical_variables):
other_features.append(
torch.tensor(categorical_variable).reshape(batch_size, -1).to(torch.int64)
)

other_features = torch.stack(other_features).reshape(batch_size, -1).long()
other_features = torch.stack(other_features).reshape(batch_size, -1).long()
else:
other_features = torch.empty(batch_size)

pred = self(
x, other_features
Expand Down Expand Up @@ -301,6 +305,7 @@ def predict_and_explain(self, text, categorical_variables, top_k=1, n=5, cutoff=
all_attr,
id_to_token_dicts,
token_to_id_dicts,
min_n=self.tokenizer.min_n,
padding_index=2009603,
end_of_string_index=0,
)
Expand Down
34 changes: 21 additions & 13 deletions torchFastText/torchFastText.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,13 @@ def validate(self, X, Y, batch_size=256, num_workers=12):
text, categorical_variables, no_cat_var = check_X(X)
y = check_Y(Y)

if categorical_variables.shape[1] != self.num_categorical_features:
raise Exception(
f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})."
)
if categorical_variables is not None:
if categorical_variables.shape[1] != self.num_categorical_features:
raise Exception(
f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})."
)
else:
assert self.pytorch_model.no_cat_var == True

self.pytorch_model.to(X.device)

Expand Down Expand Up @@ -462,10 +465,13 @@ def predict(self, X, top_k=1):

# checking right format for inputs
text, categorical_variables, no_cat_var = check_X(X)
if categorical_variables.shape[1] != self.num_categorical_features:
raise Exception(
f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})."
)
if categorical_variables is not None:
if categorical_variables.shape[1] != self.num_categorical_features:
raise Exception(
f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})."
)
else:
assert self.pytorch_model.no_cat_var == True

return self.pytorch_model.predict(text, categorical_variables, top_k=top_k)

Expand All @@ -475,11 +481,13 @@ def predict_and_explain(self, X, top_k=1):

# checking right format for inputs
text, categorical_variables, no_cat_var = check_X(X)

if categorical_variables.shape[1] != self.num_categorical_features:
raise Exception(
f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})."
)
if categorical_variables is not None:
if categorical_variables.shape[1] != self.num_categorical_features:
raise Exception(
f"X must have the same number of categorical variables as the training data ({self.num_categorical_features})."
)
else:
assert self.pytorch_model.no_cat_var == True

return self.pytorch_model.predict_and_explain(text, categorical_variables, top_k=top_k)

Expand Down
24 changes: 17 additions & 7 deletions torchFastText/utilities/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Utility functions.
"""

import warnings
import difflib
from difflib import SequenceMatcher

Expand Down Expand Up @@ -80,7 +80,7 @@ def map_processed_to_original(processed_words, original_words, n=1, cutoff=0.9):
return word_mapping


def test_end_of_word(all_processed_words, word, target_token, next_token):
def test_end_of_word(all_processed_words, word, target_token, next_token, min_n):
flag = False
if target_token[-1] == ">":
if next_token[0] == "<":
Expand All @@ -90,15 +90,15 @@ def test_end_of_word(all_processed_words, word, target_token, next_token):
flag = False
if next_token[1] != word[0]:
flag = True
if len(next_token) == 3:
if len(next_token) == min_n:
flag = True
if next_token in all_processed_words:
flag = True

return flag


def match_word_to_token_indexes(sentence, tokenized_sentence_tokens):
def match_word_to_token_indexes(sentence, tokenized_sentence_tokens, min_n):
"""
Match words to token indexes in a sentence.
Expand All @@ -116,7 +116,7 @@ def match_word_to_token_indexes(sentence, tokenized_sentence_tokens):
processed_sentence = clean_text_feature([sentence], remove_stop_words=False)[0]
processed_words = processed_sentence.split()
# we know the tokens are in the right order
for index_word, word in enumerate(processed_sentence.split()):
for index_word, word in enumerate(processed_words):
if word not in res:
res[word] = []

Expand All @@ -128,8 +128,16 @@ def match_word_to_token_indexes(sentence, tokenized_sentence_tokens):
word,
tokenized_sentence_tokens[pointer_token],
tokenized_sentence_tokens[pointer_token + 1],
min_n=min_n
):
pointer_token += 1
if pointer_token == len(tokenized_sentence_tokens)-1:
warnings.warn("Error in the tokenization of the sentence")
# workaround to avoid error: each word is asociated to regular ranges
chunck = len(tokenized_sentence_tokens) // len(processed_words)
for idx, word in enumerate(processed_words):
res[word] = range(idx * chunck, min((idx + 1) * chunck, len(tokenized_sentence_tokens)))
return res

pointer_token += 1
end = pointer_token
Expand Down Expand Up @@ -168,6 +176,7 @@ def compute_preprocessed_word_score(
scores,
id_to_token_dicts,
token_to_id_dicts,
min_n,
padding_index=2009603,
end_of_string_index=0,
):
Expand All @@ -193,7 +202,7 @@ def compute_preprocessed_word_score(

for idx, sentence in enumerate(preprocessed_text):
tokenized_sentence_tokens = tokenized_text_tokens[idx] # sentence level, List[str]
word_to_token_idx = match_word_to_token_indexes(sentence, tokenized_sentence_tokens)
word_to_token_idx = match_word_to_token_indexes(sentence, tokenized_sentence_tokens, min_n)
score_sentence_topk = scores[idx] # torch.Tensor, token scores, (top_k, seq_len)

# Calculate the score for each token and map to words
Expand Down Expand Up @@ -325,7 +334,8 @@ def explain_continuous(

for pos, token in enumerate(original_to_token[original_word]):
pos_token = original_to_token_idxs[original_word][pos]
tok = preprocess_token(token)[0]
#tok = preprocess_token(token)[0]
tok = preprocess_token(token)
score_token = all_attr[idx, k, pos_token].item()

# Embed the token at the right indexes of the word
Expand Down

0 comments on commit 02d26f5

Please sign in to comment.