diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index bffdbe2c74..5829d4e18a 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.ArrayList; @@ -92,6 +93,7 @@ * @author Thomas Vitale * @author luocongqiu * @author timostark + * @author Soby Chacko * @see ChatModel * @see com.azure.ai.openai.OpenAIClient */ @@ -456,6 +458,18 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel() : toSpringAiOptions.getDeploymentName()); + mergedAzureOptions + .setSeed(fromAzureOptions.getSeed() != null ? fromAzureOptions.getSeed() : toSpringAiOptions.getSeed()); + + mergedAzureOptions.setLogprobs((fromAzureOptions.isLogprobs() != null && fromAzureOptions.isLogprobs()) + || (toSpringAiOptions.isLogprobs() != null && toSpringAiOptions.isLogprobs())); + + mergedAzureOptions.setTopLogprobs(fromAzureOptions.getTopLogprobs() != null ? fromAzureOptions.getTopLogprobs() + : toSpringAiOptions.getTopLogProbs()); + + mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null + ? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements()); + return mergedAzureOptions; } @@ -520,6 +534,22 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, mergedAzureOptions.setResponseFormat(toAzureResponseFormat(fromSpringAiOptions.getResponseFormat())); } + if (fromSpringAiOptions.getSeed() != null) { + mergedAzureOptions.setSeed(fromSpringAiOptions.getSeed()); + } + + if (fromSpringAiOptions.isLogprobs() != null) { + mergedAzureOptions.setLogprobs(fromSpringAiOptions.isLogprobs()); + } + + if (fromSpringAiOptions.getTopLogProbs() != null) { + mergedAzureOptions.setTopLogprobs(fromSpringAiOptions.getTopLogProbs()); + } + + if (fromSpringAiOptions.getEnhancements() != null) { + mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements()); + } + return mergedAzureOptions; } @@ -566,6 +596,19 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { if (fromOptions.getResponseFormat() != null) { copyOptions.setResponseFormat(fromOptions.getResponseFormat()); } + if (fromOptions.getSeed() != null) { + copyOptions.setSeed(fromOptions.getSeed()); + } + + copyOptions.setLogprobs(fromOptions.isLogprobs()); + + if (fromOptions.getTopLogprobs() != null) { + copyOptions.setTopLogprobs(fromOptions.getTopLogprobs()); + } + + if (fromOptions.getEnhancements() != null) { + copyOptions.setEnhancements(fromOptions.getEnhancements()); + } return copyOptions; } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index 6b85eeb966..5faa64ebb1 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.ArrayList; @@ -21,6 +22,7 @@ import java.util.Map; import java.util.Set; +import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -40,6 +42,7 @@ * * @author Christian Tzolov * @author Thomas Vitale + * @author Soby Chacko */ @JsonInclude(Include.NON_NULL) public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions { @@ -165,6 +168,37 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio @JsonIgnore private Boolean proxyToolCalls; + /** + * Seed value for deterministic sampling such that the same seed and parameters return + * the same result. + */ + @JsonProperty(value = "seed") + private Long seed; + + /** + * Whether to return log probabilities of the output tokens or not. If true, returns + * the log probabilities of each output token returned in the `content` of `message`. + * This option is currently not available on the `gpt-4-vision-preview` model. + */ + @JsonProperty(value = "log_probs") + private Boolean logprobs; + + /* + * An integer between 0 and 5 specifying the number of most likely tokens to return at + * each token position, each with an associated log probability. `logprobs` must be + * set to `true` if this parameter is used. + */ + @JsonProperty(value = "top_log_probs") + private Integer topLogProbs; + + /* + * If provided, the configuration options for available Azure OpenAI chat + * enhancements. + */ + @NestedConfigurationProperty + @JsonIgnore + private AzureChatEnhancementConfiguration enhancements; + public static Builder builder() { return new Builder(); } @@ -259,6 +293,30 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) { return this; } + public Builder withSeed(Long seed) { + Assert.notNull(seed, "seed must not be null"); + this.options.seed = seed; + return this; + } + + public Builder withLogprobs(Boolean logprobs) { + Assert.notNull(logprobs, "logprobs must not be null"); + this.options.logprobs = logprobs; + return this; + } + + public Builder withTopLogprobs(Integer topLogprobs) { + Assert.notNull(topLogprobs, "topLogprobs must not be null"); + this.options.topLogProbs = topLogprobs; + return this; + } + + public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) { + Assert.notNull(enhancements, "enhancements must not be null"); + this.options.enhancements = enhancements; + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } @@ -404,6 +462,38 @@ public Integer getTopK() { return null; } + public Long getSeed() { + return this.seed; + } + + public void setSeed(Long seed) { + this.seed = seed; + } + + public Boolean isLogprobs() { + return this.logprobs; + } + + public void setLogprobs(Boolean logprobs) { + this.logprobs = logprobs; + } + + public Integer getTopLogProbs() { + return this.topLogProbs; + } + + public void setTopLogProbs(Integer topLogProbs) { + this.topLogProbs = topLogProbs; + } + + public AzureChatEnhancementConfiguration getEnhancements() { + return this.enhancements; + } + + public void setEnhancements(AzureChatEnhancementConfiguration enhancements) { + this.enhancements = enhancements; + } + @Override public Boolean getProxyToolCalls() { return this.proxyToolCalls; @@ -432,6 +522,10 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) .withResponseFormat(fromOptions.getResponseFormat()) + .withSeed(fromOptions.getSeed()) + .withLogprobs(fromOptions.isLogprobs()) + .withTopLogprobs(fromOptions.getTopLogProbs()) + .withEnhancements(fromOptions.getEnhancements()) .build(); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java index 6e8d8bd531..f7edea989b 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; +import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; import org.junit.jupiter.api.Test; @@ -34,6 +37,7 @@ /** * @author Christian Tzolov + * @author Soby Chacko */ public class AzureChatCompletionsOptionsTests { @@ -42,6 +46,9 @@ public void createRequestWithChatOptions() { OpenAIClient mockClient = Mockito.mock(OpenAIClient.class); + AzureChatEnhancementConfiguration mockAzureChatEnhancementConfiguration = Mockito + .mock(AzureChatEnhancementConfiguration.class); + var defaultOptions = AzureOpenAiChatOptions.builder() .withDeploymentName("DEFAULT_MODEL") .withTemperature(66.6) @@ -53,6 +60,10 @@ public void createRequestWithChatOptions() { .withStop(List.of("foo", "bar")) .withTopP(0.69) .withUser("user") + .withSeed(123L) + .withLogprobs(true) + .withTopLogprobs(5) + .withEnhancements(mockAzureChatEnhancementConfiguration) .withResponseFormat(AzureOpenAiResponseFormat.TEXT) .build(); @@ -72,8 +83,15 @@ public void createRequestWithChatOptions() { assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar")); assertThat(requestOptions.getTopP()).isEqualTo(0.69); assertThat(requestOptions.getUser()).isEqualTo("user"); + assertThat(requestOptions.getSeed()).isEqualTo(123L); + assertThat(requestOptions.isLogprobs()).isTrue(); + assertThat(requestOptions.getTopLogprobs()).isEqualTo(5); + assertThat(requestOptions.getEnhancements()).isEqualTo(mockAzureChatEnhancementConfiguration); assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsTextResponseFormat.class); + AzureChatEnhancementConfiguration anotherMockAzureChatEnhancementConfiguration = Mockito + .mock(AzureChatEnhancementConfiguration.class); + var runtimeOptions = AzureOpenAiChatOptions.builder() .withDeploymentName("PROMPT_MODEL") .withTemperature(99.9) @@ -85,6 +103,10 @@ public void createRequestWithChatOptions() { .withStop(List.of("foo", "bar")) .withTopP(0.111) .withUser("user2") + .withSeed(1234L) + .withLogprobs(true) + .withTopLogprobs(4) + .withEnhancements(anotherMockAzureChatEnhancementConfiguration) .withResponseFormat(AzureOpenAiResponseFormat.JSON) .build(); @@ -102,6 +124,10 @@ public void createRequestWithChatOptions() { assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar")); assertThat(requestOptions.getTopP()).isEqualTo(0.111); assertThat(requestOptions.getUser()).isEqualTo("user2"); + assertThat(requestOptions.getSeed()).isEqualTo(1234L); + assertThat(requestOptions.isLogprobs()).isTrue(); + assertThat(requestOptions.getTopLogprobs()).isEqualTo(4); + assertThat(requestOptions.getEnhancements()).isEqualTo(anotherMockAzureChatEnhancementConfiguration); assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsJsonResponseFormat.class); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index b88f795063..4eb00c374e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -936,12 +936,28 @@ public record TopLogProbs(// @formatter:off * @param promptTokens Number of tokens in the prompt. * @param totalTokens Total number of tokens used in the request (prompt + * completion). + * @param completionTokenDetails Breakdown of tokens used in a completion */ @JsonInclude(Include.NON_NULL) public record Usage(// @formatter:off @JsonProperty("completion_tokens") Integer completionTokens, @JsonProperty("prompt_tokens") Integer promptTokens, - @JsonProperty("total_tokens") Integer totalTokens) {// @formatter:on + @JsonProperty("total_tokens") Integer totalTokens, + @JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) {// @formatter:on + + public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { + this(completionTokens, promptTokens, totalTokens, null); + } + + /** + * Breakdown of tokens used in a completion + * + * @param reasoningTokens Number of tokens generated by the model for reasoning. + */ + @JsonInclude(Include.NON_NULL) + public record CompletionTokenDetails(// @formatter:off + @JsonProperty("reasoning_tokens") Integer reasoningTokens) {// @formatter:on + } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java index 821a0325d0..add5d896b5 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java @@ -58,6 +58,12 @@ public Long getGenerationTokens() { return generationTokens != null ? generationTokens.longValue() : 0; } + public Long getReasoningTokens() { + OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails(); + Integer reasoningTokens = completionTokenDetails != null ? completionTokenDetails.reasoningTokens() : null; + return reasoningTokens != null ? reasoningTokens.longValue() : 0; + } + @Override public Long getTotalTokens() { Integer totalTokens = getUsage().totalTokens(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java index 58c378f35b..b9215b4c3d 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java @@ -54,4 +54,28 @@ void whenTotalTokensIsNull() { assertThat(usage.getTotalTokens()).isEqualTo(300); } + @Test + void whenCompletionTokenDetailsIsNull() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null); + OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + assertThat(usage.getTotalTokens()).isEqualTo(300); + assertThat(usage.getReasoningTokens()).isEqualTo(0); + } + + @Test + void whenReasoningTokensIsNull() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, + new OpenAiApi.Usage.CompletionTokenDetails(null)); + OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + assertThat(usage.getReasoningTokens()).isEqualTo(0); + } + + @Test + void whenCompletionTokenDetailsIsPresent() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, + new OpenAiApi.Usage.CompletionTokenDetails(50)); + OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + assertThat(usage.getReasoningTokens()).isEqualTo(50); + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java index 2f26e97689..9dd7505371 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java @@ -5,8 +5,8 @@ import java.util.Collections; import java.util.List; -import java.util.Objects; import java.util.stream.Collectors; +import org.springframework.util.StringUtils; public class RelevancyEvaluator implements Evaluator { @@ -57,9 +57,7 @@ protected String doGetSupportingData(EvaluationRequest evaluationRequest) { List data = evaluationRequest.getDataList(); return data.stream() .map(Content::getContent) - .filter(Objects::nonNull) - .filter(c -> c instanceof String) - .map(Object::toString) + .filter(StringUtils::hasText) .collect(Collectors.joining(System.lineSeparator())); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java index 895feab7ba..b00eace87f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java @@ -42,19 +42,19 @@ */ public class JsonReader implements DocumentReader { - private Resource resource; + private final Resource resource; - private JsonMetadataGenerator jsonMetadataGenerator; + private final JsonMetadataGenerator jsonMetadataGenerator; private final ObjectMapper objectMapper = new ObjectMapper(); /** * The key from the JSON that we will use as the text to parse into the Document text */ - private List jsonKeysToUse; + private final List jsonKeysToUse; public JsonReader(Resource resource) { - this(resource, new ArrayList<>().toArray(new String[0])); + this(resource, new String[0]); } public JsonReader(Resource resource, String... jsonKeysToUse) { @@ -92,9 +92,9 @@ public List get() { private Document parseJsonNode(JsonNode jsonNode, ObjectMapper objectMapper) { Map item = objectMapper.convertValue(jsonNode, new TypeReference>() { }); - StringBuffer sb = new StringBuffer(); + var sb = new StringBuilder(); - jsonKeysToUse.parallelStream().filter(item::containsKey).forEach(key -> { + jsonKeysToUse.stream().filter(item::containsKey).forEach(key -> { sb.append(key).append(": ").append(item.get(key)).append(System.lineSeparator()); }); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc index 895e5eda84..359160ad1a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc @@ -256,5 +256,8 @@ List results = vectorStore.similaritySearch( ); ---- +== Tutorials and Code Examples +To get started with Spring AI and MongoDB: -If you would like to try out Spring AI with MongoDB, see https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/spring-ai/#std-label-spring-ai[Get Started with the Spring AI Integration]. +* See the https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/spring-ai/#std-label-spring-ai[Getting Started guide for Spring AI Integration]. +* For a comprehensive code example demonstrating Retrieval Augmented Generation (RAG) with Spring AI and MongoDB, refer to this https://www.mongodb.com/developer/languages/java/retrieval-augmented-generation-spring-ai/[detailed tutorial]. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java index 1b5a62507a..ec4d76e074 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java @@ -71,6 +71,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed .withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null)) .withBatchingStrategy(batchingStrategy) + .withMaxDocumentBatchSize(properties.getMaxDocumentBatchSize()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java index b455417461..47a12c36d3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java @@ -24,6 +24,7 @@ /** * @author Christian Tzolov * @author Muthukumaran Navaneethakrishnan + * @author Soby Chacko */ @ConfigurationProperties(PgVectorStoreProperties.CONFIG_PREFIX) public class PgVectorStoreProperties extends CommonVectorStoreProperties { @@ -45,6 +46,8 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties { private boolean schemaValidation = PgVectorStore.DEFAULT_SCHEMA_VALIDATION; + private int maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE; + public int getDimensions() { return dimensions; } @@ -101,4 +104,12 @@ public void setSchemaValidation(boolean schemaValidation) { this.schemaValidation = schemaValidation; } + public int getMaxDocumentBatchSize() { + return this.maxDocumentBatchSize; + } + + public void setMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index 11cd9409fb..c0a7245659 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure; import com.azure.ai.openai.OpenAIClient; @@ -95,22 +96,29 @@ void chatCompletion() { } @Test - void httpRequestContainsUserAgentHeader() { - contextRunner.run(context -> { - OpenAIClient openAIClient = context.getBean(OpenAIClient.class); - Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient"); - assertThat(serviceClientField).isNotNull(); - ReflectionUtils.makeAccessible(serviceClientField); - OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient); - assertThat(oaci).isNotNull(); - HttpPipeline httpPipeline = oaci.getHttpPipeline(); - HttpResponse httpResponse = httpPipeline - .send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL())) - .block(); - assertThat(httpResponse).isNotNull(); - HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT); - assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue(); - }); + void httpRequestContainsUserAgentAndCustomHeaders() { + contextRunner + .withPropertyValues("spring.ai.azure.openai.custom-headers.foo=bar", + "spring.ai.azure.openai.custom-headers.fizz=buzz") + .run(context -> { + OpenAIClient openAIClient = context.getBean(OpenAIClient.class); + Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient"); + assertThat(serviceClientField).isNotNull(); + ReflectionUtils.makeAccessible(serviceClientField); + OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient); + assertThat(oaci).isNotNull(); + HttpPipeline httpPipeline = oaci.getHttpPipeline(); + HttpResponse httpResponse = httpPipeline + .send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL())) + .block(); + assertThat(httpResponse).isNotNull(); + HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT); + assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue(); + HttpHeader customHeader1 = httpResponse.getRequest().getHeaders().get("foo"); + assertThat(customHeader1.getValue()).isEqualTo("bar"); + HttpHeader customHeader2 = httpResponse.getRequest().getHeaders().get("fizz"); + assertThat(customHeader2.getValue()).isEqualTo("buzz"); + }); } @Test diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java index c863b1ec48..b861bea6ad 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java @@ -60,7 +60,7 @@ static class ChromaDockerComposeConnectionDetails extends DockerComposeConnectio @Override public String getHost() { - return this.host; + return "http://%s".formatted(this.host); } @Override diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java index 6ce419d914..9932b1fe89 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java @@ -31,7 +31,9 @@ import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import org.springframework.web.client.HttpClientErrorException; import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.client.HttpStatusCodeException; import org.springframework.web.client.RestClient; import com.fasterxml.jackson.annotation.JsonProperty; @@ -49,6 +51,9 @@ public class ChromaApi { // Regular expression pattern that looks for a message inside the ValueError(...). private static Pattern VALUE_ERROR_PATTERN = Pattern.compile("ValueError\\('([^']*)'\\)"); + // Regular expression pattern that looks for a message. + private static Pattern MESSAGE_ERROR_PATTERN = Pattern.compile("\"message\":\"(.*?)\""); + private RestClient restClient; private final ObjectMapper objectMapper; @@ -316,8 +321,8 @@ public Collection getCollection(String collectionName) { .toEntity(Collection.class) .getBody(); } - catch (HttpServerErrorException e) { - String msg = this.getValueErrorMessage(e.getMessage()); + catch (HttpServerErrorException | HttpClientErrorException e) { + String msg = this.getErrorMessage(e); if (String.format("Collection %s does not exist.", collectionName).equals(msg)) { return null; } @@ -413,12 +418,28 @@ private void httpHeaders(HttpHeaders headers) { } } - private String getValueErrorMessage(String logString) { - if (!StringUtils.hasText(logString)) { + private String getErrorMessage(HttpStatusCodeException e) { + var errorMessage = e.getMessage(); + + // If the error message is empty or null, return an empty string + if (!StringUtils.hasText(errorMessage)) { return ""; } - Matcher m = VALUE_ERROR_PATTERN.matcher(logString); - return (m.find()) ? m.group(1) : ""; + + // If the exception is an HttpServerErrorException, use the VALUE_ERROR_PATTERN + Matcher valueErrorMatcher = VALUE_ERROR_PATTERN.matcher(errorMessage); + if (e instanceof HttpServerErrorException && valueErrorMatcher.find()) { + return valueErrorMatcher.group(1); + } + + // Otherwise, use the MESSAGE_ERROR_PATTERN for other cases + Matcher messageErrorMatcher = MESSAGE_ERROR_PATTERN.matcher(errorMessage); + if (messageErrorMatcher.find()) { + return messageErrorMatcher.group(1); + } + + // If no pattern matches, return an empty string + return ""; } } diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 8ab0546fc5..697960f15d 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -15,14 +15,10 @@ */ package org.springframework.ai.vectorstore; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; - +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.pgvector.PGvector; +import io.micrometer.observation.ObservationRegistry; import org.postgresql.util.PGobject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,11 +42,14 @@ import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.pgvector.PGvector; - -import io.micrometer.observation.ObservationRegistry; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; /** * Uses the "vector_store" table to store the Spring AI vector data. The table and the @@ -81,6 +80,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter(); + public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000; + private final String vectorTableName; private final String vectorIndexName; @@ -109,6 +110,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini private final BatchingStrategy batchingStrategy; + private final int maxDocumentBatchSize; + public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false, PgIndexType.NONE, false); @@ -132,7 +135,6 @@ public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Embeddin this(DEFAULT_SCHEMA_NAME, vectorTableName, DEFAULT_SCHEMA_VALIDATION, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema); - } private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, @@ -141,14 +143,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema, - ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy()); + ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy(), MAX_DOCUMENT_BATCH_SIZE); } private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, - BatchingStrategy batchingStrategy) { + BatchingStrategy batchingStrategy, int maxDocumentBatchSize) { super(observationRegistry, customObservationConvention); @@ -172,6 +174,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this.initializeSchema = initializeSchema; this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate); this.batchingStrategy = batchingStrategy; + this.maxDocumentBatchSize = maxDocumentBatchSize; } public PgDistanceType getDistanceType() { @@ -180,40 +183,50 @@ public PgDistanceType getDistanceType() { @Override public void doAdd(List documents) { + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); - int size = documents.size(); + List> batchedDocuments = batchDocuments(documents); + batchedDocuments.forEach(this::insertOrUpdateBatch); + } - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + private List> batchDocuments(List documents) { + List> batches = new ArrayList<>(); + for (int i = 0; i < documents.size(); i += this.maxDocumentBatchSize) { + batches.add(documents.subList(i, Math.min(i + this.maxDocumentBatchSize, documents.size()))); + } + return batches; + } - this.jdbcTemplate.batchUpdate( - "INSERT INTO " + getFullyQualifiedTableName() - + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " - + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", - new BatchPreparedStatementSetter() { - @Override - public void setValues(PreparedStatement ps, int i) throws SQLException { - - var document = documents.get(i); - var content = document.getContent(); - var json = toJson(document.getMetadata()); - var embedding = document.getEmbedding(); - var pGvector = new PGvector(embedding); - - StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, - UUID.fromString(document.getId())); - StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); - StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); - } + private void insertOrUpdateBatch(List batch) { + String sql = "INSERT INTO " + getFullyQualifiedTableName() + + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " + + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? "; + + this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + + var document = batch.get(i); + var content = document.getContent(); + var json = toJson(document.getMetadata()); + var embedding = document.getEmbedding(); + var pGvector = new PGvector(embedding); + + StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, + UUID.fromString(document.getId())); + StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); + StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); + } - @Override - public int getBatchSize() { - return size; - } - }); + @Override + public int getBatchSize() { + return batch.size(); + } + }); } private String toJson(Map map) { @@ -285,7 +298,7 @@ private String comparisonOperator() { // Initialize // --------------------------------------------------------------------------------- @Override - public void afterPropertiesSet() throws Exception { + public void afterPropertiesSet() { logger.info("Initializing PGVectorStore schema for table: {} in schema: {}", this.getVectorTableName(), this.getSchemaName()); @@ -390,7 +403,7 @@ public enum PgIndexType { * speed-recall tradeoff). There’s no training step like IVFFlat, so the index can * be created without any data in the table. */ - HNSW; + HNSW } @@ -443,7 +456,7 @@ private static class DocumentRowMapper implements RowMapper { private static final String COLUMN_DISTANCE = "distance"; - private ObjectMapper objectMapper; + private final ObjectMapper objectMapper; public DocumentRowMapper(ObjectMapper objectMapper) { this.objectMapper = objectMapper; @@ -509,6 +522,8 @@ public static class Builder { private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE; + @Nullable private VectorStoreObservationConvention searchObservationConvention; @@ -576,11 +591,17 @@ public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) { return this; } + public Builder withMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + return this; + } + public PgVectorStore build() { return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled, this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType, this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema, - this.observationRegistry, this.searchObservationConvention, this.batchingStrategy); + this.observationRegistry, this.searchObservationConvention, this.batchingStrategy, + this.maxDocumentBatchSize); } } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java index f5b69a922c..488dbd3f73 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java @@ -15,15 +15,31 @@ */ package org.springframework.ai.vectorstore; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.ArgumentCaptor; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.only; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.Collections; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; /** * @author Muthukumaran Navaneethakrishnan + * @author Soby Chacko */ - public class PgVectorStoreTests { @ParameterizedTest(name = "{0} - Verifies valid Table name") @@ -53,8 +69,39 @@ public class PgVectorStoreTests { // 64 // characters }) - public void isValidTable(String tableName, Boolean expected) { + void isValidTable(String tableName, Boolean expected) { assertThat(PgVectorSchemaValidator.isValidNameForDatabaseObject(tableName)).isEqualTo(expected); } + @Test + void shouldAddDocumentsInBatchesAndEmbedOnce() { + // Given + var jdbcTemplate = mock(JdbcTemplate.class); + var embeddingModel = mock(EmbeddingModel.class); + var pgVectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withMaxDocumentBatchSize(1000) + .build(); + + // Testing with 9989 documents + var documents = Collections.nCopies(9989, new Document("foo")); + + // When + pgVectorStore.doAdd(documents); + + // Then + verify(embeddingModel, only()).embed(eq(documents), any(), any()); + + var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class); + verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture()); + + assertThat(batchUpdateCaptor.getAllValues()).hasSize(10) + .allSatisfy(BatchPreparedStatementSetter::getBatchSize) + .satisfies(batches -> { + for (int i = 0; i < 9; i++) { + assertThat(batches.get(i).getBatchSize()).as("Batch at index %d should have size 10", i) + .isEqualTo(1000); + } + assertThat(batches.get(9).getBatchSize()).as("Last batch should have size 989").isEqualTo(989); + }); + } + }