Skip to content

Commit

Permalink
Modal antes de iniciar benchmarking e conversa
Browse files Browse the repository at this point in the history
  • Loading branch information
ThalesBezerra21 committed Jul 31, 2024
1 parent 8df4042 commit 029de51
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import java.util.Locale
val benchmarkingModelsLabels = listOf(
//"llama",
"gemma",
//"qwen",
"qwen",
)

class AppViewModel(application: Application) : AndroidViewModel(application) {
Expand Down
30 changes: 2 additions & 28 deletions android/MLCChat/app/src/main/java/ai/mlc/mlcchat/DownloadView.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,10 @@ import kotlinx.coroutines.delay
@Composable
fun DownloadView(
modifier: Modifier = Modifier,
appViewModel: AppViewModel,
onDownloadsFinished: (() -> Unit)? = null
pendingModels: List<AppViewModel.ModelState>,
numModels: Int,
) {

var pendingModels by remember {
mutableStateOf(appViewModel.benchmarkingModels)
}

val numModels = appViewModel.benchmarkingModels.size

LaunchedEffect(Unit) {
while(pendingModels.isNotEmpty()){
delay(100)

val modelState = pendingModels[0]

if(modelState.modelInitState.value == ModelInitState.Finished){
pendingModels = pendingModels.subList(1, pendingModels.size)
continue
}

if(modelState.modelInitState.value !== ModelInitState.Downloading){
modelState.handleStart()
}
}
if(onDownloadsFinished !== null)
onDownloadsFinished()
}

Column (
modifier = modifier
.background(MaterialTheme.colorScheme.primary),
Expand All @@ -70,7 +45,6 @@ fun DownloadView(
text = "Downloading model ${numModels-pendingModels.size+1} of $numModels",
color = MaterialTheme.colorScheme.onPrimary,
style = MaterialTheme.typography.titleMedium,
//fontWeight = FontWeight.Light
)
Spacer(modifier = Modifier.height(15.dp))
Column (
Expand Down
167 changes: 125 additions & 42 deletions android/MLCChat/app/src/main/java/ai/mlc/mlcchat/HomeView.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
package ai.mlc.mlcchat

import ai.mlc.mlcchat.components.AccordionItem
import ai.mlc.mlcchat.components.AccordionText
import ai.mlc.mlcchat.components.Chip
import ai.mlc.mlcchat.components.LoadingTopBottomIndicator
import ai.mlc.mlcchat.hooks.useModal
import android.content.Context
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
Expand All @@ -23,12 +21,10 @@ import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.BarChart
import androidx.compose.material.icons.filled.Chat
import androidx.compose.material3.Button
import androidx.compose.material3.AlertDialog
import androidx.compose.material3.ButtonDefaults
import androidx.compose.material3.Icon
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
Expand Down Expand Up @@ -64,30 +60,27 @@ fun HomeView(

val isIdleMeasured = true

var showStartDownloadDialog by remember { mutableStateOf(false) }
var downloadingModels by remember { mutableStateOf(false) }
val (startBenchmarking) = useStartBenchmarking(
onStart = { navController.navigate("benchmarking") }
)

fun startBenchmarking() {
navController.navigate("benchmarking")
}
val (startConversation) = useStartConversation(
onStart = { navController.navigate("home") }
)

val (isDownloading, pendingModels, numModels, startDownload) = useDownloadModels(
viewModel = appViewModel,
onFinish = { startBenchmarking() }
)

fun openDownloadDialog() {
fun initBenchmarkingFlux() {
if(!appViewModel.allBenchmarkingModelsReady()){
showStartDownloadDialog = true
startDownload()
}else{
startBenchmarking()
}
}

fun startDownload() {
downloadingModels = true
}

fun onDownloadsFinished() {
downloadingModels = false
startBenchmarking()
}

Column (
modifier = Modifier
.fillMaxSize(),
Expand Down Expand Up @@ -117,22 +110,21 @@ fun HomeView(

LargeRoundedButton(
icon = Icons.Default.BarChart,
onClick = { openDownloadDialog() },
enabled = !downloadingModels && isIdleMeasured,
onClick = { initBenchmarkingFlux() },
enabled = !isDownloading && isIdleMeasured,
text = "Start benchmarking"
)

Spacer(modifier = Modifier.height(15.dp))

LargeRoundedButton(
icon = Icons.Default.Chat,
onClick = { navController.navigate("home") },
enabled = !downloadingModels && isIdleMeasured,
onClick = { startConversation() },
enabled = !isDownloading && isIdleMeasured,
text = "Chat with LLMs"
)

}

Column(
modifier = Modifier
.weight(1f),
Expand All @@ -149,32 +141,123 @@ fun HomeView(
)
}

if(downloadingModels) {
if(isDownloading) {
DownloadView(
modifier = Modifier
.fillMaxSize(),
appViewModel = appViewModel,
onDownloadsFinished = { onDownloadsFinished() }
pendingModels = pendingModels,
numModels = numModels
)
}
}

}
}
}

if(showStartDownloadDialog) {
AlertDialog(
title = { Text(text = "Download Models") },
text = { Text(text = "To start the benchmarking, we need to download the LLM models.\n\nThis may take some time and require a large download.\n\nDo you want to continue?") },
onDismissRequest = { showStartDownloadDialog = false },
confirmButton = {
TextButton(onClick = { startDownload(); showStartDownloadDialog = false }) { Text("Download") }
},
dismissButton = {
TextButton(onClick = { showStartDownloadDialog = false }) { Text("Cancel") }
}
)
data class DownloadModelsActions(
val isDownloading: Boolean,
val pendingModels: List<AppViewModel.ModelState>,
val numModels: Int,
val startDownload: () -> Unit,
)

@Composable
fun useDownloadModels(
viewModel: AppViewModel,
onFinish: () -> Unit
): DownloadModelsActions {

var isDownloading by remember { mutableStateOf(false) }

var pendingModels by remember {
mutableStateOf(viewModel.benchmarkingModels)
}

val numModels = viewModel.benchmarkingModels.size

LaunchedEffect(isDownloading) {

if(!isDownloading)
return@LaunchedEffect

while(pendingModels.isNotEmpty()){
delay(100)

val modelState = pendingModels[0]

if(modelState.modelInitState.value == ModelInitState.Finished){
pendingModels = pendingModels.subList(1, pendingModels.size)
continue
}

if(modelState.modelInitState.value !== ModelInitState.Downloading){
modelState.handleStart()
}
}
isDownloading = false
onFinish()
}

val (showDownloadModal) = useModal(
title = "Download Models",
text = "To start the benchmarking, we need to download the LLM models.\n" +
"\n" +
"This may take some time and require a large download.\n" +
"\n" +
"Do you want to continue?",
onConfirm = { isDownloading = true },
confirmLabel = "Download"
)

return DownloadModelsActions(
isDownloading = isDownloading,
startDownload = showDownloadModal,
pendingModels = pendingModels,
numModels = numModels
)
}

data class StartBenchmarkActions(
val startBenchmarking: () -> Unit
)

@Composable
fun useStartBenchmarking(
onStart: () -> Unit
): StartBenchmarkActions {

val (showStartBenchmarkingModal) = useModal(
title = "Warning",
text = "The execution of LLMs on Android devices can be very taxing, and can cause crashes, especially on devices with less than 8GB of RAM.",
onConfirm = { onStart() },
confirmLabel = "Continue"
)

return StartBenchmarkActions(
startBenchmarking = showStartBenchmarkingModal
)
}

data class StartConversationActions(
val startConversation: () -> Unit
)

@Composable
fun useStartConversation(
onStart: () -> Unit
): StartConversationActions {

val (showStartConversationModal) = useModal(
title = "Warning",
text = "The execution of LLMs on Android devices can be very taxing, and can cause crashes, especially on devices with less than 8GB of RAM.\n"
+ "\nThe custom conversation option has the same restrictions as the default benchmarking",
onConfirm = { onStart() },
confirmLabel = "Continue"
)

return StartConversationActions(
startConversation = showStartConversationModal
)
}

@Composable
Expand Down
54 changes: 54 additions & 0 deletions android/MLCChat/app/src/main/java/ai/mlc/mlcchat/hooks/useModal.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package ai.mlc.mlcchat.hooks

import androidx.compose.material3.AlertDialog
import androidx.compose.material3.Text
import androidx.compose.material3.TextButton
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue

data class ModalActions(
val show: () -> Unit,
val hide: () -> Unit
)

@Composable
fun useModal(
title: String,
text: String,
onConfirm: () -> Unit,
confirmLabel: String,
dismissLabel: String = "Cancel"
): ModalActions {

var visible by remember { mutableStateOf(false) }

fun show() {
visible = true
}

fun hide() {
visible = false
}

if(visible) {
AlertDialog(
title = { Text(text = title) },
text = { Text(text = text) },
onDismissRequest = ::hide,
confirmButton = {
TextButton(onClick = { onConfirm(); hide() }) { Text(confirmLabel) }
},
dismissButton = {
TextButton(onClick = ::hide) { Text(dismissLabel) }
}
)
}

return ModalActions(
show = ::show,
hide = ::hide
)
}

0 comments on commit 029de51

Please sign in to comment.