Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,4 @@

#### For the benchmarking framework, please visit the public folder. The revamped version is visible in the src folder

## This is an open source effort for benchmarking split inference. If you would like to be a contributor in this benchmarking effort then please reach out to abhi24@mit.edu

### Contributors to the codebase so far -
#### 1. Abhishek Singh
#### 2. Justin Yu
#### 3. John Mose
#### 4. Rohan Sukumaran
#### 5. Emily Zhang
#### 6. Ethan Garza
## This is an open source effort for benchmarking split inference.
23 changes: 14 additions & 9 deletions src/algos/input_model_optimization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from utils.config_utils import process_config
from utils.config_utils import combine_configs, process_config
from algos.uniform_noise import UniformNoise
from algos.nopeek import NoPeek
from algos.simba_algo import SimbaAttack
Expand Down Expand Up @@ -31,21 +31,24 @@ def initialize(self, config):
self.attribute = config["attribute"]
self.obf_model_name = config["target_model"]

self.img_size = config["img_size"]
# load obfuscator model
target_exp_config = json.load(open(config["target_model_config"])) #config_loader(config["model_config"])
system_config = json.load(open("./configs/system_config.json")) #config_loader(config["model_config"])
target_config = process_config(system_config, target_exp_config)
target_exp_config["client"]["challenge"] = True
target_config = process_config(combine_configs(system_config, target_exp_config))
self.target_config = target_config

from interface import load_algo
self.obf_model = load_algo(target_config, self.utils)

wts_path = self.target_config["model_path"] + "/client_model.pt"
wts = torch.load(wts_path)
if isinstance(self.obf_model.client_model, torch.nn.DataParallel): # type: ignore
self.obf_model.client_model.module.load_state_dict(wts)
else:
self.obf_model.client_model.load_state_dict(wts)
if not config["target_model"] == "gaussian_blur":
wts_path = config["target_model_path"]
wts = torch.load(wts_path)
if isinstance(self.obf_model.client_model, torch.nn.DataParallel): # type: ignore
self.obf_model.client_model.module.load_state_dict(wts)
else:
self.obf_model.client_model.load_state_dict(wts)

self.obf_model.enable_logs(False)
self.obf_model.set_detached(False)
Expand Down Expand Up @@ -131,7 +134,9 @@ def forward(self, items):
rand_inp = rand_inp.to(self.utils.device)

optim.zero_grad()
ys = gen_model(rand_inp)[:,:,:128, :128]
ys = gen_model(rand_inp)
# resize the width and height to self.image_size
ys = torch.nn.functional.interpolate(ys, size=(self.img_size, self.img_size), mode='bilinear', align_corners=True)
out = self.obf_model({"x": ys})
loss = self.loss_fn(out, z)

Expand Down
24 changes: 14 additions & 10 deletions src/algos/input_optimization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from utils.config_utils import process_config
from utils.config_utils import combine_configs, process_config
from algos.uniform_noise import UniformNoise
from algos.nopeek import NoPeek
from algos.simba_algo import SimbaAttack
Expand All @@ -16,21 +16,24 @@ def initialize(self, config):
self.attribute = config["attribute"]
self.model_name = config["target_model"]

self.img_size = config["img_size"]
# load obfuscator model
target_exp_config = json.load(open(config["target_model_config"])) #config_loader(config["model_config"])
system_config = json.load(open("./configs/system_config.json")) #config_loader(config["model_config"])
target_config = process_config(system_config, target_exp_config)
target_exp_config["client"]["challenge"] = True
target_config = process_config(combine_configs(system_config, target_exp_config))
self.target_config = target_config

from interface import load_algo
self.model = load_algo(target_config, self.utils)

wts_path = self.target_config["model_path"] + "/client_model.pt"
wts = torch.load(wts_path)
if isinstance(self.model.client_model, torch.nn.DataParallel): # type: ignore
self.model.client_model.module.load_state_dict(wts)
else:
self.model.client_model.load_state_dict(wts)

if not config["target_model"] == "gaussian_blur":
wts_path = config["target_model_path"]
wts = torch.load(wts_path)
if isinstance(self.model.client_model, torch.nn.DataParallel): # type: ignore
self.model.client_model.module.load_state_dict(wts)
else:
self.model.client_model.load_state_dict(wts)

self.metric = MetricLoader(data_range=1)

Expand Down Expand Up @@ -89,7 +92,8 @@ def forward(self, items):

for _ in range(self.iters):
optim.zero_grad()
out = self.model({"x": ys})
ys = torch.nn.functional.interpolate(ys, size=(self.img_size, self.img_size), mode='bilinear', align_corners=True)
out = self.model({"x": ys})
loss = self.loss_fn(out, z)

ssim = self.metric.ssim(img, ys)
Expand Down
13 changes: 9 additions & 4 deletions src/algos/siamese_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def __init__(self, config, utils) -> None:
self.initialize(config)

