From 2352cb02258c239ab6f491cc2a4c74109f6b2916 Mon Sep 17 00:00:00 2001 From: luaneko Date: Wed, 25 Jun 2025 01:36:43 +1000 Subject: [PATCH] Implement async support in metatype --- crates/luaffi/src/future.rs | 9 +- crates/luaffi/src/lib.rs | 140 ++++++++++++++++++++--------- crates/luaffi_impl/src/metatype.rs | 103 ++++++++++++--------- src/main.rs | 2 +- 4 files changed, 169 insertions(+), 85 deletions(-) diff --git a/crates/luaffi/src/future.rs b/crates/luaffi/src/future.rs index d898b33..9740929 100644 --- a/crates/luaffi/src/future.rs +++ b/crates/luaffi/src/future.rs @@ -1,7 +1,7 @@ use crate::{ __internal::{display, type_id}, - Cdef, CdefBuilder, IntoFfi, Metatype, MetatypeBuilder, Type, TypeBuilder, TypeType, - UnsafeExternCFn, + Cdef, CdefBuilder, FfiReturnConvention, IntoFfi, Metatype, MetatypeBuilder, Type, TypeBuilder, + TypeType, UnsafeExternCFn, }; use luaify::luaify; use std::{ @@ -168,6 +168,11 @@ unsafe impl + 'static> Metatype for lua_future { unsafe impl + 'static> IntoFfi for lua_future { type Into = lua_future; + fn convention() -> FfiReturnConvention { + // futures are always returned by-value due to rust type inference limitations + FfiReturnConvention::ByValue + } + fn convert(self) -> Self::Into { self } diff --git a/crates/luaffi/src/lib.rs b/crates/luaffi/src/lib.rs index 72e097b..a371877 100644 --- a/crates/luaffi/src/lib.rs +++ b/crates/luaffi/src/lib.rs @@ -387,6 +387,12 @@ impl<'r> Drop for MetatypeBuilder<'r> { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FfiReturnConvention { + ByValue, + ByOutParam, +} + pub unsafe trait FromFfi: Sized { type From: Type + Sized; @@ -404,6 +410,13 @@ pub unsafe trait FromFfi: Sized { pub unsafe trait IntoFfi: Sized { type Into: Type + Sized; + fn convention() -> FfiReturnConvention { + match Self::Into::ty() { + TypeType::Void | TypeType::Primitive => FfiReturnConvention::ByValue, + TypeType::Aggregate => FfiReturnConvention::ByOutParam, + } + } + fn postlude(_ret: &str) -> impl Display { "" } @@ -414,8 +427,9 @@ pub unsafe trait IntoFfi: Sized { #[derive(Debug)] pub struct MetatypeMethodBuilder<'r, 'm> { metatype: &'m mut MetatypeBuilder<'r>, - params: String, // parameters to the lua function - args: String, // arguments to the C call + lparams: String, // parameters to the lua function + cparams: String, // parameters to the lua function + cargs: String, // arguments to the C call prelude: String, // function body prelude postlude: String, // function body postlude } @@ -424,8 +438,9 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> { pub fn new(metatype: &'m mut MetatypeBuilder<'r>) -> Self { Self { metatype, - params: String::new(), - args: String::new(), + lparams: String::new(), + cparams: String::new(), + cargs: String::new(), prelude: String::new(), postlude: String::new(), } @@ -437,18 +452,32 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> { "cannot declare void parameter" ); - (!self.params.is_empty()).then(|| self.params.push_str(", ")); - (!self.args.is_empty()).then(|| self.args.push_str(", ")); - write!(self.params, "{name}").unwrap(); - write!(self.args, "{name}").unwrap(); + let Self { + metatype: MetatypeBuilder { registry, .. }, + lparams, + cparams, + cargs, + prelude, + postlude, + .. + } = self; + + registry.include::(); + + (!lparams.is_empty()).then(|| lparams.push_str(", ")); + (!cparams.is_empty()).then(|| cparams.push_str(", ")); + (!cargs.is_empty()).then(|| cargs.push_str(", ")); + + write!(lparams, "{name}").unwrap(); + write!(cparams, "{}", T::From::cdecl(&name)).unwrap(); + write!(cargs, "{name}").unwrap(); if T::require_keepalive() { - write!(self.prelude, "local __keep_{name} = {name}; ").unwrap(); - write!(self.postlude, "__C.{KEEP_FN}(__keep_{name}); ").unwrap(); + write!(prelude, "local __keep_{name} = {name}; ").unwrap(); + write!(postlude, "__C.{KEEP_FN}(__keep_{name}); ").unwrap(); } - let name = name.to_string(); - write!(self.prelude, "{}", T::prelude(&name)).unwrap(); + write!(prelude, "{}", T::prelude(&name.to_string())).unwrap(); self } @@ -458,13 +487,27 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> { // this passes one lua `string` argument as two C `const uint8_t *ptr` and `uintptr_t len` // arguments, bypassing the slower generic `&[u8]: FromFfi` path which constructs a // temporary cdata to pass the string and its length in one argument - (!self.params.is_empty()).then(|| self.params.push_str(", ")); - (!self.args.is_empty()).then(|| self.args.push_str(", ")); - write!(self.params, "{name}").unwrap(); - write!(self.args, "{name}, __{name}_len").unwrap(); - write!(self.prelude, "local __{name}_len = 0; ").unwrap(); + let Self { + lparams, + cparams, + cargs, + prelude, + .. + } = self; + + let param_ptr = <*const u8>::cdecl("ptr"); + let param_len = usize::cdecl("len"); + + (!lparams.is_empty()).then(|| lparams.push_str(", ")); + (!cparams.is_empty()).then(|| cparams.push_str(", ")); + (!cargs.is_empty()).then(|| cargs.push_str(", ")); + + write!(lparams, "{name}").unwrap(); + write!(cparams, "{param_ptr}, {param_len}",).unwrap(); + write!(cargs, "{name}, __{name}_len").unwrap(); + write!(prelude, "local __{name}_len = 0; ").unwrap(); write!( - self.prelude, + prelude, r#"if {name} ~= nil then assert(type({name}) == "string", "string expected in argument '{name}', got " .. type({name})); __{name}_len = #{name}; end; "# ) .unwrap(); @@ -472,46 +515,63 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> { } pub fn param_ignored(&mut self) -> &mut Self { - (!self.params.is_empty()).then(|| self.params.push_str(", ")); - write!(self.params, "_").unwrap(); + (!self.lparams.is_empty()).then(|| self.lparams.push_str(", ")); + write!(self.lparams, "_").unwrap(); self } pub fn call(&mut self, func: impl Display) { let Self { - metatype, - params, - args, + metatype: + MetatypeBuilder { + registry, + cdef, + lua, + .. + }, + lparams, + cparams, + cargs, prelude, postlude, + .. } = self; - let lua = &mut metatype.lua; - write!(lua, "function({params}) {prelude}").unwrap(); + registry.include::(); + write!(lua, "function({lparams}) {prelude}").unwrap(); - match T::Into::ty() { - TypeType::Void => { - write!(lua, "__C.{func}({args}); {postlude}end").unwrap(); + match T::convention() { + FfiReturnConvention::ByValue => { + if T::Into::ty() == TypeType::Void { + write!(lua, "__C.{func}({cargs}); {postlude}end").unwrap(); + } else { + let check = T::postlude("__res"); + write!(lua, "local __res = __C.{func}({cargs}); ").unwrap(); + write!(lua, "{check}{postlude}return __res; end").unwrap(); + } + + writeln!(cdef, "{};", T::Into::cdecl(display!("{func}({cparams})"))).unwrap(); } - TypeType::Primitive => { - let check = T::postlude("__res"); - write!( - lua, - "local __res = __C.{func}({args}); {check}{postlude}return __res; end" - ) - .unwrap(); - } - TypeType::Aggregate => { + FfiReturnConvention::ByOutParam => { let ct = T::Into::name(); let check = T::postlude("__res"); write!(lua, "local __res = __new(__ct.{ct}); __C.{func}(__res").unwrap(); - if !args.is_empty() { - write!(lua, ", {args}").unwrap(); + if !cargs.is_empty() { + write!(lua, ", {cargs}").unwrap(); } - write!(lua, "); {check}{postlude}return __res; end").unwrap() + write!(lua, "); {check}{postlude}return __res; end").unwrap(); + write!(cdef, "void {func}({}", <*mut T::Into>::cdecl("out")).unwrap(); + if !cparams.is_empty() { + write!(cdef, ", {cparams}").unwrap(); + } + writeln!(cdef, ");").unwrap(); } } } + + pub fn call_inferred(&mut self, func: impl Display, _infer: impl FnOnce() -> T) { + self.call::(func) + } } // diff --git a/crates/luaffi_impl/src/metatype.rs b/crates/luaffi_impl/src/metatype.rs index 6755fc7..ddab6a6 100644 --- a/crates/luaffi_impl/src/metatype.rs +++ b/crates/luaffi_impl/src/metatype.rs @@ -3,7 +3,7 @@ use crate::utils::{ }; use proc_macro2::TokenStream; use quote::{ToTokens, format_ident, quote, quote_spanned}; -use std::{collections::HashSet, fmt}; +use std::{collections::HashSet, fmt, iter}; use syn::{ext::IdentExt, punctuated::Punctuated, spanned::Spanned, *}; pub fn transform(mut imp: ItemImpl) -> Result { @@ -181,6 +181,7 @@ impl ToTokens for Metamethod { struct FfiFunction { name: Ident, + is_async: bool, params: Vec, ret: Type, attrs: FfiFunctionAttrs, @@ -225,6 +226,7 @@ fn get_ffi_functions(imp: &mut ItemImpl) -> Result> { funcs.push(FfiFunction { name: func.sig.ident.clone(), + is_async: func.sig.asyncness.is_some(), params, ret, attrs, @@ -267,8 +269,8 @@ fn get_ffi_param_type(_ty: &Type) -> FfiParameterType { enum FfiReturnType { Void, - Primitive, - Aggregate, + ByValue, + ByOutParam, } fn get_ffi_ret_type(ty: &Type) -> FfiReturnType { @@ -288,9 +290,9 @@ fn get_ffi_ret_type(ty: &Type) -> FfiReturnType { if is_unit(ty) { FfiReturnType::Void } else if is_primitivelike(ty) { - FfiReturnType::Primitive + FfiReturnType::ByValue } else { - FfiReturnType::Aggregate + FfiReturnType::ByOutParam } } @@ -316,16 +318,23 @@ fn add_ffi_function(registry: &mut FfiRegistry, func: &FfiFunction) -> Result<() let func_name = &func.name; let shim_name = format_ident!("__ffi_{}", func_name.unraw()); let lua_name = format!("{}", func_name.unraw()); - let c_name = format!("{}_{}", ty.unraw(), func_name.unraw()); + let c_name = if let Some(priv_name) = lua_name.strip_prefix("__") { + format!("__{}_{priv_name}", ty.unraw()) + } else { + format!("{}_{lua_name}", ty.unraw()) + }; let func_params = &func.params; // target function parameters let func_ret = &func.ret; // target function return type let mut func_args = vec![]; // target function arguments let mut shim_params = vec![]; // shim function parameters - let mut shim_ret = quote_spanned!(func_ret.span() => // shim function return type - <#func_ret as #ffi::IntoFfi>::Into - ); + let mut shim_ret = if func.is_async { + // shim function return type + quote_spanned!(func_ret.span() => #ffi::future::lua_future>) + } else { + quote_spanned!(func_ret.span() => <#func_ret as #ffi::IntoFfi>::Into) + }; let mut asserts = vec![]; // compile-time builder asserts let mut build = vec![]; // ffi builder body @@ -358,43 +367,53 @@ fn add_ffi_function(registry: &mut FfiRegistry, func: &FfiFunction) -> Result<() } } - let mut shim_body = quote_spanned!(func_name.span() => // shim function body - <#func_ret as #ffi::IntoFfi>::convert(Self::#func_name(#(#func_args),*)) - ); - - match get_ffi_ret_type(func_ret) { - FfiReturnType::Void => { - asserts.push(quote_spanned!(func_ret.span() => - <<#func_ret as #ffi::IntoFfi>::Into as #ffi::Type>::ty() == #ffi::TypeType::Void - )); - } - FfiReturnType::Primitive => { - asserts.push(quote_spanned!(func_ret.span() => - <<#func_ret as #ffi::IntoFfi>::Into as #ffi::Type>::ty() == #ffi::TypeType::Primitive - )); - } - FfiReturnType::Aggregate => { - asserts.push(quote_spanned!(func_ret.span() => - <<#func_ret as #ffi::IntoFfi>::Into as #ffi::Type>::ty() == #ffi::TypeType::Aggregate - )); - - shim_params.insert(0, quote!(out: *mut #shim_ret)); - (shim_body, shim_ret) = (quote!(::std::ptr::write(out, #shim_body)), quote!(())); - } + // shim function body + let mut shim_body = if func.is_async { + // for async functions, wrapped the returned future in lua_future + quote_spanned!(func_name.span() => #ffi::future::lua_future::new(Self::#func_name(#(#func_args),*))) + } else { + quote_spanned!(func_name.span() => <#func_ret as #ffi::IntoFfi>::convert(Self::#func_name(#(#func_args),*))) }; - build.push(quote_spanned!(func_name.span() => - b.call::<#func_ret>(#c_name); - )); + if !func.is_async { + match get_ffi_ret_type(&func_ret) { + FfiReturnType::Void => { + asserts.push(quote_spanned!(func_ret.span() => + <<#func_ret as #ffi::IntoFfi>::Into as #ffi::Type>::ty() == #ffi::TypeType::Void + )); + } + FfiReturnType::ByValue => { + asserts.push(quote_spanned!(func_ret.span() => + <#func_ret as #ffi::IntoFfi>::convention() == #ffi::FfiReturnConvention::ByValue + )); + } + FfiReturnType::ByOutParam => { + asserts.push(quote_spanned!(func_ret.span() => + <#func_ret as #ffi::IntoFfi>::convention() == #ffi::FfiReturnConvention::ByOutParam + )); - let shim_params_ty = { - let tys: Punctuated = parse_quote!(#(#shim_params),*); - tys.iter().map(|pat| (*pat.ty).clone()).collect::>() - }; + shim_params.insert(0, quote!(out: *mut #shim_ret)); + (shim_body, shim_ret) = (quote!(::std::ptr::write(out, #shim_body)), quote!(())); + } + }; + } - registry.build.push(quote!( + // build.push(quote_spanned!(func_name.span() => + // b.call_inferred(#c_name, Self::#func_name); + // )); + + build.push({ + let infer_args = iter::repeat_n(quote!(::std::unreachable!()), func_params.len()); + let infer = if func.is_async { + quote!(|| #ffi::future::lua_future::new(Self::#func_name(#(#infer_args),*))) + } else { + quote!(|| Self::#func_name(#(#infer_args),*)) + }; + quote_spanned!(func_name.span() => b.call_inferred(#c_name, #infer);) + }); + + registry.build.push(quote_spanned!(func_name.span() => #(::std::assert!(#asserts);)* - b.declare::<#ffi::UnsafeExternCFn<(#(#shim_params_ty,)*), #shim_ret>>(#c_name); )); registry.build.push(match func.attrs.metamethod { @@ -678,7 +697,7 @@ fn inject_merged_drop(registry: &mut FfiRegistry, lua: Option<&LuaFunction>) -> if ::std::mem::needs_drop::() { // we only have a rust drop b.declare::<#ffi::UnsafeExternCFn<(*mut Self,), ()>>(#c_name_str); - b.metatable_raw("gc", #luaify(|self| { __C::#c_name(self); })); + b.metatable_raw("gc", ::std::format_args!("__C.{}", #c_name_str)); } )); } diff --git a/src/main.rs b/src/main.rs index ab78713..c4c0247 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,7 +18,7 @@ fn panic_cb(panic: &panic::PanicHookInfo) { "unknown error" }; - eprintln!( + eprint!( "{}:\n{trace}", format_args!( "thread '{}' panicked at {location}: {msg}",