From 03d42c01241bc78a02fff5a46d603d194ef09e62 Mon Sep 17 00:00:00 2001 From: AlArgente Date: Wed, 24 Jan 2024 08:18:36 +0100 Subject: [PATCH] Added adapter for the sentence_transformer package --- ... SentenceTransformers using FLEXible.ipynb | 23 ++++++++++--- flexnlp/notebooks/FederatedSS.py | 5 +-- flexnlp/utils/__init__.py | 1 + flexnlp/utils/adapters/__init__.py | 5 +++ flexnlp/utils/adapters/ss_adapters.py | 33 +++++++++++++++++++ 5 files changed, 60 insertions(+), 7 deletions(-) create mode 100644 flexnlp/utils/adapters/__init__.py create mode 100644 flexnlp/utils/adapters/ss_adapters.py diff --git a/flexnlp/notebooks/Federated SS with SentenceTransformers using FLEXible.ipynb b/flexnlp/notebooks/Federated SS with SentenceTransformers using FLEXible.ipynb index 0908342..63f3e6b 100644 --- a/flexnlp/notebooks/Federated SS with SentenceTransformers using FLEXible.ipynb +++ b/flexnlp/notebooks/Federated SS with SentenceTransformers using FLEXible.ipynb @@ -198,7 +198,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We need to prepare the data for the model. In this case, we have to create a DataLoader with de ´InputExample´ format from **SentenceTransformers**. " + "We need to prepare the data for the model. In this case, we have to create a DataLoader with de ´InputExample´ format from **SentenceTransformers**. We have commented the evaluator of the model in the clients, but we keep it on the server side. " ] }, { @@ -207,12 +207,21 @@ "metadata": {}, "outputs": [], "source": [ - "def create_input_examples_for_training(X_data_as_list, X_test_as_list):\n", - " \"\"\"Function to create a DataLoader to train/finetune the model at client level\n", + "from flexnlp.utils.adapters import ss_triplet_input_adapter\n", "\n", + "def create_input_examples_for_training(X_data_as_list, X_test_as_list):\n", + " \"\"\"Function to create a DataLoader to train/finetune the model at client level.\n", + " This function also create a dev_example \n", " Args:\n", " X_data_as_list (list): List containing the examples. Each example is a dict\n", " with the following keys: query, pos, neg.\n", + " X_test_as_list (list): List containing the test/validation examples. Each\n", + " example is a dict with the following keys: query, pos, neg.\n", + " Returns:\n", + " train_dataloader (DataLoader): A DataLoader for training the model, batched and shuffled.\n", + " dev_examples (InputExamples): An InputExample object from SentenceTransformers so it can be\n", + " used on the a TripleEvaluator.\n", + "\n", " \"\"\"\n", " train_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_data_as_list]\n", " dev_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]\n", @@ -228,8 +237,9 @@ " tam_train = int(len(X_data) * 0.75)\n", " X_data, X_test = X_data[:tam_train], X_data[tam_train:]\n", " train_dataloader, dev_examples = create_input_examples_for_training(X_data_as_list=X_data, X_test_as_list=X_test)\n", + " train_dataloader, _ = ss_triplet_input_adapter(X_data, X_test)\n", " train_loss = losses.TripletLoss(model=model)\n", - " evaluator = TripletEvaluator.from_input_examples(dev_examples)\n", + " # evaluator = TripletEvaluator.from_input_examples(dev_examples)\n", " warmup_steps = int(len(train_dataloader) * 1 * 0.1) #10% of train data\n", " model.fit(train_objectives=[(train_dataloader, train_loss)],\n", " epochs=1,\n", @@ -323,11 +333,14 @@ "from flex.data import Dataset\n", "\n", "def create_input_examples_for_testing(X_test_as_list):\n", - " \"\"\"Function to create a DataLoader to train/finetune the model at client level\n", + " \"\"\"Function to create an InputExample to evaluate the model at server level\n", "\n", " Args:\n", " X_test_as_list (list): List containing the examples. Each example is a dict\n", " with the following keys: query, pos, neg.\n", + " Returns:\n", + " test_examples (List[InputExample]): A list with the input formatted to be used\n", + " for the TripletEvaluator.\n", " \"\"\"\n", " return [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]\n", "\n", diff --git a/flexnlp/notebooks/FederatedSS.py b/flexnlp/notebooks/FederatedSS.py index faf9d91..0547c7a 100644 --- a/flexnlp/notebooks/FederatedSS.py +++ b/flexnlp/notebooks/FederatedSS.py @@ -22,6 +22,7 @@ from flex.pool import fed_avg from flex.pool import set_aggregated_weights_pt from flex.pool import evaluate_server_model +from flexnlp.utils.adapters import ss_triplet_input_adapter device = ( "cuda" @@ -116,7 +117,7 @@ def train(client_flex_model: FlexModel, client_data: Dataset): X_data = client_data.X_data.tolist() tam_train = int(len(X_data) * 0.75) X_data, X_test = X_data[:tam_train], X_data[tam_train:] - train_dataloader, dev_examples = create_input_examples_for_training(X_data_as_list=X_data, X_test_as_list=X_test) + train_dataloader, dev_examples = ss_triplet_input_adapter(X_data_as_list=X_data, X_test_as_list=X_test) train_loss = losses.TripletLoss(model=model) evaluator = TripletEvaluator.from_input_examples(dev_examples) warmup_steps = int(len(train_dataloader) * 1 * 0.1) #10% of train data @@ -152,7 +153,7 @@ def create_input_examples_for_testing(X_test_as_list): @evaluate_server_model def evaluate_global_model(server_flex_model: FlexModel, test_data=None): - X_test = create_input_examples_for_testing(X_test_as_list=test_dataset.X_data.tolist()) + X_test = ss_triplet_input_adapter(X_test_as_list=test_dataset.X_data.tolist(), train=False) model = server_flex_model["model"] evaluator = TripletEvaluator.from_input_examples(X_test) model.evaluate(evaluator, 'server_evaluation') diff --git a/flexnlp/utils/__init__.py b/flexnlp/utils/__init__.py index 768a6b3..6e859e2 100644 --- a/flexnlp/utils/__init__.py +++ b/flexnlp/utils/__init__.py @@ -6,3 +6,4 @@ # from flexnlp.utils.collators import classification_sampler # from flexnlp.utils.collators import ClassificationCollator from flexnlp.utils import collators +from flexnlp.utils import adapters diff --git a/flexnlp/utils/adapters/__init__.py b/flexnlp/utils/adapters/__init__.py new file mode 100644 index 0000000..78cda9c --- /dev/null +++ b/flexnlp/utils/adapters/__init__.py @@ -0,0 +1,5 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from flexnlp.utils.adapters.ss_adapters import ss_triplet_input_adapter diff --git a/flexnlp/utils/adapters/ss_adapters.py b/flexnlp/utils/adapters/ss_adapters.py new file mode 100644 index 0000000..218aefd --- /dev/null +++ b/flexnlp/utils/adapters/ss_adapters.py @@ -0,0 +1,33 @@ +def ss_triplet_input_adapter(X_train_as_list: list = None, X_test_as_list: list = None, + batch_size=16, shuffle=True, train=True, test=True): + """Function that adapt the input from a Triplet Dataset to use within a + SentenceTransformer's model. + + The method ensures that the data is provived in order to give an output. + Args: + X_train_as_list (list, optional): _description_. Defaults to None. + X_test_as_list (list, optional): _description_. Defaults to None. + batch_size (int, optional): _description_. Defaults to 16. + shuffle (bool, optional): _description_. Defaults to True. + train (bool, optional): _description_. Defaults to True. + test (bool, optional): _description_. Defaults to True. + Returns: + tuple: + """ + if not X_train_as_list and not X_test_as_list: + raise ValueError("No data given. Please provide data for train or test.") + if not train and not test: + raise ValueError("train or test parameters must be true in order to give an output.") + + from sentence_transformers import InputExample + from torch.utils.data import DataLoader + + train_examples = None + dev_examples = None + if train and len(X_train_as_list) > 1: + train_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_data_as_list] + train_examples = DataLoader(train_examples, shuffle=shuffle, batch_size=batch_size) + if test and len(X_test_as_list) > 1: + dev_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list] + + return train_examples, dev_examples \ No newline at end of file