Cleanup
This commit is contained in:
parent
b478288848
commit
466cb425bc
11 changed files with 159 additions and 320 deletions
3
rustfmt.toml
Normal file
3
rustfmt.toml
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
edition = "2021"
|
||||||
|
max_width = 120
|
||||||
|
fn_call_width = 80
|
|
@ -3,14 +3,13 @@ use tokio_rustls::rustls::{Certificate, PrivateKey};
|
||||||
|
|
||||||
pub static TLS_PRIVATE_KEY: Lazy<PrivateKey> = Lazy::new(|| {
|
pub static TLS_PRIVATE_KEY: Lazy<PrivateKey> = Lazy::new(|| {
|
||||||
let key = include_bytes!("../certs/key.pem");
|
let key = include_bytes!("../certs/key.pem");
|
||||||
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key.as_slice())
|
let mut keys =
|
||||||
.expect("failed to load embedded tls private key");
|
rustls_pemfile::pkcs8_private_keys(&mut key.as_slice()).expect("failed to load embedded tls private key");
|
||||||
PrivateKey(keys.remove(0))
|
PrivateKey(keys.remove(0))
|
||||||
});
|
});
|
||||||
pub static TLS_CERTIFICATE: Lazy<Vec<Certificate>> = Lazy::new(|| {
|
pub static TLS_CERTIFICATE: Lazy<Vec<Certificate>> = Lazy::new(|| {
|
||||||
let cert = include_bytes!("../certs/cert.pem");
|
let cert = include_bytes!("../certs/cert.pem");
|
||||||
let certs = rustls_pemfile::certs(&mut cert.as_slice())
|
let certs = rustls_pemfile::certs(&mut cert.as_slice()).expect("failed to load embedded tls certificate");
|
||||||
.expect("failed to load embedded tls certificate");
|
|
||||||
|
|
||||||
certs.into_iter().map(Certificate).collect()
|
certs.into_iter().map(Certificate).collect()
|
||||||
});
|
});
|
||||||
|
|
87
src/main.rs
87
src/main.rs
|
@ -67,13 +67,7 @@ struct Client {
|
||||||
/// This option set the maximum number of connection that will be kept open.
|
/// This option set the maximum number of connection that will be kept open.
|
||||||
/// This is useful if you plan to create/destroy a lot of tunnel (i.e: with socks5 to navigate with a browser)
|
/// This is useful if you plan to create/destroy a lot of tunnel (i.e: with socks5 to navigate with a browser)
|
||||||
/// It will avoid the latency of doing tcp + tls handshake with the server
|
/// It will avoid the latency of doing tcp + tls handshake with the server
|
||||||
#[arg(
|
#[arg(short = 'c', long, value_name = "INT", default_value = "0", verbatim_doc_comment)]
|
||||||
short = 'c',
|
|
||||||
long,
|
|
||||||
value_name = "INT",
|
|
||||||
default_value = "0",
|
|
||||||
verbatim_doc_comment
|
|
||||||
)]
|
|
||||||
connection_min_idle: u32,
|
connection_min_idle: u32,
|
||||||
|
|
||||||
/// Domain name that will be use as SNI during TLS handshake
|
/// Domain name that will be use as SNI during TLS handshake
|
||||||
|
@ -88,12 +82,7 @@ struct Client {
|
||||||
tls_verify_certificate: bool,
|
tls_verify_certificate: bool,
|
||||||
|
|
||||||
/// If set, will use this http proxy to connect to the server
|
/// If set, will use this http proxy to connect to the server
|
||||||
#[arg(
|
#[arg(short = 'p', long, value_name = "http://USER:PASS@HOST:PORT", verbatim_doc_comment)]
|
||||||
short = 'p',
|
|
||||||
long,
|
|
||||||
value_name = "http://USER:PASS@HOST:PORT",
|
|
||||||
verbatim_doc_comment
|
|
||||||
)]
|
|
||||||
http_proxy: Option<Url>,
|
http_proxy: Option<Url>,
|
||||||
|
|
||||||
/// Use a specific prefix that will show up in the http path during the upgrade request.
|
/// Use a specific prefix that will show up in the http path during the upgrade request.
|
||||||
|
@ -241,9 +230,7 @@ fn parse_local_bind(arg: &str) -> Result<(SocketAddr, &str), io::Error> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
fn parse_tunnel_dest(
|
fn parse_tunnel_dest(remaining: &str) -> Result<(Host<String>, u16, BTreeMap<String, String>), io::Error> {
|
||||||
remaining: &str,
|
|
||||||
) -> Result<(Host<String>, u16, BTreeMap<String, String>), io::Error> {
|
|
||||||
use std::io::Error;
|
use std::io::Error;
|
||||||
|
|
||||||
let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else {
|
let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else {
|
||||||
|
@ -290,13 +277,7 @@ fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
|
||||||
let timeout = options
|
let timeout = options
|
||||||
.get("timeout_sec")
|
.get("timeout_sec")
|
||||||
.and_then(|x| x.parse::<u64>().ok())
|
.and_then(|x| x.parse::<u64>().ok())
|
||||||
.map(|d| {
|
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
|
||||||
if d == 0 {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(Duration::from_secs(d))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.unwrap_or(Some(Duration::from_secs(30)));
|
.unwrap_or(Some(Duration::from_secs(30)));
|
||||||
|
|
||||||
Ok(LocalToRemote {
|
Ok(LocalToRemote {
|
||||||
|
@ -355,10 +336,7 @@ fn parse_http_headers(arg: &str) -> Result<(HeaderName, HeaderValue), io::Error>
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
return Err(io::Error::new(
|
return Err(io::Error::new(
|
||||||
ErrorKind::InvalidInput,
|
ErrorKind::InvalidInput,
|
||||||
format!(
|
format!("cannot parse http header value from {} due to {:?}", value, err),
|
||||||
"cannot parse http header value from {} due to {:?}",
|
|
||||||
value, err
|
|
||||||
),
|
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -394,10 +372,7 @@ fn parse_server_url(arg: &str) -> Result<Url, io::Error> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if url.host().is_none() {
|
if url.host().is_none() {
|
||||||
return Err(io::Error::new(
|
return Err(io::Error::new(ErrorKind::InvalidInput, format!("invalid server host {}", arg)));
|
||||||
ErrorKind::InvalidInput,
|
|
||||||
format!("invalid server host {}", arg),
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(url)
|
Ok(url)
|
||||||
|
@ -474,15 +449,9 @@ impl WsClientConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tls_server_name(&self) -> ServerName {
|
pub fn tls_server_name(&self) -> ServerName {
|
||||||
match self
|
match self.tls.as_ref().and_then(|tls| tls.tls_sni_override.as_ref()) {
|
||||||
.tls
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|tls| tls.tls_sni_override.as_ref())
|
|
||||||
{
|
|
||||||
None => match &self.remote_addr.0 {
|
None => match &self.remote_addr.0 {
|
||||||
Host::Domain(domain) => {
|
Host::Domain(domain) => ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap()),
|
||||||
ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap())
|
|
||||||
}
|
|
||||||
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)),
|
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)),
|
||||||
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)),
|
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)),
|
||||||
},
|
},
|
||||||
|
@ -529,8 +498,7 @@ async fn main() {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Extract host header from http_headers
|
// Extract host header from http_headers
|
||||||
let host_header =
|
let host_header = if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) {
|
||||||
if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) {
|
|
||||||
host_val.clone()
|
host_val.clone()
|
||||||
} else {
|
} else {
|
||||||
HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap()
|
HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap()
|
||||||
|
@ -544,16 +512,10 @@ async fn main() {
|
||||||
tls,
|
tls,
|
||||||
http_upgrade_path_prefix: args.http_upgrade_path_prefix,
|
http_upgrade_path_prefix: args.http_upgrade_path_prefix,
|
||||||
http_upgrade_credentials: args.http_upgrade_credentials,
|
http_upgrade_credentials: args.http_upgrade_credentials,
|
||||||
http_headers: args
|
http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(),
|
||||||
.http_headers
|
|
||||||
.into_iter()
|
|
||||||
.filter(|(k, _)| k != HOST)
|
|
||||||
.collect(),
|
|
||||||
http_header_host: host_header,
|
http_header_host: host_header,
|
||||||
timeout_connect: Duration::from_secs(10),
|
timeout_connect: Duration::from_secs(10),
|
||||||
websocket_ping_frequency: args
|
websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)),
|
||||||
.websocket_ping_frequency_sec
|
|
||||||
.unwrap_or(Duration::from_secs(30)),
|
|
||||||
websocket_mask_frame: args.websocket_mask_frame,
|
websocket_mask_frame: args.websocket_mask_frame,
|
||||||
http_proxy: args.http_proxy,
|
http_proxy: args.http_proxy,
|
||||||
cnx_pool: None,
|
cnx_pool: None,
|
||||||
|
@ -579,16 +541,12 @@ async fn main() {
|
||||||
let remote = tunnel.remote.clone();
|
let remote = tunnel.remote.clone();
|
||||||
let server = tcp::run_server(tunnel.local)
|
let server = tcp::run_server(tunnel.local)
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|err| {
|
.unwrap_or_else(|err| panic!("Cannot start TCP server on {}: {}", tunnel.local, err))
|
||||||
panic!("Cannot start TCP server on {}: {}", tunnel.local, err)
|
|
||||||
})
|
|
||||||
.map_err(anyhow::Error::new)
|
.map_err(anyhow::Error::new)
|
||||||
.map_ok(move |stream| (stream.into_split(), remote.clone()));
|
.map_ok(move |stream| (stream.into_split(), remote.clone()));
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) =
|
if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
|
||||||
tunnel::client::run_tunnel(client_config, tunnel, server).await
|
|
||||||
{
|
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -597,16 +555,12 @@ async fn main() {
|
||||||
let remote = tunnel.remote.clone();
|
let remote = tunnel.remote.clone();
|
||||||
let server = udp::run_server(tunnel.local, *timeout)
|
let server = udp::run_server(tunnel.local, *timeout)
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|err| {
|
.unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err))
|
||||||
panic!("Cannot start UDP server on {}: {}", tunnel.local, err)
|
|
||||||
})
|
|
||||||
.map_err(anyhow::Error::new)
|
.map_err(anyhow::Error::new)
|
||||||
.map_ok(move |stream| (tokio::io::split(stream), remote.clone()));
|
.map_ok(move |stream| (tokio::io::split(stream), remote.clone()));
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) =
|
if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
|
||||||
tunnel::client::run_tunnel(client_config, tunnel, server).await
|
|
||||||
{
|
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -614,15 +568,11 @@ async fn main() {
|
||||||
LocalProtocol::Socks5 => {
|
LocalProtocol::Socks5 => {
|
||||||
let server = socks5::run_server(tunnel.local)
|
let server = socks5::run_server(tunnel.local)
|
||||||
.await
|
.await
|
||||||
.unwrap_or_else(|err| {
|
.unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err))
|
||||||
panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)
|
|
||||||
})
|
|
||||||
.map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest));
|
.map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest));
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(err) =
|
if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
|
||||||
tunnel::client::run_tunnel(client_config, tunnel, server).await
|
|
||||||
{
|
|
||||||
error!("{:?}", err);
|
error!("{:?}", err);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -656,8 +606,7 @@ async fn main() {
|
||||||
Commands::Server(args) => {
|
Commands::Server(args) => {
|
||||||
let tls_config = if args.remote_addr.scheme() == "wss" {
|
let tls_config = if args.remote_addr.scheme() == "wss" {
|
||||||
let tls_certificate = if let Some(cert_path) = args.tls_certificate {
|
let tls_certificate = if let Some(cert_path) = args.tls_certificate {
|
||||||
tls::load_certificates_from_pem(&cert_path)
|
tls::load_certificates_from_pem(&cert_path).expect("Cannot load tls certificate")
|
||||||
.expect("Cannot load tls certificate")
|
|
||||||
} else {
|
} else {
|
||||||
embedded_certificate::TLS_CERTIFICATE.clone()
|
embedded_certificate::TLS_CERTIFICATE.clone()
|
||||||
};
|
};
|
||||||
|
|
|
@ -19,10 +19,7 @@ pub struct Socks5Listener {
|
||||||
impl Stream for Socks5Listener {
|
impl Stream for Socks5Listener {
|
||||||
type Item = anyhow::Result<(TcpStream, (Host, u16))>;
|
type Item = anyhow::Result<(TcpStream, (Host, u16))>;
|
||||||
|
|
||||||
fn poll_next(
|
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<Option<Self::Item>> {
|
|
||||||
unsafe { self.map_unchecked_mut(|x| &mut x.stream) }.poll_next(cx)
|
unsafe { self.map_unchecked_mut(|x| &mut x.stream) }.poll_next(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
38
src/tcp.rs
38
src/tcp.rs
|
@ -14,12 +14,9 @@ use tracing::log::info;
|
||||||
use url::{Host, Url};
|
use url::{Host, Url};
|
||||||
|
|
||||||
fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(), anyhow::Error> {
|
fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(), anyhow::Error> {
|
||||||
socket.set_nodelay(true).with_context(|| {
|
socket
|
||||||
format!(
|
.set_nodelay(true)
|
||||||
"cannot set no_delay on socket: {}",
|
.with_context(|| format!("cannot set no_delay on socket: {}", io::Error::last_os_error()))?;
|
||||||
io::Error::last_os_error()
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
if let Some(so_mark) = so_mark {
|
if let Some(so_mark) = so_mark {
|
||||||
|
@ -35,10 +32,7 @@ fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(),
|
||||||
);
|
);
|
||||||
|
|
||||||
if ret != 0 {
|
if ret != 0 {
|
||||||
return Err(anyhow!(
|
return Err(anyhow!("Cannot set SO_MARK on the connection {:?}", io::Error::last_os_error()));
|
||||||
"Cannot set SO_MARK on the connection {:?}",
|
|
||||||
io::Error::last_os_error()
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -117,17 +111,14 @@ pub async fn connect_with_http_proxy(
|
||||||
let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout).await?;
|
let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout).await?;
|
||||||
info!("Connected to http proxy {}:{}", proxy_host, proxy_port);
|
info!("Connected to http proxy {}:{}", proxy_host, proxy_port);
|
||||||
|
|
||||||
let authorization =
|
let authorization = if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) {
|
||||||
if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) {
|
let creds = base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password));
|
||||||
let creds =
|
|
||||||
base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password));
|
|
||||||
format!("Proxy-Authorization: Basic {}\r\n", creds)
|
format!("Proxy-Authorization: Basic {}\r\n", creds)
|
||||||
} else {
|
} else {
|
||||||
"".to_string()
|
"".to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
let connect_request =
|
let connect_request = format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
|
||||||
format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
|
|
||||||
socket.write_all(connect_request.as_bytes()).await?;
|
socket.write_all(connect_request.as_bytes()).await?;
|
||||||
|
|
||||||
let mut buf = BytesMut::with_capacity(1024);
|
let mut buf = BytesMut::with_capacity(1024);
|
||||||
|
@ -136,16 +127,15 @@ pub async fn connect_with_http_proxy(
|
||||||
match nb_bytes {
|
match nb_bytes {
|
||||||
Ok(Ok(0)) => {
|
Ok(Ok(0)) => {
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(
|
||||||
"Cannot connect to http proxy. Proxy closed the connection without returning any response"));
|
"Cannot connect to http proxy. Proxy closed the connection without returning any response"
|
||||||
|
));
|
||||||
}
|
}
|
||||||
Ok(Ok(_)) => {}
|
Ok(Ok(_)) => {}
|
||||||
Ok(Err(err)) => {
|
Ok(Err(err)) => {
|
||||||
return Err(anyhow!("Cannot connect to http proxy. {err}"));
|
return Err(anyhow!("Cannot connect to http proxy. {err}"));
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
return Err(anyhow!(
|
return Err(anyhow!("Cannot connect to http proxy. Proxy took too long to connect"));
|
||||||
"Cannot connect to http proxy. Proxy took too long to connect"
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -225,8 +215,7 @@ mod tests {
|
||||||
let server = TcpListener::bind(server_addr).await.unwrap();
|
let server = TcpListener::bind(server_addr).await.unwrap();
|
||||||
|
|
||||||
let docker = testcontainers::clients::Cli::default();
|
let docker = testcontainers::clients::Cli::default();
|
||||||
let mitm_proxy: RunnableImage<MitmProxy> =
|
let mitm_proxy: RunnableImage<MitmProxy> = RunnableImage::from(MitmProxy {}).with_network("host".to_string());
|
||||||
RunnableImage::from(MitmProxy {}).with_network("host".to_string());
|
|
||||||
let _node = docker.run(mitm_proxy);
|
let _node = docker.run(mitm_proxy);
|
||||||
|
|
||||||
let mut client = connect_with_http_proxy(
|
let mut client = connect_with_http_proxy(
|
||||||
|
@ -239,10 +228,7 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
client
|
client.write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice()).await.unwrap();
|
||||||
.write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let client_srv = server.accept().await.unwrap().0;
|
let client_srv = server.accept().await.unwrap().0;
|
||||||
pin_mut!(client_srv);
|
pin_mut!(client_srv);
|
||||||
|
|
||||||
|
|
28
src/tls.rs
28
src/tls.rs
|
@ -45,21 +45,15 @@ pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKey> {
|
||||||
match keys.len() {
|
match keys.len() {
|
||||||
0 => Err(anyhow!("No PKCS8-encoded private key found in {path:?}")),
|
0 => Err(anyhow!("No PKCS8-encoded private key found in {path:?}")),
|
||||||
1 => Ok(PrivateKey(keys.remove(0))),
|
1 => Ok(PrivateKey(keys.remove(0))),
|
||||||
_ => Err(anyhow!(
|
_ => Err(anyhow!("More than one PKCS8-encoded private key found in {path:?}")),
|
||||||
"More than one PKCS8-encoded private key found in {path:?}"
|
|
||||||
)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tls_connector(
|
fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsConnector> {
|
||||||
tls_cfg: &TlsClientConfig,
|
|
||||||
alpn_protocols: Option<Vec<Vec<u8>>>,
|
|
||||||
) -> anyhow::Result<TlsConnector> {
|
|
||||||
let mut root_store = rustls::RootCertStore::empty();
|
let mut root_store = rustls::RootCertStore::empty();
|
||||||
|
|
||||||
// Load system certificates and add them to the root store
|
// Load system certificates and add them to the root store
|
||||||
let certs = rustls_native_certs::load_native_certs()
|
let certs = rustls_native_certs::load_native_certs().with_context(|| "Cannot load system certificates")?;
|
||||||
.with_context(|| "Cannot load system certificates")?;
|
|
||||||
for cert in certs {
|
for cert in certs {
|
||||||
root_store.add(&Certificate(cert.0))?;
|
root_store.add(&Certificate(cert.0))?;
|
||||||
}
|
}
|
||||||
|
@ -71,9 +65,7 @@ fn tls_connector(
|
||||||
|
|
||||||
// To bypass certificate verification
|
// To bypass certificate verification
|
||||||
if !tls_cfg.tls_verify_certificate {
|
if !tls_cfg.tls_verify_certificate {
|
||||||
config
|
config.dangerous().set_certificate_verifier(Arc::new(NullVerifier));
|
||||||
.dangerous()
|
|
||||||
.set_certificate_verifier(Arc::new(NullVerifier));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(alpn_protocols) = alpn_protocols {
|
if let Some(alpn_protocols) = alpn_protocols {
|
||||||
|
@ -83,10 +75,7 @@ fn tls_connector(
|
||||||
Ok(tls_connector)
|
Ok(tls_connector)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tls_acceptor(
|
pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsAcceptor> {
|
||||||
tls_cfg: &TlsServerConfig,
|
|
||||||
alpn_protocols: Option<Vec<Vec<u8>>>,
|
|
||||||
) -> anyhow::Result<TlsAcceptor> {
|
|
||||||
let mut config = rustls::ServerConfig::builder()
|
let mut config = rustls::ServerConfig::builder()
|
||||||
.with_safe_defaults()
|
.with_safe_defaults()
|
||||||
.with_no_client_auth()
|
.with_no_client_auth()
|
||||||
|
@ -114,12 +103,7 @@ pub async fn connect(
|
||||||
let tls_stream = tls_connector
|
let tls_stream = tls_connector
|
||||||
.connect(sni, tcp_stream)
|
.connect(sni, tcp_stream)
|
||||||
.await
|
.await
|
||||||
.with_context(|| {
|
.with_context(|| format!("failed to do TLS handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||||
format!(
|
|
||||||
"failed to do TLS handshake with the server {:?}",
|
|
||||||
client_cfg.remote_addr
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(tls_stream)
|
Ok(tls_stream)
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,9 +44,7 @@ pub async fn connect(
|
||||||
) -> anyhow::Result<WebSocket<Upgraded>> {
|
) -> anyhow::Result<WebSocket<Upgraded>> {
|
||||||
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
||||||
Ok(tcp_stream) => tcp_stream,
|
Ok(tcp_stream) => tcp_stream,
|
||||||
Err(err) => Err(anyhow!(
|
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}"))?,
|
||||||
"failed to get a connection to the server from the pool: {err:?}"
|
|
||||||
))?,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut req = Request::builder()
|
let mut req = Request::builder()
|
||||||
|
@ -80,12 +78,7 @@ pub async fn connect(
|
||||||
let transport = pooled_cnx.deref_mut().take().unwrap();
|
let transport = pooled_cnx.deref_mut().take().unwrap();
|
||||||
let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport)
|
let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport)
|
||||||
.await
|
.await
|
||||||
.with_context(|| {
|
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||||
format!(
|
|
||||||
"failed to do websocket handshake with the server {:?}",
|
|
||||||
client_cfg.remote_addr
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(ws)
|
Ok(ws)
|
||||||
}
|
}
|
||||||
|
@ -109,10 +102,7 @@ where
|
||||||
|
|
||||||
// Forward local tx to websocket tx
|
// Forward local tx to websocket tx
|
||||||
let ping_frequency = client_cfg.websocket_ping_frequency;
|
let ping_frequency = client_cfg.websocket_ping_frequency;
|
||||||
tokio::spawn(
|
tokio::spawn(super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current()));
|
||||||
super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency)
|
|
||||||
.instrument(Span::current()),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Forward websocket rx to local rx
|
// Forward websocket rx to local rx
|
||||||
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await;
|
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await;
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
|
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
|
||||||
use futures_util::pin_mut;
|
use futures_util::pin_mut;
|
||||||
use hyper::upgrade::Upgraded;
|
use hyper::upgrade::Upgraded;
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::sync::oneshot;
|
use tokio::sync::oneshot;
|
||||||
use tokio::time::timeout;
|
|
||||||
use tracing::log::debug;
|
use tracing::log::debug;
|
||||||
use tracing::{error, info, trace, warn};
|
use tracing::{error, info, trace, warn};
|
||||||
|
|
||||||
|
@ -20,7 +20,14 @@ pub(super) async fn propagate_read(
|
||||||
info!("Closing local tx ==> websocket tx tunnel");
|
info!("Closing local tx ==> websocket tx tunnel");
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut buffer = vec![0u8; 8 * 1024];
|
static JUMBO_FRAME_SIZE: usize = 9 * 1024; // enough for a jumbo frame
|
||||||
|
let mut buffer = vec![0u8; JUMBO_FRAME_SIZE];
|
||||||
|
|
||||||
|
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
|
||||||
|
// We reuse the future to avoid creating a timer in the tight loop
|
||||||
|
let mut timeout_unpin = tokio::time::sleep(ping_frequency);
|
||||||
|
let mut timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
|
||||||
|
|
||||||
pin_mut!(local_rx);
|
pin_mut!(local_rx);
|
||||||
loop {
|
loop {
|
||||||
let read_len = select! {
|
let read_len = select! {
|
||||||
|
@ -30,9 +37,12 @@ pub(super) async fn propagate_read(
|
||||||
|
|
||||||
_ = close_tx.closed() => break,
|
_ = close_tx.closed() => break,
|
||||||
|
|
||||||
_ = timeout(ping_frequency, futures_util::future::pending::<()>()) => {
|
_ = &mut timeout => {
|
||||||
debug!("sending ping to keep websocket connection alive");
|
debug!("sending ping to keep websocket connection alive");
|
||||||
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?;
|
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?;
|
||||||
|
|
||||||
|
timeout_unpin = tokio::time::sleep(ping_frequency);
|
||||||
|
timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -41,32 +51,30 @@ pub(super) async fn propagate_read(
|
||||||
Ok(0) => break,
|
Ok(0) => break,
|
||||||
Ok(read_len) => read_len,
|
Ok(read_len) => read_len,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!(
|
warn!("error while reading incoming bytes from local tx tunnel {}", err);
|
||||||
"error while reading incoming bytes from local tx tunnel {}",
|
|
||||||
err
|
|
||||||
);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
trace!("read {} bytes", read_len);
|
trace!("read {} bytes", read_len);
|
||||||
match ws_tx
|
if let Err(err) = ws_tx
|
||||||
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len])))
|
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len])))
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(_) => {}
|
|
||||||
Err(err) => {
|
|
||||||
warn!("error while writing to websocket tx tunnel {}", err);
|
warn!("error while writing to websocket tx tunnel {}", err);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
// If the buffer has been completely filled with previous read, Double it !
|
||||||
|
// For the buffer to not be a bottleneck when the TCP window scale
|
||||||
|
// For udp, the buffer will never grows.
|
||||||
if buffer.capacity() == read_len {
|
if buffer.capacity() == read_len {
|
||||||
buffer.clear();
|
buffer.clear();
|
||||||
buffer.resize(buffer.capacity() * 2, 0);
|
buffer.resize(buffer.capacity() * 2, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send normal close
|
||||||
let _ = ws_tx.write_frame(Frame::close(1000, &[])).await;
|
let _ = ws_tx.write_frame(Frame::close(1000, &[])).await;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -104,22 +112,17 @@ pub(super) async fn propagate_write(
|
||||||
|
|
||||||
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
|
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
|
||||||
let ret = match msg.opcode {
|
let ret = match msg.opcode {
|
||||||
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
|
OpCode::Continuation | OpCode::Text | OpCode::Binary => local_tx.write_all(msg.payload.as_ref()).await,
|
||||||
local_tx.write_all(msg.payload.as_ref()).await
|
|
||||||
}
|
|
||||||
OpCode::Close => break,
|
OpCode::Close => break,
|
||||||
OpCode::Ping => Ok(()),
|
OpCode::Ping => Ok(()),
|
||||||
OpCode::Pong => Ok(()),
|
OpCode::Pong => Ok(()),
|
||||||
};
|
};
|
||||||
|
|
||||||
match ret {
|
if let Err(err) = ret {
|
||||||
Ok(_) => {}
|
|
||||||
Err(err) => {
|
|
||||||
error!("error while writing bytes to local for rx tunnel {}", err);
|
error!("error while writing bytes to local for rx tunnel {}", err);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,12 +42,8 @@ impl JwtTunnelConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
|
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
|
||||||
static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| {
|
static JWT_KEY: Lazy<(Header, EncodingKey)> =
|
||||||
(
|
Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET)));
|
||||||
Header::new(Algorithm::HS256),
|
|
||||||
EncodingKey::from_secret(JWT_SECRET),
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
|
static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
|
||||||
let mut validation = Validation::new(Algorithm::HS256);
|
let mut validation = Validation::new(Algorithm::HS256);
|
||||||
|
@ -61,11 +57,7 @@ pub enum TransportStream {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncRead for TransportStream {
|
impl AsyncRead for TransportStream {
|
||||||
fn poll_read(
|
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut ReadBuf<'_>,
|
|
||||||
) -> Poll<std::io::Result<()>> {
|
|
||||||
match self.get_mut() {
|
match self.get_mut() {
|
||||||
TransportStream::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
|
TransportStream::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
|
||||||
TransportStream::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
|
TransportStream::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
|
||||||
|
@ -74,11 +66,7 @@ impl AsyncRead for TransportStream {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncWrite for TransportStream {
|
impl AsyncWrite for TransportStream {
|
||||||
fn poll_write(
|
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<Result<usize, Error>> {
|
|
||||||
match self.get_mut() {
|
match self.get_mut() {
|
||||||
TransportStream::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
TransportStream::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
||||||
TransportStream::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
TransportStream::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
||||||
|
|
|
@ -46,15 +46,8 @@ async fn from_query(
|
||||||
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
||||||
if let Some(allowed_dests) = &server_config.restrict_to {
|
if let Some(allowed_dests) = &server_config.restrict_to {
|
||||||
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
|
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
|
||||||
if allowed_dests
|
if allowed_dests.iter().any(|dest| dest == &requested_dest).not() {
|
||||||
.iter()
|
warn!("Rejecting connection with not allowed destination: {}", requested_dest);
|
||||||
.any(|dest| dest == &requested_dest)
|
|
||||||
.not()
|
|
||||||
{
|
|
||||||
warn!(
|
|
||||||
"Rejecting connection with not allowed destination: {}",
|
|
||||||
requested_dest
|
|
||||||
);
|
|
||||||
return Err(anyhow::anyhow!("Invalid upgrade request"));
|
return Err(anyhow::anyhow!("Invalid upgrade request"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -75,12 +68,7 @@ async fn from_query(
|
||||||
LocalProtocol::Tcp { .. } => {
|
LocalProtocol::Tcp { .. } => {
|
||||||
let host = Host::parse(&jwt.claims.r)?;
|
let host = Host::parse(&jwt.claims.r)?;
|
||||||
let port = jwt.claims.rp;
|
let port = jwt.claims.rp;
|
||||||
let (rx, tx) = tcp::connect(
|
let (rx, tx) = tcp::connect(&host, port, &server_config.socket_so_mark, Duration::from_secs(10))
|
||||||
&host,
|
|
||||||
port,
|
|
||||||
&server_config.socket_so_mark,
|
|
||||||
Duration::from_secs(10),
|
|
||||||
)
|
|
||||||
.await?
|
.await?
|
||||||
.into_split();
|
.into_split();
|
||||||
|
|
||||||
|
@ -100,10 +88,7 @@ async fn server_upgrade(
|
||||||
}
|
}
|
||||||
|
|
||||||
if !req.uri().path().ends_with("/events") {
|
if !req.uri().path().ends_with("/events") {
|
||||||
warn!(
|
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
||||||
"Rejecting connection with bad upgrade request: {}",
|
|
||||||
req.uri()
|
|
||||||
);
|
|
||||||
return Ok(http::Response::builder()
|
return Ok(http::Response::builder()
|
||||||
.status(StatusCode::BAD_REQUEST)
|
.status(StatusCode::BAD_REQUEST)
|
||||||
.body(Body::from("Invalid upgrade request"))
|
.body(Body::from("Invalid upgrade request"))
|
||||||
|
@ -118,10 +103,7 @@ async fn server_upgrade(
|
||||||
|| &path[min_len..max_len] != path_prefix.as_str()
|
|| &path[min_len..max_len] != path_prefix.as_str()
|
||||||
|| !path[max_len..].starts_with('/')
|
|| !path[max_len..].starts_with('/')
|
||||||
{
|
{
|
||||||
warn!(
|
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
|
||||||
"Rejecting connection with bad path prefix in upgrade request: {}",
|
|
||||||
req.uri()
|
|
||||||
);
|
|
||||||
return Ok(http::Response::builder()
|
return Ok(http::Response::builder()
|
||||||
.status(StatusCode::BAD_REQUEST)
|
.status(StatusCode::BAD_REQUEST)
|
||||||
.body(Body::from("Invalid upgrade request"))
|
.body(Body::from("Invalid upgrade request"))
|
||||||
|
@ -133,11 +115,7 @@ async fn server_upgrade(
|
||||||
match from_query(&server_config, req.uri().query().unwrap_or_default()).await {
|
match from_query(&server_config, req.uri().query().unwrap_or_default()).await {
|
||||||
Ok(ret) => ret,
|
Ok(ret) => ret,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!(
|
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||||
"Rejecting connection with bad upgrade request: {} {}",
|
|
||||||
err,
|
|
||||||
req.uri()
|
|
||||||
);
|
|
||||||
return Ok(http::Response::builder()
|
return Ok(http::Response::builder()
|
||||||
.status(StatusCode::BAD_REQUEST)
|
.status(StatusCode::BAD_REQUEST)
|
||||||
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
|
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
|
||||||
|
@ -149,11 +127,7 @@ async fn server_upgrade(
|
||||||
let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
|
let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
|
||||||
Ok(ret) => ret,
|
Ok(ret) => ret,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!(
|
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||||
"Rejecting connection with bad upgrade request: {} {}",
|
|
||||||
err,
|
|
||||||
req.uri()
|
|
||||||
);
|
|
||||||
return Ok(http::Response::builder()
|
return Ok(http::Response::builder()
|
||||||
.status(StatusCode::BAD_REQUEST)
|
.status(StatusCode::BAD_REQUEST)
|
||||||
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
|
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
|
||||||
|
@ -171,14 +145,10 @@ async fn server_upgrade(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||||
let ping_frequency = server_config
|
let ping_frequency = server_config.websocket_ping_frequency.unwrap_or(Duration::MAX);
|
||||||
.websocket_ping_frequency
|
|
||||||
.unwrap_or(Duration::MAX);
|
|
||||||
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
|
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
|
||||||
|
|
||||||
tokio::task::spawn(
|
tokio::task::spawn(super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()));
|
||||||
super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()),
|
|
||||||
);
|
|
||||||
|
|
||||||
let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await;
|
let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await;
|
||||||
}
|
}
|
||||||
|
@ -189,10 +159,7 @@ async fn server_upgrade(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
|
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
|
||||||
info!(
|
info!("Starting wstunnel server listening on {}", server_config.bind);
|
||||||
"Starting wstunnel server listening on {}",
|
|
||||||
server_config.bind
|
|
||||||
);
|
|
||||||
|
|
||||||
let config = server_config.clone();
|
let config = server_config.clone();
|
||||||
let upgrade_fn = move |req: Request<Body>| server_upgrade(config.clone(), req);
|
let upgrade_fn = move |req: Request<Body>| server_upgrade(config.clone(), req);
|
||||||
|
|
155
src/udp.rs
155
src/udp.rs
|
@ -1,18 +1,17 @@
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use futures_util::{stream, Stream};
|
use futures_util::{stream, Stream};
|
||||||
|
|
||||||
use parking_lot::{Mutex, RwLock};
|
use parking_lot::RwLock;
|
||||||
use pin_project::{pin_project, pinned_drop};
|
use pin_project::{pin_project, pinned_drop};
|
||||||
use std::collections::hash_map::Entry;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::io::{Error, ErrorKind};
|
use std::io::{Error, ErrorKind};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::ops::DerefMut;
|
|
||||||
use std::pin::{pin, Pin};
|
use std::pin::{pin, Pin};
|
||||||
use std::sync::{Arc, Weak};
|
use std::sync::{Arc, Weak};
|
||||||
use std::task::{ready, Poll, Waker};
|
use std::task::{ready, Poll};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
|
@ -23,8 +22,7 @@ use tokio::time::Sleep;
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, info};
|
||||||
|
|
||||||
struct IoInner {
|
struct IoInner {
|
||||||
has_data_to_read: &'static Notify,
|
has_data_to_read: Notify,
|
||||||
waker: Mutex<Option<Waker>>,
|
|
||||||
has_read_data: Notify,
|
has_read_data: Notify,
|
||||||
}
|
}
|
||||||
struct UdpServer {
|
struct UdpServer {
|
||||||
|
@ -43,6 +41,7 @@ impl UdpServer {
|
||||||
cnx_timeout: timeout,
|
cnx_timeout: timeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#[inline]
|
||||||
fn clean_dead_keys(&mut self) {
|
fn clean_dead_keys(&mut self) {
|
||||||
let nb_key_to_delete = self.keys_to_delete.read().len();
|
let nb_key_to_delete = self.keys_to_delete.read().len();
|
||||||
if nb_key_to_delete == 0 {
|
if nb_key_to_delete == 0 {
|
||||||
|
@ -52,16 +51,7 @@ impl UdpServer {
|
||||||
debug!("Cleaning {} dead udp peers", nb_key_to_delete);
|
debug!("Cleaning {} dead udp peers", nb_key_to_delete);
|
||||||
let mut keys_to_delete = self.keys_to_delete.write();
|
let mut keys_to_delete = self.keys_to_delete.write();
|
||||||
for key in keys_to_delete.iter() {
|
for key in keys_to_delete.iter() {
|
||||||
let Some(peer) = self.peers.remove(key) else {
|
self.peers.remove(key);
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
#[allow(mutable_transmutes)]
|
|
||||||
unsafe {
|
|
||||||
let _ = Box::from_raw(std::mem::transmute::<&Notify, &mut Notify>(
|
|
||||||
peer.has_data_to_read,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
keys_to_delete.clear();
|
keys_to_delete.clear();
|
||||||
}
|
}
|
||||||
|
@ -90,7 +80,42 @@ impl PinnedDrop for UdpStream {
|
||||||
keys_to_delete.write().push(self.peer);
|
keys_to_delete.write().push(self.peer);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.io.has_read_data.notify_one();
|
// safety: we are dropping the notification as we extend its lifetime to 'static unsafely
|
||||||
|
// So it must be gone before we drop its parent. It should never happen but in case
|
||||||
|
let mut project = self.project();
|
||||||
|
project.pending_notification.as_mut().set(None);
|
||||||
|
project.io.has_read_data.notify_one();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UdpStream {
|
||||||
|
fn new(
|
||||||
|
socket: Arc<UdpSocket>,
|
||||||
|
peer: SocketAddr,
|
||||||
|
deadline: Option<Sleep>,
|
||||||
|
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
|
||||||
|
) -> (Self, Arc<IoInner>) {
|
||||||
|
let has_data_to_read = Notify::new();
|
||||||
|
let has_read_data = Notify::new();
|
||||||
|
has_data_to_read.notify_one();
|
||||||
|
let io = Arc::new(IoInner {
|
||||||
|
has_data_to_read,
|
||||||
|
has_read_data,
|
||||||
|
});
|
||||||
|
let mut s = Self {
|
||||||
|
socket,
|
||||||
|
peer,
|
||||||
|
deadline,
|
||||||
|
has_been_notified: false,
|
||||||
|
pending_notification: None,
|
||||||
|
io: io.clone(),
|
||||||
|
keys_to_delete,
|
||||||
|
};
|
||||||
|
|
||||||
|
let pending_notification = unsafe { std::mem::transmute(s.io.has_data_to_read.notified()) };
|
||||||
|
s.pending_notification = Some(pending_notification);
|
||||||
|
|
||||||
|
(s, io)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -111,43 +136,28 @@ impl AsyncRead for UdpStream {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(notified) = project.pending_notification.as_mut().as_pin_mut() {
|
if let Some(notified) = project.pending_notification.as_mut().as_pin_mut() {
|
||||||
if !notified.poll(cx).is_ready() {
|
ready!(notified.poll(cx));
|
||||||
project.io.waker.lock().replace(cx.waker().clone());
|
|
||||||
return Poll::Pending;
|
|
||||||
}
|
|
||||||
project.pending_notification.as_mut().set(None);
|
project.pending_notification.as_mut().set(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = ready!(project.socket.poll_recv(cx, obuf));
|
let _ = ready!(project.socket.poll_recv(cx, obuf));
|
||||||
project
|
let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) };
|
||||||
.pending_notification
|
project.pending_notification.as_mut().set(Some(notified));
|
||||||
.as_mut()
|
|
||||||
.set(Some(project.io.has_data_to_read.notified()));
|
|
||||||
project.io.has_read_data.notify_one();
|
project.io.has_read_data.notify_one();
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncWrite for UdpStream {
|
impl AsyncWrite for UdpStream {
|
||||||
fn poll_write(
|
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<Result<usize, Error>> {
|
|
||||||
self.socket.poll_send_to(cx, buf, self.peer)
|
self.socket.poll_send_to(cx, buf, self.peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<Result<(), Error>> {
|
|
||||||
self.socket.poll_send_ready(cx)
|
self.socket.poll_send_ready(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_shutdown(
|
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
self: Pin<&mut Self>,
|
|
||||||
_cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<Result<(), Error>> {
|
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -178,39 +188,22 @@ pub async fn run_server(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match server.peers.entry(peer_addr) {
|
match server.peers.get(&peer_addr) {
|
||||||
Entry::Occupied(mut peer) => {
|
Some(io) => {
|
||||||
let io = peer.get_mut();
|
|
||||||
io.has_read_data.notified().await;
|
io.has_read_data.notified().await;
|
||||||
io.has_data_to_read.notify_one();
|
io.has_data_to_read.notify_one();
|
||||||
let waker = io.waker.lock().deref_mut().take();
|
|
||||||
if let Some(waker) = waker {
|
|
||||||
waker.wake();
|
|
||||||
}
|
}
|
||||||
}
|
None => {
|
||||||
Entry::Vacant(peer) => {
|
let (udp_client, io) = UdpStream::new(
|
||||||
let has_data_to_read: &'static Notify = Box::leak(Box::new(Notify::new()));
|
server.clone_socket(),
|
||||||
let pending_notification = has_data_to_read.notified();
|
peer_addr,
|
||||||
let has_read_data = Notify::new();
|
server
|
||||||
has_data_to_read.notify_one();
|
|
||||||
let io = Arc::new(IoInner {
|
|
||||||
has_data_to_read,
|
|
||||||
waker: Mutex::new(None),
|
|
||||||
has_read_data,
|
|
||||||
});
|
|
||||||
peer.insert(io.clone());
|
|
||||||
let udp_client = UdpStream {
|
|
||||||
socket: server.clone_socket(),
|
|
||||||
peer: peer_addr,
|
|
||||||
deadline: server
|
|
||||||
.cnx_timeout
|
.cnx_timeout
|
||||||
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
|
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
|
||||||
.map(tokio::time::sleep_until),
|
.map(tokio::time::sleep_until),
|
||||||
keys_to_delete: Arc::downgrade(&server.keys_to_delete),
|
Arc::downgrade(&server.keys_to_delete),
|
||||||
has_been_notified: false,
|
);
|
||||||
pending_notification: Some(pending_notification),
|
server.peers.insert(peer_addr, io);
|
||||||
io,
|
|
||||||
};
|
|
||||||
return Some((Ok(udp_client), (server)));
|
return Some((Ok(udp_client), (server)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -231,11 +224,7 @@ impl MyUdpSocket {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncRead for MyUdpSocket {
|
impl AsyncRead for MyUdpSocket {
|
||||||
fn poll_read(
|
fn poll_read(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
buf: &mut ReadBuf<'_>,
|
|
||||||
) -> Poll<io::Result<()>> {
|
|
||||||
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }
|
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }
|
||||||
.poll_recv_from(cx, buf)
|
.poll_recv_from(cx, buf)
|
||||||
.map(|x| x.map(|_| ()))
|
.map(|x| x.map(|_| ()))
|
||||||
|
@ -243,25 +232,15 @@ impl AsyncRead for MyUdpSocket {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncWrite for MyUdpSocket {
|
impl AsyncWrite for MyUdpSocket {
|
||||||
fn poll_write(
|
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut std::task::Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<Result<usize, Error>> {
|
|
||||||
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf)
|
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
self: Pin<&mut Self>,
|
|
||||||
_cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<Result<(), Error>> {
|
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_shutdown(
|
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
self: Pin<&mut Self>,
|
|
||||||
_cx: &mut std::task::Context<'_>,
|
|
||||||
) -> Poll<Result<(), Error>> {
|
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -331,10 +310,7 @@ mod tests {
|
||||||
assert!(client.send_to(b"aaaaa".as_ref(), server_addr).await.is_ok());
|
assert!(client.send_to(b"aaaaa".as_ref(), server_addr).await.is_ok());
|
||||||
|
|
||||||
let client2 = UdpSocket::bind("[::1]:0").await.unwrap();
|
let client2 = UdpSocket::bind("[::1]:0").await.unwrap();
|
||||||
assert!(client2
|
assert!(client2.send_to(b"bbbbb".as_ref(), server_addr).await.is_ok());
|
||||||
.send_to(b"bbbbb".as_ref(), server_addr)
|
|
||||||
.await
|
|
||||||
.is_ok());
|
|
||||||
|
|
||||||
// Should have a new connection
|
// Should have a new connection
|
||||||
let fut = timeout(Duration::from_millis(100), server.next()).await;
|
let fut = timeout(Duration::from_millis(100), server.next()).await;
|
||||||
|
@ -360,10 +336,7 @@ mod tests {
|
||||||
assert_eq!(&buf[..6], b"bbbbb\0");
|
assert_eq!(&buf[..6], b"bbbbb\0");
|
||||||
|
|
||||||
assert!(client.send_to(b"ccccc".as_ref(), server_addr).await.is_ok());
|
assert!(client.send_to(b"ccccc".as_ref(), server_addr).await.is_ok());
|
||||||
assert!(client2
|
assert!(client2.send_to(b"ddddd".as_ref(), server_addr).await.is_ok());
|
||||||
.send_to(b"ddddd".as_ref(), server_addr)
|
|
||||||
.await
|
|
||||||
.is_ok());
|
|
||||||
|
|
||||||
// Server need to be polled to feed the stream with need data
|
// Server need to be polled to feed the stream with need data
|
||||||
let _ = timeout(Duration::from_millis(100), server.next()).await;
|
let _ = timeout(Duration::from_millis(100), server.next()).await;
|
||||||
|
|
Loading…
Reference in a new issue