diff --git a/Cargo.lock b/Cargo.lock index 2990505..567b57e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1625,6 +1625,7 @@ dependencies = [ "libc", "log", "once_cell", + "parking_lot", "pin-project", "rustls-native-certs", "rustls-pemfile", diff --git a/Cargo.toml b/Cargo.toml index 8779c3c..8d859de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ uuid = { version = "1.5.0", features = ["v7", "serde"] } jsonwebtoken = { version = "9.1.0", default-features = false } rustls-pemfile = { version = "1.0.3", features = [] } bytes = { version = "1.5.0", features = [] } +parking_lot = "0.12.1" rustls-native-certs = { version = "0.6.3", features = [] } tokio = { version = "1.33.0", features = ["full"] } diff --git a/src/udp.rs b/src/udp.rs index 7be195b..7e67f5c 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -1,6 +1,7 @@ use anyhow::Context; use bytes::{Buf, BytesMut}; use futures_util::{stream, Stream}; +use parking_lot::{Mutex, RwLock}; use pin_project::{pin_project, pinned_drop}; use std::collections::hash_map::Entry; use std::collections::{HashMap, VecDeque}; @@ -10,7 +11,7 @@ use std::io::{Error, ErrorKind}; use std::net::SocketAddr; use std::ops::DerefMut; use std::pin::{pin, Pin}; -use std::sync::{Arc, Mutex, RwLock, Weak}; +use std::sync::{Arc, Weak}; use std::task::{Poll, Waker}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -38,13 +39,13 @@ impl UdpServer { } } fn clean_dead_keys(&mut self) { - let nb_key_to_delete = self.keys_to_delete.read().unwrap().len(); + let nb_key_to_delete = self.keys_to_delete.read().len(); if nb_key_to_delete == 0 { return; } debug!("Cleaning {} dead udp peers", nb_key_to_delete); - let mut keys_to_delete = self.keys_to_delete.write().unwrap(); + let mut keys_to_delete = self.keys_to_delete.write(); for key in keys_to_delete.iter() { self.peers.remove(key); } @@ -69,7 +70,7 @@ pub struct UdpStream { impl PinnedDrop for UdpStream { fn drop(self: Pin<&mut Self>) { if let Some(keys_to_delete) = self.keys_to_delete.upgrade() { - keys_to_delete.write().unwrap().push(self.peer); + keys_to_delete.write().push(self.peer); } } } @@ -90,7 +91,7 @@ impl AsyncRead for UdpStream { } } - let mut guard = project.io.lock().unwrap(); + let mut guard = project.io.lock(); let (ibuf, waker, read_lens) = guard.deref_mut(); if let Some(read_len) = read_lens.pop_front() { if read_len > obuf.remaining() { @@ -161,7 +162,7 @@ pub async fn run_server( match server.peers.entry(peer_addr) { Entry::Occupied(mut peer) => { - let mut guard = peer.get_mut().lock().unwrap(); + let mut guard = peer.get_mut().lock(); let (buf, waker, read_lens) = guard.deref_mut(); // As we have done a peek_sender before, we are sure that there is pending read data // and we don't want to wait to avoid holding the lock across await point @@ -176,7 +177,7 @@ pub async fn run_server( } Err(_) => { drop(guard); - server.keys_to_delete.write().unwrap().push(peer_addr); + server.keys_to_delete.write().push(peer_addr); } } } @@ -299,10 +300,13 @@ mod tests { assert!(client.send_to(b"world".as_ref(), server_addr).await.is_ok()); assert!(client.send_to(b" test".as_ref(), server_addr).await.is_ok()); - // Server need to be polled to feed the stream with need data + // Server need to be polled to feed the stream with needed data let _ = timeout(Duration::from_millis(100), server.next()).await; + // Udp Server should respect framing from the client and not merge the two packets let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[5..])).await; - assert!(matches!(ret, Ok(Ok(10)))); + assert!(matches!(ret, Ok(Ok(5)))); + let ret = timeout(Duration::from_millis(100), stream.read(&mut buf[10..])).await; + assert!(matches!(ret, Ok(Ok(5)))); assert_eq!(&buf[..16], b"helloworld test\0"); }