diff --git a/src/core/domain/repositories/ProviderRepository.ts b/src/core/domain/repositories/ProviderRepository.ts index 6d36576..cee083b 100644 --- a/src/core/domain/repositories/ProviderRepository.ts +++ b/src/core/domain/repositories/ProviderRepository.ts @@ -22,6 +22,24 @@ const findAll = async () => { }); }; +// 查找所有提供商(包括禁用的) +const findAllIncludeDisabled = async () => { + return await prisma.provider.findMany({ + select: { + id: true, + name: true, + type: true, + api_key: true, + base_url: true, + suffix: true, + status: true, + createdAt: true, + updatedAt: true, + }, + orderBy: [{ createdAt: "desc" }], + }); +}; + // 根据ID查找提供商 const findById = async (id: number) => { return await prisma.provider.findUnique({ @@ -83,6 +101,7 @@ const updateStatus = async (id: number, status: number) => { export default { findAll, + findAllIncludeDisabled, findById, create, update, diff --git a/src/main/ipc/__tests__/handlers.test.ts b/src/main/ipc/__tests__/handlers.test.ts index 16715ab..f749673 100644 --- a/src/main/ipc/__tests__/handlers.test.ts +++ b/src/main/ipc/__tests__/handlers.test.ts @@ -3,6 +3,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' // Mock all dependencies before imports const mockProviderDal = { findAll: vi.fn(), + findAllIncludeDisabled: vi.fn(), findById: vi.fn(), create: vi.fn(), update: vi.fn(), @@ -191,7 +192,7 @@ describe('IPC Handlers', () => { { id: 1, name: 'OpenAI', type: 'openai' }, { id: 2, name: 'Anthropic', type: 'anthropic' } ] - mockProviderDal.findAll.mockResolvedValue(mockProviders) + mockProviderDal.findAllIncludeDisabled.mockResolvedValue(mockProviders) const handler = handlers.get('provider:getAll') const result = await handler!({}, {}) @@ -200,11 +201,11 @@ describe('IPC Handlers', () => { success: true, data: mockProviders }) - expect(mockProviderDal.findAll).toHaveBeenCalled() + expect(mockProviderDal.findAllIncludeDisabled).toHaveBeenCalled() }) it('should handle errors', async () => { - mockProviderDal.findAll.mockRejectedValue(new Error('Database error')) + mockProviderDal.findAllIncludeDisabled.mockRejectedValue(new Error('Database error')) const handler = handlers.get('provider:getAll') const result = await handler!({}, {}) diff --git a/src/main/ipc/handlers/__tests__/provider.handler.test.ts b/src/main/ipc/handlers/__tests__/provider.handler.test.ts index ece3fc7..e9c054b 100644 --- a/src/main/ipc/handlers/__tests__/provider.handler.test.ts +++ b/src/main/ipc/handlers/__tests__/provider.handler.test.ts @@ -3,6 +3,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' // Mock dependencies const mockProviderRepository = { findAll: vi.fn(), + findAllIncludeDisabled: vi.fn(), findById: vi.fn(), create: vi.fn(), update: vi.fn(), @@ -57,7 +58,7 @@ describe('Provider Handler', () => { { id: 1, name: 'OpenAI', type: 'openai' }, { id: 2, name: 'Anthropic', type: 'anthropic' } ] - mockProviderRepository.findAll.mockResolvedValue(mockProviders) + mockProviderRepository.findAllIncludeDisabled.mockResolvedValue(mockProviders) const handler = handlers.get('provider:getAll') const result = await handler!({}) @@ -66,11 +67,11 @@ describe('Provider Handler', () => { success: true, data: mockProviders }) - expect(mockProviderRepository.findAll).toHaveBeenCalled() + expect(mockProviderRepository.findAllIncludeDisabled).toHaveBeenCalled() }) it('should handle errors', async () => { - mockProviderRepository.findAll.mockRejectedValue(new Error('Database error')) + mockProviderRepository.findAllIncludeDisabled.mockRejectedValue(new Error('Database error')) const handler = handlers.get('provider:getAll') const result = await handler!({}) diff --git a/src/main/ipc/handlers/provider.handler.ts b/src/main/ipc/handlers/provider.handler.ts index 908af9c..7d8db1e 100644 --- a/src/main/ipc/handlers/provider.handler.ts +++ b/src/main/ipc/handlers/provider.handler.ts @@ -8,11 +8,11 @@ import type { IpcResponse } from "../../../shared/ipc/responses.js"; */ export function registerProviderHandlers() { /** - * Get all providers + * Get all providers (including disabled) */ ipcMain.handle(IPC_CHANNELS.PROVIDER.GET_ALL, async (): Promise => { try { - const providers = await providerRepository.findAll(); + const providers = await providerRepository.findAllIncludeDisabled(); return { success: true, data: providers }; } catch (error: any) { console.error("[IPC] provider:getAll error:", error); diff --git a/src/renderer/components/Provider.tsx b/src/renderer/components/Provider.tsx index cd1f31a..5c95724 100644 --- a/src/renderer/components/Provider.tsx +++ b/src/renderer/components/Provider.tsx @@ -15,7 +15,7 @@ import { List, App, } from "antd"; -import React, { useEffect, useState, useCallback, useRef } from "react"; +import React, { useCallback, useEffect, useState } from "react"; import { useTranslation } from "react-i18next"; interface ProviderProps { @@ -51,9 +51,6 @@ const Provider: React.FC = ({ // 添加测试状态 const [testingModelId, setTestingModelId] = useState(""); - // 使用 ref 存储 fetchModels 函数,避免初始化顺序问题 - const fetchModelsRef = useRef<() => Promise>(() => Promise.resolve()); - useEffect(() => { // 如果有providerId,则获取该服务商的详细信息 if (providerId) { @@ -113,13 +110,10 @@ const Provider: React.FC = ({ }; fetchProviderDetails(); - - // 获取该服务商下的所有模型 - fetchModelsRef.current(); } }, [providerId, t]); - // 获取服务商下的模型列表 + // 获取该服务商下的所有模型 const fetchModels = useCallback(async () => { if (!providerId) return; @@ -140,10 +134,12 @@ const Provider: React.FC = ({ } }, [providerId, message, t]); - // 更新 ref + // 初始加载时获取模型列表 useEffect(() => { - fetchModelsRef.current = fetchModels; - }, [fetchModels]); + if (providerId) { + fetchModels(); + } + }, [providerId, fetchModels]); // 删除模型 const deleteModel = async (modelId: string) => { @@ -158,7 +154,7 @@ const Provider: React.FC = ({ message.success(t('messages.delete_model_success')); // 刷新模型列表 - fetchModelsRef.current(); + fetchModels(); } catch (error) { console.error("Failed to delete model:", error); message.error( @@ -191,7 +187,7 @@ const Provider: React.FC = ({ setNewModelId(""); message.success(t('messages.add_model_success')); // 刷新模型列表 - fetchModelsRef.current(); + fetchModels(); } catch (error) { console.error("Failed to add model:", error); message.error(