From aa9407c0555d47f355f8660c32f370886b94df37 Mon Sep 17 00:00:00 2001 From: cristinazuhe Date: Thu, 25 Jan 2024 09:31:10 +0100 Subject: [PATCH] Notebook ready. Collators an adapters ready too. It would be ok to add 2 aggregators normally used on NLP problems. --- ...ed IMDb PT using FLExible with a GRU.ipynb | 49 ++++--------------- 1 file changed, 9 insertions(+), 40 deletions(-) diff --git a/notebooks/Federated IMDb PT using FLExible with a GRU.ipynb b/notebooks/Federated IMDb PT using FLExible with a GRU.ipynb index b2a550a..3a09787 100644 --- a/notebooks/Federated IMDb PT using FLExible with a GRU.ipynb +++ b/notebooks/Federated IMDb PT using FLExible with a GRU.ipynb @@ -431,36 +431,10 @@ "\n", " return string.strip().lower()\n", "\n", - "def collate_batch(batch):\n", - " def preprocess_text(text):\n", - " text_transform = lambda x: [vocabulary[\"\"]]+[vocabulary[token] for token in spacy_tokenizer(x)]+[vocabulary[\"\"]]\n", - " return list(text_transform(clean_str(text)))\n", - " label_list, text_list = [], []\n", - " for (_text, _label) in batch:\n", - " label_transform = lambda x: int(x) - 1\n", - " label_list.append(label_transform(_label))\n", - " processed_text = torch.tensor(preprocess_text(_text))\n", - " text_list.append(processed_text)\n", - " label_list = torch.tensor(label_list, dtype=torch.int64)\n", - " return pad_sequence(text_list, padding_value=pad_index, batch_first=True), label_list\n", - "\n", "def preprocess_text(text):\n", " text_transform = lambda x: [vocabulary[\"\"]]+[vocabulary[token] for token in spacy_tokenizer(x)]+[vocabulary[\"\"]]\n", " return list(text_transform(clean_str(text)))\n", "\n", - "def batch_sampler_v2(batch_size, indices):\n", - " random.shuffle(indices)\n", - " pooled_indices = []\n", - " # create pool of indices with similar lengths \n", - " for i in range(0, len(indices), batch_size * 100):\n", - " pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))\n", - "\n", - " pooled_indices = [x[0] for x in pooled_indices]\n", - "\n", - " # yield indices for current batch\n", - " for i in range(0, len(pooled_indices), batch_size):\n", - " yield pooled_indices[i:i + batch_size]\n", - "\n", "def train(client_flex_model: FlexModel, client_data: Dataset):\n", " X_data, y_data = client_data.to_list()\n", " if 'train_indices' not in client_flex_model:\n", @@ -476,28 +450,20 @@ " label_list = [label_transform(_label) for _label in y_data]\n", "\n", " client_data = client_data.from_array(X_data, label_list)\n", - "\n", - " # batch_size=BATCH_SIZE, shuffle=True, # No es necesario usarlo porque usamos el batch_sampler\n", " client_dataloader = DataLoader(client_data, collate_fn=basic_collate_pad_sequence_classification, batch_size=BATCH_SIZE,\n", " shuffle=True)\n", - " #  batch_sampler=batch_sampler_v2(BATCH_SIZE, train_indices))\n", " model = client_flex_model[\"model\"]\n", - " # lr = 0.001\n", " optimizer = client_flex_model['optimizer_func'](model.parameters(), lr=0.01, **client_flex_model[\"optimizer_kwargs\"])\n", " model = model.train()\n", " model = model.to(device)\n", " criterion = client_flex_model[\"criterion\"]\n", - " # Al usar batch_sampler, hay que recargar el DataLoader en cada epoch.\n", " for _ in tqdm(range(NUM_EPOCHS)):\n", - " # client_dataloader = DataLoader(client_data, collate_fn=collate_batch,\n", - " # batch_sampler=batch_sampler_v2(BATCH_SIZE, train_indices))\n", " losses = []\n", " total_acc, total_count = 0, 0\n", " for texts, labels in client_dataloader:\n", " optimizer.zero_grad()\n", " texts, labels = texts.to(device), labels.to(device)\n", " predicted_labels = model(texts).squeeze(dim=0)\n", - " # pred = pred.squeeze(dim=0)\n", " loss = criterion(predicted_labels, labels)\n", " if predicted_labels.isnan().any():\n", " print(f\"Text in batch: {texts}\")\n", @@ -603,17 +569,13 @@ " total_count = 0\n", " model = model.to(device)\n", " criterion=server_flex_model['criterion']\n", - " # get test data as a torchvision object\n", - " # test_dataloader = DataLoader(test_data, batch_size=256, shuffle=True, pin_memory=False, collate_fn=collate_batch)\n", + " # Prepare the test data for the prediction\n", " X_data, y_data = test_data.X_data.tolist(), test_data.y_data.tolist()\n", " X_data = [preprocess_text(text) for text in X_data]\n", " label_transform = lambda x: int(x) - 1\n", " label_list = [label_transform(_label) for _label in y_data]\n", - " # test_indices = [(i, len(tokenizer(s[0]))) for i, s in enumerate(X_data)]\n", " test_data = test_data.from_array(X_data, label_list)\n", " test_dataloader = DataLoader(test_data, batch_size=256, shuffle=True, pin_memory=False, collate_fn=basic_collate_pad_sequence_classification)\n", - " # test_dataloader = DataLoader(test_data, collate_fn=basic_collate_pad_sequence_classification,\n", - " #  batch_sampler=batch_sampler_v2(BATCH_SIZE, test_indices))\n", " losses = []\n", " with torch.no_grad():\n", " for data, target in test_dataloader:\n", @@ -641,6 +603,13 @@ "metrics = servers.map(evaluate_global_model, test_data=test_dataset)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Show the metrics after evaluating the model." + ] + }, { "cell_type": "code", "execution_count": null, @@ -681,7 +650,7 @@ " pool.aggregators.map(fed_avg)\n", " # The aggregator send its aggregated weights to the server\n", " pool.aggregators.map(set_aggregated_weights_pt, pool.servers)\n", - " metrics = pool.servers.map(evaluate_global_model, test_data=test_imdb_dataset)\n", + " metrics = pool.servers.map(evaluate_global_model, test_data=test_dataset)\n", " loss, acc = metrics[0]\n", " print(f\"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}\")" ]