Skip to content

Commit

Permalink
[OPIK-442] Add filtering for spans and traces based on cost/model/pro…
Browse files Browse the repository at this point in the history
…vider (#723)
  • Loading branch information
BorisTkachenko authored Nov 26, 2024
1 parent 4f28c95 commit 3e8d7a9
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ public interface Field {
String INPUT_QUERY_PARAM = "input";
String OUTPUT_QUERY_PARAM = "output";
String METADATA_QUERY_PARAM = "metadata";
String MODEL_QUERY_PARAM = "model";
String PROVIDER_QUERY_PARAM = "provider";
String TOTAL_ESTIMATED_COST_QUERY_PARAM = "total_estimated_cost";
String TAGS_QUERY_PARAM = "tags";
String USAGE_COMPLETION_TOKENS_QUERY_PARAM = "usage.completion_tokens";
String USAGE_PROMPT_TOKENS_QUERY_PARAM = "usage.prompt_tokens";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ public enum SpanField implements Field {
INPUT(INPUT_QUERY_PARAM, FieldType.STRING),
OUTPUT(OUTPUT_QUERY_PARAM, FieldType.STRING),
METADATA(METADATA_QUERY_PARAM, FieldType.DICTIONARY),
MODEL(MODEL_QUERY_PARAM, FieldType.STRING),
PROVIDER(PROVIDER_QUERY_PARAM, FieldType.STRING),
TOTAL_ESTIMATED_COST(TOTAL_ESTIMATED_COST_QUERY_PARAM, FieldType.NUMBER),
TAGS(TAGS_QUERY_PARAM, FieldType.LIST),
USAGE_COMPLETION_TOKENS(USAGE_COMPLETION_TOKENS_QUERY_PARAM, FieldType.NUMBER),
USAGE_PROMPT_TOKENS(USAGE_PROMPT_TOKENS_QUERY_PARAM, FieldType.NUMBER),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public enum TraceField implements Field {
INPUT(INPUT_QUERY_PARAM, FieldType.STRING),
OUTPUT(OUTPUT_QUERY_PARAM, FieldType.STRING),
METADATA(METADATA_QUERY_PARAM, FieldType.DICTIONARY),
TOTAL_ESTIMATED_COST(TOTAL_ESTIMATED_COST_QUERY_PARAM, FieldType.NUMBER),
TAGS(TAGS_QUERY_PARAM, FieldType.LIST),
USAGE_COMPLETION_TOKENS(USAGE_COMPLETION_TOKENS_QUERY_PARAM, FieldType.NUMBER),
USAGE_PROMPT_TOKENS(USAGE_PROMPT_TOKENS_QUERY_PARAM, FieldType.NUMBER),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ WHERE created_at BETWEEN toStartOfDay(yesterday()) AND toStartOfDay(today())
FROM (
SELECT
t.id,
sumMap(s.usage) as usage
sumMap(s.usage) as usage,
sum(s.total_estimated_cost) as total_estimated_cost
FROM (
SELECT
id
Expand Down Expand Up @@ -408,7 +409,8 @@ AND id in (
LEFT JOIN (
SELECT
trace_id,
usage
usage,
total_estimated_cost
FROM spans
WHERE workspace_id = :workspace_id
AND project_id = :project_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ public class FilterQueryBuilder {
private static final String INPUT_ANALYTICS_DB = "input";
private static final String OUTPUT_ANALYTICS_DB = "output";
private static final String METADATA_ANALYTICS_DB = "metadata";
private static final String MODEL_ANALYTICS_DB = "model";
private static final String PROVIDER_ANALYTICS_DB = "provider";
private static final String TOTAL_ESTIMATED_COST_ANALYTICS_DB = "total_estimated_cost";
private static final String TAGS_ANALYTICS_DB = "tags";
private static final String USAGE_COMPLETION_TOKENS_ANALYTICS_DB = "usage['completion_tokens']";
private static final String USAGE_PROMPT_TOKENS_ANALYTICS_DB = "usage['prompt_tokens']";
Expand Down Expand Up @@ -95,6 +98,7 @@ public class FilterQueryBuilder {
.put(TraceField.INPUT, INPUT_ANALYTICS_DB)
.put(TraceField.OUTPUT, OUTPUT_ANALYTICS_DB)
.put(TraceField.METADATA, METADATA_ANALYTICS_DB)
.put(TraceField.TOTAL_ESTIMATED_COST, TOTAL_ESTIMATED_COST_ANALYTICS_DB)
.put(TraceField.TAGS, TAGS_ANALYTICS_DB)
.put(TraceField.USAGE_COMPLETION_TOKENS, USAGE_COMPLETION_TOKENS_ANALYTICS_DB)
.put(TraceField.USAGE_PROMPT_TOKENS, USAGE_PROMPT_TOKENS_ANALYTICS_DB)
Expand All @@ -111,6 +115,9 @@ public class FilterQueryBuilder {
.put(SpanField.INPUT, INPUT_ANALYTICS_DB)
.put(SpanField.OUTPUT, OUTPUT_ANALYTICS_DB)
.put(SpanField.METADATA, METADATA_ANALYTICS_DB)
.put(SpanField.MODEL, MODEL_ANALYTICS_DB)
.put(SpanField.PROVIDER, PROVIDER_ANALYTICS_DB)
.put(SpanField.TOTAL_ESTIMATED_COST, TOTAL_ESTIMATED_COST_ANALYTICS_DB)
.put(SpanField.TAGS, TAGS_ANALYTICS_DB)
.put(SpanField.USAGE_COMPLETION_TOKENS, USAGE_COMPLETION_TOKENS_ANALYTICS_DB)
.put(SpanField.USAGE_PROMPT_TOKENS, USAGE_PROMPT_TOKENS_ANALYTICS_DB)
Expand Down Expand Up @@ -139,6 +146,7 @@ public class FilterQueryBuilder {
.add(TraceField.USAGE_COMPLETION_TOKENS)
.add(TraceField.USAGE_PROMPT_TOKENS)
.add(TraceField.USAGE_TOTAL_TOKENS)
.add(TraceField.TOTAL_ESTIMATED_COST)
.build()),
FilterStrategy.SPAN, EnumSet.copyOf(ImmutableSet.<SpanField>builder()
.add(SpanField.ID)
Expand All @@ -148,6 +156,9 @@ public class FilterQueryBuilder {
.add(SpanField.INPUT)
.add(SpanField.OUTPUT)
.add(SpanField.METADATA)
.add(SpanField.MODEL)
.add(SpanField.PROVIDER)
.add(SpanField.TOTAL_ESTIMATED_COST)
.add(SpanField.TAGS)
.add(SpanField.USAGE_COMPLETION_TOKENS)
.add(SpanField.USAGE_PROMPT_TOKENS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,55 @@ void getByProjectName__whenFilterIdAndNameEqual__thenReturnSpansFiltered() {
getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans, unexpectedSpans, apiKey);
}

@ParameterizedTest
@MethodSource
void getByProjectName__whenFilterByCorrespondingField__thenReturnSpansFiltered(SpanField filterField, Operator filterOperator, String filterValue) {
String workspaceName = UUID.randomUUID().toString();
String workspaceId = UUID.randomUUID().toString();
String apiKey = UUID.randomUUID().toString();
String model = "gpt-3.5-turbo-1106";

mockTargetWorkspace(apiKey, workspaceName, workspaceId);

var projectName = generator.generate().toString();
var unexpectedSpans = PodamFactoryUtils.manufacturePojoList(podamFactory, Span.class)
.stream()
.map(span -> span.toBuilder()
.projectId(null)
.projectName(projectName)
.feedbackScores(null)
.build())
.collect(Collectors.toCollection(ArrayList::new));
unexpectedSpans.forEach(unexpectedSpan -> createAndAssert(unexpectedSpan, apiKey, workspaceName));

var expectedSpans = List.of(podamFactory.manufacturePojo(Span.class).toBuilder()
.projectId(null)
.projectName(projectName)
.model(model)
.usage(Map.of("completion_tokens", Math.abs(podamFactory.manufacturePojo(Integer.class)),
"prompt_tokens", Math.abs(podamFactory.manufacturePojo(Integer.class))))
.feedbackScores(null)
.build());
expectedSpans.forEach(
expectedSpan -> createAndAssert(expectedSpan, apiKey, workspaceName));

// Check that it's filtered by cost
var filters = List.of(
SpanFilter.builder()
.field(filterField)
.operator(filterOperator)
.value(filterField == SpanField.PROVIDER ? expectedSpans.getFirst().provider() : filterValue)
.build());
getAndAssertPage(workspaceName, projectName, filters, expectedSpans, expectedSpans, unexpectedSpans, apiKey);
}

static Stream<Arguments> getByProjectName__whenFilterByCorrespondingField__thenReturnSpansFiltered() {
return Stream.of(
Arguments.of(SpanField.TOTAL_ESTIMATED_COST, Operator.GREATER_THAN, "0"),
Arguments.of(SpanField.MODEL, Operator.EQUAL, "gpt-3.5-turbo-1106"),
Arguments.of(SpanField.PROVIDER, Operator.EQUAL, null));
}

@Test
void getByProjectName__whenFilterNameEqual__thenReturnSpansFiltered() {
String workspaceName = UUID.randomUUID().toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.testcontainers.clickhouse.ClickHouseContainer;
import org.testcontainers.containers.MySQLContainer;
import org.testcontainers.lifecycle.Startables;
Expand Down Expand Up @@ -1522,6 +1523,51 @@ void getByProjectName__whenFilterOutputEqual__thenReturnTracesFiltered() {
getAndAssertPage(workspaceName, projectName, filters, traces, expectedTraces, unexpectedTraces, apiKey);
}

@Test
void getByProjectName__whenFilterTotalEstimatedCostGreaterThen__thenReturnTracesFiltered() {
var workspaceName = RandomStringUtils.randomAlphanumeric(10);
var workspaceId = UUID.randomUUID().toString();
var apiKey = UUID.randomUUID().toString();

mockTargetWorkspace(apiKey, workspaceName, workspaceId);

var projectName = RandomStringUtils.randomAlphanumeric(10);
var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class)
.stream()
.map(trace -> trace.toBuilder()
.projectId(null)
.projectName(projectName)
.usage(null)
.feedbackScores(null)
.build())
.collect(Collectors.toCollection(ArrayList::new));
traces.forEach(trace -> create(trace, apiKey, workspaceName));
var unexpectedTraces = traces.subList(1, traces.size());

var spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream()
.map(spanInStream -> spanInStream.toBuilder()
.projectName(projectName)
.traceId(traces.getFirst().id())
.usage(Map.of("completion_tokens", Math.abs(factory.manufacturePojo(Integer.class)),
"prompt_tokens", Math.abs(factory.manufacturePojo(Integer.class))))
.model("gpt-3.5-turbo-1106")
.build())
.collect(Collectors.toList());

batchCreateSpansAndAssert(spans, apiKey, workspaceName);

var expectedTrace = traces.getFirst().toBuilder()
.usage(aggregateSpansUsage(spans))
.build();

var filters = List.of(TraceFilter.builder()
.field(TraceField.TOTAL_ESTIMATED_COST)
.operator(Operator.GREATER_THAN)
.value("0")
.build());
getAndAssertPage(workspaceName, projectName, filters, traces, List.of(expectedTrace), unexpectedTraces, apiKey);
}

@Test
void getByProjectName__whenFilterMetadataEqualString__thenReturnTracesFiltered() {
var workspaceName = RandomStringUtils.randomAlphanumeric(10);
Expand Down Expand Up @@ -3245,7 +3291,7 @@ void getTraceWithUsage() {
}

@ParameterizedTest
@MethodSource
@ValueSource(strings = {"gpt-3.5-turbo-1106", "unknown-model"})
void getTraceWithCost(String model) {
var projectName = RandomStringUtils.randomAlphanumeric(10);
var trace = factory.manufacturePojo(Trace.class)
Expand All @@ -3266,30 +3312,20 @@ void getTraceWithCost(String model) {
.build())
.collect(Collectors.toList());

var usage = spans.stream()
.flatMap(span -> span.usage().entrySet().stream())
.map(entry -> new AbstractMap.SimpleEntry<>(entry.getKey(), Long.valueOf(entry.getValue())))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, Long::sum));

BigDecimal traceExpectedCost = spans.stream()
.map(span -> ModelPrice.fromString(span.model()).calculateCost(span.usage()))
.reduce(BigDecimal.ZERO, BigDecimal::add);
var usage = aggregateSpansUsage(spans);
BigDecimal traceExpectedCost = aggregateSpansCost(spans);

batchCreateSpansAndAssert(spans, API_KEY, TEST_WORKSPACE);

var projectId = getProjectId(projectName, TEST_WORKSPACE, API_KEY);
trace = trace.toBuilder().id(id).usage(usage).build();
Trace createdTrace = getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE);
assertThat(traceExpectedCost.compareTo(BigDecimal.ZERO) == 0
? createdTrace.totalEstimatedCost() == null
: traceExpectedCost.compareTo(createdTrace.totalEstimatedCost()) == 0)
.isEqualTo(true);
}

static Stream<Arguments> getTraceWithCost() {
return Stream.of(
Arguments.of("gpt-3.5-turbo-1106"),
Arguments.of("unknown-model"));
assertThat(createdTrace.totalEstimatedCost())
.usingRecursiveComparison(RecursiveComparisonConfiguration.builder()
.withComparatorForType(BigDecimal::compareTo, BigDecimal.class)
.build())
.isEqualTo(traceExpectedCost.compareTo(BigDecimal.ZERO) == 0 ? null : traceExpectedCost);
}

@Test
Expand Down Expand Up @@ -7560,4 +7596,17 @@ private int setupTracesForWorkspace(String workspaceName, String workspaceId, St

return traces.size();
}

private Map<String, Long> aggregateSpansUsage(List<Span> spans) {
return spans.stream()
.flatMap(span -> span.usage().entrySet().stream())
.map(entry -> new AbstractMap.SimpleEntry<>(entry.getKey(), Long.valueOf(entry.getValue())))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, Long::sum));
}

private BigDecimal aggregateSpansCost(List<Span> spans) {
return spans.stream()
.map(span -> ModelPrice.fromString(span.model()).calculateCost(span.usage()))
.reduce(BigDecimal.ZERO, BigDecimal::add);
}
}

0 comments on commit 3e8d7a9

Please sign in to comment.