Skip to content

Commit

Permalink
Added adapter for the sentence_transformer package
Browse files Browse the repository at this point in the history
  • Loading branch information
AlArgente committed Jan 24, 2024
1 parent b0cb028 commit 03d42c0
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to prepare the data for the model. In this case, we have to create a DataLoader with de ´InputExample´ format from **SentenceTransformers**. "
"We need to prepare the data for the model. In this case, we have to create a DataLoader with de ´InputExample´ format from **SentenceTransformers**. We have commented the evaluator of the model in the clients, but we keep it on the server side. "
]
},
{
Expand All @@ -207,12 +207,21 @@
"metadata": {},
"outputs": [],
"source": [
"def create_input_examples_for_training(X_data_as_list, X_test_as_list):\n",
" \"\"\"Function to create a DataLoader to train/finetune the model at client level\n",
"from flexnlp.utils.adapters import ss_triplet_input_adapter\n",
"\n",
"def create_input_examples_for_training(X_data_as_list, X_test_as_list):\n",
" \"\"\"Function to create a DataLoader to train/finetune the model at client level.\n",
" This function also create a dev_example \n",
" Args:\n",
" X_data_as_list (list): List containing the examples. Each example is a dict\n",
" with the following keys: query, pos, neg.\n",
" X_test_as_list (list): List containing the test/validation examples. Each\n",
" example is a dict with the following keys: query, pos, neg.\n",
" Returns:\n",
" train_dataloader (DataLoader): A DataLoader for training the model, batched and shuffled.\n",
" dev_examples (InputExamples): An InputExample object from SentenceTransformers so it can be\n",
" used on the a TripleEvaluator.\n",
"\n",
" \"\"\"\n",
" train_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_data_as_list]\n",
" dev_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]\n",
Expand All @@ -228,8 +237,9 @@
" tam_train = int(len(X_data) * 0.75)\n",
" X_data, X_test = X_data[:tam_train], X_data[tam_train:]\n",
" train_dataloader, dev_examples = create_input_examples_for_training(X_data_as_list=X_data, X_test_as_list=X_test)\n",
" train_dataloader, _ = ss_triplet_input_adapter(X_data, X_test)\n",
" train_loss = losses.TripletLoss(model=model)\n",
" evaluator = TripletEvaluator.from_input_examples(dev_examples)\n",
" evaluator = TripletEvaluator.from_input_examples(dev_examples)\n",
" warmup_steps = int(len(train_dataloader) * 1 * 0.1) #10% of train data\n",
" model.fit(train_objectives=[(train_dataloader, train_loss)],\n",
" epochs=1,\n",
Expand Down Expand Up @@ -323,11 +333,14 @@
"from flex.data import Dataset\n",
"\n",
"def create_input_examples_for_testing(X_test_as_list):\n",
" \"\"\"Function to create a DataLoader to train/finetune the model at client level\n",
" \"\"\"Function to create an InputExample to evaluate the model at server level\n",
"\n",
" Args:\n",
" X_test_as_list (list): List containing the examples. Each example is a dict\n",
" with the following keys: query, pos, neg.\n",
" Returns:\n",
" test_examples (List[InputExample]): A list with the input formatted to be used\n",
" for the TripletEvaluator.\n",
" \"\"\"\n",
" return [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]\n",
"\n",
Expand Down
5 changes: 3 additions & 2 deletions flexnlp/notebooks/FederatedSS.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from flex.pool import fed_avg
from flex.pool import set_aggregated_weights_pt
from flex.pool import evaluate_server_model
from flexnlp.utils.adapters import ss_triplet_input_adapter

device = (
"cuda"
Expand Down Expand Up @@ -116,7 +117,7 @@ def train(client_flex_model: FlexModel, client_data: Dataset):
X_data = client_data.X_data.tolist()
tam_train = int(len(X_data) * 0.75)
X_data, X_test = X_data[:tam_train], X_data[tam_train:]
train_dataloader, dev_examples = create_input_examples_for_training(X_data_as_list=X_data, X_test_as_list=X_test)
train_dataloader, dev_examples = ss_triplet_input_adapter(X_data_as_list=X_data, X_test_as_list=X_test)
train_loss = losses.TripletLoss(model=model)
evaluator = TripletEvaluator.from_input_examples(dev_examples)
warmup_steps = int(len(train_dataloader) * 1 * 0.1) #10% of train data
Expand Down Expand Up @@ -152,7 +153,7 @@ def create_input_examples_for_testing(X_test_as_list):

@evaluate_server_model
def evaluate_global_model(server_flex_model: FlexModel, test_data=None):
X_test = create_input_examples_for_testing(X_test_as_list=test_dataset.X_data.tolist())
X_test = ss_triplet_input_adapter(X_test_as_list=test_dataset.X_data.tolist(), train=False)
model = server_flex_model["model"]
evaluator = TripletEvaluator.from_input_examples(X_test)
model.evaluate(evaluator, 'server_evaluation')
Expand Down
1 change: 1 addition & 0 deletions flexnlp/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
# from flexnlp.utils.collators import classification_sampler
# from flexnlp.utils.collators import ClassificationCollator
from flexnlp.utils import collators
from flexnlp.utils import adapters
5 changes: 5 additions & 0 deletions flexnlp/utils/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from flexnlp.utils.adapters.ss_adapters import ss_triplet_input_adapter
33 changes: 33 additions & 0 deletions flexnlp/utils/adapters/ss_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
def ss_triplet_input_adapter(X_train_as_list: list = None, X_test_as_list: list = None,
batch_size=16, shuffle=True, train=True, test=True):
"""Function that adapt the input from a Triplet Dataset to use within a
SentenceTransformer's model.
The method ensures that the data is provived in order to give an output.
Args:
X_train_as_list (list, optional): _description_. Defaults to None.
X_test_as_list (list, optional): _description_. Defaults to None.
batch_size (int, optional): _description_. Defaults to 16.
shuffle (bool, optional): _description_. Defaults to True.
train (bool, optional): _description_. Defaults to True.
test (bool, optional): _description_. Defaults to True.
Returns:
tuple:
"""
if not X_train_as_list and not X_test_as_list:
raise ValueError("No data given. Please provide data for train or test.")
if not train and not test:
raise ValueError("train or test parameters must be true in order to give an output.")

from sentence_transformers import InputExample
from torch.utils.data import DataLoader

train_examples = None
dev_examples = None
if train and len(X_train_as_list) > 1:
train_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_data_as_list]
train_examples = DataLoader(train_examples, shuffle=shuffle, batch_size=batch_size)
if test and len(X_test_as_list) > 1:
dev_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]

return train_examples, dev_examples

0 comments on commit 03d42c0

Please sign in to comment.