Skip to content

Commit

Permalink
Add _update()
Browse files Browse the repository at this point in the history
Signed-off-by: Jay Wang <jay@zijie.wang>
  • Loading branch information
xiaohk committed Jan 31, 2024
1 parent f0e61b3 commit 09ae482
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 44 deletions.
150 changes: 128 additions & 22 deletions src/mememo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,14 @@ export class HNSW<T = string> {

// Randomly determine the max level of this node
const level = maxLevel === undefined ? this._getRandomLevel() : maxLevel;
console.log('random level:', level);
// console.log('random level:', level);

// Add this node to the node index first
this.nodes.set(key, new Node(key, value));

if (this.entryPointKey !== null) {
// (1): Search closest point from layers above
const entryPointInfo = this.nodes.get(this.entryPointKey);
if (entryPointInfo === undefined) {
throw Error(`Can't find node info of ${this.entryPointKey}`);
}
const entryPointInfo = this._getNodeInfo(this.entryPointKey);

// Start with the entry point
let minDistance = this.distanceFunction(value, entryPointInfo.value);
Expand Down Expand Up @@ -304,6 +301,116 @@ export class HNSW<T = string> {
}
}

/**
* Update an element in the index
* @param key Key of the element.
* @param value The new embedding of the element
*/
_update(key: T, value: number[]) {
if (!this.nodes.has(key)) {
throw Error(`The node with key ${key} does not exist.`);
}

this.nodes.set(key, new Node(key, value));

if (this.entryPointKey === key && this.nodes.size === 1) {
return;
}

// Re-index all the neighbors of this node in all layers
for (let l = 0; l < this.graphLayers.length; l++) {
const curGraphLayer = this.graphLayers[l];
// Layer 0 could have a different neighbor size constraint
const levelM = l === 0 ? this.mMax0 : this.m;

// If the current layer doesn't have this node, then the upper layers
// won't have it either
if (!curGraphLayer.graph.has(key)) {
break;
}
const curNode = curGraphLayer.graph.get(key)!;

// For each neighbor, we use the entire second-degree neighborhood of the
// updating node as new connection candidates
const secondDegreeNeighborhood: Set<T> = new Set([key]);

// Find the second-degree neighborhood
for (const firstDegreeNeighbor of curNode.keys()) {
secondDegreeNeighborhood.add(firstDegreeNeighbor);

const firstDegreeNeighborNode =
curGraphLayer.graph.get(firstDegreeNeighbor);
if (firstDegreeNeighborNode === undefined) {
throw Error(`Can't find node with key ${firstDegreeNeighbor}`);
}

for (const secondDegreeNeighbor of firstDegreeNeighborNode.keys()) {
secondDegreeNeighborhood.add(secondDegreeNeighbor);
}
}

// Update the first-degree neighbor's connections
const nodeCompare: IGetCompareValue<SearchNodeCandidate<T>> = (
candidate: SearchNodeCandidate<T>
) => candidate.distance;

for (const firstDegreeNeighbor of curNode.keys()) {
// (1) Find `efConstruction` number of candidates
const candidateMaxHeap = new MaxHeap(nodeCompare);
const firstDegreeNeighborInfo = this._getNodeInfo(firstDegreeNeighbor);

for (const secondDegreeNeighbor of secondDegreeNeighborhood) {
if (secondDegreeNeighbor === firstDegreeNeighbor) {
continue;
}

const secondDegreeNeighborInfo =
this._getNodeInfo(secondDegreeNeighbor);

const distance = this.distanceFunction(
firstDegreeNeighborInfo.value,
secondDegreeNeighborInfo.value
);

if (candidateMaxHeap.size() < this.efConstruction) {
// Add to the candidates if we still have open slots
candidateMaxHeap.push({ key: secondDegreeNeighbor, distance });
} else {
// Add to the candidates if the distance is better than the worst
// added candidate, by replacing the worst added candidate
if (distance < candidateMaxHeap.top()!.distance) {
candidateMaxHeap.pop();
candidateMaxHeap.push({ key: secondDegreeNeighbor, distance });
}
}
}

// (2) Select `levelM` number candidates out of the candidates
const candidates = candidateMaxHeap.toArray();
const selectedCandidates = this._selectNeighborsHeuristic(
candidates,
levelM
);

// (3) Update the neighbor's neighborhood
const newNeighborNode = new Map<T, number>();
for (const neighborNeighbor of selectedCandidates) {
newNeighborNode.set(neighborNeighbor.key, neighborNeighbor.distance);
}
curGraphLayer.graph.set(firstDegreeNeighbor, newNeighborNode);
}
}

// After re-indexing the neighbors of the updating node, we also need to
// update the outgoing edges of the updating node in all layers. This is
// similar to the initial indexing procedure in insert()
this._reIndexNode(key, value);
}

_reIndexNode(key: T, value: number[]) {
// Pass
}

/**
* Greedy search the closest neighbor in a layer.
* @param queryValue The embedding value of the query
Expand Down Expand Up @@ -345,10 +452,7 @@ export class HNSW<T = string> {
if (!visitedNodes.has(key)) {
visitedNodes.add(key);
// Compute the distance between the node and query
const curNodeInfo = this.nodes.get(key);
if (curNodeInfo === undefined) {
throw Error(`Cannot find node info with key ${key}`);
}
const curNodeInfo = this._getNodeInfo(key);
const distance = this.distanceFunction(curNodeInfo.value, queryValue);

// Continue explore the node's neighbors if the distance is improving
Expand Down Expand Up @@ -416,10 +520,7 @@ export class HNSW<T = string> {
visitedNodes.add(neighborKey);

// Compute the distance of the neighbor and query
const neighborInfo = this.nodes.get(neighborKey);
if (neighborInfo === undefined) {
throw Error(`Cannot find node with key ${neighborKey}`);
}
const neighborInfo = this._getNodeInfo(neighborKey);
const distance = this.distanceFunction(
queryValue,
neighborInfo.value
Expand Down Expand Up @@ -493,15 +594,8 @@ export class HNSW<T = string> {

// Iterate selected neighbors to see if the candidate is further away
for (const selectedNeighbor of selectedNeighbors) {
const candidateInfo = this.nodes.get(candidate.key);
if (candidateInfo === undefined) {
throw Error(`Can't find node with key ${candidate.key}`);
}

const neighborInfo = this.nodes.get(selectedNeighbor.key);
if (neighborInfo === undefined) {
throw Error(`Can't find node with key ${selectedNeighbor.key}`);
}
const candidateInfo = this._getNodeInfo(candidate.key);
const neighborInfo = this._getNodeInfo(selectedNeighbor.key);

const distanceCandidateToNeighbor = this.distanceFunction(
candidateInfo.value,
Expand Down Expand Up @@ -531,4 +625,16 @@ export class HNSW<T = string> {
_getRandomLevel() {
return Math.floor(-Math.log(this.rng()) * this.ml);
}

/**
* Helper function to get the node in the global index
* @param key Node key
*/
_getNodeInfo(key: T) {
const node = this.nodes.get(key);
if (node === undefined) {
throw Error(`Can't find node with key ${key}`);
}
return node;
}
}
44 changes: 22 additions & 22 deletions test/mememo.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,28 +252,28 @@ describe('insert()', () => {
}
});

