Skip to content

Commit

Permalink
feat: JSON Masking
Browse files Browse the repository at this point in the history
  • Loading branch information
jamfor352 committed Sep 19, 2024
1 parent b6468bf commit 3437ae0
Show file tree
Hide file tree
Showing 24 changed files with 582 additions and 88 deletions.
6 changes: 5 additions & 1 deletion src/main/java/org/akhq/configs/DataMasking.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,9 @@
@ConfigurationProperties("akhq.security.data-masking")
@Data
public class DataMasking {
List<DataMaskingFilter> filters = new ArrayList<>();
List<RegexFilter> filters = new ArrayList<>();
DataMaskingMode mode = DataMaskingMode.REGEX; // set this by default to REGEX for backwards compatibility for current users who haven't defined this property.
List<JsonMaskingFilter> jsonFilters = new ArrayList<>();
String jsonMaskReplacement = "xxxx";
boolean cachingEnabled = false;
}
12 changes: 12 additions & 0 deletions src/main/java/org/akhq/configs/DataMaskingMode.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.akhq.configs;

public enum DataMaskingMode {
// Use the existing regex-based filtering
REGEX,
// Use filtering where by default all fields of all records are masked, with fields to unmask defined in allowlists
JSON_MASK_BY_DEFAULT,
// Use filtering where by default no fields of any records are masked, with fields to mask explicitly denied
JSON_SHOW_BY_DEFAULT,
// No masker at all, best performance
NONE
}
14 changes: 14 additions & 0 deletions src/main/java/org/akhq/configs/JsonMaskingFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.akhq.configs;

import io.micronaut.context.annotation.EachProperty;
import lombok.Data;

import java.util.List;

