diff --git a/crates/lb/Cargo.toml b/crates/lb/Cargo.toml index 1d2e451..2a0aabb 100644 --- a/crates/lb/Cargo.toml +++ b/crates/lb/Cargo.toml @@ -11,7 +11,7 @@ repository.workspace = true runtime = ["tokio/rt"] task = ["tokio/rt", "tokio/time"] fs = ["tokio/fs", "dep:walkdir", "dep:globset", "dep:tempfile"] -net = ["tokio/net"] +net = ["tokio/net", "tokio/io-util"] [dependencies] derive_more = { version = "2.0.1", features = ["full"] } diff --git a/crates/lb/src/net.rs b/crates/lb/src/net.rs index 0194685..8c40c6b 100644 --- a/crates/lb/src/net.rs +++ b/crates/lb/src/net.rs @@ -9,11 +9,15 @@ use derive_more::{From, FromStr}; use luaffi::{cdef, marker::OneOf, metatype}; use std::{ + cell::{Ref, RefCell, RefMut}, net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, time::Duration, }; use thiserror::Error; -use tokio::net::{TcpListener, TcpSocket, TcpStream}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt, Interest}, + net::{TcpListener, TcpSocket, TcpStream}, +}; /// Errors that can be thrown by this library. /// @@ -174,20 +178,24 @@ impl lb_netlib { /// Creates a new TCP socket configured for IPv4. /// + /// This calls `socket(2)` with `AF_INET` and `SOCK_STREAM`. + /// /// # Errors /// /// This function may throw an error if the socket could not be created. pub extern "Lua-C" fn tcp() -> Result { - Ok(Some(TcpSocket::new_v4()?).into()) + Ok(lb_tcpsocket::new(TcpSocket::new_v4()?)) } /// Creates a new TCP socket configured for IPv6. /// + /// This calls `socket(2)` with `AF_INET6` and `SOCK_STREAM`. + /// /// # Errors /// /// This function may throw an error if the socket could not be created. pub extern "Lua-C" fn tcp_v6() -> Result { - Ok(Some(TcpSocket::new_v6()?).into()) + Ok(lb_tcpsocket::new(TcpSocket::new_v6()?)) } /// Creates a new TCP socket bound to the given address and port. @@ -490,44 +498,35 @@ impl lb_socketaddr { self.0.ip().into() } - /// Sets the IP part of this address. - /// - /// This function accepts the same arguments as [`ipaddr`](lb_netlib::ipaddr). - pub extern "Lua" fn set_ip( - &mut self, - addr: OneOf<(&str, &lb_ipaddr, &lb_socketaddr)>, - ) -> &mut Self { - if __istype(__ct.lb_ipaddr, addr) { - self.__set_ip(addr); - } else if __istype(__ct.lb_socketaddr, addr) { - self.__set_ip(addr.ip()); - } else { - self.__set_ip_parse(addr); - } - self - } - - extern "Lua-C" fn __set_ip(&mut self, ip: &lb_ipaddr) { - self.0.set_ip(ip.0); - } - - extern "Lua-C" fn __set_ip_parse(&mut self, addr: &str) -> Result<()> { - Ok(self.0.set_ip(addr.parse()?)) - } - /// Returns the port part of this address. pub extern "Lua-C" fn port(&self) -> u16 { self.0.port() } - /// Sets the port part of this address. - pub extern "Lua" fn set_port(&mut self, port: u16) -> &mut Self { - self.__set_port(port); - self + pub extern "Lua-C" fn with_ip(&self, ip: &lb_ipaddr) -> Self { + SocketAddr::new(ip.0, self.port()).into() } - extern "Lua-C" fn __set_port(&mut self, port: u16) { - self.0.set_port(port) + pub extern "Lua-C" fn with_port(&self, port: u16) -> Self { + SocketAddr::new(self.ip().0, port).into() + } + + /// Returns `true` if the given addresses are equal. + #[eq] + pub extern "Lua-C" fn equals(left: &Self, right: &Self) -> bool { + left.0 == right.0 + } + + /// Returns `true` if the left address is less than the right address. + #[lt] + pub extern "Lua-C" fn less_than(left: &Self, right: &Self) -> bool { + left.0 < right.0 + } + + /// Returns `true` if the left address is less than or equal to the right address. + #[le] + pub extern "Lua-C" fn less_than_or_equals(left: &Self, right: &Self) -> bool { + left.0 <= right.0 } /// Returns the string representation of this address. @@ -538,67 +537,85 @@ impl lb_socketaddr { } /// TCP socket which has not yet been converted to an [`lb_tcpstream`] or [`lb_tcplistener`]. -#[derive(Debug, From)] +#[derive(Debug)] #[cdef] -pub struct lb_tcpsocket(#[opaque] Option); +pub struct lb_tcpsocket(#[opaque] RefCell>); #[metatype] impl lb_tcpsocket { - fn socket(&self) -> Result<&TcpSocket> { - self.0.as_ref().ok_or(Error::SocketConsumed) + fn new(socket: TcpSocket) -> Self { + Self(RefCell::new(Some(socket))) } - /// See [`TcpSocket::keepalive`]. + fn socket(&self) -> Result> { + let socket = self.0.borrow(); + match *socket { + Some(_) => Ok(Ref::map(socket, |s| s.as_ref().unwrap())), + None => Err(Error::SocketConsumed), + } + } + + /// Gets the value of the `SO_KEEPALIVE` option on this socket. pub extern "Lua-C" fn keepalive(&self) -> Result { Ok(self.socket()?.keepalive()?) } - /// See [`TcpSocket::set_keepalive`]. + /// Sets value for the `SO_KEEPALIVE` option on this socket. + /// + /// This enables or disables periodic keepalive messages on the connection. pub extern "Lua-C" fn set_keepalive(&self, enabled: bool) -> Result<()> { Ok(self.socket()?.set_keepalive(enabled)?) } - /// See [`TcpSocket::reuseaddr`]. + /// Gets the value of the `SO_REUSEADDR` option on this socket. pub extern "Lua-C" fn reuseaddr(&self) -> Result { Ok(self.socket()?.reuseaddr()?) } - /// See [`TcpSocket::set_reuseaddr`]. + /// Sets value for the `SO_REUSEADDR` option on this socket. + /// + /// This allows the socket to bind to an address that is already in use. pub extern "Lua-C" fn set_reuseaddr(&self, enabled: bool) -> Result<()> { Ok(self.socket()?.set_reuseaddr(enabled)?) } - /// See [`TcpSocket::reuseport`]. + /// Gets the value of the `SO_REUSEPORT` option on this socket. pub extern "Lua-C" fn reuseport(&self) -> Result { Ok(self.socket()?.reuseport()?) } - /// See [`TcpSocket::set_reuseport`]. + /// Sets value for the `SO_REUSEPORT` option on this socket. + /// + /// This allows multiple sockets to bind to the same port. pub extern "Lua-C" fn set_reuseport(&self, enabled: bool) -> Result<()> { Ok(self.socket()?.set_reuseport(enabled)?) } - /// See [`TcpSocket::send_buffer_size`]. + /// Gets the value of the `SO_SNDBUF` option on this socket. pub extern "Lua-C" fn sendbuf(&self) -> Result { Ok(self.socket()?.send_buffer_size()?) } - /// See [`TcpSocket::set_send_buffer_size`]. + /// Sets value for the `SO_SNDBUF` option on this socket. + /// + /// This sets the size of the send buffer in bytes. pub extern "Lua-C" fn set_sendbuf(&self, size: u32) -> Result<()> { Ok(self.socket()?.set_send_buffer_size(size)?) } - /// See [`TcpSocket::recv_buffer_size`]. + /// Gets the value of the `SO_RCVBUF` option on this socket. pub extern "Lua-C" fn recvbuf(&self) -> Result { Ok(self.socket()?.recv_buffer_size()?) } - /// See [`TcpSocket::set_recv_buffer_size`]. + /// Sets value for the `SO_RCVBUF` option on this socket. + /// + /// This sets the size of the receive buffer in bytes. pub extern "Lua-C" fn set_recvbuf(&self, size: u32) -> Result<()> { Ok(self.socket()?.set_recv_buffer_size(size)?) } - /// See [`TcpSocket::linger`]. + /// Gets the value of the `SO_LINGER` option on this socket, in seconds. pub extern "Lua-C" fn linger(&self) -> Result { Ok(self .socket()? @@ -607,63 +624,212 @@ impl lb_tcpsocket { .unwrap_or(0.)) } - /// See [`TcpSocket::set_linger`]. + /// Sets the value of the `SO_LINGER` option on this socket. + /// + /// This controls how long the socket will remain open after close if unsent data is present. pub extern "Lua-C" fn set_linger(&self, secs: f64) -> Result<()> { Ok(self .socket()? .set_linger((secs != 0.).then_some(Duration::from_secs_f64(secs)))?) } - /// See [`TcpSocket::nodelay`]. + /// Gets the value of the `TCP_NODELAY` option on this socket. pub extern "Lua-C" fn nodelay(&self) -> Result { Ok(self.socket()?.nodelay()?) } - /// See [`TcpSocket::set_nodelay`]. + /// Sets the value of the `TCP_NODELAY` option on this socket. + /// + /// This enables or disables Nagle's algorithm which delays sending small packets. pub extern "Lua-C" fn set_nodelay(&self, enabled: bool) -> Result<()> { Ok(self.socket()?.set_nodelay(enabled)?) } - /// See [`TcpSocket::tos`]. - pub extern "Lua-C" fn tos(&self) -> Result { - Ok(self.socket()?.tos()?) - } - - /// See [`TcpSocket::set_tos`]. - pub extern "Lua-C" fn set_tos(&self, tos: u32) -> Result<()> { - Ok(self.socket()?.set_tos(tos)?) - } - - /// See [`TcpSocket::local_addr`]. + /// Gets the local address that this socket is bound to. pub extern "Lua-C" fn local_addr(&self) -> Result { Ok(self.socket()?.local_addr()?.into()) } - /// See [`TcpSocket::bind`]. + /// Binds this socket to the given local address. pub extern "Lua-C" fn bind(&self, addr: &lb_socketaddr) -> Result<()> { Ok(self.socket()?.bind(addr.0)?) } - /// See [`TcpSocket::listen`]. - pub extern "Lua-C" fn listen(&mut self, backlog: u32) -> Result { - let socket = self.0.take().ok_or(Error::SocketConsumed)?; + /// Transitions this socket to the listening state. + /// + /// This consumes the socket and returns a new [`lb_tcplistener`] that can accept incoming + /// connections. This socket object can no longer be used after this call. + pub extern "Lua-C" fn listen(&self, backlog: u32) -> Result { + let socket = self.0.borrow_mut().take().ok_or(Error::SocketConsumed)?; Ok(socket.listen(backlog)?.into()) } - /// See [`TcpSocket::connect`]. - pub async extern "Lua-C" fn connect(&mut self, addr: &lb_socketaddr) -> Result { - let socket = self.0.take().ok_or(Error::SocketConsumed)?; - Ok(socket.connect(addr.0).await?.into()) + /// Connects this socket to the given remote address, transitioning it to an established state. + /// + /// This consumes the socket and returns a new [`lb_tcpstream`] that can be used to send and + /// receive data. This socket object can no longer be used after this call. + /// + /// # Errors + /// + /// This function may throw an error if connection could not be established to the given remote + /// address. + pub async extern "Lua-C" fn connect(&self, addr: &lb_socketaddr) -> Result { + let socket = self.0.borrow_mut().take().ok_or(Error::SocketConsumed)?; + Ok(lb_tcpstream::new(socket.connect(addr.0).await?)) } } /// TCP connection between a local and a remote socket. -#[derive(Debug, From)] +#[derive(Debug)] #[cdef] -pub struct lb_tcpstream(#[opaque] TcpStream); +pub struct lb_tcpstream(#[opaque] RefCell); #[metatype] -impl lb_tcpstream {} +impl lb_tcpstream { + fn new(stream: TcpStream) -> Self { + Self(RefCell::new(stream)) + } + + fn stream(&self) -> Ref { + self.0.borrow() + } + + fn stream_mut(&self) -> RefMut { + self.0.borrow_mut() + } + + /// Gets the remote address of this stream. + pub extern "Lua-C" fn peer_addr(&self) -> Result { + Ok(self.stream().peer_addr()?.into()) + } + + /// Gets the local address of this stream. + pub extern "Lua-C" fn local_addr(&self) -> Result { + Ok(self.stream().local_addr()?.into()) + } + + /// Gets the value of the `SO_LINGER` option on this stream, in seconds. + pub extern "Lua-C" fn linger(&self) -> Result { + Ok(self + .stream() + .linger()? + .map(|n| n.as_secs_f64()) + .unwrap_or(0.)) + } + + /// Sets the value of the `SO_LINGER` option on this stream. + pub extern "Lua-C" fn set_linger(&self, secs: f64) -> Result<()> { + Ok(self + .stream() + .set_linger((secs != 0.).then_some(Duration::from_secs_f64(secs)))?) + } + + /// Gets the value of the `TCP_NODELAY` option on this stream. + pub extern "Lua-C" fn nodelay(&self) -> Result { + Ok(self.stream().nodelay()?) + } + + /// Sets the value of the `TCP_NODELAY` option on this stream. + /// + /// This enables or disables Nagle's algorithm which delays sending small packets. + pub extern "Lua-C" fn set_nodelay(&self, enabled: bool) -> Result<()> { + Ok(self.stream().set_nodelay(enabled)?) + } + + /// Gets the value of the `IP_TTL` option on this stream. + pub extern "Lua-C" fn ttl(&self) -> Result { + Ok(self.stream().ttl()?) + } + + /// Sets the value of the `IP_TTL` option on this stream. + pub extern "Lua-C" fn set_ttl(&self, ttl: u32) -> Result<()> { + Ok(self.stream().set_ttl(ttl)?) + } + + /// Waits for the stream to be ready for the given half (`"read"`, `"write"`, or `nil` for both halves). + pub async extern "Lua-C" fn ready(&self, half: Option<&str>) -> Result<()> { + let ty = match half { + None => Interest::READABLE | Interest::WRITABLE, + Some("read") => Interest::READABLE, + Some("write") => Interest::WRITABLE, + _ => Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "invalid ready interest", + ))?, + }; + + self.stream().ready(ty).await?; + Ok(()) + } + + /// Reads exactly `len` bytes from the stream. Returns None on EOF. + pub async extern "Lua-C" fn read(&self, len: u32) -> Result>> { + let mut buf = vec![0; len as usize]; + match self.stream_mut().read_exact(&mut buf).await { + Ok(_) => Ok(Some(buf)), + Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None), + Err(err) => Err(err.into()), + } + } + + /// Writes the given bytes to the stream. + pub async extern "Lua-C" fn write(&self, buf: &[u8]) -> Result<()> { + Ok(self.stream_mut().write_all(buf).await?) + } + + /// Reads up to `len` bytes from the stream. Returns None on EOF. + pub async extern "Lua-C" fn read_partial(&self, len: u32) -> Result>> { + let mut buf = vec![0; len as usize]; + let n = self.stream_mut().read(&mut buf).await?; + if n == 0 { + Ok(None) + } else { + buf.truncate(n); + Ok(Some(buf)) + } + } + + /// Writes the given bytes to the stream and returns the number of bytes successfully written. + pub async extern "Lua-C" fn write_partial(&self, buf: &[u8]) -> Result { + Ok(self.stream_mut().write(buf).await? as u32) + } + + /// Attempts to read up to `len` bytes from the stream without waiting. Returns None on EOF. + pub extern "Lua-C" fn try_read(&self, len: u32) -> Result>> { + let mut buf = vec![0u8; len as usize]; + match self.stream_mut().try_read(&mut buf) { + Ok(0) => Ok(None), + Ok(n) => { + buf.truncate(n); + Ok(Some(buf)) + } + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(None), + Err(err) => Err(err.into()), + } + } + + /// Peeks up to `len` bytes at incoming data without consuming it. Returns None on EOF. + pub async extern "Lua-C" fn peek(&self, len: u32) -> Result>> { + let mut buf = vec![0u8; len as usize]; + let n = self.stream_mut().peek(&mut buf).await?; + if n == 0 { + Ok(None) + } else { + buf.truncate(n); + Ok(Some(buf)) + } + } + + /// Shuts down this connection. + pub async extern "Lua-C" fn shutdown(&self) -> Result<()> { + Ok(self.stream_mut().shutdown().await?) + } + + #[call] + pub async extern "Lua" fn __call(&self, len: u32) -> Result>> { + self.read_partial(len) + } +} /// TCP socket server, listening for connections. #[derive(Debug, From)] diff --git a/crates/lb/tests/net.lua b/crates/lb/tests/net.lua index a800206..0cd21c0 100644 --- a/crates/lb/tests/net.lua +++ b/crates/lb/tests/net.lua @@ -1,6 +1,27 @@ local ok, net = pcall(require, "lb:net") if not ok then return end +describe("ipaddr", function() + test("invalid ipaddr throws", function() + assert(not pcall(net.ipaddr, "invalid ip")) + end) + + test("comparison", function() + local a = net.ipaddr("10.0.0.1") + local b = net.ipaddr("10.0.0.1") + local c = net.ipaddr("10.0.0.2") + assert(a ~= nil and a ~= {} and a ~= "10.0.0.1" and a ~= 167772161) + assert(a == a and a == b and a ~= c and b ~= c and c == c and c ~= a) + assert(a <= b and a < c and a <= c and b < c and b <= c and a <= a and c <= c) + assert(not (a < b or a > b or a > c or b > c or a >= c or b >= c)) + end) + + test("tostring", function() + local ip = net.ipaddr("10.0.0.1") + assert(tostring(ip) == "10.0.0.1") + end) +end) + describe("tcp", function() describe("socket", function() test("bind", function()