use BytesMut instead of vec

This commit is contained in:
Σrebe - Romain GERARD 2023-10-27 09:15:15 +02:00
parent f813d925d6
commit b70d547370
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
5 changed files with 43 additions and 41 deletions

View file

@ -451,7 +451,7 @@ pub struct WsClientConfig {
pub http_upgrade_path_prefix: String, pub http_upgrade_path_prefix: String,
pub http_upgrade_credentials: Option<HeaderValue>, pub http_upgrade_credentials: Option<HeaderValue>,
pub http_headers: HashMap<HeaderName, HeaderValue>, pub http_headers: HashMap<HeaderName, HeaderValue>,
pub host_http_header: HeaderValue, pub http_header_host: HeaderValue,
pub timeout_connect: Duration, pub timeout_connect: Duration,
pub websocket_ping_frequency: Duration, pub websocket_ping_frequency: Duration,
pub websocket_mask_frame: bool, pub websocket_mask_frame: bool,
@ -551,7 +551,7 @@ async fn main() {
.into_iter() .into_iter()
.filter(|(k, _)| k != HOST) .filter(|(k, _)| k != HOST)
.collect(), .collect(),
host_http_header: host_header, http_header_host: host_header,
timeout_connect: Duration::from_secs(10), timeout_connect: Duration::from_secs(10),
websocket_ping_frequency: args websocket_ping_frequency: args
.websocket_ping_frequency_sec .websocket_ping_frequency_sec

View file

@ -2,6 +2,7 @@ use anyhow::{anyhow, Context};
use std::{io, vec}; use std::{io, vec};
use base64::Engine; use base64::Engine;
use bytes::BytesMut;
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
@ -127,23 +128,17 @@ pub async fn connect_with_http_proxy(
let connect_request = let connect_request =
format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n"); format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
socket socket.write_all(connect_request.as_bytes()).await?;
.write_all(connect_request.trim_start().as_bytes())
.await?;
let mut buf = [0u8; 8096]; let mut buf = BytesMut::with_capacity(1024);
let mut needle = 0;
loop { loop {
let nb_bytes = tokio::time::timeout(connect_timeout, socket.read(&mut buf[needle..])).await; let nb_bytes = tokio::time::timeout(connect_timeout, socket.read_buf(&mut buf)).await;
let nb_bytes = match nb_bytes { match nb_bytes {
Ok(Ok(nb_bytes)) => { Ok(Ok(0)) => {
if nb_bytes == 0 {
return Err(anyhow!( return Err(anyhow!(
"Cannot connect to http proxy. Proxy closed the connection without returning any response")); "Cannot connect to http proxy. Proxy closed the connection without returning any response"));
} else {
nb_bytes
}
} }
Ok(Ok(_)) => {}
Ok(Err(err)) => { Ok(Err(err)) => {
return Err(anyhow!("Cannot connect to http proxy. {err}")); return Err(anyhow!("Cannot connect to http proxy. {err}"));
} }
@ -154,20 +149,24 @@ pub async fn connect_with_http_proxy(
} }
}; };
needle += nb_bytes; static END_HTTP_RESPONSE: &[u8; 4] = b"\r\n\r\n"; // It is reversed from \r\n\r\n as we reverse scan the buffer
if buf[..needle].windows(4).any(|window| window == b"\r\n\r\n") { if buf.len() > 50 * 1024
|| buf
.windows(END_HTTP_RESPONSE.len())
.any(|window| window == END_HTTP_RESPONSE)
{
break; break;
} }
} }
let ok_response = b"HTTP/1.0 200"; static OK_RESPONSE: &[u8; 12] = b"HTTP/1.0 200";
if !buf if !buf
.windows(ok_response.len()) .windows(OK_RESPONSE.len())
.any(|window| window == ok_response) .any(|window| window == OK_RESPONSE)
{ {
return Err(anyhow!( return Err(anyhow!(
"Cannot connect to http proxy. Proxy returned an invalid response: {}", "Cannot connect to http proxy. Proxy returned an invalid response: {}",
String::from_utf8_lossy(&buf[..needle]) String::from_utf8_lossy(&buf)
)); ));
} }

View file

@ -58,7 +58,7 @@ pub async fn connect(
&client_cfg.http_upgrade_path_prefix, &client_cfg.http_upgrade_path_prefix,
jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(), jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(),
)) ))
.header(HOST, &client_cfg.host_http_header) .header(HOST, &client_cfg.http_header_host)
.header(UPGRADE, "websocket") .header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade") .header(CONNECTION, "upgrade")
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key()) .header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())

View file

