Check this blog entry for more information
This project is an implementation of a Federated Learning system consisting on a Parameter Server and an Android application that can be used as a client
Both components are implemented in Kotlin using DL4J as the Machine Learning framework
Enjoy!
Training a machine learning model requires data. The more we have, the better (well... not always, but let's allow some simplifications). However, data is not cheap and more importantly, it can contain sensitive and personal information.
Recent developments in privacy in the form of new laws as GDPR and the increase of awareness of users and citizens in the value of their data is generating a need for techniques to enforce more privacy
Though techniques as anonymisation can greatly help with the privacy issue the fact that all the data is being sent to a central location to train the machine learning models is always a motive to be worried about
Federated Learning turns the update of Machine Learning models upside-down by allowing the devices on the edge to participate in the training.
Instead of sending the data in the client to a centralised location, Federated Learning sends the model to the devices participating in the federation. The model is then re-trained (using Transfer Learning) with the local data
And the data, your data, never leaves the device, let that be your phone, your laptop or your IoT gadget
Very briefly the process of training a model goes as follows:
- The server opens a new round of training
- The clients that are going to participate in the training round download the latest version of the model from the server
- Using their local data, each client updates the model
- Those updates are sent to the server
- The server gathers all updates and applies Federated Averaging to improve the shared model
- The shared model is now ready for all clients to use
To demonstrate how Federated Learning works, I have implemented a system based on Cifar-10, a well-known image classification dataset
The architecture allows to remove the UI bit in Android and apply the rest with little effort to another type of client that supports Kotlin
The averaging is done by the server once it has received a minimum number of updates. It applies Federated Averaging as defined in 1
See FederatedAveragingStrategy.kt
override fun processUpdates(): ByteArrayOutputStream {
val totalSamples = repository.getTotalSamples()
val model = ModelSerializer.restoreMultiLayerNetwork(repository.retrieveModel())
val shape = model.getLayer(layerIndex).params().shape()
val sumUpdates = repository.listClientUpdates().fold(
Nd4j.zeros(shape[0], shape[1]),
{ sumUpdates, next -> processSingleUpdate(next, totalSamples, sumUpdates) }
)
model.getLayer(layerIndex).setParams(sumUpdates)
val outputStream = ByteArrayOutputStream()
ModelSerializer.writeModel(model, outputStream, true)
repository.storeModel(outputStream.toByteArray())
return outputStream
}