Skip to content

Commit

Permalink
Add OCI GenAI embedding model support
Browse files Browse the repository at this point in the history
This commit introduces support for Oracle Cloud Infrastructure (OCI)
GenAI embedding models in Spring AI. It includes:

* New OCIEmbeddingModel class for interacting with OCI GenAI API
* Auto-configuration for easy setup and integration
* Properties for configuring OCI connection and embedding options
* Documentation updates explaining usage and configuration
* Integration tests to verify functionality

Signed-off-by: Anders Swanson <anders.swanson@oracle.com>
  • Loading branch information
anders-swanson authored and Mark Pollack committed Sep 26, 2024
1 parent 5da44c4 commit ccf190c
Show file tree
Hide file tree
Showing 20 changed files with 1,158 additions and 2 deletions.
1 change: 1 addition & 0 deletions models/spring-ai-oci-genai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[Oracle Cloud Infrastructure GenAI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm)
69 changes: 69 additions & 0 deletions models/spring-ai-oci-genai/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai</artifactId>
<version>1.0.0-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<artifactId>spring-ai-oci-genai</artifactId>
<packaging>jar</packaging>
<name>Spring AI Model - OCI GenAI</name>
<description>OCI GenAI models support</description>
<url>https://github.com/spring-projects/spring-ai</url>

<scm>
<url>https://github.com/spring-projects/spring-ai</url>
<connection>git://github.com/spring-projects/spring-ai.git</connection>
<developerConnection>git@github.com:spring-projects/spring-ai.git</developerConnection>
</scm>

<dependencies>

<!-- production dependencies -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-core</artifactId>
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>com.oracle.oci.sdk</groupId>
<artifactId>oci-java-sdk-shaded-full</artifactId>
<version>${oci-sdk-version}</version>
</dependency>

<dependency>
<groupId>com.oracle.oci.sdk</groupId>
<artifactId>oci-java-sdk-addons-oke-workload-identity</artifactId>
<version>${oci-sdk-version}</version>
</dependency>

<!-- NOTE: Required only by the @ConstructorBinding. -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot</artifactId>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context-support</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-logging</artifactId>
</dependency>

<!-- test dependencies -->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-test</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.oci;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;

import com.oracle.bmc.generativeaiinference.GenerativeAiInference;
import com.oracle.bmc.generativeaiinference.model.DedicatedServingMode;
import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails;
import com.oracle.bmc.generativeaiinference.model.EmbedTextResult;
import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode;
import com.oracle.bmc.generativeaiinference.model.ServingMode;
import com.oracle.bmc.generativeaiinference.requests.EmbedTextRequest;
import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.util.Assert;

/**
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the
* OCI GenAI Embedding API.
*
* @author Anders Swanson
* @since 1.0.0
*/
public class OCIEmbeddingModel extends AbstractEmbeddingModel {

// The OCI GenAI API has a batch size of 96 for embed text requests.
private static final int EMBEDTEXT_BATCH_SIZE = 96;

private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();

private final GenerativeAiInference genAi;

private final OCIEmbeddingOptions options;

private final ObservationRegistry observationRegistry;

private final EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions options) {
this(genAi, options, ObservationRegistry.NOOP);
}

public OCIEmbeddingModel(GenerativeAiInference genAi, OCIEmbeddingOptions options,
ObservationRegistry observationRegistry) {
Assert.notNull(genAi, "com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(observationRegistry, "observationRegistry must not be null");
this.genAi = genAi;
this.options = options;
this.observationRegistry = observationRegistry;
}

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
OCIEmbeddingOptions runtimeOptions = mergeOptions(request.getOptions(), options);
List<EmbedTextRequest> embedTextRequests = createRequests(request.getInstructions(), runtimeOptions);

EmbeddingModelObservationContext context = EmbeddingModelObservationContext.builder()
.embeddingRequest(request)
.provider(AiProvider.OCI_GENAI.value())
.requestOptions(runtimeOptions)
.build();

return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> context,
this.observationRegistry)
.observe(() -> embedAllWithContext(embedTextRequests, context));
}

@Override
public float[] embed(Document document) {
return embed(document.getContent());
}

