Skip to content

Commit

Permalink
Polishing (#6)
Browse files Browse the repository at this point in the history
* Adjust variable names

* Add test case for optional string codec (WIP)

* Refactor validateSchema

* Exhaustive list of primitive types in decoder

* Exhaustive list of primitive types in encoder (WIP)

* Finish exhaustive list of primitive types in encoder
  • Loading branch information
grouzen authored Nov 25, 2023
1 parent 3a2dc93 commit 6e5946f
Show file tree
Hide file tree
Showing 11 changed files with 1,012 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ object SchemaEncoder {
}

private def encodePrimitive[A](name: String, standardType: StandardType[A], nullable: Boolean): Field = {
def field0(arrowType: ArrowType) =
def namedField(arrowType: ArrowType) =
field(name, arrowType, nullable)

standardType match {
case StandardType.IntType =>
field0(new ArrowType.Int(32, true))
namedField(new ArrowType.Int(32, true))
case StandardType.LongType =>
field0(new ArrowType.Int(64, true))
namedField(new ArrowType.Int(64, true))
case StandardType.FloatType =>
field0(new ArrowType.FloatingPoint(FloatingPointPrecision.HALF))
namedField(new ArrowType.FloatingPoint(FloatingPointPrecision.HALF))
case StandardType.DoubleType =>
field0(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))
namedField(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))
case StandardType.StringType =>
field0(new ArrowType.Utf8)
namedField(new ArrowType.Utf8)
case other =>
throw EncoderError(s"Unsupported ZIO Schema StandardType $other")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@ import org.apache.arrow.vector.complex.{ ListVector, StructVector }
import zio._
import zio.schema._

import java.nio.ByteBuffer
import java.time.{
DayOfWeek,
Instant,
LocalDate,
LocalDateTime,
LocalTime,
Month,
MonthDay,
OffsetDateTime,
OffsetTime,
Period,
Year,
YearMonth,
ZoneId,
ZoneOffset,
ZonedDateTime
}
import java.util.UUID
import scala.annotation.tailrec
import scala.collection.immutable.ListMap
import scala.util.control.NonFatal
Expand Down Expand Up @@ -127,58 +146,108 @@ object ValueVectorDecoder {
}

@tailrec
private[codec] def decodeSchema[A](name: Option[String], schema0: Schema[A], reader0: FieldReader): DynamicValue = {
val reader = name.fold[FieldReader](reader0.reader())(reader0.reader(_))
private[codec] def decodeSchema[A](name: Option[String], schema: Schema[A], reader: FieldReader): DynamicValue = {
val reader0 = name.fold[FieldReader](reader.reader())(reader.reader(_))

schema0 match {
schema match {
case Schema.Primitive(standardType, _) =>
decodePrimitive(standardType, reader)
decodePrimitive(standardType, reader0)
case record: Schema.Record[A] =>
decodeCaseClass(record.fields, reader)
decodeCaseClass(record.fields, reader0)
case Schema.Sequence(elemSchema, _, _, _, _) =>
decodeSequence(elemSchema, reader)
decodeSequence(elemSchema, reader0)
case lzy: Schema.Lazy[_] =>
decodeSchema(name, lzy.schema, reader0)
decodeSchema(name, lzy.schema, reader)
case other =>
throw DecoderError(s"Unsupported ZIO Schema type $other")
}
}

private[codec] def decodeCaseClass[A](fields: Chunk[Schema.Field[A, _]], reader0: FieldReader): DynamicValue = {
private[codec] def decodeCaseClass[A](fields: Chunk[Schema.Field[A, _]], reader: FieldReader): DynamicValue = {
val values = ListMap(fields.map { case Schema.Field(name, schema0, _, _, _, _) =>
val value: DynamicValue = decodeSchema(Some(name), schema0, reader0)
val value: DynamicValue = decodeSchema(Some(name), schema0, reader)

name -> value
}: _*)

DynamicValue.Record(TypeId.Structural, values)
}

private[codec] def decodeSequence[A](schema0: Schema[A], reader0: FieldReader): DynamicValue = {
private[codec] def decodeSequence[A](schema: Schema[A], reader: FieldReader): DynamicValue = {
val builder = ChunkBuilder.make[DynamicValue]()

while (reader0.next())
if (reader0.isSet)
builder.addOne(decodeSchema(None, schema0, reader0))
while (reader.next())
if (reader.isSet)
builder.addOne(decodeSchema(None, schema, reader))

DynamicValue.Sequence(builder.result())
}

private[codec] def decodePrimitive[A](standardType: StandardType[A], reader0: FieldReader): DynamicValue =
private[codec] def decodePrimitive[A](standardType: StandardType[A], reader: FieldReader): DynamicValue =
standardType match {
case t: StandardType.BoolType.type =>
DynamicValue.Primitive[Boolean](reader0.readBoolean(), t)
case t: StandardType.IntType.type =>
DynamicValue.Primitive[Int](reader0.readInteger(), t)
case t: StandardType.LongType.type =>
DynamicValue.Primitive[Long](reader0.readLong(), t)
case t: StandardType.FloatType.type =>
DynamicValue.Primitive[Float](reader0.readFloat(), t)
case t: StandardType.DoubleType.type =>
DynamicValue.Primitive[Double](reader0.readDouble(), t)
case t: StandardType.StringType.type =>
DynamicValue.Primitive[String](reader0.readText().toString, t)
case other =>
case t: StandardType.StringType.type =>
DynamicValue.Primitive[String](reader.readText().toString, t)
case t: StandardType.BoolType.type =>
DynamicValue.Primitive[Boolean](reader.readBoolean(), t)
case t: StandardType.ByteType.type =>
DynamicValue.Primitive[Byte](reader.readByte(), t)
case t: StandardType.ShortType.type =>
DynamicValue.Primitive[Short](reader.readShort(), t)
case t: StandardType.IntType.type =>
DynamicValue.Primitive[Int](reader.readInteger(), t)
case t: StandardType.LongType.type =>
DynamicValue.Primitive[Long](reader.readLong(), t)
case t: StandardType.FloatType.type =>
DynamicValue.Primitive[Float](reader.readFloat(), t)
case t: StandardType.DoubleType.type =>
DynamicValue.Primitive[Double](reader.readDouble(), t)
case t: StandardType.BinaryType.type =>
DynamicValue.Primitive[Chunk[Byte]](Chunk.fromArray(reader.readByteArray()), t)
case t: StandardType.CharType.type =>
DynamicValue.Primitive[Char](reader.readCharacter(), t)
case t: StandardType.UUIDType.type =>
val bb = ByteBuffer.wrap(reader.readByteArray())
DynamicValue.Primitive[UUID](new UUID(bb.getLong, bb.getLong), t)
case t: StandardType.BigDecimalType.type =>
DynamicValue.Primitive[java.math.BigDecimal](reader.readBigDecimal(), t)
case t: StandardType.BigIntegerType.type =>
DynamicValue.Primitive[java.math.BigInteger](new java.math.BigInteger(reader.readByteArray()), t)
case t: StandardType.DayOfWeekType.type =>
DynamicValue.Primitive[DayOfWeek](DayOfWeek.of(reader.readInteger()), t)
case t: StandardType.MonthType.type =>
DynamicValue.Primitive[Month](Month.of(reader.readInteger()), t)
case t: StandardType.MonthDayType.type =>
val bb = ByteBuffer.allocate(8).putLong(reader.readLong())
DynamicValue.Primitive[MonthDay](MonthDay.of(bb.getInt(0), bb.getInt(4)), t)
case t: StandardType.PeriodType.type =>
val bb = ByteBuffer.wrap(reader.readByteArray())
DynamicValue.Primitive[Period](Period.of(bb.getInt(0), bb.getInt(4), bb.getInt(8)), t)
case t: StandardType.YearType.type =>
DynamicValue.Primitive[Year](Year.of(reader.readInteger()), t)
case t: StandardType.YearMonthType.type =>
val bb = ByteBuffer.allocate(8).putLong(reader.readLong())
DynamicValue.Primitive[YearMonth](YearMonth.of(bb.getInt(0), bb.getInt(4)), t)
case t: StandardType.ZoneIdType.type =>
DynamicValue.Primitive[ZoneId](ZoneId.of(reader.readText().toString), t)
case t: StandardType.ZoneOffsetType.type =>
DynamicValue.Primitive[ZoneOffset](ZoneOffset.of(reader.readText().toString), t)
case t: StandardType.DurationType.type =>
DynamicValue.Primitive[Duration](Duration.fromMillis(reader.readLong()), t)
case t: StandardType.InstantType.type =>
DynamicValue.Primitive[Instant](Instant.ofEpochMilli(reader.readLong()), t)
case t: StandardType.LocalDateType.type =>
DynamicValue.Primitive[LocalDate](LocalDate.parse(reader.readText().toString), t)
case t: StandardType.LocalTimeType.type =>
DynamicValue.Primitive[LocalTime](LocalTime.parse(reader.readText().toString), t)
case t: StandardType.LocalDateTimeType.type =>
DynamicValue.Primitive[LocalDateTime](LocalDateTime.parse(reader.readText().toString), t)
case t: StandardType.OffsetTimeType.type =>
DynamicValue.Primitive[OffsetTime](OffsetTime.parse(reader.readText().toString), t)
case t: StandardType.OffsetDateTimeType.type =>
DynamicValue.Primitive[OffsetDateTime](OffsetDateTime.parse(reader.readText().toString), t)
case t: StandardType.ZonedDateTimeType.type =>
DynamicValue.Primitive[ZonedDateTime](ZonedDateTime.parse(reader.readText().toString), t)
case other =>
throw DecoderError(s"Unsupported ZIO Schema type $other")
}

Expand Down
Loading

0 comments on commit 6e5946f

Please sign in to comment.