From 2852fdaed49ab997ba8d5ba6db1e18bfbc32d990 Mon Sep 17 00:00:00 2001 From: mxyhi Date: Fri, 30 Jan 2026 13:53:59 +0800 Subject: [PATCH 01/10] feat(agents): centralize agent configs to switch AI tool setups easily MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Allow managing per-tool profiles in one place so switching tools doesn’t require re-setup. Streamline the setup view and avoid redundant model-mapping hints. --- messages/en.json | 25 ++- messages/zh.json | 25 ++- src/features/agents/AgentCard.tsx | 108 +++++++++++ src/features/agents/AgentDeleteDialog.tsx | 42 ++++ src/features/agents/AgentEditorDialog.tsx | 101 ++++++++++ src/features/agents/AgentsPanel.tsx | 124 ++++++++++++ src/features/agents/agent-store.ts | 169 ++++++++++++++++ src/features/agents/constants.ts | 29 +++ src/features/agents/index.ts | 7 + src/features/agents/types.ts | 41 ++++ src/features/config/AppView.tsx | 98 ++++++++-- .../config/cards/client-setup-card.tsx | 108 +++-------- src/features/config/cards/client-setup-ui.tsx | 183 +++--------------- .../dashboard/RecentRequestsTable.tsx | 4 +- 14 files changed, 801 insertions(+), 263 deletions(-) create mode 100644 src/features/agents/AgentCard.tsx create mode 100644 src/features/agents/AgentDeleteDialog.tsx create mode 100644 src/features/agents/AgentEditorDialog.tsx create mode 100644 src/features/agents/AgentsPanel.tsx create mode 100644 src/features/agents/agent-store.ts create mode 100644 src/features/agents/constants.ts create mode 100644 src/features/agents/index.ts create mode 100644 src/features/agents/types.ts diff --git a/messages/en.json b/messages/en.json index 9464669..e3cf664 100644 --- a/messages/en.json +++ b/messages/en.json @@ -432,5 +432,28 @@ "providers_status_summary": "Active {active} · Expired {expired}", "providers_quota_loading": "Loading quota...", "providers_account_delete_title": "Delete account?", - "providers_account_delete_description": "This will delete account {label}." + "providers_account_delete_description": "This will delete account {label}.", + "common_edit": "Edit", + "common_add": "Add", + "agents_title": "Agents", + "agents_desc": "Manage agent configurations for different AI tools.", + "agents_active": "Active", + "agents_switch": "Switch to this agent", + "agents_add": "Add Agent", + "agents_empty": "No agents configured yet.", + "agents_empty_hint": "Add an agent to get started.", + "agents_delete_title": "Delete agent?", + "agents_delete_description": "This will delete agent \"{name}\".", + "agents_editor_title_add": "Add Agent", + "agents_editor_title_edit": "Edit Agent", + "agents_editor_name_label": "Name", + "agents_editor_name_placeholder": "My Claude Config", + "agents_editor_type_label": "Tool Type", + "agents_editor_type_placeholder": "Select tool type", + "agents_tool_claude": "Claude Code", + "agents_tool_claude_desc": "Anthropic's Claude Code CLI tool", + "agents_tool_codex": "Codex CLI", + "agents_tool_codex_desc": "OpenAI's Codex CLI tool", + "agents_tool_opencode": "OpenCode", + "agents_tool_opencode_desc": "Open source AI coding assistant" } diff --git a/messages/zh.json b/messages/zh.json index 231e833..9d9c9da 100644 --- a/messages/zh.json +++ b/messages/zh.json @@ -432,5 +432,28 @@ "providers_status_summary": "正常 {active} · 过期 {expired}", "providers_quota_loading": "正在加载额度…", "providers_account_delete_title": "删除账户?", - "providers_account_delete_description": "将删除账户 {label}。" + "providers_account_delete_description": "将删除账户 {label}。", + "common_edit": "编辑", + "common_add": "添加", + "agents_title": "Agents", + "agents_desc": "管理不同 AI 工具的 Agent 配置。", + "agents_active": "激活", + "agents_switch": "切换到此 Agent", + "agents_add": "添加 Agent", + "agents_empty": "暂无 Agent 配置。", + "agents_empty_hint": "添加一个 Agent 开始使用。", + "agents_delete_title": "删除 Agent?", + "agents_delete_description": "将删除 Agent「{name}」。", + "agents_editor_title_add": "添加 Agent", + "agents_editor_title_edit": "编辑 Agent", + "agents_editor_name_label": "名称", + "agents_editor_name_placeholder": "我的 Claude 配置", + "agents_editor_type_label": "工具类型", + "agents_editor_type_placeholder": "选择工具类型", + "agents_tool_claude": "Claude Code", + "agents_tool_claude_desc": "Anthropic 的 Claude Code CLI 工具", + "agents_tool_codex": "Codex CLI", + "agents_tool_codex_desc": "OpenAI 的 Codex CLI 工具", + "agents_tool_opencode": "OpenCode", + "agents_tool_opencode_desc": "开源 AI 编程助手" } diff --git a/src/features/agents/AgentCard.tsx b/src/features/agents/AgentCard.tsx new file mode 100644 index 0000000..7ee0e9e --- /dev/null +++ b/src/features/agents/AgentCard.tsx @@ -0,0 +1,108 @@ +import { MoreHorizontal, Power, Pencil, Trash2, GripVertical } from "lucide-react"; + +import { Badge } from "@/components/ui/badge"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { m } from "@/paraglide/messages.js"; + +import type { AgentConfig } from "./types"; +import { AGENT_TOOL_META } from "./constants"; + +type AgentCardProps = { + agent: AgentConfig; + onSwitch: (id: string) => void; + onEdit: (agent: AgentConfig) => void; + onDelete: (id: string) => void; + dragHandleProps?: React.HTMLAttributes; +}; + +export function AgentCard({ + agent, + onSwitch, + onEdit, + onDelete, + dragHandleProps, +}: AgentCardProps) { + const toolMeta = AGENT_TOOL_META[agent.type]; + + return ( + + + {/* 拖拽手柄 */} + {dragHandleProps && ( + + )} + + {/* 主要内容 */} +
+
+ {agent.name} + {agent.isActive && ( + + {m.agents_active()} + + )} +
+

{toolMeta.label()}

+
+ + {/* 操作按钮 */} +
+ {/* 切换激活按钮 */} + {!agent.isActive && ( + + )} + + {/* 更多操作 */} + + + + + + onEdit(agent)}> + + {m.common_edit()} + + + onDelete(agent.id)} + > + + {m.common_delete()} + + + +
+
+
+ ); +} diff --git a/src/features/agents/AgentDeleteDialog.tsx b/src/features/agents/AgentDeleteDialog.tsx new file mode 100644 index 0000000..5b32e23 --- /dev/null +++ b/src/features/agents/AgentDeleteDialog.tsx @@ -0,0 +1,42 @@ +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { m } from "@/paraglide/messages.js"; + +type AgentDeleteDialogProps = { + open: boolean; + onOpenChange: (open: boolean) => void; + agentName: string; + onConfirm: () => void; +}; + +export function AgentDeleteDialog({ + open, + onOpenChange, + agentName, + onConfirm, +}: AgentDeleteDialogProps) { + return ( + + + + {m.agents_delete_title()} + + {m.agents_delete_description({ name: agentName })} + + + + {m.common_cancel()} + {m.common_delete()} + + + + ); +} diff --git a/src/features/agents/AgentEditorDialog.tsx b/src/features/agents/AgentEditorDialog.tsx new file mode 100644 index 0000000..e5301ea --- /dev/null +++ b/src/features/agents/AgentEditorDialog.tsx @@ -0,0 +1,101 @@ +import { useCallback, useEffect, useState } from "react"; + +import { Button } from "@/components/ui/button"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { m } from "@/paraglide/messages.js"; + +import type { AgentConfig, AgentConfigForm, AgentToolType } from "./types"; +import { AGENT_TOOL_META } from "./constants"; + +type AgentEditorDialogProps = { + open: boolean; + onOpenChange: (open: boolean) => void; + /** 编辑模式时传入现有配置,添加模式时为 null */ + editingAgent: AgentConfig | null; + /** 添加模式时的默认工具类型 */ + defaultType: AgentToolType; + onSave: (form: AgentConfigForm) => void; +}; + +export function AgentEditorDialog({ + open, + onOpenChange, + editingAgent, + defaultType, + onSave, +}: AgentEditorDialogProps) { + const [name, setName] = useState(""); + + const isEditing = editingAgent !== null; + const title = isEditing ? m.agents_editor_title_edit() : m.agents_editor_title_add(); + const currentType = isEditing ? editingAgent.type : defaultType; + const toolMeta = AGENT_TOOL_META[currentType]; + + // 打开对话框时初始化表单 + useEffect(() => { + if (open) { + if (editingAgent) { + setName(editingAgent.name); + } else { + setName(""); + } + } + }, [open, editingAgent]); + + const handleNameChange = useCallback((e: React.ChangeEvent) => { + setName(e.target.value); + }, []); + + const handleSave = useCallback(() => { + if (!name.trim()) return; + onSave({ name: name.trim(), type: currentType }); + onOpenChange(false); + }, [name, currentType, onSave, onOpenChange]); + + const canSave = name.trim().length > 0; + + return ( + + + + {title} + + {toolMeta.label()} - {toolMeta.description()} + + + +
+ {/* 名称 */} +
+ + +
+
+ + + + + +
+
+ ); +} diff --git a/src/features/agents/AgentsPanel.tsx b/src/features/agents/AgentsPanel.tsx new file mode 100644 index 0000000..a1a5a0c --- /dev/null +++ b/src/features/agents/AgentsPanel.tsx @@ -0,0 +1,124 @@ +import { useCallback, useEffect, useState } from "react"; +import { Bot } from "lucide-react"; + +import { m } from "@/paraglide/messages.js"; + +import type { AgentConfig, AgentConfigForm, AgentToolType } from "./types"; +import { useAgentStore } from "./agent-store"; +import { AgentCard } from "./AgentCard"; +import { AgentEditorDialog } from "./AgentEditorDialog"; +import { AgentDeleteDialog } from "./AgentDeleteDialog"; + +type AgentsPanelProps = { + /** 当前选中的工具类型 */ + selectedTool: AgentToolType; + /** 添加对话框触发器(每次变化时打开添加对话框) */ + addTrigger: number; +}; + +export function AgentsPanel({ selectedTool, addTrigger }: AgentsPanelProps) { + const store = useAgentStore(); + + // 编辑对话框状态 + const [editorOpen, setEditorOpen] = useState(false); + const [editingAgent, setEditingAgent] = useState(null); + + // 删除对话框状态 + const [deleteOpen, setDeleteOpen] = useState(false); + const [deletingAgent, setDeletingAgent] = useState(null); + + // 根据 selectedTool 过滤 agents + const filteredAgents = store.agents.filter((agent) => agent.type === selectedTool); + + // 监听 addTrigger 变化,打开添加对话框 + useEffect(() => { + if (addTrigger > 0) { + setEditingAgent(null); + setEditorOpen(true); + } + }, [addTrigger]); + + // 打开编辑对话框 + const handleEdit = useCallback((agent: AgentConfig) => { + setEditingAgent(agent); + setEditorOpen(true); + }, []); + + // 保存(添加或更新) + const handleSave = useCallback( + (form: AgentConfigForm) => { + if (editingAgent) { + store.updateAgent(editingAgent.id, form); + } else { + // 添加时使用当前选中的工具类型 + store.addAgent({ ...form, type: selectedTool }); + } + }, + [editingAgent, store, selectedTool] + ); + + // 打开删除确认对话框 + const handleDeleteClick = useCallback((id: string) => { + const agent = store.agents.find((a) => a.id === id); + if (agent) { + setDeletingAgent(agent); + setDeleteOpen(true); + } + }, [store.agents]); + + // 确认删除 + const handleDeleteConfirm = useCallback(() => { + if (deletingAgent) { + store.deleteAgent(deletingAgent.id); + setDeleteOpen(false); + setDeletingAgent(null); + } + }, [deletingAgent, store]); + + const isEmpty = filteredAgents.length === 0; + + return ( +
+ {isEmpty ? ( + // 空状态 +
+
+ +
+

{m.agents_empty()}

+

{m.agents_empty_hint()}

+
+ ) : ( + // Agent 列表 +
+ {filteredAgents.map((agent) => ( + + ))} +
+ )} + + {/* 编辑对话框 */} + + + {/* 删除确认对话框 */} + +
+ ); +} diff --git a/src/features/agents/agent-store.ts b/src/features/agents/agent-store.ts new file mode 100644 index 0000000..0431bf2 --- /dev/null +++ b/src/features/agents/agent-store.ts @@ -0,0 +1,169 @@ +import { useCallback, useMemo, useState } from "react"; + +import type { AgentConfig, AgentConfigForm, AgentToolType } from "./types"; + +/** + * 生成唯一 ID + */ +function generateId() { + return `agent_${Date.now()}_${Math.random().toString(36).slice(2, 9)}`; +} + +/** + * 创建新的 Agent 配置 + */ +function createAgentConfig(form: AgentConfigForm, sortIndex: number): AgentConfig { + const now = new Date().toISOString(); + return { + id: generateId(), + name: form.name, + type: form.type, + isActive: false, + sortIndex, + createdAt: now, + updatedAt: now, + }; +} + +// Mock 初始数据 +const INITIAL_AGENTS: AgentConfig[] = [ + { + id: "agent_default_claude", + name: "Claude Code (Default)", + type: "claude", + isActive: true, + sortIndex: 0, + createdAt: "2025-01-01T00:00:00.000Z", + updatedAt: "2025-01-01T00:00:00.000Z", + }, + { + id: "agent_default_codex", + name: "Codex CLI", + type: "codex", + isActive: false, + sortIndex: 1, + createdAt: "2025-01-01T00:00:00.000Z", + updatedAt: "2025-01-01T00:00:00.000Z", + }, +]; + +/** + * Agent 配置状态管理 Hook + */ +export function useAgentStore() { + const [agents, setAgents] = useState(INITIAL_AGENTS); + + // 按 sortIndex 排序的 agents + const sortedAgents = useMemo( + () => [...agents].sort((a, b) => a.sortIndex - b.sortIndex), + [agents] + ); + + // 当前激活的 agent + const activeAgent = useMemo( + () => agents.find((agent) => agent.isActive) ?? null, + [agents] + ); + + // 按类型分组 + const agentsByType = useMemo(() => { + const grouped: Record = { + claude: [], + codex: [], + opencode: [], + }; + for (const agent of sortedAgents) { + grouped[agent.type].push(agent); + } + return grouped; + }, [sortedAgents]); + + // 添加 agent + const addAgent = useCallback((form: AgentConfigForm) => { + setAgents((prev) => { + const maxSortIndex = prev.reduce((max, a) => Math.max(max, a.sortIndex), -1); + const newAgent = createAgentConfig(form, maxSortIndex + 1); + // 如果是第一个该类型的 agent,自动激活 + const hasActiveOfType = prev.some((a) => a.type === form.type && a.isActive); + if (!hasActiveOfType) { + newAgent.isActive = true; + } + return [...prev, newAgent]; + }); + }, []); + + // 更新 agent + const updateAgent = useCallback((id: string, form: AgentConfigForm) => { + setAgents((prev) => + prev.map((agent) => + agent.id === id + ? { ...agent, name: form.name, type: form.type, updatedAt: new Date().toISOString() } + : agent + ) + ); + }, []); + + // 删除 agent + const deleteAgent = useCallback((id: string) => { + setAgents((prev) => { + const target = prev.find((a) => a.id === id); + if (!target) return prev; + + const remaining = prev.filter((a) => a.id !== id); + + // 如果删除的是激活的 agent,激活同类型的第一个 + if (target.isActive) { + const sameType = remaining.filter((a) => a.type === target.type); + if (sameType.length > 0) { + const firstOfType = sameType.sort((a, b) => a.sortIndex - b.sortIndex)[0]; + return remaining.map((a) => + a.id === firstOfType.id ? { ...a, isActive: true } : a + ); + } + } + + return remaining; + }); + }, []); + + // 切换激活状态(同类型只能有一个激活) + const switchAgent = useCallback((id: string) => { + setAgents((prev) => { + const target = prev.find((a) => a.id === id); + if (!target || target.isActive) return prev; + + return prev.map((agent) => { + // 同类型的其他 agent 取消激活 + if (agent.type === target.type) { + return { ...agent, isActive: agent.id === id }; + } + return agent; + }); + }); + }, []); + + // 更新排序 + const reorderAgents = useCallback((reorderedIds: string[]) => { + setAgents((prev) => { + const idToAgent = new Map(prev.map((a) => [a.id, a])); + return reorderedIds.map((id, index) => { + const agent = idToAgent.get(id); + if (!agent) throw new Error(`Agent not found: ${id}`); + return { ...agent, sortIndex: index }; + }); + }); + }, []); + + return { + agents: sortedAgents, + activeAgent, + agentsByType, + addAgent, + updateAgent, + deleteAgent, + switchAgent, + reorderAgents, + }; +} + +export type AgentStore = ReturnType; diff --git a/src/features/agents/constants.ts b/src/features/agents/constants.ts new file mode 100644 index 0000000..392fe22 --- /dev/null +++ b/src/features/agents/constants.ts @@ -0,0 +1,29 @@ +import { m } from "@/paraglide/messages.js"; + +import type { AgentToolMeta, AgentToolType } from "./types"; + +/** + * Agent 工具元数据 + */ +export const AGENT_TOOL_META: Record = { + claude: { + type: "claude", + label: () => m.agents_tool_claude(), + description: () => m.agents_tool_claude_desc(), + }, + codex: { + type: "codex", + label: () => m.agents_tool_codex(), + description: () => m.agents_tool_codex_desc(), + }, + opencode: { + type: "opencode", + label: () => m.agents_tool_opencode(), + description: () => m.agents_tool_opencode_desc(), + }, +}; + +/** + * Agent 工具类型列表(用于下拉选择) + */ +export const AGENT_TOOL_OPTIONS: AgentToolType[] = ["claude", "codex", "opencode"]; diff --git a/src/features/agents/index.ts b/src/features/agents/index.ts new file mode 100644 index 0000000..bb634f8 --- /dev/null +++ b/src/features/agents/index.ts @@ -0,0 +1,7 @@ +export { AgentsPanel } from "./AgentsPanel"; +export { AgentCard } from "./AgentCard"; +export { AgentEditorDialog } from "./AgentEditorDialog"; +export { AgentDeleteDialog } from "./AgentDeleteDialog"; +export { useAgentStore } from "./agent-store"; +export { AGENT_TOOL_META, AGENT_TOOL_OPTIONS } from "./constants"; +export type { AgentConfig, AgentConfigForm, AgentToolType, AgentToolMeta } from "./types"; diff --git a/src/features/agents/types.ts b/src/features/agents/types.ts new file mode 100644 index 0000000..73b9eef --- /dev/null +++ b/src/features/agents/types.ts @@ -0,0 +1,41 @@ +/** + * Agent 工具类型 + */ +export type AgentToolType = "claude" | "codex" | "opencode"; + +/** + * Agent 配置项 + */ +export type AgentConfig = { + /** 唯一标识符 */ + id: string; + /** 配置名称 */ + name: string; + /** Agent 工具类型 */ + type: AgentToolType; + /** 是否为当前激活配置 */ + isActive: boolean; + /** 排序索引(越小越靠前) */ + sortIndex: number; + /** 创建时间 ISO 字符串 */ + createdAt: string; + /** 更新时间 ISO 字符串 */ + updatedAt: string; +}; + +/** + * Agent 配置表单 + */ +export type AgentConfigForm = { + name: string; + type: AgentToolType; +}; + +/** + * Agent 工具元数据 + */ +export type AgentToolMeta = { + type: AgentToolType; + label: () => string; + description: () => string; +}; diff --git a/src/features/config/AppView.tsx b/src/features/config/AppView.tsx index 0adcc93..4727fab 100644 --- a/src/features/config/AppView.tsx +++ b/src/features/config/AppView.tsx @@ -1,9 +1,16 @@ -import { AlertCircle, Loader2, RefreshCw } from "lucide-react"; -import { useMemo, type CSSProperties } from "react"; +import { AlertCircle, Loader2, Plus, RefreshCw } from "lucide-react"; +import { useCallback, useMemo, useState, type CSSProperties } from "react"; import { AppSidebar } from "@/components/app-sidebar"; import { SiteHeader } from "@/components/site-header"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; import { AlertDialog, AlertDialogAction, @@ -19,7 +26,6 @@ import { Button } from "@/components/ui/button"; import { ScrollArea } from "@/components/ui/scroll-area"; import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar"; import { - ClientSetupCard, ConfigFileCard, AutoStartCard, ProjectLinksCard, @@ -41,11 +47,15 @@ import type { ProxyServiceRequestState, ProxyServiceStatus, } from "@/features/config/types"; +import { AgentsPanel } from "@/features/agents"; import { DashboardPanel } from "@/features/dashboard/DashboardPanel"; import { LogsPanel } from "@/features/logs/LogsPanel"; import { ProvidersPanel } from "@/features/providers/ProvidersPanel"; import { m } from "@/paraglide/messages.js"; +// Agent 工具类型(用于 agents 页面工具筛选与创建) +export type AgentToolId = "claude" | "codex" | "opencode"; + type AppViewProps = { activeSectionId: ConfigSectionId; form: ConfigForm; @@ -93,6 +103,10 @@ type ConfigToolbarProps = { isDirty: boolean; onReload: () => void; onSave: () => void; + // agents 页面专用 + selectedTool?: AgentToolId; + onToolChange?: (tool: AgentToolId) => void; + onAddAgent?: () => void; }; function ConfigToolbar({ @@ -102,24 +116,42 @@ function ConfigToolbar({ isDirty, onReload, onSave, + selectedTool, + onToolChange, + onAddAgent, }: ConfigToolbarProps) { const isLoading = status === "loading"; const isSaving = status === "saving"; const canReload = !isLoading && !isSaving; + const isAgentsSection = section.id === "agents"; return (
-
-

- {section.label()} -

-

- {section.description()} -

-
+ {/* agents 页面显示工具选择器,其他页面显示标题 */} + {isAgentsSection && selectedTool && onToolChange ? ( + + ) : ( +
+

+ {section.label()} +

+

+ {section.description()} +

+
+ )}
{isDirty ? ( @@ -162,13 +194,21 @@ function ConfigToolbar({ {m.common_refresh()} )} - + {/* agents 页面显示添加按钮,其他页面显示保存按钮 */} + {isAgentsSection && onAddAgent ? ( + + ) : ( + + )}
); @@ -198,11 +238,16 @@ type ConfigSectionContentProps = Omit & { proxyService: ProxyServiceViewProps; }; -type ConfigSectionBodyProps = ConfigSectionContentProps; +type ConfigSectionBodyProps = ConfigSectionContentProps & { + selectedTool: AgentToolId; + agentEditorTrigger: number; +}; function ConfigSectionBody({ activeSectionId, proxyService, + selectedTool, + agentEditorTrigger, ...props }: ConfigSectionBodyProps) { switch (activeSectionId) { @@ -259,7 +304,7 @@ function ConfigSectionBody({ case "agents": return (
- +
); default: @@ -272,6 +317,14 @@ function ConfigSectionContent({ proxyService, ...props }: ConfigSectionContentProps) { + // agents 页面的工具选择状态 + const [selectedTool, setSelectedTool] = useState("claude"); + // agents 页面的添加对话框触发器 + const [agentEditorTrigger, setAgentEditorTrigger] = useState(0); + const handleAddAgent = useCallback(() => { + setAgentEditorTrigger((prev) => prev + 1); + }, []); + if (activeSectionId === "dashboard") { return ; } @@ -291,12 +344,17 @@ function ConfigSectionContent({ isDirty={props.isDirty} onReload={props.onReload} onSave={props.onSave} + selectedTool={selectedTool} + onToolChange={setSelectedTool} + onAddAgent={handleAddAgent} /> ); diff --git a/src/features/config/cards/client-setup-card.tsx b/src/features/config/cards/client-setup-card.tsx index ab82a88..2a95e81 100644 --- a/src/features/config/cards/client-setup-card.tsx +++ b/src/features/config/cards/client-setup-card.tsx @@ -1,13 +1,11 @@ import type { ReactNode } from "react"; -import { m } from "@/paraglide/messages.js"; +import type { AgentToolId } from "@/features/config/AppView"; import { - ClientSetupOverviewCard, PlaintextWarning, - SummaryItem, ToolDetailsFallback, - ToolSetupDialog, + ToolSetupPanel, } from "./client-setup-ui"; import { useClientSetupPreview, @@ -25,13 +23,11 @@ import { type ClientSetupCardProps = { savedAt: string; isDirty: boolean; + selectedTool: AgentToolId; }; -type ToolListItem = { +type ToolPanelItem = { id: string; - title: string; - description: string; - summary: ReactNode; content: ReactNode; action: ActionState; canApply: boolean; @@ -60,17 +56,9 @@ function buildClaudeTool({ isWorking, action, onApply, -}: ToolBuildBaseArgs & ToolBuildActionArgs) { +}: ToolBuildBaseArgs & ToolBuildActionArgs): ToolPanelItem { return { id: "claude", - title: m.client_setup_claude_title(), - description: m.client_setup_claude_desc(), - summary: ( - - ), content: setup ? ( ) : ( @@ -80,7 +68,7 @@ function buildClaudeTool({ canApply: Boolean(setup) && canApply, isWorking, onApply, - } satisfies ToolListItem; + }; } function buildCodexTool({ @@ -91,17 +79,9 @@ function buildCodexTool({ isWorking, action, onApply, -}: ToolBuildBaseArgs & ToolBuildActionArgs) { +}: ToolBuildBaseArgs & ToolBuildActionArgs): ToolPanelItem { return { id: "codex", - title: m.client_setup_codex_title(), - description: m.client_setup_codex_desc(), - summary: ( - - ), content: setup ? ( ) : ( @@ -111,7 +91,7 @@ function buildCodexTool({ canApply: Boolean(setup) && canApply, isWorking, onApply, - } satisfies ToolListItem; + }; } type OpenCodeToolArgs = ToolBuildBaseArgs & ToolBuildActionArgs & { @@ -124,27 +104,12 @@ function buildOpenCodeTool({ previewState, previewMessage, canApplyOpenCode, - openCodeModelCount, isWorking, action, onApply, -}: OpenCodeToolArgs) { +}: OpenCodeToolArgs): ToolPanelItem { return { id: "opencode", - title: m.client_setup_opencode_title(), - description: m.client_setup_opencode_desc(), - summary: ( -
- - -
- ), content: setup ? ( ) : ( @@ -154,31 +119,10 @@ function buildOpenCodeTool({ canApply: Boolean(setup) && canApplyOpenCode, isWorking, onApply, - } satisfies ToolListItem; -} - -function ToolCards({ tools }: { tools: readonly ToolListItem[] }) { - return ( - <> - {tools.map((tool) => ( - - {tool.content} - - ))} - - ); + }; } -export function ClientSetupCard({ savedAt, isDirty }: ClientSetupCardProps) { +export function ClientSetupCard({ savedAt, isDirty, selectedTool }: ClientSetupCardProps) { const canApply = !isDirty; const { previewState, previewMessage, setup, loadPreview } = useClientSetupPreview(savedAt); @@ -203,29 +147,31 @@ export function ClientSetupCard({ savedAt, isDirty }: ClientSetupCardProps) { isWorking, }; - const tools: ToolListItem[] = [ - buildClaudeTool({ ...baseArgs, action: claude.action, onApply: claude.apply }), - buildCodexTool({ ...baseArgs, action: codex.action, onApply: codex.apply }), - buildOpenCodeTool({ + // 根据 selectedTool 构建对应的工具面板 + const toolBuilders: Record ToolPanelItem> = { + claude: () => buildClaudeTool({ ...baseArgs, action: claude.action, onApply: claude.apply }), + codex: () => buildCodexTool({ ...baseArgs, action: codex.action, onApply: codex.apply }), + opencode: () => buildOpenCodeTool({ ...baseArgs, action: opencode.action, onApply: opencode.apply, openCodeModelCount, canApplyOpenCode, }), - ]; + }; + + const selectedToolItem = toolBuilders[selectedTool](); return ( <> - - + + {selectedToolItem.content} + ); diff --git a/src/features/config/cards/client-setup-ui.tsx b/src/features/config/cards/client-setup-ui.tsx index 7c5ea08..eb38e2f 100644 --- a/src/features/config/cards/client-setup-ui.tsx +++ b/src/features/config/cards/client-setup-ui.tsx @@ -2,26 +2,13 @@ import type { ReactNode } from "react"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; -import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; -import { - Dialog, - DialogBody, - DialogClose, - DialogContent, - DialogDescription, - DialogFooter, - DialogHeader, - DialogTitle, - DialogTrigger, -} from "@/components/ui/dialog"; +import { Card, CardContent } from "@/components/ui/card"; import { m } from "@/paraglide/messages.js"; -import type { ActionState, ClientSetupInfo, RequestState } from "./client-setup-state"; +import type { ActionState, RequestState } from "./client-setup-state"; -type ToolSetupDialogProps = { - title: string; - description: string; - summary: ReactNode; +// 内联展示工具配置的 props +type ToolSetupPanelProps = { action: ActionState; canApply: boolean; isWorking: boolean; @@ -106,105 +93,43 @@ export function CodeBlock({ lines }: { lines: readonly string[] }) { ); } -type ToolSetupCardProps = Pick; - -function ToolSetupCard({ title, description, summary, action }: ToolSetupCardProps) { - return ( - - -
-
- {title} - {description} -
- {shouldShowBadge(action.state) ? ( - {toBadgeLabel(action.state)} - ) : null} -
-
- - {summary} - - - - -
- ); -} - -type ToolSetupModalProps = Omit; - -function ToolSetupModal({ - title, - description, +/** 内联展示工具配置面板(无弹窗) */ +export function ToolSetupPanel({ action, canApply, isWorking, onApply, children, -}: ToolSetupModalProps) { +}: ToolSetupPanelProps) { return ( - - -
-
- {title} - {description} -
- {shouldShowBadge(action.state) ? ( - {toBadgeLabel(action.state)} - ) : null} -
-
- - + + + {/* 详细配置内容 */} {children} + {/* 操作状态消息 */} {action.message ? (
{action.message}
) : null} + {/* 备份提示 */}

{m.client_setup_backup_hint()}

-
- - - - - - -
- ); -} - -export function ToolSetupDialog(props: ToolSetupDialogProps) { - return ( - - - - {props.children} - - + + + ); } @@ -250,63 +175,3 @@ export function PlaintextWarning() { ); } - -type ClientSetupOverviewCardProps = { - previewState: RequestState; - previewMessage: string; - setup: ClientSetupInfo | null; - isDirty: boolean; - isWorking: boolean; - onRefresh: () => void; -}; - -export function ClientSetupOverviewCard({ - previewState, - previewMessage, - setup, - isDirty, - isWorking, - onRefresh, -}: ClientSetupOverviewCardProps) { - return ( - - - {m.client_setup_title()} - {m.client_setup_desc()} - - -
- {shouldShowBadge(previewState) ? ( - {toBadgeLabel(previewState)} - ) : null} - -
- - {previewMessage ? ( -
- {previewMessage} -
- ) : null} - - {isDirty ? ( -
- {m.client_setup_dirty_notice()} -
- ) : null} - - {setup ? ( -
-

- {m.client_setup_proxy_base_url_label()} -

-

- {setup.proxy_http_base_url} -

-
- ) : null} -
-
- ); -} diff --git a/src/features/dashboard/RecentRequestsTable.tsx b/src/features/dashboard/RecentRequestsTable.tsx index be16893..eb232a8 100644 --- a/src/features/dashboard/RecentRequestsTable.tsx +++ b/src/features/dashboard/RecentRequestsTable.tsx @@ -122,7 +122,9 @@ function modelColumn(): ColumnDef { header: m.dashboard_table_model(), cell: ({ row }) => { const primary = row.original.model?.trim() ? row.original.model : CELL_PLACEHOLDER; - const mapped = row.original.mappedModel?.trim() ? row.original.mappedModel : null; + const rawMapped = row.original.mappedModel?.trim() ? row.original.mappedModel : null; + // 只有当 mappedModel 存在且与 model 不同时才显示映射 + const mapped = rawMapped && rawMapped !== row.original.model ? rawMapped : null; const tooltipText = mapped ? `${primary}\n${mapped}` : primary; return ( From 2db2177730f21976bad8560c8cd48b64ea206df3 Mon Sep 17 00:00:00 2001 From: mxyhi Date: Fri, 30 Jan 2026 13:54:42 +0800 Subject: [PATCH 02/10] feat(agents): let each agent target a different API and model --- messages/en.json | 16 +- messages/zh.json | 16 +- src/features/agents/AgentEditorDialog.tsx | 408 ++++++++++++++++++++-- src/features/agents/agent-store.ts | 61 +++- src/features/agents/index.ts | 16 +- src/features/agents/types.ts | 98 +++++- 6 files changed, 581 insertions(+), 34 deletions(-) diff --git a/messages/en.json b/messages/en.json index e3cf664..1042b0d 100644 --- a/messages/en.json +++ b/messages/en.json @@ -455,5 +455,19 @@ "agents_tool_codex": "Codex CLI", "agents_tool_codex_desc": "OpenAI's Codex CLI tool", "agents_tool_opencode": "OpenCode", - "agents_tool_opencode_desc": "Open source AI coding assistant" + "agents_tool_opencode_desc": "Open source AI coding assistant", + "agents_editor_connection": "Connection", + "agents_editor_api_key": "API Key", + "agents_editor_api_key_placeholder": "sk-...", + "agents_editor_base_url": "Base URL", + "agents_editor_base_url_placeholder": "https://api.example.com", + "agents_editor_model_config": "Model Configuration", + "agents_editor_claude_model": "Default Model", + "agents_editor_claude_haiku": "Haiku Model", + "agents_editor_claude_sonnet": "Sonnet Model", + "agents_editor_claude_opus": "Opus Model", + "agents_editor_codex_model": "Model", + "agents_editor_codex_reasoning": "Reasoning Effort", + "agents_editor_opencode_provider": "Provider", + "agents_editor_opencode_model": "Model" } diff --git a/messages/zh.json b/messages/zh.json index 9d9c9da..76de73c 100644 --- a/messages/zh.json +++ b/messages/zh.json @@ -455,5 +455,19 @@ "agents_tool_codex": "Codex CLI", "agents_tool_codex_desc": "OpenAI 的 Codex CLI 工具", "agents_tool_opencode": "OpenCode", - "agents_tool_opencode_desc": "开源 AI 编程助手" + "agents_tool_opencode_desc": "开源 AI 编程助手", + "agents_editor_connection": "连接配置", + "agents_editor_api_key": "API Key", + "agents_editor_api_key_placeholder": "sk-...", + "agents_editor_base_url": "Base URL", + "agents_editor_base_url_placeholder": "https://api.example.com", + "agents_editor_model_config": "模型配置", + "agents_editor_claude_model": "默认模型", + "agents_editor_claude_haiku": "Haiku 模型", + "agents_editor_claude_sonnet": "Sonnet 模型", + "agents_editor_claude_opus": "Opus 模型", + "agents_editor_codex_model": "模型", + "agents_editor_codex_reasoning": "推理强度", + "agents_editor_opencode_provider": "Provider", + "agents_editor_opencode_model": "模型" } diff --git a/src/features/agents/AgentEditorDialog.tsx b/src/features/agents/AgentEditorDialog.tsx index e5301ea..58f0d58 100644 --- a/src/features/agents/AgentEditorDialog.tsx +++ b/src/features/agents/AgentEditorDialog.tsx @@ -1,4 +1,4 @@ -import { useCallback, useEffect, useState } from "react"; +import { useCallback, useEffect, useMemo, useState } from "react"; import { Button } from "@/components/ui/button"; import { @@ -11,10 +11,27 @@ import { } from "@/components/ui/dialog"; import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; +import { PasswordInput } from "@/components/ui/password-input"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Separator } from "@/components/ui/separator"; import { m } from "@/paraglide/messages.js"; -import type { AgentConfig, AgentConfigForm, AgentToolType } from "./types"; +import type { + AgentConfig, + AgentConfigForm, + AgentSettingsMap, + AgentToolType, + CodexReasoningEffort, + OpenCodeProvider, +} from "./types"; import { AGENT_TOOL_META } from "./constants"; +import { createDefaultSettings } from "./agent-store"; type AgentEditorDialogProps = { open: boolean; @@ -26,6 +43,26 @@ type AgentEditorDialogProps = { onSave: (form: AgentConfigForm) => void; }; +/** + * Codex 推理强度选项 + */ +const CODEX_REASONING_OPTIONS: { value: CodexReasoningEffort; label: string }[] = [ + { value: "low", label: "Low" }, + { value: "medium", label: "Medium" }, + { value: "high", label: "High" }, +]; + +/** + * OpenCode Provider 选项 + */ +const OPENCODE_PROVIDER_OPTIONS: { value: OpenCodeProvider; label: string }[] = [ + { value: "openai", label: "OpenAI" }, + { value: "anthropic", label: "Anthropic" }, + { value: "gemini", label: "Google Gemini" }, + { value: "openrouter", label: "OpenRouter" }, + { value: "custom", label: "Custom" }, +]; + export function AgentEditorDialog({ open, onOpenChange, @@ -33,8 +70,28 @@ export function AgentEditorDialog({ defaultType, onSave, }: AgentEditorDialogProps) { + // 基础字段 const [name, setName] = useState(""); + // Settings 字段 + const [apiKey, setApiKey] = useState(""); + const [apiKeyVisible, setApiKeyVisible] = useState(false); + const [baseUrl, setBaseUrl] = useState(""); + + // Claude 特有字段 + const [claudeModel, setClaudeModel] = useState(""); + const [haikuModel, setHaikuModel] = useState(""); + const [sonnetModel, setSonnetModel] = useState(""); + const [opusModel, setOpusModel] = useState(""); + + // Codex 特有字段 + const [codexModel, setCodexModel] = useState(""); + const [reasoningEffort, setReasoningEffort] = useState("medium"); + + // OpenCode 特有字段 + const [openCodeProvider, setOpenCodeProvider] = useState("openai"); + const [openCodeModel, setOpenCodeModel] = useState(""); + const isEditing = editingAgent !== null; const title = isEditing ? m.agents_editor_title_edit() : m.agents_editor_title_add(); const currentType = isEditing ? editingAgent.type : defaultType; @@ -43,29 +100,116 @@ export function AgentEditorDialog({ // 打开对话框时初始化表单 useEffect(() => { if (open) { + setApiKeyVisible(false); + if (editingAgent) { + // 编辑模式:从现有配置初始化 setName(editingAgent.name); + const settings = editingAgent.settings; + setApiKey(settings.apiKey); + setBaseUrl(settings.baseUrl); + + // 根据类型初始化特有字段 + if (editingAgent.type === "claude") { + const s = settings as AgentSettingsMap["claude"]; + setClaudeModel(s.model ?? ""); + setHaikuModel(s.haikuModel ?? ""); + setSonnetModel(s.sonnetModel ?? ""); + setOpusModel(s.opusModel ?? ""); + } else if (editingAgent.type === "codex") { + const s = settings as AgentSettingsMap["codex"]; + setCodexModel(s.model ?? ""); + setReasoningEffort(s.reasoningEffort ?? "medium"); + } else if (editingAgent.type === "opencode") { + const s = settings as AgentSettingsMap["opencode"]; + setOpenCodeProvider(s.provider ?? "openai"); + setOpenCodeModel(s.model ?? ""); + } } else { + // 添加模式:使用默认值 setName(""); + const defaultSettings = createDefaultSettings(defaultType); + setApiKey(defaultSettings.apiKey); + setBaseUrl(defaultSettings.baseUrl); + + // 重置特有字段 + setClaudeModel(""); + setHaikuModel(""); + setSonnetModel(""); + setOpusModel(""); + setCodexModel(""); + setReasoningEffort("medium"); + setOpenCodeProvider("openai"); + setOpenCodeModel(""); } } - }, [open, editingAgent]); + }, [open, editingAgent, defaultType]); - const handleNameChange = useCallback((e: React.ChangeEvent) => { - setName(e.target.value); - }, []); + // 构建 settings 对象 + const buildSettings = useCallback((): AgentSettingsMap[typeof currentType] => { + const base = { apiKey, baseUrl }; + + switch (currentType) { + case "claude": + return { + ...base, + ...(claudeModel && { model: claudeModel }), + ...(haikuModel && { haikuModel }), + ...(sonnetModel && { sonnetModel }), + ...(opusModel && { opusModel }), + } as AgentSettingsMap["claude"]; + case "codex": + return { + ...base, + ...(codexModel && { model: codexModel }), + reasoningEffort, + } as AgentSettingsMap["codex"]; + case "opencode": + return { + ...base, + provider: openCodeProvider, + ...(openCodeModel && { model: openCodeModel }), + } as AgentSettingsMap["opencode"]; + default: + return base as AgentSettingsMap[typeof currentType]; + } + }, [ + currentType, + apiKey, + baseUrl, + claudeModel, + haikuModel, + sonnetModel, + opusModel, + codexModel, + reasoningEffort, + openCodeProvider, + openCodeModel, + ]); const handleSave = useCallback(() => { if (!name.trim()) return; - onSave({ name: name.trim(), type: currentType }); + onSave({ + name: name.trim(), + type: currentType, + settings: buildSettings(), + }); onOpenChange(false); - }, [name, currentType, onSave, onOpenChange]); + }, [name, currentType, buildSettings, onSave, onOpenChange]); - const canSave = name.trim().length > 0; + const canSave = useMemo(() => { + // 名称必填 + if (!name.trim()) return false; + // API Key 必填 + if (!apiKey.trim()) return false; + // Base URL 必填 + if (!baseUrl.trim()) return false; + return true; + }, [name, apiKey, baseUrl]); return ( - + {title} @@ -73,17 +217,86 @@ export function AgentEditorDialog({ -
- {/* 名称 */} -
- - +
+ {/* 基础信息 */} +
+
+ + setName(e.target.value)} + placeholder={m.agents_editor_name_placeholder()} + autoFocus + /> +
+
+ + + + {/* 连接配置 */} +
+

{m.agents_editor_connection()}

+ +
+ + setApiKey(e.target.value)} + visible={apiKeyVisible} + onVisibilityChange={() => setApiKeyVisible((v) => !v)} + placeholder={m.agents_editor_api_key_placeholder()} + /> +
+ +
+ + setBaseUrl(e.target.value)} + placeholder={m.agents_editor_base_url_placeholder()} + /> +
+
+ + + + {/* 模型配置 - 根据工具类型动态渲染 */} +
+

{m.agents_editor_model_config()}

+ + {currentType === "claude" && ( + + )} + + {currentType === "codex" && ( + + )} + + {currentType === "opencode" && ( + + )}
@@ -99,3 +312,156 @@ export function AgentEditorDialog({
); } + +// ============================================================================ +// 子组件:各工具类型的模型配置字段 +// ============================================================================ + +type ClaudeModelFieldsProps = { + model: string; + onModelChange: (value: string) => void; + haikuModel: string; + onHaikuModelChange: (value: string) => void; + sonnetModel: string; + onSonnetModelChange: (value: string) => void; + opusModel: string; + onOpusModelChange: (value: string) => void; +}; + +function ClaudeModelFields({ + model, + onModelChange, + haikuModel, + onHaikuModelChange, + sonnetModel, + onSonnetModelChange, + opusModel, + onOpusModelChange, +}: ClaudeModelFieldsProps) { + return ( +
+
+ + onModelChange(e.target.value)} + placeholder="claude-sonnet-4-20250514" + /> +
+
+ + onHaikuModelChange(e.target.value)} + placeholder="claude-haiku-3-5-20241022" + /> +
+
+ + onSonnetModelChange(e.target.value)} + placeholder="claude-sonnet-4-20250514" + /> +
+
+ + onOpusModelChange(e.target.value)} + placeholder="claude-opus-4-20250514" + /> +
+
+ ); +} + +type CodexModelFieldsProps = { + model: string; + onModelChange: (value: string) => void; + reasoningEffort: CodexReasoningEffort; + onReasoningEffortChange: (value: CodexReasoningEffort) => void; +}; + +function CodexModelFields({ + model, + onModelChange, + reasoningEffort, + onReasoningEffortChange, +}: CodexModelFieldsProps) { + return ( +
+
+ + onModelChange(e.target.value)} + placeholder="o3" + /> +
+
+ + +
+
+ ); +} + +type OpenCodeModelFieldsProps = { + provider: OpenCodeProvider; + onProviderChange: (value: OpenCodeProvider) => void; + model: string; + onModelChange: (value: string) => void; +}; + +function OpenCodeModelFields({ + provider, + onProviderChange, + model, + onModelChange, +}: OpenCodeModelFieldsProps) { + return ( +
+
+ + +
+
+ + onModelChange(e.target.value)} + placeholder="gpt-4o" + /> +
+
+ ); +} diff --git a/src/features/agents/agent-store.ts b/src/features/agents/agent-store.ts index 0431bf2..5c3b362 100644 --- a/src/features/agents/agent-store.ts +++ b/src/features/agents/agent-store.ts @@ -1,6 +1,13 @@ import { useCallback, useMemo, useState } from "react"; -import type { AgentConfig, AgentConfigForm, AgentToolType } from "./types"; +import type { + AgentConfig, + AgentConfigForm, + AgentSettingsMap, + AgentToolType, + ClaudeSettings, + CodexSettings, +} from "./types"; /** * 生成唯一 ID @@ -9,15 +16,46 @@ function generateId() { return `agent_${Date.now()}_${Math.random().toString(36).slice(2, 9)}`; } +/** + * 各工具类型的默认 Base URL + */ +const DEFAULT_BASE_URLS: Record = { + claude: "https://api.anthropic.com", + codex: "https://api.openai.com/v1", + opencode: "https://api.openai.com/v1", +}; + +/** + * 创建默认的 Agent Settings(导出供外部使用) + */ +export function createDefaultSettings(type: T): AgentSettingsMap[T] { + const base = { apiKey: "", baseUrl: DEFAULT_BASE_URLS[type] }; + + switch (type) { + case "claude": + return { ...base } as AgentSettingsMap[T]; + case "codex": + return { ...base, reasoningEffort: "medium" } as AgentSettingsMap[T]; + case "opencode": + return { ...base, provider: "openai" } as AgentSettingsMap[T]; + default: + return base as AgentSettingsMap[T]; + } +} + /** * 创建新的 Agent 配置 */ -function createAgentConfig(form: AgentConfigForm, sortIndex: number): AgentConfig { +function createAgentConfig( + form: AgentConfigForm, + sortIndex: number +): AgentConfig { const now = new Date().toISOString(); return { id: generateId(), name: form.name, type: form.type, + settings: form.settings, isActive: false, sortIndex, createdAt: now, @@ -31,6 +69,10 @@ const INITIAL_AGENTS: AgentConfig[] = [ id: "agent_default_claude", name: "Claude Code (Default)", type: "claude", + settings: { + apiKey: "", + baseUrl: "https://api.anthropic.com", + } as ClaudeSettings, isActive: true, sortIndex: 0, createdAt: "2025-01-01T00:00:00.000Z", @@ -40,6 +82,11 @@ const INITIAL_AGENTS: AgentConfig[] = [ id: "agent_default_codex", name: "Codex CLI", type: "codex", + settings: { + apiKey: "", + baseUrl: "https://api.openai.com/v1", + reasoningEffort: "medium", + } as CodexSettings, isActive: false, sortIndex: 1, createdAt: "2025-01-01T00:00:00.000Z", @@ -92,12 +139,18 @@ export function useAgentStore() { }); }, []); - // 更新 agent + // 更新 agent(包括 settings) const updateAgent = useCallback((id: string, form: AgentConfigForm) => { setAgents((prev) => prev.map((agent) => agent.id === id - ? { ...agent, name: form.name, type: form.type, updatedAt: new Date().toISOString() } + ? { + ...agent, + name: form.name, + type: form.type, + settings: form.settings, + updatedAt: new Date().toISOString(), + } : agent ) ); diff --git a/src/features/agents/index.ts b/src/features/agents/index.ts index bb634f8..9672f1c 100644 --- a/src/features/agents/index.ts +++ b/src/features/agents/index.ts @@ -2,6 +2,18 @@ export { AgentsPanel } from "./AgentsPanel"; export { AgentCard } from "./AgentCard"; export { AgentEditorDialog } from "./AgentEditorDialog"; export { AgentDeleteDialog } from "./AgentDeleteDialog"; -export { useAgentStore } from "./agent-store"; +export { useAgentStore, createDefaultSettings } from "./agent-store"; export { AGENT_TOOL_META, AGENT_TOOL_OPTIONS } from "./constants"; -export type { AgentConfig, AgentConfigForm, AgentToolType, AgentToolMeta } from "./types"; +export type { + AgentConfig, + AgentConfigForm, + AgentToolType, + AgentToolMeta, + AgentSettings, + AgentSettingsMap, + ClaudeSettings, + CodexSettings, + OpenCodeSettings, + CodexReasoningEffort, + OpenCodeProvider, +} from "./types"; diff --git a/src/features/agents/types.ts b/src/features/agents/types.ts index 73b9eef..e07e484 100644 --- a/src/features/agents/types.ts +++ b/src/features/agents/types.ts @@ -3,16 +3,94 @@ */ export type AgentToolType = "claude" | "codex" | "opencode"; +// ============================================================================ +// Agent Settings 类型定义 +// ============================================================================ + +/** + * 通用连接配置 + */ +type BaseAgentSettings = { + /** API 密钥 */ + apiKey: string; + /** API 基础 URL */ + baseUrl: string; +}; + +/** + * Claude 特有配置 + */ +export type ClaudeSettings = BaseAgentSettings & { + /** 主模型(用于一般任务) */ + model?: string; + /** Haiku 模型(快速轻量任务) */ + haikuModel?: string; + /** Sonnet 模型(平衡任务) */ + sonnetModel?: string; + /** Opus 模型(复杂任务) */ + opusModel?: string; +}; + +/** + * Codex 推理强度 + */ +export type CodexReasoningEffort = "low" | "medium" | "high"; + +/** + * Codex 特有配置 + */ +export type CodexSettings = BaseAgentSettings & { + /** 模型名称 */ + model?: string; + /** 推理强度 */ + reasoningEffort?: CodexReasoningEffort; +}; + +/** + * OpenCode Provider 类型 + */ +export type OpenCodeProvider = "openai" | "anthropic" | "gemini" | "openrouter" | "custom"; + +/** + * OpenCode 特有配置 + */ +export type OpenCodeSettings = BaseAgentSettings & { + /** Provider 类型 */ + provider?: OpenCodeProvider; + /** 模型名称 */ + model?: string; +}; + +/** + * Agent 工具类型到配置的映射 + */ +export type AgentSettingsMap = { + claude: ClaudeSettings; + codex: CodexSettings; + opencode: OpenCodeSettings; +}; + +/** + * 所有 Agent Settings 的联合类型 + */ +export type AgentSettings = AgentSettingsMap[AgentToolType]; + +// ============================================================================ +// Agent Config 类型定义 +// ============================================================================ + /** * Agent 配置项 */ -export type AgentConfig = { +export type AgentConfig = { /** 唯一标识符 */ id: string; /** 配置名称 */ name: string; /** Agent 工具类型 */ - type: AgentToolType; + type: T; + /** 工具特定配置 */ + settings: AgentSettingsMap[T]; /** 是否为当前激活配置 */ isActive: boolean; /** 排序索引(越小越靠前) */ @@ -24,11 +102,12 @@ export type AgentConfig = { }; /** - * Agent 配置表单 + * Agent 配置表单(用于添加/编辑) */ -export type AgentConfigForm = { +export type AgentConfigForm = { name: string; - type: AgentToolType; + type: T; + settings: AgentSettingsMap[T]; }; /** @@ -39,3 +118,12 @@ export type AgentToolMeta = { label: () => string; description: () => string; }; + +// ============================================================================ +// 工具函数类型 +// ============================================================================ + +/** + * 创建默认 settings 的工厂函数类型 + */ +export type CreateDefaultSettings = (type: T) => AgentSettingsMap[T]; From 49ba2219ddf3ad486a73814cf3f48c6dd738ea1e Mon Sep 17 00:00:00 2001 From: mxyhi Date: Fri, 30 Jan 2026 14:43:09 +0800 Subject: [PATCH 03/10] refactor: shift config UI to client setup flow --- messages/en.json | 39 +- messages/zh.json | 39 +- src/features/agents/AgentCard.tsx | 108 ---- src/features/agents/AgentDeleteDialog.tsx | 42 -- src/features/agents/AgentEditorDialog.tsx | 467 ------------------ src/features/agents/AgentsPanel.tsx | 124 ----- src/features/agents/agent-store.ts | 222 --------- src/features/agents/constants.ts | 29 -- src/features/agents/index.ts | 19 - src/features/agents/types.ts | 129 ----- src/features/config/AppView.tsx | 98 +--- .../config/cards/client-setup-card.tsx | 108 +++- src/features/config/cards/client-setup-ui.tsx | 183 ++++++- src/features/config/cards/upstreams-card.tsx | 264 ++-------- .../upstreams/upstream-editor-helpers.test.ts | 111 +++++ .../upstreams/upstream-editor-helpers.ts | 263 ++++++++++ .../dashboard/RecentRequestsTable.tsx | 4 +- 17 files changed, 666 insertions(+), 1583 deletions(-) delete mode 100644 src/features/agents/AgentCard.tsx delete mode 100644 src/features/agents/AgentDeleteDialog.tsx delete mode 100644 src/features/agents/AgentEditorDialog.tsx delete mode 100644 src/features/agents/AgentsPanel.tsx delete mode 100644 src/features/agents/agent-store.ts delete mode 100644 src/features/agents/constants.ts delete mode 100644 src/features/agents/index.ts delete mode 100644 src/features/agents/types.ts create mode 100644 src/features/config/cards/upstreams/upstream-editor-helpers.test.ts create mode 100644 src/features/config/cards/upstreams/upstream-editor-helpers.ts diff --git a/messages/en.json b/messages/en.json index 1042b0d..9464669 100644 --- a/messages/en.json +++ b/messages/en.json @@ -432,42 +432,5 @@ "providers_status_summary": "Active {active} · Expired {expired}", "providers_quota_loading": "Loading quota...", "providers_account_delete_title": "Delete account?", - "providers_account_delete_description": "This will delete account {label}.", - "common_edit": "Edit", - "common_add": "Add", - "agents_title": "Agents", - "agents_desc": "Manage agent configurations for different AI tools.", - "agents_active": "Active", - "agents_switch": "Switch to this agent", - "agents_add": "Add Agent", - "agents_empty": "No agents configured yet.", - "agents_empty_hint": "Add an agent to get started.", - "agents_delete_title": "Delete agent?", - "agents_delete_description": "This will delete agent \"{name}\".", - "agents_editor_title_add": "Add Agent", - "agents_editor_title_edit": "Edit Agent", - "agents_editor_name_label": "Name", - "agents_editor_name_placeholder": "My Claude Config", - "agents_editor_type_label": "Tool Type", - "agents_editor_type_placeholder": "Select tool type", - "agents_tool_claude": "Claude Code", - "agents_tool_claude_desc": "Anthropic's Claude Code CLI tool", - "agents_tool_codex": "Codex CLI", - "agents_tool_codex_desc": "OpenAI's Codex CLI tool", - "agents_tool_opencode": "OpenCode", - "agents_tool_opencode_desc": "Open source AI coding assistant", - "agents_editor_connection": "Connection", - "agents_editor_api_key": "API Key", - "agents_editor_api_key_placeholder": "sk-...", - "agents_editor_base_url": "Base URL", - "agents_editor_base_url_placeholder": "https://api.example.com", - "agents_editor_model_config": "Model Configuration", - "agents_editor_claude_model": "Default Model", - "agents_editor_claude_haiku": "Haiku Model", - "agents_editor_claude_sonnet": "Sonnet Model", - "agents_editor_claude_opus": "Opus Model", - "agents_editor_codex_model": "Model", - "agents_editor_codex_reasoning": "Reasoning Effort", - "agents_editor_opencode_provider": "Provider", - "agents_editor_opencode_model": "Model" + "providers_account_delete_description": "This will delete account {label}." } diff --git a/messages/zh.json b/messages/zh.json index 76de73c..231e833 100644 --- a/messages/zh.json +++ b/messages/zh.json @@ -432,42 +432,5 @@ "providers_status_summary": "正常 {active} · 过期 {expired}", "providers_quota_loading": "正在加载额度…", "providers_account_delete_title": "删除账户?", - "providers_account_delete_description": "将删除账户 {label}。", - "common_edit": "编辑", - "common_add": "添加", - "agents_title": "Agents", - "agents_desc": "管理不同 AI 工具的 Agent 配置。", - "agents_active": "激活", - "agents_switch": "切换到此 Agent", - "agents_add": "添加 Agent", - "agents_empty": "暂无 Agent 配置。", - "agents_empty_hint": "添加一个 Agent 开始使用。", - "agents_delete_title": "删除 Agent?", - "agents_delete_description": "将删除 Agent「{name}」。", - "agents_editor_title_add": "添加 Agent", - "agents_editor_title_edit": "编辑 Agent", - "agents_editor_name_label": "名称", - "agents_editor_name_placeholder": "我的 Claude 配置", - "agents_editor_type_label": "工具类型", - "agents_editor_type_placeholder": "选择工具类型", - "agents_tool_claude": "Claude Code", - "agents_tool_claude_desc": "Anthropic 的 Claude Code CLI 工具", - "agents_tool_codex": "Codex CLI", - "agents_tool_codex_desc": "OpenAI 的 Codex CLI 工具", - "agents_tool_opencode": "OpenCode", - "agents_tool_opencode_desc": "开源 AI 编程助手", - "agents_editor_connection": "连接配置", - "agents_editor_api_key": "API Key", - "agents_editor_api_key_placeholder": "sk-...", - "agents_editor_base_url": "Base URL", - "agents_editor_base_url_placeholder": "https://api.example.com", - "agents_editor_model_config": "模型配置", - "agents_editor_claude_model": "默认模型", - "agents_editor_claude_haiku": "Haiku 模型", - "agents_editor_claude_sonnet": "Sonnet 模型", - "agents_editor_claude_opus": "Opus 模型", - "agents_editor_codex_model": "模型", - "agents_editor_codex_reasoning": "推理强度", - "agents_editor_opencode_provider": "Provider", - "agents_editor_opencode_model": "模型" + "providers_account_delete_description": "将删除账户 {label}。" } diff --git a/src/features/agents/AgentCard.tsx b/src/features/agents/AgentCard.tsx deleted file mode 100644 index 7ee0e9e..0000000 --- a/src/features/agents/AgentCard.tsx +++ /dev/null @@ -1,108 +0,0 @@ -import { MoreHorizontal, Power, Pencil, Trash2, GripVertical } from "lucide-react"; - -import { Badge } from "@/components/ui/badge"; -import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuSeparator, - DropdownMenuTrigger, -} from "@/components/ui/dropdown-menu"; -import { m } from "@/paraglide/messages.js"; - -import type { AgentConfig } from "./types"; -import { AGENT_TOOL_META } from "./constants"; - -type AgentCardProps = { - agent: AgentConfig; - onSwitch: (id: string) => void; - onEdit: (agent: AgentConfig) => void; - onDelete: (id: string) => void; - dragHandleProps?: React.HTMLAttributes; -}; - -export function AgentCard({ - agent, - onSwitch, - onEdit, - onDelete, - dragHandleProps, -}: AgentCardProps) { - const toolMeta = AGENT_TOOL_META[agent.type]; - - return ( - - - {/* 拖拽手柄 */} - {dragHandleProps && ( - - )} - - {/* 主要内容 */} -
-
- {agent.name} - {agent.isActive && ( - - {m.agents_active()} - - )} -
-

{toolMeta.label()}

-
- - {/* 操作按钮 */} -
- {/* 切换激活按钮 */} - {!agent.isActive && ( - - )} - - {/* 更多操作 */} - - - - - - onEdit(agent)}> - - {m.common_edit()} - - - onDelete(agent.id)} - > - - {m.common_delete()} - - - -
-
-
- ); -} diff --git a/src/features/agents/AgentDeleteDialog.tsx b/src/features/agents/AgentDeleteDialog.tsx deleted file mode 100644 index 5b32e23..0000000 --- a/src/features/agents/AgentDeleteDialog.tsx +++ /dev/null @@ -1,42 +0,0 @@ -import { - AlertDialog, - AlertDialogAction, - AlertDialogCancel, - AlertDialogContent, - AlertDialogDescription, - AlertDialogFooter, - AlertDialogHeader, - AlertDialogTitle, -} from "@/components/ui/alert-dialog"; -import { m } from "@/paraglide/messages.js"; - -type AgentDeleteDialogProps = { - open: boolean; - onOpenChange: (open: boolean) => void; - agentName: string; - onConfirm: () => void; -}; - -export function AgentDeleteDialog({ - open, - onOpenChange, - agentName, - onConfirm, -}: AgentDeleteDialogProps) { - return ( - - - - {m.agents_delete_title()} - - {m.agents_delete_description({ name: agentName })} - - - - {m.common_cancel()} - {m.common_delete()} - - - - ); -} diff --git a/src/features/agents/AgentEditorDialog.tsx b/src/features/agents/AgentEditorDialog.tsx deleted file mode 100644 index 58f0d58..0000000 --- a/src/features/agents/AgentEditorDialog.tsx +++ /dev/null @@ -1,467 +0,0 @@ -import { useCallback, useEffect, useMemo, useState } from "react"; - -import { Button } from "@/components/ui/button"; -import { - Dialog, - DialogContent, - DialogDescription, - DialogFooter, - DialogHeader, - DialogTitle, -} from "@/components/ui/dialog"; -import { Input } from "@/components/ui/input"; -import { Label } from "@/components/ui/label"; -import { PasswordInput } from "@/components/ui/password-input"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { Separator } from "@/components/ui/separator"; -import { m } from "@/paraglide/messages.js"; - -import type { - AgentConfig, - AgentConfigForm, - AgentSettingsMap, - AgentToolType, - CodexReasoningEffort, - OpenCodeProvider, -} from "./types"; -import { AGENT_TOOL_META } from "./constants"; -import { createDefaultSettings } from "./agent-store"; - -type AgentEditorDialogProps = { - open: boolean; - onOpenChange: (open: boolean) => void; - /** 编辑模式时传入现有配置,添加模式时为 null */ - editingAgent: AgentConfig | null; - /** 添加模式时的默认工具类型 */ - defaultType: AgentToolType; - onSave: (form: AgentConfigForm) => void; -}; - -/** - * Codex 推理强度选项 - */ -const CODEX_REASONING_OPTIONS: { value: CodexReasoningEffort; label: string }[] = [ - { value: "low", label: "Low" }, - { value: "medium", label: "Medium" }, - { value: "high", label: "High" }, -]; - -/** - * OpenCode Provider 选项 - */ -const OPENCODE_PROVIDER_OPTIONS: { value: OpenCodeProvider; label: string }[] = [ - { value: "openai", label: "OpenAI" }, - { value: "anthropic", label: "Anthropic" }, - { value: "gemini", label: "Google Gemini" }, - { value: "openrouter", label: "OpenRouter" }, - { value: "custom", label: "Custom" }, -]; - -export function AgentEditorDialog({ - open, - onOpenChange, - editingAgent, - defaultType, - onSave, -}: AgentEditorDialogProps) { - // 基础字段 - const [name, setName] = useState(""); - - // Settings 字段 - const [apiKey, setApiKey] = useState(""); - const [apiKeyVisible, setApiKeyVisible] = useState(false); - const [baseUrl, setBaseUrl] = useState(""); - - // Claude 特有字段 - const [claudeModel, setClaudeModel] = useState(""); - const [haikuModel, setHaikuModel] = useState(""); - const [sonnetModel, setSonnetModel] = useState(""); - const [opusModel, setOpusModel] = useState(""); - - // Codex 特有字段 - const [codexModel, setCodexModel] = useState(""); - const [reasoningEffort, setReasoningEffort] = useState("medium"); - - // OpenCode 特有字段 - const [openCodeProvider, setOpenCodeProvider] = useState("openai"); - const [openCodeModel, setOpenCodeModel] = useState(""); - - const isEditing = editingAgent !== null; - const title = isEditing ? m.agents_editor_title_edit() : m.agents_editor_title_add(); - const currentType = isEditing ? editingAgent.type : defaultType; - const toolMeta = AGENT_TOOL_META[currentType]; - - // 打开对话框时初始化表单 - useEffect(() => { - if (open) { - setApiKeyVisible(false); - - if (editingAgent) { - // 编辑模式:从现有配置初始化 - setName(editingAgent.name); - const settings = editingAgent.settings; - setApiKey(settings.apiKey); - setBaseUrl(settings.baseUrl); - - // 根据类型初始化特有字段 - if (editingAgent.type === "claude") { - const s = settings as AgentSettingsMap["claude"]; - setClaudeModel(s.model ?? ""); - setHaikuModel(s.haikuModel ?? ""); - setSonnetModel(s.sonnetModel ?? ""); - setOpusModel(s.opusModel ?? ""); - } else if (editingAgent.type === "codex") { - const s = settings as AgentSettingsMap["codex"]; - setCodexModel(s.model ?? ""); - setReasoningEffort(s.reasoningEffort ?? "medium"); - } else if (editingAgent.type === "opencode") { - const s = settings as AgentSettingsMap["opencode"]; - setOpenCodeProvider(s.provider ?? "openai"); - setOpenCodeModel(s.model ?? ""); - } - } else { - // 添加模式:使用默认值 - setName(""); - const defaultSettings = createDefaultSettings(defaultType); - setApiKey(defaultSettings.apiKey); - setBaseUrl(defaultSettings.baseUrl); - - // 重置特有字段 - setClaudeModel(""); - setHaikuModel(""); - setSonnetModel(""); - setOpusModel(""); - setCodexModel(""); - setReasoningEffort("medium"); - setOpenCodeProvider("openai"); - setOpenCodeModel(""); - } - } - }, [open, editingAgent, defaultType]); - - // 构建 settings 对象 - const buildSettings = useCallback((): AgentSettingsMap[typeof currentType] => { - const base = { apiKey, baseUrl }; - - switch (currentType) { - case "claude": - return { - ...base, - ...(claudeModel && { model: claudeModel }), - ...(haikuModel && { haikuModel }), - ...(sonnetModel && { sonnetModel }), - ...(opusModel && { opusModel }), - } as AgentSettingsMap["claude"]; - case "codex": - return { - ...base, - ...(codexModel && { model: codexModel }), - reasoningEffort, - } as AgentSettingsMap["codex"]; - case "opencode": - return { - ...base, - provider: openCodeProvider, - ...(openCodeModel && { model: openCodeModel }), - } as AgentSettingsMap["opencode"]; - default: - return base as AgentSettingsMap[typeof currentType]; - } - }, [ - currentType, - apiKey, - baseUrl, - claudeModel, - haikuModel, - sonnetModel, - opusModel, - codexModel, - reasoningEffort, - openCodeProvider, - openCodeModel, - ]); - - const handleSave = useCallback(() => { - if (!name.trim()) return; - onSave({ - name: name.trim(), - type: currentType, - settings: buildSettings(), - }); - onOpenChange(false); - }, [name, currentType, buildSettings, onSave, onOpenChange]); - - const canSave = useMemo(() => { - // 名称必填 - if (!name.trim()) return false; - // API Key 必填 - if (!apiKey.trim()) return false; - // Base URL 必填 - if (!baseUrl.trim()) return false; - return true; - }, [name, apiKey, baseUrl]); - - return ( - - - - {title} - - {toolMeta.label()} - {toolMeta.description()} - - - -
- {/* 基础信息 */} -
-
- - setName(e.target.value)} - placeholder={m.agents_editor_name_placeholder()} - autoFocus - /> -
-
- - - - {/* 连接配置 */} -
-

{m.agents_editor_connection()}

- -
- - setApiKey(e.target.value)} - visible={apiKeyVisible} - onVisibilityChange={() => setApiKeyVisible((v) => !v)} - placeholder={m.agents_editor_api_key_placeholder()} - /> -
- -
- - setBaseUrl(e.target.value)} - placeholder={m.agents_editor_base_url_placeholder()} - /> -
-
- - - - {/* 模型配置 - 根据工具类型动态渲染 */} -
-

{m.agents_editor_model_config()}

- - {currentType === "claude" && ( - - )} - - {currentType === "codex" && ( - - )} - - {currentType === "opencode" && ( - - )} -
-
- - - - - -
-
- ); -} - -// ============================================================================ -// 子组件:各工具类型的模型配置字段 -// ============================================================================ - -type ClaudeModelFieldsProps = { - model: string; - onModelChange: (value: string) => void; - haikuModel: string; - onHaikuModelChange: (value: string) => void; - sonnetModel: string; - onSonnetModelChange: (value: string) => void; - opusModel: string; - onOpusModelChange: (value: string) => void; -}; - -function ClaudeModelFields({ - model, - onModelChange, - haikuModel, - onHaikuModelChange, - sonnetModel, - onSonnetModelChange, - opusModel, - onOpusModelChange, -}: ClaudeModelFieldsProps) { - return ( -
-
- - onModelChange(e.target.value)} - placeholder="claude-sonnet-4-20250514" - /> -
-
- - onHaikuModelChange(e.target.value)} - placeholder="claude-haiku-3-5-20241022" - /> -
-
- - onSonnetModelChange(e.target.value)} - placeholder="claude-sonnet-4-20250514" - /> -
-
- - onOpusModelChange(e.target.value)} - placeholder="claude-opus-4-20250514" - /> -
-
- ); -} - -type CodexModelFieldsProps = { - model: string; - onModelChange: (value: string) => void; - reasoningEffort: CodexReasoningEffort; - onReasoningEffortChange: (value: CodexReasoningEffort) => void; -}; - -function CodexModelFields({ - model, - onModelChange, - reasoningEffort, - onReasoningEffortChange, -}: CodexModelFieldsProps) { - return ( -
-
- - onModelChange(e.target.value)} - placeholder="o3" - /> -
-
- - -
-
- ); -} - -type OpenCodeModelFieldsProps = { - provider: OpenCodeProvider; - onProviderChange: (value: OpenCodeProvider) => void; - model: string; - onModelChange: (value: string) => void; -}; - -function OpenCodeModelFields({ - provider, - onProviderChange, - model, - onModelChange, -}: OpenCodeModelFieldsProps) { - return ( -
-
- - -
-
- - onModelChange(e.target.value)} - placeholder="gpt-4o" - /> -
-
- ); -} diff --git a/src/features/agents/AgentsPanel.tsx b/src/features/agents/AgentsPanel.tsx deleted file mode 100644 index a1a5a0c..0000000 --- a/src/features/agents/AgentsPanel.tsx +++ /dev/null @@ -1,124 +0,0 @@ -import { useCallback, useEffect, useState } from "react"; -import { Bot } from "lucide-react"; - -import { m } from "@/paraglide/messages.js"; - -import type { AgentConfig, AgentConfigForm, AgentToolType } from "./types"; -import { useAgentStore } from "./agent-store"; -import { AgentCard } from "./AgentCard"; -import { AgentEditorDialog } from "./AgentEditorDialog"; -import { AgentDeleteDialog } from "./AgentDeleteDialog"; - -type AgentsPanelProps = { - /** 当前选中的工具类型 */ - selectedTool: AgentToolType; - /** 添加对话框触发器(每次变化时打开添加对话框) */ - addTrigger: number; -}; - -export function AgentsPanel({ selectedTool, addTrigger }: AgentsPanelProps) { - const store = useAgentStore(); - - // 编辑对话框状态 - const [editorOpen, setEditorOpen] = useState(false); - const [editingAgent, setEditingAgent] = useState(null); - - // 删除对话框状态 - const [deleteOpen, setDeleteOpen] = useState(false); - const [deletingAgent, setDeletingAgent] = useState(null); - - // 根据 selectedTool 过滤 agents - const filteredAgents = store.agents.filter((agent) => agent.type === selectedTool); - - // 监听 addTrigger 变化,打开添加对话框 - useEffect(() => { - if (addTrigger > 0) { - setEditingAgent(null); - setEditorOpen(true); - } - }, [addTrigger]); - - // 打开编辑对话框 - const handleEdit = useCallback((agent: AgentConfig) => { - setEditingAgent(agent); - setEditorOpen(true); - }, []); - - // 保存(添加或更新) - const handleSave = useCallback( - (form: AgentConfigForm) => { - if (editingAgent) { - store.updateAgent(editingAgent.id, form); - } else { - // 添加时使用当前选中的工具类型 - store.addAgent({ ...form, type: selectedTool }); - } - }, - [editingAgent, store, selectedTool] - ); - - // 打开删除确认对话框 - const handleDeleteClick = useCallback((id: string) => { - const agent = store.agents.find((a) => a.id === id); - if (agent) { - setDeletingAgent(agent); - setDeleteOpen(true); - } - }, [store.agents]); - - // 确认删除 - const handleDeleteConfirm = useCallback(() => { - if (deletingAgent) { - store.deleteAgent(deletingAgent.id); - setDeleteOpen(false); - setDeletingAgent(null); - } - }, [deletingAgent, store]); - - const isEmpty = filteredAgents.length === 0; - - return ( -
- {isEmpty ? ( - // 空状态 -
-
- -
-

{m.agents_empty()}

-

{m.agents_empty_hint()}

-
- ) : ( - // Agent 列表 -
- {filteredAgents.map((agent) => ( - - ))} -
- )} - - {/* 编辑对话框 */} - - - {/* 删除确认对话框 */} - -
- ); -} diff --git a/src/features/agents/agent-store.ts b/src/features/agents/agent-store.ts deleted file mode 100644 index 5c3b362..0000000 --- a/src/features/agents/agent-store.ts +++ /dev/null @@ -1,222 +0,0 @@ -import { useCallback, useMemo, useState } from "react"; - -import type { - AgentConfig, - AgentConfigForm, - AgentSettingsMap, - AgentToolType, - ClaudeSettings, - CodexSettings, -} from "./types"; - -/** - * 生成唯一 ID - */ -function generateId() { - return `agent_${Date.now()}_${Math.random().toString(36).slice(2, 9)}`; -} - -/** - * 各工具类型的默认 Base URL - */ -const DEFAULT_BASE_URLS: Record = { - claude: "https://api.anthropic.com", - codex: "https://api.openai.com/v1", - opencode: "https://api.openai.com/v1", -}; - -/** - * 创建默认的 Agent Settings(导出供外部使用) - */ -export function createDefaultSettings(type: T): AgentSettingsMap[T] { - const base = { apiKey: "", baseUrl: DEFAULT_BASE_URLS[type] }; - - switch (type) { - case "claude": - return { ...base } as AgentSettingsMap[T]; - case "codex": - return { ...base, reasoningEffort: "medium" } as AgentSettingsMap[T]; - case "opencode": - return { ...base, provider: "openai" } as AgentSettingsMap[T]; - default: - return base as AgentSettingsMap[T]; - } -} - -/** - * 创建新的 Agent 配置 - */ -function createAgentConfig( - form: AgentConfigForm, - sortIndex: number -): AgentConfig { - const now = new Date().toISOString(); - return { - id: generateId(), - name: form.name, - type: form.type, - settings: form.settings, - isActive: false, - sortIndex, - createdAt: now, - updatedAt: now, - }; -} - -// Mock 初始数据 -const INITIAL_AGENTS: AgentConfig[] = [ - { - id: "agent_default_claude", - name: "Claude Code (Default)", - type: "claude", - settings: { - apiKey: "", - baseUrl: "https://api.anthropic.com", - } as ClaudeSettings, - isActive: true, - sortIndex: 0, - createdAt: "2025-01-01T00:00:00.000Z", - updatedAt: "2025-01-01T00:00:00.000Z", - }, - { - id: "agent_default_codex", - name: "Codex CLI", - type: "codex", - settings: { - apiKey: "", - baseUrl: "https://api.openai.com/v1", - reasoningEffort: "medium", - } as CodexSettings, - isActive: false, - sortIndex: 1, - createdAt: "2025-01-01T00:00:00.000Z", - updatedAt: "2025-01-01T00:00:00.000Z", - }, -]; - -/** - * Agent 配置状态管理 Hook - */ -export function useAgentStore() { - const [agents, setAgents] = useState(INITIAL_AGENTS); - - // 按 sortIndex 排序的 agents - const sortedAgents = useMemo( - () => [...agents].sort((a, b) => a.sortIndex - b.sortIndex), - [agents] - ); - - // 当前激活的 agent - const activeAgent = useMemo( - () => agents.find((agent) => agent.isActive) ?? null, - [agents] - ); - - // 按类型分组 - const agentsByType = useMemo(() => { - const grouped: Record = { - claude: [], - codex: [], - opencode: [], - }; - for (const agent of sortedAgents) { - grouped[agent.type].push(agent); - } - return grouped; - }, [sortedAgents]); - - // 添加 agent - const addAgent = useCallback((form: AgentConfigForm) => { - setAgents((prev) => { - const maxSortIndex = prev.reduce((max, a) => Math.max(max, a.sortIndex), -1); - const newAgent = createAgentConfig(form, maxSortIndex + 1); - // 如果是第一个该类型的 agent,自动激活 - const hasActiveOfType = prev.some((a) => a.type === form.type && a.isActive); - if (!hasActiveOfType) { - newAgent.isActive = true; - } - return [...prev, newAgent]; - }); - }, []); - - // 更新 agent(包括 settings) - const updateAgent = useCallback((id: string, form: AgentConfigForm) => { - setAgents((prev) => - prev.map((agent) => - agent.id === id - ? { - ...agent, - name: form.name, - type: form.type, - settings: form.settings, - updatedAt: new Date().toISOString(), - } - : agent - ) - ); - }, []); - - // 删除 agent - const deleteAgent = useCallback((id: string) => { - setAgents((prev) => { - const target = prev.find((a) => a.id === id); - if (!target) return prev; - - const remaining = prev.filter((a) => a.id !== id); - - // 如果删除的是激活的 agent,激活同类型的第一个 - if (target.isActive) { - const sameType = remaining.filter((a) => a.type === target.type); - if (sameType.length > 0) { - const firstOfType = sameType.sort((a, b) => a.sortIndex - b.sortIndex)[0]; - return remaining.map((a) => - a.id === firstOfType.id ? { ...a, isActive: true } : a - ); - } - } - - return remaining; - }); - }, []); - - // 切换激活状态(同类型只能有一个激活) - const switchAgent = useCallback((id: string) => { - setAgents((prev) => { - const target = prev.find((a) => a.id === id); - if (!target || target.isActive) return prev; - - return prev.map((agent) => { - // 同类型的其他 agent 取消激活 - if (agent.type === target.type) { - return { ...agent, isActive: agent.id === id }; - } - return agent; - }); - }); - }, []); - - // 更新排序 - const reorderAgents = useCallback((reorderedIds: string[]) => { - setAgents((prev) => { - const idToAgent = new Map(prev.map((a) => [a.id, a])); - return reorderedIds.map((id, index) => { - const agent = idToAgent.get(id); - if (!agent) throw new Error(`Agent not found: ${id}`); - return { ...agent, sortIndex: index }; - }); - }); - }, []); - - return { - agents: sortedAgents, - activeAgent, - agentsByType, - addAgent, - updateAgent, - deleteAgent, - switchAgent, - reorderAgents, - }; -} - -export type AgentStore = ReturnType; diff --git a/src/features/agents/constants.ts b/src/features/agents/constants.ts deleted file mode 100644 index 392fe22..0000000 --- a/src/features/agents/constants.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { m } from "@/paraglide/messages.js"; - -import type { AgentToolMeta, AgentToolType } from "./types"; - -/** - * Agent 工具元数据 - */ -export const AGENT_TOOL_META: Record = { - claude: { - type: "claude", - label: () => m.agents_tool_claude(), - description: () => m.agents_tool_claude_desc(), - }, - codex: { - type: "codex", - label: () => m.agents_tool_codex(), - description: () => m.agents_tool_codex_desc(), - }, - opencode: { - type: "opencode", - label: () => m.agents_tool_opencode(), - description: () => m.agents_tool_opencode_desc(), - }, -}; - -/** - * Agent 工具类型列表(用于下拉选择) - */ -export const AGENT_TOOL_OPTIONS: AgentToolType[] = ["claude", "codex", "opencode"]; diff --git a/src/features/agents/index.ts b/src/features/agents/index.ts deleted file mode 100644 index 9672f1c..0000000 --- a/src/features/agents/index.ts +++ /dev/null @@ -1,19 +0,0 @@ -export { AgentsPanel } from "./AgentsPanel"; -export { AgentCard } from "./AgentCard"; -export { AgentEditorDialog } from "./AgentEditorDialog"; -export { AgentDeleteDialog } from "./AgentDeleteDialog"; -export { useAgentStore, createDefaultSettings } from "./agent-store"; -export { AGENT_TOOL_META, AGENT_TOOL_OPTIONS } from "./constants"; -export type { - AgentConfig, - AgentConfigForm, - AgentToolType, - AgentToolMeta, - AgentSettings, - AgentSettingsMap, - ClaudeSettings, - CodexSettings, - OpenCodeSettings, - CodexReasoningEffort, - OpenCodeProvider, -} from "./types"; diff --git a/src/features/agents/types.ts b/src/features/agents/types.ts deleted file mode 100644 index e07e484..0000000 --- a/src/features/agents/types.ts +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Agent 工具类型 - */ -export type AgentToolType = "claude" | "codex" | "opencode"; - -// ============================================================================ -// Agent Settings 类型定义 -// ============================================================================ - -/** - * 通用连接配置 - */ -type BaseAgentSettings = { - /** API 密钥 */ - apiKey: string; - /** API 基础 URL */ - baseUrl: string; -}; - -/** - * Claude 特有配置 - */ -export type ClaudeSettings = BaseAgentSettings & { - /** 主模型(用于一般任务) */ - model?: string; - /** Haiku 模型(快速轻量任务) */ - haikuModel?: string; - /** Sonnet 模型(平衡任务) */ - sonnetModel?: string; - /** Opus 模型(复杂任务) */ - opusModel?: string; -}; - -/** - * Codex 推理强度 - */ -export type CodexReasoningEffort = "low" | "medium" | "high"; - -/** - * Codex 特有配置 - */ -export type CodexSettings = BaseAgentSettings & { - /** 模型名称 */ - model?: string; - /** 推理强度 */ - reasoningEffort?: CodexReasoningEffort; -}; - -/** - * OpenCode Provider 类型 - */ -export type OpenCodeProvider = "openai" | "anthropic" | "gemini" | "openrouter" | "custom"; - -/** - * OpenCode 特有配置 - */ -export type OpenCodeSettings = BaseAgentSettings & { - /** Provider 类型 */ - provider?: OpenCodeProvider; - /** 模型名称 */ - model?: string; -}; - -/** - * Agent 工具类型到配置的映射 - */ -export type AgentSettingsMap = { - claude: ClaudeSettings; - codex: CodexSettings; - opencode: OpenCodeSettings; -}; - -/** - * 所有 Agent Settings 的联合类型 - */ -export type AgentSettings = AgentSettingsMap[AgentToolType]; - -// ============================================================================ -// Agent Config 类型定义 -// ============================================================================ - -/** - * Agent 配置项 - */ -export type AgentConfig = { - /** 唯一标识符 */ - id: string; - /** 配置名称 */ - name: string; - /** Agent 工具类型 */ - type: T; - /** 工具特定配置 */ - settings: AgentSettingsMap[T]; - /** 是否为当前激活配置 */ - isActive: boolean; - /** 排序索引(越小越靠前) */ - sortIndex: number; - /** 创建时间 ISO 字符串 */ - createdAt: string; - /** 更新时间 ISO 字符串 */ - updatedAt: string; -}; - -/** - * Agent 配置表单(用于添加/编辑) - */ -export type AgentConfigForm = { - name: string; - type: T; - settings: AgentSettingsMap[T]; -}; - -/** - * Agent 工具元数据 - */ -export type AgentToolMeta = { - type: AgentToolType; - label: () => string; - description: () => string; -}; - -// ============================================================================ -// 工具函数类型 -// ============================================================================ - -/** - * 创建默认 settings 的工厂函数类型 - */ -export type CreateDefaultSettings = (type: T) => AgentSettingsMap[T]; diff --git a/src/features/config/AppView.tsx b/src/features/config/AppView.tsx index 4727fab..0adcc93 100644 --- a/src/features/config/AppView.tsx +++ b/src/features/config/AppView.tsx @@ -1,16 +1,9 @@ -import { AlertCircle, Loader2, Plus, RefreshCw } from "lucide-react"; -import { useCallback, useMemo, useState, type CSSProperties } from "react"; +import { AlertCircle, Loader2, RefreshCw } from "lucide-react"; +import { useMemo, type CSSProperties } from "react"; import { AppSidebar } from "@/components/app-sidebar"; import { SiteHeader } from "@/components/site-header"; import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; import { AlertDialog, AlertDialogAction, @@ -26,6 +19,7 @@ import { Button } from "@/components/ui/button"; import { ScrollArea } from "@/components/ui/scroll-area"; import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar"; import { + ClientSetupCard, ConfigFileCard, AutoStartCard, ProjectLinksCard, @@ -47,15 +41,11 @@ import type { ProxyServiceRequestState, ProxyServiceStatus, } from "@/features/config/types"; -import { AgentsPanel } from "@/features/agents"; import { DashboardPanel } from "@/features/dashboard/DashboardPanel"; import { LogsPanel } from "@/features/logs/LogsPanel"; import { ProvidersPanel } from "@/features/providers/ProvidersPanel"; import { m } from "@/paraglide/messages.js"; -// Agent 工具类型(用于 agents 页面工具筛选与创建) -export type AgentToolId = "claude" | "codex" | "opencode"; - type AppViewProps = { activeSectionId: ConfigSectionId; form: ConfigForm; @@ -103,10 +93,6 @@ type ConfigToolbarProps = { isDirty: boolean; onReload: () => void; onSave: () => void; - // agents 页面专用 - selectedTool?: AgentToolId; - onToolChange?: (tool: AgentToolId) => void; - onAddAgent?: () => void; }; function ConfigToolbar({ @@ -116,42 +102,24 @@ function ConfigToolbar({ isDirty, onReload, onSave, - selectedTool, - onToolChange, - onAddAgent, }: ConfigToolbarProps) { const isLoading = status === "loading"; const isSaving = status === "saving"; const canReload = !isLoading && !isSaving; - const isAgentsSection = section.id === "agents"; return (
- {/* agents 页面显示工具选择器,其他页面显示标题 */} - {isAgentsSection && selectedTool && onToolChange ? ( - - ) : ( -
-

- {section.label()} -

-

- {section.description()} -

-
- )} +
+

+ {section.label()} +

+

+ {section.description()} +

+
{isDirty ? ( @@ -194,21 +162,13 @@ function ConfigToolbar({ {m.common_refresh()} )} - {/* agents 页面显示添加按钮,其他页面显示保存按钮 */} - {isAgentsSection && onAddAgent ? ( - - ) : ( - - )} +
); @@ -238,16 +198,11 @@ type ConfigSectionContentProps = Omit & { proxyService: ProxyServiceViewProps; }; -type ConfigSectionBodyProps = ConfigSectionContentProps & { - selectedTool: AgentToolId; - agentEditorTrigger: number; -}; +type ConfigSectionBodyProps = ConfigSectionContentProps; function ConfigSectionBody({ activeSectionId, proxyService, - selectedTool, - agentEditorTrigger, ...props }: ConfigSectionBodyProps) { switch (activeSectionId) { @@ -304,7 +259,7 @@ function ConfigSectionBody({ case "agents": return (
- +
); default: @@ -317,14 +272,6 @@ function ConfigSectionContent({ proxyService, ...props }: ConfigSectionContentProps) { - // agents 页面的工具选择状态 - const [selectedTool, setSelectedTool] = useState("claude"); - // agents 页面的添加对话框触发器 - const [agentEditorTrigger, setAgentEditorTrigger] = useState(0); - const handleAddAgent = useCallback(() => { - setAgentEditorTrigger((prev) => prev + 1); - }, []); - if (activeSectionId === "dashboard") { return ; } @@ -344,17 +291,12 @@ function ConfigSectionContent({ isDirty={props.isDirty} onReload={props.onReload} onSave={props.onSave} - selectedTool={selectedTool} - onToolChange={setSelectedTool} - onAddAgent={handleAddAgent} /> ); diff --git a/src/features/config/cards/client-setup-card.tsx b/src/features/config/cards/client-setup-card.tsx index 2a95e81..ab82a88 100644 --- a/src/features/config/cards/client-setup-card.tsx +++ b/src/features/config/cards/client-setup-card.tsx @@ -1,11 +1,13 @@ import type { ReactNode } from "react"; -import type { AgentToolId } from "@/features/config/AppView"; +import { m } from "@/paraglide/messages.js"; import { + ClientSetupOverviewCard, PlaintextWarning, + SummaryItem, ToolDetailsFallback, - ToolSetupPanel, + ToolSetupDialog, } from "./client-setup-ui"; import { useClientSetupPreview, @@ -23,11 +25,13 @@ import { type ClientSetupCardProps = { savedAt: string; isDirty: boolean; - selectedTool: AgentToolId; }; -type ToolPanelItem = { +type ToolListItem = { id: string; + title: string; + description: string; + summary: ReactNode; content: ReactNode; action: ActionState; canApply: boolean; @@ -56,9 +60,17 @@ function buildClaudeTool({ isWorking, action, onApply, -}: ToolBuildBaseArgs & ToolBuildActionArgs): ToolPanelItem { +}: ToolBuildBaseArgs & ToolBuildActionArgs) { return { id: "claude", + title: m.client_setup_claude_title(), + description: m.client_setup_claude_desc(), + summary: ( + + ), content: setup ? ( ) : ( @@ -68,7 +80,7 @@ function buildClaudeTool({ canApply: Boolean(setup) && canApply, isWorking, onApply, - }; + } satisfies ToolListItem; } function buildCodexTool({ @@ -79,9 +91,17 @@ function buildCodexTool({ isWorking, action, onApply, -}: ToolBuildBaseArgs & ToolBuildActionArgs): ToolPanelItem { +}: ToolBuildBaseArgs & ToolBuildActionArgs) { return { id: "codex", + title: m.client_setup_codex_title(), + description: m.client_setup_codex_desc(), + summary: ( + + ), content: setup ? ( ) : ( @@ -91,7 +111,7 @@ function buildCodexTool({ canApply: Boolean(setup) && canApply, isWorking, onApply, - }; + } satisfies ToolListItem; } type OpenCodeToolArgs = ToolBuildBaseArgs & ToolBuildActionArgs & { @@ -104,12 +124,27 @@ function buildOpenCodeTool({ previewState, previewMessage, canApplyOpenCode, + openCodeModelCount, isWorking, action, onApply, -}: OpenCodeToolArgs): ToolPanelItem { +}: OpenCodeToolArgs) { return { id: "opencode", + title: m.client_setup_opencode_title(), + description: m.client_setup_opencode_desc(), + summary: ( +
+ + +
+ ), content: setup ? ( ) : ( @@ -119,10 +154,31 @@ function buildOpenCodeTool({ canApply: Boolean(setup) && canApplyOpenCode, isWorking, onApply, - }; + } satisfies ToolListItem; +} + +function ToolCards({ tools }: { tools: readonly ToolListItem[] }) { + return ( + <> + {tools.map((tool) => ( + + {tool.content} + + ))} + + ); } -export function ClientSetupCard({ savedAt, isDirty, selectedTool }: ClientSetupCardProps) { +export function ClientSetupCard({ savedAt, isDirty }: ClientSetupCardProps) { const canApply = !isDirty; const { previewState, previewMessage, setup, loadPreview } = useClientSetupPreview(savedAt); @@ -147,31 +203,29 @@ export function ClientSetupCard({ savedAt, isDirty, selectedTool }: ClientSetupC isWorking, }; - // 根据 selectedTool 构建对应的工具面板 - const toolBuilders: Record ToolPanelItem> = { - claude: () => buildClaudeTool({ ...baseArgs, action: claude.action, onApply: claude.apply }), - codex: () => buildCodexTool({ ...baseArgs, action: codex.action, onApply: codex.apply }), - opencode: () => buildOpenCodeTool({ + const tools: ToolListItem[] = [ + buildClaudeTool({ ...baseArgs, action: claude.action, onApply: claude.apply }), + buildCodexTool({ ...baseArgs, action: codex.action, onApply: codex.apply }), + buildOpenCodeTool({ ...baseArgs, action: opencode.action, onApply: opencode.apply, openCodeModelCount, canApplyOpenCode, }), - }; - - const selectedToolItem = toolBuilders[selectedTool](); + ]; return ( <> - - {selectedToolItem.content} - + + ); diff --git a/src/features/config/cards/client-setup-ui.tsx b/src/features/config/cards/client-setup-ui.tsx index eb38e2f..7c5ea08 100644 --- a/src/features/config/cards/client-setup-ui.tsx +++ b/src/features/config/cards/client-setup-ui.tsx @@ -2,13 +2,26 @@ import type { ReactNode } from "react"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; -import { Card, CardContent } from "@/components/ui/card"; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { + Dialog, + DialogBody, + DialogClose, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger, +} from "@/components/ui/dialog"; import { m } from "@/paraglide/messages.js"; -import type { ActionState, RequestState } from "./client-setup-state"; +import type { ActionState, ClientSetupInfo, RequestState } from "./client-setup-state"; -// 内联展示工具配置的 props -type ToolSetupPanelProps = { +type ToolSetupDialogProps = { + title: string; + description: string; + summary: ReactNode; action: ActionState; canApply: boolean; isWorking: boolean; @@ -93,43 +106,105 @@ export function CodeBlock({ lines }: { lines: readonly string[] }) { ); } -/** 内联展示工具配置面板(无弹窗) */ -export function ToolSetupPanel({ +type ToolSetupCardProps = Pick; + +function ToolSetupCard({ title, description, summary, action }: ToolSetupCardProps) { + return ( + + +
+
+ {title} + {description} +
+ {shouldShowBadge(action.state) ? ( + {toBadgeLabel(action.state)} + ) : null} +
+
+ + {summary} + + + + +
+ ); +} + +type ToolSetupModalProps = Omit; + +function ToolSetupModal({ + title, + description, action, canApply, isWorking, onApply, children, -}: ToolSetupPanelProps) { +}: ToolSetupModalProps) { return ( - - - {/* 详细配置内容 */} + + +
+
+ {title} + {description} +
+ {shouldShowBadge(action.state) ? ( + {toBadgeLabel(action.state)} + ) : null} +
+
+ + {children} - {/* 操作状态消息 */} {action.message ? (
{action.message}
) : null} - {/* 备份提示 */}

{m.client_setup_backup_hint()}

+
- {/* 底部操作栏 */} -
-
- {shouldShowBadge(action.state) ? ( - {toBadgeLabel(action.state)} - ) : null} -
- -
-
-
+ + + + + ); +} + +export function ToolSetupDialog(props: ToolSetupDialogProps) { + return ( + + + + {props.children} + + ); } @@ -175,3 +250,63 @@ export function PlaintextWarning() { ); } + +type ClientSetupOverviewCardProps = { + previewState: RequestState; + previewMessage: string; + setup: ClientSetupInfo | null; + isDirty: boolean; + isWorking: boolean; + onRefresh: () => void; +}; + +export function ClientSetupOverviewCard({ + previewState, + previewMessage, + setup, + isDirty, + isWorking, + onRefresh, +}: ClientSetupOverviewCardProps) { + return ( + + + {m.client_setup_title()} + {m.client_setup_desc()} + + +
+ {shouldShowBadge(previewState) ? ( + {toBadgeLabel(previewState)} + ) : null} + +
+ + {previewMessage ? ( +
+ {previewMessage} +
+ ) : null} + + {isDirty ? ( +
+ {m.client_setup_dirty_notice()} +
+ ) : null} + + {setup ? ( +
+

+ {m.client_setup_proxy_base_url_label()} +

+

+ {setup.proxy_http_base_url} +

+
+ ) : null} +
+
+ ); +} diff --git a/src/features/config/cards/upstreams-card.tsx b/src/features/config/cards/upstreams-card.tsx index 1cb78ae..76e7e99 100644 --- a/src/features/config/cards/upstreams-card.tsx +++ b/src/features/config/cards/upstreams-card.tsx @@ -9,6 +9,21 @@ import { mergeProviderOptions, UPSTREAM_COLUMNS, } from "@/features/config/cards/upstreams/constants"; +import { + cloneUpstreamDraft, + coerceProviderSelection, + createAutoUpstreamId, + createCopiedUpstreamId, + findIdleAntigravityAccount, + findIdleCodexAccount, + findIdleKiroAccount, + hasProvider, + normalizeProviders, + pruneConvertFromMap, + providersEqual, + resolveUpstreamIdForProviderChange, + stripJsonSuffix, +} from "@/features/config/cards/upstreams/upstream-editor-helpers"; import { ColumnsDialog } from "@/features/config/cards/upstreams/columns-dialog"; import { DeleteUpstreamDialog } from "@/features/config/cards/upstreams/delete-dialog"; import { UpstreamEditorDialog } from "@/features/config/cards/upstreams/editor-dialog"; @@ -19,13 +34,9 @@ import type { UpstreamEditorState, } from "@/features/config/cards/upstreams/types"; import { createEmptyUpstream } from "@/features/config/form"; -import { createNativeInboundFormatSet, removeInboundFormatsInSet } from "@/features/config/inbound-formats"; import { useCodexAccounts } from "@/features/codex/use-codex-accounts"; -import type { CodexAccountSummary } from "@/features/codex/types"; import { useKiroAccounts } from "@/features/kiro/use-kiro-accounts"; -import type { KiroAccountSummary } from "@/features/kiro/types"; import { useAntigravityAccounts } from "@/features/antigravity/use-antigravity-accounts"; -import type { AntigravityAccountSummary } from "@/features/antigravity/types"; import type { UpstreamForm, UpstreamStrategy } from "@/features/config/types"; import { m } from "@/paraglide/messages.js"; @@ -42,223 +53,6 @@ type UpstreamsCardProps = { onChange: (index: number, patch: Partial) => void; }; -function createCopiedUpstreamId(sourceId: string, upstreams: readonly UpstreamForm[]) { - const base = sourceId.trim() || "upstream"; - const taken = new Set( - upstreams - .map((upstream) => upstream.id.trim()) - .filter((id) => id), - ); - - const prefix = `${base}-copy`; - if (!taken.has(prefix)) { - return prefix; - } - - let suffix = 2; - while (taken.has(`${prefix}-${suffix}`)) { - suffix += 1; - } - return `${prefix}-${suffix}`; -} - -/** - * 基于 providers 自动生成唯一 ID - * - 单 provider:openai-1, openai-2 - * - 多 provider:仍以第一个 provider 作为前缀(避免 id 频繁变化) - */ -function createAutoUpstreamId( - providers: readonly string[], - upstreams: readonly UpstreamForm[], - editingIndex?: number, -) { - const base = providers[0]?.trim() || "upstream"; - const taken = new Set( - upstreams - .filter((_, index) => index !== editingIndex) - .map((upstream) => upstream.id.trim()) - .filter((id) => id), - ); - - // 先尝试 provider-1 - let suffix = 1; - while (taken.has(`${base}-${suffix}`)) { - suffix += 1; - } - return `${base}-${suffix}`; -} - -function normalizeProviders(values: readonly string[]) { - const output: string[] = []; - const seen = new Set(); - for (const value of values) { - const trimmed = value.trim(); - if (!trimmed) { - continue; - } - if (seen.has(trimmed)) { - continue; - } - seen.add(trimmed); - output.push(trimmed); - } - return output; -} - -function providersEqual(left: readonly string[], right: readonly string[]) { - if (left.length !== right.length) { - return false; - } - for (let index = 0; index < left.length; index += 1) { - if (left[index] !== right[index]) { - return false; - } - } - return true; -} - -function coerceProviderSelection(next: readonly string[]) { - const normalized = normalizeProviders(next); - const special = normalized.find((provider) => - provider === "kiro" || provider === "codex" || provider === "antigravity", - ); - if (!special) { - return normalized; - } - return [special]; -} - -function hasProvider(upstream: UpstreamForm, provider: string) { - return upstream.providers.some((value) => value.trim() === provider); -} - -function pruneConvertFromMap( - map: UpstreamForm["convertFromMap"], - providers: readonly string[], -) { - if (!Object.keys(map).length) { - return map; - } - const providerSet = new Set(providers); - const nativeFormatsInUpstream = createNativeInboundFormatSet(providers); - const output: UpstreamForm["convertFromMap"] = {}; - for (const [provider, formats] of Object.entries(map)) { - if (!providerSet.has(provider)) { - continue; - } - const filtered = removeInboundFormatsInSet(formats, nativeFormatsInUpstream); - if (!filtered.length) { - continue; - } - output[provider] = filtered; - } - return output; -} - -/** - * 去除 account_id 的 .json 后缀,用于生成更简洁的 upstream ID - */ -function stripJsonSuffix(accountId: string): string { - return accountId.endsWith(".json") ? accountId.slice(0, -5) : accountId; -} - -/** - * 找到第一个未被其他上游使用的空闲 kiro 账户 - * 优先返回 active 状态的账户 - */ -function findIdleKiroAccount( - accounts: KiroAccountSummary[], - upstreams: readonly UpstreamForm[], - editingIndex?: number, -): KiroAccountSummary | undefined { - // 收集已被使用的 kiro account id - const usedAccountIds = new Set( - upstreams - .filter((upstream, index) => { - if (index === editingIndex) return false; - return hasProvider(upstream, "kiro") && upstream.kiroAccountId.trim(); - }) - .map((upstream) => upstream.kiroAccountId.trim()), - ); - - // 先找 active 状态的空闲账户 - const activeIdle = accounts.find( - (account) => account.status === "active" && !usedAccountIds.has(account.account_id), - ); - if (activeIdle) return activeIdle; - - // 如果没有 active 的,找任意空闲账户 - return accounts.find((account) => !usedAccountIds.has(account.account_id)); -} - -/** - * 找到第一个未被其他上游使用的空闲 codex 账户 - * 优先返回 active 状态的账户 - */ -function findIdleCodexAccount( - accounts: CodexAccountSummary[], - upstreams: readonly UpstreamForm[], - editingIndex?: number, -): CodexAccountSummary | undefined { - const usedAccountIds = new Set( - upstreams - .filter((upstream, index) => { - if (index === editingIndex) return false; - return hasProvider(upstream, "codex") && upstream.codexAccountId.trim(); - }) - .map((upstream) => upstream.codexAccountId.trim()), - ); - - const activeIdle = accounts.find( - (account) => account.status === "active" && !usedAccountIds.has(account.account_id), - ); - if (activeIdle) return activeIdle; - - return accounts.find((account) => !usedAccountIds.has(account.account_id)); -} - -/** - * 找到第一个未被其他上游使用的空闲 antigravity 账户 - * 优先返回 active 状态的账户 - */ -function findIdleAntigravityAccount( - accounts: AntigravityAccountSummary[], - upstreams: readonly UpstreamForm[], - editingIndex?: number, -): AntigravityAccountSummary | undefined { - const usedAccountIds = new Set( - upstreams - .filter((upstream, index) => { - if (index === editingIndex) return false; - return ( - hasProvider(upstream, "antigravity") && - upstream.antigravityAccountId.trim() - ); - }) - .map((upstream) => upstream.antigravityAccountId.trim()), - ); - - const activeIdle = accounts.find( - (account) => account.status === "active" && !usedAccountIds.has(account.account_id), - ); - if (activeIdle) return activeIdle; - - return accounts.find((account) => !usedAccountIds.has(account.account_id)); -} - -function cloneUpstreamDraft(upstream: UpstreamForm): UpstreamForm { - const providers = normalizeProviders(upstream.providers); - return { - ...upstream, - // provider 必选:编辑/复制时也保证至少有一个 provider,避免 UI 出现“看起来有默认值但实际为空”的不同步体验 - providers: providers.length ? providers : ["openai"], - modelMappings: upstream.modelMappings.map((mapping) => ({ ...mapping })), - overrides: { - header: upstream.overrides.header.map((entry) => ({ ...entry })), - }, - }; -} - export function UpstreamsCard({ upstreams, appProxyUrl, @@ -333,10 +127,11 @@ export function UpstreamsCard({ const providersChanged = patch.providers !== undefined && !providersEqual(nextProviders, currentProviders); - const currentPrimary = currentProviders[0] ?? ""; const nextPrimary = nextProviders[0] ?? ""; - // 如果 provider 变化,自动生成新 ID 并处理账户绑定 + // 如果 provider 变化,处理账户绑定 + ID 自动逻辑: + // - 新增:根据 provider 自动生成 ID + // - 编辑:普通 provider 不自动改 ID(避免统计/引用被拆分),仅 kiro/codex/antigravity 会随账户同步 if (providersChanged) { let kiroAccountId = prev.draft.kiroAccountId; let codexAccountId = prev.draft.codexAccountId; @@ -390,18 +185,17 @@ export function UpstreamsCard({ convertFromMap = pruneConvertFromMap(convertFromMap, nextProviders); - const shouldAutoId = nextPrimary !== currentPrimary && !!nextPrimary; - const autoId = shouldAutoId - ? createAutoUpstreamId(nextProviders, upstreams, editingIndex) - : prev.draft.id; - const id = - nextPrimary === "kiro" && kiroAccountId - ? stripJsonSuffix(kiroAccountId) - : nextPrimary === "codex" && codexAccountId - ? stripJsonSuffix(codexAccountId) - : nextPrimary === "antigravity" && antigravityAccountId - ? stripJsonSuffix(antigravityAccountId) - : autoId; + const id = resolveUpstreamIdForProviderChange({ + mode: prev.mode, + currentId: prev.draft.id, + currentProviders, + nextProviders, + upstreams, + editingIndex, + kiroAccountId, + codexAccountId, + antigravityAccountId, + }); return { ...prev, diff --git a/src/features/config/cards/upstreams/upstream-editor-helpers.test.ts b/src/features/config/cards/upstreams/upstream-editor-helpers.test.ts new file mode 100644 index 0000000..6ebb222 --- /dev/null +++ b/src/features/config/cards/upstreams/upstream-editor-helpers.test.ts @@ -0,0 +1,111 @@ +import { describe, expect, it } from "vitest"; + +import { createEmptyUpstream } from "@/features/config/form"; +import { resolveUpstreamIdForProviderChange } from "@/features/config/cards/upstreams/upstream-editor-helpers"; + +describe("upstreams/upstream-editor-helpers", () => { + it("keeps id stable when editing and switching non-special provider", () => { + const upstream = createEmptyUpstream(); + upstream.id = "custom-1"; + upstream.providers = ["openai"]; + + const id = resolveUpstreamIdForProviderChange({ + mode: "edit", + currentId: upstream.id, + currentProviders: ["openai"], + nextProviders: ["gemini"], + upstreams: [upstream], + editingIndex: 0, + kiroAccountId: "", + codexAccountId: "", + antigravityAccountId: "", + }); + + expect(id).toBe("custom-1"); + }); + + it("updates id to account_id when editing and switching to kiro/codex/antigravity", () => { + const upstream = createEmptyUpstream(); + upstream.id = "custom-1"; + upstream.providers = ["openai"]; + + const kiroId = resolveUpstreamIdForProviderChange({ + mode: "edit", + currentId: upstream.id, + currentProviders: ["openai"], + nextProviders: ["kiro"], + upstreams: [upstream], + editingIndex: 0, + kiroAccountId: "foo.json", + codexAccountId: "", + antigravityAccountId: "", + }); + expect(kiroId).toBe("foo"); + + const codexId = resolveUpstreamIdForProviderChange({ + mode: "edit", + currentId: upstream.id, + currentProviders: ["openai"], + nextProviders: ["codex"], + upstreams: [upstream], + editingIndex: 0, + kiroAccountId: "", + codexAccountId: "bar.json", + antigravityAccountId: "", + }); + expect(codexId).toBe("bar"); + + const antigravityId = resolveUpstreamIdForProviderChange({ + mode: "edit", + currentId: upstream.id, + currentProviders: ["openai"], + nextProviders: ["antigravity"], + upstreams: [upstream], + editingIndex: 0, + kiroAccountId: "", + codexAccountId: "", + antigravityAccountId: "baz.json", + }); + expect(antigravityId).toBe("baz"); + }); + + it("keeps id when editing and switching away from special provider", () => { + const upstream = createEmptyUpstream(); + upstream.id = "foo"; + upstream.providers = ["kiro"]; + + const id = resolveUpstreamIdForProviderChange({ + mode: "edit", + currentId: upstream.id, + currentProviders: ["kiro"], + nextProviders: ["openai"], + upstreams: [upstream], + editingIndex: 0, + kiroAccountId: "", + codexAccountId: "", + antigravityAccountId: "", + }); + + expect(id).toBe("foo"); + }); + + it("auto-generates id when creating and switching provider", () => { + const upstream = createEmptyUpstream(); + upstream.id = "openai-1"; + upstream.providers = ["openai"]; + + const id = resolveUpstreamIdForProviderChange({ + mode: "create", + currentId: "openai-1", + currentProviders: ["openai"], + nextProviders: ["gemini"], + upstreams: [upstream], + kiroAccountId: "", + codexAccountId: "", + antigravityAccountId: "", + }); + + expect(id).toBe("gemini-1"); + }); +}); + diff --git a/src/features/config/cards/upstreams/upstream-editor-helpers.ts b/src/features/config/cards/upstreams/upstream-editor-helpers.ts new file mode 100644 index 0000000..a1cb12a --- /dev/null +++ b/src/features/config/cards/upstreams/upstream-editor-helpers.ts @@ -0,0 +1,263 @@ +import { createNativeInboundFormatSet, removeInboundFormatsInSet } from "@/features/config/inbound-formats"; +import type { UpstreamForm } from "@/features/config/types"; +import type { AntigravityAccountSummary } from "@/features/antigravity/types"; +import type { CodexAccountSummary } from "@/features/codex/types"; +import type { KiroAccountSummary } from "@/features/kiro/types"; + +export function createCopiedUpstreamId(sourceId: string, upstreams: readonly UpstreamForm[]) { + const base = sourceId.trim() || "upstream"; + const taken = new Set( + upstreams + .map((upstream) => upstream.id.trim()) + .filter((id) => id), + ); + + const prefix = `${base}-copy`; + if (!taken.has(prefix)) { + return prefix; + } + + let suffix = 2; + while (taken.has(`${prefix}-${suffix}`)) { + suffix += 1; + } + return `${prefix}-${suffix}`; +} + +/** + * 基于 providers 自动生成唯一 ID + * - 单 provider:openai-1, openai-2 + * - 多 provider:仍以第一个 provider 作为前缀(避免 id 频繁变化) + */ +export function createAutoUpstreamId( + providers: readonly string[], + upstreams: readonly UpstreamForm[], + editingIndex?: number, +) { + const base = providers[0]?.trim() || "upstream"; + const taken = new Set( + upstreams + .filter((_, index) => index !== editingIndex) + .map((upstream) => upstream.id.trim()) + .filter((id) => id), + ); + + // 先尝试 provider-1 + let suffix = 1; + while (taken.has(`${base}-${suffix}`)) { + suffix += 1; + } + return `${base}-${suffix}`; +} + +export function normalizeProviders(values: readonly string[]) { + const output: string[] = []; + const seen = new Set(); + for (const value of values) { + const trimmed = value.trim(); + if (!trimmed) { + continue; + } + if (seen.has(trimmed)) { + continue; + } + seen.add(trimmed); + output.push(trimmed); + } + return output; +} + +export function providersEqual(left: readonly string[], right: readonly string[]) { + if (left.length !== right.length) { + return false; + } + for (let index = 0; index < left.length; index += 1) { + if (left[index] !== right[index]) { + return false; + } + } + return true; +} + +export function coerceProviderSelection(next: readonly string[]) { + const normalized = normalizeProviders(next); + const special = normalized.find((provider) => + provider === "kiro" || provider === "codex" || provider === "antigravity", + ); + if (!special) { + return normalized; + } + return [special]; +} + +export function hasProvider(upstream: UpstreamForm, provider: string) { + return upstream.providers.some((value) => value.trim() === provider); +} + +export function pruneConvertFromMap( + map: UpstreamForm["convertFromMap"], + providers: readonly string[], +) { + if (!Object.keys(map).length) { + return map; + } + const providerSet = new Set(providers); + const nativeFormatsInUpstream = createNativeInboundFormatSet(providers); + const output: UpstreamForm["convertFromMap"] = {}; + for (const [provider, formats] of Object.entries(map)) { + if (!providerSet.has(provider)) { + continue; + } + const filtered = removeInboundFormatsInSet(formats, nativeFormatsInUpstream); + if (!filtered.length) { + continue; + } + output[provider] = filtered; + } + return output; +} + +/** + * 去除 account_id 的 .json 后缀,用于生成更简洁的 upstream ID + */ +export function stripJsonSuffix(accountId: string) { + return accountId.endsWith(".json") ? accountId.slice(0, -5) : accountId; +} + +/** + * 编辑时 ID 的期望: + * - 普通 provider:切换 provider 不自动改 ID(避免统计/引用被“拆分”) + * - kiro/codex/antigravity:ID 与账户绑定,允许自动同步为 account_id(去掉 .json) + */ +export function resolveUpstreamIdForProviderChange(args: { + mode: "create" | "edit"; + currentId: string; + currentProviders: readonly string[]; + nextProviders: readonly string[]; + upstreams: readonly UpstreamForm[]; + editingIndex?: number; + kiroAccountId: string; + codexAccountId: string; + antigravityAccountId: string; +}) { + const currentPrimary = args.currentProviders[0]?.trim() ?? ""; + const nextPrimary = args.nextProviders[0]?.trim() ?? ""; + + const specialId = + nextPrimary === "kiro" && args.kiroAccountId.trim() + ? stripJsonSuffix(args.kiroAccountId.trim()) + : nextPrimary === "codex" && args.codexAccountId.trim() + ? stripJsonSuffix(args.codexAccountId.trim()) + : nextPrimary === "antigravity" && args.antigravityAccountId.trim() + ? stripJsonSuffix(args.antigravityAccountId.trim()) + : null; + if (specialId) { + return specialId; + } + + // 仅“新增”才允许根据 provider 自动改 ID;编辑中保持稳定,交给用户手动调整。 + if (args.mode === "edit") { + return args.currentId; + } + + const shouldAutoId = nextPrimary !== currentPrimary && !!nextPrimary; + if (!shouldAutoId) { + return args.currentId; + } + return createAutoUpstreamId(args.nextProviders, args.upstreams, args.editingIndex); +} + +/** + * 找到第一个未被其他上游使用的空闲 kiro 账户 + * 优先返回 active 状态的账户 + */ +export function findIdleKiroAccount( + accounts: KiroAccountSummary[], + upstreams: readonly UpstreamForm[], + editingIndex?: number, +): KiroAccountSummary | undefined { + // 收集已被使用的 kiro account id + const usedAccountIds = new Set( + upstreams + .filter((upstream, index) => { + if (index === editingIndex) return false; + return hasProvider(upstream, "kiro") && upstream.kiroAccountId.trim(); + }) + .map((upstream) => upstream.kiroAccountId.trim()), + ); + + // 先找 active 状态的空闲账户 + const activeIdle = accounts.find( + (account) => account.status === "active" && !usedAccountIds.has(account.account_id), + ); + if (activeIdle) return activeIdle; + + // 如果没有 active 的,找任意空闲账户 + return accounts.find((account) => !usedAccountIds.has(account.account_id)); +} + +/** + * 找到第一个未被其他上游使用的空闲 codex 账户 + * 优先返回 active 状态的账户 + */ +export function findIdleCodexAccount( + accounts: CodexAccountSummary[], + upstreams: readonly UpstreamForm[], + editingIndex?: number, +): CodexAccountSummary | undefined { + const usedAccountIds = new Set( + upstreams + .filter((upstream, index) => { + if (index === editingIndex) return false; + return hasProvider(upstream, "codex") && upstream.codexAccountId.trim(); + }) + .map((upstream) => upstream.codexAccountId.trim()), + ); + + const activeIdle = accounts.find( + (account) => account.status === "active" && !usedAccountIds.has(account.account_id), + ); + if (activeIdle) return activeIdle; + + return accounts.find((account) => !usedAccountIds.has(account.account_id)); +} + +/** + * 找到第一个未被其他上游使用的空闲 antigravity 账户 + * 优先返回 active 状态的账户 + */ +export function findIdleAntigravityAccount( + accounts: AntigravityAccountSummary[], + upstreams: readonly UpstreamForm[], + editingIndex?: number, +): AntigravityAccountSummary | undefined { + const usedAccountIds = new Set( + upstreams + .filter((upstream, index) => { + if (index === editingIndex) return false; + return hasProvider(upstream, "antigravity") && upstream.antigravityAccountId.trim(); + }) + .map((upstream) => upstream.antigravityAccountId.trim()), + ); + + const activeIdle = accounts.find( + (account) => account.status === "active" && !usedAccountIds.has(account.account_id), + ); + if (activeIdle) return activeIdle; + + return accounts.find((account) => !usedAccountIds.has(account.account_id)); +} + +export function cloneUpstreamDraft(upstream: UpstreamForm) { + const providers = normalizeProviders(upstream.providers); + return { + ...upstream, + // provider 必选:编辑/复制时也保证至少有一个 provider,避免 UI 出现“看起来有默认值但实际为空”的不同步体验 + providers: providers.length ? providers : ["openai"], + modelMappings: upstream.modelMappings.map((mapping) => ({ ...mapping })), + overrides: { + header: upstream.overrides.header.map((entry) => ({ ...entry })), + }, + }; +} + diff --git a/src/features/dashboard/RecentRequestsTable.tsx b/src/features/dashboard/RecentRequestsTable.tsx index eb232a8..be16893 100644 --- a/src/features/dashboard/RecentRequestsTable.tsx +++ b/src/features/dashboard/RecentRequestsTable.tsx @@ -122,9 +122,7 @@ function modelColumn(): ColumnDef { header: m.dashboard_table_model(), cell: ({ row }) => { const primary = row.original.model?.trim() ? row.original.model : CELL_PLACEHOLDER; - const rawMapped = row.original.mappedModel?.trim() ? row.original.mappedModel : null; - // 只有当 mappedModel 存在且与 model 不同时才显示映射 - const mapped = rawMapped && rawMapped !== row.original.model ? rawMapped : null; + const mapped = row.original.mappedModel?.trim() ? row.original.mappedModel : null; const tooltipText = mapped ? `${primary}\n${mapped}` : primary; return ( From c4b357dfbe26c7745a75ebb359a9a3551eb17bf5 Mon Sep 17 00:00:00 2001 From: mxyhi Date: Fri, 30 Jan 2026 17:59:18 +0800 Subject: [PATCH 04/10] fix: align antigravity/claude handling with CLIProxyAPIPlus to improve compatibility --- .../src/antigravity/endpoints.rs | 3 +- .../src/proxy/anthropic_compat/request.rs | 45 ++++++++++++++++- .../src/proxy/antigravity_compat.rs | 48 ++++++++++++++++++- .../src/proxy/antigravity_compat.test.rs | 36 ++++++++++++++ .../src/proxy/config/normalize.rs | 2 +- .../src/proxy/openai_compat/input.rs | 6 ++- .../src/proxy/openai_compat/message.rs | 28 +++++++++-- crates/token_proxy_core/src/proxy/upstream.rs | 11 ++++- .../src/proxy/upstream/attempt.rs | 7 ++- .../src/proxy/upstream/kiro.rs | 2 +- src-tauri/src/antigravity/endpoints.rs | 3 +- .../src/proxy/anthropic_compat/request.rs | 45 ++++++++++++++++- src-tauri/src/proxy/antigravity_compat.rs | 48 ++++++++++++++++++- .../src/proxy/antigravity_compat.test.rs | 36 ++++++++++++++ src-tauri/src/proxy/config/normalize.rs | 2 +- src-tauri/src/proxy/openai_compat/input.rs | 6 ++- src-tauri/src/proxy/openai_compat/message.rs | 28 +++++++++-- src-tauri/src/proxy/upstream.rs | 11 ++++- src-tauri/src/proxy/upstream/attempt.rs | 7 ++- src-tauri/src/proxy/upstream/kiro.rs | 2 +- 20 files changed, 348 insertions(+), 28 deletions(-) diff --git a/crates/token_proxy_core/src/antigravity/endpoints.rs b/crates/token_proxy_core/src/antigravity/endpoints.rs index b1abde8..8784e70 100644 --- a/crates/token_proxy_core/src/antigravity/endpoints.rs +++ b/crates/token_proxy_core/src/antigravity/endpoints.rs @@ -4,7 +4,8 @@ pub(crate) const BASE_URL_DAILY: &str = "https://daily-cloudcode-pa.googleapis.c pub(crate) const BASE_URL_SANDBOX: &str = "https://daily-cloudcode-pa.sandbox.googleapis.com"; pub(crate) const BASE_URL_PROD: &str = "https://cloudcode-pa.googleapis.com"; -pub(crate) const BASE_URLS: [&str; 3] = [BASE_URL_SANDBOX, BASE_URL_DAILY, BASE_URL_PROD]; +// Align with CLIProxyAPIPlus: prefer daily, then sandbox. Prod is intentionally excluded. +pub(crate) const BASE_URLS: [&str; 2] = [BASE_URL_DAILY, BASE_URL_SANDBOX]; const ANTIGRAVITY_VERSION: &str = "1.104.0"; diff --git a/crates/token_proxy_core/src/proxy/anthropic_compat/request.rs b/crates/token_proxy_core/src/proxy/anthropic_compat/request.rs index c2221b4..574aeab 100644 --- a/crates/token_proxy_core/src/proxy/anthropic_compat/request.rs +++ b/crates/token_proxy_core/src/proxy/anthropic_compat/request.rs @@ -491,11 +491,54 @@ fn claude_content_to_blocks(content: Option<&Value>) -> Vec { }; match content { Value::String(text) => vec![json!({ "type": "text", "text": text })], - Value::Array(items) => items.clone(), + Value::Array(items) => items + .iter() + .cloned() + .map(|mut item| { + normalize_text_block_in_place(&mut item); + item + }) + .collect(), _ => Vec::new(), } } +fn normalize_text_block_in_place(block: &mut Value) { + let Some(object) = block.as_object_mut() else { + return; + }; + let block_type = object.get("type").and_then(Value::as_str).unwrap_or(""); + if block_type != "text" { + return; + } + let text_value = object.get("text"); + let new_text = text_value.and_then(extract_text_value); + if let Some(new_text) = new_text { + object.insert("text".to_string(), Value::String(new_text)); + return; + } + // If text exists but is not convertible, coerce to empty string to satisfy schema. + if text_value.is_some() { + object.insert("text".to_string(), Value::String(String::new())); + } +} + +fn extract_text_value(value: &Value) -> Option { + match value { + Value::String(text) => Some(text.to_string()), + Value::Object(object) => { + if let Some(text) = object.get("text") { + return extract_text_value(text); + } + if let Some(text) = object.get("value") { + return extract_text_value(text); + } + None + } + _ => None, + } +} + fn push_claude_message(messages: &mut Vec, role: &str, blocks: Vec) { let content = blocks; if content.is_empty() { diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat.rs b/crates/token_proxy_core/src/proxy/antigravity_compat.rs index 12c8076..e71e178 100644 --- a/crates/token_proxy_core/src/proxy/antigravity_compat.rs +++ b/crates/token_proxy_core/src/proxy/antigravity_compat.rs @@ -187,11 +187,24 @@ fn extract_model(request: &mut Map, model_hint: Option<&str>) -> .unwrap_or_else(|| DEFAULT_MODEL.to_string()) } -fn map_antigravity_model(model: &str) -> String { +pub(crate) fn map_antigravity_model(model: &str) -> String { let trimmed = model.trim(); if trimmed.is_empty() { return DEFAULT_MODEL.to_string(); } + // Align with CLIProxyAPIPlus conventions: + // - Some clients expose Claude models behind a "gemini-" prefix (e.g. gemini-claude-opus-4-5-thinking) + // while Antigravity upstream uses the stable Claude name without the prefix. + if trimmed.starts_with("gemini-claude-") { + return trimmed.trim_start_matches("gemini-").to_string(); + } + + // Claude Code / Amp CLI may request date-suffixed Claude models (e.g. claude-opus-4-5-20251101). + // Antigravity does not expose date-suffixed IDs; map them to the stable Antigravity model names. + if let Some(mapped) = map_claude_date_model_to_antigravity(trimmed) { + return mapped; + } + trimmed.to_string() } @@ -200,6 +213,39 @@ fn map_antigravity_model(model: &str) -> String { #[path = "antigravity_compat.test.rs"] mod tests; +fn map_claude_date_model_to_antigravity(model: &str) -> Option { + if !model.starts_with("claude-") { + return None; + } + + // Allow optional "-thinking" suffix (some clients encode "thinking" in the model ID). + let (base, _has_thinking_suffix) = match model.strip_suffix("-thinking") { + Some(value) => (value, true), + None => (model, false), + }; + + // Detect the trailing date segment in `...-YYYYMMDD`. + let (without_date, date_suffix) = base.rsplit_once('-')?; + if date_suffix.len() != 8 || !date_suffix.chars().all(|ch| ch.is_ascii_digit()) { + return None; + } + + // Known Claude 4.5 model families: map to the Antigravity stable names. + // NOTE: Antigravity appears to expose Sonnet/Opus (and their thinking variants) but not Haiku. + if without_date.starts_with("claude-opus-4-5") { + return Some("claude-opus-4-5-thinking".to_string()); + } + if without_date.starts_with("claude-sonnet-4-5") { + return Some("claude-sonnet-4-5-thinking".to_string()); + } + if without_date.starts_with("claude-haiku-4-5") { + // Follow CLIProxyAPIPlus example mapping: route Haiku to a close Gemini alternative. + return Some("gemini-2.5-flash".to_string()); + } + + None +} + fn normalize_system_instruction(request: &mut Map) { if let Some(value) = request.remove("system_instruction") { request.insert("systemInstruction".to_string(), value); diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs b/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs index f5f911c..170ec89 100644 --- a/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs +++ b/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs @@ -21,6 +21,42 @@ fn returns_default_on_empty_model() { assert_eq!(map_antigravity_model(""), "gemini-1.5-flash"); } +#[test] +fn strips_gemini_prefix_for_claude_aliases() { + assert_eq!( + map_antigravity_model("gemini-claude-opus-4-5-thinking"), + "claude-opus-4-5-thinking" + ); +} + +#[test] +fn maps_claude_opus_date_model_to_stable_thinking_model() { + assert_eq!( + map_antigravity_model("claude-opus-4-5-20251101"), + "claude-opus-4-5-thinking" + ); + assert_eq!( + map_antigravity_model("claude-opus-4-5-20251101-thinking"), + "claude-opus-4-5-thinking" + ); +} + +#[test] +fn maps_claude_sonnet_date_model_to_stable_thinking_model() { + assert_eq!( + map_antigravity_model("claude-sonnet-4-5-20250929"), + "claude-sonnet-4-5-thinking" + ); +} + +#[test] +fn maps_claude_haiku_date_model_to_gemini_fallback() { + assert_eq!( + map_antigravity_model("claude-haiku-4-5-20251001"), + "gemini-2.5-flash" + ); +} + #[test] fn injects_antigravity_system_instruction_for_claude() { let request = json!({ diff --git a/crates/token_proxy_core/src/proxy/config/normalize.rs b/crates/token_proxy_core/src/proxy/config/normalize.rs index 4de61e5..261cefb 100644 --- a/crates/token_proxy_core/src/proxy/config/normalize.rs +++ b/crates/token_proxy_core/src/proxy/config/normalize.rs @@ -9,7 +9,7 @@ use axum::http::header::{HeaderName, HeaderValue}; const APP_PROXY_URL_PLACEHOLDER: &str = "$app_proxy_url"; const DEFAULT_CODEX_BASE_URL: &str = "https://chatgpt.com/backend-api/codex"; -const DEFAULT_ANTIGRAVITY_BASE_URL: &str = "https://cloudcode-pa.googleapis.com"; +const DEFAULT_ANTIGRAVITY_BASE_URL: &str = "https://daily-cloudcode-pa.googleapis.com"; #[derive(Clone)] pub(super) struct NormalizedUpstream { diff --git a/crates/token_proxy_core/src/proxy/openai_compat/input.rs b/crates/token_proxy_core/src/proxy/openai_compat/input.rs index 7521144..27aa873 100644 --- a/crates/token_proxy_core/src/proxy/openai_compat/input.rs +++ b/crates/token_proxy_core/src/proxy/openai_compat/input.rs @@ -1,5 +1,7 @@ use serde_json::{json, Map, Value}; +use super::message::extract_text_from_part; + pub(super) fn responses_input_to_chat_messages(items: &[Value]) -> Result, String> { let mut messages = Vec::with_capacity(items.len()); for item in items { @@ -94,8 +96,8 @@ fn responses_message_content_to_chat_content(value: &Value) -> Option { let part_type = part.get("type").and_then(Value::as_str); match part_type { Some("input_text") | Some("text") | Some("output_text") => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - combined.push_str(text); + if let Some(text) = extract_text_from_part(part) { + combined.push_str(&text); output_parts.push(json!({ "type": "text", "text": text })); } } diff --git a/crates/token_proxy_core/src/proxy/openai_compat/message.rs b/crates/token_proxy_core/src/proxy/openai_compat/message.rs index c2a524d..73a2b1f 100644 --- a/crates/token_proxy_core/src/proxy/openai_compat/message.rs +++ b/crates/token_proxy_core/src/proxy/openai_compat/message.rs @@ -1,4 +1,24 @@ -use serde_json::{json, Value}; +use serde_json::{json, Map, Value}; + +fn extract_text_value(value: &Value) -> Option { + match value { + Value::String(text) => Some(text.to_string()), + Value::Object(object) => { + if let Some(text) = object.get("text") { + return extract_text_value(text); + } + if let Some(text) = object.get("value") { + return extract_text_value(text); + } + None + } + _ => None, + } +} + +pub(super) fn extract_text_from_part(part: &Map) -> Option { + part.get("text").and_then(extract_text_value) +} pub(super) fn extract_text_from_chat_content(content: Option<&Value>) -> Option { let Some(content) = content else { @@ -16,8 +36,8 @@ pub(super) fn extract_text_from_chat_content(content: Option<&Value>) -> Option< if !matches!(part_type, "text" | "input_text") { continue; } - if let Some(text) = part.get("text").and_then(Value::as_str) { - combined.push_str(text); + if let Some(text) = extract_text_from_part(part) { + combined.push_str(&text); } } if combined.trim().is_empty() { @@ -49,7 +69,7 @@ pub(super) fn chat_content_to_responses_message_parts( let part_type = part.get("type").and_then(Value::as_str).unwrap_or(""); match part_type { "text" | "input_text" => { - if let Some(text) = part.get("text").and_then(Value::as_str) { + if let Some(text) = extract_text_from_part(part) { out.push(json!({ "type": text_part_type, "text": text })); } } diff --git a/crates/token_proxy_core/src/proxy/upstream.rs b/crates/token_proxy_core/src/proxy/upstream.rs index 5b94ee3..0aef95f 100644 --- a/crates/token_proxy_core/src/proxy/upstream.rs +++ b/crates/token_proxy_core/src/proxy/upstream.rs @@ -432,7 +432,7 @@ async fn prepare_upstream_request( meta: &RequestMeta, request_auth: &RequestAuth, ) -> Result { - let mapped_meta = build_mapped_meta(meta, upstream); + let mapped_meta = build_mapped_meta(meta, upstream, provider); let upstream_path_with_query = resolve_upstream_path_with_query(provider, upstream_path_with_query, &mapped_meta); let upstream_url = upstream.upstream_url(&upstream_path_with_query); @@ -679,7 +679,7 @@ async fn resolve_antigravity_upstream( }) } -fn build_mapped_meta(meta: &RequestMeta, upstream: &UpstreamRuntime) -> RequestMeta { +fn build_mapped_meta(meta: &RequestMeta, upstream: &UpstreamRuntime, provider: &str) -> RequestMeta { let mapped_model = meta .original_model .as_deref() @@ -688,6 +688,13 @@ fn build_mapped_meta(meta: &RequestMeta, upstream: &UpstreamRuntime) -> RequestM mapped_model, meta.reasoning_effort.clone(), ); + let mapped_model = mapped_model.map(|model| { + if provider == "antigravity" { + super::antigravity_compat::map_antigravity_model(&model) + } else { + model + } + }); RequestMeta { stream: meta.stream, original_model: meta.original_model.clone(), diff --git a/crates/token_proxy_core/src/proxy/upstream/attempt.rs b/crates/token_proxy_core/src/proxy/upstream/attempt.rs index cef0a7c..446cc90 100644 --- a/crates/token_proxy_core/src/proxy/upstream/attempt.rs +++ b/crates/token_proxy_core/src/proxy/upstream/attempt.rs @@ -344,7 +344,12 @@ async fn send_antigravity_with_fallback( .await { Ok(response) => { - if super::utils::is_retryable_status(response.status()) && idx + 1 < urls.len() { + let status = response.status(); + // Align with CLIProxyAPIPlus: Antigravity endpoints may return 404 on one base URL + // while succeeding on another; try fallbacks on 404 as well. + if (super::utils::is_retryable_status(status) || status == StatusCode::NOT_FOUND) + && idx + 1 < urls.len() + { let _ = response.bytes().await; continue; } diff --git a/crates/token_proxy_core/src/proxy/upstream/kiro.rs b/crates/token_proxy_core/src/proxy/upstream/kiro.rs index e092840..875c2e7 100644 --- a/crates/token_proxy_core/src/proxy/upstream/kiro.rs +++ b/crates/token_proxy_core/src/proxy/upstream/kiro.rs @@ -108,7 +108,7 @@ async fn prepare_kiro_context<'a>( response_transform: FormatTransform, request_detail: Option, ) -> Result, AttemptOutcome> { - let mapped_meta = super::build_mapped_meta(meta, upstream); + let mapped_meta = super::build_mapped_meta(meta, upstream, "kiro"); let request_value = read_request_json(state, body).await?; let account_id = resolve_account_id(upstream)?; let record = load_account_record(state, &account_id).await?; diff --git a/src-tauri/src/antigravity/endpoints.rs b/src-tauri/src/antigravity/endpoints.rs index b1abde8..8784e70 100644 --- a/src-tauri/src/antigravity/endpoints.rs +++ b/src-tauri/src/antigravity/endpoints.rs @@ -4,7 +4,8 @@ pub(crate) const BASE_URL_DAILY: &str = "https://daily-cloudcode-pa.googleapis.c pub(crate) const BASE_URL_SANDBOX: &str = "https://daily-cloudcode-pa.sandbox.googleapis.com"; pub(crate) const BASE_URL_PROD: &str = "https://cloudcode-pa.googleapis.com"; -pub(crate) const BASE_URLS: [&str; 3] = [BASE_URL_SANDBOX, BASE_URL_DAILY, BASE_URL_PROD]; +// Align with CLIProxyAPIPlus: prefer daily, then sandbox. Prod is intentionally excluded. +pub(crate) const BASE_URLS: [&str; 2] = [BASE_URL_DAILY, BASE_URL_SANDBOX]; const ANTIGRAVITY_VERSION: &str = "1.104.0"; diff --git a/src-tauri/src/proxy/anthropic_compat/request.rs b/src-tauri/src/proxy/anthropic_compat/request.rs index c2221b4..574aeab 100644 --- a/src-tauri/src/proxy/anthropic_compat/request.rs +++ b/src-tauri/src/proxy/anthropic_compat/request.rs @@ -491,11 +491,54 @@ fn claude_content_to_blocks(content: Option<&Value>) -> Vec { }; match content { Value::String(text) => vec![json!({ "type": "text", "text": text })], - Value::Array(items) => items.clone(), + Value::Array(items) => items + .iter() + .cloned() + .map(|mut item| { + normalize_text_block_in_place(&mut item); + item + }) + .collect(), _ => Vec::new(), } } +fn normalize_text_block_in_place(block: &mut Value) { + let Some(object) = block.as_object_mut() else { + return; + }; + let block_type = object.get("type").and_then(Value::as_str).unwrap_or(""); + if block_type != "text" { + return; + } + let text_value = object.get("text"); + let new_text = text_value.and_then(extract_text_value); + if let Some(new_text) = new_text { + object.insert("text".to_string(), Value::String(new_text)); + return; + } + // If text exists but is not convertible, coerce to empty string to satisfy schema. + if text_value.is_some() { + object.insert("text".to_string(), Value::String(String::new())); + } +} + +fn extract_text_value(value: &Value) -> Option { + match value { + Value::String(text) => Some(text.to_string()), + Value::Object(object) => { + if let Some(text) = object.get("text") { + return extract_text_value(text); + } + if let Some(text) = object.get("value") { + return extract_text_value(text); + } + None + } + _ => None, + } +} + fn push_claude_message(messages: &mut Vec, role: &str, blocks: Vec) { let content = blocks; if content.is_empty() { diff --git a/src-tauri/src/proxy/antigravity_compat.rs b/src-tauri/src/proxy/antigravity_compat.rs index 12c8076..e71e178 100644 --- a/src-tauri/src/proxy/antigravity_compat.rs +++ b/src-tauri/src/proxy/antigravity_compat.rs @@ -187,11 +187,24 @@ fn extract_model(request: &mut Map, model_hint: Option<&str>) -> .unwrap_or_else(|| DEFAULT_MODEL.to_string()) } -fn map_antigravity_model(model: &str) -> String { +pub(crate) fn map_antigravity_model(model: &str) -> String { let trimmed = model.trim(); if trimmed.is_empty() { return DEFAULT_MODEL.to_string(); } + // Align with CLIProxyAPIPlus conventions: + // - Some clients expose Claude models behind a "gemini-" prefix (e.g. gemini-claude-opus-4-5-thinking) + // while Antigravity upstream uses the stable Claude name without the prefix. + if trimmed.starts_with("gemini-claude-") { + return trimmed.trim_start_matches("gemini-").to_string(); + } + + // Claude Code / Amp CLI may request date-suffixed Claude models (e.g. claude-opus-4-5-20251101). + // Antigravity does not expose date-suffixed IDs; map them to the stable Antigravity model names. + if let Some(mapped) = map_claude_date_model_to_antigravity(trimmed) { + return mapped; + } + trimmed.to_string() } @@ -200,6 +213,39 @@ fn map_antigravity_model(model: &str) -> String { #[path = "antigravity_compat.test.rs"] mod tests; +fn map_claude_date_model_to_antigravity(model: &str) -> Option { + if !model.starts_with("claude-") { + return None; + } + + // Allow optional "-thinking" suffix (some clients encode "thinking" in the model ID). + let (base, _has_thinking_suffix) = match model.strip_suffix("-thinking") { + Some(value) => (value, true), + None => (model, false), + }; + + // Detect the trailing date segment in `...-YYYYMMDD`. + let (without_date, date_suffix) = base.rsplit_once('-')?; + if date_suffix.len() != 8 || !date_suffix.chars().all(|ch| ch.is_ascii_digit()) { + return None; + } + + // Known Claude 4.5 model families: map to the Antigravity stable names. + // NOTE: Antigravity appears to expose Sonnet/Opus (and their thinking variants) but not Haiku. + if without_date.starts_with("claude-opus-4-5") { + return Some("claude-opus-4-5-thinking".to_string()); + } + if without_date.starts_with("claude-sonnet-4-5") { + return Some("claude-sonnet-4-5-thinking".to_string()); + } + if without_date.starts_with("claude-haiku-4-5") { + // Follow CLIProxyAPIPlus example mapping: route Haiku to a close Gemini alternative. + return Some("gemini-2.5-flash".to_string()); + } + + None +} + fn normalize_system_instruction(request: &mut Map) { if let Some(value) = request.remove("system_instruction") { request.insert("systemInstruction".to_string(), value); diff --git a/src-tauri/src/proxy/antigravity_compat.test.rs b/src-tauri/src/proxy/antigravity_compat.test.rs index f5f911c..170ec89 100644 --- a/src-tauri/src/proxy/antigravity_compat.test.rs +++ b/src-tauri/src/proxy/antigravity_compat.test.rs @@ -21,6 +21,42 @@ fn returns_default_on_empty_model() { assert_eq!(map_antigravity_model(""), "gemini-1.5-flash"); } +#[test] +fn strips_gemini_prefix_for_claude_aliases() { + assert_eq!( + map_antigravity_model("gemini-claude-opus-4-5-thinking"), + "claude-opus-4-5-thinking" + ); +} + +#[test] +fn maps_claude_opus_date_model_to_stable_thinking_model() { + assert_eq!( + map_antigravity_model("claude-opus-4-5-20251101"), + "claude-opus-4-5-thinking" + ); + assert_eq!( + map_antigravity_model("claude-opus-4-5-20251101-thinking"), + "claude-opus-4-5-thinking" + ); +} + +#[test] +fn maps_claude_sonnet_date_model_to_stable_thinking_model() { + assert_eq!( + map_antigravity_model("claude-sonnet-4-5-20250929"), + "claude-sonnet-4-5-thinking" + ); +} + +#[test] +fn maps_claude_haiku_date_model_to_gemini_fallback() { + assert_eq!( + map_antigravity_model("claude-haiku-4-5-20251001"), + "gemini-2.5-flash" + ); +} + #[test] fn injects_antigravity_system_instruction_for_claude() { let request = json!({ diff --git a/src-tauri/src/proxy/config/normalize.rs b/src-tauri/src/proxy/config/normalize.rs index 8482c79..9d624b9 100644 --- a/src-tauri/src/proxy/config/normalize.rs +++ b/src-tauri/src/proxy/config/normalize.rs @@ -8,7 +8,7 @@ use axum::http::header::{HeaderName, HeaderValue}; const APP_PROXY_URL_PLACEHOLDER: &str = "$app_proxy_url"; const DEFAULT_CODEX_BASE_URL: &str = "https://chatgpt.com/backend-api/codex"; -const DEFAULT_ANTIGRAVITY_BASE_URL: &str = "https://cloudcode-pa.googleapis.com"; +const DEFAULT_ANTIGRAVITY_BASE_URL: &str = "https://daily-cloudcode-pa.googleapis.com"; #[derive(Clone)] pub(super) struct NormalizedUpstream { diff --git a/src-tauri/src/proxy/openai_compat/input.rs b/src-tauri/src/proxy/openai_compat/input.rs index 7521144..27aa873 100644 --- a/src-tauri/src/proxy/openai_compat/input.rs +++ b/src-tauri/src/proxy/openai_compat/input.rs @@ -1,5 +1,7 @@ use serde_json::{json, Map, Value}; +use super::message::extract_text_from_part; + pub(super) fn responses_input_to_chat_messages(items: &[Value]) -> Result, String> { let mut messages = Vec::with_capacity(items.len()); for item in items { @@ -94,8 +96,8 @@ fn responses_message_content_to_chat_content(value: &Value) -> Option { let part_type = part.get("type").and_then(Value::as_str); match part_type { Some("input_text") | Some("text") | Some("output_text") => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - combined.push_str(text); + if let Some(text) = extract_text_from_part(part) { + combined.push_str(&text); output_parts.push(json!({ "type": "text", "text": text })); } } diff --git a/src-tauri/src/proxy/openai_compat/message.rs b/src-tauri/src/proxy/openai_compat/message.rs index c2a524d..73a2b1f 100644 --- a/src-tauri/src/proxy/openai_compat/message.rs +++ b/src-tauri/src/proxy/openai_compat/message.rs @@ -1,4 +1,24 @@ -use serde_json::{json, Value}; +use serde_json::{json, Map, Value}; + +fn extract_text_value(value: &Value) -> Option { + match value { + Value::String(text) => Some(text.to_string()), + Value::Object(object) => { + if let Some(text) = object.get("text") { + return extract_text_value(text); + } + if let Some(text) = object.get("value") { + return extract_text_value(text); + } + None + } + _ => None, + } +} + +pub(super) fn extract_text_from_part(part: &Map) -> Option { + part.get("text").and_then(extract_text_value) +} pub(super) fn extract_text_from_chat_content(content: Option<&Value>) -> Option { let Some(content) = content else { @@ -16,8 +36,8 @@ pub(super) fn extract_text_from_chat_content(content: Option<&Value>) -> Option< if !matches!(part_type, "text" | "input_text") { continue; } - if let Some(text) = part.get("text").and_then(Value::as_str) { - combined.push_str(text); + if let Some(text) = extract_text_from_part(part) { + combined.push_str(&text); } } if combined.trim().is_empty() { @@ -49,7 +69,7 @@ pub(super) fn chat_content_to_responses_message_parts( let part_type = part.get("type").and_then(Value::as_str).unwrap_or(""); match part_type { "text" | "input_text" => { - if let Some(text) = part.get("text").and_then(Value::as_str) { + if let Some(text) = extract_text_from_part(part) { out.push(json!({ "type": text_part_type, "text": text })); } } diff --git a/src-tauri/src/proxy/upstream.rs b/src-tauri/src/proxy/upstream.rs index a0bf6aa..929a681 100644 --- a/src-tauri/src/proxy/upstream.rs +++ b/src-tauri/src/proxy/upstream.rs @@ -390,7 +390,7 @@ async fn prepare_upstream_request( meta: &RequestMeta, request_auth: &RequestAuth, ) -> Result { - let mapped_meta = build_mapped_meta(meta, upstream); + let mapped_meta = build_mapped_meta(meta, upstream, provider); let upstream_path_with_query = resolve_upstream_path_with_query(provider, upstream_path_with_query, &mapped_meta); let upstream_url = upstream.upstream_url(&upstream_path_with_query); @@ -637,7 +637,7 @@ async fn resolve_antigravity_upstream( }) } -fn build_mapped_meta(meta: &RequestMeta, upstream: &UpstreamRuntime) -> RequestMeta { +fn build_mapped_meta(meta: &RequestMeta, upstream: &UpstreamRuntime, provider: &str) -> RequestMeta { let mapped_model = meta .original_model .as_deref() @@ -646,6 +646,13 @@ fn build_mapped_meta(meta: &RequestMeta, upstream: &UpstreamRuntime) -> RequestM mapped_model, meta.reasoning_effort.clone(), ); + let mapped_model = mapped_model.map(|model| { + if provider == "antigravity" { + super::antigravity_compat::map_antigravity_model(&model) + } else { + model + } + }); RequestMeta { stream: meta.stream, original_model: meta.original_model.clone(), diff --git a/src-tauri/src/proxy/upstream/attempt.rs b/src-tauri/src/proxy/upstream/attempt.rs index cef0a7c..446cc90 100644 --- a/src-tauri/src/proxy/upstream/attempt.rs +++ b/src-tauri/src/proxy/upstream/attempt.rs @@ -344,7 +344,12 @@ async fn send_antigravity_with_fallback( .await { Ok(response) => { - if super::utils::is_retryable_status(response.status()) && idx + 1 < urls.len() { + let status = response.status(); + // Align with CLIProxyAPIPlus: Antigravity endpoints may return 404 on one base URL + // while succeeding on another; try fallbacks on 404 as well. + if (super::utils::is_retryable_status(status) || status == StatusCode::NOT_FOUND) + && idx + 1 < urls.len() + { let _ = response.bytes().await; continue; } diff --git a/src-tauri/src/proxy/upstream/kiro.rs b/src-tauri/src/proxy/upstream/kiro.rs index 5efdb74..3303936 100644 --- a/src-tauri/src/proxy/upstream/kiro.rs +++ b/src-tauri/src/proxy/upstream/kiro.rs @@ -109,7 +109,7 @@ async fn prepare_kiro_context<'a>( response_transform: FormatTransform, request_detail: Option, ) -> Result, AttemptOutcome> { - let mapped_meta = super::build_mapped_meta(meta, upstream); + let mapped_meta = super::build_mapped_meta(meta, upstream, "kiro"); let request_value = read_request_json(state, body).await?; let account_id = resolve_account_id(upstream)?; let record = load_account_record(state, &account_id).await?; From 82b8c2c1d8adef22bcb71f7da758e4aa6366779c Mon Sep 17 00:00:00 2001 From: mxyhi Date: Fri, 30 Jan 2026 19:01:23 +0800 Subject: [PATCH 05/10] feat: improve antigravity Claude compatibility and tracing --- .../src/proxy/antigravity_compat.rs | 5 + .../src/proxy/antigravity_compat/claude.rs | 496 ++++++++++++++++++ .../proxy/antigravity_compat/claude.test.rs | 85 +++ .../antigravity_compat/signature_cache.rs | 107 ++++ .../src/proxy/config/normalize.rs | 20 +- .../src/proxy/response/dispatch/buffered.rs | 28 + .../src/proxy/response/dispatch/stream.rs | 62 +++ crates/token_proxy_core/src/proxy/server.rs | 55 +- .../src/proxy/server_helpers.rs | 117 +++-- .../src/proxy/upstream/attempt.rs | 24 + .../src/proxy/upstream/request.rs | 10 + messages/en.json | 2 +- messages/zh.json | 2 +- src-tauri/src/proxy/antigravity_compat.rs | 5 + .../src/proxy/antigravity_compat/claude.rs | 496 ++++++++++++++++++ .../antigravity_compat/signature_cache.rs | 107 ++++ src-tauri/src/proxy/config/normalize.rs | 20 +- .../src/proxy/response/dispatch/buffered.rs | 28 + .../src/proxy/response/dispatch/stream.rs | 62 +++ src-tauri/src/proxy/server.rs | 55 +- src-tauri/src/proxy/server_helpers.rs | 117 +++-- src-tauri/src/proxy/upstream/attempt.rs | 24 + src-tauri/src/proxy/upstream/request.rs | 10 + src-tauri/src/proxy/usage.rs | 5 +- src/features/config/cards/upstreams/table.tsx | 57 +- .../dashboard/RecentRequestsTable.tsx | 10 +- 26 files changed, 1899 insertions(+), 110 deletions(-) create mode 100644 crates/token_proxy_core/src/proxy/antigravity_compat/claude.rs create mode 100644 crates/token_proxy_core/src/proxy/antigravity_compat/claude.test.rs create mode 100644 crates/token_proxy_core/src/proxy/antigravity_compat/signature_cache.rs create mode 100644 src-tauri/src/proxy/antigravity_compat/claude.rs create mode 100644 src-tauri/src/proxy/antigravity_compat/signature_cache.rs diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat.rs b/crates/token_proxy_core/src/proxy/antigravity_compat.rs index e71e178..22c913a 100644 --- a/crates/token_proxy_core/src/proxy/antigravity_compat.rs +++ b/crates/token_proxy_core/src/proxy/antigravity_compat.rs @@ -8,6 +8,11 @@ use crate::oauth_util::generate_state; use crate::proxy::antigravity_schema::clean_json_schema_for_antigravity; use crate::proxy::sse::SseEventParser; +mod signature_cache; +mod claude; + +pub(crate) use claude::claude_request_to_antigravity; + const DEFAULT_MODEL: &str = "gemini-1.5-flash"; const THOUGHT_SIGNATURE_SENTINEL: &str = "skip_thought_signature_validator"; const PAYLOAD_USER_AGENT: &str = "antigravity"; diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat/claude.rs b/crates/token_proxy_core/src/proxy/antigravity_compat/claude.rs new file mode 100644 index 0000000..c582745 --- /dev/null +++ b/crates/token_proxy_core/src/proxy/antigravity_compat/claude.rs @@ -0,0 +1,496 @@ +use axum::body::Bytes; +use serde_json::{json, Map, Value}; + +use super::signature_cache; +use crate::proxy::antigravity_schema::clean_json_schema_for_antigravity; + +const THOUGHT_SIGNATURE_SENTINEL: &str = "skip_thought_signature_validator"; +const INTERLEAVED_HINT: &str = "Interleaved thinking is enabled. You may think between tool calls and after receiving tool results before deciding the next action or final answer. Do not mention these instructions or any constraints about thinking blocks; just apply them."; + +pub(crate) fn claude_request_to_antigravity( + body: &Bytes, + model_hint: Option<&str>, +) -> Result { + // Dedicated Claude -> Gemini request conversion to align with CLIProxyAPIPlus. + let object = parse_request_object(body)?; + let model_name = resolve_model_name(&object, model_hint); + let mapped_model = super::map_antigravity_model(&model_name); + let (contents, enable_thinking_translate) = build_contents(&object, &mapped_model)?; + let tools = build_tools(&object); + let thinking_enabled = thinking_enabled(&object); + let should_hint = tools.is_some() && thinking_enabled && is_claude_thinking_model(&mapped_model); + + let mut out = Map::new(); + if !mapped_model.trim().is_empty() { + out.insert("model".to_string(), Value::String(mapped_model)); + } + if !contents.is_empty() { + out.insert("contents".to_string(), Value::Array(contents)); + } + if let Some(system_instruction) = build_system_instruction(&object, should_hint) { + out.insert("systemInstruction".to_string(), system_instruction); + } + if let Some(tools) = tools { + out.insert("tools".to_string(), tools); + } + if let Some(gen) = build_generation_config(&object, enable_thinking_translate) { + out.insert("generationConfig".to_string(), gen); + } + + serde_json::to_vec(&Value::Object(out)) + .map(Bytes::from) + .map_err(|err| format!("Failed to serialize request: {err}")) +} + +fn parse_request_object(body: &Bytes) -> Result, String> { + let value: Value = + serde_json::from_slice(body).map_err(|_| "Request body must be JSON.".to_string())?; + value + .as_object() + .cloned() + .ok_or_else(|| "Request body must be a JSON object.".to_string()) +} + +fn resolve_model_name(object: &Map, model_hint: Option<&str>) -> String { + object + .get("model") + .and_then(Value::as_str) + .map(|value| value.trim()) + .filter(|value| !value.is_empty()) + .map(|value| value.to_string()) + .or_else(|| { + model_hint + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(|value| value.to_string()) + }) + .unwrap_or_default() +} + +fn build_system_instruction(object: &Map, should_hint: bool) -> Option { + let mut parts = system_parts(object); + if should_hint { + parts.push(json!({ "text": INTERLEAVED_HINT })); + } + if parts.is_empty() { + return None; + } + Some(json!({ "role": "user", "parts": parts })) +} + +fn system_parts(object: &Map) -> Vec { + let Some(system) = object.get("system") else { + return Vec::new(); + }; + match system { + Value::String(text) => system_parts_from_text(text), + Value::Array(items) => items + .iter() + .filter_map(|item| item.as_object()) + .filter(|item| item.get("type").and_then(Value::as_str) == Some("text")) + .filter_map(|item| item.get("text").and_then(Value::as_str)) + .flat_map(system_parts_from_text) + .collect(), + _ => Vec::new(), + } +} + +fn system_parts_from_text(text: &str) -> Vec { + let trimmed = text.trim(); + if trimmed.is_empty() { + Vec::new() + } else { + vec![json!({ "text": trimmed })] + } +} + +fn thinking_enabled(object: &Map) -> bool { + object + .get("thinking") + .and_then(Value::as_object) + .and_then(|thinking| thinking.get("type")) + .and_then(Value::as_str) + == Some("enabled") +} + +fn is_claude_thinking_model(model_name: &str) -> bool { + let lower = model_name.to_lowercase(); + lower.contains("claude") && lower.contains("thinking") +} + +fn build_contents( + object: &Map, + model_name: &str, +) -> Result<(Vec, bool), String> { + let Some(messages) = object.get("messages").and_then(Value::as_array) else { + return Ok((Vec::new(), true)); + }; + let mut contents = Vec::with_capacity(messages.len()); + let mut enable_thinking_translate = true; + + for message in messages { + let Some(message) = message.as_object() else { + continue; + }; + let role = message.get("role").and_then(Value::as_str).unwrap_or("user"); + let role = if role == "assistant" { "model" } else { role }; + let mut parts = Vec::new(); + let mut current_signature = String::new(); + match message.get("content") { + Some(Value::Array(items)) => { + for item in items { + let Some(item) = item.as_object() else { + continue; + }; + let block_type = item.get("type").and_then(Value::as_str).unwrap_or(""); + handle_block( + item, + block_type, + model_name, + &mut current_signature, + &mut enable_thinking_translate, + &mut parts, + ); + } + } + Some(Value::String(text)) => push_text_part(text, &mut parts), + _ => {} + } + reorder_thinking_parts(role, &mut parts); + contents.push(json!({ "role": role, "parts": parts })); + } + + Ok((contents, enable_thinking_translate)) +} + +fn handle_block( + item: &Map, + block_type: &str, + model_name: &str, + current_signature: &mut String, + enable_thinking_translate: &mut bool, + parts: &mut Vec, +) { + match block_type { + "thinking" => { + handle_thinking_block(item, model_name, current_signature, enable_thinking_translate, parts); + } + "text" => { + if let Some(text) = item.get("text").and_then(Value::as_str) { + push_text_part(text, parts); + } + } + "tool_use" => { + if let Some(part) = tool_use_to_part(item, model_name, current_signature) { + parts.push(part); + } + } + "tool_result" => { + if let Some(part) = tool_result_to_part(item) { + parts.push(part); + } + } + "image" => { + if let Some(part) = image_to_part(item) { + parts.push(part); + } + } + _ => {} + } +} + +fn handle_thinking_block( + item: &Map, + model_name: &str, + current_signature: &mut String, + enable_thinking_translate: &mut bool, + parts: &mut Vec, +) { + let thinking_text = extract_text_value(item.get("thinking")).unwrap_or_default(); + let signature = resolve_thinking_signature(model_name, &thinking_text, item); + if !signature_cache::has_valid_signature(model_name, &signature) { + *enable_thinking_translate = false; + return; + } + *current_signature = signature.clone(); + if !thinking_text.is_empty() { + signature_cache::cache_signature(model_name, &thinking_text, &signature); + } + let mut part = json!({ "thought": true }); + if !thinking_text.is_empty() { + if let Some(part) = part.as_object_mut() { + part.insert("text".to_string(), Value::String(thinking_text)); + } + } + if !signature.is_empty() { + if let Some(part) = part.as_object_mut() { + part.insert("thoughtSignature".to_string(), Value::String(signature)); + } + } + parts.push(part); +} + +fn resolve_thinking_signature( + model_name: &str, + thinking_text: &str, + item: &Map, +) -> String { + let cached = signature_cache::get_cached_signature(model_name, thinking_text); + if !cached.is_empty() { + return cached; + } + let signature = item.get("signature").and_then(Value::as_str).unwrap_or(""); + parse_client_signature(model_name, signature) +} + +fn parse_client_signature(model_name: &str, signature: &str) -> String { + if signature.contains('#') { + let mut parts = signature.splitn(2, '#'); + let prefix = parts.next().unwrap_or(""); + let value = parts.next().unwrap_or(""); + if prefix == model_name { + return value.to_string(); + } + } + signature.to_string() +} + +fn tool_use_to_part( + item: &Map, + model_name: &str, + current_signature: &str, +) -> Option { + let name = item.get("name").and_then(Value::as_str).unwrap_or(""); + let id = item.get("id").and_then(Value::as_str).unwrap_or(""); + let args_raw = parse_tool_use_input(item.get("input"))?; + + let mut part = json!({ + "functionCall": { + "name": name, + "args": args_raw + } + }); + if !id.is_empty() { + if let Some(call) = part.get_mut("functionCall").and_then(Value::as_object_mut) { + call.insert("id".to_string(), Value::String(id.to_string())); + } + } + + let signature = if signature_cache::has_valid_signature(model_name, current_signature) { + current_signature.to_string() + } else { + // Antigravity requires thoughtSignature for tool calls; use sentinel when missing. + THOUGHT_SIGNATURE_SENTINEL.to_string() + }; + if let Some(part) = part.as_object_mut() { + part.insert("thoughtSignature".to_string(), Value::String(signature)); + } + Some(part) +} + +fn parse_tool_use_input(input: Option<&Value>) -> Option { + match input { + Some(Value::Object(object)) => Some(Value::Object(object.clone())), + Some(Value::String(raw)) => serde_json::from_str::(raw).ok().and_then(|val| { + if val.is_object() { + Some(val) + } else { + None + } + }), + _ => None, + } +} + +fn tool_result_to_part(item: &Map) -> Option { + let tool_call_id = item.get("tool_use_id").and_then(Value::as_str).unwrap_or(""); + if tool_call_id.is_empty() { + return None; + } + let func_name = tool_call_name_from_id(tool_call_id); + let response = tool_result_response(item.get("content")); + Some(json!({ + "functionResponse": { + "id": tool_call_id, + "name": func_name, + "response": { "result": response } + } + })) +} + +fn tool_call_name_from_id(tool_call_id: &str) -> String { + let parts = tool_call_id.split('-').collect::>(); + if parts.len() <= 2 { + return tool_call_id.to_string(); + } + parts[..parts.len() - 2].join("-") +} + +fn tool_result_response(value: Option<&Value>) -> Value { + match value { + Some(Value::String(text)) => Value::String(text.to_string()), + Some(Value::Array(items)) => { + if items.len() == 1 { + items[0].clone() + } else { + Value::Array(items.clone()) + } + } + Some(Value::Object(object)) => Value::Object(object.clone()), + Some(other) => other.clone(), + None => Value::String(String::new()), + } +} + +fn image_to_part(item: &Map) -> Option { + let source = item.get("source").and_then(Value::as_object)?; + if source.get("type").and_then(Value::as_str) != Some("base64") { + return None; + } + let media_type = source + .get("media_type") + .and_then(Value::as_str) + .unwrap_or("image/png"); + let data = source.get("data").and_then(Value::as_str)?; + Some(json!({ + "inlineData": { + "mime_type": media_type, + "data": data + } + })) +} + +fn push_text_part(text: &str, parts: &mut Vec) { + if !text.is_empty() { + parts.push(json!({ "text": text })); + } +} + +fn reorder_thinking_parts(role: &str, parts: &mut Vec) { + if role != "model" || parts.is_empty() { + return; + } + let mut thinking = Vec::new(); + let mut others = Vec::new(); + for part in parts.iter() { + if part.get("thought").and_then(Value::as_bool) == Some(true) { + thinking.push(part.clone()); + } else { + others.push(part.clone()); + } + } + if thinking.is_empty() { + return; + } + let first_is_thinking = parts + .first() + .and_then(|part| part.get("thought").and_then(Value::as_bool)) + .unwrap_or(false); + if first_is_thinking && thinking.len() <= 1 { + return; + } + parts.clear(); + parts.extend(thinking); + parts.extend(others); +} + +fn build_tools(object: &Map) -> Option { + let tools = object.get("tools").and_then(Value::as_array)?; + let mut decls = Vec::new(); + for tool in tools { + let Some(tool) = tool.as_object() else { + continue; + }; + let input_schema = tool.get("input_schema"); + let Some(schema) = input_schema.and_then(Value::as_object) else { + continue; + }; + let mut tool_obj = Map::new(); + for (key, value) in tool.iter() { + if key == "input_schema" { + continue; + } + if is_allowed_tool_key(key) { + tool_obj.insert(key.to_string(), value.clone()); + } + } + let mut schema_value = Value::Object(schema.clone()); + clean_json_schema_for_antigravity(&mut schema_value); + tool_obj.insert("parametersJsonSchema".to_string(), schema_value); + decls.push(Value::Object(tool_obj)); + } + if decls.is_empty() { + None + } else { + Some(json!([{ "functionDeclarations": decls }])) + } +} + +fn is_allowed_tool_key(key: &str) -> bool { + matches!( + key, + "name" + | "description" + | "behavior" + | "parameters" + | "parametersJsonSchema" + | "response" + | "responseJsonSchema" + ) +} + +fn build_generation_config(object: &Map, enable_thinking: bool) -> Option { + let mut gen = Map::new(); + if enable_thinking { + if let Some(thinking) = object.get("thinking").and_then(Value::as_object) { + if thinking.get("type").and_then(Value::as_str) == Some("enabled") { + if let Some(budget) = thinking.get("budget_tokens").and_then(Value::as_i64) { + gen.insert( + "thinkingConfig".to_string(), + json!({ + "thinkingBudget": budget, + "includeThoughts": true + }), + ); + } + } + } + } + if let Some(value) = object.get("temperature").and_then(Value::as_f64) { + gen.insert("temperature".to_string(), json!(value)); + } + if let Some(value) = object.get("top_p").and_then(Value::as_f64) { + gen.insert("topP".to_string(), json!(value)); + } + if let Some(value) = object.get("top_k").and_then(Value::as_i64) { + gen.insert("topK".to_string(), json!(value)); + } + if let Some(value) = object.get("max_tokens").and_then(Value::as_i64) { + gen.insert("maxOutputTokens".to_string(), json!(value)); + } + if gen.is_empty() { + None + } else { + Some(Value::Object(gen)) + } +} + +fn extract_text_value(value: Option<&Value>) -> Option { + match value { + Some(Value::String(text)) => Some(text.to_string()), + Some(Value::Object(object)) => { + if let Some(text) = object.get("text") { + return extract_text_value(Some(text)); + } + if let Some(text) = object.get("value") { + return extract_text_value(Some(text)); + } + None + } + _ => None, + } +} + +#[cfg(test)] +#[path = "claude.test.rs"] +mod tests; diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat/claude.test.rs b/crates/token_proxy_core/src/proxy/antigravity_compat/claude.test.rs new file mode 100644 index 0000000..4451d9c --- /dev/null +++ b/crates/token_proxy_core/src/proxy/antigravity_compat/claude.test.rs @@ -0,0 +1,85 @@ +use axum::body::Bytes; +use serde_json::Value; + +use super::claude_request_to_antigravity; + +fn parse_output(bytes: Bytes) -> Value { + serde_json::from_slice(&bytes).expect("json") +} + +#[test] +fn converts_basic_structure_and_system() { + let input = Bytes::from( + r#"{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ], + "system": [{"type": "text", "text": "You are helpful"}] + }"#, + ); + let output = parse_output(claude_request_to_antigravity(&input, None).expect("convert")); + assert_eq!(output["contents"][0]["role"], "user"); + assert_eq!( + output["systemInstruction"]["parts"][0]["text"], + "You are helpful" + ); +} + +#[test] +fn drops_non_string_text_blocks() { + let input = Bytes::from( + r#"{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + {"role": "user", "content": [{"type": "text", "text": {"text": "hi"}}]} + ] + }"#, + ); + let output = parse_output(claude_request_to_antigravity(&input, None).expect("convert")); + let parts = output["contents"][0]["parts"].as_array().unwrap(); + assert!(parts.is_empty()); +} + +#[test] +fn tool_use_adds_skip_signature() { + let input = Bytes::from( + r#"{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "tool_use", "id": "call_123", "name": "get_weather", "input": "{\"location\":\"Paris\"}"} + ] + } + ] + }"#, + ); + let output = parse_output(claude_request_to_antigravity(&input, None).expect("convert")); + let part = &output["contents"][0]["parts"][0]; + assert_eq!(part["functionCall"]["name"], "get_weather"); + assert_eq!(part["thoughtSignature"], "skip_thought_signature_validator"); +} + +#[test] +fn unsigned_thinking_is_removed() { + let input = Bytes::from( + r#"{ + "model": "claude-sonnet-4-5-thinking", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "Let me think"}, + {"type": "text", "text": "Answer"} + ] + } + ] + }"#, + ); + let output = parse_output(claude_request_to_antigravity(&input, None).expect("convert")); + let parts = output["contents"][0]["parts"].as_array().unwrap(); + assert_eq!(parts.len(), 1); + assert_eq!(parts[0]["text"], "Answer"); +} diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat/signature_cache.rs b/crates/token_proxy_core/src/proxy/antigravity_compat/signature_cache.rs new file mode 100644 index 0000000..a9095da --- /dev/null +++ b/crates/token_proxy_core/src/proxy/antigravity_compat/signature_cache.rs @@ -0,0 +1,107 @@ +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::sync::{Mutex, OnceLock}; +use std::time::{Duration, Instant}; + +const SIGNATURE_CACHE_TTL: Duration = Duration::from_secs(3 * 60 * 60); +const SIGNATURE_TEXT_HASH_LEN: usize = 16; +const MIN_VALID_SIGNATURE_LEN: usize = 50; +const GEMINI_SKIP_SENTINEL: &str = "skip_thought_signature_validator"; + +type Cache = HashMap>; + +#[derive(Clone)] +struct SignatureEntry { + signature: String, + touched: Instant, +} + +static SIGNATURE_CACHE: OnceLock> = OnceLock::new(); + +fn cache_lock() -> std::sync::MutexGuard<'static, Cache> { + SIGNATURE_CACHE + .get_or_init(|| Mutex::new(HashMap::new())) + .lock() + .unwrap_or_else(|err| err.into_inner()) +} + +pub(crate) fn cache_signature(model_name: &str, text: &str, signature: &str) { + if text.trim().is_empty() || signature.trim().is_empty() { + return; + } + if signature.len() < MIN_VALID_SIGNATURE_LEN { + return; + } + let group_key = model_group_key(model_name); + let text_hash = hash_text(text); + let mut cache = cache_lock(); + let group = cache.entry(group_key).or_insert_with(HashMap::new); + group.insert( + text_hash, + SignatureEntry { + signature: signature.to_string(), + touched: Instant::now(), + }, + ); +} + +pub(crate) fn get_cached_signature(model_name: &str, text: &str) -> String { + let group_key = model_group_key(model_name); + if text.trim().is_empty() { + return fallback_signature(&group_key); + } + let text_hash = hash_text(text); + let mut cache = cache_lock(); + let Some(group) = cache.get_mut(&group_key) else { + return fallback_signature(&group_key); + }; + let Some(entry) = group.get_mut(&text_hash) else { + return fallback_signature(&group_key); + }; + if entry.touched.elapsed() > SIGNATURE_CACHE_TTL { + group.remove(&text_hash); + return fallback_signature(&group_key); + } + entry.touched = Instant::now(); + entry.signature.clone() +} + +pub(crate) fn has_valid_signature(model_name: &str, signature: &str) -> bool { + if signature.trim().is_empty() { + return false; + } + if signature == GEMINI_SKIP_SENTINEL { + return model_group_key(model_name) == "gemini"; + } + signature.len() >= MIN_VALID_SIGNATURE_LEN +} + +fn fallback_signature(group_key: &str) -> String { + if group_key == "gemini" { + GEMINI_SKIP_SENTINEL.to_string() + } else { + String::new() + } +} + +fn model_group_key(model_name: &str) -> String { + let lower = model_name.to_lowercase(); + if lower.contains("gpt") { + return "gpt".to_string(); + } + if lower.contains("claude") { + return "claude".to_string(); + } + if lower.contains("gemini") { + return "gemini".to_string(); + } + model_name.trim().to_string() +} + +fn hash_text(text: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(text.as_bytes()); + let digest = hasher.finalize(); + let hex = format!("{:x}", digest); + hex.chars().take(SIGNATURE_TEXT_HASH_LEN).collect() +} diff --git a/crates/token_proxy_core/src/proxy/config/normalize.rs b/crates/token_proxy_core/src/proxy/config/normalize.rs index 261cefb..d3fa44d 100644 --- a/crates/token_proxy_core/src/proxy/config/normalize.rs +++ b/crates/token_proxy_core/src/proxy/config/normalize.rs @@ -47,16 +47,18 @@ pub(super) fn build_provider_upstreams( Ok(output) } -fn group_upstreams_by_priority(mut upstreams: Vec) -> Vec { - upstreams.sort_by(|left, right| right.priority.cmp(&left.priority)); - let mut groups: Vec = Vec::new(); +fn group_upstreams_by_priority(upstreams: Vec) -> Vec { + // Keep same-priority order stable by preserving config insertion order. + let mut grouped: HashMap> = HashMap::new(); for upstream in upstreams { - match groups.last_mut() { - Some(group) if group.priority == upstream.priority => group.items.push(upstream), - _ => groups.push(UpstreamGroup { - priority: upstream.priority, - items: vec![upstream], - }), + grouped.entry(upstream.priority).or_default().push(upstream); + } + let mut priorities: Vec = grouped.keys().copied().collect(); + priorities.sort_by(|left, right| right.cmp(left)); + let mut groups = Vec::with_capacity(priorities.len()); + for priority in priorities { + if let Some(items) = grouped.remove(&priority) { + groups.push(UpstreamGroup { priority, items }); } } groups diff --git a/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs b/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs index 3dbe740..20664d0 100644 --- a/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs +++ b/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs @@ -15,12 +15,16 @@ use super::super::super::{ log::{build_log_entry, LogContext, LogWriter, UsageSnapshot}, model, openai_compat::{transform_response_body, FormatTransform}, + request_body::ReplayableBody, redact::redact_query_param_value, + server_helpers::log_debug_headers_body, token_rate::RequestTokenTracker, usage::extract_usage_from_response, UPSTREAM_NO_DATA_TIMEOUT, }; +const DEBUG_BODY_LOG_LIMIT_BYTES: usize = usize::MAX; + pub(super) async fn build_buffered_response( status: StatusCode, upstream_res: reqwest::Response, @@ -33,10 +37,18 @@ pub(super) async fn build_buffered_response( estimated_input_tokens: Option, ) -> Response { let mut context = context; + let response_headers = upstream_res.headers().clone(); let bytes = match read_upstream_bytes(upstream_res, &mut context, &log).await { Ok(bytes) => bytes, Err(response) => return response, }; + log_debug_headers_body( + "upstream.response.raw", + Some(&response_headers), + Some(&ReplayableBody::from_bytes(bytes.clone())), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; let bytes = if context.provider == PROVIDER_ANTIGRAVITY { match antigravity_compat::unwrap_response(&bytes) { Ok(unwrapped) => unwrapped, @@ -47,6 +59,15 @@ pub(super) async fn build_buffered_response( } else { bytes }; + if context.provider == PROVIDER_ANTIGRAVITY { + log_debug_headers_body( + "upstream.response.unwrapped", + Some(&response_headers), + Some(&ReplayableBody::from_bytes(bytes.clone())), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; + } let mut usage = extract_usage_from_response(&bytes); let response_error = response_error_for_status(status, &bytes); let request_body = context.request_body.clone(); @@ -74,6 +95,13 @@ pub(super) async fn build_buffered_response( log.clone().write_detached(entry); let output = maybe_override_response_model(output, model_override); + log_debug_headers_body( + "outbound.response", + Some(&headers), + Some(&ReplayableBody::from_bytes(output.clone())), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; let provider_for_tokens = provider_for_tokens(response_transform, context.provider.as_str()); token_count::apply_output_tokens_from_response(&request_tracker, provider_for_tokens, &output).await; diff --git a/crates/token_proxy_core/src/proxy/response/dispatch/stream.rs b/crates/token_proxy_core/src/proxy/response/dispatch/stream.rs index 602002e..d9a4c50 100644 --- a/crates/token_proxy_core/src/proxy/response/dispatch/stream.rs +++ b/crates/token_proxy_core/src/proxy/response/dispatch/stream.rs @@ -16,6 +16,7 @@ use super::super::super::{ log::{build_log_entry, LogContext, LogWriter, UsageSnapshot}, openai_compat::FormatTransform, redact::redact_query_param_value, + server_helpers::log_debug_headers_body, token_rate::RequestTokenTracker, UPSTREAM_NO_DATA_TIMEOUT, }; @@ -25,6 +26,7 @@ type UpstreamBytesStream = futures_util::stream::BoxStream< Result>, >; type ResponseStream = futures_util::stream::BoxStream<'static, Result>; +const DEBUG_BODY_LOG_LIMIT_BYTES: usize = usize::MAX; pub(super) async fn build_stream_response( status: StatusCode, @@ -43,11 +45,19 @@ pub(super) async fn build_stream_response( Ok(stream) => stream, Err(response) => return response, }; + log_debug_headers_body( + "upstream.response.headers", + Some(&headers), + None, + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; let upstream = if context.provider == PROVIDER_ANTIGRAVITY { antigravity_compat::stream_antigravity_to_gemini(upstream).boxed() } else { upstream }; + let upstream = log_upstream_stream_if_debug(upstream); let stream = stream_for_transform( response_transform, @@ -58,6 +68,14 @@ pub(super) async fn build_stream_response( estimated_input_tokens, model_override, ); + log_debug_headers_body( + "outbound.response.headers", + Some(&headers), + None, + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; + let stream = log_response_stream_if_debug(stream); let body = Body::from_stream(stream); http::build_response(status, headers, body) } @@ -450,6 +468,50 @@ fn stream_error_response( http::error_response(status, message) } +fn log_upstream_stream_if_debug(upstream: UpstreamBytesStream) -> UpstreamBytesStream { + if !tracing::enabled!(tracing::Level::DEBUG) { + return upstream; + } + upstream + .map(|item| { + if let Ok(chunk) = &item { + let text = String::from_utf8_lossy(chunk); + tracing::debug!( + stage = "upstream.response.chunk", + bytes = chunk.len(), + body = %text, + "debug dump" + ); + } else if let Err(err) = &item { + tracing::debug!(stage = "upstream.response.chunk.error", error = %err, "debug dump"); + } + item + }) + .boxed() +} + +fn log_response_stream_if_debug(stream: ResponseStream) -> ResponseStream { + if !tracing::enabled!(tracing::Level::DEBUG) { + return stream; + } + stream + .map(|item| { + if let Ok(chunk) = &item { + let text = String::from_utf8_lossy(chunk); + tracing::debug!( + stage = "outbound.response.chunk", + bytes = chunk.len(), + body = %text, + "debug dump" + ); + } else if let Err(err) = &item { + tracing::debug!(stage = "outbound.response.chunk.error", error = %err, "debug dump"); + } + item + }) + .boxed() +} + fn stream_with_optional_model_override( upstream: impl futures_util::stream::Stream> + Unpin + Send + 'static, context: LogContext, diff --git a/crates/token_proxy_core/src/proxy/server.rs b/crates/token_proxy_core/src/proxy/server.rs index b08696b..5023762 100644 --- a/crates/token_proxy_core/src/proxy/server.rs +++ b/crates/token_proxy_core/src/proxy/server.rs @@ -580,19 +580,54 @@ async fn build_outbound_body_or_respond( body: ReplayableBody, request_start: Instant, ) -> Result { - let body = match maybe_transform_request_body( + let body = transform_body_or_respond( http_clients, + log, + request_detail.clone(), + path, + plan, + meta, + body, + request_start, + ) + .await?; + apply_openai_stream_options_or_respond( + log, + request_detail, + path, + plan, + meta, + body, + request_start, + ) + .await +} + +async fn transform_body_or_respond( + http_clients: &super::http_client::ProxyHttpClients, + log: &Arc, + request_detail: Option, + path: &str, + plan: &DispatchPlan, + meta: &RequestMeta, + body: ReplayableBody, + request_start: Instant, +) -> Result { + match maybe_transform_request_body( + http_clients, + plan.provider, + path, plan.request_transform, meta.original_model.as_deref(), body, ) .await { - Ok(body) => body, + Ok(body) => Ok(body), Err(err) => { log_request_error( log, - request_detail.clone(), + request_detail, path, plan.provider, LOCAL_UPSTREAM_ID, @@ -600,10 +635,20 @@ async fn build_outbound_body_or_respond( err.message.clone(), request_start, ); - return Err(http::error_response(err.status, err.message)); + Err(http::error_response(err.status, err.message)) } - }; + } +} +async fn apply_openai_stream_options_or_respond( + log: &Arc, + request_detail: Option, + path: &str, + plan: &DispatchPlan, + meta: &RequestMeta, + body: ReplayableBody, + request_start: Instant, +) -> Result { match maybe_force_openai_stream_options_include_usage( plan.provider, plan.outbound_path.unwrap_or(path), diff --git a/crates/token_proxy_core/src/proxy/server_helpers.rs b/crates/token_proxy_core/src/proxy/server_helpers.rs index 34b54ff..b285c51 100644 --- a/crates/token_proxy_core/src/proxy/server_helpers.rs +++ b/crates/token_proxy_core/src/proxy/server_helpers.rs @@ -5,6 +5,7 @@ use axum::{ use serde_json::{Map, Value}; use super::{ + antigravity_compat, gemini, http_client::ProxyHttpClients, openai_compat::{ @@ -18,10 +19,11 @@ use super::{ const ANTHROPIC_MESSAGES_PREFIX: &str = "/v1/messages"; const ANTHROPIC_COMPLETE_PATH: &str = "/v1/complete"; +const PROVIDER_ANTIGRAVITY: &str = "antigravity"; const REQUEST_META_LIMIT_BYTES: usize = 2 * 1024 * 1024; // Format conversion needs the full JSON body; allow up to the default max_request_body_bytes (20 MiB). const REQUEST_TRANSFORM_LIMIT_BYTES: usize = 20 * 1024 * 1024; -const DEBUG_BODY_LOG_LIMIT_BYTES: usize = 64 * 1024; +const DEBUG_BODY_LOG_LIMIT_BYTES: usize = usize::MAX; const OPENAI_REASONING_MODEL_SUFFIX_PREFIX: &str = "-reasoning-"; #[derive(Debug)] @@ -162,49 +164,62 @@ fn ensure_stream_options_include_usage(object: &mut Map) -> bool } pub(crate) async fn log_debug_request(headers: &HeaderMap, body: &ReplayableBody) { - let header_snapshot: Vec<(String, String)> = headers - .iter() - .map(|(name, value)| { - let redacted = if is_sensitive_header(name.as_str()) { - "***".to_string() - } else { - value.to_str().unwrap_or("").to_string() - }; - (name.to_string(), redacted) - }) - .collect(); + log_debug_headers_body( + "inbound.request", + Some(headers), + Some(body), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; +} - let body_text = match body.read_bytes_if_small(DEBUG_BODY_LOG_LIMIT_BYTES).await { - Ok(Some(bytes)) => { - let text = String::from_utf8_lossy(&bytes); - Some(text.into_owned()) - } - Ok(None) => None, - Err(err) => { - tracing::debug!(error = %err, "debug body read failed"); - None +pub(crate) async fn log_debug_headers_body( + stage: &str, + headers: Option<&HeaderMap>, + body: Option<&ReplayableBody>, + max_body_bytes: usize, +) { + if !tracing::enabled!(tracing::Level::DEBUG) { + return; + } + + let header_snapshot = headers + .map(snapshot_headers_raw) + .unwrap_or_default(); + let body_text = if let Some(body) = body { + match body.read_bytes_if_small(max_body_bytes).await { + Ok(Some(bytes)) => Some(String::from_utf8_lossy(&bytes).into_owned()), + Ok(None) => Some(format!("[body omitted: larger than {max_body_bytes} bytes]")), + Err(err) => Some(format!("[body read failed: {err}]") ), } + } else { + None }; match body_text { Some(text) => { - tracing::debug!(headers = ?header_snapshot, body = %text, "incoming request debug dump"); + tracing::debug!(stage, headers = ?header_snapshot, body = %text, "debug dump"); } None => { - tracing::debug!(headers = ?header_snapshot, "incoming request body omitted (too large)"); + tracing::debug!(stage, headers = ?header_snapshot, "debug dump (no body)"); } } } -fn is_sensitive_header(name: &str) -> bool { - matches!( - name.to_ascii_lowercase().as_str(), - "authorization" | "proxy-authorization" | "x-api-key" - ) +fn snapshot_headers_raw(headers: &HeaderMap) -> Vec<(String, String)> { + headers + .iter() + .map(|(name, value)| { + let value = value.to_str().unwrap_or("").to_string(); + (name.to_string(), value) + }) + .collect() } pub(crate) async fn maybe_transform_request_body( http_clients: &ProxyHttpClients, + provider: &str, + path: &str, transform: FormatTransform, model_hint: Option<&str>, body: ReplayableBody, @@ -228,11 +243,39 @@ pub(crate) async fn maybe_transform_request_body( "Request body is too large to transform.", )); }; + let inbound_body = ReplayableBody::from_bytes(bytes.clone()); + log_debug_headers_body( + "transform.input", + None, + Some(&inbound_body), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; + + let outbound_bytes = if should_use_antigravity_claude(provider, path, transform) { + antigravity_compat::claude_request_to_antigravity(&bytes, model_hint) + .map_err(|message| RequestError::new(StatusCode::BAD_REQUEST, message))? + } else { + transform_request_body(transform, &bytes, http_clients, model_hint) + .await + .map_err(|message| RequestError::new(StatusCode::BAD_REQUEST, message))? + }; + let outbound_body = ReplayableBody::from_bytes(outbound_bytes); + log_debug_headers_body( + "transform.output", + None, + Some(&outbound_body), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; + Ok(outbound_body) +} - let outbound_bytes = transform_request_body(transform, &bytes, http_clients, model_hint) - .await - .map_err(|message| RequestError::new(StatusCode::BAD_REQUEST, message))?; - Ok(ReplayableBody::from_bytes(outbound_bytes)) +fn should_use_antigravity_claude(provider: &str, path: &str, transform: FormatTransform) -> bool { + // Align Antigravity with CLIProxyAPIPlus for Claude /v1/messages. + provider == PROVIDER_ANTIGRAVITY + && path == ANTHROPIC_MESSAGES_PREFIX + && transform == FormatTransform::AnthropicToGemini } pub(crate) async fn maybe_force_openai_stream_options_include_usage( @@ -277,7 +320,15 @@ pub(crate) async fn maybe_force_openai_stream_options_include_usage( format!("Failed to serialize request: {err}"), ) })?; - Ok(ReplayableBody::from_bytes(outbound_bytes)) + let outbound_body = ReplayableBody::from_bytes(outbound_bytes); + log_debug_headers_body( + "stream_options.output", + None, + Some(&outbound_body), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; + Ok(outbound_body) } pub(crate) async fn maybe_rewrite_openai_reasoning_effort_from_model_suffix( diff --git a/crates/token_proxy_core/src/proxy/upstream/attempt.rs b/crates/token_proxy_core/src/proxy/upstream/attempt.rs index 446cc90..3686dfb 100644 --- a/crates/token_proxy_core/src/proxy/upstream/attempt.rs +++ b/crates/token_proxy_core/src/proxy/upstream/attempt.rs @@ -16,9 +16,12 @@ use crate::proxy::http; use crate::proxy::openai_compat::FormatTransform; use crate::proxy::request_detail::RequestDetailSnapshot; use crate::proxy::request_body::ReplayableBody; +use crate::proxy::server_helpers::log_debug_headers_body; use crate::proxy::{config::UpstreamRuntime, ProxyState, RequestMeta}; use crate::proxy::{UPSTREAM_NO_DATA_TIMEOUT}; +const DEBUG_UPSTREAM_LOG_LIMIT_BYTES: usize = usize::MAX; + pub(super) async fn attempt_upstream( state: &ProxyState, method: Method, @@ -317,6 +320,13 @@ async fn send_antigravity_with_fallback( ) -> Result { let urls = antigravity_fallback_urls(&upstream.base_url, upstream_path_with_query); let request_headers = antigravity_request_headers(request_headers, meta, antigravity); + log_debug_headers_body( + "upstream.request", + Some(&request_headers), + Some(body), + DEBUG_UPSTREAM_LOG_LIMIT_BYTES, + ) + .await; let mut last_transport: Option = None; let mut saw_timeout = false; for (idx, url) in urls.iter().enumerate() { @@ -454,6 +464,13 @@ async fn send_upstream_request_once( request_detail: Option<&RequestDetailSnapshot>, start_time: Instant, ) -> Result { + log_debug_headers_body( + "upstream.request", + Some(request_headers), + Some(body), + DEBUG_UPSTREAM_LOG_LIMIT_BYTES, + ) + .await; let client = state .http_clients .client_for_proxy_url(upstream.proxy_url.as_deref()) @@ -569,6 +586,13 @@ async fn send_codex_attempt( start_time: Instant, attempt: &CodexSendAttempt, ) -> Result { + log_debug_headers_body( + "upstream.request", + Some(request_headers), + Some(body), + DEBUG_UPSTREAM_LOG_LIMIT_BYTES, + ) + .await; let client = build_codex_client(attempt.proxy_url.as_deref(), attempt.http1_only).map_err(|message| { CodexAttemptError::Fatal(AttemptOutcome::Fatal(http::error_response(StatusCode::BAD_GATEWAY, message))) })?; diff --git a/crates/token_proxy_core/src/proxy/upstream/request.rs b/crates/token_proxy_core/src/proxy/upstream/request.rs index 65f6646..a4583af 100644 --- a/crates/token_proxy_core/src/proxy/upstream/request.rs +++ b/crates/token_proxy_core/src/proxy/upstream/request.rs @@ -20,6 +20,7 @@ use super::super::{ RequestMeta, }; use super::super::http::RequestAuth; +use crate::proxy::server_helpers::log_debug_headers_body; const ANTHROPIC_VERSION_HEADER: &str = "anthropic-version"; const DEFAULT_ANTHROPIC_VERSION: &str = "2023-06-01"; @@ -28,6 +29,7 @@ const GEMINI_API_KEY_HEADER: HeaderName = HeaderName::from_static("x-goog-api-ke const OPENAI_RESPONSES_PATH: &str = "/v1/responses"; // Keep in sync with server_helpers request transform limit (20 MiB). const REQUEST_FILTER_LIMIT_BYTES: usize = 20 * 1024 * 1024; +const DEBUG_UPSTREAM_LOG_LIMIT_BYTES: usize = usize::MAX; pub(super) fn split_path_query(path_with_query: &str) -> (&str, Option<&str>) { match path_with_query.split_once('?') { @@ -228,6 +230,14 @@ async fn build_antigravity_body( .map_err(|message| { AttemptOutcome::Fatal(http::error_response(StatusCode::BAD_GATEWAY, message)) })?; + let wrapped_body = ReplayableBody::from_bytes(wrapped.clone()); + log_debug_headers_body( + "antigravity.wrapped", + None, + Some(&wrapped_body), + DEBUG_UPSTREAM_LOG_LIMIT_BYTES, + ) + .await; Ok(reqwest::Body::from(wrapped)) } diff --git a/messages/en.json b/messages/en.json index 9464669..1a7b286 100644 --- a/messages/en.json +++ b/messages/en.json @@ -375,7 +375,7 @@ "dashboard_hint_success_rate": "Success rate {rate}", "dashboard_hint_error_rate": "Error rate {rate}", "dashboard_tokens_hint_no_cache": "Input {input} · Output {output}", - "dashboard_tokens_hint_with_cache": "Input {input} (cached {cached}) · Output {output}", + "dashboard_tokens_hint_with_cache": "Input {input} · Output {output} · Cached {cached}", "dashboard_latency_hint": "Time to first byte (avg.)", "dashboard_providers_title": "Providers", "dashboard_providers_desc": "Sorted by tokens (Top 10)", diff --git a/messages/zh.json b/messages/zh.json index 231e833..6a00370 100644 --- a/messages/zh.json +++ b/messages/zh.json @@ -376,7 +376,7 @@ "dashboard_hint_success_rate": "成功率 {rate}", "dashboard_hint_error_rate": "错误率 {rate}", "dashboard_tokens_hint_no_cache": "输入 {input} · 输出 {output}", - "dashboard_tokens_hint_with_cache": "输入 {input} (缓存 {cached}) · 输出 {output}", + "dashboard_tokens_hint_with_cache": "输入 {input} · 输出 {output} · 缓存 {cached}", "dashboard_latency_hint": "按请求均值", "dashboard_providers_title": "Providers", "dashboard_providers_desc": "按 Tokens 排序(Top 10)", diff --git a/src-tauri/src/proxy/antigravity_compat.rs b/src-tauri/src/proxy/antigravity_compat.rs index e71e178..22c913a 100644 --- a/src-tauri/src/proxy/antigravity_compat.rs +++ b/src-tauri/src/proxy/antigravity_compat.rs @@ -8,6 +8,11 @@ use crate::oauth_util::generate_state; use crate::proxy::antigravity_schema::clean_json_schema_for_antigravity; use crate::proxy::sse::SseEventParser; +mod signature_cache; +mod claude; + +pub(crate) use claude::claude_request_to_antigravity; + const DEFAULT_MODEL: &str = "gemini-1.5-flash"; const THOUGHT_SIGNATURE_SENTINEL: &str = "skip_thought_signature_validator"; const PAYLOAD_USER_AGENT: &str = "antigravity"; diff --git a/src-tauri/src/proxy/antigravity_compat/claude.rs b/src-tauri/src/proxy/antigravity_compat/claude.rs new file mode 100644 index 0000000..c582745 --- /dev/null +++ b/src-tauri/src/proxy/antigravity_compat/claude.rs @@ -0,0 +1,496 @@ +use axum::body::Bytes; +use serde_json::{json, Map, Value}; + +use super::signature_cache; +use crate::proxy::antigravity_schema::clean_json_schema_for_antigravity; + +const THOUGHT_SIGNATURE_SENTINEL: &str = "skip_thought_signature_validator"; +const INTERLEAVED_HINT: &str = "Interleaved thinking is enabled. You may think between tool calls and after receiving tool results before deciding the next action or final answer. Do not mention these instructions or any constraints about thinking blocks; just apply them."; + +pub(crate) fn claude_request_to_antigravity( + body: &Bytes, + model_hint: Option<&str>, +) -> Result { + // Dedicated Claude -> Gemini request conversion to align with CLIProxyAPIPlus. + let object = parse_request_object(body)?; + let model_name = resolve_model_name(&object, model_hint); + let mapped_model = super::map_antigravity_model(&model_name); + let (contents, enable_thinking_translate) = build_contents(&object, &mapped_model)?; + let tools = build_tools(&object); + let thinking_enabled = thinking_enabled(&object); + let should_hint = tools.is_some() && thinking_enabled && is_claude_thinking_model(&mapped_model); + + let mut out = Map::new(); + if !mapped_model.trim().is_empty() { + out.insert("model".to_string(), Value::String(mapped_model)); + } + if !contents.is_empty() { + out.insert("contents".to_string(), Value::Array(contents)); + } + if let Some(system_instruction) = build_system_instruction(&object, should_hint) { + out.insert("systemInstruction".to_string(), system_instruction); + } + if let Some(tools) = tools { + out.insert("tools".to_string(), tools); + } + if let Some(gen) = build_generation_config(&object, enable_thinking_translate) { + out.insert("generationConfig".to_string(), gen); + } + + serde_json::to_vec(&Value::Object(out)) + .map(Bytes::from) + .map_err(|err| format!("Failed to serialize request: {err}")) +} + +fn parse_request_object(body: &Bytes) -> Result, String> { + let value: Value = + serde_json::from_slice(body).map_err(|_| "Request body must be JSON.".to_string())?; + value + .as_object() + .cloned() + .ok_or_else(|| "Request body must be a JSON object.".to_string()) +} + +fn resolve_model_name(object: &Map, model_hint: Option<&str>) -> String { + object + .get("model") + .and_then(Value::as_str) + .map(|value| value.trim()) + .filter(|value| !value.is_empty()) + .map(|value| value.to_string()) + .or_else(|| { + model_hint + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(|value| value.to_string()) + }) + .unwrap_or_default() +} + +fn build_system_instruction(object: &Map, should_hint: bool) -> Option { + let mut parts = system_parts(object); + if should_hint { + parts.push(json!({ "text": INTERLEAVED_HINT })); + } + if parts.is_empty() { + return None; + } + Some(json!({ "role": "user", "parts": parts })) +} + +fn system_parts(object: &Map) -> Vec { + let Some(system) = object.get("system") else { + return Vec::new(); + }; + match system { + Value::String(text) => system_parts_from_text(text), + Value::Array(items) => items + .iter() + .filter_map(|item| item.as_object()) + .filter(|item| item.get("type").and_then(Value::as_str) == Some("text")) + .filter_map(|item| item.get("text").and_then(Value::as_str)) + .flat_map(system_parts_from_text) + .collect(), + _ => Vec::new(), + } +} + +fn system_parts_from_text(text: &str) -> Vec { + let trimmed = text.trim(); + if trimmed.is_empty() { + Vec::new() + } else { + vec![json!({ "text": trimmed })] + } +} + +fn thinking_enabled(object: &Map) -> bool { + object + .get("thinking") + .and_then(Value::as_object) + .and_then(|thinking| thinking.get("type")) + .and_then(Value::as_str) + == Some("enabled") +} + +fn is_claude_thinking_model(model_name: &str) -> bool { + let lower = model_name.to_lowercase(); + lower.contains("claude") && lower.contains("thinking") +} + +fn build_contents( + object: &Map, + model_name: &str, +) -> Result<(Vec, bool), String> { + let Some(messages) = object.get("messages").and_then(Value::as_array) else { + return Ok((Vec::new(), true)); + }; + let mut contents = Vec::with_capacity(messages.len()); + let mut enable_thinking_translate = true; + + for message in messages { + let Some(message) = message.as_object() else { + continue; + }; + let role = message.get("role").and_then(Value::as_str).unwrap_or("user"); + let role = if role == "assistant" { "model" } else { role }; + let mut parts = Vec::new(); + let mut current_signature = String::new(); + match message.get("content") { + Some(Value::Array(items)) => { + for item in items { + let Some(item) = item.as_object() else { + continue; + }; + let block_type = item.get("type").and_then(Value::as_str).unwrap_or(""); + handle_block( + item, + block_type, + model_name, + &mut current_signature, + &mut enable_thinking_translate, + &mut parts, + ); + } + } + Some(Value::String(text)) => push_text_part(text, &mut parts), + _ => {} + } + reorder_thinking_parts(role, &mut parts); + contents.push(json!({ "role": role, "parts": parts })); + } + + Ok((contents, enable_thinking_translate)) +} + +fn handle_block( + item: &Map, + block_type: &str, + model_name: &str, + current_signature: &mut String, + enable_thinking_translate: &mut bool, + parts: &mut Vec, +) { + match block_type { + "thinking" => { + handle_thinking_block(item, model_name, current_signature, enable_thinking_translate, parts); + } + "text" => { + if let Some(text) = item.get("text").and_then(Value::as_str) { + push_text_part(text, parts); + } + } + "tool_use" => { + if let Some(part) = tool_use_to_part(item, model_name, current_signature) { + parts.push(part); + } + } + "tool_result" => { + if let Some(part) = tool_result_to_part(item) { + parts.push(part); + } + } + "image" => { + if let Some(part) = image_to_part(item) { + parts.push(part); + } + } + _ => {} + } +} + +fn handle_thinking_block( + item: &Map, + model_name: &str, + current_signature: &mut String, + enable_thinking_translate: &mut bool, + parts: &mut Vec, +) { + let thinking_text = extract_text_value(item.get("thinking")).unwrap_or_default(); + let signature = resolve_thinking_signature(model_name, &thinking_text, item); + if !signature_cache::has_valid_signature(model_name, &signature) { + *enable_thinking_translate = false; + return; + } + *current_signature = signature.clone(); + if !thinking_text.is_empty() { + signature_cache::cache_signature(model_name, &thinking_text, &signature); + } + let mut part = json!({ "thought": true }); + if !thinking_text.is_empty() { + if let Some(part) = part.as_object_mut() { + part.insert("text".to_string(), Value::String(thinking_text)); + } + } + if !signature.is_empty() { + if let Some(part) = part.as_object_mut() { + part.insert("thoughtSignature".to_string(), Value::String(signature)); + } + } + parts.push(part); +} + +fn resolve_thinking_signature( + model_name: &str, + thinking_text: &str, + item: &Map, +) -> String { + let cached = signature_cache::get_cached_signature(model_name, thinking_text); + if !cached.is_empty() { + return cached; + } + let signature = item.get("signature").and_then(Value::as_str).unwrap_or(""); + parse_client_signature(model_name, signature) +} + +fn parse_client_signature(model_name: &str, signature: &str) -> String { + if signature.contains('#') { + let mut parts = signature.splitn(2, '#'); + let prefix = parts.next().unwrap_or(""); + let value = parts.next().unwrap_or(""); + if prefix == model_name { + return value.to_string(); + } + } + signature.to_string() +} + +fn tool_use_to_part( + item: &Map, + model_name: &str, + current_signature: &str, +) -> Option { + let name = item.get("name").and_then(Value::as_str).unwrap_or(""); + let id = item.get("id").and_then(Value::as_str).unwrap_or(""); + let args_raw = parse_tool_use_input(item.get("input"))?; + + let mut part = json!({ + "functionCall": { + "name": name, + "args": args_raw + } + }); + if !id.is_empty() { + if let Some(call) = part.get_mut("functionCall").and_then(Value::as_object_mut) { + call.insert("id".to_string(), Value::String(id.to_string())); + } + } + + let signature = if signature_cache::has_valid_signature(model_name, current_signature) { + current_signature.to_string() + } else { + // Antigravity requires thoughtSignature for tool calls; use sentinel when missing. + THOUGHT_SIGNATURE_SENTINEL.to_string() + }; + if let Some(part) = part.as_object_mut() { + part.insert("thoughtSignature".to_string(), Value::String(signature)); + } + Some(part) +} + +fn parse_tool_use_input(input: Option<&Value>) -> Option { + match input { + Some(Value::Object(object)) => Some(Value::Object(object.clone())), + Some(Value::String(raw)) => serde_json::from_str::(raw).ok().and_then(|val| { + if val.is_object() { + Some(val) + } else { + None + } + }), + _ => None, + } +} + +fn tool_result_to_part(item: &Map) -> Option { + let tool_call_id = item.get("tool_use_id").and_then(Value::as_str).unwrap_or(""); + if tool_call_id.is_empty() { + return None; + } + let func_name = tool_call_name_from_id(tool_call_id); + let response = tool_result_response(item.get("content")); + Some(json!({ + "functionResponse": { + "id": tool_call_id, + "name": func_name, + "response": { "result": response } + } + })) +} + +fn tool_call_name_from_id(tool_call_id: &str) -> String { + let parts = tool_call_id.split('-').collect::>(); + if parts.len() <= 2 { + return tool_call_id.to_string(); + } + parts[..parts.len() - 2].join("-") +} + +fn tool_result_response(value: Option<&Value>) -> Value { + match value { + Some(Value::String(text)) => Value::String(text.to_string()), + Some(Value::Array(items)) => { + if items.len() == 1 { + items[0].clone() + } else { + Value::Array(items.clone()) + } + } + Some(Value::Object(object)) => Value::Object(object.clone()), + Some(other) => other.clone(), + None => Value::String(String::new()), + } +} + +fn image_to_part(item: &Map) -> Option { + let source = item.get("source").and_then(Value::as_object)?; + if source.get("type").and_then(Value::as_str) != Some("base64") { + return None; + } + let media_type = source + .get("media_type") + .and_then(Value::as_str) + .unwrap_or("image/png"); + let data = source.get("data").and_then(Value::as_str)?; + Some(json!({ + "inlineData": { + "mime_type": media_type, + "data": data + } + })) +} + +fn push_text_part(text: &str, parts: &mut Vec) { + if !text.is_empty() { + parts.push(json!({ "text": text })); + } +} + +fn reorder_thinking_parts(role: &str, parts: &mut Vec) { + if role != "model" || parts.is_empty() { + return; + } + let mut thinking = Vec::new(); + let mut others = Vec::new(); + for part in parts.iter() { + if part.get("thought").and_then(Value::as_bool) == Some(true) { + thinking.push(part.clone()); + } else { + others.push(part.clone()); + } + } + if thinking.is_empty() { + return; + } + let first_is_thinking = parts + .first() + .and_then(|part| part.get("thought").and_then(Value::as_bool)) + .unwrap_or(false); + if first_is_thinking && thinking.len() <= 1 { + return; + } + parts.clear(); + parts.extend(thinking); + parts.extend(others); +} + +fn build_tools(object: &Map) -> Option { + let tools = object.get("tools").and_then(Value::as_array)?; + let mut decls = Vec::new(); + for tool in tools { + let Some(tool) = tool.as_object() else { + continue; + }; + let input_schema = tool.get("input_schema"); + let Some(schema) = input_schema.and_then(Value::as_object) else { + continue; + }; + let mut tool_obj = Map::new(); + for (key, value) in tool.iter() { + if key == "input_schema" { + continue; + } + if is_allowed_tool_key(key) { + tool_obj.insert(key.to_string(), value.clone()); + } + } + let mut schema_value = Value::Object(schema.clone()); + clean_json_schema_for_antigravity(&mut schema_value); + tool_obj.insert("parametersJsonSchema".to_string(), schema_value); + decls.push(Value::Object(tool_obj)); + } + if decls.is_empty() { + None + } else { + Some(json!([{ "functionDeclarations": decls }])) + } +} + +fn is_allowed_tool_key(key: &str) -> bool { + matches!( + key, + "name" + | "description" + | "behavior" + | "parameters" + | "parametersJsonSchema" + | "response" + | "responseJsonSchema" + ) +} + +fn build_generation_config(object: &Map, enable_thinking: bool) -> Option { + let mut gen = Map::new(); + if enable_thinking { + if let Some(thinking) = object.get("thinking").and_then(Value::as_object) { + if thinking.get("type").and_then(Value::as_str) == Some("enabled") { + if let Some(budget) = thinking.get("budget_tokens").and_then(Value::as_i64) { + gen.insert( + "thinkingConfig".to_string(), + json!({ + "thinkingBudget": budget, + "includeThoughts": true + }), + ); + } + } + } + } + if let Some(value) = object.get("temperature").and_then(Value::as_f64) { + gen.insert("temperature".to_string(), json!(value)); + } + if let Some(value) = object.get("top_p").and_then(Value::as_f64) { + gen.insert("topP".to_string(), json!(value)); + } + if let Some(value) = object.get("top_k").and_then(Value::as_i64) { + gen.insert("topK".to_string(), json!(value)); + } + if let Some(value) = object.get("max_tokens").and_then(Value::as_i64) { + gen.insert("maxOutputTokens".to_string(), json!(value)); + } + if gen.is_empty() { + None + } else { + Some(Value::Object(gen)) + } +} + +fn extract_text_value(value: Option<&Value>) -> Option { + match value { + Some(Value::String(text)) => Some(text.to_string()), + Some(Value::Object(object)) => { + if let Some(text) = object.get("text") { + return extract_text_value(Some(text)); + } + if let Some(text) = object.get("value") { + return extract_text_value(Some(text)); + } + None + } + _ => None, + } +} + +#[cfg(test)] +#[path = "claude.test.rs"] +mod tests; diff --git a/src-tauri/src/proxy/antigravity_compat/signature_cache.rs b/src-tauri/src/proxy/antigravity_compat/signature_cache.rs new file mode 100644 index 0000000..a9095da --- /dev/null +++ b/src-tauri/src/proxy/antigravity_compat/signature_cache.rs @@ -0,0 +1,107 @@ +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::sync::{Mutex, OnceLock}; +use std::time::{Duration, Instant}; + +const SIGNATURE_CACHE_TTL: Duration = Duration::from_secs(3 * 60 * 60); +const SIGNATURE_TEXT_HASH_LEN: usize = 16; +const MIN_VALID_SIGNATURE_LEN: usize = 50; +const GEMINI_SKIP_SENTINEL: &str = "skip_thought_signature_validator"; + +type Cache = HashMap>; + +#[derive(Clone)] +struct SignatureEntry { + signature: String, + touched: Instant, +} + +static SIGNATURE_CACHE: OnceLock> = OnceLock::new(); + +fn cache_lock() -> std::sync::MutexGuard<'static, Cache> { + SIGNATURE_CACHE + .get_or_init(|| Mutex::new(HashMap::new())) + .lock() + .unwrap_or_else(|err| err.into_inner()) +} + +pub(crate) fn cache_signature(model_name: &str, text: &str, signature: &str) { + if text.trim().is_empty() || signature.trim().is_empty() { + return; + } + if signature.len() < MIN_VALID_SIGNATURE_LEN { + return; + } + let group_key = model_group_key(model_name); + let text_hash = hash_text(text); + let mut cache = cache_lock(); + let group = cache.entry(group_key).or_insert_with(HashMap::new); + group.insert( + text_hash, + SignatureEntry { + signature: signature.to_string(), + touched: Instant::now(), + }, + ); +} + +pub(crate) fn get_cached_signature(model_name: &str, text: &str) -> String { + let group_key = model_group_key(model_name); + if text.trim().is_empty() { + return fallback_signature(&group_key); + } + let text_hash = hash_text(text); + let mut cache = cache_lock(); + let Some(group) = cache.get_mut(&group_key) else { + return fallback_signature(&group_key); + }; + let Some(entry) = group.get_mut(&text_hash) else { + return fallback_signature(&group_key); + }; + if entry.touched.elapsed() > SIGNATURE_CACHE_TTL { + group.remove(&text_hash); + return fallback_signature(&group_key); + } + entry.touched = Instant::now(); + entry.signature.clone() +} + +pub(crate) fn has_valid_signature(model_name: &str, signature: &str) -> bool { + if signature.trim().is_empty() { + return false; + } + if signature == GEMINI_SKIP_SENTINEL { + return model_group_key(model_name) == "gemini"; + } + signature.len() >= MIN_VALID_SIGNATURE_LEN +} + +fn fallback_signature(group_key: &str) -> String { + if group_key == "gemini" { + GEMINI_SKIP_SENTINEL.to_string() + } else { + String::new() + } +} + +fn model_group_key(model_name: &str) -> String { + let lower = model_name.to_lowercase(); + if lower.contains("gpt") { + return "gpt".to_string(); + } + if lower.contains("claude") { + return "claude".to_string(); + } + if lower.contains("gemini") { + return "gemini".to_string(); + } + model_name.trim().to_string() +} + +fn hash_text(text: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(text.as_bytes()); + let digest = hasher.finalize(); + let hex = format!("{:x}", digest); + hex.chars().take(SIGNATURE_TEXT_HASH_LEN).collect() +} diff --git a/src-tauri/src/proxy/config/normalize.rs b/src-tauri/src/proxy/config/normalize.rs index 9d624b9..97af6d2 100644 --- a/src-tauri/src/proxy/config/normalize.rs +++ b/src-tauri/src/proxy/config/normalize.rs @@ -48,16 +48,18 @@ pub(super) fn build_provider_upstreams( Ok(output) } -fn group_upstreams_by_priority(mut upstreams: Vec) -> Vec { - upstreams.sort_by(|left, right| right.priority.cmp(&left.priority)); - let mut groups: Vec = Vec::new(); +fn group_upstreams_by_priority(upstreams: Vec) -> Vec { + // Keep same-priority order stable by preserving config insertion order. + let mut grouped: HashMap> = HashMap::new(); for upstream in upstreams { - match groups.last_mut() { - Some(group) if group.priority == upstream.priority => group.items.push(upstream), - _ => groups.push(UpstreamGroup { - priority: upstream.priority, - items: vec![upstream], - }), + grouped.entry(upstream.priority).or_default().push(upstream); + } + let mut priorities: Vec = grouped.keys().copied().collect(); + priorities.sort_by(|left, right| right.cmp(left)); + let mut groups = Vec::with_capacity(priorities.len()); + for priority in priorities { + if let Some(items) = grouped.remove(&priority) { + groups.push(UpstreamGroup { priority, items }); } } groups diff --git a/src-tauri/src/proxy/response/dispatch/buffered.rs b/src-tauri/src/proxy/response/dispatch/buffered.rs index 016df6e..4a681c8 100644 --- a/src-tauri/src/proxy/response/dispatch/buffered.rs +++ b/src-tauri/src/proxy/response/dispatch/buffered.rs @@ -15,12 +15,16 @@ use super::super::super::{ log::{build_log_entry, LogContext, LogWriter, UsageSnapshot}, model, openai_compat::{transform_response_body, FormatTransform}, + request_body::ReplayableBody, redact::redact_query_param_value, + server_helpers::log_debug_headers_body, token_rate::RequestTokenTracker, usage::extract_usage_from_response, UPSTREAM_NO_DATA_TIMEOUT, }; +const DEBUG_BODY_LOG_LIMIT_BYTES: usize = usize::MAX; + pub(super) async fn build_buffered_response( status: StatusCode, upstream_res: reqwest::Response, @@ -33,10 +37,18 @@ pub(super) async fn build_buffered_response( estimated_input_tokens: Option, ) -> Response { let mut context = context; + let response_headers = upstream_res.headers().clone(); let bytes = match read_upstream_bytes(upstream_res, &mut context, &log).await { Ok(bytes) => bytes, Err(response) => return response, }; + log_debug_headers_body( + "upstream.response.raw", + Some(&response_headers), + Some(&ReplayableBody::from_bytes(bytes.clone())), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; let bytes = if context.provider == PROVIDER_ANTIGRAVITY { match antigravity_compat::unwrap_response(&bytes) { Ok(unwrapped) => unwrapped, @@ -47,6 +59,15 @@ pub(super) async fn build_buffered_response( } else { bytes }; + if context.provider == PROVIDER_ANTIGRAVITY { + log_debug_headers_body( + "upstream.response.unwrapped", + Some(&response_headers), + Some(&ReplayableBody::from_bytes(bytes.clone())), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; + } let mut usage = extract_usage_from_response(&bytes); let response_error = response_error_for_status(status, &bytes); let request_body = context.request_body.clone(); @@ -74,6 +95,13 @@ pub(super) async fn build_buffered_response( log.clone().write_detached(entry); let output = maybe_override_response_model(output, model_override); + log_debug_headers_body( + "outbound.response", + Some(&headers), + Some(&ReplayableBody::from_bytes(output.clone())), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; let provider_for_tokens = provider_for_tokens(response_transform, context.provider.as_str()); token_count::apply_output_tokens_from_response(&request_tracker, provider_for_tokens, &output).await; diff --git a/src-tauri/src/proxy/response/dispatch/stream.rs b/src-tauri/src/proxy/response/dispatch/stream.rs index 0705c89..13c4388 100644 --- a/src-tauri/src/proxy/response/dispatch/stream.rs +++ b/src-tauri/src/proxy/response/dispatch/stream.rs @@ -16,6 +16,7 @@ use super::super::super::{ log::{build_log_entry, LogContext, LogWriter, UsageSnapshot}, openai_compat::FormatTransform, redact::redact_query_param_value, + server_helpers::log_debug_headers_body, token_rate::RequestTokenTracker, UPSTREAM_NO_DATA_TIMEOUT, }; @@ -25,6 +26,7 @@ type UpstreamBytesStream = futures_util::stream::BoxStream< Result>, >; type ResponseStream = futures_util::stream::BoxStream<'static, Result>; +const DEBUG_BODY_LOG_LIMIT_BYTES: usize = usize::MAX; pub(super) async fn build_stream_response( status: StatusCode, @@ -43,11 +45,19 @@ pub(super) async fn build_stream_response( Ok(stream) => stream, Err(response) => return response, }; + log_debug_headers_body( + "upstream.response.headers", + Some(&headers), + None, + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; let upstream = if context.provider == PROVIDER_ANTIGRAVITY { antigravity_compat::stream_antigravity_to_gemini(upstream).boxed() } else { upstream }; + let upstream = log_upstream_stream_if_debug(upstream); let stream = stream_for_transform( response_transform, @@ -58,6 +68,14 @@ pub(super) async fn build_stream_response( estimated_input_tokens, model_override, ); + log_debug_headers_body( + "outbound.response.headers", + Some(&headers), + None, + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; + let stream = log_response_stream_if_debug(stream); let body = Body::from_stream(stream); http::build_response(status, headers, body) } @@ -490,6 +508,50 @@ fn stream_error_response( http::error_response(status, message) } +fn log_upstream_stream_if_debug(upstream: UpstreamBytesStream) -> UpstreamBytesStream { + if !tracing::enabled!(tracing::Level::DEBUG) { + return upstream; + } + upstream + .map(|item| { + if let Ok(chunk) = &item { + let text = String::from_utf8_lossy(chunk); + tracing::debug!( + stage = "upstream.response.chunk", + bytes = chunk.len(), + body = %text, + "debug dump" + ); + } else if let Err(err) = &item { + tracing::debug!(stage = "upstream.response.chunk.error", error = %err, "debug dump"); + } + item + }) + .boxed() +} + +fn log_response_stream_if_debug(stream: ResponseStream) -> ResponseStream { + if !tracing::enabled!(tracing::Level::DEBUG) { + return stream; + } + stream + .map(|item| { + if let Ok(chunk) = &item { + let text = String::from_utf8_lossy(chunk); + tracing::debug!( + stage = "outbound.response.chunk", + bytes = chunk.len(), + body = %text, + "debug dump" + ); + } else if let Err(err) = &item { + tracing::debug!(stage = "outbound.response.chunk.error", error = %err, "debug dump"); + } + item + }) + .boxed() +} + fn stream_with_optional_model_override( upstream: impl futures_util::stream::Stream> + Unpin + Send + 'static, context: LogContext, diff --git a/src-tauri/src/proxy/server.rs b/src-tauri/src/proxy/server.rs index 505bc57..43d4eee 100644 --- a/src-tauri/src/proxy/server.rs +++ b/src-tauri/src/proxy/server.rs @@ -574,19 +574,54 @@ async fn build_outbound_body_or_respond( body: ReplayableBody, request_start: Instant, ) -> Result { - let body = match maybe_transform_request_body( + let body = transform_body_or_respond( http_clients, + log, + request_detail.clone(), + path, + plan, + meta, + body, + request_start, + ) + .await?; + apply_openai_stream_options_or_respond( + log, + request_detail, + path, + plan, + meta, + body, + request_start, + ) + .await +} + +async fn transform_body_or_respond( + http_clients: &super::http_client::ProxyHttpClients, + log: &Arc, + request_detail: Option, + path: &str, + plan: &DispatchPlan, + meta: &RequestMeta, + body: ReplayableBody, + request_start: Instant, +) -> Result { + match maybe_transform_request_body( + http_clients, + plan.provider, + path, plan.request_transform, meta.original_model.as_deref(), body, ) .await { - Ok(body) => body, + Ok(body) => Ok(body), Err(err) => { log_request_error( log, - request_detail.clone(), + request_detail, path, plan.provider, LOCAL_UPSTREAM_ID, @@ -594,10 +629,20 @@ async fn build_outbound_body_or_respond( err.message.clone(), request_start, ); - return Err(http::error_response(err.status, err.message)); + Err(http::error_response(err.status, err.message)) } - }; + } +} +async fn apply_openai_stream_options_or_respond( + log: &Arc, + request_detail: Option, + path: &str, + plan: &DispatchPlan, + meta: &RequestMeta, + body: ReplayableBody, + request_start: Instant, +) -> Result { match maybe_force_openai_stream_options_include_usage( plan.provider, plan.outbound_path.unwrap_or(path), diff --git a/src-tauri/src/proxy/server_helpers.rs b/src-tauri/src/proxy/server_helpers.rs index c0130a7..eadf0ae 100644 --- a/src-tauri/src/proxy/server_helpers.rs +++ b/src-tauri/src/proxy/server_helpers.rs @@ -5,6 +5,7 @@ use axum::{ use serde_json::{Map, Value}; use super::{ + antigravity_compat, gemini, http_client::ProxyHttpClients, openai_compat::{ @@ -18,10 +19,11 @@ use super::{ const ANTHROPIC_MESSAGES_PREFIX: &str = "/v1/messages"; const ANTHROPIC_COMPLETE_PATH: &str = "/v1/complete"; +const PROVIDER_ANTIGRAVITY: &str = "antigravity"; const REQUEST_META_LIMIT_BYTES: usize = 2 * 1024 * 1024; // Format conversion needs the full JSON body; allow up to the default max_request_body_bytes (20 MiB). const REQUEST_TRANSFORM_LIMIT_BYTES: usize = 20 * 1024 * 1024; -const DEBUG_BODY_LOG_LIMIT_BYTES: usize = 64 * 1024; +const DEBUG_BODY_LOG_LIMIT_BYTES: usize = usize::MAX; const OPENAI_REASONING_MODEL_SUFFIX_PREFIX: &str = "-reasoning-"; #[derive(Debug)] @@ -162,49 +164,62 @@ fn ensure_stream_options_include_usage(object: &mut Map) -> bool } pub(crate) async fn log_debug_request(headers: &HeaderMap, body: &ReplayableBody) { - let header_snapshot: Vec<(String, String)> = headers - .iter() - .map(|(name, value)| { - let redacted = if is_sensitive_header(name.as_str()) { - "***".to_string() - } else { - value.to_str().unwrap_or("").to_string() - }; - (name.to_string(), redacted) - }) - .collect(); + log_debug_headers_body( + "inbound.request", + Some(headers), + Some(body), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; +} - let body_text = match body.read_bytes_if_small(DEBUG_BODY_LOG_LIMIT_BYTES).await { - Ok(Some(bytes)) => { - let text = String::from_utf8_lossy(&bytes); - Some(text.into_owned()) - } - Ok(None) => None, - Err(err) => { - tracing::debug!(error = %err, "debug body read failed"); - None +pub(crate) async fn log_debug_headers_body( + stage: &str, + headers: Option<&HeaderMap>, + body: Option<&ReplayableBody>, + max_body_bytes: usize, +) { + if !tracing::enabled!(tracing::Level::DEBUG) { + return; + } + + let header_snapshot = headers + .map(snapshot_headers_raw) + .unwrap_or_default(); + let body_text = if let Some(body) = body { + match body.read_bytes_if_small(max_body_bytes).await { + Ok(Some(bytes)) => Some(String::from_utf8_lossy(&bytes).into_owned()), + Ok(None) => Some(format!("[body omitted: larger than {max_body_bytes} bytes]")), + Err(err) => Some(format!("[body read failed: {err}]")), } + } else { + None }; match body_text { Some(text) => { - tracing::debug!(headers = ?header_snapshot, body = %text, "incoming request debug dump"); + tracing::debug!(stage, headers = ?header_snapshot, body = %text, "debug dump"); } None => { - tracing::debug!(headers = ?header_snapshot, "incoming request body omitted (too large)"); + tracing::debug!(stage, headers = ?header_snapshot, "debug dump (no body)"); } } } -fn is_sensitive_header(name: &str) -> bool { - matches!( - name.to_ascii_lowercase().as_str(), - "authorization" | "proxy-authorization" | "x-api-key" - ) +fn snapshot_headers_raw(headers: &HeaderMap) -> Vec<(String, String)> { + headers + .iter() + .map(|(name, value)| { + let value = value.to_str().unwrap_or("").to_string(); + (name.to_string(), value) + }) + .collect() } pub(crate) async fn maybe_transform_request_body( http_clients: &ProxyHttpClients, + provider: &str, + path: &str, transform: FormatTransform, model_hint: Option<&str>, body: ReplayableBody, @@ -228,11 +243,39 @@ pub(crate) async fn maybe_transform_request_body( "Request body is too large to transform.", )); }; + let inbound_body = ReplayableBody::from_bytes(bytes.clone()); + log_debug_headers_body( + "transform.input", + None, + Some(&inbound_body), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; + + let outbound_bytes = if should_use_antigravity_claude(provider, path, transform) { + antigravity_compat::claude_request_to_antigravity(&bytes, model_hint) + .map_err(|message| RequestError::new(StatusCode::BAD_REQUEST, message))? + } else { + transform_request_body(transform, &bytes, http_clients, model_hint) + .await + .map_err(|message| RequestError::new(StatusCode::BAD_REQUEST, message))? + }; + let outbound_body = ReplayableBody::from_bytes(outbound_bytes); + log_debug_headers_body( + "transform.output", + None, + Some(&outbound_body), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; + Ok(outbound_body) +} - let outbound_bytes = transform_request_body(transform, &bytes, http_clients, model_hint) - .await - .map_err(|message| RequestError::new(StatusCode::BAD_REQUEST, message))?; - Ok(ReplayableBody::from_bytes(outbound_bytes)) +fn should_use_antigravity_claude(provider: &str, path: &str, transform: FormatTransform) -> bool { + // Align Antigravity with CLIProxyAPIPlus for Claude /v1/messages. + provider == PROVIDER_ANTIGRAVITY + && path == ANTHROPIC_MESSAGES_PREFIX + && transform == FormatTransform::AnthropicToGemini } pub(crate) async fn maybe_force_openai_stream_options_include_usage( @@ -277,7 +320,15 @@ pub(crate) async fn maybe_force_openai_stream_options_include_usage( format!("Failed to serialize request: {err}"), ) })?; - Ok(ReplayableBody::from_bytes(outbound_bytes)) + let outbound_body = ReplayableBody::from_bytes(outbound_bytes); + log_debug_headers_body( + "stream_options.output", + None, + Some(&outbound_body), + DEBUG_BODY_LOG_LIMIT_BYTES, + ) + .await; + Ok(outbound_body) } pub(crate) async fn maybe_rewrite_openai_reasoning_effort_from_model_suffix( diff --git a/src-tauri/src/proxy/upstream/attempt.rs b/src-tauri/src/proxy/upstream/attempt.rs index 446cc90..3686dfb 100644 --- a/src-tauri/src/proxy/upstream/attempt.rs +++ b/src-tauri/src/proxy/upstream/attempt.rs @@ -16,9 +16,12 @@ use crate::proxy::http; use crate::proxy::openai_compat::FormatTransform; use crate::proxy::request_detail::RequestDetailSnapshot; use crate::proxy::request_body::ReplayableBody; +use crate::proxy::server_helpers::log_debug_headers_body; use crate::proxy::{config::UpstreamRuntime, ProxyState, RequestMeta}; use crate::proxy::{UPSTREAM_NO_DATA_TIMEOUT}; +const DEBUG_UPSTREAM_LOG_LIMIT_BYTES: usize = usize::MAX; + pub(super) async fn attempt_upstream( state: &ProxyState, method: Method, @@ -317,6 +320,13 @@ async fn send_antigravity_with_fallback( ) -> Result { let urls = antigravity_fallback_urls(&upstream.base_url, upstream_path_with_query); let request_headers = antigravity_request_headers(request_headers, meta, antigravity); + log_debug_headers_body( + "upstream.request", + Some(&request_headers), + Some(body), + DEBUG_UPSTREAM_LOG_LIMIT_BYTES, + ) + .await; let mut last_transport: Option = None; let mut saw_timeout = false; for (idx, url) in urls.iter().enumerate() { @@ -454,6 +464,13 @@ async fn send_upstream_request_once( request_detail: Option<&RequestDetailSnapshot>, start_time: Instant, ) -> Result { + log_debug_headers_body( + "upstream.request", + Some(request_headers), + Some(body), + DEBUG_UPSTREAM_LOG_LIMIT_BYTES, + ) + .await; let client = state .http_clients .client_for_proxy_url(upstream.proxy_url.as_deref()) @@ -569,6 +586,13 @@ async fn send_codex_attempt( start_time: Instant, attempt: &CodexSendAttempt, ) -> Result { + log_debug_headers_body( + "upstream.request", + Some(request_headers), + Some(body), + DEBUG_UPSTREAM_LOG_LIMIT_BYTES, + ) + .await; let client = build_codex_client(attempt.proxy_url.as_deref(), attempt.http1_only).map_err(|message| { CodexAttemptError::Fatal(AttemptOutcome::Fatal(http::error_response(StatusCode::BAD_GATEWAY, message))) })?; diff --git a/src-tauri/src/proxy/upstream/request.rs b/src-tauri/src/proxy/upstream/request.rs index 65f6646..a4583af 100644 --- a/src-tauri/src/proxy/upstream/request.rs +++ b/src-tauri/src/proxy/upstream/request.rs @@ -20,6 +20,7 @@ use super::super::{ RequestMeta, }; use super::super::http::RequestAuth; +use crate::proxy::server_helpers::log_debug_headers_body; const ANTHROPIC_VERSION_HEADER: &str = "anthropic-version"; const DEFAULT_ANTHROPIC_VERSION: &str = "2023-06-01"; @@ -28,6 +29,7 @@ const GEMINI_API_KEY_HEADER: HeaderName = HeaderName::from_static("x-goog-api-ke const OPENAI_RESPONSES_PATH: &str = "/v1/responses"; // Keep in sync with server_helpers request transform limit (20 MiB). const REQUEST_FILTER_LIMIT_BYTES: usize = 20 * 1024 * 1024; +const DEBUG_UPSTREAM_LOG_LIMIT_BYTES: usize = usize::MAX; pub(super) fn split_path_query(path_with_query: &str) -> (&str, Option<&str>) { match path_with_query.split_once('?') { @@ -228,6 +230,14 @@ async fn build_antigravity_body( .map_err(|message| { AttemptOutcome::Fatal(http::error_response(StatusCode::BAD_GATEWAY, message)) })?; + let wrapped_body = ReplayableBody::from_bytes(wrapped.clone()); + log_debug_headers_body( + "antigravity.wrapped", + None, + Some(&wrapped_body), + DEBUG_UPSTREAM_LOG_LIMIT_BYTES, + ) + .await; Ok(reqwest::Body::from(wrapped)) } diff --git a/src-tauri/src/proxy/usage.rs b/src-tauri/src/proxy/usage.rs index d9c982e..2b4618b 100644 --- a/src-tauri/src/proxy/usage.rs +++ b/src-tauri/src/proxy/usage.rs @@ -176,7 +176,10 @@ fn update_usage(snapshot: &mut UsageSnapshot, data: &str) { if updated.usage_json.is_some() { snapshot.usage_json = updated.usage_json; snapshot.usage = updated.usage; - snapshot.cached_tokens = updated.cached_tokens; + if updated.cached_tokens.is_some() { + // Preserve earlier cache stats when later events omit cache fields. + snapshot.cached_tokens = updated.cached_tokens; + } } } diff --git a/src/features/config/cards/upstreams/table.tsx b/src/features/config/cards/upstreams/table.tsx index 7dc1a2f..e79d54e 100644 --- a/src/features/config/cards/upstreams/table.tsx +++ b/src/features/config/cards/upstreams/table.tsx @@ -327,7 +327,8 @@ function UpstreamRowActions({ type UpstreamsTableRowProps = { upstream: UpstreamForm; - index: number; + upstreamIndex: number; + displayIndex: number; columns: readonly UpstreamColumnDefinition[]; showApiKeys: boolean; kiroAccounts: KiroAccountMap; @@ -342,7 +343,8 @@ type UpstreamsTableRowProps = { function UpstreamsTableRow({ upstream, - index, + upstreamIndex, + displayIndex, columns, showApiKeys, kiroAccounts, @@ -354,7 +356,7 @@ function UpstreamsTableRow({ onToggleEnabled, onDelete, }: UpstreamsTableRowProps) { - const rowLabel = getUpstreamLabel(index); + const rowLabel = getUpstreamLabel(displayIndex); return ( {columns.map((column) => ( @@ -378,10 +380,10 @@ function UpstreamsTableRow({ rowLabel={rowLabel} enabled={upstream.enabled} disableDelete={disableDelete} - onEdit={() => onEdit(index)} - onCopy={() => onCopy(index)} - onToggleEnabled={() => onToggleEnabled(index)} - onDelete={() => onDelete(index)} + onEdit={() => onEdit(upstreamIndex)} + onCopy={() => onCopy(upstreamIndex)} + onToggleEnabled={() => onToggleEnabled(upstreamIndex)} + onDelete={() => onDelete(upstreamIndex)} /> ); @@ -401,6 +403,37 @@ export type UpstreamsTableProps = { onDelete: (index: number) => void; }; +type SortedUpstreamEntry = { + upstream: UpstreamForm; + upstreamIndex: number; + priority: number; +}; + +function parsePriorityValue(value: string) { + const trimmed = value.trim(); + if (!trimmed) { + return 0; + } + const number = Number.parseInt(trimmed, 10); + return Number.isFinite(number) ? number : 0; +} + +function sortUpstreamsByPriority(upstreams: UpstreamForm[]) { + // Display order follows priority descending; ties keep original list order. + const entries = upstreams.map((upstream, upstreamIndex): SortedUpstreamEntry => ({ + upstream, + upstreamIndex, + priority: parsePriorityValue(upstream.priority), + })); + entries.sort((left, right) => { + if (left.priority !== right.priority) { + return right.priority - left.priority; + } + return left.upstreamIndex - right.upstreamIndex; + }); + return entries; +} + export function UpstreamsTable({ upstreams, columns, @@ -414,16 +447,18 @@ export function UpstreamsTable({ onToggleEnabled, onDelete, }: UpstreamsTableProps) { + const sortedUpstreams = sortUpstreamsByPriority(upstreams); return (
- {upstreams.map((upstream, index) => ( + {sortedUpstreams.map((entry, displayIndex) => ( { const totalText = row.original.totalTokens === null ? CELL_PLACEHOLDER : formatInteger(row.original.totalTokens); const cachedText = row.original.cachedTokens ? formatInteger(row.original.cachedTokens) : null; - const tooltipText = cachedText ? `${totalText}\n${cachedText}` : totalText; + const totalLabel = m.dashboard_chart_total_tokens(); + const cachedLabel = m.dashboard_chart_cached_tokens(); + const tooltipParts = [ + totalText === CELL_PLACEHOLDER ? null : `${totalLabel} ${totalText}`, + cachedText ? `${cachedLabel} ${cachedText}` : null, + ].filter((part): part is string => Boolean(part)); + const tooltipText = tooltipParts.length > 0 ? tooltipParts.join("\n") : CELL_PLACEHOLDER; return ( @@ -165,7 +171,7 @@ function tokensColumn(): ColumnDef { {totalText} {cachedText ? ( - {cachedText} + {cachedLabel} {cachedText} ) : null} From 12856a604823a45b647af1406d47c1a4ac68b99f Mon Sep 17 00:00:00 2001 From: mxyhi Date: Fri, 30 Jan 2026 20:08:52 +0800 Subject: [PATCH 06/10] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(proxy):=20a?= =?UTF-8?q?lign=20antigravity=20enum=20schema=20processing=20-=20unconditi?= =?UTF-8?q?onally=20stringify=20enum=20values=20in=20`antigravity=5Fschema?= =?UTF-8?q?.rs`=20to=20align=20=20=20with=20CLIProxyAPIPlus=20behavior=20-?= =?UTF-8?q?=20force=20schema=20type=20to=20"string"=20for=20enum=20fields?= =?UTF-8?q?=20to=20ensure=20compatibility=20=20=20with=20Gemini=20and=20An?= =?UTF-8?q?tigravity=20backends?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🔧 chore(proxy): clean up redundant code and add tracking - remove unused proxy implementation files in `src-tauri/src/proxy` - delete obsolete antigravity design document - initialize workflow tracking via `task_plan.md`, `findings.md`, and `progress.md` --- .gitignore | 2 + AGENTS.md | 7 +- .../src/proxy/antigravity_schema.rs | 16 +- docs/plans/2026-01-21-antigravity-design.md | 45 - src-tauri/src/proxy/anthropic_compat.rs | 37 - src-tauri/src/proxy/anthropic_compat.test.rs | 183 ---- src-tauri/src/proxy/anthropic_compat/media.rs | 186 ---- .../src/proxy/anthropic_compat/request.rs | 586 ----------- .../src/proxy/anthropic_compat/response.rs | 273 ------ src-tauri/src/proxy/anthropic_compat/tools.rs | 173 ---- src-tauri/src/proxy/antigravity_compat.rs | 426 -------- .../src/proxy/antigravity_compat.test.rs | 171 ---- .../src/proxy/antigravity_compat/claude.rs | 496 ---------- .../antigravity_compat/signature_cache.rs | 107 -- src-tauri/src/proxy/antigravity_schema.rs | 775 --------------- src-tauri/src/proxy/codex_compat.rs | 38 - src-tauri/src/proxy/codex_compat.test.rs | 122 --- src-tauri/src/proxy/codex_compat/headers.rs | 46 - src-tauri/src/proxy/codex_compat/request.rs | 447 --------- src-tauri/src/proxy/codex_compat/response.rs | 289 ------ src-tauri/src/proxy/codex_compat/stream.rs | 491 ---------- .../src/proxy/codex_compat/tool_names.rs | 96 -- src-tauri/src/proxy/compat_content.rs | 80 -- src-tauri/src/proxy/compat_reason.rs | 65 -- src-tauri/src/proxy/config/io.rs | 205 ---- src-tauri/src/proxy/config/mod.rs | 104 -- src-tauri/src/proxy/config/model_mapping.rs | 143 --- .../src/proxy/config/model_mapping.test.rs | 36 - src-tauri/src/proxy/config/normalize.rs | 253 ----- src-tauri/src/proxy/config/types.rs | 297 ------ src-tauri/src/proxy/config/types.test.rs | 156 --- src-tauri/src/proxy/dashboard.rs | 481 --------- src-tauri/src/proxy/dashboard.test.rs | 61 -- src-tauri/src/proxy/gemini.rs | 31 - src-tauri/src/proxy/gemini_compat/mod.rs | 11 - src-tauri/src/proxy/gemini_compat/request.rs | 613 ------------ .../src/proxy/gemini_compat/request.test.rs | 58 -- src-tauri/src/proxy/gemini_compat/response.rs | 410 -------- .../src/proxy/gemini_compat/response.test.rs | 32 - src-tauri/src/proxy/gemini_compat/stream.rs | 560 ----------- src-tauri/src/proxy/gemini_compat/tools.rs | 175 ---- src-tauri/src/proxy/http.rs | 364 ------- src-tauri/src/proxy/http.test.rs | 58 -- src-tauri/src/proxy/http_client.rs | 46 - src-tauri/src/proxy/kiro/constants.rs | 52 - src-tauri/src/proxy/kiro/endpoint.rs | 36 - src-tauri/src/proxy/kiro/event_stream.rs | 159 --- src-tauri/src/proxy/kiro/mod.rs | 20 - src-tauri/src/proxy/kiro/model.rs | 68 -- src-tauri/src/proxy/kiro/payload/claude.rs | 510 ---------- .../src/proxy/kiro/payload/claude.test.rs | 55 -- src-tauri/src/proxy/kiro/payload/inference.rs | 28 - src-tauri/src/proxy/kiro/payload/input.rs | 157 --- src-tauri/src/proxy/kiro/payload/messages.rs | 391 -------- src-tauri/src/proxy/kiro/payload/mod.rs | 201 ---- src-tauri/src/proxy/kiro/payload/system.rs | 226 ----- src-tauri/src/proxy/kiro/response.rs | 394 -------- src-tauri/src/proxy/kiro/tool_parser.rs | 403 -------- src-tauri/src/proxy/kiro/tools.rs | 219 ----- src-tauri/src/proxy/kiro/types.rs | 132 --- src-tauri/src/proxy/kiro/utils.rs | 25 - src-tauri/src/proxy/log.rs | 195 ---- src-tauri/src/proxy/logs.rs | 53 - src-tauri/src/proxy/model.rs | 29 - src-tauri/src/proxy/openai_compat.rs | 554 ----------- src-tauri/src/proxy/openai_compat.test.rs | 682 ------------- src-tauri/src/proxy/openai_compat/extract.rs | 174 ---- src-tauri/src/proxy/openai_compat/input.rs | 145 --- src-tauri/src/proxy/openai_compat/message.rs | 183 ---- src-tauri/src/proxy/openai_compat/tools.rs | 114 --- src-tauri/src/proxy/openai_compat/usage.rs | 41 - src-tauri/src/proxy/redact.rs | 25 - src-tauri/src/proxy/request_body.rs | 176 ---- src-tauri/src/proxy/request_body.test.rs | 56 -- src-tauri/src/proxy/request_detail.rs | 111 --- src-tauri/src/proxy/request_token_estimate.rs | 556 ----------- .../src/proxy/request_token_estimate.test.rs | 140 --- src-tauri/src/proxy/response.rs | 200 ---- src-tauri/src/proxy/response.test.rs | 476 --------- .../proxy/response/anthropic_to_responses.rs | 572 ----------- .../response/anthropic_to_responses/format.rs | 65 -- .../src/proxy/response/chat_to_responses.rs | 593 ------------ .../response/chat_to_responses/format.rs | 64 -- .../src/proxy/response/dispatch/buffered.rs | 421 -------- src-tauri/src/proxy/response/dispatch/mod.rs | 60 -- .../src/proxy/response/dispatch/stream.rs | 587 ----------- .../src/proxy/response/kiro_to_anthropic.rs | 5 - .../kiro_to_anthropic_helpers.rs | 111 --- .../kiro_to_anthropic_stream.rs | 198 ---- .../kiro_to_anthropic_stream_blocks.rs | 357 ------- .../kiro_to_anthropic_stream_handlers.rs | 228 ----- .../src/proxy/response/kiro_to_responses.rs | 71 -- .../response/kiro_to_responses_helpers.rs | 378 -------- .../response/kiro_to_responses_stream.rs | 624 ------------ .../proxy/response/responses_to_anthropic.rs | 692 ------------- .../src/proxy/response/responses_to_chat.rs | 600 ------------ src-tauri/src/proxy/response/streaming.rs | 291 ------ src-tauri/src/proxy/response/token_count.rs | 95 -- src-tauri/src/proxy/response/upstream_read.rs | 23 - .../src/proxy/response/upstream_stream.rs | 53 - .../proxy/response/upstream_stream.test.rs | 41 - src-tauri/src/proxy/server.rs | 875 ----------------- src-tauri/src/proxy/server.test.rs | 301 ------ src-tauri/src/proxy/server/bootstrap.rs | 35 - src-tauri/src/proxy/server_helpers.rs | 436 --------- src-tauri/src/proxy/server_helpers.test.rs | 163 ---- src-tauri/src/proxy/service.rs | 371 ------- src-tauri/src/proxy/sqlite.rs | 170 ---- src-tauri/src/proxy/sse.rs | 64 -- src-tauri/src/proxy/token_estimator.rs | 259 ----- src-tauri/src/proxy/token_estimator.test.rs | 8 - src-tauri/src/proxy/token_rate.rs | 428 -------- src-tauri/src/proxy/upstream.rs | 713 -------------- src-tauri/src/proxy/upstream.test.rs | 123 --- src-tauri/src/proxy/upstream/attempt.rs | 915 ------------------ src-tauri/src/proxy/upstream/kiro.rs | 555 ----------- src-tauri/src/proxy/upstream/kiro_headers.rs | 72 -- src-tauri/src/proxy/upstream/kiro_http.rs | 151 --- src-tauri/src/proxy/upstream/request.rs | 318 ------ src-tauri/src/proxy/upstream/request.test.rs | 164 ---- src-tauri/src/proxy/upstream/result.rs | 145 --- src-tauri/src/proxy/upstream/utils.rs | 88 -- src-tauri/src/proxy/usage.rs | 189 ---- src-tauri/src/proxy/usage.test.rs | 72 -- 124 files changed, 14 insertions(+), 28967 deletions(-) delete mode 100644 docs/plans/2026-01-21-antigravity-design.md delete mode 100644 src-tauri/src/proxy/anthropic_compat.rs delete mode 100644 src-tauri/src/proxy/anthropic_compat.test.rs delete mode 100644 src-tauri/src/proxy/anthropic_compat/media.rs delete mode 100644 src-tauri/src/proxy/anthropic_compat/request.rs delete mode 100644 src-tauri/src/proxy/anthropic_compat/response.rs delete mode 100644 src-tauri/src/proxy/anthropic_compat/tools.rs delete mode 100644 src-tauri/src/proxy/antigravity_compat.rs delete mode 100644 src-tauri/src/proxy/antigravity_compat.test.rs delete mode 100644 src-tauri/src/proxy/antigravity_compat/claude.rs delete mode 100644 src-tauri/src/proxy/antigravity_compat/signature_cache.rs delete mode 100644 src-tauri/src/proxy/antigravity_schema.rs delete mode 100644 src-tauri/src/proxy/codex_compat.rs delete mode 100644 src-tauri/src/proxy/codex_compat.test.rs delete mode 100644 src-tauri/src/proxy/codex_compat/headers.rs delete mode 100644 src-tauri/src/proxy/codex_compat/request.rs delete mode 100644 src-tauri/src/proxy/codex_compat/response.rs delete mode 100644 src-tauri/src/proxy/codex_compat/stream.rs delete mode 100644 src-tauri/src/proxy/codex_compat/tool_names.rs delete mode 100644 src-tauri/src/proxy/compat_content.rs delete mode 100644 src-tauri/src/proxy/compat_reason.rs delete mode 100644 src-tauri/src/proxy/config/io.rs delete mode 100644 src-tauri/src/proxy/config/mod.rs delete mode 100644 src-tauri/src/proxy/config/model_mapping.rs delete mode 100644 src-tauri/src/proxy/config/model_mapping.test.rs delete mode 100644 src-tauri/src/proxy/config/normalize.rs delete mode 100644 src-tauri/src/proxy/config/types.rs delete mode 100644 src-tauri/src/proxy/config/types.test.rs delete mode 100644 src-tauri/src/proxy/dashboard.rs delete mode 100644 src-tauri/src/proxy/dashboard.test.rs delete mode 100644 src-tauri/src/proxy/gemini.rs delete mode 100644 src-tauri/src/proxy/gemini_compat/mod.rs delete mode 100644 src-tauri/src/proxy/gemini_compat/request.rs delete mode 100644 src-tauri/src/proxy/gemini_compat/request.test.rs delete mode 100644 src-tauri/src/proxy/gemini_compat/response.rs delete mode 100644 src-tauri/src/proxy/gemini_compat/response.test.rs delete mode 100644 src-tauri/src/proxy/gemini_compat/stream.rs delete mode 100644 src-tauri/src/proxy/gemini_compat/tools.rs delete mode 100644 src-tauri/src/proxy/http.rs delete mode 100644 src-tauri/src/proxy/http.test.rs delete mode 100644 src-tauri/src/proxy/http_client.rs delete mode 100644 src-tauri/src/proxy/kiro/constants.rs delete mode 100644 src-tauri/src/proxy/kiro/endpoint.rs delete mode 100644 src-tauri/src/proxy/kiro/event_stream.rs delete mode 100644 src-tauri/src/proxy/kiro/mod.rs delete mode 100644 src-tauri/src/proxy/kiro/model.rs delete mode 100644 src-tauri/src/proxy/kiro/payload/claude.rs delete mode 100644 src-tauri/src/proxy/kiro/payload/claude.test.rs delete mode 100644 src-tauri/src/proxy/kiro/payload/inference.rs delete mode 100644 src-tauri/src/proxy/kiro/payload/input.rs delete mode 100644 src-tauri/src/proxy/kiro/payload/messages.rs delete mode 100644 src-tauri/src/proxy/kiro/payload/mod.rs delete mode 100644 src-tauri/src/proxy/kiro/payload/system.rs delete mode 100644 src-tauri/src/proxy/kiro/response.rs delete mode 100644 src-tauri/src/proxy/kiro/tool_parser.rs delete mode 100644 src-tauri/src/proxy/kiro/tools.rs delete mode 100644 src-tauri/src/proxy/kiro/types.rs delete mode 100644 src-tauri/src/proxy/kiro/utils.rs delete mode 100644 src-tauri/src/proxy/log.rs delete mode 100644 src-tauri/src/proxy/logs.rs delete mode 100644 src-tauri/src/proxy/model.rs delete mode 100644 src-tauri/src/proxy/openai_compat.rs delete mode 100644 src-tauri/src/proxy/openai_compat.test.rs delete mode 100644 src-tauri/src/proxy/openai_compat/extract.rs delete mode 100644 src-tauri/src/proxy/openai_compat/input.rs delete mode 100644 src-tauri/src/proxy/openai_compat/message.rs delete mode 100644 src-tauri/src/proxy/openai_compat/tools.rs delete mode 100644 src-tauri/src/proxy/openai_compat/usage.rs delete mode 100644 src-tauri/src/proxy/redact.rs delete mode 100644 src-tauri/src/proxy/request_body.rs delete mode 100644 src-tauri/src/proxy/request_body.test.rs delete mode 100644 src-tauri/src/proxy/request_detail.rs delete mode 100644 src-tauri/src/proxy/request_token_estimate.rs delete mode 100644 src-tauri/src/proxy/request_token_estimate.test.rs delete mode 100644 src-tauri/src/proxy/response.rs delete mode 100644 src-tauri/src/proxy/response.test.rs delete mode 100644 src-tauri/src/proxy/response/anthropic_to_responses.rs delete mode 100644 src-tauri/src/proxy/response/anthropic_to_responses/format.rs delete mode 100644 src-tauri/src/proxy/response/chat_to_responses.rs delete mode 100644 src-tauri/src/proxy/response/chat_to_responses/format.rs delete mode 100644 src-tauri/src/proxy/response/dispatch/buffered.rs delete mode 100644 src-tauri/src/proxy/response/dispatch/mod.rs delete mode 100644 src-tauri/src/proxy/response/dispatch/stream.rs delete mode 100644 src-tauri/src/proxy/response/kiro_to_anthropic.rs delete mode 100644 src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_helpers.rs delete mode 100644 src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream.rs delete mode 100644 src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream/kiro_to_anthropic_stream_blocks.rs delete mode 100644 src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream/kiro_to_anthropic_stream_handlers.rs delete mode 100644 src-tauri/src/proxy/response/kiro_to_responses.rs delete mode 100644 src-tauri/src/proxy/response/kiro_to_responses_helpers.rs delete mode 100644 src-tauri/src/proxy/response/kiro_to_responses_stream.rs delete mode 100644 src-tauri/src/proxy/response/responses_to_anthropic.rs delete mode 100644 src-tauri/src/proxy/response/responses_to_chat.rs delete mode 100644 src-tauri/src/proxy/response/streaming.rs delete mode 100644 src-tauri/src/proxy/response/token_count.rs delete mode 100644 src-tauri/src/proxy/response/upstream_read.rs delete mode 100644 src-tauri/src/proxy/response/upstream_stream.rs delete mode 100644 src-tauri/src/proxy/response/upstream_stream.test.rs delete mode 100644 src-tauri/src/proxy/server.rs delete mode 100644 src-tauri/src/proxy/server.test.rs delete mode 100644 src-tauri/src/proxy/server/bootstrap.rs delete mode 100644 src-tauri/src/proxy/server_helpers.rs delete mode 100644 src-tauri/src/proxy/server_helpers.test.rs delete mode 100644 src-tauri/src/proxy/service.rs delete mode 100644 src-tauri/src/proxy/sqlite.rs delete mode 100644 src-tauri/src/proxy/sse.rs delete mode 100644 src-tauri/src/proxy/token_estimator.rs delete mode 100644 src-tauri/src/proxy/token_estimator.test.rs delete mode 100644 src-tauri/src/proxy/token_rate.rs delete mode 100644 src-tauri/src/proxy/upstream.rs delete mode 100644 src-tauri/src/proxy/upstream.test.rs delete mode 100644 src-tauri/src/proxy/upstream/attempt.rs delete mode 100644 src-tauri/src/proxy/upstream/kiro.rs delete mode 100644 src-tauri/src/proxy/upstream/kiro_headers.rs delete mode 100644 src-tauri/src/proxy/upstream/kiro_http.rs delete mode 100644 src-tauri/src/proxy/upstream/request.rs delete mode 100644 src-tauri/src/proxy/upstream/request.test.rs delete mode 100644 src-tauri/src/proxy/upstream/result.rs delete mode 100644 src-tauri/src/proxy/upstream/utils.rs delete mode 100644 src-tauri/src/proxy/usage.rs delete mode 100644 src-tauri/src/proxy/usage.test.rs diff --git a/.gitignore b/.gitignore index 6d0bea6..6e006a6 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,5 @@ src/paraglide/ *.sw? target/ + +.reference diff --git a/AGENTS.md b/AGENTS.md index 1e60d5d..c5cb01a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -9,7 +9,12 @@ - 前端: React 19 + TypeScript + Vite + Tailwind CSS v4 + shadcn/ui(pnpm dlx shadcn@latest add xxx) - 后端: Rust (Edition 2021) + Tokio + Axum - 桌面框架: Tauri 2 -- 代理转发/转换参考: [QuantumNous/new-api](https://github.com/QuantumNous/new-api) + +## 参考项目 + +- 代理转发/转换参考[new-api](.reference/new-api) +- kiro、codex、Antigravity等2api参考[CLIProxyAPIPlus](.reference/CLIProxyAPIPlus) +- CLIProxyAPIPlus的可视化app参考[quotio](.reference/quotio) --- diff --git a/crates/token_proxy_core/src/proxy/antigravity_schema.rs b/crates/token_proxy_core/src/proxy/antigravity_schema.rs index f852b7e..8b52e03 100644 --- a/crates/token_proxy_core/src/proxy/antigravity_schema.rs +++ b/crates/token_proxy_core/src/proxy/antigravity_schema.rs @@ -91,22 +91,18 @@ fn convert_enum_values_to_strings(schema: &mut Value) { let Some(Value::Array(values)) = get_value_mut(schema, &path) else { continue; }; - let mut needs_conversion = false; - for item in values.iter() { - if !item.is_string() { - needs_conversion = true; - break; - } - } - if !needs_conversion { - continue; - } let next = values .iter() .map(value_to_string) .map(Value::String) .collect::>(); *values = next; + let Some(parent_path) = parent_path(&path) else { + continue; + }; + if let Some(parent) = get_object_mut(schema, &parent_path) { + parent.insert("type".to_string(), Value::String("string".to_string())); + } } } diff --git a/docs/plans/2026-01-21-antigravity-design.md b/docs/plans/2026-01-21-antigravity-design.md deleted file mode 100644 index 2d62ba7..0000000 --- a/docs/plans/2026-01-21-antigravity-design.md +++ /dev/null @@ -1,45 +0,0 @@ -# Antigravity 功能接入设计(token_proxy) - -## 目标 -在 token_proxy 中实现 Antigravity 全量支持:OAuth 登录/刷新、IDE 导入与账号切换、配额/订阅、warmup、代理 provider 与格式转换、Providers UI 集成;跨平台可配置,macOS 提供默认路径。 - -## 架构概览 -后端新增 `src-tauri/src/antigravity/` 模块,负责账号体系与 IDE 交互;代理层新增 antigravity provider(以 Gemini 格式为中间层);前端新增 Providers 分组与 Upstreams 账号选择。 - -### 后端模块划分 -- `antigravity/types.rs`: 账号、登录、配额、IDE 状态、warmup 调度的数据结构。 -- `antigravity/oauth.rs`: OAuth URL 构建、token 交换与刷新、userinfo 获取。 -- `antigravity/login.rs`: 本地回调监听 + 登录会话轮询。 -- `antigravity/store.rs`: token 文件存储(config_dir/antigravity-auth),过期刷新。 -- `antigravity/ide_db.rs`: SQLite 读写、备份/回滚、WAL/SHM 清理、active email 读取。 -- `antigravity/protobuf.rs`: protobuf field 6 注入与抽取(access/refresh/expiry)。 -- `antigravity/ide.rs`: IDE 导入与账号切换流程(终止/注入/重启)。 -- `antigravity/quota.rs`: fetchAvailableModels + loadCodeAssist(project id/plan)。 -- `antigravity/warmup.rs`: 手动 warmup + 轻量调度(interval)。 - -### 代理层接入 -- 新增 provider `antigravity`。 -- 请求转换:Chat/Responses/Anthropic → Gemini(现有转换);Gemini → Antigravity wrapper(新增)。 -- 响应转换:Antigravity(Gemini 格式)→ Chat/Responses/Anthropic(复用现有 Gemini 响应转换)。 -- 上游请求:使用 OAuth access_token;遇 401 自动刷新;base URL 按 daily/sandbox/prod 回退。 - -### 配置与跨平台 -- macOS 默认 IDE DB 路径与进程名来自 quotio。 -- Windows/Linux 允许通过 `proxy_config` 的 antigravity 配置覆盖 IDE DB 路径、进程名、应用路径;未配置则 IDE 功能降级提示。 - -## 关键流程 -1. OAuth 登录:生成 state → 监听本地回调 → exchange token → userinfo(email) → 保存 token record。 -2. IDE 导入:读取 state.vscdb → protobuf 抽取 access/refresh/expiry → 保存账号记录。 -3. IDE 切换:终止 IDE → 备份 DB → 注入新 token → 重启 IDE → 清理备份;失败回滚。 -4. 代理请求:将请求转换为 Gemini → Antigravity wrapper → 上游发送 → Gemini 响应转换回客户端格式。 -5. 配额:loadCodeAssist 获取 project id 与订阅 tier → fetchAvailableModels → 组装 quota 列表。 -6. Warmup:对指定 account+model 执行 generateContent(maxOutputTokens=1);可按 interval 定时。 - -## 错误处理 -- DB 锁/超时:busy_timeout + 重试;失败回滚备份。 -- OAuth 刷新失败:账号标记过期并在 UI 提示。 -- IDE 未安装:返回明确错误,UI 仅显示导入/切换不可用。 - -## 测试策略 -- Rust:protobuf 注入/抽取单测;IDE DB 读写 mock(若需要)。 -- TS:类型检查 + Providers UI 基础交互。 diff --git a/src-tauri/src/proxy/anthropic_compat.rs b/src-tauri/src/proxy/anthropic_compat.rs deleted file mode 100644 index dfa0d02..0000000 --- a/src-tauri/src/proxy/anthropic_compat.rs +++ /dev/null @@ -1,37 +0,0 @@ -use axum::body::Bytes; - -use super::http_client::ProxyHttpClients; - -mod media; -mod request; -mod response; -mod tools; - -pub(crate) async fn responses_request_to_anthropic( - body: &Bytes, - http_clients: &ProxyHttpClients, -) -> Result { - request::responses_request_to_anthropic(body, http_clients).await -} - -pub(crate) async fn anthropic_request_to_responses( - body: &Bytes, - http_clients: &ProxyHttpClients, -) -> Result { - request::anthropic_request_to_responses(body, http_clients).await -} - -pub(crate) fn responses_response_to_anthropic( - body: &Bytes, - model_hint: Option<&str>, -) -> Result { - response::responses_response_to_anthropic(body, model_hint) -} - -pub(crate) fn anthropic_response_to_responses(body: &Bytes) -> Result { - response::anthropic_response_to_responses(body) -} - -#[cfg(test)] -#[path = "anthropic_compat.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/anthropic_compat.test.rs b/src-tauri/src/proxy/anthropic_compat.test.rs deleted file mode 100644 index af03330..0000000 --- a/src-tauri/src/proxy/anthropic_compat.test.rs +++ /dev/null @@ -1,183 +0,0 @@ -use super::*; - -use axum::body::Bytes; -use serde_json::{json, Value}; - -use crate::proxy::http_client::ProxyHttpClients; - -fn run_async(future: impl std::future::Future) -> T { - tokio::runtime::Runtime::new() - .expect("create tokio runtime") - .block_on(future) -} - -fn bytes_from_json(value: Value) -> Bytes { - Bytes::from(serde_json::to_vec(&value).expect("serialize JSON")) -} - -fn json_from_bytes(bytes: Bytes) -> Value { - serde_json::from_slice(&bytes).expect("parse JSON") -} - -#[test] -fn anthropic_request_to_responses_maps_tools_and_tool_blocks() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - - let input = bytes_from_json(json!({ - "model": "claude-3-5-sonnet", - "max_tokens": 123, - "stream": true, - "system": "sys", - "stop_sequences": ["a", "b"], - "tools": [ - { - "name": "search", - "description": "Search something", - "input_schema": { - "type": "object", - "properties": { "q": { "type": "string" } }, - "required": ["q"] - } - } - ], - "tool_choice": { - "type": "tool", - "name": "search", - "disable_parallel_tool_use": true - }, - "messages": [ - { "role": "user", "content": [{ "type": "text", "text": "hi" }] }, - { "role": "assistant", "content": [{ "type": "tool_use", "id": "call_1", "name": "search", "input": { "q": "x" } }] }, - { "role": "user", "content": [{ "type": "tool_result", "tool_use_id": "call_1", "content": "ok" }] } - ] - })); - - let output = run_async(async { - anthropic_request_to_responses(&input, &http_clients) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["model"], json!("claude-3-5-sonnet")); - assert_eq!(value["max_output_tokens"], json!(123)); - assert_eq!(value["stream"], json!(true)); - assert_eq!(value["instructions"], json!("sys")); - - assert_eq!(value["tools"][0]["type"], json!("function")); - assert_eq!(value["tools"][0]["name"], json!("search")); - assert_eq!(value["tools"][0]["parameters"]["required"], json!(["q"])); - - assert_eq!(value["tool_choice"]["type"], json!("function")); - assert_eq!(value["tool_choice"]["name"], json!("search")); - assert_eq!(value["parallel_tool_calls"], json!(false)); - assert_eq!(value["stop"], json!(["a", "b"])); - - let input_items = value["input"].as_array().expect("input array"); - assert_eq!(input_items[0]["type"], json!("message")); - assert_eq!(input_items[0]["role"], json!("user")); - assert_eq!(input_items[0]["content"][0]["type"], json!("input_text")); - assert_eq!(input_items[0]["content"][0]["text"], json!("hi")); - - assert_eq!(input_items[1]["type"], json!("function_call")); - assert_eq!(input_items[1]["call_id"], json!("call_1")); - assert_eq!(input_items[1]["name"], json!("search")); - assert_eq!(input_items[1]["arguments"], json!("{\"q\":\"x\"}")); - - assert_eq!(input_items[2]["type"], json!("function_call_output")); - assert_eq!(input_items[2]["call_id"], json!("call_1")); - assert_eq!(input_items[2]["output"], json!("ok")); -} - -#[test] -fn responses_request_to_anthropic_maps_tool_choice_and_tool_result() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "max_output_tokens": 50, - "stream": true, - "stop": ["a", "b"], - "tools": [ - { - "type": "function", - "name": "search", - "description": "Search something", - "parameters": { - "type": "object", - "properties": { "q": { "type": "string" } }, - "required": ["q"] - } - } - ], - "tool_choice": { "type": "function", "name": "search" }, - "parallel_tool_calls": false, - "input": [ - { "type": "message", "role": "user", "content": [{ "type": "input_text", "text": "hi" }] }, - { "type": "function_call", "call_id": "call_1", "name": "search", "arguments": "{\"q\":\"x\"}" }, - { "type": "function_call_output", "call_id": "call_1", "output": "ok" } - ] - })); - - let output = run_async(async { - responses_request_to_anthropic(&input, &http_clients) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["max_tokens"], json!(50)); - assert_eq!(value["stream"], json!(true)); - - assert_eq!(value["tools"][0]["name"], json!("search")); - assert_eq!(value["tool_choice"]["type"], json!("tool")); - assert_eq!(value["tool_choice"]["name"], json!("search")); - assert_eq!(value["tool_choice"]["disable_parallel_tool_use"], json!(true)); - assert_eq!(value["stop_sequences"], json!(["a", "b"])); - - let messages = value["messages"].as_array().expect("messages array"); - assert_eq!(messages.len(), 3); - assert_eq!(messages[0]["role"], json!("user")); - assert_eq!(messages[0]["content"][0]["type"], json!("text")); - assert_eq!(messages[0]["content"][0]["text"], json!("hi")); - - assert_eq!(messages[1]["role"], json!("assistant")); - assert_eq!(messages[1]["content"][0]["type"], json!("tool_use")); - assert_eq!(messages[1]["content"][0]["id"], json!("call_1")); - assert_eq!(messages[1]["content"][0]["name"], json!("search")); - assert_eq!(messages[1]["content"][0]["input"]["q"], json!("x")); - - assert_eq!(messages[2]["role"], json!("user")); - assert_eq!(messages[2]["content"][0]["type"], json!("tool_result")); - assert_eq!(messages[2]["content"][0]["tool_use_id"], json!("call_1")); - assert_eq!(messages[2]["content"][0]["content"], json!("ok")); -} - -#[test] -fn responses_response_to_anthropic_includes_thinking_block() { - let input = bytes_from_json(json!({ - "id": "resp_thinking", - "model": "gpt-4.1", - "output": [ - { - "type": "message", - "role": "assistant", - "content": [ - { "type": "reasoning_text", "text": "think" }, - { "type": "output_text", "text": "ok" } - ] - } - ], - "usage": { "input_tokens": 1, "output_tokens": 2 } - })); - - let output = responses_response_to_anthropic(&input, None).expect("transform"); - let value = json_from_bytes(output); - - assert_eq!(value["content"][0]["type"], json!("thinking")); - assert_eq!(value["content"][0]["thinking"], json!("think")); - assert!(value["content"][0]["signature"].as_str().is_some()); - assert_eq!(value["content"][1]["type"], json!("text")); - assert_eq!(value["content"][1]["text"], json!("ok")); -} diff --git a/src-tauri/src/proxy/anthropic_compat/media.rs b/src-tauri/src/proxy/anthropic_compat/media.rs deleted file mode 100644 index 63ad60e..0000000 --- a/src-tauri/src/proxy/anthropic_compat/media.rs +++ /dev/null @@ -1,186 +0,0 @@ -use base64::Engine; -use futures_util::StreamExt; -use reqwest::header::CONTENT_TYPE; -use serde_json::{json, Map, Value}; - -use super::super::http_client::ProxyHttpClients; - -const MAX_MEDIA_DOWNLOAD_BYTES: usize = 64 * 1024 * 1024; - -pub(super) async fn input_image_part_to_claude_block( - part: &Map, - http_clients: &ProxyHttpClients, -) -> Result, String> { - let url = part.get("image_url").and_then(normalize_url_value); - let Some(url) = url else { - return Ok(None); - }; - - let (media_type, data) = resolve_media_to_base64(&url, http_clients).await?; - Ok(Some(json!({ - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": data - } - }))) -} - -pub(super) async fn input_file_part_to_claude_block( - part: &Map, - http_clients: &ProxyHttpClients, -) -> Result, String> { - // Prefer file_url, but accept "file" wrappers some clients may send. - let url = part - .get("file_url") - .and_then(normalize_url_value) - .or_else(|| part.get("file").and_then(normalize_url_value)); - let Some(url) = url else { - // file_id needs an OpenAI Files API call; keep it explicit instead of guessing. - if part.get("file_id").and_then(Value::as_str).is_some() { - return Err("input_file with file_id is not supported; use file_url or data: URL.".to_string()); - } - return Ok(None); - }; - - let (media_type, data) = resolve_media_to_base64(&url, http_clients).await?; - Ok(Some(json!({ - "type": "document", - "source": { - "type": "base64", - "media_type": media_type, - "data": data - } - }))) -} - -pub(super) fn claude_image_block_to_input_image_part(block: &Map) -> Option { - let source = block.get("source").and_then(Value::as_object)?; - if source.get("type").and_then(Value::as_str) != Some("base64") { - return None; - } - let media_type = source.get("media_type").and_then(Value::as_str).unwrap_or("image/png"); - let data = source.get("data").and_then(Value::as_str)?; - Some(json!({ - "type": "input_image", - "image_url": format!("data:{media_type};base64,{data}") - })) -} - -pub(super) fn claude_document_block_to_input_file_part(block: &Map) -> Option { - let source = block.get("source").and_then(Value::as_object)?; - if source.get("type").and_then(Value::as_str) != Some("base64") { - return None; - } - let media_type = source - .get("media_type") - .and_then(Value::as_str) - .unwrap_or("application/octet-stream"); - let data = source.get("data").and_then(Value::as_str)?; - Some(json!({ - "type": "input_file", - "file_url": format!("data:{media_type};base64,{data}") - })) -} - -fn normalize_url_value(value: &Value) -> Option { - match value { - Value::String(url) => Some(url.to_string()), - Value::Object(object) => object.get("url").and_then(Value::as_str).map(|url| url.to_string()), - _ => None, - } -} - -async fn resolve_media_to_base64( - url: &str, - http_clients: &ProxyHttpClients, -) -> Result<(String, String), String> { - if let Some((media_type, data)) = parse_data_url(url) { - return Ok((media_type, data)); - } - if url.starts_with("http://") || url.starts_with("https://") { - return download_url_as_base64(url, http_clients).await; - } - Err("Unsupported media URL; expected http(s):// or data: URL.".to_string()) -} - -fn parse_data_url(url: &str) -> Option<(String, String)> { - let (meta, data) = url.strip_prefix("data:")?.split_once(",")?; - let meta = meta.trim(); - let data = data.trim(); - - // We only need to forward the base64 blob; avoid decoding to keep it fast. - let (media_type, is_base64) = match meta.split_once(";") { - Some((media_type, rest)) => (media_type.trim(), rest.trim() == "base64"), - None => (meta, false), - }; - if !is_base64 { - return None; - } - if media_type.is_empty() { - return Some(("application/octet-stream".to_string(), data.to_string())); - } - Some((media_type.to_string(), data.to_string())) -} - -async fn download_url_as_base64( - url: &str, - http_clients: &ProxyHttpClients, -) -> Result<(String, String), String> { - let client = http_clients.client_for_proxy_url(None)?; - let res = client - .get(url) - .send() - .await - .map_err(|err| format!("Failed to download media: {err}"))?; - - if !res.status().is_success() { - return Err(format!("Media download failed with status: {}", res.status())); - } - - let header_type = res - .headers() - .get(CONTENT_TYPE) - .and_then(|value| value.to_str().ok()) - .and_then(|value| value.split(';').next()) - .map(|value| value.trim().to_string()); - - let mut bytes = Vec::new(); - let mut stream = res.bytes_stream(); - while let Some(next) = stream.next().await { - let chunk = next.map_err(|err| format!("Failed to download media: {err}"))?; - if bytes.len().saturating_add(chunk.len()) > MAX_MEDIA_DOWNLOAD_BYTES { - return Err(format!( - "Media download exceeds {MAX_MEDIA_DOWNLOAD_BYTES} bytes limit." - )); - } - bytes.extend_from_slice(&chunk); - } - - let sniffed = sniff_mime_type(&bytes); - let media_type = header_type.unwrap_or_else(|| sniffed); - - let base64_data = base64::engine::general_purpose::STANDARD.encode(&bytes); - Ok((media_type, base64_data)) -} - -fn sniff_mime_type(bytes: &[u8]) -> String { - if bytes.len() >= 8 && bytes[..8] == [0x89, b'P', b'N', b'G', 0x0D, 0x0A, 0x1A, 0x0A] { - return "image/png".to_string(); - } - if bytes.len() >= 3 && bytes[..3] == [0xFF, 0xD8, 0xFF] { - return "image/jpeg".to_string(); - } - if bytes.len() >= 6 && (&bytes[..6] == b"GIF87a" || &bytes[..6] == b"GIF89a") { - return "image/gif".to_string(); - } - if bytes.len() >= 12 && &bytes[..4] == b"RIFF" && &bytes[8..12] == b"WEBP" { - return "image/webp".to_string(); - } - if bytes.len() >= 4 && &bytes[..4] == b"%PDF" { - return "application/pdf".to_string(); - } - "application/octet-stream".to_string() -} - diff --git a/src-tauri/src/proxy/anthropic_compat/request.rs b/src-tauri/src/proxy/anthropic_compat/request.rs deleted file mode 100644 index 574aeab..0000000 --- a/src-tauri/src/proxy/anthropic_compat/request.rs +++ /dev/null @@ -1,586 +0,0 @@ -use axum::body::Bytes; -use serde_json::{json, Map, Value}; - -use super::media; -use super::tools; -use super::super::http_client::ProxyHttpClients; - -pub(super) async fn responses_request_to_anthropic( - body: &Bytes, - http_clients: &ProxyHttpClients, -) -> Result { - let value: Value = - serde_json::from_slice(body).map_err(|_| "Request body must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Request body must be a JSON object.".to_string()); - }; - - let model = object - .get("model") - .and_then(Value::as_str) - .ok_or_else(|| "Request must include model.".to_string())?; - - let stream = object.get("stream").and_then(Value::as_bool).unwrap_or(false); - - let max_tokens = object - .get("max_output_tokens") - .or_else(|| object.get("max_tokens")) - .and_then(Value::as_i64) - .filter(|value| *value > 0) - .unwrap_or(4096); - - let mut system_texts = Vec::new(); - if let Some(instructions) = object.get("instructions").and_then(Value::as_str) { - if !instructions.trim().is_empty() { - system_texts.push(instructions.to_string()); - } - } - - let input = object.get("input").ok_or_else(|| "Request must include input.".to_string())?; - let mut messages = Vec::new(); - responses_input_to_claude_messages(input, &mut system_texts, &mut messages, http_clients).await?; - - let mut out = Map::new(); - out.insert("model".to_string(), Value::String(model.to_string())); - out.insert("max_tokens".to_string(), Value::Number(max_tokens.into())); - out.insert("stream".to_string(), Value::Bool(stream)); - out.insert("messages".to_string(), Value::Array(messages)); - - if let Some(system) = join_system_texts(system_texts) { - out.insert("system".to_string(), system_blocks_from_text(system)); - } - - if let Some(temperature) = object.get("temperature") { - out.insert("temperature".to_string(), temperature.clone()); - } - if let Some(top_p) = object.get("top_p") { - out.insert("top_p".to_string(), top_p.clone()); - } - - if let Some(stop_sequences) = tools::map_openai_stop_to_anthropic_stop_sequences(object.get("stop")) { - out.insert("stop_sequences".to_string(), stop_sequences); - } - - if let Some(tools_value) = object.get("tools") { - out.insert( - "tools".to_string(), - tools::map_responses_tools_to_anthropic(tools_value), - ); - } - - let parallel_tool_calls = object.get("parallel_tool_calls").and_then(Value::as_bool); - if let Some(tool_choice) = tools::map_responses_tool_choice_to_anthropic( - object.get("tool_choice"), - parallel_tool_calls, - ) { - out.insert("tool_choice".to_string(), tool_choice); - } - - serde_json::to_vec(&Value::Object(out)) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize request: {err}")) -} - -pub(super) async fn anthropic_request_to_responses( - body: &Bytes, - _http_clients: &ProxyHttpClients, -) -> Result { - let value: Value = - serde_json::from_slice(body).map_err(|_| "Request body must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Request body must be a JSON object.".to_string()); - }; - - let model = object - .get("model") - .and_then(Value::as_str) - .ok_or_else(|| "Request must include model.".to_string())?; - - let stream = object.get("stream").and_then(Value::as_bool).unwrap_or(false); - - let max_output_tokens = object - .get("max_tokens") - .and_then(Value::as_i64) - .filter(|value| *value > 0) - .unwrap_or(4096); - - let mut input_items = Vec::new(); - - let mut instructions_texts = Vec::new(); - if let Some(system) = object.get("system") { - if let Some(text) = claude_system_to_text(system) { - if !text.trim().is_empty() { - instructions_texts.push(text); - } - } - } - - let Some(messages) = object.get("messages").and_then(Value::as_array) else { - return Err("Request must include messages.".to_string()); - }; - for message in messages { - claude_message_to_responses_input_items(message, &mut input_items)?; - } - - let mut out = Map::new(); - out.insert("model".to_string(), Value::String(model.to_string())); - out.insert( - "max_output_tokens".to_string(), - Value::Number(max_output_tokens.into()), - ); - out.insert("stream".to_string(), Value::Bool(stream)); - out.insert("input".to_string(), Value::Array(input_items)); - - if let Some(instructions) = join_system_texts(instructions_texts) { - out.insert("instructions".to_string(), Value::String(instructions)); - } - - if let Some(temperature) = object.get("temperature") { - out.insert("temperature".to_string(), temperature.clone()); - } - if let Some(top_p) = object.get("top_p") { - out.insert("top_p".to_string(), top_p.clone()); - } - - if let Some(stop) = tools::map_anthropic_stop_sequences_to_openai_stop(object.get("stop_sequences")) { - out.insert("stop".to_string(), stop); - } - - if let Some(tools_value) = object.get("tools") { - out.insert( - "tools".to_string(), - tools::map_anthropic_tools_to_responses(tools_value), - ); - } - - let (tool_choice, parallel_tool_calls) = - tools::map_anthropic_tool_choice_to_responses(object.get("tool_choice")); - if let Some(tool_choice) = tool_choice { - out.insert("tool_choice".to_string(), tool_choice); - } - if let Some(parallel_tool_calls) = parallel_tool_calls { - out.insert( - "parallel_tool_calls".to_string(), - Value::Bool(parallel_tool_calls), - ); - } - - serde_json::to_vec(&Value::Object(out)) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize request: {err}")) -} - -async fn responses_input_to_claude_messages( - input: &Value, - system_texts: &mut Vec, - messages: &mut Vec, - http_clients: &ProxyHttpClients, -) -> Result<(), String> { - match input { - Value::String(text) => { - let content = vec![json!({ "type": "text", "text": text })]; - messages.push(json!({ "role": "user", "content": content })); - } - Value::Array(items) => { - for item in items { - responses_input_item_to_claude_messages(item, system_texts, messages, http_clients) - .await?; - } - } - _ => return Err("Responses input must be a string or array.".to_string()), - } - Ok(()) -} - -async fn responses_input_item_to_claude_messages( - item: &Value, - system_texts: &mut Vec, - messages: &mut Vec, - http_clients: &ProxyHttpClients, -) -> Result<(), String> { - // Accept Chat-style `{role, content}` items, as some clients send that into /v1/responses. - if item.get("role").and_then(Value::as_str).is_some() { - let role = item.get("role").and_then(Value::as_str).unwrap_or("user"); - let content = item.get("content"); - if role == "system" { - if let Some(text) = extract_text_from_any_content(content) { - if !text.trim().is_empty() { - system_texts.push(text); - } - } - return Ok(()); - } - let blocks = responses_message_content_to_claude_blocks(content, http_clients).await?; - push_claude_message(messages, role, blocks); - return Ok(()); - } - - let Some(object) = item.as_object() else { - return Ok(()); - }; - let item_type = object.get("type").and_then(Value::as_str).unwrap_or(""); - match item_type { - "message" => { - let role = object.get("role").and_then(Value::as_str).unwrap_or("user"); - let content = object.get("content"); - if role == "system" { - if let Some(text) = extract_text_from_any_content(content) { - if !text.trim().is_empty() { - system_texts.push(text); - } - } - return Ok(()); - } - let blocks = responses_message_content_to_claude_blocks(content, http_clients).await?; - push_claude_message(messages, role, blocks); - } - "function_call" => { - let tool_use_id = object - .get("call_id") - .or_else(|| object.get("id")) - .and_then(Value::as_str) - .unwrap_or("tool_use_proxy"); - let name = object.get("name").and_then(Value::as_str).unwrap_or(""); - let arguments = object.get("arguments").and_then(Value::as_str).unwrap_or(""); - let input = parse_tool_input_object(arguments); - let block = json!({ - "type": "tool_use", - "id": tool_use_id, - "name": name, - "input": input - }); - push_tool_use_block(messages, block); - } - "function_call_output" => { - let tool_use_id = object.get("call_id").and_then(Value::as_str).unwrap_or(""); - let output = object.get("output").and_then(Value::as_str).unwrap_or(""); - let block = json!({ - "type": "tool_result", - "tool_use_id": tool_use_id, - "content": output - }); - push_tool_result_block(messages, block); - } - _ => {} - } - Ok(()) -} - -async fn responses_message_content_to_claude_blocks( - content: Option<&Value>, - http_clients: &ProxyHttpClients, -) -> Result, String> { - let Some(content) = content else { - return Ok(Vec::new()); - }; - match content { - Value::String(text) => Ok(vec![json!({ "type": "text", "text": text })]), - Value::Array(parts) => { - let mut blocks = Vec::new(); - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str).unwrap_or(""); - match part_type { - "input_text" | "output_text" | "text" => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - blocks.push(json!({ "type": "text", "text": text })); - } - } - "refusal" => { - // Some OpenAI Responses payloads represent refusals as dedicated parts. - let text = part - .get("refusal") - .or_else(|| part.get("text")) - .and_then(Value::as_str) - .unwrap_or(""); - if !text.is_empty() { - blocks.push(json!({ "type": "text", "text": text })); - } - } - "input_image" => { - if let Some(block) = media::input_image_part_to_claude_block(part, http_clients).await? { - blocks.push(block); - } - } - "input_file" => { - if let Some(block) = media::input_file_part_to_claude_block(part, http_clients).await? { - blocks.push(block); - } - } - _ => {} - } - } - Ok(blocks) - } - _ => Ok(Vec::new()), - } -} - -fn claude_message_to_responses_input_items(message: &Value, input_items: &mut Vec) -> Result<(), String> { - let Some(message) = message.as_object() else { - return Ok(()); - }; - let role = message.get("role").and_then(Value::as_str).unwrap_or("user"); - if role == "system" { - return Ok(()); - } - - let content = message.get("content"); - let blocks = claude_content_to_blocks(content); - - let mut message_parts = Vec::new(); - let text_part_type = match role { - // OpenAI Responses schema expects assistant messages in `input` to use output types. - // This avoids errors like: "Invalid value: 'input_text'. Supported values are: 'output_text' and 'refusal'." - "assistant" => "output_text", - _ => "input_text", - }; - for block in &blocks { - let Some(block) = block.as_object() else { - continue; - }; - let block_type = block.get("type").and_then(Value::as_str).unwrap_or(""); - match block_type { - "text" => { - if let Some(text) = block.get("text").and_then(Value::as_str) { - message_parts.push(json!({ "type": text_part_type, "text": text })); - } - } - "image" => { - if let Some(part) = media::claude_image_block_to_input_image_part(block) { - message_parts.push(part); - } - } - "document" => { - if let Some(part) = media::claude_document_block_to_input_file_part(block) { - message_parts.push(part); - } - } - "tool_use" => {} - "tool_result" => {} - _ => {} - } - } - if !message_parts.is_empty() { - input_items.push(json!({ - "type": "message", - "role": role, - "content": message_parts - })); - } - - for block in blocks { - let Some(block) = block.as_object() else { - continue; - }; - let block_type = block.get("type").and_then(Value::as_str).unwrap_or(""); - match block_type { - "tool_use" => { - let call_id = block.get("id").and_then(Value::as_str).unwrap_or("call_proxy"); - let name = block.get("name").and_then(Value::as_str).unwrap_or(""); - let input = block.get("input").cloned().unwrap_or_else(|| json!({})); - let arguments = serde_json::to_string(&input).unwrap_or_else(|_| "{}".to_string()); - input_items.push(json!({ - "type": "function_call", - "call_id": call_id, - "name": name, - "arguments": arguments - })); - } - "tool_result" => { - let call_id = block.get("tool_use_id").and_then(Value::as_str).unwrap_or(""); - let output_raw = block.get("content").cloned().unwrap_or_else(|| json!("")); - let output_text = match &output_raw { - Value::String(text) => text.clone(), - other => serde_json::to_string(other).unwrap_or_default(), - }; - let is_error = block.get("is_error").and_then(Value::as_bool).unwrap_or(false); - let mut item = Map::new(); - item.insert("type".to_string(), json!("function_call_output")); - item.insert("call_id".to_string(), Value::String(call_id.to_string())); - item.insert("output".to_string(), Value::String(output_text)); - if is_error { - item.insert("is_error".to_string(), Value::Bool(true)); - } - if !matches!(output_raw, Value::String(_)) { - item.insert("output_parts".to_string(), output_raw); - } - input_items.push(Value::Object(item)); - } - _ => {} - } - } - - Ok(()) -} - -fn claude_system_to_text(value: &Value) -> Option { - match value { - Value::String(text) => Some(text.to_string()), - Value::Array(items) => { - let texts = items - .iter() - .filter_map(|item| item.as_object()) - .filter(|item| item.get("type").and_then(Value::as_str) == Some("text")) - .filter_map(|item| item.get("text").and_then(Value::as_str)) - .map(|text| text.to_string()) - .collect::>(); - join_system_texts(texts) - } - _ => None, - } -} - -fn join_system_texts(texts: Vec) -> Option { - let combined = texts - .into_iter() - .map(|t| t.trim().to_string()) - .filter(|t| !t.is_empty()) - .collect::>() - .join("\n"); - if combined.is_empty() { - None - } else { - Some(combined) - } -} - -fn system_blocks_from_text(text: String) -> Value { - // new-api style: `system` uses array blocks for better compatibility. - // Keep the original newlines inside the single block (avoid splitting). - json!([{ "type": "text", "text": text }]) -} - -fn extract_text_from_any_content(value: Option<&Value>) -> Option { - let Some(value) = value else { - return None; - }; - match value { - Value::String(text) => Some(text.to_string()), - Value::Array(parts) => { - let mut combined = String::new(); - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - if let Some(text) = part.get("text").and_then(Value::as_str) { - combined.push_str(text); - } - } - if combined.is_empty() { None } else { Some(combined) } - } - Value::Object(object) => object.get("text").and_then(Value::as_str).map(|t| t.to_string()), - _ => None, - } -} - -fn parse_tool_input_object(arguments: &str) -> Value { - let parsed = serde_json::from_str::(arguments).ok(); - match parsed { - Some(Value::Object(object)) => Value::Object(object), - Some(other) => json!({ "_": other }), - None => json!({ "_raw": arguments }), - } -} - -fn claude_content_to_blocks(content: Option<&Value>) -> Vec { - let Some(content) = content else { - return Vec::new(); - }; - match content { - Value::String(text) => vec![json!({ "type": "text", "text": text })], - Value::Array(items) => items - .iter() - .cloned() - .map(|mut item| { - normalize_text_block_in_place(&mut item); - item - }) - .collect(), - _ => Vec::new(), - } -} - -fn normalize_text_block_in_place(block: &mut Value) { - let Some(object) = block.as_object_mut() else { - return; - }; - let block_type = object.get("type").and_then(Value::as_str).unwrap_or(""); - if block_type != "text" { - return; - } - let text_value = object.get("text"); - let new_text = text_value.and_then(extract_text_value); - if let Some(new_text) = new_text { - object.insert("text".to_string(), Value::String(new_text)); - return; - } - // If text exists but is not convertible, coerce to empty string to satisfy schema. - if text_value.is_some() { - object.insert("text".to_string(), Value::String(String::new())); - } -} - -fn extract_text_value(value: &Value) -> Option { - match value { - Value::String(text) => Some(text.to_string()), - Value::Object(object) => { - if let Some(text) = object.get("text") { - return extract_text_value(text); - } - if let Some(text) = object.get("value") { - return extract_text_value(text); - } - None - } - _ => None, - } -} - -fn push_claude_message(messages: &mut Vec, role: &str, blocks: Vec) { - let content = blocks; - if content.is_empty() { - return; - } - messages.push(json!({ "role": role, "content": content })); -} - -fn push_tool_use_block(messages: &mut Vec, block: Value) { - if let Some(last) = messages.last_mut().and_then(Value::as_object_mut) { - if last.get("role").and_then(Value::as_str) == Some("assistant") { - if let Some(content) = last.get_mut("content").and_then(Value::as_array_mut) { - content.push(block); - return; - } - } - } - messages.push(json!({ "role": "assistant", "content": [block] })); -} - -fn push_tool_result_block(messages: &mut Vec, block: Value) { - if let Some(last) = messages.last_mut().and_then(Value::as_object_mut) { - if last.get("role").and_then(Value::as_str) == Some("user") { - if let Some(content) = last.get_mut("content") { - ensure_claude_content_array_in_place(content); - if let Some(content) = content.as_array_mut() { - content.push(block); - return; - } - } - } - } - messages.push(json!({ "role": "user", "content": [block] })); -} - -fn ensure_claude_content_array_in_place(content: &mut Value) { - if content.is_array() { - return; - } - if let Some(text) = content.as_str() { - *content = Value::Array(vec![json!({ "type": "text", "text": text })]); - return; - } - *content = Value::Array(Vec::new()); -} diff --git a/src-tauri/src/proxy/anthropic_compat/response.rs b/src-tauri/src/proxy/anthropic_compat/response.rs deleted file mode 100644 index 0539747..0000000 --- a/src-tauri/src/proxy/anthropic_compat/response.rs +++ /dev/null @@ -1,273 +0,0 @@ -use axum::body::Bytes; -use base64::{engine::general_purpose::STANDARD, Engine as _}; -use serde_json::{json, Map, Value}; -use sha2::{Digest, Sha256}; -use std::time::{SystemTime, UNIX_EPOCH}; - -use crate::proxy::compat_reason; - -fn now_s() -> i64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs() as i64 -} - -pub(super) fn responses_response_to_anthropic( - body: &Bytes, - model_hint: Option<&str>, -) -> Result { - let value: Value = - serde_json::from_slice(body).map_err(|_| "Upstream response must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Upstream response must be a JSON object.".to_string()); - }; - - let id = object.get("id").and_then(Value::as_str).unwrap_or("msg_proxy"); - let model = object - .get("model") - .and_then(Value::as_str) - .or(model_hint) - .unwrap_or("unknown"); - - let usage = object - .get("usage") - .and_then(Value::as_object) - .map(map_openai_usage_to_anthropic_usage); - - let output = object - .get("output") - .and_then(Value::as_array) - .map(|items| items.as_slice()) - .unwrap_or(&[]); - let mut combined_text = String::new(); - let mut thinking_text = String::new(); - let mut tool_uses = Vec::new(); - - for item in output { - let Some(item) = item.as_object() else { - continue; - }; - match item.get("type").and_then(Value::as_str) { - Some("message") => { - if item.get("role").and_then(Value::as_str) != Some("assistant") { - continue; - } - if let Some(content) = item.get("content").and_then(Value::as_array) { - for part in content { - let Some(part) = part.as_object() else { - continue; - }; - match part.get("type").and_then(Value::as_str) { - Some("output_text") => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - combined_text.push_str(text); - } - } - Some("reasoning_text") => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - thinking_text.push_str(text); - } - } - _ => {} - } - } - } - } - Some("function_call") => { - if let Some(tool_use) = responses_function_call_to_tool_use(item) { - tool_uses.push(tool_use); - } - } - _ => {} - } - } - - let mut content = Vec::new(); - if !thinking_text.trim().is_empty() { - let signature = thinking_signature(&thinking_text); - let mut block = json!({ "type": "thinking", "thinking": thinking_text }); - if let (Some(signature), Some(block)) = - (signature, block.as_object_mut()) - { - block.insert("signature".to_string(), Value::String(signature)); - } - content.push(block); - } - if !combined_text.trim().is_empty() || tool_uses.is_empty() { - content.push(json!({ "type": "text", "text": combined_text })); - } - let has_tool_uses = !tool_uses.is_empty(); - content.extend(tool_uses); - - let finish_reason = - compat_reason::chat_finish_reason_from_response_object(object, has_tool_uses); - let stop_reason = compat_reason::anthropic_stop_reason_from_chat_finish_reason(finish_reason); - - let out = json!({ - "id": id, - "type": "message", - "role": "assistant", - "model": model, - "content": content, - "stop_reason": stop_reason, - "stop_sequence": null, - "usage": usage.unwrap_or_else(|| json!({ "input_tokens": 0, "output_tokens": 0 })) - }); - - serde_json::to_vec(&out) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize response: {err}")) -} - -pub(super) fn anthropic_response_to_responses(body: &Bytes) -> Result { - let value: Value = - serde_json::from_slice(body).map_err(|_| "Upstream response must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Upstream response must be a JSON object.".to_string()); - }; - - let id = object.get("id").and_then(Value::as_str).unwrap_or("resp_proxy"); - let model = object.get("model").and_then(Value::as_str).unwrap_or("unknown"); - let created_at = now_s(); - - let usage = object - .get("usage") - .and_then(Value::as_object) - .map(map_anthropic_usage_to_openai_usage); - - let content = object - .get("content") - .and_then(Value::as_array) - .map(|items| items.as_slice()) - .unwrap_or(&[]); - let mut output = Vec::new(); - - let mut combined_text = String::new(); - let mut tool_calls = Vec::new(); - for block in content { - let Some(block) = block.as_object() else { - continue; - }; - match block.get("type").and_then(Value::as_str) { - Some("text") => { - if let Some(text) = block.get("text").and_then(Value::as_str) { - combined_text.push_str(text); - } - } - Some("tool_use") => { - if let Some(call) = tool_use_to_responses_function_call(block) { - tool_calls.push(call); - } - } - _ => {} - } - } - - let parallel_tool_calls = tool_calls.len() > 1; - - if !combined_text.trim().is_empty() || tool_calls.is_empty() { - output.push(json!({ - "type": "message", - "id": "msg_proxy", - "status": "completed", - "role": "assistant", - "content": [ - { "type": "output_text", "text": combined_text, "annotations": [] } - ] - })); - } - output.extend(tool_calls); - - let out = json!({ - "id": id, - "object": "response", - "created_at": created_at, - "status": "completed", - "error": null, - "model": model, - "parallel_tool_calls": parallel_tool_calls, - "output": output, - "usage": usage - }); - - serde_json::to_vec(&out) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize response: {err}")) -} - -fn responses_function_call_to_tool_use(item: &Map) -> Option { - let call_id = item.get("call_id").and_then(Value::as_str).unwrap_or(""); - let item_id = item.get("id").and_then(Value::as_str).unwrap_or(""); - let id = if !call_id.is_empty() { call_id } else { item_id }; - if id.is_empty() { - return None; - } - let name = item.get("name").and_then(Value::as_str).unwrap_or(""); - let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or(""); - let input = serde_json::from_str::(arguments) - .ok() - .and_then(|v| v.as_object().cloned().map(Value::Object)) - .unwrap_or_else(|| json!({ "_raw": arguments })); - Some(json!({ - "type": "tool_use", - "id": id, - "name": name, - "input": input - })) -} - -fn tool_use_to_responses_function_call(block: &Map) -> Option { - let call_id = block.get("id").and_then(Value::as_str).unwrap_or("call_proxy"); - let name = block.get("name").and_then(Value::as_str).unwrap_or(""); - let input = block.get("input").cloned().unwrap_or_else(|| json!({})); - let arguments = serde_json::to_string(&input).unwrap_or_else(|_| "{}".to_string()); - Some(json!({ - "id": format!("fc_{call_id}"), - "type": "function_call", - "status": "completed", - "arguments": arguments, - "call_id": call_id, - "name": name - })) -} - -fn thinking_signature(text: &str) -> Option { - if text.trim().is_empty() { - return None; - } - let mut hasher = Sha256::new(); - hasher.update(text.as_bytes()); - Some(STANDARD.encode(hasher.finalize())) -} - -fn map_openai_usage_to_anthropic_usage(usage: &Map) -> Value { - let input_tokens = usage - .get("input_tokens") - .or_else(|| usage.get("prompt_tokens")) - .and_then(Value::as_u64) - .unwrap_or(0); - let output_tokens = usage - .get("output_tokens") - .or_else(|| usage.get("completion_tokens")) - .and_then(Value::as_u64) - .unwrap_or(0); - json!({ - "input_tokens": input_tokens, - "output_tokens": output_tokens - }) -} - -fn map_anthropic_usage_to_openai_usage(usage: &Map) -> Value { - let input_tokens = usage.get("input_tokens").and_then(Value::as_u64).unwrap_or(0); - let output_tokens = usage.get("output_tokens").and_then(Value::as_u64).unwrap_or(0); - let cache_read = usage.get("cache_read_input_tokens").and_then(Value::as_u64); - let cache_creation = usage.get("cache_creation_input_tokens").and_then(Value::as_u64); - json!({ - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": input_tokens + output_tokens, - "cache_read_input_tokens": cache_read, - "cache_creation_input_tokens": cache_creation - }) -} diff --git a/src-tauri/src/proxy/anthropic_compat/tools.rs b/src-tauri/src/proxy/anthropic_compat/tools.rs deleted file mode 100644 index 7316d2c..0000000 --- a/src-tauri/src/proxy/anthropic_compat/tools.rs +++ /dev/null @@ -1,173 +0,0 @@ -use serde_json::{json, Map, Value}; - -// OpenAI Responses tool_choice <-> Anthropic Messages tool_choice mapping -// Mirrors QuantumNous/new-api semantics: -// - "required" <-> "any" -// - parallel_tool_calls <-> disable_parallel_tool_use (negated) - -pub(super) fn map_responses_tools_to_anthropic(value: &Value) -> Value { - let Some(tools) = value.as_array() else { - return Value::Array(Vec::new()); - }; - let mapped = tools.iter().filter_map(map_responses_tool).collect::>(); - Value::Array(mapped) -} - -fn map_responses_tool(value: &Value) -> Option { - let tool = value.as_object()?; - let tool_type = tool.get("type").and_then(Value::as_str).unwrap_or(""); - if tool_type != "function" { - return None; - } - - // Accept both Responses-style ({name, description, parameters}) and Chat-style ({function:{...}}). - if let Some(name) = tool.get("name").and_then(Value::as_str) { - let mut out = Map::new(); - out.insert("name".to_string(), Value::String(name.to_string())); - if let Some(description) = tool.get("description") { - out.insert("description".to_string(), description.clone()); - } - if let Some(parameters) = tool.get("parameters") { - out.insert("input_schema".to_string(), parameters.clone()); - } - return Some(Value::Object(out)); - } - - let function = tool.get("function").and_then(Value::as_object)?; - let name = function.get("name").and_then(Value::as_str)?; - let mut out = Map::new(); - out.insert("name".to_string(), Value::String(name.to_string())); - if let Some(description) = function.get("description") { - out.insert("description".to_string(), description.clone()); - } - if let Some(parameters) = function.get("parameters") { - out.insert("input_schema".to_string(), parameters.clone()); - } - Some(Value::Object(out)) -} - -pub(super) fn map_anthropic_tools_to_responses(value: &Value) -> Value { - let Some(tools) = value.as_array() else { - return Value::Array(Vec::new()); - }; - let mapped = tools - .iter() - .filter_map(map_anthropic_tool) - .collect::>(); - Value::Array(mapped) -} - -fn map_anthropic_tool(value: &Value) -> Option { - let tool = value.as_object()?; - let name = tool.get("name").and_then(Value::as_str)?; - let mut out = Map::new(); - out.insert("type".to_string(), json!("function")); - out.insert("name".to_string(), Value::String(name.to_string())); - if let Some(description) = tool.get("description") { - out.insert("description".to_string(), description.clone()); - } - if let Some(input_schema) = tool.get("input_schema") { - out.insert("parameters".to_string(), input_schema.clone()); - } - Some(Value::Object(out)) -} - -pub(super) fn map_responses_tool_choice_to_anthropic( - tool_choice: Option<&Value>, - parallel_tool_calls: Option, -) -> Option { - let mut out = match tool_choice { - None => None, - Some(Value::String(choice)) => match choice.as_str() { - "auto" => Some(json!({ "type": "auto" })), - "required" => Some(json!({ "type": "any" })), - "none" => Some(json!({ "type": "none" })), - _ => None, - }, - Some(Value::Object(choice)) => { - if choice.get("type").and_then(Value::as_str) != Some("function") { - None - } else { - let name = choice.get("name").and_then(Value::as_str).unwrap_or(""); - if name.is_empty() { - None - } else { - Some(json!({ "type": "tool", "name": name })) - } - } - } - _ => None, - }; - - if let Some(parallel) = parallel_tool_calls { - let disable_parallel = !parallel; - if out.is_none() { - out = Some(json!({ "type": "auto" })); - } - if let Some(Value::Object(object)) = out.as_mut() { - object.insert( - "disable_parallel_tool_use".to_string(), - Value::Bool(disable_parallel), - ); - } - } - - out -} - -pub(super) fn map_anthropic_tool_choice_to_responses( - tool_choice: Option<&Value>, -) -> (Option, Option) { - let Some(tool_choice) = tool_choice.and_then(Value::as_object) else { - return (None, None); - }; - - let choice_type = tool_choice.get("type").and_then(Value::as_str).unwrap_or(""); - let mapped_choice = match choice_type { - "auto" => Some(json!("auto")), - "any" => Some(json!("required")), - "none" => Some(json!("none")), - "tool" => { - let name = tool_choice.get("name").and_then(Value::as_str).unwrap_or(""); - if name.is_empty() { - None - } else { - Some(json!({ "type": "function", "name": name })) - } - } - _ => None, - }; - - let parallel_tool_calls = tool_choice - .get("disable_parallel_tool_use") - .and_then(Value::as_bool) - .map(|disable| !disable); - - (mapped_choice, parallel_tool_calls) -} - -pub(super) fn map_openai_stop_to_anthropic_stop_sequences(stop: Option<&Value>) -> Option { - let Some(stop) = stop else { - return None; - }; - match stop { - Value::String(_) => Some(Value::Array(vec![stop.clone()])), - Value::Array(items) => Some(Value::Array(items.clone())), - _ => None, - } -} - -pub(super) fn map_anthropic_stop_sequences_to_openai_stop(stop: Option<&Value>) -> Option { - let Some(stop) = stop else { - return None; - }; - let Some(items) = stop.as_array() else { - return None; - }; - match items.len() { - 0 => None, - 1 => Some(items[0].clone()), - _ => Some(Value::Array(items.clone())), - } -} - diff --git a/src-tauri/src/proxy/antigravity_compat.rs b/src-tauri/src/proxy/antigravity_compat.rs deleted file mode 100644 index 22c913a..0000000 --- a/src-tauri/src/proxy/antigravity_compat.rs +++ /dev/null @@ -1,426 +0,0 @@ -use axum::body::Bytes; -use futures_util::{stream::unfold, StreamExt}; -use serde_json::{json, Map, Value}; -use std::collections::VecDeque; -use sha2::{Digest, Sha256}; - -use crate::oauth_util::generate_state; -use crate::proxy::antigravity_schema::clean_json_schema_for_antigravity; -use crate::proxy::sse::SseEventParser; - -mod signature_cache; -mod claude; - -pub(crate) use claude::claude_request_to_antigravity; - -const DEFAULT_MODEL: &str = "gemini-1.5-flash"; -const THOUGHT_SIGNATURE_SENTINEL: &str = "skip_thought_signature_validator"; -const PAYLOAD_USER_AGENT: &str = "antigravity"; -const ANTIGRAVITY_SYSTEM_INSTRUCTION: &str = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"; - -pub(crate) fn wrap_gemini_request( - body: &Bytes, - model_hint: Option<&str>, - project_id: Option<&str>, - _user_agent: &str, -) -> Result { - let value: Value = - serde_json::from_slice(body).map_err(|_| "Request body must be JSON.".to_string())?; - let Some(mut request) = value.as_object().cloned() else { - return Err("Request body must be a JSON object.".to_string()); - }; - - let model = map_antigravity_model(&extract_model(&mut request, model_hint)); - let model_lower = model.to_lowercase(); - let should_clean_tool_schema = - model_lower.contains("claude") || model_lower.contains("gemini-3-pro-high"); - normalize_system_instruction(&mut request); - normalize_tool_schema(&mut request, should_clean_tool_schema); - ensure_system_instruction(&mut request, &model); - remove_safety_settings(&mut request); - ensure_tool_thought_signature(&mut request); - ensure_session_id(&mut request); - trim_generation_config(&mut request, &model); - - let project = project_id - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()) - .unwrap_or_else(|| generate_project_id().unwrap_or_default()); - let request_id = generate_agent_id("agent").unwrap_or_else(|_| "agent-unknown".to_string()); - - let mut root = Map::new(); - root.insert("project".to_string(), Value::String(project)); - root.insert("request".to_string(), Value::Object(request)); - root.insert("model".to_string(), Value::String(model)); - root.insert("requestId".to_string(), Value::String(request_id)); - root.insert( - "userAgent".to_string(), - Value::String(PAYLOAD_USER_AGENT.to_string()), - ); - root.insert("requestType".to_string(), Value::String("agent".to_string())); - - serde_json::to_vec(&Value::Object(root)) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize Antigravity request: {err}")) -} - -pub(crate) fn unwrap_response(bytes: &Bytes) -> Result { - let value: Value = match serde_json::from_slice(bytes) { - Ok(value) => value, - Err(_) => return Ok(bytes.clone()), - }; - if let Some(response) = value.get("response") { - return serde_json::to_vec(response) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize Antigravity response: {err}")); - } - if let Some(array) = value.as_array() { - let mut responses = Vec::new(); - for item in array { - if let Some(response) = item.get("response") { - responses.push(response.clone()); - } - } - if !responses.is_empty() { - return serde_json::to_vec(&responses) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize Antigravity response: {err}")); - } - } - Ok(bytes.clone()) -} - -pub(crate) fn stream_antigravity_to_gemini( - upstream: impl futures_util::stream::Stream> + Unpin + Send + 'static, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = AntigravityStreamState::new(upstream); - unfold(state, |state| async move { state.step().await }) -} - -struct AntigravityStreamState { - upstream: S, - parser: SseEventParser, - out: VecDeque, - finished: bool, -} - -impl AntigravityStreamState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new(upstream: S) -> Self { - Self { - upstream, - parser: SseEventParser::new(), - out: VecDeque::new(), - finished: false, - } - } - - async fn step(mut self) -> Option<(Result, Self)> { - loop { - if let Some(next) = self.out.pop_front() { - return Some((Ok(next), self)); - } - if self.finished { - return None; - } - match self.upstream.next().await { - Some(Ok(chunk)) => { - let mut events = Vec::new(); - self.parser.push_chunk(&chunk, |data| events.push(data)); - for data in events { - self.push_event(&data); - } - } - Some(Err(err)) => { - self.finished = true; - return Some((Err(err), self)); - } - None => { - self.finished = true; - let mut events = Vec::new(); - self.parser.finish(|data| events.push(data)); - for data in events { - self.push_event(&data); - } - } - } - } - } - - fn push_event(&mut self, data: &str) { - if data == "[DONE]" { - self.out - .push_back(Bytes::from(format!("data: {data}\n\n"))); - return; - } - let Ok(value) = serde_json::from_str::(data) else { - return; - }; - if let Some(response) = value.get("response") { - if let Ok(json) = serde_json::to_string(response) { - self.out - .push_back(Bytes::from(format!("data: {json}\n\n"))); - } - } else if let Ok(json) = serde_json::to_string(&value) { - self.out - .push_back(Bytes::from(format!("data: {json}\n\n"))); - } - } -} - -fn extract_model(request: &mut Map, model_hint: Option<&str>) -> String { - let from_body = request - .get("model") - .and_then(Value::as_str) - .map(|value| value.trim()) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()); - request.remove("model"); - let hint = model_hint - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()); - from_body - .or(hint) - .unwrap_or_else(|| DEFAULT_MODEL.to_string()) -} - -pub(crate) fn map_antigravity_model(model: &str) -> String { - let trimmed = model.trim(); - if trimmed.is_empty() { - return DEFAULT_MODEL.to_string(); - } - // Align with CLIProxyAPIPlus conventions: - // - Some clients expose Claude models behind a "gemini-" prefix (e.g. gemini-claude-opus-4-5-thinking) - // while Antigravity upstream uses the stable Claude name without the prefix. - if trimmed.starts_with("gemini-claude-") { - return trimmed.trim_start_matches("gemini-").to_string(); - } - - // Claude Code / Amp CLI may request date-suffixed Claude models (e.g. claude-opus-4-5-20251101). - // Antigravity does not expose date-suffixed IDs; map them to the stable Antigravity model names. - if let Some(mapped) = map_claude_date_model_to_antigravity(trimmed) { - return mapped; - } - - trimmed.to_string() -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "antigravity_compat.test.rs"] -mod tests; - -fn map_claude_date_model_to_antigravity(model: &str) -> Option { - if !model.starts_with("claude-") { - return None; - } - - // Allow optional "-thinking" suffix (some clients encode "thinking" in the model ID). - let (base, _has_thinking_suffix) = match model.strip_suffix("-thinking") { - Some(value) => (value, true), - None => (model, false), - }; - - // Detect the trailing date segment in `...-YYYYMMDD`. - let (without_date, date_suffix) = base.rsplit_once('-')?; - if date_suffix.len() != 8 || !date_suffix.chars().all(|ch| ch.is_ascii_digit()) { - return None; - } - - // Known Claude 4.5 model families: map to the Antigravity stable names. - // NOTE: Antigravity appears to expose Sonnet/Opus (and their thinking variants) but not Haiku. - if without_date.starts_with("claude-opus-4-5") { - return Some("claude-opus-4-5-thinking".to_string()); - } - if without_date.starts_with("claude-sonnet-4-5") { - return Some("claude-sonnet-4-5-thinking".to_string()); - } - if without_date.starts_with("claude-haiku-4-5") { - // Follow CLIProxyAPIPlus example mapping: route Haiku to a close Gemini alternative. - return Some("gemini-2.5-flash".to_string()); - } - - None -} - -fn normalize_system_instruction(request: &mut Map) { - if let Some(value) = request.remove("system_instruction") { - request.insert("systemInstruction".to_string(), value); - } -} - -fn ensure_system_instruction(request: &mut Map, model: &str) { - let lower = model.to_lowercase(); - if !(lower.contains("claude") || lower.contains("gemini-3-pro-high")) { - return; - } - let existing_parts = request - .get("systemInstruction") - .and_then(|value| value.get("parts")) - .and_then(Value::as_array) - .cloned(); - let mut parts = vec![ - json!({ "text": ANTIGRAVITY_SYSTEM_INSTRUCTION }), - json!({ "text": format!("Please ignore following [ignore]{ANTIGRAVITY_SYSTEM_INSTRUCTION}[/ignore]") }), - ]; - if let Some(existing_parts) = existing_parts { - parts.extend(existing_parts); - } - let mut system_instruction = Map::new(); - system_instruction.insert("role".to_string(), Value::String("user".to_string())); - system_instruction.insert("parts".to_string(), Value::Array(parts)); - request.insert( - "systemInstruction".to_string(), - Value::Object(system_instruction), - ); -} - -fn normalize_tool_schema(request: &mut Map, enabled: bool) { - if !enabled { - return; - } - let Some(tools) = request.get_mut("tools").and_then(Value::as_array_mut) else { - return; - }; - for tool_group in tools { - let Some(group) = tool_group.as_object_mut() else { - continue; - }; - if let Some(decls) = group - .get_mut("functionDeclarations") - .and_then(Value::as_array_mut) - { - for decl in decls { - let Some(decl) = decl.as_object_mut() else { - continue; - }; - if let Some(parameters) = decl.remove("parametersJsonSchema") { - decl.insert("parameters".to_string(), parameters); - } - if let Some(params) = decl.get_mut("parameters").and_then(Value::as_object_mut) { - params.remove("$schema"); - } - if let Some(params) = decl.get_mut("parameters") { - clean_json_schema_for_antigravity(params); - } - } - } - if let Some(decls) = group - .get_mut("function_declarations") - .and_then(Value::as_array_mut) - { - for decl in decls { - let Some(decl) = decl.as_object_mut() else { - continue; - }; - if let Some(parameters) = decl.remove("parametersJsonSchema") { - decl.insert("parameters".to_string(), parameters); - } - if let Some(params) = decl.get_mut("parameters").and_then(Value::as_object_mut) { - params.remove("$schema"); - } - if let Some(params) = decl.get_mut("parameters") { - clean_json_schema_for_antigravity(params); - } - } - } - } -} - -fn remove_safety_settings(request: &mut Map) { - request.remove("safetySettings"); - if let Some(obj) = request.get_mut("request").and_then(Value::as_object_mut) { - obj.remove("safetySettings"); - } -} - -fn ensure_tool_thought_signature(request: &mut Map) { - let Some(contents) = request.get_mut("contents").and_then(Value::as_array_mut) else { - return; - }; - for content in contents { - let Some(parts) = content.get_mut("parts").and_then(Value::as_array_mut) else { - continue; - }; - for part in parts { - let Some(obj) = part.as_object_mut() else { - continue; - }; - if !(obj.contains_key("functionCall") || obj.contains_key("functionResponse")) { - continue; - } - obj.entry("thoughtSignature".to_string()) - .or_insert_with(|| Value::String(THOUGHT_SIGNATURE_SENTINEL.to_string())); - } - } -} - -fn ensure_session_id(request: &mut Map) { - let session_present = request - .get("sessionId") - .and_then(Value::as_str) - .map(|value| !value.trim().is_empty()) - .unwrap_or(false); - if session_present { - return; - } - if let Some(session_id) = stable_session_id_from_contents(request) { - request.insert("sessionId".to_string(), Value::String(session_id)); - return; - } - let session_id = generate_agent_id("sess").unwrap_or_else(|_| "sess-unknown".to_string()); - request.insert("sessionId".to_string(), Value::String(session_id)); -} - -fn stable_session_id_from_contents(request: &Map) -> Option { - let contents = request.get("contents")?.as_array()?; - for content in contents { - let role = content.get("role").and_then(Value::as_str)?; - if role != "user" { - continue; - } - let parts = content.get("parts").and_then(Value::as_array)?; - let first = parts.first()?; - let text = first.get("text").and_then(Value::as_str)?; - let trimmed = text.trim(); - if trimmed.is_empty() { - continue; - } - let mut hasher = Sha256::new(); - hasher.update(trimmed.as_bytes()); - let hash = hasher.finalize(); - let mut bytes = [0_u8; 8]; - bytes.copy_from_slice(&hash[..8]); - let value = u64::from_be_bytes(bytes) & 0x7FFF_FFFF_FFFF_FFFF; - return Some(format!("-{value}")); - } - None -} - -fn trim_generation_config(request: &mut Map, model: &str) { - if model.to_lowercase().contains("claude") { - return; - } - let Some(gen) = request.get_mut("generationConfig").and_then(Value::as_object_mut) else { - return; - }; - gen.remove("maxOutputTokens"); -} - -fn generate_agent_id(prefix: &str) -> Result { - let state = generate_state(prefix)?; - Ok(state) -} - -fn generate_project_id() -> Result { - let state = generate_state("project")?; - Ok(state) -} diff --git a/src-tauri/src/proxy/antigravity_compat.test.rs b/src-tauri/src/proxy/antigravity_compat.test.rs deleted file mode 100644 index 170ec89..0000000 --- a/src-tauri/src/proxy/antigravity_compat.test.rs +++ /dev/null @@ -1,171 +0,0 @@ -use super::map_antigravity_model; -use super::wrap_gemini_request; -use axum::body::Bytes; -use serde_json::json; - -#[test] -fn keeps_claude_model_unchanged() { - assert_eq!( - map_antigravity_model("claude-3-5-sonnet-20241022"), - "claude-3-5-sonnet-20241022" - ); -} - -#[test] -fn trims_model_name() { - assert_eq!(map_antigravity_model(" gemini-1.5-pro "), "gemini-1.5-pro"); -} - -#[test] -fn returns_default_on_empty_model() { - assert_eq!(map_antigravity_model(""), "gemini-1.5-flash"); -} - -#[test] -fn strips_gemini_prefix_for_claude_aliases() { - assert_eq!( - map_antigravity_model("gemini-claude-opus-4-5-thinking"), - "claude-opus-4-5-thinking" - ); -} - -#[test] -fn maps_claude_opus_date_model_to_stable_thinking_model() { - assert_eq!( - map_antigravity_model("claude-opus-4-5-20251101"), - "claude-opus-4-5-thinking" - ); - assert_eq!( - map_antigravity_model("claude-opus-4-5-20251101-thinking"), - "claude-opus-4-5-thinking" - ); -} - -#[test] -fn maps_claude_sonnet_date_model_to_stable_thinking_model() { - assert_eq!( - map_antigravity_model("claude-sonnet-4-5-20250929"), - "claude-sonnet-4-5-thinking" - ); -} - -#[test] -fn maps_claude_haiku_date_model_to_gemini_fallback() { - assert_eq!( - map_antigravity_model("claude-haiku-4-5-20251001"), - "gemini-2.5-flash" - ); -} - -#[test] -fn injects_antigravity_system_instruction_for_claude() { - let request = json!({ - "model": "claude-3-5-sonnet-20241022", - "contents": [ - { "role": "user", "parts": [{ "text": "hello" }] } - ] - }); - let bytes = Bytes::from(request.to_string()); - let wrapped = wrap_gemini_request(&bytes, None, None, "ua").expect("wrap ok"); - let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); - let system = value["request"]["systemInstruction"].clone(); - assert_eq!(system["role"].as_str(), Some("user")); - let parts = system["parts"].as_array().expect("parts array"); - assert!(parts.len() >= 2); -} - -#[test] -fn cleans_schema_unsupported_fields() { - let request = json!({ - "model": "claude-3-5-sonnet-20241022", - "contents": [ - { "role": "user", "parts": [{ "text": "hello" }] } - ], - "tools": [ - { - "function_declarations": [ - { - "name": "t", - "parametersJsonSchema": { - "type": "object", - "properties": { - "count": { - "type": "number", - "exclusiveMinimum": 0 - }, - "name": { - "type": "string", - "propertyNames": { "pattern": "^[a-z]+$" } - } - } - } - } - ] - } - ] - }); - let bytes = Bytes::from(request.to_string()); - let wrapped = wrap_gemini_request(&bytes, None, None, "ua").expect("wrap ok"); - let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); - let schema = &value["request"]["tools"][0]["function_declarations"][0]["parameters"]; - let count = schema.get("properties").and_then(|v| v.get("count")); - let name = schema.get("properties").and_then(|v| v.get("name")); - assert!(count.and_then(|v| v.get("exclusiveMinimum")).is_none()); - assert!(name.and_then(|v| v.get("propertyNames")).is_none()); -} - -#[test] -fn cleans_schema_for_gemini_3_pro_high() { - let request = json!({ - "model": "gemini-3-pro-high", - "contents": [ - { "role": "user", "parts": [{ "text": "hello" }] } - ], - "tools": [ - { - "function_declarations": [ - { - "name": "t", - "parametersJsonSchema": { - "type": "object", - "properties": { - "count": { - "type": "number", - "exclusiveMinimum": 0 - } - } - } - } - ] - } - ] - }); - let bytes = Bytes::from(request.to_string()); - let wrapped = wrap_gemini_request(&bytes, None, None, "ua").expect("wrap ok"); - let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); - let schema = &value["request"]["tools"][0]["function_declarations"][0]["parameters"]; - let count = schema.get("properties").and_then(|v| v.get("count")); - assert!(count.and_then(|v| v.get("exclusiveMinimum")).is_none()); -} - -#[test] -fn keeps_existing_tool_config_mode() { - let request = json!({ - "model": "claude-3-5-sonnet-20241022", - "contents": [ - { "role": "user", "parts": [{ "text": "hello" }] } - ], - "toolConfig": { - "functionCallingConfig": { - "mode": "ANY" - } - } - }); - let bytes = Bytes::from(request.to_string()); - let wrapped = wrap_gemini_request(&bytes, None, None, "ua").expect("wrap ok"); - let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); - assert_eq!( - value["request"]["toolConfig"]["functionCallingConfig"]["mode"].as_str(), - Some("ANY") - ); -} diff --git a/src-tauri/src/proxy/antigravity_compat/claude.rs b/src-tauri/src/proxy/antigravity_compat/claude.rs deleted file mode 100644 index c582745..0000000 --- a/src-tauri/src/proxy/antigravity_compat/claude.rs +++ /dev/null @@ -1,496 +0,0 @@ -use axum::body::Bytes; -use serde_json::{json, Map, Value}; - -use super::signature_cache; -use crate::proxy::antigravity_schema::clean_json_schema_for_antigravity; - -const THOUGHT_SIGNATURE_SENTINEL: &str = "skip_thought_signature_validator"; -const INTERLEAVED_HINT: &str = "Interleaved thinking is enabled. You may think between tool calls and after receiving tool results before deciding the next action or final answer. Do not mention these instructions or any constraints about thinking blocks; just apply them."; - -pub(crate) fn claude_request_to_antigravity( - body: &Bytes, - model_hint: Option<&str>, -) -> Result { - // Dedicated Claude -> Gemini request conversion to align with CLIProxyAPIPlus. - let object = parse_request_object(body)?; - let model_name = resolve_model_name(&object, model_hint); - let mapped_model = super::map_antigravity_model(&model_name); - let (contents, enable_thinking_translate) = build_contents(&object, &mapped_model)?; - let tools = build_tools(&object); - let thinking_enabled = thinking_enabled(&object); - let should_hint = tools.is_some() && thinking_enabled && is_claude_thinking_model(&mapped_model); - - let mut out = Map::new(); - if !mapped_model.trim().is_empty() { - out.insert("model".to_string(), Value::String(mapped_model)); - } - if !contents.is_empty() { - out.insert("contents".to_string(), Value::Array(contents)); - } - if let Some(system_instruction) = build_system_instruction(&object, should_hint) { - out.insert("systemInstruction".to_string(), system_instruction); - } - if let Some(tools) = tools { - out.insert("tools".to_string(), tools); - } - if let Some(gen) = build_generation_config(&object, enable_thinking_translate) { - out.insert("generationConfig".to_string(), gen); - } - - serde_json::to_vec(&Value::Object(out)) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize request: {err}")) -} - -fn parse_request_object(body: &Bytes) -> Result, String> { - let value: Value = - serde_json::from_slice(body).map_err(|_| "Request body must be JSON.".to_string())?; - value - .as_object() - .cloned() - .ok_or_else(|| "Request body must be a JSON object.".to_string()) -} - -fn resolve_model_name(object: &Map, model_hint: Option<&str>) -> String { - object - .get("model") - .and_then(Value::as_str) - .map(|value| value.trim()) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()) - .or_else(|| { - model_hint - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()) - }) - .unwrap_or_default() -} - -fn build_system_instruction(object: &Map, should_hint: bool) -> Option { - let mut parts = system_parts(object); - if should_hint { - parts.push(json!({ "text": INTERLEAVED_HINT })); - } - if parts.is_empty() { - return None; - } - Some(json!({ "role": "user", "parts": parts })) -} - -fn system_parts(object: &Map) -> Vec { - let Some(system) = object.get("system") else { - return Vec::new(); - }; - match system { - Value::String(text) => system_parts_from_text(text), - Value::Array(items) => items - .iter() - .filter_map(|item| item.as_object()) - .filter(|item| item.get("type").and_then(Value::as_str) == Some("text")) - .filter_map(|item| item.get("text").and_then(Value::as_str)) - .flat_map(system_parts_from_text) - .collect(), - _ => Vec::new(), - } -} - -fn system_parts_from_text(text: &str) -> Vec { - let trimmed = text.trim(); - if trimmed.is_empty() { - Vec::new() - } else { - vec![json!({ "text": trimmed })] - } -} - -fn thinking_enabled(object: &Map) -> bool { - object - .get("thinking") - .and_then(Value::as_object) - .and_then(|thinking| thinking.get("type")) - .and_then(Value::as_str) - == Some("enabled") -} - -fn is_claude_thinking_model(model_name: &str) -> bool { - let lower = model_name.to_lowercase(); - lower.contains("claude") && lower.contains("thinking") -} - -fn build_contents( - object: &Map, - model_name: &str, -) -> Result<(Vec, bool), String> { - let Some(messages) = object.get("messages").and_then(Value::as_array) else { - return Ok((Vec::new(), true)); - }; - let mut contents = Vec::with_capacity(messages.len()); - let mut enable_thinking_translate = true; - - for message in messages { - let Some(message) = message.as_object() else { - continue; - }; - let role = message.get("role").and_then(Value::as_str).unwrap_or("user"); - let role = if role == "assistant" { "model" } else { role }; - let mut parts = Vec::new(); - let mut current_signature = String::new(); - match message.get("content") { - Some(Value::Array(items)) => { - for item in items { - let Some(item) = item.as_object() else { - continue; - }; - let block_type = item.get("type").and_then(Value::as_str).unwrap_or(""); - handle_block( - item, - block_type, - model_name, - &mut current_signature, - &mut enable_thinking_translate, - &mut parts, - ); - } - } - Some(Value::String(text)) => push_text_part(text, &mut parts), - _ => {} - } - reorder_thinking_parts(role, &mut parts); - contents.push(json!({ "role": role, "parts": parts })); - } - - Ok((contents, enable_thinking_translate)) -} - -fn handle_block( - item: &Map, - block_type: &str, - model_name: &str, - current_signature: &mut String, - enable_thinking_translate: &mut bool, - parts: &mut Vec, -) { - match block_type { - "thinking" => { - handle_thinking_block(item, model_name, current_signature, enable_thinking_translate, parts); - } - "text" => { - if let Some(text) = item.get("text").and_then(Value::as_str) { - push_text_part(text, parts); - } - } - "tool_use" => { - if let Some(part) = tool_use_to_part(item, model_name, current_signature) { - parts.push(part); - } - } - "tool_result" => { - if let Some(part) = tool_result_to_part(item) { - parts.push(part); - } - } - "image" => { - if let Some(part) = image_to_part(item) { - parts.push(part); - } - } - _ => {} - } -} - -fn handle_thinking_block( - item: &Map, - model_name: &str, - current_signature: &mut String, - enable_thinking_translate: &mut bool, - parts: &mut Vec, -) { - let thinking_text = extract_text_value(item.get("thinking")).unwrap_or_default(); - let signature = resolve_thinking_signature(model_name, &thinking_text, item); - if !signature_cache::has_valid_signature(model_name, &signature) { - *enable_thinking_translate = false; - return; - } - *current_signature = signature.clone(); - if !thinking_text.is_empty() { - signature_cache::cache_signature(model_name, &thinking_text, &signature); - } - let mut part = json!({ "thought": true }); - if !thinking_text.is_empty() { - if let Some(part) = part.as_object_mut() { - part.insert("text".to_string(), Value::String(thinking_text)); - } - } - if !signature.is_empty() { - if let Some(part) = part.as_object_mut() { - part.insert("thoughtSignature".to_string(), Value::String(signature)); - } - } - parts.push(part); -} - -fn resolve_thinking_signature( - model_name: &str, - thinking_text: &str, - item: &Map, -) -> String { - let cached = signature_cache::get_cached_signature(model_name, thinking_text); - if !cached.is_empty() { - return cached; - } - let signature = item.get("signature").and_then(Value::as_str).unwrap_or(""); - parse_client_signature(model_name, signature) -} - -fn parse_client_signature(model_name: &str, signature: &str) -> String { - if signature.contains('#') { - let mut parts = signature.splitn(2, '#'); - let prefix = parts.next().unwrap_or(""); - let value = parts.next().unwrap_or(""); - if prefix == model_name { - return value.to_string(); - } - } - signature.to_string() -} - -fn tool_use_to_part( - item: &Map, - model_name: &str, - current_signature: &str, -) -> Option { - let name = item.get("name").and_then(Value::as_str).unwrap_or(""); - let id = item.get("id").and_then(Value::as_str).unwrap_or(""); - let args_raw = parse_tool_use_input(item.get("input"))?; - - let mut part = json!({ - "functionCall": { - "name": name, - "args": args_raw - } - }); - if !id.is_empty() { - if let Some(call) = part.get_mut("functionCall").and_then(Value::as_object_mut) { - call.insert("id".to_string(), Value::String(id.to_string())); - } - } - - let signature = if signature_cache::has_valid_signature(model_name, current_signature) { - current_signature.to_string() - } else { - // Antigravity requires thoughtSignature for tool calls; use sentinel when missing. - THOUGHT_SIGNATURE_SENTINEL.to_string() - }; - if let Some(part) = part.as_object_mut() { - part.insert("thoughtSignature".to_string(), Value::String(signature)); - } - Some(part) -} - -fn parse_tool_use_input(input: Option<&Value>) -> Option { - match input { - Some(Value::Object(object)) => Some(Value::Object(object.clone())), - Some(Value::String(raw)) => serde_json::from_str::(raw).ok().and_then(|val| { - if val.is_object() { - Some(val) - } else { - None - } - }), - _ => None, - } -} - -fn tool_result_to_part(item: &Map) -> Option { - let tool_call_id = item.get("tool_use_id").and_then(Value::as_str).unwrap_or(""); - if tool_call_id.is_empty() { - return None; - } - let func_name = tool_call_name_from_id(tool_call_id); - let response = tool_result_response(item.get("content")); - Some(json!({ - "functionResponse": { - "id": tool_call_id, - "name": func_name, - "response": { "result": response } - } - })) -} - -fn tool_call_name_from_id(tool_call_id: &str) -> String { - let parts = tool_call_id.split('-').collect::>(); - if parts.len() <= 2 { - return tool_call_id.to_string(); - } - parts[..parts.len() - 2].join("-") -} - -fn tool_result_response(value: Option<&Value>) -> Value { - match value { - Some(Value::String(text)) => Value::String(text.to_string()), - Some(Value::Array(items)) => { - if items.len() == 1 { - items[0].clone() - } else { - Value::Array(items.clone()) - } - } - Some(Value::Object(object)) => Value::Object(object.clone()), - Some(other) => other.clone(), - None => Value::String(String::new()), - } -} - -fn image_to_part(item: &Map) -> Option { - let source = item.get("source").and_then(Value::as_object)?; - if source.get("type").and_then(Value::as_str) != Some("base64") { - return None; - } - let media_type = source - .get("media_type") - .and_then(Value::as_str) - .unwrap_or("image/png"); - let data = source.get("data").and_then(Value::as_str)?; - Some(json!({ - "inlineData": { - "mime_type": media_type, - "data": data - } - })) -} - -fn push_text_part(text: &str, parts: &mut Vec) { - if !text.is_empty() { - parts.push(json!({ "text": text })); - } -} - -fn reorder_thinking_parts(role: &str, parts: &mut Vec) { - if role != "model" || parts.is_empty() { - return; - } - let mut thinking = Vec::new(); - let mut others = Vec::new(); - for part in parts.iter() { - if part.get("thought").and_then(Value::as_bool) == Some(true) { - thinking.push(part.clone()); - } else { - others.push(part.clone()); - } - } - if thinking.is_empty() { - return; - } - let first_is_thinking = parts - .first() - .and_then(|part| part.get("thought").and_then(Value::as_bool)) - .unwrap_or(false); - if first_is_thinking && thinking.len() <= 1 { - return; - } - parts.clear(); - parts.extend(thinking); - parts.extend(others); -} - -fn build_tools(object: &Map) -> Option { - let tools = object.get("tools").and_then(Value::as_array)?; - let mut decls = Vec::new(); - for tool in tools { - let Some(tool) = tool.as_object() else { - continue; - }; - let input_schema = tool.get("input_schema"); - let Some(schema) = input_schema.and_then(Value::as_object) else { - continue; - }; - let mut tool_obj = Map::new(); - for (key, value) in tool.iter() { - if key == "input_schema" { - continue; - } - if is_allowed_tool_key(key) { - tool_obj.insert(key.to_string(), value.clone()); - } - } - let mut schema_value = Value::Object(schema.clone()); - clean_json_schema_for_antigravity(&mut schema_value); - tool_obj.insert("parametersJsonSchema".to_string(), schema_value); - decls.push(Value::Object(tool_obj)); - } - if decls.is_empty() { - None - } else { - Some(json!([{ "functionDeclarations": decls }])) - } -} - -fn is_allowed_tool_key(key: &str) -> bool { - matches!( - key, - "name" - | "description" - | "behavior" - | "parameters" - | "parametersJsonSchema" - | "response" - | "responseJsonSchema" - ) -} - -fn build_generation_config(object: &Map, enable_thinking: bool) -> Option { - let mut gen = Map::new(); - if enable_thinking { - if let Some(thinking) = object.get("thinking").and_then(Value::as_object) { - if thinking.get("type").and_then(Value::as_str) == Some("enabled") { - if let Some(budget) = thinking.get("budget_tokens").and_then(Value::as_i64) { - gen.insert( - "thinkingConfig".to_string(), - json!({ - "thinkingBudget": budget, - "includeThoughts": true - }), - ); - } - } - } - } - if let Some(value) = object.get("temperature").and_then(Value::as_f64) { - gen.insert("temperature".to_string(), json!(value)); - } - if let Some(value) = object.get("top_p").and_then(Value::as_f64) { - gen.insert("topP".to_string(), json!(value)); - } - if let Some(value) = object.get("top_k").and_then(Value::as_i64) { - gen.insert("topK".to_string(), json!(value)); - } - if let Some(value) = object.get("max_tokens").and_then(Value::as_i64) { - gen.insert("maxOutputTokens".to_string(), json!(value)); - } - if gen.is_empty() { - None - } else { - Some(Value::Object(gen)) - } -} - -fn extract_text_value(value: Option<&Value>) -> Option { - match value { - Some(Value::String(text)) => Some(text.to_string()), - Some(Value::Object(object)) => { - if let Some(text) = object.get("text") { - return extract_text_value(Some(text)); - } - if let Some(text) = object.get("value") { - return extract_text_value(Some(text)); - } - None - } - _ => None, - } -} - -#[cfg(test)] -#[path = "claude.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/antigravity_compat/signature_cache.rs b/src-tauri/src/proxy/antigravity_compat/signature_cache.rs deleted file mode 100644 index a9095da..0000000 --- a/src-tauri/src/proxy/antigravity_compat/signature_cache.rs +++ /dev/null @@ -1,107 +0,0 @@ -use sha2::{Digest, Sha256}; -use std::collections::HashMap; -use std::sync::{Mutex, OnceLock}; -use std::time::{Duration, Instant}; - -const SIGNATURE_CACHE_TTL: Duration = Duration::from_secs(3 * 60 * 60); -const SIGNATURE_TEXT_HASH_LEN: usize = 16; -const MIN_VALID_SIGNATURE_LEN: usize = 50; -const GEMINI_SKIP_SENTINEL: &str = "skip_thought_signature_validator"; - -type Cache = HashMap>; - -#[derive(Clone)] -struct SignatureEntry { - signature: String, - touched: Instant, -} - -static SIGNATURE_CACHE: OnceLock> = OnceLock::new(); - -fn cache_lock() -> std::sync::MutexGuard<'static, Cache> { - SIGNATURE_CACHE - .get_or_init(|| Mutex::new(HashMap::new())) - .lock() - .unwrap_or_else(|err| err.into_inner()) -} - -pub(crate) fn cache_signature(model_name: &str, text: &str, signature: &str) { - if text.trim().is_empty() || signature.trim().is_empty() { - return; - } - if signature.len() < MIN_VALID_SIGNATURE_LEN { - return; - } - let group_key = model_group_key(model_name); - let text_hash = hash_text(text); - let mut cache = cache_lock(); - let group = cache.entry(group_key).or_insert_with(HashMap::new); - group.insert( - text_hash, - SignatureEntry { - signature: signature.to_string(), - touched: Instant::now(), - }, - ); -} - -pub(crate) fn get_cached_signature(model_name: &str, text: &str) -> String { - let group_key = model_group_key(model_name); - if text.trim().is_empty() { - return fallback_signature(&group_key); - } - let text_hash = hash_text(text); - let mut cache = cache_lock(); - let Some(group) = cache.get_mut(&group_key) else { - return fallback_signature(&group_key); - }; - let Some(entry) = group.get_mut(&text_hash) else { - return fallback_signature(&group_key); - }; - if entry.touched.elapsed() > SIGNATURE_CACHE_TTL { - group.remove(&text_hash); - return fallback_signature(&group_key); - } - entry.touched = Instant::now(); - entry.signature.clone() -} - -pub(crate) fn has_valid_signature(model_name: &str, signature: &str) -> bool { - if signature.trim().is_empty() { - return false; - } - if signature == GEMINI_SKIP_SENTINEL { - return model_group_key(model_name) == "gemini"; - } - signature.len() >= MIN_VALID_SIGNATURE_LEN -} - -fn fallback_signature(group_key: &str) -> String { - if group_key == "gemini" { - GEMINI_SKIP_SENTINEL.to_string() - } else { - String::new() - } -} - -fn model_group_key(model_name: &str) -> String { - let lower = model_name.to_lowercase(); - if lower.contains("gpt") { - return "gpt".to_string(); - } - if lower.contains("claude") { - return "claude".to_string(); - } - if lower.contains("gemini") { - return "gemini".to_string(); - } - model_name.trim().to_string() -} - -fn hash_text(text: &str) -> String { - let mut hasher = Sha256::new(); - hasher.update(text.as_bytes()); - let digest = hasher.finalize(); - let hex = format!("{:x}", digest); - hex.chars().take(SIGNATURE_TEXT_HASH_LEN).collect() -} diff --git a/src-tauri/src/proxy/antigravity_schema.rs b/src-tauri/src/proxy/antigravity_schema.rs deleted file mode 100644 index f852b7e..0000000 --- a/src-tauri/src/proxy/antigravity_schema.rs +++ /dev/null @@ -1,775 +0,0 @@ -use serde_json::{Map, Value}; -use std::collections::HashMap; - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -enum PathSegment { - Key(String), - Index(usize), -} - -type Path = Vec; - -const UNSUPPORTED_CONSTRAINTS: [&str; 10] = [ - "minLength", - "maxLength", - "exclusiveMinimum", - "exclusiveMaximum", - "pattern", - "minItems", - "maxItems", - "format", - "default", - "examples", -]; - -pub(crate) fn clean_json_schema_for_antigravity(schema: &mut Value) { - convert_refs_to_hints(schema); - convert_const_to_enum(schema); - convert_enum_values_to_strings(schema); - add_enum_hints(schema); - add_additional_properties_hints(schema); - move_constraints_to_description(schema); - - merge_all_of(schema); - flatten_any_of_one_of(schema); - flatten_type_arrays(schema); - - remove_unsupported_keywords(schema); - cleanup_required_fields(schema); - add_empty_schema_placeholder(schema); -} - -fn convert_refs_to_hints(schema: &mut Value) { - let mut paths = collect_paths(schema, "$ref"); - sort_by_depth(&mut paths); - for path in paths { - let Some(value) = get_value(schema, &path) else { - continue; - }; - let ref_val = value.as_str().unwrap_or_default(); - let def_name = ref_val.rsplit('/').next().unwrap_or(ref_val); - let mut hint = format!("See: {def_name}"); - let parent_path = match parent_path(&path) { - Some(parent) => parent, - None => continue, - }; - if let Some(existing) = get_description(schema, &parent_path) { - if !existing.is_empty() { - hint = format!("{existing} ({hint})"); - } - } - let mut replacement = Map::new(); - replacement.insert("type".to_string(), Value::String("object".to_string())); - replacement.insert("description".to_string(), Value::String(hint)); - let _ = set_value_at_path(schema, &parent_path, Value::Object(replacement)); - } -} - -fn convert_const_to_enum(schema: &mut Value) { - let paths = collect_paths(schema, "const"); - for path in paths { - let value = get_value(schema, &path).cloned(); - let Some(value) = value else { - continue; - }; - let parent_path = match parent_path(&path) { - Some(parent) => parent, - None => continue, - }; - let Some(parent) = get_object_mut(schema, &parent_path) else { - continue; - }; - if !parent.contains_key("enum") { - parent.insert("enum".to_string(), Value::Array(vec![value.clone()])); - } - } -} - -fn convert_enum_values_to_strings(schema: &mut Value) { - let paths = collect_paths(schema, "enum"); - for path in paths { - let Some(Value::Array(values)) = get_value_mut(schema, &path) else { - continue; - }; - let mut needs_conversion = false; - for item in values.iter() { - if !item.is_string() { - needs_conversion = true; - break; - } - } - if !needs_conversion { - continue; - } - let next = values - .iter() - .map(value_to_string) - .map(Value::String) - .collect::>(); - *values = next; - } -} - -fn add_enum_hints(schema: &mut Value) { - let paths = collect_paths(schema, "enum"); - for path in paths { - let Some(Value::Array(values)) = get_value(schema, &path) else { - continue; - }; - if values.len() <= 1 || values.len() > 10 { - continue; - } - let hint = values - .iter() - .map(value_to_string) - .collect::>() - .join(", "); - let parent_path = match parent_path(&path) { - Some(parent) => parent, - None => continue, - }; - append_hint(schema, &parent_path, &format!("Allowed: {hint}")); - } -} - -fn add_additional_properties_hints(schema: &mut Value) { - let paths = collect_paths(schema, "additionalProperties"); - for path in paths { - let Some(Value::Bool(false)) = get_value(schema, &path) else { - continue; - }; - let parent_path = match parent_path(&path) { - Some(parent) => parent, - None => continue, - }; - append_hint(schema, &parent_path, "No extra properties allowed"); - } -} - -fn move_constraints_to_description(schema: &mut Value) { - for key in UNSUPPORTED_CONSTRAINTS { - let paths = collect_paths(schema, key); - for path in paths { - let Some(value) = get_value(schema, &path) else { - continue; - }; - if value.is_object() || value.is_array() { - continue; - } - let parent_path = match parent_path(&path) { - Some(parent) => parent, - None => continue, - }; - if is_property_definition(&parent_path) { - continue; - } - append_hint( - schema, - &parent_path, - &format!("{key}: {}", value_to_string(value)), - ); - } - } -} - -fn merge_all_of(schema: &mut Value) { - let mut paths = collect_paths(schema, "allOf"); - sort_by_depth(&mut paths); - for path in paths { - let items = match get_value(schema, &path).and_then(Value::as_array) { - Some(items) => items.clone(), - None => continue, - }; - let parent_path = match parent_path(&path) { - Some(parent) => parent, - None => continue, - }; - { - let Some(parent) = get_object_mut(schema, &parent_path) else { - continue; - }; - for item in items { - if let Some(props) = item.get("properties").and_then(Value::as_object) { - let target = parent - .entry("properties".to_string()) - .or_insert_with(|| Value::Object(Map::new())); - if let Some(target) = target.as_object_mut() { - for (key, value) in props { - target.insert(key.clone(), value.clone()); - } - } - } - if let Some(req) = item.get("required").and_then(Value::as_array) { - let required = parent - .entry("required".to_string()) - .or_insert_with(|| Value::Array(Vec::new())); - let Some(required) = required.as_array_mut() else { - continue; - }; - for value in req { - if let Some(text) = value.as_str() { - if !required.iter().any(|item| item.as_str() == Some(text)) { - required.push(Value::String(text.to_string())); - } - } - } - } - } - } - let _ = delete_at_path(schema, &path); - } -} - -fn flatten_any_of_one_of(schema: &mut Value) { - for key in ["anyOf", "oneOf"] { - let mut paths = collect_paths(schema, key); - sort_by_depth(&mut paths); - for path in paths { - let Some(Value::Array(items)) = get_value(schema, &path) else { - continue; - }; - if items.is_empty() { - continue; - } - let parent_path = match parent_path(&path) { - Some(parent) => parent, - None => continue, - }; - let parent_desc = get_description(schema, &parent_path).unwrap_or_default(); - let (best_idx, all_types) = select_best(items); - let mut selected = items[best_idx].clone(); - if !parent_desc.is_empty() { - merge_description(&mut selected, &parent_desc); - } - if all_types.len() > 1 { - append_hint_raw( - &mut selected, - &format!("Accepts: {}", all_types.join(" | ")), - ); - } - let _ = set_value_at_path(schema, &parent_path, selected); - } - } -} - -fn select_best(items: &[Value]) -> (usize, Vec) { - let mut best_idx = 0; - let mut best_score = -1; - let mut types = Vec::new(); - for (idx, item) in items.iter().enumerate() { - let mut score = 0; - let mut typ = item.get("type").and_then(Value::as_str).unwrap_or("").to_string(); - if typ == "object" || item.get("properties").is_some() { - score = 3; - if typ.is_empty() { - typ = "object".to_string(); - } - } else if typ == "array" || item.get("items").is_some() { - score = 2; - if typ.is_empty() { - typ = "array".to_string(); - } - } else if !typ.is_empty() && typ != "null" { - score = 1; - } else if typ.is_empty() { - typ = "null".to_string(); - } - if !typ.is_empty() { - types.push(typ.clone()); - } - if score > best_score { - best_score = score; - best_idx = idx; - } - } - (best_idx, types) -} - -fn flatten_type_arrays(schema: &mut Value) { - let mut paths = collect_paths(schema, "type"); - sort_by_depth(&mut paths); - let mut nullable_fields: HashMap> = HashMap::new(); - - for path in paths { - let Some(info) = parse_type_array(schema, &path) else { - continue; - }; - apply_type_array(schema, &path, &info, &mut nullable_fields); - } - - apply_nullable_fields(schema, nullable_fields); -} - -struct TypeArrayInfo { - first: String, - non_null: Vec, - has_null: bool, -} - -fn parse_type_array(schema: &Value, path: &Path) -> Option { - let items = get_value(schema, path)?.as_array()?.clone(); - if items.is_empty() { - return None; - } - let mut has_null = false; - let mut non_null = Vec::new(); - for item in items.iter() { - let text = value_to_string(item); - if text == "null" { - has_null = true; - } else if !text.is_empty() { - non_null.push(text); - } - } - let first = non_null - .first() - .cloned() - .unwrap_or_else(|| "string".to_string()); - Some(TypeArrayInfo { - first, - non_null, - has_null, - }) -} - -fn apply_type_array( - schema: &mut Value, - path: &Path, - info: &TypeArrayInfo, - nullable_fields: &mut HashMap>, -) { - if let Some(value) = get_value_mut(schema, path) { - *value = Value::String(info.first.clone()); - } - let Some(parent_path) = parent_path(path) else { - return; - }; - if info.non_null.len() > 1 { - append_hint( - schema, - &parent_path, - &format!("Accepts: {}", info.non_null.join(" | ")), - ); - } - if !info.has_null { - return; - } - let Some((object_path, field_name)) = property_field_from_type_path(path) else { - return; - }; - let mut prop_path = object_path.clone(); - prop_path.push(PathSegment::Key("properties".to_string())); - prop_path.push(PathSegment::Key(field_name.clone())); - append_hint(schema, &prop_path, "(nullable)"); - nullable_fields - .entry(object_path) - .or_default() - .push(field_name); -} - -fn apply_nullable_fields(schema: &mut Value, nullable_fields: HashMap>) { - for (object_path, fields) in nullable_fields { - let mut req_path = object_path.clone(); - req_path.push(PathSegment::Key("required".to_string())); - let Some(Value::Array(required)) = get_value_mut(schema, &req_path) else { - continue; - }; - let filtered = required - .iter() - .filter_map(|item| item.as_str()) - .filter(|name| !fields.iter().any(|field| field == name)) - .map(|value| Value::String(value.to_string())) - .collect::>(); - if filtered.is_empty() { - let _ = delete_at_path(schema, &req_path); - } else { - *required = filtered; - } - } -} - -fn remove_unsupported_keywords(schema: &mut Value) { - let mut keywords = Vec::from(UNSUPPORTED_CONSTRAINTS); - keywords.extend([ - "$schema", - "$defs", - "definitions", - "const", - "$ref", - "additionalProperties", - "propertyNames", - ]); - for key in keywords { - let mut paths = collect_paths(schema, key); - sort_by_depth(&mut paths); - for path in paths { - let parent_path = match parent_path(&path) { - Some(parent) => parent, - None => continue, - }; - if is_property_definition(&parent_path) { - continue; - } - let _ = delete_at_path(schema, &path); - } - } -} - -fn cleanup_required_fields(schema: &mut Value) { - let mut paths = collect_paths(schema, "required"); - sort_by_depth(&mut paths); - for path in paths { - let parent_path = match parent_path(&path) { - Some(parent) => parent, - None => continue, - }; - let props_path = { - let mut next = parent_path.clone(); - next.push(PathSegment::Key("properties".to_string())); - next - }; - let Some(Value::Array(required)) = get_value(schema, &path) else { - continue; - }; - let Some(Value::Object(props)) = get_value(schema, &props_path) else { - continue; - }; - let valid = required - .iter() - .filter_map(|item| item.as_str()) - .filter(|key| props.contains_key(*key)) - .map(|value| Value::String(value.to_string())) - .collect::>(); - if valid.len() == required.len() { - continue; - } - if valid.is_empty() { - let _ = delete_at_path(schema, &path); - } else { - let _ = set_value_at_path(schema, &path, Value::Array(valid)); - } - } -} - -fn add_empty_schema_placeholder(schema: &mut Value) { - let mut paths = collect_paths(schema, "type"); - sort_by_depth(&mut paths); - for path in paths { - let Some(Value::String(value)) = get_value(schema, &path) else { - continue; - }; - if value != "object" { - continue; - } - let parent_path = match parent_path(&path) { - Some(parent) => parent, - None => continue, - }; - apply_schema_placeholder(schema, &parent_path); - } -} - -fn apply_schema_placeholder(schema: &mut Value, parent_path: &Path) { - let Some(parent) = get_object_mut(schema, parent_path) else { - return; - }; - let props = parent.get("properties"); - let req = parent.get("required"); - let has_required = req - .and_then(Value::as_array) - .map(|items| !items.is_empty()) - .unwrap_or(false); - let needs_placeholder = match props { - None => true, - Some(Value::Object(map)) => map.is_empty(), - _ => false, - }; - if needs_placeholder { - add_reason_placeholder(parent); - return; - } - if !has_required { - if parent_path.is_empty() { - return; - } - add_required_placeholder(parent); - } -} - -fn add_reason_placeholder(parent: &mut Map) { - let props = parent - .entry("properties".to_string()) - .or_insert_with(|| Value::Object(Map::new())); - let Some(props) = props.as_object_mut() else { - return; - }; - let reason = props - .entry("reason".to_string()) - .or_insert_with(|| Value::Object(Map::new())); - if let Some(reason) = reason.as_object_mut() { - reason.insert("type".to_string(), Value::String("string".to_string())); - reason.insert( - "description".to_string(), - Value::String("Brief explanation of why you are calling this tool".to_string()), - ); - } - parent.insert( - "required".to_string(), - Value::Array(vec![Value::String("reason".to_string())]), - ); -} - -fn add_required_placeholder(parent: &mut Map) { - let props = parent - .entry("properties".to_string()) - .or_insert_with(|| Value::Object(Map::new())); - let Some(props) = props.as_object_mut() else { - return; - }; - if !props.contains_key("_") { - let mut placeholder = Map::new(); - placeholder.insert("type".to_string(), Value::String("boolean".to_string())); - props.insert("_".to_string(), Value::Object(placeholder)); - } - parent.insert( - "required".to_string(), - Value::Array(vec![Value::String("_".to_string())]), - ); -} - -fn collect_paths(schema: &Value, field: &str) -> Vec { - let mut paths = Vec::new(); - let mut current = Vec::new(); - walk(schema, field, &mut current, &mut paths); - paths -} - -fn walk(value: &Value, field: &str, path: &mut Path, out: &mut Vec) { - match value { - Value::Object(map) => { - for (key, val) in map { - path.push(PathSegment::Key(key.clone())); - if key == field { - out.push(path.clone()); - } - walk(val, field, path, out); - path.pop(); - } - } - Value::Array(items) => { - for (idx, item) in items.iter().enumerate() { - path.push(PathSegment::Index(idx)); - walk(item, field, path, out); - path.pop(); - } - } - _ => {} - } -} - -fn sort_by_depth(paths: &mut Vec) { - paths.sort_by(|a, b| b.len().cmp(&a.len())); -} - -fn parent_path(path: &Path) -> Option { - if path.is_empty() { - return None; - } - let mut parent = path.clone(); - parent.pop(); - Some(parent) -} - -fn get_value<'a>(root: &'a Value, path: &[PathSegment]) -> Option<&'a Value> { - let mut current = root; - for segment in path { - match segment { - PathSegment::Key(key) => { - current = current.get(key)?; - } - PathSegment::Index(index) => { - current = current.get(*index)?; - } - } - } - Some(current) -} - -fn get_value_mut<'a>(root: &'a mut Value, path: &[PathSegment]) -> Option<&'a mut Value> { - let mut current = root; - for segment in path { - match segment { - PathSegment::Key(key) => { - current = current.get_mut(key)?; - } - PathSegment::Index(index) => { - current = current.get_mut(*index)?; - } - } - } - Some(current) -} - -fn set_value_at_path(root: &mut Value, path: &[PathSegment], value: Value) -> bool { - if path.is_empty() { - *root = value; - return true; - } - let (parent, last) = match split_parent(path) { - Some(split) => split, - None => return false, - }; - let Some(parent) = get_value_mut(root, parent) else { - return false; - }; - match last { - PathSegment::Key(key) => { - let Some(obj) = parent.as_object_mut() else { - return false; - }; - obj.insert(key.clone(), value); - true - } - PathSegment::Index(index) => { - let Some(arr) = parent.as_array_mut() else { - return false; - }; - if *index >= arr.len() { - return false; - } - arr[*index] = value; - true - } - } -} - -fn delete_at_path(root: &mut Value, path: &[PathSegment]) -> bool { - let (parent, last) = match split_parent(path) { - Some(split) => split, - None => return false, - }; - let Some(parent) = get_value_mut(root, parent) else { - return false; - }; - match last { - PathSegment::Key(key) => { - let Some(obj) = parent.as_object_mut() else { - return false; - }; - obj.remove(key).is_some() - } - PathSegment::Index(index) => { - let Some(arr) = parent.as_array_mut() else { - return false; - }; - if *index >= arr.len() { - return false; - } - arr.remove(*index); - true - } - } -} - -fn split_parent(path: &[PathSegment]) -> Option<(&[PathSegment], &PathSegment)> { - let len = path.len(); - if len == 0 { - return None; - } - Some((&path[..len - 1], &path[len - 1])) -} - -fn get_object_mut<'a>(root: &'a mut Value, path: &[PathSegment]) -> Option<&'a mut Map> { - get_value_mut(root, path)?.as_object_mut() -} - -fn append_hint(root: &mut Value, path: &[PathSegment], hint: &str) { - let Some(obj) = get_object_mut(root, path) else { - return; - }; - let existing = obj - .get("description") - .and_then(Value::as_str) - .unwrap_or(""); - let next = if existing.is_empty() { - hint.to_string() - } else { - format!("{existing} ({hint})") - }; - obj.insert("description".to_string(), Value::String(next)); -} - -fn append_hint_raw(schema: &mut Value, hint: &str) { - let Some(obj) = schema.as_object_mut() else { - return; - }; - let existing = obj - .get("description") - .and_then(Value::as_str) - .unwrap_or(""); - let next = if existing.is_empty() { - hint.to_string() - } else { - format!("{existing} ({hint})") - }; - obj.insert("description".to_string(), Value::String(next)); -} - -fn merge_description(schema: &mut Value, parent_desc: &str) { - let Some(obj) = schema.as_object_mut() else { - return; - }; - let child_desc = obj - .get("description") - .and_then(Value::as_str) - .unwrap_or(""); - if child_desc.is_empty() { - obj.insert("description".to_string(), Value::String(parent_desc.to_string())); - return; - } - if child_desc == parent_desc { - return; - } - obj.insert( - "description".to_string(), - Value::String(format!("{parent_desc} ({child_desc})")), - ); -} - -fn is_property_definition(path: &[PathSegment]) -> bool { - match path.last() { - Some(PathSegment::Key(key)) if key == "properties" => true, - _ => path.len() == 1 && matches!(path[0], PathSegment::Key(ref key) if key == "properties"), - } -} - -fn get_description(root: &Value, path: &[PathSegment]) -> Option { - let obj = get_value(root, path)?.as_object()?; - let desc = obj.get("description")?.as_str()?.trim().to_string(); - Some(desc) -} - -fn value_to_string(value: &Value) -> String { - match value { - Value::String(value) => value.clone(), - Value::Number(value) => value.to_string(), - Value::Bool(value) => value.to_string(), - Value::Null => "null".to_string(), - other => other.to_string(), - } -} - -fn property_field_from_type_path(path: &[PathSegment]) -> Option<(Path, String)> { - if path.len() < 3 { - return None; - } - let len = path.len(); - if !matches!(path.get(len - 3), Some(PathSegment::Key(key)) if key == "properties") { - return None; - } - let field = match path.get(len - 2) { - Some(PathSegment::Key(key)) => key.clone(), - _ => return None, - }; - Some((path[..len - 3].to_vec(), field)) -} diff --git a/src-tauri/src/proxy/codex_compat.rs b/src-tauri/src/proxy/codex_compat.rs deleted file mode 100644 index 1fccdd7..0000000 --- a/src-tauri/src/proxy/codex_compat.rs +++ /dev/null @@ -1,38 +0,0 @@ -use axum::body::Bytes; -use axum::http::HeaderMap; - -mod headers; -mod request; -mod response; -mod stream; -mod tool_names; - -pub(crate) use headers::apply_codex_headers; -pub(crate) use request::{chat_request_to_codex, responses_request_to_codex}; -pub(crate) use response::{codex_response_to_chat, codex_response_to_responses}; -pub(crate) use stream::{stream_codex_to_chat, stream_codex_to_responses}; - -pub(crate) fn extract_tool_name_map_from_request_body( - body: Option<&str>, -) -> std::collections::HashMap { - let Some(body) = body else { - return std::collections::HashMap::new(); - }; - let bytes = Bytes::copy_from_slice(body.as_bytes()); - request::extract_tool_name_map(&bytes).unwrap_or_default() -} - -pub(crate) fn apply_codex_headers_if_needed( - provider: &str, - headers: &mut HeaderMap, - inbound: &HeaderMap, -) { - if provider != "codex" { - return; - } - apply_codex_headers(headers, inbound); -} - -#[cfg(test)] -#[path = "codex_compat.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/codex_compat.test.rs b/src-tauri/src/proxy/codex_compat.test.rs deleted file mode 100644 index 9d05f09..0000000 --- a/src-tauri/src/proxy/codex_compat.test.rs +++ /dev/null @@ -1,122 +0,0 @@ -use axum::body::Bytes; -use serde_json::json; - -use super::{chat_request_to_codex, codex_response_to_chat, responses_request_to_codex}; -use super::tool_names::shorten_name_if_needed; - -#[test] -fn chat_request_to_codex_sets_model_and_stream() { - let input = json!({ - "model": "gpt-5", - "stream": true, - "messages": [ - { "role": "user", "content": "hi" } - ] - }); - let bytes = Bytes::from(input.to_string()); - let output = chat_request_to_codex(&bytes, Some("gpt-5-codex")).expect("convert"); - let value: serde_json::Value = serde_json::from_slice(&output).expect("json"); - assert_eq!(value["model"], "gpt-5-codex"); - assert_eq!(value["stream"], true); - assert_eq!(value["input"][0]["type"], "message"); -} - -#[test] -fn codex_response_to_chat_restores_tool_name() { - let original = "mcp__very_long_tool_name_for_codex_restoration_check_v1_tool_extra_long_suffix"; - let short = shorten_name_if_needed(original); - assert!(short.len() <= 64); - assert_ne!(short, original); - - let request_body = json!({ - "tools": [ - { "type": "function", "function": { "name": original } } - ] - }) - .to_string(); - - let response = json!({ - "type": "response.completed", - "response": { - "id": "resp_1", - "created_at": 123, - "model": "gpt-5", - "status": "completed", - "output": [ - { "type": "function_call", "call_id": "call_1", "name": short, "arguments": "{}" } - ], - "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } - } - }); - let bytes = Bytes::from(response.to_string()); - let output = codex_response_to_chat(&bytes, Some(&request_body)).expect("convert"); - let value: serde_json::Value = serde_json::from_slice(&output).expect("json"); - let name = value["choices"][0]["message"]["tool_calls"][0]["function"]["name"] - .as_str() - .expect("tool name"); - assert_eq!(name, original); -} - -#[test] -fn chat_request_to_codex_skips_missing_tool_names() { - let input = json!({ - "model": "gpt-5", - "messages": [ - { "role": "user", "content": "hi" } - ], - "tools": [ - { "type": "function", "function": { "description": "noop", "parameters": {} } } - ], - "tool_choice": { "type": "function", "function": {} } - }); - let bytes = Bytes::from(input.to_string()); - let output = chat_request_to_codex(&bytes, Some("gpt-5-codex")).expect("convert"); - let value: serde_json::Value = serde_json::from_slice(&output).expect("json"); - let tools = value["tools"].as_array().expect("tools array"); - assert_eq!(tools.len(), 1); - assert!(tools[0].get("name").is_none()); - let tool_choice = value["tool_choice"].as_object().expect("tool_choice"); - assert_eq!(tool_choice.get("type").and_then(serde_json::Value::as_str), Some("function")); - assert!(tool_choice.get("name").is_none()); -} - -#[test] -fn responses_request_to_codex_uses_top_level_tool_name() { - let input = json!({ - "model": "gpt-5", - "input": "hi", - "tools": [ - { "type": "function", "name": "demo_tool", "description": "noop", "parameters": {} } - ], - "tool_choice": { "type": "function", "name": "demo_tool" } - }); - let bytes = Bytes::from(input.to_string()); - let output = responses_request_to_codex(&bytes, Some("gpt-5-codex")).expect("convert"); - let value: serde_json::Value = serde_json::from_slice(&output).expect("json"); - let tools = value["tools"].as_array().expect("tools array"); - assert_eq!(tools.len(), 1); - assert_eq!(tools[0]["name"], "demo_tool"); - assert_eq!(tools[0]["description"], "noop"); - assert!(tools[0]["parameters"].is_object()); - assert_eq!( - value["tool_choice"].get("name").and_then(serde_json::Value::as_str), - Some("demo_tool") - ); -} - -#[test] -fn responses_request_to_codex_strips_prompt_cache_retention() { - let input = json!({ - "model": "gpt-5", - "input": "hi", - "prompt_cache_retention": "24h", - "previous_response_id": "resp_123", - "safety_identifier": "sid_1" - }); - let bytes = Bytes::from(input.to_string()); - let output = responses_request_to_codex(&bytes, Some("gpt-5-codex")).expect("convert"); - let value: serde_json::Value = serde_json::from_slice(&output).expect("json"); - assert!(value.get("prompt_cache_retention").is_none()); - assert!(value.get("previous_response_id").is_none()); - assert!(value.get("safety_identifier").is_none()); -} diff --git a/src-tauri/src/proxy/codex_compat/headers.rs b/src-tauri/src/proxy/codex_compat/headers.rs deleted file mode 100644 index ab1db4a..0000000 --- a/src-tauri/src/proxy/codex_compat/headers.rs +++ /dev/null @@ -1,46 +0,0 @@ -use axum::http::header::{HeaderName, HeaderValue}; -use axum::http::HeaderMap; -const HEADER_VERSION: &str = "0.21.0"; -const HEADER_OPENAI_BETA: &str = "responses=experimental"; -const DEFAULT_USER_AGENT: &str = "codex_cli_rs/0.50.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"; - -const HEADER_VERSION_NAME: HeaderName = HeaderName::from_static("version"); -const HEADER_OPENAI_BETA_NAME: HeaderName = HeaderName::from_static("openai-beta"); -const HEADER_SESSION_ID_NAME: HeaderName = HeaderName::from_static("session_id"); -const HEADER_USER_AGENT_NAME: HeaderName = HeaderName::from_static("user-agent"); -const HEADER_ACCEPT_NAME: HeaderName = HeaderName::from_static("accept"); -const HEADER_CONNECTION_NAME: HeaderName = HeaderName::from_static("connection"); -const HEADER_ORIGINATOR_NAME: HeaderName = HeaderName::from_static("originator"); -const HEADER_ORIGINATOR: &str = "codex_cli_rs"; - -pub(crate) fn apply_codex_headers(headers: &mut HeaderMap, inbound: &HeaderMap) { - ensure_header(headers, inbound, &HEADER_VERSION_NAME, HEADER_VERSION); - ensure_header(headers, inbound, &HEADER_OPENAI_BETA_NAME, HEADER_OPENAI_BETA); - if !headers.contains_key(&HEADER_SESSION_ID_NAME) { - if let Ok(value) = HeaderValue::from_str(&generate_session_id()) { - headers.insert(HEADER_SESSION_ID_NAME, value); - } - } - ensure_header(headers, inbound, &HEADER_USER_AGENT_NAME, DEFAULT_USER_AGENT); - ensure_header(headers, inbound, &HEADER_ORIGINATOR_NAME, HEADER_ORIGINATOR); - ensure_header(headers, inbound, &HEADER_ACCEPT_NAME, "text/event-stream"); - ensure_header(headers, inbound, &HEADER_CONNECTION_NAME, "Keep-Alive"); -} - -fn ensure_header(headers: &mut HeaderMap, inbound: &HeaderMap, name: &HeaderName, fallback: &str) { - if headers.contains_key(name) { - return; - } - if let Some(value) = inbound.get(name) { - headers.insert(name.clone(), value.clone()); - return; - } - if let Ok(value) = HeaderValue::from_str(fallback) { - headers.insert(name.clone(), value); - } -} - -fn generate_session_id() -> String { - let bytes: [u8; 16] = rand::random(); - bytes.iter().map(|b| format!("{b:02x}")).collect() -} diff --git a/src-tauri/src/proxy/codex_compat/request.rs b/src-tauri/src/proxy/codex_compat/request.rs deleted file mode 100644 index 689e93e..0000000 --- a/src-tauri/src/proxy/codex_compat/request.rs +++ /dev/null @@ -1,447 +0,0 @@ -use axum::body::Bytes; -use serde_json::{json, Map, Value}; -use std::collections::HashMap; - -use super::tool_names::ToolNameMap; - -pub(crate) fn extract_tool_name_map(body: &Bytes) -> Option> { - let value: Value = serde_json::from_slice(body).ok()?; - let object = value.as_object()?; - let tools = object.get("tools")?; - let names = collect_function_tool_names(tools); - if names.is_empty() { - return None; - } - let map = ToolNameMap::from_names(&names); - Some(map.original_by_short) -} - -pub(crate) fn chat_request_to_codex( - body: &Bytes, - model_hint: Option<&str>, -) -> Result { - let object = parse_object(body)?; - let model = resolve_model(&object, model_hint); - let stream = object.get("stream").and_then(Value::as_bool).unwrap_or(false); - let effort = resolve_reasoning_effort(&object, Some(&model)); - let tool_map = build_tool_name_map(&object); - let messages = object - .get("messages") - .and_then(Value::as_array) - .ok_or_else(|| "Chat request must include messages.".to_string())?; - - let mut output = Map::new(); - output.insert("stream".to_string(), Value::Bool(stream)); - output.insert("model".to_string(), Value::String(model)); - output.insert("instructions".to_string(), Value::String(String::new())); - output.insert("parallel_tool_calls".to_string(), Value::Bool(true)); - output.insert("include".to_string(), json!(["reasoning.encrypted_content"])); - output.insert( - "reasoning".to_string(), - json!({ "effort": effort, "summary": "auto" }), - ); - - let input = map_chat_messages_to_input(messages, &tool_map); - output.insert("input".to_string(), Value::Array(input)); - - if let Some(tools) = object.get("tools") { - output.insert("tools".to_string(), map_tools(tools, &tool_map)); - } - if let Some(tool_choice) = object.get("tool_choice") { - output.insert( - "tool_choice".to_string(), - map_tool_choice(tool_choice, &tool_map), - ); - } - apply_text_format(object.get("response_format"), object.get("text"), &mut output); - - output.insert("store".to_string(), Value::Bool(false)); - - serde_json::to_vec(&Value::Object(output)) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize request: {err}")) -} - -pub(crate) fn responses_request_to_codex( - body: &Bytes, - model_hint: Option<&str>, -) -> Result { - let mut object = parse_object(body)?; - normalize_responses_payload(&mut object, model_hint); - let tool_map = build_tool_name_map(&object); - - if let Some(tools) = object.get("tools").cloned() { - object.insert("tools".to_string(), map_tools(&tools, &tool_map)); - } - if let Some(tool_choice) = object.get("tool_choice").cloned() { - object.insert( - "tool_choice".to_string(), - map_tool_choice(&tool_choice, &tool_map), - ); - } - if let Some(input) = object.get_mut("input") { - rewrite_input_function_names(input, &tool_map); - } - - serde_json::to_vec(&Value::Object(object)) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize request: {err}")) -} - -fn parse_object(body: &Bytes) -> Result, String> { - let value: Value = serde_json::from_slice(body) - .map_err(|_| "Request body must be JSON.".to_string())?; - value - .as_object() - .cloned() - .ok_or_else(|| "Request body must be a JSON object.".to_string()) -} - -fn resolve_model(object: &Map, model_hint: Option<&str>) -> String { - if let Some(model) = object.get("model").and_then(Value::as_str) { - return model_hint.unwrap_or(model).to_string(); - } - model_hint.unwrap_or_default().to_string() -} - -fn resolve_reasoning_effort(object: &Map, model: Option<&str>) -> String { - if let Some(value) = object.get("reasoning_effort").and_then(Value::as_str) { - return value.to_string(); - } - if let Some(model) = object.get("model").and_then(Value::as_str) { - if let Some(effort) = parse_effort_suffix(model) { - return effort; - } - } - if let Some(model) = model { - if let Some(effort) = parse_effort_suffix(model) { - return effort; - } - } - "medium".to_string() -} - -fn parse_effort_suffix(model: &str) -> Option { - let (base, effort) = model.rsplit_once("-reasoning-")?; - if base.trim().is_empty() { - return None; - } - let effort = effort.trim().to_ascii_lowercase(); - if effort.is_empty() { - return None; - } - Some(effort) -} - -fn build_tool_name_map(object: &Map) -> ToolNameMap { - let names = object - .get("tools") - .map(collect_function_tool_names) - .unwrap_or_default(); - ToolNameMap::from_names(&names) -} - -fn collect_function_tool_names(value: &Value) -> Vec { - let mut names = Vec::new(); - let Some(items) = value.as_array() else { - return names; - }; - for tool in items { - if tool.get("type").and_then(Value::as_str) != Some("function") { - continue; - } - let name = tool - .get("function") - .and_then(|value| value.get("name")) - .and_then(Value::as_str) - .or_else(|| tool.get("name").and_then(Value::as_str)); - if let Some(name) = name { - if !name.is_empty() { - names.push(name.to_string()); - } - } - } - names -} - -fn map_chat_messages_to_input(messages: &[Value], tool_map: &ToolNameMap) -> Vec { - let mut input = Vec::new(); - for message in messages { - let Some(role) = message.get("role").and_then(Value::as_str) else { - continue; - }; - if role == "tool" { - if let Some(item) = map_tool_message(message) { - input.push(item); - } - continue; - } - if let Some(item) = map_regular_message(message, role) { - input.push(item); - } - if role == "assistant" { - map_tool_calls(message, tool_map, &mut input); - } - } - input -} - -fn map_tool_message(message: &Value) -> Option { - let call_id = message.get("tool_call_id").and_then(Value::as_str)?; - let empty = Value::String(String::new()); - let content = message.get("content").unwrap_or(&empty); - Some(json!({ - "type": "function_call_output", - "call_id": call_id, - "output": value_to_string(content), - })) -} - -fn map_regular_message(message: &Value, role: &str) -> Option { - let content = message.get("content")?; - let parts = map_message_content(role, content); - let target_role = if role == "system" { "developer" } else { role }; - Some(json!({ - "type": "message", - "role": target_role, - "content": parts, - })) -} - -fn map_message_content(role: &str, content: &Value) -> Vec { - let mut parts = Vec::new(); - match content { - Value::String(text) => { - push_text_part(&mut parts, role, text); - } - Value::Array(items) => { - for item in items { - if let Some(text) = item.get("text").and_then(Value::as_str) { - push_text_part(&mut parts, role, text); - continue; - } - if item.get("type").and_then(Value::as_str) == Some("image_url") - && role == "user" - { - if let Some(url) = item.get("image_url").and_then(|value| value.get("url")).and_then(Value::as_str) { - parts.push(json!({ "type": "input_image", "image_url": url })); - } - } - if let Some(text) = item.as_str() { - push_text_part(&mut parts, role, text); - } - } - } - _ => {} - } - parts -} - -fn push_text_part(parts: &mut Vec, role: &str, text: &str) { - let part_type = if role == "assistant" { "output_text" } else { "input_text" }; - parts.push(json!({ "type": part_type, "text": text })); -} - -fn map_tool_calls(message: &Value, tool_map: &ToolNameMap, input: &mut Vec) { - let Some(tool_calls) = message.get("tool_calls").and_then(Value::as_array) else { - return; - }; - for call in tool_calls { - if call.get("type").and_then(Value::as_str) != Some("function") { - continue; - } - let call_id = call.get("id").and_then(Value::as_str).unwrap_or_default(); - let name = call - .get("function") - .and_then(|value| value.get("name")) - .and_then(Value::as_str) - .unwrap_or_default(); - let arguments = call - .get("function") - .and_then(|value| value.get("arguments")) - .and_then(Value::as_str) - .unwrap_or_default(); - input.push(json!({ - "type": "function_call", - "call_id": call_id, - "name": tool_map.shorten(name), - "arguments": arguments, - })); - } -} - -fn map_tools(tools: &Value, tool_map: &ToolNameMap) -> Value { - let Some(items) = tools.as_array() else { - return Value::Array(Vec::new()); - }; - let mut output = Vec::new(); - for tool in items { - let tool_type = tool.get("type").and_then(Value::as_str).unwrap_or_default(); - if tool_type != "function" { - if tool.is_object() { - output.push(tool.clone()); - } - continue; - } - let function = tool.get("function").unwrap_or(&Value::Null); - let mut item = Map::new(); - item.insert("type".to_string(), Value::String("function".to_string())); - let name = function - .get("name") - .and_then(Value::as_str) - .or_else(|| tool.get("name").and_then(Value::as_str)); - if let Some(name) = name { - if !name.is_empty() { - item.insert("name".to_string(), Value::String(tool_map.shorten(name))); - } - } - if let Some(desc) = function.get("description").or_else(|| tool.get("description")) { - item.insert("description".to_string(), desc.clone()); - } - if let Some(params) = function.get("parameters").or_else(|| tool.get("parameters")) { - item.insert("parameters".to_string(), params.clone()); - } - if let Some(strict) = function.get("strict").or_else(|| tool.get("strict")) { - item.insert("strict".to_string(), strict.clone()); - } - output.push(Value::Object(item)); - } - Value::Array(output) -} - -fn map_tool_choice(choice: &Value, tool_map: &ToolNameMap) -> Value { - if let Some(value) = choice.as_str() { - return Value::String(value.to_string()); - } - let Some(object) = choice.as_object() else { - return choice.clone(); - }; - let Some(choice_type) = object.get("type").and_then(Value::as_str) else { - return choice.clone(); - }; - if choice_type != "function" { - return choice.clone(); - } - let name = object - .get("function") - .and_then(|value| value.get("name")) - .and_then(Value::as_str) - .or_else(|| object.get("name").and_then(Value::as_str)); - let mut output = Map::new(); - output.insert("type".to_string(), Value::String("function".to_string())); - if let Some(name) = name { - if !name.is_empty() { - output.insert("name".to_string(), Value::String(tool_map.shorten(name))); - } - } - Value::Object(output) -} - -fn apply_text_format(response_format: Option<&Value>, text: Option<&Value>, output: &mut Map) { - if let Some(rf) = response_format { - let rf_type = rf.get("type").and_then(Value::as_str).unwrap_or_default(); - let mut text_obj = Map::new(); - match rf_type { - "text" => { - text_obj.insert("format".to_string(), json!({ "type": "text" })); - } - "json_schema" => { - let mut format_obj = Map::new(); - format_obj.insert("type".to_string(), Value::String("json_schema".to_string())); - if let Some(schema) = rf.get("json_schema") { - if let Some(name) = schema.get("name") { - format_obj.insert("name".to_string(), name.clone()); - } - if let Some(strict) = schema.get("strict") { - format_obj.insert("strict".to_string(), strict.clone()); - } - if let Some(schema_value) = schema.get("schema") { - format_obj.insert("schema".to_string(), schema_value.clone()); - } - } - text_obj.insert("format".to_string(), Value::Object(format_obj)); - } - _ => {} - } - output.insert("text".to_string(), Value::Object(text_obj)); - } - - if let Some(text) = text { - if let Some(verbosity) = text.get("verbosity") { - let entry = output.entry("text".to_string()).or_insert_with(|| json!({})); - if let Value::Object(obj) = entry { - obj.insert("verbosity".to_string(), verbosity.clone()); - } - } - } -} - -fn normalize_responses_payload(object: &mut Map, model_hint: Option<&str>) { - let model = object - .get("model") - .and_then(Value::as_str) - .or(model_hint) - .unwrap_or_default(); - object.insert("model".to_string(), Value::String(model.to_string())); - object.insert("stream".to_string(), Value::Bool(true)); - object.insert("store".to_string(), Value::Bool(false)); - object.insert("parallel_tool_calls".to_string(), Value::Bool(true)); - object.insert( - "include".to_string(), - json!(["reasoning.encrypted_content"]), - ); - for key in [ - "max_output_tokens", - "max_completion_tokens", - "temperature", - "top_p", - "service_tier", - "previous_response_id", - "prompt_cache_retention", - "safety_identifier", - ] { - object.remove(key); - } - - if !object.contains_key("instructions") { - object.insert("instructions".to_string(), Value::String(String::new())); - } - - let input = match object.get("input") { - Some(Value::String(text)) => vec![json!({ - "type": "message", - "role": "user", - "content": [json!({"type":"input_text","text": text})] - })], - Some(Value::Array(items)) => items.clone(), - _ => Vec::new(), - }; - object.insert("input".to_string(), Value::Array(input)); -} - -fn rewrite_input_function_names(input: &mut Value, tool_map: &ToolNameMap) { - let Some(items) = input.as_array_mut() else { - return; - }; - for item in items { - let Some(item_type) = item.get("type").and_then(Value::as_str) else { - continue; - }; - if item_type != "function_call" { - continue; - } - if let Some(name) = item.get("name").and_then(Value::as_str) { - let short = tool_map.shorten(name); - if let Some(object) = item.as_object_mut() { - object.insert("name".to_string(), Value::String(short)); - } - } - } -} - -fn value_to_string(value: &Value) -> String { - if let Some(text) = value.as_str() { - return text.to_string(); - } - value.to_string() -} diff --git a/src-tauri/src/proxy/codex_compat/response.rs b/src-tauri/src/proxy/codex_compat/response.rs deleted file mode 100644 index fd2e517..0000000 --- a/src-tauri/src/proxy/codex_compat/response.rs +++ /dev/null @@ -1,289 +0,0 @@ -use axum::body::Bytes; -use serde_json::{json, Map, Value}; -use std::collections::HashMap; -use std::time::{SystemTime, UNIX_EPOCH}; - -use super::extract_tool_name_map_from_request_body; - -pub(crate) fn codex_response_to_chat( - bytes: &Bytes, - request_body: Option<&str>, -) -> Result { - let value: Value = serde_json::from_slice(bytes) - .map_err(|_| "Response body must be JSON.".to_string())?; - let Some(response) = extract_response_object(&value) else { - return Err("Codex response missing response object.".to_string()); - }; - let tool_name_map = extract_tool_name_map_from_request_body(request_body); - let output = build_chat_completion_value(&response, &tool_name_map); - - serde_json::to_vec(&output) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize response: {err}")) -} - -pub(crate) fn codex_response_to_responses( - bytes: &Bytes, - request_body: Option<&str>, -) -> Result { - let value: Value = serde_json::from_slice(bytes) - .map_err(|_| "Response body must be JSON.".to_string())?; - let Some(mut response) = extract_response_object(&value) else { - return Err("Codex response missing response object.".to_string()); - }; - let tool_name_map = extract_tool_name_map_from_request_body(request_body); - restore_tool_names_in_response(&mut response, &tool_name_map); - - serde_json::to_vec(&Value::Object(response)) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize response: {err}")) -} - -fn extract_response_object(value: &Value) -> Option> { - if value.get("type").and_then(Value::as_str) == Some("response.completed") { - return value.get("response").and_then(Value::as_object).cloned(); - } - if let Some(response) = value.get("response").and_then(Value::as_object) { - return Some(response.clone()); - } - value.as_object().cloned() -} - -fn build_chat_completion_value( - response: &Map, - tool_name_map: &HashMap, -) -> Value { - let (content_text, reasoning_text, tool_calls) = - extract_response_output(response, tool_name_map); - let id = response - .get("id") - .and_then(Value::as_str) - .unwrap_or("chatcmpl_proxy") - .to_string(); - let created = response - .get("created_at") - .and_then(Value::as_i64) - .unwrap_or_else(now_unix_seconds); - let model = response - .get("model") - .and_then(Value::as_str) - .unwrap_or("unknown") - .to_string(); - let finish_reason = resolve_finish_reason(response, !tool_calls.is_empty()); - let message = build_chat_message(&content_text, &reasoning_text, tool_calls); - - let mut output = Map::new(); - output.insert("id".to_string(), Value::String(id)); - output.insert( - "object".to_string(), - Value::String("chat.completion".to_string()), - ); - output.insert("created".to_string(), Value::Number(created.into())); - output.insert("model".to_string(), Value::String(model)); - output.insert( - "choices".to_string(), - Value::Array(vec![json!({ - "index": 0, - "message": message, - "finish_reason": finish_reason - })]), - ); - if let Some(usage) = map_usage(response) { - output.insert("usage".to_string(), usage); - } - Value::Object(output) -} - -fn build_chat_message(content: &str, reasoning: &str, tool_calls: Vec) -> Map { - let mut message = Map::new(); - message.insert("role".to_string(), Value::String("assistant".to_string())); - message.insert("content".to_string(), optional_text_value(content)); - message.insert("reasoning_content".to_string(), optional_text_value(reasoning)); - if tool_calls.is_empty() { - message.insert("tool_calls".to_string(), Value::Null); - } else { - message.insert("tool_calls".to_string(), Value::Array(tool_calls)); - } - message -} - -fn extract_response_output( - response: &Map, - tool_name_map: &HashMap, -) -> (String, String, Vec) { - let mut content_text = String::new(); - let mut reasoning_text = String::new(); - let mut tool_calls = Vec::new(); - - let Some(output) = response.get("output").and_then(Value::as_array) else { - return (content_text, reasoning_text, tool_calls); - }; - - for item in output { - let Some(item) = item.as_object() else { - continue; - }; - match item.get("type").and_then(Value::as_str) { - Some("reasoning") => { - if reasoning_text.is_empty() { - reasoning_text = extract_reasoning_summary(item); - } - } - Some("message") => { - if content_text.is_empty() { - content_text = extract_output_text(item); - } - } - Some("function_call") => { - if let Some(tool_call) = build_tool_call(item, tool_name_map) { - tool_calls.push(tool_call); - } - } - _ => {} - } - } - - (content_text, reasoning_text, tool_calls) -} - -fn extract_reasoning_summary(item: &Map) -> String { - let Some(summary) = item.get("summary").and_then(Value::as_array) else { - return String::new(); - }; - for part in summary { - let Some(part) = part.as_object() else { - continue; - }; - if part.get("type").and_then(Value::as_str) != Some("summary_text") { - continue; - } - if let Some(text) = part.get("text").and_then(Value::as_str) { - return text.to_string(); - } - } - String::new() -} - -fn extract_output_text(item: &Map) -> String { - let Some(content) = item.get("content").and_then(Value::as_array) else { - return String::new(); - }; - for part in content { - let Some(part) = part.as_object() else { - continue; - }; - if part.get("type").and_then(Value::as_str) != Some("output_text") { - continue; - } - if let Some(text) = part.get("text").and_then(Value::as_str) { - return text.to_string(); - } - } - String::new() -} - -fn build_tool_call( - item: &Map, - tool_name_map: &HashMap, -) -> Option { - let call_id = item.get("call_id").and_then(Value::as_str).unwrap_or(""); - let name = item.get("name").and_then(Value::as_str).unwrap_or(""); - let arguments = item - .get("arguments") - .and_then(Value::as_str) - .unwrap_or(""); - let restored_name = tool_name_map.get(name).map(String::as_str).unwrap_or(name); - - Some(json!({ - "id": call_id, - "type": "function", - "function": { - "name": restored_name, - "arguments": arguments - } - })) -} - -fn restore_tool_names_in_response( - response: &mut Map, - tool_name_map: &HashMap, -) { - if tool_name_map.is_empty() { - return; - } - let Some(output) = response.get_mut("output").and_then(Value::as_array_mut) else { - return; - }; - for item in output { - let Some(item) = item.as_object_mut() else { - continue; - }; - if item.get("type").and_then(Value::as_str) != Some("function_call") { - continue; - } - let Some(name) = item.get("name").and_then(Value::as_str) else { - continue; - }; - let Some(restored) = tool_name_map.get(name) else { - continue; - }; - item.insert("name".to_string(), Value::String(restored.clone())); - } -} - -fn resolve_finish_reason(response: &Map, has_tool_calls: bool) -> Value { - let status = response - .get("status") - .and_then(Value::as_str) - .unwrap_or("completed"); - if status == "completed" { - Value::String(if has_tool_calls { "tool_calls" } else { "stop" }.to_string()) - } else { - Value::Null - } -} - -fn map_usage(response: &Map) -> Option { - let usage = response.get("usage")?; - let input_tokens = usage.get("input_tokens").and_then(Value::as_i64); - let output_tokens = usage.get("output_tokens").and_then(Value::as_i64); - let total_tokens = usage.get("total_tokens").and_then(Value::as_i64); - if input_tokens.is_none() && output_tokens.is_none() && total_tokens.is_none() { - return None; - } - let mut mapped = Map::new(); - if let Some(value) = input_tokens { - mapped.insert("prompt_tokens".to_string(), Value::Number(value.into())); - } - if let Some(value) = output_tokens { - mapped.insert("completion_tokens".to_string(), Value::Number(value.into())); - } - if let Some(value) = total_tokens { - mapped.insert("total_tokens".to_string(), Value::Number(value.into())); - } - if let Some(reasoning) = usage - .get("output_tokens_details") - .and_then(|details| details.get("reasoning_tokens")) - .and_then(Value::as_i64) - { - let mut details = Map::new(); - details.insert("reasoning_tokens".to_string(), Value::Number(reasoning.into())); - mapped.insert("completion_tokens_details".to_string(), Value::Object(details)); - } - Some(Value::Object(mapped)) -} - -fn optional_text_value(value: &str) -> Value { - if value.is_empty() { - Value::Null - } else { - Value::String(value.to_string()) - } -} - -fn now_unix_seconds() -> i64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs() as i64 -} diff --git a/src-tauri/src/proxy/codex_compat/stream.rs b/src-tauri/src/proxy/codex_compat/stream.rs deleted file mode 100644 index cabe0cd..0000000 --- a/src-tauri/src/proxy/codex_compat/stream.rs +++ /dev/null @@ -1,491 +0,0 @@ -use axum::body::Bytes; -use futures_util::{stream::try_unfold, StreamExt}; -use serde_json::{json, Map, Value}; -use std::collections::{HashMap, VecDeque}; -use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; - -use super::extract_tool_name_map_from_request_body; -use super::super::log::{build_log_entry, LogContext, LogWriter}; -use super::super::sse::SseEventParser; -use super::super::token_rate::RequestTokenTracker; -use super::super::usage::SseUsageCollector; - -pub(crate) fn stream_codex_to_chat( - upstream: impl futures_util::stream::Stream> - + Unpin - + Send - + 'static, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = CodexToChatState::new(upstream, context, log, token_tracker); - try_unfold(state, |state| async move { state.step().await }) -} - -struct CodexToChatState { - upstream: S, - parser: SseEventParser, - collector: SseUsageCollector, - log: Arc, - context: LogContext, - token_tracker: RequestTokenTracker, - out: VecDeque, - response_id: String, - created: i64, - model: String, - function_call_index: i64, - finish_reason: Option<&'static str>, - sent_done: bool, - logged: bool, - upstream_ended: bool, - tool_name_map: HashMap, -} - -impl CodexToChatState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new( - upstream: S, - mut context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - ) -> Self { - let now_ms = now_unix_seconds(); - let response_id = format!("chatcmpl_proxy_{now_ms}"); - let model = context - .model - .clone() - .unwrap_or_else(|| "unknown".to_string()); - let tool_name_map = extract_tool_name_map_from_request_body(context.request_body.as_deref()); - context.request_body = None; - - Self { - upstream, - parser: SseEventParser::new(), - collector: SseUsageCollector::new(), - log, - token_tracker, - context, - out: VecDeque::new(), - response_id, - created: now_ms, - model, - function_call_index: -1, - finish_reason: None, - sent_done: false, - logged: false, - upstream_ended: false, - tool_name_map, - } - } - - async fn step(mut self) -> Result, std::io::Error> { - loop { - if let Some(next) = self.out.pop_front() { - return Ok(Some((next, self))); - } - if self.upstream_ended { - return Ok(None); - } - - match self.upstream.next().await { - Some(Ok(chunk)) => { - if self.context.ttfb_ms.is_none() { - self.context.ttfb_ms = Some(self.context.start.elapsed().as_millis()); - } - self.collector.push_chunk(&chunk); - let mut events = Vec::new(); - self.parser.push_chunk(&chunk, |data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - } - Some(Err(err)) => { - self.log_usage_once(); - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); - } - None => { - self.upstream_ended = true; - let mut events = Vec::new(); - self.parser.finish(|data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - if !self.sent_done { - self.push_done(); - } - self.log_usage_once(); - if self.out.is_empty() { - return Ok(None); - } - } - } - } - } - - fn handle_event(&mut self, data: &str, token_texts: &mut Vec) { - if self.sent_done { - return; - } - if data == "[DONE]" { - self.push_done(); - return; - } - let Ok(value) = serde_json::from_str::(data) else { - return; - }; - let Some(event_type) = value.get("type").and_then(Value::as_str) else { - return; - }; - - match event_type { - "response.created" => { - self.update_from_created(&value); - } - "response.output_text.delta" => { - if let Some(delta) = value.get("delta").and_then(Value::as_str) { - token_texts.push(delta.to_string()); - self.push_chunk(json!({ "role": "assistant", "content": delta })); - } - } - "response.reasoning_summary_text.delta" => { - if let Some(delta) = value.get("delta").and_then(Value::as_str) { - token_texts.push(delta.to_string()); - self.push_chunk(json!({ "role": "assistant", "reasoning_content": delta })); - } - } - "response.reasoning_summary_text.done" => { - self.push_chunk(json!({ "role": "assistant", "reasoning_content": "\n\n" })); - } - "response.output_item.done" => { - self.handle_function_call_item(&value); - } - "response.completed" => { - self.finish_reason = Some(self.resolve_finish_reason()); - } - _ => {} - } - } - - fn update_from_created(&mut self, value: &Value) { - if let Some(response) = value.get("response").and_then(Value::as_object) { - if let Some(id) = response.get("id").and_then(Value::as_str) { - if !id.is_empty() { - self.response_id = id.to_string(); - } - } - if let Some(created) = response.get("created_at").and_then(Value::as_i64) { - self.created = created; - } - if let Some(model) = response.get("model").and_then(Value::as_str) { - if !model.is_empty() { - self.model = model.to_string(); - } - } - } - } - - fn handle_function_call_item(&mut self, value: &Value) { - let Some(item) = value.get("item").and_then(Value::as_object) else { - return; - }; - if item.get("type").and_then(Value::as_str) != Some("function_call") { - return; - } - let name = item.get("name").and_then(Value::as_str).unwrap_or(""); - let restored = self - .tool_name_map - .get(name) - .map(String::as_str) - .unwrap_or(name); - let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or(""); - let id = item - .get("call_id") - .and_then(Value::as_str) - .or_else(|| item.get("id").and_then(Value::as_str)) - .unwrap_or("call_proxy"); - self.function_call_index += 1; - let tool_call = json!({ - "index": self.function_call_index, - "id": id, - "type": "function", - "function": { "name": restored, "arguments": arguments } - }); - self.push_chunk(json!({ "role": "assistant", "tool_calls": [tool_call] })); - } - - fn push_chunk(&mut self, delta: Value) { - let chunk = chat_chunk_sse(&self.response_id, self.created, &self.model, delta, None); - self.out.push_back(chunk); - } - - fn push_done(&mut self) { - if self.sent_done { - return; - } - let finish = self.finish_reason.unwrap_or_else(|| self.resolve_finish_reason()); - let done = chat_chunk_sse( - &self.response_id, - self.created, - &self.model, - json!({}), - Some(finish), - ); - self.out.push_back(done); - self.out.push_back(Bytes::from("data: [DONE]\n\n")); - self.sent_done = true; - } - - fn resolve_finish_reason(&self) -> &'static str { - if self.function_call_index >= 0 { - "tool_calls" - } else { - "stop" - } - } - - fn log_usage_once(&mut self) { - if self.logged { - return; - } - self.logged = true; - let entry = build_log_entry(&self.context, self.collector.finish(), None); - self.log.clone().write_detached(entry); - } -} - -pub(crate) fn stream_codex_to_responses( - upstream: impl futures_util::stream::Stream> - + Unpin - + Send - + 'static, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = CodexToResponsesState::new(upstream, context, log, token_tracker); - try_unfold(state, |state| async move { state.step().await }) -} - -struct CodexToResponsesState { - upstream: S, - parser: SseEventParser, - collector: SseUsageCollector, - log: Arc, - context: LogContext, - token_tracker: RequestTokenTracker, - out: VecDeque, - sent_done: bool, - logged: bool, - upstream_ended: bool, - tool_name_map: HashMap, -} - -impl CodexToResponsesState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new( - upstream: S, - mut context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - ) -> Self { - let tool_name_map = extract_tool_name_map_from_request_body(context.request_body.as_deref()); - context.request_body = None; - Self { - upstream, - parser: SseEventParser::new(), - collector: SseUsageCollector::new(), - log, - token_tracker, - context, - out: VecDeque::new(), - sent_done: false, - logged: false, - upstream_ended: false, - tool_name_map, - } - } - - async fn step(mut self) -> Result, std::io::Error> { - loop { - if let Some(next) = self.out.pop_front() { - return Ok(Some((next, self))); - } - if self.upstream_ended { - return Ok(None); - } - - match self.upstream.next().await { - Some(Ok(chunk)) => { - if self.context.ttfb_ms.is_none() { - self.context.ttfb_ms = Some(self.context.start.elapsed().as_millis()); - } - self.collector.push_chunk(&chunk); - let mut events = Vec::new(); - self.parser.push_chunk(&chunk, |data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - } - Some(Err(err)) => { - self.log_usage_once(); - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); - } - None => { - self.upstream_ended = true; - let mut events = Vec::new(); - self.parser.finish(|data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - if !self.sent_done { - self.out.push_back(Bytes::from("data: [DONE]\n\n")); - self.sent_done = true; - } - self.log_usage_once(); - if self.out.is_empty() { - return Ok(None); - } - } - } - } - } - - fn handle_event(&mut self, data: &str, token_texts: &mut Vec) { - if self.sent_done { - return; - } - if data == "[DONE]" { - self.out.push_back(Bytes::from("data: [DONE]\n\n")); - self.sent_done = true; - return; - } - let Ok(mut value) = serde_json::from_str::(data) else { - return; - }; - restore_tool_names_in_event(&mut value, &self.tool_name_map); - if let Some(delta) = extract_output_text_delta(&value) { - token_texts.push(delta.to_string()); - } - self.out - .push_back(Bytes::from(format!("data: {}\n\n", value.to_string()))); - } - - fn log_usage_once(&mut self) { - if self.logged { - return; - } - self.logged = true; - let entry = build_log_entry(&self.context, self.collector.finish(), None); - self.log.clone().write_detached(entry); - } -} - -fn restore_tool_names_in_event(value: &mut Value, tool_name_map: &HashMap) { - if tool_name_map.is_empty() { - return; - } - if let Some(item) = value.get_mut("item").and_then(Value::as_object_mut) { - restore_tool_names_in_item(item, tool_name_map); - } - if let Some(response) = value.get_mut("response").and_then(Value::as_object_mut) { - restore_tool_names_in_response(response, tool_name_map); - } - if let Some(response) = value.as_object_mut() { - restore_tool_names_in_response(response, tool_name_map); - } -} - -fn restore_tool_names_in_response( - response: &mut Map, - tool_name_map: &HashMap, -) { - let Some(output) = response.get_mut("output").and_then(Value::as_array_mut) else { - return; - }; - for item in output { - let Some(item) = item.as_object_mut() else { - continue; - }; - restore_tool_names_in_item(item, tool_name_map); - } -} - -fn restore_tool_names_in_item(item: &mut Map, tool_name_map: &HashMap) { - if item.get("type").and_then(Value::as_str) != Some("function_call") { - return; - } - let Some(name) = item.get("name").and_then(Value::as_str) else { - return; - }; - let Some(restored) = tool_name_map.get(name) else { - return; - }; - item.insert("name".to_string(), Value::String(restored.clone())); -} - -fn extract_output_text_delta(value: &Value) -> Option<&str> { - if value.get("type").and_then(Value::as_str) != Some("response.output_text.delta") { - return None; - } - value.get("delta").and_then(Value::as_str) -} - -fn chat_chunk_sse( - id: &str, - created: i64, - model: &str, - delta: Value, - finish_reason: Option<&str>, -) -> Bytes { - let chunk = json!({ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "delta": delta, - "finish_reason": finish_reason - } - ] - }); - Bytes::from(format!("data: {}\n\n", chunk.to_string())) -} - -fn now_unix_seconds() -> i64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs() as i64 -} diff --git a/src-tauri/src/proxy/codex_compat/tool_names.rs b/src-tauri/src/proxy/codex_compat/tool_names.rs deleted file mode 100644 index a918f4e..0000000 --- a/src-tauri/src/proxy/codex_compat/tool_names.rs +++ /dev/null @@ -1,96 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -const TOOL_NAME_LIMIT: usize = 64; -const MCP_PREFIX: &str = "mcp__"; - -#[derive(Clone, Debug, Default)] -pub(crate) struct ToolNameMap { - pub(crate) short_by_original: HashMap, - pub(crate) original_by_short: HashMap, -} - -impl ToolNameMap { - pub(crate) fn from_names(names: &[String]) -> Self { - let mut used = HashSet::new(); - let mut short_by_original = HashMap::new(); - let mut original_by_short = HashMap::new(); - - for name in names { - let candidate = base_candidate(name); - let short = make_unique(&candidate, &mut used); - short_by_original.insert(name.clone(), short.clone()); - original_by_short.insert(short, name.clone()); - } - - Self { - short_by_original, - original_by_short, - } - } - - pub(crate) fn shorten(&self, name: &str) -> String { - self.short_by_original - .get(name) - .cloned() - .unwrap_or_else(|| shorten_name_if_needed(name)) - } -} - -pub(crate) fn shorten_name_if_needed(name: &str) -> String { - if name.len() <= TOOL_NAME_LIMIT { - return name.to_string(); - } - base_candidate(name) -} - -fn base_candidate(name: &str) -> String { - if name.len() <= TOOL_NAME_LIMIT { - return name.to_string(); - } - if let Some(candidate) = shorten_mcp_name(name) { - return candidate; - } - truncate_name(name, TOOL_NAME_LIMIT) -} - -fn shorten_mcp_name(name: &str) -> Option { - if !name.starts_with(MCP_PREFIX) { - return None; - } - let idx = name.rfind("__")?; - if idx <= MCP_PREFIX.len() { - return None; - } - let mut candidate = format!("{MCP_PREFIX}{}", &name[idx + 2..]); - if candidate.len() > TOOL_NAME_LIMIT { - candidate.truncate(TOOL_NAME_LIMIT); - } - Some(candidate) -} - -fn truncate_name(name: &str, limit: usize) -> String { - let mut out = name.to_string(); - if out.len() > limit { - out.truncate(limit); - } - out -} - -fn make_unique(candidate: &str, used: &mut HashSet) -> String { - if used.insert(candidate.to_string()) { - return candidate.to_string(); - } - for index in 1.. { - let suffix = format!("_{index}"); - let allowed = TOOL_NAME_LIMIT.saturating_sub(suffix.len()); - let mut base = candidate.to_string(); - if base.len() > allowed { - base.truncate(allowed); - } - let next = format!("{base}{suffix}"); - if used.insert(next.clone()) { - return next; - } - } - candidate.to_string() -} diff --git a/src-tauri/src/proxy/compat_content.rs b/src-tauri/src/proxy/compat_content.rs deleted file mode 100644 index ffd2d14..0000000 --- a/src-tauri/src/proxy/compat_content.rs +++ /dev/null @@ -1,80 +0,0 @@ -use serde_json::{json, Value}; - -pub(crate) fn chat_message_content_from_responses_parts(parts: &[Value]) -> Value { - let mut output_parts = Vec::new(); - let mut combined_text = String::new(); - let mut text_only = true; - - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str); - match part_type { - Some("output_text") => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - combined_text.push_str(text); - output_parts.push(json!({ "type": "text", "text": text })); - } - } - Some("reasoning_text") => {} - Some("refusal") => { - let text = part - .get("refusal") - .or_else(|| part.get("text")) - .and_then(Value::as_str) - .unwrap_or(""); - if !text.is_empty() { - combined_text.push_str(text); - output_parts.push(json!({ "type": "text", "text": text })); - } - } - Some("output_image") => { - if let Some(image_url) = part.get("image_url") { - text_only = false; - output_parts.push(json!({ "type": "image_url", "image_url": image_url.clone() })); - } - } - Some("input_image") => { - if let Some(image_url) = part.get("image_url") { - text_only = false; - output_parts.push(json!({ "type": "image_url", "image_url": image_url.clone() })); - } - } - Some("input_text") | Some("text") => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - combined_text.push_str(text); - output_parts.push(json!({ "type": "text", "text": text })); - } - } - _ => { - text_only = false; - } - } - } - - if text_only { - Value::String(combined_text) - } else { - Value::Array(output_parts) - } -} - -pub(crate) fn chat_message_non_text_parts_from_responses(parts: &[Value]) -> Vec { - let mut output_parts = Vec::new(); - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str); - match part_type { - Some("output_image") | Some("input_image") => { - if let Some(image_url) = part.get("image_url") { - output_parts.push(json!({ "type": "image_url", "image_url": image_url.clone() })); - } - } - _ => {} - } - } - output_parts -} diff --git a/src-tauri/src/proxy/compat_reason.rs b/src-tauri/src/proxy/compat_reason.rs deleted file mode 100644 index c086342..0000000 --- a/src-tauri/src/proxy/compat_reason.rs +++ /dev/null @@ -1,65 +0,0 @@ -use serde_json::Value; - -pub(crate) fn chat_finish_reason_from_responses( - status: Option<&str>, - incomplete_reason: Option<&str>, - has_tool_calls: bool, -) -> &'static str { - // Prefer explicit incomplete reason, then status, then tool calls. - if let Some(reason) = incomplete_reason { - return map_responses_reason_to_chat_finish_reason(reason); - } - if matches!(status, Some("incomplete")) { - return "length"; - } - if has_tool_calls { - return "tool_calls"; - } - "stop" -} - -pub(crate) fn chat_finish_reason_from_response_object( - response: &serde_json::Map, - has_tool_calls: bool, -) -> &'static str { - let status = response.get("status").and_then(Value::as_str); - let incomplete_reason = response - .get("incomplete_details") - .and_then(Value::as_object) - .and_then(|details| details.get("reason")) - .and_then(Value::as_str); - chat_finish_reason_from_responses(status, incomplete_reason, has_tool_calls) -} - -pub(crate) fn responses_status_from_chat_finish_reason( - finish_reason: Option<&str>, -) -> (Option<&'static str>, Option<&'static str>) { - let Some(reason) = finish_reason else { - return (None, None); - }; - match reason { - "length" => (Some("incomplete"), Some("max_tokens")), - "content_filter" => (Some("incomplete"), Some("content_filter")), - _ => (None, None), - } -} - -pub(crate) fn anthropic_stop_reason_from_chat_finish_reason(reason: &str) -> &'static str { - match reason { - "stop" => "end_turn", - "length" => "max_tokens", - "tool_calls" => "tool_use", - "content_filter" => "refusal", - _ => "end_turn", - } -} - -fn map_responses_reason_to_chat_finish_reason(reason: &str) -> &'static str { - match reason { - "max_output_tokens" | "max_tokens" => "length", - "content_filter" => "content_filter", - "tool_calls" | "tool_use" => "tool_calls", - "stop" | "stop_sequence" | "end_turn" => "stop", - _ => "stop", - } -} diff --git a/src-tauri/src/proxy/config/io.rs b/src-tauri/src/proxy/config/io.rs deleted file mode 100644 index 4876a5a..0000000 --- a/src-tauri/src/proxy/config/io.rs +++ /dev/null @@ -1,205 +0,0 @@ -use std::path::{Path, PathBuf}; -use std::time::Instant; - -use tauri::{AppHandle, Manager}; - -use super::ProxyConfigFile; - -const CONFIG_FILE_NAME: &str = "config.jsonc"; -const DEFAULT_CONFIG_HEADER: &str = concat!( - "// Token Proxy config (JSONC). Comments and trailing commas are supported.\n", - "// log_level (optional): silent|error|warn|info|debug|trace. Default: silent.\n", - "// app_proxy_url (optional): http(s)://... | socks5(h)://... (used for app updates and upstream proxy reuse).\n", - "// upstreams[].proxy_url (optional): empty => direct; \"$app_proxy_url\" => use app_proxy_url; or an explicit proxy URL.\n" -); - -pub(super) async fn load_config_file(app: &AppHandle) -> Result { - let path = config_file_path(app)?; - tracing::debug!(path = %path.display(), "load_config_file start"); - let start = Instant::now(); - match tokio::fs::read_to_string(&path).await { - Ok(contents) => { - tracing::debug!( - path = %path.display(), - bytes = contents.len(), - elapsed_ms = start.elapsed().as_millis(), - "load_config_file read" - ); - parse_config_file(&contents, &path) - } - Err(err) if err.kind() == std::io::ErrorKind::NotFound => { - tracing::debug!( - path = %path.display(), - elapsed_ms = start.elapsed().as_millis(), - "load_config_file missing, creating default" - ); - let config = ProxyConfigFile::default(); - save_config_file(app, &config).await?; - Ok(config) - } - Err(err) => { - tracing::error!( - path = %path.display(), - elapsed_ms = start.elapsed().as_millis(), - error = %err, - "load_config_file read failed" - ); - Err(format!("Failed to read config file: {err}")) - } - } -} - -pub(super) async fn save_config_file( - app: &AppHandle, - config: &ProxyConfigFile, -) -> Result<(), String> { - let path = config_file_path(app)?; - tracing::debug!(path = %path.display(), "save_config_file start"); - let start = Instant::now(); - ensure_parent_dir(&path).await?; - tracing::debug!( - path = %path.display(), - elapsed_ms = start.elapsed().as_millis(), - "save_config_file ensured dir" - ); - let data = serde_json::to_string_pretty(config) - .map_err(|err| format!("Failed to serialize config: {err}"))?; - let header = read_existing_header(&path) - .await - .unwrap_or_else(default_config_header); - tracing::debug!( - path = %path.display(), - elapsed_ms = start.elapsed().as_millis(), - "save_config_file header ready" - ); - let output = merge_header_and_body(header, data); - tokio::fs::write(&path, output) - .await - .map_err(|err| format!("Failed to write config file: {err}"))?; - tracing::debug!( - path = %path.display(), - elapsed_ms = start.elapsed().as_millis(), - "save_config_file wrote" - ); - Ok(()) -} - -/// Config directory: BaseDirectory::AppConfig -pub(crate) fn config_dir_path(app: &AppHandle) -> Result { - app.path() - .app_config_dir() - .map_err(|err| format!("Failed to resolve config dir: {err}")) -} - -/// Config file path: based on the config directory -pub(super) fn config_file_path(app: &AppHandle) -> Result { - Ok(config_dir_path(app)?.join(CONFIG_FILE_NAME)) -} - -fn parse_config_file(contents: &str, path: &Path) -> Result { - let sanitized = crate::jsonc::sanitize_jsonc(contents); - serde_json::from_str(&sanitized) - .map_err(|err| format!("Failed to parse config file {}: {err}", path.display())) -} - -async fn read_existing_header(path: &Path) -> Option { - tracing::debug!(path = %path.display(), "read_existing_header start"); - let start = Instant::now(); - let contents = tokio::fs::read_to_string(path).await.ok()?; - tracing::debug!( - path = %path.display(), - bytes = contents.len(), - elapsed_ms = start.elapsed().as_millis(), - "read_existing_header read" - ); - let header = extract_leading_jsonc_comments(&contents); - if header.trim().is_empty() { - None - } else { - Some(header) - } -} - -fn extract_leading_jsonc_comments(contents: &str) -> String { - let bytes = contents.as_bytes(); - let mut output = Vec::new(); - let mut index = 0; - - while index < bytes.len() { - let byte = bytes[index]; - if byte == b' ' || byte == b'\t' || byte == b'\r' || byte == b'\n' { - output.push(byte); - index += 1; - continue; - } - - if byte == b'/' && index + 1 < bytes.len() { - let next = bytes[index + 1]; - if next == b'/' { - output.push(byte); - output.push(next); - index += 2; - while index < bytes.len() { - let current = bytes[index]; - output.push(current); - index += 1; - if current == b'\n' { - break; - } - } - continue; - } - if next == b'*' { - output.push(byte); - output.push(next); - index += 2; - while index < bytes.len() { - let current = bytes[index]; - output.push(current); - index += 1; - if current == b'*' && index < bytes.len() && bytes[index] == b'/' { - output.push(b'/'); - index += 1; - break; - } - } - continue; - } - } - - break; - } - - String::from_utf8(output).unwrap_or_default() -} - -fn default_config_header() -> String { - DEFAULT_CONFIG_HEADER.to_string() -} - -fn merge_header_and_body(header: String, body: String) -> String { - if header.is_empty() { - format!("{body}\n") - } else if header.ends_with('\n') { - format!("{header}{body}\n") - } else { - format!("{header}\n{body}\n") - } -} - -async fn ensure_parent_dir(path: &Path) -> Result<(), String> { - let Some(parent) = path.parent() else { - return Ok(()); - }; - tracing::debug!(path = %parent.display(), "ensure_parent_dir start"); - let start = Instant::now(); - tokio::fs::create_dir_all(parent) - .await - .map_err(|err| format!("Failed to create config directory: {err}"))?; - tracing::debug!( - path = %parent.display(), - elapsed_ms = start.elapsed().as_millis(), - "ensure_parent_dir done" - ); - Ok(()) -} diff --git a/src-tauri/src/proxy/config/mod.rs b/src-tauri/src/proxy/config/mod.rs deleted file mode 100644 index 0e33753..0000000 --- a/src-tauri/src/proxy/config/mod.rs +++ /dev/null @@ -1,104 +0,0 @@ -mod io; -mod model_mapping; -mod normalize; -mod types; - -use tauri::AppHandle; - -const DEFAULT_MAX_REQUEST_BODY_BYTES: u64 = 20 * 1024 * 1024; - -pub(crate) use io::config_dir_path; -pub(crate) use types::{ - ConfigResponse, - KiroPreferredEndpoint, - ProxyConfig, - ProxyConfigFile, - ProviderUpstreams, - TrayTokenRateConfig, - TrayTokenRateFormat, - UpstreamConfig, - UpstreamOverrides, - UpstreamGroup, - HeaderOverride, - UpstreamRuntime, - UpstreamStrategy, -}; - -pub(crate) async fn read_config(app: AppHandle) -> Result { - let config = io::load_config_file(&app).await?; - let path = io::config_file_path(&app)?; - Ok(ConfigResponse { - path: path.to_string_lossy().to_string(), - config, - }) -} - -pub(crate) fn app_proxy_url_from_config(config: &ProxyConfigFile) -> Result, String> { - normalize_app_proxy_url(config.app_proxy_url.as_deref()) -} - -pub(crate) async fn write_config( - app: AppHandle, - config: ProxyConfigFile, -) -> Result<(), String> { - build_runtime_config(config.clone())?; - io::save_config_file(&app, &config).await -} - -impl ProxyConfig { - pub(crate) fn addr(&self) -> String { - format!("{}:{}", self.host, self.port) - } - - pub(crate) async fn load(app: &AppHandle) -> Result { - let config = io::load_config_file(app).await?; - build_runtime_config(config) - } - - pub(crate) fn provider_upstreams(&self, provider: &str) -> Option<&ProviderUpstreams> { - self.upstreams.get(provider) - } -} - -fn build_runtime_config(config: ProxyConfigFile) -> Result { - let log_level = config.log_level; - let max_request_body_bytes = resolve_max_request_body_bytes(config.max_request_body_bytes); - let app_proxy_url = normalize_app_proxy_url(config.app_proxy_url.as_deref())?; - let normalized_upstreams = - normalize::normalize_upstreams(&config.upstreams, app_proxy_url.as_deref())?; - let upstreams = normalize::build_provider_upstreams(normalized_upstreams)?; - Ok(ProxyConfig { - host: config.host, - port: config.port, - local_api_key: config.local_api_key, - log_level, - max_request_body_bytes, - enable_api_format_conversion: config.enable_api_format_conversion, - upstream_strategy: config.upstream_strategy, - upstreams, - kiro_preferred_endpoint: config.kiro_preferred_endpoint, - antigravity_user_agent: config.antigravity_user_agent, - }) -} - -fn resolve_max_request_body_bytes(value: Option) -> usize { - let value = value.unwrap_or(DEFAULT_MAX_REQUEST_BODY_BYTES); - let value = if value == 0 { - DEFAULT_MAX_REQUEST_BODY_BYTES - } else { - value - }; - usize::try_from(value).unwrap_or(usize::MAX) -} - -fn normalize_app_proxy_url(value: Option<&str>) -> Result, String> { - let value = value.unwrap_or_default().trim(); - if value.is_empty() { - return Ok(None); - } - let parsed = url::Url::parse(value).map_err(|_| "app_proxy_url is not a valid URL.".to_string())?; - match parsed.scheme() { - "http" | "https" | "socks5" | "socks5h" => Ok(Some(value.to_string())), - scheme => Err(format!("app_proxy_url scheme is not supported: {scheme}.")), - } -} diff --git a/src-tauri/src/proxy/config/model_mapping.rs b/src-tauri/src/proxy/config/model_mapping.rs deleted file mode 100644 index 75d402c..0000000 --- a/src-tauri/src/proxy/config/model_mapping.rs +++ /dev/null @@ -1,143 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -#[derive(Clone, Debug)] -pub(crate) struct ModelMappingRules { - exact: HashMap, - prefix: Vec, - wildcard: Option, -} - -#[derive(Clone, Debug)] -struct PrefixRule { - prefix: String, - target: String, -} - -impl ModelMappingRules { - pub(crate) fn map_model(&self, model: &str) -> Option<&str> { - if let Some(target) = self.exact.get(model) { - return Some(target.as_str()); - } - for rule in &self.prefix { - if model.starts_with(&rule.prefix) { - return Some(rule.target.as_str()); - } - } - self.wildcard.as_deref() - } -} - -pub(crate) fn compile_model_mappings( - upstream_id: &str, - mappings: &HashMap, -) -> Result, String> { - if mappings.is_empty() { - return Ok(None); - } - let mut builder = ModelMappingBuilder::new(upstream_id); - for (pattern, target) in mappings { - builder.push(pattern, target)?; - } - Ok(Some(builder.finish())) -} - -struct ModelMappingBuilder<'a> { - upstream_id: &'a str, - exact: HashMap, - prefix: Vec, - wildcard: Option, - seen_patterns: HashSet, -} - -impl<'a> ModelMappingBuilder<'a> { - fn new(upstream_id: &'a str) -> Self { - Self { - upstream_id, - exact: HashMap::new(), - prefix: Vec::new(), - wildcard: None, - seen_patterns: HashSet::new(), - } - } - - fn push(&mut self, pattern_raw: &str, target_raw: &str) -> Result<(), String> { - let pattern = pattern_raw.trim(); - let target = target_raw.trim(); - if pattern.is_empty() { - return Err(self.error("model mapping pattern cannot be empty")); - } - if target.is_empty() { - return Err(format!( - "Upstream {} model mapping target for \"{}\" cannot be empty.", - self.upstream_id, pattern - )); - } - if pattern == "*" { - if self.wildcard.is_some() { - return Err(format!( - "Upstream {} model mapping wildcard \"*\" can only be defined once.", - self.upstream_id - )); - } - self.wildcard = Some(target.to_string()); - return Ok(()); - } - if !self.seen_patterns.insert(pattern.to_string()) { - return Err(format!( - "Upstream {} model mapping pattern is duplicated: {}.", - self.upstream_id, pattern - )); - } - if pattern.ends_with('*') { - // 前缀模式:只允许尾部通配,且前缀不能为空。 - let prefix_value = pattern.trim_end_matches('*'); - if prefix_value.is_empty() { - return Err(self.error("model mapping prefix cannot be empty")); - } - if prefix_value.contains('*') { - return Err(self.invalid_pattern(pattern)); - } - self.prefix.push(PrefixRule { - prefix: prefix_value.to_string(), - target: target.to_string(), - }); - return Ok(()); - } - if pattern.contains('*') { - return Err(self.invalid_pattern(pattern)); - } - self.exact.insert(pattern.to_string(), target.to_string()); - Ok(()) - } - - fn finish(mut self) -> ModelMappingRules { - self.prefix.sort_by(|left, right| { - right - .prefix - .len() - .cmp(&left.prefix.len()) - .then_with(|| left.prefix.cmp(&right.prefix)) - }); - ModelMappingRules { - exact: self.exact, - prefix: self.prefix, - wildcard: self.wildcard, - } - } - - fn error(&self, message: &str) -> String { - format!("Upstream {} {}.", self.upstream_id, message) - } - - fn invalid_pattern(&self, pattern: &str) -> String { - format!( - "Upstream {} model mapping pattern is invalid: {}.", - self.upstream_id, pattern - ) - } -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "model_mapping.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/config/model_mapping.test.rs b/src-tauri/src/proxy/config/model_mapping.test.rs deleted file mode 100644 index 7061dc5..0000000 --- a/src-tauri/src/proxy/config/model_mapping.test.rs +++ /dev/null @@ -1,36 +0,0 @@ -use super::*; - -#[test] -fn model_mapping_prefers_exact_then_prefix_then_wildcard() { - let mut mappings = HashMap::new(); - mappings.insert("gpt-4".to_string(), "gpt-4.1".to_string()); - mappings.insert("gpt-4*".to_string(), "gpt-4.1-mini".to_string()); - mappings.insert("*".to_string(), "gpt-default".to_string()); - let rules = compile_model_mappings("demo", &mappings) - .expect("compile") - .expect("rules"); - - assert_eq!(rules.map_model("gpt-4"), Some("gpt-4.1")); - assert_eq!(rules.map_model("gpt-4-vision"), Some("gpt-4.1-mini")); - assert_eq!(rules.map_model("other"), Some("gpt-default")); -} - -#[test] -fn model_mapping_prefix_prefers_longer_prefix() { - let mut mappings = HashMap::new(); - mappings.insert("gpt-4*".to_string(), "wide".to_string()); - mappings.insert("gpt-4.1*".to_string(), "narrow".to_string()); - let rules = compile_model_mappings("demo", &mappings) - .expect("compile") - .expect("rules"); - assert_eq!(rules.map_model("gpt-4.1-mini"), Some("narrow")); -} - -#[test] -fn model_mapping_rejects_multiple_wildcards() { - let mut mappings = HashMap::new(); - mappings.insert("*".to_string(), "a".to_string()); - mappings.insert(" * ".to_string(), "b".to_string()); - let err = compile_model_mappings("demo", &mappings).unwrap_err(); - assert!(err.contains("wildcard")); -} diff --git a/src-tauri/src/proxy/config/normalize.rs b/src-tauri/src/proxy/config/normalize.rs deleted file mode 100644 index 97af6d2..0000000 --- a/src-tauri/src/proxy/config/normalize.rs +++ /dev/null @@ -1,253 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -use super::{ - model_mapping::compile_model_mappings, HeaderOverride, ProviderUpstreams, UpstreamConfig, - UpstreamGroup, UpstreamOverrides, UpstreamRuntime, -}; -use axum::http::header::{HeaderName, HeaderValue}; - -const APP_PROXY_URL_PLACEHOLDER: &str = "$app_proxy_url"; -const DEFAULT_CODEX_BASE_URL: &str = "https://chatgpt.com/backend-api/codex"; -const DEFAULT_ANTIGRAVITY_BASE_URL: &str = "https://daily-cloudcode-pa.googleapis.com"; - -#[derive(Clone)] -pub(super) struct NormalizedUpstream { - pub(super) provider: String, - pub(super) runtime: UpstreamRuntime, -} - -pub(super) fn normalize_upstreams( - upstreams: &[UpstreamConfig], - app_proxy_url: Option<&str>, -) -> Result, String> { - validate_upstream_ids(upstreams)?; - let mut normalized = Vec::with_capacity(upstreams.len()); - for upstream in upstreams { - if let Some(entry) = normalize_single_upstream(upstream, app_proxy_url)? { - normalized.push(entry); - } - } - Ok(normalized) -} - -pub(super) fn build_provider_upstreams( - upstreams: Vec, -) -> Result, String> { - let mut grouped: HashMap> = HashMap::new(); - for upstream in upstreams { - grouped - .entry(upstream.provider) - .or_default() - .push(upstream.runtime); - } - let mut output = HashMap::new(); - for (provider, upstreams) in grouped { - let groups = group_upstreams_by_priority(upstreams); - output.insert(provider, ProviderUpstreams { groups }); - } - Ok(output) -} - -fn group_upstreams_by_priority(upstreams: Vec) -> Vec { - // Keep same-priority order stable by preserving config insertion order. - let mut grouped: HashMap> = HashMap::new(); - for upstream in upstreams { - grouped.entry(upstream.priority).or_default().push(upstream); - } - let mut priorities: Vec = grouped.keys().copied().collect(); - priorities.sort_by(|left, right| right.cmp(left)); - let mut groups = Vec::with_capacity(priorities.len()); - for priority in priorities { - if let Some(items) = grouped.remove(&priority) { - groups.push(UpstreamGroup { priority, items }); - } - } - groups -} - -fn validate_upstream_ids(upstreams: &[UpstreamConfig]) -> Result<(), String> { - let mut seen_ids = HashSet::new(); - for upstream in upstreams { - let id = upstream.id.trim(); - if id.is_empty() { - return Err("Upstream id cannot be empty.".to_string()); - } - if !seen_ids.insert(id.to_string()) { - return Err(format!("Upstream id already exists: {id}.")); - } - } - Ok(()) -} - -fn normalize_single_upstream( - upstream: &UpstreamConfig, - app_proxy_url: Option<&str>, -) -> Result, String> { - if !upstream.enabled { - return Ok(None); - } - let provider = upstream.provider.trim(); - if provider.is_empty() { - return Err(format!( - "Upstream {} provider cannot be empty.", - upstream.id - )); - } - let base_url = upstream.base_url.trim(); - let base_url = if base_url.is_empty() { - if provider == "codex" { - DEFAULT_CODEX_BASE_URL.to_string() - } else if provider == "antigravity" { - DEFAULT_ANTIGRAVITY_BASE_URL.to_string() - } else if provider == "kiro" { - String::new() - } else { - return Err(format!( - "Upstream {} base_url cannot be empty.", - upstream.id - )); - } - } else { - base_url.to_string() - }; - let api_key = upstream - .api_key - .as_ref() - .map(|value| value.trim()) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()); - let kiro_account_id = upstream - .kiro_account_id - .as_ref() - .map(|value| value.trim()) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()); - let codex_account_id = upstream - .codex_account_id - .as_ref() - .map(|value| value.trim()) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()); - let antigravity_account_id = upstream - .antigravity_account_id - .as_ref() - .map(|value| value.trim()) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()); - if provider == "kiro" && kiro_account_id.is_none() { - return Err(format!( - "Upstream {} requires a Kiro account binding.", - upstream.id - )); - } - if provider == "codex" && codex_account_id.is_none() { - return Err(format!( - "Upstream {} requires a Codex account binding.", - upstream.id - )); - } - if provider == "antigravity" && antigravity_account_id.is_none() { - return Err(format!( - "Upstream {} requires an Antigravity account binding.", - upstream.id - )); - } - let proxy_url = normalize_upstream_proxy_url( - upstream.proxy_url.as_deref(), - app_proxy_url, - &upstream.id, - )?; - let model_mappings = compile_model_mappings(&upstream.id, &upstream.model_mappings)?; - let header_overrides = normalize_header_overrides(upstream.overrides.as_ref())?; - let runtime = UpstreamRuntime { - id: upstream.id.trim().to_string(), - base_url, - api_key, - filter_prompt_cache_retention: upstream.filter_prompt_cache_retention, - filter_safety_identifier: upstream.filter_safety_identifier, - kiro_account_id, - codex_account_id, - antigravity_account_id, - kiro_preferred_endpoint: upstream.preferred_endpoint.clone(), - proxy_url, - priority: upstream.priority.unwrap_or(0), - model_mappings, - header_overrides, - }; - Ok(Some(NormalizedUpstream { - provider: provider.to_string(), - runtime, - })) -} - -fn normalize_header_overrides( - overrides: Option<&UpstreamOverrides>, -) -> Result>, String> { - let Some(overrides) = overrides else { - return Ok(None); - }; - if overrides.header.is_empty() { - return Ok(None); - } - - let mut normalized = Vec::with_capacity(overrides.header.len()); - for (raw_name, raw_value) in &overrides.header { - let trimmed = raw_name.trim(); - let name = HeaderName::from_bytes(trimmed.as_bytes()) - .map_err(|_| format!("Invalid header name in overrides: {raw_name}"))?; - - let value: Option = match raw_value { - Some(value) => { - if value.is_empty() { - // 允许空字符串,代表设置为空值。 - Some(HeaderValue::from_str("").map_err(|_| { - format!("Invalid header value for {raw_name}") - })?) - } else { - Some(HeaderValue::from_str(value).map_err(|_| { - format!("Invalid header value for {raw_name}") - })?) - } - } - None => None, - }; - - normalized.push(HeaderOverride { name, value }); - } - - // 用户输入大小写混合时,保持用户写法;应用阶段再做覆盖策略。 - Ok(Some(normalized)) -} - -fn normalize_upstream_proxy_url( - proxy_url: Option<&str>, - app_proxy_url: Option<&str>, - upstream_id: &str, -) -> Result, String> { - let value = proxy_url.unwrap_or_default().trim(); - if value.is_empty() { - return Ok(None); - } - if value == APP_PROXY_URL_PLACEHOLDER { - let app_proxy_url = app_proxy_url.unwrap_or_default().trim(); - if app_proxy_url.is_empty() { - return Err(format!( - "Upstream {upstream_id} proxy_url is set to {APP_PROXY_URL_PLACEHOLDER}, but app_proxy_url is empty." - )); - } - return Ok(Some(validate_proxy_url(app_proxy_url, upstream_id)?.to_string())); - } - Ok(Some(validate_proxy_url(value, upstream_id)?.to_string())) -} - -fn validate_proxy_url<'a>(value: &'a str, upstream_id: &str) -> Result<&'a str, String> { - let parsed = url::Url::parse(value).map_err(|_| { - format!("Upstream {upstream_id} proxy_url is not a valid URL.") - })?; - match parsed.scheme() { - "http" | "https" | "socks5" | "socks5h" => Ok(value), - scheme => Err(format!( - "Upstream {upstream_id} proxy_url scheme is not supported: {scheme}." - )), - } -} diff --git a/src-tauri/src/proxy/config/types.rs b/src-tauri/src/proxy/config/types.rs deleted file mode 100644 index e8ac084..0000000 --- a/src-tauri/src/proxy/config/types.rs +++ /dev/null @@ -1,297 +0,0 @@ -use axum::http::header::{HeaderName, HeaderValue}; -use serde::{de, Deserialize, Deserializer, Serialize}; -use std::collections::HashMap; - -use super::model_mapping::ModelMappingRules; -use crate::logging::LogLevel; - -fn default_enabled() -> bool { - true -} - -fn is_false(value: &bool) -> bool { - !*value -} - -fn default_proxy_port() -> u16 { - // Dev 与安装包需要可并行运行;debug 默认换一个端口,避免与 release/安装包冲突。 - if cfg!(debug_assertions) { - 19208 - } else { - 9208 - } -} - -fn default_tray_token_rate_enabled() -> bool { - true -} - -fn default_log_level() -> LogLevel { - LogLevel::Silent -} - -#[derive(Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub(crate) enum UpstreamStrategy { - PriorityRoundRobin, - PriorityFillFirst, -} - -impl Default for UpstreamStrategy { - fn default() -> Self { - Self::PriorityFillFirst - } -} - -#[derive(Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub(crate) enum TrayTokenRateFormat { - Combined, - Split, - Both, -} - -impl Default for TrayTokenRateFormat { - fn default() -> Self { - Self::Split - } -} - -#[derive(Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub(crate) enum KiroPreferredEndpoint { - Ide, - Cli, -} - -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct TrayTokenRateConfig { - #[serde(default = "default_tray_token_rate_enabled")] - pub(crate) enabled: bool, - #[serde(default)] - pub(crate) format: TrayTokenRateFormat, -} - -impl Default for TrayTokenRateConfig { - fn default() -> Self { - Self { - enabled: default_tray_token_rate_enabled(), - format: TrayTokenRateFormat::default(), - } - } -} - -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct UpstreamConfig { - pub(crate) id: String, - pub(crate) provider: String, - pub(crate) base_url: String, - pub(crate) api_key: Option, - /// Only meaningful for provider "openai-response": strip `prompt_cache_retention` from /v1/responses requests. - #[serde(default, skip_serializing_if = "is_false")] - pub(crate) filter_prompt_cache_retention: bool, - /// Only meaningful for provider "openai-response": strip `safety_identifier` from /v1/responses requests. - #[serde(default, skip_serializing_if = "is_false")] - pub(crate) filter_safety_identifier: bool, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) kiro_account_id: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) codex_account_id: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) antigravity_account_id: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) preferred_endpoint: Option, - pub(crate) proxy_url: Option, - pub(crate) priority: Option, - #[serde(default = "default_enabled")] - pub(crate) enabled: bool, - #[serde(default)] - pub(crate) model_mappings: HashMap, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) overrides: Option, -} - -#[derive(Clone, Serialize, Deserialize, Default)] -pub(crate) struct UpstreamOverrides { - #[serde(default)] - pub(crate) header: HashMap>, -} - -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct ProxyConfigFile { - pub(crate) host: String, - pub(crate) port: u16, - pub(crate) local_api_key: Option, - pub(crate) app_proxy_url: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) kiro_preferred_endpoint: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) antigravity_ide_db_path: Option, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub(crate) antigravity_app_paths: Vec, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub(crate) antigravity_process_names: Vec, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub(crate) antigravity_user_agent: Option, - #[serde(default = "default_log_level", deserialize_with = "deserialize_log_level")] - pub(crate) log_level: LogLevel, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) max_request_body_bytes: Option, - #[serde(default)] - pub(crate) tray_token_rate: TrayTokenRateConfig, - /// 是否允许在不同 API 格式之间自动互转(例如 OpenAI Chat↔Responses、Claude Messages↔OpenAI Responses)。 - /// 默认为开启;关闭时将严格按 provider 路由,不做格式转换。 - #[serde(default)] - pub(crate) enable_api_format_conversion: bool, - #[serde(default)] - pub(crate) upstream_strategy: UpstreamStrategy, - #[serde(default)] - pub(crate) upstreams: Vec, -} - -impl Default for ProxyConfigFile { - fn default() -> Self { - Self { - host: "127.0.0.1".to_string(), - port: default_proxy_port(), - local_api_key: None, - app_proxy_url: None, - kiro_preferred_endpoint: None, - antigravity_ide_db_path: None, - antigravity_app_paths: Vec::new(), - antigravity_process_names: Vec::new(), - antigravity_user_agent: None, - log_level: LogLevel::default(), - max_request_body_bytes: None, - tray_token_rate: TrayTokenRateConfig::default(), - enable_api_format_conversion: true, - upstream_strategy: UpstreamStrategy::PriorityFillFirst, - upstreams: Vec::new(), - } - } -} - -#[derive(Clone)] -pub(crate) struct ProxyConfig { - pub(crate) host: String, - pub(crate) port: u16, - pub(crate) local_api_key: Option, - pub(crate) log_level: LogLevel, - pub(crate) max_request_body_bytes: usize, - pub(crate) enable_api_format_conversion: bool, - pub(crate) upstream_strategy: UpstreamStrategy, - pub(crate) upstreams: HashMap, - pub(crate) kiro_preferred_endpoint: Option, - pub(crate) antigravity_user_agent: Option, -} - -fn deserialize_log_level<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - let raw = Option::::deserialize(deserializer)?; - let value = raw.unwrap_or_default().trim().to_ascii_lowercase(); - if value.is_empty() { - return Ok(LogLevel::Silent); - } - match value.as_str() { - "silent" => Ok(LogLevel::Silent), - "error" => Ok(LogLevel::Error), - "warn" | "warning" => Ok(LogLevel::Warn), - "info" => Ok(LogLevel::Info), - "debug" => Ok(LogLevel::Debug), - "trace" => Ok(LogLevel::Trace), - other => Err(de::Error::custom(format!("invalid log_level: {other}"))), - } -} - -#[derive(Clone)] -pub(crate) struct ProviderUpstreams { - pub(crate) groups: Vec, -} - -#[derive(Clone)] -pub(crate) struct UpstreamGroup { - pub(crate) priority: i32, - pub(crate) items: Vec, -} - -#[derive(Clone)] -pub(crate) struct UpstreamRuntime { - pub(crate) id: String, - pub(crate) base_url: String, - pub(crate) api_key: Option, - pub(crate) filter_prompt_cache_retention: bool, - pub(crate) filter_safety_identifier: bool, - pub(crate) kiro_account_id: Option, - pub(crate) codex_account_id: Option, - pub(crate) antigravity_account_id: Option, - pub(crate) kiro_preferred_endpoint: Option, - pub(crate) proxy_url: Option, - pub(crate) priority: i32, - pub(crate) model_mappings: Option, - pub(crate) header_overrides: Option>, -} - -#[derive(Clone)] -pub(crate) struct HeaderOverride { - pub(crate) name: HeaderName, - pub(crate) value: Option, -} - -impl UpstreamRuntime { - /// 构建上游请求 URL,智能处理 base_url 与 path 的路径重叠 - /// 例如:base_url = "https://example.com/openai/v1", path = "/v1/chat/completions" - /// 结果:https://example.com/openai/v1/chat/completions(去掉重复的 /v1) - pub(crate) fn upstream_url(&self, path: &str) -> String { - let base = self.base_url.trim_end_matches('/'); - let effective_path = strip_overlapping_prefix(base, path); - format!("{base}{effective_path}") - } - - pub(crate) fn map_model(&self, model: &str) -> Option { - self.model_mappings - .as_ref() - .and_then(|rules| rules.map_model(model)) - .map(|value| value.to_string()) - } -} - -#[derive(Serialize)] -pub(crate) struct ConfigResponse { - pub(crate) path: String, - pub(crate) config: ProxyConfigFile, -} - -/// 去掉 path 开头与 base_url 路径部分重叠的前缀 -/// base_url: "https://example.com/openai/v1" -> base_path: "/openai/v1" -/// 如果 path 以 base_path 的某个后缀开头(如 "/v1"),则去掉该重叠部分 -pub(crate) fn strip_overlapping_prefix<'a>(base_url: &str, path: &'a str) -> &'a str { - let Some(base_path) = url::Url::parse(base_url) - .ok() - .map(|url| url.path().to_string()) - else { - return path; - }; - // 检查 base_path 的每个后缀是否与 path 的前缀重叠 - // 例如 base_path = "/openai/v1",依次检查 "/openai/v1", "/v1" - let base_path = base_path.trim_end_matches('/'); - for (idx, ch) in base_path.char_indices() { - if ch == '/' { - let suffix = &base_path[idx..]; - if path.starts_with(suffix) { - return &path[suffix.len()..]; - } - } - } - // 完整匹配检查(base_path 本身) - if path.starts_with(base_path) { - return &path[base_path.len()..]; - } - path -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "types.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/config/types.test.rs b/src-tauri/src/proxy/config/types.test.rs deleted file mode 100644 index b99a97f..0000000 --- a/src-tauri/src/proxy/config/types.test.rs +++ /dev/null @@ -1,156 +0,0 @@ -use super::*; - -#[test] -fn test_strip_overlapping_prefix() { - // 标准 OpenAI 兼容格式:base_url 包含 /v1 - assert_eq!( - strip_overlapping_prefix("https://api.example.com/openai/v1", "/v1/chat/completions"), - "/chat/completions" - ); - assert_eq!( - strip_overlapping_prefix("https://api.example.com/v1", "/v1/chat/completions"), - "/chat/completions" - ); - - // 无重叠情况:base_url 不包含路径 - assert_eq!( - strip_overlapping_prefix("https://api.openai.com", "/v1/chat/completions"), - "/v1/chat/completions" - ); - - // 无重叠情况:base_url 路径与请求路径无公共后缀 - assert_eq!( - strip_overlapping_prefix("https://api.example.com/openai/", "/v1/chat/completions"), - "/v1/chat/completions" - ); - assert_eq!( - strip_overlapping_prefix("https://api.example.com/openai", "/v1/chat/completions"), - "/v1/chat/completions" - ); - - // 多层路径重叠 - assert_eq!( - strip_overlapping_prefix("https://example.com/api/openai/v1", "/v1/models"), - "/models" - ); - - // 完整路径重叠 - assert_eq!( - strip_overlapping_prefix("https://example.com/openai/v1", "/openai/v1/completions"), - "/completions" - ); - - // 带尾斜杠的 base_url - assert_eq!( - strip_overlapping_prefix("https://example.com/v1/", "/v1/chat/completions"), - "/chat/completions" - ); - - // 无效 URL 回退 - assert_eq!( - strip_overlapping_prefix("not-a-valid-url", "/v1/chat/completions"), - "/v1/chat/completions" - ); -} - -#[test] -fn test_upstream_url() { - // openai provider: /v1/chat/completions - let upstream = UpstreamRuntime { - id: "test".to_string(), - base_url: "https://api.example.com/openai/v1".to_string(), - api_key: None, - filter_prompt_cache_retention: false, - filter_safety_identifier: false, - kiro_account_id: None, - codex_account_id: None, - antigravity_account_id: None, - kiro_preferred_endpoint: None, - proxy_url: None, - priority: 0, - model_mappings: None, - header_overrides: None, - }; - assert_eq!( - upstream.upstream_url("/v1/chat/completions"), - "https://api.example.com/openai/v1/chat/completions" - ); - - // openai-response provider: /v1/responses - let upstream_responses = UpstreamRuntime { - id: "test".to_string(), - base_url: "https://api.example.com/openai/v1".to_string(), - api_key: None, - filter_prompt_cache_retention: false, - filter_safety_identifier: false, - kiro_account_id: None, - codex_account_id: None, - antigravity_account_id: None, - kiro_preferred_endpoint: None, - proxy_url: None, - priority: 0, - model_mappings: None, - header_overrides: None, - }; - assert_eq!( - upstream_responses.upstream_url("/v1/responses"), - "https://api.example.com/openai/v1/responses" - ); - - // 无路径前缀的 base_url - let upstream_no_path = UpstreamRuntime { - id: "test".to_string(), - base_url: "https://api.openai.com".to_string(), - api_key: None, - filter_prompt_cache_retention: false, - filter_safety_identifier: false, - kiro_account_id: None, - codex_account_id: None, - antigravity_account_id: None, - kiro_preferred_endpoint: None, - proxy_url: None, - priority: 0, - model_mappings: None, - header_overrides: None, - }; - assert_eq!( - upstream_no_path.upstream_url("/v1/chat/completions"), - "https://api.openai.com/v1/chat/completions" - ); - assert_eq!( - upstream_no_path.upstream_url("/v1/responses"), - "https://api.openai.com/v1/responses" - ); - - // 带尾斜杠的 base_url - let upstream_trailing_slash = UpstreamRuntime { - id: "test".to_string(), - base_url: "https://api.example.com/openai/v1/".to_string(), - api_key: None, - filter_prompt_cache_retention: false, - filter_safety_identifier: false, - kiro_account_id: None, - codex_account_id: None, - antigravity_account_id: None, - kiro_preferred_endpoint: None, - proxy_url: None, - priority: 0, - model_mappings: None, - header_overrides: None, - }; - // openai: /v1/chat/completions - assert_eq!( - upstream_trailing_slash.upstream_url("/v1/chat/completions"), - "https://api.example.com/openai/v1/chat/completions" - ); - // openai-response: /v1/responses - assert_eq!( - upstream_trailing_slash.upstream_url("/v1/responses"), - "https://api.example.com/openai/v1/responses" - ); - // anthropic: /v1/messages - assert_eq!( - upstream_trailing_slash.upstream_url("/v1/messages"), - "https://api.example.com/openai/v1/messages" - ); -} diff --git a/src-tauri/src/proxy/dashboard.rs b/src-tauri/src/proxy/dashboard.rs deleted file mode 100644 index 842acfb..0000000 --- a/src-tauri/src/proxy/dashboard.rs +++ /dev/null @@ -1,481 +0,0 @@ -use serde::{Deserialize, Serialize}; -use sqlx::Row; -use tauri::AppHandle; -use std::collections::HashMap; - -use super::sqlite; - -const RECENT_PAGE_SIZE: u32 = 50; - -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct DashboardRange { - pub(crate) from_ts_ms: Option, - pub(crate) to_ts_ms: Option, -} - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct DashboardSummary { - pub(crate) total_requests: u64, - pub(crate) success_requests: u64, - pub(crate) error_requests: u64, - pub(crate) total_tokens: u64, - pub(crate) input_tokens: u64, - pub(crate) output_tokens: u64, - pub(crate) cached_tokens: u64, - pub(crate) avg_latency_ms: u64, -} - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct DashboardProviderStat { - pub(crate) provider: String, - pub(crate) requests: u64, - pub(crate) total_tokens: u64, - pub(crate) cached_tokens: u64, -} - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct DashboardSeriesPoint { - pub(crate) ts_ms: u64, - pub(crate) total_requests: u64, - pub(crate) error_requests: u64, - pub(crate) input_tokens: u64, - pub(crate) output_tokens: u64, - pub(crate) cached_tokens: u64, - pub(crate) total_tokens: u64, -} - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct DashboardRequestItem { - pub(crate) id: u64, - pub(crate) ts_ms: u64, - pub(crate) path: String, - pub(crate) provider: String, - pub(crate) upstream_id: String, - pub(crate) model: Option, - pub(crate) mapped_model: Option, - pub(crate) stream: bool, - pub(crate) status: u16, - pub(crate) total_tokens: Option, - pub(crate) cached_tokens: Option, - pub(crate) latency_ms: u64, - pub(crate) upstream_request_id: Option, -} - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct DashboardSnapshot { - pub(crate) summary: DashboardSummary, - pub(crate) providers: Vec, - pub(crate) series: Vec, - pub(crate) recent: Vec, - /// 是否只基于日志文件末尾片段做统计(Step1:true;Step2 SQLite 后应为 false)。 - pub(crate) truncated: bool, -} - -pub(crate) async fn read_snapshot( - app: AppHandle, - range: DashboardRange, - offset: Option, -) -> Result { - let offset = offset.unwrap_or(0); - - let pool = sqlite::open_read_pool(&app).await?; - let from_ts_ms = range.from_ts_ms.map(|value| value as i64); - let to_ts_ms = range.to_ts_ms.map(|value| value as i64); - let bucket_ms = resolve_bucket_ms(&pool, from_ts_ms, to_ts_ms).await?; - - let summary = query_summary(&pool, from_ts_ms, to_ts_ms).await?; - let providers = query_providers(&pool, from_ts_ms, to_ts_ms).await?; - let series = query_series(&pool, from_ts_ms, to_ts_ms, bucket_ms).await?; - let series = fill_series_buckets(series, from_ts_ms, to_ts_ms, bucket_ms); - let recent = query_recent(&pool, from_ts_ms, to_ts_ms, offset).await?; - - Ok(DashboardSnapshot { - summary, - providers, - series, - recent, - truncated: false, - }) -} - -async fn query_summary( - pool: &sqlx::SqlitePool, - from_ts_ms: Option, - to_ts_ms: Option, -) -> Result { - let row = sqlx::query( - r#" -SELECT - COUNT(*) AS total_requests, - COALESCE(SUM(CASE WHEN status BETWEEN 200 AND 299 THEN 1 ELSE 0 END), 0) AS success_requests, - COALESCE(SUM(CASE WHEN status >= 400 THEN 1 ELSE 0 END), 0) AS error_requests, - COALESCE(SUM(CASE - WHEN total_tokens IS NOT NULL THEN total_tokens - WHEN input_tokens IS NOT NULL OR output_tokens IS NOT NULL THEN COALESCE(input_tokens, 0) + COALESCE(output_tokens, 0) - ELSE 0 - END), 0) AS total_tokens, - COALESCE(SUM(COALESCE(input_tokens, 0)), 0) AS input_tokens, - COALESCE(SUM(COALESCE(output_tokens, 0)), 0) AS output_tokens, - COALESCE(SUM(COALESCE(cached_tokens, 0)), 0) AS cached_tokens, - COALESCE(SUM(latency_ms), 0) AS latency_sum_ms -FROM request_logs -WHERE (?1 IS NULL OR ts_ms >= ?1) AND (?2 IS NULL OR ts_ms <= ?2); -"#, - ) - .bind(from_ts_ms) - .bind(to_ts_ms) - .fetch_one(pool) - .await - .map_err(|err| format!("Failed to query dashboard summary: {err}"))?; - - let total_requests = i64_to_u64(row.try_get("total_requests").unwrap_or(0)); - let success_requests = i64_to_u64(row.try_get("success_requests").unwrap_or(0)); - let error_requests = i64_to_u64(row.try_get("error_requests").unwrap_or(0)); - let total_tokens = i64_to_u64(row.try_get("total_tokens").unwrap_or(0)); - let input_tokens = i64_to_u64(row.try_get("input_tokens").unwrap_or(0)); - let output_tokens = i64_to_u64(row.try_get("output_tokens").unwrap_or(0)); - let cached_tokens = i64_to_u64(row.try_get("cached_tokens").unwrap_or(0)); - let latency_sum_ms = i64_to_u64(row.try_get("latency_sum_ms").unwrap_or(0)); - - let avg_latency_ms = if total_requests == 0 { - 0 - } else { - latency_sum_ms / total_requests - }; - - Ok(DashboardSummary { - total_requests, - success_requests, - error_requests, - total_tokens, - input_tokens, - output_tokens, - cached_tokens, - avg_latency_ms, - }) -} - -async fn query_providers( - pool: &sqlx::SqlitePool, - from_ts_ms: Option, - to_ts_ms: Option, -) -> Result, String> { - let providers = sqlx::query( - r#" -SELECT - provider, - COUNT(*) AS requests, - COALESCE(SUM(CASE - WHEN total_tokens IS NOT NULL THEN total_tokens - WHEN input_tokens IS NOT NULL OR output_tokens IS NOT NULL THEN COALESCE(input_tokens, 0) + COALESCE(output_tokens, 0) - ELSE 0 - END), 0) AS total_tokens, - COALESCE(SUM(COALESCE(cached_tokens, 0)), 0) AS cached_tokens -FROM request_logs -WHERE (?1 IS NULL OR ts_ms >= ?1) AND (?2 IS NULL OR ts_ms <= ?2) -GROUP BY provider -ORDER BY total_tokens DESC; -"#, - ) - .bind(from_ts_ms) - .bind(to_ts_ms) - .fetch_all(pool) - .await - .map_err(|err| format!("Failed to query provider stats: {err}"))? - .into_iter() - .filter_map(|row| { - let provider: String = row.try_get("provider").ok()?; - let requests: i64 = row.try_get("requests").ok()?; - let total_tokens: i64 = row.try_get("total_tokens").ok()?; - let cached_tokens: i64 = row.try_get("cached_tokens").ok()?; - Some(DashboardProviderStat { - provider, - requests: i64_to_u64(requests), - total_tokens: i64_to_u64(total_tokens), - cached_tokens: i64_to_u64(cached_tokens), - }) - }) - .collect::>(); - - Ok(providers) -} - -async fn query_series( - pool: &sqlx::SqlitePool, - from_ts_ms: Option, - to_ts_ms: Option, - bucket_ms: u64, -) -> Result, String> { - let series = sqlx::query( - r#" -SELECT - (ts_ms / ?3) * ?3 AS bucket_ts_ms, - COUNT(*) AS total_requests, - COALESCE(SUM(CASE WHEN status >= 400 THEN 1 ELSE 0 END), 0) AS error_requests, - COALESCE(SUM(COALESCE(input_tokens, 0)), 0) AS input_tokens, - COALESCE(SUM(COALESCE(output_tokens, 0)), 0) AS output_tokens, - COALESCE(SUM(COALESCE(cached_tokens, 0)), 0) AS cached_tokens, - COALESCE(SUM(CASE - WHEN total_tokens IS NOT NULL THEN total_tokens - WHEN input_tokens IS NOT NULL OR output_tokens IS NOT NULL THEN COALESCE(input_tokens, 0) + COALESCE(output_tokens, 0) - ELSE 0 - END), 0) AS total_tokens -FROM request_logs -WHERE (?1 IS NULL OR ts_ms >= ?1) AND (?2 IS NULL OR ts_ms <= ?2) -GROUP BY bucket_ts_ms -ORDER BY bucket_ts_ms ASC; -"#, - ) - .bind(from_ts_ms) - .bind(to_ts_ms) - .bind(i64::try_from(bucket_ms).unwrap_or(i64::MAX)) - .fetch_all(pool) - .await - .map_err(|err| format!("Failed to query dashboard series: {err}"))? - .into_iter() - .filter_map(|row| { - let ts_ms: i64 = row.try_get("bucket_ts_ms").ok()?; - let total_requests: i64 = row.try_get("total_requests").ok()?; - let error_requests: i64 = row.try_get("error_requests").ok()?; - let input_tokens: i64 = row.try_get("input_tokens").ok()?; - let output_tokens: i64 = row.try_get("output_tokens").ok()?; - let cached_tokens: i64 = row.try_get("cached_tokens").ok()?; - let total_tokens: i64 = row.try_get("total_tokens").ok()?; - Some(DashboardSeriesPoint { - ts_ms: i64_to_u64(ts_ms), - total_requests: i64_to_u64(total_requests), - error_requests: i64_to_u64(error_requests), - input_tokens: i64_to_u64(input_tokens), - output_tokens: i64_to_u64(output_tokens), - cached_tokens: i64_to_u64(cached_tokens), - total_tokens: i64_to_u64(total_tokens), - }) - }) - .collect::>(); - - Ok(series) -} - -fn fill_series_buckets( - series: Vec, - from_ts_ms: Option, - to_ts_ms: Option, - bucket_ms: u64, -) -> Vec { - if bucket_ms == 0 { - return series; - } - - let resolved_from_ts_ms = from_ts_ms.or_else(|| { - series - .first() - .and_then(|point| i64::try_from(point.ts_ms).ok()) - }); - let resolved_to_ts_ms = to_ts_ms.or_else(|| { - series - .last() - .and_then(|point| i64::try_from(point.ts_ms).ok()) - }); - - // range=all 且没有任何数据时交给前端兜底(最近 7 天 0 线)。 - let (resolved_from_ts_ms, resolved_to_ts_ms) = match (resolved_from_ts_ms, resolved_to_ts_ms) { - (Some(from), Some(to)) => (from, to), - _ => return series, - }; - - let start_bucket_ts_ms = align_down_bucket_ts_ms(resolved_from_ts_ms, bucket_ms); - let end_bucket_ts_ms = align_down_bucket_ts_ms(resolved_to_ts_ms, bucket_ms); - - let (start_bucket_ts_ms, end_bucket_ts_ms) = if end_bucket_ts_ms < start_bucket_ts_ms { - (start_bucket_ts_ms, start_bucket_ts_ms) - } else { - (start_bucket_ts_ms, end_bucket_ts_ms) - }; - - let by_bucket: HashMap = series - .into_iter() - .map(|point| (point.ts_ms, point)) - .collect(); - - let expected_len = ((end_bucket_ts_ms - start_bucket_ts_ms) / bucket_ms).saturating_add(1); - let mut filled = Vec::with_capacity(usize::try_from(expected_len).unwrap_or(usize::MAX)); - - let mut cursor_ts_ms = start_bucket_ts_ms; - while cursor_ts_ms <= end_bucket_ts_ms { - if let Some(point) = by_bucket.get(&cursor_ts_ms) { - filled.push(point.clone()); - } else { - filled.push(DashboardSeriesPoint { - ts_ms: cursor_ts_ms, - total_requests: 0, - error_requests: 0, - input_tokens: 0, - output_tokens: 0, - cached_tokens: 0, - total_tokens: 0, - }); - } - - match cursor_ts_ms.checked_add(bucket_ms) { - Some(next) => cursor_ts_ms = next, - None => break, - } - } - - filled -} - -fn align_down_bucket_ts_ms(ts_ms: i64, bucket_ms: u64) -> u64 { - let ts_ms = i64_to_u64(ts_ms); - if bucket_ms == 0 { - return ts_ms; - } - (ts_ms / bucket_ms) * bucket_ms -} - -async fn query_recent( - pool: &sqlx::SqlitePool, - from_ts_ms: Option, - to_ts_ms: Option, - offset: u32, -) -> Result, String> { - let recent = sqlx::query( - r#" -SELECT - id, - ts_ms, - path, - provider, - upstream_id, - model, - mapped_model, - stream, - status, - CASE - WHEN total_tokens IS NOT NULL THEN total_tokens - WHEN input_tokens IS NOT NULL OR output_tokens IS NOT NULL THEN COALESCE(input_tokens, 0) + COALESCE(output_tokens, 0) - ELSE NULL - END AS total_tokens, - cached_tokens, - latency_ms, - upstream_request_id -FROM request_logs -WHERE (?1 IS NULL OR ts_ms >= ?1) AND (?2 IS NULL OR ts_ms <= ?2) -ORDER BY ts_ms DESC -LIMIT ?3 OFFSET ?4; -"#, - ) - .bind(from_ts_ms) - .bind(to_ts_ms) - .bind(i64::from(RECENT_PAGE_SIZE)) - .bind(i64::from(offset)) - .fetch_all(pool) - .await - .map_err(|err| format!("Failed to query recent requests: {err}"))? - .into_iter() - .filter_map(|row| { - let id: i64 = row.try_get("id").ok()?; - let ts_ms: i64 = row.try_get("ts_ms").ok()?; - let path: String = row.try_get("path").ok()?; - let provider: String = row.try_get("provider").ok()?; - let upstream_id: String = row.try_get("upstream_id").ok()?; - let model: Option = row.try_get("model").ok()?; - let mapped_model: Option = row.try_get("mapped_model").ok()?; - let stream: bool = row.try_get("stream").unwrap_or(false); - let status: i64 = row.try_get("status").unwrap_or(0); - let total_tokens: Option = row.try_get("total_tokens").ok()?; - let cached_tokens: Option = row.try_get("cached_tokens").ok()?; - let latency_ms: i64 = row.try_get("latency_ms").unwrap_or(0); - let upstream_request_id: Option = row.try_get("upstream_request_id").ok()?; - Some(DashboardRequestItem { - id: i64_to_u64(id), - ts_ms: i64_to_u64(ts_ms), - path, - provider, - upstream_id, - model, - mapped_model, - stream, - status: i64_to_u16(status), - total_tokens: total_tokens.map(i64_to_u64), - cached_tokens: cached_tokens.map(i64_to_u64), - latency_ms: i64_to_u64(latency_ms), - upstream_request_id, - }) - }) - .collect::>(); - - Ok(recent) -} - -async fn resolve_bucket_ms( - pool: &sqlx::SqlitePool, - from_ts_ms: Option, - to_ts_ms: Option, -) -> Result { - if let (Some(from), Some(to)) = (from_ts_ms, to_ts_ms) { - let span_ms = (to - from).max(0) as u64; - return Ok(select_bucket_ms(span_ms)); - } - - let row = sqlx::query( - r#" -SELECT - MIN(ts_ms) AS min_ts, - MAX(ts_ms) AS max_ts -FROM request_logs -WHERE (?1 IS NULL OR ts_ms >= ?1) AND (?2 IS NULL OR ts_ms <= ?2); -"#, - ) - .bind(from_ts_ms) - .bind(to_ts_ms) - .fetch_one(pool) - .await - .map_err(|err| format!("Failed to query dashboard range: {err}"))?; - - let min_ts: Option = row.try_get("min_ts").ok(); - let max_ts: Option = row.try_get("max_ts").ok(); - let start = from_ts_ms.or(min_ts).unwrap_or(0); - let end = to_ts_ms.or(max_ts).unwrap_or(start); - let span_ms = (end - start).max(0) as u64; - Ok(select_bucket_ms(span_ms)) -} - -fn select_bucket_ms(span_ms: u64) -> u64 { - // 根据跨度选择合适的桶大小,避免点数过多或过少。 - if span_ms <= 60 * 60 * 1000 { - return 5 * 60 * 1000; - } - if span_ms <= 6 * 60 * 60 * 1000 { - return 15 * 60 * 1000; - } - if span_ms <= 24 * 60 * 60 * 1000 { - return 30 * 60 * 1000; - } - if span_ms <= 7 * 24 * 60 * 60 * 1000 { - return 2 * 60 * 60 * 1000; - } - if span_ms <= 31 * 24 * 60 * 60 * 1000 { - return 24 * 60 * 60 * 1000; - } - 7 * 24 * 60 * 60 * 1000 -} - -fn i64_to_u64(value: i64) -> u64 { - value.max(0) as u64 -} - -fn i64_to_u16(value: i64) -> u16 { - value.clamp(0, u16::MAX as i64) as u16 -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "dashboard.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/dashboard.test.rs b/src-tauri/src/proxy/dashboard.test.rs deleted file mode 100644 index f0386d0..0000000 --- a/src-tauri/src/proxy/dashboard.test.rs +++ /dev/null @@ -1,61 +0,0 @@ -use super::*; - -fn series_point(ts_ms: u64, total_requests: u64) -> DashboardSeriesPoint { - DashboardSeriesPoint { - ts_ms, - total_requests, - error_requests: 0, - input_tokens: total_requests, - output_tokens: 0, - cached_tokens: 0, - total_tokens: total_requests, - } -} - -#[test] -fn fill_series_buckets_inserts_missing_points() { - let bucket_ms = 60_000; - let series = vec![series_point(0, 1), series_point(120_000, 2)]; - let filled = fill_series_buckets(series, Some(0), Some(120_000), bucket_ms); - assert_eq!(filled.len(), 3); - assert_eq!(filled[0].ts_ms, 0); - assert_eq!(filled[0].total_requests, 1); - assert_eq!(filled[1].ts_ms, 60_000); - assert_eq!(filled[1].total_requests, 0); - assert_eq!(filled[2].ts_ms, 120_000); - assert_eq!(filled[2].total_requests, 2); -} - -#[test] -fn fill_series_buckets_pads_start_and_end_of_range() { - let bucket_ms = 60_000; - let series = vec![series_point(120_000, 3)]; - let filled = fill_series_buckets(series, Some(0), Some(180_000), bucket_ms); - assert_eq!(filled.len(), 4); - assert_eq!(filled[0].ts_ms, 0); - assert_eq!(filled[0].total_requests, 0); - assert_eq!(filled[1].ts_ms, 60_000); - assert_eq!(filled[1].total_requests, 0); - assert_eq!(filled[2].ts_ms, 120_000); - assert_eq!(filled[2].total_requests, 3); - assert_eq!(filled[3].ts_ms, 180_000); - assert_eq!(filled[3].total_requests, 0); -} - -#[test] -fn fill_series_buckets_handles_empty_series_with_explicit_range() { - let bucket_ms = 60_000; - let filled = fill_series_buckets(Vec::new(), Some(0), Some(120_000), bucket_ms); - assert_eq!(filled.len(), 3); - assert_eq!(filled[0].ts_ms, 0); - assert_eq!(filled[1].ts_ms, 60_000); - assert_eq!(filled[2].ts_ms, 120_000); - assert!(filled.iter().all(|point| point.total_requests == 0)); -} - -#[test] -fn fill_series_buckets_returns_original_when_range_unknown_and_empty() { - let bucket_ms = 60_000; - let filled = fill_series_buckets(Vec::new(), None, None, bucket_ms); - assert!(filled.is_empty()); -} diff --git a/src-tauri/src/proxy/gemini.rs b/src-tauri/src/proxy/gemini.rs deleted file mode 100644 index 8512c1c..0000000 --- a/src-tauri/src/proxy/gemini.rs +++ /dev/null @@ -1,31 +0,0 @@ -pub(crate) const GEMINI_MODELS_PREFIX: &str = "/v1beta/models/"; -const GEMINI_GENERATE_SUFFIX: &str = ":generateContent"; -const GEMINI_STREAM_SUFFIX: &str = ":streamGenerateContent"; - -pub(crate) fn is_gemini_path(path: &str) -> bool { - if !path.starts_with(GEMINI_MODELS_PREFIX) { - return false; - } - path.ends_with(GEMINI_GENERATE_SUFFIX) || path.ends_with(GEMINI_STREAM_SUFFIX) -} - -pub(crate) fn is_gemini_stream_path(path: &str) -> bool { - path.starts_with(GEMINI_MODELS_PREFIX) && path.ends_with(GEMINI_STREAM_SUFFIX) -} - -pub(crate) fn parse_gemini_model_from_path(path: &str) -> Option { - let rest = path.strip_prefix(GEMINI_MODELS_PREFIX)?; - let (model, _) = rest.split_once(':')?; - let model = model.trim(); - if model.is_empty() { - None - } else { - Some(model.to_string()) - } -} - -pub(crate) fn replace_gemini_model_in_path(path: &str, model: &str) -> Option { - let rest = path.strip_prefix(GEMINI_MODELS_PREFIX)?; - let (_, suffix) = rest.split_once(':')?; - Some(format!("{GEMINI_MODELS_PREFIX}{model}:{suffix}")) -} diff --git a/src-tauri/src/proxy/gemini_compat/mod.rs b/src-tauri/src/proxy/gemini_compat/mod.rs deleted file mode 100644 index 9f04179..0000000 --- a/src-tauri/src/proxy/gemini_compat/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod request; -mod response; -mod stream; -mod tools; - -pub(crate) use request::chat_request_to_gemini; -pub(crate) use request::gemini_request_to_chat; -pub(crate) use response::chat_response_to_gemini; -pub(crate) use response::gemini_response_to_chat; -pub(crate) use stream::stream_chat_to_gemini; -pub(crate) use stream::stream_gemini_to_chat; diff --git a/src-tauri/src/proxy/gemini_compat/request.rs b/src-tauri/src/proxy/gemini_compat/request.rs deleted file mode 100644 index 45de412..0000000 --- a/src-tauri/src/proxy/gemini_compat/request.rs +++ /dev/null @@ -1,613 +0,0 @@ -//! OpenAI Chat 请求 → Gemini 请求转换 - -use axum::body::Bytes; -use serde_json::{json, Map, Value}; - -use super::tools::{ - map_chat_tool_choice_to_gemini, map_chat_tools_to_gemini, map_gemini_tool_config_to_chat, - map_gemini_tools_to_chat, gemini_function_call_to_chat_tool_call, -}; - -/// 将 OpenAI Chat 请求转换为 Gemini 格式 -pub(crate) fn chat_request_to_gemini(body: &Bytes) -> Result { - let value: Value = - serde_json::from_slice(body).map_err(|_| "Request body must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Request body must be a JSON object.".to_string()); - }; - - let Some(messages) = object.get("messages").and_then(Value::as_array) else { - return Err("Chat request must include messages.".to_string()); - }; - - let (contents, system_instruction) = chat_messages_to_gemini_contents(messages)?; - - let mut out = Map::new(); - out.insert("contents".to_string(), Value::Array(contents)); - - // 系统指令 - if let Some(system) = system_instruction { - out.insert( - "systemInstruction".to_string(), - json!({ "parts": [{ "text": system }] }), - ); - } - - // 生成参数 - let mut gen_config = Map::new(); - if let Some(temperature) = object.get("temperature").and_then(Value::as_f64) { - gen_config.insert("temperature".to_string(), json!(temperature)); - } - if let Some(top_p) = object.get("top_p").and_then(Value::as_f64) { - gen_config.insert("topP".to_string(), json!(top_p)); - } - if let Some(max_tokens) = object - .get("max_completion_tokens") - .or_else(|| object.get("max_tokens")) - .and_then(Value::as_i64) - { - gen_config.insert("maxOutputTokens".to_string(), json!(max_tokens)); - } - if let Some(stop) = object.get("stop") { - let stop_sequences = match stop { - Value::String(s) => vec![s.clone()], - Value::Array(arr) => arr - .iter() - .filter_map(Value::as_str) - .map(|s| s.to_string()) - .collect(), - _ => vec![], - }; - if !stop_sequences.is_empty() { - gen_config.insert( - "stopSequences".to_string(), - Value::Array(stop_sequences.into_iter().map(Value::String).collect()), - ); - } - } - if let Some(seed) = object.get("seed").and_then(Value::as_i64) { - gen_config.insert("seed".to_string(), json!(seed)); - } - // 响应格式 - if let Some(response_format) = object.get("response_format").and_then(Value::as_object) { - if let Some(format_type) = response_format.get("type").and_then(Value::as_str) { - if format_type == "json_object" || format_type == "json_schema" { - gen_config.insert( - "responseMimeType".to_string(), - json!("application/json"), - ); - // 如有 json_schema,复制 schema - if format_type == "json_schema" { - if let Some(schema) = response_format - .get("json_schema") - .and_then(Value::as_object) - .and_then(|js| js.get("schema")) - { - gen_config.insert("responseSchema".to_string(), schema.clone()); - } - } - } - } - } - if !gen_config.is_empty() { - out.insert("generationConfig".to_string(), Value::Object(gen_config)); - } - - // 工具 - if let Some(tools) = object.get("tools") { - let gemini_tools = map_chat_tools_to_gemini(tools); - if let Some(arr) = gemini_tools.as_array() { - if !arr.is_empty() { - out.insert("tools".to_string(), gemini_tools); - } - } - } - if let Some(tool_choice) = object.get("tool_choice") { - if let Some(tool_config) = map_chat_tool_choice_to_gemini(tool_choice) { - out.insert("toolConfig".to_string(), tool_config); - } - } - - // 默认安全设置(参考 new-api:禁用内容过滤以保证完整回复) - out.insert( - "safetySettings".to_string(), - json!([ - { "category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE" }, - { "category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE" }, - { "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE" }, - { "category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE" } - ]), - ); - - serde_json::to_vec(&Value::Object(out)) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize Gemini request: {err}")) -} - -/// 将 Gemini 请求转换为 OpenAI Chat 格式 -pub(crate) fn gemini_request_to_chat( - body: &Bytes, - model_hint: Option<&str>, -) -> Result { - let value: Value = - serde_json::from_slice(body).map_err(|_| "Request body must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Request body must be a JSON object.".to_string()); - }; - - let Some(contents) = object.get("contents").and_then(Value::as_array) else { - return Err("Gemini request must include contents.".to_string()); - }; - - let mut messages = gemini_contents_to_chat_messages(contents)?; - if let Some(system) = extract_system_instruction(object.get("systemInstruction")) { - messages.insert(0, json!({ "role": "system", "content": system })); - } - - let mut out = Map::new(); - if let Some(model) = object.get("model").and_then(Value::as_str).or(model_hint) { - out.insert("model".to_string(), Value::String(model.to_string())); - } - out.insert("messages".to_string(), Value::Array(messages)); - - if let Some(gen_config) = object.get("generationConfig").and_then(Value::as_object) { - map_generation_config_to_chat(gen_config, &mut out); - } - - if let Some(tools) = object.get("tools") { - let tools = map_gemini_tools_to_chat(tools); - if tools.as_array().is_some_and(|arr| !arr.is_empty()) { - out.insert("tools".to_string(), tools); - } - } - if let Some(tool_config) = object.get("toolConfig") { - if let Some(tool_choice) = map_gemini_tool_config_to_chat(tool_config) { - out.insert("tool_choice".to_string(), tool_choice); - } - } - - serde_json::to_vec(&Value::Object(out)) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize Chat request: {err}")) -} - -/// 将 Chat messages 转换为 Gemini contents,并提取系统指令 -fn chat_messages_to_gemini_contents( - messages: &[Value], -) -> Result<(Vec, Option), String> { - let mut system_texts = Vec::new(); - let mut contents = Vec::new(); - - for message in messages { - let Some(message) = message.as_object() else { - continue; - }; - - let role = message.get("role").and_then(Value::as_str).unwrap_or("user"); - match role { - "system" | "developer" => { - if let Some(text) = extract_text_from_content(message.get("content")) { - system_texts.push(text); - } - } - "user" => { - let parts = chat_content_to_gemini_parts(message.get("content"))?; - if !parts.is_empty() { - contents.push(json!({ "role": "user", "parts": parts })); - } - } - "assistant" => { - let mut parts = chat_content_to_gemini_parts(message.get("content"))?; - // 处理 tool_calls - if let Some(tool_calls) = message.get("tool_calls").and_then(Value::as_array) { - for tool_call in tool_calls { - if let Some(fc) = chat_tool_call_to_gemini_function_call(tool_call) { - parts.push(fc); - } - } - } - // 处理旧版 function_call - if let Some(function_call) = - message.get("function_call").and_then(Value::as_object) - { - if let Some(fc) = legacy_function_call_to_gemini(function_call) { - parts.push(fc); - } - } - if !parts.is_empty() { - contents.push(json!({ "role": "model", "parts": parts })); - } - } - "tool" | "function" => { - // 工具结果 → functionResponse - let name = message - .get("name") - .or_else(|| message.get("tool_call_id")) - .and_then(Value::as_str) - .unwrap_or("function"); - let response_content = message.get("content"); - let response = parse_tool_response_content(response_content); - contents.push(json!({ - "role": "user", - "parts": [{ - "functionResponse": { - "name": name, - "response": response - } - }] - })); - } - _ => {} - } - } - - let system_instruction = if system_texts.is_empty() { - None - } else { - Some( - system_texts - .into_iter() - .filter(|t| !t.trim().is_empty()) - .collect::>() - .join("\n"), - ) - }; - - Ok((contents, system_instruction)) -} - -fn gemini_contents_to_chat_messages(contents: &[Value]) -> Result, String> { - let mut messages = Vec::new(); - for content in contents { - let Some(content) = content.as_object() else { - continue; - }; - let mut converted = gemini_content_to_chat_messages(content)?; - messages.append(&mut converted); - } - Ok(messages) -} - -fn gemini_content_to_chat_messages(content: &serde_json::Map) -> Result, String> { - let role = content.get("role").and_then(Value::as_str).unwrap_or("user"); - let role = if role == "model" { "assistant" } else { role }; - let parts = content - .get("parts") - .and_then(Value::as_array) - .map(|value| value.as_slice()) - .unwrap_or(&[]); - - let mut messages = Vec::new(); - let mut content_parts: Vec = Vec::new(); - let mut tool_calls: Vec = Vec::new(); - - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - if let Some(tool_message) = function_response_to_chat_message(part) { - if !content_parts.is_empty() || !tool_calls.is_empty() { - messages.push(build_chat_message(role, &content_parts, &tool_calls)); - content_parts.clear(); - tool_calls.clear(); - } - messages.push(tool_message); - continue; - } - if let Some(function_call) = part.get("functionCall").and_then(Value::as_object) { - let tool_call = - gemini_function_call_to_chat_tool_call(function_call, tool_calls.len()); - tool_calls.push(tool_call); - continue; - } - if let Some(content_part) = gemini_part_to_chat_content_part(part) { - content_parts.push(content_part); - } - } - - if !content_parts.is_empty() || !tool_calls.is_empty() { - messages.push(build_chat_message(role, &content_parts, &tool_calls)); - } - - Ok(messages) -} - -fn build_chat_message(role: &str, content_parts: &[Value], tool_calls: &[Value]) -> Value { - let content = build_chat_content(content_parts); - let mut message = json!({ "role": role, "content": content }); - if !tool_calls.is_empty() { - if let Some(message) = message.as_object_mut() { - message.insert("tool_calls".to_string(), Value::Array(tool_calls.to_vec())); - } - } - message -} - -fn build_chat_content(parts: &[Value]) -> Value { - if parts.is_empty() { - return Value::String(String::new()); - } - let mut combined = String::new(); - let mut text_only = true; - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - if part.get("type").and_then(Value::as_str) != Some("text") { - text_only = false; - } - if let Some(text) = part.get("text").and_then(Value::as_str) { - combined.push_str(text); - } - } - if text_only { - Value::String(combined) - } else { - Value::Array(parts.to_vec()) - } -} - -fn gemini_part_to_chat_content_part(part: &serde_json::Map) -> Option { - if let Some(text) = part.get("text").and_then(Value::as_str) { - return Some(json!({ "type": "text", "text": text })); - } - if let Some(inline) = part.get("inlineData").and_then(Value::as_object) { - return gemini_inline_data_to_image_url(inline); - } - if let Some(file_data) = part.get("fileData").and_then(Value::as_object) { - return gemini_file_data_to_image_url(file_data); - } - None -} - -fn gemini_inline_data_to_image_url(data: &serde_json::Map) -> Option { - let mime = data - .get("mimeType") - .and_then(Value::as_str) - .unwrap_or("application/octet-stream"); - let payload = data.get("data").and_then(Value::as_str)?; - let url = format!("data:{mime};base64,{payload}"); - Some(json!({ "type": "image_url", "image_url": { "url": url } })) -} - -fn gemini_file_data_to_image_url(data: &serde_json::Map) -> Option { - let uri = data.get("fileUri").and_then(Value::as_str)?; - Some(json!({ "type": "image_url", "image_url": { "url": uri } })) -} - -fn function_response_to_chat_message(part: &serde_json::Map) -> Option { - let response = part.get("functionResponse")?.as_object()?; - let name = response.get("name").and_then(Value::as_str).unwrap_or(""); - let payload = response.get("response").cloned().unwrap_or_else(|| json!({})); - let content = match payload { - Value::String(text) => text, - other => serde_json::to_string(&other).unwrap_or_else(|_| "{}".to_string()), - }; - let mut message = json!({ "role": "tool", "content": content }); - if !name.is_empty() { - if let Some(message) = message.as_object_mut() { - message.insert("name".to_string(), Value::String(name.to_string())); - } - } - Some(message) -} - -fn extract_system_instruction(value: Option<&Value>) -> Option { - let Some(value) = value else { - return None; - }; - let parts = value.get("parts").and_then(Value::as_array)?; - let mut texts = Vec::new(); - for part in parts { - let Some(text) = part.get("text").and_then(Value::as_str) else { - continue; - }; - if !text.trim().is_empty() { - texts.push(text.to_string()); - } - } - if texts.is_empty() { - None - } else { - Some(texts.join("\n")) - } -} - -fn map_generation_config_to_chat( - gen_config: &serde_json::Map, - out: &mut Map, -) { - if let Some(temperature) = gen_config.get("temperature").and_then(Value::as_f64) { - out.insert("temperature".to_string(), json!(temperature)); - } - if let Some(top_p) = gen_config.get("topP").and_then(Value::as_f64) { - out.insert("top_p".to_string(), json!(top_p)); - } - if let Some(max_tokens) = gen_config.get("maxOutputTokens").and_then(Value::as_i64) { - out.insert( - "max_completion_tokens".to_string(), - Value::Number(max_tokens.into()), - ); - } - if let Some(stop) = map_stop_sequences(gen_config.get("stopSequences")) { - out.insert("stop".to_string(), stop); - } - if let Some(seed) = gen_config.get("seed").and_then(Value::as_i64) { - out.insert("seed".to_string(), json!(seed)); - } - if let Some(response_format) = map_gemini_response_format(gen_config) { - out.insert("response_format".to_string(), response_format); - } -} - -fn map_stop_sequences(value: Option<&Value>) -> Option { - let Some(sequences) = value.and_then(Value::as_array) else { - return None; - }; - let items = sequences - .iter() - .filter_map(Value::as_str) - .map(|item| Value::String(item.to_string())) - .collect::>(); - if items.is_empty() { - None - } else if items.len() == 1 { - items.first().cloned() - } else { - Some(Value::Array(items)) - } -} - -fn map_gemini_response_format(gen_config: &serde_json::Map) -> Option { - if let Some(schema) = gen_config.get("responseSchema") { - return Some(json!({ - "type": "json_schema", - "json_schema": { "schema": schema.clone() } - })); - } - let mime = gen_config - .get("responseMimeType") - .and_then(Value::as_str) - .unwrap_or(""); - if mime.contains("json") { - return Some(json!({ "type": "json_object" })); - } - None -} - -/// 从 content 中提取纯文本 -fn extract_text_from_content(content: Option<&Value>) -> Option { - let Some(content) = content else { - return None; - }; - match content { - Value::String(text) => Some(text.to_string()), - Value::Array(parts) => { - let mut combined = String::new(); - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - if let Some(text) = part.get("text").and_then(Value::as_str) { - combined.push_str(text); - } - } - if combined.is_empty() { - None - } else { - Some(combined) - } - } - _ => None, - } -} - -/// 将 Chat content 转换为 Gemini parts -fn chat_content_to_gemini_parts(content: Option<&Value>) -> Result, String> { - let Some(content) = content else { - return Ok(Vec::new()); - }; - match content { - Value::String(text) => Ok(vec![json!({ "text": text })]), - Value::Null => Ok(vec![]), - Value::Array(parts) => { - let mut out = Vec::new(); - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str).unwrap_or(""); - match part_type { - "text" => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - out.push(json!({ "text": text })); - } - } - "image_url" => { - if let Some(inline_data) = image_url_to_gemini_inline_data(part) { - out.push(json!({ "inlineData": inline_data })); - } - } - _ => {} - } - } - Ok(out) - } - _ => Ok(Vec::new()), - } -} - -/// 将 image_url 转换为 Gemini inlineData -fn image_url_to_gemini_inline_data(part: &serde_json::Map) -> Option { - let image_url = part.get("image_url")?; - let url = match image_url { - Value::String(s) => s.as_str(), - Value::Object(obj) => obj.get("url").and_then(Value::as_str)?, - _ => return None, - }; - - // 处理 data URI(base64) - if let Some(rest) = url.strip_prefix("data:") { - if let Some((mime_type, data)) = rest.split_once(";base64,") { - return Some(json!({ - "mimeType": mime_type, - "data": data - })); - } - } - - // HTTP(S) URL → Gemini 支持 fileData 格式引用外部文件 - // 但简单起见,这里仅支持 base64 inline,外部 URL 需要先下载 - // 对于外部 URL,返回 None(或可扩展为使用 fileData) - None -} - -/// 将 OpenAI tool_call 转换为 Gemini functionCall -fn chat_tool_call_to_gemini_function_call(tool_call: &Value) -> Option { - let tool_call = tool_call.as_object()?; - let function = tool_call.get("function")?.as_object()?; - let name = function.get("name").and_then(Value::as_str)?; - let arguments = function.get("arguments").and_then(Value::as_str).unwrap_or("{}"); - let args: Value = serde_json::from_str(arguments).unwrap_or_else(|_| json!({})); - Some(json!({ - "functionCall": { - "name": name, - "args": args - } - })) -} - -/// 将旧版 function_call 转换为 Gemini functionCall -fn legacy_function_call_to_gemini(function_call: &serde_json::Map) -> Option { - let name = function_call.get("name").and_then(Value::as_str)?; - let arguments = function_call - .get("arguments") - .and_then(Value::as_str) - .unwrap_or("{}"); - let args: Value = serde_json::from_str(arguments).unwrap_or_else(|_| json!({})); - Some(json!({ - "functionCall": { - "name": name, - "args": args - } - })) -} - -/// 解析工具响应内容 -fn parse_tool_response_content(content: Option<&Value>) -> Value { - let Some(content) = content else { - return json!({}); - }; - match content { - Value::String(s) => { - // 尝试解析为 JSON - serde_json::from_str(s).unwrap_or_else(|_| json!({ "result": s })) - } - other => other.clone(), - } -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "request.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/gemini_compat/request.test.rs b/src-tauri/src/proxy/gemini_compat/request.test.rs deleted file mode 100644 index 4d49824..0000000 --- a/src-tauri/src/proxy/gemini_compat/request.test.rs +++ /dev/null @@ -1,58 +0,0 @@ -use super::*; -use serde_json::json; - -#[test] -fn gemini_request_to_chat_maps_system_tools_and_format() { - let input = json!({ - "systemInstruction": { "parts": [{ "text": "sys" }] }, - "contents": [ - { "role": "user", "parts": [{ "text": "hi" }] } - ], - "generationConfig": { - "temperature": 0.2, - "topP": 0.8, - "maxOutputTokens": 12, - "responseMimeType": "application/json" - }, - "tools": [{ - "functionDeclarations": [ - { "name": "getFoo", "description": "x", "parameters": { "type": "object" } } - ] - }], - "toolConfig": { "functionCallingConfig": { "mode": "ANY", "allowedFunctionNames": ["getFoo"] } } - }); - - let output = gemini_request_to_chat( - &Bytes::from(serde_json::to_vec(&input).unwrap()), - Some("gemini-1.5-flash"), - ) - .expect("convert"); - let value: Value = serde_json::from_slice(&output).expect("json"); - assert_eq!(value["model"], json!("gemini-1.5-flash")); - assert_eq!(value["messages"][0]["role"], json!("system")); - assert_eq!(value["messages"][1]["role"], json!("user")); - assert_eq!(value["messages"][1]["content"], json!("hi")); - assert_eq!(value["tools"][0]["function"]["name"], json!("getFoo")); - assert_eq!(value["tool_choice"]["function"]["name"], json!("getFoo")); - assert_eq!(value["response_format"]["type"], json!("json_object")); - assert_eq!(value["max_completion_tokens"], json!(12)); -} - -#[test] -fn gemini_request_to_chat_maps_function_response() { - let input = json!({ - "contents": [ - { - "role": "user", - "parts": [ - { "functionResponse": { "name": "getFoo", "response": { "ok": true } } } - ] - } - ] - }); - let output = gemini_request_to_chat(&Bytes::from(serde_json::to_vec(&input).unwrap()), None) - .expect("convert"); - let value: Value = serde_json::from_slice(&output).expect("json"); - assert_eq!(value["messages"][0]["role"], json!("tool")); - assert_eq!(value["messages"][0]["name"], json!("getFoo")); -} diff --git a/src-tauri/src/proxy/gemini_compat/response.rs b/src-tauri/src/proxy/gemini_compat/response.rs deleted file mode 100644 index 67f6beb..0000000 --- a/src-tauri/src/proxy/gemini_compat/response.rs +++ /dev/null @@ -1,410 +0,0 @@ -//! Gemini 响应 → OpenAI Chat 响应转换 - -use axum::body::Bytes; -use serde_json::{json, Map, Value}; - -use super::tools::gemini_function_call_to_chat_tool_call; - -/// 将 OpenAI Chat 响应转换为 Gemini 格式 -pub(crate) fn chat_response_to_gemini( - bytes: &Bytes, - _model_hint: Option<&str>, -) -> Result { - let value: Value = - serde_json::from_slice(bytes).map_err(|_| "Upstream response must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Upstream response must be a JSON object.".to_string()); - }; - - let choices = object - .get("choices") - .and_then(Value::as_array) - .map(|arr| arr.as_slice()) - .unwrap_or(&[]); - - let mut candidates = Vec::new(); - for (index, choice) in choices.iter().enumerate() { - if let Some(candidate) = chat_choice_to_gemini_candidate(choice, index) { - candidates.push(candidate); - } - } - if candidates.is_empty() { - candidates.push(json!({ - "index": 0, - "content": { "role": "model", "parts": [] }, - "finishReason": "STOP" - })); - } - - let usage = object - .get("usage") - .and_then(Value::as_object) - .and_then(map_chat_usage_to_gemini_usage); - - let mut output = json!({ - "candidates": candidates - }); - if let Some(usage) = usage { - if let Some(obj) = output.as_object_mut() { - obj.insert("usageMetadata".to_string(), usage); - } - } - - serde_json::to_vec(&output) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize Gemini response: {err}")) -} - -/// 将 Gemini 响应转换为 OpenAI Chat 格式 -pub(crate) fn gemini_response_to_chat(bytes: &Bytes, model_hint: Option<&str>) -> Result { - let value: Value = - serde_json::from_slice(bytes).map_err(|_| "Upstream response must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Upstream response must be a JSON object.".to_string()); - }; - - // 检查是否有 error 字段(Gemini 错误响应) - if let Some(error) = object.get("error") { - return handle_gemini_error(error, model_hint); - } - - let candidates = object - .get("candidates") - .and_then(Value::as_array) - .map(|arr| arr.as_slice()) - .unwrap_or(&[]); - - let model = model_hint.unwrap_or("gemini"); - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(); - let id = format!("chatcmpl_gemini_{now_ms}"); - let created = (now_ms / 1000) as i64; - - let mut choices = Vec::new(); - for (index, candidate) in candidates.iter().enumerate() { - if let Some(choice) = gemini_candidate_to_chat_choice(candidate, index) { - choices.push(choice); - } - } - - // 如果没有候选结果,创建一个空的选择 - if choices.is_empty() { - choices.push(json!({ - "index": 0, - "message": { - "role": "assistant", - "content": "" - }, - "finish_reason": "stop" - })); - } - - let usage = object - .get("usageMetadata") - .and_then(Value::as_object) - .map(gemini_usage_to_chat_usage); - - let out = json!({ - "id": id, - "object": "chat.completion", - "created": created, - "model": model, - "choices": choices, - "usage": usage - }); - - serde_json::to_vec(&out) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize Chat response: {err}")) -} - -/// 处理 Gemini 错误响应 -fn handle_gemini_error(error: &Value, model_hint: Option<&str>) -> Result { - let message = error - .get("message") - .and_then(Value::as_str) - .unwrap_or("Unknown error from Gemini"); - let code = error - .get("code") - .and_then(Value::as_i64) - .unwrap_or(500); - - let model = model_hint.unwrap_or("gemini"); - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(); - - let out = json!({ - "id": format!("chatcmpl_gemini_{now_ms}"), - "object": "chat.completion", - "created": (now_ms / 1000) as i64, - "model": model, - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": format!("Error from Gemini (code {}): {}", code, message) - }, - "finish_reason": "stop" - }], - "usage": null - }); - - serde_json::to_vec(&out) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize error response: {err}")) -} - -/// 将 Gemini candidate 转换为 Chat choice -fn gemini_candidate_to_chat_choice(candidate: &Value, index: usize) -> Option { - let candidate = candidate.as_object()?; - let content = candidate.get("content")?.as_object()?; - let parts = content.get("parts").and_then(Value::as_array)?; - - let mut text_content = String::new(); - let mut tool_calls = Vec::new(); - let mut tool_call_index = 0; - - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - - // 文本内容 - if let Some(text) = part.get("text").and_then(Value::as_str) { - text_content.push_str(text); - } - - // 函数调用 - if let Some(function_call) = part.get("functionCall").and_then(Value::as_object) { - let tool_call = gemini_function_call_to_chat_tool_call(function_call, tool_call_index); - tool_calls.push(tool_call); - tool_call_index += 1; - } - } - - let finish_reason = gemini_finish_reason_to_chat( - candidate.get("finishReason").and_then(Value::as_str), - !tool_calls.is_empty(), - ); - - let mut message = json!({ - "role": "assistant", - "content": if text_content.is_empty() { Value::Null } else { Value::String(text_content) } - }); - - if !tool_calls.is_empty() { - if let Some(msg) = message.as_object_mut() { - msg.insert("tool_calls".to_string(), Value::Array(tool_calls)); - } - } - - Some(json!({ - "index": index, - "message": message, - "finish_reason": finish_reason - })) -} - -/// 将 Gemini finishReason 转换为 Chat finish_reason -fn gemini_finish_reason_to_chat(reason: Option<&str>, has_tool_calls: bool) -> &'static str { - if has_tool_calls { - return "tool_calls"; - } - match reason { - Some("STOP") => "stop", - Some("MAX_TOKENS") => "length", - Some("SAFETY") => "content_filter", - Some("RECITATION") => "content_filter", - Some("OTHER") => "stop", - Some("BLOCKLIST") => "content_filter", - Some("PROHIBITED_CONTENT") => "content_filter", - Some("SPII") => "content_filter", - _ => "stop", - } -} - -/// 将 Gemini usageMetadata 转换为 Chat usage -fn gemini_usage_to_chat_usage(usage: &Map) -> Value { - let prompt_tokens = usage - .get("promptTokenCount") - .and_then(Value::as_u64) - .unwrap_or(0); - let completion_tokens = usage - .get("candidatesTokenCount") - .and_then(Value::as_u64) - .unwrap_or(0); - let total_tokens = usage - .get("totalTokenCount") - .and_then(Value::as_u64) - .unwrap_or(prompt_tokens + completion_tokens); - let cached_tokens = usage - .get("cachedContentTokenCount") - .and_then(Value::as_u64); - - let mut result = json!({ - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens - }); - - if let Some(cached) = cached_tokens { - if let Some(obj) = result.as_object_mut() { - obj.insert("cached_tokens".to_string(), json!(cached)); - } - } - - result -} - -fn chat_choice_to_gemini_candidate(choice: &Value, index: usize) -> Option { - let choice = choice.as_object()?; - let message = choice.get("message").and_then(Value::as_object)?; - - let content_parts = message.get("content_parts").and_then(Value::as_array); - let content = if let Some(parts) = content_parts { - map_chat_content_parts_to_gemini_parts(parts) - } else { - map_chat_content_to_gemini_parts(message.get("content")) - }; - - let tool_calls = message - .get("tool_calls") - .and_then(Value::as_array) - .map(|calls| map_chat_tool_calls_to_gemini_parts(calls)) - .unwrap_or_default(); - - let mut parts = Vec::new(); - parts.extend(content); - parts.extend(tool_calls); - - let finish_reason = choice - .get("finish_reason") - .and_then(Value::as_str) - .map(chat_finish_reason_to_gemini); - - let mut candidate = json!({ - "index": index, - "content": { "role": "model", "parts": parts } - }); - if let Some(reason) = finish_reason { - if let Some(obj) = candidate.as_object_mut() { - obj.insert("finishReason".to_string(), Value::String(reason.to_string())); - } - } - Some(candidate) -} - -fn map_chat_content_to_gemini_parts(content: Option<&Value>) -> Vec { - let Some(content) = content else { - return Vec::new(); - }; - match content { - Value::String(text) => vec![json!({ "text": text })], - Value::Array(parts) => map_chat_content_parts_to_gemini_parts(parts), - _ => Vec::new(), - } -} - -fn map_chat_content_parts_to_gemini_parts(parts: &[Value]) -> Vec { - let mut output = Vec::new(); - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str).unwrap_or(""); - match part_type { - "text" | "input_text" | "output_text" => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - output.push(json!({ "text": text })); - } - } - "image_url" => { - if let Some(url) = extract_image_url(part.get("image_url")) { - output.push(url); - } - } - "input_image" | "output_image" => { - if let Some(url) = extract_image_url(part.get("image_url")) { - output.push(url); - } - } - _ => {} - } - } - output -} - -fn extract_image_url(value: Option<&Value>) -> Option { - let url = match value { - Some(Value::String(url)) => Some(url.as_str()), - Some(Value::Object(obj)) => obj.get("url").and_then(Value::as_str), - _ => None, - }?; - if let Some(rest) = url.strip_prefix("data:") { - if let Some((mime_type, data)) = rest.split_once(";base64,") { - return Some(json!({ "inlineData": { "mimeType": mime_type, "data": data } })); - } - } - Some(json!({ "fileData": { "fileUri": url } })) -} - -fn map_chat_tool_calls_to_gemini_parts(tool_calls: &[Value]) -> Vec { - let mut output = Vec::new(); - for call in tool_calls { - let Some(call) = call.as_object() else { - continue; - }; - let function = call.get("function").and_then(Value::as_object); - let name = function - .and_then(|function| function.get("name")) - .and_then(Value::as_str) - .unwrap_or(""); - let arguments = function - .and_then(|function| function.get("arguments")) - .and_then(Value::as_str) - .unwrap_or("{}"); - if name.is_empty() { - continue; - } - let args: Value = serde_json::from_str(arguments).unwrap_or_else(|_| json!({})); - output.push(json!({ - "functionCall": { - "name": name, - "args": args - } - })); - } - output -} - -fn map_chat_usage_to_gemini_usage(usage: &Map) -> Option { - let prompt_tokens = usage.get("prompt_tokens").and_then(Value::as_u64); - let completion_tokens = usage.get("completion_tokens").and_then(Value::as_u64); - let total_tokens = usage.get("total_tokens").and_then(Value::as_u64); - if prompt_tokens.is_none() && completion_tokens.is_none() && total_tokens.is_none() { - return None; - } - Some(json!({ - "promptTokenCount": prompt_tokens.unwrap_or(0), - "candidatesTokenCount": completion_tokens.unwrap_or(0), - "totalTokenCount": total_tokens.unwrap_or_else(|| prompt_tokens.unwrap_or(0) + completion_tokens.unwrap_or(0)) - })) -} - -fn chat_finish_reason_to_gemini(reason: &str) -> &'static str { - match reason { - "stop" => "STOP", - "length" => "MAX_TOKENS", - "content_filter" => "SAFETY", - _ => "STOP", - } -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "response.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/gemini_compat/response.test.rs b/src-tauri/src/proxy/gemini_compat/response.test.rs deleted file mode 100644 index f8a673a..0000000 --- a/src-tauri/src/proxy/gemini_compat/response.test.rs +++ /dev/null @@ -1,32 +0,0 @@ -use super::*; - -#[test] -fn chat_response_to_gemini_maps_tool_calls_and_text() { - let input = json!({ - "id": "chatcmpl_x", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "hello", - "tool_calls": [{ - "id": "call_1", - "type": "function", - "function": { "name": "getFoo", "arguments": "{\"a\":1}" } - }] - }, - "finish_reason": "stop" - }], - "usage": { "prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3 } - }); - - let output = chat_response_to_gemini(&Bytes::from(serde_json::to_vec(&input).unwrap()), None) - .expect("convert"); - let value: Value = serde_json::from_slice(&output).expect("json"); - assert_eq!(value["candidates"][0]["content"]["parts"][0]["text"], json!("hello")); - assert_eq!( - value["candidates"][0]["content"]["parts"][1]["functionCall"]["name"], - json!("getFoo") - ); - assert_eq!(value["usageMetadata"]["totalTokenCount"], json!(3)); -} diff --git a/src-tauri/src/proxy/gemini_compat/stream.rs b/src-tauri/src/proxy/gemini_compat/stream.rs deleted file mode 100644 index be31faf..0000000 --- a/src-tauri/src/proxy/gemini_compat/stream.rs +++ /dev/null @@ -1,560 +0,0 @@ -//! Gemini 流式响应 → OpenAI Chat 流式响应转换 - -use axum::body::Bytes; -use futures_util::{stream::try_unfold, StreamExt}; -use serde_json::{json, Value}; -use std::{collections::VecDeque, sync::Arc}; - -use crate::proxy::log::{build_log_entry, LogContext, LogWriter}; -use crate::proxy::sse::SseEventParser; -use crate::proxy::token_rate::RequestTokenTracker; -use crate::proxy::usage::SseUsageCollector; - -use super::tools::gemini_function_call_to_chat_tool_call; - -/// 将 Gemini 流式响应转换为 OpenAI Chat 流式响应 -pub(crate) fn stream_gemini_to_chat( - upstream: impl futures_util::stream::Stream> + Unpin + Send + 'static, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = GeminiToChatState::new(upstream, context, log, token_tracker); - try_unfold(state, |state| async move { state.step().await }) -} - -/// 将 OpenAI Chat 流式响应转换为 Gemini 流式响应 -pub(crate) fn stream_chat_to_gemini( - upstream: impl futures_util::stream::Stream> + Unpin + Send + 'static, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = ChatToGeminiState::new(upstream, context, log, token_tracker); - try_unfold(state, |state| async move { state.step().await }) -} - -struct GeminiToChatState { - upstream: S, - parser: SseEventParser, - collector: SseUsageCollector, - log: Arc, - context: LogContext, - token_tracker: RequestTokenTracker, - out: VecDeque, - chat_id: String, - created: i64, - model: String, - sent_role: bool, - sent_done: bool, - logged: bool, - upstream_ended: bool, - tool_call_index: usize, -} - -struct ToolCallState { - name: String, - arguments: String, -} - -struct ChatToGeminiState { - upstream: S, - parser: SseEventParser, - collector: SseUsageCollector, - log: Arc, - context: LogContext, - token_tracker: RequestTokenTracker, - out: VecDeque, - sent_done: bool, - logged: bool, - upstream_ended: bool, - tool_calls: Vec>, -} - -impl GeminiToChatState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new( - upstream: S, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - ) -> Self { - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis(); - Self { - upstream, - parser: SseEventParser::new(), - collector: SseUsageCollector::new(), - log, - model: context.model.clone().unwrap_or_else(|| "gemini".to_string()), - context, - token_tracker, - out: VecDeque::new(), - chat_id: format!("chatcmpl_gemini_{now_ms}"), - created: (now_ms / 1000) as i64, - sent_role: false, - sent_done: false, - logged: false, - upstream_ended: false, - tool_call_index: 0, - } - } - - async fn step(mut self) -> Result, std::io::Error> { - loop { - if let Some(next) = self.out.pop_front() { - return Ok(Some((next, self))); - } - - if self.upstream_ended { - return Ok(None); - } - - match self.upstream.next().await { - Some(Ok(chunk)) => { - if self.context.ttfb_ms.is_none() { - self.context.ttfb_ms = Some(self.context.start.elapsed().as_millis()); - } - self.collector.push_chunk(&chunk); - let mut events = Vec::new(); - self.parser.push_chunk(&chunk, |data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - } - Some(Err(err)) => { - self.log_usage_once(); - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); - } - None => { - self.upstream_ended = true; - let mut events = Vec::new(); - self.parser.finish(|data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - if !self.sent_done { - self.push_done("stop"); - } - self.log_usage_once(); - if self.out.is_empty() { - return Ok(None); - } - } - } - } - } - - fn handle_event(&mut self, data: &str, token_texts: &mut Vec) { - if self.sent_done { - return; - } - if data == "[DONE]" { - self.push_done("stop"); - return; - } - let Ok(value) = serde_json::from_str::(data) else { - return; - }; - - // 处理 Gemini 响应格式 - let Some(candidates) = value.get("candidates").and_then(Value::as_array) else { - return; - }; - - for candidate in candidates { - self.handle_candidate(candidate, token_texts); - } - } - - fn handle_candidate(&mut self, candidate: &Value, token_texts: &mut Vec) { - let Some(candidate) = candidate.as_object() else { - return; - }; - - // 检查 finishReason - let finish_reason = candidate.get("finishReason").and_then(Value::as_str); - - let Some(content) = candidate.get("content").and_then(Value::as_object) else { - // 如果有 finishReason 但没有 content,发送完成信号 - if finish_reason.is_some() { - self.push_done(gemini_finish_reason_to_chat(finish_reason, false)); - } - return; - }; - - let Some(parts) = content.get("parts").and_then(Value::as_array) else { - return; - }; - - let mut has_tool_calls = false; - - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - - // 文本内容 - if let Some(text) = part.get("text").and_then(Value::as_str) { - if !text.is_empty() { - token_texts.push(text.to_string()); - self.ensure_role_sent(); - self.out.push_back(chat_chunk_sse( - &self.chat_id, - self.created, - &self.model, - json!({ "content": text }), - None, - )); - } - } - - // 函数调用 - if let Some(function_call) = part.get("functionCall").and_then(Value::as_object) { - has_tool_calls = true; - self.ensure_role_sent(); - let tool_call = - gemini_function_call_to_chat_tool_call(function_call, self.tool_call_index); - self.tool_call_index += 1; - - // 发送工具调用 delta - self.out.push_back(chat_chunk_sse( - &self.chat_id, - self.created, - &self.model, - json!({ "tool_calls": [tool_call] }), - None, - )); - } - } - - // 处理完成原因 - if let Some(reason) = finish_reason { - let chat_reason = gemini_finish_reason_to_chat(Some(reason), has_tool_calls); - self.push_done(chat_reason); - } - } - - fn ensure_role_sent(&mut self) { - if self.sent_role { - return; - } - self.sent_role = true; - self.out.push_back(chat_chunk_sse( - &self.chat_id, - self.created, - &self.model, - json!({ "role": "assistant", "content": "" }), - None, - )); - } - - fn push_done(&mut self, finish_reason: &str) { - if self.sent_done { - return; - } - self.sent_done = true; - self.out.push_back(chat_chunk_sse( - &self.chat_id, - self.created, - &self.model, - json!({}), - Some(finish_reason), - )); - self.out.push_back(Bytes::from("data: [DONE]\n\n")); - } - - fn log_usage_once(&mut self) { - if self.logged { - return; - } - self.logged = true; - let entry = build_log_entry(&self.context, self.collector.finish(), None); - self.log.clone().write_detached(entry); - } -} - -impl ChatToGeminiState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new( - upstream: S, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - ) -> Self { - Self { - upstream, - parser: SseEventParser::new(), - collector: SseUsageCollector::new(), - log, - context, - token_tracker, - out: VecDeque::new(), - sent_done: false, - logged: false, - upstream_ended: false, - tool_calls: Vec::new(), - } - } - - async fn step(mut self) -> Result, std::io::Error> { - loop { - if let Some(next) = self.out.pop_front() { - return Ok(Some((next, self))); - } - - if self.upstream_ended { - return Ok(None); - } - - match self.upstream.next().await { - Some(Ok(chunk)) => { - if self.context.ttfb_ms.is_none() { - self.context.ttfb_ms = Some(self.context.start.elapsed().as_millis()); - } - self.collector.push_chunk(&chunk); - let mut events = Vec::new(); - self.parser.push_chunk(&chunk, |data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - } - Some(Err(err)) => { - self.log_usage_once(); - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); - } - None => { - self.upstream_ended = true; - let mut events = Vec::new(); - self.parser.finish(|data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - if !self.sent_done { - self.push_finish_reason("STOP", None); - } - self.log_usage_once(); - if self.out.is_empty() { - return Ok(None); - } - } - } - } - } - - fn handle_event(&mut self, data: &str, token_texts: &mut Vec) { - if self.sent_done { - return; - } - if data == "[DONE]" { - self.sent_done = true; - return; - } - let Ok(value) = serde_json::from_str::(data) else { - return; - }; - - let usage = value.get("usage").and_then(chat_usage_to_gemini_usage); - let Some(choice) = value - .get("choices") - .and_then(Value::as_array) - .and_then(|choices| choices.first()) - else { - return; - }; - let delta = choice.get("delta").and_then(Value::as_object).cloned(); - let finish_reason = choice.get("finish_reason").and_then(Value::as_str); - - if let Some(delta) = delta.as_ref() { - if let Some(content) = delta.get("content").and_then(Value::as_str) { - token_texts.push(content.to_string()); - self.push_text_delta(content, usage.clone()); - } - if let Some(tool_calls) = delta.get("tool_calls").and_then(Value::as_array) { - for tool_call in tool_calls { - if let Some(part) = self.update_tool_call(tool_call) { - self.push_candidate_part(part, usage.clone()); - } - } - } - } - - if let Some(reason) = finish_reason { - let mapped = chat_finish_reason_to_gemini(reason); - self.push_finish_reason(mapped, usage); - } - } - - fn push_text_delta(&mut self, text: &str, usage: Option) { - let candidate = json!({ - "index": 0, - "content": { "role": "model", "parts": [{ "text": text }] } - }); - self.out.push_back(gemini_chunk_sse(candidate, usage)); - } - - fn push_candidate_part(&mut self, part: Value, usage: Option) { - let candidate = json!({ - "index": 0, - "content": { "role": "model", "parts": [part] } - }); - self.out.push_back(gemini_chunk_sse(candidate, usage)); - } - - fn push_finish_reason(&mut self, reason: &str, usage: Option) { - if self.sent_done { - return; - } - self.sent_done = true; - let candidate = json!({ - "index": 0, - "content": { "role": "model", "parts": [] }, - "finishReason": reason - }); - self.out.push_back(gemini_chunk_sse(candidate, usage)); - } - - fn update_tool_call(&mut self, tool_call: &Value) -> Option { - let Some(tool_call) = tool_call.as_object() else { - return None; - }; - let index = tool_call.get("index").and_then(Value::as_u64).unwrap_or(0) as usize; - let function = tool_call.get("function").and_then(Value::as_object)?; - let name = function.get("name").and_then(Value::as_str).unwrap_or(""); - let delta = function - .get("arguments") - .and_then(Value::as_str) - .unwrap_or(""); - - if self.tool_calls.len() <= index { - self.tool_calls.resize_with(index + 1, || None); - } - let state = self.tool_calls[index].get_or_insert(ToolCallState { - name: String::new(), - arguments: String::new(), - }); - if !name.is_empty() { - state.name = name.to_string(); - } - if !delta.is_empty() { - state.arguments.push_str(delta); - } - if state.name.is_empty() { - return None; - } - let args: Value = serde_json::from_str(&state.arguments).unwrap_or_else(|_| json!({})); - Some(json!({ - "functionCall": { "name": state.name, "args": args } - })) - } - - fn log_usage_once(&mut self) { - if self.logged { - return; - } - self.logged = true; - let entry = build_log_entry(&self.context, self.collector.finish(), None); - self.log.clone().write_detached(entry); - } -} - -fn chat_chunk_sse( - id: &str, - created: i64, - model: &str, - delta: Value, - finish_reason: Option<&str>, -) -> Bytes { - let chunk = json!({ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [{ - "index": 0, - "delta": delta, - "finish_reason": finish_reason - }] - }); - Bytes::from(format!("data: {}\n\n", chunk)) -} - -fn gemini_finish_reason_to_chat(reason: Option<&str>, has_tool_calls: bool) -> &'static str { - if has_tool_calls { - return "tool_calls"; - } - match reason { - Some("STOP") => "stop", - Some("MAX_TOKENS") => "length", - Some("SAFETY") => "content_filter", - Some("RECITATION") => "content_filter", - Some("OTHER") => "stop", - Some("BLOCKLIST") => "content_filter", - Some("PROHIBITED_CONTENT") => "content_filter", - Some("SPII") => "content_filter", - _ => "stop", - } -} - -fn gemini_chunk_sse(candidate: Value, usage: Option) -> Bytes { - let mut payload = json!({ "candidates": [candidate] }); - if let Some(usage) = usage { - if let Some(obj) = payload.as_object_mut() { - obj.insert("usageMetadata".to_string(), usage); - } - } - Bytes::from(format!("data: {}\n\n", payload)) -} - -fn chat_usage_to_gemini_usage(usage: &Value) -> Option { - let prompt_tokens = usage.get("prompt_tokens").and_then(Value::as_u64); - let completion_tokens = usage.get("completion_tokens").and_then(Value::as_u64); - let total_tokens = usage.get("total_tokens").and_then(Value::as_u64); - if prompt_tokens.is_none() && completion_tokens.is_none() && total_tokens.is_none() { - return None; - } - Some(json!({ - "promptTokenCount": prompt_tokens.unwrap_or(0), - "candidatesTokenCount": completion_tokens.unwrap_or(0), - "totalTokenCount": total_tokens.unwrap_or_else(|| prompt_tokens.unwrap_or(0) + completion_tokens.unwrap_or(0)) - })) -} - -fn chat_finish_reason_to_gemini(reason: &str) -> &'static str { - match reason { - "stop" => "STOP", - "length" => "MAX_TOKENS", - "content_filter" => "SAFETY", - _ => "STOP", - } -} diff --git a/src-tauri/src/proxy/gemini_compat/tools.rs b/src-tauri/src/proxy/gemini_compat/tools.rs deleted file mode 100644 index 57bb912..0000000 --- a/src-tauri/src/proxy/gemini_compat/tools.rs +++ /dev/null @@ -1,175 +0,0 @@ -//! OpenAI Chat ↔ Gemini 工具定义转换 - -use serde_json::{json, Value}; - -/// 将 OpenAI Chat 格式的 tools 转换为 Gemini 格式的 functionDeclarations -pub(super) fn map_chat_tools_to_gemini(tools: &Value) -> Value { - let Some(tools) = tools.as_array() else { - return json!([]); - }; - - let declarations: Vec = tools - .iter() - .filter_map(|tool| { - let tool = tool.as_object()?; - if tool.get("type").and_then(Value::as_str) != Some("function") { - return None; - } - let function = tool.get("function")?.as_object()?; - let name = function.get("name").and_then(Value::as_str)?; - let description = function.get("description").and_then(Value::as_str).unwrap_or(""); - let parameters = function.get("parameters").cloned().unwrap_or_else(|| json!({})); - Some(json!({ - "name": name, - "description": description, - "parameters": parameters - })) - }) - .collect(); - - json!([{ - "functionDeclarations": declarations - }]) -} - -/// 将 OpenAI Chat 格式的 tool_choice 转换为 Gemini 格式的 toolConfig -pub(super) fn map_chat_tool_choice_to_gemini(tool_choice: &Value) -> Option { - match tool_choice { - Value::String(s) => match s.as_str() { - "none" => Some(json!({ "functionCallingConfig": { "mode": "NONE" } })), - "auto" => Some(json!({ "functionCallingConfig": { "mode": "AUTO" } })), - "required" => Some(json!({ "functionCallingConfig": { "mode": "ANY" } })), - _ => None, - }, - Value::Object(obj) => { - // { "type": "function", "function": { "name": "..." } } - if obj.get("type").and_then(Value::as_str) == Some("function") { - if let Some(function) = obj.get("function").and_then(Value::as_object) { - if let Some(name) = function.get("name").and_then(Value::as_str) { - return Some(json!({ - "functionCallingConfig": { - "mode": "ANY", - "allowedFunctionNames": [name] - } - })); - } - } - } - None - } - _ => None, - } -} - -/// 将 Gemini 格式的 tools 转换为 OpenAI Chat 格式的 tools -pub(super) fn map_gemini_tools_to_chat(value: &Value) -> Value { - let Some(groups) = value.as_array() else { - return json!([]); - }; - - let mut tools = Vec::new(); - for group in groups { - let Some(group) = group.as_object() else { - continue; - }; - let Some(declarations) = group.get("functionDeclarations").and_then(Value::as_array) else { - continue; - }; - for declaration in declarations { - let Some(declaration) = declaration.as_object() else { - continue; - }; - let name = declaration.get("name").and_then(Value::as_str).unwrap_or(""); - if name.is_empty() { - continue; - } - let description = declaration - .get("description") - .and_then(Value::as_str) - .unwrap_or(""); - let parameters = declaration - .get("parameters") - .cloned() - .unwrap_or_else(|| json!({})); - tools.push(json!({ - "type": "function", - "function": { - "name": name, - "description": description, - "parameters": parameters - } - })); - } - } - - Value::Array(tools) -} - -/// 将 Gemini 格式的 toolConfig 转换为 OpenAI Chat 格式的 tool_choice -pub(super) fn map_gemini_tool_config_to_chat(value: &Value) -> Option { - let Some(tool_config) = value.as_object() else { - return None; - }; - let Some(config) = tool_config - .get("functionCallingConfig") - .and_then(Value::as_object) - else { - return None; - }; - - let mode = config.get("mode").and_then(Value::as_str).unwrap_or(""); - match mode { - "NONE" => Some(Value::String("none".to_string())), - "AUTO" => Some(Value::String("auto".to_string())), - "ANY" => { - let allowed = config - .get("allowedFunctionNames") - .and_then(Value::as_array) - .cloned() - .unwrap_or_default(); - if allowed.len() == 1 { - let name = allowed - .first() - .and_then(Value::as_str) - .unwrap_or("") - .to_string(); - if !name.is_empty() { - return Some(json!({ - "type": "function", - "function": { "name": name } - })); - } - } - Some(Value::String("required".to_string())) - } - _ => None, - } -} - -/// 将 Gemini 格式的 functionCall 转换为 OpenAI Chat 格式的 tool_call -pub(super) fn gemini_function_call_to_chat_tool_call( - function_call: &serde_json::Map, - index: usize, -) -> Value { - let name = function_call - .get("name") - .and_then(Value::as_str) - .unwrap_or(""); - let args = function_call - .get("args") - .cloned() - .unwrap_or_else(|| json!({})); - let arguments = match args { - Value::String(s) => s, - other => serde_json::to_string(&other).unwrap_or_else(|_| "{}".to_string()), - }; - - json!({ - "id": format!("call_gemini_{index}"), - "type": "function", - "function": { - "name": name, - "arguments": arguments - } - }) -} diff --git a/src-tauri/src/proxy/http.rs b/src-tauri/src/proxy/http.rs deleted file mode 100644 index 7f5454f..0000000 --- a/src-tauri/src/proxy/http.rs +++ /dev/null @@ -1,364 +0,0 @@ -use axum::{ - body::Body, - http::{ - header::{ - HeaderName, HeaderValue, AUTHORIZATION, CONNECTION, CONTENT_LENGTH, HOST, - PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER, TRANSFER_ENCODING, UPGRADE, - }, - HeaderMap, StatusCode, - }, - response::Response, -}; -use reqwest::header::HeaderMap as ReqwestHeaderMap; -use serde_json::json; - -use super::{ - config::{ProxyConfig, UpstreamRuntime}, - gemini, - server_helpers::is_anthropic_path, -}; -use url::form_urlencoded; - -const KEEP_ALIVE: HeaderName = HeaderName::from_static("keep-alive"); -const X_OPENAI_API_KEY: &str = "x-openai-api-key"; -const X_API_KEY: &str = "x-api-key"; -const X_ANTHROPIC_API_KEY: &str = "x-anthropic-api-key"; -const X_GOOG_API_KEY: &str = "x-goog-api-key"; - -pub(crate) fn ensure_local_auth( - config: &ProxyConfig, - headers: &HeaderMap, - path: &str, - query: Option<&str>, -) -> Result<(), String> { - let Some(expected) = config.local_api_key.as_ref() else { - tracing::debug!("no local_api_key configured, skipping local auth"); - return Ok(()); - }; - tracing::debug!(path = %path, "local auth required, resolving local key"); - let Some(provided) = resolve_local_auth_token(headers, path, query)? else { - tracing::warn!(path = %path, "missing local access key"); - return Err("Missing local access key.".to_string()); - }; - if provided != expected.as_str() { - tracing::warn!( - path = %path, - got = %mask_key(&provided), - expected = %mask_key(expected), - "local auth mismatch" - ); - return Err("Local access key is invalid.".to_string()); - } - tracing::debug!(path = %path, "local auth passed"); - Ok(()) -} - -/// 遮蔽敏感 key,仅显示前 8 字符 -fn mask_key(key: &str) -> String { - if key.len() <= 8 { - return key.to_string(); - } - format!("{}...", &key[..8]) -} - -fn resolve_local_auth_token( - headers: &HeaderMap, - path: &str, - query: Option<&str>, -) -> Result, String> { - // Local auth follows request format: Anthropic -> x-api-key (or Authorization), Gemini -> x-goog-api-key/?key, others -> Authorization. - if is_anthropic_path(path) { - if let Some(value) = parse_raw_header(headers, X_API_KEY)? { - return Ok(Some(value)); - } - if let Some(value) = parse_raw_header(headers, X_ANTHROPIC_API_KEY)? { - return Ok(Some(value)); - } - return parse_bearer_header(headers); - } - - if gemini::is_gemini_path(path) { - if let Some(value) = parse_raw_header(headers, X_GOOG_API_KEY)? { - return Ok(Some(value)); - } - return parse_query_key(query); - } - - parse_bearer_header(headers) -} - -fn parse_raw_header(headers: &HeaderMap, name: &str) -> Result, String> { - let Some(header) = headers.get(name) else { - return Ok(None); - }; - let Ok(value) = header.to_str() else { - return Err("Local access key is invalid.".to_string()); - }; - let value = value.trim(); - if value.is_empty() { - return Err("Local access key is invalid.".to_string()); - } - Ok(Some(value.to_string())) -} - -fn parse_bearer_header(headers: &HeaderMap) -> Result, String> { - let Some(header) = headers.get(AUTHORIZATION) else { - return Ok(None); - }; - let Ok(value) = header.to_str() else { - return Err("Local access key is invalid.".to_string()); - }; - let value = value.trim(); - let Some(token) = value.strip_prefix("Bearer ") else { - return Err("Local access key is invalid.".to_string()); - }; - let token = token.trim(); - if token.is_empty() { - return Err("Local access key is invalid.".to_string()); - } - Ok(Some(token.to_string())) -} - -fn parse_query_key(query: Option<&str>) -> Result, String> { - let Some(query) = query else { - return Ok(None); - }; - for (key, value) in form_urlencoded::parse(query.as_bytes()) { - if key != "key" { - continue; - } - let value = value.trim(); - if value.is_empty() { - return Err("Local access key is invalid.".to_string()); - } - return Ok(Some(value.to_string())); - } - Ok(None) -} - -#[derive(Clone, Default)] -pub(crate) struct RequestAuth { - pub(crate) openai_bearer: Option, - pub(crate) anthropic_api_key: Option, - pub(crate) gemini_api_key: Option, - pub(crate) authorization_fallback: Option, -} - -pub(crate) struct UpstreamAuthHeader { - pub(crate) name: HeaderName, - pub(crate) value: HeaderValue, -} - -pub(crate) fn resolve_request_auth( - config: &ProxyConfig, - headers: &HeaderMap, -) -> Result { - let mut auth = RequestAuth::default(); - // When local auth is enabled, request auth headers are reserved for local access and not used upstream. - if config.local_api_key.is_none() { - if let Some(value) = headers.get(X_OPENAI_API_KEY) { - let Ok(value) = value.to_str() else { - return Err("Upstream API key is invalid.".to_string()); - }; - auth.openai_bearer = Some(bearer_header(value).ok_or_else(|| { - "Upstream API key contains invalid characters.".to_string() - })?); - } - - // Anthropic uses `x-api-key`; allow explicit overrides as well. - if let Some(value) = headers - .get(X_API_KEY) - .or_else(|| headers.get(X_ANTHROPIC_API_KEY)) - { - let Ok(_) = value.to_str() else { - return Err("Upstream API key is invalid.".to_string()); - }; - auth.anthropic_api_key = Some(value.clone()); - } - - if let Some(value) = headers.get(AUTHORIZATION) { - auth.authorization_fallback = Some(value.clone()); - } - - if let Some(value) = headers.get(X_GOOG_API_KEY) { - let Ok(value) = value.to_str() else { - return Err("Upstream API key is invalid.".to_string()); - }; - let value = value.trim(); - if !value.is_empty() { - auth.gemini_api_key = Some(value.to_string()); - } - } - } - Ok(auth) -} - -pub(crate) fn resolve_upstream_auth( - provider: &str, - upstream: &UpstreamRuntime, - request_auth: &RequestAuth, -) -> Result, Response> { - tracing::debug!( - provider = %provider, - upstream_id = %upstream.id, - has_upstream_key = upstream.api_key.is_some(), - has_openai_bearer = request_auth.openai_bearer.is_some(), - has_anthropic_key = request_auth.anthropic_api_key.is_some(), - has_auth_fallback = request_auth.authorization_fallback.is_some(), - "resolving upstream auth" - ); - - match provider { - "anthropic" => { - let value = match upstream.api_key.as_ref() { - Some(key) => { - tracing::debug!("using upstream.api_key for Anthropic"); - HeaderValue::from_str(key).map_err(|_| { - error_response( - StatusCode::UNAUTHORIZED, - "Upstream API key contains invalid characters.", - ) - })? - } - None => { - let Some(value) = request_auth.anthropic_api_key.clone() else { - tracing::warn!("no API key for Anthropic"); - return Ok(None); - }; - tracing::debug!("using request_auth.anthropic_api_key for Anthropic"); - value - } - }; - - Ok(Some(UpstreamAuthHeader { - name: HeaderName::from_static(X_API_KEY), - value, - })) - } - _ => { - if let Some(key) = upstream.api_key.as_ref() { - tracing::debug!(provider = %provider, "using upstream.api_key"); - let value = bearer_header(key).ok_or_else(|| { - error_response( - StatusCode::UNAUTHORIZED, - "Upstream API key contains invalid characters.", - ) - })?; - return Ok(Some(UpstreamAuthHeader { - name: AUTHORIZATION, - value, - })); - } - - if let Some(value) = request_auth.openai_bearer.clone() { - tracing::debug!(provider = %provider, "using request_auth.openai_bearer"); - return Ok(Some(UpstreamAuthHeader { - name: AUTHORIZATION, - value, - })); - } - - if let Some(value) = request_auth.authorization_fallback.clone() { - tracing::debug!(provider = %provider, "using request_auth.authorization_fallback"); - return Ok(Some(UpstreamAuthHeader { - name: AUTHORIZATION, - value, - })); - } - - tracing::warn!(provider = %provider, "no API key found"); - Ok(None) - } - } -} - -pub(crate) fn bearer_header(value: &str) -> Option { - let header = format!("Bearer {value}"); - HeaderValue::from_str(&header).ok() -} - -pub(crate) fn build_upstream_headers( - headers: &HeaderMap, - auth: UpstreamAuthHeader, -) -> ReqwestHeaderMap { - let mut output = ReqwestHeaderMap::new(); - for (name, value) in headers.iter() { - if should_skip_request_header(name) { - continue; - } - if name == AUTHORIZATION - || name == &auth.name - || name.as_str().eq_ignore_ascii_case(X_OPENAI_API_KEY) - || name.as_str().eq_ignore_ascii_case(X_API_KEY) - || name.as_str().eq_ignore_ascii_case(X_ANTHROPIC_API_KEY) - || name.as_str().eq_ignore_ascii_case(X_GOOG_API_KEY) - { - continue; - } - output.append(name.clone(), value.clone()); - } - output.insert(auth.name, auth.value); - output -} - -fn should_skip_request_header(name: &HeaderName) -> bool { - is_hop_header(name) || name == HOST || name == CONTENT_LENGTH -} - -pub(crate) fn is_hop_header(name: &HeaderName) -> bool { - name == CONNECTION - || name == KEEP_ALIVE - || name == PROXY_AUTHENTICATE - || name == PROXY_AUTHORIZATION - || name == TE - || name == TRAILER - || name == TRANSFER_ENCODING - || name == UPGRADE -} - -pub(crate) fn filter_response_headers(headers: &ReqwestHeaderMap) -> HeaderMap { - let mut output = HeaderMap::new(); - for (name, value) in headers.iter() { - if is_hop_header(name) { - continue; - } - output.append(name.clone(), value.clone()); - } - output -} - -pub(crate) fn build_response(status: StatusCode, headers: HeaderMap, body: Body) -> Response { - let mut response = Response::new(body); - *response.status_mut() = status; - *response.headers_mut() = headers; - response -} - -pub(crate) fn error_response(status: StatusCode, message: impl AsRef) -> Response { - let body = json!({ - "error": { - "message": message.as_ref(), - "type": "proxy_error" - } - }); - let mut response = Response::new(Body::from(body.to_string())); - *response.status_mut() = status; - response.headers_mut().insert( - axum::http::header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), - ); - response -} - -pub(crate) fn extract_request_id(headers: &ReqwestHeaderMap) -> Option { - headers - .get("x-request-id") - .or_else(|| headers.get("openai-request-id")) - .and_then(|value| value.to_str().ok()) - .map(|value| value.to_string()) -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "http.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/http.test.rs b/src-tauri/src/proxy/http.test.rs deleted file mode 100644 index 3e3da55..0000000 --- a/src-tauri/src/proxy/http.test.rs +++ /dev/null @@ -1,58 +0,0 @@ -use super::*; -use crate::logging::LogLevel; -use std::collections::HashMap; - -fn config_with_local(key: &str) -> ProxyConfig { - ProxyConfig { - host: "127.0.0.1".to_string(), - port: 9208, - local_api_key: Some(key.to_string()), - log_level: LogLevel::Silent, - max_request_body_bytes: 1024, - enable_api_format_conversion: false, - upstream_strategy: crate::proxy::config::UpstreamStrategy::PriorityFillFirst, - upstreams: HashMap::new(), - kiro_preferred_endpoint: None, - antigravity_user_agent: None, - } -} - -#[test] -fn local_auth_accepts_anthropic_headers() { - let config = config_with_local("local-key"); - let mut headers = HeaderMap::new(); - headers.insert("x-api-key", HeaderValue::from_static("local-key")); - let result = ensure_local_auth(&config, &headers, "/v1/messages", None); - assert!(result.is_ok()); -} - -#[test] -fn local_auth_accepts_anthropic_authorization_only() { - let config = config_with_local("local-key"); - let mut headers = HeaderMap::new(); - headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer local-key")); - let result = ensure_local_auth(&config, &headers, "/v1/messages", None); - assert!(result.is_ok()); -} - -#[test] -fn local_auth_accepts_gemini_query_key() { - let config = config_with_local("local-key"); - let headers = HeaderMap::new(); - let result = ensure_local_auth( - &config, - &headers, - "/v1beta/models/gemini-1.5-flash:generateContent", - Some("key=local-key"), - ); - assert!(result.is_ok()); -} - -#[test] -fn local_auth_accepts_openai_authorization() { - let config = config_with_local("local-key"); - let mut headers = HeaderMap::new(); - headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer local-key")); - let result = ensure_local_auth(&config, &headers, "/v1/chat/completions", None); - assert!(result.is_ok()); -} diff --git a/src-tauri/src/proxy/http_client.rs b/src-tauri/src/proxy/http_client.rs deleted file mode 100644 index 2137289..0000000 --- a/src-tauri/src/proxy/http_client.rs +++ /dev/null @@ -1,46 +0,0 @@ -use std::{collections::HashMap, sync::Mutex}; - -use reqwest::{Client, ClientBuilder, Proxy}; - -pub(crate) struct ProxyHttpClients { - direct: Client, - by_proxy_url: Mutex>, -} - -impl ProxyHttpClients { - pub(crate) fn new() -> Result { - let direct = ClientBuilder::new() - // 默认不走系统代理;仅当用户显式配置 proxy_url 时才走代理。 - .no_proxy() - .build() - .map_err(|err| format!("Failed to build direct HTTP client: {err}"))?; - Ok(Self { - direct, - by_proxy_url: Mutex::new(HashMap::new()), - }) - } - - pub(crate) fn client_for_proxy_url(&self, proxy_url: Option<&str>) -> Result { - let Some(proxy_url) = proxy_url.map(|value| value.trim()).filter(|value| !value.is_empty()) - else { - return Ok(self.direct.clone()); - }; - - let mut guard = self - .by_proxy_url - .lock() - .map_err(|_| "HTTP client pool is poisoned.".to_string())?; - if let Some(existing) = guard.get(proxy_url) { - return Ok(existing.clone()); - } - - let proxy = Proxy::all(proxy_url) - .map_err(|_| "proxy_url is invalid or not supported.".to_string())?; - let client = ClientBuilder::new() - .proxy(proxy) - .build() - .map_err(|err| format!("Failed to build proxied HTTP client: {err}"))?; - guard.insert(proxy_url.to_string(), client.clone()); - Ok(client) - } -} diff --git a/src-tauri/src/proxy/kiro/constants.rs b/src-tauri/src/proxy/kiro/constants.rs deleted file mode 100644 index ec891f2..0000000 --- a/src-tauri/src/proxy/kiro/constants.rs +++ /dev/null @@ -1,52 +0,0 @@ -pub(crate) const KIRO_MAX_OUTPUT_TOKENS: i64 = 32_000; - -pub(crate) const KIRO_AGENTIC_SYSTEM_PROMPT: &str = r#" -# CRITICAL: CHUNKED WRITE PROTOCOL (MANDATORY) - -You MUST follow these rules for ALL file operations. Violation causes server timeouts and task failure. - -## ABSOLUTE LIMITS -- **MAXIMUM 350 LINES** per single write/edit operation - NO EXCEPTIONS -- **RECOMMENDED 300 LINES** or less for optimal performance -- **NEVER** write entire files in one operation if >300 lines - -## MANDATORY CHUNKED WRITE STRATEGY - -### For NEW FILES (>300 lines total): -1. FIRST: Write initial chunk (first 250-300 lines) using write_to_file/fsWrite -2. THEN: Append remaining content in 250-300 line chunks using file append operations -3. REPEAT: Continue appending until complete - -### For EDITING EXISTING FILES: -1. Use surgical edits (apply_diff/targeted edits) - change ONLY what's needed -2. NEVER rewrite entire files - use incremental modifications -3. Split large refactors into multiple small, focused edits - -### For LARGE CODE GENERATION: -1. Generate in logical sections (imports, types, functions separately) -2. Write each section as a separate operation -3. Use append operations for subsequent sections - -## EXAMPLES OF CORRECT BEHAVIOR - -✅ CORRECT: Writing a 600-line file -- Operation 1: Write lines 1-300 (initial file creation) -- Operation 2: Append lines 301-600 - -✅ CORRECT: Editing multiple functions -- Operation 1: Edit function A -- Operation 2: Edit function B -- Operation 3: Edit function C - -❌ WRONG: Writing 500 lines in single operation → TIMEOUT -❌ WRONG: Rewriting entire file to change 5 lines → TIMEOUT -❌ WRONG: Generating massive code blocks without chunking → TIMEOUT - -## WHY THIS MATTERS -- Server has 2-3 minute timeout for operations -- Large writes exceed timeout and FAIL completely -- Chunked writes are FASTER and more RELIABLE -- Failed writes waste time and require retry - -REMEMBER: When in doubt, write LESS per operation. Multiple small operations > one large operation. -"#; diff --git a/src-tauri/src/proxy/kiro/endpoint.rs b/src-tauri/src/proxy/kiro/endpoint.rs deleted file mode 100644 index 7485588..0000000 --- a/src-tauri/src/proxy/kiro/endpoint.rs +++ /dev/null @@ -1,36 +0,0 @@ -use crate::proxy::config::KiroPreferredEndpoint; - -#[derive(Clone, Copy, Debug)] -pub(crate) struct KiroEndpointConfig { - pub(crate) url: &'static str, - pub(crate) origin: &'static str, - pub(crate) amz_target: &'static str, -} - -const CODEWHISPERER_ENDPOINT: KiroEndpointConfig = KiroEndpointConfig { - url: "https://codewhisperer.us-east-1.amazonaws.com/generateAssistantResponse", - origin: "AI_EDITOR", - amz_target: "AmazonCodeWhispererStreamingService.GenerateAssistantResponse", -}; - -const AMAZON_Q_ENDPOINT: KiroEndpointConfig = KiroEndpointConfig { - url: "https://q.us-east-1.amazonaws.com/generateAssistantResponse", - origin: "CLI", - amz_target: "AmazonQDeveloperStreamingService.SendMessage", -}; - -pub(crate) fn select_endpoints( - preferred: Option, - is_idc: bool, -) -> Vec { - // IDC auth must use CodeWhisperer origin/endpoint pairing. - if is_idc { - return vec![CODEWHISPERER_ENDPOINT]; - } - - match preferred { - Some(KiroPreferredEndpoint::Ide) => vec![CODEWHISPERER_ENDPOINT, AMAZON_Q_ENDPOINT], - Some(KiroPreferredEndpoint::Cli) => vec![AMAZON_Q_ENDPOINT, CODEWHISPERER_ENDPOINT], - None => vec![CODEWHISPERER_ENDPOINT, AMAZON_Q_ENDPOINT], - } -} diff --git a/src-tauri/src/proxy/kiro/event_stream.rs b/src-tauri/src/proxy/kiro/event_stream.rs deleted file mode 100644 index f54b64e..0000000 --- a/src-tauri/src/proxy/kiro/event_stream.rs +++ /dev/null @@ -1,159 +0,0 @@ -use std::io; - -const MIN_FRAME_SIZE: usize = 16; -const MAX_FRAME_SIZE: usize = 10 << 20; - -#[derive(Debug)] -pub(crate) struct EventStreamError { - pub(crate) message: String, -} - -#[derive(Debug, Clone)] -pub(crate) struct EventStreamMessage { - pub(crate) event_type: String, - pub(crate) payload: Vec, -} - -pub(crate) struct EventStreamDecoder { - buffer: Vec, -} - -impl EventStreamDecoder { - pub(crate) fn new() -> Self { - Self { buffer: Vec::new() } - } - - pub(crate) fn push(&mut self, chunk: &[u8]) -> Result, EventStreamError> { - self.buffer.extend_from_slice(chunk); - self.decode_available() - } - - pub(crate) fn finish(&mut self) -> Result, EventStreamError> { - self.decode_available() - } - - fn decode_available(&mut self) -> Result, EventStreamError> { - let mut out = Vec::new(); - loop { - if self.buffer.len() < MIN_FRAME_SIZE { - break; - } - let total_len = read_u32(&self.buffer[0..4]) as usize; - let headers_len = read_u32(&self.buffer[4..8]) as usize; - - if total_len < MIN_FRAME_SIZE { - return Err(EventStreamError { - message: "EventStream frame too small".to_string(), - }); - } - if total_len > MAX_FRAME_SIZE { - return Err(EventStreamError { - message: "EventStream frame too large".to_string(), - }); - } - if self.buffer.len() < total_len { - break; - } - - let headers_start = 12; - let headers_end = headers_start + headers_len; - if headers_end > total_len { - return Err(EventStreamError { - message: "EventStream header length invalid".to_string(), - }); - } - let payload_start = headers_end; - if payload_start + 4 > total_len { - return Err(EventStreamError { - message: "EventStream payload length invalid".to_string(), - }); - } - let payload_end = total_len - 4; // last 4 bytes are message CRC - let headers = &self.buffer[headers_start..headers_end]; - let payload = self.buffer[payload_start..payload_end].to_vec(); - - let event_type = parse_event_type(headers).unwrap_or_default(); - out.push(EventStreamMessage { event_type, payload }); - - self.buffer.drain(0..total_len); - } - Ok(out) - } -} - -fn read_u32(bytes: &[u8]) -> u32 { - u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) -} - -fn parse_event_type(headers: &[u8]) -> Option { - let mut cursor = 0; - while cursor < headers.len() { - let name_len = *headers.get(cursor)? as usize; - cursor += 1; - if cursor + name_len > headers.len() { - return None; - } - let name = std::str::from_utf8(&headers[cursor..cursor + name_len]).ok()?; - cursor += name_len; - let header_type = *headers.get(cursor)?; - cursor += 1; - - let value = match header_type { - 0 | 1 => None, - 2 => { - cursor += 1; - None - } - 3 => { - cursor += 2; - None - } - 4 => { - cursor += 4; - None - } - 5 | 8 => { - cursor += 8; - None - } - 6 | 7 => { - if cursor + 2 > headers.len() { - return None; - } - let len = u16::from_be_bytes([ - headers[cursor], - headers[cursor + 1], - ]) as usize; - cursor += 2; - if cursor + len > headers.len() { - return None; - } - let bytes = &headers[cursor..cursor + len]; - cursor += len; - if header_type == 7 { - Some(String::from_utf8_lossy(bytes).to_string()) - } else { - None - } - } - 9 => { - cursor += 16; - None - } - _ => return None, - }; - - if name == ":event-type" { - return value; - } - } - None -} - -impl From for EventStreamError { - fn from(err: io::Error) -> Self { - EventStreamError { - message: err.to_string(), - } - } -} diff --git a/src-tauri/src/proxy/kiro/mod.rs b/src-tauri/src/proxy/kiro/mod.rs deleted file mode 100644 index 8037797..0000000 --- a/src-tauri/src/proxy/kiro/mod.rs +++ /dev/null @@ -1,20 +0,0 @@ -pub(crate) mod endpoint; -pub(crate) mod constants; -pub(crate) mod event_stream; -pub(crate) mod model; -pub(crate) mod payload; -pub(crate) mod response; -pub(crate) mod tools; -pub(crate) mod tool_parser; -pub(crate) mod types; -pub(crate) mod utils; - -pub(crate) use endpoint::{select_endpoints, KiroEndpointConfig}; -pub(crate) use event_stream::EventStreamDecoder; -pub(crate) use model::{determine_agentic_mode, map_model_to_kiro}; -pub(crate) use payload::{ - build_payload_from_chat, build_payload_from_claude, build_payload_from_responses, - BuildPayloadResult, -}; -pub(crate) use response::{parse_event_stream, KiroUsage}; -pub(crate) use types::KiroToolUse; diff --git a/src-tauri/src/proxy/kiro/model.rs b/src-tauri/src/proxy/kiro/model.rs deleted file mode 100644 index 2ea67fe..0000000 --- a/src-tauri/src/proxy/kiro/model.rs +++ /dev/null @@ -1,68 +0,0 @@ -pub(crate) fn determine_agentic_mode(model: &str) -> (bool, bool) { - let trimmed = model.trim(); - let is_agentic = trimmed.ends_with("-agentic"); - let is_chat_only = trimmed.ends_with("-chat"); - (is_agentic, is_chat_only) -} - -pub(crate) fn map_model_to_kiro(model: &str) -> String { - let normalized = model.trim(); - match normalized { - // Amazon Q prefix - "amazonq-auto" => "auto", - "amazonq-claude-opus-4-5" => "claude-opus-4.5", - "amazonq-claude-sonnet-4-5" => "claude-sonnet-4.5", - "amazonq-claude-sonnet-4-5-20250929" => "claude-sonnet-4.5", - "amazonq-claude-sonnet-4" => "claude-sonnet-4", - "amazonq-claude-sonnet-4-20250514" => "claude-sonnet-4", - "amazonq-claude-haiku-4-5" => "claude-haiku-4.5", - // Kiro prefix - "kiro-claude-opus-4-5" => "claude-opus-4.5", - "kiro-claude-sonnet-4-5" => "claude-sonnet-4.5", - "kiro-claude-sonnet-4-5-20250929" => "claude-sonnet-4.5", - "kiro-claude-sonnet-4" => "claude-sonnet-4", - "kiro-claude-sonnet-4-20250514" => "claude-sonnet-4", - "kiro-claude-haiku-4-5" => "claude-haiku-4.5", - "kiro-auto" => "auto", - // Native format - "claude-opus-4-5" => "claude-opus-4.5", - "claude-opus-4.5" => "claude-opus-4.5", - "claude-haiku-4-5" => "claude-haiku-4.5", - "claude-haiku-4.5" => "claude-haiku-4.5", - "claude-sonnet-4-5" => "claude-sonnet-4.5", - "claude-sonnet-4-5-20250929" => "claude-sonnet-4.5", - "claude-sonnet-4.5" => "claude-sonnet-4.5", - "claude-sonnet-4" => "claude-sonnet-4", - "claude-sonnet-4-20250514" => "claude-sonnet-4", - "auto" => "auto", - // Agentic variants - "claude-opus-4.5-agentic" => "claude-opus-4.5", - "claude-sonnet-4.5-agentic" => "claude-sonnet-4.5", - "claude-sonnet-4-agentic" => "claude-sonnet-4", - "claude-haiku-4.5-agentic" => "claude-haiku-4.5", - "kiro-claude-opus-4-5-agentic" => "claude-opus-4.5", - "kiro-claude-sonnet-4-5-agentic" => "claude-sonnet-4.5", - "kiro-claude-sonnet-4-agentic" => "claude-sonnet-4", - "kiro-claude-haiku-4-5-agentic" => "claude-haiku-4.5", - _ => { - let lower = normalized.to_ascii_lowercase(); - if lower.contains("haiku") { - return "claude-haiku-4.5".to_string(); - } - if lower.contains("sonnet") { - if lower.contains("3-7") || lower.contains("3.7") { - return "claude-3-7-sonnet-20250219".to_string(); - } - if lower.contains("4-5") || lower.contains("4.5") { - return "claude-sonnet-4.5".to_string(); - } - return "claude-sonnet-4".to_string(); - } - if lower.contains("opus") { - return "claude-opus-4.5".to_string(); - } - return "claude-sonnet-4.5".to_string(); - } - } - .to_string() -} diff --git a/src-tauri/src/proxy/kiro/payload/claude.rs b/src-tauri/src/proxy/kiro/payload/claude.rs deleted file mode 100644 index 00c5d01..0000000 --- a/src-tauri/src/proxy/kiro/payload/claude.rs +++ /dev/null @@ -1,510 +0,0 @@ -use std::collections::HashSet; - -use axum::http::HeaderMap; -use serde_json::{json, Map, Value}; - -use super::super::constants::KIRO_AGENTIC_SYSTEM_PROMPT; -use super::super::types::{ - KiroAssistantResponseMessage, KiroConversationState, KiroHistoryMessage, KiroImage, - KiroImageSource, KiroPayload, KiroTextContent, KiroToolResult, KiroToolUse, - KiroUserInputMessage, KiroUserInputMessageContext, -}; -use super::inference::build_inference_config; -use super::system::{ - extract_tool_choice_hint, has_thinking_tags, inject_hint, inject_timestamp, is_thinking_enabled, -}; -use super::{BuildPayloadResult, THINKING_HINT}; -use super::super::utils::random_uuid; - -pub(crate) fn build_payload_from_claude( - request: &Value, - model_id: &str, - profile_arn: Option<&str>, - origin: &str, - is_agentic: bool, - is_chat_only: bool, - headers: &HeaderMap, -) -> Result { - let object = request - .as_object() - .ok_or_else(|| "Request body must be a JSON object.".to_string())?; - let messages = extract_messages(object)?; - let merged_messages = merge_adjacent_messages(&messages); - let system_prompt = build_system_prompt(object, headers, is_agentic); - - let (history, current_user, current_tool_results) = - process_claude_messages(&merged_messages, model_id, origin); - let current_message = super::build_current_message( - &history, - current_user, - current_tool_results, - model_id, - origin, - &system_prompt, - object, - is_chat_only, - ); - - let payload = KiroPayload { - conversation_state: KiroConversationState { - chat_trigger_type: "MANUAL".to_string(), - conversation_id: random_uuid(), - current_message, - history, - }, - profile_arn: profile_arn.map(|value| value.to_string()), - inference_config: build_inference_config(object), - }; - - let payload_bytes = serde_json::to_vec(&payload) - .map_err(|err| format!("Failed to serialize request payload: {err}"))?; - - Ok(BuildPayloadResult { - payload: payload_bytes, - }) -} - -fn extract_messages(object: &Map) -> Result, String> { - object - .get("messages") - .and_then(Value::as_array) - .map(|items| items.clone()) - .ok_or_else(|| "Request must include messages.".to_string()) -} - -fn build_system_prompt(object: &Map, headers: &HeaderMap, is_agentic: bool) -> String { - let base_system_prompt = extract_claude_system(object); - let thinking_enabled = is_thinking_enabled(object, headers, &base_system_prompt); - - let mut system_prompt = inject_timestamp(base_system_prompt); - if is_agentic { - system_prompt = inject_hint(system_prompt, KIRO_AGENTIC_SYSTEM_PROMPT.trim()); - } - if let Some(tool_choice_hint) = extract_tool_choice_hint(object) { - system_prompt = inject_hint(system_prompt, &tool_choice_hint); - } - - if thinking_enabled && !has_thinking_tags(&system_prompt) { - system_prompt = prepend_hint(system_prompt, THINKING_HINT); - } - - system_prompt -} - -fn prepend_hint(system_prompt: String, hint: &str) -> String { - if hint.trim().is_empty() { - return system_prompt; - } - if system_prompt.trim().is_empty() { - return hint.trim().to_string(); - } - format!("{}\n\n{}", hint.trim(), system_prompt) -} - -fn extract_claude_system(object: &Map) -> String { - let Some(system) = object.get("system") else { - return String::new(); - }; - match system { - Value::String(text) => text.to_string(), - Value::Array(items) => { - let mut output = String::new(); - for item in items { - if let Some(text) = item.get("text").and_then(Value::as_str) { - output.push_str(text); - } else if let Some(text) = item.as_str() { - output.push_str(text); - } - } - output - } - _ => String::new(), - } -} - -fn merge_adjacent_messages(messages: &[Value]) -> Vec { - let mut merged: Vec = Vec::new(); - for message in messages { - let Some(message) = message.as_object() else { - continue; - }; - let role = message.get("role").and_then(Value::as_str).unwrap_or(""); - let content = message.get("content").unwrap_or(&Value::Null); - let blocks = normalize_blocks(content); - - if let Some(last) = merged.last_mut().and_then(Value::as_object_mut) { - if last.get("role").and_then(Value::as_str) == Some(role) { - let merged_blocks = merge_blocks(last.get("content"), blocks); - last.insert("content".to_string(), Value::Array(merged_blocks)); - continue; - } - } - - merged.push(json!({ - "role": role, - "content": Value::Array(blocks), - })); - } - merged -} - -fn normalize_blocks(content: &Value) -> Vec { - match content { - Value::String(text) => vec![json!({ "type": "text", "text": text })], - Value::Array(items) => items.clone(), - _ => Vec::new(), - } -} - -fn merge_blocks(existing: Option<&Value>, mut next: Vec) -> Vec { - let mut merged = match existing { - Some(Value::Array(items)) => items.clone(), - Some(Value::String(text)) => vec![json!({ "type": "text", "text": text })], - _ => Vec::new(), - }; - - merge_text_blocks(&mut merged, &mut next); - merged.extend(next); - merged -} - -fn merge_text_blocks(existing: &mut Vec, next: &mut Vec) { - let Some(Value::Object(last)) = existing.last_mut() else { - return; - }; - if last.get("type").and_then(Value::as_str) != Some("text") { - return; - } - let Some(Value::Object(first)) = next.first_mut() else { - return; - }; - if first.get("type").and_then(Value::as_str) != Some("text") { - return; - } - let last_text = last.get("text").and_then(Value::as_str).unwrap_or("").to_string(); - let first_text = first.get("text").and_then(Value::as_str).unwrap_or("").to_string(); - if last_text.is_empty() && first_text.is_empty() { - return; - } - last.insert("text".to_string(), Value::String(format!("{last_text}\n{first_text}"))); - next.remove(0); -} - -fn process_claude_messages( - messages: &[Value], - model_id: &str, - origin: &str, -) -> (Vec, Option, Vec) { - let mut history = Vec::new(); - let mut current_user = None; - let mut current_tool_results = Vec::new(); - - for (index, message) in messages.iter().enumerate() { - let Some(message) = message.as_object() else { - continue; - }; - let role = message.get("role").and_then(Value::as_str).unwrap_or(""); - let is_last = index == messages.len().saturating_sub(1); - - match role { - "user" => { - let (mut user_msg, tool_results) = build_user_message(message, model_id, origin); - if is_last { - current_user = Some(user_msg); - current_tool_results = tool_results; - continue; - } - if user_msg.content.trim().is_empty() { - user_msg.content = if tool_results.is_empty() { - "Continue".to_string() - } else { - "Tool results provided.".to_string() - }; - } - if !tool_results.is_empty() { - user_msg.user_input_message_context = Some(KiroUserInputMessageContext { - tool_results, - tools: Vec::new(), - }); - } - history.push(KiroHistoryMessage { - user_input_message: Some(user_msg), - assistant_response_message: None, - }); - } - "assistant" => { - let assistant_msg = build_assistant_message(message); - history.push(KiroHistoryMessage { - user_input_message: None, - assistant_response_message: Some(assistant_msg), - }); - if is_last { - current_user = Some(KiroUserInputMessage { - content: "Continue".to_string(), - model_id: model_id.to_string(), - origin: origin.to_string(), - images: Vec::new(), - user_input_message_context: None, - }); - } - } - _ => {} - } - } - - (history, current_user, current_tool_results) -} - -fn build_user_message( - message: &Map, - model_id: &str, - origin: &str, -) -> (KiroUserInputMessage, Vec) { - let mut content = String::new(); - let mut images = Vec::new(); - let mut tool_results = Vec::new(); - let mut seen_tool_use_ids = HashSet::new(); - - if let Some(value) = message.get("content") { - match value { - Value::String(text) => { - content.push_str(text); - } - Value::Array(parts) => { - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str).unwrap_or(""); - match part_type { - "text" => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - content.push_str(text); - } - } - "image" => { - if let Some(image) = parse_image_block(part) { - images.push(image); - } - } - "tool_result" => { - if let Some(result) = - parse_tool_result_block(part, &mut seen_tool_use_ids) - { - tool_results.push(result); - } - } - _ => {} - } - if part_type.is_empty() - && part.get("tool_use_id").is_some() - && part.get("content").is_some() - { - if let Some(result) = - parse_tool_result_block(part, &mut seen_tool_use_ids) - { - tool_results.push(result); - } - } - } - } - _ => {} - } - } - - let user_msg = KiroUserInputMessage { - content, - model_id: model_id.to_string(), - origin: origin.to_string(), - images, - user_input_message_context: None, - }; - - (user_msg, tool_results) -} - -fn parse_image_block(block: &Map) -> Option { - let source = block.get("source").and_then(Value::as_object)?; - let media_type = source.get("media_type").and_then(Value::as_str)?; - let data = source.get("data").and_then(Value::as_str)?; - if data.is_empty() { - return None; - } - let format = media_type.split('/').last().unwrap_or(""); - if format.is_empty() { - return None; - } - Some(KiroImage { - format: format.to_string(), - source: KiroImageSource { - bytes: data.to_string(), - }, - }) -} - -fn parse_tool_result_block( - block: &Map, - seen_tool_use_ids: &mut HashSet, -) -> Option { - let tool_use_id = block - .get("tool_use_id") - .or_else(|| block.get("toolUseId")) - .or_else(|| block.get("tool_call_id")) - .and_then(Value::as_str) - .unwrap_or(""); - if tool_use_id.is_empty() { - return None; - } - if seen_tool_use_ids.contains(tool_use_id) { - return None; - } - seen_tool_use_ids.insert(tool_use_id.to_string()); - - let is_error = block.get("is_error").and_then(Value::as_bool).unwrap_or(false); - let content = block - .get("content") - .or_else(|| block.get("output")); - let mut contents = parse_tool_result_contents(content); - if contents.is_empty() { - contents = vec![KiroTextContent { - text: "Tool use was cancelled by the user".to_string(), - }]; - } - - Some(KiroToolResult { - content: contents, - status: if is_error { - "error".to_string() - } else { - "success".to_string() - }, - tool_use_id: tool_use_id.to_string(), - }) -} - -fn parse_tool_result_contents(content: Option<&Value>) -> Vec { - let Some(content) = content else { - return Vec::new(); - }; - match content { - Value::String(text) => { - if text.is_empty() { - Vec::new() - } else { - vec![KiroTextContent { - text: text.to_string(), - }] - } - } - Value::Object(item) => { - if item.get("type").and_then(Value::as_str) == Some("text") { - if let Some(text) = item.get("text").and_then(Value::as_str) { - if !text.is_empty() { - return vec![KiroTextContent { - text: text.to_string(), - }]; - } - } - } - if let Some(text) = item.get("text").and_then(Value::as_str) { - if !text.is_empty() { - return vec![KiroTextContent { - text: text.to_string(), - }]; - } - } - Vec::new() - } - Value::Array(items) => items - .iter() - .filter_map(|item| { - if let Some(text) = item.as_str() { - if !text.is_empty() { - return Some(KiroTextContent { - text: text.to_string(), - }); - } - } - let Some(item) = item.as_object() else { - return None; - }; - if item.get("type").and_then(Value::as_str) == Some("text") { - let text = item.get("text").and_then(Value::as_str)?; - if !text.is_empty() { - return Some(KiroTextContent { - text: text.to_string(), - }); - } - } - if let Some(text) = item.get("text").and_then(Value::as_str) { - if !text.is_empty() { - return Some(KiroTextContent { - text: text.to_string(), - }); - } - } - None - }) - .collect(), - _ => Vec::new(), - } -} - -fn build_assistant_message(message: &Map) -> KiroAssistantResponseMessage { - let mut content = String::new(); - let mut tool_uses = Vec::new(); - - if let Some(value) = message.get("content") { - match value { - Value::String(text) => content.push_str(text), - Value::Array(parts) => { - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str).unwrap_or(""); - match part_type { - "text" => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - content.push_str(text); - } - } - "tool_use" => { - if let Some(tool_use) = parse_tool_use_block(part) { - tool_uses.push(tool_use); - } - } - _ => {} - } - } - } - _ => {} - } - } - - KiroAssistantResponseMessage { content, tool_uses } -} - -fn parse_tool_use_block(block: &Map) -> Option { - let tool_use_id = block.get("id").and_then(Value::as_str).unwrap_or(""); - let name = block.get("name").and_then(Value::as_str).unwrap_or(""); - if tool_use_id.is_empty() || name.is_empty() { - return None; - } - let input_value = block.get("input"); - let input = input_value - .and_then(Value::as_object) - .map(|object| object.clone()) - .unwrap_or_default(); - - Some(KiroToolUse { - tool_use_id: tool_use_id.to_string(), - name: name.to_string(), - input, - }) -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "claude.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/kiro/payload/claude.test.rs b/src-tauri/src/proxy/kiro/payload/claude.test.rs deleted file mode 100644 index e654367..0000000 --- a/src-tauri/src/proxy/kiro/payload/claude.test.rs +++ /dev/null @@ -1,55 +0,0 @@ -use super::*; - -#[test] -fn build_payload_from_claude_includes_tools_and_results() { - let request = json!({ - "model": "claude-sonnet-4.5", - "tool_choice": { "type": "tool", "name": "mcp__demo__ping" }, - "tools": [ - { - "name": "mcp__demo__ping", - "description": "", - "input_schema": { "type": "object", "properties": {} } - } - ], - "messages": [ - { - "role": "user", - "content": [ - { "type": "text", "text": "hi" }, - { "type": "tool_result", "tool_use_id": "toolu_1", "is_error": true, "content": [{ "type": "text", "text": "fail" }] } - ] - } - ] - }); - let headers = HeaderMap::new(); - let result = build_payload_from_claude( - &request, - "claude-sonnet-4.5", - None, - "CLI", - false, - false, - &headers, - ) - .expect("payload"); - let payload: Value = serde_json::from_slice(&result.payload).expect("json"); - let context = payload - .get("conversationState") - .and_then(Value::as_object) - .and_then(|state| state.get("currentMessage")) - .and_then(Value::as_object) - .and_then(|msg| msg.get("userInputMessage")) - .and_then(Value::as_object) - .and_then(|msg| msg.get("userInputMessageContext")) - .and_then(Value::as_object) - .expect("context"); - assert!(context.get("tools").and_then(Value::as_array).is_some()); - assert_eq!( - context - .get("toolResults") - .and_then(Value::as_array) - .map(|items| items.len()), - Some(1) - ); -} diff --git a/src-tauri/src/proxy/kiro/payload/inference.rs b/src-tauri/src/proxy/kiro/payload/inference.rs deleted file mode 100644 index 24acbc2..0000000 --- a/src-tauri/src/proxy/kiro/payload/inference.rs +++ /dev/null @@ -1,28 +0,0 @@ -use serde_json::{Map, Value}; - -use super::super::constants::KIRO_MAX_OUTPUT_TOKENS; -use super::super::types::KiroInferenceConfig; - -pub(super) fn build_inference_config(object: &Map) -> Option { - let mut max_tokens = object - .get("max_output_tokens") - .or_else(|| object.get("max_tokens")) - .and_then(Value::as_i64); - if let Some(value) = max_tokens { - if value == -1 { - max_tokens = Some(KIRO_MAX_OUTPUT_TOKENS); - } - } - let temperature = object.get("temperature").and_then(Value::as_f64); - let top_p = object.get("top_p").and_then(Value::as_f64); - - if max_tokens.is_none() && temperature.is_none() && top_p.is_none() { - return None; - } - - Some(KiroInferenceConfig { - max_tokens, - temperature, - top_p, - }) -} diff --git a/src-tauri/src/proxy/kiro/payload/input.rs b/src-tauri/src/proxy/kiro/payload/input.rs deleted file mode 100644 index b923fac..0000000 --- a/src-tauri/src/proxy/kiro/payload/input.rs +++ /dev/null @@ -1,157 +0,0 @@ -use serde_json::{Map, Value}; - -pub(super) fn extract_input_messages(object: &Map) -> Result, String> { - let input = object.get("input"); - match input { - Some(Value::String(text)) => Ok(vec![serde_json::json!({ "role": "user", "content": text })]), - Some(Value::Array(items)) => responses_input_to_chat_messages(items), - Some(Value::Null) | None => Ok(Vec::new()), - _ => Err("Responses input must be a string or array.".to_string()), - } -} - -fn responses_input_to_chat_messages(items: &[Value]) -> Result, String> { - let mut messages = Vec::with_capacity(items.len()); - for item in items { - messages.push(responses_input_item_to_chat_message(item)?); - } - Ok(messages) -} - -fn responses_input_item_to_chat_message(item: &Value) -> Result { - let Some(item) = item.as_object() else { - return Err("Responses input item must be an object.".to_string()); - }; - - if item.get("role").and_then(Value::as_str).is_some() { - let mut output = item.clone(); - if let Some(content) = item - .get("content") - .and_then(responses_message_content_to_chat_content) - { - output.insert("content".to_string(), content); - } - return Ok(Value::Object(output)); - } - - let Some(item_type) = item.get("type").and_then(Value::as_str) else { - return Err("Responses input item must include role or type.".to_string()); - }; - - match item_type { - "message" => responses_message_item_to_chat_message(item), - "function_call_output" => responses_function_call_output_item_to_chat_message(item), - "function_call" => responses_function_call_item_to_chat_message(item), - other => Err(format!("Unsupported Responses input item type: {other}")), - } -} - -fn responses_message_item_to_chat_message(item: &Map) -> Result { - let role = item - .get("role") - .and_then(Value::as_str) - .ok_or_else(|| "Responses message item must include role.".to_string())?; - let content = item - .get("content") - .and_then(responses_message_content_to_chat_content) - .unwrap_or_else(|| Value::String(String::new())); - Ok(serde_json::json!({ "role": role, "content": content })) -} - -fn responses_function_call_output_item_to_chat_message( - item: &Map, -) -> Result { - let call_id = item - .get("call_id") - .and_then(Value::as_str) - .ok_or_else(|| "function_call_output must include call_id.".to_string())?; - let output = item.get("output").and_then(Value::as_str).unwrap_or(""); - let mut message = Map::new(); - message.insert("role".to_string(), Value::String("tool".to_string())); - message.insert("tool_call_id".to_string(), Value::String(call_id.to_string())); - message.insert("content".to_string(), Value::String(output.to_string())); - if let Some(is_error) = item.get("is_error").and_then(Value::as_bool) { - if is_error { - message.insert("is_error".to_string(), Value::Bool(true)); - } - } - if let Some(parts) = item.get("output_parts") { - message.insert("content_parts".to_string(), parts.clone()); - } - Ok(Value::Object(message)) -} - -fn responses_function_call_item_to_chat_message(item: &Map) -> Result { - let call_id = item - .get("call_id") - .and_then(Value::as_str) - .ok_or_else(|| "function_call must include call_id.".to_string())?; - let name = item.get("name").and_then(Value::as_str).unwrap_or(""); - let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or(""); - Ok(serde_json::json!({ - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": call_id, - "type": "function", - "function": { "name": name, "arguments": arguments } - } - ] - })) -} - -fn responses_message_content_to_chat_content(value: &Value) -> Option { - match value { - Value::String(text) => Some(Value::String(text.to_string())), - Value::Array(parts) => { - let mut output_parts = Vec::new(); - let mut combined = String::new(); - let mut text_only = true; - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str); - match part_type { - Some("input_text") | Some("text") | Some("output_text") => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - combined.push_str(text); - output_parts.push(serde_json::json!({ "type": "text", "text": text })); - } - } - Some("refusal") => { - let text = part - .get("refusal") - .or_else(|| part.get("text")) - .and_then(Value::as_str) - .unwrap_or(""); - if !text.is_empty() { - combined.push_str(text); - output_parts.push(serde_json::json!({ "type": "text", "text": text })); - } - } - Some("input_image") | Some("output_image") => { - if let Some(image_url) = part.get("image_url") { - text_only = false; - output_parts.push(serde_json::json!({ - "type": "image_url", - "image_url": image_url - })); - } - } - _ => { - text_only = false; - } - } - } - if text_only { - Some(Value::String(combined)) - } else { - Some(Value::Array(output_parts)) - } - } - Value::Null => None, - _ => Some(Value::String(String::new())), - } -} diff --git a/src-tauri/src/proxy/kiro/payload/messages.rs b/src-tauri/src/proxy/kiro/payload/messages.rs deleted file mode 100644 index 99a699d..0000000 --- a/src-tauri/src/proxy/kiro/payload/messages.rs +++ /dev/null @@ -1,391 +0,0 @@ -use serde_json::{Map, Value}; - -use super::super::types::{ - KiroAssistantResponseMessage, KiroHistoryMessage, KiroImage, KiroImageSource, KiroTextContent, - KiroToolResult, KiroToolUse, KiroUserInputMessage, KiroUserInputMessageContext, -}; - -pub(super) fn process_messages( - messages: &[Value], - model_id: &str, - origin: &str, -) -> (Vec, Option, Vec) { - let mut state = MessageState::new(); - for (index, message) in messages.iter().enumerate() { - let Some(message) = message.as_object() else { - continue; - }; - let role = message.get("role").and_then(Value::as_str).unwrap_or(""); - let is_last = index == messages.len().saturating_sub(1); - match role { - "system" => continue, - "user" => handle_user_message(message, model_id, origin, is_last, &mut state), - "assistant" => handle_assistant_message(message, model_id, origin, is_last, &mut state), - "tool" => handle_tool_message(message, &mut state), - _ => {} - } - } - finalize_pending_tools(model_id, origin, &mut state); - (state.history, state.current_user, state.current_tool_results) -} - -struct MessageState { - history: Vec, - current_user: Option, - current_tool_results: Vec, - pending_tool_results: Vec, -} - -impl MessageState { - fn new() -> Self { - Self { - history: Vec::new(), - current_user: None, - current_tool_results: Vec::new(), - pending_tool_results: Vec::new(), - } - } -} - -fn handle_user_message( - message: &Map, - model_id: &str, - origin: &str, - is_last: bool, - state: &mut MessageState, -) { - let (mut user_msg, tool_results) = build_user_message(message, model_id, origin); - let mut tool_results = state - .pending_tool_results - .drain(..) - .chain(tool_results) - .collect::>(); - - if is_last { - state.current_user = Some(user_msg); - state.current_tool_results = tool_results; - return; - } - - if user_msg.content.trim().is_empty() { - user_msg.content = if tool_results.is_empty() { - "Continue".to_string() - } else { - "Tool results provided.".to_string() - }; - } - if !tool_results.is_empty() { - user_msg.user_input_message_context = Some(KiroUserInputMessageContext { - tool_results: tool_results.drain(..).collect(), - tools: Vec::new(), - }); - } - state.history.push(KiroHistoryMessage { - user_input_message: Some(user_msg), - assistant_response_message: None, - }); -} - -fn handle_assistant_message( - message: &Map, - model_id: &str, - origin: &str, - is_last: bool, - state: &mut MessageState, -) { - let assistant_msg = build_assistant_message(message); - if !state.pending_tool_results.is_empty() { - let synthetic = KiroUserInputMessage { - content: "Tool results provided.".to_string(), - model_id: model_id.to_string(), - origin: origin.to_string(), - images: Vec::new(), - user_input_message_context: Some(KiroUserInputMessageContext { - tool_results: state.pending_tool_results.drain(..).collect(), - tools: Vec::new(), - }), - }; - state.history.push(KiroHistoryMessage { - user_input_message: Some(synthetic), - assistant_response_message: None, - }); - } - - state.history.push(KiroHistoryMessage { - user_input_message: None, - assistant_response_message: Some(assistant_msg), - }); - - if is_last { - state.current_user = Some(KiroUserInputMessage { - content: "Continue".to_string(), - model_id: model_id.to_string(), - origin: origin.to_string(), - images: Vec::new(), - user_input_message_context: None, - }); - } -} - -fn handle_tool_message(message: &Map, state: &mut MessageState) { - if let Some(tool_result) = build_tool_result(message) { - state.pending_tool_results.push(tool_result); - } -} - -fn finalize_pending_tools(model_id: &str, origin: &str, state: &mut MessageState) { - if state.pending_tool_results.is_empty() { - return; - } - state - .current_tool_results - .extend(state.pending_tool_results.drain(..)); - if state.current_user.is_some() { - return; - } - state.current_user = Some(KiroUserInputMessage { - content: "Tool results provided.".to_string(), - model_id: model_id.to_string(), - origin: origin.to_string(), - images: Vec::new(), - user_input_message_context: None, - }); -} - -fn build_user_message( - message: &Map, - model_id: &str, - origin: &str, -) -> (KiroUserInputMessage, Vec) { - let mut content = String::new(); - let mut images = Vec::new(); - - if let Some(value) = message.get("content") { - match value { - Value::String(text) => { - content.push_str(text); - } - Value::Array(parts) => { - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str).unwrap_or("text"); - match part_type { - "text" | "input_text" | "output_text" => { - if let Some(text) = part.get("text").and_then(Value::as_str) { - content.push_str(text); - } - } - "image_url" | "input_image" => { - if let Some(image) = parse_image_url(part.get("image_url")) { - images.push(image); - } - } - _ => {} - } - } - } - _ => {} - } - } - - let user_msg = KiroUserInputMessage { - content, - model_id: model_id.to_string(), - origin: origin.to_string(), - images, - user_input_message_context: None, - }; - - (user_msg, Vec::new()) -} - -fn build_assistant_message(message: &Map) -> KiroAssistantResponseMessage { - let mut content = String::new(); - if let Some(value) = message.get("content") { - match value { - Value::String(text) => content.push_str(text), - Value::Array(parts) => { - for part in parts { - if part.get("type").and_then(Value::as_str) == Some("text") { - if let Some(text) = part.get("text").and_then(Value::as_str) { - content.push_str(text); - } - } - } - } - _ => {} - } - } - - let mut tool_uses = Vec::new(); - if let Some(tool_calls) = message.get("tool_calls").and_then(Value::as_array) { - for tool_call in tool_calls { - let Some(tool_call) = tool_call.as_object() else { - continue; - }; - if tool_call.get("type").and_then(Value::as_str) != Some("function") { - continue; - } - let tool_use_id = tool_call.get("id").and_then(Value::as_str).unwrap_or(""); - let name = tool_call - .get("function") - .and_then(Value::as_object) - .and_then(|function| function.get("name")) - .and_then(Value::as_str) - .unwrap_or(""); - let arguments = tool_call - .get("function") - .and_then(Value::as_object) - .and_then(|function| function.get("arguments")) - .and_then(Value::as_str) - .unwrap_or(""); - let input = serde_json::from_str::>(arguments).unwrap_or_default(); - if !tool_use_id.is_empty() && !name.is_empty() { - tool_uses.push(KiroToolUse { - tool_use_id: tool_use_id.to_string(), - name: name.to_string(), - input, - }); - } - } - } - - KiroAssistantResponseMessage { content, tool_uses } -} - -fn build_tool_result(message: &Map) -> Option { - let tool_use_id = message - .get("tool_call_id") - .and_then(Value::as_str) - .unwrap_or(""); - if tool_use_id.is_empty() { - return None; - } - let is_error = message - .get("is_error") - .and_then(Value::as_bool) - .unwrap_or(false); - let mut contents = extract_tool_result_contents(message); - if contents.is_empty() { - contents = vec![KiroTextContent { - text: "Tool use was cancelled by the user".to_string(), - }]; - } - Some(KiroToolResult { - content: contents, - status: if is_error { - "error".to_string() - } else { - "success".to_string() - }, - tool_use_id: tool_use_id.to_string(), - }) -} - -fn extract_tool_result_contents(message: &Map) -> Vec { - if let Some(parts) = message.get("content_parts").and_then(Value::as_array) { - let mut out = Vec::new(); - for part in parts { - match part { - Value::String(text) => { - if !text.is_empty() { - out.push(KiroTextContent { text: text.clone() }); - } - } - Value::Object(obj) => { - if obj.get("type").and_then(Value::as_str) == Some("text") { - if let Some(text) = obj.get("text").and_then(Value::as_str) { - if !text.is_empty() { - out.push(KiroTextContent { - text: text.to_string(), - }); - } - } - } - } - _ => {} - } - } - if !out.is_empty() { - return out; - } - } - - let content = message - .get("content") - .and_then(Value::as_str) - .unwrap_or(""); - if content.is_empty() { - return Vec::new(); - } - vec![KiroTextContent { - text: content.to_string(), - }] -} - -fn parse_image_url(value: Option<&Value>) -> Option { - let url = match value { - Some(Value::String(url)) => url.as_str(), - Some(Value::Object(obj)) => obj.get("url").and_then(Value::as_str)?, - _ => return None, - }; - if !url.starts_with("data:") { - return None; - } - let parts = url.splitn(2, ";base64,").collect::>(); - if parts.len() != 2 { - return None; - } - let media_type = parts[0].trim_start_matches("data:"); - let data = parts[1].trim(); - if data.is_empty() { - return None; - } - let format = media_type.split('/').last().unwrap_or("").to_string(); - if format.is_empty() { - return None; - } - Some(KiroImage { - format, - source: KiroImageSource { - bytes: data.to_string(), - }, - }) -} - -pub(super) fn build_final_content( - content: &str, - system_prompt: &str, - tool_results: &[KiroToolResult], -) -> String { - let mut output = String::new(); - if !system_prompt.trim().is_empty() { - output.push_str("--- SYSTEM PROMPT ---\n"); - output.push_str(system_prompt.trim()); - output.push_str("\n--- END SYSTEM PROMPT ---\n\n"); - } - output.push_str(content); - - if output.trim().is_empty() { - if tool_results.is_empty() { - return "Continue".to_string(); - } - return "Tool results provided.".to_string(); - } - - output -} - -pub(super) fn deduplicate_tool_results(results: Vec) -> Vec { - let mut seen = std::collections::HashSet::new(); - let mut output = Vec::new(); - for result in results { - if !seen.insert(result.tool_use_id.clone()) { - continue; - } - output.push(result); - } - output -} diff --git a/src-tauri/src/proxy/kiro/payload/mod.rs b/src-tauri/src/proxy/kiro/payload/mod.rs deleted file mode 100644 index 42d61fe..0000000 --- a/src-tauri/src/proxy/kiro/payload/mod.rs +++ /dev/null @@ -1,201 +0,0 @@ -use axum::http::HeaderMap; -use serde_json::{Map, Value}; - -use super::constants::KIRO_AGENTIC_SYSTEM_PROMPT; -use super::tools::convert_openai_tools; -use super::types::{ - KiroConversationState, KiroCurrentMessage, KiroPayload, KiroUserInputMessage, - KiroUserInputMessageContext, -}; -use super::utils::random_uuid; -use inference::build_inference_config; -use input::extract_input_messages; -use messages::{build_final_content, deduplicate_tool_results, process_messages}; -use system::{ - extract_response_format_hint, extract_system_prompt, extract_tool_choice_hint, - inject_hint, inject_timestamp, is_thinking_enabled, -}; - -mod inference; -mod input; -mod messages; -mod system; -mod claude; - -const THINKING_HINT: &str = - "enabled\n200000"; - -pub(crate) struct BuildPayloadResult { - pub(crate) payload: Vec, -} - -pub(crate) use claude::build_payload_from_claude; - -pub(crate) fn build_payload_from_responses( - request: &Value, - model_id: &str, - profile_arn: Option<&str>, - origin: &str, - is_agentic: bool, - is_chat_only: bool, - headers: &HeaderMap, -) -> Result { - let object = request - .as_object() - .ok_or_else(|| "Request body must be a JSON object.".to_string())?; - - let messages = extract_input_messages(object)?; - let system_prompt = prepare_system_prompt(object, &messages, headers, is_agentic); - - let (history, current_user, current_tool_results) = - process_messages(&messages, model_id, origin); - let current_message = build_current_message( - &history, - current_user, - current_tool_results, - model_id, - origin, - &system_prompt, - object, - is_chat_only, - ); - - let payload = KiroPayload { - conversation_state: KiroConversationState { - chat_trigger_type: "MANUAL".to_string(), - conversation_id: random_uuid(), - current_message, - history, - }, - profile_arn: profile_arn.map(|value| value.to_string()), - inference_config: build_inference_config(object), - }; - - let payload_bytes = serde_json::to_vec(&payload) - .map_err(|err| format!("Failed to serialize request payload: {err}"))?; - - Ok(BuildPayloadResult { - payload: payload_bytes, - }) -} - -pub(crate) fn build_payload_from_chat( - request: &Value, - model_id: &str, - profile_arn: Option<&str>, - origin: &str, - is_agentic: bool, - is_chat_only: bool, - headers: &HeaderMap, -) -> Result { - let object = request - .as_object() - .ok_or_else(|| "Request body must be a JSON object.".to_string())?; - let messages = object - .get("messages") - .and_then(Value::as_array) - .ok_or_else(|| "Chat request must include messages.".to_string())?; - let messages = messages.clone(); - - let system_prompt = prepare_system_prompt(object, &messages, headers, is_agentic); - let (history, current_user, current_tool_results) = - process_messages(&messages, model_id, origin); - let current_message = build_current_message( - &history, - current_user, - current_tool_results, - model_id, - origin, - &system_prompt, - object, - is_chat_only, - ); - - let payload = KiroPayload { - conversation_state: KiroConversationState { - chat_trigger_type: "MANUAL".to_string(), - conversation_id: random_uuid(), - current_message, - history, - }, - profile_arn: profile_arn.map(|value| value.to_string()), - inference_config: build_inference_config(object), - }; - - let payload_bytes = serde_json::to_vec(&payload) - .map_err(|err| format!("Failed to serialize request payload: {err}"))?; - - Ok(BuildPayloadResult { - payload: payload_bytes, - }) -} - -fn prepare_system_prompt( - object: &Map, - messages: &[Value], - headers: &HeaderMap, - is_agentic: bool, -) -> String { - let mut system_prompt = extract_system_prompt(object, messages); - let thinking_enabled = is_thinking_enabled(object, headers, &system_prompt); - if thinking_enabled && !system::has_thinking_tags(&system_prompt) { - system_prompt = inject_hint(system_prompt, THINKING_HINT); - } - system_prompt = inject_timestamp(system_prompt); - if is_agentic { - system_prompt = inject_hint(system_prompt, KIRO_AGENTIC_SYSTEM_PROMPT.trim()); - } - - if let Some(tool_choice_hint) = extract_tool_choice_hint(object) { - system_prompt = inject_hint(system_prompt, &tool_choice_hint); - } - if let Some(response_format_hint) = extract_response_format_hint(object) { - system_prompt = inject_hint(system_prompt, &response_format_hint); - } - - system_prompt -} - -fn build_current_message( - history: &[super::types::KiroHistoryMessage], - current_user: Option, - mut tool_results: Vec, - model_id: &str, - origin: &str, - system_prompt: &str, - object: &Map, - is_chat_only: bool, -) -> KiroCurrentMessage { - if let Some(mut user) = current_user { - let prompt = if history.is_empty() { system_prompt } else { "" }; - user.content = build_final_content(&user.content, prompt, &tool_results); - tool_results = deduplicate_tool_results(tool_results); - let tools = convert_openai_tools(object.get("tools"), is_chat_only); - if !tools.is_empty() || !tool_results.is_empty() { - user.user_input_message_context = Some(KiroUserInputMessageContext { - tool_results, - tools, - }); - } - return KiroCurrentMessage { - user_input_message: user, - }; - } - - let fallback = if system_prompt.trim().is_empty() { - "Continue".to_string() - } else { - format!( - "--- SYSTEM PROMPT ---\n{system_prompt}\n--- END SYSTEM PROMPT ---\n" - ) - }; - KiroCurrentMessage { - user_input_message: KiroUserInputMessage { - content: fallback, - model_id: model_id.to_string(), - origin: origin.to_string(), - images: Vec::new(), - user_input_message_context: None, - }, - } -} diff --git a/src-tauri/src/proxy/kiro/payload/system.rs b/src-tauri/src/proxy/kiro/payload/system.rs deleted file mode 100644 index c38b28c..0000000 --- a/src-tauri/src/proxy/kiro/payload/system.rs +++ /dev/null @@ -1,226 +0,0 @@ -use axum::http::HeaderMap; -use serde_json::{Map, Value}; -use time::OffsetDateTime; - -pub(super) fn extract_system_prompt(object: &Map, messages: &[Value]) -> String { - let mut parts = Vec::new(); - if let Some(Value::String(instructions)) = object.get("instructions") { - if !instructions.trim().is_empty() { - parts.push(instructions.trim().to_string()); - } - } - - for message in messages { - let Some(message) = message.as_object() else { - continue; - }; - let role = message.get("role").and_then(Value::as_str); - if role != Some("system") { - continue; - } - if let Some(content) = message.get("content") { - match content { - Value::String(text) => { - if !text.trim().is_empty() { - parts.push(text.trim().to_string()); - } - } - Value::Array(items) => { - for item in items { - if let Some(text) = item.get("text").and_then(Value::as_str) { - if !text.trim().is_empty() { - parts.push(text.trim().to_string()); - } - } - } - } - _ => {} - } - } - } - - parts.join("\n") -} - -pub(super) fn is_thinking_enabled( - object: &Map, - headers: &HeaderMap, - system_prompt: &str, -) -> bool { - if thinking_enabled_from_header(headers) { - return true; - } - if thinking_enabled_from_claude(object) { - return true; - } - if thinking_enabled_from_reasoning_effort(object) { - return true; - } - if thinking_enabled_from_system_prompt(system_prompt) { - return true; - } - if thinking_enabled_from_model_hint(object) { - return true; - } - false -} - -pub(super) fn inject_hint(mut system_prompt: String, hint: &str) -> String { - if hint.trim().is_empty() { - return system_prompt; - } - if system_prompt.trim().is_empty() { - return hint.trim().to_string(); - } - system_prompt.push('\n'); - system_prompt.push_str(hint.trim()); - system_prompt -} - -pub(super) fn inject_timestamp(system_prompt: String) -> String { - let timestamp = format_timestamp(); - let context = format!("[Context: Current time is {timestamp}]"); - if system_prompt.trim().is_empty() { - return context; - } - format!("{context}\n\n{system_prompt}") -} - -pub(super) fn extract_tool_choice_hint(object: &Map) -> Option { - let tool_choice = object.get("tool_choice")?; - if let Some(choice) = tool_choice.as_str() { - return match choice { - "none" => Some( - "[INSTRUCTION: Do NOT use any tools. Respond with text only.]".to_string(), - ), - "required" | "any" => Some("[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]".to_string()), - "auto" => None, - _ => None, - }; - } - if let Some(choice) = tool_choice.as_object() { - let choice_type = choice.get("type").and_then(Value::as_str).unwrap_or(""); - let name = match choice_type { - "function" => choice - .get("function") - .and_then(Value::as_object) - .and_then(|function| function.get("name")) - .and_then(Value::as_str) - .or_else(|| choice.get("name").and_then(Value::as_str)) - .unwrap_or(""), - "tool" => choice.get("name").and_then(Value::as_str).unwrap_or(""), - "any" => "", - _ => "", - }; - if choice_type == "any" { - return Some("[INSTRUCTION: You MUST use at least one of the available tools to respond. Do not respond with text only - always make a tool call.]".to_string()); - } - if !name.trim().is_empty() { - return Some(format!("[INSTRUCTION: You MUST use the tool named '{name}' to respond. Do not use any other tool or respond with text only.]")); - } - } - None -} - -pub(super) fn extract_response_format_hint(object: &Map) -> Option { - let mut format_value = object.get("response_format"); - if format_value.is_none() { - format_value = object - .get("text") - .and_then(Value::as_object) - .and_then(|text| text.get("format")); - } - let format_value = format_value?; - let format_type = format_value.get("type").and_then(Value::as_str); - match format_type { - Some("json_object") => Some("[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]".to_string()), - Some("json_schema") => { - let schema = format_value - .get("json_schema") - .and_then(Value::as_object) - .and_then(|schema| schema.get("schema")); - if let Some(schema) = schema { - let mut schema_str = schema.to_string(); - if schema_str.len() > 500 { - schema_str.truncate(500); - schema_str.push_str("..."); - } - return Some(format!("[INSTRUCTION: You MUST respond with valid JSON that matches this schema: {schema_str}. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]")); - } - Some("[INSTRUCTION: You MUST respond with valid JSON only. Do not include any text before or after the JSON. Do not wrap the JSON in markdown code blocks. Output raw JSON directly.]".to_string()) - } - Some("text") | _ => None, - } -} - -fn format_timestamp() -> String { - let format = time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second] UTC"); - if let Ok(format) = format { - if let Ok(value) = OffsetDateTime::now_utc().format(&format) { - return value; - } - } - OffsetDateTime::now_utc() - .format(&time::format_description::well_known::Rfc3339) - .unwrap_or_else(|_| "unknown".to_string()) -} - -fn thinking_enabled_from_header(headers: &HeaderMap) -> bool { - let beta = headers.get("anthropic-beta").or_else(|| headers.get("Anthropic-Beta")); - let Some(beta) = beta else { - return false; - }; - beta.to_str() - .ok() - .is_some_and(|value| value.contains("interleaved-thinking")) -} - -fn thinking_enabled_from_claude(object: &Map) -> bool { - let Some(thinking) = object.get("thinking").and_then(Value::as_object) else { - return false; - }; - if thinking.get("type").and_then(Value::as_str) != Some("enabled") { - return false; - } - if let Some(budget) = thinking.get("budget_tokens").and_then(Value::as_i64) { - return budget > 0; - } - true -} - -fn thinking_enabled_from_reasoning_effort(object: &Map) -> bool { - let Some(reasoning) = object.get("reasoning_effort").and_then(Value::as_str) else { - return false; - }; - !reasoning.trim().is_empty() && reasoning != "none" -} - -fn thinking_enabled_from_system_prompt(system_prompt: &str) -> bool { - extract_thinking_mode(system_prompt) - .is_some_and(|value| matches!(value.trim(), "interleaved" | "enabled")) -} - -fn thinking_enabled_from_model_hint(object: &Map) -> bool { - if object.get("max_completion_tokens").is_none() { - return false; - } - let Some(model) = object.get("model").and_then(Value::as_str) else { - return false; - }; - let lower = model.to_ascii_lowercase(); - lower.contains("thinking") || lower.contains("reason") -} - -pub(super) fn has_thinking_tags(system_prompt: &str) -> bool { - system_prompt.contains("") || system_prompt.contains("") -} - -fn extract_thinking_mode(system_prompt: &str) -> Option { - let start = system_prompt.find("")?; - let end = system_prompt.find("")?; - if end <= start { - return None; - } - let value_start = start + "".len(); - Some(system_prompt[value_start..end].to_string()) -} diff --git a/src-tauri/src/proxy/kiro/response.rs b/src-tauri/src/proxy/kiro/response.rs deleted file mode 100644 index 85d63fa..0000000 --- a/src-tauri/src/proxy/kiro/response.rs +++ /dev/null @@ -1,394 +0,0 @@ -use std::collections::HashSet; - -use serde_json::{Map, Value}; - -use super::event_stream::EventStreamDecoder; -use super::tool_parser::{ - deduplicate_tool_uses, parse_embedded_tool_calls, process_tool_use_event, ToolUseState, -}; -use super::types::KiroToolUse; - -#[derive(Clone, Debug, Default)] -pub(crate) struct KiroUsage { - pub(crate) input_tokens: Option, - pub(crate) output_tokens: Option, - pub(crate) total_tokens: Option, - pub(crate) context_usage_percentage: Option, -} - -#[derive(Clone, Debug)] -pub(crate) struct KiroParsedResponse { - pub(crate) content: String, - pub(crate) reasoning: String, - pub(crate) tool_uses: Vec, - pub(crate) usage: KiroUsage, - pub(crate) stop_reason: Option, -} - -pub(crate) fn parse_event_stream(bytes: &[u8]) -> Result { - let mut decoder = EventStreamDecoder::new(); - let messages = decoder - .push(bytes) - .map_err(|err| format!("EventStream parse error: {}", err.message))?; - - let mut content = String::new(); - let mut tool_uses: Vec = Vec::new(); - let mut reasoning = String::new(); - let mut usage = KiroUsage::default(); - let mut stop_reason: Option = None; - let mut processed_tool_keys: HashSet = HashSet::new(); - let mut tool_state: Option = None; - let mut saw_invalid_state = false; - - for message in messages { - if message.payload.is_empty() { - continue; - } - let Ok(event) = serde_json::from_slice::(&message.payload) else { - continue; - }; - let Some(event_obj) = event.as_object() else { - continue; - }; - - if let Some(error) = extract_error(event_obj) { - if error == "invalidStateEvent" { - saw_invalid_state = true; - } else { - return Err(error); - } - } - - update_stop_reason(event_obj, &mut stop_reason); - update_usage(event_obj, &mut usage); - - let event_type = if !message.event_type.is_empty() { - message.event_type.as_str() - } else { - detect_event_type(event_obj) - }; - - match event_type { - "followupPromptEvent" => {} - "assistantResponseEvent" => { - if let Some(Value::Object(assistant)) = event_obj.get("assistantResponseEvent") { - if let Some(text) = assistant.get("content").and_then(Value::as_str) { - content.push_str(text); - } - if let Some(tool_items) = assistant.get("toolUses").and_then(Value::as_array) { - extract_tool_uses(tool_items, &mut tool_uses, &mut processed_tool_keys); - } - update_stop_reason(assistant, &mut stop_reason); - } - if let Some(text) = event_obj.get("content").and_then(Value::as_str) { - content.push_str(text); - } - if let Some(tool_items) = event_obj.get("toolUses").and_then(Value::as_array) { - extract_tool_uses(tool_items, &mut tool_uses, &mut processed_tool_keys); - } - } - "toolUseEvent" => { - let (completed, next_state) = - process_tool_use_event(event_obj, tool_state.take(), &mut processed_tool_keys); - tool_uses.extend(completed); - tool_state = next_state; - } - "reasoningContentEvent" => { - if let Some(Value::Object(reasoning_event)) = event_obj.get("reasoningContentEvent") { - if let Some(text) = reasoning_event.get("thinkingText").and_then(Value::as_str) { - reasoning.push_str(text); - } - if let Some(text) = reasoning_event.get("text").and_then(Value::as_str) { - reasoning.push_str(text); - } - } - } - "messageStopEvent" | "message_stop" => { - update_stop_reason(event_obj, &mut stop_reason); - } - _ => {} - } - } - - if saw_invalid_state { - // Ignore invalidStateEvent and continue parsing. - } - - let (cleaned_content, extracted_reasoning) = extract_thinking_from_content(&content); - content = cleaned_content; - if !extracted_reasoning.trim().is_empty() { - if !reasoning.is_empty() && !reasoning.ends_with('\n') { - reasoning.push('\n'); - } - reasoning.push_str(extracted_reasoning.trim()); - } - - let (cleaned, embedded_tool_uses) = - parse_embedded_tool_calls(&content, &mut processed_tool_keys); - content = cleaned; - tool_uses.extend(embedded_tool_uses); - tool_uses = deduplicate_tool_uses(tool_uses); - - if stop_reason.is_none() { - if !tool_uses.is_empty() { - stop_reason = Some("tool_use".to_string()); - } else { - stop_reason = Some("end_turn".to_string()); - } - } - - Ok(KiroParsedResponse { - content, - reasoning, - tool_uses, - usage, - stop_reason, - }) -} - -fn extract_thinking_from_content(content: &str) -> (String, String) { - const START: &str = ""; - const END: &str = ""; - - if !content.contains(START) { - return (content.to_string(), String::new()); - } - - let mut cleaned = String::new(); - let mut reasoning = String::new(); - let mut remaining = content; - - loop { - let Some(start_idx) = remaining.find(START) else { - cleaned.push_str(remaining); - break; - }; - let (before, after_start) = remaining.split_at(start_idx); - cleaned.push_str(before); - let after_start = &after_start[START.len()..]; - - let Some(end_idx) = after_start.find(END) else { - reasoning.push_str(after_start); - break; - }; - let (think_block, rest) = after_start.split_at(end_idx); - reasoning.push_str(think_block); - remaining = &rest[END.len()..]; - } - - (cleaned, reasoning) -} - -fn detect_event_type(event: &Map) -> &str { - for key in [ - "assistantResponseEvent", - "toolUseEvent", - "reasoningContentEvent", - "messageStopEvent", - "message_stop", - "messageMetadataEvent", - "metadataEvent", - "usageEvent", - "usage", - "metricsEvent", - "meteringEvent", - "supplementaryWebLinksEvent", - "error", - "exception", - "internalServerException", - "invalidStateEvent", - ] { - if event.contains_key(key) { - return key; - } - } - "" -} - -fn extract_error(event: &Map) -> Option { - if let Some(Value::String(err_type)) = event.get("_type") { - let message = event - .get("message") - .and_then(Value::as_str) - .unwrap_or(""); - return Some(format!("Kiro error: {err_type} {message}")); - } - if let Some(Value::String(kind)) = event.get("type") { - if matches!( - kind.as_str(), - "error" | "exception" | "internalServerException" - ) { - let message = event - .get("message") - .and_then(Value::as_str) - .unwrap_or(""); - if message.is_empty() { - if let Some(Value::Object(err_obj)) = event.get("error") { - if let Some(text) = err_obj.get("message").and_then(Value::as_str) { - return Some(format!("Kiro error: {text}")); - } - } - } - return Some(format!("Kiro error: {message}")); - } - } - if event.contains_key("invalidStateEvent") - || event - .get("eventType") - .and_then(Value::as_str) - .is_some_and(|value| value == "invalidStateEvent") - { - return Some("invalidStateEvent".to_string()); - } - None -} - -fn update_stop_reason(event: &Map, stop_reason: &mut Option) { - if let Some(reason) = event.get("stop_reason").and_then(Value::as_str) { - *stop_reason = Some(reason.to_string()); - } - if let Some(reason) = event.get("stopReason").and_then(Value::as_str) { - *stop_reason = Some(reason.to_string()); - } -} - -fn update_usage(event: &Map, usage: &mut KiroUsage) { - if let Some(context_pct) = event.get("contextUsagePercentage").and_then(Value::as_f64) { - usage.context_usage_percentage = Some(context_pct); - } - if let Some(tokens) = event.get("inputTokens").and_then(Value::as_u64) { - usage.input_tokens = Some(tokens); - } - if let Some(tokens) = event.get("outputTokens").and_then(Value::as_u64) { - usage.output_tokens = Some(tokens); - } - if let Some(tokens) = event.get("totalTokens").and_then(Value::as_u64) { - usage.total_tokens = Some(tokens); - } - - if let Some(metadata) = event.get("messageMetadataEvent").and_then(Value::as_object) { - update_usage_from_metadata(metadata, usage); - } else if let Some(metadata) = event.get("metadataEvent").and_then(Value::as_object) { - update_usage_from_metadata(metadata, usage); - } - - if let Some(usage_obj) = event.get("usage").and_then(Value::as_object) { - update_usage_from_usage_obj(usage_obj, usage); - } - if let Some(usage_obj) = event.get("usageEvent").and_then(Value::as_object) { - update_usage_from_usage_obj(usage_obj, usage); - } - - if let Some(links) = event.get("supplementaryWebLinksEvent").and_then(Value::as_object) { - if let Some(tokens) = links.get("inputTokens").and_then(Value::as_u64) { - usage.input_tokens = Some(tokens); - } - if let Some(tokens) = links.get("outputTokens").and_then(Value::as_u64) { - usage.output_tokens = Some(tokens); - } - } - - if let Some(metrics) = event.get("metricsEvent").and_then(Value::as_object) { - if let Some(tokens) = metrics.get("inputTokens").and_then(Value::as_u64) { - usage.input_tokens = Some(tokens); - } - if let Some(tokens) = metrics.get("outputTokens").and_then(Value::as_u64) { - usage.output_tokens = Some(tokens); - } - } -} - -fn update_usage_from_metadata(metadata: &Map, usage: &mut KiroUsage) { - if let Some(token_usage) = metadata.get("tokenUsage").and_then(Value::as_object) { - if let Some(tokens) = token_usage.get("outputTokens").and_then(Value::as_u64) { - usage.output_tokens = Some(tokens); - } - if let Some(tokens) = token_usage.get("totalTokens").and_then(Value::as_u64) { - usage.total_tokens = Some(tokens); - } - if let Some(tokens) = token_usage.get("uncachedInputTokens").and_then(Value::as_u64) { - usage.input_tokens = Some(tokens); - } - if let Some(tokens) = token_usage.get("cacheReadInputTokens").and_then(Value::as_u64) { - let current = usage.input_tokens.unwrap_or(0); - usage.input_tokens = Some(current + tokens); - } - if let Some(context_pct) = token_usage - .get("contextUsagePercentage") - .and_then(Value::as_f64) - { - usage.context_usage_percentage = Some(context_pct); - } - } - - if usage.input_tokens.is_none() { - if let Some(tokens) = metadata.get("inputTokens").and_then(Value::as_u64) { - usage.input_tokens = Some(tokens); - } - } - if usage.output_tokens.is_none() { - if let Some(tokens) = metadata.get("outputTokens").and_then(Value::as_u64) { - usage.output_tokens = Some(tokens); - } - } - if usage.total_tokens.is_none() { - if let Some(tokens) = metadata.get("totalTokens").and_then(Value::as_u64) { - usage.total_tokens = Some(tokens); - } - } -} - -fn update_usage_from_usage_obj(usage_obj: &Map, usage: &mut KiroUsage) { - let input_tokens = usage_obj - .get("input_tokens") - .or_else(|| usage_obj.get("prompt_tokens")) - .and_then(Value::as_u64); - let output_tokens = usage_obj - .get("output_tokens") - .or_else(|| usage_obj.get("completion_tokens")) - .and_then(Value::as_u64); - let total_tokens = usage_obj.get("total_tokens").and_then(Value::as_u64); - - if input_tokens.is_some() { - usage.input_tokens = input_tokens; - } - if output_tokens.is_some() { - usage.output_tokens = output_tokens; - } - if total_tokens.is_some() { - usage.total_tokens = total_tokens; - } -} - -fn extract_tool_uses( - tool_items: &[Value], - output: &mut Vec, - processed: &mut HashSet, -) { - for item in tool_items { - let Some(tool) = item.as_object() else { - continue; - }; - let tool_use_id = tool - .get("toolUseId") - .or_else(|| tool.get("tool_use_id")) - .and_then(Value::as_str) - .unwrap_or(""); - let dedupe_key = format!("id:{tool_use_id}"); - if tool_use_id.is_empty() || processed.contains(&dedupe_key) { - continue; - } - let name = tool.get("name").and_then(Value::as_str).unwrap_or(""); - let input = tool - .get("input") - .and_then(Value::as_object) - .cloned() - .unwrap_or_default(); - processed.insert(dedupe_key); - output.push(KiroToolUse { - tool_use_id: tool_use_id.to_string(), - name: name.to_string(), - input, - }); - } -} diff --git a/src-tauri/src/proxy/kiro/tool_parser.rs b/src-tauri/src/proxy/kiro/tool_parser.rs deleted file mode 100644 index 3f7dc21..0000000 --- a/src-tauri/src/proxy/kiro/tool_parser.rs +++ /dev/null @@ -1,403 +0,0 @@ -use serde_json::{Map, Value}; -use std::collections::HashSet; - -use super::types::KiroToolUse; -use super::utils::random_uuid; - -pub(crate) struct ToolUseState { - id: String, - name: String, - input_buffer: String, -} - -pub(crate) fn parse_embedded_tool_calls( - text: &str, - processed: &mut HashSet, -) -> (String, Vec) { - if !text.contains("[Called") { - return (text.to_string(), Vec::new()); - } - - let mut matches = Vec::new(); - let mut cursor = 0usize; - while let Some(offset) = text[cursor..].find("[Called") { - let start = cursor + offset; - let mut idx = start + "[Called".len(); - idx = skip_whitespace(text, idx); - - let name_start = idx; - while idx < text.len() && is_tool_name_char(text.as_bytes()[idx]) { - idx += 1; - } - if name_start == idx { - cursor = start + 1; - continue; - } - let tool_name = &text[name_start..idx]; - - idx = skip_whitespace(text, idx); - if !text[idx..].starts_with("with") { - cursor = start + 1; - continue; - } - idx += "with".len(); - - idx = skip_whitespace(text, idx); - if !text[idx..].starts_with("args:") { - cursor = start + 1; - continue; - } - idx += "args:".len(); - idx = skip_whitespace(text, idx); - - if idx >= text.len() || text.as_bytes()[idx] != b'{' { - cursor = start + 1; - continue; - } - let json_start = idx; - let json_end = match find_matching_bracket(text, json_start) { - Some(end) => end, - None => { - cursor = start + 1; - continue; - } - }; - - let mut closing = json_end + 1; - while closing < text.len() && text.as_bytes()[closing] != b']' { - closing += 1; - } - if closing >= text.len() { - cursor = start + 1; - continue; - } - let match_end = closing + 1; - matches.push(( - start, - match_end, - tool_name.to_string(), - text[json_start..=json_end].to_string(), - )); - - cursor = match_end; - } - - if matches.is_empty() { - return (text.to_string(), Vec::new()); - } - - let mut clean_text = text.to_string(); - let mut tool_uses = Vec::new(); - for (start, end, name, json_str) in matches.into_iter().rev() { - let repaired = repair_json(&json_str); - let input = match serde_json::from_str::>(&repaired) { - Ok(map) => map, - Err(_) => continue, - }; - - let dedupe_key = format!("content:{name}:{repaired}"); - if processed.contains(&dedupe_key) { - if clean_text.is_char_boundary(start) && clean_text.is_char_boundary(end) { - clean_text.replace_range(start..end, ""); - } - continue; - } - processed.insert(dedupe_key); - - let tool_use_id = generate_tool_use_id(); - tool_uses.push(KiroToolUse { - tool_use_id, - name, - input, - }); - - if clean_text.is_char_boundary(start) && clean_text.is_char_boundary(end) { - clean_text.replace_range(start..end, ""); - } - } - - (clean_text, tool_uses) -} - -pub(crate) fn process_tool_use_event( - event: &Map, - current: Option, - processed: &mut HashSet, -) -> (Vec, Option) { - let mut tool_uses = Vec::new(); - let mut state = current; - - let source = event - .get("toolUseEvent") - .and_then(Value::as_object) - .unwrap_or(event); - - let tool_use_id = tool_use_id(source); - let name = source.get("name").and_then(Value::as_str).unwrap_or(""); - let stop = source.get("stop").and_then(Value::as_bool).unwrap_or(false); - - if let (Some(tool_use_id), true) = (tool_use_id, !name.is_empty()) { - let dedupe_key = format!("id:{tool_use_id}"); - if let Some(current_state) = &state { - if current_state.id != tool_use_id { - if !processed.contains(&format!("id:{}", current_state.id)) { - let input = parse_tool_input(¤t_state.input_buffer); - tool_uses.push(KiroToolUse { - tool_use_id: current_state.id.clone(), - name: current_state.name.clone(), - input, - }); - processed.insert(format!("id:{}", current_state.id)); - } - state = None; - } - } - - if state.is_none() && !processed.contains(&dedupe_key) { - state = Some(ToolUseState { - id: tool_use_id.to_string(), - name: name.to_string(), - input_buffer: String::new(), - }); - } - } - - if let Some(current_state) = &mut state { - if let Some(Value::String(fragment)) = source.get("input") { - current_state.input_buffer.push_str(fragment); - } else if let Some(Value::Object(input)) = source.get("input") { - let serialized = serde_json::to_string(input).unwrap_or_default(); - current_state.input_buffer = serialized; - } - } - - if stop { - if let Some(current_state) = state.take() { - let input = parse_tool_input(¤t_state.input_buffer); - let dedupe_key = format!("id:{}", current_state.id); - if !processed.contains(&dedupe_key) { - processed.insert(dedupe_key); - tool_uses.push(KiroToolUse { - tool_use_id: current_state.id, - name: current_state.name, - input, - }); - } - } - } - - (tool_uses, state) -} - -pub(crate) fn deduplicate_tool_uses(tool_uses: Vec) -> Vec { - let mut seen_ids = HashSet::new(); - let mut seen_content = HashSet::new(); - let mut output = Vec::new(); - - for tool_use in tool_uses { - if !seen_ids.insert(tool_use.tool_use_id.clone()) { - continue; - } - let input_json = serde_json::to_string(&tool_use.input).unwrap_or_default(); - let content_key = format!("{}:{}", tool_use.name, input_json); - if !seen_content.insert(content_key) { - continue; - } - output.push(tool_use); - } - output -} - -fn tool_use_id(source: &Map) -> Option<&str> { - source - .get("toolUseId") - .or_else(|| source.get("tool_use_id")) - .and_then(Value::as_str) -} - -fn parse_tool_input(raw: &str) -> Map { - if raw.trim().is_empty() { - return Map::new(); - } - let repaired = repair_json(raw); - serde_json::from_str::>(&repaired).unwrap_or_default() -} - -fn generate_tool_use_id() -> String { - let raw = random_uuid().replace('-', ""); - let suffix = if raw.len() >= 12 { &raw[..12] } else { raw.as_str() }; - format!("toolu_{suffix}") -} - -fn skip_whitespace(text: &str, mut idx: usize) -> usize { - while idx < text.len() && text.as_bytes()[idx].is_ascii_whitespace() { - idx += 1; - } - idx -} - -fn is_tool_name_char(byte: u8) -> bool { - matches!(byte, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_' | b'.' | b'-') -} - -fn find_matching_bracket(text: &str, start: usize) -> Option { - let bytes = text.as_bytes(); - let open = *bytes.get(start)?; - let close = match open { - b'{' => b'}', - b'[' => b']', - _ => return None, - }; - let mut depth = 1usize; - let mut in_string = false; - let mut escape_next = false; - - for idx in (start + 1)..bytes.len() { - let ch = bytes[idx]; - if escape_next { - escape_next = false; - continue; - } - if ch == b'\\' && in_string { - escape_next = true; - continue; - } - if ch == b'"' { - in_string = !in_string; - continue; - } - if in_string { - continue; - } - if ch == open { - depth += 1; - } else if ch == close { - depth = depth.saturating_sub(1); - if depth == 0 { - return Some(idx); - } - } - } - None -} - -fn repair_json(raw: &str) -> String { - let mut value = raw.trim().to_string(); - if value.is_empty() { - return "{}".to_string(); - } - - if serde_json::from_str::(&value).is_ok() { - return value; - } - - let original = value.clone(); - value = escape_newlines_in_strings(&value); - value = remove_trailing_commas(&value); - - let mut brace_count = 0i32; - let mut bracket_count = 0i32; - let mut in_string = false; - let mut escape_next = false; - let mut last_valid_index: Option = None; - - for (idx, ch) in value.bytes().enumerate() { - if escape_next { - escape_next = false; - continue; - } - if ch == b'\\' { - escape_next = true; - continue; - } - if ch == b'"' { - in_string = !in_string; - continue; - } - if in_string { - continue; - } - match ch { - b'{' => brace_count += 1, - b'}' => brace_count -= 1, - b'[' => bracket_count += 1, - b']' => bracket_count -= 1, - _ => {} - } - if brace_count >= 0 && bracket_count >= 0 { - last_valid_index = Some(idx); - } - } - - if brace_count > 0 || bracket_count > 0 { - if let Some(last) = last_valid_index { - if last + 1 < value.len() { - value.truncate(last + 1); - } - } - while brace_count > 0 { - value.push('}'); - brace_count -= 1; - } - while bracket_count > 0 { - value.push(']'); - bracket_count -= 1; - } - } - - if serde_json::from_str::(&value).is_ok() { - value - } else { - original - } -} - -fn escape_newlines_in_strings(raw: &str) -> String { - let mut out = String::with_capacity(raw.len() + 64); - let mut in_string = false; - let mut escaped = false; - for ch in raw.chars() { - if escaped { - out.push(ch); - escaped = false; - continue; - } - if ch == '\\' && in_string { - out.push(ch); - escaped = true; - continue; - } - if ch == '"' { - in_string = !in_string; - out.push(ch); - continue; - } - if in_string { - match ch { - '\n' => out.push_str("\\n"), - '\r' => out.push_str("\\r"), - '\t' => out.push_str("\\t"), - _ => out.push(ch), - } - } else { - out.push(ch); - } - } - out -} - -fn remove_trailing_commas(raw: &str) -> String { - let mut out = String::with_capacity(raw.len()); - let mut iter = raw.chars().peekable(); - while let Some(ch) = iter.next() { - if ch == ',' { - if let Some(next) = iter.peek() { - if *next == '}' || *next == ']' { - continue; - } - } - } - out.push(ch); - } - out -} diff --git a/src-tauri/src/proxy/kiro/tools.rs b/src-tauri/src/proxy/kiro/tools.rs deleted file mode 100644 index 0b2a689..0000000 --- a/src-tauri/src/proxy/kiro/tools.rs +++ /dev/null @@ -1,219 +0,0 @@ -use serde_json::{Map, Value}; - -use super::types::{KiroInputSchema, KiroToolSpecification, KiroToolWrapper}; - -const KIRO_MAX_TOOL_DESC_LEN: usize = 10237; -const TOOL_COMPRESSION_TARGET_SIZE: usize = 20 * 1024; -const MIN_TOOL_DESCRIPTION_LENGTH: usize = 50; - -pub(crate) fn convert_openai_tools(tools: Option<&Value>, is_chat_only: bool) -> Vec { - if is_chat_only { - return Vec::new(); - } - let Some(Value::Array(items)) = tools else { - return Vec::new(); - }; - - let mut output = Vec::new(); - for item in items { - let Some(tool) = item.as_object() else { - continue; - }; - let tool_type = tool.get("type").and_then(Value::as_str).unwrap_or("function"); - if tool_type != "function" { - continue; - } - let (name, description, parameters) = match extract_tool_fields(tool) { - Some(fields) => fields, - None => continue, - }; - if name.is_empty() { - continue; - } - let name = shorten_tool_name(name); - let mut description = description.to_string(); - if description.trim().is_empty() { - description = format!("Tool: {name}"); - } - if description.len() > KIRO_MAX_TOOL_DESC_LEN { - description = truncate_utf8(&description, KIRO_MAX_TOOL_DESC_LEN - 30) - + "... (description truncated)"; - } - let parameters = parameters.unwrap_or_else(|| Value::Object(Map::new())); - - output.push(KiroToolWrapper { - tool_specification: KiroToolSpecification { - name: name.to_string(), - description, - input_schema: KiroInputSchema { json: parameters }, - }, - }); - } - - compress_tools_if_needed(output) -} - -fn extract_tool_fields(tool: &Map) -> Option<(&str, &str, Option)> { - if let Some(function) = tool.get("function").and_then(Value::as_object) { - let name = function.get("name").and_then(Value::as_str).unwrap_or(""); - let description = function - .get("description") - .and_then(Value::as_str) - .unwrap_or(""); - let parameters = function.get("parameters").cloned(); - return Some((name, description, parameters)); - } - - let name = tool.get("name").and_then(Value::as_str).unwrap_or(""); - let description = tool - .get("description") - .and_then(Value::as_str) - .unwrap_or(""); - let parameters = tool - .get("parameters") - .cloned() - .or_else(|| tool.get("input_schema").cloned()); - Some((name, description, parameters)) -} - -fn shorten_tool_name(name: &str) -> String { - const LIMIT: usize = 64; - if name.len() <= LIMIT { - return name.to_string(); - } - if let Some(stripped) = name.strip_prefix("mcp__") { - if let Some(idx) = stripped.rfind("__") { - let suffix = &stripped[idx + 2..]; - let candidate = format!("mcp__{suffix}"); - if candidate.len() <= LIMIT { - return candidate; - } - return candidate.chars().take(LIMIT).collect(); - } - } - name.chars().take(LIMIT).collect() -} - -fn truncate_utf8(value: &str, max_len: usize) -> String { - if value.len() <= max_len { - return value.to_string(); - } - let mut end = max_len; - while !value.is_char_boundary(end) { - end -= 1; - } - value[..end].to_string() -} - -fn compress_tools_if_needed(tools: Vec) -> Vec { - if tools.is_empty() { - return tools; - } - let original_size = calculate_tools_size(&tools); - if original_size <= TOOL_COMPRESSION_TARGET_SIZE { - return tools; - } - - let mut compressed = tools - .into_iter() - .map(|tool| KiroToolWrapper { - tool_specification: KiroToolSpecification { - name: tool.tool_specification.name, - description: tool.tool_specification.description, - input_schema: KiroInputSchema { - json: tool.tool_specification.input_schema.json, - }, - }, - }) - .collect::>(); - - for tool in &mut compressed { - tool.tool_specification.input_schema.json = - simplify_schema(&tool.tool_specification.input_schema.json); - } - - let size_after_schema = calculate_tools_size(&compressed); - if size_after_schema <= TOOL_COMPRESSION_TARGET_SIZE { - return compressed; - } - - let size_to_reduce = (size_after_schema - TOOL_COMPRESSION_TARGET_SIZE) as f64; - let total_desc_len: f64 = compressed - .iter() - .map(|tool| tool.tool_specification.description.len() as f64) - .sum(); - - if total_desc_len > 0.0 { - let mut keep_ratio = 1.0 - (size_to_reduce / total_desc_len); - if keep_ratio > 1.0 { - keep_ratio = 1.0; - } - if keep_ratio < 0.0 { - keep_ratio = 0.0; - } - for tool in &mut compressed { - let desc = tool.tool_specification.description.clone(); - let target_len = (desc.len() as f64 * keep_ratio) as usize; - tool.tool_specification.description = compress_description(&desc, target_len); - } - } - - compressed -} - -fn compress_description(description: &str, target_len: usize) -> String { - let mut target = target_len; - if target < MIN_TOOL_DESCRIPTION_LENGTH { - target = MIN_TOOL_DESCRIPTION_LENGTH; - } - if description.len() <= target { - return description.to_string(); - } - let trimmed = truncate_utf8(description, target.saturating_sub(3)); - format!("{trimmed}...") -} - -fn calculate_tools_size(tools: &[KiroToolWrapper]) -> usize { - serde_json::to_vec(tools).map(|data| data.len()).unwrap_or(0) -} - -fn simplify_schema(value: &Value) -> Value { - let Some(object) = value.as_object() else { - return value.clone(); - }; - let mut simplified = Map::new(); - - for key in ["type", "enum", "required"] { - if let Some(val) = object.get(key) { - simplified.insert(key.to_string(), val.clone()); - } - } - - if let Some(properties) = object.get("properties").and_then(Value::as_object) { - let mut simplified_props = Map::new(); - for (key, val) in properties { - simplified_props.insert(key.clone(), simplify_schema(val)); - } - simplified.insert("properties".to_string(), Value::Object(simplified_props)); - } - - if let Some(items) = object.get("items") { - simplified.insert("items".to_string(), simplify_schema(items)); - } - - if let Some(additional) = object.get("additionalProperties") { - simplified.insert( - "additionalProperties".to_string(), - simplify_schema(additional), - ); - } - - for key in ["anyOf", "oneOf", "allOf"] { - if let Some(Value::Array(values)) = object.get(key) { - let simplified_values = values.iter().map(simplify_schema).collect::>(); - simplified.insert(key.to_string(), Value::Array(simplified_values)); - } - } - - Value::Object(simplified) -} diff --git a/src-tauri/src/proxy/kiro/types.rs b/src-tauri/src/proxy/kiro/types.rs deleted file mode 100644 index b01a3ce..0000000 --- a/src-tauri/src/proxy/kiro/types.rs +++ /dev/null @@ -1,132 +0,0 @@ -use serde::Serialize; -use serde_json::{Map, Value}; - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroPayload { - pub(crate) conversation_state: KiroConversationState, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) profile_arn: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) inference_config: Option, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroInferenceConfig { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) max_tokens: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) top_p: Option, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroConversationState { - pub(crate) chat_trigger_type: String, - pub(crate) conversation_id: String, - pub(crate) current_message: KiroCurrentMessage, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub(crate) history: Vec, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroCurrentMessage { - pub(crate) user_input_message: KiroUserInputMessage, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroHistoryMessage { - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) user_input_message: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) assistant_response_message: Option, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroImage { - pub(crate) format: String, - pub(crate) source: KiroImageSource, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroImageSource { - pub(crate) bytes: String, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroUserInputMessage { - pub(crate) content: String, - pub(crate) model_id: String, - pub(crate) origin: String, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub(crate) images: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub(crate) user_input_message_context: Option, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroUserInputMessageContext { - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub(crate) tool_results: Vec, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub(crate) tools: Vec, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroToolResult { - pub(crate) content: Vec, - pub(crate) status: String, - pub(crate) tool_use_id: String, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroTextContent { - pub(crate) text: String, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroToolWrapper { - pub(crate) tool_specification: KiroToolSpecification, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroToolSpecification { - pub(crate) name: String, - pub(crate) description: String, - pub(crate) input_schema: KiroInputSchema, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroInputSchema { - pub(crate) json: Value, -} - -#[derive(Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroAssistantResponseMessage { - pub(crate) content: String, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub(crate) tool_uses: Vec, -} - -#[derive(Clone, Debug, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct KiroToolUse { - pub(crate) tool_use_id: String, - pub(crate) name: String, - pub(crate) input: Map, -} diff --git a/src-tauri/src/proxy/kiro/utils.rs b/src-tauri/src/proxy/kiro/utils.rs deleted file mode 100644 index 4e19e75..0000000 --- a/src-tauri/src/proxy/kiro/utils.rs +++ /dev/null @@ -1,25 +0,0 @@ -use rand::RngCore; - -pub(crate) fn random_uuid() -> String { - let mut bytes = [0u8; 16]; - rand::rng().fill_bytes(&mut bytes); - format!( - "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}", - bytes[0], - bytes[1], - bytes[2], - bytes[3], - bytes[4], - bytes[5], - bytes[6], - bytes[7], - bytes[8], - bytes[9], - bytes[10], - bytes[11], - bytes[12], - bytes[13], - bytes[14], - bytes[15] - ) -} diff --git a/src-tauri/src/proxy/log.rs b/src-tauri/src/proxy/log.rs deleted file mode 100644 index 2aee033..0000000 --- a/src-tauri/src/proxy/log.rs +++ /dev/null @@ -1,195 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use sqlx::SqlitePool; -use std::{ - sync::Arc, - time::{Instant, SystemTime, UNIX_EPOCH}, -}; - -#[cfg(debug_assertions)] -macro_rules! debug_log_error { - ($($arg:tt)*) => { - eprintln!($($arg)*); - }; -} - -#[cfg(not(debug_assertions))] -macro_rules! debug_log_error { - ($($arg:tt)*) => {}; -} - -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct TokenUsage { - pub(crate) input_tokens: Option, - pub(crate) output_tokens: Option, - pub(crate) total_tokens: Option, -} - -#[derive(Clone, Serialize, Deserialize)] -pub(crate) struct UsageSnapshot { - pub(crate) usage: Option, - pub(crate) cached_tokens: Option, - pub(crate) usage_json: Option, -} - -#[derive(Serialize, Deserialize)] -pub(crate) struct LogEntry { - pub(crate) ts_ms: u128, - pub(crate) path: String, - pub(crate) provider: String, - pub(crate) upstream_id: String, - pub(crate) model: Option, - pub(crate) mapped_model: Option, - pub(crate) stream: bool, - pub(crate) status: u16, - pub(crate) usage: Option, - pub(crate) cached_tokens: Option, - pub(crate) usage_json: Option, - pub(crate) upstream_request_id: Option, - pub(crate) request_headers: Option, - pub(crate) request_body: Option, - pub(crate) response_error: Option, - pub(crate) latency_ms: u128, -} - -#[derive(Clone)] -pub(crate) struct LogContext { - pub(crate) path: String, - pub(crate) provider: String, - pub(crate) upstream_id: String, - pub(crate) model: Option, - pub(crate) mapped_model: Option, - pub(crate) stream: bool, - pub(crate) status: u16, - pub(crate) upstream_request_id: Option, - pub(crate) request_headers: Option, - pub(crate) request_body: Option, - // Time-to-first-byte (TTFB) measured from `start`. - // For streaming responses, this is recorded when we receive the first upstream chunk. - pub(crate) ttfb_ms: Option, - pub(crate) start: Instant, -} - -pub(crate) struct LogWriter { - sqlite: Option, -} - -impl LogWriter { - pub(crate) fn new(sqlite: Option) -> Self { - Self { sqlite } - } - - // Fire-and-forget logging to avoid blocking the request path. - pub(crate) fn write_detached(self: Arc, entry: LogEntry) { - tokio::spawn(async move { - self.write(&entry).await; - }); - } - - pub(crate) async fn write(&self, entry: &LogEntry) { - let Some(pool) = self.sqlite.as_ref() else { - return; - }; - if let Err(_err) = insert_log_entry(pool, entry).await { - debug_log_error!("proxy sqlite write failed: {_err}"); - } - } -} - -pub(crate) fn build_log_entry( - context: &LogContext, - usage: UsageSnapshot, - response_error: Option, -) -> LogEntry { - LogEntry { - ts_ms: now_ms(), - path: context.path.clone(), - provider: context.provider.clone(), - upstream_id: context.upstream_id.clone(), - model: context.model.clone(), - mapped_model: context.mapped_model.clone(), - stream: context.stream, - status: context.status, - usage: usage.usage, - cached_tokens: usage.cached_tokens, - usage_json: usage.usage_json, - upstream_request_id: context.upstream_request_id.clone(), - request_headers: context.request_headers.clone(), - request_body: context.request_body.clone(), - response_error, - latency_ms: context - .ttfb_ms - .unwrap_or_else(|| context.start.elapsed().as_millis()), - } -} - -fn now_ms() -> u128 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_millis() -} - -async fn insert_log_entry(pool: &SqlitePool, entry: &LogEntry) -> Result<(), sqlx::Error> { - let usage = entry.usage.as_ref(); - let input_tokens = usage.and_then(|usage| usage.input_tokens).map(to_i64_u64); - let output_tokens = usage.and_then(|usage| usage.output_tokens).map(to_i64_u64); - let total_tokens = usage.and_then(|usage| usage.total_tokens).map(to_i64_u64); - let cached_tokens = entry.cached_tokens.map(to_i64_u64); - let usage_json = entry.usage_json.as_ref().map(Value::to_string); - - sqlx::query( - r#" -INSERT INTO request_logs ( - ts_ms, - path, - provider, - upstream_id, - model, - mapped_model, - stream, - status, - input_tokens, - output_tokens, - total_tokens, - cached_tokens, - usage_json, - upstream_request_id, - request_headers, - request_body, - response_error, - latency_ms -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?); -"#, - ) - .bind(to_i64_u128(entry.ts_ms)) - .bind(entry.path.as_str()) - .bind(entry.provider.as_str()) - .bind(entry.upstream_id.as_str()) - .bind(entry.model.as_deref()) - .bind(entry.mapped_model.as_deref()) - .bind(entry.stream) - .bind(i64::from(entry.status)) - .bind(input_tokens) - .bind(output_tokens) - .bind(total_tokens) - .bind(cached_tokens) - .bind(usage_json.as_deref()) - .bind(entry.upstream_request_id.as_deref()) - .bind(entry.request_headers.as_deref()) - .bind(entry.request_body.as_deref()) - .bind(entry.response_error.as_deref()) - .bind(to_i64_u128(entry.latency_ms)) - .execute(pool) - .await?; - - Ok(()) -} - -fn to_i64_u128(value: u128) -> i64 { - value.min(i64::MAX as u128) as i64 -} - -fn to_i64_u64(value: u64) -> i64 { - value.min(i64::MAX as u64) as i64 -} diff --git a/src-tauri/src/proxy/logs.rs b/src-tauri/src/proxy/logs.rs deleted file mode 100644 index aa20847..0000000 --- a/src-tauri/src/proxy/logs.rs +++ /dev/null @@ -1,53 +0,0 @@ -use serde::Serialize; -use sqlx::Row; -use tauri::AppHandle; - -use super::sqlite; - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub(crate) struct RequestLogDetail { - pub(crate) id: u64, - pub(crate) request_headers: Option, - pub(crate) request_body: Option, - pub(crate) response_error: Option, -} - -pub(crate) async fn read_request_log_detail( - app: AppHandle, - id: u64, -) -> Result { - let pool = sqlite::open_read_pool(&app).await?; - let row = sqlx::query( - r#" -SELECT - id, - request_headers, - request_body, - response_error -FROM request_logs -WHERE id = ? -LIMIT 1; -"#, - ) - .bind(id as i64) - .fetch_optional(&pool) - .await - .map_err(|err| format!("Failed to query request log detail: {err}"))?; - - let Some(row) = row else { - return Err("Request log not found.".to_string()); - }; - - let id = row.try_get::("id").unwrap_or_default(); - let request_headers = row.try_get::, _>("request_headers").ok().flatten(); - let request_body = row.try_get::, _>("request_body").ok().flatten(); - let response_error = row.try_get::, _>("response_error").ok().flatten(); - - Ok(RequestLogDetail { - id: id.max(0) as u64, - request_headers, - request_body, - response_error, - }) -} diff --git a/src-tauri/src/proxy/model.rs b/src-tauri/src/proxy/model.rs deleted file mode 100644 index 670f4ef..0000000 --- a/src-tauri/src/proxy/model.rs +++ /dev/null @@ -1,29 +0,0 @@ -use axum::body::Bytes; -use serde_json::Value; - -pub(crate) fn rewrite_request_model(bytes: &Bytes, model: &str) -> Option { - let mut value: Value = serde_json::from_slice(bytes).ok()?; - let object = value.as_object_mut()?; - if !object.contains_key("model") { - return None; - } - object.insert("model".to_string(), Value::String(model.to_string())); - serde_json::to_vec(&value).ok().map(Bytes::from) -} - -pub(crate) fn rewrite_response_model(bytes: &Bytes, model: &str) -> Option { - let mut value: Value = serde_json::from_slice(bytes).ok()?; - let object = value.as_object_mut()?; - if object.contains_key("model") { - object.insert("model".to_string(), Value::String(model.to_string())); - return serde_json::to_vec(&value).ok().map(Bytes::from); - } - let Some(response) = object.get_mut("response").and_then(Value::as_object_mut) else { - return None; - }; - if !response.contains_key("model") { - return None; - } - response.insert("model".to_string(), Value::String(model.to_string())); - serde_json::to_vec(&value).ok().map(Bytes::from) -} diff --git a/src-tauri/src/proxy/openai_compat.rs b/src-tauri/src/proxy/openai_compat.rs deleted file mode 100644 index 4586596..0000000 --- a/src-tauri/src/proxy/openai_compat.rs +++ /dev/null @@ -1,554 +0,0 @@ -use axum::body::Bytes; -use serde_json::{json, Map, Value}; - -use super::{ - anthropic_compat, - compat_content, - compat_reason, - codex_compat, - gemini_compat, - http_client::ProxyHttpClients, -}; - -mod extract; -mod input; -mod message; -mod tools; -mod usage; - -pub(crate) const CHAT_PATH: &str = "/v1/chat/completions"; -pub(crate) const RESPONSES_PATH: &str = "/v1/responses"; - -pub(crate) const PROVIDER_CHAT: &str = "openai"; -pub(crate) const PROVIDER_RESPONSES: &str = "openai-response"; - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub(crate) enum ApiFormat { - ChatCompletions, - Responses, -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub(crate) enum FormatTransform { - None, - ChatToResponses, - ResponsesToChat, - ResponsesToAnthropic, - AnthropicToResponses, - ChatToAnthropic, - AnthropicToChat, - GeminiToAnthropic, - AnthropicToGemini, - ChatToGemini, - GeminiToChat, - ResponsesToGemini, - GeminiToResponses, - KiroToResponses, - KiroToChat, - KiroToAnthropic, - ChatToCodex, - ResponsesToCodex, - CodexToChat, - CodexToResponses, -} - -pub(crate) fn inbound_format(path: &str) -> Option { - match path { - CHAT_PATH => Some(ApiFormat::ChatCompletions), - RESPONSES_PATH => Some(ApiFormat::Responses), - _ => None, - } -} - -pub(crate) async fn transform_request_body( - transform: FormatTransform, - body: &Bytes, - http_clients: &ProxyHttpClients, - model_hint: Option<&str>, -) -> Result { - match transform { - FormatTransform::None => Ok(body.clone()), - FormatTransform::ChatToResponses => chat_request_to_responses(body), - FormatTransform::ResponsesToChat => responses_request_to_chat(body), - FormatTransform::ResponsesToAnthropic => { - anthropic_compat::responses_request_to_anthropic(body, http_clients).await - } - FormatTransform::AnthropicToResponses => { - anthropic_compat::anthropic_request_to_responses(body, http_clients).await - } - FormatTransform::ChatToAnthropic => { - let intermediate = chat_request_to_responses(body)?; - anthropic_compat::responses_request_to_anthropic(&intermediate, http_clients).await - } - FormatTransform::AnthropicToChat => { - let intermediate = anthropic_compat::anthropic_request_to_responses(body, http_clients).await?; - responses_request_to_chat(&intermediate) - } - FormatTransform::GeminiToAnthropic => { - gemini_request_to_anthropic(body, http_clients, model_hint).await - } - FormatTransform::AnthropicToGemini => anthropic_request_to_gemini(body, http_clients).await, - FormatTransform::ChatToGemini => gemini_compat::chat_request_to_gemini(body), - FormatTransform::GeminiToChat => gemini_compat::gemini_request_to_chat(body, model_hint), - FormatTransform::ResponsesToGemini => responses_request_to_gemini(body), - FormatTransform::GeminiToResponses => gemini_request_to_responses(body, model_hint), - FormatTransform::KiroToResponses - | FormatTransform::KiroToChat - | FormatTransform::KiroToAnthropic => Ok(body.clone()), - FormatTransform::ChatToCodex => codex_compat::chat_request_to_codex(body, model_hint), - FormatTransform::ResponsesToCodex => codex_compat::responses_request_to_codex(body, model_hint), - FormatTransform::CodexToChat | FormatTransform::CodexToResponses => Ok(body.clone()), - } -} - -pub(crate) fn transform_response_body( - transform: FormatTransform, - bytes: &Bytes, - model_hint: Option<&str>, -) -> Result { - match transform { - FormatTransform::None => Ok(bytes.clone()), - FormatTransform::ChatToResponses => chat_response_to_responses(bytes), - FormatTransform::ResponsesToChat => responses_response_to_chat(bytes, model_hint), - FormatTransform::ResponsesToAnthropic => { - anthropic_compat::responses_response_to_anthropic(bytes, model_hint) - } - FormatTransform::AnthropicToResponses => anthropic_compat::anthropic_response_to_responses(bytes), - FormatTransform::ChatToAnthropic => { - let intermediate = chat_response_to_responses(bytes)?; - anthropic_compat::responses_response_to_anthropic(&intermediate, model_hint) - } - FormatTransform::AnthropicToChat => { - let intermediate = anthropic_compat::anthropic_response_to_responses(bytes)?; - responses_response_to_chat(&intermediate, model_hint) - } - FormatTransform::GeminiToAnthropic => gemini_response_to_anthropic(bytes, model_hint), - FormatTransform::AnthropicToGemini => anthropic_response_to_gemini(bytes, model_hint), - FormatTransform::ChatToGemini => gemini_compat::chat_response_to_gemini(bytes, model_hint), - FormatTransform::GeminiToChat => gemini_compat::gemini_response_to_chat(bytes, model_hint), - FormatTransform::ResponsesToGemini => responses_response_to_gemini(bytes, model_hint), - FormatTransform::GeminiToResponses => gemini_response_to_responses(bytes, model_hint), - FormatTransform::KiroToResponses - | FormatTransform::KiroToChat - | FormatTransform::KiroToAnthropic => Err("Kiro response conversion is handled upstream.".to_string()), - FormatTransform::CodexToChat | FormatTransform::CodexToResponses => { - Err("Codex response conversion is handled upstream.".to_string()) - } - FormatTransform::ChatToCodex | FormatTransform::ResponsesToCodex => { - Err("Codex response conversion is handled upstream.".to_string()) - } - } -} - -fn chat_request_to_responses(body: &Bytes) -> Result { - let value: Value = serde_json::from_slice(body) - .map_err(|_| "Request body must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Request body must be a JSON object.".to_string()); - }; - - let Some(messages) = object.get("messages").and_then(Value::as_array) else { - return Err("Chat request must include messages.".to_string()); - }; - - let (input, instructions) = chat_messages_to_responses_input(messages)?; - - // Responses API uses `input` (string or structured items). - let mut output = Map::new(); - copy_key(object, &mut output, "model"); - output.insert("input".to_string(), Value::Array(input)); - if let Some(instructions) = instructions { - output.insert("instructions".to_string(), Value::String(instructions)); - } - copy_key(object, &mut output, "stream"); - copy_key(object, &mut output, "temperature"); - copy_key(object, &mut output, "top_p"); - copy_key(object, &mut output, "stop"); - copy_key(object, &mut output, "metadata"); - copy_key(object, &mut output, "user"); - copy_key(object, &mut output, "seed"); - copy_key(object, &mut output, "parallel_tool_calls"); - copy_key(object, &mut output, "modalities"); - copy_key(object, &mut output, "audio"); - - if let Some(max_output_tokens) = object - .get("max_completion_tokens") - .or_else(|| object.get("max_tokens")) - .and_then(Value::as_i64) - { - output.insert("max_output_tokens".to_string(), Value::Number(max_output_tokens.into())); - } - - if let Some(tools) = object.get("tools") { - output.insert("tools".to_string(), tools::map_chat_tools_to_responses(tools)); - } - if let Some(tool_choice) = object.get("tool_choice") { - output.insert( - "tool_choice".to_string(), - tools::map_chat_tool_choice_to_responses(tool_choice), - ); - } - if let Some(response_format) = object.get("response_format") { - let mut text_obj = Map::new(); - text_obj.insert("format".to_string(), response_format.clone()); - output.insert("text".to_string(), Value::Object(text_obj)); - } - - serde_json::to_vec(&Value::Object(output)) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize request: {err}")) -} - -fn responses_request_to_chat(body: &Bytes) -> Result { - let value: Value = serde_json::from_slice(body) - .map_err(|_| "Request body must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Request body must be a JSON object.".to_string()); - }; - - let mut messages = match object.get("input") { - Some(Value::String(text)) => vec![json!({ "role": "user", "content": text })], - Some(Value::Array(items)) => input::responses_input_to_chat_messages(items)?, - _ => return Err("Responses request must include input.".to_string()), - }; - - // Responses API supports `instructions`; translate it to a system message. - if let Some(instructions) = object.get("instructions").and_then(Value::as_str) { - if !instructions.trim().is_empty() { - messages.insert(0, json!({ "role": "system", "content": instructions })); - } - } - - let mut output = Map::new(); - copy_key(object, &mut output, "model"); - output.insert("messages".to_string(), Value::Array(messages)); - copy_key(object, &mut output, "stream"); - copy_key(object, &mut output, "temperature"); - copy_key(object, &mut output, "top_p"); - copy_key(object, &mut output, "stop"); - copy_key(object, &mut output, "metadata"); - copy_key(object, &mut output, "user"); - copy_key(object, &mut output, "seed"); - copy_key(object, &mut output, "parallel_tool_calls"); - copy_key(object, &mut output, "modalities"); - copy_key(object, &mut output, "audio"); - - if let Some(max_output_tokens) = object.get("max_output_tokens").and_then(Value::as_i64) { - // Prefer the modern chat parameter. - output.insert( - "max_completion_tokens".to_string(), - Value::Number(max_output_tokens.into()), - ); - } - - if let Some(tools) = object.get("tools") { - output.insert("tools".to_string(), tools::map_responses_tools_to_chat(tools)); - } - if let Some(tool_choice) = object.get("tool_choice") { - output.insert( - "tool_choice".to_string(), - tools::map_responses_tool_choice_to_chat(tool_choice), - ); - } - if let Some(text_format) = object - .get("text") - .and_then(Value::as_object) - .and_then(|text| text.get("format")) - { - output.insert("response_format".to_string(), text_format.clone()); - } - - serde_json::to_vec(&Value::Object(output)) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize request: {err}")) -} - -fn responses_request_to_gemini(body: &Bytes) -> Result { - let intermediate = responses_request_to_chat(body)?; - gemini_compat::chat_request_to_gemini(&intermediate) -} - -fn gemini_request_to_responses(body: &Bytes, model_hint: Option<&str>) -> Result { - let intermediate = gemini_compat::gemini_request_to_chat(body, model_hint)?; - chat_request_to_responses(&intermediate) -} - -fn responses_response_to_gemini(bytes: &Bytes, model_hint: Option<&str>) -> Result { - let intermediate = responses_response_to_chat(bytes, model_hint)?; - gemini_compat::chat_response_to_gemini(&intermediate, model_hint) -} - -fn gemini_response_to_responses(bytes: &Bytes, model_hint: Option<&str>) -> Result { - let intermediate = gemini_compat::gemini_response_to_chat(bytes, model_hint)?; - chat_response_to_responses(&intermediate) -} - -async fn gemini_request_to_anthropic( - body: &Bytes, - http_clients: &ProxyHttpClients, - model_hint: Option<&str>, -) -> Result { - let intermediate = gemini_compat::gemini_request_to_chat(body, model_hint)?; - let intermediate = chat_request_to_responses(&intermediate)?; - anthropic_compat::responses_request_to_anthropic(&intermediate, http_clients).await -} - -async fn anthropic_request_to_gemini( - body: &Bytes, - http_clients: &ProxyHttpClients, -) -> Result { - let intermediate = anthropic_compat::anthropic_request_to_responses(body, http_clients).await?; - let intermediate = responses_request_to_chat(&intermediate)?; - gemini_compat::chat_request_to_gemini(&intermediate) -} - -fn gemini_response_to_anthropic( - bytes: &Bytes, - model_hint: Option<&str>, -) -> Result { - let intermediate = gemini_compat::gemini_response_to_chat(bytes, model_hint)?; - let intermediate = chat_response_to_responses(&intermediate)?; - anthropic_compat::responses_response_to_anthropic(&intermediate, model_hint) -} - -fn anthropic_response_to_gemini( - bytes: &Bytes, - model_hint: Option<&str>, -) -> Result { - let intermediate = anthropic_compat::anthropic_response_to_responses(bytes)?; - let intermediate = responses_response_to_chat(&intermediate, model_hint)?; - gemini_compat::chat_response_to_gemini(&intermediate, model_hint) -} - -fn chat_messages_to_responses_input( - messages: &[Value], -) -> Result<(Vec, Option), String> { - let mut system_texts = Vec::new(); - let mut input = Vec::new(); - let mut has_user_message = false; - - for message in messages { - let Some(message) = message.as_object() else { - continue; - }; - - let role = message.get("role").and_then(Value::as_str).unwrap_or("user"); - match role { - "system" => push_chat_system_message(&mut system_texts, message), - "user" => push_chat_user_message(&mut input, &mut has_user_message, message)?, - "assistant" => push_chat_assistant_message(&mut input, &mut has_user_message, message)?, - "tool" => push_chat_tool_message(&mut input, message), - _ => {} - } - } - - let instructions = message::join_non_empty_lines(system_texts); - Ok((input, instructions)) -} - -fn push_chat_system_message(system_texts: &mut Vec, message: &Map) { - if let Some(text) = message::extract_text_from_chat_content(message.get("content")) { - system_texts.push(text); - } -} - -fn push_chat_user_message( - input: &mut Vec, - has_user_message: &mut bool, - message: &Map, -) -> Result<(), String> { - let parts = message::chat_content_to_responses_message_parts(message.get("content"), "input_text")?; - if parts.is_empty() { - return Ok(()); - } - input.push(json!({ "type": "message", "role": "user", "content": parts })); - *has_user_message = true; - Ok(()) -} - -fn push_chat_assistant_message( - input: &mut Vec, - has_user_message: &mut bool, - message: &Map, -) -> Result<(), String> { - // Responses API expects assistant message content parts to use output types. - // This matches OpenAI's schema and avoids errors like: "supported values are output_text/refusal". - let parts = message::chat_content_to_responses_message_parts(message.get("content"), "output_text")?; - let tool_calls = message::chat_tool_calls_to_responses_items(message.get("tool_calls")); - let legacy_call = message::chat_function_call_to_responses_item(message.get("function_call")); - - let has_payload = - !parts.is_empty() || !tool_calls.is_empty() || legacy_call.is_some(); - if has_payload && !*has_user_message { - input.push(message::user_placeholder_item()); - *has_user_message = true; - } - - if !parts.is_empty() { - input.push(json!({ "type": "message", "role": "assistant", "content": parts })); - } - input.extend(tool_calls); - if let Some(item) = legacy_call { - input.push(item); - } - Ok(()) -} - -fn push_chat_tool_message(input: &mut Vec, message: &Map) { - let call_id = message.get("tool_call_id").and_then(Value::as_str).unwrap_or(""); - let output = message::stringify_any_json(message.get("content")); - input.push(json!({ - "type": "function_call_output", - "call_id": call_id, - "output": output - })); -} - -fn responses_response_to_chat(bytes: &Bytes, model_hint: Option<&str>) -> Result { - let value: Value = serde_json::from_slice(bytes) - .map_err(|_| "Upstream response must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Upstream response must be a JSON object.".to_string()); - }; - - let extracted = extract::extract_responses_output(&value); - let id = object - .get("id") - .and_then(Value::as_str) - .unwrap_or("chatcmpl-proxy"); - let created = object.get("created_at").and_then(Value::as_i64).unwrap_or(0); - let model = object - .get("model") - .and_then(Value::as_str) - .or(model_hint) - .unwrap_or("unknown"); - - let usage = object - .get("usage") - .and_then(|usage| usage::map_usage_responses_to_chat(usage)); - - let finish_reason = - compat_reason::chat_finish_reason_from_response_object(object, !extracted.tool_calls.is_empty()); - - let reasoning_text = extracted.reasoning_text.clone(); - let mut message = json!({ - "role": "assistant", - "content": compat_content::chat_message_content_from_responses_parts( - &extracted.content_parts, - ) - }); - if let Some(message) = message.as_object_mut() { - if !reasoning_text.trim().is_empty() { - message.insert( - "reasoning_content".to_string(), - Value::String(reasoning_text), - ); - } - } - if !extracted.tool_calls.is_empty() { - if let Some(message) = message.as_object_mut() { - message.insert("tool_calls".to_string(), Value::Array(extracted.tool_calls)); - } - } - - let output = json!({ - "id": id, - "object": "chat.completion", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "message": message, - "finish_reason": finish_reason - } - ], - "usage": usage - }); - - serde_json::to_vec(&output) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize response: {err}")) -} - -fn chat_response_to_responses(bytes: &Bytes) -> Result { - let value: Value = serde_json::from_slice(bytes) - .map_err(|_| "Upstream response must be JSON.".to_string())?; - let Some(object) = value.as_object() else { - return Err("Upstream response must be a JSON object.".to_string()); - }; - - let content = extract::extract_chat_choice_text(&value).unwrap_or_default(); - let tool_calls = extract::extract_chat_tool_calls(&value); - let parallel_tool_calls = tool_calls.len() > 1; - let id = object.get("id").and_then(Value::as_str).unwrap_or("resp-proxy"); - let created = object.get("created").and_then(Value::as_i64).unwrap_or(0); - let model = object.get("model").and_then(Value::as_str).unwrap_or("unknown"); - let finish_reason = object - .get("choices") - .and_then(Value::as_array) - .and_then(|choices| choices.first()) - .and_then(Value::as_object) - .and_then(|choice| choice.get("finish_reason")) - .and_then(Value::as_str); - let (status, incomplete_reason) = - compat_reason::responses_status_from_chat_finish_reason(finish_reason); - let status = status.unwrap_or("completed"); - let incomplete_details = incomplete_reason - .map(|reason| json!({ "reason": reason })) - .unwrap_or(Value::Null); - - let usage = object - .get("usage") - .and_then(|usage| usage::map_usage_chat_to_responses(usage)); - - let mut output = Vec::new(); - if !content.trim().is_empty() || tool_calls.is_empty() { - output.push(json!({ - "type": "message", - "id": "msg_proxy", - "status": "completed", - "role": "assistant", - "content": [ - { "type": "output_text", "text": content, "annotations": [] } - ] - })); - } - for call in tool_calls { - output.push(json!({ - "id": call.item_id, - "type": "function_call", - "status": "completed", - "arguments": call.arguments, - "call_id": call.call_id, - "name": call.name - })); - } - - let output = json!({ - "id": id, - "object": "response", - "created_at": created, - "status": status, - "error": null, - "incomplete_details": incomplete_details, - "model": model, - "parallel_tool_calls": parallel_tool_calls, - "output": output, - "usage": usage - }); - - serde_json::to_vec(&output) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize response: {err}")) -} - -fn copy_key(source: &serde_json::Map, target: &mut Map, key: &str) { - if let Some(value) = source.get(key) { - target.insert(key.to_string(), value.clone()); - } -} - -#[cfg(test)] -#[path = "openai_compat.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/openai_compat.test.rs b/src-tauri/src/proxy/openai_compat.test.rs deleted file mode 100644 index 2f58003..0000000 --- a/src-tauri/src/proxy/openai_compat.test.rs +++ /dev/null @@ -1,682 +0,0 @@ -use super::*; -use axum::body::Bytes; -use serde_json::{json, Value}; -use crate::proxy::http_client::ProxyHttpClients; - -fn run_async(future: impl std::future::Future) -> T { - tokio::runtime::Runtime::new() - .expect("create tokio runtime") - .block_on(future) -} - -fn bytes_from_json(value: Value) -> Bytes { - Bytes::from(serde_json::to_vec(&value).expect("serialize JSON")) -} - -fn json_from_bytes(bytes: Bytes) -> Value { - serde_json::from_slice(&bytes).expect("parse JSON") -} -fn transform_request_value( - transform: FormatTransform, - input: Value, - http_clients: &ProxyHttpClients, - model_hint: Option<&str>, -) -> Value { - let bytes = bytes_from_json(input); - let output = run_async(async { - transform_request_body(transform, &bytes, http_clients, model_hint) - .await - .expect("transform") - }); - json_from_bytes(output) -} -fn transform_response_value(transform: FormatTransform, input: Value, model_hint: Option<&str>) -> Value { - let bytes = bytes_from_json(input); - let output = transform_response_body(transform, &bytes, model_hint).expect("transform"); - json_from_bytes(output) -} -#[test] -fn chat_request_to_responses_maps_common_fields() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let chat_messages = json!([ - { "role": "user", "content": "hi" }, - { "role": "assistant", "content": "hello" } - ]); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "messages": chat_messages, - "stream": true, - "temperature": 0.7, - "top_p": 0.9, - // Prefer `max_completion_tokens` over `max_tokens`. - "max_tokens": 111, - "max_completion_tokens": 222 - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - let expected_input = json!([ - { - "type": "message", - "role": "user", - "content": [{ "type": "input_text", "text": "hi" }] - }, - { - "type": "message", - "role": "assistant", - "content": [{ "type": "output_text", "text": "hello" }] - } - ]); - - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["input"], expected_input); - assert_eq!(value["stream"], json!(true)); - assert_eq!(value["temperature"], json!(0.7)); - assert_eq!(value["top_p"], json!(0.9)); - assert_eq!(value["max_output_tokens"], json!(222)); - assert!(value.get("messages").is_none()); -} - -#[test] -fn responses_request_to_chat_maps_tools_and_tool_choice() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let parameters = json!({ - "type": "object", - "properties": { "q": { "type": "string" } }, - "required": ["q"] - }); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": "hello", - "tools": [ - { - "type": "function", - "name": "search", - "description": "Search something", - "parameters": parameters - } - ], - "tool_choice": { "type": "function", "name": "search" }, - "stream": false - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["tools"][0]["type"], json!("function")); - assert_eq!(value["tools"][0]["function"]["name"], json!("search")); - assert_eq!(value["tools"][0]["function"]["description"], json!("Search something")); - assert_eq!(value["tools"][0]["function"]["parameters"], parameters); - assert_eq!(value["tool_choice"]["type"], json!("function")); - assert_eq!(value["tool_choice"]["function"]["name"], json!("search")); -} - -#[test] -fn chat_request_to_responses_maps_tools_and_tool_choice() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let parameters = json!({ - "type": "object", - "properties": { "q": { "type": "string" } }, - "required": ["q"] - }); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "messages": [{ "role": "user", "content": "hi" }], - "tools": [ - { - "type": "function", - "function": { - "name": "search", - "description": "Search something", - "parameters": parameters - } - } - ], - "tool_choice": { "type": "function", "function": { "name": "search" } }, - "stream": false - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["tools"][0]["type"], json!("function")); - assert_eq!(value["tools"][0]["name"], json!("search")); - assert_eq!(value["tools"][0]["description"], json!("Search something")); - assert_eq!(value["tools"][0]["parameters"], parameters); - assert_eq!(value["tool_choice"]["type"], json!("function")); - assert_eq!(value["tool_choice"]["name"], json!("search")); -} - -#[test] -fn responses_request_to_chat_instructions_becomes_system_message() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": "hello", - "instructions": "be concise", - "stream": false, - "max_output_tokens": 99 - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - let messages = value["messages"].as_array().expect("messages array"); - - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["stream"], json!(false)); - assert_eq!(value["max_completion_tokens"], json!(99)); - assert_eq!(messages.len(), 2); - assert_eq!(messages[0]["role"], json!("system")); - assert_eq!(messages[0]["content"], json!("be concise")); - assert_eq!(messages[1]["role"], json!("user")); - assert_eq!(messages[1]["content"], json!("hello")); -} - -#[test] -fn responses_request_to_chat_accepts_message_array_input() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input_messages = json!([{ "role": "user", "content": "hi" }]); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": input_messages, - "stream": true - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["stream"], json!(true)); - assert_eq!(value["messages"], input_messages); -} - -#[test] -fn responses_request_to_chat_converts_input_text_content_parts_to_string() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input_messages = json!([{ - "role": "user", - "content": [ - { "type": "input_text", "text": "分析项目的逻辑缺陷和性能缺陷" } - ] - }]); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": input_messages, - "stream": false - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["messages"][0]["role"], json!("user")); - assert_eq!( - value["messages"][0]["content"], - json!("分析项目的逻辑缺陷和性能缺陷") - ); -} - -#[test] -fn chat_request_to_responses_maps_response_format() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "messages": [{ "role": "user", "content": "hi" }], - "response_format": { - "type": "json_schema", - "json_schema": { - "name": "example", - "schema": { "type": "object", "properties": { "ok": { "type": "boolean" } } } - } - } - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["text"]["format"]["type"], json!("json_schema")); - assert_eq!(value["text"]["format"]["json_schema"]["name"], json!("example")); -} - -#[test] -fn responses_request_to_chat_maps_text_format_to_response_format() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": "hi", - "text": { "format": { "type": "json_object" } } - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["response_format"]["type"], json!("json_object")); -} - -#[test] -fn responses_response_to_chat_extracts_output_text_and_maps_usage() { - let input = bytes_from_json(json!({ - "id": "resp_123", - "created_at": 1700000000, - "model": "gpt-4.1", - "output": [ - { - "type": "message", - "role": "assistant", - "content": [ - { "type": "output_text", "text": "Hello", "annotations": [] }, - { "type": "output_text", "text": " world", "annotations": [] } - ] - } - ], - "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } - })); - - let output = transform_response_body(FormatTransform::ResponsesToChat, &input, None).expect("transform"); - let value = json_from_bytes(output); - - assert_eq!(value["id"], json!("resp_123")); - assert_eq!(value["object"], json!("chat.completion")); - assert_eq!(value["created"], json!(1700000000)); - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["choices"][0]["message"]["role"], json!("assistant")); - assert_eq!(value["choices"][0]["message"]["content"], json!("Hello world")); - assert_eq!(value["choices"][0]["finish_reason"], json!("stop")); - assert_eq!(value["usage"]["prompt_tokens"], json!(1)); - assert_eq!(value["usage"]["completion_tokens"], json!(2)); - assert_eq!(value["usage"]["total_tokens"], json!(3)); -} - -#[test] -fn responses_response_to_chat_maps_reasoning_content() { - let input = bytes_from_json(json!({ - "id": "resp_reason", - "created_at": 1700000002, - "model": "gpt-4.1", - "output": [ - { - "type": "message", - "role": "assistant", - "content": [ - { "type": "reasoning_text", "text": "think", "annotations": [] }, - { "type": "output_text", "text": "ok", "annotations": [] } - ] - } - ] - })); - - let output = transform_response_body(FormatTransform::ResponsesToChat, &input, None).expect("transform"); - let value = json_from_bytes(output); - - let message = &value["choices"][0]["message"]; - assert_eq!(message["content"], json!("ok")); - assert_eq!(message["reasoning_content"], json!("think")); -} - -#[test] -fn responses_response_to_chat_includes_tool_calls_and_multimodal_content() { - let input = bytes_from_json(json!({ - "id": "resp_456", - "created_at": 1700000001, - "model": "gpt-4.1", - "output": [ - { - "type": "message", - "role": "assistant", - "content": [ - { "type": "output_text", "text": "Hello", "annotations": [] }, - { "type": "output_image", "image_url": { "url": "https://example.com/a.png" } } - ] - }, - { - "type": "function_call", - "call_id": "call_foo", - "name": "doThing", - "arguments": "{\"a\":1}" - } - ], - "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } - })); - - let output = transform_response_body(FormatTransform::ResponsesToChat, &input, None).expect("transform"); - let value = json_from_bytes(output); - - let message = &value["choices"][0]["message"]; - assert_eq!(message["role"], json!("assistant")); - assert_eq!(message["content"][0]["type"], json!("text")); - assert_eq!(message["content"][0]["text"], json!("Hello")); - assert_eq!(message["content"][1]["type"], json!("image_url")); - assert_eq!( - message["content"][1]["image_url"]["url"], - json!("https://example.com/a.png") - ); - assert_eq!(message["tool_calls"][0]["id"], json!("call_foo")); - assert_eq!(message["tool_calls"][0]["function"]["name"], json!("doThing")); - assert_eq!(message["tool_calls"][0]["function"]["arguments"], json!("{\"a\":1}")); - assert_eq!(value["choices"][0]["finish_reason"], json!("tool_calls")); -} - -#[test] -fn chat_response_to_responses_extracts_choice_text_and_maps_usage() { - let input = bytes_from_json(json!({ - "id": "chatcmpl_123", - "created": 1700000000, - "model": "gpt-4.1", - "choices": [ - { "index": 0, "message": { "role": "assistant", "content": "Hello" } } - ], - "usage": { "prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3 } - })); - - let output = transform_response_body(FormatTransform::ChatToResponses, &input, None).expect("transform"); - let value = json_from_bytes(output); - - assert_eq!(value["id"], json!("chatcmpl_123")); - assert_eq!(value["object"], json!("response")); - assert_eq!(value["created_at"], json!(1700000000)); - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["output"][0]["type"], json!("message")); - assert_eq!(value["output"][0]["role"], json!("assistant")); - assert_eq!(value["output"][0]["content"][0]["type"], json!("output_text")); - assert_eq!(value["output"][0]["content"][0]["text"], json!("Hello")); - assert_eq!(value["usage"]["input_tokens"], json!(1)); - assert_eq!(value["usage"]["output_tokens"], json!(2)); - assert_eq!(value["usage"]["total_tokens"], json!(3)); -} - -#[test] -fn chat_response_to_responses_maps_finish_reason_to_incomplete_details() { - let input = bytes_from_json(json!({ - "id": "chatcmpl_456", - "created": 1700000002, - "model": "gpt-4.1", - "choices": [ - { "index": 0, "message": { "role": "assistant", "content": "Hello" }, "finish_reason": "length" } - ], - "usage": { "prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3 } - })); - - let output = transform_response_body(FormatTransform::ChatToResponses, &input, None).expect("transform"); - let value = json_from_bytes(output); - - assert_eq!(value["status"], json!("incomplete")); - assert_eq!(value["incomplete_details"]["reason"], json!("max_tokens")); -} - -#[test] -fn responses_request_to_chat_converts_function_call_output_to_tool_message() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": [ - { "type": "function_call_output", "call_id": "call_123", "output": "ok" } - ], - "stream": false - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - let messages = value["messages"].as_array().expect("messages array"); - - assert_eq!(messages.len(), 1); - assert_eq!(messages[0]["role"], json!("tool")); - assert_eq!(messages[0]["tool_call_id"], json!("call_123")); - assert_eq!(messages[0]["content"], json!("ok")); -} - -#[test] -fn chat_response_to_responses_maps_tool_calls_into_output() { - let input = bytes_from_json(json!({ - "id": "chatcmpl_123", - "created": 1700000000, - "model": "gpt-4.1", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_foo", - "type": "function", - "function": { - "name": "getRandomNumber", - "arguments": "{\"a\":\"0\"}" - } - } - ] - } - } - ], - "usage": { "prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3 } - })); - - let output = transform_response_body(FormatTransform::ChatToResponses, &input, None).expect("transform"); - let value = json_from_bytes(output); - - assert_eq!(value["id"], json!("chatcmpl_123")); - assert_eq!(value["object"], json!("response")); - assert_eq!(value["created_at"], json!(1700000000)); - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["output"][0]["type"], json!("function_call")); - assert_eq!(value["output"][0]["call_id"], json!("call_foo")); - assert_eq!(value["output"][0]["name"], json!("getRandomNumber")); - assert_eq!(value["output"][0]["arguments"], json!("{\"a\":\"0\"}")); - assert_eq!(value["usage"]["input_tokens"], json!(1)); - assert_eq!(value["usage"]["output_tokens"], json!(2)); - assert_eq!(value["usage"]["total_tokens"], json!(3)); -} - -#[test] -fn chat_request_to_responses_rejects_missing_messages() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = bytes_from_json(json!({ "model": "gpt-4.1" })); - let err = run_async(async { - transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) - .await - .expect_err("should fail") - }); - assert!(err.contains("messages")); -} - -#[test] -fn transform_request_body_rejects_non_json() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = Bytes::from_static(b"not-json"); - let err = run_async(async { - transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) - .await - .expect_err("should fail") - }); - assert!(err.contains("JSON")); -} - -#[test] -fn responses_and_gemini_request_conversions() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let responses_value = transform_request_value( - FormatTransform::ResponsesToGemini, - json!({ - "model": "gpt-4.1", - "input": "hi", - "instructions": "sys", - "temperature": 0.5, - "top_p": 0.9, - "max_output_tokens": 128, - "stop": ["a", "b"], - "seed": 7 - }), - &http_clients, - None, - ); - assert_eq!(responses_value["contents"][0]["parts"][0]["text"], json!("hi")); - assert_eq!(responses_value["systemInstruction"]["parts"][0]["text"], json!("sys")); - assert_eq!(responses_value["generationConfig"]["maxOutputTokens"], json!(128)); - assert_eq!(responses_value["generationConfig"]["stopSequences"], json!(["a", "b"])); - assert_eq!(responses_value["generationConfig"]["seed"], json!(7)); - let gemini_value = transform_request_value( - FormatTransform::GeminiToResponses, - json!({ - "model": "gemini-1.5-flash", - "contents": [{ "role": "user", "parts": [{ "text": "hello" }] }], - "systemInstruction": { "parts": [{ "text": "rules" }] }, - "generationConfig": { "maxOutputTokens": 64, "topP": 0.8 } - }), - &http_clients, - None, - ); - assert_eq!(gemini_value["model"], json!("gemini-1.5-flash")); - assert_eq!(gemini_value["instructions"], json!("rules")); - assert_eq!(gemini_value["input"][0]["content"][0]["text"], json!("hello")); - assert_eq!(gemini_value["max_output_tokens"], json!(64)); - assert_eq!(gemini_value["top_p"], json!(0.8)); -} -#[test] -fn gemini_and_anthropic_request_conversions() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let gemini_value = transform_request_value( - FormatTransform::GeminiToAnthropic, - json!({ - "contents": [{ "role": "user", "parts": [{ "text": "ping" }] }], - "systemInstruction": { "parts": [{ "text": "sys" }] }, - "generationConfig": { "maxOutputTokens": 42 } - }), - &http_clients, - Some("claude-3-5-sonnet"), - ); - assert_eq!(gemini_value["model"], json!("claude-3-5-sonnet")); - assert_eq!(gemini_value["system"][0]["text"], json!("sys")); - assert_eq!(gemini_value["messages"][0]["content"][0]["text"], json!("ping")); - assert_eq!(gemini_value["max_tokens"], json!(42)); - let anthropic_value = transform_request_value( - FormatTransform::AnthropicToGemini, - json!({ - "model": "claude-3-5-sonnet", - "max_tokens": 321, - "system": "guard", - "stop_sequences": ["x"], - "messages": [{ "role": "user", "content": [{ "type": "text", "text": "yo" }] }] - }), - &http_clients, - None, - ); - assert_eq!(anthropic_value["systemInstruction"]["parts"][0]["text"], json!("guard")); - assert_eq!(anthropic_value["contents"][0]["parts"][0]["text"], json!("yo")); - assert_eq!(anthropic_value["generationConfig"]["maxOutputTokens"], json!(321)); - assert_eq!(anthropic_value["generationConfig"]["stopSequences"], json!(["x"])); -} -#[test] -fn responses_and_gemini_response_conversions() { - let responses_value = transform_response_value( - FormatTransform::ResponsesToGemini, - json!({ - "id": "resp_1", - "created_at": 1700000000, - "model": "gpt-4.1", - "output": [ - { - "type": "message", - "role": "assistant", - "content": [{ "type": "output_text", "text": "Hello", "annotations": [] }] - } - ], - "usage": { "input_tokens": 2, "output_tokens": 3, "total_tokens": 5 } - }), - None, - ); - assert_eq!(responses_value["candidates"][0]["content"]["parts"][0]["text"], json!("Hello")); - assert_eq!(responses_value["usageMetadata"]["promptTokenCount"], json!(2)); - assert_eq!(responses_value["usageMetadata"]["candidatesTokenCount"], json!(3)); - assert_eq!(responses_value["usageMetadata"]["totalTokenCount"], json!(5)); - let gemini_value = transform_response_value( - FormatTransform::GeminiToResponses, - json!({ - "candidates": [ - { "content": { "role": "model", "parts": [{ "text": "Hi" }] }, "finishReason": "STOP" } - ], - "usageMetadata": { - "promptTokenCount": 4, - "candidatesTokenCount": 6, - "totalTokenCount": 10 - } - }), - Some("gemini-1.5-pro"), - ); - assert_eq!(gemini_value["output"][0]["content"][0]["text"], json!("Hi")); - assert_eq!(gemini_value["usage"]["input_tokens"], json!(4)); - assert_eq!(gemini_value["usage"]["output_tokens"], json!(6)); - assert_eq!(gemini_value["usage"]["total_tokens"], json!(10)); -} -#[test] -fn gemini_and_anthropic_response_conversions() { - let gemini_value = transform_response_value( - FormatTransform::GeminiToAnthropic, - json!({ - "candidates": [ - { "content": { "role": "model", "parts": [{ "text": "Howdy" }] }, "finishReason": "STOP" } - ], - "usageMetadata": { - "promptTokenCount": 1, - "candidatesTokenCount": 2, - "totalTokenCount": 3 - } - }), - Some("claude-3-5-sonnet"), - ); - assert_eq!(gemini_value["model"], json!("claude-3-5-sonnet")); - assert_eq!(gemini_value["content"][0]["text"], json!("Howdy")); - assert_eq!(gemini_value["usage"]["input_tokens"], json!(1)); - assert_eq!(gemini_value["usage"]["output_tokens"], json!(2)); - assert_eq!(gemini_value["stop_reason"], json!("end_turn")); - let anthropic_value = transform_response_value( - FormatTransform::AnthropicToGemini, - json!({ - "id": "msg_1", - "model": "claude-3-5-sonnet", - "content": [{ "type": "text", "text": "Yo" }], - "usage": { "input_tokens": 4, "output_tokens": 6 } - }), - None, - ); - assert_eq!(anthropic_value["candidates"][0]["content"]["parts"][0]["text"], json!("Yo")); - assert_eq!(anthropic_value["usageMetadata"]["promptTokenCount"], json!(4)); - assert_eq!(anthropic_value["usageMetadata"]["candidatesTokenCount"], json!(6)); - assert_eq!(anthropic_value["usageMetadata"]["totalTokenCount"], json!(10)); -} diff --git a/src-tauri/src/proxy/openai_compat/extract.rs b/src-tauri/src/proxy/openai_compat/extract.rs deleted file mode 100644 index cab2fbb..0000000 --- a/src-tauri/src/proxy/openai_compat/extract.rs +++ /dev/null @@ -1,174 +0,0 @@ -use serde_json::{json, Map, Value}; - -pub(super) struct ChatToolCall { - pub(super) item_id: String, - pub(super) call_id: String, - pub(super) name: String, - pub(super) arguments: String, -} - -pub(super) struct ResponsesOutput { - pub(super) content_parts: Vec, - pub(super) reasoning_text: String, - pub(super) tool_calls: Vec, -} - -pub(super) fn extract_chat_choice_text(value: &Value) -> Option { - let choices = value.get("choices")?.as_array()?; - let first = choices.first()?.as_object()?; - let message = first.get("message")?.as_object()?; - message.get("content")?.as_str().map(|text| text.to_string()) -} - -pub(super) fn extract_chat_tool_calls(value: &Value) -> Vec { - let Some(message) = extract_chat_first_message(value) else { - return Vec::new(); - }; - let tool_calls = message - .get("tool_calls") - .and_then(Value::as_array) - .map(|tool_calls| extract_chat_message_tool_calls(tool_calls)) - .unwrap_or_default(); - if !tool_calls.is_empty() { - return tool_calls; - } - - message - .get("function_call") - .and_then(Value::as_object) - .and_then(extract_chat_message_legacy_function_call) - .into_iter() - .collect() -} - -fn extract_chat_first_message(value: &Value) -> Option<&Map> { - let choices = value.get("choices")?.as_array()?; - let first = choices.first()?.as_object()?; - first.get("message")?.as_object() -} - -fn extract_chat_message_tool_calls(tool_calls: &[Value]) -> Vec { - let mut output = Vec::new(); - for call in tool_calls { - let Some(call) = call.as_object() else { - continue; - }; - let call_id = call.get("id").and_then(Value::as_str).unwrap_or(""); - if call_id.is_empty() { - continue; - } - - let function = call.get("function").and_then(Value::as_object); - let name = function - .and_then(|function| function.get("name")) - .and_then(Value::as_str) - .unwrap_or(""); - let arguments = function - .and_then(|function| function.get("arguments")) - .and_then(Value::as_str) - .unwrap_or(""); - - output.push(ChatToolCall { - item_id: format!("fc_{call_id}"), - call_id: call_id.to_string(), - name: name.to_string(), - arguments: arguments.to_string(), - }); - } - output -} - -fn extract_chat_message_legacy_function_call(function_call: &Map) -> Option { - let name = function_call.get("name").and_then(Value::as_str).unwrap_or(""); - let arguments = function_call - .get("arguments") - .and_then(Value::as_str) - .unwrap_or(""); - if name.is_empty() && arguments.is_empty() { - return None; - } - Some(ChatToolCall { - item_id: "fc_call_proxy".to_string(), - call_id: "call_proxy".to_string(), - name: name.to_string(), - arguments: arguments.to_string(), - }) -} - -pub(super) fn extract_responses_output(value: &Value) -> ResponsesOutput { - let Some(output) = value.get("output").and_then(Value::as_array) else { - return ResponsesOutput { - content_parts: Vec::new(), - reasoning_text: String::new(), - tool_calls: Vec::new(), - }; - }; - - let mut content_parts = Vec::new(); - let mut reasoning_text = String::new(); - let mut tool_calls = Vec::new(); - - for item in output { - let Some(item) = item.as_object() else { - continue; - }; - match item.get("type").and_then(Value::as_str) { - Some("message") => { - if item.get("role").and_then(Value::as_str) != Some("assistant") { - continue; - } - let Some(content) = item.get("content").and_then(Value::as_array) else { - continue; - }; - for part in content { - if let Some(part_obj) = part.as_object() { - let part_type = part_obj.get("type").and_then(Value::as_str); - if part_type == Some("reasoning_text") { - if let Some(text) = part_obj.get("text").and_then(Value::as_str) { - reasoning_text.push_str(text); - } - } - } - content_parts.push(part.clone()); - } - } - Some("function_call") => { - if let Some(tool_call) = extract_responses_tool_call(item) { - tool_calls.push(tool_call); - } - } - _ => {} - } - } - - ResponsesOutput { - content_parts, - reasoning_text, - tool_calls, - } -} - -fn extract_responses_tool_call(item: &Map) -> Option { - let call_id = item.get("call_id").and_then(Value::as_str).unwrap_or(""); - let item_id = item.get("id").and_then(Value::as_str).unwrap_or(""); - let name = item.get("name").and_then(Value::as_str).unwrap_or(""); - let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or(""); - let id = if !call_id.is_empty() { - call_id.to_string() - } else if !item_id.is_empty() { - item_id.to_string() - } else { - "call_proxy".to_string() - }; - if name.is_empty() && arguments.is_empty() && id == "call_proxy" { - return None; - } - Some(json!({ - "id": id, - "type": "function", - "function": { - "name": name, - "arguments": arguments - } - })) -} diff --git a/src-tauri/src/proxy/openai_compat/input.rs b/src-tauri/src/proxy/openai_compat/input.rs deleted file mode 100644 index 27aa873..0000000 --- a/src-tauri/src/proxy/openai_compat/input.rs +++ /dev/null @@ -1,145 +0,0 @@ -use serde_json::{json, Map, Value}; - -use super::message::extract_text_from_part; - -pub(super) fn responses_input_to_chat_messages(items: &[Value]) -> Result, String> { - let mut messages = Vec::with_capacity(items.len()); - for item in items { - messages.push(responses_input_item_to_chat_message(item)?); - } - Ok(messages) -} - -fn responses_input_item_to_chat_message(item: &Value) -> Result { - let Some(item) = item.as_object() else { - return Err("Responses input item must be an object.".to_string()); - }; - - // Cherry Studio / Codex CLI 可能会直接传 `[{ role, content:[{type,text}...] }]`, - // 这里需要把 content parts 归一化成 Chat API 需要的字符串/多模态数组。 - if item.get("role").and_then(Value::as_str).is_some() { - let mut output = item.clone(); - if let Some(content) = item.get("content").and_then(responses_message_content_to_chat_content) { - output.insert("content".to_string(), content); - } - return Ok(Value::Object(output)); - } - - let Some(item_type) = item.get("type").and_then(Value::as_str) else { - return Err("Responses input item must include role or type.".to_string()); - }; - - match item_type { - "message" => responses_message_item_to_chat_message(item), - "function_call_output" => responses_function_call_output_item_to_chat_message(item), - "function_call" => responses_function_call_item_to_chat_message(item), - _ => Err(format!("Unsupported Responses input item type: {item_type}")), - } -} - -fn responses_message_item_to_chat_message(item: &Map) -> Result { - let role = item - .get("role") - .and_then(Value::as_str) - .ok_or_else(|| "Responses message item must include role.".to_string())?; - let content = item - .get("content") - .and_then(responses_message_content_to_chat_content) - .unwrap_or_else(|| Value::String(String::new())); - Ok(json!({ "role": role, "content": content })) -} - -fn responses_function_call_output_item_to_chat_message(item: &Map) -> Result { - let call_id = item - .get("call_id") - .and_then(Value::as_str) - .ok_or_else(|| "function_call_output must include call_id.".to_string())?; - let output = item.get("output").and_then(Value::as_str).unwrap_or(""); - Ok(json!({ - "role": "tool", - "tool_call_id": call_id, - "content": output - })) -} - -fn responses_function_call_item_to_chat_message(item: &Map) -> Result { - let call_id = item - .get("call_id") - .and_then(Value::as_str) - .ok_or_else(|| "function_call must include call_id.".to_string())?; - let name = item.get("name").and_then(Value::as_str).unwrap_or(""); - let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or(""); - Ok(json!({ - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": call_id, - "type": "function", - "function": { "name": name, "arguments": arguments } - } - ] - })) -} - -fn responses_message_content_to_chat_content(value: &Value) -> Option { - match value { - Value::String(text) => Some(Value::String(text.to_string())), - Value::Array(parts) => { - let mut output_parts = Vec::new(); - let mut combined = String::new(); - let mut text_only = true; - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str); - match part_type { - Some("input_text") | Some("text") | Some("output_text") => { - if let Some(text) = extract_text_from_part(part) { - combined.push_str(&text); - output_parts.push(json!({ "type": "text", "text": text })); - } - } - Some("refusal") => { - // Responses may represent refusals as a dedicated content part. - let text = part - .get("refusal") - .or_else(|| part.get("text")) - .and_then(Value::as_str) - .unwrap_or(""); - if text.is_empty() { - continue; - } - combined.push_str(text); - output_parts.push(json!({ "type": "text", "text": text })); - } - Some("input_image") => { - // Chat Completions expects `{type:"image_url", image_url:{url:"..."}}`. - let url = match part.get("image_url") { - Some(Value::String(url)) => Some(json!({ "url": url })), - Some(Value::Object(object)) => object.get("url").and_then(Value::as_str).map(|url| json!({ "url": url })), - _ => None, - }; - let Some(image_url) = url else { - continue; - }; - text_only = false; - output_parts.push(json!({ "type": "image_url", "image_url": image_url })); - } - Some("input_file") => { - // Chat Completions doesn't have a standardized file/document part; skip for now. - text_only = false; - } - _ => continue, - } - } - if text_only { - Some(Value::String(combined)) - } else { - Some(Value::Array(output_parts)) - } - } - _ => Some(Value::String(String::new())), - } -} diff --git a/src-tauri/src/proxy/openai_compat/message.rs b/src-tauri/src/proxy/openai_compat/message.rs deleted file mode 100644 index 73a2b1f..0000000 --- a/src-tauri/src/proxy/openai_compat/message.rs +++ /dev/null @@ -1,183 +0,0 @@ -use serde_json::{json, Map, Value}; - -fn extract_text_value(value: &Value) -> Option { - match value { - Value::String(text) => Some(text.to_string()), - Value::Object(object) => { - if let Some(text) = object.get("text") { - return extract_text_value(text); - } - if let Some(text) = object.get("value") { - return extract_text_value(text); - } - None - } - _ => None, - } -} - -pub(super) fn extract_text_from_part(part: &Map) -> Option { - part.get("text").and_then(extract_text_value) -} - -pub(super) fn extract_text_from_chat_content(content: Option<&Value>) -> Option { - let Some(content) = content else { - return None; - }; - match content { - Value::String(text) => Some(text.to_string()), - Value::Array(parts) => { - let mut combined = String::new(); - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str).unwrap_or(""); - if !matches!(part_type, "text" | "input_text") { - continue; - } - if let Some(text) = extract_text_from_part(part) { - combined.push_str(&text); - } - } - if combined.trim().is_empty() { - None - } else { - Some(combined) - } - } - Value::Object(object) => object.get("text").and_then(Value::as_str).map(|t| t.to_string()), - _ => None, - } -} - -pub(super) fn chat_content_to_responses_message_parts( - content: Option<&Value>, - text_part_type: &str, -) -> Result, String> { - let Some(content) = content else { - return Ok(Vec::new()); - }; - match content { - Value::String(text) => Ok(vec![json!({ "type": text_part_type, "text": text })]), - Value::Array(parts) => { - let mut out = Vec::new(); - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - let part_type = part.get("type").and_then(Value::as_str).unwrap_or(""); - match part_type { - "text" | "input_text" => { - if let Some(text) = extract_text_from_part(part) { - out.push(json!({ "type": text_part_type, "text": text })); - } - } - "image_url" => { - let url = match part.get("image_url") { - Some(Value::String(url)) => Some(json!({ "url": url })), - Some(Value::Object(object)) => object - .get("url") - .and_then(Value::as_str) - .map(|url| json!({ "url": url })), - _ => None, - }; - if let Some(image_url) = url { - out.push(json!({ "type": "input_image", "image_url": image_url })); - } - } - "input_image" => { - if let Some(image_url) = part.get("image_url") { - out.push(json!({ "type": "input_image", "image_url": image_url.clone() })); - } - } - _ => {} - } - } - Ok(out) - } - _ => Ok(Vec::new()), - } -} - -pub(super) fn chat_tool_calls_to_responses_items(value: Option<&Value>) -> Vec { - let Some(tool_calls) = value.and_then(Value::as_array) else { - return Vec::new(); - }; - - tool_calls - .iter() - .enumerate() - .filter_map(|(idx, call)| chat_tool_call_to_responses_item(call, idx)) - .collect() -} - -fn chat_tool_call_to_responses_item(value: &Value, idx: usize) -> Option { - let call = value.as_object()?; - let call_id = call - .get("id") - .and_then(Value::as_str) - .filter(|v| !v.is_empty()) - .map(|v| v.to_string()) - .unwrap_or_else(|| format!("call_proxy_{idx}")); - let function = call.get("function").and_then(Value::as_object)?; - let name = function.get("name").and_then(Value::as_str).unwrap_or(""); - let arguments = stringify_any_json(function.get("arguments")); - - Some(json!({ - "type": "function_call", - "call_id": call_id, - "name": name, - "arguments": arguments - })) -} - -pub(super) fn chat_function_call_to_responses_item(value: Option<&Value>) -> Option { - let Some(value) = value else { - return None; - }; - let Some(function) = value.as_object() else { - return None; - }; - let name = function.get("name").and_then(Value::as_str).unwrap_or(""); - if name.is_empty() { - return None; - } - let arguments = stringify_any_json(function.get("arguments")); - Some(json!({ - "type": "function_call", - "call_id": "call_legacy", - "name": name, - "arguments": arguments - })) -} - -pub(super) fn stringify_any_json(value: Option<&Value>) -> String { - match value { - None => String::new(), - Some(Value::String(text)) => text.to_string(), - Some(other) => serde_json::to_string(other).unwrap_or_default(), - } -} - -pub(super) fn user_placeholder_item() -> Value { - json!({ - "type": "message", - "role": "user", - "content": [{ "type": "input_text", "text": "..." }] - }) -} - -pub(super) fn join_non_empty_lines(texts: Vec) -> Option { - let combined = texts - .into_iter() - .map(|t| t.trim().to_string()) - .filter(|t| !t.is_empty()) - .collect::>() - .join("\n"); - if combined.is_empty() { - None - } else { - Some(combined) - } -} diff --git a/src-tauri/src/proxy/openai_compat/tools.rs b/src-tauri/src/proxy/openai_compat/tools.rs deleted file mode 100644 index 21d99f2..0000000 --- a/src-tauri/src/proxy/openai_compat/tools.rs +++ /dev/null @@ -1,114 +0,0 @@ -use serde_json::{json, Map, Value}; - -pub(super) fn map_responses_tools_to_chat(value: &Value) -> Value { - let Some(tools) = value.as_array() else { - return value.clone(); - }; - let mapped = tools.iter().map(map_responses_tool_to_chat).collect::>(); - Value::Array(mapped) -} - -fn map_responses_tool_to_chat(value: &Value) -> Value { - let Some(tool) = value.as_object() else { - return value.clone(); - }; - - if tool.get("function").and_then(Value::as_object).is_some() { - return value.clone(); - } - if tool.get("type").and_then(Value::as_str) != Some("function") { - return value.clone(); - } - - let mut function = Map::new(); - if let Some(name) = tool.get("name") { - function.insert("name".to_string(), name.clone()); - } - if let Some(description) = tool.get("description") { - function.insert("description".to_string(), description.clone()); - } - if let Some(parameters) = tool.get("parameters") { - function.insert("parameters".to_string(), parameters.clone()); - } - - json!({ - "type": "function", - "function": Value::Object(function) - }) -} - -pub(super) fn map_chat_tools_to_responses(value: &Value) -> Value { - let Some(tools) = value.as_array() else { - return value.clone(); - }; - let mapped = tools.iter().map(map_chat_tool_to_responses).collect::>(); - Value::Array(mapped) -} - -fn map_chat_tool_to_responses(value: &Value) -> Value { - let Some(tool) = value.as_object() else { - return value.clone(); - }; - - if tool.get("type").and_then(Value::as_str) != Some("function") { - return value.clone(); - } - if tool.get("name").and_then(Value::as_str).is_some() { - return value.clone(); - } - let Some(function) = tool.get("function").and_then(Value::as_object) else { - return value.clone(); - }; - - let mut output = Map::new(); - output.insert("type".to_string(), json!("function")); - if let Some(name) = function.get("name") { - output.insert("name".to_string(), name.clone()); - } - if let Some(description) = function.get("description") { - output.insert("description".to_string(), description.clone()); - } - if let Some(parameters) = function.get("parameters") { - output.insert("parameters".to_string(), parameters.clone()); - } - Value::Object(output) -} - -pub(super) fn map_responses_tool_choice_to_chat(value: &Value) -> Value { - let Some(choice) = value.as_object() else { - return value.clone(); - }; - if choice.get("function").and_then(Value::as_object).is_some() { - return value.clone(); - } - if choice.get("type").and_then(Value::as_str) != Some("function") { - return value.clone(); - } - let name = choice.get("name").and_then(Value::as_str).unwrap_or(""); - json!({ - "type": "function", - "function": { "name": name } - }) -} - -pub(super) fn map_chat_tool_choice_to_responses(value: &Value) -> Value { - let Some(choice) = value.as_object() else { - return value.clone(); - }; - if choice.get("name").and_then(Value::as_str).is_some() { - return value.clone(); - } - if choice.get("type").and_then(Value::as_str) != Some("function") { - return value.clone(); - } - let name = choice - .get("function") - .and_then(|function| function.get("name")) - .and_then(Value::as_str) - .unwrap_or(""); - json!({ - "type": "function", - "name": name - }) -} - diff --git a/src-tauri/src/proxy/openai_compat/usage.rs b/src-tauri/src/proxy/openai_compat/usage.rs deleted file mode 100644 index 6a5c377..0000000 --- a/src-tauri/src/proxy/openai_compat/usage.rs +++ /dev/null @@ -1,41 +0,0 @@ -use serde_json::{json, Value}; - -pub(super) fn map_usage_responses_to_chat(usage: &Value) -> Option { - let usage = usage.as_object()?; - let input = usage.get("input_tokens").and_then(Value::as_u64); - let output = usage.get("output_tokens").and_then(Value::as_u64); - let total = usage - .get("total_tokens") - .and_then(Value::as_u64) - .or_else(|| match (input, output) { - (Some(input), Some(output)) => input.checked_add(output), - _ => None, - }); - if input.is_none() && output.is_none() && total.is_none() { - return None; - } - Some(json!({ - "prompt_tokens": input, - "completion_tokens": output, - "total_tokens": total - })) -} - -pub(super) fn map_usage_chat_to_responses(usage: &Value) -> Option { - let usage = usage.as_object()?; - let prompt = usage.get("prompt_tokens").and_then(Value::as_u64); - let completion = usage.get("completion_tokens").and_then(Value::as_u64); - let total = usage.get("total_tokens").and_then(Value::as_u64).or_else(|| match (prompt, completion) { - (Some(prompt), Some(completion)) => prompt.checked_add(completion), - _ => None, - }); - if prompt.is_none() && completion.is_none() && total.is_none() { - return None; - } - Some(json!({ - "input_tokens": prompt, - "output_tokens": completion, - "total_tokens": total - })) -} - diff --git a/src-tauri/src/proxy/redact.rs b/src-tauri/src/proxy/redact.rs deleted file mode 100644 index 3dd68f2..0000000 --- a/src-tauri/src/proxy/redact.rs +++ /dev/null @@ -1,25 +0,0 @@ -pub(crate) fn redact_query_param_value(message: &str, name: &str) -> String { - let needle = format!("{name}="); - let mut output = String::with_capacity(message.len()); - let mut rest = message; - - while let Some(pos) = rest.find(&needle) { - let (before, after) = rest.split_at(pos); - output.push_str(before); - output.push_str(&needle); - output.push_str("***"); - - let after = &after[needle.len()..]; - let mut end = after.len(); - for (idx, ch) in after.char_indices() { - if matches!(ch, '&' | ')' | ' ' | '\n' | '\r' | '\t' | '"' | '\'') { - end = idx; - break; - } - } - rest = &after[end..]; - } - - output.push_str(rest); - output -} diff --git a/src-tauri/src/proxy/request_body.rs b/src-tauri/src/proxy/request_body.rs deleted file mode 100644 index da71b9f..0000000 --- a/src-tauri/src/proxy/request_body.rs +++ /dev/null @@ -1,176 +0,0 @@ -use axum::body::{Body, Bytes}; -use futures_util::StreamExt; -use std::{ - path::PathBuf, - sync::atomic::{AtomicUsize, Ordering}, - time::{SystemTime, UNIX_EPOCH}, -}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - -const IN_MEMORY_LIMIT_BYTES: usize = 512 * 1024; -const TEMP_FILE_PREFIX: &str = "token_proxy_body"; -const FILE_READ_CHUNK_BYTES: usize = 64 * 1024; - -static TEMP_FILE_COUNTER: AtomicUsize = AtomicUsize::new(0); - -// 将入站请求体缓存为“可重放”形式:小体积保留在内存,超过阈值则落盘到临时文件。 -// 这样可以在上游重试/降级时,重复发送同一份请求体(代价是需要先完整读取请求体)。 -pub(crate) struct ReplayableBody { - inner: ReplayableBodyInner, - len: u64, -} - -enum ReplayableBodyInner { - InMemory(Bytes), - TempFile { path: PathBuf }, -} - -impl ReplayableBody { - pub(crate) fn from_bytes(bytes: Bytes) -> Self { - Self { - len: bytes.len() as u64, - inner: ReplayableBodyInner::InMemory(bytes), - } - } - - pub(crate) async fn from_body(body: Body) -> Result { - let mut stream = body.into_data_stream(); - let mut len = 0_u64; - let mut buffer: Vec = Vec::new(); - let mut temp: Option<(PathBuf, tokio::fs::File)> = None; - - while let Some(next) = stream.next().await { - let chunk = next.map_err(|err| { - std::io::Error::new(std::io::ErrorKind::Other, format!("Read request body failed: {err}")) - })?; - len = len.saturating_add(chunk.len() as u64); - - if let Some((_, file)) = temp.as_mut() { - file.write_all(&chunk).await?; - continue; - } - - if buffer.len().saturating_add(chunk.len()) <= IN_MEMORY_LIMIT_BYTES { - buffer.extend_from_slice(&chunk); - continue; - } - - let (path, mut file) = create_temp_file().await?; - if let Err(err) = file.write_all(&buffer).await { - cleanup_temp_file(&path); - return Err(err); - } - buffer.clear(); - if let Err(err) = file.write_all(&chunk).await { - cleanup_temp_file(&path); - return Err(err); - } - temp = Some((path, file)); - } - - if let Some((path, mut file)) = temp { - if let Err(err) = file.flush().await { - cleanup_temp_file(&path); - return Err(err); - } - return Ok(Self { - inner: ReplayableBodyInner::TempFile { path }, - len, - }); - } - - Ok(Self { - inner: ReplayableBodyInner::InMemory(Bytes::from(buffer)), - len, - }) - } - - pub(crate) async fn read_bytes_if_small( - &self, - limit: usize, - ) -> Result, std::io::Error> { - let Some(len) = usize::try_from(self.len).ok() else { - return Ok(None); - }; - if len > limit { - return Ok(None); - } - - match &self.inner { - ReplayableBodyInner::InMemory(bytes) => Ok(Some(bytes.clone())), - ReplayableBodyInner::TempFile { path } => { - let mut file = tokio::fs::File::open(path).await?; - let mut output = Vec::with_capacity(len); - let mut chunk = vec![0_u8; FILE_READ_CHUNK_BYTES]; - loop { - let read = file.read(&mut chunk).await?; - if read == 0 { - break; - } - output.extend_from_slice(&chunk[..read]); - } - Ok(Some(Bytes::from(output))) - } - } - } - - pub(crate) async fn to_reqwest_body(&self) -> Result { - match &self.inner { - ReplayableBodyInner::InMemory(bytes) => Ok(reqwest::Body::from(bytes.clone())), - ReplayableBodyInner::TempFile { path } => { - let file = tokio::fs::File::open(path).await?; - Ok(reqwest::Body::from(file)) - } - } - } -} - -impl Drop for ReplayableBody { - fn drop(&mut self) { - if let ReplayableBodyInner::TempFile { path } = &self.inner { - cleanup_temp_file(path); - } - } -} - -async fn create_temp_file() -> Result<(PathBuf, tokio::fs::File), std::io::Error> { - let path = next_temp_path(); - let file = tokio::fs::OpenOptions::new() - .create_new(true) - .write(true) - .open(&path) - .await?; - Ok((path, file)) -} - -fn next_temp_path() -> PathBuf { - let now_ns = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_nanos(); - let counter = TEMP_FILE_COUNTER.fetch_add(1, Ordering::Relaxed); - let name = format!("{TEMP_FILE_PREFIX}_{now_ns}_{counter}"); - std::env::temp_dir().join(name) -} - -fn cleanup_temp_file(path: &PathBuf) { - let _ = std::fs::remove_file(path); -} - -#[cfg(test)] -impl ReplayableBody { - fn is_temp_file(&self) -> bool { - matches!(self.inner, ReplayableBodyInner::TempFile { .. }) - } - - fn temp_path(&self) -> Option { - match &self.inner { - ReplayableBodyInner::TempFile { path } => Some(path.clone()), - ReplayableBodyInner::InMemory(_) => None, - } - } -} - -#[cfg(test)] -#[path = "request_body.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/request_body.test.rs b/src-tauri/src/proxy/request_body.test.rs deleted file mode 100644 index a49467b..0000000 --- a/src-tauri/src/proxy/request_body.test.rs +++ /dev/null @@ -1,56 +0,0 @@ -use super::*; -use axum::body::Body; - -fn run_async(future: impl std::future::Future) -> T { - tokio::runtime::Runtime::new() - .expect("create tokio runtime") - .block_on(future) -} - -#[test] -fn replayable_body_small_stays_in_memory() { - run_async(async { - let input = vec![b'a'; 16]; - let body = ReplayableBody::from_body(Body::from(input.clone())) - .await - .expect("spool body"); - - assert!(!body.is_temp_file()); - let bytes = body - .read_bytes_if_small(1024) - .await - .expect("read bytes") - .expect("bytes present"); - assert_eq!(bytes.as_ref(), input.as_slice()); - }); -} - -#[test] -fn replayable_body_large_spools_to_temp_file_and_cleans_up() { - run_async(async { - let input = vec![b'b'; IN_MEMORY_LIMIT_BYTES + 1]; - let body = ReplayableBody::from_body(Body::from(input.clone())) - .await - .expect("spool body"); - - assert!(body.is_temp_file()); - let path = body.temp_path().expect("temp path"); - assert!( - std::fs::metadata(&path).is_ok(), - "temp file should exist: {path:?}" - ); - - let bytes = body - .read_bytes_if_small(IN_MEMORY_LIMIT_BYTES + 32) - .await - .expect("read bytes") - .expect("bytes present"); - assert_eq!(bytes.as_ref(), input.as_slice()); - - drop(body); - assert!( - std::fs::metadata(&path).is_err(), - "temp file should be removed on drop: {path:?}" - ); - }); -} diff --git a/src-tauri/src/proxy/request_detail.rs b/src-tauri/src/proxy/request_detail.rs deleted file mode 100644 index 9baac9c..0000000 --- a/src-tauri/src/proxy/request_detail.rs +++ /dev/null @@ -1,111 +0,0 @@ -use axum::http::HeaderMap; -use serde::Serialize; -use std::sync::atomic::{AtomicBool, Ordering}; -use tauri::{AppHandle, Emitter}; - -use super::request_body::ReplayableBody; - -const BODY_TOO_LARGE_MESSAGE: &str = "[body omitted: too large]"; - -#[derive(Clone, Default)] -pub(crate) struct RequestDetailSnapshot { - pub(crate) request_headers: Option, - pub(crate) request_body: Option, -} - -const REQUEST_DETAIL_CAPTURE_EVENT: &str = "request-detail-capture-changed"; - -pub(crate) struct RequestDetailCapture { - armed: AtomicBool, - app: Option, -} - -impl RequestDetailCapture { - pub(crate) fn new(app: AppHandle) -> Self { - Self { - armed: AtomicBool::new(false), - app: Some(app), - } - } - - pub(crate) fn arm(&self) { - self.armed.store(true, Ordering::SeqCst); - self.emit(true); - } - - pub(crate) fn disarm(&self) { - self.armed.store(false, Ordering::SeqCst); - self.emit(false); - } - - pub(crate) fn take(&self) -> bool { - let was_armed = self.armed.swap(false, Ordering::SeqCst); - if was_armed { - self.emit(false); - } - was_armed - } - - pub(crate) fn is_armed(&self) -> bool { - self.armed.load(Ordering::SeqCst) - } - - fn emit(&self, enabled: bool) { - let Some(app) = self.app.as_ref() else { - return; - }; - let _ = app.emit( - REQUEST_DETAIL_CAPTURE_EVENT, - RequestDetailCaptureEvent { enabled }, - ); - } -} - -impl Default for RequestDetailCapture { - fn default() -> Self { - Self { - armed: AtomicBool::new(false), - app: None, - } - } -} - -pub(crate) fn serialize_request_headers(headers: &HeaderMap) -> Option { - let items: Vec = headers - .iter() - .map(|(name, value)| HeaderEntry { - name: name.to_string(), - value: value.to_str().unwrap_or("").to_string(), - }) - .collect(); - serde_json::to_string(&items).ok() -} - -pub(crate) async fn capture_request_detail( - headers: &HeaderMap, - body: &ReplayableBody, - max_body_bytes: usize, -) -> RequestDetailSnapshot { - let request_headers = serialize_request_headers(headers); - let request_body = match body.read_bytes_if_small(max_body_bytes).await { - Ok(Some(bytes)) => Some(String::from_utf8_lossy(&bytes).to_string()), - Ok(None) => Some(BODY_TOO_LARGE_MESSAGE.to_string()), - Err(err) => Some(format!("Failed to read request body: {err}")), - }; - - RequestDetailSnapshot { - request_headers, - request_body, - } -} - -#[derive(Serialize)] -struct HeaderEntry { - name: String, - value: String, -} - -#[derive(Clone, Serialize)] -struct RequestDetailCaptureEvent { - enabled: bool, -} diff --git a/src-tauri/src/proxy/request_token_estimate.rs b/src-tauri/src/proxy/request_token_estimate.rs deleted file mode 100644 index 6c082bc..0000000 --- a/src-tauri/src/proxy/request_token_estimate.rs +++ /dev/null @@ -1,556 +0,0 @@ -use base64::Engine; -use serde_json::Value; - -use super::token_estimator::{self, TokenProvider}; - -const OPENAI_MESSAGE_OVERHEAD: u64 = 3; -const OPENAI_NAME_OVERHEAD: u64 = 3; -const OPENAI_TOOL_OVERHEAD: u64 = 8; -const OPENAI_FIXED_OVERHEAD: u64 = 3; - -const DEFAULT_IMAGE_TOKENS: u64 = 520; -const DEFAULT_AUDIO_TOKENS: u64 = 256; -const DEFAULT_VIDEO_TOKENS: u64 = 4096 * 2; -const DEFAULT_FILE_TOKENS: u64 = 4096; - -const IMAGE_DECODE_LIMIT_BYTES: usize = 64 * 1024; - -pub(crate) fn estimate_request_input_tokens(value: &Value, model: Option<&str>) -> Option { - let message_stats = sum_message_stats(value, model); - let root_tokens = sum_root_tokens(value, model); - let tool_count = count_tools(value); - let total = message_stats - .tokens - .saturating_add(root_tokens) - .saturating_add(openai_overhead_tokens( - message_stats.message_count, - message_stats.name_count, - tool_count, - )); - - if total == 0 { - None - } else { - Some(total) - } -} - -#[derive(Default)] -struct MessageStats { - tokens: u64, - message_count: u64, - name_count: u64, -} - -fn sum_message_stats(value: &Value, model: Option<&str>) -> MessageStats { - let Some(messages) = value.get("messages").and_then(Value::as_array) else { - return MessageStats::default(); - }; - - let mut stats = MessageStats::default(); - for message in messages { - stats.message_count = stats.message_count.saturating_add(1); - if message.get("name").and_then(Value::as_str).is_some() { - stats.name_count = stats.name_count.saturating_add(1); - } - stats.tokens = stats - .tokens - .saturating_add(sum_message_text_tokens(message, model)) - .saturating_add(sum_message_media_tokens(message, model)); - } - stats -} - -fn sum_root_tokens(value: &Value, model: Option<&str>) -> u64 { - let mut total = 0u64; - - if let Some(prompt) = value.get("prompt") { - total = total.saturating_add(sum_text_value(prompt, model)); - } - - if let Some(input) = value.get("input") { - total = total.saturating_add(sum_input_text_tokens(input, model)); - total = total.saturating_add(sum_input_media_tokens(input, model)); - } - - if let Some(system) = value.get("system") { - total = total.saturating_add(sum_text_value(system, model)); - } - - if let Some(system_instruction) = value.get("system_instruction") { - total = total.saturating_add(sum_text_value(system_instruction, model)); - } - - if let Some(system_instruction) = value.get("systemInstruction") { - total = total.saturating_add(sum_text_value(system_instruction, model)); - } - - if let Some(instructions) = value.get("instructions") { - total = total.saturating_add(sum_text_value(instructions, model)); - } - - if let Some(contents) = value.get("contents") { - total = total.saturating_add(sum_gemini_contents_text_tokens(contents, model)); - total = total.saturating_add(sum_gemini_contents_media_tokens(contents, model)); - } - - total -} - -fn count_tools(value: &Value) -> u64 { - value - .get("tools") - .and_then(Value::as_array) - .map(|items| items.len() as u64) - .unwrap_or(0) -} - -fn openai_overhead_tokens(message_count: u64, name_count: u64, tool_count: u64) -> u64 { - if message_count == 0 && name_count == 0 && tool_count == 0 { - return 0; - } - message_count - .saturating_mul(OPENAI_MESSAGE_OVERHEAD) - .saturating_add(name_count.saturating_mul(OPENAI_NAME_OVERHEAD)) - .saturating_add(tool_count.saturating_mul(OPENAI_TOOL_OVERHEAD)) - .saturating_add(OPENAI_FIXED_OVERHEAD) -} - -fn sum_message_text_tokens(message: &Value, model: Option<&str>) -> u64 { - let Some(content) = message.get("content") else { - return 0; - }; - sum_content_text_tokens(content, model) -} - -fn sum_message_media_tokens(message: &Value, model: Option<&str>) -> u64 { - let Some(content) = message.get("content") else { - return 0; - }; - sum_content_media_tokens(content, model) -} - -fn sum_input_text_tokens(input: &Value, model: Option<&str>) -> u64 { - match input { - Value::String(_) => sum_text_value(input, model), - Value::Array(items) => items.iter().fold(0u64, |acc, item| { - let mut total = acc; - if item.is_string() { - total = total.saturating_add(sum_text_value(item, model)); - } else if let Some(content) = item.get("content") { - total = total.saturating_add(sum_content_text_tokens(content, model)); - } else if let Some(text) = item.get("text") { - total = total.saturating_add(sum_text_value(text, model)); - } - total - }), - Value::Object(object) => object - .get("content") - .map(|content| sum_content_text_tokens(content, model)) - .or_else(|| object.get("text").map(|text| sum_text_value(text, model))) - .unwrap_or(0), - _ => 0, - } -} - -fn sum_input_media_tokens(input: &Value, model: Option<&str>) -> u64 { - match input { - Value::Array(items) => items.iter().fold(0u64, |acc, item| { - let mut total = acc; - if let Some(content) = item.get("content") { - total = total.saturating_add(sum_content_media_tokens(content, model)); - } - total - }), - Value::Object(object) => object - .get("content") - .map(|content| sum_content_media_tokens(content, model)) - .unwrap_or(0), - _ => 0, - } -} - -fn sum_gemini_contents_text_tokens(contents: &Value, model: Option<&str>) -> u64 { - let Some(contents) = contents.as_array() else { - return 0; - }; - contents.iter().fold(0u64, |acc, content| { - let mut total = acc; - if let Some(parts) = content.get("parts").and_then(Value::as_array) { - for part in parts { - if let Some(text) = part.get("text") { - total = total.saturating_add(sum_text_value(text, model)); - } - } - } - total - }) -} - -fn sum_gemini_contents_media_tokens(contents: &Value, model: Option<&str>) -> u64 { - let Some(contents) = contents.as_array() else { - return 0; - }; - contents.iter().fold(0u64, |acc, content| { - let mut total = acc; - if let Some(parts) = content.get("parts").and_then(Value::as_array) { - for part in parts { - total = total.saturating_add(sum_gemini_part_media_tokens(part, model)); - } - } - total - }) -} - -fn sum_gemini_part_media_tokens(part: &Value, model: Option<&str>) -> u64 { - if let Some(inline) = part.get("inlineData") { - let mime = inline.get("mimeType").and_then(Value::as_str); - let data = inline.get("data").and_then(Value::as_str); - return estimate_media_tokens(model, mime, data, None); - } - if let Some(file_data) = part.get("fileData") { - let mime = file_data.get("mimeType").and_then(Value::as_str); - return estimate_media_tokens(model, mime, None, None); - } - 0 -} - -fn sum_content_text_tokens(content: &Value, model: Option<&str>) -> u64 { - match content { - Value::String(_) => sum_text_value(content, model), - Value::Array(items) => items.iter().fold(0u64, |acc, item| { - let mut total = acc; - if let Some(text) = item.get("text") { - total = total.saturating_add(sum_text_value(text, model)); - } else if item.is_string() { - total = total.saturating_add(sum_text_value(item, model)); - } - total - }), - _ => 0, - } -} - -fn sum_content_media_tokens(content: &Value, model: Option<&str>) -> u64 { - let Some(items) = content.as_array() else { - return 0; - }; - items.iter().fold(0u64, |acc, item| { - acc.saturating_add(sum_openai_part_media_tokens(item, model)) - .saturating_add(sum_anthropic_part_media_tokens(item, model)) - }) -} - -fn sum_openai_part_media_tokens(part: &Value, model: Option<&str>) -> u64 { - let Some(part_type) = part.get("type").and_then(Value::as_str) else { - return 0; - }; - match part_type { - "image_url" => { - let image = part.get("image_url"); - let (url, detail) = match image { - Some(Value::String(url)) => (Some(url.as_str()), None), - Some(Value::Object(obj)) => ( - obj.get("url").and_then(Value::as_str), - obj.get("detail").and_then(Value::as_str), - ), - _ => (None, None), - }; - estimate_media_tokens(model, Some("image/*"), url, detail) - } - "input_audio" => { - let audio = part.get("input_audio").and_then(Value::as_object); - let data = audio.and_then(|obj| obj.get("data").and_then(Value::as_str)); - estimate_media_tokens(model, Some("audio/*"), data, None) - } - _ => 0, - } -} - -fn sum_anthropic_part_media_tokens(part: &Value, model: Option<&str>) -> u64 { - let Some(part_type) = part.get("type").and_then(Value::as_str) else { - return 0; - }; - if part_type != "image" { - return 0; - } - let source = part.get("source"); - let mime = source - .and_then(Value::as_object) - .and_then(|obj| obj.get("media_type")) - .and_then(Value::as_str); - let data = source - .and_then(Value::as_object) - .and_then(|obj| obj.get("data")) - .and_then(Value::as_str); - estimate_media_tokens(model, mime, data, None) -} - -fn sum_text_value(value: &Value, model: Option<&str>) -> u64 { - match value { - Value::String(text) => token_estimator::estimate_text_tokens(model, text), - Value::Array(items) => items.iter().fold(0u64, |acc, item| { - acc.saturating_add(sum_text_value(item, model)) - }), - Value::Object(object) => object - .get("text") - .and_then(Value::as_str) - .map(|text| token_estimator::estimate_text_tokens(model, text)) - .unwrap_or(0), - _ => 0, - } -} - -fn estimate_media_tokens( - model: Option<&str>, - mime: Option<&str>, - data: Option<&str>, - detail: Option<&str>, -) -> u64 { - let kind = media_kind_from_mime(mime); - match kind { - MediaKind::Image => estimate_image_tokens(model, data, detail), - MediaKind::Audio => DEFAULT_AUDIO_TOKENS, - MediaKind::Video => DEFAULT_VIDEO_TOKENS, - MediaKind::Other => DEFAULT_FILE_TOKENS, - } -} - -fn estimate_image_tokens(model: Option<&str>, data: Option<&str>, detail: Option<&str>) -> u64 { - let provider = token_estimator::provider_for_model(model); - if provider != TokenProvider::OpenAI { - return DEFAULT_IMAGE_TOKENS; - } - - let normalized = model.unwrap_or("").trim().to_ascii_lowercase(); - - if normalized.contains("glm-4") { - return 1047; - } - - if let Some(multiplier) = patch_multiplier(&normalized) { - if let Some((width, height)) = decode_image_dimensions(data) { - return estimate_patch_tokens(width, height, multiplier); - } - return base_tile_tokens(&normalized).0; - } - - let (base_tokens, tile_tokens) = base_tile_tokens(&normalized); - if detail == Some("low") { - return base_tokens; - } - - if let Some((width, height)) = decode_image_dimensions(data) { - return estimate_tile_tokens(width, height, base_tokens, tile_tokens); - } - - base_tokens -} - -fn media_kind_from_mime(mime: Option<&str>) -> MediaKind { - let Some(mime) = mime else { - return MediaKind::Other; - }; - let normalized = mime.to_ascii_lowercase(); - if normalized.starts_with("image/") { - return MediaKind::Image; - } - if normalized.starts_with("audio/") { - return MediaKind::Audio; - } - if normalized.starts_with("video/") { - return MediaKind::Video; - } - MediaKind::Other -} - -fn patch_multiplier(model: &str) -> Option { - if model.contains("gpt-4.1-mini") || model.contains("gpt-5-mini") { - return Some(1.62); - } - if model.contains("gpt-4.1-nano") || model.contains("gpt-5-nano") { - return Some(2.46); - } - if model.contains("o4-mini") { - return Some(1.72); - } - None -} - -fn base_tile_tokens(model: &str) -> (u64, u64) { - if model.contains("gpt-4o-mini") { - return (2833, 5667); - } - if model.contains("gpt-5-chat-latest") - || (model.contains("gpt-5") && !model.contains("mini") && !model.contains("nano")) - { - return (70, 140); - } - if model.starts_with("o1") || model.starts_with("o3") || model.contains("o1-pro") { - return (75, 150); - } - if model.contains("computer-use-preview") { - return (65, 129); - } - if model.contains("4.1") || model.contains("4o") || model.contains("4.5") { - return (85, 170); - } - (85, 170) -} - -// tile 规则:最长边 <= 2048;最短边缩放至 768;按 512 分块。 -fn estimate_tile_tokens(width: u32, height: u32, base: u64, tile: u64) -> u64 { - if width == 0 || height == 0 { - return base; - } - let (mut w, mut h) = (width as f64, height as f64); - let max_side = w.max(h); - if max_side > 2048.0 { - let ratio = 2048.0 / max_side; - w *= ratio; - h *= ratio; - } - let min_side = w.min(h); - if min_side > 0.0 { - let ratio = 768.0 / min_side; - w *= ratio; - h *= ratio; - } - let tiles_w = (w / 512.0).ceil() as u64; - let tiles_h = (h / 512.0).ceil() as u64; - base.saturating_add(tiles_w.saturating_mul(tiles_h).saturating_mul(tile)) -} - -// patch 规则:32x32 patch,上限 1536,按 multiplier 估算。 -fn estimate_patch_tokens(width: u32, height: u32, multiplier: f64) -> u64 { - if width == 0 || height == 0 { - return 0; - } - let (mut w, mut h) = (width as f64, height as f64); - let mut patches = (w / 32.0).ceil() * (h / 32.0).ceil(); - if patches > 1536.0 { - let ratio = (1536.0 / patches).sqrt(); - w *= ratio; - h *= ratio; - patches = (w / 32.0).ceil() * (h / 32.0).ceil(); - } - (patches * multiplier).ceil() as u64 -} - -fn decode_image_dimensions(data: Option<&str>) -> Option<(u32, u32)> { - let Some(data) = data else { - return None; - }; - - if let Some((mime, payload)) = parse_data_uri(data) { - let bytes = decode_base64_prefix(payload, IMAGE_DECODE_LIMIT_BYTES)?; - return decode_dimensions_from_bytes(mime, &bytes); - } - - let bytes = decode_base64_prefix(data, IMAGE_DECODE_LIMIT_BYTES)?; - decode_dimensions_from_bytes(None, &bytes) -} - -fn parse_data_uri(data: &str) -> Option<(Option<&str>, &str)> { - let data = data.strip_prefix("data:")?; - let (meta, payload) = data.split_once(',')?; - let (mime, encoding) = meta.split_once(';')?; - if encoding.trim() != "base64" { - return None; - } - Some((Some(mime.trim()), payload)) -} - -fn decode_base64_prefix(data: &str, max_bytes: usize) -> Option> { - let max_chars = ((max_bytes + 2) / 3) * 4; - let mut slice_len = data.len().min(max_chars); - slice_len -= slice_len % 4; - if slice_len == 0 { - return None; - } - let prefix = &data[..slice_len]; - base64::engine::general_purpose::STANDARD - .decode(prefix) - .ok() -} - -fn decode_dimensions_from_bytes(mime: Option<&str>, bytes: &[u8]) -> Option<(u32, u32)> { - if let Some(mime) = mime { - let normalized = mime.to_ascii_lowercase(); - if normalized.contains("png") { - return png_dimensions(bytes); - } - if normalized.contains("jpeg") || normalized.contains("jpg") { - return jpeg_dimensions(bytes); - } - } - png_dimensions(bytes).or_else(|| jpeg_dimensions(bytes)) -} - -fn png_dimensions(bytes: &[u8]) -> Option<(u32, u32)> { - const PNG_SIGNATURE: [u8; 8] = [0x89, b'P', b'N', b'G', 0x0D, 0x0A, 0x1A, 0x0A]; - if bytes.len() < 24 || bytes[..8] != PNG_SIGNATURE { - return None; - } - if &bytes[12..16] != b"IHDR" { - return None; - } - let width = u32::from_be_bytes(bytes[16..20].try_into().ok()?); - let height = u32::from_be_bytes(bytes[20..24].try_into().ok()?); - Some((width, height)) -} - -fn jpeg_dimensions(bytes: &[u8]) -> Option<(u32, u32)> { - if bytes.len() < 4 || bytes[0] != 0xFF || bytes[1] != 0xD8 { - return None; - } - let mut index = 2usize; - while index + 3 < bytes.len() { - if bytes[index] != 0xFF { - index += 1; - continue; - } - let marker = bytes[index + 1]; - if marker == 0xD8 || marker == 0xD9 { - index += 2; - continue; - } - if index + 3 >= bytes.len() { - break; - } - let length = u16::from_be_bytes([bytes[index + 2], bytes[index + 3]]) as usize; - if length < 2 || index + 2 + length > bytes.len() { - break; - } - if is_sof_marker(marker) { - if length >= 7 { - let height = u16::from_be_bytes([bytes[index + 5], bytes[index + 6]]); - let width = u16::from_be_bytes([bytes[index + 7], bytes[index + 8]]); - return Some((width as u32, height as u32)); - } - return None; - } - index += 2 + length; - } - None -} - -fn is_sof_marker(marker: u8) -> bool { - matches!( - marker, - 0xC0 | 0xC1 | 0xC2 | 0xC3 | 0xC5 | 0xC6 | 0xC7 | 0xC9 | 0xCA | 0xCB | 0xCD - | 0xCE | 0xCF - ) -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum MediaKind { - Image, - Audio, - Video, - Other, -} - -#[cfg(test)] -#[path = "request_token_estimate.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/request_token_estimate.test.rs b/src-tauri/src/proxy/request_token_estimate.test.rs deleted file mode 100644 index 633e66d..0000000 --- a/src-tauri/src/proxy/request_token_estimate.test.rs +++ /dev/null @@ -1,140 +0,0 @@ -use base64::Engine; -use serde_json::json; - -use super::estimate_request_input_tokens; - -const PNG_SIGNATURE: [u8; 8] = [0x89, b'P', b'N', b'G', 0x0D, 0x0A, 0x1A, 0x0A]; - -fn encode_png_base64(width: u32, height: u32) -> String { - let mut bytes = vec![0u8; 24]; - bytes[..8].copy_from_slice(&PNG_SIGNATURE); - bytes[8..12].copy_from_slice(&13u32.to_be_bytes()); - bytes[12..16].copy_from_slice(b"IHDR"); - bytes[16..20].copy_from_slice(&width.to_be_bytes()); - bytes[20..24].copy_from_slice(&height.to_be_bytes()); - base64::engine::general_purpose::STANDARD.encode(bytes) -} - -#[test] -fn estimates_openai_overhead_tokens() { - let value = json!({ - "model": "gpt-4o", - "messages": [ - {"role": "user", "content": ""}, - {"role": "assistant", "name": "bot", "content": ""} - ], - "tools": [{}, {}, {}] - }); - let tokens = estimate_request_input_tokens(&value, Some("gpt-4o")).unwrap(); - // 2 messages *3 + 1 name *3 + 3 tools *8 + 3 fixed - assert_eq!(tokens, 36); -} - -#[test] -fn estimates_low_detail_image_tokens_for_openai() { - let value = json!({ - "model": "gpt-4o", - "messages": [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": "https://example.com/a.png", "detail": "low"}} - ] - } - ] - }); - let tokens = estimate_request_input_tokens(&value, Some("gpt-4o")).unwrap(); - // base 85 + overhead (1 message + fixed) - assert_eq!(tokens, 91); -} - -#[test] -fn estimates_image_tokens_for_non_openai_model() { - let value = json!({ - "model": "claude-3-opus", - "messages": [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": "https://example.com/a.png"}} - ] - } - ] - }); - let tokens = estimate_request_input_tokens(&value, Some("claude-3-opus")).unwrap(); - // 520 default image + overhead 6 - assert_eq!(tokens, 526); -} - -#[test] -fn estimates_patch_image_tokens_for_gpt_4_1_mini() { - let data = encode_png_base64(32, 32); - let value = json!({ - "model": "gpt-4.1-mini", - "messages": [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{}", data)}} - ] - } - ] - }); - let tokens = estimate_request_input_tokens(&value, Some("gpt-4.1-mini")).unwrap(); - // patch 1 * 1.62 => ceil 2, overhead 6 - assert_eq!(tokens, 8); -} - -#[test] -fn estimates_tile_image_tokens_for_gpt_4o() { - let data = encode_png_base64(1024, 1024); - let value = json!({ - "model": "gpt-4o", - "messages": [ - { - "role": "user", - "content": [ - {"type": "image_url", "image_url": {"url": format!("data:image/png;base64,{}", data)}} - ] - } - ] - }); - let tokens = estimate_request_input_tokens(&value, Some("gpt-4o")).unwrap(); - // base 85 + tiles 4 * 170 + overhead 6 - assert_eq!(tokens, 771); -} - -#[test] -fn estimates_gemini_inline_image_tokens() { - let data = encode_png_base64(64, 64); - let value = json!({ - "model": "gemini-1.5-flash", - "contents": [ - { - "parts": [ - {"inlineData": {"mimeType": "image/png", "data": data}} - ] - } - ] - }); - let tokens = estimate_request_input_tokens(&value, Some("gemini-1.5-flash")).unwrap(); - assert_eq!(tokens, 520); -} - -#[test] -fn estimates_openai_input_audio_tokens() { - let value = json!({ - "model": "gpt-4o", - "messages": [ - { - "role": "user", - "content": [ - {"type": "input_audio", "input_audio": {"data": "dGVzdA=="}} - ] - } - ] - }); - let tokens = estimate_request_input_tokens(&value, Some("gpt-4o")).unwrap(); - // audio 256 + overhead 6 - assert_eq!(tokens, 262); -} diff --git a/src-tauri/src/proxy/response.rs b/src-tauri/src/proxy/response.rs deleted file mode 100644 index 7827e3a..0000000 --- a/src-tauri/src/proxy/response.rs +++ /dev/null @@ -1,200 +0,0 @@ -use axum::{body::Bytes, response::Response}; -use serde_json::Value; -use std::{ - sync::Arc, - time::{Instant, SystemTime, UNIX_EPOCH}, -}; - -use super::{ - http, - log::{LogContext, LogWriter}, - openai_compat::FormatTransform, - token_rate::TokenRateTracker, - request_detail::RequestDetailSnapshot, - RequestMeta, -}; - -const PROVIDER_OPENAI: &str = "openai"; -const PROVIDER_OPENAI_RESPONSES: &str = "openai-response"; -const PROVIDER_ANTHROPIC: &str = "anthropic"; -const PROVIDER_ANTIGRAVITY: &str = "antigravity"; -const PROVIDER_GEMINI: &str = "gemini"; -const PROVIDER_CODEX: &str = "codex"; -const RESPONSE_ERROR_LIMIT_BYTES: usize = 256 * 1024; - -pub(super) async fn build_proxy_response( - meta: &RequestMeta, - provider: &str, - upstream_id: &str, - inbound_path: &str, - upstream_res: reqwest::Response, - log: Arc, - token_rate: Arc, - start: Instant, - response_transform: FormatTransform, - request_detail: Option, -) -> Response { - let status = upstream_res.status(); - let mut response_headers = http::filter_response_headers(upstream_res.headers()); - let (request_headers, request_body) = request_detail - .map(|detail| (detail.request_headers, detail.request_body)) - .unwrap_or((None, None)); - let context = LogContext { - path: inbound_path.to_string(), - provider: provider.to_string(), - upstream_id: upstream_id.to_string(), - model: meta.original_model.clone(), - mapped_model: meta.mapped_model.clone(), - stream: meta.stream, - status: status.as_u16(), - upstream_request_id: http::extract_request_id(upstream_res.headers()), - request_headers, - request_body, - ttfb_ms: None, - start, - }; - let model_override = meta.model_override(); - if response_transform != FormatTransform::None { - // The body will change; let hyper recalculate the content length. - response_headers.remove(axum::http::header::CONTENT_LENGTH); - } - let model_for_tokens = meta - .mapped_model - .as_deref() - .or(meta.original_model.as_deref()) - .map(|value| value.to_string()); - let request_tracker = token_rate - .register(model_for_tokens, meta.estimated_input_tokens) - .await; - let should_stream = meta.stream && !status.is_client_error() && !status.is_server_error(); - if should_stream { - dispatch::build_stream_response( - status, - upstream_res, - response_headers, - context, - log, - request_tracker, - response_transform, - model_override, - meta.estimated_input_tokens, - ) - .await - } else { - dispatch::build_buffered_response( - status, - upstream_res, - response_headers, - context, - log, - request_tracker, - response_transform, - model_override, - meta.estimated_input_tokens, - ) - .await - } -} - -pub(super) async fn build_proxy_response_buffered( - meta: &RequestMeta, - provider: &str, - upstream_id: &str, - inbound_path: &str, - upstream_res: reqwest::Response, - log: Arc, - token_rate: Arc, - start: Instant, - response_transform: FormatTransform, - request_detail: Option, -) -> Response { - let status = upstream_res.status(); - let mut response_headers = http::filter_response_headers(upstream_res.headers()); - let (request_headers, request_body) = request_detail - .map(|detail| (detail.request_headers, detail.request_body)) - .unwrap_or((None, None)); - let context = LogContext { - path: inbound_path.to_string(), - provider: provider.to_string(), - upstream_id: upstream_id.to_string(), - model: meta.original_model.clone(), - mapped_model: meta.mapped_model.clone(), - stream: meta.stream, - status: status.as_u16(), - upstream_request_id: http::extract_request_id(upstream_res.headers()), - request_headers, - request_body, - ttfb_ms: None, - start, - }; - let model_override = meta.model_override(); - if response_transform != FormatTransform::None { - response_headers.remove(axum::http::header::CONTENT_LENGTH); - } - let model_for_tokens = meta - .mapped_model - .as_deref() - .or(meta.original_model.as_deref()) - .map(|value| value.to_string()); - let request_tracker = token_rate - .register(model_for_tokens, meta.estimated_input_tokens) - .await; - dispatch::build_buffered_response( - status, - upstream_res, - response_headers, - context, - log, - request_tracker, - response_transform, - model_override, - meta.estimated_input_tokens, - ) - .await -} - -#[cfg(test)] -fn stream_chat_to_responses( - upstream: impl futures_util::stream::Stream> + Unpin + Send + 'static, - context: LogContext, - log: Arc, - token_tracker: super::token_rate::RequestTokenTracker, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - chat_to_responses::stream_chat_to_responses(upstream, context, log, token_tracker) -} - -fn responses_event_sse(event: Value) -> Bytes { - Bytes::from(format!("data: {}\n\n", event.to_string())) -} - -fn anthropic_event_sse(event_type: &str, event: Value) -> Bytes { - Bytes::from(format!("event: {event_type}\ndata: {}\n\n", event.to_string())) -} - -fn now_ms() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_millis() as u64 -} - -mod chat_to_responses; -mod anthropic_to_responses; -mod responses_to_chat; -mod responses_to_anthropic; -mod kiro_to_anthropic; -mod kiro_to_responses; -mod kiro_to_responses_helpers; -mod kiro_to_responses_stream; -mod dispatch; -mod streaming; -mod token_count; -mod upstream_read; -mod upstream_stream; - -#[cfg(test)] -#[path = "response.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/response.test.rs b/src-tauri/src/proxy/response.test.rs deleted file mode 100644 index f0615c5..0000000 --- a/src-tauri/src/proxy/response.test.rs +++ /dev/null @@ -1,476 +0,0 @@ -use super::*; -use axum::body::Bytes; -use futures_util::StreamExt; -use serde_json::{json, Value}; -use sqlx::{sqlite::SqlitePoolOptions, Row, SqlitePool}; -use std::{ - sync::Arc, - time::{Duration, Instant}, -}; - -use super::super::log::{LogContext, LogWriter}; -use tokio::time::{sleep, Instant as TokioInstant}; - -fn run_async(future: impl std::future::Future) -> T { - tokio::runtime::Runtime::new() - .expect("create tokio runtime") - .block_on(future) -} - -async fn create_test_sqlite_pool() -> SqlitePool { - let pool = SqlitePoolOptions::new() - .max_connections(1) - .connect("sqlite::memory:") - .await - .expect("connect sqlite"); - crate::proxy::sqlite::init_schema(&pool) - .await - .expect("init sqlite schema"); - pool -} - -fn parse_sse_json(bytes: &Bytes) -> Option { - let text = String::from_utf8_lossy(bytes); - let Some(data) = text.strip_prefix("data: ") else { - panic!("unexpected SSE chunk: {text:?}"); - }; - let data = data.trim(); - if data == "[DONE]" { - return None; - } - Some(serde_json::from_str::(data).expect("parse SSE JSON")) -} - -async fn setup_responses_stream() -> (Arc, LogContext, SqlitePool) { - let sqlite_pool = create_test_sqlite_pool().await; - let log = Arc::new(LogWriter::new(Some(sqlite_pool.clone()))); - let context = LogContext { - path: "/v1/responses".to_string(), - provider: "openai-response".to_string(), - upstream_id: "unit-test".to_string(), - model: Some("unit-model".to_string()), - mapped_model: Some("unit-model".to_string()), - stream: true, - status: 200, - upstream_request_id: None, - request_headers: None, - request_body: None, - ttfb_ms: None, - start: Instant::now(), - }; - (log, context, sqlite_pool) -} - -async fn collect_responses_to_chat_chunks( - upstream: impl futures_util::stream::Stream> - + Unpin - + Send - + 'static, - context: LogContext, - log: Arc, -) -> Vec { - let token_tracker = super::super::token_rate::TokenRateTracker::new() - .register(None, None) - .await; - super::responses_to_chat::stream_responses_to_chat(upstream, context, log, token_tracker) - .map(|item| item.expect("stream item")) - .collect() - .await -} - -async fn read_first_usage_tokens( - pool: &SqlitePool, -) -> (Option, Option, Option) { - let deadline = TokioInstant::now() + std::time::Duration::from_secs(2); - loop { - let row = sqlx::query( - "SELECT input_tokens, output_tokens, total_tokens FROM request_logs ORDER BY id LIMIT 1", - ) - .fetch_optional(pool) - .await - .ok() - .flatten(); - if let Some(row) = row { - let input_tokens = row.try_get::, _>("input_tokens").unwrap_or_default(); - let output_tokens = row - .try_get::, _>("output_tokens") - .unwrap_or_default(); - let total_tokens = row.try_get::, _>("total_tokens").unwrap_or_default(); - return (input_tokens, output_tokens, total_tokens); - } - if TokioInstant::now() >= deadline { - panic!("log entry"); - } - sleep(Duration::from_millis(10)).await; - } -} - -#[test] -fn stream_responses_to_chat_emits_role_delta_and_done_and_logs_usage() { - run_async(async { - let (log, context, sqlite_pool) = setup_responses_stream().await; - - let upstream = futures_util::stream::iter(vec![ - Ok(Bytes::from( - "data: {\"type\":\"response.output_text.delta\",\"delta\":\"Hello\"}\n\n", - )), - Ok(Bytes::from( - "data: {\"type\":\"response.output_text.delta\",\"delta\":\" world\"}\n\n", - )), - // Usage can appear on a different event; collector should still pick it up. - Ok(Bytes::from( - "data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"total_tokens\":3}}}\n\n", - )), - Ok(Bytes::from("data: [DONE]\n\n")), - ]); - - let chunks = collect_responses_to_chat_chunks(upstream, context, log.clone()).await; - - assert_eq!(chunks.len(), 5); - - let first = parse_sse_json(&chunks[0]).expect("json"); - let id = first["id"].as_str().expect("id"); - assert!(id.starts_with("chatcmpl_proxy_")); - assert_eq!(first["model"], json!("unit-model")); - assert_eq!(first["choices"][0]["delta"]["role"], json!("assistant")); - assert_eq!(first["choices"][0]["delta"]["content"], json!("")); - - let second = parse_sse_json(&chunks[1]).expect("json"); - assert_eq!(second["id"], json!(id)); - assert_eq!(second["choices"][0]["delta"]["content"], json!("Hello")); - - let third = parse_sse_json(&chunks[2]).expect("json"); - assert_eq!(third["id"], json!(id)); - assert_eq!(third["choices"][0]["delta"]["content"], json!(" world")); - - let done = parse_sse_json(&chunks[3]).expect("json"); - assert_eq!(done["id"], json!(id)); - assert_eq!(done["choices"][0]["finish_reason"], json!("stop")); - - assert_eq!(String::from_utf8_lossy(&chunks[4]), "data: [DONE]\n\n"); - - let (input_tokens, output_tokens, total_tokens) = - read_first_usage_tokens(&sqlite_pool).await; - assert_eq!(input_tokens, Some(1)); - assert_eq!(output_tokens, Some(2)); - assert_eq!(total_tokens, Some(3)); - }); -} - -#[test] -fn stream_responses_to_chat_emits_tool_call_deltas_and_finish_reason() { - run_async(async { - let (log, context, _sqlite_pool) = setup_responses_stream().await; - - let upstream = futures_util::stream::iter(vec![ - Ok(Bytes::from( - "data: {\"type\":\"response.output_item.added\",\"output_index\":0,\"item\":{\"id\":\"fc_1\",\"type\":\"function_call\",\"status\":\"in_progress\",\"call_id\":\"call_foo\",\"name\":\"getRandomNumber\",\"arguments\":\"\"}}\n\n", - )), - Ok(Bytes::from( - "data: {\"type\":\"response.function_call_arguments.delta\",\"item_id\":\"fc_1\",\"output_index\":0,\"delta\":\"{\\\"a\\\":\\\"0\\\"\"}\n\n", - )), - Ok(Bytes::from( - "data: {\"type\":\"response.function_call_arguments.delta\",\"item_id\":\"fc_1\",\"output_index\":0,\"delta\":\",\\\"b\\\":\\\"100\\\"}\"}\n\n", - )), - Ok(Bytes::from("data: [DONE]\n\n")), - ]); - - let chunks = collect_responses_to_chat_chunks(upstream, context, log.clone()).await; - - assert_eq!(chunks.len(), 6); - - let first = parse_sse_json(&chunks[0]).expect("json"); - assert_eq!(first["choices"][0]["delta"]["role"], json!("assistant")); - - let initial = parse_sse_json(&chunks[1]).expect("json"); - assert_eq!(initial["choices"][0]["delta"]["tool_calls"][0]["id"], json!("call_foo")); - assert_eq!( - initial["choices"][0]["delta"]["tool_calls"][0]["function"]["name"], - json!("getRandomNumber") - ); - assert_eq!( - initial["choices"][0]["delta"]["tool_calls"][0]["function"]["arguments"], - json!("") - ); - - let delta_1 = parse_sse_json(&chunks[2]).expect("json"); - assert_eq!( - delta_1["choices"][0]["delta"]["tool_calls"][0]["function"]["arguments"], - json!("{\"a\":\"0\"") - ); - - let delta_2 = parse_sse_json(&chunks[3]).expect("json"); - assert_eq!( - delta_2["choices"][0]["delta"]["tool_calls"][0]["function"]["arguments"], - json!(",\"b\":\"100\"}") - ); - - let done = parse_sse_json(&chunks[4]).expect("json"); - assert_eq!(done["choices"][0]["finish_reason"], json!("tool_calls")); - - assert_eq!(String::from_utf8_lossy(&chunks[5]), "data: [DONE]\n\n"); - }); -} - -#[test] -fn stream_responses_to_chat_emits_content_parts_for_non_text() { - run_async(async { - let (log, context, _sqlite_pool) = setup_responses_stream().await; - - let upstream = futures_util::stream::iter(vec![ - Ok(Bytes::from( - "data: {\"type\":\"response.output_text.delta\",\"delta\":\"Hello\"}\n\n", - )), - Ok(Bytes::from( - "data: {\"type\":\"response.output_item.done\",\"output_index\":0,\"item\":{\"id\":\"msg_1\",\"type\":\"message\",\"status\":\"completed\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"Hello\",\"annotations\":[]},{\"type\":\"output_image\",\"image_url\":{\"url\":\"https://example.com/a.png\"}}]}}\n\n", - )), - Ok(Bytes::from("data: [DONE]\n\n")), - ]); - - let chunks = collect_responses_to_chat_chunks(upstream, context, log.clone()).await; - - assert_eq!(chunks.len(), 5); - - let first = parse_sse_json(&chunks[0]).expect("json"); - assert_eq!(first["choices"][0]["delta"]["role"], json!("assistant")); - - let text_delta = parse_sse_json(&chunks[1]).expect("json"); - assert_eq!(text_delta["choices"][0]["delta"]["content"], json!("Hello")); - - let parts_delta = parse_sse_json(&chunks[2]).expect("json"); - assert_eq!(parts_delta["choices"][0]["delta"]["content"][0]["type"], json!("image_url")); - assert_eq!( - parts_delta["choices"][0]["delta"]["content"][0]["image_url"]["url"], - json!("https://example.com/a.png") - ); - - let done = parse_sse_json(&chunks[3]).expect("json"); - assert_eq!(done["choices"][0]["finish_reason"], json!("stop")); - - assert_eq!(String::from_utf8_lossy(&chunks[4]), "data: [DONE]\n\n"); - }); -} - -#[test] -fn stream_chat_to_responses_handles_chunk_boundaries_and_emits_created_delta_done_and_logs_usage() { - run_async(async { - let sqlite_pool = create_test_sqlite_pool().await; - let log = Arc::new(LogWriter::new(Some(sqlite_pool.clone()))); - let context = LogContext { - path: "/v1/chat/completions".to_string(), - provider: "openai".to_string(), - upstream_id: "unit-test".to_string(), - model: Some("unit-model".to_string()), - mapped_model: Some("unit-model".to_string()), - stream: true, - status: 200, - upstream_request_id: None, - request_headers: None, - request_body: None, - ttfb_ms: None, - start: Instant::now(), - }; - - let first_event = "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n"; - let (first_a, first_b) = first_event.split_at(12); - - let upstream = futures_util::stream::iter(vec![ - Ok::(Bytes::from(first_a.to_string())), - Ok(Bytes::from(first_b.to_string())), - Ok(Bytes::from( - "data: {\"choices\":[{\"delta\":{\"content\":\"!\"}}]}\n\n", - )), - // Chat usage format. - Ok(Bytes::from( - "data: {\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n", - )), - Ok(Bytes::from("data: [DONE]\n\n")), - ]); - - let token_tracker = super::super::token_rate::TokenRateTracker::new() - .register(None, None) - .await; - let chunks: Vec = - stream_chat_to_responses(upstream, context, log.clone(), token_tracker) - .map(|item| item.expect("stream item")) - .collect() - .await; - - assert_eq!(chunks.len(), 10); - - let created = parse_sse_json(&chunks[0]).expect("json"); - assert_eq!(created["type"], json!("response.created")); - let response_id = created["response"]["id"].as_str().expect("response.id"); - assert!(response_id.starts_with("resp_")); - - let output_item_added = parse_sse_json(&chunks[1]).expect("json"); - assert_eq!(output_item_added["type"], json!("response.output_item.added")); - assert_eq!(output_item_added["output_index"], json!(0)); - let item_id = output_item_added["item"]["id"].as_str().expect("item.id"); - assert!(item_id.starts_with("msg_")); - - let content_part_added = parse_sse_json(&chunks[2]).expect("json"); - assert_eq!(content_part_added["type"], json!("response.content_part.added")); - assert_eq!(content_part_added["item_id"], json!(item_id)); - assert_eq!(content_part_added["output_index"], json!(0)); - assert_eq!(content_part_added["content_index"], json!(0)); - assert_eq!(content_part_added["part"]["type"], json!("output_text")); - assert_eq!(content_part_added["part"]["text"], json!("")); - - let delta_1 = parse_sse_json(&chunks[3]).expect("json"); - assert_eq!(delta_1["type"], json!("response.output_text.delta")); - assert_eq!(delta_1["item_id"], json!(item_id)); - assert_eq!(delta_1["delta"], json!("Hi")); - assert_eq!(delta_1["sequence_number"], json!(3)); - - let delta_2 = parse_sse_json(&chunks[4]).expect("json"); - assert_eq!(delta_2["type"], json!("response.output_text.delta")); - assert_eq!(delta_2["item_id"], json!(item_id)); - assert_eq!(delta_2["delta"], json!("!")); - assert_eq!(delta_2["sequence_number"], json!(4)); - - let output_text_done = parse_sse_json(&chunks[5]).expect("json"); - assert_eq!(output_text_done["type"], json!("response.output_text.done")); - assert_eq!(output_text_done["item_id"], json!(item_id)); - assert_eq!(output_text_done["text"], json!("Hi!")); - - let content_part_done = parse_sse_json(&chunks[6]).expect("json"); - assert_eq!(content_part_done["type"], json!("response.content_part.done")); - assert_eq!(content_part_done["item_id"], json!(item_id)); - assert_eq!(content_part_done["part"]["text"], json!("Hi!")); - - let output_item_done = parse_sse_json(&chunks[7]).expect("json"); - assert_eq!(output_item_done["type"], json!("response.output_item.done")); - assert_eq!(output_item_done["output_index"], json!(0)); - assert_eq!(output_item_done["item"]["id"], json!(item_id)); - assert_eq!(output_item_done["item"]["content"][0]["type"], json!("output_text")); - assert_eq!(output_item_done["item"]["content"][0]["text"], json!("Hi!")); - - let completed = parse_sse_json(&chunks[8]).expect("json"); - assert_eq!(completed["type"], json!("response.completed")); - assert_eq!(completed["response"]["id"], json!(response_id)); - assert_eq!(completed["response"]["output"][0]["id"], json!(item_id)); - assert_eq!(completed["response"]["output"][0]["content"][0]["text"], json!("Hi!")); - assert_eq!(completed["response"]["usage"]["input_tokens"], json!(1)); - assert_eq!(completed["response"]["usage"]["output_tokens"], json!(2)); - assert_eq!(completed["response"]["usage"]["total_tokens"], json!(3)); - - assert_eq!(String::from_utf8_lossy(&chunks[9]), "data: [DONE]\n\n"); - - let (input_tokens, output_tokens, total_tokens) = - read_first_usage_tokens(&sqlite_pool).await; - assert_eq!(input_tokens, Some(1)); - assert_eq!(output_tokens, Some(2)); - assert_eq!(total_tokens, Some(3)); - }); -} - -#[test] -fn stream_chat_to_responses_emits_function_call_events_and_includes_them_in_completed_response() { - run_async(async { - let sqlite_pool = create_test_sqlite_pool().await; - let log = Arc::new(LogWriter::new(Some(sqlite_pool.clone()))); - let context = LogContext { - path: "/v1/chat/completions".to_string(), - provider: "openai".to_string(), - upstream_id: "unit-test".to_string(), - model: Some("unit-model".to_string()), - mapped_model: Some("unit-model".to_string()), - stream: true, - status: 200, - upstream_request_id: None, - request_headers: None, - request_body: None, - ttfb_ms: None, - start: Instant::now(), - }; - - let upstream = futures_util::stream::iter(vec![ - Ok::(Bytes::from( - "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_foo\",\"type\":\"function\",\"function\":{\"name\":\"getRandomNumber\",\"arguments\":\"{\\\"a\\\":\\\"0\\\"\"}}]}}]}\n\n", - )), - Ok(Bytes::from( - "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\",\\\"b\\\":\\\"100\\\"}\"}}]}}]}\n\n", - )), - // Chat usage format. - Ok(Bytes::from( - "data: {\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n", - )), - Ok(Bytes::from("data: [DONE]\n\n")), - ]); - - let token_tracker = super::super::token_rate::TokenRateTracker::new() - .register(None, None) - .await; - let chunks: Vec = - stream_chat_to_responses(upstream, context, log.clone(), token_tracker) - .map(|item| item.expect("stream item")) - .collect() - .await; - - assert_eq!(chunks.len(), 8); - - let created = parse_sse_json(&chunks[0]).expect("json"); - assert_eq!(created["type"], json!("response.created")); - let response_id = created["response"]["id"].as_str().expect("response.id"); - assert!(response_id.starts_with("resp_")); - - let output_item_added = parse_sse_json(&chunks[1]).expect("json"); - assert_eq!(output_item_added["type"], json!("response.output_item.added")); - assert_eq!(output_item_added["output_index"], json!(0)); - assert_eq!(output_item_added["item"]["type"], json!("function_call")); - assert_eq!(output_item_added["item"]["call_id"], json!("call_foo")); - assert_eq!(output_item_added["item"]["name"], json!("getRandomNumber")); - let item_id = output_item_added["item"]["id"].as_str().expect("item.id"); - assert!(item_id.starts_with("fc_")); - - let delta_1 = parse_sse_json(&chunks[2]).expect("json"); - assert_eq!(delta_1["type"], json!("response.function_call_arguments.delta")); - assert_eq!(delta_1["item_id"], json!(item_id)); - assert_eq!(delta_1["output_index"], json!(0)); - assert_eq!(delta_1["delta"], json!("{\"a\":\"0\"")); - - let delta_2 = parse_sse_json(&chunks[3]).expect("json"); - assert_eq!(delta_2["type"], json!("response.function_call_arguments.delta")); - assert_eq!(delta_2["item_id"], json!(item_id)); - assert_eq!(delta_2["output_index"], json!(0)); - assert_eq!(delta_2["delta"], json!(",\"b\":\"100\"}")); - - let args_done = parse_sse_json(&chunks[4]).expect("json"); - assert_eq!(args_done["type"], json!("response.function_call_arguments.done")); - assert_eq!(args_done["item_id"], json!(item_id)); - assert_eq!(args_done["name"], json!("getRandomNumber")); - assert_eq!(args_done["arguments"], json!("{\"a\":\"0\",\"b\":\"100\"}")); - - let item_done = parse_sse_json(&chunks[5]).expect("json"); - assert_eq!(item_done["type"], json!("response.output_item.done")); - assert_eq!(item_done["item"]["id"], json!(item_id)); - assert_eq!(item_done["item"]["status"], json!("completed")); - assert_eq!(item_done["item"]["type"], json!("function_call")); - assert_eq!(item_done["item"]["call_id"], json!("call_foo")); - assert_eq!(item_done["item"]["name"], json!("getRandomNumber")); - assert_eq!(item_done["item"]["arguments"], json!("{\"a\":\"0\",\"b\":\"100\"}")); - - let completed = parse_sse_json(&chunks[6]).expect("json"); - assert_eq!(completed["type"], json!("response.completed")); - assert_eq!(completed["response"]["id"], json!(response_id)); - assert_eq!(completed["response"]["output"][0]["type"], json!("function_call")); - assert_eq!(completed["response"]["output"][0]["call_id"], json!("call_foo")); - assert_eq!(completed["response"]["output"][0]["name"], json!("getRandomNumber")); - assert_eq!( - completed["response"]["output"][0]["arguments"], - json!("{\"a\":\"0\",\"b\":\"100\"}") - ); - assert_eq!(completed["response"]["usage"]["input_tokens"], json!(1)); - assert_eq!(completed["response"]["usage"]["output_tokens"], json!(2)); - assert_eq!(completed["response"]["usage"]["total_tokens"], json!(3)); - - assert_eq!(String::from_utf8_lossy(&chunks[7]), "data: [DONE]\n\n"); - - let (input_tokens, output_tokens, total_tokens) = - read_first_usage_tokens(&sqlite_pool).await; - assert_eq!(input_tokens, Some(1)); - assert_eq!(output_tokens, Some(2)); - assert_eq!(total_tokens, Some(3)); - }); -} diff --git a/src-tauri/src/proxy/response/anthropic_to_responses.rs b/src-tauri/src/proxy/response/anthropic_to_responses.rs deleted file mode 100644 index f0a5924..0000000 --- a/src-tauri/src/proxy/response/anthropic_to_responses.rs +++ /dev/null @@ -1,572 +0,0 @@ -use axum::body::Bytes; -use futures_util::StreamExt; -use serde_json::{json, Value}; -use std::{collections::HashMap, collections::VecDeque, sync::Arc}; - -use super::super::log::{build_log_entry, LogContext, LogWriter}; -use super::super::sse::SseEventParser; -use super::super::token_rate::RequestTokenTracker; -use super::super::usage::SseUsageCollector; -use format::{snapshot_to_output_item, usage_to_value, OutputItemSnapshot}; - -mod format; - -pub(super) fn stream_anthropic_to_responses( - upstream: impl futures_util::stream::Stream> - + Unpin - + Send - + 'static, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = AnthropicToResponsesState::new(upstream, context, log, token_tracker); - futures_util::stream::try_unfold(state, |state| async move { state.step().await }) -} - -struct MessageOutput { - id: String, - output_index: u64, - text: String, -} - -struct FunctionCallOutput { - id: String, - output_index: u64, - call_id: String, - name: String, - arguments: String, -} - -struct AnthropicToResponsesState { - upstream: S, - parser: SseEventParser, - collector: SseUsageCollector, - log: Arc, - context: LogContext, - token_tracker: RequestTokenTracker, - out: VecDeque, - id_seed: u64, - response_id: String, - created_at: i64, - model: String, - next_output_index: u64, - message: Option, - function_calls: Vec>, - // Claude stream uses block index; map it to our function_call slot. - tool_call_by_block_index: HashMap, - sequence: u64, - sent_done: bool, - logged: bool, - upstream_ended: bool, -} - -impl AnthropicToResponsesState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new( - upstream: S, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - ) -> Self { - let now_ms = super::now_ms(); - let created_at = (now_ms / 1000) as i64; - let model = context - .model - .clone() - .unwrap_or_else(|| "unknown".to_string()); - - let mut state = Self { - upstream, - parser: SseEventParser::new(), - collector: SseUsageCollector::new(), - log, - context, - token_tracker, - out: VecDeque::new(), - id_seed: now_ms, - response_id: format!("resp_{now_ms}"), - created_at, - model, - next_output_index: 0, - message: None, - function_calls: Vec::new(), - tool_call_by_block_index: HashMap::new(), - sequence: 0, - sent_done: false, - logged: false, - upstream_ended: false, - }; - state.push_response_created(); - state - } - - async fn step(mut self) -> Result, std::io::Error> { - loop { - if let Some(next) = self.out.pop_front() { - return Ok(Some((next, self))); - } - - if self.upstream_ended { - return Ok(None); - } - - match self.upstream.next().await { - Some(Ok(chunk)) => { - self.collector.push_chunk(&chunk); - let mut events = Vec::new(); - self.parser.push_chunk(&chunk, |data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - } - Some(Err(err)) => { - self.log_usage_once(); - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); - } - None => { - self.upstream_ended = true; - let mut events = Vec::new(); - self.parser.finish(|data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - if !self.sent_done { - self.push_done(); - } - self.log_usage_once(); - if self.out.is_empty() { - return Ok(None); - } - } - } - } - } - - fn handle_event(&mut self, data: &str, token_texts: &mut Vec) { - if self.sent_done { - return; - } - // Claude stream may include event: lines; parser only yields data: payload. - let Ok(value) = serde_json::from_str::(data) else { - return; - }; - let Some(event_type) = value.get("type").and_then(Value::as_str) else { - return; - }; - - match event_type { - "message_start" => { - // Preserve the original requested model alias if present (consistent with other - // format conversions); only fall back to upstream model when we have no hint. - if self.model == "unknown" { - if let Some(model) = value - .get("message") - .and_then(|m| m.get("model")) - .and_then(Value::as_str) - { - if !model.is_empty() { - self.model = model.to_string(); - } - } - } - } - "content_block_start" => self.handle_content_block_start(&value), - "content_block_delta" => self.handle_content_block_delta(&value, token_texts), - "message_stop" => { - self.push_done(); - } - _ => {} - } - } - - fn handle_content_block_start(&mut self, value: &Value) { - let index = value.get("index").and_then(Value::as_u64).unwrap_or(0) as usize; - let Some(block) = value.get("content_block").and_then(Value::as_object) else { - return; - }; - let block_type = block.get("type").and_then(Value::as_str).unwrap_or(""); - match block_type { - "text" => { - self.ensure_message_output(); - } - "tool_use" => { - let call_id = block.get("id").and_then(Value::as_str).unwrap_or(""); - let name = block.get("name").and_then(Value::as_str).unwrap_or(""); - let tool_index = self.ensure_function_call_output(index, Some(call_id), Some(name)); - self.tool_call_by_block_index.insert(index, tool_index); - } - _ => {} - } - } - - fn handle_content_block_delta(&mut self, value: &Value, token_texts: &mut Vec) { - let index = value.get("index").and_then(Value::as_u64).unwrap_or(0) as usize; - let Some(delta) = value.get("delta").and_then(Value::as_object) else { - return; - }; - let delta_type = delta.get("type").and_then(Value::as_str).unwrap_or(""); - match delta_type { - "text_delta" => { - let Some(text) = delta.get("text").and_then(Value::as_str) else { - return; - }; - self.ensure_message_output(); - let (item_id, output_index) = { - let message = self.message.as_mut().expect("message output exists"); - message.text.push_str(text); - (message.id.clone(), message.output_index) - }; - token_texts.push(text.to_string()); - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_text.delta", - "item_id": item_id, - "output_index": output_index, - "content_index": 0, - "delta": text, - "sequence_number": sequence_number - }))); - } - "input_json_delta" => { - let Some(partial_json) = delta.get("partial_json").and_then(Value::as_str) else { - return; - }; - let call_index = match self.tool_call_by_block_index.get(&index) { - Some(idx) => *idx, - None => { - let tool_index = self.ensure_function_call_output(index, None, None); - self.tool_call_by_block_index.insert(index, tool_index); - tool_index - } - }; - let (item_id, output_index) = { - let state = self - .function_calls - .get_mut(call_index) - .and_then(Option::as_mut) - .expect("call output exists"); - state.arguments.push_str(partial_json); - (state.id.clone(), state.output_index) - }; - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.function_call_arguments.delta", - "item_id": item_id, - "output_index": output_index, - "delta": partial_json, - "sequence_number": sequence_number - }))); - } - _ => {} - } - } - - fn ensure_message_output(&mut self) { - if self.message.is_some() { - return; - } - let output_index = self.next_output_index; - self.next_output_index += 1; - let message_id = format!("msg_{}", self.id_seed); - self.push_message_item_added(&message_id, output_index); - self.push_message_content_part_added(&message_id, output_index); - self.message = Some(MessageOutput { - id: message_id, - output_index, - text: String::new(), - }); - } - - fn ensure_function_call_output(&mut self, block_index: usize, call_id: Option<&str>, name: Option<&str>) -> usize { - // Allocate one function_call per Claude content block index. - let call_index = self.function_calls.len(); - let output_index = self.next_output_index; - self.next_output_index += 1; - - let item_id = format!("fc_{}_{}", self.id_seed, block_index); - let call_id = call_id - .map(|v| v.to_string()) - .unwrap_or_else(|| format!("call_{}_{}", self.id_seed, block_index)); - let name = name.unwrap_or("").to_string(); - - self.push_function_call_item_added(&item_id, output_index, &call_id, &name); - self.function_calls.push(Some(FunctionCallOutput { - id: item_id, - output_index, - call_id, - name, - arguments: String::new(), - })); - call_index - } - - fn push_response_created(&mut self) { - let response = self.build_response_object("in_progress", Vec::new(), None, None); - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.created", - "response": response, - "sequence_number": sequence_number - }))); - } - - fn push_message_item_added(&mut self, item_id: &str, output_index: u64) { - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_item.added", - "output_index": output_index, - "item": { - "id": item_id, - "type": "message", - "status": "in_progress", - "role": "assistant", - "content": [] - }, - "sequence_number": sequence_number - }))); - } - - fn push_message_content_part_added(&mut self, item_id: &str, output_index: u64) { - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.content_part.added", - "item_id": item_id, - "output_index": output_index, - "content_index": 0, - "part": { - "type": "output_text", - "text": "", - "annotations": [] - }, - "sequence_number": sequence_number - }))); - } - - fn push_function_call_item_added( - &mut self, - item_id: &str, - output_index: u64, - call_id: &str, - name: &str, - ) { - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_item.added", - "output_index": output_index, - "item": { - "id": item_id, - "type": "function_call", - "status": "in_progress", - "arguments": "", - "call_id": call_id, - "name": name - }, - "sequence_number": sequence_number - }))); - } - - fn push_done(&mut self) { - if self.sent_done { - return; - } - self.sent_done = true; - - let completed_at = (super::now_ms() / 1000) as i64; - let usage_snapshot = self.collector.finish(); - let usage = usage_snapshot - .usage - .clone() - .map(|usage| usage_to_value(usage, usage_snapshot.cached_tokens)); - - let mut snapshots = Vec::new(); - if let Some(message) = &self.message { - snapshots.push(OutputItemSnapshot::Message { - id: message.id.clone(), - output_index: message.output_index, - text: message.text.clone(), - }); - } - for call in &self.function_calls { - let Some(call) = call else { - continue; - }; - snapshots.push(OutputItemSnapshot::FunctionCall { - id: call.id.clone(), - output_index: call.output_index, - call_id: call.call_id.clone(), - name: call.name.clone(), - arguments: call.arguments.clone(), - }); - } - snapshots.sort_by_key(|item| match item { - OutputItemSnapshot::Message { output_index, .. } => *output_index, - OutputItemSnapshot::FunctionCall { output_index, .. } => *output_index, - }); - - let output = snapshots - .iter() - .map(snapshot_to_output_item) - .collect::>(); - for snapshot in &snapshots { - self.push_item_done_events(snapshot); - } - - let response = self.build_response_object("completed", output, usage, Some(completed_at)); - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.completed", - "response": response, - "sequence_number": sequence_number - }))); - self.out.push_back(Bytes::from("data: [DONE]\n\n")); - } - - fn push_item_done_events(&mut self, snapshot: &OutputItemSnapshot) { - match snapshot { - OutputItemSnapshot::Message { - id, - output_index, - text, - } => self.push_message_done_events(id, *output_index, text), - OutputItemSnapshot::FunctionCall { - id, - output_index, - call_id, - name, - arguments, - } => self.push_function_call_done_events(id, *output_index, call_id, name, arguments), - } - } - - fn push_message_done_events(&mut self, item_id: &str, output_index: u64, text: &str) { - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_text.done", - "item_id": item_id, - "output_index": output_index, - "content_index": 0, - "text": text, - "sequence_number": sequence_number - }))); - - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.content_part.done", - "item_id": item_id, - "output_index": output_index, - "content_index": 0, - "part": { - "type": "output_text", - "text": text, - "annotations": [] - }, - "sequence_number": sequence_number - }))); - - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_item.done", - "output_index": output_index, - "item": { - "id": item_id, - "type": "message", - "status": "completed", - "role": "assistant", - "content": [ - { "type": "output_text", "text": text, "annotations": [] } - ] - }, - "sequence_number": sequence_number - }))); - } - - fn push_function_call_done_events( - &mut self, - item_id: &str, - output_index: u64, - call_id: &str, - name: &str, - arguments: &str, - ) { - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.function_call_arguments.done", - "item_id": item_id, - "output_index": output_index, - "arguments": arguments, - "sequence_number": sequence_number, - "name": name - }))); - - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_item.done", - "output_index": output_index, - "item": { - "id": item_id, - "type": "function_call", - "status": "completed", - "arguments": arguments, - "call_id": call_id, - "name": name - }, - "sequence_number": sequence_number - }))); - } - - fn build_response_object( - &self, - status: &str, - output: Vec, - usage: Option, - completed_at: Option, - ) -> Value { - json!({ - "id": self.response_id.as_str(), - "object": "response", - "created_at": self.created_at, - "model": self.model.as_str(), - "status": status, - "output": output, - "parallel_tool_calls": self.parallel_tool_calls(), - "completed_at": completed_at, - "usage": usage, - "error": null, - "metadata": {} - }) - } - - fn parallel_tool_calls(&self) -> bool { - self.function_calls.iter().filter(|call| call.is_some()).count() > 1 - } - - fn log_usage_once(&mut self) { - if self.logged { - return; - } - self.logged = true; - let entry = build_log_entry(&self.context, self.collector.finish(), None); - self.log.clone().write_detached(entry); - } - - fn next_sequence_number(&mut self) -> u64 { - let current = self.sequence; - self.sequence += 1; - current - } -} diff --git a/src-tauri/src/proxy/response/anthropic_to_responses/format.rs b/src-tauri/src/proxy/response/anthropic_to_responses/format.rs deleted file mode 100644 index 5b979e7..0000000 --- a/src-tauri/src/proxy/response/anthropic_to_responses/format.rs +++ /dev/null @@ -1,65 +0,0 @@ -use serde_json::{json, Value}; - -use super::super::super::log::TokenUsage; - -pub(super) enum OutputItemSnapshot { - Message { - id: String, - output_index: u64, - text: String, - }, - FunctionCall { - id: String, - output_index: u64, - call_id: String, - name: String, - arguments: String, - }, -} - -pub(super) fn usage_to_value(usage: TokenUsage, cached_tokens: Option) -> Value { - let input_tokens = usage.input_tokens.unwrap_or(0); - let output_tokens = usage.output_tokens.unwrap_or(0); - let total_tokens = usage - .total_tokens - .or_else(|| input_tokens.checked_add(output_tokens)) - .unwrap_or(0); - let cached_tokens = cached_tokens.unwrap_or(0); - - json!({ - "input_tokens": input_tokens, - "input_tokens_details": { "cached_tokens": cached_tokens }, - "output_tokens": output_tokens, - "output_tokens_details": { "reasoning_tokens": 0 }, - "total_tokens": total_tokens - }) -} - -pub(super) fn snapshot_to_output_item(snapshot: &OutputItemSnapshot) -> Value { - match snapshot { - OutputItemSnapshot::Message { id, text, .. } => json!({ - "id": id, - "type": "message", - "status": "completed", - "role": "assistant", - "content": [ - { "type": "output_text", "text": text, "annotations": [] } - ] - }), - OutputItemSnapshot::FunctionCall { - id, - call_id, - name, - arguments, - .. - } => json!({ - "id": id, - "type": "function_call", - "status": "completed", - "call_id": call_id, - "name": name, - "arguments": arguments - }), - } -} - diff --git a/src-tauri/src/proxy/response/chat_to_responses.rs b/src-tauri/src/proxy/response/chat_to_responses.rs deleted file mode 100644 index 1b71082..0000000 --- a/src-tauri/src/proxy/response/chat_to_responses.rs +++ /dev/null @@ -1,593 +0,0 @@ -use axum::body::Bytes; -use futures_util::StreamExt; -use serde_json::{json, Value}; -use std::{collections::VecDeque, sync::Arc}; - -use super::super::log::{build_log_entry, LogContext, LogWriter}; -use super::super::sse::SseEventParser; -use super::super::token_rate::RequestTokenTracker; -use super::super::usage::SseUsageCollector; -use format::{snapshot_to_output_item, usage_to_value, OutputItemSnapshot}; - -mod format; - -pub(super) fn stream_chat_to_responses( - upstream: impl futures_util::stream::Stream> - + Unpin - + Send - + 'static, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = ChatToResponsesState::new(upstream, context, log, token_tracker); - futures_util::stream::try_unfold(state, |state| async move { state.step().await }) -} - -struct MessageOutput { - id: String, - output_index: u64, - text: String, -} - -struct FunctionCallOutput { - id: String, - output_index: u64, - call_id: String, - name: String, - arguments: String, -} - -struct ChatToResponsesState { - upstream: S, - parser: SseEventParser, - collector: SseUsageCollector, - log: Arc, - context: LogContext, - token_tracker: RequestTokenTracker, - out: VecDeque, - id_seed: u64, - response_id: String, - created_at: i64, - model: String, - next_output_index: u64, - message: Option, - function_calls: Vec>, - sequence: u64, - sent_done: bool, - logged: bool, - upstream_ended: bool, -} - -impl ChatToResponsesState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new( - upstream: S, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - ) -> Self { - let now_ms = super::now_ms(); - let created_at = (now_ms / 1000) as i64; - let model = context - .model - .clone() - .unwrap_or_else(|| "unknown".to_string()); - - let mut state = Self { - upstream, - parser: SseEventParser::new(), - collector: SseUsageCollector::new(), - log, - context, - token_tracker, - out: VecDeque::new(), - id_seed: now_ms, - response_id: format!("resp_{now_ms}"), - created_at, - model, - next_output_index: 0, - message: None, - function_calls: Vec::new(), - sequence: 0, - sent_done: false, - logged: false, - upstream_ended: false, - }; - state.push_response_created(); - state - } - - async fn step(mut self) -> Result, std::io::Error> { - loop { - if let Some(next) = self.out.pop_front() { - return Ok(Some((next, self))); - } - - if self.upstream_ended { - return Ok(None); - } - - match self.upstream.next().await { - Some(Ok(chunk)) => { - if self.context.ttfb_ms.is_none() { - self.context.ttfb_ms = Some(self.context.start.elapsed().as_millis()); - } - self.collector.push_chunk(&chunk); - let mut events = Vec::new(); - self.parser.push_chunk(&chunk, |data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - } - Some(Err(err)) => { - self.log_usage_once(); - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); - } - None => { - self.upstream_ended = true; - let mut events = Vec::new(); - self.parser.finish(|data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - if !self.sent_done { - self.push_done(); - } - self.log_usage_once(); - if self.out.is_empty() { - return Ok(None); - } - } - } - } - } - - fn handle_event(&mut self, data: &str, token_texts: &mut Vec) { - if self.sent_done { - return; - } - if data == "[DONE]" { - self.push_done(); - return; - } - let Ok(value) = serde_json::from_str::(data) else { - return; - }; - - let Some(delta) = value - .get("choices") - .and_then(Value::as_array) - .and_then(|choices| choices.first()) - .and_then(|choice| choice.get("delta")) - else { - return; - }; - - if let Some(content) = delta.get("content").and_then(Value::as_str) { - self.handle_text_delta(content, token_texts); - } - if let Some(tool_calls) = delta.get("tool_calls").and_then(Value::as_array) { - for tool_call in tool_calls { - self.handle_tool_call_delta(tool_call); - } - } - if let Some(function_call) = delta.get("function_call") { - self.handle_legacy_function_call_delta(function_call); - } - } - - fn handle_text_delta(&mut self, delta: &str, token_texts: &mut Vec) { - self.ensure_message_output(); - let (item_id, output_index) = { - let message = self.message.as_mut().expect("message output must exist"); - message.text.push_str(delta); - (message.id.clone(), message.output_index) - }; - token_texts.push(delta.to_string()); - - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_text.delta", - "item_id": item_id.as_str(), - "output_index": output_index, - "content_index": 0, - "delta": delta, - "sequence_number": sequence_number - }))); - } - - fn handle_tool_call_delta(&mut self, tool_call: &Value) { - let Some(tool_call) = tool_call.as_object() else { - return; - }; - let call_index = tool_call.get("index").and_then(Value::as_u64).unwrap_or(0) as usize; - - let call_id = tool_call.get("id").and_then(Value::as_str); - let function = tool_call.get("function").and_then(Value::as_object); - let name = function - .and_then(|function| function.get("name")) - .and_then(Value::as_str); - let arguments_delta = function - .and_then(|function| function.get("arguments")) - .and_then(Value::as_str); - - if let Some(arguments_delta) = arguments_delta { - let (item_id, output_index) = { - let state = self.ensure_function_call_output(call_index, call_id, name); - state.arguments.push_str(arguments_delta); - (state.id.clone(), state.output_index) - }; - self.push_function_call_arguments_delta(&item_id, output_index, arguments_delta); - } else { - self.ensure_function_call_output(call_index, call_id, name); - } - } - - fn handle_legacy_function_call_delta(&mut self, function_call: &Value) { - let Some(function_call) = function_call.as_object() else { - return; - }; - let name = function_call.get("name").and_then(Value::as_str); - let arguments_delta = function_call.get("arguments").and_then(Value::as_str); - - if let Some(arguments_delta) = arguments_delta { - let (item_id, output_index) = { - let state = self.ensure_function_call_output(0, None, name); - state.arguments.push_str(arguments_delta); - (state.id.clone(), state.output_index) - }; - self.push_function_call_arguments_delta(&item_id, output_index, arguments_delta); - } else { - self.ensure_function_call_output(0, None, name); - } - } - - fn ensure_message_output(&mut self) { - if self.message.is_none() { - let output_index = self.next_output_index; - self.next_output_index += 1; - let message_id = format!("msg_{}", self.id_seed); - self.push_message_item_added(&message_id, output_index); - self.push_message_content_part_added(&message_id, output_index); - self.message = Some(MessageOutput { - id: message_id, - output_index, - text: String::new(), - }); - } - } - - fn ensure_function_call_output( - &mut self, - call_index: usize, - call_id: Option<&str>, - name: Option<&str>, - ) -> &mut FunctionCallOutput { - if self.function_calls.len() <= call_index { - self.function_calls.resize_with(call_index + 1, || None); - } - - if self.function_calls[call_index].is_none() { - let output_index = self.next_output_index; - self.next_output_index += 1; - let item_id = format!("fc_{}_{}", self.id_seed, call_index); - let call_id = call_id - .map(|value| value.to_string()) - .unwrap_or_else(|| format!("call_{}_{}", self.id_seed, call_index)); - let name = name.unwrap_or("").to_string(); - - self.push_function_call_item_added(&item_id, output_index, &call_id, &name); - self.function_calls[call_index] = Some(FunctionCallOutput { - id: item_id, - output_index, - call_id, - name, - arguments: String::new(), - }); - } else { - let state = self.function_calls[call_index] - .as_mut() - .expect("call output must exist"); - if let Some(call_id) = call_id { - if state.call_id.is_empty() { - state.call_id = call_id.to_string(); - } - } - if let Some(name) = name { - if state.name.is_empty() { - state.name = name.to_string(); - } - } - } - - self.function_calls[call_index] - .as_mut() - .expect("call output must exist") - } - - fn push_response_created(&mut self) { - let response = self.build_response_object("in_progress", Vec::new(), None, None); - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.created", - "response": response, - "sequence_number": sequence_number - }))); - } - - fn push_message_item_added(&mut self, item_id: &str, output_index: u64) { - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_item.added", - "output_index": output_index, - "item": { - "id": item_id, - "type": "message", - "status": "in_progress", - "role": "assistant", - "content": [] - }, - "sequence_number": sequence_number - }))); - } - - fn push_message_content_part_added(&mut self, item_id: &str, output_index: u64) { - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.content_part.added", - "item_id": item_id, - "output_index": output_index, - "content_index": 0, - "part": { - "type": "output_text", - "text": "", - "annotations": [] - }, - "sequence_number": sequence_number - }))); - } - - fn push_function_call_item_added( - &mut self, - item_id: &str, - output_index: u64, - call_id: &str, - name: &str, - ) { - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_item.added", - "output_index": output_index, - "item": { - "id": item_id, - "type": "function_call", - "status": "in_progress", - "arguments": "", - "call_id": call_id, - "name": name - }, - "sequence_number": sequence_number - }))); - } - - fn push_function_call_arguments_delta( - &mut self, - item_id: &str, - output_index: u64, - delta: &str, - ) { - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.function_call_arguments.delta", - "item_id": item_id, - "output_index": output_index, - "delta": delta, - "sequence_number": sequence_number - }))); - } - - fn push_done(&mut self) { - if self.sent_done { - return; - } - self.sent_done = true; - - let completed_at = (super::now_ms() / 1000) as i64; - let usage = self.collector.finish().usage.map(usage_to_value); - - let mut snapshots = Vec::new(); - if let Some(message) = &self.message { - snapshots.push(OutputItemSnapshot::Message { - id: message.id.clone(), - output_index: message.output_index, - text: message.text.clone(), - }); - } - for call in &self.function_calls { - let Some(call) = call else { - continue; - }; - snapshots.push(OutputItemSnapshot::FunctionCall { - id: call.id.clone(), - output_index: call.output_index, - call_id: call.call_id.clone(), - name: call.name.clone(), - arguments: call.arguments.clone(), - }); - } - snapshots.sort_by_key(|item| match item { - OutputItemSnapshot::Message { output_index, .. } => *output_index, - OutputItemSnapshot::FunctionCall { output_index, .. } => *output_index, - }); - - let output = snapshots - .iter() - .map(snapshot_to_output_item) - .collect::>(); - for snapshot in &snapshots { - self.push_item_done_events(snapshot); - } - - let response = self.build_response_object("completed", output, usage, Some(completed_at)); - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.completed", - "response": response, - "sequence_number": sequence_number - }))); - self.out.push_back(Bytes::from("data: [DONE]\n\n")); - } - - fn push_item_done_events(&mut self, snapshot: &OutputItemSnapshot) { - match snapshot { - OutputItemSnapshot::Message { - id, - output_index, - text, - } => self.push_message_done_events(id, *output_index, text), - OutputItemSnapshot::FunctionCall { - id, - output_index, - call_id, - name, - arguments, - } => self.push_function_call_done_events(id, *output_index, call_id, name, arguments), - } - } - - fn push_message_done_events(&mut self, item_id: &str, output_index: u64, text: &str) { - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_text.done", - "item_id": item_id, - "output_index": output_index, - "content_index": 0, - "text": text, - "sequence_number": sequence_number - }))); - - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.content_part.done", - "item_id": item_id, - "output_index": output_index, - "content_index": 0, - "part": { - "type": "output_text", - "text": text, - "annotations": [] - }, - "sequence_number": sequence_number - }))); - - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_item.done", - "output_index": output_index, - "item": { - "id": item_id, - "type": "message", - "status": "completed", - "role": "assistant", - "content": [ - { - "type": "output_text", - "text": text, - "annotations": [] - } - ] - }, - "sequence_number": sequence_number - }))); - } - - fn push_function_call_done_events( - &mut self, - item_id: &str, - output_index: u64, - call_id: &str, - name: &str, - arguments: &str, - ) { - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.function_call_arguments.done", - "item_id": item_id, - "output_index": output_index, - "name": name, - "arguments": arguments, - "sequence_number": sequence_number - }))); - - let sequence_number = self.next_sequence_number(); - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.output_item.done", - "output_index": output_index, - "item": { - "id": item_id, - "type": "function_call", - "status": "completed", - "call_id": call_id, - "name": name, - "arguments": arguments - }, - "sequence_number": sequence_number - }))); - } - - fn build_response_object( - &self, - status: &str, - output: Vec, - usage: Option, - completed_at: Option, - ) -> Value { - json!({ - "id": self.response_id.as_str(), - "object": "response", - "created_at": self.created_at, - "model": self.model.as_str(), - "status": status, - "output": output, - "parallel_tool_calls": self.parallel_tool_calls(), - "completed_at": completed_at, - "usage": usage, - "error": null, - "metadata": {} - }) - } - - fn parallel_tool_calls(&self) -> bool { - self.function_calls.iter().filter(|call| call.is_some()).count() > 1 - } - - fn log_usage_once(&mut self) { - if self.logged { - return; - } - self.logged = true; - let entry = build_log_entry(&self.context, self.collector.finish(), None); - self.log.clone().write_detached(entry); - } - - fn next_sequence_number(&mut self) -> u64 { - let current = self.sequence; - self.sequence += 1; - current - } -} diff --git a/src-tauri/src/proxy/response/chat_to_responses/format.rs b/src-tauri/src/proxy/response/chat_to_responses/format.rs deleted file mode 100644 index 1afb8ee..0000000 --- a/src-tauri/src/proxy/response/chat_to_responses/format.rs +++ /dev/null @@ -1,64 +0,0 @@ -use serde_json::{json, Value}; - -use super::super::super::log::TokenUsage; - -pub(super) enum OutputItemSnapshot { - Message { - id: String, - output_index: u64, - text: String, - }, - FunctionCall { - id: String, - output_index: u64, - call_id: String, - name: String, - arguments: String, - }, -} - -pub(super) fn usage_to_value(usage: TokenUsage) -> Value { - let input_tokens = usage.input_tokens.unwrap_or(0); - let output_tokens = usage.output_tokens.unwrap_or(0); - let total_tokens = usage - .total_tokens - .or_else(|| input_tokens.checked_add(output_tokens)) - .unwrap_or(0); - - json!({ - "input_tokens": input_tokens, - "input_tokens_details": { "cached_tokens": 0 }, - "output_tokens": output_tokens, - "output_tokens_details": { "reasoning_tokens": 0 }, - "total_tokens": total_tokens - }) -} - -pub(super) fn snapshot_to_output_item(snapshot: &OutputItemSnapshot) -> Value { - match snapshot { - OutputItemSnapshot::Message { id, text, .. } => json!({ - "id": id, - "type": "message", - "status": "completed", - "role": "assistant", - "content": [ - { "type": "output_text", "text": text, "annotations": [] } - ] - }), - OutputItemSnapshot::FunctionCall { - id, - call_id, - name, - arguments, - .. - } => json!({ - "id": id, - "type": "function_call", - "status": "completed", - "call_id": call_id, - "name": name, - "arguments": arguments - }), - } -} - diff --git a/src-tauri/src/proxy/response/dispatch/buffered.rs b/src-tauri/src/proxy/response/dispatch/buffered.rs deleted file mode 100644 index 4a681c8..0000000 --- a/src-tauri/src/proxy/response/dispatch/buffered.rs +++ /dev/null @@ -1,421 +0,0 @@ -use axum::{ - body::{Body, Bytes}, - http::{HeaderMap, StatusCode}, - response::Response, -}; -use std::sync::Arc; - -use super::super::{ - kiro_to_anthropic, kiro_to_responses, token_count, upstream_read, upstream_stream, - PROVIDER_ANTIGRAVITY, PROVIDER_GEMINI, RESPONSE_ERROR_LIMIT_BYTES, -}; -use super::super::super::{ - antigravity_compat, codex_compat, - http, - log::{build_log_entry, LogContext, LogWriter, UsageSnapshot}, - model, - openai_compat::{transform_response_body, FormatTransform}, - request_body::ReplayableBody, - redact::redact_query_param_value, - server_helpers::log_debug_headers_body, - token_rate::RequestTokenTracker, - usage::extract_usage_from_response, - UPSTREAM_NO_DATA_TIMEOUT, -}; - -const DEBUG_BODY_LOG_LIMIT_BYTES: usize = usize::MAX; - -pub(super) async fn build_buffered_response( - status: StatusCode, - upstream_res: reqwest::Response, - headers: HeaderMap, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, - response_transform: FormatTransform, - model_override: Option<&str>, - estimated_input_tokens: Option, -) -> Response { - let mut context = context; - let response_headers = upstream_res.headers().clone(); - let bytes = match read_upstream_bytes(upstream_res, &mut context, &log).await { - Ok(bytes) => bytes, - Err(response) => return response, - }; - log_debug_headers_body( - "upstream.response.raw", - Some(&response_headers), - Some(&ReplayableBody::from_bytes(bytes.clone())), - DEBUG_BODY_LOG_LIMIT_BYTES, - ) - .await; - let bytes = if context.provider == PROVIDER_ANTIGRAVITY { - match antigravity_compat::unwrap_response(&bytes) { - Ok(unwrapped) => unwrapped, - Err(message) => { - return http::error_response(StatusCode::BAD_GATEWAY, message); - } - } - } else { - bytes - }; - if context.provider == PROVIDER_ANTIGRAVITY { - log_debug_headers_body( - "upstream.response.unwrapped", - Some(&response_headers), - Some(&ReplayableBody::from_bytes(bytes.clone())), - DEBUG_BODY_LOG_LIMIT_BYTES, - ) - .await; - } - let mut usage = extract_usage_from_response(&bytes); - let response_error = response_error_for_status(status, &bytes); - let request_body = context.request_body.clone(); - let output = if status.is_success() { - match convert_success_body( - response_transform, - &bytes, - &mut context, - usage, - log.clone(), - estimated_input_tokens, - request_body.as_deref(), - ) { - Ok(converted) => { - usage = converted.usage; - converted.output - } - Err(response) => return response, - } - } else { - bytes - }; - - let entry = build_log_entry(&context, usage, response_error); - log.clone().write_detached(entry); - - let output = maybe_override_response_model(output, model_override); - log_debug_headers_body( - "outbound.response", - Some(&headers), - Some(&ReplayableBody::from_bytes(output.clone())), - DEBUG_BODY_LOG_LIMIT_BYTES, - ) - .await; - let provider_for_tokens = provider_for_tokens(response_transform, context.provider.as_str()); - token_count::apply_output_tokens_from_response(&request_tracker, provider_for_tokens, &output).await; - - http::build_response(status, headers, Body::from(output)) -} - -struct ConvertedBody { - output: Bytes, - usage: UsageSnapshot, -} - -fn convert_success_body( - transform: FormatTransform, - bytes: &Bytes, - context: &mut LogContext, - usage: UsageSnapshot, - log: Arc, - estimated_input_tokens: Option, - request_body: Option<&str>, -) -> Result { - match transform { - FormatTransform::KiroToResponses => convert_kiro_to_responses_body( - bytes, - context, - usage, - log, - estimated_input_tokens, - ), - FormatTransform::KiroToChat => convert_kiro_to_chat_body( - bytes, - context, - usage, - log, - estimated_input_tokens, - ), - FormatTransform::KiroToAnthropic => convert_kiro_to_anthropic_body( - bytes, - context, - usage, - log, - estimated_input_tokens, - ), - FormatTransform::CodexToChat => { - convert_codex_to_chat_body(bytes, context, usage, log, request_body) - } - FormatTransform::CodexToResponses => { - convert_codex_to_responses_body(bytes, context, usage, log, request_body) - } - _ if transform != FormatTransform::None => { - convert_generic_body(transform, bytes, context, usage, log) - } - _ => Ok(ConvertedBody { - output: bytes.clone(), - usage, - }), - } -} - -fn convert_kiro_to_responses_body( - bytes: &Bytes, - context: &mut LogContext, - usage: UsageSnapshot, - log: Arc, - estimated_input_tokens: Option, -) -> Result { - let converted = match kiro_to_responses::convert_kiro_response( - bytes, - context.model.as_deref(), - estimated_input_tokens, - ) { - Ok(converted) => converted, - Err(message) => { - return Err(respond_transform_error(context, usage, log, message)); - } - }; - let usage = resolve_kiro_usage( - bytes, - &converted, - context.model.as_deref(), - estimated_input_tokens, - ); - Ok(ConvertedBody { - output: converted, - usage, - }) -} - -fn convert_kiro_to_chat_body( - bytes: &Bytes, - context: &mut LogContext, - usage: UsageSnapshot, - log: Arc, - estimated_input_tokens: Option, -) -> Result { - let responses = match kiro_to_responses::convert_kiro_response( - bytes, - context.model.as_deref(), - estimated_input_tokens, - ) { - Ok(converted) => converted, - Err(message) => { - return Err(respond_transform_error(context, usage, log, message)); - } - }; - let usage = resolve_kiro_usage( - bytes, - &responses, - context.model.as_deref(), - estimated_input_tokens, - ); - let converted = - match transform_response_body(FormatTransform::ResponsesToChat, &responses, context.model.as_deref()) { - Ok(converted) => converted, - Err(message) => { - return Err(respond_transform_error(context, usage, log, message)); - } - }; - Ok(ConvertedBody { - output: converted, - usage, - }) -} - -fn convert_kiro_to_anthropic_body( - bytes: &Bytes, - context: &mut LogContext, - usage: UsageSnapshot, - log: Arc, - estimated_input_tokens: Option, -) -> Result { - let converted = match kiro_to_anthropic::convert_kiro_response( - bytes, - context.model.as_deref(), - estimated_input_tokens, - ) { - Ok(converted) => converted, - Err(message) => { - return Err(respond_transform_error(context, usage, log, message)); - } - }; - let usage = resolve_kiro_usage( - bytes, - &converted, - context.model.as_deref(), - estimated_input_tokens, - ); - Ok(ConvertedBody { - output: converted, - usage, - }) -} - -fn convert_codex_to_chat_body( - bytes: &Bytes, - context: &mut LogContext, - usage: UsageSnapshot, - log: Arc, - request_body: Option<&str>, -) -> Result { - let converted = match codex_compat::codex_response_to_chat(bytes, request_body) { - Ok(converted) => converted, - Err(message) => { - return Err(respond_transform_error(context, usage, log, message)); - } - }; - Ok(ConvertedBody { - output: converted, - usage, - }) -} - -fn convert_codex_to_responses_body( - bytes: &Bytes, - context: &mut LogContext, - usage: UsageSnapshot, - log: Arc, - request_body: Option<&str>, -) -> Result { - let converted = match codex_compat::codex_response_to_responses(bytes, request_body) { - Ok(converted) => converted, - Err(message) => { - return Err(respond_transform_error(context, usage, log, message)); - } - }; - Ok(ConvertedBody { - output: converted, - usage, - }) -} - -fn convert_generic_body( - transform: FormatTransform, - bytes: &Bytes, - context: &mut LogContext, - usage: UsageSnapshot, - log: Arc, -) -> Result { - let converted = match transform_response_body(transform, bytes, context.model.as_deref()) { - Ok(converted) => converted, - Err(message) => { - return Err(respond_transform_error(context, usage, log, message)); - } - }; - Ok(ConvertedBody { - output: converted, - usage, - }) -} - -async fn read_upstream_bytes( - upstream_res: reqwest::Response, - context: &mut LogContext, - log: &Arc, -) -> Result { - let bytes = match upstream_read::read_upstream_bytes_with_ttfb(upstream_res, context).await { - Ok(bytes) => bytes, - Err(err) => { - let (status, message) = match err { - upstream_stream::UpstreamStreamError::IdleTimeout(_) => ( - StatusCode::GATEWAY_TIMEOUT, - format!( - "Upstream response timed out after {}s.", - UPSTREAM_NO_DATA_TIMEOUT.as_secs() - ), - ), - upstream_stream::UpstreamStreamError::Upstream(err) => { - let raw = err.to_string(); - let message = if context.provider == PROVIDER_GEMINI { - redact_query_param_value(&raw, "key") - } else { - raw - }; - ( - StatusCode::BAD_GATEWAY, - format!("Failed to read upstream response: {message}"), - ) - } - }; - context.status = status.as_u16(); - let empty_usage = UsageSnapshot { - usage: None, - cached_tokens: None, - usage_json: None, - }; - let entry = build_log_entry(context, empty_usage, Some(message.clone())); - log.clone().write_detached(entry); - return Err(http::error_response(status, message)); - } - }; - Ok(bytes) -} - -fn respond_transform_error( - context: &mut LogContext, - usage: UsageSnapshot, - log: Arc, - message: String, -) -> Response { - let error_message = format!("Failed to transform upstream response: {message}"); - context.status = StatusCode::BAD_GATEWAY.as_u16(); - let entry = build_log_entry(context, usage, Some(error_message.clone())); - log.clone().write_detached(entry); - http::error_response(StatusCode::BAD_GATEWAY, error_message) -} - -fn resolve_kiro_usage( - raw_bytes: &Bytes, - responses_bytes: &Bytes, - model: Option<&str>, - estimated_input_tokens: Option, -) -> UsageSnapshot { - let usage = extract_usage_from_response(responses_bytes); - if usage.usage.is_none() && usage.cached_tokens.is_none() && usage.usage_json.is_none() { - if let Some(fallback) = - kiro_to_responses::extract_kiro_usage_snapshot(raw_bytes, model, estimated_input_tokens) - { - return fallback; - } - } - usage -} - -fn maybe_override_response_model(bytes: Bytes, model_override: Option<&str>) -> Bytes { - let Some(model_override) = model_override else { - return bytes; - }; - model::rewrite_response_model(&bytes, model_override).unwrap_or(bytes) -} - -fn response_error_text(bytes: &Bytes) -> String { - let slice = bytes.as_ref(); - if slice.len() <= RESPONSE_ERROR_LIMIT_BYTES { - return String::from_utf8_lossy(slice).to_string(); - } - let truncated = &slice[..RESPONSE_ERROR_LIMIT_BYTES]; - format!("{}... (truncated)", String::from_utf8_lossy(truncated)) -} - -fn response_error_for_status(status: StatusCode, bytes: &Bytes) -> Option { - if status.is_client_error() || status.is_server_error() { - Some(response_error_text(bytes)) - } else { - None - } -} - -fn provider_for_tokens(transform: FormatTransform, provider: &str) -> &str { - match transform { - FormatTransform::KiroToResponses => "openai-response", - FormatTransform::KiroToChat => "openai", - FormatTransform::KiroToAnthropic => "anthropic", - FormatTransform::CodexToChat => "openai", - FormatTransform::CodexToResponses => "openai-response", - _ if provider == PROVIDER_ANTIGRAVITY => PROVIDER_GEMINI, - _ => provider, - } -} diff --git a/src-tauri/src/proxy/response/dispatch/mod.rs b/src-tauri/src/proxy/response/dispatch/mod.rs deleted file mode 100644 index 1a398fe..0000000 --- a/src-tauri/src/proxy/response/dispatch/mod.rs +++ /dev/null @@ -1,60 +0,0 @@ -mod buffered; -mod stream; - -use axum::http::{HeaderMap, StatusCode}; -use axum::response::Response; -use std::sync::Arc; - -use super::super::log::{LogContext, LogWriter}; -use super::super::openai_compat::FormatTransform; -use super::super::token_rate::RequestTokenTracker; - -pub(super) async fn build_stream_response( - status: StatusCode, - upstream_res: reqwest::Response, - headers: HeaderMap, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, - response_transform: FormatTransform, - model_override: Option<&str>, - estimated_input_tokens: Option, -) -> Response { - stream::build_stream_response( - status, - upstream_res, - headers, - context, - log, - request_tracker, - response_transform, - model_override, - estimated_input_tokens, - ) - .await -} - -pub(super) async fn build_buffered_response( - status: StatusCode, - upstream_res: reqwest::Response, - headers: HeaderMap, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, - response_transform: FormatTransform, - model_override: Option<&str>, - estimated_input_tokens: Option, -) -> Response { - buffered::build_buffered_response( - status, - upstream_res, - headers, - context, - log, - request_tracker, - response_transform, - model_override, - estimated_input_tokens, - ) - .await -} diff --git a/src-tauri/src/proxy/response/dispatch/stream.rs b/src-tauri/src/proxy/response/dispatch/stream.rs deleted file mode 100644 index 13c4388..0000000 --- a/src-tauri/src/proxy/response/dispatch/stream.rs +++ /dev/null @@ -1,587 +0,0 @@ -use axum::{ - body::{Body, Bytes}, - http::{HeaderMap, StatusCode}, - response::Response, -}; -use futures_util::StreamExt; -use std::sync::Arc; - -use super::super::{ - anthropic_to_responses, chat_to_responses, kiro_to_anthropic, kiro_to_responses, - responses_to_anthropic, responses_to_chat, streaming, upstream_stream, PROVIDER_CODEX, - PROVIDER_ANTIGRAVITY, PROVIDER_GEMINI, PROVIDER_OPENAI, PROVIDER_OPENAI_RESPONSES, -}; -use super::super::super::{ - antigravity_compat, codex_compat, gemini_compat, http, - log::{build_log_entry, LogContext, LogWriter, UsageSnapshot}, - openai_compat::FormatTransform, - redact::redact_query_param_value, - server_helpers::log_debug_headers_body, - token_rate::RequestTokenTracker, - UPSTREAM_NO_DATA_TIMEOUT, -}; - -type UpstreamBytesStream = futures_util::stream::BoxStream< - 'static, - Result>, ->; -type ResponseStream = futures_util::stream::BoxStream<'static, Result>; -const DEBUG_BODY_LOG_LIMIT_BYTES: usize = usize::MAX; - -pub(super) async fn build_stream_response( - status: StatusCode, - upstream_res: reqwest::Response, - headers: HeaderMap, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, - response_transform: FormatTransform, - model_override: Option<&str>, - estimated_input_tokens: Option, -) -> Response { - let mut context = context; - let upstream = - match prepare_upstream_stream(status, &headers, upstream_res, &mut context, &log).await { - Ok(stream) => stream, - Err(response) => return response, - }; - log_debug_headers_body( - "upstream.response.headers", - Some(&headers), - None, - DEBUG_BODY_LOG_LIMIT_BYTES, - ) - .await; - let upstream = if context.provider == PROVIDER_ANTIGRAVITY { - antigravity_compat::stream_antigravity_to_gemini(upstream).boxed() - } else { - upstream - }; - let upstream = log_upstream_stream_if_debug(upstream); - - let stream = stream_for_transform( - response_transform, - upstream, - context, - log, - request_tracker, - estimated_input_tokens, - model_override, - ); - log_debug_headers_body( - "outbound.response.headers", - Some(&headers), - None, - DEBUG_BODY_LOG_LIMIT_BYTES, - ) - .await; - let stream = log_response_stream_if_debug(stream); - let body = Body::from_stream(stream); - http::build_response(status, headers, body) -} - -fn stream_for_transform( - transform: FormatTransform, - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, - estimated_input_tokens: Option, - model_override: Option<&str>, -) -> ResponseStream { - if is_simple_transform(transform) { - return stream_for_simple_transform( - transform, - upstream, - context, - log, - request_tracker, - model_override, - estimated_input_tokens, - ); - } - stream_for_composed_transform( - transform, - upstream, - context, - log, - request_tracker, - estimated_input_tokens, - ) -} - -fn is_simple_transform(transform: FormatTransform) -> bool { - matches!( - transform, - FormatTransform::None - | FormatTransform::ResponsesToChat - | FormatTransform::ChatToResponses - | FormatTransform::ResponsesToAnthropic - | FormatTransform::AnthropicToResponses - | FormatTransform::GeminiToChat - | FormatTransform::ChatToGemini - | FormatTransform::KiroToResponses - | FormatTransform::KiroToAnthropic - | FormatTransform::CodexToChat - | FormatTransform::CodexToResponses - | FormatTransform::ChatToCodex - | FormatTransform::ResponsesToCodex - ) -} - -fn stream_for_simple_transform( - transform: FormatTransform, - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, - model_override: Option<&str>, - estimated_input_tokens: Option, -) -> ResponseStream { - match transform { - FormatTransform::None - | FormatTransform::ResponsesToChat - | FormatTransform::ChatToResponses - | FormatTransform::ResponsesToAnthropic - | FormatTransform::AnthropicToResponses => stream_for_basic_transform( - transform, - upstream, - context, - log, - request_tracker, - model_override, - ), - _ => stream_for_simple_extended( - transform, - upstream, - context, - log, - request_tracker, - estimated_input_tokens, - ), - } -} - -fn stream_for_basic_transform( - transform: FormatTransform, - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, - model_override: Option<&str>, -) -> ResponseStream { - match transform { - FormatTransform::None => stream_with_optional_model_override( - upstream, - context, - log, - request_tracker, - model_override, - ), - FormatTransform::ResponsesToChat => { - responses_to_chat::stream_responses_to_chat(upstream, context, log, request_tracker).boxed() - } - FormatTransform::ChatToResponses => { - chat_to_responses::stream_chat_to_responses(upstream, context, log, request_tracker).boxed() - } - FormatTransform::ResponsesToAnthropic => { - responses_to_anthropic::stream_responses_to_anthropic(upstream, context, log, request_tracker) - .boxed() - } - FormatTransform::AnthropicToResponses => { - anthropic_to_responses::stream_anthropic_to_responses(upstream, context, log, request_tracker) - .boxed() - } - _ => streaming::stream_with_logging(upstream, context, log, request_tracker).boxed(), - } -} - -fn stream_for_simple_extended( - transform: FormatTransform, - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, - estimated_input_tokens: Option, -) -> ResponseStream { - match transform { - FormatTransform::GeminiToChat => { - gemini_compat::stream_gemini_to_chat(upstream, context, log, request_tracker).boxed() - } - FormatTransform::ChatToGemini => { - gemini_compat::stream_chat_to_gemini(upstream, context, log, request_tracker).boxed() - } - FormatTransform::KiroToResponses => kiro_to_responses::stream_kiro_to_responses( - upstream, - context, - log, - request_tracker, - estimated_input_tokens, - ) - .boxed(), - FormatTransform::KiroToAnthropic => kiro_to_anthropic::stream_kiro_to_anthropic( - upstream, - context, - log, - request_tracker, - estimated_input_tokens, - ) - .boxed(), - FormatTransform::CodexToChat => { - codex_compat::stream_codex_to_chat(upstream, context, log, request_tracker).boxed() - } - FormatTransform::CodexToResponses => { - codex_compat::stream_codex_to_responses(upstream, context, log, request_tracker).boxed() - } - FormatTransform::ChatToCodex | FormatTransform::ResponsesToCodex => { - streaming::stream_with_logging(upstream, context, log, request_tracker).boxed() - } - _ => streaming::stream_with_logging(upstream, context, log, request_tracker).boxed(), - } -} - -fn stream_for_composed_transform( - transform: FormatTransform, - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, - estimated_input_tokens: Option, -) -> ResponseStream { - match transform { - FormatTransform::ChatToAnthropic => stream_chat_to_anthropic(upstream, context, log, request_tracker), - FormatTransform::AnthropicToChat => stream_anthropic_to_chat(upstream, context, log, request_tracker), - FormatTransform::GeminiToAnthropic => stream_gemini_to_anthropic(upstream, context, log, request_tracker), - FormatTransform::AnthropicToGemini => stream_anthropic_to_gemini(upstream, context, log, request_tracker), - FormatTransform::ResponsesToGemini => stream_responses_to_gemini(upstream, context, log, request_tracker), - FormatTransform::GeminiToResponses => stream_gemini_to_responses(upstream, context, log, request_tracker), - FormatTransform::KiroToChat => { - stream_kiro_to_chat(upstream, context, log, request_tracker, estimated_input_tokens) - } - _ => streaming::stream_with_logging(upstream, context, log, request_tracker).boxed(), - } -} - -fn stream_chat_to_anthropic( - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, -) -> ResponseStream { - let intermediate_log = Arc::new(LogWriter::new(None)); - let intermediate_tracker = RequestTokenTracker::disabled(); - let responses_stream = chat_to_responses::stream_chat_to_responses( - upstream, - context.clone(), - intermediate_log, - intermediate_tracker, - ) - .boxed(); - responses_to_anthropic::stream_responses_to_anthropic( - responses_stream, - context, - log, - request_tracker, - ) - .boxed() -} - -fn stream_anthropic_to_chat( - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, -) -> ResponseStream { - let intermediate_log = Arc::new(LogWriter::new(None)); - let intermediate_tracker = RequestTokenTracker::disabled(); - let responses_stream = anthropic_to_responses::stream_anthropic_to_responses( - upstream, - context.clone(), - intermediate_log, - intermediate_tracker, - ) - .boxed(); - responses_to_chat::stream_responses_to_chat( - responses_stream, - context, - log, - request_tracker, - ) - .boxed() -} - -fn stream_gemini_to_anthropic( - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, -) -> ResponseStream { - let first_log = Arc::new(LogWriter::new(None)); - let first_tracker = RequestTokenTracker::disabled(); - let chat_stream = gemini_compat::stream_gemini_to_chat( - upstream, - context.clone(), - first_log, - first_tracker, - ) - .boxed(); - let second_log = Arc::new(LogWriter::new(None)); - let second_tracker = RequestTokenTracker::disabled(); - let responses_stream = chat_to_responses::stream_chat_to_responses( - chat_stream, - context.clone(), - second_log, - second_tracker, - ) - .boxed(); - responses_to_anthropic::stream_responses_to_anthropic( - responses_stream, - context, - log, - request_tracker, - ) - .boxed() -} - -fn stream_anthropic_to_gemini( - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, -) -> ResponseStream { - let first_log = Arc::new(LogWriter::new(None)); - let first_tracker = RequestTokenTracker::disabled(); - let responses_stream = anthropic_to_responses::stream_anthropic_to_responses( - upstream, - context.clone(), - first_log, - first_tracker, - ) - .boxed(); - let second_log = Arc::new(LogWriter::new(None)); - let second_tracker = RequestTokenTracker::disabled(); - let chat_stream = responses_to_chat::stream_responses_to_chat( - responses_stream, - context.clone(), - second_log, - second_tracker, - ) - .boxed(); - gemini_compat::stream_chat_to_gemini(chat_stream, context, log, request_tracker).boxed() -} - -fn stream_responses_to_gemini( - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, -) -> ResponseStream { - let intermediate_log = Arc::new(LogWriter::new(None)); - let intermediate_tracker = RequestTokenTracker::disabled(); - let chat_stream = responses_to_chat::stream_responses_to_chat( - upstream, - context.clone(), - intermediate_log, - intermediate_tracker, - ) - .boxed(); - gemini_compat::stream_chat_to_gemini(chat_stream, context, log, request_tracker).boxed() -} - -fn stream_gemini_to_responses( - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, -) -> ResponseStream { - let intermediate_log = Arc::new(LogWriter::new(None)); - let intermediate_tracker = RequestTokenTracker::disabled(); - let chat_stream = gemini_compat::stream_gemini_to_chat( - upstream, - context.clone(), - intermediate_log, - intermediate_tracker, - ) - .boxed(); - chat_to_responses::stream_chat_to_responses(chat_stream, context, log, request_tracker).boxed() -} - -fn stream_kiro_to_chat( - upstream: UpstreamBytesStream, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, - estimated_input_tokens: Option, -) -> ResponseStream { - let intermediate_log = Arc::new(LogWriter::new(None)); - let intermediate_tracker = RequestTokenTracker::disabled(); - let responses_stream = kiro_to_responses::stream_kiro_to_responses( - upstream, - context.clone(), - intermediate_log, - intermediate_tracker, - estimated_input_tokens, - ) - .boxed(); - responses_to_chat::stream_responses_to_chat( - responses_stream, - context, - log, - request_tracker, - ) - .boxed() -} - -async fn prepare_upstream_stream( - status: StatusCode, - headers: &HeaderMap, - upstream_res: reqwest::Response, - context: &mut LogContext, - log: &Arc, -) -> Result< - futures_util::stream::BoxStream< - 'static, - Result>, - >, - Response, -> { - let mut upstream = upstream_stream::with_idle_timeout(upstream_res.bytes_stream()); - let first = upstream.next().await; - match first { - Some(Ok(chunk)) => Ok(chain_first_chunk(chunk, upstream, context)), - Some(Err(err)) => Err(stream_error_response(err, context, log)), - None => Err(http::build_response(status, headers.clone(), Body::empty())), - } -} - -fn chain_first_chunk( - chunk: Bytes, - upstream: UpstreamBytesStream, - context: &mut LogContext, -) -> UpstreamBytesStream { - if context.ttfb_ms.is_none() { - context.ttfb_ms = Some(context.start.elapsed().as_millis()); - } - futures_util::stream::iter(vec![Ok::< - Bytes, - upstream_stream::UpstreamStreamError, - >(chunk)]) - .chain(upstream) - .boxed() -} - -fn stream_error_response( - err: upstream_stream::UpstreamStreamError, - context: &mut LogContext, - log: &Arc, -) -> Response { - let (status, message) = match err { - upstream_stream::UpstreamStreamError::IdleTimeout(_) => ( - StatusCode::GATEWAY_TIMEOUT, - format!( - "Upstream response timed out after {}s.", - UPSTREAM_NO_DATA_TIMEOUT.as_secs() - ), - ), - upstream_stream::UpstreamStreamError::Upstream(err) => { - let raw = err.to_string(); - let message = if context.provider == PROVIDER_GEMINI { - redact_query_param_value(&raw, "key") - } else { - raw - }; - ( - StatusCode::BAD_GATEWAY, - format!("Failed to read upstream response: {message}"), - ) - } - }; - - context.status = status.as_u16(); - let empty_usage = UsageSnapshot { - usage: None, - cached_tokens: None, - usage_json: None, - }; - let entry = build_log_entry(context, empty_usage, Some(message.clone())); - log.clone().write_detached(entry); - http::error_response(status, message) -} - -fn log_upstream_stream_if_debug(upstream: UpstreamBytesStream) -> UpstreamBytesStream { - if !tracing::enabled!(tracing::Level::DEBUG) { - return upstream; - } - upstream - .map(|item| { - if let Ok(chunk) = &item { - let text = String::from_utf8_lossy(chunk); - tracing::debug!( - stage = "upstream.response.chunk", - bytes = chunk.len(), - body = %text, - "debug dump" - ); - } else if let Err(err) = &item { - tracing::debug!(stage = "upstream.response.chunk.error", error = %err, "debug dump"); - } - item - }) - .boxed() -} - -fn log_response_stream_if_debug(stream: ResponseStream) -> ResponseStream { - if !tracing::enabled!(tracing::Level::DEBUG) { - return stream; - } - stream - .map(|item| { - if let Ok(chunk) = &item { - let text = String::from_utf8_lossy(chunk); - tracing::debug!( - stage = "outbound.response.chunk", - bytes = chunk.len(), - body = %text, - "debug dump" - ); - } else if let Err(err) = &item { - tracing::debug!(stage = "outbound.response.chunk.error", error = %err, "debug dump"); - } - item - }) - .boxed() -} - -fn stream_with_optional_model_override( - upstream: impl futures_util::stream::Stream> + Unpin + Send + 'static, - context: LogContext, - log: Arc, - request_tracker: RequestTokenTracker, - model_override: Option<&str>, -) -> futures_util::stream::BoxStream<'static, Result> -where - E: std::error::Error + Send + Sync + 'static, -{ - if let Some(model_override) = model_override { - if should_rewrite_sse_model(&context.provider) { - return streaming::stream_with_logging_and_model_override( - upstream, - context, - log, - model_override.to_string(), - request_tracker, - ) - .boxed(); - } - } - streaming::stream_with_logging(upstream, context, log, request_tracker).boxed() -} - -// 只对 data-only SSE 的提供商做行级重写,避免破坏带 event: 行的流。 -fn should_rewrite_sse_model(provider: &str) -> bool { - provider == PROVIDER_OPENAI - || provider == PROVIDER_OPENAI_RESPONSES - || provider == PROVIDER_GEMINI - || provider == PROVIDER_ANTIGRAVITY - || provider == PROVIDER_CODEX -} diff --git a/src-tauri/src/proxy/response/kiro_to_anthropic.rs b/src-tauri/src/proxy/response/kiro_to_anthropic.rs deleted file mode 100644 index d4ffa93..0000000 --- a/src-tauri/src/proxy/response/kiro_to_anthropic.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod kiro_to_anthropic_helpers; -mod kiro_to_anthropic_stream; - -pub(super) use kiro_to_anthropic_helpers::convert_kiro_response; -pub(super) use kiro_to_anthropic_stream::stream_kiro_to_anthropic; diff --git a/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_helpers.rs b/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_helpers.rs deleted file mode 100644 index 736a22b..0000000 --- a/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_helpers.rs +++ /dev/null @@ -1,111 +0,0 @@ -use axum::body::Bytes; -use base64::{engine::general_purpose::STANDARD, Engine as _}; -use serde_json::{json, Value}; -use sha2::{Digest, Sha256}; - -use crate::proxy::kiro::{parse_event_stream, utils::random_uuid, KiroToolUse, KiroUsage}; -use super::super::kiro_to_responses_helpers::{apply_usage_fallback, usage_json_from_kiro}; - -pub(crate) fn convert_kiro_response( - bytes: &Bytes, - model: Option<&str>, - estimated_input_tokens: Option, -) -> Result { - let parsed = parse_event_stream(bytes) - .map_err(|message| format!("Failed to parse Kiro response: {message}"))?; - let mut usage = parsed.usage.clone(); - apply_usage_fallback( - &mut usage, - model, - estimated_input_tokens, - &parsed.content, - &parsed.reasoning, - ); - let response = build_claude_response( - parsed.content, - parsed.reasoning, - parsed.tool_uses, - usage, - parsed.stop_reason.as_deref(), - model.unwrap_or("unknown"), - ); - serde_json::to_vec(&response) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize response: {err}")) -} - -pub(super) fn split_partial_tag(segment: &str, tag: &str) -> (String, String) { - if tag.len() <= 1 || segment.is_empty() { - return (segment.to_string(), String::new()); - } - let max_len = std::cmp::min(segment.len(), tag.len() - 1); - for len in (1..=max_len).rev() { - if segment.ends_with(&tag[..len]) { - let emit_end = segment.len() - len; - return (segment[..emit_end].to_string(), segment[emit_end..].to_string()); - } - } - (segment.to_string(), String::new()) -} - -fn build_claude_response( - content: String, - reasoning: String, - tool_uses: Vec, - usage: KiroUsage, - stop_reason: Option<&str>, - model: &str, -) -> Value { - let mut blocks = Vec::new(); - if !reasoning.trim().is_empty() { - blocks.push(json!({ - "type": "thinking", - "thinking": reasoning, - "signature": thinking_signature(&reasoning) - })); - } - if !content.trim().is_empty() { - blocks.push(json!({ "type": "text", "text": content })); - } - for tool_use in tool_uses.iter() { - blocks.push(json!({ - "type": "tool_use", - "id": tool_use.tool_use_id, - "name": tool_use.name, - "input": tool_use.input - })); - } - if blocks.is_empty() { - blocks.push(json!({ "type": "text", "text": "" })); - } - let stop_reason = stop_reason.unwrap_or_else(|| { - if tool_uses.is_empty() { - "end_turn" - } else { - "tool_use" - } - }); - let usage_value = usage_json_from_kiro(&usage).unwrap_or_else(|| json!({ - "input_tokens": 0, - "output_tokens": 0 - })); - json!({ - "id": format!("msg_{}", random_uuid()), - "type": "message", - "role": "assistant", - "model": model, - "content": blocks, - "stop_reason": stop_reason, - "stop_sequence": null, - "usage": usage_value - }) -} - -fn thinking_signature(text: &str) -> String { - if text.is_empty() { - return String::new(); - } - let mut hasher = Sha256::new(); - hasher.update(text.as_bytes()); - STANDARD.encode(hasher.finalize()) -} diff --git a/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream.rs b/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream.rs deleted file mode 100644 index 3a8209b..0000000 --- a/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream.rs +++ /dev/null @@ -1,198 +0,0 @@ -use axum::body::Bytes; -use futures_util::StreamExt; -use std::collections::{HashMap, HashSet, VecDeque}; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use crate::proxy::kiro::{EventStreamDecoder, KiroUsage}; -use crate::proxy::log::{LogContext, LogWriter}; -use crate::proxy::token_rate::RequestTokenTracker; - -pub(super) const USAGE_UPDATE_CHAR_THRESHOLD: usize = 5000; -pub(super) const USAGE_UPDATE_TIME_INTERVAL: Duration = Duration::from_secs(15); -pub(super) const USAGE_UPDATE_TOKEN_DELTA: u64 = 10; - -pub(crate) fn stream_kiro_to_anthropic( - upstream: impl futures_util::stream::Stream> - + Unpin - + Send - + 'static, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - estimated_input_tokens: Option, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = KiroToAnthropicState::new( - upstream, - context, - log, - token_tracker, - estimated_input_tokens, - ); - futures_util::stream::try_unfold(state, |state| async move { state.step().await }) -} - -enum ActiveBlock { - Text { index: usize }, - Thinking { index: usize }, - ToolUse { id: String }, -} - -struct ToolUseState { - index: usize, - name: String, - sent_start: bool, - sent_stop: bool, - sent_input: bool, -} - -struct ThinkingStreamState { - in_thinking: bool, - pending: String, -} - -struct KiroToAnthropicState { - upstream: S, - decoder: EventStreamDecoder, - log: Arc, - context: LogContext, - token_tracker: RequestTokenTracker, - estimated_input_tokens: Option, - out: VecDeque, - message_id: String, - model: String, - sent_message_start: bool, - sent_message_stop: bool, - active_block: Option, - next_block_index: usize, - tool_uses: HashMap, - processed_tool_keys: HashSet, - tool_state: Option, - usage: KiroUsage, - stop_reason: Option, - thinking_state: ThinkingStreamState, - raw_content: String, - content: String, - reasoning: String, - saw_tool_use: bool, - logged: bool, - upstream_ended: bool, - last_ping_len: usize, - last_ping_time: Instant, - last_reported_output_tokens: u64, -} - -impl KiroToAnthropicState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new( - upstream: S, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - estimated_input_tokens: Option, - ) -> Self { - let now_ms = super::super::now_ms(); - let model = context - .model - .clone() - .unwrap_or_else(|| "unknown".to_string()); - Self { - upstream, - decoder: EventStreamDecoder::new(), - log, - context, - token_tracker, - estimated_input_tokens, - out: VecDeque::new(), - message_id: format!("msg_proxy_{now_ms}"), - model, - sent_message_start: false, - sent_message_stop: false, - active_block: None, - next_block_index: 0, - tool_uses: HashMap::new(), - processed_tool_keys: HashSet::new(), - tool_state: None, - usage: KiroUsage::default(), - stop_reason: None, - thinking_state: ThinkingStreamState { - in_thinking: false, - pending: String::new(), - }, - raw_content: String::new(), - content: String::new(), - reasoning: String::new(), - saw_tool_use: false, - logged: false, - upstream_ended: false, - last_ping_len: 0, - last_ping_time: Instant::now(), - last_reported_output_tokens: 0, - } - } - - async fn step(mut self) -> Result, std::io::Error> { - loop { - if let Some(next) = self.out.pop_front() { - return Ok(Some((next, self))); - } - - if self.upstream_ended { - return Ok(None); - } - - match self.upstream.next().await { - Some(Ok(chunk)) => { - self.handle_chunk(&chunk).await?; - } - Some(Err(err)) => { - self.log_usage_once(); - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); - } - None => { - self.upstream_ended = true; - self.finish_stream().await?; - if self.out.is_empty() { - return Ok(None); - } - } - } - } - } - - async fn handle_chunk(&mut self, chunk: &Bytes) -> Result<(), std::io::Error> { - let messages = self - .decoder - .push(chunk) - .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.message))?; - for message in messages { - self.handle_message(&message.payload, &message.event_type) - .await; - } - Ok(()) - } - - async fn finish_stream(&mut self) -> Result<(), std::io::Error> { - let messages = self - .decoder - .finish() - .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.message))?; - for message in messages { - self.handle_message(&message.payload, &message.event_type) - .await; - } - self.flush_thinking_pending().await; - self.finish_message_if_needed(); - self.log_usage_once(); - Ok(()) - } -} - -mod kiro_to_anthropic_stream_blocks; -mod kiro_to_anthropic_stream_handlers; diff --git a/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream/kiro_to_anthropic_stream_blocks.rs b/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream/kiro_to_anthropic_stream_blocks.rs deleted file mode 100644 index 989ab73..0000000 --- a/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream/kiro_to_anthropic_stream_blocks.rs +++ /dev/null @@ -1,357 +0,0 @@ -use axum::body::Bytes; -use serde_json::{json, Map, Value}; -use std::time::Instant; - -use super::{ActiveBlock, KiroToAnthropicState, ToolUseState, USAGE_UPDATE_CHAR_THRESHOLD, USAGE_UPDATE_TIME_INTERVAL, USAGE_UPDATE_TOKEN_DELTA}; -use crate::proxy::log::{build_log_entry, UsageSnapshot}; -use crate::proxy::token_estimator; -use super::super::super::kiro_to_responses_helpers::{ - apply_usage_fallback, - usage_from_kiro, - usage_json_from_kiro, -}; - -impl KiroToAnthropicState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - pub(super) async fn emit_text_delta(&mut self, delta: &str) { - if delta.is_empty() { - return; - } - self.ensure_message_start(); - let index = self.ensure_text_block(); - self.content.push_str(delta); - self.token_tracker.add_output_text(delta).await; - self.out.push_back(super::super::super::anthropic_event_sse( - "content_block_delta", - json!({ - "type": "content_block_delta", - "index": index, - "delta": { "type": "text_delta", "text": delta } - }), - )); - } - - pub(super) async fn emit_thinking_delta(&mut self, delta: &str) { - if delta.is_empty() { - return; - } - self.ensure_message_start(); - let index = self.ensure_thinking_block(); - self.reasoning.push_str(delta); - self.token_tracker.add_output_text(delta).await; - self.out.push_back(super::super::super::anthropic_event_sse( - "content_block_delta", - json!({ - "type": "content_block_delta", - "index": index, - "delta": { "type": "thinking_delta", "thinking": delta } - }), - )); - } - - pub(super) fn ensure_message_start(&mut self) { - if self.sent_message_start { - return; - } - self.sent_message_start = true; - let usage = usage_json_from_kiro(&self.usage).unwrap_or_else(|| json!({ - "input_tokens": 0, - "output_tokens": 0 - })); - let message = json!({ - "id": self.message_id.as_str(), - "type": "message", - "role": "assistant", - "model": self.model.as_str(), - "content": [], - "stop_reason": null, - "stop_sequence": null, - "usage": usage - }); - self.out.push_back(super::super::super::anthropic_event_sse( - "message_start", - json!({ "type": "message_start", "message": message }), - )); - } - - fn ensure_text_block(&mut self) -> usize { - if let Some(ActiveBlock::Text { index }) = self.active_block { - return index; - } - self.stop_active_block(); - let index = self.next_block_index; - self.next_block_index += 1; - self.active_block = Some(ActiveBlock::Text { index }); - self.out.push_back(super::super::super::anthropic_event_sse( - "content_block_start", - json!({ - "type": "content_block_start", - "index": index, - "content_block": { "type": "text", "text": "" } - }), - )); - index - } - - fn ensure_thinking_block(&mut self) -> usize { - if let Some(ActiveBlock::Thinking { index }) = self.active_block { - return index; - } - self.stop_active_block(); - let index = self.next_block_index; - self.next_block_index += 1; - self.active_block = Some(ActiveBlock::Thinking { index }); - self.out.push_back(super::super::super::anthropic_event_sse( - "content_block_start", - json!({ - "type": "content_block_start", - "index": index, - "content_block": { "type": "thinking", "thinking": "" } - }), - )); - index - } - - pub(super) fn ensure_tool_use_block(&mut self, tool_use_id: &str, name: &str) { - self.ensure_message_start(); - if !self.tool_uses.contains_key(tool_use_id) { - let index = self.next_block_index; - self.next_block_index += 1; - self.tool_uses.insert(tool_use_id.to_string(), ToolUseState { - index, - name: name.to_string(), - sent_start: false, - sent_stop: false, - sent_input: false, - }); - } - if let Some(state) = self.tool_uses.get_mut(tool_use_id) { - if state.name.is_empty() { - state.name = name.to_string(); - } - } - if !self.tool_uses.get(tool_use_id).is_some_and(|state| state.sent_start) { - self.start_tool_use_block(tool_use_id); - } - } - - fn start_tool_use_block(&mut self, tool_use_id: &str) { - let Some((index, name, sent_start)) = self.tool_uses.get(tool_use_id).map(|state| { - (state.index, state.name.clone(), state.sent_start) - }) else { - return; - }; - if sent_start { - return; - } - self.stop_active_block(); - if let Some(state) = self.tool_uses.get_mut(tool_use_id) { - state.sent_start = true; - } - self.saw_tool_use = true; - self.active_block = Some(ActiveBlock::ToolUse { - id: tool_use_id.to_string(), - }); - self.out.push_back(super::super::super::anthropic_event_sse( - "content_block_start", - json!({ - "type": "content_block_start", - "index": index, - "content_block": { - "type": "tool_use", - "id": tool_use_id, - "name": name, - "input": {} - } - }), - )); - } - - pub(super) fn emit_tool_use_input(&mut self, tool_use_id: &str, value: &Value) { - let Some((index, sent_input)) = self - .tool_uses - .get(tool_use_id) - .map(|state| (state.index, state.sent_input)) - else { - return; - }; - let input = match value { - Value::String(text) => text.clone(), - Value::Object(obj) => serde_json::to_string(obj).unwrap_or_default(), - Value::Null => String::new(), - other => other.to_string(), - }; - if input.trim().is_empty() { - return; - } - if sent_input && !value.is_string() { - return; - } - self.set_active_tool_use(tool_use_id); - self.out.push_back(super::super::super::anthropic_event_sse( - "content_block_delta", - json!({ - "type": "content_block_delta", - "index": index, - "delta": { "type": "input_json_delta", "partial_json": input } - }), - )); - if let Some(state) = self.tool_uses.get_mut(tool_use_id) { - if !value.is_string() { - state.sent_input = true; - } - } - } - - fn set_active_tool_use(&mut self, tool_use_id: &str) { - if !self.tool_uses.contains_key(tool_use_id) { - return; - } - match &self.active_block { - Some(ActiveBlock::ToolUse { id }) if id == tool_use_id => {} - _ => { - self.stop_active_block(); - self.active_block = Some(ActiveBlock::ToolUse { - id: tool_use_id.to_string(), - }); - } - } - } - - pub(super) fn stop_tool_use_block(&mut self, tool_use_id: &str) { - let Some(state) = self.tool_uses.get_mut(tool_use_id) else { - return; - }; - if state.sent_stop { - return; - } - state.sent_stop = true; - let index = state.index; - self.out.push_back(super::super::super::anthropic_event_sse( - "content_block_stop", - json!({ "type": "content_block_stop", "index": index }), - )); - if matches!(&self.active_block, Some(ActiveBlock::ToolUse { id }) if id == tool_use_id) { - self.active_block = None; - } - } - - fn stop_active_block(&mut self) { - let Some(active) = self.active_block.take() else { - return; - }; - match active { - ActiveBlock::Text { index } | ActiveBlock::Thinking { index } => { - self.out.push_back(super::super::super::anthropic_event_sse( - "content_block_stop", - json!({ "type": "content_block_stop", "index": index }), - )); - } - ActiveBlock::ToolUse { id } => { - self.stop_tool_use_block(&id); - } - } - } - - pub(super) fn finish_message_if_needed(&mut self) { - if self.sent_message_stop { - return; - } - self.ensure_message_start(); - self.stop_active_block(); - - let stop_reason = self.stop_reason.clone().unwrap_or_else(|| { - if self.saw_tool_use { - "tool_use".to_string() - } else { - "end_turn".to_string() - } - }); - apply_usage_fallback( - &mut self.usage, - Some(&self.model), - self.estimated_input_tokens, - &self.content, - &self.reasoning, - ); - let input_tokens = self.usage.input_tokens.unwrap_or(0); - let output_tokens = self.usage.output_tokens.unwrap_or(0); - let mut usage_obj = Map::new(); - usage_obj.insert("input_tokens".to_string(), json!(input_tokens)); - usage_obj.insert("output_tokens".to_string(), json!(output_tokens)); - if let Some(cached) = usage_json_from_kiro(&self.usage) - .and_then(|value| value.get("cache_read_input_tokens").cloned()) - { - usage_obj.insert("cache_read_input_tokens".to_string(), cached); - } - - self.out.push_back(super::super::super::anthropic_event_sse( - "message_delta", - json!({ - "type": "message_delta", - "delta": { "stop_reason": stop_reason, "stop_sequence": null }, - "usage": Value::Object(usage_obj) - }), - )); - self.out.push_back(super::super::super::anthropic_event_sse( - "message_stop", - json!({ "type": "message_stop" }), - )); - self.sent_message_stop = true; - } - - pub(super) fn log_usage_once(&mut self) { - if self.logged { - return; - } - self.logged = true; - apply_usage_fallback( - &mut self.usage, - Some(&self.model), - self.estimated_input_tokens, - &self.content, - &self.reasoning, - ); - let usage_snapshot = UsageSnapshot { - usage: usage_from_kiro(&self.usage), - cached_tokens: None, - usage_json: usage_json_from_kiro(&self.usage), - }; - let entry = build_log_entry(&self.context, usage_snapshot, None); - self.log.clone().write_detached(entry); - } - - pub(super) fn maybe_emit_usage_ping(&mut self) { - let len = self.raw_content.len(); - let should_send = len.saturating_sub(self.last_ping_len) >= USAGE_UPDATE_CHAR_THRESHOLD - || (self.last_ping_time.elapsed() >= USAGE_UPDATE_TIME_INTERVAL && len > self.last_ping_len); - if !should_send { - return; - } - - let output_tokens = token_estimator::estimate_text_tokens(Some(&self.model), &self.raw_content); - if output_tokens > self.last_reported_output_tokens + USAGE_UPDATE_TOKEN_DELTA { - self.ensure_message_start(); - let input_tokens = self.usage.input_tokens.unwrap_or(0); - self.out.push_back(super::super::super::anthropic_event_sse( - "ping", - json!({ - "type": "ping", - "usage": { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - "total_tokens": input_tokens.saturating_add(output_tokens), - "estimated": true - } - }), - )); - self.last_reported_output_tokens = output_tokens; - } - - self.last_ping_len = len; - self.last_ping_time = Instant::now(); - } -} diff --git a/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream/kiro_to_anthropic_stream_handlers.rs b/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream/kiro_to_anthropic_stream_handlers.rs deleted file mode 100644 index 902e1e7..0000000 --- a/src-tauri/src/proxy/response/kiro_to_anthropic/kiro_to_anthropic_stream/kiro_to_anthropic_stream_handlers.rs +++ /dev/null @@ -1,228 +0,0 @@ -use serde_json::{json, Map, Value}; - -use super::KiroToAnthropicState; -use crate::proxy::kiro::tool_parser::process_tool_use_event; -use super::super::kiro_to_anthropic_helpers::split_partial_tag; -use super::super::super::kiro_to_responses_helpers::{ - detect_event_type, - extract_error, - update_stop_reason, - update_usage, -}; - -impl KiroToAnthropicState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - pub(super) async fn handle_message(&mut self, payload: &[u8], event_type: &str) { - if self.sent_message_stop || payload.is_empty() { - return; - } - let Ok(event) = serde_json::from_slice::(payload) else { - return; - }; - let Some(event_obj) = event.as_object() else { - return; - }; - if let Some(error) = extract_error(event_obj) { - if error != "invalidStateEvent" { - self.finish_message_if_needed(); - } - return; - } - if !self.sent_message_start { - self.ensure_message_start(); - } - - update_stop_reason(event_obj, &mut self.stop_reason); - update_usage(event_obj, &mut self.usage); - - let event_type = if !event_type.is_empty() { - event_type - } else { - detect_event_type(event_obj) - }; - - match event_type { - "assistantResponseEvent" => self.handle_assistant_response(event_obj).await, - "toolUseEvent" => self.handle_tool_use_event(event_obj).await, - "reasoningContentEvent" => self.handle_reasoning_content(event_obj).await, - "messageStopEvent" | "message_stop" => { - update_stop_reason(event_obj, &mut self.stop_reason); - } - _ => {} - } - } - - async fn handle_assistant_response(&mut self, event: &Map) { - if let Some(Value::Object(assistant)) = event.get("assistantResponseEvent") { - if let Some(text) = assistant.get("content").and_then(Value::as_str) { - self.handle_text_delta(text).await; - } - if let Some(items) = assistant.get("toolUses").and_then(Value::as_array) { - self.handle_tool_uses(items); - } - update_stop_reason(assistant, &mut self.stop_reason); - } - if let Some(text) = event.get("content").and_then(Value::as_str) { - self.handle_text_delta(text).await; - } - if let Some(items) = event.get("toolUses").and_then(Value::as_array) { - self.handle_tool_uses(items); - } - } - - async fn handle_reasoning_content(&mut self, event: &Map) { - if let Some(Value::Object(reasoning)) = event.get("reasoningContentEvent") { - if let Some(text) = reasoning.get("thinkingText").and_then(Value::as_str) { - self.emit_thinking_delta(text).await; - } - if let Some(text) = reasoning.get("text").and_then(Value::as_str) { - self.emit_thinking_delta(text).await; - } - return; - } - - if let Some(text) = event.get("text").and_then(Value::as_str) { - self.emit_thinking_delta(text).await; - } - } - - async fn handle_text_delta(&mut self, delta: &str) { - if delta.is_empty() { - return; - } - self.raw_content.push_str(delta); - - let mut combined = String::new(); - if !self.thinking_state.pending.is_empty() { - combined.push_str(&self.thinking_state.pending); - self.thinking_state.pending.clear(); - } - combined.push_str(delta); - self.process_thinking_delta(&combined).await; - self.maybe_emit_usage_ping(); - } - - async fn process_thinking_delta(&mut self, input: &str) { - const START: &str = ""; - const END: &str = ""; - - let mut cursor = 0; - while cursor < input.len() { - if self.thinking_state.in_thinking { - if let Some(pos) = input[cursor..].find(END) { - let end = cursor + pos; - if end > cursor { - self.emit_thinking_delta(&input[cursor..end]).await; - } - cursor = end + END.len(); - self.thinking_state.in_thinking = false; - continue; - } - let (emit, pending) = split_partial_tag(&input[cursor..], END); - if !emit.is_empty() { - self.emit_thinking_delta(&emit).await; - } - self.thinking_state.pending = pending; - break; - } - - if let Some(pos) = input[cursor..].find(START) { - let end = cursor + pos; - if end > cursor { - self.emit_text_delta(&input[cursor..end]).await; - } - cursor = end + START.len(); - self.thinking_state.in_thinking = true; - continue; - } - let (emit, pending) = split_partial_tag(&input[cursor..], START); - if !emit.is_empty() { - self.emit_text_delta(&emit).await; - } - self.thinking_state.pending = pending; - break; - } - } - - pub(super) async fn flush_thinking_pending(&mut self) { - if self.thinking_state.pending.is_empty() { - return; - } - let pending = std::mem::take(&mut self.thinking_state.pending); - if self.thinking_state.in_thinking { - self.emit_thinking_delta(&pending).await; - } else { - self.emit_text_delta(&pending).await; - } - } - - async fn handle_tool_use_event(&mut self, event: &Map) { - let (completed, next_state) = - process_tool_use_event(event, self.tool_state.take(), &mut self.processed_tool_keys); - self.tool_state = next_state; - - let source = event - .get("toolUseEvent") - .and_then(Value::as_object) - .unwrap_or(event); - let tool_use_id = tool_use_id(source); - let name = source.get("name").and_then(Value::as_str).unwrap_or(""); - let stop = source.get("stop").and_then(Value::as_bool).unwrap_or(false); - let input_value = source.get("input"); - - if let Some(tool_use_id) = tool_use_id { - if !name.is_empty() { - self.ensure_tool_use_block(tool_use_id, name); - } - if let Some(input_value) = input_value { - self.emit_tool_use_input(tool_use_id, input_value); - } - if stop { - self.stop_tool_use_block(tool_use_id); - } - } - - for tool_use in completed { - self.ensure_tool_use_block(&tool_use.tool_use_id, &tool_use.name); - self.emit_tool_use_input( - &tool_use.tool_use_id, - &Value::Object(tool_use.input.clone()), - ); - self.stop_tool_use_block(&tool_use.tool_use_id); - } - } - - fn handle_tool_uses(&mut self, items: &[Value]) { - for item in items { - let Some(tool) = item.as_object() else { - continue; - }; - let tool_use_id = tool_use_id(tool); - let name = tool.get("name").and_then(Value::as_str).unwrap_or(""); - let input = tool.get("input").cloned().unwrap_or_else(|| json!({})); - let Some(tool_use_id) = tool_use_id else { - continue; - }; - let dedupe_key = format!("id:{tool_use_id}"); - if self.processed_tool_keys.contains(&dedupe_key) { - continue; - } - self.processed_tool_keys.insert(dedupe_key); - if !name.is_empty() { - self.ensure_tool_use_block(tool_use_id, name); - } - self.emit_tool_use_input(tool_use_id, &input); - self.stop_tool_use_block(tool_use_id); - } - } -} - -fn tool_use_id(source: &Map) -> Option<&str> { - source - .get("toolUseId") - .or_else(|| source.get("tool_use_id")) - .and_then(Value::as_str) -} diff --git a/src-tauri/src/proxy/response/kiro_to_responses.rs b/src-tauri/src/proxy/response/kiro_to_responses.rs deleted file mode 100644 index b21fe1a..0000000 --- a/src-tauri/src/proxy/response/kiro_to_responses.rs +++ /dev/null @@ -1,71 +0,0 @@ -use axum::body::Bytes; -use super::super::log::UsageSnapshot; -use super::kiro_to_responses_helpers::{ - apply_usage_fallback, - build_response_object, - usage_from_kiro, - usage_json_from_kiro, -}; - -pub(super) use super::kiro_to_responses_stream::stream_kiro_to_responses; - -pub(super) fn convert_kiro_response( - bytes: &Bytes, - model: Option<&str>, - estimated_input_tokens: Option, -) -> Result { - let parsed = crate::proxy::kiro::parse_event_stream(bytes) - .map_err(|message| format!("Failed to parse Kiro response: {message}"))?; - let mut usage = parsed.usage.clone(); - apply_usage_fallback( - &mut usage, - model, - estimated_input_tokens, - &parsed.content, - &parsed.reasoning, - ); - let now_ms = super::now_ms(); - let response_id = format!("resp_{now_ms}"); - let created_at = (now_ms / 1000) as i64; - let response = build_response_object( - parsed.content, - parsed.reasoning, - parsed.tool_uses, - usage, - parsed.stop_reason.as_deref(), - model, - response_id, - created_at, - ); - serde_json::to_vec(&response) - .map(Bytes::from) - .map_err(|err| format!("Failed to serialize response: {err}")) -} - -pub(super) fn extract_kiro_usage_snapshot( - bytes: &Bytes, - model: Option<&str>, - estimated_input_tokens: Option, -) -> Option { - let parsed = crate::proxy::kiro::parse_event_stream(bytes).ok()?; - let mut usage = parsed.usage.clone(); - apply_usage_fallback( - &mut usage, - model, - estimated_input_tokens, - &parsed.content, - &parsed.reasoning, - ); - let usage_snapshot = UsageSnapshot { - usage: usage_from_kiro(&usage), - cached_tokens: None, - usage_json: usage_json_from_kiro(&usage), - }; - if usage_snapshot.usage.is_none() - && usage_snapshot.usage_json.is_none() - && usage_snapshot.cached_tokens.is_none() - { - return None; - } - Some(usage_snapshot) -} diff --git a/src-tauri/src/proxy/response/kiro_to_responses_helpers.rs b/src-tauri/src/proxy/response/kiro_to_responses_helpers.rs deleted file mode 100644 index d0859d1..0000000 --- a/src-tauri/src/proxy/response/kiro_to_responses_helpers.rs +++ /dev/null @@ -1,378 +0,0 @@ -use serde_json::{json, Map, Value}; - -use super::super::compat_reason; -use super::super::kiro::{KiroToolUse, KiroUsage}; -use super::super::token_estimator; - -pub(super) struct FunctionCallOutput { - pub(super) id: String, - pub(super) output_index: u64, - pub(super) call_id: String, - pub(super) name: String, - pub(super) arguments: String, -} - -pub(super) fn usage_from_kiro(usage: &KiroUsage) -> Option { - if usage.input_tokens.is_none() - && usage.output_tokens.is_none() - && usage.total_tokens.is_none() - { - return None; - } - let total_tokens = usage - .total_tokens - .or_else(|| match (usage.input_tokens, usage.output_tokens) { - (Some(input), Some(output)) => Some(input.saturating_add(output)), - _ => None, - }); - Some(super::super::log::TokenUsage { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - total_tokens, - }) -} - -pub(super) fn usage_json_from_kiro(usage: &KiroUsage) -> Option { - let input_tokens = usage.input_tokens?; - let output_tokens = usage.output_tokens.unwrap_or(0); - let total_tokens = usage - .total_tokens - .or_else(|| input_tokens.checked_add(output_tokens)) - .unwrap_or(input_tokens); - Some(json!({ - "input_tokens": input_tokens, - "input_tokens_details": { "cached_tokens": 0 }, - "output_tokens": output_tokens, - "output_tokens_details": { "reasoning_tokens": 0 }, - "total_tokens": total_tokens - })) -} - -pub(super) fn apply_usage_fallback( - usage: &mut KiroUsage, - model: Option<&str>, - estimated_input_tokens: Option, - content: &str, - reasoning: &str, -) { - if usage.input_tokens.is_none() { - if let Some(pct) = usage.context_usage_percentage { - let input = ((pct * 200000.0) / 100.0).round() as u64; - if input > 0 { - usage.input_tokens = Some(input); - } - } else if let Some(estimate) = estimated_input_tokens { - usage.input_tokens = Some(estimate); - } - } - - if usage.output_tokens.is_none() { - if let (Some(total), Some(input)) = (usage.total_tokens, usage.input_tokens) { - if total >= input { - usage.output_tokens = Some(total - input); - } - } - } - - if usage.output_tokens.is_none() { - let mut output_text = String::new(); - output_text.push_str(content); - if !reasoning.trim().is_empty() { - output_text.push_str(reasoning); - } - if output_text.trim().is_empty() { - return; - } - let estimated = token_estimator::estimate_text_tokens(model, &output_text); - if estimated > 0 { - usage.output_tokens = Some(estimated); - } - } - - if usage.total_tokens.is_none() { - if let (Some(input), Some(output)) = (usage.input_tokens, usage.output_tokens) { - usage.total_tokens = Some(input.saturating_add(output)); - } - } -} - -pub(super) fn build_response_object( - content: String, - reasoning: String, - tool_uses: Vec, - usage: KiroUsage, - stop_reason: Option<&str>, - model: Option<&str>, - response_id: String, - created_at: i64, -) -> Value { - let (status, incomplete_reason) = - compat_reason::responses_status_from_chat_finish_reason(map_stop_reason(stop_reason)); - let status = status.unwrap_or("completed"); - let incomplete_details = incomplete_reason - .map(|reason| json!({ "reason": reason })) - .unwrap_or(Value::Null); - - let usage_value = usage_json_from_kiro(&usage); - let usage_json = usage_value.unwrap_or(Value::Null); - let parallel_tool_calls = tool_uses.len() > 1; - - let mut output = Vec::new(); - if !content.trim().is_empty() || !reasoning.trim().is_empty() || tool_uses.is_empty() { - let mut parts = Vec::new(); - if !reasoning.trim().is_empty() { - parts.push(json!({ "type": "reasoning_text", "text": reasoning })); - } - parts.push(json!({ - "type": "output_text", - "text": content, - "annotations": [] - })); - output.push(json!({ - "type": "message", - "id": "msg_0", - "status": "completed", - "role": "assistant", - "content": parts - })); - } - for (index, tool_use) in tool_uses.iter().enumerate() { - let arguments = serde_json::to_string(&tool_use.input).unwrap_or_default(); - output.push(json!({ - "id": format!("fc_{index}"), - "type": "function_call", - "status": "completed", - "arguments": arguments, - "call_id": tool_use.tool_use_id, - "name": tool_use.name - })); - } - - json!({ - "id": response_id, - "object": "response", - "created_at": created_at, - "status": status, - "error": null, - "incomplete_details": incomplete_details, - "model": model.unwrap_or("unknown"), - "parallel_tool_calls": parallel_tool_calls, - "output": output, - "usage": usage_json - }) -} - -pub(super) fn collect_tool_uses(function_calls: &[Option]) -> Vec { - let mut output = Vec::new(); - for call in function_calls { - let Some(call) = call else { - continue; - }; - let input = - serde_json::from_str::>(&call.arguments).unwrap_or_default(); - output.push(KiroToolUse { - tool_use_id: call.call_id.clone(), - name: call.name.clone(), - input, - }); - } - output -} - -pub(super) fn detect_event_type(event: &Map) -> &str { - for key in [ - "assistantResponseEvent", - "toolUseEvent", - "reasoningContentEvent", - "messageStopEvent", - "message_stop", - "messageMetadataEvent", - "metadataEvent", - "usageEvent", - "usage", - "metricsEvent", - "meteringEvent", - "supplementaryWebLinksEvent", - ] { - if event.contains_key(key) { - return key; - } - } - "" -} - -pub(super) fn extract_error(event: &Map) -> Option { - if let Some(Value::String(err_type)) = event.get("_type") { - let message = event.get("message").and_then(Value::as_str).unwrap_or(""); - return Some(format!("Kiro error: {err_type} {message}")); - } - if let Some(Value::String(kind)) = event.get("type") { - if matches!( - kind.as_str(), - "error" | "exception" | "internalServerException" - ) { - let message = event.get("message").and_then(Value::as_str).unwrap_or(""); - if message.is_empty() { - if let Some(Value::Object(err_obj)) = event.get("error") { - if let Some(text) = err_obj.get("message").and_then(Value::as_str) { - return Some(format!("Kiro error: {text}")); - } - } - } - return Some(format!("Kiro error: {message}")); - } - } - if event.contains_key("invalidStateEvent") - || event - .get("eventType") - .and_then(Value::as_str) - .is_some_and(|value| value == "invalidStateEvent") - { - return Some("invalidStateEvent".to_string()); - } - None -} - -pub(super) fn update_stop_reason(event: &Map, stop_reason: &mut Option) { - if let Some(reason) = event.get("stop_reason").and_then(Value::as_str) { - *stop_reason = Some(reason.to_string()); - } - if let Some(reason) = event.get("stopReason").and_then(Value::as_str) { - *stop_reason = Some(reason.to_string()); - } -} - -pub(super) fn update_usage(event: &Map, usage: &mut KiroUsage) { - if let Some(context_pct) = event.get("contextUsagePercentage").and_then(Value::as_f64) { - usage.context_usage_percentage = Some(context_pct); - } - if let Some(tokens) = event.get("inputTokens").and_then(Value::as_u64) { - usage.input_tokens = Some(tokens); - } - if let Some(tokens) = event.get("outputTokens").and_then(Value::as_u64) { - usage.output_tokens = Some(tokens); - } - if let Some(tokens) = event.get("totalTokens").and_then(Value::as_u64) { - usage.total_tokens = Some(tokens); - } - - if let Some(metadata) = event.get("messageMetadataEvent").and_then(Value::as_object) { - update_usage_from_metadata(metadata, usage); - } else if let Some(metadata) = event.get("metadataEvent").and_then(Value::as_object) { - update_usage_from_metadata(metadata, usage); - } - - if let Some(usage_obj) = event.get("usage").and_then(Value::as_object) { - update_usage_from_usage_obj(usage_obj, usage); - } - if let Some(usage_obj) = event.get("usageEvent").and_then(Value::as_object) { - update_usage_from_usage_obj(usage_obj, usage); - } - - if let Some(links) = event - .get("supplementaryWebLinksEvent") - .and_then(Value::as_object) - { - if let Some(tokens) = links.get("inputTokens").and_then(Value::as_u64) { - usage.input_tokens = Some(tokens); - } - if let Some(tokens) = links.get("outputTokens").and_then(Value::as_u64) { - usage.output_tokens = Some(tokens); - } - } - - if let Some(metrics) = event.get("metricsEvent").and_then(Value::as_object) { - if let Some(tokens) = metrics.get("inputTokens").and_then(Value::as_u64) { - usage.input_tokens = Some(tokens); - } - if let Some(tokens) = metrics.get("outputTokens").and_then(Value::as_u64) { - usage.output_tokens = Some(tokens); - } - } - - if let Some(metering) = event.get("meteringEvent").and_then(Value::as_object) { - if let Some(tokens) = metering.get("inputTokens").and_then(Value::as_u64) { - usage.input_tokens = Some(tokens); - } - if let Some(tokens) = metering.get("outputTokens").and_then(Value::as_u64) { - usage.output_tokens = Some(tokens); - } - if let Some(tokens) = metering.get("totalTokens").and_then(Value::as_u64) { - usage.total_tokens = Some(tokens); - } - } -} - -fn map_stop_reason(stop_reason: Option<&str>) -> Option<&'static str> { - match stop_reason { - Some("max_tokens") => Some("length"), - Some("content_filtered") => Some("content_filter"), - Some("tool_use") => Some("tool_calls"), - Some("stop_sequence") | Some("end_turn") => Some("stop"), - Some(other) if other.is_empty() => None, - Some(_) => Some("stop"), - None => None, - } -} - -fn update_usage_from_metadata(metadata: &Map, usage: &mut KiroUsage) { - if let Some(token_usage) = metadata.get("tokenUsage").and_then(Value::as_object) { - if let Some(tokens) = token_usage.get("outputTokens").and_then(Value::as_u64) { - usage.output_tokens = Some(tokens); - } - if let Some(tokens) = token_usage.get("totalTokens").and_then(Value::as_u64) { - usage.total_tokens = Some(tokens); - } - if let Some(tokens) = token_usage.get("uncachedInputTokens").and_then(Value::as_u64) { - usage.input_tokens = Some(tokens); - } - if let Some(tokens) = token_usage.get("cacheReadInputTokens").and_then(Value::as_u64) { - let current = usage.input_tokens.unwrap_or(0); - usage.input_tokens = Some(current + tokens); - } - if let Some(context_pct) = token_usage - .get("contextUsagePercentage") - .and_then(Value::as_f64) - { - usage.context_usage_percentage = Some(context_pct); - } - } - - if usage.input_tokens.is_none() { - if let Some(tokens) = metadata.get("inputTokens").and_then(Value::as_u64) { - usage.input_tokens = Some(tokens); - } - } - if usage.output_tokens.is_none() { - if let Some(tokens) = metadata.get("outputTokens").and_then(Value::as_u64) { - usage.output_tokens = Some(tokens); - } - } - if usage.total_tokens.is_none() { - if let Some(tokens) = metadata.get("totalTokens").and_then(Value::as_u64) { - usage.total_tokens = Some(tokens); - } - } -} - -fn update_usage_from_usage_obj(usage_obj: &Map, usage: &mut KiroUsage) { - let input_tokens = usage_obj - .get("input_tokens") - .or_else(|| usage_obj.get("prompt_tokens")) - .and_then(Value::as_u64); - let output_tokens = usage_obj - .get("output_tokens") - .or_else(|| usage_obj.get("completion_tokens")) - .and_then(Value::as_u64); - let total_tokens = usage_obj.get("total_tokens").and_then(Value::as_u64); - - if input_tokens.is_some() { - usage.input_tokens = input_tokens; - } - if output_tokens.is_some() { - usage.output_tokens = output_tokens; - } - if total_tokens.is_some() { - usage.total_tokens = total_tokens; - } -} diff --git a/src-tauri/src/proxy/response/kiro_to_responses_stream.rs b/src-tauri/src/proxy/response/kiro_to_responses_stream.rs deleted file mode 100644 index 362b744..0000000 --- a/src-tauri/src/proxy/response/kiro_to_responses_stream.rs +++ /dev/null @@ -1,624 +0,0 @@ -use axum::body::Bytes; -use futures_util::StreamExt; -use serde_json::{json, Map, Value}; -use std::{collections::{HashMap, HashSet, VecDeque}, sync::Arc}; - -use super::super::kiro::{EventStreamDecoder, KiroUsage, KiroToolUse}; -use super::super::kiro::tool_parser::{process_tool_use_event, ToolUseState}; -use super::super::log::{build_log_entry, LogContext, LogWriter, UsageSnapshot}; -use super::super::token_rate::RequestTokenTracker; -use super::kiro_to_responses_helpers::{ - apply_usage_fallback, - build_response_object, - collect_tool_uses, - detect_event_type, - extract_error, - update_stop_reason, - update_usage, - usage_from_kiro, - usage_json_from_kiro, - FunctionCallOutput, -}; - -pub(super) fn stream_kiro_to_responses( - upstream: impl futures_util::stream::Stream> - + Unpin - + Send - + 'static, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - estimated_input_tokens: Option, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = KiroToResponsesState::new( - upstream, - context, - log, - token_tracker, - estimated_input_tokens, - ); - futures_util::stream::try_unfold(state, |state| async move { state.step().await }) -} - -struct MessageOutput { - id: String, - output_index: u64, - text: String, -} - -struct ThinkingStreamState { - in_thinking: bool, - pending: String, -} - -struct KiroToResponsesState { - upstream: S, - decoder: EventStreamDecoder, - log: Arc, - context: LogContext, - token_tracker: RequestTokenTracker, - estimated_input_tokens: Option, - out: VecDeque, - response_id: String, - created_at: i64, - model: String, - next_output_index: u64, - message: Option, - reasoning: String, - thinking_state: ThinkingStreamState, - function_calls: Vec>, - tool_call_by_id: HashMap, - processed_tool_keys: HashSet, - tool_state: Option, - usage: KiroUsage, - stop_reason: Option, - sequence: u64, - sent_done: bool, - logged: bool, - upstream_ended: bool, -} - -impl KiroToResponsesState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new( - upstream: S, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - estimated_input_tokens: Option, - ) -> Self { - let now_ms = super::now_ms(); - let created_at = (now_ms / 1000) as i64; - let model = context - .model - .clone() - .unwrap_or_else(|| "unknown".to_string()); - let mut state = Self { - upstream, - decoder: EventStreamDecoder::new(), - log, - context, - token_tracker, - estimated_input_tokens, - out: VecDeque::new(), - response_id: format!("resp_{now_ms}"), - created_at, - model, - next_output_index: 0, - message: None, - reasoning: String::new(), - thinking_state: ThinkingStreamState { - in_thinking: false, - pending: String::new(), - }, - function_calls: Vec::new(), - tool_call_by_id: HashMap::new(), - processed_tool_keys: HashSet::new(), - tool_state: None, - usage: KiroUsage::default(), - stop_reason: None, - sequence: 0, - sent_done: false, - logged: false, - upstream_ended: false, - }; - state.push_response_created(); - state - } - - async fn step(mut self) -> Result, std::io::Error> { - loop { - if let Some(next) = self.out.pop_front() { - return Ok(Some((next, self))); - } - - if self.upstream_ended { - return Ok(None); - } - - match self.upstream.next().await { - Some(Ok(chunk)) => { - self.handle_chunk(&chunk).await?; - } - Some(Err(err)) => { - self.log_usage_once(); - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); - } - None => { - self.upstream_ended = true; - self.finish_stream().await?; - if self.out.is_empty() { - return Ok(None); - } - } - } - } - } - - async fn handle_chunk(&mut self, chunk: &Bytes) -> Result<(), std::io::Error> { - let messages = self - .decoder - .push(chunk) - .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.message))?; - for message in messages { - self.handle_message(&message.payload, &message.event_type) - .await; - } - Ok(()) - } - - async fn finish_stream(&mut self) -> Result<(), std::io::Error> { - let messages = self - .decoder - .finish() - .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.message))?; - for message in messages { - self.handle_message(&message.payload, &message.event_type) - .await; - } - self.flush_thinking_pending().await; - if !self.sent_done { - self.push_done(); - } - self.log_usage_once(); - Ok(()) - } - - async fn handle_message(&mut self, payload: &[u8], event_type: &str) { - if self.sent_done || payload.is_empty() { - return; - } - let Ok(event) = serde_json::from_slice::(payload) else { - return; - }; - let Some(event_obj) = event.as_object() else { - return; - }; - if let Some(error) = extract_error(event_obj) { - if error != "invalidStateEvent" { - self.push_error(error); - } - return; - } - - update_stop_reason(event_obj, &mut self.stop_reason); - update_usage(event_obj, &mut self.usage); - - let event_type = if !event_type.is_empty() { - event_type - } else { - detect_event_type(event_obj) - }; - - match event_type { - "followupPromptEvent" => {} - "assistantResponseEvent" => self.handle_assistant_response(event_obj).await, - "toolUseEvent" => self.handle_tool_use_event(event_obj).await, - "reasoningContentEvent" => self.handle_reasoning_content(event_obj).await, - "messageStopEvent" | "message_stop" => { - update_stop_reason(event_obj, &mut self.stop_reason); - } - _ => {} - } - } - - async fn handle_assistant_response(&mut self, event: &Map) { - if let Some(Value::Object(assistant)) = event.get("assistantResponseEvent") { - if let Some(text) = assistant.get("content").and_then(Value::as_str) { - self.handle_text_delta(text).await; - } - if let Some(items) = assistant.get("toolUses").and_then(Value::as_array) { - self.handle_tool_uses(items); - } - update_stop_reason(assistant, &mut self.stop_reason); - } - if let Some(text) = event.get("content").and_then(Value::as_str) { - self.handle_text_delta(text).await; - } - if let Some(items) = event.get("toolUses").and_then(Value::as_array) { - self.handle_tool_uses(items); - } - } - - async fn handle_reasoning_content(&mut self, event: &Map) { - if let Some(Value::Object(reasoning)) = event.get("reasoningContentEvent") { - if let Some(text) = reasoning.get("thinkingText").and_then(Value::as_str) { - self.emit_reasoning_delta(text).await; - } - if let Some(text) = reasoning.get("text").and_then(Value::as_str) { - self.emit_reasoning_delta(text).await; - } - } - } - - async fn handle_text_delta(&mut self, delta: &str) { - if delta.is_empty() { - return; - } - let mut combined = String::new(); - if !self.thinking_state.pending.is_empty() { - combined.push_str(&self.thinking_state.pending); - self.thinking_state.pending.clear(); - } - combined.push_str(delta); - self.process_thinking_delta(&combined).await; - } - - async fn emit_text_delta(&mut self, delta: &str) { - if delta.is_empty() { - return; - } - self.ensure_message_output(); - if let Some(message) = self.message.as_mut() { - message.text.push_str(delta); - } - self.token_tracker.add_output_text(delta).await; - let item_id = self.message.as_ref().map(|m| m.id.clone()).unwrap_or_default(); - let output_index = self.message.as_ref().map(|m| m.output_index).unwrap_or(0); - self.push_event(json!({ - "type": "response.output_text.delta", - "item_id": item_id, - "output_index": output_index, - "content_index": 0, - "delta": delta - })); - } - - async fn emit_reasoning_delta(&mut self, delta: &str) { - if delta.is_empty() { - return; - } - self.ensure_message_output(); - self.reasoning.push_str(delta); - self.token_tracker.add_output_text(delta).await; - let item_id = self.message.as_ref().map(|m| m.id.clone()).unwrap_or_default(); - let output_index = self.message.as_ref().map(|m| m.output_index).unwrap_or(0); - self.push_event(json!({ - "type": "response.reasoning_text.delta", - "item_id": item_id, - "output_index": output_index, - "content_index": 0, - "delta": delta - })); - } - - async fn process_thinking_delta(&mut self, input: &str) { - const START: &str = ""; - const END: &str = ""; - - let mut cursor = 0; - // Parse tags incrementally so reasoning never leaks into output_text. - while cursor < input.len() { - if self.thinking_state.in_thinking { - if let Some(pos) = input[cursor..].find(END) { - let end = cursor + pos; - if end > cursor { - self.emit_reasoning_delta(&input[cursor..end]).await; - } - cursor = end + END.len(); - self.thinking_state.in_thinking = false; - continue; - } - let (emit, pending) = split_partial_tag(&input[cursor..], END); - if !emit.is_empty() { - self.emit_reasoning_delta(&emit).await; - } - self.thinking_state.pending = pending; - break; - } - - if let Some(pos) = input[cursor..].find(START) { - let end = cursor + pos; - if end > cursor { - self.emit_text_delta(&input[cursor..end]).await; - } - cursor = end + START.len(); - self.thinking_state.in_thinking = true; - continue; - } - let (emit, pending) = split_partial_tag(&input[cursor..], START); - if !emit.is_empty() { - self.emit_text_delta(&emit).await; - } - self.thinking_state.pending = pending; - break; - } - } - - async fn flush_thinking_pending(&mut self) { - if self.thinking_state.pending.is_empty() { - return; - } - let pending = std::mem::take(&mut self.thinking_state.pending); - if self.thinking_state.in_thinking { - self.emit_reasoning_delta(&pending).await; - } else { - self.emit_text_delta(&pending).await; - } - } - - async fn handle_tool_use_event(&mut self, event: &Map) { - let (completed, next_state) = - process_tool_use_event(event, self.tool_state.take(), &mut self.processed_tool_keys); - self.tool_state = next_state; - for tool_use in completed { - self.ensure_function_call_output(&tool_use); - self.finalize_function_call(&tool_use); - } - } - - fn handle_tool_uses(&mut self, items: &[Value]) { - for item in items { - let Some(tool) = item.as_object() else { - continue; - }; - let tool_use_id = tool - .get("toolUseId") - .or_else(|| tool.get("tool_use_id")) - .and_then(Value::as_str) - .unwrap_or(""); - let dedupe_key = format!("id:{tool_use_id}"); - if tool_use_id.is_empty() || self.processed_tool_keys.contains(&dedupe_key) { - continue; - } - let name = tool.get("name").and_then(Value::as_str).unwrap_or(""); - let input = tool - .get("input") - .and_then(Value::as_object) - .cloned() - .unwrap_or_default(); - self.processed_tool_keys.insert(dedupe_key); - let tool_use = KiroToolUse { - tool_use_id: tool_use_id.to_string(), - name: name.to_string(), - input, - }; - self.ensure_function_call_output(&tool_use); - self.finalize_function_call(&tool_use); - } - } - - fn ensure_message_output(&mut self) { - if self.message.is_some() { - return; - } - let output_index = self.next_output_index; - self.next_output_index += 1; - let message_id = format!("msg_{}", self.response_id); - self.push_event(json!({ - "type": "response.output_item.added", - "output_index": output_index, - "item": { - "id": message_id, - "type": "message", - "status": "in_progress", - "role": "assistant", - "content": [] - } - })); - self.message = Some(MessageOutput { - id: message_id, - output_index, - text: String::new(), - }); - } - - fn ensure_function_call_output(&mut self, tool_use: &KiroToolUse) { - let index = if let Some(index) = self.tool_call_by_id.get(&tool_use.tool_use_id) { - *index - } else { - let index = self.function_calls.len(); - self.tool_call_by_id - .insert(tool_use.tool_use_id.clone(), index); - index - }; - if self.function_calls.len() <= index { - self.function_calls.resize_with(index + 1, || None); - } - - if self.function_calls[index].is_none() { - let output_index = self.next_output_index; - self.next_output_index += 1; - let item_id = format!("fc_{}", tool_use.tool_use_id); - let call_id = tool_use.tool_use_id.clone(); - let name = tool_use.name.clone(); - self.push_event(json!({ - "type": "response.output_item.added", - "output_index": output_index, - "item": { - "id": item_id, - "type": "function_call", - "status": "in_progress", - "call_id": call_id, - "name": name, - "arguments": "" - } - })); - self.function_calls[index] = Some(FunctionCallOutput { - id: format!("fc_{}", tool_use.tool_use_id), - output_index, - call_id: tool_use.tool_use_id.clone(), - name: tool_use.name.clone(), - arguments: String::new(), - }); - } - } - - fn finalize_function_call(&mut self, tool_use: &KiroToolUse) { - let Some(index) = self.tool_call_by_id.get(&tool_use.tool_use_id).copied() else { - return; - }; - let Some(state) = self.function_calls.get_mut(index).and_then(Option::as_mut) else { - return; - }; - if state.arguments.is_empty() { - state.arguments = serde_json::to_string(&tool_use.input).unwrap_or_default(); - } - let item_id = state.id.clone(); - let output_index = state.output_index; - let name = state.name.clone(); - let call_id = state.call_id.clone(); - let arguments = state.arguments.clone(); - self.push_event(json!({ - "type": "response.function_call_arguments.done", - "item_id": item_id, - "output_index": output_index, - "name": name, - "arguments": arguments - })); - self.push_event(json!({ - "type": "response.output_item.done", - "output_index": output_index, - "item": { - "id": item_id, - "type": "function_call", - "status": "completed", - "call_id": call_id, - "name": name, - "arguments": arguments - } - })); - } - - fn push_response_created(&mut self) { - self.push_event(json!({ - "type": "response.created", - "response": { - "id": self.response_id, - "object": "response", - "created_at": self.created_at, - "status": "in_progress", - "model": self.model - } - })); - } - - fn push_done(&mut self) { - if self.sent_done { - return; - } - self.sent_done = true; - self.finalize_usage(); - if self.stop_reason.is_none() { - self.stop_reason = Some(if self.function_calls.iter().any(|call| call.is_some()) { - "tool_use".to_string() - } else { - "end_turn".to_string() - }); - } - let tool_uses = collect_tool_uses(&self.function_calls); - let response = build_response_object( - self.message - .as_ref() - .map(|message| message.text.clone()) - .unwrap_or_default(), - self.reasoning.clone(), - tool_uses, - self.usage.clone(), - self.stop_reason.as_deref(), - Some(&self.model), - self.response_id.clone(), - self.created_at, - ); - self.push_event(json!({ - "type": "response.completed", - "response": response - })); - self.out.push_back(Bytes::from("data: [DONE]\n\n")); - } - - fn push_event(&mut self, mut event: Value) { - if let Some(obj) = event.as_object_mut() { - let sequence_number = self.next_sequence(); - obj.insert("sequence_number".to_string(), Value::Number(sequence_number.into())); - } - self.out.push_back(super::responses_event_sse(event)); - } - - fn push_error(&mut self, message: String) { - if self.sent_done { - return; - } - self.sent_done = true; - self.out.push_back(super::responses_event_sse(json!({ - "type": "response.failed", - "error": { "message": message } - }))); - self.out.push_back(Bytes::from("data: [DONE]\n\n")); - } - - fn next_sequence(&mut self) -> u64 { - self.sequence += 1; - self.sequence - } - - fn log_usage_once(&mut self) { - if self.logged { - return; - } - self.logged = true; - self.finalize_usage(); - let usage_snapshot = UsageSnapshot { - usage: usage_from_kiro(&self.usage), - cached_tokens: None, - usage_json: usage_json_from_kiro(&self.usage), - }; - let entry = build_log_entry(&self.context, usage_snapshot, None); - self.log.clone().write_detached(entry); - } - - fn finalize_usage(&mut self) { - let content = self - .message - .as_ref() - .map(|message| message.text.as_str()) - .unwrap_or(""); - apply_usage_fallback( - &mut self.usage, - Some(&self.model), - self.estimated_input_tokens, - content, - &self.reasoning, - ); - } -} - -fn split_partial_tag(segment: &str, tag: &str) -> (String, String) { - if tag.len() <= 1 || segment.len() < 1 { - return (segment.to_string(), String::new()); - } - let max_len = std::cmp::min(segment.len(), tag.len() - 1); - for len in (1..=max_len).rev() { - if segment.ends_with(&tag[..len]) { - let emit_end = segment.len() - len; - return (segment[..emit_end].to_string(), segment[emit_end..].to_string()); - } - } - (segment.to_string(), String::new()) -} diff --git a/src-tauri/src/proxy/response/responses_to_anthropic.rs b/src-tauri/src/proxy/response/responses_to_anthropic.rs deleted file mode 100644 index 663ef11..0000000 --- a/src-tauri/src/proxy/response/responses_to_anthropic.rs +++ /dev/null @@ -1,692 +0,0 @@ -use axum::body::Bytes; -use futures_util::{stream::try_unfold, StreamExt}; -use serde_json::{json, Map, Value}; -use std::{ - collections::{HashMap, VecDeque}, - sync::Arc, -}; - -use super::super::compat_reason; -use super::super::log::{build_log_entry, LogContext, LogWriter}; -use super::super::sse::SseEventParser; -use super::super::token_rate::RequestTokenTracker; -use super::super::usage::SseUsageCollector; - -pub(super) fn stream_responses_to_anthropic( - upstream: impl futures_util::stream::Stream> - + Unpin - + Send - + 'static, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = ResponsesToAnthropicState::new(upstream, context, log, token_tracker); - try_unfold(state, |state| async move { state.step().await }) -} - -enum ActiveBlock { - Text { index: usize }, - Thinking { index: usize }, - ToolUse { item_id: String }, -} - -struct ToolUseState { - index: usize, - tool_use_id: String, - name: String, - sent_start: bool, - sent_stop: bool, - sent_input: bool, -} - -struct ResponsesToAnthropicState { - upstream: S, - parser: SseEventParser, - collector: SseUsageCollector, - log: Arc, - context: LogContext, - token_tracker: RequestTokenTracker, - out: VecDeque, - message_id: String, - model: String, - sent_message_start: bool, - sent_message_stop: bool, - logged: bool, - upstream_ended: bool, - active_block: Option, - next_block_index: usize, - tool_uses: HashMap, - saw_tool_use: bool, - stop_reason_override: Option<&'static str>, - saw_reasoning_delta: bool, -} - -impl ResponsesToAnthropicState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new( - upstream: S, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - ) -> Self { - let now_ms = super::now_ms(); - let model = context - .model - .clone() - .unwrap_or_else(|| "unknown".to_string()); - Self { - upstream, - parser: SseEventParser::new(), - collector: SseUsageCollector::new(), - log, - context, - token_tracker, - out: VecDeque::new(), - message_id: format!("msg_proxy_{now_ms}"), - model, - sent_message_start: false, - sent_message_stop: false, - logged: false, - upstream_ended: false, - active_block: None, - next_block_index: 0, - tool_uses: HashMap::new(), - saw_tool_use: false, - stop_reason_override: None, - saw_reasoning_delta: false, - } - } - - async fn step(mut self) -> Result, std::io::Error> { - loop { - if let Some(next) = self.out.pop_front() { - return Ok(Some((next, self))); - } - - if self.upstream_ended { - self.log_usage_once(); - return Ok(None); - } - - match self.upstream.next().await { - Some(Ok(chunk)) => { - self.collector.push_chunk(&chunk); - let mut events = Vec::new(); - self.parser.push_chunk(&chunk, |data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - } - Some(Err(err)) => { - self.log_usage_once(); - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); - } - None => { - self.upstream_ended = true; - let mut events = Vec::new(); - self.parser.finish(|data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - self.finish_message_if_needed(); - if self.out.is_empty() { - self.log_usage_once(); - return Ok(None); - } - } - } - } - } - - fn handle_event(&mut self, data: &str, token_texts: &mut Vec) { - if self.sent_message_stop { - return; - } - if data == "[DONE]" { - self.finish_message_if_needed(); - return; - } - let Ok(value) = serde_json::from_str::(data) else { - return; - }; - let Some(event_type) = value.get("type").and_then(Value::as_str) else { - return; - }; - - if event_type.ends_with("output_text.delta") { - self.handle_output_text_delta(&value, token_texts); - return; - } - if event_type.ends_with("reasoning_text.delta") { - self.handle_reasoning_text_delta(&value, token_texts); - return; - } - if event_type.ends_with("output_item.added") { - self.handle_output_item_added(&value); - return; - } - if event_type.ends_with("function_call_arguments.delta") { - self.handle_function_call_arguments_delta(&value); - return; - } - if event_type.ends_with("function_call_arguments.done") { - self.handle_function_call_arguments_done(&value); - return; - } - if event_type.ends_with("output_item.done") { - self.handle_output_item_done(&value); - return; - } - if event_type.ends_with("response.completed") { - self.handle_response_completed(&value); - return; - } - if event_type.ends_with("response.incomplete") { - self.handle_response_incomplete(&value); - return; - } - } - - fn handle_output_text_delta(&mut self, value: &Value, token_texts: &mut Vec) { - let Some(delta) = value.get("delta").and_then(Value::as_str) else { - return; - }; - token_texts.push(delta.to_string()); - self.ensure_message_start(); - let index = self.ensure_text_block(); - self.out.push_back(super::anthropic_event_sse( - "content_block_delta", - json!({ - "type": "content_block_delta", - "index": index, - "delta": { "type": "text_delta", "text": delta } - }), - )); - } - - fn handle_reasoning_text_delta(&mut self, value: &Value, token_texts: &mut Vec) { - let Some(delta) = value.get("delta").and_then(Value::as_str) else { - return; - }; - self.saw_reasoning_delta = true; - token_texts.push(delta.to_string()); - self.ensure_message_start(); - let index = self.ensure_thinking_block(); - self.out.push_back(super::anthropic_event_sse( - "content_block_delta", - json!({ - "type": "content_block_delta", - "index": index, - "delta": { "type": "thinking_delta", "thinking": delta } - }), - )); - } - - fn handle_output_item_added(&mut self, value: &Value) { - let Some(item) = value.get("item").and_then(Value::as_object) else { - return; - }; - if item.get("type").and_then(Value::as_str) != Some("function_call") { - return; - } - let item_id = item.get("id").and_then(Value::as_str).unwrap_or(""); - let call_id = item.get("call_id").and_then(Value::as_str).unwrap_or(""); - let name = item.get("name").and_then(Value::as_str).unwrap_or(""); - - let tool_use_id = if !call_id.is_empty() { - call_id.to_string() - } else if !item_id.is_empty() { - item_id.to_string() - } else { - "tool_use_proxy".to_string() - }; - - self.ensure_message_start(); - self.ensure_tool_use_block(item_id, &tool_use_id, name); - } - - fn handle_function_call_arguments_delta(&mut self, value: &Value) { - let Some(item_id) = value.get("item_id").and_then(Value::as_str) else { - return; - }; - let Some(delta) = value.get("delta").and_then(Value::as_str) else { - return; - }; - self.ensure_message_start(); - self.ensure_tool_use_state(item_id); - if !self.tool_uses.get(item_id).is_some_and(|state| state.sent_start) { - self.start_tool_use_block(item_id); - } - self.set_active_tool_use(item_id); - let Some(index) = self.tool_uses.get(item_id).map(|state| state.index) else { - return; - }; - self.out.push_back(super::anthropic_event_sse( - "content_block_delta", - json!({ - "type": "content_block_delta", - "index": index, - "delta": { "type": "input_json_delta", "partial_json": delta } - }), - )); - } - - fn handle_function_call_arguments_done(&mut self, value: &Value) { - let Some(item_id) = value.get("item_id").and_then(Value::as_str) else { - return; - }; - let arguments = value.get("arguments").and_then(Value::as_str).unwrap_or(""); - self.ensure_message_start(); - self.ensure_tool_use_state(item_id); - self.emit_tool_use_arguments(item_id, arguments); - self.stop_tool_use_block(item_id); - } - - fn handle_output_item_done(&mut self, value: &Value) { - let Some(item) = value.get("item").and_then(Value::as_object) else { - return; - }; - if item.get("type").and_then(Value::as_str) != Some("function_call") { - return; - } - let Some(item_id) = item.get("id").and_then(Value::as_str) else { - return; - }; - self.ensure_message_start(); - self.ensure_tool_use_state(item_id); - self.stop_tool_use_block(item_id); - } - - fn handle_response_completed(&mut self, value: &Value) { - let Some(response) = value.get("response").and_then(Value::as_object) else { - return; - }; - self.handle_response_output_items(response); - self.stop_reason_override = Some(compat_reason::anthropic_stop_reason_from_chat_finish_reason( - compat_reason::chat_finish_reason_from_response_object(response, self.saw_tool_use), - )); - } - - fn handle_response_incomplete(&mut self, value: &Value) { - let Some(response) = value.get("response").and_then(Value::as_object) else { - return; - }; - self.handle_response_output_items(response); - self.stop_reason_override = Some(compat_reason::anthropic_stop_reason_from_chat_finish_reason( - compat_reason::chat_finish_reason_from_response_object(response, self.saw_tool_use), - )); - } - - fn handle_response_output_items(&mut self, response: &Map) { - let Some(output) = response.get("output").and_then(Value::as_array) else { - return; - }; - let mut reasoning_snapshot = String::new(); - for item in output { - let Some(item) = item.as_object() else { - continue; - }; - match item.get("type").and_then(Value::as_str) { - Some("function_call") => { - if let Some(item_id) = item.get("id").and_then(Value::as_str) { - let call_id = item.get("call_id").and_then(Value::as_str).unwrap_or(""); - let name = item.get("name").and_then(Value::as_str).unwrap_or(""); - let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or(""); - let tool_use_id = if !call_id.is_empty() { - call_id.to_string() - } else { - item_id.to_string() - }; - self.ensure_tool_use_block(item_id, &tool_use_id, name); - self.emit_tool_use_arguments(item_id, arguments); - self.stop_tool_use_block(item_id); - } - } - Some("message") => { - if item.get("role").and_then(Value::as_str) != Some("assistant") { - continue; - } - let Some(content) = item.get("content").and_then(Value::as_array) else { - continue; - }; - if reasoning_snapshot.is_empty() { - reasoning_snapshot = extract_reasoning_text(content); - } - } - _ => {} - } - } - self.emit_reasoning_snapshot(&reasoning_snapshot); - } - - fn ensure_message_start(&mut self) { - if self.sent_message_start { - return; - } - self.sent_message_start = true; - - // Usage is best-effort: OpenAI responses stream may not expose input tokens early. - let message = json!({ - "id": self.message_id.as_str(), - "type": "message", - "role": "assistant", - "model": self.model.as_str(), - "content": [], - "stop_reason": null, - "stop_sequence": null, - "usage": { "input_tokens": 0, "output_tokens": 0 } - }); - self.out.push_back(super::anthropic_event_sse( - "message_start", - json!({ "type": "message_start", "message": message }), - )); - } - - fn ensure_text_block(&mut self) -> usize { - if let Some(ActiveBlock::Text { index }) = self.active_block { - return index; - } - - self.stop_active_block(); - let index = self.next_block_index; - self.next_block_index += 1; - self.active_block = Some(ActiveBlock::Text { index }); - self.out.push_back(super::anthropic_event_sse( - "content_block_start", - json!({ - "type": "content_block_start", - "index": index, - "content_block": { "type": "text", "text": "" } - }), - )); - index - } - - fn ensure_thinking_block(&mut self) -> usize { - if let Some(ActiveBlock::Thinking { index }) = self.active_block { - return index; - } - - self.stop_active_block(); - let index = self.next_block_index; - self.next_block_index += 1; - self.active_block = Some(ActiveBlock::Thinking { index }); - self.out.push_back(super::anthropic_event_sse( - "content_block_start", - json!({ - "type": "content_block_start", - "index": index, - "content_block": { "type": "thinking", "thinking": "" } - }), - )); - index - } - - fn emit_reasoning_snapshot(&mut self, text: &str) { - if self.saw_reasoning_delta || text.trim().is_empty() { - return; - } - self.saw_reasoning_delta = true; - self.ensure_message_start(); - let index = self.ensure_thinking_block(); - self.out.push_back(super::anthropic_event_sse( - "content_block_delta", - json!({ - "type": "content_block_delta", - "index": index, - "delta": { "type": "thinking_delta", "thinking": text } - }), - )); - self.stop_active_block(); - } - - fn ensure_tool_use_block(&mut self, item_id: &str, tool_use_id: &str, name: &str) { - if !self.tool_uses.contains_key(item_id) { - let index = self.next_block_index; - self.next_block_index += 1; - self.tool_uses.insert(item_id.to_string(), ToolUseState { - index, - tool_use_id: tool_use_id.to_string(), - name: name.to_string(), - sent_start: false, - sent_stop: false, - sent_input: false, - }); - } - - if let Some(state) = self.tool_uses.get_mut(item_id) { - if state.tool_use_id.is_empty() { - state.tool_use_id = tool_use_id.to_string(); - } - if state.name.is_empty() { - state.name = name.to_string(); - } - } - - if !self.tool_uses.get(item_id).is_some_and(|state| state.sent_start) { - self.start_tool_use_block(item_id); - } - } - - fn ensure_tool_use_state(&mut self, item_id: &str) -> &mut ToolUseState { - self.tool_uses.entry(item_id.to_string()).or_insert_with(|| { - let index = self.next_block_index; - self.next_block_index += 1; - ToolUseState { - index, - tool_use_id: item_id.to_string(), - name: String::new(), - sent_start: false, - sent_stop: false, - sent_input: false, - } - }) - } - - fn start_tool_use_block(&mut self, item_id: &str) { - let Some((index, tool_use_id, name, sent_start)) = self.tool_uses.get(item_id).map(|state| { - ( - state.index, - state.tool_use_id.clone(), - state.name.clone(), - state.sent_start, - ) - }) else { - return; - }; - if sent_start { - return; - } - - self.stop_active_block(); - if let Some(state) = self.tool_uses.get_mut(item_id) { - state.sent_start = true; - } - self.saw_tool_use = true; - self.active_block = Some(ActiveBlock::ToolUse { - item_id: item_id.to_string(), - }); - self.out.push_back(super::anthropic_event_sse( - "content_block_start", - json!({ - "type": "content_block_start", - "index": index, - "content_block": { - "type": "tool_use", - "id": tool_use_id, - "name": name, - "input": {} - } - }), - )); - } - - fn emit_tool_use_arguments(&mut self, item_id: &str, arguments: &str) { - if arguments.trim().is_empty() { - return; - } - let state = self.ensure_tool_use_state(item_id); - if state.sent_input { - return; - } - if !state.sent_start { - self.start_tool_use_block(item_id); - } - self.set_active_tool_use(item_id); - let Some(index) = self.tool_uses.get(item_id).map(|state| state.index) else { - return; - }; - self.out.push_back(super::anthropic_event_sse( - "content_block_delta", - json!({ - "type": "content_block_delta", - "index": index, - "delta": { "type": "input_json_delta", "partial_json": arguments } - }), - )); - if let Some(state) = self.tool_uses.get_mut(item_id) { - state.sent_input = true; - } - } - - fn set_active_tool_use(&mut self, item_id: &str) { - if !self.tool_uses.contains_key(item_id) { - return; - }; - match &self.active_block { - Some(ActiveBlock::ToolUse { item_id: active }) if active == item_id => {} - _ => { - self.stop_active_block(); - self.active_block = Some(ActiveBlock::ToolUse { - item_id: item_id.to_string(), - }); - } - } - } - - fn stop_tool_use_block(&mut self, item_id: &str) { - let Some(state) = self.tool_uses.get_mut(item_id) else { - return; - }; - if state.sent_stop { - return; - } - state.sent_stop = true; - if matches!( - &self.active_block, - Some(ActiveBlock::ToolUse { item_id: active }) if active == item_id - ) { - self.active_block = None; - } - self.out.push_back(super::anthropic_event_sse( - "content_block_stop", - json!({ "type": "content_block_stop", "index": state.index }), - )); - } - - fn stop_active_block(&mut self) { - let Some(active) = self.active_block.take() else { - return; - }; - match active { - ActiveBlock::Text { index } => { - self.out.push_back(super::anthropic_event_sse( - "content_block_stop", - json!({ "type": "content_block_stop", "index": index }), - )); - } - ActiveBlock::Thinking { index } => { - self.out.push_back(super::anthropic_event_sse( - "content_block_stop", - json!({ "type": "content_block_stop", "index": index }), - )); - } - ActiveBlock::ToolUse { item_id } => { - self.stop_tool_use_block(&item_id); - } - } - } - - fn finish_message_if_needed(&mut self) { - if self.sent_message_stop { - return; - } - self.ensure_message_start(); - self.stop_active_block(); - - let stop_reason = self.stop_reason_override.unwrap_or_else(|| { - if self.saw_tool_use { - "tool_use" - } else { - "end_turn" - } - }); - let usage = self.collector.finish(); - let (input_tokens, output_tokens) = usage - .usage - .as_ref() - .map(|u| (u.input_tokens.unwrap_or(0), u.output_tokens.unwrap_or(0))) - .unwrap_or((0, 0)); - let mut usage_obj = Map::new(); - usage_obj.insert("input_tokens".to_string(), json!(input_tokens)); - usage_obj.insert("output_tokens".to_string(), json!(output_tokens)); - if let Some(cached) = usage.cached_tokens { - // Best-effort mapping: treat cached tokens as "cache_read_input_tokens". - usage_obj.insert("cache_read_input_tokens".to_string(), json!(cached)); - } - - self.out.push_back(super::anthropic_event_sse( - "message_delta", - json!({ - "type": "message_delta", - "delta": { "stop_reason": stop_reason, "stop_sequence": null }, - "usage": Value::Object(usage_obj) - }), - )); - self.out.push_back(super::anthropic_event_sse( - "message_stop", - json!({ "type": "message_stop" }), - )); - self.sent_message_stop = true; - } - - fn log_usage_once(&mut self) { - if self.logged { - return; - } - let entry = build_log_entry(&self.context, self.collector.finish(), None); - self.log.clone().write_detached(entry); - self.logged = true; - } -} - -fn extract_reasoning_text(parts: &[Value]) -> String { - let mut reasoning = String::new(); - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - if part.get("type").and_then(Value::as_str) != Some("reasoning_text") { - continue; - } - if let Some(text) = part.get("text").and_then(Value::as_str) { - reasoning.push_str(text); - } - } - reasoning -} diff --git a/src-tauri/src/proxy/response/responses_to_chat.rs b/src-tauri/src/proxy/response/responses_to_chat.rs deleted file mode 100644 index 52ca7ad..0000000 --- a/src-tauri/src/proxy/response/responses_to_chat.rs +++ /dev/null @@ -1,600 +0,0 @@ -use axum::body::Bytes; -use futures_util::{stream::try_unfold, StreamExt}; -use serde_json::{json, Map, Value}; -use std::{ - collections::{HashMap, VecDeque}, - sync::Arc, -}; - -use super::super::compat_content; -use super::super::compat_reason; -use super::super::log::{build_log_entry, LogContext, LogWriter}; -use super::super::sse::SseEventParser; -use super::super::token_rate::RequestTokenTracker; -use super::super::usage::SseUsageCollector; - -pub(super) fn stream_responses_to_chat( - upstream: impl futures_util::stream::Stream> - + Unpin - + Send - + 'static, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = ResponsesToChatState::new(upstream, context, log, token_tracker); - try_unfold(state, |state| async move { state.step().await }) -} - -struct ResponsesToChatState { - upstream: S, - parser: SseEventParser, - collector: SseUsageCollector, - log: Arc, - context: LogContext, - token_tracker: RequestTokenTracker, - out: VecDeque, - chat_id: String, - created: i64, - model: String, - sent_role: bool, - sent_done: bool, - logged: bool, - upstream_ended: bool, - tool_calls: Vec, - tool_calls_by_item_id: HashMap, - // 非文本输出只透传一次,避免重复注入。 - content_parts_sent: bool, - finish_reason_override: Option<&'static str>, - saw_text_delta: bool, - saw_reasoning_delta: bool, -} - -struct ToolCallState { - index: usize, - call_id: String, - name: String, - arguments: String, - sent_initial: bool, - sent_arguments: bool, -} - -impl ResponsesToChatState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new( - upstream: S, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, - ) -> Self { - let now_ms = super::now_ms(); - Self { - upstream, - parser: SseEventParser::new(), - collector: SseUsageCollector::new(), - log, - token_tracker, - model: context - .model - .clone() - .unwrap_or_else(|| "unknown".to_string()), - context, - out: VecDeque::new(), - chat_id: format!("chatcmpl_proxy_{now_ms}"), - created: (now_ms / 1000) as i64, - sent_role: false, - sent_done: false, - logged: false, - upstream_ended: false, - tool_calls: Vec::new(), - tool_calls_by_item_id: HashMap::new(), - content_parts_sent: false, - finish_reason_override: None, - saw_text_delta: false, - saw_reasoning_delta: false, - } - } - - async fn step(mut self) -> Result, std::io::Error> { - loop { - if let Some(next) = self.out.pop_front() { - return Ok(Some((next, self))); - } - - if self.upstream_ended { - return Ok(None); - } - - match self.upstream.next().await { - Some(Ok(chunk)) => { - if self.context.ttfb_ms.is_none() { - self.context.ttfb_ms = Some(self.context.start.elapsed().as_millis()); - } - self.collector.push_chunk(&chunk); - let mut events = Vec::new(); - self.parser.push_chunk(&chunk, |data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - } - Some(Err(err)) => { - self.log_usage_once(); - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); - } - None => { - self.upstream_ended = true; - let mut events = Vec::new(); - self.parser.finish(|data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - self.handle_event(&data, &mut texts); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - if !self.sent_done { - self.push_done(); - } - self.log_usage_once(); - if self.out.is_empty() { - return Ok(None); - } - } - } - } - } - - fn handle_event(&mut self, data: &str, token_texts: &mut Vec) { - if self.sent_done { - return; - } - if data == "[DONE]" { - self.push_done(); - return; - } - let Ok(value) = serde_json::from_str::(data) else { - return; - }; - let Some(event_type) = value.get("type").and_then(Value::as_str) else { - return; - }; - if event_type.ends_with("output_text.delta") { - self.handle_output_text_delta(&value, token_texts); - return; - } - if event_type.ends_with("reasoning_text.delta") { - self.handle_reasoning_text_delta(&value, token_texts); - return; - } - if event_type.ends_with("function_call_arguments.delta") { - self.handle_function_call_arguments_delta(&value); - return; - } - if event_type.ends_with("function_call_arguments.done") { - self.handle_function_call_arguments_done(&value); - return; - } - if event_type.ends_with("output_item.added") { - self.handle_output_item_added(&value); - return; - } - if event_type.ends_with("output_item.done") { - self.handle_output_item_done(&value); - return; - } - if event_type.ends_with("response.completed") { - self.handle_response_completed(&value); - return; - } - if event_type.ends_with("response.incomplete") { - self.handle_response_incomplete(&value); - } - } - - fn handle_output_text_delta(&mut self, value: &Value, token_texts: &mut Vec) { - let Some(delta) = value.get("delta").and_then(Value::as_str) else { - return; - }; - self.saw_text_delta = true; - token_texts.push(delta.to_string()); - self.ensure_role_sent(); - self.out.push_back(chat_chunk_sse( - &self.chat_id, - self.created, - &self.model, - json!({ "content": delta }), - None, - )); - } - - fn handle_reasoning_text_delta(&mut self, value: &Value, token_texts: &mut Vec) { - let Some(delta) = value.get("delta").and_then(Value::as_str) else { - return; - }; - self.saw_reasoning_delta = true; - token_texts.push(delta.to_string()); - self.ensure_role_sent(); - self.out.push_back(chat_chunk_sse( - &self.chat_id, - self.created, - &self.model, - json!({ "reasoning_content": delta }), - None, - )); - } - - fn handle_output_item_added(&mut self, value: &Value) { - let Some(item) = value.get("item").and_then(Value::as_object) else { - return; - }; - let Some(item_type) = item.get("type").and_then(Value::as_str) else { - return; - }; - if item_type == "function_call" { - self.handle_function_call_item_added(item); - } - } - - fn handle_function_call_item_added(&mut self, item: &Map) { - let Some(item_id) = item.get("id").and_then(Value::as_str) else { - return; - }; - let call_id = item.get("call_id").and_then(Value::as_str); - let name = item.get("name").and_then(Value::as_str); - - let (index, call_id, name, should_emit) = { - let state = self.ensure_tool_call_state(item_id, call_id, name); - let should_emit = !state.sent_initial; - state.sent_initial = true; - ( - state.index, - state.call_id.clone(), - state.name.clone(), - should_emit, - ) - }; - if should_emit { - let id = tool_call_id(&call_id, item_id); - self.push_tool_call_delta(index, &id, &name, ""); - } - } - - fn handle_function_call_arguments_delta(&mut self, value: &Value) { - let Some(item_id) = value.get("item_id").and_then(Value::as_str) else { - return; - }; - let Some(delta) = value.get("delta").and_then(Value::as_str) else { - return; - }; - let (index, call_id, name) = { - let state = self.ensure_tool_call_state(item_id, None, None); - state.arguments.push_str(delta); - state.sent_initial = true; - state.sent_arguments = true; - (state.index, state.call_id.clone(), state.name.clone()) - }; - let id = tool_call_id(&call_id, item_id); - self.push_tool_call_delta(index, &id, &name, delta); - } - - fn handle_function_call_arguments_done(&mut self, value: &Value) { - let Some(item_id) = value.get("item_id").and_then(Value::as_str) else { - return; - }; - let arguments = value.get("arguments").and_then(Value::as_str).unwrap_or(""); - let name = value.get("name").and_then(Value::as_str); - - let (index, call_id, name, should_emit) = { - let state = self.ensure_tool_call_state(item_id, None, name); - if !arguments.is_empty() { - state.arguments = arguments.to_string(); - } - let should_emit = !arguments.is_empty() && !state.sent_arguments; - state.sent_initial = true; - if should_emit { - state.sent_arguments = true; - } - (state.index, state.call_id.clone(), state.name.clone(), should_emit) - }; - if should_emit { - let id = tool_call_id(&call_id, item_id); - self.push_tool_call_delta(index, &id, &name, arguments); - } - } - - fn handle_output_item_done(&mut self, value: &Value) { - let Some(item) = value.get("item").and_then(Value::as_object) else { - return; - }; - let Some(item_type) = item.get("type").and_then(Value::as_str) else { - return; - }; - match item_type { - "function_call" => self.handle_function_call_item_snapshot(item), - "message" => self.handle_message_item_snapshot(item), - _ => {} - } - } - - fn handle_response_completed(&mut self, value: &Value) { - let Some(response) = value.get("response").and_then(Value::as_object) else { - return; - }; - self.handle_response_output_items(response); - self.finish_reason_override = Some(compat_reason::chat_finish_reason_from_response_object( - response, - !self.tool_calls.is_empty(), - )); - } - - fn handle_response_incomplete(&mut self, value: &Value) { - let Some(response) = value.get("response").and_then(Value::as_object) else { - return; - }; - self.handle_response_output_items(response); - self.finish_reason_override = Some(compat_reason::chat_finish_reason_from_response_object( - response, - !self.tool_calls.is_empty(), - )); - } - - fn handle_response_output_items(&mut self, response: &Map) { - let Some(output) = response.get("output").and_then(Value::as_array) else { - return; - }; - for item in output { - let Some(item) = item.as_object() else { - continue; - }; - match item.get("type").and_then(Value::as_str) { - Some("function_call") => self.handle_function_call_item_snapshot(item), - Some("message") => self.handle_message_item_snapshot(item), - _ => {} - } - } - } - - fn handle_function_call_item_snapshot(&mut self, item: &Map) { - let Some(item_id) = item.get("id").and_then(Value::as_str) else { - return; - }; - let call_id = item.get("call_id").and_then(Value::as_str); - let name = item.get("name").and_then(Value::as_str); - let arguments = item.get("arguments").and_then(Value::as_str).unwrap_or(""); - - let (index, call_id, name, should_emit) = { - let state = self.ensure_tool_call_state(item_id, call_id, name); - if !arguments.is_empty() { - state.arguments = arguments.to_string(); - } - let should_emit = !arguments.is_empty() && !state.sent_arguments; - state.sent_initial = true; - if should_emit { - state.sent_arguments = true; - } - (state.index, state.call_id.clone(), state.name.clone(), should_emit) - }; - if should_emit { - let id = tool_call_id(&call_id, item_id); - self.push_tool_call_delta(index, &id, &name, arguments); - } - } - - fn handle_message_item_snapshot(&mut self, item: &Map) { - if item.get("role").and_then(Value::as_str) != Some("assistant") { - return; - } - let Some(content) = item.get("content").and_then(Value::as_array) else { - return; - }; - self.maybe_emit_content_parts(content); - } - - fn ensure_tool_call_state( - &mut self, - item_id: &str, - call_id: Option<&str>, - name: Option<&str>, - ) -> &mut ToolCallState { - let index = if let Some(index) = self.tool_calls_by_item_id.get(item_id) { - *index - } else { - let index = self.tool_calls.len(); - self.tool_calls_by_item_id - .insert(item_id.to_string(), index); - self.tool_calls.push(ToolCallState { - index, - call_id: String::new(), - name: String::new(), - arguments: String::new(), - sent_initial: false, - sent_arguments: false, - }); - index - }; - - let state = self.tool_calls.get_mut(index).expect("tool call state"); - if let Some(call_id) = call_id { - if state.call_id.is_empty() { - state.call_id = call_id.to_string(); - } - } - if let Some(name) = name { - if state.name.is_empty() { - state.name = name.to_string(); - } - } - state - } - - fn maybe_emit_content_parts(&mut self, parts: &[Value]) { - if !self.saw_reasoning_delta { - let reasoning_text = extract_reasoning_text(parts); - if !reasoning_text.trim().is_empty() { - self.saw_reasoning_delta = true; - self.ensure_role_sent(); - self.out.push_back(chat_chunk_sse( - &self.chat_id, - self.created, - &self.model, - json!({ "reasoning_content": reasoning_text }), - None, - )); - } - } - - if self.content_parts_sent { - return; - } - let non_text_parts = compat_content::chat_message_non_text_parts_from_responses(parts); - let delta_content = if !non_text_parts.is_empty() { - Value::Array(non_text_parts) - } else if !self.saw_text_delta { - compat_content::chat_message_content_from_responses_parts(parts) - } else { - return; - }; - self.ensure_role_sent(); - self.out.push_back(chat_chunk_sse( - &self.chat_id, - self.created, - &self.model, - json!({ "content": delta_content }), - None, - )); - self.content_parts_sent = true; - } - - fn ensure_role_sent(&mut self) { - if self.sent_role { - return; - } - self.sent_role = true; - self.out.push_back(chat_chunk_sse( - &self.chat_id, - self.created, - &self.model, - json!({ "role": "assistant", "content": "" }), - None, - )); - } - - fn push_tool_call_delta(&mut self, index: usize, id: &str, name: &str, arguments: &str) { - self.ensure_role_sent(); - let mut function = Map::new(); - if !name.is_empty() { - function.insert("name".to_string(), Value::String(name.to_string())); - } - function.insert( - "arguments".to_string(), - Value::String(arguments.to_string()), - ); - let tool_call = json!({ - "index": index, - "id": id, - "type": "function", - "function": Value::Object(function) - }); - self.out.push_back(chat_chunk_sse( - &self.chat_id, - self.created, - &self.model, - json!({ "tool_calls": [tool_call] }), - None, - )); - } - - fn push_done(&mut self) { - if self.sent_done { - return; - } - self.sent_done = true; - self.out.push_back(chat_chunk_sse( - &self.chat_id, - self.created, - &self.model, - json!({}), - Some(self.finish_reason()), - )); - self.out.push_back(Bytes::from("data: [DONE]\n\n")); - } - - fn log_usage_once(&mut self) { - if self.logged { - return; - } - self.logged = true; - let entry = build_log_entry(&self.context, self.collector.finish(), None); - self.log.clone().write_detached(entry); - } - - fn finish_reason(&self) -> &'static str { - if let Some(reason) = self.finish_reason_override { - return reason; - } - if self.tool_calls.is_empty() { - "stop" - } else { - "tool_calls" - } - } -} - -fn chat_chunk_sse( - id: &str, - created: i64, - model: &str, - delta: Value, - finish_reason: Option<&str>, -) -> Bytes { - let chunk = json!({ - "id": id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "delta": delta, - "finish_reason": finish_reason - } - ] - }); - Bytes::from(format!("data: {}\n\n", chunk.to_string())) -} - -fn extract_reasoning_text(parts: &[Value]) -> String { - let mut reasoning = String::new(); - for part in parts { - let Some(part) = part.as_object() else { - continue; - }; - if part.get("type").and_then(Value::as_str) != Some("reasoning_text") { - continue; - } - if let Some(text) = part.get("text").and_then(Value::as_str) { - reasoning.push_str(text); - } - } - reasoning -} - -fn tool_call_id(call_id: &str, item_id: &str) -> String { - if !call_id.is_empty() { - call_id.to_string() - } else if !item_id.is_empty() { - item_id.to_string() - } else { - "call_proxy".to_string() - } -} diff --git a/src-tauri/src/proxy/response/streaming.rs b/src-tauri/src/proxy/response/streaming.rs deleted file mode 100644 index b214f7a..0000000 --- a/src-tauri/src/proxy/response/streaming.rs +++ /dev/null @@ -1,291 +0,0 @@ -use axum::body::Bytes; -use futures_util::{stream::try_unfold, StreamExt}; -use serde_json::Value; -use std::{collections::VecDeque, sync::Arc}; - -use super::{ - PROVIDER_ANTHROPIC, PROVIDER_ANTIGRAVITY, PROVIDER_CODEX, PROVIDER_GEMINI, PROVIDER_OPENAI, - PROVIDER_OPENAI_RESPONSES, -}; -use super::super::log::{build_log_entry, LogContext, LogWriter}; -use super::super::model; -use super::super::sse::SseEventParser; -use super::super::token_rate::RequestTokenTracker; -use super::super::usage::SseUsageCollector; - -pub(super) fn stream_with_logging( - upstream: impl futures_util::stream::Stream> + Unpin + Send + 'static, - context: LogContext, - log: Arc, - token_tracker: RequestTokenTracker, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let collector = SseUsageCollector::new(); - let parser = SseEventParser::new(); - try_unfold( - (upstream, collector, parser, log, context, token_tracker), - |(mut upstream, mut collector, mut parser, log, mut context, token_tracker)| async move { - match upstream.next().await { - Some(Ok(chunk)) => { - if context.ttfb_ms.is_none() { - context.ttfb_ms = Some(context.start.elapsed().as_millis()); - } - collector.push_chunk(&chunk); - let provider = context.provider.as_str(); - let mut texts = Vec::new(); - parser.push_chunk(&chunk, |data| { - if let Some(text) = extract_stream_text(provider, &data) { - texts.push(text); - } - }); - for text in texts { - token_tracker.add_output_text(&text).await; - } - Ok(Some((chunk, (upstream, collector, parser, log, context, token_tracker)))) - } - Some(Err(err)) => { - let provider = context.provider.as_str(); - let mut texts = Vec::new(); - parser.finish(|data| { - if let Some(text) = extract_stream_text(provider, &data) { - texts.push(text); - } - }); - for text in texts { - token_tracker.add_output_text(&text).await; - } - let entry = build_log_entry(&context, collector.finish(), None); - log.clone().write_detached(entry); - Err(std::io::Error::new(std::io::ErrorKind::Other, err)) - } - None => { - let provider = context.provider.as_str(); - let mut texts = Vec::new(); - parser.finish(|data| { - if let Some(text) = extract_stream_text(provider, &data) { - texts.push(text); - } - }); - for text in texts { - token_tracker.add_output_text(&text).await; - } - let entry = build_log_entry(&context, collector.finish(), None); - log.clone().write_detached(entry); - Ok(None) - } - } - }, - ) -} - -pub(super) fn stream_with_logging_and_model_override( - upstream: impl futures_util::stream::Stream> + Unpin + Send + 'static, - context: LogContext, - log: Arc, - model_override: String, - token_tracker: RequestTokenTracker, -) -> impl futures_util::stream::Stream> + Send -where - E: std::error::Error + Send + Sync + 'static, -{ - let state = ModelOverrideStreamState::new(upstream, context, log, model_override, token_tracker); - try_unfold(state, |state| async move { state.step().await }) -} - -struct ModelOverrideStreamState { - upstream: S, - parser: SseEventParser, - collector: SseUsageCollector, - log: Arc, - context: LogContext, - token_tracker: RequestTokenTracker, - out: VecDeque, - model_override: String, - upstream_ended: bool, - logged: bool, -} - -impl ModelOverrideStreamState -where - S: futures_util::stream::Stream> + Unpin + Send + 'static, - E: std::error::Error + Send + Sync + 'static, -{ - fn new( - upstream: S, - context: LogContext, - log: Arc, - model_override: String, - token_tracker: RequestTokenTracker, - ) -> Self { - Self { - upstream, - parser: SseEventParser::new(), - collector: SseUsageCollector::new(), - log, - context, - token_tracker, - out: VecDeque::new(), - model_override, - upstream_ended: false, - logged: false, - } - } - - async fn step(mut self) -> Result, std::io::Error> { - loop { - if let Some(next) = self.out.pop_front() { - return Ok(Some((next, self))); - } - if self.upstream_ended { - self.log_usage_once(); - return Ok(None); - } - - match self.upstream.next().await { - Some(Ok(chunk)) => { - if self.context.ttfb_ms.is_none() { - self.context.ttfb_ms = Some(self.context.start.elapsed().as_millis()); - } - self.collector.push_chunk(&chunk); - let mut events = Vec::new(); - self.parser.push_chunk(&chunk, |data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - if let Some(text) = extract_stream_text(&self.context.provider, &data) { - texts.push(text); - } - self.push_event_output(&data); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - } - Some(Err(err)) => { - self.log_usage_once(); - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); - } - None => { - self.upstream_ended = true; - let mut events = Vec::new(); - self.parser.finish(|data| events.push(data)); - let mut texts = Vec::new(); - for data in events { - if let Some(text) = extract_stream_text(&self.context.provider, &data) { - texts.push(text); - } - self.push_event_output(&data); - } - for text in texts { - self.token_tracker.add_output_text(&text).await; - } - } - } - } - } - - fn push_event_output(&mut self, data: &str) { - let output = rewrite_sse_data(data, &self.model_override); - self.out.push_back(Bytes::from(format!("data: {output}\n\n"))); - } - - fn log_usage_once(&mut self) { - if self.logged { - return; - } - let entry = build_log_entry(&self.context, self.collector.finish(), None); - self.log.clone().write_detached(entry); - self.logged = true; - } -} - -fn rewrite_sse_data(data: &str, model_override: &str) -> String { - if data == "[DONE]" { - return data.to_string(); - } - let bytes = Bytes::copy_from_slice(data.as_bytes()); - model::rewrite_response_model(&bytes, model_override) - .and_then(|bytes| String::from_utf8(bytes.to_vec()).ok()) - .unwrap_or_else(|| data.to_string()) -} - -fn extract_stream_text(provider: &str, data: &str) -> Option { - if data == "[DONE]" { - return None; - } - let Ok(value) = serde_json::from_str::(data) else { - return None; - }; - - match provider { - PROVIDER_OPENAI | PROVIDER_OPENAI_RESPONSES | PROVIDER_CODEX => { - extract_openai_stream_text(&value) - } - PROVIDER_ANTHROPIC => extract_anthropic_stream_text(&value), - PROVIDER_GEMINI | PROVIDER_ANTIGRAVITY => extract_gemini_stream_text(&value), - _ => None, - } - .or_else(|| extract_fallback_stream_text(&value)) -} - -fn extract_openai_stream_text(value: &Value) -> Option { - let delta = value - .get("choices") - .and_then(Value::as_array) - .and_then(|choices| choices.first()) - .and_then(|choice| choice.get("delta")) - .and_then(|delta| delta.get("content")) - .and_then(Value::as_str); - if let Some(delta) = delta { - return Some(delta.to_string()); - } - let event_type = value.get("type").and_then(Value::as_str)?; - if event_type.ends_with("output_text.delta") { - return value - .get("delta") - .and_then(Value::as_str) - .map(|text| text.to_string()); - } - None -} - -fn extract_anthropic_stream_text(value: &Value) -> Option { - if let Some(delta) = value.get("delta") { - if let Some(text) = delta.get("text").and_then(Value::as_str) { - return Some(text.to_string()); - } - if let Some(text) = delta.as_str() { - return Some(text.to_string()); - } - } - value - .get("content_block") - .and_then(|block| block.get("text")) - .and_then(Value::as_str) - .map(|text| text.to_string()) -} - -fn extract_gemini_stream_text(value: &Value) -> Option { - let candidates = value.get("candidates").and_then(Value::as_array)?; - for candidate in candidates { - if let Some(content) = candidate.get("content") { - if let Some(parts) = content.get("parts").and_then(Value::as_array) { - for part in parts { - if let Some(text) = part.get("text").and_then(Value::as_str) { - return Some(text.to_string()); - } - } - } - } - } - None -} - -fn extract_fallback_stream_text(value: &Value) -> Option { - value - .get("delta") - .and_then(Value::as_str) - .or_else(|| value.get("text").and_then(Value::as_str)) - .map(|text| text.to_string()) -} diff --git a/src-tauri/src/proxy/response/token_count.rs b/src-tauri/src/proxy/response/token_count.rs deleted file mode 100644 index 70cb054..0000000 --- a/src-tauri/src/proxy/response/token_count.rs +++ /dev/null @@ -1,95 +0,0 @@ -use axum::body::Bytes; -use serde_json::Value; - -use super::super::token_rate::RequestTokenTracker; - -pub(super) async fn apply_output_tokens_from_response( - tracker: &RequestTokenTracker, - provider: &str, - bytes: &Bytes, -) { - let Ok(value) = serde_json::from_slice::(bytes) else { - return; - }; - let mut texts = Vec::new(); - - match provider { - "openai" | "openai-response" | "codex" => { - if let Some(choices) = value.get("choices").and_then(Value::as_array) { - for choice in choices { - if let Some(content) = choice - .get("message") - .and_then(|message| message.get("content")) - { - if let Some(text) = content.as_str() { - texts.push(text.to_string()); - } else if let Some(parts) = content.as_array() { - for part in parts { - if let Some(text) = part.get("text").and_then(Value::as_str) { - texts.push(text.to_string()); - } - } - } - } - if let Some(text) = choice.get("text").and_then(Value::as_str) { - texts.push(text.to_string()); - } - } - } - if texts.is_empty() { - if let Some(output) = value.get("output").and_then(Value::as_array) { - collect_responses_output(output, &mut texts); - } - } - } - "anthropic" => { - if let Some(content) = value.get("content").and_then(Value::as_array) { - for item in content { - if let Some(text) = item.get("text").and_then(Value::as_str) { - texts.push(text.to_string()); - } - } - } - } - "gemini" => { - if let Some(candidates) = value.get("candidates").and_then(Value::as_array) { - collect_gemini_output(candidates, &mut texts); - } - } - _ => {} - } - - if texts.is_empty() { - return; - } - - for text in texts { - tracker.add_output_text(&text).await; - } -} - -fn collect_responses_output(output: &[Value], texts: &mut Vec) { - for item in output { - if let Some(content) = item.get("content").and_then(Value::as_array) { - for part in content { - if let Some(text) = part.get("text").and_then(Value::as_str) { - texts.push(text.to_string()); - } - } - } - } -} - -fn collect_gemini_output(candidates: &[Value], texts: &mut Vec) { - for candidate in candidates { - if let Some(content) = candidate.get("content") { - if let Some(parts) = content.get("parts").and_then(Value::as_array) { - for part in parts { - if let Some(text) = part.get("text").and_then(Value::as_str) { - texts.push(text.to_string()); - } - } - } - } - } -} diff --git a/src-tauri/src/proxy/response/upstream_read.rs b/src-tauri/src/proxy/response/upstream_read.rs deleted file mode 100644 index 6a30462..0000000 --- a/src-tauri/src/proxy/response/upstream_read.rs +++ /dev/null @@ -1,23 +0,0 @@ -use axum::body::Bytes; -use futures_util::StreamExt; - -use super::upstream_stream::{self, UpstreamStreamError}; -use super::super::log::LogContext; - -pub(super) async fn read_upstream_bytes_with_ttfb( - upstream_res: reqwest::Response, - context: &mut LogContext, -) -> Result> { - let mut upstream = upstream_stream::with_idle_timeout(upstream_res.bytes_stream()); - let mut out = Vec::new(); - - while let Some(item) = upstream.next().await { - let chunk = item?; - if context.ttfb_ms.is_none() { - context.ttfb_ms = Some(context.start.elapsed().as_millis()); - } - out.extend_from_slice(chunk.as_ref()); - } - - Ok(Bytes::from(out)) -} diff --git a/src-tauri/src/proxy/response/upstream_stream.rs b/src-tauri/src/proxy/response/upstream_stream.rs deleted file mode 100644 index 9c00712..0000000 --- a/src-tauri/src/proxy/response/upstream_stream.rs +++ /dev/null @@ -1,53 +0,0 @@ -use axum::body::Bytes; -use futures_util::{stream::try_unfold, StreamExt}; -use std::{error::Error, fmt, time::Duration}; - -use crate::proxy::UPSTREAM_NO_DATA_TIMEOUT; - -#[derive(Debug)] -pub(crate) enum UpstreamStreamError { - IdleTimeout(Duration), - Upstream(E), -} - -impl fmt::Display for UpstreamStreamError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::IdleTimeout(duration) => { - write!(f, "Upstream stream idle timeout after {}s.", duration.as_secs()) - } - Self::Upstream(err) => write!(f, "{err}"), - } - } -} - -impl Error for UpstreamStreamError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - match self { - Self::IdleTimeout(_) => None, - Self::Upstream(err) => Some(err), - } - } -} - -pub(super) fn with_idle_timeout( - upstream: impl futures_util::stream::Stream> + Unpin + Send + 'static, -) -> futures_util::stream::BoxStream<'static, Result>> -where - E: Error + Send + Sync + 'static, -{ - try_unfold(upstream, |mut upstream| async move { - match tokio::time::timeout(UPSTREAM_NO_DATA_TIMEOUT, upstream.next()).await { - Ok(Some(Ok(chunk))) => Ok(Some((chunk, upstream))), - Ok(Some(Err(err))) => Err(UpstreamStreamError::Upstream(err)), - Ok(None) => Ok(None), - Err(_) => Err(UpstreamStreamError::IdleTimeout(UPSTREAM_NO_DATA_TIMEOUT)), - } - }) - .boxed() -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "upstream_stream.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/response/upstream_stream.test.rs b/src-tauri/src/proxy/response/upstream_stream.test.rs deleted file mode 100644 index 2942e49..0000000 --- a/src-tauri/src/proxy/response/upstream_stream.test.rs +++ /dev/null @@ -1,41 +0,0 @@ -use super::*; -use futures_util::StreamExt; - -#[tokio::test] -async fn idle_timeout_returns_error() { - let upstream = futures_util::stream::pending::>(); - let mut stream = with_idle_timeout(upstream); - - let item = tokio::time::timeout(Duration::from_secs(1), stream.next()) - .await - .expect("test timeout") - .expect("item") - .expect_err("timeout error"); - - assert!(matches!(item, UpstreamStreamError::IdleTimeout(_))); -} - -#[tokio::test] -async fn passes_through_success_chunks() { - let upstream = futures_util::stream::iter(vec![Ok::( - Bytes::from_static(b"hello"), - )]); - let mut stream = with_idle_timeout(upstream); - - let first = stream.next().await.expect("first").expect("ok"); - assert_eq!(first, Bytes::from_static(b"hello")); - - assert!(stream.next().await.is_none()); -} - -#[tokio::test] -async fn propagates_upstream_errors() { - let upstream = futures_util::stream::iter(vec![Err(std::io::Error::new( - std::io::ErrorKind::Other, - "boom", - ))]); - let mut stream = with_idle_timeout(upstream); - - let err = stream.next().await.expect("first").expect_err("err"); - assert!(matches!(err, UpstreamStreamError::Upstream(_))); -} diff --git a/src-tauri/src/proxy/server.rs b/src-tauri/src/proxy/server.rs deleted file mode 100644 index 43d4eee..0000000 --- a/src-tauri/src/proxy/server.rs +++ /dev/null @@ -1,875 +0,0 @@ -use axum::{ - body::Body, - extract::State, - http::{HeaderMap, Method, StatusCode, Uri}, - response::Response, -}; -use std::{ - sync::Arc, - time::Instant, -}; -use tokio::sync::RwLock; - -use super::{ - config::ProxyConfig, - gemini, - http, - log::{build_log_entry, LogContext, LogWriter, UsageSnapshot}, - openai_compat::{ - inbound_format, ApiFormat, FormatTransform, CHAT_PATH, PROVIDER_CHAT, PROVIDER_RESPONSES, - RESPONSES_PATH, - }, - request_detail::{capture_request_detail, serialize_request_headers, RequestDetailSnapshot}, - request_body::ReplayableBody, - server_helpers::{ - extract_request_path, is_anthropic_path, log_debug_request, - maybe_force_openai_stream_options_include_usage, maybe_transform_request_body, - parse_request_meta_best_effort, - }, - upstream::forward_upstream_request, - ProxyState, RequestMeta, -}; -use crate::logging::LogLevel; - -const PROVIDER_ANTHROPIC: &str = "anthropic"; -const PROVIDER_ANTIGRAVITY: &str = "antigravity"; -const PROVIDER_GEMINI: &str = "gemini"; -const PROVIDER_KIRO: &str = "kiro"; -const PROVIDER_CODEX: &str = "codex"; -const PROVIDER_PROXY: &str = "proxy"; -const LOCAL_UPSTREAM_ID: &str = "local"; -const CODEX_RESPONSES_PATH: &str = "/responses"; - -type ProxyStateHandle = Arc>>; - -mod bootstrap; -pub(crate) use bootstrap::{build_router, build_upstream_cursors}; - -struct DispatchPlan { - provider: &'static str, - outbound_path: Option<&'static str>, - request_transform: FormatTransform, - response_transform: FormatTransform, -} - -struct PreparedRequest { - path: String, - outbound_path_with_query: String, - plan: DispatchPlan, - meta: RequestMeta, - request_detail: Option, - outbound_body: ReplayableBody, - request_auth: http::RequestAuth, -} - -struct InboundRequest { - path: String, - plan: DispatchPlan, - body: ReplayableBody, - meta: RequestMeta, - request_detail: Option, -} - -const ERROR_NO_UPSTREAM: &str = "No available upstream configured."; -const ERROR_CHAT_CONVERSION_DISABLED: &str = - "API format conversion is disabled (enable_api_format_conversion=false). Configure provider \"openai\" for /v1/chat/completions or enable conversion."; -const ERROR_RESPONSES_CONVERSION_DISABLED: &str = - "API format conversion is disabled (enable_api_format_conversion=false). Configure provider \"openai-response\" for /v1/responses or enable conversion."; -const ERROR_ANTHROPIC_CONVERSION_DISABLED: &str = - "API format conversion is disabled (enable_api_format_conversion=false). Configure provider \"anthropic\" for /v1/messages or enable conversion."; -const ERROR_GEMINI_CONVERSION_DISABLED: &str = - "API format conversion is disabled (enable_api_format_conversion=false). Configure provider \"gemini\" for Gemini paths or enable conversion."; - -fn base_plan(provider: &'static str) -> DispatchPlan { - DispatchPlan { - provider, - outbound_path: None, - request_transform: FormatTransform::None, - response_transform: FormatTransform::None, - } -} - -struct ProviderRank { - priority: i32, - min_id: String, -} - -fn provider_rank(config: &ProxyConfig, provider: &str) -> Option { - let upstreams = config.provider_upstreams(provider)?; - let (priority, min_id) = match upstreams.groups.first() { - Some(group) => { - let min_id = group - .items - .iter() - .map(|item| item.id.as_str()) - .min() - .unwrap_or(provider); - (group.priority, min_id) - } - None => (0, provider), - }; - Some(ProviderRank { - priority, - min_id: min_id.to_string(), - }) -} - -fn choose_provider_by_priority(config: &ProxyConfig, candidates: &[&'static str]) -> Option<&'static str> { - let mut selected: Option<(&'static str, ProviderRank)> = None; - for candidate in candidates { - let Some(rank) = provider_rank(config, candidate) else { - continue; - }; - match &selected { - None => selected = Some((*candidate, rank)), - Some((_, best)) => { - if rank.priority > best.priority - || (rank.priority == best.priority && rank.min_id < best.min_id) - { - selected = Some((*candidate, rank)); - } - } - } - } - selected.map(|(provider, _)| provider) -} - -fn resolve_gemini_plan(config: &ProxyConfig, path: &str) -> Option> { - if !gemini::is_gemini_path(path) { - return None; - } - if let Some(selected) = - choose_provider_by_priority(config, &[PROVIDER_GEMINI, PROVIDER_ANTIGRAVITY]) - { - return Some(Ok(base_plan(selected))); - } - let fallback = choose_provider_by_priority( - config, - &[PROVIDER_RESPONSES, PROVIDER_CHAT, PROVIDER_ANTHROPIC], - ); - let Some(fallback) = fallback else { - return Some(Err(ERROR_NO_UPSTREAM.to_string())); - }; - if !config.enable_api_format_conversion { - return Some(Err(ERROR_GEMINI_CONVERSION_DISABLED.to_string())); - } - Some(Ok(match fallback { - PROVIDER_RESPONSES => DispatchPlan { - provider: PROVIDER_RESPONSES, - outbound_path: Some(RESPONSES_PATH), - request_transform: FormatTransform::GeminiToResponses, - response_transform: FormatTransform::ResponsesToGemini, - }, - PROVIDER_CHAT => DispatchPlan { - provider: PROVIDER_CHAT, - outbound_path: Some(CHAT_PATH), - request_transform: FormatTransform::GeminiToChat, - response_transform: FormatTransform::ChatToGemini, - }, - PROVIDER_ANTHROPIC => DispatchPlan { - provider: PROVIDER_ANTHROPIC, - outbound_path: Some("/v1/messages"), - request_transform: FormatTransform::GeminiToAnthropic, - response_transform: FormatTransform::AnthropicToGemini, - }, - _ => base_plan(PROVIDER_RESPONSES), - })) -} - -fn resolve_anthropic_plan( - config: &ProxyConfig, - path: &str, -) -> Option> { - if !is_anthropic_path(path) { - return None; - } - if path == "/v1/messages" { - // Claude Code uses /v1/messages. Prefer native providers (Anthropic/Kiro) by priority. - if let Some(selected) = - choose_provider_by_priority(config, &[PROVIDER_ANTHROPIC, PROVIDER_KIRO]) - { - return Some(Ok(match selected { - PROVIDER_ANTHROPIC => base_plan(PROVIDER_ANTHROPIC), - PROVIDER_KIRO => DispatchPlan { - provider: PROVIDER_KIRO, - outbound_path: Some(RESPONSES_PATH), - request_transform: FormatTransform::None, - response_transform: FormatTransform::KiroToAnthropic, - }, - _ => base_plan(PROVIDER_ANTHROPIC), - })); - } - if !config.enable_api_format_conversion { - if config.provider_upstreams(PROVIDER_ANTIGRAVITY).is_some() { - return Some(Ok(DispatchPlan { - provider: PROVIDER_ANTIGRAVITY, - outbound_path: None, - request_transform: FormatTransform::AnthropicToGemini, - response_transform: FormatTransform::GeminiToAnthropic, - })); - } - return Some(Err(ERROR_ANTHROPIC_CONVERSION_DISABLED.to_string())); - } - // If native providers are missing, fall back to other formats when enabled (new-api style). - let fallback = choose_provider_by_priority( - config, - &[ - PROVIDER_RESPONSES, - PROVIDER_CHAT, - PROVIDER_GEMINI, - PROVIDER_ANTIGRAVITY, - ], - ); - let Some(fallback) = fallback else { - return Some(Err(ERROR_NO_UPSTREAM.to_string())); - }; - return Some(Ok(match fallback { - PROVIDER_RESPONSES => DispatchPlan { - provider: PROVIDER_RESPONSES, - outbound_path: Some(RESPONSES_PATH), - request_transform: FormatTransform::AnthropicToResponses, - response_transform: FormatTransform::ResponsesToAnthropic, - }, - PROVIDER_CHAT => DispatchPlan { - provider: PROVIDER_CHAT, - outbound_path: Some(CHAT_PATH), - request_transform: FormatTransform::AnthropicToChat, - response_transform: FormatTransform::ChatToAnthropic, - }, - PROVIDER_GEMINI => DispatchPlan { - provider: PROVIDER_GEMINI, - outbound_path: None, - request_transform: FormatTransform::AnthropicToGemini, - response_transform: FormatTransform::GeminiToAnthropic, - }, - PROVIDER_ANTIGRAVITY => DispatchPlan { - provider: PROVIDER_ANTIGRAVITY, - outbound_path: None, - request_transform: FormatTransform::AnthropicToGemini, - response_transform: FormatTransform::GeminiToAnthropic, - }, - _ => base_plan(PROVIDER_RESPONSES), - })); - } - if config.provider_upstreams(PROVIDER_ANTHROPIC).is_some() { - return Some(Ok(base_plan(PROVIDER_ANTHROPIC))); - } - Some(Err(ERROR_NO_UPSTREAM.to_string())) -} - -fn resolve_formatless_plan(config: &ProxyConfig) -> Result { - let provider = choose_provider_by_priority( - config, - &[PROVIDER_CHAT, PROVIDER_RESPONSES, PROVIDER_ANTHROPIC], - ) - .ok_or_else(|| ERROR_NO_UPSTREAM.to_string())?; - Ok(base_plan(provider)) -} - -fn resolve_chat_plan(config: &ProxyConfig) -> Result { - if config.provider_upstreams(PROVIDER_CHAT).is_some() { - return Ok(base_plan(PROVIDER_CHAT)); - } - let selected = choose_provider_by_priority( - config, - &[ - PROVIDER_RESPONSES, - PROVIDER_CODEX, - PROVIDER_ANTHROPIC, - PROVIDER_GEMINI, - PROVIDER_ANTIGRAVITY, - PROVIDER_KIRO, - ], - ) - .ok_or_else(|| ERROR_NO_UPSTREAM.to_string())?; - if !config.enable_api_format_conversion { - return Err(ERROR_CHAT_CONVERSION_DISABLED.to_string()); - } - - Ok(match selected { - PROVIDER_RESPONSES => DispatchPlan { - provider: PROVIDER_RESPONSES, - outbound_path: Some(RESPONSES_PATH), - request_transform: FormatTransform::ChatToResponses, - response_transform: FormatTransform::ResponsesToChat, - }, - PROVIDER_ANTHROPIC => DispatchPlan { - provider: PROVIDER_ANTHROPIC, - outbound_path: Some("/v1/messages"), - request_transform: FormatTransform::ChatToAnthropic, - response_transform: FormatTransform::AnthropicToChat, - }, - PROVIDER_CODEX => DispatchPlan { - provider: PROVIDER_CODEX, - outbound_path: Some(CODEX_RESPONSES_PATH), - request_transform: FormatTransform::ChatToCodex, - response_transform: FormatTransform::CodexToChat, - }, - PROVIDER_GEMINI => DispatchPlan { - provider: PROVIDER_GEMINI, - outbound_path: None, // Gemini 路径需要在 upstream 层根据 model 动态构建 - request_transform: FormatTransform::ChatToGemini, - response_transform: FormatTransform::GeminiToChat, - }, - PROVIDER_ANTIGRAVITY => DispatchPlan { - provider: PROVIDER_ANTIGRAVITY, - outbound_path: None, - request_transform: FormatTransform::ChatToGemini, - response_transform: FormatTransform::GeminiToChat, - }, - PROVIDER_KIRO => DispatchPlan { - provider: PROVIDER_KIRO, - outbound_path: Some(RESPONSES_PATH), - request_transform: FormatTransform::None, - response_transform: FormatTransform::KiroToChat, - }, - _ => base_plan(PROVIDER_RESPONSES), - }) -} - -fn resolve_responses_plan(config: &ProxyConfig) -> Result { - if let Some(selected) = - choose_provider_by_priority(config, &[PROVIDER_RESPONSES, PROVIDER_CODEX, PROVIDER_KIRO]) - { - if selected == PROVIDER_RESPONSES { - return Ok(base_plan(PROVIDER_RESPONSES)); - } - if selected == PROVIDER_CODEX { - return Ok(DispatchPlan { - provider: PROVIDER_CODEX, - outbound_path: Some(CODEX_RESPONSES_PATH), - request_transform: FormatTransform::ResponsesToCodex, - response_transform: FormatTransform::CodexToResponses, - }); - } - return Ok(DispatchPlan { - provider: PROVIDER_KIRO, - outbound_path: Some(RESPONSES_PATH), - request_transform: FormatTransform::None, - response_transform: FormatTransform::KiroToResponses, - }); - } - if !config.enable_api_format_conversion { - return Err(ERROR_RESPONSES_CONVERSION_DISABLED.to_string()); - } - - let selected = choose_provider_by_priority( - config, - &[ - PROVIDER_CHAT, - PROVIDER_ANTHROPIC, - PROVIDER_GEMINI, - PROVIDER_ANTIGRAVITY, - ], - ) - .ok_or_else(|| ERROR_NO_UPSTREAM.to_string())?; - Ok(match selected { - PROVIDER_CHAT => DispatchPlan { - provider: PROVIDER_CHAT, - outbound_path: Some(CHAT_PATH), - request_transform: FormatTransform::ResponsesToChat, - response_transform: FormatTransform::ChatToResponses, - }, - PROVIDER_ANTHROPIC => DispatchPlan { - provider: PROVIDER_ANTHROPIC, - outbound_path: Some("/v1/messages"), - request_transform: FormatTransform::ResponsesToAnthropic, - response_transform: FormatTransform::AnthropicToResponses, - }, - PROVIDER_GEMINI => DispatchPlan { - provider: PROVIDER_GEMINI, - outbound_path: None, - request_transform: FormatTransform::ResponsesToGemini, - response_transform: FormatTransform::GeminiToResponses, - }, - PROVIDER_ANTIGRAVITY => DispatchPlan { - provider: PROVIDER_ANTIGRAVITY, - outbound_path: None, - request_transform: FormatTransform::ResponsesToGemini, - response_transform: FormatTransform::GeminiToResponses, - }, - _ => base_plan(PROVIDER_CHAT), - }) -} - -fn resolve_dispatch_plan(config: &ProxyConfig, path: &str) -> Result { - if let Some(plan) = resolve_gemini_plan(config, path) { - return plan; - } - if let Some(plan) = resolve_anthropic_plan(config, path) { - return plan; - } - - let Some(format) = inbound_format(path) else { - return resolve_formatless_plan(config); - }; - - match format { - ApiFormat::ChatCompletions => resolve_chat_plan(config), - ApiFormat::Responses => resolve_responses_plan(config), - } -} - -async fn capture_detail_from_body( - headers: &HeaderMap, - body: Body, - max_body_bytes: usize, -) -> RequestDetailSnapshot { - match ReplayableBody::from_body(body).await { - Ok(replayable) => capture_request_detail(headers, &replayable, max_body_bytes).await, - Err(err) => RequestDetailSnapshot { - request_headers: serialize_request_headers(headers), - request_body: Some(format!("Failed to read request body: {err}")), - }, - } -} - -fn log_request_error( - log: &Arc, - detail: Option, - path: &str, - provider: &str, - upstream_id: &str, - status: StatusCode, - response_error: String, - start: Instant, -) { - let (request_headers, request_body) = - detail.map(|detail| (detail.request_headers, detail.request_body)).unwrap_or((None, None)); - let context = LogContext { - path: path.to_string(), - provider: provider.to_string(), - upstream_id: upstream_id.to_string(), - model: None, - mapped_model: None, - stream: false, - status: status.as_u16(), - upstream_request_id: None, - request_headers, - request_body, - ttfb_ms: None, - start, - }; - let usage = UsageSnapshot { - usage: None, - cached_tokens: None, - usage_json: None, - }; - let entry = build_log_entry(&context, usage, Some(response_error)); - log.clone().write_detached(entry); -} - -async fn ensure_local_auth_or_respond( - config: &ProxyConfig, - log: &Arc, - headers: &HeaderMap, - body: Body, - capture_next: bool, - path: &str, - query: Option<&str>, - request_start: Instant, - max_body_bytes: usize, -) -> Result { - if let Err(message) = http::ensure_local_auth(config, headers, path, query) { - tracing::warn!("local auth failed"); - let detail = if capture_next { - Some(capture_detail_from_body(headers, body, max_body_bytes).await) - } else { - None - }; - log_request_error( - log, - detail, - path, - PROVIDER_PROXY, - LOCAL_UPSTREAM_ID, - StatusCode::UNAUTHORIZED, - message.clone(), - request_start, - ); - return Err(http::error_response(StatusCode::UNAUTHORIZED, message)); - } - Ok(body) -} - -async fn resolve_plan_or_respond( - config: &ProxyConfig, - log: &Arc, - headers: &HeaderMap, - body: Body, - capture_next: bool, - path: &str, - request_start: Instant, - max_body_bytes: usize, -) -> Result<(DispatchPlan, Body), Response> { - match resolve_dispatch_plan(config, path) { - Ok(plan) => { - tracing::debug!(provider = %plan.provider, "dispatch plan resolved"); - Ok((plan, body)) - } - Err(message) => { - tracing::warn!("no dispatch plan found"); - let detail = if capture_next { - Some(capture_detail_from_body(headers, body, max_body_bytes).await) - } else { - None - }; - log_request_error( - log, - detail, - path, - PROVIDER_PROXY, - LOCAL_UPSTREAM_ID, - StatusCode::BAD_GATEWAY, - message.clone(), - request_start, - ); - Err(http::error_response(StatusCode::BAD_GATEWAY, message)) - } - } -} - -async fn read_body_or_respond( - log: &Arc, - headers: &HeaderMap, - body: Body, - capture_next: bool, - path: &str, - request_start: Instant, -) -> Result { - match ReplayableBody::from_body(body).await { - Ok(body) => Ok(body), - Err(err) => { - let message = format!("Failed to read request body: {err}"); - let detail = if capture_next { - Some(RequestDetailSnapshot { - request_headers: serialize_request_headers(headers), - request_body: Some(message.clone()), - }) - } else { - None - }; - log_request_error( - log, - detail, - path, - PROVIDER_PROXY, - LOCAL_UPSTREAM_ID, - StatusCode::BAD_REQUEST, - message.clone(), - request_start, - ); - Err(http::error_response(StatusCode::BAD_REQUEST, message)) - } - } -} - -async fn build_outbound_body_or_respond( - http_clients: &super::http_client::ProxyHttpClients, - log: &Arc, - request_detail: Option, - path: &str, - plan: &DispatchPlan, - meta: &RequestMeta, - body: ReplayableBody, - request_start: Instant, -) -> Result { - let body = transform_body_or_respond( - http_clients, - log, - request_detail.clone(), - path, - plan, - meta, - body, - request_start, - ) - .await?; - apply_openai_stream_options_or_respond( - log, - request_detail, - path, - plan, - meta, - body, - request_start, - ) - .await -} - -async fn transform_body_or_respond( - http_clients: &super::http_client::ProxyHttpClients, - log: &Arc, - request_detail: Option, - path: &str, - plan: &DispatchPlan, - meta: &RequestMeta, - body: ReplayableBody, - request_start: Instant, -) -> Result { - match maybe_transform_request_body( - http_clients, - plan.provider, - path, - plan.request_transform, - meta.original_model.as_deref(), - body, - ) - .await - { - Ok(body) => Ok(body), - Err(err) => { - log_request_error( - log, - request_detail, - path, - plan.provider, - LOCAL_UPSTREAM_ID, - err.status, - err.message.clone(), - request_start, - ); - Err(http::error_response(err.status, err.message)) - } - } -} - -async fn apply_openai_stream_options_or_respond( - log: &Arc, - request_detail: Option, - path: &str, - plan: &DispatchPlan, - meta: &RequestMeta, - body: ReplayableBody, - request_start: Instant, -) -> Result { - match maybe_force_openai_stream_options_include_usage( - plan.provider, - plan.outbound_path.unwrap_or(path), - meta, - body, - ) - .await - { - Ok(body) => Ok(body), - Err(err) => { - log_request_error( - log, - request_detail, - path, - plan.provider, - LOCAL_UPSTREAM_ID, - err.status, - err.message.clone(), - request_start, - ); - Err(http::error_response(err.status, err.message)) - } - } -} - -fn resolve_request_auth_or_respond( - config: &ProxyConfig, - headers: &HeaderMap, - log: &Arc, - request_detail: Option, - path: &str, - provider: &str, - request_start: Instant, -) -> Result { - match http::resolve_request_auth(config, headers) { - Ok(auth) => Ok(auth), - Err(message) => { - log_request_error( - log, - request_detail, - path, - provider, - LOCAL_UPSTREAM_ID, - StatusCode::UNAUTHORIZED, - message.clone(), - request_start, - ); - Err(http::error_response(StatusCode::UNAUTHORIZED, message)) - } - } -} - -fn build_outbound_path_with_query(outbound_path: &str, uri: &Uri) -> String { - uri.query() - .map(|query| format!("{outbound_path}?{query}")) - .unwrap_or_else(|| outbound_path.to_string()) -} - -async fn prepare_inbound_request( - state: &ProxyState, - headers: &HeaderMap, - path: String, - query: Option, - body: Body, - capture_next: bool, - request_start: Instant, - is_debug_log: bool, -) -> Result { - let body = ensure_local_auth_or_respond( - &state.config, - &state.log, - headers, - body, - capture_next, - &path, - query.as_deref(), - request_start, - state.config.max_request_body_bytes, - ) - .await?; - let (plan, body) = resolve_plan_or_respond( - &state.config, - &state.log, - headers, - body, - capture_next, - &path, - request_start, - state.config.max_request_body_bytes, - ) - .await?; - let body = read_body_or_respond(&state.log, headers, body, capture_next, &path, request_start) - .await?; - if is_debug_log { - log_debug_request(headers, &body).await; - } - let meta = parse_request_meta_best_effort(&path, &body).await; - let request_detail = if capture_next { - Some( - capture_request_detail(headers, &body, state.config.max_request_body_bytes).await, - ) - } else { - None - }; - Ok(InboundRequest { - path, - plan, - meta, - request_detail, - body, - }) -} - -async fn finalize_prepared_request( - state: &ProxyState, - headers: &HeaderMap, - uri: &Uri, - inbound: InboundRequest, - request_start: Instant, -) -> Result { - // 对于 ChatToGemini 转换,需要根据 model 动态构建 Gemini 路径 - let outbound_path = match (inbound.plan.outbound_path, inbound.plan.provider) { - (Some(path), _) => path.to_string(), - (None, PROVIDER_GEMINI) if inbound.plan.request_transform != FormatTransform::None => { - // 从 meta 中获取 model,构建 Gemini API 路径 - let model = inbound - .meta - .mapped_model - .as_deref() - .or(inbound.meta.original_model.as_deref()) - .unwrap_or("gemini-1.5-flash"); - let suffix = if inbound.meta.stream { - ":streamGenerateContent" - } else { - ":generateContent" - }; - format!("{}{}{}", gemini::GEMINI_MODELS_PREFIX, model, suffix) - } - (None, _) => inbound.path.clone(), - }; - let outbound_path_with_query = build_outbound_path_with_query(&outbound_path, uri); - let outbound_body = build_outbound_body_or_respond( - &state.http_clients, - &state.log, - inbound.request_detail.clone(), - &inbound.path, - &inbound.plan, - &inbound.meta, - inbound.body, - request_start, - ) - .await?; - let request_auth = resolve_request_auth_or_respond( - &state.config, - headers, - &state.log, - inbound.request_detail.clone(), - &inbound.path, - inbound.plan.provider, - request_start, - )?; - Ok(PreparedRequest { - path: inbound.path, - outbound_path_with_query, - plan: inbound.plan, - meta: inbound.meta, - request_detail: inbound.request_detail, - outbound_body, - request_auth, - }) -} - -async fn proxy_request( - State(state): State, - method: Method, - uri: Uri, - headers: HeaderMap, - body: Body, -) -> Response { - // 只在此处短暂持有读锁,避免影响并发请求性能。 - let state = { state.read().await.clone() }; - let request_start = Instant::now(); - let capture_next = state.request_detail.take(); - let is_debug_log = cfg!(debug_assertions) - && matches!(state.config.log_level, LogLevel::Debug | LogLevel::Trace); - let (path, _) = extract_request_path(&uri); - let query = uri.query().map(|value| value.to_string()); - tracing::info!(method = %method, path = %path, "incoming request"); - tracing::debug!(headers = ?headers.keys().collect::>(), "request headers"); - - let inbound = match prepare_inbound_request( - &state, - &headers, - path, - query, - body, - capture_next, - request_start, - is_debug_log, - ) - .await - { - Ok(inbound) => inbound, - Err(response) => return response, - }; - let prepared = match finalize_prepared_request(&state, &headers, &uri, inbound, request_start) - .await - { - Ok(prepared) => prepared, - Err(response) => return response, - }; - forward_upstream_request( - state, - method, - prepared.plan.provider, - &prepared.path, - &prepared.outbound_path_with_query, - headers, - prepared.outbound_body, - prepared.meta, - prepared.request_auth, - prepared.plan.response_transform, - prepared.request_detail, - ) - .await -} - -#[cfg(test)] -#[path = "server.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/server.test.rs b/src-tauri/src/proxy/server.test.rs deleted file mode 100644 index 06ca00a..0000000 --- a/src-tauri/src/proxy/server.test.rs +++ /dev/null @@ -1,301 +0,0 @@ -use super::*; - -use std::collections::HashMap; - -use crate::logging::LogLevel; -use crate::proxy::config::{ProviderUpstreams, ProxyConfig, UpstreamGroup, UpstreamRuntime, UpstreamStrategy}; - -fn config_with_providers( - providers: &[&'static str], - enable_api_format_conversion: bool, -) -> ProxyConfig { - let mut upstreams = HashMap::new(); - for provider in providers { - upstreams.insert((*provider).to_string(), ProviderUpstreams { groups: Vec::new() }); - } - ProxyConfig { - host: "127.0.0.1".to_string(), - port: 9208, - local_api_key: None, - log_level: LogLevel::Silent, - max_request_body_bytes: 20 * 1024 * 1024, - enable_api_format_conversion, - upstream_strategy: UpstreamStrategy::PriorityRoundRobin, - upstreams, - kiro_preferred_endpoint: None, - antigravity_user_agent: None, - } -} - -fn config_with_upstreams( - upstreams: &[(&'static str, i32, &'static str)], - enable_api_format_conversion: bool, -) -> ProxyConfig { - let mut provider_map: HashMap = HashMap::new(); - for (provider, priority, id) in upstreams { - let runtime = UpstreamRuntime { - id: (*id).to_string(), - base_url: "https://example.com".to_string(), - api_key: None, - filter_prompt_cache_retention: false, - filter_safety_identifier: false, - kiro_account_id: None, - codex_account_id: None, - antigravity_account_id: None, - kiro_preferred_endpoint: None, - proxy_url: None, - priority: *priority, - model_mappings: None, - header_overrides: None, - }; - let entry = provider_map - .entry((*provider).to_string()) - .or_insert_with(|| ProviderUpstreams { groups: Vec::new() }); - if let Some(group) = entry.groups.iter_mut().find(|group| group.priority == *priority) { - group.items.push(runtime); - } else { - entry.groups.push(UpstreamGroup { - priority: *priority, - items: vec![runtime], - }); - } - } - for upstreams in provider_map.values_mut() { - upstreams.groups.sort_by(|left, right| right.priority.cmp(&left.priority)); - } - ProxyConfig { - host: "127.0.0.1".to_string(), - port: 9208, - local_api_key: None, - log_level: LogLevel::Silent, - max_request_body_bytes: 20 * 1024 * 1024, - enable_api_format_conversion, - upstream_strategy: UpstreamStrategy::PriorityRoundRobin, - upstreams: provider_map, - kiro_preferred_endpoint: None, - antigravity_user_agent: None, - } -} - -#[test] -fn chat_fallback_requires_format_conversion_enabled() { - let config = config_with_providers(&[PROVIDER_RESPONSES], false); - let error = resolve_dispatch_plan(&config, CHAT_PATH) - .err() - .expect("should reject"); - assert!(error.contains("format conversion is disabled")); - - let config = config_with_providers(&[PROVIDER_RESPONSES], true); - let plan = resolve_dispatch_plan(&config, CHAT_PATH).expect("should fallback"); - assert_eq!(plan.provider, PROVIDER_RESPONSES); - assert_eq!(plan.outbound_path, Some(RESPONSES_PATH)); - assert_eq!(plan.request_transform, FormatTransform::ChatToResponses); - assert_eq!(plan.response_transform, FormatTransform::ResponsesToChat); -} - -#[test] -fn responses_fallback_requires_format_conversion_enabled() { - let config = config_with_providers(&[PROVIDER_CHAT], false); - let error = resolve_dispatch_plan(&config, RESPONSES_PATH) - .err() - .expect("should reject"); - assert!(error.contains("format conversion is disabled")); - - let config = config_with_providers(&[PROVIDER_CHAT], true); - let plan = resolve_dispatch_plan(&config, RESPONSES_PATH).expect("should fallback"); - assert_eq!(plan.provider, PROVIDER_CHAT); - assert_eq!(plan.outbound_path, Some(CHAT_PATH)); - assert_eq!(plan.request_transform, FormatTransform::ResponsesToChat); - assert_eq!(plan.response_transform, FormatTransform::ChatToResponses); -} - -#[test] -fn chat_to_codex_requires_format_conversion_enabled() { - let config = config_with_providers(&[PROVIDER_CODEX], false); - let error = resolve_dispatch_plan(&config, CHAT_PATH) - .err() - .expect("should reject"); - assert!(error.contains("format conversion is disabled")); - - let config = config_with_providers(&[PROVIDER_CODEX], true); - let plan = resolve_dispatch_plan(&config, CHAT_PATH).expect("should dispatch"); - assert_eq!(plan.provider, PROVIDER_CODEX); - assert_eq!(plan.outbound_path, Some(CODEX_RESPONSES_PATH)); - assert_eq!(plan.request_transform, FormatTransform::ChatToCodex); - assert_eq!(plan.response_transform, FormatTransform::CodexToChat); -} - -#[test] -fn responses_prefers_codex_without_conversion() { - let config = config_with_providers(&[PROVIDER_CODEX], false); - let plan = resolve_dispatch_plan(&config, RESPONSES_PATH).expect("should dispatch"); - assert_eq!(plan.provider, PROVIDER_CODEX); - assert_eq!(plan.outbound_path, Some(CODEX_RESPONSES_PATH)); - assert_eq!(plan.request_transform, FormatTransform::ResponsesToCodex); - assert_eq!(plan.response_transform, FormatTransform::CodexToResponses); -} - -#[test] -fn responses_same_protocol_preferred_over_priority() { - let config = config_with_upstreams( - &[(PROVIDER_RESPONSES, 0, "resp"), (PROVIDER_CHAT, 10, "chat")], - false, - ); - let plan = resolve_dispatch_plan(&config, RESPONSES_PATH).expect("should dispatch"); - assert_eq!(plan.provider, PROVIDER_RESPONSES); - assert_eq!(plan.request_transform, FormatTransform::None); - assert_eq!(plan.response_transform, FormatTransform::None); -} - -#[test] -fn responses_same_protocol_tiebreaks_by_id() { - let config = config_with_upstreams( - &[(PROVIDER_RESPONSES, 5, "b-resp"), (PROVIDER_KIRO, 5, "a-kiro")], - false, - ); - let plan = resolve_dispatch_plan(&config, RESPONSES_PATH).expect("should dispatch"); - assert_eq!(plan.provider, PROVIDER_KIRO); - assert_eq!(plan.response_transform, FormatTransform::KiroToResponses); -} - -#[test] -fn anthropic_messages_fallback_requires_format_conversion_enabled() { - let config = config_with_providers(&[PROVIDER_RESPONSES], false); - let error = resolve_dispatch_plan(&config, "/v1/messages") - .err() - .expect("should reject"); - assert!(error.contains("format conversion is disabled")); - - let config = config_with_providers(&[PROVIDER_RESPONSES], true); - let plan = resolve_dispatch_plan(&config, "/v1/messages").expect("should fallback"); - assert_eq!(plan.provider, PROVIDER_RESPONSES); - assert_eq!(plan.outbound_path, Some(RESPONSES_PATH)); - assert_eq!(plan.request_transform, FormatTransform::AnthropicToResponses); - assert_eq!(plan.response_transform, FormatTransform::ResponsesToAnthropic); -} - -#[test] -fn anthropic_messages_fallbacks_to_kiro_without_conversion() { - let config = config_with_providers(&[PROVIDER_KIRO], false); - let plan = resolve_dispatch_plan(&config, "/v1/messages").expect("should fallback"); - assert_eq!(plan.provider, PROVIDER_KIRO); - assert_eq!(plan.outbound_path, Some(RESPONSES_PATH)); - assert_eq!(plan.request_transform, FormatTransform::None); - assert_eq!(plan.response_transform, FormatTransform::KiroToAnthropic); -} - -#[test] -fn anthropic_messages_allows_antigravity_without_conversion() { - let config = config_with_providers(&[PROVIDER_ANTIGRAVITY], false); - let plan = resolve_dispatch_plan(&config, "/v1/messages").expect("should fallback"); - assert_eq!(plan.provider, PROVIDER_ANTIGRAVITY); - assert_eq!(plan.outbound_path, None); - assert_eq!(plan.request_transform, FormatTransform::AnthropicToGemini); - assert_eq!(plan.response_transform, FormatTransform::GeminiToAnthropic); -} - -#[test] -fn anthropic_messages_prefers_kiro_without_conversion() { - let config = config_with_upstreams( - &[(PROVIDER_RESPONSES, 10, "resp"), (PROVIDER_KIRO, 0, "kiro")], - false, - ); - let plan = resolve_dispatch_plan(&config, "/v1/messages").expect("should fallback"); - assert_eq!(plan.provider, PROVIDER_KIRO); - assert_eq!(plan.outbound_path, Some(RESPONSES_PATH)); - assert_eq!(plan.request_transform, FormatTransform::None); - assert_eq!(plan.response_transform, FormatTransform::KiroToAnthropic); -} - -#[test] -fn anthropic_messages_prefers_anthropic_when_priority_higher() { - let config = config_with_upstreams( - &[(PROVIDER_ANTHROPIC, 5, "anthro"), (PROVIDER_KIRO, 1, "kiro")], - false, - ); - let plan = resolve_dispatch_plan(&config, "/v1/messages").expect("should dispatch"); - assert_eq!(plan.provider, PROVIDER_ANTHROPIC); - assert_eq!(plan.outbound_path, None); - assert_eq!(plan.request_transform, FormatTransform::None); - assert_eq!(plan.response_transform, FormatTransform::None); -} - -#[test] -fn anthropic_messages_tiebreaks_by_id_between_anthropic_and_kiro() { - let config = config_with_upstreams( - &[(PROVIDER_ANTHROPIC, 5, "b-anthro"), (PROVIDER_KIRO, 5, "a-kiro")], - false, - ); - let plan = resolve_dispatch_plan(&config, "/v1/messages").expect("should dispatch"); - assert_eq!(plan.provider, PROVIDER_KIRO); - assert_eq!(plan.outbound_path, Some(RESPONSES_PATH)); - assert_eq!(plan.request_transform, FormatTransform::None); - assert_eq!(plan.response_transform, FormatTransform::KiroToAnthropic); -} - -#[test] -fn responses_fallback_to_anthropic_requires_format_conversion_enabled() { - let config = config_with_providers(&[PROVIDER_ANTHROPIC], false); - let error = resolve_dispatch_plan(&config, RESPONSES_PATH) - .err() - .expect("should reject"); - assert!(error.contains("format conversion is disabled")); - - let config = config_with_providers(&[PROVIDER_ANTHROPIC], true); - let plan = resolve_dispatch_plan(&config, RESPONSES_PATH).expect("should fallback"); - assert_eq!(plan.provider, PROVIDER_ANTHROPIC); - assert_eq!(plan.outbound_path, Some("/v1/messages")); - assert_eq!(plan.request_transform, FormatTransform::ResponsesToAnthropic); - assert_eq!(plan.response_transform, FormatTransform::AnthropicToResponses); -} - -#[test] -fn gemini_route_requires_format_conversion_for_fallback() { - let config = config_with_providers(&[PROVIDER_CHAT], false); - let error = resolve_dispatch_plan(&config, "/v1beta/models/gemini-1.5-flash:generateContent") - .err() - .expect("should reject"); - assert!(error.contains("format conversion is disabled")); -} - -#[test] -fn gemini_route_fallbacks_to_chat() { - let config = config_with_providers(&[PROVIDER_CHAT], true); - let plan = resolve_dispatch_plan(&config, "/v1beta/models/gemini-1.5-flash:generateContent") - .expect("should fallback"); - assert_eq!(plan.provider, PROVIDER_CHAT); - assert_eq!(plan.outbound_path, Some(CHAT_PATH)); - assert_eq!(plan.request_transform, FormatTransform::GeminiToChat); - assert_eq!(plan.response_transform, FormatTransform::ChatToGemini); -} - -#[test] -fn gemini_route_fallbacks_to_anthropic() { - let config = config_with_providers(&[PROVIDER_ANTHROPIC], true); - let plan = resolve_dispatch_plan(&config, "/v1beta/models/gemini-1.5-flash:generateContent") - .expect("should fallback"); - assert_eq!(plan.provider, PROVIDER_ANTHROPIC); - assert_eq!(plan.outbound_path, Some("/v1/messages")); - assert_eq!(plan.request_transform, FormatTransform::GeminiToAnthropic); - assert_eq!(plan.response_transform, FormatTransform::AnthropicToGemini); -} - -#[test] -fn anthropic_messages_fallbacks_to_gemini() { - let config = config_with_providers(&[PROVIDER_GEMINI], true); - let plan = resolve_dispatch_plan(&config, "/v1/messages").expect("should fallback"); - assert_eq!(plan.provider, PROVIDER_GEMINI); - assert_eq!(plan.outbound_path, None); - assert_eq!(plan.request_transform, FormatTransform::AnthropicToGemini); - assert_eq!(plan.response_transform, FormatTransform::GeminiToAnthropic); -} - -#[test] -fn gemini_route_dispatches_to_gemini() { - let config = config_with_providers(&[PROVIDER_GEMINI], false); - let plan = resolve_dispatch_plan(&config, "/v1beta/models/gemini-1.5-flash:generateContent") - .expect("should dispatch"); - assert_eq!(plan.provider, PROVIDER_GEMINI); - assert_eq!(plan.request_transform, FormatTransform::None); - assert_eq!(plan.response_transform, FormatTransform::None); -} diff --git a/src-tauri/src/proxy/server/bootstrap.rs b/src-tauri/src/proxy/server/bootstrap.rs deleted file mode 100644 index 9390136..0000000 --- a/src-tauri/src/proxy/server/bootstrap.rs +++ /dev/null @@ -1,35 +0,0 @@ -use axum::{ - extract::DefaultBodyLimit, - routing::any, - Router, -}; -use std::{collections::HashMap, sync::atomic::AtomicUsize}; - -use crate::proxy::config::ProxyConfig; -use super::{proxy_request, ProxyStateHandle}; - -pub(crate) fn build_upstream_cursors( - config: &ProxyConfig, -) -> HashMap> { - let mut cursors: HashMap> = HashMap::new(); - for (provider, upstreams) in &config.upstreams { - let group_cursors = upstreams - .groups - .iter() - .map(|_| AtomicUsize::new(0)) - .collect(); - cursors.insert(provider.clone(), group_cursors); - } - cursors -} - -pub(crate) fn build_router( - state: ProxyStateHandle, - max_request_body_bytes: usize, -) -> Router { - Router::new() - .route("/{*path}", any(proxy_request)) - // 限制入站请求体,避免超大请求占用内存/临时盘并拖慢首字节。 - .layer(DefaultBodyLimit::max(max_request_body_bytes)) - .with_state(state) -} diff --git a/src-tauri/src/proxy/server_helpers.rs b/src-tauri/src/proxy/server_helpers.rs deleted file mode 100644 index eadf0ae..0000000 --- a/src-tauri/src/proxy/server_helpers.rs +++ /dev/null @@ -1,436 +0,0 @@ -use axum::{ - body::Bytes, - http::{HeaderMap, StatusCode, Uri}, -}; -use serde_json::{Map, Value}; - -use super::{ - antigravity_compat, - gemini, - http_client::ProxyHttpClients, - openai_compat::{ - transform_request_body, FormatTransform, CHAT_PATH, PROVIDER_CHAT, PROVIDER_RESPONSES, - RESPONSES_PATH, - }, - request_token_estimate, - request_body::ReplayableBody, - RequestMeta, -}; - -const ANTHROPIC_MESSAGES_PREFIX: &str = "/v1/messages"; -const ANTHROPIC_COMPLETE_PATH: &str = "/v1/complete"; -const PROVIDER_ANTIGRAVITY: &str = "antigravity"; -const REQUEST_META_LIMIT_BYTES: usize = 2 * 1024 * 1024; -// Format conversion needs the full JSON body; allow up to the default max_request_body_bytes (20 MiB). -const REQUEST_TRANSFORM_LIMIT_BYTES: usize = 20 * 1024 * 1024; -const DEBUG_BODY_LOG_LIMIT_BYTES: usize = usize::MAX; -const OPENAI_REASONING_MODEL_SUFFIX_PREFIX: &str = "-reasoning-"; - -#[derive(Debug)] -pub(crate) struct RequestError { - pub(crate) status: StatusCode, - pub(crate) message: String, -} - -impl RequestError { - pub(crate) fn new(status: StatusCode, message: impl Into) -> Self { - Self { - status, - message: message.into(), - } - } -} - -pub(crate) fn extract_request_path(uri: &Uri) -> (String, String) { - let path = uri.path().to_string(); - let path_with_query = uri - .query() - .map(|query| format!("{path}?{query}")) - .unwrap_or_else(|| path.clone()); - (path, path_with_query) -} - -pub(crate) fn is_anthropic_path(path: &str) -> bool { - if path == ANTHROPIC_COMPLETE_PATH || path == ANTHROPIC_MESSAGES_PREFIX { - return true; - } - if !path.starts_with(ANTHROPIC_MESSAGES_PREFIX) { - return false; - } - path.as_bytes() - .get(ANTHROPIC_MESSAGES_PREFIX.len()) - .is_some_and(|byte| *byte == b'/') -} - -pub(crate) async fn parse_request_meta_best_effort( - path: &str, - body: &ReplayableBody, -) -> RequestMeta { - let stream_from_path = gemini::is_gemini_stream_path(path); - let model_from_path = gemini::parse_gemini_model_from_path(path); - let fallback_meta = RequestMeta { - stream: stream_from_path, - original_model: model_from_path.clone(), - mapped_model: None, - reasoning_effort: None, - estimated_input_tokens: None, - }; - - let Some(bytes) = body - .read_bytes_if_small(REQUEST_META_LIMIT_BYTES) - .await - .unwrap_or(None) - else { - return fallback_meta; - }; - let value: Value = match serde_json::from_slice(&bytes) { - Ok(value) => value, - Err(_) => return fallback_meta, - }; - let stream = value - .get("stream") - .and_then(Value::as_bool) - .unwrap_or(false) - || stream_from_path; - let mut original_model = value - .get("model") - .and_then(Value::as_str) - .map(|value| value.to_string()) - .or(model_from_path); - - // KISS: only support the explicit `-reasoning-` suffix to avoid ambiguity. - // This mirrors new-api behavior: strip the suffix from `model` and translate it into - // OpenAI reasoning parameters when dispatching to OpenAI providers. - let mut reasoning_effort = None; - if let Some(model) = original_model.as_deref() { - if let Some((base_model, effort)) = parse_openai_reasoning_effort_from_model_suffix(model) { - original_model = Some(base_model); - reasoning_effort = Some(effort); - } - } - - let estimated_input_tokens = - request_token_estimate::estimate_request_input_tokens(&value, original_model.as_deref()); - RequestMeta { - stream, - original_model, - mapped_model: None, - reasoning_effort, - estimated_input_tokens, - } -} - -pub(crate) fn parse_openai_reasoning_effort_from_model_suffix( - model: &str, -) -> Option<(String, String)> { - let (base, effort_raw) = model.rsplit_once(OPENAI_REASONING_MODEL_SUFFIX_PREFIX)?; - let base = base.trim(); - let effort = effort_raw.trim().to_ascii_lowercase(); - if base.is_empty() || effort.is_empty() { - return None; - } - - match effort.as_str() { - "low" | "medium" | "high" | "minimal" | "none" | "xhigh" => { - Some((base.to_string(), effort)) - } - _ => None, - } -} - -fn ensure_stream_options_include_usage(object: &mut Map) -> bool { - let include_usage = object - .get("stream_options") - .and_then(Value::as_object) - .and_then(|options| options.get("include_usage")) - .and_then(Value::as_bool) - .unwrap_or(false); - if include_usage { - return false; - } - - let options = match object.get_mut("stream_options") { - Some(Value::Object(options)) => options, - _ => { - object.insert("stream_options".to_string(), Value::Object(Map::new())); - object - .get_mut("stream_options") - .and_then(Value::as_object_mut) - .expect("stream_options must be object") - } - }; - options.insert("include_usage".to_string(), Value::Bool(true)); - true -} - -pub(crate) async fn log_debug_request(headers: &HeaderMap, body: &ReplayableBody) { - log_debug_headers_body( - "inbound.request", - Some(headers), - Some(body), - DEBUG_BODY_LOG_LIMIT_BYTES, - ) - .await; -} - -pub(crate) async fn log_debug_headers_body( - stage: &str, - headers: Option<&HeaderMap>, - body: Option<&ReplayableBody>, - max_body_bytes: usize, -) { - if !tracing::enabled!(tracing::Level::DEBUG) { - return; - } - - let header_snapshot = headers - .map(snapshot_headers_raw) - .unwrap_or_default(); - let body_text = if let Some(body) = body { - match body.read_bytes_if_small(max_body_bytes).await { - Ok(Some(bytes)) => Some(String::from_utf8_lossy(&bytes).into_owned()), - Ok(None) => Some(format!("[body omitted: larger than {max_body_bytes} bytes]")), - Err(err) => Some(format!("[body read failed: {err}]")), - } - } else { - None - }; - - match body_text { - Some(text) => { - tracing::debug!(stage, headers = ?header_snapshot, body = %text, "debug dump"); - } - None => { - tracing::debug!(stage, headers = ?header_snapshot, "debug dump (no body)"); - } - } -} - -fn snapshot_headers_raw(headers: &HeaderMap) -> Vec<(String, String)> { - headers - .iter() - .map(|(name, value)| { - let value = value.to_str().unwrap_or("").to_string(); - (name.to_string(), value) - }) - .collect() -} - -pub(crate) async fn maybe_transform_request_body( - http_clients: &ProxyHttpClients, - provider: &str, - path: &str, - transform: FormatTransform, - model_hint: Option<&str>, - body: ReplayableBody, -) -> Result { - if transform == FormatTransform::None { - return Ok(body); - } - - let Some(bytes) = body - .read_bytes_if_small(REQUEST_TRANSFORM_LIMIT_BYTES) - .await - .map_err(|err| { - RequestError::new( - StatusCode::BAD_REQUEST, - format!("Failed to read request body: {err}"), - ) - })? - else { - return Err(RequestError::new( - StatusCode::PAYLOAD_TOO_LARGE, - "Request body is too large to transform.", - )); - }; - let inbound_body = ReplayableBody::from_bytes(bytes.clone()); - log_debug_headers_body( - "transform.input", - None, - Some(&inbound_body), - DEBUG_BODY_LOG_LIMIT_BYTES, - ) - .await; - - let outbound_bytes = if should_use_antigravity_claude(provider, path, transform) { - antigravity_compat::claude_request_to_antigravity(&bytes, model_hint) - .map_err(|message| RequestError::new(StatusCode::BAD_REQUEST, message))? - } else { - transform_request_body(transform, &bytes, http_clients, model_hint) - .await - .map_err(|message| RequestError::new(StatusCode::BAD_REQUEST, message))? - }; - let outbound_body = ReplayableBody::from_bytes(outbound_bytes); - log_debug_headers_body( - "transform.output", - None, - Some(&outbound_body), - DEBUG_BODY_LOG_LIMIT_BYTES, - ) - .await; - Ok(outbound_body) -} - -fn should_use_antigravity_claude(provider: &str, path: &str, transform: FormatTransform) -> bool { - // Align Antigravity with CLIProxyAPIPlus for Claude /v1/messages. - provider == PROVIDER_ANTIGRAVITY - && path == ANTHROPIC_MESSAGES_PREFIX - && transform == FormatTransform::AnthropicToGemini -} - -pub(crate) async fn maybe_force_openai_stream_options_include_usage( - provider: &str, - outbound_path: &str, - meta: &RequestMeta, - body: ReplayableBody, -) -> Result { - if provider != PROVIDER_CHAT || outbound_path != CHAT_PATH || !meta.stream { - return Ok(body); - } - - let Some(bytes) = body - .read_bytes_if_small(REQUEST_TRANSFORM_LIMIT_BYTES) - .await - .map_err(|err| { - RequestError::new( - StatusCode::BAD_REQUEST, - format!("Failed to read request body: {err}"), - ) - })? - else { - // Best-effort: request body too large, keep original. - return Ok(body); - }; - - let Ok(mut value) = serde_json::from_slice::(&bytes) else { - return Ok(body); - }; - let Some(object) = value.as_object_mut() else { - return Ok(body); - }; - if !ensure_stream_options_include_usage(object) { - return Ok(body); - } - - let outbound_bytes = serde_json::to_vec(&value) - .map(Bytes::from) - .map_err(|err| { - RequestError::new( - StatusCode::BAD_REQUEST, - format!("Failed to serialize request: {err}"), - ) - })?; - let outbound_body = ReplayableBody::from_bytes(outbound_bytes); - log_debug_headers_body( - "stream_options.output", - None, - Some(&outbound_body), - DEBUG_BODY_LOG_LIMIT_BYTES, - ) - .await; - Ok(outbound_body) -} - -pub(crate) async fn maybe_rewrite_openai_reasoning_effort_from_model_suffix( - provider: &str, - outbound_path: &str, - meta: &RequestMeta, - body: &ReplayableBody, -) -> Result, RequestError> { - let Some(effort) = meta.reasoning_effort.as_deref() else { - return Ok(None); - }; - if !should_apply_openai_reasoning_effort(provider, outbound_path) { - return Ok(None); - } - - let Some(bytes) = body - .read_bytes_if_small(REQUEST_TRANSFORM_LIMIT_BYTES) - .await - .map_err(|err| { - RequestError::new( - StatusCode::BAD_REQUEST, - format!("Failed to read request body: {err}"), - ) - })? - else { - // Best-effort: request body too large, keep original. - return Ok(None); - }; - - let Ok(mut value) = serde_json::from_slice::(&bytes) else { - return Ok(None); - }; - let Some(object) = value.as_object_mut() else { - return Ok(None); - }; - - let model_for_upstream = meta - .mapped_model - .as_deref() - .or(meta.original_model.as_deref()); - apply_openai_reasoning_effort_to_body( - provider, - outbound_path, - model_for_upstream, - effort, - object, - ); - - let outbound_bytes = serde_json::to_vec(&value) - .map(Bytes::from) - .map_err(|err| { - RequestError::new( - StatusCode::BAD_REQUEST, - format!("Failed to serialize request: {err}"), - ) - })?; - Ok(Some(ReplayableBody::from_bytes(outbound_bytes))) -} - -fn should_apply_openai_reasoning_effort(provider: &str, outbound_path: &str) -> bool { - (provider == PROVIDER_CHAT && outbound_path == CHAT_PATH) - || (provider == PROVIDER_RESPONSES && outbound_path == RESPONSES_PATH) -} - -fn apply_openai_reasoning_effort_to_body( - provider: &str, - outbound_path: &str, - normalized_model: Option<&str>, - effort: &str, - object: &mut Map, -) { - // Ensure the upstream sees the normalized base model (without the `-reasoning-...` suffix). - if let Some(model) = normalized_model { - object.insert("model".to_string(), Value::String(model.to_string())); - } - - if provider == PROVIDER_CHAT && outbound_path == CHAT_PATH { - object.insert( - "reasoning_effort".to_string(), - Value::String(effort.to_string()), - ); - return; - } - if provider == PROVIDER_RESPONSES && outbound_path == RESPONSES_PATH { - let reasoning = ensure_json_object_field(object, "reasoning"); - reasoning.insert("effort".to_string(), Value::String(effort.to_string())); - } -} - -fn ensure_json_object_field<'a>( - object: &'a mut Map, - key: &str, -) -> &'a mut Map { - if !matches!(object.get(key), Some(Value::Object(_))) { - object.insert(key.to_string(), Value::Object(Map::new())); - } - object - .get_mut(key) - .and_then(Value::as_object_mut) - .expect("inserted value must be object") -} - -#[cfg(test)] -#[path = "server_helpers.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/server_helpers.test.rs b/src-tauri/src/proxy/server_helpers.test.rs deleted file mode 100644 index 1329e37..0000000 --- a/src-tauri/src/proxy/server_helpers.test.rs +++ /dev/null @@ -1,163 +0,0 @@ -use super::*; - -use axum::body::Bytes; - -#[test] -fn force_openai_chat_stream_usage_inserts_stream_options_include_usage() { - let rt = tokio::runtime::Runtime::new().expect("runtime"); - rt.block_on(async { - let input = Bytes::from_static(br#"{"stream":true,"messages":[]}"#); - let meta = RequestMeta { - stream: true, - original_model: None, - mapped_model: None, - reasoning_effort: None, - estimated_input_tokens: None, - }; - let body = ReplayableBody::from_bytes(input); - let output = - maybe_force_openai_stream_options_include_usage(PROVIDER_CHAT, CHAT_PATH, &meta, body) - .await - .expect("ok"); - let bytes = output - .read_bytes_if_small(1024) - .await - .expect("read") - .expect("bytes"); - let value: Value = serde_json::from_slice(&bytes).expect("json"); - assert_eq!(value["stream_options"]["include_usage"], Value::Bool(true)); - }); -} - -#[test] -fn gemini_meta_prefers_path_for_stream_and_model() { - let rt = tokio::runtime::Runtime::new().expect("runtime"); - rt.block_on(async { - let body = ReplayableBody::from_bytes(Bytes::from_static(b"{}")); - let meta = parse_request_meta_best_effort( - "/v1beta/models/gemini-1.5-flash:streamGenerateContent", - &body, - ) - .await; - assert!(meta.stream); - assert_eq!(meta.original_model.as_deref(), Some("gemini-1.5-flash")); - }); -} - -#[test] -fn meta_parses_reasoning_suffix_and_strips_model() { - let rt = tokio::runtime::Runtime::new().expect("runtime"); - rt.block_on(async { - let body = ReplayableBody::from_bytes(Bytes::from_static( - br#"{"model":"gpt-4.1-reasoning-high","messages":[]}"#, - )); - let meta = parse_request_meta_best_effort(CHAT_PATH, &body).await; - assert_eq!(meta.original_model.as_deref(), Some("gpt-4.1")); - assert_eq!(meta.reasoning_effort.as_deref(), Some("high")); - }); -} - -#[test] -fn apply_reasoning_suffix_for_chat_sets_reasoning_effort_and_model() { - let rt = tokio::runtime::Runtime::new().expect("runtime"); - rt.block_on(async { - let meta = RequestMeta { - stream: false, - original_model: Some("gpt-4.1".to_string()), - mapped_model: None, - reasoning_effort: Some("high".to_string()), - estimated_input_tokens: None, - }; - let body = ReplayableBody::from_bytes(Bytes::from_static( - br#"{"model":"gpt-4.1-reasoning-high","messages":[]}"#, - )); - let rewritten = maybe_rewrite_openai_reasoning_effort_from_model_suffix( - PROVIDER_CHAT, - CHAT_PATH, - &meta, - &body, - ) - .await - .expect("ok") - .expect("should rewrite"); - let bytes = rewritten - .read_bytes_if_small(1024) - .await - .expect("read") - .expect("bytes"); - let value: Value = serde_json::from_slice(&bytes).expect("json"); - assert_eq!(value["model"], Value::String("gpt-4.1".to_string())); - assert_eq!(value["reasoning_effort"], Value::String("high".to_string())); - }); -} - -#[test] -fn apply_reasoning_suffix_for_responses_sets_reasoning_object_and_model() { - let rt = tokio::runtime::Runtime::new().expect("runtime"); - rt.block_on(async { - let meta = RequestMeta { - stream: false, - original_model: Some("gpt-4.1".to_string()), - mapped_model: None, - reasoning_effort: Some("high".to_string()), - estimated_input_tokens: None, - }; - let body = ReplayableBody::from_bytes(Bytes::from_static( - br#"{"model":"gpt-4.1-reasoning-high","input":"hi"}"#, - )); - let rewritten = maybe_rewrite_openai_reasoning_effort_from_model_suffix( - PROVIDER_RESPONSES, - RESPONSES_PATH, - &meta, - &body, - ) - .await - .expect("ok") - .expect("should rewrite"); - let bytes = rewritten - .read_bytes_if_small(1024) - .await - .expect("read") - .expect("bytes"); - let value: Value = serde_json::from_slice(&bytes).expect("json"); - assert_eq!(value["model"], Value::String("gpt-4.1".to_string())); - assert_eq!( - value["reasoning"]["effort"], - Value::String("high".to_string()) - ); - }); -} - -#[test] -fn apply_reasoning_suffix_prefers_mapped_model_as_upstream_model() { - let rt = tokio::runtime::Runtime::new().expect("runtime"); - rt.block_on(async { - let meta = RequestMeta { - stream: false, - original_model: Some("gpt-4.1".to_string()), - mapped_model: Some("o3-mini".to_string()), - reasoning_effort: Some("high".to_string()), - estimated_input_tokens: None, - }; - let body = ReplayableBody::from_bytes(Bytes::from_static( - br#"{"model":"gpt-4.1-reasoning-high","messages":[]}"#, - )); - let rewritten = maybe_rewrite_openai_reasoning_effort_from_model_suffix( - PROVIDER_CHAT, - CHAT_PATH, - &meta, - &body, - ) - .await - .expect("ok") - .expect("should rewrite"); - let bytes = rewritten - .read_bytes_if_small(1024) - .await - .expect("read") - .expect("bytes"); - let value: Value = serde_json::from_slice(&bytes).expect("json"); - assert_eq!(value["model"], Value::String("o3-mini".to_string())); - assert_eq!(value["reasoning_effort"], Value::String("high".to_string())); - }); -} diff --git a/src-tauri/src/proxy/service.rs b/src-tauri/src/proxy/service.rs deleted file mode 100644 index fd1d673..0000000 --- a/src-tauri/src/proxy/service.rs +++ /dev/null @@ -1,371 +0,0 @@ -use serde::Serialize; -use sqlx::SqlitePool; -use std::future::IntoFuture; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use tauri::{AppHandle, Manager}; -use tokio::sync::{Mutex, RwLock}; -use tokio::task::JoinHandle; -use tokio::time::timeout; - -use super::config::ProxyConfig; -use super::log::LogWriter; -use super::request_detail::RequestDetailCapture; -use super::sqlite; -use super::server; -use super::ProxyState; -use crate::logging::LoggingState; - -/// 默认优雅停机等待时间;超时后会强制 abort server task。 -const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10); - -type ProxyStateHandle = Arc>>; -type ProxyRouter = axum::Router; - -#[derive(Clone)] -pub(crate) struct ProxyServiceHandle { - inner: Arc, -} - -impl ProxyServiceHandle { - pub(crate) fn new() -> Self { - Self { - inner: Arc::new(ProxyService::new()), - } - } - - pub(crate) async fn status(&self) -> ProxyServiceStatus { - self.inner.status().await - } - - pub(crate) async fn start(&self, app: AppHandle) -> Result { - self.inner.start(app).await - } - - pub(crate) async fn stop(&self) -> Result { - self.inner.stop().await - } - - pub(crate) async fn restart(&self, app: AppHandle) -> Result { - self.inner.restart(app).await - } - - pub(crate) async fn reload(&self, app: AppHandle) -> Result { - self.inner.reload(app).await - } -} - -#[derive(Clone, Serialize, Debug)] -#[serde(rename_all = "snake_case")] -pub(crate) enum ProxyServiceState { - Running, - Stopped, -} - -#[derive(Clone, Serialize)] -pub(crate) struct ProxyServiceStatus { - pub(crate) state: ProxyServiceState, - pub(crate) addr: Option, - pub(crate) last_error: Option, -} - -impl ProxyServiceStatus { - fn stopped(last_error: Option) -> Self { - Self { - state: ProxyServiceState::Stopped, - addr: None, - last_error, - } - } - - fn running(addr: String, last_error: Option) -> Self { - Self { - state: ProxyServiceState::Running, - addr: Some(addr), - last_error, - } - } -} - -struct ProxyService { - inner: Mutex, -} - -impl ProxyService { - fn new() -> Self { - Self { - inner: Mutex::new(ProxyServiceInner::new()), - } - } - - async fn status(&self) -> ProxyServiceStatus { - let mut inner = self.inner.lock().await; - inner.refresh_if_finished().await; - inner.status() - } - - async fn start(&self, app: AppHandle) -> Result { - let mut inner = self.inner.lock().await; - inner.refresh_if_finished().await; - inner.start(app).await?; - Ok(inner.status()) - } - - async fn stop(&self) -> Result { - let mut inner = self.inner.lock().await; - inner.refresh_if_finished().await; - inner.stop().await?; - Ok(inner.status()) - } - - async fn restart(&self, app: AppHandle) -> Result { - let mut inner = self.inner.lock().await; - inner.refresh_if_finished().await; - inner.restart(app).await?; - Ok(inner.status()) - } - - async fn reload(&self, app: AppHandle) -> Result { - let mut inner = self.inner.lock().await; - inner.refresh_if_finished().await; - inner.reload(app).await?; - Ok(inner.status()) - } -} - -struct ProxyServiceInner { - running: Option, - sqlite_pool: Option, - last_error: Option, -} - -impl ProxyServiceInner { - fn new() -> Self { - Self { - running: None, - sqlite_pool: None, - last_error: None, - } - } - - fn status(&self) -> ProxyServiceStatus { - match &self.running { - Some(running) => ProxyServiceStatus::running(running.addr.clone(), self.last_error.clone()), - None => ProxyServiceStatus::stopped(self.last_error.clone()), - } - } - - async fn refresh_if_finished(&mut self) { - let Some(running) = self.running.as_mut() else { - return; - }; - let Some(task) = running.task.as_ref() else { - return; - }; - if !task.is_finished() { - return; - } - let running = self.running.take().expect("running must exist"); - self.finish_task(running).await; - } - - async fn start(&mut self, app: AppHandle) -> Result<(), String> { - if self.running.is_some() { - return Ok(()); - } - if self.sqlite_pool.is_none() { - self.sqlite_pool = sqlite::open_write_pool(&app).await.ok(); - } - let sqlite_pool = self.sqlite_pool.clone(); - let loaded_config = ProxyConfig::load(&app).await?; - let addr = loaded_config.addr(); - - let (state_handle, router) = build_router_state(&app, loaded_config, sqlite_pool).await?; - let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); - let listener = tokio::net::TcpListener::bind(&addr) - .await - .map_err(|err| format!("Failed to bind {addr}: {err}"))?; - tracing::info!(addr = %addr, "proxy listening"); - - let task = tokio::spawn(async move { - axum::serve(listener, router) - .with_graceful_shutdown(async move { - let _ = shutdown_rx.await; - }) - .into_future() - .await - .map_err(|err| format!("Proxy server failed: {err}")) - }); - - self.running = Some(RunningProxy { - addr, - state_handle, - shutdown_tx: Some(shutdown_tx), - task: Some(task), - shutdown_timeout: DEFAULT_SHUTDOWN_TIMEOUT, - }); - self.last_error = None; - Ok(()) - } - - async fn stop(&mut self) -> Result<(), String> { - let Some(running) = self.running.take() else { - return Ok(()); - }; - self.finish_task(running).await; - Ok(()) - } - - async fn restart(&mut self, app: AppHandle) -> Result<(), String> { - self.stop().await?; - self.start(app).await - } - - async fn reload(&mut self, app: AppHandle) -> Result<(), String> { - tracing::debug!("proxy reload start"); - let start = Instant::now(); - if self.running.is_none() { - tracing::debug!("proxy reload: not running, start instead"); - return self.start(app).await; - } - let loaded_config = ProxyConfig::load(&app).await?; - let addr = loaded_config.addr(); - let current_addr = self - .running - .as_ref() - .map(|running| running.addr.as_str()) - .unwrap_or_default() - .to_string(); - - tracing::debug!(addr = %addr, current_addr = %current_addr, "proxy reload config loaded"); - if addr != current_addr { - // host/port 变更无法热更新监听地址;退化为安全重启。 - tracing::info!( - addr = %addr, - current_addr = %current_addr, - "proxy reload detected addr change, restarting" - ); - return self.restart(app).await; - } - let current_max_request_body_bytes = if let Some(running) = self.running.as_ref() { - let guard = running.state_handle.read().await; - guard.config.max_request_body_bytes - } else { - loaded_config.max_request_body_bytes - }; - if loaded_config.max_request_body_bytes != current_max_request_body_bytes { - tracing::info!( - new_max_request_body_bytes = loaded_config.max_request_body_bytes, - current_max_request_body_bytes = current_max_request_body_bytes, - "proxy reload detected body limit change, restarting" - ); - return self.restart(app).await; - } - - let sqlite_pool = self.sqlite_pool.clone(); - let new_state = build_proxy_state(&app, loaded_config, sqlite_pool).await?; - let Some(running) = self.running.as_ref() else { - tracing::debug!("proxy reload: running cleared before swap"); - return Ok(()); - }; - { - let mut guard = running.state_handle.write().await; - *guard = new_state; - } - tracing::debug!(elapsed_ms = start.elapsed().as_millis(), "proxy reload applied"); - Ok(()) - } - - async fn finish_task(&mut self, mut running: RunningProxy) { - if let Some(tx) = running.shutdown_tx.take() { - let _ = tx.send(()); - } - if let Some(task) = running.task.take() { - self.await_stop(task, running.shutdown_timeout).await; - } - } - - async fn await_stop(&mut self, task: JoinHandle>, timeout_duration: Duration) { - let mut task = task; - match timeout(timeout_duration, &mut task).await { - Ok(Ok(Ok(()))) => {} - Ok(Ok(Err(message))) => { - self.last_error = Some(message); - } - Ok(Err(err)) => { - self.last_error = Some(format!("Proxy task join failed: {err}")); - } - Err(_) => { - task.abort(); - self.last_error = Some("Proxy stop timed out; aborted.".to_string()); - } - } - } -} - -struct RunningProxy { - addr: String, - state_handle: ProxyStateHandle, - shutdown_tx: Option>, - task: Option>>, - shutdown_timeout: Duration, -} - -async fn build_router_state( - app: &AppHandle, - config: ProxyConfig, - sqlite_pool: Option, -) -> Result<(ProxyStateHandle, ProxyRouter), String> { - let state = build_proxy_state(app, config, sqlite_pool).await?; - let max_request_body_bytes = state.config.max_request_body_bytes; - let state_handle = Arc::new(RwLock::new(state)); - let router = - server::build_router(state_handle.clone(), max_request_body_bytes).with_state::<()>( - state_handle.clone(), - ); - Ok((state_handle.clone(), router)) -} - -async fn build_proxy_state( - app: &AppHandle, - config: ProxyConfig, - sqlite_pool: Option, -) -> Result, String> { - if let Some(logging_state) = app.try_state::() { - logging_state.apply_level(config.log_level); - } - let log = Arc::new(LogWriter::new(sqlite_pool)); - let http_clients = super::http_client::ProxyHttpClients::new()?; - let cursors = server::build_upstream_cursors(&config); - let request_detail = app - .state::>() - .inner() - .clone(); - let token_rate = app - .state::>() - .inner() - .clone(); - let kiro_accounts = app - .state::>() - .inner() - .clone(); - let codex_accounts = app - .state::>() - .inner() - .clone(); - let antigravity_accounts = app - .state::>() - .inner() - .clone(); - Ok(Arc::new(ProxyState { - config, - http_clients, - log, - cursors, - request_detail, - token_rate, - kiro_accounts, - codex_accounts, - antigravity_accounts, - })) -} diff --git a/src-tauri/src/proxy/sqlite.rs b/src-tauri/src/proxy/sqlite.rs deleted file mode 100644 index cd2f272..0000000 --- a/src-tauri/src/proxy/sqlite.rs +++ /dev/null @@ -1,170 +0,0 @@ -use sqlx::Row; -use sqlx::{ - sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous}, - SqlitePool, -}; -use std::path::PathBuf; -use std::time::Duration; -use tauri::AppHandle; -use tokio::sync::OnceCell; - -use super::config; - -const DB_FILE_NAME: &str = "data.db"; - -struct SqlitePools { - read: SqlitePool, - write: SqlitePool, -} - -// 只初始化一次,避免每次刷新重复建池与 schema/index 检查。 -static SQLITE_POOLS: OnceCell = OnceCell::const_new(); - -pub(crate) async fn open_read_pool(app: &AppHandle) -> Result { - let pools = open_pools(app).await?; - Ok(pools.read.clone()) -} - -pub(crate) async fn open_write_pool(app: &AppHandle) -> Result { - let pools = open_pools(app).await?; - Ok(pools.write.clone()) -} - -async fn open_pools(app: &AppHandle) -> Result<&'static SqlitePools, String> { - let app = app.clone(); - SQLITE_POOLS - .get_or_try_init(|| async move { - let path = usage_db_path(&app)?; - if let Some(parent) = path.parent() { - tokio::fs::create_dir_all(parent) - .await - .map_err(|err| format!("Failed to create db directory: {err}"))?; - } - let read = connect_pool(&path).await?; - init_schema(&read).await?; - let write = connect_pool(&path).await?; - init_schema(&write).await?; - Ok(SqlitePools { read, write }) - }) - .await -} - -fn usage_db_path(app: &AppHandle) -> Result { - Ok(config::config_dir_path(app)?.join(DB_FILE_NAME)) -} - -async fn connect_pool(path: &PathBuf) -> Result { - let options = SqliteConnectOptions::new() - .filename(path) - .create_if_missing(true) - .journal_mode(SqliteJournalMode::Wal) - .synchronous(SqliteSynchronous::Normal) - .busy_timeout(Duration::from_secs(5)); - - SqlitePoolOptions::new() - .max_connections(1) - .connect_with(options) - .await - .map_err(|err| format!("Failed to connect sqlite: {err}")) -} - -pub(crate) async fn init_schema(pool: &SqlitePool) -> Result<(), String> { - sqlx::query( - r#" -CREATE TABLE IF NOT EXISTS request_logs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ts_ms INTEGER NOT NULL, - path TEXT NOT NULL, - provider TEXT NOT NULL, - upstream_id TEXT NOT NULL, - model TEXT, - mapped_model TEXT, - stream INTEGER NOT NULL, - status INTEGER NOT NULL, - input_tokens INTEGER, - output_tokens INTEGER, - total_tokens INTEGER, - cached_tokens INTEGER, - usage_json TEXT, - upstream_request_id TEXT, - request_headers TEXT, - request_body TEXT, - response_error TEXT, - latency_ms INTEGER NOT NULL -); -"#, - ) - .execute(pool) - .await - .map_err(|err| format!("Failed to create request_logs table: {err}"))?; - - ensure_request_logs_columns(pool).await?; - - sqlx::query("CREATE INDEX IF NOT EXISTS idx_request_logs_ts_ms ON request_logs(ts_ms);") - .execute(pool) - .await - .map_err(|err| format!("Failed to create idx_request_logs_ts_ms: {err}"))?; - - sqlx::query( - "CREATE INDEX IF NOT EXISTS idx_request_logs_provider_ts_ms ON request_logs(provider, ts_ms);", - ) - .execute(pool) - .await - .map_err(|err| format!("Failed to create idx_request_logs_provider_ts_ms: {err}"))?; - - Ok(()) -} - -async fn ensure_request_logs_columns(pool: &SqlitePool) -> Result<(), String> { - let columns = sqlx::query("PRAGMA table_info(request_logs);") - .fetch_all(pool) - .await - .map_err(|err| format!("Failed to read request_logs schema: {err}"))? - .into_iter() - .filter_map(|row| row.try_get::("name").ok()) - .collect::>(); - - if !columns.contains("cached_tokens") { - sqlx::query("ALTER TABLE request_logs ADD COLUMN cached_tokens INTEGER;") - .execute(pool) - .await - .map_err(|err| format!("Failed to add cached_tokens column: {err}"))?; - } - - if !columns.contains("mapped_model") { - sqlx::query("ALTER TABLE request_logs ADD COLUMN mapped_model TEXT;") - .execute(pool) - .await - .map_err(|err| format!("Failed to add mapped_model column: {err}"))?; - } - - if !columns.contains("usage_json") { - sqlx::query("ALTER TABLE request_logs ADD COLUMN usage_json TEXT;") - .execute(pool) - .await - .map_err(|err| format!("Failed to add usage_json column: {err}"))?; - } - - if !columns.contains("request_headers") { - sqlx::query("ALTER TABLE request_logs ADD COLUMN request_headers TEXT;") - .execute(pool) - .await - .map_err(|err| format!("Failed to add request_headers column: {err}"))?; - } - - if !columns.contains("request_body") { - sqlx::query("ALTER TABLE request_logs ADD COLUMN request_body TEXT;") - .execute(pool) - .await - .map_err(|err| format!("Failed to add request_body column: {err}"))?; - } - - if !columns.contains("response_error") { - sqlx::query("ALTER TABLE request_logs ADD COLUMN response_error TEXT;") - .execute(pool) - .await - .map_err(|err| format!("Failed to add response_error column: {err}"))?; - } - - Ok(()) -} diff --git a/src-tauri/src/proxy/sse.rs b/src-tauri/src/proxy/sse.rs deleted file mode 100644 index 2e2726a..0000000 --- a/src-tauri/src/proxy/sse.rs +++ /dev/null @@ -1,64 +0,0 @@ -pub(crate) struct SseEventParser { - buffer: String, - current_data: String, -} - -impl SseEventParser { - pub(crate) fn new() -> Self { - Self { - buffer: String::new(), - current_data: String::new(), - } - } - - pub(crate) fn push_chunk(&mut self, chunk: &[u8], mut on_event: F) { - let text = String::from_utf8_lossy(chunk); - self.buffer.push_str(&text); - while let Some(pos) = self.buffer.find('\n') { - let mut line = self.buffer[..pos].to_string(); - self.buffer.drain(..=pos); - if line.ends_with('\r') { - line.pop(); - } - self.process_line(&line, &mut on_event); - } - } - - pub(crate) fn finish(&mut self, mut on_event: F) { - if !self.buffer.is_empty() { - let mut buffer = std::mem::take(&mut self.buffer); - if buffer.ends_with('\r') { - buffer.pop(); - } - self.process_line(&buffer, &mut on_event); - } - self.flush_event(&mut on_event); - } - - fn process_line(&mut self, line: &str, on_event: &mut F) { - if line.is_empty() { - self.flush_event(on_event); - return; - } - if let Some(data) = line.strip_prefix("data:") { - let data = data.trim_start(); - if !self.current_data.is_empty() { - self.current_data.push('\n'); - } - self.current_data.push_str(data); - } - } - - fn flush_event(&mut self, on_event: &mut F) { - if self.current_data.is_empty() { - return; - } - let data = std::mem::take(&mut self.current_data); - let data = data.trim(); - if data.is_empty() { - return; - } - on_event(data.to_string()); - } -} - diff --git a/src-tauri/src/proxy/token_estimator.rs b/src-tauri/src/proxy/token_estimator.rs deleted file mode 100644 index cf7d442..0000000 --- a/src-tauri/src/proxy/token_estimator.rs +++ /dev/null @@ -1,259 +0,0 @@ -use std::sync::OnceLock; - -use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE}; - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub(crate) enum TokenProvider { - OpenAI, - Gemini, - Claude, -} - -#[derive(Clone, Copy)] -struct Multipliers { - word: f64, - number: f64, - cjk: f64, - symbol: f64, - math_symbol: f64, - url_delim: f64, - at_sign: f64, - emoji: f64, - newline: f64, - space: f64, - base_pad: f64, -} - -const MULTIPLIERS_OPENAI: Multipliers = Multipliers { - word: 1.02, - number: 1.55, - cjk: 0.85, - symbol: 0.4, - math_symbol: 2.68, - url_delim: 1.0, - at_sign: 2.0, - emoji: 2.12, - newline: 0.5, - space: 0.42, - base_pad: 0.0, -}; - -const MULTIPLIERS_GEMINI: Multipliers = Multipliers { - word: 1.15, - number: 2.8, - cjk: 0.68, - symbol: 0.38, - math_symbol: 1.05, - url_delim: 1.2, - at_sign: 2.5, - emoji: 1.08, - newline: 1.15, - space: 0.2, - base_pad: 0.0, -}; - -const MULTIPLIERS_CLAUDE: Multipliers = Multipliers { - word: 1.13, - number: 1.63, - cjk: 1.21, - symbol: 0.4, - math_symbol: 4.52, - url_delim: 1.26, - at_sign: 2.82, - emoji: 2.6, - newline: 0.89, - space: 0.39, - base_pad: 0.0, -}; - -pub(crate) fn provider_for_model(model: Option<&str>) -> TokenProvider { - let Some(model) = model else { - return TokenProvider::OpenAI; - }; - let normalized = model.trim().to_ascii_lowercase(); - if normalized.contains("gemini") { - return TokenProvider::Gemini; - } - if normalized.contains("claude") { - return TokenProvider::Claude; - } - TokenProvider::OpenAI -} - -pub(crate) fn estimate_text_tokens(model: Option<&str>, text: &str) -> u64 { - if text.is_empty() { - return 0; - } - let provider = provider_for_model(model); - if provider == TokenProvider::OpenAI { - return estimate_text_tokens_openai(model, text); - } - estimate_text_tokens_by_provider(provider, text) -} - -fn estimate_text_tokens_openai(model: Option<&str>, text: &str) -> u64 { - let bpe = bpe_for_model(model); - bpe.encode_with_special_tokens(text).len() as u64 -} - -fn estimate_text_tokens_by_provider(provider: TokenProvider, text: &str) -> u64 { - let multipliers = match provider { - TokenProvider::OpenAI => MULTIPLIERS_OPENAI, - TokenProvider::Gemini => MULTIPLIERS_GEMINI, - TokenProvider::Claude => MULTIPLIERS_CLAUDE, - }; - - // 以字符类别估算 token 数,复刻 new-api 的启发式逻辑。 - let mut count = 0.0f64; - let mut current_word_type: Option = None; - - for ch in text.chars() { - if ch.is_whitespace() { - current_word_type = None; - if ch == '\n' || ch == '\t' { - count += multipliers.newline; - } else { - count += multipliers.space; - } - continue; - } - - if is_cjk(ch) { - current_word_type = None; - count += multipliers.cjk; - continue; - } - - if is_emoji(ch) { - current_word_type = None; - count += multipliers.emoji; - continue; - } - - if is_latin_or_number(ch) { - let new_type = if ch.is_ascii_digit() || ch.is_numeric() { - WordType::Number - } else { - WordType::Latin - }; - - if current_word_type.is_none() || current_word_type != Some(new_type) { - count += match new_type { - WordType::Latin => multipliers.word, - WordType::Number => multipliers.number, - }; - current_word_type = Some(new_type); - } - continue; - } - - current_word_type = None; - if is_math_symbol(ch) { - count += multipliers.math_symbol; - } else if ch == '@' { - count += multipliers.at_sign; - } else if is_url_delim(ch) { - count += multipliers.url_delim; - } else { - count += multipliers.symbol; - } - } - - let total = count.ceil() + multipliers.base_pad; - total as u64 -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum WordType { - Latin, - Number, -} - -fn bpe_for_model(model: Option<&str>) -> &'static CoreBPE { - if matches_o200k(model) { - static O200K: OnceLock = OnceLock::new(); - return O200K.get_or_init(|| { - o200k_base().unwrap_or_else(|_| cl100k_base().expect("cl100k_base")) - }); - } - - static CL100K: OnceLock = OnceLock::new(); - CL100K.get_or_init(|| cl100k_base().expect("cl100k_base")) -} - -fn matches_o200k(model: Option<&str>) -> bool { - let Some(model) = model else { - return false; - }; - let normalized = model.trim().to_ascii_lowercase(); - if normalized.is_empty() { - return false; - } - normalized.starts_with("o1") - || normalized.starts_with("o3") - || normalized.starts_with("o4") - || normalized.starts_with("gpt-4o") - || normalized.starts_with("gpt-4.1") -} - -fn is_latin_or_number(ch: char) -> bool { - if ch.is_ascii_alphanumeric() { - return true; - } - ch.is_alphanumeric() -} - -fn is_cjk(ch: char) -> bool { - let code = ch as u32; - matches!( - code, - 0x3400..=0x4DBF - | 0x4E00..=0x9FFF - | 0xF900..=0xFAFF - | 0x20000..=0x2A6DF - | 0x2A700..=0x2B73F - | 0x2B740..=0x2B81F - | 0x2B820..=0x2CEAF - | 0x2CEB0..=0x2EBEF - | 0x30000..=0x3134F - ) -} - -fn is_emoji(ch: char) -> bool { - let code = ch as u32; - matches!( - code, - 0x1F300..=0x1F5FF - | 0x1F600..=0x1F64F - | 0x1F680..=0x1F6FF - | 0x1F700..=0x1F77F - | 0x1F780..=0x1F7FF - | 0x1F800..=0x1F8FF - | 0x1F900..=0x1F9FF - | 0x1FA00..=0x1FAFF - | 0x2600..=0x26FF - | 0x2700..=0x27BF - ) -} - -fn is_math_symbol(ch: char) -> bool { - let code = ch as u32; - matches!( - code, - 0x2200..=0x22FF | 0x27C0..=0x27EF | 0x2980..=0x29FF | 0x2A00..=0x2AFF - | 0x2190..=0x21FF | 0x2B00..=0x2BFF - ) || matches!(ch, '+' | '-' | '*' | '/' | '=' | '^' | '%') -} - -fn is_url_delim(ch: char) -> bool { - matches!( - ch, - ':' | '/' | '?' | '#' | '[' | ']' | '!' | '$' | '&' | '\'' - | '(' | ')' | '*' | '+' | ',' | ';' | '=' - ) -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "token_estimator.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/token_estimator.test.rs b/src-tauri/src/proxy/token_estimator.test.rs deleted file mode 100644 index 47e33c1..0000000 --- a/src-tauri/src/proxy/token_estimator.test.rs +++ /dev/null @@ -1,8 +0,0 @@ -use super::*; - -#[test] -fn estimate_tokens_for_claude_uses_heuristic() { - let tokens = estimate_text_tokens(Some("claude-3-opus"), "a"); - // Claude word multiplier 1.13 -> ceil => 2 - assert_eq!(tokens, 2); -} diff --git a/src-tauri/src/proxy/token_rate.rs b/src-tauri/src/proxy/token_rate.rs deleted file mode 100644 index a6de195..0000000 --- a/src-tauri/src/proxy/token_rate.rs +++ /dev/null @@ -1,428 +0,0 @@ -use std::collections::{HashMap, VecDeque}; -use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use tokio::{ - sync::{watch, Mutex, RwLock}, - time::{interval, MissedTickBehavior}, -}; - -const RATE_WINDOW: Duration = Duration::from_secs(1); -const CLEANUP_INTERVAL: Duration = Duration::from_secs(60); -// 超过该时长未记录 token 的请求窗口视为过期,避免 HashMap 无界增长。 -const REQUEST_TTL: Duration = Duration::from_secs(300); - -#[derive(Clone)] -pub(crate) struct TokenRateTracker { - inner: Arc, - activity_tx: watch::Sender, -} - -struct TrackerInner { - next_id: AtomicU64, - active: AtomicUsize, - enabled: AtomicBool, - generation: AtomicU64, - cleanup_started: AtomicBool, - last_cleanup: Mutex, - requests: RwLock>>>, -} - -struct RequestWindow { - events: VecDeque, - last_seen: Instant, -} - -struct TokenEvent { - ts: Instant, - input: u64, - output: u64, -} - -#[derive(Debug, Clone, Copy)] -pub(crate) struct TokenRateSnapshot { - pub(crate) input: u64, - pub(crate) output: u64, - pub(crate) total: u64, - pub(crate) connections: u64, -} - -pub(crate) struct RequestTokenTracker { - id: Option, - window: Option>>, - tracker: TokenRateTracker, - model: Option, - generation: Option, -} - -impl TokenRateTracker { - pub(crate) fn new() -> Arc { - let (activity_tx, _activity_rx) = watch::channel(0u64); - let tracker = Arc::new(Self { - inner: Arc::new(TrackerInner { - next_id: AtomicU64::new(1), - active: AtomicUsize::new(0), - enabled: AtomicBool::new(true), - generation: AtomicU64::new(1), - cleanup_started: AtomicBool::new(false), - last_cleanup: Mutex::new(Instant::now()), - requests: RwLock::new(HashMap::new()), - }), - activity_tx, - }); - tracker.try_start_cleanup(); - tracker - } - - pub(crate) fn subscribe_activity(&self) -> watch::Receiver { - self.activity_tx.subscribe() - } - - pub(crate) fn notify_activity(&self) { - let next = self.activity_tx.borrow().wrapping_add(1); - let _ = self.activity_tx.send(next); - } - - pub(crate) async fn set_enabled(&self, enabled: bool) { - self.try_start_cleanup(); - tracing::debug!(enabled, "token_rate set_enabled start"); - let previous = self.inner.enabled.swap(enabled, Ordering::SeqCst); - if previous == enabled { - tracing::debug!(enabled, "token_rate set_enabled noop"); - return; - } - // 每次开关切换递增 generation,确保旧请求不会在重新开启后继续计数。 - self.inner.generation.fetch_add(1, Ordering::SeqCst); - if !enabled { - tracing::debug!("token_rate set_enabled clearing requests start"); - let mut guard = self.inner.requests.write().await; - guard.clear(); - self.inner.active.store(0, Ordering::SeqCst); - tracing::debug!("token_rate set_enabled clearing requests done"); - } - tracing::debug!(enabled, "token_rate set_enabled done"); - } - - pub(crate) async fn register( - &self, - model: Option, - input_tokens: Option, - ) -> RequestTokenTracker { - self.try_start_cleanup(); - self.maybe_cleanup(Instant::now()).await; - let enabled = self.inner.enabled.load(Ordering::SeqCst); - let generation = self.inner.generation.load(Ordering::SeqCst); - let (mut id, mut window) = if enabled { - let id = self.inner.next_id.fetch_add(1, Ordering::SeqCst); - let window = Arc::new(Mutex::new(RequestWindow::new())); - let mut guard = self.inner.requests.write().await; - guard.insert(id, window.clone()); - self.inner.active.fetch_add(1, Ordering::SeqCst); - (Some(id), Some(window)) - } else { - (None, None) - }; - let mut effective_generation = if enabled { Some(generation) } else { None }; - if let Some(current_id) = id { - let still_enabled = self.inner.enabled.load(Ordering::SeqCst); - let current_generation = self.inner.generation.load(Ordering::SeqCst); - if !still_enabled || current_generation != generation { - // 开关状态变更后不再追踪该请求,避免重新开启时继续计数。 - self.unregister(current_id).await; - id = None; - window = None; - effective_generation = None; - } - } - - let tracker = RequestTokenTracker { - id, - window, - tracker: self.clone(), - model, - generation: effective_generation, - }; - if let Some(tokens) = input_tokens { - tracker.add_input_tokens(tokens).await; - } - if enabled { - self.notify_activity(); - } - tracker - } - - pub(crate) async fn snapshot(&self) -> TokenRateSnapshot { - self.try_start_cleanup(); - if !self.inner.enabled.load(Ordering::SeqCst) { - return TokenRateSnapshot { - input: 0, - output: 0, - total: 0, - connections: 0, - }; - } - self.maybe_cleanup(Instant::now()).await; - let now = Instant::now(); - let windows: Vec>> = self - .inner - .requests - .read() - .await - .values() - .cloned() - .collect(); - let mut input = 0u64; - let mut output = 0u64; - for window in windows { - let mut guard = window.lock().await; - guard.prune(now); - let (i, o) = guard.sum(); - input = input.saturating_add(i); - output = output.saturating_add(o); - } - TokenRateSnapshot { - input, - output, - total: input.saturating_add(output), - connections: self.inner.active.load(Ordering::SeqCst) as u64, - } - } - - pub(crate) fn has_active_requests(&self) -> bool { - if !self.inner.enabled.load(Ordering::SeqCst) { - return false; - } - self.inner.active.load(Ordering::SeqCst) > 0 - } - - async fn record(&self, window: &Arc>, input: u64, output: u64) { - if input == 0 && output == 0 { - return; - } - let now = Instant::now(); - { - let mut guard = window.lock().await; - guard.push(TokenEvent { - ts: now, - input, - output, - }); - } - self.maybe_cleanup(now).await; - } - - async fn unregister(&self, id: u64) { - let removed = self - .inner - .requests - .write() - .await - .remove(&id) - .is_some(); - if removed { - self.inner.active.fetch_sub(1, Ordering::SeqCst); - } - } - - // 在有 Tokio runtime 时启动清理任务,避免无 reactor 场景崩溃。 - fn try_start_cleanup(&self) { - if self.inner.cleanup_started.load(Ordering::SeqCst) { - return; - } - let Ok(handle) = tokio::runtime::Handle::try_current() else { - return; - }; - if self - .inner - .cleanup_started - .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) - .is_err() - { - return; - } - let weak_inner = Arc::downgrade(&self.inner); - handle.spawn(async move { - let mut ticker = interval(CLEANUP_INTERVAL); - ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); - loop { - ticker.tick().await; - let Some(inner) = weak_inner.upgrade() else { - break; - }; - if !inner.enabled.load(Ordering::SeqCst) { - continue; - } - cleanup_expired_inner(&inner, Instant::now()).await; - } - }); - } - - // 惰性清理:在流量发生时按间隔触发,减少单独后台依赖。 - async fn maybe_cleanup(&self, now: Instant) { - if !self.inner.enabled.load(Ordering::SeqCst) { - return; - } - if !self.should_cleanup(now).await { - return; - } - self.cleanup_expired(now).await; - } - - async fn should_cleanup(&self, now: Instant) -> bool { - let mut guard = self.inner.last_cleanup.lock().await; - if now.duration_since(*guard) < CLEANUP_INTERVAL { - return false; - } - *guard = now; - true - } - - async fn cleanup_expired(&self, now: Instant) { - cleanup_expired_inner(&self.inner, now).await; - } -} - -async fn cleanup_expired_inner(inner: &TrackerInner, now: Instant) { - let windows: Vec<(u64, Arc>)> = inner - .requests - .read() - .await - .iter() - .map(|(id, window)| (*id, window.clone())) - .collect(); - if windows.is_empty() { - return; - } - let mut expired = Vec::new(); - for (id, window) in windows { - let guard = window.lock().await; - if guard.is_expired(now) { - expired.push(id); - } - } - if expired.is_empty() { - return; - } - let mut guard = inner.requests.write().await; - let mut removed = 0usize; - for id in expired { - if guard.remove(&id).is_some() { - removed += 1; - } - } - if removed > 0 { - inner.active.fetch_sub(removed, Ordering::SeqCst); - } -} - -impl RequestWindow { - fn new() -> Self { - Self { - events: VecDeque::new(), - last_seen: Instant::now(), - } - } - - fn push(&mut self, event: TokenEvent) { - let now = event.ts; - self.events.push_back(event); - self.last_seen = now; - self.prune(now); - } - - fn prune(&mut self, now: Instant) { - while let Some(front) = self.events.front() { - if now.duration_since(front.ts) <= RATE_WINDOW { - break; - } - self.events.pop_front(); - } - } - - fn sum(&self) -> (u64, u64) { - let mut input = 0u64; - let mut output = 0u64; - for event in &self.events { - input = input.saturating_add(event.input); - output = output.saturating_add(event.output); - } - (input, output) - } - - fn is_expired(&self, now: Instant) -> bool { - now.saturating_duration_since(self.last_seen) > REQUEST_TTL - } -} - -impl RequestTokenTracker { - pub(crate) fn disabled() -> Self { - // `generation=None` makes `can_record()` return false, so this tracker is a no-op. - // We still need a TokenRateTracker instance to satisfy the struct layout. - let tracker = TokenRateTracker::new(); - Self { - id: None, - window: None, - tracker: tracker.as_ref().clone(), - model: None, - generation: None, - } - } - - pub(crate) async fn add_input_tokens(&self, tokens: u64) { - if !self.can_record() { - return; - } - let Some(window) = self.window.as_ref() else { - return; - }; - self.tracker.record(window, tokens, 0).await; - } - - pub(crate) async fn add_output_text(&self, text: &str) { - if !self.can_record() { - return; - } - let tokens = estimate_text_tokens(self.model.as_deref(), text); - let Some(window) = self.window.as_ref() else { - return; - }; - self.tracker.record(window, 0, tokens).await; - } - - fn can_record(&self) -> bool { - let Some(generation) = self.generation else { - return false; - }; - if !self.tracker.inner.enabled.load(Ordering::SeqCst) { - return false; - } - // generation 不一致说明开关已经切换,旧请求不再计数。 - self.tracker.inner.generation.load(Ordering::SeqCst) == generation - } -} - -impl Drop for RequestTokenTracker { - fn drop(&mut self) { - let Some(id) = self.id else { - return; - }; - if let Ok(mut guard) = self.tracker.inner.requests.try_write() { - if guard.remove(&id).is_some() { - self.tracker.inner.active.fetch_sub(1, Ordering::SeqCst); - } - return; - } - // 避免在 Drop 中阻塞异步运行时,使用最佳努力异步清理。 - if let Ok(handle) = tokio::runtime::Handle::try_current() { - let tracker = self.tracker.clone(); - handle.spawn(async move { - tracker.unregister(id).await; - }); - } - } -} - -pub(crate) fn estimate_text_tokens(model: Option<&str>, text: &str) -> u64 { - super::token_estimator::estimate_text_tokens(model, text) -} diff --git a/src-tauri/src/proxy/upstream.rs b/src-tauri/src/proxy/upstream.rs deleted file mode 100644 index 929a681..0000000 --- a/src-tauri/src/proxy/upstream.rs +++ /dev/null @@ -1,713 +0,0 @@ -use axum::{ - http::{ - header::{AUTHORIZATION, USER_AGENT}, - HeaderMap, HeaderValue, Method, StatusCode, - }, - response::Response, -}; -use crate::antigravity::endpoints as antigravity_endpoints; -use crate::antigravity::project as antigravity_project; -use std::{ - sync::{ - Arc, - }, - time::Instant, -}; - -const GEMINI_API_KEY_QUERY: &str = "key"; -const LOCAL_UPSTREAM_ID: &str = "local"; -const ANTIGRAVITY_GENERATE_PATH: &str = "/v1internal:generateContent"; -const ANTIGRAVITY_STREAM_PATH: &str = "/v1internal:streamGenerateContent"; - -mod request; -mod attempt; -mod result; -mod utils; -mod kiro; -mod kiro_headers; -mod kiro_http; - -use utils::{ - build_group_order, resolve_group_start, -}; - -#[cfg(test)] -use crate::proxy::redact::redact_query_param_value; - -use super::{ - config::{ProviderUpstreams, UpstreamRuntime}, - gemini, - http, - http::RequestAuth, - openai_compat::FormatTransform, - request_detail::RequestDetailSnapshot, - request_body::ReplayableBody, - ProxyState, - RequestMeta, -}; - -const REQUEST_MODEL_MAPPING_LIMIT_BYTES: usize = 4 * 1024 * 1024; - -pub(super) async fn forward_upstream_request( - state: Arc, - method: Method, - provider: &str, - inbound_path: &str, - upstream_path_with_query: &str, - headers: HeaderMap, - body: ReplayableBody, - meta: RequestMeta, - request_auth: RequestAuth, - response_transform: FormatTransform, - request_detail: Option, -) -> Response { - let upstreams = match resolve_provider_upstreams( - &state, - provider, - inbound_path, - &meta, - request_detail.as_ref(), - ) { - Ok(upstreams) => upstreams, - Err(response) => return response, - }; - let summary = run_upstream_groups( - &state, - method, - provider, - inbound_path, - upstream_path_with_query, - &headers, - &body, - &meta, - &request_auth, - response_transform, - request_detail.clone(), - upstreams, - ) - .await; - if let Some(response) = summary.response { - return response; - } - finalize_forward_response( - &state, - provider, - inbound_path, - &meta, - request_detail.as_ref(), - summary, - ) -} - -struct GroupAttemptResult { - response: Option, - attempted: usize, - missing_auth: bool, - last_timeout_error: Option, - last_retry_error: Option, - last_retry_response: Option, -} - -impl GroupAttemptResult { - fn new() -> Self { - Self { - response: None, - attempted: 0, - missing_auth: false, - last_timeout_error: None, - last_retry_error: None, - last_retry_response: None, - } - } -} - -struct ForwardAttemptState { - response: Option, - attempted: usize, - missing_auth: bool, - last_timeout_error: Option, - last_retry_error: Option, - last_retry_response: Option, -} - -impl ForwardAttemptState { - fn new() -> Self { - Self { - response: None, - attempted: 0, - missing_auth: false, - last_timeout_error: None, - last_retry_error: None, - last_retry_response: None, - } - } -} - -enum AttemptOutcome { - Success(Response), - Retryable { - message: String, - response: Option, - is_timeout: bool, - }, - Fatal(Response), - SkippedAuth, -} - -fn apply_attempt_outcome( - result: &mut GroupAttemptResult, - outcome: AttemptOutcome, -) -> bool { - match outcome { - AttemptOutcome::Success(response) | AttemptOutcome::Fatal(response) => { - result.response = Some(response); - true - } - AttemptOutcome::Retryable { - message, - response, - is_timeout, - } => { - if is_timeout { - result.last_timeout_error = Some(message.clone()); - } else { - result.last_retry_error = Some(message.clone()); - } - if response.is_some() { - result.last_retry_response = response; - } - false - } - AttemptOutcome::SkippedAuth => { - result.missing_auth = true; - false - } - } -} - -fn merge_group_result(state: &mut ForwardAttemptState, result: GroupAttemptResult) -> bool { - state.attempted += result.attempted; - state.missing_auth |= result.missing_auth; - if let Some(response) = result.response { - state.response = Some(response); - return true; - } - if result.last_timeout_error.is_some() { - state.last_timeout_error = result.last_timeout_error; - } - if result.last_retry_error.is_some() { - state.last_retry_error = result.last_retry_error; - } - if let Some(response) = result.last_retry_response { - state.last_retry_response = Some(response); - } - false -} - -pub(super) struct PreparedUpstreamRequest { - upstream_path_with_query: String, - upstream_url: String, - request_headers: HeaderMap, - meta: RequestMeta, - antigravity: Option, -} - -struct ResolvedUpstreamAuth { - upstream_url: String, - auth: http::UpstreamAuthHeader, - extra_headers: Option, - antigravity: Option, -} - -#[derive(Clone)] -pub(super) struct AntigravityRequestInfo { - project_id: Option, - user_agent: String, -} - -fn resolve_provider_upstreams<'a>( - state: &'a ProxyState, - provider: &str, - inbound_path: &str, - meta: &RequestMeta, - request_detail: Option<&RequestDetailSnapshot>, -) -> Result<&'a ProviderUpstreams, Response> { - match state.config.provider_upstreams(provider) { - Some(upstreams) => Ok(upstreams), - None => { - result::log_upstream_error_if_needed( - &state.log, - request_detail, - meta, - provider, - LOCAL_UPSTREAM_ID, - inbound_path, - StatusCode::BAD_GATEWAY, - "No available upstream configured.".to_string(), - Instant::now(), - ); - Err(http::error_response( - StatusCode::BAD_GATEWAY, - "No available upstream configured.", - )) - } - } -} - -async fn run_upstream_groups( - state: &ProxyState, - method: Method, - provider: &str, - inbound_path: &str, - upstream_path_with_query: &str, - headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - request_auth: &RequestAuth, - response_transform: FormatTransform, - request_detail: Option, - upstreams: &ProviderUpstreams, -) -> ForwardAttemptState { - let mut summary = ForwardAttemptState::new(); - for (group_index, group) in upstreams.groups.iter().enumerate() { - // Only rotate within the highest priority group; retry network failures before degrading. - if group.items.is_empty() { - continue; - } - let result = try_group_upstreams( - state, - method.clone(), - provider, - group_index, - &group.items, - inbound_path, - upstream_path_with_query, - headers, - body, - meta, - request_auth, - response_transform, - request_detail.clone(), - ) - .await; - if merge_group_result(&mut summary, result) { - break; - } - } - summary -} - -fn finalize_forward_response( - state: &ProxyState, - provider: &str, - inbound_path: &str, - meta: &RequestMeta, - request_detail: Option<&RequestDetailSnapshot>, - summary: ForwardAttemptState, -) -> Response { - if summary.attempted == 0 && summary.missing_auth { - result::log_upstream_error_if_needed( - &state.log, - request_detail, - meta, - provider, - LOCAL_UPSTREAM_ID, - inbound_path, - StatusCode::UNAUTHORIZED, - "Missing upstream API key.".to_string(), - Instant::now(), - ); - return http::error_response(StatusCode::UNAUTHORIZED, "Missing upstream API key."); - } - if let Some(response) = summary.last_retry_response { - return response; - } - if let Some(err) = summary.last_timeout_error { - return http::error_response(StatusCode::GATEWAY_TIMEOUT, err); - } - if let Some(err) = summary.last_retry_error { - return http::error_response( - StatusCode::BAD_GATEWAY, - format!("Upstream request failed: {err}"), - ); - } - http::error_response( - StatusCode::BAD_GATEWAY, - "No available upstream configured.", - ) -} - -async fn try_group_upstreams( - state: &ProxyState, - method: Method, - provider: &str, - group_index: usize, - items: &[UpstreamRuntime], - inbound_path: &str, - upstream_path_with_query: &str, - headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - request_auth: &RequestAuth, - response_transform: FormatTransform, - request_detail: Option, -) -> GroupAttemptResult { - let mut result = GroupAttemptResult::new(); - let start = resolve_group_start(state, provider, group_index, items.len()); - for item_index in build_group_order(items.len(), start) { - let upstream = &items[item_index]; - let outcome = attempt::attempt_upstream( - state, - method.clone(), - provider, - upstream, - inbound_path, - upstream_path_with_query, - headers, - body, - meta, - request_auth, - response_transform, - request_detail.clone(), - ) - .await; - if !matches!(outcome, AttemptOutcome::SkippedAuth) { - result.attempted += 1; - } - if apply_attempt_outcome(&mut result, outcome) { - return result; - } - } - result -} - -async fn prepare_upstream_request( - state: &ProxyState, - provider: &str, - upstream: &UpstreamRuntime, - upstream_path_with_query: &str, - headers: &HeaderMap, - meta: &RequestMeta, - request_auth: &RequestAuth, -) -> Result { - let mapped_meta = build_mapped_meta(meta, upstream, provider); - let upstream_path_with_query = - resolve_upstream_path_with_query(provider, upstream_path_with_query, &mapped_meta); - let upstream_url = upstream.upstream_url(&upstream_path_with_query); - let resolved = resolve_upstream_auth( - state, - provider, - upstream, - request_auth, - &upstream_path_with_query, - &upstream_url, - ) - .await?; - let ResolvedUpstreamAuth { - upstream_url, - auth, - extra_headers, - antigravity, - } = resolved; - let request_headers = request::build_request_headers( - provider, - headers, - auth, - extra_headers.as_ref(), - upstream.header_overrides.as_deref(), - ); - Ok(PreparedUpstreamRequest { - upstream_path_with_query, - upstream_url, - request_headers, - meta: mapped_meta, - antigravity, - }) -} - -async fn resolve_upstream_auth( - state: &ProxyState, - provider: &str, - upstream: &UpstreamRuntime, - request_auth: &RequestAuth, - upstream_path_with_query: &str, - upstream_url: &str, -) -> Result { - if provider == "gemini" { - let (upstream_url, auth) = request::resolve_gemini_upstream( - upstream, - request_auth, - upstream_path_with_query, - upstream_url, - )?; - return Ok(ResolvedUpstreamAuth { - upstream_url, - auth, - extra_headers: None, - antigravity: None, - }); - } - if provider == "kiro" { - return resolve_kiro_upstream(state, upstream, upstream_url).await; - } - if provider == "codex" { - return resolve_codex_upstream(state, upstream, upstream_url).await; - } - if provider == "antigravity" { - return resolve_antigravity_upstream(state, upstream, upstream_url).await; - } - let auth = match http::resolve_upstream_auth(provider, upstream, request_auth) { - Ok(Some(auth)) => auth, - Ok(None) => return Err(AttemptOutcome::SkippedAuth), - Err(response) => return Err(AttemptOutcome::Fatal(response)), - }; - Ok(ResolvedUpstreamAuth { - upstream_url: upstream_url.to_string(), - auth, - extra_headers: None, - antigravity: None, - }) -} - -async fn resolve_kiro_upstream( - state: &ProxyState, - upstream: &UpstreamRuntime, - upstream_url: &str, -) -> Result { - let Some(account_id) = upstream.kiro_account_id.as_deref() else { - return Err(AttemptOutcome::Fatal(http::error_response( - StatusCode::UNAUTHORIZED, - "Kiro account is not configured.", - ))); - }; - let token = state - .kiro_accounts - .get_access_token(account_id) - .await - .map_err(|err| AttemptOutcome::Fatal(http::error_response(StatusCode::UNAUTHORIZED, err)))?; - let value = http::bearer_header(&token).ok_or_else(|| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::UNAUTHORIZED, - "Upstream access token contains invalid characters.", - )) - })?; - Ok(ResolvedUpstreamAuth { - upstream_url: upstream_url.to_string(), - auth: http::UpstreamAuthHeader { - name: AUTHORIZATION, - value, - }, - extra_headers: None, - antigravity: None, - }) -} - -async fn resolve_codex_upstream( - state: &ProxyState, - upstream: &UpstreamRuntime, - upstream_url: &str, -) -> Result { - let Some(account_id) = upstream.codex_account_id.as_deref() else { - return Err(AttemptOutcome::Fatal(http::error_response( - StatusCode::UNAUTHORIZED, - "Codex account is not configured.", - ))); - }; - let record = state - .codex_accounts - .get_account_record(account_id) - .await - .map_err(|err| AttemptOutcome::Fatal(http::error_response(StatusCode::UNAUTHORIZED, err)))?; - let value = http::bearer_header(&record.access_token).ok_or_else(|| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::UNAUTHORIZED, - "Upstream access token contains invalid characters.", - )) - })?; - let mut extra_headers = HeaderMap::new(); - if let Some(account_id) = record.account_id.as_deref() { - if let Ok(value) = axum::http::HeaderValue::from_str(account_id) { - extra_headers.insert(axum::http::HeaderName::from_static("chatgpt-account-id"), value); - } - } - let extra_headers = if extra_headers.is_empty() { - None - } else { - Some(extra_headers) - }; - Ok(ResolvedUpstreamAuth { - upstream_url: upstream_url.to_string(), - auth: http::UpstreamAuthHeader { - name: AUTHORIZATION, - value, - }, - extra_headers, - antigravity: None, - }) -} - -async fn resolve_antigravity_upstream( - state: &ProxyState, - upstream: &UpstreamRuntime, - upstream_url: &str, -) -> Result { - let Some(account_id) = upstream.antigravity_account_id.as_deref() else { - return Err(AttemptOutcome::Fatal(http::error_response( - StatusCode::UNAUTHORIZED, - "Antigravity account is not configured.", - ))); - }; - let mut record = state - .antigravity_accounts - .get_account_record(account_id) - .await - .map_err(|err| AttemptOutcome::Fatal(http::error_response(StatusCode::UNAUTHORIZED, err)))?; - if record - .project_id - .as_deref() - .map(str::trim) - .filter(|value| !value.is_empty()) - .is_none() - { - let proxy_url = state.antigravity_accounts.app_proxy_url().await; - match antigravity_project::load_code_assist(&record.access_token, proxy_url.as_deref()).await - { - Ok(info) => { - if let Some(value) = info.project_id.clone() { - let _ = state - .antigravity_accounts - .update_project_id(account_id, value.clone()) - .await; - record.project_id = Some(value); - } else if let Some(tier_id) = info.plan_type.as_deref() { - if let Ok(Some(value)) = antigravity_project::onboard_user( - &record.access_token, - proxy_url.as_deref(), - tier_id, - ) - .await - { - let _ = state - .antigravity_accounts - .update_project_id(account_id, value.clone()) - .await; - record.project_id = Some(value); - } - } - } - Err(err) => { - tracing::warn!(error = %err, "antigravity loadCodeAssist failed in proxy"); - } - } - } - let value = http::bearer_header(&record.access_token).ok_or_else(|| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::UNAUTHORIZED, - "Upstream access token contains invalid characters.", - )) - })?; - let user_agent = state - .config - .antigravity_user_agent - .as_deref() - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()) - .unwrap_or_else(antigravity_endpoints::default_user_agent); - let mut extra_headers = HeaderMap::new(); - if let Ok(value) = HeaderValue::from_str(&user_agent) { - extra_headers.insert(USER_AGENT, value); - } - let extra_headers = if extra_headers.is_empty() { - None - } else { - Some(extra_headers) - }; - Ok(ResolvedUpstreamAuth { - upstream_url: upstream_url.to_string(), - auth: http::UpstreamAuthHeader { - name: AUTHORIZATION, - value, - }, - extra_headers, - antigravity: Some(AntigravityRequestInfo { - project_id: record.project_id.clone(), - user_agent, - }), - }) -} - -fn build_mapped_meta(meta: &RequestMeta, upstream: &UpstreamRuntime, provider: &str) -> RequestMeta { - let mapped_model = meta - .original_model - .as_deref() - .map(|original| upstream.map_model(original).unwrap_or_else(|| original.to_string())); - let (mapped_model, reasoning_effort) = normalize_mapped_model_reasoning_suffix( - mapped_model, - meta.reasoning_effort.clone(), - ); - let mapped_model = mapped_model.map(|model| { - if provider == "antigravity" { - super::antigravity_compat::map_antigravity_model(&model) - } else { - model - } - }); - RequestMeta { - stream: meta.stream, - original_model: meta.original_model.clone(), - mapped_model, - reasoning_effort, - estimated_input_tokens: meta.estimated_input_tokens, - } -} - -fn normalize_mapped_model_reasoning_suffix( - mapped_model: Option, - reasoning_effort: Option, -) -> (Option, Option) { - let Some(mapped_model) = mapped_model else { - return (None, reasoning_effort); - }; - let Some((base_model, mapped_effort)) = - super::server_helpers::parse_openai_reasoning_effort_from_model_suffix(&mapped_model) - else { - return (Some(mapped_model), reasoning_effort); - }; - - // If the user already specified an explicit effort in the incoming `model`, keep it. - let reasoning_effort = reasoning_effort.or(Some(mapped_effort)); - (Some(base_model), reasoning_effort) -} - -fn resolve_upstream_path_with_query( - provider: &str, - upstream_path_with_query: &str, - meta: &RequestMeta, -) -> String { - if provider == "antigravity" { - return if meta.stream { - // Align with CLIProxyAPIPlus: Antigravity streaming defaults to SSE via `alt=sse`. - format!("{ANTIGRAVITY_STREAM_PATH}?alt=sse") - } else { - ANTIGRAVITY_GENERATE_PATH.to_string() - }; - } - if provider != "gemini" || meta.model_override().is_none() { - return upstream_path_with_query.to_string(); - } - let Some(mapped_model) = meta.mapped_model.as_deref() else { - return upstream_path_with_query.to_string(); - }; - let (path, query) = request::split_path_query(upstream_path_with_query); - let replaced = gemini::replace_gemini_model_in_path(path, mapped_model) - .unwrap_or_else(|| path.to_string()); - match query { - Some(query) => format!("{replaced}?{query}"), - None => replaced, - } -} - -#[cfg(test)] -#[path = "upstream.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/upstream.test.rs b/src-tauri/src/proxy/upstream.test.rs deleted file mode 100644 index cc00f30..0000000 --- a/src-tauri/src/proxy/upstream.test.rs +++ /dev/null @@ -1,123 +0,0 @@ -use super::*; -use super::utils::is_retryable_status; - -#[test] -fn retryable_status_matches_proxy_policy() { - assert!(is_retryable_status(StatusCode::BAD_REQUEST)); - assert!(is_retryable_status(StatusCode::FORBIDDEN)); - assert!(is_retryable_status(StatusCode::TOO_MANY_REQUESTS)); - assert!(is_retryable_status(StatusCode::TEMPORARY_REDIRECT)); - assert!(is_retryable_status(StatusCode::INTERNAL_SERVER_ERROR)); - - // Exclude 504/524 timeouts from retries. - assert!(!is_retryable_status(StatusCode::GATEWAY_TIMEOUT)); - assert!(!is_retryable_status(StatusCode::from_u16(524).expect("524"))); - - assert!(!is_retryable_status(StatusCode::UNAUTHORIZED)); -} - -#[test] -fn extract_query_param_reads_key_value() { - let value = utils::extract_query_param("/v1beta/models/x:generateContent?key=abc&foo=bar", "key"); - assert_eq!(value.as_deref(), Some("abc")); -} - -#[test] -fn ensure_query_param_overrides_existing_value() { - let url = "https://example.com/v1beta/models/x:generateContent?foo=bar&key=old"; - let updated = utils::ensure_query_param(url, "key", "new").expect("updated url"); - assert!(updated.contains("foo=bar")); - assert!(updated.contains("key=new")); - assert!(!updated.contains("key=old")); -} - -#[test] -fn redact_query_param_value_hides_secret() { - let message = "error sending request for url (https://example.com/path?key=SECRET&foo=bar)"; - let redacted = redact_query_param_value(message, "key"); - assert!(redacted.contains("key=***")); - assert!(!redacted.contains("SECRET")); - assert!(redacted.contains("foo=bar")); -} - -#[test] -fn apply_header_overrides_sets_and_removes() { - use axum::http::header::{AUTHORIZATION, CONTENT_LENGTH, HOST}; - use axum::http::{HeaderMap, HeaderName, HeaderValue}; - - let mut headers = HeaderMap::new(); - headers.insert(HeaderName::from_static("x-remove"), HeaderValue::from_static("value")); - headers.insert(AUTHORIZATION, HeaderValue::from_static("Bearer original")); - headers.insert(HeaderName::from_static("x-keep"), HeaderValue::from_static("old")); - - let overrides = vec![ - super::super::config::HeaderOverride { - name: HeaderName::from_static("x-custom"), - value: Some(HeaderValue::from_static("new")), - }, - super::super::config::HeaderOverride { - name: AUTHORIZATION, - value: Some(HeaderValue::from_static("Bearer override")), - }, - super::super::config::HeaderOverride { - name: HeaderName::from_static("x-remove"), - value: None, - }, - super::super::config::HeaderOverride { - name: HOST, - value: Some(HeaderValue::from_static("skip.example.com")), - }, - super::super::config::HeaderOverride { - name: CONTENT_LENGTH, - value: Some(HeaderValue::from_static("123")), - }, - ]; - - request::apply_header_overrides(&mut headers, &overrides); - - assert_eq!( - headers.get("x-custom").and_then(|v| v.to_str().ok()), - Some("new") - ); - assert_eq!( - headers.get(AUTHORIZATION).and_then(|v| v.to_str().ok()), - Some("Bearer override") - ); - assert!(!headers.contains_key("x-remove")); - // hop-by-hop/host/content-length must stay untouched/removed - assert!(!headers.contains_key(HOST)); - assert!(!headers.contains_key(CONTENT_LENGTH)); -} - -#[test] -fn mapped_model_reasoning_suffix_is_stripped_and_becomes_effort() { - let (model, effort) = normalize_mapped_model_reasoning_suffix( - Some("gpt-4.1-reasoning-high".to_string()), - None, - ); - assert_eq!(model.as_deref(), Some("gpt-4.1")); - assert_eq!(effort.as_deref(), Some("high")); -} - -#[test] -fn mapped_model_reasoning_suffix_does_not_override_existing_effort() { - let (model, effort) = normalize_mapped_model_reasoning_suffix( - Some("gpt-4.1-reasoning-high".to_string()), - Some("low".to_string()), - ); - assert_eq!(model.as_deref(), Some("gpt-4.1")); - assert_eq!(effort.as_deref(), Some("low")); -} - -#[test] -fn antigravity_stream_path_defaults_to_alt_sse() { - let meta = RequestMeta { - stream: true, - original_model: None, - mapped_model: None, - reasoning_effort: None, - estimated_input_tokens: None, - }; - let path = resolve_upstream_path_with_query("antigravity", "/v1/chat/completions", &meta); - assert_eq!(path, format!("{ANTIGRAVITY_STREAM_PATH}?alt=sse")); -} diff --git a/src-tauri/src/proxy/upstream/attempt.rs b/src-tauri/src/proxy/upstream/attempt.rs deleted file mode 100644 index 3686dfb..0000000 --- a/src-tauri/src/proxy/upstream/attempt.rs +++ /dev/null @@ -1,915 +0,0 @@ -use std::time::Instant; - -use axum::http::{ - header::{ACCEPT, ACCEPT_ENCODING, CONTENT_TYPE, USER_AGENT}, - HeaderMap, HeaderValue, Method, StatusCode, -}; -use reqwest::{Client, Proxy}; -use tokio::time::timeout; - -use super::result; -use super::request; -use super::utils::{is_retryable_error, sanitize_upstream_error}; -use super::{AttemptOutcome, PreparedUpstreamRequest}; -use crate::antigravity::endpoints as antigravity_endpoints; -use crate::proxy::http; -use crate::proxy::openai_compat::FormatTransform; -use crate::proxy::request_detail::RequestDetailSnapshot; -use crate::proxy::request_body::ReplayableBody; -use crate::proxy::server_helpers::log_debug_headers_body; -use crate::proxy::{config::UpstreamRuntime, ProxyState, RequestMeta}; -use crate::proxy::{UPSTREAM_NO_DATA_TIMEOUT}; - -const DEBUG_UPSTREAM_LOG_LIMIT_BYTES: usize = usize::MAX; - -pub(super) async fn attempt_upstream( - state: &ProxyState, - method: Method, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - upstream_path_with_query: &str, - headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - request_auth: &crate::proxy::http::RequestAuth, - response_transform: FormatTransform, - request_detail: Option, -) -> AttemptOutcome { - if provider == "kiro" { - return super::kiro::attempt_kiro_upstream( - state, - method, - upstream, - inbound_path, - headers, - body, - meta, - response_transform, - request_detail, - ) - .await; - } - let first = match attempt_send( - state, - method.clone(), - provider, - upstream, - inbound_path, - upstream_path_with_query, - headers, - body, - meta, - request_auth, - request_detail.as_ref(), - ) - .await - { - Ok(attempt) => attempt, - Err(outcome) => return outcome, - }; - if let Some(outcome) = retry_after_kiro_refresh( - state, - method, - provider, - upstream, - inbound_path, - upstream_path_with_query, - headers, - body, - meta, - request_auth, - response_transform, - request_detail.clone(), - &first, - ) - .await - { - return outcome; - } - finalize_attempt( - state, - provider, - upstream, - inbound_path, - response_transform, - request_detail, - first, - ) - .await -} - -struct UpstreamAttempt { - response: reqwest::Response, - meta: RequestMeta, - start_time: Instant, -} - -async fn retry_after_kiro_refresh( - state: &ProxyState, - method: Method, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - upstream_path_with_query: &str, - headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - request_auth: &crate::proxy::http::RequestAuth, - response_transform: FormatTransform, - request_detail: Option, - first: &UpstreamAttempt, -) -> Option { - if !should_refresh_kiro(provider, &first.response) { - return None; - } - if let Err(outcome) = refresh_kiro_account(state, upstream).await { - return Some(outcome); - } - let retry = match attempt_send( - state, - method, - provider, - upstream, - inbound_path, - upstream_path_with_query, - headers, - body, - meta, - request_auth, - request_detail.as_ref(), - ) - .await - { - Ok(attempt) => attempt, - Err(outcome) => return Some(outcome), - }; - Some( - finalize_attempt( - state, - provider, - upstream, - inbound_path, - response_transform, - request_detail, - retry, - ) - .await, - ) -} - -async fn finalize_attempt( - state: &ProxyState, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - response_transform: FormatTransform, - request_detail: Option, - attempt: UpstreamAttempt, -) -> AttemptOutcome { - result::handle_upstream_result( - Ok(attempt.response), - &attempt.meta, - provider, - &upstream.id, - inbound_path, - state.log.clone(), - state.token_rate.clone(), - attempt.start_time, - response_transform, - request_detail, - ) - .await -} - -async fn attempt_send( - state: &ProxyState, - method: Method, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - upstream_path_with_query: &str, - headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - request_auth: &crate::proxy::http::RequestAuth, - request_detail: Option<&RequestDetailSnapshot>, -) -> Result { - let prepared = super::prepare_upstream_request( - state, - provider, - upstream, - upstream_path_with_query, - headers, - meta, - request_auth, - ) - .await?; - let PreparedUpstreamRequest { - upstream_path_with_query, - upstream_url, - request_headers, - meta, - antigravity, - } = prepared; - let start_time = Instant::now(); - let response = send_upstream_request( - state, - method, - provider, - upstream, - inbound_path, - &upstream_path_with_query, - &upstream_url, - &request_headers, - body, - &meta, - antigravity.as_ref(), - request_detail, - start_time, - ) - .await?; - Ok(UpstreamAttempt { - response, - meta, - start_time, - }) -} - -async fn send_upstream_request( - state: &ProxyState, - method: Method, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - upstream_path_with_query: &str, - upstream_url: &str, - request_headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - antigravity: Option<&super::AntigravityRequestInfo>, - request_detail: Option<&RequestDetailSnapshot>, - start_time: Instant, -) -> Result { - if provider == "codex" { - return send_codex_request( - state, - method, - provider, - upstream, - inbound_path, - upstream_path_with_query, - upstream_url, - request_headers, - body, - meta, - antigravity, - request_detail, - start_time, - ) - .await; - } - if provider == "antigravity" { - return send_antigravity_with_fallback( - state, - method, - provider, - upstream, - inbound_path, - upstream_path_with_query, - request_headers, - body, - meta, - antigravity, - request_detail, - start_time, - ) - .await; - } - send_upstream_request_once( - state, - method, - provider, - upstream, - inbound_path, - upstream_path_with_query, - upstream_url, - request_headers, - body, - meta, - antigravity, - request_detail, - start_time, - ) - .await -} - -async fn send_antigravity_with_fallback( - state: &ProxyState, - method: Method, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - upstream_path_with_query: &str, - request_headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - antigravity: Option<&super::AntigravityRequestInfo>, - request_detail: Option<&RequestDetailSnapshot>, - start_time: Instant, -) -> Result { - let urls = antigravity_fallback_urls(&upstream.base_url, upstream_path_with_query); - let request_headers = antigravity_request_headers(request_headers, meta, antigravity); - log_debug_headers_body( - "upstream.request", - Some(&request_headers), - Some(body), - DEBUG_UPSTREAM_LOG_LIMIT_BYTES, - ) - .await; - let mut last_transport: Option = None; - let mut saw_timeout = false; - for (idx, url) in urls.iter().enumerate() { - let upstream_body = request::build_upstream_body( - provider, - upstream, - upstream_path_with_query, - body, - meta, - antigravity, - ) - .await?; - match send_request_once( - state - .http_clients - .client_for_proxy_url(upstream.proxy_url.as_deref()) - .map_err(|message| { - AttemptOutcome::Fatal(http::error_response(StatusCode::BAD_GATEWAY, message)) - })?, - &method, - url, - &request_headers, - upstream_body, - ) - .await - { - Ok(response) => { - let status = response.status(); - // Align with CLIProxyAPIPlus: Antigravity endpoints may return 404 on one base URL - // while succeeding on another; try fallbacks on 404 as well. - if (super::utils::is_retryable_status(status) || status == StatusCode::NOT_FOUND) - && idx + 1 < urls.len() - { - let _ = response.bytes().await; - continue; - } - return Ok(response); - } - Err(SendFailure::Timeout) => saw_timeout = true, - Err(SendFailure::Transport(err)) => last_transport = Some(err), - } - } - if saw_timeout { - return Err(handle_upstream_timeout( - state, - provider, - upstream, - inbound_path, - meta, - request_detail, - start_time, - )); - } - if let Some(err) = last_transport { - return Err(map_upstream_error( - state, - provider, - upstream, - inbound_path, - meta, - request_detail, - err, - start_time, - )); - } - Err(AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - "Antigravity upstream request failed.".to_string(), - ))) -} - -async fn send_codex_request( - state: &ProxyState, - method: Method, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - upstream_path_with_query: &str, - upstream_url: &str, - request_headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - antigravity: Option<&super::AntigravityRequestInfo>, - request_detail: Option<&RequestDetailSnapshot>, - start_time: Instant, -) -> Result { - let Some(proxy_url) = upstream.proxy_url.as_deref() else { - return send_upstream_request_once( - state, - method, - provider, - upstream, - inbound_path, - upstream_path_with_query, - upstream_url, - request_headers, - body, - meta, - antigravity, - request_detail, - start_time, - ) - .await; - }; - send_codex_with_fallback( - state, - method, - provider, - upstream, - inbound_path, - upstream_path_with_query, - upstream_url, - request_headers, - body, - meta, - request_detail, - start_time, - proxy_url, - ) - .await -} - -async fn send_upstream_request_once( - state: &ProxyState, - method: Method, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - upstream_path_with_query: &str, - upstream_url: &str, - request_headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - antigravity: Option<&super::AntigravityRequestInfo>, - request_detail: Option<&RequestDetailSnapshot>, - start_time: Instant, -) -> Result { - log_debug_headers_body( - "upstream.request", - Some(request_headers), - Some(body), - DEBUG_UPSTREAM_LOG_LIMIT_BYTES, - ) - .await; - let client = state - .http_clients - .client_for_proxy_url(upstream.proxy_url.as_deref()) - .map_err(|message| { - AttemptOutcome::Fatal(http::error_response(StatusCode::BAD_GATEWAY, message)) - })?; - let upstream_body = request::build_upstream_body( - provider, - upstream, - upstream_path_with_query, - body, - meta, - antigravity, - ) - .await?; - match send_request_once( - client, - &method, - upstream_url, - request_headers, - upstream_body, - ) - .await - { - Ok(result) => Ok(result), - Err(SendFailure::Transport(err)) => Err(map_upstream_error( - state, - provider, - upstream, - inbound_path, - meta, - request_detail, - err, - start_time, - )), - Err(SendFailure::Timeout) => Err(handle_upstream_timeout( - state, - provider, - upstream, - inbound_path, - meta, - request_detail, - start_time, - )), - } -} - -async fn send_codex_with_fallback( - state: &ProxyState, - method: Method, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - upstream_path_with_query: &str, - upstream_url: &str, - request_headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - request_detail: Option<&RequestDetailSnapshot>, - start_time: Instant, - proxy_url: &str, -) -> Result { - // Codex 代理回退:socks5h / http1_only,缓解 DNS/ALPN/TLS 兼容问题。 - let attempts = build_codex_send_attempts(proxy_url); - let mut last_error: Option = None; - for attempt in attempts { - match send_codex_attempt( - state, - &method, - provider, - upstream, - inbound_path, - upstream_path_with_query, - upstream_url, - request_headers, - body, - meta, - request_detail, - start_time, - &attempt, - ) - .await - { - Ok(result) => return Ok(result), - Err(CodexAttemptError::Retry(err)) => last_error = Some(err), - Err(CodexAttemptError::Fatal(outcome)) => return Err(outcome), - } - } - Err(finalize_codex_fallback( - state, - provider, - upstream, - inbound_path, - meta, - request_detail, - start_time, - last_error, - )) -} - -async fn send_codex_attempt( - state: &ProxyState, - method: &Method, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - upstream_path_with_query: &str, - upstream_url: &str, - request_headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - request_detail: Option<&RequestDetailSnapshot>, - start_time: Instant, - attempt: &CodexSendAttempt, -) -> Result { - log_debug_headers_body( - "upstream.request", - Some(request_headers), - Some(body), - DEBUG_UPSTREAM_LOG_LIMIT_BYTES, - ) - .await; - let client = build_codex_client(attempt.proxy_url.as_deref(), attempt.http1_only).map_err(|message| { - CodexAttemptError::Fatal(AttemptOutcome::Fatal(http::error_response(StatusCode::BAD_GATEWAY, message))) - })?; - let upstream_body = request::build_upstream_body( - provider, - upstream, - upstream_path_with_query, - body, - meta, - None, - ) - .await - .map_err(CodexAttemptError::Fatal)?; - match send_request_once( - client, - method, - upstream_url, - request_headers, - upstream_body, - ) - .await - { - Ok(result) => Ok(result), - Err(SendFailure::Timeout) => Err(CodexAttemptError::Fatal(handle_upstream_timeout( - state, - provider, - upstream, - inbound_path, - meta, - request_detail, - start_time, - ))), - Err(SendFailure::Transport(err)) => { - if should_retry_codex_send(&err) { - return Err(CodexAttemptError::Retry(err)); - } - Err(CodexAttemptError::Fatal(map_upstream_error( - state, - provider, - upstream, - inbound_path, - meta, - request_detail, - err, - start_time, - ))) - } - } -} - -fn finalize_codex_fallback( - state: &ProxyState, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - meta: &RequestMeta, - request_detail: Option<&RequestDetailSnapshot>, - start_time: Instant, - last_error: Option, -) -> AttemptOutcome { - let Some(err) = last_error else { - return AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - "Codex upstream request failed.".to_string(), - )); - }; - map_upstream_error( - state, - provider, - upstream, - inbound_path, - meta, - request_detail, - err, - start_time, - ) -} - -async fn send_request_once( - client: Client, - method: &Method, - upstream_url: &str, - request_headers: &HeaderMap, - upstream_body: reqwest::Body, -) -> Result { - let upstream_res = timeout( - UPSTREAM_NO_DATA_TIMEOUT, - client - .request(method.clone(), upstream_url) - .headers(request_headers.clone()) - .body(upstream_body) - .send(), - ) - .await; - match upstream_res { - Ok(Ok(result)) => Ok(result), - Ok(Err(err)) => Err(SendFailure::Transport(err)), - Err(_) => Err(SendFailure::Timeout), - } -} - -struct CodexSendAttempt { - proxy_url: Option, - http1_only: bool, -} - -enum SendFailure { - Transport(reqwest::Error), - Timeout, -} - -enum CodexAttemptError { - Retry(reqwest::Error), - Fatal(AttemptOutcome), -} - -fn build_codex_send_attempts(proxy_url: &str) -> Vec { - let mut attempts = Vec::new(); - attempts.push(CodexSendAttempt { - proxy_url: Some(proxy_url.to_string()), - http1_only: false, - }); - if let Some(upgraded) = upgrade_socks5(proxy_url) { - attempts.push(CodexSendAttempt { - proxy_url: Some(upgraded), - http1_only: false, - }); - } - attempts.push(CodexSendAttempt { - proxy_url: Some(proxy_url.to_string()), - http1_only: true, - }); - attempts -} - -fn upgrade_socks5(proxy_url: &str) -> Option { - let value = proxy_url.trim(); - if value.starts_with("socks5h://") { - return None; - } - if value.starts_with("socks5://") { - return Some(value.replacen("socks5://", "socks5h://", 1)); - } - None -} - -fn antigravity_fallback_urls(base_url: &str, path: &str) -> Vec { - let mut urls = Vec::new(); - let path = if path.starts_with('/') { - path.to_string() - } else { - format!("/{path}") - }; - let base_url = base_url.trim_end_matches('/'); - let bases = if base_url.is_empty() || base_url == antigravity_endpoints::BASE_URL_PROD { - antigravity_endpoints::BASE_URLS - .iter() - .map(|base| base.to_string()) - .collect::>() - } else { - antigravity_endpoints::build_base_url_list(base_url) - }; - for base in bases { - let base = base.trim_end_matches('/'); - urls.push(format!("{base}{path}")); - } - urls -} - -fn antigravity_request_headers( - base: &HeaderMap, - meta: &RequestMeta, - antigravity: Option<&super::AntigravityRequestInfo>, -) -> HeaderMap { - let mut headers = base.clone(); - let user_agent = antigravity - .map(|info| info.user_agent.clone()) - .unwrap_or_else(antigravity_endpoints::default_user_agent); - if let Ok(value) = HeaderValue::from_str(&user_agent) { - headers.insert(USER_AGENT, value); - } - headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); - headers.insert(ACCEPT_ENCODING, HeaderValue::from_static("identity")); - let accept = if meta.stream { - "text/event-stream" - } else { - "application/json" - }; - headers.insert(ACCEPT, HeaderValue::from_static(accept)); - headers -} - -fn build_codex_client(proxy_url: Option<&str>, http1_only: bool) -> Result { - let mut builder = Client::builder(); - if let Some(proxy_url) = proxy_url.map(str::trim).filter(|value| !value.is_empty()) { - let proxy = Proxy::all(proxy_url) - .map_err(|_| "proxy_url is invalid or not supported.".to_string())?; - builder = builder.proxy(proxy); - } else { - builder = builder.no_proxy(); - } - if http1_only { - builder = builder.http1_only(); - } - builder - .build() - .map_err(|err| format!("Failed to build Codex upstream client: {err}")) -} - -fn should_retry_codex_send(err: &reqwest::Error) -> bool { - err.is_connect() || err.is_request() -} - -fn handle_upstream_timeout( - state: &ProxyState, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - meta: &RequestMeta, - request_detail: Option<&RequestDetailSnapshot>, - start_time: Instant, -) -> AttemptOutcome { - let message = format!( - "Upstream did not respond within {}s.", - UPSTREAM_NO_DATA_TIMEOUT.as_secs() - ); - result::log_upstream_error_if_needed( - &state.log, - request_detail, - meta, - provider, - &upstream.id, - inbound_path, - StatusCode::GATEWAY_TIMEOUT, - message.clone(), - start_time, - ); - AttemptOutcome::Retryable { - message, - response: None, - is_timeout: true, - } -} - -fn map_upstream_error( - state: &ProxyState, - provider: &str, - upstream: &UpstreamRuntime, - inbound_path: &str, - meta: &RequestMeta, - request_detail: Option<&RequestDetailSnapshot>, - err: reqwest::Error, - start_time: Instant, -) -> AttemptOutcome { - let message = sanitize_upstream_error(provider, &err); - if is_retryable_error(&err) { - let status = if err.is_timeout() { - StatusCode::GATEWAY_TIMEOUT - } else { - StatusCode::BAD_GATEWAY - }; - result::log_upstream_error_if_needed( - &state.log, - request_detail, - meta, - provider, - &upstream.id, - inbound_path, - status, - message.clone(), - start_time, - ); - return AttemptOutcome::Retryable { - message, - response: None, - is_timeout: err.is_timeout(), - }; - } - let error_message = format!("Upstream request failed: {message}"); - result::log_upstream_error_if_needed( - &state.log, - request_detail, - meta, - provider, - &upstream.id, - inbound_path, - StatusCode::BAD_GATEWAY, - error_message.clone(), - start_time, - ); - AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - error_message, - )) -} - -fn should_refresh_kiro(provider: &str, response: &reqwest::Response) -> bool { - provider == "kiro" - && (response.status() == StatusCode::UNAUTHORIZED - || response.status() == StatusCode::FORBIDDEN) -} - -async fn refresh_kiro_account( - state: &ProxyState, - upstream: &UpstreamRuntime, -) -> Result<(), AttemptOutcome> { - let Some(account_id) = upstream.kiro_account_id.as_deref() else { - return Err(AttemptOutcome::Fatal(http::error_response( - StatusCode::UNAUTHORIZED, - "Kiro account is not configured.", - ))); - }; - state - .kiro_accounts - .refresh_account(account_id) - .await - .map_err(|err| { - AttemptOutcome::Fatal(http::error_response(StatusCode::UNAUTHORIZED, err)) - }) -} diff --git a/src-tauri/src/proxy/upstream/kiro.rs b/src-tauri/src/proxy/upstream/kiro.rs deleted file mode 100644 index 3303936..0000000 --- a/src-tauri/src/proxy/upstream/kiro.rs +++ /dev/null @@ -1,555 +0,0 @@ -use axum::body::Bytes; -use axum::body::Body; -use axum::http::{HeaderMap, Method, StatusCode}; -use serde_json::Value; -use std::time::{Duration, Instant}; -use super::{result, AttemptOutcome}; -use super::kiro_http::{ - build_client, - handle_send_error, - read_request_json, - refresh_kiro_account, - send_kiro_request, -}; -use crate::proxy::http; -use crate::proxy::kiro::{ - build_payload_from_chat, build_payload_from_claude, build_payload_from_responses, - determine_agentic_mode, map_model_to_kiro, select_endpoints, BuildPayloadResult, - KiroEndpointConfig, -}; -use crate::proxy::openai_compat::FormatTransform; -use crate::proxy::request_body::ReplayableBody; -use crate::proxy::{ProxyState, RequestMeta}; -use crate::proxy::{config::UpstreamRuntime, request_detail::RequestDetailSnapshot}; -use crate::kiro::KiroTokenRecord; - -const MAX_KIRO_RETRIES: usize = 2; -const MAX_KIRO_BACKOFF_SECS: u64 = 30; - -pub(super) async fn attempt_kiro_upstream( - state: &ProxyState, - method: Method, - upstream: &UpstreamRuntime, - inbound_path: &str, - headers: &HeaderMap, - body: &ReplayableBody, - meta: &RequestMeta, - response_transform: FormatTransform, - request_detail: Option, -) -> AttemptOutcome { - let mut context = match prepare_kiro_context( - state, - upstream, - body, - meta, - headers, - method, - inbound_path, - response_transform, - request_detail, - ) - .await - { - Ok(context) => context, - Err(outcome) => return outcome, - }; - run_kiro_endpoints(&mut context).await -} - -struct KiroContext<'a> { - state: &'a ProxyState, - method: Method, - upstream: &'a UpstreamRuntime, - inbound_path: &'a str, - headers: &'a HeaderMap, - response_transform: FormatTransform, - request_detail: Option, - mapped_meta: RequestMeta, - request_value: Value, - account_id: String, - record: KiroTokenRecord, - profile_arn: Option, - endpoints: Vec, - is_idc: bool, - model_id: String, - is_agentic: bool, - is_chat_only: bool, - source_format: KiroSourceFormat, - client: reqwest::Client, -} - -#[derive(Clone, Copy, Debug)] -enum KiroSourceFormat { - OpenAIChat, - Responses, - Anthropic, -} - -enum EndpointOutcome { - Continue, - Done(AttemptOutcome), -} - -enum ResponseAction { - RetryAfter(Duration), - RefreshAndRetry, - NextEndpoint, - Finalize(reqwest::Response, Instant), - Return(AttemptOutcome), -} - -async fn prepare_kiro_context<'a>( - state: &'a ProxyState, - upstream: &'a UpstreamRuntime, - body: &ReplayableBody, - meta: &RequestMeta, - headers: &'a HeaderMap, - method: Method, - inbound_path: &'a str, - response_transform: FormatTransform, - request_detail: Option, -) -> Result, AttemptOutcome> { - let mapped_meta = super::build_mapped_meta(meta, upstream, "kiro"); - let request_value = read_request_json(state, body).await?; - let account_id = resolve_account_id(upstream)?; - let record = load_account_record(state, &account_id).await?; - let is_idc = record.auth_method.trim().eq_ignore_ascii_case("idc"); - let profile_arn = resolve_profile_arn(&record); - let endpoints = resolve_endpoints(state, upstream, is_idc); - let (model_id, is_agentic, is_chat_only) = resolve_model(&mapped_meta); - let source_format = resolve_source_format(response_transform); - let client = build_client(state, upstream)?; - - Ok(KiroContext { - state, - method, - upstream, - inbound_path, - headers, - response_transform, - request_detail, - mapped_meta, - request_value, - account_id, - record, - profile_arn, - endpoints, - is_idc, - model_id, - is_agentic, - is_chat_only, - source_format, - client, - }) -} - -async fn run_kiro_endpoints(context: &mut KiroContext<'_>) -> AttemptOutcome { - let endpoints = context.endpoints.clone(); - let total = endpoints.len(); - for (index, endpoint) in endpoints.iter().enumerate() { - let is_last = index + 1 >= total; - match attempt_endpoint(context, endpoint, is_last).await { - EndpointOutcome::Continue => continue, - EndpointOutcome::Done(outcome) => return outcome, - } - } - - AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - "Kiro upstream request failed.", - )) -} - -async fn attempt_endpoint( - context: &mut KiroContext<'_>, - endpoint: &KiroEndpointConfig, - is_last: bool, -) -> EndpointOutcome { - let mut payload = match build_endpoint_payload(context, endpoint).await { - Ok(payload) => payload, - Err(outcome) => return EndpointOutcome::Done(outcome), - }; - - for attempt in 0..=MAX_KIRO_RETRIES { - let (response, start_time) = - match send_endpoint_request(context, endpoint, &payload.payload).await { - Ok(result) => result, - Err(outcome) => return EndpointOutcome::Done(outcome), - }; - - match handle_response_action(context, response, start_time, attempt, is_last) - .await - { - ResponseAction::RetryAfter(delay) => { - tokio::time::sleep(delay).await; - continue; - } - ResponseAction::RefreshAndRetry => { - match refresh_and_rebuild_payload(context, endpoint).await { - Ok(updated) => payload = updated, - Err(outcome) => return EndpointOutcome::Done(outcome), - } - continue; - } - ResponseAction::NextEndpoint => return EndpointOutcome::Continue, - ResponseAction::Finalize(response, start_time) => { - return EndpointOutcome::Done( - finalize_response( - context.state, - &context.mapped_meta, - context.upstream, - context.inbound_path, - context.response_transform, - context.request_detail.clone(), - response, - false, - start_time, - ) - .await, - ); - } - ResponseAction::Return(outcome) => return EndpointOutcome::Done(outcome), - } - } - - EndpointOutcome::Done(AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - "Kiro upstream request failed.", - ))) -} - -async fn build_endpoint_payload( - context: &KiroContext<'_>, - endpoint: &KiroEndpointConfig, -) -> Result { - let payload = match context.source_format { - KiroSourceFormat::OpenAIChat => build_payload_from_chat( - &context.request_value, - &context.model_id, - context.profile_arn.as_deref(), - endpoint.origin, - context.is_agentic, - context.is_chat_only, - context.headers, - ), - KiroSourceFormat::Anthropic => { - build_payload_from_anthropic(context, endpoint.origin).await - } - KiroSourceFormat::Responses => build_payload_from_responses( - &context.request_value, - &context.model_id, - context.profile_arn.as_deref(), - endpoint.origin, - context.is_agentic, - context.is_chat_only, - context.headers, - ), - }; - payload.map_err(|message| { - AttemptOutcome::Fatal(http::error_response(StatusCode::BAD_REQUEST, message)) - }) -} - -async fn handle_response_action( - context: &mut KiroContext<'_>, - response: reqwest::Response, - start_time: Instant, - attempt: usize, - is_last: bool, -) -> ResponseAction { - let status = response.status(); - // Kiro-specific retry/fallback: 5xx backoff, 401 refresh, 403 token-only refresh, 429 endpoint switch. - if status == StatusCode::TOO_MANY_REQUESTS { - return if is_last { - ResponseAction::Finalize(response, start_time) - } else { - ResponseAction::NextEndpoint - }; - } - if status.is_server_error() { - if attempt < MAX_KIRO_RETRIES { - return ResponseAction::RetryAfter(backoff_delay(attempt)); - } - return ResponseAction::Finalize(response, start_time); - } - if status == StatusCode::UNAUTHORIZED { - if attempt < MAX_KIRO_RETRIES { - return ResponseAction::RefreshAndRetry; - } - return ResponseAction::Finalize(response, start_time); - } - if status == StatusCode::FORBIDDEN { - return handle_forbidden_response(context, response, start_time, attempt).await; - } - if status == StatusCode::PAYMENT_REQUIRED { - return ResponseAction::Finalize(response, start_time); - } - - ResponseAction::Finalize(response, start_time) -} - -async fn handle_forbidden_response( - context: &mut KiroContext<'_>, - response: reqwest::Response, - start_time: Instant, - attempt: usize, -) -> ResponseAction { - let status = response.status(); - let headers = response.headers().clone(); - let body = match response.bytes().await { - Ok(bytes) => bytes, - Err(err) => { - let message = format!("Failed to read upstream response: {err}"); - return ResponseAction::Return(AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - message, - ))); - } - }; - let body_text = String::from_utf8_lossy(&body); - - if contains_suspended_flag(&body_text) { - let outcome = build_error_outcome(context, status, &headers, body, start_time); - return ResponseAction::Return(outcome); - } - - if contains_token_error(&body_text) && attempt < MAX_KIRO_RETRIES { - return ResponseAction::RefreshAndRetry; - } - - let outcome = build_error_outcome(context, status, &headers, body, start_time); - ResponseAction::Return(outcome) -} - -async fn refresh_and_rebuild_payload( - context: &mut KiroContext<'_>, - endpoint: &KiroEndpointConfig, -) -> Result { - refresh_kiro_account(context.state, &context.account_id).await?; - context.record = load_account_record(context.state, &context.account_id).await?; - let was_idc = context.is_idc; - context.is_idc = context - .record - .auth_method - .trim() - .eq_ignore_ascii_case("idc"); - if context.is_idc != was_idc { - context.endpoints = resolve_endpoints(context.state, context.upstream, context.is_idc); - } - build_endpoint_payload(context, endpoint).await -} - -fn backoff_delay(attempt: usize) -> Duration { - let exp = 1u64 << attempt; - Duration::from_secs(exp.min(MAX_KIRO_BACKOFF_SECS)) -} - -fn contains_suspended_flag(body: &str) -> bool { - let upper = body.to_ascii_uppercase(); - upper.contains("SUSPENDED") || upper.contains("TEMPORARILY_SUSPENDED") -} - -fn contains_token_error(body: &str) -> bool { - let lower = body.to_ascii_lowercase(); - lower.contains("token") - || lower.contains("expired") - || lower.contains("invalid") - || lower.contains("unauthorized") -} - -fn build_error_outcome( - context: &KiroContext<'_>, - status: StatusCode, - headers: &reqwest::header::HeaderMap, - body: Bytes, - start_time: Instant, -) -> AttemptOutcome { - let message = summarize_error_body(&body); - result::log_upstream_error_if_needed( - &context.state.log, - context.request_detail.as_ref(), - &context.mapped_meta, - "kiro", - &context.upstream.id, - context.inbound_path, - status, - message, - start_time, - ); - AttemptOutcome::Success(build_passthrough_response(status, headers, body)) -} - -fn build_passthrough_response( - status: StatusCode, - headers: &reqwest::header::HeaderMap, - body: Bytes, -) -> axum::response::Response { - let filtered = http::filter_response_headers(headers); - http::build_response(status, filtered, Body::from(body)) -} - -fn summarize_error_body(body: &Bytes) -> String { - const LIMIT: usize = 2048; - let text = String::from_utf8_lossy(body); - if text.len() > LIMIT { - format!("{}…", &text[..LIMIT]) - } else { - text.to_string() - } -} - -async fn build_payload_from_anthropic( - context: &KiroContext<'_>, - origin: &str, -) -> Result { - build_payload_from_claude( - &context.request_value, - &context.model_id, - context.profile_arn.as_deref(), - origin, - context.is_agentic, - context.is_chat_only, - context.headers, - ) -} - -async fn send_endpoint_request( - context: &KiroContext<'_>, - endpoint: &KiroEndpointConfig, - payload: &[u8], -) -> Result<(reqwest::Response, Instant), AttemptOutcome> { - let start_time = Instant::now(); - let response = match send_kiro_request( - &context.client, - context.method.clone(), - endpoint.url, - &context.record.access_token, - endpoint.amz_target, - context.is_idc, - payload, - context.upstream.header_overrides.as_deref(), - ) - .await - { - Ok(response) => response, - Err(err) => { - let outcome = handle_send_error( - context.state, - &context.mapped_meta, - context.upstream, - context.inbound_path, - context.response_transform, - context.request_detail.clone(), - err, - start_time, - ) - .await; - return Err(outcome); - } - }; - Ok((response, start_time)) -} - -fn resolve_account_id(upstream: &UpstreamRuntime) -> Result { - upstream - .kiro_account_id - .as_ref() - .map(|value| value.to_string()) - .ok_or_else(|| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::UNAUTHORIZED, - "Kiro account is not configured.", - )) - }) -} - -fn resolve_profile_arn(record: &KiroTokenRecord) -> Option { - match record.auth_method.as_str() { - "builder-id" | "idc" => None, - _ => record.profile_arn.clone(), - } -} - -async fn load_account_record( - state: &ProxyState, - account_id: &str, -) -> Result { - state - .kiro_accounts - .get_account_record(account_id) - .await - .map_err(|err| AttemptOutcome::Fatal(http::error_response(StatusCode::UNAUTHORIZED, err))) -} - -fn resolve_endpoints( - state: &ProxyState, - upstream: &UpstreamRuntime, - is_idc: bool, -) -> Vec { - let preferred = upstream - .kiro_preferred_endpoint - .clone() - .or(state.config.kiro_preferred_endpoint.clone()); - select_endpoints(preferred, is_idc) -} - -fn resolve_model(meta: &RequestMeta) -> (String, bool, bool) { - let model_source = meta - .mapped_model - .as_deref() - .or(meta.original_model.as_deref()) - .unwrap_or("claude-sonnet-4.5"); - let (is_agentic, is_chat_only) = - determine_agentic_mode(meta.original_model.as_deref().unwrap_or(model_source)); - (map_model_to_kiro(model_source), is_agentic, is_chat_only) -} - -fn resolve_source_format(transform: FormatTransform) -> KiroSourceFormat { - match transform { - FormatTransform::KiroToChat => KiroSourceFormat::OpenAIChat, - FormatTransform::KiroToAnthropic => KiroSourceFormat::Anthropic, - _ => KiroSourceFormat::Responses, - } -} - -async fn finalize_response( - state: &ProxyState, - meta: &RequestMeta, - upstream: &UpstreamRuntime, - inbound_path: &str, - response_transform: FormatTransform, - request_detail: Option, - response: reqwest::Response, - force_success: bool, - start_time: Instant, -) -> AttemptOutcome { - if force_success { - let output = crate::proxy::response::build_proxy_response( - meta, - "kiro", - &upstream.id, - inbound_path, - response, - state.log.clone(), - state.token_rate.clone(), - start_time, - response_transform, - request_detail, - ) - .await; - return AttemptOutcome::Success(output); - } - result::handle_upstream_result( - Ok(response), - meta, - "kiro", - &upstream.id, - inbound_path, - state.log.clone(), - state.token_rate.clone(), - start_time, - response_transform, - request_detail, - ) - .await -} diff --git a/src-tauri/src/proxy/upstream/kiro_headers.rs b/src-tauri/src/proxy/upstream/kiro_headers.rs deleted file mode 100644 index 376d383..0000000 --- a/src-tauri/src/proxy/upstream/kiro_headers.rs +++ /dev/null @@ -1,72 +0,0 @@ -use axum::http::{ - header::{ACCEPT, CONTENT_TYPE, USER_AGENT}, - HeaderMap, HeaderName, HeaderValue, -}; - -use crate::proxy::http; - -const KIRO_REQUEST_CONTENT_TYPE: &str = "application/x-amz-json-1.0"; -const KIRO_REQUEST_ACCEPT: &str = "*/*"; -const KIRO_AGENT_MODE_IDC: &str = "spec"; -const KIRO_AGENT_MODE_DEFAULT: &str = "vibe"; -const KIRO_OPT_OUT: &str = "true"; -const KIRO_SDK_REQUEST: &str = "attempt=1; max=3"; -const KIRO_USER_AGENT_IDC: &str = "aws-sdk-js/1.0.18 ua/2.1 os/darwin#25.0.0 lang/js md/nodejs#20.16.0 api/codewhispererstreaming#1.0.18 m/E KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"; -const KIRO_USER_AGENT_IDC_AMZ: &str = - "aws-sdk-js/1.0.18 KiroIDE-0.2.13-66c23a8c5d15afabec89ef9954ef52a119f10d369df04d548fc6c1eac694b0d1"; -const KIRO_USER_AGENT_DEFAULT: &str = "aws-sdk-rust/1.3.9 os/macos lang/rust/1.87.0"; -const KIRO_USER_AGENT_DEFAULT_AMZ: &str = - "aws-sdk-rust/1.3.9 ua/2.1 api/ssooidc/1.88.0 os/macos lang/rust/1.87.0 m/E app/AmazonQ-For-CLI"; - -const HEADER_AMZ_TARGET: HeaderName = HeaderName::from_static("x-amz-target"); -const HEADER_AMZ_USER_AGENT: HeaderName = HeaderName::from_static("x-amz-user-agent"); -const HEADER_AMZ_SDK_REQUEST: HeaderName = HeaderName::from_static("amz-sdk-request"); -const HEADER_AMZ_SDK_INVOCATION_ID: HeaderName = HeaderName::from_static("amz-sdk-invocation-id"); -const HEADER_KIRO_AGENT_MODE: HeaderName = HeaderName::from_static("x-amzn-kiro-agent-mode"); -const HEADER_KIRO_OPTOUT: HeaderName = HeaderName::from_static("x-amzn-codewhisperer-optout"); - -pub(super) fn build_kiro_headers( - access_token: &str, - amz_target: &str, - is_idc: bool, -) -> HeaderMap { - let mut headers = HeaderMap::new(); - headers.insert(CONTENT_TYPE, HeaderValue::from_static(KIRO_REQUEST_CONTENT_TYPE)); - headers.insert(ACCEPT, HeaderValue::from_static(KIRO_REQUEST_ACCEPT)); - if let Ok(value) = HeaderValue::from_str(amz_target) { - headers.insert(HEADER_AMZ_TARGET, value); - } - headers.insert(HEADER_AMZ_SDK_REQUEST, HeaderValue::from_static(KIRO_SDK_REQUEST)); - if let Ok(value) = HeaderValue::from_str(&crate::proxy::kiro::utils::random_uuid()) { - headers.insert(HEADER_AMZ_SDK_INVOCATION_ID, value); - } - headers.insert( - HEADER_KIRO_AGENT_MODE, - HeaderValue::from_static(if is_idc { - KIRO_AGENT_MODE_IDC - } else { - KIRO_AGENT_MODE_DEFAULT - }), - ); - headers.insert(HEADER_KIRO_OPTOUT, HeaderValue::from_static(KIRO_OPT_OUT)); - headers.insert( - USER_AGENT, - HeaderValue::from_static(if is_idc { - KIRO_USER_AGENT_IDC - } else { - KIRO_USER_AGENT_DEFAULT - }), - ); - headers.insert( - HEADER_AMZ_USER_AGENT, - HeaderValue::from_static(if is_idc { - KIRO_USER_AGENT_IDC_AMZ - } else { - KIRO_USER_AGENT_DEFAULT_AMZ - }), - ); - if let Some(auth) = http::bearer_header(access_token) { - headers.insert(axum::http::header::AUTHORIZATION, auth); - } - headers -} diff --git a/src-tauri/src/proxy/upstream/kiro_http.rs b/src-tauri/src/proxy/upstream/kiro_http.rs deleted file mode 100644 index c7c50d3..0000000 --- a/src-tauri/src/proxy/upstream/kiro_http.rs +++ /dev/null @@ -1,151 +0,0 @@ -use axum::http::{Method, StatusCode}; -use tokio::time::timeout; - -use super::request; -use super::{result, AttemptOutcome}; -use super::kiro_headers::build_kiro_headers; -use crate::proxy::http; -use crate::proxy::request_body::ReplayableBody; -use crate::proxy::request_detail::RequestDetailSnapshot; -use crate::proxy::{ProxyState, RequestMeta, UPSTREAM_NO_DATA_TIMEOUT}; -use crate::proxy::{config::UpstreamRuntime}; -use crate::proxy::openai_compat::FormatTransform; - -pub(super) enum KiroSendError { - Timeout, - Upstream(reqwest::Error), -} - -pub(super) fn build_client( - state: &ProxyState, - upstream: &UpstreamRuntime, -) -> Result { - state - .http_clients - .client_for_proxy_url(upstream.proxy_url.as_deref()) - .map_err(|message| { - AttemptOutcome::Fatal(http::error_response(StatusCode::BAD_GATEWAY, message)) - }) -} - -pub(super) async fn read_request_json( - state: &ProxyState, - body: &ReplayableBody, -) -> Result { - let Some(bytes) = body - .read_bytes_if_small(state.config.max_request_body_bytes) - .await - .map_err(|err| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_REQUEST, - format!("Failed to read request body: {err}"), - )) - })? - else { - return Err(AttemptOutcome::Fatal(http::error_response( - StatusCode::PAYLOAD_TOO_LARGE, - "Request body is too large to transform.", - ))); - }; - serde_json::from_slice::(&bytes).map_err(|_| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_REQUEST, - "Request body must be JSON.", - )) - }) -} - -pub(super) async fn send_kiro_request( - client: &reqwest::Client, - method: Method, - url: &str, - access_token: &str, - amz_target: &str, - is_idc: bool, - payload: &[u8], - overrides: Option<&[crate::proxy::config::HeaderOverride]>, -) -> Result { - let mut request_headers = build_kiro_headers(access_token, amz_target, is_idc); - if let Some(overrides) = overrides { - request::apply_header_overrides(&mut request_headers, overrides); - } - - let result = timeout( - UPSTREAM_NO_DATA_TIMEOUT, - client - .request(method, url) - .headers(request_headers) - .body(payload.to_vec()) - .send(), - ) - .await; - match result { - Ok(Ok(response)) => Ok(response), - Ok(Err(err)) => Err(KiroSendError::Upstream(err)), - Err(_) => Err(KiroSendError::Timeout), - } -} - -pub(super) async fn refresh_kiro_account( - state: &ProxyState, - account_id: &str, -) -> Result<(), AttemptOutcome> { - state - .kiro_accounts - .refresh_account(account_id) - .await - .map_err(|err| { - AttemptOutcome::Fatal(http::error_response(StatusCode::UNAUTHORIZED, err)) - }) -} - -pub(super) async fn handle_send_error( - state: &ProxyState, - meta: &RequestMeta, - upstream: &UpstreamRuntime, - inbound_path: &str, - response_transform: FormatTransform, - request_detail: Option, - err: KiroSendError, - start_time: std::time::Instant, -) -> AttemptOutcome { - match err { - KiroSendError::Upstream(err) => { - result::handle_upstream_result( - Err(err), - meta, - "kiro", - &upstream.id, - inbound_path, - state.log.clone(), - state.token_rate.clone(), - start_time, - response_transform, - request_detail, - ) - .await - } - KiroSendError::Timeout => { - let message = format!( - "Upstream did not respond within {}s.", - UPSTREAM_NO_DATA_TIMEOUT.as_secs() - ); - result::log_upstream_error_if_needed( - &state.log, - request_detail.as_ref(), - meta, - "kiro", - &upstream.id, - inbound_path, - StatusCode::GATEWAY_TIMEOUT, - message.clone(), - start_time, - ); - AttemptOutcome::Retryable { - message, - response: None, - is_timeout: true, - } - } - } -} diff --git a/src-tauri/src/proxy/upstream/request.rs b/src-tauri/src/proxy/upstream/request.rs deleted file mode 100644 index a4583af..0000000 --- a/src-tauri/src/proxy/upstream/request.rs +++ /dev/null @@ -1,318 +0,0 @@ -use axum::{ - body::Bytes, - http::{ - header::{HeaderName, HeaderValue, CONTENT_LENGTH, HOST}, - HeaderMap, StatusCode, - }, -}; -use serde_json::Value; - -use super::{ - utils::{ensure_query_param, extract_query_param}, - AttemptOutcome, -}; -use super::super::{ - codex_compat, - config::{HeaderOverride, UpstreamRuntime}, - http, - model, - request_body::ReplayableBody, - RequestMeta, -}; -use super::super::http::RequestAuth; -use crate::proxy::server_helpers::log_debug_headers_body; - -const ANTHROPIC_VERSION_HEADER: &str = "anthropic-version"; -const DEFAULT_ANTHROPIC_VERSION: &str = "2023-06-01"; -const GEMINI_API_KEY_QUERY: &str = "key"; -const GEMINI_API_KEY_HEADER: HeaderName = HeaderName::from_static("x-goog-api-key"); -const OPENAI_RESPONSES_PATH: &str = "/v1/responses"; -// Keep in sync with server_helpers request transform limit (20 MiB). -const REQUEST_FILTER_LIMIT_BYTES: usize = 20 * 1024 * 1024; -const DEBUG_UPSTREAM_LOG_LIMIT_BYTES: usize = usize::MAX; - -pub(super) fn split_path_query(path_with_query: &str) -> (&str, Option<&str>) { - match path_with_query.split_once('?') { - Some((path, query)) => (path, Some(query)), - None => (path_with_query, None), - } -} - -pub(super) fn build_request_headers( - provider: &str, - headers: &HeaderMap, - auth: http::UpstreamAuthHeader, - extra_headers: Option<&HeaderMap>, - header_overrides: Option<&[HeaderOverride]>, -) -> HeaderMap { - let mut request_headers = http::build_upstream_headers(headers, auth); - if provider == "anthropic" && !request_headers.contains_key(ANTHROPIC_VERSION_HEADER) { - // Anthropic 官方 API 需要 `anthropic-version`;缺省时补一个稳定默认值,允许客户端覆盖。 - request_headers.insert( - ANTHROPIC_VERSION_HEADER, - HeaderValue::from_static(DEFAULT_ANTHROPIC_VERSION), - ); - } - codex_compat::apply_codex_headers_if_needed(provider, &mut request_headers, headers); - - if let Some(extra_headers) = extra_headers { - for (name, value) in extra_headers.iter() { - request_headers.insert(name.clone(), value.clone()); - } - } - - if let Some(overrides) = header_overrides { - apply_header_overrides(&mut request_headers, overrides); - } - request_headers -} - -pub(super) fn apply_header_overrides(request_headers: &mut HeaderMap, overrides: &[HeaderOverride]) { - for override_item in overrides { - // 屏蔽 hop-by-hop / Host / Content-Length,无论配置为何。 - if crate::proxy::http::is_hop_header(&override_item.name) - || override_item.name == HOST - || override_item.name == CONTENT_LENGTH - { - continue; - } - - match &override_item.value { - Some(value) => { - request_headers.insert(override_item.name.clone(), value.clone()); - } - None => { - request_headers.remove(&override_item.name); - } - } - } -} - -pub(super) async fn build_upstream_body( - provider: &str, - upstream: &UpstreamRuntime, - upstream_path_with_query: &str, - body: &ReplayableBody, - meta: &RequestMeta, - antigravity: Option<&super::AntigravityRequestInfo>, -) -> Result { - if provider == "antigravity" { - return build_antigravity_body(body, meta, antigravity).await; - } - let mapped_body = maybe_rewrite_request_body_model(body, meta).await?; - let mapped_source = mapped_body.as_ref().unwrap_or(body); - let upstream_path = split_path_query(upstream_path_with_query).0; - let reasoning_body = match super::super::server_helpers::maybe_rewrite_openai_reasoning_effort_from_model_suffix( - provider, - upstream_path, - meta, - mapped_source, - ) - .await - { - Ok(body) => body, - Err(err) => { - return Err(AttemptOutcome::Fatal(http::error_response(err.status, err.message))) - } - }; - let source = reasoning_body - .as_ref() - .or(mapped_body.as_ref()) - .unwrap_or(body); - let filtered = - maybe_filter_openai_responses_request_fields(provider, upstream, upstream_path_with_query, source) - .await?; - let final_source = filtered.as_ref().unwrap_or(source); - final_source - .to_reqwest_body() - .await - .map_err(|err| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - format!("Failed to read cached request body: {err}"), - )) - }) -} - -async fn maybe_filter_openai_responses_request_fields( - provider: &str, - upstream: &UpstreamRuntime, - upstream_path_with_query: &str, - body: &ReplayableBody, -) -> Result, AttemptOutcome> { - let should_filter_prompt_cache_retention = upstream.filter_prompt_cache_retention; - let should_filter_safety_identifier = upstream.filter_safety_identifier; - if provider != "openai-response" - || (!should_filter_prompt_cache_retention && !should_filter_safety_identifier) - { - return Ok(None); - } - let upstream_path = split_path_query(upstream_path_with_query).0; - if upstream_path != OPENAI_RESPONSES_PATH { - return Ok(None); - } - - let Some(bytes) = body - .read_bytes_if_small(REQUEST_FILTER_LIMIT_BYTES) - .await - .map_err(|err| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - format!("Failed to read cached request body: {err}"), - )) - })? - else { - // Best-effort: request body too large to rewrite. - return Ok(None); - }; - - let Ok(mut value) = serde_json::from_slice::(&bytes) else { - return Ok(None); - }; - let Some(object) = value.as_object_mut() else { - return Ok(None); - }; - let mut changed = false; - if should_filter_prompt_cache_retention { - changed = changed || object.remove("prompt_cache_retention").is_some(); - } - if should_filter_safety_identifier { - changed = changed || object.remove("safety_identifier").is_some(); - } - if !changed { - return Ok(None); - } - - let outbound_bytes = serde_json::to_vec(&value) - .map(Bytes::from) - .map_err(|err| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - format!("Failed to serialize request: {err}"), - )) - })?; - Ok(Some(ReplayableBody::from_bytes(outbound_bytes))) -} - -async fn build_antigravity_body( - body: &ReplayableBody, - meta: &RequestMeta, - antigravity: Option<&super::AntigravityRequestInfo>, -) -> Result { - let Some(info) = antigravity else { - return Err(AttemptOutcome::Fatal(http::error_response( - StatusCode::UNAUTHORIZED, - "Antigravity account is not configured.", - ))); - }; - let Some(bytes) = body - .read_bytes_if_small(super::REQUEST_MODEL_MAPPING_LIMIT_BYTES) - .await - .map_err(|err| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - format!("Failed to read request body: {err}"), - )) - })? - else { - return Err(AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - "Antigravity request body is too large.", - ))); - }; - let model = meta.mapped_model.as_deref().or(meta.original_model.as_deref()); - let wrapped = super::super::antigravity_compat::wrap_gemini_request( - &bytes, - model, - info.project_id.as_deref(), - &info.user_agent, - ) - .map_err(|message| { - AttemptOutcome::Fatal(http::error_response(StatusCode::BAD_GATEWAY, message)) - })?; - let wrapped_body = ReplayableBody::from_bytes(wrapped.clone()); - log_debug_headers_body( - "antigravity.wrapped", - None, - Some(&wrapped_body), - DEBUG_UPSTREAM_LOG_LIMIT_BYTES, - ) - .await; - Ok(reqwest::Body::from(wrapped)) -} - -async fn maybe_rewrite_request_body_model( - body: &ReplayableBody, - meta: &RequestMeta, -) -> Result, AttemptOutcome> { - if meta.model_override().is_none() { - return Ok(None); - } - let Some(mapped_model) = meta.mapped_model.as_deref() else { - return Ok(None); - }; - let Some(bytes) = body - .read_bytes_if_small(super::REQUEST_MODEL_MAPPING_LIMIT_BYTES) - .await - .map_err(|err| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - format!("Failed to read request body: {err}"), - )) - })? - else { - return Ok(None); - }; - let Some(rewritten) = model::rewrite_request_model(&bytes, mapped_model) else { - return Ok(None); - }; - Ok(Some(ReplayableBody::from_bytes(rewritten))) -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "request.test.rs"] -mod tests; - -pub(super) fn resolve_gemini_upstream( - upstream: &UpstreamRuntime, - request_auth: &RequestAuth, - upstream_path_with_query: &str, - upstream_url: &str, -) -> Result<(String, http::UpstreamAuthHeader), AttemptOutcome> { - let query_key = extract_query_param(upstream_path_with_query, GEMINI_API_KEY_QUERY); - let selected = upstream - .api_key - .as_deref() - .or_else(|| request_auth.gemini_api_key.as_deref()) - .or_else(|| query_key.as_deref()); - - let Some(api_key) = selected else { - return Err(AttemptOutcome::SkippedAuth); - }; - - let upstream_url = match ensure_query_param(upstream_url, GEMINI_API_KEY_QUERY, api_key) { - Ok(url) => url, - Err(message) => { - return Err(AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - format!("Failed to build upstream URL: {message}"), - ))) - } - }; - - let value = HeaderValue::from_str(api_key).map_err(|_| { - AttemptOutcome::Fatal(http::error_response( - StatusCode::UNAUTHORIZED, - "Upstream API key contains invalid characters.", - )) - })?; - - Ok(( - upstream_url, - http::UpstreamAuthHeader { - name: GEMINI_API_KEY_HEADER, - value, - }, - )) -} diff --git a/src-tauri/src/proxy/upstream/request.test.rs b/src-tauri/src/proxy/upstream/request.test.rs deleted file mode 100644 index 86122ee..0000000 --- a/src-tauri/src/proxy/upstream/request.test.rs +++ /dev/null @@ -1,164 +0,0 @@ -use super::*; - -#[tokio::test] -async fn filters_prompt_cache_retention_for_openai_responses_upstream() { - let upstream = UpstreamRuntime { - id: "test".to_string(), - base_url: "https://api.openai.com".to_string(), - api_key: None, - filter_prompt_cache_retention: true, - filter_safety_identifier: false, - kiro_account_id: None, - codex_account_id: None, - antigravity_account_id: None, - kiro_preferred_endpoint: None, - proxy_url: None, - priority: 0, - model_mappings: None, - header_overrides: None, - }; - let body = ReplayableBody::from_bytes(Bytes::from_static( - br#"{"model":"gpt-4o","prompt_cache_retention":"24h","input":"hi"}"#, - )); - - let rewritten = maybe_filter_openai_responses_request_fields( - "openai-response", - &upstream, - "/v1/responses?foo=bar", - &body, - ) - .await; - let rewritten = match rewritten { - Ok(value) => value, - Err(_) => panic!("rewrite result"), - }; - - let rewritten = rewritten.expect("should rewrite"); - let bytes = rewritten - .read_bytes_if_small(1024) - .await - .expect("read rewritten bytes") - .expect("rewritten body exists"); - let value: Value = serde_json::from_slice(&bytes).expect("json"); - - assert!(value.get("prompt_cache_retention").is_none()); - assert_eq!(value.get("model").and_then(Value::as_str), Some("gpt-4o")); -} - -#[tokio::test] -async fn filter_prompt_cache_retention_is_noop_when_disabled() { - let upstream = UpstreamRuntime { - id: "test".to_string(), - base_url: "https://api.openai.com".to_string(), - api_key: None, - filter_prompt_cache_retention: false, - filter_safety_identifier: false, - kiro_account_id: None, - codex_account_id: None, - antigravity_account_id: None, - kiro_preferred_endpoint: None, - proxy_url: None, - priority: 0, - model_mappings: None, - header_overrides: None, - }; - let body = ReplayableBody::from_bytes(Bytes::from_static( - br#"{"model":"gpt-4o","prompt_cache_retention":"24h","input":"hi"}"#, - )); - - let rewritten = maybe_filter_openai_responses_request_fields( - "openai-response", - &upstream, - "/v1/responses", - &body, - ) - .await; - let rewritten = match rewritten { - Ok(value) => value, - Err(_) => panic!("rewrite result"), - }; - - assert!(rewritten.is_none()); -} - -#[tokio::test] -async fn filters_safety_identifier_for_openai_responses_upstream() { - let upstream = UpstreamRuntime { - id: "test".to_string(), - base_url: "https://api.openai.com".to_string(), - api_key: None, - filter_prompt_cache_retention: false, - filter_safety_identifier: true, - kiro_account_id: None, - codex_account_id: None, - antigravity_account_id: None, - kiro_preferred_endpoint: None, - proxy_url: None, - priority: 0, - model_mappings: None, - header_overrides: None, - }; - let body = ReplayableBody::from_bytes(Bytes::from_static( - br#"{"model":"gpt-4o","safety_identifier":"sid_1","input":"hi"}"#, - )); - - let rewritten = maybe_filter_openai_responses_request_fields( - "openai-response", - &upstream, - "/v1/responses", - &body, - ) - .await; - let rewritten = match rewritten { - Ok(value) => value, - Err(_) => panic!("rewrite result"), - }; - - let rewritten = rewritten.expect("should rewrite"); - let bytes = rewritten - .read_bytes_if_small(1024) - .await - .expect("read rewritten bytes") - .expect("rewritten body exists"); - let value: Value = serde_json::from_slice(&bytes).expect("json"); - - assert!(value.get("safety_identifier").is_none()); - assert_eq!(value.get("prompt_cache_retention"), None); - assert_eq!(value.get("model").and_then(Value::as_str), Some("gpt-4o")); -} - -#[tokio::test] -async fn filter_safety_identifier_is_noop_when_disabled() { - let upstream = UpstreamRuntime { - id: "test".to_string(), - base_url: "https://api.openai.com".to_string(), - api_key: None, - filter_prompt_cache_retention: false, - filter_safety_identifier: false, - kiro_account_id: None, - codex_account_id: None, - antigravity_account_id: None, - kiro_preferred_endpoint: None, - proxy_url: None, - priority: 0, - model_mappings: None, - header_overrides: None, - }; - let body = ReplayableBody::from_bytes(Bytes::from_static( - br#"{"model":"gpt-4o","safety_identifier":"sid_1","input":"hi"}"#, - )); - - let rewritten = maybe_filter_openai_responses_request_fields( - "openai-response", - &upstream, - "/v1/responses", - &body, - ) - .await; - let rewritten = match rewritten { - Ok(value) => value, - Err(_) => panic!("rewrite result"), - }; - - assert!(rewritten.is_none()); -} diff --git a/src-tauri/src/proxy/upstream/result.rs b/src-tauri/src/proxy/upstream/result.rs deleted file mode 100644 index 450f2b3..0000000 --- a/src-tauri/src/proxy/upstream/result.rs +++ /dev/null @@ -1,145 +0,0 @@ -use std::sync::Arc; -use std::time::Instant; - -use axum::http::StatusCode; - -use super::utils::{is_retryable_error, is_retryable_status, sanitize_upstream_error}; -use super::AttemptOutcome; -use crate::proxy::http; -use crate::proxy::log::{build_log_entry, LogContext, LogWriter, UsageSnapshot}; -use crate::proxy::openai_compat::FormatTransform; -use crate::proxy::request_detail::RequestDetailSnapshot; -use crate::proxy::response::{build_proxy_response, build_proxy_response_buffered}; -use crate::proxy::token_rate::TokenRateTracker; -use crate::proxy::RequestMeta; - -pub(super) async fn handle_upstream_result( - upstream_res: Result, - meta: &RequestMeta, - provider: &str, - upstream_id: &str, - inbound_path: &str, - log: Arc, - token_rate: Arc, - start_time: Instant, - response_transform: FormatTransform, - request_detail: Option, -) -> AttemptOutcome { - match upstream_res { - Ok(res) if is_retryable_status(res.status()) => { - let response = build_proxy_response_buffered( - meta, - provider, - upstream_id, - inbound_path, - res, - log, - token_rate, - start_time, - response_transform, - request_detail.clone(), - ) - .await; - AttemptOutcome::Retryable { - message: format!("Upstream responded with {}", response.status()), - response: Some(response), - is_timeout: false, - } - } - Ok(res) => { - let response = build_proxy_response( - meta, - provider, - upstream_id, - inbound_path, - res, - log, - token_rate, - start_time, - response_transform, - request_detail.clone(), - ) - .await; - AttemptOutcome::Success(response) - } - Err(err) if is_retryable_error(&err) => { - let message = sanitize_upstream_error(provider, &err); - let status = if err.is_timeout() { - StatusCode::GATEWAY_TIMEOUT - } else { - StatusCode::BAD_GATEWAY - }; - log_upstream_error_if_needed( - &log, - request_detail.as_ref(), - meta, - provider, - upstream_id, - inbound_path, - status, - message.clone(), - start_time, - ); - AttemptOutcome::Retryable { - message, - response: None, - is_timeout: err.is_timeout(), - } - } - Err(err) => { - let message = sanitize_upstream_error(provider, &err); - log_upstream_error_if_needed( - &log, - request_detail.as_ref(), - meta, - provider, - upstream_id, - inbound_path, - StatusCode::BAD_GATEWAY, - format!("Upstream request failed: {message}"), - start_time, - ); - AttemptOutcome::Fatal(http::error_response( - StatusCode::BAD_GATEWAY, - format!("Upstream request failed: {message}"), - )) - } - } -} - -pub(super) fn log_upstream_error_if_needed( - log: &Arc, - request_detail: Option<&RequestDetailSnapshot>, - meta: &RequestMeta, - provider: &str, - upstream_id: &str, - inbound_path: &str, - status: StatusCode, - response_error: String, - start_time: Instant, -) { - let (request_headers, request_body) = request_detail - .map(|detail| (detail.request_headers.clone(), detail.request_body.clone())) - .unwrap_or((None, None)); - let context = LogContext { - path: inbound_path.to_string(), - provider: provider.to_string(), - upstream_id: upstream_id.to_string(), - model: meta.original_model.clone(), - mapped_model: meta.mapped_model.clone(), - stream: meta.stream, - status: status.as_u16(), - upstream_request_id: None, - request_headers, - request_body, - ttfb_ms: None, - start: start_time, - }; - let usage = UsageSnapshot { - usage: None, - cached_tokens: None, - usage_json: None, - }; - let entry = build_log_entry(&context, usage, Some(response_error)); - log.clone().write_detached(entry); -} diff --git a/src-tauri/src/proxy/upstream/utils.rs b/src-tauri/src/proxy/upstream/utils.rs deleted file mode 100644 index 728b25d..0000000 --- a/src-tauri/src/proxy/upstream/utils.rs +++ /dev/null @@ -1,88 +0,0 @@ -use axum::http::StatusCode; -use std::sync::atomic::Ordering; - -use super::super::{config::UpstreamStrategy, ProxyState}; -use crate::proxy::redact::redact_query_param_value; - -pub(super) fn extract_query_param(path_with_query: &str, name: &str) -> Option { - let url = url::Url::parse(&format!("http://localhost{path_with_query}")).ok()?; - url.query_pairs() - .find(|(key, _)| key == name) - .map(|(_, value)| value.into_owned()) -} - -pub(super) fn ensure_query_param(url: &str, name: &str, value: &str) -> Result { - let mut parsed = url::Url::parse(url).map_err(|err| err.to_string())?; - let pairs: Vec<(String, String)> = parsed - .query_pairs() - .map(|(key, value)| (key.into_owned(), value.into_owned())) - .collect(); - - { - let mut writer = parsed.query_pairs_mut(); - writer.clear(); - for (key, existing) in pairs { - if key == name { - continue; - } - writer.append_pair(&key, &existing); - } - writer.append_pair(name, value); - } - - Ok(parsed.to_string()) -} - -pub(super) fn sanitize_upstream_error(provider: &str, err: &reqwest::Error) -> String { - let message = err.to_string(); - if provider == "gemini" { - return redact_query_param_value(&message, super::GEMINI_API_KEY_QUERY); - } - message -} - -pub(super) fn resolve_group_start( - state: &ProxyState, - provider: &str, - group_index: usize, - group_len: usize, -) -> usize { - match state.config.upstream_strategy { - UpstreamStrategy::PriorityFillFirst => 0, - UpstreamStrategy::PriorityRoundRobin => state - .cursors - .get(provider) - .and_then(|cursors| cursors.get(group_index)) - .map(|cursor| cursor.fetch_add(1, Ordering::Relaxed) % group_len) - .unwrap_or(0), - } -} - -pub(super) fn build_group_order(group_len: usize, start: usize) -> Vec { - (0..group_len) - .map(|offset| (start + offset) % group_len) - .collect() -} - -pub(super) fn is_retryable_error(err: &reqwest::Error) -> bool { - err.is_timeout() || err.is_connect() -} - -pub(super) fn is_retryable_status(status: StatusCode) -> bool { - // 基于 new-api 的重试策略:400/429/307/5xx(排除 504/524);额外允许 403 触发 fallback。 - if status == StatusCode::BAD_REQUEST - || status == StatusCode::FORBIDDEN - || status == StatusCode::TOO_MANY_REQUESTS - || status == StatusCode::TEMPORARY_REDIRECT - { - return true; - } - if status == StatusCode::GATEWAY_TIMEOUT { - return false; - } - if status.as_u16() == 524 { - // Cloudflare timeout. - return false; - } - status.is_server_error() -} diff --git a/src-tauri/src/proxy/usage.rs b/src-tauri/src/proxy/usage.rs deleted file mode 100644 index 2b4618b..0000000 --- a/src-tauri/src/proxy/usage.rs +++ /dev/null @@ -1,189 +0,0 @@ -use axum::body::Bytes; -use serde_json::Value; - -use super::sse::SseEventParser; -use super::log::{TokenUsage, UsageSnapshot}; - -pub(crate) struct SseUsageCollector { - parser: SseEventParser, - snapshot: UsageSnapshot, -} - -impl SseUsageCollector { - pub(crate) fn new() -> Self { - Self { - parser: SseEventParser::new(), - snapshot: UsageSnapshot { - usage: None, - cached_tokens: None, - usage_json: None, - }, - } - } - - pub(crate) fn push_chunk(&mut self, chunk: &[u8]) { - let snapshot = &mut self.snapshot; - self.parser - .push_chunk(chunk, |data| update_usage(snapshot, &data)); - } - - pub(crate) fn finish(&mut self) -> UsageSnapshot { - let snapshot = &mut self.snapshot; - self.parser.finish(|data| update_usage(snapshot, &data)); - self.snapshot.clone() - } -} - -pub(crate) fn extract_usage_from_response(bytes: &Bytes) -> UsageSnapshot { - let Ok(value) = serde_json::from_slice::(bytes) else { - return UsageSnapshot { - usage: None, - cached_tokens: None, - usage_json: None, - }; - }; - - if let Some(usage) = value.get("usage") { - return snapshot_from_usage_value(usage); - } - - value - .get("usageMetadata") - .map(snapshot_from_usage_metadata_value) - .unwrap_or(UsageSnapshot { - usage: None, - cached_tokens: None, - usage_json: None, - }) -} - -fn extract_usage_from_event(value: &Value) -> Option { - if let Some(usage) = value.get("usage") { - return Some(snapshot_from_usage_value(usage)); - } - - if let Some(usage) = value.get("message").and_then(|message| message.get("usage")) { - return Some(snapshot_from_usage_value(usage)); - } - - if let Some(usage) = value.get("response").and_then(|response| response.get("usage")) { - return Some(snapshot_from_usage_value(usage)); - } - - if let Some(metadata) = value.get("usageMetadata") { - return Some(snapshot_from_usage_metadata_value(metadata)); - } - - value - .get("response") - .and_then(|response| response.get("usageMetadata")) - .map(snapshot_from_usage_metadata_value) -} - -fn usage_from_value(value: &Value) -> Option { - // Normalize both OpenAI Responses usage (`input_tokens`/`output_tokens`) and - // Chat Completions usage (`prompt_tokens`/`completion_tokens`) into a single shape. - let input_tokens = value - .get("input_tokens") - .and_then(Value::as_u64) - .or_else(|| value.get("prompt_tokens").and_then(Value::as_u64)); - let output_tokens = value - .get("output_tokens") - .and_then(Value::as_u64) - .or_else(|| value.get("completion_tokens").and_then(Value::as_u64)); - let total_tokens = value.get("total_tokens").and_then(Value::as_u64); - if input_tokens.is_some() || output_tokens.is_some() || total_tokens.is_some() { - return Some(TokenUsage { - input_tokens, - output_tokens, - total_tokens, - }); - } - None -} - -fn gemini_usage_from_value(value: &Value) -> Option { - // Gemini API 返回 `usageMetadata`:prompt/candidates/total token 计数。 - let input_tokens = value.get("promptTokenCount").and_then(Value::as_u64); - let output_tokens = value - .get("candidatesTokenCount") - .and_then(Value::as_u64); - let total_tokens = value.get("totalTokenCount").and_then(Value::as_u64); - - if input_tokens.is_some() || output_tokens.is_some() || total_tokens.is_some() { - return Some(TokenUsage { - input_tokens, - output_tokens, - total_tokens, - }); - } - None -} - -fn snapshot_from_usage_value(value: &Value) -> UsageSnapshot { - UsageSnapshot { - usage: usage_from_value(value), - cached_tokens: cached_tokens_from_usage_value(value), - usage_json: Some(value.clone()), - } -} - -fn snapshot_from_usage_metadata_value(value: &Value) -> UsageSnapshot { - UsageSnapshot { - usage: gemini_usage_from_value(value), - cached_tokens: None, - usage_json: Some(value.clone()), - } -} - -fn cached_tokens_from_usage_value(value: &Value) -> Option { - let cache_read = value.get("cache_read_input_tokens").and_then(Value::as_u64); - let cache_creation = value - .get("cache_creation_input_tokens") - .and_then(Value::as_u64); - if cache_read.is_some() || cache_creation.is_some() { - return match (cache_read, cache_creation) { - (Some(left), Some(right)) => left.checked_add(right), - (Some(left), None) => Some(left), - (None, Some(right)) => Some(right), - (None, None) => None, - }; - } - - value - .get("input_tokens_details") - .and_then(|details| details.get("cached_tokens")) - .and_then(Value::as_u64) - .or_else(|| { - value - .get("prompt_tokens_details") - .and_then(|details| details.get("cached_tokens")) - .and_then(Value::as_u64) - }) - .or_else(|| value.get("cached_tokens").and_then(Value::as_u64)) -} - -fn update_usage(snapshot: &mut UsageSnapshot, data: &str) { - if data == "[DONE]" { - return; - } - let Ok(value) = serde_json::from_str::(data) else { - return; - }; - let Some(updated) = extract_usage_from_event(&value) else { - return; - }; - if updated.usage_json.is_some() { - snapshot.usage_json = updated.usage_json; - snapshot.usage = updated.usage; - if updated.cached_tokens.is_some() { - // Preserve earlier cache stats when later events omit cache fields. - snapshot.cached_tokens = updated.cached_tokens; - } - } -} - -// 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 -#[cfg(test)] -#[path = "usage.test.rs"] -mod tests; diff --git a/src-tauri/src/proxy/usage.test.rs b/src-tauri/src/proxy/usage.test.rs deleted file mode 100644 index 8da87b7..0000000 --- a/src-tauri/src/proxy/usage.test.rs +++ /dev/null @@ -1,72 +0,0 @@ -use super::*; -use serde_json::json; - -#[test] -fn extract_usage_from_gemini_usage_metadata() { - let bytes = Bytes::from_static( - br#"{"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":2,"totalTokenCount":3}}"#, - ); - let usage = extract_usage_from_response(&bytes).usage.expect("usage"); - assert_eq!(usage.input_tokens, Some(1)); - assert_eq!(usage.output_tokens, Some(2)); - assert_eq!(usage.total_tokens, Some(3)); -} - -#[test] -fn sse_usage_collector_extracts_gemini_usage_metadata() { - let mut collector = SseUsageCollector::new(); - collector.push_chunk( - b"data: {\"usageMetadata\":{\"promptTokenCount\":1,\"candidatesTokenCount\":2,\"totalTokenCount\":3}}\n\n", - ); - let usage = collector.finish().usage.expect("usage"); - assert_eq!(usage.input_tokens, Some(1)); - assert_eq!(usage.output_tokens, Some(2)); - assert_eq!(usage.total_tokens, Some(3)); -} - -#[test] -fn extract_cached_tokens_from_openai_input_tokens_details() { - let bytes = Bytes::from_static( - br#"{"usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3,"input_tokens_details":{"cached_tokens":4}}}"#, - ); - let snapshot = extract_usage_from_response(&bytes); - assert_eq!(snapshot.cached_tokens, Some(4)); - assert_eq!(snapshot.usage_json.expect("usage_json")["input_tokens"], json!(1)); -} - -#[test] -fn extract_cached_tokens_from_openai_prompt_tokens_details() { - let bytes = Bytes::from_static( - br#"{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4}}}"#, - ); - let snapshot = extract_usage_from_response(&bytes); - assert_eq!(snapshot.cached_tokens, Some(4)); - assert_eq!(snapshot.usage_json.expect("usage_json")["prompt_tokens"], json!(1)); -} - -#[test] -fn extract_cached_tokens_from_anthropic_cache_fields() { - let bytes = Bytes::from_static( - br#"{"usage":{"input_tokens":1,"output_tokens":2,"cache_read_input_tokens":4,"cache_creation_input_tokens":5}}"#, - ); - let snapshot = extract_usage_from_response(&bytes); - assert_eq!(snapshot.cached_tokens, Some(9)); - assert_eq!( - snapshot.usage_json.expect("usage_json")["cache_read_input_tokens"], - json!(4) - ); -} - -#[test] -fn sse_usage_collector_extracts_anthropic_message_usage_and_cache_tokens() { - let mut collector = SseUsageCollector::new(); - collector.push_chunk( - b"data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"cache_read_input_tokens\":4,\"cache_creation_input_tokens\":5}}}\n\n", - ); - let snapshot = collector.finish(); - assert_eq!(snapshot.cached_tokens, Some(9)); - let usage = snapshot.usage.expect("usage"); - assert_eq!(usage.input_tokens, Some(1)); - assert_eq!(usage.output_tokens, Some(2)); - assert_eq!(usage.total_tokens, None); -} From 49f4031c20ddfd303f2816f0807b4dfd059c99c5 Mon Sep 17 00:00:00 2001 From: mxyhi Date: Fri, 30 Jan 2026 20:11:40 +0800 Subject: [PATCH 07/10] docs: normalize project name casing for consistency --- AGENTS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AGENTS.md b/AGENTS.md index c5cb01a..1462a83 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -13,7 +13,7 @@ ## 参考项目 - 代理转发/转换参考[new-api](.reference/new-api) -- kiro、codex、Antigravity等2api参考[CLIProxyAPIPlus](.reference/CLIProxyAPIPlus) +- kiro、codex、antigravity等2api参考[CLIProxyAPIPlus](.reference/CLIProxyAPIPlus) - CLIProxyAPIPlus的可视化app参考[quotio](.reference/quotio) --- From 799f94a11f6adb3aba7e5ff3d792bb4f42a68757 Mon Sep 17 00:00:00 2001 From: mxyhi Date: Sat, 31 Jan 2026 00:21:05 +0800 Subject: [PATCH 08/10] chore: reduce noisy tray title logging --- src-tauri/src/tray.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src-tauri/src/tray.rs b/src-tauri/src/tray.rs index 6bf762c..27666e0 100644 --- a/src-tauri/src/tray.rs +++ b/src-tauri/src/tray.rs @@ -140,32 +140,27 @@ impl TrayState { .clone() }; if !config.enabled { - tracing::debug!("tray update_token_rate_title disabled -> clear"); self.clear_title(); return; } // 启用后始终显示速率;无 token 时展示并发请求数。 let snapshot = self.inner.token_rate.snapshot().await; let title = format_rate_title(snapshot, config.format); - tracing::debug!(title = %title, "tray update_token_rate_title set"); self.set_title(Some(title)); } #[cfg(target_os = "macos")] fn set_title(&self, title: Option) { - tracing::debug!(title = ?title, "tray set_title called"); let mut last_title = self .inner .last_title .write() .expect("tray title lock poisoned"); if *last_title == title { - tracing::debug!(title = ?title, "tray set_title skipped (unchanged)"); return; } let _ = self.inner.tray.set_title(title.as_deref()); *last_title = title; - tracing::debug!(title = ?*last_title, "tray set_title applied"); } #[cfg(target_os = "macos")] From 7e1575a24c2672f1a195249fd497f90d2c1c0435 Mon Sep 17 00:00:00 2001 From: mxyhi Date: Sat, 31 Jan 2026 14:30:01 +0800 Subject: [PATCH 09/10] fix: mirror CLIProxyAPIPlus behavior to avoid upstream rejects --- README.md | 4 +- README.zh-CN.md | 4 +- .../src/proxy/antigravity_compat.rs | 152 +++- .../src/proxy/antigravity_compat.test.rs | 166 ++++- .../proxy/antigravity_compat/gemini_fixups.rs | 195 ++++++ .../src/proxy/antigravity_schema.rs | 452 ++++-------- .../src/proxy/antigravity_schema/ops.rs | 330 +++++++++ .../src/proxy/config/normalize.rs | 11 +- .../src/proxy/openai_compat.rs | 1 + .../src/proxy/openai_compat.test.part1.rs | 512 ++++++++++++++ .../src/proxy/openai_compat.test.part2.rs | 154 +++++ .../src/proxy/openai_compat.test.rs | 651 +----------------- .../src/proxy/openai_compat/usage.rs | 66 +- .../src/proxy/response.test.part2.rs | 125 ++++ .../src/proxy/response.test.rs | 38 +- .../src/proxy/response/chat_to_responses.rs | 25 +- .../response/chat_to_responses/state_types.rs | 16 + .../src/proxy/response/dispatch/buffered.rs | 19 +- .../proxy/response/responses_to_anthropic.rs | 5 + .../src/proxy/server_helpers.rs | 9 + .../src/proxy/token_estimator.rs | 47 +- .../src/proxy/token_estimator.test.rs | 32 + .../src/proxy/upstream/attempt.rs | 10 +- .../src/proxy/upstream/request.rs | 14 +- src/features/config/inbound-formats.ts | 4 +- 25 files changed, 1986 insertions(+), 1056 deletions(-) create mode 100644 crates/token_proxy_core/src/proxy/antigravity_compat/gemini_fixups.rs create mode 100644 crates/token_proxy_core/src/proxy/antigravity_schema/ops.rs create mode 100644 crates/token_proxy_core/src/proxy/openai_compat.test.part1.rs create mode 100644 crates/token_proxy_core/src/proxy/openai_compat.test.part2.rs create mode 100644 crates/token_proxy_core/src/proxy/response.test.part2.rs create mode 100644 crates/token_proxy_core/src/proxy/response/chat_to_responses/state_types.rs diff --git a/README.md b/README.md index d8bfb28..c7a1890 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,9 @@ Notes: - Cross-format fallback/conversion is controlled by `upstreams[].convert_from_map` (no global switch). If a provider has no eligible upstream for the inbound format, it won't be selected. - If `openai` is missing for `/v1/chat/completions`: fallback can be `openai-response`, `anthropic`, or `gemini` (priority-based; tie-break prefers `openai-response`). - For `/v1/messages`: choose between `anthropic` and `kiro` by priority; tie-break uses upstream id. If the chosen provider returns a retryable error, the proxy will fall back to the other native provider (Anthropic ↔ Kiro) when configured. -- If neither `anthropic` nor `kiro` exists for `/v1/messages`: fallback can be `openai-response`, `openai`, or `gemini` when the target provider is allowed for `anthropic_messages` via `convert_from_map`. +- If neither `anthropic` nor `kiro` exists for `/v1/messages`: + - `antigravity` is supported by default (no `convert_from_map` needed; aligned with CLIProxyAPIPlus Antigravity/Claude Code behavior). + - Other providers can be selected only when allowed for `anthropic_messages` via `convert_from_map` (e.g. `openai-response`, `openai`, `gemini`). - If `openai-response` is missing for `/v1/responses`: fallback can be `openai`, `anthropic`, or `gemini` (priority-based; tie-break prefers `openai`). - If `gemini` is missing for `/v1beta/models/*:generateContent`: fallback can be `openai-response`, `openai`, or `anthropic` (priority-based; tie-break prefers `openai-response`). diff --git a/README.zh-CN.md b/README.zh-CN.md index 787cf61..0cdbe6c 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -125,7 +125,9 @@ pnpm exec tsc --noEmit - 跨格式 fallback/转换由 `upstreams[].convert_from_map` 控制(不再有全局开关);若某个 provider 在该入站格式下没有任何可用 upstream,则不会被选中。 - `/v1/chat/completions` 缺少 `openai`:可 fallback 到 `openai-response` / `anthropic` / `gemini`(按优先级选择,平级优先 `openai-response`) - `/v1/messages`:在 `anthropic` 与 `kiro` 间按优先级选择;平级按 upstream id 排序。若命中 provider 返回“可重试错误”,且另一个 native provider 已配置,则会自动 fallback(Anthropic ↔ Kiro) -- 当 `/v1/messages` 缺少 `anthropic` 且 `kiro` 也不存在时:若目标 provider 在 `convert_from_map` 中允许 `anthropic_messages`,则可 fallback 到 `openai-response` / `openai` / `gemini`(按优先级选择,平级优先 `openai-response`) +- 当 `/v1/messages` 缺少 `anthropic` 且 `kiro` 也不存在时: + - `antigravity`:默认支持(无需 `convert_from_map`,对齐 CLIProxyAPIPlus 的 Antigravity/Claude Code 体验) + - 其它 provider:若在 `convert_from_map` 中允许 `anthropic_messages`,则可 fallback 到 `openai-response` / `openai` / `gemini`(按优先级选择,平级优先 `openai-response`) - `/v1/responses` 缺少 `openai-response`:可 fallback 到 `openai` / `anthropic` / `gemini`(按优先级选择,平级优先 `openai`) - `/v1beta/models/*:generateContent` 缺少 `gemini`:可 fallback 到 `openai-response` / `openai` / `anthropic`(按优先级选择,平级优先 `openai-response`) diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat.rs b/crates/token_proxy_core/src/proxy/antigravity_compat.rs index 22c913a..6c42e3e 100644 --- a/crates/token_proxy_core/src/proxy/antigravity_compat.rs +++ b/crates/token_proxy_core/src/proxy/antigravity_compat.rs @@ -4,12 +4,12 @@ use serde_json::{json, Map, Value}; use std::collections::VecDeque; use sha2::{Digest, Sha256}; -use crate::oauth_util::generate_state; -use crate::proxy::antigravity_schema::clean_json_schema_for_antigravity; +use crate::proxy::antigravity_schema::{clean_json_schema_for_antigravity, clean_json_schema_for_gemini}; use crate::proxy::sse::SseEventParser; mod signature_cache; mod claude; +mod gemini_fixups; pub(crate) use claude::claude_request_to_antigravity; @@ -32,22 +32,28 @@ pub(crate) fn wrap_gemini_request( let model = map_antigravity_model(&extract_model(&mut request, model_hint)); let model_lower = model.to_lowercase(); - let should_clean_tool_schema = + let use_antigravity_schema_cleaner = model_lower.contains("claude") || model_lower.contains("gemini-3-pro-high"); + // Align with CLIProxyAPIPlus: fix CLI tool response grouping and normalize roles + // before applying Antigravity-specific wrappers and schema transforms. + gemini_fixups::fix_cli_tool_response(&mut request); + gemini_fixups::normalize_contents_roles(&mut request); normalize_system_instruction(&mut request); - normalize_tool_schema(&mut request, should_clean_tool_schema); + normalize_tool_schema(&mut request, use_antigravity_schema_cleaner); ensure_system_instruction(&mut request, &model); remove_safety_settings(&mut request); ensure_tool_thought_signature(&mut request); + ensure_thinking_signature_for_gemini(&mut request, &model_lower); ensure_session_id(&mut request); + ensure_tool_config_mode(&mut request, &model_lower); trim_generation_config(&mut request, &model); let project = project_id .map(str::trim) .filter(|value| !value.is_empty()) .map(|value| value.to_string()) - .unwrap_or_else(|| generate_project_id().unwrap_or_default()); - let request_id = generate_agent_id("agent").unwrap_or_else(|_| "agent-unknown".to_string()); + .unwrap_or_else(generate_project_id); + let request_id = generate_request_id(); let mut root = Map::new(); root.insert("project".to_string(), Value::String(project)); @@ -284,9 +290,11 @@ fn ensure_system_instruction(request: &mut Map, model: &str) { } fn normalize_tool_schema(request: &mut Map, enabled: bool) { - if !enabled { - return; - } + // Align with CLIProxyAPIPlus: + // - Always rename `parametersJsonSchema` -> `parameters` for Antigravity upstream. + // - Use different schema cleaners based on model family: + // - Claude / gemini-3-pro-high: Antigravity cleaner (+ placeholders) + // - Others: Gemini cleaner (no placeholders) let Some(tools) = request.get_mut("tools").and_then(Value::as_array_mut) else { return; }; @@ -309,7 +317,11 @@ fn normalize_tool_schema(request: &mut Map, enabled: bool) { params.remove("$schema"); } if let Some(params) = decl.get_mut("parameters") { - clean_json_schema_for_antigravity(params); + if enabled { + clean_json_schema_for_antigravity(params); + } else { + clean_json_schema_for_gemini(params); + } } } } @@ -328,7 +340,11 @@ fn normalize_tool_schema(request: &mut Map, enabled: bool) { params.remove("$schema"); } if let Some(params) = decl.get_mut("parameters") { - clean_json_schema_for_antigravity(params); + if enabled { + clean_json_schema_for_antigravity(params); + } else { + clean_json_schema_for_gemini(params); + } } } } @@ -354,29 +370,65 @@ fn ensure_tool_thought_signature(request: &mut Map) { let Some(obj) = part.as_object_mut() else { continue; }; - if !(obj.contains_key("functionCall") || obj.contains_key("functionResponse")) { + // Align with CLIProxyAPIPlus: only functionCall requires thoughtSignature sentinels. + if !obj.contains_key("functionCall") { continue; } - obj.entry("thoughtSignature".to_string()) - .or_insert_with(|| Value::String(THOUGHT_SIGNATURE_SENTINEL.to_string())); + let set_sentinel = match obj.get("thoughtSignature").and_then(Value::as_str) { + Some(value) if value.len() >= 50 => false, + _ => true, + }; + if set_sentinel { + obj.insert( + "thoughtSignature".to_string(), + Value::String(THOUGHT_SIGNATURE_SENTINEL.to_string()), + ); + } } } } -fn ensure_session_id(request: &mut Map) { - let session_present = request - .get("sessionId") - .and_then(Value::as_str) - .map(|value| !value.trim().is_empty()) - .unwrap_or(false); - if session_present { +fn ensure_thinking_signature_for_gemini(request: &mut Map, model_lower: &str) { + // Align with CLIProxyAPIPlus gemini->antigravity behavior: + // Gemini (non-Claude) models may produce thinking blocks without signatures; mark them + // with the skip-sentinel so upstream bypasses signature validation. + if model_lower.contains("claude") { return; } + let Some(contents) = request.get_mut("contents").and_then(Value::as_array_mut) else { + return; + }; + for content in contents { + if content.get("role").and_then(Value::as_str) != Some("model") { + continue; + } + let Some(parts) = content.get_mut("parts").and_then(Value::as_array_mut) else { + continue; + }; + for part in parts { + let Some(obj) = part.as_object_mut() else { + continue; + }; + if obj.get("thought").and_then(Value::as_bool) != Some(true) { + continue; + } + // Align with CLIProxyAPIPlus: always force skip sentinel on Gemini thinking blocks. + obj.insert( + "thoughtSignature".to_string(), + Value::String(THOUGHT_SIGNATURE_SENTINEL.to_string()), + ); + } + } +} + +fn ensure_session_id(request: &mut Map) { + // Align with CLIProxyAPIPlus: always overwrite sessionId with a stable dash-decimal id + // (some clients send UUID-like values which Antigravity rejects). if let Some(session_id) = stable_session_id_from_contents(request) { request.insert("sessionId".to_string(), Value::String(session_id)); return; } - let session_id = generate_agent_id("sess").unwrap_or_else(|_| "sess-unknown".to_string()); + let session_id = generate_session_id(); request.insert("sessionId".to_string(), Value::String(session_id)); } @@ -390,12 +442,12 @@ fn stable_session_id_from_contents(request: &Map) -> Option, model: &str) { gen.remove("maxOutputTokens"); } -fn generate_agent_id(prefix: &str) -> Result { - let state = generate_state(prefix)?; - Ok(state) +fn ensure_tool_config_mode(request: &mut Map, model_lower: &str) { + // CLIProxyAPIPlus forces VALIDATED mode for Claude in Antigravity. + if !model_lower.contains("claude") { + return; + } + let tool_config = request + .entry("toolConfig".to_string()) + .or_insert_with(|| Value::Object(Map::new())); + let Some(tool_config) = tool_config.as_object_mut() else { + return; + }; + let calling = tool_config + .entry("functionCallingConfig".to_string()) + .or_insert_with(|| Value::Object(Map::new())); + let Some(calling) = calling.as_object_mut() else { + return; + }; + calling.insert( + "mode".to_string(), + Value::String("VALIDATED".to_string()), + ); +} + +fn generate_request_id() -> String { + // Align with CLIProxyAPIPlus: "agent-" + UUID. + format!("agent-{}", crate::proxy::kiro::utils::random_uuid()) +} + +fn generate_session_id() -> String { + // Align with CLIProxyAPIPlus: "-" + random 63-bit-ish decimal (legacy behavior). + let n = rand::random::() % 9_000_000_000_000_000_000u64; + format!("-{n}") } -fn generate_project_id() -> Result { - let state = generate_state("project")?; - Ok(state) +fn generate_project_id() -> String { + // Align with CLIProxyAPIPlus generateProjectID(): + // adjectives/nouns + "-" + first 5 chars of uuid. + const ADJECTIVES: [&str; 5] = ["useful", "bright", "swift", "calm", "bold"]; + const NOUNS: [&str; 5] = ["fuze", "wave", "spark", "flow", "core"]; + + let adj = ADJECTIVES[(rand::random::() as usize) % ADJECTIVES.len()]; + let noun = NOUNS[(rand::random::() as usize) % NOUNS.len()]; + let uuid = crate::proxy::kiro::utils::random_uuid(); + let random_part = uuid.replace('-', ""); + let random_part = random_part.chars().take(5).collect::().to_ascii_lowercase(); + format!("{adj}-{noun}-{random_part}") } diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs b/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs index 170ec89..410cf8e 100644 --- a/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs +++ b/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs @@ -166,6 +166,170 @@ fn keeps_existing_tool_config_mode() { let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); assert_eq!( value["request"]["toolConfig"]["functionCallingConfig"]["mode"].as_str(), - Some("ANY") + Some("VALIDATED") ); } + +#[test] +fn gemini_schema_cleaner_renames_and_does_not_add_placeholders() { + let request = json!({ + "model": "gemini-2.5-pro", + "contents": [ + { "role": "user", "parts": [{ "text": "hello" }] } + ], + "tools": [ + { + "function_declarations": [ + { + "name": "t", + "parametersJsonSchema": { + "type": "object", + "properties": {} + } + } + ] + } + ] + }); + let bytes = Bytes::from(request.to_string()); + let wrapped = wrap_gemini_request(&bytes, None, None, "ua").expect("wrap ok"); + let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); + let schema = &value["request"]["tools"][0]["function_declarations"][0]["parameters"]; + assert!(schema.get("properties").and_then(|v| v.get("reason")).is_none()); + assert!(schema.get("required").is_none()); +} + +#[test] +fn session_id_fallback_is_dash_decimal() { + let request = json!({ + "model": "gemini-2.5-pro", + "contents": [] + }); + let bytes = Bytes::from(request.to_string()); + let wrapped = wrap_gemini_request(&bytes, None, None, "ua").expect("wrap ok"); + let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); + let session = value["request"]["sessionId"].as_str().expect("sessionId"); + assert!(session.starts_with('-')); + assert!(session[1..].chars().all(|ch| ch.is_ascii_digit())); +} + +#[test] +fn request_id_and_project_match_reference_shapes() { + let request = json!({ + "model": "gemini-2.5-pro", + "contents": [ + { "role": "user", "parts": [{ "text": "hello" }] } + ] + }); + let bytes = Bytes::from(request.to_string()); + let wrapped = wrap_gemini_request(&bytes, None, None, "ua").expect("wrap ok"); + let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); + + let request_id = value["requestId"].as_str().expect("requestId"); + assert!(request_id.starts_with("agent-")); + assert!(is_uuid_like(&request_id["agent-".len()..])); + + let project = value["project"].as_str().expect("project"); + let parts: Vec<&str> = project.split('-').collect(); + assert_eq!(parts.len(), 3); + assert!(matches!(parts[0], "useful" | "bright" | "swift" | "calm" | "bold")); + assert!(matches!(parts[1], "fuze" | "wave" | "spark" | "flow" | "core")); + assert_eq!(parts[2].len(), 5); + assert!(parts[2].chars().all(|ch| ch.is_ascii_hexdigit())); +} + +#[test] +fn overwrites_session_id_even_when_provided() { + let request = json!({ + "model": "gemini-2.5-pro", + "sessionId": "not-stable", + "contents": [ + { "role": "user", "parts": [{ "text": "hello" }] } + ] + }); + let bytes = Bytes::from(request.to_string()); + let wrapped = wrap_gemini_request(&bytes, None, None, "ua").expect("wrap ok"); + let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); + let session = value["request"]["sessionId"].as_str().expect("sessionId"); + assert_ne!(session, "not-stable"); + assert!(session.starts_with('-')); + assert!(session[1..].chars().all(|ch| ch.is_ascii_digit())); +} + +#[test] +fn merges_function_responses_for_parallel_calls_and_does_not_set_response_signatures() { + let request = json!({ + "model": "gemini-2.5-pro", + "contents": [ + { + "role": "model", + "parts": [ + { "functionCall": { "name": "tool_one", "args": { "a": "1" } } }, + { "functionCall": { "name": "tool_two", "args": { "b": "2" } } } + ] + }, + { + "role": "user", + "parts": [ + { "functionResponse": { "name": "tool_one", "response": { "result": "ok1" } } } + ] + }, + { + "role": "user", + "parts": [ + { "functionResponse": { "name": "tool_two", "response": { "result": "ok2" } } } + ] + } + ] + }); + let bytes = Bytes::from(request.to_string()); + let wrapped = wrap_gemini_request(&bytes, None, None, "ua").expect("wrap ok"); + let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); + let contents = value["request"]["contents"].as_array().expect("contents array"); + assert_eq!(contents.len(), 2); + assert_eq!(contents[0]["role"].as_str(), Some("model")); + assert_eq!(contents[1]["role"].as_str(), Some("user")); + + let merged_parts = contents[1]["parts"].as_array().expect("parts array"); + assert_eq!(merged_parts.len(), 2); + for part in merged_parts { + assert!(part.get("functionResponse").is_some()); + assert!(part.get("thoughtSignature").is_none()); + } +} + +#[test] +fn normalizes_invalid_roles_in_contents() { + let request = json!({ + "model": "gemini-2.5-pro", + "contents": [ + { "role": "assistant", "parts": [{ "text": "a" }] }, + { "role": "assistant", "parts": [{ "text": "b" }] } + ] + }); + let bytes = Bytes::from(request.to_string()); + let wrapped = wrap_gemini_request(&bytes, None, None, "ua").expect("wrap ok"); + let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); + let contents = value["request"]["contents"].as_array().expect("contents array"); + assert_eq!(contents[0]["role"].as_str(), Some("user")); + assert_eq!(contents[1]["role"].as_str(), Some("model")); +} + +fn is_uuid_like(value: &str) -> bool { + if value.len() != 36 { + return false; + } + let bytes: Vec = value.chars().collect(); + for &idx in &[8_usize, 13, 18, 23] { + if bytes.get(idx) != Some(&'-') { + return false; + } + } + value + .chars() + .enumerate() + .all(|(idx, ch)| match idx { + 8 | 13 | 18 | 23 => ch == '-', + _ => ch.is_ascii_hexdigit(), + }) +} diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat/gemini_fixups.rs b/crates/token_proxy_core/src/proxy/antigravity_compat/gemini_fixups.rs new file mode 100644 index 0000000..5935db6 --- /dev/null +++ b/crates/token_proxy_core/src/proxy/antigravity_compat/gemini_fixups.rs @@ -0,0 +1,195 @@ +use serde_json::{json, Map, Value}; +use std::collections::VecDeque; + +struct CliToolResponseFixer { + out: Vec, + pending_groups: Vec, + collected_responses: VecDeque, +} + +impl CliToolResponseFixer { + fn new(capacity: usize) -> Self { + Self { + out: Vec::with_capacity(capacity), + pending_groups: Vec::new(), + collected_responses: VecDeque::new(), + } + } + + fn push_content(&mut self, content: Value) { + let Some(obj) = content.as_object() else { + return; + }; + + let parts = parts_slice(obj); + if let Some(response_parts) = response_only_parts(parts) { + self.push_responses_and_maybe_merge(response_parts); + return; + } + + if is_model_content(obj) { + let function_call_count = count_function_calls(parts); + if function_call_count > 0 { + self.out.push(content); + self.pending_groups.push(function_call_count); + return; + } + } + + self.out.push(content); + } + + fn push_responses_and_maybe_merge(&mut self, response_parts: Vec) { + for part in response_parts { + self.collected_responses.push_back(part); + } + self.try_satisfy_latest_group(); + } + + fn try_satisfy_latest_group(&mut self) { + for idx in (0..self.pending_groups.len()).rev() { + let needed = self.pending_groups[idx]; + if self.collected_responses.len() < needed { + continue; + } + let merged_parts = self.take_merged_parts(needed); + if !merged_parts.is_empty() { + self.out.push(json!({ "role": "function", "parts": merged_parts })); + } + self.pending_groups.remove(idx); + break; + } + } + + fn take_merged_parts(&mut self, needed: usize) -> Vec { + let mut merged_parts = Vec::with_capacity(needed); + for _ in 0..needed { + let Some(next) = self.collected_responses.pop_front() else { + break; + }; + if next.is_object() { + merged_parts.push(next); + } else { + merged_parts.push(fallback_function_response_part(&next)); + } + } + merged_parts + } + + fn flush_remaining(mut self) -> Vec { + for needed in std::mem::take(&mut self.pending_groups) { + if self.collected_responses.len() < needed { + break; + } + let merged_parts = self.take_merged_parts(needed); + if !merged_parts.is_empty() { + self.out.push(json!({ "role": "function", "parts": merged_parts })); + } + } + self.out + } +} + +/// Align with CLIProxyAPIPlus `fixCLIToolResponse()`: +/// - Collect standalone `functionResponse` contents. +/// - For each preceding `model` content that contains N `functionCall` parts, +/// merge the next N `functionResponse` parts into a single content entry. +/// +/// NOTE: This intentionally assumes that a "tool response content" contains ONLY +/// `functionResponse` parts; if mixed parts exist, we keep the content unchanged. +pub(super) fn fix_cli_tool_response(request: &mut Map) { + let Some(contents) = request.get_mut("contents").and_then(Value::as_array_mut) else { + return; + }; + + let original = std::mem::take(contents); + let mut fixer = CliToolResponseFixer::new(original.len()); + for content in original { + fixer.push_content(content); + } + *contents = fixer.flush_remaining(); +} + +pub(super) fn normalize_contents_roles(request: &mut Map) { + let Some(contents) = request.get_mut("contents").and_then(Value::as_array_mut) else { + return; + }; + let mut prev_role = String::new(); + for content in contents { + let Some(obj) = content.as_object_mut() else { + continue; + }; + let role = obj + .get("role") + .and_then(Value::as_str) + .unwrap_or("") + .to_string(); + let valid = role == "user" || role == "model"; + let role = if valid { + role + } else { + let next = if prev_role.is_empty() { + "user" + } else if prev_role == "user" { + "model" + } else { + "user" + }; + obj.insert("role".to_string(), Value::String(next.to_string())); + next.to_string() + }; + prev_role = role; + } +} + +fn fallback_function_response_part(value: &Value) -> Value { + // Best-effort fallback; should be rare in practice. + json!({ + "functionResponse": { + "name": "unknown", + "response": { "result": value_to_string(value) } + } + }) +} + +fn value_to_string(value: &Value) -> String { + match value { + Value::String(text) => text.clone(), + other => other.to_string(), + } +} + +fn parts_slice(content: &Map) -> &[Value] { + content + .get("parts") + .and_then(Value::as_array) + .map(|value| value.as_slice()) + .unwrap_or(&[]) +} + +fn response_only_parts(parts: &[Value]) -> Option> { + let mut responses = Vec::new(); + for part in parts { + if part.get("functionResponse").is_some() { + responses.push(part.clone()); + } else { + return None; + } + } + if responses.is_empty() { + None + } else { + Some(responses) + } +} + +fn count_function_calls(parts: &[Value]) -> usize { + parts + .iter() + .filter(|part| part.get("functionCall").is_some()) + .count() +} + +fn is_model_content(content: &Map) -> bool { + content.get("role").and_then(Value::as_str) == Some("model") +} diff --git a/crates/token_proxy_core/src/proxy/antigravity_schema.rs b/crates/token_proxy_core/src/proxy/antigravity_schema.rs index 8b52e03..12d708e 100644 --- a/crates/token_proxy_core/src/proxy/antigravity_schema.rs +++ b/crates/token_proxy_core/src/proxy/antigravity_schema.rs @@ -1,6 +1,9 @@ use serde_json::{Map, Value}; use std::collections::HashMap; +mod ops; +use ops::*; + #[derive(Clone, Debug, PartialEq, Eq, Hash)] enum PathSegment { Key(String), @@ -23,6 +26,14 @@ const UNSUPPORTED_CONSTRAINTS: [&str; 10] = [ ]; pub(crate) fn clean_json_schema_for_antigravity(schema: &mut Value) { + clean_json_schema(schema, true); +} + +pub(crate) fn clean_json_schema_for_gemini(schema: &mut Value) { + clean_json_schema(schema, false); +} + +fn clean_json_schema(schema: &mut Value, add_placeholder: bool) { convert_refs_to_hints(schema); convert_const_to_enum(schema); convert_enum_values_to_strings(schema); @@ -35,8 +46,16 @@ pub(crate) fn clean_json_schema_for_antigravity(schema: &mut Value) { flatten_type_arrays(schema); remove_unsupported_keywords(schema); + if !add_placeholder { + // Gemini schema cleanup: remove nullable/title and placeholder-only fields. + remove_keywords(schema, &["nullable", "title"]); + remove_placeholder_fields(schema); + } cleanup_required_fields(schema); - add_empty_schema_placeholder(schema); + // Antigravity/Claude VALIDATED-mode: object schemas cannot be empty; inject placeholders. + if add_placeholder { + add_empty_schema_placeholder(schema); + } } fn convert_refs_to_hints(schema: &mut Value) { @@ -65,6 +84,114 @@ fn convert_refs_to_hints(schema: &mut Value) { } } +fn remove_keywords(schema: &mut Value, keywords: &[&str]) { + for key in keywords { + let mut paths = collect_paths(schema, key); + sort_by_depth(&mut paths); + for path in paths { + let Some(parent_path) = parent_path(&path) else { + continue; + }; + if is_property_definition(&parent_path) { + continue; + } + let _ = delete_at_path(schema, &path); + } + } +} + +fn remove_placeholder_fields(schema: &mut Value) { + remove_placeholder_property(schema, "_", None); + remove_placeholder_reason(schema); +} + +fn remove_placeholder_property(schema: &mut Value, key: &str, required_key: Option<&str>) { + let mut paths = collect_paths(schema, key); + sort_by_depth(&mut paths); + for path in paths { + if !ends_with_properties_key(&path, key) { + continue; + } + let _ = delete_at_path(schema, &path); + let Some(parent_path) = trim_properties_key_suffix(&path, key) else { + continue; + }; + remove_required_entry(schema, &parent_path, required_key.unwrap_or(key)); + } +} + +fn remove_placeholder_reason(schema: &mut Value) { + let mut paths = collect_paths(schema, "reason"); + sort_by_depth(&mut paths); + for path in paths { + if !ends_with_properties_key(&path, "reason") { + continue; + } + let Some(parent_path) = trim_properties_key_suffix(&path, "reason") else { + continue; + }; + // Only remove when it's the only property and matches our placeholder description. + let props_path = join_path(&parent_path, "properties"); + let Some(Value::Object(props)) = get_value(schema, &props_path) else { + continue; + }; + if props.len() != 1 { + continue; + } + let desc_path = join_path(&path, "description"); + let desc = get_value(schema, &desc_path) + .and_then(Value::as_str) + .unwrap_or(""); + if desc != "Brief explanation of why you are calling this tool" { + continue; + } + let _ = delete_at_path(schema, &path); + remove_required_entry(schema, &parent_path, "reason"); + } +} + +fn ends_with_properties_key(path: &Path, key: &str) -> bool { + if path.len() < 2 { + return false; + } + matches!(path.get(path.len() - 2), Some(PathSegment::Key(k)) if k == "properties") + && matches!(path.last(), Some(PathSegment::Key(k)) if k == key) +} + +fn trim_properties_key_suffix(path: &Path, key: &str) -> Option { + if !ends_with_properties_key(path, key) { + return None; + } + let mut parent = path.clone(); + parent.pop(); // key + parent.pop(); // properties + Some(parent) +} + +fn join_path(parent: &Path, key: &str) -> Path { + let mut next = parent.clone(); + next.push(PathSegment::Key(key.to_string())); + next +} + +fn remove_required_entry(schema: &mut Value, parent_path: &Path, key: &str) { + let req_path = join_path(parent_path, "required"); + let Some(Value::Array(required)) = get_value_mut(schema, &req_path) else { + return; + }; + let next = required + .iter() + .filter_map(|item| item.as_str()) + .filter(|item| *item != key) + .map(|item| Value::String(item.to_string())) + .collect::>(); + if next.is_empty() { + let _ = delete_at_path(schema, &req_path); + } else { + *required = next; + } +} + fn convert_const_to_enum(schema: &mut Value) { let paths = collect_paths(schema, "const"); for path in paths { @@ -446,326 +573,3 @@ fn cleanup_required_fields(schema: &mut Value) { } } } - -fn add_empty_schema_placeholder(schema: &mut Value) { - let mut paths = collect_paths(schema, "type"); - sort_by_depth(&mut paths); - for path in paths { - let Some(Value::String(value)) = get_value(schema, &path) else { - continue; - }; - if value != "object" { - continue; - } - let parent_path = match parent_path(&path) { - Some(parent) => parent, - None => continue, - }; - apply_schema_placeholder(schema, &parent_path); - } -} - -fn apply_schema_placeholder(schema: &mut Value, parent_path: &Path) { - let Some(parent) = get_object_mut(schema, parent_path) else { - return; - }; - let props = parent.get("properties"); - let req = parent.get("required"); - let has_required = req - .and_then(Value::as_array) - .map(|items| !items.is_empty()) - .unwrap_or(false); - let needs_placeholder = match props { - None => true, - Some(Value::Object(map)) => map.is_empty(), - _ => false, - }; - if needs_placeholder { - add_reason_placeholder(parent); - return; - } - if !has_required { - if parent_path.is_empty() { - return; - } - add_required_placeholder(parent); - } -} - -fn add_reason_placeholder(parent: &mut Map) { - let props = parent - .entry("properties".to_string()) - .or_insert_with(|| Value::Object(Map::new())); - let Some(props) = props.as_object_mut() else { - return; - }; - let reason = props - .entry("reason".to_string()) - .or_insert_with(|| Value::Object(Map::new())); - if let Some(reason) = reason.as_object_mut() { - reason.insert("type".to_string(), Value::String("string".to_string())); - reason.insert( - "description".to_string(), - Value::String("Brief explanation of why you are calling this tool".to_string()), - ); - } - parent.insert( - "required".to_string(), - Value::Array(vec![Value::String("reason".to_string())]), - ); -} - -fn add_required_placeholder(parent: &mut Map) { - let props = parent - .entry("properties".to_string()) - .or_insert_with(|| Value::Object(Map::new())); - let Some(props) = props.as_object_mut() else { - return; - }; - if !props.contains_key("_") { - let mut placeholder = Map::new(); - placeholder.insert("type".to_string(), Value::String("boolean".to_string())); - props.insert("_".to_string(), Value::Object(placeholder)); - } - parent.insert( - "required".to_string(), - Value::Array(vec![Value::String("_".to_string())]), - ); -} - -fn collect_paths(schema: &Value, field: &str) -> Vec { - let mut paths = Vec::new(); - let mut current = Vec::new(); - walk(schema, field, &mut current, &mut paths); - paths -} - -fn walk(value: &Value, field: &str, path: &mut Path, out: &mut Vec) { - match value { - Value::Object(map) => { - for (key, val) in map { - path.push(PathSegment::Key(key.clone())); - if key == field { - out.push(path.clone()); - } - walk(val, field, path, out); - path.pop(); - } - } - Value::Array(items) => { - for (idx, item) in items.iter().enumerate() { - path.push(PathSegment::Index(idx)); - walk(item, field, path, out); - path.pop(); - } - } - _ => {} - } -} - -fn sort_by_depth(paths: &mut Vec) { - paths.sort_by(|a, b| b.len().cmp(&a.len())); -} - -fn parent_path(path: &Path) -> Option { - if path.is_empty() { - return None; - } - let mut parent = path.clone(); - parent.pop(); - Some(parent) -} - -fn get_value<'a>(root: &'a Value, path: &[PathSegment]) -> Option<&'a Value> { - let mut current = root; - for segment in path { - match segment { - PathSegment::Key(key) => { - current = current.get(key)?; - } - PathSegment::Index(index) => { - current = current.get(*index)?; - } - } - } - Some(current) -} - -fn get_value_mut<'a>(root: &'a mut Value, path: &[PathSegment]) -> Option<&'a mut Value> { - let mut current = root; - for segment in path { - match segment { - PathSegment::Key(key) => { - current = current.get_mut(key)?; - } - PathSegment::Index(index) => { - current = current.get_mut(*index)?; - } - } - } - Some(current) -} - -fn set_value_at_path(root: &mut Value, path: &[PathSegment], value: Value) -> bool { - if path.is_empty() { - *root = value; - return true; - } - let (parent, last) = match split_parent(path) { - Some(split) => split, - None => return false, - }; - let Some(parent) = get_value_mut(root, parent) else { - return false; - }; - match last { - PathSegment::Key(key) => { - let Some(obj) = parent.as_object_mut() else { - return false; - }; - obj.insert(key.clone(), value); - true - } - PathSegment::Index(index) => { - let Some(arr) = parent.as_array_mut() else { - return false; - }; - if *index >= arr.len() { - return false; - } - arr[*index] = value; - true - } - } -} - -fn delete_at_path(root: &mut Value, path: &[PathSegment]) -> bool { - let (parent, last) = match split_parent(path) { - Some(split) => split, - None => return false, - }; - let Some(parent) = get_value_mut(root, parent) else { - return false; - }; - match last { - PathSegment::Key(key) => { - let Some(obj) = parent.as_object_mut() else { - return false; - }; - obj.remove(key).is_some() - } - PathSegment::Index(index) => { - let Some(arr) = parent.as_array_mut() else { - return false; - }; - if *index >= arr.len() { - return false; - } - arr.remove(*index); - true - } - } -} - -fn split_parent(path: &[PathSegment]) -> Option<(&[PathSegment], &PathSegment)> { - let len = path.len(); - if len == 0 { - return None; - } - Some((&path[..len - 1], &path[len - 1])) -} - -fn get_object_mut<'a>(root: &'a mut Value, path: &[PathSegment]) -> Option<&'a mut Map> { - get_value_mut(root, path)?.as_object_mut() -} - -fn append_hint(root: &mut Value, path: &[PathSegment], hint: &str) { - let Some(obj) = get_object_mut(root, path) else { - return; - }; - let existing = obj - .get("description") - .and_then(Value::as_str) - .unwrap_or(""); - let next = if existing.is_empty() { - hint.to_string() - } else { - format!("{existing} ({hint})") - }; - obj.insert("description".to_string(), Value::String(next)); -} - -fn append_hint_raw(schema: &mut Value, hint: &str) { - let Some(obj) = schema.as_object_mut() else { - return; - }; - let existing = obj - .get("description") - .and_then(Value::as_str) - .unwrap_or(""); - let next = if existing.is_empty() { - hint.to_string() - } else { - format!("{existing} ({hint})") - }; - obj.insert("description".to_string(), Value::String(next)); -} - -fn merge_description(schema: &mut Value, parent_desc: &str) { - let Some(obj) = schema.as_object_mut() else { - return; - }; - let child_desc = obj - .get("description") - .and_then(Value::as_str) - .unwrap_or(""); - if child_desc.is_empty() { - obj.insert("description".to_string(), Value::String(parent_desc.to_string())); - return; - } - if child_desc == parent_desc { - return; - } - obj.insert( - "description".to_string(), - Value::String(format!("{parent_desc} ({child_desc})")), - ); -} - -fn is_property_definition(path: &[PathSegment]) -> bool { - match path.last() { - Some(PathSegment::Key(key)) if key == "properties" => true, - _ => path.len() == 1 && matches!(path[0], PathSegment::Key(ref key) if key == "properties"), - } -} - -fn get_description(root: &Value, path: &[PathSegment]) -> Option { - let obj = get_value(root, path)?.as_object()?; - let desc = obj.get("description")?.as_str()?.trim().to_string(); - Some(desc) -} - -fn value_to_string(value: &Value) -> String { - match value { - Value::String(value) => value.clone(), - Value::Number(value) => value.to_string(), - Value::Bool(value) => value.to_string(), - Value::Null => "null".to_string(), - other => other.to_string(), - } -} - -fn property_field_from_type_path(path: &[PathSegment]) -> Option<(Path, String)> { - if path.len() < 3 { - return None; - } - let len = path.len(); - if !matches!(path.get(len - 3), Some(PathSegment::Key(key)) if key == "properties") { - return None; - } - let field = match path.get(len - 2) { - Some(PathSegment::Key(key)) => key.clone(), - _ => return None, - }; - Some((path[..len - 3].to_vec(), field)) -} diff --git a/crates/token_proxy_core/src/proxy/antigravity_schema/ops.rs b/crates/token_proxy_core/src/proxy/antigravity_schema/ops.rs new file mode 100644 index 0000000..399dc8c --- /dev/null +++ b/crates/token_proxy_core/src/proxy/antigravity_schema/ops.rs @@ -0,0 +1,330 @@ +use serde_json::{Map, Value}; + +use super::{Path, PathSegment}; + +pub(super) fn add_empty_schema_placeholder(schema: &mut Value) { + let mut paths = collect_paths(schema, "type"); + sort_by_depth(&mut paths); + for path in paths { + let Some(Value::String(value)) = get_value(schema, &path) else { + continue; + }; + if value != "object" { + continue; + } + let parent_path = match parent_path(&path) { + Some(parent) => parent, + None => continue, + }; + apply_schema_placeholder(schema, &parent_path); + } +} + +fn apply_schema_placeholder(schema: &mut Value, parent_path: &Path) { + let Some(parent) = get_object_mut(schema, parent_path) else { + return; + }; + let props = parent.get("properties"); + let req = parent.get("required"); + let has_required = req + .and_then(Value::as_array) + .map(|items| !items.is_empty()) + .unwrap_or(false); + let needs_placeholder = match props { + None => true, + Some(Value::Object(map)) => map.is_empty(), + _ => false, + }; + if needs_placeholder { + add_reason_placeholder(parent); + return; + } + if !has_required { + if parent_path.is_empty() { + return; + } + add_required_placeholder(parent); + } +} + +fn add_reason_placeholder(parent: &mut Map) { + let props = parent + .entry("properties".to_string()) + .or_insert_with(|| Value::Object(Map::new())); + let Some(props) = props.as_object_mut() else { + return; + }; + let reason = props + .entry("reason".to_string()) + .or_insert_with(|| Value::Object(Map::new())); + if let Some(reason) = reason.as_object_mut() { + reason.insert("type".to_string(), Value::String("string".to_string())); + reason.insert( + "description".to_string(), + Value::String("Brief explanation of why you are calling this tool".to_string()), + ); + } + parent.insert( + "required".to_string(), + Value::Array(vec![Value::String("reason".to_string())]), + ); +} + +fn add_required_placeholder(parent: &mut Map) { + let props = parent + .entry("properties".to_string()) + .or_insert_with(|| Value::Object(Map::new())); + let Some(props) = props.as_object_mut() else { + return; + }; + if !props.contains_key("_") { + let mut placeholder = Map::new(); + placeholder.insert("type".to_string(), Value::String("boolean".to_string())); + props.insert("_".to_string(), Value::Object(placeholder)); + } + parent.insert( + "required".to_string(), + Value::Array(vec![Value::String("_".to_string())]), + ); +} + +pub(super) fn collect_paths(schema: &Value, field: &str) -> Vec { + let mut paths = Vec::new(); + let mut current = Vec::new(); + walk(schema, field, &mut current, &mut paths); + paths +} + +fn walk(value: &Value, field: &str, path: &mut Path, out: &mut Vec) { + match value { + Value::Object(map) => { + for (key, val) in map { + path.push(PathSegment::Key(key.clone())); + if key == field { + out.push(path.clone()); + } + walk(val, field, path, out); + path.pop(); + } + } + Value::Array(items) => { + for (idx, item) in items.iter().enumerate() { + path.push(PathSegment::Index(idx)); + walk(item, field, path, out); + path.pop(); + } + } + _ => {} + } +} + +pub(super) fn sort_by_depth(paths: &mut Vec) { + paths.sort_by(|a, b| b.len().cmp(&a.len())); +} + +pub(super) fn parent_path(path: &Path) -> Option { + if path.is_empty() { + return None; + } + let mut parent = path.clone(); + parent.pop(); + Some(parent) +} + +pub(super) fn get_value<'a>(root: &'a Value, path: &[PathSegment]) -> Option<&'a Value> { + let mut current = root; + for segment in path { + match segment { + PathSegment::Key(key) => { + current = current.get(key)?; + } + PathSegment::Index(index) => { + current = current.get(*index)?; + } + } + } + Some(current) +} + +pub(super) fn get_value_mut<'a>(root: &'a mut Value, path: &[PathSegment]) -> Option<&'a mut Value> { + let mut current = root; + for segment in path { + match segment { + PathSegment::Key(key) => { + current = current.get_mut(key)?; + } + PathSegment::Index(index) => { + current = current.get_mut(*index)?; + } + } + } + Some(current) +} + +pub(super) fn set_value_at_path(root: &mut Value, path: &[PathSegment], value: Value) -> bool { + if path.is_empty() { + *root = value; + return true; + } + let (parent, last) = match split_parent(path) { + Some(split) => split, + None => return false, + }; + let Some(parent) = get_value_mut(root, parent) else { + return false; + }; + match last { + PathSegment::Key(key) => { + let Some(obj) = parent.as_object_mut() else { + return false; + }; + obj.insert(key.clone(), value); + true + } + PathSegment::Index(index) => { + let Some(arr) = parent.as_array_mut() else { + return false; + }; + if *index >= arr.len() { + return false; + } + arr[*index] = value; + true + } + } +} + +pub(super) fn delete_at_path(root: &mut Value, path: &[PathSegment]) -> bool { + let (parent, last) = match split_parent(path) { + Some(split) => split, + None => return false, + }; + let Some(parent) = get_value_mut(root, parent) else { + return false; + }; + match last { + PathSegment::Key(key) => { + let Some(obj) = parent.as_object_mut() else { + return false; + }; + obj.remove(key).is_some() + } + PathSegment::Index(index) => { + let Some(arr) = parent.as_array_mut() else { + return false; + }; + if *index >= arr.len() { + return false; + } + arr.remove(*index); + true + } + } +} + +fn split_parent(path: &[PathSegment]) -> Option<(&[PathSegment], &PathSegment)> { + let len = path.len(); + if len == 0 { + return None; + } + Some((&path[..len - 1], &path[len - 1])) +} + +pub(super) fn get_object_mut<'a>( + root: &'a mut Value, + path: &[PathSegment], +) -> Option<&'a mut Map> { + get_value_mut(root, path)?.as_object_mut() +} + +pub(super) fn append_hint(root: &mut Value, path: &[PathSegment], hint: &str) { + let Some(obj) = get_object_mut(root, path) else { + return; + }; + let existing = obj + .get("description") + .and_then(Value::as_str) + .unwrap_or(""); + let next = if existing.is_empty() { + hint.to_string() + } else { + format!("{existing} ({hint})") + }; + obj.insert("description".to_string(), Value::String(next)); +} + +pub(super) fn append_hint_raw(schema: &mut Value, hint: &str) { + let Some(obj) = schema.as_object_mut() else { + return; + }; + let existing = obj + .get("description") + .and_then(Value::as_str) + .unwrap_or(""); + let next = if existing.is_empty() { + hint.to_string() + } else { + format!("{existing} ({hint})") + }; + obj.insert("description".to_string(), Value::String(next)); +} + +pub(super) fn merge_description(schema: &mut Value, parent_desc: &str) { + let Some(obj) = schema.as_object_mut() else { + return; + }; + let child_desc = obj + .get("description") + .and_then(Value::as_str) + .unwrap_or(""); + if child_desc.is_empty() { + obj.insert("description".to_string(), Value::String(parent_desc.to_string())); + return; + } + if child_desc == parent_desc { + return; + } + obj.insert( + "description".to_string(), + Value::String(format!("{parent_desc} ({child_desc})")), + ); +} + +pub(super) fn is_property_definition(path: &[PathSegment]) -> bool { + match path.last() { + Some(PathSegment::Key(key)) if key == "properties" => true, + _ => path.len() == 1 && matches!(path[0], PathSegment::Key(ref key) if key == "properties"), + } +} + +pub(super) fn get_description(root: &Value, path: &[PathSegment]) -> Option { + let obj = get_value(root, path)?.as_object()?; + let desc = obj.get("description")?.as_str()?.trim().to_string(); + Some(desc) +} + +pub(super) fn value_to_string(value: &Value) -> String { + match value { + Value::String(value) => value.clone(), + Value::Number(value) => value.to_string(), + Value::Bool(value) => value.to_string(), + Value::Null => "null".to_string(), + other => other.to_string(), + } +} + +pub(super) fn property_field_from_type_path(path: &[PathSegment]) -> Option<(Path, String)> { + if path.len() < 3 { + return None; + } + let len = path.len(); + if !matches!(path.get(len - 3), Some(PathSegment::Key(key)) if key == "properties") { + return None; + } + let field = match path.get(len - 2) { + Some(PathSegment::Key(key)) => key.clone(), + _ => return None, + }; + Some((path[..len - 3].to_vec(), field)) +} + diff --git a/crates/token_proxy_core/src/proxy/config/normalize.rs b/crates/token_proxy_core/src/proxy/config/normalize.rs index d3fa44d..bd16c9a 100644 --- a/crates/token_proxy_core/src/proxy/config/normalize.rs +++ b/crates/token_proxy_core/src/proxy/config/normalize.rs @@ -293,8 +293,15 @@ fn native_inbound_formats_for_provider(provider: &str) -> InboundApiFormatMask { "kiro" => mask.insert(InboundApiFormat::AnthropicMessages), // Codex 的“native”更接近 OpenAI Responses;Chat 通常需要显式允许转换。 "codex" => mask.insert(InboundApiFormat::OpenaiResponses), - // Antigravity 原生处理 Gemini 路径;其它格式需显式允许转换后再走 Gemini 兼容层。 - "antigravity" => mask.insert(InboundApiFormat::Gemini), + // Align with CLIProxyAPIPlus: + // - Antigravity natively supports Gemini routes. + // - It also supports Claude Code /v1/messages (Anthropic format) out-of-the-box via + // Anthropic->Gemini request conversion + Antigravity wrapping. Do not gate this behind + // convert_from_map, otherwise users must "enable conversion" manually. + "antigravity" => { + mask.insert(InboundApiFormat::Gemini); + mask.insert(InboundApiFormat::AnthropicMessages); + } _ => {} } mask diff --git a/crates/token_proxy_core/src/proxy/openai_compat.rs b/crates/token_proxy_core/src/proxy/openai_compat.rs index e31451e..06ef812 100644 --- a/crates/token_proxy_core/src/proxy/openai_compat.rs +++ b/crates/token_proxy_core/src/proxy/openai_compat.rs @@ -15,6 +15,7 @@ mod input; mod message; mod tools; mod usage; +pub(crate) use usage::map_usage_chat_to_responses; pub(crate) const CHAT_PATH: &str = "/v1/chat/completions"; pub(crate) const RESPONSES_PATH: &str = "/v1/responses"; diff --git a/crates/token_proxy_core/src/proxy/openai_compat.test.part1.rs b/crates/token_proxy_core/src/proxy/openai_compat.test.part1.rs new file mode 100644 index 0000000..7a7716b --- /dev/null +++ b/crates/token_proxy_core/src/proxy/openai_compat.test.part1.rs @@ -0,0 +1,512 @@ +use super::*; + +#[test] +fn chat_request_to_responses_maps_common_fields() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let chat_messages = json!([ + { "role": "user", "content": "hi" }, + { "role": "assistant", "content": "hello" } + ]); + let input = bytes_from_json(json!({ + "model": "gpt-4.1", + "messages": chat_messages, + "stream": true, + "temperature": 0.7, + "top_p": 0.9, + // Prefer `max_completion_tokens` over `max_tokens`. + "max_tokens": 111, + "max_completion_tokens": 222 + })); + + let output = run_async(async { + transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) + .await + .expect("transform") + }); + let value = json_from_bytes(output); + + let expected_input = json!([ + { + "type": "message", + "role": "user", + "content": [{ "type": "input_text", "text": "hi" }] + }, + { + "type": "message", + "role": "assistant", + "content": [{ "type": "output_text", "text": "hello" }] + } + ]); + + assert_eq!(value["model"], json!("gpt-4.1")); + assert_eq!(value["input"], expected_input); + assert_eq!(value["stream"], json!(true)); + assert_eq!(value["temperature"], json!(0.7)); + assert_eq!(value["top_p"], json!(0.9)); + assert_eq!(value["max_output_tokens"], json!(222)); + assert!(value.get("messages").is_none()); +} + +#[test] +fn responses_request_to_chat_maps_tools_and_tool_choice() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let parameters = json!({ + "type": "object", + "properties": { "q": { "type": "string" } }, + "required": ["q"] + }); + let input = bytes_from_json(json!({ + "model": "gpt-4.1", + "input": "hello", + "tools": [ + { + "type": "function", + "name": "search", + "description": "Search something", + "parameters": parameters + } + ], + "tool_choice": { "type": "function", "name": "search" }, + "stream": false + })); + + let output = run_async(async { + transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) + .await + .expect("transform") + }); + let value = json_from_bytes(output); + + assert_eq!(value["tools"][0]["type"], json!("function")); + assert_eq!(value["tools"][0]["function"]["name"], json!("search")); + assert_eq!(value["tools"][0]["function"]["description"], json!("Search something")); + assert_eq!(value["tools"][0]["function"]["parameters"], parameters); + assert_eq!(value["tool_choice"]["type"], json!("function")); + assert_eq!(value["tool_choice"]["function"]["name"], json!("search")); +} + +#[test] +fn chat_request_to_responses_maps_tools_and_tool_choice() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let parameters = json!({ + "type": "object", + "properties": { "q": { "type": "string" } }, + "required": ["q"] + }); + let input = bytes_from_json(json!({ + "model": "gpt-4.1", + "messages": [{ "role": "user", "content": "hi" }], + "tools": [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search something", + "parameters": parameters + } + } + ], + "tool_choice": { "type": "function", "function": { "name": "search" } }, + "stream": false + })); + + let output = run_async(async { + transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) + .await + .expect("transform") + }); + let value = json_from_bytes(output); + + assert_eq!(value["tools"][0]["type"], json!("function")); + assert_eq!(value["tools"][0]["name"], json!("search")); + assert_eq!(value["tools"][0]["description"], json!("Search something")); + assert_eq!(value["tools"][0]["parameters"], parameters); + assert_eq!(value["tool_choice"]["type"], json!("function")); + assert_eq!(value["tool_choice"]["name"], json!("search")); +} + +#[test] +fn responses_request_to_chat_instructions_becomes_system_message() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let input = bytes_from_json(json!({ + "model": "gpt-4.1", + "input": "hello", + "instructions": "be concise", + "stream": false, + "max_output_tokens": 99 + })); + + let output = run_async(async { + transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) + .await + .expect("transform") + }); + let value = json_from_bytes(output); + let messages = value["messages"].as_array().expect("messages array"); + + assert_eq!(value["model"], json!("gpt-4.1")); + assert_eq!(value["stream"], json!(false)); + assert_eq!(value["max_completion_tokens"], json!(99)); + assert_eq!(messages.len(), 2); + assert_eq!(messages[0]["role"], json!("system")); + assert_eq!(messages[0]["content"], json!("be concise")); + assert_eq!(messages[1]["role"], json!("user")); + assert_eq!(messages[1]["content"], json!("hello")); +} + +#[test] +fn responses_request_to_chat_accepts_message_array_input() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let input_messages = json!([{ "role": "user", "content": "hi" }]); + let input = bytes_from_json(json!({ + "model": "gpt-4.1", + "input": input_messages, + "stream": true + })); + + let output = run_async(async { + transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) + .await + .expect("transform") + }); + let value = json_from_bytes(output); + + assert_eq!(value["model"], json!("gpt-4.1")); + assert_eq!(value["stream"], json!(true)); + assert_eq!(value["messages"], input_messages); +} + +#[test] +fn responses_request_to_chat_converts_input_text_content_parts_to_string() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let input_messages = json!([{ + "role": "user", + "content": [ + { "type": "input_text", "text": "分析项目的逻辑缺陷和性能缺陷" } + ] + }]); + let input = bytes_from_json(json!({ + "model": "gpt-4.1", + "input": input_messages, + "stream": false + })); + + let output = run_async(async { + transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) + .await + .expect("transform") + }); + let value = json_from_bytes(output); + + assert_eq!(value["messages"][0]["role"], json!("user")); + assert_eq!( + value["messages"][0]["content"], + json!("分析项目的逻辑缺陷和性能缺陷") + ); +} + +#[test] +fn chat_request_to_responses_maps_response_format() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let input = bytes_from_json(json!({ + "model": "gpt-4.1", + "messages": [{ "role": "user", "content": "hi" }], + "response_format": { + "type": "json_schema", + "json_schema": { + "name": "example", + "schema": { "type": "object", "properties": { "ok": { "type": "boolean" } } } + } + } + })); + + let output = run_async(async { + transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) + .await + .expect("transform") + }); + let value = json_from_bytes(output); + + assert_eq!(value["text"]["format"]["type"], json!("json_schema")); + assert_eq!(value["text"]["format"]["json_schema"]["name"], json!("example")); +} + +#[test] +fn responses_request_to_chat_maps_text_format_to_response_format() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let input = bytes_from_json(json!({ + "model": "gpt-4.1", + "input": "hi", + "text": { "format": { "type": "json_object" } } + })); + + let output = run_async(async { + transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) + .await + .expect("transform") + }); + let value = json_from_bytes(output); + + assert_eq!(value["response_format"]["type"], json!("json_object")); +} + +#[test] +fn responses_response_to_chat_extracts_output_text_and_maps_usage() { + let input = bytes_from_json(json!({ + "id": "resp_123", + "created_at": 1700000000, + "model": "gpt-4.1", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [ + { "type": "output_text", "text": "Hello", "annotations": [] }, + { "type": "output_text", "text": " world", "annotations": [] } + ] + } + ], + "usage": { + "input_tokens": 1, + "output_tokens": 2, + "total_tokens": 3, + "output_tokens_details": { "reasoning_tokens": 7 } + } + })); + + let output = transform_response_body(FormatTransform::ResponsesToChat, &input, None).expect("transform"); + let value = json_from_bytes(output); + + assert_eq!(value["id"], json!("resp_123")); + assert_eq!(value["object"], json!("chat.completion")); + assert_eq!(value["created"], json!(1700000000)); + assert_eq!(value["model"], json!("gpt-4.1")); + assert_eq!(value["choices"][0]["message"]["role"], json!("assistant")); + assert_eq!(value["choices"][0]["message"]["content"], json!("Hello world")); + assert_eq!(value["choices"][0]["finish_reason"], json!("stop")); + assert_eq!(value["usage"]["prompt_tokens"], json!(1)); + assert_eq!(value["usage"]["completion_tokens"], json!(2)); + assert_eq!(value["usage"]["total_tokens"], json!(3)); + assert_eq!( + value["usage"]["completion_tokens_details"]["reasoning_tokens"], + json!(7) + ); +} + +#[test] +fn responses_response_to_chat_maps_reasoning_content() { + let input = bytes_from_json(json!({ + "id": "resp_reason", + "created_at": 1700000002, + "model": "gpt-4.1", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [ + { "type": "reasoning_text", "text": "think", "annotations": [] }, + { "type": "output_text", "text": "ok", "annotations": [] } + ] + } + ] + })); + + let output = transform_response_body(FormatTransform::ResponsesToChat, &input, None).expect("transform"); + let value = json_from_bytes(output); + + let message = &value["choices"][0]["message"]; + assert_eq!(message["content"], json!("ok")); + assert_eq!(message["reasoning_content"], json!("think")); +} + +#[test] +fn responses_response_to_chat_includes_tool_calls_and_multimodal_content() { + let input = bytes_from_json(json!({ + "id": "resp_456", + "created_at": 1700000001, + "model": "gpt-4.1", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [ + { "type": "output_text", "text": "Hello", "annotations": [] }, + { "type": "output_image", "image_url": { "url": "https://example.com/a.png" } } + ] + }, + { + "type": "function_call", + "call_id": "call_foo", + "name": "doThing", + "arguments": "{\"a\":1}" + } + ], + "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } + })); + + let output = transform_response_body(FormatTransform::ResponsesToChat, &input, None).expect("transform"); + let value = json_from_bytes(output); + + let message = &value["choices"][0]["message"]; + assert_eq!(message["role"], json!("assistant")); + assert_eq!(message["content"][0]["type"], json!("text")); + assert_eq!(message["content"][0]["text"], json!("Hello")); + assert_eq!(message["content"][1]["type"], json!("image_url")); + assert_eq!( + message["content"][1]["image_url"]["url"], + json!("https://example.com/a.png") + ); + assert_eq!(message["tool_calls"][0]["id"], json!("call_foo")); + assert_eq!(message["tool_calls"][0]["function"]["name"], json!("doThing")); + assert_eq!(message["tool_calls"][0]["function"]["arguments"], json!("{\"a\":1}")); + assert_eq!(value["choices"][0]["finish_reason"], json!("tool_calls")); +} + +#[test] +fn chat_response_to_responses_extracts_choice_text_and_maps_usage() { + let input = bytes_from_json(json!({ + "id": "chatcmpl_123", + "created": 1700000000, + "model": "gpt-4.1", + "choices": [ + { "index": 0, "message": { "role": "assistant", "content": "Hello" } } + ], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3, + "completion_tokens_details": { "reasoning_tokens": 5 } + } + })); + + let output = transform_response_body(FormatTransform::ChatToResponses, &input, None).expect("transform"); + let value = json_from_bytes(output); + + assert_eq!(value["id"], json!("chatcmpl_123")); + assert_eq!(value["object"], json!("response")); + assert_eq!(value["created_at"], json!(1700000000)); + assert_eq!(value["model"], json!("gpt-4.1")); + assert_eq!(value["output"][0]["type"], json!("message")); + assert_eq!(value["output"][0]["role"], json!("assistant")); + assert_eq!(value["output"][0]["content"][0]["type"], json!("output_text")); + assert_eq!(value["output"][0]["content"][0]["text"], json!("Hello")); + assert_eq!(value["usage"]["input_tokens"], json!(1)); + assert_eq!(value["usage"]["output_tokens"], json!(2)); + assert_eq!(value["usage"]["total_tokens"], json!(3)); + assert_eq!( + value["usage"]["output_tokens_details"]["reasoning_tokens"], + json!(5) + ); +} + +#[test] +fn chat_response_to_responses_maps_finish_reason_to_incomplete_details() { + let input = bytes_from_json(json!({ + "id": "chatcmpl_456", + "created": 1700000002, + "model": "gpt-4.1", + "choices": [ + { "index": 0, "message": { "role": "assistant", "content": "Hello" }, "finish_reason": "length" } + ], + "usage": { "prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3 } + })); + + let output = transform_response_body(FormatTransform::ChatToResponses, &input, None).expect("transform"); + let value = json_from_bytes(output); + + assert_eq!(value["status"], json!("incomplete")); + assert_eq!(value["incomplete_details"]["reason"], json!("max_tokens")); +} + +#[test] +fn responses_request_to_chat_converts_function_call_output_to_tool_message() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let input = bytes_from_json(json!({ + "model": "gpt-4.1", + "input": [ + { "type": "function_call_output", "call_id": "call_123", "output": "ok" } + ], + "stream": false + })); + + let output = run_async(async { + transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) + .await + .expect("transform") + }); + let value = json_from_bytes(output); + let messages = value["messages"].as_array().expect("messages array"); + + assert_eq!(messages.len(), 1); + assert_eq!(messages[0]["role"], json!("tool")); + assert_eq!(messages[0]["tool_call_id"], json!("call_123")); + assert_eq!(messages[0]["content"], json!("ok")); +} + +#[test] +fn chat_response_to_responses_maps_tool_calls_into_output() { + let input = bytes_from_json(json!({ + "id": "chatcmpl_123", + "created": 1700000000, + "model": "gpt-4.1", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_foo", + "type": "function", + "function": { + "name": "getRandomNumber", + "arguments": "{\"a\":\"0\"}" + } + } + ] + } + } + ], + "usage": { "prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3 } + })); + + let output = transform_response_body(FormatTransform::ChatToResponses, &input, None).expect("transform"); + let value = json_from_bytes(output); + + assert_eq!(value["id"], json!("chatcmpl_123")); + assert_eq!(value["object"], json!("response")); + assert_eq!(value["created_at"], json!(1700000000)); + assert_eq!(value["model"], json!("gpt-4.1")); + assert_eq!(value["output"][0]["type"], json!("function_call")); + assert_eq!(value["output"][0]["call_id"], json!("call_foo")); + assert_eq!(value["output"][0]["name"], json!("getRandomNumber")); + assert_eq!(value["output"][0]["arguments"], json!("{\"a\":\"0\"}")); + assert_eq!(value["usage"]["input_tokens"], json!(1)); + assert_eq!(value["usage"]["output_tokens"], json!(2)); + assert_eq!(value["usage"]["total_tokens"], json!(3)); +} + +#[test] +fn chat_request_to_responses_rejects_missing_messages() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let input = bytes_from_json(json!({ "model": "gpt-4.1" })); + let err = run_async(async { + transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) + .await + .expect_err("should fail") + }); + assert!(err.contains("messages")); +} + +#[test] +fn transform_request_body_rejects_non_json() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let input = Bytes::from_static(b"not-json"); + let err = run_async(async { + transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) + .await + .expect_err("should fail") + }); + assert!(err.contains("JSON")); +} diff --git a/crates/token_proxy_core/src/proxy/openai_compat.test.part2.rs b/crates/token_proxy_core/src/proxy/openai_compat.test.part2.rs new file mode 100644 index 0000000..2dee692 --- /dev/null +++ b/crates/token_proxy_core/src/proxy/openai_compat.test.part2.rs @@ -0,0 +1,154 @@ +use super::*; + +#[test] +fn responses_and_gemini_request_conversions() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let responses_value = transform_request_value( + FormatTransform::ResponsesToGemini, + json!({ + "model": "gpt-4.1", + "input": "hi", + "instructions": "sys", + "temperature": 0.5, + "top_p": 0.9, + "max_output_tokens": 128, + "stop": ["a", "b"], + "seed": 7 + }), + &http_clients, + None, + ); + assert_eq!(responses_value["contents"][0]["parts"][0]["text"], json!("hi")); + assert_eq!(responses_value["systemInstruction"]["parts"][0]["text"], json!("sys")); + assert_eq!(responses_value["generationConfig"]["maxOutputTokens"], json!(128)); + assert_eq!(responses_value["generationConfig"]["stopSequences"], json!(["a", "b"])); + assert_eq!(responses_value["generationConfig"]["seed"], json!(7)); + let gemini_value = transform_request_value( + FormatTransform::GeminiToResponses, + json!({ + "model": "gemini-1.5-flash", + "contents": [{ "role": "user", "parts": [{ "text": "hello" }] }], + "systemInstruction": { "parts": [{ "text": "rules" }] }, + "generationConfig": { "maxOutputTokens": 64, "topP": 0.8 } + }), + &http_clients, + None, + ); + assert_eq!(gemini_value["model"], json!("gemini-1.5-flash")); + assert_eq!(gemini_value["instructions"], json!("rules")); + assert_eq!(gemini_value["input"][0]["content"][0]["text"], json!("hello")); + assert_eq!(gemini_value["max_output_tokens"], json!(64)); + assert_eq!(gemini_value["top_p"], json!(0.8)); +} +#[test] +fn gemini_and_anthropic_request_conversions() { + let http_clients = ProxyHttpClients::new().expect("http clients"); + let gemini_value = transform_request_value( + FormatTransform::GeminiToAnthropic, + json!({ + "contents": [{ "role": "user", "parts": [{ "text": "ping" }] }], + "systemInstruction": { "parts": [{ "text": "sys" }] }, + "generationConfig": { "maxOutputTokens": 42 } + }), + &http_clients, + Some("claude-3-5-sonnet"), + ); + assert_eq!(gemini_value["model"], json!("claude-3-5-sonnet")); + assert_eq!(gemini_value["system"][0]["text"], json!("sys")); + assert_eq!(gemini_value["messages"][0]["content"][0]["text"], json!("ping")); + assert_eq!(gemini_value["max_tokens"], json!(42)); + let anthropic_value = transform_request_value( + FormatTransform::AnthropicToGemini, + json!({ + "model": "claude-3-5-sonnet", + "max_tokens": 321, + "system": "guard", + "stop_sequences": ["x"], + "messages": [{ "role": "user", "content": [{ "type": "text", "text": "yo" }] }] + }), + &http_clients, + None, + ); + assert_eq!(anthropic_value["systemInstruction"]["parts"][0]["text"], json!("guard")); + assert_eq!(anthropic_value["contents"][0]["parts"][0]["text"], json!("yo")); + assert_eq!(anthropic_value["generationConfig"]["maxOutputTokens"], json!(321)); + assert_eq!(anthropic_value["generationConfig"]["stopSequences"], json!(["x"])); +} +#[test] +fn responses_and_gemini_response_conversions() { + let responses_value = transform_response_value( + FormatTransform::ResponsesToGemini, + json!({ + "id": "resp_1", + "created_at": 1700000000, + "model": "gpt-4.1", + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{ "type": "output_text", "text": "Hello", "annotations": [] }] + } + ], + "usage": { "input_tokens": 2, "output_tokens": 3, "total_tokens": 5 } + }), + None, + ); + assert_eq!(responses_value["candidates"][0]["content"]["parts"][0]["text"], json!("Hello")); + assert_eq!(responses_value["usageMetadata"]["promptTokenCount"], json!(2)); + assert_eq!(responses_value["usageMetadata"]["candidatesTokenCount"], json!(3)); + assert_eq!(responses_value["usageMetadata"]["totalTokenCount"], json!(5)); + let gemini_value = transform_response_value( + FormatTransform::GeminiToResponses, + json!({ + "candidates": [ + { "content": { "role": "model", "parts": [{ "text": "Hi" }] }, "finishReason": "STOP" } + ], + "usageMetadata": { + "promptTokenCount": 4, + "candidatesTokenCount": 6, + "totalTokenCount": 10 + } + }), + Some("gemini-1.5-pro"), + ); + assert_eq!(gemini_value["output"][0]["content"][0]["text"], json!("Hi")); + assert_eq!(gemini_value["usage"]["input_tokens"], json!(4)); + assert_eq!(gemini_value["usage"]["output_tokens"], json!(6)); + assert_eq!(gemini_value["usage"]["total_tokens"], json!(10)); +} +#[test] +fn gemini_and_anthropic_response_conversions() { + let gemini_value = transform_response_value( + FormatTransform::GeminiToAnthropic, + json!({ + "candidates": [ + { "content": { "role": "model", "parts": [{ "text": "Howdy" }] }, "finishReason": "STOP" } + ], + "usageMetadata": { + "promptTokenCount": 1, + "candidatesTokenCount": 2, + "totalTokenCount": 3 + } + }), + Some("claude-3-5-sonnet"), + ); + assert_eq!(gemini_value["model"], json!("claude-3-5-sonnet")); + assert_eq!(gemini_value["content"][0]["text"], json!("Howdy")); + assert_eq!(gemini_value["usage"]["input_tokens"], json!(1)); + assert_eq!(gemini_value["usage"]["output_tokens"], json!(2)); + assert_eq!(gemini_value["stop_reason"], json!("end_turn")); + let anthropic_value = transform_response_value( + FormatTransform::AnthropicToGemini, + json!({ + "id": "msg_1", + "model": "claude-3-5-sonnet", + "content": [{ "type": "text", "text": "Yo" }], + "usage": { "input_tokens": 4, "output_tokens": 6 } + }), + None, + ); + assert_eq!(anthropic_value["candidates"][0]["content"]["parts"][0]["text"], json!("Yo")); + assert_eq!(anthropic_value["usageMetadata"]["promptTokenCount"], json!(4)); + assert_eq!(anthropic_value["usageMetadata"]["candidatesTokenCount"], json!(6)); + assert_eq!(anthropic_value["usageMetadata"]["totalTokenCount"], json!(10)); +} diff --git a/crates/token_proxy_core/src/proxy/openai_compat.test.rs b/crates/token_proxy_core/src/proxy/openai_compat.test.rs index 2f58003..2eac4f8 100644 --- a/crates/token_proxy_core/src/proxy/openai_compat.test.rs +++ b/crates/token_proxy_core/src/proxy/openai_compat.test.rs @@ -1,6 +1,7 @@ use super::*; use axum::body::Bytes; use serde_json::{json, Value}; + use crate::proxy::http_client::ProxyHttpClients; fn run_async(future: impl std::future::Future) -> T { @@ -16,6 +17,7 @@ fn bytes_from_json(value: Value) -> Bytes { fn json_from_bytes(bytes: Bytes) -> Value { serde_json::from_slice(&bytes).expect("parse JSON") } + fn transform_request_value( transform: FormatTransform, input: Value, @@ -30,653 +32,16 @@ fn transform_request_value( }); json_from_bytes(output) } + fn transform_response_value(transform: FormatTransform, input: Value, model_hint: Option<&str>) -> Value { let bytes = bytes_from_json(input); let output = transform_response_body(transform, &bytes, model_hint).expect("transform"); json_from_bytes(output) } -#[test] -fn chat_request_to_responses_maps_common_fields() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let chat_messages = json!([ - { "role": "user", "content": "hi" }, - { "role": "assistant", "content": "hello" } - ]); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "messages": chat_messages, - "stream": true, - "temperature": 0.7, - "top_p": 0.9, - // Prefer `max_completion_tokens` over `max_tokens`. - "max_tokens": 111, - "max_completion_tokens": 222 - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - let expected_input = json!([ - { - "type": "message", - "role": "user", - "content": [{ "type": "input_text", "text": "hi" }] - }, - { - "type": "message", - "role": "assistant", - "content": [{ "type": "output_text", "text": "hello" }] - } - ]); - - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["input"], expected_input); - assert_eq!(value["stream"], json!(true)); - assert_eq!(value["temperature"], json!(0.7)); - assert_eq!(value["top_p"], json!(0.9)); - assert_eq!(value["max_output_tokens"], json!(222)); - assert!(value.get("messages").is_none()); -} - -#[test] -fn responses_request_to_chat_maps_tools_and_tool_choice() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let parameters = json!({ - "type": "object", - "properties": { "q": { "type": "string" } }, - "required": ["q"] - }); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": "hello", - "tools": [ - { - "type": "function", - "name": "search", - "description": "Search something", - "parameters": parameters - } - ], - "tool_choice": { "type": "function", "name": "search" }, - "stream": false - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["tools"][0]["type"], json!("function")); - assert_eq!(value["tools"][0]["function"]["name"], json!("search")); - assert_eq!(value["tools"][0]["function"]["description"], json!("Search something")); - assert_eq!(value["tools"][0]["function"]["parameters"], parameters); - assert_eq!(value["tool_choice"]["type"], json!("function")); - assert_eq!(value["tool_choice"]["function"]["name"], json!("search")); -} - -#[test] -fn chat_request_to_responses_maps_tools_and_tool_choice() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let parameters = json!({ - "type": "object", - "properties": { "q": { "type": "string" } }, - "required": ["q"] - }); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "messages": [{ "role": "user", "content": "hi" }], - "tools": [ - { - "type": "function", - "function": { - "name": "search", - "description": "Search something", - "parameters": parameters - } - } - ], - "tool_choice": { "type": "function", "function": { "name": "search" } }, - "stream": false - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["tools"][0]["type"], json!("function")); - assert_eq!(value["tools"][0]["name"], json!("search")); - assert_eq!(value["tools"][0]["description"], json!("Search something")); - assert_eq!(value["tools"][0]["parameters"], parameters); - assert_eq!(value["tool_choice"]["type"], json!("function")); - assert_eq!(value["tool_choice"]["name"], json!("search")); -} - -#[test] -fn responses_request_to_chat_instructions_becomes_system_message() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": "hello", - "instructions": "be concise", - "stream": false, - "max_output_tokens": 99 - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - let messages = value["messages"].as_array().expect("messages array"); - - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["stream"], json!(false)); - assert_eq!(value["max_completion_tokens"], json!(99)); - assert_eq!(messages.len(), 2); - assert_eq!(messages[0]["role"], json!("system")); - assert_eq!(messages[0]["content"], json!("be concise")); - assert_eq!(messages[1]["role"], json!("user")); - assert_eq!(messages[1]["content"], json!("hello")); -} - -#[test] -fn responses_request_to_chat_accepts_message_array_input() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input_messages = json!([{ "role": "user", "content": "hi" }]); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": input_messages, - "stream": true - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["stream"], json!(true)); - assert_eq!(value["messages"], input_messages); -} - -#[test] -fn responses_request_to_chat_converts_input_text_content_parts_to_string() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input_messages = json!([{ - "role": "user", - "content": [ - { "type": "input_text", "text": "分析项目的逻辑缺陷和性能缺陷" } - ] - }]); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": input_messages, - "stream": false - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["messages"][0]["role"], json!("user")); - assert_eq!( - value["messages"][0]["content"], - json!("分析项目的逻辑缺陷和性能缺陷") - ); -} - -#[test] -fn chat_request_to_responses_maps_response_format() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "messages": [{ "role": "user", "content": "hi" }], - "response_format": { - "type": "json_schema", - "json_schema": { - "name": "example", - "schema": { "type": "object", "properties": { "ok": { "type": "boolean" } } } - } - } - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["text"]["format"]["type"], json!("json_schema")); - assert_eq!(value["text"]["format"]["json_schema"]["name"], json!("example")); -} - -#[test] -fn responses_request_to_chat_maps_text_format_to_response_format() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": "hi", - "text": { "format": { "type": "json_object" } } - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - - assert_eq!(value["response_format"]["type"], json!("json_object")); -} - -#[test] -fn responses_response_to_chat_extracts_output_text_and_maps_usage() { - let input = bytes_from_json(json!({ - "id": "resp_123", - "created_at": 1700000000, - "model": "gpt-4.1", - "output": [ - { - "type": "message", - "role": "assistant", - "content": [ - { "type": "output_text", "text": "Hello", "annotations": [] }, - { "type": "output_text", "text": " world", "annotations": [] } - ] - } - ], - "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } - })); - - let output = transform_response_body(FormatTransform::ResponsesToChat, &input, None).expect("transform"); - let value = json_from_bytes(output); - - assert_eq!(value["id"], json!("resp_123")); - assert_eq!(value["object"], json!("chat.completion")); - assert_eq!(value["created"], json!(1700000000)); - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["choices"][0]["message"]["role"], json!("assistant")); - assert_eq!(value["choices"][0]["message"]["content"], json!("Hello world")); - assert_eq!(value["choices"][0]["finish_reason"], json!("stop")); - assert_eq!(value["usage"]["prompt_tokens"], json!(1)); - assert_eq!(value["usage"]["completion_tokens"], json!(2)); - assert_eq!(value["usage"]["total_tokens"], json!(3)); -} - -#[test] -fn responses_response_to_chat_maps_reasoning_content() { - let input = bytes_from_json(json!({ - "id": "resp_reason", - "created_at": 1700000002, - "model": "gpt-4.1", - "output": [ - { - "type": "message", - "role": "assistant", - "content": [ - { "type": "reasoning_text", "text": "think", "annotations": [] }, - { "type": "output_text", "text": "ok", "annotations": [] } - ] - } - ] - })); - - let output = transform_response_body(FormatTransform::ResponsesToChat, &input, None).expect("transform"); - let value = json_from_bytes(output); - - let message = &value["choices"][0]["message"]; - assert_eq!(message["content"], json!("ok")); - assert_eq!(message["reasoning_content"], json!("think")); -} - -#[test] -fn responses_response_to_chat_includes_tool_calls_and_multimodal_content() { - let input = bytes_from_json(json!({ - "id": "resp_456", - "created_at": 1700000001, - "model": "gpt-4.1", - "output": [ - { - "type": "message", - "role": "assistant", - "content": [ - { "type": "output_text", "text": "Hello", "annotations": [] }, - { "type": "output_image", "image_url": { "url": "https://example.com/a.png" } } - ] - }, - { - "type": "function_call", - "call_id": "call_foo", - "name": "doThing", - "arguments": "{\"a\":1}" - } - ], - "usage": { "input_tokens": 1, "output_tokens": 2, "total_tokens": 3 } - })); - - let output = transform_response_body(FormatTransform::ResponsesToChat, &input, None).expect("transform"); - let value = json_from_bytes(output); - - let message = &value["choices"][0]["message"]; - assert_eq!(message["role"], json!("assistant")); - assert_eq!(message["content"][0]["type"], json!("text")); - assert_eq!(message["content"][0]["text"], json!("Hello")); - assert_eq!(message["content"][1]["type"], json!("image_url")); - assert_eq!( - message["content"][1]["image_url"]["url"], - json!("https://example.com/a.png") - ); - assert_eq!(message["tool_calls"][0]["id"], json!("call_foo")); - assert_eq!(message["tool_calls"][0]["function"]["name"], json!("doThing")); - assert_eq!(message["tool_calls"][0]["function"]["arguments"], json!("{\"a\":1}")); - assert_eq!(value["choices"][0]["finish_reason"], json!("tool_calls")); -} - -#[test] -fn chat_response_to_responses_extracts_choice_text_and_maps_usage() { - let input = bytes_from_json(json!({ - "id": "chatcmpl_123", - "created": 1700000000, - "model": "gpt-4.1", - "choices": [ - { "index": 0, "message": { "role": "assistant", "content": "Hello" } } - ], - "usage": { "prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3 } - })); - - let output = transform_response_body(FormatTransform::ChatToResponses, &input, None).expect("transform"); - let value = json_from_bytes(output); - - assert_eq!(value["id"], json!("chatcmpl_123")); - assert_eq!(value["object"], json!("response")); - assert_eq!(value["created_at"], json!(1700000000)); - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["output"][0]["type"], json!("message")); - assert_eq!(value["output"][0]["role"], json!("assistant")); - assert_eq!(value["output"][0]["content"][0]["type"], json!("output_text")); - assert_eq!(value["output"][0]["content"][0]["text"], json!("Hello")); - assert_eq!(value["usage"]["input_tokens"], json!(1)); - assert_eq!(value["usage"]["output_tokens"], json!(2)); - assert_eq!(value["usage"]["total_tokens"], json!(3)); -} - -#[test] -fn chat_response_to_responses_maps_finish_reason_to_incomplete_details() { - let input = bytes_from_json(json!({ - "id": "chatcmpl_456", - "created": 1700000002, - "model": "gpt-4.1", - "choices": [ - { "index": 0, "message": { "role": "assistant", "content": "Hello" }, "finish_reason": "length" } - ], - "usage": { "prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3 } - })); - - let output = transform_response_body(FormatTransform::ChatToResponses, &input, None).expect("transform"); - let value = json_from_bytes(output); - - assert_eq!(value["status"], json!("incomplete")); - assert_eq!(value["incomplete_details"]["reason"], json!("max_tokens")); -} - -#[test] -fn responses_request_to_chat_converts_function_call_output_to_tool_message() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = bytes_from_json(json!({ - "model": "gpt-4.1", - "input": [ - { "type": "function_call_output", "call_id": "call_123", "output": "ok" } - ], - "stream": false - })); - - let output = run_async(async { - transform_request_body(FormatTransform::ResponsesToChat, &input, &http_clients, None) - .await - .expect("transform") - }); - let value = json_from_bytes(output); - let messages = value["messages"].as_array().expect("messages array"); - - assert_eq!(messages.len(), 1); - assert_eq!(messages[0]["role"], json!("tool")); - assert_eq!(messages[0]["tool_call_id"], json!("call_123")); - assert_eq!(messages[0]["content"], json!("ok")); -} - -#[test] -fn chat_response_to_responses_maps_tool_calls_into_output() { - let input = bytes_from_json(json!({ - "id": "chatcmpl_123", - "created": 1700000000, - "model": "gpt-4.1", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_foo", - "type": "function", - "function": { - "name": "getRandomNumber", - "arguments": "{\"a\":\"0\"}" - } - } - ] - } - } - ], - "usage": { "prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3 } - })); - - let output = transform_response_body(FormatTransform::ChatToResponses, &input, None).expect("transform"); - let value = json_from_bytes(output); - - assert_eq!(value["id"], json!("chatcmpl_123")); - assert_eq!(value["object"], json!("response")); - assert_eq!(value["created_at"], json!(1700000000)); - assert_eq!(value["model"], json!("gpt-4.1")); - assert_eq!(value["output"][0]["type"], json!("function_call")); - assert_eq!(value["output"][0]["call_id"], json!("call_foo")); - assert_eq!(value["output"][0]["name"], json!("getRandomNumber")); - assert_eq!(value["output"][0]["arguments"], json!("{\"a\":\"0\"}")); - assert_eq!(value["usage"]["input_tokens"], json!(1)); - assert_eq!(value["usage"]["output_tokens"], json!(2)); - assert_eq!(value["usage"]["total_tokens"], json!(3)); -} - -#[test] -fn chat_request_to_responses_rejects_missing_messages() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = bytes_from_json(json!({ "model": "gpt-4.1" })); - let err = run_async(async { - transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) - .await - .expect_err("should fail") - }); - assert!(err.contains("messages")); -} -#[test] -fn transform_request_body_rejects_non_json() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let input = Bytes::from_static(b"not-json"); - let err = run_async(async { - transform_request_body(FormatTransform::ChatToResponses, &input, &http_clients, None) - .await - .expect_err("should fail") - }); - assert!(err.contains("JSON")); -} +// Split the test suite to keep each file below the project's line limit. +#[path = "openai_compat.test.part1.rs"] +mod part1; +#[path = "openai_compat.test.part2.rs"] +mod part2; -#[test] -fn responses_and_gemini_request_conversions() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let responses_value = transform_request_value( - FormatTransform::ResponsesToGemini, - json!({ - "model": "gpt-4.1", - "input": "hi", - "instructions": "sys", - "temperature": 0.5, - "top_p": 0.9, - "max_output_tokens": 128, - "stop": ["a", "b"], - "seed": 7 - }), - &http_clients, - None, - ); - assert_eq!(responses_value["contents"][0]["parts"][0]["text"], json!("hi")); - assert_eq!(responses_value["systemInstruction"]["parts"][0]["text"], json!("sys")); - assert_eq!(responses_value["generationConfig"]["maxOutputTokens"], json!(128)); - assert_eq!(responses_value["generationConfig"]["stopSequences"], json!(["a", "b"])); - assert_eq!(responses_value["generationConfig"]["seed"], json!(7)); - let gemini_value = transform_request_value( - FormatTransform::GeminiToResponses, - json!({ - "model": "gemini-1.5-flash", - "contents": [{ "role": "user", "parts": [{ "text": "hello" }] }], - "systemInstruction": { "parts": [{ "text": "rules" }] }, - "generationConfig": { "maxOutputTokens": 64, "topP": 0.8 } - }), - &http_clients, - None, - ); - assert_eq!(gemini_value["model"], json!("gemini-1.5-flash")); - assert_eq!(gemini_value["instructions"], json!("rules")); - assert_eq!(gemini_value["input"][0]["content"][0]["text"], json!("hello")); - assert_eq!(gemini_value["max_output_tokens"], json!(64)); - assert_eq!(gemini_value["top_p"], json!(0.8)); -} -#[test] -fn gemini_and_anthropic_request_conversions() { - let http_clients = ProxyHttpClients::new().expect("http clients"); - let gemini_value = transform_request_value( - FormatTransform::GeminiToAnthropic, - json!({ - "contents": [{ "role": "user", "parts": [{ "text": "ping" }] }], - "systemInstruction": { "parts": [{ "text": "sys" }] }, - "generationConfig": { "maxOutputTokens": 42 } - }), - &http_clients, - Some("claude-3-5-sonnet"), - ); - assert_eq!(gemini_value["model"], json!("claude-3-5-sonnet")); - assert_eq!(gemini_value["system"][0]["text"], json!("sys")); - assert_eq!(gemini_value["messages"][0]["content"][0]["text"], json!("ping")); - assert_eq!(gemini_value["max_tokens"], json!(42)); - let anthropic_value = transform_request_value( - FormatTransform::AnthropicToGemini, - json!({ - "model": "claude-3-5-sonnet", - "max_tokens": 321, - "system": "guard", - "stop_sequences": ["x"], - "messages": [{ "role": "user", "content": [{ "type": "text", "text": "yo" }] }] - }), - &http_clients, - None, - ); - assert_eq!(anthropic_value["systemInstruction"]["parts"][0]["text"], json!("guard")); - assert_eq!(anthropic_value["contents"][0]["parts"][0]["text"], json!("yo")); - assert_eq!(anthropic_value["generationConfig"]["maxOutputTokens"], json!(321)); - assert_eq!(anthropic_value["generationConfig"]["stopSequences"], json!(["x"])); -} -#[test] -fn responses_and_gemini_response_conversions() { - let responses_value = transform_response_value( - FormatTransform::ResponsesToGemini, - json!({ - "id": "resp_1", - "created_at": 1700000000, - "model": "gpt-4.1", - "output": [ - { - "type": "message", - "role": "assistant", - "content": [{ "type": "output_text", "text": "Hello", "annotations": [] }] - } - ], - "usage": { "input_tokens": 2, "output_tokens": 3, "total_tokens": 5 } - }), - None, - ); - assert_eq!(responses_value["candidates"][0]["content"]["parts"][0]["text"], json!("Hello")); - assert_eq!(responses_value["usageMetadata"]["promptTokenCount"], json!(2)); - assert_eq!(responses_value["usageMetadata"]["candidatesTokenCount"], json!(3)); - assert_eq!(responses_value["usageMetadata"]["totalTokenCount"], json!(5)); - let gemini_value = transform_response_value( - FormatTransform::GeminiToResponses, - json!({ - "candidates": [ - { "content": { "role": "model", "parts": [{ "text": "Hi" }] }, "finishReason": "STOP" } - ], - "usageMetadata": { - "promptTokenCount": 4, - "candidatesTokenCount": 6, - "totalTokenCount": 10 - } - }), - Some("gemini-1.5-pro"), - ); - assert_eq!(gemini_value["output"][0]["content"][0]["text"], json!("Hi")); - assert_eq!(gemini_value["usage"]["input_tokens"], json!(4)); - assert_eq!(gemini_value["usage"]["output_tokens"], json!(6)); - assert_eq!(gemini_value["usage"]["total_tokens"], json!(10)); -} -#[test] -fn gemini_and_anthropic_response_conversions() { - let gemini_value = transform_response_value( - FormatTransform::GeminiToAnthropic, - json!({ - "candidates": [ - { "content": { "role": "model", "parts": [{ "text": "Howdy" }] }, "finishReason": "STOP" } - ], - "usageMetadata": { - "promptTokenCount": 1, - "candidatesTokenCount": 2, - "totalTokenCount": 3 - } - }), - Some("claude-3-5-sonnet"), - ); - assert_eq!(gemini_value["model"], json!("claude-3-5-sonnet")); - assert_eq!(gemini_value["content"][0]["text"], json!("Howdy")); - assert_eq!(gemini_value["usage"]["input_tokens"], json!(1)); - assert_eq!(gemini_value["usage"]["output_tokens"], json!(2)); - assert_eq!(gemini_value["stop_reason"], json!("end_turn")); - let anthropic_value = transform_response_value( - FormatTransform::AnthropicToGemini, - json!({ - "id": "msg_1", - "model": "claude-3-5-sonnet", - "content": [{ "type": "text", "text": "Yo" }], - "usage": { "input_tokens": 4, "output_tokens": 6 } - }), - None, - ); - assert_eq!(anthropic_value["candidates"][0]["content"]["parts"][0]["text"], json!("Yo")); - assert_eq!(anthropic_value["usageMetadata"]["promptTokenCount"], json!(4)); - assert_eq!(anthropic_value["usageMetadata"]["candidatesTokenCount"], json!(6)); - assert_eq!(anthropic_value["usageMetadata"]["totalTokenCount"], json!(10)); -} diff --git a/crates/token_proxy_core/src/proxy/openai_compat/usage.rs b/crates/token_proxy_core/src/proxy/openai_compat/usage.rs index 6a5c377..6c3c884 100644 --- a/crates/token_proxy_core/src/proxy/openai_compat/usage.rs +++ b/crates/token_proxy_core/src/proxy/openai_compat/usage.rs @@ -1,6 +1,6 @@ -use serde_json::{json, Value}; +use serde_json::{json, Map, Value}; -pub(super) fn map_usage_responses_to_chat(usage: &Value) -> Option { +pub(crate) fn map_usage_responses_to_chat(usage: &Value) -> Option { let usage = usage.as_object()?; let input = usage.get("input_tokens").and_then(Value::as_u64); let output = usage.get("output_tokens").and_then(Value::as_u64); @@ -14,28 +14,58 @@ pub(super) fn map_usage_responses_to_chat(usage: &Value) -> Option { if input.is_none() && output.is_none() && total.is_none() { return None; } - Some(json!({ - "prompt_tokens": input, - "completion_tokens": output, - "total_tokens": total - })) + + let mut mapped = Map::new(); + mapped.insert("prompt_tokens".to_string(), json!(input)); + mapped.insert("completion_tokens".to_string(), json!(output)); + mapped.insert("total_tokens".to_string(), json!(total)); + + // Preserve reasoning token details when converting Responses -> Chat. + let reasoning_tokens = usage + .get("output_tokens_details") + .and_then(Value::as_object) + .and_then(|details| details.get("reasoning_tokens")) + .and_then(Value::as_u64) + .unwrap_or(0); + mapped.insert( + "completion_tokens_details".to_string(), + json!({ "reasoning_tokens": reasoning_tokens }), + ); + + Some(Value::Object(mapped)) } -pub(super) fn map_usage_chat_to_responses(usage: &Value) -> Option { +pub(crate) fn map_usage_chat_to_responses(usage: &Value) -> Option { let usage = usage.as_object()?; let prompt = usage.get("prompt_tokens").and_then(Value::as_u64); let completion = usage.get("completion_tokens").and_then(Value::as_u64); - let total = usage.get("total_tokens").and_then(Value::as_u64).or_else(|| match (prompt, completion) { - (Some(prompt), Some(completion)) => prompt.checked_add(completion), - _ => None, - }); + let total = usage + .get("total_tokens") + .and_then(Value::as_u64) + .or_else(|| match (prompt, completion) { + (Some(prompt), Some(completion)) => prompt.checked_add(completion), + _ => None, + }); if prompt.is_none() && completion.is_none() && total.is_none() { return None; } - Some(json!({ - "input_tokens": prompt, - "output_tokens": completion, - "total_tokens": total - })) -} + let mut mapped = Map::new(); + mapped.insert("input_tokens".to_string(), json!(prompt)); + mapped.insert("output_tokens".to_string(), json!(completion)); + mapped.insert("total_tokens".to_string(), json!(total)); + + // Preserve reasoning token details when converting Chat -> Responses. + let reasoning_tokens = usage + .get("completion_tokens_details") + .and_then(Value::as_object) + .and_then(|details| details.get("reasoning_tokens")) + .and_then(Value::as_u64) + .unwrap_or(0); + mapped.insert( + "output_tokens_details".to_string(), + json!({ "reasoning_tokens": reasoning_tokens }), + ); + + Some(Value::Object(mapped)) +} diff --git a/crates/token_proxy_core/src/proxy/response.test.part2.rs b/crates/token_proxy_core/src/proxy/response.test.part2.rs new file mode 100644 index 0000000..146109c --- /dev/null +++ b/crates/token_proxy_core/src/proxy/response.test.part2.rs @@ -0,0 +1,125 @@ +use axum::body::Bytes; +use futures_util::StreamExt; +use serde_json::{json, Value}; +use std::{sync::Arc, time::Instant}; + +use crate::proxy::log::{LogContext, LogWriter}; + +#[test] +fn stream_gemini_to_anthropic_emits_single_input_json_delta_for_tool_calls() { + super::run_async(async { + let context = LogContext { + path: "/v1/messages".to_string(), + provider: "antigravity".to_string(), + upstream_id: "unit-test".to_string(), + model: Some("unit-model".to_string()), + mapped_model: Some("unit-model".to_string()), + stream: true, + status: 200, + upstream_request_id: None, + request_headers: None, + request_body: None, + ttfb_ms: None, + start: Instant::now(), + }; + + let gemini_event = json!({ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "Task", + "args": { + "description": "explore", + "prompt": "scan repo", + "subagent_type": "Explore" + } + } + } + ] + }, + "finishReason": "STOP" + } + ] + }); + let upstream = futures_util::stream::iter(vec![ + Ok::(Bytes::from(format!( + "data: {}\n\n", + gemini_event.to_string() + ))), + Ok::(Bytes::from("data: [DONE]\n\n")), + ]); + + let token_tracker_1 = crate::proxy::token_rate::TokenRateTracker::new() + .register(None, None) + .await; + let chat_stream = crate::proxy::gemini_compat::stream_gemini_to_chat( + upstream, + context.clone(), + Arc::new(LogWriter::new(None)), + token_tracker_1, + ) + .boxed(); + + let token_tracker_2 = crate::proxy::token_rate::TokenRateTracker::new() + .register(None, None) + .await; + let responses_stream = super::super::chat_to_responses::stream_chat_to_responses( + chat_stream, + context.clone(), + Arc::new(LogWriter::new(None)), + token_tracker_2, + ) + .boxed(); + + let token_tracker_3 = crate::proxy::token_rate::TokenRateTracker::new() + .register(None, None) + .await; + let anthropic_stream = super::super::responses_to_anthropic::stream_responses_to_anthropic( + responses_stream, + context, + Arc::new(LogWriter::new(None)), + token_tracker_3, + ); + + let chunks: Vec = anthropic_stream + .map(|item| item.expect("stream item")) + .collect() + .await; + + let mut input_json_deltas: Vec = Vec::new(); + for chunk in &chunks { + let Some((event_type, data)) = super::parse_anthropic_sse(chunk) else { + continue; + }; + if event_type != "content_block_delta" { + continue; + } + if data + .get("delta") + .and_then(|value| value.get("type")) + .and_then(Value::as_str) + != Some("input_json_delta") + { + continue; + } + let Some(partial) = data + .get("delta") + .and_then(|value| value.get("partial_json")) + .and_then(Value::as_str) + else { + continue; + }; + input_json_deltas.push(partial.to_string()); + } + + // If we emit both `.delta` fragments and the final `.done` full arguments, clients will + // concatenate them and end up with invalid JSON (tool input becomes `{}`). + assert_eq!(input_json_deltas.len(), 1); + assert!(input_json_deltas[0].contains("\"description\"")); + assert!(input_json_deltas[0].contains("\"prompt\"")); + assert!(input_json_deltas[0].contains("\"subagent_type\"")); + }); +} diff --git a/crates/token_proxy_core/src/proxy/response.test.rs b/crates/token_proxy_core/src/proxy/response.test.rs index f0615c5..9acf6fc 100644 --- a/crates/token_proxy_core/src/proxy/response.test.rs +++ b/crates/token_proxy_core/src/proxy/response.test.rs @@ -41,6 +41,32 @@ fn parse_sse_json(bytes: &Bytes) -> Option { Some(serde_json::from_str::(data).expect("parse SSE JSON")) } +fn parse_anthropic_sse(bytes: &Bytes) -> Option<(String, Value)> { + let text = String::from_utf8_lossy(bytes); + let mut event_type: Option<&str> = None; + let mut data_line: Option<&str> = None; + for line in text.lines() { + if let Some(value) = line.strip_prefix("event: ") { + event_type = Some(value.trim()); + } + if let Some(value) = line.strip_prefix("data: ") { + data_line = Some(value.trim()); + } + } + let Some(event_type) = event_type else { + return None; + }; + let Some(data_line) = data_line else { + return None; + }; + let data = serde_json::from_str::(data_line).expect("parse anthropic SSE JSON"); + Some((event_type.to_string(), data)) +} + +// Split the test suite to keep each file below the project's line limit. +#[path = "response.test.part2.rs"] +mod part2; + async fn setup_responses_stream() -> (Arc, LogContext, SqlitePool) { let sqlite_pool = create_test_sqlite_pool().await; let log = Arc::new(LogWriter::new(Some(sqlite_pool.clone()))); @@ -282,7 +308,7 @@ fn stream_chat_to_responses_handles_chunk_boundaries_and_emits_created_delta_don )), // Chat usage format. Ok(Bytes::from( - "data: {\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n", + "data: {\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3,\"completion_tokens_details\":{\"reasoning_tokens\":9}}}\n\n", )), Ok(Bytes::from("data: [DONE]\n\n")), ]); @@ -354,6 +380,10 @@ fn stream_chat_to_responses_handles_chunk_boundaries_and_emits_created_delta_don assert_eq!(completed["response"]["usage"]["input_tokens"], json!(1)); assert_eq!(completed["response"]["usage"]["output_tokens"], json!(2)); assert_eq!(completed["response"]["usage"]["total_tokens"], json!(3)); + assert_eq!( + completed["response"]["usage"]["output_tokens_details"]["reasoning_tokens"], + json!(9) + ); assert_eq!(String::from_utf8_lossy(&chunks[9]), "data: [DONE]\n\n"); @@ -394,7 +424,7 @@ fn stream_chat_to_responses_emits_function_call_events_and_includes_them_in_comp )), // Chat usage format. Ok(Bytes::from( - "data: {\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3}}\n\n", + "data: {\"usage\":{\"prompt_tokens\":1,\"completion_tokens\":2,\"total_tokens\":3,\"completion_tokens_details\":{\"reasoning_tokens\":4}}}\n\n", )), Ok(Bytes::from("data: [DONE]\n\n")), ]); @@ -464,6 +494,10 @@ fn stream_chat_to_responses_emits_function_call_events_and_includes_them_in_comp assert_eq!(completed["response"]["usage"]["input_tokens"], json!(1)); assert_eq!(completed["response"]["usage"]["output_tokens"], json!(2)); assert_eq!(completed["response"]["usage"]["total_tokens"], json!(3)); + assert_eq!( + completed["response"]["usage"]["output_tokens_details"]["reasoning_tokens"], + json!(4) + ); assert_eq!(String::from_utf8_lossy(&chunks[7]), "data: [DONE]\n\n"); diff --git a/crates/token_proxy_core/src/proxy/response/chat_to_responses.rs b/crates/token_proxy_core/src/proxy/response/chat_to_responses.rs index 1b71082..027b1de 100644 --- a/crates/token_proxy_core/src/proxy/response/chat_to_responses.rs +++ b/crates/token_proxy_core/src/proxy/response/chat_to_responses.rs @@ -8,8 +8,10 @@ use super::super::sse::SseEventParser; use super::super::token_rate::RequestTokenTracker; use super::super::usage::SseUsageCollector; use format::{snapshot_to_output_item, usage_to_value, OutputItemSnapshot}; +use state_types::{FunctionCallOutput, MessageOutput}; mod format; +mod state_types; pub(super) fn stream_chat_to_responses( upstream: impl futures_util::stream::Stream> @@ -27,20 +29,6 @@ where futures_util::stream::try_unfold(state, |state| async move { state.step().await }) } -struct MessageOutput { - id: String, - output_index: u64, - text: String, -} - -struct FunctionCallOutput { - id: String, - output_index: u64, - call_id: String, - name: String, - arguments: String, -} - struct ChatToResponsesState { upstream: S, parser: SseEventParser, @@ -408,7 +396,14 @@ where self.sent_done = true; let completed_at = (super::now_ms() / 1000) as i64; - let usage = self.collector.finish().usage.map(usage_to_value); + let usage_snapshot = self.collector.finish(); + // Prefer upstream `usage` JSON to preserve breakdown fields (e.g. reasoning tokens). + // Fallback to the normalized TokenUsage counters when upstream did not provide usage. + let usage = usage_snapshot + .usage_json + .as_ref() + .and_then(super::super::openai_compat::map_usage_chat_to_responses) + .or_else(|| usage_snapshot.usage.map(usage_to_value)); let mut snapshots = Vec::new(); if let Some(message) = &self.message { diff --git a/crates/token_proxy_core/src/proxy/response/chat_to_responses/state_types.rs b/crates/token_proxy_core/src/proxy/response/chat_to_responses/state_types.rs new file mode 100644 index 0000000..ee1f703 --- /dev/null +++ b/crates/token_proxy_core/src/proxy/response/chat_to_responses/state_types.rs @@ -0,0 +1,16 @@ +// Small helper types extracted to keep `chat_to_responses.rs` under the project's line limit. + +pub(super) struct MessageOutput { + pub(super) id: String, + pub(super) output_index: u64, + pub(super) text: String, +} + +pub(super) struct FunctionCallOutput { + pub(super) id: String, + pub(super) output_index: u64, + pub(super) call_id: String, + pub(super) name: String, + pub(super) arguments: String, +} + diff --git a/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs b/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs index 20664d0..2b9bb3b 100644 --- a/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs +++ b/crates/token_proxy_core/src/proxy/response/dispatch/buffered.rs @@ -17,13 +17,14 @@ use super::super::super::{ openai_compat::{transform_response_body, FormatTransform}, request_body::ReplayableBody, redact::redact_query_param_value, - server_helpers::log_debug_headers_body, + server_helpers::{log_debug_headers_body, truncate_for_log}, token_rate::RequestTokenTracker, usage::extract_usage_from_response, UPSTREAM_NO_DATA_TIMEOUT, }; const DEBUG_BODY_LOG_LIMIT_BYTES: usize = usize::MAX; +const ANTIGRAVITY_ERROR_LOG_LIMIT_BYTES: usize = 8 * 1024; pub(super) async fn build_buffered_response( status: StatusCode, @@ -49,7 +50,10 @@ pub(super) async fn build_buffered_response( DEBUG_BODY_LOG_LIMIT_BYTES, ) .await; - let bytes = if context.provider == PROVIDER_ANTIGRAVITY { + if context.provider == PROVIDER_ANTIGRAVITY && !status.is_success() { + log_antigravity_error_body(status, &bytes); + } + let bytes = if context.provider == PROVIDER_ANTIGRAVITY && status.is_success() { match antigravity_compat::unwrap_response(&bytes) { Ok(unwrapped) => unwrapped, Err(message) => { @@ -146,6 +150,17 @@ fn convert_success_body( } } +fn log_antigravity_error_body(status: StatusCode, bytes: &Bytes) { + let body_text = String::from_utf8_lossy(bytes); + let truncated = truncate_for_log(&body_text, ANTIGRAVITY_ERROR_LOG_LIMIT_BYTES); + // 仅在错误时记录,避免日志噪音与性能影响。 + tracing::warn!( + status = %status, + body = %truncated, + "antigravity upstream error body" + ); +} + fn convert_kiro_to_anthropic_body( bytes: &Bytes, context: &mut LogContext, diff --git a/crates/token_proxy_core/src/proxy/response/responses_to_anthropic.rs b/crates/token_proxy_core/src/proxy/response/responses_to_anthropic.rs index 663ef11..5937fbf 100644 --- a/crates/token_proxy_core/src/proxy/response/responses_to_anthropic.rs +++ b/crates/token_proxy_core/src/proxy/response/responses_to_anthropic.rs @@ -284,6 +284,11 @@ where "delta": { "type": "input_json_delta", "partial_json": delta } }), )); + // Claude Code 会把 input_json_delta 的 partial_json 逐段拼接成最终 JSON。 + // 若我们在 arguments.done 再发送一次完整 arguments,会导致拼接重复并变成非法 JSON(最终 tool input 变成 {})。 + if let Some(state) = self.tool_uses.get_mut(item_id) { + state.sent_input = true; + } } fn handle_function_call_arguments_done(&mut self, value: &Value) { diff --git a/crates/token_proxy_core/src/proxy/server_helpers.rs b/crates/token_proxy_core/src/proxy/server_helpers.rs index b285c51..f6b57a8 100644 --- a/crates/token_proxy_core/src/proxy/server_helpers.rs +++ b/crates/token_proxy_core/src/proxy/server_helpers.rs @@ -206,6 +206,15 @@ pub(crate) async fn log_debug_headers_body( } } +pub(crate) fn truncate_for_log(value: &str, max_bytes: usize) -> String { + if value.len() <= max_bytes { + return value.to_string(); + } + let mut out = value.chars().take(max_bytes).collect::(); + let _ = out.push_str("...[truncated]"); + out +} + fn snapshot_headers_raw(headers: &HeaderMap) -> Vec<(String, String)> { headers .iter() diff --git a/crates/token_proxy_core/src/proxy/token_estimator.rs b/crates/token_proxy_core/src/proxy/token_estimator.rs index cf7d442..9dc1941 100644 --- a/crates/token_proxy_core/src/proxy/token_estimator.rs +++ b/crates/token_proxy_core/src/proxy/token_estimator.rs @@ -1,4 +1,4 @@ -use std::sync::OnceLock; +use std::{collections::HashSet, sync::OnceLock}; use tiktoken_rs::{cl100k_base, o200k_base, CoreBPE}; @@ -207,7 +207,20 @@ fn is_cjk(ch: char) -> bool { let code = ch as u32; matches!( code, - 0x3400..=0x4DBF + // Japanese kana (Hiragana + Katakana). + // new-api 口径:0x3040-0x30FF + 0x3040..=0x30FF + // Korean (Hangul syllables). + // new-api 口径:0xAC00-0xD7A3 + | 0xAC00..=0xD7A3 + // CJK radicals (unicode.Han includes 0x2E80-0x2FDF, but does NOT include 0x2FF0-0x2FFF). + | 0x2E80..=0x2FDF + // Special Han-script characters in CJK Symbols and Punctuation. + // Verified via Go `unicode.Is(unicode.Han, r)`. + | 0x3005 + | 0x3007 + | 0x303B + | 0x3400..=0x4DBF | 0x4E00..=0x9FFF | 0xF900..=0xFAFF | 0x20000..=0x2A6DF @@ -215,6 +228,8 @@ fn is_cjk(ch: char) -> bool { | 0x2B740..=0x2B81F | 0x2B820..=0x2CEAF | 0x2CEB0..=0x2EBEF + // CJK Compatibility Ideographs Supplement. + | 0x2F800..=0x2FA1F | 0x30000..=0x3134F ) } @@ -238,19 +253,27 @@ fn is_emoji(ch: char) -> bool { fn is_math_symbol(ch: char) -> bool { let code = ch as u32; - matches!( - code, - 0x2200..=0x22FF | 0x27C0..=0x27EF | 0x2980..=0x29FF | 0x2A00..=0x2AFF - | 0x2190..=0x21FF | 0x2B00..=0x2BFF - ) || matches!(ch, '+' | '-' | '*' | '/' | '=' | '^' | '%') + // Mirror new-api: + // - explicit symbol list (covers degrees, primes, super/sub-scripts, etc.) + // - Mathematical Operators (U+2200–U+22FF) + // - Supplemental Mathematical Operators (U+2A00–U+2AFF) + // - Mathematical Alphanumeric Symbols (U+1D400–U+1D7FF) + matches!(code, 0x2200..=0x22FF | 0x2A00..=0x2AFF | 0x1D400..=0x1D7FF) + || math_symbol_set().contains(&ch) } fn is_url_delim(ch: char) -> bool { - matches!( - ch, - ':' | '/' | '?' | '#' | '[' | ']' | '!' | '$' | '&' | '\'' - | '(' | ')' | '*' | '+' | ',' | ';' | '=' - ) + // Mirror new-api: "/:?&=;#%" + matches!(ch, '/' | ':' | '?' | '&' | '=' | ';' | '#' | '%') +} + +fn math_symbol_set() -> &'static HashSet { + static SYMBOLS: OnceLock> = OnceLock::new(); + SYMBOLS.get_or_init(|| { + // Keep this list identical to `.reference/new-api/service/token_estimator.go` to avoid drift. + const MATH_SYMBOLS: &str = "∑∫∂√∞≤≥≠≈±×÷∈∉∋∌⊂⊃⊆⊇∪∩∧∨¬∀∃∄∅∆∇∝∟∠∡∢°′″‴⁺⁻⁼⁽⁾ⁿ₀₁₂₃₄₅₆₇₈₉₊₋₌₍₎²³¹⁴⁵⁶⁷⁸⁹⁰"; + MATH_SYMBOLS.chars().collect() + }) } // 单元测试拆到独立文件,使用 `#[path]` 以保持 `.test.rs` 命名约定。 diff --git a/crates/token_proxy_core/src/proxy/token_estimator.test.rs b/crates/token_proxy_core/src/proxy/token_estimator.test.rs index 47e33c1..4a6d6d3 100644 --- a/crates/token_proxy_core/src/proxy/token_estimator.test.rs +++ b/crates/token_proxy_core/src/proxy/token_estimator.test.rs @@ -6,3 +6,35 @@ fn estimate_tokens_for_claude_uses_heuristic() { // Claude word multiplier 1.13 -> ceil => 2 assert_eq!(tokens, 2); } + +#[test] +fn estimate_tokens_counts_japanese_as_cjk_per_char() { + // new-api 的 isCJK 包含 0x3040..=0x30FF(假名),本项目需要对齐。 + let tokens = estimate_text_tokens(Some("claude-3-opus"), "あいうえお"); + // 5 chars * 1.21 => 6.05 -> ceil => 7 + assert_eq!(tokens, 7); +} + +#[test] +fn estimate_tokens_counts_korean_as_cjk_per_char() { + // new-api 的 isCJK 包含 0xAC00..=0xD7A3(韩文音节),本项目需要对齐。 + let tokens = estimate_text_tokens(Some("gemini-1.5-flash"), "가나다"); + // 3 chars * 0.68 => 2.04 -> ceil => 3 + assert_eq!(tokens, 3); +} + +#[test] +fn estimate_tokens_treats_percent_as_url_delim() { + // new-api 的 URLDelim 集合包含 '%' + let tokens = estimate_text_tokens(Some("claude-3-opus"), "%"); + // URLDelim 1.26 -> ceil => 2 + assert_eq!(tokens, 2); +} + +#[test] +fn estimate_tokens_treats_plus_as_symbol() { + // new-api 的 mathSymbols 不包含 '+',应按 Symbol 计费。 + let tokens = estimate_text_tokens(Some("claude-3-opus"), "+"); + // Symbol 0.4 -> ceil => 1 + assert_eq!(tokens, 1); +} diff --git a/crates/token_proxy_core/src/proxy/upstream/attempt.rs b/crates/token_proxy_core/src/proxy/upstream/attempt.rs index 3686dfb..8974ff1 100644 --- a/crates/token_proxy_core/src/proxy/upstream/attempt.rs +++ b/crates/token_proxy_core/src/proxy/upstream/attempt.rs @@ -1,7 +1,7 @@ use std::time::Instant; use axum::http::{ - header::{ACCEPT, ACCEPT_ENCODING, CONTENT_TYPE, USER_AGENT}, + header::{ACCEPT, ACCEPT_ENCODING, AUTHORIZATION, CONTENT_TYPE, USER_AGENT}, HeaderMap, HeaderValue, Method, StatusCode, }; use reqwest::{Client, Proxy}; @@ -767,7 +767,13 @@ fn antigravity_request_headers( meta: &RequestMeta, antigravity: Option<&super::AntigravityRequestInfo>, ) -> HeaderMap { - let mut headers = base.clone(); + // Align with CLIProxyAPIPlus: only forward essential headers to Antigravity. + // Do NOT pass through inbound headers (e.g. anthropic-beta/x-stainless), as they can + // trigger upstream validation errors. + let mut headers = HeaderMap::new(); + if let Some(value) = base.get(AUTHORIZATION).cloned() { + headers.insert(AUTHORIZATION, value); + } let user_agent = antigravity .map(|info| info.user_agent.clone()) .unwrap_or_else(antigravity_endpoints::default_user_agent); diff --git a/crates/token_proxy_core/src/proxy/upstream/request.rs b/crates/token_proxy_core/src/proxy/upstream/request.rs index a4583af..a3004d8 100644 --- a/crates/token_proxy_core/src/proxy/upstream/request.rs +++ b/crates/token_proxy_core/src/proxy/upstream/request.rs @@ -20,7 +20,7 @@ use super::super::{ RequestMeta, }; use super::super::http::RequestAuth; -use crate::proxy::server_helpers::log_debug_headers_body; +use crate::proxy::server_helpers::{log_debug_headers_body, truncate_for_log}; const ANTHROPIC_VERSION_HEADER: &str = "anthropic-version"; const DEFAULT_ANTHROPIC_VERSION: &str = "2023-06-01"; @@ -30,6 +30,7 @@ const OPENAI_RESPONSES_PATH: &str = "/v1/responses"; // Keep in sync with server_helpers request transform limit (20 MiB). const REQUEST_FILTER_LIMIT_BYTES: usize = 20 * 1024 * 1024; const DEBUG_UPSTREAM_LOG_LIMIT_BYTES: usize = usize::MAX; +const ANTIGRAVITY_WRAPPED_LOG_LIMIT_BYTES: usize = 8 * 1024; pub(super) fn split_path_query(path_with_query: &str) -> (&str, Option<&str>) { match path_with_query.split_once('?') { @@ -238,9 +239,20 @@ async fn build_antigravity_body( DEBUG_UPSTREAM_LOG_LIMIT_BYTES, ) .await; + log_antigravity_wrapped_body(&wrapped); Ok(reqwest::Body::from(wrapped)) } +fn log_antigravity_wrapped_body(bytes: &[u8]) { + if !tracing::enabled!(tracing::Level::WARN) { + return; + } + let body_text = String::from_utf8_lossy(bytes); + let truncated = truncate_for_log(&body_text, ANTIGRAVITY_WRAPPED_LOG_LIMIT_BYTES); + // 仅在 antigravity 请求阶段记录,便于复现上游校验错误。 + tracing::warn!(body = %truncated, "antigravity wrapped request payload"); +} + async fn maybe_rewrite_request_body_model( body: &ReplayableBody, meta: &RequestMeta, diff --git a/src/features/config/inbound-formats.ts b/src/features/config/inbound-formats.ts index 1953922..10e025a 100644 --- a/src/features/config/inbound-formats.ts +++ b/src/features/config/inbound-formats.ts @@ -30,7 +30,9 @@ const PROVIDER_NATIVE_INBOUND_FORMATS: Readonly< gemini: ["gemini"], kiro: ["anthropic_messages"], codex: ["openai_responses"], - antigravity: ["gemini"], + // Align with backend `native_inbound_formats_for_provider()`: + // Antigravity supports Gemini routes + Claude Code (/v1/messages) out-of-the-box. + antigravity: ["gemini", "anthropic_messages"], }; export function getProviderNativeInboundFormats(provider: string) { From 57a9878112a4ee395ac0be6384c072ee8cec891b Mon Sep 17 00:00:00 2001 From: mxyhi Date: Sat, 31 Jan 2026 19:10:57 +0800 Subject: [PATCH 10/10] fix(proxy): defer date-suffixed model validation to Antigravity --- .../src/proxy/antigravity_compat.rs | 70 ++++++------------- .../src/proxy/antigravity_compat.test.rs | 50 +++++++++++-- .../src/proxy/antigravity_compat/claude.rs | 18 ++--- .../proxy/antigravity_compat/claude.test.rs | 16 +++++ 4 files changed, 90 insertions(+), 64 deletions(-) diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat.rs b/crates/token_proxy_core/src/proxy/antigravity_compat.rs index 6c42e3e..78b1505 100644 --- a/crates/token_proxy_core/src/proxy/antigravity_compat.rs +++ b/crates/token_proxy_core/src/proxy/antigravity_compat.rs @@ -182,6 +182,14 @@ where } fn extract_model(request: &mut Map, model_hint: Option<&str>) -> String { + // Align with CLIProxyAPIPlus model-mapping behavior: + // - If upstream/model-mapping produced a model_hint, it MUST override whatever the client put + // in the request body (e.g. Claude Code may send a Claude model that Antigravity doesn't have). + // - Always remove request["model"] so the inner request stays Gemini-shaped. + let hint = model_hint + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(|value| value.to_string()); let from_body = request .get("model") .and_then(Value::as_str) @@ -189,12 +197,7 @@ fn extract_model(request: &mut Map, model_hint: Option<&str>) -> .filter(|value| !value.is_empty()) .map(|value| value.to_string()); request.remove("model"); - let hint = model_hint - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()); - from_body - .or(hint) + hint.or(from_body) .unwrap_or_else(|| DEFAULT_MODEL.to_string()) } @@ -203,17 +206,21 @@ pub(crate) fn map_antigravity_model(model: &str) -> String { if trimmed.is_empty() { return DEFAULT_MODEL.to_string(); } - // Align with CLIProxyAPIPlus conventions: - // - Some clients expose Claude models behind a "gemini-" prefix (e.g. gemini-claude-opus-4-5-thinking) - // while Antigravity upstream uses the stable Claude name without the prefix. - if trimmed.starts_with("gemini-claude-") { - return trimmed.trim_start_matches("gemini-").to_string(); + // Strict alignment with CLIProxyAPIPlus: + // - Do NOT remap date-suffixed model IDs. Let Antigravity upstream validate/support them. + // - Only normalize legacy/alias model IDs that CLIProxy migrates for antigravity. + // - Keep "gemini-claude-*" aliases compatible by stripping the "gemini-" prefix. + match trimmed { + // Legacy Antigravity aliases used by older configs/clients. + "gemini-2.5-computer-use-preview-10-2025" => return "rev19-uic3-1p".to_string(), + "gemini-3-pro-image-preview" => return "gemini-3-pro-image".to_string(), + "gemini-3-pro-preview" => return "gemini-3-pro-high".to_string(), + "gemini-3-flash-preview" => return "gemini-3-flash".to_string(), + _ => {} } - // Claude Code / Amp CLI may request date-suffixed Claude models (e.g. claude-opus-4-5-20251101). - // Antigravity does not expose date-suffixed IDs; map them to the stable Antigravity model names. - if let Some(mapped) = map_claude_date_model_to_antigravity(trimmed) { - return mapped; + if trimmed.starts_with("gemini-claude-") { + return trimmed.trim_start_matches("gemini-").to_string(); } trimmed.to_string() @@ -224,39 +231,6 @@ pub(crate) fn map_antigravity_model(model: &str) -> String { #[path = "antigravity_compat.test.rs"] mod tests; -fn map_claude_date_model_to_antigravity(model: &str) -> Option { - if !model.starts_with("claude-") { - return None; - } - - // Allow optional "-thinking" suffix (some clients encode "thinking" in the model ID). - let (base, _has_thinking_suffix) = match model.strip_suffix("-thinking") { - Some(value) => (value, true), - None => (model, false), - }; - - // Detect the trailing date segment in `...-YYYYMMDD`. - let (without_date, date_suffix) = base.rsplit_once('-')?; - if date_suffix.len() != 8 || !date_suffix.chars().all(|ch| ch.is_ascii_digit()) { - return None; - } - - // Known Claude 4.5 model families: map to the Antigravity stable names. - // NOTE: Antigravity appears to expose Sonnet/Opus (and their thinking variants) but not Haiku. - if without_date.starts_with("claude-opus-4-5") { - return Some("claude-opus-4-5-thinking".to_string()); - } - if without_date.starts_with("claude-sonnet-4-5") { - return Some("claude-sonnet-4-5-thinking".to_string()); - } - if without_date.starts_with("claude-haiku-4-5") { - // Follow CLIProxyAPIPlus example mapping: route Haiku to a close Gemini alternative. - return Some("gemini-2.5-flash".to_string()); - } - - None -} - fn normalize_system_instruction(request: &mut Map) { if let Some(value) = request.remove("system_instruction") { request.insert("systemInstruction".to_string(), value); diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs b/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs index 410cf8e..942f121 100644 --- a/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs +++ b/crates/token_proxy_core/src/proxy/antigravity_compat.test.rs @@ -30,31 +30,67 @@ fn strips_gemini_prefix_for_claude_aliases() { } #[test] -fn maps_claude_opus_date_model_to_stable_thinking_model() { +fn keeps_claude_opus_date_model_unchanged() { assert_eq!( map_antigravity_model("claude-opus-4-5-20251101"), - "claude-opus-4-5-thinking" + "claude-opus-4-5-20251101" ); assert_eq!( map_antigravity_model("claude-opus-4-5-20251101-thinking"), - "claude-opus-4-5-thinking" + "claude-opus-4-5-20251101-thinking" ); } #[test] -fn maps_claude_sonnet_date_model_to_stable_thinking_model() { +fn keeps_claude_sonnet_date_model_unchanged() { assert_eq!( map_antigravity_model("claude-sonnet-4-5-20250929"), - "claude-sonnet-4-5-thinking" + "claude-sonnet-4-5-20250929" ); } #[test] -fn maps_claude_haiku_date_model_to_gemini_fallback() { +fn keeps_claude_haiku_date_model_unchanged() { assert_eq!( map_antigravity_model("claude-haiku-4-5-20251001"), - "gemini-2.5-flash" + "claude-haiku-4-5-20251001" + ); +} + +#[test] +fn maps_legacy_antigravity_aliases_to_canonical_ids() { + assert_eq!( + map_antigravity_model("gemini-2.5-computer-use-preview-10-2025"), + "rev19-uic3-1p" + ); + assert_eq!( + map_antigravity_model("gemini-3-pro-image-preview"), + "gemini-3-pro-image" ); + assert_eq!( + map_antigravity_model("gemini-3-pro-preview"), + "gemini-3-pro-high" + ); + assert_eq!( + map_antigravity_model("gemini-3-flash-preview"), + "gemini-3-flash" + ); +} + +#[test] +fn model_hint_overrides_body_model_in_wrap_request() { + let request = json!({ + "model": "claude-haiku-4-5-20251001", + "contents": [ + { "role": "user", "parts": [{ "text": "hello" }] } + ] + }); + let bytes = Bytes::from(request.to_string()); + let wrapped = + wrap_gemini_request(&bytes, Some("gemini-2.5-flash"), None, "ua").expect("wrap ok"); + let value: serde_json::Value = serde_json::from_slice(&wrapped).expect("wrapped json"); + assert_eq!(value["model"].as_str(), Some("gemini-2.5-flash")); + assert!(value["request"].get("systemInstruction").is_none()); } #[test] diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat/claude.rs b/crates/token_proxy_core/src/proxy/antigravity_compat/claude.rs index c582745..0859a49 100644 --- a/crates/token_proxy_core/src/proxy/antigravity_compat/claude.rs +++ b/crates/token_proxy_core/src/proxy/antigravity_compat/claude.rs @@ -52,19 +52,19 @@ fn parse_request_object(body: &Bytes) -> Result, String> { } fn resolve_model_name(object: &Map, model_hint: Option<&str>) -> String { - object + // Model mapping must override client-provided model when routing Claude Code -> Antigravity. + // This matches CLIProxyAPIPlus behavior where the translator receives the mapped model name. + let hint = model_hint + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(|value| value.to_string()); + let from_body = object .get("model") .and_then(Value::as_str) .map(|value| value.trim()) .filter(|value| !value.is_empty()) - .map(|value| value.to_string()) - .or_else(|| { - model_hint - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()) - }) - .unwrap_or_default() + .map(|value| value.to_string()); + hint.or(from_body).unwrap_or_default() } fn build_system_instruction(object: &Map, should_hint: bool) -> Option { diff --git a/crates/token_proxy_core/src/proxy/antigravity_compat/claude.test.rs b/crates/token_proxy_core/src/proxy/antigravity_compat/claude.test.rs index 4451d9c..3f28b6d 100644 --- a/crates/token_proxy_core/src/proxy/antigravity_compat/claude.test.rs +++ b/crates/token_proxy_core/src/proxy/antigravity_compat/claude.test.rs @@ -83,3 +83,19 @@ fn unsigned_thinking_is_removed() { assert_eq!(parts.len(), 1); assert_eq!(parts[0]["text"], "Answer"); } + +#[test] +fn model_hint_overrides_request_model() { + let input = Bytes::from( + r#"{ + "model": "claude-haiku-4-5-20251001", + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + }"#, + ); + let output = parse_output( + claude_request_to_antigravity(&input, Some("gemini-2.5-flash")).expect("convert"), + ); + assert_eq!(output["model"], "gemini-2.5-flash"); +}