Implement string parameter specialisation

This commit is contained in:
lumi 2025-06-25 18:42:09 +10:00
parent 681dd332ab
commit 98100d02fa
Signed by: luaneko
GPG Key ID: 406809B8763FF07A
4 changed files with 207 additions and 83 deletions

View File

@ -485,12 +485,17 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> {
self
}
pub fn param_str(&mut self, name: impl Display) -> &mut Self {
// fast-path for &str and &[u8]-like parameters
pub fn param_str(
&mut self,
name: impl Display,
allow_nil: bool,
check_utf8: bool,
) -> &mut Self {
// fast-path for &[u8] and &str-like parameters
//
// this passes one lua `string` argument as two C `const uint8_t *ptr` and `uintptr_t len`
// arguments, bypassing the slower generic `&[u8]: FromFfi` path which constructs a
// temporary cdata to pass the string and its length in one argument
// temporary cdata to pass the string and its length in one argument.
let Self {
lparams,
cparams,
@ -499,22 +504,26 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> {
..
} = self;
let param_ptr = <*const u8>::cdecl("ptr");
let param_len = usize::cdecl("len");
let param_ptr = <*const u8>::cdecl(&name);
let param_len = usize::cdecl(format!("{name}_len"));
(!lparams.is_empty()).then(|| lparams.push_str(", "));
(!cparams.is_empty()).then(|| cparams.push_str(", "));
(!cargs.is_empty()).then(|| cargs.push_str(", "));
write!(lparams, "{name}").unwrap();
write!(cparams, "{param_ptr}, {param_len}",).unwrap();
write!(cparams, "{param_ptr}, {param_len}").unwrap();
write!(cargs, "{name}, __{name}_len").unwrap();
write!(prelude, "local __{name}_len = 0; ").unwrap();
write!(
prelude,
r#"if {name} ~= nil then assert(type({name}) == "string", "string expected in argument '{name}', got " .. type({name})); __{name}_len = #{name}; end; "#
)
.unwrap();
write!(prelude, "local __{name}_len = 0; if {name} ~= nil then ").unwrap();
write!(prelude, r#"assert(type({name}) == "string", "string expected in argument '{name}', got " .. type({name})); "#).unwrap();
write!(prelude, r#"__{name}_len = #{name}; "#).unwrap();
if check_utf8 {
write!(prelude, r#"assert(__C.{IS_UTF8_FN}({name}, __{name}_len), "argument '{name}' must be a valid utf-8 string"); "#).unwrap();
}
if !allow_nil {
write!(prelude, r#"else return error("string expected in argument '{name}', got " .. type({name})); "#).unwrap();
}
write!(prelude, r#"end; "#).unwrap();
self
}
@ -549,21 +558,21 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> {
if T::Into::ty() == TypeType::Void {
write!(lua, "__C.{func}({cargs}); {postlude}end").unwrap();
} else {
let check = T::postlude("__res");
write!(lua, "local __res = __C.{func}({cargs}); ").unwrap();
write!(lua, "{check}{postlude}return __res; end").unwrap();
let check = T::postlude("__ret");
write!(lua, "local __ret = __C.{func}({cargs}); ").unwrap();
write!(lua, "{check}{postlude}return __ret; end").unwrap();
}
writeln!(cdef, "{};", T::Into::cdecl(display!("{func}({cparams})"))).unwrap();
}
FfiReturnConvention::ByOutParam => {
let ct = T::Into::name();
let check = T::postlude("__res");
write!(lua, "local __res = __new(__ct.{ct}); __C.{func}(__res").unwrap();
let check = T::postlude("__out");
write!(lua, "local __out = __new(__ct.{ct}); __C.{func}(__out").unwrap();
if !cargs.is_empty() {
write!(lua, ", {cargs}").unwrap();
}
write!(lua, "); {check}{postlude}return __res; end").unwrap();
write!(lua, "); {check}{postlude}return __out; end").unwrap();
write!(cdef, "void {func}({}", <*mut T::Into>::cdecl("out")).unwrap();
if !cparams.is_empty() {
write!(cdef, ", {cparams}").unwrap();

View File

@ -1,5 +1,6 @@
use crate::utils::{
ffi_crate, is_primitivelike, is_unit, pat_ident, syn_assert, syn_error, ty_name,
StringLike, ffi_crate, is_optionlike, is_primitivelike, is_stringlike, is_unit, pat_ident,
syn_assert, syn_error, ty_name,
};
use proc_macro2::TokenStream;
use quote::{ToTokens, format_ident, quote, quote_spanned};
@ -29,12 +30,27 @@ pub fn transform(mut imp: ItemImpl) -> Result<TokenStream> {
))
}
struct Registry {
ty: Ident,
shims: Vec<ImplItemFn>,
build: Vec<TokenStream>,
}
impl Registry {
fn new(ty: Ident) -> Self {
Self {
ty,
shims: vec![],
build: vec![],
}
}
}
fn generate_impls(imp: &mut ItemImpl) -> Result<TokenStream> {
let ffi = ffi_crate();
let ty = imp.self_ty.clone();
let ty_name = ty_name(&ty)?;
let mut ffi_funcs = FfiRegistry::new(ty_name.clone());
let mut lua_funcs = LuaRegistry::new(ty_name.clone());
let mut registry = Registry::new(ty_name.clone());
let mut mms = HashSet::new();
let mut lua_drop = None;
@ -47,7 +63,7 @@ fn generate_impls(imp: &mut ItemImpl) -> Result<TokenStream> {
);
}
add_ffi_function(&mut ffi_funcs, &func)?;
add_ffi_function(&mut registry, &func)?;
}
for func in get_lua_functions(imp)? {
@ -59,37 +75,32 @@ fn generate_impls(imp: &mut ItemImpl) -> Result<TokenStream> {
);
}
if func.attrs.metamethod == Some(Metamethod::Gc) {
if let Some(Metamethod::Gc) = func.attrs.metamethod {
lua_drop = Some(func);
} else {
add_lua_function(&mut lua_funcs, &func)?;
add_lua_function(&mut registry, &func)?;
}
}
if !mms.contains(&Metamethod::New) {
inject_fallback_new(&mut lua_funcs)?;
inject_fallback_new(&mut registry)?;
}
inject_merged_drop(&mut ffi_funcs, lua_drop.as_ref())?;
inject_merged_drop(&mut registry, lua_drop.as_ref())?;
let ffi_shims = &ffi_funcs.shims;
let ffi_build = &ffi_funcs.build;
let lua_build = &lua_funcs.build;
let ffi_exports = generate_ffi_exports(&ffi_funcs)?;
let shims = &registry.shims;
let build = &registry.build;
let exports = generate_ffi_exports(&registry)?;
Ok(quote_spanned!(ty.span() =>
impl #ty { #(#ffi_shims)* }
impl #ty { #(#shims)* }
unsafe impl #ffi::Metatype for #ty {
type Target = Self;
fn build(b: &mut #ffi::MetatypeBuilder) {
#(#ffi_build)*
#(#lua_build)*
}
fn build(b: &mut #ffi::MetatypeBuilder) { #(#build)* }
}
#ffi_exports
#exports
))
}
@ -201,9 +212,15 @@ fn get_ffi_functions(imp: &mut ItemImpl) -> Result<Vec<FfiFunction>> {
&& let Some(ref abi) = abi.name
&& abi.value() == "Lua-C"
{
syn_assert!(
func.sig.generics.params.len() == 0,
func.sig.generics,
"cannot be generic"
);
func.sig.abi = None;
let params = func
let params: Vec<_> = func
.sig
.inputs
.iter()
@ -221,6 +238,24 @@ fn get_ffi_functions(imp: &mut ItemImpl) -> Result<Vec<FfiFunction>> {
ReturnType::Type(_, ref ty) => (**ty).clone(),
};
for param in params.iter() {
// double underscores are reserved for generated glue code
syn_assert!(
!pat_ident(&param.pat)?.to_string().starts_with("__"),
param.pat,
"parameter names should not start with `__`"
);
// lifetime should be determined by the caller (lua)
if let Type::Reference(ref ty) = *param.ty {
syn_assert!(
ty.lifetime.is_none(),
ty.lifetime,
"lifetime should be determined by the caller"
);
}
}
let attrs = parse_ffi_function_attrs(&mut func.attrs)?;
attrs.metamethod.map(|mm| document_metamethod(func, mm));
@ -261,14 +296,26 @@ fn parse_ffi_function_attrs(attrs: &mut Vec<Attribute>) -> Result<FfiFunctionAtt
Ok(parsed)
}
#[derive(Debug, Clone, Copy)]
enum FfiParameterType {
Default,
StringLike(StringLike),
OptionStringLike(StringLike),
}
fn get_ffi_param_type(_ty: &Type) -> FfiParameterType {
FfiParameterType::Default
fn get_ffi_param_type(ty: &Type) -> FfiParameterType {
if let Some(str) = is_stringlike(ty) {
FfiParameterType::StringLike(str)
} else if let Some(arg) = is_optionlike(ty)
&& let Some(str) = is_stringlike(arg)
{
FfiParameterType::OptionStringLike(str)
} else {
FfiParameterType::Default
}
}
#[derive(Debug, Clone, Copy)]
enum FfiReturnType {
Void,
ByValue,
@ -298,23 +345,7 @@ fn get_ffi_ret_type(ty: &Type) -> FfiReturnType {
}
}
struct FfiRegistry {
ty: Ident,
shims: Vec<ImplItemFn>,
build: Vec<TokenStream>,
}
impl FfiRegistry {
fn new(ty: Ident) -> Self {
Self {
ty,
shims: vec![],
build: vec![],
}
}
}
fn add_ffi_function(registry: &mut FfiRegistry, func: &FfiFunction) -> Result<()> {
fn add_ffi_function(registry: &mut Registry, func: &FfiFunction) -> Result<()> {
let ffi = ffi_crate();
let ty = &registry.ty;
let func_name = &func.name;
@ -366,6 +397,41 @@ fn add_ffi_function(registry: &mut FfiRegistry, func: &FfiFunction) -> Result<()
b.param::<#func_param>(#name);
));
}
ty @ (FfiParameterType::StringLike(str) | FfiParameterType::OptionStringLike(str)) => {
let shim_param_len = format_ident!("arg{i}_len");
shim_params.push(quote_spanned!(func_param.span() =>
#shim_param: ::std::option::Option<&::std::primitive::u8>,
#shim_param_len: ::std::primitive::usize
));
let allow_nil = matches!(ty, FfiParameterType::OptionStringLike(_));
let check_utf8 = matches!(str, StringLike::Str);
let mut func_arg = quote_spanned!(func_param.span() =>
#shim_param.map(|s| ::std::slice::from_raw_parts(s, #shim_param_len))
);
func_arg = match str {
StringLike::SliceU8 => func_arg,
StringLike::BStr => {
quote_spanned!(func_param.span() => #func_arg.map(::bstr::BStr::new))
}
StringLike::Str => {
quote_spanned!(func_param.span() => #func_arg.map(|s| {
::std::debug_assert!(::std::str::from_utf8(s).is_ok());
::std::str::from_utf8_unchecked(s)
}))
}
};
if !allow_nil {
func_arg = quote_spanned!(func_param.span() => {
let arg = #func_arg;
::std::debug_assert!(arg.is_some());
arg.unwrap_unchecked()
});
}
func_args.push(func_arg);
build.push(quote_spanned!(param.pat.span() =>
b.param_str(#name, #allow_nil, #check_utf8);
));
}
}
}
@ -439,7 +505,7 @@ fn add_ffi_function(registry: &mut FfiRegistry, func: &FfiFunction) -> Result<()
Ok(())
}
fn generate_ffi_exports(registry: &FfiRegistry) -> Result<TokenStream> {
fn generate_ffi_exports(registry: &Registry) -> Result<TokenStream> {
let ty = &registry.ty;
let names = registry.shims.iter().map(|f| &f.sig.ident);
@ -474,6 +540,12 @@ fn get_lua_functions(imp: &mut ItemImpl) -> Result<Vec<LuaFunction>> {
&& let Some(ref abi) = abi.name
&& abi.value() == "Lua"
{
syn_assert!(
func.sig.generics.params.len() == 0,
func.sig.generics,
"cannot be generic"
);
let mut params: Vec<_> = func
.sig
.inputs
@ -597,18 +669,7 @@ fn parse_lua_function_attrs(attrs: &mut Vec<Attribute>) -> Result<LuaFunctionAtt
Ok(parsed)
}
struct LuaRegistry {
ty: Ident,
build: Vec<TokenStream>,
}
impl LuaRegistry {
fn new(ty: Ident) -> Self {
Self { ty, build: vec![] }
}
}
fn add_lua_function(registry: &mut LuaRegistry, func: &LuaFunction) -> Result<()> {
fn add_lua_function(registry: &mut Registry, func: &LuaFunction) -> Result<()> {
let ffi = ffi_crate();
let luaify = quote!(#ffi::__internal::luaify!);
let func_name = &func.name;
@ -628,7 +689,7 @@ fn add_lua_function(registry: &mut LuaRegistry, func: &LuaFunction) -> Result<()
Ok(())
}
fn inject_fallback_new(registry: &mut LuaRegistry) -> Result<()> {
fn inject_fallback_new(registry: &mut Registry) -> Result<()> {
let ty = &registry.ty;
let lua = format!(
r#"function() error("type '{}' has no constructor"); end"#,
@ -642,7 +703,7 @@ fn inject_fallback_new(registry: &mut LuaRegistry) -> Result<()> {
Ok(())
}
fn inject_merged_drop(registry: &mut FfiRegistry, lua: Option<&LuaFunction>) -> Result<()> {
fn inject_merged_drop(registry: &mut Registry, lua: Option<&LuaFunction>) -> Result<()> {
let ffi = ffi_crate();
let luaify = quote!(#ffi::__internal::luaify!);
let ty = &registry.ty;

View File

@ -57,13 +57,13 @@ pub fn is_unit(ty: &Type) -> bool {
pub fn is_primitivelike(ty: &Type) -> bool {
match ty {
Type::Tuple(tuple) if tuple.elems.is_empty() => true, // unit type
Type::Reference(_) | Type::Ptr(_) => true,
Type::Paren(paren) => is_primitivelike(&paren.elem),
Type::Tuple(tuple) if tuple.elems.is_empty() => return true, // unit type
Type::Reference(_) | Type::Ptr(_) => return true,
Type::Paren(paren) => return is_primitivelike(&paren.elem),
Type::Path(path) => {
if let Some(name) = path.path.get_ident() {
matches!(
format!("{name}").as_str(),
return matches!(
name.to_string().as_str(),
"bool"
| "u8"
| "u16"
@ -94,11 +94,66 @@ pub fn is_primitivelike(ty: &Type) -> bool {
| "c_size_t"
| "c_ssize_t"
| "c_ptrdiff_t"
)
} else {
false
);
}
}
_ => false,
_ => {}
}
false
}
#[derive(Debug, Clone, Copy)]
pub enum StringLike {
SliceU8,
Str,
BStr,
}
pub fn is_stringlike(ty: &Type) -> Option<StringLike> {
if let Type::Reference(ty) = ty
&& ty.mutability.is_none()
&& ty.lifetime.is_none()
{
match *ty.elem {
Type::Slice(ref slice) => {
// match &[u8]
if let Type::Path(ref path) = *slice.elem
&& let Some(name) = path.path.get_ident()
&& name == "u8"
{
return Some(StringLike::SliceU8);
}
}
Type::Path(ref path) => {
// match &str or &BStr
if let Some(name) = path.path.get_ident() {
match name.to_string().as_str() {
"str" => return Some(StringLike::Str),
"BStr" => return Some(StringLike::BStr),
_ => {}
}
}
}
_ => {}
}
}
None
}
pub fn is_optionlike(ty: &Type) -> Option<&Type> {
if let Type::Path(path) = ty
&& path.path.leading_colon.is_none()
&& path.path.segments.len() == 1
&& let Some(segment) = path.path.segments.get(0)
&& segment.ident == "Option"
&& let PathArguments::AngleBracketed(ref angle) = segment.arguments
&& angle.args.len() == 1
&& let Some(GenericArgument::Type(ty)) = angle.args.get(0)
{
Some(ty)
} else {
None
}
}

View File

@ -1333,8 +1333,7 @@ impl<'s> DerefMut for StackGuard<'s> {
impl<'s> Drop for StackGuard<'s> {
fn drop(&mut self) {
#[cfg(debug_assertions)]
if self.check_overpop {
if cfg!(debug_assertions) && self.check_overpop {
let new_size = self.stack.size();
assert!(
self.size <= new_size,