diff --git a/flexnlp/notebooks/FederatedSS.py b/flexnlp/notebooks/FederatedSS.py index c3ee55a..faf9d91 100644 --- a/flexnlp/notebooks/FederatedSS.py +++ b/flexnlp/notebooks/FederatedSS.py @@ -38,8 +38,8 @@ dataset_id = "embedding-data/QQP_triplets" # dataset_id = "embedding-data/sentence-compression" -dataset = load_dataset(dataset_id, split=['train[:1%]'])[0] - +data = load_dataset(dataset_id, split=['train[:1%]'])[0].train_test_split(test_size=0.1) +dataset, test_dataset = data['train'], data['test'] print(f"- The {dataset_id} dataset has {dataset.num_rows} examples.") print(f"- Each example is a {type(dataset[0])} with a {type(dataset[0]['set'])} as value.") print(f"- Examples look like this: {dataset[0]}") @@ -110,7 +110,9 @@ def create_input_examples_for_training(X_data_as_list, X_test_as_list): def train(client_flex_model: FlexModel, client_data: Dataset): print("Training client") model = client_flex_model['model'] - # client_train_dataset = client_data.to_numpy() + sentences = ['This is an example sentence', 'Each sentence is converted'] + encodings = model.encode(sentences) + print(f"Old encodings: {encodings}") X_data = client_data.X_data.tolist() tam_train = int(len(X_data) * 0.75) X_data, X_test = X_data[:tam_train], X_data[tam_train:] @@ -124,7 +126,7 @@ def train(client_flex_model: FlexModel, client_data: Dataset): evaluator=evaluator, evaluation_steps=1000, ) - model.evaluate(evaluator, 'model_evaluation') + # model.evaluate(evaluator, 'model_evaluation') sentences = ['This is an example sentence', 'Each sentence is converted'] encodings = model.encode(sentences) print(f"New encodings: {encodings}") @@ -137,6 +139,27 @@ def train(client_flex_model: FlexModel, client_data: Dataset): aggregators.map(set_aggregated_weights_pt, servers) +def create_input_examples_for_testing(X_test_as_list): + """Function to create a DataLoader to train/finetune the model at client level + + Args: + X_test_as_list (list): List containing the examples. Each example is a dict + with the following keys: query, pos, neg. + """ + return [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list] + +test_dataset = Dataset.from_huggingface_dataset(test_dataset, X_columns=['set']) + +@evaluate_server_model +def evaluate_global_model(server_flex_model: FlexModel, test_data=None): + X_test = create_input_examples_for_testing(X_test_as_list=test_dataset.X_data.tolist()) + model = server_flex_model["model"] + evaluator = TripletEvaluator.from_input_examples(X_test) + model.evaluate(evaluator, 'server_evaluation') + print("Model evaluation saved to file.") + +servers.map(evaluate_global_model, test_data=test_dataset) + def train_n_rounds(n_rounds, clients_per_round=2): pool = FlexPool.client_server_pool(fed_dataset=flex_dataset, init_func=build_server_model) for i in range(n_rounds): @@ -153,6 +176,7 @@ def train_n_rounds(n_rounds, clients_per_round=2): pool.aggregators.map(fed_avg) # The aggregator send its aggregated weights to the server pool.aggregators.map(set_aggregated_weights_pt, pool.servers) + servers.map(evaluate_global_model, test_data=test_dataset) # Train the model for n_rounds # train_n_rounds(5)