Add suport for http2 as transport for tunnel
This commit is contained in:
parent
cf3500dffb
commit
459a0667b1
9 changed files with 606 additions and 168 deletions
22
Cargo.lock
generated
22
Cargo.lock
generated
|
@ -613,6 +613,25 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "991910e35c615d8cab86b5ab04be67e6ad24d2bf5f4f11fdbbed26da999bbeab"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"http 1.0.0",
|
||||
"indexmap",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.14.3"
|
||||
|
@ -651,7 +670,7 @@ dependencies = [
|
|||
"futures-channel",
|
||||
"futures-io",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"h2 0.3.22",
|
||||
"http 0.2.11",
|
||||
"idna 0.4.0",
|
||||
"ipnet",
|
||||
|
@ -776,6 +795,7 @@ dependencies = [
|
|||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"h2 0.4.1",
|
||||
"http 1.0.0",
|
||||
"http-body",
|
||||
"httparse",
|
||||
|
|
|
@ -20,8 +20,8 @@ futures-util = { version = "0.3.30" }
|
|||
hickory-resolver = { version = "0.24.0", features = ["tokio", "dns-over-https-rustls", "dns-over-rustls"] }
|
||||
ppp = { version = "2.2.0", features = [] }
|
||||
|
||||
hyper = { version = "1.1.0", features = ["client", "http1"] }
|
||||
hyper-util = { version = "0.1.2", features = ["tokio"] }
|
||||
hyper = { version = "1.1.0", features = ["client", "http1", "http2"] }
|
||||
hyper-util = { version = "0.1.2", features = ["tokio", "server", "server-auto"] }
|
||||
http-body-util = { version = "0.1.0" }
|
||||
jsonwebtoken = { version = "9.2.0", default-features = false }
|
||||
log = "0.4.20"
|
||||
|
|
15
src/main.rs
15
src/main.rs
|
@ -72,7 +72,7 @@ struct Wstunnel {
|
|||
nb_worker_threads: Option<u32>,
|
||||
|
||||
/// Control the log verbosity. i.e: TRACE, DEBUG, INFO, WARN, ERROR, OFF
|
||||
/// for more details: https://docs.rs/env_logger/0.10.1/env_logger/#enabling-logging
|
||||
/// for more details: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#example-syntax
|
||||
#[arg(
|
||||
long,
|
||||
global = true,
|
||||
|
@ -81,7 +81,7 @@ struct Wstunnel {
|
|||
env = "RUST_LOG",
|
||||
default_value = "INFO"
|
||||
)]
|
||||
log_lvl: Directive,
|
||||
log_lvl: String,
|
||||
}
|
||||
|
||||
#[derive(clap::Subcommand, Debug)]
|
||||
|
@ -645,13 +645,14 @@ async fn main() {
|
|||
.count()
|
||||
> 0 => {}
|
||||
_ => {
|
||||
let mut env_filter = EnvFilter::builder().parse(&args.log_lvl).expect("Invalid log level");
|
||||
if !(args.log_lvl.contains("h2::") || args.log_lvl.contains("h2=")) {
|
||||
env_filter =
|
||||
env_filter.add_directive(Directive::from_str("h2::codec=off").expect("Invalid log directive"));
|
||||
}
|
||||
tracing_subscriber::fmt()
|
||||
.with_ansi(args.no_color.is_none())
|
||||
.with_env_filter(
|
||||
EnvFilter::builder()
|
||||
.with_default_directive(args.log_lvl)
|
||||
.from_env_lossy(),
|
||||
)
|
||||
.with_env_filter(env_filter)
|
||||
.init();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::{JwtTunnelConfig, RemoteAddr, JWT_DECODE};
|
||||
use super::{JwtTunnelConfig, RemoteAddr, TransportScheme, JWT_DECODE};
|
||||
use crate::tunnel::transport::{TunnelReader, TunnelWriter};
|
||||
use crate::{tunnel, WsClientConfig};
|
||||
use futures_util::pin_mut;
|
||||
use hyper::header::COOKIE;
|
||||
|
@ -13,52 +14,6 @@ use tracing::{error, span, Instrument, Level, Span};
|
|||
use url::Host;
|
||||
use uuid::Uuid;
|
||||
|
||||
//async fn connect_http2(
|
||||
// request_id: Uuid,
|
||||
// client_cfg: &WsClientConfig,
|
||||
// dest_addr: &RemoteAddr,
|
||||
//) -> anyhow::Result<BodyStream<Incoming>> {
|
||||
// let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
||||
// Ok(cnx) => Ok(cnx),
|
||||
// Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
|
||||
// }?;
|
||||
//
|
||||
// let mut req = Request::builder()
|
||||
// .method("GET")
|
||||
// .uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix))
|
||||
// .header(HOST, &client_cfg.http_header_host)
|
||||
// .header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr))
|
||||
// .version(hyper::Version::HTTP_2);
|
||||
//
|
||||
// for (k, v) in &client_cfg.http_headers {
|
||||
// req = req.header(k, v);
|
||||
// }
|
||||
// if let Some(auth) = &client_cfg.http_upgrade_credentials {
|
||||
// req = req.header(AUTHORIZATION, auth);
|
||||
// }
|
||||
//
|
||||
// let x: Vec<u8> = vec![];
|
||||
// //let bosy = StreamBody::new(stream::iter(vec![anyhow::Result::Ok(hyper::body::Frame::data(x.as_slice()))]));
|
||||
// let req = req.body(Empty::<Bytes>::new()).with_context(|| {
|
||||
// format!(
|
||||
// "failed to build HTTP request to contact the server {:?}",
|
||||
// client_cfg.remote_addr
|
||||
// )
|
||||
// })?;
|
||||
// debug!("with HTTP upgrade request {:?}", req);
|
||||
// let transport = pooled_cnx.deref_mut().take().unwrap();
|
||||
// let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()).handshake(TokioIo::new(transport)).await
|
||||
// .with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||
// tokio::spawn(cnx);
|
||||
//
|
||||
// let response = request_sender.send_request(req)
|
||||
// .await
|
||||
// .with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?;
|
||||
//
|
||||
// // TODO: verify response is ok
|
||||
// Ok(BodyStream::new(response.into_body()))
|
||||
//}
|
||||
|
||||
async fn connect_to_server<R, W>(
|
||||
request_id: Uuid,
|
||||
client_cfg: &WsClientConfig,
|
||||
|
@ -69,7 +24,20 @@ where
|
|||
R: AsyncRead + Send + 'static,
|
||||
W: AsyncWrite + Send + 'static,
|
||||
{
|
||||
let ((ws_rx, ws_tx), _) = tunnel::transport::websocket::connect(request_id, client_cfg, remote_cfg).await?;
|
||||
// Connect to server with the correct protocol
|
||||
let (ws_rx, ws_tx) = match client_cfg.remote_addr.scheme() {
|
||||
TransportScheme::Ws | TransportScheme::Wss => {
|
||||
tunnel::transport::websocket::connect(request_id, client_cfg, remote_cfg)
|
||||
.await
|
||||
.map(|(r, w, _response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w)))?
|
||||
}
|
||||
TransportScheme::Http | TransportScheme::Https => {
|
||||
tunnel::transport::http2::connect(request_id, client_cfg, remote_cfg)
|
||||
.await
|
||||
.map(|(r, w, _response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w)))?
|
||||
}
|
||||
};
|
||||
|
||||
let (local_rx, local_tx) = duplex_stream;
|
||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||
|
||||
|
@ -117,7 +85,7 @@ where
|
|||
}
|
||||
|
||||
pub async fn run_reverse_tunnel<F, Fut, T>(
|
||||
client_config: Arc<WsClientConfig>,
|
||||
client_cfg: Arc<WsClientConfig>,
|
||||
remote_addr: RemoteAddr,
|
||||
connect_to_dest: F,
|
||||
) -> anyhow::Result<()>
|
||||
|
@ -127,7 +95,7 @@ where
|
|||
T: AsyncRead + AsyncWrite + Send + 'static,
|
||||
{
|
||||
loop {
|
||||
let client_config = client_config.clone();
|
||||
let client_config = client_cfg.clone();
|
||||
let request_id = Uuid::now_v7();
|
||||
let span = span!(
|
||||
Level::INFO,
|
||||
|
@ -136,16 +104,25 @@ where
|
|||
remote = format!("{}:{}", remote_addr.host, remote_addr.port)
|
||||
);
|
||||
let _span = span.enter();
|
||||
|
||||
// Correctly configure tunnel cfg
|
||||
let ((ws_rx, ws_tx), response) =
|
||||
tunnel::transport::websocket::connect(request_id, &client_config, &remote_addr)
|
||||
let (ws_rx, ws_tx, response) = match client_cfg.remote_addr.scheme() {
|
||||
TransportScheme::Ws | TransportScheme::Wss => {
|
||||
tunnel::transport::websocket::connect(request_id, &client_cfg, &remote_addr)
|
||||
.instrument(span.clone())
|
||||
.await?;
|
||||
.await
|
||||
.map(|(r, w, response)| (TunnelReader::Websocket(r), TunnelWriter::Websocket(w), response))?
|
||||
}
|
||||
TransportScheme::Http | TransportScheme::Https => {
|
||||
tunnel::transport::http2::connect(request_id, &client_cfg, &remote_addr)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
.map(|(r, w, response)| (TunnelReader::Http2(r), TunnelWriter::Http2(w), response))?
|
||||
}
|
||||
};
|
||||
|
||||
// Connect to endpoint
|
||||
let remote = response
|
||||
.headers()
|
||||
.headers
|
||||
.get(COOKIE)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.and_then(|h| {
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use ahash::{HashMap, HashMapExt};
|
||||
use anyhow::anyhow;
|
||||
use bytes::Bytes;
|
||||
use futures_util::{pin_mut, FutureExt, Stream, StreamExt};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyStream, Either, StreamBody};
|
||||
use std::cmp::min;
|
||||
use std::fmt::Debug;
|
||||
use std::future::Future;
|
||||
|
@ -12,24 +15,28 @@ use std::time::Duration;
|
|||
|
||||
use super::{tunnel_to_jwt_token, JwtTunnelConfig, RemoteAddr, JWT_DECODE, JWT_HEADER_PREFIX};
|
||||
use crate::{socks5, tcp, tls, udp, LocalProtocol, TlsServerConfig, WsServerConfig};
|
||||
use hyper::body::Incoming;
|
||||
use hyper::body::{Frame, Incoming};
|
||||
use hyper::header::{COOKIE, SEC_WEBSOCKET_PROTOCOL};
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::server::conn::http1;
|
||||
use hyper::server::conn::{http1, http2};
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{http, Request, Response, StatusCode};
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use jsonwebtoken::TokenData;
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use crate::socks5::Socks5Stream;
|
||||
use crate::tunnel::tls_reloader::TlsReloader;
|
||||
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
|
||||
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
|
||||
use crate::udp::UdpStream;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::select;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tracing::{error, info, span, warn, Instrument, Level, Span};
|
||||
use url::Host;
|
||||
use uuid::Uuid;
|
||||
|
@ -284,6 +291,7 @@ fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<JwtTunnelCon
|
|||
.and_then(|header| header.to_str().ok())
|
||||
.and_then(|header| header.split_once(JWT_HEADER_PREFIX))
|
||||
.map(|(_prefix, jwt)| jwt)
|
||||
.or_else(|| req.headers().get(COOKIE).and_then(|header| header.to_str().ok()))
|
||||
.unwrap_or_default();
|
||||
|
||||
let (validation, decode_key) = JWT_DECODE.deref();
|
||||
|
@ -327,7 +335,7 @@ fn validate_destination(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn server_upgrade(
|
||||
async fn ws_server_upgrade(
|
||||
server_config: Arc<WsServerConfig>,
|
||||
mut client_addr: SocketAddr,
|
||||
mut req: Request<Incoming>,
|
||||
|
@ -404,10 +412,17 @@ async fn server_upgrade(
|
|||
ws_tx.set_auto_apply_mask(server_config.websocket_mask_frame);
|
||||
|
||||
tokio::task::spawn(
|
||||
super::transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()),
|
||||
super::transport::io::propagate_remote_to_local(local_tx, WebsocketTunnelRead::new(ws_rx), close_rx)
|
||||
.instrument(Span::current()),
|
||||
);
|
||||
|
||||
let _ = super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, None).await;
|
||||
let _ = super::transport::io::propagate_local_to_remote(
|
||||
local_rx,
|
||||
WebsocketTunnelWrite::new(ws_tx),
|
||||
close_tx,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
.instrument(Span::current()),
|
||||
);
|
||||
|
@ -429,6 +444,92 @@ async fn server_upgrade(
|
|||
Response::from_parts(response.into_parts().0, "".to_string())
|
||||
}
|
||||
|
||||
async fn http_server_upgrade(
|
||||
server_config: Arc<WsServerConfig>,
|
||||
mut client_addr: SocketAddr,
|
||||
req: Request<Incoming>,
|
||||
) -> Response<Either<String, BoxBody<Bytes, anyhow::Error>>> {
|
||||
match extract_x_forwarded_for(&req) {
|
||||
Ok(Some((x_forward_for, x_forward_for_str))) => {
|
||||
info!("Request X-Forwarded-For: {:?}", x_forward_for);
|
||||
Span::current().record("forwarded_for", x_forward_for_str);
|
||||
client_addr.set_ip(x_forward_for);
|
||||
}
|
||||
Ok(_) => {}
|
||||
Err(err) => return err.map(Either::Left),
|
||||
};
|
||||
|
||||
if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) {
|
||||
return err.map(Either::Left);
|
||||
}
|
||||
|
||||
let jwt = match extract_tunnel_info(&req) {
|
||||
Ok(jwt) => jwt,
|
||||
Err(err) => return err.map(Either::Left),
|
||||
};
|
||||
|
||||
Span::current().record("id", &jwt.claims.id);
|
||||
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
||||
|
||||
if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) {
|
||||
return err.map(Either::Left);
|
||||
}
|
||||
|
||||
let req_protocol = jwt.claims.p.clone();
|
||||
let tunnel = match run_tunnel(&server_config, jwt, client_addr).await {
|
||||
Ok(ret) => ret,
|
||||
Err(err) => {
|
||||
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||
return http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Either::Left("Invalid upgrade request".to_string()))
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
let (remote_addr, local_rx, local_tx) = tunnel;
|
||||
info!("connected to {:?} {}:{}", req_protocol, remote_addr.host, remote_addr.port);
|
||||
|
||||
let ws_rx = BodyStream::new(req.into_body());
|
||||
let (ws_tx, rx) = mpsc::channel::<Bytes>(1024);
|
||||
let body = BoxBody::new(StreamBody::new(
|
||||
ReceiverStream::new(rx).map(|s| -> anyhow::Result<Frame<Bytes>> { Ok(Frame::data(s)) }),
|
||||
));
|
||||
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.body(Either::Right(body))
|
||||
.expect("bug: failed to build response");
|
||||
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||
tokio::task::spawn(
|
||||
super::transport::io::propagate_remote_to_local(local_tx, Http2TunnelRead::new(ws_rx), close_rx)
|
||||
.instrument(Span::current()),
|
||||
);
|
||||
|
||||
let _ =
|
||||
super::transport::io::propagate_local_to_remote(local_rx, Http2TunnelWrite::new(ws_tx), close_tx, None)
|
||||
.await;
|
||||
}
|
||||
.instrument(Span::current()),
|
||||
);
|
||||
|
||||
if req_protocol == LocalProtocol::ReverseSocks5 {
|
||||
let Ok(header_val) = HeaderValue::from_str(&tunnel_to_jwt_token(Uuid::from_u128(0), &remote_addr)) else {
|
||||
error!("Bad header value for reverse socks5: {} {}", remote_addr.host, remote_addr.port);
|
||||
return http::Response::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Either::Left("Invalid upgrade request".to_string()))
|
||||
.unwrap();
|
||||
};
|
||||
response.headers_mut().insert(COOKIE, header_val);
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
struct TlsContext<'a> {
|
||||
tls_acceptor: Arc<TlsAcceptor>,
|
||||
tls_reloader: TlsReloader,
|
||||
|
@ -452,10 +553,32 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
|||
info!("Starting wstunnel server listening on {}", server_config.bind);
|
||||
|
||||
// setup upgrade request handler
|
||||
// FIXME: Avoid double clone of the arc for each request
|
||||
let mk_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| {
|
||||
let mk_websocket_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| {
|
||||
move |req: Request<Incoming>| {
|
||||
server_upgrade(server_config.clone(), client_addr, req).map::<anyhow::Result<_>, _>(Ok)
|
||||
ws_server_upgrade(server_config.clone(), client_addr, req).map::<anyhow::Result<_>, _>(Ok)
|
||||
}
|
||||
};
|
||||
|
||||
let mk_http_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| {
|
||||
move |req: Request<Incoming>| {
|
||||
http_server_upgrade(server_config.clone(), client_addr, req).map::<anyhow::Result<_>, _>(Ok)
|
||||
}
|
||||
};
|
||||
|
||||
let mk_auto_upgrade_fn = |server_config: Arc<WsServerConfig>, client_addr: SocketAddr| {
|
||||
move |req: Request<Incoming>| {
|
||||
let server_config = server_config.clone();
|
||||
async move {
|
||||
if !fastwebsockets::upgrade::is_upgrade_request(&req) {
|
||||
http_server_upgrade(server_config.clone(), client_addr, req)
|
||||
.map::<anyhow::Result<_>, _>(Ok)
|
||||
.await
|
||||
} else {
|
||||
ws_server_upgrade(server_config.clone(), client_addr, req)
|
||||
.map(|response| Ok::<_, anyhow::Error>(response.map(Either::Left)))
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -493,9 +616,11 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
|||
);
|
||||
|
||||
info!("Accepting connection");
|
||||
let upgrade_fn = mk_upgrade_fn(server_config.clone(), peer_addr);
|
||||
// TLS
|
||||
if let Some(tls) = tls_context.as_mut() {
|
||||
let server_config = server_config.clone();
|
||||
|
||||
// Check if we need to enable TLS or not
|
||||
match tls_context.as_mut() {
|
||||
Some(tls) => {
|
||||
// Reload TLS certificate if needed
|
||||
let tls_acceptor = tls.tls_acceptor().clone();
|
||||
let fut = async move {
|
||||
|
@ -508,32 +633,58 @@ pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()
|
|||
}
|
||||
};
|
||||
|
||||
match tls_stream.inner().get_ref().1.alpn_protocol() {
|
||||
// http2
|
||||
Some(b"h2") => {
|
||||
let mut conn_builder = http2::Builder::new(TokioExecutor::new());
|
||||
if let Some(ping) = server_config.websocket_ping_frequency {
|
||||
conn_builder.keep_alive_interval(ping);
|
||||
}
|
||||
|
||||
let http_upgrade_fn = mk_http_upgrade_fn(server_config, peer_addr);
|
||||
let con_fut = conn_builder.serve_connection(tls_stream, service_fn(http_upgrade_fn));
|
||||
if let Err(e) = con_fut.await {
|
||||
error!("Error while upgrading cnx to http: {:?}", e);
|
||||
}
|
||||
}
|
||||
// websocket
|
||||
_ => {
|
||||
let websocket_upgrade_fn = mk_websocket_upgrade_fn(server_config, peer_addr);
|
||||
let conn_fut = http1::Builder::new()
|
||||
.serve_connection(tls_stream, service_fn(upgrade_fn))
|
||||
.serve_connection(tls_stream, service_fn(websocket_upgrade_fn))
|
||||
.with_upgrades();
|
||||
|
||||
if let Err(e) = conn_fut.await {
|
||||
error!("Error while upgrading cnx to websocket: {:?}", e);
|
||||
error!("Error while upgrading cnx: {:?}", e);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
.instrument(span);
|
||||
|
||||
tokio::spawn(fut);
|
||||
// Normal
|
||||
} else {
|
||||
let stream = hyper_util::rt::TokioIo::new(stream);
|
||||
let conn_fut = http1::Builder::new()
|
||||
.serve_connection(stream, service_fn(upgrade_fn))
|
||||
.with_upgrades();
|
||||
|
||||
}
|
||||
// HTTP without TLS
|
||||
None => {
|
||||
let fut = async move {
|
||||
if let Err(e) = conn_fut.await {
|
||||
let stream = hyper_util::rt::TokioIo::new(stream);
|
||||
let mut conn_fut = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
|
||||
if let Some(ping) = server_config.websocket_ping_frequency {
|
||||
conn_fut.http2().keep_alive_interval(ping);
|
||||
}
|
||||
|
||||
let websocket_upgrade_fn = mk_auto_upgrade_fn(server_config, peer_addr);
|
||||
let upgradable = conn_fut.serve_connection_with_upgrades(stream, service_fn(websocket_upgrade_fn));
|
||||
|
||||
if let Err(e) = upgradable.await {
|
||||
error!("Error while upgrading cnx to websocket: {:?}", e);
|
||||
}
|
||||
}
|
||||
.instrument(span);
|
||||
|
||||
tokio::spawn(fut);
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
168
src/tunnel/transport/http2.rs
Normal file
168
src/tunnel/transport/http2.rs
Normal file
|
@ -0,0 +1,168 @@
|
|||
use crate::tunnel::transport::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
|
||||
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr};
|
||||
use crate::WsClientConfig;
|
||||
use anyhow::{anyhow, Context};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use http_body_util::{BodyExt, BodyStream, StreamBody};
|
||||
use hyper::body::{Frame, Incoming};
|
||||
use hyper::header::{AUTHORIZATION, COOKIE, HOST};
|
||||
use hyper::http::response::Parts;
|
||||
use hyper::Request;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use log::{debug, error, warn};
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use std::ops::DerefMut;
|
||||
use tokio::io::{AsyncWrite, AsyncWriteExt};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct Http2TunnelRead {
|
||||
inner: BodyStream<Incoming>,
|
||||
}
|
||||
|
||||
impl Http2TunnelRead {
|
||||
pub fn new(inner: BodyStream<Incoming>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelRead for Http2TunnelRead {
|
||||
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> Result<(), io::Error> {
|
||||
loop {
|
||||
match self.inner.next().await {
|
||||
Some(Ok(frame)) => match frame.into_data() {
|
||||
Ok(data) => {
|
||||
return match writer.write_all(data.as_ref()).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("{:?}", err);
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Some(Err(err)) => {
|
||||
return Err(io::Error::new(ErrorKind::ConnectionAborted, err));
|
||||
}
|
||||
None => return Err(io::Error::new(ErrorKind::BrokenPipe, "closed")),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Http2TunnelWrite {
|
||||
inner: mpsc::Sender<Bytes>,
|
||||
buf: BytesMut,
|
||||
}
|
||||
|
||||
impl Http2TunnelWrite {
|
||||
pub fn new(inner: mpsc::Sender<Bytes>) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
buf: BytesMut::with_capacity(MAX_PACKET_LENGTH * 20), // ~ 1Mb
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelWrite for Http2TunnelWrite {
|
||||
fn buf_mut(&mut self) -> &mut BytesMut {
|
||||
&mut self.buf
|
||||
}
|
||||
|
||||
async fn write(&mut self) -> Result<(), io::Error> {
|
||||
let data = self.buf.split().freeze();
|
||||
let ret = match self.inner.send(data).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
|
||||
};
|
||||
|
||||
if self.buf.capacity() < MAX_PACKET_LENGTH {
|
||||
//info!("read {} Kb {} Kb", self.buf.capacity() / 1024, old_capa / 1024);
|
||||
self.buf.reserve(MAX_PACKET_LENGTH * 4)
|
||||
}
|
||||
|
||||
ret
|
||||
}
|
||||
|
||||
async fn ping(&mut self) -> Result<(), io::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<(), io::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
request_id: Uuid,
|
||||
client_cfg: &WsClientConfig,
|
||||
dest_addr: &RemoteAddr,
|
||||
) -> anyhow::Result<(Http2TunnelRead, Http2TunnelWrite, Parts)> {
|
||||
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
||||
Ok(cnx) => Ok(cnx),
|
||||
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
|
||||
}?;
|
||||
|
||||
let mut req = Request::builder()
|
||||
.method("POST")
|
||||
.uri(format!(
|
||||
"{}://{}:{}/{}/events",
|
||||
client_cfg.remote_addr.scheme(),
|
||||
client_cfg.remote_addr.host(),
|
||||
client_cfg.remote_addr.port(),
|
||||
&client_cfg.http_upgrade_path_prefix
|
||||
))
|
||||
.header(HOST, &client_cfg.http_header_host)
|
||||
.header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr))
|
||||
.version(hyper::Version::HTTP_2);
|
||||
|
||||
for (k, v) in &client_cfg.http_headers {
|
||||
req = req.header(k, v);
|
||||
}
|
||||
if let Some(auth) = &client_cfg.http_upgrade_credentials {
|
||||
req = req.header(AUTHORIZATION, auth);
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(1024);
|
||||
let body = StreamBody::new(ReceiverStream::new(rx).map(|s| -> anyhow::Result<Frame<Bytes>> { Ok(Frame::data(s)) }));
|
||||
let req = req.body(body).with_context(|| {
|
||||
format!(
|
||||
"failed to build HTTP request to contact the server {:?}",
|
||||
client_cfg.remote_addr
|
||||
)
|
||||
})?;
|
||||
debug!("with HTTP upgrade request {:?}", req);
|
||||
let transport = pooled_cnx.deref_mut().take().unwrap();
|
||||
let (mut request_sender, cnx) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.timer(TokioTimer::new())
|
||||
.keep_alive_interval(client_cfg.websocket_ping_frequency)
|
||||
.keep_alive_while_idle(false)
|
||||
.handshake(TokioIo::new(transport))
|
||||
.await
|
||||
.with_context(|| format!("failed to do http2 handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = cnx.await {
|
||||
error!("{:?}", err)
|
||||
}
|
||||
});
|
||||
|
||||
let response = request_sender
|
||||
.send_request(req)
|
||||
.await
|
||||
.with_context(|| format!("failed to send http2 request with the server {:?}", client_cfg.remote_addr))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(anyhow!(
|
||||
"Http2 server rejected the connection: {:?}: {:?}",
|
||||
response.status(),
|
||||
String::from_utf8(response.into_body().collect().await?.to_bytes().to_vec()).unwrap_or_default()
|
||||
));
|
||||
}
|
||||
|
||||
let (parts, body) = response.into_parts();
|
||||
Ok((Http2TunnelRead::new(BodyStream::new(body)), Http2TunnelWrite::new(tx), parts))
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
|
||||
use bytes::BufMut;
|
||||
use futures_util::{pin_mut, FutureExt};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
|
@ -6,7 +7,7 @@ use tokio::select;
|
|||
use tokio::sync::oneshot;
|
||||
use tokio::time::Instant;
|
||||
use tracing::log::debug;
|
||||
use tracing::{error, info, trace, warn};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
pub async fn propagate_local_to_remote(
|
||||
local_rx: impl AsyncRead,
|
||||
|
@ -19,7 +20,6 @@ pub async fn propagate_local_to_remote(
|
|||
});
|
||||
|
||||
static MAX_PACKET_LENGTH: usize = 64 * 1024;
|
||||
let mut buffer = vec![0u8; MAX_PACKET_LENGTH];
|
||||
|
||||
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
|
||||
// We reuse the future to avoid creating a timer in the tight loop
|
||||
|
@ -32,21 +32,26 @@ pub async fn propagate_local_to_remote(
|
|||
pin_mut!(should_close);
|
||||
pin_mut!(local_rx);
|
||||
loop {
|
||||
debug_assert!(
|
||||
ws_tx.buf_mut().chunk_mut().len() >= MAX_PACKET_LENGTH,
|
||||
"buffer must be large enough to receive a whole packet length"
|
||||
);
|
||||
|
||||
let read_len = select! {
|
||||
biased;
|
||||
|
||||
read_len = local_rx.read(&mut buffer) => read_len,
|
||||
read_len = local_rx.read_buf(ws_tx.buf_mut()) => read_len,
|
||||
|
||||
_ = &mut should_close => break,
|
||||
|
||||
_ = timeout.tick(), if ping_frequency.is_some() => {
|
||||
debug!("sending ping to keep websocket connection alive");
|
||||
debug!("sending ping to keep connection alive");
|
||||
ws_tx.ping().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let read_len = match read_len {
|
||||
let _read_len = match read_len {
|
||||
Ok(0) => break,
|
||||
Ok(read_len) => read_len,
|
||||
Err(err) => {
|
||||
|
@ -56,27 +61,10 @@ pub async fn propagate_local_to_remote(
|
|||
};
|
||||
|
||||
//debug!("read {} wasted {}% usable {} capa {}", read_len, 100 - (read_len * 100 / buffer.capacity()), buffer.as_slice().len(), buffer.capacity());
|
||||
if let Err(err) = ws_tx.write(&buffer[..read_len]).await {
|
||||
warn!("error while writing to websocket tx tunnel {}", err);
|
||||
if let Err(err) = ws_tx.write().await {
|
||||
warn!("error while writing to tx tunnel {}", err);
|
||||
break;
|
||||
}
|
||||
|
||||
// If the buffer has been completely filled with previous read, Double it !
|
||||
// For the buffer to not be a bottleneck when the TCP window scale
|
||||
// For udp, the buffer will never grows.
|
||||
if buffer.capacity() == read_len {
|
||||
buffer.clear();
|
||||
let new_size = buffer.capacity() + (buffer.capacity() / 4); // grow buffer by 1.25 %
|
||||
buffer.reserve_exact(new_size);
|
||||
buffer.resize(buffer.capacity(), 0);
|
||||
trace!(
|
||||
"Buffer {} Mb {} {} {}",
|
||||
buffer.capacity() as f64 / 1024.0 / 1024.0,
|
||||
new_size,
|
||||
buffer.as_slice().len(),
|
||||
buffer.capacity()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Send normal close
|
||||
|
@ -103,7 +91,7 @@ pub async fn propagate_remote_to_local(
|
|||
};
|
||||
|
||||
if let Err(err) = msg {
|
||||
error!("error while reading from websocket rx {}", err);
|
||||
error!("error while reading from tunnel rx {}", err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,15 +1,74 @@
|
|||
use crate::tunnel::transport::http2::{Http2TunnelRead, Http2TunnelWrite};
|
||||
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
|
||||
use bytes::BytesMut;
|
||||
use std::future::Future;
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
pub mod http2;
|
||||
pub mod io;
|
||||
pub mod websocket;
|
||||
|
||||
static MAX_PACKET_LENGTH: usize = 64 * 1024;
|
||||
|
||||
pub trait TunnelWrite: Send + 'static {
|
||||
fn write(&mut self, buf: &[u8]) -> impl Future<Output = anyhow::Result<()>> + Send;
|
||||
fn ping(&mut self) -> impl Future<Output = anyhow::Result<()>> + Send;
|
||||
fn close(&mut self) -> impl Future<Output = anyhow::Result<()>> + Send;
|
||||
fn buf_mut(&mut self) -> &mut BytesMut;
|
||||
fn write(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
|
||||
fn ping(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
|
||||
fn close(&mut self) -> impl Future<Output = Result<(), std::io::Error>> + Send;
|
||||
}
|
||||
|
||||
pub trait TunnelRead: Send + 'static {
|
||||
fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> impl Future<Output = anyhow::Result<()>> + Send;
|
||||
fn copy(
|
||||
&mut self,
|
||||
writer: impl AsyncWrite + Unpin + Send,
|
||||
) -> impl Future<Output = Result<(), std::io::Error>> + Send;
|
||||
}
|
||||
|
||||
pub enum TunnelReader {
|
||||
Websocket(WebsocketTunnelRead),
|
||||
Http2(Http2TunnelRead),
|
||||
}
|
||||
|
||||
impl TunnelRead for TunnelReader {
|
||||
async fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> Result<(), std::io::Error> {
|
||||
match self {
|
||||
TunnelReader::Websocket(s) => s.copy(writer).await,
|
||||
TunnelReader::Http2(s) => s.copy(writer).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum TunnelWriter {
|
||||
Websocket(WebsocketTunnelWrite),
|
||||
Http2(Http2TunnelWrite),
|
||||
}
|
||||
|
||||
impl TunnelWrite for TunnelWriter {
|
||||
fn buf_mut(&mut self) -> &mut BytesMut {
|
||||
match self {
|
||||
TunnelWriter::Websocket(s) => s.buf_mut(),
|
||||
TunnelWriter::Http2(s) => s.buf_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn write(&mut self) -> Result<(), std::io::Error> {
|
||||
match self {
|
||||
TunnelWriter::Websocket(s) => s.write().await,
|
||||
TunnelWriter::Http2(s) => s.write().await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn ping(&mut self) -> Result<(), std::io::Error> {
|
||||
match self {
|
||||
TunnelWriter::Websocket(s) => s.ping().await,
|
||||
TunnelWriter::Http2(s) => s.ping().await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<(), std::io::Error> {
|
||||
match self {
|
||||
TunnelWriter::Websocket(s) => s.close().await,
|
||||
TunnelWriter::Http2(s) => s.close().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,40 +1,105 @@
|
|||
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
|
||||
use crate::tunnel::transport::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
|
||||
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX};
|
||||
use crate::WsClientConfig;
|
||||
use anyhow::{anyhow, Context};
|
||||
use bytes::Bytes;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
|
||||
use http_body_util::Empty;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE};
|
||||
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
|
||||
use hyper::http::response::Parts;
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper::{Request, Response};
|
||||
use hyper::Request;
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use log::debug;
|
||||
use std::io;
|
||||
use std::io::ErrorKind;
|
||||
use std::ops::DerefMut;
|
||||
use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||
use tracing::trace;
|
||||
use uuid::Uuid;
|
||||
|
||||
impl TunnelWrite for WebSocketWrite<WriteHalf<TokioIo<Upgraded>>> {
|
||||
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()> {
|
||||
self.write_frame(Frame::binary(Payload::Borrowed(buf)))
|
||||
.await
|
||||
.with_context(|| "cannot send ws frame")
|
||||
pub struct WebsocketTunnelWrite {
|
||||
inner: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
|
||||
buf: BytesMut,
|
||||
}
|
||||
|
||||
impl WebsocketTunnelWrite {
|
||||
pub fn new(ws: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>) -> Self {
|
||||
Self {
|
||||
inner: ws,
|
||||
buf: BytesMut::with_capacity(MAX_PACKET_LENGTH),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelWrite for WebsocketTunnelWrite {
|
||||
fn buf_mut(&mut self) -> &mut BytesMut {
|
||||
&mut self.buf
|
||||
}
|
||||
|
||||
async fn ping(&mut self) -> anyhow::Result<()> {
|
||||
self.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut [])))
|
||||
.await
|
||||
.with_context(|| "cannot send ws ping")
|
||||
async fn write(&mut self) -> Result<(), io::Error> {
|
||||
let read_len = self.buf.len();
|
||||
let buf = &mut self.buf;
|
||||
|
||||
let ret = self
|
||||
.inner
|
||||
.write_frame(Frame::binary(Payload::BorrowedMut(&mut buf[..read_len])))
|
||||
.await;
|
||||
|
||||
if let Err(err) = ret {
|
||||
return Err(io::Error::new(ErrorKind::ConnectionAborted, err));
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> anyhow::Result<()> {
|
||||
self.write_frame(Frame::close(1000, &[]))
|
||||
// If the buffer has been completely filled with previous read, Grows it !
|
||||
// For the buffer to not be a bottleneck when the TCP window scale
|
||||
// For udp, the buffer will never grows.
|
||||
buf.clear();
|
||||
if buf.capacity() == read_len {
|
||||
let new_size = buf.capacity() + (buf.capacity() / 4); // grow buffer by 1.25 %
|
||||
buf.reserve(new_size);
|
||||
buf.resize(buf.capacity(), 0);
|
||||
trace!(
|
||||
"Buffer {} Mb {} {} {}",
|
||||
buf.capacity() as f64 / 1024.0 / 1024.0,
|
||||
new_size,
|
||||
buf.len(),
|
||||
buf.capacity()
|
||||
)
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ping(&mut self) -> Result<(), io::Error> {
|
||||
if let Err(err) = self
|
||||
.inner
|
||||
.write_frame(Frame::new(true, OpCode::Ping, None, Payload::BorrowedMut(&mut [])))
|
||||
.await
|
||||
.with_context(|| "cannot close websocket cnx")
|
||||
{
|
||||
return Err(io::Error::new(ErrorKind::BrokenPipe, err));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> Result<(), io::Error> {
|
||||
if let Err(err) = self.inner.write_frame(Frame::close(1000, &[])).await {
|
||||
return Err(io::Error::new(ErrorKind::BrokenPipe, err));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WebsocketTunnelRead {
|
||||
inner: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>,
|
||||
}
|
||||
|
||||
impl WebsocketTunnelRead {
|
||||
pub fn new(ws: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>) -> Self {
|
||||
Self { inner: ws }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -42,21 +107,24 @@ fn frame_reader(x: Frame<'_>) -> futures_util::future::Ready<anyhow::Result<()>>
|
|||
debug!("frame {:?} {:?}", x.opcode, x.payload);
|
||||
futures_util::future::ready(anyhow::Ok(()))
|
||||
}
|
||||
impl TunnelRead for WebSocketRead<ReadHalf<TokioIo<Upgraded>>> {
|
||||
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> anyhow::Result<()> {
|
||||
|
||||
impl TunnelRead for WebsocketTunnelRead {
|
||||
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> Result<(), io::Error> {
|
||||
loop {
|
||||
let msg = self
|
||||
.read_frame(&mut frame_reader)
|
||||
.await
|
||||
.with_context(|| "error while reading from websocket")?;
|
||||
let msg = match self.inner.read_frame(&mut frame_reader).await {
|
||||
Ok(msg) => msg,
|
||||
Err(err) => return Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
|
||||
};
|
||||
|
||||
trace!("receive ws frame {:?} {:?}", msg.opcode, msg.payload);
|
||||
match msg.opcode {
|
||||
OpCode::Continuation | OpCode::Text | OpCode::Binary => {
|
||||
writer.write_all(msg.payload.as_ref()).await.with_context(|| "")?;
|
||||
return Ok(());
|
||||
return match writer.write_all(msg.payload.as_ref()).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(io::Error::new(ErrorKind::ConnectionAborted, err)),
|
||||
}
|
||||
OpCode::Close => return Err(anyhow!("websocket close")),
|
||||
}
|
||||
OpCode::Close => return Err(io::Error::new(ErrorKind::NotConnected, "websocket close")),
|
||||
OpCode::Ping => continue,
|
||||
OpCode::Pong => continue,
|
||||
};
|
||||
|
@ -68,7 +136,7 @@ pub async fn connect(
|
|||
request_id: Uuid,
|
||||
client_cfg: &WsClientConfig,
|
||||
dest_addr: &RemoteAddr,
|
||||
) -> anyhow::Result<((impl TunnelRead, impl TunnelWrite), Response<Incoming>)> {
|
||||
) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite, Parts)> {
|
||||
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
||||
Ok(cnx) => Ok(cnx),
|
||||
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
|
||||
|
@ -76,7 +144,7 @@ pub async fn connect(
|
|||
|
||||
let mut req = Request::builder()
|
||||
.method("GET")
|
||||
.uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix,))
|
||||
.uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix))
|
||||
.header(HOST, &client_cfg.http_header_host)
|
||||
.header(UPGRADE, "websocket")
|
||||
.header(CONNECTION, "upgrade")
|
||||
|
@ -109,5 +177,11 @@ pub async fn connect(
|
|||
|
||||
ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
|
||||
|
||||
Ok((ws.split(tokio::io::split), response))
|
||||
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
|
||||
|
||||
Ok((
|
||||
WebsocketTunnelRead::new(ws_rx),
|
||||
WebsocketTunnelWrite::new(ws_tx),
|
||||
response.into_parts().0,
|
||||
))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue