Skip to content

Commit 325476c

Browse files
Client load data directly from csvs. (madgik#484)
* Added a retry mechanism for the connection of flower clients to the flower server. * Update data processing for client, so they load data from csv and not from the database
1 parent 1385732 commit 325476c

28 files changed

+1796
-1250
lines changed

.github/workflows/algorithm_validation_tests.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ jobs:
245245
with:
246246
run: cat /tmp/exareme2/localworker1.out
247247

248-
- name: Run Flower algorithm validation tests
249-
run: poetry run pytest tests/algorithm_validation_tests/flower/test_logistic_regression.py -n 2 --verbosity=4 --reruns 6 --reruns-delay 5
250-
251248
- name: Run Exareme2 algorithm validation tests
252249
run: poetry run pytest tests/algorithm_validation_tests/exareme2/ --verbosity=4 -n 16 -k "input1 and not input1-" # run tests 10-19
250+
251+
- name: Run Flower algorithm validation tests
252+
run: poetry run pytest tests/algorithm_validation_tests/flower/test_logistic_regression.py -n 2 --verbosity=4 --reruns 6 --reruns-delay 5

exareme2/algorithms/flower/flower_data_processing.py renamed to exareme2/algorithms/flower/inputdata_preprocessing.py

+17-25
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Optional
66

77
import pandas as pd
8-
import pymonetdb
98
import requests
109
from flwr.common.logger import FLOWER_LOGGER
1110
from pydantic import BaseModel
@@ -29,37 +28,30 @@ class Inputdata(BaseModel):
2928
x: Optional[List[str]]
3029

3130

32-
def fetch_data(data_model, datasets, from_db=False) -> pd.DataFrame:
33-
return (
34-
_fetch_data_from_db(data_model, datasets)
35-
if from_db
36-
else _fetch_data_from_csv(data_model, datasets)
37-
)
31+
def fetch_client_data(inputdata) -> pd.DataFrame:
32+
FLOWER_LOGGER.error(f"BROOO {os.getenv('CSV_PATHS')}")
33+
dataframes = [
34+
pd.read_csv(f"{os.getenv('DATA_PATH')}{csv_path}")
35+
for csv_path in os.getenv("CSV_PATHS").split(",")
36+
]
37+
df = pd.concat(dataframes, ignore_index=True)
38+
df = df[df["dataset"].isin(inputdata.datasets)]
39+
return df[inputdata.x + inputdata.y]
3840

3941

40-
def _fetch_data_from_db(data_model, datasets) -> pd.DataFrame:
41-
query = f'SELECT * FROM "{data_model}"."primary_data"'
42-
conn = pymonetdb.connect(
43-
hostname=os.getenv("MONETDB_IP"),
44-
port=int(os.getenv("MONETDB_PORT")),
45-
username=os.getenv("MONETDB_USERNAME"),
46-
password=os.getenv("MONETDB_PASSWORD"),
47-
database=os.getenv("MONETDB_DB"),
42+
def fetch_server_data(inputdata) -> pd.DataFrame:
43+
data_folder = Path(
44+
f"{os.getenv('DATA_PATH')}/{inputdata.data_model.split(':')[0]}_v_0_1"
4845
)
49-
df = pd.read_sql(query, conn)
50-
conn.close()
51-
df = df[df["dataset"].isin(datasets)]
52-
return df
53-
54-
55-
def _fetch_data_from_csv(data_model, datasets) -> pd.DataFrame:
56-
data_folder = Path(f"{os.getenv('DATA_PATH')}/{data_model.split(':')[0]}_v_0_1")
46+
print(f"Loading data from folder: {data_folder}")
5747
dataframes = [
5848
pd.read_csv(data_folder / f"{dataset}.csv")
59-
for dataset in datasets
49+
for dataset in inputdata.datasets
6050
if (data_folder / f"{dataset}.csv").exists()
6151
]
62-
return pd.concat(dataframes, ignore_index=True)
52+
df = pd.concat(dataframes, ignore_index=True)
53+
df = df[df["dataset"].isin(inputdata.datasets)]
54+
return df[inputdata.x + inputdata.y]
6355

6456

6557
def preprocess_data(inputdata, full_data):

exareme2/algorithms/flower/logistic_regression/client.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import os
2+
import time
23
import warnings
4+
from math import log2
35

46
import flwr as fl
7+
from flwr.common.logger import FLOWER_LOGGER
58
from sklearn.linear_model import LogisticRegression
69
from sklearn.metrics import log_loss
710
from utils import get_model_parameters
811
from utils import set_initial_params
912
from utils import set_model_params
1013

11-
from exareme2.algorithms.flower.flower_data_processing import fetch_data
12-
from exareme2.algorithms.flower.flower_data_processing import get_input
13-
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
14+
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_client_data
15+
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
16+
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data
1417

1518

1619
class LogisticRegressionClient(fl.client.NumPyClient):
@@ -39,11 +42,27 @@ def evaluate(self, parameters, config):
3942
if __name__ == "__main__":
4043
model = LogisticRegression(penalty="l2", max_iter=1, warm_start=True)
4144
inputdata = get_input()
42-
full_data = fetch_data(inputdata.data_model, inputdata.datasets, from_db=True)
45+
full_data = fetch_client_data(inputdata)
4346
X_train, y_train = preprocess_data(inputdata, full_data)
4447
set_initial_params(model, X_train, full_data, inputdata)
4548

4649
client = LogisticRegressionClient(model, X_train, y_train)
47-
fl.client.start_client(
48-
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client()
49-
)
50+
51+
attempts = 0
52+
max_attempts = int(log2(int(os.environ["TIMEOUT"])))
53+
while True:
54+
try:
55+
fl.client.start_client(
56+
server_address=os.environ["SERVER_ADDRESS"], client=client.to_client()
57+
)
58+
FLOWER_LOGGER.debug("Connection successful on attempt", attempts + 1)
59+
break
60+
except Exception as e:
61+
FLOWER_LOGGER.warning(
62+
f"Connection with the server failed. Attempt {attempts + 1} failed: {e}"
63+
)
64+
time.sleep(pow(2, attempts))
65+
attempts += 1
66+
if attempts >= max_attempts:
67+
FLOWER_LOGGER.error("Could not establish connection to the server.")
68+
raise e

exareme2/algorithms/flower/logistic_regression/server.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from utils import set_initial_params
77
from utils import set_model_params
88

9-
from exareme2.algorithms.flower.flower_data_processing import fetch_data
10-
from exareme2.algorithms.flower.flower_data_processing import get_input
11-
from exareme2.algorithms.flower.flower_data_processing import post_result
12-
from exareme2.algorithms.flower.flower_data_processing import preprocess_data
9+
from exareme2.algorithms.flower.inputdata_preprocessing import fetch_server_data
10+
from exareme2.algorithms.flower.inputdata_preprocessing import get_input
11+
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
12+
from exareme2.algorithms.flower.inputdata_preprocessing import preprocess_data
1313

1414
# TODO: NUM_OF_ROUNDS should become a parameter of the algorithm and be set on the AlgorithmRequestDTO
1515
NUM_OF_ROUNDS = 5
@@ -35,7 +35,7 @@ def evaluate(server_round, parameters, config):
3535
if __name__ == "__main__":
3636
model = LogisticRegression()
3737
inputdata = get_input()
38-
full_data = fetch_data(inputdata.data_model, inputdata.datasets)
38+
full_data = fetch_server_data(inputdata)
3939
X_train, y_train = preprocess_data(inputdata, full_data)
4040
set_initial_params(model, X_train, full_data, inputdata)
4141
strategy = fl.server.strategy.FedAvg(

exareme2/algorithms/flower/mnist_logistic_regression/server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sklearn.linear_model import LogisticRegression
66
from sklearn.metrics import log_loss
77

8-
from exareme2.algorithms.flower.flower_data_processing import post_result
8+
from exareme2.algorithms.flower.inputdata_preprocessing import post_result
99
from exareme2.algorithms.flower.mnist_logistic_regression import utils
1010

1111
NUM_OF_ROUNDS = 5

exareme2/controller/celery/tasks_handler.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,15 @@ def queue_healthcheck_task(
298298
)
299299

300300
def start_flower_client(
301-
self, request_id, algorithm_name, server_address
301+
self, request_id, algorithm_name, server_address, csv_paths, execution_timeout
302302
) -> WorkerTaskResult:
303303
return self._queue_task(
304304
task_signature=TASK_SIGNATURES["start_flower_client"],
305305
request_id=request_id,
306306
algorithm_name=algorithm_name,
307307
server_address=server_address,
308+
csv_paths=csv_paths,
309+
execution_timeout=execution_timeout,
308310
)
309311

310312
def start_flower_server(

exareme2/controller/quart/endpoints.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@ async def get_datasets() -> dict:
3232

3333
@algorithms.route("/datasets_locations", methods=["GET"])
3434
async def get_datasets_locations() -> dict:
35-
return get_worker_landscape_aggregator().get_datasets_locations().datasets_locations
35+
return {
36+
data_model: {
37+
dataset: info.worker_id for dataset, info in datasets_location.items()
38+
}
39+
for data_model, datasets_location in get_worker_landscape_aggregator()
40+
.get_datasets_locations()
41+
.datasets_locations.items()
42+
}
3643

3744

3845
@algorithms.route("/cdes_metadata", methods=["GET"])

exareme2/controller/services/flower/controller.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import asyncio
2+
import warnings
3+
from typing import Dict
24
from typing import List
35

6+
from exareme2.controller import config as ctrl_config
47
from exareme2.controller import logger as ctrl_logger
58
from exareme2.controller.federation_info_logs import log_experiment_execution
69
from exareme2.controller.services.flower.tasks_handler import FlowerTasksHandler
@@ -52,10 +55,16 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
5255
request_id = algorithm_request_dto.request_id
5356
context_id = UIDGenerator().get_a_uid()
5457
logger = ctrl_logger.get_request_logger(request_id)
55-
workers_info = self._get_workers_info_by_dataset(
58+
csv_paths_per_worker_id: Dict[
59+
str, List[str]
60+
] = self.worker_landscape_aggregator.get_csv_paths_per_worker_id(
5661
algorithm_request_dto.inputdata.data_model,
5762
algorithm_request_dto.inputdata.datasets,
5863
)
64+
workers_info = [
65+
self.worker_landscape_aggregator.get_worker_info(worker_id)
66+
for worker_id in csv_paths_per_worker_id
67+
]
5968
task_handlers = [
6069
self._create_worker_tasks_handler(request_id, worker)
6170
for worker in workers_info
@@ -87,7 +96,10 @@ async def exec_algorithm(self, algorithm_name, algorithm_request_dto):
8796
)
8897
clients_pids = {
8998
handler.start_flower_client(
90-
algorithm_name, str(server_address)
99+
algorithm_name,
100+
str(server_address),
101+
csv_paths_per_worker_id[handler.worker_id],
102+
ctrl_config.flower_execution_timeout,
91103
): handler
92104
for handler in task_handlers
93105
}
@@ -127,15 +139,3 @@ async def _cleanup(
127139
server_task_handler.stop_flower_server(server_pid, algorithm_name)
128140
for pid, handler in clients_pids.items():
129141
handler.stop_flower_client(pid, algorithm_name)
130-
131-
def _get_workers_info_by_dataset(self, data_model, datasets) -> List[WorkerInfo]:
132-
"""Retrieves worker information for those handling the specified datasets."""
133-
worker_ids = (
134-
self.worker_landscape_aggregator.get_worker_ids_with_any_of_datasets(
135-
data_model, datasets
136-
)
137-
)
138-
return [
139-
self.worker_landscape_aggregator.get_worker_info(worker_id)
140-
for worker_id in worker_ids
141-
]

exareme2/controller/services/flower/tasks_handler.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,15 @@ def worker_id(self) -> str:
2929
def worker_data_address(self) -> str:
3030
return self._db_address
3131

32-
def start_flower_client(self, algorithm_name, server_address) -> int:
32+
def start_flower_client(
33+
self, algorithm_name, server_address, csv_paths, execution_timeout
34+
) -> int:
3335
return self._worker_tasks_handler.start_flower_client(
34-
self._request_id, algorithm_name, server_address
36+
self._request_id,
37+
algorithm_name,
38+
server_address,
39+
csv_paths,
40+
execution_timeout,
3541
).get(timeout=self._tasks_timeout)
3642

3743
def start_flower_server(

exareme2/controller/services/worker_landscape_aggregator/worker_info_tasks_handler.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from exareme2.controller.celery.tasks_handler import WorkerTasksHandler
55
from exareme2.worker_communication import CommonDataElements
66
from exareme2.worker_communication import DataModelAttributes
7+
from exareme2.worker_communication import DatasetsInfoPerDataModel
78
from exareme2.worker_communication import WorkerInfo
89

910

@@ -23,10 +24,11 @@ def get_worker_info_task(self) -> WorkerInfo:
2324
).get(self._tasks_timeout)
2425
return WorkerInfo.parse_raw(result)
2526

26-
def get_worker_datasets_per_data_model_task(self) -> Dict[str, Dict[str, str]]:
27-
return self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
27+
def get_worker_datasets_per_data_model_task(self) -> DatasetsInfoPerDataModel:
28+
result = self._worker_tasks_handler.queue_worker_datasets_per_data_model_task(
2829
self._request_id
2930
).get(self._tasks_timeout)
31+
return DatasetsInfoPerDataModel.parse_raw(result)
3032

3133
def get_data_model_cdes_task(self, data_model: str) -> CommonDataElements:
3234
result = self._worker_tasks_handler.queue_data_model_cdes_task(

0 commit comments

Comments
 (0)