Skip to content

Commit

Permalink
OpenAi: Add support for structured outputs and JSON schema
Browse files Browse the repository at this point in the history
- Added support for OpenAI's structured outputs feature, which allows specifying a JSON schema for the model to match
- Introduced new record to configure the desired response format
- Added support for configuring the response format via application properties or the chat options builder
- Extend teh BeanOutputConverter to help generate JSON schema from a target domain object and convert the response.
- Added comprehensive tests to cover the new response format functionality

Resolves #1196
  • Loading branch information
tzolov authored and markpollack committed Aug 8, 2024
1 parent 866b262 commit 91afed5
Show file tree
Hide file tree
Showing 7 changed files with 634 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
Expand Down Expand Up @@ -521,7 +522,53 @@ public static Object FUNCTION(String functionName) {
*/
@JsonInclude(Include.NON_NULL)
public record ResponseFormat(
@JsonProperty("type") String type) {
@JsonProperty("type") Type type,
@JsonProperty("json_schema") JsonSchema jsonSchema ) {

public enum Type {
/**
* Enables JSON mode, which guarantees the message
* the model generates is valid JSON.
*/
@JsonProperty("json_object")
JSON_OBJECT,

/**
* Enables Structured Outputs which guarantees the model
* will match your supplied JSON schema.
*/
@JsonProperty("json_schema")
JSON_SCHEMA
}

@JsonInclude(Include.NON_NULL)
public record JsonSchema(
@JsonProperty("name") String name,
@JsonProperty("schema") Map<String, Object> schema,
@JsonProperty("strict") Boolean strict) {

public JsonSchema(String name, String schema) {
this(name, ModelOptionsUtils.jsonToMap(schema), true);
}

public JsonSchema(String name, String schema, Boolean strict) {
this(StringUtils.hasText(name)? name : "custom_response_format_schema", ModelOptionsUtils.jsonToMap(schema), strict);
}
}

public ResponseFormat(Type type) {
this(type, (JsonSchema) null);
}

public ResponseFormat(Type type, String jsonSchena) {
this(type, "custom_response_format_schema", jsonSchena, true);
}

@ConstructorBinding
public ResponseFormat(Type type, String name, String schema, Boolean strict) {
this(type, StringUtils.hasText(schema)? new JsonSchema(name, schema, strict): null);
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,46 @@
*/
package org.springframework.ai.openai.chat;

import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat;
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;

import static org.assertj.core.api.Assertions.assertThat;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JacksonException;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;

/**
* @author Christian Tzolov
*/
@SpringBootTest(classes = OpenAiChatModel2IT.Config.class)
@SpringBootTest(classes = OpenAiChatModelResponseFormatIT.Config.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiChatModel2IT {
public class OpenAiChatModelResponseFormatIT {

private final Logger logger = LoggerFactory.getLogger(getClass());

@Autowired
private OpenAiChatModel openAiChatModel;

@Test
void responseFormatTest() throws JsonMappingException, JsonProcessingException {
void jsonObject() throws JsonMappingException, JsonProcessingException {

// 400 - ResponseError[error=Error[message='json' is not one of ['json_object',
// 'text'] -
Expand All @@ -64,7 +67,50 @@ void responseFormatTest() throws JsonMappingException, JsonProcessingException {

Prompt prompt = new Prompt("List 8 planets. Use JSON response",
OpenAiChatOptions.builder()
.withResponseFormat(new ChatCompletionRequest.ResponseFormat("json_object"))
.withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_OBJECT))
.build());

ChatResponse response = this.openAiChatModel.call(prompt);

assertThat(response).isNotNull();

String content = response.getResult().getOutput().getContent();

logger.info("Response content: {}", content);

assertThat(isValidJson(content)).isTrue();
}

@Test
void jsonSchema() throws JsonMappingException, JsonProcessingException {

var jsonSchema = """
{
"type": "object",
"properties": {
"steps": {
"type": "array",
"items": {
"type": "object",
"properties": {
"explanation": { "type": "string" },
"output": { "type": "string" }
},
"required": ["explanation", "output"],
"additionalProperties": false
}
},
"final_answer": { "type": "string" }
},
"required": ["steps", "final_answer"],
"additionalProperties": false
}
""";

Prompt prompt = new Prompt("how can I solve 8x + 7 = -23",
OpenAiChatOptions.builder()
.withModel(ChatModel.GPT_4_O_MINI)
.withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, jsonSchema))
.build());

ChatResponse response = this.openAiChatModel.call(prompt);
Expand All @@ -78,6 +124,47 @@ void responseFormatTest() throws JsonMappingException, JsonProcessingException {
assertThat(isValidJson(content)).isTrue();
}

@Test
void jsonSchemaBeanConverter() throws JsonMappingException, JsonProcessingException {

record MathReasoning(@JsonProperty(required = true, value = "steps") Steps steps,
@JsonProperty(required = true, value = "final_answer") String finalAnswer) {

record Steps(@JsonProperty(required = true, value = "items") Items[] items) {

record Items(@JsonProperty(required = true, value = "explanation") String explanation,
@JsonProperty(required = true, value = "output") String output) {
}
}
}

var outputConverter = new BeanOutputConverter<>(MathReasoning.class);

var jsonSchema1 = outputConverter.getJsonSchema();

System.out.println(jsonSchema1);

Prompt prompt = new Prompt("how can I solve 8x + 7 = -23",
OpenAiChatOptions.builder()
.withModel(ChatModel.GPT_4_O_MINI)
.withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, jsonSchema1))
.build());

ChatResponse response = this.openAiChatModel.call(prompt);

assertThat(response).isNotNull();

String content = response.getResult().getOutput().getContent();

logger.info("Response content: {}", content);

MathReasoning mathReasoning = outputConverter.convert(content);

System.out.println(mathReasoning);

assertThat(isValidJson(content)).isTrue();
}

private static ObjectMapper MAPPER = new ObjectMapper().enable(DeserializationFeature.FAIL_ON_TRAILING_TOKENS);

public static boolean isValidJson(String json) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.github.victools.jsonschema.generator.Option;
import com.github.victools.jsonschema.generator.SchemaGenerator;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfig;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
import com.github.victools.jsonschema.module.jackson.JacksonModule;
import com.github.victools.jsonschema.module.jackson.JacksonOption;

/**
* An implementation of {@link StructuredOutputConverter} that transforms the LLM output
Expand Down Expand Up @@ -140,9 +142,10 @@ private BeanOutputConverter(TypeReference<T> typeRef, ObjectMapper objectMapper)
* Generates the JSON schema for the target type.
*/
private void generateSchema() {
JacksonModule jacksonModule = new JacksonModule();
JacksonModule jacksonModule = new JacksonModule(JacksonOption.RESPECT_JSONPROPERTY_REQUIRED);
SchemaGeneratorConfigBuilder configBuilder = new SchemaGeneratorConfigBuilder(DRAFT_2020_12, PLAIN_JSON)
.with(jacksonModule);
.with(jacksonModule)
.with(Option.FORBIDDEN_ADDITIONAL_PROPERTIES_BY_DEFAULT);
SchemaGeneratorConfig config = configBuilder.build();
SchemaGenerator generator = new SchemaGenerator(config);
JsonNode jsonNode = generator.generateSchema(this.typeRef.getType());
Expand Down Expand Up @@ -205,4 +208,12 @@ public String getFormat() {
return String.format(template, this.jsonSchema);
}

/**
* Provides the generated JSON schema for the target type.
* @return The generated JSON schema.
*/
public String getJsonSchema() {
return this.jsonSchema;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ public void formatClassType() {
"someString" : {
"type" : "string"
}
}
},
"additionalProperties" : false
}```
""");
}
Expand All @@ -156,7 +157,8 @@ public void formatTypeReference() {
"someString" : {
"type" : "string"
}
}
},
"additionalProperties" : false
}```
""");
}
Expand All @@ -181,7 +183,8 @@ public void formatTypeReferenceArray() {
"someString" : {
"type" : "string"
}
}
},
"additionalProperties" : false
}
}```
""");
Expand All @@ -199,7 +202,8 @@ public void formatClassTypeWithAnnotations() {
"type" : "string",
"description" : "string_property_description"
}
}
},
"additionalProperties" : false
}```
""");
}
Expand All @@ -217,7 +221,8 @@ public void formatTypeReferenceWithAnnotations() {
"type" : "string",
"description" : "string_property_description"
}
}
},
"additionalProperties" : false
}```
""");
}
Expand Down
Loading

0 comments on commit 91afed5

Please sign in to comment.