-
-
Notifications
You must be signed in to change notification settings - Fork 80
/
main.py
65 lines (51 loc) · 2.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import importlib
import sys
import inspect
from pathlib import Path
import hydra
from omegaconf import DictConfig
from src.server.fedavg import FedAvgServer
FLBENCH_ROOT = Path(__file__).parent.absolute()
if FLBENCH_ROOT not in sys.path:
sys.path.append(FLBENCH_ROOT.as_posix())
from src.utils.tools import parse_args
@hydra.main(config_path="config", config_name="defaults", version_base=None)
def main(config: DictConfig):
method_name = config.method.lower()
try:
fl_method_server_module = importlib.import_module(f"src.server.{method_name}")
except:
raise ImportError(f"Can't import `src.server.{method_name}`.")
module_attributes = inspect.getmembers(fl_method_server_module)
server_class = [
attribute
for attribute in module_attributes
if attribute[0].lower() == method_name + "server"
][0][1]
get_method_hyperparams_func = getattr(server_class, f"get_hyperparams", None)
config = parse_args(config, method_name, get_method_hyperparams_func)
# target method is not inherited from FedAvgServer
if server_class.__bases__[0] != FedAvgServer and server_class != FedAvgServer:
parent_server_class = server_class.__bases__[0]
if hasattr(parent_server_class, "get_hyperparams"):
get_parent_method_hyperparams_func = getattr(
parent_server_class, f"get_hyperparams", None
)
# class name: <METHOD_NAME>Server, only want <METHOD_NAME>
parent_method_name = parent_server_class.__name__.lower()[:-6]
# extract the hyperparameters of the parent method
parent_config = parse_args(
config, parent_method_name, get_parent_method_hyperparams_func
)
setattr(
config, parent_method_name, getattr(parent_config, parent_method_name)
)
server = server_class(args=config)
server.run()
if __name__ == "__main__":
# For gather the Fl-bench logs and hydra logs
# Otherwise the hydra logs are stored in ./outputs/...
sys.argv.append(
"hydra.run.dir=./out/${method}/${dataset.name}/${now:%Y-%m-%d-%H-%M-%S}"
)
main()