diff --git a/frontend/src/api/setting.ts b/frontend/src/api/setting.ts index 517366826..fc179de5a 100644 --- a/frontend/src/api/setting.ts +++ b/frontend/src/api/setting.ts @@ -8,6 +8,8 @@ import { API_QUERY_KEYS } from "@/constants/api"; import type { ApiResponse } from "@/lib/api-client"; import { apiClient } from "@/lib/api-client"; import type { + CheckModelRequest, + CheckModelResult, MemoryItem, ModelProvider, ProviderDetail, @@ -170,6 +172,18 @@ export const useSetDefaultProviderModel = () => { }); }; +/** + * Check model availability by provider/model with optional strict live check. + * - When `strict` is false, validates configuration only (API key/base URL). + * - When `strict` is true, performs a minimal request to verify reachability. + */ +export const useCheckModelAvailability = () => { + return useMutation({ + mutationFn: (params: CheckModelRequest) => + apiClient.post>("/models/check", params), + }); +}; + /** * Hook to get model providers sorted by API key availability. * Providers with API keys configured appear first. diff --git a/frontend/src/app/setting/components/models/model-detail.tsx b/frontend/src/app/setting/components/models/model-detail.tsx index ffaa253e2..f7eec1dc0 100644 --- a/frontend/src/app/setting/components/models/model-detail.tsx +++ b/frontend/src/app/setting/components/models/model-detail.tsx @@ -5,6 +5,7 @@ import { useEffect, useState } from "react"; import { z } from "zod"; import { useAddProviderModel, + useCheckModelAvailability, useDeleteProviderModel, useGetModelProviderDetail, useSetDefaultProvider, @@ -65,6 +66,12 @@ export function ModelDetail({ provider }: ModelDetailProps) { useSetDefaultProviderModel(); const { mutate: setDefaultProvider, isPending: settingDefaultProvider } = useSetDefaultProvider(); + const { + data: checkResult, + mutateAsync: checkAvailability, + isPending: checkingAvailability, + reset: resetCheckResult, + } = useCheckModelAvailability(); const [isAddDialogOpen, setIsAddDialogOpen] = useState(false); const [showApiKey, setShowApiKey] = useState(false); @@ -95,8 +102,11 @@ export function ModelDetail({ provider }: ModelDetailProps) { }, [providerDetail, configForm.setFieldValue]); useEffect(() => { - if (provider) setShowApiKey(false); - }, [provider]); + if (provider) { + setShowApiKey(false); + resetCheckResult(); + } + }, [provider, resetCheckResult]); const addModelForm = useForm({ defaultValues: { @@ -133,7 +143,8 @@ export function ModelDetail({ provider }: ModelDetailProps) { addingModel || deletingModel || settingDefaultModel || - settingDefaultProvider; + settingDefaultProvider || + checkingAvailability; if (detailLoading) { return ( @@ -173,39 +184,79 @@ export function ModelDetail({ provider }: ModelDetailProps) { > API key - - field.handleChange(e.target.value)} - onBlur={() => configForm.handleSubmit()} - onKeyDown={(e) => { - if (e.key === "Enter") { - e.preventDefault(); - e.currentTarget.blur(); - } +
+ + field.handleChange(e.target.value)} + onBlur={() => configForm.handleSubmit()} + onKeyDown={(e) => { + if (e.key === "Enter") { + e.preventDefault(); + e.currentTarget.blur(); + } + }} + /> + + setShowApiKey(!showApiKey)} + aria-label={ + showApiKey ? "Hide password" : "Show password" + } + > + {showApiKey ? ( + + ) : ( + + )} + + + + + +
+ {checkResult?.data && ( +
+ {checkResult.data.ok ? ( + + Available + {checkResult.data.status + ? ` (${checkResult.data.status})` + : ""} + + ) : ( + + Unavailable + {checkResult.data.status + ? ` (${checkResult.data.status})` + : ""} + {checkResult.data.error + ? `: ${checkResult.data.error}` + : ""} + + )} +
+ )} SuccessResponse[CheckModelResponse]: + try: + manager = get_config_manager() + provider = payload.provider or manager.primary_provider + cfg = manager.get_provider_config(provider) + if cfg is None: + raise HTTPException( + status_code=404, detail=f"Provider '{provider}' not found" + ) + + model_id = payload.model_id or cfg.default_model + if not model_id: + raise HTTPException( + status_code=400, + detail="Model id not specified and provider has no default", + ) + + # Perform a minimal live request (ping) without configuration validation + result = CheckModelResponse( + ok=False, + provider=provider, + model_id=model_id, + status=None, + error=None, + ) + try: + import asyncio + + import httpx + except Exception as e: + result.ok = False + result.status = "runtime_missing" + result.error = f"Runtime dependency missing: {e}" + return SuccessResponse.create(data=result, msg="Live check failed") + + # Prefer a direct minimal request for OpenAI-compatible providers. + # This avoids hidden fallbacks and validates API key/auth. + api_key = (payload.api_key or cfg.api_key or "").strip() + base_url = (getattr(cfg, "base_url", None) or "").strip() + # Use direct request timeout only (no agent fallback) + direct_timeout_s = 5.0 + if provider == "google": + direct_timeout_s = 30.0 + + def _normalize_model_id_for_provider(provider_name: str, mid: str) -> str: + """Normalize model id for specific providers to avoid 404s. + + - Google Gemini: sometimes configs use vendor-prefixed ids like + "google/gemini-1.5-flash"; the REST path expects just the model + name segment (e.g., "gemini-1.5-flash"). + - Other providers: return as-is. + """ + if provider_name == "google" and "/" in mid: + return mid.split("/")[-1] + return mid + + normalized_model_id = _normalize_model_id_for_provider(provider, model_id) + + async def _direct_openai_like_ping(endpoint: str) -> bool: + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + json_body = { + "model": model_id, + "messages": [{"role": "user", "content": "ping"}], + "max_tokens": 1, + "temperature": 0, + } + async with httpx.AsyncClient(timeout=direct_timeout_s) as client: + resp = await client.post(endpoint, headers=headers, json=json_body) + # Handle auth failures explicitly + if resp.status_code in (401, 403): + try: + err_json = resp.json() + msg = err_json.get("error", {}).get("message") or str(err_json) + except Exception: + msg = resp.text + result.ok = False + result.status = "auth_failed" + result.error = msg or "Unauthorized" + return False + if resp.status_code >= 400: + # Other request failures + try: + err_json = resp.json() + msg = err_json.get("error", {}).get("message") or str(err_json) + except Exception: + msg = resp.text + result.ok = False + result.status = "request_failed" + result.error = msg or f"HTTP {resp.status_code}" + return False + # Success path: verify minimal structure + try: + data = resp.json() + except Exception: + data = None + if not data or "choices" not in data: + result.ok = False + result.status = "request_failed" + result.error = "Unexpected response structure" + return False + result.status = "reachable" + result.ok = True + return True + + async def _direct_google_ping(endpoint: str) -> bool: + # Gemini REST uses api key via query param `key`, but we also + # set header to be safe. + headers = { + "Content-Type": "application/json", + "x-goog-api-key": api_key, + } + json_body = { + "contents": [ + { + "role": "user", + "parts": [{"text": "ping"}], + } + ] + } + async with httpx.AsyncClient(timeout=direct_timeout_s) as client: + resp = await client.post( + endpoint, + headers=headers, + params={"key": api_key} if api_key else None, + json=json_body, + ) + + if resp.status_code in (401, 403): + try: + err_json = resp.json() + msg = err_json.get("error", {}).get("message") or str(err_json) + except Exception: + msg = resp.text + result.ok = False + result.status = "auth_failed" + result.error = msg or "Unauthorized" + return False + if resp.status_code >= 400: + try: + err_json = resp.json() + msg = err_json.get("error", {}).get("message") or str(err_json) + except Exception: + msg = resp.text + result.ok = False + result.status = "request_failed" + # Preserve HTTP code in error to enable v1/v1beta fallback on 404 + if msg: + result.error = f"HTTP {resp.status_code}: {msg}" + else: + result.error = f"HTTP {resp.status_code}" + return False + # Minimal success: presence of candidates + try: + data = resp.json() + except Exception: + data = None + if not data or "candidates" not in data: + result.ok = False + result.status = "request_failed" + result.error = "Unexpected response structure" + return False + result.status = "reachable" + result.ok = True + return True + + def _normalize_base_url(url: str) -> str: + return (url or "").strip().rstrip("/") + + def _resolve_endpoint() -> tuple[str | None, str]: + """Return (endpoint, style) where style in {"openai_like", "google", "azure"}. + + Priority: if base_url provided, derive from host; else fall back to known provider mappings. + """ + bu = _normalize_base_url(base_url) + # Host-driven detection + if bu: + lower = bu.lower() + if ( + "generativelanguage.googleapis.com" in lower + or "googleapis.com" in lower + ): + # Construct Google endpoint for fast direct ping + # Handle cases where base_url already includes '/models' or full ':generateContent' path + if ":generatecontent" in lower: + # Treat as full endpoint + return bu, "google" + if "/models/" in lower: + # If base_url already includes '/models', avoid duplicating + if lower.endswith("/models"): + endpoint = f"{bu}/{normalized_model_id}:generateContent" + else: + # base_url might be '/models/{model}', append ':generateContent' if missing + endpoint = ( + f"{bu}:generateContent" + if not lower.endswith(":generatecontent") + else bu + ) + return endpoint, "google" + # If base_url already includes version segment, do not repeat it + if lower.endswith("/v1beta") or "/v1beta/" in lower: + endpoint = ( + f"{bu}/models/{normalized_model_id}:generateContent" + ) + elif lower.endswith("/v1") or "/v1/" in lower: + endpoint = ( + f"{bu}/models/{normalized_model_id}:generateContent" + ) + else: + endpoint = f"{bu}/v1beta/models/{normalized_model_id}:generateContent" + return endpoint, "google" + if "openai.azure.com" in lower or "/openai/deployments" in lower: + # If user pasted a deployments URL, keep it; otherwise construct from base_url + # Azure requires api_version + api_version = ( + getattr(cfg, "extra_config", {}).get("api_version") + if hasattr(cfg, "extra_config") + else None + ) + if not api_version: + return None, "azure" + endpoint = f"{bu}/openai/deployments/{model_id}/chat/completions?api-version={api_version}" + return endpoint, "azure" + if "openrouter.ai" in lower: + return f"{bu}/api/v1/chat/completions" if not lower.endswith( + "/api/v1" + ) else f"{bu}/chat/completions", "openai_like" + if "openai.com" in lower: + return f"{bu}/v1/chat/completions" if not lower.endswith( + "/v1" + ) else f"{bu}/chat/completions", "openai_like" + if "deepseek.com" in lower: + return f"{bu}/v1/chat/completions" if not lower.endswith( + "/v1" + ) else f"{bu}/chat/completions", "openai_like" + if "siliconflow" in lower: + return f"{bu}/v1/chat/completions" if not lower.endswith( + "/v1" + ) else f"{bu}/chat/completions", "openai_like" + if "dashscope.aliyuncs.com" in lower or "dashscope.com" in lower: + # DashScope OpenAI-compatible endpoint lives under compatible-mode + if lower.endswith("/compatible-mode/v1"): + return f"{bu}/chat/completions", "openai_like" + return ( + f"{bu}/compatible-mode/v1/chat/completions", + "openai_like", + ) + # If base_url provided but host is unrecognized: + # - For openai-compatible, treat as generic OpenAI-like + # - For Google/Azure, ignore base_url and fall through to provider fallback + # - For other providers, fall through to provider fallback to use official endpoints + if provider == "openai-compatible": + return f"{bu}/v1/chat/completions", "openai_like" + + # Provider-driven fallback + if provider == "google": + # Official Google endpoint for direct ping (v1beta by default) + return ( + f"https://generativelanguage.googleapis.com/v1beta/models/{normalized_model_id}:generateContent", + "google", + ) + if provider == "azure": + api_version = ( + getattr(cfg, "extra_config", {}).get("api_version") + if hasattr(cfg, "extra_config") + else None + ) + if base_url and api_version: + endpoint = f"{base_url}/openai/deployments/{model_id}/chat/completions?api-version={api_version}" + return endpoint, "azure" + return None, "azure" + if provider == "openai": + return "https://api.openai.com/v1/chat/completions", "openai_like" + if provider == "openrouter": + return ( + "https://openrouter.ai/api/v1/chat/completions", + "openai_like", + ) + if provider == "deepseek": + return "https://api.deepseek.com/v1/chat/completions", "openai_like" + if provider == "siliconflow": + return ( + "https://api.siliconflow.cn/v1/chat/completions", + "openai_like", + ) + if provider == "dashscope": + return ( + "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions", + "openai_like", + ) + if provider == "openai-compatible": + if base_url: + bu = _normalize_base_url(base_url) + if bu.endswith("/v1"): + return f"{bu}/chat/completions", "openai_like" + return f"{bu}/v1/chat/completions", "openai_like" + return None, "openai_like" + + # Decide endpoint for known OpenAI-compatible providers + completed_via_direct = False + try: + if not api_key: + # Missing API key: fail fast for providers requiring auth + if provider in { + "openai", + "openrouter", + "deepseek", + "siliconflow", + "azure", + "google", + }: + result.ok = False + result.status = "auth_failed" + result.error = "API key is missing" + return SuccessResponse.create(data=result, msg="Auth failed") + + endpoint, style = _resolve_endpoint() + + if endpoint: + # Perform direct ping with timeout + if style == "google": + completed_via_direct = await asyncio.wait_for( + _direct_google_ping(endpoint), timeout=direct_timeout_s + ) + # If 404 from v1beta, try v1 (or vice versa) + if ( + not completed_via_direct + and (result.error or "").find("404") != -1 + ): + alt_endpoint = None + if "/v1beta/" in endpoint: + alt_endpoint = endpoint.replace("/v1beta/", "/v1/") + elif "/v1/" in endpoint: + alt_endpoint = endpoint.replace("/v1/", "/v1beta/") + if alt_endpoint: + # Reset status/error before retry + result.status = None + result.error = None + completed_via_direct = await asyncio.wait_for( + _direct_google_ping(alt_endpoint), + timeout=direct_timeout_s, + ) + else: + completed_via_direct = await asyncio.wait_for( + _direct_openai_like_ping(endpoint), timeout=direct_timeout_s + ) + if completed_via_direct: + return SuccessResponse.create( + data=result, msg="Model reachable" + ) + else: + return SuccessResponse.create( + data=result, msg=result.status or "Request failed" + ) + else: + # No endpoint available for direct probe + result.ok = False + result.status = "probe_unavailable" + if style == "azure": + result.error = "Azure requires API Host (base_url) and api_version for direct probe" + elif provider == "openai-compatible" and not base_url: + result.error = "OpenAI-compatible provider requires API Host to run direct probe" + else: + result.error = "Direct probe endpoint not resolved" + return SuccessResponse.create(data=result, msg="Probe unavailable") + except asyncio.TimeoutError: + result.ok = False + result.status = "timeout" + result.error = f"Timed out after {int(direct_timeout_s * 1000)} ms" + return SuccessResponse.create(data=result, msg="Timeout") + except httpx.TimeoutException: + result.ok = False + result.status = "timeout" + result.error = f"Timed out after {int(direct_timeout_s * 1000)} ms" + return SuccessResponse.create(data=result, msg="Timeout") + except Exception as e: + # Direct probe threw an unexpected error; report and do not fall back to agent + result.ok = False + result.status = "request_failed" + result.error = str(e) + return SuccessResponse.create(data=result, msg="Request failed") + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to check model: {e}") + return router diff --git a/python/valuecell/server/api/schemas/__init__.py b/python/valuecell/server/api/schemas/__init__.py index 58e8f300e..d80acf94a 100644 --- a/python/valuecell/server/api/schemas/__init__.py +++ b/python/valuecell/server/api/schemas/__init__.py @@ -35,7 +35,7 @@ UserI18nSettingsData, UserI18nSettingsRequest, ) -from .model import LLMModelConfigData +from .model import CheckModelRequest, CheckModelResponse, LLMModelConfigData from .task import TaskCancelData from .user_profile import ( CreateUserProfileRequest, @@ -117,4 +117,6 @@ "TaskCancelData", # Model schemas "LLMModelConfigData", + "CheckModelRequest", + "CheckModelResponse", ] diff --git a/python/valuecell/server/api/schemas/model.py b/python/valuecell/server/api/schemas/model.py index b641c3211..6a72b426b 100644 --- a/python/valuecell/server/api/schemas/model.py +++ b/python/valuecell/server/api/schemas/model.py @@ -83,3 +83,32 @@ class SetDefaultModelRequest(BaseModel): None, description="Optional display name; added/updated in models list if provided", ) + + +# --- Model availability check --- +class CheckModelRequest(BaseModel): + """Request payload to check if a provider+model is usable.""" + + provider: Optional[str] = Field( + None, description="Provider to check; defaults to current primary provider" + ) + model_id: Optional[str] = Field( + None, description="Model id to check; defaults to provider's default model" + ) + api_key: Optional[str] = Field( + None, description="Temporary API key to use for this check (optional)" + ) + # strict/live check removed; this endpoint now validates configuration only. + + +class CheckModelResponse(BaseModel): + """Response payload describing the model availability check result.""" + + ok: bool = Field(..., description="Whether the provider+model is usable") + provider: str = Field(..., description="Provider under test") + model_id: str = Field(..., description="Model id under test") + status: Optional[str] = Field( + None, + description="Status label like 'valid_config', 'reachable', 'timeout', 'request_failed'", + ) + error: Optional[str] = Field(None, description="Error message if any")