diff --git a/justfile b/justfile index 2d184f3..a2d229c 100644 --- a/justfile +++ b/justfile @@ -15,6 +15,6 @@ make_release $VERSION $FORCE="": @just docker_release v$VERSION docker_release $TAG: - docker login -u erebe ghcr.io + #docker login -u erebe ghcr.io ~/.depot/bin/depot build --project v4z5w7md33 --platform linux/arm/v7,linux/arm64,linux/amd64 -t ghcr.io/erebe/wstunnel:$TAG -t ghcr.io/erebe/wstunnel:latest --push . diff --git a/src/main.rs b/src/main.rs index 1509076..60f283a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -794,7 +794,7 @@ async fn main() { let server = socks5::run_server(tunnel.local) .await .unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)) - .map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest)); + .map_ok(|(stream, remote_dest)| (tokio::io::split(stream), remote_dest)); tokio::spawn(async move { if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { diff --git a/src/socks5.rs b/src/socks5.rs index 44213ae..8d57012 100644 --- a/src/socks5.rs +++ b/src/socks5.rs @@ -1,23 +1,29 @@ +use crate::udp::UdpStream; use anyhow::Context; use fast_socks5::server::{Config, DenyAuthentication, Socks5Server}; use fast_socks5::util::target_addr::TargetAddr; use fast_socks5::{consts, ReplyError}; use futures_util::{stream, Stream, StreamExt}; +use std::io::{Error, IoSlice}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::pin::Pin; use std::task::Poll; -use tokio::io::AsyncWriteExt; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::net::TcpStream; use tracing::{info, warn}; use url::Host; #[allow(clippy::type_complexity)] pub struct Socks5Listener { - stream: Pin> + Send>>, + stream: Pin> + Send>>, } +pub enum Socks5Protocol { + Tcp(TcpStream), + Udp(UdpStream), +} impl Stream for Socks5Listener { - type Item = anyhow::Result<(TcpStream, (Host, u16))>; + type Item = anyhow::Result<(Socks5Protocol, (Host, u16))>; fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { unsafe { self.map_unchecked_mut(|x| &mut x.stream) }.poll_next(cx) @@ -35,6 +41,7 @@ pub async fn run_server(bind: SocketAddr) -> Result Result (Host::Ipv4(*ip.ip()), ip.port()), @@ -82,7 +93,7 @@ pub async fn run_server(bind: SocketAddr) -> Result Vec { reply } +impl Unpin for Socks5Protocol {} +impl AsyncRead for Socks5Protocol { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + Socks5Protocol::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf), + Socks5Protocol::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_read(cx, buf), + } + } +} + +impl AsyncWrite for Socks5Protocol { + fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll> { + match self.get_mut() { + Socks5Protocol::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf), + Socks5Protocol::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + match self.get_mut() { + Socks5Protocol::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_flush(cx), + Socks5Protocol::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + match self.get_mut() { + Socks5Protocol::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx), + Socks5Protocol::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_shutdown(cx), + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + Socks5Protocol::Tcp(s) => unsafe { Pin::new_unchecked(s) }.poll_write_vectored(cx, bufs), + Socks5Protocol::Udp(s) => unsafe { Pin::new_unchecked(s) }.poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + Socks5Protocol::Tcp(s) => s.is_write_vectored(), + Socks5Protocol::Udp(s) => s.is_write_vectored(), + } + } +} + //#[cfg(test)] //mod test { // use super::*; diff --git a/src/tunnel/io.rs b/src/tunnel/io.rs index 607d09e..0f11037 100644 --- a/src/tunnel/io.rs +++ b/src/tunnel/io.rs @@ -1,6 +1,7 @@ use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite}; use futures_util::{pin_mut, FutureExt}; use hyper::upgrade::Upgraded; +use std::cmp::max; use hyper_util::rt::TokioIo; use std::time::Duration; diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 909fdf5..bb7664f 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -22,6 +22,7 @@ use jsonwebtoken::TokenData; use once_cell::sync::Lazy; use parking_lot::Mutex; +use crate::socks5::Socks5Protocol; use crate::tunnel::tls_reloader::TlsReloader; use crate::udp::UdpStream; use tokio::io::{AsyncRead, AsyncWrite}; @@ -104,7 +105,7 @@ async fn run_tunnel( } LocalProtocol::ReverseSocks5 => { #[allow(clippy::type_complexity)] - static SERVERS: Lazy, u16), mpsc::Receiver<(TcpStream, (Host, u16))>>>> = + static SERVERS: Lazy, u16), mpsc::Receiver<(Socks5Protocol, (Host, u16))>>>> = Lazy::new(|| Mutex::new(HashMap::with_capacity(0))); let local_srv = (Host::parse(&jwt.claims.r)?, jwt.claims.rp);