diff --git a/soroban-sdk-macros/src/syn_ext.rs b/soroban-sdk-macros/src/syn_ext.rs index 42bcb166..64273c01 100644 --- a/soroban-sdk-macros/src/syn_ext.rs +++ b/soroban-sdk-macros/src/syn_ext.rs @@ -1,11 +1,12 @@ use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; +use std::collections::HashMap; use syn::{ parse::{Parse, ParseStream}, punctuated::Punctuated, token::Comma, AngleBracketedGenericArguments, Attribute, GenericArgument, Path, PathArguments, PathSegment, - QSelf, ReturnType, Token, TypePath, + ReturnType, Token, TypePath, }; use syn::{ spanned::Spanned, token::And, Error, FnArg, Ident, ImplItem, ImplItemFn, ItemImpl, ItemTrait, @@ -216,10 +217,14 @@ fn unpack_result(typ: &Type) -> Option<(Type, Type)> { fn flatten_associated_items_in_impl_fns(imp: &mut ItemImpl) { // TODO: Flatten associated consts used in functions. // Flatten associated types used in functions. - let Some((_, trait_, _)) = &imp.trait_ else { - return; - }; - let self_ty = &*imp.self_ty; + let associated_types = imp + .items + .iter() + .filter_map(|item| match item { + ImplItem::Type(i) => Some((i.ident.clone(), i.ty.clone())), + _ => None, + }) + .collect::>(); let fn_input_types = imp .items .iter_mut() @@ -232,29 +237,19 @@ fn flatten_associated_items_in_impl_fns(imp: &mut ItemImpl) { }) .flatten(); for t in fn_input_types { - let span = t.span(); - if let Type::Path(TypePath { - qself: qself @ None, - path, - }) = t.as_mut() - { - let segments = path.segments.clone(); - if segments.first() == Some(&PathSegment::from(format_ident!("Self"))) { - *qself = Some(QSelf { - lt_token: Token![<](span), - ty: Box::new(self_ty.clone()), - // The index of the first path segment outside the <>. - // For example, would be 2 for ::AssociatedItem. - position: trait_.segments.len(), - as_token: Some(Token![as](span)), - gt_token: Token![>](span), - }); - *path = trait_.clone(); - // Add the original path segments after the trait, skipping - // the first one which is the Self. - for segment in segments.into_iter().skip(1) { - path.segments.push_punct(Token![::](span)); - path.segments.push_value(segment); + if let Type::Path(TypePath { qself: None, path }) = t.as_mut() { + let segments = &path.segments; + if segments.len() == 2 + && segments.first() == Some(&PathSegment::from(format_ident!("Self"))) + { + if let Some(PathSegment { + arguments: PathArguments::None, + ident, + }) = segments.get(1) + { + if let Some(resolved_ty) = associated_types.get(ident) { + *t.as_mut() = resolved_ty.clone(); + } } } }