Skip to content

Commit

Permalink
Refactor Ollama usage metadata to add embedding support
Browse files Browse the repository at this point in the history
 - Extend the OllamaApi.EmbeddingsResponse with total_duration,
   load_duration and prompt_eval_count fields
 - Rename OllamaUsage to OllamaChatUsage for clarity
 - Add OllamaEmbeddingUsage to track embedding-specific usage metrics
 - Update OllamaEmbeddingModel to use OllamaEmbeddingUsage
 - Extend EmbeddingsResponse with additional metadata fields
 - Update tests to reflect new usage tracking for embeddings

 Resolves #1536
  • Loading branch information
tzolov committed Oct 14, 2024
1 parent fe09b4d commit 6d38c85
Show file tree
Hide file tree
Showing 8 changed files with 96 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCall;
import org.springframework.ai.ollama.api.OllamaApi.Message.ToolCallFunction;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.metadata.OllamaUsage;
import org.springframework.ai.ollama.metadata.OllamaChatUsage;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -177,7 +177,7 @@ && isToolCall(response, Set.of("stop"))) {
public static ChatResponseMetadata from(OllamaApi.ChatResponse response) {
Assert.notNull(response, "OllamaApi.ChatResponse must not be null");
return ChatResponseMetadata.builder()
.withUsage(OllamaUsage.from(response))
.withUsage(OllamaChatUsage.from(response))
.withModel(response.model())
.withKeyValue("created-at", response.createdAt())
.withKeyValue("eval-duration", response.evalDuration())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.metadata.OllamaEmbeddingUsage;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

Expand Down Expand Up @@ -125,7 +126,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
.toList();

EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata(response.model(),
new EmptyUsage());
OllamaEmbeddingUsage.from(response));

EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, embeddingResponseMetadata);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,11 @@ public record EmbeddingResponse(
@JsonInclude(Include.NON_NULL)
public record EmbeddingsResponse(
@JsonProperty("model") String model,
@JsonProperty("embeddings") List<float[]> embeddings) {
@JsonProperty("embeddings") List<float[]> embeddings,
@JsonProperty("total_duration") Long totalDuration,
@JsonProperty("load_duration") Long loadDuration,
@JsonProperty("prompt_eval_count") Integer promptEvalCount) {

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@
* @see Usage
* @author Fu Cheng
*/
public class OllamaUsage implements Usage {
public class OllamaChatUsage implements Usage {

protected static final String AI_USAGE_STRING = "{ promptTokens: %1$d, generationTokens: %2$d, totalTokens: %3$d }";

public static OllamaUsage from(OllamaApi.ChatResponse response) {
public static OllamaChatUsage from(OllamaApi.ChatResponse response) {
Assert.notNull(response, "OllamaApi.ChatResponse must not be null");
return new OllamaUsage(response);
return new OllamaChatUsage(response);
}

private final OllamaApi.ChatResponse response;

public OllamaUsage(OllamaApi.ChatResponse response) {
public OllamaChatUsage(OllamaApi.ChatResponse response) {
this.response = response;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.ollama.metadata;

import java.util.Optional;

import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
import org.springframework.util.Assert;

/**
* {@link Usage} implementation for {@literal Ollama} embeddings.
*
* @see Usage
* @author Christian Tzolov
*/
public class OllamaEmbeddingUsage implements Usage {

protected static final String AI_USAGE_STRING = "{ promptTokens: %1$d, generationTokens: %2$d, totalTokens: %3$d }";

public static OllamaEmbeddingUsage from(EmbeddingsResponse response) {
Assert.notNull(response, "OllamaApi.EmbeddingsResponse must not be null");
return new OllamaEmbeddingUsage(response);
}

private Long promptTokens;

public OllamaEmbeddingUsage(EmbeddingsResponse response) {
this.promptTokens = Optional.ofNullable(response.promptEvalCount()).map(Integer::longValue).orElse(0L);
}

@Override
public Long getPromptTokens() {
return this.promptTokens;
}

@Override
public Long getGenerationTokens() {
return 0L;
}

@Override
public String toString() {
return AI_USAGE_STRING.formatted(getPromptTokens(), getGenerationTokens(), getTotalTokens());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
*/
package org.springframework.ai.ollama;

import static org.assertj.core.api.Assertions.assertThat;

import java.io.IOException;
import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.BeforeAll;
Expand All @@ -24,32 +29,22 @@
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApiIT;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.testcontainers.junit.jupiter.Testcontainers;

import java.io.IOException;
import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest
@DisabledIf("isDisabled")
@Testcontainers
class OllamaEmbeddingModelIT extends BaseOllamaIT {

private static final String MODEL = OllamaModel.MISTRAL.getName();
private static final String MODEL = "mxbai-embed-large";

private static final Log logger = LogFactory.getLog(OllamaApiIT.class);

// @Container
// static OllamaContainer ollamaContainer = new
// OllamaContainer(OllamaImage.DEFAULT_IMAGE);

static String baseUrl = "http://localhost:11434";

@BeforeAll
Expand All @@ -75,8 +70,10 @@ void embeddings() {
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getMetadata().getModel()).isEqualTo(MODEL);
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4);
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4);

assertThat(embeddingModel.dimensions()).isEqualTo(4096);
assertThat(embeddingModel.dimensions()).isEqualTo(1024);
}

@SpringBootConfiguration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,25 @@ public class OllamaEmbeddingModelTests {
public void options() {

when(ollamaApi.embed(embeddingsRequestCaptor.capture()))
.thenReturn(
new EmbeddingsResponse("RESPONSE_MODEL_NAME", List.of(new float[]{1f, 2f, 3f}, new float[]{4f, 5f, 6f})))
.thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME",
List.of(new float[] { 1f, 2f, 3f }, new float[] { 4f, 5f, 6f }), 0L, 0L, 0))
.thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME2",
List.of(new float[]{7f, 8f, 9f}, new float[]{10f, 11f, 12f})));
List.of(new float[] { 7f, 8f, 9f }, new float[] { 10f, 11f, 12f }), 0L, 0L, 0));

// Tests default options
var defaultOptions = OllamaOptions.builder().withModel("DEFAULT_MODEL").build();

var embeddingModel = new OllamaEmbeddingModel(ollamaApi, defaultOptions);

EmbeddingResponse response = embeddingModel
.call(new EmbeddingRequest(List.of("Input1", "Input2", "Input3"), EmbeddingOptionsBuilder.builder().build()));
EmbeddingResponse response = embeddingModel.call(
new EmbeddingRequest(List.of("Input1", "Input2", "Input3"), EmbeddingOptionsBuilder.builder().build()));

assertThat(response.getResults()).hasSize(2);
assertThat(response.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[]{1f, 2f, 3f});
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[] { 1f, 2f, 3f });
assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
assertThat(response.getResults().get(1).getIndex()).isEqualTo(1);
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[]{4f, 5f, 6f});
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[] { 4f, 5f, 6f });
assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME");

Expand All @@ -94,10 +94,10 @@ public void options() {

assertThat(response.getResults()).hasSize(2);
assertThat(response.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[]{7f, 8f, 9f});
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[] { 7f, 8f, 9f });
assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
assertThat(response.getResults().get(1).getIndex()).isEqualTo(1);
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[]{10f, 11f, 12f});
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[] { 10f, 11f, 12f });
assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME2");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ public void embedText() {
assertThat(response).isNotNull();
assertThat(response.embeddings()).hasSize(1);
assertThat(response.embeddings().get(0)).hasSize(3200);
assertThat(response.model()).isEqualTo(MODEL);
assertThat(response.promptEvalCount()).isEqualTo(5);
assertThat(response.loadDuration()).isGreaterThan(1);
assertThat(response.totalDuration()).isGreaterThan(1);

}

}

0 comments on commit 6d38c85

Please sign in to comment.