From 14e8f452f1693d544f7be6403520fb17b6e743d2 Mon Sep 17 00:00:00 2001 From: Dmitrii Chechetkin Date: Tue, 19 Mar 2024 09:54:15 -0400 Subject: [PATCH] Unit tests and handling of malformatted responses --- .../intellij/tree/iq/chat/ChatGptHandler.java | 12 +- .../tree/iq/intents/IntentProcessor.java | 113 ++++++++------ .../intellij/tree/iq/ui/ChatPanel.java | 43 +++--- src/main/resources/iq/intent_prompt.txt | 2 +- .../intellij/tree/iq/AbstractCapellaTest.java | 5 +- .../intellij/tree/iq/AbstractIQTest.java | 140 +++++++++--------- .../tree/iq/AbstractMockedIQTest.java | 4 +- .../intents/CapellaIntentProcessorTest.java | 6 +- .../iq/intents/MockedIntentProcessorTest.java | 42 +++--- .../intents/actions/CreateCollectionTest.java | 41 +++-- 10 files changed, 212 insertions(+), 196 deletions(-) diff --git a/src/main/java/com/couchbase/intellij/tree/iq/chat/ChatGptHandler.java b/src/main/java/com/couchbase/intellij/tree/iq/chat/ChatGptHandler.java index db7f97d7..d2cd37ab 100644 --- a/src/main/java/com/couchbase/intellij/tree/iq/chat/ChatGptHandler.java +++ b/src/main/java/com/couchbase/intellij/tree/iq/chat/ChatGptHandler.java @@ -52,11 +52,15 @@ public Consumer onSubscribe(ChatMessageEvent.Initiating event) { public Action onComplete(ConversationContext ctx) { return () -> { SwingUtilities.invokeLater(() -> { - final List assistantMessages = toMessages(partialResponseChoices); - if (!assistantMessages.isEmpty()) { - ctx.addChatMessage(assistantMessages.get(0)); + try { + final List assistantMessages = toMessages(partialResponseChoices); + if (!assistantMessages.isEmpty()) { + ctx.addChatMessage(assistantMessages.get(0)); + } + listener.responseCompleted(event.responseArrived(assistantMessages)); + } catch (Exception e) { + } - listener.responseCompleted(event.responseArrived(assistantMessages)); }); }; } diff --git a/src/main/java/com/couchbase/intellij/tree/iq/intents/IntentProcessor.java b/src/main/java/com/couchbase/intellij/tree/iq/intents/IntentProcessor.java index c6f856d0..de15c966 100644 --- a/src/main/java/com/couchbase/intellij/tree/iq/intents/IntentProcessor.java +++ b/src/main/java/com/couchbase/intellij/tree/iq/intents/IntentProcessor.java @@ -39,6 +39,8 @@ import java.util.stream.Collectors; public class IntentProcessor { + protected static final String PROMPT_CONNECT_TO_DB = "Respond once by only telling the user that, in order to fulfill their request, they first need to connect their IDE plugin to Couchbase cluster using the 'Explorer' tab. Do not provide any other suggestions how to perform requested action in this res ..."; + protected static final String PROMPT_ANSWER_USER = "Now answer the user's question in plain text given this additional information and instructions and then continue working according to the original system prompt: "; private static String secondaryPrompt; private static String intentPrompt; private static Map loadedActions = new HashMap<>(); @@ -63,37 +65,69 @@ private ActionInterface getAction(String name) { } public Disposable process(ChatPanel chat, ChatMessage userMessage, JsonObject intents) { - final ChatLink link = chat.getChatLink(); + var application = ApplicationManager.getApplication(); + var chatCompletionRequestProvider = application.getService(ChatCompletionRequestProvider.class); + var chatCompletionRequest = chatCompletionRequestProvider.chatCompletionRequest(chat.getChatLink().getConversationContext(), userMessage) + .build(); + + // replace previous system prompts + chatCompletionRequest.getMessages().removeAll(chatCompletionRequest.getMessages().stream() + .skip(1) + .filter(message -> ChatMessageRole.SYSTEM.value().equals(message.getRole())) + .collect(Collectors.toList())); + + String intentPrompt = generateIntentReturnPrompt(chat, userMessage, intents); + chatCompletionRequest.getMessages().add(new ChatMessage(ChatMessageRole.SYSTEM.value(), intentPrompt)); + + chat.getQuestion().addIntentPrompt(intentPrompt); + + ChatMessageEvent.Starting event = ChatMessageEvent.starting(chat.getChatLink(), userMessage); + return application.getService(ChatGptHandler.class) + .handle(chat.getChatLink().getConversationContext(), event.initiating(chatCompletionRequest), chat.getChatLink().getListener()) + .subscribeOn(Schedulers.io()) + .subscribe(); + } + + public String generateIntentReturnPrompt(ChatPanel chat, ChatMessage userMessage, JsonObject intents) { StringBuilder intentPrompt = new StringBuilder(); + ActiveCluster activeCluster = ActiveCluster.getInstance(); QueryContext windowContext = IQWindowContent.getInstance().map(IQWindowContent::getClusterContext).orElse(null); if (activeCluster == null || activeCluster.getCluster() == null) { - intentPrompt.append("Respond once by only telling the user that, in order to fulfill their request, they first need to connect their IDE plugin to Couchbase cluster using the 'Explorer' tab. Do not provide any other suggestions how to perform requested action in this response."); - } else if (intents.containsKey("actions")) { + return PROMPT_CONNECT_TO_DB; + } + + intentPrompt.append(PROMPT_ANSWER_USER); + if (intents.containsKey("actions")) { JsonArray detectedActions = intents.getArray("actions"); - for (int i = 0; i < detectedActions.size(); i++) { - JsonObject intent = null; - if (detectedActions.get(i) instanceof String) { - intent = JsonObject.create(); - intent.put("action", detectedActions.getString(i)); - } else { - intent = detectedActions.getObject(i); - } - ActionInterface action = getAction(intent.getString("action")); - if (action != null) { - String bucketName = null, scopeName = null; - if (intent.containsKey("bucketName")) { - bucketName = intent.getString("bucketName"); - if (intent.containsKey("scopeName")) { - scopeName = intent.getString("scopeName"); + if (detectedActions != null) { + for (int i = 0; i < detectedActions.size(); i++) { + JsonObject intent = null; + if (detectedActions.get(i) instanceof String) { + intent = JsonObject.create(); + intent.put("action", detectedActions.getString(i)); + } else { + intent = detectedActions.getObject(i); + if (intent.size() == 0) { + continue; } - } else if (windowContext != null) { - bucketName = windowContext.getBucket(); - scopeName = windowContext.getScope(); } - String prompt = action.fire(chat.getProject(), bucketName, scopeName, intents, intent); - if (prompt != null) { - intentPrompt.append(prompt); + ActionInterface action = getAction(intent.getString("action")); + if (action != null) { + String bucketName = null, scopeName = null; + if (intent.containsKey("bucketName")) { + bucketName = intent.getString("bucketName"); + if (intent.containsKey("scopeName")) { + scopeName = intent.getString("scopeName"); + } + } else if (windowContext != null) { + bucketName = windowContext.getBucket(); + scopeName = windowContext.getScope(); + } + String prompt = action.fire(chat.getProject(), bucketName, scopeName, intents, intent); + if (prompt != null) { + intentPrompt.append(prompt); + } } } } @@ -108,27 +142,8 @@ public Disposable process(ChatPanel chat, ChatMessage userMessage, JsonObject in } } - intentPrompt.append("Now answer the user's question given this additional information and instructions and then continue working according to the original system prompt"); - var application = ApplicationManager.getApplication(); - var chatCompletionRequestProvider = application.getService(ChatCompletionRequestProvider.class); - var chatCompletionRequest = chatCompletionRequestProvider.chatCompletionRequest(link.getConversationContext(), userMessage) - .build(); - - // replace previous system prompts - chatCompletionRequest.getMessages().removeAll(chatCompletionRequest.getMessages().stream() - .skip(1) - .filter(message -> ChatMessageRole.SYSTEM.value().equals(message.getRole())) - .collect(Collectors.toList())); - chatCompletionRequest.getMessages().add(new ChatMessage(ChatMessageRole.SYSTEM.value(), intentPrompt.toString())); - Log.info(String.format("IQ intent prompt: %s", intentPrompt.toString())); - chat.getQuestion().addIntentPrompt(intentPrompt.toString()); - - ChatMessageEvent.Starting event = ChatMessageEvent.starting(link, userMessage); - return application.getService(ChatGptHandler.class) - .handle(link.getConversationContext(), event.initiating(chatCompletionRequest), link.getListener()) - .subscribeOn(Schedulers.io()) - .subscribe(); + return intentPrompt.toString(); } private void appendCollectionIndexes(@NotNull JsonArray collections, @NotNull StringBuilder intentPrompt) { @@ -145,11 +160,11 @@ private void appendCollectionIndexes(@NotNull JsonArray collections, @NotNull St scope.getChildren().stream() .filter(collection -> collectionName.equalsIgnoreCase(collection.getName())) .forEach(collection -> { - Collection c = cluster.bucket(bucket.getName()).scope(scope.getName()).collection(collectionName); - c.queryIndexes().getAllIndexes().forEach(index -> { - collecitonIndexes.put(index.name(), index.indexKey()); - }); - intentPrompt.append(String.format("Indexes on collection '%s.%s.%s': %s\n", bucket.getName(), scope.getName(), collectionName, collecitonIndexes.toString())); + Collection c = cluster.bucket(bucket.getName()).scope(scope.getName()).collection(collectionName); + c.queryIndexes().getAllIndexes().forEach(index -> { + collecitonIndexes.put(index.name(), index.indexKey()); + }); + intentPrompt.append(String.format("Indexes on collection '%s.%s.%s': %s\n", bucket.getName(), scope.getName(), collectionName, collecitonIndexes.toString())); }); }); }); diff --git a/src/main/java/com/couchbase/intellij/tree/iq/ui/ChatPanel.java b/src/main/java/com/couchbase/intellij/tree/iq/ui/ChatPanel.java index 33432af8..dd761305 100644 --- a/src/main/java/com/couchbase/intellij/tree/iq/ui/ChatPanel.java +++ b/src/main/java/com/couchbase/intellij/tree/iq/ui/ChatPanel.java @@ -257,15 +257,19 @@ public void responseArriving(ChatMessageEvent.ResponseArriving event) { public void responseArrived(ChatMessageEvent.ResponseArrived event) { } + public boolean isJsonResponse(ChatMessage message) { + return message.getContent().startsWith("{"); + } + @Override public void responseCompleted(ChatMessageEvent.ResponseArrived event) { messageRetryCount = 0; List response = event.getResponseChoices(); response.forEach(message -> Log.info(String.format("IQ response message: %s", message.toString()))); - if (response.size() == 1 && response.get(0).getContent().startsWith("{")) { + if (response.size() == 1 && isJsonResponse(response.get(0))) { JsonObject intents = JsonObject.fromJson(response.get(0).getContent()); - if (isEmptyResponse(intents) || containsNoneIntent(intents)) { - response = Arrays.asList(new ChatMessage("assistant", FAKE_CONFUSION) ); + if (isEmptyResponse(intents)) { + response = Arrays.asList(new ChatMessage("assistant", FAKE_CONFUSION)); } else { IntentProcessor intentProcessor = ApplicationManager.getApplication().getService(IntentProcessor.class); getQuestion().addIntentResponse(intents); @@ -280,32 +284,23 @@ public void responseCompleted(ChatMessageEvent.ResponseArrived event) { }); } - private boolean containsNoneIntent(JsonObject intents) { - JsonArray actions = intents.getArray("actions"); - for (int i = 0; i < actions.size(); i++) { - JsonObject action = actions.getObject(i); - - if (!action.containsKey("action") || action.getString("action").equalsIgnoreCase("none")) { - return true; - } - } - - return false; - } - private boolean isEmptyResponse(JsonObject intents) { - return intents.size() == 0 - || !intents.containsKey("actions") - || !(intents.get("actions") instanceof JsonArray) - || intents.getArray("actions").size() == 0; + return intents == null || !( + intents.size() > 0 || + intents.containsKey("collections") || ( + intents.containsKey("actions") && + (intents.get("actions") instanceof JsonArray) && + intents.getArray("actions").size() > 0 + ) + ); } public void setContent(List content) { content.forEach(message -> message.setContent(message.getContent() - .replaceAll("/```\n\s*SELECT/gmi", "```sql\nSELECT") - .replaceAll("```\nUPDATE", "```sql\nUPDATE") - .replaceAll("```\nDELETE", "```sql\nDELETE") - .replaceAll("```\nCREATE", "```sql\nCREATE") + .replaceAll("/```\n\s*SELECT/gmi", "```sql\nSELECT") + .replaceAll("```\nUPDATE", "```sql\nUPDATE") + .replaceAll("```\nDELETE", "```sql\nDELETE") + .replaceAll("```\nCREATE", "```sql\nCREATE") // .replaceAll("```sql", "```sqlpp") )); TextFragment parseResult = ChatCompletionParser.parseGPT35TurboWithStream(content); diff --git a/src/main/resources/iq/intent_prompt.txt b/src/main/resources/iq/intent_prompt.txt index b7760c4a..379cce82 100644 --- a/src/main/resources/iq/intent_prompt.txt +++ b/src/main/resources/iq/intent_prompt.txt @@ -33,7 +33,7 @@ If you extracted any entities or identified actions or intents listed above, res ] } -You have access to user cluster via the plugin. Respond with the JSON to query necessary information from the plugin. +You have access to user cluster via the plugin. If user asks a question about their cluster, always respond in the JSON format listed above. If any additional information is needed, respond in the JSON format listed above. Do not EVER add any non-json text to your response with JSON. diff --git a/src/test/java/com/couchbase/intellij/tree/iq/AbstractCapellaTest.java b/src/test/java/com/couchbase/intellij/tree/iq/AbstractCapellaTest.java index 31f0a445..808452fb 100644 --- a/src/test/java/com/couchbase/intellij/tree/iq/AbstractCapellaTest.java +++ b/src/test/java/com/couchbase/intellij/tree/iq/AbstractCapellaTest.java @@ -5,7 +5,9 @@ import com.couchbase.intellij.tree.iq.chat.ConfigurationPage; import com.couchbase.intellij.tree.iq.core.IQCredentials; import com.couchbase.intellij.tree.iq.settings.OpenAISettingsState; +import com.couchbase.intellij.tree.iq.ui.ChatPanel; import com.couchbase.intellij.workbench.Log; +import org.mockito.Mockito; public abstract class AbstractCapellaTest extends AbstractIQTest { @Override @@ -30,7 +32,8 @@ protected void setUp() throws Exception { Log.setLevel(3); Log.setPrinter(new Log.StdoutPrinter()); - link = new ChatLinkService(getProject(), null, cp); + ChatLinkService link = new ChatLinkService(getProject(), null, cp); + panel = new ChatPanel(getProject(), iqGptConfig, Mockito.mock(CapellaOrganizationList.class), Mockito.mock(CapellaOrganization.class), Mockito.mock(ChatPanel.LogoutListener.class), Mockito.mock(ChatPanel.OrganizationListener.class)); ctx = new ChatLinkState(cp); } diff --git a/src/test/java/com/couchbase/intellij/tree/iq/AbstractIQTest.java b/src/test/java/com/couchbase/intellij/tree/iq/AbstractIQTest.java index 150f5b8d..d2ed30b7 100644 --- a/src/test/java/com/couchbase/intellij/tree/iq/AbstractIQTest.java +++ b/src/test/java/com/couchbase/intellij/tree/iq/AbstractIQTest.java @@ -1,47 +1,79 @@ package com.couchbase.intellij.tree.iq; +import com.couchbase.client.java.Cluster; import com.couchbase.client.java.json.JsonArray; import com.couchbase.client.java.json.JsonObject; -import com.couchbase.intellij.tree.iq.chat.ChatExchangeAbortException; -import com.couchbase.intellij.tree.iq.chat.ChatGptHandler; -import com.couchbase.intellij.tree.iq.chat.ChatLink; -import com.couchbase.intellij.tree.iq.chat.ChatLinkService; -import com.couchbase.intellij.tree.iq.chat.ChatLinkState; -import com.couchbase.intellij.tree.iq.chat.ChatMessageEvent; -import com.couchbase.intellij.tree.iq.chat.ChatMessageListener; -import com.couchbase.intellij.tree.iq.chat.ConfigurationPage; -import com.couchbase.intellij.tree.iq.chat.ConversationContext; -import com.couchbase.intellij.tree.iq.core.IQCredentials; +import com.couchbase.intellij.database.ActiveCluster; +import com.couchbase.intellij.database.QueryContext; +import com.couchbase.intellij.tree.iq.chat.*; +import com.couchbase.intellij.tree.iq.intents.IntentProcessor; import com.couchbase.intellij.tree.iq.intents.actions.ActionInterface; -import com.couchbase.intellij.tree.iq.settings.OpenAISettingsState; -import com.couchbase.intellij.workbench.Log; +import com.couchbase.intellij.tree.iq.ui.ChatPanel; +import com.couchbase.intellij.utils.Subscribable; import com.intellij.testFramework.fixtures.BasePlatformTestCase; -import com.theokanning.openai.completion.chat.ChatCompletionRequest; -import com.theokanning.openai.completion.chat.ChatMessage; -import com.theokanning.openai.completion.chat.ChatMessageRole; +import com.theokanning.openai.completion.chat.*; +import org.mockito.Mock; +import org.mockito.Mockito; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; import java.util.function.Consumer; public abstract class AbstractIQTest extends BasePlatformTestCase { protected static final String IQ_URL = System.getenv("CAPELLA_DOMAIN") + "/v2/organizations/%s/integrations/iq/"; private static final ChatGptHandler handler = new ChatGptHandler(); protected static ConversationContext ctx; - protected static ChatLink link; + protected static ChatPanel panel; + protected static IntentProcessor processor = new IntentProcessor(); + protected static CapellaOrganization organization; + + protected static CapellaOrganizationList organizationList; + + protected static void mockCluster() { + Cluster cluster = Mockito.mock(Cluster.class); + ActiveCluster activeCluster = Mockito.mock(ActiveCluster.class); + QueryContext context = Mockito.mock(QueryContext.class); + Mockito.when(context.getBucket()).thenReturn("default"); + Mockito.when(context.getScope()).thenReturn("_default"); + Mockito.when(activeCluster.getQueryContext()).thenReturn(new Subscribable<>(context)); + Mockito.when(activeCluster.getCluster()).thenReturn(cluster); + ActiveCluster.setInstance(activeCluster); + } + + @Override + protected void setUp() throws Exception { + super.setUp(); + organization = new CapellaOrganization(); + organization.setId("orgid"); + organization.setName("test"); + CapellaOrganization.IQ iq = new CapellaOrganization.IQ(); + iq.setEnabled(true); + CapellaOrganization.Other other = new CapellaOrganization.Other(); + other.setIsTermsAcceptedForOrg(true); + iq.setOther(other); + organization.setIq(iq); + CapellaOrganizationList.Entry entry = new CapellaOrganizationList.Entry(); + entry.setData(organization); + organizationList = new CapellaOrganizationList(); + organizationList.setData(Arrays.asList(entry)); + } - protected void send(String message, Consumer listener) { - send(message, false, listener); + protected ChatCompletionResult send(String message) { + return send(message, false); } - protected void send(String message, boolean isSystem, Consumer listener) { + protected ChatCompletionResult send(String message, boolean isSystem) { + CompletableFuture future = new CompletableFuture<>(); assertNotNull(ctx); - assertNotNull(link); + assertNotNull(panel); ChatMessage chatMessage = new ChatMessage( isSystem ? ChatMessageRole.SYSTEM.value() : ChatMessageRole.USER.value(), message ); - ChatMessageEvent.Starting event = ChatMessageEvent.starting(AbstractIQTest.link, chatMessage); + ChatMessageEvent.Starting event = ChatMessageEvent.starting(AbstractIQTest.panel.getChatLink(), chatMessage); ctx.addChatMessage(chatMessage); List messages = ctx.getChatMessages(ctx.getModelType(), chatMessage); if (isSystem) { @@ -50,63 +82,28 @@ protected void send(String message, boolean isSystem, Consumer getIntents(ChatMessageEvent.ResponseArrived response, Class action) { + protected List getIntents(ChatCompletionResult response, Class action) { List results = new ArrayList<>(); JsonObject json = getJson(response); assertInstanceOf(json.get("actions"), JsonArray.class); @@ -135,9 +132,16 @@ public void assertCauseMessage(Throwable e, String msg) { throw new AssertionError("Exception was not caused by an error with message '" + msg + "'"); } - protected void assertResponseTextEquals(ChatMessageEvent.ResponseArrived responseArrived, String fakeConfusion) { - assertTrue(responseArrived.getResponseChoices().stream() - .filter(Objects::nonNull) - .anyMatch(choice -> choice.getContent() != null && choice.getContent().equals(fakeConfusion))); + protected void assertResponseTextEquals(ChatCompletionResult actual, String expected) { + assertEquals(expected, getResponse(actual)); + } + + protected String intentFeedback(ChatCompletionResult result) { + String response = getResponse(result); + JsonObject intents = JsonObject.create(); + if (response != "") { + intents = JsonObject.fromJson(response); + } + return processor.generateIntentReturnPrompt(panel, result.getChoices().get(0).getMessage(), intents); } } diff --git a/src/test/java/com/couchbase/intellij/tree/iq/AbstractMockedIQTest.java b/src/test/java/com/couchbase/intellij/tree/iq/AbstractMockedIQTest.java index 2714d75b..78bd9fd5 100644 --- a/src/test/java/com/couchbase/intellij/tree/iq/AbstractMockedIQTest.java +++ b/src/test/java/com/couchbase/intellij/tree/iq/AbstractMockedIQTest.java @@ -8,6 +8,7 @@ import com.couchbase.intellij.tree.iq.core.CapellaAuth; import com.couchbase.intellij.tree.iq.core.IQCredentials; import com.couchbase.intellij.tree.iq.settings.OpenAISettingsState; +import com.couchbase.intellij.tree.iq.ui.ChatPanel; import com.couchbase.intellij.workbench.Log; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; @@ -56,7 +57,7 @@ protected void setUp() throws Exception { Log.setLevel(3); Log.setPrinter(new Log.StdoutPrinter()); - link = new ChatLinkService(getProject(), null, cp); + panel = new ChatPanel(getProject(), iqGptConfig, organizationList, organization, Mockito.mock(ChatPanel.LogoutListener.class), Mockito.mock(ChatPanel.OrganizationListener.class)); ctx = new ChatLinkState(cp); } @@ -73,6 +74,7 @@ protected void enqueueResponse(String text) { packet.put("model", GPT_MODEL); JsonArray choices = JsonArray.create(); JsonObject choice = JsonObject.create(); + choices.add(choice); choice.put("index", 0); JsonObject message = JsonObject.create(); message.put("role", "assistant"); diff --git a/src/test/java/com/couchbase/intellij/tree/iq/intents/CapellaIntentProcessorTest.java b/src/test/java/com/couchbase/intellij/tree/iq/intents/CapellaIntentProcessorTest.java index 73fbcaa0..fd0c15b2 100644 --- a/src/test/java/com/couchbase/intellij/tree/iq/intents/CapellaIntentProcessorTest.java +++ b/src/test/java/com/couchbase/intellij/tree/iq/intents/CapellaIntentProcessorTest.java @@ -8,8 +8,8 @@ public class CapellaIntentProcessorTest extends AbstractCapellaTest { @Test public void testRespondsWithJson() throws Exception { - send("list all boolean fields in the airport collection", this::assertJsonResponse); - send("create a collection named 'test'", this::assertJsonResponse); - send("open an airport with id 'UULL'", this::assertJsonResponse); + assertJsonResponse(send("list all boolean fields in the airport collection")); + assertJsonResponse(send("create a collection named 'test'")); + assertJsonResponse(send("open an airport with id 'UULL'")); } } \ No newline at end of file diff --git a/src/test/java/com/couchbase/intellij/tree/iq/intents/MockedIntentProcessorTest.java b/src/test/java/com/couchbase/intellij/tree/iq/intents/MockedIntentProcessorTest.java index bdd29e69..4b55b206 100644 --- a/src/test/java/com/couchbase/intellij/tree/iq/intents/MockedIntentProcessorTest.java +++ b/src/test/java/com/couchbase/intellij/tree/iq/intents/MockedIntentProcessorTest.java @@ -3,15 +3,14 @@ import com.couchbase.client.java.json.JsonArray; import com.couchbase.client.java.json.JsonObject; import com.couchbase.intellij.tree.iq.AbstractMockedIQTest; -import com.couchbase.intellij.tree.iq.ui.ChatPanel; +import com.theokanning.openai.completion.chat.ChatCompletionResult; public class MockedIntentProcessorTest extends AbstractMockedIQTest { public void testErrorHandling() throws Exception { enqueueError(500, 500, "test error"); try { - send("boop", responseArrived -> { - }); + send("boop"); } catch (Throwable e) { assertCauseMessage(e, "test error"); return; @@ -20,45 +19,40 @@ public void testErrorHandling() throws Exception { } public void testEmptyJsonResponse() throws Exception { + mockCluster(); enqueueResponse(""); - send("boop", responseArrived -> { - assertResponseTextEquals(responseArrived, ChatPanel.FAKE_CONFUSION); - }); + assertEquals(IntentProcessor.PROMPT_ANSWER_USER, intentFeedback(send("boop"))); enqueueResponse("{ }"); - send("boop", responseArrived -> { - assertResponseTextEquals(responseArrived, ChatPanel.FAKE_CONFUSION); - }); + assertEquals(IntentProcessor.PROMPT_ANSWER_USER, intentFeedback(send("boop"))); enqueueResponse("{\n}"); - send("boop", responseArrived -> { - assertResponseTextEquals(responseArrived, ChatPanel.FAKE_CONFUSION); - }); + assertEquals(IntentProcessor.PROMPT_ANSWER_USER, intentFeedback(send("boop"))); JsonObject response = JsonObject.create(); response.putNull("actions"); enqueueResponse(response); - send("boop me!", responseArrived -> { - assertResponseTextEquals(responseArrived, ChatPanel.FAKE_CONFUSION); - }); + assertEquals(IntentProcessor.PROMPT_ANSWER_USER, intentFeedback(send("boop"))); JsonArray actions = JsonArray.create(); response.put("actions", actions); enqueueResponse(response); - send("boop me!", responseArrived -> { - assertResponseTextEquals(responseArrived, ChatPanel.FAKE_CONFUSION); - }); + assertEquals(IntentProcessor.PROMPT_ANSWER_USER, intentFeedback(send("boop"))); JsonObject action = JsonObject.create(); actions.add(action); enqueueResponse(response); - send("boop me!", responseArrived -> { - assertResponseTextEquals(responseArrived, ChatPanel.FAKE_CONFUSION); - }); + assertEquals(IntentProcessor.PROMPT_ANSWER_USER, intentFeedback(send("boop"))); + } + public void testRegularResponses() throws Exception { // validate that regular responses are passed correctly enqueueResponse("boop"); - send("beep", responseArrived -> { - assertResponseTextEquals(responseArrived, "boop"); - }); + ChatCompletionResult response = send("beep"); + assertResponseTextEquals(response, "boop"); } + public void testQueryResponses() throws Exception { + mockCluster(); + enqueueResponse("{\"ids\": [],\"query\": \"SELECT * FROM `airport` WHERE city = 'Washington, DC'\",\"fields\": [],\"collections\": [\"airport\"]}"); + assertEquals(IntentProcessor.PROMPT_ANSWER_USER, intentFeedback(send("beep"))); + } } diff --git a/src/test/java/com/couchbase/intellij/tree/iq/intents/actions/CreateCollectionTest.java b/src/test/java/com/couchbase/intellij/tree/iq/intents/actions/CreateCollectionTest.java index c5c3d182..eeadbc6c 100644 --- a/src/test/java/com/couchbase/intellij/tree/iq/intents/actions/CreateCollectionTest.java +++ b/src/test/java/com/couchbase/intellij/tree/iq/intents/actions/CreateCollectionTest.java @@ -6,6 +6,7 @@ import com.couchbase.intellij.testutil.TestActiveCluster; import com.couchbase.intellij.tree.iq.AbstractCapellaTest; import com.couchbase.intellij.tree.iq.AbstractIQTest; +import com.theokanning.openai.completion.chat.ChatCompletionResult; import org.junit.Test; import org.mockito.Mockito; @@ -22,27 +23,25 @@ protected void setUp() throws Exception { @Test public void test() throws Exception { - send("create a collection named 'test'", response -> { - JsonObject jsonResponse = getJson(response); - assertNotNull(jsonResponse); - List intents = getIntents(response, CreateCollection.class); - assertSize(1, intents); - CreateCollection action = new CreateCollection(); - String prompt = action.fire(getProject(), null, null, jsonResponse, intents.get(0)); - assertNotNull(prompt); - assertTrue(prompt.startsWith("ask the user in which bucket and scope")); - send(prompt, true, response2 -> { - assertNotJson(response2); - assertTrue(getResponse(response2).contains("bucket")); - assertTrue(getResponse(response2).contains("scope")); - send("travel-sample bucket and scope inventoery", response3 -> { - JsonObject jsonResponse2 = getJson(response3); - assertNotNull(jsonResponse2); - List intents2 = getIntents(response, CreateCollection.class); - assertSize(1, intents2); - }); - }); - }); + ChatCompletionResult response = send("create a collection named 'test'"); + + JsonObject jsonResponse = getJson(response); + assertNotNull(jsonResponse); + List intents = getIntents(response, CreateCollection.class); + assertSize(1, intents); + CreateCollection action = new CreateCollection(); + String prompt = action.fire(getProject(), null, null, jsonResponse, intents.get(0)); + assertNotNull(prompt); + assertTrue(prompt.startsWith("ask the user in which bucket and scope")); + ChatCompletionResult response2 = send(prompt, true); + assertNotJson(response2); + assertTrue(getResponse(response2).contains("bucket")); + assertTrue(getResponse(response2).contains("scope")); + ChatCompletionResult response3 = send("travel-sample bucket and scope inventory"); + JsonObject jsonResponse2 = getJson(response3); + assertNotNull(jsonResponse2); + List intents2 = getIntents(response, CreateCollection.class); + assertSize(1, intents2); } } \ No newline at end of file