Skip to content

Commit

Permalink
Notebook ready. Collators an adapters ready too. It would be ok to ad…
Browse files Browse the repository at this point in the history
…d 2 aggregators normally used on NLP problems.
  • Loading branch information
cristinazuhe committed Jan 25, 2024
1 parent 8ca6d59 commit aa9407c
Showing 1 changed file with 9 additions and 40 deletions.
49 changes: 9 additions & 40 deletions notebooks/Federated IMDb PT using FLExible with a GRU.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -431,36 +431,10 @@
"\n",
" return string.strip().lower()\n",
"\n",
"def collate_batch(batch):\n",
" def preprocess_text(text):\n",
" text_transform = lambda x: [vocabulary[\"<pad>\"]]+[vocabulary[token] for token in spacy_tokenizer(x)]+[vocabulary[\"<pad>\"]]\n",
" return list(text_transform(clean_str(text)))\n",
" label_list, text_list = [], []\n",
" for (_text, _label) in batch:\n",
" label_transform = lambda x: int(x) - 1\n",
" label_list.append(label_transform(_label))\n",
" processed_text = torch.tensor(preprocess_text(_text))\n",
" text_list.append(processed_text)\n",
" label_list = torch.tensor(label_list, dtype=torch.int64)\n",
" return pad_sequence(text_list, padding_value=pad_index, batch_first=True), label_list\n",
"\n",
"def preprocess_text(text):\n",
" text_transform = lambda x: [vocabulary[\"<pad>\"]]+[vocabulary[token] for token in spacy_tokenizer(x)]+[vocabulary[\"<pad>\"]]\n",
" return list(text_transform(clean_str(text)))\n",
"\n",
"def batch_sampler_v2(batch_size, indices):\n",
" random.shuffle(indices)\n",
" pooled_indices = []\n",
" # create pool of indices with similar lengths \n",
" for i in range(0, len(indices), batch_size * 100):\n",
" pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))\n",
"\n",
" pooled_indices = [x[0] for x in pooled_indices]\n",
"\n",
" # yield indices for current batch\n",
" for i in range(0, len(pooled_indices), batch_size):\n",
" yield pooled_indices[i:i + batch_size]\n",
"\n",
"def train(client_flex_model: FlexModel, client_data: Dataset):\n",
" X_data, y_data = client_data.to_list()\n",
" if 'train_indices' not in client_flex_model:\n",
Expand All @@ -476,28 +450,20 @@
" label_list = [label_transform(_label) for _label in y_data]\n",
"\n",
" client_data = client_data.from_array(X_data, label_list)\n",
"\n",
" # batch_size=BATCH_SIZE, shuffle=True, # No es necesario usarlo porque usamos el batch_sampler\n",
" client_dataloader = DataLoader(client_data, collate_fn=basic_collate_pad_sequence_classification, batch_size=BATCH_SIZE,\n",
" shuffle=True)\n",
" #  batch_sampler=batch_sampler_v2(BATCH_SIZE, train_indices))\n",
" model = client_flex_model[\"model\"]\n",
" # lr = 0.001\n",
" optimizer = client_flex_model['optimizer_func'](model.parameters(), lr=0.01, **client_flex_model[\"optimizer_kwargs\"])\n",
" model = model.train()\n",
" model = model.to(device)\n",
" criterion = client_flex_model[\"criterion\"]\n",
" # Al usar batch_sampler, hay que recargar el DataLoader en cada epoch.\n",
" for _ in tqdm(range(NUM_EPOCHS)):\n",
" # client_dataloader = DataLoader(client_data, collate_fn=collate_batch,\n",
" # batch_sampler=batch_sampler_v2(BATCH_SIZE, train_indices))\n",
" losses = []\n",
" total_acc, total_count = 0, 0\n",
" for texts, labels in client_dataloader:\n",
" optimizer.zero_grad()\n",
" texts, labels = texts.to(device), labels.to(device)\n",
" predicted_labels = model(texts).squeeze(dim=0)\n",
" # pred = pred.squeeze(dim=0)\n",
" loss = criterion(predicted_labels, labels)\n",
" if predicted_labels.isnan().any():\n",
" print(f\"Text in batch: {texts}\")\n",
Expand Down Expand Up @@ -603,17 +569,13 @@
" total_count = 0\n",
" model = model.to(device)\n",
" criterion=server_flex_model['criterion']\n",
" # get test data as a torchvision object\n",
" # test_dataloader = DataLoader(test_data, batch_size=256, shuffle=True, pin_memory=False, collate_fn=collate_batch)\n",
" # Prepare the test data for the prediction\n",
" X_data, y_data = test_data.X_data.tolist(), test_data.y_data.tolist()\n",
" X_data = [preprocess_text(text) for text in X_data]\n",
" label_transform = lambda x: int(x) - 1\n",
" label_list = [label_transform(_label) for _label in y_data]\n",
" # test_indices = [(i, len(tokenizer(s[0]))) for i, s in enumerate(X_data)]\n",
" test_data = test_data.from_array(X_data, label_list)\n",
" test_dataloader = DataLoader(test_data, batch_size=256, shuffle=True, pin_memory=False, collate_fn=basic_collate_pad_sequence_classification)\n",
" # test_dataloader = DataLoader(test_data, collate_fn=basic_collate_pad_sequence_classification,\n",
" #  batch_sampler=batch_sampler_v2(BATCH_SIZE, test_indices))\n",
" losses = []\n",
" with torch.no_grad():\n",
" for data, target in test_dataloader:\n",
Expand Down Expand Up @@ -641,6 +603,13 @@
"metrics = servers.map(evaluate_global_model, test_data=test_dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Show the metrics after evaluating the model."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -681,7 +650,7 @@
" pool.aggregators.map(fed_avg)\n",
" # The aggregator send its aggregated weights to the server\n",
" pool.aggregators.map(set_aggregated_weights_pt, pool.servers)\n",
" metrics = pool.servers.map(evaluate_global_model, test_data=test_imdb_dataset)\n",
" metrics = pool.servers.map(evaluate_global_model, test_data=test_dataset)\n",
" loss, acc = metrics[0]\n",
" print(f\"Server: Test acc: {acc:.4f}, test loss: {loss:.4f}\")"
]
Expand Down

0 comments on commit aa9407c

Please sign in to comment.