private EmbeddingResponse embedAllWithContext(List<EmbedTextRequest> embedTextRequests,
EmbeddingModelObservationContext context) {
String modelId = null;
AtomicInteger index = new AtomicInteger(0);
List<Embedding> embeddings = new ArrayList<>();
for (EmbedTextRequest embedTextRequest : embedTextRequests) {
EmbedTextResult embedTextResult = genAi.embedText(embedTextRequest).getEmbedTextResult();
if (modelId == null) {
modelId = embedTextResult.getModelId();
}
for (List<Float> e : embedTextResult.getEmbeddings()) {
float[] data = toFloats(e);
embeddings.add(new Embedding(data, index.getAndIncrement()));
}
}
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
metadata.setModel(modelId);
metadata.setUsage(new EmptyUsage());
EmbeddingResponse embeddingResponse = new EmbeddingResponse(embeddings, metadata);
context.setResponse(embeddingResponse);
return embeddingResponse;
}

private ServingMode servingMode(OCIEmbeddingOptions embeddingOptions) {
return switch (embeddingOptions.getServingMode()) {
case "dedicated" -> DedicatedServingMode.builder().endpointId(embeddingOptions.getModel()).build();
case "on-demand" -> OnDemandServingMode.builder().modelId(embeddingOptions.getModel()).build();
default -> throw new IllegalArgumentException(
"unknown serving mode for OCI embedding model: " + embeddingOptions.getServingMode());
};
}

private List<EmbedTextRequest> createRequests(List<String> inputs, OCIEmbeddingOptions embeddingOptions) {
int size = inputs.size();
List<EmbedTextRequest> requests = new ArrayList<>();
for (int i = 0; i < inputs.size(); i += EMBEDTEXT_BATCH_SIZE) {
List<String> batch = inputs.subList(i, Math.min(i + EMBEDTEXT_BATCH_SIZE, size));
requests.add(createRequest(batch, embeddingOptions));
}
return requests;
}

private EmbedTextRequest createRequest(List<String> inputs, OCIEmbeddingOptions embeddingOptions) {
EmbedTextDetails embedTextDetails = EmbedTextDetails.builder()
.servingMode(servingMode(embeddingOptions))
.compartmentId(embeddingOptions.getCompartment())
.inputs(inputs)
.truncate(Objects.requireNonNullElse(embeddingOptions.getTruncate(), EmbedTextDetails.Truncate.End))
.build();
return EmbedTextRequest.builder().embedTextDetails(embedTextDetails).build();
}

private OCIEmbeddingOptions mergeOptions(EmbeddingOptions embeddingOptions, OCIEmbeddingOptions defaultOptions) {
if (embeddingOptions instanceof OCIEmbeddingOptions) {
OCIEmbeddingOptions dynamicOptions = ModelOptionsUtils.merge(embeddingOptions, defaultOptions,
OCIEmbeddingOptions.class);
if (dynamicOptions != null) {
return dynamicOptions;
}
}
return defaultOptions;
}

private float[] toFloats(List<Float> embedding) {
float[] floats = new float[embedding.size()];
for (int i = 0; i < embedding.size(); i++) {
floats[i] = embedding.get(i);
}
return floats;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.oci;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails;
import org.springframework.ai.embedding.EmbeddingOptions;

/**
* The configuration information for OCI embedding requests
*
* @author Anders Swanson
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public class OCIEmbeddingOptions implements EmbeddingOptions {

private @JsonProperty("model") String model;

private @JsonProperty("compartment") String compartment;

private @JsonProperty("servingMode") String servingMode;

private @JsonProperty("truncate") EmbedTextDetails.Truncate truncate;

public static Builder builder() {
return new Builder();
}

public static class Builder {

private final OCIEmbeddingOptions options = new OCIEmbeddingOptions();

public Builder withModel(String model) {
this.options.setModel(model);
return this;
}

public Builder withCompartment(String compartment) {
this.options.setCompartment(compartment);
return this;
}

public Builder withServingMode(String servingMode) {
this.options.setServingMode(servingMode);
return this;
}

public Builder withTruncate(EmbedTextDetails.Truncate truncate) {
this.options.truncate = truncate;
return this;
}

public OCIEmbeddingOptions build() {
return this.options;
}

}

public String getModel() {
return this.model;
}

/**
* Not used by OCI GenAI.
* @return null
*/
@Override
public Integer getDimensions() {
return null;
}

public void setModel(String model) {
this.model = model;
}

public String getCompartment() {
return compartment;
}

public void setCompartment(String compartment) {
this.compartment = compartment;
}

public String getServingMode() {
return servingMode;
}

public void setServingMode(String servingMode) {
this.servingMode = servingMode;
}

public EmbedTextDetails.Truncate getTruncate() {
return truncate;
}

public void setTruncate(EmbedTextDetails.Truncate truncate) {
this.truncate = truncate;
}

}
Loading

0 comments on commit ccf190c

Please sign in to comment.