Correctly handle tcp disconnections

This commit is contained in:
lumi 2025-06-28 21:11:09 +10:00
parent 263ca1cf48
commit 8c406a46b3
Signed by: luaneko
GPG Key ID: 406809B8763FF07A
2 changed files with 127 additions and 23 deletions

View File

@ -13,6 +13,7 @@ use luaffi::{
metatype, metatype,
}; };
use luajit::LUA_NOREF; use luajit::LUA_NOREF;
use std::io::ErrorKind;
use std::{ use std::{
cell::{BorrowError, BorrowMutError, Ref, RefCell, RefMut}, cell::{BorrowError, BorrowMutError, Ref, RefCell, RefMut},
ffi::c_int, ffi::c_int,
@ -794,7 +795,7 @@ impl lb_tcpstream {
Some("read") => Interest::READABLE, Some("read") => Interest::READABLE,
Some("write") => Interest::WRITABLE, Some("write") => Interest::WRITABLE,
_ => Err(std::io::Error::new( _ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput, ErrorKind::InvalidInput,
"invalid ready interest", "invalid ready interest",
))?, ))?,
}; };
@ -803,6 +804,16 @@ impl lb_tcpstream {
Ok(()) Ok(())
} }
fn is_disc(err: ErrorKind) -> bool {
matches!(
err,
ErrorKind::ConnectionReset
| ErrorKind::BrokenPipe
| ErrorKind::UnexpectedEof
| ErrorKind::WriteZero
)
}
/// Reads exactly `len` bytes from this stream. /// Reads exactly `len` bytes from this stream.
/// ///
/// If the connection was closed, this returns `nil`. /// If the connection was closed, this returns `nil`.
@ -810,7 +821,7 @@ impl lb_tcpstream {
let mut buf = vec![0; len as usize]; let mut buf = vec![0; len as usize];
Ok(match self.read_half()?.read_exact(&mut buf).await { Ok(match self.read_half()?.read_exact(&mut buf).await {
Ok(_) => Some(buf), Ok(_) => Some(buf),
Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => None, Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),
}) })
} }
@ -821,12 +832,14 @@ impl lb_tcpstream {
/// queue. If the connection was closed, this returns `nil`. /// queue. If the connection was closed, this returns `nil`.
pub async extern "Lua-C" fn read_partial(&self, len: u32) -> Result<Option<Vec<u8>>> { pub async extern "Lua-C" fn read_partial(&self, len: u32) -> Result<Option<Vec<u8>>> {
let mut buf = vec![0; len as usize]; let mut buf = vec![0; len as usize];
let n = self.read_half()?.read(&mut buf).await?; Ok(match self.read_half()?.read(&mut buf).await {
Ok(if n == 0 { Ok(0) => None,
None Ok(n) => Some({
} else { buf.truncate(n);
buf.truncate(n); buf
Some(buf) }),
Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()),
}) })
} }
@ -838,23 +851,40 @@ impl lb_tcpstream {
let mut buf = vec![0; len as usize]; let mut buf = vec![0; len as usize];
Ok(match self.read_half()?.try_read(&mut buf) { Ok(match self.read_half()?.try_read(&mut buf) {
Ok(0) => None, Ok(0) => None,
Ok(n) => { Ok(n) => Some({
buf.truncate(n); buf.truncate(n);
Some(buf) buf
} }),
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => None, Err(err) if Self::is_disc(err.kind()) || err.kind() == ErrorKind::WouldBlock => None,
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),
}) })
} }
/// Writes the given bytes to this stream. /// Writes the given bytes to this stream.
pub async extern "Lua-C" fn write(&self, buf: &[u8]) -> Result<()> { pub async extern "Lua-C" fn write(&self, buf: &[u8]) -> Result<bool> {
Ok(self.write_half()?.write_all(buf).await?) Ok(match self.write_half()?.write_all(buf).await {
Ok(()) => true,
Err(err) if Self::is_disc(err.kind()) => false,
Err(err) => return Err(err.into()),
})
} }
/// Writes the given bytes to this stream and returns the number of bytes successfully written. /// Writes the given bytes to this stream and returns the number of bytes successfully written.
pub async extern "Lua-C" fn write_partial(&self, buf: &[u8]) -> Result<u32> { pub async extern "Lua-C" fn write_partial(&self, buf: &[u8]) -> Result<Option<u32>> {
Ok(self.write_half()?.write(buf).await? as u32) Ok(match self.write_half()?.write(buf).await {
Ok(0) => None,
Ok(n) => Some(n as u32),
Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()),
})
}
pub extern "Lua-C" fn try_write(&self, buf: &[u8]) -> Result<Option<u32>> {
Ok(match self.write_half()?.try_write(buf) {
Ok(n) => Some(n as u32),
Err(err) if Self::is_disc(err.kind()) || err.kind() == ErrorKind::WouldBlock => None,
Err(err) => return Err(err.into()),
})
} }
/// Peeks up to `len` bytes at incoming data without consuming it. /// Peeks up to `len` bytes at incoming data without consuming it.
@ -863,12 +893,14 @@ impl lb_tcpstream {
/// family of functions. /// family of functions.
pub async extern "Lua-C" fn peek(&self, len: u32) -> Result<Option<Vec<u8>>> { pub async extern "Lua-C" fn peek(&self, len: u32) -> Result<Option<Vec<u8>>> {
let mut buf = vec![0; len as usize]; let mut buf = vec![0; len as usize];
let n = self.read_half()?.peek(&mut buf).await?; Ok(match self.read_half()?.peek(&mut buf).await {
Ok(if n == 0 { Ok(0) => None,
None Ok(n) => Some({
} else { buf.truncate(n);
buf.truncate(n); buf
Some(buf) }),
Err(err) if Self::is_disc(err.kind()) => None,
Err(err) => return Err(err.into()),
}) })
} }

View File

@ -79,9 +79,81 @@ describe("tcp", function()
end) end)
end) end)
describe("stream", function()
test("no concurrent two reads/writes", function()
local listener = net.listen_tcp(net.localhost())
local client = net.connect_tcp(listener:local_addr())
local server = listener()
local reader = spawn(function()
assert(client:read(1) == nil) -- this should block first, then return nil from disconnection
end)
spawn(function()
assert(not pcall(client.read, client, 1)) -- this should fail, since the first task is still reading
end):await()
server:shutdown()
reader:await()
end)
test("allow concurrent read/write", function()
local listener = net.listen_tcp(net.localhost())
local client = net.connect_tcp(listener:local_addr())
local server = listener()
local reader = spawn(function()
assert(client:read(1) == nil) -- this should block first, then return nil from disconnection
end)
spawn(function()
client:write("hello") -- should be able to write while the first task is reading
end):await()
server:shutdown()
reader:await()
end)
test("stop reading from disconnected stream", function()
local listener = net.listen_tcp(net.localhost())
local client = net.connect_tcp(listener:local_addr())
local server = listener()
local reader = spawn(function()
while client:read(4) ~= nil do
end
assert(client:try_read(4) == nil)
assert(client:read_partial(4) == nil)
assert(client:read(4) == nil)
end)
for _ = 1, 10 do
assert(server:write("ping") == true)
end
sleep(100)
server:shutdown()
server = nil
collectgarbage()
reader:await()
end)
test("stop writing to disconnected stream", function()
local listener = net.listen_tcp(net.localhost())
local client = net.connect_tcp(listener:local_addr())
local server = listener()
local writer = spawn(function()
while client:write("pong") do
end
assert(client:try_write("pong") == nil)
assert(client:write_partial("pong") == nil)
assert(client:write("pong") == false)
end)
for _ = 1, 10 do
assert(server:read(4) == "pong")
end
sleep(100)
server:shutdown()
server = nil
collectgarbage()
writer:await()
end)
end)
describe("listener", function() describe("listener", function()
test("accept", function() test("accept", function()
local listener = net.listen_tcp("127.0.0.1", 0) local listener = net.listen_tcp(net.localhost())
local addr = listener:local_addr() local addr = listener:local_addr()
local accepted = false local accepted = false
local client = net.tcp() local client = net.tcp()