luby/src/main.rs

281 lines
7.3 KiB
Rust

use clap::Parser;
use mimalloc::MiMalloc;
use owo_colors::OwoColorize;
use std::{backtrace::Backtrace, fmt::Display, num::NonZero, panic, process, thread};
use sysexits::ExitCode;
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
fn panic_cb(panic: &panic::PanicHookInfo) {
let trace = Backtrace::force_capture();
let location = panic.location().unwrap();
let payload = panic.payload();
let msg = if let Some(s) = payload.downcast_ref::<&'static str>() {
s
} else if let Some(s) = payload.downcast_ref::<String>() {
s.as_str()
} else {
"unknown error"
};
eprint!(
"{}\n{trace}",
format_args!(
"thread '{}' panicked at {location}: {msg}",
thread::current().name().unwrap_or("<unnamed>")
)
.red()
.bold()
);
eprintln!(
"{}",
format_args!(
"luby should never panic. Please kindly report this bug at {}.",
env!("CARGO_PKG_REPOSITORY")
)
.yellow()
.bold()
);
}
#[derive(Debug, Parser)]
struct Args {
/// Paths to scripts to execute.
#[clap(value_name = "SCRIPTS")]
path: Vec<String>,
/// Strings to execute.
#[clap(long, short = 'e', value_name = "CHUNK")]
eval: Vec<String>,
/// Libraries to require.
#[clap(long, short = 'l', value_name = "NAME")]
lib: Vec<String>,
/// Console log level.
#[clap(long, value_name = "LEVEL", default_value = "debug")]
log: tracing::Level,
/// LuaJIT control commands.
#[clap(long, short = 'j', help_heading = "Runtime", value_name = "CMD=FLAGS")]
jit: Vec<String>,
/// Number of worker threads.
#[clap(
long,
short = 'T',
help_heading = "Runtime",
value_name = "THREADS",
default_value_t = Self::threads()
)]
threads: NonZero<usize>,
/// Number of blocking threads.
#[clap(
long,
help_heading = "Runtime",
value_name = "THREADS",
default_value_t = Self::blocking_threads()
)]
blocking_threads: NonZero<usize>,
/// Enable tokio-console integration.
#[cfg(feature = "tokio-console")]
#[clap(long, help_heading = "Debugging")]
enable_console: bool,
/// tokio-console publish address.
#[cfg(feature = "tokio-console")]
#[clap(
long,
help_heading = "Debugging",
value_name = "ADDRESS",
default_value = "127.0.0.1:6669",
requires = "enable_console"
)]
console_addr: std::net::SocketAddr,
/// Dump internal data.
#[clap(
long,
help_heading = "Debugging",
value_name = "DATA",
value_parser = ["cdef"]
)]
dump: Vec<String>,
/// Print version.
#[clap(long, short = 'V')]
version: bool,
}
impl Args {
fn threads() -> NonZero<usize> {
thread::available_parallelism().unwrap_or(NonZero::new(1).unwrap())
}
fn blocking_threads() -> NonZero<usize> {
NonZero::new(1024).unwrap()
}
}
fn main() {
panic::set_hook(Box::new(panic_cb));
let args = Args::parse();
if args.version {
return print_version();
}
init_logger(&args);
let tokio = init_tokio(&args);
let lua = init_lua(&args);
let main = lua.spawn(async |s| main_async(args, s).await);
tokio.block_on(async {
lua.await;
main.await.unwrap()
})
}
fn print_version() {
println!("luby {}", env!("VERGEN_GIT_DESCRIBE"));
println!("{}\n", env!("CARGO_PKG_HOMEPAGE"));
println!("Compiled with {} -- {}", luajit::version(), luajit::url());
println!(
"Compiled with rustc {} on {} for {}",
env!("VERGEN_RUSTC_SEMVER"),
env!("VERGEN_RUSTC_HOST_TRIPLE"),
env!("VERGEN_CARGO_TARGET_TRIPLE"),
);
}
fn unwrap_exit<T, E: Display>(code: ExitCode) -> impl FnOnce(E) -> T {
move |err| {
eprintln!("{}", err.red().bold());
code.exit()
}
}
fn init_logger(args: &Args) {
use tracing::level_filters::LevelFilter;
use tracing_subscriber::util::*;
let log = tracing_subscriber::fmt()
.compact()
.with_env_filter(
tracing_subscriber::EnvFilter::builder()
.with_default_directive(LevelFilter::from(args.log).into())
.from_env_lossy(),
)
.with_file(false)
.with_line_number(false)
.with_target(false)
.finish();
#[cfg(feature = "tokio-console")]
{
use tracing_subscriber::Layer;
console_subscriber::ConsoleLayer::builder()
.with_default_env()
.server_addr(args.console_addr)
.spawn()
.with_subscriber(log)
.init();
}
#[cfg(not(feature = "tokio-console"))]
log.init();
}
fn init_tokio(args: &Args) -> tokio::runtime::Runtime {
let mut rt = match args.threads.get() {
1 => tokio::runtime::Builder::new_current_thread(),
n => {
let mut rt = tokio::runtime::Builder::new_multi_thread();
rt.worker_threads(n - 1);
rt
}
};
rt.enable_all()
.thread_name("luby")
.max_blocking_threads(args.blocking_threads.get())
.build()
.unwrap_or_else(unwrap_exit(ExitCode::OsErr))
}
fn init_lua(args: &Args) -> lb::runtime::Runtime {
let mut rt = lb::runtime::Builder::new();
luby::open(&mut rt);
if args.dump.iter().find(|s| *s == "cdef").is_some() {
print!("{}", rt.registry());
}
let mut rt = rt.build().unwrap();
for arg in args.jit.iter() {
let mut s = rt.guard();
if let Some((cmd, flags)) = parse_jitlib_cmd(arg)
&& let Ok(_) = s.require(format!("jit.{cmd}"), 1)
{
(s.push("start"), s.get(-2), s.push(flags));
s.call(1, 0)
} else {
s.require("jit", 1).unwrap();
match arg.as_str() {
cmd @ ("on" | "off" | "flush") => {
(s.push(cmd), s.get(-2));
s.call(0, 0)
}
arg => {
(s.push("opt"), s.get(-2));
(s.push("start"), s.get(-2), s.push(arg));
s.call(1, 0)
}
}
}
.unwrap_or_else(unwrap_exit(ExitCode::Usage));
}
rt
}
fn parse_jitlib_cmd(s: &str) -> Option<(&str, &str)> {
match s {
"p" => Some(("p", "Flspv10")), // default -jp flags
"v" => Some(("v", "-")), // default -jv flags
"dump" => Some(("dump", "tirs")), // default -jdump flags
_ => s.split_once('='),
}
}
async fn main_async(args: Args, state: &mut luajit::State) {
for ref path in args.path {
let mut s = state.guard();
let chunk = match std::fs::read(path) {
Ok(chunk) => chunk,
Err(err) => {
eprintln!("{}", format_args!("{path}: {err}").red().bold());
ExitCode::NoInput.exit();
}
};
s.load(&luajit::Chunk::new(chunk).path(path))
.unwrap_or_else(unwrap_exit(ExitCode::NoInput));
if let Err(err) = s.call_async(0, 0).await {
match err.trace() {
Some(trace) => eprintln!("{}\n{trace}", err.red().bold()),
None => eprintln!("{}", err.red().bold()),
}
process::exit(1);
}
}
}