use clap::Parser; use lb_core::{GlobalState, PrettyError}; use mimalloc::MiMalloc; use owo_colors::OwoColorize; use std::{backtrace::Backtrace, net::SocketAddr, num::NonZero, panic, thread}; use tokio::{runtime, task::LocalSet}; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; fn panic_cb(panic: &panic::PanicHookInfo) { let trace = Backtrace::force_capture(); let location = panic.location().unwrap(); let payload = panic.payload(); let msg = if let Some(s) = payload.downcast_ref::<&'static str>() { s } else if let Some(s) = payload.downcast_ref::() { s.as_str() } else { "unknown error" }; eprint!( "{}", PrettyError::new(msg) .with_trace(trace) .prepend(format_args!( "thread '{}' panicked at {location}", thread::current().name().unwrap_or("") )) ); } #[derive(Debug, Parser)] struct Args { /// Paths to scripts to execute. #[clap(value_name = "SCRIPTS")] paths: Vec, /// Strings to execute. #[clap(long, short = 'e', value_name = "CHUNK")] evals: Vec, /// Libraries to require on startup. #[clap(long, short = 'l', value_name = "NAME")] libs: Vec, /// Console log level. #[clap(long, value_name = "LEVEL", default_value = "debug")] log_level: tracing::Level, /// Number of runtime worker threads. #[clap(long, value_name = "THREADS", default_value_t = Self::threads())] threads: NonZero, /// Number of runtime blocking threads. #[clap(long, value_name = "THREADS", default_value_t = Self::blocking_threads())] blocking_threads: NonZero, /// Enable tokio-console integration. #[clap(long)] enable_console: bool, /// tokio-console publish address. #[clap(long, value_name = "ADDRESS", default_value = "127.0.0.1:6669")] console_addr: SocketAddr, } impl Args { fn threads() -> NonZero { thread::available_parallelism().unwrap_or(NonZero::new(1).unwrap()) } fn blocking_threads() -> NonZero { NonZero::new(1024).unwrap() } } fn main() { panic::set_hook(Box::new(panic_cb)); let args = Args::parse(); init_logger(&args); let runtime = init_runtime(&args); GlobalState::set(init_vm(&args)); let main = LocalSet::new(); main.spawn_local(run(args)); runtime.block_on(main); } fn init_logger(args: &Args) { use tracing::level_filters::LevelFilter; use tracing_subscriber::{Layer, util::*}; let console = tracing_subscriber::fmt() .compact() .with_env_filter( tracing_subscriber::EnvFilter::builder() .with_default_directive(LevelFilter::from(args.log_level).into()) .from_env_lossy(), ) .with_file(false) .with_line_number(false) .with_target(false) .finish(); if args.enable_console { console_subscriber::ConsoleLayer::builder() .with_default_env() .server_addr(args.console_addr) .spawn() .with_subscriber(console) .init() } else { console.init() } } fn init_runtime(args: &Args) -> runtime::Runtime { if args.threads.get() == 1 { runtime::Builder::new_current_thread() } else { runtime::Builder::new_multi_thread() } .enable_all() .thread_name("lb") .worker_threads(args.threads.get() - 1) .max_blocking_threads(args.blocking_threads.get()) .build() .unwrap_or_else(|err| panic!("failed to initialise runtime: {err}")) } fn init_vm(_args: &Args) -> luajit::State { let mut state = luajit::State::new().unwrap_or_else(|err| panic!("failed to initialise runtime: {err}")); let mut registry = luaffi::Registry::new(); registry.include::(); println!("{registry}"); state .load(&luajit::Chunk::new(registry.done()).name("@[luby]")) .and_then(|()| state.call(0, 0)) .unwrap_or_else(|err| panic!("failed to load modules: {err}")); state } async fn run(args: Args) { let mut state = GlobalState::new_thread(); for ref path in args.paths { let chunk = match std::fs::read(path) { Ok(chunk) => chunk, Err(err) => return eprintln!("{}", format!("{path}: {err}").red()), }; if let Err(err) = state.load(&luajit::Chunk::new(chunk).path(path)) { return eprintln!("{}", err.red()); } match state.call_async(0, 0).await { Ok(_) => {} Err(err) => GlobalState::uncaught_error(err), } } }