Skip to content

Commit

Permalink
Parquet codec tests fix (#4698)
Browse files Browse the repository at this point in the history
Parquet codec tests fix

Signed-off-by: Krishna Kondaka <krishkdk@dev-dsk-krishkdk-2c-bd29c437.us-west-2.amazon.com>
  • Loading branch information
kkondaka authored Jul 2, 2024
1 parent c6ca7ad commit a466013
Showing 1 changed file with 74 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@

import org.apache.avro.Schema;
import org.apache.avro.SchemaBuilder;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.parquet.ParquetReadOptions;
import org.apache.parquet.column.page.PageReadStore;
import org.apache.parquet.example.data.Group;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.example.data.simple.SimpleGroup;
import org.apache.parquet.example.data.simple.convert.GroupRecordConverter;
import org.apache.parquet.hadoop.ParquetFileReader;
import org.apache.parquet.hadoop.metadata.ParquetMetadata;
import org.apache.parquet.hadoop.util.HadoopInputFile;
import org.apache.parquet.io.ColumnIOFactory;
import org.apache.parquet.io.LocalInputFile;
import org.apache.parquet.io.MessageColumnIO;
import org.apache.parquet.io.RecordReader;
import org.apache.parquet.schema.MessageType;
Expand All @@ -39,12 +42,11 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.IOException;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.Collections;
Expand All @@ -59,6 +61,7 @@

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
Expand Down Expand Up @@ -114,11 +117,12 @@ void test_happy_case(final int numberOfRecords) throws Exception {
parquetOutputCodec.writeEvent(event, outputStream);
}
parquetOutputCodec.complete(outputStream);
List<Map<String, Object>> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes()));
List<Map<String, Object>> actualRecords = createParquetRecordsList(new FileInputStream(tempFile));
int index = 0;
assertThat(inputMaps.size(), equalTo(actualRecords.size()));
for (final Map<String, Object> actualMap : actualRecords) {
assertThat(actualMap, notNullValue());
Map expectedMap = generateRecords(numberOfRecords).get(index);
Map expectedMap = inputMaps.get(index);
assertThat(expectedMap, Matchers.equalTo(actualMap));
index++;
}
Expand All @@ -141,14 +145,16 @@ void test_happy_case_nullable_records(final int numberOfRecords) throws Exceptio
parquetOutputCodec.writeEvent(event, outputStream);
}
parquetOutputCodec.complete(outputStream);
List<Map<String, Object>> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes()));
List<Map<String, Object>> actualRecords = createParquetRecordsList(new FileInputStream(tempFile));
int index = 0;
assertThat(inputMaps.size(), equalTo(actualRecords.size()));
for (final Map<String, Object> actualMap : actualRecords) {
assertThat(actualMap, notNullValue());
Map expectedMap = generateRecords(numberOfRecords).get(index);
Map expectedMap = inputMaps.get(index);
assertThat(expectedMap, Matchers.equalTo(actualMap));
index++;
}
outputStream.close();
tempFile.delete();
}

Expand All @@ -167,11 +173,12 @@ void test_happy_case_nullable_records_with_empty_maps(final int numberOfRecords)
parquetOutputCodec.writeEvent(event, outputStream);
}
parquetOutputCodec.complete(outputStream);
List<Map<String, Object>> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes()));
List<Map<String, Object>> actualRecords = createParquetRecordsList(new FileInputStream(tempFile));
int index = 0;
assertThat(inputMaps.size(), equalTo(actualRecords.size()));
for (final Map<String, Object> actualMap : actualRecords) {
assertThat(actualMap, notNullValue());
Map expectedMap = generateRecords(numberOfRecords).get(index);
Map expectedMap = inputMaps.get(index);
assertThat(expectedMap, Matchers.equalTo(actualMap));
index++;
}
Expand All @@ -193,6 +200,9 @@ void writeEvent_includes_record_when_field_does_not_exist_in_user_supplied_schem
final Event eventWithInvalidField = mock(Event.class);
final String invalidFieldName = UUID.randomUUID().toString();
Map<String, Object> mapWithInvalid = generateRecords(1).get(0);
Map<String, Object> mapWithoutInvalid = mapWithInvalid.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
mapWithInvalid.put(invalidFieldName, UUID.randomUUID().toString());
when(eventWithInvalidField.toMap()).thenReturn(mapWithInvalid);
final ParquetOutputCodec objectUnderTest = createObjectUnderTest();
Expand All @@ -204,12 +214,12 @@ void writeEvent_includes_record_when_field_does_not_exist_in_user_supplied_schem
objectUnderTest.writeEvent(eventWithInvalidField, outputStream);

