Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions crates/rmcp-macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ fn extract_schema_from_return_type(ret_type: &syn::Type) -> Option<Expr> {
// First, try direct Json<T>
if let Some(inner_type) = extract_json_inner_type(ret_type) {
return syn::parse2::<Expr>(quote! {
rmcp::handler::server::tool::cached_schema_for_type::<#inner_type>()
rmcp::handler::server::tool::cached_schema_for_output::<#inner_type>()
.unwrap_or_else(|e| {
panic!(
"Invalid output schema for Json<{}>: {}",
std::any::type_name::<#inner_type>(),
e
)
})
})
.ok();
}
Expand Down Expand Up @@ -57,7 +64,14 @@ fn extract_schema_from_return_type(ret_type: &syn::Type) -> Option<Expr> {
let inner_type = extract_json_inner_type(ok_type)?;

syn::parse2::<Expr>(quote! {
rmcp::handler::server::tool::cached_schema_for_type::<#inner_type>()
rmcp::handler::server::tool::cached_schema_for_output::<#inner_type>()
.unwrap_or_else(|e| {
panic!(
"Invalid output schema for Result<Json<{}>, E>: {}",
std::any::type_name::<#inner_type>(),
e
)
})
})
.ok()
}
Expand Down
66 changes: 66 additions & 0 deletions crates/rmcp/src/handler/server/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,50 @@ pub fn cached_schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject
})
}

/// Generate and validate a JSON schema for outputSchema (must have root type "object").
pub fn schema_for_output<T: JsonSchema>() -> Result<JsonObject, String> {
let schema = schema_for_type::<T>();

match schema.get("type") {
Some(serde_json::Value::String(t)) if t == "object" => Ok(schema),
Some(serde_json::Value::String(t)) => Err(format!(
"MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
t
)),
None => Err(
"Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string()
),
Some(other) => Err(format!(
"Schema 'type' field has unexpected format: {:?}. Expected \"object\".",
other
)),
}
}

/// Call [`schema_for_output`] with a cache.
pub fn cached_schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String>
{
thread_local! {
static CACHE_FOR_OUTPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
};
CACHE_FOR_OUTPUT.with(|cache| {
if let Some(result) = cache
.read()
.expect("output schema cache lock poisoned")
.get(&TypeId::of::<T>())
{
result.clone()
} else {
let result = schema_for_output::<T>().map(Arc::new);
cache
.write()
.expect("output schema cache lock poisoned")
.insert(TypeId::of::<T>(), result.clone());
result
}
})
}

/// Trait for extracting parts from a context, unifying tool and prompt extraction
pub trait FromContextPart<C>: Sized {
fn from_context_part(context: &mut C) -> Result<Self, crate::ErrorData>;
Expand Down Expand Up @@ -143,3 +187,25 @@ pub trait AsRequestContext {
fn as_request_context(&self) -> &RequestContext<RoleServer>;
fn as_request_context_mut(&mut self) -> &mut RequestContext<RoleServer>;
}

#[cfg(test)]
mod tests {
use super::*;

#[derive(serde::Serialize, serde::Deserialize, JsonSchema)]
struct TestObject {
value: i32,
}

#[test]
fn test_schema_for_output_rejects_primitive() {
let result = schema_for_output::<i32>();
assert!(result.is_err(),);
}

#[test]
fn test_schema_for_output_accepts_object() {
let result = schema_for_output::<TestObject>();
assert!(result.is_ok(),);
}
}
4 changes: 3 additions & 1 deletion crates/rmcp/src/handler/server/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use serde::de::DeserializeOwned;

use super::common::{AsRequestContext, FromContextPart};
pub use super::{
common::{Extension, RequestId, cached_schema_for_type, schema_for_type},
common::{
Extension, RequestId, cached_schema_for_output, cached_schema_for_type, schema_for_type,
},
router::tool::{ToolRoute, ToolRouter},
};
use crate::{
Expand Down
8 changes: 7 additions & 1 deletion crates/rmcp/src/model/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,14 @@ impl Tool {
}

/// Set the output schema using a type that implements JsonSchema
///
/// # Panics
///
/// Panics if the generated schema does not have root type "object" as required by MCP specification.
pub fn with_output_schema<T: JsonSchema + 'static>(mut self) -> Self {
self.output_schema = Some(crate::handler::server::tool::cached_schema_for_type::<T>());
let schema = crate::handler::server::tool::cached_schema_for_output::<T>()
.unwrap_or_else(|e| panic!("Invalid output schema for tool '{}': {}", self.name, e));
self.output_schema = Some(schema);
self
}

Expand Down
4 changes: 4 additions & 0 deletions examples/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,7 @@ path = "src/completion_stdio.rs"
[[example]]
name = "servers_progress_demo"
path = "src/progress_demo.rs"

[[example]]
name = "servers_calculator_stdio"
path = "src/calculator_stdio.rs"
26 changes: 26 additions & 0 deletions examples/servers/src/calculator_stdio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use anyhow::Result;
use common::calculator::Calculator;
use rmcp::{ServiceExt, transport::stdio};
use tracing_subscriber::{self, EnvFilter};
mod common;

/// npx @modelcontextprotocol/inspector cargo run -p mcp-server-examples --example servers_calculator_stdio
#[tokio::main]
async fn main() -> Result<()> {
// Initialize the tracing subscriber with file and stdout logging
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into()))
.with_writer(std::io::stderr)
.with_ansi(false)
.init();

tracing::info!("Starting Calculator MCP server");

// Create an instance of our calculator router
let service = Calculator::new().serve(stdio()).await.inspect_err(|e| {
tracing::error!("serving error: {:?}", e);
})?;

service.waiting().await?;
Ok(())
}
9 changes: 3 additions & 6 deletions examples/servers/src/common/calculator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

use rmcp::{
ServerHandler,
handler::server::{
router::tool::ToolRouter,
wrapper::{Json, Parameters},
},
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
model::{ServerCapabilities, ServerInfo},
schemars, tool, tool_handler, tool_router,
};
Expand Down Expand Up @@ -44,8 +41,8 @@ impl Calculator {
}

#[tool(description = "Calculate the difference of two numbers")]
fn sub(&self, Parameters(SubRequest { a, b }): Parameters<SubRequest>) -> Json<i32> {
Json(a - b)
fn sub(&self, Parameters(SubRequest { a, b }): Parameters<SubRequest>) -> String {
(a - b).to_string()
}
}

Expand Down