Skip to content

Commit 397323b

Browse files
committed
Add README to run the code
1 parent da7f446 commit 397323b

File tree

3 files changed

+90
-83
lines changed

3 files changed

+90
-83
lines changed

README.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,24 @@
11
# SplitEasy
22

3-
This repository contains the code for our paper [SplitEasy: A Practical Approach for Training ML models on Mobile Devices in a split second](https://arxiv.org/abs/2011.04232). The code and the README will be updated in the coming days.
3+
This repository contains the code for our paper [SplitEasy: A Practical Approach for Training ML models on Mobile Devices in a split second](https://arxiv.org/abs/2011.04232).
4+
5+
The etnire code is written using React Native. To follow the general guideline for React Native apps follow this [link](https://developers.facebook.com/docs/react-native/getting-started/).
6+
7+
There are three parts to this code: loading the dataset on the mobile device, running the models on the mobile device and running the models on the server.
8+
9+
10+
### Loading the dataset on the mobile device
11+
12+
The data files will be available in this [link](). Move them to your assets folder. The json files are split into 5 parts each with 25 images, so that we can load the image into the phone.
13+
14+
Once the data is in the assets folder, go to *App.js* and ensure that you load the *LoadImageNet* component. The code will load take some time to load on your device, once it loaded click the load data button. The data that is being loaded will be logged in your terminal. Then, change the i value in line 58 of *imagenet_load.js* to i
15+
+25 and change the name of the file to the next file. Repeat this process for 5 times and the data is loaded in your phone.
16+
17+
### Loading the dataset on the mobile device
18+
19+
The component to load in your *App.js* now is the *SplitNet* component. The main changes to make in this component is to add the server url in the *App.js* file and *final_implementation.js* file. Based on the model, load the appropriate file for ModelA and ModelC, the files are availabe in this [link](). The component will have the button for training the model and once you click it the values are logged in your terminal.
20+
21+
### Running the Server Code
22+
23+
The server code should preferably be run in a server with GPU support. The code requires the argument *--model_name* in which you mention the name of the architecture. Once you run the code, you can use the IP address of your server to run the javascript code.
24+

server/restserver.py

Lines changed: 67 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,4 @@
1-
#cloud required imports
21
import os
3-
# from firebase_admin import credentials, initialize_app
4-
5-
#custom imports
62
import ast
73
from flask import Flask , jsonify, request, Response
84
from flask_socketio import SocketIO, emit
@@ -18,34 +14,18 @@
1814
import mgzip
1915
import inception
2016
import inception_resnet
17+
import argparse
18+
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument("--model_name", type=str)
2121

22-
# class SplitB(nn.Module):
23-
# def __init__(self):
24-
# super(ModelB, self).__init__()
25-
26-
# self.conv = nn.Sequential(
27-
# nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
28-
# nn.MaxPool2d(kernel_size=2,stride=2),
29-
# nn.ReLU(inplace=True),
30-
# nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
31-
# nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
32-
# nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
33-
# nn.MaxPool2d(kernel_size=2, stride=2),
34-
# nn.ReLU(inplace=True))
35-
36-
# self.linear = nn.Sequential(
37-
# nn.Linear(256*3*3, 1024,bias=False),
38-
# nn.ReLU(inplace=True),
39-
# nn.Linear(1024, 512, bias=False),
40-
# nn.ReLU(inplace=True)
41-
# )
42-
# def forward(self, x):
43-
# output = self.conv(x)
44-
# output = output.view(-1, 256*3*3)
45-
# output = self.linear(output)
46-
47-
# return output
48-
class SplitB(nn.Module):
22+
class Identity(nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
def forward(self, x):
26+
return x
27+
28+
class ResNetSplitB(nn.Module):
4929
def __init__(self):
5030
super(SplitB, self).__init__()
5131
resnet_model = models.resnet50(pretrained=False)
@@ -66,67 +46,59 @@ def forward(self, x):
6646
output = output.view(-1, 2048)
6747
return output
6848

69-
# class SplitB(nn.Module):
70-
# def __init__(self):
71-
# super(SplitB, self).__init__()
49+
class VGGSplitB(nn.Module):
50+
def __init__(self):
51+
super(SplitB, self).__init__()
7252

73-
# conv_layers = [module for module in models.vgg19(pretrained=False).features.modules() if type(module) != nn.Sequential][3:]
74-
# avgpool = [module for module in models.vgg19(pretrained=False).avgpool.modules() if type(module) != nn.Sequential]
75-
# self.conv_model = nn.Sequential(*conv_layers,
76-
# *avgpool)
53+
conv_layers = [module for module in models.vgg19(pretrained=False).features.modules() if type(module) != nn.Sequential][3:]
54+
avgpool = [module for module in models.vgg19(pretrained=False).avgpool.modules() if type(module) != nn.Sequential]
55+
self.conv_model = nn.Sequential(*conv_layers,
56+
*avgpool)
7757

78-
# classifier = [module for module in models.vgg19(pretrained=False).classifier.modules() if type(module) != nn.Sequential][:-1]
79-
# self.fc_model = nn.Sequential(*classifier)
80-
# def forward(self, x):
81-
# output = self.conv_model(x)
82-
# output = output.view(-1, 25088)
83-
# output = self.fc_model(output)
84-
# return output
58+
classifier = [module for module in models.vgg19(pretrained=False).classifier.modules() if type(module) != nn.Sequential][:-1]
59+
self.fc_model = nn.Sequential(*classifier)
60+
def forward(self, x):
61+
output = self.conv_model(x)
62+
output = output.view(-1, 25088)
63+
output = self.fc_model(output)
64+
return output
8565

86-
class Identity(nn.Module):
66+
class InceptionV3SplitB(nn.Module):
8767
def __init__(self):
8868
super().__init__()
69+
self.model = inception.inceptionv3()
8970
def forward(self, x):
90-
return x
71+
output = self.model(x)
72+
output = output.view(-1, 2048)
73+
return output
9174

92-
# class SplitB(nn.Module):
93-
# def __init__(self):
94-
# super().__init__()
95-
# self.model = inception.inceptionv3()
96-
# def forward(self, x):
97-
# output = self.model(x)
98-
# output = output.view(-1, 2048)
99-
# return output
100-
101-
# class SplitB(nn.Module):
102-
# def __init__(self):
103-
# super().__init__()
104-
# self.model = inception_resnet.inceptionresnetv2(pretrained=False, num_classes=1000)
105-
# def forward(self, x):
106-
# output = self.model(x)
107-
# output = output.view(-1, 1536)
108-
# return output
109-
110-
# class SplitB(nn.Module):
111-
# def __init__(self):
112-
# super().__init__()
113-
# self.features = models.densenet121().features
114-
# self.features.conv0 = Identity()
115-
# self.features.norm0 = Identity()
116-
# self.features.relu0 = Identity()
117-
# self.features.pool0 = Identity()
118-
# def forward(self, x):
119-
# output = self.features(x)
120-
# output = nn.functional.adaptive_avg_pool2d(output, (1, 1))
121-
# output = output.view(-1, 1024)
122-
# return output
75+
class InceptionResNetSplitB(nn.Module):
76+
def __init__(self):
77+
super().__init__()
78+
self.model = inception_resnet.inceptionresnetv2(pretrained=False, num_classes=1000)
79+
def forward(self, x):
80+
output = self.model(x)
81+
output = output.view(-1, 1536)
82+
return output
83+
84+
class DenseNetSplitB(nn.Module):
85+
def __init__(self):
86+
super().__init__()
87+
self.features = models.densenet121().features
88+
self.features.conv0 = Identity()
89+
self.features.norm0 = Identity()
90+
self.features.relu0 = Identity()
91+
self.features.pool0 = Identity()
92+
def forward(self, x):
93+
output = self.features(x)
94+
output = nn.functional.adaptive_avg_pool2d(output, (1, 1))
95+
output = output.view(-1, 1024)
96+
return output
12397

12498
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12599
app = Flask(__name__)
126100
socketio = SocketIO(app, cors_allowed_origins="*", ping_timeout=180, ping_interval=10)
127-
model = SplitB()
128-
model.to(torch.float32)
129-
model.to(device)
101+
130102
loss_fn = nn.MSELoss()
131103
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)
132104
output_dict_B = []
@@ -207,4 +179,18 @@ def back_propagation():
207179

208180
port = int(os.environ.get('PORT', 8080))
209181
if __name__=="__main__":
182+
args = parser.parse_args()
183+
if args.model_name == "resnet":
184+
model = ResNetSplitB()
185+
elif args.model_name == "inception":
186+
model = InceptionV3SplitB()
187+
elif args.model_name == "inception_resnet":
188+
model = InceptionResNetSplitB()
189+
elif args.model_name == "densenet":
190+
model = DenseNetSplitB()
191+
elif args.model_name == "vgg":
192+
model = VGGSplitB()
193+
194+
model.to(torch.float32)
195+
model.to(device)
210196
socketio.run(app, host='0.0.0.0', debug=True, port=port)

src/App.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import io from 'socket.io-client';
3535

3636

3737
const App: () => React$Node = () => {
38-
const server_url = 'http://52.194.230.228:8080';
38+
const server_url = '....';
3939
const socket = new io(server_url, {
4040
query: 'b64=1',
4141
pingTimeout: 360000

0 commit comments

Comments
 (0)