Skip to content

Commit

Permalink
[OPIK-524] Add endpoint to get experiment output columns (#840)
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagohora authored Dec 10, 2024
1 parent 222c25c commit a4c741d
Show file tree
Hide file tree
Showing 12 changed files with 466 additions and 47 deletions.
37 changes: 37 additions & 0 deletions apps/opik-backend/src/main/java/com/comet/opik/api/Column.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Builder;
import lombok.RequiredArgsConstructor;

import java.util.Set;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record Column(String name, Set<ColumnType> types, String filterFieldPrefix) {

@JsonProperty("filterField")
@Schema(accessMode = Schema.AccessMode.READ_ONLY, description = "The field to use for filtering", name = "filterField")
public String filterField() {
return "%s.%s".formatted(filterFieldPrefix, name);
}

@RequiredArgsConstructor
public enum ColumnType {
STRING("string"),
NUMBER("number"),
OBJECT("object"),
BOOLEAN("boolean"),
ARRAY("array"),
NULL("null");

@JsonValue
private final String value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
import com.comet.opik.api.validate.DatasetItemInputValidation;
import com.comet.opik.api.validate.SourceValidation;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Builder;
import lombok.RequiredArgsConstructor;

import java.time.Instant;
import java.util.List;
Expand Down Expand Up @@ -56,21 +54,6 @@ public record DatasetItemPage(
@JsonView({DatasetItem.View.Public.class}) long total,
@JsonView({DatasetItem.View.Public.class}) Set<Column> columns) implements Page<DatasetItem>{

public record Column(String name, Set<ColumnType> types) {

@RequiredArgsConstructor
public enum ColumnType {
STRING("string"),
NUMBER("number"),
OBJECT("object"),
BOOLEAN("boolean"),
ARRAY("array"),
NULL("null");

@JsonValue
private final String value;
}
}
}

public static class View {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import lombok.Builder;

import java.util.List;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record PageColumns(List<Column> columns) {
public static PageColumns empty() {
return new PageColumns(List.of());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.annotation.Nullable;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Pattern;
import lombok.Builder;

import java.time.Instant;
import java.util.List;
import java.util.UUID;

import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
// This annotation is used to specify the strategy to be used for naming of properties for the annotated type. Required so that OpenAPI schema generation uses snake_case
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ public Response toResponse(InvalidFormatException exception) {
log.info("Deserialization exception: {}", exception.getMessage());
int endIndex = errorMessage.indexOf(": Failed to deserialize");
return Response.status(Response.Status.BAD_REQUEST)
.entity(new ErrorMessage(List.of(endIndex == -1 ? "Unable to process JSON" : errorMessage.substring(0, endIndex))))
.entity(new ErrorMessage(
List.of(endIndex == -1 ? "Unable to process JSON" : errorMessage.substring(0, endIndex))))
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.comet.opik.api.DatasetItemsDelete;
import com.comet.opik.api.DatasetUpdate;
import com.comet.opik.api.ExperimentItem;
import com.comet.opik.api.PageColumns;
import com.comet.opik.api.filter.ExperimentsComparisonFilter;
import com.comet.opik.api.filter.FiltersFactory;
import com.comet.opik.api.resources.v1.priv.validate.ExperimentParamsValidator;
Expand Down Expand Up @@ -63,7 +64,9 @@

import java.net.URI;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Predicate;

import static com.comet.opik.api.Dataset.DatasetPage;
import static com.comet.opik.utils.AsyncUtils.setRequestContext;
Expand Down Expand Up @@ -399,4 +402,33 @@ public Response findDatasetItemsWithExperimentItems(
return Response.ok(datasetItemPage).build();
}

@GET
@Path("/{id}/items/experiments/items/output/columns")
@Operation(operationId = "getDatasetItemsOutputColumns", summary = "Get dataset items output columns", description = "Get dataset items output columns", responses = {
@ApiResponse(responseCode = "200", description = "Dataset item output columns", content = @Content(schema = @Schema(implementation = PageColumns.class)))
})
public Response getDatasetItemsOutputColumns(
@PathParam("id") @NotNull UUID datasetId,
@QueryParam("experiment_ids") String experimentIdsQueryParam) {

var experimentIds = Optional.ofNullable(experimentIdsQueryParam)
.filter(Predicate.not(String::isEmpty))
.map(ExperimentParamsValidator::getExperimentIds)
.orElse(null);

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Finding traces output columns by datasetId '{}', experimentIds '{}', on workspaceId '{}'",
datasetId, experimentIds, workspaceId);

PageColumns columns = itemService.getOutputColumns(datasetId, experimentIds)
.contextWrite(ctx -> setRequestContext(ctx, requestContext))
.block();

log.info("Found traces output columns by datasetId '{}', experimentIds '{}', on workspaceId '{}'",
datasetId, experimentIds, workspaceId);

return Response.ok(columns).build();
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik.domain;

import com.clickhouse.client.ClickHouseException;
import com.comet.opik.api.Column;
import com.comet.opik.api.DatasetItem;
import com.comet.opik.api.DatasetItemSearchCriteria;
import com.comet.opik.domain.filter.FilterQueryBuilder;
Expand All @@ -17,6 +18,7 @@
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.stringtemplate.v4.ST;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand All @@ -29,7 +31,6 @@
import java.util.UUID;

import static com.comet.opik.api.DatasetItem.DatasetItemPage;
import static com.comet.opik.api.DatasetItem.DatasetItemPage.Column;
import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToFlux;
import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToMono;
import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.Segment;
Expand Down Expand Up @@ -57,6 +58,8 @@ public interface DatasetItemDAO {
Mono<List<WorkspaceAndResourceId>> getDatasetItemWorkspace(Set<UUID> datasetItemIds);

Flux<DatasetItemSummary> findDatasetItemSummaryByDatasetIds(Set<UUID> datasetIds);

Mono<List<Column>> getOutputColumns(UUID datasetId, Set<UUID> experimentIds);
}

@Singleton
Expand Down Expand Up @@ -471,6 +474,62 @@ AND id IN (SELECT trace_id FROM experiment_items_final)
;
""";

private static final String SELECT_DATASET_EXPERIMENT_ITEMS_COLUMNS_BY_DATASET_ID = """
SELECT
arrayFold(
(acc, x) -> mapFromArrays(
arrayMap(key -> key, arrayDistinct(arrayConcat(mapKeys(acc), mapKeys(x)))),
arrayMap(
key -> arrayDistinct(arrayConcat(acc[key], x[key])),
arrayDistinct(arrayConcat(mapKeys(acc), mapKeys(x)))
)
),
arrayDistinct(
arrayFlatten(
groupArray(
arrayMap(
key -> map(key, [toString(JSONType(JSONExtractRaw(output, key)))]),
JSONExtractKeys(output)
)
)
)
),
CAST(map(), 'Map(String, Array(String))')
) AS columns
FROM (
SELECT
id
FROM dataset_items
WHERE workspace_id = :workspace_id
AND dataset_id = :dataset_id
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
) as di
INNER JOIN (
SELECT
ei.id,
ei.trace_id,
ei.dataset_item_id
FROM experiment_items ei
WHERE workspace_id = :workspace_id
<if(experiment_ids)>
AND experiment_id in :experiment_ids
<endif>
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
) as ei ON ei.dataset_item_id = di.id
INNER JOIN (
SELECT
id,
output
FROM traces
WHERE workspace_id = :workspace_id
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
) as t ON t.id = ei.trace_id
;
""";

private final @NonNull TransactionTemplateAsync asyncTemplate;
private final @NonNull FilterQueryBuilder filterQueryBuilder;

Expand Down Expand Up @@ -634,6 +693,29 @@ public Flux<DatasetItemSummary> findDatasetItemSummaryByDatasetIds(Set<UUID> dat
});
}

@Override
public Mono<List<Column>> getOutputColumns(@NonNull UUID datasetId, Set<UUID> experimentIds) {
return asyncTemplate.nonTransaction(connection -> {

ST template = new ST(SELECT_DATASET_EXPERIMENT_ITEMS_COLUMNS_BY_DATASET_ID);

if (CollectionUtils.isNotEmpty(experimentIds)) {
template.add("experiment_ids", experimentIds);
}

var statement = connection.createStatement(template.render())
.bind("dataset_id", datasetId);

if (CollectionUtils.isNotEmpty(experimentIds)) {
statement.bind("experiment_ids", experimentIds.toArray(UUID[]::new));
}

return makeMonoContextAware(bindWorkspaceIdToMono(statement))
.flatMap(result -> DatasetItemResultMapper.mapColumns(result, "output"))
.map(List::copyOf);
});
}

@Override
@WithSpan
public Mono<Long> delete(@NonNull List<UUID> ids) {
Expand Down Expand Up @@ -673,7 +755,7 @@ public Mono<DatasetItemPage> getItems(@NonNull UUID datasetId, int page, int siz
.bind("workspace_id", workspaceId)
.execute())
.doFinally(signalType -> endSegment(segmentCount))
.flatMap(DatasetItemResultMapper::mapCountAndColumns)
.flatMap(results -> DatasetItemResultMapper.mapCountAndColumns(results, "data"))
.reduce(DatasetItemResultMapper::groupResults)
.flatMap(result -> {

Expand Down Expand Up @@ -802,7 +884,7 @@ private Mono<Set<Column>> mapColumnsField(DatasetItemSearchCriteria datasetItemS
bindWorkspaceIdToMono(
connection.createStatement(SELECT_DATASET_ITEMS_COLUMNS_BY_DATASET_ID)
.bind("datasetId", datasetItemSearchCriteria.datasetId())))
.flatMap(DatasetItemResultMapper::mapColumns))
.flatMap(result -> DatasetItemResultMapper.mapColumns(result, "data")))
.doFinally(signalType -> endSegment(segment));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.comet.opik.domain;

import com.comet.opik.api.Column;
import com.comet.opik.api.DatasetItem;
import com.comet.opik.api.DatasetItemSource;
import com.comet.opik.api.ExperimentItem;
Expand Down Expand Up @@ -27,8 +28,7 @@
import java.util.UUID;
import java.util.stream.Collectors;

import static com.comet.opik.api.DatasetItem.DatasetItemPage.Column;
import static com.comet.opik.api.DatasetItem.DatasetItemPage.Column.ColumnType;
import static com.comet.opik.api.Column.ColumnType;
import static com.comet.opik.utils.ValidationUtils.CLICKHOUSE_FIXED_STRING_UUID_FIELD_NULL_VALUE;
import static java.util.function.Predicate.not;
import static java.util.stream.Collectors.toMap;
Expand Down Expand Up @@ -97,12 +97,14 @@ static Map.Entry<Long, Set<Column>> groupResults(Map.Entry<Long, Set<Column>> re
return Map.entry(result1.getKey() + result2.getKey(), Sets.union(result1.getValue(), result2.getValue()));
}

private static Set<Column> mapColumnsField(Map<String, String[]> row) {
private static Set<Column> mapColumnsField(Map<String, String[]> row, String filterField) {
return Optional.ofNullable(row).orElse(Map.of())
.entrySet()
.stream()
.map(columnArray -> new Column(columnArray.getKey(),
Set.of(mapColumnType(columnArray.getValue()))))
.map(columnArray -> Column.builder().name(columnArray.getKey())
.types(Set.of(mapColumnType(columnArray.getValue())))
.filterFieldPrefix(filterField)
.build())
.collect(Collectors.toSet());
}

Expand Down Expand Up @@ -210,11 +212,11 @@ static String getOrDefault(UUID value) {
return Optional.ofNullable(value).map(UUID::toString).orElse("");
}

static Publisher<Map.Entry<Long, Set<Column>>> mapCountAndColumns(Result result) {
static Publisher<Map.Entry<Long, Set<Column>>> mapCountAndColumns(Result result, String filterFieldPrefix) {
return result.map((row, rowMetadata) -> {
Long count = extractCountFromResult(row);
Map<String, String[]> columnsMap = extractColumnsField(row);
return Map.entry(count, mapColumnsField(columnsMap));
return Map.entry(count, mapColumnsField(columnsMap, filterFieldPrefix));
});
}

Expand All @@ -223,17 +225,17 @@ private static Long extractCountFromResult(Row row) {
}

private static Map<String, String[]> extractColumnsField(Row row) {
return (Map<String, String[]>) row.get("columns", Map.class);
return row.get("columns", Map.class);
}

static Publisher<Long> mapCount(Result result) {
return result.map((row, rowMetadata) -> extractCountFromResult(row));
}

static Mono<Set<Column>> mapColumns(Result result) {
static Mono<Set<Column>> mapColumns(Result result, String filterFieldPrefix) {
return Mono.from(result.map((row, rowMetadata) -> {
Map<String, String[]> columnsMap = extractColumnsField(row);
return DatasetItemResultMapper.mapColumnsField(columnsMap);
return DatasetItemResultMapper.mapColumnsField(columnsMap, filterFieldPrefix);
}));
}
}
Loading

0 comments on commit a4c741d

Please sign in to comment.