Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/core/domain/repositories/ProviderRepository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ const findAll = async () => {
});
};

// 查找所有提供商(包括禁用的)
const findAllIncludeDisabled = async () => {
return await prisma.provider.findMany({
select: {
id: true,
name: true,
type: true,
api_key: true,
base_url: true,
suffix: true,
status: true,
createdAt: true,
updatedAt: true,
},
orderBy: [{ createdAt: "desc" }],
});
};

// 根据ID查找提供商
const findById = async (id: number) => {
return await prisma.provider.findUnique({
Expand Down Expand Up @@ -83,6 +101,7 @@ const updateStatus = async (id: number, status: number) => {

export default {
findAll,
findAllIncludeDisabled,
findById,
create,
update,
Expand Down
7 changes: 4 additions & 3 deletions src/main/ipc/__tests__/handlers.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'
// Mock all dependencies before imports
const mockProviderDal = {
findAll: vi.fn(),
findAllIncludeDisabled: vi.fn(),
findById: vi.fn(),
create: vi.fn(),
update: vi.fn(),
Expand Down Expand Up @@ -191,7 +192,7 @@ describe('IPC Handlers', () => {
{ id: 1, name: 'OpenAI', type: 'openai' },
{ id: 2, name: 'Anthropic', type: 'anthropic' }
]
mockProviderDal.findAll.mockResolvedValue(mockProviders)
mockProviderDal.findAllIncludeDisabled.mockResolvedValue(mockProviders)

const handler = handlers.get('provider:getAll')
const result = await handler!({}, {})
Expand All @@ -200,11 +201,11 @@ describe('IPC Handlers', () => {
success: true,
data: mockProviders
})
expect(mockProviderDal.findAll).toHaveBeenCalled()
expect(mockProviderDal.findAllIncludeDisabled).toHaveBeenCalled()
})

it('should handle errors', async () => {
mockProviderDal.findAll.mockRejectedValue(new Error('Database error'))
mockProviderDal.findAllIncludeDisabled.mockRejectedValue(new Error('Database error'))

const handler = handlers.get('provider:getAll')
const result = await handler!({}, {})
Expand Down
7 changes: 4 additions & 3 deletions src/main/ipc/handlers/__tests__/provider.handler.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'
// Mock dependencies
const mockProviderRepository = {
findAll: vi.fn(),
findAllIncludeDisabled: vi.fn(),
findById: vi.fn(),
create: vi.fn(),
update: vi.fn(),
Expand Down Expand Up @@ -57,7 +58,7 @@ describe('Provider Handler', () => {
{ id: 1, name: 'OpenAI', type: 'openai' },
{ id: 2, name: 'Anthropic', type: 'anthropic' }
]
mockProviderRepository.findAll.mockResolvedValue(mockProviders)
mockProviderRepository.findAllIncludeDisabled.mockResolvedValue(mockProviders)

const handler = handlers.get('provider:getAll')
const result = await handler!({})
Expand All @@ -66,11 +67,11 @@ describe('Provider Handler', () => {
success: true,
data: mockProviders
})
expect(mockProviderRepository.findAll).toHaveBeenCalled()
expect(mockProviderRepository.findAllIncludeDisabled).toHaveBeenCalled()
})

it('should handle errors', async () => {
mockProviderRepository.findAll.mockRejectedValue(new Error('Database error'))
mockProviderRepository.findAllIncludeDisabled.mockRejectedValue(new Error('Database error'))

const handler = handlers.get('provider:getAll')
const result = await handler!({})
Expand Down
4 changes: 2 additions & 2 deletions src/main/ipc/handlers/provider.handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import type { IpcResponse } from "../../../shared/ipc/responses.js";
*/
export function registerProviderHandlers() {
/**
* Get all providers
* Get all providers (including disabled)
*/
ipcMain.handle(IPC_CHANNELS.PROVIDER.GET_ALL, async (): Promise<IpcResponse> => {
try {
const providers = await providerRepository.findAll();
const providers = await providerRepository.findAllIncludeDisabled();
return { success: true, data: providers };
} catch (error: any) {
console.error("[IPC] provider:getAll error:", error);
Expand Down
22 changes: 9 additions & 13 deletions src/renderer/components/Provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
List,
App,
} from "antd";
import React, { useEffect, useState, useCallback, useRef } from "react";
import React, { useCallback, useEffect, useState } from "react";
import { useTranslation } from "react-i18next";

interface ProviderProps {
Expand Down Expand Up @@ -51,9 +51,6 @@ const Provider: React.FC<ProviderProps> = ({
// 添加测试状态
const [testingModelId, setTestingModelId] = useState<string>("");

// 使用 ref 存储 fetchModels 函数,避免初始化顺序问题
const fetchModelsRef = useRef<() => Promise<void>>(() => Promise.resolve());

useEffect(() => {
// 如果有providerId,则获取该服务商的详细信息
if (providerId) {
Expand Down Expand Up @@ -113,13 +110,10 @@ const Provider: React.FC<ProviderProps> = ({
};

fetchProviderDetails();

// 获取该服务商下的所有模型
fetchModelsRef.current();
}
}, [providerId, t]);

// 获取服务商下的模型列表
// 获取该服务商下的所有模型
const fetchModels = useCallback(async () => {
if (!providerId) return;

Expand All @@ -140,10 +134,12 @@ const Provider: React.FC<ProviderProps> = ({
}
}, [providerId, message, t]);

// 更新 ref
// 初始加载时获取模型列表
useEffect(() => {
fetchModelsRef.current = fetchModels;
}, [fetchModels]);
if (providerId) {
fetchModels();
}
}, [providerId, fetchModels]);

// 删除模型
const deleteModel = async (modelId: string) => {
Expand All @@ -158,7 +154,7 @@ const Provider: React.FC<ProviderProps> = ({

message.success(t('messages.delete_model_success'));
// 刷新模型列表
fetchModelsRef.current();
fetchModels();
} catch (error) {
console.error("Failed to delete model:", error);
message.error(
Expand Down Expand Up @@ -191,7 +187,7 @@ const Provider: React.FC<ProviderProps> = ({
setNewModelId("");
message.success(t('messages.add_model_success'));
// 刷新模型列表
fetchModelsRef.current();
fetchModels();
} catch (error) {
console.error("Failed to add model:", error);
message.error(
Expand Down