Skip to content

Commit

Permalink
Experimental reflective inference of json schema from a case class - …
Browse files Browse the repository at this point in the history
…example: CreateChatCompletionJsonForCaseClass
  • Loading branch information
peterbanda committed Sep 17, 2024
1 parent 452eaaf commit 5e89f4b
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package io.cequence.openaiscala.service

import io.cequence.openaiscala.OpenAIScalaClientException
import io.cequence.openaiscala.domain.JsonSchema

import scala.reflect.runtime.universe._
import io.cequence.openaiscala.service.ReflectionUtil._

// This is experimental and subject to change
trait JsonSchemaReflectionHelper {

def jsonSchemaFor[T: TypeTag](
dateAsNumber: Boolean = false,
useRuntimeMirror: Boolean = false
): JsonSchema = {
val mirror = if (useRuntimeMirror) runtimeMirror(getClass.getClassLoader) else typeTag[T].mirror
asJsonSchema(typeOf[T], mirror, dateAsNumber)
}

private def asJsonSchema(
typ: Type,
mirror: Mirror,
dateAsNumber: Boolean = false
): JsonSchema =
typ match {
// number
case t
if t matches (typeOf[Int], typeOf[Long], typeOf[Byte], typeOf[Double], typeOf[
Float
], typeOf[BigDecimal], typeOf[BigInt]) =>
JsonSchema.Number()

// boolean
case t if t matches typeOf[Boolean] =>
JsonSchema.Boolean()

// string
case t if t matches (typeOf[String], typeOf[java.util.UUID]) =>
JsonSchema.String()

// enum
case t if t subMatches (typeOf[Enumeration#Value], typeOf[Enum[_]]) =>
JsonSchema.String()

// date
case t if t matches (typeOf[java.util.Date], typeOf[org.joda.time.DateTime]) =>
if (dateAsNumber) JsonSchema.Number() else JsonSchema.String()

// array/seq
case t if t subMatches (typeOf[Seq[_]], typeOf[Set[_]], typeOf[Array[_]]) =>
val innerType = t.typeArgs.head
val itemsSchema = asJsonSchema(innerType, mirror, dateAsNumber)
JsonSchema.Array(itemsSchema)

case t if isCaseClass(t) =>
caseClassAsJsonSchema(t, mirror, dateAsNumber)

// map - TODO
case t if t subMatches (typeOf[Map[String, _]]) =>
throw new OpenAIScalaClientException(
"JSON schema reflection doesn't support 'Map' type."
)

// either value - TODO
case t if t matches typeOf[Either[_, _]] =>
throw new OpenAIScalaClientException(
"JSON schema reflection doesn't support 'Either' type."
)

// otherwise
case _ =>
val typeName =
if (typ <:< typeOf[Option[_]])
s"Option[${typ.typeArgs.head.typeSymbol.fullName}]"
else
typ.typeSymbol.fullName
throw new OpenAIScalaClientException(s"Type ${typeName} unknown.")
}

private def caseClassAsJsonSchema(
typ: Type,
mirror: Mirror,
dateAsNumber: Boolean
): JsonSchema = {
val memberNamesAndTypes = getCaseClassMemberNamesAndTypes(typ)

val fieldSchemas = memberNamesAndTypes.toSeq.map {
case (fieldName: String, memberType: Type) =>
val fieldSchema = asJsonSchema(memberType, mirror, dateAsNumber)
(fieldName, fieldSchema, memberType.isOption())
}

val required = fieldSchemas.collect { case (fieldName, _, false) => fieldName }
val properties = fieldSchemas.map { case (fieldName, schema, _) => (fieldName, schema) }

JsonSchema.Object(
properties.toMap,
required
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.cequence.openaiscala.service

import scala.reflect.runtime.universe._

object ReflectionUtil {

implicit class InfixOp(val typ: Type) {

private val optionInnerType =
if (typ <:< typeOf[Option[_]])
Some(typ.typeArgs.head)
else
None

def matches(types: Type*): Boolean =
types.exists(typ =:= _) ||
(optionInnerType.isDefined && types.exists(optionInnerType.get =:= _))

def subMatches(types: Type*): Boolean =
types.exists(typ <:< _) ||
(optionInnerType.isDefined && types.exists(optionInnerType.get <:< _))

def isOption(): Boolean =
typ <:< typeOf[Option[_]]
}

def isCaseClass(runType: Type): Boolean =
runType.members.exists(m => m.isMethod && m.asMethod.isCaseAccessor)

def shortName(symbol: Symbol): String = {
val paramFullName = symbol.fullName
paramFullName.substring(paramFullName.lastIndexOf('.') + 1, paramFullName.length)
}

def getCaseClassMemberNamesAndTypes(
runType: Type
): Traversable[(String, Type)] =
runType.decls.sorted.collect {
case m: MethodSymbol if m.isCaseAccessor => (shortName(m), m.returnType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.cequence.openaiscala.examples
import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.examples.fixtures.TestFixtures
import io.cequence.openaiscala.service.OpenAIServiceConsts
import play.api.libs.json.Json

import scala.concurrent.Future

Expand All @@ -17,9 +18,10 @@ object CreateChatCompletionJson extends Example with TestFixtures with OpenAISer
service
.createChatCompletion(
messages = messages,
settings = DefaultSettings.createJsonChatCompletion(capitalsSchema)
settings = DefaultSettings.createJsonChatCompletion(capitalsSchemaDef1)
)
.map { content =>
printMessageContent(content)
.map { response =>
val json = Json.parse(messageContent(response))
println(Json.prettyPrint(json))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package io.cequence.openaiscala.examples

import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.JsonSchemaDef
import io.cequence.openaiscala.examples.fixtures.TestFixtures
import io.cequence.openaiscala.service.{JsonSchemaReflectionHelper, OpenAIServiceConsts}
import play.api.libs.json.Json

import scala.concurrent.Future

// experimental
object CreateChatCompletionJsonForCaseClass extends Example with TestFixtures with JsonSchemaReflectionHelper with OpenAIServiceConsts {

private val messages = Seq(
SystemMessage(capitalsPrompt),
UserMessage("List only african countries")
)

// Case class(es)
private case class CapitalsResponse(
countries: Seq[Country],
)

private case class Country(
country: String,
capital: String
)

// json schema def
private val jsonSchemaDef: JsonSchemaDef = JsonSchemaDef(
name = "capitals_response",
strict = true,
// reflective json schema for case class
structure = jsonSchemaFor[CapitalsResponse]()
)

override protected def run: Future[_] =
service
.createChatCompletion(
messages = messages,
settings = DefaultSettings.createJsonChatCompletion(jsonSchemaDef)
)
.map { response =>
val json = Json.parse(messageContent(response))
println(Json.prettyPrint(json))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ object CreateChatCompletionStreamedJson
service
.createChatCompletionStreamed(
messages = messages,
settings = DefaultSettings.createJsonChatCompletion(capitalsSchema)
settings = DefaultSettings.createJsonChatCompletion(capitalsSchemaDef1)
)
.runWith(
Sink.foreach { completion =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,46 @@
package io.cequence.openaiscala.examples.fixtures

import io.cequence.openaiscala.domain.settings.JsonSchema
import io.cequence.openaiscala.domain.JsonSchema
import io.cequence.openaiscala.domain.settings.JsonSchemaDef
import org.slf4j.LoggerFactory

trait TestFixtures {

val logger = LoggerFactory.getLogger(getClass)

val capitalsPrompt = "Give me the most populous capital cities in JSON format."

val capitalsSchema = JsonSchema(
name = "capitals_response",
strict = true,
structure = capitalsSchemaStructure
val capitalsSchemaDef1 = capitalsSchemaDefAux(Left(capitalsSchema1))

val capitalsSchemaDef2 = capitalsSchemaDefAux(Right(capitalsSchema2))

def capitalsSchemaDefAux(schema: Either[JsonSchema, Map[String, Any]]) =
JsonSchemaDef(
name = "capitals_response",
strict = true,
structure = schema
)

lazy protected val capitalsSchema1 = JsonSchema.Object(
properties = Map(
"countries" -> JsonSchema.Array(
items = JsonSchema.Object(
properties = Map(
"country" -> JsonSchema.String(
description = Some("The name of the country")
),
"capital" -> JsonSchema.String(
description = Some("The capital city of the country")
)
),
required = Seq("country", "capital")
)
)
),
required = Seq("countries")
)

lazy private val capitalsSchemaStructure = Map(
lazy protected val capitalsSchema2 = Map(
"type" -> "object",
"properties" -> Map(
"countries" -> Map(
Expand All @@ -35,5 +63,4 @@ trait TestFixtures {
),
"required" -> Seq("countries")
)

}

0 comments on commit 5e89f4b

Please sign in to comment.