Skip to content

Commit

Permalink
Sort search result by node frequency
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed Sep 19, 2024
1 parent 6c4143c commit 5e8674f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
16 changes: 12 additions & 4 deletions src/services/nodeSearchService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { getNodeSource } from '@/types/nodeSource'
import Fuse, { IFuseOptions, FuseSearchOptions } from 'fuse.js'
import _ from 'lodash'

type SearchAuxScore = [number, number, number, number]
export type SearchAuxScore = number[]

interface ExtraSearchOptions {
matchWildcards?: boolean
Expand Down Expand Up @@ -51,17 +51,20 @@ export class FuseSearch<T> {
return aux.map((x) => x.item)
}

public calcAuxScores(query: string, entry: T, score: number) {
public calcAuxScores(query: string, entry: T, score: number): SearchAuxScore {
let values: string[] = []
if (!this.keys.length) values = [entry as string]
else values = this.keys.map((x) => entry[x])
const scores = values.map((x) => this.calcAuxSingle(query, x, score))
const result = scores.sort(this.compareAux)[0]
let result = scores.sort(this.compareAux)[0]

const deprecated = values.some((x) =>
x.toLocaleLowerCase().includes('deprecated')
)
result[0] += deprecated && result[0] != 0 ? 5 : 0
if (entry['postProcessSearchScores']) {
result = entry['postProcessSearchScores'](result) as SearchAuxScore
}
return result
}

Expand Down Expand Up @@ -117,7 +120,12 @@ export class FuseSearch<T> {
}

public compareAux(a: SearchAuxScore, b: SearchAuxScore) {
return a[0] - b[0] || a[1] - b[1] || a[2] - b[2] || a[3] - b[3]
for (let i = 0; i < Math.min(a.length, b.length); i++) {
if (a[i] !== b[i]) {
return a[i] - b[i]
}
}
return a.length - b.length
}
}

Expand Down
20 changes: 17 additions & 3 deletions src/stores/nodeDefStore.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import { NodeSearchService } from '@/services/nodeSearchService'
import {
NodeSearchService,
type SearchAuxScore
} from '@/services/nodeSearchService'
import { ComfyNodeDef } from '@/types/apiTypes'
import { defineStore } from 'pinia'
import { Type, Transform, plainToClass, Expose } from 'class-transformer'
Expand Down Expand Up @@ -215,6 +218,12 @@ export class ComfyNodeDefImpl {
get isDummyFolder(): boolean {
return this.name === ''
}

postProcessSearchScores(scores: SearchAuxScore): SearchAuxScore {
const nodeFrequencyStore = useNodeFrequencyStore()
const nodeFrequency = nodeFrequencyStore.getNodeFrequencyByName(this.name)
return [scores[0], -nodeFrequency, ...scores.slice(1)]
}
}

export const SYSTEM_NODE_DEFS: Record<string, ComfyNodeDef> = {
Expand Down Expand Up @@ -362,7 +371,11 @@ export const useNodeFrequencyStore = defineStore('nodeFrequency', () => {
}

const getNodeFrequency = (nodeDef: ComfyNodeDefImpl) => {
return nodeFrequencyLookup.value[nodeDef.name] ?? 0
return getNodeFrequencyByName(nodeDef.name)
}

const getNodeFrequencyByName = (nodeName: string) => {
return nodeFrequencyLookup.value[nodeName] ?? 0
}

const nodeDefStore = useNodeDefStore()
Expand All @@ -378,6 +391,7 @@ export const useNodeFrequencyStore = defineStore('nodeFrequency', () => {
topNodeDefs,
isLoaded,
loadNodeFrequencies,
getNodeFrequency
getNodeFrequency,
getNodeFrequencyByName
}
})

0 comments on commit 5e8674f

Please sign in to comment.