refacto: avoid lock in socks5 stream
This commit is contained in:
parent
0dded01b7f
commit
58c34ccc41
4 changed files with 53 additions and 20 deletions
|
@ -3,4 +3,5 @@ mod udp_server;
|
||||||
|
|
||||||
pub use tcp_server::run_server;
|
pub use tcp_server::run_server;
|
||||||
pub use tcp_server::Socks5Listener;
|
pub use tcp_server::Socks5Listener;
|
||||||
pub use tcp_server::Socks5Stream;
|
pub use tcp_server::Socks5ReadHalf;
|
||||||
|
pub use tcp_server::Socks5WriteHalf;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use super::udp_server::Socks5UdpStream;
|
use super::udp_server::{Socks5UdpStream, Socks5UdpStreamWriter};
|
||||||
use crate::LocalProtocol;
|
use crate::LocalProtocol;
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server};
|
use fast_socks5::server::{Config, DenyAuthentication, SimpleUserPassword, Socks5Server};
|
||||||
|
@ -11,6 +11,7 @@ use std::pin::Pin;
|
||||||
use std::task::Poll;
|
use std::task::Poll;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||||
|
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
@ -21,9 +22,19 @@ pub struct Socks5Listener {
|
||||||
socks_server: Pin<Box<dyn Stream<Item = anyhow::Result<(Socks5Stream, (Host, u16))>> + Send>>,
|
socks_server: Pin<Box<dyn Stream<Item = anyhow::Result<(Socks5Stream, (Host, u16))>> + Send>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub enum Socks5ReadHalf {
|
||||||
|
Tcp(OwnedReadHalf),
|
||||||
|
Udp(Socks5UdpStream),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Socks5WriteHalf {
|
||||||
|
Tcp(OwnedWriteHalf),
|
||||||
|
Udp(Socks5UdpStreamWriter),
|
||||||
|
}
|
||||||
|
|
||||||
pub enum Socks5Stream {
|
pub enum Socks5Stream {
|
||||||
Tcp(TcpStream),
|
Tcp(TcpStream),
|
||||||
Udp(Socks5UdpStream),
|
Udp((Socks5UdpStream, Socks5UdpStreamWriter)),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Socks5Stream {
|
impl Socks5Stream {
|
||||||
|
@ -31,10 +42,20 @@ impl Socks5Stream {
|
||||||
match self {
|
match self {
|
||||||
Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, // TODO: Implement proxy protocol
|
Self::Tcp(_) => LocalProtocol::Tcp { proxy_protocol: false }, // TODO: Implement proxy protocol
|
||||||
Self::Udp(s) => LocalProtocol::Udp {
|
Self::Udp(s) => LocalProtocol::Udp {
|
||||||
timeout: s.watchdog_deadline.as_ref().map(|x| x.period()),
|
timeout: s.0.watchdog_deadline.as_ref().map(|x| x.period()),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_split(self) -> (Socks5ReadHalf, Socks5WriteHalf) {
|
||||||
|
match self {
|
||||||
|
Self::Tcp(s) => {
|
||||||
|
let (r, w) = s.into_split();
|
||||||
|
(Socks5ReadHalf::Tcp(r), Socks5WriteHalf::Tcp(w))
|
||||||
|
}
|
||||||
|
Self::Udp((r, w)) => (Socks5ReadHalf::Udp(r), Socks5WriteHalf::Udp(w)),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Stream for Socks5Listener {
|
impl Stream for Socks5Listener {
|
||||||
|
@ -95,7 +116,8 @@ pub async fn run_server(
|
||||||
return match udp_conn {
|
return match udp_conn {
|
||||||
Some(Ok(stream)) => {
|
Some(Ok(stream)) => {
|
||||||
let dest = stream.destination();
|
let dest = stream.destination();
|
||||||
Some((Ok((Socks5Stream::Udp(stream), dest)), (server, udp_server)))
|
let writer = stream.writer();
|
||||||
|
Some((Ok((Socks5Stream::Udp((stream, writer)), dest)), (server, udp_server)))
|
||||||
}
|
}
|
||||||
Some(Err(err)) => {
|
Some(Err(err)) => {
|
||||||
Some((Err(anyhow::Error::new(err)), (server, udp_server)))
|
Some((Err(anyhow::Error::new(err)), (server, udp_server)))
|
||||||
|
@ -200,7 +222,7 @@ fn new_reply(error: &ReplyError, sock_addr: SocketAddr) -> Vec<u8> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Unpin for Socks5Stream {}
|
impl Unpin for Socks5Stream {}
|
||||||
impl AsyncRead for Socks5Stream {
|
impl AsyncRead for Socks5ReadHalf {
|
||||||
fn poll_read(
|
fn poll_read(
|
||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut std::task::Context<'_>,
|
cx: &mut std::task::Context<'_>,
|
||||||
|
@ -213,7 +235,7 @@ impl AsyncRead for Socks5Stream {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncWrite for Socks5Stream {
|
impl AsyncWrite for Socks5WriteHalf {
|
||||||
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||||
match self.get_mut() {
|
match self.get_mut() {
|
||||||
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf),
|
Self::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf),
|
||||||
|
|
|
@ -93,6 +93,12 @@ impl Socks5UdpServer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct Socks5UdpStreamWriter {
|
||||||
|
send_socket: Arc<UdpSocket>,
|
||||||
|
peer: SocketAddr,
|
||||||
|
udp_header: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
#[pin_project(PinnedDrop)]
|
#[pin_project(PinnedDrop)]
|
||||||
pub struct Socks5UdpStream {
|
pub struct Socks5UdpStream {
|
||||||
#[pin]
|
#[pin]
|
||||||
|
@ -153,6 +159,14 @@ impl Socks5UdpStream {
|
||||||
TargetAddr::Domain(h, p) => (Host::Domain(h.clone()), *p),
|
TargetAddr::Domain(h, p) => (Host::Domain(h.clone()), *p),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn writer(&self) -> Socks5UdpStreamWriter {
|
||||||
|
Socks5UdpStreamWriter {
|
||||||
|
send_socket: self.send_socket.clone(),
|
||||||
|
peer: self.peer,
|
||||||
|
udp_header: self.udp_header.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncRead for Socks5UdpStream {
|
impl AsyncRead for Socks5UdpStream {
|
||||||
|
@ -194,15 +208,12 @@ impl AsyncRead for Socks5UdpStream {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncWrite for Socks5UdpStream {
|
impl AsyncWrite for Socks5UdpStreamWriter {
|
||||||
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
fn poll_write(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||||
let this = self.project();
|
let header_len = self.udp_header.len();
|
||||||
let header_len = this.udp_header.len();
|
self.udp_header.extend_from_slice(buf);
|
||||||
this.udp_header.extend_from_slice(buf);
|
let ret = self.send_socket.poll_send_to(cx, self.udp_header.as_slice(), self.peer);
|
||||||
let ret = this
|
self.udp_header.truncate(header_len);
|
||||||
.send_socket
|
|
||||||
.poll_send_to(cx, this.udp_header.as_slice(), *this.peer);
|
|
||||||
this.udp_header.truncate(header_len);
|
|
||||||
ret.map(|r| r.map(|write_len| write_len - header_len))
|
ret.map(|r| r.map(|write_len| write_len - header_len))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
use crate::protocols::socks5;
|
use crate::protocols::socks5;
|
||||||
use crate::protocols::socks5::{Socks5Listener, Socks5Stream};
|
use crate::protocols::socks5::{Socks5Listener, Socks5ReadHalf, Socks5WriteHalf};
|
||||||
use crate::tunnel::RemoteAddr;
|
use crate::tunnel::RemoteAddr;
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::task::{ready, Poll};
|
use std::task::{ready, Poll};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{ReadHalf, WriteHalf};
|
|
||||||
use tokio_stream::Stream;
|
use tokio_stream::Stream;
|
||||||
|
|
||||||
pub struct Socks5TunnelListener {
|
pub struct Socks5TunnelListener {
|
||||||
|
@ -28,7 +27,7 @@ impl Socks5TunnelListener {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Stream for Socks5TunnelListener {
|
impl Stream for Socks5TunnelListener {
|
||||||
type Item = anyhow::Result<((ReadHalf<Socks5Stream>, WriteHalf<Socks5Stream>), RemoteAddr)>;
|
type Item = anyhow::Result<((Socks5ReadHalf, Socks5WriteHalf), RemoteAddr)>;
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
|
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
let this = self.get_mut();
|
let this = self.get_mut();
|
||||||
|
@ -37,7 +36,7 @@ impl Stream for Socks5TunnelListener {
|
||||||
let ret = match ret {
|
let ret = match ret {
|
||||||
Some(Ok((stream, (host, port)))) => {
|
Some(Ok((stream, (host, port)))) => {
|
||||||
let protocol = stream.local_protocol();
|
let protocol = stream.local_protocol();
|
||||||
Some(anyhow::Ok((tokio::io::split(stream), RemoteAddr { protocol, host, port })))
|
Some(anyhow::Ok((stream.into_split(), RemoteAddr { protocol, host, port })))
|
||||||
}
|
}
|
||||||
Some(Err(err)) => Some(Err(err)),
|
Some(Err(err)) => Some(Err(err)),
|
||||||
None => None,
|
None => None,
|
||||||
|
|
Loading…
Reference in a new issue