@EachProperty("jsonfilters")
@Data
public class JsonMaskingFilter {
String description = "UNKNOWN";
String topic = "UNKNOWN";
List<String> keys = List.of("UNKNOWN");
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

@EachProperty("filters")
@Data
public class DataMaskingFilter {
public class RegexFilter {
String description;
String searchRegex;
String replacement;
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/akhq/models/Record.java
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ public void setTruncated(Boolean truncated) {
this.truncated = truncated;
}

public void setTopic(Topic topic) {
this.topic = topic;
}

private String convertToString(byte[] payload, String schemaId, boolean isKey) {
if (payload == null) {
return null;
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/org/akhq/repositories/RecordRepository.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.akhq.modules.schemaregistry.RecordWithSchemaSerializerFactory;
import org.akhq.utils.AvroToJsonSerializer;
import org.akhq.utils.Debug;
import org.akhq.utils.MaskingUtils;
import org.akhq.utils.Masker;
import org.apache.kafka.clients.admin.DeletedRecords;
import org.apache.kafka.clients.admin.RecordsToDelete;
import org.apache.kafka.clients.consumer.*;
Expand Down Expand Up @@ -79,7 +79,7 @@ public class RecordRepository extends AbstractRepository {
private AvroWireFormatConverter avroWireFormatConverter;

@Inject
private MaskingUtils maskingUtils;
private Masker masker;

@Value("${akhq.topic-data.poll-timeout:10000}")
protected int pollTimeout;
Expand Down Expand Up @@ -453,7 +453,7 @@ private ConsumerRecords<byte[], byte[]> poll(KafkaConsumer<byte[], byte[]> consu
private Record newRecord(ConsumerRecord<byte[], byte[]> record, String clusterId, Topic topic) {
SchemaRegistryType schemaRegistryType = this.schemaRegistryRepository.getSchemaRegistryType(clusterId);
SchemaRegistryClient client = this.kafkaModule.getRegistryClient(clusterId);
return maskingUtils.maskRecord(new Record(
return masker.maskRecord(new Record(
client,
record,
this.schemaRegistryRepository.getSchemaRegistryType(clusterId),
Expand All @@ -473,7 +473,7 @@ private Record newRecord(ConsumerRecord<byte[], byte[]> record, String clusterId
private Record newRecord(ConsumerRecord<byte[], byte[]> record, BaseOptions options, Topic topic) {
SchemaRegistryType schemaRegistryType = this.schemaRegistryRepository.getSchemaRegistryType(options.clusterId);
SchemaRegistryClient client = this.kafkaModule.getRegistryClient(options.clusterId);
return maskingUtils.maskRecord(new Record(
return masker.maskRecord(new Record(
client,
record,
schemaRegistryType,
Expand Down
19 changes: 16 additions & 3 deletions src/main/java/org/akhq/utils/AvroSerializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeFormatterBuilder;
import java.time.temporal.ChronoField;
import java.util.*;
import java.util.stream.Collectors;

Expand All @@ -45,7 +46,19 @@ public class AvroSerializer {
private static final TimeConversions.LocalTimestampMillisConversion LOCAL_TIMESTAMP_MILLIS_CONVERSION = new TimeConversions.LocalTimestampMillisConversion();

protected static final String DATE_FORMAT = "yyyy-MM-dd[XXX]";
protected static final String TIME_FORMAT = "HH:mm[:ss][.SSSSSS][XXX]";
protected static final DateTimeFormatter TIME_FORMATTER = new DateTimeFormatterBuilder()
.appendPattern("HH:mm")
.optionalStart()
.appendPattern(":ss")
.optionalEnd()
.optionalStart()
.appendFraction(ChronoField.NANO_OF_SECOND, 1, 9, true)
.optionalEnd()
.optionalStart()
.appendPattern("XXX")
.optionalEnd()
.toFormatter();

protected static final DateTimeFormatter DATETIME_FORMAT = new DateTimeFormatterBuilder()
.append(DateTimeFormatter.ISO_LOCAL_DATE_TIME)
.optionalStart()
Expand Down Expand Up @@ -323,7 +336,7 @@ protected static Instant parseDateTime(String data) {
private static Long timeMicrosSerializer(Object data, Schema schema, Schema.Type primitiveType, LogicalType logicalType) {
LocalTime value;
if (data instanceof String) {
value = LocalTime.parse((String) data, DateTimeFormatter.ofPattern(AvroSerializer.TIME_FORMAT));
value = LocalTime.parse((String) data, TIME_FORMATTER);
} else {
value = (LocalTime) data;
}
Expand All @@ -339,7 +352,7 @@ private static Integer timeMillisSerializer(Object data, Schema schema, Schema.T
LocalTime value;

if (data instanceof String) {
value = LocalTime.parse((String) data, DateTimeFormatter.ofPattern(AvroSerializer.TIME_FORMAT));
value = LocalTime.parse((String) data, TIME_FORMATTER);
} else {
value = (LocalTime) data;
}
Expand Down
65 changes: 65 additions & 0 deletions src/main/java/org/akhq/utils/JsonMaskByDefaultMasker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package org.akhq.utils;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import org.akhq.configs.JsonMaskingFilter;
import org.akhq.models.Record;

import java.util.List;
import java.util.Map;

@RequiredArgsConstructor
public class JsonMaskByDefaultMasker implements Masker {

private final List<JsonMaskingFilter> jsonMaskingFilters;
private final String jsonMaskReplacement;

public Record maskRecord(Record record) {
try {
if(record.getValue().trim().startsWith("{") && record.getValue().trim().endsWith("}")) {
JsonMaskingFilter foundFilter = null;
for (JsonMaskingFilter filter : jsonMaskingFilters) {
if (record.getTopic().getName().equalsIgnoreCase(filter.getTopic())) {
foundFilter = filter;
}
}
if (foundFilter != null) {
return applyMasking(record, foundFilter.getKeys());
} else {
return applyMasking(record, List.of());
}
} else {
return record;
}
} catch (Exception e) {
LOG.error("Error masking record", e);
return record;
}
}

@SneakyThrows
private Record applyMasking(Record record, List<String> keys) {
JsonObject jsonElement = JsonParser.parseString(record.getValue()).getAsJsonObject();
maskAllExcept(jsonElement, keys);
record.setValue(jsonElement.toString());
return record;
}

private void maskAllExcept(JsonObject node, List<String> keys) {
if (node.isJsonObject()) {
JsonObject objectNode = node.getAsJsonObject();
for(Map.Entry<String, JsonElement> entry : objectNode.entrySet()) {
if(entry.getValue().isJsonObject()) {
maskAllExcept(entry.getValue().getAsJsonObject(), keys);
} else {
if(!keys.contains(entry.getKey())) {
objectNode.addProperty(entry.getKey(), jsonMaskReplacement);
}
}
}
}
}
}
64 changes: 64 additions & 0 deletions src/main/java/org/akhq/utils/JsonShowByDefaultMasker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package org.akhq.utils;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import org.akhq.configs.JsonMaskingFilter;
import org.akhq.models.Record;

import java.util.List;

@RequiredArgsConstructor
public class JsonShowByDefaultMasker implements Masker {

private final List<JsonMaskingFilter> jsonMaskingFilters;
private final String jsonMaskReplacement;

public Record maskRecord(Record record) {
try {
if(record.getValue().trim().startsWith("{") && record.getValue().trim().endsWith("}")) {
JsonMaskingFilter foundFilter = null;
for (JsonMaskingFilter filter : jsonMaskingFilters) {
if (record.getTopic().getName().equalsIgnoreCase(filter.getTopic())) {
foundFilter = filter;
}
}
if (foundFilter != null) {
return applyMasking(record, foundFilter.getKeys());
} else {
return record;
}
} else {
return record;
}
} catch (Exception e) {
LOG.error("Error masking record", e);
return record;
}
}

@SneakyThrows
private Record applyMasking(Record record, List<String> keys) {
JsonObject jsonElement = JsonParser.parseString(record.getValue()).getAsJsonObject();
for(String key : keys) {
maskField(jsonElement, key.split("\\."), 0);
}
record.setValue(jsonElement.toString());
return record;
}

private void maskField(JsonObject node, String[] keys, int index) {
if (index == keys.length - 1) {
if (node.has(keys[index])) {
node.addProperty(keys[index], jsonMaskReplacement);
}
} else {
JsonElement childNode = node.get(keys[index]);
if (childNode != null && childNode.isJsonObject()) {
maskField(childNode.getAsJsonObject(), keys, index + 1);
}
}
}
}
12 changes: 12 additions & 0 deletions src/main/java/org/akhq/utils/Masker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package org.akhq.utils;

import org.akhq.models.Record;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public interface Masker {

Logger LOG = LoggerFactory.getLogger(Masker.class);

Record maskRecord(Record record);
}
22 changes: 22 additions & 0 deletions src/main/java/org/akhq/utils/MaskerFactory.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package org.akhq.utils;

import io.micronaut.context.annotation.Bean;
import io.micronaut.context.annotation.Factory;
import org.akhq.configs.DataMasking;

@Factory
public class MaskerFactory {

@Bean
public Masker createMaskingUtil(DataMasking dataMasking) {
if(dataMasking == null) {
return new NoOpMasker();
}
return switch(dataMasking.getMode()) {
case REGEX -> new RegexMasker(dataMasking.getFilters());
case JSON_MASK_BY_DEFAULT -> new JsonMaskByDefaultMasker(dataMasking.getJsonFilters(), dataMasking.getJsonMaskReplacement());
case JSON_SHOW_BY_DEFAULT -> new JsonShowByDefaultMasker(dataMasking.getJsonFilters(), dataMasking.getJsonMaskReplacement());
case NONE -> new NoOpMasker();
};
}
}
11 changes: 11 additions & 0 deletions src/main/java/org/akhq/utils/NoOpMasker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.akhq.utils;

import org.akhq.models.Record;

public class NoOpMasker implements Masker {

@Override
public Record maskRecord(Record record) {
return record;
}
}
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
package org.akhq.utils;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import org.akhq.configs.DataMasking;
import org.akhq.configs.DataMaskingFilter;
import lombok.RequiredArgsConstructor;
import org.akhq.configs.RegexFilter;
import org.akhq.models.Record;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Singleton
public class MaskingUtils {
private static final Logger LOG = LoggerFactory.getLogger(MaskingUtils.class);
import java.util.List;

@Inject
DataMasking dataMasking;
@RequiredArgsConstructor
public class RegexMasker implements Masker {

public Record maskRecord(Record record) {
LOG.trace("masking record");
private final List<RegexFilter> filters;

@Override
public Record maskRecord(Record record) {
String value = record.getValue();
String key = record.getKey();

for (DataMaskingFilter filter : dataMasking.getFilters()) {
for (RegexFilter filter : filters) {
if (value != null) {
value = value.replaceAll(filter.getSearchRegex(), filter.getReplacement());
}
Expand All @@ -35,4 +30,4 @@ public Record maskRecord(Record record) {

return record;
}
}
}
29 changes: 29 additions & 0 deletions src/test/java/org/akhq/utils/DefaultMaskerSettingTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.akhq.utils;

import io.micronaut.test.extensions.junit5.annotation.MicronautTest;
import jakarta.inject.Inject;
import org.akhq.configs.DataMasking;
import org.junit.jupiter.api.Test;

import static org.akhq.configs.DataMaskingMode.REGEX;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;

@MicronautTest
public class DefaultMaskerSettingTest {

@Inject
DataMasking dataMasking;

@Inject
Masker masker;

@Test
void defaultValuesShouldUseRegexForBackwardsCompatibility() {
assertEquals(
REGEX,
dataMasking.getMode()
);
assertInstanceOf(RegexMasker.class, masker);
}
}
Loading

0 comments on commit 3437ae0

Please sign in to comment.