Skip to content

Commit

Permalink
[OPIK-751] Use Mustache for online scoring (#1043)
Browse files Browse the repository at this point in the history
* Triggering LLM calls to score after Traces are received

* Adapting LLM provider to the new format.
Fixing missing traceId in the score when we moved into the batch storing.

* PR change requests

* running spotless

* removing ununsed imports

* [OPIK-751] Use mustache for online scoring

* Fix format

* Add edge case tests

---------

Co-authored-by: Daniel Augusto <daniela@comet.com>
  • Loading branch information
thiagohora and ldaugusto authored Jan 15, 2025
1 parent 69f85b0 commit de08570
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.comet.opik.api.FeedbackScoreBatchItem;
import com.comet.opik.api.ScoreSource;
import com.comet.opik.api.Trace;
import com.comet.opik.utils.MustacheUtils;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -26,7 +27,6 @@
import lombok.Builder;
import lombok.experimental.UtilityClass;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.text.StringSubstitutor;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -103,14 +103,11 @@ static List<ChatMessage> renderMessages(List<LlmAsJudgeMessage> templateMessages
.collect(
Collectors.toMap(MessageVariableMapping::variableName, MessageVariableMapping::valueToReplace));

// will convert all '{{key}}' into 'value'
// TODO: replace with Mustache Java to be in confirm with FE
var templateRenderer = new StringSubstitutor(replacements, "{{", "}}");

// render the message templates from evaluator rule
return templateMessages.stream()
.map(templateMessage -> {
var renderedMessage = templateRenderer.replace(templateMessage.content());
// will convert all '{{key}}' into 'value'
var renderedMessage = MustacheUtils.render(templateMessage.content(), replacements);

return switch (templateMessage.role()) {
case USER -> UserMessage.from(renderedMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import com.comet.opik.api.PromptVersion.PromptVersionPage;
import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.utils.MustacheVariableExtractor;
import com.comet.opik.utils.MustacheUtils;
import com.google.inject.ImplementedBy;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
Expand Down Expand Up @@ -396,7 +396,7 @@ private Set<String> getVariables(String template) {
return null;
}

return MustacheVariableExtractor.extractVariables(template);
return MustacheUtils.extractVariables(template);
}

private EntityAlreadyExistsException newConflict(String alreadyExists) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.comet.opik.api.PromptVersion;
import com.comet.opik.utils.JsonUtils;
import com.comet.opik.utils.MustacheVariableExtractor;
import com.comet.opik.utils.MustacheUtils;
import com.fasterxml.jackson.databind.JsonNode;
import org.jdbi.v3.core.mapper.ColumnMapper;
import org.jdbi.v3.core.statement.StatementContext;
Expand Down Expand Up @@ -35,7 +35,7 @@ private PromptVersion mapObject(JsonNode jsonNode) {
.template(jsonNode.get("template").asText())
.metadata(jsonNode.get("metadata"))
.changeDescription(jsonNode.get("change_description").asText())
.variables(MustacheVariableExtractor.extractVariables(jsonNode.get("template").asText()))
.variables(MustacheUtils.extractVariables(jsonNode.get("template").asText()))
.createdAt(Instant.from(FORMATTER.parse(jsonNode.get("created_at").asText())))
.createdBy(jsonNode.get("created_by").asText())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import com.github.mustachejava.codes.ValueCode;
import lombok.experimental.UtilityClass;

import java.io.IOException;
import java.io.StringReader;
import java.io.StringWriter;
import java.io.UncheckedIOException;
import java.io.Writer;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

@UtilityClass
public class MustacheVariableExtractor {
public class MustacheUtils {

public static final MustacheFactory MF = new DefaultMustacheFactory();

Expand All @@ -31,6 +36,18 @@ public static Set<String> extractVariables(String template) {
return variables;
}

public static String render(String template, Map<String, ?> context) {

Mustache mustache = MF.compile(new StringReader(template), "template");

try (Writer writer = mustache.execute(new StringWriter(), context)) {
writer.flush();
return writer.toString();
} catch (IOException e) {
throw new UncheckedIOException("Failed to render template", e);
}
}

private static void collectVariables(Code[] codes, Set<String> variables) {
for (Code code : codes) {
if (Objects.requireNonNull(code) instanceof ValueCode valueCode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.junit.jupiter.MockitoExtension;
import uk.co.jemos.podam.api.PodamFactory;

import java.math.BigDecimal;
Expand All @@ -42,7 +43,9 @@
@Slf4j
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@DisplayName("LlmAsJudge Message Render")
public class OnlineScoringEngineTest {
@ExtendWith(MockitoExtension.class)
class OnlineScoringEngineTest {

@Mock
AutomationRuleEvaluatorService ruleEvaluatorService;
@Mock
Expand Down Expand Up @@ -98,14 +101,37 @@ public class OnlineScoringEngineTest {
}
""".formatted(outputStr).trim();

String edgeCaseTemplate = "Summary: {{summary}}\\nInstruction: {{ instruction }}\\n\\n";
String testEvaluatorEdgeCase = """
{
"model": { "name": "gpt-4o", "temperature": 0.3 },
"messages": [
{ "role": "USER", "content": "%s" },
{ "role": "SYSTEM", "content": "You're a helpful AI, be cordial." }
],
"variables": {
"summary": "input.questions.question1",
"instruction": "output.output",
"nonUsed": "input.questions.question2",
"toFail1": "metadata.nonexistent.path"
},
"schema": [
{ "name": "Relevance", "type": "INTEGER", "description": "Relevance of the summary" },
{ "name": "Conciseness", "type": "DOUBLE", "description": "Conciseness of the summary" },
{ "name": "Technical Accuracy", "type": "BOOLEAN", "description": "Technical accuracy of the summary" }
]
}
"""
.formatted(edgeCaseTemplate).trim();

private ObjectMapper mapper = new ObjectMapper();

@BeforeEach
void setUp() throws JsonProcessingException {
MockitoAnnotations.openMocks(this);
Mockito.doNothing().when(eventBus).register(Mockito.any());
onlineScoringEventListener = new OnlineScoringEventListener(eventBus, ruleEvaluatorService,
aiProxyService, feedbackScoreService);

var mapper = new ObjectMapper();
evaluatorCode = mapper.readValue(testEvaluator, AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode.class);
trace = Trace.builder().input(mapper.readTree(input)).output(mapper.readTree(output)).build();
}
Expand Down Expand Up @@ -238,4 +264,25 @@ void testParseResponseIntoFeedbacks(String aiMessage, Integer expectedResults) {

}
}

@Test
@DisplayName("render a message template with edge cases")
void testRenderEdgeCaseTemplate() throws JsonProcessingException {

var evaluatorEdgeCase = mapper.readValue(testEvaluatorEdgeCase,
AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode.class);

var renderedMessages = OnlineScoringEngine.renderMessages(evaluatorEdgeCase.messages(),
evaluatorEdgeCase.variables(), trace);

assertThat(renderedMessages).hasSize(2);

var userMessage = renderedMessages.get(0);
assertThat(userMessage.getClass()).isEqualTo(UserMessage.class);
assertThat(((UserMessage) userMessage).singleText()).contains(summaryStr);
assertThat(((UserMessage) userMessage).singleText()).contains(outputStr);

var systemMessage = renderedMessages.get(1);
assertThat(systemMessage.getClass()).isEqualTo(SystemMessage.class);
}
}

0 comments on commit de08570

Please sign in to comment.