objectUnderTest.complete(outputStream);
List<Map<String, Object>> actualRecords = createParquetRecordsList(new ByteArrayInputStream(tempFile.toString().getBytes()));
List<Map<String, Object>> actualRecords = createParquetRecordsList(new FileInputStream(tempFile));
int index = 0;
for (final Map<String, Object> actualMap : actualRecords) {
assertThat(actualMap, notNullValue());
Map expectedMap = generateRecords(1).get(index);
assertThat(expectedMap, Matchers.equalTo(actualMap));
assertThat(mapWithInvalid, not(Matchers.equalTo(actualMap)));
assertThat(mapWithoutInvalid, Matchers.equalTo(actualMap));
index++;
}
}
Expand Down Expand Up @@ -550,12 +560,34 @@ private static Schema createStandardInnerSchemaForNestedRecord(
return assembler.endRecord();
}

private List<Map<String, Object>> createParquetRecordsList(final InputStream inputStream) throws IOException {
private List<String> extractStringList(SimpleGroup group, String fieldName) {
int fieldIndex = group.getType().getFieldIndex(fieldName);
int repetitionCount = group.getGroup(fieldIndex, 0).getFieldRepetitionCount(0);
List<String> resultList = new ArrayList<>();
for (int i = 0; i < repetitionCount; i++) {
resultList.add(group.getGroup(fieldIndex, 0).getString(0, i));
}
return resultList;
}

private Map<String, Object> extractNestedGroup(SimpleGroup group, String fieldName) {

Map<String, Object> resultMap = new HashMap<>();
int fieldIndex = group.getType().getFieldIndex(fieldName);
int f1 = group.getGroup(fieldIndex, 0).getType().getFieldIndex("firstFieldInNestedRecord");
resultMap.put("firstFieldInNestedRecord", group.getGroup(fieldIndex, 0).getString(f1,0));
int f2 = group.getGroup(fieldIndex, 0).getType().getFieldIndex("secondFieldInNestedRecord");
resultMap.put("secondFieldInNestedRecord", group.getGroup(fieldIndex, 0).getInteger(f2,0));

return resultMap;
}

private List<Map<String, Object>> createParquetRecordsList(final InputStream inputStream) throws IOException, RuntimeException {

final File tempFile = new File(tempDirectory, FILE_NAME);
Files.copy(inputStream, tempFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
List<Map<String, Object>> actualRecordList = new ArrayList<>();
try (final ParquetFileReader parquetFileReader = new ParquetFileReader(new LocalInputFile(Path.of(tempFile.toURI())), ParquetReadOptions.builder().build())) {
try (ParquetFileReader parquetFileReader = new ParquetFileReader(HadoopInputFile.fromPath(new Path(tempFile.toURI()), new Configuration()), ParquetReadOptions.builder().build())) {
final ParquetMetadata footer = parquetFileReader.getFooter();
final MessageType schema = createdParquetSchema(footer);
PageReadStore pages;
Expand All @@ -566,15 +598,34 @@ private List<Map<String, Object>> createParquetRecordsList(final InputStream inp
final RecordReader<Group> recordReader = columnIO.getRecordReader(pages, new GroupRecordConverter(schema));
for (int row = 0; row < rows; row++) {
final Map<String, Object> eventData = new HashMap<>();
int fieldIndex = 0;
final SimpleGroup simpleGroup = (SimpleGroup) recordReader.read();
final GroupType groupType = simpleGroup.getType();


for (Type field : schema.getFields()) {
try {
eventData.put(field.getName(), simpleGroup.getValueToString(fieldIndex, 0));
} catch (Exception parquetException) {
LOG.error("Failed to parse Parquet", parquetException);
Object value;
int fieldIndex = groupType.getFieldIndex(field.getName());
if (simpleGroup.getFieldRepetitionCount(fieldIndex) == 0) {
continue;
}
switch (field.getName()) {
case "name": value = simpleGroup.getString(fieldIndex, 0);
break;
case "age": value = simpleGroup.getInteger(fieldIndex, 0);
break;
case "myLong": value = simpleGroup.getLong(fieldIndex, 0);
break;
case "myFloat": value = simpleGroup.getFloat(fieldIndex, 0);
break;
case "myDouble": value = simpleGroup.getDouble(fieldIndex, 0);
break;
case "myArray": value = extractStringList(simpleGroup, "myArray");
break;
case "nestedRecord": value = extractNestedGroup(simpleGroup, "nestedRecord");
break;
default: throw new IllegalArgumentException("Unknown field");
}
fieldIndex++;
eventData.put(field.getName(), value);
}
actualRecordList.add((HashMap) eventData);
}
Expand All @@ -590,4 +641,4 @@ private List<Map<String, Object>> createParquetRecordsList(final InputStream inp
private MessageType createdParquetSchema(ParquetMetadata parquetMetadata) {
return parquetMetadata.getFileMetaData().getSchema();
}
}
}

0 comments on commit a466013

Please sign in to comment.