Skip to content

Commit

Permalink
🚀 Multiple updates
Browse files Browse the repository at this point in the history
Integrate hydra for better config file management.
Beautify codes.
Add md5 of data partitioning to avoid duplicate partition.
  • Loading branch information
KarhouTam committed Aug 10, 2024
1 parent b36c57b commit c68aba8
Show file tree
Hide file tree
Showing 53 changed files with 1,597 additions and 372 deletions.
1,182 changes: 1,180 additions & 2 deletions .environment/poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions .environment/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ url = "https://mirrors.sustech.edu.cn/pypi/simple"
priority = "primary"

[tool.poetry.dependencies]
python = ">=3.10, <3.12"
python = ">=3.10, <=3.12"
torch = "2.2.0"
torchvision = "^0.17.0"
torchaudio = "^2.2.0"
Expand All @@ -28,9 +28,10 @@ scikit-learn = "^1.5.0"
faiss-cpu = "^1.7.4"
pynvml = "^11.5.0"
PyYAML = "^6.0.1"
ray = "^2.24.0"
ray = { extras = ["default"], version = "2.32.0" }
tensorboard = "^2.16.2"
cvxpy = "^1.5.1"
hydra-core = "^1.3.2"

[build-system]
requires = ["poetry-core"]
Expand Down
3 changes: 2 additions & 1 deletion .environment/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ pynvml
PyYAML
ray[default]
tensorboard
cvxpy
cvxpy
hydra-core
32 changes: 18 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,23 @@ About methods of generating federated dastaset, go check [`data/README.md`](data


### Step 2. Run Experiment
❗ Method name should be identical to the `.py` file name in `src/server`.

```sh
python main.py <method> [your_config_file.yml] [method_args...]
python main.py [--config-path, --config-name] [method=<METHOD_NAME> args...]
```
- `method`: The algorithm's name, e.g., `method=fedavg`. ❗ Method name should be identical to the `.py` file name in `src/server`.
- `--config-path`: Relative path to the directory of the config file. Defaults to `config`.
- `--config-name`: Name of `.yaml` config file (w/o the `.yaml` extension). Defaults to `defaults`, which points to `config/defaults.yaml`.

Such as running FedAvg with all defaults.
```sh
python main.py fedavg
python main.py method=fedavg
```
Defaults are set in [`src/utils/constants.py`](src/utils/constants.py)
Defaults are set in both [`config/defaults.yaml`](config/defaults.yaml) and [`src/utils/constants.py`](src/utils/constants.py).

### How To Customize FL method Arguments 🤖
- By modifying config file
- By explicitly setting in CLI, e.g., `python main.py fedprox config/my_cfg.yml --mu 0.01`.
- By explicitly setting in CLI, e.g., `python main.py --config-name my_cfg.yaml method=fedprox fedprox.mu=0.01`.
- By modifying the default value in `src/utils/constants.py/DEFAULT_COMMON_ARGS` or `get_hyperparams()` of the method

⚠ For the same FL method argument, the priority of argument setting is **CLI > Config file > Default value**.
Expand All @@ -170,18 +173,17 @@ class FedProxServer(FedAvgServer):
return parser.parse_args(args_list)

```
and your `.yml` config file has
and your `.yaml` config file has
```yaml
# your_config.yml
# config/your_config.yaml
...
fedprox:
mu: 0.01
```
```shell
python main.py fedprox # fedprox.mu = 1
python main.py fedprox your_config.yml # fedprox.mu = 0.01
python main.py fedprox your_config.yml --mu 10 # fedprox.mu = 10
python main.py fedprox # fedprox.mu = 1
python main.py fedprox --config-name your_config # fedprox.mu = 0.01
```

### Monitor 📈
Expand All @@ -190,7 +192,7 @@ FL-bench supports `visdom` and `tensorboard`.
#### Activate
**👀 NOTE:** You needs to launch `visdom` / `tensorboard` server by yourself.
```yaml
# your config_file.yml
# your_config.yaml
common:
...
visible: tensorboard # options: [null, visdom, tensorboard]
Expand All @@ -210,7 +212,7 @@ common:
This feature can **vastly improve your training efficiency**. At the same time, this feature is user-friendly and easy to use!!!
### Activate (What You ONLY Need To Do)
```yaml
# your_config_file.yml
# your_config.yaml
mode: parallel
parallel:
num_workers: 2 # any positive integer that larger than 1
Expand All @@ -225,7 +227,7 @@ ray start --head [OPTIONS]
```
👀 **NOTE:** You need to keep `num_cpus: null` and `num_gpus: null` in your config file for connecting to a existing `Ray` cluster.
```yaml
# your_config_file.yml
# your_config_file.yaml
# Connect to an existing Ray cluster in localhost.
mode: parallel
parallel:
Expand All @@ -245,13 +247,15 @@ All common arguments have their default value. Go check [`DEFAULT_COMMON_ARGS`](

⚠ Common arguments cannot be set via CLI.

You can also write your own `.yml` config file. I offer you a [template](config/template.yml) in `config` and recommend you to save your config files there also.
You can also write your own `.yaml` config file. I offer you a [template](config/template.yaml) in `config` and recommend you to save your config files there also.

One example: `python main.py fedavg config/template.yaml [cli_method_args...]`

About the default values of specific FL method arguments, go check corresponding `FL-bench/src/server/<method>.py` for the full details.
| Arguments | Type | Description |
| ---------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `--config-path` | `str` | The directory of config files. Defaults to `config`, means `./config`. |
| `--config-name` | `str` | The name of config file (w/o the `.yaml` extension). Defaults to `defaults`, which points to `config/defaults.yaml`. |
| `dataset` | `str` | The name of dataset that experiment run on. |
| `model` | `str` | The model backbone experiment used. |
| `seed` | `int` | Random seed for running experiment. |
Expand Down
3 changes: 3 additions & 0 deletions config/template.yml → config/defaults.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Full explaination are listed on README.md

method: fedavg

mode: parallel # [serial, parallel]

parallel: # It's fine to keep these configs.
Expand Down Expand Up @@ -67,3 +69,4 @@ pfedsim:
# ...

# NOTE: For those unmentioned arguments, the default values are set in `get_hyperparams()` in `class <method>Server` in `src/server/<method>.py`

11 changes: 7 additions & 4 deletions data/tune_ratios_manually.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
"import numpy as np\n",
"import pickle\n",
"import json\n",
"import hashlib\n",
"\n",
"dataset = \"cifar10\"\n",
"\n",
"new_testset_ratio = 0.5\n",
"new_valset_ratio = 0\n",
"\n",
"partition = pickle.load(open(f'{dataset}/partition.pkl', 'rb'))\n",
"partition = pickle.load(open(f\"{dataset}/partition.pkl\", \"rb\"))\n",
"\n",
"for i in range(len(partition[\"data_indices\"])):\n",
" indices = np.concatenate(\n",
Expand All @@ -40,12 +41,14 @@
" new_testset_size + new_valset_size :\n",
" ]\n",
"\n",
"pickle.dump(partition, open(f'{dataset}/partition.pkl', 'wb'))\n",
"pickle.dump(partition, open(f\"{dataset}/partition.pkl\", \"wb\"))\n",
"\n",
"args = json.load(open(f'{dataset}/args.json', 'r'))\n",
"args = json.load(open(f\"{dataset}/args.json\", \"r\"))\n",
"args[\"test_ratio\"] = new_testset_ratio\n",
"args[\"val_ratio\"] = new_valset_ratio\n",
"json.dump(args, open(f'{dataset}/args.json', 'w'), indent=4)"
"json.dump(args, open(f\"{dataset}/args.json\", \"w\"), indent=4)\n",
"with open(f\"{dataset}/partition_md5.txt\", \"w\") as f:\n",
" f.write(hashlib.md5(json.dumps(args.__dict__).encode()).hexdigest())"
]
}
],
Expand Down
11 changes: 6 additions & 5 deletions data/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import pickle
from argparse import Namespace
from pathlib import Path
from typing import List

import numpy as np
import pandas as pd
import torch
import torchvision
from omegaconf import DictConfig
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
Expand Down Expand Up @@ -312,6 +312,8 @@ def __init__(
split = args.emnist_split
elif isinstance(args, dict):
split = args["emnist_split"]
elif isinstance(args, DictConfig):
split = args.emnist_split
train_part = torchvision.datasets.EMNIST(
root, split=split, train=True, download=True
)
Expand Down Expand Up @@ -386,6 +388,8 @@ def __init__(
super_class = args.super_class
elif isinstance(args, dict):
super_class = args["super_class"]
elif isinstance(args, DictConfig):
super_class = args.super_class

if super_class:
# super_class: [sub_classes]
Expand Down Expand Up @@ -580,10 +584,7 @@ def __init__(
self.classes = list(range(len(metadata["classes"])))
self.targets = torch.load(targets_path)
self.pre_transform = transforms.Compose(
[
transforms.Resize(metadata["image_size"]),
transforms.ToTensor(),
]
[transforms.Resize(metadata["image_size"]), transforms.ToTensor()]
)
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
Expand Down
11 changes: 11 additions & 0 deletions generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pickle
import random
import hashlib
from copy import deepcopy
from collections import Counter
from argparse import ArgumentParser
Expand Down Expand Up @@ -38,6 +39,13 @@ def main(args):
if not os.path.isdir(dataset_root):
os.mkdir(dataset_root)

if os.path.isfile(dataset_root / "partition_md5.txt"):
with open(dataset_root / "partition_md5.txt", "r") as f:
md5 = f.read()
if md5 == hashlib.md5(json.dumps(args.__dict__).encode()).hexdigest():
print("Partition file already exists. Skip partitioning.")
return

client_num = args.client_num
partition = {"separation": None, "data_indices": [[] for _ in range(client_num)]}
# x: num of samples,
Expand Down Expand Up @@ -318,6 +326,9 @@ def _idx_2_domain_label(index):
with open(dataset_root / "args.json", "w") as f:
json.dump(prune_args(args), f, indent=4)

with open(dataset_root / "partition_md5.txt", "w") as f:
f.write(hashlib.md5(json.dumps(args.__dict__).encode()).hexdigest())


if __name__ == "__main__":
parser = ArgumentParser()
Expand Down
80 changes: 32 additions & 48 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,23 @@
import inspect
from pathlib import Path

import yaml
import pynvml
import hydra
from omegaconf import DictConfig

from src.server.fedavg import FedAvgServer

FLBENCH_ROOT = Path(__file__).parent.absolute()
if FLBENCH_ROOT not in sys.path:
sys.path.append(FLBENCH_ROOT.as_posix())


from src.utils.tools import parse_args

if __name__ == "__main__":
if len(sys.argv) < 2:
raise RuntimeError(
"No method is specified. Run like `python main.py <method> [config_file_relative_path] [cli_method_args ...]`,",
"e.g., python main.py fedavg config/template.yml`",
)