@ -1,6 +1,9 @@
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite}; use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
use futures_util::pin_mut; use futures_util::pin_mut;
use hyper::upgrade::Upgraded; use hyper::upgrade::Upgraded;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::select; use tokio::select;
@ -22,21 +25,21 @@ pub(super) async fn propagate_read(
let mut buffer = vec![0u8; 8 * 1024]; let mut buffer = vec![0u8; 8 * 1024];
pin_mut!(local_rx); pin_mut!(local_rx);
loop { loop {
let read = select! { let read_len = select! {
biased; biased;
read_len = local_rx.read(buffer.as_mut_slice()) => read_len, read_len = local_rx.read(&mut buffer) => read_len,
_ = close_tx.closed() => break, _ = close_tx.closed() => break,
_ = timeout(ping_frequency, futures_util::future::pending::<()>()) => { _ = timeout(ping_frequency, futures_util::future::pending::<()>()) => {
debug!("sending ping to keep websocket connection alive"); debug!("sending ping to keep websocket connection alive");
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::Borrowed(&[]))).await?; ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?;
continue; continue;
} }
}; };
let read_len = match read { let read_len = match read_len {
Ok(read_len) if read_len > 0 => read_len, Ok(read_len) if read_len > 0 => read_len,
Ok(_) => break, Ok(_) => break,
Err(err) => { Err(err) => {
@ -50,7 +53,7 @@ pub(super) async fn propagate_read(
trace!("read {} bytes", read_len); trace!("read {} bytes", read_len);
match ws_tx match ws_tx
.write_frame(Frame::binary(Payload::Borrowed(&buffer[..read_len]))) .write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len])))
.await .await
{ {
Ok(_) => {} Ok(_) => {}
@ -60,8 +63,10 @@ pub(super) async fn propagate_read(
} }
} }
if read_len == buffer.len() { if buffer.capacity() == read_len {
buffer.resize(read_len * 2, 0); buffer.clear();
info!("capa: {} read:{}", buffer.capacity(), read_len);
buffer.resize(buffer.capacity() * 2, 0);
} }
} }
@ -85,14 +90,14 @@ pub(super) async fn propagate_write(
pin_mut!(local_tx); pin_mut!(local_tx);
loop { loop {
let ret = select! { let msg = select! {
biased; biased;
ret = ws_rx.read_frame(&mut x) => ret, msg = ws_rx.read_frame(&mut x) => msg,
_ = &mut close_rx => break, _ = &mut close_rx => break,
}; };
let msg = match ret { let msg = match msg {
Ok(msg) => msg, Ok(msg) => msg,
Err(err) => { Err(err) => {
error!("error while reading from websocket rx {}", err); error!("error while reading from websocket rx {}", err);

View file

@ -37,7 +37,6 @@ impl UdpServer {
cnx_timeout: timeout, cnx_timeout: timeout,
} }
} }
fn clean_dead_keys(&mut self) { fn clean_dead_keys(&mut self) {
let nb_key_to_delete = self.keys_to_delete.read().unwrap().len(); let nb_key_to_delete = self.keys_to_delete.read().unwrap().len();
if nb_key_to_delete == 0 { if nb_key_to_delete == 0 {
@ -51,7 +50,6 @@ impl UdpServer {
} }
keys_to_delete.clear(); keys_to_delete.clear();
} }
fn clone_socket(&self) -> Arc<UdpSocket> { fn clone_socket(&self) -> Arc<UdpSocket> {
self.listener.clone() self.listener.clone()
} }
@ -80,7 +78,7 @@ impl AsyncRead for UdpStream {
fn poll_read( fn poll_read(
self: Pin<&mut Self>, self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>, obuf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> { ) -> Poll<io::Result<()>> {
let project = self.project(); let project = self.project();
if let Some(deadline) = project.deadline.as_pin_mut() { if let Some(deadline) = project.deadline.as_pin_mut() {
@ -93,11 +91,11 @@ impl AsyncRead for UdpStream {
} }
let mut guard = project.io.lock().unwrap(); let mut guard = project.io.lock().unwrap();
let (inner, waker) = guard.deref_mut(); let (ibuf, waker) = guard.deref_mut();
if inner.has_remaining() { if ibuf.has_remaining() {
let max = inner.remaining().min(buf.remaining()); let max = ibuf.remaining().min(obuf.remaining());
buf.put_slice(&inner[..max]); obuf.put_slice(&ibuf[..max]);
inner.advance(max); ibuf.advance(max);
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} else { } else {
waker.replace(cx.waker().clone()); waker.replace(cx.waker().clone());