Skip to content

Commit

Permalink
Consolidate Ollama auto-pull logic
Browse files Browse the repository at this point in the history
Consolidate the Ollama auto-pull logic at startup time, supporting the auto-pull for the default models specified via configuration properties and for optional models specified for initialization.

Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
  • Loading branch information
ThomasVitale authored and tzolov committed Oct 21, 2024
1 parent 1cadc49 commit d5bc9c9
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
this.defaultOptions = defaultOptions;
this.observationRegistry = observationRegistry;
this.modelManager = new OllamaModelManager(chatApi, modelManagementOptions);
initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
}

public static Builder builder() {
Expand Down Expand Up @@ -302,11 +302,6 @@ else if (message instanceof ToolResponseMessage toolMessage) {
}
OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);

mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
}

// Override the model.
if (!StringUtils.hasText(mergedOptions.getModel())) {
throw new IllegalArgumentException("Model is not set!");
Expand All @@ -331,8 +326,6 @@ else if (message instanceof ToolResponseMessage toolMessage) {
requestBuilder.withTools(this.getFunctionTools(functionsForThisRequest));
}

initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());

return requestBuilder.build();
}

Expand Down Expand Up @@ -379,7 +372,7 @@ public ChatOptions getDefaultOptions() {
/**
* Pull the given model into Ollama based on the specified strategy.
*/
private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
private void initializeModel(String model, PullModelStrategy pullModelStrategy) {
if (pullModelStrategy != null && !PullModelStrategy.NEVER.equals(pullModelStrategy)) {
this.modelManager.pullModel(model, pullModelStrategy);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public OllamaEmbeddingModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,
this.observationRegistry = observationRegistry;
this.modelManager = new OllamaModelManager(ollamaApi, modelManagementOptions);

initializeModelIfEnabled(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
}

public static Builder builder() {
Expand Down Expand Up @@ -139,19 +139,12 @@ OllamaApi.EmbeddingsRequest ollamaEmbeddingRequest(List<String> inputContent, Em

OllamaOptions mergedOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class);

mergedOptions.setPullModelStrategy(this.defaultOptions.getPullModelStrategy());
if (runtimeOptions != null && runtimeOptions.getPullModelStrategy() != null) {
mergedOptions.setPullModelStrategy(runtimeOptions.getPullModelStrategy());
}

// Override the model.
if (!StringUtils.hasText(mergedOptions.getModel())) {
throw new IllegalArgumentException("Model is not set!");
}
String model = mergedOptions.getModel();

initializeModelIfEnabled(mergedOptions.getModel(), mergedOptions.getPullModelStrategy());

return new OllamaApi.EmbeddingsRequest(model, inputContent, DurationParser.parse(mergedOptions.getKeepAlive()),
OllamaOptions.filterNonSupportedFields(mergedOptions.toMap()), mergedOptions.getTruncate());
}
Expand All @@ -163,7 +156,7 @@ private EmbeddingOptions buildRequestOptions(OllamaApi.EmbeddingsRequest request
/**
* Pull the given model into Ollama based on the specified strategy.
*/
private void initializeModelIfEnabled(String model, PullModelStrategy pullModelStrategy) {
private void initializeModel(String model, PullModelStrategy pullModelStrategy) {
if (pullModelStrategy != null && !PullModelStrategy.NEVER.equals(pullModelStrategy)) {
this.modelManager.pullModel(model, pullModelStrategy);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.ollama.management.PullModelStrategy;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -303,12 +302,6 @@ public class OllamaOptions implements FunctionCallingOptions, ChatOptions, Embed
@JsonIgnore
private Map<String, Object> toolContext;

/**
* Strategy for pulling models at run-time.
*/
@JsonIgnore
private PullModelStrategy pullModelStrategy;

public static OllamaOptions builder() {
return new OllamaOptions();
}
Expand Down Expand Up @@ -521,11 +514,6 @@ public OllamaOptions withToolContext(Map<String, Object> toolContext) {
return this;
}

public OllamaOptions withPullModelStrategy(PullModelStrategy pullModelStrategy) {
this.pullModelStrategy = pullModelStrategy;
return this;
}

// -------------------
// Getters and Setters
// -------------------
Expand Down Expand Up @@ -866,14 +854,6 @@ public void setToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
}

public PullModelStrategy getPullModelStrategy() {
return this.pullModelStrategy;
}

public void setPullModelStrategy(PullModelStrategy pullModelStrategy) {
this.pullModelStrategy = pullModelStrategy;
}

/**
* Convert the {@link OllamaOptions} object to a {@link Map} of key/value pairs.
* @return The {@link Map} of key/value pairs.
Expand Down Expand Up @@ -944,8 +924,7 @@ public static OllamaOptions fromOptions(OllamaOptions fromOptions) {
.withFunctions(fromOptions.getFunctions())
.withProxyToolCalls(fromOptions.getProxyToolCalls())
.withFunctionCallbacks(fromOptions.getFunctionCallbacks())
.withToolContext(fromOptions.getToolContext())
.withPullModelStrategy(fromOptions.getPullModelStrategy());
.withToolContext(fromOptions.getToolContext());
}
// @formatter:on

Expand Down Expand Up @@ -975,8 +954,7 @@ public boolean equals(Object o) {
&& Objects.equals(penalizeNewline, that.penalizeNewline) && Objects.equals(stop, that.stop)
&& Objects.equals(functionCallbacks, that.functionCallbacks)
&& Objects.equals(proxyToolCalls, that.proxyToolCalls) && Objects.equals(functions, that.functions)
&& Objects.equals(toolContext, that.toolContext)
&& Objects.equals(pullModelStrategy, that.pullModelStrategy);
&& Objects.equals(toolContext, that.toolContext);
}

@Override
Expand All @@ -987,7 +965,7 @@ public int hashCode() {
this.topP, tfsZ, this.typicalP, this.repeatLastN, this.temperature, this.repeatPenalty,
this.presencePenalty, this.frequencyPenalty, this.mirostat, this.mirostatTau, this.mirostatEta,
this.penalizeNewline, this.stop, this.functionCallbacks, this.functions, this.proxyToolCalls,
this.toolContext, this.pullModelStrategy);
this.toolContext);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,49 @@
*/
public record ModelManagementOptions(PullModelStrategy pullModelStrategy, List<String> additionalModels,
Duration timeout, Integer maxRetries) {

public static ModelManagementOptions defaults() {
return new ModelManagementOptions(PullModelStrategy.NEVER, List.of(), Duration.ofMinutes(5), 0);
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private PullModelStrategy pullModelStrategy = PullModelStrategy.NEVER;

private List<String> additionalModels = List.of();

private Duration timeout = Duration.ofMinutes(5);

private Integer maxRetries = 0;

public Builder withPullModelStrategy(PullModelStrategy pullModelStrategy) {
this.pullModelStrategy = pullModelStrategy;
return this;
}

public Builder withAdditionalModels(List<String> additionalModels) {
this.additionalModels = additionalModels;
return this;
}

public Builder withTimeout(Duration timeout) {
this.timeout = timeout;
return this;
}

public Builder withMaxRetries(Integer maxRetries) {
this.maxRetries = maxRetries;
return this;
}

public ModelManagementOptions build() {
return new ModelManagementOptions(pullModelStrategy, additionalModels, timeout, maxRetries);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.PullModelStrategy;
Expand All @@ -56,6 +57,8 @@ class OllamaChatModelIT extends BaseOllamaIT {

private static final String MODEL = OllamaModel.LLAMA3_2.getName();

private static final String ADDITIONAL_MODEL = "tinyllama";

@Autowired
private OllamaChatModel chatModel;

Expand All @@ -65,23 +68,17 @@ class OllamaChatModelIT extends BaseOllamaIT {
@Test
void autoPullModelTest() {
var modelManager = new OllamaModelManager(ollamaApi);
var model = "tinyllama";
modelManager.deleteModel(model);
assertThat(modelManager.isModelAvailable(model)).isFalse();
assertThat(modelManager.isModelAvailable(ADDITIONAL_MODEL)).isTrue();

String joke = ChatClient.create(chatModel)
.prompt("Tell me a joke")
.options(OllamaOptions.builder()
.withModel(model)
.withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
.build())
.options(OllamaOptions.builder().withModel(ADDITIONAL_MODEL).build())
.call()
.content();

assertThat(joke).isNotEmpty();
assertThat(modelManager.isModelAvailable(model)).isTrue();

modelManager.deleteModel(model);
modelManager.deleteModel(ADDITIONAL_MODEL);
}

@Test
Expand Down Expand Up @@ -249,6 +246,10 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
return OllamaChatModel.builder()
.withOllamaApi(ollamaApi)
.withDefaultOptions(OllamaOptions.create().withModel(MODEL).withTemperature(0.9))
.withModelManagementOptions(ModelManagementOptions.builder()
.withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
.withAdditionalModels(List.of(ADDITIONAL_MODEL))
.build())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaModel;
import org.springframework.ai.ollama.management.ModelManagementOptions;
import org.springframework.ai.ollama.management.OllamaModelManager;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.ollama.management.PullModelStrategy;
Expand All @@ -41,6 +42,8 @@ class OllamaEmbeddingModelIT extends BaseOllamaIT {

private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName();

private static final String ADDITIONAL_MODEL = "all-minilm";

@Autowired
private OllamaEmbeddingModel embeddingModel;

Expand All @@ -65,36 +68,29 @@ void embeddings() {
}

@Test
void autoPullModel() {
void autoPullModelAtStartupTime() {
var model = "all-minilm";
assertThat(embeddingModel).isNotNull();

var modelManager = new OllamaModelManager(ollamaApi);
modelManager.deleteModel(model);
assertThat(modelManager.isModelAvailable(model)).isFalse();
assertThat(modelManager.isModelAvailable(ADDITIONAL_MODEL)).isTrue();

EmbeddingResponse embeddingResponse = embeddingModel
.call(new EmbeddingRequest(List.of("Hello World", "Something else"),
OllamaOptions.builder()
.withModel(model)
.withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
.withTruncate(false)
.build()));

assertThat(modelManager.isModelAvailable(model)).isTrue();
OllamaOptions.builder().withModel(model).withTruncate(false).build()));

assertThat(embeddingResponse.getResults()).hasSize(2);
assertThat(embeddingResponse.getResults().get(0).getIndex()).isEqualTo(0);
assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getResults().get(1).getIndex()).isEqualTo(1);
assertThat(embeddingResponse.getResults().get(1).getOutput()).isNotEmpty();
assertThat(embeddingResponse.getMetadata().getModel()).contains(model);
assertThat(embeddingResponse.getMetadata().getModel()).contains(ADDITIONAL_MODEL);
assertThat(embeddingResponse.getMetadata().getUsage().getPromptTokens()).isEqualTo(4);
assertThat(embeddingResponse.getMetadata().getUsage().getTotalTokens()).isEqualTo(4);

assertThat(embeddingModel.dimensions()).isEqualTo(768);

modelManager.deleteModel(model);
modelManager.deleteModel(ADDITIONAL_MODEL);
}

@SpringBootConfiguration
Expand All @@ -110,6 +106,10 @@ public OllamaEmbeddingModel ollamaEmbedding(OllamaApi ollamaApi) {
return OllamaEmbeddingModel.builder()
.withOllamaApi(ollamaApi)
.withDefaultOptions(OllamaOptions.create().withModel(MODEL))
.withModelManagementOptions(ModelManagementOptions.builder()
.withPullModelStrategy(PullModelStrategy.WHEN_MISSING)
.withAdditionalModels(List.of(ADDITIONAL_MODEL))
.build())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
*/
public class OllamaImage {

public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.3.13");
public static final DockerImageName DEFAULT_IMAGE = DockerImageName.parse("ollama/ollama:0.3.14");

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
*/
package org.springframework.ai.ollama.api;

import static org.assertj.core.api.Assertions.assertThat;

import java.io.IOException;
import java.time.Duration;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledIf;
import org.springframework.ai.ollama.BaseOllamaIT;
import org.springframework.http.HttpStatus;
import org.testcontainers.junit.jupiter.Testcontainers;

import java.io.IOException;
import java.time.Duration;

import static org.assertj.core.api.Assertions.assertThat;

/**
* Integration tests for the Ollama APIs to manage models.
*
Expand All @@ -36,7 +36,7 @@
@DisabledIf("isDisabled")
public class OllamaApiModelsIT extends BaseOllamaIT {

private static final String MODEL = OllamaModel.NOMIC_EMBED_TEXT.getName();
private static final String MODEL = "all-minilm";

static OllamaApi ollamaApi;

Expand All @@ -60,7 +60,7 @@ public void showModel() {
var showModelResponse = ollamaApi.showModel(showModelRequest);

assertThat(showModelResponse).isNotNull();
assertThat(showModelResponse.details().family()).isEqualTo("nomic-bert");
assertThat(showModelResponse.details().family()).isEqualTo("bert");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// :YES: image::yes.svg[width=16]
// :NO: image::no.svg[width=12]
// [%autowidth]


This table compares various Chat Models supported by Spring AI, detailing their capabilities:

Expand Down Expand Up @@ -39,6 +39,5 @@ This table compares various Chat Models supported by Spring AI, detailing their
| xref::api/chat/bedrock/bedrock-llama.adoc[Amazon Bedrock/Llama] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12]
| xref::api/chat/bedrock/bedrock-titan.adoc[Amazon Bedrock/Titan] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12]
| xref::api/chat/bedrock/bedrock-anthropic3.adoc[Amazon Bedrock/Anthropic 3] | text ^a| image::no.svg[width=12] ^a| image::yes.svg[width=16] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12] ^a| image::no.svg[width=12]

|====

Loading

0 comments on commit d5bc9c9

Please sign in to comment.