This project implements a Transformer model for machine translation from English to French using PyTorch. The code is split into modular files for better readability and maintainability.
- Prerequisites
- File Structure
- Dataset
- Setup and Installation
- Training the Model
- Testing the Model
- Loading the Pre-trained Model
- Implementation Assumptions
- References
- Python 3.7 or higher
- PyTorch 1.7 or higher
- NLTK
- NumPy
- Matplotlib
train.py
: Main script to train the Transformer model.test.py
: Script to test the pre-trained model on the test set.encoder.py
: Contains theEncoder
class and related components.decoder.py
: Contains theDecoder
class and related components.model.py
: Defines theTransformer
model combining encoder and decoder.utils.py
: Includes helper functions and classes such asVocabulary
,Dataset
, and training utilities.data/
: Directory containing the dataset files (train.en
,train.fr
,dev.en
,dev.fr
,test.en
,test.fr
).transformer.pt
: The saved pre-trained model file (if available).
The dataset should be organized as follows:
data/
├── train.en # English training data
├── train.fr # French training data
├── dev.en # English validation data
├── dev.fr # French validation data
├── test.en # English test data
└── test.fr # French test data
Ensure that the dataset files are preprocessed (tokenized and cleaned) and aligned line by line between the source and target languages.
-
Create a Virtual Environment (Optional but Recommended)
python -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate`
-
Install Dependencies
pip install torch nltk numpy matplotlib
-
Download NLTK Data
In your Python environment, download the NLTK tokenizer:
import nltk nltk.download('punkt')
To train the Transformer model from scratch:
-
Update Data Path
In
train.py
, update theDATA_PATH
variable to point to your dataset directory:DATA_PATH = 'path/to/your/data'
-
Run the Training Script
python train.py
This script will:
- Load and preprocess the data.
- Build the vocabularies.
- Initialize and train the Transformer model.
- Save the trained model to
transformer.pt
. - Plot and save the training and validation loss curves as
loss_plot.png
.
-
Adjust Hyperparameters (Optional)
You can adjust the hyperparameters in
train.py
to experiment with different settings:NUM_EPOCHS = 15 BATCH_SIZE = 128 LEARNING_RATE = 0.0001 # ... and others
To evaluate the model on the test set:
-
Ensure the Pre-trained Model is Available
Make sure to download
transformer.pt
. -
Update Data Path
In
test.py
, update theDATA_PATH
variable to point to your dataset directory:DATA_PATH = 'path/to/your/data'
-
Run the Testing Script
python test.py
This script will:
- Load the test data and vocabularies.
- Load the pre-trained model.
- Generate translations for the test set.
- Calculate BLEU scores and save them to
testbleu.txt
. - Plot and save the BLEU score distribution as
bleu_distribution.png
. - Display sample translations and their BLEU scores.
If you have a pre-trained model file transformer.pt
, you can use it without retraining:
-
Place the Model File
Ensure
transformer.pt
is in the same directory as your scripts. -
Run the Testing Script
python test.py
The script will automatically load the model and proceed with evaluation.
- Data Alignment: It is assumed that the source and target datasets are aligned line by line.
- Tokenization: Basic tokenization is performed using NLTK's
word_tokenize
. - Vocabulary Threshold: Words occurring less than two times are treated as
<UNK>
. - Sequence Length: Maximum sequence length is set to 60 tokens. Sequences longer than this are truncated.
- Special Tokens: The following special tokens are used:
<PAD>
: Padding token (index 0)<SOS>
: Start-of-sentence token (index 1)<EOS>
: End-of-sentence token (index 2)<UNK>
: Unknown word token (index 3)
- Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). Attention Is All You Need. Advances in Neural Information Processing Systems.
** Pre-trained Model **
- Download Pre-trained Model: [https://iiitaphyd-my.sharepoint.com/:u:/g/personal/shivashankar_gande_students_iiit_ac_in/EXs3y2gJx8lMm4qeUhvVC54Bixd4y2jqY1YGDOwfhajV2Q?e=20drhX]