Skip to content

Commit

Permalink
Implement deriveSequence for SchemaEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
grouzen committed Dec 5, 2023
1 parent 373a437 commit e905b53
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import zio.Chunk

object Schemas {

trait Def[Self <: Def[_]] {
abstract class Def[Self <: Def[_]] {

def named(name: String): Type

Expand Down Expand Up @@ -38,10 +38,10 @@ object Schemas {
this.copy(length = len)

def required: PrimitiveDef =
this.copy(isOptional = true)
this.copy(isOptional = false)

def optional: PrimitiveDef =
this.copy(isOptional = false)
this.copy(isOptional = true)

}

Expand All @@ -55,11 +55,30 @@ object Schemas {
}

def required: RecordDef =
this.copy(isOptional = true)
this.copy(isOptional = false)

def optional: RecordDef =
this.copy(isOptional = true)

}

case class ListDef(
element: Type,
isOptional: Boolean = false
) extends Def[ListDef] {

def named(name: String): Type =
Types
.list(repetition(isOptional))
.element(element)
.named(name)

def required: ListDef =
this.copy(isOptional = false)

def optional: ListDef =
this.copy(isOptional = true)

}

def repetition(optional: Boolean): Repetition =
Expand All @@ -77,5 +96,6 @@ object Schemas {
val uuid: PrimitiveDef = PrimitiveDef(FIXED_LEN_BYTE_ARRAY, uuidType()).length(16)

def record(fields: Chunk[Type]): RecordDef = RecordDef(fields)
def list(element: Type): ListDef = ListDef(element)

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,8 @@ object SchemaEncoderDeriver {
summoned: => Option[SchemaEncoder[A]]
): SchemaEncoder[A] = new SchemaEncoder[A] {

private def fieldType[A1](
name0: String,
schema0: Schema[A1],
encoder: SchemaEncoder[_]
) = {
val fieldOptional = schema0 match {
case _: Schema.Optional[_] => true
case _ => false
}

encoder.asInstanceOf[SchemaEncoder[A1]].encode(schema0, name0, fieldOptional)
}
private def fieldType[A1](name0: String, schema0: Schema[A1], encoder: SchemaEncoder[_]) =
encoder.asInstanceOf[SchemaEncoder[A1]].encode(schema0, name0, isSchemaOptional(schema0))

override def encode(schema: Schema[A], name: String, optional: Boolean): Type = {
val fieldTypes = record.fields.zip(fields.map(_.unwrap)).map { case (field, encoder) =>
Expand Down Expand Up @@ -82,7 +72,13 @@ object SchemaEncoderDeriver {
sequence: Schema.Sequence[C[A], A, _],
inner: => SchemaEncoder[A],
summoned: => Option[SchemaEncoder[C[A]]]
): SchemaEncoder[C[A]] = ???
): SchemaEncoder[C[A]] = new SchemaEncoder[C[A]] {
override def encode(schema: Schema[C[A]], name: String, optional: Boolean): Type =
Schemas
.list(inner.encode(sequence.elementSchema, "element", isSchemaOptional(sequence.elementSchema)))
.optionality(optional)
.named(name)
}

override def deriveMap[K, V](
map: Schema.Map[K, V],
Expand All @@ -98,8 +94,13 @@ object SchemaEncoderDeriver {
summoned: => Option[SchemaEncoder[B]]
): SchemaEncoder[B] = ???

}
}.cached

val summoned: Deriver[SchemaEncoder] = default.autoAcceptSummoned

private def isSchemaOptional(schema: Schema[_]): Boolean =
schema match {
case _: Schema.Optional[_] => true
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ object SchemaEncoderDeriverSpec extends ZIOSpecDefault {
implicit val schema: Schema[Record] = DeriveSchema.gen[Record]
}

private def encode[A](encoder: SchemaEncoder[_], schema: Schema[A], name: String, optional: Boolean) =
encoder.asInstanceOf[SchemaEncoder[A]].encode(schema, name, optional)

override def spec: Spec[TestEnvironment with Scope, Any] =
suite("SchemaEncoderDeriverSpec")(
test("primitive") {
def encode[A](encoder: SchemaEncoder[_], schema: Schema[A], name: String, optional: Boolean) =
encoder.asInstanceOf[SchemaEncoder[A]].encode(schema, name, optional)

def named(defs: List[PrimitiveDef], names: List[String]) =
defs.zip(names).map { case (schemaDef, name) =>
schemaDef.named(name)
Expand Down Expand Up @@ -87,9 +87,10 @@ object SchemaEncoderDeriverSpec extends ZIOSpecDefault {
.reduce(_ && _)
},
test("record") {
val name = "record"
val encoder = Derive.derive[SchemaEncoder, Record](SchemaEncoderDeriver.default)
val tpeOptional = encoder.encode(Record.schema, "record", optional = true)
val tpeRequired = encoder.encode(Record.schema, "record", optional = false)
val tpeOptional = encoder.encode(Record.schema, name, optional = true)
val tpeRequired = encoder.encode(Record.schema, name, optional = false)
val schemaDef = Schemas.record(
Chunk(
Schemas.int.required.named("a"),
Expand All @@ -98,9 +99,69 @@ object SchemaEncoderDeriverSpec extends ZIOSpecDefault {
)

assertTrue(
tpeOptional == schemaDef.optional.named("record"),
tpeRequired == schemaDef.required.named("record")
tpeOptional == schemaDef.optional.named(name),
tpeRequired == schemaDef.required.named(name)
)
},
test("sequence") {
val name = "mylist"
val encoders: List[SchemaEncoder[_]] =
List(
Derive.derive[SchemaEncoder, List[String]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Boolean]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Byte]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Short]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Int]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Long]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[UUID]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Option[String]]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Option[Boolean]]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Option[Byte]]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Option[Short]]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Option[Int]]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Option[Long]]](SchemaEncoderDeriver.default),
Derive.derive[SchemaEncoder, List[Option[UUID]]](SchemaEncoderDeriver.default)
)
val schemas: List[Schema[_]] =
List(
Schema.list[String],
Schema.list[Int],
Schema.list[Option[String]],
Schema.list[Option[Int]]
)
val elements =
List(
Schemas.string,
Schemas.boolean,
Schemas.byte,
Schemas.short,
Schemas.int,
Schemas.long,
Schemas.uuid
)
val schemaDefs =
(elements.map(_.required) ++ elements.map(_.optional))
.map(_.named("element"))
.map(Schemas.list)
val expectedOptional =
schemaDefs.map(_.optional.named(name))
val expectedRequired =
schemaDefs.map(_.required.named(name))

encoders
.zip(schemas)
.zip(expectedOptional)
.zip(expectedRequired)
.map { case (((encoder, schema), expOptional), expRequired) =>
val tpeOptional = encode(encoder, schema, name, optional = true)
val tpeRequired = encode(encoder, schema, name, optional = false)

assertTrue(
tpeOptional == expOptional,
tpeRequired == expRequired
)
}
.reduce(_ && _)
}
)

Expand Down

0 comments on commit e905b53

Please sign in to comment.