From b70d547370b125cce377ec25e9db7d00bacfdd7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A3rebe=20-=20Romain=20GERARD?= Date: Fri, 27 Oct 2023 09:15:15 +0200 Subject: [PATCH] use BytesMut instead of vec --- src/main.rs | 4 ++-- src/tcp.rs | 39 +++++++++++++++++++-------------------- src/tunnel/client.rs | 2 +- src/tunnel/io.rs | 25 +++++++++++++++---------- src/udp.rs | 14 ++++++-------- 5 files changed, 43 insertions(+), 41 deletions(-) diff --git a/src/main.rs b/src/main.rs index f02f165..3c198f1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -451,7 +451,7 @@ pub struct WsClientConfig { pub http_upgrade_path_prefix: String, pub http_upgrade_credentials: Option, pub http_headers: HashMap, - pub host_http_header: HeaderValue, + pub http_header_host: HeaderValue, pub timeout_connect: Duration, pub websocket_ping_frequency: Duration, pub websocket_mask_frame: bool, @@ -551,7 +551,7 @@ async fn main() { .into_iter() .filter(|(k, _)| k != HOST) .collect(), - host_http_header: host_header, + http_header_host: host_header, timeout_connect: Duration::from_secs(10), websocket_ping_frequency: args .websocket_ping_frequency_sec diff --git a/src/tcp.rs b/src/tcp.rs index ec055b1..b091935 100644 --- a/src/tcp.rs +++ b/src/tcp.rs @@ -2,6 +2,7 @@ use anyhow::{anyhow, Context}; use std::{io, vec}; use base64::Engine; +use bytes::BytesMut; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -127,23 +128,17 @@ pub async fn connect_with_http_proxy( let connect_request = format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n"); - socket - .write_all(connect_request.trim_start().as_bytes()) - .await?; + socket.write_all(connect_request.as_bytes()).await?; - let mut buf = [0u8; 8096]; - let mut needle = 0; + let mut buf = BytesMut::with_capacity(1024); loop { - let nb_bytes = tokio::time::timeout(connect_timeout, socket.read(&mut buf[needle..])).await; - let nb_bytes = match nb_bytes { - Ok(Ok(nb_bytes)) => { - if nb_bytes == 0 { - return Err(anyhow!( - "Cannot connect to http proxy. Proxy closed the connection without returning any response" )); - } else { - nb_bytes - } + let nb_bytes = tokio::time::timeout(connect_timeout, socket.read_buf(&mut buf)).await; + match nb_bytes { + Ok(Ok(0)) => { + return Err(anyhow!( + "Cannot connect to http proxy. Proxy closed the connection without returning any response")); } + Ok(Ok(_)) => {} Ok(Err(err)) => { return Err(anyhow!("Cannot connect to http proxy. {err}")); } @@ -154,20 +149,24 @@ pub async fn connect_with_http_proxy( } }; - needle += nb_bytes; - if buf[..needle].windows(4).any(|window| window == b"\r\n\r\n") { + static END_HTTP_RESPONSE: &[u8; 4] = b"\r\n\r\n"; // It is reversed from \r\n\r\n as we reverse scan the buffer + if buf.len() > 50 * 1024 + || buf + .windows(END_HTTP_RESPONSE.len()) + .any(|window| window == END_HTTP_RESPONSE) + { break; } } - let ok_response = b"HTTP/1.0 200"; + static OK_RESPONSE: &[u8; 12] = b"HTTP/1.0 200"; if !buf - .windows(ok_response.len()) - .any(|window| window == ok_response) + .windows(OK_RESPONSE.len()) + .any(|window| window == OK_RESPONSE) { return Err(anyhow!( "Cannot connect to http proxy. Proxy returned an invalid response: {}", - String::from_utf8_lossy(&buf[..needle]) + String::from_utf8_lossy(&buf) )); } diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index 37e66eb..95ac7cf 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -58,7 +58,7 @@ pub async fn connect( &client_cfg.http_upgrade_path_prefix, jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(), )) - .header(HOST, &client_cfg.host_http_header) + .header(HOST, &client_cfg.http_header_host) .header(UPGRADE, "websocket") .header(CONNECTION, "upgrade") .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) diff --git a/src/tunnel/io.rs b/src/tunnel/io.rs index edf5f7d..7397c87 100644 --- a/src/tunnel/io.rs +++ b/src/tunnel/io.rs @@ -1,6 +1,9 @@ + use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite}; use futures_util::pin_mut; use hyper::upgrade::Upgraded; + + use std::time::Duration; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::select; @@ -22,21 +25,21 @@ pub(super) async fn propagate_read( let mut buffer = vec![0u8; 8 * 1024]; pin_mut!(local_rx); loop { - let read = select! { + let read_len = select! { biased; - read_len = local_rx.read(buffer.as_mut_slice()) => read_len, + read_len = local_rx.read(&mut buffer) => read_len, _ = close_tx.closed() => break, _ = timeout(ping_frequency, futures_util::future::pending::<()>()) => { debug!("sending ping to keep websocket connection alive"); - ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::Borrowed(&[]))).await?; + ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?; continue; } }; - let read_len = match read { + let read_len = match read_len { Ok(read_len) if read_len > 0 => read_len, Ok(_) => break, Err(err) => { @@ -50,7 +53,7 @@ pub(super) async fn propagate_read( trace!("read {} bytes", read_len); match ws_tx - .write_frame(Frame::binary(Payload::Borrowed(&buffer[..read_len]))) + .write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len]))) .await { Ok(_) => {} @@ -60,8 +63,10 @@ pub(super) async fn propagate_read( } } - if read_len == buffer.len() { - buffer.resize(read_len * 2, 0); + if buffer.capacity() == read_len { + buffer.clear(); + info!("capa: {} read:{}", buffer.capacity(), read_len); + buffer.resize(buffer.capacity() * 2, 0); } } @@ -85,14 +90,14 @@ pub(super) async fn propagate_write( pin_mut!(local_tx); loop { - let ret = select! { + let msg = select! { biased; - ret = ws_rx.read_frame(&mut x) => ret, + msg = ws_rx.read_frame(&mut x) => msg, _ = &mut close_rx => break, }; - let msg = match ret { + let msg = match msg { Ok(msg) => msg, Err(err) => { error!("error while reading from websocket rx {}", err); diff --git a/src/udp.rs b/src/udp.rs index 47741c7..86edc82 100644 --- a/src/udp.rs +++ b/src/udp.rs @@ -37,7 +37,6 @@ impl UdpServer { cnx_timeout: timeout, } } - fn clean_dead_keys(&mut self) { let nb_key_to_delete = self.keys_to_delete.read().unwrap().len(); if nb_key_to_delete == 0 { @@ -51,7 +50,6 @@ impl UdpServer { } keys_to_delete.clear(); } - fn clone_socket(&self) -> Arc { self.listener.clone() } @@ -80,7 +78,7 @@ impl AsyncRead for UdpStream { fn poll_read( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, - buf: &mut ReadBuf<'_>, + obuf: &mut ReadBuf<'_>, ) -> Poll> { let project = self.project(); if let Some(deadline) = project.deadline.as_pin_mut() { @@ -93,11 +91,11 @@ impl AsyncRead for UdpStream { } let mut guard = project.io.lock().unwrap(); - let (inner, waker) = guard.deref_mut(); - if inner.has_remaining() { - let max = inner.remaining().min(buf.remaining()); - buf.put_slice(&inner[..max]); - inner.advance(max); + let (ibuf, waker) = guard.deref_mut(); + if ibuf.has_remaining() { + let max = ibuf.remaining().min(obuf.remaining()); + obuf.put_slice(&ibuf[..max]); + ibuf.advance(max); Poll::Ready(Ok(())) } else { waker.replace(cx.waker().clone());