def initialize(self, config):
if config.get("challenge", False):
self.is_attacked = True
else:
self.is_attacked = False
self.client_model = self.init_client_model(config)
self.put_on_gpus()
self.utils.register_model("client_model", self.client_model)
Expand All @@ -47,11 +51,12 @@ def initialize(self, config):

def forward(self, items):
x = items["x"]
pred_lbls = items["pred_lbls"]
self.z = self.client_model(x)
self.contrastive_loss = self.loss(self.z, pred_lbls)
self.utils.logger.add_entry(self.mode + "/" + self.ct_loss_tag,
self.contrastive_loss.item())
if not self.is_attacked:
pred_lbls = items["pred_lbls"]
self.contrastive_loss = self.loss(self.z, pred_lbls)
self.utils.logger.add_entry(self.mode + "/" + self.ct_loss_tag,
self.contrastive_loss.item())
# z will be detached to prevent any grad flow from the client
z = self.z.detach()
z.requires_grad = True
Expand Down
3 changes: 3 additions & 0 deletions src/algos/supervised_decoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from algos.simba_algo import SimbaAttack
from models.image_decoder import Decoder
from utils.metrics import MetricLoader
Expand All @@ -11,6 +12,7 @@ def __init__(self, config, utils):
def initialize(self, config):
self.attribute = config["attribute"]
self.metric = MetricLoader()
self.img_size = config["img_size"]
if self.attribute == "data":
self.loss_tag = "recons_loss"

Expand Down Expand Up @@ -48,6 +50,7 @@ def initialize(self, config):
def forward(self, items):
z = items["z"]
self.reconstruction = self.model(z)
ys = torch.nn.functional.interpolate(ys, size=(self.img_size, self.img_size), mode='bilinear', align_corners=True)
self.orig = items["x"]

self.loss = self.loss_fn(self.reconstruction, self.orig)
Expand Down
3 changes: 2 additions & 1 deletion src/configs/split_inference.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"experiment_type": "time_profiling",
"method": "split_inference",
"client": {"model_name": "resnet18", "split_layer": 6,
"client": {"model_name": "resnet18", "split_layer": 1,
"pretrained": false, "optimizer": "adam", "lr": 3e-4},
"server": {"model_name": "resnet18", "split_layer": 6, "logits": 2, "pretrained": false,
"lr": 3e-4, "optimizer": "adam"},
Expand Down
4 changes: 2 additions & 2 deletions src/configs/system_config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"dataset_path": "/home/justinyu/fairface/",
"experiments_folder": "/home/justinyu/experiments/",
"dataset_path": "/u/abhi24/Datasets/Faces/fairface/",
"experiments_folder": "/u/abhi24/Workspace/simba/experiments/",
"gpu_devices": [1, 3]
}
64 changes: 40 additions & 24 deletions src/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,15 @@ def load_label(self):

def __getitem__(self, index):
filepath = self.indicies[index]

# Added all cases - Train, valid, challenge and when no arg is specified
if self.config["train"] is True:
filename = "train/"+str(filepath)+".jpg"
elif self.config["challenge"] is True: # check if this is actually present in the config file. If not, lets add it - (Rohan)
filename = "challenge/"+str(filepath)+".jpg"
elif self.config["val"] is True:
filename = "val/"+str(filepath)+".jpg"
else:
filename = filepath.split('/')[-1].split('.')[0]

# if self.config["train"] is True:
# filename = "train/"+str(filepath)+".jpg"
# elif self.config.get("challenge", False): # check if this is actually present in the config file. If not, lets add it - (Rohan)
# filename = "challenge/"+str(filepath)+".jpg"
# else:
# filename = "val/"+str(filepath)+".jpg"
filename = str(filepath)

img = self.load_image(filepath)
img = self.transforms(img)
pred_label = self.load_label(filepath, "pred")
Expand All @@ -111,9 +109,11 @@ def __getitem__(self, index):
else:
privacy_label = self.load_label(filepath, "privacy")
privacy_label = self.to_tensor(privacy_label)
# print(img.shape, pred_label.shape, privacy_label.shape, filepath, filename)
sample = {'img': img, 'prediction_label': pred_label,
'private_label': privacy_label,
'filepath': filepath, 'filename': filename}
return sample

def __len__(self):
return len(self.indicies)
Expand Down Expand Up @@ -211,8 +211,10 @@ def __init__(self, config):
self.data_to_run_on = None

if config["train"] is True:
config["path"] += "/train"
self.data_to_run_on = self.train_dict
else:
config["path"] += "/val"
self.data_to_run_on = self.val_dict

super(Cifar10, self).__init__(config)
Expand Down Expand Up @@ -244,10 +246,7 @@ def load_label(self, filepath, label_type):
except:
return 1, 1




