Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions ai-core/src/main/scala/wvlet/ai/core/weaver/CaseClassWeaver.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package wvlet.ai.core.weaver

import scala.deriving.Mirror // Keep Mirror for `m`
// erasedValue, summonInline, constValue, error are no longer needed here
import wvlet.ai.core.msgpack.spi.{Packer, Unpacker}

// Removed duplicate ObjectWeaver trait.
// The canonical one is in ObjectWeaver.scala

/**
* Custom exception for errors occurring during weaver packing.
* @param message
* A description of the error.
* @param cause
* The underlying cause of the error, if any.
*/
case class WeaverPackingException(message: String, cause: Throwable = null)
extends RuntimeException(message, cause)

// Companion object removed for this attempt

// Constructor now accepts elementWeavers. Mirror m is still needed for fromProduct.
class CaseClassWeaver[A](private val elementWeavers: List[ObjectWeaver[?]])(using
m: Mirror.ProductOf[A]
) extends ObjectWeaver[A]:

// Internal buildWeavers and elementWeavers val are removed.

override def pack(packer: Packer, v: A, config: WeaverConfig): Unit =
val product = v.asInstanceOf[Product]
if product.productArity != elementWeavers.size then
throw WeaverPackingException(
s"Element count mismatch. Expected: ${elementWeavers.size}, Got: ${product.productArity}"
)
packer.packArrayHeader(elementWeavers.size)

product
.productIterator
.zip(elementWeavers)
.foreach { case (elemValue, weaver) =>
(weaver.asInstanceOf[ObjectWeaver[Any]]).pack(packer, elemValue, config)
}

override def unpack(unpacker: Unpacker, context: WeaverContext): Unit =
val numElements = unpacker.unpackArrayHeader
if numElements != elementWeavers.size then
context.setError(
new IllegalArgumentException(
s"Element count mismatch. Expected: ${elementWeavers.size}, Got: ${numElements}"
)
)
// This point is for future consideration of schema evolution or robust error recovery.
// For now, strict element count matching is enforced.
return

val elements = new Array[Any](elementWeavers.size)
var i = 0
var failed = false

while i < elementWeavers.size && !failed do
val weaver = elementWeavers(i)
val elementContext = WeaverContext(context.config)
// Assuming weaver is ObjectWeaver[?] so direct call is not possible without cast
// However, the element type is unknown here to do a safe cast.
// This part of unpack will need careful handling if we stick to List[ObjectWeaver[?]]
(weaver.asInstanceOf[ObjectWeaver[Any]]).unpack(unpacker, elementContext)

if elementContext.hasError then
context.setError(
new RuntimeException(
s"Failed to unpack element $i: ${elementContext.getError.get.getMessage}",
elementContext.getError.get
)
)
failed = true
else
elements(i) = elementContext.getLastValue
i += 1

if !failed then
try
val instance = m.fromProduct(
new Product:
override def productArity: Int = elements.length
override def productElement(n: Int): Any = elements(n)
override def canEqual(that: Any): Boolean =
that.isInstanceOf[Product] && that.asInstanceOf[Product].productArity == productArity
)
context.setObject(instance)
catch
case e: Throwable =>
context.setError(new RuntimeException("Failed to instantiate case class from product", e))
// Closing brace for try-catch
// Closing brace for if (!failed)
// If failed, context will already have an error set.
// Closing brace for unpack method
end unpack

// Closing brace for CaseClassWeaver class

end CaseClassWeaver
18 changes: 18 additions & 0 deletions ai-core/src/main/scala/wvlet/ai/core/weaver/ObjectWeaver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package wvlet.ai.core.weaver

import wvlet.ai.core.msgpack.spi.{MessagePack, MsgPack, Packer, Unpacker}
import wvlet.ai.core.weaver.codec.{JSONWeaver, PrimitiveWeaver}
import scala.deriving.Mirror
import scala.compiletime.{constValue, summonInline}

trait ObjectWeaver[A]:
def weave(v: A, config: WeaverConfig = WeaverConfig()): MsgPack = toMsgPack(v, config)
Expand Down Expand Up @@ -61,3 +63,19 @@ object ObjectWeaver:
): A = weaver.fromJson(json, config)

export PrimitiveWeaver.given

private inline def buildWeaverList[ElemTypes <: Tuple](
idx: Int
): List[ObjectWeaver[?]] = // Removed inline from idx
inline if idx >= constValue[Tuple.Size[ElemTypes]] then // Base case: index out of bounds
Nil
else
// Summons ObjectWeaver for the element type at the current index
val headWeaver = summonInline[ObjectWeaver[Tuple.Elem[ElemTypes, idx.type]]]
headWeaver :: buildWeaverList[ElemTypes](idx + 1) // Recursive call

inline given [A](using m: Mirror.ProductOf[A]): ObjectWeaver[A] =
val weavers = buildWeaverList[m.MirroredElemTypes](0)
new CaseClassWeaver[A](weavers)(using m)

