Skip to content

Commit

Permalink
Modified clip_avg_f
Browse files Browse the repository at this point in the history
  • Loading branch information
cristinazuhe committed Jan 26, 2024
1 parent 4289a8e commit d41e94e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions flexnlp/pool/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def clip_avg_f(aggregate_weights_as_list: list, clip_threshold: float = 0.9):
agg_weights = []
for layer_index in range(n_layers):
weights_per_layer = []
for client_weights in aggregate_weights_as_list:
w = tl.tensor(client_weights[layer_index])
weights_per_layer.append(w)
for w in aggregate_weights_as_list:
# w = tl.tensor(client_weights[layer_index])
weights_per_layer.append(w[layer_index])
weights_per_layer = tl.stack(weights_per_layer)
clip_thresh = np.percentile(weights_per_layer, clip_threshold*100, axis=0)
clip_thresh = tl.tensor(np.percentile(weights_per_layer, clip_threshold*100, axis=0))
sum_clipped_layer = tl.mean(tl.clip(weights_per_layer, -clip_thresh, clip_thresh), axis=0)
agg_weights.append(sum_clipped_layer)
return agg_weights
Expand Down

0 comments on commit d41e94e

Please sign in to comment.