Skip to content

Commit

Permalink
🐛 Fix a simple bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
futabato committed Nov 5, 2024
1 parent 4327a9f commit 9a943a7
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/federatedlearning/reputation/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
import torch.nn as nn
import torchvision
from federatedlearning.models.cnn import CNNMnist
from models.resnet import ResNet18
from nptyping import DataFrame
Expand Down Expand Up @@ -225,8 +226,10 @@ def monitor_time_series_convergence(
global_model = CNNMnist(cfg=cfg)
local_model = CNNMnist(cfg=cfg)
elif cfg.train.dataset == "cifar":
global_model = ResNet18()
local_model = ResNet18()
global_model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
global_model.fc = torch.nn.Linear(global_model.fc.in_features, 10)
local_model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
local_model.fc = torch.nn.Linear(local_model.fc.in_features, 10)
# Check if it's not the first round
if round > 0:
# Load the local model weights for the current client and round
Expand Down Expand Up @@ -279,8 +282,18 @@ def monitor_time_series_similarity(
previous_local_model = CNNMnist(cfg)
current_local_model = CNNMnist(cfg)
elif cfg.train.dataset == "cifar":
previous_local_model = ResNet18()
current_local_model = ResNet18()
previous_local_model = torchvision.models.resnet18(
weights="IMAGENET1K_V1"
)
previous_local_model.fc = torch.nn.Linear(
previous_local_model.fc.in_features, 10
)
current_local_model = torchvision.models.resnet18(
weights="IMAGENET1K_V1"
)
current_local_model.fc = torch.nn.Linear(
current_local_model.fc.in_features, 10
)

is_reliable: bool = True
if round > 0:
Expand Down Expand Up @@ -322,7 +335,8 @@ def monitor_trust_scored_clustering(
if cfg.train.dataset == "mnist":
local_model = CNNMnist(cfg=cfg)
elif cfg.train.dataset == "cifar":
local_model = ResNet18()
local_model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
local_model.fc = torch.nn.Linear(local_model.fc.in_features, 10)

print(f"[TrustScoredClustering] {selected_client_idx=}")
print(f"[TrustScoredClustering] {byzantine_clients=}")
Expand Down

0 comments on commit 9a943a7

Please sign in to comment.