cleanup transport addr and scheme
This commit is contained in:
parent
3eef03d8c4
commit
ebd7591b34
8 changed files with 177 additions and 107 deletions
|
@ -15,7 +15,7 @@ pub async fn propagate_local_to_remote(
|
|||
ping_frequency: Option<Duration>,
|
||||
) -> anyhow::Result<()> {
|
||||
let _guard = scopeguard::guard((), |_| {
|
||||
info!("Closing local ==>> remote tunnel");
|
||||
info!("Closing local => remote tunnel");
|
||||
});
|
||||
|
||||
static MAX_PACKET_LENGTH: usize = 64 * 1024;
|
||||
|
@ -86,12 +86,12 @@ pub async fn propagate_local_to_remote(
|
|||
}
|
||||
|
||||
pub async fn propagate_remote_to_local(
|
||||
local_tx: impl AsyncWrite,
|
||||
local_tx: impl AsyncWrite + Send,
|
||||
mut ws_rx: impl TunnelRead,
|
||||
mut close_rx: oneshot::Receiver<()>,
|
||||
) -> anyhow::Result<()> {
|
||||
let _guard = scopeguard::guard((), |_| {
|
||||
info!("Closing local <<== remote tunnel");
|
||||
info!("Closing local <= remote tunnel");
|
||||
});
|
||||
|
||||
pin_mut!(local_tx);
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
use std::future::Future;
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
pub mod io;
|
||||
pub mod websocket;
|
||||
|
||||
pub trait TunnelWrite {
|
||||
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()>;
|
||||
async fn ping(&mut self) -> anyhow::Result<()>;
|
||||
async fn close(&mut self) -> anyhow::Result<()>;
|
||||
pub trait TunnelWrite: Send + 'static {
|
||||
fn write(&mut self, buf: &[u8]) -> impl Future<Output = anyhow::Result<()>> + Send;
|
||||
fn ping(&mut self) -> impl Future<Output = anyhow::Result<()>> + Send;
|
||||
fn close(&mut self) -> impl Future<Output = anyhow::Result<()>> + Send;
|
||||
}
|
||||
|
||||
pub trait TunnelRead {
|
||||
async fn copy(&mut self, writer: impl AsyncWrite + Unpin) -> anyhow::Result<()>;
|
||||
pub trait TunnelRead: Send + 'static {
|
||||
fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> impl Future<Output = anyhow::Result<()>> + Send;
|
||||
}
|
||||
|
|
|
@ -1,11 +1,22 @@
|
|||
use crate::tunnel::transport::{TunnelRead, TunnelWrite};
|
||||
use crate::tunnel::{tunnel_to_jwt_token, RemoteAddr, JWT_HEADER_PREFIX};
|
||||
use crate::WsClientConfig;
|
||||
use anyhow::{anyhow, Context};
|
||||
use bytes::Bytes;
|
||||
use fastwebsockets::{Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
|
||||
use http_body_util::Empty;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE};
|
||||
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper::{Request, Response};
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use log::debug;
|
||||
use std::ops::DerefMut;
|
||||
use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
||||
use tracing::trace;
|
||||
use uuid::Uuid;
|
||||
|
||||
impl TunnelWrite for WebSocketWrite<WriteHalf<TokioIo<Upgraded>>> {
|
||||
async fn write(&mut self, buf: &[u8]) -> anyhow::Result<()> {
|
||||
|
@ -32,7 +43,7 @@ fn frame_reader(x: Frame<'_>) -> futures_util::future::Ready<anyhow::Result<()>>
|
|||
futures_util::future::ready(anyhow::Ok(()))
|
||||
}
|
||||
impl TunnelRead for WebSocketRead<ReadHalf<TokioIo<Upgraded>>> {
|
||||
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin) -> anyhow::Result<()> {
|
||||
async fn copy(&mut self, mut writer: impl AsyncWrite + Unpin + Send) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let msg = self
|
||||
.read_frame(&mut frame_reader)
|
||||
|
@ -52,3 +63,51 @@ impl TunnelRead for WebSocketRead<ReadHalf<TokioIo<Upgraded>>> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect(
|
||||
request_id: Uuid,
|
||||
client_cfg: &WsClientConfig,
|
||||
dest_addr: &RemoteAddr,
|
||||
) -> anyhow::Result<((impl TunnelRead, impl TunnelWrite), Response<Incoming>)> {
|
||||
let mut pooled_cnx = match client_cfg.cnx_pool().get().await {
|
||||
Ok(cnx) => Ok(cnx),
|
||||
Err(err) => Err(anyhow!("failed to get a connection to the server from the pool: {err:?}")),
|
||||
}?;
|
||||
|
||||
let mut req = Request::builder()
|
||||
.method("GET")
|
||||
.uri(format!("/{}/events", &client_cfg.http_upgrade_path_prefix,))
|
||||
.header(HOST, &client_cfg.http_header_host)
|
||||
.header(UPGRADE, "websocket")
|
||||
.header(CONNECTION, "upgrade")
|
||||
.header(SEC_WEBSOCKET_KEY, fastwebsockets::handshake::generate_key())
|
||||
.header(SEC_WEBSOCKET_VERSION, "13")
|
||||
.header(
|
||||
SEC_WEBSOCKET_PROTOCOL,
|
||||
format!("v1, {}{}", JWT_HEADER_PREFIX, tunnel_to_jwt_token(request_id, dest_addr)),
|
||||
)
|
||||
.version(hyper::Version::HTTP_11);
|
||||
|
||||
for (k, v) in &client_cfg.http_headers {
|
||||
req = req.header(k, v);
|
||||
}
|
||||
if let Some(auth) = &client_cfg.http_upgrade_credentials {
|
||||
req = req.header(AUTHORIZATION, auth);
|
||||
}
|
||||
|
||||
let req = req.body(Empty::<Bytes>::new()).with_context(|| {
|
||||
format!(
|
||||
"failed to build HTTP request to contact the server {:?}",
|
||||
client_cfg.remote_addr
|
||||
)
|
||||
})?;
|
||||
debug!("with HTTP upgrade request {:?}", req);
|
||||
let transport = pooled_cnx.deref_mut().take().unwrap();
|
||||
let (mut ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport)
|
||||
.await
|
||||
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||
|
||||
ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
|
||||
|
||||
Ok((ws.split(tokio::io::split), response))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue