Skip to content

Commit

Permalink
Fix list converter
Browse files Browse the repository at this point in the history
  • Loading branch information
grouzen committed Dec 11, 2023
1 parent 7480a82 commit 12751a7
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ object Value {
recordConsumer.endGroup()
}

override def put(name: String, value: Value): MapValue = ???
override def put(name: String, value: Value): MapValue =
this.copy(values = Map(Value.string(name) -> value))
// this.copy(values = values.updated(name, value))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ package me.mnedokushev.zio.apache.parquet.core.hadoop
import me.mnedokushev.zio.apache.parquet.core.Value
import me.mnedokushev.zio.apache.parquet.core.Value.{ GroupValue, PrimitiveValue }
import org.apache.parquet.io.api.{ Binary, Converter, GroupConverter, PrimitiveConverter }
import org.apache.parquet.schema.{ GroupType, LogicalTypeAnnotation, Type }
import org.apache.parquet.schema.Type.Repetition
import org.apache.parquet.schema.{ GroupType, LogicalTypeAnnotation }
import zio.Chunk

import scala.jdk.CollectionConverters._

abstract class GroupValueConverter[V <: GroupValue[V]](schema: GroupType) extends GroupConverter { parent =>
abstract class GroupValueConverter[V <: GroupValue[V]](
schema: GroupType,
parent: Option[GroupValueConverter[_]] = None
) extends GroupConverter { self =>

def get: V =
this.groupValue
Expand All @@ -19,22 +23,27 @@ abstract class GroupValueConverter[V <: GroupValue[V]](schema: GroupType) extend
protected var groupValue: V = _

private val converters: Chunk[Converter] =
Chunk.fromIterable(schema.getFields.asScala.toList.map(fromSchema))

private def fromSchema(schema0: Type) = {
val name = schema0.getName

schema0.getLogicalTypeAnnotation match {
case _ if schema0.isPrimitive =>
primitive(name)
case _: LogicalTypeAnnotation.ListLogicalTypeAnnotation =>
GroupValueConverter.list(schema0.asGroupType(), name, parent)
case _: LogicalTypeAnnotation.MapLogicalTypeAnnotation =>
GroupValueConverter.map(schema0.asGroupType(), name, parent)
case _ =>
GroupValueConverter.record(schema0.asGroupType(), name, parent)
}
}
Chunk.fromIterable(
schema.getFields.asScala.toList.map { schema0 =>
val name = schema0.getName

schema0.getLogicalTypeAnnotation match {
case _ if schema0.isPrimitive =>
primitive(name)
case _: LogicalTypeAnnotation.ListLogicalTypeAnnotation =>
list(schema0.asGroupType(), name)
case _: LogicalTypeAnnotation.MapLogicalTypeAnnotation =>
map(schema0.asGroupType(), name)
case _ =>
val name = schema0.getName
val repetition = schema0.getRepetition

val p = if (name == "list" && repetition == Repetition.REPEATED) Some(this) else None

record(schema0.asGroupType(), name, p)
}
}
)

override def getConverter(fieldIndex: Int): Converter =
converters(fieldIndex)
Expand All @@ -43,81 +52,86 @@ abstract class GroupValueConverter[V <: GroupValue[V]](schema: GroupType) extend
new PrimitiveConverter {

override def addBinary(value: Binary): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.BinaryValue(value))
parent.getOrElse(self).put(name, PrimitiveValue.BinaryValue(value))

override def addBoolean(value: Boolean): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.BooleanValue(value))
parent.getOrElse(self).put(name, PrimitiveValue.BooleanValue(value))

override def addDouble(value: Double): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.DoubleValue(value))
parent.getOrElse(self).put(name, PrimitiveValue.DoubleValue(value))

override def addFloat(value: Float): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.FloatValue(value))
parent.getOrElse(self).put(name, PrimitiveValue.FloatValue(value))

