This commit is contained in:
Σrebe - Romain GERARD 2023-10-30 08:13:38 +01:00
parent b478288848
commit 466cb425bc
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
11 changed files with 159 additions and 320 deletions

3
rustfmt.toml Normal file
View file

@ -0,0 +1,3 @@
edition = "2021"
max_width = 120
fn_call_width = 80

View file

@ -3,14 +3,13 @@ use tokio_rustls::rustls::{Certificate, PrivateKey};
pub static TLS_PRIVATE_KEY: Lazy<PrivateKey> = Lazy::new(|| { pub static TLS_PRIVATE_KEY: Lazy<PrivateKey> = Lazy::new(|| {
let key = include_bytes!("../certs/key.pem"); let key = include_bytes!("../certs/key.pem");
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key.as_slice()) let mut keys =
.expect("failed to load embedded tls private key"); rustls_pemfile::pkcs8_private_keys(&mut key.as_slice()).expect("failed to load embedded tls private key");
PrivateKey(keys.remove(0)) PrivateKey(keys.remove(0))
}); });
pub static TLS_CERTIFICATE: Lazy<Vec<Certificate>> = Lazy::new(|| { pub static TLS_CERTIFICATE: Lazy<Vec<Certificate>> = Lazy::new(|| {
let cert = include_bytes!("../certs/cert.pem"); let cert = include_bytes!("../certs/cert.pem");
let certs = rustls_pemfile::certs(&mut cert.as_slice()) let certs = rustls_pemfile::certs(&mut cert.as_slice()).expect("failed to load embedded tls certificate");
.expect("failed to load embedded tls certificate");
certs.into_iter().map(Certificate).collect() certs.into_iter().map(Certificate).collect()
}); });

View file

