Skip to content

Commit

Permalink
Replace the Embedding format from List<Double> to float[]
Browse files Browse the repository at this point in the history
 - Adjust all affected classes including the Document.
 - Update docs.

Related to #405
  • Loading branch information
tzolov authored and markpollack committed Aug 13, 2024
1 parent 656fa8b commit d538e00
Show file tree
Hide file tree
Showing 67 changed files with 442 additions and 412 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,21 @@
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

import java.util.ArrayList;
import java.util.List;

/**
* Azure Open AI Embedding Model implementation.
*
* @author Mark Pollack
* @author Christian Tzolov
* @author Thomas Vitale
* @since 1.0.0
*/
public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {

private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiEmbeddingModel.class);
Expand Down Expand Up @@ -64,13 +74,17 @@ public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode me
}

@Override
public List<Double> embed(Document document) {
public float[] embed(Document document) {
logger.debug("Retrieving embeddings");

EmbeddingResponse response = this
.call(new EmbeddingRequest(List.of(document.getFormattedContent(this.metadataMode)), null));
logger.debug("Embeddings retrieved");
return response.getResults().stream().map(embedding -> embedding.getOutput()).flatMap(List::stream).toList();

if (CollectionUtils.isEmpty(response.getResults())) {
return new float[0];
}
return response.getResults().get(0).getOutput();
}

@Override
Expand Down Expand Up @@ -108,8 +122,7 @@ private List<Embedding> generateEmbeddingList(List<EmbeddingItem> nativeData) {
for (EmbeddingItem nativeDatum : nativeData) {
List<Float> nativeDatumEmbedding = nativeDatum.getEmbedding();
int nativeIndex = nativeDatum.getPromptIndex();
Embedding embedding = new Embedding(nativeDatumEmbedding.stream().map(f -> f.doubleValue()).toList(),
nativeIndex);
Embedding embedding = new Embedding(EmbeddingUtils.toPrimitive(nativeDatumEmbedding), nativeIndex);
data.add(embedding);
}
return data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,8 @@ public BedrockCohereEmbeddingModel(CohereEmbeddingBedrockApi cohereEmbeddingBedr
this.defaultOptions = options;
}

// /**
// * Cohere Embedding API input types.
// * @param inputType the input type to use.
// * @return this client.
// */
// public BedrockCohereEmbeddingModel withInputType(CohereEmbeddingRequest.InputType
// inputType) {
// this.inputType = inputType;
// return this;
// }

// /**
// * Specifies how the API handles inputs longer than the maximum token length. If you
// specify LEFT or RIGHT, the
// * model discards the input until the remaining input is exactly the maximum input
// token length for the model.
// * @param truncate the truncate option to use.
// * @return this client.
// */
// public BedrockCohereEmbeddingModel withTruncate(CohereEmbeddingRequest.Truncate
// truncate) {
// this.truncate = truncate;
// return this;
// }

@Override
public List<Double> embed(Document document) {
public float[] embed(Document document) {
return embed(document.getContent());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public enum Truncate {
@JsonInclude(Include.NON_NULL)
public record CohereEmbeddingResponse(
@JsonProperty("id") String id,
@JsonProperty("embeddings") List<List<Double>> embeddings,
@JsonProperty("embeddings") List<float[]> embeddings,
@JsonProperty("texts") List<String> texts,
@JsonProperty("response_type") String responseType,
// For future use: Currently bedrock doesn't return invocationMetrics for the cohere embedding model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public BedrockTitanEmbeddingModel withInputType(InputType inputType) {
}

@Override
public List<Double> embed(Document document) {
public float[] embed(Document document) {
return embed(document.getContent());
}

Expand All @@ -87,16 +87,13 @@ public EmbeddingResponse call(EmbeddingRequest request) {
"Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
}

List<List<Double>> embeddingList = new ArrayList<>();
List<Embedding> embeddings = new ArrayList<>();
var indexCounter = new AtomicInteger(0);
for (String inputContent : request.getInstructions()) {
var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions());
TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest);
embeddingList.add(response.embedding());
embeddings.add(new Embedding(response.embedding(), indexCounter.getAndIncrement()));
}
var indexCounter = new AtomicInteger(0);
List<Embedding> embeddings = embeddingList.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public TitanEmbeddingRequest build() {
*/
@JsonInclude(Include.NON_NULL)
public record TitanEmbeddingResponse(
@JsonProperty("embedding") List<Double> embedding,
@JsonProperty("embedding") float[] embedding,
@JsonProperty("inputTextTokenCount") Integer inputTextTokenCount,
@JsonProperty("message") Object message) {
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public MiniMaxEmbeddingModel(MiniMaxApi miniMaxApi, MetadataMode metadataMode, M
}

@Override
public List<Double> embed(Document document) {
public float[] embed(Document document) {
Assert.notNull(document, "Document must not be null");
return this.embed(document.getFormattedContent(this.metadataMode));
}
Expand Down Expand Up @@ -137,7 +137,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {

List<Embedding> embeddings = new ArrayList<>();
for (int i = 0; i < apiEmbeddingResponse.vectors().size(); i++) {
List<Double> vector = apiEmbeddingResponse.vectors().get(i);
float[] vector = apiEmbeddingResponse.vectors().get(i);
embeddings.add(new Embedding(vector, i));
}
return new EmbeddingResponse(embeddings, metadata);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ public EmbeddingRequest(List<String> texts, EmbeddingType type) {
*/
@JsonInclude(Include.NON_NULL)
public record EmbeddingList(
@JsonProperty("vectors") List<List<Double>> vectors,
@JsonProperty("vectors") List<float[]> vectors,
@JsonProperty("model") String model,
@JsonProperty("total_tokens") Integer totalTokens) {
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public void miniMaxChatStreamNonTransientError() {
@Test
public void miniMaxEmbeddingTransientError() {

EmbeddingList expectedEmbeddings = new EmbeddingList(List.of(List.of(9.9, 8.8)), "model", 10);
EmbeddingList expectedEmbeddings = new EmbeddingList(List.of(new float[] { 9.9f, 8.8f }), "model", 10);

when(miniMaxApi.embeddings(isA(EmbeddingRequest.class)))
.thenThrow(new TransientAiException("Transient Error 1"))
Expand All @@ -168,7 +168,7 @@ public void miniMaxEmbeddingTransientError() {
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));

assertThat(result).isNotNull();
assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8));
assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f });
assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
}

@Override
public List<Double> embed(Document document) {
public float[] embed(Document document) {
Assert.notNull(document, "Document must not be null");
return this.embed(document.getFormattedContent(this.metadataMode));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public record Usage(
public record Embedding(
// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("embedding") List<Double> embedding,
@JsonProperty("embedding") float[] embedding,
@JsonProperty("object") String object) {
// @formatter:on

Expand All @@ -207,7 +207,7 @@ public record Embedding(
* @param embedding The embedding vector, which is a list of floats. The length of
* vector depends on the model.
*/
public Embedding(Integer index, List<Double> embedding) {
public Embedding(Integer index, float[] embedding) {
this(index, embedding, "embedding");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
*/
package org.springframework.ai.mistralai;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.when;

import java.util.List;
import java.util.Optional;

Expand All @@ -23,8 +28,6 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.mistralai.api.MistralAiApi;
Expand All @@ -45,10 +48,7 @@
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.when;
import reactor.core.publisher.Flux;

/**
* @author Christian Tzolov
Expand Down Expand Up @@ -166,7 +166,7 @@ public void mistralAiChatStreamNonTransientError() {
public void mistralAiEmbeddingTransientError() {

EmbeddingList<Embedding> expectedEmbeddings = new EmbeddingList<>("list",
List.of(new Embedding(0, List.of(9.9, 8.8))), "model", new MistralAiApi.Usage(10, 10, 10));
List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new MistralAiApi.Usage(10, 10, 10));

when(mistralAiApi.embeddings(isA(EmbeddingRequest.class)))
.thenThrow(new TransientAiException("Transient Error 1"))
Expand All @@ -177,7 +177,7 @@ public void mistralAiEmbeddingTransientError() {
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));

assertThat(result).isNotNull();
assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8));
assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f });
assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
}

@Override
public List<Double> embed(Document document) {
public float[] embed(Document document) {
return embed(document.getContent());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ public EmbeddingRequest(String model, String prompt) {
@Deprecated(since = "1.0.0-M2", forRemoval = true)
@JsonInclude(Include.NON_NULL)
public record EmbeddingResponse(
@JsonProperty("embedding") List<Double> embedding) {
@JsonProperty("embedding") List<Float> embedding) {
}


Expand All @@ -764,7 +764,7 @@ public record EmbeddingResponse(
@JsonInclude(Include.NON_NULL)
public record EmbeddingsResponse(
@JsonProperty("model") String model,
@JsonProperty("embeddings") List<List<Double>> embeddings) {
@JsonProperty("embeddings") List<float[]> embeddings) {
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ public void options() {

when(ollamaApi.embed(embeddingsRequestCaptor.capture()))
.thenReturn(
new EmbeddingsResponse("RESPONSE_MODEL_NAME", List.of(List.of(1d, 2d, 3d), List.of(4d, 5d, 6d))))
new EmbeddingsResponse("RESPONSE_MODEL_NAME", List.of(new float[]{1f, 2f, 3f}, new float[]{4f, 5f, 6f})))
.thenReturn(new EmbeddingsResponse("RESPONSE_MODEL_NAME2",
List.of(List.of(7d, 8d, 9d), List.of(10d, 11d, 12d))));
List.of(new float[]{7f, 8f, 9f}, new float[]{10f, 11f, 12f})));

// Tests default options
var defaultOptions = OllamaOptions.builder().withModel("DEFAULT_MODEL").build();
Expand All @@ -69,10 +69,10 @@ public void options() {

assertThat(response.getResults()).hasSize(2);
assertThat(response.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(response.getResults().get(0).getOutput()).isEqualTo(List.of(1d, 2d, 3d));
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[]{1f, 2f, 3f});
assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
assertThat(response.getResults().get(1).getIndex()).isEqualTo(1);
assertThat(response.getResults().get(1).getOutput()).isEqualTo(List.of(4d, 5d, 6d));
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[]{4f, 5f, 6f});
assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME");

Expand All @@ -94,10 +94,10 @@ public void options() {

assertThat(response.getResults()).hasSize(2);
assertThat(response.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(response.getResults().get(0).getOutput()).isEqualTo(List.of(7d, 8d, 9d));
assertThat(response.getResults().get(0).getOutput()).isEqualTo(new float[]{7f, 8f, 9f});
assertThat(response.getResults().get(0).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
assertThat(response.getResults().get(1).getIndex()).isEqualTo(1);
assertThat(response.getResults().get(1).getOutput()).isEqualTo(List.of(10d, 11d, 12d));
assertThat(response.getResults().get(1).getOutput()).isEqualTo(new float[]{10f, 11f, 12f});
assertThat(response.getResults().get(1).getMetadata()).isEqualTo(EmbeddingResultMetadata.EMPTY);
assertThat(response.getMetadata().getModel()).isEqualTo("RESPONSE_MODEL_NAME2");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public OpenAiEmbeddingModel(OpenAiApi openAiApi, MetadataMode metadataMode, Open
}

@Override
public List<Double> embed(Document document) {
public float[] embed(Document document) {
Assert.notNull(document, "Document must not be null");
return this.embed(document.getFormattedContent(this.metadataMode));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,7 @@ public String getValue() {
@JsonInclude(Include.NON_NULL)
public record Embedding(// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("embedding") List<Double> embedding,
@JsonProperty("embedding") float[] embedding,
@JsonProperty("object") String object) {// @formatter:on

/**
Expand All @@ -1112,7 +1112,7 @@ public record Embedding(// @formatter:off
* @param embedding The embedding vector, which is a list of floats. The length of
* vector depends on the model.
*/
public Embedding(Integer index, List<Double> embedding) {
public Embedding(Integer index, float[] embedding) {
this(index, embedding, "embedding");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ public void openAiChatStreamNonTransientError() {
public void openAiEmbeddingTransientError() {

EmbeddingList<Embedding> expectedEmbeddings = new EmbeddingList<>("list",
List.of(new Embedding(0, List.of(9.9, 8.8))), "model", new OpenAiApi.Usage(10, 10, 10));
List.of(new Embedding(0, new float[] { 9.9f, 8.8f })), "model", new OpenAiApi.Usage(10, 10, 10));

when(openAiApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new TransientAiException("Transient Error 1"))
.thenThrow(new TransientAiException("Transient Error 2"))
Expand All @@ -207,7 +207,7 @@ public void openAiEmbeddingTransientError() {
.call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));

assertThat(result).isNotNull();
assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8));
assertThat(result.getResult().getOutput()).isEqualTo(new float[] { 9.9f, 8.8f });
assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
}
Expand Down
Loading

0 comments on commit d538e00

Please sign in to comment.