From d7fad90b4d14f7e274bc7920a0b495fb75c52dcc Mon Sep 17 00:00:00 2001 From: Dan Selman Date: Tue, 28 May 2024 14:05:13 +0100 Subject: [PATCH] feat: expose similarity search and fulltext query as tools Signed-off-by: Dan Selman --- src/demo/index.ts | 4 +- src/graphmodel.ts | 155 ++++++++++++++++++++++++++++++++++++++-------- src/types.ts | 4 ++ 3 files changed, 136 insertions(+), 27 deletions(-) diff --git a/src/demo/index.ts b/src/demo/index.ts index e13c269..346b216 100644 --- a/src/demo/index.ts +++ b/src/demo/index.ts @@ -174,9 +174,9 @@ async function run() { const convo = new Conversation(graphModel); let result = await convo.appendUserMessage('Tell me a joke about actors'); logger.success(result); - result = await convo.appendUserMessage('Which actor acted in Fear and Loathing in Las Vegas?'); + result = await convo.appendUserMessage('Which actor is related to Fear and Loathing in Las Vegas?'); logger.success(result); - result = await convo.appendUserMessage('Who directed that movie?'); + result = await convo.appendUserMessage('Which director directed that movie?'); logger.success(result); result = await convo.appendUserMessage('How many movies do we have?'); logger.success(result); diff --git a/src/graphmodel.ts b/src/graphmodel.ts index b574cb2..4e9c461 100644 --- a/src/graphmodel.ts +++ b/src/graphmodel.ts @@ -1,4 +1,4 @@ -import { ClassDeclaration, Factory, Introspector, ModelManager, ModelUtil, RelationshipDeclaration, Serializer } from "@accordproject/concerto-core"; +import { ClassDeclaration, Factory, Introspector, ModelManager, ModelUtil, Property, RelationshipDeclaration, Serializer } from "@accordproject/concerto-core"; import neo4j, { DateTime, Driver, ManagedTransaction } from 'neo4j-driver'; import { Context, EmbeddingCacheNode, FullTextIndex, GraphModelOptions, PropertyBag, SimilarityResult, VectorIndex } from "./types"; import { ROOT_MODEL, ROOT_NAMESPACE } from "./model"; @@ -96,7 +96,12 @@ export class GraphModel { const properties = decl.getProperties() .filter(p => p.getDecorator('fulltext_index')) .map(p => p.getName()); - return properties.length > 0 ? { properties } : undefined; + return properties.length > 0 ? + { + properties, + indexName: this.getFullTextIndexName(decl), + type: decl.getName() + } : undefined; } private getPropertyVectorIndex(property): VectorIndex { @@ -134,13 +139,15 @@ export class GraphModel { throw new Error(`@vector_index decorator on property ${property.getFullyQualifiedName()} is invalid. It references the property ${propertyName} but the property is not Double[].`); } return { + type: property.getParent().getName(), property: propertyName, size: property.getDecorator('vector_index').getArguments()[1] as unknown as number, - type: property.getDecorator('vector_index').getArguments()[2] as unknown as string, + indexType: property.getDecorator('vector_index').getArguments()[2] as unknown as string, + indexName: this.getPropertyVectorIndexName(property.getParent(), property) } } - private getPropertyVectorIndexName(decl, vectorProperty) { + private getPropertyVectorIndexName(decl:ClassDeclaration, vectorProperty: Property) { return `${decl.getName()}_${vectorProperty.getName()}`.toLowerCase(); } @@ -203,6 +210,41 @@ export class GraphModel { this.options.logger?.info('Create constraints completed'); } + /** + * Get all the vector indexes for the model + */ + getVectorIndexes(): Array { + const result: Array = []; + const graphNodes = this.getGraphNodeDeclarations(); + for (let n = 0; n < graphNodes.length; n++) { + const graphNode = graphNodes[n]; + const vectorProperties = graphNode.getProperties().filter(p => p.getDecorator('vector_index')); + for (let i = 0; i < vectorProperties.length; i++) { + const vectorProperty = vectorProperties[i]; + const vectorIndex = this.getPropertyVectorIndex(vectorProperty); + result.push(vectorIndex); + } + } + return result; + } + + /** + * Get all the full text indexes for the model + */ + getFullTextIndexes(): Array { + const result: Array = []; + const graphNodes = this.getGraphNodeDeclarations(); + for (let n = 0; n < graphNodes.length; n++) { + const graphNode = graphNodes[n]; + const fullTextIndex = this.getFullTextIndex(graphNode); + if (fullTextIndex) { + result.push(fullTextIndex) + } + } + return result; + } + + /** * Create vector indexes for the model */ @@ -210,16 +252,11 @@ export class GraphModel { this.options.logger?.info('Creating vector indexes...'); const { session } = await this.openSession(); await session.executeWrite(async tx => { - const graphNodes = this.getGraphNodeDeclarations(); - for (let n = 0; n < graphNodes.length; n++) { - const graphNode = graphNodes[n]; - const vectorProperties = graphNode.getProperties().filter(p => p.getDecorator('vector_index')); - for (let i = 0; i < vectorProperties.length; i++) { - const vectorProperty = vectorProperties[i]; - const vectorIndex = this.getPropertyVectorIndex(vectorProperty); - const indexName = this.getPropertyVectorIndexName(graphNode, vectorProperty); - await tx.run(`CALL db.index.vector.createNodeIndex("${indexName}", "${graphNode.getName()}", "${vectorIndex.property}", ${vectorIndex.size}, "${vectorIndex.type}")`); - } + const indexes = this.getVectorIndexes(); + for (let n = 0; n < indexes.length; n++) { + const index = indexes[n]; + console.log(JSON.stringify(index, null, 2)); + await tx.run(`CALL db.index.vector.createNodeIndex("${index.indexName}", "${index.type}", "${index.property}", ${index.size}, "${index.indexType}")`); } }) await session.close(); @@ -233,15 +270,11 @@ export class GraphModel { this.options.logger?.info('Creating full text indexes...'); const { session } = await this.openSession(); await session.executeWrite(async tx => { - const graphNodes = this.getGraphNodeDeclarations(); - for (let n = 0; n < graphNodes.length; n++) { - const graphNode = graphNodes[n]; - const fullTextIndex = this.getFullTextIndex(graphNode); - if (fullTextIndex) { - const indexName = this.getFullTextIndexName(graphNode); - const props = fullTextIndex.properties.map(p => `n.${p}`); - await tx.run(`CREATE FULLTEXT INDEX ${indexName} FOR (n:${graphNode.getName()}) ON EACH [${props.join(',')}];`); - } + const indexes = this.getFullTextIndexes(); + for (let n = 0; n < indexes.length; n++) { + const index = indexes[n]; + const props = index.properties.map(p => `n.${p}`); + await tx.run(`CREATE FULLTEXT INDEX ${index.indexName} FOR (n:${index.type}) ON EACH [${props.join(',')}];`); } }) await session.close(); @@ -431,6 +464,7 @@ export class GraphModel { if (!textContentNode.embedding) { throw new Error(`Internal error. Failed to get embedding for ${searchText}`); } + this.options.logger?.info(`Similarity query of '${typeName}.${propertyName}' for '${searchText}'`); return this.similarityQueryFromEmbedding(typeName, propertyName, textContentNode.embedding, count); } catch (err) { @@ -495,6 +529,7 @@ export class GraphModel { if (!fullTextIndex) { throw new Error(`No full text index for properties of ${typeName}`); } + this.options.logger?.info(`Fulltext search of '${typeName}' for '${searchText}'`); const indexName = this.getFullTextIndexName(graphNode); const props = fullTextIndex.properties.map(p => `node.${p}`); props.push('node.identifier'); @@ -596,7 +631,7 @@ export class GraphModel { try { return await this.query(`MATCH (n:${node.getName()} WHERE n.identifier='${name}') RETURN n;`); } - catch(err) { + catch (err) { return `An error occurred: ${err}`; } }), @@ -613,6 +648,76 @@ export class GraphModel { } }) } + // full text search + const fullTextIndexes = this.getFullTextIndexes(); + for (let n = 0; n < fullTextIndexes.length; n++) { + const index = fullTextIndexes[n]; + result.push({ + type: "function", + function: { + description: `Full-text search over ${index.type}`, + name: `fulltext_${index.type.toLowerCase()}`, + function: (async (args: { search: string, count?: number }) => { + const { search, count } = args; + try { + return await this.fullTextQuery(index.type, search, count ? count : 10); + } + catch (err) { + return `An error occurred: ${err}`; + } + }), + parse: JSON.parse, + parameters: { + "type": "object", + "properties": { + "search": { + "type": "string", + }, + "count": { + "type": "number", + } + }, + "required": ["search"] + } + } + }) + } + + // similarity search + const vectorIndexes = this.getVectorIndexes(); + for (let n = 0; n < vectorIndexes.length; n++) { + const index = vectorIndexes[n]; + result.push({ + type: "function", + function: { + description: `Similiarity/conceptual search over ${index.type}.${index.property}`, + name: `similarity_${index.type.toLowerCase()}_${index.property.toLowerCase()}`, + function: (async (args: { query: string, property: string, count?: number }) => { + const { query, count } = args; + try { + return await this.similarityQuery(index.type, index.property, query, count ? count : 10); + } + catch (err) { + return `An error occurred: ${err}`; + } + }), + parse: JSON.parse, + parameters: { + "type": "object", + "properties": { + "query": { + "type": "string", + }, + "count": { + "type": "number", + } + }, + "required": ["query"] + } + } + }) + } + // generic: chat with data... result.push({ type: "function", @@ -624,7 +729,7 @@ export class GraphModel { try { return await this.chatWithData(query); } - catch(err) { + catch (err) { return `An error occurred: ${err}`; } }), diff --git a/src/types.ts b/src/types.ts index deaea76..6ec715c 100644 --- a/src/types.ts +++ b/src/types.ts @@ -15,13 +15,17 @@ export type VectorIndex = { property: string; size: number; type: string; + indexType: string; + indexName: string; } /** * Definition of a full text index over some properties */ export type FullTextIndex = { + type: string; properties: Array; + indexName: string; } /**