luby/crates/lb/src/net/tcp.rs
2025-06-30 20:07:13 +10:00

620 lines
22 KiB
Rust

use super::*;
use luaffi::{cdef, marker::fun, metatype};
use luajit::LUA_NOREF;
use std::io::ErrorKind;
use std::{
cell::{Ref, RefCell, RefMut},
ffi::c_int,
time::Duration,
};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, Interest},
net::{TcpListener, TcpSocket, TcpStream},
};
/// TCP socket which has not yet been converted to an [`lb_tcpstream`] or [`lb_tcplistener`].
///
/// This type represents a TCP socket in its initial state, before it is connected or set to listen.
/// It can be configured (e.g., socket options, bind address) before being converted to an
/// [`lb_tcpstream`] (via [`connect`](lb_tcpsocket::connect)) or [`lb_tcplistener`] (via
/// [`listen`](lb_tcpsocket::listen)), after which it can no longer be used.
///
/// Methods on this type may fail if the operating system does not support the requested operation.
#[derive(Debug)]
#[cdef]
pub struct lb_tcpsocket(#[opaque] RefCell<Option<TcpSocket>>);
#[metatype]
impl lb_tcpsocket {
pub(super) fn new(socket: TcpSocket) -> Self {
Self(RefCell::new(Some(socket)))
}
fn socket<'s>(&'s self) -> Result<Ref<'s, TcpSocket>> {
let socket = self.0.borrow();
match *socket {
Some(_) => Ok(Ref::map(socket, |s| s.as_ref().unwrap())),
None => Err(Error::SocketConsumed),
}
}
/// Gets the value of the `SO_KEEPALIVE` option on this socket.
pub extern "Lua-C" fn keepalive(&self) -> Result<bool> {
Ok(self.socket()?.keepalive()?)
}
/// Sets value for the `SO_KEEPALIVE` option on this socket.
///
/// This enables or disables periodic keepalive messages on the connection.
pub extern "Lua-C" fn set_keepalive(&self, enabled: bool) -> Result<()> {
Ok(self.socket()?.set_keepalive(enabled)?)
}
/// Gets the value of the `SO_REUSEADDR` option on this socket.
pub extern "Lua-C" fn reuseaddr(&self) -> Result<bool> {
Ok(self.socket()?.reuseaddr()?)
}
/// Sets value for the `SO_REUSEADDR` option on this socket.
///
/// This allows the socket to bind to an address that is already in use.
pub extern "Lua-C" fn set_reuseaddr(&self, enabled: bool) -> Result<()> {
Ok(self.socket()?.set_reuseaddr(enabled)?)
}
/// Gets the value of the `SO_REUSEPORT` option on this socket.
pub extern "Lua-C" fn reuseport(&self) -> Result<bool> {
Ok(self.socket()?.reuseport()?)
}
/// Sets value for the `SO_REUSEPORT` option on this socket.
///
/// This allows multiple sockets to bind to the same port.
pub extern "Lua-C" fn set_reuseport(&self, enabled: bool) -> Result<()> {
Ok(self.socket()?.set_reuseport(enabled)?)
}
/// Gets the value of the `SO_SNDBUF` option on this socket.
pub extern "Lua-C" fn sendbuf(&self) -> Result<u32> {
Ok(self.socket()?.send_buffer_size()?)
}
/// Sets value for the `SO_SNDBUF` option on this socket.
///
/// This sets the size of the send buffer in bytes.
pub extern "Lua-C" fn set_sendbuf(&self, size: u32) -> Result<()> {
Ok(self.socket()?.set_send_buffer_size(size)?)
}
/// Gets the value of the `SO_RCVBUF` option on this socket.
pub extern "Lua-C" fn recvbuf(&self) -> Result<u32> {
Ok(self.socket()?.recv_buffer_size()?)
}
/// Sets value for the `SO_RCVBUF` option on this socket.
///
/// This sets the size of the receive buffer in bytes.
pub extern "Lua-C" fn set_recvbuf(&self, size: u32) -> Result<()> {
Ok(self.socket()?.set_recv_buffer_size(size)?)
}
/// Gets the value of the `SO_LINGER` option on this socket, in seconds.
pub extern "Lua-C" fn linger(&self) -> Result<f64> {
Ok(self
.socket()?
.linger()?
.map(|n| n.as_secs_f64())
.unwrap_or(0.))
}
/// Sets the value of the `SO_LINGER` option on this socket.
///
/// This controls how long the socket will remain open after close if unsent data is present.
pub extern "Lua-C" fn set_linger(&self, secs: f64) -> Result<()> {
let secs = secs.max(0.);
Ok(self
.socket()?
.set_linger((secs != 0.).then_some(Duration::from_secs_f64(secs)))?)
}
/// Gets the value of the `TCP_NODELAY` option on this socket.
pub extern "Lua-C" fn nodelay(&self) -> Result<bool> {
Ok(self.socket()?.nodelay()?)
}
/// Sets the value of the `TCP_NODELAY` option on this socket.
///
/// This enables or disables Nagle's algorithm, which delays sending small packets.
pub extern "Lua-C" fn set_nodelay(&self, enabled: bool) -> Result<()> {
Ok(self.socket()?.set_nodelay(enabled)?)
}
/// Gets the local address that this socket is bound to.
pub extern "Lua-C" fn local_addr(&self) -> Result<lb_socketaddr> {
Ok(self.socket()?.local_addr()?.into())
}
/// Binds this socket to the given local address.
pub extern "Lua-C" fn bind(&self, addr: &lb_socketaddr) -> Result<()> {
Ok(self.socket()?.bind(addr.0)?)
}
/// Transitions this socket to the listening state.
///
/// This consumes the socket and returns a new [`lb_tcplistener`] that can accept incoming
/// connections. This socket object can no longer be used after this call.
pub extern "Lua-C" fn listen(&self, backlog: u32) -> Result<lb_tcplistener> {
let socket = self.0.borrow_mut().take().ok_or(Error::SocketConsumed)?;
Ok(lb_tcplistener::new(socket.listen(backlog)?))
}
/// Connects this socket to the given remote socket address, transitioning it to an established
/// state.
///
/// This consumes the socket and returns a new [`lb_tcpstream`] that can be used to send and
/// receive data. This socket object can no longer be used after this call.
///
/// # Errors
///
/// This function may throw an error if connection could not be established to the given remote
/// address.
pub async extern "Lua-C" fn connect(&self, addr: &lb_socketaddr) -> Result<lb_tcpstream> {
let socket = self.0.borrow_mut().take().ok_or(Error::SocketConsumed)?;
Ok(lb_tcpstream::new(socket.connect(addr.0).await?))
}
}
/// TCP connection between a local and a remote socket.
///
/// This represents an established TCP connection. It is created by connecting an [`lb_tcpsocket`]
/// to a remote socket (via [`connect`](lb_tcpsocket::connect)) or accepting a connection from an
/// [`lb_tcplistener`] (via [`accept`](lb_tcplistener::accept)). It provides methods for reading
/// from and writing to the stream asynchronously.
///
/// The stream supports reading and writing data in both directions concurrently. Typically you
/// would spawn one reader task and one writer task to handle incoming and outgoing data
/// respectively. Connection is closed when this object goes out of scope and gets garbage
/// collected, or when [`close`](Self::close) is explicitly called.
///
/// Methods on this type may fail if the operating system does not support the requested operation.
///
/// # Example
///
/// This examples spawns a reader task and a writer task to operate on the stream concurrently.
///
/// ```lua
/// local task = require("lb:task")
/// local net = require("lb:net")
/// local socket = net.connect_tcp("127.0.0.1:1234")
///
/// print("local address: ", socket:local_addr())
/// print("remote address: ", socket:peer_addr())
///
/// local reader = spawn(function()
/// for chunk in socket, 1024 do
/// print("received: ", chunk)
/// done
///
/// print("done reading")
/// end)
///
/// local writer = spawn(function()
/// for i = 1, 10 do
/// local msg = ("message %d\n"):format(i)
/// socket:write(msg)
/// print("sent: ", msg)
/// done
///
/// print("done writing")
/// end)
///
/// task.join(reader, writer)
/// ```
///
/// The above example uses the socket as an iterator in a generic `for` loop to read data in chunks
/// of up to 1024 bytes. It is equivalent to the following:
///
/// ```lua
/// while true do
/// local chunk = socket:read_partial(1024)
/// if chunk == nil then break end
/// print("received: ", chunk)
/// end
/// ```
#[derive(Debug)]
#[cdef]
pub struct lb_tcpstream {
#[opaque]
read: RefCell<Option<OwnedReadHalf>>,
#[opaque]
write: RefCell<Option<OwnedWriteHalf>>,
}
#[metatype]
impl lb_tcpstream {
pub(super) fn new(stream: TcpStream) -> Self {
let (read, write) = stream.into_split();
Self {
read: RefCell::new(Some(read)),
write: RefCell::new(Some(write)),
}
}
fn read_half<'s>(&'s self) -> Result<RefMut<'s, OwnedReadHalf>> {
let read = self.read.borrow_mut();
match *read {
Some(_) => Ok(RefMut::map(read, |s| s.as_mut().unwrap())),
None => Err(Error::SocketClosed),
}
}
fn write_half<'s>(&'s self) -> Result<RefMut<'s, OwnedWriteHalf>> {
let write = self.write.borrow_mut();
match *write {
Some(_) => Ok(RefMut::map(write, |s| s.as_mut().unwrap())),
None => Err(Error::SocketClosed),
}
}
/// The local socket address that this stream is bound to.
pub extern "Lua-C" fn local_addr(&self) -> Result<lb_socketaddr> {
Ok(self.read_half()?.local_addr()?.into())
}
/// The remote socket address that this stream is connected to.
pub extern "Lua-C" fn peer_addr(&self) -> Result<lb_socketaddr> {
Ok(self.read_half()?.peer_addr()?.into())
}
/// Waits for this stream to be ready in the given half.
///
/// The argument `half` can be `"read"` for the readable half, `"write"` for the writable half,
/// or `nil` for both.
pub async extern "Lua-C" fn ready(&self, half: Option<&str>) -> Result<()> {
self.read_half()?
.ready(match half {
Some("read") => Interest::READABLE,
Some("write") => Interest::WRITABLE,
None => Interest::READABLE | Interest::WRITABLE,
_ => Err(Error::InvalidSocketHalf)?,
})
.await?;
Ok(())
}
/// Closes this stream in the given half.
///
/// The argument `half` can be `"read"` for the readable half, `"write"` for the writable half,
/// or `nil` for both.
///
/// Once the half is closed, it can no longer be used for reading or writing for that half. Once
/// both halves are closed, the stream is fully shut down.
pub extern "Lua-C" fn close(&self, half: Option<&str>) -> Result<()> {
Ok(match half {
Some("read") => drop(self.read.try_borrow_mut()?.take()),
Some("write") => drop(self.write.try_borrow_mut()?.take()),
None => drop((
self.read.try_borrow_mut()?.take(),
self.write.try_borrow_mut()?.take(),
)),
_ => Err(Error::InvalidSocketHalf)?,
})
}
fn is_disc(err: ErrorKind) -> bool {
matches!(
err,
ErrorKind::ConnectionReset // graceful shutdown
| ErrorKind::BrokenPipe // abrupt shutdown
| ErrorKind::UnexpectedEof // could not read requested amount of data
| ErrorKind::WriteZero // could not write requested amount of data
)
}
/// Reads exactly `len` bytes from this stream.
///
/// If the connection was closed, this returns `nil`.
pub async extern "Lua-C" fn read(&self, len: u32) -> Result<Option<Vec<u8>>> {
let mut buf = vec![0; len as usize];
Ok(match self.read_half()?.read_exact(&mut buf).await {
Ok(_) => Some(buf),
Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()),
})
}
/// Reads up to `len` bytes from this stream.
///
/// The returned bytes may be less than `len` in length if the stream had less data available in
/// queue. If there was no data available or the connection was closed, this returns `nil`.
pub async extern "Lua-C" fn read_partial(&self, len: u32) -> Result<Option<Vec<u8>>> {
let mut buf = vec![0; len as usize];
Ok(match self.read_half()?.read(&mut buf).await {
Ok(0) => None,
Ok(n) => Some({
buf.truncate(n);
buf
}),
Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()),
})
}
/// Attempts to read up to `len` bytes from this stream without waiting.
///
/// The returned bytes may be less than `len` in length if the stream had less data available in
/// queue. If there was no data available or the connection was closed, this returns `nil`.
pub extern "Lua-C" fn try_read(&self, len: u32) -> Result<Option<Vec<u8>>> {
let mut buf = vec![0; len as usize];
Ok(match self.read_half()?.try_read(&mut buf) {
Ok(0) => None,
Ok(n) => Some({
buf.truncate(n);
buf
}),
Err(err) if Self::is_disc(err.kind()) || err.kind() == ErrorKind::WouldBlock => None,
Err(err) => return Err(err.into()),
})
}
/// Writes exactly the given bytes to this stream.
///
/// If the connection was closed, this returns `false`.
pub async extern "Lua-C" fn write(&self, buf: &[u8]) -> Result<bool> {
Ok(match self.write_half()?.write_all(buf).await {
Ok(()) => true,
Err(err) if Self::is_disc(err.kind()) => false,
Err(err) => return Err(err.into()),
})
}
/// Writes the given bytes to this stream, and returns the number of bytes successfully written.
///
/// The returned number may be less than the length of `buf` if there was not enough space in
/// queue. If the connection was closed, this returns `nil`.
pub async extern "Lua-C" fn write_partial(&self, buf: &[u8]) -> Result<Option<u32>> {
Ok(match self.write_half()?.write(buf).await {
Ok(n) => Some(n as u32),
Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()),
})
}
/// Attempts to write the given bytes to this stream without waiting, and returns the number of
/// bytes successfully written.
///
/// The returned number may be less than the length of `buf` if there was not enough space in
/// queue. If the connection was closed, this returns `nil`.
pub extern "Lua-C" fn try_write(&self, buf: &[u8]) -> Result<Option<u32>> {
Ok(match self.write_half()?.try_write(buf) {
Ok(n) => Some(n as u32),
Err(err) if Self::is_disc(err.kind()) || err.kind() == ErrorKind::WouldBlock => None,
Err(err) => return Err(err.into()),
})
}
/// Peeks up to `len` bytes at incoming data without consuming it.
///
/// Successive calls will return the same data until it is consumed by the [`read*`](Self::read)
/// family of functions.
pub async extern "Lua-C" fn peek(&self, len: u32) -> Result<Option<Vec<u8>>> {
let mut buf = vec![0; len as usize];
Ok(match self.read_half()?.peek(&mut buf).await {
Ok(0) => None,
Ok(n) => Some({
buf.truncate(n);
buf
}),
Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()),
})
}
/// Alias for [`read_partial`](Self::read_partial).
#[call]
pub async extern "Lua" fn call(&self, len: u32) -> Result<Option<Vec<u8>>> {
self.read_partial(len)
}
}
/// TCP socket server, listening for connections.
///
/// This type represents a TCP server socket that can accept incoming connections. It is created by
/// transitioning an [`lb_tcpsocket`] to the listening state via [`listen`](lb_tcpsocket::listen).
///
/// Methods on this type may fail if the operating system does not support the requested operation.
///
/// # Example
///
/// The listener can be used as an iterator in a generic `for` loop to accept incoming connections:
///
/// ```lua
/// local net = require("lb:net")
/// local listener = net.listen_tcp("127.0.0.1")
///
/// print("listening on: ", listener:local_addr())
///
/// for stream in listener do
/// print("accepted connection from: ", stream:peer_addr())
/// print("local address: ", stream:local_addr())
///
/// spawn(function()
/// stream:write("hello from server\n")
/// stream:close()
/// end)
/// end
/// ```
#[derive(Debug)]
#[cdef]
pub struct lb_tcplistener {
#[opaque]
listener: TcpListener,
__on_accept_ref: c_int,
}
#[metatype]
impl lb_tcplistener {
pub(super) fn new(listener: TcpListener) -> Self {
Self {
listener,
__on_accept_ref: LUA_NOREF,
}
}
/// The local socket address that this listener is bound to.
pub extern "Lua-C" fn local_addr(&self) -> Result<lb_socketaddr> {
Ok(self.listener.local_addr()?.into())
}
/// Registers a callback to be invoked with each new incoming connection before it is converted
/// to an [`lb_tcpstream`].
///
/// The callback receives a temporary [`lb_tcplistener_stream`] object, which can be used to log
/// incoming connections or configure socket options (such as
/// [`set_nodelay`](lb_tcplistener_stream), [`set_linger`](lb_tcplistener_stream), etc.) before
/// it is converted to an [`lb_tcpstream`]. The callback is called synchronously during
/// [`accept`](Self::accept) and should complete as quickly as possible. The provided
/// configurable object is only valid within the callback and is converted to an
/// [`lb_tcpstream`] as soon as it returns.
///
/// If a callback already exists, it is replaced with the new one.
///
/// # Example
///
/// ```lua
/// local net = require("lb:net")
/// local listener = net.listen_tcp("127.0.0.1")
///
/// listener:on_accept(function(stream)
/// print("accepted connection from: ", stream:peer_addr())
/// print("local address: ", stream:local_addr())
///
/// stream:set_nodelay(true)
/// end)
/// ```
pub extern "Lua" fn on_accept(&self, cb: fun<(&lb_tcplistener_stream,), ()>) {
assert(
rawequal(cb, ()) || r#type(cb) == "function",
concat!("function expected in argument 'cb', got ", r#type(cb)),
);
__unref(self.__on_accept_ref);
self.__on_accept_ref = __ref(cb);
}
/// Accepts a new incoming TCP connection.
///
/// If an [`on_accept`](Self::on_accept) callback is registered, it is invoked with a temporary
/// [`lb_tcplistener_stream`] object representing the new connection. This allows configuration
/// of socket options for this specific connection, before the stream is converted to an
/// [`lb_tcpstream`] and returned for the connection to be read from or written to.
pub async extern "Lua" fn accept(&self) -> Result<lb_tcpstream> {
let stream = self.__accept();
let on_accept = __registry[self.__on_accept_ref];
if !rawequal(on_accept, ()) {
on_accept(stream);
}
stream.__convert()
}
async extern "Lua-C" fn __accept(&self) -> Result<lb_tcplistener_stream> {
let (stream, _) = self.listener.accept().await?;
Ok(lb_tcplistener_stream::new(stream))
}
/// Alias for [`accept`](Self::accept).
#[call]
pub async extern "Lua" fn call(&self) -> Result<lb_tcpstream> {
self.accept()
}
#[gc]
extern "Lua" fn gc(&self) {
__unref(self.__on_accept_ref);
}
}
/// TCP connection that has just been accepted by [`lb_tcplistener`].
///
/// This type is passed to the [`on_accept`](lb_tcplistener::on_accept) callback on
/// [`lb_tcplistener`], allowing socket options to be set before the stream is converted to an
/// [`lb_tcpstream`]. After conversion, this object can no longer be used.
///
/// Methods on this type may fail if the operating system does not support the requested operation.
#[derive(Debug)]
#[cdef]
pub struct lb_tcplistener_stream(#[opaque] RefCell<Option<TcpStream>>);
#[metatype]
impl lb_tcplistener_stream {
fn new(stream: TcpStream) -> Self {
Self(RefCell::new(Some(stream)))
}
fn stream<'s>(&'s self) -> Result<Ref<'s, TcpStream>> {
let socket = self.0.borrow();
match *socket {
Some(_) => Ok(Ref::map(socket, |s| s.as_ref().unwrap())),
None => Err(Error::SocketConsumed),
}
}
/// The local socket address that the listener is bound to.
pub extern "Lua-C" fn local_addr(&self) -> Result<lb_socketaddr> {
Ok(self.stream()?.local_addr()?.into())
}
/// The remote socket address that this stream is connected to.
pub extern "Lua-C" fn peer_addr(&self) -> Result<lb_socketaddr> {
Ok(self.stream()?.peer_addr()?.into())
}
/// Gets the value of the `TCP_NODELAY` option on this stream.
pub extern "Lua-C" fn nodelay(&self) -> Result<bool> {
Ok(self.stream()?.nodelay()?)
}
/// Sets the value of the `TCP_NODELAY` option on this stream.
///
/// This enables or disables Nagle's algorithm, which delays sending small packets.
pub extern "Lua-C" fn set_nodelay(&self, enabled: bool) -> Result<()> {
Ok(self.stream()?.set_nodelay(enabled)?)
}
/// Gets the value of the `SO_LINGER` option on this stream, in seconds.
pub extern "Lua-C" fn linger(&self) -> Result<f64> {
Ok(self
.stream()?
.linger()?
.map(|n| n.as_secs_f64())
.unwrap_or(0.))
}
/// Sets the value of the `SO_LINGER` option on this stream.
///
/// This controls how long the stream will remain open after close if unsent data is present.
pub extern "Lua-C" fn set_linger(&self, secs: f64) -> Result<()> {
let secs = secs.max(0.);
Ok(self
.stream()?
.set_linger((secs != 0.).then_some(std::time::Duration::from_secs_f64(secs)))?)
}
/// Gets the value of the `IP_TTL` option for this stream.
pub extern "Lua-C" fn ttl(&self) -> Result<u32> {
Ok(self.stream()?.ttl()?)
}
/// Sets the value for the `IP_TTL` option on this stream.
pub extern "Lua-C" fn set_ttl(&self, ttl: u32) -> Result<()> {
Ok(self.stream()?.set_ttl(ttl)?)
}
extern "Lua-C" fn __convert(&self) -> Result<lb_tcpstream> {
Ok(lb_tcpstream::new(
self.0.borrow_mut().take().ok_or(Error::SocketConsumed)?,
))
}
}