From ee594f93fbfc2d0a34f3651b2bc03a3e94e78308 Mon Sep 17 00:00:00 2001 From: AlArgente Date: Tue, 23 Jan 2024 08:56:23 +0100 Subject: [PATCH] Added the notebook for the Federated Semantic Search task. --- ... SentenceTransformers using FLEXible.ipynb | 430 ++++++++++++++++++ 1 file changed, 430 insertions(+) create mode 100644 flexnlp/notebooks/Federated SS with SentenceTransformers using FLEXible.ipynb diff --git a/flexnlp/notebooks/Federated SS with SentenceTransformers using FLEXible.ipynb b/flexnlp/notebooks/Federated SS with SentenceTransformers using FLEXible.ipynb new file mode 100644 index 0000000..f785638 --- /dev/null +++ b/flexnlp/notebooks/Federated SS with SentenceTransformers using FLEXible.ipynb @@ -0,0 +1,430 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training a Semantic Similarity/Semantic Search with Sentence Transformers using FLEXible." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "from datasets import load_dataset\n", + "from sentence_transformers import SentenceTransformer, models # , util\n", + "from sentence_transformers import InputExample\n", + "from sentence_transformers import losses\n", + "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, TripletEvaluator\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.utils.data import DataLoader\n", + "from datasets import Dataset as HFDataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = (\n", + " \"cuda\"\n", + " if torch.cuda.is_available()\n", + " else \"mps\"\n", + " if torch.backends.mps.is_available()\n", + " else \"cpu\"\n", + ")\n", + "\n", + "print(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load el dataset\n", + "\n", + "First we load the dataset. As there isn't federated datasets for this task, it is needed to load a centralized dataset and federate it. In this tutorial we are using the ´embedding-data/QQP_triplets´ dataset from **Huggigface Datasets**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the dataset\n", + "dataset_id = \"embedding-data/QQP_triplets\"\n", + "# dataset_id = \"embedding-data/sentence-compression\"\n", + "\n", + "data = load_dataset(dataset_id, split=['train[:1%]'])[0].train_test_split(test_size=0.1)\n", + "dataset, test_dataset = data['train'], data['test']\n", + "print(f\"- The {dataset_id} dataset has {dataset.num_rows} examples.\")\n", + "print(f\"- Each example is a {type(dataset[0])} with a {type(dataset[0]['set'])} as value.\")\n", + "print(f\"- Examples look like this: {dataset[0]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# From centralized data to federated data\n", + "\n", + "First we're going to federate the dataset using the FedDataDristibution class, that has functions to load multiple datasets from deep learning libraries such as PyTorch or TensorFlow. In this notebook we are using PyTorch, so we need to use the functions from the PyTorch ecosystem, and for the text datasets, we need to use the function `from_config_with_torchtext_dataset`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.data import FedDatasetConfig, FedDataDistribution\n", + "\n", + "config = FedDatasetConfig(seed=0)\n", + "config.n_clients = 2\n", + "config.replacement = False # ensure that clients do not share any data\n", + "config.client_names = ['client1', 'client2'] # Optional\n", + "flex_dataset = FedDataDistribution.from_config_with_huggingface_dataset(data=dataset, config=config,\n", + " X_columns=['set'], # 'title', 'context', 'question'],\n", + " label_columns=['set']\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2) Federate a model with FLEXible.\n", + "\n", + "Once we've federated the dataset, it's time to create the FlexPool. The FlexPool class is the one that simulates the real-time scenario for federated learning, so it is in charge of the communications across actors. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.model import FlexModel\n", + "from flex.pool import FlexPool\n", + "\n", + "from flex.pool.decorators import init_server_model\n", + "from flex.pool.decorators import deploy_server_model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook we are going to simulate a client-server architecture, which we can easily build using the FlexPool class, using the function `client_server_architecture`. This function needs a FlexDataset, which we already have prepared, and a function to initialize the server model, which we have to create.\n", + "\n", + "The model we are going to use is a simple LSTM, which will have the embeddings, the LSTM, a Linear layer and the output layer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@init_server_model\n", + "def build_server_model():\n", + " server_flex_model = FlexModel()\n", + " word_embedding_model = models.Transformer('distilroberta-base')\n", + " pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())\n", + " server_flex_model['model'] = SentenceTransformer(modules=[word_embedding_model, pooling_model])\n", + "\n", + " return server_flex_model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we've defined the function to initialize the server model, we can create the FlexPool using the function `client_server_architecture`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "flex_pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)\n", + "\n", + "clients = flex_pool.clients\n", + "servers = flex_pool.servers\n", + "aggregators = flex_pool.aggregators\n", + "\n", + "print(f\"Number of nodes in the pool {len(flex_pool)}: {len(servers)} server plus {len(clients)} clients. The server is also an aggregator\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can use the decorator `deploy_server_model` to create a custom function that deploys our server model, or we can use the primitive `deploy_server_model_pt` to deploy the server model to the clients." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.pool import deploy_server_model, deploy_server_model_pt\n", + "\n", + "@deploy_server_model\n", + "def copy_server_model_to_clients(server_flex_model: FlexModel):\n", + " return deepcopy(server_flex_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "servers.map(copy_server_model_to_clients, clients) # Using the function created with the decorator\n", + "# servers.map(deploy_server_model_pt, clients) # Using the primitive function" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We need to prepare the data for the model. In this case, we have to create a DataLoader with de ´InputExample´ format from **SentenceTransformers**. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_input_examples_for_training(X_data_as_list, X_test_as_list):\n", + " \"\"\"Function to create a DataLoader to train/finetune the model at client level\n", + "\n", + " Args:\n", + " X_data_as_list (list): List containing the examples. Each example is a dict\n", + " with the following keys: query, pos, neg.\n", + " \"\"\"\n", + " train_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_data_as_list]\n", + " dev_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]\n", + " return DataLoader(train_examples, shuffle=True, batch_size=16), dev_examples\n", + "\n", + "def train(client_flex_model: FlexModel, client_data):\n", + " print(\"Training client\")\n", + " model = client_flex_model['model']\n", + " sentences = ['This is an example sentence', 'Each sentence is converted']\n", + " encodings = model.encode(sentences)\n", + " print(f\"Old encodings: {encodings}\")\n", + " X_data = client_data.X_data.tolist()\n", + " tam_train = int(len(X_data) * 0.75)\n", + " X_data, X_test = X_data[:tam_train], X_data[tam_train:]\n", + " train_dataloader, dev_examples = create_input_examples_for_training(X_data_as_list=X_data, X_test_as_list=X_test)\n", + " train_loss = losses.TripletLoss(model=model)\n", + " evaluator = TripletEvaluator.from_input_examples(dev_examples)\n", + " warmup_steps = int(len(train_dataloader) * 1 * 0.1) #10% of train data\n", + " model.fit(train_objectives=[(train_dataloader, train_loss)],\n", + " epochs=1,\n", + " warmup_steps=warmup_steps,\n", + " evaluator=evaluator,\n", + " evaluation_steps=1000,\n", + " )\n", + " # model.evaluate(evaluator, 'model_evaluation')\n", + " sentences = ['This is an example sentence', 'Each sentence is converted']\n", + " encodings = model.encode(sentences)\n", + " print(f\"New encodings: {encodings}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clients.map(train)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After training the model, we have to aggregate the weights from the clients model in order to update the global model. To to so, we are going to use the primitive `collect_clients_weights_pt`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.pool import collect_clients_weights_pt\n", + "\n", + "aggregators.map(collect_clients_weights_pt, clients)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once the weights are aggregated, we aggregate them. In this notebook we use the FedAvg method that is already implemented in FLEXible." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.pool import fed_avg\n", + "\n", + "aggregators.map(fed_avg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The function `set_aggregated_weights_pt` sed the aggregated weights to the server model to update it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.pool import set_aggregated_weights_pt\n", + "\n", + "aggregators.map(set_aggregated_weights_pt, servers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now it's turn to evaluate the global model. To do so, we have to create a function using the decoratod `evaluate_server_model`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from flex.pool import evaluate_server_model\n", + "from flex.data import Dataset\n", + "\n", + "def create_input_examples_for_testing(X_test_as_list):\n", + " \"\"\"Function to create a DataLoader to train/finetune the model at client level\n", + "\n", + " Args:\n", + " X_test_as_list (list): List containing the examples. Each example is a dict\n", + " with the following keys: query, pos, neg.\n", + " \"\"\"\n", + " return [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]\n", + "\n", + "test_dataset = Dataset.from_huggingface_dataset(test_dataset, X_columns=['set'])\n", + "\n", + "@evaluate_server_model\n", + "def evaluate_global_model(server_flex_model: FlexModel, test_data=None):\n", + " X_test = create_input_examples_for_testing(X_test_as_list=test_dataset.X_data.tolist())\n", + " model = server_flex_model[\"model\"]\n", + " evaluator = TripletEvaluator.from_input_examples(X_test)\n", + " model.evaluate(evaluator, 'server_evaluation')\n", + " print(\"Model evaluation saved to file.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "servers.map(evaluate_global_model, test_data=test_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run the federated learning experiment for a few rounds\n", + "\n", + "Now, we can summarize the steps provided above and run the federated experiment for multiple rounds:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train_n_rounds(n_rounds, clients_per_round=2): \n", + " pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)\n", + " for i in range(n_rounds):\n", + " print(f\"\\nRunning round: {i+1} of {n_rounds}\")\n", + " selected_clients_pool = pool.clients.select(clients_per_round)\n", + " selected_clients = selected_clients_pool.clients\n", + " print(f\"Selected clients for this round: {len(selected_clients)}\")\n", + " # Deploy the server model to the selected clients\n", + " pool.servers.map(deploy_server_model_pt, selected_clients)\n", + " # Each selected client trains her model\n", + " selected_clients.map(train)\n", + " # The aggregador collects weights from the selected clients and aggregates them\n", + " pool.aggregators.map(collect_clients_weights_pt, selected_clients)\n", + " pool.aggregators.map(fed_avg)\n", + " # The aggregator send its aggregated weights to the server\n", + " pool.aggregators.map(set_aggregated_weights_pt, pool.servers)\n", + " servers.map(evaluate_global_model, test_data=test_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Train the model for n_rounds\n", + "train_n_rounds(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# End\n", + "\n", + "Congratulations, you've just trained a **SentenceTransformers** model using FLEXible!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "flexible", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}