From 6322107fb322cd57ddb33bdf111c1d70f04cb718 Mon Sep 17 00:00:00 2001 From: Rhys Howell Date: Tue, 17 Sep 2024 17:32:45 -0400 Subject: [PATCH 1/2] feat(chat): add /schema command handler --- package.json | 13 + src/commands/index.ts | 1 + src/mdbExtensionController.ts | 32 +- src/participant/markdown.ts | 2 + src/participant/participant.ts | 385 ++++++++++++++---- src/participant/prompts/generic.ts | 4 + src/participant/prompts/history.ts | 13 + src/participant/prompts/schema.ts | 85 ++++ src/participant/schema.ts | 8 +- .../suite/participant/participant.test.ts | 6 +- 10 files changed, 454 insertions(+), 95 deletions(-) create mode 100644 src/participant/prompts/schema.ts diff --git a/package.json b/package.json index 7b742b4a..1e442b80 100644 --- a/package.json +++ b/package.json @@ -94,6 +94,11 @@ "name": "query", "isSticky": true, "description": "Ask how to write MongoDB queries or pipelines. For example, you can ask: \"Show me all the documents where the address contains the word street\"." + }, + { + "name": "schema", + "isSticky": true, + "description": "Analyze a collection's schema." } ] } @@ -172,6 +177,10 @@ "command": "mdb.selectCollectionWithParticipant", "title": "MongoDB: Select Collection with Participant" }, + { + "command": "mdb.participantViewRawSchemaOutput", + "title": "MongoDB: View Raw Schema JSON Output" + }, { "command": "mdb.connectWithParticipant", "title": "MongoDB: Change Active Connection with Participant" @@ -742,6 +751,10 @@ "command": "mdb.selectCollectionWithParticipant", "when": "false" }, + { + "command": "mdb.participantViewRawSchemaOutput", + "when": "false" + }, { "command": "mdb.connectWithParticipant", "when": "false" diff --git a/src/commands/index.ts b/src/commands/index.ts index 2a11bf67..767b99ff 100644 --- a/src/commands/index.ts +++ b/src/commands/index.ts @@ -82,6 +82,7 @@ enum EXTENSION_COMMANDS { CONNECT_WITH_PARTICIPANT = 'mdb.connectWithParticipant', SELECT_DATABASE_WITH_PARTICIPANT = 'mdb.selectDatabaseWithParticipant', SELECT_COLLECTION_WITH_PARTICIPANT = 'mdb.selectCollectionWithParticipant', + PARTICIPANT_OPEN_RAW_SCHEMA_OUTPUT = 'mdb.participantViewRawSchemaOutput', } export default EXTENSION_COMMANDS; diff --git a/src/mdbExtensionController.ts b/src/mdbExtensionController.ts index 17cda055..918b3dea 100644 --- a/src/mdbExtensionController.ts +++ b/src/mdbExtensionController.ts @@ -44,6 +44,7 @@ import { ConnectionStorage } from './storage/connectionStorage'; import type StreamProcessorTreeItem from './explorer/streamProcessorTreeItem'; import type { RunParticipantQueryCommandArgs } from './participant/participant'; import ParticipantController from './participant/participant'; +import type { OpenSchemaCommandArgs } from './participant/prompts/schema'; // This class is the top-level controller for our extension. // Commands which the extensions handles are defined in the function `activate`. @@ -308,30 +309,37 @@ export default class MDBExtensionController implements vscode.Disposable { ); this.registerCommand( EXTENSION_COMMANDS.CONNECT_WITH_PARTICIPANT, - (id?: string) => - this._participantController.connectWithParticipant( - id ? decodeURIComponent(id) : id - ) + (_data: string) => { + const data = JSON.parse(decodeURIComponent(_data)); + return this._participantController.connectWithParticipant(data); + } ); this.registerCommand( EXTENSION_COMMANDS.SELECT_DATABASE_WITH_PARTICIPANT, (_data: string) => { const data = JSON.parse(decodeURIComponent(_data)); - return this._participantController.selectDatabaseWithParticipant({ - chatId: data.chatId, - databaseName: data.databaseName, - }); + return this._participantController.selectDatabaseWithParticipant(data); } ); this.registerCommand( EXTENSION_COMMANDS.SELECT_COLLECTION_WITH_PARTICIPANT, (_data: string) => { const data = JSON.parse(decodeURIComponent(_data)); - return this._participantController.selectCollectionWithParticipant({ - chatId: data.chatId, - databaseName: data.databaseName, - collectionName: data.collectionName, + return this._participantController.selectCollectionWithParticipant( + data + ); + } + ); + this.registerCommand( + EXTENSION_COMMANDS.PARTICIPANT_OPEN_RAW_SCHEMA_OUTPUT, + async ({ schema }: OpenSchemaCommandArgs) => { + const document = await vscode.workspace.openTextDocument({ + language: 'json', + content: schema, }); + await vscode.window.showTextDocument(document, { preview: true }); + + return !!document; } ); }; diff --git a/src/participant/markdown.ts b/src/participant/markdown.ts index 75ca9ed4..ec959dcb 100644 --- a/src/participant/markdown.ts +++ b/src/participant/markdown.ts @@ -6,6 +6,8 @@ export function createMarkdownLink({ name, }: { commandId: string; + // TODO: Create types for this data so we can also then use them on the extension + // controller when we parse the result. data?: | { [field: string]: any; diff --git a/src/participant/participant.ts b/src/participant/participant.ts index 97ec53f4..59e4d682 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -1,5 +1,5 @@ import * as vscode from 'vscode'; -import { getSimplifiedSchema } from 'mongodb-schema'; +import { getSimplifiedSchema, parseSchema } from 'mongodb-schema'; import type { Document } from 'bson'; import { createLogger } from '../logging'; @@ -8,7 +8,7 @@ import type { LoadedConnection } from '../storage/connectionStorage'; import EXTENSION_COMMANDS from '../commands'; import type { StorageController } from '../storage'; import { StorageVariables } from '../storage'; -import { GenericPrompt } from './prompts/generic'; +import { GenericPrompt, isPromptEmpty } from './prompts/generic'; import { AskToConnectChatResult, CHAT_PARTICIPANT_ID, @@ -22,6 +22,12 @@ import { getSimplifiedSampleDocuments } from './sampleDocuments'; import { getCopilotModel } from './model'; import { createMarkdownLink } from './markdown'; import { ChatMetadataStore } from './chatMetadata'; +import { doesLastMessageAskForNamespace } from './prompts/history'; +import { + DOCUMENTS_TO_SAMPLE_FOR_SCHEMA_PROMPT, + type OpenSchemaCommandArgs, + SchemaPrompt, +} from './prompts/schema'; const log = createLogger('participant'); @@ -39,6 +45,8 @@ export type RunParticipantQueryCommandArgs = { const DB_NAME_REGEX = `${DB_NAME_ID}: (.*)`; const COL_NAME_REGEX = `${COL_NAME_ID}: (.*)`; +type ParticipantCommand = '/query' | '/schema'; + const MAX_MARKDOWN_LIST_LENGTH = 10; export function parseForDatabaseAndCollectionName(text: string): { @@ -211,7 +219,13 @@ export default class ParticipantController { }; } - async connectWithParticipant(id?: string): Promise { + async connectWithParticipant({ + id, + command, + }: { + id?: string; + command?: string; + }): Promise { if (!id) { const didChangeActiveConnection = await this._connectionController.changeActiveConnection(); @@ -226,11 +240,11 @@ export default class ParticipantController { const connectionName = this._connectionController.getActiveConnectionName(); return this.writeChatMessageAsUser( - `/query ${connectionName}` + `${command ? `${command} ` : ''}${connectionName}` ) as Promise; } - getConnectionsTree(): vscode.MarkdownString[] { + getConnectionsTree(command: ParticipantCommand): vscode.MarkdownString[] { return [ ...this._connectionController .getSavedConnections() @@ -243,22 +257,30 @@ export default class ParticipantController { .map((conn: LoadedConnection) => createMarkdownLink({ commandId: EXTENSION_COMMANDS.CONNECT_WITH_PARTICIPANT, - data: conn.id, + data: { + id: conn.id, + command, + }, name: conn.name, }) ), createMarkdownLink({ commandId: EXTENSION_COMMANDS.CONNECT_WITH_PARTICIPANT, name: 'Show more', + data: { + command, + }, }), ]; } - async getDatabaseQuickPicks(): Promise { + async getDatabaseQuickPicks( + command: ParticipantCommand + ): Promise { const dataService = this._connectionController.getActiveDataService(); if (!dataService) { // Run a blank command to get the user to connect first. - void this.writeChatMessageAsUser('/query'); + void this.writeChatMessageAsUser(command); return []; } @@ -275,8 +297,10 @@ export default class ParticipantController { } } - async _selectDatabaseWithQuickPick(): Promise { - const databases = await this.getDatabaseQuickPicks(); + async _selectDatabaseWithQuickPick( + command: ParticipantCommand + ): Promise { + const databases = await this.getDatabaseQuickPicks(command); const selectedQuickPickItem = await vscode.window.showQuickPick(databases, { placeHolder: 'Select a database...', }); @@ -285,14 +309,16 @@ export default class ParticipantController { async selectDatabaseWithParticipant({ chatId, + command, databaseName: _databaseName, }: { chatId: string; + command: ParticipantCommand; databaseName?: string; }): Promise { let databaseName: string | undefined = _databaseName; if (!databaseName) { - databaseName = await this._selectDatabaseWithQuickPick(); + databaseName = await this._selectDatabaseWithQuickPick(command); if (!databaseName) { return false; } @@ -303,17 +329,21 @@ export default class ParticipantController { }); return this.writeChatMessageAsUser( - `/query ${databaseName}` + `${command} ${databaseName}` ) as Promise; } - async getCollectionQuickPicks( - databaseName: string - ): Promise { + async getCollectionQuickPicks({ + command, + databaseName, + }: { + command: ParticipantCommand; + databaseName: string; + }): Promise { const dataService = this._connectionController.getActiveDataService(); if (!dataService) { // Run a blank command to get the user to connect first. - void this.writeChatMessageAsUser('/query'); + void this.writeChatMessageAsUser(command); return []; } @@ -328,10 +358,17 @@ export default class ParticipantController { } } - async _selectCollectionWithQuickPick( - databaseName: string - ): Promise { - const collections = await this.getCollectionQuickPicks(databaseName); + async _selectCollectionWithQuickPick({ + command, + databaseName, + }: { + command: ParticipantCommand; + databaseName: string; + }): Promise { + const collections = await this.getCollectionQuickPicks({ + command, + databaseName, + }); const selectedQuickPickItem = await vscode.window.showQuickPick( collections, { @@ -342,17 +379,22 @@ export default class ParticipantController { } async selectCollectionWithParticipant({ + command, chatId, databaseName, collectionName: _collectionName, }: { + command: ParticipantCommand; chatId: string; databaseName: string; collectionName?: string; }): Promise { let collectionName: string | undefined = _collectionName; if (!collectionName) { - collectionName = await this._selectCollectionWithQuickPick(databaseName); + collectionName = await this._selectCollectionWithQuickPick({ + command, + databaseName, + }); if (!collectionName) { return false; } @@ -363,13 +405,17 @@ export default class ParticipantController { collectionName: collectionName, }); return this.writeChatMessageAsUser( - `/query ${collectionName}` + `${command} ${collectionName}` ) as Promise; } - async getDatabasesTree( - context: vscode.ChatContext - ): Promise { + async getDatabasesTree({ + command, + context, + }: { + command: ParticipantCommand; + context: vscode.ChatContext; + }): Promise { const dataService = this._connectionController.getActiveDataService(); if (!dataService) { return []; @@ -384,6 +430,7 @@ export default class ParticipantController { createMarkdownLink({ commandId: EXTENSION_COMMANDS.SELECT_DATABASE_WITH_PARTICIPANT, data: { + command, chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( context.history ), @@ -396,6 +443,7 @@ export default class ParticipantController { ? [ createMarkdownLink({ data: { + command, chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( context.history ), @@ -412,10 +460,15 @@ export default class ParticipantController { } } - async getCollectionTree( - databaseName: string, - context: vscode.ChatContext - ): Promise { + async getCollectionTree({ + command, + context, + databaseName, + }: { + command: ParticipantCommand; + databaseName: string; + context: vscode.ChatContext; + }): Promise { const dataService = this._connectionController.getActiveDataService(); if (!dataService) { return []; @@ -428,6 +481,7 @@ export default class ParticipantController { createMarkdownLink({ commandId: EXTENSION_COMMANDS.SELECT_COLLECTION_WITH_PARTICIPANT, data: { + command, chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( context.history ), @@ -443,6 +497,7 @@ export default class ParticipantController { commandId: EXTENSION_COMMANDS.SELECT_COLLECTION_WITH_PARTICIPANT, data: { + command, chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( context.history ), @@ -505,11 +560,13 @@ export default class ParticipantController { } async _askForNamespace({ + command, context, databaseName, collectionName, stream, }: { + command: ParticipantCommand; context: vscode.ChatContext; databaseName: string | undefined; collectionName: string | undefined; @@ -519,17 +576,26 @@ export default class ParticipantController { // we retrieve the available namespaces from the current connection. // Users can then select a value by clicking on an item in the list. if (!databaseName) { - const tree = await this.getDatabasesTree(context); + const tree = await this.getDatabasesTree({ + command, + context, + }); stream.markdown( - 'What is the name of the database you would like this query to run against?\n\n' + `What is the name of the database you would like ${ + command === '/query' ? 'this query' : '' + } to run against?\n\n` ); for (const item of tree) { stream.markdown(item); } } else if (!collectionName) { - const tree = await this.getCollectionTree(databaseName, context); + const tree = await this.getCollectionTree({ + command, + databaseName, + context, + }); stream.markdown( - `Which collection would you like to query within ${databaseName}?\n\n` + `Which collection would you like to use within ${databaseName}?\n\n` ); for (const item of tree) { stream.markdown(item); @@ -543,15 +609,20 @@ export default class ParticipantController { }); } - _askToConnect( - context: vscode.ChatContext, - stream: vscode.ChatResponseStream - ): vscode.ChatResult { - const tree = this.getConnectionsTree(); + _askToConnect({ + command, + context, + stream, + }: { + command: ParticipantCommand; + context: vscode.ChatContext; + stream: vscode.ChatResponseStream; + }): vscode.ChatResult { stream.markdown( "Looks like you aren't currently connected, first let's get you connected to the cluster we'd like to create this query to run against.\n\n" ); + const tree = this.getConnectionsTree(command); for (const item of tree) { stream.markdown(item); } @@ -559,39 +630,54 @@ export default class ParticipantController { } // The sample documents returned from this are simplified (strings and arrays shortened). + // The sample documents are only returned when a user has the setting enabled. async _fetchCollectionSchemaAndSampleDocuments({ abortSignal, databaseName, collectionName, + amountOfDocumentsToSample = NUM_DOCUMENTS_TO_SAMPLE, + schemaFormat = 'simplified', }: { abortSignal; databaseName: string; collectionName: string; + amountOfDocumentsToSample?: number; + schemaFormat?: 'simplified' | 'full'; }): Promise<{ schema?: string; sampleDocuments?: Document[]; + amountOfDocumentsSampled: number; }> { const dataService = this._connectionController.getActiveDataService(); if (!dataService) { - return {}; + return { + amountOfDocumentsSampled: 0, + }; } try { - const sampleDocuments = - (await dataService?.sample?.( - `${databaseName}.${collectionName}`, - { - query: {}, - size: NUM_DOCUMENTS_TO_SAMPLE, - }, - { promoteValues: false }, - { - abortSignal, - } - )) || []; - - const unformattedSchema = await getSimplifiedSchema(sampleDocuments); - const schema = new SchemaFormatter().format(unformattedSchema); + const sampleDocuments = await dataService.sample( + `${databaseName}.${collectionName}`, + { + query: {}, + size: amountOfDocumentsToSample, + }, + { promoteValues: false }, + { + abortSignal, + } + ); + + let schema: string; + if (schemaFormat === 'simplified') { + const unformattedSchema = await getSimplifiedSchema(sampleDocuments); + schema = new SchemaFormatter().format(unformattedSchema); + } else { + const unformattedSchema = await parseSchema(sampleDocuments, { + storeValues: false, + }); + schema = JSON.stringify(unformattedSchema, null, 2); + } const useSampleDocsInCopilot = !!vscode.workspace .getConfiguration('mdb') @@ -602,29 +688,25 @@ export default class ParticipantController { ? getSimplifiedSampleDocuments(sampleDocuments) : undefined, schema, + amountOfDocumentsSampled: sampleDocuments.length, }; } catch (err: any) { log.error('Unable to fetch schema and sample documents', err); - return {}; + throw err; } } - async handleEmptyQueryRequest({ + async handleEmptyNamespaceMessage({ + command, context, stream, }: { + command: ParticipantCommand; context: vscode.ChatContext; stream: vscode.ChatResponseStream; }): Promise { const lastMessageMetaData: vscode.ChatResponseTurn | undefined = context .history[context.history.length - 1] as vscode.ChatResponseTurn; - if ( - (lastMessageMetaData?.result as NamespaceRequestChatResult)?.metadata - ?.intent !== 'askForNamespace' - ) { - stream.markdown(GenericPrompt.getEmptyRequestResponse()); - return new EmptyRequestChatResult(context.history); - } // When the last message was asking for a database or collection name, // we re-ask the question. @@ -638,14 +720,21 @@ export default class ParticipantController { 'Please select a collection by either clicking on an item in the list or typing the name manually in the chat.' ) ); - tree = await this.getCollectionTree(databaseName, context); + tree = await this.getCollectionTree({ + command, + databaseName, + context, + }); } else { stream.markdown( vscode.l10n.t( 'Please select a database by either clicking on an item in the list or typing the name manually in the chat.' ) ); - tree = await this.getDatabasesTree(context); + tree = await this.getDatabasesTree({ + command, + context, + }); } for (const item of tree) { @@ -659,28 +748,32 @@ export default class ParticipantController { }); } - // @MongoDB /query find all documents where the "address" has the word Broadway in it. - async handleQueryRequest( + // @MongoDB /schema + async handleSchemaRequest( request: vscode.ChatRequest, context: vscode.ChatContext, stream: vscode.ChatResponseStream, token: vscode.CancellationToken ): Promise { if (!this._connectionController.getActiveDataService()) { - return this._askToConnect(context, stream); + return this._askToConnect({ + command: '/schema', + context, + stream, + }); } - if (!request.prompt || request.prompt.trim().length === 0) { - return this.handleEmptyQueryRequest({ + if ( + isPromptEmpty(request) && + doesLastMessageAskForNamespace(context.history) + ) { + return this.handleEmptyNamespaceMessage({ + command: '/schema', context, stream, }); } - // We "prompt chain" to handle the query requests. - // First we ask the model to parse for the database and collection name. - // If they exist, we can then use them in our final completion. - // When they don't exist we ask the user for them. const { databaseName, collectionName } = await this._getNamespaceFromChat({ request, context, @@ -689,6 +782,7 @@ export default class ParticipantController { }); if (!databaseName || !collectionName) { return await this._askForNamespace({ + command: '/schema', context, databaseName, collectionName, @@ -701,12 +795,149 @@ export default class ParticipantController { abortController.abort(); }); - const { schema, sampleDocuments } = - await this._fetchCollectionSchemaAndSampleDocuments({ + stream.push( + new vscode.ChatResponseProgressPart( + 'Fetching documents and analyzing schema...' + ) + ); + + let sampleDocuments: Document[] | undefined; + let amountOfDocumentsSampled: number; + let schema: string | undefined; + try { + ({ + sampleDocuments, + amountOfDocumentsSampled, // There can be fewer than the amount we attempt to sample. + schema, + } = await this._fetchCollectionSchemaAndSampleDocuments({ abortSignal: abortController.signal, databaseName, + schemaFormat: 'full', collectionName, + amountOfDocumentsToSample: DOCUMENTS_TO_SAMPLE_FOR_SCHEMA_PROMPT, + })); + + if (!schema || amountOfDocumentsSampled === 0) { + stream.markdown( + vscode.l10n.t( + 'Unable to generate a schema from the collection, no documents found.' + ) + ); + return { metadata: {} }; + } + } catch (e) { + stream.markdown( + vscode.l10n.t( + `Unable to generate a schema from the collection, an error occurred: ${e}` + ) + ); + return { metadata: {} }; + } + + const messages = SchemaPrompt.buildMessages({ + request, + context, + databaseName, + amountOfDocumentsSampled, + collectionName, + schema, + connectionNames: this._connectionController + .getSavedConnections() + .map((connection) => connection.name), + ...(sampleDocuments ? { sampleDocuments } : {}), + }); + const responseContent = await this.getChatResponseContent({ + messages, + stream, + token, + }); + stream.markdown(responseContent); + + stream.button({ + command: EXTENSION_COMMANDS.PARTICIPANT_OPEN_RAW_SCHEMA_OUTPUT, + title: vscode.l10n.t('Open JSON Output'), + arguments: [ + { + schema, + } as OpenSchemaCommandArgs, + ], + }); + + return { metadata: {} }; + } + + // @MongoDB /query find all documents where the "address" has the word Broadway in it. + async handleQueryRequest( + request: vscode.ChatRequest, + context: vscode.ChatContext, + stream: vscode.ChatResponseStream, + token: vscode.CancellationToken + ): Promise { + if (!this._connectionController.getActiveDataService()) { + return this._askToConnect({ + command: '/query', + context, + stream, + }); + } + + if (isPromptEmpty(request)) { + if (doesLastMessageAskForNamespace(context.history)) { + return this.handleEmptyNamespaceMessage({ + command: '/query', + context, + stream, + }); + } + + stream.markdown(QueryPrompt.getEmptyRequestResponse()); + return new EmptyRequestChatResult(context.history); + } + + // We "prompt chain" to handle the query requests. + // First we ask the model to parse for the database and collection name. + // If they exist, we can then use them in our final completion. + // When they don't exist we ask the user for them. + const { databaseName, collectionName } = await this._getNamespaceFromChat({ + request, + context, + stream, + token, + }); + if (!databaseName || !collectionName) { + return await this._askForNamespace({ + command: '/query', + context, + databaseName, + collectionName, + stream, }); + } + + const abortController = new AbortController(); + token.onCancellationRequested(() => { + abortController.abort(); + }); + + let schema: string | undefined; + let sampleDocuments: Document[] | undefined; + try { + ({ schema, sampleDocuments } = + await this._fetchCollectionSchemaAndSampleDocuments({ + abortSignal: abortController.signal, + databaseName, + collectionName, + })); + } catch (e) { + // When an error fetching the collection schema or sample docs occurs, + // we still want to continue as it isn't critical, however, + // we do want to notify the user. + stream.markdown( + vscode.l10n.t( + 'An error occurred while fetching the collection schema and sample documents.\nThe generated query will not be able to reference your data.' + ) + ); + } const messages = await QueryPrompt.buildMessages({ request, @@ -786,7 +1017,7 @@ export default class ParticipantController { } else if (request.command === 'docs') { // TODO(VSCODE-570): Implement this. } else if (request.command === 'schema') { - // TODO(VSCODE-571): Implement this. + return await this.handleSchemaRequest(...args); } return await this.handleGenericRequest(...args); } diff --git a/src/participant/prompts/generic.ts b/src/participant/prompts/generic.ts index d176fead..733a86fe 100644 --- a/src/participant/prompts/generic.ts +++ b/src/participant/prompts/generic.ts @@ -46,3 +46,7 @@ Respond in MongoDB shell syntax using the \`\`\`javascript code block syntax.`; return messages; } } + +export function isPromptEmpty(request: vscode.ChatRequest): boolean { + return !request.prompt || request.prompt.trim().length === 0; +} diff --git a/src/participant/prompts/history.ts b/src/participant/prompts/history.ts index fe13d2cd..726640f8 100644 --- a/src/participant/prompts/history.ts +++ b/src/participant/prompts/history.ts @@ -68,3 +68,16 @@ export function getHistoryMessages({ return messages; } + +export function doesLastMessageAskForNamespace( + history: ReadonlyArray +): boolean { + const lastMessageMetaData: vscode.ChatResponseTurn | undefined = history[ + history.length - 1 + ] as vscode.ChatResponseTurn; + + return ( + (lastMessageMetaData?.result as NamespaceRequestChatResult)?.metadata + ?.intent === 'askForNamespace' + ); +} diff --git a/src/participant/prompts/schema.ts b/src/participant/prompts/schema.ts new file mode 100644 index 00000000..ffbe6ba6 --- /dev/null +++ b/src/participant/prompts/schema.ts @@ -0,0 +1,85 @@ +import * as vscode from 'vscode'; + +import { getHistoryMessages } from './history'; + +export const DOCUMENTS_TO_SAMPLE_FOR_SCHEMA_PROMPT = 100; + +export type OpenSchemaCommandArgs = { + schema: string; +}; + +export class SchemaPrompt { + static getAssistantPrompt({ + amountOfDocumentsSampled, + }: { + amountOfDocumentsSampled: number; + }): vscode.LanguageModelChatMessage { + const prompt = `You are a senior engineer who describes the schema of documents in a MongoDB database. +The schema is generated from a sample of documents in the user's collection. +You must follows these rules. +Rule 1: Try to be as concise as possible. +Rule 2: Pay attention to the JSON schema. +Rule 3: Mention the amount of documents sampled in your response. +Amount of documents sampled: ${amountOfDocumentsSampled}.`; + + // eslint-disable-next-line new-cap + return vscode.LanguageModelChatMessage.Assistant(prompt); + } + + static getUserPrompt({ + databaseName, + collectionName, + prompt, + schema, + }: { + databaseName: string; + collectionName: string; + prompt: string; + schema: string; + }): vscode.LanguageModelChatMessage { + const userInput = `${ + prompt ? `The user provided additional information: "${prompt}"\n` : '' + }Database name: ${databaseName} +Collection name: ${collectionName} +Schema: +${schema}`; + + // eslint-disable-next-line new-cap + return vscode.LanguageModelChatMessage.User(userInput); + } + + static buildMessages({ + context, + databaseName, + collectionName, + schema, + amountOfDocumentsSampled, + request, + connectionNames, + }: { + request: { + prompt: string; + }; + databaseName: string; + collectionName: string; + schema: string; + amountOfDocumentsSampled: number; + context: vscode.ChatContext; + connectionNames: string[]; + }): vscode.LanguageModelChatMessage[] { + const messages = [ + SchemaPrompt.getAssistantPrompt({ + amountOfDocumentsSampled, + }), + ...getHistoryMessages({ context, connectionNames }), + SchemaPrompt.getUserPrompt({ + prompt: request.prompt, + databaseName, + collectionName, + schema, + }), + ]; + + return messages; + } +} diff --git a/src/participant/schema.ts b/src/participant/schema.ts index faf1e7fb..ad247eb7 100644 --- a/src/participant/schema.ts +++ b/src/participant/schema.ts @@ -22,13 +22,13 @@ export class SchemaFormatter { private processSchemaTypeList( prefix: string, pTypes: SimplifiedSchemaType[] - ) { + ): void { if (pTypes.length !== 0) { this.processSchemaType(prefix, pTypes[0]); } } - private processSchemaType(prefix: string, pType: SimplifiedSchemaType) { + private processSchemaType(prefix: string, pType: SimplifiedSchemaType): void { const bsonType = pType.bsonType; if (bsonType === 'Document') { const fields = (pType as SimplifiedSchemaDocumentType).fields; @@ -67,7 +67,7 @@ export class SchemaFormatter { this.addToFormattedSchemaString(prefix + ': ' + bsonType); } - private processDocumentType(prefix: string, pDoc: SimplifiedSchema) { + private processDocumentType(prefix: string, pDoc: SimplifiedSchema): void { if (!pDoc) { return; } @@ -93,7 +93,7 @@ export class SchemaFormatter { } } - addToFormattedSchemaString(fieldAndType: string) { + addToFormattedSchemaString(fieldAndType: string): void { if (this.schemaString.length > 0) { this.schemaString += '\n'; } diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 37897911..5d1711a9 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -291,12 +291,14 @@ suite('Participant Controller Test Suite', function () { }); test('calls connect by id for an existing connection', async function () { - await testParticipantController.connectWithParticipant('123'); + await testParticipantController.connectWithParticipant({ + id: '123', + }); expect(connectWithConnectionIdStub).to.have.been.calledWithExactly('123'); }); test('calls connect with uri for a new connection', async function () { - await testParticipantController.connectWithParticipant(); + await testParticipantController.connectWithParticipant({}); expect(changeActiveConnectionStub).to.have.been.called; }); }); From 2e97557a70dcb8b9f29dc745ceb4614d3c00c223 Mon Sep 17 00:00:00 2001 From: Rhys Howell Date: Fri, 20 Sep 2024 14:33:28 -0400 Subject: [PATCH 2/2] add tests --- src/participant/participant.ts | 7 + src/test/suite/mdbExtensionController.test.ts | 82 +++++-- .../suite/participant/participant.test.ts | 229 ++++++++++++++---- 3 files changed, 255 insertions(+), 63 deletions(-) diff --git a/src/participant/participant.ts b/src/participant/participant.ts index c08bc484..40636588 100644 --- a/src/participant/participant.ts +++ b/src/participant/participant.ts @@ -567,6 +567,7 @@ export default class ParticipantController { chatId: ChatMetadataStore.getChatIdFromHistoryOrNewChatId( context.history ), + databaseName, }, name: 'Show more', }), @@ -734,6 +735,12 @@ export default class ParticipantController { } ); + if (!sampleDocuments) { + return { + amountOfDocumentsSampled: 0, + }; + } + let schema: string; if (schemaFormat === 'simplified') { const unformattedSchema = await getSimplifiedSchema(sampleDocuments); diff --git a/src/test/suite/mdbExtensionController.test.ts b/src/test/suite/mdbExtensionController.test.ts index e887485b..5e11b1cc 100644 --- a/src/test/suite/mdbExtensionController.test.ts +++ b/src/test/suite/mdbExtensionController.test.ts @@ -33,7 +33,7 @@ const testDatabaseURI = 'mongodb://localhost:27088'; function getTestConnectionTreeItem( options?: Partial[0]> -) { +): ConnectionTreeItem { return new ConnectionTreeItem({ connectionId: 'tasty_sandwhich', collapsibleState: vscode.TreeItemCollapsibleState.None, @@ -48,7 +48,7 @@ function getTestConnectionTreeItem( function getTestCollectionTreeItem( options?: Partial[0]> -) { +): CollectionTreeItem { return new CollectionTreeItem({ collection: { name: 'testColName', @@ -65,7 +65,7 @@ function getTestCollectionTreeItem( function getTestDatabaseTreeItem( options?: Partial[0]> -) { +): DatabaseTreeItem { return new DatabaseTreeItem({ databaseName: 'zebra', dataService: {} as DataService, @@ -78,7 +78,7 @@ function getTestDatabaseTreeItem( function getTestStreamProcessorTreeItem( options?: Partial[0]> -) { +): StreamProcessorTreeItem { return new StreamProcessorTreeItem({ streamProcessorName: 'zebra', streamProcessorState: 'CREATED', @@ -88,7 +88,7 @@ function getTestStreamProcessorTreeItem( }); } -function getTestFieldTreeItem() { +function getTestFieldTreeItem(): FieldTreeItem { return new FieldTreeItem({ field: { name: 'dolphins are sentient', @@ -101,7 +101,7 @@ function getTestFieldTreeItem() { }); } -function getTestSchemaTreeItem() { +function getTestSchemaTreeItem(): SchemaTreeItem { return new SchemaTreeItem({ databaseName: 'zebraWearwolf', collectionName: 'giraffeVampire', @@ -116,7 +116,7 @@ function getTestSchemaTreeItem() { function getTestDocumentTreeItem( options?: Partial[0]> -) { +): DocumentTreeItem { return new DocumentTreeItem({ document: {}, namespace: 'waffle.house', @@ -129,10 +129,14 @@ function getTestDocumentTreeItem( suite('MDBExtensionController Test Suite', function () { this.timeout(10000); + const sandbox = sinon.createSandbox(); + + afterEach(() => { + sandbox.restore(); + }); suite('when not connected', () => { let showErrorMessageStub: SinonSpy; - const sandbox = sinon.createSandbox(); beforeEach(() => { sandbox.stub(vscode.window, 'showInformationMessage'); @@ -145,10 +149,6 @@ suite('MDBExtensionController Test Suite', function () { ); }); - afterEach(() => { - sandbox.restore(); - }); - test('mdb.addDatabase command fails when not connected to the connection', async () => { const testTreeItem = getTestConnectionTreeItem(); const addDatabaseSucceeded = await vscode.commands.executeCommand( @@ -177,8 +177,6 @@ suite('MDBExtensionController Test Suite', function () { let fakeCreatePlaygroundFileWithContent: SinonSpy; let openExternalStub: SinonStub; - const sandbox = sinon.createSandbox(); - beforeEach(() => { showInformationMessageStub = sandbox.stub( vscode.window, @@ -206,10 +204,6 @@ suite('MDBExtensionController Test Suite', function () { ); }); - afterEach(() => { - sandbox.restore(); - }); - test('mdb.viewCollectionDocuments command should call onViewCollectionDocuments on the editor controller with the collection namespace', async () => { const textCollectionTree = getTestCollectionTreeItem(); await vscode.commands.executeCommand( @@ -1853,4 +1847,56 @@ suite('MDBExtensionController Test Suite', function () { }); }); }); + + test('mdb.participantViewRawSchemaOutput command opens a json document with the output', async () => { + const openTextDocumentStub = sandbox.stub( + vscode.workspace, + 'openTextDocument' + ); + const showTextDocumentStub = sandbox.stub( + vscode.window, + 'showTextDocument' + ); + + const schemaContent = `{ + "count": 1, + "fields": [ + { + "name": "_id", + "path": [ + "_id" + ], + "count": 1, + "type": "ObjectId", + "probability": 1, + "hasDuplicates": false, + "types": [ + { + "name": "ObjectId", + "path": [ + "_id" + ], + "count": 1, + "probability": 1, + "bsonType": "ObjectId" + } + ] + } + ] +}`; + await vscode.commands.executeCommand('mdb.participantViewRawSchemaOutput', { + schema: schemaContent, + }); + + assert(openTextDocumentStub.calledOnce); + assert.deepStrictEqual(openTextDocumentStub.firstCall.args[0], { + language: 'json', + content: schemaContent, + }); + + assert(showTextDocumentStub.calledOnce); + assert.deepStrictEqual(showTextDocumentStub.firstCall.args[1], { + preview: true, + }); + }); }); diff --git a/src/test/suite/participant/participant.test.ts b/src/test/suite/participant/participant.test.ts index 5d39960e..2410e7aa 100644 --- a/src/test/suite/participant/participant.test.ts +++ b/src/test/suite/participant/participant.test.ts @@ -41,6 +41,10 @@ const loadedConnection = { const testChatId = 'test-chat-id'; +const encodeStringify = (obj: Record): string => { + return encodeURIComponent(JSON.stringify(obj)); +}; + suite('Participant Controller Test Suite', function () { const extensionContextStub = new ExtensionContextStub(); @@ -54,6 +58,7 @@ suite('Participant Controller Test Suite', function () { let testParticipantController: ParticipantController; let chatContextStub: vscode.ChatContext; let chatStreamStub: { + push: sinon.SinonSpy; markdown: sinon.SinonSpy; button: sinon.SinonSpy; }; @@ -104,6 +109,7 @@ suite('Participant Controller Test Suite', function () { ], }; chatStreamStub = { + push: sinon.fake(), markdown: sinon.fake(), button: sinon.fake(), }; @@ -209,17 +215,15 @@ suite('Participant Controller Test Suite', function () { "Looks like you aren't currently connected, first let's get you connected to the cluster we'd like to create this query to run against." ); const listConnectionsMessage = chatStreamStub.markdown.getCall(1).args[0]; - const expectedContent = encodeURIComponent( - JSON.stringify({ id: 'id', command: '/query' }) - ); + const expectedContent = encodeStringify({ id: 'id', command: '/query' }); expect(listConnectionsMessage.value).to.include( `- localhost` ); const showMoreMessage = chatStreamStub.markdown.getCall(2).args[0]; expect(showMoreMessage.value).to.include( - `- Show more` + `- Show more` ); expect(chatResult?.metadata?.chatId.length).to.equal(testChatId.length); expect({ @@ -252,17 +256,15 @@ suite('Participant Controller Test Suite', function () { "Looks like you aren't currently connected, first let's get you connected to the cluster we'd like to create this query to run against." ); const listConnectionsMessage = chatStreamStub.markdown.getCall(1).args[0]; - const expectedContent = encodeURIComponent( - JSON.stringify({ id: 'id', command: '/query' }) - ); + const expectedContent = encodeStringify({ id: 'id0', command: '/query' }); expect(listConnectionsMessage.value).to.include( `- localhost0` ); const showMoreMessage = chatStreamStub.markdown.getCall(11).args[0]; expect(showMoreMessage.value).to.include( - `- Show more` + `- Show more` ); expect(chatStreamStub.markdown.callCount).to.be.eql(12); expect(chatResult?.metadata?.chatId.length).to.equal(testChatId.length); @@ -292,17 +294,15 @@ suite('Participant Controller Test Suite', function () { "Looks like you aren't currently connected, first let's get you connected to the cluster we'd like to create this query to run against" ); const listConnectionsMessage = chatStreamStub.markdown.getCall(4).args[0]; - const expectedContent = encodeURIComponent( - JSON.stringify({ id: 'id', command: '/query' }) - ); + const expectedContent = encodeStringify({ id: 'id', command: '/query' }); expect(listConnectionsMessage.value).to.include( `- localhost` ); const showMoreMessage = chatStreamStub.markdown.getCall(5).args[0]; expect(showMoreMessage.value).to.include( - `- Show more` + `- Show more` ); expect(chatResult?.metadata?.chatId.length).to.equal(testChatId.length); expect({ @@ -732,21 +732,19 @@ suite('Participant Controller Test Suite', function () { 'What is the name of the database you would like this query to run against?' ); const listDBsMessage = chatStreamStub.markdown.getCall(1).args[0]; - const expectedContent = encodeURIComponent( - JSON.stringify({ - chatId: testChatId, - database: 'dbOne', - command: '/query', - }) - ); + const expectedContent = encodeStringify({ + command: '/query', + chatId: testChatId, + databaseName: 'dbOne', + }); expect(listDBsMessage.value).to.include( `- dbOne` ); const showMoreDBsMessage = chatStreamStub.markdown.getCall(11).args[0]; expect(showMoreDBsMessage.value).to.include( - `- Show more` ); expect(showMoreDBsMessage.value).to.include('"'); @@ -807,22 +805,30 @@ suite('Participant Controller Test Suite', function () { const askForCollMessage = chatStreamStub.markdown.getCall(12).args[0]; expect(askForCollMessage).to.include( - 'Which collection would you like to query within dbOne?' + 'Which collection would you like to use within dbOne?' ); const listCollsMessage = chatStreamStub.markdown.getCall(13).args[0]; + const expectedCollsContent = encodeStringify({ + command: '/query', + chatId: testChatId, + databaseName: 'dbOne', + collectionName: 'collOne', + }); expect(listCollsMessage.value).to.include( - '- collOne' + `- collOne` ); const showMoreCollsMessage = chatStreamStub.markdown.getCall(23).args[0]; expect(showMoreCollsMessage.value).to.include( - '- Show more` ); - expect(showMoreCollsMessage.value).to.include('">Show more'); expect(chatStreamStub.markdown.callCount).to.be.eql(24); expect(chatResult2?.metadata?.chatId).to.equal(firstChatId); expect({ @@ -968,15 +974,23 @@ suite('Participant Controller Test Suite', function () { ); const listDBsMessage = chatStreamStub.markdown.getCall(1).args[0]; expect(listDBsMessage.value).to.include( - '- dbOne' + `- dbOne` ); const showMoreDBsMessage = chatStreamStub.markdown.getCall(11).args[0]; expect(showMoreDBsMessage.value).to.include( - '- Show more` ); expect({ ...chatResult?.metadata, @@ -1062,15 +1076,25 @@ suite('Participant Controller Test Suite', function () { ); const listCollsMessage = chatStreamStub.markdown.getCall(1).args[0]; expect(listCollsMessage.value).to.include( - '- collOne' + `- collOne` ); const showMoreCollsMessage = - chatStreamStub.markdown.getCall(1).args[0]; + chatStreamStub.markdown.getCall(11).args[0]; expect(showMoreCollsMessage.value).to.include( - '- Show more` ); expect({ ...chatResult?.metadata, @@ -1085,6 +1109,121 @@ suite('Participant Controller Test Suite', function () { }); }); + suite('schema command', function () { + suite('known namespace from running namespace LLM', function () { + beforeEach(function () { + sendRequestStub.onCall(0).resolves({ + text: ['DATABASE_NAME: dbOne\n', 'COLLECTION_NAME: collOne\n`'], + }); + }); + + test('shows a button to view the json output', async function () { + const chatRequestMock = { + prompt: '', + command: 'schema', + references: [], + }; + sampleStub.resolves([ + { + _id: new ObjectId('63ed1d522d8573fa5c203660'), + }, + ]); + await invokeChatHandler(chatRequestMock); + const expectedSchema = `{ + "count": 1, + "fields": [ + { + "name": "_id", + "path": [ + "_id" + ], + "count": 1, + "type": "ObjectId", + "probability": 1, + "hasDuplicates": false, + "types": [ + { + "name": "ObjectId", + "path": [ + "_id" + ], + "count": 1, + "probability": 1, + "bsonType": "ObjectId" + } + ] + } + ] +}`; + expect(chatStreamStub?.button.getCall(0).args[0]).to.deep.equal({ + command: 'mdb.participantViewRawSchemaOutput', + title: 'Open JSON Output', + arguments: [ + { + schema: expectedSchema, + }, + ], + }); + }); + + test("includes the collection's schema in the request", async function () { + sampleStub.resolves([ + { + _id: new ObjectId('63ed1d522d8573fa5c203660'), + field: { + stringField: + 'There was a house cat who finally got the chance to do what it had always wanted to do.', + arrayField: [new Int32('1')], + }, + }, + { + _id: new ObjectId('63ed1d522d8573fa5c203660'), + field: { + stringField: 'Pineapple.', + arrayField: [new Int32('166')], + }, + }, + ]); + const chatRequestMock = { + prompt: '', + command: 'schema', + references: [], + }; + await invokeChatHandler(chatRequestMock); + const messages = sendRequestStub.secondCall.args[0]; + expect(messages[0].content).to.include( + 'Amount of documents sampled: 2' + ); + expect(messages[1].content).to.include( + `Database name: dbOne +Collection name: collOne +Schema: +{ + "count": 2, + "fields": [` + ); + expect(messages[1].content).to.include(`"name": "arrayField", + "path": [ + "field", + "arrayField" + ],`); + }); + + test('prints a message when no documents are found', async function () { + sampleStub.resolves([]); + const chatRequestMock = { + prompt: '', + command: 'schema', + references: [], + }; + await invokeChatHandler(chatRequestMock); + expect(chatStreamStub?.markdown.getCall(0).args[0]).to.include( + 'Unable to generate a schema from the collection, no documents found.' + ); + }); + }); + }); + suite('docs command', function () { const initialFetch = global.fetch; let fetchStub;