diff --git a/Cargo.lock b/Cargo.lock index c751684..043a84d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1053,6 +1053,7 @@ dependencies = [ "derive_more", "globset", "luaffi", + "luaify", "luajit", "sysexits", "tempfile", @@ -1168,7 +1169,6 @@ dependencies = [ "bstr", "luaffi", "luajit-sys", - "thiserror", ] [[package]] diff --git a/crates/lb/Cargo.toml b/crates/lb/Cargo.toml index 6d1837c..83cfefe 100644 --- a/crates/lb/Cargo.toml +++ b/crates/lb/Cargo.toml @@ -18,6 +18,7 @@ net = ["tokio/net", "tokio/io-util"] derive_more = { version = "2.0.1", features = ["full"] } globset = { version = "0.4.16", optional = true } luaffi = { path = "../luaffi" } +luaify = { path = "../luaify" } luajit = { path = "../luajit" } sysexits = "0.9.0" tempfile = { version = "3.20.0", optional = true } diff --git a/crates/lb/src/runtime.rs b/crates/lb/src/runtime.rs index 39bd6e1..a154a1a 100644 --- a/crates/lb/src/runtime.rs +++ b/crates/lb/src/runtime.rs @@ -1,7 +1,8 @@ #![doc(hidden)] use derive_more::{Deref, DerefMut}; use luaffi::{Module, Registry}; -use luajit::{Chunk, State}; +use luaify::luaify_chunk; +use luajit::{Chunk, Index, NewTable, State}; use std::rc::Rc; use tokio::{ task::{JoinHandle, LocalSet, futures::TaskLocalFuture, spawn_local}, @@ -13,6 +14,7 @@ pub type ErrorFn = dyn Fn(&luajit::Error); pub struct Builder { registry: Registry, report_err: Rc, + prohibit_globals: bool, } impl Builder { @@ -23,6 +25,7 @@ impl Builder { Some(trace) => eprintln!("unhandled lua error: {err}\n{trace}"), None => eprintln!("unhandled lua error: {err}"), }), + prohibit_globals: false, } } @@ -35,19 +38,42 @@ impl Builder { self } + pub fn prohibit_globals(&mut self, enabled: bool) -> &mut Self { + self.prohibit_globals = enabled; + self + } + pub fn module(&mut self) -> &mut Self { self.registry.preload::(); self } pub fn build(&self) -> luajit::Result { + let mut state = State::new()?; + let chunk = Chunk::new(self.registry.build()).with_path("[luby]"); + state.eval(&chunk, 0, Some(0))?; + + if self.prohibit_globals { + let mut s = state.guard(); + s.eval( + &Chunk::new(luaify_chunk!({ + return |self, key, value| { + error(("undeclared local variable '%s'").format(key), 2); + }; + })), + 0, + Some(1), + ) + .unwrap(); + s.push(NewTable::new()); + (s.push("__index"), s.push_idx(-3), s.set(-3)); + (s.push("__newindex"), s.push_idx(-3), s.set(-3)); + s.set_metatable(Index::globals()); + } + Ok(Runtime { cx: Context { - state: { - let mut s = State::new()?; - s.eval(Chunk::new(self.registry.build()).path("[luby]"), 0, 0)?; - s - }, + state, report_err: self.report_err.clone(), }, tasks: LocalSet::new(), @@ -97,7 +123,7 @@ pub struct Context { impl Context { pub fn new_thread(&self) -> Self { Self { - state: self.state.new_thread(), + state: State::new_thread(&self.state), report_err: self.report_err.clone(), } } diff --git a/crates/lb/src/task.rs b/crates/lb/src/task.rs index 38e1ec9..93816d6 100644 --- a/crates/lb/src/task.rs +++ b/crates/lb/src/task.rs @@ -12,7 +12,6 @@ use luaffi::{ marker::{function, many}, metatype, }; -use luajit::LUA_MULTRET; use std::{cell::RefCell, ffi::c_int, time::Duration}; use tokio::{task::JoinHandle, time::sleep}; @@ -59,12 +58,12 @@ impl lb_tasklib { extern "Lua-C" fn __spawn(spawn_ref: c_int, handle_ref: c_int) -> lb_task { let handle = spawn(async move |cx| { // SAFETY: handle_ref is always unique, created in Self::spawn above. - let state = unsafe { cx.new_ref_unchecked(spawn_ref) }; + let state = unsafe { luajit::Ref::from_raw(cx, spawn_ref) }; let mut s = cx.guard(); s.resize(0); s.push(state); // this drops the state table ref, but the table is still on the stack let narg = s.unpack(1, 1, None) - 1; // unpack the function and its args from the state table - match s.call_async(narg, LUA_MULTRET).await { + match s.call_async(narg, None).await { Ok(nret) => { s.pack(1, nret); // pack the return values back into the state table } diff --git a/crates/luaify/src/lib.rs b/crates/luaify/src/lib.rs index 5e1ed7c..33cb5ca 100644 --- a/crates/luaify/src/lib.rs +++ b/crates/luaify/src/lib.rs @@ -1,3 +1,6 @@ +//! # luaify +//! +//! A Rust for generating Lua code from Rust syntax. use crate::{ generate::{generate, generate_chunk}, transform::{transform, transform_chunk}, diff --git a/crates/luajit/Cargo.toml b/crates/luajit/Cargo.toml index 769b068..e1360a3 100644 --- a/crates/luajit/Cargo.toml +++ b/crates/luajit/Cargo.toml @@ -16,4 +16,3 @@ bitflags = { version = "2.9.1", features = ["std"] } bstr = "1.12.0" luaffi = { path = "../luaffi" } luajit-sys = { path = "../luajit-sys" } -thiserror = "2.0.12" diff --git a/crates/luajit/src/lib.rs b/crates/luajit/src/lib.rs index f252276..b63ed49 100644 --- a/crates/luajit/src/lib.rs +++ b/crates/luajit/src/lib.rs @@ -5,21 +5,20 @@ use luaffi::future::lua_pollable; use luajit_sys::*; use std::{ alloc::{Layout, alloc, dealloc, realloc}, - borrow::Cow, - ffi::{CStr, CString, NulError}, + cell::UnsafeCell, + ffi::CString, fmt, marker::PhantomData, - mem::ManuallyDrop, - ops::{Deref, DerefMut}, + mem::{self, ManuallyDrop}, + num::NonZero, + ops::{self, Deref, DerefMut}, os::raw::{c_char, c_int, c_void}, pin::Pin, process, ptr::{self, NonNull}, rc::Rc, slice, - str::Utf8Error, }; -use thiserror::Error; /// LuaJIT version string. pub fn version() -> &'static str { @@ -41,65 +40,165 @@ pub use luajit_sys::{ LUA_ENVIRONINDEX, LUA_GLOBALSINDEX, LUA_MULTRET, LUA_NOREF, LUA_REFNIL, LUA_REGISTRYINDEX, }; +/// Lua result. +pub type Result = std::result::Result; + /// Lua error. -#[derive(Debug, Error)] -#[non_exhaustive] -pub enum Error { - /// Out of memory error. - #[error("out of memory")] - OutOfMemory, - /// Lua syntax error returned by [`Stack::load`]. - #[error("{msg}")] +#[derive(Debug)] +pub struct Error(ErrorInner); + +#[derive(Debug)] +enum ErrorInner { + Memory, Syntax { - /// Content of the chunk which had errors. - chunk: BString, - /// Lua error message. msg: BString, + _chunk: Chunk, }, - /// Lua chunk name error returned by [`Stack::load`]. - #[error("bad chunk name: {0}")] - BadChunkName(NulError), - /// Lua error returned by [`Stack::call`]. - #[error("{msg}")] Call { - /// Lua error message. msg: BString, - }, - /// Lua error returned by [`Stack::resume`]. - #[error("{msg}")] - Resume { - /// Lua error message. - msg: BString, - /// Lua stack trace. + kind: ErrorKind, trace: Option, }, - /// Type mismatch type error. - #[error("{0} expected, got {1}")] - InvalidType( - /// The expected type. - &'static str, - /// The actual type. - &'static str, - ), - /// Invalid UTF-8 string error. - #[error("{0}")] - InvalidUtf8(#[from] Utf8Error), + _Argument { + index: u32, + func: BString, + err: Box, + }, + _Slot { + index: i32, + err: Box, + }, + Type { + expected: &'static str, + got: &'static str, + }, + Other(Box), } impl Error { - /// Lua stack trace, if it was collected. - /// - /// Currently this is only available for [`Error::Resume`]. + pub fn new(error: impl Into>) -> Self { + Self(ErrorInner::Other(error.into())) + } + + pub fn invalid_type(expected: &'static str, got: &'static str) -> Self { + Self(ErrorInner::Type { expected, got }) + } + + fn memory() -> Self { + Self(ErrorInner::Memory) + } + + fn syntax(msg: impl AsRef<[u8]>, chunk: Chunk) -> Self { + Self(ErrorInner::Syntax { + msg: msg.as_ref().into(), + _chunk: chunk, + }) + } + + fn call(kind: ErrorKind, msg: impl AsRef<[u8]>, trace: Option) -> Self { + Self(ErrorInner::Call { + kind, + msg: msg.as_ref().into(), + trace, + }) + } + + /// Kind of this error. + pub fn kind(&self) -> ErrorKind { + match self.0 { + ErrorInner::Memory => ErrorKind::Memory, + ErrorInner::Syntax { .. } => ErrorKind::Syntax, + ErrorInner::Call { kind, .. } => kind, + _ => std::error::Error::source(self) + .and_then(|err| err.downcast_ref::()) + .map(|err| err.kind()) + .unwrap_or(ErrorKind::Runtime), + } + } + + /// Stack trace, or [`None`] if it was not collected. pub fn trace(&self) -> Option<&BStr> { - match self { - Self::Resume { trace, .. } => trace.as_ref().map(|s| s.as_ref()), + match self.0 { + ErrorInner::Call { ref trace, .. } => trace.as_ref().map(|s| s.as_ref()), _ => None, } } } -/// Lua result. -pub type Result = ::std::result::Result; +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match &self.0 { + ErrorInner::Syntax { msg, .. } | ErrorInner::Call { msg, .. } => write!(f, "{msg}"), + ErrorInner::_Argument { index, func, err } => { + write!(f, "bad argument #{index} to '{func}': {err}") + } + ErrorInner::_Slot { index, err } => write!(f, "bad stack slot #{index}: {err}"), + ErrorInner::Type { expected, got } => { + write!(f, "{expected} expected, got {got}") + } + _ => match std::error::Error::source(self) { + Some(err) => write!(f, "{err}"), + None => write!(f, "{}", self.kind()), + }, + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match &self.0 { + ErrorInner::_Argument { err, .. } | ErrorInner::_Slot { err, .. } => Some(err), + ErrorInner::Other(err) => Some(err.as_ref()), + _ => None, + } + } +} + +/// Lua error kind. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ErrorKind { + /// Runtime error. + Runtime, + /// Syntax error. + Syntax, + /// Memory allocation error. + Memory, + /// Error while running the error handler function. + Error, +} + +impl ErrorKind { + fn from_raw(code: c_int) -> Option { + Some(match code { + LUA_ERRRUN => Self::Runtime, + LUA_ERRSYNTAX => Self::Syntax, + LUA_ERRMEM => Self::Memory, + LUA_ERRERR => Self::Error, + _ => return None, + }) + } + + /// Name of this error kind, like `"runtime"` or `"syntax"`. + pub fn name(self) -> &'static str { + match self { + Self::Runtime => "runtime", + Self::Syntax => "syntax", + Self::Memory => "memory", + Self::Error => "error", + } + } +} + +impl fmt::Display for ErrorKind { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ErrorKind::Runtime => write!(f, "runtime error"), + ErrorKind::Syntax => write!(f, "syntax error"), + ErrorKind::Memory => write!(f, "memory allocation error"), + ErrorKind::Error => write!(f, "error while running the error handler function"), + } + } +} /// Lua type. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] @@ -128,8 +227,7 @@ pub enum Type { } impl Type { - /// Converts a raw Lua type code to [`Type`], returning [`None`] if the value is invalid. - pub fn from_code(code: c_int) -> Option { + fn from_raw(code: c_int) -> Option { Some(match code { LUA_TNIL => Self::Nil, LUA_TBOOLEAN => Self::Boolean, @@ -175,19 +273,20 @@ pub enum Status { #[default] Normal, /// Thread terminated with an error and can no longer be used. - Dead, + Dead( + /// Why this thread is dead. + ErrorKind, + ), /// Thread suspended with `coroutine.yield(...)` and is awaiting resume. Suspended, } impl Status { - /// Converts a raw Lua status code to [`Status`], returning [`None`] if the value is invalid. - pub fn from_code(code: c_int) -> Option { + fn from_raw(code: c_int) -> Option { Some(match code { LUA_OK => Self::Normal, LUA_YIELD => Self::Suspended, - LUA_ERRRUN | LUA_ERRSYNTAX | LUA_ERRMEM | LUA_ERRERR => Self::Dead, - _ => return None, + _ => Self::Dead(ErrorKind::from_raw(code)?), }) } @@ -195,7 +294,7 @@ impl Status { pub fn name(&self) -> &'static str { match self { Status::Normal => "normal", - Status::Dead => "dead", + Status::Dead(_) => "dead", Status::Suspended => "suspended", } } @@ -203,14 +302,18 @@ impl Status { impl fmt::Display for Status { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name()) + match self { + Status::Normal => write!(f, "normal"), + Status::Dead(why) => write!(f, "dead ({why})"), + Status::Suspended => write!(f, "suspended"), + } } } -/// Result of [`Stack::resume`]. +/// Result of [`resume`](Stack::resume). #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum ResumeStatus { - /// Thread returned successfully and can run another function. + /// Thread completed successfully and can run another function. #[default] Ok, /// Thread suspended with `coroutine.yield(...)` and is awaiting resume. @@ -253,31 +356,23 @@ bitflags! { } } -impl LoadMode { - fn to_mode_str(&self) -> Cow<'static, CStr> { - Cow::Borrowed(match *self { - Self::NONE => c"", - Self::AUTO => c"bt", - Self::TEXT => c"t", - Self::BINARY => c"b", - _ => { - let mut s = String::new(); - self.contains(Self::TEXT).then(|| s.push_str("t")); - self.contains(Self::BINARY).then(|| s.push_str("b")); - self.contains(Self::GC32).then(|| s.push_str("W")); - self.contains(Self::GC64).then(|| s.push_str("X")); - return Cow::Owned(CString::new(s).unwrap()); - } - }) - } -} - impl Default for LoadMode { fn default() -> Self { Self::AUTO } } +impl LoadMode { + fn mode_str(&self) -> CString { + let mut s = String::new(); + self.contains(Self::TEXT).then(|| s.push_str("t")); + self.contains(Self::BINARY).then(|| s.push_str("b")); + self.contains(Self::GC32).then(|| s.push_str("W")); + self.contains(Self::GC64).then(|| s.push_str("X")); + CString::new(s).unwrap() + } +} + bitflags! { /// Flags for [`Stack::dump`]. #[repr(transparent)] @@ -296,30 +391,23 @@ bitflags! { } } -impl DumpMode { - fn to_mode_str(&self) -> Cow<'static, CStr> { - Cow::Borrowed(match *self { - Self::NONE => c"", - Self::STRIP => c"s", - Self::DETERMINISTIC => c"d", - _ => { - let mut s = String::new(); - self.contains(Self::STRIP).then(|| s.push_str("s")); - self.contains(Self::DETERMINISTIC).then(|| s.push_str("d")); - self.contains(Self::GC32).then(|| s.push_str("W")); - self.contains(Self::GC64).then(|| s.push_str("X")); - return Cow::Owned(CString::new(s).unwrap()); - } - }) - } -} - impl Default for DumpMode { fn default() -> Self { Self::DETERMINISTIC } } +impl DumpMode { + fn mode_str(&self) -> CString { + let mut s = String::new(); + self.contains(Self::STRIP).then(|| s.push_str("s")); + self.contains(Self::DETERMINISTIC).then(|| s.push_str("d")); + self.contains(Self::GC32).then(|| s.push_str("W")); + self.contains(Self::GC64).then(|| s.push_str("X")); + CString::new(s).unwrap() + } +} + /// Lua chunk data. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Chunk { @@ -329,7 +417,7 @@ pub struct Chunk { } impl Chunk { - /// Creates a new [`Chunk`] with the given content. + /// Creates a new chunk with the given content. pub fn new(content: impl Into) -> Self { Self { name: "?".into(), @@ -338,24 +426,41 @@ impl Chunk { } } - /// Sets the name of this chunk as `name`. - pub fn name(&mut self, name: impl AsRef<[u8]>) -> &mut Self { - self.name = name.as_ref().into(); - self + /// Name of this chunk, like `?` or `"@path/to/file.lua"`. + pub fn name(&self) -> &BStr { + self.name.as_ref() } - /// Sets the name of this chunk as the path `path`. - pub fn path(&mut self, path: impl AsRef<[u8]>) -> &mut Self { + /// Path of this chunk, if the name of this chunk starts with `@`. + pub fn path(&self) -> Option<&BStr> { + self.name.strip_prefix(b"@").map(|s| s.as_ref()) + } + + /// Mode flag for loading this chunk. + pub fn mode(&self) -> LoadMode { + self.mode + } + + /// Assigns a name to this chunk. + pub fn with_name(self, name: impl AsRef<[u8]>) -> Self { + Self { + name: name.as_ref().into(), + ..self + } + } + + /// Assigns a name to this chunk as a path. + /// + /// This sets the name of this chunk to `@path`, where `path` is the given path. + pub fn with_path(self, path: impl AsRef<[u8]>) -> Self { let mut name = BString::from(b"@"); name.extend_from_slice(path.as_ref()); - self.name = name; - self + Self { name, ..self } } - /// Sets the mode flag for loading this chunk. - pub fn mode(&mut self, mode: LoadMode) -> &mut Self { - self.mode = mode; - self + /// Assigns a mode flag for loading this chunk. + pub fn with_mode(self, mode: LoadMode) -> Self { + Self { mode, ..self } } } @@ -385,20 +490,123 @@ impl> From for Chunk { } } +macro_rules! assert_slots { + ($stack:expr, $n:expr) => {{ + let size = $stack.size(); + let n = $n; + assert!( + n <= size, + "stack underflow: expected at least {n} values, got {size}: {stack:?}", + stack = $stack, + ); + let _base = size - n; + _base + }}; +} + +macro_rules! assert_type { + ($stack:expr, $slot:expr, $expected:expr, $pat:pat) => {{ + let slot = $stack.slot($slot); + let expected = $expected; + let idx = slot.index(); + let got = slot.ty(); + assert!( + matches!(got, $pat), + "expected {expected} at index {idx}, got {got}: {stack:?}", + stack = $stack, + ); + slot + }}; +} + #[derive(Debug)] struct GlobalState { ptr: NonNull, + _alloc: Box>, +} + +#[derive(Debug, Default)] +struct AllocatorState { + alloc: usize, + dealloc: usize, } impl GlobalState { pub fn new() -> Result { unsafe { // SAFETY: lua_newstate may return a null pointer if allocation fails - let ptr = NonNull::new(lua_newstate(Some(Self::alloc_cb), ptr::null_mut())) - .ok_or(Error::OutOfMemory)?; - + let mut alloc = Box::new(UnsafeCell::new(AllocatorState::default())); + let ud = alloc.get_mut() as *mut _ as *mut c_void; + let ptr = NonNull::new(lua_newstate(Some(Self::alloc_cb), ud)).ok_or(Error::memory())?; lua_atpanic(ptr.as_ptr(), Some(Self::panic_cb)); - Ok(Self { ptr }) + Ok(Self { ptr, _alloc: alloc }) + } + } + + fn as_ptr(&self) -> *mut lua_State { + self.ptr.as_ptr() + } + + unsafe extern "C" fn alloc_cb( + ud: *mut c_void, + ptr: *mut c_void, + osize: usize, + nsize: usize, + ) -> *mut c_void { + unsafe { + // https://github.com/tikv/jemallocator/blob/main/jemallocator/src/lib.rs + #[cfg(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc"))] + const ALIGNOF_MAX_ALIGN_T: usize = 8; + #[cfg(any( + target_arch = "x86", + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "powerpc64", + target_arch = "loongarch64", + target_arch = "mips64", + target_arch = "riscv64", + target_arch = "s390x", + target_arch = "sparc64" + ))] + const ALIGNOF_MAX_ALIGN_T: usize = 16; + + let old = Layout::from_size_align(osize, ALIGNOF_MAX_ALIGN_T) + .expect("lua alloc error: requested osize is too large"); + let new = Layout::from_size_align(nsize, ALIGNOF_MAX_ALIGN_T) + .expect("lua alloc error: requested nsize is too large"); + + let state = &mut *(ud as *mut AllocatorState); + let ptr = ptr as *mut u8; + + // SAFETY: from lua documentation: + // + // When nsize is zero, the allocator must return NULL; if osize is not zero, it should + // free the block pointed to by ptr. When nsize is not zero, the allocator returns + // NULL if and only if it cannot fill the request. When nsize is not zero and osize is + // zero, the allocator should behave like malloc. When nsize and osize are not zero, + // the allocator behaves like realloc. Lua assumes that the allocator never fails when + // osize >= nsize. + (if new.size() == 0 { + if old.size() != 0 { + dealloc(ptr, old); + state.dealloc += old.size(); + } + + ptr::null_mut() + } else { + let ptr = if old.size() == 0 { + alloc(new) + } else { + realloc(ptr, old, new.size()) + }; + + if !ptr.is_null() { + state.alloc += new.size(); + state.dealloc += old.size(); + } + + ptr + }) as *mut c_void } } @@ -408,64 +616,13 @@ impl GlobalState { let stack = unsafe { Stack::new_unchecked(L) }; let msg = stack - .slot(-1) - .string() - .unwrap_or(b"unknown lua panic".into()); + .top() + .and_then(|s| s.string()) + .unwrap_or(b"unknown error".into()); eprintln!("lua panicked: {msg}"); process::abort() } - - unsafe extern "C" fn alloc_cb( - _ud: *mut c_void, - ptr: *mut c_void, - osize: usize, - nsize: usize, - ) -> *mut c_void { - // https://github.com/tikv/jemallocator/blob/main/jemallocator/src/lib.rs - #[cfg(any(target_arch = "arm", target_arch = "mips", target_arch = "powerpc"))] - const ALIGNOF_MAX_ALIGN_T: usize = 8; - #[cfg(any( - target_arch = "x86", - target_arch = "x86_64", - target_arch = "aarch64", - target_arch = "powerpc64", - target_arch = "loongarch64", - target_arch = "mips64", - target_arch = "riscv64", - target_arch = "s390x", - target_arch = "sparc64" - ))] - const ALIGNOF_MAX_ALIGN_T: usize = 16; - - let old_layout = Layout::from_size_align(osize, ALIGNOF_MAX_ALIGN_T) - .expect("lua alloc error: requested osize is too large"); - let new_layout = Layout::from_size_align(nsize, ALIGNOF_MAX_ALIGN_T) - .expect("lua alloc error: requested nsize is too large"); - - // SAFETY: from lua documentation: - // When nsize is zero, the allocator must return NULL; if osize is not zero, it should - // free the block pointed to by ptr. When nsize is not zero, the allocator returns NULL if - // and only if it cannot fill the request. When nsize is not zero and osize is zero, the - // allocator should behave like malloc. When nsize and osize are not zero, the allocator - // behaves like realloc. Lua assumes that the allocator never fails when osize >= nsize. - if nsize == 0 { - if osize != 0 { - unsafe { dealloc(ptr as *mut u8, old_layout) } - } - ptr::null_mut() - } else { - if osize == 0 { - unsafe { alloc(new_layout) as *mut c_void } - } else { - unsafe { realloc(ptr as *mut u8, old_layout, nsize) as *mut c_void } - } - } - } - - fn as_ptr(&self) -> *mut lua_State { - self.ptr.as_ptr() - } } impl Drop for GlobalState { @@ -484,6 +641,18 @@ pub struct Ref { } impl Ref { + /// Creates a new ref by popping the value at the top of the given stack. + pub fn new(stack: &mut State) -> Self { + // SAFETY: luaL_ref always returns a unique key + assert_slots!(stack, 1); + unsafe { Self::from_raw(stack, luaL_ref(stack.as_ptr(), LUA_REGISTRYINDEX)) } + } + + pub unsafe fn from_raw(stack: &State, key: c_int) -> Self { + let state = Rc::clone(&stack.thread_ref.state); + Self { state, key } + } + /// Consumes this ref and returns the original key used to create the ref. /// /// This key can be used to index into the registry table ([`LUA_REGISTRYINDEX`]) to retrieve @@ -496,9 +665,22 @@ impl Ref { } } +impl Clone for Ref { + fn clone(&self) -> Self { + let state = Rc::clone(&self.state); + let key = unsafe { + // SAFETY: luaL_ref always returns a unique key + let mut stack = Stack::new_unchecked(state.as_ptr()); + stack.geti(PseudoIndex::Registry, self.key as isize); + luaL_ref(stack.as_ptr(), LUA_REGISTRYINDEX) + }; + + Self { state, key } + } +} + impl Drop for Ref { fn drop(&mut self) { - // SAFETY: luaL_unref is guaranteed to not fail unsafe { luaL_unref(self.state.as_ptr(), LUA_REGISTRYINDEX, self.key) } } } @@ -508,7 +690,7 @@ impl Drop for Ref { /// A state instance can be manipulated using the [`Stack`] object that it mutably dereferences to. #[derive(Debug)] pub struct State { - thread: Ref, + thread_ref: Ref, stack: Stack, } @@ -522,48 +704,39 @@ impl State { let state = Rc::new(GlobalState::new()?); let mut state = Self { stack: unsafe { Stack::new_unchecked(state.as_ptr()) }, - thread: Ref { + thread_ref: Ref { + // SAFETY: the main thread doesn't need a ref, it doesn't get garbage collected state, key: LUA_NOREF, }, }; - state.push_function_raw(Some(Self::open_cb), 0); - state.call(0, 0)?; + state.push(Function::Bare(Self::open_cb)); + state.call(0, Some(0))?; Ok(state) } - unsafe extern "C" fn open_cb(L: *mut lua_State) -> c_int { - unsafe { - luaL_openlibs(L); - luaJIT_openlibs(L); // luajit-sys extension to open jitlibs - 0 - } - } - - /// Creates a new empty thread (coroutine) associated with this state. - pub fn new_thread(&self) -> Self { - self.ensure(1); + /// Creates a new empty thread (coroutine) associated with the given state. + pub fn new_thread(state: &State) -> Self { Self { - // SAFETY: lua_newthread never returns null, but may panic on oom - stack: unsafe { Stack::new_unchecked(lua_newthread(self.as_ptr())) }, - thread: Ref { - state: Rc::clone(&self.thread.state), - key: unsafe { luaL_ref(self.as_ptr(), LUA_REGISTRYINDEX) }, + // SAFETY: must call lua_newthread first before ref'ing it + stack: unsafe { + state.ensure(1); + Stack::new_unchecked(lua_newthread(state.as_ptr())) + }, + thread_ref: Ref { + state: Rc::clone(&state.thread_ref.state), + key: unsafe { luaL_ref(state.as_ptr(), LUA_REGISTRYINDEX) }, }, } } - /// Creates a new [`Ref`] with the given key. - /// - /// # Safety - /// - /// The caller must ensure that the given ref key is unique and not already used by any other - /// instances of [`Ref`]. - pub unsafe fn new_ref_unchecked(&self, key: c_int) -> Ref { - Ref { - state: Rc::clone(&self.thread.state), - key, + fn open_cb(stack: &mut Stack) -> c_int { + unsafe { + // SAFETY: this is only ever called once on state initialisation + luaL_openlibs(stack.as_ptr()); // lua base libraries + luaJIT_openlibs(stack.as_ptr()); // luajit-sys extension to open jitlibs + 0 } } } @@ -610,50 +783,47 @@ impl DerefMut for State { /// It is guaranteed that the lifetime of an `&str` will not outlive the lifetime of the original /// `string` while the stack is immutable and the original string cannot be removed from the stack, /// preventing it from being garbage-collected by Lua. +#[derive(PartialEq, Eq, Hash)] #[repr(transparent)] pub struct Stack(NonNull); impl Stack { - /// Creates a new [`Stack`] from a raw pointer. - /// - /// # Safety - /// - /// The pointer must not be null. See also mutability guarantees on [`Stack`](Stack#mutability). - pub unsafe fn new_unchecked(ptr: *mut lua_State) -> Self { - assert!(!ptr.is_null(), "attempt to create Stack with null pointer"); + unsafe fn new_unchecked(ptr: *mut lua_State) -> Self { + debug_assert!(!ptr.is_null(), "attempt to create stack with null pointer"); Self(unsafe { NonNull::new_unchecked(ptr) }) } - /// Pointer to the [`lua_State`]. + /// Pointer to the underlying [`lua_State`]. pub fn as_ptr(&self) -> *mut lua_State { self.0.as_ptr() } /// Size of the stack. /// - /// This is the number of values on the stack and points to the value at the top of the stack - /// when used as an index. - /// /// Equivalent to [`lua_gettop`]. - pub fn size(&self) -> c_int { - unsafe { lua_gettop(self.as_ptr()) } + pub fn size(&self) -> u32 { + unsafe { lua_gettop(self.as_ptr()) as u32 } } /// Resizes the stack to fit exactly `n` values, reallocating the stack and popping any - /// extraneous values or pushing nils to fill the space as necessary. + /// extraneous values or pushing `nil`s to fill the space as necessary. /// /// Equivalent to [`lua_settop`]. /// /// # Panics /// - /// Panics if `n` is negative. - pub fn resize(&mut self, n: c_int) { - // SAFETY: lua_settop can throw on oom (doesn't cpgrowstack) when growing, so we call ensure - // first - assert!(0 <= n, "cannot resize to size {n}"); + /// Panics if the stack could not be resized. + pub fn resize(&mut self, n: u32) { + assert!( + n <= (LUAI_MAXCSTACK as u32), + "stack overflow: cannot resize stack to size > {LUAI_MAXCSTACK}" + ); + + // SAFETY: lua_settop can throw on oom (calls growstack not cpgrowstack) when growing, so we + // need to call ensure first if we might reallocate let size = self.size(); - (n > size).then(|| self.ensure(n - size)); - unsafe { lua_settop(self.as_ptr(), n) } + n.checked_sub(size).map(|n| self.ensure(n)); + unsafe { lua_settop(self.as_ptr(), n as c_int) } } /// Reallocates the stack to fit `n` more values, if necessary. @@ -665,11 +835,20 @@ impl Stack { /// /// # Panics /// - /// Panics if `n` is negative or reallocation fails. - pub fn ensure(&self, n: c_int) { - // lua_checkstack throws on oom in puc lua 5.1, but it is fine in luajit - assert!(n >= 0, "ensure called with a negative value"); - unsafe { assert!(lua_checkstack(self.as_ptr(), n) != 0, "stack out of memory") } + /// Panics if reallocation fails. + pub fn ensure(&self, n: u32) { + if n != 0 { + // lua_checkstack throws on oom in PUC lua 5.1, but it is fine in luajit + assert!( + n <= (LUAI_MAXCSTACK as u32), + "stack overflow: cannot reallocate stack to size > {LUAI_MAXCSTACK}" + ); + assert!( + unsafe { lua_checkstack(self.as_ptr(), n as c_int) != 0 }, + "stack overflow: failed to reallocate stack to size {len}", + len = self.size() + n, + ) + } } /// Pops `n` values at the top of the stack. @@ -679,13 +858,18 @@ impl Stack { /// # Panics /// /// Panics if there are less than `n` values on the stack. - pub fn pop(&mut self, n: c_int) { - assert!(0 <= n && n <= self.size(), "cannot pop {n}: {self:?}"); - unsafe { lua_pop(self.as_ptr(), n) } + pub fn pop(&mut self, n: u32) { + if n != 0 { + assert_slots!(self, n); + unsafe { lua_pop(self.as_ptr(), n as c_int) } + } } /// Pops the value at the top of the stack and inserts it at index `idx` by shifting up existing - /// values. + /// values, and returns the slot for that index. + /// + /// If the index `idx` points to the top of the stack, this does not pop anything and keeps the + /// stack untouched. /// /// Index `idx` cannot be a pseudo-index. /// @@ -694,33 +878,42 @@ impl Stack { /// # Panics /// /// Panics if the stack is empty or the index `idx` is invalid. - pub fn pop_insert(&mut self, idx: c_int) { - assert!(self.size() >= 1, "cannot pop 1: {self:?}"); - let idx = self.slot(idx).index(); - assert!(idx > 0, "cannot insert into pseudo-index {idx}"); - unsafe { lua_insert(self.as_ptr(), idx) } + pub fn pop_insert<'s>(&'s mut self, idx: impl ToSlot) -> Slot<'s> { + let top = self.slot(-1).index(); + let slot = self.slot(idx); + let idx = slot.index(); + if let Index::Stack(_) = idx { + unsafe { (idx != top).then(|| lua_insert(self.as_ptr(), idx.into_raw())) }; + } else { + panic!("cannot insert into pseudo-index {idx}"); + } + slot } - /// Pops the value at the top of the stack and replaces the value at index `idx` with it. + /// Pops the value at the top of the stack, replaces the value at index `idx` with it, and + /// returns the slot for that index. /// - /// If the index `idx` points to the top of the stack, this still pops the value and is - /// functionally equivalent to [`Stack::pop`] in that case. + /// If the index `idx` points to the top of the stack, this does not pop anything and keeps the + /// stack untouched. /// /// Equivalent to [`lua_replace`]. /// /// # Panics /// /// Panics if the stack is empty or the index `idx` is invalid. - pub fn pop_replace(&mut self, idx: c_int) { - assert!(self.size() >= 1, "cannot pop 1: {self:?}"); - unsafe { lua_replace(self.as_ptr(), self.slot(idx).index()) } + pub fn pop_replace<'s>(&'s mut self, idx: impl ToSlot) -> Slot<'s> { + let top = self.slot(-1).index(); + let slot = self.slot(idx); + let idx = slot.index(); + unsafe { (idx != top).then(|| lua_replace(self.as_ptr(), idx.into_raw())) }; + slot } /// Status of the current thread. /// /// Equivalent to [`lua_status`]. pub fn status(&self) -> Status { - Status::from_code(unsafe { lua_status(self.as_ptr()) }).unwrap() + Status::from_raw(unsafe { lua_status(self.as_ptr()) }).unwrap() } /// Iterator over all values on the stack. @@ -749,57 +942,59 @@ impl Stack { StackGuard::new(self, true) } + /// Slot for the value at the top of the stack, or [`None`] if the stack is empty. + pub fn top<'s>(&'s self) -> Option> { + let size = self.size(); + (size != 0).then(|| unsafe { self.slot_unchecked(Index::stack(size)) }) + } + /// Handle for the value at index `idx`. /// /// # Panics /// /// Panics if the index `idx` is invalid. - pub fn slot<'s>(&'s self, idx: c_int) -> Slot<'s> { - self.try_slot(idx) - .unwrap_or_else(|| panic!("invalid index {idx}: {self:?}")) + pub fn slot<'s>(&'s self, idx: impl ToSlot) -> Slot<'s> { + idx.to_slot(self) } - /// Handle for the value at index `idx`, or [`None`] if there is no value at that index. - pub fn try_slot<'s>(&'s self, idx: c_int) -> Option> { - self.absindex(idx) - .map(|idx| unsafe { Slot::new_unchecked(self, idx) }) + pub unsafe fn slot_unchecked<'s>(&'s self, idx: Index) -> Slot<'s> { + // SAFETY: the caller must ensure that the index is valid + unsafe { Slot::new_unchecked(self, idx) } } - fn absindex(&self, idx: c_int) -> Option { - if LUA_REGISTRYINDEX < idx && idx <= 0 { - // SAFETY: must check any relative index that gets passed to index2adr in lj_api.c - // because luajit doesn't check for out-of-bounds access for relative indices with - // assertions disabled - let top = self.size(); - let idx = top + idx + 1; - (0 < idx && idx <= top).then_some(idx) - } else { - unsafe { lua_type(self.as_ptr(), idx) != LUA_TNONE }.then_some(idx) - } + /// Slot for the upvalue of the current function at index `idx`. + pub fn upvalue<'s>(&'s self, idx: u32) -> Slot<'s> { + self.slot(Index::upvalue(idx)) + } + + /// Slot for the registry table. + pub fn registry<'s>(&'s self) -> Slot<'s> { + self.slot(PseudoIndex::Registry) + } + + /// Slot for the environment table of the current function. + pub fn environment<'s>(&'s self) -> Slot<'s> { + self.slot(PseudoIndex::Environment) + } + + /// Slot for the global environment table of the current thread. + pub fn globals<'s>(&'s self) -> Slot<'s> { + self.slot(PseudoIndex::Globals) } /// Pushes the given value at the top of the stack. /// /// Equivalent to the `lua_push*` family of functions depending on the type of `T`. - pub fn push(&mut self, value: T) { - value.push(self) + pub fn push<'s, T: Push>(&'s mut self, value: T) { + value.push(self); } - /// Pushes the given C function at the top of the stack. - /// - /// Equivalent to [`lua_pushcclosure`]. - /// - /// # Panics - /// - /// Panics if the given function pointer is null. - pub fn push_function_raw(&mut self, f: lua_CFunction, upvals: c_int) { - assert!(f.is_some(), "function must not be null"); - self.ensure(1); - unsafe { lua_pushcclosure(self.as_ptr(), f, upvals) } + pub fn push_idx(&mut self, idx: impl ToSlot) { + self.push(self.slot(idx).index()); } /// Gets a field of the table at index `idx` using the value at the top of the stack as the key, - /// and replaces the key with the retrieved value. + /// replaces the key with the field value, and returns the slot for it. /// /// This function does not invoke the `__index` metamethod. /// @@ -808,18 +1003,15 @@ impl Stack { /// # Panics /// /// Panics if the value at index `idx` is not a table. - pub fn get(&mut self, idx: c_int) { - assert!(self.size() >= 1, "expected 1 value: {self:?}"); - let table = self.slot(idx); - assert!( - table.type_of() == Type::Table, - "expected table at index {idx}: {self:?}" - ); - unsafe { lua_rawget(self.as_ptr(), table.index()) } + pub fn get<'s>(&'s mut self, idx: impl ToSlot) -> Slot<'s> { + let slot = assert_type!(self, idx, "table", Type::Table); + assert_slots!(self, 1); + unsafe { lua_rawget(self.as_ptr(), slot.index().into_raw()) } + slot } - /// Gets a field of the table at index `idx` using `n` as the key, and pushes the retrieved - /// value at the top of the stack. + /// Gets a field of the table at index `idx` using `n` as the key, pushes the field value at the + /// top of the stack, and returns the slot for it. /// /// This function does not invoke the `__index` metamethod. /// @@ -828,14 +1020,17 @@ impl Stack { /// # Panics /// /// Panics if the value at index `idx` is not a table. - pub fn geti(&mut self, idx: c_int, n: c_int) { - let table = self.slot(idx); - assert!( - table.type_of() == Type::Table, - "expected table at index {idx}: {self:?}" - ); + pub fn geti<'s>(&'s mut self, idx: impl ToSlot, n: isize) -> Slot<'s> { + let slot = assert_type!(self, idx, "table", Type::Table); self.ensure(1); - unsafe { lua_rawgeti(self.as_ptr(), table.index(), n) } + match n.try_into() { + Ok(n) => unsafe { lua_rawgeti(self.as_ptr(), slot.index().into_raw(), n) }, + Err(_) => unsafe { + lua_pushinteger(self.as_ptr(), n as lua_Integer); + lua_rawget(self.as_ptr(), slot.index().into_raw()); + }, + } + slot } /// Sets a field of the table at index `idx` using two values at the top of the stack as the key @@ -850,14 +1045,10 @@ impl Stack { /// # Panics /// /// Panics if the value at index `idx` is not a table. - pub fn set(&mut self, idx: c_int) { - assert!(self.size() >= 2, "expected 2 values: {self:?}"); - let table = self.slot(idx); - assert!( - table.type_of() == Type::Table, - "expected table at index {idx}: {self:?}" - ); - unsafe { lua_rawset(self.as_ptr(), table.index()) } + pub fn set(&mut self, idx: impl ToSlot) { + let slot = assert_type!(self, idx, "table", Type::Table); + assert_slots!(self, 2); + unsafe { lua_rawset(self.as_ptr(), slot.index().into_raw()) } } /// Sets a field of the table at index `idx` using `n` is the key and the value at the top of @@ -872,14 +1063,49 @@ impl Stack { /// # Panics /// /// Panics if the value at index `idx` is not a table. - pub fn seti(&mut self, idx: c_int, n: c_int) { - assert!(self.size() >= 1, "expected 1 value: {self:?}"); - let table = self.slot(idx); - assert!( - table.type_of() == Type::Table, - "expected table at index {idx}: {self:?}" - ); - unsafe { lua_rawseti(self.as_ptr(), table.index(), n) } + pub fn seti(&mut self, idx: impl ToSlot, n: isize) { + let slot = assert_type!(self, idx, "table", Type::Table); + assert_slots!(self, 1); + match n.try_into() { + Ok(n) => unsafe { lua_rawseti(self.as_ptr(), slot.index().into_raw(), n) }, + Err(_) => unsafe { + self.ensure(1); + lua_pushinteger(self.as_ptr(), n as lua_Integer); + lua_insert(self.as_ptr(), -2); + lua_rawset(self.as_ptr(), slot.index().into_raw()); + }, + } + } + + /// Gets the metatable of the value at index `idx`, pushes it onto the stack, and returns the + /// slot for that metatable. + /// + /// If the value at index `idx` does not have a metatable, then nothing is pushed onto the stack + /// and [`None`] is returned. + /// + /// Metatables can be attached to non-table values including primitives like numbers, strings, + /// and booleans. + pub fn get_metatable<'s>(&'s mut self, idx: impl ToSlot) -> Option> { + let slot = self.slot(idx); + self.ensure(1); + unsafe { lua_getmetatable(self.as_ptr(), slot.index().into_raw()) != 0 } + .then(|| self.slot(-1)) + } + + /// Sets the metatable of the value at index `idx` to the value at the top of the stack, and + /// returns the slot for that value. + /// + /// The value at the top of the stack can be `nil` to remove the metatable, or a table to set it + /// as the new metatable. This value is popped from the stack and set as the metatable of the + /// value at index `idx`. + /// + /// Metatables can be attached to non-table values including primitives like numbers, strings, + /// and booleans. + pub fn set_metatable<'s>(&'s mut self, idx: impl ToSlot) -> Slot<'s> { + let slot = self.slot(idx); + assert_type!(self, -1, "table or nil", Type::Table | Type::Nil); + unsafe { lua_setmetatable(self.as_ptr(), slot.index().into_raw()) }; // always returns 1 + slot } /// Packs the array-part of the table at index `idx` from the stack. @@ -898,27 +1124,30 @@ impl Stack { /// /// Panics if `n` is negative, there are not enough values on the stack, or the value at index /// `idx` is not a table. - pub fn pack(&mut self, idx: c_int, n: c_int) -> c_int { - assert!(n >= 0, "n must be nonnegative"); - let size = self.size(); - let table = self.slot(idx); - assert!( - table.type_of() == Type::Table, - "expected table at index {idx}: {self:?}" - ); - assert!(n <= size, "expected {n} values: {self:?}"); - assert!(idx <= size - n, "cannot pack a table into itself: {self:?}"); + pub fn pack(&mut self, idx: impl ToSlot, n: u32) -> u32 { + let slot = assert_type!(self, idx, "table", Type::Table); + let base = assert_slots!(self, n); + if let Index::Stack(idx) = slot.index() { + assert!( + idx.get() <= base, + "cannot pack a table into itself: {self:?}" + ); + } + unsafe { - (0..n).for_each(|i| lua_rawseti(self.as_ptr(), table.index(), n - i)); + for i in 0..n { + lua_rawseti(self.as_ptr(), slot.index().into_raw(), (n - i) as c_int); + } self.ensure(2); lua_pushliteral(self.as_ptr(), "n"); lua_pushinteger(self.as_ptr(), n as lua_Integer); - lua_rawset(self.as_ptr(), table.index()); + lua_rawset(self.as_ptr(), slot.index().into_raw()); n } } - /// Unpacks the array-part of the table at index `idx` onto the stack. + /// Unpacks the array-part of the table at index `idx` onto the stack, and returns the number of + /// values pushed. /// /// If `j` is [`None`], then it is set to be the value of the field `"n"` interpreted as an /// integer. If this field does not exist, then it is set to the be length of the table as @@ -934,31 +1163,41 @@ impl Stack { /// # Panics /// /// Panics if the value at index `idx` is not a table. - pub fn unpack(&mut self, idx: c_int, i: c_int, j: Option) -> c_int { - let table = self.slot(idx); - assert!( - table.type_of() == Type::Table, - "expected table at index {idx}: {self:?}" - ); + pub fn unpack(&mut self, idx: impl ToSlot, i: isize, j: Option) -> u32 { + let slot = assert_type!(self, idx, "table", Type::Table); let j = match j { Some(j) => j, None => unsafe { self.ensure(1); lua_pushliteral(self.as_ptr(), "n"); - lua_rawget(self.as_ptr(), table.index()); + lua_rawget(self.as_ptr(), slot.index().into_raw()); let mut isnum = 0; let n = lua_tointegerx(self.as_ptr(), -1, &raw mut isnum); lua_pop(self.as_ptr(), 1); (isnum != 0) - .then_some(n as c_int) - .unwrap_or_else(|| lua_objlen(self.as_ptr(), table.index()) as c_int) + .then_some(n as isize) + .unwrap_or_else(|| slot.length() as isize) }, }; - let n = (j - i + 1).max(0); + + let n = (j - i + 1) + .max(0) + .try_into() + .expect("too many values to unpack"); + if n > 0 { self.ensure(n); - (0..n).for_each(|n| unsafe { lua_rawgeti(self.as_ptr(), table.index(), i + n) }); + for k in i..=j { + match k.try_into() { + Ok(n) => unsafe { lua_rawgeti(self.as_ptr(), slot.index().into_raw(), n) }, + Err(_) => unsafe { + lua_pushinteger(self.as_ptr(), n as lua_Integer); + lua_rawget(self.as_ptr(), slot.index().into_raw()); + }, + } + } } + n } @@ -970,23 +1209,22 @@ impl Stack { /// stack is returned. /// /// Equivalent to `require(name)`. - pub fn require(&mut self, name: impl AsRef<[u8]>, nret: c_int) -> Result { + pub fn require(&mut self, name: impl AsRef<[u8]>, nret: Option) -> Result { self.push("require"); - self.get(LUA_GLOBALSINDEX); + self.get(PseudoIndex::Globals); self.push(name.as_ref()); self.call(1, nret) } - /// Pushes the given chunk as a function at the top of the stack. + /// Pushes the given chunk as a function at the top of the stack and returns the slot for the + /// new function. /// /// Equivalent to [`lua_loadx`]. pub fn load(&mut self, chunk: &Chunk) -> Result<()> { type State<'s> = Option<&'s [u8]>; let mut state: State = Some(chunk.content.as_ref()); - let name = CString::new(chunk.name.to_vec()).map_err(Error::BadChunkName)?; - let mode = chunk.mode.to_mode_str(); - unsafe extern "C" fn reader_cb( + unsafe extern "C" fn read_cb( _L: *mut lua_State, state: *mut c_void, size: *mut usize, @@ -1002,12 +1240,17 @@ impl Stack { } } + let name = CString::new(chunk.name.to_vec()) + .map_err(|err| Error::syntax(format!("invalid chunk name: {err}"), chunk.clone()))?; + + let mode = chunk.mode.mode_str(); + self.ensure(1); match unsafe { lua_loadx( self.as_ptr(), - Some(reader_cb), + Some(read_cb), &raw mut state as *mut c_void, name.as_ptr(), mode.as_ptr(), @@ -1016,18 +1259,17 @@ impl Stack { LUA_OK => Ok(()), LUA_ERRMEM => { self.pop(1); - Err(Error::OutOfMemory) + Err(Error::memory()) } LUA_ERRSYNTAX => { - let chunk = name.into_bytes().into(); let msg = self .slot(-1) .string() .unwrap_or(b"unknown error".into()) - .into(); + .to_owned(); self.pop(1); - Err(Error::Syntax { chunk, msg }) + Err(Error::syntax(msg, chunk.clone())) } _ => unreachable!(), } @@ -1040,23 +1282,18 @@ impl Stack { /// # Panics /// /// Panics if the value at index `idx` is not a function. - pub fn dump(&self, idx: c_int, mode: DumpMode) -> Result { - let func = self.slot(idx); - assert!( - func.type_of() == Type::Function, - "expected function at index {idx}: {self:?}" - ); + pub fn dump(&self, idx: impl ToSlot, mode: DumpMode) -> Result { + let idx = assert_type!(self, idx, "function", Type::Function).index(); unsafe { - let idx = func.index(); let mut s = self.guard_unchecked(); s.push("string"); - s.get(LUA_GLOBALSINDEX); + s.get(PseudoIndex::Globals); s.push("dump"); s.get(-2); // local dump = string.dump - s.push(Index(idx)); - s.push(mode.to_mode_str().to_bytes()); - s.call(2, 1)?; + s.push(idx); + s.push(mode.mode_str().to_bytes()); + s.call(2, Some(1))?; s.slot(-1).parse() // return dump(idx, mode) } } @@ -1070,10 +1307,8 @@ impl Stack { /// # Panics /// /// Panics if there are not enough values on the stack or thread status is invalid. - pub fn eval(&mut self, chunk: &Chunk, narg: c_int, nret: c_int) -> Result { - assert!(0 <= narg && (0 <= nret || nret == LUA_MULTRET)); - let base = self.size() - narg; - assert!(base >= 0, "expected {narg} values: {self:?}"); + pub fn eval(&mut self, chunk: &Chunk, narg: u32, nret: Option) -> Result { + let base = assert_slots!(self, narg); self.load(chunk)?; self.pop_insert(base + 1); self.call(narg, nret) @@ -1085,9 +1320,9 @@ impl Stack { /// There must be `narg + 1` values at the top of the stack, including the function to call at /// the index `top - narg` (i.e. the function is pushed first and then `narg` values as /// arguments). All arguments and the function are popped from the stack and then any return - /// values are pushed. If `nret` is not [`LUA_MULTRET`], then the number of return values pushed - /// will be exactly `nret`, filling with nils if necessary. Finally, the number of values pushed - /// to the stack is returned. + /// values are pushed. If `nret` is not [`None`], then the number of return values pushed will + /// be exactly `nret`, filling with nils if necessary. Finally, the number of values pushed to + /// the stack is returned. /// /// The current thread status must not be suspended or dead. /// @@ -1097,42 +1332,43 @@ impl Stack { /// /// Panics if there are not enough values on the stack, the function to call is not on the /// stack, or thread status is invalid. - pub fn call(&mut self, narg: c_int, nret: c_int) -> Result { - assert!(0 <= narg && (0 <= nret || nret == LUA_MULTRET)); - - let top = self.size(); - let need = narg + 1; // need the function on the stack - let base = top - need; - - assert!(base >= 0, "expected {need} values: {self:?}"); - assert!( - self.slot(base + 1).type_of() == Type::Function, - "expected function at index {}: {self:?}", - base + 1 - ); - assert!( - self.status() == Status::Normal, - "thread {self:p} called in wrong state" - ); + pub fn call(&mut self, narg: u32, nret: Option) -> Result { + let base = match self.status() { + Status::Normal => { + // need the function on the stack + let base = assert_slots!(self, narg + 1); + assert_type!(self, base + 1, "function", Type::Function); + base + } + status => panic!("thread {self:p} called in wrong state: {status}"), + }; // TODO: use error handler to collect backtrace - match unsafe { lua_pcall(self.as_ptr(), narg, nret, 0) } { + match unsafe { + lua_pcall( + self.as_ptr(), + narg as c_int, + nret.map_or(LUA_MULTRET, |n| { + n.try_into().expect("too many return values") + }), + 0, + ) + } { LUA_OK => Ok(self.size() - base), LUA_ERRMEM => { self.pop(1); - Err(Error::OutOfMemory) + Err(Error::memory()) } - LUA_ERRRUN | LUA_ERRERR => { + code => { let msg = self .slot(-1) .string() .unwrap_or(b"unknown error".into()) - .into(); + .to_owned(); self.pop(1); - Err(Error::Call { msg }) + Err(Error::call(ErrorKind::from_raw(code).unwrap(), msg, None)) } - _ => unreachable!(), } } @@ -1159,49 +1395,44 @@ impl Stack { /// /// Panics if there are not enough values on the stack, the function to call is not on the /// stack, or thread status is invalid. - pub async fn call_async(&mut self, mut narg: c_int, nret: c_int) -> Result { - assert!(0 <= narg && (0 <= nret || nret == LUA_MULTRET)); - - let top = self.size(); - let need = narg + 1; // need the function on the stack - let base = top - need; - - assert!(base >= 0, "expected {need} values: {self:?}"); - assert!( - self.slot(base + 1).type_of() == Type::Function, - "expected function at index {}: {self:?}", - base + 1 - ); - assert!( - self.status() == Status::Normal, - "thread {self:p} called in wrong state" - ); + pub async fn call_async(&mut self, mut narg: u32, nret: Option) -> Result { + let base = match self.status() { + Status::Normal => { + // need the function on the stack + let base = assert_slots!(self, narg + 1); + assert_type!(self, base + 1, "function", Type::Function); + base + } + status => panic!("thread {self:p} called in wrong state: {status}"), + }; loop { match self.resume(narg)? { ResumeStatus::Ok => { - if nret == LUA_MULTRET { - break Ok(self.size() - base); - } else { - self.resize(base + nret); - break Ok(nret); - } + break Ok(match nret { + Some(n) => { + self.resize(base + n); + n + } + None => self.size() - base, + }); } - ResumeStatus::Suspended => { + ResumeStatus::Suspended => unsafe { narg = 0; self.resize(1); - let ptr = self.slot(1).cdata::().cast_mut(); - if !ptr.is_null() { - // SAFETY: rust futures boxed in cdata payloads are never moved by the GC so - // we can safely make a Pin out of this pointer. see also comments in - // `luaffi::future::lua_future`. - let fut = unsafe { Pin::new_unchecked(&mut *ptr) }; - if fut.is_valid() { - fut.await; - } + // SAFETY: Rust futures boxed in cdata payloads are never moved by the GC so we + // can safely make a Pin out of this pointer. See also comments in + // `luaffi::future::lua_future`. + if let Some(fut) = self + .slot(1) + .cdata_mut::() + .map(|ptr| Pin::new_unchecked(ptr)) + && fut.is_valid() + { + fut.await; } - } + }, } } } @@ -1223,41 +1454,40 @@ impl Stack { /// /// Panics if there are not enough values on the stack, the function to call is not on the /// stack, or thread status is invalid. - pub fn resume(&mut self, narg: c_int) -> Result { - assert!(0 <= narg); - let status = self.status(); - let need = match status { - Status::Normal => narg + 1, // need the function on the stack - Status::Suspended => narg, - Status::Dead => panic!("thread {self:p} resumed in wrong state"), + pub fn resume(&mut self, narg: u32) -> Result { + match self.status() { + Status::Normal => { + // need the function on the stack + let base = assert_slots!(self, narg + 1); + assert_type!(self, base + 1, "function", Type::Function); + base + } + Status::Suspended => assert_slots!(self, narg), + status => panic!("thread {self:p} resumed in wrong state: {status}"), }; - let base = self.size() - need; - assert!(base >= 0, "expected {need} values: {self:?}"); - assert!( - status == Status::Suspended || self.slot(base + 1).type_of() == Type::Function, - "expected function at index {}: {self:?}", - base + 1 - ); - match unsafe { lua_resume(self.as_ptr(), narg) } { + match unsafe { lua_resume(self.as_ptr(), narg as c_int) } { LUA_OK => Ok(ResumeStatus::Ok), LUA_YIELD => Ok(ResumeStatus::Suspended), LUA_ERRMEM => { self.pop(1); - Err(Error::OutOfMemory) + Err(Error::memory()) } - LUA_ERRRUN | LUA_ERRERR => { + code => { let msg = self .slot(-1) .string() .unwrap_or(b"unknown error".into()) - .into(); + .to_owned(); self.pop(1); - let trace = self.backtrace(0); - Err(Error::Resume { msg, trace }) + + Err(Error::call( + ErrorKind::from_raw(code).unwrap(), + msg, + self.backtrace(0), + )) } - _ => unreachable!(), } } @@ -1269,11 +1499,16 @@ impl Stack { /// which automatically provides the backtrace with [`Error::trace`] if it is available. /// /// Equivalent to [`luaL_traceback`]. - pub fn backtrace(&self, level: c_int) -> Option { - assert!(level >= 0, "level must be nonnegative"); - self.ensure(LUA_MINSTACK); + pub fn backtrace(&self, level: u32) -> Option { unsafe { - luaL_traceback(self.as_ptr(), self.as_ptr(), ptr::null(), 0); + self.ensure(LUA_MINSTACK as u32); + + luaL_traceback( + self.as_ptr(), // thread to push the trace onto + self.as_ptr(), // thread to trace + ptr::null(), // message prefixed to trace string + level.try_into().expect("invalid trace level"), + ); // SAFETY: must clone the trace string here before popping it off the stack let trace = self @@ -1281,7 +1516,7 @@ impl Stack { .string() .map(|s| s.strip_prefix(b"stack traceback:\n").unwrap_or(s).as_bstr()) .filter(|s| !s.is_empty()) - .map(|s| s.into()); + .map(|s| s.to_owned()); lua_pop(self.as_ptr(), 1); trace @@ -1291,33 +1526,17 @@ impl Stack { impl fmt::Debug for Stack { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - struct PointerValue(&'static str, *const c_void); - impl<'s> fmt::Debug for PointerValue { + struct Slots<'s>(&'s Stack); + impl<'s> fmt::Debug for Slots<'s> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} {:p}", self.0, self.1) - } - } - - struct Values<'s>(&'s Stack); - impl<'s> fmt::Debug for Values<'s> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut list = f.debug_list(); - for value in self.0.iter() { - match value.type_of() { - Type::Nil => list.entry(&"nil"), - Type::Boolean => list.entry(&value.boolean()), - Type::Number => list.entry(&value.number()), - Type::String => list.entry(&value.string().unwrap()), - ty => list.entry(&PointerValue(ty.name(), value.pointer())), - }; - } - list.finish() + f.debug_list().entries(self.0.iter()).finish() } } f.debug_struct("Stack") .field("ptr", &self.0) - .field("values", &Values(self)) + .field("size", &self.size()) + .field("slots", &Slots(self)) .finish() } } @@ -1332,7 +1551,7 @@ impl fmt::Debug for Stack { pub struct StackGuard<'s> { parent: PhantomData<&'s mut Stack>, stack: Stack, - size: c_int, + base: u32, check_overpop: bool, } @@ -1341,7 +1560,7 @@ impl<'s> StackGuard<'s> { Self { parent: PhantomData, stack: unsafe { Stack::new_unchecked(stack.as_ptr()) }, // SAFETY: stack.as_ptr() is never null - size: stack.size(), + base: stack.size(), check_overpop, } } @@ -1366,13 +1585,13 @@ impl<'s> Drop for StackGuard<'s> { if cfg!(debug_assertions) && self.check_overpop { let new_size = self.stack.size(); assert!( - self.size <= new_size, - "StackGuard detected over-popping by {} values (this is UB!)", - self.size - new_size + self.base <= new_size, + "StackGuard detected over-popping by {n} values (this is UB!!)", + n = self.base - new_size, ); } - self.stack.resize(self.size); + self.stack.resize(self.base); } } @@ -1382,14 +1601,17 @@ impl<'s> Drop for StackGuard<'s> { #[derive(Debug)] pub struct StackIter<'s> { stack: &'s Stack, - idx: c_int, - top: c_int, + idx: u32, + size: u32, } impl<'s> StackIter<'s> { fn new(stack: &'s Stack) -> Self { - let top = stack.size(); - Self { stack, idx: 0, top } + Self { + stack, + idx: 0, + size: stack.size(), + } } } @@ -1397,50 +1619,181 @@ impl<'s> Iterator for StackIter<'s> { type Item = Slot<'s>; fn next(&mut self) -> Option { - (self.idx < self.top).then(|| { + (self.idx < self.size).then(|| { self.idx += 1; - unsafe { Slot::new_unchecked(self.stack, self.idx) } + unsafe { self.stack.slot_unchecked(Index::stack(self.idx)) } }) } } +/// Index of a value in a [`Stack`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Index { + Stack(NonZero), + Upvalue(NonZero), + Pseudo(PseudoIndex), +} + +impl Index { + pub fn stack(idx: u32) -> Self { + Self::Stack(NonZero::new(idx).expect("stack index cannot be zero")) + } + + pub fn upvalue(idx: u32) -> Self { + Self::Upvalue(NonZero::new(idx).expect("upvalue index cannot be zero")) + } + + pub fn registry() -> Self { + Self::Pseudo(PseudoIndex::Registry) + } + + pub fn environment() -> Self { + Self::Pseudo(PseudoIndex::Environment) + } + + pub fn globals() -> Self { + Self::Pseudo(PseudoIndex::Globals) + } + + fn _from_raw(idx: c_int) -> Self { + match idx { + 0 => panic!("index cannot be zero"), + idx if idx > 0 => Self::stack(idx as u32), + idx if idx < LUA_GLOBALSINDEX => Self::upvalue((LUA_GLOBALSINDEX - idx) as u32), + _ => match PseudoIndex::_from_raw(idx) { + Some(idx) => Self::Pseudo(idx), + None => panic!("invalid pseudo-index {idx}"), + }, + } + } + + fn into_raw(self) -> c_int { + match self { + Self::Stack(idx) => idx.get().try_into().expect("stack index overflow"), + Self::Upvalue(idx) => idx + .get() + .try_into() + .ok() + .and_then(|idx| LUA_GLOBALSINDEX.checked_sub(idx)) + .expect("upvalue index overflow"), + Self::Pseudo(idx) => idx.into_raw(), + } + } +} + +impl ops::Add for Index { + type Output = Index; + + fn add(self, rhs: u32) -> Self::Output { + match self { + Self::Stack(idx) => { + Self::stack(idx.get().checked_add(rhs).expect("stack index overflow")) + } + Self::Upvalue(idx) => { + Self::upvalue(idx.get().checked_add(rhs).expect("upvalue index overflow")) + } + Self::Pseudo(idx) => panic!("cannot add offset to {idx} pseudo-index"), + } + } +} + +impl ops::Sub for Index { + type Output = Index; + + fn sub(self, rhs: u32) -> Self::Output { + match self { + Self::Stack(idx) => { + Self::stack(idx.get().checked_sub(rhs).expect("stack index underflow")) + } + Self::Upvalue(idx) => { + Self::upvalue(idx.get().checked_sub(rhs).expect("upvalue index underflow")) + } + Self::Pseudo(idx) => panic!("cannot subtract offset from {idx} pseudo-index"), + } + } +} + +impl fmt::Display for Index { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Stack(idx) => write!(f, "stack#{idx}"), + Self::Upvalue(idx) => write!(f, "upvalue#{idx}"), + Self::Pseudo(idx) => write!(f, "{idx}"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum PseudoIndex { + Registry, + Environment, + Globals, +} + +impl PseudoIndex { + fn _from_raw(idx: c_int) -> Option { + Some(match idx { + LUA_REGISTRYINDEX => Self::Registry, + LUA_ENVIRONINDEX => Self::Environment, + LUA_GLOBALSINDEX => Self::Globals, + _ => return None, + }) + } + + fn into_raw(self) -> c_int { + match self { + Self::Registry => LUA_REGISTRYINDEX, + Self::Environment => LUA_ENVIRONINDEX, + Self::Globals => LUA_GLOBALSINDEX, + } + } +} + +impl fmt::Display for PseudoIndex { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Registry => write!(f, "registry"), + Self::Environment => write!(f, "environment"), + Self::Globals => write!(f, "globals"), + } + } +} + /// Lua value handle into the stack. +#[derive(Clone, Copy)] pub struct Slot<'s> { stack: &'s Stack, - idx: c_int, + idx: Index, } impl<'s> Slot<'s> { - /// Creates a new [`Slot`] for given index into the stack. - /// - /// # Safety - /// - /// Index `idx` must be a valid absolute index. - pub unsafe fn new_unchecked(stack: &'s Stack, idx: c_int) -> Self { + unsafe fn new_unchecked(stack: &'s Stack, idx: Index) -> Self { + debug_assert!( + unsafe { lua_type(stack.as_ptr(), idx.into_raw()) != LUA_TNONE }, + "invalid stack index {idx}: {stack:?}" + ); + Self { stack, idx } } - /// Index of this slot within the stack. - /// - /// This value is always a valid absolute positive index or a negative pseudo-index. It is never - /// a negative relative index. - pub fn index(&self) -> c_int { + pub fn index(&self) -> Index { self.idx } /// Type of the value in this slot. - pub fn type_of(&self) -> Type { - Type::from_code(unsafe { lua_type(self.stack.as_ptr(), self.idx) }).unwrap_or(Type::Nil) + pub fn ty(&self) -> Type { + Type::from_raw(unsafe { lua_type(self.stack.as_ptr(), self.idx.into_raw()) }) + .unwrap_or(Type::Nil) } - /// Parses the value in this slot as a `T`. + /// Parses the value in this slot as the type `T`. pub fn parse>(&self) -> Result { T::parse(self) } /// Parses the value in this slot as a [`bool`]. /// - /// If the value is not a `boolean`, then this returns false. + /// If the value is not a `boolean`, then this returns false. `nil` is always considered false. /// /// Equivalent to [`lua_toboolean`]. pub fn boolean(&self) -> bool { @@ -1453,7 +1806,7 @@ impl<'s> Slot<'s> { /// /// Equivalent to [`lua_touserdata`]. pub fn lightuserdata(&self) -> *mut T { - (self.type_of() == Type::Lightuserdata) + (self.ty() == Type::Lightuserdata) .then(|| self.parse().ok()) .flatten() .unwrap_or(ptr::null_mut()) @@ -1503,47 +1856,195 @@ impl<'s> Slot<'s> { /// Parses the value in this slot as a [`lua_CFunction`]. /// - /// If the value is not a C function, then this returns [`None`]. + /// If the value is not a raw C function, then this returns [`None`]. /// /// Equivalent to [`lua_tocfunction`]. pub fn function_raw(&self) -> lua_CFunction { - unsafe { lua_tocfunction(self.stack.as_ptr(), self.idx) } + unsafe { lua_tocfunction(self.stack.as_ptr(), self.idx.into_raw()) } } - /// Parses the value in this slot as a `cdata` pointer. + /// Parses the value in this slot as a `cdata` and returns an immutable reference to its + /// payload. /// - /// If the value is a `cdata`, then the returned pointer is the address of the base of the cdata - /// payload. Otherwise this returns a null pointer. + /// See [`cdata_ptr`](Self::cdata_ptr) regarding safety. + pub unsafe fn cdata(&self) -> Option<&T> { + let ptr = self.cdata_ptr::(); + (!ptr.is_null()).then(|| unsafe { &*ptr }) + } + + /// Parses the value in this slot as a `cdata` and returns a mutable reference to its payload. + /// + /// See [`cdata_ptr_mut`](Self::cdata_ptr_mut) regarding safety. + pub unsafe fn cdata_mut(&self) -> Option<&mut T> { + let ptr = self.cdata_ptr_mut::(); + (!ptr.is_null()).then(|| unsafe { &mut *ptr }) + } + + /// Parses the value in this slot as a `cdata` and returns a mutable pointer to its payload. + /// + /// If the value is a `cdata`, then the returned pointer is the address of the base of the its + /// payload. Otherwise, this returns a null pointer. + /// + /// Nothing is done to ensure that the payload is of the type `T`. If only the pointer is needed + /// and not its payload, then it is recommended for `T` to be [`c_void`]. + /// + /// Refer to LuaJIT's [FFI semantics](https://luajit.org/ext_ffi_semantics.html) more + /// documentation regarding cdata objects. /// /// Equivalent to [`lua_topointer`]. - pub fn cdata(&self) -> *const T { - (self.type_of() == Type::Cdata) + pub fn cdata_ptr(&self) -> *const T { + (self.ty() == Type::Cdata) .then(|| self.pointer().cast()) .unwrap_or(ptr::null_mut()) } + /// Parses the value in this slot as a `cdata` and returns a mutable pointer to its payload. + /// + /// If the value is a `cdata`, then the returned pointer is the address of the base of the its + /// payload. Otherwise, this returns a null pointer. + /// + /// Nothing is done to ensure that the payload is of the type `T`. If only the pointer is needed + /// and not its payload, then it is recommended for `T` to be [`c_void`]. + /// + /// Refer to LuaJIT's [FFI semantics](https://luajit.org/ext_ffi_semantics.html) more + /// documentation regarding cdata objects. In particular, reference cdata objects are immutable + /// after initialisation and must not be modified ("no re-seating of references"). + /// + /// Equivalent to [`lua_topointer`]. + pub fn cdata_ptr_mut(&self) -> *mut T { + self.cdata_ptr::().cast_mut() + } + /// Parses the value in this slot as a generic pointer. /// - /// If the value is not a GC-managed object that can be represented by a pointer, then this - /// returns a null pointer. + /// If the value is not a garbage-collected object that can be represented by a pointer (i.e. a + /// non-`nil` primitive value), then this returns a null pointer. /// /// Equivalent to [`lua_topointer`]. pub fn pointer(&self) -> *const c_void { - unsafe { lua_topointer(self.stack.as_ptr(), self.idx).cast() } + unsafe { lua_topointer(self.stack.as_ptr(), self.idx.into_raw()).cast() } } /// Returns the length of the value in this slot. /// - /// For strings, this is the byte-length of the contents of the string. For tables, this is the - /// length of the table defined by the Lua `#` operator. For userdata, this is the size of its - /// payload in bytes. For numbers, this is equivalent to converting the value in-place into a - /// `string` representation before calculating its length. Otherwise, this returns 0. + /// For strings, this is the byte-length of the string. For tables, this is the length of the + /// table defined by the Lua `#` operator. For userdata, this is the size of its payload in + /// bytes. Otherwise, this returns 0. /// /// This function does not invoke the `__len` metamethod. /// /// Equivalent to [`lua_objlen`]. pub fn length(&self) -> usize { - unsafe { lua_objlen(self.stack.as_ptr(), self.idx) } + matches!(self.ty(), Type::String | Type::Table | Type::Userdata) + .then(|| unsafe { lua_objlen(self.stack.as_ptr(), self.idx.into_raw()) }) + .unwrap_or(0) + } +} + +impl<'s> fmt::Debug for Slot<'s> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.ty() { + Type::Nil => write!(f, "nil"), + Type::Boolean => write!(f, "{}", self.boolean()), + Type::Number => write!(f, "{}", self.number().unwrap()), + Type::String => fmt::Debug::fmt(self.string().unwrap(), f), + ty => write!(f, "{ty} {:p}", self.pointer()), + } + } +} + +pub trait ToSlot { + fn to_slot<'s>(&self, stack: &'s Stack) -> Slot<'s>; +} + +impl ToSlot for &T +where + T: ToSlot, +{ + fn to_slot<'s>(&self, stack: &'s Stack) -> Slot<'s> { + (*self).to_slot(stack) + } +} + +impl ToSlot for Slot<'_> { + fn to_slot<'s>(&self, stack: &'s Stack) -> Slot<'s> { + assert!(self.stack == stack); + unsafe { stack.slot_unchecked(self.idx) } + } +} + +impl ToSlot for Index { + fn to_slot<'s>(&self, stack: &'s Stack) -> Slot<'s> { + match self { + Self::Pseudo(_) => {} + Self::Stack(idx) => { + assert!( + idx.get() <= stack.size(), + "stack underflow: expected at least {idx} values: {stack:?}" + ); + } + Self::Upvalue(idx) => { + assert!( + unsafe { lua_type(stack.as_ptr(), self.into_raw()) != LUA_TNONE }, + "stack underflow: expected at least {idx} upvalues" + ); + } + } + + unsafe { stack.slot_unchecked(*self) } + } +} + +impl ToSlot for PseudoIndex { + fn to_slot<'s>(&self, stack: &'s Stack) -> Slot<'s> { + unsafe { stack.slot_unchecked(Index::Pseudo(*self)) } + } +} + +impl ToSlot for u32 { + fn to_slot<'s>(&self, stack: &'s Stack) -> Slot<'s> { + stack.slot(Index::stack(*self)) + } +} + +impl ToSlot for i32 { + fn to_slot<'s>(&self, stack: &'s Stack) -> Slot<'s> { + let idx = *self; + if idx >= 0 { + return stack.slot(idx as u32); + } + let size = stack.size(); + let offset = (-idx) as u32; + assert!( + offset <= size, + "stack underflow: expected at least {offset} values: {stack:?}" + ); + unsafe { stack.slot_unchecked(Index::stack(size - offset + 1)) } + } +} + +impl<'s> ops::Add for Slot<'s> { + type Output = Slot<'s>; + + fn add(self, n: u32) -> Self::Output { + if n == 0 { + self + } else { + self.stack.slot(self.idx + n) + } + } +} + +impl<'s> ops::Sub for Slot<'s> { + type Output = Slot<'s>; + + fn sub(self, n: u32) -> Self::Output { + if n == 0 { + self + } else { + // SAFETY: subtracted index is guaranteed to be valid + unsafe { self.stack.slot_unchecked(self.idx - n) } + } } } @@ -1666,19 +2167,40 @@ impl Push for Ref { } } -/// [`Push`]es a copy of the value at an index onto a [`Stack`]. -/// -/// Equivalent to [`lua_pushvalue`]. -#[derive(Debug, Default, Clone, Copy, Hash)] -pub struct Index( - /// Index of the value to copy. - pub c_int, -); - impl Push for Index { fn push(&self, stack: &mut Stack) { + match self { + Self::Pseudo(_) => {} + Self::Stack(idx) => { + assert!( + idx.get() <= stack.size(), + "stack underflow: expected at least {idx} values: {stack:?}" + ); + } + Self::Upvalue(idx) => { + assert!( + unsafe { lua_type(stack.as_ptr(), self.into_raw()) != LUA_TNONE }, + "stack underflow: expected at least {idx} upvalues" + ) + } + } + stack.ensure(1); - unsafe { lua_pushvalue(stack.as_ptr(), stack.slot(self.0).index()) } + unsafe { lua_pushvalue(stack.as_ptr(), self.into_raw()) }; + } +} + +impl Push for PseudoIndex { + fn push(&self, stack: &mut Stack) { + Index::Pseudo(*self).push(stack); + } +} + +impl<'s> Push for Slot<'s> { + fn push(&self, stack: &mut Stack) { + assert!(self.stack == stack); // TODO: check global_State are equal and xmove instead + stack.ensure(1); + unsafe { lua_pushvalue(stack.as_ptr(), self.idx.into_raw()) }; } } @@ -1691,30 +2213,30 @@ impl Push for Index { #[derive(Debug, Default, Clone, Copy, Hash)] pub struct NewTable { /// Size of the preallocated array part. - pub narr: c_int, + pub narr: u32, /// Size of the preallocated hash part. - pub nrec: c_int, + pub nrec: u32, } impl NewTable { /// Creates a new [`NewTable`] with no preallocations defined. pub fn new() -> Self { - Self::new_sized(0, 0) + Self::sized(0, 0) } /// Creates a new [`NewTable`] with the array part set to preallocate `size`. - pub fn new_array(size: c_int) -> Self { - Self::new_sized(size, 0) + pub fn array(size: u32) -> Self { + Self::sized(size, 0) } /// Creates a new [`NewTable`] with the hash part set to preallocate `size`. - pub fn new_record(size: c_int) -> Self { - Self::new_sized(0, size) + pub fn record(size: u32) -> Self { + Self::sized(0, size) } /// Creates a new [`NewTable`] with the array and hash parts set to preallocate `narr` and /// `rec`. - pub fn new_sized(narr: c_int, nrec: c_int) -> Self { + pub fn sized(narr: u32, nrec: u32) -> Self { Self { narr, nrec } } } @@ -1722,10 +2244,55 @@ impl NewTable { impl Push for NewTable { fn push(&self, stack: &mut Stack) { let Self { narr, nrec } = *self; - assert!(0 <= narr, "narr must be nonnegative"); - assert!(0 <= nrec, "nrec must be nonnegative"); stack.ensure(1); - unsafe { lua_createtable(stack.as_ptr(), narr, nrec) } + unsafe { + lua_createtable( + stack.as_ptr(), + narr.try_into().expect("table narr too big"), + nrec.try_into().expect("table nrec too big"), + ) + } + } +} + +pub type BareFunction = fn(&mut Stack) -> c_int; + +#[derive(Debug, Clone, Copy, Hash)] +pub enum Function { + Bare(BareFunction), + Raw(lua_CFunction), +} + +impl Push for Function { + fn push(&self, stack: &mut Stack) { + match *self { + Function::Bare(f) => unsafe { + unsafe extern "C" fn cb(L: *mut lua_State) -> c_int { + unsafe { + let mut stack = Stack::new_unchecked(L); + let f = mem::transmute::<*mut c_void, BareFunction>( + stack.slot(Index::upvalue(1)).lightuserdata(), + ); + f(&mut stack) + } + } + + stack.ensure(2); + lua_pushlightuserdata(stack.as_ptr(), f as *mut c_void); + lua_pushcclosure(stack.as_ptr(), Some(cb), 1); + }, + Function::Raw(f) => unsafe { + assert!(f.is_some(), "raw function cannot be null"); + stack.ensure(1); + lua_pushcfunction(stack.as_ptr(), f); + }, + } + } +} + +impl Push for BareFunction { + fn push(&self, stack: &mut Stack) { + Function::Bare(*self).push(stack); } } @@ -1757,16 +2324,16 @@ pub trait Parse<'s>: Sized { impl Parse<'_> for () { fn parse(slot: &Slot) -> Result { - match slot.type_of() { + match slot.ty() { Type::Nil => Ok(()), - ty => Err(Error::InvalidType("nil", ty.name())), + ty => Err(Error::invalid_type("nil", ty.name())), } } } impl Parse<'_> for bool { fn parse(slot: &Slot) -> Result { - Ok(unsafe { lua_toboolean(slot.stack.as_ptr(), slot.index()) != 0 }) + Ok(unsafe { lua_toboolean(slot.stack.as_ptr(), slot.index().into_raw()) != 0 }) } } @@ -1774,11 +2341,11 @@ macro_rules! impl_parse_ptr { ($type:ty) => { impl Parse<'_> for $type { fn parse(slot: &Slot) -> Result { - let ptr = unsafe { lua_touserdata(slot.stack.as_ptr(), slot.idx) }; + let ptr = unsafe { lua_touserdata(slot.stack.as_ptr(), slot.index().into_raw()) }; if !ptr.is_null() { Ok(ptr as $type) } else { - Err(Error::InvalidType("userdata", slot.type_of().name())) + Err(Error::invalid_type("userdata", slot.ty().name())) } } } @@ -1793,11 +2360,13 @@ macro_rules! impl_parse_num { impl Parse<'_> for $type { fn parse(slot: &Slot) -> Result { let mut isnum = 0; - let n = unsafe { lua_tonumberx(slot.stack.as_ptr(), slot.idx, &raw mut isnum) }; + let n = unsafe { + lua_tonumberx(slot.stack.as_ptr(), slot.index().into_raw(), &raw mut isnum) + }; if isnum != 0 { Ok(n as $type) } else { - Err(Error::InvalidType("number", slot.type_of().name())) + Err(Error::invalid_type("number", slot.ty().name())) } } } @@ -1812,11 +2381,13 @@ macro_rules! impl_parse_int { impl Parse<'_> for $type { fn parse(slot: &Slot) -> Result { let mut isnum = 0; - let n = unsafe { lua_tointegerx(slot.stack.as_ptr(), slot.idx, &raw mut isnum) }; + let n = unsafe { + lua_tointegerx(slot.stack.as_ptr(), slot.index().into_raw(), &raw mut isnum) + }; if isnum != 0 { Ok(n as $type) } else { - Err(Error::InvalidType("number", slot.type_of().name())) + Err(Error::invalid_type("number", slot.ty().name())) } } } @@ -1839,11 +2410,13 @@ macro_rules! impl_parse_str { impl<'s> Parse<'s> for $type { fn parse(slot: &Slot<'s>) -> Result { let mut len = 0; - let ptr = unsafe { lua_tolstring(slot.stack.as_ptr(), slot.idx, &mut len) }; + let ptr = unsafe { + lua_tolstring(slot.stack.as_ptr(), slot.index().into_raw(), &mut len) + }; if !ptr.is_null() { Ok(unsafe { slice::from_raw_parts(ptr.cast(), len).into() }) } else { - Err(Error::InvalidType("string", slot.type_of().name())) + Err(Error::invalid_type("string", slot.ty().name())) } } } @@ -1854,7 +2427,9 @@ macro_rules! impl_parse_str_utf8 { ($type:ty) => { impl<'s> Parse<'s> for $type { fn parse(slot: &Slot<'s>) -> Result { - Ok(std::str::from_utf8(Parse::parse(slot)?)?.into()) + Ok(std::str::from_utf8(Parse::parse(slot)?) + .map_err(Error::new)? + .into()) } } }; diff --git a/src/main.rs b/src/main.rs index 32a49cd..85d2ad1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ use clap::Parser; +use luajit::Chunk; use mimalloc::MiMalloc; use owo_colors::OwoColorize; use std::{backtrace::Backtrace, fmt::Display, num::NonZero, panic, process, thread}; @@ -78,12 +79,21 @@ struct Args { #[clap(long, short = 'j', help_heading = "Runtime", value_name = "CMD=FLAGS")] jit: Vec, + /// Allow global variables. + #[clap( + long, + help_heading = "Runtime", + value_name = "ENABLED", + default_value_t = true + )] + allow_globals: bool, + /// Number of worker threads. #[clap( long, short = 'T', help_heading = "Runtime", - value_name = "THREADS", + value_name = "COUNT", default_value_t = Self::threads() )] threads: NonZero, @@ -92,14 +102,14 @@ struct Args { #[clap( long, help_heading = "Runtime", - value_name = "THREADS", + value_name = "COUNT", default_value_t = Self::blocking_threads() )] blocking_threads: NonZero, /// Enable tokio-console integration. #[cfg(feature = "tokio-console")] - #[clap(long, help_heading = "Debugging")] + #[clap(long, help_heading = "Debugging", value_name = "ENABLED")] enable_console: bool, /// tokio-console publish address. @@ -229,27 +239,32 @@ fn init_lua(args: &Args) -> lb::runtime::Runtime { print!("{}", rt.registry()); // for cdef debugging } - rt.unhandled_error(error_cb).build().unwrap() + rt.unhandled_error(error_cb) + .prohibit_globals(!args.allow_globals) + .build() + .unwrap() }; for arg in args.jit.iter() { let mut s = rt.guard(); let res = if let Some((cmd, flags)) = parse_jitlib_cmd(arg) - && let Ok(_) = s.require(format!("jit.{cmd}"), 1) + && let Ok(_) = s.require(format!("jit.{cmd}"), Some(1)) { - (s.push("start"), s.get(-2), s.push(flags)); - s.call(1, 0) // require("jit.{cmd}").start(flags) + (s.push("start"), s.get(-2)); + s.push(flags); + s.call(1, Some(0)) // require("jit.{cmd}").start(flags) } else { - s.require("jit", 1).unwrap(); + s.require("jit", Some(1)).unwrap(); match arg.as_str() { cmd @ ("on" | "off" | "flush") => { (s.push(cmd), s.get(-2)); - s.call(0, 0) // require("jit").[on/off/flush]() + s.call(0, Some(0)) // require("jit").[on/off/flush]() } flags => { (s.push("opt"), s.get(-2)); - (s.push("start"), s.get(-2), s.push(flags)); - s.call(1, 0) // require("jit").opt.start(flags) + (s.push("start"), s.get(-2)); + s.push(flags); + s.call(1, Some(0)) // require("jit").opt.start(flags) } } }; @@ -282,9 +297,9 @@ async fn main_async(args: Args, cx: &mut lb::runtime::Context) -> ExitCode { } }; - if let Err(ref err) = cx.load(&luajit::Chunk::new(chunk).path(path)) { + if let Err(ref err) = cx.load(&Chunk::new(chunk).with_path(path)) { cx.report_error(err); - } else if let Err(ref err) = cx.call_async(0, 0).await { + } else if let Err(ref err) = cx.call_async(0, Some(0)).await { cx.report_error(err); } } diff --git a/tests/main.lua b/tests/main.lua index 8fc1db9..2f3e3ed 100644 --- a/tests/main.lua +++ b/tests/main.lua @@ -26,7 +26,7 @@ end local function create_test(name, f, group) local test = { type = "test", name = name or "", group = group, state = "pending", f = f } - local fenv = setmetatable({}, { __index = global }) + local fenv = setmetatable({}, { __index = global, __newindex = global }) setfenv(f, fenv) return test end @@ -45,7 +45,7 @@ local function create_group(name, f, parent) table.insert(group.items, item) return item end, - }, { __index = global }) + }, { __index = global, __newindex = global }) setfenv(f, fenv) f(group) diff --git a/tests/main.rs b/tests/main.rs index 8ae2a2e..4c6e2b3 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -10,14 +10,17 @@ fn main() -> ExitCode { let lua = { let mut rt = lb::runtime::Builder::new(); luby::open(&mut rt); - rt.unhandled_error(error_cb).build().unwrap() + rt.unhandled_error(error_cb) + .prohibit_globals(true) + .build() + .unwrap() }; let path = "tests/main.lua"; let main = lua.spawn(async move |s| { - if let Err(ref err) = s.load(Chunk::new(fs::read(path).unwrap()).path(path)) { + if let Err(ref err) = s.load(&Chunk::new(fs::read(path).unwrap()).with_path(path)) { s.report_error(err); - } else if let Err(ref err) = s.call_async(0, 1).await { + } else if let Err(ref err) = s.call_async(0, Some(1)).await { s.report_error(err); }