it.skip('Find random seeds', () => {
// Find random seed that give a nice level sequence to test
const size = 100;
for (let i = 1; i < 100000; i++) {
const rng = randomLcg(i);
const curLevels: number[] = [];
const ml = 1 / Math.log(16);

for (let j = 0; j < size; j++) {
const level = Math.floor(-Math.log(rng()) * ml);
curLevels.push(level);
}

if (Math.max(...curLevels) < 4) {
const levelSum = curLevels.reduce((sum, value) => sum + value, 0);
if (levelSum > 20) {
console.log('Good seed: ', i);
break;
}
}
}
});
// it('Find random seeds', () => {
// // Find random seed that give a nice level sequence to test
// const size = 100;
// for (let i = 1; i < 100000; i++) {
// const rng = randomLcg(i);
// const curLevels: number[] = [];
// const ml = 1 / Math.log(16);

// for (let j = 0; j < size; j++) {
// const level = Math.floor(-Math.log(rng()) * ml);
// curLevels.push(level);
// }

// if (Math.max(...curLevels) < 4) {
// const levelSum = curLevels.reduce((sum, value) => sum + value, 0);
// if (levelSum > 20) {
// console.log('Good seed: ', i);
// break;
// }
// }
// }
// });
});

//==========================================================================||
Expand Down

0 comments on commit 09ae482

Please sign in to comment.