refacto: avoid lock in socks5 stream
This commit is contained in:
parent
0dded01b7f
commit
58c34ccc41
4 changed files with 53 additions and 20 deletions
|
@ -3,4 +3,5 @@ mod udp_server;
|
|||
|
||||
pub use tcp_server::run_server;
|
||||
pub use tcp_server::Socks5Listener;
|
||||
pub use tcp_server::Socks5Stream;
|
||||
pub use tcp_server::Socks5ReadHalf;
|
||||
pub use tcp_server::Socks5WriteHalf;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::udp_server::Socks5UdpStream;
|
||||
use super::udp_server::{Socks5UdpStream, Socks5UdpStreamWriter};
|
||||
use crate::LocalProtocol;
|
||||
use anyhow::Context;
|
||||
use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server};
|
||||
|
@ -11,6 +11,7 @@ use std::pin::Pin;
|
|||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::select;
|
||||
use tracing::{info, warn};
|
||||
|
@ -21,9 +22,19 @@ pub struct Socks5Listener {
|
|||
socks_server: Pin<Box<dyn Stream<Item = anyhow::Result<(Socks5Stream, (Host, u16))>> + Send>>,
|
||||
}
|
||||
|
||||
pub enum Socks5ReadHalf {
|
||||
Tcp(OwnedReadHalf),
|
||||
Udp(Socks5UdpStream),
|
||||
}
|
||||
|
||||
pub enum Socks5WriteHalf {
|
||||
Tcp(OwnedWriteHalf),
|
||||
Udp(Socks5UdpStreamWriter),
|
||||
}
|
||||
|
||||
pub enum Socks5Stream {
|
||||
Tcp(TcpStream),
|
||||
Udp(Socks5UdpStream),
|
||||
Udp((Socks5UdpStream, Socks5UdpStreamWriter)),
|
||||
}
|
||||
|
||||
impl Socks5Stream {
|
||||
|
@ -31,10 +42,20 @@ impl Socks5Stream {
|
|||
match self {
|
||||
Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, // TODO: Implement proxy protocol
|
||||
Self::Udp(s) => LocalProtocol::Udp {
|
||||
timeout: s.watchdog_deadline.as_ref().map(|x| x.period()),
|
||||
timeout: s.0.watchdog_deadline.as_ref().map(|x| x.period()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_split(self) -> (Socks5ReadHalf, Socks5WriteHalf) {
|
||||
match self {
|
||||
Self::Tcp(s) => {
|
||||
let (r, w) = s.into_split();
|
||||
(Socks5ReadHalf::Tcp(r), Socks5WriteHalf::Tcp(w))
|
||||
}
|
||||
Self::Udp((r, w)) => (Socks5ReadHalf::Udp(r), Socks5WriteHalf::Udp(w)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for Socks5Listener {
|
||||
|
@ -95,7 +116,8 @@ pub async fn run_server(
|
|||
return match udp_conn {
|
||||
Some(Ok(stream)) => {
|
||||
let dest = stream.destination();
|
||||
Some((Ok((Socks5Stream::Udp(stream), dest)), (server, udp_server)))
|
||||
let writer = stream.writer();
|
||||
Some((Ok((Socks5Stream::Udp((stream, writer)), dest)), (server, udp_server)))
|
||||
}
|
||||
Some(Err(err)) => {
|
||||
Some((Err(anyhow::Error::new(err)), (server, udp_server)))
|
||||
|
@ -200,7 +222,7 @@ fn new_reply(error: &ReplyError, sock_addr: SocketAddr) -> Vec<u8> {
|
|||
}
|
||||
|
||||
impl Unpin for Socks5Stream {}
|
||||
impl AsyncRead for Socks5Stream {
|
||||
impl AsyncRead for Socks5ReadHalf {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
|
@ -213,7 +235,7 @@ impl AsyncRead for Socks5Stream {
|
|||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Socks5Stream {
|
||||
impl AsyncWrite for Socks5WriteHalf {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||
match self.get_mut() {
|
||||
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf),
|
||||
|
|
|
@ -93,6 +93,12 @@ impl Socks5UdpServer {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct Socks5UdpStreamWriter {
|
||||
send_socket: Arc<UdpSocket>,
|
||||
peer: SocketAddr,
|
||||
udp_header: Vec<u8>,
|
||||
}
|
||||
|
||||
#[pin_project(PinnedDrop)]
|
||||
pub struct Socks5UdpStream {
|
||||
#[pin]
|
||||
|
@ -153,6 +159,14 @@ impl Socks5UdpStream {
|
|||
TargetAddr::Domain(h, p) => (Host::Domain(h.clone()), *p),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn writer(&self) -> Socks5UdpStreamWriter {
|
||||
Socks5UdpStreamWriter {
|
||||
send_socket: self.send_socket.clone(),
|
||||
peer: self.peer,
|
||||
udp_header: self.udp_header.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for Socks5UdpStream {
|
||||
|
@ -194,15 +208,12 @@ impl AsyncRead for Socks5UdpStream {
|
|||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Socks5UdpStream {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||
let this = self.project();
|
||||
let header_len = this.udp_header.len();
|
||||
this.udp_header.extend_from_slice(buf);
|
||||
let ret = this
|
||||
.send_socket
|
||||
.poll_send_to(cx, this.udp_header.as_slice(), *this.peer);
|
||||
this.udp_header.truncate(header_len);
|
||||
impl AsyncWrite for Socks5UdpStreamWriter {
|
||||
fn poll_write(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||
let header_len = self.udp_header.len();
|
||||
self.udp_header.extend_from_slice(buf);
|
||||
let ret = self.send_socket.poll_send_to(cx, self.udp_header.as_slice(), self.peer);
|
||||
self.udp_header.truncate(header_len);
|
||||
ret.map(|r| r.map(|write_len| write_len - header_len))
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
use crate::protocols::socks5;
|
||||
use crate::protocols::socks5::{Socks5Listener, Socks5Stream};
|
||||
use crate::protocols::socks5::{Socks5Listener, Socks5ReadHalf, Socks5WriteHalf};
|
||||
use crate::tunnel::RemoteAddr;
|
||||
use anyhow::{anyhow, Context};
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::task::{ready, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{ReadHalf, WriteHalf};
|
||||
use tokio_stream::Stream;
|
||||
|
||||
pub struct Socks5TunnelListener {
|
||||
|
@ -28,7 +27,7 @@ impl Socks5TunnelListener {
|
|||
}
|
||||
|
||||
impl Stream for Socks5TunnelListener {
|
||||
type Item = anyhow::Result<((ReadHalf<Socks5Stream>, WriteHalf<Socks5Stream>), RemoteAddr)>;
|
||||
type Item = anyhow::Result<((Socks5ReadHalf, Socks5WriteHalf), RemoteAddr)>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
@ -37,7 +36,7 @@ impl Stream for Socks5TunnelListener {
|
|||
let ret = match ret {
|
||||
Some(Ok((stream, (host, port)))) => {
|
||||
let protocol = stream.local_protocol();
|
||||
Some(anyhow::Ok((tokio::io::split(stream), RemoteAddr { protocol, host, port })))
|
||||
Some(anyhow::Ok((stream.into_split(), RemoteAddr { protocol, host, port })))
|
||||
}
|
||||
Some(Err(err)) => Some(Err(err)),
|
||||
None => None,
|
||||
|
|
Loading…
Reference in a new issue