end ObjectWeaver
Original file line number Diff line number Diff line change
Expand Up @@ -610,4 +610,25 @@ object PrimitiveWeaver:
u.skipValue
context.setError(new IllegalArgumentException(s"Cannot convert ${other} to ListMap"))

inline given optionWeaver[T](using elementWeaver: => ObjectWeaver[T]): ObjectWeaver[Option[T]] =
new ObjectWeaver[Option[T]]:
override def pack(p: Packer, v: Option[T], config: WeaverConfig): Unit =
v match
case Some(value) =>
elementWeaver.pack(p, value, config)
case None =>
p.packNil // Corrected: removed parentheses

override def unpack(u: Unpacker, context: WeaverContext): Unit =
if u.tryUnpackNil then
context.setObject(None)
else
// Need a fresh context for the element, in case of error or nested structures
val elementContext = WeaverContext(context.config)
elementWeaver.unpack(u, elementContext)
if elementContext.hasError then
context.setError(elementContext.getError.get)
else
context.setObject(Some(elementContext.getLastValue.asInstanceOf[T]))

end PrimitiveWeaver
98 changes: 98 additions & 0 deletions ai-core/src/test/scala/wvlet/ai/core/weaver/WeaverTest.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
package wvlet.ai.core.weaver

import wvlet.airspec.AirSpec
import wvlet.ai.core.weaver.ObjectWeaver // Ensure ObjectWeaver is imported if not already fully covered
import scala.jdk.CollectionConverters.*

// Define case classes for testing
case class SimpleCase(i: Int, s: String, b: Boolean)
case class NestedCase(name: String, simple: SimpleCase)
case class OptionCase(id: Int, opt: Option[String])
case class SeqCase(key: String, values: Seq[Int])

class WeaverTest extends AirSpec:

test("weave int") {
Expand Down Expand Up @@ -497,4 +504,95 @@ class WeaverTest extends AirSpec:
result.get.getMessage.contains("Cannot convert") shouldBe true
}

// Tests for SimpleCase
test("weave SimpleCase") {
val v = SimpleCase(10, "test case", true)
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[SimpleCase](msgpack)
v shouldBe v2
}

test("SimpleCase toJson") {
val v = SimpleCase(20, "json test", false)
val json = ObjectWeaver.toJson(v)
val v2 = ObjectWeaver.fromJson[SimpleCase](json)
v shouldBe v2
}

// Tests for NestedCase
test("weave NestedCase") {
val v = NestedCase("nested", SimpleCase(30, "inner", true))
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[NestedCase](msgpack)
v shouldBe v2
}

test("NestedCase toJson") {
val v = NestedCase("nested json", SimpleCase(40, "inner json", false))
val json = ObjectWeaver.toJson(v)
val v2 = ObjectWeaver.fromJson[NestedCase](json)
v shouldBe v2
}

// Tests for OptionCase
test("weave OptionCase with Some") {
val v = OptionCase(50, Some("option value"))
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[OptionCase](msgpack)
v shouldBe v2
}

test("OptionCase toJson with Some") {
val v = OptionCase(60, Some("option json"))
val json = ObjectWeaver.toJson(v)
val v2 = ObjectWeaver.fromJson[OptionCase](json)
v shouldBe v2
}

test("weave OptionCase with None") {
val v = OptionCase(70, None)
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[OptionCase](msgpack)
v shouldBe v2
}

test("OptionCase toJson with None") {
val v = OptionCase(80, None)
val json = ObjectWeaver.toJson(v)
// Check against expected JSON for None, as direct None might be ambiguous for fromJson
// Depending on JSON library, None might be represented as null or omitted
// Assuming it's represented as null or handled by the weaver
val v2 = ObjectWeaver.fromJson[OptionCase](json)
v shouldBe v2
}

// Tests for SeqCase
test("weave SeqCase with non-empty Seq") {
val v = SeqCase("seq test", Seq(1, 2, 3, 4))
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[SeqCase](msgpack)
v shouldBe v2
}

test("SeqCase toJson with non-empty Seq") {
val v = SeqCase("seq json", Seq(5, 6, 7))
val json = ObjectWeaver.toJson(v)
val v2 = ObjectWeaver.fromJson[SeqCase](json)
v shouldBe v2
}

test("weave SeqCase with empty Seq") {
val v = SeqCase("empty seq", Seq.empty[Int])
val msgpack = ObjectWeaver.weave(v)
val v2 = ObjectWeaver.unweave[SeqCase](msgpack)
v shouldBe v2
}

test("SeqCase toJson with empty Seq") {
val v = SeqCase("empty seq json", Seq.empty[Int])
val json = ObjectWeaver.toJson(v)
val v2 = ObjectWeaver.fromJson[SeqCase](json)
v shouldBe v2
}

end WeaverTest
Loading