Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
innovation-cat committed Jun 16, 2021
1 parent c645c25 commit c10a5bd
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions chapter03_Python_image_classification/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def __init__(self, conf, model, train_dataset, id = -1):

self.conf = conf

self.local_model = model
self.local_model = models.get_model(self.conf["model_name"])

self.client_id = id

Expand All @@ -25,10 +25,10 @@ def local_train(self, model):
for name, param in model.state_dict().items():
self.local_model.state_dict()[name].copy_(param.clone())


#print(id(model))
optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'],
momentum=self.conf['momentum'])

#print(id(self.local_model))
self.local_model.train()
for e in range(self.conf["local_epochs"]):

Expand All @@ -49,6 +49,7 @@ def local_train(self, model):
diff = dict()
for name, data in self.local_model.state_dict().items():
diff[name] = (data - model.state_dict()[name])
#print(diff[name])

return diff

2 changes: 1 addition & 1 deletion chapter15_Backdoor_Attack/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, conf, model, train_dataset, id = -1):

self.conf = conf

self.local_model = model
self.local_model = models.get_model(self.conf["model_name"])

self.client_id = id

Expand Down

0 comments on commit c10a5bd

Please sign in to comment.