diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java new file mode 100644 index 0000000000..65af1fb765 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java @@ -0,0 +1,328 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.chat.proxy; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.converter.BeanOutputConverter; +import org.springframework.ai.converter.ListOutputConverter; +import org.springframework.ai.converter.MapOutputConverter; +import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.tool.MockWeatherService; +import org.springframework.ai.openai.chat.ActorsFilms; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.Resource; + +import reactor.core.publisher.Flux; + +/** + * @author Christian Tzolov + * @since 1.0.0 + */ +@SpringBootTest(classes = NvidiaWithOpenAiChatModelIT.Config.class) +@EnabledIfEnvironmentVariable(named = "NVIDIA_API_KEY", matches = ".+") +class NvidiaWithOpenAiChatModelIT { + + private static final Logger logger = LoggerFactory.getLogger(NvidiaWithOpenAiChatModelIT.class); + + private static final String NVIDIA_BASE_URL = "https://integrate.api.nvidia.com"; + + private static final String DEFAULT_NVIDIA_MODEL = "meta/llama-3.1-70b-instruct"; + + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Autowired + private OpenAiChatModel chatModel; + + @Test + void roleTest() { + UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + ChatResponse response = chatModel.call(prompt); + assertThat(response.getResults()).hasSize(1); + assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); + } + + @Test + void streamRoleTest() { + UserMessage userMessage = new UserMessage( + "Tell me about 3 famous pirates from the Golden Age of Piracy and what they did."); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + Flux flux = chatModel.stream(prompt); + + List responses = flux.collectList().block(); + assertThat(responses.size()).isGreaterThan(1); + + String stitchedResponseContent = responses.stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + + assertThat(stitchedResponseContent).contains("Blackbeard"); + } + + @Test + void streamingWithTokenUsage() { + var promptOptions = OpenAiChatOptions.builder().withStreamUsage(true).withSeed(1).build(); + + var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions); + + var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage(); + var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage(); + + assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getGenerationTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); + + assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens()); + assertThat(streamingTokenUsage.getGenerationTokens()).isEqualTo(referenceTokenUsage.getGenerationTokens()); + assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens()); + + } + + @Test + void listOutputConverter() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputConverter outputConverter = new ListOutputConverter(conversionService); + + String format = outputConverter.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.chatModel.call(prompt).getResult(); + + List list = outputConverter.convert(generation.getOutput().getContent()); + assertThat(list).hasSize(5); + + } + + @Test + void mapOutputConverter() { + MapOutputConverter outputConverter = new MapOutputConverter(); + + String format = outputConverter.getFormat(); + String template = """ + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + Map result = outputConverter.convert(generation.getOutput().getContent()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + @Test + void beanOutputConverter() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilms.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography for a random actor. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getContent()); + assertThat(actorsFilms.getActor()).isNotEmpty(); + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + @Test + void beanOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = chatModel.call(prompt).getResult(); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getContent()); + logger.info("" + actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputConverterRecords() { + + BeanOutputConverter outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class); + + String format = outputConverter.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = chatModel.stream(prompt) + .collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .filter(c -> c != null) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream); + logger.info("" + actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void functionCallTest() { + + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = OpenAiChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + ChatResponse response = chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + + assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15"); + } + + @Test + void streamFunctionCallTest() { + + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = OpenAiChatOptions.builder() + .withFunctionCallbacks(List.of(FunctionCallbackWrapper.builder(new MockWeatherService()) + .withName("getCurrentWeather") + .withDescription("Get the weather in location") + .withResponseConverter((response) -> "" + response.temp() + response.unit()) + .build())) + .build(); + + Flux response = chatModel.stream(new Prompt(messages, promptOptions)); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + + assertThat(content).contains("30", "10", "15"); + } + + @Test + void validateCallResponseMetadata() { + // @formatter:off + ChatResponse response = ChatClient.create(chatModel).prompt() + .options(OpenAiChatOptions.builder().withModel(DEFAULT_NVIDIA_MODEL).build()) + .user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did") + .call() + .chatResponse(); + // @formatter:on + + logger.info(response.toString()); + assertThat(response.getMetadata().getId()).isNotEmpty(); + assertThat(response.getMetadata().getModel()).containsIgnoringCase(DEFAULT_NVIDIA_MODEL); + assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getGenerationTokens()).isPositive(); + assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public OpenAiApi chatCompletionApi() { + return new OpenAiApi(NVIDIA_BASE_URL, System.getenv("NVIDIA_API_KEY")); + } + + @Bean + public OpenAiChatModel openAiClient(OpenAiApi openAiApi) { + return new OpenAiChatModel(openAiApi, + OpenAiChatOptions.builder().withMaxTokens(2048).withModel(DEFAULT_NVIDIA_MODEL).build()); + } + + } + +} \ No newline at end of file diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-nvidia-function-calling.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-nvidia-function-calling.jpg new file mode 100644 index 0000000000..85ca19fb96 Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-nvidia-function-calling.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-nvidia-llm-api.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-nvidia-llm-api.jpg new file mode 100644 index 0000000000..cf2631b2b3 Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-nvidia-llm-api.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-nvidia-registration.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-nvidia-registration.jpg new file mode 100644 index 0000000000..0e5af03ac6 Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/spring-ai-nvidia-registration.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc index 0676c38d74..04ee9cb0c3 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/nav.adoc @@ -27,6 +27,7 @@ **** xref:api/chat/functions/minimax-chat-functions.adoc[Function Calling] *** xref:api/chat/moonshot-chat.adoc[Moonshot AI] //// **** xref:api/chat/functions/moonshot-chat-functions.adoc[Function Calling] +*** xref:api/chat/nvidia-chat.adoc[NVIDIA] *** xref:api/chat/ollama-chat.adoc[Ollama] **** xref:api/chat/functions/ollama-chat-functions.adoc[Function Calling] *** xref:api/chat/openai-chat.adoc[OpenAI] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc index 56d35e0733..76a733f0fd 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/groq-chat.adoc @@ -1,6 +1,6 @@ = Groq Chat -https://groq.com/[Groq] is an extreally fast, LPU™ based, AI Inference Engine that support various https://console.groq.com/docs/models[AI Models], +https://groq.com/[Groq] is an extremely fast, LPU™ based, AI Inference Engine that support various https://console.groq.com/docs/models[AI Models], supports `Tool/Function Calling` and exposes a `OpenAI API` compatible endpoint. Spring AI integrates with the https://groq.com/[Groq] by reusing the existing xref::api/chat/openai-chat.adoc[OpenAI] client. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc new file mode 100644 index 0000000000..fc1a33697c --- /dev/null +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/nvidia-chat.adoc @@ -0,0 +1,251 @@ += NVIDIA Chat + +https://docs.api.nvidia.com/nim/reference/llm-apis[NVIDIA LLM API] is a proxy AI Inference Engine offering a wide range of models from link:https://docs.api.nvidia.com/nim/reference/llm-apis#models[various providers]. + +Spring AI integrates with the NVIDIA LLM API by reusing the existing xref::api/chat/openai-chat.adoc[OpenAI] client. +For this you need to set the base-url to `https://integrate.api.nvidia.com`, select one of the provided https://docs.api.nvidia.com/nim/reference/llm-apis#model[LLM models] and get an `api-key` for it. + +image::spring-ai-nvidia-llm-api.jpg[w=800,align="center"] + +NOTE: NVIDIA LLM API requires the `max-token` parameter to be explicitly set or server error will be thrown. + +Check the https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/NvidiaWithOpenAiChatModelIT.java[NvidiaWithOpenAiChatModelIT.java] tests +for examples of using NVIDIA LLM API with Spring AI. + +== Prerequisite + +* Create link:https://build.nvidia.com/explore/discover[NVIDIA] account with sufficient credits. +* Select a LLM Model to use. For example the `meta/llama-3.1-70b-instruct` in the screenshot below. +* From the selected model's page, you can get the `api-key` for accessing this model. + +image::spring-ai-nvidia-registration.jpg[w=800,align="center"] + +== Auto-configuration + +Spring AI provides Spring Boot auto-configuration for the OpenAI Chat Client. +To enable it add the following dependency to your project's Maven `pom.xml` file: + +[source, xml] +---- + + org.springframework.ai + spring-ai-openai-spring-boot-starter + +---- + +or to your Gradle `build.gradle` build file. + +[source,groovy] +---- +dependencies { + implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter' +} +---- + +TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Management] section to add the Spring AI BOM to your build file. + +=== Chat Properties + +==== Retry Properties + +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the OpenAI chat model. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +| spring.ai.retry.on-client-errors | If false, throw a NonTransientAiException, and do not attempt retry for `4xx` client error codes | false +| spring.ai.retry.exclude-on-http-codes | List of HTTP status codes that should not trigger a retry (e.g. to throw NonTransientAiException). | empty +| spring.ai.retry.on-http-codes | List of HTTP status codes that should trigger a retry (e.g. to throw TransientAiException). | empty +|==== + +==== Connection Properties + +The prefix `spring.ai.openai` is used as the property prefix that lets you connect to OpenAI. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.openai.base-url | The URL to connect to. Must be set to `https://integrate.api.nvidia.com` | - +| spring.ai.openai.api-key | The NVIDIA API Key | - +|==== + +==== Configuration Properties + +The prefix `spring.ai.openai.chat` is the property prefix that lets you configure the chat model implementation for OpenAI. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.openai.chat.enabled | Enable OpenAI chat model. | true +| spring.ai.openai.chat.base-url | Optional overrides the spring.ai.openai.base-url to provide chat specific url. Must be set to `https://integrate.api.nvidia.com` | - +| spring.ai.openai.chat.api-key | Optional overrides the spring.ai.openai.api-key to provide chat specific api-key | - +| spring.ai.openai.chat.options.model | The link:https://docs.api.nvidia.com/nim/reference/llm-apis#models[NVIDIA LLM model] to use | - +| spring.ai.openai.chat.options.temperature | The sampling temperature to use that controls the apparent creativity of generated completions. Higher values will make output more random while lower values will make results more focused and deterministic. It is not recommended to modify temperature and top_p for the same completions request as the interaction of these two settings is difficult to predict. | 0.8 +| spring.ai.openai.chat.options.frequencyPenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. | 0.0f +| spring.ai.openai.chat.options.maxTokens | The maximum number of tokens to generate in the chat completion. The total length of input tokens and generated tokens is limited by the model's context length. | - +| spring.ai.openai.chat.options.n | How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. | 1 +| spring.ai.openai.chat.options.presencePenalty | Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. | - +| spring.ai.openai.chat.options.responseFormat | An object specifying the format that the model must output. Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.| - +| spring.ai.openai.chat.options.seed | This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same seed and parameters should return the same result. | - +| spring.ai.openai.chat.options.stop | Up to 4 sequences where the API will stop generating further tokens. | - +| spring.ai.openai.chat.options.topP | An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both. | - +| spring.ai.openai.chat.options.tools | A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. | - +| spring.ai.openai.chat.options.toolChoice | Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces the model to call that function. none is the default when no functions are present. auto is the default if functions are present. | - +| spring.ai.openai.chat.options.user | A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. | - +| spring.ai.openai.chat.options.functions | List of functions, identified by their names, to enable for function calling in a single prompt requests. Functions with those names must exist in the functionCallbacks registry. | - +| spring.ai.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chunk is an empty array and all other chunks will also include a usage field, but with a null value. | false +|==== + +TIP: All properties prefixed with `spring.ai.openai.chat.options` can be overridden at runtime by adding a request specific <> to the `Prompt` call. + +== Runtime Options [[chat-options]] + +The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java[OpenAiChatOptions.java] provides model configurations, such as the model to use, the temperature, the frequency penalty, etc. + +On start-up, the default options can be configured with the `OpenAiChatModel(api, options)` constructor or the `spring.ai.openai.chat.options.*` properties. + +At run-time you can override the default options by adding new, request specific, options to the `Prompt` call. +For example to override the default model and temperature for a specific request: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "Generate the names of 5 famous pirates.", + OpenAiChatOptions.builder() + .withModel("mixtral-8x7b-32768") + .withTemperature(0.4) + .build() + )); +---- + +TIP: In addition to the model specific https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java[OpenAiChatOptions] you can use a portable https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java[ChatOptions] instance, created with the https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptionsBuilder.java[ChatOptionsBuilder#builder()]. + +== Function Calling + +NVIDIA LLM API supports Tool/Function calling when selecting a model that supports it. + +image::spring-ai-nvidia-function-calling.jpg[w=800,align="center"] + +You can register custom Java functions with your ChatModel and have the provided model intelligently choose to output a JSON object containing arguments to call one or many of the registered functions. +This is a powerful technique to connect the LLM capabilities with external tools and APIs. + +=== Tool Example + +Here's a simple example of how to use NVIDIA LLM API function calling with Spring AI: + +[source,application.properties] +---- +spring.ai.openai.api-key=${NVIDIA_API_KEY} +spring.ai.openai.base-url=https://integrate.api.nvidia.com +spring.ai.openai.chat.options.model=meta/llama-3.1-70b-instruct +spring.ai.openai.chat.options.max-tokens=2048 +---- + +[source,java] +---- +@SpringBootApplication +public class NvidiaLlmApplication { + + public static void main(String[] args) { + SpringApplication.run(NvidiaLlmApplication.class, args); + } + + @Bean + CommandLineRunner runner(ChatClient.Builder chatClientBuilder) { + return args -> { + var chatClient = chatClientBuilder.build(); + + var response = chatClient.prompt() + .user("What is the weather in Amsterdam and Paris?") + .functions("weatherFunction") // reference by bean name. + .call() + .content(); + + System.out.println(response); + }; + } + + @Bean + @Description("Get the weather in location") + public Function weatherFunction() { + return new MockWeatherService(); + } + + public static class MockWeatherService implements Function { + + public record WeatherRequest(String location, String unit) {} + public record WeatherResponse(double temp, String unit) {} + + @Override + public WeatherResponse apply(WeatherRequest request) { + double temperature = request.location().contains("Amsterdam") ? 20 : 25; + return new WeatherResponse(temperature, request.unit); + } + } +} +---- + +In this example, when the model needs weather information, it will automatically call the `weatherFunction` bean, which can then fetch real-time weather data. +The expected response looks like this: "The weather in Amsterdam is currently 20 degrees Celsius, and the weather in Paris is currently 25 degrees Celsius." + +Read more about OpenAI link:https://docs.spring.io/spring-ai/reference/api/chat/functions/openai-chat-functions.html[Function Calling]. + + +== Sample Controller + +https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-openai-spring-boot-starter` to your pom (or gradle) dependencies. + +Add a `application.properties` file, under the `src/main/resources` directory, to enable and configure the OpenAi chat model: + +[source,application.properties] +---- +spring.ai.openai.api-key=${NVIDIA_API_KEY} +spring.ai.openai.base-url=https://integrate.api.nvidia.com +spring.ai.openai.chat.options.model=meta/llama-3.1-70b-instruct + +# The NVIDIA LLM API doesn't support embeddings, so we need to disable it. +spring.ai.openai.embedding.enabled=false + +# The NVIDIA LLM API requires this parameter to be set explicitly or server internal error will be thrown. +spring.ai.openai.chat.options.max-tokens=2048 +---- + +TIP: replace the `api-key` with your NVIDIA credentials. + +NOTE: NVIDIA LLM API requires the `max-token` parameter to be explicitly set or server error will be thrown. + + +Here is an example of a simple `@Controller` class that uses the chat model for text generations. + +[source,java] +---- +@RestController +public class ChatController { + + private final OpenAiChatModel chatModel; + + @Autowired + public ChatController(OpenAiChatModel chatModel) { + this.chatModel = chatModel; + } + + @GetMapping("/ai/generate") + public Map generate(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + return Map.of("generation", chatModel.call(message)); + } + + @GetMapping("/ai/generateStream") + public Flux generateStream(@RequestParam(value = "message", defaultValue = "Tell me a joke") String message) { + Prompt prompt = new Prompt(new UserMessage(message)); + return chatModel.stream(prompt); + } +} +----