fix(#360): data not flushed immediatly on reverse tunnel
This commit is contained in:
parent
e2d1413379
commit
2bc7c94578
4 changed files with 186 additions and 33 deletions
|
@ -4,6 +4,7 @@ use crate::tunnel::client::l4_transport_stream::TransportStream;
|
||||||
use crate::tunnel::client::WsClientConfig;
|
use crate::tunnel::client::WsClientConfig;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bb8::ManageConnection;
|
use bb8::ManageConnection;
|
||||||
|
use bytes::Bytes;
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::instrument;
|
use tracing::instrument;
|
||||||
|
@ -58,9 +59,9 @@ impl ManageConnection for WsConnection {
|
||||||
|
|
||||||
if self.remote_addr.tls().is_some() {
|
if self.remote_addr.tls().is_some() {
|
||||||
let tls_stream = tls::connect(self, tcp_stream).await?;
|
let tls_stream = tls::connect(self, tcp_stream).await?;
|
||||||
Ok(Some(TransportStream::Tls(tls_stream)))
|
Ok(Some(TransportStream::from_client_tls(tls_stream, Bytes::default())))
|
||||||
} else {
|
} else {
|
||||||
Ok(Some(TransportStream::Plain(tcp_stream)))
|
Ok(Some(TransportStream::from_tcp(tcp_stream, Bytes::default())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,46 +1,160 @@
|
||||||
|
use bytes::{Buf, Bytes};
|
||||||
|
use std::cmp;
|
||||||
use std::io::{Error, IoSlice};
|
use std::io::{Error, IoSlice};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, WriteHalf};
|
||||||
|
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio_rustls::client::TlsStream;
|
|
||||||
|
|
||||||
pub enum TransportStream {
|
pub struct TransportStream {
|
||||||
Plain(TcpStream),
|
read: TransportReadHalf,
|
||||||
Tls(TlsStream<TcpStream>),
|
write: TransportWriteHalf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TransportStream {
|
||||||
|
pub fn from_tcp(tcp: TcpStream, read_buf: Bytes) -> Self {
|
||||||
|
let (read, write) = tcp.into_split();
|
||||||
|
Self {
|
||||||
|
read: TransportReadHalf::Plain(read, read_buf),
|
||||||
|
write: TransportWriteHalf::Plain(write),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_client_tls(tls: tokio_rustls::client::TlsStream<TcpStream>, read_buf: Bytes) -> Self {
|
||||||
|
let (read, write) = tokio::io::split(tls);
|
||||||
|
Self {
|
||||||
|
read: TransportReadHalf::Tls(read, read_buf),
|
||||||
|
write: TransportWriteHalf::Tls(write),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_server_tls(tls: tokio_rustls::server::TlsStream<TcpStream>, read_buf: Bytes) -> Self {
|
||||||
|
let (read, write) = tokio::io::split(tls);
|
||||||
|
Self {
|
||||||
|
read: TransportReadHalf::TlsSrv(read, read_buf),
|
||||||
|
write: TransportWriteHalf::TlsSrv(write),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from(self, read_buf: Bytes) -> Self {
|
||||||
|
let mut read = self.read;
|
||||||
|
*read.read_buf_mut() = read_buf;
|
||||||
|
Self {
|
||||||
|
read,
|
||||||
|
write: self.write,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_split(self) -> (TransportReadHalf, TransportWriteHalf) {
|
||||||
|
(self.read, self.write)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum TransportReadHalf {
|
||||||
|
Plain(OwnedReadHalf, Bytes),
|
||||||
|
Tls(ReadHalf<tokio_rustls::client::TlsStream<TcpStream>>, Bytes),
|
||||||
|
TlsSrv(ReadHalf<tokio_rustls::server::TlsStream<TcpStream>>, Bytes),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TransportReadHalf {
|
||||||
|
fn read_buf_mut(&mut self) -> &mut Bytes {
|
||||||
|
match self {
|
||||||
|
Self::Plain(_, buf) => buf,
|
||||||
|
Self::Tls(_, buf) => buf,
|
||||||
|
Self::TlsSrv(_, buf) => buf,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum TransportWriteHalf {
|
||||||
|
Plain(OwnedWriteHalf),
|
||||||
|
Tls(WriteHalf<tokio_rustls::client::TlsStream<TcpStream>>),
|
||||||
|
TlsSrv(WriteHalf<tokio_rustls::server::TlsStream<TcpStream>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncRead for TransportStream {
|
impl AsyncRead for TransportStream {
|
||||||
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||||
match self.get_mut() {
|
unsafe { self.map_unchecked_mut(|s| &mut s.read).poll_read(cx, buf) }
|
||||||
Self::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
|
|
||||||
Self::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncWrite for TransportStream {
|
impl AsyncWrite for TransportStream {
|
||||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||||
match self.get_mut() {
|
unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_write(cx, buf) }
|
||||||
Self::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
}
|
||||||
Self::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
|
unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_flush(cx) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
|
unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_shutdown(cx) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_vectored(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
bufs: &[IoSlice<'_>],
|
||||||
|
) -> Poll<Result<usize, Error>> {
|
||||||
|
unsafe { self.map_unchecked_mut(|s| &mut s.write).poll_write_vectored(cx, bufs) }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_write_vectored(&self) -> bool {
|
||||||
|
self.write.is_write_vectored()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for TransportReadHalf {
|
||||||
|
#[inline]
|
||||||
|
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
|
||||||
|
let read_buf = this.read_buf_mut();
|
||||||
|
if !read_buf.is_empty() {
|
||||||
|
let copy_len = cmp::min(read_buf.len(), buf.remaining());
|
||||||
|
buf.put_slice(&read_buf[..copy_len]);
|
||||||
|
read_buf.advance(copy_len);
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
|
||||||
|
match this {
|
||||||
|
Self::Plain(cnx, _) => Pin::new(cnx).poll_read(cx, buf),
|
||||||
|
Self::Tls(cnx, _) => Pin::new(cnx).poll_read(cx, buf),
|
||||||
|
Self::TlsSrv(cnx, _) => Pin::new(cnx).poll_read(cx, buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for TransportWriteHalf {
|
||||||
|
#[inline]
|
||||||
|
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
Self::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
||||||
|
Self::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
||||||
|
Self::TlsSrv(cnx) => Pin::new(cnx).poll_write(cx, buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
match self.get_mut() {
|
match self.get_mut() {
|
||||||
Self::Plain(cnx) => Pin::new(cnx).poll_flush(cx),
|
Self::Plain(cnx) => Pin::new(cnx).poll_flush(cx),
|
||||||
Self::Tls(cnx) => Pin::new(cnx).poll_flush(cx),
|
Self::Tls(cnx) => Pin::new(cnx).poll_flush(cx),
|
||||||
|
Self::TlsSrv(cnx) => Pin::new(cnx).poll_flush(cx),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||||
match self.get_mut() {
|
match self.get_mut() {
|
||||||
Self::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx),
|
Self::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx),
|
||||||
Self::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx),
|
Self::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx),
|
||||||
|
Self::TlsSrv(cnx) => Pin::new(cnx).poll_shutdown(cx),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
fn poll_write_vectored(
|
fn poll_write_vectored(
|
||||||
self: Pin<&mut Self>,
|
self: Pin<&mut Self>,
|
||||||
cx: &mut Context<'_>,
|
cx: &mut Context<'_>,
|
||||||
|
@ -49,13 +163,16 @@ impl AsyncWrite for TransportStream {
|
||||||
match self.get_mut() {
|
match self.get_mut() {
|
||||||
Self::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
|
Self::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
|
||||||
Self::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
|
Self::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
|
||||||
|
Self::TlsSrv(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
fn is_write_vectored(&self) -> bool {
|
fn is_write_vectored(&self) -> bool {
|
||||||
match &self {
|
match &self {
|
||||||
Self::Plain(cnx) => cnx.is_write_vectored(),
|
Self::Plain(cnx) => cnx.is_write_vectored(),
|
||||||
Self::Tls(cnx) => cnx.is_write_vectored(),
|
Self::Tls(cnx) => cnx.is_write_vectored(),
|
||||||
|
Self::TlsSrv(cnx) => cnx.is_write_vectored(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,8 +2,9 @@ use crate::restrictions::types::RestrictionsRules;
|
||||||
use crate::tunnel::server::utils::{bad_request, inject_cookie};
|
use crate::tunnel::server::utils::{bad_request, inject_cookie};
|
||||||
use crate::tunnel::server::WsServer;
|
use crate::tunnel::server::WsServer;
|
||||||
use crate::tunnel::transport;
|
use crate::tunnel::transport;
|
||||||
use crate::tunnel::transport::websocket::{WebsocketTunnelRead, WebsocketTunnelWrite};
|
use crate::tunnel::transport::websocket::mk_websocket_tunnel;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
use fastwebsockets::Role;
|
||||||
use http_body_util::combinators::BoxBody;
|
use http_body_util::combinators::BoxBody;
|
||||||
use http_body_util::Either;
|
use http_body_util::Either;
|
||||||
use hyper::body::Incoming;
|
use hyper::body::Incoming;
|
||||||
|
@ -46,31 +47,26 @@ pub(super) async fn ws_server_upgrade(
|
||||||
tokio::spawn(
|
tokio::spawn(
|
||||||
async move {
|
async move {
|
||||||
let (ws_rx, ws_tx) = match fut.await {
|
let (ws_rx, ws_tx) = match fut.await {
|
||||||
Ok(mut ws) => {
|
Ok(ws) => mk_websocket_tunnel(ws, Role::Server, mask_frame)?,
|
||||||
ws.set_auto_pong(false);
|
|
||||||
ws.set_auto_close(false);
|
|
||||||
ws.set_auto_apply_mask(mask_frame);
|
|
||||||
ws.split(tokio::io::split)
|
|
||||||
}
|
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("Error during http upgrade request: {:?}", err);
|
error!("Error during http upgrade request: {:?}", err);
|
||||||
return;
|
return Err(anyhow::Error::from(err));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let (close_tx, close_rx) = oneshot::channel::<()>();
|
let (close_tx, close_rx) = oneshot::channel::<()>();
|
||||||
|
|
||||||
let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx);
|
|
||||||
tokio::task::spawn(
|
tokio::task::spawn(
|
||||||
transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()),
|
transport::io::propagate_remote_to_local(local_tx, ws_rx, close_rx).instrument(Span::current()),
|
||||||
);
|
);
|
||||||
|
|
||||||
let _ = transport::io::propagate_local_to_remote(
|
let _ = transport::io::propagate_local_to_remote(
|
||||||
local_rx,
|
local_rx,
|
||||||
WebsocketTunnelWrite::new(ws_tx, pending_ops),
|
ws_tx,
|
||||||
close_tx,
|
close_tx,
|
||||||
server.config.websocket_ping_frequency,
|
server.config.websocket_ping_frequency,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
.instrument(Span::current()),
|
.instrument(Span::current()),
|
||||||
);
|
);
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
use super::io::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
|
use super::io::{TunnelRead, TunnelWrite, MAX_PACKET_LENGTH};
|
||||||
|
use crate::tunnel::client::l4_transport_stream::{TransportReadHalf, TransportStream, TransportWriteHalf};
|
||||||
use crate::tunnel::client::WsClient;
|
use crate::tunnel::client::WsClient;
|
||||||
use crate::tunnel::transport::headers_from_file;
|
use crate::tunnel::transport::headers_from_file;
|
||||||
use crate::tunnel::transport::jwt::{tunnel_to_jwt_token, JWT_HEADER_PREFIX};
|
use crate::tunnel::transport::jwt::{tunnel_to_jwt_token, JWT_HEADER_PREFIX};
|
||||||
use crate::tunnel::RemoteAddr;
|
use crate::tunnel::RemoteAddr;
|
||||||
use anyhow::{anyhow, Context};
|
use anyhow::{anyhow, Context};
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::{Bytes, BytesMut};
|
||||||
use fastwebsockets::{CloseCode, Frame, OpCode, Payload, WebSocketRead, WebSocketWrite};
|
use fastwebsockets::{CloseCode, Frame, OpCode, Payload, Role, WebSocket, WebSocketRead, WebSocketWrite};
|
||||||
use http_body_util::Empty;
|
use http_body_util::Empty;
|
||||||
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE};
|
use hyper::header::{AUTHORIZATION, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE};
|
||||||
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
|
use hyper::header::{CONNECTION, HOST, SEC_WEBSOCKET_KEY};
|
||||||
|
@ -21,14 +22,16 @@ use std::ops::DerefMut;
|
||||||
use std::sync::atomic::AtomicUsize;
|
use std::sync::atomic::AtomicUsize;
|
||||||
use std::sync::atomic::Ordering::Relaxed;
|
use std::sync::atomic::Ordering::Relaxed;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
|
use tokio::io::{AsyncWrite, AsyncWriteExt};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::mpsc::{Receiver, Sender};
|
use tokio::sync::mpsc::{Receiver, Sender};
|
||||||
use tokio::sync::Notify;
|
use tokio::sync::Notify;
|
||||||
|
use tokio_rustls::server::TlsStream;
|
||||||
use tracing::trace;
|
use tracing::trace;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
pub struct WebsocketTunnelWrite {
|
pub struct WebsocketTunnelWrite {
|
||||||
inner: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
|
inner: WebSocketWrite<TransportWriteHalf>,
|
||||||
buf: BytesMut,
|
buf: BytesMut,
|
||||||
pending_operations: Receiver<Frame<'static>>,
|
pending_operations: Receiver<Frame<'static>>,
|
||||||
pending_ops_notify: Arc<Notify>,
|
pending_ops_notify: Arc<Notify>,
|
||||||
|
@ -37,7 +40,7 @@ pub struct WebsocketTunnelWrite {
|
||||||
|
|
||||||
impl WebsocketTunnelWrite {
|
impl WebsocketTunnelWrite {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
ws: WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
|
ws: WebSocketWrite<TransportWriteHalf>,
|
||||||
(pending_operations, notify): (Receiver<Frame<'static>>, Arc<Notify>),
|
(pending_operations, notify): (Receiver<Frame<'static>>, Arc<Notify>),
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
@ -146,13 +149,13 @@ impl TunnelWrite for WebsocketTunnelWrite {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct WebsocketTunnelRead {
|
pub struct WebsocketTunnelRead {
|
||||||
inner: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>,
|
inner: WebSocketRead<TransportReadHalf>,
|
||||||
pending_operations: Sender<Frame<'static>>,
|
pending_operations: Sender<Frame<'static>>,
|
||||||
notify_pending_ops: Arc<Notify>,
|
notify_pending_ops: Arc<Notify>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WebsocketTunnelRead {
|
impl WebsocketTunnelRead {
|
||||||
pub fn new(ws: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>) -> (Self, (Receiver<Frame<'static>>, Arc<Notify>)) {
|
pub fn new(ws: WebSocketRead<TransportReadHalf>) -> (Self, (Receiver<Frame<'static>>, Arc<Notify>)) {
|
||||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||||
let notify = Arc::new(Notify::new());
|
let notify = Arc::new(Notify::new());
|
||||||
(
|
(
|
||||||
|
@ -278,16 +281,52 @@ pub async fn connect(
|
||||||
})?;
|
})?;
|
||||||
debug!("with HTTP upgrade request {:?}", req);
|
debug!("with HTTP upgrade request {:?}", req);
|
||||||
let transport = pooled_cnx.deref_mut().take().unwrap();
|
let transport = pooled_cnx.deref_mut().take().unwrap();
|
||||||
let (mut ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport)
|
let (ws, response) = fastwebsockets::handshake::client(&TokioExecutor::new(), req, transport)
|
||||||
.await
|
.await
|
||||||
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;
|
.with_context(|| format!("failed to do websocket handshake with the server {:?}", client_cfg.remote_addr))?;
|
||||||
|
|
||||||
ws.set_auto_apply_mask(client_cfg.websocket_mask_frame);
|
let (ws_rx, ws_tx) = mk_websocket_tunnel(ws, Role::Client, client_cfg.websocket_mask_frame)?;
|
||||||
ws.set_auto_close(false);
|
Ok((ws_rx, ws_tx, response.into_parts().0))
|
||||||
ws.set_auto_pong(false);
|
}
|
||||||
|
|
||||||
let (ws_rx, ws_tx) = ws.split(tokio::io::split);
|
pub fn mk_websocket_tunnel(
|
||||||
|
ws: WebSocket<TokioIo<Upgraded>>,
|
||||||
|
role: Role,
|
||||||
|
mask_frame: bool,
|
||||||
|
) -> anyhow::Result<(WebsocketTunnelRead, WebsocketTunnelWrite)> {
|
||||||
|
let mut ws = match role {
|
||||||
|
Role::Client => {
|
||||||
|
let stream = ws
|
||||||
|
.into_inner()
|
||||||
|
.into_inner()
|
||||||
|
.downcast::<TokioIo<TransportStream>>()
|
||||||
|
.map_err(|_| anyhow!("cannot downcast websocket client stream"))?;
|
||||||
|
let transport = TransportStream::from(stream.io.into_inner(), stream.read_buf);
|
||||||
|
WebSocket::after_handshake(transport, role)
|
||||||
|
}
|
||||||
|
Role::Server => {
|
||||||
|
let upgraded = ws.into_inner().into_inner();
|
||||||
|
match upgraded.downcast::<TokioIo<TlsStream<TcpStream>>>() {
|
||||||
|
Ok(stream) => {
|
||||||
|
let transport = TransportStream::from_server_tls(stream.io.into_inner(), stream.read_buf);
|
||||||
|
WebSocket::after_handshake(transport, role)
|
||||||
|
}
|
||||||
|
Err(upgraded) => {
|
||||||
|
let stream = upgraded
|
||||||
|
.downcast::<TokioIo<TcpStream>>()
|
||||||
|
.map_err(|_| anyhow!("cannot downcast websocket server stream"))?;
|
||||||
|
let transport = TransportStream::from_tcp(stream.io.into_inner(), stream.read_buf);
|
||||||
|
WebSocket::after_handshake(transport, role)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ws.set_auto_pong(false);
|
||||||
|
ws.set_auto_close(false);
|
||||||
|
ws.set_auto_apply_mask(mask_frame);
|
||||||
|
let (ws_rx, ws_tx) = ws.split(|x| x.into_split());
|
||||||
|
|
||||||
let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx);
|
let (ws_rx, pending_ops) = WebsocketTunnelRead::new(ws_rx);
|
||||||
Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops), response.into_parts().0))
|
Ok((ws_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops)))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue