Skip to content

Commit

Permalink
cxx function
Browse files Browse the repository at this point in the history
  • Loading branch information
russelltg committed Nov 28, 2023
1 parent 6432e9e commit ab5ac82
Show file tree
Hide file tree
Showing 22 changed files with 369 additions and 32 deletions.
2 changes: 1 addition & 1 deletion gen/src/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ pub(super) fn generate(syntax: File, opt: &Opt) -> Result<GeneratedCode> {
cfg::strip(errors, cfg_errors, opt.cfg_evaluator.as_ref(), apis);
errors.propagate()?;

let ref types = Types::collect(errors, apis);
let ref types = Types::collect(errors, &apis);
check::precheck(errors, apis, opt);
errors.propagate()?;

Expand Down
57 changes: 56 additions & 1 deletion gen/src/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::gen::nested::NamespaceEntries;
use crate::gen::out::OutFile;
use crate::gen::{builtin, include, Opt};
use crate::syntax::atom::Atom::{self, *};
use crate::syntax::instantiate::{ImplKey, NamedImplKey};
use crate::syntax::instantiate::{FunctionImplKey, ImplKey, NamedImplKey};
use crate::syntax::map::UnorderedMap as Map;
use crate::syntax::set::UnorderedSet;
use crate::syntax::symbol::{self, Symbol};
Expand Down Expand Up @@ -218,6 +218,7 @@ fn pick_includes_and_builtins(out: &mut OutFile, apis: &[Api]) {
Type::SharedPtr(_) | Type::WeakPtr(_) => out.include.memory = true,
Type::Str(_) => out.builtin.rust_str = true,
Type::CxxVector(_) => out.include.vector = true,
Type::CxxFunction(_) => out.include.functional = true,
Type::Fn(_) => out.builtin.rust_fn = true,
Type::SliceRef(_) => out.builtin.rust_slice = true,
Type::Array(_) => out.include.array = true,
Expand Down Expand Up @@ -1255,6 +1256,25 @@ fn write_type(out: &mut OutFile, ty: &Type) {
write_type(out, &ty.inner);
write!(out, ">");
}
Type::CxxFunction(ty) => {
if let Type::Fn(f) = &ty.inner {
write!(out, "::std::function<");
match &f.ret {
Some(ret) => write_type(out, ret),
None => write!(out, "void"),
}
write!(out, "(");
for (i, arg) in f.args.iter().enumerate() {
if i > 0 {
write!(out, ", ");
}
write_type(out, &arg.ty);
}
write!(out, ")>");
} else {
unreachable!("should not have produced a CxxFunction with non-Fn inner")
}
}
Type::Ref(r) => {
write_type_space(out, &r.inner);
if !r.mutable {
Expand Down Expand Up @@ -1339,6 +1359,7 @@ fn write_space_after_type(out: &mut OutFile, ty: &Type) {
| Type::WeakPtr(_)
| Type::Str(_)
| Type::CxxVector(_)
| Type::CxxFunction(_)
| Type::RustVec(_)
| Type::SliceRef(_)
| Type::Fn(_)
Expand Down Expand Up @@ -1413,6 +1434,7 @@ fn write_generic_instantiations(out: &mut OutFile) {
ImplKey::SharedPtr(ident) => write_shared_ptr(out, ident),
ImplKey::WeakPtr(ident) => write_weak_ptr(out, ident),
ImplKey::CxxVector(ident) => write_cxx_vector(out, ident),
ImplKey::CxxFunction(ref key) => write_cxx_function(out, key),
}
}
out.end_block(Block::ExternC);
Expand Down Expand Up @@ -1956,3 +1978,36 @@ fn write_cxx_vector(out: &mut OutFile, key: NamedImplKey) {
out.include.memory = true;
write_unique_ptr_common(out, UniquePtr::CxxVector(element));
}

fn write_cxx_function(out: &mut OutFile, key: &FunctionImplKey) {
begin_function_definition(out);
let link_invoke = key.link_name_invoke(out.types);

if let Type::CxxFunction(ty) = key.ty {
if let Type::Fn(sig) = &ty.inner {
if let Some(ret) = &sig.ret {
write_type(out, &ret);
} else {
write!(out, "void")
}
write!(out, " {link_invoke}(");
write_type(out, &key.ty);
write!(out, " f");
for (id, a) in sig.args.iter().enumerate() {
write!(out, ", ");
write_type(out, &a.ty);
write!(out, " a_{id}");
}
writeln!(out, ") {{");
write!(out, " return f(");
for id in 0..sig.args.len() {
write!(out, "a_{id}");
if id != sig.args.len() - 1 {
write!(out, ", ");
}
}
writeln!(out, ");");
writeln!(out, "}}");
}
}
}
3 changes: 2 additions & 1 deletion macro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ experimental-enum-variants-from-header = ["clang-ast", "flate2", "memmap", "serd
[dependencies]
proc-macro2 = "1.0.63"
quote = "1.0.29"
syn = { version = "2.0.23", features = ["full"] }
syn = { version = "2.0.23", features = ["full", "extra-traits"] }

# optional dependencies:
clang-ast = { version = "0.1.18", optional = true }
Expand All @@ -32,6 +32,7 @@ memmap = { version = "0.7", optional = true }
serde = { version = "1.0.166", optional = true }
serde_derive = { version = "1.0.166", optional = true }
serde_json = { version = "1.0.100", optional = true }
libc = "0.2.150"

[dev-dependencies]
cxx = { version = "1.0", path = ".." }
Expand Down
86 changes: 76 additions & 10 deletions macro/src/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use crate::syntax::atom::Atom::*;
use crate::syntax::attrs::{self, OtherAttrs};
use crate::syntax::cfg::CfgExpr;
use crate::syntax::file::Module;
use crate::syntax::instantiate::{ImplKey, NamedImplKey};
use crate::syntax::instantiate::{FunctionImplKey, ImplKey, NamedImplKey};
use crate::syntax::qualified::QualifiedName;
use crate::syntax::report::Errors;
use crate::syntax::symbol::Symbol;
use crate::syntax::{
self, check, mangle, Api, Doc, Enum, ExternFn, ExternType, Impl, Lifetimes, Pair, Signature,
Struct, Trait, Type, TypeAlias, Types,
self, check, mangle, Api, Doc, Enum, ExternFn, ExternType, Impl, Lifetimes, Pair,
Signature, Struct, Trait, Type, TypeAlias, Types,
};
use crate::type_id::Crate;
use crate::{derive, generics};
Expand Down Expand Up @@ -91,25 +91,35 @@ fn expand(ffi: Module, doc: Doc, attrs: OtherAttrs, apis: &[Api], types: &Types)
}
}

let mut first_cxx_function = true;
for (impl_key, &explicit_impl) in &types.impls {
match *impl_key {
match impl_key {
ImplKey::RustBox(ident) => {
hidden.extend(expand_rust_box(ident, types, explicit_impl));
hidden.extend(expand_rust_box(*ident, types, explicit_impl));
}
ImplKey::RustVec(ident) => {
hidden.extend(expand_rust_vec(ident, types, explicit_impl));
hidden.extend(expand_rust_vec(*ident, types, explicit_impl));
}
ImplKey::UniquePtr(ident) => {
expanded.extend(expand_unique_ptr(ident, types, explicit_impl));
expanded.extend(expand_unique_ptr(*ident, types, explicit_impl));
}
ImplKey::SharedPtr(ident) => {
expanded.extend(expand_shared_ptr(ident, types, explicit_impl));
expanded.extend(expand_shared_ptr(*ident, types, explicit_impl));
}
ImplKey::WeakPtr(ident) => {
expanded.extend(expand_weak_ptr(ident, types, explicit_impl));
expanded.extend(expand_weak_ptr(*ident, types, explicit_impl));
}
ImplKey::CxxVector(ident) => {
expanded.extend(expand_cxx_vector(ident, explicit_impl, types));
expanded.extend(expand_cxx_vector(*ident, explicit_impl, types));
}
ImplKey::CxxFunction(sig) => {
expanded.extend(expand_cxx_function(
&sig,
explicit_impl,
types,
first_cxx_function,
));
first_cxx_function = false;
}
}
}
Expand Down Expand Up @@ -1674,6 +1684,7 @@ fn expand_cxx_vector(
let elem = key.rust;
let name = elem.to_string();
let resolve = types.resolve(elem);

let prefix = format!("cxxbridge1$std$vector${}$", resolve.name.to_symbol());
let link_new = format!("{}new", prefix);
let link_size = format!("{}size", prefix);
Expand Down Expand Up @@ -1819,6 +1830,61 @@ fn expand_cxx_vector(
}
}

fn expand_cxx_function(
key: &FunctionImplKey,
_explicit_impl: Option<&Impl>,
types: &Types,
first_cxx_function: bool,
) -> TokenStream {
let link_invoke = key.link_name_invoke(types);

if let Type::CxxFunction(ty) = key.ty {
if let Type::Fn(sig) = &ty.inner {
let only_first_time = if first_cxx_function {
quote! {
pub struct CrateFnImpls;
pub type CxxFunction<F> = ::cxx::CxxFunction<CrateFnImpls, F>;
}
} else {
Default::default()
};

let ret = if let Some(Type::Ident(r)) = &sig.ret {
quote! { #r }
} else {
quote! { () }
};

let mut arg_types = quote!{};
let mut arg_names = quote!{};
let mut arg_type_names = quote!{};
for (id, a) in sig.args.iter().enumerate() {
if let Type::Ident(ident) = &a.ty {
let arg_name = format_ident!("a_{id}");
arg_types = quote!{ #arg_types #ident, };
arg_names = quote!{ #arg_names #arg_name, };
arg_type_names = quote!{ #arg_type_names #arg_name: #ident, };
}
}

return quote_spanned! {ty.name.span() =>
#only_first_time

unsafe impl ::cxx::private::CxxFunctionImpl<#ret, (#arg_types)> for CrateFnImpls {
fn __invoke(this: *mut ::cxx::core::ffi::c_void, (#arg_names): (#arg_types)) -> #ret {
extern "C" {
#[link_name = #link_invoke]
fn __function_invoke(this: *mut ::cxx::core::ffi::c_void, #arg_type_names) -> #ret;
}
unsafe { __function_invoke(this, #arg_names) }
}
}
};
}
}
panic!("bad type");
}

fn expand_return_type(ret: &Option<Type>) -> TokenStream {
match ret {
Some(ret) => quote!(-> #ret),
Expand Down
37 changes: 37 additions & 0 deletions src/cxx_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use core::{ffi::c_void, marker::PhantomData};

///
pub unsafe trait CxxFunctionImpl<Ret, Args> {
#[doc(hidden)]
fn __invoke(f: *mut c_void, a: Args) -> Ret;
}

/// Bridge to std::function<Fn>
#[repr(C)]
pub struct CxxFunction<I: ?Sized, F> {
// A thing, because repr(C) structs are not allowed to consist exclusively
// of PhantomData fields.
_void: [c_void; 0],
_impl: PhantomData<I>,
_fn: PhantomData<F>,
}

impl<I, Out> CxxFunction<I, fn() -> Out>
where
I: CxxFunctionImpl<Out, ()>,
{
/// Run the std::function
pub fn invoke(&mut self) -> Out {
<I as CxxFunctionImpl<Out, ()>>::__invoke(self as *mut _ as _, ())
}
}

impl<I, In, Out> CxxFunction<I, fn(In) -> Out>
where
I: CxxFunctionImpl<Out, (In,)>,
{
/// Run the std::function
pub fn invoke(&mut self, a: In) -> Out {
<I as CxxFunctionImpl<Out, (In,)>>::__invoke(self as *mut _ as _, (a,))
}
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ compile_error! {
mod macros;

mod c_char;
mod cxx_function;
mod cxx_vector;
mod exception;
mod extern_type;
Expand All @@ -475,6 +476,7 @@ mod unwind;
pub mod vector;
mod weak_ptr;

pub use crate::cxx_function::CxxFunction;
pub use crate::cxx_vector::CxxVector;
#[cfg(feature = "alloc")]
pub use crate::exception::Exception;
Expand Down Expand Up @@ -503,6 +505,7 @@ pub type Vector<T> = CxxVector<T>;
#[doc(hidden)]
pub mod private {
pub use crate::c_char::c_char;
pub use crate::cxx_function::CxxFunctionImpl;
pub use crate::cxx_vector::VectorElement;
pub use crate::extern_type::{verify_extern_kind, verify_extern_type};
pub use crate::function::FatFunction;
Expand Down
37 changes: 35 additions & 2 deletions syntax/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ fn do_typecheck(cx: &mut Check) {
Type::SharedPtr(ptr) => check_type_shared_ptr(cx, ptr),
Type::WeakPtr(ptr) => check_type_weak_ptr(cx, ptr),
Type::CxxVector(ptr) => check_type_cxx_vector(cx, ptr),
Type::CxxFunction(ptr) => check_type_cxx_function(cx, ptr),
Type::Ref(ty) => check_type_ref(cx, ty),
Type::Ptr(ty) => check_type_ptr(cx, ty),
Type::Array(array) => check_type_array(cx, array),
Expand Down Expand Up @@ -225,6 +226,34 @@ fn check_type_cxx_vector(cx: &mut Check, ptr: &Ty1) {
cx.error(ptr, "unsupported vector element type");
}

fn check_type_cxx_function(cx: &mut Check, ptr: &Ty1) {
if let Type::Fn(ptr) = &ptr.inner {
if ptr.asyncness.is_some() {
cx.error(ptr, "CxxFunction signature must not be async");
}
if ptr.unsafety.is_some() {
cx.error(
ptr,
"unimplemented: CxxFunction signature must not be unsafe",
);
}
if ptr.generics.lt_token.is_some() {
cx.error(ptr, "CxxFunction signature must not be generic");
}
if ptr.receiver.is_some() {
cx.error(ptr, "CxxFunction signature must not have a receiver");
}
if ptr.throws {
cx.error(ptr, "CxxFunction signature must not throw");
}
// if let Type::Fn(_sig) = &ptr.inner {
// return;
// }
return;
}
cx.error(ptr, "CxxFunction must have fn(a, b) -> c generic");
}

fn check_type_ref(cx: &mut Check, ty: &Ref) {
if ty.mutable && !ty.pinned {
if let Some(requires_pin) = match &ty.inner {
Expand Down Expand Up @@ -526,7 +555,9 @@ fn check_api_impl(cx: &mut Check, imp: &Impl) {
| Type::UniquePtr(ty)
| Type::SharedPtr(ty)
| Type::WeakPtr(ty)
| Type::CxxVector(ty) => {
| Type::CxxVector(ty)
| Type::CxxFunction(ty) // ?????
=> {
if let Type::Ident(inner) = &ty.inner {
if Atom::from(&inner.rust).is_none() {
return;
Expand Down Expand Up @@ -609,6 +640,7 @@ fn check_reserved_name(cx: &mut Check, ident: &Ident) {
|| ident == "WeakPtr"
|| ident == "Vec"
|| ident == "CxxVector"
|| ident == "CxxFunction"
|| ident == "str"
|| Atom::from(ident).is_some()
{
Expand Down Expand Up @@ -648,7 +680,7 @@ fn is_unsized(cx: &mut Check, ty: &Type) -> bool {
ident == CxxString || is_opaque_cxx(cx, ident) || cx.types.rust.contains(ident)
}
Type::Array(array) => is_unsized(cx, &array.inner),
Type::CxxVector(_) | Type::Fn(_) | Type::Void(_) => true,
Type::CxxVector(_) | Type::CxxFunction(_) | Type::Fn(_) | Type::Void(_) => true, // NOTE: maybe actually sized? idk
Type::RustBox(_)
| Type::RustVec(_)
| Type::UniquePtr(_)
Expand Down Expand Up @@ -732,6 +764,7 @@ fn describe(cx: &mut Check, ty: &Type) -> String {
Type::Ptr(_) => "raw pointer".to_owned(),
Type::Str(_) => "&str".to_owned(),
Type::CxxVector(_) => "C++ vector".to_owned(),
Type::CxxFunction(_) => "C++ function".to_owned(),
Type::SliceRef(_) => "slice".to_owned(),
Type::Fn(_) => "function pointer".to_owned(),
Type::Void(_) => "()".to_owned(),
Expand Down
Loading

0 comments on commit ab5ac82

Please sign in to comment.