Skip to content

Commit 8568e03

Browse files
authored
Sort search result by node frequency (#879)
* Sort search result by node frequency * Fix jest test
1 parent 6c4143c commit 8568e03

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

src/services/nodeSearchService.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { getNodeSource } from '@/types/nodeSource'
33
import Fuse, { IFuseOptions, FuseSearchOptions } from 'fuse.js'
44
import _ from 'lodash'
55

6-
type SearchAuxScore = [number, number, number, number]
6+
export type SearchAuxScore = number[]
77

88
interface ExtraSearchOptions {
99
matchWildcards?: boolean
@@ -51,17 +51,20 @@ export class FuseSearch<T> {
5151
return aux.map((x) => x.item)
5252
}
5353

54-
public calcAuxScores(query: string, entry: T, score: number) {
54+
public calcAuxScores(query: string, entry: T, score: number): SearchAuxScore {
5555
let values: string[] = []
5656
if (!this.keys.length) values = [entry as string]
5757
else values = this.keys.map((x) => entry[x])
5858
const scores = values.map((x) => this.calcAuxSingle(query, x, score))
59-
const result = scores.sort(this.compareAux)[0]
59+
let result = scores.sort(this.compareAux)[0]
6060

6161
const deprecated = values.some((x) =>
6262
x.toLocaleLowerCase().includes('deprecated')
6363
)
6464
result[0] += deprecated && result[0] != 0 ? 5 : 0
65+
if (entry['postProcessSearchScores']) {
66+
result = entry['postProcessSearchScores'](result) as SearchAuxScore
67+
}
6568
return result
6669
}
6770

@@ -117,7 +120,12 @@ export class FuseSearch<T> {
117120
}
118121

119122
public compareAux(a: SearchAuxScore, b: SearchAuxScore) {
120-
return a[0] - b[0] || a[1] - b[1] || a[2] - b[2] || a[3] - b[3]
123+
for (let i = 0; i < Math.min(a.length, b.length); i++) {
124+
if (a[i] !== b[i]) {
125+
return a[i] - b[i]
126+
}
127+
}
128+
return a.length - b.length
121129
}
122130
}
123131

src/stores/nodeDefStore.ts

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
import { NodeSearchService } from '@/services/nodeSearchService'
1+
import {
2+
NodeSearchService,
3+
type SearchAuxScore
4+
} from '@/services/nodeSearchService'
25
import { ComfyNodeDef } from '@/types/apiTypes'
36
import { defineStore } from 'pinia'
47
import { Type, Transform, plainToClass, Expose } from 'class-transformer'
@@ -215,6 +218,12 @@ export class ComfyNodeDefImpl {
215218
get isDummyFolder(): boolean {
216219
return this.name === ''
217220
}
221+
222+
postProcessSearchScores(scores: SearchAuxScore): SearchAuxScore {
223+
const nodeFrequencyStore = useNodeFrequencyStore()
224+
const nodeFrequency = nodeFrequencyStore.getNodeFrequencyByName(this.name)
225+
return [scores[0], -nodeFrequency, ...scores.slice(1)]
226+
}
218227
}
219228

220229
export const SYSTEM_NODE_DEFS: Record<string, ComfyNodeDef> = {
@@ -362,7 +371,11 @@ export const useNodeFrequencyStore = defineStore('nodeFrequency', () => {
362371
}
363372

364373
const getNodeFrequency = (nodeDef: ComfyNodeDefImpl) => {
365-
return nodeFrequencyLookup.value[nodeDef.name] ?? 0
374+
return getNodeFrequencyByName(nodeDef.name)
375+
}
376+
377+
const getNodeFrequencyByName = (nodeName: string) => {
378+
return nodeFrequencyLookup.value[nodeName] ?? 0
366379
}
367380

368381
const nodeDefStore = useNodeDefStore()
@@ -378,6 +391,7 @@ export const useNodeFrequencyStore = defineStore('nodeFrequency', () => {
378391
topNodeDefs,
379392
isLoaded,
380393
loadNodeFrequencies,
381-
getNodeFrequency
394+
getNodeFrequency,
395+
getNodeFrequencyByName
382396
}
383397
})

tests-ui/utils/setup.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ export const mockNodeDefStore = () => {
132132
}
133133

134134
jest.mock('@/stores/nodeDefStore', () => ({
135-
useNodeDefStore: jest.fn(() => mockedNodeDefStore)
135+
useNodeDefStore: jest.fn(() => mockedNodeDefStore),
136+
useNodeFrequencyStore: jest.fn(() => ({
137+
getNodeFrequencyByName: jest.fn(() => 0)
138+
}))
136139
}))
137140
}

0 commit comments

Comments
 (0)