method_name = sys.argv[1]

config_file_path = None
cli_method_args = []
if len(sys.argv) > 2:
if ".yaml" in sys.argv[2] or ".yml" in sys.argv[2]: # ***.yml or ***.yaml
config_file_path = Path(sys.argv[2]).absolute()
cli_method_args = sys.argv[3:]
else:
cli_method_args = sys.argv[2:]

@hydra.main(config_path="config", config_name="defaults", version_base=None)
def main(config: DictConfig):
method_name = config.method.lower()

try:
fl_method_server_module = importlib.import_module(f"src.server.{method_name}")
except:
Expand All @@ -47,19 +35,7 @@

get_method_hyperparams_func = getattr(server_class, f"get_hyperparams", None)

config_file_args = None
if config_file_path is not None and os.path.isfile(config_file_path):
with open(config_file_path, "r") as f:
try:
config_file_args = yaml.safe_load(f)
except:
raise TypeError(
f"Config file's type should be yaml, now is {config_file_path}"
)

ARGS = parse_args(
config_file_args, method_name, get_method_hyperparams_func, cli_method_args
)
config = parse_args(config, method_name, get_method_hyperparams_func)

# target method is not inherited from FedAvgServer
if server_class.__bases__[0] != FedAvgServer and server_class != FedAvgServer:
Expand All @@ -68,22 +44,21 @@
get_parent_method_hyperparams_func = getattr(
parent_server_class, f"get_hyperparams", None
)
# class name: ***Server, only want ***
# class name: <METHOD_NAME>Server, only want <METHOD_NAME>
parent_method_name = parent_server_class.__name__.lower()[:-6]
# extract the hyperparams of parent method
PARENT_ARGS = parse_args(
config_file_args,
parent_method_name,
get_parent_method_hyperparams_func,
cli_method_args,
# extract the hyperparameters of the parent method
parent_config = parse_args(
config, parent_method_name, get_parent_method_hyperparams_func
)
setattr(
config, parent_method_name, getattr(parent_config, parent_method_name)
)
setattr(ARGS, parent_method_name, getattr(PARENT_ARGS, parent_method_name))

if ARGS.mode == "parallel":
if config.mode == "parallel":
import ray

num_available_gpus = ARGS.parallel.num_gpus
num_available_cpus = ARGS.parallel.num_cpus
num_available_gpus = config.parallel.num_gpus
num_available_cpus = config.parallel.num_cpus
if num_available_gpus is None:
pynvml.nvmlInit()
num_total_gpus = pynvml.nvmlDeviceGetCount()
Expand All @@ -97,7 +72,7 @@
num_available_cpus = os.cpu_count()
try:
ray.init(
address=ARGS.parallel.ray_cluster_addr,
address=config.parallel.ray_cluster_addr,
namespace=method_name,
num_cpus=num_available_cpus,
num_gpus=num_available_gpus,
Expand All @@ -107,14 +82,23 @@
# have existing cluster
# then no pass num_cpus and num_gpus
ray.init(
address=ARGS.parallel.ray_cluster_addr,
address=config.parallel.ray_cluster_addr,
namespace=method_name,
ignore_reinit_error=True,
)

cluster_resources = ray.cluster_resources()
ARGS.parallel.num_cpus = cluster_resources["CPU"]
ARGS.parallel.num_gpus = cluster_resources["GPU"]
config.parallel.num_cpus = cluster_resources["CPU"]
config.parallel.num_gpus = cluster_resources["GPU"]

server = server_class(args=ARGS)
server = server_class(args=config)
server.run()


if __name__ == "__main__":
# For gather the Fl-bench logs and hydra logs
# Otherwise the hydra logs are stored in ./outputs/...
sys.argv.append(
"hydra.run.dir=./out/${method}/${common.dataset}/${now:%Y-%m-%d-%H-%M-%S}"
)
main()
Loading

0 comments on commit c68aba8

Please sign in to comment.