From 9a943a7e88954c8a949a4ee6315d6889606a0336 Mon Sep 17 00:00:00 2001 From: futabato <01futabato10@gmail.com> Date: Tue, 5 Nov 2024 14:54:28 +0900 Subject: [PATCH] :bug: Fix a simple bug. --- .../reputation/monitoring.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/federatedlearning/reputation/monitoring.py b/src/federatedlearning/reputation/monitoring.py index 29acdd8..76dc78a 100644 --- a/src/federatedlearning/reputation/monitoring.py +++ b/src/federatedlearning/reputation/monitoring.py @@ -5,6 +5,7 @@ import numpy as np import torch import torch.nn as nn +import torchvision from federatedlearning.models.cnn import CNNMnist from models.resnet import ResNet18 from nptyping import DataFrame @@ -225,8 +226,10 @@ def monitor_time_series_convergence( global_model = CNNMnist(cfg=cfg) local_model = CNNMnist(cfg=cfg) elif cfg.train.dataset == "cifar": - global_model = ResNet18() - local_model = ResNet18() + global_model = torchvision.models.resnet18(weights="IMAGENET1K_V1") + global_model.fc = torch.nn.Linear(global_model.fc.in_features, 10) + local_model = torchvision.models.resnet18(weights="IMAGENET1K_V1") + local_model.fc = torch.nn.Linear(local_model.fc.in_features, 10) # Check if it's not the first round if round > 0: # Load the local model weights for the current client and round @@ -279,8 +282,18 @@ def monitor_time_series_similarity( previous_local_model = CNNMnist(cfg) current_local_model = CNNMnist(cfg) elif cfg.train.dataset == "cifar": - previous_local_model = ResNet18() - current_local_model = ResNet18() + previous_local_model = torchvision.models.resnet18( + weights="IMAGENET1K_V1" + ) + previous_local_model.fc = torch.nn.Linear( + previous_local_model.fc.in_features, 10 + ) + current_local_model = torchvision.models.resnet18( + weights="IMAGENET1K_V1" + ) + current_local_model.fc = torch.nn.Linear( + current_local_model.fc.in_features, 10 + ) is_reliable: bool = True if round > 0: @@ -322,7 +335,8 @@ def monitor_trust_scored_clustering( if cfg.train.dataset == "mnist": local_model = CNNMnist(cfg=cfg) elif cfg.train.dataset == "cifar": - local_model = ResNet18() + local_model = torchvision.models.resnet18(weights="IMAGENET1K_V1") + local_model.fc = torch.nn.Linear(local_model.fc.in_features, 10) print(f"[TrustScoredClustering] {selected_client_idx=}") print(f"[TrustScoredClustering] {byzantine_clients=}")