diff --git a/README.md b/README.md index b209fb9db..b4f196305 100644 --- a/README.md +++ b/README.md @@ -181,8 +181,13 @@ kotlin { ``` ## Examples -You can find several KInference usage examples in [this repository](https://github.com/JetBrains-Research/kinference-examples). -The repository has examples of multi-backend project configuration and sharing KInference-related code between the modules. +The [examples module](https://github.com/JetBrains-Research/kinference/tree/master/examples) contains examples of solving classification tasks +(cats vs dogs) and text generation. +Different backends are used in the examples. +Models for the examples were selected from the [ONNX Model Zoo](https://github.com/onnx/models). +Running the examples does not require converting models to different opsets. +However, if you need to run a model with operator versions not supported by KInference, +you can refer to [Convert guide](https://github.com/OpenPPL/ppl.nn/blob/master/docs/en/onnx-model-opset-convert-guide.md). ## Want to know more? KInference API itself is widely documented, so you can explore its code and interfaces to get to know KInference better. diff --git a/build.gradle.kts b/build.gradle.kts index 7c543737d..3caf12538 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -6,6 +6,7 @@ import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnLockMismatchReport import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnPlugin import org.jetbrains.kotlin.gradle.targets.js.yarn.YarnRootExtension import org.jetbrains.kotlin.gradle.tasks.KotlinCompilationTask +import org.jetbrains.kotlin.utils.addToStdlib.applyIf group = "io.kinference" version = "0.2.22" @@ -35,21 +36,23 @@ subprojects { apply { plugin("org.jetbrains.kotlin.multiplatform") - - plugin("maven-publish") plugin("idea") } - publishing { - repositories { - maven { - name = "SpacePackages" - url = uri("https://packages.jetbrains.team/maven/p/ki/maven") + applyIf(path != ":examples") { + apply(plugin = "maven-publish") + + publishing { + repositories { + maven { + name = "SpacePackages" + url = uri("https://packages.jetbrains.team/maven/p/ki/maven") - credentials { - username = System.getenv("JB_SPACE_CLIENT_ID") - password = System.getenv("JB_SPACE_CLIENT_SECRET") + credentials { + username = System.getenv("JB_SPACE_CLIENT_ID") + password = System.getenv("JB_SPACE_CLIENT_SECRET") + } } } } diff --git a/examples/build.gradle.kts b/examples/build.gradle.kts new file mode 100644 index 000000000..69891a917 --- /dev/null +++ b/examples/build.gradle.kts @@ -0,0 +1,33 @@ +group = rootProject.group +version = rootProject.version + +kotlin { + jvm() + + sourceSets { + jvmMain { + dependencies { + api(project(":inference:inference-api")) + api(project(":inference:inference-core")) + api(project(":inference:inference-ort")) + api(project(":serialization:serializer-protobuf")) + api(project(":utils:utils-common")) + + api(project(":ndarray:ndarray-api")) + api(project(":ndarray:ndarray-core")) + + implementation("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.5.2") + implementation("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:0.5.2") // Dataset support + + implementation("io.ktor:ktor-client-core:2.3.12") + implementation("io.ktor:ktor-client-cio:2.3.12") // JVM Engine + + api("org.slf4j:slf4j-api:2.0.9") + api("org.slf4j:slf4j-simple:2.0.9") + + implementation("ai.djl:api:0.28.0") + implementation("ai.djl.huggingface:tokenizers:0.28.0") + } + } + } +} diff --git a/examples/src/jvmMain/kotlin/io/kinference/examples/Utils.kt b/examples/src/jvmMain/kotlin/io/kinference/examples/Utils.kt new file mode 100644 index 000000000..eb9cbcd99 --- /dev/null +++ b/examples/src/jvmMain/kotlin/io/kinference/examples/Utils.kt @@ -0,0 +1,96 @@ +package io.kinference.examples + +import io.kinference.core.KIONNXData +import io.kinference.core.data.tensor.KITensor +import io.kinference.ndarray.arrays.LongNDArray +import io.kinference.ndarray.arrays.NumberNDArrayCore +import io.ktor.client.HttpClient +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.request.prepareRequest +import io.ktor.client.statement.bodyAsChannel +import io.ktor.util.cio.writeChannel +import io.ktor.utils.io.copyAndClose +import java.io.File + +/** + * Directory used to store cached files. + * + * This variable combines the user's current working directory + * with a "cache" subdirectory to create the path for storing cache files. + * It is used in various functions to check for existing files or directories, + * create new ones if they do not exist, and manage the caching of downloaded files. + */ +val cacheDirectory = System.getProperty("user.dir") + "/.cache/" + +/** + * Downloads a file from the given URL and saves it with the specified file name. + * + * Checks if the directory specified by `cacheDirectory` exists. + * If not, it creates the directory. If the file already exists, + * the download is skipped. Otherwise, the file is downloaded + * using an HTTP client with a 10-minute timeout setting. + * + * @param url The URL from which to download the file. + * @param fileName The name to use for the downloaded file. + * @param timeout Optional timeout duration for the download request, in milliseconds. + * Defaults to 600,000 milliseconds (10 minutes). + * Increase the timeout if you are not sure that download for the particular model with fit into the default timeout. + */ +suspend fun downloadFile(url: String, fileName: String, timeout: Long = 600_000) { + // Ensure the predefined path is treated as a directory + val directory = File(cacheDirectory) + + // Check if the directory exists, if not create it + if (!directory.exists()) { + println("Predefined directory doesn't exist. Creating directory at $cacheDirectory.") + directory.mkdirs() // Create the directory if it doesn't exist + } + + // Check if the file already exists + val file = File(directory, fileName) + if (file.exists()) { + println("File already exists at ${file.absolutePath}. Skipping download.") + return // Exit the function if the file exists + } + + // Create an instance of HttpClient with custom timeout settings + val client = HttpClient { + install(HttpTimeout) { + requestTimeoutMillis = timeout + } + } + + // Download the file and write to the specified output path + client.prepareRequest(url).execute { response -> + response.bodyAsChannel().copyAndClose(file.writeChannel()) + } + + client.close() +} + +/** + * Extracts the token ID with the highest probability from the output tensor. + * + * @param output A map containing the output tensors identified by their names. + * @param tokensSize The number of tokens in the sequence. + * @param outputName The name of the tensor containing the logits. + * @return The ID of the top token. + */ +suspend fun extractTopToken(output: Map>, tokensSize: Int, outputName: String): Long { + val logits = output[outputName]!! as KITensor + val sliced = logits.data.slice( + starts = intArrayOf(0, 0, tokensSize - 1, 0), // First batch, first element in the second dimension, last token, first vocab entry + ends = intArrayOf(1, 1, tokensSize, 50257), // Same batch, same second dimension, one token step, whole vocab (50257) + steps = intArrayOf(1, 1, 1, 1) // Step of 1 for each dimension + ) as NumberNDArrayCore + val softmax = sliced.softmax(axis = -1) + val topK = softmax.topK( + axis = -1, // Apply top-k along the last dimension (vocabulary size) + k = 1, // Retrieve the top 1 element + largest = true, // We want the largest probabilities (most probable tokens) + sorted = false // Sorting is unnecessary since we are only retrieving the top 1 + ) + val tokenId = (topK.second as LongNDArray)[intArrayOf(0, 0, 0, 0)] + + return tokenId +} diff --git a/examples/src/jvmMain/kotlin/io/kinference/examples/classification/KIClassificationMain.kt b/examples/src/jvmMain/kotlin/io/kinference/examples/classification/KIClassificationMain.kt new file mode 100644 index 000000000..00ace87bc --- /dev/null +++ b/examples/src/jvmMain/kotlin/io/kinference/examples/classification/KIClassificationMain.kt @@ -0,0 +1,122 @@ +package io.kinference.examples.classification + +import io.kinference.core.KIEngine +import io.kinference.core.data.tensor.KITensor +import io.kinference.core.data.tensor.asTensor +import io.kinference.examples.downloadFile +import io.kinference.examples.cacheDirectory +import io.kinference.ndarray.arrays.* +import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke +import io.kinference.utils.CommonDataLoader +import io.kinference.utils.PredictionConfigs +import io.kinference.utils.inlines.InlineInt +import okio.Path.Companion.toPath +import org.jetbrains.kotlinx.dl.api.preprocessing.pipeline +import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset +import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath +import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders +import org.jetbrains.kotlinx.dl.impl.inference.imagerecognition.InputType +import org.jetbrains.kotlinx.dl.impl.preprocessing.* +import org.jetbrains.kotlinx.dl.impl.preprocessing.image.* +import java.awt.image.BufferedImage +import java.io.File +import kotlin.collections.mutableMapOf + +// Constants for input and output tensor names used in the CaffeNet model +private const val INPUT_TENSOR_NAME = "data_0" +private const val OUTPUT_TENSOR_NAME = "prob_1" + +// Preprocessing pipeline for input images using KotlinDL +private val preprocessing = pipeline() + .resize { + outputWidth = 224 + outputHeight = 224 + interpolation = InterpolationType.BILINEAR + } + .convert { colorMode = ColorMode.BGR } + .toFloatArray { } + .call(InputType.CAFFE.preprocessing()) + +// Path to the small dataset of dogs vs cats images (100 images) +private val dogsVsCatsDatasetPath = dogsCatsSmallDatasetPath() + +/** + * Creates a Map of input tensors categorized by their respective classes (e.g., "cat" and "dog"). + * + * This function reads images from the dataset, preprocesses them, + * transposes the tensors to the required format, and groups them + * based on their class label. + * + * @return A Map where the keys are the class labels (e.g., "cat" and "dog"), + * and the values are lists of KITensor objects representing the input tensors + * for each class. + */ +private suspend fun createInputs(): Map> { + val dataset = OnFlyImageDataset.create( + File(dogsVsCatsDatasetPath), + FromFolders(mapping = mapOf("cat" to 0, "dog" to 1)), + preprocessing + ).shuffle() + + + val tensorShape = intArrayOf(1, 224, 224, 3) // Original tensor shape is [batch, width, height, channel] + val permuteAxis = intArrayOf(0, 3, 1, 2) // Permutations for shape [batch, channel, width, height] + val inputTensors = mutableMapOf>() + + for (i in 0 until dataset.xSize()) { + val inputData = dataset.getX(i) + val inputClass = if (dataset.getY(i).toInt() == 0) "cat" else "dog" + val floatNDArray = FloatNDArray(tensorShape) { index: InlineInt -> inputData[index.value] } // Create an NDArray from the image data + val inputTensor = floatNDArray.transpose(permuteAxis).asTensor(INPUT_TENSOR_NAME) // Transpose and create a tensor from the NDArray + inputTensors.putIfAbsent(inputClass, mutableListOf()) + inputTensors[inputClass]!!.add(inputTensor) + } + + return inputTensors +} + +/** + * Displays the top 5 predictions with their corresponding labels and scores. + * + * @param predictions The predicted scores in a multidimensional array format. + * @param classLabels The list of class labels corresponding to the predictions. + * @param originalClass The actual class label of the instance being predicted. + */ +private fun displayTopPredictions(predictions: FloatNDArray, classLabels: List, originalClass: String) { + val predictionArray = predictions.array.toArray() + val indexedScores = predictionArray.withIndex().sortedByDescending { it.value }.take(5) + + println("\nOriginal class: $originalClass") + println("Top 5 predictions:") + for ((index, score) in indexedScores) { + val predictedClassLabel = if (index in classLabels.indices) classLabels[index] else "Unknown" + println("${predictedClassLabel}: ${"%.2f".format(score * 100)}%") + } +} + +suspend fun main() { + val modelUrl = "https://github.com/onnx/models/raw/main/validated/vision/classification/caffenet/model/caffenet-12.onnx" + val synsetUrl = "https://s3.amazonaws.com/onnx-model-zoo/synset.txt" + val modelName = "CaffeNet" + + println("Downloading model from: $modelUrl") + downloadFile(modelUrl, "$modelName.onnx") + println("Downloading synset from: $synsetUrl") + downloadFile(synsetUrl, "synset.txt") + + val classLabels = File("$cacheDirectory/synset.txt").readLines() + + println("Loading model...") + val model = KIEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath(), optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator) + println("Creating inputs...") + val inputTensors = createInputs() + + println("Starting inference...") + inputTensors.forEach { dataClass -> + dataClass.value.forEach { tensor -> + val actualOutputs = model.predict(listOf(tensor)) + val predictions = actualOutputs[OUTPUT_TENSOR_NAME]?.data as FloatNDArray + displayTopPredictions(predictions, classLabels, dataClass.key) + } + } +} diff --git a/examples/src/jvmMain/kotlin/io/kinference/examples/classification/ORTClassificationMain.kt b/examples/src/jvmMain/kotlin/io/kinference/examples/classification/ORTClassificationMain.kt new file mode 100644 index 000000000..c8f3b0d5c --- /dev/null +++ b/examples/src/jvmMain/kotlin/io/kinference/examples/classification/ORTClassificationMain.kt @@ -0,0 +1,121 @@ +package io.kinference.examples.classification + +import io.kinference.examples.downloadFile +import io.kinference.examples.cacheDirectory +import io.kinference.ndarray.arrays.* +import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke +import io.kinference.ort.ORTEngine +import io.kinference.ort.data.tensor.ORTTensor +import io.kinference.utils.CommonDataLoader +import io.kinference.utils.inlines.InlineInt +import io.kinference.utils.toLongArray +import okio.Path.Companion.toPath +import org.jetbrains.kotlinx.dl.api.preprocessing.pipeline +import org.jetbrains.kotlinx.dl.dataset.OnFlyImageDataset +import org.jetbrains.kotlinx.dl.dataset.embedded.dogsCatsSmallDatasetPath +import org.jetbrains.kotlinx.dl.dataset.generator.FromFolders +import org.jetbrains.kotlinx.dl.impl.inference.imagerecognition.InputType +import org.jetbrains.kotlinx.dl.impl.preprocessing.* +import org.jetbrains.kotlinx.dl.impl.preprocessing.image.* +import java.awt.image.BufferedImage +import java.io.File +import kotlin.collections.mutableMapOf + +// Constants for input and output tensor names used in the CaffeNet model +private const val INPUT_TENSOR_NAME = "data_0" +private const val OUTPUT_TENSOR_NAME = "prob_1" + +// Preprocessing pipeline for input images using KotlinDL +private val preprocessing = pipeline() + .resize { + outputWidth = 224 + outputHeight = 224 + interpolation = InterpolationType.BILINEAR + } + .convert { colorMode = ColorMode.BGR } + .toFloatArray { } + .call(InputType.CAFFE.preprocessing()) + +// Path to the small dataset of dogs vs cats images (100 images) +private val dogsVsCatsDatasetPath = dogsCatsSmallDatasetPath() + +/** + * Creates a Map of input tensors categorized by their respective classes (e.g., "cat" and "dog"). + * + * This function reads images from the dataset, preprocesses them, + * transposes the tensors to the required format, and groups them + * based on their class label. + * + * @return A Map where the keys are the class labels (e.g., "cat" and "dog"), + * and the values are lists of KITensor objects representing the input tensors + * for each class. + */ +private suspend fun createInputs(): Map> { + val dataset = OnFlyImageDataset.create( + File(dogsVsCatsDatasetPath), + FromFolders(mapping = mapOf("cat" to 0, "dog" to 1)), + preprocessing + ).shuffle() + + + val tensorShape = intArrayOf(1, 224, 224, 3) // Original tensor shape is [batch, width, height, channel] + val permuteAxis = intArrayOf(0, 3, 1, 2) // Permutations for shape [batch, channel, width, height] + val inputTensors = mutableMapOf>() + + for (i in 0 until dataset.xSize()) { + val inputData = dataset.getX(i) + val inputClass = if (dataset.getY(i).toInt() == 0) "cat" else "dog" + val floatNDArray = FloatNDArray(tensorShape) { index: InlineInt -> inputData[index.value] }.transpose(permuteAxis) // Create an NDArray from the image data + val inputTensor = ORTTensor(floatNDArray.array.toArray(), floatNDArray.shape.toLongArray(), INPUT_TENSOR_NAME) // Transpose and create a tensor from the NDArray + inputTensors.putIfAbsent(inputClass, mutableListOf()) + inputTensors[inputClass]!!.add(inputTensor) + } + + return inputTensors +} + +/** + * Displays the top 5 predictions with their corresponding labels and scores. + * + * @param predictions The predicted scores in a multidimensional array format. + * @param classLabels The list of class labels corresponding to the predictions. + * @param originalClass The actual class label of the instance being predicted. + */ +private fun displayTopPredictions(predictions: ORTTensor, classLabels: List, originalClass: String) { + val predictionArray = predictions.toFloatArray() + val indexedScores = predictionArray.withIndex().sortedByDescending { it.value }.take(5) + + println("\nOriginal class: $originalClass") + println("Top 5 predictions:") + for ((index, score) in indexedScores) { + val predictedClassLabel = if (index in classLabels.indices) classLabels[index] else "Unknown" + println("${predictedClassLabel}: ${"%.2f".format(score * 100)}%") + } +} + +suspend fun main() { + val modelUrl = "https://github.com/onnx/models/raw/main/validated/vision/classification/caffenet/model/caffenet-12.onnx" + val synsetUrl = "https://s3.amazonaws.com/onnx-model-zoo/synset.txt" + val modelName = "CaffeNet" + + println("Downloading model from: $modelUrl") + downloadFile(modelUrl, "$modelName.onnx") + println("Downloading synset from: $synsetUrl") + downloadFile(synsetUrl, "synset.txt") + + val classLabels = File("$cacheDirectory/synset.txt").readLines() + + println("Loading model...") + val model = ORTEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath()) + println("Creating inputs...") + val inputTensors = createInputs() + + println("Starting inference...") + inputTensors.forEach { dataClass -> + dataClass.value.forEach { tensor -> + val actualOutputs = model.predict(listOf(tensor)) + val predictions = actualOutputs[OUTPUT_TENSOR_NAME]!! as ORTTensor + displayTopPredictions(predictions, classLabels, dataClass.key) + } + } +} diff --git a/examples/src/jvmMain/kotlin/io/kinference/examples/lm/KIGPT2Main.kt b/examples/src/jvmMain/kotlin/io/kinference/examples/lm/KIGPT2Main.kt new file mode 100644 index 000000000..81e106ee2 --- /dev/null +++ b/examples/src/jvmMain/kotlin/io/kinference/examples/lm/KIGPT2Main.kt @@ -0,0 +1,58 @@ +package io.kinference.examples.lm + +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer +import io.kinference.core.KIEngine +import io.kinference.core.data.tensor.asTensor +import io.kinference.examples.downloadFile +import io.kinference.examples.extractTopToken +import io.kinference.examples.cacheDirectory +import io.kinference.ndarray.arrays.LongNDArray +import io.kinference.ndarray.arrays.NDArrayCore +import io.kinference.utils.CommonDataLoader +import io.kinference.utils.PredictionConfigs +import io.kinference.utils.inlines.InlineInt +import okio.Path.Companion.toPath + +// Constants for input and output tensor names used in the GPT-2 model +private const val INPUT_TENSOR_NAME = "input1" +private const val OUTPUT_TENSOR_NAME = "output1" // We use only logits tensor + +suspend fun main() { + val modelUrl = "https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx" + val modelName = "gpt2-lm-head-10" + + println("Downloading model from: $modelUrl") + downloadFile(modelUrl, "$modelName.onnx") //GPT-2 from model zoo is around 650 Mb, adjust your timeout if needed + + println("Loading model...") + val model = KIEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath(), optimize = true, predictionConfig = PredictionConfigs.DefaultAutoAllocator) + + val tokenizer = HuggingFaceTokenizer.newInstance("gpt2", mapOf("modelMaxLength" to "1024")) + val testString = "Neurogenesis is most active during embryonic development and is responsible for producing " + + "all the various types of neurons of the organism, but it continues throughout adult life " + + "in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will " + + "live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances." + val encoded = tokenizer.encode(testString) + val tokens = encoded.ids + val tokensSize = tokens.size + + val predictionLength = 34 + val outputTokens = LongArray(predictionLength) { 0 } + + val input = LongNDArray(1, tokensSize) { idx: InlineInt -> tokens[idx.value] }.unsqueeze(0) + var currentContext = input.clone() + + print("Here goes the test text for generation:\n$testString") + + for (idx in 0 until predictionLength) { + val inputTensor = listOf((currentContext as NDArrayCore).asTensor(INPUT_TENSOR_NAME)) + val output = model.predict(inputTensor) + + outputTokens[idx] = extractTopToken(output, tokensSize + idx, OUTPUT_TENSOR_NAME) + + val newTokenArray = LongNDArray(1, 1) { _: InlineInt -> outputTokens[idx] } + currentContext = currentContext.concat(listOf(newTokenArray.unsqueeze(0)), axis = -1) + print(tokenizer.decode(longArrayOf(outputTokens[idx]))) + } + println("\n\nDone") +} diff --git a/examples/src/jvmMain/kotlin/io/kinference/examples/lm/ORTGPT2Main.kt b/examples/src/jvmMain/kotlin/io/kinference/examples/lm/ORTGPT2Main.kt new file mode 100644 index 000000000..dd0634131 --- /dev/null +++ b/examples/src/jvmMain/kotlin/io/kinference/examples/lm/ORTGPT2Main.kt @@ -0,0 +1,72 @@ +package io.kinference.examples.lm + +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer +import io.kinference.core.data.tensor.KITensor +import io.kinference.core.data.tensor.asTensor +import io.kinference.examples.downloadFile +import io.kinference.examples.extractTopToken +import io.kinference.examples.cacheDirectory +import io.kinference.ndarray.arrays.FloatNDArray +import io.kinference.ndarray.arrays.FloatNDArray.Companion.invoke +import io.kinference.ort.ORTData +import io.kinference.ort.ORTEngine +import io.kinference.ort.data.tensor.ORTTensor +import io.kinference.utils.CommonDataLoader +import io.kinference.utils.inlines.InlineInt +import io.kinference.utils.toIntArray +import okio.Path.Companion.toPath + +// Constants for input and output tensor names used in the GPT-2 model +private const val INPUT_TENSOR_NAME = "input1" +private const val OUTPUT_TENSOR_NAME = "output1" // We use only logits tensor + +suspend fun main() { + val modelUrl = "https://github.com/onnx/models/raw/main/validated/text/machine_comprehension/gpt-2/model/gpt2-lm-head-10.onnx" + val modelName = "gpt2-lm-head-10" + + println("Downloading model from: $modelUrl") + downloadFile(modelUrl, "$modelName.onnx") //GPT-2 from model zoo is around 650 Mb, adjust your timeout if needed + + println("Loading model...") + val model = ORTEngine.loadModel("$cacheDirectory/$modelName.onnx".toPath()) + + val tokenizer = HuggingFaceTokenizer.newInstance("gpt2", mapOf("modelMaxLength" to "1024")) + val testString = "Neurogenesis is most active during embryonic development and is responsible for producing " + + "all the various types of neurons of the organism, but it continues throughout adult life " + + "in a variety of organisms. Once born, neurons do not divide (see mitosis), and many will " + + "live the lifespan of the animal, except under extraordinary and usually pathogenic circumstances." + val encoded = tokenizer.encode(testString) + val tokens = encoded.ids + val tokensSize = tokens.size + + val predictionLength = 34 + val outputTokens = LongArray(predictionLength) { 0 } + + val input = ORTTensor(tokens, longArrayOf(1, 1, tokensSize.toLong())) + var currentContext = input.clone(INPUT_TENSOR_NAME) + + print("Here goes the test text for generation:\n$testString") + + for (idx in 0 until predictionLength) { + val inputTensor = listOf(currentContext) + val output = model.predict(inputTensor) + + outputTokens[idx] = extractTopToken(convertToKITensorMap(output), tokensSize + idx, OUTPUT_TENSOR_NAME) + + val newTokenArray = tokens + outputTokens.slice(IntRange(0, idx)) + currentContext = ORTTensor(newTokenArray, longArrayOf(1, 1, tokensSize + idx + 1L), INPUT_TENSOR_NAME) + print(tokenizer.decode(longArrayOf(outputTokens[idx]))) + } + println("\n\nDone") +} + +private suspend fun convertToKITensorMap(outputs: Map>): Map { + return outputs.map { (name, ortTensor) -> + val ortTensor = ortTensor as ORTTensor + val data = ortTensor.toFloatArray() + val shape = ortTensor.shape.toIntArray() + val ndArray = FloatNDArray(shape) { idx: InlineInt -> data[idx.value] } + val kiTensor = ndArray.asTensor(name) + return@map name to kiTensor + }.toMap() +} diff --git a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/Gemm.kt b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/Gemm.kt index ed2646f2c..5a9c7558f 100644 --- a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/Gemm.kt +++ b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/Gemm.kt @@ -15,24 +15,65 @@ import io.kinference.protobuf.message.AttributeProto import io.kinference.protobuf.message.TensorProto sealed class Gemm(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) : Operator(name, info, attributes, inputs, outputs) { + private val alpha: Double by attribute { it: Number -> it.toDouble() } + private val beta: Double by attribute { it: Number -> it.toDouble() } + + private val transA: Boolean by attribute { it: Number -> it.toInt() != 0 } + private val transB: Boolean by attribute { it: Number -> it.toInt() != 0 } + companion object { private val DEFAULT_VERSION = VersionInfo(sinceVersion = 11) operator fun invoke(name: String, version: Int?, attributes: Map>, inputs: List, outputs: List) = when (version ?: DEFAULT_VERSION.sinceVersion) { + in GemmVer9.VERSION.asRange() -> GemmVer9(name, attributes, inputs, outputs) in GemmVer11.VERSION.asRange() -> GemmVer11(name, attributes, inputs, outputs) else -> error("Unsupported version of Gemm operator: $version") } } -} + protected suspend fun getDest(array: NDArrayCore, type: DataType, targetShape: IntArray): MutableNDArrayCore { + if (array.shape.contentEquals(targetShape)) return array.toMutable() -class GemmVer11(name: String, attributes: Map>, inputs: List, outputs: List) : Gemm(name, INFO, attributes, inputs, outputs) { - private val alpha: Double by attribute { it: Number -> it.toDouble() } - private val beta: Double by attribute { it: Number -> it.toDouble() } + val dstArray = allocateNDArray(type, Strides(targetShape)) as MutableNumberNDArrayCore + val unsqueezedShape = unsqueezeFirst(array.shape, targetShape.size) - private val transA: Boolean by attribute { it: Number -> it.toInt() != 0 } - private val transB: Boolean by attribute { it: Number -> it.toInt() != 0 } + if (targetShape[1] != unsqueezedShape[1] && unsqueezedShape[1] == 1) { + val targetBlockSize = targetShape[1] + for (i in 0 until unsqueezedShape[0]) { + val dstOffsetBase = i * targetBlockSize + dstArray.fillByArrayValue(array, i, dstOffsetBase, dstOffsetBase + targetBlockSize) + } + } else { + dstArray.copyFrom(0, array) + } + + for (i in 1 until targetShape[0]) dstArray.copyFrom(i * targetShape[1], dstArray, 0, targetShape[1]) + return dstArray + } + + protected suspend fun > apply(inputs: List, optionalBias: Boolean): List { + val a = inputs[0]!!.data as NumberNDArrayCore + val b = inputs[1]!!.data as NumberNDArrayCore + val m = if (!transA) a.shape[0] else a.shape[1] + val n = if (!transB) b.shape[1] else b.shape[0] + val k = if (!transA) a.shape[1] else a.shape[0] + + val targetShape = intArrayOf(m, n) + val bias = if (optionalBias) { + inputs.getOrNull(2)?.data ?: allocateNDArray(a.type, targetShape) + } else { + inputs[2]!!.data + } as NumberNDArrayCore + + val c = getDest(bias, a.type, intArrayOf(m, n)) + gemm(m, n, k, alpha, a, b, beta, c, transposeA = transA, transposeB = transB) + + return listOf(c.asTensor()) + } +} + +class GemmVer9(name: String, attributes: Map>, inputs: List, outputs: List) : Gemm(name, INFO, attributes, inputs, outputs) { companion object { private val TYPE_CONSTRAINTS = setOf( TensorProto.DataType.FLOAT16, @@ -55,47 +96,53 @@ class GemmVer11(name: String, attributes: Map>, inputs: L private val INPUTS_INFO = listOf( IOInfo(0, TYPE_CONSTRAINTS, "A", optional = false), IOInfo(1, TYPE_CONSTRAINTS, "B", optional = false), - IOInfo(2, TYPE_CONSTRAINTS, "C", optional = true) + IOInfo(2, TYPE_CONSTRAINTS, "C", optional = false) ) private val OUTPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "Y", optional = false)) - internal val VERSION = VersionInfo(sinceVersion = 11) + internal val VERSION = VersionInfo(sinceVersion = 9, untilVersion = 11) private val INFO = OperatorInfo("Gemm", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) - - private suspend fun getDest(array: NDArrayCore?, type: DataType, targetShape: IntArray): MutableNDArrayCore { - if (array == null) return allocateNDArray(type, Strides(targetShape)) - if (array.shape.contentEquals(targetShape)) return array.toMutable() - - val dstArray = allocateNDArray(type, Strides(targetShape)) as MutableNumberNDArrayCore - val unsqueezedShape = unsqueezeFirst(array.shape, targetShape.size) - - if (targetShape[1] != unsqueezedShape[1] && unsqueezedShape[1] == 1) { - val targetBlockSize = targetShape[1] - for (i in 0 until unsqueezedShape[0]) { - val dstOffsetBase = i * targetBlockSize - dstArray.fillByArrayValue(array, i, dstOffsetBase, dstOffsetBase + targetBlockSize) - } - } else { - dstArray.copyFrom(0, array) - } - - for (i in 1 until targetShape[0]) dstArray.copyFrom(i * targetShape[1], dstArray, 0, targetShape[1]) - return dstArray - } } override suspend fun > apply(contexts: Contexts, inputs: List): List { - val a = inputs[0]!!.data as NumberNDArrayCore - val b = inputs[1]!!.data as NumberNDArrayCore + return apply>(inputs, INPUTS_INFO[2].optional) + } +} - val m = if (!transA) a.shape[0] else a.shape[1] - val n = if (!transB) b.shape[1] else b.shape[0] - val k = if (!transA) a.shape[1] else a.shape[0] +class GemmVer11(name: String, attributes: Map>, inputs: List, outputs: List) : Gemm(name, INFO, attributes, inputs, outputs) { + companion object { + private val TYPE_CONSTRAINTS = setOf( + TensorProto.DataType.FLOAT16, + TensorProto.DataType.FLOAT, + TensorProto.DataType.DOUBLE, + TensorProto.DataType.UINT32, + TensorProto.DataType.UINT64, + TensorProto.DataType.INT32, + TensorProto.DataType.INT64, + TensorProto.DataType.BFLOAT16 + ) - val c = getDest(inputs.getOrNull(2)?.data, a.type, intArrayOf(m, n)) - gemm(m, n, k, alpha, a, b, beta, c, transposeA = transA, transposeB = transB) + private val ATTRIBUTES_INFO = listOf( + AttributeInfo("alpha", setOf(AttributeProto.AttributeType.FLOAT), false, 1.0), + AttributeInfo("beta", setOf(AttributeProto.AttributeType.FLOAT), false, 1.0), + AttributeInfo("transA", setOf(AttributeProto.AttributeType.INT), false, 0), + AttributeInfo("transB", setOf(AttributeProto.AttributeType.INT), false, 0) + ) - return listOf(c.asTensor()) + private val INPUTS_INFO = listOf( + IOInfo(0, TYPE_CONSTRAINTS, "A", optional = false), + IOInfo(1, TYPE_CONSTRAINTS, "B", optional = false), + IOInfo(2, TYPE_CONSTRAINTS, "C", optional = true) + ) + + private val OUTPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "Y", optional = false)) + + internal val VERSION = VersionInfo(sinceVersion = 11) + private val INFO = OperatorInfo("Gemm", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) + } + + override suspend fun > apply(contexts: Contexts, inputs: List): List { + return apply>(inputs, INPUTS_INFO[2].optional) } } diff --git a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/LRN.kt b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/LRN.kt index 0ff15f088..1bae19370 100644 --- a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/LRN.kt +++ b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/math/LRN.kt @@ -16,17 +16,17 @@ import io.kinference.protobuf.message.TensorProto sealed class LRN(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) : Operator(name, info, attributes, inputs, outputs) { companion object { - private val DEFAULT_VERSION = VersionInfo(sinceVersion = 13) // last version. Other versions: 1. + private val DEFAULT_VERSION = VersionInfo(sinceVersion = 1) operator fun invoke(name: String, version: Int?, attributes: Map>, inputs: List, outputs: List) = when (version ?: DEFAULT_VERSION.sinceVersion) { - in LRN13.VERSION.asRange() -> LRN13(name, attributes, inputs, outputs) + in LRN1.VERSION.asRange() -> LRN1(name, attributes, inputs, outputs) else -> error("Unsupported version of LRN operator: $version") } } } -class LRN13(name: String, attributes: Map>, inputs: List, outputs: List) : +class LRN1(name: String, attributes: Map>, inputs: List, outputs: List) : LRN(name, INFO, attributes, inputs, outputs) { companion object { private val TYPE_CONSTRAINTS = setOf( @@ -51,7 +51,7 @@ class LRN13(name: String, attributes: Map>, inputs: List< IOInfo(0, TYPE_CONSTRAINTS, "Y", optional = false, differentiable = true) ) - internal val VERSION = VersionInfo(sinceVersion = 13) + internal val VERSION = VersionInfo(sinceVersion = 1) private val INFO = OperatorInfo("LRN", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) } diff --git a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Dropout.kt b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Dropout.kt index 904fb01be..4b3c13ddf 100644 --- a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Dropout.kt +++ b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Dropout.kt @@ -14,17 +14,17 @@ import io.kinference.utils.inlines.InlineInt sealed class Dropout(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) : Operator(name, info, attributes, inputs, outputs) { companion object { - private val DEFAULT_VERSION = VersionInfo(sinceVersion = 13) // last version. Other versions: 1, 6, 7, 10, 12. + private val DEFAULT_VERSION = VersionInfo(sinceVersion = 12) // last version. Other versions: 1, 6, 7, 10. operator fun invoke(name: String, version: Int?, attributes: Map>, inputs: List, outputs: List) = when (version ?: DEFAULT_VERSION.sinceVersion) { - in Dropout13.VERSION.asRange() -> Dropout13(name, attributes, inputs, outputs) + in Dropout12.VERSION.asRange() -> Dropout12(name, attributes, inputs, outputs) else -> error("Unsupported version of Dropout operator: $version") } } } -class Dropout13(name: String, attributes: Map>, inputs: List, outputs: List) : +class Dropout12(name: String, attributes: Map>, inputs: List, outputs: List) : Dropout(name, INFO, attributes, inputs, outputs) { companion object { private val TYPE_CONSTRAINTS_T = setOf( @@ -59,7 +59,7 @@ class Dropout13(name: String, attributes: Map>, inputs: L IOInfo(1, TYPE_CONSTRAINTS_T2, "mask", optional = true, differentiable = false) ) - internal val VERSION = VersionInfo(sinceVersion = 13) + internal val VERSION = VersionInfo(sinceVersion = 12) private val INFO = OperatorInfo("Dropout", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) } diff --git a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Squeeze.kt b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Squeeze.kt index 38d521038..e329b239f 100644 --- a/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Squeeze.kt +++ b/inference/inference-core/src/jvmMain/kotlin/io/kinference.core/operators/tensor/Squeeze.kt @@ -5,6 +5,7 @@ import io.kinference.core.data.tensor.KITensor import io.kinference.core.data.tensor.asTensor import io.kinference.data.ONNXData import io.kinference.graph.Contexts +import io.kinference.ndarray.arrays.LongNDArray import io.kinference.operator.* import io.kinference.protobuf.message.AttributeProto import io.kinference.utils.toIntArray @@ -15,7 +16,8 @@ sealed class Squeeze(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) = when (version ?: DEFAULT_VERSION.sinceVersion) { in SqueezeVer1.VERSION.asRange() -> SqueezeVer1(name, attributes, inputs, outputs) - else -> error("Unsupported version of Constant operator: $version") + in SqueezeVer13.VERSION.asRange() -> SqueezeVer13(name, attributes, inputs, outputs) + else -> error("Unsupported version of Squeeze operator: $version") } } } @@ -44,3 +46,29 @@ class SqueezeVer1(name: String, attributes: Map>, inputs: return listOf(inputs.first()!!.data.toMutable().squeeze(*squeezeAxes).asTensor()) } } + +class SqueezeVer13(name: String, attributes: Map>, inputs: List, outputs: List) : Squeeze(name, INFO, attributes, inputs, outputs) { + companion object { + private val TYPE_CONSTRAINTS = ALL_DATA_TYPES + + private val ATTRIBUTES_INFO = listOf() + + private val INPUTS_INFO = listOf( + IOInfo(0, TYPE_CONSTRAINTS, "data", optional = false), + IOInfo(1, INT_DATA_TYPES, "axes", optional = true) + ) + + private val OUTPUTS_INFO = listOf(IOInfo(0, TYPE_CONSTRAINTS, "squeezed", optional = false)) + + internal val VERSION = VersionInfo(sinceVersion = 13, untilVersion = 21) + private val INFO = OperatorInfo("Squeeze", ATTRIBUTES_INFO, INPUTS_INFO, OUTPUTS_INFO, VERSION, OperatorInfo.DEFAULT_DOMAIN) + } + + override suspend fun > apply(contexts: Contexts, inputs: List): List { + val axes = inputs.getOrNull(1)?.data as? LongNDArray? + if (axes != null && axes.rank != 1) error("Axes attribute must be an 1D tensor") + + val squeezeAxes = axes?.array?.toArray()?.toIntArray() ?: IntArray(0) + return listOf(inputs.first()!!.data.toMutable().squeeze(*squeezeAxes).asTensor()) + } +} diff --git a/inference/inference-core/src/jvmTest/kotlin/io/kinference/operators/operations/SqueezeTest.kt b/inference/inference-core/src/jvmTest/kotlin/io/kinference/operators/operations/SqueezeTest.kt index b0856866d..f4d9981ea 100644 --- a/inference/inference-core/src/jvmTest/kotlin/io/kinference/operators/operations/SqueezeTest.kt +++ b/inference/inference-core/src/jvmTest/kotlin/io/kinference/operators/operations/SqueezeTest.kt @@ -4,8 +4,22 @@ import io.kinference.KITestEngine.KIAccuracyRunner import io.kinference.utils.TestRunner import kotlin.test.Test -class SqueezeTest { - private fun getTargetPath(dirName: String) = "squeeze/$dirName/" +class SqueezeVer1Test { + private fun getTargetPath(dirName: String) = "squeeze/v1/$dirName/" + + @Test + fun test_squeeze() = TestRunner.runTest { + KIAccuracyRunner.runFromResources(getTargetPath("test_squeeze")) + } + + @Test + fun test_squeeze_with_negative_axes() = TestRunner.runTest { + KIAccuracyRunner.runFromResources(getTargetPath("test_squeeze_negative_axes")) + } +} + +class SqueezeVer13Test { + private fun getTargetPath(dirName: String) = "squeeze/v13/$dirName/" @Test fun test_squeeze() = TestRunner.runTest { diff --git a/settings.gradle.kts b/settings.gradle.kts index a0fda8249..2b93d5119 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -28,6 +28,8 @@ include(":adapters:kmath:adapter-kmath-core") include(":adapters:kmath:adapter-kmath-ort") include(":adapters:kmath:adapter-kmath-ort-gpu") +include(":examples") + pluginManagement { repositories { diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze/descriptor.txt b/utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze/descriptor.txt similarity index 100% rename from utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze/descriptor.txt rename to utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze/descriptor.txt diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze/model.onnx b/utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze/model.onnx similarity index 100% rename from utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze/model.onnx rename to utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze/model.onnx diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze/test_data_set_0/input_0.pb similarity index 100% rename from utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze/test_data_set_0/input_0.pb rename to utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze/test_data_set_0/input_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze/test_data_set_0/output_0.pb similarity index 100% rename from utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze/test_data_set_0/output_0.pb rename to utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze/test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze_negative_axes/descriptor.txt b/utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze_negative_axes/descriptor.txt similarity index 100% rename from utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze_negative_axes/descriptor.txt rename to utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze_negative_axes/descriptor.txt diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze_negative_axes/model.onnx b/utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze_negative_axes/model.onnx similarity index 100% rename from utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze_negative_axes/model.onnx rename to utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze_negative_axes/model.onnx diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze_negative_axes/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze_negative_axes/test_data_set_0/input_0.pb similarity index 100% rename from utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze_negative_axes/test_data_set_0/input_0.pb rename to utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze_negative_axes/test_data_set_0/input_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze_negative_axes/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze_negative_axes/test_data_set_0/output_0.pb similarity index 100% rename from utils/utils-testing/src/commonMain/resources/squeeze/test_squeeze_negative_axes/test_data_set_0/output_0.pb rename to utils/utils-testing/src/commonMain/resources/squeeze/v1/test_squeeze_negative_axes/test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/descriptor.txt b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/descriptor.txt new file mode 100644 index 000000000..d0d4ef393 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/descriptor.txt @@ -0,0 +1,3 @@ +test_data_set_0/input_0.pb +test_data_set_0/input_1.pb +test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/model.onnx b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/model.onnx new file mode 100644 index 000000000..26798f82e --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/model.onnx @@ -0,0 +1,19 @@ + backend-test:o + +x +axesy"Squeeze test_squeezeZ +x + + + + +Z +axes + + +b +y + + + +B \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/test_data_set_0/input_0.pb new file mode 100644 index 000000000..09d5a14a4 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +BxJx?h>z?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/test_data_set_0/input_1.pb b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/test_data_set_0/input_1.pb new file mode 100644 index 000000000..ec9874a7b Binary files /dev/null and b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/test_data_set_0/input_1.pb differ diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/test_data_set_0/output_0.pb new file mode 100644 index 000000000..ad67ee7f4 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze/test_data_set_0/output_0.pb @@ -0,0 +1 @@ +ByJx?h>z?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/descriptor.txt b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/descriptor.txt new file mode 100644 index 000000000..d0d4ef393 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/descriptor.txt @@ -0,0 +1,3 @@ +test_data_set_0/input_0.pb +test_data_set_0/input_1.pb +test_data_set_0/output_0.pb diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/model.onnx b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/model.onnx new file mode 100644 index 000000000..e148d48ca --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/model.onnx @@ -0,0 +1,19 @@ + backend-test:} + +x +axesy"Squeezetest_squeeze_negative_axesZ +x + + + + +Z +axes + + +b +y + + + +B \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/test_data_set_0/input_0.pb b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/test_data_set_0/input_0.pb new file mode 100644 index 000000000..55c2bdcf0 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +BxJz?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B> \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/test_data_set_0/input_1.pb b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/test_data_set_0/input_1.pb new file mode 100644 index 000000000..2f4bbd39a --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/test_data_set_0/input_1.pb @@ -0,0 +1 @@ +BaxesJ \ No newline at end of file diff --git a/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/test_data_set_0/output_0.pb b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/test_data_set_0/output_0.pb new file mode 100644 index 000000000..ed4b2d0f6 --- /dev/null +++ b/utils/utils-testing/src/commonMain/resources/squeeze/v13/test_squeeze_negative_axes/test_data_set_0/output_0.pb @@ -0,0 +1 @@ +ByJz?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B> \ No newline at end of file