This commit is contained in:
Σrebe - Romain GERARD 2024-05-29 19:19:03 +02:00
parent 677b29bedf
commit 2dd99130fa
No known key found for this signature in database
GPG key ID: 7A42B4B97E0332F4
14 changed files with 140 additions and 161 deletions

View file

@ -88,29 +88,24 @@ pub enum TransportScheme {
}
impl TransportScheme {
pub fn values() -> &'static [TransportScheme] {
&[
TransportScheme::Ws,
TransportScheme::Wss,
TransportScheme::Http,
TransportScheme::Https,
]
pub const fn values() -> &'static [Self] {
&[Self::Ws, Self::Wss, Self::Http, Self::Https]
}
pub fn to_str(self) -> &'static str {
pub const fn to_str(self) -> &'static str {
match self {
TransportScheme::Ws => "ws",
TransportScheme::Wss => "wss",
TransportScheme::Http => "http",
TransportScheme::Https => "https",
Self::Ws => "ws",
Self::Wss => "wss",
Self::Http => "http",
Self::Https => "https",
}
}
pub fn alpn_protocols(&self) -> Vec<Vec<u8>> {
match self {
TransportScheme::Ws => vec![],
TransportScheme::Wss => vec![b"http/1.1".to_vec()],
TransportScheme::Http => vec![],
TransportScheme::Https => vec![b"h2".to_vec()],
Self::Ws => vec![],
Self::Wss => vec![b"http/1.1".to_vec()],
Self::Http => vec![],
Self::Https => vec![b"h2".to_vec()],
}
}
}
@ -119,10 +114,10 @@ impl FromStr for TransportScheme {
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"https" => Ok(TransportScheme::Https),
"http" => Ok(TransportScheme::Http),
"wss" => Ok(TransportScheme::Wss),
"ws" => Ok(TransportScheme::Ws),
"https" => Ok(Self::Https),
"http" => Ok(Self::Http),
"wss" => Ok(Self::Wss),
"ws" => Ok(Self::Ws),
_ => Err(()),
}
}
@ -169,71 +164,71 @@ impl Debug for TransportAddr {
impl TransportAddr {
pub fn new(scheme: TransportScheme, host: Host, port: u16, tls: Option<TlsClientConfig>) -> Option<Self> {
match scheme {
TransportScheme::Https => Some(TransportAddr::Https {
TransportScheme::Https => Some(Self::Https {
scheme: TransportScheme::Https,
tls: tls?,
host,
port,
}),
TransportScheme::Http => Some(TransportAddr::Http {
TransportScheme::Http => Some(Self::Http {
scheme: TransportScheme::Http,
host,
port,
}),
TransportScheme::Wss => Some(TransportAddr::Wss {
TransportScheme::Wss => Some(Self::Wss {
scheme: TransportScheme::Wss,
tls: tls?,
host,
port,
}),
TransportScheme::Ws => Some(TransportAddr::Ws {
TransportScheme::Ws => Some(Self::Ws {
scheme: TransportScheme::Ws,
host,
port,
}),
}
}
pub fn is_websocket(&self) -> bool {
matches!(self, TransportAddr::Ws { .. } | TransportAddr::Wss { .. })
pub const fn is_websocket(&self) -> bool {
matches!(self, Self::Ws { .. } | Self::Wss { .. })
}
pub fn is_http2(&self) -> bool {
matches!(self, TransportAddr::Http { .. } | TransportAddr::Https { .. })
pub const fn is_http2(&self) -> bool {
matches!(self, Self::Http { .. } | Self::Https { .. })
}
pub fn tls(&self) -> Option<&TlsClientConfig> {
pub const fn tls(&self) -> Option<&TlsClientConfig> {
match self {
TransportAddr::Wss { tls, .. } => Some(tls),
TransportAddr::Https { tls, .. } => Some(tls),
TransportAddr::Ws { .. } => None,
TransportAddr::Http { .. } => None,
Self::Wss { tls, .. } => Some(tls),
Self::Https { tls, .. } => Some(tls),
Self::Ws { .. } => None,
Self::Http { .. } => None,
}
}
pub fn host(&self) -> &Host {
pub const fn host(&self) -> &Host {
match self {
TransportAddr::Wss { host, .. } => host,
TransportAddr::Ws { host, .. } => host,
TransportAddr::Https { host, .. } => host,
TransportAddr::Http { host, .. } => host,
Self::Wss { host, .. } => host,
Self::Ws { host, .. } => host,
Self::Https { host, .. } => host,
Self::Http { host, .. } => host,
}
}
pub fn port(&self) -> u16 {
pub const fn port(&self) -> u16 {
match self {
TransportAddr::Wss { port, .. } => *port,
TransportAddr::Ws { port, .. } => *port,
TransportAddr::Https { port, .. } => *port,
TransportAddr::Http { port, .. } => *port,
Self::Wss { port, .. } => *port,
Self::Ws { port, .. } => *port,
Self::Https { port, .. } => *port,
Self::Http { port, .. } => *port,
}
}
pub fn scheme(&self) -> &TransportScheme {
pub const fn scheme(&self) -> &TransportScheme {
match self {
TransportAddr::Wss { scheme, .. } => scheme,
TransportAddr::Ws { scheme, .. } => scheme,
TransportAddr::Https { scheme, .. } => scheme,
TransportAddr::Http { scheme, .. } => scheme,
Self::Wss { scheme, .. } => scheme,
Self::Ws { scheme, .. } => scheme,
Self::Https { scheme, .. } => scheme,
Self::Http { scheme, .. } => scheme,
}
}
}
@ -257,8 +252,8 @@ pub enum TransportStream {
impl AsyncRead for TransportStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
Self::Plain(cnx) => Pin::new(cnx).poll_read(cx, buf),
Self::Tls(cnx) => Pin::new(cnx).poll_read(cx, buf),
}
}
}
@ -266,22 +261,22 @@ impl AsyncRead for TransportStream {
impl AsyncWrite for TransportStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
Self::Plain(cnx) => Pin::new(cnx).poll_write(cx, buf),
Self::Tls(cnx) => Pin::new(cnx).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_flush(cx),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_flush(cx),
Self::Plain(cnx) => Pin::new(cnx).poll_flush(cx),
Self::Tls(cnx) => Pin::new(cnx).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx),
Self::Plain(cnx) => Pin::new(cnx).poll_shutdown(cx),
Self::Tls(cnx) => Pin::new(cnx).poll_shutdown(cx),
}
}
@ -291,15 +286,15 @@ impl AsyncWrite for TransportStream {
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, Error>> {
match self.get_mut() {
TransportStream::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
TransportStream::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
Self::Plain(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
Self::Tls(cnx) => Pin::new(cnx).poll_write_vectored(cx, bufs),
}
}
fn is_write_vectored(&self) -> bool {
match &self {
TransportStream::Plain(cnx) => cnx.is_write_vectored(),
TransportStream::Tls(cnx) => cnx.is_write_vectored(),
Self::Plain(cnx) => cnx.is_write_vectored(),
Self::Tls(cnx) => cnx.is_write_vectored(),
}
}
}

View file

@ -37,7 +37,7 @@ enum TlsReloaderState {
impl TlsReloaderState {
fn fs_watcher(&self) -> &Mutex<RecommendedWatcher> {
match self {
TlsReloaderState::Empty => unreachable!(),
Self::Empty => unreachable!(),
Server(this) => &this.fs_watcher,
Client(this) => &this.fs_watcher,
}

View file

@ -24,7 +24,7 @@ pub struct Http2TunnelRead {
}
impl Http2TunnelRead {
pub fn new(inner: BodyStream<Incoming>) -> Self {
pub const fn new(inner: BodyStream<Incoming>) -> Self {
Self { inner }
}
}
@ -108,23 +108,24 @@ pub async fn connect(
}?;
// In http2 HOST header does not exist, it is explicitly set in the authority from the request uri
let (headers_file, authority) = if let Some(headers_file_path) = &client_cfg.http_headers_file {
let (host, headers) = headers_from_file(headers_file_path);
let host = if let Some((_, v)) = host {
match (client_cfg.remote_addr.scheme(), client_cfg.remote_addr.port()) {
(TransportScheme::Http, 80) | (TransportScheme::Https, 443) => {
Some(v.to_str().unwrap_or("").to_string())
let (headers_file, authority) = client_cfg
.http_headers_file
.as_ref()
.map_or((None, None), |headers_file_path| {
let (host, headers) = headers_from_file(headers_file_path);
let host = if let Some((_, v)) = host {
match (client_cfg.remote_addr.scheme(), client_cfg.remote_addr.port()) {
(TransportScheme::Http, 80) | (TransportScheme::Https, 443) => {
Some(v.to_str().unwrap_or("").to_string())
}
(_, port) => Some(format!("{}:{}", v.to_str().unwrap_or(""), port)),
}
(_, port) => Some(format!("{}:{}", v.to_str().unwrap_or(""), port)),
}
} else {
None
};
} else {
None
};
(Some(headers), host)
} else {
(None, None)
};
(Some(headers), host)
});
let mut req = Request::builder()
.method("POST")
@ -133,7 +134,7 @@ pub async fn connect(
client_cfg.remote_addr.scheme(),
authority
.as_deref()
.unwrap_or(client_cfg.http_header_host.to_str().unwrap_or("")),
.unwrap_or_else(|| client_cfg.http_header_host.to_str().unwrap_or("")),
&client_cfg.http_upgrade_path_prefix
))
.header(COOKIE, tunnel_to_jwt_token(request_id, dest_addr))

View file

@ -24,7 +24,7 @@ pub async fn propagate_local_to_remote(
// We do our own pin_mut! to avoid shadowing timeout and be able to reset it, on next loop iteration
// We reuse the future to avoid creating a timer in the tight loop
let frequency = ping_frequency.unwrap_or(Duration::from_secs(3600 * 24));
let start_at = Instant::now().checked_add(frequency).unwrap_or(Instant::now());
let start_at = Instant::now().checked_add(frequency).unwrap_or_else(Instant::now);
let timeout = tokio::time::interval_at(start_at, frequency);
let should_close = close_tx.closed().fuse();

View file

@ -38,8 +38,8 @@ pub enum TunnelReader {
impl TunnelRead for TunnelReader {
async fn copy(&mut self, writer: impl AsyncWrite + Unpin + Send) -> Result<(), std::io::Error> {
match self {
TunnelReader::Websocket(s) => s.copy(writer).await,
TunnelReader::Http2(s) => s.copy(writer).await,
Self::Websocket(s) => s.copy(writer).await,
Self::Http2(s) => s.copy(writer).await,
}
}
}
@ -52,29 +52,29 @@ pub enum TunnelWriter {
impl TunnelWrite for TunnelWriter {
fn buf_mut(&mut self) -> &mut BytesMut {
match self {
TunnelWriter::Websocket(s) => s.buf_mut(),
TunnelWriter::Http2(s) => s.buf_mut(),
Self::Websocket(s) => s.buf_mut(),
Self::Http2(s) => s.buf_mut(),
}
}
async fn write(&mut self) -> Result<(), std::io::Error> {
match self {
TunnelWriter::Websocket(s) => s.write().await,
TunnelWriter::Http2(s) => s.write().await,
Self::Websocket(s) => s.write().await,
Self::Http2(s) => s.write().await,
}
}
async fn ping(&mut self) -> Result<(), std::io::Error> {
match self {
TunnelWriter::Websocket(s) => s.ping().await,
TunnelWriter::Http2(s) => s.ping().await,
Self::Websocket(s) => s.ping().await,
Self::Http2(s) => s.ping().await,
}
}
async fn close(&mut self) -> Result<(), std::io::Error> {
match self {
TunnelWriter::Websocket(s) => s.close().await,
TunnelWriter::Http2(s) => s.close().await,
Self::Websocket(s) => s.close().await,
Self::Http2(s) => s.close().await,
}
}
}

View file

@ -97,7 +97,7 @@ pub struct WebsocketTunnelRead {
}
impl WebsocketTunnelRead {
pub fn new(ws: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>) -> Self {
pub const fn new(ws: WebSocketRead<ReadHalf<TokioIo<Upgraded>>>) -> Self {
Self { inner: ws }
}
}