Skip to content

Commit

Permalink
JBAI-5829 [examples] Added GPT-2 example using ORTEngine for text gen…
Browse files Browse the repository at this point in the history
…eration.
  • Loading branch information
dmitriyb committed Sep 27, 2024
1 parent 5710839 commit a38c555
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 169 deletions.
23 changes: 23 additions & 0 deletions examples/src/jvmMain/kotlin/io/kinference/examples/Utils.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package io.kinference.examples

import io.kinference.core.KIONNXData
import io.kinference.core.data.tensor.KITensor
import io.kinference.ndarray.arrays.LongNDArray
import io.kinference.ndarray.arrays.NumberNDArrayCore
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.request.prepareRequest
Expand Down Expand Up @@ -39,3 +43,22 @@ suspend fun downloadFile(url: String, outputPath: String) {

client.close()
}

suspend fun extractTopToken(output: Map<String, KIONNXData<*>>, tokensSize: Int, outputName: String): Long {
val logits = output[outputName]!! as KITensor
val sliced = logits.data.slice(
starts = intArrayOf(0, 0, tokensSize - 1, 0), // First batch, first element in the second dimension, last token, first vocab entry
ends = intArrayOf(1, 1, tokensSize, 50257), // Same batch, same second dimension, one token step, whole vocab (50257)
steps = intArrayOf(1, 1, 1, 1) // Step of 1 for each dimension
) as NumberNDArrayCore
val softmax = sliced.softmax(axis = -1)
val topK = softmax.topK(
axis = -1, // Apply top-k along the last dimension (vocabulary size)
k = 1, // Retrieve the top 1 element
largest = true, // We want the largest probabilities (most probable tokens)
sorted = false // Sorting is unnecessary since we are only retrieving the top 1
)
val tokenId = (topK.second as LongNDArray)[intArrayOf(0, 0, 0, 0)]

return tokenId
}
60 changes: 60 additions & 0 deletions examples/src/jvmMain/kotlin/io/kinference/examples/lm/KIMain.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package io.kinference.examples.lm

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
import io.kinference.core.KIEngine
import io.kinference.core.data.tensor.asTensor
import io.kinference.examples.downloadFile
import io.kinference.examples.extractTopToken
import io.kinference.examples.resourcesPath
import io.kinference.ndarray.arrays.LongNDArray
import io.kinference.ndarray.arrays.NDArrayCore
import io.kinference.utils.CommonDataLoader
import io.kinference.utils.PredictionConfigs
import io.kinference.utils.inlines.InlineInt
import okio.Path.Companion.toPath

// Constants for input and output tensor names used in the GPT-2 model
private const val INPUT_TENSOR_NAME = "input1"
private const val OUTPUT_TENSOR_NAME = "output1" // We use only logits tensor

suspend fun main() {
val modelUrl = "https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx"
val modelName = "gpt2-lm-head-10"

println("Downloading model from: $modelUrl")
downloadFile(modelUrl, "$resourcesPath/$modelName.onnx")

val modelBytes = CommonDataLoader.bytes("${resourcesPath}/$modelName.onnx".toPath())

println("Loading model...")
val model = KIEngine.loadModel(modelBytes, optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator)

val tokenizer = HuggingFaceTokenizer.newInstance("gpt2", mapOf("modelMaxLength" to "1024"))
val testString = "Neurogenesis is most active during embryonic development and is responsible for producing " +
"all the various types of neurons of the organism, but it continues throughout adult life " +
"in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will " +
"live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances."
val encoded = tokenizer.encode(testString)
val tokens = encoded.ids
val tokensSize = tokens.size

val predictionLength = 34
val outputTokens = LongArray(predictionLength) { 0 }

val input = LongNDArray(1, tokensSize) { idx: InlineInt -> tokens[idx.value] }.unsqueeze(0)
var currentContext = input.clone()

print("Here goes the test text for generation:\n$testString")

for (idx in 0 until predictionLength) {
val inputTensor = listOf((currentContext as NDArrayCore).asTensor(INPUT_TENSOR_NAME))
val output = model.predict(inputTensor)

outputTokens[idx] = extractTopToken(output, tokensSize + idx, OUTPUT_TENSOR_NAME)

val newTokenArray = LongNDArray(1, 1) { _: InlineInt -> outputTokens[idx] }
currentContext = currentContext.concat(listOf(newTokenArray.unsqueeze(0)), axis = -1)
print(tokenizer.decode(longArrayOf(outputTokens[idx])))
}
println("\n\nDone")
}
169 changes: 0 additions & 169 deletions examples/src/jvmMain/kotlin/io/kinference/examples/lm/Main.kt

