From bdba18d642777c89205b71ee69a31481a745078d Mon Sep 17 00:00:00 2001 From: Arvin Xu Date: Sat, 13 Apr 2024 11:21:21 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20perf:=20fix=20performance?= =?UTF-8?q?=20issue=20with=20model=20list=20(#2012)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ⚡️ perf: improve performance * ⚡️ perf: improve performance * ⚡️ perf: improve performance --- src/config/modelProviders/index.ts | 18 ++++ .../AgentSetting/AgentConfig/ModelSelect.tsx | 1 + src/store/global/slices/common/action.ts | 9 ++ .../slices/settings/actions/llm.test.ts | 96 ++++++++++++++++++- .../global/slices/settings/actions/llm.ts | 96 +++++++++++++++++++ .../global/slices/settings/initialState.ts | 6 ++ .../settings/selectors/modelProvider.test.ts | 66 +------------ .../settings/selectors/modelProvider.ts | 89 +++-------------- 8 files changed, 233 insertions(+), 148 deletions(-) diff --git a/src/config/modelProviders/index.ts b/src/config/modelProviders/index.ts index b4c89c736fdf2..dfbb5caf05237 100644 --- a/src/config/modelProviders/index.ts +++ b/src/config/modelProviders/index.ts @@ -1,6 +1,7 @@ import { ChatModelCard, ModelProviderCard } from '@/types/llm'; import AnthropicProvider from './anthropic'; +import AzureProvider from './azure'; import BedrockProvider from './bedrock'; import GoogleProvider from './google'; import GroqProvider from './groq'; @@ -30,6 +31,23 @@ export const LOBE_DEFAULT_MODEL_LIST: ChatModelCard[] = [ ZeroOneProvider.chatModels, ].flat(); +export const DEFAULT_MODEL_PROVIDER_LIST = [ + OpenAIProvider, + { ...AzureProvider, chatModels: [] }, + OllamaProvider, + AnthropicProvider, + GoogleProvider, + OpenRouterProvider, + TogetherAIProvider, + BedrockProvider, + PerplexityProvider, + MistralProvider, + GroqProvider, + MoonshotProvider, + ZeroOneProvider, + ZhiPuProvider, +]; + export const filterEnabledModels = (provider: ModelProviderCard) => { return provider.chatModels.filter((v) => v.enabled).map((m) => m.id); }; diff --git a/src/features/AgentSetting/AgentConfig/ModelSelect.tsx b/src/features/AgentSetting/AgentConfig/ModelSelect.tsx index f966530c58ab4..aa61124fac237 100644 --- a/src/features/AgentSetting/AgentConfig/ModelSelect.tsx +++ b/src/features/AgentSetting/AgentConfig/ModelSelect.tsx @@ -29,6 +29,7 @@ const ModelSelect = memo(() => { modelProviderSelectors.modelProviderListForModelSelect, isEqual, ); + const { styles } = useStyles(); const options = useMemo(() => { diff --git a/src/store/global/slices/common/action.ts b/src/store/global/slices/common/action.ts index 22d89cb3566bf..aeeb71b4e2ad4 100644 --- a/src/store/global/slices/common/action.ts +++ b/src/store/global/slices/common/action.ts @@ -62,6 +62,9 @@ export const createCommonSlice: StateCreator< refreshUserConfig: async () => { await mutate([USER_CONFIG_FETCH_KEY, true]); + + // when get the user config ,refresh the model provider list to the latest + get().refreshModelProviderList(); }, switchBackToChat: (sessionId) => { @@ -159,7 +162,10 @@ export const createCommonSlice: StateCreator< }; const defaultSettings = merge(get().defaultSettings, serverSettings); + set({ defaultSettings, serverConfig: data }, false, n('initGlobalConfig')); + + get().refreshDefaultModelProviderList(); } }, revalidateOnFocus: false, @@ -181,6 +187,9 @@ export const createCommonSlice: StateCreator< n('fetchUserConfig', data), ); + // when get the user config ,refresh the model provider list to the latest + get().refreshModelProviderList(); + const { language } = settingsSelectors.currentSettings(get()); if (language === 'auto') { switchLang('auto'); diff --git a/src/store/global/slices/settings/actions/llm.test.ts b/src/store/global/slices/settings/actions/llm.test.ts index df7c3b5ddd79a..9ecab760ea111 100644 --- a/src/store/global/slices/settings/actions/llm.test.ts +++ b/src/store/global/slices/settings/actions/llm.test.ts @@ -2,9 +2,18 @@ import { act, renderHook } from '@testing-library/react'; import { describe, expect, it, vi } from 'vitest'; import { userService } from '@/services/user'; -import { useGlobalStore } from '@/store/global'; -import { modelConfigSelectors, settingsSelectors } from '@/store/global/slices/settings/selectors'; +import { GlobalStore, useGlobalStore } from '@/store/global'; +import { + GlobalSettingsState, + initialSettingsState, +} from '@/store/global/slices/settings/initialState'; +import { + modelConfigSelectors, + modelProviderSelectors, + settingsSelectors, +} from '@/store/global/slices/settings/selectors'; import { GeneralModelProviderConfig } from '@/types/settings'; +import { merge } from '@/utils/merge'; import { CustomModelCardDispatch, customModelCardsReducer } from '../reducers/customModelCard'; @@ -15,9 +24,6 @@ vi.mock('@/services/user', () => ({ resetUserSettings: vi.fn(), }, })); -vi.mock('../reducers/customModelCard', () => ({ - customModelCardsReducer: vi.fn().mockReturnValue([]), -})); describe('LLMSettingsSliceAction', () => { describe('setModelProviderConfig', () => { @@ -57,4 +63,84 @@ describe('LLMSettingsSliceAction', () => { expect(result.current.setModelProviderConfig).not.toHaveBeenCalled(); }); }); + + describe('refreshDefaultModelProviderList', () => { + it('default', async () => { + const { result } = renderHook(() => useGlobalStore()); + + act(() => { + useGlobalStore.setState({ + serverConfig: { + languageModel: { + azure: { serverModelCards: [{ id: 'abc', deploymentName: 'abc' }] }, + }, + telemetry: {}, + }, + }); + }); + + act(() => { + result.current.refreshDefaultModelProviderList(); + }); + + // Assert that setModelProviderConfig was not called + const azure = result.current.defaultModelProviderList.find((m) => m.id === 'azure'); + expect(azure?.chatModels).toEqual([{ id: 'abc', deploymentName: 'abc' }]); + }); + }); + + describe('refreshModelProviderList', () => { + it('visible', async () => { + const { result } = renderHook(() => useGlobalStore()); + act(() => { + useGlobalStore.setState({ + settings: { + languageModel: { + ollama: { enabledModels: ['llava'] }, + }, + }, + }); + }); + + act(() => { + result.current.refreshModelProviderList(); + }); + + const ollamaList = result.current.modelProviderList.find((r) => r.id === 'ollama'); + // Assert that setModelProviderConfig was not called + expect(ollamaList?.chatModels.find((c) => c.id === 'llava')).toEqual({ + displayName: 'LLaVA 7B', + functionCall: false, + enabled: true, + id: 'llava', + tokens: 4000, + vision: true, + }); + }); + + it('modelProviderListForModelSelect should return only enabled providers', () => { + const { result } = renderHook(() => useGlobalStore()); + + act(() => { + useGlobalStore.setState({ + settings: { + languageModel: { + perplexity: { enabled: true }, + azure: { enabled: false }, + }, + }, + }); + }); + + act(() => { + result.current.refreshModelProviderList(); + }); + + const enabledProviders = modelProviderSelectors.modelProviderListForModelSelect( + result.current, + ); + expect(enabledProviders).toHaveLength(2); + expect(enabledProviders[1].id).toBe('perplexity'); + }); + }); }); diff --git a/src/store/global/slices/settings/actions/llm.ts b/src/store/global/slices/settings/actions/llm.ts index 0c4711e8776b9..02faa096afd1e 100644 --- a/src/store/global/slices/settings/actions/llm.ts +++ b/src/store/global/slices/settings/actions/llm.ts @@ -1,11 +1,28 @@ import useSWR, { SWRResponse } from 'swr'; import type { StateCreator } from 'zustand/vanilla'; +import { + AnthropicProviderCard, + AzureProviderCard, + BedrockProviderCard, + GoogleProviderCard, + GroqProviderCard, + MistralProviderCard, + MoonshotProviderCard, + OllamaProviderCard, + OpenAIProviderCard, + OpenRouterProviderCard, + PerplexityProviderCard, + TogetherAIProviderCard, + ZeroOneProviderCard, + ZhiPuProviderCard, +} from '@/config/modelProviders'; import { GlobalStore } from '@/store/global'; import { ChatModelCard } from '@/types/llm'; import { GlobalLLMConfig, GlobalLLMProviderKey } from '@/types/settings'; import { CustomModelCardDispatch, customModelCardsReducer } from '../reducers/customModelCard'; +import { modelProviderSelectors } from '../selectors/modelProvider'; import { settingsSelectors } from '../selectors/settings'; /** @@ -16,12 +33,18 @@ export interface LLMSettingsAction { provider: GlobalLLMProviderKey, payload: CustomModelCardDispatch, ) => Promise; + /** + * make sure the default model provider list is sync to latest state + */ + refreshDefaultModelProviderList: () => void; + refreshModelProviderList: () => void; removeEnabledModels: (provider: GlobalLLMProviderKey, model: string) => Promise; setModelProviderConfig: ( provider: T, config: Partial, ) => Promise; toggleEditingCustomModelCard: (params?: { id: string; provider: GlobalLLMProviderKey }) => void; + toggleProviderEnabled: (provider: GlobalLLMProviderKey, enabled: boolean) => Promise; useFetchProviderModelList: ( @@ -46,6 +69,76 @@ export const llmSettingsSlice: StateCreator< await get().setModelProviderConfig(provider, { customModelCards: nextState }); }, + refreshDefaultModelProviderList: () => { + /** + * Because we have several model cards sources, we need to merge the model cards + * the priority is below: + * 1 - server side model cards + * 2 - remote model cards + * 3 - default model cards + */ + + // eslint-disable-next-line unicorn/consistent-function-scoping + const mergeModels = (provider: GlobalLLMProviderKey, defaultChatModels: ChatModelCard[]) => { + // if the chat model is config in the server side, use the server side model cards + const serverChatModels = modelProviderSelectors.serverProviderModelCards(provider)(get()); + const remoteChatModels = modelProviderSelectors.remoteProviderModelCards(provider)(get()); + + return serverChatModels ?? remoteChatModels ?? defaultChatModels; + }; + + const defaultModelProviderList = [ + { + ...OpenAIProviderCard, + chatModels: mergeModels('openai', OpenAIProviderCard.chatModels), + }, + { ...AzureProviderCard, chatModels: mergeModels('azure', []) }, + { ...OllamaProviderCard, chatModels: mergeModels('ollama', OllamaProviderCard.chatModels) }, + AnthropicProviderCard, + GoogleProviderCard, + { + ...OpenRouterProviderCard, + chatModels: mergeModels('openrouter', OpenRouterProviderCard.chatModels), + }, + { + ...TogetherAIProviderCard, + chatModels: mergeModels('togetherai', TogetherAIProviderCard.chatModels), + }, + BedrockProviderCard, + PerplexityProviderCard, + MistralProviderCard, + GroqProviderCard, + MoonshotProviderCard, + ZeroOneProviderCard, + ZhiPuProviderCard, + ]; + + set({ defaultModelProviderList }, false, 'refreshDefaultModelProviderList'); + + get().refreshModelProviderList(); + }, + + refreshModelProviderList: () => { + const modelProviderList = get().defaultModelProviderList.map((list) => ({ + ...list, + chatModels: modelProviderSelectors + .getModelCardsById(list.id)(get()) + ?.map((model) => { + const models = modelProviderSelectors.getEnableModelsById(list.id)(get()); + + if (!models) return model; + + return { + ...model, + enabled: models?.some((m) => m === model.id), + }; + }), + enabled: modelProviderSelectors.isProviderEnabled(list.id as any)(get()), + })); + + set({ modelProviderList }, false, 'refreshModelProviderList'); + }, + removeEnabledModels: async (provider, model) => { const config = settingsSelectors.providerConfig(provider)(get()); @@ -60,6 +153,7 @@ export const llmSettingsSlice: StateCreator< toggleEditingCustomModelCard: (params) => { set({ editingCustomCardModel: params }, false, 'toggleEditingCustomModelCard'); }, + toggleProviderEnabled: async (provider, enabled) => { await get().setSettings({ languageModel: { [provider]: { enabled } } }); }, @@ -79,6 +173,8 @@ export const llmSettingsSlice: StateCreator< latestFetchTime: Date.now(), remoteModelCards: data, }); + + get().refreshDefaultModelProviderList(); } }, revalidateOnFocus: false, diff --git a/src/store/global/slices/settings/initialState.ts b/src/store/global/slices/settings/initialState.ts index ee7b0506e101c..db49008ad70b5 100644 --- a/src/store/global/slices/settings/initialState.ts +++ b/src/store/global/slices/settings/initialState.ts @@ -1,20 +1,26 @@ import { DeepPartial } from 'utility-types'; +import { DEFAULT_MODEL_PROVIDER_LIST } from '@/config/modelProviders'; import { DEFAULT_SETTINGS } from '@/const/settings'; +import { ModelProviderCard } from '@/types/llm'; import { GlobalServerConfig } from '@/types/serverConfig'; import { GlobalSettings } from '@/types/settings'; export interface GlobalSettingsState { avatar?: string; + defaultModelProviderList: ModelProviderCard[]; defaultSettings: GlobalSettings; editingCustomCardModel?: { id: string; provider: string } | undefined; + modelProviderList: ModelProviderCard[]; serverConfig: GlobalServerConfig; settings: DeepPartial; userId?: string; } export const initialSettingsState: GlobalSettingsState = { + defaultModelProviderList: DEFAULT_MODEL_PROVIDER_LIST, defaultSettings: DEFAULT_SETTINGS, + modelProviderList: DEFAULT_MODEL_PROVIDER_LIST, serverConfig: { telemetry: {}, }, diff --git a/src/store/global/slices/settings/selectors/modelProvider.test.ts b/src/store/global/slices/settings/selectors/modelProvider.test.ts index 48bb592beaba0..ce1074b59bcc2 100644 --- a/src/store/global/slices/settings/selectors/modelProvider.test.ts +++ b/src/store/global/slices/settings/selectors/modelProvider.test.ts @@ -7,71 +7,7 @@ import { GlobalSettingsState, initialSettingsState } from '../initialState'; import { getDefaultModeProviderById, modelProviderSelectors } from './modelProvider'; describe('modelProviderSelectors', () => { - describe('providerListWithConfig', () => { - it('visible', () => { - const s = merge(initialSettingsState, { - settings: { - languageModel: { - ollama: { - enabledModels: ['llava'], - }, - }, - }, - } as GlobalSettingsState) as unknown as GlobalStore; - - const ollamaList = modelProviderSelectors.modelProviderList(s).find((r) => r.id === 'ollama'); - - expect(ollamaList?.chatModels.find((c) => c.id === 'llava')).toEqual({ - displayName: 'LLaVA 7B', - functionCall: false, - enabled: true, - id: 'llava', - tokens: 4000, - vision: true, - }); - }); - it('with user custom models', () => { - const s = merge(initialSettingsState, { - settings: { - languageModel: { - perplexity: { - customModelCards: [{ id: 'sonar-online', displayName: 'Sonar Online' }], - }, - }, - }, - } as GlobalSettingsState) as unknown as GlobalStore; - - const providerList = modelProviderSelectors - .modelProviderList(s) - .find((r) => r.id === 'perplexity'); - - expect(providerList?.chatModels.find((c) => c.id === 'sonar-online')).toEqual({ - id: 'sonar-online', - displayName: 'Sonar Online', - enabled: false, - isCustom: true, - }); - }); - }); - - describe('providerListForModelSelect', () => { - it('should return only enabled providers', () => { - const s = merge(initialSettingsState, { - settings: { - languageModel: { - perplexity: { enabled: true }, - azure: { enabled: false }, - }, - }, - } as GlobalSettingsState) as unknown as GlobalStore; - - const enabledProviders = modelProviderSelectors.modelProviderListForModelSelect(s); - expect(enabledProviders).toHaveLength(2); - expect(enabledProviders[1].id).toBe('perplexity'); - }); - }); - - describe('providerCard', () => { + describe('getDefaultModeProviderById', () => { it('should return the correct ModelProviderCard when provider ID matches', () => { const s = merge(initialSettingsState, {}) as unknown as GlobalStore; diff --git a/src/store/global/slices/settings/selectors/modelProvider.ts b/src/store/global/slices/settings/selectors/modelProvider.ts index 90fafaae6296c..f70a7cd8ae2a9 100644 --- a/src/store/global/slices/settings/selectors/modelProvider.ts +++ b/src/store/global/slices/settings/selectors/modelProvider.ts @@ -1,22 +1,6 @@ import { uniqBy } from 'lodash-es'; -import { - AnthropicProviderCard, - AzureProviderCard, - BedrockProviderCard, - GoogleProviderCard, - GroqProviderCard, - MistralProviderCard, - MoonshotProviderCard, - OllamaProviderCard, - OpenAIProviderCard, - OpenRouterProviderCard, - PerplexityProviderCard, - TogetherAIProviderCard, - ZeroOneProviderCard, - ZhiPuProviderCard, - filterEnabledModels, -} from '@/config/modelProviders'; +import { filterEnabledModels } from '@/config/modelProviders'; import { ChatModelCard, ModelProviderCard } from '@/types/llm'; import { ServerModelProviderConfig } from '@/types/serverConfig'; import { GlobalLLMProviderKey } from '@/types/settings'; @@ -59,49 +43,8 @@ const isProviderEnabled = (provider: GlobalLLMProviderKey) => (s: GlobalStore) = /** * define all the model list of providers */ -const defaultModelProviderList = (s: GlobalStore): ModelProviderCard[] => { - /** - * Because we have several model cards sources, we need to merge the model cards - * the priority is below: - * 1 - server side model cards - * 2 - remote model cards - * 3 - default model cards - */ - - const mergeModels = (provider: GlobalLLMProviderKey, defaultChatModels: ChatModelCard[]) => { - // if the chat model is config in the server side, use the server side model cards - const serverChatModels = serverProviderModelCards(provider)(s); - const remoteChatModels = remoteProviderModelCards(provider)(s); - - return serverChatModels ?? remoteChatModels ?? defaultChatModels; - }; - - return [ - { - ...OpenAIProviderCard, - chatModels: mergeModels('openai', OpenAIProviderCard.chatModels), - }, - { ...AzureProviderCard, chatModels: mergeModels('azure', []) }, - { ...OllamaProviderCard, chatModels: mergeModels('ollama', OllamaProviderCard.chatModels) }, - AnthropicProviderCard, - GoogleProviderCard, - { - ...OpenRouterProviderCard, - chatModels: mergeModels('openrouter', OpenRouterProviderCard.chatModels), - }, - { - ...TogetherAIProviderCard, - chatModels: mergeModels('togetherai', TogetherAIProviderCard.chatModels), - }, - BedrockProviderCard, - PerplexityProviderCard, - MistralProviderCard, - GroqProviderCard, - MoonshotProviderCard, - ZeroOneProviderCard, - ZhiPuProviderCard, - ]; -}; +const defaultModelProviderList = (s: GlobalStore): ModelProviderCard[] => + s.defaultModelProviderList; export const getDefaultModeProviderById = (provider: string) => (s: GlobalStore) => defaultModelProviderList(s).find((s) => s.id === provider); @@ -146,21 +89,7 @@ const getEnableModelsById = (provider: string) => (s: GlobalStore) => { return getProviderConfigById(provider)(s)?.enabledModels?.filter(Boolean); }; -const modelProviderList = (s: GlobalStore): ModelProviderCard[] => - defaultModelProviderList(s).map((list) => ({ - ...list, - chatModels: getModelCardsById(list.id)(s)?.map((model) => { - const models = getEnableModelsById(list.id)(s); - - if (!models) return model; - - return { - ...model, - enabled: models?.some((m) => m === model.id), - }; - }), - enabled: isProviderEnabled(list.id as any)(s), - })); +const modelProviderList = (s: GlobalStore): ModelProviderCard[] => s.modelProviderList; const modelProviderListForModelSelect = (s: GlobalStore): ModelProviderCard[] => modelProviderList(s) @@ -196,22 +125,26 @@ const modelMaxToken = (id: string) => (s: GlobalStore) => getModelCardById(id)(s export const modelProviderSelectors = { defaultModelProviderList, - getDefaultEnabledModelsById, getDefaultModelCardById, getEnableModelsById, getModelCardById, - getModelCardsById, + getModelCardsById, isModelEnabledFiles, isModelEnabledFunctionCall, isModelEnabledUpload, isModelEnabledVision, isModelHasMaxToken, - modelMaxToken, + isProviderEnabled, + modelMaxToken, modelProviderList, + modelProviderListForModelSelect, + + remoteProviderModelCards, + serverProviderModelCards, };