diff --git a/crates/lb/src/net.rs b/crates/lb/src/net.rs index 899d4ce..5783bcc 100644 --- a/crates/lb/src/net.rs +++ b/crates/lb/src/net.rs @@ -13,6 +13,7 @@ use luaffi::{ metatype, }; use luajit::LUA_NOREF; +use std::io::ErrorKind; use std::{ cell::{BorrowError, BorrowMutError, Ref, RefCell, RefMut}, ffi::c_int, @@ -794,7 +795,7 @@ impl lb_tcpstream { Some("read") => Interest::READABLE, Some("write") => Interest::WRITABLE, _ => Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, + ErrorKind::InvalidInput, "invalid ready interest", ))?, }; @@ -803,6 +804,16 @@ impl lb_tcpstream { Ok(()) } + fn is_disc(err: ErrorKind) -> bool { + matches!( + err, + ErrorKind::ConnectionReset + | ErrorKind::BrokenPipe + | ErrorKind::UnexpectedEof + | ErrorKind::WriteZero + ) + } + /// Reads exactly `len` bytes from this stream. /// /// If the connection was closed, this returns `nil`. @@ -810,7 +821,7 @@ impl lb_tcpstream { let mut buf = vec![0; len as usize]; Ok(match self.read_half()?.read_exact(&mut buf).await { Ok(_) => Some(buf), - Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => None, + Err(err) if Self::is_disc(err.kind()) => None, Err(err) => return Err(err.into()), }) } @@ -821,12 +832,14 @@ impl lb_tcpstream { /// queue. If the connection was closed, this returns `nil`. pub async extern "Lua-C" fn read_partial(&self, len: u32) -> Result>> { let mut buf = vec![0; len as usize]; - let n = self.read_half()?.read(&mut buf).await?; - Ok(if n == 0 { - None - } else { - buf.truncate(n); - Some(buf) + Ok(match self.read_half()?.read(&mut buf).await { + Ok(0) => None, + Ok(n) => Some({ + buf.truncate(n); + buf + }), + Err(err) if Self::is_disc(err.kind()) => None, + Err(err) => return Err(err.into()), }) } @@ -838,23 +851,40 @@ impl lb_tcpstream { let mut buf = vec![0; len as usize]; Ok(match self.read_half()?.try_read(&mut buf) { Ok(0) => None, - Ok(n) => { + Ok(n) => Some({ buf.truncate(n); - Some(buf) - } - Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => None, + buf + }), + Err(err) if Self::is_disc(err.kind()) || err.kind() == ErrorKind::WouldBlock => None, Err(err) => return Err(err.into()), }) } /// Writes the given bytes to this stream. - pub async extern "Lua-C" fn write(&self, buf: &[u8]) -> Result<()> { - Ok(self.write_half()?.write_all(buf).await?) + pub async extern "Lua-C" fn write(&self, buf: &[u8]) -> Result { + Ok(match self.write_half()?.write_all(buf).await { + Ok(()) => true, + Err(err) if Self::is_disc(err.kind()) => false, + Err(err) => return Err(err.into()), + }) } /// Writes the given bytes to this stream and returns the number of bytes successfully written. - pub async extern "Lua-C" fn write_partial(&self, buf: &[u8]) -> Result { - Ok(self.write_half()?.write(buf).await? as u32) + pub async extern "Lua-C" fn write_partial(&self, buf: &[u8]) -> Result> { + Ok(match self.write_half()?.write(buf).await { + Ok(0) => None, + Ok(n) => Some(n as u32), + Err(err) if Self::is_disc(err.kind()) => None, + Err(err) => return Err(err.into()), + }) + } + + pub extern "Lua-C" fn try_write(&self, buf: &[u8]) -> Result> { + Ok(match self.write_half()?.try_write(buf) { + Ok(n) => Some(n as u32), + Err(err) if Self::is_disc(err.kind()) || err.kind() == ErrorKind::WouldBlock => None, + Err(err) => return Err(err.into()), + }) } /// Peeks up to `len` bytes at incoming data without consuming it. @@ -863,12 +893,14 @@ impl lb_tcpstream { /// family of functions. pub async extern "Lua-C" fn peek(&self, len: u32) -> Result>> { let mut buf = vec![0; len as usize]; - let n = self.read_half()?.peek(&mut buf).await?; - Ok(if n == 0 { - None - } else { - buf.truncate(n); - Some(buf) + Ok(match self.read_half()?.peek(&mut buf).await { + Ok(0) => None, + Ok(n) => Some({ + buf.truncate(n); + buf + }), + Err(err) if Self::is_disc(err.kind()) => None, + Err(err) => return Err(err.into()), }) } diff --git a/crates/lb/tests/net.lua b/crates/lb/tests/net.lua index f65bbbe..d157f90 100644 --- a/crates/lb/tests/net.lua +++ b/crates/lb/tests/net.lua @@ -79,9 +79,81 @@ describe("tcp", function() end) end) + describe("stream", function() + test("no concurrent two reads/writes", function() + local listener = net.listen_tcp(net.localhost()) + local client = net.connect_tcp(listener:local_addr()) + local server = listener() + local reader = spawn(function() + assert(client:read(1) == nil) -- this should block first, then return nil from disconnection + end) + spawn(function() + assert(not pcall(client.read, client, 1)) -- this should fail, since the first task is still reading + end):await() + server:shutdown() + reader:await() + end) + + test("allow concurrent read/write", function() + local listener = net.listen_tcp(net.localhost()) + local client = net.connect_tcp(listener:local_addr()) + local server = listener() + local reader = spawn(function() + assert(client:read(1) == nil) -- this should block first, then return nil from disconnection + end) + spawn(function() + client:write("hello") -- should be able to write while the first task is reading + end):await() + server:shutdown() + reader:await() + end) + + test("stop reading from disconnected stream", function() + local listener = net.listen_tcp(net.localhost()) + local client = net.connect_tcp(listener:local_addr()) + local server = listener() + local reader = spawn(function() + while client:read(4) ~= nil do + end + assert(client:try_read(4) == nil) + assert(client:read_partial(4) == nil) + assert(client:read(4) == nil) + end) + for _ = 1, 10 do + assert(server:write("ping") == true) + end + sleep(100) + server:shutdown() + server = nil + collectgarbage() + reader:await() + end) + + test("stop writing to disconnected stream", function() + local listener = net.listen_tcp(net.localhost()) + local client = net.connect_tcp(listener:local_addr()) + local server = listener() + local writer = spawn(function() + while client:write("pong") do + end + assert(client:try_write("pong") == nil) + assert(client:write_partial("pong") == nil) + assert(client:write("pong") == false) + end) + for _ = 1, 10 do + assert(server:read(4) == "pong") + end + sleep(100) + server:shutdown() + server = nil + collectgarbage() + writer:await() + end) + end) + describe("listener", function() test("accept", function() - local listener = net.listen_tcp("127.0.0.1", 0) + local listener = net.listen_tcp(net.localhost()) local addr = listener:local_addr() local accepted = false local client = net.tcp()