@ -67,13 +67,7 @@ struct Client {
/// This option set the maximum number of connection that will be kept open. /// This option set the maximum number of connection that will be kept open.
/// This is useful if you plan to create/destroy a lot of tunnel (i.e: with socks5 to navigate with a browser) /// This is useful if you plan to create/destroy a lot of tunnel (i.e: with socks5 to navigate with a browser)
/// It will avoid the latency of doing tcp + tls handshake with the server /// It will avoid the latency of doing tcp + tls handshake with the server
#[arg( #[arg(short = 'c', long, value_name = "INT", default_value = "0", verbatim_doc_comment)]
short = 'c',
long,
value_name = "INT",
default_value = "0",
verbatim_doc_comment
)]
connection_min_idle: u32, connection_min_idle: u32,
/// Domain name that will be use as SNI during TLS handshake /// Domain name that will be use as SNI during TLS handshake
@ -88,12 +82,7 @@ struct Client {
tls_verify_certificate: bool, tls_verify_certificate: bool,
/// If set, will use this http proxy to connect to the server /// If set, will use this http proxy to connect to the server
#[arg( #[arg(short = 'p', long, value_name = "http://USER:PASS@HOST:PORT", verbatim_doc_comment)]
short = 'p',
long,
value_name = "http://USER:PASS@HOST:PORT",
verbatim_doc_comment
)]
http_proxy: Option<Url>, http_proxy: Option<Url>,
/// Use a specific prefix that will show up in the http path during the upgrade request. /// Use a specific prefix that will show up in the http path during the upgrade request.
@ -241,9 +230,7 @@ fn parse_local_bind(arg: &str) -> Result<(SocketAddr, &str), io::Error> {
} }
#[allow(clippy::type_complexity)] #[allow(clippy::type_complexity)]
fn parse_tunnel_dest( fn parse_tunnel_dest(remaining: &str) -> Result<(Host<String>, u16, BTreeMap<String, String>), io::Error> {
remaining: &str,
) -> Result<(Host<String>, u16, BTreeMap<String, String>), io::Error> {
use std::io::Error; use std::io::Error;
let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else { let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else {
@ -290,13 +277,7 @@ fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
let timeout = options let timeout = options
.get("timeout_sec") .get("timeout_sec")
.and_then(|x| x.parse::<u64>().ok()) .and_then(|x| x.parse::<u64>().ok())
.map(|d| { .map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
if d == 0 {
None
} else {
Some(Duration::from_secs(d))
}
})
.unwrap_or(Some(Duration::from_secs(30))); .unwrap_or(Some(Duration::from_secs(30)));
Ok(LocalToRemote { Ok(LocalToRemote {
@ -355,10 +336,7 @@ fn parse_http_headers(arg: &str) -> Result<(HeaderName, HeaderValue), io::Error>
Err(err) => { Err(err) => {
return Err(io::Error::new( return Err(io::Error::new(
ErrorKind::InvalidInput, ErrorKind::InvalidInput,
format!( format!("cannot parse http header value from {} due to {:?}", value, err),
"cannot parse http header value from {} due to {:?}",
value, err
),
)) ))
} }
}; };
@ -394,10 +372,7 @@ fn parse_server_url(arg: &str) -> Result<Url, io::Error> {
} }
if url.host().is_none() { if url.host().is_none() {
return Err(io::Error::new( return Err(io::Error::new(ErrorKind::InvalidInput, format!("invalid server host {}", arg)));
ErrorKind::InvalidInput,
format!("invalid server host {}", arg),
));
} }
Ok(url) Ok(url)
@ -474,15 +449,9 @@ impl WsClientConfig {
} }
pub fn tls_server_name(&self) -> ServerName { pub fn tls_server_name(&self) -> ServerName {
match self match self.tls.as_ref().and_then(|tls| tls.tls_sni_override.as_ref()) {
.tls
.as_ref()
.and_then(|tls| tls.tls_sni_override.as_ref())
{
None => match &self.remote_addr.0 { None => match &self.remote_addr.0 {
Host::Domain(domain) => { Host::Domain(domain) => ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap()),
ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap())
}
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)), Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)),
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)), Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)),
}, },
@ -529,12 +498,11 @@ async fn main() {
}; };
// Extract host header from http_headers // Extract host header from http_headers
let host_header = let host_header = if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) {
if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) { host_val.clone()
host_val.clone() } else {
} else { HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap()
HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap() };
};
let mut client_config = WsClientConfig { let mut client_config = WsClientConfig {
remote_addr: ( remote_addr: (
args.remote_addr.host().unwrap().to_owned(), args.remote_addr.host().unwrap().to_owned(),
@ -544,16 +512,10 @@ async fn main() {
tls, tls,
http_upgrade_path_prefix: args.http_upgrade_path_prefix, http_upgrade_path_prefix: args.http_upgrade_path_prefix,
http_upgrade_credentials: args.http_upgrade_credentials, http_upgrade_credentials: args.http_upgrade_credentials,
http_headers: args http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(),
.http_headers
.into_iter()
.filter(|(k, _)| k != HOST)
.collect(),
http_header_host: host_header, http_header_host: host_header,
timeout_connect: Duration::from_secs(10), timeout_connect: Duration::from_secs(10),
websocket_ping_frequency: args websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)),
.websocket_ping_frequency_sec
.unwrap_or(Duration::from_secs(30)),
websocket_mask_frame: args.websocket_mask_frame, websocket_mask_frame: args.websocket_mask_frame,
http_proxy: args.http_proxy, http_proxy: args.http_proxy,
cnx_pool: None, cnx_pool: None,
@ -579,16 +541,12 @@ async fn main() {
let remote = tunnel.remote.clone(); let remote = tunnel.remote.clone();
let server = tcp::run_server(tunnel.local) let server = tcp::run_server(tunnel.local)
.await .await
.unwrap_or_else(|err| { .unwrap_or_else(|err| panic!("Cannot start TCP server on {}: {}", tunnel.local, err))
panic!("Cannot start TCP server on {}: {}", tunnel.local, err)
})
.map_err(anyhow::Error::new) .map_err(anyhow::Error::new)
.map_ok(move |stream| (stream.into_split(), remote.clone())); .map_ok(move |stream| (stream.into_split(), remote.clone()));
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
tunnel::client::run_tunnel(client_config, tunnel, server).await
{
error!("{:?}", err); error!("{:?}", err);
} }
}); });
@ -597,16 +555,12 @@ async fn main() {
let remote = tunnel.remote.clone(); let remote = tunnel.remote.clone();
let server = udp::run_server(tunnel.local, *timeout) let server = udp::run_server(tunnel.local, *timeout)
.await .await
.unwrap_or_else(|err| { .unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err))
panic!("Cannot start UDP server on {}: {}", tunnel.local, err)
})
.map_err(anyhow::Error::new) .map_err(anyhow::Error::new)
.map_ok(move |stream| (tokio::io::split(stream), remote.clone())); .map_ok(move |stream| (tokio::io::split(stream), remote.clone()));
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
tunnel::client::run_tunnel(client_config, tunnel, server).await
{
error!("{:?}", err); error!("{:?}", err);
} }
}); });
@ -614,15 +568,11 @@ async fn main() {
LocalProtocol::Socks5 => { LocalProtocol::Socks5 => {
let server = socks5::run_server(tunnel.local) let server = socks5::run_server(tunnel.local)
.await .await
.unwrap_or_else(|err| { .unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err))
panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)
})
.map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest)); .map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest));
tokio::spawn(async move { tokio::spawn(async move {
if let Err(err) = if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
tunnel::client::run_tunnel(client_config, tunnel, server).await
{
error!("{:?}", err); error!("{:?}", err);
} }
}); });
@ -656,8 +606,7 @@ async fn main() {
Commands::Server(args) => { Commands::Server(args) => {
let tls_config = if args.remote_addr.scheme() == "wss" { let tls_config = if args.remote_addr.scheme() == "wss" {
let tls_certificate = if let Some(cert_path) = args.tls_certificate { let tls_certificate = if let Some(cert_path) = args.tls_certificate {
tls::load_certificates_from_pem(&cert_path) tls::load_certificates_from_pem(&cert_path).expect("Cannot load tls certificate")
.expect("Cannot load tls certificate")
} else { } else {
embedded_certificate::TLS_CERTIFICATE.clone() embedded_certificate::TLS_CERTIFICATE.clone()
}; };

View file

@ -19,10 +19,7 @@ pub struct Socks5Listener {
impl Stream for Socks5Listener { impl Stream for Socks5Listener {
type Item = anyhow::Result<(TcpStream, (Host, u16))>; type Item = anyhow::Result<(TcpStream, (Host, u16))>;
fn poll_next( fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
unsafe { self.map_unchecked_mut(|x| &mut x.stream) }.poll_next(cx) unsafe { self.map_unchecked_mut(|x| &mut x.stream) }.poll_next(cx)
} }
} }

View file

@ -14,12 +14,9 @@ use tracing::log::info;
use url::{Host, Url}; use url::{Host, Url};
fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(), anyhow::Error> { fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(), anyhow::Error> {
socket.set_nodelay(true).with_context(|| { socket
format!( .set_nodelay(true)
"cannot set no_delay on socket: {}", .with_context(|| format!("cannot set no_delay on socket: {}", io::Error::last_os_error()))?;
io::Error::last_os_error()
)
})?;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
if let Some(so_mark) = so_mark { if let Some(so_mark) = so_mark {
@ -35,10 +32,7 @@ fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(),
); );
if ret != 0 { if ret != 0 {
return Err(anyhow!( return Err(anyhow!("Cannot set SO_MARK on the connection {:?}", io::Error::last_os_error()));
"Cannot set SO_MARK on the connection {:?}",
io::Error::last_os_error()
));
} }
} }
} }
@ -117,17 +111,14 @@ pub async fn connect_with_http_proxy(
let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout).await?; let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout).await?;
info!("Connected to http proxy {}:{}", proxy_host, proxy_port); info!("Connected to http proxy {}:{}", proxy_host, proxy_port);
let authorization = let authorization = if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) {
if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) { let creds = base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password));
let creds = format!("Proxy-Authorization: Basic {}\r\n", creds)
base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password)); } else {
format!("Proxy-Authorization: Basic {}\r\n", creds) "".to_string()
} else { };
"".to_string()
};
let connect_request = let connect_request = format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
socket.write_all(connect_request.as_bytes()).await?; socket.write_all(connect_request.as_bytes()).await?;
let mut buf = BytesMut::with_capacity(1024); let mut buf = BytesMut::with_capacity(1024);
@ -136,16 +127,15 @@ pub async fn connect_with_http_proxy(
match nb_bytes { match nb_bytes {
Ok(Ok(0)) => { Ok(Ok(0)) => {
return Err(anyhow!( return Err(anyhow!(
"Cannot connect to http proxy. Proxy closed the connection without returning any response")); "Cannot connect to http proxy. Proxy closed the connection without returning any response"
));
} }
Ok(Ok(_)) => {} Ok(Ok(_)) => {}
Ok(Err(err)) => { Ok(Err(err)) => {
return Err(anyhow!("Cannot connect to http proxy. {err}")); return Err(anyhow!("Cannot connect to http proxy. {err}"));
} }
Err(_) => { Err(_) => {
return Err(anyhow!( return Err(anyhow!("Cannot connect to http proxy. Proxy took too long to connect"));
"Cannot connect to http proxy. Proxy took too long to connect"
));
} }
}; };
@ -225,8 +215,7 @@ mod tests {
let server = TcpListener::bind(server_addr).await.unwrap(); let server = TcpListener::bind(server_addr).await.unwrap();
let docker = testcontainers::clients::Cli::default(); let docker = testcontainers::clients::Cli::default();
let mitm_proxy: RunnableImage<MitmProxy> = let mitm_proxy: RunnableImage<MitmProxy> = RunnableImage::from(MitmProxy {}).with_network("host".to_string());
RunnableImage::from(MitmProxy {}).with_network("host".to_string());
let _node = docker.run(mitm_proxy); let _node = docker.run(mitm_proxy);
let mut client = connect_with_http_proxy( let mut client = connect_with_http_proxy(
@ -239,10 +228,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
client client.write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice()).await.unwrap();
.write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice())
.await
.unwrap();
let client_srv = server.accept().await.unwrap().0; let client_srv = server.accept().await.unwrap().0;
pin_mut!(client_srv); pin_mut!(client_srv);

View file

@ -45,21 +45,15 @@ pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKey> {
match keys.len() { match keys.len() {
0 => Err(anyhow!("No PKCS8-encoded private key found in {path:?}")), 0 => Err(anyhow!("No PKCS8-encoded private key found in {path:?}")),
1 => Ok(PrivateKey(keys.remove(0))), 1 => Ok(PrivateKey(keys.remove(0))),
_ => Err(anyhow!( _ => Err(anyhow!("More than one PKCS8-encoded private key found in {path:?}")),
"More than one PKCS8-encoded private key found in {path:?}"
)),
} }
} }
fn tls_connector( fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsConnector> {
tls_cfg: &TlsClientConfig,
alpn_protocols: Option<Vec<Vec<u8>>>,
) -> anyhow::Result<TlsConnector> {
let mut root_store = rustls::RootCertStore::empty(); let mut root_store = rustls::RootCertStore::empty();
// Load system certificates and add them to the root store // Load system certificates and add them to the root store
let certs = rustls_native_certs::load_native_certs() let certs = rustls_native_certs::load_native_certs().with_context(|| "Cannot load system certificates")?;
.with_context(|| "Cannot load system certificates")?;
for cert in certs { for cert in certs {
root_store.add(&Certificate(cert.0))?; root_store.add(&Certificate(cert.0))?;
} }
@ -71,9 +65,7 @@ fn tls_connector(
// To bypass certificate verification // To bypass certificate verification
if !tls_cfg.tls_verify_certificate { if !tls_cfg.tls_verify_certificate {
config config.dangerous().set_certificate_verifier(Arc::new(NullVerifier));
.dangerous()
.set_certificate_verifier(Arc::new(NullVerifier));
} }
if let Some(alpn_protocols) = alpn_protocols { if let Some(alpn_protocols) = alpn_protocols {
@ -83,10 +75,7 @@ fn tls_connector(
Ok(tls_connector) Ok(tls_connector)
} }
pub fn tls_acceptor( pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsAcceptor> {
tls_cfg: &TlsServerConfig,
alpn_protocols: Option<Vec<Vec<u8>>>,
) -> anyhow::Result<TlsAcceptor> {
let mut config = rustls::ServerConfig::builder() let mut config = rustls::ServerConfig::builder()
.with_safe_defaults() .with_safe_defaults()
.with_no_client_auth() .with_no_client_auth()
@ -114,12 +103,7 @@ pub async fn connect(
let tls_stream = tls_connector let tls_stream = tls_connector
.connect(sni, tcp_stream) .connect(sni, tcp_stream)
.await .await
.with_context(|| { .with_context(|| format!("failed to do TLS handshake with the server {:?}", client_cfg.remote_addr))?;
format!(
"failed to do TLS handshake with the server {:?}",
client_cfg.remote_addr
)
})?;
Ok(tls_stream) Ok(tls_stream)
} }

View file

@ -44,9 +44,7 @@ pub async fn connect(
) -> anyhow::Result<WebSocket<Upgraded>> { ) -> anyhow::Result<WebSocket<Upgraded>> {
let mut pooled_cnx = match client_cfg.cnx_pool().get().await { let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
Ok(tcp_stream) => tcp_stream, Ok(tcp_stream) => tcp_stream,
Err(err) => Err(anyhow!( Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}"))?,
"failed to get a connection to the server from the pool: {err:?}"
))?,
}; };
let mut req = Request::builder() let mut req = Request::builder()
@ -80,12 +78,7 @@ pub async fn connect(
let transport = pooled_cnx.deref_mut().take().unwrap(); let transport = pooled_cnx.deref_mut().take().unwrap();
let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport) let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport)
.await .await
.with_context(|| { .with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;
format!(
"failed to do websocket handshake with the server {:?}",
client_cfg.remote_addr
)
})?;
Ok(ws) Ok(ws)
} }
@ -109,10 +102,7 @@ where
// Forward local tx to websocket tx // Forward local tx to websocket tx
let ping_frequency = client_cfg.websocket_ping_frequency; let ping_frequency = client_cfg.websocket_ping_frequency;
tokio::spawn( tokio::spawn(super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current()));
super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency)
.instrument(Span::current()),
);
// Forward websocket rx to local rx // Forward websocket rx to local rx
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await; let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await;

