Cleanup
This commit is contained in:
parent
b478288848
commit
466cb425bc
11 changed files with 159 additions and 320 deletions
3
rustfmt.toml
Normal file
3
rustfmt.toml
Normal file
|
@ -0,0 +1,3 @@
|
|||
edition = "2021"
|
||||
max_width = 120
|
||||
fn_call_width = 80
|
|
@ -3,14 +3,13 @@ use tokio_rustls::rustls::{Certificate, PrivateKey};
|
|||
|
||||
pub static TLS_PRIVATE_KEY: Lazy<PrivateKey> = Lazy::new(|| {
|
||||
let key = include_bytes!("../certs/key.pem");
|
||||
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key.as_slice())
|
||||
.expect("failed to load embedded tls private key");
|
||||
let mut keys =
|
||||
rustls_pemfile::pkcs8_private_keys(&mut key.as_slice()).expect("failed to load embedded tls private key");
|
||||
PrivateKey(keys.remove(0))
|
||||
});
|
||||
pub static TLS_CERTIFICATE: Lazy<Vec<Certificate>> = Lazy::new(|| {
|
||||
let cert = include_bytes!("../certs/cert.pem");
|
||||
let certs = rustls_pemfile::certs(&mut cert.as_slice())
|
||||
.expect("failed to load embedded tls certificate");
|
||||
let certs = rustls_pemfile::certs(&mut cert.as_slice()).expect("failed to load embedded tls certificate");
|
||||
|
||||
certs.into_iter().map(Certificate).collect()
|
||||
});
|
||||
|
|
95
src/main.rs
95
src/main.rs
|
@ -67,13 +67,7 @@ struct Client {
|
|||
/// This option set the maximum number of connection that will be kept open.
|
||||
/// This is useful if you plan to create/destroy a lot of tunnel (i.e: with socks5 to navigate with a browser)
|
||||
/// It will avoid the latency of doing tcp + tls handshake with the server
|
||||
#[arg(
|
||||
short = 'c',
|
||||
long,
|
||||
value_name = "INT",
|
||||
default_value = "0",
|
||||
verbatim_doc_comment
|
||||
)]
|
||||
#[arg(short = 'c', long, value_name = "INT", default_value = "0", verbatim_doc_comment)]
|
||||
connection_min_idle: u32,
|
||||
|
||||
/// Domain name that will be use as SNI during TLS handshake
|
||||
|
@ -88,12 +82,7 @@ struct Client {
|
|||
tls_verify_certificate: bool,
|
||||
|
||||
/// If set, will use this http proxy to connect to the server
|
||||
#[arg(
|
||||
short = 'p',
|
||||
long,
|
||||
value_name = "http://USER:PASS@HOST:PORT",
|
||||
verbatim_doc_comment
|
||||
)]
|
||||
#[arg(short = 'p', long, value_name = "http://USER:PASS@HOST:PORT", verbatim_doc_comment)]
|
||||
http_proxy: Option<Url>,
|
||||
|
||||
/// Use a specific prefix that will show up in the http path during the upgrade request.
|
||||
|
@ -241,9 +230,7 @@ fn parse_local_bind(arg: &str) -> Result<(SocketAddr, &str), io::Error> {
|
|||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn parse_tunnel_dest(
|
||||
remaining: &str,
|
||||
) -> Result<(Host<String>, u16, BTreeMap<String, String>), io::Error> {
|
||||
fn parse_tunnel_dest(remaining: &str) -> Result<(Host<String>, u16, BTreeMap<String, String>), io::Error> {
|
||||
use std::io::Error;
|
||||
|
||||
let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else {
|
||||
|
@ -290,13 +277,7 @@ fn parse_tunnel_arg(arg: &str) -> Result<LocalToRemote, io::Error> {
|
|||
let timeout = options
|
||||
.get("timeout_sec")
|
||||
.and_then(|x| x.parse::<u64>().ok())
|
||||
.map(|d| {
|
||||
if d == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(Duration::from_secs(d))
|
||||
}
|
||||
})
|
||||
.map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) })
|
||||
.unwrap_or(Some(Duration::from_secs(30)));
|
||||
|
||||
Ok(LocalToRemote {
|
||||
|
@ -355,10 +336,7 @@ fn parse_http_headers(arg: &str) -> Result<(HeaderName, HeaderValue), io::Error>
|
|||
Err(err) => {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!(
|
||||
"cannot parse http header value from {} due to {:?}",
|
||||
value, err
|
||||
),
|
||||
format!("cannot parse http header value from {} due to {:?}", value, err),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
@ -394,10 +372,7 @@ fn parse_server_url(arg: &str) -> Result<Url, io::Error> {
|
|||
}
|
||||
|
||||
if url.host().is_none() {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("invalid server host {}", arg),
|
||||
));
|
||||
return Err(io::Error::new(ErrorKind::InvalidInput, format!("invalid server host {}", arg)));
|
||||
}
|
||||
|
||||
Ok(url)
|
||||
|
@ -474,15 +449,9 @@ impl WsClientConfig {
|
|||
}
|
||||
|
||||
pub fn tls_server_name(&self) -> ServerName {
|
||||
match self
|
||||
.tls
|
||||
.as_ref()
|
||||
.and_then(|tls| tls.tls_sni_override.as_ref())
|
||||
{
|
||||
match self.tls.as_ref().and_then(|tls| tls.tls_sni_override.as_ref()) {
|
||||
None => match &self.remote_addr.0 {
|
||||
Host::Domain(domain) => {
|
||||
ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap())
|
||||
}
|
||||
Host::Domain(domain) => ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap()),
|
||||
Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)),
|
||||
Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)),
|
||||
},
|
||||
|
@ -529,12 +498,11 @@ async fn main() {
|
|||
};
|
||||
|
||||
// Extract host header from http_headers
|
||||
let host_header =
|
||||
if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) {
|
||||
host_val.clone()
|
||||
} else {
|
||||
HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap()
|
||||
};
|
||||
let host_header = if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) {
|
||||
host_val.clone()
|
||||
} else {
|
||||
HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap()
|
||||
};
|
||||
let mut client_config = WsClientConfig {
|
||||
remote_addr: (
|
||||
args.remote_addr.host().unwrap().to_owned(),
|
||||
|
@ -544,16 +512,10 @@ async fn main() {
|
|||
tls,
|
||||
http_upgrade_path_prefix: args.http_upgrade_path_prefix,
|
||||
http_upgrade_credentials: args.http_upgrade_credentials,
|
||||
http_headers: args
|
||||
.http_headers
|
||||
.into_iter()
|
||||
.filter(|(k, _)| k != HOST)
|
||||
.collect(),
|
||||
http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(),
|
||||
http_header_host: host_header,
|
||||
timeout_connect: Duration::from_secs(10),
|
||||
websocket_ping_frequency: args
|
||||
.websocket_ping_frequency_sec
|
||||
.unwrap_or(Duration::from_secs(30)),
|
||||
websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)),
|
||||
websocket_mask_frame: args.websocket_mask_frame,
|
||||
http_proxy: args.http_proxy,
|
||||
cnx_pool: None,
|
||||
|
@ -579,16 +541,12 @@ async fn main() {
|
|||
let remote = tunnel.remote.clone();
|
||||
let server = tcp::run_server(tunnel.local)
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
panic!("Cannot start TCP server on {}: {}", tunnel.local, err)
|
||||
})
|
||||
.unwrap_or_else(|err| panic!("Cannot start TCP server on {}: {}", tunnel.local, err))
|
||||
.map_err(anyhow::Error::new)
|
||||
.map_ok(move |stream| (stream.into_split(), remote.clone()));
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) =
|
||||
tunnel::client::run_tunnel(client_config, tunnel, server).await
|
||||
{
|
||||
if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -597,16 +555,12 @@ async fn main() {
|
|||
let remote = tunnel.remote.clone();
|
||||
let server = udp::run_server(tunnel.local, *timeout)
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
panic!("Cannot start UDP server on {}: {}", tunnel.local, err)
|
||||
})
|
||||
.unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err))
|
||||
.map_err(anyhow::Error::new)
|
||||
.map_ok(move |stream| (tokio::io::split(stream), remote.clone()));
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) =
|
||||
tunnel::client::run_tunnel(client_config, tunnel, server).await
|
||||
{
|
||||
if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -614,15 +568,11 @@ async fn main() {
|
|||
LocalProtocol::Socks5 => {
|
||||
let server = socks5::run_server(tunnel.local)
|
||||
.await
|
||||
.unwrap_or_else(|err| {
|
||||
panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)
|
||||
})
|
||||
.unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err))
|
||||
.map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest));
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) =
|
||||
tunnel::client::run_tunnel(client_config, tunnel, server).await
|
||||
{
|
||||
if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await {
|
||||
error!("{:?}", err);
|
||||
}
|
||||
});
|
||||
|
@ -656,8 +606,7 @@ async fn main() {
|
|||
Commands::Server(args) => {
|
||||
let tls_config = if args.remote_addr.scheme() == "wss" {
|
||||
let tls_certificate = if let Some(cert_path) = args.tls_certificate {
|
||||
tls::load_certificates_from_pem(&cert_path)
|
||||
.expect("Cannot load tls certificate")
|
||||
tls::load_certificates_from_pem(&cert_path).expect("Cannot load tls certificate")
|
||||
} else {
|
||||
embedded_certificate::TLS_CERTIFICATE.clone()
|
||||
};
|
||||
|
|
|
@ -19,10 +19,7 @@ pub struct Socks5Listener {
|
|||
impl Stream for Socks5Listener {
|
||||
type Item = anyhow::Result<(TcpStream, (Host, u16))>;
|
||||
|
||||
fn poll_next(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
unsafe { self.map_unchecked_mut(|x| &mut x.stream) }.poll_next(cx)
|
||||
}
|
||||
}
|
||||
|
|
46
src/tcp.rs
46
src/tcp.rs
|
@ -14,12 +14,9 @@ use tracing::log::info;
|
|||
use url::{Host, Url};
|
||||
|
||||
fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(), anyhow::Error> {
|
||||
socket.set_nodelay(true).with_context(|| {
|
||||
format!(
|
||||
"cannot set no_delay on socket: {}",
|
||||
io::Error::last_os_error()
|
||||
)
|
||||
})?;
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.with_context(|| format!("cannot set no_delay on socket: {}", io::Error::last_os_error()))?;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
if let Some(so_mark) = so_mark {
|
||||
|
@ -35,10 +32,7 @@ fn configure_socket(socket: &mut TcpSocket, so_mark: &Option<i32>) -> Result<(),
|
|||
);
|
||||
|
||||
if ret != 0 {
|
||||
return Err(anyhow!(
|
||||
"Cannot set SO_MARK on the connection {:?}",
|
||||
io::Error::last_os_error()
|
||||
));
|
||||
return Err(anyhow!("Cannot set SO_MARK on the connection {:?}", io::Error::last_os_error()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -117,17 +111,14 @@ pub async fn connect_with_http_proxy(
|
|||
let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout).await?;
|
||||
info!("Connected to http proxy {}:{}", proxy_host, proxy_port);
|
||||
|
||||
let authorization =
|
||||
if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) {
|
||||
let creds =
|
||||
base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password));
|
||||
format!("Proxy-Authorization: Basic {}\r\n", creds)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
let authorization = if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) {
|
||||
let creds = base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password));
|
||||
format!("Proxy-Authorization: Basic {}\r\n", creds)
|
||||
} else {
|
||||
"".to_string()
|
||||
};
|
||||
|
||||
let connect_request =
|
||||
format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
|
||||
let connect_request = format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
|
||||
socket.write_all(connect_request.as_bytes()).await?;
|
||||
|
||||
let mut buf = BytesMut::with_capacity(1024);
|
||||
|
@ -136,16 +127,15 @@ pub async fn connect_with_http_proxy(
|
|||
match nb_bytes {
|
||||
Ok(Ok(0)) => {
|
||||
return Err(anyhow!(
|
||||
"Cannot connect to http proxy. Proxy closed the connection without returning any response"));
|
||||
"Cannot connect to http proxy. Proxy closed the connection without returning any response"
|
||||
));
|
||||
}
|
||||
Ok(Ok(_)) => {}
|
||||
Ok(Err(err)) => {
|
||||
return Err(anyhow!("Cannot connect to http proxy. {err}"));
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(anyhow!(
|
||||
"Cannot connect to http proxy. Proxy took too long to connect"
|
||||
));
|
||||
return Err(anyhow!("Cannot connect to http proxy. Proxy took too long to connect"));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -225,8 +215,7 @@ mod tests {
|
|||
let server = TcpListener::bind(server_addr).await.unwrap();
|
||||
|
||||
let docker = testcontainers::clients::Cli::default();
|
||||
let mitm_proxy: RunnableImage<MitmProxy> =
|
||||
RunnableImage::from(MitmProxy {}).with_network("host".to_string());
|
||||
let mitm_proxy: RunnableImage<MitmProxy> = RunnableImage::from(MitmProxy {}).with_network("host".to_string());
|
||||
let _node = docker.run(mitm_proxy);
|
||||
|
||||
let mut client = connect_with_http_proxy(
|
||||
|
@ -239,10 +228,7 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
client
|
||||
.write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice())
|
||||
.await
|
||||
.unwrap();
|
||||
client.write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice()).await.unwrap();
|
||||
let client_srv = server.accept().await.unwrap().0;
|
||||
pin_mut!(client_srv);
|
||||
|
||||
|
|
28
src/tls.rs
28
src/tls.rs
|
@ -45,21 +45,15 @@ pub fn load_private_key_from_file(path: &Path) -> anyhow::Result<PrivateKey> {
|
|||
match keys.len() {
|
||||
0 => Err(anyhow!("No PKCS8-encoded private key found in {path:?}")),
|
||||
1 => Ok(PrivateKey(keys.remove(0))),
|
||||
_ => Err(anyhow!(
|
||||
"More than one PKCS8-encoded private key found in {path:?}"
|
||||
)),
|
||||
_ => Err(anyhow!("More than one PKCS8-encoded private key found in {path:?}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn tls_connector(
|
||||
tls_cfg: &TlsClientConfig,
|
||||
alpn_protocols: Option<Vec<Vec<u8>>>,
|
||||
) -> anyhow::Result<TlsConnector> {
|
||||
fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsConnector> {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
|
||||
// Load system certificates and add them to the root store
|
||||
let certs = rustls_native_certs::load_native_certs()
|
||||
.with_context(|| "Cannot load system certificates")?;
|
||||
let certs = rustls_native_certs::load_native_certs().with_context(|| "Cannot load system certificates")?;
|
||||
for cert in certs {
|
||||
root_store.add(&Certificate(cert.0))?;
|
||||
}
|
||||
|
@ -71,9 +65,7 @@ fn tls_connector(
|
|||
|
||||
// To bypass certificate verification
|
||||
if !tls_cfg.tls_verify_certificate {
|
||||
config
|
||||
.dangerous()
|
||||
.set_certificate_verifier(Arc::new(NullVerifier));
|
||||
config.dangerous().set_certificate_verifier(Arc::new(NullVerifier));
|
||||
}
|
||||
|
||||
if let Some(alpn_protocols) = alpn_protocols {
|
||||
|
@ -83,10 +75,7 @@ fn tls_connector(
|
|||
Ok(tls_connector)
|
||||
}
|
||||
|
||||
pub fn tls_acceptor(
|
||||
tls_cfg: &TlsServerConfig,
|
||||
alpn_protocols: Option<Vec<Vec<u8>>>,
|
||||
) -> anyhow::Result<TlsAcceptor> {
|
||||
pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8>>>) -> anyhow::Result<TlsAcceptor> {
|
||||
let mut config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
|
@ -114,12 +103,7 @@ pub async fn connect(
|
|||
let tls_stream = tls_connector
|
||||
.connect(sni, tcp_stream)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to do TLS handshake with the server {:?}",
|
||||
client_cfg.remote_addr
|
||||
)
|
||||
})?;
|
||||
.with_context(|| format!("failed to do TLS handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||
|
||||
Ok(tls_stream)
|
||||
}
|
||||
|
|
|
@ -44,9 +44,7 @@ pub async fn connect(
|
|||
) -> anyhow::Result<WebSocket<Upgraded>> {
|
||||
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
||||
Ok(tcp_stream) => tcp_stream,
|
||||
Err(err) => Err(anyhow!(
|
||||
"failed to get a connection to the server from the pool: {err:?}"
|
||||
))?,
|
||||
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}"))?,
|
||||
};
|
||||
|
||||
let mut req = Request::builder()
|
||||
|
@ -80,12 +78,7 @@ pub async fn connect(
|
|||
let transport = pooled_cnx.deref_mut().take().unwrap();
|
||||
let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to do websocket handshake with the server {:?}",
|
||||
client_cfg.remote_addr
|
||||
)
|
||||
})?;
|
||||
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||
|
||||
Ok(ws)
|
||||
}
|
||||
|
@ -109,10 +102,7 @@ where
|
|||
|
||||
// Forward local tx to websocket tx
|
||||
let ping_frequency = client_cfg.websocket_ping_frequency;
|
||||
tokio::spawn(
|
||||
super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency)
|
||||
.instrument(Span::current()),
|
||||
);
|
||||
tokio::spawn(super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current()));
|
||||
|
||||
// Forward websocket rx to local rx
|
||||
let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await;
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
|
||||
use futures_util::pin_mut;
|
||||
use hyper::upgrade::Upgraded;
|
||||
use std::pin::Pin;
|
||||
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||
use tokio::select;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::time::timeout;
|
||||
use tracing::log::debug;
|
||||
use tracing::{error, info, trace, warn};
|
||||
|
||||
|
@ -20,7 +20,14 @@ pub(super) async fn propagate_read(
|
|||
info!("Closing local tx ==> websocket tx tunnel");
|
||||
});
|
||||
|
||||
let mut buffer = vec![0u8; 8 * 1024];
|
||||
static JUMBO_FRAME_SIZE: usize = 9 * 1024; // enough for a jumbo frame
|
||||
let mut buffer = vec![0u8; JUMBO_FRAME_SIZE];
|
||||
|
||||
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
|
||||
// We reuse the future to avoid creating a timer in the tight loop
|
||||
let mut timeout_unpin = tokio::time::sleep(ping_frequency);
|
||||
let mut timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
|
||||
|
||||
pin_mut!(local_rx);
|
||||
loop {
|
||||
let read_len = select! {
|
||||
|
@ -30,9 +37,12 @@ pub(super) async fn propagate_read(
|
|||
|
||||
_ = close_tx.closed() => break,
|
||||
|
||||
_ = timeout(ping_frequency, futures_util::future::pending::<()>()) => {
|
||||
_ = &mut timeout => {
|
||||
debug!("sending ping to keep websocket connection alive");
|
||||
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?;
|
||||
|
||||
timeout_unpin = tokio::time::sleep(ping_frequency);
|
||||
timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) };
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
@ -41,32 +51,30 @@ pub(super) async fn propagate_read(
|
|||
Ok(0) => break,
|
||||
Ok(read_len) => read_len,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"error while reading incoming bytes from local tx tunnel {}",
|
||||
err
|
||||
);
|
||||
warn!("error while reading incoming bytes from local tx tunnel {}", err);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
trace!("read {} bytes", read_len);
|
||||
match ws_tx
|
||||
if let Err(err) = ws_tx
|
||||
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len])))
|
||||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(err) => {
|
||||
warn!("error while writing to websocket tx tunnel {}", err);
|
||||
break;
|
||||
}
|
||||
warn!("error while writing to websocket tx tunnel {}", err);
|
||||
break;
|
||||
}
|
||||
|
||||
// If the buffer has been completely filled with previous read, Double it !
|
||||
// For the buffer to not be a bottleneck when the TCP window scale
|
||||
// For udp, the buffer will never grows.
|
||||
if buffer.capacity() == read_len {
|
||||
buffer.clear();
|
||||
buffer.resize(buffer.capacity() * 2, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// Send normal close
|
||||
let _ = ws_tx.write_frame(Frame::close(1000, &[])).await;
|
||||
|
||||
Ok(())
|
||||
|
@ -104,20 +112,15 @@ pub(super) async fn propagate_write(
|
|||
|
||||
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
|
||||
let ret = match msg.opcode {
|
||||
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
|
||||
local_tx.write_all(msg.payload.as_ref()).await
|
||||
}
|
||||
OpCode::Continuation | OpCode::Text | OpCode::Binary => local_tx.write_all(msg.payload.as_ref()).await,
|
||||
OpCode::Close => break,
|
||||
OpCode::Ping => Ok(()),
|
||||
OpCode::Pong => Ok(()),
|
||||
};
|
||||
|
||||
match ret {
|
||||
Ok(_) => {}
|
||||
Err(err) => {
|
||||
error!("error while writing bytes to local for rx tunnel {}", err);
|
||||
break;
|
||||
}
|
||||
if let Err(err) = ret {
|
||||
error!("error while writing bytes to local for rx tunnel {}", err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -42,12 +42,8 @@ impl JwtTunnelConfig {
|
|||
}
|
||||
|
||||
static JWT_SECRET: &[u8; 15] = b"champignonfrais";
|
||||
static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| {
|
||||
(
|
||||
Header::new(Algorithm::HS256),
|
||||
EncodingKey::from_secret(JWT_SECRET),
|
||||
)
|
||||
});
|
||||
static JWT_KEY: Lazy<(Header, EncodingKey)> =
|
||||
Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET)));
|
||||
|
||||
static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| {
|
||||
let mut validation = Validation::new(Algorithm::HS256);
|
||||
|
@ -61,11 +57,7 @@ pub enum TransportStream {
|
|||
}
|
||||
|
||||
impl AsyncRead for TransportStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
TransportStream::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
|
||||
TransportStream::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
|
||||
|
@ -74,11 +66,7 @@ impl AsyncRead for TransportStream {
|
|||
}
|
||||
|
||||
impl AsyncWrite for TransportStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, Error>> {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||
match self.get_mut() {
|
||||
TransportStream::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
||||
TransportStream::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
||||
|
|
|
@ -46,15 +46,8 @@ async fn from_query(
|
|||
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
||||
if let Some(allowed_dests) = &server_config.restrict_to {
|
||||
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
|
||||
if allowed_dests
|
||||
.iter()
|
||||
.any(|dest| dest == &requested_dest)
|
||||
.not()
|
||||
{
|
||||
warn!(
|
||||
"Rejecting connection with not allowed destination: {}",
|
||||
requested_dest
|
||||
);
|
||||
if allowed_dests.iter().any(|dest| dest == &requested_dest).not() {
|
||||
warn!("Rejecting connection with not allowed destination: {}", requested_dest);
|
||||
return Err(anyhow::anyhow!("Invalid upgrade request"));
|
||||
}
|
||||
}
|
||||
|
@ -75,14 +68,9 @@ async fn from_query(
|
|||
LocalProtocol::Tcp { .. } => {
|
||||
let host = Host::parse(&jwt.claims.r)?;
|
||||
let port = jwt.claims.rp;
|
||||
let (rx, tx) = tcp::connect(
|
||||
&host,
|
||||
port,
|
||||
&server_config.socket_so_mark,
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await?
|
||||
.into_split();
|
||||
let (rx, tx) = tcp::connect(&host, port, &server_config.socket_so_mark, Duration::from_secs(10))
|
||||
.await?
|
||||
.into_split();
|
||||
|
||||
Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx)))
|
||||
}
|
||||
|
@ -100,10 +88,7 @@ async fn server_upgrade(
|
|||
}
|
||||
|
||||
if !req.uri().path().ends_with("/events") {
|
||||
warn!(
|
||||
"Rejecting connection with bad upgrade request: {}",
|
||||
req.uri()
|
||||
);
|
||||
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
||||
return Ok(http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::from("Invalid upgrade request"))
|
||||
|
@ -118,10 +103,7 @@ async fn server_upgrade(
|
|||
|| &path[min_len..max_len] != path_prefix.as_str()
|
||||
|| !path[max_len..].starts_with('/')
|
||||
{
|
||||
warn!(
|
||||
"Rejecting connection with bad path prefix in upgrade request: {}",
|
||||
req.uri()
|
||||
);
|
||||
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
|
||||
return Ok(http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::from("Invalid upgrade request"))
|
||||
|
@ -133,11 +115,7 @@ async fn server_upgrade(
|
|||
match from_query(&server_config, req.uri().query().unwrap_or_default()).await {
|
||||
Ok(ret) => ret,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Rejecting connection with bad upgrade request: {} {}",
|
||||
err,
|
||||
req.uri()
|
||||
);
|
||||
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||
return Ok(http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
|
||||
|
@ -149,11 +127,7 @@ async fn server_upgrade(
|
|||
let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
|
||||
Ok(ret) => ret,
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Rejecting connection with bad upgrade request: {} {}",
|
||||
err,
|
||||
req.uri()
|
||||
);
|
||||
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||
return Ok(http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Body::from(format!("Invalid upgrade request: {:?}", err)))
|
||||
|
@ -171,14 +145,10 @@ async fn server_upgrade(
|
|||
}
|
||||
};
|
||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||
let ping_frequency = server_config
|
||||
.websocket_ping_frequency
|
||||
.unwrap_or(Duration::MAX);
|
||||
let ping_frequency = server_config.websocket_ping_frequency.unwrap_or(Duration::MAX);
|
||||
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
|
||||
|
||||
tokio::task::spawn(
|
||||
super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()),
|
||||
);
|
||||
tokio::task::spawn(super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()));
|
||||
|
||||
let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await;
|
||||
}
|
||||
|
@ -189,10 +159,7 @@ async fn server_upgrade(
|
|||
}
|
||||
|
||||
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
|
||||
info!(
|
||||
"Starting wstunnel server listening on {}",
|
||||
server_config.bind
|
||||
);
|
||||
info!("Starting wstunnel server listening on {}", server_config.bind);
|
||||
|
||||
let config = server_config.clone();
|
||||
let upgrade_fn = move |req: Request<Body>| server_upgrade(config.clone(), req);
|
||||
|
|
155
src/udp.rs
155
src/udp.rs
|
@ -1,18 +1,17 @@
|
|||
use anyhow::Context;
|
||||
use futures_util::{stream, Stream};
|
||||
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use parking_lot::RwLock;
|
||||
use pin_project::{pin_project, pinned_drop};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::io::{Error, ErrorKind};
|
||||
use std::net::SocketAddr;
|
||||
use std::ops::DerefMut;
|
||||
|
||||
use std::pin::{pin, Pin};
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::task::{ready, Poll, Waker};
|
||||
use std::task::{ready, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::UdpSocket;
|
||||
|
@ -23,8 +22,7 @@ use tokio::time::Sleep;
|
|||
use tracing::{debug, error, info};
|
||||
|
||||
struct IoInner {
|
||||
has_data_to_read: &'static Notify,
|
||||
waker: Mutex<Option<Waker>>,
|
||||
has_data_to_read: Notify,
|
||||
has_read_data: Notify,
|
||||
}
|
||||
struct UdpServer {
|
||||
|
@ -43,6 +41,7 @@ impl UdpServer {
|
|||
cnx_timeout: timeout,
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
fn clean_dead_keys(&mut self) {
|
||||
let nb_key_to_delete = self.keys_to_delete.read().len();
|
||||
if nb_key_to_delete == 0 {
|
||||
|
@ -52,16 +51,7 @@ impl UdpServer {
|
|||
debug!("Cleaning {} dead udp peers", nb_key_to_delete);
|
||||
let mut keys_to_delete = self.keys_to_delete.write();
|
||||
for key in keys_to_delete.iter() {
|
||||
let Some(peer) = self.peers.remove(key) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
#[allow(mutable_transmutes)]
|
||||
unsafe {
|
||||
let _ = Box::from_raw(std::mem::transmute::<&Notify, &mut Notify>(
|
||||
peer.has_data_to_read,
|
||||
));
|
||||
}
|
||||
self.peers.remove(key);
|
||||
}
|
||||
keys_to_delete.clear();
|
||||
}
|
||||
|
@ -90,7 +80,42 @@ impl PinnedDrop for UdpStream {
|
|||
keys_to_delete.write().push(self.peer);
|
||||
}
|
||||
|
||||
self.io.has_read_data.notify_one();
|
||||
// safety: we are dropping the notification as we extend its lifetime to 'static unsafely
|
||||
// So it must be gone before we drop its parent. It should never happen but in case
|
||||
let mut project = self.project();
|
||||
project.pending_notification.as_mut().set(None);
|
||||
project.io.has_read_data.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpStream {
|
||||
fn new(
|
||||
socket: Arc<UdpSocket>,
|
||||
peer: SocketAddr,
|
||||
deadline: Option<Sleep>,
|
||||
keys_to_delete: Weak<RwLock<Vec<SocketAddr>>>,
|
||||
) -> (Self, Arc<IoInner>) {
|
||||
let has_data_to_read = Notify::new();
|
||||
let has_read_data = Notify::new();
|
||||
has_data_to_read.notify_one();
|
||||
let io = Arc::new(IoInner {
|
||||
has_data_to_read,
|
||||
has_read_data,
|
||||
});
|
||||
let mut s = Self {
|
||||
socket,
|
||||
peer,
|
||||
deadline,
|
||||
has_been_notified: false,
|
||||
pending_notification: None,
|
||||
io: io.clone(),
|
||||
keys_to_delete,
|
||||
};
|
||||
|
||||
let pending_notification = unsafe { std::mem::transmute(s.io.has_data_to_read.notified()) };
|
||||
s.pending_notification = Some(pending_notification);
|
||||
|
||||
(s, io)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -111,43 +136,28 @@ impl AsyncRead for UdpStream {
|
|||
}
|
||||
|
||||
if let Some(notified) = project.pending_notification.as_mut().as_pin_mut() {
|
||||
if !notified.poll(cx).is_ready() {
|
||||
project.io.waker.lock().replace(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
}
|
||||
ready!(notified.poll(cx));
|
||||
project.pending_notification.as_mut().set(None);
|
||||
}
|
||||
|
||||
let _ = ready!(project.socket.poll_recv(cx, obuf));
|
||||
project
|
||||
.pending_notification
|
||||
.as_mut()
|
||||
.set(Some(project.io.has_data_to_read.notified()));
|
||||
let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) };
|
||||
project.pending_notification.as_mut().set(Some(notified));
|
||||
project.io.has_read_data.notify_one();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for UdpStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, Error>> {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||
self.socket.poll_send_to(cx, buf, self.peer)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Error>> {
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
self.socket.poll_send_ready(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Error>> {
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
@ -178,39 +188,22 @@ pub async fn run_server(
|
|||
}
|
||||
};
|
||||
|
||||
match server.peers.entry(peer_addr) {
|
||||
Entry::Occupied(mut peer) => {
|
||||
let io = peer.get_mut();
|
||||
match server.peers.get(&peer_addr) {
|
||||
Some(io) => {
|
||||
io.has_read_data.notified().await;
|
||||
io.has_data_to_read.notify_one();
|
||||
let waker = io.waker.lock().deref_mut().take();
|
||||
if let Some(waker) = waker {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
Entry::Vacant(peer) => {
|
||||
let has_data_to_read: &'static Notify = Box::leak(Box::new(Notify::new()));
|
||||
let pending_notification = has_data_to_read.notified();
|
||||
let has_read_data = Notify::new();
|
||||
has_data_to_read.notify_one();
|
||||
let io = Arc::new(IoInner {
|
||||
has_data_to_read,
|
||||
waker: Mutex::new(None),
|
||||
has_read_data,
|
||||
});
|
||||
peer.insert(io.clone());
|
||||
let udp_client = UdpStream {
|
||||
socket: server.clone_socket(),
|
||||
peer: peer_addr,
|
||||
deadline: server
|
||||
None => {
|
||||
let (udp_client, io) = UdpStream::new(
|
||||
server.clone_socket(),
|
||||
peer_addr,
|
||||
server
|
||||
.cnx_timeout
|
||||
.and_then(|timeout| tokio::time::Instant::now().checked_add(timeout))
|
||||
.map(tokio::time::sleep_until),
|
||||
keys_to_delete: Arc::downgrade(&server.keys_to_delete),
|
||||
has_been_notified: false,
|
||||
pending_notification: Some(pending_notification),
|
||||
io,
|
||||
};
|
||||
Arc::downgrade(&server.keys_to_delete),
|
||||
);
|
||||
server.peers.insert(peer_addr, io);
|
||||
return Some((Ok(udp_client), (server)));
|
||||
}
|
||||
}
|
||||
|
@ -231,11 +224,7 @@ impl MyUdpSocket {
|
|||
}
|
||||
|
||||
impl AsyncRead for MyUdpSocket {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
fn poll_read(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
|
||||
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }
|
||||
.poll_recv_from(cx, buf)
|
||||
.map(|x| x.map(|_| ()))
|
||||
|
@ -243,25 +232,15 @@ impl AsyncRead for MyUdpSocket {
|
|||
}
|
||||
|
||||
impl AsyncWrite for MyUdpSocket {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, Error>> {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||
unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Error>> {
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), Error>> {
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
@ -331,10 +310,7 @@ mod tests {
|
|||
assert!(client.send_to(b"aaaaa".as_ref(), server_addr).await.is_ok());
|
||||
|
||||
let client2 = UdpSocket::bind("[::1]:0").await.unwrap();
|
||||
assert!(client2
|
||||
.send_to(b"bbbbb".as_ref(), server_addr)
|
||||
.await
|
||||
.is_ok());
|
||||
assert!(client2.send_to(b"bbbbb".as_ref(), server_addr).await.is_ok());
|
||||
|
||||
// Should have a new connection
|
||||
let fut = timeout(Duration::from_millis(100), server.next()).await;
|
||||
|
@ -360,10 +336,7 @@ mod tests {
|
|||
assert_eq!(&buf[..6], b"bbbbb\0");
|
||||
|
||||
assert!(client.send_to(b"ccccc".as_ref(), server_addr).await.is_ok());
|
||||
assert!(client2
|
||||
.send_to(b"ddddd".as_ref(), server_addr)
|
||||
.await
|
||||
.is_ok());
|
||||
assert!(client2.send_to(b"ddddd".as_ref(), server_addr).await.is_ok());
|
||||
|
||||
// Server need to be polled to feed the stream with need data
|
||||
let _ = timeout(Duration::from_millis(100), server.next()).await;
|
||||
|
|
Loading…
Reference in a new issue