Correctly handle tcp disconnections
This commit is contained in:
parent
263ca1cf48
commit
8c406a46b3
@ -13,6 +13,7 @@ use luaffi::{
|
||||
metatype,
|
||||
};
|
||||
use luajit::LUA_NOREF;
|
||||
use std::io::ErrorKind;
|
||||
use std::{
|
||||
cell::{BorrowError, BorrowMutError, Ref, RefCell, RefMut},
|
||||
ffi::c_int,
|
||||
@ -794,7 +795,7 @@ impl lb_tcpstream {
|
||||
Some("read") => Interest::READABLE,
|
||||
Some("write") => Interest::WRITABLE,
|
||||
_ => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
ErrorKind::InvalidInput,
|
||||
"invalid ready interest",
|
||||
))?,
|
||||
};
|
||||
@ -803,6 +804,16 @@ impl lb_tcpstream {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_disc(err: ErrorKind) -> bool {
|
||||
matches!(
|
||||
err,
|
||||
ErrorKind::ConnectionReset
|
||||
| ErrorKind::BrokenPipe
|
||||
| ErrorKind::UnexpectedEof
|
||||
| ErrorKind::WriteZero
|
||||
)
|
||||
}
|
||||
|
||||
/// Reads exactly `len` bytes from this stream.
|
||||
///
|
||||
/// If the connection was closed, this returns `nil`.
|
||||
@ -810,7 +821,7 @@ impl lb_tcpstream {
|
||||
let mut buf = vec![0; len as usize];
|
||||
Ok(match self.read_half()?.read_exact(&mut buf).await {
|
||||
Ok(_) => Some(buf),
|
||||
Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => None,
|
||||
Err(err) if Self::is_disc(err.kind()) => None,
|
||||
Err(err) => return Err(err.into()),
|
||||
})
|
||||
}
|
||||
@ -821,12 +832,14 @@ impl lb_tcpstream {
|
||||
/// queue. If the connection was closed, this returns `nil`.
|
||||
pub async extern "Lua-C" fn read_partial(&self, len: u32) -> Result<Option<Vec<u8>>> {
|
||||
let mut buf = vec![0; len as usize];
|
||||
let n = self.read_half()?.read(&mut buf).await?;
|
||||
Ok(if n == 0 {
|
||||
None
|
||||
} else {
|
||||
buf.truncate(n);
|
||||
Some(buf)
|
||||
Ok(match self.read_half()?.read(&mut buf).await {
|
||||
Ok(0) => None,
|
||||
Ok(n) => Some({
|
||||
buf.truncate(n);
|
||||
buf
|
||||
}),
|
||||
Err(err) if Self::is_disc(err.kind()) => None,
|
||||
Err(err) => return Err(err.into()),
|
||||
})
|
||||
}
|
||||
|
||||
@ -838,23 +851,40 @@ impl lb_tcpstream {
|
||||
let mut buf = vec![0; len as usize];
|
||||
Ok(match self.read_half()?.try_read(&mut buf) {
|
||||
Ok(0) => None,
|
||||
Ok(n) => {
|
||||
Ok(n) => Some({
|
||||
buf.truncate(n);
|
||||
Some(buf)
|
||||
}
|
||||
Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => None,
|
||||
buf
|
||||
}),
|
||||
Err(err) if Self::is_disc(err.kind()) || err.kind() == ErrorKind::WouldBlock => None,
|
||||
Err(err) => return Err(err.into()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Writes the given bytes to this stream.
|
||||
pub async extern "Lua-C" fn write(&self, buf: &[u8]) -> Result<()> {
|
||||
Ok(self.write_half()?.write_all(buf).await?)
|
||||
pub async extern "Lua-C" fn write(&self, buf: &[u8]) -> Result<bool> {
|
||||
Ok(match self.write_half()?.write_all(buf).await {
|
||||
Ok(()) => true,
|
||||
Err(err) if Self::is_disc(err.kind()) => false,
|
||||
Err(err) => return Err(err.into()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Writes the given bytes to this stream and returns the number of bytes successfully written.
|
||||
pub async extern "Lua-C" fn write_partial(&self, buf: &[u8]) -> Result<u32> {
|
||||
Ok(self.write_half()?.write(buf).await? as u32)
|
||||
pub async extern "Lua-C" fn write_partial(&self, buf: &[u8]) -> Result<Option<u32>> {
|
||||
Ok(match self.write_half()?.write(buf).await {
|
||||
Ok(0) => None,
|
||||
Ok(n) => Some(n as u32),
|
||||
Err(err) if Self::is_disc(err.kind()) => None,
|
||||
Err(err) => return Err(err.into()),
|
||||
})
|
||||
}
|
||||
|
||||
pub extern "Lua-C" fn try_write(&self, buf: &[u8]) -> Result<Option<u32>> {
|
||||
Ok(match self.write_half()?.try_write(buf) {
|
||||
Ok(n) => Some(n as u32),
|
||||
Err(err) if Self::is_disc(err.kind()) || err.kind() == ErrorKind::WouldBlock => None,
|
||||
Err(err) => return Err(err.into()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Peeks up to `len` bytes at incoming data without consuming it.
|
||||
@ -863,12 +893,14 @@ impl lb_tcpstream {
|
||||
/// family of functions.
|
||||
pub async extern "Lua-C" fn peek(&self, len: u32) -> Result<Option<Vec<u8>>> {
|
||||
let mut buf = vec![0; len as usize];
|
||||
let n = self.read_half()?.peek(&mut buf).await?;
|
||||
Ok(if n == 0 {
|
||||
None
|
||||
} else {
|
||||
buf.truncate(n);
|
||||
Some(buf)
|
||||
Ok(match self.read_half()?.peek(&mut buf).await {
|
||||
Ok(0) => None,
|
||||
Ok(n) => Some({
|
||||
buf.truncate(n);
|
||||
buf
|
||||
}),
|
||||
Err(err) if Self::is_disc(err.kind()) => None,
|
||||
Err(err) => return Err(err.into()),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -79,9 +79,81 @@ describe("tcp", function()
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("stream", function()
|
||||
test("no concurrent two reads/writes", function()
|
||||
local listener = net.listen_tcp(net.localhost())
|
||||
local client = net.connect_tcp(listener:local_addr())
|
||||
local server = listener()
|
||||
local reader = spawn(function()
|
||||
assert(client:read(1) == nil) -- this should block first, then return nil from disconnection
|
||||
end)
|
||||
spawn(function()
|
||||
assert(not pcall(client.read, client, 1)) -- this should fail, since the first task is still reading
|
||||
end):await()
|
||||
server:shutdown()
|
||||
reader:await()
|
||||
end)
|
||||
|
||||
test("allow concurrent read/write", function()
|
||||
local listener = net.listen_tcp(net.localhost())
|
||||
local client = net.connect_tcp(listener:local_addr())
|
||||
local server = listener()
|
||||
local reader = spawn(function()
|
||||
assert(client:read(1) == nil) -- this should block first, then return nil from disconnection
|
||||
end)
|
||||
spawn(function()
|
||||
client:write("hello") -- should be able to write while the first task is reading
|
||||
end):await()
|
||||
server:shutdown()
|
||||
reader:await()
|
||||
end)
|
||||
|
||||
test("stop reading from disconnected stream", function()
|
||||
local listener = net.listen_tcp(net.localhost())
|
||||
local client = net.connect_tcp(listener:local_addr())
|
||||
local server = listener()
|
||||
local reader = spawn(function()
|
||||
while client:read(4) ~= nil do
|
||||
end
|
||||
assert(client:try_read(4) == nil)
|
||||
assert(client:read_partial(4) == nil)
|
||||
assert(client:read(4) == nil)
|
||||
end)
|
||||
for _ = 1, 10 do
|
||||
assert(server:write("ping") == true)
|
||||
end
|
||||
sleep(100)
|
||||
server:shutdown()
|
||||
server = nil
|
||||
collectgarbage()
|
||||
reader:await()
|
||||
end)
|
||||
|
||||
test("stop writing to disconnected stream", function()
|
||||
local listener = net.listen_tcp(net.localhost())
|
||||
local client = net.connect_tcp(listener:local_addr())
|
||||
local server = listener()
|
||||
local writer = spawn(function()
|
||||
while client:write("pong") do
|
||||
end
|
||||
assert(client:try_write("pong") == nil)
|
||||
assert(client:write_partial("pong") == nil)
|
||||
assert(client:write("pong") == false)
|
||||
end)
|
||||
for _ = 1, 10 do
|
||||
assert(server:read(4) == "pong")
|
||||
end
|
||||
sleep(100)
|
||||
server:shutdown()
|
||||
server = nil
|
||||
collectgarbage()
|
||||
writer:await()
|
||||
end)
|
||||
end)
|
||||
|
||||
describe("listener", function()
|
||||
test("accept", function()
|
||||
local listener = net.listen_tcp("127.0.0.1", 0)
|
||||
local listener = net.listen_tcp(net.localhost())
|
||||
local addr = listener:local_addr()
|
||||
local accepted = false
|
||||
local client = net.tcp()
|
||||
|
Loading…
x
Reference in New Issue
Block a user