Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vkorukanti committed Jun 20, 2023
1 parent e11e4fb commit 1d4f986
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.util.Comparator;

import io.delta.kernel.internal.expressions.CastingComparator;

/**
* A {@link BinaryOperator} that compares the left and right {@link Expression}s and evaluates to a
* boolean value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,49 @@
* limitations under the License.
*/

package io.delta.kernel.expressions;
package io.delta.kernel.internal.expressions;

import java.util.Comparator;

import io.delta.kernel.types.BinaryType;
import io.delta.kernel.types.BooleanType;
import io.delta.kernel.types.ByteType;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.DateType;
import io.delta.kernel.types.DoubleType;
import io.delta.kernel.types.IntegerType;
import io.delta.kernel.types.FloatType;
import io.delta.kernel.types.LongType;
import io.delta.kernel.types.ShortType;
import io.delta.kernel.types.StringType;
import io.delta.kernel.types.TimestampType;

// TODO: exclude from public interfaces (move to "internal" somewhere?)
public class CastingComparator<T extends Comparable<T>> implements Comparator<Object> {

public static Comparator<Object> forDataType(DataType dataType) {
if (dataType instanceof IntegerType) {
return new CastingComparator<Integer>();
}

if (dataType instanceof BooleanType) {
return new CastingComparator<Boolean>();
}

if (dataType instanceof FloatType) {
} else if (dataType instanceof ByteType) {
return new CastingComparator<Byte>();
} else if (dataType instanceof ShortType) {
return new CastingComparator<Short>();
} else if (dataType instanceof IntegerType) {
return new CastingComparator<Integer>();
} else if (dataType instanceof LongType) {
return new CastingComparator<Long>();
}

if (dataType instanceof StringType) {
} else if (dataType instanceof FloatType) {
return new CastingComparator<Float>();
} else if (dataType instanceof DoubleType) {
return new CastingComparator<Double>();
} else if (dataType instanceof StringType) {
return new CastingComparator<String>();
} else if (dataType instanceof DateType) {
// Date value is accessed as integer (number of days since epoch).
// This may change in the future.
return new CastingComparator<Integer>();
} else if (dataType instanceof TimestampType) {
// Timestamp value is accessed as long (epoch seconds). This may change in the future.
return new CastingComparator<Long>();
}

throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.delta.kernel.types;
package io.delta.kernel.internal.types;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
Expand All @@ -26,6 +27,17 @@
import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.data.ColumnarBatch;
import io.delta.kernel.data.Row;
import io.delta.kernel.types.ArrayType;
import io.delta.kernel.types.BasePrimitiveType;
import io.delta.kernel.types.BooleanType;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.DecimalType;
import io.delta.kernel.types.MapType;
import io.delta.kernel.types.MixedDataType;
import io.delta.kernel.types.StringType;
import io.delta.kernel.types.StructField;
import io.delta.kernel.types.StructType;
import io.delta.kernel.utils.CloseableIterator;
import io.delta.kernel.utils.Utils;

/**
Expand Down Expand Up @@ -68,12 +80,15 @@ public static StructType fromJson(JsonHandler jsonHandler, String serializedStru
*/
private static StructType parseStructType(JsonHandler jsonHandler, String serializedStructType)
{
Row row = parse(jsonHandler, serializedStructType, STRUCT_TYPE_SCHEMA);
final List<Row> fields = row.getArray(0);
return new StructType(
fields.stream()
.map(field -> parseStructField(jsonHandler, field))
.collect(Collectors.toList()));
Function<Row, StructType> evalMethod = (row) -> {
final List<Row> fields = row.getArray(0);
return new StructType(
fields.stream()
.map(field -> parseStructField(jsonHandler, field))
.collect(Collectors.toList()));
};
return parseAndEvalSingleRow(
jsonHandler, serializedStructType, STRUCT_TYPE_SCHEMA, evalMethod);
}

/**
Expand All @@ -82,7 +97,8 @@ private static StructType parseStructType(JsonHandler jsonHandler, String serial
private static StructField parseStructField(JsonHandler jsonHandler, Row row)
{
String name = row.getString(0);
DataType type = parseDataType(jsonHandler, row, 1);
String serializedDataType = row.getString(1);
DataType type = parseDataType(jsonHandler, serializedDataType);
boolean nullable = row.getBoolean(2);
Map<String, String> metadata = row.getMap(3);

Expand All @@ -92,92 +108,107 @@ private static StructField parseStructField(JsonHandler jsonHandler, Row row)
/**
* Utility method to parse the data type from the {@link Row}.
*/
private static DataType parseDataType(JsonHandler jsonHandler, Row row, int ordinal)
private static DataType parseDataType(JsonHandler jsonHandler, String serializedDataType)
{
final String typeName = row.getString(ordinal);

if (BasePrimitiveType.isPrimitiveType(typeName)) {
return BasePrimitiveType.createPrimitive(typeName);
if (BasePrimitiveType.isPrimitiveType(serializedDataType)) {
return BasePrimitiveType.createPrimitive(serializedDataType);
}

// Check if it is decimal type
if (typeName.startsWith("decimal")) {
if (typeName.equalsIgnoreCase("decimal")) {
if (serializedDataType.startsWith("decimal")) {
if (serializedDataType.equalsIgnoreCase("decimal")) {
return DecimalType.USER_DEFAULT;
}

// parse the precision and scale
Matcher matcher = DECIMAL_TYPE_PATTERN.matcher(typeName);
Matcher matcher = DECIMAL_TYPE_PATTERN.matcher(serializedDataType);
if (!matcher.matches()) {
throw new IllegalArgumentException("Invalid decimal type format: " + typeName);
throw new IllegalArgumentException(
"Invalid decimal type format: " + serializedDataType);
}
return new DecimalType(
Integer.valueOf(matcher.group("precision")),
Integer.valueOf(matcher.group("scale")));
}
// This must be a complex type which is described as an JSON object.

Optional<ArrayType> arrayType = parseAsArrayType(jsonHandler, typeName);
Optional<ArrayType> arrayType = parseAsArrayType(jsonHandler, serializedDataType);
if (arrayType.isPresent()) {
return arrayType.get();
}

Optional<MapType> mapType = parseAsMapType(jsonHandler, typeName);
Optional<MapType> mapType = parseAsMapType(jsonHandler, serializedDataType);
if (mapType.isPresent()) {
return mapType.get();
}

return parseStructType(jsonHandler, typeName);
return parseStructType(jsonHandler, serializedDataType);
}

private static Optional<ArrayType> parseAsArrayType(JsonHandler jsonHandler, String json)
{
Row row = parse(jsonHandler, json, ARRAY_TYPE_SCHEMA);
if (!"array".equalsIgnoreCase(row.getString(0))) {
return Optional.empty();
}
Function<Row, Optional<ArrayType>> evalMethod = (row) -> {
if (!"array".equalsIgnoreCase(row.getString(0))) {
return Optional.empty();
}

if (row.isNullAt(1) || row.isNullAt(2)) {
throw new IllegalArgumentException("invalid array serialized format: " + json);
}
if (row.isNullAt(1) || row.isNullAt(2)) {
throw new IllegalArgumentException("invalid array serialized format: " + json);
}

// Now parse the element type and create an array data type object
DataType elementType = parseDataType(jsonHandler, row.getString(1));
boolean containsNull = row.getBoolean(2);

// Now parse the element type and create an array data type object
DataType elementType = parseDataType(jsonHandler, row, 1);
boolean containsNull = row.getBoolean(2);
return Optional.of(new ArrayType(elementType, containsNull));
};

return Optional.of(new ArrayType(elementType, containsNull));
return parseAndEvalSingleRow(jsonHandler, json, ARRAY_TYPE_SCHEMA, evalMethod);
}

private static Optional<MapType> parseAsMapType(JsonHandler jsonHandler, String json)
{
Row row = parse(jsonHandler, json, MAP_TYPE_SCHEMA);
if (!"map".equalsIgnoreCase(row.getString(0))) {
return Optional.empty();
}
Function<Row, Optional<MapType>> evalMethod = (row -> {
if (!"map".equalsIgnoreCase(row.getString(0))) {
return Optional.empty();
}

if (row.isNullAt(1) || row.isNullAt(2) || row.isNullAt(3)) {
throw new IllegalArgumentException("invalid map serialized format: " + json);
}
if (row.isNullAt(1) || row.isNullAt(2) || row.isNullAt(3)) {
throw new IllegalArgumentException("invalid map serialized format: " + json);
}

// Now parse the key and value types and create a map data type object
DataType keyType = parseDataType(jsonHandler, row, 1);
DataType valueType = parseDataType(jsonHandler, row, 2);
boolean valueContainsNull = row.getBoolean(3);
// Now parse the key and value types and create a map data type object
DataType keyType = parseDataType(jsonHandler, row.getString(1));
DataType valueType = parseDataType(jsonHandler, row.getString(2));
boolean valueContainsNull = row.getBoolean(3);

return Optional.of(new MapType(keyType, valueType, valueContainsNull));
return Optional.of(new MapType(keyType, valueType, valueContainsNull));
});

return parseAndEvalSingleRow(jsonHandler, json, MAP_TYPE_SCHEMA, evalMethod);
}

/**
* Helper method to parse a single json string
*/
private static Row parse(JsonHandler jsonHandler, String jsonString, StructType outputSchema)
private static <R> R parseAndEvalSingleRow(
JsonHandler jsonHandler,
String jsonString,
StructType outputSchema,
Function<Row, R> evalFunction)
{
ColumnVector columnVector = Utils.singletonColumnVector(jsonString);
ColumnarBatch result = jsonHandler.parseJson(columnVector, outputSchema);

assert result.getSize() == 1;

return result.getRows().next();
CloseableIterator<Row> rows = result.getRows();
try {
return evalFunction.apply(rows.next());
}
finally {
Utils.safeClose(rows);
}
}

/**
Expand All @@ -200,17 +231,17 @@ private static Row parse(JsonHandler jsonHandler, String jsonString, StructType
/**
* Example Array Type in serialized format
* {
* "type" : "array",
* "elementType" : {
* "type" : "struct",
* "fields" : [ {
* "name" : "d",
* "type" : "integer",
* "nullable" : false,
* "metadata" : { }
* } ]
* },
* "containsNull" : true
* "type" : "array",
* "elementType" : {
* "type" : "struct",
* "fields" : [ {
* "name" : "d",
* "type" : "integer",
* "nullable" : false,
* "metadata" : { }
* } ]
* },
* "containsNull" : true
* }
*/
private static StructType ARRAY_TYPE_SCHEMA =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
/**
* Base class for all primitive types {@link DataType}.
*/
class BasePrimitiveType extends DataType
public abstract class BasePrimitiveType extends DataType
{
/**
* Create a primitive type {@link DataType}
*
* @param primitiveTypeName Primitive type name.
* @return
*/
protected static DataType createPrimitive(String primitiveTypeName)
public static DataType createPrimitive(String primitiveTypeName)
{
return Optional.ofNullable(nameToPrimitiveTypeMap.get(primitiveTypeName))
.orElseThrow(
Expand All @@ -42,13 +42,13 @@ protected static DataType createPrimitive(String primitiveTypeName)
/**
* Is the given type name a primitive type?
*/
protected static boolean isPrimitiveType(String typeName)
public static boolean isPrimitiveType(String typeName)
{
return nameToPrimitiveTypeMap.containsKey(typeName);
}

/** For testing only */
protected static List<DataType> getAllPrimitiveTypes() {
public static List<DataType> getAllPrimitiveTypes() {
return nameToPrimitiveTypeMap.values().stream().collect(Collectors.toList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ public final class DecimalType extends DataType

public DecimalType(int precision, int scale)
{
if (precision < 0 || precision > 38 || scale < 0 || scale > 38 || scale > precision) {
throw new IllegalArgumentException(String.format(
"Invalid precision and scale combo (%d, %d). They should be in the range [0, 38] " +
"and scale can not be more than the precision.", precision, scale));
}
this.precision = precision;
this.scale = scale;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
* }
* </pre>
* <p>
* `map` type column schema is serailized as:
* `struct` type column schema is serialized as:
* <pre>
* {
* "type" : "struct",
Expand Down
14 changes: 13 additions & 1 deletion kernel/kernel-api/src/main/java/io/delta/kernel/utils/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public String getString(int rowId) {

/**
* Utility method to get the physical schema from the scan state {@link Row} returned by
* {@link Scan#getScanState(TableClient)}}.
* {@link Scan#getScanState(TableClient)}.
*
* @param scanState Scan state {@link Row}
* @return Physical schema to read from the data files.
Expand All @@ -148,4 +148,16 @@ public static FileStatus getFileStatus(Row scanFileInfo) {

return FileStatus.of(path, size, 0);
}

/**
* Close the iterator.
* @param i1
*/
public static void safeClose(CloseableIterator i1) {
try {
i1.close();
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
}
}
Loading

0 comments on commit 1d4f986

Please sign in to comment.