-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Experimental reflective inference of json schema from a case class - …
…example: CreateChatCompletionJsonForCaseClass
- Loading branch information
1 parent
452eaaf
commit 5e89f4b
Showing
6 changed files
with
229 additions
and
11 deletions.
There are no files selected for viewing
101 changes: 101 additions & 0 deletions
101
openai-core/src/main/scala/io/cequence/openaiscala/service/JsonSchemaReflectionHelper.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
package io.cequence.openaiscala.service | ||
|
||
import io.cequence.openaiscala.OpenAIScalaClientException | ||
import io.cequence.openaiscala.domain.JsonSchema | ||
|
||
import scala.reflect.runtime.universe._ | ||
import io.cequence.openaiscala.service.ReflectionUtil._ | ||
|
||
// This is experimental and subject to change | ||
trait JsonSchemaReflectionHelper { | ||
|
||
def jsonSchemaFor[T: TypeTag]( | ||
dateAsNumber: Boolean = false, | ||
useRuntimeMirror: Boolean = false | ||
): JsonSchema = { | ||
val mirror = if (useRuntimeMirror) runtimeMirror(getClass.getClassLoader) else typeTag[T].mirror | ||
asJsonSchema(typeOf[T], mirror, dateAsNumber) | ||
} | ||
|
||
private def asJsonSchema( | ||
typ: Type, | ||
mirror: Mirror, | ||
dateAsNumber: Boolean = false | ||
): JsonSchema = | ||
typ match { | ||
// number | ||
case t | ||
if t matches (typeOf[Int], typeOf[Long], typeOf[Byte], typeOf[Double], typeOf[ | ||
Float | ||
], typeOf[BigDecimal], typeOf[BigInt]) => | ||
JsonSchema.Number() | ||
|
||
// boolean | ||
case t if t matches typeOf[Boolean] => | ||
JsonSchema.Boolean() | ||
|
||
// string | ||
case t if t matches (typeOf[String], typeOf[java.util.UUID]) => | ||
JsonSchema.String() | ||
|
||
// enum | ||
case t if t subMatches (typeOf[Enumeration#Value], typeOf[Enum[_]]) => | ||
JsonSchema.String() | ||
|
||
// date | ||
case t if t matches (typeOf[java.util.Date], typeOf[org.joda.time.DateTime]) => | ||
if (dateAsNumber) JsonSchema.Number() else JsonSchema.String() | ||
|
||
// array/seq | ||
case t if t subMatches (typeOf[Seq[_]], typeOf[Set[_]], typeOf[Array[_]]) => | ||
val innerType = t.typeArgs.head | ||
val itemsSchema = asJsonSchema(innerType, mirror, dateAsNumber) | ||
JsonSchema.Array(itemsSchema) | ||
|
||
case t if isCaseClass(t) => | ||
caseClassAsJsonSchema(t, mirror, dateAsNumber) | ||
|
||
// map - TODO | ||
case t if t subMatches (typeOf[Map[String, _]]) => | ||
throw new OpenAIScalaClientException( | ||
"JSON schema reflection doesn't support 'Map' type." | ||
) | ||
|
||
// either value - TODO | ||
case t if t matches typeOf[Either[_, _]] => | ||
throw new OpenAIScalaClientException( | ||
"JSON schema reflection doesn't support 'Either' type." | ||
) | ||
|
||
// otherwise | ||
case _ => | ||
val typeName = | ||
if (typ <:< typeOf[Option[_]]) | ||
s"Option[${typ.typeArgs.head.typeSymbol.fullName}]" | ||
else | ||
typ.typeSymbol.fullName | ||
throw new OpenAIScalaClientException(s"Type ${typeName} unknown.") | ||
} | ||
|
||
private def caseClassAsJsonSchema( | ||
typ: Type, | ||
mirror: Mirror, | ||
dateAsNumber: Boolean | ||
): JsonSchema = { | ||
val memberNamesAndTypes = getCaseClassMemberNamesAndTypes(typ) | ||
|
||
val fieldSchemas = memberNamesAndTypes.toSeq.map { | ||
case (fieldName: String, memberType: Type) => | ||
val fieldSchema = asJsonSchema(memberType, mirror, dateAsNumber) | ||
(fieldName, fieldSchema, memberType.isOption()) | ||
} | ||
|
||
val required = fieldSchemas.collect { case (fieldName, _, false) => fieldName } | ||
val properties = fieldSchemas.map { case (fieldName, schema, _) => (fieldName, schema) } | ||
|
||
JsonSchema.Object( | ||
properties.toMap, | ||
required | ||
) | ||
} | ||
} |
41 changes: 41 additions & 0 deletions
41
openai-core/src/main/scala/io/cequence/openaiscala/service/ReflectionUtil.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package io.cequence.openaiscala.service | ||
|
||
import scala.reflect.runtime.universe._ | ||
|
||
object ReflectionUtil { | ||
|
||
implicit class InfixOp(val typ: Type) { | ||
|
||
private val optionInnerType = | ||
if (typ <:< typeOf[Option[_]]) | ||
Some(typ.typeArgs.head) | ||
else | ||
None | ||
|
||
def matches(types: Type*): Boolean = | ||
types.exists(typ =:= _) || | ||
(optionInnerType.isDefined && types.exists(optionInnerType.get =:= _)) | ||
|
||
def subMatches(types: Type*): Boolean = | ||
types.exists(typ <:< _) || | ||
(optionInnerType.isDefined && types.exists(optionInnerType.get <:< _)) | ||
|
||
def isOption(): Boolean = | ||
typ <:< typeOf[Option[_]] | ||
} | ||
|
||
def isCaseClass(runType: Type): Boolean = | ||
runType.members.exists(m => m.isMethod && m.asMethod.isCaseAccessor) | ||
|
||
def shortName(symbol: Symbol): String = { | ||
val paramFullName = symbol.fullName | ||
paramFullName.substring(paramFullName.lastIndexOf('.') + 1, paramFullName.length) | ||
} | ||
|
||
def getCaseClassMemberNamesAndTypes( | ||
runType: Type | ||
): Traversable[(String, Type)] = | ||
runType.decls.sorted.collect { | ||
case m: MethodSymbol if m.isCaseAccessor => (shortName(m), m.returnType) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
47 changes: 47 additions & 0 deletions
47
...rc/main/scala/io/cequence/openaiscala/examples/CreateChatCompletionJsonForCaseClass.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package io.cequence.openaiscala.examples | ||
|
||
import io.cequence.openaiscala.domain._ | ||
import io.cequence.openaiscala.domain.settings.JsonSchemaDef | ||
import io.cequence.openaiscala.examples.fixtures.TestFixtures | ||
import io.cequence.openaiscala.service.{JsonSchemaReflectionHelper, OpenAIServiceConsts} | ||
import play.api.libs.json.Json | ||
|
||
import scala.concurrent.Future | ||
|
||
// experimental | ||
object CreateChatCompletionJsonForCaseClass extends Example with TestFixtures with JsonSchemaReflectionHelper with OpenAIServiceConsts { | ||
|
||
private val messages = Seq( | ||
SystemMessage(capitalsPrompt), | ||
UserMessage("List only african countries") | ||
) | ||
|
||
// Case class(es) | ||
private case class CapitalsResponse( | ||
countries: Seq[Country], | ||
) | ||
|
||
private case class Country( | ||
country: String, | ||
capital: String | ||
) | ||
|
||
// json schema def | ||
private val jsonSchemaDef: JsonSchemaDef = JsonSchemaDef( | ||
name = "capitals_response", | ||
strict = true, | ||
// reflective json schema for case class | ||
structure = jsonSchemaFor[CapitalsResponse]() | ||
) | ||
|
||
override protected def run: Future[_] = | ||
service | ||
.createChatCompletion( | ||
messages = messages, | ||
settings = DefaultSettings.createJsonChatCompletion(jsonSchemaDef) | ||
) | ||
.map { response => | ||
val json = Json.parse(messageContent(response)) | ||
println(Json.prettyPrint(json)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters