diff --git a/.gitignore b/.gitignore index 44b34da..9bf0d92 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,5 @@ data/ # Experimental files notebooks/.vector_cache/* -notebooks/my_awesome_qa_model/runs* \ No newline at end of file +notebooks/my_awesome_qa_model/runs/* +flexnlp/notebooks/my_awesome_qa_model/runs/* \ No newline at end of file diff --git a/flexnlp/notebooks/Federated IMDb PT using FLExible with a GRU.ipynb b/flexnlp/notebooks/Federated IMDb PT using FLExible with a GRU.ipynb index 119aaa3..1149c97 100644 --- a/flexnlp/notebooks/Federated IMDb PT using FLExible with a GRU.ipynb +++ b/flexnlp/notebooks/Federated IMDb PT using FLExible with a GRU.ipynb @@ -39,6 +39,13 @@ "print(device)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Cargamos el dataset" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/flexnlp/notebooks/Federated QA with Hugginface using FLEXIBLE.ipynb b/flexnlp/notebooks/Federated QA with Hugginface using FLEXIBLE.ipynb new file mode 100644 index 0000000..d2fa8ce --- /dev/null +++ b/flexnlp/notebooks/Federated QA with Hugginface using FLEXIBLE.ipynb @@ -0,0 +1,650 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "import torch\n", + "import torch.nn as nn\n", + "from datasets import load_dataset\n", + "from datasets import Dataset as HFDataset\n", + "from transformers import AutoTokenizer\n", + "from transformers import DefaultDataCollator\n", + "from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer\n", + "import collections\n", + "import numpy as np\n", + "import evaluate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = (\n", + " \"cuda\"\n", + " if torch.cuda.is_available()\n", + " else \"mps\"\n", + " if torch.backends.mps.is_available()\n", + " else \"cpu\"\n", + ")\n", + "\n", + "print(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load el dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load a percentage of squal\n", + "squad = load_dataset(\"squad\", split=\"train[:1%]\")\n", + "# Split 80% train, 20% test\n", + "squad = squad.train_test_split(test_size=0.2)\n", + "print(squad)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Preprocess" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_checkpoint = \"distilbert-base-uncased\"\n", + "#model_checkpoint = \"bert-base-cased\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n", + "\n", + "max_length = 384\n", + "stride = 128\n", + "\n", + "\n", + "def preprocess_training_examples_as_lists(examples, answers_examples):\n", + " \"\"\"\n", + " Function that preprocess the data that comes as a list \n", + " instead as a Dataset type.\n", + " Args:\n", + " examples (list[list]): List of lists containg the examples to\n", + " preprocess. ['id', 'title', 'context', 'question']\n", + " answers (list[str]): List containing the answers\n", + " \"\"\"\n", + " questions = [q[3].strip() for q in examples]\n", + " contexts = [c[2] for c in examples]\n", + " inputs = tokenizer(\n", + " questions,\n", + " # examples[\"context\"],\n", + " contexts,\n", + " max_length=max_length,\n", + " truncation=\"only_second\",\n", + " stride=stride,\n", + " return_overflowing_tokens=True,\n", + " return_offsets_mapping=True,\n", + " padding=\"max_length\",\n", + " )\n", + "\n", + " offset_mapping = inputs.pop(\"offset_mapping\")\n", + " sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n", + " # answers = examples[\"answers\"]\n", + " answers = [answers_examples[1][i] for i in range(len(answers_examples[1]))]\n", + " start_positions = []\n", + " end_positions = []\n", + "\n", + " for i, offset in enumerate(offset_mapping):\n", + " sample_idx = sample_map[i]\n", + " answer = answers[sample_idx]\n", + " start_char = answer[\"answer_start\"][0]\n", + " end_char = answer[\"answer_start\"][0] + len(answer[\"text\"][0])\n", + " sequence_ids = inputs.sequence_ids(i)\n", + "\n", + " # Find the start and end of the context\n", + " idx = 0\n", + " while sequence_ids[idx] != 1:\n", + " idx += 1\n", + " context_start = idx\n", + " while sequence_ids[idx] == 1:\n", + " idx += 1\n", + " context_end = idx - 1\n", + "\n", + " # If the answer is not fully inside the context, label is (0, 0)\n", + " if offset[context_start][0] > start_char or offset[context_end][1] < end_char:\n", + " start_positions.append(0)\n", + " end_positions.append(0)\n", + " else:\n", + " # Otherwise it's the start and end token positions\n", + " idx = context_start\n", + " while idx <= context_end and offset[idx][0] <= start_char:\n", + " idx += 1\n", + " start_positions.append(idx - 1)\n", + "\n", + " idx = context_end\n", + " while idx >= context_start and offset[idx][1] >= end_char:\n", + " idx -= 1\n", + " end_positions.append(idx + 1)\n", + "\n", + " inputs[\"start_positions\"] = start_positions\n", + " inputs[\"end_positions\"] = end_positions\n", + " return HFDataset.from_dict(inputs)\n", + "\n", + "def preprocess_validation_examples(examples):\n", + " questions = [q.strip() for q in examples[\"question\"]]\n", + " inputs = tokenizer(\n", + " questions,\n", + " examples[\"context\"],\n", + " max_length=max_length,\n", + " truncation=\"only_second\",\n", + " stride=stride,\n", + " return_overflowing_tokens=True,\n", + " return_offsets_mapping=True,\n", + " padding=\"max_length\",\n", + " )\n", + "\n", + " sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n", + " example_ids = []\n", + "\n", + " for i in range(len(inputs[\"input_ids\"])):\n", + " sample_idx = sample_map[i]\n", + " example_ids.append(examples[\"id\"][sample_idx])\n", + "\n", + " sequence_ids = inputs.sequence_ids(i)\n", + " offset = inputs[\"offset_mapping\"][i]\n", + " inputs[\"offset_mapping\"][i] = [\n", + " o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)\n", + " ]\n", + "\n", + " inputs[\"example_id\"] = example_ids\n", + " return inputs\n", + "\"\"\"\n", + "train_dataset = squad[\"train\"].map(\n", + " preprocess_training_examples,\n", + " batched=True,\n", + " remove_columns=squad[\"train\"].column_names,\n", + ")\n", + "\"\"\"\n", + "test_dataset = squad[\"test\"].map(\n", + " preprocess_validation_examples,\n", + " batched=True,\n", + " remove_columns=squad[\"test\"].column_names,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = squad[\"train\"]\n", + "print(train_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# test_dataset = squad[\"test\"]\n", + "print(test_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# From centralized data to federated data\n", + "\n", + "First we're going to federate the dataset using the FedDataDristibution class, that has functions to load multiple datasets from deep learning libraries such as PyTorch or TensorFlow. In this notebook we are using PyTorch, so we need to use the functions from the PyTorch ecosystem, and for the text datasets, we need to use the function `from_config_with_huggingface_dataset`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.data import FedDatasetConfig, FedDataDistribution\n", + "\n", + "config = FedDatasetConfig(seed=0)\n", + "config.n_clients = 2\n", + "config.replacement = False # ensure that clients do not share any data\n", + "config.client_names = ['client1', 'client2'] # Optional\n", + "flex_dataset = FedDataDistribution.from_config_with_huggingface_dataset(data=train_dataset, config=config,\n", + " X_columns=['id', 'title', 'context', 'question'],\n", + " label_columns=['answers']\n", + " # X_columns=['input_ids', 'attention_mask'],\n", + " # label_columns=['start_positions', 'end_positions']\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We may also want to use the FLEXible dataset for the test data, so we just use da function `from_huggingface_dataset` in the Dataset class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.data import Dataset\n", + "\n", + "test_dataset = Dataset.from_huggingface_dataset(test_dataset,\n", + " X_columns=['input_ids', 'attention_mask'])\n", + " # label_columns=['start_positions', 'end_positions'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2) Federate a model with FLEXible.\n", + "\n", + "Once we've federated the dataset, it's time to create the FlexPool. The FlexPool class is the one that simulates the real-time scenario for federated learning, so it is in charge of the communications across actors. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.model import FlexModel\n", + "from flex.pool import FlexPool\n", + "\n", + "from flex.pool.decorators import init_server_model\n", + "from flex.pool.decorators import deploy_server_model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook we are going to simulate a client-server architecture, which we can easily build using the FlexPool class, using the function `client_server_architecture`. This function needs a FlexDataset, which we already have prepared, and a function to initialize the server model, which we have to create.\n", + "\n", + "The model we are going to use is a simple LSTM, which will have the embeddings, the LSTM, a Linear layer and the output layer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@init_server_model\n", + "def build_server_model():\n", + " server_flex_model = FlexModel()\n", + "\n", + " server_flex_model['model'] = AutoModelForQuestionAnswering.from_pretrained(\"distilbert-base-uncased\")\n", + " # Required to store this for later stages of the FL training process\n", + " server_flex_model['training_args'] = TrainingArguments(\n", + " output_dir=\"my_awesome_qa_model\",\n", + " evaluation_strategy=\"epoch\",\n", + " learning_rate=2e-5,\n", + " per_device_train_batch_size=16,\n", + " per_device_eval_batch_size=16,\n", + " num_train_epochs=3,\n", + " weight_decay=0.01,\n", + " )\n", + "\n", + " # server_flex_model['trainer'] = trainer\n", + "\n", + " return server_flex_model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we've defined the function to initialize the server model, we can create the FlexPool using the function `client_server_architecture`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flex_pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)\n", + "\n", + "clients = flex_pool.clients\n", + "servers = flex_pool.servers\n", + "aggregators = flex_pool.aggregators\n", + "\n", + "print(f\"Number of nodes in the pool {len(flex_pool)}: {len(servers)} server plus {len(clients)} clients. The server is also an aggregator\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can use the decorator `deploy_server_model` to create a custom function that deploys our server model, or we can use the primitive `deploy_server_model_pt` to deploy the server model to the clients." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.pool import deploy_server_model, deploy_server_model_pt\n", + "\n", + "@deploy_server_model\n", + "def copy_server_model_to_clients(server_flex_model: FlexModel):\n", + " return deepcopy(server_flex_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "servers.map(copy_server_model_to_clients, clients) # Using the function created with the decorator\n", + "# servers.map(deploy_server_model_pt, clients) # Using the primitive function" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As text needs to be preprocessed and batched on the clients, we can do it on the train function.\n", + "\n", + "As we have preprocesed the text before federating the data, and we are using the `Trainer` class from the Transformers library, we can train the client's models using the `train` function from the `Trainer` class" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset as TorchDataset\n", + "\n", + "class QADataset(TorchDataset):\n", + " def __init__(self, encodings, labels) -> None:\n", + " self.encodings = encodings\n", + " self.labels = labels" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Train each client's model\n", + "def train(client_flex_model: FlexModel, client_data: Dataset):\n", + " print(\"Training client\")\n", + " model = client_flex_model['model']\n", + " training_args = client_flex_model['training_args']\n", + " # client_train_dataset = client_data.to_numpy()\n", + " X_data = client_data.X_data.tolist()\n", + " y_data = client_data.to_list()\n", + " client_train_dataset = preprocess_training_examples_as_lists(examples=X_data, answers_examples=y_data)\n", + " trainer = Trainer(\n", + " model = model,\n", + " args=training_args,\n", + " train_dataset=client_train_dataset,\n", + " tokenizer=tokenizer,\n", + " )\n", + " trainer.train()\n", + "\n", + "clients.map(train)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After training the model, we have to aggregate the weights from the clients model in order to update the global model. To to so, we are going to use the primitive `collect_clients_weights_pt`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.pool import collect_clients_weights_pt\n", + "\n", + "aggregators.map(collect_clients_weights_pt, clients)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the weights are aggregated, we aggregate them. In this notebook we use the FedAvg method that is already implemented in FLEXible." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.pool import fed_avg\n", + "\n", + "aggregators.map(fed_avg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The function `set_aggregated_weights_pt` sed the aggregated weights to the server model to update it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.pool import set_aggregated_weights_pt\n", + "\n", + "aggregators.map(set_aggregated_weights_pt, servers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now it's turn to evaluate the global model. To do so, we have to create a function using the decoratod `evaluate_server_model`. \n", + "\n", + "In question answering we have to postprocess the predictions obtained, so we have created the function `compute_metrics` that will give us the performance of the model. Here we use the trainer function too. To do so, we creater a trainer instance in the server's FlexModel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_best = 20\n", + "max_answer_length = 30\n", + "predicted_answers = []\n", + "metric = evaluate.load(\"squad\")\n", + "\n", + "def compute_metrics(start_logits, end_logits, features, examples):\n", + " example_to_features = collections.defaultdict(list)\n", + " for idx, feature in enumerate(features):\n", + " example_to_features[feature[\"example_id\"]].append(idx)\n", + "\n", + " predicted_answers = []\n", + " for example in tqdm(examples):\n", + " example_id = example[\"id\"]\n", + " context = example[\"context\"]\n", + " answers = []\n", + "\n", + " # Loop through all features associated with that example\n", + " for feature_index in example_to_features[example_id]:\n", + " start_logit = start_logits[feature_index]\n", + " end_logit = end_logits[feature_index]\n", + " offsets = features[feature_index][\"offset_mapping\"]\n", + "\n", + " start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()\n", + " end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()\n", + " for start_index in start_indexes:\n", + " for end_index in end_indexes:\n", + " # Skip answers that are not fully in the context\n", + " if offsets[start_index] is None or offsets[end_index] is None:\n", + " continue\n", + " # Skip answers with a length that is either < 0 or > max_answer_length\n", + " if (\n", + " end_index < start_index\n", + " or end_index - start_index + 1 > max_answer_length\n", + " ):\n", + " continue\n", + "\n", + " answer = {\n", + " \"text\": context[offsets[start_index][0] : offsets[end_index][1]],\n", + " \"logit_score\": start_logit[start_index] + end_logit[end_index],\n", + " }\n", + " answers.append(answer)\n", + "\n", + " # Select the answer with the best score\n", + " if len(answers) > 0:\n", + " best_answer = max(answers, key=lambda x: x[\"logit_score\"])\n", + " predicted_answers.append(\n", + " {\"id\": example_id, \"prediction_text\": best_answer[\"text\"]}\n", + " )\n", + " else:\n", + " predicted_answers.append({\"id\": example_id, \"prediction_text\": \"\"})\n", + "\n", + " theoretical_answers = [{\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in examples]\n", + " return metric.compute(predictions=predicted_answers, references=theoretical_answers)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: Test the training phase on GPU before evaluating the model\n", + "from flex.pool import evaluate_server_model\n", + "\n", + "\n", + "@evaluate_server_model\n", + "def evaluate_global_model(server_flex_model: FlexModel, test_data=None):\n", + " model = server_flex_model[\"model\"]\n", + " training_args = server_flex_model[\"training_args\"]\n", + " trainer = Trainer(\n", + " model = model,\n", + " args=training_args,\n", + " train_dataset=test_data,\n", + " # eval_dataset=validation_dataset,\n", + " tokenizer=tokenizer,\n", + " # data_collator=data_collator,\n", + " )\n", + " predictions, _, _ = trainer.predict(test_data)\n", + " start_logits, end_logits = predictions\n", + " print(\"Server metrics:\", compute_metrics(start_logits, end_logits, test_data, squad[\"test\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "servers.map(evaluate_global_model, test_data=test_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run the federated learning experiment for a few rounds\n", + "\n", + "Now, we can summarize the steps provided above and run the federated experiment for multiple rounds:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train_n_rounds(n_rounds, clients_per_round=2): \n", + " pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)\n", + " for i in range(n_rounds):\n", + " print(f\"\\nRunning round: {i+1} of {n_rounds}\")\n", + " selected_clients_pool = pool.clients.select(clients_per_round)\n", + " selected_clients = selected_clients_pool.clients\n", + " print(f\"Selected clients for this round: {len(selected_clients)}\")\n", + " # Deploy the server model to the selected clients\n", + " pool.servers.map(deploy_server_model_pt, selected_clients)\n", + " # Each selected client trains her model\n", + " selected_clients.map(train)\n", + " # The aggregador collects weights from the selected clients and aggregates them\n", + " pool.aggregators.map(collect_clients_weights_pt, selected_clients)\n", + " pool.aggregators.map(fed_avg)\n", + " # The aggregator send its aggregated weights to the server\n", + " pool.aggregators.map(set_aggregated_weights_pt, pool.servers)\n", + " pool.servers.map(evaluate_global_model, test_data=test_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# train_n_rounds(5)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "flexible", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/flexnlp/notebooks/Federated_QA.py b/flexnlp/notebooks/Federated_QA.py new file mode 100644 index 0000000..2620345 --- /dev/null +++ b/flexnlp/notebooks/Federated_QA.py @@ -0,0 +1,380 @@ +from copy import deepcopy +from tqdm.auto import tqdm +import torch +import torch.nn as nn +from datasets import load_dataset +from datasets import Dataset as HFDataset +from transformers import AutoTokenizer +from transformers import DefaultDataCollator +from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer +import collections +import numpy as np +import evaluate + +# FLEXible imports +from flex.data import FedDatasetConfig, FedDataDistribution +from flex.data import Dataset +from flex.model import FlexModel +from flex.pool import FlexPool +from flex.pool.decorators import init_server_model +from flex.pool.decorators import deploy_server_model +from flex.pool import deploy_server_model, deploy_server_model_pt +from torch.utils.data import Dataset as TorchDataset +from flex.pool import collect_clients_weights_pt +from flex.pool import fed_avg +from flex.pool import set_aggregated_weights_pt +from flex.pool import evaluate_server_model + + +device = ( + "cuda" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" +) + +print(device) + +# Load the dataset +# Load a percentage of squal +squad = load_dataset("squad", split="train[:1%]") +# Split 80% train, 20% test +squad = squad.train_test_split(test_size=0.9) +print(squad) + +# Preprocess functions + +model_checkpoint = "distilbert-base-uncased" +#model_checkpoint = "bert-base-cased" +tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) + +max_length = 512 +stride = 128 + + +def preprocess_training_examples(examples): + questions = [q.strip() for q in examples["question"]] + inputs = tokenizer( + questions, + examples["context"], + max_length=max_length, + truncation="only_second", + stride=stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length", + ) + + offset_mapping = inputs.pop("offset_mapping") + sample_map = inputs.pop("overflow_to_sample_mapping") + answers = examples["answers"] + start_positions = [] + end_positions = [] + + for i, offset in enumerate(offset_mapping): + sample_idx = sample_map[i] + answer = answers[sample_idx] + start_char = answer["answer_start"][0] + end_char = answer["answer_start"][0] + len(answer["text"][0]) + sequence_ids = inputs.sequence_ids(i) + + # Find the start and end of the context + idx = 0 + while sequence_ids[idx] != 1: + idx += 1 + context_start = idx + while sequence_ids[idx] == 1: + idx += 1 + context_end = idx - 1 + + # If the answer is not fully inside the context, label is (0, 0) + if offset[context_start][0] > start_char or offset[context_end][1] < end_char: + start_positions.append(0) + end_positions.append(0) + else: + # Otherwise it's the start and end token positions + idx = context_start + while idx <= context_end and offset[idx][0] <= start_char: + idx += 1 + start_positions.append(idx - 1) + + idx = context_end + while idx >= context_start and offset[idx][1] >= end_char: + idx -= 1 + end_positions.append(idx + 1) + + inputs["start_positions"] = start_positions + inputs["end_positions"] = end_positions + return inputs + +def preprocess_training_examples_as_lists(examples, answers_examples): + """ + Function that preprocess the data that comes as a list + instead as a Dataset type. + Args: + examples (list[list]): List of lists containg the examples to + preprocess. ['id', 'title', 'context', 'question'] + answers (list[str]): List containing the answers + """ + questions = [q[3].strip() for q in examples] + contexts = [c[2] for c in examples] + inputs = tokenizer( + questions, + # examples["context"], + contexts, + max_length=max_length, + truncation="only_second", + stride=stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length", + ) + + offset_mapping = inputs.pop("offset_mapping") + sample_map = inputs.pop("overflow_to_sample_mapping") + # answers = examples["answers"] + answers = [answers_examples[1][i] for i in range(len(answers_examples[1]))] + start_positions = [] + end_positions = [] + + for i, offset in enumerate(offset_mapping): + sample_idx = sample_map[i] + answer = answers[sample_idx] + start_char = answer["answer_start"][0] + end_char = answer["answer_start"][0] + len(answer["text"][0]) + sequence_ids = inputs.sequence_ids(i) + + # Find the start and end of the context + idx = 0 + while sequence_ids[idx] != 1: + idx += 1 + context_start = idx + while sequence_ids[idx] == 1: + idx += 1 + context_end = idx - 1 + + # If the answer is not fully inside the context, label is (0, 0) + if offset[context_start][0] > start_char or offset[context_end][1] < end_char: + start_positions.append(0) + end_positions.append(0) + else: + # Otherwise it's the start and end token positions + idx = context_start + while idx <= context_end and offset[idx][0] <= start_char: + idx += 1 + start_positions.append(idx - 1) + + idx = context_end + while idx >= context_start and offset[idx][1] >= end_char: + idx -= 1 + end_positions.append(idx + 1) + + inputs["start_positions"] = start_positions + inputs["end_positions"] = end_positions + return HFDataset.from_dict(inputs) + + +def preprocess_validation_examples(examples): + questions = [q.strip() for q in examples["question"]] + inputs = tokenizer( + questions, + examples["context"], + max_length=max_length, + truncation="only_second", + stride=stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length", + ) + + sample_map = inputs.pop("overflow_to_sample_mapping") + example_ids = [] + + for i in range(len(inputs["input_ids"])): + sample_idx = sample_map[i] + example_ids.append(examples["id"][sample_idx]) + + sequence_ids = inputs.sequence_ids(i) + offset = inputs["offset_mapping"][i] + inputs["offset_mapping"][i] = [ + o if sequence_ids[k] == 1 else None for k, o in enumerate(offset) + ] + + inputs["example_id"] = example_ids + return inputs + + +train_dataset_processed = squad["train"].map( + preprocess_training_examples, + batched=True, + remove_columns=squad["train"].column_names, +) + + +train_dataset = squad["train"] + +test_dataset = squad["test"].map( + preprocess_validation_examples, + batched=True, + remove_columns=squad["test"].column_names, +) + +# From centralized to federated data + +config = FedDatasetConfig(seed=0) +config.n_clients = 2 +config.replacement = False # ensure that clients do not share any data +config.client_names = ['client1', 'client2'] # Optional +flex_dataset = FedDataDistribution.from_config_with_huggingface_dataset(data=train_dataset, config=config, + X_columns=['id', 'title', 'context', 'question'], + label_columns=['answers'] + # X_columns=['input_ids', 'attention_mask'], + # label_columns=['start_positions', 'end_positions'] + ) + +# Init the server model and deploy it +@init_server_model +def build_server_model(): + server_flex_model = FlexModel() + + server_flex_model['model'] = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased") + # Required to store this for later stages of the FL training process + server_flex_model['training_args'] = TrainingArguments( + output_dir="my_awesome_qa_model", + # evaluation_strategy="epoch", + learning_rate=2e-5, + per_device_train_batch_size=16, + per_device_eval_batch_size=16, + num_train_epochs=3, + weight_decay=0.01, + ) + + # server_flex_model['trainer'] = trainer + + return server_flex_model + +flex_pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model) + +clients = flex_pool.clients +servers = flex_pool.servers +aggregators = flex_pool.aggregators + +print(f"Number of nodes in the pool {len(flex_pool)}: {len(servers)} server plus {len(clients)} clients. The server is also an aggregator") + +# Deploy the model +@deploy_server_model +def copy_server_model_to_clients(server_flex_model: FlexModel): + return deepcopy(server_flex_model) + +servers.map(copy_server_model_to_clients, clients) # Using the function created with the decorator +# servers.map(deploy_server_model_pt, clients) # Using the primitive function + +# Train each client's model +def train(client_flex_model: FlexModel, client_data: Dataset): + print("Training client") + model = client_flex_model['model'] + training_args = client_flex_model['training_args'] + # client_train_dataset = client_data.to_numpy() + X_data = client_data.X_data.tolist() + y_data = client_data.to_list() + client_train_dataset = preprocess_training_examples_as_lists(examples=X_data, answers_examples=y_data) + # breakpoint() + trainer = Trainer( + model = model, + args=training_args, + train_dataset=client_train_dataset, + # eval_dataset=validation_dataset, + tokenizer=tokenizer, + # data_collator=data_collator, + ) + trainer.train() + +clients.map(train) + +aggregators.map(collect_clients_weights_pt, clients) + +aggregators.map(fed_avg) + +aggregators.map(set_aggregated_weights_pt, servers) + +# TODO: Add the evaluate function +n_best = 20 +max_answer_length = 30 +predicted_answers = [] +metric = evaluate.load("squad") + +def compute_metrics(start_logits, end_logits, features, examples): + example_to_features = collections.defaultdict(list) + for idx, feature in enumerate(features): + example_to_features[feature["example_id"]].append(idx) + + predicted_answers = [] + for example in tqdm(examples): + example_id = example["id"] + context = example["context"] + answers = [] + + # Loop through all features associated with that example + for feature_index in example_to_features[example_id]: + start_logit = start_logits[feature_index] + end_logit = end_logits[feature_index] + offsets = features[feature_index]["offset_mapping"] + + start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist() + end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist() + for start_index in start_indexes: + for end_index in end_indexes: + # Skip answers that are not fully in the context + if offsets[start_index] is None or offsets[end_index] is None: + continue + # Skip answers with a length that is either < 0 or > max_answer_length + if ( + end_index < start_index + or end_index - start_index + 1 > max_answer_length + ): + continue + + answer = { + "text": context[offsets[start_index][0] : offsets[end_index][1]], + "logit_score": start_logit[start_index] + end_logit[end_index], + } + answers.append(answer) + + # Select the answer with the best score + if len(answers) > 0: + best_answer = max(answers, key=lambda x: x["logit_score"]) + predicted_answers.append( + {"id": example_id, "prediction_text": best_answer["text"]} + ) + else: + predicted_answers.append({"id": example_id, "prediction_text": ""}) + + theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples] + return metric.compute(predictions=predicted_answers, references=theoretical_answers) + +@evaluate_server_model +def evaluate_global_model(server_flex_model: FlexModel, test_data=None): + model = server_flex_model["model"] + training_args = server_flex_model["training_args"] + trainer = Trainer( + model = model, + args=training_args, + train_dataset=test_data, + # eval_dataset=validation_dataset, + tokenizer=tokenizer, + # data_collator=data_collator, + ) + predictions, _, _ = trainer.predict(test_data) + start_logits, end_logits = predictions + print(len(predictions[0])) + return compute_metrics(start_logits, end_logits, test_data, squad["test"]) + +# predictions, _, _ = trainer.predict(validation_dataset) +# start_logits, end_logits = predictions + +# print(len(predictions[0])) +# compute_metrics(start_logits, end_logits, validation_dataset, squad["test"]) +metrics = servers.map(evaluate_global_model, test_data=test_dataset) + +print(metrics) diff --git a/setup.py b/setup.py index 94e57c3..f2a0e6a 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,8 @@ "torchtext", "portalocker", "torchdata", + "datasets", + "transformers" ], extras_require={ "tensorflow": TF_requires,