diff --git a/crates/oapi-macros/src/endpoint/mod.rs b/crates/oapi-macros/src/endpoint/mod.rs index a83947a4b..8c08f6c16 100644 --- a/crates/oapi-macros/src/endpoint/mod.rs +++ b/crates/oapi-macros/src/endpoint/mod.rs @@ -1,6 +1,6 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens}; -use syn::{Expr, Ident, ImplItem, Item, Pat, ReturnType, Signature, Type}; +use syn::{Expr, Generics, Ident, ImplItem, Item, Pat, ReturnType, Signature, Type}; use crate::doc_comment::CommentAttributes; use crate::{omit_type_path_lifetimes, parse_input_type, Array, DiagResult, InputType, Operation}; @@ -14,6 +14,7 @@ fn metadata( attr: EndpointAttr, name: &Ident, mut modifiers: Vec, + generics: &Generics, ) -> DiagResult { let tfn = Ident::new( &format!("__macro_gen_oapi_endpoint_type_id_{}", name), @@ -46,17 +47,19 @@ fn metadata( #(#modifiers)* }}) }; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let stream = quote! { - fn #tfn() -> ::std::any::TypeId { - ::std::any::TypeId::of::<#name>() + fn #tfn #impl_generics() -> ::std::any::TypeId #where_clause { + ::std::any::TypeId::of::<#name #ty_generics>() } - fn #cfn() -> #oapi::oapi::Endpoint { + fn #cfn #impl_generics() -> #oapi::oapi::Endpoint #where_clause { let mut components = #oapi::oapi::Components::new(); let status_codes: &[#salvo::http::StatusCode] = &#status_codes; let mut operation = #oapi::oapi::Operation::new(); #modifiers if operation.operation_id.is_none() { - operation.operation_id = Some(#oapi::oapi::naming::assign_name::<#name>(#oapi::oapi::naming::NameRule::Auto)); + operation.operation_id = Some(#oapi::oapi::naming::assign_name::<#name #ty_generics>(#oapi::oapi::naming::NameRule::Auto)); } if !status_codes.is_empty() { let responses = std::ops::DerefMut::deref_mut(&mut operation.responses); @@ -117,7 +120,7 @@ pub(crate) fn generate(mut attr: EndpointAttr, input: Item) -> syn::Result syn::Result TokenStream { let attr = syn::parse_macro_input!(attr as EndpointAttr); let item = parse_macro_input!(input as Item); match endpoint::generate(attr, item) { - Ok(stream) => stream.into(), + Ok(stream) => { + // Temporary debug printing. + println!("{}", stream.to_token_stream()); + + stream.into() + } Err(e) => e.to_compile_error().into(), } } diff --git a/crates/oapi-macros/tests/endpoint_tests.rs b/crates/oapi-macros/tests/endpoint_tests.rs index 7de21dbec..819893796 100644 --- a/crates/oapi-macros/tests/endpoint_tests.rs +++ b/crates/oapi-macros/tests/endpoint_tests.rs @@ -1,3 +1,5 @@ +use std::fmt::Debug; + use assert_json_diff::assert_json_eq; use salvo::oapi::extract::*; use salvo::prelude::*; @@ -43,3 +45,51 @@ fn test_endpoint_hello() { }) ); } + +#[test] +fn test_endpoint_generic() { + struct Generic(T); + + #[endpoint] + impl Generic + where + T: Send + Sync + 'static, + { + async fn handle(name: QueryParam) -> String { + format!("[Generic] Hello, {}!", name.as_deref().unwrap_or("World")) + } + } + + let router = Router::new().push(Router::with_path("generic").get(Generic(()))); + + let doc = OpenApi::new("test api", "0.0.1").merge_router(&router); + assert_json_eq!( + doc, + json!({ + "openapi":"3.1.0", + "info":{ + "title":"test api", + "version":"0.0.1" + }, + "paths":{ + "/generic":{ + "get":{ + "operationId":"endpoint_tests.test_endpoint_generic.Generic", + "parameters":[{ + "name":"name", + "in":"query", + "description":"Get parameter `name` from request url query.", + "required":false,"schema":{"type":"string"} + }], + "responses":{ + "200":{ + "description":"Ok", + "content":{"text/plain":{"schema":{"type":"string"}}} + } + } + } + } + } + }) + ); +}