diff --git a/src/tunnel/client.rs b/src/tunnel/client.rs index dc4ba7b..fee2ab4 100644 --- a/src/tunnel/client.rs +++ b/src/tunnel/client.rs @@ -2,6 +2,7 @@ use super::{to_host_port, JwtTunnelConfig, JWT_KEY}; use crate::{LocalToRemote, WsClientConfig}; use anyhow::{anyhow, Context}; +use base64::Engine; use fastwebsockets::WebSocket; use futures_util::pin_mut; use hyper::header::{AUTHORIZATION, COOKIE, SEC_WEBSOCKET_VERSION, UPGRADE}; @@ -186,7 +187,8 @@ where .and_then(|h| { h.to_str() .ok() - .and_then(|s| Url::parse(s).ok()) + .and_then(|s| base64::engine::general_purpose::STANDARD.decode(s).ok()) + .and_then(|s| Url::parse(&String::from_utf8_lossy(&s)).ok()) .and_then(|url| match (url.host(), url.port()) { (Some(h), Some(p)) => Some((h.to_owned(), p)), _ => None, diff --git a/src/tunnel/io.rs b/src/tunnel/io.rs index 6a3ab16..088e8ec 100644 --- a/src/tunnel/io.rs +++ b/src/tunnel/io.rs @@ -1,5 +1,5 @@ use fastwebsockets::{Frame, OpCode, Payload, WebSocketError, WebSocketRead, WebSocketWrite}; -use futures_util::pin_mut; +use futures_util::{pin_mut, FutureExt}; use hyper::upgrade::Upgraded; use std::time::Duration; @@ -28,8 +28,10 @@ pub(super) async fn propagate_read( let frequency = ping_frequency.unwrap_or(Duration::from_secs(3600 * 24)); let start_at = Instant::now().checked_add(frequency).unwrap_or(Instant::now()); let timeout = tokio::time::interval_at(start_at, frequency); - pin_mut!(timeout); + let should_close = close_tx.closed().fuse(); + pin_mut!(timeout); + pin_mut!(should_close); pin_mut!(local_rx); loop { let read_len = select! { @@ -37,7 +39,7 @@ pub(super) async fn propagate_read( read_len = local_rx.read(&mut buffer) => read_len, - _ = close_tx.closed() => break, + _ = &mut should_close => break, _ = timeout.tick(), if ping_frequency.is_some() => { debug!("sending ping to keep websocket connection alive"); diff --git a/src/tunnel/server.rs b/src/tunnel/server.rs index cae18b4..653356a 100644 --- a/src/tunnel/server.rs +++ b/src/tunnel/server.rs @@ -1,5 +1,6 @@ use ahash::{HashMap, HashMapExt}; use anyhow::anyhow; +use base64::Engine; use futures_util::{pin_mut, Stream, StreamExt}; use std::cmp::min; use std::fmt::Debug; @@ -264,9 +265,14 @@ async fn server_upgrade( .instrument(Span::current()), ); - response - .headers_mut() - .insert(COOKIE, HeaderValue::from_str(&format!("fake://{}:{}", dest, port)).unwrap()); + if protocol == LocalProtocol::ReverseSocks5 { + response.headers_mut().insert( + COOKIE, + HeaderValue::from_str( + &base64::engine::general_purpose::STANDARD.encode(format!("fake://{}:{}", dest, port)), + )?, + ); + } Ok(response) }