From 263ca1cf4865d7beca75b055f8686de62242e293 Mon Sep 17 00:00:00 2001 From: luaneko Date: Sat, 28 Jun 2025 20:00:01 +1000 Subject: [PATCH] Add basic tcp listener test --- crates/lb/src/net.rs | 271 +++++++++++++++++++++++++++++++++++----- crates/lb/tests/net.lua | 41 ++++++ 2 files changed, 283 insertions(+), 29 deletions(-) diff --git a/crates/lb/src/net.rs b/crates/lb/src/net.rs index f19246d..899d4ce 100644 --- a/crates/lb/src/net.rs +++ b/crates/lb/src/net.rs @@ -7,9 +7,15 @@ //! //! See [`lb_netlib`] for items exported by this library. use derive_more::{From, FromStr}; -use luaffi::{cdef, marker::OneOf, metatype}; +use luaffi::{ + cdef, + marker::{OneOf, fun}, + metatype, +}; +use luajit::LUA_NOREF; use std::{ cell::{BorrowError, BorrowMutError, Ref, RefCell, RefMut}, + ffi::c_int, net::{AddrParseError, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, time::Duration, }; @@ -493,6 +499,10 @@ impl lb_ipaddr { } /// Socket address, which is an IP address with a port number. +/// +/// This represents a combination of an IP address and a port, such as `127.0.0.1:8080` or +/// `[::1]:443`. It is used to specify endpoints for network connections and listeners, and can be +/// constructed by [`socketaddr`](lb_libnet::socketaddr). #[derive(Debug, Clone, Copy, PartialEq, Eq, From, FromStr)] #[cdef] pub struct lb_socketaddr(#[opaque] SocketAddr); @@ -545,6 +555,11 @@ impl lb_socketaddr { } /// TCP socket which has not yet been converted to an [`lb_tcpstream`] or [`lb_tcplistener`]. +/// +/// This type represents a TCP socket in its initial state, before it is connected or set to listen. +/// It can be configured (e.g., socket options, bind address) before being converted to an +/// [`lb_tcpstream`] (via [`connect`](lb_tcpsocket::connect)) or [`lb_tcplistener`] (via +/// [`listen`](lb_tcpsocket::listen)), after which it can no longer be used. #[derive(Debug)] #[cdef] pub struct lb_tcpsocket(#[opaque] RefCell>); @@ -636,6 +651,7 @@ impl lb_tcpsocket { /// /// This controls how long the socket will remain open after close if unsent data is present. pub extern "Lua-C" fn set_linger(&self, secs: f64) -> Result<()> { + let secs = secs.max(0.); Ok(self .socket()? .set_linger((secs != 0.).then_some(Duration::from_secs_f64(secs)))?) @@ -669,10 +685,11 @@ impl lb_tcpsocket { /// connections. This socket object can no longer be used after this call. pub extern "Lua-C" fn listen(&self, backlog: u32) -> Result { let socket = self.0.borrow_mut().take().ok_or(Error::SocketConsumed)?; - Ok(socket.listen(backlog)?.into()) + Ok(lb_tcplistener::new(socket.listen(backlog)?)) } - /// Connects this socket to the given remote address, transitioning it to an established state. + /// Connects this socket to the given remote socket address, transitioning it to an established + /// state. /// /// This consumes the socket and returns a new [`lb_tcpstream`] that can be used to send and /// receive data. This socket object can no longer be used after this call. @@ -688,6 +705,48 @@ impl lb_tcpsocket { } /// TCP connection between a local and a remote socket. +/// +/// This represents an established TCP connection. It is created by connecting an [`lb_tcpsocket`] +/// to a remote socket address (via [`connect`](lb_tcpsocket::connect)) or accepting a connection +/// from an [`lb_tcplistener`] (via [`accept`](lb_tcplistener::accept)). It provides methods for +/// reading from and writing to the stream asynchronously. +/// +/// This type supports reading and writing data in both directions concurrently. Typically you would +/// spawn one reader and one writer task to handle incoming and outgoing data respectively. +/// Connection is closed when this object goes out of scope and gets garbage collected, or when +/// [`shutdown`](Self::shutdown) is explicitly called. +/// +/// # Example +/// +/// This examples spawns a reader task and a writer task to operate on the stream concurrently. +/// +/// ```lua +/// local task = require("lb:task") +/// local net = require("lb:net") +/// local socket = net.connect_tcp("127.0.0.1:1234") +/// +/// print("local address:", socket:local_addr()) +/// print("remote address:", socket:peer_addr()) +/// +/// local reader = spawn(function() +/// for chunk in socket, 1024 do +/// if chunk ~= nil then +/// print("received:", chunk) +/// else +/// print("read half closed") +/// end +/// done +/// end) +/// +/// local writer = spawn(function() +/// for i = 1, 10 do +/// socket:write(("message %d\n"):format(i)) +/// print("sent message", i) +/// done +/// end) +/// +/// task.join(reader, writer) +/// ``` #[derive(Debug)] #[cdef] pub struct lb_tcpstream { @@ -715,20 +774,20 @@ impl lb_tcpstream { Ok(self.write.try_borrow_mut()?) } - /// Gets the remote address of this stream. + /// Gets the remote socket address of this stream. pub extern "Lua-C" fn peer_addr(&self) -> Result { Ok(self.read_half()?.peer_addr()?.into()) } - /// Gets the local address of this stream. + /// Gets the local socket address of this stream. pub extern "Lua-C" fn local_addr(&self) -> Result { Ok(self.read_half()?.local_addr()?.into()) } /// Waits for this stream to be ready in the given half. /// - /// `half` can be `"read"` for the readable half, `"write"` for the writable half, or `nil` for - /// both. + /// The argument `half` can be `"read"` for the readable half, `"write"` for the writable half, + /// or `nil` for both. pub async extern "Lua-C" fn ready(&self, half: Option<&str>) -> Result<()> { let ty = match half { None => Interest::READABLE | Interest::WRITABLE, @@ -749,11 +808,11 @@ impl lb_tcpstream { /// If the connection was closed, this returns `nil`. pub async extern "Lua-C" fn read(&self, len: u32) -> Result>> { let mut buf = vec![0; len as usize]; - match self.read_half()?.read_exact(&mut buf).await { - Ok(_) => Ok(Some(buf)), - Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None), - Err(err) => Err(err.into()), - } + Ok(match self.read_half()?.read_exact(&mut buf).await { + Ok(_) => Some(buf), + Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => None, + Err(err) => return Err(err.into()), + }) } /// Reads up to `len` bytes from this stream. @@ -763,12 +822,12 @@ impl lb_tcpstream { pub async extern "Lua-C" fn read_partial(&self, len: u32) -> Result>> { let mut buf = vec![0; len as usize]; let n = self.read_half()?.read(&mut buf).await?; - if n == 0 { - Ok(None) + Ok(if n == 0 { + None } else { buf.truncate(n); - Ok(Some(buf)) - } + Some(buf) + }) } /// Attempts to read up to `len` bytes from this stream without waiting. @@ -777,15 +836,15 @@ impl lb_tcpstream { /// queue. If there was no data available or the connection was closed, this returns `nil`. pub extern "Lua-C" fn try_read(&self, len: u32) -> Result>> { let mut buf = vec![0; len as usize]; - match self.read_half()?.try_read(&mut buf) { - Ok(0) => Ok(None), + Ok(match self.read_half()?.try_read(&mut buf) { + Ok(0) => None, Ok(n) => { buf.truncate(n); - Ok(Some(buf)) + Some(buf) } - Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => Ok(None), - Err(err) => Err(err.into()), - } + Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => None, + Err(err) => return Err(err.into()), + }) } /// Writes the given bytes to this stream. @@ -805,12 +864,12 @@ impl lb_tcpstream { pub async extern "Lua-C" fn peek(&self, len: u32) -> Result>> { let mut buf = vec![0; len as usize]; let n = self.read_half()?.peek(&mut buf).await?; - if n == 0 { - Ok(None) + Ok(if n == 0 { + None } else { buf.truncate(n); - Ok(Some(buf)) - } + Some(buf) + }) } /// Shuts down this connection. @@ -825,9 +884,163 @@ impl lb_tcpstream { } /// TCP socket server, listening for connections. -#[derive(Debug, From)] +/// +/// This type represents a TCP server socket that can accept incoming connections. It is created by +/// transitioning an [`lb_tcpsocket`] to the listening state via [`listen`](lb_tcpsocket::listen). +#[derive(Debug)] #[cdef] -pub struct lb_tcplistener(#[opaque] TcpListener); +pub struct lb_tcplistener { + #[opaque] + listener: TcpListener, + __on_accept_ref: c_int, +} #[metatype] -impl lb_tcplistener {} +impl lb_tcplistener { + fn new(listener: TcpListener) -> Self { + Self { + listener, + __on_accept_ref: LUA_NOREF, + } + } + + /// Returns the local socket address that this listener is bound to. + pub extern "Lua-C" fn local_addr(&self) -> Result { + Ok(self.listener.local_addr()?.into()) + } + + /// Gets the value of the `IP_TTL` option for this socket. + pub extern "Lua-C" fn ttl(&self) -> Result { + Ok(self.listener.ttl()?) + } + + /// Sets the value for the `IP_TTL` option on this socket. + pub extern "Lua-C" fn set_ttl(&self, ttl: u32) -> Result<()> { + Ok(self.listener.set_ttl(ttl)?) + } + + /// Registers a callback to be invoked with each new incoming connection before it is converted + /// to an [`lb_tcpstream`]. + /// + /// The callback receives a temporary [`lb_tcplistener_stream`] object, which can be used to + /// configure socket options (such as [`set_nodelay`](lb_tcplistener_stream), + /// [`set_linger`](lb_tcplistener_stream), etc.) before the stream is converted to an + /// [`lb_tcpstream`]. The callback is called synchronously during [`accept`](Self::accept) and + /// should complete as quickly as possible. The provided configurable object is only valid + /// within the callback and is converted to an [`lb_tcpstream`] as soon as it returns. + pub extern "Lua" fn on_accept(&self, cb: fun<(&lb_tcplistener_stream,), ()>) { + assert( + rawequal(cb, ()) || r#type(cb) == "function", + concat!("function expected in argument 'cb', got ", r#type(cb)), + ); + __unref(self.__on_accept_ref); + self.__on_accept_ref = __ref(cb); + } + + /// Accepts a new incoming TCP connection. + /// + /// If an [`on_accept`](Self::on_accept) callback is registered, it is invoked with a temporary + /// [`lb_tcplistener_stream`] object representing the new connection. This allows configuration + /// of socket options for this specific connection, before the stream is converted to an + /// [`lb_tcpstream`] and returned for the connection to be read from or written to. + #[call] + pub async extern "Lua" fn accept(&self) -> Result { + let stream = self.__accept(); + let on_accept = __registry[self.__on_accept_ref]; + if on_accept != () { + on_accept(stream); + } + stream.__convert() + } + + async extern "Lua-C" fn __accept(&self) -> Result { + let (stream, _) = self.listener.accept().await?; + Ok(lb_tcplistener_stream::new(stream)) + } + + #[gc] + extern "Lua" fn gc(&self) { + __unref(self.__on_accept_ref); + } +} + +/// TCP connection that has just been accepted by [`lb_tcplistener`]. +/// +/// This type is passed to the [`on_accept`](lb_tcplistener::on_accept) callback on +/// [`lb_tcplistener`], allowing socket options to be set before the stream is converted to an +/// [`lb_tcpstream`]. After conversion, this object can no longer be used. +#[derive(Debug)] +#[cdef] +pub struct lb_tcplistener_stream(#[opaque] RefCell>); + +#[metatype] +impl lb_tcplistener_stream { + fn new(stream: TcpStream) -> Self { + Self(RefCell::new(Some(stream))) + } + + fn stream<'s>(&'s self) -> Result> { + let socket = self.0.borrow(); + match *socket { + Some(_) => Ok(Ref::map(socket, |s| s.as_ref().unwrap())), + None => Err(Error::SocketConsumed), + } + } + + /// Returns the local socket address that the listener is bound to. + pub extern "Lua-C" fn local_addr(&self) -> Result { + Ok(self.stream()?.local_addr()?.into()) + } + + /// Returns the remote socket address of this stream. + pub extern "Lua-C" fn peer_addr(&self) -> Result { + Ok(self.stream()?.peer_addr()?.into()) + } + + /// Gets the value of the `TCP_NODELAY` option on this stream. + pub extern "Lua-C" fn nodelay(&self) -> Result { + Ok(self.stream()?.nodelay()?) + } + + /// Sets the value of the `TCP_NODELAY` option on this stream. + /// + /// This enables or disables Nagle's algorithm, which delays sending small packets. + pub extern "Lua-C" fn set_nodelay(&self, enabled: bool) -> Result<()> { + Ok(self.stream()?.set_nodelay(enabled)?) + } + + /// Gets the value of the `SO_LINGER` option on this stream, in seconds. + pub extern "Lua-C" fn linger(&self) -> Result { + Ok(self + .stream()? + .linger()? + .map(|n| n.as_secs_f64()) + .unwrap_or(0.)) + } + + /// Sets the value of the `SO_LINGER` option on this stream. + /// + /// This controls how long the stream will remain open after close if unsent data is present. + pub extern "Lua-C" fn set_linger(&self, secs: f64) -> Result<()> { + let secs = secs.max(0.); + Ok(self + .stream()? + .set_linger((secs != 0.).then_some(std::time::Duration::from_secs_f64(secs)))?) + } + + /// Gets the value of the `IP_TTL` option for this stream. + pub extern "Lua-C" fn ttl(&self) -> Result { + Ok(self.stream()?.ttl()?) + } + + /// Sets the value for the `IP_TTL` option on this stream. + pub extern "Lua-C" fn set_ttl(&self, ttl: u32) -> Result<()> { + Ok(self.stream()?.set_ttl(ttl)?) + } + + extern "Lua-C" fn __convert(&self) -> Result { + Ok(lb_tcpstream::new( + self.0.borrow_mut().take().ok_or(Error::SocketConsumed)?, + )) + } +} diff --git a/crates/lb/tests/net.lua b/crates/lb/tests/net.lua index 8505b38..f65bbbe 100644 --- a/crates/lb/tests/net.lua +++ b/crates/lb/tests/net.lua @@ -49,14 +49,20 @@ describe("tcp", function() -- sendbuf socket:set_sendbuf(4096) assert(socket:sendbuf() >= 4096) + assert(not pcall(socket.set_sendbuf, socket, 0)) + assert(not pcall(socket.set_sendbuf, socket, -1)) -- recvbuf socket:set_recvbuf(4096) assert(socket:recvbuf() >= 4096) + assert(not pcall(socket.set_recvbuf, socket, 0)) + assert(not pcall(socket.set_recvbuf, socket, -1)) -- linger socket:set_linger(0) assert(socket:linger() == 0) socket:set_linger(2) assert(math.abs(socket:linger() - 2) < 0.1) + socket:set_linger(-1) + assert(socket:linger() == 0) -- nodelay socket:set_nodelay(true) assert(socket:nodelay() == true) @@ -72,4 +78,39 @@ describe("tcp", function() assert(not pcall(socket.local_addr, socket)) end) end) + + describe("listener", function() + test("accept", function() + local listener = net.listen_tcp("127.0.0.1", 0) + local addr = listener:local_addr() + local accepted = false + local client = net.tcp() + local accepted_stream + listener:on_accept(function(stream) + accepted = true + accepted_stream = stream + -- configure stream + stream:set_nodelay(true) + assert(stream:nodelay() == true) + end) + -- connect client + local client_stream = client:connect(addr) + local server_stream = listener() + assert(accepted) + assert(accepted_stream ~= nil) + -- check addresses + assert(server_stream:local_addr() ~= nil) + assert(server_stream:peer_addr() ~= nil) + assert(client_stream:local_addr() ~= nil) + assert(client_stream:peer_addr() ~= nil) + -- test data transfer + server_stream:write("hello") + local buf = client_stream:read(5) + assert(buf ~= nil and #buf == 5) + assert(buf == "hello") + -- shutdown + server_stream:shutdown() + client_stream:shutdown() + end) + end) end)