Official code repository of the paper titled BoK: Introducing Bag-of-Keywords Loss for Interpretable Dialogue Response Generation.
Python 3.11 or later.
❱❱❱ pip install -r requirements.txtSet up Perl. Download and unzip meteor-1.5.tar.gz inside 3rdparty directory.
Download the following datasets.
- DailyDialog: Download link http://yanran.li/files/ijcnlp_dailydialog.zip
- PersonaChat: Download data using ParlAI (https://parl.ai/docs/tasks.html#persona-chat)
Set the dataset paths correctly in the following files: DialoGPT/create_data.py and T5/create_data.py
Train GPT2 and T5 by running train.py script provided in the respective directories.
- With BoK loss
❱❱❱ python train.py -path=<model_dir> -src_file=train.py -dt=dd/pc -key- With BoW loss
❱❱❱ python train.py -path=<model_dir> -src_file=train.py -dt=dd/pc -key -all- Basic Model
❱❱❱ python train.py -path=<model_dir> -src_file=train.py -dt=dd/pcGenerate dialogues for DailyDialog and PersonaChat test data.
- For base model (without BoK/BoW loss)
python generate.py -path=<model dir> -dt=dd/pc - For models using BoK/BoW loss (only response generation)
python generate.py -path=<model_dir> -dt=dd/pc -key - For models using BoK/BoW loss (response generation + tok-k token prediction)
python generate_predict.py -path=<model_dir> -dt=dd/pc -keyPostprocess the generated and reference file. This step is required only for the DailyDialog dataset. Download multi-reference test data for DailyDialog from this link.
❱❱❱ python post_process_dailydialog.py -in=<file_name>Compute the metrics.
❱❱❱ python compute_metrics.py -in=<result_path> -hyp=<hyp_file>Note: <result_path> is the directory that contains the <hyp_file> and the <ref_file>.
Follow Dial-M repo to train (or download) Dial-M model.
Run evaluation script.
❱❱❱ python eval_dialm.py -path=<output_dialm> -dt=dd/pc -out=<out_dir> -out=<model_dir> -lbl=<output_label>Note: out is the path of the trained model and lbl is the label that was used to generate the output by running the generate.py script.
Follow USL-H to compute the metrics.