Skip to content

Commit

Permalink
Solved bug on clip_avg
Browse files Browse the repository at this point in the history
  • Loading branch information
cristinazuhe committed Jan 26, 2024
1 parent 0e6a02b commit 4289a8e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 25 deletions.
24 changes: 12 additions & 12 deletions flexnlp/pool/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import tensorly as tl
from flex.pool.decorators import aggregate_weights
from flex.pool.decorators import set_tensorly_backend
from flex.pool.aggregators import set_tensorly_backend


def clip_avg_f(aggregate_weights_as_list: list, clip_threshold: float = 0.9):
Expand All @@ -19,20 +19,20 @@ def clip_avg_f(aggregate_weights_as_list: list, clip_threshold: float = 0.9):
w = tl.tensor(client_weights[layer_index])
weights_per_layer.append(w)
weights_per_layer = tl.stack(weights_per_layer)
clip_threshold = np.quantile(weights_per_layer, clip_threshold)
sum_clipped_layer = tl.sum(tl.clip(weights_per_layer, -clip_threshold, clip_threshold), axis=0)
clip_thresh = np.percentile(weights_per_layer, clip_threshold*100, axis=0)
sum_clipped_layer = tl.mean(tl.clip(weights_per_layer, -clip_thresh, clip_thresh), axis=0)
agg_weights.append(sum_clipped_layer)
return agg_weights

@aggregate_weights
def clip_avg(aggregate_weights_as_list: list, clip_threshold: float = 0.9):
"""Aggregate the weights using the clip average method.
This function calculates the quantile of the weights of each layer and
def clip_avg(aggregated_weights_as_list: list, clip_threshold: float = 0.9):
"""Aggregate the weights using the clip average aggregation method.
This function calculates the percentile of the weights of each layer and
then clips the weights to the interval [-quantile, quantile].
Args:
aggregate_weights_as_list (list): List of weights to aggregate.
clip_threshold (float, optional): Quantile threshold to apply to each
aggregated_weights_as_list (list): List of weights to aggregate.
clip_threshold (float, optional): Percentile threshold to apply to each
layer. Defaults to 0.9.
Returns:
Expand All @@ -43,15 +43,15 @@ def clip_avg(aggregate_weights_as_list: list, clip_threshold: float = 0.9):
aggregator = flex.pool.aggregators
server = flex.pool.servers
clip_threshold = 0.98 # quantile to clip the weights
clip_threshold = 0.9 # percentile to clip the weights
aggregator.map(server, clip_avg, clip_threshold)
Example of use using the FlePool without separating server
and aggregator, and following a client-server architecture:
from flex.pool.primitives import clip_avg
clip_threshold = 0.98 # quantile to clip the weights
clip_threshold = 0.9 # percentile to clip the weights
flex_pool.aggregators.map(flex_pool.servers, clip_avg, clip_threshold=clip_threshold)
"""
set_tensorly_backend()
return clip_avg_f(aggregate_weights_as_list, clip_threshold)
set_tensorly_backend(aggregated_weights_as_list)
return clip_avg_f(aggregated_weights_as_list, clip_threshold)
2 changes: 1 addition & 1 deletion flexnlp/utils/adapters/ss_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ def ss_triplet_input_adapter(X_train_as_list: list = None, X_test_as_list: list
if test and len(X_test_as_list) > 1:
dev_examples = [InputExample(texts=[example['query'], example['pos'][0], example['neg'][0]]) for example in X_test_as_list]

return train_examples, dev_examples
return train_examples, dev_examples
15 changes: 3 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
from setuptools import find_packages, setup


TF_requires = ["tensorflow<2.11", # https://github.com/tensorflow/tensorflow/issues/58973
"tensorflow_datasets",
"tensorflow_hub"
]

PT_requires = ["torch",
"torchvision",
"torchtext",
"torchdata",
"portalocker",
]

HF_requires = ["datasets"]

setup(
name="flexnlp",
version="0.0.1",
Expand All @@ -38,19 +31,17 @@
"portalocker",
"torchdata",
"datasets",
"transformers"
"transformers",
"sentence_transformers",
"sentencepiece",
],
extras_require={
"tensorflow": TF_requires,
"pytorch": PT_requires,
"hugginface": HF_requires,
"develop": ["pytest",
"pytest-cov",
"pytest-xdist",
"coverage",
"jinja2",
*TF_requires,
*HF_requires
],
},
python_requires=">=3.8.10",
Expand Down

0 comments on commit 4289a8e

Please sign in to comment.