diff --git a/src/main.rs b/src/main.rs index de077b7..21bd7aa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -191,7 +191,8 @@ struct Client { #[arg(long, value_name = "USER[:PASS]", value_parser = parse_http_credentials, verbatim_doc_comment)] http_upgrade_credentials: Option, - /// Frequency at which the client will send websocket ping to the server. + /// Frequency at which the client will send websocket pings to the server. + /// Set to zero to disable. #[arg(long, value_name = "seconds", default_value = "30", value_parser = parse_duration_sec, verbatim_doc_comment)] websocket_ping_frequency_sec: Option, @@ -277,7 +278,8 @@ struct Server { socket_so_mark: Option, /// Frequency at which the server will send websocket ping to client. - #[arg(long, value_name = "seconds", value_parser = parse_duration_sec, verbatim_doc_comment)] + /// Set to zero to disable. + #[arg(long, value_name = "seconds", default_value = "30", value_parser = parse_duration_sec, verbatim_doc_comment)] websocket_ping_frequency_sec: Option, /// Enable the masking of websocket frames. Default is false @@ -806,7 +808,10 @@ async fn main() -> anyhow::Result<()> { http_headers_file: args.http_headers_file, http_header_host: host_header, timeout_connect: Duration::from_secs(10), - websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)), + websocket_ping_frequency: args + .websocket_ping_frequency_sec + .or(Some(Duration::from_secs(30))) + .filter(|d| d.as_secs() > 0), websocket_mask_frame: args.websocket_mask_frame, dns_resolver: DnsResolver::new_from_urls( &args.dns_resolver, @@ -1121,7 +1126,10 @@ async fn main() -> anyhow::Result<()> { let server_config = WsServerConfig { socket_so_mark: args.socket_so_mark, bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0], - websocket_ping_frequency: args.websocket_ping_frequency_sec, + websocket_ping_frequency: args + .websocket_ping_frequency_sec + .or(Some(Duration::from_secs(30))) + .filter(|d| d.as_secs() > 0), timeout_connect: Duration::from_secs(10), websocket_mask_frame: args.websocket_mask_frame, tls: tls_config, diff --git a/src/tunnel/client/client.rs b/src/tunnel/client/client.rs index 09ee545..0f6172d 100644 --- a/src/tunnel/client/client.rs +++ b/src/tunnel/client/client.rs @@ -85,7 +85,7 @@ impl WsClient { // Forward local tx to websocket tx let ping_frequency = self.config.websocket_ping_frequency; tokio::spawn( - super::super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency)) + super::super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, ping_frequency) .instrument(Span::current()), ); @@ -197,13 +197,8 @@ impl WsClient { let tunnel = async move { let ping_frequency = client.config.websocket_ping_frequency; tokio::spawn( - super::super::transport::io::propagate_local_to_remote( - local_rx, - ws_tx, - close_tx, - Some(ping_frequency), - ) - .in_current_span(), + super::super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, ping_frequency) + .in_current_span(), ); // Forward websocket rx to local rx diff --git a/src/tunnel/client/config.rs b/src/tunnel/client/config.rs index 4d44c21..6ec9cb1 100644 --- a/src/tunnel/client/config.rs +++ b/src/tunnel/client/config.rs @@ -22,7 +22,7 @@ pub struct WsClientConfig { pub http_headers_file: Option, pub http_header_host: HeaderValue, pub timeout_connect: Duration, - pub websocket_ping_frequency: Duration, + pub websocket_ping_frequency: Option, pub websocket_mask_frame: bool, pub http_proxy: Option, pub dns_resolver: DnsResolver, diff --git a/src/tunnel/server/handler_websocket.rs b/src/tunnel/server/handler_websocket.rs index d3bdc27..1c749ae 100644 --- a/src/tunnel/server/handler_websocket.rs +++ b/src/tunnel/server/handler_websocket.rs @@ -11,7 +11,6 @@ use hyper::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}; use hyper::{Request, Response}; use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; use tokio::sync::oneshot; use tracing::{error, warn, Instrument, Span}; @@ -69,7 +68,7 @@ pub(super) async fn ws_server_upgrade( local_rx, WebsocketTunnelWrite::new(ws_tx, pending_ops), close_tx, - Some(Duration::from_secs(30)), + server.config.websocket_ping_frequency, ) .await; }