From d41e94ec290c4979f19fed5335fd803ec898bf2c Mon Sep 17 00:00:00 2001 From: cristinazuhe Date: Fri, 26 Jan 2024 11:36:37 +0100 Subject: [PATCH] Modified clip_avg_f --- flexnlp/pool/aggregators.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flexnlp/pool/aggregators.py b/flexnlp/pool/aggregators.py index 8263fc3..918df3b 100644 --- a/flexnlp/pool/aggregators.py +++ b/flexnlp/pool/aggregators.py @@ -15,11 +15,11 @@ def clip_avg_f(aggregate_weights_as_list: list, clip_threshold: float = 0.9): agg_weights = [] for layer_index in range(n_layers): weights_per_layer = [] - for client_weights in aggregate_weights_as_list: - w = tl.tensor(client_weights[layer_index]) - weights_per_layer.append(w) + for w in aggregate_weights_as_list: + # w = tl.tensor(client_weights[layer_index]) + weights_per_layer.append(w[layer_index]) weights_per_layer = tl.stack(weights_per_layer) - clip_thresh = np.percentile(weights_per_layer, clip_threshold*100, axis=0) + clip_thresh = tl.tensor(np.percentile(weights_per_layer, clip_threshold*100, axis=0)) sum_clipped_layer = tl.mean(tl.clip(weights_per_layer, -clip_thresh, clip_thresh), axis=0) agg_weights.append(sum_clipped_layer) return agg_weights