Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Ollama embedding model implementation #1159

Merged
merged 1 commit into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,26 @@
*/
package org.springframework.ai.ollama;

import java.util.ArrayList;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingRequest;
import org.springframework.ai.ollama.api.OllamaApi.EmbeddingsResponse;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -71,70 +75,43 @@ public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions) {
this.defaultOptions = defaultOptions;
}

/**
* @deprecated Use {@link OllamaOptions#setModel} instead.
*/
@Deprecated
public OllamaEmbeddingModel withModel(String model) {
this.defaultOptions.setModel(model);
return this;
}

/**
* @deprecated Use {@link OllamaOptions} constructor instead.
*/
@Deprecated
public OllamaEmbeddingModel withDefaultOptions(OllamaOptions options) {
this.defaultOptions = options;
return this;
}

@Override
public List<Double> embed(Document document) {
return embed(document.getContent());
}

@Override
public EmbeddingResponse call(org.springframework.ai.embedding.EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
if (request.getInstructions().size() != 1) {
logger.warn(
"Ollama Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
}
public EmbeddingResponse call(EmbeddingRequest request) {

List<List<Double>> embeddingList = new ArrayList<>();
for (String inputContent : request.getInstructions()) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");

EmbeddingRequest ollamaEmbeddingRequest = ollamaEmbeddingRequest(inputContent, request.getOptions());
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest = ollamaEmbeddingRequest(request.getInstructions(),
request.getOptions());

OllamaApi.EmbeddingResponse response = this.ollamaApi.embeddings(ollamaEmbeddingRequest);
EmbeddingsResponse response = this.ollamaApi.embed(ollamaEmbeddingRequest);

embeddingList.add(response.embedding());
}
AtomicInteger indexCounter = new AtomicInteger(0);

List<Embedding> embeddings = embeddingList.stream()
List<Embedding> embeddings = response.embeddings()
.stream()
.map(e -> new Embedding(e, indexCounter.getAndIncrement()))
.toList();
return new EmbeddingResponse(embeddings);

EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata(response.model(),
new EmptyUsage());

return new EmbeddingResponse(embeddings, embeddingResponseMetadata);
}

/**
* Package access for testing.
*/
OllamaApi.EmbeddingRequest ollamaEmbeddingRequest(String inputContent, EmbeddingOptions options) {
OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, EmbeddingOptions options) {

// runtime options
OllamaOptions runtimeOptions = null;
if (options != null) {
if (options instanceof OllamaOptions ollamaOptions) {
runtimeOptions = ollamaOptions;
}
else {
// currently EmbeddingOptions does not have any portable options to be
// merged.
runtimeOptions = null;
}
if (options != null && options instanceof OllamaOptions ollamaOptions) {
runtimeOptions = ollamaOptions;
}

OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);
Expand All @@ -144,8 +121,40 @@ OllamaApi.EmbeddingRequest ollamaEmbeddingRequest(String inputContent, Embedding
throw new IllegalArgumentException("Model is not set!");
}
String model = mergedOptions.getModel();
return new EmbeddingRequest(model, inputContent, null,
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()));

return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
}

