Skip to content

Commit

Permalink
[FLINK-34996][Connectors/Kafka] Allow custom Serializer/Deserializer …
Browse files Browse the repository at this point in the history
…initialization and remove mockito.
  • Loading branch information
hugogu committed Apr 4, 2024
1 parent 68da758 commit 2631db7
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,10 @@ class KafkaSerializerWrapper<IN> implements SerializationSchema<IN> {
@SuppressWarnings("unchecked")
@Override
public void open(InitializationContext context) throws Exception {
final ClassLoader userCodeClassLoader = context.getUserCodeClassLoader().asClassLoader();
final ClassLoader userCodeClassLoader = selectClassLoader(context);
try (TemporaryClassLoaderContext ignored =
TemporaryClassLoaderContext.of(userCodeClassLoader)) {
serializer =
InstantiationUtil.instantiate(
serializerClass.getName(),
Serializer.class,
userCodeClassLoader);
initializeSerializer(userCodeClassLoader);

if (serializer instanceof Configurable) {
((Configurable) serializer).configure(config);
Expand All @@ -88,4 +84,20 @@ public byte[] serialize(IN element) {
checkState(serializer != null, "Call open() once before trying to serialize elements.");
return serializer.serialize(topicSelector.apply(element), element);
}

/**
* Selects the class loader to be used when instantiating the serializer.
* Using a class loader with user code allows users to customize the serializer.
*/
protected ClassLoader selectClassLoader(InitializationContext context) {
return context.getUserCodeClassLoader().asClassLoader();
}

protected void initializeSerializer(ClassLoader classLoader) throws Exception {
serializer =
InstantiationUtil.instantiate(
serializerClass.getName(),
Serializer.class,
classLoader);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,10 @@ class KafkaValueOnlyDeserializerWrapper<T> implements KafkaRecordDeserialization
@Override
@SuppressWarnings("unchecked")
public void open(DeserializationSchema.InitializationContext context) throws Exception {
ClassLoader userCodeClassLoader = context.getUserCodeClassLoader().asClassLoader();
ClassLoader userCodeClassLoader = selectClassLoader(context);
try (TemporaryClassLoaderContext ignored =
TemporaryClassLoaderContext.of(userCodeClassLoader)) {
deserializer =
(Deserializer<T>)
InstantiationUtil.instantiate(
deserializerClass.getName(),
Deserializer.class,
userCodeClassLoader);
initializeDeserializer(userCodeClassLoader);

if (deserializer instanceof Configurable) {
((Configurable) deserializer).configure(config);
Expand Down Expand Up @@ -103,4 +98,21 @@ public void deserialize(ConsumerRecord<byte[], byte[]> record, Collector<T> coll
public TypeInformation<T> getProducedType() {
return TypeExtractor.createTypeInfo(Deserializer.class, deserializerClass, 0, null, null);
}

/**
* Selects the class loader to be used when instantiating the deserializer.
* Using a class loader with user code allows users to customize the deserializer.
*/
protected ClassLoader selectClassLoader(DeserializationSchema.InitializationContext context) {
return context.getUserCodeClassLoader().asClassLoader();
}

protected void initializeDeserializer(ClassLoader classLoader) throws Exception {
deserializer =
(Deserializer<T>)
InstantiationUtil.instantiate(
deserializerClass.getName(),
Deserializer.class,
classLoader);
}
}
Original file line number Diff line number Diff line change
@@ -1,28 +1,58 @@
package org.apache.flink.connector.kafka.sink;

import org.apache.flink.streaming.connectors.kafka.testutils.SerializationTestBase;
import org.apache.flink.api.common.serialization.SerializationSchema;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.apache.flink.util.FlinkUserCodeClassLoaders;
import org.apache.flink.util.SimpleUserCodeClassLoader;
import org.apache.flink.util.UserCodeClassLoader;

import org.apache.kafka.common.serialization.StringSerializer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.junit.MockitoJUnitRunner;

import static org.mockito.Mockito.when;
import java.net.URL;

@RunWith(MockitoJUnitRunner.class)
public class KafkaSerializerWrapperTest extends SerializationTestBase {
@Override
protected void setupContext() {
when(serializationContext.getUserCodeClassLoader()).thenReturn(userCodeClassLoader);
}
import static org.junit.Assert.assertEquals;

/**
* Tests for {@link KafkaSerializerWrapper}.
*/
public class KafkaSerializerWrapperTest {
@Test
public void testUserCodeClassLoaderIsUsed() throws Exception {
final KafkaSerializerWrapper<String> wrapper =
new KafkaSerializerWrapper<>(StringSerializer.class, true, (value) -> "topic");
final KafkaSerializerWrapperCaptureForTest wrapper = new KafkaSerializerWrapperCaptureForTest();
final ClassLoader classLoader = FlinkUserCodeClassLoaders.childFirst(
new URL[0], getClass().getClassLoader(), new String[0], throwable -> {}, true);
wrapper.open(new SerializationSchema.InitializationContext() {
@Override
public MetricGroup getMetricGroup() {
return new UnregisteredMetricsGroup();
}

@Override
public UserCodeClassLoader getUserCodeClassLoader() {
return SimpleUserCodeClassLoader.create(classLoader);
}
});

assertEquals(classLoader, wrapper.getClassLoaderUsed());
}

static class KafkaSerializerWrapperCaptureForTest extends KafkaSerializerWrapper<String> {
private ClassLoader classLoaderUsed;

KafkaSerializerWrapperCaptureForTest() {
super(StringSerializer.class, true, (value) -> "topic");
}

public ClassLoader getClassLoaderUsed() {
return classLoaderUsed;
}

testUserClassLoaderIsUsedWhen(() -> {
wrapper.open(serializationContext);
return null;
}, new StringSerializer());
@Override
protected void initializeSerializer(ClassLoader classLoader) throws Exception {
classLoaderUsed = classLoader;
super.initializeSerializer(classLoader);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,32 +1,60 @@
package org.apache.flink.connector.kafka.source.reader.deserializer;

import org.apache.flink.streaming.connectors.kafka.testutils.SerializationTestBase;
import org.apache.flink.api.common.serialization.DeserializationSchema;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.apache.flink.util.FlinkUserCodeClassLoaders;
import org.apache.flink.util.SimpleUserCodeClassLoader;
import org.apache.flink.util.UserCodeClassLoader;

import org.apache.kafka.common.serialization.StringDeserializer;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.junit.MockitoJUnitRunner;

import java.net.URL;
import java.util.HashMap;
import java.util.Map;

import static org.mockito.Mockito.when;

@RunWith(MockitoJUnitRunner.class)
public class KafkaValueOnlyDeserializerWrapperTest extends SerializationTestBase {
@Override
protected void setupContext() {
when(deserializationContext.getUserCodeClassLoader()).thenReturn(userCodeClassLoader);
}
import static org.junit.Assert.assertEquals;

/**
* Tests for {@link KafkaValueOnlyDeserializerWrapper}.
*/
public class KafkaValueOnlyDeserializerWrapperTest {
@Test
public void testUserCodeClassLoaderIsUsed() throws Exception {
final Map<String, String> config = new HashMap<>();
final KafkaValueOnlyDeserializerWrapper<String> wrapper =
new KafkaValueOnlyDeserializerWrapper<>(StringDeserializer.class, config);

testUserClassLoaderIsUsedWhen(() -> {
wrapper.open(deserializationContext);
return null;
}, new StringDeserializer());
final KafkaValueOnlyDeserializerWrapperCaptureForTest wrapper =
new KafkaValueOnlyDeserializerWrapperCaptureForTest();
final ClassLoader classLoader = FlinkUserCodeClassLoaders.childFirst(
new URL[0], getClass().getClassLoader(), new String[0], throwable -> {}, true);
wrapper.open(new DeserializationSchema.InitializationContext() {
@Override
public MetricGroup getMetricGroup() {
return new UnregisteredMetricsGroup();
}

@Override
public UserCodeClassLoader getUserCodeClassLoader() {
return SimpleUserCodeClassLoader.create(classLoader);
}
});

assertEquals(classLoader, wrapper.getClassLoaderUsed());
}

static class KafkaValueOnlyDeserializerWrapperCaptureForTest extends KafkaValueOnlyDeserializerWrapper<String> {
private ClassLoader classLoaderUsed;

KafkaValueOnlyDeserializerWrapperCaptureForTest() {
super(StringDeserializer.class, new HashMap<>());
}

public ClassLoader getClassLoaderUsed() {
return classLoaderUsed;
}

@Override
protected void initializeDeserializer(ClassLoader classLoader) throws Exception {
classLoaderUsed = classLoader;
super.initializeDeserializer(classLoader);
}
}
}

This file was deleted.

This file was deleted.

0 comments on commit 2631db7

Please sign in to comment.