This file was deleted.

74 changes: 74 additions & 0 deletions examples/src/jvmMain/kotlin/io/kinference/examples/lm/ORTMain.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package io.kinference.examples.lm

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
import io.kinference.core.data.tensor.KITensor
import io.kinference.core.data.tensor.asTensor
import io.kinference.examples.downloadFile
import io.kinference.examples.extractTopToken
import io.kinference.examples.resourcesPath
import io.kinference.ndarray.arrays.FloatNDArray
import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke
import io.kinference.ort.ORTData
import io.kinference.ort.ORTEngine
import io.kinference.ort.data.tensor.ORTTensor
import io.kinference.utils.CommonDataLoader
import io.kinference.utils.inlines.InlineInt
import io.kinference.utils.toIntArray
import okio.Path.Companion.toPath

// Constants for input and output tensor names used in the GPT-2 model
private const val INPUT_TENSOR_NAME = "input1"
private const val OUTPUT_TENSOR_NAME = "output1" // We use only logits tensor

suspend fun main() {
val modelUrl = "https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx"
val modelName = "gpt2-lm-head-10"

println("Downloading model from: $modelUrl")
downloadFile(modelUrl, "$resourcesPath/$modelName.onnx")

val modelBytes = CommonDataLoader.bytes("${resourcesPath}/$modelName.onnx".toPath())

println("Loading model...")
val model = ORTEngine.loadModel(modelBytes)

val tokenizer = HuggingFaceTokenizer.newInstance("gpt2", mapOf("modelMaxLength" to "1024"))
val testString = "Neurogenesis is most active during embryonic development and is responsible for producing " +
"all the various types of neurons of the organism, but it continues throughout adult life " +
"in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will " +
"live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances."
val encoded = tokenizer.encode(testString)
val tokens = encoded.ids
val tokensSize = tokens.size

val predictionLength = 34
val outputTokens = LongArray(predictionLength) { 0 }

val input = ORTTensor(tokens, longArrayOf(1, 1, tokensSize.toLong()))
var currentContext = input.clone(INPUT_TENSOR_NAME)

print("Here goes the test text for generation:\n$testString")

for (idx in 0 until predictionLength) {
val inputTensor = listOf(currentContext)
val output = model.predict(inputTensor)

outputTokens[idx] = extractTopToken(convertToKITensorMap(output), tokensSize + idx, OUTPUT_TENSOR_NAME)

val newTokenArray = tokens + outputTokens.slice(IntRange(0, idx))
currentContext = ORTTensor(newTokenArray, longArrayOf(1, 1, tokensSize + idx + 1L), INPUT_TENSOR_NAME)
print(tokenizer.decode(longArrayOf(outputTokens[idx])))
}
println("\n\nDone")
}

private suspend fun convertToKITensorMap(outputs: Map<String, ORTData<*>>): Map<String, KITensor> {
return outputs.map { (key, value) ->
val ortTensor = value as ORTTensor
val data = ortTensor.toFloatArray()
val shape = ortTensor.shape.toIntArray()
val ndArray = FloatNDArray(shape) { idx: InlineInt -> data[idx.value] }
val tensor = ndArray.asTensor(key)
return@map key to tensor
}.toMap()
}

0 comments on commit a38c555

Please sign in to comment.