Fix race condition in task state table unref
This commit is contained in:
parent
cea9bc0813
commit
d5e85f2c30
@ -12,7 +12,7 @@ use luaffi::{
|
|||||||
marker::{function, many},
|
marker::{function, many},
|
||||||
metatype,
|
metatype,
|
||||||
};
|
};
|
||||||
use luajit::{LUA_MULTRET, Type};
|
use luajit::LUA_MULTRET;
|
||||||
use std::{cell::RefCell, ffi::c_int, time::Duration};
|
use std::{cell::RefCell, ffi::c_int, time::Duration};
|
||||||
use tokio::{task::JoinHandle, time::sleep};
|
use tokio::{task::JoinHandle, time::sleep};
|
||||||
|
|
||||||
@ -42,38 +42,41 @@ impl lb_tasklib {
|
|||||||
pub extern "Lua" fn spawn(f: function, ...) -> lb_task {
|
pub extern "Lua" fn spawn(f: function, ...) -> lb_task {
|
||||||
// pack the function and its arguments into a table and pass its ref to rust.
|
// pack the function and its arguments into a table and pass its ref to rust.
|
||||||
//
|
//
|
||||||
// this table is used from rust-side to call the function with its args, and it's also
|
// this "state" table is used from rust-side to call the function with its args, and it's
|
||||||
// reused to store its return values that the task handle can return when awaited. the ref
|
// also reused to store its return values that the task handle can return when awaited. the
|
||||||
// is owned by the task handle and unref'ed when it's gc'ed.
|
// ref is owned by the task handle and unref'ed when it's gc'ed.
|
||||||
assert(
|
assert(
|
||||||
r#type(f) == "function",
|
r#type(f) == "function",
|
||||||
concat!("function expected in argument 'f', got ", r#type(f)),
|
concat!("function expected in argument 'f', got ", r#type(f)),
|
||||||
);
|
);
|
||||||
Self::__spawn(__ref(__tpack(f, variadic!())))
|
// we need two refs: one for the spawn call, and the other for the task handle. this is to
|
||||||
|
// ensure the task handle isn't gc'ed and the state table unref'ed before the spawn callback
|
||||||
|
// runs and puts the state table on the stack.
|
||||||
|
let state = __tpack(f, variadic!());
|
||||||
|
Self::__spawn(__ref(state), __ref(state))
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "Lua-C" fn __spawn(key: c_int) -> lb_task {
|
extern "Lua-C" fn __spawn(spawn_ref: c_int, handle_ref: c_int) -> lb_task {
|
||||||
let handle = spawn(async move |cx| {
|
let handle = spawn(async move |cx| {
|
||||||
// SAFETY: key is always unique, created by __ref above.
|
// SAFETY: handle_ref is always unique, created in Self::spawn above.
|
||||||
let arg = unsafe { cx.new_ref_unchecked(key) };
|
let state = unsafe { cx.new_ref_unchecked(spawn_ref) };
|
||||||
let mut s = cx.guard();
|
let mut s = cx.guard();
|
||||||
s.resize(0);
|
s.resize(0);
|
||||||
s.push(&arg);
|
s.push(state); // this drops the state table ref, but the table is still on the stack
|
||||||
let narg = s.unpack(1, 1, None) - 1; // unpack the function and its args from the table
|
let narg = s.unpack(1, 1, None) - 1; // unpack the function and its args from the state table
|
||||||
debug_assert!(s.slot(2).type_of() == Type::Function);
|
|
||||||
match s.call_async(narg, LUA_MULTRET).await {
|
match s.call_async(narg, LUA_MULTRET).await {
|
||||||
Ok(nret) => {
|
Ok(nret) => {
|
||||||
s.pack(1, nret); // pack the return values back into the table
|
s.pack(1, nret); // pack the return values back into the state table
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
drop(s);
|
drop(s);
|
||||||
cx.report_error(&err);
|
cx.report_error(&err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let _ = arg.into_raw(); // the original ref is owned by the task handle and unref'ed there
|
|
||||||
});
|
});
|
||||||
|
|
||||||
lb_task::new(handle, key)
|
// spawn_ref is owned by the task handle and unref'ed there when the handle gets gc'ed
|
||||||
|
lb_task::new(handle, handle_ref)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -103,8 +106,8 @@ impl lb_task {
|
|||||||
async extern "Lua-C" fn __await(&self) {
|
async extern "Lua-C" fn __await(&self) {
|
||||||
if let Some(handle) = self.handle.borrow_mut().take() {
|
if let Some(handle) = self.handle.borrow_mut().take() {
|
||||||
handle
|
handle
|
||||||
.await
|
.await // task handler should never panic
|
||||||
.unwrap_or_else(|err| panic!("task handler panicked: {err}"));
|
.unwrap_or_else(|err| std::panic::resume_unwind(err.into_panic()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,16 +8,15 @@ local function __ref(value)
|
|||||||
if ref ~= nil and ref ~= 0 then
|
if ref ~= nil and ref ~= 0 then
|
||||||
__registry[FREELIST_REF] = __registry[ref]
|
__registry[FREELIST_REF] = __registry[ref]
|
||||||
else
|
else
|
||||||
ref = #__registry + 1
|
ref = rawlen(__registry) + 1
|
||||||
end
|
end
|
||||||
__registry[ref] = value
|
__registry[ref] = value
|
||||||
return ref
|
return ref
|
||||||
end
|
end
|
||||||
|
|
||||||
local function __unref(ref)
|
local function __unref(ref)
|
||||||
if ref < 0 then return nil end
|
if ref > 0 then
|
||||||
local value = __registry[ref]
|
__registry[ref] = __registry[FREELIST_REF]
|
||||||
__registry[ref] = __registry[FREELIST_REF]
|
__registry[FREELIST_REF] = ref
|
||||||
__registry[FREELIST_REF] = ref
|
end
|
||||||
return value
|
|
||||||
end
|
end
|
||||||
|
@ -19,7 +19,7 @@ local icons = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
local function color(name, s)
|
local function color(name, s)
|
||||||
return colors[name] .. s .. colors.reset
|
return ("%s %s %s"):format(colors[name], s, colors.reset)
|
||||||
end
|
end
|
||||||
|
|
||||||
local function create_test(name, f, group)
|
local function create_test(name, f, group)
|
||||||
@ -54,7 +54,7 @@ local function name_test(test)
|
|||||||
local name = test.name
|
local name = test.name
|
||||||
local group = test.group
|
local group = test.group
|
||||||
while group ~= nil do
|
while group ~= nil do
|
||||||
if group.name ~= "" then name = string.format("%s %s %s", group.name, icons.chevron, name) end
|
if group.name ~= "" then name = ("%s %s %s"):format(group.name, icons.chevron, name) end
|
||||||
group = group.parent
|
group = group.parent
|
||||||
end
|
end
|
||||||
return name
|
return name
|
||||||
@ -68,11 +68,12 @@ local function run_test(test)
|
|||||||
local ok, res = xpcall(test.f, trace, test)
|
local ok, res = xpcall(test.f, trace, test)
|
||||||
if ok then
|
if ok then
|
||||||
test.state = "pass"
|
test.state = "pass"
|
||||||
print("", string.format("%s %s", color("pass", "PASS"), name_test(test)))
|
print("", ("%s %s"):format(color("pass", "PASS"), name_test(test)))
|
||||||
else
|
else
|
||||||
test.state = "fail"
|
test.state = "fail"
|
||||||
print("", string.format("%s %s\n\n%s\n", color("fail", "FAIL"), name_test(test), res))
|
print("", ("%s %s\n\n%s\n"):format(color("fail", "FAIL"), name_test(test), res))
|
||||||
end
|
end
|
||||||
|
collectgarbage() -- gc after each test to test destructors
|
||||||
return test
|
return test
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -86,6 +87,23 @@ local function start(cx, item)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
local function check_unrefs()
|
||||||
|
-- ensure all refs were properly unref'ed
|
||||||
|
local registry = debug.getregistry()
|
||||||
|
local count = #registry
|
||||||
|
local ref = 0 -- FREELIST_REF
|
||||||
|
while type(registry[ref]) == "number" do
|
||||||
|
local next = registry[ref]
|
||||||
|
registry[ref], ref = nil, next
|
||||||
|
end
|
||||||
|
for i = 1, count do
|
||||||
|
local value = registry[i]
|
||||||
|
if type(value) ~= "thread" then -- ignore threads pinned by the runtime
|
||||||
|
assert(rawequal(registry[i], nil), ("ref %d not unref'ed: %s"):format(i, registry[i]))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
local function main(item)
|
local function main(item)
|
||||||
local cx = { tasks = {} }
|
local cx = { tasks = {} }
|
||||||
local pass, fail = 0, 0
|
local pass, fail = 0, 0
|
||||||
@ -97,17 +115,23 @@ local function main(item)
|
|||||||
fail = fail + 1
|
fail = fail + 1
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
local code = 1
|
||||||
if fail == 0 then
|
if fail == 0 then
|
||||||
print("", color("pass", string.format("%s %d tests passed", icons.check, pass)))
|
print("", color("pass", ("%s %d tests passed"):format(icons.check, pass)))
|
||||||
return 0
|
code = 0
|
||||||
|
else
|
||||||
|
print(
|
||||||
|
"",
|
||||||
|
("%s, %s"):format(
|
||||||
|
color("pass", ("%s %d tests passed"):format(icons.check, pass)),
|
||||||
|
color("fail", ("%s %d tests failed"):format(icons.cross, fail))
|
||||||
|
)
|
||||||
|
)
|
||||||
end
|
end
|
||||||
print(
|
cx = nil
|
||||||
"",
|
collectgarbage()
|
||||||
color("pass", string.format("%s %d tests passed", icons.check, pass))
|
check_unrefs()
|
||||||
.. ", "
|
return code -- report error to cargo
|
||||||
.. color("fail", string.format("%s %d tests failed", icons.cross, fail))
|
|
||||||
)
|
|
||||||
return 1 -- report error to cargo
|
|
||||||
end
|
end
|
||||||
|
|
||||||
return main(create_group("", function()
|
return main(create_group("", function()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user