From ab5ac82c948f24a6256e3d9fe5d74ecbadf0c5ed Mon Sep 17 00:00:00 2001 From: Russell Greene Date: Mon, 27 Nov 2023 22:53:01 -0700 Subject: [PATCH] cxx function --- gen/src/mod.rs | 2 +- gen/src/write.rs | 57 +++++++++++++++++++++++++++- macro/Cargo.toml | 3 +- macro/src/expand.rs | 86 +++++++++++++++++++++++++++++++++++++----- src/cxx_function.rs | 37 ++++++++++++++++++ src/lib.rs | 3 ++ syntax/check.rs | 37 +++++++++++++++++- syntax/impls.rs | 2 + syntax/improper.rs | 8 ++-- syntax/instantiate.rs | 88 +++++++++++++++++++++++++++++++++++++++++-- syntax/map.rs | 6 +-- syntax/mod.rs | 1 + syntax/names.rs | 2 +- syntax/namespace.rs | 2 +- syntax/parse.rs | 24 +++++++++--- syntax/pod.rs | 1 + syntax/set.rs | 1 + syntax/tokens.rs | 7 ++++ syntax/types.rs | 5 ++- syntax/visit.rs | 1 + tests/ffi/lib.rs | 10 +++++ tests/ffi/tests.cc | 18 +++++++++ 22 files changed, 369 insertions(+), 32 deletions(-) create mode 100644 src/cxx_function.rs diff --git a/gen/src/mod.rs b/gen/src/mod.rs index f24846a7e..f852cc97b 100644 --- a/gen/src/mod.rs +++ b/gen/src/mod.rs @@ -163,7 +163,7 @@ pub(super) fn generate(syntax: File, opt: &Opt) -> Result { cfg::strip(errors, cfg_errors, opt.cfg_evaluator.as_ref(), apis); errors.propagate()?; - let ref types = Types::collect(errors, apis); + let ref types = Types::collect(errors, &apis); check::precheck(errors, apis, opt); errors.propagate()?; diff --git a/gen/src/write.rs b/gen/src/write.rs index 8eef0a76b..cdf831ed6 100644 --- a/gen/src/write.rs +++ b/gen/src/write.rs @@ -3,7 +3,7 @@ use crate::gen::nested::NamespaceEntries; use crate::gen::out::OutFile; use crate::gen::{builtin, include, Opt}; use crate::syntax::atom::Atom::{self, *}; -use crate::syntax::instantiate::{ImplKey, NamedImplKey}; +use crate::syntax::instantiate::{FunctionImplKey, ImplKey, NamedImplKey}; use crate::syntax::map::UnorderedMap as Map; use crate::syntax::set::UnorderedSet; use crate::syntax::symbol::{self, Symbol}; @@ -218,6 +218,7 @@ fn pick_includes_and_builtins(out: &mut OutFile, apis: &[Api]) { Type::SharedPtr(_) | Type::WeakPtr(_) => out.include.memory = true, Type::Str(_) => out.builtin.rust_str = true, Type::CxxVector(_) => out.include.vector = true, + Type::CxxFunction(_) => out.include.functional = true, Type::Fn(_) => out.builtin.rust_fn = true, Type::SliceRef(_) => out.builtin.rust_slice = true, Type::Array(_) => out.include.array = true, @@ -1255,6 +1256,25 @@ fn write_type(out: &mut OutFile, ty: &Type) { write_type(out, &ty.inner); write!(out, ">"); } + Type::CxxFunction(ty) => { + if let Type::Fn(f) = &ty.inner { + write!(out, "::std::function<"); + match &f.ret { + Some(ret) => write_type(out, ret), + None => write!(out, "void"), + } + write!(out, "("); + for (i, arg) in f.args.iter().enumerate() { + if i > 0 { + write!(out, ", "); + } + write_type(out, &arg.ty); + } + write!(out, ")>"); + } else { + unreachable!("should not have produced a CxxFunction with non-Fn inner") + } + } Type::Ref(r) => { write_type_space(out, &r.inner); if !r.mutable { @@ -1339,6 +1359,7 @@ fn write_space_after_type(out: &mut OutFile, ty: &Type) { | Type::WeakPtr(_) | Type::Str(_) | Type::CxxVector(_) + | Type::CxxFunction(_) | Type::RustVec(_) | Type::SliceRef(_) | Type::Fn(_) @@ -1413,6 +1434,7 @@ fn write_generic_instantiations(out: &mut OutFile) { ImplKey::SharedPtr(ident) => write_shared_ptr(out, ident), ImplKey::WeakPtr(ident) => write_weak_ptr(out, ident), ImplKey::CxxVector(ident) => write_cxx_vector(out, ident), + ImplKey::CxxFunction(ref key) => write_cxx_function(out, key), } } out.end_block(Block::ExternC); @@ -1956,3 +1978,36 @@ fn write_cxx_vector(out: &mut OutFile, key: NamedImplKey) { out.include.memory = true; write_unique_ptr_common(out, UniquePtr::CxxVector(element)); } + +fn write_cxx_function(out: &mut OutFile, key: &FunctionImplKey) { + begin_function_definition(out); + let link_invoke = key.link_name_invoke(out.types); + + if let Type::CxxFunction(ty) = key.ty { + if let Type::Fn(sig) = &ty.inner { + if let Some(ret) = &sig.ret { + write_type(out, &ret); + } else { + write!(out, "void") + } + write!(out, " {link_invoke}("); + write_type(out, &key.ty); + write!(out, " f"); + for (id, a) in sig.args.iter().enumerate() { + write!(out, ", "); + write_type(out, &a.ty); + write!(out, " a_{id}"); + } + writeln!(out, ") {{"); + write!(out, " return f("); + for id in 0..sig.args.len() { + write!(out, "a_{id}"); + if id != sig.args.len() - 1 { + write!(out, ", "); + } + } + writeln!(out, ");"); + writeln!(out, "}}"); + } + } +} diff --git a/macro/Cargo.toml b/macro/Cargo.toml index d6ed78e6f..e8637c67b 100644 --- a/macro/Cargo.toml +++ b/macro/Cargo.toml @@ -23,7 +23,7 @@ experimental-enum-variants-from-header = ["clang-ast", "flate2", "memmap", "serd [dependencies] proc-macro2 = "1.0.63" quote = "1.0.29" -syn = { version = "2.0.23", features = ["full"] } +syn = { version = "2.0.23", features = ["full", "extra-traits"] } # optional dependencies: clang-ast = { version = "0.1.18", optional = true } @@ -32,6 +32,7 @@ memmap = { version = "0.7", optional = true } serde = { version = "1.0.166", optional = true } serde_derive = { version = "1.0.166", optional = true } serde_json = { version = "1.0.100", optional = true } +libc = "0.2.150" [dev-dependencies] cxx = { version = "1.0", path = ".." } diff --git a/macro/src/expand.rs b/macro/src/expand.rs index 7715ba8fa..1370d64c4 100644 --- a/macro/src/expand.rs +++ b/macro/src/expand.rs @@ -2,13 +2,13 @@ use crate::syntax::atom::Atom::*; use crate::syntax::attrs::{self, OtherAttrs}; use crate::syntax::cfg::CfgExpr; use crate::syntax::file::Module; -use crate::syntax::instantiate::{ImplKey, NamedImplKey}; +use crate::syntax::instantiate::{FunctionImplKey, ImplKey, NamedImplKey}; use crate::syntax::qualified::QualifiedName; use crate::syntax::report::Errors; use crate::syntax::symbol::Symbol; use crate::syntax::{ - self, check, mangle, Api, Doc, Enum, ExternFn, ExternType, Impl, Lifetimes, Pair, Signature, - Struct, Trait, Type, TypeAlias, Types, + self, check, mangle, Api, Doc, Enum, ExternFn, ExternType, Impl, Lifetimes, Pair, + Signature, Struct, Trait, Type, TypeAlias, Types, }; use crate::type_id::Crate; use crate::{derive, generics}; @@ -91,25 +91,35 @@ fn expand(ffi: Module, doc: Doc, attrs: OtherAttrs, apis: &[Api], types: &Types) } } + let mut first_cxx_function = true; for (impl_key, &explicit_impl) in &types.impls { - match *impl_key { + match impl_key { ImplKey::RustBox(ident) => { - hidden.extend(expand_rust_box(ident, types, explicit_impl)); + hidden.extend(expand_rust_box(*ident, types, explicit_impl)); } ImplKey::RustVec(ident) => { - hidden.extend(expand_rust_vec(ident, types, explicit_impl)); + hidden.extend(expand_rust_vec(*ident, types, explicit_impl)); } ImplKey::UniquePtr(ident) => { - expanded.extend(expand_unique_ptr(ident, types, explicit_impl)); + expanded.extend(expand_unique_ptr(*ident, types, explicit_impl)); } ImplKey::SharedPtr(ident) => { - expanded.extend(expand_shared_ptr(ident, types, explicit_impl)); + expanded.extend(expand_shared_ptr(*ident, types, explicit_impl)); } ImplKey::WeakPtr(ident) => { - expanded.extend(expand_weak_ptr(ident, types, explicit_impl)); + expanded.extend(expand_weak_ptr(*ident, types, explicit_impl)); } ImplKey::CxxVector(ident) => { - expanded.extend(expand_cxx_vector(ident, explicit_impl, types)); + expanded.extend(expand_cxx_vector(*ident, explicit_impl, types)); + } + ImplKey::CxxFunction(sig) => { + expanded.extend(expand_cxx_function( + &sig, + explicit_impl, + types, + first_cxx_function, + )); + first_cxx_function = false; } } } @@ -1674,6 +1684,7 @@ fn expand_cxx_vector( let elem = key.rust; let name = elem.to_string(); let resolve = types.resolve(elem); + let prefix = format!("cxxbridge1$std$vector${}$", resolve.name.to_symbol()); let link_new = format!("{}new", prefix); let link_size = format!("{}size", prefix); @@ -1819,6 +1830,61 @@ fn expand_cxx_vector( } } +fn expand_cxx_function( + key: &FunctionImplKey, + _explicit_impl: Option<&Impl>, + types: &Types, + first_cxx_function: bool, +) -> TokenStream { + let link_invoke = key.link_name_invoke(types); + + if let Type::CxxFunction(ty) = key.ty { + if let Type::Fn(sig) = &ty.inner { + let only_first_time = if first_cxx_function { + quote! { + pub struct CrateFnImpls; + pub type CxxFunction = ::cxx::CxxFunction; + } + } else { + Default::default() + }; + + let ret = if let Some(Type::Ident(r)) = &sig.ret { + quote! { #r } + } else { + quote! { () } + }; + + let mut arg_types = quote!{}; + let mut arg_names = quote!{}; + let mut arg_type_names = quote!{}; + for (id, a) in sig.args.iter().enumerate() { + if let Type::Ident(ident) = &a.ty { + let arg_name = format_ident!("a_{id}"); + arg_types = quote!{ #arg_types #ident, }; + arg_names = quote!{ #arg_names #arg_name, }; + arg_type_names = quote!{ #arg_type_names #arg_name: #ident, }; + } + } + + return quote_spanned! {ty.name.span() => + #only_first_time + + unsafe impl ::cxx::private::CxxFunctionImpl<#ret, (#arg_types)> for CrateFnImpls { + fn __invoke(this: *mut ::cxx::core::ffi::c_void, (#arg_names): (#arg_types)) -> #ret { + extern "C" { + #[link_name = #link_invoke] + fn __function_invoke(this: *mut ::cxx::core::ffi::c_void, #arg_type_names) -> #ret; + } + unsafe { __function_invoke(this, #arg_names) } + } + } + }; + } + } + panic!("bad type"); +} + fn expand_return_type(ret: &Option) -> TokenStream { match ret { Some(ret) => quote!(-> #ret), diff --git a/src/cxx_function.rs b/src/cxx_function.rs new file mode 100644 index 000000000..196f489d9 --- /dev/null +++ b/src/cxx_function.rs @@ -0,0 +1,37 @@ +use core::{ffi::c_void, marker::PhantomData}; + +/// +pub unsafe trait CxxFunctionImpl { + #[doc(hidden)] + fn __invoke(f: *mut c_void, a: Args) -> Ret; +} + +/// Bridge to std::function +#[repr(C)] +pub struct CxxFunction { + // A thing, because repr(C) structs are not allowed to consist exclusively + // of PhantomData fields. + _void: [c_void; 0], + _impl: PhantomData, + _fn: PhantomData, +} + +impl CxxFunction Out> +where + I: CxxFunctionImpl, +{ + /// Run the std::function + pub fn invoke(&mut self) -> Out { + >::__invoke(self as *mut _ as _, ()) + } +} + +impl CxxFunction Out> +where + I: CxxFunctionImpl, +{ + /// Run the std::function + pub fn invoke(&mut self, a: In) -> Out { + >::__invoke(self as *mut _ as _, (a,)) + } +} diff --git a/src/lib.rs b/src/lib.rs index 9c2281108..cdfe29391 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -449,6 +449,7 @@ compile_error! { mod macros; mod c_char; +mod cxx_function; mod cxx_vector; mod exception; mod extern_type; @@ -475,6 +476,7 @@ mod unwind; pub mod vector; mod weak_ptr; +pub use crate::cxx_function::CxxFunction; pub use crate::cxx_vector::CxxVector; #[cfg(feature = "alloc")] pub use crate::exception::Exception; @@ -503,6 +505,7 @@ pub type Vector = CxxVector; #[doc(hidden)] pub mod private { pub use crate::c_char::c_char; + pub use crate::cxx_function::CxxFunctionImpl; pub use crate::cxx_vector::VectorElement; pub use crate::extern_type::{verify_extern_kind, verify_extern_type}; pub use crate::function::FatFunction; diff --git a/syntax/check.rs b/syntax/check.rs index b5fd45e11..d406561d9 100644 --- a/syntax/check.rs +++ b/syntax/check.rs @@ -51,6 +51,7 @@ fn do_typecheck(cx: &mut Check) { Type::SharedPtr(ptr) => check_type_shared_ptr(cx, ptr), Type::WeakPtr(ptr) => check_type_weak_ptr(cx, ptr), Type::CxxVector(ptr) => check_type_cxx_vector(cx, ptr), + Type::CxxFunction(ptr) => check_type_cxx_function(cx, ptr), Type::Ref(ty) => check_type_ref(cx, ty), Type::Ptr(ty) => check_type_ptr(cx, ty), Type::Array(array) => check_type_array(cx, array), @@ -225,6 +226,34 @@ fn check_type_cxx_vector(cx: &mut Check, ptr: &Ty1) { cx.error(ptr, "unsupported vector element type"); } +fn check_type_cxx_function(cx: &mut Check, ptr: &Ty1) { + if let Type::Fn(ptr) = &ptr.inner { + if ptr.asyncness.is_some() { + cx.error(ptr, "CxxFunction signature must not be async"); + } + if ptr.unsafety.is_some() { + cx.error( + ptr, + "unimplemented: CxxFunction signature must not be unsafe", + ); + } + if ptr.generics.lt_token.is_some() { + cx.error(ptr, "CxxFunction signature must not be generic"); + } + if ptr.receiver.is_some() { + cx.error(ptr, "CxxFunction signature must not have a receiver"); + } + if ptr.throws { + cx.error(ptr, "CxxFunction signature must not throw"); + } + // if let Type::Fn(_sig) = &ptr.inner { + // return; + // } + return; + } + cx.error(ptr, "CxxFunction must have fn(a, b) -> c generic"); +} + fn check_type_ref(cx: &mut Check, ty: &Ref) { if ty.mutable && !ty.pinned { if let Some(requires_pin) = match &ty.inner { @@ -526,7 +555,9 @@ fn check_api_impl(cx: &mut Check, imp: &Impl) { | Type::UniquePtr(ty) | Type::SharedPtr(ty) | Type::WeakPtr(ty) - | Type::CxxVector(ty) => { + | Type::CxxVector(ty) + | Type::CxxFunction(ty) // ????? + => { if let Type::Ident(inner) = &ty.inner { if Atom::from(&inner.rust).is_none() { return; @@ -609,6 +640,7 @@ fn check_reserved_name(cx: &mut Check, ident: &Ident) { || ident == "WeakPtr" || ident == "Vec" || ident == "CxxVector" + || ident == "CxxFunction" || ident == "str" || Atom::from(ident).is_some() { @@ -648,7 +680,7 @@ fn is_unsized(cx: &mut Check, ty: &Type) -> bool { ident == CxxString || is_opaque_cxx(cx, ident) || cx.types.rust.contains(ident) } Type::Array(array) => is_unsized(cx, &array.inner), - Type::CxxVector(_) | Type::Fn(_) | Type::Void(_) => true, + Type::CxxVector(_) | Type::CxxFunction(_) | Type::Fn(_) | Type::Void(_) => true, // NOTE: maybe actually sized? idk Type::RustBox(_) | Type::RustVec(_) | Type::UniquePtr(_) @@ -732,6 +764,7 @@ fn describe(cx: &mut Check, ty: &Type) -> String { Type::Ptr(_) => "raw pointer".to_owned(), Type::Str(_) => "&str".to_owned(), Type::CxxVector(_) => "C++ vector".to_owned(), + Type::CxxFunction(_) => "C++ function".to_owned(), Type::SliceRef(_) => "slice".to_owned(), Type::Fn(_) => "function pointer".to_owned(), Type::Void(_) => "()".to_owned(), diff --git a/syntax/impls.rs b/syntax/impls.rs index 36e1f322a..28b8e347b 100644 --- a/syntax/impls.rs +++ b/syntax/impls.rs @@ -53,6 +53,7 @@ impl Hash for Type { Type::Str(t) => t.hash(state), Type::RustVec(t) => t.hash(state), Type::CxxVector(t) => t.hash(state), + Type::CxxFunction(t) => t.hash(state), Type::Fn(t) => t.hash(state), Type::SliceRef(t) => t.hash(state), Type::Array(t) => t.hash(state), @@ -75,6 +76,7 @@ impl PartialEq for Type { (Type::Str(lhs), Type::Str(rhs)) => lhs == rhs, (Type::RustVec(lhs), Type::RustVec(rhs)) => lhs == rhs, (Type::CxxVector(lhs), Type::CxxVector(rhs)) => lhs == rhs, + (Type::CxxFunction(lhs), Type::CxxFunction(rhs)) => lhs == rhs, (Type::Fn(lhs), Type::Fn(rhs)) => lhs == rhs, (Type::SliceRef(lhs), Type::SliceRef(rhs)) => lhs == rhs, (Type::Void(_), Type::Void(_)) => true, diff --git a/syntax/improper.rs b/syntax/improper.rs index a19f5b7d6..3d635c9e1 100644 --- a/syntax/improper.rs +++ b/syntax/improper.rs @@ -28,9 +28,11 @@ impl<'a> Types<'a> { | Type::Fn(_) | Type::Void(_) | Type::SliceRef(_) => Definite(true), - Type::UniquePtr(_) | Type::SharedPtr(_) | Type::WeakPtr(_) | Type::CxxVector(_) => { - Definite(false) - } + Type::UniquePtr(_) + | Type::SharedPtr(_) + | Type::WeakPtr(_) + | Type::CxxVector(_) + | Type::CxxFunction(_) => Definite(false), Type::Ref(ty) => self.determine_improper_ctype(&ty.inner), Type::Ptr(ty) => self.determine_improper_ctype(&ty.inner), Type::Array(ty) => self.determine_improper_ctype(&ty.inner), diff --git a/syntax/instantiate.rs b/syntax/instantiate.rs index dda306982..7e654cc2d 100644 --- a/syntax/instantiate.rs +++ b/syntax/instantiate.rs @@ -1,9 +1,14 @@ -use crate::syntax::{NamedType, Ty1, Type}; +use crate::syntax::{Atom, NamedType, Ty1, Type}; use proc_macro2::{Ident, Span}; -use std::hash::{Hash, Hasher}; +use std::{ + fmt::Write, + hash::{Hash, Hasher}, +}; use syn::Token; -#[derive(Copy, Clone, PartialEq, Eq, Hash)] +use super::Types; + +#[derive(Clone, PartialEq, Eq, Hash)] pub(crate) enum ImplKey<'a> { RustBox(NamedImplKey<'a>), RustVec(NamedImplKey<'a>), @@ -11,6 +16,15 @@ pub(crate) enum ImplKey<'a> { SharedPtr(NamedImplKey<'a>), WeakPtr(NamedImplKey<'a>), CxxVector(NamedImplKey<'a>), + CxxFunction(FunctionImplKey<'a>), +} + +#[derive(Clone)] +pub(crate) struct FunctionImplKey<'a> { + pub ret: Option<&'a Ident>, + pub args: Vec<&'a Ident>, + + pub ty: &'a Type, } #[derive(Copy, Clone)] @@ -52,6 +66,29 @@ impl Type { if let Type::Ident(ident) = &ty.inner { return Some(ImplKey::CxxVector(NamedImplKey::new(ty, ident))); } + } else if let Type::CxxFunction(ty) = self { + if let Type::Fn(sig) = &ty.inner { + let ret = match &sig.ret { + Some(Type::Ident(ret_id)) => Some(&ret_id.rust), + Some(_) => return None, // non-ident return, weird + None => None, + }; + + let args = sig + .args + .iter() + .map(|a| { + if let Type::Ident(a) = &a.ty { + Some(&a.rust) + } else { + None + } + }) + .flatten() + .collect(); + + return Some(ImplKey::CxxFunction(FunctionImplKey::new(&self, ret, args))); + } } None } @@ -82,3 +119,48 @@ impl<'a> NamedImplKey<'a> { } } } + +impl<'a> FunctionImplKey<'a> { + fn new(ty: &'a Type, ret: Option<&'a Ident>, args: Vec<&'a Ident>) -> FunctionImplKey<'a> { + FunctionImplKey { ret, args, ty } + } + + pub fn link_name_invoke(&self, types: &Types) -> String { + let ret_str = if let Some(ret) = self.ret { + if let Some(atom) = Atom::from(ret) { + atom.to_string() + } else { + types.resolve(ret).name.to_symbol().to_string() + } + } else { + "".to_string() + }; + + let mut prefix = format!("cxxbridge1$std$function${}", ret_str); + for arg in &self.args { + if let Some(atom) = Atom::from(arg) { + write!(&mut prefix, "${}", atom).unwrap(); + } else { + write!(&mut prefix, "${}", types.resolve(*arg).name.to_symbol()).unwrap(); + } + } + prefix.push('$'); + + format!("{}invoke", prefix) + } +} + +impl<'a> PartialEq for FunctionImplKey<'a> { + fn eq(&self, other: &Self) -> bool { + self.ret == other.ret && self.args == other.args + } +} + +impl<'a> Eq for FunctionImplKey<'a> {} + +impl<'a> Hash for FunctionImplKey<'a> { + fn hash(&self, hasher: &mut H) { + self.ret.hash(hasher); + self.args.hash(hasher); + } +} diff --git a/syntax/map.rs b/syntax/map.rs index 4a2db0b83..42b615e87 100644 --- a/syntax/map.rs +++ b/syntax/map.rs @@ -38,17 +38,17 @@ mod ordered { impl OrderedMap where - K: Copy + Hash + Eq, + K: Hash + Eq + Clone, { pub(crate) fn insert(&mut self, key: K, value: V) -> Option { - match self.map.entry(key) { + match self.map.entry(key.clone()) { Entry::Occupied(entry) => { let i = &mut self.vec[*entry.get()]; Some(mem::replace(&mut i.1, value)) } Entry::Vacant(entry) => { entry.insert(self.vec.len()); - self.vec.push((key, value)); + self.vec.push((key.clone(), value)); None } } diff --git a/syntax/mod.rs b/syntax/mod.rs index 5ff343b4d..f1139ea7d 100644 --- a/syntax/mod.rs +++ b/syntax/mod.rs @@ -270,6 +270,7 @@ pub(crate) enum Type { Ptr(Box), Str(Box), CxxVector(Box), + CxxFunction(Box), Fn(Box), Void(Span), SliceRef(Box), diff --git a/syntax/names.rs b/syntax/names.rs index 7afa5a9e3..9b6be5337 100644 --- a/syntax/names.rs +++ b/syntax/names.rs @@ -7,7 +7,7 @@ use syn::ext::IdentExt; use syn::parse::{Error, Parser, Result}; use syn::punctuated::Punctuated; -#[derive(Clone)] +#[derive(Clone, Debug)] pub(crate) struct ForeignName { text: String, } diff --git a/syntax/namespace.rs b/syntax/namespace.rs index 417fb34f1..6956b340b 100644 --- a/syntax/namespace.rs +++ b/syntax/namespace.rs @@ -9,7 +9,7 @@ mod kw { syn::custom_keyword!(namespace); } -#[derive(Clone, Default)] +#[derive(Clone, Default, Debug)] pub(crate) struct Namespace { segments: Vec, } diff --git a/syntax/parse.rs b/syntax/parse.rs index 850dcc8d1..ebd13d7c9 100644 --- a/syntax/parse.rs +++ b/syntax/parse.rs @@ -1066,7 +1066,7 @@ fn parse_impl(cx: &mut Errors, imp: ItemImpl) -> Result { | Type::UniquePtr(ty) | Type::SharedPtr(ty) | Type::WeakPtr(ty) - | Type::CxxVector(ty) => match &ty.inner { + | Type::CxxVector(ty) => match &ty.inner { Type::Ident(ident) => ident.generics.clone(), _ => Lifetimes::default(), }, @@ -1074,10 +1074,11 @@ fn parse_impl(cx: &mut Errors, imp: ItemImpl) -> Result { | Type::Ref(_) | Type::Ptr(_) | Type::Str(_) - | Type::Fn(_) + | Type::Fn(_) // | Type::CxxFunction(_) // TODO: CxxFunction probably needs lifetime | Type::Void(_) | Type::SliceRef(_) | Type::Array(_) => Lifetimes::default(), + Type::CxxFunction(_) => panic!() }; let negative = negative_token.is_some(); @@ -1261,6 +1262,15 @@ fn parse_type_path(ty: &TypePath) -> Result { rangle: generic.gt_token, }))); } + } else if ident == "CxxFunction" && generic.args.len() == 1 { + if let GenericArgument::Type(syn::Type::BareFn(f)) = &generic.args[0] { + return Ok(Type::CxxFunction(Box::new(Ty1 { + name: ident, + langle: generic.lt_token, + inner: Type::Fn(parse_signature(f)?), + rangle: generic.gt_token, + }))); + } } else if ident == "Box" && generic.args.len() == 1 { if let GenericArgument::Type(arg) = &generic.args[0] { let inner = parse_type(arg)?; @@ -1361,7 +1371,7 @@ fn parse_type_array(ty: &TypeArray) -> Result { }))) } -fn parse_type_fn(ty: &TypeBareFn) -> Result { +fn parse_signature(ty: &TypeBareFn) -> Result> { if ty.lifetimes.is_some() { return Err(Error::new_spanned( ty, @@ -1419,7 +1429,7 @@ fn parse_type_fn(ty: &TypeBareFn) -> Result { let receiver = None; let paren_token = ty.paren_token; - Ok(Type::Fn(Box::new(Signature { + Ok(Box::new(Signature { asyncness, unsafety, fn_token, @@ -1430,7 +1440,11 @@ fn parse_type_fn(ty: &TypeBareFn) -> Result { throws, paren_token, throws_tokens, - }))) + })) +} + +fn parse_type_fn(ty: &TypeBareFn) -> Result { + Ok(Type::Fn(parse_signature(ty)?)) } fn parse_return_type( diff --git a/syntax/pod.rs b/syntax/pod.rs index 506e53cb5..4320b6c99 100644 --- a/syntax/pod.rs +++ b/syntax/pod.rs @@ -28,6 +28,7 @@ impl<'a> Types<'a> { | Type::SharedPtr(_) | Type::WeakPtr(_) | Type::CxxVector(_) + | Type::CxxFunction(_) | Type::Void(_) => false, Type::Ref(_) | Type::Str(_) | Type::Fn(_) | Type::SliceRef(_) | Type::Ptr(_) => true, Type::Array(array) => self.is_guaranteed_pod(&array.inner), diff --git a/syntax/set.rs b/syntax/set.rs index 0907834b5..ab4887c39 100644 --- a/syntax/set.rs +++ b/syntax/set.rs @@ -59,6 +59,7 @@ mod unordered { // Wrapper prohibits accidentally introducing iteration over the set, which // could lead to nondeterministic generated code. + #[derive(Debug)] pub(crate) struct UnorderedSet(HashSet); impl UnorderedSet diff --git a/syntax/tokens.rs b/syntax/tokens.rs index 05eddc703..8becc82d9 100644 --- a/syntax/tokens.rs +++ b/syntax/tokens.rs @@ -28,6 +28,7 @@ impl ToTokens for Type { | Type::SharedPtr(ty) | Type::WeakPtr(ty) | Type::CxxVector(ty) + | Type::CxxFunction(ty) | Type::RustVec(ty) => ty.to_tokens(tokens), Type::Ref(r) | Type::Str(r) => r.to_tokens(tokens), Type::Ptr(p) => p.to_tokens(tokens), @@ -75,6 +76,12 @@ impl ToTokens for Ty1 { "Vec" => { tokens.extend(quote_spanned!(span=> ::cxx::alloc::vec::)); } + "CxxFunction" => { + tokens.extend(quote_spanned!(span=> ::cxx::CxxFunction {} } name.to_tokens(tokens); diff --git a/syntax/types.rs b/syntax/types.rs index 623a8b8d6..7b009c7d7 100644 --- a/syntax/types.rs +++ b/syntax/types.rs @@ -175,7 +175,7 @@ impl<'a> Types<'a> { Some(impl_key) => impl_key, None => continue, }; - let implicit_impl = match impl_key { + let implicit_impl = match &impl_key { ImplKey::RustBox(ident) | ImplKey::RustVec(ident) | ImplKey::UniquePtr(ident) @@ -184,6 +184,9 @@ impl<'a> Types<'a> { | ImplKey::CxxVector(ident) => { Atom::from(ident.rust).is_none() && !aliases.contains_key(ident.rust) } + ImplKey::CxxFunction(_) => { + true // always need an impl + } }; if implicit_impl && !impls.contains_key(&impl_key) { impls.insert(impl_key, None); diff --git a/syntax/visit.rs b/syntax/visit.rs index e31b8c41b..22a7b574b 100644 --- a/syntax/visit.rs +++ b/syntax/visit.rs @@ -17,6 +17,7 @@ where | Type::SharedPtr(ty) | Type::WeakPtr(ty) | Type::CxxVector(ty) + | Type::CxxFunction(ty) | Type::RustVec(ty) => visitor.visit_type(&ty.inner), Type::Ref(r) => visitor.visit_type(&r.inner), Type::Ptr(p) => visitor.visit_type(&p.inner), diff --git a/tests/ffi/lib.rs b/tests/ffi/lib.rs index ef8d5b371..4592ab3ef 100644 --- a/tests/ffi/lib.rs +++ b/tests/ffi/lib.rs @@ -300,6 +300,8 @@ pub mod ffi { fn r_take_ref_rust_vec(v: &Vec); fn r_take_ref_rust_vec_string(v: &Vec); fn r_take_enum(e: Enum); + fn r_call_ref_function(v: &mut CxxFunction); + fn r_call_ref_function_arg_ret(v: &mut CxxFunction i32>) -> i32; fn r_try_return_void() -> Result<()>; fn r_try_return_primitive() -> Result; @@ -622,6 +624,14 @@ fn r_take_enum(e: ffi::Enum) { let _ = e; } +fn r_call_ref_function(f: &mut ffi::CxxFunction) { + f.invoke(); +} + +fn r_call_ref_function_arg_ret(f: &mut ffi::CxxFunction i32>) -> i32 { + f.invoke(32) +} + fn r_try_return_void() -> Result<(), Error> { Ok(()) } diff --git a/tests/ffi/tests.cc b/tests/ffi/tests.cc index 8cf74bebb..e2256b2b4 100644 --- a/tests/ffi/tests.cc +++ b/tests/ffi/tests.cc @@ -8,6 +8,7 @@ #include #include #include +#include extern "C" void cxx_test_suite_set_correct() noexcept; extern "C" tests::R *cxx_test_suite_get_box() noexcept; @@ -898,6 +899,23 @@ extern "C" const char *cxx_run_test() noexcept { (void)rust::Vec(); (void)rust::Vec(); + bool called = false; + std::function f = [&] { + called = true; + }; + + ASSERT(!called); + r_call_ref_function(f); + ASSERT(called); + + // i32 pre_call = -1; + // std::function f = [&](int32_t a) { + // pre_call = a; + // return 45; + // }; + // ASSERT(r_call_ref_function_arg_ret(f) == 45); + // ASSERT(pre_call == 32); + cxx_test_suite_set_correct(); return nullptr; }