diff --git a/README.md b/README.md index d436397d4..7ef16a522 100755 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ The following is a summary of the commonly used NLP scenarios covered in the rep |-------------------------| ------------------- |-------|---| |Text Classification |BERT, XLNet, RoBERTa| Text classification is a supervised learning method of learning and predicting the category or the class of a document given its text content. |English, Hindi, Arabic| |Named Entity Recognition |BERT| Named entity recognition (NER) is the task of classifying words or key phrases of a text into predefined entities of interest. |English| -|Text Summarization|BERTSumExt
BERTSumAbs
UniLM (s2s-ft)
MiniLM |Text summarization is a language generation task of summarizing the input text into a shorter paragraph of text.|English +|Text Summarization|BERTSumExt
BERTSumAbs
UniLM (s2s-ft)
MiniLM
T5
BART|Text summarization is a language generation task of summarizing the input text into a shorter paragraph of text.|English |Entailment |BERT, XLNet, RoBERTa| Textual entailment is the task of classifying the binary relation between two natural-language texts, *text* and *hypothesis*, to determine if the *text* agrees with the *hypothesis* or not. |English| |Question Answering |BiDAF, BERT, XLNet| Question answering (QA) is the task of retrieving or generating a valid answer for a given query in natural language, provided with a passage related to the query. |English| |Sentence Similarity |BERT, GenSen| Sentence similarity is the process of computing a similarity score given a pair of text documents. |English| diff --git a/examples/text_summarization/abstractive_summarization_bartt5_cnndm.ipynb b/examples/text_summarization/abstractive_summarization_bartt5_cnndm.ipynb new file mode 100644 index 000000000..194e2a1d2 --- /dev/null +++ b/examples/text_summarization/abstractive_summarization_bartt5_cnndm.ipynb @@ -0,0 +1,965 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) Microsoft Corporation. All rights reserved.\n", + "\n", + "Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Abstractive Summarization on CNN/DM Dataset using Transformers\n", + "\n", + "\n", + "### Summary\n", + "\n", + "This notebook demonstrates how to fine tune Transformers models like [BART](https://arxiv.org/abs/1910.13461) and [T5](https://arxiv.org/abs/1910.10683) together with HuggingFace's [transformers library](https://github.com/huggingface/transformers)for abstractive text summarization. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.\n", + "\n", + "\n", + "\n", + "\n", + "### Before You Start\n", + "\n", + "Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of steps. If QUICK_RUN = True, the notebook takes about 5 minutes to run on a VM with 1 Tesla K80 GPUs with 12GB GPU memory. If QUICK_RUN = False, it takes around 15 minutes for data preprocessing, 15 minutes for fine-tuning and 3 hours for running evaluation on the whole CNN/DM test dataset.\n", + "\n", + "### Additional Notes\n", + "\n", + "* **ROUGE Evalation**: To run rouge evaluation, please refer to the section of compute_rouge_perl in [summarization_evaluation.ipynb](./summarization_evaluation.ipynb) for setup.\n", + "\n", + "* **Distributed Training**:\n", + "Please note that the jupyter notebook only allows to use pytorch [DataParallel](https://pytorch.org/docs/master/nn.html#dataparallel). Faster speed and larger batch size can be achieved with pytorch [DistributedDataParallel](https://pytorch.org/docs/master/notes/ddp.html)(DDP). Script [extractive_summarization_cnndm_distributed_train.py](./extractive_summarization_cnndm_distributed_train.py) shows an example of how to use DDP.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "\n", + "%autoreload 2\n", + "## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n", + "QUICK_RUN = False" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configuration\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/dask/dataframe/utils.py:15: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n", + " import pandas.util.testing as tm\n" + ] + } + ], + "source": [ + "import os\n", + "import shutil\n", + "import sys\n", + "from tempfile import TemporaryDirectory\n", + "import time\n", + "import torch\n", + "\n", + "nlp_path = os.path.abspath(\"../../\")\n", + "if nlp_path not in sys.path:\n", + " sys.path.insert(0, nlp_path)\n", + "\n", + "from utils_nlp.dataset.cnndm import CNNDMSummarizationDataset\n", + "from utils_nlp.eval import compute_rouge_python, compute_rouge_perl\n", + "from utils_nlp.models.transformers.abstractive_summarization_bartt5 import (\n", + " AbstractiveSummarizer)\n", + "\n", + "from utils_nlp.models.transformers.datasets import SummarizationDataset\n", + "import nltk\n", + "from nltk import tokenize\n", + "\n", + "import pandas as pd\n", + "import scrapbook as sb\n", + "import pprint\n", + "start_time = time.time()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Configuration: choose the transformer model to be used" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Several pretrained models have been made available by [Hugging Face](https://github.com/huggingface/transformers). For abstractive summarization, the following pretrained models are supported. " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
model_name
0bart-large
1bart-large-mnli
2bart-large-cnn
3bart-large-xsum
4t5-small
5t5-base
6t5-large
7t5-3b
8t5-11b
\n", + "
" + ], + "text/plain": [ + " model_name\n", + "0 bart-large\n", + "1 bart-large-mnli\n", + "2 bart-large-cnn\n", + "3 bart-large-xsum\n", + "4 t5-small\n", + "5 t5-base\n", + "6 t5-large\n", + "7 t5-3b\n", + "8 t5-11b" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame({\"model_name\": AbstractiveSummarizer.list_supported_models()})" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "# Transformer model being used\n", + "# MODEL_NAME = \"bart-large\"\n", + "MODEL_NAME = \"t5-small\"\n", + "# notebook parameters\n", + "# the cache data path during find tuning\n", + "CACHE_DIR = \"./t5_cache\" #TemporaryDirectory().name\n", + "summarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data Preprocessing\n", + "\n", + "The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples. The code in following cell will download the CNN/DM dataset listed at https://github.com/harvardnlp/sent-summary/.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "# the data path used to save the downloaded data file\n", + "DATA_PATH = \"./bartt5_cnndm\" #TemporaryDirectory().name\n", + "# The number of lines at the head of data file used for preprocessing. -1 means all the lines.\n", + "TOP_N = 100\n", + "if not QUICK_RUN:\n", + " TOP_N = -1" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, sent_split=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\"" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_dataset[0]['tgt_txt']" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "287227\n", + "11490\n" + ] + } + ], + "source": [ + "print(len(train_dataset))\n", + "print(len(test_dataset))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Preprocess the data." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from multiprocessing import Pool\n", + "def preprocess(summarizer, input_data_list, num_workers=50, chunk_size=100, internal_batch_size=5e3):\n", + " \"\"\" preprocess the data for abstractive summarization.\n", + "\n", + " Args:\n", + " input_data_list (list of dictionary): input list where each item is\n", + " an dictionary with fields \"src\" and \"tgt\" and both fields are string.\n", + " num_workers (int, optional): The number of workers in the pool.\n", + " Defautls to 50.\n", + " chunk_size (int, optional): The size that a worker processes.\n", + " Defaults to 100.\n", + " internal_batch_size (int, optional): The size that one pool processes.\n", + " Defaults to 5000. Reduce this number if you see segment fault.\n", + "\n", + " Returns:\n", + " list of dictionary with addtional fields \"source_ids\",\n", + " \"source_mask\" and \"target_ids\".\n", + " \"\"\"\n", + " i = 0\n", + " temp_dir = TemporaryDirectory().name\n", + " os.makedirs(temp_dir, mode=0o777, exist_ok=False)\n", + " temp_file = \".temp_preprocess\"\n", + " processed_length = 0\n", + " result = []\n", + " print(len(input_data_list))\n", + " pool = Pool(num_workers, initializer=summarizer.processor.initializer)\n", + " while processed_length < len(input_data_list):\n", + " max_length = int(min(processed_length+internal_batch_size, len(input_data_list)))\n", + " temp = []\n", + " for j in range(processed_length, max_length):\n", + " temp.append(input_data_list[j])\n", + " result_generator = pool.imap(summarizer.processor.encode_example, temp, chunk_size)\n", + " torch.save(list(result_generator), os.path.join(temp_dir, temp_file+str(i)))\n", + " i += 1\n", + " processed_length = max_length\n", + " #print(processed_length)\n", + "\n", + " pool.close()\n", + " pool.join()\n", + " result = []\n", + " total_batch_number = i\n", + " for i in range(total_batch_number):\n", + " result.extend(torch.load(os.path.join(temp_dir, temp_file+str(i))))\n", + " if os.path.exists(temp_dir):\n", + " shutil.rmtree(temp_dir, ignore_errors=True)\n", + " return result\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "287227\n", + "CPU times: user 1min 43s, sys: 57.7 s, total: 2min 40s\n", + "Wall time: 13min 22s\n" + ] + } + ], + "source": [ + "%%time\n", + "# abs_sum_train = summarizer.processor.preprocess(train_dataset)\n", + "abs_sum_train = preprocess(summarizer, train_dataset)\n", + "# torch.save(abs_sum_train, os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))\n", + "# abs_sum_train = torch.load(os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# abs_sum_train = torch.load(os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))\n", + "\n", + "# torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_{0}_full.pt\".format(MODEL_NAME)))\n", + "# abs_sum_test = torch.load(os.path.join(DATA_PATH, \"test_{0}_full.pt\".format(MODEL_NAME)))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11490\n", + "CPU times: user 4.37 s, sys: 5.82 s, total: 10.2 s\n", + "Wall time: 37.6 s\n" + ] + } + ], + "source": [ + "%%time\n", + "# abs_sum_test= summarizer.processor.preprocess(test_dataset)\n", + "abs_sum_test= preprocess(summarizer, test_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "287227\n", + "11490\n" + ] + } + ], + "source": [ + "print(len(abs_sum_train))\n", + "print(len(abs_sum_test))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Inspect Data" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['source_ids', 'source_mask', 'target_ids'])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "abs_sum_train[0].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'source_ids': tensor([21603, 10, 6005, ..., 3, 31, 7]),\n", + " 'source_mask': tensor([1, 1, 1, ..., 1, 1, 1]),\n", + " 'target_ids': tensor([19367, 3, 1092, 16, 11171, 16, 1337, 3690, 33, 629,\n", + " 26, 30, 8, 96, 11821, 1501, 96, 5191, 3, 849,\n", + " 1926, 90, 99, 348, 845, 167, 33, 132, 38, 3,\n", + " 9, 741, 13, 96, 1792, 179, 3110, 106, 725, 96,\n", + " 298, 3, 75, 29, 29, 8108, 3064, 3, 6, 1868,\n", + " 14314, 7, 3, 10, 96, 3, 23, 183, 8, 520,\n", + " 13, 8, 2753, 96, 90, 99, 348, 845, 8, 358,\n", + " 19, 73, 4998, 11, 3, 88, 3, 31, 7, 6237,\n", + " 21, 483, 3, 5, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "abs_sum_train[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fine tune model\n", + "To start model fine-tuning, we need to specify the paramters as follows." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "BATCH_SIZE_PER_GPU = 4\n", + "GRADIENT_ACCUMULATION_STEPS = 1\n", + "MAX_POS_LENGTH = 512\n", + "\n", + "# GPU used for training\n", + "NUM_GPUS = torch.cuda.device_count()\n", + "\n", + "\n", + "# Learning rate\n", + "LEARNING_RATE=3e-5\n", + "MAX_GRAD_NORM=0.1\n", + "\n", + "# How often the statistics reports show up in training, unit is step.\n", + "REPORT_EVERY=100\n", + "SAVE_EVERY=1000\n", + "\n", + "# total number of steps for training\n", + "MAX_STEPS=100\n", + "# number of steps for warm up\n", + "WARMUP_STEPS=5e1\n", + " \n", + "if not QUICK_RUN:\n", + " MAX_STEPS=1000\n", + " WARMUP_STEPS=5e2\n", + " \n", + "# inference parameters\n", + "TEST_PER_GPU_BATCH_SIZE = 96\n", + "BEAM_SIZE = 3\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "timestamp: 21/05/2020 14:55:00, average loss: 2.950966, time duration: 91.118061,\n", + " number of examples in current reporting: 400, step 100\n", + " out of total 1000\n", + "timestamp: 21/05/2020 14:56:28, average loss: 2.496747, time duration: 87.725076,\n", + " number of examples in current reporting: 400, step 200\n", + " out of total 1000\n", + "timestamp: 21/05/2020 14:57:57, average loss: 2.232086, time duration: 88.453045,\n", + " number of examples in current reporting: 400, step 300\n", + " out of total 1000\n", + "timestamp: 21/05/2020 14:59:25, average loss: 2.104675, time duration: 88.361590,\n", + " number of examples in current reporting: 400, step 400\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:00:53, average loss: 1.996439, time duration: 88.524355,\n", + " number of examples in current reporting: 400, step 500\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:02:22, average loss: 1.918175, time duration: 89.008806,\n", + " number of examples in current reporting: 400, step 600\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:03:52, average loss: 1.981885, time duration: 89.146901,\n", + " number of examples in current reporting: 400, step 700\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:05:20, average loss: 1.892776, time duration: 88.711287,\n", + " number of examples in current reporting: 400, step 800\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:06:49, average loss: 1.842356, time duration: 88.697730,\n", + " number of examples in current reporting: 400, step 900\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:08:17, average loss: 1.920532, time duration: 88.288858,\n", + " number of examples in current reporting: 400, step 1000\n", + " out of total 1000\n", + "./t5_cache\n", + "saving through pytorch to ./t5_cache/t5-small_step_1000.pt\n", + "saving through pytorch to ./t5_cache/fine_tuned/abssum_t5-small.pt\n" + ] + } + ], + "source": [ + "\n", + "summarizer.fit(\n", + " abs_sum_train,\n", + " num_gpus=NUM_GPUS,\n", + " batch_size=BATCH_SIZE_PER_GPU*NUM_GPUS,\n", + " gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n", + " max_steps=MAX_STEPS,\n", + " learning_rate=LEARNING_RATE,\n", + " max_grad_norm=MAX_GRAD_NORM,\n", + " warmup_steps=WARMUP_STEPS,\n", + " verbose=True,\n", + " report_every=REPORT_EVERY,\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'\\nimport torch\\nmodel_path = os.path.join(\\n CACHE_DIR,\\n \"abssum_modelname_{0}_steps_{1}.pt\".format(\\n MODEL_NAME, MAX_STEPS\\n ))\\nsummarizer.save_model(global_step=MAX_STEPS, full_name=model_path)\\n\\nsummarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR)\\nsummarizer.model.load_state_dict(torch.load(model_path, map_location=\"cpu\")[\\'model\\'])\\n'" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# save a finetuned model and load a previous saved model\n", + "\"\"\"\n", + "import torch\n", + "model_path = os.path.join(\n", + " CACHE_DIR,\n", + " \"abssum_modelname_{0}_steps_{1}.pt\".format(\n", + " MODEL_NAME, MAX_STEPS\n", + " ))\n", + "summarizer.save_model(global_step=MAX_STEPS, full_name=model_path)\n", + "\n", + "summarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR)\n", + "summarizer.model.load_state_dict(torch.load(model_path, map_location=\"cpu\")['model'])\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model Evaluation\n", + "\n", + "[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluating text summarization. \n", + "For the settings in this notebook with QUICK_RUN=False, you should get ROUGE scores close to the following numbers:\n", + "\n", + "``\n", + "{'rouge-1': {'f': 0.3532833731474843,\n", + " 'p': 0.5062112092750258,\n", + " 'r': 0.2854026986121758},\n", + " 'rouge-2': {'f': 0.1627400891022247,\n", + " 'p': 0.23802173638805246,\n", + " 'r': 0.13034686738843493},\n", + " 'rouge-l': {'f': 0.2587374492685969,\n", + " 'p': 0.3710902340617733,\n", + " 'r': 0.20909466938819835}}\n", + " `` " + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "source = []\n", + "target = []\n", + "\n", + " \n", + "for i in test_dataset:\n", + " source.append(i[\"src_txt\"]) \n", + " target.append(i['tgt'].replace(\"\",\"\").replace(\"\", \"\").replace(\"\\n\", \"\")) " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r", + "Generating summary: 0%| | 0/120 [00:00", "", line) return line +def _remove_tags(line): + line = re.sub(r"", "", line) + # change to + # pyrouge test requires as sentence splitter + line = re.sub(r"", "", line) + return line + def _target_sentence_tokenization(line): return line.split("") @@ -71,7 +78,7 @@ def CNNDMSummarizationDataset(*args, **kwargs): URLS = ["https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz"] def _setup_datasets( - url, top_n=-1, local_cache_path=".data", prepare_extractive=True + url, top_n=-1, local_cache_path=".data", sent_split=True, prepare_extractive=True ): FILE_NAME = "cnndm.tar.gz" maybe_download(url, FILE_NAME, local_cache_path) @@ -86,6 +93,30 @@ def _setup_datasets( test_source_file = fname if fname.endswith("test.txt.tgt.tagged"): test_target_file = fname + if not sent_split: + return ( + SummarizationDataset( + train_source_file, + target_file=train_target_file, + source_preprocessing=[_clean,], + target_preprocessing=[ + _clean, + _remove_tags, + ], + top_n=top_n + ), + SummarizationDataset( + test_source_file, + source_preprocessing=[_clean,], + target_preprocessing=[ + _clean, + _remove_tags, + ], + target_file=test_target_file, + top_n=top_n + ), + + ) if prepare_extractive: diff --git a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py new file mode 100644 index 000000000..1fa9b71f8 --- /dev/null +++ b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py @@ -0,0 +1,744 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# This script reuses some code from https://github.com/huggingface/transformers/ + + +import functools +import logging +import os +import pickle +from tqdm import tqdm +import torch +from torch.utils.data import ( + DataLoader, + SequentialSampler, + RandomSampler, +) +from torch.utils.data.distributed import DistributedSampler +import torch.multiprocessing +from torch import nn + + +torch.multiprocessing.set_sharing_strategy("file_system") + +from transformers import ( + AutoConfig, + AutoModelWithLMHead, + AutoTokenizer, + BartForConditionalGeneration, + BART_PRETRAINED_MODEL_ARCHIVE_MAP, + T5ForConditionalGeneration, + T5_PRETRAINED_MODEL_ARCHIVE_MAP, +) +from transformers.tokenization_utils import trim_batch + +from utils_nlp.common.pytorch_utils import ( + compute_training_steps, + get_device, + move_model_to_device, + parallelize_model, +) +from utils_nlp.models.transformers.common import Transformer + + +MODEL_MODES = { + "language-modeling": AutoModelWithLMHead, +} + +MODEL_CLASS = {} +MODEL_CLASS.update( + {k: BartForConditionalGeneration for k in BART_PRETRAINED_MODEL_ARCHIVE_MAP} +) +MODEL_CLASS.update( + {k: T5ForConditionalGeneration for k in T5_PRETRAINED_MODEL_ARCHIVE_MAP} +) + +logger = logging.getLogger(__name__) + + +class Predictor(nn.Module): + """ + Predictor which can run on multi-GPUs. + + Args: + model (AbstractiveSummarizer): the summarizer model which will + be used for prediction. + min_length (int): the minimum generated summary length. + max_length (int): the maximum generated summary length. + kwargs (dict): Additional kwargs that will be forwarded + to `Predictor`. Please consult the arguments in function + `PreTrainedModel::generate`. + + """ + + def __init__(self, model, min_length=55, max_length=140, **kwargs): + super(Predictor, self).__init__() + self.model = model.module if hasattr(model, "module") else model + self.min_length = min_length + self.max_length = max_length + self.config = kwargs + + def forward(self, src, src_mask): + """ Generate sequences for models with a LM head. + + Args: + src: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)` + The sequence used as a prompt for the generation. If `None` the + method initializes it as an empty `torch.LongTensor` of shape `(1,)`. + src_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids` + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + Defaults to `None`. + """ + + device = src.device + with torch.no_grad(): + summaries = self.model.generate( + input_ids=src, + attention_mask=src_mask, + min_length=self.min_length, + max_length=self.max_length, + **self.config, + ) + predictions = torch.tensor( + [ + i.tolist()[0 : self.max_length] + + [0] * (self.max_length - i.size()[0]) + for i in summaries + ], + device=device, + ) + + return predictions + + +def validate(summarizer, validate_dataset, num_gpus=1, TOP_N=2): + """ validation function to be used optionally in fine tuning. + + Args: + summarizer(BertSumAbs): The summarizer under fine tuning. + validate_dataset (SummarizationDataset): dataset for validation. + num_gpus (int, optional): number of GPUs used for validation. + Defaults to 1. + TOP_N (int, optional): the number of examples used from + validate_dataset. Defaults to 2. + + Returns: + None. + """ + shortened_dataset = validate_dataset[0:TOP_N] + a = summarizer.processor.collate_fn(shortened_dataset, "cuda:0", True) + c = summarizer.processor.get_inputs( + a, "cuda:0", summarizer.model_name, summarizer.tokenizer, True + ) + + output = summarizer.model(**c) + generated_summaries = summarizer.predict( + shortened_dataset, num_gpus=num_gpus, batch_size=TOP_N + ) + print("validation loss is {}".format(output[0])) + print("prediction is {}".format(generated_summaries[0])) + + +class SummarizationProcessor: + """ Class for preprocessing abstractive summarization data for BART/T5 models. + + Args: + tokenizer(AutoTokenizer): tokenizer for the model used for preprocessing. + config(AutoConfig): config for the model used for preprocessing. + max_source_length (int, optional): Max number of tokens that be used + as input. Defaults to 1024. + max_target_length (int, optional): Max number of tokens that be used + as in target. Defaults to 140. + + """ + + def __init__( + self, tokenizer, config, max_source_length=1024, max_target_length=140, + ): + + self.tokenizer = tokenizer + self.config = config + + self.prefix = config.prefix + self.with_target = False + self.max_source_length = max_source_length + self.max_target_length = max_target_length + + def initializer(self): + global tokenizer + tokenizer = self.tokenizer + + def encode_example( + self, example, + ): + """ preprocess a single data example for abstractive summarization. + + Args: + example (dict): a data item with fields "src" and "tgt" and both + fields are string. + + Returns: + a dictionary with fields "source_ids", + "source_mask" and "target_ids". + + """ + + global tokenizer + result = {} + prefix = self.prefix + max_source_length = self.max_source_length + max_target_length = self.max_target_length + pad_to_max_length = True + return_tensors = "pt" + + tokenized_source = tokenizer.batch_encode_plus( + [prefix + example["src"]], + max_length=max_source_length, + pad_to_max_length=pad_to_max_length, + return_tensors=return_tensors, + ) + + source_ids = tokenized_source["input_ids"].squeeze() + src_mask = tokenized_source["attention_mask"].squeeze() + result["source_ids"] = source_ids + result["source_mask"] = src_mask + if "tgt" in example: + tokenized_target = tokenizer.batch_encode_plus( + [example["tgt"]], + max_length=max_target_length, + pad_to_max_length=pad_to_max_length, + return_tensors=return_tensors, + ) + target_ids = tokenized_target["input_ids"].squeeze() + result["target_ids"] = target_ids + return result + + def preprocess(self, input_data_list): + """ preprocess a list of data for abstractive summarization. + + Args: + input_data_list (list of dictionary): input list where each item is + an dictionary with fields "src" and "tgt" and both fields are string. + + Returns: + list of dictionary with fields "source_ids", + "source_mask" and "target_ids". + """ + + result = [] + for i in input_data_list: + result.append(self.encode_example(i)) + return result + + @staticmethod + def get_inputs(batch, device, model_name, tokenizer=None, train_mode=True): + """ + Creates an input dictionary given a model name. + + Args: + batch (object): A Batch containing lists of source ids, source_mask. + If train_mode is True, it also contains the list of target ids. + device (torch.device): A PyTorch device. + model_name (bool, optional): Model name used to format the inputs. + tokenizer (AutoTokenizer, optional): tokenizer whose pad_token_id + will be used for processing. + train_mode (bool, optional): Training mode flag. + Defaults to True. + + Returns: + dict: Dictionary containing source ids, attention masks. + Decoder input ids and LM labels are only returned when + train_mode is True. + """ + + pad_token_id = tokenizer.pad_token_id + if not train_mode: + source_ids, source_mask = batch["source_ids"], batch["source_mask"] + return { + "input_ids": source_ids, + "attention_mask": source_mask, + } + + else: + y = trim_batch(batch["target_ids"], pad_token_id) + source_ids, source_mask = trim_batch( + batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"] + ) + y_ids = y[:, :-1].contiguous() + lm_labels = y[:, 1:].clone() + lm_labels[y[:, 1:] == pad_token_id] = -100 + + return { + "input_ids": source_ids, + "attention_mask": source_mask, + "decoder_input_ids": y_ids, + "lm_labels": lm_labels, + } + + def collate_fn(self, batch, device, train_mode=False): + """ Collate formats the data passed to the data loader. + In particular we tokenize the data batch after batch to avoid keeping them + all in memory. + + Args: + batch (list of dictionary): input data to be loaded. + device (torch.device): A PyTorch device. + train_mode (bool, optional): Training mode flag. + Defaults to True. + + Returns: + namedtuple: a nametuple containing source ids, source mask. + If train_mode is True, it also contains the target ids. + """ + + input_ids = torch.stack([x["source_ids"] for x in batch]) + masks = torch.stack([x["source_mask"] for x in batch]) + pad_token_id = self.tokenizer.pad_token_id + source_ids, source_mask = trim_batch( + input_ids, pad_token_id, attention_mask=masks + ) + if train_mode: + target_ids = torch.stack([x["target_ids"] for x in batch]) + y = trim_batch(target_ids, pad_token_id) + return { + "source_ids": source_ids.to(device), + "source_mask": source_mask.to(device), + "target_ids": y.to(device), + } + else: + return { + "source_ids": source_ids.to(device), + "source_mask": source_mask.to(device), + } + + +class AbstractiveSummarizer(Transformer): + """class which performs abstractive summarization fine tuning and + prediction based on BART and T5 model """ + + def __init__( + self, + # processor, + model_name="t5-small", + cache_dir=".", + max_source_length=1024, + max_target_length=240, + ): + """Initialize an object of BertSumAbs. + + Args: + model_name (str, optional:) Name of the pretrained model which is used . + `AbstractiveSummarizer.list_supported_models()` to see all supported + model names. Defaults to "t5-small". + cache_dir (str, optional): Directory to cache the model. Defaults to ".". + max_source_length (int, optional): maximum source length for the + input. Defaults to 1024. + max_target_length (int, optional): maximum target length for the + training input. Defaults to 240. + + """ + + if model_name not in self.list_supported_models(): + raise ValueError( + "Model name {} is not supported by AbstractiveSummarizer. " + "Call 'AbstractiveSummarizer.list_supported_models()' to" + "get all supported model " + "names.".format(value) + ) + + self.config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir,) + self.config.output_past = True # to enable num_beams greater than 1 + task_specific_params = self.config.task_specific_params + if task_specific_params is not None: + self.config.update(task_specific_params.get("summarization", {})) + self.config.update({"max_length": max_target_length}) + self.config.update({"attention_dropout": 0.1}) + + self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir,) + + self.processor = SummarizationProcessor( + self.tokenizer, self.config, max_source_length, max_target_length + ) + + self._model_name = model_name + self.model = MODEL_MODES["language-modeling"].from_pretrained( + self.model_name, config=self.config, cache_dir=cache_dir, + ) + + self.cache_dir = cache_dir + self.max_source_length = max_source_length + self.max_target_length = max_target_length + + self.amp = None + self.optimizer = None + self.scheduler = None + + @staticmethod + def list_supported_models(): + return list(MODEL_CLASS) + + def fit( + self, + train_dataset, + num_gpus=None, + gpu_ids=None, + batch_size=4, + local_rank=-1, + max_steps=5e4, + warmup_steps=2e3, + learning_rate=0.002, + weight_decay=0.01, + adam_epsilon=1e-8, + max_grad_norm=1.0, + gradient_accumulation_steps=1, + report_every=10, + save_every=1000, + verbose=True, + seed=None, + fp16=False, + fp16_opt_level="O2", + world_size=1, + rank=0, + validation_function=None, + checkpoint=None, + **kwargs, + ): + """ + Fine-tune pre-trained transofmer models for extractive summarization. + + Args: + train_dataset (SummarizationDataset): Training dataset. + num_gpus (int, optional): The number of GPUs to use. If None, all + available GPUs will be used. If set to 0 or GPUs are + not available, CPU device will be used. Defaults to None. + gpu_ids (list): List of GPU IDs to be used. + If set to None, the first num_gpus GPUs will be used. + Defaults to None. + batch_size (int, optional): Maximum number of examples in each batch. + local_rank (int, optional): Local_rank for distributed training on GPUs. + Local rank means the ranking of the current GPU device on the current + node. Defaults to -1, which means non-distributed training. + max_steps (int, optional): Maximum number of training steps. + Defaults to 5e4. + warmup_steps (int, optional): Number of steps taken to increase + learning rate from 0 to `learning_rate`. Defaults to 2e3. + learning_rate (float, optional): Learning rate of the optimizer. + Defaults to 0.002. + weight_decay (float, optional): Weight decay to apply after each parameter + update. Defaults to 0.01. + adam_epsilon (float, optional): Epsilon of the AdamW optimizer. + Defaults to 1e-8. + max_grad_norm (float, optional): Maximum gradient norm for + gradient clipping. Defaults to 0. + gradient_accumulation_steps (int, optional): Number of batches to accumulate + gradients on between each model parameter update. Defaults to 1. + report_every (int, optional): The interval by steps to print out the + training log. Defaults to 10. + save_every (int, optional): The interval by steps to save the finetuned + model. Defaults to 100. + verbose (bool, optional): Whether to print out the training log. + Defaults to True. + seed (int, optional): Random seed used to improve reproducibility. + Defaults to None. + fp16 (bool, optional): Whether to use mixed precision training. + Defaults to False. + fp16_opt_level (str, optional): optimization level, refer to + https://nvidia.github.io/apex/amp.html#opt-levels for details. + Value choices are: "O0", "O1", "O2", "O3". Defaults to "O2". + world_size (int, optional): Total number of GPUs that will be used. + Defaults to 1. + rank (int, optional): Global rank of the current GPU in distributed + training. It's calculated with the rank of the current node in the + cluster/world and the `local_rank` of the device in the current node. + See an example in :file: `examples/text_summarization/ + abstractive_summarization_bertsum_cnndm_distributed_train.py`. + Defaults to 0. + validation_function (function, optional): function used in fitting to + validate the performance. Default to None. + checkpoint (str, optional): file path for a checkpoint based on which the + training continues. Default to None. + """ + + # move model to devices + checkpoint_state_dict = None + if checkpoint: + # checkpoint should have "model", "optimizer", "amp" + checkpoint_state_dict = torch.load(checkpoint, map_location="cpu") + + # init optimizer + device, num_gpus, amp = self.prepare_model_and_optimizer( + num_gpus=num_gpus, + gpu_ids=gpu_ids, + local_rank=local_rank, + fp16=fp16, + fp16_opt_level=fp16_opt_level, + weight_decay=weight_decay, + learning_rate=learning_rate, + adam_epsilon=adam_epsilon, + checkpoint_state_dict=checkpoint_state_dict, + ) + + self.amp = amp + + global_step = 0 + if ( + checkpoint_state_dict + and "global_step" in checkpoint_state_dict + and checkpoint_state_dict["global_step"] + ): + global_step = checkpoint_state_dict["global_step"] / world_size + print("global_step is {}".format(global_step)) + + self.scheduler = Transformer.get_default_scheduler( + optimizer=self.optimizer, + warmup_steps=warmup_steps, + num_training_steps=max_steps, + ) + if global_step > 0: + self.scheduler.load_state_dict(checkpoint_state_dict["lr_scheduler"]) + + if local_rank == -1: + sampler = RandomSampler(train_dataset) + else: + sampler = DistributedSampler( + train_dataset, num_replicas=world_size, rank=rank + ) + + def collate_fn(data): + return self.processor.collate_fn(data, device, train_mode=True) + + train_dataloader = DataLoader( + train_dataset, + sampler=sampler, + batch_size=batch_size, + collate_fn=collate_fn, + ) + + # compute the max number of training steps + max_steps = compute_training_steps( + train_dataloader, + max_steps=max_steps, + gradient_accumulation_steps=gradient_accumulation_steps, + ) + + get_inputs = functools.partial( + self.processor.get_inputs, tokenizer=self.processor.tokenizer + ) + super().fine_tune( + train_dataloader=train_dataloader, + get_inputs=get_inputs, + device=device, + num_gpus=num_gpus, + max_steps=max_steps, + global_step=global_step, + max_grad_norm=max_grad_norm, + gradient_accumulation_steps=gradient_accumulation_steps, + verbose=verbose, + seed=seed, + report_every=report_every, + save_every=save_every, + optimizer=self.optimizer, + scheduler=self.scheduler, + fp16=fp16, + amp=amp, + validation_function=validation_function, + ) + + # release GPU memories + self.model.cpu() + torch.cuda.empty_cache() + + self.save_model(global_step=max_steps) + + def predict( + self, + test_dataset, + num_gpus=None, + gpu_ids=None, + local_rank=-1, + batch_size=16, + min_length=56, + max_length=140, + num_beams=4, + length_penalty=2.0, + no_repeat_ngram_size=3, + early_stopping=True, + fp16=False, + verbose=True, + checkpoint=None, + **predictor_kwargs, + ): + """ + Predict the summarization for the input data iterator. + + Args: + test_dataset (SummarizationDataset): Dataset for which the summary + to be predicted. + num_gpus (int, optional): The number of GPUs used in prediction. + Defaults to 1. + gpu_ids (list): List of GPU IDs to be used. + If set to None, the first num_gpus GPUs will be used. + Defaults to None. + local_rank (int, optional): Local rank of the device in distributed + inferencing. Defaults to -1, which means non-distributed inferencing. + batch_size (int, optional): The number of test examples in each batch. + Defaults to 16. + min_length (int, optional): Minimum number of tokens in the output sequence. + Defaults to 140. + max_length (int, optional): Maximum number of tokens in output + sequence. Defaults to 150. + num_beams (int, optional): Beam size for beam search. Defaults to 4. + length_penalty (float, optional): Exponential penalty to the length. + Defaults to 2.0. + no_repeat_ngram_size (int, optional): If set to int >0, all ngrams of size + `no_repeat_ngram_size` can only occur once in the generated summary. + Defaults to 3. + early_stopping (bool, optional): If set to `True` beam search is stopped + when at least `num_beams` sentences finished per batch. + Defautls to True. + fp16 (bool, optional): Whether to use half-precision model for prediction. + Defaults to False. + verbose (bool, optional): Whether to print out the training log. + Defaults to True. + checkpoint (str, optional): + predictor_kwargs (dict, optional): Additional kwargs that will be forwarded + to `Predictor`. Please consult the arguments in function + `PreTrainedModel::generate`. + + Returns: + List of strings which are the summaries + + """ + + device, num_gpus = get_device( + num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank + ) + model = move_model_to_device(self.model, device) + + checkpoint_state_dict = None + if checkpoint: + # checkpoint should have "model", "optimizer", "amp" + checkpoint_state_dict = torch.load(checkpoint, map_location="cpu") + model.load_state_dict(checkpoint_state_dict["model"]) + + model.eval() + + model = parallelize_model( + model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank, + ) + + if fp16: + model = model.half() + + test_sampler = SequentialSampler(test_dataset) + + def collate_fn(data): + return self.processor.collate_fn(data, device, train_mode=False) + + test_dataloader = DataLoader( + test_dataset, + sampler=test_sampler, + batch_size=batch_size, + collate_fn=collate_fn, + ) + print("dataset length is {}".format(len(test_dataset))) + + predictor = Predictor( + model, + min_length, + max_length, + num_beams=num_beams, + length_penalty=length_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + early_stopping=early_stopping, + **predictor_kwargs, + ) + + # move model to devices + def this_model_move_callback(model, device): + model = move_model_to_device(model, device) + return parallelize_model( + model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank + ) + + predictor = this_model_move_callback(predictor, device) + + generated_summaries = [] + + for batch in tqdm( + test_dataloader, desc="Generating summary", disable=not verbose + ): + input_ids, masks = trim_batch( + batch["source_ids"], + self.tokenizer.pad_token_id, + attention_mask=batch["source_mask"], + ) + summaries = predictor(input_ids, masks) + decoded_summaries = [ + self.tokenizer.decode( + g, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + for g in summaries + ] + + generated_summaries.extend(decoded_summaries) + + # release GPU memories + # self.model.cpu() + del batch + torch.cuda.empty_cache() + + return generated_summaries + + def save_model(self, global_step=None, full_name=None): + """ + save the trained model. + + Args: + global_step (int, optional): The number of steps that the model has been + finetuned for. Defaults to None. + full_name (str, optional): File name to save the model's `state_dict()`. + If it's None, the model is going to be saved under "fine_tuned" folder + of the cached directory of the object. Defaults to None. + """ + model_to_save = ( + self.model.module if hasattr(self.model, "module") else self.model + ) # Take care of distributed/parallel training + + if full_name is None: + output_model_dir = os.path.join(self.cache_dir, "fine_tuned") + os.makedirs(self.cache_dir, exist_ok=True) + os.makedirs(output_model_dir, exist_ok=True) + full_name = os.path.join( + output_model_dir, "abssum_{}.pt".format(self.model_name) + ) + else: + path, filename = os.path.split(full_name) + print(path) + os.makedirs(path, exist_ok=True) + + checkpoint = { + "optimizer": self.optimizer.state_dict() if self.optimizer else None, + "lr_scheduler": self.scheduler.state_dict() if self.scheduler else None, + "model": model_to_save.state_dict(), + "amp": self.amp.state_dict() if self.amp else None, + "global_step": global_step, + "max_source_length": self.max_source_length, + "max_target_length": self.max_target_length, + } + + logger.info("Saving model checkpoint to %s", full_name) + try: + print("saving through pytorch to {}".format(full_name)) + torch.save(checkpoint, full_name) + except OSError: + try: + print("saving as pickle") + pickle.dump(checkpoint, open(full_name, "wb")) + except Exception: + raise + except Exception: + raise diff --git a/utils_nlp/models/transformers/common.py b/utils_nlp/models/transformers/common.py index cbe845c5f..ec193e7ed 100755 --- a/utils_nlp/models/transformers/common.py +++ b/utils_nlp/models/transformers/common.py @@ -232,7 +232,7 @@ def fine_tune( epoch_iterator = tqdm( train_dataloader, desc="Iteration", - disable=local_rank not in [-1, 0] or not verbose, + disable=True #local_rank not in [-1, 0] or not verbose, ) for step, batch in enumerate(epoch_iterator): inputs = get_inputs(batch, device, self.model_name) @@ -291,6 +291,10 @@ def fine_tune( ) logger.info(log_line) print(log_line) + if validation_function: + validation_log = validation_function(self) + logger.info(validation_log) + print(validation_log) accum_loss = 0 train_size = 0 start = end @@ -318,10 +322,6 @@ def fine_tune( self.cache_dir, f"{self.model_name}_step_{global_step}.pt" ) self.save_model(global_step, saved_model_path) - if validation_function: - validation_log = validation_function(self) - logger.info(validation_log) - print(validation_log) if global_step > max_steps: epoch_iterator.close() break diff --git a/utils_nlp/models/transformers/datasets.py b/utils_nlp/models/transformers/datasets.py index 0c659f190..e21e7d95b 100644 --- a/utils_nlp/models/transformers/datasets.py +++ b/utils_nlp/models/transformers/datasets.py @@ -519,7 +519,7 @@ def parallel_preprocess( word_tokenize=word_tokenize, ), input_data, - chunksize=min(1, int(len(input_data) / num_pool)), + chunksize=max(1, int(len(input_data) / num_pool)), ) p.close() p.join() diff --git a/utils_nlp/models/transformers/extractive_summarization.py b/utils_nlp/models/transformers/extractive_summarization.py index 2753685df..e7cb7d84f 100644 --- a/utils_nlp/models/transformers/extractive_summarization.py +++ b/utils_nlp/models/transformers/extractive_summarization.py @@ -302,7 +302,7 @@ def parallel_preprocess(input_data, preprocess, num_pool=-1): p = Pool(num_pool) results = p.map( - preprocess, input_data, chunksize=min(1, int(len(input_data) / num_pool)), + preprocess, input_data, chunksize=max(1, int(len(input_data) / num_pool)), ) p.close() p.join()