From 5e8674fac02e31abe080936ab9e26944f68eb615 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Thu, 19 Sep 2024 09:54:39 +0900 Subject: [PATCH] Sort search result by node frequency --- src/services/nodeSearchService.ts | 16 ++++++++++++---- src/stores/nodeDefStore.ts | 20 +++++++++++++++++--- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/src/services/nodeSearchService.ts b/src/services/nodeSearchService.ts index b4a014b51a..af82b6b502 100644 --- a/src/services/nodeSearchService.ts +++ b/src/services/nodeSearchService.ts @@ -3,7 +3,7 @@ import { getNodeSource } from '@/types/nodeSource' import Fuse, { IFuseOptions, FuseSearchOptions } from 'fuse.js' import _ from 'lodash' -type SearchAuxScore = [number, number, number, number] +export type SearchAuxScore = number[] interface ExtraSearchOptions { matchWildcards?: boolean @@ -51,17 +51,20 @@ export class FuseSearch { return aux.map((x) => x.item) } - public calcAuxScores(query: string, entry: T, score: number) { + public calcAuxScores(query: string, entry: T, score: number): SearchAuxScore { let values: string[] = [] if (!this.keys.length) values = [entry as string] else values = this.keys.map((x) => entry[x]) const scores = values.map((x) => this.calcAuxSingle(query, x, score)) - const result = scores.sort(this.compareAux)[0] + let result = scores.sort(this.compareAux)[0] const deprecated = values.some((x) => x.toLocaleLowerCase().includes('deprecated') ) result[0] += deprecated && result[0] != 0 ? 5 : 0 + if (entry['postProcessSearchScores']) { + result = entry['postProcessSearchScores'](result) as SearchAuxScore + } return result } @@ -117,7 +120,12 @@ export class FuseSearch { } public compareAux(a: SearchAuxScore, b: SearchAuxScore) { - return a[0] - b[0] || a[1] - b[1] || a[2] - b[2] || a[3] - b[3] + for (let i = 0; i < Math.min(a.length, b.length); i++) { + if (a[i] !== b[i]) { + return a[i] - b[i] + } + } + return a.length - b.length } } diff --git a/src/stores/nodeDefStore.ts b/src/stores/nodeDefStore.ts index 00360ccd27..a57f92d60f 100644 --- a/src/stores/nodeDefStore.ts +++ b/src/stores/nodeDefStore.ts @@ -1,4 +1,7 @@ -import { NodeSearchService } from '@/services/nodeSearchService' +import { + NodeSearchService, + type SearchAuxScore +} from '@/services/nodeSearchService' import { ComfyNodeDef } from '@/types/apiTypes' import { defineStore } from 'pinia' import { Type, Transform, plainToClass, Expose } from 'class-transformer' @@ -215,6 +218,12 @@ export class ComfyNodeDefImpl { get isDummyFolder(): boolean { return this.name === '' } + + postProcessSearchScores(scores: SearchAuxScore): SearchAuxScore { + const nodeFrequencyStore = useNodeFrequencyStore() + const nodeFrequency = nodeFrequencyStore.getNodeFrequencyByName(this.name) + return [scores[0], -nodeFrequency, ...scores.slice(1)] + } } export const SYSTEM_NODE_DEFS: Record = { @@ -362,7 +371,11 @@ export const useNodeFrequencyStore = defineStore('nodeFrequency', () => { } const getNodeFrequency = (nodeDef: ComfyNodeDefImpl) => { - return nodeFrequencyLookup.value[nodeDef.name] ?? 0 + return getNodeFrequencyByName(nodeDef.name) + } + + const getNodeFrequencyByName = (nodeName: string) => { + return nodeFrequencyLookup.value[nodeName] ?? 0 } const nodeDefStore = useNodeDefStore() @@ -378,6 +391,7 @@ export const useNodeFrequencyStore = defineStore('nodeFrequency', () => { topNodeDefs, isLoaded, loadNodeFrequencies, - getNodeFrequency + getNodeFrequency, + getNodeFrequencyByName } })