Implement proper error handlign in spawned tasks

This commit is contained in:
lumi 2025-06-26 18:47:31 +10:00
parent 24c5e9edc2
commit 9b7dbcc141
Signed by: luaneko
GPG Key ID: 406809B8763FF07A
3 changed files with 122 additions and 74 deletions

View File

@ -1,20 +1,27 @@
use derive_more::{Deref, DerefMut}; use derive_more::{Deref, DerefMut};
use luaffi::{Module, Registry}; use luaffi::{Module, Registry};
use luajit::{Chunk, State}; use luajit::{Chunk, State};
use std::rc::Rc;
use tokio::{ use tokio::{
task::{JoinHandle, LocalSet, futures::TaskLocalFuture, spawn_local}, task::{JoinHandle, LocalSet, futures::TaskLocalFuture, spawn_local},
task_local, task_local,
}; };
#[derive(Debug, Default)] pub type ErrorFn = dyn Fn(&luajit::Error);
pub struct Builder { pub struct Builder {
registry: Registry, registry: Registry,
report_err: Rc<ErrorFn>,
} }
impl Builder { impl Builder {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
registry: Registry::new(), registry: Registry::new(),
report_err: Rc::new(|err| match err.trace() {
Some(trace) => eprintln!("unhandled lua error: {err}\n{trace}"),
None => eprintln!("unhandled lua error: {err}"),
}),
} }
} }
@ -22,6 +29,11 @@ impl Builder {
&self.registry &self.registry
} }
pub fn unhandled_error(&mut self, handler: impl Fn(&luajit::Error) + 'static) -> &mut Self {
self.report_err = Rc::new(handler);
self
}
pub fn module<T: Module>(&mut self) -> &mut Self { pub fn module<T: Module>(&mut self) -> &mut Self {
self.registry.preload::<T>(); self.registry.preload::<T>();
self self
@ -29,50 +41,74 @@ impl Builder {
pub fn build(&self) -> luajit::Result<Runtime> { pub fn build(&self) -> luajit::Result<Runtime> {
Ok(Runtime { Ok(Runtime {
state: { cx: Context {
let mut s = State::new()?; state: {
s.eval(Chunk::new(self.registry.build()).path("[luby]"), 0, 0)?; let mut s = State::new()?;
s s.eval(Chunk::new(self.registry.build()).path("[luby]"), 0, 0)?;
s
},
report_err: self.report_err.clone(),
}, },
tasks: LocalSet::new(), tasks: LocalSet::new(),
}) })
} }
} }
#[derive(Debug, Deref, DerefMut)] #[derive(Deref, DerefMut)]
pub struct Runtime { pub struct Runtime {
#[deref] #[deref]
#[deref_mut] #[deref_mut]
state: State, cx: Context,
tasks: LocalSet, tasks: LocalSet,
} }
task_local! {
static STATE: State;
}
impl Runtime { impl Runtime {
pub fn spawn<T: 'static>( pub fn spawn<T: 'static>(
&self, &self,
f: impl AsyncFnOnce(&mut State) -> T + 'static, f: impl AsyncFnOnce(&mut Context) -> T + 'static,
) -> JoinHandle<T> { ) -> JoinHandle<T> {
self.tasks self.tasks
.spawn_local(async move { f(&mut STATE.with(|s| s.new_thread())).await }) .spawn_local(async move { f(&mut CURRENT.with(|s| s.new_thread())).await })
} }
} }
pub fn spawn<T: 'static>(f: impl AsyncFnOnce(&mut State) -> T + 'static) -> JoinHandle<T> {
// SAFETY: `new_thread` must be called inside `spawn_local` because this free-standing spawn
// function may be called via ffi from lua, and it is not safe to access the lua state within
// ffi calls.
spawn_local(async move { f(&mut STATE.with(|s| s.new_thread())).await })
}
impl IntoFuture for Runtime { impl IntoFuture for Runtime {
type Output = (); type Output = ();
type IntoFuture = TaskLocalFuture<State, LocalSet>; type IntoFuture = TaskLocalFuture<Context, LocalSet>;
fn into_future(self) -> Self::IntoFuture { fn into_future(self) -> Self::IntoFuture {
STATE.scope(self.state, self.tasks) CURRENT.scope(self.cx, self.tasks)
} }
} }
task_local! {
static CURRENT: Context;
}
#[derive(Deref, DerefMut)]
pub struct Context {
#[deref]
#[deref_mut]
state: State,
report_err: Rc<ErrorFn>,
}
impl Context {
pub fn new_thread(&self) -> Self {
Self {
state: self.state.new_thread(),
report_err: self.report_err.clone(),
}
}
pub fn report_error(&self, err: &luajit::Error) {
(self.report_err)(&err);
}
}
pub fn spawn<T: 'static>(f: impl AsyncFnOnce(&mut Context) -> T + 'static) -> JoinHandle<T> {
// SAFETY: `new_thread` must be called inside `spawn_local` because this free-standing spawn
// function may be called via ffi from lua, and it is not safe to access the lua state within
// ffi calls.
spawn_local(async move { f(&mut CURRENT.with(|s| s.new_thread())).await })
}

