Skip to content

Commit

Permalink
Merge pull request #16 from InseeFrLab/package
Browse files Browse the repository at this point in the history
Package - Ready to release
  • Loading branch information
meilame-tayebjee authored Dec 17, 2024
2 parents 6d0acae + d878765 commit cc5bdc2
Show file tree
Hide file tree
Showing 15 changed files with 631 additions and 1,072 deletions.
89 changes: 82 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,90 @@
# torch-FastText: Efficient text classification with PyTorch
# torchFastText : Efficient text classification with PyTorch

This repository provides a PyTorch-based package of [the fastText architecture](https://github.com/facebookresearch/) [1].
A flexible PyTorch implementation of FastText for text classification with support for categorical features.

#### Installation
## Features

```bash
pip install git+https://github.com/inseefrlab/torch-fasttext@package
```
- Supports text classification with FastText architecture
- Handles both text and categorical features
- N-gram tokenization
- Flexible optimizer and scheduler options
- GPU and CPU support
- Model checkpointing and early stopping
- Prediction and model explanation capabilities

## Installation

#### References
```bash
pip install torchFastText
```

## Key Components

- `build()`: Constructs the FastText model architecture
- `train()`: Trains the model with built-in callbacks and logging
- `predict()`: Generates class predictions
- `predict_and_explain()`: Provides predictions with feature attributions

## Subpackages

- `preprocess`: To preprocess text input, using `nltk` and `unidecode` libraries.
- `explainability`: Simple methods to visualize feature attributions at word and letter levels, using `captum`library.

Run `pip install torchFastText[preprocess]` or `pip install torchFastText[explainability]` to download these optional dependencies.


## Quick Start

```python
from torchFastText import torchFastText

# Initialize the model
model = torchFastText(
num_buckets=1000000,
embedding_dim=100,
min_count=5,
min_n=3,
max_n=6,
len_word_ngrams=True,
sparse=True
)

# Train the model
model.train(
X_train=train_data,
y_train=train_labels,
X_val=val_data,
y_val=val_labels,
num_epochs=10,
batch_size=64
)
# Make predictions
predictions = model.predict(test_data)
```

where ```train_data``` is an array of size $(N,d)$, having the text in string format in the first column, the other columns containing tokenized categorical variables in `int` format.

Please make sure `y_train` contains at least one time each possible label.

## Dependencies

- PyTorch Lightning
- NumPy

## Documentation

For detailed usage and examples, please refer to the [experiments notebook](experiments.ipynb). Use `pip install -r requirements.txt` after cloning the repository to install the necessary dependencies (some are specific to the notebook).

## Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

## License

MIT
## References

Inspired by the original FastText paper [1] and implementation.

[1] A. Joulin, E. Grave, P. Bojanowski, T. Mikolov, [*Bag of Tricks for Efficient Text Classification*](https://arxiv.org/abs/1607.01759)

Expand Down
122 changes: 111 additions & 11 deletions notebooks/experiments.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,24 @@
"import pandas as pd\n",
"import pyarrow.parquet as pq\n",
"import s3fs\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"sys.path.append(\"../\")\n",
"from torchFastText import torchFastText\n",
"from torchFastText.preprocess import (\n",
" clean_text_feature,\n",
" stratified_split_rare_labels,\n",
")\n",
"from torchFastText.preprocess import clean_text_feature\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Some utils functions that will help us format our dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -108,7 +113,37 @@
" df = categorize_surface(df, \"SRF\", like_sirene_3=True)\n",
" df = df[[text_feature, \"EVT\", \"CJ\", \"NAT\", \"TYP\", \"SRF\", \"CRT\", label_col]]\n",
"\n",
" return df, les"
" return df, les\n",
"\n",
"\n",
"def stratified_split_rare_labels(X, y, test_size=0.2, min_train_samples=1):\n",
" # Get unique labels and their frequencies\n",
" unique_labels, label_counts = np.unique(y, return_counts=True)\n",
"\n",
" # Separate rare and common labels\n",
" rare_labels = unique_labels[label_counts == 1]\n",
"\n",
" # Create initial mask for rare labels to go into training set\n",
" rare_label_mask = np.isin(y, rare_labels)\n",
"\n",
" # Separate data into rare and common label datasets\n",
" X_rare = X[rare_label_mask]\n",
" y_rare = y[rare_label_mask]\n",
" X_common = X[~rare_label_mask]\n",
" y_common = y[~rare_label_mask]\n",
"\n",
" # Split common labels stratified\n",
" X_common_train, X_common_test, y_common_train, y_common_test = train_test_split(\n",
" X_common, y_common, test_size=test_size, stratify=y_common\n",
" )\n",
"\n",
" # Combine rare labels with common labels split\n",
" X_train = np.concatenate([X_rare, X_common_train])\n",
" y_train = np.concatenate([y_rare, y_common_train])\n",
" X_test = X_common_test\n",
" y_test = y_common_test\n",
"\n",
" return X_train, X_test, y_train, y_test"
]
},
{
Expand Down Expand Up @@ -136,7 +171,7 @@
" )\n",
" .read_pandas()\n",
" .to_pandas()\n",
").sample(frac=0.01).fillna(np.nan)"
").sample(frac=0.001).fillna(np.nan)"
]
},
{
Expand Down Expand Up @@ -236,7 +271,7 @@
"source": [
"Put the columns in the right format:\n",
" - First column contains the processed text (str)\n",
" - Next ones contain the \"tokenized\" categorical variables in int format"
" - Next ones contain the \"tokenized\" categorical (discrete) variables in int format"
]
},
{
Expand All @@ -245,7 +280,7 @@
"metadata": {},
"outputs": [],
"source": [
"df, _ = clean_and_tokenize_df(df, text_feature=\"libelle_processed\") # NE PAS OUBLIER DE REMETYTRE PROCESSEd\n",
"df, _ = clean_and_tokenize_df(df, text_feature=\"libelle_processed\")\n",
"X = df[[\"libelle_processed\", \"EVT\", \"CJ\", \"NAT\", \"TYP\", \"CRT\", \"SRF\"]].values\n",
"y = df[\"apet_finale\"].values\n",
"print(X)\n",
Expand Down Expand Up @@ -277,7 +312,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# (Optional) Build the torch-fastText model (without training it)"
"# Build the torch-fastText model (without training it)"
]
},
{
Expand Down Expand Up @@ -318,7 +353,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We build the model using the training data. We have now access to the PyTorch model and a tokenizer."
"We build the model using the training data. We have now access to the tokenizer, the PyTorch model as well as a PyTorch Lightning module ready to be trained."
]
},
{
Expand Down Expand Up @@ -349,6 +384,23 @@
"This step is useful to initialize the full torchFastText model without training it, if needed for some reason. But if it is not necessary, and we could have directly launched the training (building is then handled automatically if necessary)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can play with the tokenizer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sentence = [\"lorem ipsum dolor sit amet\"]\n",
"print(model.tokenizer.tokenize(sentence)[2])"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -400,7 +452,55 @@
"metadata": {},
"outputs": [],
"source": [
"model.load_from_checkpoint(model.best_model_path)"
"model.load_from_checkpoint(model.best_model_path) # or any other checkpoint path (string)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Make predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = [\"coiffeur, boulangerie, pâtisserie\"]\n",
"X= np.array([[text[0], 0, 0, 0, 0, 0, 0]]) # our new entry\n",
"TOP_K = 5\n",
"\n",
"pred, conf = model.predict(X, top_k=TOP_K)\n",
"pred_naf = encoder.inverse_transform(pred.reshape(-1))\n",
"subset = naf2008.set_index(\"code\").loc[np.flip(pred_naf)]\n",
"\n",
"for i in range(TOP_K-1, -1, -1):\n",
" print(f\"Prediction: {pred_naf[i]}, confidence: {conf[0, i]}, description: {subset['libelle'][pred_naf[i]]}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Explainability"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torchFastText.explainability.visualisation import (\n",
" visualize_letter_scores,\n",
" visualize_word_scores,\n",
")\n",
"\n",
"pred, conf, all_scores, all_scores_letters = model.predict_and_explain(X)\n",
"visualize_word_scores(all_scores, text, pred_naf.reshape(1, -1))\n",
"visualize_letter_scores(all_scores_letters, text, pred_naf.reshape(1, -1))"
]
}
],
Expand Down
Loading

0 comments on commit cc5bdc2

Please sign in to comment.