Skip to content

Commit

Permalink
[NA] Bulk insert improvements (#179)
Browse files Browse the repository at this point in the history
* [NA] Bulk insert improvements

* Add fix for dataset items and experiment item insertion
  • Loading branch information
thiagohora authored Sep 4, 2024
1 parent 126ef51 commit 43c044b
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import com.comet.opik.api.ExperimentItem;
import com.comet.opik.api.FeedbackScore;
import com.comet.opik.api.ScoreSource;
import com.comet.opik.infrastructure.BulkConfig;
import com.comet.opik.infrastructure.BulkOperationsConfig;
import com.comet.opik.infrastructure.db.TransactionTemplate;
import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.databind.JsonNode;
Expand Down Expand Up @@ -127,7 +127,7 @@ INSERT INTO dataset_items (
:expectedOutput<item.index> AS expected_output,
:metadata<item.index> AS metadata,
now64(9) AS created_at,
:workspace_id<item.index> AS workspace_id,
:workspace_id AS workspace_id,
:createdBy<item.index> AS created_by,
:lastUpdatedBy<item.index> AS last_updated_by
<if(item.hasNext)>
Expand All @@ -139,6 +139,14 @@ LEFT JOIN (
SELECT
*
FROM dataset_items
WHERE id IN (
<items:{item |
:id<item.index>
<if(item.hasNext)>
,
<endif>
}>
)
ORDER BY last_updated_at DESC
LIMIT 1 BY id
) AS old
Expand Down Expand Up @@ -379,7 +387,7 @@ LEFT JOIN (
""";

private final @NonNull TransactionTemplate asyncTemplate;
private final @NonNull @Config("bulkOperations") BulkConfig bulkConfig;
private final @NonNull @Config("bulkOperations") BulkOperationsConfig bulkConfig;

@Override
public Mono<Long> save(@NonNull UUID datasetId, @NonNull List<DatasetItem> items) {
Expand All @@ -388,10 +396,10 @@ public Mono<Long> save(@NonNull UUID datasetId, @NonNull List<DatasetItem> items
return Mono.empty();
}

return inset(datasetId, items);
return insert(datasetId, items);
}

private Mono<Long> inset(UUID datasetId, List<DatasetItem> items) {
private Mono<Long> insert(UUID datasetId, List<DatasetItem> items) {
List<List<DatasetItem>> batches = Lists.partition(items, bulkConfig.getSize());

return Flux.fromIterable(batches)
Expand All @@ -401,7 +409,7 @@ private Mono<Long> inset(UUID datasetId, List<DatasetItem> items) {

private Mono<Long> mapAndInsert(UUID datasetId, List<DatasetItem> items, Connection connection) {

List<QueryItem> queryItems = getQueryItemPlaceHolder(items);
List<QueryItem> queryItems = getQueryItemPlaceHolder(items.size());

var template = new ST(INSERT_DATASET_ITEM)
.add("items", queryItems);
Expand All @@ -412,6 +420,8 @@ private Mono<Long> mapAndInsert(UUID datasetId, List<DatasetItem> items, Connect

return makeMonoContextAware((userName, workspaceName, workspaceId) -> {

statement.bind("workspace_id", workspaceId);

int i = 0;
for (DatasetItem item : items) {
statement.bind("id" + i, item.id());
Expand All @@ -422,7 +432,6 @@ private Mono<Long> mapAndInsert(UUID datasetId, List<DatasetItem> items, Connect
statement.bind("input" + i, getOrDefault(item.input()));
statement.bind("expectedOutput" + i, getOrDefault(item.expectedOutput()));
statement.bind("metadata" + i, getOrDefault(item.metadata()));
statement.bind("workspace_id" + i, workspaceId);
statement.bind("createdBy" + i,userName);
statement.bind("lastUpdatedBy" + i, userName);
i++;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.comet.opik.domain;

import com.comet.opik.api.ExperimentItem;
import com.comet.opik.infrastructure.BulkConfig;
import com.comet.opik.infrastructure.BulkOperationsConfig;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import io.r2dbc.spi.Connection;
Expand Down Expand Up @@ -79,7 +79,7 @@ INSERT INTO experiment_items (
:experiment_id<item.index> AS experiment_id,
:dataset_item_id<item.index> AS dataset_item_id,
:trace_id<item.index> AS trace_id,
:workspace_id<item.index> AS workspace_id,
:workspace_id AS workspace_id,
:created_by<item.index> AS created_by,
:last_updated_by<item.index> AS last_updated_by
<if(item.hasNext)>
Expand All @@ -91,6 +91,14 @@ LEFT JOIN (
SELECT
id, workspace_id
FROM experiment_items
WHERE id IN (
<items:{item |
:id<item.index>
<if(item.hasNext)>
,
<endif>
}>
)
ORDER BY last_updated_at DESC
LIMIT 1 BY id
) AS old
Expand Down Expand Up @@ -131,7 +139,7 @@ LEFT JOIN (
""";

private final @NonNull ConnectionFactory connectionFactory;
private final @NonNull @Config("bulkOperations") BulkConfig bulkConfig;
private final @NonNull @Config("bulkOperations") BulkOperationsConfig bulkConfig;

public Flux<ExperimentSummary> findExperimentSummaryByDatasetIds(Collection<UUID> datasetIds) {

Expand Down Expand Up @@ -169,7 +177,7 @@ public Mono<Long> insert(@NonNull Set<ExperimentItem> experimentItems) {

private Mono<Long> insert(Collection<ExperimentItem> experimentItems, Connection connection) {

List<QueryItem> queryItems = getQueryItemPlaceHolder(experimentItems);
List<QueryItem> queryItems = getQueryItemPlaceHolder(experimentItems.size());

var template = new ST(INSERT)
.add("items", queryItems);
Expand All @@ -180,13 +188,14 @@ private Mono<Long> insert(Collection<ExperimentItem> experimentItems, Connection

return makeMonoContextAware((userName, workspaceName, workspaceId) -> {

statement.bind("workspace_id", workspaceId);

int index = 0;
for (ExperimentItem item : experimentItems) {
statement.bind("id" + index, item.id());
statement.bind("experiment_id" + index, item.experimentId());
statement.bind("dataset_item_id" + index, item.datasetItemId());
statement.bind("trace_id" + index, item.traceId());
statement.bind("workspace_id" + index, workspaceId);
statement.bind("created_by" + index, userName);
statement.bind("last_updated_by" + index, userName);
index++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotNull;
import lombok.Data;

@Data
public class BulkConfig {
public class BulkOperationsConfig {

@Valid
@JsonProperty
private Integer size;

@NotNull
private int size;
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ public class OpikConfiguration extends Configuration {

@Valid
@NotNull @JsonProperty
private BulkConfig bulkOperations = new BulkConfig();
private BulkOperationsConfig bulkOperations = new BulkOperationsConfig();
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package com.comet.opik.infrastructure.db;

import com.comet.opik.infrastructure.BulkOperationsConfig;
import com.comet.opik.infrastructure.DatabaseAnalyticsFactory;
import com.comet.opik.infrastructure.OpikConfiguration;
import com.google.inject.Provides;
import io.r2dbc.spi.ConnectionFactory;
import jakarta.inject.Named;
import jakarta.inject.Singleton;
import ru.vyarus.dropwizard.guice.module.support.DropwizardAwareModule;
import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;

public class DatabaseAnalyticsModule extends DropwizardAwareModule<OpikConfiguration> {

Expand Down Expand Up @@ -38,4 +40,10 @@ public TransactionTemplate getTransactionTemplate(ConnectionFactory connectionFa
return new TransactionTemplateImpl(connectionFactory);
}

@Provides
@Singleton
public BulkOperationsConfig bulkOperation(@Config("bulkOperations") BulkOperationsConfig bulkConfig) {
return bulkConfig;
}

}
Original file line number Diff line number Diff line change
@@ -1,29 +1,26 @@
package com.comet.opik.utils;

import java.util.Collection;
import lombok.RequiredArgsConstructor;

import java.util.List;
import java.util.stream.IntStream;

public class TemplateUtils {

@RequiredArgsConstructor
public static class QueryItem {
public final int index;
public final boolean hasNext;

public QueryItem(int index, boolean hasNext) {
this.index = index;
this.hasNext = hasNext;
}
}

public static List<QueryItem> getQueryItemPlaceHolder(Collection<?> items) {
public static List<QueryItem> getQueryItemPlaceHolder(int size) {

if (items == null || items.isEmpty()) {
if (size == 0) {
return List.of();
}

return IntStream.range(0, items.size())
.mapToObj(i -> new QueryItem(i, i < items.size() - 1))
return IntStream.range(0, size)
.mapToObj(i -> new QueryItem(i, i < size - 1))
.toList();
}
}

0 comments on commit 43c044b

Please sign in to comment.