diff --git a/src/protocols/socks5/mod.rs b/src/protocols/socks5/mod.rs index 3214c0b..762f84e 100644 --- a/src/protocols/socks5/mod.rs +++ b/src/protocols/socks5/mod.rs @@ -3,4 +3,5 @@ mod udp_server; pub use tcp_server::run_server; pub use tcp_server::Socks5Listener; -pub use tcp_server::Socks5Stream; +pub use tcp_server::Socks5ReadHalf; +pub use tcp_server::Socks5WriteHalf; diff --git a/src/protocols/socks5/tcp_server.rs b/src/protocols/socks5/tcp_server.rs index 158fca8..6722a2d 100644 --- a/src/protocols/socks5/tcp_server.rs +++ b/src/protocols/socks5/tcp_server.rs @@ -1,4 +1,4 @@ -use super::udp_server::Socks5UdpStream; +use super::udp_server::{Socks5UdpStream, Socks5UdpStreamWriter}; use crate::LocalProtocol; use anyhow::Context; use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server}; @@ -11,6 +11,7 @@ use std::pin::Pin; use std::task::Poll; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; use tokio::select; use tracing::{info, warn}; @@ -21,9 +22,19 @@ pub struct Socks5Listener { socks_server: Pin> + Send>>, } +pub enum Socks5ReadHalf { + Tcp(OwnedReadHalf), + Udp(Socks5UdpStream), +} + +pub enum Socks5WriteHalf { + Tcp(OwnedWriteHalf), + Udp(Socks5UdpStreamWriter), +} + pub enum Socks5Stream { Tcp(TcpStream), - Udp(Socks5UdpStream), + Udp((Socks5UdpStream, Socks5UdpStreamWriter)), } impl Socks5Stream { @@ -31,10 +42,20 @@ impl Socks5Stream { match self { Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, // TODO: Implement proxy protocol Self::Udp(s) => LocalProtocol::Udp { - timeout: s.watchdog_deadline.as_ref().map(|x| x.period()), + timeout: s.0.watchdog_deadline.as_ref().map(|x| x.period()), }, } } + + pub fn into_split(self) -> (Socks5ReadHalf, Socks5WriteHalf) { + match self { + Self::Tcp(s) => { + let (r, w) = s.into_split(); + (Socks5ReadHalf::Tcp(r), Socks5WriteHalf::Tcp(w)) + } + Self::Udp((r, w)) => (Socks5ReadHalf::Udp(r), Socks5WriteHalf::Udp(w)), + } + } } impl Stream for Socks5Listener { @@ -95,7 +116,8 @@ pub async fn run_server( return match udp_conn { Some(Ok(stream)) => { let dest = stream.destination(); - Some((Ok((Socks5Stream::Udp(stream), dest)), (server, udp_server))) + let writer = stream.writer(); + Some((Ok((Socks5Stream::Udp((stream, writer)), dest)), (server, udp_server))) } Some(Err(err)) => { Some((Err(anyhow::Error::new(err)), (server, udp_server))) @@ -200,7 +222,7 @@ fn new_reply(error: &ReplyError, sock_addr: SocketAddr) -> Vec { } impl Unpin for Socks5Stream {} -impl AsyncRead for Socks5Stream { +impl AsyncRead for Socks5ReadHalf { fn poll_read( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -213,7 +235,7 @@ impl AsyncRead for Socks5Stream { } } -impl AsyncWrite for Socks5Stream { +impl AsyncWrite for Socks5WriteHalf { fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll> { match self.get_mut() { Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf), diff --git a/src/protocols/socks5/udp_server.rs b/src/protocols/socks5/udp_server.rs index c3d98e3..d848a35 100644 --- a/src/protocols/socks5/udp_server.rs +++ b/src/protocols/socks5/udp_server.rs @@ -93,6 +93,12 @@ impl Socks5UdpServer { } } +pub struct Socks5UdpStreamWriter { + send_socket: Arc, + peer: SocketAddr, + udp_header: Vec, +} + #[pin_project(PinnedDrop)] pub struct Socks5UdpStream { #[pin] @@ -153,6 +159,14 @@ impl Socks5UdpStream { TargetAddr::Domain(h, p) => (Host::Domain(h.clone()), *p), } } + + pub fn writer(&self) -> Socks5UdpStreamWriter { + Socks5UdpStreamWriter { + send_socket: self.send_socket.clone(), + peer: self.peer, + udp_header: self.udp_header.clone(), + } + } } impl AsyncRead for Socks5UdpStream { @@ -194,15 +208,12 @@ impl AsyncRead for Socks5UdpStream { } } -impl AsyncWrite for Socks5UdpStream { - fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll> { - let this = self.project(); - let header_len = this.udp_header.len(); - this.udp_header.extend_from_slice(buf); - let ret = this - .send_socket - .poll_send_to(cx, this.udp_header.as_slice(), *this.peer); - this.udp_header.truncate(header_len); +impl AsyncWrite for Socks5UdpStreamWriter { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll> { + let header_len = self.udp_header.len(); + self.udp_header.extend_from_slice(buf); + let ret = self.send_socket.poll_send_to(cx, self.udp_header.as_slice(), self.peer); + self.udp_header.truncate(header_len); ret.map(|r| r.map(|write_len| write_len - header_len)) } diff --git a/src/tunnel/listeners/socks5.rs b/src/tunnel/listeners/socks5.rs index 9c30fde..57ca1a2 100644 --- a/src/tunnel/listeners/socks5.rs +++ b/src/tunnel/listeners/socks5.rs @@ -1,12 +1,11 @@ use crate::protocols::socks5; -use crate::protocols::socks5::{Socks5Listener, Socks5Stream}; +use crate::protocols::socks5::{Socks5Listener, Socks5ReadHalf, Socks5WriteHalf}; use crate::tunnel::RemoteAddr; use anyhow::{anyhow, Context}; use std::net::SocketAddr; use std::pin::Pin; use std::task::{ready, Poll}; use std::time::Duration; -use tokio::io::{ReadHalf, WriteHalf}; use tokio_stream::Stream; pub struct Socks5TunnelListener { @@ -28,7 +27,7 @@ impl Socks5TunnelListener { } impl Stream for Socks5TunnelListener { - type Item = anyhow::Result<((ReadHalf, WriteHalf), RemoteAddr)>; + type Item = anyhow::Result<((Socks5ReadHalf, Socks5WriteHalf), RemoteAddr)>; fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { let this = self.get_mut(); @@ -37,7 +36,7 @@ impl Stream for Socks5TunnelListener { let ret = match ret { Some(Ok((stream, (host, port)))) => { let protocol = stream.local_protocol(); - Some(anyhow::Ok((tokio::io::split(stream), RemoteAddr { protocol, host, port }))) + Some(anyhow::Ok((stream.into_split(), RemoteAddr { protocol, host, port }))) } Some(Err(err)) => Some(Err(err)), None => None,