From 98100d02fa542d7df09ade434166622b09c3c4e2 Mon Sep 17 00:00:00 2001 From: luaneko Date: Wed, 25 Jun 2025 18:42:09 +1000 Subject: [PATCH] Implement string parameter specialisation --- crates/luaffi/src/lib.rs | 45 +++++--- crates/luaffi_impl/src/metatype.rs | 169 ++++++++++++++++++++--------- crates/luaffi_impl/src/utils.rs | 73 +++++++++++-- crates/luajit/src/lib.rs | 3 +- 4 files changed, 207 insertions(+), 83 deletions(-) diff --git a/crates/luaffi/src/lib.rs b/crates/luaffi/src/lib.rs index a1518ba..baaab24 100644 --- a/crates/luaffi/src/lib.rs +++ b/crates/luaffi/src/lib.rs @@ -485,12 +485,17 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> { self } - pub fn param_str(&mut self, name: impl Display) -> &mut Self { - // fast-path for &str and &[u8]-like parameters + pub fn param_str( + &mut self, + name: impl Display, + allow_nil: bool, + check_utf8: bool, + ) -> &mut Self { + // fast-path for &[u8] and &str-like parameters // // this passes one lua `string` argument as two C `const uint8_t *ptr` and `uintptr_t len` // arguments, bypassing the slower generic `&[u8]: FromFfi` path which constructs a - // temporary cdata to pass the string and its length in one argument + // temporary cdata to pass the string and its length in one argument. let Self { lparams, cparams, @@ -499,22 +504,26 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> { .. } = self; - let param_ptr = <*const u8>::cdecl("ptr"); - let param_len = usize::cdecl("len"); + let param_ptr = <*const u8>::cdecl(&name); + let param_len = usize::cdecl(format!("{name}_len")); (!lparams.is_empty()).then(|| lparams.push_str(", ")); (!cparams.is_empty()).then(|| cparams.push_str(", ")); (!cargs.is_empty()).then(|| cargs.push_str(", ")); write!(lparams, "{name}").unwrap(); - write!(cparams, "{param_ptr}, {param_len}",).unwrap(); + write!(cparams, "{param_ptr}, {param_len}").unwrap(); write!(cargs, "{name}, __{name}_len").unwrap(); - write!(prelude, "local __{name}_len = 0; ").unwrap(); - write!( - prelude, - r#"if {name} ~= nil then assert(type({name}) == "string", "string expected in argument '{name}', got " .. type({name})); __{name}_len = #{name}; end; "# - ) - .unwrap(); + write!(prelude, "local __{name}_len = 0; if {name} ~= nil then ").unwrap(); + write!(prelude, r#"assert(type({name}) == "string", "string expected in argument '{name}', got " .. type({name})); "#).unwrap(); + write!(prelude, r#"__{name}_len = #{name}; "#).unwrap(); + if check_utf8 { + write!(prelude, r#"assert(__C.{IS_UTF8_FN}({name}, __{name}_len), "argument '{name}' must be a valid utf-8 string"); "#).unwrap(); + } + if !allow_nil { + write!(prelude, r#"else return error("string expected in argument '{name}', got " .. type({name})); "#).unwrap(); + } + write!(prelude, r#"end; "#).unwrap(); self } @@ -549,21 +558,21 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> { if T::Into::ty() == TypeType::Void { write!(lua, "__C.{func}({cargs}); {postlude}end").unwrap(); } else { - let check = T::postlude("__res"); - write!(lua, "local __res = __C.{func}({cargs}); ").unwrap(); - write!(lua, "{check}{postlude}return __res; end").unwrap(); + let check = T::postlude("__ret"); + write!(lua, "local __ret = __C.{func}({cargs}); ").unwrap(); + write!(lua, "{check}{postlude}return __ret; end").unwrap(); } writeln!(cdef, "{};", T::Into::cdecl(display!("{func}({cparams})"))).unwrap(); } FfiReturnConvention::ByOutParam => { let ct = T::Into::name(); - let check = T::postlude("__res"); - write!(lua, "local __res = __new(__ct.{ct}); __C.{func}(__res").unwrap(); + let check = T::postlude("__out"); + write!(lua, "local __out = __new(__ct.{ct}); __C.{func}(__out").unwrap(); if !cargs.is_empty() { write!(lua, ", {cargs}").unwrap(); } - write!(lua, "); {check}{postlude}return __res; end").unwrap(); + write!(lua, "); {check}{postlude}return __out; end").unwrap(); write!(cdef, "void {func}({}", <*mut T::Into>::cdecl("out")).unwrap(); if !cparams.is_empty() { write!(cdef, ", {cparams}").unwrap(); diff --git a/crates/luaffi_impl/src/metatype.rs b/crates/luaffi_impl/src/metatype.rs index 104c9f8..bf574e4 100644 --- a/crates/luaffi_impl/src/metatype.rs +++ b/crates/luaffi_impl/src/metatype.rs @@ -1,5 +1,6 @@ use crate::utils::{ - ffi_crate, is_primitivelike, is_unit, pat_ident, syn_assert, syn_error, ty_name, + StringLike, ffi_crate, is_optionlike, is_primitivelike, is_stringlike, is_unit, pat_ident, + syn_assert, syn_error, ty_name, }; use proc_macro2::TokenStream; use quote::{ToTokens, format_ident, quote, quote_spanned}; @@ -29,12 +30,27 @@ pub fn transform(mut imp: ItemImpl) -> Result { )) } +struct Registry { + ty: Ident, + shims: Vec, + build: Vec, +} + +impl Registry { + fn new(ty: Ident) -> Self { + Self { + ty, + shims: vec![], + build: vec![], + } + } +} + fn generate_impls(imp: &mut ItemImpl) -> Result { let ffi = ffi_crate(); let ty = imp.self_ty.clone(); let ty_name = ty_name(&ty)?; - let mut ffi_funcs = FfiRegistry::new(ty_name.clone()); - let mut lua_funcs = LuaRegistry::new(ty_name.clone()); + let mut registry = Registry::new(ty_name.clone()); let mut mms = HashSet::new(); let mut lua_drop = None; @@ -47,7 +63,7 @@ fn generate_impls(imp: &mut ItemImpl) -> Result { ); } - add_ffi_function(&mut ffi_funcs, &func)?; + add_ffi_function(&mut registry, &func)?; } for func in get_lua_functions(imp)? { @@ -59,37 +75,32 @@ fn generate_impls(imp: &mut ItemImpl) -> Result { ); } - if func.attrs.metamethod == Some(Metamethod::Gc) { + if let Some(Metamethod::Gc) = func.attrs.metamethod { lua_drop = Some(func); } else { - add_lua_function(&mut lua_funcs, &func)?; + add_lua_function(&mut registry, &func)?; } } if !mms.contains(&Metamethod::New) { - inject_fallback_new(&mut lua_funcs)?; + inject_fallback_new(&mut registry)?; } - inject_merged_drop(&mut ffi_funcs, lua_drop.as_ref())?; + inject_merged_drop(&mut registry, lua_drop.as_ref())?; - let ffi_shims = &ffi_funcs.shims; - let ffi_build = &ffi_funcs.build; - let lua_build = &lua_funcs.build; - let ffi_exports = generate_ffi_exports(&ffi_funcs)?; + let shims = ®istry.shims; + let build = ®istry.build; + let exports = generate_ffi_exports(®istry)?; Ok(quote_spanned!(ty.span() => - impl #ty { #(#ffi_shims)* } + impl #ty { #(#shims)* } unsafe impl #ffi::Metatype for #ty { type Target = Self; - - fn build(b: &mut #ffi::MetatypeBuilder) { - #(#ffi_build)* - #(#lua_build)* - } + fn build(b: &mut #ffi::MetatypeBuilder) { #(#build)* } } - #ffi_exports + #exports )) } @@ -201,9 +212,15 @@ fn get_ffi_functions(imp: &mut ItemImpl) -> Result> { && let Some(ref abi) = abi.name && abi.value() == "Lua-C" { + syn_assert!( + func.sig.generics.params.len() == 0, + func.sig.generics, + "cannot be generic" + ); + func.sig.abi = None; - let params = func + let params: Vec<_> = func .sig .inputs .iter() @@ -221,6 +238,24 @@ fn get_ffi_functions(imp: &mut ItemImpl) -> Result> { ReturnType::Type(_, ref ty) => (**ty).clone(), }; + for param in params.iter() { + // double underscores are reserved for generated glue code + syn_assert!( + !pat_ident(¶m.pat)?.to_string().starts_with("__"), + param.pat, + "parameter names should not start with `__`" + ); + + // lifetime should be determined by the caller (lua) + if let Type::Reference(ref ty) = *param.ty { + syn_assert!( + ty.lifetime.is_none(), + ty.lifetime, + "lifetime should be determined by the caller" + ); + } + } + let attrs = parse_ffi_function_attrs(&mut func.attrs)?; attrs.metamethod.map(|mm| document_metamethod(func, mm)); @@ -261,14 +296,26 @@ fn parse_ffi_function_attrs(attrs: &mut Vec) -> Result FfiParameterType { - FfiParameterType::Default +fn get_ffi_param_type(ty: &Type) -> FfiParameterType { + if let Some(str) = is_stringlike(ty) { + FfiParameterType::StringLike(str) + } else if let Some(arg) = is_optionlike(ty) + && let Some(str) = is_stringlike(arg) + { + FfiParameterType::OptionStringLike(str) + } else { + FfiParameterType::Default + } } +#[derive(Debug, Clone, Copy)] enum FfiReturnType { Void, ByValue, @@ -298,23 +345,7 @@ fn get_ffi_ret_type(ty: &Type) -> FfiReturnType { } } -struct FfiRegistry { - ty: Ident, - shims: Vec, - build: Vec, -} - -impl FfiRegistry { - fn new(ty: Ident) -> Self { - Self { - ty, - shims: vec![], - build: vec![], - } - } -} - -fn add_ffi_function(registry: &mut FfiRegistry, func: &FfiFunction) -> Result<()> { +fn add_ffi_function(registry: &mut Registry, func: &FfiFunction) -> Result<()> { let ffi = ffi_crate(); let ty = ®istry.ty; let func_name = &func.name; @@ -366,6 +397,41 @@ fn add_ffi_function(registry: &mut FfiRegistry, func: &FfiFunction) -> Result<() b.param::<#func_param>(#name); )); } + ty @ (FfiParameterType::StringLike(str) | FfiParameterType::OptionStringLike(str)) => { + let shim_param_len = format_ident!("arg{i}_len"); + shim_params.push(quote_spanned!(func_param.span() => + #shim_param: ::std::option::Option<&::std::primitive::u8>, + #shim_param_len: ::std::primitive::usize + )); + let allow_nil = matches!(ty, FfiParameterType::OptionStringLike(_)); + let check_utf8 = matches!(str, StringLike::Str); + let mut func_arg = quote_spanned!(func_param.span() => + #shim_param.map(|s| ::std::slice::from_raw_parts(s, #shim_param_len)) + ); + func_arg = match str { + StringLike::SliceU8 => func_arg, + StringLike::BStr => { + quote_spanned!(func_param.span() => #func_arg.map(::bstr::BStr::new)) + } + StringLike::Str => { + quote_spanned!(func_param.span() => #func_arg.map(|s| { + ::std::debug_assert!(::std::str::from_utf8(s).is_ok()); + ::std::str::from_utf8_unchecked(s) + })) + } + }; + if !allow_nil { + func_arg = quote_spanned!(func_param.span() => { + let arg = #func_arg; + ::std::debug_assert!(arg.is_some()); + arg.unwrap_unchecked() + }); + } + func_args.push(func_arg); + build.push(quote_spanned!(param.pat.span() => + b.param_str(#name, #allow_nil, #check_utf8); + )); + } } } @@ -439,7 +505,7 @@ fn add_ffi_function(registry: &mut FfiRegistry, func: &FfiFunction) -> Result<() Ok(()) } -fn generate_ffi_exports(registry: &FfiRegistry) -> Result { +fn generate_ffi_exports(registry: &Registry) -> Result { let ty = ®istry.ty; let names = registry.shims.iter().map(|f| &f.sig.ident); @@ -474,6 +540,12 @@ fn get_lua_functions(imp: &mut ItemImpl) -> Result> { && let Some(ref abi) = abi.name && abi.value() == "Lua" { + syn_assert!( + func.sig.generics.params.len() == 0, + func.sig.generics, + "cannot be generic" + ); + let mut params: Vec<_> = func .sig .inputs @@ -597,18 +669,7 @@ fn parse_lua_function_attrs(attrs: &mut Vec) -> Result, -} - -impl LuaRegistry { - fn new(ty: Ident) -> Self { - Self { ty, build: vec![] } - } -} - -fn add_lua_function(registry: &mut LuaRegistry, func: &LuaFunction) -> Result<()> { +fn add_lua_function(registry: &mut Registry, func: &LuaFunction) -> Result<()> { let ffi = ffi_crate(); let luaify = quote!(#ffi::__internal::luaify!); let func_name = &func.name; @@ -628,7 +689,7 @@ fn add_lua_function(registry: &mut LuaRegistry, func: &LuaFunction) -> Result<() Ok(()) } -fn inject_fallback_new(registry: &mut LuaRegistry) -> Result<()> { +fn inject_fallback_new(registry: &mut Registry) -> Result<()> { let ty = ®istry.ty; let lua = format!( r#"function() error("type '{}' has no constructor"); end"#, @@ -642,7 +703,7 @@ fn inject_fallback_new(registry: &mut LuaRegistry) -> Result<()> { Ok(()) } -fn inject_merged_drop(registry: &mut FfiRegistry, lua: Option<&LuaFunction>) -> Result<()> { +fn inject_merged_drop(registry: &mut Registry, lua: Option<&LuaFunction>) -> Result<()> { let ffi = ffi_crate(); let luaify = quote!(#ffi::__internal::luaify!); let ty = ®istry.ty; diff --git a/crates/luaffi_impl/src/utils.rs b/crates/luaffi_impl/src/utils.rs index 07c7f6d..356fac9 100644 --- a/crates/luaffi_impl/src/utils.rs +++ b/crates/luaffi_impl/src/utils.rs @@ -57,13 +57,13 @@ pub fn is_unit(ty: &Type) -> bool { pub fn is_primitivelike(ty: &Type) -> bool { match ty { - Type::Tuple(tuple) if tuple.elems.is_empty() => true, // unit type - Type::Reference(_) | Type::Ptr(_) => true, - Type::Paren(paren) => is_primitivelike(&paren.elem), + Type::Tuple(tuple) if tuple.elems.is_empty() => return true, // unit type + Type::Reference(_) | Type::Ptr(_) => return true, + Type::Paren(paren) => return is_primitivelike(&paren.elem), Type::Path(path) => { if let Some(name) = path.path.get_ident() { - matches!( - format!("{name}").as_str(), + return matches!( + name.to_string().as_str(), "bool" | "u8" | "u16" @@ -94,11 +94,66 @@ pub fn is_primitivelike(ty: &Type) -> bool { | "c_size_t" | "c_ssize_t" | "c_ptrdiff_t" - ) - } else { - false + ); } } - _ => false, + _ => {} + } + + false +} + +#[derive(Debug, Clone, Copy)] +pub enum StringLike { + SliceU8, + Str, + BStr, +} + +pub fn is_stringlike(ty: &Type) -> Option { + if let Type::Reference(ty) = ty + && ty.mutability.is_none() + && ty.lifetime.is_none() + { + match *ty.elem { + Type::Slice(ref slice) => { + // match &[u8] + if let Type::Path(ref path) = *slice.elem + && let Some(name) = path.path.get_ident() + && name == "u8" + { + return Some(StringLike::SliceU8); + } + } + Type::Path(ref path) => { + // match &str or &BStr + if let Some(name) = path.path.get_ident() { + match name.to_string().as_str() { + "str" => return Some(StringLike::Str), + "BStr" => return Some(StringLike::BStr), + _ => {} + } + } + } + _ => {} + } + } + + None +} + +pub fn is_optionlike(ty: &Type) -> Option<&Type> { + if let Type::Path(path) = ty + && path.path.leading_colon.is_none() + && path.path.segments.len() == 1 + && let Some(segment) = path.path.segments.get(0) + && segment.ident == "Option" + && let PathArguments::AngleBracketed(ref angle) = segment.arguments + && angle.args.len() == 1 + && let Some(GenericArgument::Type(ty)) = angle.args.get(0) + { + Some(ty) + } else { + None } } diff --git a/crates/luajit/src/lib.rs b/crates/luajit/src/lib.rs index 939e813..6f07784 100644 --- a/crates/luajit/src/lib.rs +++ b/crates/luajit/src/lib.rs @@ -1333,8 +1333,7 @@ impl<'s> DerefMut for StackGuard<'s> { impl<'s> Drop for StackGuard<'s> { fn drop(&mut self) { - #[cfg(debug_assertions)] - if self.check_overpop { + if cfg!(debug_assertions) && self.check_overpop { let new_size = self.stack.size(); assert!( self.size <= new_size,