Skip to content

Commit

Permalink
Unit tests and handling of malformatted responses
Browse files Browse the repository at this point in the history
  • Loading branch information
chedim committed Mar 19, 2024
1 parent 8af8a7b commit 14e8f45
Show file tree
Hide file tree
Showing 10 changed files with 212 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,15 @@ public Consumer<Subscription> onSubscribe(ChatMessageEvent.Initiating event) {
public Action onComplete(ConversationContext ctx) {
return () -> {
SwingUtilities.invokeLater(() -> {
final List<ChatMessage> assistantMessages = toMessages(partialResponseChoices);
if (!assistantMessages.isEmpty()) {
ctx.addChatMessage(assistantMessages.get(0));
try {
final List<ChatMessage> assistantMessages = toMessages(partialResponseChoices);
if (!assistantMessages.isEmpty()) {
ctx.addChatMessage(assistantMessages.get(0));
}
listener.responseCompleted(event.responseArrived(assistantMessages));
} catch (Exception e) {

}
listener.responseCompleted(event.responseArrived(assistantMessages));
});
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import java.util.stream.Collectors;

public class IntentProcessor {
protected static final String PROMPT_CONNECT_TO_DB = "Respond once by only telling the user that, in order to fulfill their request, they first need to connect their IDE plugin to Couchbase cluster using the 'Explorer' tab. Do not provide any other suggestions how to perform requested action in this res ...";
protected static final String PROMPT_ANSWER_USER = "Now answer the user's question in plain text given this additional information and instructions and then continue working according to the original system prompt: ";
private static String secondaryPrompt;
private static String intentPrompt;
private static Map<String, ActionInterface> loadedActions = new HashMap<>();
Expand All @@ -63,37 +65,69 @@ private ActionInterface getAction(String name) {
}

public Disposable process(ChatPanel chat, ChatMessage userMessage, JsonObject intents) {
final ChatLink link = chat.getChatLink();
var application = ApplicationManager.getApplication();
var chatCompletionRequestProvider = application.getService(ChatCompletionRequestProvider.class);
var chatCompletionRequest = chatCompletionRequestProvider.chatCompletionRequest(chat.getChatLink().getConversationContext(), userMessage)
.build();

// replace previous system prompts
chatCompletionRequest.getMessages().removeAll(chatCompletionRequest.getMessages().stream()
.skip(1)
.filter(message -> ChatMessageRole.SYSTEM.value().equals(message.getRole()))
.collect(Collectors.toList()));

String intentPrompt = generateIntentReturnPrompt(chat, userMessage, intents);
chatCompletionRequest.getMessages().add(new ChatMessage(ChatMessageRole.SYSTEM.value(), intentPrompt));

chat.getQuestion().addIntentPrompt(intentPrompt);

ChatMessageEvent.Starting event = ChatMessageEvent.starting(chat.getChatLink(), userMessage);
return application.getService(ChatGptHandler.class)
.handle(chat.getChatLink().getConversationContext(), event.initiating(chatCompletionRequest), chat.getChatLink().getListener())
.subscribeOn(Schedulers.io())
.subscribe();
}

public String generateIntentReturnPrompt(ChatPanel chat, ChatMessage userMessage, JsonObject intents) {
StringBuilder intentPrompt = new StringBuilder();

ActiveCluster activeCluster = ActiveCluster.getInstance();
QueryContext windowContext = IQWindowContent.getInstance().map(IQWindowContent::getClusterContext).orElse(null);
if (activeCluster == null || activeCluster.getCluster() == null) {
intentPrompt.append("Respond once by only telling the user that, in order to fulfill their request, they first need to connect their IDE plugin to Couchbase cluster using the 'Explorer' tab. Do not provide any other suggestions how to perform requested action in this response.");
} else if (intents.containsKey("actions")) {
return PROMPT_CONNECT_TO_DB;
}

intentPrompt.append(PROMPT_ANSWER_USER);
if (intents.containsKey("actions")) {
JsonArray detectedActions = intents.getArray("actions");
for (int i = 0; i < detectedActions.size(); i++) {
JsonObject intent = null;
if (detectedActions.get(i) instanceof String) {
intent = JsonObject.create();
intent.put("action", detectedActions.getString(i));
} else {
intent = detectedActions.getObject(i);
}
ActionInterface action = getAction(intent.getString("action"));
if (action != null) {
String bucketName = null, scopeName = null;
if (intent.containsKey("bucketName")) {
bucketName = intent.getString("bucketName");
if (intent.containsKey("scopeName")) {
scopeName = intent.getString("scopeName");
if (detectedActions != null) {
for (int i = 0; i < detectedActions.size(); i++) {
JsonObject intent = null;
if (detectedActions.get(i) instanceof String) {
intent = JsonObject.create();
intent.put("action", detectedActions.getString(i));
} else {
intent = detectedActions.getObject(i);
if (intent.size() == 0) {
continue;
}
} else if (windowContext != null) {
bucketName = windowContext.getBucket();
scopeName = windowContext.getScope();
}
String prompt = action.fire(chat.getProject(), bucketName, scopeName, intents, intent);
if (prompt != null) {
intentPrompt.append(prompt);
ActionInterface action = getAction(intent.getString("action"));
if (action != null) {
String bucketName = null, scopeName = null;
if (intent.containsKey("bucketName")) {
bucketName = intent.getString("bucketName");
if (intent.containsKey("scopeName")) {
scopeName = intent.getString("scopeName");
}
} else if (windowContext != null) {
bucketName = windowContext.getBucket();
scopeName = windowContext.getScope();
}
String prompt = action.fire(chat.getProject(), bucketName, scopeName, intents, intent);
if (prompt != null) {
intentPrompt.append(prompt);
}
}
}
}
Expand All @@ -108,27 +142,8 @@ public Disposable process(ChatPanel chat, ChatMessage userMessage, JsonObject in
}
}

intentPrompt.append("Now answer the user's question given this additional information and instructions and then continue working according to the original system prompt");
var application = ApplicationManager.getApplication();
var chatCompletionRequestProvider = application.getService(ChatCompletionRequestProvider.class);
var chatCompletionRequest = chatCompletionRequestProvider.chatCompletionRequest(link.getConversationContext(), userMessage)
.build();

// replace previous system prompts
chatCompletionRequest.getMessages().removeAll(chatCompletionRequest.getMessages().stream()
.skip(1)
.filter(message -> ChatMessageRole.SYSTEM.value().equals(message.getRole()))
.collect(Collectors.toList()));
chatCompletionRequest.getMessages().add(new ChatMessage(ChatMessageRole.SYSTEM.value(), intentPrompt.toString()));

Log.info(String.format("IQ intent prompt: %s", intentPrompt.toString()));
chat.getQuestion().addIntentPrompt(intentPrompt.toString());

ChatMessageEvent.Starting event = ChatMessageEvent.starting(link, userMessage);
return application.getService(ChatGptHandler.class)
.handle(link.getConversationContext(), event.initiating(chatCompletionRequest), link.getListener())
.subscribeOn(Schedulers.io())
.subscribe();
return intentPrompt.toString();
}

private void appendCollectionIndexes(@NotNull JsonArray collections, @NotNull StringBuilder intentPrompt) {
Expand All @@ -145,11 +160,11 @@ private void appendCollectionIndexes(@NotNull JsonArray collections, @NotNull St
scope.getChildren().stream()
.filter(collection -> collectionName.equalsIgnoreCase(collection.getName()))
.forEach(collection -> {
Collection c = cluster.bucket(bucket.getName()).scope(scope.getName()).collection(collectionName);
c.queryIndexes().getAllIndexes().forEach(index -> {
collecitonIndexes.put(index.name(), index.indexKey());
});
intentPrompt.append(String.format("Indexes on collection '%s.%s.%s': %s\n", bucket.getName(), scope.getName(), collectionName, collecitonIndexes.toString()));
Collection c = cluster.bucket(bucket.getName()).scope(scope.getName()).collection(collectionName);
c.queryIndexes().getAllIndexes().forEach(index -> {
collecitonIndexes.put(index.name(), index.indexKey());
});
intentPrompt.append(String.format("Indexes on collection '%s.%s.%s': %s\n", bucket.getName(), scope.getName(), collectionName, collecitonIndexes.toString()));
});
});
});
Expand Down
43 changes: 19 additions & 24 deletions src/main/java/com/couchbase/intellij/tree/iq/ui/ChatPanel.java
Original file line number Diff line number Diff line change
Expand Up @@ -257,15 +257,19 @@ public void responseArriving(ChatMessageEvent.ResponseArriving event) {
public void responseArrived(ChatMessageEvent.ResponseArrived event) {
}

public boolean isJsonResponse(ChatMessage message) {
return message.getContent().startsWith("{");
}

@Override
public void responseCompleted(ChatMessageEvent.ResponseArrived event) {
messageRetryCount = 0;
List<ChatMessage> response = event.getResponseChoices();
response.forEach(message -> Log.info(String.format("IQ response message: %s", message.toString())));
if (response.size() == 1 && response.get(0).getContent().startsWith("{")) {
if (response.size() == 1 && isJsonResponse(response.get(0))) {
JsonObject intents = JsonObject.fromJson(response.get(0).getContent());
if (isEmptyResponse(intents) || containsNoneIntent(intents)) {
response = Arrays.asList(new ChatMessage("assistant", FAKE_CONFUSION) );
if (isEmptyResponse(intents)) {
response = Arrays.asList(new ChatMessage("assistant", FAKE_CONFUSION));
} else {
IntentProcessor intentProcessor = ApplicationManager.getApplication().getService(IntentProcessor.class);
getQuestion().addIntentResponse(intents);
Expand All @@ -280,32 +284,23 @@ public void responseCompleted(ChatMessageEvent.ResponseArrived event) {
});
}

private boolean containsNoneIntent(JsonObject intents) {
JsonArray actions = intents.getArray("actions");
for (int i = 0; i < actions.size(); i++) {
JsonObject action = actions.getObject(i);

if (!action.containsKey("action") || action.getString("action").equalsIgnoreCase("none")) {
return true;
}
}

return false;
}

private boolean isEmptyResponse(JsonObject intents) {
return intents.size() == 0
|| !intents.containsKey("actions")
|| !(intents.get("actions") instanceof JsonArray)
|| intents.getArray("actions").size() == 0;
return intents == null || !(
intents.size() > 0 ||
intents.containsKey("collections") || (
intents.containsKey("actions") &&
(intents.get("actions") instanceof JsonArray) &&
intents.getArray("actions").size() > 0
)
);
}

public void setContent(List<ChatMessage> content) {
content.forEach(message -> message.setContent(message.getContent()
.replaceAll("/```\n\s*SELECT/gmi", "```sql\nSELECT")
.replaceAll("```\nUPDATE", "```sql\nUPDATE")
.replaceAll("```\nDELETE", "```sql\nDELETE")
.replaceAll("```\nCREATE", "```sql\nCREATE")
.replaceAll("/```\n\s*SELECT/gmi", "```sql\nSELECT")
.replaceAll("```\nUPDATE", "```sql\nUPDATE")
.replaceAll("```\nDELETE", "```sql\nDELETE")
.replaceAll("```\nCREATE", "```sql\nCREATE")
// .replaceAll("```sql", "```sqlpp")
));
TextFragment parseResult = ChatCompletionParser.parseGPT35TurboWithStream(content);
Expand Down
2 changes: 1 addition & 1 deletion src/main/resources/iq/intent_prompt.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ If you extracted any entities or identified actions or intents listed above, res
]
}

You have access to user cluster via the plugin. Respond with the JSON to query necessary information from the plugin.
You have access to user cluster via the plugin.
If user asks a question about their cluster, always respond in the JSON format listed above.
If any additional information is needed, respond in the JSON format listed above.
Do not EVER add any non-json text to your response with JSON.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import com.couchbase.intellij.tree.iq.chat.ConfigurationPage;
import com.couchbase.intellij.tree.iq.core.IQCredentials;
import com.couchbase.intellij.tree.iq.settings.OpenAISettingsState;
import com.couchbase.intellij.tree.iq.ui.ChatPanel;
import com.couchbase.intellij.workbench.Log;
import org.mockito.Mockito;

public abstract class AbstractCapellaTest extends AbstractIQTest {
@Override
Expand All @@ -30,7 +32,8 @@ protected void setUp() throws Exception {
Log.setLevel(3);
Log.setPrinter(new Log.StdoutPrinter());

link = new ChatLinkService(getProject(), null, cp);
ChatLinkService link = new ChatLinkService(getProject(), null, cp);
panel = new ChatPanel(getProject(), iqGptConfig, Mockito.mock(CapellaOrganizationList.class), Mockito.mock(CapellaOrganization.class), Mockito.mock(ChatPanel.LogoutListener.class), Mockito.mock(ChatPanel.OrganizationListener.class));
ctx = new ChatLinkState(cp);
}

Expand Down
Loading

0 comments on commit 14e8f45

Please sign in to comment.