public static class DurationParser {

private static Pattern PATTERN = Pattern.compile("(\\d+)(ms|s|m|h)");

public static Duration parse(String input) {

if (!StringUtils.hasText(input)) {
return null;
}

Matcher matcher = PATTERN.matcher(input);

if (matcher.matches()) {
long value = Long.parseLong(matcher.group(1));
String unit = matcher.group(2);

return switch (unit) {
case "ms" -> Duration.ofMillis(value);
case "s" -> Duration.ofSeconds(value);
case "m" -> Duration.ofMinutes(value);
case "h" -> Duration.ofHours(value);
default -> throw new IllegalArgumentException("Unsupported time unit: " + unit);
};
}
else {
throw new IllegalArgumentException("Invalid duration format: " + input);
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ public record GenerateResponse(
@JsonProperty("eval_duration") Duration evalDuration) {
}

/**
/**
* Generate a completion for the given prompt.
* @param completionRequest Completion request.
* @return Completion response.
Expand Down Expand Up @@ -691,11 +691,40 @@ public Flux<ChatResponse> streamingChat(ChatRequest chatRequest) {
* Generate embeddings from a model.
*
* @param model The name of model to generate embeddings from.
* @param prompt The text to generate embeddings for.
* @param input The text or list of text to generate embeddings for.
* @param keepAlive Controls how long the model will stay loaded into memory following the request (default: 5m).
* @param options Additional model parameters listed in the documentation for the
* @param truncate Truncates the end of each input to fit within context length.
* Returns error if false and context length is exceeded. Defaults to true.
*/
@JsonInclude(Include.NON_NULL)
public record EmbeddingsRequest(
@JsonProperty("model") String model,
@JsonProperty("input") List<String> input,
@JsonProperty("keep_alive") Duration keepAlive,
@JsonProperty("options") Map<String, Object> options,
@JsonProperty("truncate") Boolean truncate) {

/**
* Shortcut constructor to create a EmbeddingRequest without options.
* @param model The name of model to generate embeddings from.
* @param input The text or list of text to generate embeddings for.
*/
public EmbeddingsRequest(String model, String input) {
this(model, List.of(input), null, null, null);
}
}

/**
* Generate embeddings from a model.
*
* @param model The name of model to generate embeddings from.
* @param prompt The text generate embeddings for
* @param keepAlive Controls how long the model will stay loaded into memory following the request (default: 5m).
* @param options Additional model parameters listed in the documentation for the
* Model file such as temperature.
* @deprecated Use {@link EmbeddingsRequest} instead.
*/
@Deprecated(since = "1.0.0-M2", forRemoval = true)
@JsonInclude(Include.NON_NULL)
public record EmbeddingRequest(
@JsonProperty("model") String model,
Expand All @@ -717,17 +746,49 @@ public EmbeddingRequest(String model, String prompt) {
* The response object returned from the /embedding endpoint.
*
* @param embedding The embedding generated from the model.
* @deprecated Use {@link EmbeddingsResponse} instead.
*/
@Deprecated(since = "1.0.0-M2", forRemoval = true)
@JsonInclude(Include.NON_NULL)
public record EmbeddingResponse(
@JsonProperty("embedding") List<Double> embedding) {
}


/**
* The response object returned from the /embedding endpoint.
* @param model The model used for generating the embeddings.
* @param embeddings The list of embeddings generated from the model.
* Each embedding (list of doubles) corresponds to a single input text.
*/
@JsonInclude(Include.NON_NULL)
public record EmbeddingsResponse(
@JsonProperty("model") String model,
@JsonProperty("embeddings") List<List<Double>> embeddings) {
}

/**
* Generate embeddings from a model.
* @param embeddingsRequest Embedding request.
* @return Embeddings response.
*/
public EmbeddingsResponse embed(EmbeddingsRequest embeddingsRequest) {
Assert.notNull(embeddingsRequest, REQUEST_BODY_NULL_ERROR);

return this.restClient.post()
.uri("/api/embed")
.body(embeddingsRequest)
.retrieve()
.onStatus(this.responseErrorHandler)
.body(EmbeddingsResponse.class);
}
/**
* Generate embeddings from a model.
* @param embeddingRequest Embedding request.
* @return Embedding response.
* @deprecated Use {@link #embed(EmbeddingsRequest)} instead.
*/
@Deprecated(since = "1.0.0-M2", forRemoval = true)
public EmbeddingResponse embeddings(EmbeddingRequest embeddingRequest) {
Assert.notNull(embeddingRequest, REQUEST_BODY_NULL_ERROR);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed

public static final String DEFAULT_MODEL = OllamaModel.MISTRAL.id();

private static final List<String> NON_SUPPORTED_FIELDS = List.of("model", "format", "keep_alive");
private static final List<String> NON_SUPPORTED_FIELDS = List.of("model", "format", "keep_alive", "truncate");

// Following fields are options which must be set when the model is loaded into
// memory.
Expand Down Expand Up @@ -267,6 +267,13 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
* Part of Chat completion <a href="https://github.com/ollama/ollama/blob/main/docs/api.md#parameters-1">advanced parameters</a>.
*/
@JsonProperty("keep_alive") private String keepAlive;


/**
* Truncates the end of each input to fit within context length. Returns error if false and context length is exceeded.
* Defaults to true.
*/
@JsonProperty("truncate") private Boolean truncate;

/**
* Tool Function Callbacks to register with the ChatModel.
Expand Down Expand Up @@ -312,14 +319,6 @@ public OllamaOptions withModel(OllamaModel model) {
return this;
}

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public OllamaOptions withFormat(String format) {
this.format = format;
return this;
Expand All @@ -330,6 +329,11 @@ public OllamaOptions withKeepAlive(String keepAlive) {
return this;
}

public OllamaOptions withTruncate(Boolean truncate) {
this.truncate = truncate;
return this;
}

public OllamaOptions withUseNUMA(Boolean useNUMA) {
this.useNUMA = useNUMA;
return this;
Expand Down Expand Up @@ -491,6 +495,17 @@ public OllamaOptions withFunction(String functionName) {
return this;
}

// -------------------
// Getters and Setters
// -------------------
public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public String getFormat() {
return this.format;
}
Expand Down Expand Up @@ -739,6 +754,14 @@ public void setStop(List<String> stop) {
this.stop = stop;
}

public Boolean getTruncate() {
return this.truncate;
}

public void setTruncate(Boolean truncate) {
this.truncate = truncate;
}

@Override
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
Expand Down Expand Up @@ -797,6 +820,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
.withModel(fromOptions.getModel())
.withFormat(fromOptions.getFormat())
.withKeepAlive(fromOptions.getKeepAlive())
.withTruncate(fromOptions.getTruncate())
.withUseNUMA(fromOptions.getUseNUMA())
.withNumCtx(fromOptions.getNumCtx())
.withNumBatch(fromOptions.getNumBatch())
Expand Down Expand Up @@ -839,15 +863,16 @@ public boolean equals(Object o) {
return false;
OllamaOptions that = (OllamaOptions) o;
return Objects.equals(model, that.model) && Objects.equals(format, that.format)
&& Objects.equals(keepAlive, that.keepAlive) && Objects.equals(useNUMA, that.useNUMA)
&& Objects.equals(numCtx, that.numCtx) && Objects.equals(numBatch, that.numBatch)
&& Objects.equals(numGPU, that.numGPU) && Objects.equals(mainGPU, that.mainGPU)
&& Objects.equals(lowVRAM, that.lowVRAM) && Objects.equals(f16KV, that.f16KV)
&& Objects.equals(logitsAll, that.logitsAll) && Objects.equals(vocabOnly, that.vocabOnly)
&& Objects.equals(useMMap, that.useMMap) && Objects.equals(useMLock, that.useMLock)
&& Objects.equals(numThread, that.numThread) && Objects.equals(numKeep, that.numKeep)
&& Objects.equals(seed, that.seed) && Objects.equals(numPredict, that.numPredict)
&& Objects.equals(topK, that.topK) && Objects.equals(topP, that.topP) && Objects.equals(tfsZ, that.tfsZ)
&& Objects.equals(keepAlive, that.keepAlive) && Objects.equals(truncate, that.truncate)
&& Objects.equals(useNUMA, that.useNUMA) && Objects.equals(numCtx, that.numCtx)
&& Objects.equals(numBatch, that.numBatch) && Objects.equals(numGPU, that.numGPU)
&& Objects.equals(mainGPU, that.mainGPU) && Objects.equals(lowVRAM, that.lowVRAM)
&& Objects.equals(f16KV, that.f16KV) && Objects.equals(logitsAll, that.logitsAll)
&& Objects.equals(vocabOnly, that.vocabOnly) && Objects.equals(useMMap, that.useMMap)
&& Objects.equals(useMLock, that.useMLock) && Objects.equals(numThread, that.numThread)
&& Objects.equals(numKeep, that.numKeep) && Objects.equals(seed, that.seed)
&& Objects.equals(numPredict, that.numPredict) && Objects.equals(topK, that.topK)
&& Objects.equals(topP, that.topP) && Objects.equals(tfsZ, that.tfsZ)
&& Objects.equals(typicalP, that.typicalP) && Objects.equals(repeatLastN, that.repeatLastN)
&& Objects.equals(temperature, that.temperature) && Objects.equals(repeatPenalty, that.repeatPenalty)
&& Objects.equals(presencePenalty, that.presencePenalty)
Expand All @@ -860,12 +885,12 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
return Objects.hash(this.model, this.format, this.keepAlive, this.useNUMA, this.numCtx, this.numBatch,
this.numGPU, this.mainGPU, lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly, this.useMMap,
this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK, this.topP, tfsZ,
this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty, this.presencePenalty,
this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta, this.penalizeNewline,
this.stop, this.functionCallbacks, this.functions);
return Objects.hash(this.model, this.format, this.keepAlive, this.truncate, this.useNUMA, this.numCtx,
this.numBatch, this.numGPU, this.mainGPU, lowVRAM, this.f16KV, this.logitsAll, this.vocabOnly,
this.useMMap, this.useMLock, this.numThread, this.numKeep, this.seed, this.numPredict, this.topK,
this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty,
this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta,
this.penalizeNewline, this.stop, this.functionCallbacks, this.functions);
}

}
Loading
Loading