Skip to content

Commit

Permalink
Added the train/test split of the data, and the evaluation function f…
Browse files Browse the repository at this point in the history
…or the model at the server side. TODO: Create a notebook with the code done and the task is done.
  • Loading branch information
cristinazuhe committed Jan 22, 2024
1 parent 6bc637d commit 1672f4f
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions flexnlp/notebooks/FederatedSS.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
dataset_id = "embedding-data/QQP_triplets"
# dataset_id = "embedding-data/sentence-compression"

dataset = load_dataset(dataset_id, split=['train[:1%]'])[0]

data = load_dataset(dataset_id, split=['train[:1%]'])[0].train_test_split(test_size=0.1)
dataset, test_dataset = data['train'], data['test']
print(f"- The {dataset_id} dataset has {dataset.num_rows} examples.")
print(f"- Each example is a {type(dataset[0])} with a {type(dataset[0]['set'])} as value.")
print(f"- Examples look like this: {dataset[0]}")
Expand Down Expand Up @@ -110,7 +110,9 @@ def create_input_examples_for_training(X_data_as_list, X_test_as_list):
def train(client_flex_model: FlexModel, client_data: Dataset):
print("Training client")
model = client_flex_model['model']
# client_train_dataset = client_data.to_numpy()
sentences = ['This is an example sentence', 'Each sentence is converted']
encodings = model.encode(sentences)
print(f"Old encodings: {encodings}")
X_data = client_data.X_data.tolist()
tam_train = int(len(X_data) * 0.75)
X_data, X_test = X_data[:tam_train], X_data[tam_train:]
Expand All @@ -124,7 +126,7 @@ def train(client_flex_model: FlexModel, client_data: Dataset):
evaluator=evaluator,
evaluation_steps=1000,
)
model.evaluate(evaluator, 'model_evaluation')
# model.evaluate(evaluator, 'model_evaluation')
sentences = ['This is an example sentence', 'Each sentence is converted']
encodings = model.encode(sentences)
print(f"New encodings: {encodings}")
Expand All @@ -137,6 +139,27 @@ def train(client_flex_model: FlexModel, client_data: Dataset):

aggregators.map(set_aggregated_weights_pt, servers)

def create_input_examples_for_testing(X_test_as_list):
"""Function to create a DataLoader to train/finetune the model at client level
Args:
X_test_as_list (list): List containing the examples. Each example is a dict
with the following keys: query, pos, neg.
"""
return [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]

test_dataset = Dataset.from_huggingface_dataset(test_dataset, X_columns=['set'])

@evaluate_server_model
def evaluate_global_model(server_flex_model: FlexModel, test_data=None):
X_test = create_input_examples_for_testing(X_test_as_list=test_dataset.X_data.tolist())
model = server_flex_model["model"]
evaluator = TripletEvaluator.from_input_examples(X_test)
model.evaluate(evaluator, 'server_evaluation')
print("Model evaluation saved to file.")

servers.map(evaluate_global_model, test_data=test_dataset)

def train_n_rounds(n_rounds, clients_per_round=2):
pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model)
for i in range(n_rounds):
Expand All @@ -153,6 +176,7 @@ def train_n_rounds(n_rounds, clients_per_round=2):
pool.aggregators.map(fed_avg)
# The aggregator send its aggregated weights to the server
pool.aggregators.map(set_aggregated_weights_pt, pool.servers)
servers.map(evaluate_global_model, test_data=test_dataset)

# Train the model for n_rounds
# train_n_rounds(5)
Expand Down

0 comments on commit 1672f4f

Please sign in to comment.