diff --git a/crates/lb/src/runtime.rs b/crates/lb/src/runtime.rs index 2b7f4ac..0d1c857 100644 --- a/crates/lb/src/runtime.rs +++ b/crates/lb/src/runtime.rs @@ -1,20 +1,27 @@ use derive_more::{Deref, DerefMut}; use luaffi::{Module, Registry}; use luajit::{Chunk, State}; +use std::rc::Rc; use tokio::{ task::{JoinHandle, LocalSet, futures::TaskLocalFuture, spawn_local}, task_local, }; -#[derive(Debug, Default)] +pub type ErrorFn = dyn Fn(&luajit::Error); + pub struct Builder { registry: Registry, + report_err: Rc, } impl Builder { pub fn new() -> Self { Self { registry: Registry::new(), + report_err: Rc::new(|err| match err.trace() { + Some(trace) => eprintln!("unhandled lua error: {err}\n{trace}"), + None => eprintln!("unhandled lua error: {err}"), + }), } } @@ -22,6 +29,11 @@ impl Builder { &self.registry } + pub fn unhandled_error(&mut self, handler: impl Fn(&luajit::Error) + 'static) -> &mut Self { + self.report_err = Rc::new(handler); + self + } + pub fn module(&mut self) -> &mut Self { self.registry.preload::(); self @@ -29,50 +41,74 @@ impl Builder { pub fn build(&self) -> luajit::Result { Ok(Runtime { - state: { - let mut s = State::new()?; - s.eval(Chunk::new(self.registry.build()).path("[luby]"), 0, 0)?; - s + cx: Context { + state: { + let mut s = State::new()?; + s.eval(Chunk::new(self.registry.build()).path("[luby]"), 0, 0)?; + s + }, + report_err: self.report_err.clone(), }, tasks: LocalSet::new(), }) } } -#[derive(Debug, Deref, DerefMut)] +#[derive(Deref, DerefMut)] pub struct Runtime { #[deref] #[deref_mut] - state: State, + cx: Context, tasks: LocalSet, } -task_local! { - static STATE: State; -} - impl Runtime { pub fn spawn( &self, - f: impl AsyncFnOnce(&mut State) -> T + 'static, + f: impl AsyncFnOnce(&mut Context) -> T + 'static, ) -> JoinHandle { self.tasks - .spawn_local(async move { f(&mut STATE.with(|s| s.new_thread())).await }) + .spawn_local(async move { f(&mut CURRENT.with(|s| s.new_thread())).await }) } } -pub fn spawn(f: impl AsyncFnOnce(&mut State) -> T + 'static) -> JoinHandle { - // SAFETY: `new_thread` must be called inside `spawn_local` because this free-standing spawn - // function may be called via ffi from lua, and it is not safe to access the lua state within - // ffi calls. - spawn_local(async move { f(&mut STATE.with(|s| s.new_thread())).await }) -} - impl IntoFuture for Runtime { type Output = (); - type IntoFuture = TaskLocalFuture; + type IntoFuture = TaskLocalFuture; fn into_future(self) -> Self::IntoFuture { - STATE.scope(self.state, self.tasks) + CURRENT.scope(self.cx, self.tasks) } } + +task_local! { + static CURRENT: Context; +} + +#[derive(Deref, DerefMut)] +pub struct Context { + #[deref] + #[deref_mut] + state: State, + report_err: Rc, +} + +impl Context { + pub fn new_thread(&self) -> Self { + Self { + state: self.state.new_thread(), + report_err: self.report_err.clone(), + } + } + + pub fn report_error(&self, err: &luajit::Error) { + (self.report_err)(&err); + } +} + +pub fn spawn(f: impl AsyncFnOnce(&mut Context) -> T + 'static) -> JoinHandle { + // SAFETY: `new_thread` must be called inside `spawn_local` because this free-standing spawn + // function may be called via ffi from lua, and it is not safe to access the lua state within + // ffi calls. + spawn_local(async move { f(&mut CURRENT.with(|s| s.new_thread())).await }) +} diff --git a/crates/lb/src/task.rs b/crates/lb/src/task.rs index c7f1134..ea7919b 100644 --- a/crates/lb/src/task.rs +++ b/crates/lb/src/task.rs @@ -1,7 +1,7 @@ use crate::runtime::spawn; use luaffi::{cdef, metatype}; -use std::{ffi::c_int, process}; -use tokio::task::JoinHandle; +use std::{ffi::c_int, time::Duration}; +use tokio::{task::JoinHandle, time::sleep}; #[cdef(module = "lb:task")] pub struct lb_tasklib; @@ -20,17 +20,17 @@ impl lb_tasklib { } extern "Lua-C" fn __spawn(&self, key: c_int) -> lb_task { - let handle = spawn(async move |s| { + let handle = spawn(async move |cx| { // SAFETY: key is always unique, created by __ref above - let arg = unsafe { s.new_ref_unchecked(key) }; + let arg = unsafe { cx.new_ref_unchecked(key) }; + let mut s = cx.guard(); s.resize(0); s.push(arg); - let narg = s.unpack(1, 1, None) - 1; - println!("{s:?}"); - if let Err(_err) = s.call_async(narg, 0).await { - process::exit(1) + let narg = s.unpack(1, 1, None) - 1; // unpack the table containing the function to call and its args + if let Err(err) = s.call_async(narg, 0).await { + drop(s); + cx.report_error(&err); } - println!("{s:?}"); }); lb_task { handle } diff --git a/src/main.rs b/src/main.rs index eda23b1..2f17ad4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -40,6 +40,22 @@ fn panic_cb(panic: &panic::PanicHookInfo) { ); } +fn error_cb(err: &luajit::Error) { + match err.trace() { + Some(trace) => eprintln!("{}\n{trace}", err.red().bold()), + None => eprintln!("{}", err.red().bold()), + } + + process::exit(1); +} + +fn unwrap_exit(code: ExitCode) -> impl FnOnce(E) -> T { + move |err| { + eprintln!("{}", err.red().bold()); + code.exit() + } +} + #[derive(Debug, Parser)] struct Args { /// Paths to scripts to execute. @@ -121,12 +137,12 @@ impl Args { } } -fn main() { +fn main() -> Result<(), ExitCode> { panic::set_hook(Box::new(panic_cb)); let args = Args::parse(); if args.version { - return print_version(); + return Ok(print_version()); } init_logger(&args); @@ -153,13 +169,6 @@ fn print_version() { ); } -fn unwrap_exit(code: ExitCode) -> impl FnOnce(E) -> T { - move |err| { - eprintln!("{}", err.red().bold()); - code.exit() - } -} - fn init_logger(args: &Args) { use tracing::level_filters::LevelFilter; use tracing_subscriber::util::*; @@ -192,54 +201,62 @@ fn init_logger(args: &Args) { } fn init_tokio(args: &Args) -> tokio::runtime::Runtime { - let mut rt = match args.threads.get() { + match args.threads.get() { 1 => tokio::runtime::Builder::new_current_thread(), n => { let mut rt = tokio::runtime::Builder::new_multi_thread(); rt.worker_threads(n - 1); rt } - }; - - rt.enable_all() - .thread_name("luby") - .max_blocking_threads(args.blocking_threads.get()) - .build() - .unwrap_or_else(unwrap_exit(ExitCode::OsErr)) + } + .enable_all() + .thread_name("luby") + .max_blocking_threads(args.blocking_threads.get()) + .build() + .unwrap_or_else(unwrap_exit(ExitCode::OsErr)) } fn init_lua(args: &Args) -> lb::runtime::Runtime { - let mut rt = lb::runtime::Builder::new(); - luby::open(&mut rt); + let mut rt = { + let mut rt = lb::runtime::Builder::new(); + rt.unhandled_error(error_cb); + luby::open(&mut rt); - if args.dump.iter().find(|s| *s == "cdef").is_some() { - print!("{}", rt.registry()); + if args.dump.iter().find(|s| *s == "cdef").is_some() { + print!("{}", rt.registry()); // for debugging + } + + rt } - - let mut rt = rt.build().unwrap(); + .build() + .unwrap(); for arg in args.jit.iter() { let mut s = rt.guard(); - if let Some((cmd, flags)) = parse_jitlib_cmd(arg) + let res = if let Some((cmd, flags)) = parse_jitlib_cmd(arg) && let Ok(_) = s.require(format!("jit.{cmd}"), 1) { (s.push("start"), s.get(-2), s.push(flags)); - s.call(1, 0) + s.call(1, 0) // require("jit.{cmd}").start(flags) } else { s.require("jit", 1).unwrap(); match arg.as_str() { cmd @ ("on" | "off" | "flush") => { (s.push(cmd), s.get(-2)); - s.call(0, 0) + s.call(0, 0) // require("jit").[on/off/flush]() } - arg => { + flags => { (s.push("opt"), s.get(-2)); - (s.push("start"), s.get(-2), s.push(arg)); - s.call(1, 0) + (s.push("start"), s.get(-2), s.push(flags)); + s.call(1, 0) // require("jit").opt.start(flags) } } + }; + + if let Err(err) = res { + drop(s); + rt.report_error(&err); } - .unwrap_or_else(unwrap_exit(ExitCode::Usage)); } rt @@ -254,27 +271,22 @@ fn parse_jitlib_cmd(s: &str) -> Option<(&str, &str)> { } } -async fn main_async(args: Args, state: &mut luajit::State) { +async fn main_async(args: Args, cx: &mut lb::runtime::Context) -> Result<(), ExitCode> { for ref path in args.path { - let mut s = state.guard(); let chunk = match std::fs::read(path) { Ok(chunk) => chunk, Err(err) => { eprintln!("{}", format_args!("{path}: {err}").red().bold()); - ExitCode::NoInput.exit(); + return Err(ExitCode::NoInput); } }; - s.load(&luajit::Chunk::new(chunk).path(path)) - .unwrap_or_else(unwrap_exit(ExitCode::NoInput)); - - if let Err(err) = s.call_async(0, 0).await { - match err.trace() { - Some(trace) => eprintln!("{}\n{trace}", err.red().bold()), - None => eprintln!("{}", err.red().bold()), - } - - process::exit(1); + if let Err(err) = cx.load(&luajit::Chunk::new(chunk).path(path)) { + cx.report_error(&err); + } else if let Err(err) = cx.call_async(0, 0).await { + cx.report_error(&err); } } + + Ok(()) }