Correctly handle tcp disconnections

This commit is contained in:
2025-06-28 21:11:09 +10:00
parent 263ca1cf48
commit 8c406a46b3
2 changed files with 127 additions and 23 deletions

View File

@@ -13,6 +13,7 @@ use luaffi::{
metatype,
};
use luajit::LUA_NOREF;
use std::io::ErrorKind;
use std::{
cell::{BorrowError, BorrowMutError, Ref, RefCell, RefMut},
ffi::c_int,
@@ -794,7 +795,7 @@ impl lb_tcpstream {
Some("read") => Interest::READABLE,
Some("write") => Interest::WRITABLE,
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
ErrorKind::InvalidInput,
"invalid ready interest",
))?,
};
@@ -803,6 +804,16 @@ impl lb_tcpstream {
Ok(())
}
fn is_disc(err: ErrorKind) -> bool {
matches!(
err,
ErrorKind::ConnectionReset
| ErrorKind::BrokenPipe
| ErrorKind::UnexpectedEof
| ErrorKind::WriteZero
)
}
/// Reads exactly `len` bytes from this stream.
///
/// If the connection was closed, this returns `nil`.
@@ -810,7 +821,7 @@ impl lb_tcpstream {
let mut buf = vec![0; len as usize];
Ok(match self.read_half()?.read_exact(&mut buf).await {
Ok(_) => Some(buf),
Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => None,
Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()),
})
}
@@ -821,12 +832,14 @@ impl lb_tcpstream {
/// queue. If the connection was closed, this returns `nil`.
pub async extern "Lua-C" fn read_partial(&self, len: u32) -> Result<Option<Vec<u8>>> {
let mut buf = vec![0; len as usize];
let n = self.read_half()?.read(&mut buf).await?;
Ok(if n == 0 {
None
} else {
buf.truncate(n);
Some(buf)
Ok(match self.read_half()?.read(&mut buf).await {
Ok(0) => None,
Ok(n) => Some({
buf.truncate(n);
buf
}),
Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()),
})
}
@@ -838,23 +851,40 @@ impl lb_tcpstream {
let mut buf = vec![0; len as usize];
Ok(match self.read_half()?.try_read(&mut buf) {
Ok(0) => None,
Ok(n) => {
Ok(n) => Some({
buf.truncate(n);
Some(buf)
}
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => None,
buf
}),
Err(err) if Self::is_disc(err.kind()) || err.kind() == ErrorKind::WouldBlock => None,
Err(err) => return Err(err.into()),
})
}
/// Writes the given bytes to this stream.
pub async extern "Lua-C" fn write(&self, buf: &[u8]) -> Result<()> {
Ok(self.write_half()?.write_all(buf).await?)
pub async extern "Lua-C" fn write(&self, buf: &[u8]) -> Result<bool> {
Ok(match self.write_half()?.write_all(buf).await {
Ok(()) => true,
Err(err) if Self::is_disc(err.kind()) => false,
Err(err) => return Err(err.into()),
})
}
/// Writes the given bytes to this stream and returns the number of bytes successfully written.
pub async extern "Lua-C" fn write_partial(&self, buf: &[u8]) -> Result<u32> {
Ok(self.write_half()?.write(buf).await? as u32)
pub async extern "Lua-C" fn write_partial(&self, buf: &[u8]) -> Result<Option<u32>> {
Ok(match self.write_half()?.write(buf).await {
Ok(0) => None,
Ok(n) => Some(n as u32),
Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()),
})
}
pub extern "Lua-C" fn try_write(&self, buf: &[u8]) -> Result<Option<u32>> {
Ok(match self.write_half()?.try_write(buf) {
Ok(n) => Some(n as u32),
Err(err) if Self::is_disc(err.kind()) || err.kind() == ErrorKind::WouldBlock => None,
Err(err) => return Err(err.into()),
})
}
/// Peeks up to `len` bytes at incoming data without consuming it.
@@ -863,12 +893,14 @@ impl lb_tcpstream {
/// family of functions.
pub async extern "Lua-C" fn peek(&self, len: u32) -> Result<Option<Vec<u8>>> {
let mut buf = vec![0; len as usize];
let n = self.read_half()?.peek(&mut buf).await?;
Ok(if n == 0 {
None
} else {
buf.truncate(n);
Some(buf)
Ok(match self.read_half()?.peek(&mut buf).await {
Ok(0) => None,
Ok(n) => Some({
buf.truncate(n);
buf
}),
Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()),
})
}