-
-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
129 lines (97 loc) · 3.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import importlib.util
import json
from enum import Enum
from pathlib import Path
from syftbox.lib import Client
class ParticipantStateCols(Enum):
EMAIL = "Email"
FL_CLIENT_INSTALLED = "Fl Client Installed"
PROJECT_APPROVED = "Project Approved"
ADDED_PRIVATE_DATA = "Added Private Data"
ROUND = "Round (current/total)"
MODEL_TRAINING_PROGRESS = "Training Progress"
def has_empty_dirs(directory: Path):
return any(
subdir.is_dir() and is_dir_empty(subdir) for subdir in directory.iterdir()
)
def is_dir_empty(directory: Path):
return not any(directory.iterdir())
def read_json(data_path: Path):
with open(data_path) as fp:
data = json.load(fp)
return data
def save_json(data: dict, data_path: Path):
with open(data_path, "w") as fp:
json.dump(data, fp, indent=4)
def create_participant_json_file(
participants: list, total_rounds: int, output_path: Path
):
data = []
for participant in participants:
data.append(
{
ParticipantStateCols.EMAIL.value: participant,
ParticipantStateCols.FL_CLIENT_INSTALLED.value: False,
ParticipantStateCols.PROJECT_APPROVED.value: False,
ParticipantStateCols.ADDED_PRIVATE_DATA.value: False,
ParticipantStateCols.ROUND.value: f"0/{total_rounds}",
ParticipantStateCols.MODEL_TRAINING_PROGRESS.value: "N/A",
}
)
save_json(data=data, data_path=output_path)
def update_json(
data_path: Path,
participant_email: str,
column_name: ParticipantStateCols,
column_val: str,
):
if column_name not in ParticipantStateCols:
return
participant_history = read_json(data_path=data_path)
for participant in participant_history:
if participant[ParticipantStateCols.EMAIL.value] == participant_email:
participant[column_name.value] = column_val
save_json(participant_history, data_path)
def load_model_class(model_path: Path, model_class_name: str) -> type:
spec = importlib.util.spec_from_file_location(model_path.stem, model_path)
model_arch = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_arch)
model_class = getattr(model_arch, model_class_name)
return model_class
def get_all_directories(path: Path) -> list:
"""
Returns the list of directories present in the given path
"""
return [x for x in path.iterdir() if x.is_dir()]
def get_network_participants(client: Client):
exclude_dir = ["apps", ".syft"]
entries = client.datasites.iterdir()
users = []
for entry in entries:
if entry.is_dir() and entry not in exclude_dir:
users.append(entry.name)
return users
def validate_launch_config(fl_config: Path) -> bool:
"""
Validates the `fl_config.json` file
"""
try:
fl_config = read_json(fl_config)
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON format in {fl_config.resolve()}")
required_keys = [
"project_name",
"aggregator",
"participants",
"model_arch",
"model_weight",
"model_class_name",
"rounds",
"epoch",
"test_dataset",
"learning_rate",
]
for key in required_keys:
if key not in fl_config:
raise ValueError(f"Required key {key} is missing in fl_config.json")
return True