From 35e6113233824d061bfb8bf5059510aeca529787 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Fri, 20 Sep 2024 14:38:58 -0400 Subject: [PATCH 1/9] Adding integration test for Azure custom headers --- .../azure/AzureOpenAiAutoConfigurationIT.java | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index 11cd9409fb..c0a7245659 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure; import com.azure.ai.openai.OpenAIClient; @@ -95,22 +96,29 @@ void chatCompletion() { } @Test - void httpRequestContainsUserAgentHeader() { - contextRunner.run(context -> { - OpenAIClient openAIClient = context.getBean(OpenAIClient.class); - Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient"); - assertThat(serviceClientField).isNotNull(); - ReflectionUtils.makeAccessible(serviceClientField); - OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient); - assertThat(oaci).isNotNull(); - HttpPipeline httpPipeline = oaci.getHttpPipeline(); - HttpResponse httpResponse = httpPipeline - .send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL())) - .block(); - assertThat(httpResponse).isNotNull(); - HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT); - assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue(); - }); + void httpRequestContainsUserAgentAndCustomHeaders() { + contextRunner + .withPropertyValues("spring.ai.azure.openai.custom-headers.foo=bar", + "spring.ai.azure.openai.custom-headers.fizz=buzz") + .run(context -> { + OpenAIClient openAIClient = context.getBean(OpenAIClient.class); + Field serviceClientField = ReflectionUtils.findField(OpenAIClient.class, "serviceClient"); + assertThat(serviceClientField).isNotNull(); + ReflectionUtils.makeAccessible(serviceClientField); + OpenAIClientImpl oaci = (OpenAIClientImpl) ReflectionUtils.getField(serviceClientField, openAIClient); + assertThat(oaci).isNotNull(); + HttpPipeline httpPipeline = oaci.getHttpPipeline(); + HttpResponse httpResponse = httpPipeline + .send(new HttpRequest(HttpMethod.POST, new URI(System.getenv("AZURE_OPENAI_ENDPOINT")).toURL())) + .block(); + assertThat(httpResponse).isNotNull(); + HttpHeader httpHeader = httpResponse.getRequest().getHeaders().get(HttpHeaderName.USER_AGENT); + assertThat(httpHeader.getValue().startsWith("spring-ai azsdk-java-azure-ai-openai/")).isTrue(); + HttpHeader customHeader1 = httpResponse.getRequest().getHeaders().get("foo"); + assertThat(customHeader1.getValue()).isEqualTo("bar"); + HttpHeader customHeader2 = httpResponse.getRequest().getHeaders().get("fizz"); + assertThat(customHeader2.getValue()).isEqualTo("buzz"); + }); } @Test From 42dcb45f32373d7c978d4cb7baad0b708a91eb67 Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Fri, 20 Sep 2024 17:09:28 -0400 Subject: [PATCH 2/9] Align AzureOpenAiChatOptions with Azure ChatCompletionsOptions Add missing options from Azure ChatCompletionsOptions to Spring AI AzureOpenAiChatOptions. The following fields have been added: - seed - logprobs - topLogprobs - enhancements This change ensures better alignment between the two option sets, improving compatibility and feature parity. Resolves https://github.com/spring-projects/spring-ai/issues/889 --- .../ai/azure/openai/AzureOpenAiChatModel.java | 45 ++++++++- .../azure/openai/AzureOpenAiChatOptions.java | 96 ++++++++++++++++++- .../AzureChatCompletionsOptionsTests.java | 28 +++++- 3 files changed, 166 insertions(+), 3 deletions(-) diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index bffdbe2c74..5829d4e18a 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.ArrayList; @@ -92,6 +93,7 @@ * @author Thomas Vitale * @author luocongqiu * @author timostark + * @author Soby Chacko * @see ChatModel * @see com.azure.ai.openai.OpenAIClient */ @@ -456,6 +458,18 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromAzureOptions, mergedAzureOptions.setModel(fromAzureOptions.getModel() != null ? fromAzureOptions.getModel() : toSpringAiOptions.getDeploymentName()); + mergedAzureOptions + .setSeed(fromAzureOptions.getSeed() != null ? fromAzureOptions.getSeed() : toSpringAiOptions.getSeed()); + + mergedAzureOptions.setLogprobs((fromAzureOptions.isLogprobs() != null && fromAzureOptions.isLogprobs()) + || (toSpringAiOptions.isLogprobs() != null && toSpringAiOptions.isLogprobs())); + + mergedAzureOptions.setTopLogprobs(fromAzureOptions.getTopLogprobs() != null ? fromAzureOptions.getTopLogprobs() + : toSpringAiOptions.getTopLogProbs()); + + mergedAzureOptions.setEnhancements(fromAzureOptions.getEnhancements() != null + ? fromAzureOptions.getEnhancements() : toSpringAiOptions.getEnhancements()); + return mergedAzureOptions; } @@ -520,6 +534,22 @@ private ChatCompletionsOptions merge(AzureOpenAiChatOptions fromSpringAiOptions, mergedAzureOptions.setResponseFormat(toAzureResponseFormat(fromSpringAiOptions.getResponseFormat())); } + if (fromSpringAiOptions.getSeed() != null) { + mergedAzureOptions.setSeed(fromSpringAiOptions.getSeed()); + } + + if (fromSpringAiOptions.isLogprobs() != null) { + mergedAzureOptions.setLogprobs(fromSpringAiOptions.isLogprobs()); + } + + if (fromSpringAiOptions.getTopLogProbs() != null) { + mergedAzureOptions.setTopLogprobs(fromSpringAiOptions.getTopLogProbs()); + } + + if (fromSpringAiOptions.getEnhancements() != null) { + mergedAzureOptions.setEnhancements(fromSpringAiOptions.getEnhancements()); + } + return mergedAzureOptions; } @@ -566,6 +596,19 @@ private ChatCompletionsOptions copy(ChatCompletionsOptions fromOptions) { if (fromOptions.getResponseFormat() != null) { copyOptions.setResponseFormat(fromOptions.getResponseFormat()); } + if (fromOptions.getSeed() != null) { + copyOptions.setSeed(fromOptions.getSeed()); + } + + copyOptions.setLogprobs(fromOptions.isLogprobs()); + + if (fromOptions.getTopLogprobs() != null) { + copyOptions.setTopLogprobs(fromOptions.getTopLogprobs()); + } + + if (fromOptions.getEnhancements() != null) { + copyOptions.setEnhancements(fromOptions.getEnhancements()); + } return copyOptions; } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index 6b85eeb966..5faa64ebb1 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import java.util.ArrayList; @@ -21,6 +22,7 @@ import java.util.Map; import java.util.Set; +import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -40,6 +42,7 @@ * * @author Christian Tzolov * @author Thomas Vitale + * @author Soby Chacko */ @JsonInclude(Include.NON_NULL) public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptions { @@ -165,6 +168,37 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions, ChatOptio @JsonIgnore private Boolean proxyToolCalls; + /** + * Seed value for deterministic sampling such that the same seed and parameters return + * the same result. + */ + @JsonProperty(value = "seed") + private Long seed; + + /** + * Whether to return log probabilities of the output tokens or not. If true, returns + * the log probabilities of each output token returned in the `content` of `message`. + * This option is currently not available on the `gpt-4-vision-preview` model. + */ + @JsonProperty(value = "log_probs") + private Boolean logprobs; + + /* + * An integer between 0 and 5 specifying the number of most likely tokens to return at + * each token position, each with an associated log probability. `logprobs` must be + * set to `true` if this parameter is used. + */ + @JsonProperty(value = "top_log_probs") + private Integer topLogProbs; + + /* + * If provided, the configuration options for available Azure OpenAI chat + * enhancements. + */ + @NestedConfigurationProperty + @JsonIgnore + private AzureChatEnhancementConfiguration enhancements; + public static Builder builder() { return new Builder(); } @@ -259,6 +293,30 @@ public Builder withProxyToolCalls(Boolean proxyToolCalls) { return this; } + public Builder withSeed(Long seed) { + Assert.notNull(seed, "seed must not be null"); + this.options.seed = seed; + return this; + } + + public Builder withLogprobs(Boolean logprobs) { + Assert.notNull(logprobs, "logprobs must not be null"); + this.options.logprobs = logprobs; + return this; + } + + public Builder withTopLogprobs(Integer topLogprobs) { + Assert.notNull(topLogprobs, "topLogprobs must not be null"); + this.options.topLogProbs = topLogprobs; + return this; + } + + public Builder withEnhancements(AzureChatEnhancementConfiguration enhancements) { + Assert.notNull(enhancements, "enhancements must not be null"); + this.options.enhancements = enhancements; + return this; + } + public AzureOpenAiChatOptions build() { return this.options; } @@ -404,6 +462,38 @@ public Integer getTopK() { return null; } + public Long getSeed() { + return this.seed; + } + + public void setSeed(Long seed) { + this.seed = seed; + } + + public Boolean isLogprobs() { + return this.logprobs; + } + + public void setLogprobs(Boolean logprobs) { + this.logprobs = logprobs; + } + + public Integer getTopLogProbs() { + return this.topLogProbs; + } + + public void setTopLogProbs(Integer topLogProbs) { + this.topLogProbs = topLogProbs; + } + + public AzureChatEnhancementConfiguration getEnhancements() { + return this.enhancements; + } + + public void setEnhancements(AzureChatEnhancementConfiguration enhancements) { + this.enhancements = enhancements; + } + @Override public Boolean getProxyToolCalls() { return this.proxyToolCalls; @@ -432,6 +522,10 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .withFunctionCallbacks(fromOptions.getFunctionCallbacks()) .withFunctions(fromOptions.getFunctions()) .withResponseFormat(fromOptions.getResponseFormat()) + .withSeed(fromOptions.getSeed()) + .withLogprobs(fromOptions.isLogprobs()) + .withTopLogprobs(fromOptions.getTopLogProbs()) + .withEnhancements(fromOptions.getEnhancements()) .build(); } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java index 6e8d8bd531..f7edea989b 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureChatCompletionsOptionsTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.azure.openai; import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; +import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration; import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; import org.junit.jupiter.api.Test; @@ -34,6 +37,7 @@ /** * @author Christian Tzolov + * @author Soby Chacko */ public class AzureChatCompletionsOptionsTests { @@ -42,6 +46,9 @@ public void createRequestWithChatOptions() { OpenAIClient mockClient = Mockito.mock(OpenAIClient.class); + AzureChatEnhancementConfiguration mockAzureChatEnhancementConfiguration = Mockito + .mock(AzureChatEnhancementConfiguration.class); + var defaultOptions = AzureOpenAiChatOptions.builder() .withDeploymentName("DEFAULT_MODEL") .withTemperature(66.6) @@ -53,6 +60,10 @@ public void createRequestWithChatOptions() { .withStop(List.of("foo", "bar")) .withTopP(0.69) .withUser("user") + .withSeed(123L) + .withLogprobs(true) + .withTopLogprobs(5) + .withEnhancements(mockAzureChatEnhancementConfiguration) .withResponseFormat(AzureOpenAiResponseFormat.TEXT) .build(); @@ -72,8 +83,15 @@ public void createRequestWithChatOptions() { assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar")); assertThat(requestOptions.getTopP()).isEqualTo(0.69); assertThat(requestOptions.getUser()).isEqualTo("user"); + assertThat(requestOptions.getSeed()).isEqualTo(123L); + assertThat(requestOptions.isLogprobs()).isTrue(); + assertThat(requestOptions.getTopLogprobs()).isEqualTo(5); + assertThat(requestOptions.getEnhancements()).isEqualTo(mockAzureChatEnhancementConfiguration); assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsTextResponseFormat.class); + AzureChatEnhancementConfiguration anotherMockAzureChatEnhancementConfiguration = Mockito + .mock(AzureChatEnhancementConfiguration.class); + var runtimeOptions = AzureOpenAiChatOptions.builder() .withDeploymentName("PROMPT_MODEL") .withTemperature(99.9) @@ -85,6 +103,10 @@ public void createRequestWithChatOptions() { .withStop(List.of("foo", "bar")) .withTopP(0.111) .withUser("user2") + .withSeed(1234L) + .withLogprobs(true) + .withTopLogprobs(4) + .withEnhancements(anotherMockAzureChatEnhancementConfiguration) .withResponseFormat(AzureOpenAiResponseFormat.JSON) .build(); @@ -102,6 +124,10 @@ public void createRequestWithChatOptions() { assertThat(requestOptions.getStop()).isEqualTo(List.of("foo", "bar")); assertThat(requestOptions.getTopP()).isEqualTo(0.111); assertThat(requestOptions.getUser()).isEqualTo("user2"); + assertThat(requestOptions.getSeed()).isEqualTo(1234L); + assertThat(requestOptions.isLogprobs()).isTrue(); + assertThat(requestOptions.getTopLogprobs()).isEqualTo(4); + assertThat(requestOptions.getEnhancements()).isEqualTo(anotherMockAzureChatEnhancementConfiguration); assertThat(requestOptions.getResponseFormat()).isInstanceOf(ChatCompletionsJsonResponseFormat.class); } From 202148d45bf9c226a04768f7ff9836a89e0bee9c Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Mon, 23 Sep 2024 21:37:04 -0400 Subject: [PATCH 3/9] Prevent timeouts with configurable batching for PgVectorStore inserts Resolves https://github.com/spring-projects/spring-ai/issues/1199 - Implement configurable maxDocumentBatchSize to prevent insert timeouts when adding large numbers of documents - Update PgVectorStore to process document inserts in controlled batches - Add maxDocumentBatchSize property to PgVectorStoreProperties - Update PgVectorStoreAutoConfiguration to use the new batching property - Add tests to verify batching behavior and performance This change addresses the issue of PgVectorStore inserts timing out due to large document volumes. By introducing configurable batching, users can now control the insert process to avoid timeouts while maintaining performance and reducing memory overhead for large-scale document additions. --- .../PgVectorStoreAutoConfiguration.java | 1 + .../pgvector/PgVectorStoreProperties.java | 11 ++ .../ai/vectorstore/PgVectorStore.java | 121 ++++++++++-------- .../ai/vectorstore/PgVectorStoreTests.java | 51 +++++++- 4 files changed, 132 insertions(+), 52 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java index 1b5a62507a..ec4d76e074 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreAutoConfiguration.java @@ -71,6 +71,7 @@ public PgVectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embed .withObservationRegistry(observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)) .withSearchObservationConvention(customObservationConvention.getIfAvailable(() -> null)) .withBatchingStrategy(batchingStrategy) + .withMaxDocumentBatchSize(properties.getMaxDocumentBatchSize()) .build(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java index b455417461..47a12c36d3 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/pgvector/PgVectorStoreProperties.java @@ -24,6 +24,7 @@ /** * @author Christian Tzolov * @author Muthukumaran Navaneethakrishnan + * @author Soby Chacko */ @ConfigurationProperties(PgVectorStoreProperties.CONFIG_PREFIX) public class PgVectorStoreProperties extends CommonVectorStoreProperties { @@ -45,6 +46,8 @@ public class PgVectorStoreProperties extends CommonVectorStoreProperties { private boolean schemaValidation = PgVectorStore.DEFAULT_SCHEMA_VALIDATION; + private int maxDocumentBatchSize = PgVectorStore.MAX_DOCUMENT_BATCH_SIZE; + public int getDimensions() { return dimensions; } @@ -101,4 +104,12 @@ public void setSchemaValidation(boolean schemaValidation) { this.schemaValidation = schemaValidation; } + public int getMaxDocumentBatchSize() { + return this.maxDocumentBatchSize; + } + + public void setMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + } + } diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java index 8ab0546fc5..697960f15d 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/vectorstore/PgVectorStore.java @@ -15,14 +15,10 @@ */ package org.springframework.ai.vectorstore; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; - +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.pgvector.PGvector; +import io.micrometer.observation.ObservationRegistry; import org.postgresql.util.PGobject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,11 +42,14 @@ import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.pgvector.PGvector; - -import io.micrometer.observation.ObservationRegistry; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; /** * Uses the "vector_store" table to store the Spring AI vector data. The table and the @@ -81,6 +80,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini public final FilterExpressionConverter filterExpressionConverter = new PgVectorFilterExpressionConverter(); + public static final int MAX_DOCUMENT_BATCH_SIZE = 10_000; + private final String vectorTableName; private final String vectorIndexName; @@ -109,6 +110,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini private final BatchingStrategy batchingStrategy; + private final int maxDocumentBatchSize; + public PgVectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) { this(jdbcTemplate, embeddingModel, INVALID_EMBEDDING_DIMENSION, PgDistanceType.COSINE_DISTANCE, false, PgIndexType.NONE, false); @@ -132,7 +135,6 @@ public PgVectorStore(String vectorTableName, JdbcTemplate jdbcTemplate, Embeddin this(DEFAULT_SCHEMA_NAME, vectorTableName, DEFAULT_SCHEMA_VALIDATION, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema); - } private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, @@ -141,14 +143,14 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this(schemaName, vectorTableName, vectorTableValidationsEnabled, jdbcTemplate, embeddingModel, dimensions, distanceType, removeExistingVectorStoreTable, createIndexMethod, initializeSchema, - ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy()); + ObservationRegistry.NOOP, null, new TokenCountBatchingStrategy(), MAX_DOCUMENT_BATCH_SIZE); } private PgVectorStore(String schemaName, String vectorTableName, boolean vectorTableValidationsEnabled, JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel, int dimensions, PgDistanceType distanceType, boolean removeExistingVectorStoreTable, PgIndexType createIndexMethod, boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, - BatchingStrategy batchingStrategy) { + BatchingStrategy batchingStrategy, int maxDocumentBatchSize) { super(observationRegistry, customObservationConvention); @@ -172,6 +174,7 @@ private PgVectorStore(String schemaName, String vectorTableName, boolean vectorT this.initializeSchema = initializeSchema; this.schemaValidator = new PgVectorSchemaValidator(jdbcTemplate); this.batchingStrategy = batchingStrategy; + this.maxDocumentBatchSize = maxDocumentBatchSize; } public PgDistanceType getDistanceType() { @@ -180,40 +183,50 @@ public PgDistanceType getDistanceType() { @Override public void doAdd(List documents) { + this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); - int size = documents.size(); + List> batchedDocuments = batchDocuments(documents); + batchedDocuments.forEach(this::insertOrUpdateBatch); + } - this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy); + private List> batchDocuments(List documents) { + List> batches = new ArrayList<>(); + for (int i = 0; i < documents.size(); i += this.maxDocumentBatchSize) { + batches.add(documents.subList(i, Math.min(i + this.maxDocumentBatchSize, documents.size()))); + } + return batches; + } - this.jdbcTemplate.batchUpdate( - "INSERT INTO " + getFullyQualifiedTableName() - + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " - + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? ", - new BatchPreparedStatementSetter() { - @Override - public void setValues(PreparedStatement ps, int i) throws SQLException { - - var document = documents.get(i); - var content = document.getContent(); - var json = toJson(document.getMetadata()); - var embedding = document.getEmbedding(); - var pGvector = new PGvector(embedding); - - StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, - UUID.fromString(document.getId())); - StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); - StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); - StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); - StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); - } + private void insertOrUpdateBatch(List batch) { + String sql = "INSERT INTO " + getFullyQualifiedTableName() + + " (id, content, metadata, embedding) VALUES (?, ?, ?::jsonb, ?) " + "ON CONFLICT (id) DO " + + "UPDATE SET content = ? , metadata = ?::jsonb , embedding = ? "; + + this.jdbcTemplate.batchUpdate(sql, new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + + var document = batch.get(i); + var content = document.getContent(); + var json = toJson(document.getMetadata()); + var embedding = document.getEmbedding(); + var pGvector = new PGvector(embedding); + + StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, + UUID.fromString(document.getId())); + StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector); + StatementCreatorUtils.setParameterValue(ps, 5, SqlTypeValue.TYPE_UNKNOWN, content); + StatementCreatorUtils.setParameterValue(ps, 6, SqlTypeValue.TYPE_UNKNOWN, json); + StatementCreatorUtils.setParameterValue(ps, 7, SqlTypeValue.TYPE_UNKNOWN, pGvector); + } - @Override - public int getBatchSize() { - return size; - } - }); + @Override + public int getBatchSize() { + return batch.size(); + } + }); } private String toJson(Map map) { @@ -285,7 +298,7 @@ private String comparisonOperator() { // Initialize // --------------------------------------------------------------------------------- @Override - public void afterPropertiesSet() throws Exception { + public void afterPropertiesSet() { logger.info("Initializing PGVectorStore schema for table: {} in schema: {}", this.getVectorTableName(), this.getSchemaName()); @@ -390,7 +403,7 @@ public enum PgIndexType { * speed-recall tradeoff). There’s no training step like IVFFlat, so the index can * be created without any data in the table. */ - HNSW; + HNSW } @@ -443,7 +456,7 @@ private static class DocumentRowMapper implements RowMapper { private static final String COLUMN_DISTANCE = "distance"; - private ObjectMapper objectMapper; + private final ObjectMapper objectMapper; public DocumentRowMapper(ObjectMapper objectMapper) { this.objectMapper = objectMapper; @@ -509,6 +522,8 @@ public static class Builder { private BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(); + private int maxDocumentBatchSize = MAX_DOCUMENT_BATCH_SIZE; + @Nullable private VectorStoreObservationConvention searchObservationConvention; @@ -576,11 +591,17 @@ public Builder withBatchingStrategy(BatchingStrategy batchingStrategy) { return this; } + public Builder withMaxDocumentBatchSize(int maxDocumentBatchSize) { + this.maxDocumentBatchSize = maxDocumentBatchSize; + return this; + } + public PgVectorStore build() { return new PgVectorStore(this.schemaName, this.vectorTableName, this.vectorTableValidationsEnabled, this.jdbcTemplate, this.embeddingModel, this.dimensions, this.distanceType, this.removeExistingVectorStoreTable, this.indexType, this.initializeSchema, - this.observationRegistry, this.searchObservationConvention, this.batchingStrategy); + this.observationRegistry, this.searchObservationConvention, this.batchingStrategy, + this.maxDocumentBatchSize); } } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java index f5b69a922c..488dbd3f73 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreTests.java @@ -15,15 +15,31 @@ */ package org.springframework.ai.vectorstore; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.ArgumentCaptor; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.only; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.Collections; + +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; /** * @author Muthukumaran Navaneethakrishnan + * @author Soby Chacko */ - public class PgVectorStoreTests { @ParameterizedTest(name = "{0} - Verifies valid Table name") @@ -53,8 +69,39 @@ public class PgVectorStoreTests { // 64 // characters }) - public void isValidTable(String tableName, Boolean expected) { + void isValidTable(String tableName, Boolean expected) { assertThat(PgVectorSchemaValidator.isValidNameForDatabaseObject(tableName)).isEqualTo(expected); } + @Test + void shouldAddDocumentsInBatchesAndEmbedOnce() { + // Given + var jdbcTemplate = mock(JdbcTemplate.class); + var embeddingModel = mock(EmbeddingModel.class); + var pgVectorStore = new PgVectorStore.Builder(jdbcTemplate, embeddingModel).withMaxDocumentBatchSize(1000) + .build(); + + // Testing with 9989 documents + var documents = Collections.nCopies(9989, new Document("foo")); + + // When + pgVectorStore.doAdd(documents); + + // Then + verify(embeddingModel, only()).embed(eq(documents), any(), any()); + + var batchUpdateCaptor = ArgumentCaptor.forClass(BatchPreparedStatementSetter.class); + verify(jdbcTemplate, times(10)).batchUpdate(anyString(), batchUpdateCaptor.capture()); + + assertThat(batchUpdateCaptor.getAllValues()).hasSize(10) + .allSatisfy(BatchPreparedStatementSetter::getBatchSize) + .satisfies(batches -> { + for (int i = 0; i < 9; i++) { + assertThat(batches.get(i).getBatchSize()).as("Batch at index %d should have size 10", i) + .isEqualTo(1000); + } + assertThat(batches.get(9).getBatchSize()).as("Last batch should have size 989").isEqualTo(989); + }); + } + } From 2ecffc10c73404e7d1512c12d7fdaa30443280b5 Mon Sep 17 00:00:00 2001 From: Fu Cheng Date: Sat, 21 Sep 2024 14:14:56 +1200 Subject: [PATCH 4/9] Refactor data filtering in RelevancyEvaluator Replace Objects::nonNull and instanceof checks with StringUtils::hasText for more efficient and cleaner content filtering. This change simplifies the stream operation in the getContent method, improving readability and potentially performance. --- .../springframework/ai/evaluation/RelevancyEvaluator.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java index 2f26e97689..9dd7505371 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/evaluation/RelevancyEvaluator.java @@ -5,8 +5,8 @@ import java.util.Collections; import java.util.List; -import java.util.Objects; import java.util.stream.Collectors; +import org.springframework.util.StringUtils; public class RelevancyEvaluator implements Evaluator { @@ -57,9 +57,7 @@ protected String doGetSupportingData(EvaluationRequest evaluationRequest) { List data = evaluationRequest.getDataList(); return data.stream() .map(Content::getContent) - .filter(Objects::nonNull) - .filter(c -> c instanceof String) - .map(Object::toString) + .filter(StringUtils::hasText) .collect(Collectors.joining(System.lineSeparator())); } From 835450761e7d961b51181a36d6cf6bd26fa61f0e Mon Sep 17 00:00:00 2001 From: Ricken Bazolo Date: Sun, 22 Sep 2024 03:14:33 +0200 Subject: [PATCH 5/9] handle http client error exception on getCollection method --- .../springframework/ai/chroma/ChromaApi.java | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java index 6ce419d914..9932b1fe89 100644 --- a/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java +++ b/vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/chroma/ChromaApi.java @@ -31,7 +31,9 @@ import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import org.springframework.web.client.HttpClientErrorException; import org.springframework.web.client.HttpServerErrorException; +import org.springframework.web.client.HttpStatusCodeException; import org.springframework.web.client.RestClient; import com.fasterxml.jackson.annotation.JsonProperty; @@ -49,6 +51,9 @@ public class ChromaApi { // Regular expression pattern that looks for a message inside the ValueError(...). private static Pattern VALUE_ERROR_PATTERN = Pattern.compile("ValueError\\('([^']*)'\\)"); + // Regular expression pattern that looks for a message. + private static Pattern MESSAGE_ERROR_PATTERN = Pattern.compile("\"message\":\"(.*?)\""); + private RestClient restClient; private final ObjectMapper objectMapper; @@ -316,8 +321,8 @@ public Collection getCollection(String collectionName) { .toEntity(Collection.class) .getBody(); } - catch (HttpServerErrorException e) { - String msg = this.getValueErrorMessage(e.getMessage()); + catch (HttpServerErrorException | HttpClientErrorException e) { + String msg = this.getErrorMessage(e); if (String.format("Collection %s does not exist.", collectionName).equals(msg)) { return null; } @@ -413,12 +418,28 @@ private void httpHeaders(HttpHeaders headers) { } } - private String getValueErrorMessage(String logString) { - if (!StringUtils.hasText(logString)) { + private String getErrorMessage(HttpStatusCodeException e) { + var errorMessage = e.getMessage(); + + // If the error message is empty or null, return an empty string + if (!StringUtils.hasText(errorMessage)) { return ""; } - Matcher m = VALUE_ERROR_PATTERN.matcher(logString); - return (m.find()) ? m.group(1) : ""; + + // If the exception is an HttpServerErrorException, use the VALUE_ERROR_PATTERN + Matcher valueErrorMatcher = VALUE_ERROR_PATTERN.matcher(errorMessage); + if (e instanceof HttpServerErrorException && valueErrorMatcher.find()) { + return valueErrorMatcher.group(1); + } + + // Otherwise, use the MESSAGE_ERROR_PATTERN for other cases + Matcher messageErrorMatcher = MESSAGE_ERROR_PATTERN.matcher(errorMessage); + if (messageErrorMatcher.find()) { + return messageErrorMatcher.group(1); + } + + // If no pattern matches, return an empty string + return ""; } } From 1673907db09a6cff72439ce62f87450d4deac977 Mon Sep 17 00:00:00 2001 From: dafriz Date: Sun, 22 Sep 2024 21:53:20 +1000 Subject: [PATCH 6/9] Add support for reasoning tokens in OpenAI usage data This change introduces a new field for tracking reasoning tokens in the OpenAI API response. It extends the Usage record to include CompletionTokenDetails, allowing for more granular token usage reporting. The OpenAiUsage class is updated to expose this new data, and corresponding unit tests are added to verify the behavior. This enhancement provides more detailed insights into token usage, particularly for advanced AI models that separate reasoning from other generation processes. --- .../ai/openai/api/OpenAiApi.java | 18 +++++++++++++- .../ai/openai/metadata/OpenAiUsage.java | 6 +++++ .../ai/openai/metadata/OpenAiUsageTests.java | 24 +++++++++++++++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index b88f795063..4eb00c374e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -936,12 +936,28 @@ public record TopLogProbs(// @formatter:off * @param promptTokens Number of tokens in the prompt. * @param totalTokens Total number of tokens used in the request (prompt + * completion). + * @param completionTokenDetails Breakdown of tokens used in a completion */ @JsonInclude(Include.NON_NULL) public record Usage(// @formatter:off @JsonProperty("completion_tokens") Integer completionTokens, @JsonProperty("prompt_tokens") Integer promptTokens, - @JsonProperty("total_tokens") Integer totalTokens) {// @formatter:on + @JsonProperty("total_tokens") Integer totalTokens, + @JsonProperty("completion_tokens_details") CompletionTokenDetails completionTokenDetails) {// @formatter:on + + public Usage(Integer completionTokens, Integer promptTokens, Integer totalTokens) { + this(completionTokens, promptTokens, totalTokens, null); + } + + /** + * Breakdown of tokens used in a completion + * + * @param reasoningTokens Number of tokens generated by the model for reasoning. + */ + @JsonInclude(Include.NON_NULL) + public record CompletionTokenDetails(// @formatter:off + @JsonProperty("reasoning_tokens") Integer reasoningTokens) {// @formatter:on + } } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java index 821a0325d0..add5d896b5 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/metadata/OpenAiUsage.java @@ -58,6 +58,12 @@ public Long getGenerationTokens() { return generationTokens != null ? generationTokens.longValue() : 0; } + public Long getReasoningTokens() { + OpenAiApi.Usage.CompletionTokenDetails completionTokenDetails = getUsage().completionTokenDetails(); + Integer reasoningTokens = completionTokenDetails != null ? completionTokenDetails.reasoningTokens() : null; + return reasoningTokens != null ? reasoningTokens.longValue() : 0; + } + @Override public Long getTotalTokens() { Integer totalTokens = getUsage().totalTokens(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java index 58c378f35b..b9215b4c3d 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/metadata/OpenAiUsageTests.java @@ -54,4 +54,28 @@ void whenTotalTokensIsNull() { assertThat(usage.getTotalTokens()).isEqualTo(300); } + @Test + void whenCompletionTokenDetailsIsNull() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, null); + OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + assertThat(usage.getTotalTokens()).isEqualTo(300); + assertThat(usage.getReasoningTokens()).isEqualTo(0); + } + + @Test + void whenReasoningTokensIsNull() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, + new OpenAiApi.Usage.CompletionTokenDetails(null)); + OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + assertThat(usage.getReasoningTokens()).isEqualTo(0); + } + + @Test + void whenCompletionTokenDetailsIsPresent() { + OpenAiApi.Usage openAiUsage = new OpenAiApi.Usage(100, 200, 300, + new OpenAiApi.Usage.CompletionTokenDetails(50)); + OpenAiUsage usage = OpenAiUsage.from(openAiUsage); + assertThat(usage.getReasoningTokens()).isEqualTo(50); + } + } From c205c7d5cad5d6278be9c56d3babe6ed9af87ae9 Mon Sep 17 00:00:00 2001 From: Fu Cheng Date: Mon, 23 Sep 2024 13:17:58 +1200 Subject: [PATCH 7/9] Fix interleaved output in JsonReader's parseJsonNode method Replace parallelStream with stream to prevent thread-unsafe appends to the shared StringBuilder. This fixes the issue of intermingled key-value pairs in the generated Document content. Also, replace StringBuffer with StringBuilder for better performance in single-threaded context. The change ensures correct ordering of extracted JSON keys and their values in the resulting Document, improving the reliability and readability of the parsed output. --- .../org/springframework/ai/reader/JsonReader.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java index 895feab7ba..b00eace87f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/reader/JsonReader.java @@ -42,19 +42,19 @@ */ public class JsonReader implements DocumentReader { - private Resource resource; + private final Resource resource; - private JsonMetadataGenerator jsonMetadataGenerator; + private final JsonMetadataGenerator jsonMetadataGenerator; private final ObjectMapper objectMapper = new ObjectMapper(); /** * The key from the JSON that we will use as the text to parse into the Document text */ - private List jsonKeysToUse; + private final List jsonKeysToUse; public JsonReader(Resource resource) { - this(resource, new ArrayList<>().toArray(new String[0])); + this(resource, new String[0]); } public JsonReader(Resource resource, String... jsonKeysToUse) { @@ -92,9 +92,9 @@ public List get() { private Document parseJsonNode(JsonNode jsonNode, ObjectMapper objectMapper) { Map item = objectMapper.convertValue(jsonNode, new TypeReference>() { }); - StringBuffer sb = new StringBuffer(); + var sb = new StringBuilder(); - jsonKeysToUse.parallelStream().filter(item::containsKey).forEach(key -> { + jsonKeysToUse.stream().filter(item::containsKey).forEach(key -> { sb.append(key).append(": ").append(item.get(key)).append(System.lineSeparator()); }); From ee00c620c62fb5f86d6ba418b35869c260fced25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edd=C3=BA=20Mel=C3=A9ndez?= Date: Mon, 23 Sep 2024 13:47:51 -0600 Subject: [PATCH 8/9] Prepend 'http://' to host in ChromaConnectionDetails Modify getHost method to return a properly formatted URL string. This ensures that the Chroma client can correctly connect to the service when using Docker Compose. Fixes #1395 --- .../chroma/ChromaDockerComposeConnectionDetailsFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java index c863b1ec48..b861bea6ad 100644 --- a/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java +++ b/spring-ai-spring-boot-docker-compose/src/main/java/org/springframework/ai/docker/compose/service/connection/chroma/ChromaDockerComposeConnectionDetailsFactory.java @@ -60,7 +60,7 @@ static class ChromaDockerComposeConnectionDetails extends DockerComposeConnectio @Override public String getHost() { - return this.host; + return "http://%s".formatted(this.host); } @Override From c5f07e540299cf4dbd314027416dd5ad9966a398 Mon Sep 17 00:00:00 2001 From: ashni <105304831+ashni-mongodb@users.noreply.github.com> Date: Tue, 24 Sep 2024 11:49:26 -0400 Subject: [PATCH 9/9] Enhance MongoDB docs with additional tutorials Update mongodb.adoc to include links to both beginner and intermediate content for Spring AI and MongoDB integration. Add a new section "Tutorials and Code Examples" with: - Link to the Getting Started guide for basic integration - Link to a detailed RAG tutorial for more advanced usage This change provides users with a clear path from initial setup to more complex implementations using the MongoDB Atlas Vector Store. --- .../antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc index 895e5eda84..359160ad1a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/mongodb.adoc @@ -256,5 +256,8 @@ List results = vectorStore.similaritySearch( ); ---- +== Tutorials and Code Examples +To get started with Spring AI and MongoDB: -If you would like to try out Spring AI with MongoDB, see https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/spring-ai/#std-label-spring-ai[Get Started with the Spring AI Integration]. +* See the https://www.mongodb.com/docs/atlas/atlas-vector-search/ai-integrations/spring-ai/#std-label-spring-ai[Getting Started guide for Spring AI Integration]. +* For a comprehensive code example demonstrating Retrieval Augmented Generation (RAG) with Spring AI and MongoDB, refer to this https://www.mongodb.com/developer/languages/java/retrieval-augmented-generation-spring-ai/[detailed tutorial].