override def addInt(value: Int): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.Int32Value(value))
parent.getOrElse(self).put(name, PrimitiveValue.Int32Value(value))

override def addLong(value: Long): Unit =
parent.groupValue = parent.groupValue.put(name, PrimitiveValue.Int64Value(value))
parent.getOrElse(self).put(name, PrimitiveValue.Int64Value(value))

}

}

object GroupValueConverter {

def root(schema: GroupType): GroupValueConverter[GroupValue.RecordValue] =
new GroupValueConverter[GroupValue.RecordValue](schema) {

override def start(): Unit =
this.groupValue = Value.record(
schema.getFields.asScala.toList.map(_.getName -> Value.nil).toMap
)

override def end(): Unit = ()
}

def record[V <: GroupValue[V]](
private def record(
schema: GroupType,
name: String,
parent: GroupValueConverter[V]
): GroupValueConverter[GroupValue.RecordValue] =
new GroupValueConverter[GroupValue.RecordValue](schema) {

override def start(): Unit =
this.groupValue = Value.record(Map.empty)

override def end(): Unit =
parent.put(name, this.groupValue)

}
parent: Option[GroupValueConverter[_]]
): GroupValueConverter[GroupValue.RecordValue] = parent match {
case Some(_) =>
new GroupValueConverter[GroupValue.RecordValue](schema, parent) {
override def start(): Unit = ()
override def end(): Unit = ()
}
case _ =>
new GroupValueConverter[GroupValue.RecordValue](schema, parent) {

override def start(): Unit =
this.groupValue = Value.record(Map.empty)

override def end(): Unit =
self.put(name, this.groupValue)

}
}

def list[V <: GroupValue[V]](
private def list(
schema: GroupType,
name: String,
parent: GroupValueConverter[V]
name: String
): GroupValueConverter[GroupValue.ListValue] =
new GroupValueConverter[GroupValue.ListValue](schema) {

override def start(): Unit =
this.groupValue = Value.list(Chunk.empty)

override def end(): Unit =
parent.put(name, this.groupValue)
self.put(name, this.groupValue)
}

def map[V <: GroupValue[V]](
private def map(
schema: GroupType,
name: String,
parent: GroupValueConverter[V]
name: String
): GroupValueConverter[GroupValue.MapValue] =
new GroupValueConverter[GroupValue.MapValue](schema) {

override def start(): Unit =
this.groupValue = Value.map(Map.empty)

override def end(): Unit =
parent.put(name, this.groupValue)
self.put(name, this.groupValue)
}

}

object GroupValueConverter {

def root(schema: GroupType): GroupValueConverter[GroupValue.RecordValue] =
new GroupValueConverter[GroupValue.RecordValue](schema) {

override def start(): Unit =
this.groupValue = Value.record(
schema.getFields.asScala.toList.map(_.getName -> Value.nil).toMap
)

override def end(): Unit = ()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ object ParquetIOSpec extends ZIOSpecDefault {
val tmpCrcPath = tmpDir / ".parquet-writer-spec.parquet.crc"
val tmpPath = tmpDir / tmpFile

case class Record(a: Int, b: String, c: Option[Long])
case class Record(a: Int, b: String, c: Option[Long], d: List[Int])
object Record {
implicit val schema: Schema[Record] =
DeriveSchema.gen[Record]
Expand All @@ -30,7 +30,10 @@ object ParquetIOSpec extends ZIOSpecDefault {
override def spec: Spec[TestEnvironment with Scope, Any] =
suite("ParquetIOSpec")(
test("write and read") {
val payload = Chunk(Record(1, "foo", None), Record(2, "bar", Some(3L)))
val payload = Chunk(
Record(1, "foo", None, List(1, 2)),
Record(2, "bar", Some(3L), List.empty)
)

for {
writer <- ZIO.service[ParquetWriter[Record]]
Expand Down

0 comments on commit 12751a7

Please sign in to comment.