luby/tests/main.lua
2025-06-29 18:06:56 +10:00

158 lines
4.1 KiB
Lua

if (...) ~= nil and (...).type == "group" then return end -- prevent recursive harness call
local ok = pcall(require, "lb:task")
if not ok then error("lua test harness requires 'lb:task'") end
local ok, time = pcall(require, "lb:time")
if not ok then error("lua test harness requires 'lb:time'") end
local ok, fs = pcall(require, "lb:fs")
if not ok then error("lua test harness requires 'lb:fs'") end
local global = _G
local color = {
reset = "\x1b[0m",
pass = "\x1b[32;1m", -- green
fail = "\x1b[31;1m", -- red
}
local icon = {
check = "\u{2713}",
cross = "\u{00d7}",
chevron = "\u{203a}",
}
local function style(name, s)
return ("%s%s%s"):format(color[name], s, color.reset)
end
local function create_test(name, f, group)
local test = { type = "test", name = name or "", group = group, state = "pending", f = f }
local fenv = setmetatable({}, { __index = global })
setfenv(f, fenv)
return test
end
local function create_group(name, f, parent)
local group = { type = "group", name = name or "", parent = parent, items = {} }
local fenv = setmetatable({
describe = function(name, f)
local item = create_group(name, f, group)
table.insert(group.items, item)
return item
end,
test = function(name, f)
local item = create_test(name, f, group)
table.insert(group.items, item)
return item
end,
}, { __index = global })
setfenv(f, fenv)
f(group)
return group
end
local function name_test(test)
local name = test.name
local group = test.group
while group ~= nil do
if group.name ~= "" then name = ("%s %s %s"):format(group.name, icon.chevron, name) end
group = group.parent
end
return name
end
local function trace(msg)
return style("fail", msg) .. debug.traceback("", 2):sub(("\nstack traceback:"):len() + 1)
end
local function run_test(test)
local ok, res = xpcall(test.f, trace, test)
if ok then
test.state = "pass"
print("", ("%s %s"):format(style("pass", "PASS"), name_test(test)))
else
test.state = "fail"
print("", ("%s %s\n\n%s\n"):format(style("fail", "FAIL"), name_test(test), res))
end
collectgarbage() -- gc after each test to test destructors
return test
end
local function start(cx, item)
if item.type == "test" then
table.insert(cx.tasks, spawn(run_test, item))
elseif item.type == "group" then
for _, item in ipairs(item.items) do
start(cx, item)
end
end
end
local function check_refs()
-- ensure all refs were properly unref'ed
local registry = debug.getregistry()
local count = #registry
local ref = 0 -- FREELIST_REF
while type(registry[ref]) == "number" do
local next = registry[ref]
registry[ref], ref = nil, next
end
for i = 1, count do
local value = registry[i]
if type(value) ~= "thread" then -- ignore threads pinned by the runtime
assert(rawequal(registry[i], nil), ("ref %d not unref'ed: %s"):format(i, registry[i]))
end
end
end
local function main(item)
local cx = { tasks = {} }
local time, pass, fail = time.instant(), 0, 0
start(cx, item)
for _, task in ipairs(cx.tasks) do
if task:await().state == "pass" then
pass = pass + 1
else
fail = fail + 1
end
end
local elapsed = time:elapsed()
local code = 1
if fail == 0 then
print("", style("pass", ("%s %d tests passed"):format(icon.check, pass)))
code = 0
else
print(
"",
("%s, %s"):format(
style("pass", ("%s %d tests passed"):format(icon.check, pass)),
style("fail", ("%s %d tests failed"):format(icon.cross, fail))
)
)
end
if elapsed < 1000 then
print("", ("%s completed in %.2f ms"):format(icon.chevron, elapsed * 1000))
else
print("", ("%s completed in %.2f s"):format(icon.chevron, elapsed))
end
cx = nil
collectgarbage()
check_refs()
return code -- report error to cargo
end
return main(create_group("", function()
local function include(path, pat)
for entry in fs.glob_dir(path, pat) do
local path = entry:path()
local f, err = loadfile(path)
if not f then error(err) end
describe(path, f)
end
end
include("tests", "**/*.lua")
include("crates", "*/tests/**/*.lua")
end))