From 5001e2621b43e55e6799bb6447090d9a193b04d5 Mon Sep 17 00:00:00 2001 From: Alena Khineika Date: Sat, 14 Sep 2024 14:45:35 +0200 Subject: [PATCH] =?UTF-8?q?refactor:=20diff=C3=A9rentiante=20db=20and=20co?= =?UTF-8?q?ll=20selection?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/participant/participant.ts | 73 ++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 30 deletions(-) diff --git a/src/participant/participant.ts b/src/participant/participant.ts index 2f0c46c4..ea47521e 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -42,6 +42,11 @@ const COL_NAME_REGEX = `${COL_NAME_ID}: (.*)\n?`; const MAX_MARKDOWN_LIST_LENGTH = 10; +export enum SELECT_NAMESPACE { + DATABASE = 'database', + COLLECTION = 'collection', +} + export function parseForDatabaseAndCollectionName(text: string): { databaseName?: string; collectionName?: string; @@ -67,6 +72,11 @@ export default class ParticipantController { _storageController: StorageController; _chatResult?: ChatResult; + // This state exists only within a single request and is reset immediately + // when requestHandler() is invoked. As a result, the state is not shared + // across different chat conversations. + _isSelectingNamespace?: SELECT_NAMESPACE; + constructor({ connectionController, storageController, @@ -295,6 +305,7 @@ export default class ParticipantController { if (!selectedName) { return false; } + this._isSelectingNamespace = SELECT_NAMESPACE.DATABASE; return vscode.commands.executeCommand('workbench.action.chat.open', { query: `@MongoDB /query ${selectedName || ''}`, }); @@ -343,6 +354,7 @@ export default class ParticipantController { if (!selectedName) { return false; } + this._isSelectingNamespace = SELECT_NAMESPACE.COLLECTION; return vscode.commands.executeCommand('workbench.action.chat.open', { query: `@MongoDB /query ${selectedName || ''}`, }); @@ -454,8 +466,11 @@ export default class ParticipantController { stream: vscode.ChatResponseStream, token: vscode.CancellationToken ): Promise<{ - databaseName?: string; - collectionName?: string; + namespace: { + databaseName?: string; + collectionName?: string; + }; + namespaceHasChanged: boolean; }> { const historyWithNamespace = context.history .filter((historyItem) => { @@ -477,29 +492,22 @@ export default class ParticipantController { collectionName: undefined, }; - const dataService = this._connectionController.getActiveDataService(); - if (dataService) { - try { - const databases = await dataService.listDatabases(); - const newDatabaseName = databases.find( - (db) => db.name === request.prompt - )?.name; - if (newDatabaseName) { - namespace.databaseName = newDatabaseName; - } else if (namespace.databaseName) { - const collections = await dataService.listCollections( - namespace.databaseName - ); - const newCollectionName = collections.find( - (db) => db.name === request.prompt - )?.name; - if (newCollectionName) { - namespace.collectionName = newCollectionName; - } - } - } catch (error) { - // Do nothing. - } + if (this._isSelectingNamespace === SELECT_NAMESPACE.DATABASE) { + return { + namespace: { + databaseName: request.prompt, + collectionName: undefined, + }, + namespaceHasChanged: true, + }; + } else if (this._isSelectingNamespace === SELECT_NAMESPACE.COLLECTION) { + return { + namespace: { + databaseName: namespace.databaseName, + collectionName: request.prompt, + }, + namespaceHasChanged: true, + }; } if (!namespace.databaseName || !namespace.collectionName) { @@ -515,11 +523,16 @@ export default class ParticipantController { const namespaceFromPrompt = parseForDatabaseAndCollectionName( responseContentWithNamespace ); - namespace.databaseName = namespaceFromPrompt.databaseName; - namespace.collectionName = namespaceFromPrompt.collectionName; + return { + namespace: namespaceFromPrompt, + namespaceHasChanged: true, + }; } - return namespace; + return { + namespace, + namespaceHasChanged: false, + }; } async _askForNamespace( @@ -615,7 +628,7 @@ export default class ParticipantController { return { metadata: {} }; } - const namespace = await this._findNamespace( + const { namespace, namespaceHasChanged } = await this._findNamespace( request, context, stream, @@ -633,7 +646,7 @@ export default class ParticipantController { }); let sampleDocuments = this._findSampleDocuments(context); - if (!sampleDocuments) { + if (namespaceHasChanged || !sampleDocuments) { sampleDocuments = await this._fetchSampleDocuments( namespace, abortController.signal