Skip to content

Commit

Permalink
JBAI-5829 [examples] Refactored resource path usage to cache directory.
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitriyb committed Sep 30, 2024
1 parent 5b2e755 commit 740da0f
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 26 deletions.
52 changes: 42 additions & 10 deletions examples/src/jvmMain/kotlin/io/kinference/examples/Utils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,70 @@ import io.ktor.util.cio.writeChannel
import io.ktor.utils.io.copyAndClose
import java.io.File

val resourcesPath = System.getProperty("user.dir") + "/cache/"
/**
* Directory used to store cached files.
*
* This variable combines the user's current working directory
* with a "cache" subdirectory to create the path for storing cache files.
* It is used in various functions to check for existing files or directories,
* create new ones if they do not exist, and manage the caching of downloaded files.
*/
val cacheDirectory = System.getProperty("user.dir") + "/cache/"

/**
* Downloads a file from the specified URL and saves it to the given output path.
* If the file already exists at the output path, the download is skipped.
* Downloads a file from the given URL and saves it with the specified file name.
*
* @param url The URL from which the file will be downloaded.
* @param outputPath The path to which the downloaded file will be saved.
* Checks if the directory specified by `cacheDirectory` exists.
* If not, it creates the directory. If the file already exists,
* the download is skipped. Otherwise, the file is downloaded
* using an HTTP client with a 10-minute timeout setting.
*
* @param url The URL from which to download the file.
* @param fileName The name to use for the downloaded file.
* @param timeout Optional timeout duration for the download request, in milliseconds.
* Defaults to 600,000 milliseconds (10 minutes).
* Increase the timeout if you are not sure that download for the particular model with fit into the default timeout.
*/
suspend fun downloadFile(url: String, outputPath: String) {
suspend fun downloadFile(url: String, fileName: String, timeout: Long = 600_000) {
// Ensure the predefined path is treated as a directory
val directory = File(cacheDirectory)

// Check if the directory exists, if not create it
if (!directory.exists()) {
println("Predefined directory doesn't exist. Creating directory at $cacheDirectory.")
directory.mkdirs() // Create the directory if it doesn't exist
}

// Check if the file already exists
val file = File(outputPath)
val file = File(directory, fileName)
if (file.exists()) {
println("File already exists at $outputPath. Skipping download.")
println("File already exists at ${file.absolutePath}. Skipping download.")
return // Exit the function if the file exists
}

// Create an instance of HttpClient with custom timeout settings
val client = HttpClient {
install(HttpTimeout) {
requestTimeoutMillis = 600_000 // Set timeout to 10 minutes (600,000 milliseconds)
requestTimeoutMillis = timeout
}
}

// Download the file and write to the specified output path
client.prepareRequest(url).execute { response ->
response.bodyAsChannel().copyAndClose(File(outputPath).writeChannel())
response.bodyAsChannel().copyAndClose(file.writeChannel())
}

client.close()
}

/**
* Extracts the token ID with the highest probability from the output tensor.
*
* @param output A map containing the output tensors identified by their names.
* @param tokensSize The number of tokens in the sequence.
* @param outputName The name of the tensor containing the logits.
* @return The ID of the top token.
*/
suspend fun extractTopToken(output: Map<String, KIONNXData<*>>, tokensSize: Int, outputName: String): Long {
val logits = output[outputName]!! as KITensor
val sliced = logits.data.slice(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import io.kinference.core.KIEngine
import io.kinference.core.data.tensor.KITensor
import io.kinference.core.data.tensor.asTensor
import io.kinference.examples.downloadFile
import io.kinference.examples.resourcesPath
import io.kinference.examples.cacheDirectory
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke
import io.kinference.utils.CommonDataLoader
Expand Down Expand Up @@ -100,12 +100,12 @@ suspend fun main() {
val modelName = "CaffeNet"

println("Downloading model from: $modelUrl")
downloadFile(modelUrl, "$resourcesPath/$modelName.onnx")
downloadFile(modelUrl, "$modelName.onnx")
println("Downloading synset from: $synsetUrl")
downloadFile(synsetUrl, "$resourcesPath/synset.txt")
downloadFile(synsetUrl, "synset.txt")

val modelBytes = CommonDataLoader.bytes("$resourcesPath/$modelName.onnx".toPath())
val classLabels = File("$resourcesPath/synset.txt").readLines()
val modelBytes = CommonDataLoader.bytes("$cacheDirectory/$modelName.onnx".toPath())
val classLabels = File("$cacheDirectory/synset.txt").readLines()

println("Loading model...")
val model = KIEngine.loadModel(modelBytes, optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.kinference.examples.classification

import io.kinference.examples.downloadFile
import io.kinference.examples.resourcesPath
import io.kinference.examples.cacheDirectory
import io.kinference.ndarray.arrays.*
import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke
import io.kinference.ort.ORTEngine
Expand Down Expand Up @@ -99,12 +99,12 @@ suspend fun main() {
val modelName = "CaffeNet"

println("Downloading model from: $modelUrl")
downloadFile(modelUrl, "$resourcesPath/$modelName.onnx")
downloadFile(modelUrl, "$modelName.onnx")
println("Downloading synset from: $synsetUrl")
downloadFile(synsetUrl, "$resourcesPath/synset.txt")
downloadFile(synsetUrl, "synset.txt")

val modelBytes = CommonDataLoader.bytes("$resourcesPath/$modelName.onnx".toPath())
val classLabels = File("$resourcesPath/synset.txt").readLines()
val modelBytes = CommonDataLoader.bytes("$cacheDirectory/$modelName.onnx".toPath())
val classLabels = File("$cacheDirectory/synset.txt").readLines()

println("Loading model...")
val model = ORTEngine.loadModel(modelBytes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import io.kinference.core.KIEngine
import io.kinference.core.data.tensor.asTensor
import io.kinference.examples.downloadFile
import io.kinference.examples.extractTopToken
import io.kinference.examples.resourcesPath
import io.kinference.examples.cacheDirectory
import io.kinference.ndarray.arrays.LongNDArray
import io.kinference.ndarray.arrays.NDArrayCore
import io.kinference.utils.CommonDataLoader
Expand All @@ -22,9 +22,9 @@ suspend fun main() {
val modelName = "gpt2-lm-head-10"

println("Downloading model from: $modelUrl")
downloadFile(modelUrl, "$resourcesPath/$modelName.onnx")
downloadFile(modelUrl, "$modelName.onnx") //GPT-2 from model zoo is around 650 Mb, adjust your timeout if needed

val modelBytes = CommonDataLoader.bytes("${resourcesPath}/$modelName.onnx".toPath())
val modelBytes = CommonDataLoader.bytes("${cacheDirectory}/$modelName.onnx".toPath())

println("Loading model...")
val model = KIEngine.loadModel(modelBytes, optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import io.kinference.core.data.tensor.KITensor
import io.kinference.core.data.tensor.asTensor
import io.kinference.examples.downloadFile
import io.kinference.examples.extractTopToken
import io.kinference.examples.resourcesPath
import io.kinference.examples.cacheDirectory
import io.kinference.ndarray.arrays.FloatNDArray
import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke
import io.kinference.ort.ORTData
Expand All @@ -25,9 +25,9 @@ suspend fun main() {
val modelName = "gpt2-lm-head-10"

println("Downloading model from: $modelUrl")
downloadFile(modelUrl, "$resourcesPath/$modelName.onnx")
downloadFile(modelUrl, "$modelName.onnx") //GPT-2 from model zoo is around 650 Mb, adjust your timeout if needed

val modelBytes = CommonDataLoader.bytes("${resourcesPath}/$modelName.onnx".toPath())
val modelBytes = CommonDataLoader.bytes("${cacheDirectory}/$modelName.onnx".toPath())

println("Loading model...")
val model = ORTEngine.loadModel(modelBytes)
Expand Down

0 comments on commit 740da0f

Please sign in to comment.