Skip to content

Commit

Permalink
feat: expose similarity search and fulltext query as tools
Browse files Browse the repository at this point in the history
Signed-off-by: Dan Selman <danscode@selman.org>
  • Loading branch information
dselman committed May 28, 2024
1 parent 7aaff8f commit d7fad90
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 27 deletions.
4 changes: 2 additions & 2 deletions src/demo/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ async function run() {
const convo = new Conversation(graphModel);
let result = await convo.appendUserMessage('Tell me a joke about actors');
logger.success(result);
result = await convo.appendUserMessage('Which actor acted in Fear and Loathing in Las Vegas?');
result = await convo.appendUserMessage('Which actor is related to Fear and Loathing in Las Vegas?');
logger.success(result);
result = await convo.appendUserMessage('Who directed that movie?');
result = await convo.appendUserMessage('Which director directed that movie?');
logger.success(result);
result = await convo.appendUserMessage('How many movies do we have?');
logger.success(result);
Expand Down
155 changes: 130 additions & 25 deletions src/graphmodel.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { ClassDeclaration, Factory, Introspector, ModelManager, ModelUtil, RelationshipDeclaration, Serializer } from "@accordproject/concerto-core";
import { ClassDeclaration, Factory, Introspector, ModelManager, ModelUtil, Property, RelationshipDeclaration, Serializer } from "@accordproject/concerto-core";
import neo4j, { DateTime, Driver, ManagedTransaction } from 'neo4j-driver';
import { Context, EmbeddingCacheNode, FullTextIndex, GraphModelOptions, PropertyBag, SimilarityResult, VectorIndex } from "./types";
import { ROOT_MODEL, ROOT_NAMESPACE } from "./model";
Expand Down Expand Up @@ -96,7 +96,12 @@ export class GraphModel {
const properties = decl.getProperties()
.filter(p => p.getDecorator('fulltext_index'))
.map(p => p.getName());
return properties.length > 0 ? { properties } : undefined;
return properties.length > 0 ?
{
properties,
indexName: this.getFullTextIndexName(decl),
type: decl.getName()
} : undefined;
}

private getPropertyVectorIndex(property): VectorIndex {
Expand Down Expand Up @@ -134,13 +139,15 @@ export class GraphModel {
throw new Error(`@vector_index decorator on property ${property.getFullyQualifiedName()} is invalid. It references the property ${propertyName} but the property is not Double[].`);
}
return {
type: property.getParent().getName(),
property: propertyName,
size: property.getDecorator('vector_index').getArguments()[1] as unknown as number,
type: property.getDecorator('vector_index').getArguments()[2] as unknown as string,
indexType: property.getDecorator('vector_index').getArguments()[2] as unknown as string,
indexName: this.getPropertyVectorIndexName(property.getParent(), property)
}
}

private getPropertyVectorIndexName(decl, vectorProperty) {
private getPropertyVectorIndexName(decl:ClassDeclaration, vectorProperty: Property) {
return `${decl.getName()}_${vectorProperty.getName()}`.toLowerCase();
}

Expand Down Expand Up @@ -203,23 +210,53 @@ export class GraphModel {
this.options.logger?.info('Create constraints completed');
}

/**
* Get all the vector indexes for the model
*/
getVectorIndexes(): Array<VectorIndex> {
const result: Array<VectorIndex> = [];
const graphNodes = this.getGraphNodeDeclarations();
for (let n = 0; n < graphNodes.length; n++) {
const graphNode = graphNodes[n];
const vectorProperties = graphNode.getProperties().filter(p => p.getDecorator('vector_index'));
for (let i = 0; i < vectorProperties.length; i++) {
const vectorProperty = vectorProperties[i];
const vectorIndex = this.getPropertyVectorIndex(vectorProperty);
result.push(vectorIndex);
}
}
return result;
}

/**
* Get all the full text indexes for the model
*/
getFullTextIndexes(): Array<FullTextIndex> {
const result: Array<FullTextIndex> = [];
const graphNodes = this.getGraphNodeDeclarations();
for (let n = 0; n < graphNodes.length; n++) {
const graphNode = graphNodes[n];
const fullTextIndex = this.getFullTextIndex(graphNode);
if (fullTextIndex) {
result.push(fullTextIndex)
}
}
return result;
}


/**
* Create vector indexes for the model
*/
async createVectorIndexes() {
this.options.logger?.info('Creating vector indexes...');
const { session } = await this.openSession();
await session.executeWrite(async tx => {
const graphNodes = this.getGraphNodeDeclarations();
for (let n = 0; n < graphNodes.length; n++) {
const graphNode = graphNodes[n];
const vectorProperties = graphNode.getProperties().filter(p => p.getDecorator('vector_index'));
for (let i = 0; i < vectorProperties.length; i++) {
const vectorProperty = vectorProperties[i];
const vectorIndex = this.getPropertyVectorIndex(vectorProperty);
const indexName = this.getPropertyVectorIndexName(graphNode, vectorProperty);
await tx.run(`CALL db.index.vector.createNodeIndex("${indexName}", "${graphNode.getName()}", "${vectorIndex.property}", ${vectorIndex.size}, "${vectorIndex.type}")`);
}
const indexes = this.getVectorIndexes();
for (let n = 0; n < indexes.length; n++) {
const index = indexes[n];
console.log(JSON.stringify(index, null, 2));
await tx.run(`CALL db.index.vector.createNodeIndex("${index.indexName}", "${index.type}", "${index.property}", ${index.size}, "${index.indexType}")`);
}
})
await session.close();
Expand All @@ -233,15 +270,11 @@ export class GraphModel {
this.options.logger?.info('Creating full text indexes...');
const { session } = await this.openSession();
await session.executeWrite(async tx => {
const graphNodes = this.getGraphNodeDeclarations();
for (let n = 0; n < graphNodes.length; n++) {
const graphNode = graphNodes[n];
const fullTextIndex = this.getFullTextIndex(graphNode);
if (fullTextIndex) {
const indexName = this.getFullTextIndexName(graphNode);
const props = fullTextIndex.properties.map(p => `n.${p}`);
await tx.run(`CREATE FULLTEXT INDEX ${indexName} FOR (n:${graphNode.getName()}) ON EACH [${props.join(',')}];`);
}
const indexes = this.getFullTextIndexes();
for (let n = 0; n < indexes.length; n++) {
const index = indexes[n];
const props = index.properties.map(p => `n.${p}`);
await tx.run(`CREATE FULLTEXT INDEX ${index.indexName} FOR (n:${index.type}) ON EACH [${props.join(',')}];`);
}
})
await session.close();
Expand Down Expand Up @@ -431,6 +464,7 @@ export class GraphModel {
if (!textContentNode.embedding) {
throw new Error(`Internal error. Failed to get embedding for ${searchText}`);
}
this.options.logger?.info(`Similarity query of '${typeName}.${propertyName}' for '${searchText}'`);
return this.similarityQueryFromEmbedding(typeName, propertyName, textContentNode.embedding, count);
}
catch (err) {
Expand Down Expand Up @@ -495,6 +529,7 @@ export class GraphModel {
if (!fullTextIndex) {
throw new Error(`No full text index for properties of ${typeName}`);
}
this.options.logger?.info(`Fulltext search of '${typeName}' for '${searchText}'`);
const indexName = this.getFullTextIndexName(graphNode);
const props = fullTextIndex.properties.map(p => `node.${p}`);
props.push('node.identifier');
Expand Down Expand Up @@ -596,7 +631,7 @@ export class GraphModel {
try {
return await this.query(`MATCH (n:${node.getName()} WHERE n.identifier='${name}') RETURN n;`);
}
catch(err) {
catch (err) {
return `An error occurred: ${err}`;
}
}),
Expand All @@ -613,6 +648,76 @@ export class GraphModel {
}
})
}
// full text search
const fullTextIndexes = this.getFullTextIndexes();
for (let n = 0; n < fullTextIndexes.length; n++) {
const index = fullTextIndexes[n];
result.push({
type: "function",
function: {
description: `Full-text search over ${index.type}`,
name: `fulltext_${index.type.toLowerCase()}`,
function: (async (args: { search: string, count?: number }) => {
const { search, count } = args;
try {
return await this.fullTextQuery(index.type, search, count ? count : 10);
}
catch (err) {
return `An error occurred: ${err}`;
}
}),
parse: JSON.parse,
parameters: {
"type": "object",
"properties": {
"search": {
"type": "string",
},
"count": {
"type": "number",
}
},
"required": ["search"]
}
}
})
}

// similarity search
const vectorIndexes = this.getVectorIndexes();
for (let n = 0; n < vectorIndexes.length; n++) {
const index = vectorIndexes[n];
result.push({
type: "function",
function: {
description: `Similiarity/conceptual search over ${index.type}.${index.property}`,
name: `similarity_${index.type.toLowerCase()}_${index.property.toLowerCase()}`,
function: (async (args: { query: string, property: string, count?: number }) => {
const { query, count } = args;
try {
return await this.similarityQuery(index.type, index.property, query, count ? count : 10);
}
catch (err) {
return `An error occurred: ${err}`;
}
}),
parse: JSON.parse,
parameters: {
"type": "object",
"properties": {
"query": {
"type": "string",
},
"count": {
"type": "number",
}
},
"required": ["query"]
}
}
})
}

// generic: chat with data...
result.push({
type: "function",
Expand All @@ -624,7 +729,7 @@ export class GraphModel {
try {
return await this.chatWithData(query);
}
catch(err) {
catch (err) {
return `An error occurred: ${err}`;
}
}),
Expand Down
4 changes: 4 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ export type VectorIndex = {
property: string;
size: number;
type: string;
indexType: string;
indexName: string;
}

/**
* Definition of a full text index over some properties
*/
export type FullTextIndex = {
type: string;
properties: Array<string>;
indexName: string;
}

/**
Expand Down

0 comments on commit d7fad90

Please sign in to comment.