Skip to content

Commit

Permalink
fix(oapi-macros): Support the generic parameters. salvo-rs#945
Browse files Browse the repository at this point in the history
  • Loading branch information
andeya committed Oct 10, 2024
1 parent 8183989 commit a7fc162
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 11 deletions.
29 changes: 19 additions & 10 deletions crates/oapi-macros/src/endpoint/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use syn::{Expr, Ident, ImplItem, Item, Pat, ReturnType, Signature, Type};
use syn::{Expr, Generics, Ident, ImplItem, Item, Pat, ReturnType, Signature, Type};

use crate::doc_comment::CommentAttributes;
use crate::{omit_type_path_lifetimes, parse_input_type, Array, DiagResult, InputType, Operation};
Expand All @@ -14,6 +14,7 @@ fn metadata(
attr: EndpointAttr,
name: &Ident,
mut modifiers: Vec<TokenStream>,
generics: &Generics,
) -> DiagResult<TokenStream> {
let tfn = Ident::new(
&format!("__macro_gen_oapi_endpoint_type_id_{}", name),
Expand Down Expand Up @@ -46,17 +47,19 @@ fn metadata(
#(#modifiers)*
}})
};
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let stream = quote! {
fn #tfn() -> ::std::any::TypeId {
::std::any::TypeId::of::<#name>()
fn #tfn #impl_generics() -> ::std::any::TypeId #where_clause {
::std::any::TypeId::of::<#name #ty_generics>()
}
fn #cfn() -> #oapi::oapi::Endpoint {
fn #cfn #impl_generics() -> #oapi::oapi::Endpoint #where_clause {
let mut components = #oapi::oapi::Components::new();
let status_codes: &[#salvo::http::StatusCode] = &#status_codes;
let mut operation = #oapi::oapi::Operation::new();
#modifiers
if operation.operation_id.is_none() {
operation.operation_id = Some(#oapi::oapi::naming::assign_name::<#name>(#oapi::oapi::naming::NameRule::Auto));
operation.operation_id = Some(#oapi::oapi::naming::assign_name::<#name #ty_generics>(#oapi::oapi::naming::NameRule::Auto));
}
if !status_codes.is_empty() {
let responses = std::ops::DerefMut::deref_mut(&mut operation.responses);
Expand Down Expand Up @@ -117,7 +120,7 @@ pub(crate) fn generate(mut attr: EndpointAttr, input: Item) -> syn::Result<Token
};

let (hfn, modifiers) = handle_fn(&salvo, &oapi, sig)?;
let meta = metadata(&salvo, &oapi, attr, name, modifiers)?;
let meta = metadata(&salvo, &oapi, attr, name, modifiers, &sig.generics)?;
Ok(quote! {
#sdef
#[#salvo::async_trait]
Expand Down Expand Up @@ -153,10 +156,16 @@ pub(crate) fn generate(mut attr: EndpointAttr, input: Item) -> syn::Result<Token
};
let (hfn, modifiers) = handle_fn(&salvo, &oapi, &hmtd.sig)?;
let ty = &item_impl.self_ty;
let (impl_generics, _, where_clause) = &item_impl.generics.split_for_impl();
let name = Ident::new(&ty.to_token_stream().to_string(), Span::call_site());
let meta = metadata(&salvo, &oapi, attr, &name, modifiers)?;

let (impl_generics, ty_generics, where_clause) = &item_impl.generics.split_for_impl();
let name = ty
.to_token_stream()
.to_string()
.to_owned()
.trim_end_matches(&ty_generics.to_token_stream().to_string())
.trim()
.to_owned();
let name = Ident::new(&name, Span::call_site());
let meta = metadata(&salvo, &oapi, attr, &name, modifiers, &item_impl.generics)?;
Ok(quote! {
#item_impl
#[#salvo::async_trait]
Expand Down
7 changes: 6 additions & 1 deletion crates/oapi-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ pub fn endpoint(attr: TokenStream, input: TokenStream) -> TokenStream {
let attr = syn::parse_macro_input!(attr as EndpointAttr);
let item = parse_macro_input!(input as Item);
match endpoint::generate(attr, item) {
Ok(stream) => stream.into(),
Ok(stream) => {
// Temporary debug printing.
println!("{}", stream.to_token_stream());

stream.into()
}
Err(e) => e.to_compile_error().into(),
}
}
Expand Down
50 changes: 50 additions & 0 deletions crates/oapi-macros/tests/endpoint_tests.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::fmt::Debug;

use assert_json_diff::assert_json_eq;
use salvo::oapi::extract::*;
use salvo::prelude::*;
Expand Down Expand Up @@ -43,3 +45,51 @@ fn test_endpoint_hello() {
})
);
}

#[test]
fn test_endpoint_generic() {
pub struct Generic<T: Sized>(T);

#[endpoint]
impl<T: Sized> Generic<T>
where
T: Send + Sync + 'static,
{
async fn handle(&self, _req: &mut Request) -> String {
String::new()
}
}

let router = Router::new().push(Router::with_path("generic").get(Generic(())));

let doc = OpenApi::new("test api", "0.0.1").merge_router(&router);
assert_json_eq!(
doc,
json!({
"openapi":"3.1.0",
"info":{
"title":"test api",
"version":"0.0.1"
},
"paths":{
"/generic":{
"get":{
"operationId":"endpoint_tests.test_endpoint_generic.generic",
"parameters":[{
"name":"name",
"in":"query",
"description":"Get parameter `name` from request url query.",
"required":false,"schema":{"type":"string"}
}],
"responses":{
"200":{
"description":"Ok",
"content":{"text/plain":{"schema":{"type":"string"}}}
}
}
}
}
}
})
);
}

0 comments on commit a7fc162

Please sign in to comment.