Skip to content

Commit

Permalink
Merge branch 'master' into master-kotlin18
Browse files Browse the repository at this point in the history
# Conflicts:
#	build.gradle.kts
  • Loading branch information
dmitriyb committed Oct 7, 2024
2 parents 5b44daf + 92b4073 commit d7d7ecd
Show file tree
Hide file tree
Showing 32 changed files with 723 additions and 73 deletions.
33 changes: 19 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ it is highly recommended to use KInference TensorFlow.js backend instead for mor
KInference Core dependency coordinates:
```kotlin
dependencies {
api("io.kinference", "inference-core", "0.2.23")
api("io.kinference", "inference-core", "0.2.24")
}
```

Expand All @@ -67,7 +67,7 @@ This backend is recommended for JavaScript projects.
TensorFlow.js backend dependency coordinates:
```kotlin
dependencies {
api("io.kinference", "inference-tfjs", "0.2.23")
api("io.kinference", "inference-tfjs", "0.2.24")
}
```

Expand All @@ -81,14 +81,14 @@ To check on the system requirements, visit the following [link](https://onnxrunt
ONNXRuntime CPU backend dependency coordinates:
```kotlin
dependencies {
api("io.kinference", "inference-ort", "0.2.23")
api("io.kinference", "inference-ort", "0.2.24")
}
```

ONNXRuntime GPU backend dependency coordinates:
```kotlin
dependencies {
api("io.kinference", "inference-ort-gpu", "0.2.23")
api("io.kinference", "inference-ort-gpu", "0.2.24")
}
```

Expand All @@ -104,7 +104,7 @@ Array adapter for the [kmath](https://github.com/SciProgCentre/kmath) library th
Dependency coordinates:
```kotlin
dependencies {
api("io.kinference", "adapter-kmath-{backend_name}", "0.2.23")
api("io.kinference", "adapter-kmath-{backend_name}", "0.2.24")
}
```

Expand All @@ -114,12 +114,12 @@ Array adapter for the [multik](https://github.com/Kotlin/multik) library that wo
Dependency coordinates:
```kotlin
dependencies {
api("io.kinference", "adapter-multik-{backend_name}", "0.2.23")
api("io.kinference", "adapter-multik-{backend_name}", "0.2.24")
}
```

## Getting started
Let us now walk through how to get started with KInference. The latest version of KInference is *0.2.23*
Let us now walk through how to get started with KInference. The latest version of KInference is *0.2.24*

### Setup dependencies repository

Expand All @@ -142,7 +142,7 @@ To enable the backend, you can add the chosen KInference runtime as a dependency

```kotlin
dependencies {
api("io.kinference", "inference-core", "0.2.23")
api("io.kinference", "inference-core", "0.2.24")
}
```

Expand All @@ -160,29 +160,34 @@ kotlin {
sourceSets {
val commonMain by getting {
dependencies {
api("io.kinference:inference-api:0.2.23")
api("io.kinference:ndarray-api:0.2.23")
api("io.kinference:inference-api:0.2.24")
api("io.kinference:ndarray-api:0.2.24")
}
}

val jvmMain by getting {
dependencies {
api("io.kinference:inference-core:0.2.23")
api("io.kinference:inference-core:0.2.24")
}
}

val jsMain by getting {
dependencies {
api("io.kinference:inference-tfjs:0.2.23")
api("io.kinference:inference-tfjs:0.2.24")
}
}
}
}
```

## Examples
You can find several KInference usage examples in [this repository](https://github.com/JetBrains-Research/kinference-examples).
The repository has examples of multi-backend project configuration and sharing KInference-related code between the modules.
The [examples module](https://github.com/JetBrains-Research/kinference/tree/master/examples) contains examples of solving classification tasks
(cats vs dogs) and text generation.
Different backends are used in the examples.
Models for the examples were selected from the [ONNX Model Zoo](https://github.com/onnx/models).
Running the examples does not require converting models to different opsets.
However, if you need to run a model with operator versions not supported by KInference,
you can refer to [Convert guide](https://github.com/OpenPPL/ppl.nn/blob/master/docs/en/onnx-model-opset-convert-guide.md).

## Want to know more?
KInference API itself is widely documented, so you can explore its code and interfaces to get to know KInference better.
Expand Down
25 changes: 14 additions & 11 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnLockMismatchReport
import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnPlugin
import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnRootExtension
import org.jetbrains.kotlin.gradle.tasks.KotlinCompilationTask
import org.jetbrains.kotlin.utils.addToStdlib.applyIf

group = "io.kinference"
version = "0.2.23-kotlin18"
version = "0.2.24-kotlin18"

plugins {
alias(libs.plugins.kotlin.multiplatform) apply false
Expand All @@ -35,21 +36,23 @@ subprojects {

apply {
plugin("org.jetbrains.kotlin.multiplatform")

plugin("maven-publish")
plugin("idea")
}


publishing {
repositories {
maven {
name = "SpacePackages"
url = uri("https://packages.jetbrains.team/maven/p/ki/maven")
applyIf(path != ":examples") {
apply(plugin = "maven-publish")

publishing {
repositories {
maven {
name = "SpacePackages"
url = uri("https://packages.jetbrains.team/maven/p/ki/maven")

credentials {
username = System.getenv("JB_SPACE_CLIENT_ID")
password = System.getenv("JB_SPACE_CLIENT_SECRET")
credentials {
username = System.getenv("JB_SPACE_CLIENT_ID")
password = System.getenv("JB_SPACE_CLIENT_SECRET")
}
}
}
}
Expand Down
33 changes: 33 additions & 0 deletions examples/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
group = rootProject.group
version = rootProject.version

kotlin {
jvm()

sourceSets {
jvmMain {
dependencies {
api(project(":inference:inference-api"))
api(project(":inference:inference-core"))
api(project(":inference:inference-ort"))
api(project(":serialization:serializer-protobuf"))
api(project(":utils:utils-common"))

api(project(":ndarray:ndarray-api"))
api(project(":ndarray:ndarray-core"))

implementation("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.5.2")
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:0.5.2") // Dataset support

implementation("io.ktor:ktor-client-core:2.3.12")
implementation("io.ktor:ktor-client-cio:2.3.12") // JVM Engine

api("org.slf4j:slf4j-api:2.0.9")
api("org.slf4j:slf4j-simple:2.0.9")

implementation("ai.djl:api:0.28.0")
implementation("ai.djl.huggingface:tokenizers:0.28.0")
}
}
}
}
96 changes: 96 additions & 0 deletions examples/src/jvmMain/kotlin/io/kinference/examples/Utils.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package io.kinference.examples

import io.kinference.core.KIONNXData
import io.kinference.core.data.tensor.KITensor
import io.kinference.ndarray.arrays.LongNDArray
import io.kinference.ndarray.arrays.NumberNDArrayCore
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.request.prepareRequest
import io.ktor.client.statement.bodyAsChannel
import io.ktor.util.cio.writeChannel
import io.ktor.utils.io.copyAndClose
import java.io.File

/**
* Directory used to store cached files.
*
* This variable combines the user's current working directory
* with a "cache" subdirectory to create the path for storing cache files.
* It is used in various functions to check for existing files or directories,
* create new ones if they do not exist, and manage the caching of downloaded files.
*/
val cacheDirectory = System.getProperty("user.dir") + "/.cache/"

/**
* Downloads a file from the given URL and saves it with the specified file name.
*
* Checks if the directory specified by `cacheDirectory` exists.
* If not, it creates the directory. If the file already exists,
* the download is skipped. Otherwise, the file is downloaded
* using an HTTP client with a 10-minute timeout setting.
*
* @param url The URL from which to download the file.
* @param fileName The name to use for the downloaded file.
* @param timeout Optional timeout duration for the download request, in milliseconds.
* Defaults to 600,000 milliseconds (10 minutes).
* Increase the timeout if you are not sure that download for the particular model with fit into the default timeout.
*/
suspend fun downloadFile(url: String, fileName: String, timeout: Long = 600_000) {
// Ensure the predefined path is treated as a directory
val directory = File(cacheDirectory)

// Check if the directory exists, if not create it
if (!directory.exists()) {
println("Predefined directory doesn't exist. Creating directory at $cacheDirectory.")
directory.mkdirs() // Create the directory if it doesn't exist
}

// Check if the file already exists
val file = File(directory, fileName)
if (file.exists()) {
println("File already exists at ${file.absolutePath}. Skipping download.")
return // Exit the function if the file exists
}

// Create an instance of HttpClient with custom timeout settings
val client = HttpClient {
install(HttpTimeout) {
requestTimeoutMillis = timeout
}
}

// Download the file and write to the specified output path
client.prepareRequest(url).execute { response ->
response.bodyAsChannel().copyAndClose(file.writeChannel())
}

client.close()
}

/**
* Extracts the token ID with the highest probability from the output tensor.
*
* @param output A map containing the output tensors identified by their names.
* @param tokensSize The number of tokens in the sequence.
* @param outputName The name of the tensor containing the logits.
* @return The ID of the top token.
*/
suspend fun extractTopToken(output: Map<String, KIONNXData<*>>, tokensSize: Int, outputName: String): Long {
val logits = output[outputName]!! as KITensor
val sliced = logits.data.slice(
starts = intArrayOf(0, 0, tokensSize - 1, 0), // First batch, first element in the second dimension, last token, first vocab entry
ends = intArrayOf(1, 1, tokensSize, 50257), // Same batch, same second dimension, one token step, whole vocab (50257)
steps = intArrayOf(1, 1, 1, 1) // Step of 1 for each dimension
) as NumberNDArrayCore
val softmax = sliced.softmax(axis = -1)
val topK = softmax.topK(
axis = -1, // Apply top-k along the last dimension (vocabulary size)
k = 1, // Retrieve the top 1 element
largest = true, // We want the largest probabilities (most probable tokens)
sorted = false // Sorting is unnecessary since we are only retrieving the top 1
)
val tokenId = (topK.second as LongNDArray)[intArrayOf(0, 0, 0, 0)]

return tokenId
}
Loading

0 comments on commit d7d7ecd

Please sign in to comment.