diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index fe389be0..a57082fe 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -27,7 +27,14 @@ fn extract_schema_from_return_type(ret_type: &syn::Type) -> Option { // First, try direct Json if let Some(inner_type) = extract_json_inner_type(ret_type) { return syn::parse2::(quote! { - rmcp::handler::server::tool::cached_schema_for_type::<#inner_type>() + rmcp::handler::server::tool::cached_schema_for_output::<#inner_type>() + .unwrap_or_else(|e| { + panic!( + "Invalid output schema for Json<{}>: {}", + std::any::type_name::<#inner_type>(), + e + ) + }) }) .ok(); } @@ -57,7 +64,14 @@ fn extract_schema_from_return_type(ret_type: &syn::Type) -> Option { let inner_type = extract_json_inner_type(ok_type)?; syn::parse2::(quote! { - rmcp::handler::server::tool::cached_schema_for_type::<#inner_type>() + rmcp::handler::server::tool::cached_schema_for_output::<#inner_type>() + .unwrap_or_else(|e| { + panic!( + "Invalid output schema for Result, E>: {}", + std::any::type_name::<#inner_type>(), + e + ) + }) }) .ok() } diff --git a/crates/rmcp/src/handler/server/common.rs b/crates/rmcp/src/handler/server/common.rs index ad144ae3..0f60670a 100644 --- a/crates/rmcp/src/handler/server/common.rs +++ b/crates/rmcp/src/handler/server/common.rs @@ -50,6 +50,50 @@ pub fn cached_schema_for_type() -> Arc() -> Result { + let schema = schema_for_type::(); + + match schema.get("type") { + Some(serde_json::Value::String(t)) if t == "object" => Ok(schema), + Some(serde_json::Value::String(t)) => Err(format!( + "MCP specification requires tool outputSchema to have root type 'object', but found '{}'.", + t + )), + None => Err( + "Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string() + ), + Some(other) => Err(format!( + "Schema 'type' field has unexpected format: {:?}. Expected \"object\".", + other + )), + } +} + +/// Call [`schema_for_output`] with a cache. +pub fn cached_schema_for_output() -> Result, String> +{ + thread_local! { + static CACHE_FOR_OUTPUT: std::sync::RwLock, String>>> = Default::default(); + }; + CACHE_FOR_OUTPUT.with(|cache| { + if let Some(result) = cache + .read() + .expect("output schema cache lock poisoned") + .get(&TypeId::of::()) + { + result.clone() + } else { + let result = schema_for_output::().map(Arc::new); + cache + .write() + .expect("output schema cache lock poisoned") + .insert(TypeId::of::(), result.clone()); + result + } + }) +} + /// Trait for extracting parts from a context, unifying tool and prompt extraction pub trait FromContextPart: Sized { fn from_context_part(context: &mut C) -> Result; @@ -143,3 +187,25 @@ pub trait AsRequestContext { fn as_request_context(&self) -> &RequestContext; fn as_request_context_mut(&mut self) -> &mut RequestContext; } + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(serde::Serialize, serde::Deserialize, JsonSchema)] + struct TestObject { + value: i32, + } + + #[test] + fn test_schema_for_output_rejects_primitive() { + let result = schema_for_output::(); + assert!(result.is_err(),); + } + + #[test] + fn test_schema_for_output_accepts_object() { + let result = schema_for_output::(); + assert!(result.is_ok(),); + } +} diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index cf842679..98a54e98 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -9,7 +9,9 @@ use serde::de::DeserializeOwned; use super::common::{AsRequestContext, FromContextPart}; pub use super::{ - common::{Extension, RequestId, cached_schema_for_type, schema_for_type}, + common::{ + Extension, RequestId, cached_schema_for_output, cached_schema_for_type, schema_for_type, + }, router::tool::{ToolRoute, ToolRouter}, }; use crate::{ diff --git a/crates/rmcp/src/model/tool.rs b/crates/rmcp/src/model/tool.rs index c7309bbe..55a47849 100644 --- a/crates/rmcp/src/model/tool.rs +++ b/crates/rmcp/src/model/tool.rs @@ -165,8 +165,14 @@ impl Tool { } /// Set the output schema using a type that implements JsonSchema + /// + /// # Panics + /// + /// Panics if the generated schema does not have root type "object" as required by MCP specification. pub fn with_output_schema(mut self) -> Self { - self.output_schema = Some(crate::handler::server::tool::cached_schema_for_type::()); + let schema = crate::handler::server::tool::cached_schema_for_output::() + .unwrap_or_else(|e| panic!("Invalid output schema for tool '{}': {}", self.name, e)); + self.output_schema = Some(schema); self } diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 8bf97f2a..bbadfa67 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -105,3 +105,7 @@ path = "src/completion_stdio.rs" [[example]] name = "servers_progress_demo" path = "src/progress_demo.rs" + +[[example]] +name = "servers_calculator_stdio" +path = "src/calculator_stdio.rs" diff --git a/examples/servers/src/calculator_stdio.rs b/examples/servers/src/calculator_stdio.rs new file mode 100644 index 00000000..6af82042 --- /dev/null +++ b/examples/servers/src/calculator_stdio.rs @@ -0,0 +1,26 @@ +use anyhow::Result; +use common::calculator::Calculator; +use rmcp::{ServiceExt, transport::stdio}; +use tracing_subscriber::{self, EnvFilter}; +mod common; + +/// npx @modelcontextprotocol/inspector cargo run -p mcp-server-examples --example servers_calculator_stdio +#[tokio::main] +async fn main() -> Result<()> { + // Initialize the tracing subscriber with file and stdout logging + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into())) + .with_writer(std::io::stderr) + .with_ansi(false) + .init(); + + tracing::info!("Starting Calculator MCP server"); + + // Create an instance of our calculator router + let service = Calculator::new().serve(stdio()).await.inspect_err(|e| { + tracing::error!("serving error: {:?}", e); + })?; + + service.waiting().await?; + Ok(()) +} diff --git a/examples/servers/src/common/calculator.rs b/examples/servers/src/common/calculator.rs index de89c5fb..e6f97ce0 100644 --- a/examples/servers/src/common/calculator.rs +++ b/examples/servers/src/common/calculator.rs @@ -2,10 +2,7 @@ use rmcp::{ ServerHandler, - handler::server::{ - router::tool::ToolRouter, - wrapper::{Json, Parameters}, - }, + handler::server::{router::tool::ToolRouter, wrapper::Parameters}, model::{ServerCapabilities, ServerInfo}, schemars, tool, tool_handler, tool_router, }; @@ -44,8 +41,8 @@ impl Calculator { } #[tool(description = "Calculate the difference of two numbers")] - fn sub(&self, Parameters(SubRequest { a, b }): Parameters) -> Json { - Json(a - b) + fn sub(&self, Parameters(SubRequest { a, b }): Parameters) -> String { + (a - b).to_string() } }