class CelebA(datasets.CelebA):
class CelebA(datasets.CelebA, BaseDataset):
def __init__(self, config):
config = deepcopy(config)
data_split = "train" if config["train"] else "valid"
Expand All @@ -262,10 +261,10 @@ def __init__(self, config):
'wavy_hair': 33,
'big_nose': 7,
'mouth_open': 21}
if self.prediction_attribute in self.attr_indices.keys():
if self.prediction_attribute in self.attr_indices.keys() or self.prediction_attribute == 'data':
target_pred = 'attr'
else:
raise ValueError("Prediction Attribute {} is not supported.".format(self.prediction_attribute))
# else:
# raise ValueError("Prediction Attribute {} is not supported.".format(self.prediction_attribute))
if self.protected_attribute in self.attr_indices.keys():
target_protect = 'attr'
target_type = [target_pred, target_protect]
Expand Down Expand Up @@ -313,10 +312,16 @@ def __init__(self, config):
self.format = "pt" # hardcoded for now
self.set_filepaths(config["challenge_dir"])
self.protected_attribute = config["protected_attribute"]

self.config = config
self.transforms = config["transforms"]
if config["dataset"] == "fairface":
self.dataset_obj = FairFace(config)
elif config["dataset"] == "celeba":
self.dataset_obj = CelebA(config)
self.dataset_obj.format = "jpg"
elif config["dataset"] == "cifar10":
self.dataset_obj = Cifar10(config)
self.dataset_obj.format = "jpg"
else:
print("not implemented yet")
exit()
Expand All @@ -334,11 +339,22 @@ def get_imgpath(self, filename):
""" The challenge folder only consists of filename
but the corresponding file in the dataset is obtained here
"""
filename = "/" + filename + "." + self.dataset_obj.format
l = list(filter(lambda x: x.endswith(filename),
self.dataset_obj.filepaths))
assert len(l) == 1
return l[0]
if self.config["dataset"] == "celeba":
filename = os.path.join(self.dataset_obj.root, self.dataset_obj.base_folder, "img_align_celeba", filename + ".jpg")
return filename
elif self.config["dataset"] == "fairface":
filename = "/" + filename + "." + self.dataset_obj.format
l = list(filter(lambda x: x.endswith(filename),
self.dataset_obj.filepaths))
assert len(l) == 1
return l[0]
elif self.config["dataset"] == "cifar10":
filename = self.dataset_obj.indicies[int(filename)]
return filename
else:
print("not implemented yet", self.config["dataset"])
exit()


def __getitem__(self, index):
filepath = self.filepaths[index]
Expand All @@ -347,7 +363,7 @@ def __getitem__(self, index):
imgpath = self.get_imgpath(filename)
img = self.load_image(imgpath)
if self.protected_attribute == "data":
privacy_label = self.dataset_obj.transforms(img)
privacy_label = self.transforms(img)
else:
privacy_label = self.load_label(imgpath, "privacy")
privacy_label = self.to_tensor(privacy_label)
Expand Down
20 changes: 13 additions & 7 deletions src/data/download_pytorch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ def translate_into_images(dataset, output_dir, prefix=None):
new_dir = os.path.join(output_dir, prefix)
if not os.path.exists(new_dir):
os.mkdir(new_dir)
skip=False
if not os.path.exists(output_dir):
os.mkdir(output_dir)
else:
if os.listdir(output_dir):
skip=True
num_of_images = dataset.data.shape[0]
filenames = []
filepaths = []
Expand All @@ -31,8 +35,10 @@ def translate_into_images(dataset, output_dir, prefix=None):
if prefix:
filename = os.path.join(prefix, filename)
filepath = os.path.join(output_dir, filename)
img = Image.fromarray(dataset.data[i])
img.save(filepath)
if not skip:
print(filepath)
img = Image.fromarray(dataset.data[i])
img.save(filepath)
filenames.append(filename)
filepaths.append(filepath)
return filenames, filepaths
Expand Down Expand Up @@ -92,27 +98,27 @@ def load_cifar_as_dict(output_dir):

trainset, valset = load_CIFAR10_dataset(output_dir)

train_files, train_filepaths = translate_into_images(trainset, output_dir, "train")
val_files, val_filepaths = translate_into_images(valset, output_dir, "val")
# train_files, train_filepaths = translate_into_images(trainset, output_dir, "train")
# val_files, val_filepaths = translate_into_images(valset, output_dir, "val")
train_labels = load_labels(trainset)
val_labels = load_labels(valset)

train_dict = dict()
train_dict["set"] = trainset
train_dict["file"] = train_files
# train_dict["file"] = train_files
train_dict["animated"] = map_class_to_animated(trainset)
train_dict["class"] = train_labels

val_dict = dict()
val_dict["set"] = valset
val_dict["file"] = val_files
# val_dict["file"] = val_files
val_dict["animated"] = map_class_to_animated(valset)
val_dict["class"] = val_labels

return train_dict, val_dict


if __name__ == '__main__':
base_dir = "/home/mit6_91621/cybop/"
base_dir = "./datasets/"
output_dir = base_dir + "cifar10"
save_and_organize_cifar10_dataset(output_dir)
Loading