diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..df3318f --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,3 @@ +edition = "2021" +max_width = 120 +fn_call_width = 80 diff --git a/src/embedded_certificate.rs b/src/embedded_certificate.rs index 49bcfbe..325b1f2 100644 --- a/src/embedded_certificate.rs +++ b/src/embedded_certificate.rs @@ -3,14 +3,13 @@ use tokio_rustls::rustls::{Certificate, PrivateKey}; pub static TLS_PRIVATE_KEY: Lazy = Lazy::new(|| { let key = include_bytes!("../certs/key.pem"); - let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key.as_slice()) - .expect("failed to load embedded tls private key"); + let mut keys = + rustls_pemfile::pkcs8_private_keys(&mut key.as_slice()).expect("failed to load embedded tls private key"); PrivateKey(keys.remove(0)) }); pub static TLS_CERTIFICATE: Lazy> = Lazy::new(|| { let cert = include_bytes!("../certs/cert.pem"); - let certs = rustls_pemfile::certs(&mut cert.as_slice()) - .expect("failed to load embedded tls certificate"); + let certs = rustls_pemfile::certs(&mut cert.as_slice()).expect("failed to load embedded tls certificate"); certs.into_iter().map(Certificate).collect() }); diff --git a/src/main.rs b/src/main.rs index 1626714..138a900 100644 --- a/src/main.rs +++ b/src/main.rs @@ -67,13 +67,7 @@ struct Client { /// This option set the maximum number of connection that will be kept open. /// This is useful if you plan to create/destroy a lot of tunnel (i.e: with socks5 to navigate with a browser) /// It will avoid the latency of doing tcp + tls handshake with the server - #[arg( - short = 'c', - long, - value_name = "INT", - default_value = "0", - verbatim_doc_comment - )] + #[arg(short = 'c', long, value_name = "INT", default_value = "0", verbatim_doc_comment)] connection_min_idle: u32, /// Domain name that will be use as SNI during TLS handshake @@ -88,12 +82,7 @@ struct Client { tls_verify_certificate: bool, /// If set, will use this http proxy to connect to the server - #[arg( - short = 'p', - long, - value_name = "http://USER:PASS@HOST:PORT", - verbatim_doc_comment - )] + #[arg(short = 'p', long, value_name = "http://USER:PASS@HOST:PORT", verbatim_doc_comment)] http_proxy: Option, /// Use a specific prefix that will show up in the http path during the upgrade request. @@ -241,9 +230,7 @@ fn parse_local_bind(arg: &str) -> Result<(SocketAddr, &str), io::Error> { } #[allow(clippy::type_complexity)] -fn parse_tunnel_dest( - remaining: &str, -) -> Result<(Host, u16, BTreeMap), io::Error> { +fn parse_tunnel_dest(remaining: &str) -> Result<(Host, u16, BTreeMap), io::Error> { use std::io::Error; let Ok(remote) = Url::parse(&format!("fake://{}", remaining)) else { @@ -290,13 +277,7 @@ fn parse_tunnel_arg(arg: &str) -> Result { let timeout = options .get("timeout_sec") .and_then(|x| x.parse::().ok()) - .map(|d| { - if d == 0 { - None - } else { - Some(Duration::from_secs(d)) - } - }) + .map(|d| if d == 0 { None } else { Some(Duration::from_secs(d)) }) .unwrap_or(Some(Duration::from_secs(30))); Ok(LocalToRemote { @@ -355,10 +336,7 @@ fn parse_http_headers(arg: &str) -> Result<(HeaderName, HeaderValue), io::Error> Err(err) => { return Err(io::Error::new( ErrorKind::InvalidInput, - format!( - "cannot parse http header value from {} due to {:?}", - value, err - ), + format!("cannot parse http header value from {} due to {:?}", value, err), )) } }; @@ -394,10 +372,7 @@ fn parse_server_url(arg: &str) -> Result { } if url.host().is_none() { - return Err(io::Error::new( - ErrorKind::InvalidInput, - format!("invalid server host {}", arg), - )); + return Err(io::Error::new(ErrorKind::InvalidInput, format!("invalid server host {}", arg))); } Ok(url) @@ -474,15 +449,9 @@ impl WsClientConfig { } pub fn tls_server_name(&self) -> ServerName { - match self - .tls - .as_ref() - .and_then(|tls| tls.tls_sni_override.as_ref()) - { + match self.tls.as_ref().and_then(|tls| tls.tls_sni_override.as_ref()) { None => match &self.remote_addr.0 { - Host::Domain(domain) => { - ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap()) - } + Host::Domain(domain) => ServerName::DnsName(DnsName::try_from(domain.clone()).unwrap()), Host::Ipv4(ip) => ServerName::IpAddress(IpAddr::V4(*ip)), Host::Ipv6(ip) => ServerName::IpAddress(IpAddr::V6(*ip)), }, @@ -529,12 +498,11 @@ async fn main() { }; // Extract host header from http_headers - let host_header = - if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) { - host_val.clone() - } else { - HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap() - }; + let host_header = if let Some((_, host_val)) = args.http_headers.iter().find(|(h, _)| *h == HOST) { + host_val.clone() + } else { + HeaderValue::from_str(&args.remote_addr.host().unwrap().to_string()).unwrap() + }; let mut client_config = WsClientConfig { remote_addr: ( args.remote_addr.host().unwrap().to_owned(), @@ -544,16 +512,10 @@ async fn main() { tls, http_upgrade_path_prefix: args.http_upgrade_path_prefix, http_upgrade_credentials: args.http_upgrade_credentials, - http_headers: args - .http_headers - .into_iter() - .filter(|(k, _)| k != HOST) - .collect(), + http_headers: args.http_headers.into_iter().filter(|(k, _)| k != HOST).collect(), http_header_host: host_header, timeout_connect: Duration::from_secs(10), - websocket_ping_frequency: args - .websocket_ping_frequency_sec - .unwrap_or(Duration::from_secs(30)), + websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)), websocket_mask_frame: args.websocket_mask_frame, http_proxy: args.http_proxy, cnx_pool: None, @@ -579,16 +541,12 @@ async fn main() { let remote = tunnel.remote.clone(); let server = tcp::run_server(tunnel.local) .await - .unwrap_or_else(|err| { - panic!("Cannot start TCP server on {}: {}", tunnel.local, err) - }) + .unwrap_or_else(|err| panic!("Cannot start TCP server on {}: {}", tunnel.local, err)) .map_err(anyhow::Error::new) .map_ok(move |stream| (stream.into_split(), remote.clone())); tokio::spawn(async move { - if let Err(err) = - tunnel::client::run_tunnel(client_config, tunnel, server).await - { + if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { error!("{:?}", err); } }); @@ -597,16 +555,12 @@ async fn main() { let remote = tunnel.remote.clone(); let server = udp::run_server(tunnel.local, *timeout) .await - .unwrap_or_else(|err| { - panic!("Cannot start UDP server on {}: {}", tunnel.local, err) - }) + .unwrap_or_else(|err| panic!("Cannot start UDP server on {}: {}", tunnel.local, err)) .map_err(anyhow::Error::new) .map_ok(move |stream| (tokio::io::split(stream), remote.clone())); tokio::spawn(async move { - if let Err(err) = - tunnel::client::run_tunnel(client_config, tunnel, server).await - { + if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { error!("{:?}", err); } }); @@ -614,15 +568,11 @@ async fn main() { LocalProtocol::Socks5 => { let server = socks5::run_server(tunnel.local) .await - .unwrap_or_else(|err| { - panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err) - }) + .unwrap_or_else(|err| panic!("Cannot start Socks5 server on {}: {}", tunnel.local, err)) .map_ok(|(stream, remote_dest)| (stream.into_split(), remote_dest)); tokio::spawn(async move { - if let Err(err) = - tunnel::client::run_tunnel(client_config, tunnel, server).await - { + if let Err(err) = tunnel::client::run_tunnel(client_config, tunnel, server).await { error!("{:?}", err); } }); @@ -656,8 +606,7 @@ async fn main() { Commands::Server(args) => { let tls_config = if args.remote_addr.scheme() == "wss" { let tls_certificate = if let Some(cert_path) = args.tls_certificate { - tls::load_certificates_from_pem(&cert_path) - .expect("Cannot load tls certificate") + tls::load_certificates_from_pem(&cert_path).expect("Cannot load tls certificate") } else { embedded_certificate::TLS_CERTIFICATE.clone() }; diff --git a/src/socks5.rs b/src/socks5.rs index b7cb796..44213ae 100644 --- a/src/socks5.rs +++ b/src/socks5.rs @@ -19,10 +19,7 @@ pub struct Socks5Listener { impl Stream for Socks5Listener { type Item = anyhow::Result<(TcpStream, (Host, u16))>; - fn poll_next( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { unsafe { self.map_unchecked_mut(|x| &mut x.stream) }.poll_next(cx) } } diff --git a/src/tcp.rs b/src/tcp.rs index b2aa675..3403a92 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -14,12 +14,9 @@ use tracing::log::info; use url::{Host, Url}; fn configure_socket(socket: &mut TcpSocket, so_mark: &Option) -> Result<(), anyhow::Error> { - socket.set_nodelay(true).with_context(|| { - format!( - "cannot set no_delay on socket: {}", - io::Error::last_os_error() - ) - })?; + socket + .set_nodelay(true) + .with_context(|| format!("cannot set no_delay on socket: {}", io::Error::last_os_error()))?; #[cfg(target_os = "linux")] if let Some(so_mark) = so_mark { @@ -35,10 +32,7 @@ fn configure_socket(socket: &mut TcpSocket, so_mark: &Option) -> Result<(), ); if ret != 0 { - return Err(anyhow!( - "Cannot set SO_MARK on the connection {:?}", - io::Error::last_os_error() - )); + return Err(anyhow!("Cannot set SO_MARK on the connection {:?}", io::Error::last_os_error())); } } } @@ -117,17 +111,14 @@ pub async fn connect_with_http_proxy( let mut socket = connect(&proxy_host, proxy_port, so_mark, connect_timeout).await?; info!("Connected to http proxy {}:{}", proxy_host, proxy_port); - let authorization = - if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) { - let creds = - base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password)); - format!("Proxy-Authorization: Basic {}\r\n", creds) - } else { - "".to_string() - }; + let authorization = if let Some((user, password)) = proxy.password().map(|p| (proxy.username(), p)) { + let creds = base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", user, password)); + format!("Proxy-Authorization: Basic {}\r\n", creds) + } else { + "".to_string() + }; - let connect_request = - format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n"); + let connect_request = format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n"); socket.write_all(connect_request.as_bytes()).await?; let mut buf = BytesMut::with_capacity(1024); @@ -136,16 +127,15 @@ pub async fn connect_with_http_proxy( match nb_bytes { Ok(Ok(0)) => { return Err(anyhow!( - "Cannot connect to http proxy. Proxy closed the connection without returning any response")); + "Cannot connect to http proxy. Proxy closed the connection without returning any response" + )); } Ok(Ok(_)) => {} Ok(Err(err)) => { return Err(anyhow!("Cannot connect to http proxy. {err}")); } Err(_) => { - return Err(anyhow!( - "Cannot connect to http proxy. Proxy took too long to connect" - )); + return Err(anyhow!("Cannot connect to http proxy. Proxy took too long to connect")); } }; @@ -225,8 +215,7 @@ mod tests { let server = TcpListener::bind(server_addr).await.unwrap(); let docker = testcontainers::clients::Cli::default(); - let mitm_proxy: RunnableImage = - RunnableImage::from(MitmProxy {}).with_network("host".to_string()); + let mitm_proxy: RunnableImage = RunnableImage::from(MitmProxy {}).with_network("host".to_string()); let _node = docker.run(mitm_proxy); let mut client = connect_with_http_proxy( @@ -239,10 +228,7 @@ mod tests { .await .unwrap(); - client - .write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice()) - .await - .unwrap(); + client.write_all(b"GET / HTTP/1.1\r\n\r\n".as_slice()).await.unwrap(); let client_srv = server.accept().await.unwrap().0; pin_mut!(client_srv); diff --git a/src/tls.rs b/src/tls.rs index fe35de7..b700d25 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -45,21 +45,15 @@ pub fn load_private_key_from_file(path: &Path) -> anyhow::Result { match keys.len() { 0 => Err(anyhow!("No PKCS8-encoded private key found in {path:?}")), 1 => Ok(PrivateKey(keys.remove(0))), - _ => Err(anyhow!( - "More than one PKCS8-encoded private key found in {path:?}" - )), + _ => Err(anyhow!("More than one PKCS8-encoded private key found in {path:?}")), } } -fn tls_connector( - tls_cfg: &TlsClientConfig, - alpn_protocols: Option>>, -) -> anyhow::Result { +fn tls_connector(tls_cfg: &TlsClientConfig, alpn_protocols: Option>>) -> anyhow::Result { let mut root_store = rustls::RootCertStore::empty(); // Load system certificates and add them to the root store - let certs = rustls_native_certs::load_native_certs() - .with_context(|| "Cannot load system certificates")?; + let certs = rustls_native_certs::load_native_certs().with_context(|| "Cannot load system certificates")?; for cert in certs { root_store.add(&Certificate(cert.0))?; } @@ -71,9 +65,7 @@ fn tls_connector( // To bypass certificate verification if !tls_cfg.tls_verify_certificate { - config - .dangerous() - .set_certificate_verifier(Arc::new(NullVerifier)); + config.dangerous().set_certificate_verifier(Arc::new(NullVerifier)); } if let Some(alpn_protocols) = alpn_protocols { @@ -83,10 +75,7 @@ fn tls_connector( Ok(tls_connector) } -pub fn tls_acceptor( - tls_cfg: &TlsServerConfig, - alpn_protocols: Option>>, -) -> anyhow::Result { +pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option>>) -> anyhow::Result { let mut config = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() @@ -114,12 +103,7 @@ pub async fn connect( let tls_stream = tls_connector .connect(sni, tcp_stream) .await - .with_context(|| { - format!( - "failed to do TLS handshake with the server {:?}", - client_cfg.remote_addr - ) - })?; + .with_context(|| format!("failed to do TLS handshake with the server {:?}", client_cfg.remote_addr))?; Ok(tls_stream) } diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 415152d..6710f04 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -44,9 +44,7 @@ pub async fn connect( ) -> anyhow::Result> { let mut pooled_cnx = match client_cfg.cnx_pool().get().await { Ok(tcp_stream) => tcp_stream, - Err(err) => Err(anyhow!( - "failed to get a connection to the server from the pool: {err:?}" - ))?, + Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}"))?, }; let mut req = Request::builder() @@ -80,12 +78,7 @@ pub async fn connect( let transport = pooled_cnx.deref_mut().take().unwrap(); let (ws, _) = fastwebsockets::handshake::client(&SpawnExecutor, req, transport) .await - .with_context(|| { - format!( - "failed to do websocket handshake with the server {:?}", - client_cfg.remote_addr - ) - })?; + .with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?; Ok(ws) } @@ -109,10 +102,7 @@ where // Forward local tx to websocket tx let ping_frequency = client_cfg.websocket_ping_frequency; - tokio::spawn( - super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency) - .instrument(Span::current()), - ); + tokio::spawn(super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).instrument(Span::current())); // Forward websocket rx to local rx let _ = super::io::propagate_write(local_tx, ws_rx, close_rx).await; diff --git a/src/tunnel/io.rs b/src/tunnel/io.rs index abf5bca..0caa459 100644 --- a/src/tunnel/io.rs +++ b/src/tunnel/io.rs @@ -1,12 +1,12 @@ use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite}; use futures_util::pin_mut; use hyper::upgrade::Upgraded; +use std::pin::Pin; use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::select; use tokio::sync::oneshot; -use tokio::time::timeout; use tracing::log::debug; use tracing::{error, info, trace, warn}; @@ -20,7 +20,14 @@ pub(super) async fn propagate_read( info!("Closing local tx ==> websocket tx tunnel"); }); - let mut buffer = vec![0u8; 8 * 1024]; + static JUMBO_FRAME_SIZE: usize = 9 * 1024; // enough for a jumbo frame + let mut buffer = vec![0u8; JUMBO_FRAME_SIZE]; + + // We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration + // We reuse the future to avoid creating a timer in the tight loop + let mut timeout_unpin = tokio::time::sleep(ping_frequency); + let mut timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) }; + pin_mut!(local_rx); loop { let read_len = select! { @@ -30,9 +37,12 @@ pub(super) async fn propagate_read( _ = close_tx.closed() => break, - _ = timeout(ping_frequency, futures_util::future::pending::<()>()) => { + _ = &mut timeout => { debug!("sending ping to keep websocket connection alive"); ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?; + + timeout_unpin = tokio::time::sleep(ping_frequency); + timeout = unsafe { Pin::new_unchecked(&mut timeout_unpin) }; continue; } }; @@ -41,32 +51,30 @@ pub(super) async fn propagate_read( Ok(0) => break, Ok(read_len) => read_len, Err(err) => { - warn!( - "error while reading incoming bytes from local tx tunnel {}", - err - ); + warn!("error while reading incoming bytes from local tx tunnel {}", err); break; } }; trace!("read {} bytes", read_len); - match ws_tx + if let Err(err) = ws_tx .write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len]))) .await { - Ok(_) => {} - Err(err) => { - warn!("error while writing to websocket tx tunnel {}", err); - break; - } + warn!("error while writing to websocket tx tunnel {}", err); + break; } + // If the buffer has been completely filled with previous read, Double it ! + // For the buffer to not be a bottleneck when the TCP window scale + // For udp, the buffer will never grows. if buffer.capacity() == read_len { buffer.clear(); buffer.resize(buffer.capacity() * 2, 0); } } + // Send normal close let _ = ws_tx.write_frame(Frame::close(1000, &[])).await; Ok(()) @@ -104,20 +112,15 @@ pub(super) async fn propagate_write( trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload); let ret = match msg.opcode { - OpCode::Continuation | OpCode::Text | OpCode::Binary => { - local_tx.write_all(msg.payload.as_ref()).await - } + OpCode::Continuation | OpCode::Text | OpCode::Binary => local_tx.write_all(msg.payload.as_ref()).await, OpCode::Close => break, OpCode::Ping => Ok(()), OpCode::Pong => Ok(()), }; - match ret { - Ok(_) => {} - Err(err) => { - error!("error while writing bytes to local for rx tunnel {}", err); - break; - } + if let Err(err) = ret { + error!("error while writing bytes to local for rx tunnel {}", err); + break; } } diff --git a/src/tunnel/mod.rs b/src/tunnel/mod.rs index 0709d1d..3b9b0a7 100644 --- a/src/tunnel/mod.rs +++ b/src/tunnel/mod.rs @@ -42,12 +42,8 @@ impl JwtTunnelConfig { } static JWT_SECRET: &[u8; 15] = b"champignonfrais"; -static JWT_KEY: Lazy<(Header, EncodingKey)> = Lazy::new(|| { - ( - Header::new(Algorithm::HS256), - EncodingKey::from_secret(JWT_SECRET), - ) -}); +static JWT_KEY: Lazy<(Header, EncodingKey)> = + Lazy::new(|| (Header::new(Algorithm::HS256), EncodingKey::from_secret(JWT_SECRET))); static JWT_DECODE: Lazy<(Validation, DecodingKey)> = Lazy::new(|| { let mut validation = Validation::new(Algorithm::HS256); @@ -61,11 +57,7 @@ pub enum TransportStream { } impl AsyncRead for TransportStream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { match self.get_mut() { TransportStream::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf), TransportStream::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf), @@ -74,11 +66,7 @@ impl AsyncRead for TransportStream { } impl AsyncWrite for TransportStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { match self.get_mut() { TransportStream::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf), TransportStream::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf), diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index 6376f7f..8235d9e 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -46,15 +46,8 @@ async fn from_query( Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp)); if let Some(allowed_dests) = &server_config.restrict_to { let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp); - if allowed_dests - .iter() - .any(|dest| dest == &requested_dest) - .not() - { - warn!( - "Rejecting connection with not allowed destination: {}", - requested_dest - ); + if allowed_dests.iter().any(|dest| dest == &requested_dest).not() { + warn!("Rejecting connection with not allowed destination: {}", requested_dest); return Err(anyhow::anyhow!("Invalid upgrade request")); } } @@ -75,14 +68,9 @@ async fn from_query( LocalProtocol::Tcp { .. } => { let host = Host::parse(&jwt.claims.r)?; let port = jwt.claims.rp; - let (rx, tx) = tcp::connect( - &host, - port, - &server_config.socket_so_mark, - Duration::from_secs(10), - ) - .await? - .into_split(); + let (rx, tx) = tcp::connect(&host, port, &server_config.socket_so_mark, Duration::from_secs(10)) + .await? + .into_split(); Ok((jwt.claims.p, host, port, Box::pin(rx), Box::pin(tx))) } @@ -100,10 +88,7 @@ async fn server_upgrade( } if !req.uri().path().ends_with("/events") { - warn!( - "Rejecting connection with bad upgrade request: {}", - req.uri() - ); + warn!("Rejecting connection with bad upgrade request: {}", req.uri()); return Ok(http::Response::builder() .status(StatusCode::BAD_REQUEST) .body(Body::from("Invalid upgrade request")) @@ -118,10 +103,7 @@ async fn server_upgrade( || &path[min_len..max_len] != path_prefix.as_str() || !path[max_len..].starts_with('/') { - warn!( - "Rejecting connection with bad path prefix in upgrade request: {}", - req.uri() - ); + warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri()); return Ok(http::Response::builder() .status(StatusCode::BAD_REQUEST) .body(Body::from("Invalid upgrade request")) @@ -133,11 +115,7 @@ async fn server_upgrade( match from_query(&server_config, req.uri().query().unwrap_or_default()).await { Ok(ret) => ret, Err(err) => { - warn!( - "Rejecting connection with bad upgrade request: {} {}", - err, - req.uri() - ); + warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); return Ok(http::Response::builder() .status(StatusCode::BAD_REQUEST) .body(Body::from(format!("Invalid upgrade request: {:?}", err))) @@ -149,11 +127,7 @@ async fn server_upgrade( let (response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) { Ok(ret) => ret, Err(err) => { - warn!( - "Rejecting connection with bad upgrade request: {} {}", - err, - req.uri() - ); + warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri()); return Ok(http::Response::builder() .status(StatusCode::BAD_REQUEST) .body(Body::from(format!("Invalid upgrade request: {:?}", err))) @@ -171,14 +145,10 @@ async fn server_upgrade( } }; let (close_tx, close_rx) = oneshot::channel::<()>(); - let ping_frequency = server_config - .websocket_ping_frequency - .unwrap_or(Duration::MAX); + let ping_frequency = server_config.websocket_ping_frequency.unwrap_or(Duration::MAX); ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame); - tokio::task::spawn( - super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current()), - ); + tokio::task::spawn(super::io::propagate_write(local_tx, ws_rx, close_rx).instrument(Span::current())); let _ = super::io::propagate_read(local_rx, ws_tx, close_tx, ping_frequency).await; } @@ -189,10 +159,7 @@ async fn server_upgrade( } pub async fn run_server(server_config: Arc) -> anyhow::Result<()> { - info!( - "Starting wstunnel server listening on {}", - server_config.bind - ); + info!("Starting wstunnel server listening on {}", server_config.bind); let config = server_config.clone(); let upgrade_fn = move |req: Request| server_upgrade(config.clone(), req); diff --git a/src/udp.rs b/src/udp.rs index 739e698..617cc62 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -1,18 +1,17 @@ use anyhow::Context; use futures_util::{stream, Stream}; -use parking_lot::{Mutex, RwLock}; +use parking_lot::RwLock; use pin_project::{pin_project, pinned_drop}; -use std::collections::hash_map::Entry; use std::collections::HashMap; use std::future::Future; use std::io; use std::io::{Error, ErrorKind}; use std::net::SocketAddr; -use std::ops::DerefMut; + use std::pin::{pin, Pin}; use std::sync::{Arc, Weak}; -use std::task::{ready, Poll, Waker}; +use std::task::{ready, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::UdpSocket; @@ -23,8 +22,7 @@ use tokio::time::Sleep; use tracing::{debug, error, info}; struct IoInner { - has_data_to_read: &'static Notify, - waker: Mutex>, + has_data_to_read: Notify, has_read_data: Notify, } struct UdpServer { @@ -43,6 +41,7 @@ impl UdpServer { cnx_timeout: timeout, } } + #[inline] fn clean_dead_keys(&mut self) { let nb_key_to_delete = self.keys_to_delete.read().len(); if nb_key_to_delete == 0 { @@ -52,16 +51,7 @@ impl UdpServer { debug!("Cleaning {} dead udp peers", nb_key_to_delete); let mut keys_to_delete = self.keys_to_delete.write(); for key in keys_to_delete.iter() { - let Some(peer) = self.peers.remove(key) else { - continue; - }; - - #[allow(mutable_transmutes)] - unsafe { - let _ = Box::from_raw(std::mem::transmute::<&Notify, &mut Notify>( - peer.has_data_to_read, - )); - } + self.peers.remove(key); } keys_to_delete.clear(); } @@ -90,7 +80,42 @@ impl PinnedDrop for UdpStream { keys_to_delete.write().push(self.peer); } - self.io.has_read_data.notify_one(); + // safety: we are dropping the notification as we extend its lifetime to 'static unsafely + // So it must be gone before we drop its parent. It should never happen but in case + let mut project = self.project(); + project.pending_notification.as_mut().set(None); + project.io.has_read_data.notify_one(); + } +} + +impl UdpStream { + fn new( + socket: Arc, + peer: SocketAddr, + deadline: Option, + keys_to_delete: Weak>>, + ) -> (Self, Arc) { + let has_data_to_read = Notify::new(); + let has_read_data = Notify::new(); + has_data_to_read.notify_one(); + let io = Arc::new(IoInner { + has_data_to_read, + has_read_data, + }); + let mut s = Self { + socket, + peer, + deadline, + has_been_notified: false, + pending_notification: None, + io: io.clone(), + keys_to_delete, + }; + + let pending_notification = unsafe { std::mem::transmute(s.io.has_data_to_read.notified()) }; + s.pending_notification = Some(pending_notification); + + (s, io) } } @@ -111,43 +136,28 @@ impl AsyncRead for UdpStream { } if let Some(notified) = project.pending_notification.as_mut().as_pin_mut() { - if !notified.poll(cx).is_ready() { - project.io.waker.lock().replace(cx.waker().clone()); - return Poll::Pending; - } + ready!(notified.poll(cx)); project.pending_notification.as_mut().set(None); } let _ = ready!(project.socket.poll_recv(cx, obuf)); - project - .pending_notification - .as_mut() - .set(Some(project.io.has_data_to_read.notified())); + let notified: Notified<'static> = unsafe { std::mem::transmute(project.io.has_data_to_read.notified()) }; + project.pending_notification.as_mut().set(Some(notified)); project.io.has_read_data.notify_one(); Poll::Ready(Ok(())) } } impl AsyncWrite for UdpStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { + fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll> { self.socket.poll_send_to(cx, buf, self.peer) } - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { self.socket.poll_send_ready(cx) } - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } @@ -178,39 +188,22 @@ pub async fn run_server( } }; - match server.peers.entry(peer_addr) { - Entry::Occupied(mut peer) => { - let io = peer.get_mut(); + match server.peers.get(&peer_addr) { + Some(io) => { io.has_read_data.notified().await; io.has_data_to_read.notify_one(); - let waker = io.waker.lock().deref_mut().take(); - if let Some(waker) = waker { - waker.wake(); - } } - Entry::Vacant(peer) => { - let has_data_to_read: &'static Notify = Box::leak(Box::new(Notify::new())); - let pending_notification = has_data_to_read.notified(); - let has_read_data = Notify::new(); - has_data_to_read.notify_one(); - let io = Arc::new(IoInner { - has_data_to_read, - waker: Mutex::new(None), - has_read_data, - }); - peer.insert(io.clone()); - let udp_client = UdpStream { - socket: server.clone_socket(), - peer: peer_addr, - deadline: server + None => { + let (udp_client, io) = UdpStream::new( + server.clone_socket(), + peer_addr, + server .cnx_timeout .and_then(|timeout| tokio::time::Instant::now().checked_add(timeout)) .map(tokio::time::sleep_until), - keys_to_delete: Arc::downgrade(&server.keys_to_delete), - has_been_notified: false, - pending_notification: Some(pending_notification), - io, - }; + Arc::downgrade(&server.keys_to_delete), + ); + server.peers.insert(peer_addr, io); return Some((Ok(udp_client), (server))); } } @@ -231,11 +224,7 @@ impl MyUdpSocket { } impl AsyncRead for MyUdpSocket { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { + fn poll_read(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { unsafe { self.map_unchecked_mut(|x| &mut x.socket) } .poll_recv_from(cx, buf) .map(|x| x.map(|_| ())) @@ -243,25 +232,15 @@ impl AsyncRead for MyUdpSocket { } impl AsyncWrite for MyUdpSocket { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> Poll> { + fn poll_write(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll> { unsafe { self.map_unchecked_mut(|x| &mut x.socket) }.poll_send(cx, buf) } - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut std::task::Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } @@ -331,10 +310,7 @@ mod tests { assert!(client.send_to(b"aaaaa".as_ref(), server_addr).await.is_ok()); let client2 = UdpSocket::bind("[::1]:0").await.unwrap(); - assert!(client2 - .send_to(b"bbbbb".as_ref(), server_addr) - .await - .is_ok()); + assert!(client2.send_to(b"bbbbb".as_ref(), server_addr).await.is_ok()); // Should have a new connection let fut = timeout(Duration::from_millis(100), server.next()).await; @@ -360,10 +336,7 @@ mod tests { assert_eq!(&buf[..6], b"bbbbb\0"); assert!(client.send_to(b"ccccc".as_ref(), server_addr).await.is_ok()); - assert!(client2 - .send_to(b"ddddd".as_ref(), server_addr) - .await - .is_ok()); + assert!(client2.send_to(b"ddddd".as_ref(), server_addr).await.is_ok()); // Server need to be polled to feed the stream with need data let _ = timeout(Duration::from_millis(100), server.next()).await;