Skip to content

Commit

Permalink
Merge pull request #47 from futabato/feature/issue-46
Browse files Browse the repository at this point in the history
RabbitMQの組み込み
  • Loading branch information
futabato authored Oct 21, 2024
2 parents 4a903ff + 4cb896e commit af7f2dc
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 2 deletions.
14 changes: 13 additions & 1 deletion compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,16 @@ services:
volumes:
- ./:/workspace
stdin_open: true
tty: true
tty: true
rabbitmq:
# -management を付与したイメージの場合、15672ポートでManagement Plugin(Web UIのようなもの)が利用できるようになる
# Management Pluginのドキュメント: https://www.rabbitmq.com/management.html
# イメージは現時点で最新のものを指定している
image: rabbitmq:3-management
container_name: rabbitmq
ports:
- 5672:5672
- 15672:15672
# データの永続化
volumes:
- ./docker/rabbitmq/data:/var/lib/rabbitmq
1 change: 1 addition & 0 deletions config/client/client_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
client_id: 1
1 change: 1 addition & 0 deletions config/client/client_2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
client_id: 2
1 change: 1 addition & 0 deletions config/client/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
client_id: 0
1 change: 1 addition & 0 deletions config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ defaults:
- mlflow: default
- train: default
- federatedlearning: default
- client: default
- override hydra/sweeper: optuna
hydra:
run:
Expand Down
18 changes: 17 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ types-pyyaml = "^6.0.12.20240311"
tslearn = "^0.6.3"
matplotlib-fontja = "^1.0.0"
japanize-matplotlib = "^1.1.3"
pika = "^1.3.2"

[tool.poetry.group.dev.dependencies]
pyproject-flake8 = "^6.0.0.post1"
Expand Down
137 changes: 137 additions & 0 deletions src/federatedlearning/client/consumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import copy
import pickle

import hydra
import pika
import torch
import torch.nn as nn
from federatedlearning.client.training import LocalUpdate
from federatedlearning.datasets.common import get_dataset
from federatedlearning.models.cnn import CNNMnist
from omegaconf import DictConfig


class FLClient:
def __init__(
self,
cfg: DictConfig,
host: str = "rabbitmq",
local_queue: str = "local_model_queue",
exchage_name: str = "global_model_exchange",
username: str = "guest",
password: str = "guest",
):
self.host = host
self.credentials = pika.PlainCredentials(
username=username, password=password
)
self.local_queue = local_queue
self.exchange_name = exchage_name
self.connection = None
self.channel = None

# 実験設定
self.cfg = cfg
self.client_id = cfg.client.client_id
# Load the dataset and partition it according to the client groups
(
self.train_dataset,
self.test_dataset,
self.client_groups,
) = get_dataset(self.cfg)
# Determine the computing device (GPU or CPU)
self.device: torch.device = (
torch.device(f"cuda:{cfg.train.gpu}")
if cfg.train.gpu is not None and cfg.train.gpu >= 0
else torch.device("cpu")
)
self.num_epochs = 3

self._connect()

def _connect(self):
self.connection = pika.BlockingConnection(
pika.ConnectionParameters(self.host, credentials=self.credentials)
)
self.channel = self.connection.channel()
self.result = self.channel.queue_declare(queue="", exclusive=True)
self.global_queue = self.result.method.queue
self.channel.queue_bind(
exchange=self.exchange_name, queue=self.global_queue
)

def receive_global_model(self):
def callback(ch, method, properties, body):
round = properties.headers.get("round")
state_dict = pickle.loads(body)
global_model = CNNMnist(self.cfg)
global_model.load_state_dict(state_dict)
global_model.to(self.device)
print(" [x] Received initial global model")

# ローカル学習
local_model = self.local_train(global_model, round)

# ローカルモデルをサーバに送信
self.send_local_model(model=local_model, client_id=self.client_id)

self.channel.basic_consume(
queue=self.global_queue,
on_message_callback=callback,
auto_ack=True,
)
print(" [*] Waiting for initial global model. To exit press CTRL+C")
try:
self.channel.start_consuming()
except KeyboardInterrupt:
self.stop_consuming()

def local_train(self, global_model: nn.Module, round: int):
print("[x] Training model locally...")
local_model = LocalUpdate(
cfg=self.cfg,
dataset=self.train_dataset,
client_id=self.client_id,
idxs=self.client_groups[self.client_id],
)
weight, loss = local_model.update_weights(
model=copy.deepcopy(global_model), global_round=round
)

return weight

def send_local_model(self, model: nn.Module, client_id: str):
serialized_model = pickle.dumps(model)
headers = {"client_id": client_id}
self.channel.basic_publish(
exchange="",
routing_key=self.local_queue,
body=serialized_model,
properties=pika.BasicProperties(headers=headers),
)
print(f" [x] Sent updated local model from client {client_id}")

def stop_consuming(self):
if self.channel:
self.channel.stop_consuming()

def close(self):
if self.connection:
self.connection.close()


@hydra.main(
version_base="1.1", config_path="/workspace/config", config_name="default"
)
def main(cfg: DictConfig):
client = FLClient(cfg)
try:
# サーバからのグローバルモデルを待ち、それを受信
client.receive_global_model()
finally:
# 終了時に接続を閉じる
client.close()


if __name__ == "__main__":
main()
Loading

0 comments on commit af7f2dc

Please sign in to comment.