Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
pgerhard committed May 29, 2024
1 parent ac91302 commit fe827f9
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,11 @@

package org.springframework.ai.openai.chat;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.AdvisedRequest;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.RequestResponseAdvisor;
Expand All @@ -46,6 +39,12 @@
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Description;
import org.springframework.core.ParameterizedTypeReference;
import reactor.core.publisher.Flux;

import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.assertj.core.api.Assertions.assertThat;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;

import org.springframework.ai.chat.client.resolver.BeanNameResolver;
import org.springframework.ai.chat.client.resolver.SimpleNameResolver;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.support.GenericApplicationContext;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.messages.Media;
Expand Down Expand Up @@ -64,21 +69,27 @@
*/
public interface ChatClient {

static ChatClient create(ChatModel chatModel) {
return builder(chatModel).build();
static ChatClient create(ChatModel chatModel, BeanNameResolver beanNameResolver) {
return builder(chatModel, beanNameResolver).build();
}

// QUESTION: Should this constructor be removed in favour of the one requiring the
// beanNameResolver?
static Builder builder(ChatModel chatModel) {
return new Builder(chatModel);
return new Builder(chatModel, new SimpleNameResolver());
}

static Builder builder(ChatModel chatModel, BeanNameResolver beanNameResolver) {
return new Builder(chatModel, beanNameResolver);
}

ChatClientRequest prompt();

ChatClientPromptRequest prompt(Prompt prompt);

/**
* Return a {@link ChatClient.Builder} to create a new {@link ChatClient} whose
* settings are replicated from the default {@link ChatClientRequest} of this client.
* Return a {@link Builder} to create a new {@link ChatClient} whose settings are
* replicated from the default {@link ChatClientRequest} of this client.
*/
Builder mutate();

Expand Down Expand Up @@ -263,16 +274,20 @@ class ChatClientRequest {

private final Map<String, Object> advisorParams = new HashMap<>();

private final BeanNameResolver beanNameResolver;

/* copy constructor */
ChatClientRequest(ChatClientRequest ccr) {
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.systemText, ccr.systemParams, ccr.functionCallbacks,
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams);
ccr.messages, ccr.functionNames, ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams,
ccr.beanNameResolver);
}

public ChatClientRequest(ChatModel chatModel, String userText, Map<String, Object> userParams,
String systemText, Map<String, Object> systemParams, List<FunctionCallback> functionCallbacks,
List<Message> messages, List<String> functionNames, List<Media> media, ChatOptions chatOptions,
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams) {
List<RequestResponseAdvisor> advisors, Map<String, Object> advisorParams,
BeanNameResolver beanNameResolver) {

this.chatModel = chatModel;
this.chatOptions = chatOptions != null ? chatOptions : chatModel.getDefaultOptions();
Expand All @@ -288,14 +303,15 @@ public ChatClientRequest(ChatModel chatModel, String userText, Map<String, Objec
this.media.addAll(media);
this.advisors.addAll(advisors);
this.advisorParams.putAll(advisorParams);
this.beanNameResolver = beanNameResolver;
}

/**
* Return a {@code ChatClient.Builder} to create a new {@code ChatClient} whose
* settings are replicated from this {@code ChatClientRequest}.
*/
public Builder mutate() {
Builder builder = ChatClient.builder(chatModel)
Builder builder = ChatClient.builder(chatModel, beanNameResolver)
.defaultSystem(s -> s.text(this.systemText).params(this.systemParams))
.defaultUser(u -> u.text(this.userText)
.params(this.userParams)
Expand Down Expand Up @@ -365,6 +381,11 @@ public <I, O> ChatClientRequest function(String name, String description,
return this;
}

public <T> ChatClientRequest functionBean(Class<T> functionBeanType) {
this.functionNames.add(beanNameResolver.resolveName(functionBeanType));
return this;
}

public ChatClientRequest functions(String... functionBeanNames) {
Assert.notNull(functionBeanNames, "the functionBeanNames must be non-null");
this.functionNames.addAll(List.of(functionBeanNames));
Expand Down Expand Up @@ -528,7 +549,7 @@ private static ChatClientRequest adviseOnRequest(ChatClientRequest inputRequest,
adviseRequest.userParams(), adviseRequest.systemText(), adviseRequest.systemParams(),
adviseRequest.functionCallbacks(), adviseRequest.messages(), adviseRequest.functionNames(),
adviseRequest.media(), adviseRequest.chatOptions(), adviseRequest.advisors(),
adviseRequest.advisorParams());
adviseRequest.advisorParams(), inputRequest.beanNameResolver);
}

return advisedRequest;
Expand Down Expand Up @@ -735,11 +756,11 @@ class Builder {

private final ChatModel chatModel;

Builder(ChatModel chatModel) {
Builder(ChatModel chatModel, BeanNameResolver beanNameResolver) {
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
this.chatModel = chatModel;
this.defaultRequest = new ChatClientRequest(chatModel, "", Map.of(), "", Map.of(), List.of(), List.of(),
List.of(), List.of(), null, List.of(), Map.of());
List.of(), List.of(), null, List.of(), Map.of(), beanNameResolver);
}

public Builder defaultAdvisors(RequestResponseAdvisor... advisor) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.springframework.ai.chat.client.resolver;

import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.util.Assert;

public class ApplicationContextBeanNameResolver implements BeanNameResolver {

private final GenericApplicationContext context;

public ApplicationContextBeanNameResolver(GenericApplicationContext context) {
this.context = context;
}

@Override
public <T> String resolveName(Class<T> beanType) {
String[] namesForType = context.getBeanNamesForType(beanType);
Assert.isTrue(namesForType.length == 1, "A bean must have a unique definiton");

/*
* The following snippet could be used in other places to resolve the description
* from a function registered as a bean BeanDefinition beanDefinition =
* context.getBeanDefinition(namesForType[0]); String description =
* beanDefinition.getDescription();
*/

return namesForType[0];
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package org.springframework.ai.chat.client.resolver;

public interface BeanNameResolver {

<T> String resolveName(Class<T> beanType);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.springframework.ai.chat.client.resolver;

public class SimpleNameResolver implements BeanNameResolver {

@Override
public <T> String resolveName(Class<T> beanType) {
return beanType.getSimpleName().substring(0, 1).toLowerCase() + beanType.getSimpleName().substring(1);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.autoconfigure.chat.client;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.resolver.ApplicationContextBeanNameResolver;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.client.ChatClientCustomizer;
import org.springframework.beans.factory.ObjectProvider;
Expand All @@ -28,6 +29,7 @@
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Scope;
import org.springframework.context.support.GenericApplicationContext;

/**
* {@link EnableAutoConfiguration Auto-configuration} for {@link ChatClient}.
Expand Down Expand Up @@ -60,8 +62,9 @@ ChatClientBuilderConfigurer chatClientBuilderConfigurer(ObjectProvider<ChatClien
@Bean
@Scope("prototype")
@ConditionalOnMissingBean
ChatClient.Builder chatClientBuilder(ChatClientBuilderConfigurer chatClientBuilderConfigurer, ChatModel chatModel) {
ChatClient.Builder builder = ChatClient.builder(chatModel);
ChatClient.Builder chatClientBuilder(ChatClientBuilderConfigurer chatClientBuilderConfigurer, ChatModel chatModel,
GenericApplicationContext context) {
ChatClient.Builder builder = ChatClient.builder(chatModel, new ApplicationContextBeanNameResolver(context));
return chatClientBuilderConfigurer.configure(builder);
}

Expand Down

0 comments on commit fe827f9

Please sign in to comment.