Skip to content

Commit

Permalink
Fixed some bugs in the adapter, and the in the notebook.
Browse files Browse the repository at this point in the history
  • Loading branch information
cristinazuhe committed Jan 24, 2024
1 parent 03d42c0 commit b39504a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,24 +209,6 @@
"source": [
"from flexnlp.utils.adapters import ss_triplet_input_adapter\n",
"\n",
"def create_input_examples_for_training(X_data_as_list, X_test_as_list):\n",
" \"\"\"Function to create a DataLoader to train/finetune the model at client level.\n",
" This function also create a dev_example \n",
" Args:\n",
" X_data_as_list (list): List containing the examples. Each example is a dict\n",
" with the following keys: query, pos, neg.\n",
" X_test_as_list (list): List containing the test/validation examples. Each\n",
" example is a dict with the following keys: query, pos, neg.\n",
" Returns:\n",
" train_dataloader (DataLoader): A DataLoader for training the model, batched and shuffled.\n",
" dev_examples (InputExamples): An InputExample object from SentenceTransformers so it can be\n",
" used on the a TripleEvaluator.\n",
"\n",
" \"\"\"\n",
" train_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_data_as_list]\n",
" dev_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]\n",
" return DataLoader(train_examples, shuffle=True, batch_size=16), dev_examples\n",
"\n",
"def train(client_flex_model: FlexModel, client_data):\n",
" print(\"Training client\")\n",
" model = client_flex_model['model']\n",
Expand All @@ -236,15 +218,14 @@
" X_data = client_data.X_data.tolist()\n",
" tam_train = int(len(X_data) * 0.75)\n",
" X_data, X_test = X_data[:tam_train], X_data[tam_train:]\n",
" train_dataloader, dev_examples = create_input_examples_for_training(X_data_as_list=X_data, X_test_as_list=X_test)\n",
" train_dataloader, _ = ss_triplet_input_adapter(X_data, X_test)\n",
" train_loss = losses.TripletLoss(model=model)\n",
" # evaluator = TripletEvaluator.from_input_examples(dev_examples)\n",
" warmup_steps = int(len(train_dataloader) * 1 * 0.1) #10% of train data\n",
" model.fit(train_objectives=[(train_dataloader, train_loss)],\n",
" epochs=1,\n",
" warmup_steps=warmup_steps,\n",
" evaluator=evaluator,\n",
" # evaluator=evaluator,\n",
" evaluation_steps=1000,\n",
" )\n",
" # model.evaluate(evaluator, 'model_evaluation')\n",
Expand Down Expand Up @@ -332,23 +313,11 @@
"from flex.pool import evaluate_server_model\n",
"from flex.data import Dataset\n",
"\n",
"def create_input_examples_for_testing(X_test_as_list):\n",
" \"\"\"Function to create an InputExample to evaluate the model at server level\n",
"\n",
" Args:\n",
" X_test_as_list (list): List containing the examples. Each example is a dict\n",
" with the following keys: query, pos, neg.\n",
" Returns:\n",
" test_examples (List[InputExample]): A list with the input formatted to be used\n",
" for the TripletEvaluator.\n",
" \"\"\"\n",
" return [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]\n",
"\n",
"test_dataset = Dataset.from_huggingface_dataset(test_dataset, X_columns=['set'])\n",
"\n",
"@evaluate_server_model\n",
"def evaluate_global_model(server_flex_model: FlexModel, test_data=None):\n",
" X_test = create_input_examples_for_testing(X_test_as_list=test_dataset.X_data.tolist())\n",
" _, X_test = ss_triplet_input_adapter(X_test_as_list=test_dataset.X_data.tolist(), train=False)\n",
" model = server_flex_model[\"model\"]\n",
" evaluator = TripletEvaluator.from_input_examples(X_test)\n",
" model.evaluate(evaluator, 'server_evaluation')\n",
Expand Down Expand Up @@ -434,7 +403,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.3"
},
"orig_nbformat": 4
},
Expand Down
4 changes: 2 additions & 2 deletions flexnlp/notebooks/FederatedSS.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def train(client_flex_model: FlexModel, client_data: Dataset):
X_data = client_data.X_data.tolist()
tam_train = int(len(X_data) * 0.75)
X_data, X_test = X_data[:tam_train], X_data[tam_train:]
train_dataloader, dev_examples = ss_triplet_input_adapter(X_data_as_list=X_data, X_test_as_list=X_test)
train_dataloader, dev_examples = ss_triplet_input_adapter(X_train_as_list=X_data, X_test_as_list=X_test)
train_loss = losses.TripletLoss(model=model)
evaluator = TripletEvaluator.from_input_examples(dev_examples)
warmup_steps = int(len(train_dataloader) * 1 * 0.1) #10% of train data
Expand Down Expand Up @@ -153,7 +153,7 @@ def create_input_examples_for_testing(X_test_as_list):

@evaluate_server_model
def evaluate_global_model(server_flex_model: FlexModel, test_data=None):
X_test = ss_triplet_input_adapter(X_test_as_list=test_dataset.X_data.tolist(), train=False)
_, X_test = ss_triplet_input_adapter(X_test_as_list=test_dataset.X_data.tolist(), train=False)
model = server_flex_model["model"]
evaluator = TripletEvaluator.from_input_examples(X_test)
model.evaluate(evaluator, 'server_evaluation')
Expand Down
2 changes: 1 addition & 1 deletion flexnlp/utils/adapters/ss_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def ss_triplet_input_adapter(X_train_as_list: list = None, X_test_as_list: list
train_examples = None
dev_examples = None
if train and len(X_train_as_list) > 1:
train_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_data_as_list]
train_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_train_as_list]
train_examples = DataLoader(train_examples, shuffle=shuffle, batch_size=batch_size)
if test and len(X_test_as_list) > 1:
dev_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]
Expand Down

0 comments on commit b39504a

Please sign in to comment.