From f64f61c1c55c2a629f22c0dd1e8f77b1b9c409b7 Mon Sep 17 00:00:00 2001 From: therealansh Date: Fri, 31 May 2024 13:32:18 +0530 Subject: [PATCH] fix(tensor): add validation to check if input can be converted to desired tensor shape --- .../kotlinx/dl/api/inference/TensorFlowInferenceModel.kt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt index 691b711be..4b3b4d7f4 100644 --- a/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt +++ b/tensorflow/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/TensorFlowInferenceModel.kt @@ -111,6 +111,9 @@ public open class TensorFlowInferenceModel( internal fun FloatData.toTensor(): Tensor { val preparedData = serializeToBuffer(floats) + require(preparedData.remaining() == shape.dims().reduce(Long::times).toInt()) { + // TODO: add more details about the shape and how to fix it. + } return Tensor.create(longArrayOf(1L, *shape.dims()), preparedData) }