diff --git a/crates/luaify/src/generate.rs b/crates/luaify/src/generate.rs index 18d528b..3bee6a3 100644 --- a/crates/luaify/src/generate.rs +++ b/crates/luaify/src/generate.rs @@ -1,30 +1,38 @@ use crate::utils::{syn_assert, syn_error, wrap_expr_block}; +use proc_macro2::TokenStream; +use quote::quote; use std::fmt::Display; use syn::{ext::*, punctuated::*, spanned::*, *}; -pub fn generate(expr: &Expr) -> Result { +pub fn generate(expr: &Expr) -> Result { let mut f = Formatter::default(); - generate_expr(&mut f, expr, Context::expr(true))?; - f.done() + match expr { + Expr::Block(block) => generate_block(&mut f, &block.block, Context::stmt(true))?, + _ => generate_expr(&mut f, expr, Context::expr(true))?, + } + + Ok(f.done()) } #[derive(Debug, Default)] struct Formatter { buf: String, space: bool, + format_args: Vec, } impl Formatter { fn write(&mut self, s: impl Display) -> &mut Self { fn sep(c: char) -> bool { match c { - '(' | ')' | '[' | ']' | '+' | '-' | '*' | '/' | '%' | '^' | '#' | '=' | '~' - | '<' | '>' | ':' | ';' | '.' | ',' | '\'' | '"' | ' ' => true, + '(' | ')' | '[' | ']' | '{' | '}' | '+' | '-' | '*' | '/' | '%' | '^' | '#' + | '=' | '~' | '<' | '>' | ':' | ';' | '.' | ',' | '\'' | '"' | ' ' => true, _ => false, } } - let s = format!("{s}"); + // we are inside a format string, so we need to escape braces + let s = format!("{s}").replace("{", "{{").replace("}", "}}"); if !s.is_empty() { if self.space && !sep(s.chars().next().unwrap()) { self.buf.push(' '); @@ -32,11 +40,27 @@ impl Formatter { self.buf.push_str(&s); self.space = !sep(s.chars().last().unwrap()); } + self } - fn done(self) -> Result { - Ok(self.buf) + fn interpolate(&mut self, arg: Expr) -> &mut Self { + self.space.then(|| self.buf.push(' ')); + self.buf.push_str("{}"); + self.format_args.push(arg); + self.space = true; + self + } + + fn done(self) -> TokenStream { + let fmt = self.buf; + let args = self.format_args; + if args.is_empty() { + let fmt = fmt.replace("{{", "{").replace("}}", "}"); + quote!(#fmt) + } else { + quote!(::std::format!(#fmt, #(#args),*)) + } } } @@ -119,7 +143,7 @@ fn generate_expr_assign(f: &mut Formatter, ass: &ExprAssign, cx: Context) -> Res ass, "assignment must be in statement position" ); - generate_expr(f, &ass.left, Context::expr(false))?; + generate_expr(f, &ass.left, Context::expr(true))?; f.write("="); generate_expr(f, &ass.right, Context::expr(true))?; Ok(()) @@ -211,16 +235,16 @@ fn generate_expr_binary(f: &mut Formatter, bin: &ExprBinary, cx: Context) -> Res BinOp::DivAssign(_) => assign_bin_op!("/"), BinOp::Rem(_) => call_op!("math.fmod"), BinOp::RemAssign(_) => assign_call_op!("math.fmod"), - BinOp::BitAnd(_) => call_op!("bit.band"), - BinOp::BitAndAssign(_) => assign_call_op!("bit.band"), - BinOp::BitOr(_) => call_op!("bit.bor"), - BinOp::BitOrAssign(_) => assign_call_op!("bit.bor"), - BinOp::BitXor(_) => call_op!("bit.bxor"), - BinOp::BitXorAssign(_) => assign_call_op!("bit.bxor"), - BinOp::Shl(_) => call_op!("bit.lshift"), - BinOp::ShlAssign(_) => assign_call_op!("bit.lshift"), - BinOp::Shr(_) => call_op!("bit.arshift"), - BinOp::ShrAssign(_) => assign_call_op!("bit.arshift"), + BinOp::BitAnd(_) => call_op!("band"), + BinOp::BitAndAssign(_) => assign_call_op!("band"), + BinOp::BitOr(_) => call_op!("bor"), + BinOp::BitOrAssign(_) => assign_call_op!("bor"), + BinOp::BitXor(_) => call_op!("bxor"), + BinOp::BitXorAssign(_) => assign_call_op!("bxor"), + BinOp::Shl(_) => call_op!("lshift"), + BinOp::ShlAssign(_) => assign_call_op!("lshift"), + BinOp::Shr(_) => call_op!("arshift"), + BinOp::ShrAssign(_) => assign_call_op!("arshift"), BinOp::Eq(_) => bin_op!("=="), BinOp::Lt(_) => bin_op!("<"), BinOp::Le(_) => bin_op!("<="), @@ -239,7 +263,7 @@ fn generate_expr_block(f: &mut Formatter, block: &ExprBlock, cx: Context) -> Res assert_no_attrs!(block); syn_assert!(cx.is_stmt(), block, "block must be in statement position"); f.write("do"); - generate_block_body(f, &block.block, cx)?; + generate_block(f, &block.block, cx)?; if let Some(ref label) = block.label { generate_label_continue(f, label)?; } @@ -293,7 +317,7 @@ fn generate_expr_closure(f: &mut Formatter, clo: &ExprClosure, cx: Context) -> R f.write("function("); generate_punctuated_pat(f, &clo.inputs)?; f.write(")"); - generate_block_body(f, &wrap_expr_block(&clo.body), Context::stmt(true))?; + generate_block(f, &wrap_expr_block(&clo.body), Context::stmt(true))?; f.write("end"); Ok(()) } @@ -359,7 +383,7 @@ fn generate_expr_forloop(f: &mut Formatter, fo: &ExprForLoop, cx: Context) -> Re } } f.write("do"); - generate_block_body(f, &fo.body, Context::stmt(false))?; + generate_block(f, &fo.body, Context::stmt(false))?; if let Some(ref label) = fo.label { generate_label_continue(f, label)?; } @@ -377,7 +401,7 @@ fn generate_expr_if(f: &mut Formatter, mut xif: &ExprIf, cx: Context) -> Result< loop { generate_expr(f, &xif.cond, Context::expr(false))?; f.write("then"); - generate_block_body(f, &xif.then_branch, cx)?; + generate_block(f, &xif.then_branch, cx)?; if let Some((_, ref expr)) = xif.else_branch { match **expr { Expr::If(ref elseif) => { @@ -387,7 +411,7 @@ fn generate_expr_if(f: &mut Formatter, mut xif: &ExprIf, cx: Context) -> Result< } ref els => { f.write("else"); - generate_block_body(f, &wrap_expr_block(els), cx)?; + generate_block(f, &wrap_expr_block(els), cx)?; } } } @@ -427,7 +451,7 @@ fn generate_expr_loop(f: &mut Formatter, lo: &ExprLoop, cx: Context) -> Result<( assert_no_attrs!(lo); syn_assert!(cx.is_stmt(), lo, "loop must be in statement position"); f.write("while true do"); - generate_block_body(f, &lo.body, Context::stmt(false))?; + generate_block(f, &lo.body, Context::stmt(false))?; if let Some(ref label) = lo.label { generate_label_continue(f, label)?; } @@ -533,7 +557,7 @@ fn generate_expr_while(f: &mut Formatter, whil: &ExprWhile, cx: Context) -> Resu f.write("while"); generate_expr(f, &whil.cond, Context::expr(false))?; f.write("do"); - generate_block_body(f, &whil.body, Context::stmt(false))?; + generate_block(f, &whil.body, Context::stmt(false))?; if let Some(ref label) = whil.label { generate_label_continue(f, label)?; } @@ -665,7 +689,7 @@ fn generate_path_segment(f: &mut Formatter, seg: &PathSegment) -> Result<()> { } } -fn generate_block_body(f: &mut Formatter, block: &Block, cx: Context) -> Result<()> { +fn generate_block(f: &mut Formatter, block: &Block, cx: Context) -> Result<()> { let len = block.stmts.len(); for (i, stmt) in block.stmts.iter().enumerate() { match stmt { @@ -716,7 +740,7 @@ fn generate_item_fn(f: &mut Formatter, func: &ItemFn) -> Result<()> { }; f.write("local"); generate_signature(f, &func.sig)?; - generate_block_body(f, &func.block, Context::stmt(true))?; + generate_block(f, &func.block, Context::stmt(true))?; f.write("end"); Ok(()) } @@ -771,17 +795,17 @@ fn generate_receiver(f: &mut Formatter, recv: &Receiver) -> Result<()> { fn generate_macro(f: &mut Formatter, mac: &Macro, cx: Context) -> Result<()> { match format!("{}", mac.path.require_ident()?).as_str() { "concat" => generate_macro_concat(f, mac, cx), + "embed" => generate_macro_embed(f, mac, cx), + "raw" => generate_macro_raw(f, mac, cx), name => syn_error!(mac.path, "unknown macro '{name}'"), } } fn generate_macro_concat(f: &mut Formatter, mac: &Macro, cx: Context) -> Result<()> { - syn_assert!(cx.is_value(), mac, "must be in expression position"); + syn_assert!(cx.is_value(), mac, "concat! must be in expression position"); cx.is_ret().then(|| f.write("return")); let args = mac.parse_body_with(>::parse_terminated)?; - if args.is_empty() { - syn_error!(mac.path, "expected at least one argument") - } + syn_assert!(!args.is_empty(), mac, "expected at least one argument"); for (i, arg) in args.iter().enumerate() { (i != 0).then(|| f.write("..")); generate_expr(f, arg, Context::expr(false))?; @@ -789,6 +813,23 @@ fn generate_macro_concat(f: &mut Formatter, mac: &Macro, cx: Context) -> Result< Ok(()) } +fn generate_macro_embed(f: &mut Formatter, mac: &Macro, cx: Context) -> Result<()> { + syn_assert!(cx.is_value(), mac, "embed! must be in expression position"); + cx.is_ret().then(|| f.write("return")); + let arg = mac.parse_body::()?; + f.write("\""); + f.interpolate(parse_quote! { + str::replace(&str::replace(&::std::format!("{}", #arg), "\\", "\\\\"), "\"", "\\\"") + }); + f.write("\""); + Ok(()) +} + +fn generate_macro_raw(f: &mut Formatter, mac: &Macro, _cx: Context) -> Result<()> { + f.interpolate(mac.parse_body()?); + Ok(()) +} + #[derive(Debug, Clone, Copy)] enum PatContext { Single, @@ -808,6 +849,7 @@ impl PatContext { fn generate_pat(f: &mut Formatter, pat: &Pat, cx: PatContext) -> Result<()> { match pat { Pat::Ident(ident) => generate_pat_ident(f, ident, cx), + Pat::Macro(mac) => generate_pat_macro(f, mac, cx), Pat::Tuple(tuple) => generate_pat_tuple(f, tuple, cx), Pat::Type(typed) => generate_pat_typed(f, typed, cx), Pat::Wild(wild) => generate_pat_wild(f, wild, cx), @@ -829,6 +871,11 @@ fn generate_pat_ident(f: &mut Formatter, ident: &PatIdent, _cx: PatContext) -> R generate_ident(f, &ident.ident) } +fn generate_pat_macro(f: &mut Formatter, mac: &PatMacro, _cx: PatContext) -> Result<()> { + assert_no_attrs!(mac); + generate_macro(f, &mac.mac, Context::expr(false)) +} + fn generate_pat_tuple(f: &mut Formatter, tuple: &PatTuple, cx: PatContext) -> Result<()> { assert_no_attrs!(tuple); match tuple.elems.len() { diff --git a/crates/luaify/src/lib.rs b/crates/luaify/src/lib.rs index 03a5c41..8440456 100644 --- a/crates/luaify/src/lib.rs +++ b/crates/luaify/src/lib.rs @@ -1,6 +1,6 @@ use crate::{generate::generate, transform::transform}; use proc_macro::TokenStream as TokenStream1; -use quote::{ToTokens, quote}; +use quote::ToTokens; use syn::parse_macro_input; mod generate; @@ -11,7 +11,7 @@ mod utils; pub fn luaify(input: TokenStream1) -> TokenStream1 { let mut expr = parse_macro_input!(input); match transform(&mut expr).and_then(|()| generate(&expr)) { - Ok(s) => quote!(#s).into_token_stream(), + Ok(s) => s, Err(err) => err.into_compile_error().into_token_stream(), } .into() diff --git a/crates/luaify/src/transform.rs b/crates/luaify/src/transform.rs index 4e4b3cc..28dbbf3 100644 --- a/crates/luaify/src/transform.rs +++ b/crates/luaify/src/transform.rs @@ -1,6 +1,6 @@ -use std::mem; - use crate::utils::{LuaType, syn_error, unwrap_expr_ident, unwrap_pat_ident, wrap_expr_block}; +use quote::format_ident; +use std::mem; use syn::{spanned::*, visit_mut::*, *}; pub fn transform(expr: &mut Expr) -> Result<()> { @@ -158,7 +158,7 @@ impl Visitor { } }; - let tmp = Ident::new(&format!("_{ident}"), ident.span()); + let tmp = format_ident!("_{ident}"); let span = cast.span(); *expr = match ty { LuaType::Any => parse_quote_spanned!(span => {}), diff --git a/crates/luaify/tests/test.rs b/crates/luaify/tests/test.rs index 58a37e9..825595a 100644 --- a/crates/luaify/tests/test.rs +++ b/crates/luaify/tests/test.rs @@ -10,6 +10,12 @@ fn raw_ident() { assert_eq!(luaify!(r#mut::r#ref()), r#"mut.ref()"#); } +#[test] +fn escape() { + assert_eq!(luaify!("\nmy\tstring\x00a"), r#""\nmy\tstring\x00a""#); + assert_eq!(luaify!(r#" "raw string" "#), r#"" \"raw string\" ""#); +} + #[test] fn indexing() { assert_eq!(luaify!(table.0), r#"table[0]"#); @@ -274,18 +280,15 @@ fn ops() { assert_eq!(luaify!(|| a / b), r#"function()return a/b;end"#); assert_eq!(luaify!(|| a /= b), r#"function()a=a/b;end"#); assert_eq!(luaify!(|| a = b % c), r#"function()a=math.fmod(b,c);end"#); - assert_eq!(luaify!(|| a = b << c), r#"function()a=bit.lshift(b,c);end"#); + assert_eq!(luaify!(|| a = b << c), r#"function()a=lshift(b,c);end"#); assert_eq!( luaify!(|| a <<= b << c), - r#"function()a=bit.lshift(a,bit.lshift(b,c));end"# - ); - assert_eq!( - luaify!(|| a = b >> c), - r#"function()a=bit.arshift(b,c);end"# + r#"function()a=lshift(a,lshift(b,c));end"# ); + assert_eq!(luaify!(|| a = b >> c), r#"function()a=arshift(b,c);end"#); assert_eq!( luaify!(|| a >>= b >> c), - r#"function()a=bit.arshift(a,bit.arshift(b,c));end"# + r#"function()a=arshift(a,arshift(b,c));end"# ); assert_eq!(luaify!(|| a && b), r#"function()return a and b;end"#); assert_eq!(luaify!(|| a || b), r#"function()return a or b;end"#); @@ -307,15 +310,15 @@ fn ops() { ); assert_eq!( luaify!(|| -a || !--b && c >> d), - r#"function()return-a or not-(-b)and bit.arshift(c,d);end"# + r#"function()return-a or not-(-b)and arshift(c,d);end"# ); assert_eq!( luaify!(|| -a || !(--b && c) >> d), - r#"function()return-a or bit.arshift(not(-(-b)and c),d);end"# + r#"function()return-a or arshift(not(-(-b)and c),d);end"# ); assert_eq!( luaify!(|| a >> b << c >> d), - r#"function()return bit.arshift(bit.lshift(bit.arshift(a,b),c),d);end"# + r#"function()return arshift(lshift(arshift(a,b),c),d);end"# ); }