View File

@ -1,7 +1,7 @@
use crate::runtime::spawn; use crate::runtime::spawn;
use luaffi::{cdef, metatype}; use luaffi::{cdef, metatype};
use std::{ffi::c_int, process}; use std::{ffi::c_int, time::Duration};
use tokio::task::JoinHandle; use tokio::{task::JoinHandle, time::sleep};
#[cdef(module = "lb:task")] #[cdef(module = "lb:task")]
pub struct lb_tasklib; pub struct lb_tasklib;
@ -20,17 +20,17 @@ impl lb_tasklib {
} }
extern "Lua-C" fn __spawn(&self, key: c_int) -> lb_task { extern "Lua-C" fn __spawn(&self, key: c_int) -> lb_task {
let handle = spawn(async move |s| { let handle = spawn(async move |cx| {
// SAFETY: key is always unique, created by __ref above // SAFETY: key is always unique, created by __ref above
let arg = unsafe { s.new_ref_unchecked(key) }; let arg = unsafe { cx.new_ref_unchecked(key) };
let mut s = cx.guard();
s.resize(0); s.resize(0);
s.push(arg); s.push(arg);
let narg = s.unpack(1, 1, None) - 1; let narg = s.unpack(1, 1, None) - 1; // unpack the table containing the function to call and its args
println!("{s:?}"); if let Err(err) = s.call_async(narg, 0).await {
if let Err(_err) = s.call_async(narg, 0).await { drop(s);
process::exit(1) cx.report_error(&err);
} }
println!("{s:?}");
}); });
lb_task { handle } lb_task { handle }

View File

@ -40,6 +40,22 @@ fn panic_cb(panic: &panic::PanicHookInfo) {
); );
} }
fn error_cb(err: &luajit::Error) {
match err.trace() {
Some(trace) => eprintln!("{}\n{trace}", err.red().bold()),
None => eprintln!("{}", err.red().bold()),
}
process::exit(1);
}
fn unwrap_exit<T, E: Display>(code: ExitCode) -> impl FnOnce(E) -> T {
move |err| {
eprintln!("{}", err.red().bold());
code.exit()
}
}
#[derive(Debug, Parser)] #[derive(Debug, Parser)]
struct Args { struct Args {
/// Paths to scripts to execute. /// Paths to scripts to execute.
@ -121,12 +137,12 @@ impl Args {
} }
} }
fn main() { fn main() -> Result<(), ExitCode> {
panic::set_hook(Box::new(panic_cb)); panic::set_hook(Box::new(panic_cb));
let args = Args::parse(); let args = Args::parse();
if args.version { if args.version {
return print_version(); return Ok(print_version());
} }
init_logger(&args); init_logger(&args);
@ -153,13 +169,6 @@ fn print_version() {
); );
} }
fn unwrap_exit<T, E: Display>(code: ExitCode) -> impl FnOnce(E) -> T {
move |err| {
eprintln!("{}", err.red().bold());
code.exit()
}
}
fn init_logger(args: &Args) { fn init_logger(args: &Args) {
use tracing::level_filters::LevelFilter; use tracing::level_filters::LevelFilter;
use tracing_subscriber::util::*; use tracing_subscriber::util::*;
@ -192,54 +201,62 @@ fn init_logger(args: &Args) {
} }
fn init_tokio(args: &Args) -> tokio::runtime::Runtime { fn init_tokio(args: &Args) -> tokio::runtime::Runtime {
let mut rt = match args.threads.get() { match args.threads.get() {
1 => tokio::runtime::Builder::new_current_thread(), 1 => tokio::runtime::Builder::new_current_thread(),
n => { n => {
let mut rt = tokio::runtime::Builder::new_multi_thread(); let mut rt = tokio::runtime::Builder::new_multi_thread();
rt.worker_threads(n - 1); rt.worker_threads(n - 1);
rt rt
} }
}; }
.enable_all()
rt.enable_all() .thread_name("luby")
.thread_name("luby") .max_blocking_threads(args.blocking_threads.get())
.max_blocking_threads(args.blocking_threads.get()) .build()
.build() .unwrap_or_else(unwrap_exit(ExitCode::OsErr))
.unwrap_or_else(unwrap_exit(ExitCode::OsErr))
} }
fn init_lua(args: &Args) -> lb::runtime::Runtime { fn init_lua(args: &Args) -> lb::runtime::Runtime {
let mut rt = lb::runtime::Builder::new(); let mut rt = {
luby::open(&mut rt); let mut rt = lb::runtime::Builder::new();
rt.unhandled_error(error_cb);
luby::open(&mut rt);
if args.dump.iter().find(|s| *s == "cdef").is_some() { if args.dump.iter().find(|s| *s == "cdef").is_some() {
print!("{}", rt.registry()); print!("{}", rt.registry()); // for debugging
}
rt
} }
.build()
let mut rt = rt.build().unwrap(); .unwrap();
for arg in args.jit.iter() { for arg in args.jit.iter() {
let mut s = rt.guard(); let mut s = rt.guard();
if let Some((cmd, flags)) = parse_jitlib_cmd(arg) let res = if let Some((cmd, flags)) = parse_jitlib_cmd(arg)
&& let Ok(_) = s.require(format!("jit.{cmd}"), 1) && let Ok(_) = s.require(format!("jit.{cmd}"), 1)
{ {
(s.push("start"), s.get(-2), s.push(flags)); (s.push("start"), s.get(-2), s.push(flags));
s.call(1, 0) s.call(1, 0) // require("jit.{cmd}").start(flags)
} else { } else {
s.require("jit", 1).unwrap(); s.require("jit", 1).unwrap();
match arg.as_str() { match arg.as_str() {
cmd @ ("on" | "off" | "flush") => { cmd @ ("on" | "off" | "flush") => {
(s.push(cmd), s.get(-2)); (s.push(cmd), s.get(-2));
s.call(0, 0) s.call(0, 0) // require("jit").[on/off/flush]()
} }
arg => { flags => {
(s.push("opt"), s.get(-2)); (s.push("opt"), s.get(-2));
(s.push("start"), s.get(-2), s.push(arg)); (s.push("start"), s.get(-2), s.push(flags));
s.call(1, 0) s.call(1, 0) // require("jit").opt.start(flags)
} }
} }
};
if let Err(err) = res {
drop(s);
rt.report_error(&err);
} }
.unwrap_or_else(unwrap_exit(ExitCode::Usage));
} }
rt rt
@ -254,27 +271,22 @@ fn parse_jitlib_cmd(s: &str) -> Option<(&str, &str)> {
} }
} }
async fn main_async(args: Args, state: &mut luajit::State) { async fn main_async(args: Args, cx: &mut lb::runtime::Context) -> Result<(), ExitCode> {
for ref path in args.path { for ref path in args.path {
let mut s = state.guard();
let chunk = match std::fs::read(path) { let chunk = match std::fs::read(path) {
Ok(chunk) => chunk, Ok(chunk) => chunk,
Err(err) => { Err(err) => {
eprintln!("{}", format_args!("{path}: {err}").red().bold()); eprintln!("{}", format_args!("{path}: {err}").red().bold());
ExitCode::NoInput.exit(); return Err(ExitCode::NoInput);
} }
}; };
s.load(&luajit::Chunk::new(chunk).path(path)) if let Err(err) = cx.load(&luajit::Chunk::new(chunk).path(path)) {
.unwrap_or_else(unwrap_exit(ExitCode::NoInput)); cx.report_error(&err);
} else if let Err(err) = cx.call_async(0, 0).await {
if let Err(err) = s.call_async(0, 0).await { cx.report_error(&err);
match err.trace() {
Some(trace) => eprintln!("{}\n{trace}", err.red().bold()),
None => eprintln!("{}", err.red().bold()),
}
process::exit(1);
} }
} }
Ok(())
} }