Fix race condition in task state table unref

This commit is contained in:
lumi 2025-06-28 19:52:28 +10:00
parent cea9bc0813
commit d5e85f2c30
Signed by: luaneko
GPG Key ID: 406809B8763FF07A
3 changed files with 61 additions and 35 deletions

View File

@ -12,7 +12,7 @@ use luaffi::{
marker::{function, many}, marker::{function, many},
metatype, metatype,
}; };
use luajit::{LUA_MULTRET, Type}; use luajit::LUA_MULTRET;
use std::{cell::RefCell, ffi::c_int, time::Duration}; use std::{cell::RefCell, ffi::c_int, time::Duration};
use tokio::{task::JoinHandle, time::sleep}; use tokio::{task::JoinHandle, time::sleep};
@ -42,38 +42,41 @@ impl lb_tasklib {
pub extern "Lua" fn spawn(f: function, ...) -> lb_task { pub extern "Lua" fn spawn(f: function, ...) -> lb_task {
// pack the function and its arguments into a table and pass its ref to rust. // pack the function and its arguments into a table and pass its ref to rust.
// //
// this table is used from rust-side to call the function with its args, and it's also // this "state" table is used from rust-side to call the function with its args, and it's
// reused to store its return values that the task handle can return when awaited. the ref // also reused to store its return values that the task handle can return when awaited. the
// is owned by the task handle and unref'ed when it's gc'ed. // ref is owned by the task handle and unref'ed when it's gc'ed.
assert( assert(
r#type(f) == "function", r#type(f) == "function",
concat!("function expected in argument 'f', got ", r#type(f)), concat!("function expected in argument 'f', got ", r#type(f)),
); );
Self::__spawn(__ref(__tpack(f, variadic!()))) // we need two refs: one for the spawn call, and the other for the task handle. this is to
// ensure the task handle isn't gc'ed and the state table unref'ed before the spawn callback
// runs and puts the state table on the stack.
let state = __tpack(f, variadic!());
Self::__spawn(__ref(state), __ref(state))
} }
extern "Lua-C" fn __spawn(key: c_int) -> lb_task { extern "Lua-C" fn __spawn(spawn_ref: c_int, handle_ref: c_int) -> lb_task {
let handle = spawn(async move |cx| { let handle = spawn(async move |cx| {
// SAFETY: key is always unique, created by __ref above. // SAFETY: handle_ref is always unique, created in Self::spawn above.
let arg = unsafe { cx.new_ref_unchecked(key) }; let state = unsafe { cx.new_ref_unchecked(spawn_ref) };
let mut s = cx.guard(); let mut s = cx.guard();
s.resize(0); s.resize(0);
s.push(&arg); s.push(state); // this drops the state table ref, but the table is still on the stack
let narg = s.unpack(1, 1, None) - 1; // unpack the function and its args from the table let narg = s.unpack(1, 1, None) - 1; // unpack the function and its args from the state table
debug_assert!(s.slot(2).type_of() == Type::Function);
match s.call_async(narg, LUA_MULTRET).await { match s.call_async(narg, LUA_MULTRET).await {
Ok(nret) => { Ok(nret) => {
s.pack(1, nret); // pack the return values back into the table s.pack(1, nret); // pack the return values back into the state table
} }
Err(err) => { Err(err) => {
drop(s); drop(s);
cx.report_error(&err); cx.report_error(&err);
} }
} }
let _ = arg.into_raw(); // the original ref is owned by the task handle and unref'ed there
}); });
lb_task::new(handle, key) // spawn_ref is owned by the task handle and unref'ed there when the handle gets gc'ed
lb_task::new(handle, handle_ref)
} }
} }
@ -103,8 +106,8 @@ impl lb_task {
async extern "Lua-C" fn __await(&self) { async extern "Lua-C" fn __await(&self) {
if let Some(handle) = self.handle.borrow_mut().take() { if let Some(handle) = self.handle.borrow_mut().take() {
handle handle
.await .await // task handler should never panic
.unwrap_or_else(|err| panic!("task handler panicked: {err}")); .unwrap_or_else(|err| std::panic::resume_unwind(err.into_panic()));
} }
} }

View File

@ -8,16 +8,15 @@ local function __ref(value)
if ref ~= nil and ref ~= 0 then if ref ~= nil and ref ~= 0 then
__registry[FREELIST_REF] = __registry[ref] __registry[FREELIST_REF] = __registry[ref]
else else
ref = #__registry + 1 ref = rawlen(__registry) + 1
end end
__registry[ref] = value __registry[ref] = value
return ref return ref
end end
local function __unref(ref) local function __unref(ref)
if ref < 0 then return nil end if ref > 0 then
local value = __registry[ref] __registry[ref] = __registry[FREELIST_REF]
__registry[ref] = __registry[FREELIST_REF] __registry[FREELIST_REF] = ref
__registry[FREELIST_REF] = ref end
return value
end end

View File

@ -19,7 +19,7 @@ local icons = {
} }
local function color(name, s) local function color(name, s)
return colors[name] .. s .. colors.reset return ("%s %s %s"):format(colors[name], s, colors.reset)
end end
local function create_test(name, f, group) local function create_test(name, f, group)
@ -54,7 +54,7 @@ local function name_test(test)
local name = test.name local name = test.name
local group = test.group local group = test.group
while group ~= nil do while group ~= nil do
if group.name ~= "" then name = string.format("%s %s %s", group.name, icons.chevron, name) end if group.name ~= "" then name = ("%s %s %s"):format(group.name, icons.chevron, name) end
group = group.parent group = group.parent
end end
return name return name
@ -68,11 +68,12 @@ local function run_test(test)
local ok, res = xpcall(test.f, trace, test) local ok, res = xpcall(test.f, trace, test)
if ok then if ok then
test.state = "pass" test.state = "pass"
print("", string.format("%s %s", color("pass", "PASS"), name_test(test))) print("", ("%s %s"):format(color("pass", "PASS"), name_test(test)))
else else
test.state = "fail" test.state = "fail"
print("", string.format("%s %s\n\n%s\n", color("fail", "FAIL"), name_test(test), res)) print("", ("%s %s\n\n%s\n"):format(color("fail", "FAIL"), name_test(test), res))
end end
collectgarbage() -- gc after each test to test destructors
return test return test
end end
@ -86,6 +87,23 @@ local function start(cx, item)
end end
end end
local function check_unrefs()
-- ensure all refs were properly unref'ed
local registry = debug.getregistry()
local count = #registry
local ref = 0 -- FREELIST_REF
while type(registry[ref]) == "number" do
local next = registry[ref]
registry[ref], ref = nil, next
end
for i = 1, count do
local value = registry[i]
if type(value) ~= "thread" then -- ignore threads pinned by the runtime
assert(rawequal(registry[i], nil), ("ref %d not unref'ed: %s"):format(i, registry[i]))
end
end
end
local function main(item) local function main(item)
local cx = { tasks = {} } local cx = { tasks = {} }
local pass, fail = 0, 0 local pass, fail = 0, 0
@ -97,17 +115,23 @@ local function main(item)
fail = fail + 1 fail = fail + 1
end end
end end
local code = 1
if fail == 0 then if fail == 0 then
print("", color("pass", string.format("%s %d tests passed", icons.check, pass))) print("", color("pass", ("%s %d tests passed"):format(icons.check, pass)))
return 0 code = 0
else
print(
"",
("%s, %s"):format(
color("pass", ("%s %d tests passed"):format(icons.check, pass)),
color("fail", ("%s %d tests failed"):format(icons.cross, fail))
)
)
end end
print( cx = nil
"", collectgarbage()
color("pass", string.format("%s %d tests passed", icons.check, pass)) check_unrefs()
.. ", " return code -- report error to cargo
.. color("fail", string.format("%s %d tests failed", icons.cross, fail))
)
return 1 -- report error to cargo
end end
return main(create_group("", function() return main(create_group("", function()