diff --git a/flexnlp/notebooks/FederatedSS.py b/flexnlp/notebooks/FederatedSS.py index f595657..c3ee55a 100644 --- a/flexnlp/notebooks/FederatedSS.py +++ b/flexnlp/notebooks/FederatedSS.py @@ -96,7 +96,7 @@ def copy_server_model_to_clients(server_flex_model: FlexModel): # Prepare data for training phase -def create_input_examples_for_training(X_data_as_list): +def create_input_examples_for_training(X_data_as_list, X_test_as_list): """Function to create a DataLoader to train/finetune the model at client level Args: @@ -104,13 +104,17 @@ def create_input_examples_for_training(X_data_as_list): with the following keys: query, pos, neg. """ train_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_data_as_list] - return DataLoader(train_examples, shuffle=True, batch_size=16), train_examples + dev_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list] + return DataLoader(train_examples, shuffle=True, batch_size=16), dev_examples def train(client_flex_model: FlexModel, client_data: Dataset): print("Training client") model = client_flex_model['model'] # client_train_dataset = client_data.to_numpy() - train_dataloader, dev_examples = create_input_examples_for_training(X_data_as_list=client_data.X_data.tolist()) + X_data = client_data.X_data.tolist() + tam_train = int(len(X_data) * 0.75) + X_data, X_test = X_data[:tam_train], X_data[tam_train:] + train_dataloader, dev_examples = create_input_examples_for_training(X_data_as_list=X_data, X_test_as_list=X_test) train_loss = losses.TripletLoss(model=model) evaluator = TripletEvaluator.from_input_examples(dev_examples) warmup_steps = int(len(train_dataloader) * 1 * 0.1) #10% of train data @@ -120,6 +124,7 @@ def train(client_flex_model: FlexModel, client_data: Dataset): evaluator=evaluator, evaluation_steps=1000, ) + model.evaluate(evaluator, 'model_evaluation') sentences = ['This is an example sentence', 'Each sentence is converted'] encodings = model.encode(sentences) print(f"New encodings: {encodings}")