-
Notifications
You must be signed in to change notification settings - Fork 736
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add OCI GenAI embedding model support
This commit introduces support for Oracle Cloud Infrastructure (OCI) GenAI embedding models in Spring AI. It includes: * New OCIEmbeddingModel class for interacting with OCI GenAI API * Auto-configuration for easy setup and integration * Properties for configuring OCI connection and embedding options * Documentation updates explaining usage and configuration * Integration tests to verify functionality Signed-off-by: Anders Swanson <anders.swanson@oracle.com>
- Loading branch information
1 parent
5da44c4
commit ccf190c
Showing
20 changed files
with
1,158 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
[Oracle Cloud Infrastructure GenAI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<parent> | ||
<groupId>org.springframework.ai</groupId> | ||
<artifactId>spring-ai</artifactId> | ||
<version>1.0.0-SNAPSHOT</version> | ||
<relativePath>../../pom.xml</relativePath> | ||
</parent> | ||
<artifactId>spring-ai-oci-genai</artifactId> | ||
<packaging>jar</packaging> | ||
<name>Spring AI Model - OCI GenAI</name> | ||
<description>OCI GenAI models support</description> | ||
<url>https://github.com/spring-projects/spring-ai</url> | ||
|
||
<scm> | ||
<url>https://github.com/spring-projects/spring-ai</url> | ||
<connection>git://github.com/spring-projects/spring-ai.git</connection> | ||
<developerConnection>git@github.com:spring-projects/spring-ai.git</developerConnection> | ||
</scm> | ||
|
||
<dependencies> | ||
|
||
<!-- production dependencies --> | ||
<dependency> | ||
<groupId>org.springframework.ai</groupId> | ||
<artifactId>spring-ai-core</artifactId> | ||
<version>${project.parent.version}</version> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.oracle.oci.sdk</groupId> | ||
<artifactId>oci-java-sdk-shaded-full</artifactId> | ||
<version>${oci-sdk-version}</version> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.oracle.oci.sdk</groupId> | ||
<artifactId>oci-java-sdk-addons-oke-workload-identity</artifactId> | ||
<version>${oci-sdk-version}</version> | ||
</dependency> | ||
|
||
<!-- NOTE: Required only by the @ConstructorBinding. --> | ||
<dependency> | ||
<groupId>org.springframework.boot</groupId> | ||
<artifactId>spring-boot</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.springframework</groupId> | ||
<artifactId>spring-context-support</artifactId> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.springframework.boot</groupId> | ||
<artifactId>spring-boot-starter-logging</artifactId> | ||
</dependency> | ||
|
||
<!-- test dependencies --> | ||
<dependency> | ||
<groupId>org.springframework.ai</groupId> | ||
<artifactId>spring-ai-test</artifactId> | ||
<version>${project.version}</version> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
</dependencies> | ||
|
||
</project> |
177 changes: 177 additions & 0 deletions
177
models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
/* | ||
* Copyright 2024 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.springframework.ai.oci; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Objects; | ||
import java.util.concurrent.atomic.AtomicInteger; | ||
|
||
import com.oracle.bmc.generativeaiinference.GenerativeAiInference; | ||
import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode; | ||
import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails; | ||
import com.oracle.bmc.generativeaiinference.model.EmbedTextResult; | ||
import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode; | ||
import com.oracle.bmc.generativeaiinference.model.ServingMode; | ||
import com.oracle.bmc.generativeaiinference.requests.EmbedTextRequest; | ||
import io.micrometer.observation.ObservationRegistry; | ||
import org.springframework.ai.chat.metadata.EmptyUsage; | ||
import org.springframework.ai.document.Document; | ||
import org.springframework.ai.embedding.AbstractEmbeddingModel; | ||
import org.springframework.ai.embedding.Embedding; | ||
import org.springframework.ai.embedding.EmbeddingOptions; | ||
import org.springframework.ai.embedding.EmbeddingRequest; | ||
import org.springframework.ai.embedding.EmbeddingResponse; | ||
import org.springframework.ai.embedding.EmbeddingResponseMetadata; | ||
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; | ||
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; | ||
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; | ||
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; | ||
import org.springframework.ai.model.ModelOptionsUtils; | ||
import org.springframework.ai.observation.conventions.AiProvider; | ||
import org.springframework.util.Assert; | ||
|
||
/** | ||
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the | ||
* OCI GenAI Embedding API. | ||
* | ||
* @author Anders Swanson | ||
* @since 1.0.0 | ||
*/ | ||
public class OCIEmbeddingModel extends AbstractEmbeddingModel { | ||
|
||
// The OCI GenAI API has a batch size of 96 for embed text requests. | ||
private static final int EMBEDTEXT_BATCH_SIZE = 96; | ||
|
||
private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); | ||
|
||
private final GenerativeAiInference genAi; | ||
|
||
private final OCIEmbeddingOptions options; | ||
|
||
private final ObservationRegistry observationRegistry; | ||
|
||
private final EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; | ||
|
||
public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions options) { | ||
this(genAi, options, ObservationRegistry.NOOP); | ||
} | ||
|
||
public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions options, | ||
ObservationRegistry observationRegistry) { | ||
Assert.notNull(genAi, "com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient must not be null"); | ||
Assert.notNull(options, "options must not be null"); | ||
Assert.notNull(observationRegistry, "observationRegistry must not be null"); | ||
this.genAi = genAi; | ||
this.options = options; | ||
this.observationRegistry = observationRegistry; | ||
} | ||
|
||
@Override | ||
public EmbeddingResponse call(EmbeddingRequest request) { | ||
Assert.notEmpty(request.getInstructions(), "At least one text is required!"); | ||
OCIEmbeddingOptions runtimeOptions = mergeOptions(request.getOptions(), options); | ||
List<EmbedTextRequest> embedTextRequests = createRequests(request.getInstructions(), runtimeOptions); | ||
|
||
EmbeddingModelObservationContext context = EmbeddingModelObservationContext.builder() | ||
.embeddingRequest(request) | ||
.provider(AiProvider.OCI_GENAI.value()) | ||
.requestOptions(runtimeOptions) | ||
.build(); | ||
|
||
return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION | ||
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> context, | ||
this.observationRegistry) | ||
.observe(() -> embedAllWithContext(embedTextRequests, context)); | ||
} | ||
|
||
@Override | ||
public float[] embed(Document document) { | ||
return embed(document.getContent()); | ||
} | ||
|
||
private EmbeddingResponse embedAllWithContext(List<EmbedTextRequest> embedTextRequests, | ||
EmbeddingModelObservationContext context) { | ||
String modelId = null; | ||
AtomicInteger index = new AtomicInteger(0); | ||
List<Embedding> embeddings = new ArrayList<>(); | ||
for (EmbedTextRequest embedTextRequest : embedTextRequests) { | ||
EmbedTextResult embedTextResult = genAi.embedText(embedTextRequest).getEmbedTextResult(); | ||
if (modelId == null) { | ||
modelId = embedTextResult.getModelId(); | ||
} | ||
for (List<Float> e : embedTextResult.getEmbeddings()) { | ||
float[] data = toFloats(e); | ||
embeddings.add(new Embedding(data, index.getAndIncrement())); | ||
} | ||
} | ||
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata(); | ||
metadata.setModel(modelId); | ||
metadata.setUsage(new EmptyUsage()); | ||
EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, metadata); | ||
context.setResponse(embeddingResponse); | ||
return embeddingResponse; | ||
} | ||
|
||
private ServingMode servingMode(OCIEmbeddingOptions embeddingOptions) { | ||
return switch (embeddingOptions.getServingMode()) { | ||
case "dedicated" -> DedicatedServingMode.builder().endpointId(embeddingOptions.getModel()).build(); | ||
case "on-demand" -> OnDemandServingMode.builder().modelId(embeddingOptions.getModel()).build(); | ||
default -> throw new IllegalArgumentException( | ||
"unknown serving mode for OCI embedding model: " + embeddingOptions.getServingMode()); | ||
}; | ||
} | ||
|
||
private List<EmbedTextRequest> createRequests(List<String> inputs, OCIEmbeddingOptions embeddingOptions) { | ||
int size = inputs.size(); | ||
List<EmbedTextRequest> requests = new ArrayList<>(); | ||
for (int i = 0; i < inputs.size(); i += EMBEDTEXT_BATCH_SIZE) { | ||
List<String> batch = inputs.subList(i, Math.min(i + EMBEDTEXT_BATCH_SIZE, size)); | ||
requests.add(createRequest(batch, embeddingOptions)); | ||
} | ||
return requests; | ||
} | ||
|
||
private EmbedTextRequest createRequest(List<String> inputs, OCIEmbeddingOptions embeddingOptions) { | ||
EmbedTextDetails embedTextDetails = EmbedTextDetails.builder() | ||
.servingMode(servingMode(embeddingOptions)) | ||
.compartmentId(embeddingOptions.getCompartment()) | ||
.inputs(inputs) | ||
.truncate(Objects.requireNonNullElse(embeddingOptions.getTruncate(), EmbedTextDetails.Truncate.End)) | ||
.build(); | ||
return EmbedTextRequest.builder().embedTextDetails(embedTextDetails).build(); | ||
} | ||
|
||
private OCIEmbeddingOptions mergeOptions(EmbeddingOptions embeddingOptions, OCIEmbeddingOptions defaultOptions) { | ||
if (embeddingOptions instanceof OCIEmbeddingOptions) { | ||
OCIEmbeddingOptions dynamicOptions = ModelOptionsUtils.merge(embeddingOptions, defaultOptions, | ||
OCIEmbeddingOptions.class); | ||
if (dynamicOptions != null) { | ||
return dynamicOptions; | ||
} | ||
} | ||
return defaultOptions; | ||
} | ||
|
||
private float[] toFloats(List<Float> embedding) { | ||
float[] floats = new float[embedding.size()]; | ||
for (int i = 0; i < embedding.size(); i++) { | ||
floats[i] = embedding.get(i); | ||
} | ||
return floats; | ||
} | ||
|
||
} |
114 changes: 114 additions & 0 deletions
114
models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/OCIEmbeddingOptions.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* Copyright 2024 the original author or authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* https://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.springframework.ai.oci; | ||
|
||
import com.fasterxml.jackson.annotation.JsonInclude; | ||
import com.fasterxml.jackson.annotation.JsonProperty; | ||
import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails; | ||
import org.springframework.ai.embedding.EmbeddingOptions; | ||
|
||
/** | ||
* The configuration information for OCI embedding requests | ||
* | ||
* @author Anders Swanson | ||
*/ | ||
@JsonInclude(JsonInclude.Include.NON_NULL) | ||
public class OCIEmbeddingOptions implements EmbeddingOptions { | ||
|
||
private @JsonProperty("model") String model; | ||
|
||
private @JsonProperty("compartment") String compartment; | ||
|
||
private @JsonProperty("servingMode") String servingMode; | ||
|
||
private @JsonProperty("truncate") EmbedTextDetails.Truncate truncate; | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
public static class Builder { | ||
|
||
private final OCIEmbeddingOptions options = new OCIEmbeddingOptions(); | ||
|
||
public Builder withModel(String model) { | ||
this.options.setModel(model); | ||
return this; | ||
} | ||
|
||
public Builder withCompartment(String compartment) { | ||
this.options.setCompartment(compartment); | ||
return this; | ||
} | ||
|
||
public Builder withServingMode(String servingMode) { | ||
this.options.setServingMode(servingMode); | ||
return this; | ||
} | ||
|
||
public Builder withTruncate(EmbedTextDetails.Truncate truncate) { | ||
this.options.truncate = truncate; | ||
return this; | ||
} | ||
|
||
public OCIEmbeddingOptions build() { | ||
return this.options; | ||
} | ||
|
||
} | ||
|
||
public String getModel() { | ||
return this.model; | ||
} | ||
|
||
/** | ||
* Not used by OCI GenAI. | ||
* @return null | ||
*/ | ||
@Override | ||
public Integer getDimensions() { | ||
return null; | ||
} | ||
|
||
public void setModel(String model) { | ||
this.model = model; | ||
} | ||
|
||
public String getCompartment() { | ||
return compartment; | ||
} | ||
|
||
public void setCompartment(String compartment) { | ||
this.compartment = compartment; | ||
} | ||
|
||
public String getServingMode() { | ||
return servingMode; | ||
} | ||
|
||
public void setServingMode(String servingMode) { | ||
this.servingMode = servingMode; | ||
} | ||
|
||
public EmbedTextDetails.Truncate getTruncate() { | ||
return truncate; | ||
} | ||
|
||
public void setTruncate(EmbedTextDetails.Truncate truncate) { | ||
this.truncate = truncate; | ||
} | ||
|
||
} |
Oops, something went wrong.