View file

@ -1,12 +1,12 @@
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite}; use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
use futures_util::pin_mut; use futures_util::pin_mut;
use hyper::upgrade::Upgraded; use hyper::upgrade::Upgraded;
use std::pin::Pin;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::select; use tokio::select;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio::time::timeout;
use tracing::log::debug; use tracing::log::debug;
use tracing::{error, info, trace, warn}; use tracing::{error, info, trace, warn};
@ -20,7 +20,14 @@ pub(super) async fn propagate_read(
info!("Closing local tx ==> websocket tx tunnel"); info!("Closing local tx ==> websocket tx tunnel");
}); });
let mut buffer = vec![0u8; 8 * 1024]; static JUMBO_FRAME_SIZE: usize = 9 * 1024; // enough for a jumbo frame
let mut buffer = vec![0u8; JUMBO_FRAME_SIZE];
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
// We reuse the future to avoid creating a timer in the tight loop
let mut timeout_unpin = tokio::time::sleep(ping_frequency);
let mut timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
pin_mut!(local_rx); pin_mut!(local_rx);
loop { loop {
let read_len = select! { let read_len = select! {
@ -30,9 +37,12 @@ pub(super) async fn propagate_read(
_ = close_tx.closed() => break, _ = close_tx.closed() => break,
_ = timeout(ping_frequency, futures_util::future::pending::<()>()) => { _ = &mut timeout => {
debug!("sending ping to keep websocket connection alive"); debug!("sending ping to keep websocket connection alive");
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?; ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?;
timeout_unpin = tokio::time::sleep(ping_frequency);
timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
continue; continue;
} }
}; };
@ -41,32 +51,30 @@ pub(super) async fn propagate_read(
Ok(0) => break, Ok(0) => break,
Ok(read_len) => read_len, Ok(read_len) => read_len,
Err(err) => { Err(err) => {
warn!( warn!("error while reading incoming bytes from local tx tunnel {}", err);
"error while reading incoming bytes from local tx tunnel {}",
err
);
break; break;
} }
}; };
trace!("read {} bytes", read_len); trace!("read {} bytes", read_len);
match ws_tx if let Err(err) = ws_tx
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len]))) .write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len])))
.await .await
{ {
Ok(_) => {} warn!("error while writing to websocket tx tunnel {}", err);
Err(err) => { break;
warn!("error while writing to websocket tx tunnel {}", err);
break;
}
} }
// If the buffer has been completely filled with previous read, Double it !
// For the buffer to not be a bottleneck when the TCP window scale
// For udp, the buffer will never grows.
if buffer.capacity() == read_len { if buffer.capacity() == read_len {
buffer.clear(); buffer.clear();
buffer.resize(buffer.capacity() * 2, 0); buffer.resize(buffer.capacity() * 2, 0);
} }
} }
// Send normal close
let _ = ws_tx.write_frame(Frame::close(1000, &[])).await; let _ = ws_tx.write_frame(Frame::close(1000, &[])).await;
Ok(()) Ok(())
@ -104,20 +112,15 @@ pub(super) async fn propagate_write(
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload); trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
let ret = match msg.opcode { let ret = match msg.opcode {
OpCode::Continuation | OpCode::Text | OpCode::Binary => { OpCode::Continuation | OpCode::Text | OpCode::Binary => local_tx.write_all(msg.payload.as_ref()).await,
local_tx.write_all(msg.payload.as_ref()).await
}
OpCode::Close => break, OpCode::Close => break,
OpCode::Ping => Ok(()), OpCode::Ping => Ok(()),
OpCode::Pong => Ok(()), OpCode::Pong => Ok(()),
}; };
match ret { if let Err(err) = ret {
Ok(_) => {} error!("error while writing bytes to local for rx tunnel {}", err);
Err(err) => { break;
error!("error while writing bytes to local for rx tunnel {}", err);
break;
}
} }
} }

