Skip to content

Commit dea5531

Browse files
committed
Support exhaustive list of primitive types in SchemaEncoder
1 parent d02815c commit dea5531

File tree

3 files changed

+105
-31
lines changed

3 files changed

+105
-31
lines changed

modules/core/src/main/scala/me/mnedokushev/zio/apache/parquet/core/Schemas.scala

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,19 @@ object Schemas {
2222

2323
case class PrimitiveDef(
2424
typeName: PrimitiveTypeName,
25-
annotation: LogicalTypeAnnotation,
25+
annotation: Option[LogicalTypeAnnotation] = None,
2626
isOptional: Boolean = false,
2727
length: Int = 0
2828
) extends Def[PrimitiveDef] {
2929

30-
def named(name: String): Type =
31-
Types
32-
.primitive(typeName, repetition(isOptional))
33-
.as(annotation)
30+
def named(name: String): Type = {
31+
val builder = Types.primitive(typeName, repetition(isOptional))
32+
33+
annotation
34+
.fold(builder)(builder.as)
3435
.length(length)
3536
.named(name)
37+
}
3638

3739
def length(len: Int): PrimitiveDef =
3840
this.copy(length = len)
@@ -104,13 +106,36 @@ object Schemas {
104106
import PrimitiveTypeName._
105107
import LogicalTypeAnnotation._
106108

107-
val string: PrimitiveDef = PrimitiveDef(BINARY, stringType())
108-
val boolean: PrimitiveDef = PrimitiveDef(INT32, intType(8, false))
109-
val byte: PrimitiveDef = PrimitiveDef(INT32, intType(8, false))
110-
val short: PrimitiveDef = PrimitiveDef(INT32, intType(16, true))
111-
val int: PrimitiveDef = PrimitiveDef(INT32, intType(32, true))
112-
val long: PrimitiveDef = PrimitiveDef(INT64, intType(64, true))
113-
val uuid: PrimitiveDef = PrimitiveDef(FIXED_LEN_BYTE_ARRAY, uuidType()).length(16)
109+
def enum0: PrimitiveDef = PrimitiveDef(BINARY, Some(enumType()))
110+
val string: PrimitiveDef = PrimitiveDef(BINARY, Some(stringType()))
111+
val boolean: PrimitiveDef = PrimitiveDef(INT32, Some(intType(8, false)))
112+
val byte: PrimitiveDef = PrimitiveDef(INT32, Some(intType(8, false)))
113+
val short: PrimitiveDef = PrimitiveDef(INT32, Some(intType(16, true)))
114+
val int: PrimitiveDef = PrimitiveDef(INT32, Some(intType(32, true)))
115+
val long: PrimitiveDef = PrimitiveDef(INT64, Some(intType(64, true)))
116+
val float: PrimitiveDef = PrimitiveDef(FLOAT)
117+
val double: PrimitiveDef = PrimitiveDef(DOUBLE)
118+
val binary: PrimitiveDef = PrimitiveDef(BINARY)
119+
val char: PrimitiveDef = PrimitiveDef(INT32, Some(intType(8, false)))
120+
val uuid: PrimitiveDef = PrimitiveDef(FIXED_LEN_BYTE_ARRAY, Some(uuidType())).length(16)
121+
val bigDecimal: PrimitiveDef = PrimitiveDef(INT64, Some(decimalType(11, 2)))
122+
val bigInteger: PrimitiveDef = PrimitiveDef(BINARY)
123+
val dayOfWeek: PrimitiveDef = PrimitiveDef(INT32, Some(intType(8, false)))
124+
val monthType: PrimitiveDef = PrimitiveDef(INT32, Some(intType(8, false)))
125+
val monthDay: PrimitiveDef = PrimitiveDef(INT32, Some(intType(8, false)))
126+
val period: PrimitiveDef = PrimitiveDef(FIXED_LEN_BYTE_ARRAY).length(12)
127+
val year: PrimitiveDef = PrimitiveDef(INT32, Some(intType(16, false)))
128+
val yearMonth: PrimitiveDef = PrimitiveDef(INT32, Some(intType(32, false)))
129+
val zoneId: PrimitiveDef = PrimitiveDef(BINARY, Some(stringType()))
130+
val zoneOffset: PrimitiveDef = PrimitiveDef(BINARY, Some(stringType()))
131+
val duration: PrimitiveDef = PrimitiveDef(INT64, Some(intType(64, false)))
132+
val instant: PrimitiveDef = PrimitiveDef(INT64, Some(intType(64, false)))
133+
val localDate: PrimitiveDef = PrimitiveDef(INT32, Some(dateType()))
134+
val localTime: PrimitiveDef = PrimitiveDef(INT32, Some(timeType(true, TimeUnit.MILLIS)))
135+
val localDateTime: PrimitiveDef = PrimitiveDef(INT64, Some(timestampType(true, TimeUnit.MILLIS)))
136+
val offsetTime: PrimitiveDef = PrimitiveDef(INT32, Some(timeType(false, TimeUnit.MILLIS)))
137+
val offsetDateTime: PrimitiveDef = PrimitiveDef(INT64, Some(timestampType(false, TimeUnit.MILLIS)))
138+
val zonedDateTime: PrimitiveDef = PrimitiveDef(INT64, Some(timestampType(false, TimeUnit.MILLIS)))
114139

115140
def record(fields: Chunk[Type]): RecordDef = RecordDef(fields)
116141
def list(element: Type): ListDef = ListDef(element)

modules/core/src/main/scala/me/mnedokushev/zio/apache/parquet/core/codec/SchemaEncoderDeriver.scala

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package me.mnedokushev.zio.apache.parquet.core.codec
22

33
import me.mnedokushev.zio.apache.parquet.core.Schemas
4+
import me.mnedokushev.zio.apache.parquet.core.Schemas.PrimitiveDef
45
import org.apache.parquet.schema.Type
56
import zio.Chunk
67
import zio.schema.{ Deriver, Schema, StandardType }
@@ -33,33 +34,81 @@ object SchemaEncoderDeriver {
3334
summoned: => Option[SchemaEncoder[A]]
3435
): SchemaEncoder[A] = new SchemaEncoder[A] {
3536
override def encode(schema: Schema[A], name: String, optional: Boolean): Type =
36-
Schemas.string.optionality(optional).named(name)
37+
Schemas.enum0.optionality(optional).named(name)
3738
}
3839

3940
override def derivePrimitive[A](
4041
st: StandardType[A],
4142
summoned: => Option[SchemaEncoder[A]]
4243
): SchemaEncoder[A] =
4344
new SchemaEncoder[A] {
44-
override def encode(schema: Schema[A], name: String, optional: Boolean): Type =
45+
override def encode(schema: Schema[A], name: String, optional: Boolean): Type = {
46+
def tpe(prim: PrimitiveDef) =
47+
prim.optionality(optional).named(name)
48+
4549
st match {
46-
case StandardType.StringType =>
47-
Schemas.string.optionality(optional).named(name)
48-
case StandardType.BoolType =>
49-
Schemas.boolean.optionality(optional).named(name)
50-
case StandardType.ByteType =>
51-
Schemas.byte.optionality(optional).named(name)
52-
case StandardType.ShortType =>
53-
Schemas.short.optionality(optional).named(name)
54-
case StandardType.IntType =>
55-
Schemas.int.optionality(optional).named(name)
56-
case StandardType.LongType =>
57-
Schemas.long.optionality(optional).named(name)
58-
// TODO: add the other types
59-
case StandardType.UUIDType =>
60-
Schemas.uuid.optionality(optional).named(name)
61-
case _ => ???
50+
case StandardType.StringType =>
51+
tpe(Schemas.string)
52+
case StandardType.BoolType =>
53+
tpe(Schemas.boolean)
54+
case StandardType.ByteType =>
55+
tpe(Schemas.byte)
56+
case StandardType.ShortType =>
57+
tpe(Schemas.short)
58+
case StandardType.IntType =>
59+
tpe(Schemas.int)
60+
case StandardType.LongType =>
61+
tpe(Schemas.long)
62+
case StandardType.FloatType =>
63+
tpe(Schemas.float)
64+
case StandardType.DoubleType =>
65+
tpe(Schemas.double)
66+
case StandardType.BinaryType =>
67+
tpe(Schemas.binary)
68+
case StandardType.CharType =>
69+
tpe(Schemas.char)
70+
case StandardType.UUIDType =>
71+
tpe(Schemas.uuid)
72+
case StandardType.BigDecimalType =>
73+
tpe(Schemas.bigDecimal)
74+
case StandardType.BigIntegerType =>
75+
tpe(Schemas.bigInteger)
76+
case StandardType.DayOfWeekType =>
77+
tpe(Schemas.dayOfWeek)
78+
case StandardType.MonthType =>
79+
tpe(Schemas.monthType)
80+
case StandardType.MonthDayType =>
81+
tpe(Schemas.monthDay)
82+
case StandardType.PeriodType =>
83+
tpe(Schemas.period)
84+
case StandardType.YearType =>
85+
tpe(Schemas.year)
86+
case StandardType.YearMonthType =>
87+
tpe(Schemas.yearMonth)
88+
case StandardType.ZoneIdType =>
89+
tpe(Schemas.zoneId)
90+
case StandardType.ZoneOffsetType =>
91+
tpe(Schemas.zoneOffset)
92+
case StandardType.DurationType =>
93+
tpe(Schemas.duration)
94+
case StandardType.InstantType =>
95+
tpe(Schemas.instant)
96+
case StandardType.LocalDateType =>
97+
tpe(Schemas.localDate)
98+
case StandardType.LocalTimeType =>
99+
tpe(Schemas.localTime)
100+
case StandardType.LocalDateTimeType =>
101+
tpe(Schemas.localDateTime)
102+
case StandardType.OffsetTimeType =>
103+
tpe(Schemas.offsetTime)
104+
case StandardType.OffsetDateTimeType =>
105+
tpe(Schemas.offsetDateTime)
106+
case StandardType.ZonedDateTimeType =>
107+
tpe(Schemas.zonedDateTime)
108+
case StandardType.UnitType =>
109+
throw EncoderError("Unit type is unsupported")
62110
}
111+
}
63112
}
64113

65114
override def deriveOption[A](

modules/core/src/test/scala/me/mnedokushev/zio/apache/parquet/core/codec/SchemaEncoderDeriverSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ object SchemaEncoderDeriverSpec extends ZIOSpecDefault {
191191
val encoder = Derive.derive[SchemaEncoder, MyEnum](SchemaEncoderDeriver.default)
192192
val tpe = encoder.encode(Schema[MyEnum], name, optional = true)
193193

194-
assertTrue(tpe == Schemas.string.optional.named(name))
194+
assertTrue(tpe == Schemas.enum0.optional.named(name))
195195
}
196196
// test("summoned") {
197197
// // @nowarn annotation is needed to avoid having 'variable is not used' compiler error

0 commit comments

Comments
 (0)