diff --git a/crates/luaffi/Cargo.toml b/crates/luaffi/Cargo.toml index ddec740..1f61a4b 100644 --- a/crates/luaffi/Cargo.toml +++ b/crates/luaffi/Cargo.toml @@ -9,4 +9,3 @@ luaffi_impl = { version = "0.1.0", path = "../luaffi_impl" } luaify = { version = "0.1.0", path = "../luaify" } rustc-hash = "2.1.1" simdutf8 = "0.1.5" -static_assertions = "1.1.0" diff --git a/crates/luaffi/src/future.rs b/crates/luaffi/src/future.rs index 64c36f7..e580ba8 100644 --- a/crates/luaffi/src/future.rs +++ b/crates/luaffi/src/future.rs @@ -1,6 +1,6 @@ use crate::{ __internal::{display, type_id}, - CDef, CDefBuilder, Metatype, MetatypeBuilder, ToFfi, Type, TypeBuilder, + CDef, CDefBuilder, FfiReturnConvention, Metatype, MetatypeBuilder, ToFfi, Type, TypeBuilder, }; use luaify::luaify; use std::{ @@ -94,12 +94,9 @@ impl> lua_future { } unsafe extern "C" fn take(&mut self) -> ::To { - // `fut:__take()` returns the fulfilled value by-value because it is the lowest common - // denominator for supported return conventions (all `ToFfi` impls support return by-value; - // primitives e.g. don't support return by out-param because they get boxed in cdata). - // - // Plus, if we preallocate a cdata for out-param and the thread for some reason gets dropped - // and never resumed, GC could call the destructor on an uninitialised cdata. + // `fut:__take()` returns the fulfilled value by-value (not by out-param) because if we + // preallocate a cdata for the out-param and the thread for some reason gets dropped and + // never resumed, the GC could call the destructor on an uninitialised cdata. match self.state { State::Fulfilled(_) => match mem::replace(&mut self.state, State::Complete) { State::Fulfilled(value) => value.convert(), @@ -170,7 +167,7 @@ unsafe impl + 'static> ToFfi for lua_future { self } - fn postlude(ret: &str) -> impl Display { + fn postlude(ret: &str, _conv: FfiReturnConvention) -> impl Display { // When returning a future from Rust to Lua, yield it immediately to the runtime which will // poll it to completion in the background, then take the fulfilled value once the thread // gets resumed. Lua user code should never to worry about awaiting futures. @@ -181,7 +178,7 @@ unsafe impl + 'static> ToFfi for lua_future { // `coroutine.yield` is cached as `yield` and `ffi.gc` as `gc` in locals (see lib.rs) display!( "yield({ret}); {ret} = gc({ret}, nil):__take(); {}", - ::postlude(ret) + ::postlude(ret, FfiReturnConvention::ByValue) ) } } diff --git a/crates/luaffi/src/internal.rs b/crates/luaffi/src/internal.rs index d5c748d..b67242d 100644 --- a/crates/luaffi/src/internal.rs +++ b/crates/luaffi/src/internal.rs @@ -1,6 +1,5 @@ pub use luaify::*; use rustc_hash::FxHasher; -pub use static_assertions::*; use std::{ any::TypeId, fmt::{self, Display, Formatter}, diff --git a/crates/luaffi/src/lib.rs b/crates/luaffi/src/lib.rs index 0521407..09d5ead 100644 --- a/crates/luaffi/src/lib.rs +++ b/crates/luaffi/src/lib.rs @@ -19,8 +19,8 @@ const KEEP_FN: &str = "luaffi_keep"; const IS_UTF8_FN: &str = "luaffi_is_utf8"; // Dummy function to ensure that strings passed to Rust via wrapper objects will not be -// garbage-collected until the end of the function. -// This shall exist until LuaJIT one day implements something like `ffi.keep(obj)`. +// garbage-collected until the end of the function. This shall exist until LuaJIT one day implements +// something like `ffi.keep(obj)`. // // https://github.com/LuaJIT/LuaJIT/issues/1167 #[unsafe(export_name = "luaffi_keep")] @@ -32,7 +32,6 @@ unsafe extern "C" fn __is_utf8(ptr: *const u8, len: usize) -> bool { } const CACHE_LIBS: &[(&str, &str)] = &[ - // libs in global ("table", "table"), ("string", "string"), ("math", "math"), @@ -43,8 +42,8 @@ const CACHE_LIBS: &[(&str, &str)] = &[ // require ("bit", r#"require("bit")"#), ("ffi", r#"require("ffi")"#), - ("new", r#"require("table.new")"#), - ("clear", r#"require("table.clear")"#), + ("__tnew", r#"require("table.new")"#), + ("__tclear", r#"require("table.clear")"#), ]; // https://www.lua.org/manual/5.1/manual.html#5.1 @@ -71,44 +70,48 @@ const CACHE_GLOBALS: &[(&str, &str)] = &[ ("tostring", "tostring"), ("require", "require"), // table - ("concat", "table.concat"), - ("insert", "table.insert"), - ("maxn", "table.maxn"), - ("remove", "table.remove"), - ("sort", "table.sort"), + ("__tconcat", "table.concat"), + ("__tinsert", "table.insert"), + ("__tmaxn", "table.maxn"), + ("__tremove", "table.remove"), + ("__tsort", "table.sort"), // string - ("strlen", "string.len"), - ("format", "string.format"), - ("strsub", "string.sub"), - ("gsub", "string.gsub"), - ("gmatch", "string.gmatch"), - ("dump", "string.dump"), + ("__slen", "string.len"), + ("__sformat", "string.format"), + ("__ssub", "string.sub"), + ("__sgsub", "string.gsub"), + ("__sgmatch", "string.gmatch"), + ("__sdump", "string.dump"), // math - ("random", "math.random"), + ("__fmod", "math.fmod"), // coroutine - ("yield", "coroutine.yield"), + ("__yield", "coroutine.yield"), // debug - ("traceback", "debug.traceback"), + ("__traceback", "debug.traceback"), // ffi - ("C", "ffi.C"), - ("cdef", "ffi.cdef"), - ("typeof", "ffi.typeof"), - ("metatype", "ffi.metatype"), - ("cast", "ffi.cast"), - ("gc", "ffi.gc"), + ("__C", "ffi.C"), + ("__cdef", "ffi.cdef"), + ("__cnew", "ffi.new"), + ("__ctype", "ffi.typeof"), + ("__ctypes", "{}"), + ("__istype", "ffi.istype"), + ("__metatype", "ffi.metatype"), + ("__cast", "ffi.cast"), + ("__gc", "ffi.gc"), + ("__sizeof", "ffi.sizeof"), + ("__alignof", "ffi.alignof"), + ("__intern", "ffi.string"), // bit - ("tobit", "bit.tobit"), - ("tohex", "bit.tohex"), - ("bnot", "bit.bnot"), - ("band", "bit.band"), - ("bor", "bit.bor"), - ("bxor", "bit.bxor"), - ("lshift", "bit.lshift"), - ("rshift", "bit.rshift"), - ("arshift", "bit.arshift"), - ("rol", "bit.rol"), - ("ror", "bit.ror"), - ("bswap", "bit.bswap"), + ("__bnot", "bit.bnot"), + ("__band", "bit.band"), + ("__bor", "bit.bor"), + ("__bxor", "bit.bxor"), + ("__blshift", "bit.lshift"), + ("__brshift", "bit.rshift"), + ("__barshift", "bit.arshift"), + ("__brol", "bit.rol"), + ("__bror", "bit.ror"), + ("__bswap", "bit.bswap"), ]; fn cache_local(f: &mut Formatter, list: &[(&str, &str)]) -> fmt::Result { @@ -135,11 +138,6 @@ impl Registry { s } - pub fn preload(&mut self, _name: impl Display) -> &mut Self { - self.include::(); - self - } - pub fn include(&mut self) -> &mut Self { self.types .insert(T::name().to_string()) @@ -164,10 +162,10 @@ impl Display for Registry { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let name = env!("CARGO_PKG_NAME"); let version = env!("CARGO_PKG_VERSION"); - writeln!(f, "-- automatically generated by {name} {version}")?; + writeln!(f, "--- automatically generated by {name} {version}")?; cache_local(f, CACHE_LIBS)?; cache_local(f, CACHE_GLOBALS)?; - writeln!(f, "cdef [[{}]];", self.cdef)?; + writeln!(f, "__cdef [[{}]];", self.cdef)?; write!(f, "{}", self.lua) } } @@ -222,14 +220,6 @@ pub struct CDefBuilder<'r> { impl<'r> CDefBuilder<'r> { fn new(registry: &'r mut Registry) -> Self { - writeln!( - registry.lua, - r#"local {} = typeof("{}");"#, - T::name(), - T::cdecl("") - ) - .unwrap(); - Self { registry, cdef: format!("{} {{ ", T::cdecl("")), @@ -293,6 +283,7 @@ pub unsafe trait Metatype { pub struct MetatypeBuilder<'r> { registry: &'r mut Registry, name: String, + cdecl: String, cdef: String, lua: String, } @@ -302,8 +293,9 @@ impl<'r> MetatypeBuilder<'r> { Self { registry, name: T::Target::name().to_string(), + cdecl: T::Target::cdecl("").to_string(), cdef: String::new(), - lua: format!(r#"do local __mt, __idx = {{}}, {{}}; __mt.__index = __idx; "#), + lua: r#"do local __mt, __idx = {}, {}; __mt.__index = __idx; "#.into(), } } @@ -333,7 +325,7 @@ impl<'r> MetatypeBuilder<'r> { name: impl Display, f: impl FnOnce(&mut MetatypeMethodBuilder), ) -> &mut Self { - write!(self.lua, "__idx.{name} = ").unwrap(); + write!(self.lua, "__mt.{name} = ").unwrap(); f(&mut MetatypeMethodBuilder::new(self)); write!(self.lua, "; ").unwrap(); self @@ -350,6 +342,7 @@ impl<'r> Drop for MetatypeBuilder<'r> { let Self { registry, name, + cdecl, cdef, lua, .. @@ -357,37 +350,45 @@ impl<'r> Drop for MetatypeBuilder<'r> { registry.cdef.push_str(cdef); registry.lua.push_str(lua); - writeln!(registry.lua, "metatype({name}, __mt); end;").unwrap(); + + writeln!( + registry.lua, + r#"__ctypes.{name} = __metatype("{cdecl}", __mt); end;"# + ) + .unwrap(); } } pub unsafe trait FromFfi: Sized { type From: Type + Sized; - type FromValue: Type + Sized; + type FromArg: Type + Sized; - const ARG_KEEPALIVE: bool = false; + fn require_keepalive() -> bool { + false + } fn prelude(_arg: &str) -> impl Display { "" } fn convert(from: Self::From) -> Self; - fn convert_value(from: Self::FromValue) -> Self; + fn convert_arg(from: Self::FromArg) -> Self; } pub unsafe trait ToFfi: Sized { type To: Type + Sized; - fn postlude(_ret: &str) -> impl Display { + fn postlude(_ret: &str, _conv: FfiReturnConvention) -> impl Display { "" } fn convert(self) -> Self::To; } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)] pub enum FfiReturnConvention { Void, + #[default] ByValue, ByOutParam, } @@ -420,9 +421,9 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> { write!(self.params, "{name}").unwrap(); write!(self.args, "{name}").unwrap(); - if T::ARG_KEEPALIVE { + if T::require_keepalive() { write!(self.prelude, "local __keep_{name} = {name}; ").unwrap(); - write!(self.postlude, "C.{KEEP_FN}(__keep_{name}); ").unwrap(); + write!(self.postlude, "__C.{KEEP_FN}(__keep_{name}); ").unwrap(); } let name = name.to_string(); @@ -432,6 +433,10 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> { pub fn param_str(&mut self, name: impl Display) -> &mut Self { // fast-path for &str and &[u8]-like parameters + // + // this passes one lua `string` argument as two C `const uint8_t *ptr` and `uintptr_t len` + // arguments, bypassing the slower generic `&[u8]: FromFfi` path which constructs a + // temporary cdata to pass the string and its length in one argument (!self.params.is_empty()).then(|| self.params.push_str(", ")); (!self.args.is_empty()).then(|| self.args.push_str(", ")); write!(self.params, "{name}").unwrap(); @@ -445,7 +450,7 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> { self } - pub fn call(&mut self, func: impl Display, conv: FfiReturnConvention) { + pub fn call(&mut self, func: impl Display, ret: FfiReturnConvention) { let Self { metatype, params, @@ -459,22 +464,22 @@ impl<'r, 'm> MetatypeMethodBuilder<'r, 'm> { let lua = &mut metatype.lua; write!(lua, "function({params}) {prelude}").unwrap(); - match conv { + match ret { FfiReturnConvention::Void => { - write!(lua, "C.{func}({args}); {postlude}return nil; end").unwrap(); + write!(lua, "__C.{func}({args}); {postlude}end").unwrap(); } FfiReturnConvention::ByValue => { - let check = T::postlude("res"); + let check = T::postlude("res", ret); write!( lua, - "local res = C.{func}({args}); {check}{postlude}return res; end" + "local res = __C.{func}({args}); {check}{postlude}return res; end" ) .unwrap(); } FfiReturnConvention::ByOutParam => { let ct = T::To::name(); - let check = T::postlude("res"); - write!(lua, "local res = {ct}(); C.{func}(res").unwrap(); + let check = T::postlude("res", ret); + write!(lua, "local res = __cnew(__ctypes.{ct}); __C.{func}(res").unwrap(); if !args.is_empty() { write!(lua, ", {args}").unwrap(); } @@ -505,16 +510,25 @@ impl_primitive!(c_void, "void"); unsafe impl ToFfi for () { // - // SAFETY: Unit type return maps to a C void return, which is a nil return in lua. - // There is no equivalent to passing a unit type as an argument in C. - // `c_void` cannot be returned from rust so it should return the unit type instead. + // SAFETY: Unit type return maps to a C void return, which is a nil return in lua. There is no + // equivalent to passing a unit type as an argument in C. `c_void` cannot be returned from rust + // so it should return the unit type instead. // type To = (); fn convert(self) -> Self::To {} + + fn postlude(_ret: &str, conv: FfiReturnConvention) -> impl Display { + assert!( + conv == FfiReturnConvention::Void, + "void type cannot be instantiated" + ); + + "" + } } macro_rules! impl_copy_primitive { - ($rtype:ty, $ctype:expr, $ltype:expr) => { + ($rtype:ty, $ctype:expr, $ltype:expr $(, $unwrap:expr)?) => { impl_primitive!($rtype, $ctype); // @@ -522,7 +536,7 @@ macro_rules! impl_copy_primitive { // unsafe impl FromFfi for $rtype { type From = $rtype; - type FromValue = $rtype; + type FromArg = $rtype; fn prelude(arg: &str) -> impl Display { display!(r#"assert(type({arg}) == "{0}", "{0} expected in argument '{arg}', got " .. type({arg})); "#, $ltype) @@ -532,7 +546,7 @@ macro_rules! impl_copy_primitive { from } - fn convert_value(from: Self::FromValue) -> Self { + fn convert_arg(from: Self::FromArg) -> Self { from } } @@ -543,23 +557,38 @@ macro_rules! impl_copy_primitive { fn convert(self) -> Self::To { self } + + #[allow(unused)] + fn postlude(ret: &str, conv: FfiReturnConvention) -> impl Display { + disp(move |f| { + match conv { + FfiReturnConvention::Void => unreachable!(), + FfiReturnConvention::ByValue => {}, + // if a primitive type for some reason gets returned by out-param, unwrap + // the cdata containing the value and convert it to the equivalent lua value + FfiReturnConvention::ByOutParam => { $(write!(f, "{ret} = {}; ", $unwrap(ret))?;)? }, + } + + Ok(()) + }) + } } }; } -impl_copy_primitive!(bool, "bool", "boolean"); -impl_copy_primitive!(u8, "uint8_t", "number"); -impl_copy_primitive!(u16, "uint16_t", "number"); -impl_copy_primitive!(u32, "uint32_t", "number"); +impl_copy_primitive!(bool, "bool", "boolean", |n| display!("{n} ~= 0")); +impl_copy_primitive!(u8, "uint8_t", "number", |n| display!("tonumber({n})")); +impl_copy_primitive!(u16, "uint16_t", "number", |n| display!("tonumber({n})")); +impl_copy_primitive!(u32, "uint32_t", "number", |n| display!("tonumber({n})")); impl_copy_primitive!(u64, "uint64_t", "number"); impl_copy_primitive!(usize, "uintptr_t", "number"); -impl_copy_primitive!(i8, "int8_t", "number"); -impl_copy_primitive!(i16, "int16_t", "number"); -impl_copy_primitive!(i32, "int32_t", "number"); +impl_copy_primitive!(i8, "int8_t", "number", |n| display!("tonumber({n})")); +impl_copy_primitive!(i16, "int16_t", "number", |n| display!("tonumber({n})")); +impl_copy_primitive!(i32, "int32_t", "number", |n| display!("tonumber({n})")); impl_copy_primitive!(i64, "int64_t", "number"); impl_copy_primitive!(isize, "intptr_t", "number"); -impl_copy_primitive!(c_float, "float", "number"); -impl_copy_primitive!(c_double, "double", "number"); +impl_copy_primitive!(c_float, "float", "number", |n| display!("tonumber({n})")); +impl_copy_primitive!(c_double, "double", "number", |n| display!("tonumber({n})")); unsafe impl Type for *const T { fn name() -> impl Display { @@ -596,34 +625,34 @@ unsafe impl Type for *mut T { // unsafe impl FromFfi for *const T { type From = *const T; - type FromValue = *const T; + type FromArg = *const T; fn convert(from: Self::From) -> Self { from } - fn convert_value(from: Self::FromValue) -> Self { + fn convert_arg(from: Self::FromArg) -> Self { from } } unsafe impl FromFfi for *mut T { type From = *mut T; - type FromValue = *mut T; + type FromArg = *mut T; fn convert(from: Self::From) -> Self { from } - fn convert_value(from: Self::FromValue) -> Self { + fn convert_arg(from: Self::FromArg) -> Self { from } } // -// SAFETY: Return by value for pointers, which maps to a `cdata` return in lua containing the pointer (`T *`). -// We also map null pointers to `nil` for convenience (otherwise it's still a cdata value containing -// a null pointer) +// SAFETY: Return by value for pointers, which maps to a `cdata` return in lua containing the +// pointer (`T *`). We also map null pointers to `nil` for convenience (otherwise it's still a cdata +// value containing a null pointer) // unsafe impl ToFfi for *const T { type To = *const T; @@ -632,7 +661,7 @@ unsafe impl ToFfi for *const T { self } - fn postlude(ret: &str) -> impl Display { + fn postlude(ret: &str, _conv: FfiReturnConvention) -> impl Display { display!("if {ret} == nil then {ret} = nil; end; ") } } @@ -644,14 +673,14 @@ unsafe impl ToFfi for *mut T { self } - fn postlude(ret: &str) -> impl Display { + fn postlude(ret: &str, _conv: FfiReturnConvention) -> impl Display { display!("if {ret} == nil then {ret} = nil; end; ") } } // -// SAFETY: No `ToFfi` for references because we can't guarantee that the returned reference converted -// to a pointer will not outlive the pointee. +// SAFETY: No `ToFfi` for references because we can't guarantee that the returned reference +// converted to a pointer will not outlive the pointee. // unsafe impl Type for &T { fn name() -> impl Display { @@ -682,22 +711,18 @@ unsafe impl Type for &mut T { } // -// SAFETY: Pass by value for references, which have the same semantics as pointers (see above). -// Must ensure that the pointer is not nil before being converted to a reference. +// SAFETY: Pass by value for references, which have the same semantics as pointers (see above). Must +// ensure that the pointer is not nil before being converted to a reference. // unsafe impl FromFfi for &T { type From = *const T; - type FromValue = *const T; + type FromArg = *const T; fn prelude(arg: &str) -> impl Display { display!(r#"assert({arg} ~= nil, "argument '{arg}' cannot be nil"); "#) } fn convert(from: Self::From) -> Self { - Self::convert_value(from) - } - - fn convert_value(from: Self::FromValue) -> Self { debug_assert!( !from.is_null(), "<&T>::convert() called on a null pointer when it was checked to be non-null" @@ -705,36 +730,37 @@ unsafe impl FromFfi for &T { unsafe { &*from } } + + fn convert_arg(from: Self::FromArg) -> Self { + Self::convert(from) + } } unsafe impl FromFfi for &mut T { // // SAFETY: `FromFfi` for *mutable* references is safe because it is guaranteed that no two Rust - // code called via FFI can be running at the same time on the same OS thread (no Lua reentrancy). + // code called via FFI can be running at the same time on the same OS thread (no Lua + // reentrancy). // // i.e. The call stack will always look something like this: // - // * Runtime (LuaJIT/Rust) -> Lua (via C) -> Rust (via FFI): - // This is SAFE and the only use case we support. All references (mutable or not) to Rust - // user objects will be dropped before returning to Lua. + // * Runtime (LuaJIT/Rust) -> Lua (via C) -> Rust (via FFI): This is SAFE and the only use case + // we support. All references (mutable or not) to Rust user objects will be dropped before + // returning to Lua. // - // * Runtime (LuaJIT/Rust) -> Lua (via C) -> Rust (via FFI) -> Lua (via callback): - // This is UNSAFE because we cannot prevent the Lua callback from calling back into Rust code - // via FFI which could violate exclusive borrow semantics. This is prevented by not - // implementing `FromFfi` for function pointers (see below). + // * Runtime (LuaJIT/Rust) -> Lua (via C) -> Rust (via FFI) -> Lua (via callback): This is + // UNSAFE because we cannot prevent the Lua callback from calling back into Rust code via + // FFI which could violate exclusive borrow semantics. This is prevented by not implementing + // `FromFfi` for function pointers (see below). // type From = *mut T; - type FromValue = *mut T; + type FromArg = *mut T; fn prelude(arg: &str) -> impl Display { display!(r#"assert({arg} ~= nil, "argument '{arg}' cannot be nil"); "#) } fn convert(from: Self::From) -> Self { - Self::convert_value(from) - } - - fn convert_value(from: Self::FromValue) -> Self { debug_assert!( !from.is_null(), "<&mut T>::convert() called on a null pointer when it was checked to be non-null" @@ -742,11 +768,15 @@ unsafe impl FromFfi for &mut T { unsafe { &mut *from } } + + fn convert_arg(from: Self::FromArg) -> Self { + Self::convert(from) + } } // -// SAFETY: No `FromFfi` and `ToFfi` for arrays because passing or returning them by value is not -// a thing in C (they are just pointers). +// SAFETY: No `FromFfi` and `ToFfi` for arrays because passing or returning them by value is not a +// thing in C (they are just pointers). // // TODO: we could automatically convert them to tables and vice-versa // @@ -786,7 +816,8 @@ macro_rules! impl_function { (($($type:tt)+), fn($($arg:tt),*) -> $ret:tt) => { // - // SAFETY: No `FromFfi` for function pointers because of borrow safety invariants (see above in `&mut T`). + // SAFETY: No `FromFfi` for function pointers because of borrow safety invariants (see above + // in `&mut T`). // // We also can't implement `ToFfi` because we can't call `FromFfi` and `ToFfi` for the // function's respective argument and return values. diff --git a/crates/luaffi/src/option.rs b/crates/luaffi/src/option.rs index 93c4d9a..6802e2f 100644 --- a/crates/luaffi/src/option.rs +++ b/crates/luaffi/src/option.rs @@ -1,4 +1,4 @@ -use crate::{CDef, CDefBuilder, FromFfi, ToFfi, Type, TypeBuilder, display}; +use crate::{CDef, CDefBuilder, FfiReturnConvention, FromFfi, ToFfi, Type, TypeBuilder, display}; use std::{ffi::c_int, fmt::Display, ptr}; #[repr(C)] @@ -29,34 +29,32 @@ unsafe impl CDef for lua_option { } unsafe impl FromFfi for Option { - type From = *mut Self::FromValue; // pass by-ref - type FromValue = lua_option; + type From = lua_option; + type FromArg = *mut Self::From; // pass by-ref - const ARG_KEEPALIVE: bool = T::ARG_KEEPALIVE; + fn require_keepalive() -> bool { + T::require_keepalive() + } fn prelude(arg: &str) -> impl Display { - let ct = Self::FromValue::name(); + let ct = Self::From::name(); display!( - "if {arg} == nil then {arg} = {ct}(); else {}{arg} = {ct}(1, {arg}); end; ", + "if {arg} == nil then {arg} = __cnew(__ctypes.{ct}); else {}{arg} = __cnew(__ctypes.{ct}, 1, {arg}); end; ", T::prelude(arg) ) } fn convert(from: Self::From) -> Self { - debug_assert!( - !from.is_null(), - "Option::convert() called on a null lua_option" - ); - - Self::convert_value(unsafe { ptr::replace(from, lua_option::None) }) - } - - fn convert_value(from: Self::FromValue) -> Self { match from { - lua_option::Some(value) => Some(T::convert_value(value)), + lua_option::Some(value) => Some(T::convert(value)), lua_option::None => None, } } + + fn convert_arg(from: Self::FromArg) -> Self { + debug_assert!(!from.is_null()); + Self::convert(unsafe { ptr::replace(from, lua_option::None) }) + } } unsafe impl ToFfi for Option { @@ -69,12 +67,12 @@ unsafe impl ToFfi for Option { } } - fn postlude(ret: &str) -> impl Display { + fn postlude(ret: &str, _conv: FfiReturnConvention) -> impl Display { // if we don't have a value, return nil. otherwise copy out the inner value immediately, // forget the option cdata, then call postlude on the inner value. display!( "if {ret}.__tag == 0 then {ret} = nil; else {ret} = {ret}.__value; {}end; ", - T::postlude(ret) + T::postlude(ret, FfiReturnConvention::ByValue) ) } } diff --git a/crates/luaffi/src/string.rs b/crates/luaffi/src/string.rs index 6edfe6b..a8de81a 100644 --- a/crates/luaffi/src/string.rs +++ b/crates/luaffi/src/string.rs @@ -2,80 +2,80 @@ use crate::{__internal::disp, FromFfi, IS_UTF8_FN, Type}; use luaffi_impl::{cdef, metatype}; use std::{fmt, ptr, slice}; -#[cdef] #[derive(Debug, Clone, Copy)] +#[cdef] pub struct lua_buf { __ptr: *mut u8, __len: usize, } #[metatype] -impl lua_buf {} +impl lua_buf { + #[new] + extern "Lua-C" fn new() -> u32 { + todo!() + } +} unsafe impl FromFfi for *const [u8] { - type From = *const Self::FromValue; // pass by-ref - type FromValue = lua_buf; + type From = lua_buf; + type FromArg = *const Self::From; - const ARG_KEEPALIVE: bool = true; + fn require_keepalive() -> bool { + true + } fn prelude(arg: &str) -> impl fmt::Display { // this converts string arguments to a `lua_buf` with a pointer to the string and its length disp(move |f| { - let ct = lua_buf::name(); + let ct = Self::From::name(); write!( f, r#"if {arg} ~= nil then assert(type({arg}) == "string", "string expected in argument '{arg}', got " .. type({arg})); "# )?; - write!(f, "{arg} = {ct}({arg}, #{arg}); end; ") + write!(f, "{arg} = __cnew(__ctypes.{ct}, {arg}, #{arg}); end; ") }) } fn convert(from: Self::From) -> Self { + ptr::slice_from_raw_parts(from.__ptr, from.__len) + } + + fn convert_arg(from: Self::FromArg) -> Self { if from.is_null() { ptr::slice_from_raw_parts(ptr::null(), 0) } else { - // SAFETY: this is safe because lua_buf is copyable - unsafe { Self::convert_value(*from) } + Self::convert(unsafe { *from }) } } - - fn convert_value(from: Self::FromValue) -> Self { - ptr::slice_from_raw_parts(from.__ptr, from.__len) - } } unsafe impl FromFfi for &str { - type From = *const Self::FromValue; // pass by-ref - type FromValue = lua_buf; + type From = lua_buf; + type FromArg = *const Self::From; - const ARG_KEEPALIVE: bool = true; + fn require_keepalive() -> bool { + true + } fn prelude(arg: &str) -> impl fmt::Display { disp(move |f| { - let ct = lua_buf::name(); + let ct = Self::From::name(); write!( f, r#"assert(type({arg}) == "string", "string expected in argument '{arg}', got " .. type({arg})); "# )?; write!( f, - r#"assert(C.{IS_UTF8_FN}({arg}, #{arg}), "argument '{arg}' must be a valid utf8 string"); "# + r#"assert(__C.{IS_UTF8_FN}({arg}, #{arg}), "argument '{arg}' must be a valid utf-8 string"); "# )?; - write!(f, "{arg} = {ct}({arg}, #{arg}); ") + write!(f, "{arg} = __cnew(__ctypes.{ct}, {arg}, #{arg}); ") }) } fn convert(from: Self::From) -> Self { - debug_assert!( - !from.is_null(), - "<&str>::convert() called on a null lua_buf" - ); - - // SAFETY: this is safe because lua_buf is copyable - unsafe { Self::convert_value(*from) } - } - - fn convert_value(from: Self::FromValue) -> Self { + // SAFETY: we already checked that the string is nonnull and valid utf8 from the lua side + debug_assert!(!from.__ptr.is_null()); let s = unsafe { slice::from_raw_parts(from.__ptr, from.__len) }; debug_assert!( @@ -83,7 +83,11 @@ unsafe impl FromFfi for &str { "<&str>::convert() called on an invalid utf8 string when it was checked to be valid" ); - // SAFETY: we already checked that the string is valid utf8 from the lua side unsafe { std::str::from_utf8_unchecked(s) } } + + fn convert_arg(from: Self::FromArg) -> Self { + debug_assert!(!from.is_null()); + unsafe { Self::convert(*from) } + } } diff --git a/crates/luaffi_impl/src/cdef.rs b/crates/luaffi_impl/src/cdef.rs index 87ab870..c763f4d 100644 --- a/crates/luaffi_impl/src/cdef.rs +++ b/crates/luaffi_impl/src/cdef.rs @@ -22,7 +22,7 @@ pub fn transform(_args: Args, mut item: Item) -> Result { _ => syn_error!(item, "expected struct or enum"), }; - let mod_name = format_ident!("__cdef__{name}"); + let mod_name = format_ident!("__{name}_cdef"); Ok(quote! { #[repr(C)] @@ -42,23 +42,28 @@ pub fn transform(_args: Args, mut item: Item) -> Result { fn generate_type(ty: &Ident) -> Result { let ffi = ffi_crate(); let fmt = quote!(::std::format!); - let name_fmt = LitStr::new(&format!("{ty}"), ty.span()); - let cdecl_fmt = LitStr::new(&format!("struct {ty} {{name}}"), ty.span()); + let name = LitStr::new(&format!("{ty}"), ty.span()); + let cdecl_fmt = LitStr::new(&format!("struct {ty} {{}}"), ty.span()); Ok(quote! { unsafe impl #ffi::Type for #ty { - fn name() -> ::std::string::String { - #fmt(#name_fmt) + fn name() -> impl ::std::fmt::Display { + #name } - fn cdecl(name: impl ::std::fmt::Display) -> ::std::string::String { - #fmt(#cdecl_fmt) + fn cdecl(name: impl ::std::fmt::Display) -> impl ::std::fmt::Display { + #fmt(#cdecl_fmt, name) } fn build(b: &mut #ffi::TypeBuilder) { b.cdef::().metatype::(); } } + + unsafe impl #ffi::ToFfi for #ty { + type To = Self; + fn convert(self) -> Self::To { self } + } }) } @@ -113,6 +118,11 @@ struct CField { attrs: CFieldAttrs, } +#[derive(Default)] +struct CFieldAttrs { + opaque: bool, +} + fn to_cfields(fields: &mut Fields) -> Result> { match fields { Fields::Named(fields) => fields.named.iter_mut(), @@ -133,11 +143,6 @@ fn to_cfields(fields: &mut Fields) -> Result> { .collect() } -#[derive(Default)] -struct CFieldAttrs { - opaque: bool, -} - fn parse_attrs(attrs: &mut Vec) -> Result { let mut parsed = CFieldAttrs::default(); let mut i = 0; diff --git a/crates/luaffi_impl/src/metatype.rs b/crates/luaffi_impl/src/metatype.rs index cbc29dc..bfbfb8d 100644 --- a/crates/luaffi_impl/src/metatype.rs +++ b/crates/luaffi_impl/src/metatype.rs @@ -11,7 +11,7 @@ pub fn transform(mut imp: ItemImpl) -> Result { ); let impls = generate_impls(&mut imp)?; - let mod_name = format_ident!("__metatype__{}", ty_name(&imp.self_ty)?); + let mod_name = format_ident!("__{}_metatype", ty_name(&imp.self_ty)?); Ok(quote! { #imp @@ -31,25 +31,30 @@ fn generate_impls(imp: &mut ItemImpl) -> Result { let ffi = ffi_crate(); let ffi_funcs = get_ffi_functions(imp)?; + + // wrapper extern "C" functions that call the actual implementation let ffi_wrappers: Vec<_> = ffi_funcs .iter() .map(generate_ffi_wrapper) .collect::>()?; + // ffi function registration code let ffi_register: Vec<_> = ffi_funcs .iter() .map(generate_ffi_register) .collect::>()?; - let ffi_drop_fn = format_ident!("__ffi_drop"); - let ffi_drop_name = format!("{ty_name}_drop"); + let ffi_drop_rname = format_ident!("__ffi_drop"); + let ffi_drop_cname = format!("{ty_name}_drop"); + // ffi function symbol export code let ffi_exports = { - let mut names = vec![&ffi_drop_fn]; + let mut names = vec![&ffi_drop_rname]; names.extend(ffi_funcs.iter().map(|f| &f.rust_name)); generate_ffi_exports(&ty, names.into_iter())? }; + // lua function registration code let lua_funcs = get_lua_functions(imp)?; let lua_register: Vec<_> = lua_funcs .iter() @@ -57,6 +62,15 @@ fn generate_impls(imp: &mut ItemImpl) -> Result { .collect::>()?; Ok(quote! { + impl #ty { + #(#ffi_wrappers)* + + #[unsafe(export_name = #ffi_drop_cname)] + unsafe extern "C" fn #ffi_drop_rname(ptr: *mut Self) { + unsafe { ::std::ptr::drop_in_place(ptr) } + } + } + unsafe impl #ffi::Metatype for #ty { type Target = Self; @@ -64,17 +78,8 @@ fn generate_impls(imp: &mut ItemImpl) -> Result { #(#ffi_register)* #(#lua_register)* - b.declare::(#ffi_drop_name); - b.metatable_raw("gc", ::std::format_args!("C.{}", #ffi_drop_name)); - } - } - - impl #ty { - #(#ffi_wrappers)* - - #[unsafe(export_name = #ffi_drop_name)] - unsafe extern "C" fn #ffi_drop_fn(&mut self) { - unsafe { ::std::ptr::drop_in_place(self) } + b.declare::(#ffi_drop_cname); + b.metatable_raw("gc", ::std::format_args!("__C.{}", #ffi_drop_cname)); } } @@ -89,7 +94,13 @@ struct FfiFunction { c_name: String, params: Vec, ret: Type, - ret_out: bool, + ret_by_out: bool, + attrs: FfiFunctionAttrs, +} + +#[derive(Default)] +struct FfiFunctionAttrs { + metatable: Option, } fn get_ffi_functions(imp: &mut ItemImpl) -> Result> { @@ -103,6 +114,7 @@ fn get_ffi_functions(imp: &mut ItemImpl) -> Result> { { func.sig.abi = None; + // normalise inputs to PatType let params = func .sig .inputs @@ -118,13 +130,14 @@ fn get_ffi_functions(imp: &mut ItemImpl) -> Result> { }) .collect::>()?; + // normalise output to Type let ret = match func.sig.output { ReturnType::Default => parse_quote!(()), ReturnType::Type(_, ref ty) => (**ty).clone(), }; // whether to use out-param for return values - let ret_out = !is_primitive(&ret); + let ret_by_out = !is_primitive(&ret); funcs.push(FfiFunction { name: func.sig.ident.clone(), @@ -133,7 +146,8 @@ fn get_ffi_functions(imp: &mut ItemImpl) -> Result> { c_name: format!("{}_{}", ty_name(&imp.self_ty)?, func.sig.ident), params, ret, - ret_out, + ret_by_out, + attrs: parse_ffi_function_attrs(&mut func.attrs)?, }); } } @@ -141,6 +155,27 @@ fn get_ffi_functions(imp: &mut ItemImpl) -> Result> { Ok(funcs) } +fn parse_ffi_function_attrs(attrs: &mut Vec) -> Result { + let mut parsed = FfiFunctionAttrs::default(); + let mut i = 0; + while let Some(attr) = attrs.get(i) { + if let Some(name) = attr.path().get_ident() { + if name == "metatable" { + parsed.metatable = attr.parse_args()?; + attrs.remove(i); + continue; + } else if name == "new" { + parsed.metatable = parse_quote!("new"); + attrs.remove(i); + continue; + } + } + i += 1; + } + + Ok(parsed) +} + #[derive(Debug)] enum FfiArgType { Default, @@ -150,14 +185,6 @@ fn get_ffi_arg_type(_ty: &Type) -> FfiArgType { FfiArgType::Default } -fn escape_self(name: &Ident) -> Ident { - if name == "self" { - format_ident!("__self") - } else { - name.clone() - } -} - fn generate_ffi_wrapper(func: &FfiFunction) -> Result { let ffi = ffi_crate(); let name = &func.name; @@ -166,36 +193,35 @@ fn generate_ffi_wrapper(func: &FfiFunction) -> Result { let mut params = vec![]; let mut args = vec![]; - for param in func.params.iter() { - let name = escape_self(pat_ident(¶m.pat)?); + for (i, param) in func.params.iter().enumerate() { + let name = format_ident!("__arg{i}"); let ty = ¶m.ty; match get_ffi_arg_type(ty) { FfiArgType::Default => { - params.push(quote! { #name: <#ty as #ffi::FromFfi>::From }); - args.push(quote! { <#ty as #ffi::FromFfi>::convert(#name) }); + params.push(quote! { #name: <#ty as #ffi::FromFfi>::FromArg }); + args.push(quote! { <#ty as #ffi::FromFfi>::convert_arg(#name) }); } } } - // make return by out-param the first parameter - let (ret, do_ret) = if func.ret_out { + let (ret, call) = if func.ret_by_out { + // make return by out-param the first parameter let ret = &func.ret; - params.insert(0, quote! { __ret_out: *mut #ret }); + params.insert(0, quote! { __out: *mut #ret }); ( - quote! { () }, - quote! { unsafe { ::std::ptr::write(__ret_out, __ret) }; }, + quote!(()), + quote! { ::std::ptr::write(__out, Self::#name(#(#args),*)) }, ) } else { let ret = &func.ret; - (quote! { #ret }, quote! { return __ret; }) + (quote! { #ret }, quote! { Self::#name(#(#args),*) }) }; Ok(quote! { #[unsafe(export_name = #c_name)] unsafe extern "C" fn #rust_name(#(#params),*) -> #ret { - let __ret = Self::#name(#(#args),*); - #do_ret + unsafe { #call } } }) } @@ -204,35 +230,53 @@ fn generate_ffi_register(func: &FfiFunction) -> Result { let ffi = ffi_crate(); let lua_name = &func.lua_name; let c_name = &func.c_name; + let mut params = vec![]; - let mut asserts = vec![]; + let mut register = vec![]; for param in func.params.iter() { let name = format!("{}", pat_ident(¶m.pat)?); let ty = ¶m.ty; - params.push(match get_ffi_arg_type(ty) { - FfiArgType::Default => quote! { b.param::<#ty>(#name); }, - }); + match get_ffi_arg_type(ty) { + FfiArgType::Default => { + params.push(quote! { <#ty as #ffi::FromFfi>::FromArg }); + register.push(quote! { b.param::<#ty>(#name); }) + } + }; } let ret = &func.ret; let ret_conv = if is_unit(ret) { quote! { #ffi::FfiReturnConvention::Void } - } else if func.ret_out { - asserts.push(quote! { #ffi::__internal::assert_type_ne_all!(#ret, ()); }); - quote! { #ffi::FfiReturnConvention::OutParam } + } else if func.ret_by_out { + quote! { #ffi::FfiReturnConvention::ByOutParam } } else { - asserts.push(quote! { #ffi::__internal::assert_type_ne_all!(#ret, ()); }); quote! { #ffi::FfiReturnConvention::ByValue } }; + let declare = quote! { + b.declare:: #ret>(#c_name); + }; + + let register = match func.attrs.metatable { + Some(ref mt) => quote! { + b.metatable(#mt, |b| { + #(#register)* + b.call::<#ret>(#c_name, #ret_conv); + }); + }, + None => quote! { + b.index(#lua_name, |b| { + #(#register)* + b.call::<#ret>(#c_name, #ret_conv); + }); + }, + }; + Ok(quote! { - b.index(#lua_name, |b| { - #(#asserts)* - #(#params)* - b.call::<#ret>(#c_name, #ret_conv); - }); + #declare + #register }) } @@ -241,7 +285,8 @@ fn generate_ffi_exports<'a>( names: impl Iterator, ) -> Result { Ok(quote! { - // hack to prevent ffi functions from being dead code-eliminated + // this ensures ffi function symbol exports are actually present in the resulting binary, + // otherwise they may get dead code-eliminated before it reaches the linker #[used] static __FFI_EXPORTS: &[fn()] = unsafe { &[#(::std::mem::transmute(#ty::#names as *const ())),*] @@ -265,6 +310,7 @@ fn get_lua_functions(imp: &mut ItemImpl) -> Result> { && let Some(ref abi) = abi.name && abi.value() == "Lua" { + // normalise inputs to PatType let params = func .sig .inputs @@ -281,6 +327,13 @@ fn get_lua_functions(imp: &mut ItemImpl) -> Result> { }) .collect::>()?; + // shouldn't specify an output type + syn_assert!( + matches!(func.sig.output, ReturnType::Default), + func.sig.output, + "cannot have return type" + ); + funcs.push(LuaFunction { name: format!("{}", func.sig.ident), body: func.block.clone(), diff --git a/crates/luaffi_impl/src/utils.rs b/crates/luaffi_impl/src/utils.rs index 13254ac..377dacd 100644 --- a/crates/luaffi_impl/src/utils.rs +++ b/crates/luaffi_impl/src/utils.rs @@ -1,5 +1,5 @@ use std::env; -use syn::*; +use syn::{spanned::Spanned, *}; macro_rules! syn_error { ($src:expr, $($fmt:expr),+) => {{ @@ -35,11 +35,15 @@ pub fn ty_name(ty: &Type) -> Result<&Ident> { } } -pub fn pat_ident(pat: &Pat) -> Result<&Ident> { - match pat { - Pat::Ident(ident) => Ok(&ident.ident), +pub fn pat_ident(pat: &Pat) -> Result { + Ok(match pat { + Pat::Ident(ident) => match ident.subpat { + Some((_, ref subpat)) => syn_error!(subpat, "unexpected subpattern"), + None => ident.ident.clone(), + }, + Pat::Wild(wild) => Ident::new("_", wild.span()), _ => syn_error!(pat, "expected ident"), - } + }) } pub fn is_unit(ty: &Type) -> bool { diff --git a/crates/luaify/src/generate.rs b/crates/luaify/src/generate.rs index d8a067d..0e8bb41 100644 --- a/crates/luaify/src/generate.rs +++ b/crates/luaify/src/generate.rs @@ -163,9 +163,9 @@ fn generate_expr_binary(f: &mut Formatter, bin: &ExprBinary, cx: Context) -> Res | BinOp::Shl(_) | BinOp::Shr(_) | BinOp::Eq(_) + | BinOp::Ne(_) | BinOp::Lt(_) | BinOp::Le(_) - | BinOp::Ne(_) | BinOp::Ge(_) | BinOp::Gt(_) | BinOp::And(_) @@ -233,22 +233,22 @@ fn generate_expr_binary(f: &mut Formatter, bin: &ExprBinary, cx: Context) -> Res BinOp::MulAssign(_) => assign_bin_op!("*"), BinOp::Div(_) => bin_op!("/"), BinOp::DivAssign(_) => assign_bin_op!("/"), - BinOp::Rem(_) => call_op!("math.fmod"), - BinOp::RemAssign(_) => assign_call_op!("math.fmod"), - BinOp::BitAnd(_) => call_op!("band"), - BinOp::BitAndAssign(_) => assign_call_op!("band"), - BinOp::BitOr(_) => call_op!("bor"), - BinOp::BitOrAssign(_) => assign_call_op!("bor"), - BinOp::BitXor(_) => call_op!("bxor"), - BinOp::BitXorAssign(_) => assign_call_op!("bxor"), - BinOp::Shl(_) => call_op!("lshift"), - BinOp::ShlAssign(_) => assign_call_op!("lshift"), - BinOp::Shr(_) => call_op!("arshift"), - BinOp::ShrAssign(_) => assign_call_op!("arshift"), + BinOp::Rem(_) => call_op!("__fmod"), + BinOp::RemAssign(_) => assign_call_op!("__fmod"), + BinOp::BitAnd(_) => call_op!("__band"), + BinOp::BitAndAssign(_) => assign_call_op!("__band"), + BinOp::BitOr(_) => call_op!("__bor"), + BinOp::BitOrAssign(_) => assign_call_op!("__bor"), + BinOp::BitXor(_) => call_op!("__bxor"), + BinOp::BitXorAssign(_) => assign_call_op!("__bxor"), + BinOp::Shl(_) => call_op!("__blshift"), + BinOp::ShlAssign(_) => assign_call_op!("__blshift"), + BinOp::Shr(_) => call_op!("__barshift"), + BinOp::ShrAssign(_) => assign_call_op!("__barshift"), BinOp::Eq(_) => bin_op!("=="), + BinOp::Ne(_) => bin_op!("~="), BinOp::Lt(_) => bin_op!("<"), BinOp::Le(_) => bin_op!("<="), - BinOp::Ne(_) => bin_op!("~="), BinOp::Ge(_) => bin_op!(">="), BinOp::Gt(_) => bin_op!(">"), BinOp::And(_) => bin_op!("and"), diff --git a/crates/luaify/src/transform.rs b/crates/luaify/src/transform.rs index 43fefcd..6873aa6 100644 --- a/crates/luaify/src/transform.rs +++ b/crates/luaify/src/transform.rs @@ -1,4 +1,4 @@ -use crate::utils::{LuaType, syn_error, unwrap_expr_ident, unwrap_pat_ident, wrap_expr_block}; +use crate::utils::{LuaType, expr_ident, pat_ident, syn_error, wrap_expr_block}; use quote::format_ident; use std::mem; use syn::{spanned::*, visit_mut::*, *}; @@ -73,7 +73,7 @@ impl Visitor { match input { Pat::Ident(_) => {} Pat::Type(typed) => { - let ident = unwrap_pat_ident(&typed.pat)?; + let ident = pat_ident(&typed.pat)?; let ty = mem::replace(&mut typed.ty, parse_quote!(_)); match (&*ty).try_into()? { LuaType::Any => {} @@ -112,7 +112,7 @@ impl Visitor { Some((Ident::new("self", recv.self_token.span()), ty)) } FnArg::Typed(typed) => { - let ident = unwrap_pat_ident(&typed.pat)?; + let ident = pat_ident(&typed.pat)?; let ty = mem::replace(&mut typed.ty, parse_quote!(_)); Some((ident, ty)) } @@ -149,9 +149,9 @@ impl Visitor { let mut prelude: Option = None; let ty: LuaType = (&*cast.ty).try_into()?; let ty_str = format!("{ty}"); - let (ident, msg) = match unwrap_expr_ident(&arg).ok() { - Some(ident) => (ident.clone(), format!("{ty} expected in '{ident}', got ")), - None => { + let (ident, msg) = match expr_ident(&arg) { + Ok(ident) => (ident.clone(), format!("{ty} expected in '{ident}', got ")), + Err(_) => { let ident = Ident::new("_", arg.span()); prelude = Some(parse_quote! { let #ident = #arg; }); (ident, format!("{ty} expected, got ")) diff --git a/crates/luaify/src/utils.rs b/crates/luaify/src/utils.rs index bd0a422..d452d4d 100644 --- a/crates/luaify/src/utils.rs +++ b/crates/luaify/src/utils.rs @@ -25,14 +25,14 @@ pub fn wrap_expr_block(expr: &Expr) -> Block { } } -pub fn unwrap_expr_ident(expr: &Expr) -> Result<&Ident> { +pub fn expr_ident(expr: &Expr) -> Result<&Ident> { match expr { Expr::Path(path) => path.path.require_ident(), _ => syn_error!(expr, "expected ident"), } } -pub fn unwrap_pat_ident(pat: &Pat) -> Result { +pub fn pat_ident(pat: &Pat) -> Result { Ok(match pat { Pat::Ident(ident) => match ident.subpat { Some((_, ref subpat)) => syn_error!(subpat, "unexpected subpattern"), diff --git a/crates/luajit/src/lib.rs b/crates/luajit/src/lib.rs index ea48090..ab45213 100644 --- a/crates/luajit/src/lib.rs +++ b/crates/luajit/src/lib.rs @@ -82,6 +82,9 @@ impl Error { } } +/// Lua result. +pub type Result = ::std::result::Result; + /// Lua type. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] pub enum Type { @@ -306,28 +309,37 @@ impl Default for DumpMode { pub struct Chunk { name: BString, content: BString, + mode: LoadMode, } impl Chunk { - /// Creates a named [`Chunk`] with the given content. - pub fn named(name: impl Into, content: impl Into) -> Self { - Self { - name: name.into(), - content: content.into(), - } - } - - /// Creates an unnamed [`Chunk`] with the given content. - pub fn unnamed(content: impl Into) -> Self { + /// Creates a new [`Chunk`] with the given content. + pub fn new(content: impl Into) -> Self { Self { name: "?".into(), content: content.into(), + mode: LoadMode::AUTO, } } - /// Name of this chunk. - pub fn name(&self) -> &BStr { - self.name.as_ref() + /// Sets the name of this chunk as `name`. + pub fn name(&mut self, name: impl AsRef<[u8]>) -> &mut Self { + self.name = name.as_ref().into(); + self + } + + /// Sets the name of this chunk as the path `path`. + pub fn path(&mut self, path: impl AsRef<[u8]>) -> &mut Self { + let mut name = BString::from(b"@"); + name.extend_from_slice(path.as_ref()); + self.name = name; + self + } + + /// Sets the mode flag for loading this chunk. + pub fn mode(&mut self, mode: LoadMode) -> &mut Self { + self.mode = mode; + self } } @@ -345,13 +357,19 @@ impl DerefMut for Chunk { } } +impl> From for Chunk { + fn from(value: T) -> Self { + Self::new(value) + } +} + #[derive(Debug)] struct GlobalState { ptr: NonNull, } impl GlobalState { - pub fn new() -> Result { + pub fn new() -> Result { unsafe { // SAFETY: lua_newstate may return a null pointer if allocation fails let ptr = NonNull::new(lua_newstate(Some(Self::alloc_cb), ptr::null_mut())) @@ -455,7 +473,7 @@ impl Drop for Ref { /// A state instance can be manipulated using the [`Stack`] object that it mutably dereferences to. #[derive(Debug)] pub struct State { - thread_ref: Ref, + thread: Ref, stack: Stack, } @@ -465,11 +483,11 @@ impl State { /// All built-in libraries are opened by default. /// /// This may return an error if allocation or library initialisation fails. - pub fn new() -> Result { + pub fn new() -> Result { let state = Rc::new(GlobalState::new()?); let mut state = Self { stack: unsafe { Stack::new_unchecked(state.as_ptr()) }, - thread_ref: Ref { + thread: Ref { state, key: LUA_NOREF, }, @@ -492,12 +510,25 @@ impl State { Self { // SAFETY: lua_newthread never returns null, but may panic on oom stack: unsafe { Stack::new_unchecked(lua_newthread(self.as_ptr())) }, - thread_ref: Ref { - state: Rc::clone(&self.thread_ref.state), + thread: Ref { + state: Rc::clone(&self.thread.state), key: unsafe { luaL_ref(self.as_ptr(), LUA_REGISTRYINDEX) }, }, } } + + /// Creates a new [`Ref`] with the given key. + /// + /// # Safety + /// + /// The caller must ensure that the given ref key is unique and not already used by any other + /// instances of [`Ref`]. + pub unsafe fn new_ref_unchecked(&self, key: c_int) -> Ref { + Ref { + state: Rc::clone(&self.thread.state), + key, + } + } } impl Deref for State { @@ -616,6 +647,23 @@ impl Stack { unsafe { lua_pop(self.as_ptr(), n) } } + /// Pops the value at the top of the stack and inserts it at index `idx` by shifting up existing + /// values. + /// + /// Index `idx` cannot be a pseudo-index. + /// + /// Equivalent to [`lua_insert`]. + /// + /// # Panic + /// + /// Panics if the stack is empty or the index `idx` is invalid. + pub fn pop_insert(&mut self, idx: c_int) { + assert!(self.size() >= 1, "cannot pop 1: {self:?}"); + let idx = self.slot(idx).index(); + assert!(idx > 0, "cannot insert into pseudo-index {idx}"); + unsafe { lua_insert(self.as_ptr(), idx) } + } + /// Pops the value at the top of the stack and replaces the value at index `idx` with it. /// /// If the index `idx` points to the top of the stack, this still pops the value and is @@ -800,11 +848,11 @@ impl Stack { /// Pushes the given chunk as a function at the top of the stack. /// /// Equivalent to [`lua_loadx`]. - pub fn load(&mut self, chunk: &Chunk, mode: LoadMode) -> Result<(), Error> { + pub fn load(&mut self, chunk: &Chunk) -> Result<()> { type State<'s> = Option<&'s [u8]>; let mut state: State = Some(chunk.content.as_ref()); let name = CString::new(chunk.name.to_vec()).map_err(Error::BadChunkName)?; - let mode = mode.to_mode_str(); + let mode = chunk.mode.to_mode_str(); unsafe extern "C" fn reader_cb( _L: *mut lua_State, @@ -860,7 +908,7 @@ impl Stack { /// # Panic /// /// Panics if the value at index `idx` is not a function. - pub fn dump(&self, idx: c_int, mode: DumpMode) -> Result { + pub fn dump(&self, idx: c_int, mode: DumpMode) -> Result { let func = self.slot(idx); assert!( func.type_of() == Type::Function, @@ -881,6 +929,24 @@ impl Stack { } } + /// Evaluates the given chunk on the stack synchronously with `narg` values at the top of the + /// stack as arguments. + /// + /// Equivalent to calling [`load`](Self::load) on the chunk and then [`call`](Self::call) on the + /// loaded function. + /// + /// # Panic + /// + /// Panics if there are not enough values on the stack or thread status is invalid. + pub fn eval(&mut self, chunk: &Chunk, narg: c_int, nret: c_int) -> Result { + assert!(0 <= narg && (0 <= nret || nret == LUA_MULTRET)); + let base = self.size() - narg; + assert!(base >= 0, "expected {narg} values: {self:?}"); + self.load(chunk)?; + self.pop_insert(base + 1); + self.call(narg, nret) + } + /// Calls a function on the stack synchronously with `narg` values at the top of the stack as /// arguments. /// @@ -888,8 +954,8 @@ impl Stack { /// the index `top - narg` (i.e. the function is pushed first and then `narg` values as /// arguments). All arguments and the function are popped from the stack and then any return /// values are pushed. If `nret` is not [`LUA_MULTRET`], then the number of return values pushed - /// will be exactly `nret`, filling with nils if necessary. Finally, the number of return values - /// pushed to the stack is returned. + /// will be exactly `nret`, filling with nils if necessary. Finally, the number of returned + /// values pushed to the stack is returned. /// /// The current thread status must not be suspended or dead. /// @@ -899,7 +965,7 @@ impl Stack { /// /// Panics if there are not enough values on the stack, the function to call is not on the /// stack, or thread status is invalid. - pub fn call(&mut self, narg: c_int, nret: c_int) -> Result { + pub fn call(&mut self, narg: c_int, nret: c_int) -> Result { assert!(0 <= narg && (0 <= nret || nret == LUA_MULTRET)); let top = self.size(); @@ -945,8 +1011,8 @@ impl Stack { /// the index `top - narg` (i.e. the function is pushed first and then `narg` values as /// arguments). All arguments and the function are popped from the stack and then any return /// values are pushed. If `nret` is not [`LUA_MULTRET`], then the number of return values pushed - /// will be exactly `nret`, filling with nils if necessary. Finally, the number of return values - /// pushed to the stack is returned. + /// will be exactly `nret`, filling with nils if necessary. Finally, the number of returned + /// values pushed to the stack is returned. /// /// If the thread yields a Rust [`Future`] value, then it will be polled to completion before /// the thread is resumed with the output of the [`Future`] as the argument. If the thread @@ -961,7 +1027,7 @@ impl Stack { /// /// Panics if there are not enough values on the stack, the function to call is not on the /// stack, or thread status is invalid. - pub async fn call_async(&mut self, mut narg: c_int, nret: c_int) -> Result { + pub async fn call_async(&mut self, mut narg: c_int, nret: c_int) -> Result { assert!(0 <= narg && (0 <= nret || nret == LUA_MULTRET)); let top = self.size(); @@ -1015,7 +1081,7 @@ impl Stack { /// pushed first and then `narg` values as arguments). If the current thread status is /// suspended, then there must be `narg` values at the top of the stack. All arguments and the /// function are popped from the stack and then any yielded values are pushed. Finally, the new - /// status of the thread is returned. + /// status of the thread indicating whether the thread had completed or suspended is returned. /// /// The current thread status must not be dead. /// @@ -1025,7 +1091,7 @@ impl Stack { /// /// Panics if there are not enough values on the stack, the function to call is not on the /// stack, or thread status is invalid. - pub fn resume(&mut self, narg: c_int) -> Result { + pub fn resume(&mut self, narg: c_int) -> Result { assert!(0 <= narg); let status = self.status(); let need = match status { @@ -1237,7 +1303,7 @@ impl<'s> Slot<'s> { } /// Parses the value in this slot as a `T`. - pub fn parse>(&self) -> Result { + pub fn parse>(&self) -> Result { T::parse(self) } @@ -1264,6 +1330,9 @@ impl<'s> Slot<'s> { /// Parses the value in this slot as a [`lua_Number`]. /// + /// If the value is not a `number` or a `string` that can be parsed as a number, then this + /// returns [`None`]. + /// /// Equivalent to [`lua_tonumberx`]. pub fn number(&self) -> Option { self.parse().ok() @@ -1271,6 +1340,9 @@ impl<'s> Slot<'s> { /// Parses the value in this slot as a [`lua_Integer`]. /// + /// If the value is not a `number` or a `string` that can be parsed as an integer, then this + /// returns [`None`]. + /// /// Equivalent to [`lua_tointegerx`]. pub fn integer(&self) -> Option { self.parse().ok() @@ -1278,6 +1350,11 @@ impl<'s> Slot<'s> { /// Parses the value in this slot as a binary string. /// + /// If the value is a `number`, then it is converted in-place into a `string` representation of + /// the value first. + /// + /// If the value is not a `string`, then this returns [`None`]. + /// /// Equivalent to [`lua_tolstring`]. pub fn string(&self) -> Option<&'s BStr> { self.parse().ok() @@ -1285,12 +1362,19 @@ impl<'s> Slot<'s> { /// Parses the value in this slot as a UTF-8 string. /// + /// If the value is a `number`, then it is converted in-place into a `string` representation + /// first. + /// /// Equivalent to [`lua_tolstring`]. pub fn string_utf8(&self) -> Option<&'s str> { self.parse().ok() } /// Parses the value in this slot as a [`lua_CFunction`]. + /// + /// If the value is not a C function, then this returns [`None`]. + /// + /// Equivalent to [`lua_tocfunction`]. pub fn function_raw(&self) -> lua_CFunction { unsafe { lua_tocfunction(self.stack.as_ptr(), self.idx) } } @@ -1299,6 +1383,8 @@ impl<'s> Slot<'s> { /// /// If the value is a `cdata`, then the returned pointer is the address of the base of the cdata /// payload. Otherwise this returns a null pointer. + /// + /// Equivalent to [`lua_topointer`]. pub fn cdata(&self) -> *const T { (self.type_of() == Type::Cdata) .then(|| self.pointer().cast()) @@ -1306,9 +1392,28 @@ impl<'s> Slot<'s> { } /// Parses the value in this slot as a generic pointer. + /// + /// If the value is not a GC-managed object that can be represented by a pointer, then this + /// returns a null pointer. + /// + /// Equivalent to [`lua_topointer`]. pub fn pointer(&self) -> *const c_void { unsafe { lua_topointer(self.stack.as_ptr(), self.idx).cast() } } + + /// Returns the length of the value in this slot. + /// + /// For strings, this is the byte-length of the contents of the string. For tables, this is the + /// length of the table defined by the Lua `#` operator. For userdata, this is the size of its + /// payload in bytes. For numbers, this is equivalent to converting the value in-place into a + /// `string` representation before calculating its length. Otherwise, this returns 0. + /// + /// This function does not invoke the `__len` metamethod. + /// + /// Equivalent to [`lua_objlen`]. + pub fn length(&self) -> usize { + unsafe { lua_objlen(self.stack.as_ptr(), self.idx) } + } } /// Pushes a value onto a [`Stack`]. @@ -1509,11 +1614,11 @@ impl Push for CurrentThread { /// [`Slot::parse`]. pub trait Parse<'s>: Sized { /// Parses the value in the given slot. - fn parse(slot: &Slot<'s>) -> Result; + fn parse(slot: &Slot<'s>) -> Result; } impl Parse<'_> for () { - fn parse(slot: &Slot) -> Result { + fn parse(slot: &Slot) -> Result { match slot.type_of() { Type::Nil => Ok(()), ty => Err(Error::InvalidType("nil", ty.name())), @@ -1522,7 +1627,7 @@ impl Parse<'_> for () { } impl Parse<'_> for bool { - fn parse(slot: &Slot) -> Result { + fn parse(slot: &Slot) -> Result { Ok(unsafe { lua_toboolean(slot.stack.as_ptr(), slot.index()) != 0 }) } } @@ -1530,7 +1635,7 @@ impl Parse<'_> for bool { macro_rules! impl_parse_ptr { ($type:ty) => { impl Parse<'_> for $type { - fn parse(slot: &Slot) -> Result { + fn parse(slot: &Slot) -> Result { let ptr = unsafe { lua_touserdata(slot.stack.as_ptr(), slot.idx) }; if !ptr.is_null() { Ok(ptr as $type) @@ -1548,7 +1653,7 @@ impl_parse_ptr!(*const T); macro_rules! impl_parse_num { ($type:ty) => { impl Parse<'_> for $type { - fn parse(slot: &Slot) -> Result { + fn parse(slot: &Slot) -> Result { let mut isnum = 0; let n = unsafe { lua_tonumberx(slot.stack.as_ptr(), slot.idx, &raw mut isnum) }; if isnum != 0 { @@ -1567,7 +1672,7 @@ impl_parse_num!(f64); macro_rules! impl_parse_int { ($type:ty) => { impl Parse<'_> for $type { - fn parse(slot: &Slot) -> Result { + fn parse(slot: &Slot) -> Result { let mut isnum = 0; let n = unsafe { lua_tointegerx(slot.stack.as_ptr(), slot.idx, &raw mut isnum) }; if isnum != 0 { @@ -1594,7 +1699,7 @@ impl_parse_int!(isize); macro_rules! impl_parse_str { ($type:ty) => { impl<'s> Parse<'s> for $type { - fn parse(slot: &Slot<'s>) -> Result { + fn parse(slot: &Slot<'s>) -> Result { let mut len = 0; let ptr = unsafe { lua_tolstring(slot.stack.as_ptr(), slot.idx, &mut len) }; if !ptr.is_null() { @@ -1610,7 +1715,7 @@ macro_rules! impl_parse_str { macro_rules! impl_parse_str_utf8 { ($type:ty) => { impl<'s> Parse<'s> for $type { - fn parse(slot: &Slot<'s>) -> Result { + fn parse(slot: &Slot<'s>) -> Result { Ok(std::str::from_utf8(Parse::parse(slot)?)?.into()) } } diff --git a/src/main.rs b/src/main.rs index c523af8..267f729 100644 --- a/src/main.rs +++ b/src/main.rs @@ -141,10 +141,7 @@ fn init_vm(_args: &Args) -> luajit::State { println!("{registry}"); state - .load( - &luajit::Chunk::named("@[luby]", registry.done()), - luajit::LoadMode::TEXT, - ) + .load(&luajit::Chunk::new(registry.done()).name("@[luby]")) .and_then(|()| state.call(0, 0)) .unwrap_or_else(|err| panic!("failed to load modules: {err}")); @@ -159,16 +156,13 @@ async fn run(args: Args) { Err(err) => return eprintln!("{}", format!("{path}: {err}").red()), }; - if let Err(err) = state.load( - &luajit::Chunk::named(format!("@{path}"), chunk), - Default::default(), - ) { + if let Err(err) = state.load(&luajit::Chunk::new(chunk).path(path)) { return eprintln!("{}", err.red()); } - state - .call_async(0) - .await - .unwrap_or_else(GlobalState::uncaught_error); + match state.call_async(0, 0).await { + Ok(_) => {} + Err(err) => GlobalState::uncaught_error(err), + } } }