-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
73 lines (57 loc) · 3.44 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torch.utils.data import random_split, DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose
from torchvision.datasets import MNIST
def get_mnist(data_path: str = "./data"):
"""Download MNIST and apply minimal transformation."""
tr = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
trainset = MNIST(data_path, train=True, download=True, transform=tr)
testset = MNIST(data_path, train=False, download=True, transform=tr)
return trainset, testset
def prepare_dataset(num_partitions: int, batch_size: int, val_ratio: float = 0.1):
"""Download MNIST and generate IID partitions."""
# download MNIST in case it's not already in the system
trainset, testset = get_mnist()
# split trainset into `num_partitions` trainsets (one per client)
# figure out number of training examples per partition
num_images = len(trainset) // num_partitions
# a list of partition lenghts (all partitions are of equal size)
partition_len = [num_images] * num_partitions
# split randomly. This returns a list of trainsets, each with `num_images` training examples
# Note this is the simplest way of splitting this dataset. A more realistic (but more challenging) partitioning
# would induce heterogeneity in the partitions in the form of for example: each client getting a different
# amount of training examples, each client having a different distribution over the labels (maybe even some
# clients not having a single training example for certain classes). If you are curious, you can check online
# for Dirichlet (LDA) or pathological dataset partitioning in FL. A place to start is: https://arxiv.org/abs/1909.06335
trainsets = random_split(
trainset, partition_len, torch.Generator().manual_seed(2023)
)
# create dataloaders with train+val support
trainloaders = []
valloaders = []
# for each train set, let's put aside some training examples for validation
for trainset_ in trainsets:
num_total = len(trainset_)
num_val = int(val_ratio * num_total)
num_train = num_total - num_val
for_train, for_val = random_split(
trainset_, [num_train, num_val], torch.Generator().manual_seed(2023)
)
# construct data loaders and append to their respective list.
# In this way, the i-th client will get the i-th element in the trainloaders list and the i-th element in the valloaders list
trainloaders.append(
DataLoader(for_train, batch_size=batch_size, shuffle=True, num_workers=2)
)
valloaders.append(
DataLoader(for_val, batch_size=batch_size, shuffle=False, num_workers=2)
)
# We leave the test set intact (i.e. we don't partition it)
# This test set will be left on the server side and we'll be used to evaluate the
# performance of the global model after each round.
# Please note that a more realistic setting would instead use a validation set on the server for
# this purpose and only use the testset after the final round.
# Also, in some settings (specially outside simulation) it might not be feasible to construct a validation
# set on the server side, therefore evaluating the global model can only be done by the clients. (see the comment
# in main.py above the strategy definition for more details on this)
testloader = DataLoader(testset, batch_size=128)
return trainloaders, valloaders, testloader