From c10a5bd504e1c94decd2f3a6583ac9facb985d7f Mon Sep 17 00:00:00 2001 From: innovation-cat Date: Wed, 16 Jun 2021 21:22:16 +0800 Subject: [PATCH] fix bugs --- chapter03_Python_image_classification/client.py | 7 ++++--- chapter15_Backdoor_Attack/client.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/chapter03_Python_image_classification/client.py b/chapter03_Python_image_classification/client.py index 0a6ac44..dc03ce7 100644 --- a/chapter03_Python_image_classification/client.py +++ b/chapter03_Python_image_classification/client.py @@ -6,7 +6,7 @@ def __init__(self, conf, model, train_dataset, id = -1): self.conf = conf - self.local_model = model + self.local_model = models.get_model(self.conf["model_name"]) self.client_id = id @@ -25,10 +25,10 @@ def local_train(self, model): for name, param in model.state_dict().items(): self.local_model.state_dict()[name].copy_(param.clone()) - + #print(id(model)) optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'], momentum=self.conf['momentum']) - + #print(id(self.local_model)) self.local_model.train() for e in range(self.conf["local_epochs"]): @@ -49,6 +49,7 @@ def local_train(self, model): diff = dict() for name, data in self.local_model.state_dict().items(): diff[name] = (data - model.state_dict()[name]) + #print(diff[name]) return diff \ No newline at end of file diff --git a/chapter15_Backdoor_Attack/client.py b/chapter15_Backdoor_Attack/client.py index 6d44267..3edde45 100644 --- a/chapter15_Backdoor_Attack/client.py +++ b/chapter15_Backdoor_Attack/client.py @@ -9,7 +9,7 @@ def __init__(self, conf, model, train_dataset, id = -1): self.conf = conf - self.local_model = model + self.local_model = models.get_model(self.conf["model_name"]) self.client_id = id