diff --git a/crates/common/src/routing.rs b/crates/common/src/routing.rs index 918435a64..1f167a671 100644 --- a/crates/common/src/routing.rs +++ b/crates/common/src/routing.rs @@ -40,15 +40,70 @@ pub fn get_llm_provider( let mut rng = thread_rng(); llm_providers .iter() - .filter(|(_, provider)| { - provider - .model - .as_ref() - .map(|m| !m.starts_with("Arch")) - .unwrap_or(true) - }) + .filter(|(_, provider)| provider.internal != Some(true)) .choose(&mut rng) - .expect("There should always be at least one non-Arch llm provider") + .expect("There should always be at least one non-internal llm provider") .1 .clone() } + +#[cfg(test)] +mod tests { + use super::*; + use crate::configuration::LlmProviderType; + + #[test] + fn test_get_llm_provider_excludes_internal_providers() { + let providers = vec![ + LlmProvider { + name: "openai-gpt4".to_string(), + provider_interface: LlmProviderType::OpenAI, + model: Some("gpt-4".to_string()), + internal: None, + default: None, + ..Default::default() + }, + LlmProvider { + name: "anthropic-claude".to_string(), + provider_interface: LlmProviderType::Anthropic, + model: Some("claude-3".to_string()), + internal: Some(false), + default: None, + ..Default::default() + }, + LlmProvider { + name: "arch-router".to_string(), + provider_interface: LlmProviderType::Arch, + model: Some("Arch-Router".to_string()), + internal: Some(true), + default: None, + ..Default::default() + }, + LlmProvider { + name: "plano-orchestrator".to_string(), + provider_interface: LlmProviderType::Arch, + model: Some("Plano-Orchestrator".to_string()), + internal: Some(true), + default: None, + ..Default::default() + }, + ]; + + let llm_providers = LlmProviders::try_from(providers).unwrap(); + + // Test multiple times to account for randomness + for _ in 0..10 { + let selected = get_llm_provider(&llm_providers, None); + + // Verify the selected provider is never internal + assert_ne!(selected.internal, Some(true)); + + // Verify it's one of the non-internal providers + assert!( + selected.name == "openai-gpt4" || selected.name == "anthropic-claude", + "Selected provider '{}' should be one of the non-internal providers", + selected.name + ); + } + } +}