Skip to content

Commit

Permalink
public code release
Browse files Browse the repository at this point in the history
  • Loading branch information
qiujiaming315 committed Jul 31, 2022
0 parents commit 6c1fcd7
Show file tree
Hide file tree
Showing 52 changed files with 4,407 additions and 0 deletions.
9 changes: 9 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
MIT License

Copyright (c) 2022 Jiaming Qiu, Ruiqi Wang, Ayan Chakrabarti, Roch Guérin, Chenyang Lu

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
72 changes: 72 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Adaptive Edge Offloading for Image Classification Under Rate Limit

<img src="img/system_overview.png" alt="system_overview" width="200"/>

## Overview

This repo provides a Python implementation for:

Jiaming Qiu, Ruiqi Wang, Ayan Chakrabarti, Roch Guérin, and Chenyang Lu, **"Adaptive Edge Offloading for Image Classification Under Rate Limit"**
[paper]
[slides]
[code]

This work is an extension of **"Real-time Edge Classification: Optimal Offloading under Token Bucket Constraints"**
[[paper]](https://ieeexplore.ieee.org/document/9708981)
[[code]](https://github.com/ayanc/edgeml.mdp)

## Requirements

We recommend a recent Python 3.7+ distribution of [Anaconda](https://www.anaconda.com/products/individual) with `numpy`, `tensorflow 2`, and `matplotlib` installed.

## Reproducing Results

To reproduce the results presented in the paper, you may download and extract our pre-computed [data](https://arxiv.org/abs/2010.13737) merged from multiple simulations. The files should be extracted into the `viz/data/` sub-directory, which allows you to directly run the note-books in the visualization step.

## Model Training and Simulation

#### Data Preparation

You will need to first download an `npz` [file](https://arxiv.org/abs/2010.13737) containing the pre-computed offloading metrics for the ILSVRC validation set with the weak and strong classifier pair we reported in the paper (VGG-style 16-layer vs. [OFA](https://github.com/mit-han-lab/once-for-all) 595MFlops). You can put the file in the `data/` sub-directory. Check our [previous work](https://github.com/ayanc/edgeml.mdp) for more details on computing offloading metrics with your own dataset and classifier pair.

#### Training

`train.py` in the root directory is the main script for training the DQN model. Here is a typical example run:

``` shell
# Start running 3k Q-value iterations with a learning rate of 10^-3.
./train.py data/ofa_top5.npz wts --tr 0.25 --tb 2 --stype 1 --itype 1 --lr 3 --maxiter 3000
# Reward and loss have saturated. Drop the learning rate to 10^-4 and run another 1k iterations.
./train.py data/ofa_top5.npz wts --tr 0.25 --tb 2 --stype 1 --itype 1 --lr 4 --maxiter 4000
```

Running the script with `-h` gives you a detailed list of parameters.

#### Simulation

After you trained the DQN model, weights are stored in the specified directory (`wts/` in the above example). The model is then ready for simulation on the test sequences with `simulate.py`. Make sure you point the script to the right directory to load the model and use consistent token bucket and sequence configurations throughout training and simulation. For example, to run simulation with the model trained above, you should use:

``` shell
# Set the model directory to wts/ and use the same configuration.
./simulate.py data/ofa_top5.npz wts save --tr 0.25 --tb 2 --stype 1 --itype 1
```

The simulation results (Average loss of the DQN, MDP, Baseline and Lower Bound policies, as well as that of the weak and strong classifiers) are stored in the specified directory (`save/` in the above example). You can compare the performance of different policies, or the performance of the same policy accorss different configurations by training DQN models and running simulations accorss these configurations.

#### Library

The actual library for loading data, generating sequences, computing policies and simulations is in the [lib](lib/) directory. Please take a look at its [README.md](lib/README.md) file for documentation.

## Visualization

We provide separate jupyter notebooks in `viz/notebook/` to visualize (either downloaded or generated) results, producing the figures included in the paper.

## License

This code is being released under the [MIT License](LICENSE). Note that the offloading metrics are computed following the procedure developed in our [previous work](https://github.com/ayanc/edgeml.mdp), and the OFA results files were generated from the [models and code](https://github.com/mit-han-lab/once-for-all) provided by its authors.

## Acknowledgments

This work was partially supported by the National Science Foundation under grants no. CPS-1646579 and CNS-2006530, and the Fullgraf Foundation. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the authors, and do not necessarily reflect the views of the National Science Foundation and the Fullgraf Foundation.

Ayan Chakrabarti extensively contributed to this implementation by collaborating with Jiaming Qiu.
Binary file added img/system_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 13 additions & 0 deletions lib/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
## Code

Various parts of the code are factored into different modules based on functionality.

- `utils.py`: Utility functions for printing messages, detecting a stop signal (`Ctrl+C`), and reading and writing checkpoints.
- `genseq.py`: Functions for generating synthetic sequences of image offloading metrics and inter-arrival times using different temporal models. Note that the offloading metric sequences are represented with image indices from the raw dataset to save memory.
- `data.py`: Code to load the data from the raw `npz` file, and a tuple sampler that randomly samples tuples of offloading metrics and inter-arrival times (for the current and the next state) and reward.
- `bucket.py`: Function to handle various bucket operations:
- Convert rate and depth to integers.
- Various utilities to interpret a long vector output from a neural network as a table of (n, a) values. Do operations to correctly do the whole `max_a' Q(n'=(n,a), a')`.
- Code for computing an i.i.d. mdp policy.
- `model.py`: Functions to create different kinds of Keras models, including one from an iid policy threshold vector.
- `bstream.py`: Generates tensorflow compiled @tf.function's to sequentially simulate offloading decision making with a model and token bucket.
Empty file added lib/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions lib/bstream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import tensorflow as tf

"""Simulating sequences with Tensorflow models."""


def modelsfunc(model, qpm_, nstream, nhist1, nhist2):
"""
Create a tensorflow compiled function to run simulation for a model.
:param model: tensorflow model.
:param qpm_: Bucket parameters from getqpm(rate, bdepth).
:param nstream: Number of parallel streams this functoin will be called on.
:param nhist1: Number of historical offloading metric values the model needs.
:param nhist2: Number of historical inter-arrival times the model needs
(When negative, the model needs no inter-arrival time input).
:return: a function that you should call as func(tf.data.Dataset, int) to get average loss,
where the second parameter to the function is number of elements in the dataset.
You can use the data2tf to get a tuple corresponding to these arguments:
simfunc = modelsfunc(...)
tfdset = data2tf(...)
avg_loss = simfunc(*tfdset)
"""
qpm = tf.constant(qpm_)
idx = tf.range(nstream)
intnum = tf.constant(nhist2 + 1)
# Create tensorflow variables to store the total reward and the current state.
tot_gain = tf.Variable(0, trainable=False, dtype=tf.float64)
nstate = tf.Variable([qpm_[2]] * nstream, trainable=False, dtype=tf.int32)
mhist = tf.Variable(np.zeros((nstream, nhist1 + 1)), trainable=False, dtype=tf.float32)
ihist = tf.Variable(np.zeros((nstream, max(nhist2 + 1, 1))), trainable=False, dtype=tf.int32)

@tf.function
def simloop(dataset):
for (_m, _r, _i) in dataset:
# Update the the history window of offloading metrics and inter-arrival time.
mhist.assign(tf.concat((tf.reshape(_m, (-1, 1)), mhist[:, :-1]), 1))
qipt = tf.concat((mhist, tf.cast(ihist, tf.float32)[:, :intnum]), 1)
ihist.assign(tf.concat((tf.reshape(_i, (-1, 1)), ihist[:, :-1]), 1))
qvals = model(qipt)
# Determine the offloading decision with the predicted Q-values, and update token bucket state.
deci = (qvals[:, (qpm[1] - qpm[0]):(qpm[2] - qpm[0] + 1)] <= qvals[:, (qpm[2] - qpm[0] + 1):])
ifsend = tf.gather_nd(deci, tf.stack([idx, tf.maximum(tf.constant(0), nstate - qpm[1])], 1))
ifsend = tf.logical_and(ifsend, nstate >= qpm[1])
nstate.assign(tf.minimum(qpm[2], nstate - tf.where(ifsend, qpm[1], 0) + qpm[0] * _i))
# Update the cumulative offloading reward.
igain = tf.reduce_sum(tf.cast(tf.where(ifsend, _r, 0), tf.float64))
tot_gain.assign_add(igain)

def mstream(dataset, sz_=1.0):
# Clear the variables and run the simulation loop.
tot_gain.assign(np.float64(0))
nstate.assign([qpm_[2]] * nstream)
mhist.assign(np.zeros((nstream, nhist1 + 1), np.float32))
ihist.assign(np.zeros((nstream, max(nhist2 + 1, 1)), np.int32))
simloop(dataset)
return tot_gain.numpy() / sz_

return mstream


def data2tf(dset, iset, metrics, rewards):
"""Convert dataset from numpy arrays to a (tf.Dataset, int:size) tuple."""
metrics, rewards = metrics.astype(np.float32), rewards.astype(np.float32)
metrics, rewards = metrics[dset.T], rewards[dset.T]
dataset = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(metrics),
tf.data.Dataset.from_tensor_slices(rewards),
tf.data.Dataset.from_tensor_slices(iset.T)))
return dataset, np.size(dset)
100 changes: 100 additions & 0 deletions lib/bucket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import numpy as np

"""Functions for dealing with token buckets."""


def getqpm(rate, bdepth, maxp=100):
"""Return the scaled token bucket parameters (q,p,m) (integer) such that q/p ~ rate, m/p ~ bdepth."""
denom = np.arange(maxp, dtype=np.int64) + 1
rerr, berr = denom * rate, denom * bdepth
err = (rerr - np.floor(rerr) + berr - np.floor(berr)) / denom
_p = denom[np.argmin(err)]
_q = np.int64(np.floor(rate * _p))
_m = np.int64(np.floor(bdepth * _p))
gcd = np.gcd(np.gcd(_q, _p), _m)
return _q // gcd, _p // gcd, _m // gcd


def getvpidx(rate, bdepth):
"""Get the token numbers for indices of value and policy vectors."""
qpm = getqpm(rate, bdepth)
vidx = np.arange(qpm[0], qpm[2] + 1, dtype=np.float64) / qpm[1]
pidx = np.arange(qpm[1], qpm[2] + 1, dtype=np.float64) / qpm[1]
return vidx, pidx


def qflatlen(qpm):
"""Return size of vector to represent Q values of all [n,a]."""
return 2 * (qpm[2] + 1) - qpm[0] - qpm[1]


def qflat2dec(qpm, qflat):
"""Turn Q value matrix into a decision matrix."""
qflat = qflat[:, (qpm[1] - qpm[0]):].reshape((-1, 2, qpm[2] + 1 - qpm[1]))
return (qflat[:, 1, :] >= qflat[:, 0, :]).astype(np.uint8)


def qflatprev(qpm, qprime, rew, discount, trans):
"""Compute Q[n,a] = ar + max_a' Q[n'=(n.a), a']."""
a0len = qpm[2] + 1 - qpm[0]
qmax = qprime[:, :a0len]
qmax[:, (qpm[1] - qpm[0]):] = np.maximum(qmax[:, (qpm[1] - qpm[0]):], qprime[:, a0len:])

if len(trans.shape) > 1:
qtrans = np.concatenate((trans[qpm[0]:, :], trans[:(a0len - qpm[1] + qpm[0]), :]), axis=0)
qtrans = np.transpose(qtrans)
qtgt = np.matmul(qmax, qtrans)
else:
qtgt = np.zeros_like(qprime)
trans = trans.astype(np.int32)
idx1 = qpm[0] * trans[:, np.newaxis] + np.arange(a0len)
idx1 = np.minimum(idx1, a0len - 1)
idx2 = qpm[0] * (trans - 1)[:, np.newaxis] + np.arange(a0len - qpm[1] + qpm[0])
idx2 = np.minimum(idx2, a0len - 1)
qtgt[:, :a0len] = qmax[np.arange(qprime.shape[0])[:, np.newaxis], idx1]
qtgt[:, a0len:] = qmax[np.arange(qprime.shape[0])[:, np.newaxis], idx2]
qnew = discount * qtgt
qnew[:, a0len:] += rew[:, np.newaxis]
return qnew


def iidpolicy(qpm, metrics, rewards, trans, discount=0.9999, itparam=(1e4, 1e-6)):
"""
Find optimal policy thresholds assuming iid offloading metrics and periodic image arrival.
:param qpm: Bucket parameters from getqpm(rate, bdepth).
:param metrics: Training set of offloading metrics.
:param rewards: Training set of offloading rewards.
:param trans: The transition matrix for token bucket states.
:param discount: Discount factor to apply (default 0.9999).
:return: MDP policy thresholds.
"""

# Sort metrics and compute F(theta) and G(theta)
idx = np.argsort(-metrics)
metrics, rewards = np.float64(metrics[idx]), np.float64(rewards[idx])
gtheta = np.cumsum(rewards) / len(rewards)
ftheta = np.float64(np.arange(1, len(rewards) + 1)) / len(rewards)
fgt = (np.reshape(ftheta, (-1, 1)), np.reshape(gtheta, (-1, 1)))

thresh = (np.amax(metrics) - np.amin(metrics)) * itparam[1]

# Do value iterations
value, policy = np.zeros((qpm[2] - qpm[0] + 1), np.float64), None
for i in range(int(itparam[0])):
vprev, pprev = value.copy(), policy

# If n < P/P, can't send
value[:(qpm[1] - qpm[0])] = discount * np.dot(trans[qpm[0]:qpm[1]], vprev)

# If n >= P/P:
vnosend = np.dot(trans[qpm[1]:], vprev)
vsend = np.dot(trans[:(qpm[2] - qpm[1] + 1)], vprev)
score = fgt[1] + discount * (fgt[0] * vsend + (1 - fgt[0]) * vnosend)
value[(qpm[1] - qpm[0]):] = np.amax(score, 0)
policy = metrics[np.argmax(score, 0)]

if i > 0:
if np.max(np.abs(policy - pprev)) < thresh:
break

return policy
59 changes: 59 additions & 0 deletions lib/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np

"""Utility functions for handling data."""


def load_base(fname):
"""
Load individual image data from npz file.
:param fname: Name of npz file.
:return: (train_metrics, train_rewards), (test_metrics, test_rewards).
"""
# Load the offloading metrics for the training and test sets.
npz = np.load(fname)
tr_m, ts_m = npz['metric_tr'], npz['metric_ts']
# Normalize the offloading metrics.
_mu, _sd = np.mean(tr_m), np.std(tr_m)
tr_m = (tr_m - _mu) / _sd
ts_m = (ts_m - _mu) / _sd
# Compute the offloading reward based on the costs.
tr_r = npz['wcost_tr'] - npz['scost_tr']
ts_r = npz['wcost_ts'] - npz['scost_ts']
return (tr_m, tr_r), (ts_m, ts_r)


class DataTuples:
"""Sample tuples (segments) from a dataset."""

def __init__(self, dset, iset, metrics, rewards, nhist1, nhist2, ntuples):
self.metrics, self.rewards = metrics, rewards
self.nhist1, self.nhist2 = nhist1, nhist2
self.ntuples = ntuples
self.nseq, self.lseq = np.shape(dset)
self.dset, self.iset = dset.flatten(), iset.flatten()

def sample(self):
"""Sample n tuples of (curm, reward, nextm)."""
# Randomly select n indexes from the sequences as the end points of the segments.
idx0 = np.random.randint(0, self.nseq, size=(self.ntuples,))
idx1 = np.random.randint(0, self.lseq - 1, size=(self.ntuples,))
# Retrieve the offloading rewards.
reward = self.rewards[self.dset[idx0 * self.lseq + idx1]]
# Retrieve the offloading metrics for the current and the next states.
idxh = idx1[:, np.newaxis] - np.arange(self.nhist1 + 1)
curm = self.dset[np.maximum(0, idx0[:, np.newaxis] * self.lseq + idxh)]
curm = self.metrics[curm]
curm[idxh < 0] = 0
nextm = self.metrics[self.dset[idx0 * self.lseq + idx1 + 1]][:, np.newaxis]
nextm = np.concatenate((nextm, curm[:, :-1]), -1)
# Retrieve the inter-arrival times for the current and the next states.
if self.nhist2 >= 0:
idxg = idx1[:, np.newaxis] - np.arange(self.nhist2 + 1) - 1
curi = self.iset[np.maximum(0, idx0[:, np.newaxis] * self.lseq + idxg)]
curi[idxg < 0] = 0
curm = np.concatenate((curm, curi), axis=1)
nexti = self.iset[idx0 * self.lseq + idx1][:, np.newaxis]
# Concatenate the inter-arrival times to the offloading metrics.
nexti = np.concatenate((nexti, curi[:, :-1]), -1)
nextm = np.concatenate((nextm, nexti), axis=1)
return curm, reward, nextm
Loading

0 comments on commit 6c1fcd7

Please sign in to comment.