diff --git a/Cargo.lock b/Cargo.lock index f5a42f463e..af10656259 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10451,6 +10451,7 @@ dependencies = [ "sentry", "serde", "serde_json", + "strum 0.27.2", "thiserror 2.0.18", "tokio", "tower 0.5.3", diff --git a/apps/api/openapi.gen.json b/apps/api/openapi.gen.json index b44c3f704a..cf116fea5d 100644 --- a/apps/api/openapi.gen.json +++ b/apps/api/openapi.gen.json @@ -172,6 +172,49 @@ } } }, + "/llm/chat/completions": { + "post": { + "tags": [ + "llm" + ], + "operationId": "llm_chat_completions", + "parameters": [ + { + "name": "x-char-task", + "in": "header", + "description": "Task type for model selection", + "required": false, + "schema": { + "oneOf": [ + { + "type": "null" + }, + { + "$ref": "#/components/schemas/CharTask" + } + ] + } + } + ], + "responses": { + "200": { + "description": "Chat completion response (streaming or non-streaming)" + }, + "401": { + "description": "Unauthorized" + }, + "429": { + "description": "Rate limit exceeded" + }, + "502": { + "description": "Upstream provider failed" + }, + "504": { + "description": "Request timeout" + } + } + } + }, "/nango/connect-session": { "post": { "tags": [ @@ -881,6 +924,14 @@ } } }, + "CharTask": { + "type": "string", + "enum": [ + "chat", + "enhance", + "title" + ] + }, "ConnectSessionResponse": { "type": "object", "required": [ @@ -1759,12 +1810,6 @@ "scheme": "bearer", "bearerFormat": "JWT", "description": "Supabase JWT token" - }, - "device_fingerprint": { - "type": "apiKey", - "in": "header", - "name": "x-device-fingerprint", - "description": "Optional device fingerprint for analytics" } } }, diff --git a/apps/api/src/openapi.rs b/apps/api/src/openapi.rs index 52d0bea2dd..ea7428646d 100644 --- a/apps/api/src/openapi.rs +++ b/apps/api/src/openapi.rs @@ -1,9 +1,7 @@ use std::collections::BTreeMap; use utoipa::openapi::path::{Operation, PathItem}; -use utoipa::openapi::security::{ - ApiKey, ApiKeyValue, Http, HttpAuthScheme, SecurityRequirement, SecurityScheme, -}; +use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityRequirement, SecurityScheme}; use utoipa::{Modify, OpenApi}; #[derive(OpenApi)] @@ -71,13 +69,6 @@ impl Modify for SecurityAddon { .build(), ), ); - components.add_security_scheme( - "device_fingerprint", - SecurityScheme::ApiKey(ApiKey::Header(ApiKeyValue::with_description( - "x-device-fingerprint", - "Optional device fingerprint for analytics", - ))), - ); } } } diff --git a/apps/desktop/src/components/chat/session.tsx b/apps/desktop/src/components/chat/session.tsx index 46e21d1f70..8c9cafe12e 100644 --- a/apps/desktop/src/components/chat/session.tsx +++ b/apps/desktop/src/components/chat/session.tsx @@ -213,7 +213,7 @@ function useTransport( systemPromptOverride?: string, ) { const registry = useToolRegistry(); - const configuredModel = useLanguageModel(); + const configuredModel = useLanguageModel("chat"); const model = modelOverride ?? configuredModel; const language = main.UI.useValue("ai_language", main.STORE_ID) ?? "en"; const [systemPrompt, setSystemPrompt] = useState(); diff --git a/apps/desktop/src/components/chat/view.tsx b/apps/desktop/src/components/chat/view.tsx index 5bc12c5619..fd666ff003 100644 --- a/apps/desktop/src/components/chat/view.tsx +++ b/apps/desktop/src/components/chat/view.tsx @@ -18,7 +18,7 @@ export function ChatView() { currentTab?.type === "sessions" ? currentTab.id : undefined; const stableSessionId = useStableSessionId(groupId); - const model = useLanguageModel(); + const model = useLanguageModel("chat"); const { handleSendMessage } = useChatActions({ groupId, diff --git a/apps/desktop/src/components/main/body/sessions/floating/options-menu.tsx b/apps/desktop/src/components/main/body/sessions/floating/options-menu.tsx index 45e8cb30a9..95a754e48e 100644 --- a/apps/desktop/src/components/main/body/sessions/floating/options-menu.tsx +++ b/apps/desktop/src/components/main/body/sessions/floating/options-menu.tsx @@ -62,7 +62,7 @@ export function OptionsMenu({ const { user_id } = main.UI.useValues(main.STORE_ID); const updateSessionTabState = useTabs((state) => state.updateSessionTabState); const createEnhancedNote = useCreateEnhancedNote(); - const model = useLanguageModel(); + const model = useLanguageModel("enhance"); const generate = useAITask((state) => state.generate); const selectedTemplateId = settings.UI.useValue( "selected_template_id", diff --git a/apps/desktop/src/components/main/body/sessions/note-input/enhanced/enhance-error.tsx b/apps/desktop/src/components/main/body/sessions/note-input/enhanced/enhance-error.tsx index 1286a7e57a..b1104cfdce 100644 --- a/apps/desktop/src/components/main/body/sessions/note-input/enhanced/enhance-error.tsx +++ b/apps/desktop/src/components/main/body/sessions/note-input/enhanced/enhance-error.tsx @@ -16,7 +16,7 @@ export function EnhanceError({ enhancedNoteId: string; error: Error | undefined; }) { - const model = useLanguageModel(); + const model = useLanguageModel("enhance"); const generate = useAITask((state) => state.generate); const templateId = (main.UI.useCell( diff --git a/apps/desktop/src/components/main/body/sessions/note-input/header.tsx b/apps/desktop/src/components/main/body/sessions/note-input/header.tsx index 018948cf43..14dcd2fbfe 100644 --- a/apps/desktop/src/components/main/body/sessions/note-input/header.tsx +++ b/apps/desktop/src/components/main/body/sessions/note-input/header.tsx @@ -301,7 +301,7 @@ function CreateOtherFormatButton({ main.STORE_ID, ); const createEnhancedNote = useCreateEnhancedNote(); - const model = useLanguageModel(); + const model = useLanguageModel("enhance"); const openNew = useTabs((state) => state.openNew); const store = main.UI.useStore(main.STORE_ID); @@ -611,7 +611,7 @@ function labelForEditorView(view: EditorView): string { } function useEnhanceLogic(sessionId: string, enhancedNoteId: string) { - const model = useLanguageModel(); + const model = useLanguageModel("enhance"); const llmStatus = useLLMConnectionStatus(); const taskId = createTaskId(enhancedNoteId, "enhance"); const [missingModelError, setMissingModelError] = useState( diff --git a/apps/desktop/src/hooks/autoEnhance/runner.ts b/apps/desktop/src/hooks/autoEnhance/runner.ts index 1c91355a4b..39a8be3283 100644 --- a/apps/desktop/src/hooks/autoEnhance/runner.ts +++ b/apps/desktop/src/hooks/autoEnhance/runner.ts @@ -30,7 +30,7 @@ export function useAutoEnhanceRunner( isEnhancing: boolean; } { const sessionId = tab.id; - const model = useLanguageModel(); + const model = useLanguageModel("enhance"); const { conn: llmConn } = useLLMConnection(); const { updateSessionTabState } = useTabs(); const createEnhancedNote = useCreateEnhancedNote(); diff --git a/apps/desktop/src/hooks/useLLMConnection.ts b/apps/desktop/src/hooks/useLLMConnection.ts index abd0abfaeb..bca8951042 100644 --- a/apps/desktop/src/hooks/useLLMConnection.ts +++ b/apps/desktop/src/hooks/useLLMConnection.ts @@ -7,6 +7,7 @@ import { fetch as tauriFetch } from "@tauri-apps/plugin-http"; import { extractReasoningMiddleware, wrapLanguageModel } from "ai"; import { useMemo } from "react"; +import type { CharTask } from "@hypr/api-client"; import type { AIProviderStorage } from "@hypr/store"; import { useAuth } from "../auth"; @@ -22,7 +23,7 @@ import { } from "../components/settings/ai/shared/eligibility"; import { env } from "../env"; import * as settings from "../store/tinybase/store/settings"; -import { tracedFetch } from "../utils/traced-fetch"; +import { createTracedFetch, tracedFetch } from "../utils/traced-fetch"; type LanguageModelV3 = Parameters[0]["model"]; @@ -52,9 +53,12 @@ type LLMConnectionResult = { status: LLMConnectionStatus; }; -export const useLanguageModel = (): LanguageModelV3 | null => { +export const useLanguageModel = (task?: CharTask): LanguageModelV3 | null => { const { conn } = useLLMConnection(); - return useMemo(() => (conn ? createLanguageModel(conn) : null), [conn]); + return useMemo( + () => (conn ? createLanguageModel(conn, task) : null), + [conn, task], + ); }; export const useLLMConnection = (): LLMConnectionResult => { @@ -227,11 +231,14 @@ const wrapWithThinkingMiddleware = ( }); }; -const createLanguageModel = (conn: LLMConnectionInfo): LanguageModelV3 => { +const createLanguageModel = ( + conn: LLMConnectionInfo, + task?: CharTask, +): LanguageModelV3 => { switch (conn.providerId) { case "hyprnote": { const provider = createOpenRouter({ - fetch: tracedFetch, + fetch: task ? createTracedFetch(task) : tracedFetch, baseURL: conn.baseUrl, apiKey: conn.apiKey, }); diff --git a/apps/desktop/src/hooks/useTitleGeneration.ts b/apps/desktop/src/hooks/useTitleGeneration.ts index 1e785c4edf..a8dfe75a7a 100644 --- a/apps/desktop/src/hooks/useTitleGeneration.ts +++ b/apps/desktop/src/hooks/useTitleGeneration.ts @@ -8,7 +8,7 @@ import { useLanguageModel } from "./useLLMConnection"; export function useTitleGeneration(tab: Extract) { const sessionId = tab.id; - const model = useLanguageModel(); + const model = useLanguageModel("title"); const titleTaskId = createTaskId(sessionId, "title"); diff --git a/apps/desktop/src/utils/index.ts b/apps/desktop/src/utils/index.ts index d9f254e7c3..f6e82b5ed4 100644 --- a/apps/desktop/src/utils/index.ts +++ b/apps/desktop/src/utils/index.ts @@ -40,3 +40,4 @@ export const buildWebAppUrl = async ( export const DEFAULT_USER_ID = "00000000-0000-0000-0000-000000000000"; export const DEVICE_FINGERPRINT_HEADER = "x-device-fingerprint"; +export const CHAR_TASK_HEADER = "x-char-task"; diff --git a/apps/desktop/src/utils/traced-fetch.ts b/apps/desktop/src/utils/traced-fetch.ts index fe3842bf08..260120bd84 100644 --- a/apps/desktop/src/utils/traced-fetch.ts +++ b/apps/desktop/src/utils/traced-fetch.ts @@ -1,9 +1,10 @@ import * as Sentry from "@sentry/react"; import { fetch as tauriFetch } from "@tauri-apps/plugin-http"; +import type { CharTask } from "@hypr/api-client"; import { commands as miscCommands } from "@hypr/plugin-misc"; -import { DEVICE_FINGERPRINT_HEADER } from "."; +import { CHAR_TASK_HEADER, DEVICE_FINGERPRINT_HEADER } from "."; let cachedFingerprint: string | null = null; @@ -57,3 +58,11 @@ export const tracedFetch: typeof fetch = async (input, init) => { }, ); }; + +export function createTracedFetch(task: CharTask): typeof fetch { + return async (input, init) => { + const headers = new Headers(init?.headers); + headers.set(CHAR_TASK_HEADER, task); + return tracedFetch(input, { ...init, headers }); + }; +} diff --git a/crates/api-support/src/routes/mod.rs b/crates/api-support/src/routes/mod.rs index 44acbd4f62..fa891112d4 100644 --- a/crates/api-support/src/routes/mod.rs +++ b/crates/api-support/src/routes/mod.rs @@ -13,15 +13,23 @@ use crate::state::AppState; pub use feedback::{FeedbackRequest, FeedbackResponse}; pub async fn router(config: SupportConfig) -> Router { + let resolver = hypr_llm_proxy::StaticModelResolver::default() + .with_models( + hypr_llm_proxy::MODEL_KEY_DEFAULT, + vec![ + "openai/gpt-oss-120b".into(), + "moonshotai/kimi-k2-0905".into(), + ], + ) + .with_models( + hypr_llm_proxy::MODEL_KEY_TOOL_CALLING, + vec![ + "anthropic/claude-haiku-4.5".into(), + "moonshotai/kimi-k2-0905:exacto".into(), + ], + ); let llm_config = hypr_llm_proxy::LlmProxyConfig::new(&config.openrouter) - .with_models_default(vec![ - "openai/gpt-oss-120b".into(), - "moonshotai/kimi-k2-0905".into(), - ]) - .with_models_tool_calling(vec![ - "anthropic/claude-haiku-4.5".into(), - "moonshotai/kimi-k2-0905:exacto".into(), - ]); + .with_model_resolver(std::sync::Arc::new(resolver)); let llm_router = hypr_llm_proxy::router(llm_config); let state = AppState::new(config).await; diff --git a/crates/llm-proxy/Cargo.toml b/crates/llm-proxy/Cargo.toml index 28161de4ae..486dc4fddb 100644 --- a/crates/llm-proxy/Cargo.toml +++ b/crates/llm-proxy/Cargo.toml @@ -20,6 +20,7 @@ backon = { workspace = true } bytes = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +strum = { workspace = true, features = ["derive"] } thiserror = { workspace = true } utoipa = { workspace = true } diff --git a/crates/llm-proxy/src/config.rs b/crates/llm-proxy/src/config.rs index 79bb4afe71..fe78e1b80c 100644 --- a/crates/llm-proxy/src/config.rs +++ b/crates/llm-proxy/src/config.rs @@ -3,6 +3,7 @@ use std::time::Duration; use crate::analytics::AnalyticsReporter; use crate::env::ApiKey; +use crate::model::{ModelContext, ModelResolver, StaticModelResolver}; use crate::provider::{OpenRouterProvider, Provider}; const DEFAULT_TIMEOUT_MS: u64 = 120_000; @@ -28,8 +29,7 @@ impl Default for RetryConfig { pub struct LlmProxyConfig { pub api_key: String, pub timeout: Duration, - pub models_tool_calling: Vec, - pub models_default: Vec, + resolver: Arc, pub analytics: Option>, pub provider: Arc, pub retry_config: RetryConfig, @@ -40,34 +40,24 @@ impl LlmProxyConfig { Self { api_key: api_key.into().0, timeout: Duration::from_millis(DEFAULT_TIMEOUT_MS), - models_tool_calling: vec![ - "anthropic/claude-haiku-4.5".into(), - "openai/gpt-oss-120b:exacto".into(), - "moonshotai/kimi-k2-0905:exacto".into(), - ], - models_default: vec![ - "anthropic/claude-sonnet-4.5".into(), - "openai/gpt-5.2-chat".into(), - "moonshotai/kimi-k2-0905".into(), - ], + resolver: Arc::new(StaticModelResolver::default()), analytics: None, provider: Arc::new(OpenRouterProvider::default()), retry_config: RetryConfig::default(), } } - pub fn with_timeout(mut self, timeout: Duration) -> Self { - self.timeout = timeout; - self + pub fn resolve(&self, ctx: &ModelContext) -> Vec { + self.resolver.resolve(ctx) } - pub fn with_models_tool_calling(mut self, models: Vec) -> Self { - self.models_tool_calling = models; + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; self } - pub fn with_models_default(mut self, models: Vec) -> Self { - self.models_default = models; + pub fn with_model_resolver(mut self, resolver: Arc) -> Self { + self.resolver = resolver; self } diff --git a/crates/llm-proxy/src/handler/mod.rs b/crates/llm-proxy/src/handler/mod.rs index e4b69fd269..8ab4b0db1d 100644 --- a/crates/llm-proxy/src/handler/mod.rs +++ b/crates/llm-proxy/src/handler/mod.rs @@ -20,6 +20,7 @@ use reqwest::Client; use crate::analytics::{AnalyticsReporter, GenerationEvent}; use crate::config::LlmProxyConfig; +use crate::model::{CharTask, ModelContext}; use crate::types::{ChatCompletionRequest, ToolChoice}; async fn report_with_cost( @@ -170,24 +171,31 @@ where async fn completions_handler( State(state): State, analytics_ctx: AnalyticsContext, + headers: axum::http::HeaderMap, Json(request): Json, ) -> Response { let start_time = Instant::now(); + let task = headers + .get(crate::CHAR_TASK_HEADER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + let needs_tool_calling = request.tools.as_ref().is_some_and(|t| !t.is_empty()) && !matches!(&request.tool_choice, Some(ToolChoice::String(s)) if s == "none"); - let models = if needs_tool_calling { - state.config.models_tool_calling.clone() - } else { - state.config.models_default.clone() + let ctx = ModelContext { + task, + needs_tool_calling, }; + let models = state.config.resolve(&ctx); let stream = request.stream.unwrap_or(false); tracing::info!( stream = %stream, has_tools = %needs_tool_calling, + task = ?task, message_count = %request.messages.len(), model_count = %models.len(), provider = %state.config.provider.name(), @@ -203,11 +211,17 @@ async fn completions_handler( } scope.set_tag("llm.stream", stream.to_string()); scope.set_tag("llm.tool_calling", needs_tool_calling.to_string()); + if let Some(t) = &task { + scope.set_tag("llm.task", t.to_string()); + } let mut ctx = BTreeMap::new(); ctx.insert("model_count".into(), models.len().into()); ctx.insert("message_count".into(), request.messages.len().into()); ctx.insert("has_tools".into(), needs_tool_calling.into()); + if let Some(t) = &task { + ctx.insert("task".into(), serde_json::Value::String(t.to_string())); + } scope.set_context("llm_request", sentry::protocol::Context::Other(ctx)); }); diff --git a/crates/llm-proxy/src/lib.rs b/crates/llm-proxy/src/lib.rs index 8645e8fd6c..25e13bd6c0 100644 --- a/crates/llm-proxy/src/lib.rs +++ b/crates/llm-proxy/src/lib.rs @@ -2,13 +2,20 @@ mod analytics; mod config; mod env; mod handler; +pub mod model; mod openapi; pub mod provider; mod types; +pub const CHAR_TASK_HEADER: &str = "x-char-task"; + pub use analytics::{AnalyticsReporter, GenerationEvent}; pub use config::*; pub use env::{ApiKey, Env}; pub use handler::{chat_completions_router, router}; pub use hypr_analytics::{AuthenticatedUserId, DeviceFingerprint}; +pub use model::{ + CharTask, MODEL_KEY_DEFAULT, MODEL_KEY_TOOL_CALLING, ModelContext, ModelResolver, + StaticModelResolver, +}; pub use openapi::openapi; diff --git a/crates/llm-proxy/src/model.rs b/crates/llm-proxy/src/model.rs new file mode 100644 index 0000000000..fe87ea4191 --- /dev/null +++ b/crates/llm-proxy/src/model.rs @@ -0,0 +1,214 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +pub const MODEL_KEY_DEFAULT: &str = "default"; +pub const MODEL_KEY_TOOL_CALLING: &str = "tool_calling"; + +#[derive( + Debug, + Clone, + Copy, + PartialEq, + Eq, + Hash, + Serialize, + Deserialize, + ToSchema, + strum::Display, + strum::EnumString, +)] +#[serde(rename_all = "lowercase")] +#[strum(serialize_all = "lowercase")] +pub enum CharTask { + Chat, + Enhance, + Title, +} + +pub struct ModelContext { + pub task: Option, + pub needs_tool_calling: bool, +} + +pub trait ModelResolver: Send + Sync { + fn resolve(&self, ctx: &ModelContext) -> Vec; +} + +#[derive(Clone)] +pub struct StaticModelResolver { + pub(crate) models: HashMap>, +} + +impl Default for StaticModelResolver { + fn default() -> Self { + let mut models = HashMap::new(); + + models.insert( + CharTask::Chat.to_string(), + vec![ + "anthropic/claude-haiku-4.5".into(), + "anthropic/claude-sonnet-4.6".into(), + "z-ai/glm-5".into(), + ], + ); + models.insert( + CharTask::Title.to_string(), + vec![ + "moonshotai/kimi-k2-0905".into(), + "google/gemini-2.5-flash-lite".into(), + "z-ai/glm-4.7-flash".into(), + ], + ); + models.insert( + MODEL_KEY_TOOL_CALLING.to_owned(), + vec![ + "anthropic/claude-sonnet-4.6".into(), + "anthropic/claude-haiku-4.5".into(), + "moonshotai/kimi-k2-0905:exacto".into(), + ], + ); + models.insert( + MODEL_KEY_DEFAULT.to_owned(), + vec![ + "anthropic/claude-sonnet-4.6".into(), + "openai/gpt-5.2-chat".into(), + "moonshotai/kimi-k2-0905".into(), + ], + ); + + Self { models } + } +} + +impl StaticModelResolver { + pub fn with_models(mut self, key: impl Into, models: Vec) -> Self { + self.models.insert(key.into(), models); + self + } +} + +impl ModelResolver for StaticModelResolver { + fn resolve(&self, ctx: &ModelContext) -> Vec { + if let Some(models) = ctx.task.and_then(|t| self.models.get(&t.to_string())) { + return models.clone(); + } + + let key = if ctx.needs_tool_calling { + MODEL_KEY_TOOL_CALLING + } else { + MODEL_KEY_DEFAULT + }; + self.models.get(key).cloned().unwrap_or_default() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + type ResolveTestCase = ( + &'static str, + Option, + bool, + Option<(&'static str, Vec<&'static str>)>, + &'static [&'static str], + ); + + fn run_resolve_test( + name: &str, + resolver: StaticModelResolver, + ctx: ModelContext, + expected: &[&str], + ) { + let models = resolver.resolve(&ctx); + let expected: Vec = expected.iter().map(|s| (*s).to_string()).collect(); + assert_eq!(models, expected, "{name}"); + } + + #[test] + fn resolve() { + let cases: &[ResolveTestCase] = &[ + ( + "by_task", + Some(CharTask::Chat), + false, + None, + &[ + "anthropic/claude-haiku-4.5", + "anthropic/claude-sonnet-4.6", + "z-ai/glm-5", + ], + ), + ( + "by_tool_calling", + None, + true, + None, + &[ + "anthropic/claude-sonnet-4.6", + "anthropic/claude-haiku-4.5", + "moonshotai/kimi-k2-0905:exacto", + ], + ), + ( + "default", + None, + false, + None, + &[ + "anthropic/claude-sonnet-4.6", + "openai/gpt-5.2-chat", + "moonshotai/kimi-k2-0905", + ], + ), + ( + "task_overrides_tool_calling", + Some(CharTask::Chat), + true, + None, + &[ + "anthropic/claude-haiku-4.5", + "anthropic/claude-sonnet-4.6", + "z-ai/glm-5", + ], + ), + ( + "with_models_custom_key", + Some(CharTask::Enhance), + false, + Some(("enhance", vec!["foo/bar"])), + &["foo/bar"], + ), + ( + "unknown_task_falls_back_to_default", + Some(CharTask::Enhance), + false, + None, + &[ + "anthropic/claude-sonnet-4.6", + "openai/gpt-5.2-chat", + "moonshotai/kimi-k2-0905", + ], + ), + ]; + + for (name, task, needs_tool_calling, with_models, expected) in cases { + let mut resolver = StaticModelResolver::default(); + if let Some((key, models)) = with_models { + resolver = + resolver.with_models(*key, models.iter().map(|s| (*s).to_string()).collect()); + } + run_resolve_test( + name, + resolver, + ModelContext { + task: *task, + needs_tool_calling: *needs_tool_calling, + }, + expected, + ); + } + } +} diff --git a/crates/llm-proxy/src/openapi.rs b/crates/llm-proxy/src/openapi.rs index d69c6b19ea..71fa229f99 100644 --- a/crates/llm-proxy/src/openapi.rs +++ b/crates/llm-proxy/src/openapi.rs @@ -1,9 +1,29 @@ use utoipa::OpenApi; +use crate::model::CharTask; + +#[utoipa::path( + post, + path = "/llm/chat/completions", + operation_id = "llm_chat_completions", + params( + ("x-char-task" = Option, Header, description = "Task type for model selection"), + ), + responses( + (status = 200, description = "Chat completion response (streaming or non-streaming)"), + (status = 401, description = "Unauthorized"), + (status = 429, description = "Rate limit exceeded"), + (status = 502, description = "Upstream provider failed"), + (status = 504, description = "Request timeout"), + ), + tag = "llm", +)] +async fn _chat_completions_spec() {} + #[derive(OpenApi)] #[openapi( - paths(), - components(schemas()), + paths(_chat_completions_spec), + components(schemas(CharTask)), tags((name = "llm", description = "LLM chat completions proxy")) )] pub struct ApiDoc; diff --git a/crates/llm-proxy/tests/common/analytics.rs b/crates/llm-proxy/tests/common/analytics.rs new file mode 100644 index 0000000000..b4adfb0e53 --- /dev/null +++ b/crates/llm-proxy/tests/common/analytics.rs @@ -0,0 +1,41 @@ +use std::sync::{Arc, Mutex}; + +use llm_proxy::{AnalyticsReporter, GenerationEvent}; + +#[derive(Default, Clone)] +pub struct MockAnalytics { + events: Arc>>, +} + +impl AnalyticsReporter for MockAnalytics { + fn report_generation( + &self, + event: GenerationEvent, + ) -> std::pin::Pin + Send + '_>> { + let events = self.events.clone(); + Box::pin(async move { + events.lock().unwrap().push(event); + }) + } +} + +impl MockAnalytics { + pub fn captured_events(&self) -> Vec { + self.events.lock().unwrap().clone() + } +} + +pub fn simple_message(content: &str) -> serde_json::Value { + serde_json::json!({ + "messages": [{"role": "user", "content": content}], + "max_tokens": 10 + }) +} + +pub fn stream_request(content: &str) -> serde_json::Value { + serde_json::json!({ + "messages": [{"role": "user", "content": content}], + "stream": true, + "max_tokens": 10 + }) +} diff --git a/crates/llm-proxy/tests/utils.rs b/crates/llm-proxy/tests/common/harness.rs similarity index 81% rename from crates/llm-proxy/tests/utils.rs rename to crates/llm-proxy/tests/common/harness.rs index a301bdbe07..d9318da690 100644 --- a/crates/llm-proxy/tests/utils.rs +++ b/crates/llm-proxy/tests/common/harness.rs @@ -1,34 +1,17 @@ -use std::sync::{Arc, Mutex}; +#![allow(dead_code)] + +use std::sync::Arc; use axum::body::Body; use axum::http::Request; use llm_proxy::provider::OpenRouterProvider; -use llm_proxy::{AnalyticsReporter, GenerationEvent, LlmProxyConfig}; +use llm_proxy::{GenerationEvent, LlmProxyConfig, MODEL_KEY_DEFAULT, StaticModelResolver}; use wiremock::matchers::{header, method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; -#[derive(Default, Clone)] -pub struct MockAnalytics { - events: Arc>>, -} - -impl AnalyticsReporter for MockAnalytics { - fn report_generation( - &self, - event: GenerationEvent, - ) -> std::pin::Pin + Send + '_>> { - let events = self.events.clone(); - Box::pin(async move { - events.lock().unwrap().push(event); - }) - } -} +use super::analytics::MockAnalytics; impl MockAnalytics { - pub fn captured_events(&self) -> Vec { - self.events.lock().unwrap().clone() - } - pub async fn get_single_event(&self) -> GenerationEvent { let timeout = std::time::Duration::from_secs(10); let poll_interval = std::time::Duration::from_millis(50); @@ -64,16 +47,20 @@ impl TestHarness { } pub fn config(&self) -> LlmProxyConfig { + let resolver = StaticModelResolver::default() + .with_models(MODEL_KEY_DEFAULT, vec!["openai/gpt-4.1-nano".into()]); LlmProxyConfig::new("test-api-key") .with_provider(Arc::new(OpenRouterProvider::new(self.mock_server.uri()))) - .with_models_default(vec!["openai/gpt-4.1-nano".into()]) + .with_model_resolver(Arc::new(resolver)) .with_analytics(Arc::new(self.analytics.clone())) } pub fn config_no_analytics(&self) -> LlmProxyConfig { + let resolver = StaticModelResolver::default() + .with_models(MODEL_KEY_DEFAULT, vec!["openai/gpt-4.1-nano".into()]); LlmProxyConfig::new("test-api-key") .with_provider(Arc::new(OpenRouterProvider::new(self.mock_server.uri()))) - .with_models_default(vec!["openai/gpt-4.1-nano".into()]) + .with_model_resolver(Arc::new(resolver)) } pub async fn mount_json_response(&self, response: serde_json::Value) { @@ -134,21 +121,6 @@ pub fn build_request(body: serde_json::Value) -> Request { .unwrap() } -pub fn simple_message(content: &str) -> serde_json::Value { - serde_json::json!({ - "messages": [{"role": "user", "content": content}], - "max_tokens": 10 - }) -} - -pub fn stream_request(content: &str) -> serde_json::Value { - serde_json::json!({ - "messages": [{"role": "user", "content": content}], - "stream": true, - "max_tokens": 10 - }) -} - pub fn stream_chunks(id: &str) -> [String; 4] { [ format!( diff --git a/crates/llm-proxy/tests/common/mod.rs b/crates/llm-proxy/tests/common/mod.rs new file mode 100644 index 0000000000..3a0701da97 --- /dev/null +++ b/crates/llm-proxy/tests/common/mod.rs @@ -0,0 +1,2 @@ +pub mod analytics; +pub mod harness; diff --git a/crates/llm-proxy/tests/proxy_e2e.rs b/crates/llm-proxy/tests/proxy_e2e.rs index ca840f1a1b..6db1989182 100644 --- a/crates/llm-proxy/tests/proxy_e2e.rs +++ b/crates/llm-proxy/tests/proxy_e2e.rs @@ -1,12 +1,12 @@ -mod utils; +mod common; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use axum::http::StatusCode; -use llm_proxy::{LlmProxyConfig, router}; -use utils::*; +use common::analytics::{MockAnalytics, simple_message, stream_request}; +use llm_proxy::{LlmProxyConfig, MODEL_KEY_DEFAULT, StaticModelResolver, router}; async fn start_server(config: LlmProxyConfig) -> SocketAddr { let app = router(config); @@ -23,11 +23,15 @@ async fn start_server(config: LlmProxyConfig) -> SocketAddr { fn real_config(analytics: MockAnalytics) -> LlmProxyConfig { let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set"); - LlmProxyConfig::new(api_key) - .with_models_default(vec![ + let resolver = StaticModelResolver::default().with_models( + MODEL_KEY_DEFAULT, + vec![ "moonshotai/kimi-k2-0905".into(), "anthropic/claude-haiku-4.5".into(), - ]) + ], + ); + LlmProxyConfig::new(api_key) + .with_model_resolver(Arc::new(resolver)) .with_analytics(Arc::new(analytics)) } diff --git a/crates/llm-proxy/tests/replay.rs b/crates/llm-proxy/tests/replay.rs index 686e6ab847..3ae99ed8df 100644 --- a/crates/llm-proxy/tests/replay.rs +++ b/crates/llm-proxy/tests/replay.rs @@ -1,10 +1,12 @@ -mod utils; -use utils::*; +mod common; + +use common::analytics::*; +use common::harness::*; use std::sync::Arc; use axum::http::StatusCode; -use llm_proxy::{LlmProxyConfig, router}; +use llm_proxy::{LlmProxyConfig, MODEL_KEY_DEFAULT, StaticModelResolver, router}; use tower::ServiceExt; mod basic { @@ -152,8 +154,10 @@ mod e2e { fn real_config(analytics: MockAnalytics) -> LlmProxyConfig { let api_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set"); + let resolver = StaticModelResolver::default() + .with_models(MODEL_KEY_DEFAULT, vec!["openai/gpt-4.1-nano".into()]); LlmProxyConfig::new(api_key) - .with_models_default(vec!["openai/gpt-4.1-nano".into()]) + .with_model_resolver(Arc::new(resolver)) .with_analytics(Arc::new(analytics)) } diff --git a/crates/transcribe-cactus/src/service/batch.rs b/crates/transcribe-cactus/src/service/batch.rs index c38ce514ef..bbe6340cb3 100644 --- a/crates/transcribe-cactus/src/service/batch.rs +++ b/crates/transcribe-cactus/src/service/batch.rs @@ -176,6 +176,8 @@ fn content_type_to_extension(content_type: &str) -> &'static str { mod tests { use std::path::Path; + use hypr_language::ISO639; + use super::*; #[test] @@ -301,13 +303,13 @@ mod tests { let wav_bytes = std::fs::read(hypr_data::english_1::AUDIO_PATH) .unwrap_or_else(|e| panic!("failed to read fixture wav: {e}")); - let response = transcribe_batch( - &wav_bytes, - "audio/wav", - &ListenParams::default(), - model_path, - ) - .unwrap_or_else(|e| panic!("real-model batch transcription failed: {e}")); + let params = ListenParams { + languages: vec![ISO639::En.into()], + ..Default::default() + }; + + let response = transcribe_batch(&wav_bytes, "audio/wav", ¶ms, model_path) + .unwrap_or_else(|e| panic!("real-model batch transcription failed: {e}")); let Some(channel) = response.results.channels.first() else { panic!("expected at least one channel in response"); diff --git a/crates/transcribe-cactus/tests/batch.rs b/crates/transcribe-cactus/tests/batch.rs index e0f5cd93a2..0504f1466e 100644 --- a/crates/transcribe-cactus/tests/batch.rs +++ b/crates/transcribe-cactus/tests/batch.rs @@ -66,7 +66,10 @@ fn e2e_batch() { let wav_bytes = audio_wav_bytes(); - let url = format!("http://{}/v1/listen?channels=1&sample_rate=16000&language=en", addr); + let url = format!( + "http://{}/v1/listen?channels=1&sample_rate=16000&language=en", + addr + ); let client = reqwest::Client::new(); let response = client .post(&url) diff --git a/packages/api-client/src/generated/index.ts b/packages/api-client/src/generated/index.ts index ec7961cd67..5030bf8792 100644 --- a/packages/api-client/src/generated/index.ts +++ b/packages/api-client/src/generated/index.ts @@ -1,4 +1,4 @@ // This file is auto-generated by @hey-api/openapi-ts -export { canStartTrial, createConnectSession, createContact, createConversation, createEvent, getMessages, listCalendars, listConnections, listConversations, listEvents, nangoWebhook, type Options, sendMessage, startTrial, sttListenBatch, sttListenStream, sttStatus, submit } from './sdk.gen'; -export type { BatchAlternatives, BatchChannel, BatchResponse, BatchResults, BatchWord, CanStartTrialData, CanStartTrialErrors, CanStartTrialReason, CanStartTrialResponse, CanStartTrialResponse2, CanStartTrialResponses, ClientOptions, ConnectionItem, ConnectSessionResponse, ConversationSummary, CreateConnectSessionData, CreateConnectSessionErrors, CreateConnectSessionResponse, CreateConnectSessionResponses, CreateContactData, CreateContactErrors, CreateContactRequest, CreateContactResponse, CreateContactResponse2, CreateContactResponses, CreateConversationData, CreateConversationErrors, CreateConversationRequest, CreateConversationResponse, CreateConversationResponse2, CreateConversationResponses, CreateEventData, CreateEventErrors, CreateEventRequest, CreateEventResponse, CreateEventResponse2, CreateEventResponses, DeviceInfo, EventAttendee, EventDateTime, FeedbackRequest, FeedbackResponse, FeedbackType, GetMessagesData, GetMessagesErrors, GetMessagesResponse, GetMessagesResponses, Interval, ListCalendarsData, ListCalendarsErrors, ListCalendarsResponse, ListCalendarsResponse2, ListCalendarsResponses, ListConnectionsData, ListConnectionsErrors, ListConnectionsResponse, ListConnectionsResponse2, ListConnectionsResponses, ListConversationsData, ListConversationsErrors, ListConversationsQuery, ListConversationsResponse, ListConversationsResponses, ListenCallbackRequest, ListenCallbackResponse, ListEventsData, ListEventsErrors, ListEventsRequest, ListEventsResponse, ListEventsResponse2, ListEventsResponses, MessageResponse, NangoWebhookData, NangoWebhookErrors, NangoWebhookResponse, NangoWebhookResponses, PipelineStatus, SendMessageData, SendMessageErrors, SendMessageRequest, SendMessageResponse, SendMessageResponses, StartTrialData, StartTrialErrors, StartTrialReason, StartTrialResponse, StartTrialResponse2, StartTrialResponses, StreamAlternatives, StreamChannel, StreamMetadata, StreamModelInfo, StreamResponse, StreamWord, SttListenBatchData, SttListenBatchErrors, SttListenBatchResponse, SttListenBatchResponses, SttListenStreamData, SttListenStreamErrors, SttStatusData, SttStatusErrors, SttStatusResponse, SttStatusResponse2, SttStatusResponses, SubmitData, SubmitError, SubmitErrors, SubmitResponse, SubmitResponses, WebhookResponse } from './types.gen'; +export { canStartTrial, createConnectSession, createContact, createConversation, createEvent, getMessages, listCalendars, listConnections, listConversations, listEvents, llmChatCompletions, nangoWebhook, type Options, sendMessage, startTrial, sttListenBatch, sttListenStream, sttStatus, submit } from './sdk.gen'; +export type { BatchAlternatives, BatchChannel, BatchResponse, BatchResults, BatchWord, CanStartTrialData, CanStartTrialErrors, CanStartTrialReason, CanStartTrialResponse, CanStartTrialResponse2, CanStartTrialResponses, CharTask, ClientOptions, ConnectionItem, ConnectSessionResponse, ConversationSummary, CreateConnectSessionData, CreateConnectSessionErrors, CreateConnectSessionResponse, CreateConnectSessionResponses, CreateContactData, CreateContactErrors, CreateContactRequest, CreateContactResponse, CreateContactResponse2, CreateContactResponses, CreateConversationData, CreateConversationErrors, CreateConversationRequest, CreateConversationResponse, CreateConversationResponse2, CreateConversationResponses, CreateEventData, CreateEventErrors, CreateEventRequest, CreateEventResponse, CreateEventResponse2, CreateEventResponses, DeviceInfo, EventAttendee, EventDateTime, FeedbackRequest, FeedbackResponse, FeedbackType, GetMessagesData, GetMessagesErrors, GetMessagesResponse, GetMessagesResponses, Interval, ListCalendarsData, ListCalendarsErrors, ListCalendarsResponse, ListCalendarsResponse2, ListCalendarsResponses, ListConnectionsData, ListConnectionsErrors, ListConnectionsResponse, ListConnectionsResponse2, ListConnectionsResponses, ListConversationsData, ListConversationsErrors, ListConversationsQuery, ListConversationsResponse, ListConversationsResponses, ListenCallbackRequest, ListenCallbackResponse, ListEventsData, ListEventsErrors, ListEventsRequest, ListEventsResponse, ListEventsResponse2, ListEventsResponses, LlmChatCompletionsData, LlmChatCompletionsErrors, LlmChatCompletionsResponses, MessageResponse, NangoWebhookData, NangoWebhookErrors, NangoWebhookResponse, NangoWebhookResponses, PipelineStatus, SendMessageData, SendMessageErrors, SendMessageRequest, SendMessageResponse, SendMessageResponses, StartTrialData, StartTrialErrors, StartTrialReason, StartTrialResponse, StartTrialResponse2, StartTrialResponses, StreamAlternatives, StreamChannel, StreamMetadata, StreamModelInfo, StreamResponse, StreamWord, SttListenBatchData, SttListenBatchErrors, SttListenBatchResponse, SttListenBatchResponses, SttListenStreamData, SttListenStreamErrors, SttStatusData, SttStatusErrors, SttStatusResponse, SttStatusResponse2, SttStatusResponses, SubmitData, SubmitError, SubmitErrors, SubmitResponse, SubmitResponses, WebhookResponse } from './types.gen'; diff --git a/packages/api-client/src/generated/sdk.gen.ts b/packages/api-client/src/generated/sdk.gen.ts index 500d1f9bb0..ff92c45a7e 100644 --- a/packages/api-client/src/generated/sdk.gen.ts +++ b/packages/api-client/src/generated/sdk.gen.ts @@ -2,7 +2,7 @@ import type { Client, Options as Options2, TDataShape } from './client'; import { client } from './client.gen'; -import type { CanStartTrialData, CanStartTrialErrors, CanStartTrialResponses, CreateConnectSessionData, CreateConnectSessionErrors, CreateConnectSessionResponses, CreateContactData, CreateContactErrors, CreateContactResponses, CreateConversationData, CreateConversationErrors, CreateConversationResponses, CreateEventData, CreateEventErrors, CreateEventResponses, GetMessagesData, GetMessagesErrors, GetMessagesResponses, ListCalendarsData, ListCalendarsErrors, ListCalendarsResponses, ListConnectionsData, ListConnectionsErrors, ListConnectionsResponses, ListConversationsData, ListConversationsErrors, ListConversationsResponses, ListEventsData, ListEventsErrors, ListEventsResponses, NangoWebhookData, NangoWebhookErrors, NangoWebhookResponses, SendMessageData, SendMessageErrors, SendMessageResponses, StartTrialData, StartTrialErrors, StartTrialResponses, SttListenBatchData, SttListenBatchErrors, SttListenBatchResponses, SttListenStreamData, SttListenStreamErrors, SttStatusData, SttStatusErrors, SttStatusResponses, SubmitData, SubmitErrors, SubmitResponses } from './types.gen'; +import type { CanStartTrialData, CanStartTrialErrors, CanStartTrialResponses, CreateConnectSessionData, CreateConnectSessionErrors, CreateConnectSessionResponses, CreateContactData, CreateContactErrors, CreateContactResponses, CreateConversationData, CreateConversationErrors, CreateConversationResponses, CreateEventData, CreateEventErrors, CreateEventResponses, GetMessagesData, GetMessagesErrors, GetMessagesResponses, ListCalendarsData, ListCalendarsErrors, ListCalendarsResponses, ListConnectionsData, ListConnectionsErrors, ListConnectionsResponses, ListConversationsData, ListConversationsErrors, ListConversationsResponses, ListEventsData, ListEventsErrors, ListEventsResponses, LlmChatCompletionsData, LlmChatCompletionsErrors, LlmChatCompletionsResponses, NangoWebhookData, NangoWebhookErrors, NangoWebhookResponses, SendMessageData, SendMessageErrors, SendMessageResponses, StartTrialData, StartTrialErrors, StartTrialResponses, SttListenBatchData, SttListenBatchErrors, SttListenBatchResponses, SttListenStreamData, SttListenStreamErrors, SttStatusData, SttStatusErrors, SttStatusResponses, SubmitData, SubmitErrors, SubmitResponses } from './types.gen'; export type Options = Options2 & { /** @@ -53,6 +53,8 @@ export const submit = (options: Options(options?: Options) => (options?.client ?? client).post({ url: '/llm/chat/completions', ...options }); + export const createConnectSession = (options?: Options) => (options?.client ?? client).post({ security: [{ scheme: 'bearer', type: 'http' }], url: '/nango/connect-session', diff --git a/packages/api-client/src/generated/types.gen.ts b/packages/api-client/src/generated/types.gen.ts index 862858721c..4ab4681119 100644 --- a/packages/api-client/src/generated/types.gen.ts +++ b/packages/api-client/src/generated/types.gen.ts @@ -41,6 +41,8 @@ export type CanStartTrialResponse = { reason?: null | CanStartTrialReason; }; +export type CharTask = 'chat' | 'enhance' | 'title'; + export type ConnectSessionResponse = { expires_at: string; token: string; @@ -377,6 +379,45 @@ export type SubmitResponses = { export type SubmitResponse = SubmitResponses[keyof SubmitResponses]; +export type LlmChatCompletionsData = { + body?: never; + headers?: { + /** + * Task type for model selection + */ + 'x-char-task'?: null | CharTask; + }; + path?: never; + query?: never; + url: '/llm/chat/completions'; +}; + +export type LlmChatCompletionsErrors = { + /** + * Unauthorized + */ + 401: unknown; + /** + * Rate limit exceeded + */ + 429: unknown; + /** + * Upstream provider failed + */ + 502: unknown; + /** + * Request timeout + */ + 504: unknown; +}; + +export type LlmChatCompletionsResponses = { + /** + * Chat completion response (streaming or non-streaming) + */ + 200: unknown; +}; + export type CreateConnectSessionData = { body?: never; path?: never;