use BytesMut instead of vec

This commit is contained in:
Σrebe - Romain GERARD 2023-10-27 09:15:15 +02:00
parent f813d925d6
commit b70d547370
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
5 changed files with 43 additions and 41 deletions

View file

@ -451,7 +451,7 @@ pub struct WsClientConfig {
pub http_upgrade_path_prefix: String,
pub http_upgrade_credentials: Option<HeaderValue>,
pub http_headers: HashMap<HeaderName, HeaderValue>,
pub host_http_header: HeaderValue,
pub http_header_host: HeaderValue,
pub timeout_connect: Duration,
pub websocket_ping_frequency: Duration,
pub websocket_mask_frame: bool,
@ -551,7 +551,7 @@ async fn main() {
.into_iter()
.filter(|(k, _)| k != HOST)
.collect(),
host_http_header: host_header,
http_header_host: host_header,
timeout_connect: Duration::from_secs(10),
websocket_ping_frequency: args
.websocket_ping_frequency_sec

View file

@ -2,6 +2,7 @@ use anyhow::{anyhow, Context};
use std::{io, vec};
use base64::Engine;
use bytes::BytesMut;
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
@ -127,23 +128,17 @@ pub async fn connect_with_http_proxy(
let connect_request =
format!("CONNECT {host}:{port} HTTP/1.0\r\nHost: {host}:{port}\r\n{authorization}\r\n");
socket
.write_all(connect_request.trim_start().as_bytes())
.await?;
socket.write_all(connect_request.as_bytes()).await?;
let mut buf = [0u8; 8096];
let mut needle = 0;
let mut buf = BytesMut::with_capacity(1024);
loop {
let nb_bytes = tokio::time::timeout(connect_timeout, socket.read(&mut buf[needle..])).await;
let nb_bytes = match nb_bytes {
Ok(Ok(nb_bytes)) => {
if nb_bytes == 0 {
return Err(anyhow!(
"Cannot connect to http proxy. Proxy closed the connection without returning any response" ));
} else {
nb_bytes
}
let nb_bytes = tokio::time::timeout(connect_timeout, socket.read_buf(&mut buf)).await;
match nb_bytes {
Ok(Ok(0)) => {
return Err(anyhow!(
"Cannot connect to http proxy. Proxy closed the connection without returning any response"));
}
Ok(Ok(_)) => {}
Ok(Err(err)) => {
return Err(anyhow!("Cannot connect to http proxy. {err}"));
}
@ -154,20 +149,24 @@ pub async fn connect_with_http_proxy(
}
};
needle += nb_bytes;
if buf[..needle].windows(4).any(|window| window == b"\r\n\r\n") {
static END_HTTP_RESPONSE: &[u8; 4] = b"\r\n\r\n"; // It is reversed from \r\n\r\n as we reverse scan the buffer
if buf.len() > 50 * 1024
|| buf
.windows(END_HTTP_RESPONSE.len())
.any(|window| window == END_HTTP_RESPONSE)
{
break;
}
}
let ok_response = b"HTTP/1.0 200";
static OK_RESPONSE: &[u8; 12] = b"HTTP/1.0 200";
if !buf
.windows(ok_response.len())
.any(|window| window == ok_response)
.windows(OK_RESPONSE.len())
.any(|window| window == OK_RESPONSE)
{
return Err(anyhow!(
"Cannot connect to http proxy. Proxy returned an invalid response: {}",
String::from_utf8_lossy(&buf[..needle])
String::from_utf8_lossy(&buf)
));
}

View file

@ -58,7 +58,7 @@ pub async fn connect(
&client_cfg.http_upgrade_path_prefix,
jsonwebtoken::encode(alg, &data, secret).unwrap_or_default(),
))
.header(HOST, &client_cfg.host_http_header)
.header(HOST, &client_cfg.http_header_host)
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())

View file

@ -1,6 +1,9 @@
use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite};
use futures_util::pin_mut;
use hyper::upgrade::Upgraded;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::select;
@ -22,21 +25,21 @@ pub(super) async fn propagate_read(
let mut buffer = vec![0u8; 8 * 1024];
pin_mut!(local_rx);
loop {
let read = select! {
let read_len = select! {
biased;
read_len = local_rx.read(buffer.as_mut_slice()) => read_len,
read_len = local_rx.read(&mut buffer) => read_len,
_ = close_tx.closed() => break,
_ = timeout(ping_frequency, futures_util::future::pending::<()>()) => {
debug!("sending ping to keep websocket connection alive");
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::Borrowed(&[]))).await?;
ws_tx.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut []))).await?;
continue;
}
};
let read_len = match read {
let read_len = match read_len {
Ok(read_len) if read_len > 0 => read_len,
Ok(_) => break,
Err(err) => {
@ -50,7 +53,7 @@ pub(super) async fn propagate_read(
trace!("read {} bytes", read_len);
match ws_tx
.write_frame(Frame::binary(Payload::Borrowed(&buffer[..read_len])))
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buffer[..read_len])))
.await
{
Ok(_) => {}
@ -60,8 +63,10 @@ pub(super) async fn propagate_read(
}
}
if read_len == buffer.len() {
buffer.resize(read_len * 2, 0);
if buffer.capacity() == read_len {
buffer.clear();
info!("capa: {} read:{}", buffer.capacity(), read_len);
buffer.resize(buffer.capacity() * 2, 0);
}
}
@ -85,14 +90,14 @@ pub(super) async fn propagate_write(
pin_mut!(local_tx);
loop {
let ret = select! {
let msg = select! {
biased;
ret = ws_rx.read_frame(&mut x) => ret,
msg = ws_rx.read_frame(&mut x) => msg,
_ = &mut close_rx => break,
};
let msg = match ret {
let msg = match msg {
Ok(msg) => msg,
Err(err) => {
error!("error while reading from websocket rx {}", err);

View file

@ -37,7 +37,6 @@ impl UdpServer {
cnx_timeout: timeout,
}
}
fn clean_dead_keys(&mut self) {
let nb_key_to_delete = self.keys_to_delete.read().unwrap().len();
if nb_key_to_delete == 0 {
@ -51,7 +50,6 @@ impl UdpServer {
}
keys_to_delete.clear();
}
fn clone_socket(&self) -> Arc<UdpSocket> {
self.listener.clone()
}
@ -80,7 +78,7 @@ impl AsyncRead for UdpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
obuf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let project = self.project();
if let Some(deadline) = project.deadline.as_pin_mut() {
@ -93,11 +91,11 @@ impl AsyncRead for UdpStream {
}
let mut guard = project.io.lock().unwrap();
let (inner, waker) = guard.deref_mut();
if inner.has_remaining() {
let max = inner.remaining().min(buf.remaining());
buf.put_slice(&inner[..max]);
inner.advance(max);
let (ibuf, waker) = guard.deref_mut();
if ibuf.has_remaining() {
let max = ibuf.remaining().min(obuf.remaining());
obuf.put_slice(&ibuf[..max]);
ibuf.advance(max);
Poll::Ready(Ok(()))
} else {
waker.replace(cx.waker().clone());