Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.openmined.syft.domain.TrainingParameters
import org.openmined.syft.execution.JobStatusMessage
import org.openmined.syft.execution.SyftJob
import org.openmined.syft.execution.TrainingState
import org.openmined.syft.execution.checkpoint.JsonCheckPointSerializer

@ExperimentalCoroutinesApi
@ExperimentalUnsignedTypes
Expand All @@ -32,9 +33,10 @@ class TrainingTask(
private val modelVersion: String
) {
private val syftWorker = Syft.getInstance(configuration, authToken)
private lateinit var mnistJob: SyftJob

suspend fun runTask(logger: MnistLogger) {
val mnistJob = syftWorker.newJob(modelName, modelVersion)
mnistJob = syftWorker.newJob(modelName, modelVersion)
val statusPublisher = PublishProcessor.create<Result>()

logger.postLog("Processing $modelName $modelVersion")
Expand Down Expand Up @@ -77,6 +79,37 @@ class TrainingTask(
}
}

suspend fun stopTask(logger: MnistLogger) {
logger.postLog("Stopping training!")
mnistJob.stop().collect {
withContext(Dispatchers.Main) {
processTrainingState(it, logger)
}
}

// mnistJob.save().collect { saveState -> processTrainingState(saveState, logger) }

logger.postLog("Training stopped!")
}

suspend fun resumeTask(logger: MnistLogger) {
val startTime = System.currentTimeMillis()
logger.postLog("Resuming training!")
logger.postState(ContentState.Training)

mnistJob.resume(
mnistJob.jobModel.plans,
dataLoader,
generateTrainingParameters()
).collect {
withContext(Dispatchers.Main) {
processTrainingState(it, logger)
}
}

logger.postLog("Training Finished after resumed in ${System.currentTimeMillis() - startTime} ms")
}

private suspend fun executeTraining(
logger: MnistLogger,
mnistJob: SyftJob,
Expand All @@ -89,7 +122,8 @@ class TrainingTask(
mnistJob.train(requestResult.plans,
requestResult.clientConfig!!,
dataLoader,
generateTrainingParameters()
generateTrainingParameters(),
JsonCheckPointSerializer()
).collect {
// collect happens in IO Dispatcher. Change context to process the training state.
withContext(Dispatchers.Main) {
Expand Down Expand Up @@ -123,6 +157,18 @@ class TrainingTask(
is TrainingState.Complete -> {
logger.postLog("Training completed!")
}
is TrainingState.Stop -> {
logger.postLog("Training stopped!")
}
is TrainingState.Resume -> {
logger.postLog("Training resumed!")
}
is TrainingState.Save -> {
logger.postLog("Model checkpoint created at ${trainingState.path}")
}
is TrainingState.Load -> {
logger.postLog("model checkpoint loaded!")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class MnistActivity : AppCompatActivity() {

private lateinit var binding: ActivityMnistBinding
private lateinit var viewModel: MnistActivityViewModel
private var trainingButtonToggle = false

override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
Expand All @@ -54,6 +55,17 @@ class MnistActivity : AppCompatActivity() {

binding.buttonFirst.setOnClickListener { launchForegroundCycle() }
binding.buttonSecond.setOnClickListener { launchBackgroundCycle() }
binding.buttonTraining.setOnClickListener {
if (trainingButtonToggle) {
viewModel.stopTraining()
binding.buttonTraining.text = getString(R.string.resume_training)
trainingButtonToggle = !trainingButtonToggle
} else {
viewModel.resumeTraining()
binding.buttonTraining.text = getString(R.string.stop_training)
trainingButtonToggle = !trainingButtonToggle
}
}

viewModel.processState.observe(
this,
Expand All @@ -77,7 +89,7 @@ class MnistActivity : AppCompatActivity() {

private fun launchForegroundCycle() {
val config = SyftConfiguration.builder(this, viewModel.baseUrl)
// .setMessagingClient(SyftConfiguration.NetworkingClients.HTTP)
.setMessagingClient(SyftConfiguration.NetworkingClients.HTTP)
.setCacheTimeout(0L)
.disableBatteryCheck()
.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ class MnistActivityViewModel(
}
}

fun stopTraining() {
viewModelScope.launch(Dispatchers.IO) {
trainingTask!!.stopTask(this@MnistActivityViewModel)
}
}

fun resumeTraining() {
viewModelScope.launch(Dispatchers.IO) {
trainingTask!!.resumeTask(this@MnistActivityViewModel)
}
}

fun disposeTraining() {
compositeDisposable.clear()
trainingTask?.disposeTraining()
Expand All @@ -102,7 +114,6 @@ class MnistActivityViewModel(
workerRepository.getWorkInfo(it)
}


fun submitJob(): LiveData<WorkInfo> {
val requestId = workerRepository.getRunningWorkStatus()
?: workerRepository.submitJob(authToken, baseUrl)
Expand Down
21 changes: 17 additions & 4 deletions demo-app/src/main/res/layout/activity_mnist.xml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
<LinearLayout
android:id="@+id/scrollArea"
android:layout_width="match_parent"
android:layout_height="200dp"
android:layout_height="210dp"
android:gravity="center_vertical"
android:orientation="horizontal"
android:padding="10dp"
Expand Down Expand Up @@ -113,7 +113,7 @@
android:id="@+id/button_first"
style="@android:style/Widget.DeviceDefault.Button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_height="40dp"
android:background="@drawable/button_solid"
android:padding="10dp"
android:text="@string/start_foreground"
Expand All @@ -125,7 +125,7 @@
android:id="@+id/button_second"
style="@android:style/Widget.DeviceDefault.Button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_height="40dp"
android:layout_marginTop="10dp"
android:background="@drawable/button_solid"
android:padding="10dp"
Expand All @@ -134,11 +134,24 @@
android:textSize="12sp"
android:textStyle="bold" />

<Button
android:id="@+id/button_training"
style="@android:style/Widget.DeviceDefault.Button"
android:layout_width="match_parent"
android:layout_height="40dp"
android:layout_marginTop="10dp"
android:background="@drawable/button_solid"
android:padding="10dp"
android:text="@string/stop_training"
android:textColor="@color/white"
android:textSize="12sp"
android:textStyle="bold" />

<Button
android:id="@+id/button_cancel"
style="@android:style/Widget.DeviceDefault.Button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_height="40dp"
android:layout_marginTop="10dp"
android:background="@drawable/button_solid"
android:backgroundTint="@color/red"
Expand Down
2 changes: 2 additions & 0 deletions demo-app/src/main/res/values/strings.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
<string name="start_background">Start Background</string>
<string name="close_background">Close Background</string>
<string name="start_foreground">Start Foreground</string>
<string name="stop_training">Stop Training</string>
<string name="resume_training">Resume Training</string>
<string name="enter_url">Enter PyGrid server URL:</string>
<string name="error_url">url is not valid</string>
<string name="title_activity_work_info">Running Job Info</string>
Expand Down
1 change: 1 addition & 0 deletions syft/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ dependencies {
testImplementation 'org.jetbrains.kotlinx:kotlinx-coroutines-test:1.4.1'
testImplementation "androidx.arch.core:core-testing:2.1.0"
testImplementation 'app.cash.turbine:turbine:0.2.1'
testImplementation 'org.json:json:20201115'
}
136 changes: 123 additions & 13 deletions syft/src/main/java/org/openmined/syft/execution/SyftJob.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.async
import kotlinx.coroutines.cancel
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.cancellable
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.flow
import org.openmined.syft.Syft
import org.openmined.syft.data.loader.DataLoader
Expand All @@ -17,6 +21,8 @@ import org.openmined.syft.domain.JobRepository
import org.openmined.syft.domain.OutputParamType
import org.openmined.syft.domain.SyftConfiguration
import org.openmined.syft.domain.TrainingParameters
import org.openmined.syft.execution.checkpoint.CheckPoint
import org.openmined.syft.execution.checkpoint.CheckPointSerializer
import org.openmined.syft.networking.datamodels.ClientConfig
import org.openmined.syft.networking.datamodels.syft.CycleResponseData
import org.openmined.syft.proto.SyftModel
Expand Down Expand Up @@ -90,6 +96,11 @@ class SyftJob internal constructor(
internal val model = SyftModel(jobModel.modelName, jobModel.version)
private var requestKey = ""
private val jobScope = CoroutineScope(Dispatchers.IO)
private var checkPoint: CheckPoint? = null
private var checkPointSerializer: CheckPointSerializer<*>? = null
internal lateinit var clientConfig: ClientConfig
internal var currentStep = 0
private var currentTrainingState: TrainingState? = null

/**
* Starts the job by asking syft worker to request for cycle.
Expand Down Expand Up @@ -125,25 +136,59 @@ class SyftJob internal constructor(
plans: ConcurrentHashMap<String, Plan>,
clientConfig: ClientConfig,
dataLoader: DataLoader,
trainingParameters: TrainingParameters
trainingParameters: TrainingParameters,
checkPointSerializer: CheckPointSerializer<*>? = null
): Flow<TrainingState> = flow {

plans["training_plan"]?.let { plan ->
this@SyftJob.checkPointSerializer = checkPointSerializer
this@SyftJob.clientConfig = clientConfig

// TODO What do we do with this? Should all clients be forced to use "batch_size"?
val batchSize = (clientConfig.planArgs["batch_size"]
?: error("batch_size doesn't exist")).toInt()
// TODO What do we do with this? Should all clients be forced to use "batch_size"?
val batchSize = (clientConfig.planArgs["batch_size"]
?: error("batch_size doesn't exist")).toInt()

val batchIValue = IValue.from(
Tensor.fromBlob(longArrayOf(batchSize.toLong()), longArrayOf(1))
)
repeat(clientConfig.properties.maxUpdates) { step ->
val batchIValue = IValue.from(
Tensor.fromBlob(longArrayOf(batchSize.toLong()), longArrayOf(1))
)

emit(TrainingState.Epoch(step + 1))
dataLoader.reset()
// TODO We should check requirements before arriving to this point
val steps = clientConfig.properties.maxUpdates

// TODO We should check requirements before arriving to this point
// emit(TrainingState.Error(IllegalStateException("No params in the model")))
val modelParams = model.paramArray ?: emptyArray()
val modelParams = model.paramArray ?: emptyArray()
trainingLoop(
plans,
clientConfig,
dataLoader,
trainingParameters,
modelParams,
batchSize,
steps
).cancellable().collect {
if (shouldCancelTrainingLoop()) {
currentCoroutineContext().cancel()
}
emit(it)
}
}

@ExperimentalStdlibApi
private fun trainingLoop(
plans: ConcurrentHashMap<String, Plan>,
clientConfig: ClientConfig,
dataLoader: DataLoader,
trainingParameters: TrainingParameters,
modelParams: Array<Tensor>,
batchSize: Int,
steps: Int
): Flow<TrainingState> = flow {
plans["training_plan"]?.let { plan ->

repeat(steps) { step ->
currentStep = step + 1
emit(TrainingState.Epoch(currentStep))
dataLoader.reset()

val paramIValue = IValue.listFrom(*modelParams)

for (batchData in dataLoader) {
Expand Down Expand Up @@ -207,6 +252,71 @@ class SyftJob internal constructor(
}
}

fun stop(): Flow<TrainingState> = flow {
currentTrainingState = TrainingState.Stop
emit(TrainingState.Stop)
checkPointSerializer?.let {
checkPoint = CheckPoint.fromJob(this@SyftJob)
it.serialize(checkPoint!!)
}
}

fun save(
path: String = "${config.filesDir}/checkpoint-${System.currentTimeMillis()}",
overwrite: Boolean = false
) : Flow<TrainingState> = flow {
checkPointSerializer?.let {
checkPoint = CheckPoint.fromJob(this@SyftJob)
val result = it.save(checkPoint!!, path, overwrite)
emit(TrainingState.Save(result))
}
}

@ExperimentalStdlibApi
fun resume(
plans: ConcurrentHashMap<String, Plan>,
dataLoader: DataLoader,
trainingParameters: TrainingParameters,
checkPointFilePath: String? = null
): Flow<TrainingState> = flow {
if (checkPointFilePath != null) {
currentTrainingState = TrainingState.Load
emit(TrainingState.Load)
checkPointSerializer?.let {
checkPoint = it.load(checkPointFilePath)
}
}

checkPoint?.let {
currentTrainingState = TrainingState.Resume
emit(TrainingState.Resume)
val batchSize = (it.clientConfig!!.planArgs["batch_size"]
?: error("batch_size doesn't exist")).toInt()
val batchIValue = IValue.from(
Tensor.fromBlob(longArrayOf(batchSize.toLong()), longArrayOf(1))
)
val steps = it.steps - it.currentStep
trainingLoop(
plans,
it.clientConfig!!,
dataLoader,
trainingParameters,
it.modelParams!!,
batchSize,
steps
).cancellable().collect { trainingState ->
if (shouldCancelTrainingLoop()) {
currentCoroutineContext().cancel()
}
emit(trainingState)
}
}
}

private fun shouldCancelTrainingLoop(): Boolean {
return currentTrainingState == TrainingState.Stop
}

/**
* This method is called by [Syft Worker][org.openmined.syft.Syft] on being accepted by PyGrid into a cycle
* @param responseData The training parameters and requestKey returned by PyGrid
Expand Down
Loading