View file

@ -42,12 +42,8 @@ impl JwtTunnelConfig {
} }
static JWT_SECRET: &[u8; 15] = b"champignonfrais"; static JWT_SECRET: &[u8; 15] = b"champignonfrais";
static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| { static JWT_KEY: Lazy<(Header, EncodingKey)> =
( Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET)));
Header::new(Algorithm::HS256),
EncodingKey::from_secret(JWT_SECRET),
)
});
static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| { static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
let mut validation = Validation::new(Algorithm::HS256); let mut validation = Validation::new(Algorithm::HS256);
@ -61,11 +57,7 @@ pub enum TransportStream {
} }
impl AsyncRead for TransportStream { impl AsyncRead for TransportStream {
fn poll_read( fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() { match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf), TransportStream::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf), TransportStream::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
@ -74,11 +66,7 @@ impl AsyncRead for TransportStream {
} }
impl AsyncWrite for TransportStream { impl AsyncWrite for TransportStream {
fn poll_write( fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
match self.get_mut() { match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf), TransportStream::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf), TransportStream::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),

View file

@ -46,15 +46,8 @@ async fn from_query(
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
if let Some(allowed_dests) = &server_config.restrict_to { if let Some(allowed_dests) = &server_config.restrict_to {
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp); let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
if allowed_dests if allowed_dests.iter().any(|dest| dest == &requested_dest).not() {
.iter() warn!("Rejecting connection with not allowed destination: {}", requested_dest);
.any(|dest| dest == &requested_dest)
.not()
{
warn!(
"Rejecting connection with not allowed destination: {}",
requested_dest
);
return Err(anyhow::anyhow!("Invalid upgrade request")); return Err(anyhow::anyhow!("Invalid upgrade request"));
} }
} }
@ -75,14 +68,9 @@ async fn from_query(
LocalProtocol::Tcp { .. } => { LocalProtocol::Tcp { .. } => {
let host = Host::parse(&jwt.claims.r)?; let host = Host::parse(&jwt.claims.r)?;
let port = jwt.claims.rp; let port = jwt.claims.rp;
let (rx, tx) = tcp::connect( let (rx, tx) = tcp::connect(&host, port, &server_config.socket_so_mark, Duration::from_secs(10))
&host, .await?
port, .into_split();
&server_config.socket_so_mark,
Duration::from_secs(10),
)
.await?
.into_split();
Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx))) Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx)))
} }
@ -100,10 +88,7 @@ async fn server_upgrade(
} }
if !req.uri().path().ends_with("/events") { if !req.uri().path().ends_with("/events") {
warn!( warn!("Rejecting connection with bad upgrade request: {}", req.uri());
"Rejecting connection with bad upgrade request: {}",
req.uri()
);
return Ok(http::Response::builder() return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST) .status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid upgrade request")) .body(Body::from("Invalid upgrade request"))
@ -118,10 +103,7 @@ async fn server_upgrade(
|| &path[min_len..max_len] != path_prefix.as_str() || &path[min_len..max_len] != path_prefix.as_str()
|| !path[max_len..].starts_with('/') || !path[max_len..].starts_with('/')
{ {
warn!( warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
"Rejecting connection with bad path prefix in upgrade request: {}",
req.uri()
);
return Ok(http::Response::builder() return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST) .status(StatusCode::BAD_REQUEST)
.body(Body::from("Invalid upgrade request")) .body(Body::from("Invalid upgrade request"))
@ -133,11 +115,7 @@ async fn server_upgrade(
match from_query(&server_config, req.uri().query().unwrap_or_default()).await { match from_query(&server_config, req.uri().query().unwrap_or_default()).await {
Ok(ret) => ret, Ok(ret) => ret,
Err(err) => { Err(err) => {
warn!( warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
"Rejecting connection with bad upgrade request: {} {}",
err,
req.uri()
);
return Ok(http::Response::builder() return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST) .status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Invalid upgrade request: {:?}", err))) .body(Body::from(format!("Invalid upgrade request: {:?}", err)))
@ -149,11 +127,7 @@ async fn server_upgrade(
let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) { let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
Ok(ret) => ret, Ok(ret) => ret,
Err(err) => { Err(err) => {
warn!( warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
"Rejecting connection with bad upgrade request: {} {}",
err,
req.uri()
);
return Ok(http::Response::builder() return Ok(http::Response::builder()
.status(StatusCode::BAD_REQUEST) .status(StatusCode::BAD_REQUEST)
.body(Body::from(format!("Invalid upgrade request: {:?}", err))) .body(Body::from(format!("Invalid upgrade request: {:?}", err)))
@ -171,14 +145,10 @@ async fn server_upgrade(
} }
}; };
let (close_tx, close_rx) = oneshot::channel::<()>(); let (close_tx, close_rx) = oneshot::channel::<()>();
let ping_frequency = server_config let ping_frequency = server_config.websocket_ping_frequency.unwrap_or(Duration::MAX);
.websocket_ping_frequency
.unwrap_or(Duration::MAX);
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame); ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
tokio::task::spawn( tokio::task::spawn(super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()));
super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()),
);
let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await; let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await;
} }
@ -189,10 +159,7 @@ async fn server_upgrade(
} }
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> { pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
info!( info!("Starting wstunnel server listening on {}", server_config.bind);
"Starting wstunnel server listening on {}",
server_config.bind
);
let config = server_config.clone(); let config = server_config.clone();
let upgrade_fn = move |req: Request<Body>| server_upgrade(config.clone(), req); let upgrade_fn = move |req: Request<Body>| server_upgrade(config.clone(), req);

View file

@ -1,18 +1,17 @@
use anyhow::Context; use anyhow::Context;
use futures_util::{stream, Stream}; use futures_util::{stream, Stream};
use parking_lot::{Mutex, RwLock}; use parking_lot::RwLock;
use pin_project::{pin_project, pinned_drop}; use pin_project::{pin_project, pinned_drop};
use std::collections::hash_map::Entry;
use std::collections::HashMap; use std::collections::HashMap;
use std::future::Future; use std::future::Future;
use std::io; use std::io;
use std::io::{Error, ErrorKind}; use std::io::{Error, ErrorKind};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::ops::DerefMut;
use std::pin::{pin, Pin}; use std::pin::{pin, Pin};
use std::sync::{Arc, Weak}; use std::sync::{Arc, Weak};
use std::task::{ready, Poll, Waker}; use std::task::{ready, Poll};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
@ -23,8 +22,7 @@ use tokio::time::Sleep;
use tracing::{debug, error, info}; use tracing::{debug, error, info};
struct IoInner { struct IoInner {
has_data_to_read: &'static Notify, has_data_to_read: Notify,
waker: Mutex<Option<Waker>>,
has_read_data: Notify, has_read_data: Notify,
} }
struct UdpServer { struct UdpServer {
@ -43,6 +41,7 @@ impl UdpServer {
cnx_timeout: timeout, cnx_timeout: timeout,
} }
} }
#[inline]
fn clean_dead_keys(&mut self) { fn clean_dead_keys(&mut self) {
let nb_key_to_delete = self.keys_to_delete.read().len(); let nb_key_to_delete = self.keys_to_delete.read().len();
if nb_key_to_delete == 0 { if nb_key_to_delete == 0 {
@ -52,16 +51,7 @@ impl UdpServer {
debug!("Cleaning {} dead udp peers", nb_key_to_delete); debug!("Cleaning {} dead udp peers", nb_key_to_delete);
let mut keys_to_delete = self.keys_to_delete.write(); let mut keys_to_delete = self.keys_to_delete.write();
for key in keys_to_delete.iter() { for key in keys_to_delete.iter() {
let Some(peer) = self.peers.remove(key) else { self.peers.remove(key);
continue;
};
#[allow(mutable_transmutes)]
unsafe {
let _ = Box::from_raw(std::mem::transmute::<&Notify, &mut Notify>(
peer.has_data_to_read,
));
}
} }
keys_to_delete.clear(); keys_to_delete.clear();
} }
@ -90,7 +80,42 @@ impl PinnedDrop for UdpStream {
keys_to_delete.write().push(self.peer); keys_to_delete.write().push(self.peer);
} }
self.io.has_read_data.notify_one(); // safety: we are dropping the notification as we extend its lifetime to 'static unsafely
// So it must be gone before we drop its parent. It should never happen but in case
let mut project = self.project();
project.pending_notification.as_mut().set(None);
project.io.has_read_data.notify_one();
}
}
impl UdpStream {
fn new(
socket: Arc<UdpSocket>,
peer: SocketAddr,
deadline: Option<Sleep>,
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
) -> (Self, Arc<IoInner>) {
let has_data_to_read = Notify::new();
let has_read_data = Notify::new();
has_data_to_read.notify_one();
let io = Arc::new(IoInner {
has_data_to_read,
has_read_data,
});
let mut s = Self {
socket,
peer,
deadline,
has_been_notified: false,
pending_notification: None,
io: io.clone(),
keys_to_delete,
};
let pending_notification = unsafe { std::mem::transmute(s.io.has_data_to_read.notified()) };
s.pending_notification = Some(pending_notification);
(s, io)
} }
} }
@ -111,43 +136,28 @@ impl AsyncRead for UdpStream {
} }
if let Some(notified) = project.pending_notification.as_mut().as_pin_mut() { if let Some(notified) = project.pending_notification.as_mut().as_pin_mut() {
if !notified.poll(cx).is_ready() { ready!(notified.poll(cx));
project.io.waker.lock().replace(cx.waker().clone());
return Poll::Pending;
}
project.pending_notification.as_mut().set(None); project.pending_notification.as_mut().set(None);
} }
let _ = ready!(project.socket.poll_recv(cx, obuf)); let _ = ready!(project.socket.poll_recv(cx, obuf));
project let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) };
.pending_notification project.pending_notification.as_mut().set(Some(notified));
.as_mut()
.set(Some(project.io.has_data_to_read.notified()));
project.io.has_read_data.notify_one(); project.io.has_read_data.notify_one();
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
} }
impl AsyncWrite for UdpStream { impl AsyncWrite for UdpStream {
fn poll_write( fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
self.socket.poll_send_to(cx, buf, self.peer) self.socket.poll_send_to(cx, buf, self.peer)
} }
fn poll_flush( fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
self.socket.poll_send_ready(cx) self.socket.poll_send_ready(cx)
} }
fn poll_shutdown( fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
} }
@ -178,39 +188,22 @@ pub async fn run_server(
} }
}; };
match server.peers.entry(peer_addr) { match server.peers.get(&peer_addr) {
Entry::Occupied(mut peer) => { Some(io) => {
let io = peer.get_mut();
io.has_read_data.notified().await; io.has_read_data.notified().await;
io.has_data_to_read.notify_one(); io.has_data_to_read.notify_one();
let waker = io.waker.lock().deref_mut().take();
if let Some(waker) = waker {
waker.wake();
}
} }
Entry::Vacant(peer) => { None => {
let has_data_to_read: &'static Notify = Box::leak(Box::new(Notify::new())); let (udp_client, io) = UdpStream::new(
let pending_notification = has_data_to_read.notified(); server.clone_socket(),
let has_read_data = Notify::new(); peer_addr,
has_data_to_read.notify_one(); server
let io = Arc::new(IoInner {
has_data_to_read,
waker: Mutex::new(None),
has_read_data,
});
peer.insert(io.clone());
let udp_client = UdpStream {
socket: server.clone_socket(),
peer: peer_addr,
deadline: server
.cnx_timeout .cnx_timeout
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout)) .and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
.map(tokio::time::sleep_until), .map(tokio::time::sleep_until),
keys_to_delete: Arc::downgrade(&server.keys_to_delete), Arc::downgrade(&server.keys_to_delete),
has_been_notified: false, );
pending_notification: Some(pending_notification), server.peers.insert(peer_addr, io);
io,
};
return Some((Ok(udp_client), (server))); return Some((Ok(udp_client), (server)));
} }
} }
@ -231,11 +224,7 @@ impl MyUdpSocket {
} }
impl AsyncRead for MyUdpSocket { impl AsyncRead for MyUdpSocket {
fn poll_read( fn poll_read(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
unsafe { self.map_unchecked_mut(|x| &mut x.socket) } unsafe { self.map_unchecked_mut(|x| &mut x.socket) }
.poll_recv_from(cx, buf) .poll_recv_from(cx, buf)
.map(|x| x.map(|_| ())) .map(|x| x.map(|_| ()))
@ -243,25 +232,15 @@ impl AsyncRead for MyUdpSocket {
} }
impl AsyncWrite for MyUdpSocket { impl AsyncWrite for MyUdpSocket {
fn poll_write( fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf) unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf)
} }
fn poll_flush( fn poll_flush(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn poll_shutdown( fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
} }
@ -331,10 +310,7 @@ mod tests {
assert!(client.send_to(b"aaaaa".as_ref(), server_addr).await.is_ok()); assert!(client.send_to(b"aaaaa".as_ref(), server_addr).await.is_ok());
let client2 = UdpSocket::bind("[::1]:0").await.unwrap(); let client2 = UdpSocket::bind("[::1]:0").await.unwrap();
assert!(client2 assert!(client2.send_to(b"bbbbb".as_ref(), server_addr).await.is_ok());
.send_to(b"bbbbb".as_ref(), server_addr)
.await
.is_ok());
// Should have a new connection // Should have a new connection
let fut = timeout(Duration::from_millis(100), server.next()).await; let fut = timeout(Duration::from_millis(100), server.next()).await;
@ -360,10 +336,7 @@ mod tests {
assert_eq!(&buf[..6], b"bbbbb\0"); assert_eq!(&buf[..6], b"bbbbb\0");
assert!(client.send_to(b"ccccc".as_ref(), server_addr).await.is_ok()); assert!(client.send_to(b"ccccc".as_ref(), server_addr).await.is_ok());
assert!(client2 assert!(client2.send_to(b"ddddd".as_ref(), server_addr).await.is_ok());
.send_to(b"ddddd".as_ref(), server_addr)
.await
.is_ok());
// Server need to be polled to feed the stream with need data // Server need to be polled to feed the stream with need data
let _ = timeout(Duration::from_millis(100), server.next()).await; let _ = timeout(Duration::from_millis(100), server.next()).await;