From d5e85f2c30dbcf88a2526fbd95cb71d56728ec88 Mon Sep 17 00:00:00 2001 From: luaneko Date: Sat, 28 Jun 2025 19:52:28 +1000 Subject: [PATCH] Fix race condition in task state table unref --- crates/lb/src/task.rs | 35 ++++++++++++++------------- crates/luaffi/src/lib.lua | 11 ++++----- tests/main.lua | 50 +++++++++++++++++++++++++++++---------- 3 files changed, 61 insertions(+), 35 deletions(-) diff --git a/crates/lb/src/task.rs b/crates/lb/src/task.rs index 4909d1c..3aa017e 100644 --- a/crates/lb/src/task.rs +++ b/crates/lb/src/task.rs @@ -12,7 +12,7 @@ use luaffi::{ marker::{function, many}, metatype, }; -use luajit::{LUA_MULTRET, Type}; +use luajit::LUA_MULTRET; use std::{cell::RefCell, ffi::c_int, time::Duration}; use tokio::{task::JoinHandle, time::sleep}; @@ -42,38 +42,41 @@ impl lb_tasklib { pub extern "Lua" fn spawn(f: function, ...) -> lb_task { // pack the function and its arguments into a table and pass its ref to rust. // - // this table is used from rust-side to call the function with its args, and it's also - // reused to store its return values that the task handle can return when awaited. the ref - // is owned by the task handle and unref'ed when it's gc'ed. + // this "state" table is used from rust-side to call the function with its args, and it's + // also reused to store its return values that the task handle can return when awaited. the + // ref is owned by the task handle and unref'ed when it's gc'ed. assert( r#type(f) == "function", concat!("function expected in argument 'f', got ", r#type(f)), ); - Self::__spawn(__ref(__tpack(f, variadic!()))) + // we need two refs: one for the spawn call, and the other for the task handle. this is to + // ensure the task handle isn't gc'ed and the state table unref'ed before the spawn callback + // runs and puts the state table on the stack. + let state = __tpack(f, variadic!()); + Self::__spawn(__ref(state), __ref(state)) } - extern "Lua-C" fn __spawn(key: c_int) -> lb_task { + extern "Lua-C" fn __spawn(spawn_ref: c_int, handle_ref: c_int) -> lb_task { let handle = spawn(async move |cx| { - // SAFETY: key is always unique, created by __ref above. - let arg = unsafe { cx.new_ref_unchecked(key) }; + // SAFETY: handle_ref is always unique, created in Self::spawn above. + let state = unsafe { cx.new_ref_unchecked(spawn_ref) }; let mut s = cx.guard(); s.resize(0); - s.push(&arg); - let narg = s.unpack(1, 1, None) - 1; // unpack the function and its args from the table - debug_assert!(s.slot(2).type_of() == Type::Function); + s.push(state); // this drops the state table ref, but the table is still on the stack + let narg = s.unpack(1, 1, None) - 1; // unpack the function and its args from the state table match s.call_async(narg, LUA_MULTRET).await { Ok(nret) => { - s.pack(1, nret); // pack the return values back into the table + s.pack(1, nret); // pack the return values back into the state table } Err(err) => { drop(s); cx.report_error(&err); } } - let _ = arg.into_raw(); // the original ref is owned by the task handle and unref'ed there }); - lb_task::new(handle, key) + // spawn_ref is owned by the task handle and unref'ed there when the handle gets gc'ed + lb_task::new(handle, handle_ref) } } @@ -103,8 +106,8 @@ impl lb_task { async extern "Lua-C" fn __await(&self) { if let Some(handle) = self.handle.borrow_mut().take() { handle - .await - .unwrap_or_else(|err| panic!("task handler panicked: {err}")); + .await // task handler should never panic + .unwrap_or_else(|err| std::panic::resume_unwind(err.into_panic())); } } diff --git a/crates/luaffi/src/lib.lua b/crates/luaffi/src/lib.lua index 82f5834..be60daa 100644 --- a/crates/luaffi/src/lib.lua +++ b/crates/luaffi/src/lib.lua @@ -8,16 +8,15 @@ local function __ref(value) if ref ~= nil and ref ~= 0 then __registry[FREELIST_REF] = __registry[ref] else - ref = #__registry + 1 + ref = rawlen(__registry) + 1 end __registry[ref] = value return ref end local function __unref(ref) - if ref < 0 then return nil end - local value = __registry[ref] - __registry[ref] = __registry[FREELIST_REF] - __registry[FREELIST_REF] = ref - return value + if ref > 0 then + __registry[ref] = __registry[FREELIST_REF] + __registry[FREELIST_REF] = ref + end end diff --git a/tests/main.lua b/tests/main.lua index b7d16f0..ec6fa1d 100644 --- a/tests/main.lua +++ b/tests/main.lua @@ -19,7 +19,7 @@ local icons = { } local function color(name, s) - return colors[name] .. s .. colors.reset + return ("%s %s %s"):format(colors[name], s, colors.reset) end local function create_test(name, f, group) @@ -54,7 +54,7 @@ local function name_test(test) local name = test.name local group = test.group while group ~= nil do - if group.name ~= "" then name = string.format("%s %s %s", group.name, icons.chevron, name) end + if group.name ~= "" then name = ("%s %s %s"):format(group.name, icons.chevron, name) end group = group.parent end return name @@ -68,11 +68,12 @@ local function run_test(test) local ok, res = xpcall(test.f, trace, test) if ok then test.state = "pass" - print("", string.format("%s %s", color("pass", "PASS"), name_test(test))) + print("", ("%s %s"):format(color("pass", "PASS"), name_test(test))) else test.state = "fail" - print("", string.format("%s %s\n\n%s\n", color("fail", "FAIL"), name_test(test), res)) + print("", ("%s %s\n\n%s\n"):format(color("fail", "FAIL"), name_test(test), res)) end + collectgarbage() -- gc after each test to test destructors return test end @@ -86,6 +87,23 @@ local function start(cx, item) end end +local function check_unrefs() + -- ensure all refs were properly unref'ed + local registry = debug.getregistry() + local count = #registry + local ref = 0 -- FREELIST_REF + while type(registry[ref]) == "number" do + local next = registry[ref] + registry[ref], ref = nil, next + end + for i = 1, count do + local value = registry[i] + if type(value) ~= "thread" then -- ignore threads pinned by the runtime + assert(rawequal(registry[i], nil), ("ref %d not unref'ed: %s"):format(i, registry[i])) + end + end +end + local function main(item) local cx = { tasks = {} } local pass, fail = 0, 0 @@ -97,17 +115,23 @@ local function main(item) fail = fail + 1 end end + local code = 1 if fail == 0 then - print("", color("pass", string.format("%s %d tests passed", icons.check, pass))) - return 0 + print("", color("pass", ("%s %d tests passed"):format(icons.check, pass))) + code = 0 + else + print( + "", + ("%s, %s"):format( + color("pass", ("%s %d tests passed"):format(icons.check, pass)), + color("fail", ("%s %d tests failed"):format(icons.cross, fail)) + ) + ) end - print( - "", - color("pass", string.format("%s %d tests passed", icons.check, pass)) - .. ", " - .. color("fail", string.format("%s %d tests failed", icons.cross, fail)) - ) - return 1 -- report error to cargo + cx = nil + collectgarbage() + check_unrefs() + return code -- report error to cargo end return main(create_group("", function()