Skip to content

Commit

Permalink
Add googleSearchRetrieval as an option for VertexAIGeminiChatOptions
Browse files Browse the repository at this point in the history
- Add test
  • Loading branch information
markpollack committed Aug 22, 2024
1 parent bc55bc7 commit 793052c
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Map;
import java.util.Set;

import com.google.cloud.vertexai.api.GoogleSearchRetrieval;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
Expand Down Expand Up @@ -315,8 +316,19 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
}

// Add the enabled functions definitions to the request's tools parameter.
List<Tool> tools = new ArrayList<>();
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
List<Tool> tools = this.getFunctionTools(functionsForThisRequest);
tools.addAll(this.getFunctionTools(functionsForThisRequest));
}

if (prompt.getOptions() instanceof VertexAiGeminiChatOptions options && options.getGoogleSearchRetrieval()) {
final var googleSearchRetrieval = GoogleSearchRetrieval.newBuilder().getDefaultInstanceForType();
final var googleSearchRetrievalTool = Tool.newBuilder()
.setGoogleSearchRetrieval(googleSearchRetrieval)
.build();
tools.add(googleSearchRetrievalTool);
}
if (!CollectionUtils.isEmpty(tools)) {
generativeModelBuilder.setTools(tools);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;

import com.fasterxml.jackson.annotation.JsonIgnore;
Expand All @@ -35,7 +36,8 @@
/**
* @author Christian Tzolov
* @author Thomas Vitale
* @since 0.8.1
* @author Grogdunn
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class VertexAiGeminiChatOptions implements FunctionCallingOptions, ChatOptions {
Expand Down Expand Up @@ -107,6 +109,13 @@ public enum TransportType {
@JsonIgnore
private Set<String> functions = new HashSet<>();

/**
* Use Google search Grounding feature
*/
@JsonIgnore
private boolean googleSearchRetrieval = false;


// @formatter:on

public static Builder builder() {
Expand Down Expand Up @@ -180,6 +189,11 @@ public Builder withFunction(String functionName) {
return this;
}

public Builder withGoogleSearchRetrieval(boolean googleSearch) {
this.options.googleSearchRetrieval = googleSearch;
return this;
}

public VertexAiGeminiChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -299,107 +313,42 @@ public Float getPresencePenalty() {
return null;
}

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((stopSequences == null) ? 0 : stopSequences.hashCode());
result = prime * result + ((temperature == null) ? 0 : temperature.hashCode());
result = prime * result + ((topP == null) ? 0 : topP.hashCode());
result = prime * result + ((topK == null) ? 0 : topK.hashCode());
result = prime * result + ((candidateCount == null) ? 0 : candidateCount.hashCode());
result = prime * result + ((maxOutputTokens == null) ? 0 : maxOutputTokens.hashCode());
result = prime * result + ((model == null) ? 0 : model.hashCode());
result = prime * result + ((responseMimeType == null) ? 0 : responseMimeType.hashCode());
result = prime * result + ((functionCallbacks == null) ? 0 : functionCallbacks.hashCode());
result = prime * result + ((functions == null) ? 0 : functions.hashCode());
return result;
public boolean getGoogleSearchRetrieval() {
return this.googleSearchRetrieval;
}

public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) {
this.googleSearchRetrieval = googleSearchRetrieval;
}

@Override
public boolean equals(Object obj) {
if (this == obj)
public boolean equals(Object o) {
if (this == o)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
VertexAiGeminiChatOptions other = (VertexAiGeminiChatOptions) obj;
if (stopSequences == null) {
if (other.stopSequences != null)
return false;
}
else if (!stopSequences.equals(other.stopSequences))
if (!(o instanceof VertexAiGeminiChatOptions that))
return false;
if (temperature == null) {
if (other.temperature != null)
return false;
}
else if (!temperature.equals(other.temperature))
return false;
if (topP == null) {
if (other.topP != null)
return false;
}
else if (!topP.equals(other.topP))
return false;
if (topK == null) {
if (other.topK != null)
return false;
}
else if (!topK.equals(other.topK))
return false;
if (candidateCount == null) {
if (other.candidateCount != null)
return false;
}
else if (!candidateCount.equals(other.candidateCount))
return false;
if (maxOutputTokens == null) {
if (other.maxOutputTokens != null)
return false;
}
else if (!maxOutputTokens.equals(other.maxOutputTokens))
return false;
if (model == null) {
if (other.model != null)
return false;
}
else if (!model.equals(other.model))
return false;
if (responseMimeType == null) {
if (other.responseMimeType != null)
return false;
}
else if (!responseMimeType.equals(other.responseMimeType)) {
return false;
}
if (functionCallbacks == null) {
if (other.functionCallbacks != null)
return false;
}
else if (!functionCallbacks.equals(other.functionCallbacks))
return false;
if (functions == null) {
if (other.functions != null)
return false;
}
else if (!functions.equals(other.functions))
return false;
return true;
return googleSearchRetrieval == that.googleSearchRetrieval && Objects.equals(stopSequences, that.stopSequences)
&& Objects.equals(temperature, that.temperature) && Objects.equals(topP, that.topP)
&& Objects.equals(topK, that.topK) && Objects.equals(candidateCount, that.candidateCount)
&& Objects.equals(maxOutputTokens, that.maxOutputTokens) && Objects.equals(model, that.model)
&& Objects.equals(responseMimeType, that.responseMimeType)
&& Objects.equals(functionCallbacks, that.functionCallbacks)
&& Objects.equals(functions, that.functions);
}

@Override
public int hashCode() {
return Objects.hash(stopSequences, temperature, topP, topK, candidateCount, maxOutputTokens, model,
responseMimeType, functionCallbacks, functions, googleSearchRetrieval);
}

@Override
public String toString() {
return "VertexAiGeminiChatOptions [stopSequences=" + stopSequences + ", temperature=" + temperature + ", topP="
+ topP + ", topK=" + topK + ", candidateCount=" + candidateCount + ", maxOutputTokens="
+ maxOutputTokens + ", model=" + model + ", responseMimeType=" + responseMimeType
+ ", functionCallbacks=" + functionCallbacks + ", functions=" + functions + ", getClass()=" + getClass()
+ ", getStopSequences()=" + getStopSequences() + ", getTemperature()=" + getTemperature()
+ ", getTopP()=" + getTopP() + ", getTopK()=" + getTopK() + ", getCandidateCount()="
+ getCandidateCount() + ", getMaxOutputTokens()=" + getMaxOutputTokens() + ", getModel()=" + getModel()
+ ", getFunctionCallbacks()=" + getFunctionCallbacks() + ", getFunctions()=" + getFunctions()
+ ", hashCode()=" + hashCode() + ", toString()=" + super.toString() + "]";
return "VertexAiGeminiChatOptions{" + "stopSequences=" + stopSequences + ", temperature=" + temperature
+ ", topP=" + topP + ", topK=" + topK + ", candidateCount=" + candidateCount + ", maxOutputTokens="
+ maxOutputTokens + ", model='" + model + '\'' + ", responseMimeType='" + responseMimeType + '\''
+ ", functionCallbacks=" + functionCallbacks + ", functions=" + functions + ", googleSearchRetrieval="
+ googleSearchRetrieval + '}';
}

@Override
Expand All @@ -419,6 +368,8 @@ public static VertexAiGeminiChatOptions fromOptions(VertexAiGeminiChatOptions fr
options.setFunctionCallbacks(fromOptions.getFunctionCallbacks());
options.setResponseMimeType(fromOptions.getResponseMimeType());
options.setFunctions(fromOptions.getFunctions());
options.setResponseMimeType(fromOptions.getResponseMimeType());
options.setGoogleSearchRetrieval(fromOptions.getGoogleSearchRetrieval());
return options;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import com.google.cloud.vertexai.Transport;
import com.google.cloud.vertexai.VertexAI;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

Expand Down Expand Up @@ -63,15 +64,28 @@ class VertexAiGeminiChatModelIT {

@Test
void roleTest() {
Prompt prompt = createPrompt(VertexAiGeminiChatOptions.builder().build());
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew");
}

@Test
void googleSearchTool() {
Prompt prompt = createPrompt(VertexAiGeminiChatOptions.builder().withGoogleSearchRetrieval(true).build());
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew");
}

@NotNull
private Prompt createPrompt(VertexAiGeminiChatOptions chatOptions) {
String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did.";
String name = "Bob";
String voice = "pirate";
UserMessage userMessage = new UserMessage(request);
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = chatModel.call(prompt);
assertThat(response.getResult().getOutput().getContent()).containsAnyOf("Blackbeard", "Bartholomew");
Prompt prompt = new Prompt(List.of(userMessage, systemMessage), chatOptions);
return prompt;
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ The prefix `spring.ai.vertex.ai.gemini.chat` is the property prefix that lets yo

| spring.ai.vertex.ai.gemini.chat.options.model | Supported https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini[Vertex AI Gemini Chat model] to use include the (1.0 ) `gemini-pro`, `gemini-pro-vision` (deprecated) and the new `gemini-1.5-pro-001`, `gemini-1.5-flash-001` models. | gemini-1.5-pro-001
| spring.ai.vertex.ai.gemini.chat.options.responseMimeType | Output response mimetype of the generated candidate text. | `text/plain`: (default) Text output or `application/json`: JSON response.
| spring.ai.vertex.ai.gemini.chat.options.googleSearchRetrieval | Use Google search Grounding feature | `true` or `false`, default `false`.
| spring.ai.vertex.ai.gemini.chat.options.temperature | Controls the randomness of the output. Values can range over [0.0,1.0], inclusive. A value closer to 1.0 will produce responses that are more varied, while a value closer to 0.0 will typically result in less surprising responses from the generative. This value specifies default to be used by the backend while making the call to the generative. | 0.8
| spring.ai.vertex.ai.gemini.chat.options.topK | The maximum number of tokens to consider when sampling. The generative uses combined Top-k and nucleus sampling. Top-k sampling considers the set of topK most probable tokens. | -
| spring.ai.vertex.ai.gemini.chat.options.topP | The maximum cumulative probability of tokens to consider when sampling. The generative uses combined Top-k and nucleus sampling. Nucleus sampling considers the smallest set of tokens whose probability sum is at least topP. | -
Expand Down

0 comments on commit 793052c

Please sign in to comment.