Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
pgerhard committed May 29, 2024
1 parent ac91302 commit 6fa6f9f
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,11 @@

package org.springframework.ai.openai.chat;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.RequestResponseAdvisor;
Expand All @@ -46,6 +39,12 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Description;
import org.springframework.core.ParameterizedTypeReference;
import reactor.core.publisher.Flux;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;

import org.springframework.ai.chat.client.resolver.BeanNameResolver;
import org.springframework.ai.chat.client.resolver.SimpleNameResolver;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.support.GenericApplicationContext;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.Media;
Expand Down Expand Up @@ -64,20 +69,25 @@
*/
public interface ChatClient {

static ChatClient create(ChatModel chatModel) {
return builder(chatModel).build();
static ChatClient create(ChatModel chatModel, BeanNameResolver beanNameResolver) {
return builder(chatModel, beanNameResolver).build();
}

// QUESTION: Should this constructor be removed in favour of the one requiring the beanNameResolver?
static Builder builder(ChatModel chatModel) {
return new Builder(chatModel);
return new Builder(chatModel, new SimpleNameResolver());
}

static Builder builder(ChatModel chatModel, BeanNameResolver beanNameResolver) {
return new Builder(chatModel, beanNameResolver);
}

ChatClientRequest prompt();

ChatClientPromptRequest prompt(Prompt prompt);

/**
* Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose
* Return a {@link Builder} to create a new {@link ChatClient} whose
* settings are replicated from the default {@link ChatClientRequest} of this client.
*/
Builder mutate();
Expand Down Expand Up @@ -263,16 +273,18 @@ class ChatClientRequest {

private final Map<String, Object> advisorParams = new HashMap<>();

private final BeanNameResolver beanNameResolver;

/* copy constructor */
ChatClientRequest(ChatClientRequest ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams);
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, ccr.beanNameResolver);
}

public ChatClientRequest(ChatModel chatModel, String userText, Map<String, Object> userParams,
String systemText, Map<String, Object> systemParams, List<FunctionCallback> functionCallbacks,
List<Message> messages, List<String> functionNames, List<Media> media, ChatOptions chatOptions,
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams) {
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams, BeanNameResolver beanNameResolver) {

this.chatModel = chatModel;
this.chatOptions = chatOptions != null ? chatOptions : chatModel.getDefaultOptions();
Expand All @@ -288,14 +300,15 @@ public ChatClientRequest(ChatModel chatModel, String userText, Map<String, Objec
this.media.addAll(media);
this.advisors.addAll(advisors);
this.advisorParams.putAll(advisorParams);
this.beanNameResolver = beanNameResolver;
}

/**
* Return a {@code ChatClient.Builder} to create a new {@code ChatClient} whose
* settings are replicated from this {@code ChatClientRequest}.
*/
public Builder mutate() {
Builder builder = ChatClient.builder(chatModel)
Builder builder = ChatClient.builder(chatModel, beanNameResolver)
.defaultSystem(s -> s.text(this.systemText).params(this.systemParams))
.defaultUser(u -> u.text(this.userText)
.params(this.userParams)
Expand Down Expand Up @@ -365,6 +378,11 @@ public <I, O> ChatClientRequest function(String name, String description,
return this;
}

public <T> ChatClientRequest functionBean(Class<T> functionBeanType) {
this.functionNames.add(beanNameResolver.resolveName(functionBeanType));
return this;
}

public ChatClientRequest functions(String... functionBeanNames) {
Assert.notNull(functionBeanNames, "the functionBeanNames must be non-null");
this.functionNames.addAll(List.of(functionBeanNames));
Expand Down Expand Up @@ -528,7 +546,7 @@ private static ChatClientRequest adviseOnRequest(ChatClientRequest inputRequest,
adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(),
adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(),
adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(),
adviseRequest.advisorParams());
adviseRequest.advisorParams(), inputRequest.beanNameResolver);
}

return advisedRequest;
Expand Down Expand Up @@ -735,12 +753,12 @@ class Builder {

private final ChatModel chatModel;

Builder(ChatModel chatModel) {
Builder(ChatModel chatModel, BeanNameResolver beanNameResolver) {
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
this.chatModel = chatModel;
this.defaultRequest = new ChatClientRequest(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
List.of(), List.of(), null, List.of(), Map.of());
}
List.of(), List.of(), null, List.of(), Map.of(), beanNameResolver);
}

public Builder defaultAdvisors(RequestResponseAdvisor... advisor) {
this.defaultRequest.advisors(advisor);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package org.springframework.ai.chat.client.resolver;

import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.util.Assert;

public class ApplicationContextBeanNameResolver implements BeanNameResolver {
private final GenericApplicationContext context;

public ApplicationContextBeanNameResolver(GenericApplicationContext context) {
this.context = context;
}

@Override
public <T> String resolveName(Class<T> beanType) {
String[] namesForType = context.getBeanNamesForType(beanType);
Assert.isTrue(namesForType.length == 1, "A bean must have a unique definiton");

/* The following snippet could be used in other places to resolve the description from a function registered as a bean
BeanDefinition beanDefinition = context.getBeanDefinition(namesForType[0]);
String description = beanDefinition.getDescription();
*/

return namesForType[0];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package org.springframework.ai.chat.client.resolver;

public interface BeanNameResolver {

<T> String resolveName(Class<T> beanType);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.springframework.ai.chat.client.resolver;

public class SimpleNameResolver implements BeanNameResolver {
@Override
public <T> String resolveName(Class<T> beanType) {
return beanType.getSimpleName().substring(0, 1).toLowerCase()
+ beanType.getSimpleName().substring(1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.autoconfigure.chat.client;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.resolver.ApplicationContextBeanNameResolver;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.client.ChatClientCustomizer;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -28,6 +29,7 @@
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Scope;
import org.springframework.context.support.GenericApplicationContext;

/**
* {@link EnableAutoConfiguration Auto-configuration} for {@link ChatClient}.
Expand Down Expand Up @@ -60,8 +62,8 @@ ChatClientBuilderConfigurer chatClientBuilderConfigurer(ObjectProvider<ChatClien
@Bean
@Scope("prototype")
@ConditionalOnMissingBean
ChatClient.Builder chatClientBuilder(ChatClientBuilderConfigurer chatClientBuilderConfigurer, ChatModel chatModel) {
ChatClient.Builder builder = ChatClient.builder(chatModel);
ChatClient.Builder chatClientBuilder(ChatClientBuilderConfigurer chatClientBuilderConfigurer, ChatModel chatModel, GenericApplicationContext context) {
ChatClient.Builder builder = ChatClient.builder(chatModel, new ApplicationContextBeanNameResolver(context));
return chatClientBuilderConfigurer.configure(builder);
}

Expand Down

0 comments on commit 6fa6f9f

Please sign in to comment.