From c67442d605006ea75e7ebf45400dfa5b6547464f Mon Sep 17 00:00:00 2001 From: Soby Chacko Date: Thu, 19 Sep 2024 15:38:50 -0400 Subject: [PATCH] Add custom header support for Azure OpenAI - Adds configuration properties to allow custom header specification - Implements mechanism to apply custom headers to Azure OpenAI requests - Enhances flexibility for users to customize API interactions These changes allow users to add necessary headers for authentication, tracking, or other purposes when interacting with Azure OpenAI services. Resolves https://github.com/spring-projects/spring-ai/issues/1284 --- .../openai/AzureOpenAiAutoConfiguration.java | 12 ++++++++++-- .../openai/AzureOpenAiConnectionProperties.java | 16 +++++++++++++++- .../azure/AzureOpenAiAutoConfigurationIT.java | 12 +++--------- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index 407efb82ef..6f43eceae8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.azure.openai; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionModel; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; @@ -40,6 +42,7 @@ import com.azure.core.credential.KeyCredential; import com.azure.core.credential.TokenCredential; import com.azure.core.util.ClientOptions; +import com.azure.core.util.Header; /** * @author Piotr Olaszewski @@ -57,14 +60,19 @@ public class AzureOpenAiAutoConfiguration { @Bean @ConditionalOnMissingBean({ OpenAIClient.class, TokenCredential.class }) public OpenAIClient openAIClient(AzureOpenAiConnectionProperties connectionProperties) { - if (StringUtils.hasText(connectionProperties.getApiKey())) { Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty"); + Map customHeaders = connectionProperties.getCustomHeaders(); + List
headers = customHeaders.entrySet() + .stream() + .map(entry -> new Header(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()); + ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers); return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) .credential(new AzureKeyCredential(connectionProperties.getApiKey())) - .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)) + .clientOptions(clientOptions) .buildClient(); } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java index cabd6b2e75..16a128260e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiConnectionProperties.java @@ -1,5 +1,5 @@ /* - * Copyright 2023 - 2024 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.azure.openai; +import java.util.HashMap; +import java.util.Map; + import org.springframework.boot.context.properties.ConfigurationProperties; @ConfigurationProperties(AzureOpenAiConnectionProperties.CONFIG_PREFIX) @@ -40,6 +44,8 @@ public class AzureOpenAiConnectionProperties { */ private String endpoint; + private Map customHeaders = new HashMap<>(); + public String getEndpoint() { return this.endpoint; } @@ -64,4 +70,12 @@ public void setOpenAiApiKey(String openAiApiKey) { this.openAiApiKey = openAiApiKey; } + public Map getCustomHeaders() { + return customHeaders; + } + + public void setCustomHeaders(Map customHeaders) { + this.customHeaders = customHeaders; + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index cce0e9b8d4..11cd9409fb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -15,6 +15,9 @@ */ package org.springframework.ai.autoconfigure.azure; +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.implementation.OpenAIClientImpl; +import com.azure.core.http.*; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; @@ -34,15 +37,6 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.ReflectionUtils; - -import com.azure.ai.openai.OpenAIClient; -import com.azure.ai.openai.implementation.OpenAIClientImpl; -import com.azure.core.http.HttpHeader; -import com.azure.core.http.HttpHeaderName; -import com.azure.core.http.HttpMethod; -import com.azure.core.http.HttpPipeline; -import com.azure.core.http.HttpRequest; -import com.azure.core.http.HttpResponse; import reactor.core.publisher.Flux; import java.lang.reflect.Field;