use clap::Parser; use mimalloc::MiMalloc; use owo_colors::OwoColorize; use std::{ backtrace::Backtrace, fmt::Display, net::SocketAddr, num::NonZero, panic, process, thread, }; use sysexits::ExitCode; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; fn panic_cb(panic: &panic::PanicHookInfo) { let trace = Backtrace::force_capture(); let location = panic.location().unwrap(); let payload = panic.payload(); let msg = if let Some(s) = payload.downcast_ref::<&'static str>() { s } else if let Some(s) = payload.downcast_ref::() { s.as_str() } else { "unknown error" }; eprint!( "{}\n{trace}", format_args!( "thread '{}' panicked at {location}: {msg}", thread::current().name().unwrap_or("") ) .red() .bold() ); eprintln!( "{}", format_args!( "luby should never panic. Please kindly report this bug at {}.", env!("CARGO_PKG_REPOSITORY") ) .yellow() .bold() ); } #[derive(Debug, Parser)] struct Args { /// Paths to scripts to execute. #[clap(value_name = "SCRIPTS")] path: Vec, /// Strings to execute. #[clap(long, short = 'e', value_name = "CHUNK")] eval: Vec, /// Libraries to require. #[clap(long, short = 'l', value_name = "NAME")] lib: Vec, /// Console log level. #[clap(long, value_name = "LEVEL", default_value = "debug")] log: tracing::Level, /// LuaJIT control commands. #[clap(long, short = 'j', help_heading = "Runtime", value_name = "CMD=FLAGS")] jit: Vec, /// Number of worker threads. #[clap( long, short = 'T', help_heading = "Runtime", value_name = "THREADS", default_value_t = Self::threads() )] threads: NonZero, /// Number of blocking threads. #[clap( long, help_heading = "Runtime", value_name = "THREADS", default_value_t = Self::blocking_threads() )] blocking_threads: NonZero, /// Enable tokio-console integration. #[clap(long, help_heading = "Debugging")] enable_console: bool, /// tokio-console publish address. #[clap( long, help_heading = "Debugging", value_name = "ADDRESS", default_value = "127.0.0.1:6669", requires = "enable_console" )] console_addr: SocketAddr, /// Dump internal data. #[clap( long, help_heading = "Debugging", value_name = "DATA", value_parser = ["cdef"] )] dump: Vec, /// Print version. #[clap(long, short = 'V')] version: bool, } impl Args { fn threads() -> NonZero { thread::available_parallelism().unwrap_or(NonZero::new(1).unwrap()) } fn blocking_threads() -> NonZero { NonZero::new(1024).unwrap() } } fn exit_err(code: ExitCode) -> impl FnOnce(E) -> T { move |err| { eprintln!("{}", err.red().bold()); code.exit() } } fn main() { panic::set_hook(Box::new(panic_cb)); let args = Args::parse(); if args.version { return print_version(); } init_logger(&args); let tokio = init_tokio(&args); let lua = init_lua(&args); let main = lua.spawn(async |s| main_async(args, s).await); tokio.block_on(async { lua.await; main.await.unwrap() }) } fn print_version() { println!("luby {}", env!("VERGEN_GIT_DESCRIBE")); println!("{}\n", env!("CARGO_PKG_HOMEPAGE")); println!("Compiled with {} -- {}", luajit::version(), luajit::url()); println!( "Compiled with rustc {} on {} for {}", env!("VERGEN_RUSTC_SEMVER"), env!("VERGEN_RUSTC_HOST_TRIPLE"), env!("VERGEN_CARGO_TARGET_TRIPLE"), ); } fn init_logger(args: &Args) { use tracing::level_filters::LevelFilter; use tracing_subscriber::{Layer, util::*}; let log = tracing_subscriber::fmt() .compact() .with_env_filter( tracing_subscriber::EnvFilter::builder() .with_default_directive(LevelFilter::from(args.log).into()) .from_env_lossy(), ) .with_file(false) .with_line_number(false) .with_target(false) .finish(); if args.enable_console { console_subscriber::ConsoleLayer::builder() .with_default_env() .server_addr(args.console_addr) .spawn() .with_subscriber(log) .init() } else { log.init() } } fn init_tokio(args: &Args) -> tokio::runtime::Runtime { let mut rt = match args.threads.get() { 1 => tokio::runtime::Builder::new_current_thread(), n => { let mut rt = tokio::runtime::Builder::new_multi_thread(); rt.worker_threads(n - 1); rt } }; rt.enable_all() .thread_name("luby") .max_blocking_threads(args.blocking_threads.get()) .build() .unwrap_or_else(exit_err(ExitCode::OsErr)) } fn init_lua(args: &Args) -> lb::runtime::Runtime { let rt = lb::runtime::Builder::new(); if args.dump.iter().find(|s| *s == "cdef").is_some() { print!("{}", rt.registry()); } let mut rt = rt.build().unwrap_or_else(exit_err(ExitCode::Software)); for arg in args.jit.iter() { let mut s = rt.guard(); if let Some((cmd, flags)) = parse_jitlib_cmd(arg) && let Ok(_) = s.require(format!("jit.{cmd}"), 1) { (s.push("start"), s.get(-2), s.push(flags)); s.call(1, 0) } else { s.require("jit", 1).unwrap(); match arg.as_str() { cmd @ ("on" | "off" | "flush") => { (s.push(cmd), s.get(-2)); s.call(0, 0) } arg => { (s.push("opt"), s.get(-2)); (s.push("start"), s.get(-2), s.push(arg)); s.call(1, 0) } } } .unwrap_or_else(exit_err(ExitCode::Usage)); } rt } fn parse_jitlib_cmd(s: &str) -> Option<(&str, &str)> { match s { "p" => Some(("p", "Flspv10")), "v" => Some(("v", "-")), "dump" => Some(("dump", "tirs")), _ => s.split_once('='), } } async fn main_async(args: Args, state: &mut luajit::State) { for ref path in args.path { let mut s = state.guard(); let chunk = match std::fs::read(path) { Ok(chunk) => chunk, Err(err) => { eprintln!("{}", format_args!("{path}: {err}").red().bold()); ExitCode::NoInput.exit(); } }; s.load(&luajit::Chunk::new(chunk).path(path)) .unwrap_or_else(exit_err(ExitCode::NoInput)); if let Err(err) = s.call_async(0, 0).await { match err.trace() { Some(trace) => eprintln!("{}\n{trace}", err.red().bold()), None => eprintln!("{}", err.red().bold()), } process::exit(1); } } }