diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 18ca8541e..2e9ec2f4a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,6 @@ # Contribution +We would love feedback on our [Roadmap](https://github.com/orgs/katanemo/projects/1) and we welcome contributions to **Plano**! We would love feedback on our [Roadmap](https://github.com/orgs/katanemo/projects/1) and we welcome contributions to **Plano**! Whether you're fixing bugs, adding new features, improving documentation, or creating tutorials, your help is much appreciated. @@ -7,8 +8,10 @@ Whether you're fixing bugs, adding new features, improving documentation, or cre ### 1. Fork the Repository +Fork the repository to create your own version of **Plano**: Fork the repository to create your own version of **Plano**: +- Navigate to the [Plano GitHub repository](https://github.com/katanemo/plano). - Navigate to the [Plano GitHub repository](https://github.com/katanemo/plano). - Click the "Fork" button in the upper right corner. - This will create a copy of the repository under your GitHub account. @@ -75,7 +78,8 @@ This creates a virtual environment in `.venv` and installs all dependencies. Optionally, install planoai globally in editable mode: ```bash -$ uv tool install --editable . +$ git clone https://github.com/katanemo/plano.git +$ cd plano ``` Now you can use `planoai` commands from anywhere, or use `uv run planoai` from the `cli` directory. @@ -123,19 +127,13 @@ $ uv run pytest -v Before committing, you can run all pre-commit checks manually: ```bash -$ pre-commit run --all-files +cd plano +cargo test ``` -This ensures your code passes all checks before you commit. - -### 9. Push Changes and Create a Pull Request - -Once your changes are tested and committed: - -```bash -$ git push origin -``` +### 6. Push changes, and create a Pull request +Go back to the original Plano repository, and you should see a "Compare & pull request" button. Click that to submit a Pull Request (PR). In your PR description, clearly explain the changes you made and why they are necessary. Go back to the original Plano repository, and you should see a "Compare & pull request" button. Click that to submit a Pull Request (PR). In your PR description, clearly explain the changes you made and why they are necessary. We will review your pull request and provide feedback. Once approved, your contribution will be merged into the main repository! diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index 636e2d222..d6dd37f1b 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -187,11 +187,21 @@ def validate_and_render_schema(): model_name = model_provider.get("model") print("Processing model_provider: ", model_provider) - if model_name in model_name_keys: + + # Check if this is a wildcard model (provider/*) + is_wildcard = False + if "/" in model_name: + model_name_tokens = model_name.split("/") + if len(model_name_tokens) >= 2 and model_name_tokens[-1] == "*": + is_wildcard = True + + if model_name in model_name_keys and not is_wildcard: raise Exception( f"Duplicate model name {model_name}, please provide unique model name for each model_provider" ) - model_name_keys.add(model_name) + + if not is_wildcard: + model_name_keys.add(model_name) if model_provider.get("name") is None: model_provider["name"] = model_name @@ -202,7 +212,21 @@ def validate_and_render_schema(): raise Exception( f"Invalid model name {model_name}. Please provide model name in the format /." ) - provider = model_name_tokens[0] + provider = model_name_tokens[0].strip() + + # Check if this is a wildcard (provider/*) + is_wildcard = model_name_tokens[-1].strip() == "*" + + # Validate wildcard constraints + if is_wildcard: + if model_provider.get("default", False): + raise Exception( + f"Model {model_name} is configured as default but uses wildcard (*). Default models cannot be wildcards." + ) + if model_provider.get("routing_preferences"): + raise Exception( + f"Model {model_name} has routing_preferences but uses wildcard (*). Models with routing preferences cannot be wildcards." + ) # Validate azure_openai and ollama provider requires base_url if (provider in SUPPORTED_PROVIDERS_WITH_BASE_URL) and model_provider.get( @@ -213,7 +237,9 @@ def validate_and_render_schema(): ) model_id = "/".join(model_name_tokens[1:]) - if provider not in SUPPORTED_PROVIDERS: + + # For wildcard providers, allow any provider name + if not is_wildcard and provider not in SUPPORTED_PROVIDERS: if ( model_provider.get("base_url", None) is None or model_provider.get("provider_interface", None) is None @@ -227,11 +253,13 @@ def validate_and_render_schema(): f"Please provide provider interface as part of model name {model_name} using the format /. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' " ) - if model_id in model_name_keys: - raise Exception( - f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider" - ) - model_name_keys.add(model_id) + # For wildcard models, don't add model_id to the keys since it's "*" + if not is_wildcard: + if model_id in model_name_keys: + raise Exception( + f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider" + ) + model_name_keys.add(model_id) for routing_preference in model_provider.get("routing_preferences", []): if routing_preference.get("name") in model_usage_name_keys: diff --git a/cli/planoai/core.py b/cli/planoai/core.py index e680b4770..574095659 100644 --- a/cli/planoai/core.py +++ b/cli/planoai/core.py @@ -145,7 +145,7 @@ def stop_docker_container(service=PLANO_DOCKER_NAME): def start_cli_agent(arch_config_file=None, settings_json="{}"): - """Start a CLI client connected to Arch.""" + """Start a CLI client connected to Plano.""" with open(arch_config_file, "r") as file: arch_config = file.read() diff --git a/cli/planoai/main.py b/cli/planoai/main.py index a27378833..9157373ca 100644 --- a/cli/planoai/main.py +++ b/cli/planoai/main.py @@ -74,7 +74,7 @@ def main(ctx, version): log.info(f"Starting plano cli version: {get_version()}") if ctx.invoked_subcommand is None: - click.echo("""Arch (The Intelligent Prompt Gateway) CLI""") + click.echo("""Plano (AI-native proxy and dataplane for agentic apps) CLI""") click.echo(logo) click.echo(ctx.get_help()) @@ -121,16 +121,16 @@ def build(): @click.command() @click.argument("file", required=False) # Optional file argument @click.option( - "--path", default=".", help="Path to the directory containing arch_config.yaml" + "--path", default=".", help="Path to the directory containing config.yaml" ) @click.option( "--foreground", default=False, - help="Run Arch in the foreground. Default is False", + help="Run Plano in the foreground. Default is False", is_flag=True, ) def up(file, path, foreground): - """Starts Arch.""" + """Starts Plano.""" # Use the utility function to find config file arch_config_file = find_config_file(path, file) @@ -270,7 +270,7 @@ def logs(debug, follow): help="Additional settings as JSON string for the CLI agent.", ) def cli_agent(type, file, path, settings): - """Start a CLI agent connected to Arch. + """Start a CLI agent connected to Plano. CLI_AGENT: The type of CLI agent to start (currently only 'claude' is supported) """ @@ -278,7 +278,7 @@ def cli_agent(type, file, path, settings): # Check if plano docker container is running archgw_status = docker_container_status(PLANO_DOCKER_NAME) if archgw_status != "running": - log.error(f"archgw docker container is not running (status: {archgw_status})") + log.error(f"plano docker container is not running (status: {archgw_status})") log.error("Please start plano using the 'planoai up' command.") sys.exit(1) diff --git a/config/README.md b/config/README.md index 49c62797c..5eb59fac2 100644 --- a/config/README.md +++ b/config/README.md @@ -18,7 +18,7 @@ $ cargo test ``` ## Local development -- Build docker image for arch gateway. Note this needs to be built once. +- Build docker image for Plano. Note this needs to be built once. ``` $ sh build_filter_image.sh ``` @@ -27,9 +27,9 @@ $ cargo test ``` $ cargo build --target wasm32-wasip1 --release ``` -- Start envoy with arch_config.yaml and test, +- Start envoy with config.yaml and test, ``` - $ docker compose -f docker-compose.dev.yaml up archgw + $ docker compose -f docker-compose.dev.yaml up plano ``` - dev version of docker-compose file uses following files that are mounted inside the container. That means no docker rebuild is needed if any of these files change. Just restart the container and chagne will be picked up, - envoy.template.yaml diff --git a/crates/Cargo.lock b/crates/Cargo.lock index c1ac44976..f182439a9 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -459,6 +459,35 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + +[[package]] +name = "cookie_store" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fc4bff745c9b4c7fb1e97b25d13153da2bc7796260141df62378998d070207f" +dependencies = [ + "cookie", + "document-features", + "idna", + "indexmap 2.9.0", + "log", + "serde", + "serde_derive", + "serde_json", + "time", + "url", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -628,6 +657,15 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "document-features" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61" +dependencies = [ + "litrs", +] + [[package]] name = "duration-string" version = "0.3.0" @@ -999,11 +1037,13 @@ version = "0.1.0" dependencies = [ "aws-smithy-eventstream", "bytes", + "chrono", "log", "serde", "serde_json", "serde_with", "thiserror 2.0.12", + "ureq", "uuid", ] @@ -1479,6 +1519,12 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" +[[package]] +name = "litrs" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" + [[package]] name = "llm_gateway" version = "0.1.0" @@ -2417,6 +2463,7 @@ version = "0.23.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", @@ -3385,6 +3432,38 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "3.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a" +dependencies = [ + "base64 0.22.1", + "cookie_store", + "flate2", + "log", + "percent-encoding", + "rustls 0.23.27", + "rustls-pki-types", + "serde", + "serde_json", + "ureq-proto", + "utf-8", + "webpki-roots", +] + +[[package]] +name = "ureq-proto" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" +dependencies = [ + "base64 0.22.1", + "http 1.3.1", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.4" @@ -3402,6 +3481,12 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -3578,6 +3663,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12bed680863276c63889429bfd6cab3b99943659923822de1c8a39c49e4d722c" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "whoami" version = "1.6.1" diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index 6e78f8b00..bd77f9482 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -123,6 +123,14 @@ pub async fn llm_chat( let is_streaming_request = client_request.is_streaming(); let resolved_model = resolve_model_alias(&model_from_request, &model_aliases); + // Handle provider/model slug format (e.g., "openai/gpt-4") + // Extract just the model name for upstream (providers don't understand the slug) + let model_name_only = if let Some((_, model)) = resolved_model.split_once('/') { + model.to_string() + } else { + resolved_model.clone() + }; + // Extract tool names and user message preview for span attributes let tool_names = client_request.get_tool_names(); let user_message_preview = client_request @@ -132,7 +140,9 @@ pub async fn llm_chat( // Extract messages for signal analysis (clone before moving client_request) let messages_for_signals = client_request.get_messages(); - client_request.set_model(resolved_model.clone()); + // Set the model to just the model name (without provider prefix) + // This ensures upstream receives "gpt-4" not "openai/gpt-4" + client_request.set_model(model_name_only.clone()); if client_request.remove_metadata_key("archgw_preference_config") { debug!( "[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", @@ -240,16 +250,22 @@ pub async fn llm_chat( } }; + // Use the resolved model (could be "gpt-4" or "openai/gpt-4") as the provider hint + // The routing layer will use llm_providers.get() which handles both formats: + // - "gpt-4" → looks up by model name + // - "openai/gpt-4" → looks up by provider/model slug + // If router doesn't find anything, it will use routing_result.model_name + let provider_hint_value = resolved_model.clone(); let model_name = routing_result.model_name; debug!( - "[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Resolved Model: {}", - request_id, full_qualified_llm_provider_url, model_name + "[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Provider Hint: {}, Model for upstream: {}", + request_id, full_qualified_llm_provider_url, provider_hint_value, model_name_only ); request_headers.insert( ARCH_PROVIDER_HINT_HEADER, - header::HeaderValue::from_str(&model_name).unwrap(), + header::HeaderValue::from_str(&provider_hint_value).unwrap(), ); request_headers.insert( diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 60fd20d04..c600ed5dc 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -255,7 +255,8 @@ impl LlmProviderType { /// Get the ProviderId for this LlmProviderType /// Used with the new function-based hermesllm API pub fn to_provider_id(&self) -> hermesllm::ProviderId { - hermesllm::ProviderId::from(self.to_string().as_str()) + hermesllm::ProviderId::try_from(self.to_string().as_str()) + .expect("LlmProviderType should always map to a valid ProviderId") } } diff --git a/crates/common/src/llm_providers.rs b/crates/common/src/llm_providers.rs index a3ae93a07..8a26f377a 100644 --- a/crates/common/src/llm_providers.rs +++ b/crates/common/src/llm_providers.rs @@ -1,4 +1,5 @@ use crate::configuration::LlmProvider; +use hermesllm::providers::ProviderId; use std::collections::HashMap; use std::rc::Rc; @@ -6,6 +7,9 @@ use std::rc::Rc; pub struct LlmProviders { providers: HashMap>, default: Option>, + /// Wildcard providers: maps provider prefix to base provider config + /// e.g., "openai" -> LlmProvider for "openai/*" + wildcard_providers: HashMap>, } impl LlmProviders { @@ -18,7 +22,36 @@ impl LlmProviders { } pub fn get(&self, name: &str) -> Option> { - self.providers.get(name).cloned() + // First try exact match + if let Some(provider) = self.providers.get(name).cloned() { + return Some(provider); + } + + // If name contains '/', it could be: + // 1. A full model ID like "openai/gpt-4" that we need to lookup + // 2. A provider/model slug that should match a wildcard provider + if let Some((provider_prefix, model_name)) = name.split_once('/') { + // Try to find the expanded model entry (e.g., "openai/gpt-4") + let full_model_id = format!("{}/{}", provider_prefix, model_name); + if let Some(provider) = self.providers.get(&full_model_id).cloned() { + return Some(provider); + } + + // Try to find just the model name (for expanded wildcard entries) + if let Some(provider) = self.providers.get(model_name).cloned() { + return Some(provider); + } + + // Fall back to wildcard match (e.g., "openai/*") + if let Some(wildcard_provider) = self.wildcard_providers.get(provider_prefix) { + // Create a new provider with the specific model from the slug + let mut specific_provider = (**wildcard_provider).clone(); + specific_provider.model = Some(model_name.to_string()); + return Some(Rc::new(specific_provider)); + } + } + + None } } @@ -43,10 +76,12 @@ impl TryFrom> for LlmProviders { let mut llm_providers = LlmProviders { providers: HashMap::new(), default: None, + wildcard_providers: HashMap::new(), }; for llm_provider in llm_providers_config { let llm_provider: Rc = Rc::new(llm_provider); + if llm_provider.default.unwrap_or_default() { match llm_providers.default { Some(_) => return Err(LlmProvidersNewError::MoreThanOneDefault), @@ -54,27 +89,168 @@ impl TryFrom> for LlmProviders { } } - // Insert and check that there is no other provider with the same name. let name = llm_provider.name.clone(); - if llm_providers - .providers - .insert(name.clone(), Rc::clone(&llm_provider)) - .is_some() - { - return Err(LlmProvidersNewError::DuplicateName(name)); - } - // also add model_id as key for provider lookup - if let Some(model) = llm_provider.model.clone() { + // Check if this is a wildcard provider (model is "*" or ends with "/*") + let is_wildcard = llm_provider + .model + .as_ref() + .map(|m| m == "*" || m.ends_with("/*")) + .unwrap_or(false); + + if is_wildcard { + // Extract provider prefix from name + // e.g., "openai/*" -> "openai" + let provider_prefix = name.trim_end_matches("/*").trim_end_matches('*'); + + // For wildcard providers, we: + // 1. Store the base config in wildcard_providers for runtime matching + // 2. Optionally expand to all known models if available + + llm_providers + .wildcard_providers + .insert(provider_prefix.to_string(), Rc::clone(&llm_provider)); + + // Try to expand wildcard using ProviderId models + if let Ok(provider_id) = ProviderId::try_from(provider_prefix) { + let models = provider_id.models(); + if !models.is_empty() { + log::info!( + "Expanding wildcard provider '{}' to {} models", + provider_prefix, + models.len() + ); + + // Create a provider entry for each model + for model_name in models { + let full_model_id = format!("{}/{}", provider_prefix, model_name); + + // Create a new provider with the specific model + let mut expanded_provider = (*llm_provider).clone(); + expanded_provider.model = Some(model_name.clone()); + expanded_provider.name = full_model_id.clone(); + + let expanded_rc = Rc::new(expanded_provider); + + // Insert with full model ID as key + llm_providers + .providers + .insert(full_model_id.clone(), Rc::clone(&expanded_rc)); + + // Also insert with just model name for backward compatibility + llm_providers.providers.insert(model_name, expanded_rc); + } + } + } else { + log::warn!( + "Wildcard provider '{}' specified but no models found in registry. \ + Will match dynamically at runtime.", + provider_prefix + ); + } + } else { + // Non-wildcard provider - original behavior if llm_providers .providers - .insert(model, llm_provider) + .insert(name.clone(), Rc::clone(&llm_provider)) .is_some() { return Err(LlmProvidersNewError::DuplicateName(name)); } + + // also add model_id as key for provider lookup + if let Some(model) = llm_provider.model.clone() { + if llm_providers + .providers + .insert(model, llm_provider) + .is_some() + { + return Err(LlmProvidersNewError::DuplicateName(name)); + } + } } } + Ok(llm_providers) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::configuration::LlmProviderType; + + fn create_test_provider(name: &str, model: Option) -> LlmProvider { + LlmProvider { + name: name.to_string(), + model, + access_key: None, + endpoint: None, + cluster_name: None, + provider_interface: LlmProviderType::OpenAI, + default: None, + base_url_path_prefix: None, + port: None, + rate_limits: None, + usage: None, + routing_preferences: None, + internal: None, + stream: None, + passthrough_auth: None, + } + } + + #[test] + fn test_static_provider_lookup() { + // Test 1: Statically defined provider - should be findable by model or provider name + let providers = vec![create_test_provider("my-openai", Some("gpt-4".to_string()))]; + let llm_providers = LlmProviders::try_from(providers).unwrap(); + + // Should find by model name + let result = llm_providers.get("gpt-4"); + assert!(result.is_some()); + assert_eq!(result.unwrap().name, "my-openai"); + + // Should also find by provider name + let result = llm_providers.get("my-openai"); + assert!(result.is_some()); + assert_eq!(result.unwrap().name, "my-openai"); + } + + #[test] + fn test_wildcard_provider_with_known_model() { + // Test 2: Wildcard provider that expands to OpenAI models + let providers = vec![create_test_provider("openai/*", Some("*".to_string()))]; + let llm_providers = LlmProviders::try_from(providers).unwrap(); + + // Should find via expanded wildcard entry + let result = llm_providers.get("openai/gpt-4"); + let provider = result.unwrap(); + assert_eq!(provider.name, "openai/gpt-4"); + assert_eq!(provider.model.as_ref().unwrap(), "gpt-4"); + + // Should also be able to find by just model name (from expansion) + let result = llm_providers.get("gpt-4"); + assert_eq!(result.unwrap().model.as_ref().unwrap(), "gpt-4"); + } + + #[test] + fn test_custom_wildcard_provider_with_full_slug() { + // Test 3: Custom wildcard provider with full slug offered + let providers = vec![create_test_provider( + "custom-provider/*", + Some("*".to_string()), + )]; + let llm_providers = LlmProviders::try_from(providers).unwrap(); + + // Should match via wildcard fallback and extract model name from slug + let result = llm_providers.get("custom-provider/custom-model"); + let provider = result.unwrap(); + assert_eq!(provider.model.as_ref().unwrap(), "custom-model"); + + // Wildcard should be stored + assert!(llm_providers + .wildcard_providers + .contains_key("custom-provider")); + } +} diff --git a/crates/common/src/routing.rs b/crates/common/src/routing.rs index 918435a64..2007c579c 100644 --- a/crates/common/src/routing.rs +++ b/crates/common/src/routing.rs @@ -2,9 +2,8 @@ use std::rc::Rc; use crate::{configuration, llm_providers::LlmProviders}; use configuration::LlmProvider; -use rand::{seq::IteratorRandom, thread_rng}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum ProviderHint { Default, Name(String), @@ -22,33 +21,19 @@ impl From for ProviderHint { pub fn get_llm_provider( llm_providers: &LlmProviders, provider_hint: Option, -) -> Rc { - let maybe_provider = provider_hint.and_then(|hint| match hint { - ProviderHint::Default => llm_providers.default(), - // FIXME: should a non-existent name in the hint be more explicit? i.e, return a BAD_REQUEST? - ProviderHint::Name(name) => llm_providers.get(&name), - }); - - if let Some(provider) = maybe_provider { - return provider; - } - - if llm_providers.default().is_some() { - return llm_providers.default().unwrap(); +) -> Result, String> { + match provider_hint { + Some(ProviderHint::Default) => llm_providers + .default() + .ok_or_else(|| "No default provider configured".to_string()), + Some(ProviderHint::Name(name)) => llm_providers + .get(&name) + .ok_or_else(|| format!("Model '{}' not found in configured providers", name)), + None => { + // No hint provided - must have a default configured + llm_providers + .default() + .ok_or_else(|| "No model specified and no default provider configured".to_string()) + } } - - let mut rng = thread_rng(); - llm_providers - .iter() - .filter(|(_, provider)| { - provider - .model - .as_ref() - .map(|m| !m.starts_with("Arch")) - .unwrap_or(true) - }) - .choose(&mut rng) - .expect("There should always be at least one non-Arch llm provider") - .1 - .clone() } diff --git a/crates/hermesllm/Cargo.toml b/crates/hermesllm/Cargo.toml index d877fc00d..7bce647f2 100644 --- a/crates/hermesllm/Cargo.toml +++ b/crates/hermesllm/Cargo.toml @@ -3,6 +3,11 @@ name = "hermesllm" version = "0.1.0" edition = "2021" +[[bin]] +name = "fetch_models" +path = "src/bin/fetch_models.rs" +required-features = ["model-fetch"] + [dependencies] serde = {version = "1.0.219", features = ["derive"]} serde_json = "1.0.140" @@ -12,3 +17,9 @@ aws-smithy-eventstream = "0.60" bytes = "1.10" uuid = { version = "1.11", features = ["v4"] } log = "0.4" +chrono = { version = "0.4", optional = true } +ureq = { version = "3.1", features = ["json"], optional = true } + +[features] +default = [] +model-fetch = ["ureq", "chrono"] diff --git a/crates/hermesllm/src/bin/fetch_models.rs b/crates/hermesllm/src/bin/fetch_models.rs new file mode 100644 index 000000000..73f5e0d98 --- /dev/null +++ b/crates/hermesllm/src/bin/fetch_models.rs @@ -0,0 +1,167 @@ +// Fetch latest provider models from OpenRouter and update provider_models.json +// Usage: OPENROUTER_API_KEY=xxx cargo run --bin fetch_models + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +fn main() { + // Default to writing in the same directory as this source file + let default_path = std::path::Path::new(file!()) + .parent() + .unwrap() + .join("provider_models.json"); + + let output_path = std::env::args() + .nth(1) + .unwrap_or_else(|| default_path.to_string_lossy().to_string()); + + println!("Fetching latest models from OpenRouter..."); + + match fetch_openrouter_models() { + Ok(models) => { + let json = serde_json::to_string_pretty(&models).expect("Failed to serialize models"); + + std::fs::write(&output_path, json).expect("Failed to write provider_models.json"); + + println!( + "✓ Successfully updated {} providers ({} models) to {}", + models.metadata.total_providers, models.metadata.total_models, output_path + ); + } + Err(e) => { + eprintln!("Error fetching models: {}", e); + eprintln!("\nMake sure OPENROUTER_API_KEY is set:"); + eprintln!(" export OPENROUTER_API_KEY=your-key-here"); + eprintln!(" cargo run --bin fetch_models"); + std::process::exit(1); + } + } +} + +#[derive(Debug, Deserialize)] +struct OpenRouterModel { + id: String, + architecture: Option, +} + +#[derive(Debug, Deserialize)] +struct Architecture { + modality: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenRouterResponse { + data: Vec, +} + +#[derive(Debug, Serialize)] +struct ProviderModels { + version: String, + source: String, + providers: HashMap>, + metadata: Metadata, +} + +#[derive(Debug, Serialize)] +struct Metadata { + total_providers: usize, + total_models: usize, + last_updated: String, +} + +fn fetch_openrouter_models() -> Result> { + let api_key = std::env::var("OPENROUTER_API_KEY") + .map_err(|_| "OPENROUTER_API_KEY environment variable not set")?; + + let response_body = ureq::get("https://openrouter.ai/api/v1/models") + .header("Authorization", &format!("Bearer {}", api_key)) + .call()? + .body_mut() + .read_to_string()?; + + let openrouter_response: OpenRouterResponse = serde_json::from_str(&response_body)?; + + // Supported providers to include + let supported_providers = [ + "openai", + "anthropic", + "mistralai", + "deepseek", + "google", + "x-ai", + "moonshotai", + "qwen", + "amazon", + "z-ai", + ]; + + let mut providers: HashMap> = HashMap::new(); + let mut total_models = 0; + let mut filtered_modality: Vec<(String, String)> = Vec::new(); + let mut filtered_provider: Vec<(String, Option)> = Vec::new(); + + for model in openrouter_response.data { + let modality = model + .architecture + .as_ref() + .and_then(|arch| arch.modality.clone()); + + // Only include text->text and text+image->text models + if let Some(ref mod_str) = modality { + if mod_str != "text->text" && mod_str != "text" && mod_str != "text+image->text" { + filtered_modality.push((model.id.clone(), mod_str.clone())); + continue; + } + } + + // Extract provider from model ID (e.g., "openai/gpt-4" -> "openai") + if let Some(provider_name) = model.id.split('/').next() { + if supported_providers.contains(&provider_name) { + providers + .entry(provider_name.to_string()) + .or_default() + .push(model.id.clone()); + total_models += 1; + } else { + filtered_provider.push((model.id.clone(), modality)); + } + } + } + + println!("✅ Loaded models from {} providers:", providers.len()); + let mut sorted_providers: Vec<_> = providers.iter().collect(); + sorted_providers.sort_by_key(|(name, _)| *name); + for (provider, models) in sorted_providers { + println!(" • {}: {} models", provider, models.len()); + } + + // Group filtered providers to get counts + let mut filtered_by_provider: HashMap = HashMap::new(); + for (model_id, _modality) in &filtered_provider { + if let Some(provider_name) = model_id.split('/').next() { + *filtered_by_provider + .entry(provider_name.to_string()) + .or_insert(0) += 1; + } + } + + println!( + "\n⏭️ Skipped {} providers ({} models total)", + filtered_by_provider.len(), + filtered_provider.len() + ); + println!(); + + let total_providers = providers.len(); + + Ok(ProviderModels { + version: "1.0".to_string(), + source: "openrouter".to_string(), + providers, + metadata: Metadata { + total_providers, + total_models, + last_updated: chrono::Utc::now().to_rfc3339(), + }, + }) +} diff --git a/crates/hermesllm/src/bin/provider_models.json b/crates/hermesllm/src/bin/provider_models.json new file mode 100644 index 000000000..509293bf8 --- /dev/null +++ b/crates/hermesllm/src/bin/provider_models.json @@ -0,0 +1,236 @@ +{ + "version": "1.0", + "source": "openrouter", + "providers": { + "openai": [ + "openai/gpt-5.2-codex", + "openai/gpt-5.2-chat", + "openai/gpt-5.2-pro", + "openai/gpt-5.2", + "openai/gpt-5.1-codex-max", + "openai/gpt-5.1", + "openai/gpt-5.1-chat", + "openai/gpt-5.1-codex", + "openai/gpt-5.1-codex-mini", + "openai/gpt-oss-safeguard-20b", + "openai/o3-deep-research", + "openai/o4-mini-deep-research", + "openai/gpt-5-pro", + "openai/gpt-5-codex", + "openai/gpt-4o-audio-preview", + "openai/gpt-5-chat", + "openai/gpt-5", + "openai/gpt-5-mini", + "openai/gpt-5-nano", + "openai/gpt-oss-120b:free", + "openai/gpt-oss-120b", + "openai/gpt-oss-120b:exacto", + "openai/gpt-oss-20b:free", + "openai/gpt-oss-20b", + "openai/o3-pro", + "openai/o4-mini-high", + "openai/o3", + "openai/o4-mini", + "openai/gpt-4.1", + "openai/gpt-4.1-mini", + "openai/gpt-4.1-nano", + "openai/o1-pro", + "openai/gpt-4o-mini-search-preview", + "openai/gpt-4o-search-preview", + "openai/o3-mini-high", + "openai/o3-mini", + "openai/o1", + "openai/gpt-4o-2024-11-20", + "openai/chatgpt-4o-latest", + "openai/gpt-4o-2024-08-06", + "openai/gpt-4o-mini-2024-07-18", + "openai/gpt-4o-mini", + "openai/gpt-4o-2024-05-13", + "openai/gpt-4o", + "openai/gpt-4o:extended", + "openai/gpt-4-turbo", + "openai/gpt-3.5-turbo-0613", + "openai/gpt-4-turbo-preview", + "openai/gpt-4-1106-preview", + "openai/gpt-3.5-turbo-instruct", + "openai/gpt-3.5-turbo-16k", + "openai/gpt-4-0314", + "openai/gpt-4", + "openai/gpt-3.5-turbo" + ], + "mistralai": [ + "mistralai/mistral-small-creative", + "mistralai/devstral-2512:free", + "mistralai/devstral-2512", + "mistralai/ministral-14b-2512", + "mistralai/ministral-8b-2512", + "mistralai/ministral-3b-2512", + "mistralai/mistral-large-2512", + "mistralai/voxtral-small-24b-2507", + "mistralai/mistral-medium-3.1", + "mistralai/codestral-2508", + "mistralai/devstral-medium", + "mistralai/devstral-small", + "mistralai/mistral-small-3.2-24b-instruct", + "mistralai/mistral-medium-3", + "mistralai/mistral-small-3.1-24b-instruct:free", + "mistralai/mistral-small-3.1-24b-instruct", + "mistralai/mistral-saba", + "mistralai/mistral-small-24b-instruct-2501", + "mistralai/mistral-large-2411", + "mistralai/mistral-large-2407", + "mistralai/pixtral-large-2411", + "mistralai/ministral-8b", + "mistralai/ministral-3b", + "mistralai/pixtral-12b", + "mistralai/mistral-nemo", + "mistralai/mistral-7b-instruct", + "mistralai/mistral-7b-instruct-v0.3", + "mistralai/mixtral-8x22b-instruct", + "mistralai/mistral-large", + "mistralai/mistral-tiny", + "mistralai/mistral-7b-instruct-v0.2", + "mistralai/mixtral-8x7b-instruct", + "mistralai/mistral-7b-instruct-v0.1" + ], + "qwen": [ + "qwen/qwen3-vl-32b-instruct", + "qwen/qwen3-vl-8b-thinking", + "qwen/qwen3-vl-8b-instruct", + "qwen/qwen3-vl-30b-a3b-thinking", + "qwen/qwen3-vl-30b-a3b-instruct", + "qwen/qwen3-vl-235b-a22b-thinking", + "qwen/qwen3-vl-235b-a22b-instruct", + "qwen/qwen3-max", + "qwen/qwen3-coder-plus", + "qwen/qwen3-coder-flash", + "qwen/qwen3-next-80b-a3b-thinking", + "qwen/qwen3-next-80b-a3b-instruct:free", + "qwen/qwen3-next-80b-a3b-instruct", + "qwen/qwen-plus-2025-07-28", + "qwen/qwen-plus-2025-07-28:thinking", + "qwen/qwen3-30b-a3b-thinking-2507", + "qwen/qwen3-coder-30b-a3b-instruct", + "qwen/qwen3-30b-a3b-instruct-2507", + "qwen/qwen3-235b-a22b-thinking-2507", + "qwen/qwen3-coder:free", + "qwen/qwen3-coder", + "qwen/qwen3-coder:exacto", + "qwen/qwen3-235b-a22b-2507", + "qwen/qwen3-4b:free", + "qwen/qwen3-30b-a3b", + "qwen/qwen3-8b", + "qwen/qwen3-14b", + "qwen/qwen3-32b", + "qwen/qwen3-235b-a22b", + "qwen/qwen2.5-coder-7b-instruct", + "qwen/qwen2.5-vl-32b-instruct", + "qwen/qwq-32b", + "qwen/qwen-vl-plus", + "qwen/qwen-vl-max", + "qwen/qwen-turbo", + "qwen/qwen2.5-vl-72b-instruct", + "qwen/qwen-plus", + "qwen/qwen-max", + "qwen/qwen-2.5-coder-32b-instruct", + "qwen/qwen-2.5-7b-instruct", + "qwen/qwen-2.5-72b-instruct", + "qwen/qwen-2.5-vl-7b-instruct:free", + "qwen/qwen-2.5-vl-7b-instruct" + ], + "z-ai": [ + "z-ai/glm-4.7", + "z-ai/glm-4.6v", + "z-ai/glm-4.6", + "z-ai/glm-4.6:exacto", + "z-ai/glm-4.5v", + "z-ai/glm-4.5", + "z-ai/glm-4.5-air:free", + "z-ai/glm-4.5-air", + "z-ai/glm-4-32b" + ], + "moonshotai": [ + "moonshotai/kimi-k2-thinking", + "moonshotai/kimi-k2-0905", + "moonshotai/kimi-k2-0905:exacto", + "moonshotai/kimi-k2:free", + "moonshotai/kimi-k2", + "moonshotai/kimi-dev-72b" + ], + "anthropic": [ + "anthropic/claude-opus-4.5", + "anthropic/claude-haiku-4.5", + "anthropic/claude-sonnet-4.5", + "anthropic/claude-opus-4.1", + "anthropic/claude-opus-4", + "anthropic/claude-sonnet-4", + "anthropic/claude-3.7-sonnet:thinking", + "anthropic/claude-3.7-sonnet", + "anthropic/claude-3.5-haiku", + "anthropic/claude-3.5-sonnet", + "anthropic/claude-3-haiku" + ], + "google": [ + "google/gemini-3-flash-preview", + "google/gemini-3-pro-preview", + "google/gemini-2.5-flash-preview-09-2025", + "google/gemini-2.5-flash-lite-preview-09-2025", + "google/gemini-2.5-flash-lite", + "google/gemma-3n-e2b-it:free", + "google/gemini-2.5-flash", + "google/gemini-2.5-pro", + "google/gemini-2.5-pro-preview", + "google/gemma-3n-e4b-it:free", + "google/gemma-3n-e4b-it", + "google/gemini-2.5-pro-preview-05-06", + "google/gemma-3-4b-it:free", + "google/gemma-3-4b-it", + "google/gemma-3-12b-it:free", + "google/gemma-3-12b-it", + "google/gemma-3-27b-it:free", + "google/gemma-3-27b-it", + "google/gemini-2.0-flash-lite-001", + "google/gemini-2.0-flash-001", + "google/gemini-2.0-flash-exp:free", + "google/gemma-2-27b-it", + "google/gemma-2-9b-it" + ], + "amazon": [ + "amazon/nova-2-lite-v1", + "amazon/nova-premier-v1", + "amazon/nova-lite-v1", + "amazon/nova-micro-v1", + "amazon/nova-pro-v1" + ], + "deepseek": [ + "deepseek/deepseek-v3.2-speciale", + "deepseek/deepseek-v3.2", + "deepseek/deepseek-v3.2-exp", + "deepseek/deepseek-v3.1-terminus:exacto", + "deepseek/deepseek-v3.1-terminus", + "deepseek/deepseek-chat-v3.1", + "deepseek/deepseek-r1-0528:free", + "deepseek/deepseek-r1-0528", + "deepseek/deepseek-chat-v3-0324", + "deepseek/deepseek-r1-distill-qwen-32b", + "deepseek/deepseek-r1-distill-llama-70b", + "deepseek/deepseek-r1", + "deepseek/deepseek-chat" + ], + "x-ai": [ + "x-ai/grok-4.1-fast", + "x-ai/grok-4-fast", + "x-ai/grok-code-fast-1", + "x-ai/grok-4", + "x-ai/grok-3-mini", + "x-ai/grok-3", + "x-ai/grok-3-mini-beta", + "x-ai/grok-3-beta" + ] + }, + "metadata": { + "total_providers": 10, + "total_models": 205, + "last_updated": "2026-01-16T20:30:00.806165+00:00" + } +} diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 3f8324e96..997fc72a4 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -29,10 +29,27 @@ mod tests { #[test] fn test_provider_id_conversion() { - assert_eq!(ProviderId::from("openai"), ProviderId::OpenAI); - assert_eq!(ProviderId::from("mistral"), ProviderId::Mistral); - assert_eq!(ProviderId::from("groq"), ProviderId::Groq); - assert_eq!(ProviderId::from("arch"), ProviderId::Arch); + assert_eq!(ProviderId::try_from("openai").unwrap(), ProviderId::OpenAI); + assert_eq!( + ProviderId::try_from("mistral").unwrap(), + ProviderId::Mistral + ); + assert_eq!(ProviderId::try_from("groq").unwrap(), ProviderId::Groq); + assert_eq!(ProviderId::try_from("arch").unwrap(), ProviderId::Arch); + + // Test aliases + assert_eq!(ProviderId::try_from("google").unwrap(), ProviderId::Gemini); + assert_eq!( + ProviderId::try_from("together").unwrap(), + ProviderId::TogetherAI + ); + assert_eq!( + ProviderId::try_from("amazon").unwrap(), + ProviderId::AmazonBedrock + ); + + // Test error case + assert!(ProviderId::try_from("unknown_provider").is_err()); } #[test] diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index afaabea6d..902239635 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -1,6 +1,36 @@ use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi}; use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; +use std::collections::HashMap; use std::fmt::Display; +use std::sync::OnceLock; + +static PROVIDER_MODELS_JSON: &str = include_str!("../bin/provider_models.json"); + +fn load_provider_models() -> &'static HashMap> { + static MODELS: OnceLock>> = OnceLock::new(); + MODELS.get_or_init(|| { + let data: serde_json::Value = serde_json::from_str(PROVIDER_MODELS_JSON) + .expect("Failed to parse provider_models.json"); + + let providers = data + .get("providers") + .expect("Missing 'providers' key") + .as_object() + .expect("'providers' must be an object"); + + let mut result = HashMap::new(); + for (provider, models) in providers { + let model_list: Vec = models + .as_array() + .expect("Models must be an array") + .iter() + .map(|m| m.as_str().expect("Model must be a string").to_string()) + .collect(); + result.insert(provider.clone(), model_list); + } + result + }) +} /// Provider identifier enum - simple enum for identifying providers #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -23,31 +53,70 @@ pub enum ProviderId { AmazonBedrock, } -impl From<&str> for ProviderId { - fn from(value: &str) -> Self { +impl TryFrom<&str> for ProviderId { + type Error = String; + + fn try_from(value: &str) -> Result { match value.to_lowercase().as_str() { - "openai" => ProviderId::OpenAI, - "mistral" => ProviderId::Mistral, - "deepseek" => ProviderId::Deepseek, - "groq" => ProviderId::Groq, - "gemini" => ProviderId::Gemini, - "anthropic" => ProviderId::Anthropic, - "github" => ProviderId::GitHub, - "arch" => ProviderId::Arch, - "azure_openai" => ProviderId::AzureOpenAI, - "xai" => ProviderId::XAI, - "together_ai" => ProviderId::TogetherAI, - "ollama" => ProviderId::Ollama, - "moonshotai" => ProviderId::Moonshotai, - "zhipu" => ProviderId::Zhipu, - "qwen" => ProviderId::Qwen, // alias for Qwen - "amazon_bedrock" => ProviderId::AmazonBedrock, - _ => panic!("Unknown provider: {}", value), + "openai" => Ok(ProviderId::OpenAI), + "mistral" => Ok(ProviderId::Mistral), + "deepseek" => Ok(ProviderId::Deepseek), + "groq" => Ok(ProviderId::Groq), + "gemini" => Ok(ProviderId::Gemini), + "google" => Ok(ProviderId::Gemini), // alias + "anthropic" => Ok(ProviderId::Anthropic), + "github" => Ok(ProviderId::GitHub), + "arch" => Ok(ProviderId::Arch), + "azure_openai" => Ok(ProviderId::AzureOpenAI), + "xai" => Ok(ProviderId::XAI), + "together_ai" => Ok(ProviderId::TogetherAI), + "together" => Ok(ProviderId::TogetherAI), // alias + "ollama" => Ok(ProviderId::Ollama), + "moonshotai" => Ok(ProviderId::Moonshotai), + "zhipu" => Ok(ProviderId::Zhipu), + "qwen" => Ok(ProviderId::Qwen), + "amazon_bedrock" => Ok(ProviderId::AmazonBedrock), + "amazon" => Ok(ProviderId::AmazonBedrock), // alias + _ => Err(format!("Unknown provider: {}", value)), } } } impl ProviderId { + /// Get all available models for this provider + /// Returns model names without the provider prefix (e.g., "gpt-4" not "openai/gpt-4") + pub fn models(&self) -> Vec { + let provider_key = match self { + ProviderId::AmazonBedrock => "amazon", + ProviderId::AzureOpenAI => "openai", + ProviderId::TogetherAI => "together", + ProviderId::Gemini => "google", + ProviderId::OpenAI => "openai", + ProviderId::Anthropic => "anthropic", + ProviderId::Mistral => "mistralai", + ProviderId::Deepseek => "deepseek", + ProviderId::Groq => "groq", + ProviderId::XAI => "x-ai", + ProviderId::Moonshotai => "moonshotai", + ProviderId::Zhipu => "z-ai", + ProviderId::Qwen => "qwen", + _ => return Vec::new(), + }; + + load_provider_models() + .get(provider_key) + .map(|models| { + models + .iter() + .filter_map(|model| { + // Strip provider prefix (e.g., "openai/gpt-4" -> "gpt-4") + model.split_once('/').map(|(_, name)| name.to_string()) + }) + .collect() + }) + .unwrap_or_default() + } + /// Given a client API, return the compatible upstream API for this provider pub fn compatible_api_for_client( &self, @@ -169,3 +238,102 @@ impl Display for ProviderId { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_models_loaded_from_json() { + // Test that we can load models for each supported provider + let openai_models = ProviderId::OpenAI.models(); + assert!(!openai_models.is_empty(), "OpenAI should have models"); + + let anthropic_models = ProviderId::Anthropic.models(); + assert!(!anthropic_models.is_empty(), "Anthropic should have models"); + + let mistral_models = ProviderId::Mistral.models(); + assert!(!mistral_models.is_empty(), "Mistral should have models"); + + let deepseek_models = ProviderId::Deepseek.models(); + assert!(!deepseek_models.is_empty(), "Deepseek should have models"); + + let gemini_models = ProviderId::Gemini.models(); + assert!(!gemini_models.is_empty(), "Gemini should have models"); + } + + #[test] + fn test_model_names_without_provider_prefix() { + // Test that model names don't include the provider/ prefix + let openai_models = ProviderId::OpenAI.models(); + for model in &openai_models { + assert!( + !model.contains('/'), + "Model name '{}' should not contain provider prefix", + model + ); + } + + let anthropic_models = ProviderId::Anthropic.models(); + for model in &anthropic_models { + assert!( + !model.contains('/'), + "Model name '{}' should not contain provider prefix", + model + ); + } + } + + #[test] + fn test_specific_models_exist() { + // Test that specific well-known models are present + let openai_models = ProviderId::OpenAI.models(); + let has_gpt4 = openai_models.iter().any(|m| m.contains("gpt-4")); + assert!(has_gpt4, "OpenAI models should include GPT-4 variants"); + + let anthropic_models = ProviderId::Anthropic.models(); + let has_claude = anthropic_models.iter().any(|m| m.contains("claude")); + assert!( + has_claude, + "Anthropic models should include Claude variants" + ); + } + + #[test] + fn test_unsupported_providers_return_empty() { + // Providers without models should return empty vec + let github_models = ProviderId::GitHub.models(); + assert!( + github_models.is_empty(), + "GitHub should return empty models list" + ); + + let ollama_models = ProviderId::Ollama.models(); + assert!( + ollama_models.is_empty(), + "Ollama should return empty models list" + ); + } + + #[test] + fn test_provider_name_mapping() { + // Test that provider key mappings work correctly + let xai_models = ProviderId::XAI.models(); + assert!( + !xai_models.is_empty(), + "XAI should have models (mapped to x-ai)" + ); + + let zhipu_models = ProviderId::Zhipu.models(); + assert!( + !zhipu_models.is_empty(), + "Zhipu should have models (mapped to z-ai)" + ); + + let amazon_models = ProviderId::AmazonBedrock.models(); + assert!( + !amazon_models.is_empty(), + "AmazonBedrock should have models (mapped to amazon)" + ); + } +} diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 8da0f92af..9992cd6e6 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,6 +1,6 @@ use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use http::StatusCode; -use log::{debug, info, warn}; +use log::{debug, error, info, warn}; use proxy_wasm::hostcalls::get_current_time; use proxy_wasm::traits::*; use proxy_wasm::types::*; @@ -128,16 +128,23 @@ impl StreamContext { } } - fn select_llm_provider(&mut self) { + fn select_llm_provider(&mut self) -> Result<(), String> { let provider_hint = self .get_http_request_header(ARCH_PROVIDER_HINT_HEADER) .map(|llm_name| llm_name.into()); // info!("llm_providers: {:?}", self.llm_providers); - self.llm_provider = Some(routing::get_llm_provider( - &self.llm_providers, - provider_hint, - )); + let provider = + routing::get_llm_provider(&self.llm_providers, provider_hint).map_err(|err| { + error!( + "[PLANO_REQ_ID:{}] PROVIDER_SELECTION_FAILED: Hint='None' Error='{}'", + self.request_identifier(), + err + ); + err + })?; + + self.llm_provider = Some(provider); info!( "[PLANO_REQ_ID:{}] PROVIDER_SELECTION: Hint='{}' -> Selected='{}'", @@ -146,6 +153,8 @@ impl StreamContext { .unwrap_or("none".to_string()), self.llm_provider.as_ref().unwrap().name ); + + Ok(()) } fn modify_auth_headers(&mut self) -> Result<(), ServerError> { @@ -764,7 +773,15 @@ impl HttpContext for StreamContext { // let routing_header_value = self.get_http_request_header(ARCH_ROUTING_HEADER); - self.select_llm_provider(); + if let Err(err) = self.select_llm_provider() { + self.send_http_response( + 400, + vec![], + Some(format!(r#"{{"error": "{}"}}"#, err).as_bytes()), + ); + return Action::Continue; + } + // Check if this is a supported API endpoint if SupportedAPIsFromClient::from_endpoint(&request_path).is_none() { self.send_http_response(404, vec![], Some(b"Unsupported endpoint")); diff --git a/demos/samples_java/weather_forcecast_service/run_demo.sh b/demos/samples_java/weather_forcecast_service/run_demo.sh index 992bb260b..6623dee5e 100644 --- a/demos/samples_java/weather_forcecast_service/run_demo.sh +++ b/demos/samples_java/weather_forcecast_service/run_demo.sh @@ -18,8 +18,8 @@ start_demo() { echo ".env file created with OPENAI_API_KEY." fi - # Step 3: Start Arch - echo "Starting Arch with config.yaml..." + # Step 3: Start Plano + echo "Starting Plano with config.yaml..." planoai up config.yaml # Step 4: Start developer services @@ -33,8 +33,8 @@ stop_demo() { echo "Stopping Network Agent using Docker Compose..." docker compose down - # Step 2: Stop Arch - echo "Stopping Arch..." + # Step 2: Stop Plano + echo "Stopping Plano..." planoai down } diff --git a/demos/samples_python/currency_exchange/run_demo.sh b/demos/samples_python/currency_exchange/run_demo.sh index 992bb260b..6623dee5e 100644 --- a/demos/samples_python/currency_exchange/run_demo.sh +++ b/demos/samples_python/currency_exchange/run_demo.sh @@ -18,8 +18,8 @@ start_demo() { echo ".env file created with OPENAI_API_KEY." fi - # Step 3: Start Arch - echo "Starting Arch with config.yaml..." + # Step 3: Start Plano + echo "Starting Plano with config.yaml..." planoai up config.yaml # Step 4: Start developer services @@ -33,8 +33,8 @@ stop_demo() { echo "Stopping Network Agent using Docker Compose..." docker compose down - # Step 2: Stop Arch - echo "Stopping Arch..." + # Step 2: Stop Plano + echo "Stopping Plano..." planoai down } diff --git a/demos/use_cases/wildcard_providers/config.yaml b/demos/use_cases/wildcard_providers/config.yaml new file mode 100644 index 000000000..8ddd04698 --- /dev/null +++ b/demos/use_cases/wildcard_providers/config.yaml @@ -0,0 +1,26 @@ +version: v0.3.0 + +listeners: + # Model listener for direct LLM access + - type: model + name: llms + address: 0.0.0.0 + port: 12000 + +model_providers: + # OpenAI - support all models via wildcard + - model: openai/* + access_key: $OPENAI_API_KEY + + # Anthropic - support all Claude models + - model: anthropic/* + access_key: $ANTHROPIC_API_KEY + + - model: xai/* + access_key: $GROK_API_KEY + + + # Custom internal LLM provider + # Note: Requires base_url and provider_interface for unknown providers + - model: ollama/* + base_url: https://llm.internal.company.com