diff --git a/src/udp.rs b/src/udp.rs index 6f07bdf..ed29bb5 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -142,7 +142,8 @@ impl AsyncRead for UdpStream { project.pending_notification.as_mut().set(None); } - let _ = ready!(project.socket.poll_recv(cx, obuf)); + let peer = ready!(project.socket.poll_recv_from(cx, obuf))?; + debug_assert_eq!(peer, *project.peer); let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) }; project.pending_notification.as_mut().set(Some(notified)); project.io.has_read_data.notify_one(); @@ -179,7 +180,14 @@ pub async fn run_server( .with_context(|| format!("Cannot create UDP server {:?}", bind))?; let udp_server = UdpServer::new(Arc::new(listener), timeout); - let stream = stream::unfold(udp_server, |mut server| async { + let stream = stream::unfold((udp_server, None), |(mut server, peer_with_data)| async move { + // New returned peer hasn't read its data yet, await for it. + if let Some(await_peer) = peer_with_data { + if let Some(peer) = server.peers.get(&await_peer) { + peer.has_read_data.notified().await; + } + }; + loop { server.clean_dead_keys(); let peer_addr = match server.listener.peek_sender().await { @@ -192,8 +200,8 @@ pub async fn run_server( match server.peers.get(&peer_addr) { Some(io) => { - io.has_read_data.notified().await; io.has_data_to_read.notify_one(); + io.has_read_data.notified().await; } None => { let (udp_client, io) = UdpStream::new( @@ -206,7 +214,7 @@ pub async fn run_server( Arc::downgrade(&server.keys_to_delete), ); server.peers.insert(peer_addr, io); - return Some((Ok(udp_client), (server))); + return Some((Ok(udp_client), (server, Some(peer_addr)))); } } } @@ -360,8 +368,7 @@ mod tests { #[tokio::test] async fn test_multiple_client() { let server_addr: SocketAddr = "[::1]:1235".parse().unwrap(); - let server = run_server(server_addr, None).await.unwrap(); - pin_mut!(server); + let mut server = Box::pin(run_server(server_addr, None).await.unwrap()); // Send some data to the server let client = UdpSocket::bind("[::1]:0").await.unwrap(); @@ -383,21 +390,28 @@ mod tests { assert!(matches!(ret, Ok(5))); assert_eq!(&buf[..6], b"aaaaa\0"); + // make the server make progress let fut2 = timeout(Duration::from_millis(100), server.next()).await; assert!(matches!(fut2, Ok(Some(Ok(_))))); let stream2 = fut2.unwrap().unwrap().unwrap(); pin_mut!(stream2); + // let the server make progress + tokio::spawn(async move { + loop { + let _ = server.next().await; + } + }); + let ret = stream2.read(&mut buf).await; assert!(matches!(ret, Ok(5))); assert_eq!(&buf[..6], b"bbbbb\0"); assert!(client.send_to(b"ccccc".as_ref(), server_addr).await.is_ok()); assert!(client2.send_to(b"ddddd".as_ref(), server_addr).await.is_ok()); - - // Server need to be polled to feed the stream with need data - let _ = timeout(Duration::from_millis(100), server.next()).await; + assert!(client2.send_to(b"eeeee".as_ref(), server_addr).await.is_ok()); + assert!(client.send_to(b"fffff".as_ref(), server_addr).await.is_ok()); let ret = stream.read(&mut buf).await; assert!(matches!(ret, Ok(5))); @@ -406,6 +420,14 @@ mod tests { let ret = stream2.read(&mut buf).await; assert!(matches!(ret, Ok(5))); assert_eq!(&buf[..6], b"ddddd\0"); + + let ret = stream2.read(&mut buf).await; + assert!(matches!(ret, Ok(5))); + assert_eq!(&buf[..6], b"eeeee\0"); + + let ret = stream.read(&mut buf).await; + assert!(matches!(ret, Ok(5))); + assert_eq!(&buf[..6], b"fffff\0"); } #[tokio::test]