diff --git a/inference/inference-core/src/commonTest/kotlin/io/kinference/KITestEngine.kt b/inference/inference-core/src/commonTest/kotlin/io/kinference/KITestEngine.kt index 506c381e2..6e2366737 100644 --- a/inference/inference-core/src/commonTest/kotlin/io/kinference/KITestEngine.kt +++ b/inference/inference-core/src/commonTest/kotlin/io/kinference/KITestEngine.kt @@ -9,7 +9,7 @@ import io.kinference.runners.AccuracyRunner import io.kinference.runners.PerformanceRunner import io.kinference.utils.* -object KITestEngine : TestEngine>(KIEngine) { +object KITestEngine : TestEngine>(KIEngine), Cacheable { override fun checkEquals(expected: KIONNXData<*>, actual: KIONNXData<*>, delta: Double) { KIAssertions.assertEquals(expected, actual, delta) } @@ -26,6 +26,10 @@ object KITestEngine : TestEngine>(KIEngine) { } } + override fun clearCache() { + KIEngine.clearCache() + } + val KIAccuracyRunner = AccuracyRunner(KITestEngine) val KIPerformanceRunner = PerformanceRunner(KITestEngine) } diff --git a/utils/utils-testing/src/commonMain/kotlin/io.kinference/TestEngine.kt b/utils/utils-testing/src/commonMain/kotlin/io.kinference/TestEngine.kt index 3f0c7bf05..559e7ba0c 100644 --- a/utils/utils-testing/src/commonMain/kotlin/io.kinference/TestEngine.kt +++ b/utils/utils-testing/src/commonMain/kotlin/io.kinference/TestEngine.kt @@ -23,6 +23,10 @@ interface MemoryProfileable { fun allocatedMemory(): Int } +interface Cacheable { + fun clearCache() +} + expect object TestLoggerFactory { fun create(name: String): KILogger } diff --git a/utils/utils-testing/src/commonMain/kotlin/io.kinference/runners/AccuracyRunner.kt b/utils/utils-testing/src/commonMain/kotlin/io.kinference/runners/AccuracyRunner.kt index 98ad2caec..f090dfca3 100644 --- a/utils/utils-testing/src/commonMain/kotlin/io.kinference/runners/AccuracyRunner.kt +++ b/utils/utils-testing/src/commonMain/kotlin/io.kinference/runners/AccuracyRunner.kt @@ -103,8 +103,13 @@ class AccuracyRunner>(private val testEngine: TestEngine) } inputs.forEach { it.close() } expectedOutputs.forEach { it.close() } + + if (testEngine is Cacheable) { + testEngine.clearCache() + } } model.close() + if (testEngine is MemoryProfileable) { assertEquals(0, testEngine.allocatedMemory(), "Memory leak found after model dispose") } diff --git a/utils/utils-testing/src/commonMain/kotlin/io.kinference/runners/PerformanceRunner.kt b/utils/utils-testing/src/commonMain/kotlin/io.kinference/runners/PerformanceRunner.kt index 4b6ac1537..8210af0b5 100644 --- a/utils/utils-testing/src/commonMain/kotlin/io.kinference/runners/PerformanceRunner.kt +++ b/utils/utils-testing/src/commonMain/kotlin/io.kinference/runners/PerformanceRunner.kt @@ -1,7 +1,6 @@ package io.kinference.runners -import io.kinference.TestEngine -import io.kinference.TestLoggerFactory +import io.kinference.* import io.kinference.data.ONNXData import io.kinference.data.ONNXDataType import io.kinference.model.Model @@ -103,6 +102,10 @@ class PerformanceRunner>(private val engine: TestEngine) { } inputs.forEach { it.close() } + + if (engine is Cacheable) { + engine.clearCache() + } } model.close()