Skip to content

Commit

Permalink
JBAI-4915 [core] cache clearing after each cycle of testing
Browse files Browse the repository at this point in the history
  • Loading branch information
cupertank committed Jun 18, 2024
1 parent f6d9269 commit c84525f
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import io.kinference.runners.AccuracyRunner
import io.kinference.runners.PerformanceRunner
import io.kinference.utils.*

object KITestEngine : TestEngine<KIONNXData<*>>(KIEngine) {
object KITestEngine : TestEngine<KIONNXData<*>>(KIEngine), Cacheable {
override fun checkEquals(expected: KIONNXData<*>, actual: KIONNXData<*>, delta: Double) {
KIAssertions.assertEquals(expected, actual, delta)
}
Expand All @@ -26,6 +26,10 @@ object KITestEngine : TestEngine<KIONNXData<*>>(KIEngine) {
}
}

override fun clearCache() {
KIEngine.clearCache()
}

val KIAccuracyRunner = AccuracyRunner(KITestEngine)
val KIPerformanceRunner = PerformanceRunner(KITestEngine)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ interface MemoryProfileable {
fun allocatedMemory(): Int
}

interface Cacheable {
fun clearCache()
}

expect object TestLoggerFactory {
fun create(name: String): KILogger
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,13 @@ class AccuracyRunner<T : ONNXData<*, *>>(private val testEngine: TestEngine<T>)
}
inputs.forEach { it.close() }
expectedOutputs.forEach { it.close() }

if (testEngine is Cacheable) {
testEngine.clearCache()
}
}
model.close()

if (testEngine is MemoryProfileable) {
assertEquals(0, testEngine.allocatedMemory(), "Memory leak found after model dispose")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.kinference.runners

import io.kinference.TestEngine
import io.kinference.TestLoggerFactory
import io.kinference.*
import io.kinference.data.ONNXData
import io.kinference.data.ONNXDataType
import io.kinference.model.Model
Expand Down Expand Up @@ -103,6 +102,10 @@ class PerformanceRunner<T : ONNXData<*, *>>(private val engine: TestEngine<T>) {
}

inputs.forEach { it.close() }

if (engine is Cacheable) {
engine.clearCache()
}
}

model.close()
Expand Down

0 comments on commit c84525f

Please sign in to comment.