Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix crash on DataNull value supplied to python function #108

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import io.axual.ksml.data.exception.ExecutionException;
import io.axual.ksml.data.mapper.DataTypeSchemaMapper;
import io.axual.ksml.data.object.DataNull;
import io.axual.ksml.data.object.DataObject;
import io.axual.ksml.data.object.DataStruct;
import io.axual.ksml.data.schema.DataField;
Expand Down Expand Up @@ -55,6 +56,10 @@ public DataObjectDeserializer(DataType type, DataTypeSchemaMapper dataTypeDataSc

@Override
public DataObject deserialize(String topic, byte[] data) {
if(data == null || data.length == 0) {
return DataNull.INSTANCE;
}

final var wrapper = jsonDeserializer.deserialize(topic, data);
if (wrapper == null) {
throw new ExecutionException("Retrieved unexpected null wrapper from state store " + topic);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@
* =========================LICENSE_END==================================
*/

import org.apache.kafka.common.serialization.Serializer;

import java.util.List;

import io.axual.ksml.data.exception.ExecutionException;
import io.axual.ksml.data.mapper.DataSchemaMapper;
import io.axual.ksml.data.object.DataNull;
import io.axual.ksml.data.object.DataObject;
import io.axual.ksml.data.object.DataStruct;
import io.axual.ksml.data.schema.DataField;
import io.axual.ksml.data.schema.StructSchema;
import io.axual.ksml.data.type.DataType;
import lombok.Getter;
import org.apache.kafka.common.serialization.Serializer;

import java.util.List;

import static io.axual.ksml.data.parser.schema.DataSchemaDSL.DATA_OBJECT_TYPE_NAME;
import static io.axual.ksml.data.parser.schema.DataSchemaDSL.DATA_SCHEMA_NAMESPACE;
Expand All @@ -56,6 +58,9 @@ public DataObjectSerializer(DataType type, DataSchemaMapper<DataType> dataTypeSc

@Override
public byte[] serialize(String topic, DataObject data) {
if (data == null || data instanceof DataNull) {
return new byte[0];
}
if (!expectedType.isAssignableFrom(data.type())) {
throw new ExecutionException("Incorrect type passed in: expected=" + expectedType + ", got " + data.type());
}
Expand Down
22 changes: 16 additions & 6 deletions ksml-data/src/main/java/io/axual/ksml/data/serde/UnionSerde.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
* =========================LICENSE_END==================================
*/

import io.axual.ksml.data.exception.ExecutionException;
import io.axual.ksml.data.object.DataObject;
import io.axual.ksml.data.type.DataType;
import io.axual.ksml.data.type.UnionType;
import org.apache.kafka.common.serialization.Deserializer;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.common.serialization.Serializer;
Expand All @@ -32,6 +28,12 @@
import java.util.List;
import java.util.Map;

import io.axual.ksml.data.exception.ExecutionException;
import io.axual.ksml.data.object.DataNull;
import io.axual.ksml.data.object.DataObject;
import io.axual.ksml.data.type.DataType;
import io.axual.ksml.data.type.UnionType;

public class UnionSerde implements Serde<Object> {
private record PossibleType(DataType type, Serializer<Object> serializer,
Deserializer<Object> deserializer) {
Expand All @@ -52,6 +54,10 @@ public void configure(Map<String, ?> configs, boolean isKey) {

@Override
public byte[] serialize(String topic, Object data) {
if (data == null || data instanceof DataNull) {
return new byte[0];
}

for (PossibleType possibleType : possibleTypes) {
// Check if we are serializing a DataObject. If so, then check compatibility
// using its own data dataType, else check compatibility with Java native dataType.
Expand All @@ -65,7 +71,7 @@ public byte[] serialize(String topic, Object data) {
}
}
}
throw new ExecutionException("Can not serialize object as union alternative: " + (data != null ? data.getClass().getSimpleName() : "null"));
throw new ExecutionException("Can not serialize object as union alternative: " + data.getClass().getSimpleName());
}
}

Expand All @@ -79,6 +85,10 @@ public void configure(Map<String, ?> configs, boolean isKey) {

@Override
public Object deserialize(String topic, byte[] data) {
if (data == null || data.length == 0) {
return DataNull.INSTANCE;
}

for (PossibleType possibleType : possibleTypes) {
try {
Object result = possibleType.deserializer.deserialize(topic, data);
Expand All @@ -89,7 +99,7 @@ public Object deserialize(String topic, byte[] data) {
// Not properly deserialized, so ignore and try next alternative
}
}
throw new ExecutionException("Can not deserialize data as union possible dataType" + (data != null ? data : "null"));
throw new ExecutionException("Can not deserialize data as union possible dataType" + possibleTypes);
}
}

Expand Down
21 changes: 13 additions & 8 deletions ksml/src/main/java/io/axual/ksml/python/PythonFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
*/


import org.apache.kafka.streams.processor.StateStore;
import org.graalvm.polyglot.Value;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

import io.axual.ksml.data.exception.ExecutionException;
import io.axual.ksml.data.mapper.DataObjectConverter;
import io.axual.ksml.data.object.DataNull;
Expand All @@ -33,14 +42,6 @@
import io.axual.ksml.store.StateStores;
import io.axual.ksml.user.UserFunction;
import lombok.extern.slf4j.Slf4j;
import org.apache.kafka.streams.processor.StateStore;
import org.graalvm.polyglot.Value;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

import static io.axual.ksml.data.notation.UserType.DEFAULT_NOTATION;

Expand Down Expand Up @@ -86,6 +87,10 @@ public DataObject call(StateStores stores, DataObject... parameters) {
}
// Validate the parameter types
for (int index = 0; index < parameters.length; index++) {
if(parameters[index] instanceof DataNull){
// A DataNull should always be acceptable as a function argument
continue;
}
if (!this.parameters[index].type().isAssignableFrom(parameters[index])) {
throw new TopologyException("User function \"" + name + "\" expects parameter " + (index + 1) + " (\"" + this.parameters[index].name() + "\") to be " + this.parameters[index].type() + ", but " + parameters[index].type() + " was passed in");
}
Expand Down
79 changes: 57 additions & 22 deletions ksml/src/test/java/io/axual/ksml/python/PythonFunctionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@
* =========================LICENSE_END==================================
*/

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;

import io.axual.ksml.data.notation.UserType;
import io.axual.ksml.data.notation.binary.BinaryNotation;
import io.axual.ksml.data.object.DataInteger;
import io.axual.ksml.data.object.DataObject;
import io.axual.ksml.data.object.DataPrimitive;
import io.axual.ksml.data.object.DataNull;
import io.axual.ksml.data.object.DataString;
import io.axual.ksml.definition.FunctionDefinition;
import io.axual.ksml.definition.ParameterDefinition;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;

public class PythonFunctionTest {
final PythonContext context = new PythonContext();
Expand All @@ -42,14 +45,15 @@ public class PythonFunctionTest {
@ParameterizedTest
@CsvSource({"1, 2, 3", "100,100,200", "100, -1, 99", "99, -100, -1"})
void testAdditionExpression(Integer i1, Integer i2, Integer sum) {
FunctionDefinition adderDef = FunctionDefinition.as("adder", params, null, null, "one + two", resultType, null);
PythonFunction adder = PythonFunction.forFunction(context, "test", "adder", adderDef);
final var adderDef = FunctionDefinition.as("adder", params, null, null, "one + two", resultType, null);
final var adder = PythonFunction.forFunction(context, "test", "adder", adderDef);

DataObject arg1 = new DataInteger(i1);
DataObject arg2 = new DataInteger(i2);
final var arg1 = new DataInteger(i1);
final var arg2 = new DataInteger(i2);

DataPrimitive result = (DataPrimitive) adder.call(arg1, arg2);
assertEquals(sum, result.value());
final var result = adder.call(arg1, arg2);
assertInstanceOf(DataInteger.class, result);
assertEquals(sum, ((DataInteger)result).value());
}

/**
Expand All @@ -63,14 +67,15 @@ def myAddFunc(one, two):
return one + two

""";
FunctionDefinition adderDef = FunctionDefinition.as("adder", params, null, pythonCode.split("\n"), "myAddFunc(one, two)", resultType, null);
PythonFunction adder = PythonFunction.forFunction(context, "test", "adder", adderDef);
final var adderDef = FunctionDefinition.as("adder", params, null, pythonCode.split("\n"), "myAddFunc(one, two)", resultType, null);
final var adder = PythonFunction.forFunction(context, "test", "adder", adderDef);

DataObject arg1 = new DataInteger(i1);
DataObject arg2 = new DataInteger(i2);
final var arg1 = new DataInteger(i1);
final var arg2 = new DataInteger(i2);

DataPrimitive result = (DataPrimitive) adder.call(arg1, arg2);
assertEquals(sum, result.value());
final var result = adder.call(arg1, arg2);
assertInstanceOf(DataInteger.class, result);
assertEquals(sum, ((DataInteger)result).value());
}

/**
Expand All @@ -84,13 +89,43 @@ def myAddFunc(one, two):
return one + two

""";
FunctionDefinition adderDef = FunctionDefinition.as("adder", params, pythonCode.split("\n"), null, "myAddFunc(one, two)", resultType, null);
PythonFunction adder = PythonFunction.forFunction(context, "test", "adder", adderDef);
final var adderDef = FunctionDefinition.as("adder", params, pythonCode.split("\n"), null, "myAddFunc(one, two)", resultType, null);
final var adder = PythonFunction.forFunction(context, "test", "adder", adderDef);

final var arg1 = new DataInteger(i1);
final var arg2 = new DataInteger(i2);

final var result = adder.call(arg1, arg2);
assertInstanceOf(DataInteger.class, result);
assertEquals(sum, ((DataInteger)result).value());
}

@Test
/**
* Test that Null Key/Values are accepted as parameters
*/
void testNullKeyValue() {
final var stringResultType = new UserType(BinaryNotation.NOTATION_NAME, DataString.DATATYPE);
final var concatDef = FunctionDefinition.as("concat", params, null, null, "str(one is None) + ' ' + str(two is None)", stringResultType, null);
final var concat = PythonFunction.forFunction(context, "test", "adder", concatDef);

final var nullArg = DataNull.INSTANCE;
final var nonNullArg = new DataInteger(1);

final var expectedResultNullKey = "True False";
var resultNullKey = concat.call(nullArg, nonNullArg);
assertInstanceOf(DataString.class, resultNullKey);
assertEquals(expectedResultNullKey, ((DataString) resultNullKey).value());

final var expectedResultNullValue = "False True";
var resultNullValue = concat.call(nonNullArg, nullArg);
assertInstanceOf(DataString.class, resultNullValue);
assertEquals(expectedResultNullValue, ((DataString) resultNullValue).value());

DataObject arg1 = new DataInteger(i1);
DataObject arg2 = new DataInteger(i2);
final var expectedResultNullKeyValue = "True True";
var resultNullKeyValue = concat.call(nullArg, nullArg);
assertInstanceOf(DataString.class, resultNullKeyValue);
assertEquals(expectedResultNullKeyValue, ((DataString) resultNullKeyValue).value());

DataPrimitive result = (DataPrimitive) adder.call(arg1, arg2);
assertEquals(sum, result.value());
}
}
Loading