Skip to content

Commit

Permalink
refactor: différentiante db and coll selection
Browse files Browse the repository at this point in the history
  • Loading branch information
alenakhineika committed Sep 14, 2024
1 parent a1a538b commit 5001e26
Showing 1 changed file with 43 additions and 30 deletions.
73 changes: 43 additions & 30 deletions src/participant/participant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ const COL_NAME_REGEX = `${COL_NAME_ID}: (.*)\n?`;

const MAX_MARKDOWN_LIST_LENGTH = 10;

export enum SELECT_NAMESPACE {
DATABASE = 'database',
COLLECTION = 'collection',
}

export function parseForDatabaseAndCollectionName(text: string): {
databaseName?: string;
collectionName?: string;
Expand All @@ -67,6 +72,11 @@ export default class ParticipantController {
_storageController: StorageController;
_chatResult?: ChatResult;

// This state exists only within a single request and is reset immediately
// when requestHandler() is invoked. As a result, the state is not shared
// across different chat conversations.
_isSelectingNamespace?: SELECT_NAMESPACE;

constructor({
connectionController,
storageController,
Expand Down Expand Up @@ -295,6 +305,7 @@ export default class ParticipantController {
if (!selectedName) {
return false;
}
this._isSelectingNamespace = SELECT_NAMESPACE.DATABASE;
return vscode.commands.executeCommand('workbench.action.chat.open', {
query: `@MongoDB /query ${selectedName || ''}`,
});
Expand Down Expand Up @@ -343,6 +354,7 @@ export default class ParticipantController {
if (!selectedName) {
return false;
}
this._isSelectingNamespace = SELECT_NAMESPACE.COLLECTION;
return vscode.commands.executeCommand('workbench.action.chat.open', {
query: `@MongoDB /query ${selectedName || ''}`,
});
Expand Down Expand Up @@ -454,8 +466,11 @@ export default class ParticipantController {
stream: vscode.ChatResponseStream,
token: vscode.CancellationToken
): Promise<{
databaseName?: string;
collectionName?: string;
namespace: {
databaseName?: string;
collectionName?: string;
};
namespaceHasChanged: boolean;
}> {
const historyWithNamespace = context.history
.filter((historyItem) => {
Expand All @@ -477,29 +492,22 @@ export default class ParticipantController {
collectionName: undefined,
};

const dataService = this._connectionController.getActiveDataService();
if (dataService) {
try {
const databases = await dataService.listDatabases();
const newDatabaseName = databases.find(
(db) => db.name === request.prompt
)?.name;
if (newDatabaseName) {
namespace.databaseName = newDatabaseName;
} else if (namespace.databaseName) {
const collections = await dataService.listCollections(
namespace.databaseName
);
const newCollectionName = collections.find(
(db) => db.name === request.prompt
)?.name;
if (newCollectionName) {
namespace.collectionName = newCollectionName;
}
}
} catch (error) {
// Do nothing.
}
if (this._isSelectingNamespace === SELECT_NAMESPACE.DATABASE) {
return {
namespace: {
databaseName: request.prompt,
collectionName: undefined,
},
namespaceHasChanged: true,
};
} else if (this._isSelectingNamespace === SELECT_NAMESPACE.COLLECTION) {
return {
namespace: {
databaseName: namespace.databaseName,
collectionName: request.prompt,
},
namespaceHasChanged: true,
};
}

if (!namespace.databaseName || !namespace.collectionName) {
Expand All @@ -515,11 +523,16 @@ export default class ParticipantController {
const namespaceFromPrompt = parseForDatabaseAndCollectionName(
responseContentWithNamespace
);
namespace.databaseName = namespaceFromPrompt.databaseName;
namespace.collectionName = namespaceFromPrompt.collectionName;
return {
namespace: namespaceFromPrompt,
namespaceHasChanged: true,
};
}

return namespace;
return {
namespace,
namespaceHasChanged: false,
};
}

async _askForNamespace(
Expand Down Expand Up @@ -615,7 +628,7 @@ export default class ParticipantController {
return { metadata: {} };
}

const namespace = await this._findNamespace(
const { namespace, namespaceHasChanged } = await this._findNamespace(
request,
context,
stream,
Expand All @@ -633,7 +646,7 @@ export default class ParticipantController {
});

let sampleDocuments = this._findSampleDocuments(context);
if (!sampleDocuments) {
if (namespaceHasChanged || !sampleDocuments) {
sampleDocuments = await this._fetchSampleDocuments(
namespace,
abortController.signal
Expand Down

0 comments on commit 5001e26

Please sign in to comment.