refacto: Add specific trait for TunnelListener

This commit is contained in:
Σrebe - Romain GERARD 2024-07-28 12:22:33 +02:00
parent 6e10c27dbb
commit 6a07201de1
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
5 changed files with 427 additions and 117 deletions

View file

@ -9,6 +9,7 @@ mod tcp;
mod tls; mod tls;
mod tls_utils; mod tls_utils;
mod tunnel; mod tunnel;
mod types;
mod udp; mod udp;
#[cfg(unix)] #[cfg(unix)]
mod unix_socket; mod unix_socket;
@ -16,7 +17,6 @@ mod unix_socket;
use anyhow::anyhow; use anyhow::anyhow;
use base64::Engine; use base64::Engine;
use clap::Parser; use clap::Parser;
use futures_util::{stream, TryStreamExt};
use hyper::header::HOST; use hyper::header::HOST;
use hyper::http::{HeaderName, HeaderValue}; use hyper::http::{HeaderName, HeaderValue};
use log::debug; use log::debug;
@ -46,6 +46,10 @@ use crate::restrictions::types::RestrictionsRules;
use crate::tls_utils::{cn_from_certificate, find_leaf_certificate}; use crate::tls_utils::{cn_from_certificate, find_leaf_certificate};
use crate::tunnel::tls_reloader::TlsReloader; use crate::tunnel::tls_reloader::TlsReloader;
use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme}; use crate::tunnel::{to_host_port, RemoteAddr, TransportAddr, TransportScheme};
use crate::types::{
HttpProxyTunnelListener, Socks5TunnelListener, StdioTunnelListener, TProxyUdpTunnelListener, TcpTunnelListener,
TproxyTcpTunnelListener, UdpTunnelListener, UnixTunnelListener,
};
use crate::udp::MyUdpSocket; use crate::udp::MyUdpSocket;
use tracing_subscriber::filter::Directive; use tracing_subscriber::filter::Directive;
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
@ -1184,21 +1188,10 @@ async fn main() {
match &tunnel.local_protocol { match &tunnel.local_protocol {
LocalProtocol::Tcp { proxy_protocol } => { LocalProtocol::Tcp { proxy_protocol } => {
let proxy_protocol = *proxy_protocol;
let remote = tunnel.remote.clone();
let server = tcp::run_server(tunnel.local, false) let server = tcp::run_server(tunnel.local, false)
.await .await
.unwrap_or_else(|err| panic!("Cannot start TCP server on {}: {}", tunnel.local, err)) .unwrap_or_else(|err| panic!("Cannot start TCP server on {}: {}", tunnel.local, err));
.map_err(anyhow::Error::new) let server = TcpTunnelListener::new(server, tunnel.remote.clone(), *proxy_protocol);
.map_ok(move |stream| {
let remote = RemoteAddr {
protocol: LocalProtocol::Tcp { proxy_protocol },
host: remote.0.clone(),
port: remote.1,
};
(stream.into_split(), remote)
});
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
error!("{:?}", err); error!("{:?}", err);
@ -1207,20 +1200,10 @@ async fn main() {
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
LocalProtocol::TProxyTcp => { LocalProtocol::TProxyTcp => {
let server = tcp::run_server(tunnel.local, true) let server = tcp::run_server(tunnel.local, true).await.unwrap_or_else(|err| {
.await panic!("Cannot start TProxy TCP server on {}: {}", tunnel.local, err)
.unwrap_or_else(|err| panic!("Cannot start TProxy TCP server on {}: {}", tunnel.local, err)) });
.map_err(anyhow::Error::new) let server = TproxyTcpTunnelListener::new(server, false); // TODO: support proxy protocol
.map_ok(move |stream| {
// In TProxy mode local destination is the final ip:port destination
let (host, port) = to_host_port(stream.local_addr().unwrap());
let remote = RemoteAddr {
protocol: LocalProtocol::Tcp { proxy_protocol: false },
host,
port,
};
(stream.into_split(), remote)
});
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
@ -1230,22 +1213,11 @@ async fn main() {
} }
#[cfg(unix)] #[cfg(unix)]
LocalProtocol::Unix { path } => { LocalProtocol::Unix { path } => {
let remote = tunnel.remote.clone(); let server = unix_socket::run_server(path).await.unwrap_or_else(|err| {
let server = unix_socket::run_server(path) panic!("Cannot start Unix domain server on {}: {}", tunnel.local, err)
.await });
.unwrap_or_else(|err| {
panic!("Cannot start Unix domain server on {}: {}", tunnel.local, err)
})
.map_err(anyhow::Error::new)
.map_ok(move |stream| {
let remote = RemoteAddr {
protocol: LocalProtocol::Tcp { proxy_protocol: false },
host: remote.0.clone(),
port: remote.1,
};
(stream.into_split(), remote)
});
let server = UnixTunnelListener::new(server, tunnel.remote.clone(), false); // TODO: support proxy protocol
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
error!("{:?}", err); error!("{:?}", err);
@ -1259,25 +1231,14 @@ async fn main() {
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
LocalProtocol::TProxyUdp { timeout } => { LocalProtocol::TProxyUdp { timeout } => {
let timeout = *timeout;
let server = let server =
udp::run_server(tunnel.local, timeout, udp::configure_tproxy, udp::mk_send_socket_tproxy) udp::run_server(tunnel.local, *timeout, udp::configure_tproxy, udp::mk_send_socket_tproxy)
.await .await
.unwrap_or_else(|err| { .unwrap_or_else(|err| {
panic!("Cannot start TProxy UDP server on {}: {}", tunnel.local, err) panic!("Cannot start TProxy UDP server on {}: {}", tunnel.local, err)
})
.map_err(anyhow::Error::new)
.map_ok(move |stream| {
// In TProxy mode local destination is the final ip:port destination
let (host, port) = to_host_port(stream.local_addr().unwrap());
let remote = RemoteAddr {
protocol: LocalProtocol::Udp { timeout },
host,
port,
};
(tokio::io::split(stream), remote)
}); });
let server = TProxyUdpTunnelListener::new(server, *timeout);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
error!("{:?}", err); error!("{:?}", err);
@ -1289,20 +1250,10 @@ async fn main() {
panic!("Transparent proxy is not available for non Linux platform") panic!("Transparent proxy is not available for non Linux platform")
} }
LocalProtocol::Udp { timeout } => { LocalProtocol::Udp { timeout } => {
let (host, port) = tunnel.remote.clone(); let server = udp::run_server(tunnel.local, *timeout, |_| Ok(()), |s| Ok(s.clone()))
let timeout = *timeout;
let server = udp::run_server(tunnel.local, timeout, |_| Ok(()), |s| Ok(s.clone()))
.await .await
.unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err)) .unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err));
.map_err(anyhow::Error::new) let server = UdpTunnelListener::new(server, tunnel.remote.clone(), *timeout);
.map_ok(move |stream| {
let remote = RemoteAddr {
protocol: LocalProtocol::Udp { timeout },
host: host.clone(),
port,
};
(tokio::io::split(stream), remote)
});
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
@ -1313,16 +1264,9 @@ async fn main() {
LocalProtocol::Socks5 { timeout, credentials } => { LocalProtocol::Socks5 { timeout, credentials } => {
let server = socks5::run_server(tunnel.local, *timeout, credentials.clone()) let server = socks5::run_server(tunnel.local, *timeout, credentials.clone())
.await .await
.unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)) .unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err));
.map_ok(|(stream, (host, port))| {
let remote = RemoteAddr {
protocol: stream.local_protocol(),
host,
port,
};
(tokio::io::split(stream), remote)
});
let server = Socks5TunnelListener::new(server);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
error!("{:?}", err); error!("{:?}", err);
@ -1334,19 +1278,13 @@ async fn main() {
credentials, credentials,
proxy_protocol, proxy_protocol,
} => { } => {
let proxy_protocol = *proxy_protocol;
let server = http_proxy::run_server(tunnel.local, *timeout, credentials.clone()) let server = http_proxy::run_server(tunnel.local, *timeout, credentials.clone())
.await .await
.unwrap_or_else(|err| panic!("Cannot start http proxy server on {}: {}", tunnel.local, err)) .unwrap_or_else(|err| {
.map_ok(move |(stream, (host, port))| { panic!("Cannot start http proxy server on {}: {}", tunnel.local, err)
let remote = RemoteAddr {
protocol: LocalProtocol::Tcp { proxy_protocol },
host,
port,
};
(tokio::io::split(stream), remote)
}); });
let server = HttpProxyTunnelListener::new(server, *proxy_protocol);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel(client_config, server).await { if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
error!("{:?}", err); error!("{:?}", err);
@ -1358,20 +1296,9 @@ async fn main() {
let (server, mut handle) = stdio::server::run_server().await.unwrap_or_else(|err| { let (server, mut handle) = stdio::server::run_server().await.unwrap_or_else(|err| {
panic!("Cannot start STDIO server: {}", err); panic!("Cannot start STDIO server: {}", err);
}); });
let server = StdioTunnelListener::new(server, tunnel.remote.clone(), false);
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = tunnel::client::run_tunnel( if let Err(err) = tunnel::client::run_tunnel(client_config, server).await {
client_config,
stream::once(async move {
let remote = RemoteAddr {
protocol: LocalProtocol::Tcp { proxy_protocol: false },
host: tunnel.remote.0,
port: tunnel.remote.1,
};
Ok((server, remote))
}),
)
.await
{
error!("{:?}", err); error!("{:?}", err);
} }
}); });

View file

@ -1,5 +1,6 @@
use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE}; use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE};
use crate::tunnel::transport::{TunnelReader, TunnelWriter}; use crate::tunnel::transport::{TunnelReader, TunnelWriter};
use crate::types::TunnelListener;
use crate::{tunnel, WsClientConfig}; use crate::{tunnel, WsClientConfig};
use futures_util::pin_mut; use futures_util::pin_mut;
use hyper::header::COOKIE; use hyper::header::COOKIE;
@ -10,7 +11,7 @@ use std::ops::Deref;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio_stream::{Stream, StreamExt}; use tokio_stream::StreamExt;
use tracing::{error, event, span, Instrument, Level, Span}; use tracing::{error, event, span, Instrument, Level, Span};
use url::Host; use url::Host;
use uuid::Uuid; use uuid::Uuid;
@ -56,11 +57,7 @@ where
Ok(()) Ok(())
} }
pub async fn run_tunnel<T, R, W>(client_config: Arc<WsClientConfig>, incoming_cnx: T) -> anyhow::Result<()> pub async fn run_tunnel(client_config: Arc<WsClientConfig>, incoming_cnx: impl TunnelListener) -> anyhow::Result<()>
where
T: Stream<Item = anyhow::Result<((R, W), RemoteAddr)>>,
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{ {
pin_mut!(incoming_cnx); pin_mut!(incoming_cnx);
while let Some(cnx) = incoming_cnx.next().await { while let Some(cnx) = incoming_cnx.next().await {

View file

@ -121,7 +121,8 @@ async fn run_tunnel(
let listening_server = let listening_server =
udp::run_server(bind.parse()?, timeout, |_| Ok(()), |send_socket| Ok(send_socket.clone())); udp::run_server(bind.parse()?, timeout, |_| Ok(()), |send_socket| Ok(send_socket.clone()));
let udp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?; let udp = run_listening_server(&local_srv, SERVERS.deref(), listening_server).await?;
let (local_rx, local_tx) = tokio::io::split(udp); let udp_writer = udp.writer();
let (local_rx, local_tx) = (udp, udp_writer);
let remote = RemoteAddr { let remote = RemoteAddr {
protocol: remote.protocol, protocol: remote.protocol,

374
src/types.rs Normal file
View file

@ -0,0 +1,374 @@
use crate::http_proxy::HttpProxyListener;
use crate::socks5::{Socks5Listener, Socks5Stream};
use crate::tunnel::{to_host_port, RemoteAddr};
use crate::udp::{UdpStream, UdpStreamWriter};
use crate::unix_socket::UnixListenerStream;
use crate::LocalProtocol;
use std::io;
use std::pin::Pin;
use std::task::{ready, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadHalf, WriteHalf};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::unix;
use tokio_stream::wrappers::TcpListenerStream;
use tokio_stream::Stream;
use url::Host;
pub trait TunnelListener: Stream<Item = anyhow::Result<((Self::Reader, Self::Writer), RemoteAddr)>> {
type Reader: AsyncRead + Send + 'static;
type Writer: AsyncWrite + Send + 'static;
}
impl<T, R, W> TunnelListener for T
where
T: Stream<Item = anyhow::Result<((R, W), RemoteAddr)>>,
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
type Reader = R;
type Writer = W;
}
pub struct TcpTunnelListener {
listener: TcpListenerStream,
dest: (Host, u16),
proxy_protocol: bool,
}
impl TcpTunnelListener {
pub fn new(listener: TcpListenerStream, dest: (Host, u16), proxy_protocol: bool) -> Self {
Self {
listener,
dest,
proxy_protocol,
}
}
}
impl Stream for TcpTunnelListener {
type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok(strean)) => {
let (host, port) = this.dest.clone();
Some(anyhow::Ok((
strean.into_split(),
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}
// TPROXY
pub struct TproxyTcpTunnelListener {
listener: TcpListenerStream,
proxy_protocol: bool,
}
impl TproxyTcpTunnelListener {
pub fn new(listener: TcpListenerStream, proxy_protocol: bool) -> Self {
Self {
listener,
proxy_protocol,
}
}
}
impl Stream for TproxyTcpTunnelListener {
type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let (host, port) = to_host_port(stream.local_addr().unwrap());
Some(anyhow::Ok((
stream.into_split(),
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}
// UNIX
pub struct UnixTunnelListener {
listener: UnixListenerStream,
dest: (Host, u16),
proxy_protocol: bool,
}
impl UnixTunnelListener {
pub fn new(listener: UnixListenerStream, dest: (Host, u16), proxy_protocol: bool) -> Self {
Self {
listener,
dest,
proxy_protocol,
}
}
}
impl Stream for UnixTunnelListener {
type Item = anyhow::Result<((unix::OwnedReadHalf, unix::OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let stream = stream.into_split();
let (host, port) = this.dest.clone();
Some(anyhow::Ok((
stream,
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}
// TPROXY UDP
pub struct TProxyUdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
listener: S,
timeout: Option<Duration>,
}
impl<S> TProxyUdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
pub fn new(listener: S, timeout: Option<Duration>) -> Self {
Self { listener, timeout }
}
}
impl<S> Stream for TProxyUdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
type Item = anyhow::Result<((UdpStream, UdpStreamWriter), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
let ret = ready!(unsafe { Pin::new_unchecked(&mut this.listener) }.poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let (host, port) = to_host_port(stream.local_addr().unwrap());
let stream_writer = stream.writer();
Some(anyhow::Ok((
(stream, stream_writer),
RemoteAddr {
protocol: LocalProtocol::Udp { timeout: this.timeout },
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}
pub struct UdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
listener: S,
dest: (Host, u16),
timeout: Option<Duration>,
}
impl<S> UdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
pub fn new(listener: S, dest: (Host, u16), timeout: Option<Duration>) -> Self {
Self {
listener,
dest,
timeout,
}
}
}
impl<S> Stream for UdpTunnelListener<S>
where
S: Stream<Item = io::Result<UdpStream>>,
{
type Item = anyhow::Result<((UdpStream, UdpStreamWriter), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
let ret = ready!(unsafe { Pin::new_unchecked(&mut this.listener) }.poll_next(cx));
let ret = match ret {
Some(Ok(stream)) => {
let (host, port) = this.dest.clone();
let stream_writer = stream.writer();
Some(anyhow::Ok((
(stream, stream_writer),
RemoteAddr {
protocol: LocalProtocol::Udp { timeout: this.timeout },
host,
port,
},
)))
}
Some(Err(err)) => Some(Err(anyhow::Error::new(err))),
None => None,
};
Poll::Ready(ret)
}
}
pub struct Socks5TunnelListener {
listener: Socks5Listener,
}
impl Socks5TunnelListener {
pub fn new(listener: Socks5Listener) -> Self {
Self { listener }
}
}
impl Stream for Socks5TunnelListener {
type Item = anyhow::Result<((ReadHalf<Socks5Stream>, WriteHalf<Socks5Stream>), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok((stream, (host, port)))) => {
let protocol = stream.local_protocol();
Some(anyhow::Ok((tokio::io::split(stream), RemoteAddr { protocol, host, port })))
}
Some(Err(err)) => Some(Err(err)),
None => None,
};
Poll::Ready(ret)
}
}
pub struct HttpProxyTunnelListener {
listener: HttpProxyListener,
proxy_protocol: bool,
}
impl HttpProxyTunnelListener {
pub fn new(listener: HttpProxyListener, proxy_protocol: bool) -> Self {
Self {
listener,
proxy_protocol,
}
}
}
impl Stream for HttpProxyTunnelListener {
type Item = anyhow::Result<((OwnedReadHalf, OwnedWriteHalf), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let ret = ready!(Pin::new(&mut this.listener).poll_next(cx));
let ret = match ret {
Some(Ok((stream, (host, port)))) => {
let protocol = LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
};
Some(anyhow::Ok((stream.into_split(), RemoteAddr { protocol, host, port })))
}
Some(Err(err)) => Some(Err(err)),
None => None,
};
Poll::Ready(ret)
}
}
pub struct StdioTunnelListener<R, W>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
listener: Option<(R, W)>,
dest: (Host, u16),
proxy_protocol: bool,
}
impl<R, W> StdioTunnelListener<R, W>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
pub fn new(listener: (R, W), dest: (Host, u16), proxy_protocol: bool) -> Self {
Self {
listener: Some(listener),
proxy_protocol,
dest,
}
}
}
impl<R, W> Stream for StdioTunnelListener<R, W>
where
R: AsyncRead + Send + 'static,
W: AsyncWrite + Send + 'static,
{
type Item = anyhow::Result<((R, W), RemoteAddr)>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
let ret = match this.listener.take() {
None => None,
Some(stream) => {
let (host, port) = this.dest.clone();
Some(Ok((
stream,
RemoteAddr {
protocol: LocalProtocol::Tcp {
proxy_protocol: this.proxy_protocol,
},
host,
port,
},
)))
}
};
Poll::Ready(ret)
}
}

View file

@ -5,7 +5,7 @@ use parking_lot::RwLock;
use pin_project::{pin_project, pinned_drop}; use pin_project::{pin_project, pinned_drop};
use std::collections::HashMap; use std::collections::HashMap;
use std::future::Future; use std::future::Future;
use std::io; use std::{io, task};
use std::io::{Error, ErrorKind}; use std::io::{Error, ErrorKind};
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use tokio::task::JoinSet; use tokio::task::JoinSet;
@ -164,12 +164,18 @@ impl UdpStream {
pub fn local_addr(&self) -> io::Result<SocketAddr> { pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.send_socket.local_addr() self.send_socket.local_addr()
} }
pub fn writer(&self) -> UdpStreamWriter {
UdpStreamWriter {
send_socket: self.send_socket.clone(),
peer: self.peer,
}
}
} }
impl AsyncRead for UdpStream { impl AsyncRead for UdpStream {
fn poll_read( fn poll_read(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut task::Context<'_>,
obuf: &mut ReadBuf<'_>, obuf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
let mut project = self.project(); let mut project = self.project();
@ -209,16 +215,21 @@ impl AsyncRead for UdpStream {
} }
} }
impl AsyncWrite for UdpStream { pub struct UdpStreamWriter {
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> { send_socket: Arc<UdpSocket>,
peer: SocketAddr,
}
impl AsyncWrite for UdpStreamWriter {
fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
self.send_socket.poll_send_to(cx, buf, self.peer) self.send_socket.poll_send_to(cx, buf, self.peer)
} }
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> { fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
self.send_socket.poll_send_ready(cx) self.send_socket.poll_send_ready(cx)
} }
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> { fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
} }
@ -299,7 +310,7 @@ impl MyUdpSocket {
} }
impl AsyncRead for MyUdpSocket { impl AsyncRead for MyUdpSocket {
fn poll_read(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> { fn poll_read(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
unsafe { self.map_unchecked_mut(|x| &mut x.socket) } unsafe { self.map_unchecked_mut(|x| &mut x.socket) }
.poll_recv_from(cx, buf) .poll_recv_from(cx, buf)
.map(|x| x.map(|_| ())) .map(|x| x.map(|_| ()))
@ -307,15 +318,15 @@ impl AsyncRead for MyUdpSocket {
} }
impl AsyncWrite for MyUdpSocket { impl AsyncWrite for MyUdpSocket {
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> { fn poll_write(self: Pin<&mut Self>, cx: &mut task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf) unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf)
} }
fn poll_flush(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> { fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> { fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
} }