Skip to content

Commit

Permalink
Change JsonOutputParser to MapOutputParser
Browse files Browse the repository at this point in the history
  • Loading branch information
markpollack committed Aug 19, 2023
1 parent 8e276ca commit f2f81cd
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import org.springframework.messaging.converter.MessageConverter;

public abstract class AbstractMessageConverterOutputParser implements OutputParser<Object> {
public abstract class AbstractMessageConverterOutputParser<T> implements OutputParser<T> {

private MessageConverter messageConverter;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
package org.springframework.ai.parser;

import com.fasterxml.jackson.databind.JsonNode;
import org.springframework.messaging.Message;
import org.springframework.messaging.converter.MappingJackson2MessageConverter;
import org.springframework.messaging.converter.MessageConverter;
import org.springframework.messaging.support.MessageBuilder;

import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;

/**
* Uses Jackson
*/
public class JsonOutputParser extends AbstractMessageConverterOutputParser {
public class MapOutputParser extends AbstractMessageConverterOutputParser<Map<String, Object>> {

private Class dataType;

public JsonOutputParser(Class dataType) {
public MapOutputParser() {
super(new MappingJackson2MessageConverter());
this.dataType = dataType;
}

@Override
public Object parse(String text) {
public Map<String, Object> parse(String text) {
Message<?> message = MessageBuilder.withPayload(text.getBytes(StandardCharsets.UTF_8)).build();
return getMessageConverter().fromMessage(message, dataType);
return (Map) getMessageConverter().fromMessage(message, HashMap.class);
}

@Override
Expand All @@ -32,7 +31,7 @@ public String getFormat() {
The data structure for the JSON should match this Java class: %s
Do not include any explanations, only provide a RFC8259 compliant JSON response following this format without deviation.
""";
return String.format(raw, dataType.getCanonicalName());
return String.format(raw, "java.util.HashMap");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import org.springframework.ai.client.AiResponse;
import org.springframework.ai.client.Generation;
import org.springframework.ai.openai.testutils.AbstractIntegrationTest;
import org.springframework.ai.parser.JsonOutputParser;
import org.springframework.ai.parser.ListOutputParser;
import org.springframework.ai.parser.MapOutputParser;
import org.springframework.ai.prompt.Prompt;
import org.springframework.ai.prompt.PromptTemplate;
import org.springframework.ai.prompt.SystemPromptTemplate;
Expand All @@ -17,7 +17,6 @@
import org.springframework.core.io.Resource;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -64,8 +63,8 @@ void outputParser() {
}

@Test
void jsonOutputParser() {
JsonOutputParser outputParser = new JsonOutputParser(HashMap.class);
void mapOutputParser() {
MapOutputParser outputParser = new MapOutputParser();

String format = outputParser.getFormat();
String template = """
Expand All @@ -77,10 +76,8 @@ void jsonOutputParser() {
Prompt prompt = new Prompt(promptTemplate.createMessage());
Generation generation = openAiClient.generate(prompt).getGeneration();

Object result = outputParser.parse(generation.getText());
System.out.println(result);
assertThat(result).isNotNull();
assertThat(((Map) result).get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
Map<String, Object> result = outputParser.parse(generation.getText());
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));

}

Expand Down

0 comments on commit f2f81cd

Please sign in to comment.