From 12751a7fe99c80d6aa16da1f679e3a9ad998f0f6 Mon Sep 17 00:00:00 2001 From: Michael Nedokushev Date: Mon, 11 Dec 2023 20:28:48 +0000 Subject: [PATCH] Fix list converter --- .../zio/apache/parquet/core/Value.scala | 3 +- .../core/hadoop/GroupValueConverter.scala | 132 ++++++++++-------- .../parquet/core/hadoop/ParquetIOSpec.scala | 7 +- 3 files changed, 80 insertions(+), 62 deletions(-) diff --git a/modules/core/src/main/scala/me/mnedokushev/zio/apache/parquet/core/Value.scala b/modules/core/src/main/scala/me/mnedokushev/zio/apache/parquet/core/Value.scala index 244b22f..aa266fa 100644 --- a/modules/core/src/main/scala/me/mnedokushev/zio/apache/parquet/core/Value.scala +++ b/modules/core/src/main/scala/me/mnedokushev/zio/apache/parquet/core/Value.scala @@ -155,7 +155,8 @@ object Value { recordConsumer.endGroup() } - override def put(name: String, value: Value): MapValue = ??? + override def put(name: String, value: Value): MapValue = + this.copy(values = Map(Value.string(name) -> value)) // this.copy(values = values.updated(name, value)) } diff --git a/modules/core/src/main/scala/me/mnedokushev/zio/apache/parquet/core/hadoop/GroupValueConverter.scala b/modules/core/src/main/scala/me/mnedokushev/zio/apache/parquet/core/hadoop/GroupValueConverter.scala index d30a679..3a90152 100644 --- a/modules/core/src/main/scala/me/mnedokushev/zio/apache/parquet/core/hadoop/GroupValueConverter.scala +++ b/modules/core/src/main/scala/me/mnedokushev/zio/apache/parquet/core/hadoop/GroupValueConverter.scala @@ -3,12 +3,16 @@ package me.mnedokushev.zio.apache.parquet.core.hadoop import me.mnedokushev.zio.apache.parquet.core.Value import me.mnedokushev.zio.apache.parquet.core.Value.{ GroupValue, PrimitiveValue } import org.apache.parquet.io.api.{ Binary, Converter, GroupConverter, PrimitiveConverter } -import org.apache.parquet.schema.{ GroupType, LogicalTypeAnnotation, Type } +import org.apache.parquet.schema.Type.Repetition +import org.apache.parquet.schema.{ GroupType, LogicalTypeAnnotation } import zio.Chunk import scala.jdk.CollectionConverters._ -abstract class GroupValueConverter[V <: GroupValue[V]](schema: GroupType) extends GroupConverter { parent => +abstract class GroupValueConverter[V <: GroupValue[V]]( + schema: GroupType, + parent: Option[GroupValueConverter[_]] = None +) extends GroupConverter { self => def get: V = this.groupValue @@ -19,22 +23,27 @@ abstract class GroupValueConverter[V <: GroupValue[V]](schema: GroupType) extend protected var groupValue: V = _ private val converters: Chunk[Converter] = - Chunk.fromIterable(schema.getFields.asScala.toList.map(fromSchema)) - - private def fromSchema(schema0: Type) = { - val name = schema0.getName - - schema0.getLogicalTypeAnnotation match { - case _ if schema0.isPrimitive => - primitive(name) - case _: LogicalTypeAnnotation.ListLogicalTypeAnnotation => - GroupValueConverter.list(schema0.asGroupType(), name, parent) - case _: LogicalTypeAnnotation.MapLogicalTypeAnnotation => - GroupValueConverter.map(schema0.asGroupType(), name, parent) - case _ => - GroupValueConverter.record(schema0.asGroupType(), name, parent) - } - } + Chunk.fromIterable( + schema.getFields.asScala.toList.map { schema0 => + val name = schema0.getName + + schema0.getLogicalTypeAnnotation match { + case _ if schema0.isPrimitive => + primitive(name) + case _: LogicalTypeAnnotation.ListLogicalTypeAnnotation => + list(schema0.asGroupType(), name) + case _: LogicalTypeAnnotation.MapLogicalTypeAnnotation => + map(schema0.asGroupType(), name) + case _ => + val name = schema0.getName + val repetition = schema0.getRepetition + + val p = if (name == "list" && repetition == Repetition.REPEATED) Some(this) else None + + record(schema0.asGroupType(), name, p) + } + } + ) override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) @@ -43,59 +52,50 @@ abstract class GroupValueConverter[V <: GroupValue[V]](schema: GroupType) extend new PrimitiveConverter { override def addBinary(value: Binary): Unit = - parent.groupValue = parent.groupValue.put(name, PrimitiveValue.BinaryValue(value)) + parent.getOrElse(self).put(name, PrimitiveValue.BinaryValue(value)) override def addBoolean(value: Boolean): Unit = - parent.groupValue = parent.groupValue.put(name, PrimitiveValue.BooleanValue(value)) + parent.getOrElse(self).put(name, PrimitiveValue.BooleanValue(value)) override def addDouble(value: Double): Unit = - parent.groupValue = parent.groupValue.put(name, PrimitiveValue.DoubleValue(value)) + parent.getOrElse(self).put(name, PrimitiveValue.DoubleValue(value)) override def addFloat(value: Float): Unit = - parent.groupValue = parent.groupValue.put(name, PrimitiveValue.FloatValue(value)) + parent.getOrElse(self).put(name, PrimitiveValue.FloatValue(value)) override def addInt(value: Int): Unit = - parent.groupValue = parent.groupValue.put(name, PrimitiveValue.Int32Value(value)) + parent.getOrElse(self).put(name, PrimitiveValue.Int32Value(value)) override def addLong(value: Long): Unit = - parent.groupValue = parent.groupValue.put(name, PrimitiveValue.Int64Value(value)) + parent.getOrElse(self).put(name, PrimitiveValue.Int64Value(value)) } -} - -object GroupValueConverter { - - def root(schema: GroupType): GroupValueConverter[GroupValue.RecordValue] = - new GroupValueConverter[GroupValue.RecordValue](schema) { - - override def start(): Unit = - this.groupValue = Value.record( - schema.getFields.asScala.toList.map(_.getName -> Value.nil).toMap - ) - - override def end(): Unit = () - } - - def record[V <: GroupValue[V]]( + private def record( schema: GroupType, name: String, - parent: GroupValueConverter[V] - ): GroupValueConverter[GroupValue.RecordValue] = - new GroupValueConverter[GroupValue.RecordValue](schema) { - - override def start(): Unit = - this.groupValue = Value.record(Map.empty) - - override def end(): Unit = - parent.put(name, this.groupValue) - - } + parent: Option[GroupValueConverter[_]] + ): GroupValueConverter[GroupValue.RecordValue] = parent match { + case Some(_) => + new GroupValueConverter[GroupValue.RecordValue](schema, parent) { + override def start(): Unit = () + override def end(): Unit = () + } + case _ => + new GroupValueConverter[GroupValue.RecordValue](schema, parent) { + + override def start(): Unit = + this.groupValue = Value.record(Map.empty) + + override def end(): Unit = + self.put(name, this.groupValue) + + } + } - def list[V <: GroupValue[V]]( + private def list( schema: GroupType, - name: String, - parent: GroupValueConverter[V] + name: String ): GroupValueConverter[GroupValue.ListValue] = new GroupValueConverter[GroupValue.ListValue](schema) { @@ -103,13 +103,12 @@ object GroupValueConverter { this.groupValue = Value.list(Chunk.empty) override def end(): Unit = - parent.put(name, this.groupValue) + self.put(name, this.groupValue) } - def map[V <: GroupValue[V]]( + private def map( schema: GroupType, - name: String, - parent: GroupValueConverter[V] + name: String ): GroupValueConverter[GroupValue.MapValue] = new GroupValueConverter[GroupValue.MapValue](schema) { @@ -117,7 +116,22 @@ object GroupValueConverter { this.groupValue = Value.map(Map.empty) override def end(): Unit = - parent.put(name, this.groupValue) + self.put(name, this.groupValue) + } + +} + +object GroupValueConverter { + + def root(schema: GroupType): GroupValueConverter[GroupValue.RecordValue] = + new GroupValueConverter[GroupValue.RecordValue](schema) { + + override def start(): Unit = + this.groupValue = Value.record( + schema.getFields.asScala.toList.map(_.getName -> Value.nil).toMap + ) + + override def end(): Unit = () } } diff --git a/modules/core/src/test/scala/me/mnedokushev/zio/apache/parquet/core/hadoop/ParquetIOSpec.scala b/modules/core/src/test/scala/me/mnedokushev/zio/apache/parquet/core/hadoop/ParquetIOSpec.scala index e4a39e6..a9e0b64 100644 --- a/modules/core/src/test/scala/me/mnedokushev/zio/apache/parquet/core/hadoop/ParquetIOSpec.scala +++ b/modules/core/src/test/scala/me/mnedokushev/zio/apache/parquet/core/hadoop/ParquetIOSpec.scala @@ -15,7 +15,7 @@ object ParquetIOSpec extends ZIOSpecDefault { val tmpCrcPath = tmpDir / ".parquet-writer-spec.parquet.crc" val tmpPath = tmpDir / tmpFile - case class Record(a: Int, b: String, c: Option[Long]) + case class Record(a: Int, b: String, c: Option[Long], d: List[Int]) object Record { implicit val schema: Schema[Record] = DeriveSchema.gen[Record] @@ -30,7 +30,10 @@ object ParquetIOSpec extends ZIOSpecDefault { override def spec: Spec[TestEnvironment with Scope, Any] = suite("ParquetIOSpec")( test("write and read") { - val payload = Chunk(Record(1, "foo", None), Record(2, "bar", Some(3L))) + val payload = Chunk( + Record(1, "foo", None, List(1, 2)), + Record(2, "bar", Some(3L), List.empty) + ) for { writer <- ZIO.service[ParquetWriter[Record]]