Implement async support in metatype

This commit is contained in:
lumi 2025-06-25 01:36:43 +10:00
parent cbf786206d
commit 2352cb0225
Signed by: luaneko
GPG Key ID: 406809B8763FF07A
4 changed files with 169 additions and 85 deletions

View File

@ -1,7 +1,7 @@
use crate::{
__internal::{display, type_id},
Cdef, CdefBuilder, IntoFfi, Metatype, MetatypeBuilder, Type, TypeBuilder, TypeType,
UnsafeExternCFn,
Cdef, CdefBuilder, FfiReturnConvention, IntoFfi, Metatype, MetatypeBuilder, Type, TypeBuilder,
TypeType, UnsafeExternCFn,
};
use luaify::luaify;
use std::{
@ -168,6 +168,11 @@ unsafe impl<F: Future<Output: IntoFfi> + 'static> Metatype for lua_future<F> {
unsafe impl<F: Future<Output: IntoFfi> + 'static> IntoFfi for lua_future<F> {
type Into = lua_future<F>;
fn convention() -> FfiReturnConvention {
// futures are always returned by-value due to rust type inference limitations
FfiReturnConvention::ByValue
}
fn convert(self) -> Self::Into {
self
}

View File

@ -387,6 +387,12 @@ impl<'r> Drop for MetatypeBuilder<'r> {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FfiReturnConvention {
ByValue,
ByOutParam,
}
pub unsafe trait FromFfi: Sized {
type From: Type + Sized;
@ -404,6 +410,13 @@ pub unsafe trait FromFfi: Sized {
pub unsafe trait IntoFfi: Sized {
type Into: Type + Sized;
fn convention() -> FfiReturnConvention {
match Self::Into::ty() {
TypeType::Void | TypeType::Primitive => FfiReturnConvention::ByValue,
TypeType::Aggregate => FfiReturnConvention::ByOutParam,
}
}
fn postlude(_ret: &str) -> impl Display {
""
}
@ -414,8 +427,9 @@ pub unsafe trait IntoFfi: Sized {
#[derive(Debug)]
pub struct MetatypeMethodBuilder<'r, 'm> {
metatype: &'m mut MetatypeBuilder<'r>,
params: String, // parameters to the lua function
args: String, // arguments to the C call
lparams: String, // parameters to the lua function
cparams: String, // parameters to the lua function
cargs: String, // arguments to the C call
prelude: String, // function body prelude
postlude: String, // function body postlude
}
@ -424,8 +438,9 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> {
pub fn new(metatype: &'m mut MetatypeBuilder<'r>) -> Self {
Self {
metatype,
params: String::new(),
args: String::new(),
lparams: String::new(),
cparams: String::new(),
cargs: String::new(),
prelude: String::new(),
postlude: String::new(),
}
@ -437,18 +452,32 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> {
"cannot declare void parameter"
);
(!self.params.is_empty()).then(|| self.params.push_str(", "));
(!self.args.is_empty()).then(|| self.args.push_str(", "));
write!(self.params, "{name}").unwrap();
write!(self.args, "{name}").unwrap();
let Self {
metatype: MetatypeBuilder { registry, .. },
lparams,
cparams,
cargs,
prelude,
postlude,
..
} = self;
registry.include::<T::From>();
(!lparams.is_empty()).then(|| lparams.push_str(", "));
(!cparams.is_empty()).then(|| cparams.push_str(", "));
(!cargs.is_empty()).then(|| cargs.push_str(", "));
write!(lparams, "{name}").unwrap();
write!(cparams, "{}", T::From::cdecl(&name)).unwrap();
write!(cargs, "{name}").unwrap();
if T::require_keepalive() {
write!(self.prelude, "local __keep_{name} = {name}; ").unwrap();
write!(self.postlude, "__C.{KEEP_FN}(__keep_{name}); ").unwrap();
write!(prelude, "local __keep_{name} = {name}; ").unwrap();
write!(postlude, "__C.{KEEP_FN}(__keep_{name}); ").unwrap();
}
let name = name.to_string();
write!(self.prelude, "{}", T::prelude(&name)).unwrap();
write!(prelude, "{}", T::prelude(&name.to_string())).unwrap();
self
}
@ -458,13 +487,27 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> {
// this passes one lua `string` argument as two C `const uint8_t *ptr` and `uintptr_t len`
// arguments, bypassing the slower generic `&[u8]: FromFfi` path which constructs a
// temporary cdata to pass the string and its length in one argument
(!self.params.is_empty()).then(|| self.params.push_str(", "));
(!self.args.is_empty()).then(|| self.args.push_str(", "));
write!(self.params, "{name}").unwrap();
write!(self.args, "{name}, __{name}_len").unwrap();
write!(self.prelude, "local __{name}_len = 0; ").unwrap();
let Self {
lparams,
cparams,
cargs,
prelude,
..
} = self;
let param_ptr = <*const u8>::cdecl("ptr");
let param_len = usize::cdecl("len");
(!lparams.is_empty()).then(|| lparams.push_str(", "));
(!cparams.is_empty()).then(|| cparams.push_str(", "));
(!cargs.is_empty()).then(|| cargs.push_str(", "));
write!(lparams, "{name}").unwrap();
write!(cparams, "{param_ptr}, {param_len}",).unwrap();
write!(cargs, "{name}, __{name}_len").unwrap();
write!(prelude, "local __{name}_len = 0; ").unwrap();
write!(
self.prelude,
prelude,
r#"if {name} ~= nil then assert(type({name}) == "string", "string expected in argument '{name}', got " .. type({name})); __{name}_len = #{name}; end; "#
)
.unwrap();
@ -472,46 +515,63 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> {
}
pub fn param_ignored(&mut self) -> &mut Self {
(!self.params.is_empty()).then(|| self.params.push_str(", "));
write!(self.params, "_").unwrap();
(!self.lparams.is_empty()).then(|| self.lparams.push_str(", "));
write!(self.lparams, "_").unwrap();
self
}
pub fn call<T: IntoFfi>(&mut self, func: impl Display) {
let Self {
metatype,
params,
args,
metatype:
MetatypeBuilder {
registry,
cdef,
lua,
..
},
lparams,
cparams,
cargs,
prelude,
postlude,
..
} = self;
let lua = &mut metatype.lua;
write!(lua, "function({params}) {prelude}").unwrap();
registry.include::<T::Into>();
write!(lua, "function({lparams}) {prelude}").unwrap();
match T::Into::ty() {
TypeType::Void => {
write!(lua, "__C.{func}({args}); {postlude}end").unwrap();
match T::convention() {
FfiReturnConvention::ByValue => {
if T::Into::ty() == TypeType::Void {
write!(lua, "__C.{func}({cargs}); {postlude}end").unwrap();
} else {
let check = T::postlude("__res");
write!(lua, "local __res = __C.{func}({cargs}); ").unwrap();
write!(lua, "{check}{postlude}return __res; end").unwrap();
}
writeln!(cdef, "{};", T::Into::cdecl(display!("{func}({cparams})"))).unwrap();
}
TypeType::Primitive => {
let check = T::postlude("__res");
write!(
lua,
"local __res = __C.{func}({args}); {check}{postlude}return __res; end"
)
.unwrap();
}
TypeType::Aggregate => {
FfiReturnConvention::ByOutParam => {
let ct = T::Into::name();
let check = T::postlude("__res");
write!(lua, "local __res = __new(__ct.{ct}); __C.{func}(__res").unwrap();
if !args.is_empty() {
write!(lua, ", {args}").unwrap();
if !cargs.is_empty() {
write!(lua, ", {cargs}").unwrap();
}
write!(lua, "); {check}{postlude}return __res; end").unwrap()
write!(lua, "); {check}{postlude}return __res; end").unwrap();
write!(cdef, "void {func}({}", <*mut T::Into>::cdecl("out")).unwrap();
if !cparams.is_empty() {
write!(cdef, ", {cparams}").unwrap();
}
writeln!(cdef, ");").unwrap();
}
}
}
pub fn call_inferred<T: IntoFfi>(&mut self, func: impl Display, _infer: impl FnOnce() -> T) {
self.call::<T>(func)
}
}
//

View File

@ -3,7 +3,7 @@ use crate::utils::{
};
use proc_macro2::TokenStream;
use quote::{ToTokens, format_ident, quote, quote_spanned};
use std::{collections::HashSet, fmt};
use std::{collections::HashSet, fmt, iter};
use syn::{ext::IdentExt, punctuated::Punctuated, spanned::Spanned, *};
pub fn transform(mut imp: ItemImpl) -> Result<TokenStream> {
@ -181,6 +181,7 @@ impl ToTokens for Metamethod {
struct FfiFunction {
name: Ident,
is_async: bool,
params: Vec<PatType>,
ret: Type,
attrs: FfiFunctionAttrs,
@ -225,6 +226,7 @@ fn get_ffi_functions(imp: &mut ItemImpl) -> Result<Vec<FfiFunction>> {
funcs.push(FfiFunction {
name: func.sig.ident.clone(),
is_async: func.sig.asyncness.is_some(),
params,
ret,
attrs,
@ -267,8 +269,8 @@ fn get_ffi_param_type(_ty: &Type) -> FfiParameterType {
enum FfiReturnType {
Void,
Primitive,
Aggregate,
ByValue,
ByOutParam,
}
fn get_ffi_ret_type(ty: &Type) -> FfiReturnType {
@ -288,9 +290,9 @@ fn get_ffi_ret_type(ty: &Type) -> FfiReturnType {
if is_unit(ty) {
FfiReturnType::Void
} else if is_primitivelike(ty) {
FfiReturnType::Primitive
FfiReturnType::ByValue
} else {
FfiReturnType::Aggregate
FfiReturnType::ByOutParam
}
}
@ -316,16 +318,23 @@ fn add_ffi_function(registry: &mut FfiRegistry, func: &FfiFunction) -> Result<()
let func_name = &func.name;
let shim_name = format_ident!("__ffi_{}", func_name.unraw());
let lua_name = format!("{}", func_name.unraw());
let c_name = format!("{}_{}", ty.unraw(), func_name.unraw());
let c_name = if let Some(priv_name) = lua_name.strip_prefix("__") {
format!("__{}_{priv_name}", ty.unraw())
} else {
format!("{}_{lua_name}", ty.unraw())
};
let func_params = &func.params; // target function parameters
let func_ret = &func.ret; // target function return type
let mut func_args = vec![]; // target function arguments
let mut shim_params = vec![]; // shim function parameters
let mut shim_ret = quote_spanned!(func_ret.span() => // shim function return type
<#func_ret as #ffi::IntoFfi>::Into
);
let mut shim_ret = if func.is_async {
// shim function return type
quote_spanned!(func_ret.span() => #ffi::future::lua_future<impl ::std::future::Future<Output = #func_ret>>)
} else {
quote_spanned!(func_ret.span() => <#func_ret as #ffi::IntoFfi>::Into)
};
let mut asserts = vec![]; // compile-time builder asserts
let mut build = vec![]; // ffi builder body
@ -358,43 +367,53 @@ fn add_ffi_function(registry: &mut FfiRegistry, func: &FfiFunction) -> Result<()
}
}
let mut shim_body = quote_spanned!(func_name.span() => // shim function body
<#func_ret as #ffi::IntoFfi>::convert(Self::#func_name(#(#func_args),*))
);
match get_ffi_ret_type(func_ret) {
FfiReturnType::Void => {
asserts.push(quote_spanned!(func_ret.span() =>
<<#func_ret as #ffi::IntoFfi>::Into as #ffi::Type>::ty() == #ffi::TypeType::Void
));
}
FfiReturnType::Primitive => {
asserts.push(quote_spanned!(func_ret.span() =>
<<#func_ret as #ffi::IntoFfi>::Into as #ffi::Type>::ty() == #ffi::TypeType::Primitive
));
}
FfiReturnType::Aggregate => {
asserts.push(quote_spanned!(func_ret.span() =>
<<#func_ret as #ffi::IntoFfi>::Into as #ffi::Type>::ty() == #ffi::TypeType::Aggregate
));
shim_params.insert(0, quote!(out: *mut #shim_ret));
(shim_body, shim_ret) = (quote!(::std::ptr::write(out, #shim_body)), quote!(()));
}
// shim function body
let mut shim_body = if func.is_async {
// for async functions, wrapped the returned future in lua_future
quote_spanned!(func_name.span() => #ffi::future::lua_future::new(Self::#func_name(#(#func_args),*)))
} else {
quote_spanned!(func_name.span() => <#func_ret as #ffi::IntoFfi>::convert(Self::#func_name(#(#func_args),*)))
};
build.push(quote_spanned!(func_name.span() =>
b.call::<#func_ret>(#c_name);
));
if !func.is_async {
match get_ffi_ret_type(&func_ret) {
FfiReturnType::Void => {
asserts.push(quote_spanned!(func_ret.span() =>
<<#func_ret as #ffi::IntoFfi>::Into as #ffi::Type>::ty() == #ffi::TypeType::Void
));
}
FfiReturnType::ByValue => {
asserts.push(quote_spanned!(func_ret.span() =>
<#func_ret as #ffi::IntoFfi>::convention() == #ffi::FfiReturnConvention::ByValue
));
}
FfiReturnType::ByOutParam => {
asserts.push(quote_spanned!(func_ret.span() =>
<#func_ret as #ffi::IntoFfi>::convention() == #ffi::FfiReturnConvention::ByOutParam
));
let shim_params_ty = {
let tys: Punctuated<PatType, Token![,]> = parse_quote!(#(#shim_params),*);
tys.iter().map(|pat| (*pat.ty).clone()).collect::<Vec<_>>()
};
shim_params.insert(0, quote!(out: *mut #shim_ret));
(shim_body, shim_ret) = (quote!(::std::ptr::write(out, #shim_body)), quote!(()));
}
};
}
registry.build.push(quote!(
// build.push(quote_spanned!(func_name.span() =>
// b.call_inferred(#c_name, Self::#func_name);
// ));
build.push({
let infer_args = iter::repeat_n(quote!(::std::unreachable!()), func_params.len());
let infer = if func.is_async {
quote!(|| #ffi::future::lua_future::new(Self::#func_name(#(#infer_args),*)))
} else {
quote!(|| Self::#func_name(#(#infer_args),*))
};
quote_spanned!(func_name.span() => b.call_inferred(#c_name, #infer);)
});
registry.build.push(quote_spanned!(func_name.span() =>
#(::std::assert!(#asserts);)*
b.declare::<#ffi::UnsafeExternCFn<(#(#shim_params_ty,)*), #shim_ret>>(#c_name);
));
registry.build.push(match func.attrs.metamethod {
@ -678,7 +697,7 @@ fn inject_merged_drop(registry: &mut FfiRegistry, lua: Option<&LuaFunction>) ->
if ::std::mem::needs_drop::<Self>() {
// we only have a rust drop
b.declare::<#ffi::UnsafeExternCFn<(*mut Self,), ()>>(#c_name_str);
b.metatable_raw("gc", #luaify(|self| { __C::#c_name(self); }));
b.metatable_raw("gc", ::std::format_args!("__C.{}", #c_name_str));
}
));
}

View File

@ -18,7 +18,7 @@ fn panic_cb(panic: &panic::PanicHookInfo) {
"unknown error"
};
eprintln!(
eprint!(
"{}:\n{trace}",
format_args!(
"thread '{}' panicked at {location}: {msg}",