diff --git a/examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py b/examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py index 84ac6f68b..31e01a744 100644 --- a/examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py +++ b/examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py @@ -26,12 +26,9 @@ def run_server(): participants_count = int(Context.get_parameters( "participants_count", 1 )) - agg_ip = Context.get_parameters("AGG_IP", "0.0.0.0") - agg_port = int(Context.get_parameters("AGG_PORT", "7363")) + server = AggregationServer( aggregation=aggregation_algorithm, - host=agg_ip, - http_port=agg_port, exit_round=exit_round, ws_size=20 * 1024 * 1024, participants_count=participants_count diff --git a/lib/sedna/service/server/aggregation.py b/lib/sedna/service/server/aggregation.py index 3529d6b48..5a7bb3091 100644 --- a/lib/sedna/service/server/aggregation.py +++ b/lib/sedna/service/server/aggregation.py @@ -26,6 +26,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send from sedna.common.log import LOGGER +from sedna.common.config import Context from sedna.common.utils import get_host_ip from sedna.common.class_factory import ClassFactory, ClassType from sedna.algorithms.aggregation import AggClient @@ -211,12 +212,14 @@ def __init__( self, aggregation: str, host: str = None, - http_port: int = 7363, + http_port: int = None, exit_round: int = 1, participants_count: int = 1, ws_size: int = 10 * 1024 * 1024): if not host: - host = get_host_ip() + host = Context.get_parameters("AGG_BIND_IP", get_host_ip()) + if not http_port: + http_port = int(Context.get_parameters("AGG_BIND_PORT", 7363)) super( AggregationServer, self).__init__(