refacto: avoid lock in socks5 stream

This commit is contained in:
Σrebe - Romain GERARD 2024-07-31 08:35:45 +02:00
parent 0dded01b7f
commit 58c34ccc41
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
4 changed files with 53 additions and 20 deletions

View file

@ -3,4 +3,5 @@ mod udp_server;
pub use tcp_server::run_server; pub use tcp_server::run_server;
pub use tcp_server::Socks5Listener; pub use tcp_server::Socks5Listener;
pub use tcp_server::Socks5Stream; pub use tcp_server::Socks5ReadHalf;
pub use tcp_server::Socks5WriteHalf;

View file

@ -1,4 +1,4 @@
use super::udp_server::Socks5UdpStream; use super::udp_server::{Socks5UdpStream, Socks5UdpStreamWriter};
use crate::LocalProtocol; use crate::LocalProtocol;
use anyhow::Context; use anyhow::Context;
use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server}; use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server};
@ -11,6 +11,7 @@ use std::pin::Pin;
use std::task::Poll; use std::task::Poll;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::select; use tokio::select;
use tracing::{info, warn}; use tracing::{info, warn};
@ -21,9 +22,19 @@ pub struct Socks5Listener {
socks_server: Pin<Box<dyn Stream<Item = anyhow::Result<(Socks5Stream, (Host, u16))>> + Send>>, socks_server: Pin<Box<dyn Stream<Item = anyhow::Result<(Socks5Stream, (Host, u16))>> + Send>>,
} }
pub enum Socks5ReadHalf {
Tcp(OwnedReadHalf),
Udp(Socks5UdpStream),
}
pub enum Socks5WriteHalf {
Tcp(OwnedWriteHalf),
Udp(Socks5UdpStreamWriter),
}
pub enum Socks5Stream { pub enum Socks5Stream {
Tcp(TcpStream), Tcp(TcpStream),
Udp(Socks5UdpStream), Udp((Socks5UdpStream, Socks5UdpStreamWriter)),
} }
impl Socks5Stream { impl Socks5Stream {
@ -31,10 +42,20 @@ impl Socks5Stream {
match self { match self {
Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, // TODO: Implement proxy protocol Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, // TODO: Implement proxy protocol
Self::Udp(s) => LocalProtocol::Udp { Self::Udp(s) => LocalProtocol::Udp {
timeout: s.watchdog_deadline.as_ref().map(|x| x.period()), timeout: s.0.watchdog_deadline.as_ref().map(|x| x.period()),
}, },
} }
} }
pub fn into_split(self) -> (Socks5ReadHalf, Socks5WriteHalf) {
match self {
Self::Tcp(s) => {
let (r, w) = s.into_split();
(Socks5ReadHalf::Tcp(r), Socks5WriteHalf::Tcp(w))
}
Self::Udp((r, w)) => (Socks5ReadHalf::Udp(r), Socks5WriteHalf::Udp(w)),
}
}
} }
impl Stream for Socks5Listener { impl Stream for Socks5Listener {
@ -95,7 +116,8 @@ pub async fn run_server(
return match udp_conn { return match udp_conn {
Some(Ok(stream)) => { Some(Ok(stream)) => {
let dest = stream.destination(); let dest = stream.destination();
Some((Ok((Socks5Stream::Udp(stream), dest)), (server, udp_server))) let writer = stream.writer();
Some((Ok((Socks5Stream::Udp((stream, writer)), dest)), (server, udp_server)))
} }
Some(Err(err)) => { Some(Err(err)) => {
Some((Err(anyhow::Error::new(err)), (server, udp_server))) Some((Err(anyhow::Error::new(err)), (server, udp_server)))
@ -200,7 +222,7 @@ fn new_reply(error: &ReplyError, sock_addr: SocketAddr) -> Vec<u8> {
} }
impl Unpin for Socks5Stream {} impl Unpin for Socks5Stream {}
impl AsyncRead for Socks5Stream { impl AsyncRead for Socks5ReadHalf {
fn poll_read( fn poll_read(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
@ -213,7 +235,7 @@ impl AsyncRead for Socks5Stream {
} }
} }
impl AsyncWrite for Socks5Stream { impl AsyncWrite for Socks5WriteHalf {
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> { fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
match self.get_mut() { match self.get_mut() {
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf), Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf),

View file

@ -93,6 +93,12 @@ impl Socks5UdpServer {
} }
} }
pub struct Socks5UdpStreamWriter {
send_socket: Arc<UdpSocket>,
peer: SocketAddr,
udp_header: Vec<u8>,
}
#[pin_project(PinnedDrop)] #[pin_project(PinnedDrop)]
pub struct Socks5UdpStream { pub struct Socks5UdpStream {
#[pin] #[pin]
@ -153,6 +159,14 @@ impl Socks5UdpStream {
TargetAddr::Domain(h, p) => (Host::Domain(h.clone()), *p), TargetAddr::Domain(h, p) => (Host::Domain(h.clone()), *p),
} }
} }
pub fn writer(&self) -> Socks5UdpStreamWriter {
Socks5UdpStreamWriter {
send_socket: self.send_socket.clone(),
peer: self.peer,
udp_header: self.udp_header.clone(),
}
}
} }
impl AsyncRead for Socks5UdpStream { impl AsyncRead for Socks5UdpStream {
@ -194,15 +208,12 @@ impl AsyncRead for Socks5UdpStream {
} }
} }
impl AsyncWrite for Socks5UdpStream { impl AsyncWrite for Socks5UdpStreamWriter {
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> { fn poll_write(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
let this = self.project(); let header_len = self.udp_header.len();
let header_len = this.udp_header.len(); self.udp_header.extend_from_slice(buf);
this.udp_header.extend_from_slice(buf); let ret = self.send_socket.poll_send_to(cx, self.udp_header.as_slice(), self.peer);
let ret = this self.udp_header.truncate(header_len);
.send_socket
.poll_send_to(cx, this.udp_header.as_slice(), *this.peer);
this.udp_header.truncate(header_len);
ret.map(|r| r.map(|write_len| write_len - header_len)) ret.map(|r| r.map(|write_len| write_len - header_len))
} }

View file

@ -1,12 +1,11 @@
use crate::protocols::socks5; use crate::protocols::socks5;
use crate::protocols::socks5::{Socks5Listener, Socks5Stream}; use crate::protocols::socks5::{Socks5Listener, Socks5ReadHalf, Socks5WriteHalf};
use crate::tunnel::RemoteAddr; use crate::tunnel::RemoteAddr;
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::pin::Pin; use std::pin::Pin;
use std::task::{ready, Poll}; use std::task::{ready, Poll};
use std::time::Duration; use std::time::Duration;
use tokio::io::{ReadHalf, WriteHalf};
use tokio_stream::Stream; use tokio_stream::Stream;
pub struct Socks5TunnelListener { pub struct Socks5TunnelListener {
@ -28,7 +27,7 @@ impl Socks5TunnelListener {
} }
impl Stream for Socks5TunnelListener { impl Stream for Socks5TunnelListener {
type Item = anyhow::Result<((ReadHalf<Socks5Stream>, WriteHalf<Socks5Stream>), RemoteAddr)>; type Item = anyhow::Result<((Socks5ReadHalf, Socks5WriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut(); let this = self.get_mut();
@ -37,7 +36,7 @@ impl Stream for Socks5TunnelListener {
let ret = match ret { let ret = match ret {
Some(Ok((stream, (host, port)))) => { Some(Ok((stream, (host, port)))) => {
let protocol = stream.local_protocol(); let protocol = stream.local_protocol();
Some(anyhow::Ok((tokio::io::split(stream), RemoteAddr { protocol, host, port }))) Some(anyhow::Ok((stream.into_split(), RemoteAddr { protocol, host, port })))
} }
Some(Err(err)) => Some(Err(err)), Some(Err(err)) => Some(Err(err)),
None => None, None => None,