Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Mmasoud1 committed Nov 19, 2024
1 parent 0290ccc commit 038050e
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 271 deletions.
7 changes: 7 additions & 0 deletions app/code/aggregator/gradient_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def average_gradients(self, gradients_list):
# Convert gradients to numpy arrays and perform averaging
n = len(gradients_list)

# Ensure each gradient in gradients_list is in NumPy format
gradients_list = [
[np.array(grad) if isinstance(grad, list) else grad for grad in gradients]
for gradients in gradients_list
]


# Initialize Empty Arrays: sum_arrays is created as a list of arrays with the same shape as the gradients.
# These arrays will accumulate the gradients from all clients.
sum_arrays = [np.zeros_like(arr) for arr in gradients_list[0]]
Expand Down
8 changes: 4 additions & 4 deletions app/code/executor/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def split_dataset(self):
train_data, valid_data, infer_data = torch.utils.data.random_split(self, [train_size, valid_size, self.len - train_size - valid_size])
return train_data, valid_data, infer_data

def get_loaders(self, batch_size=1, shuffle=True):
def get_loaders(self, batch_size=1, shuffle=True, num_workers=0):
train_data, valid_data, infer_data = self.split_dataset()
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=False)
infer_loader = torch.utils.data.DataLoader(infer_data, batch_size=batch_size, shuffle=False)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
infer_loader = torch.utils.data.DataLoader(infer_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_loader, valid_loader, infer_loader
Loading

0 comments on commit 038050e

Please sign in to comment.