Make the ping frequency from the server to the client configurable (#338)

This commit is contained in:
Ramses 2024-08-20 08:30:12 +02:00 committed by GitHub
parent 4432462087
commit 08936bb5e4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 17 additions and 15 deletions

View file

@ -191,7 +191,8 @@ struct Client {
#[arg(long, value_name = "USER[:PASS]", value_parser = parse_http_credentials, verbatim_doc_comment)] #[arg(long, value_name = "USER[:PASS]", value_parser = parse_http_credentials, verbatim_doc_comment)]
http_upgrade_credentials: Option<HeaderValue>, http_upgrade_credentials: Option<HeaderValue>,
/// Frequency at which the client will send websocket ping to the server. /// Frequency at which the client will send websocket pings to the server.
/// Set to zero to disable.
#[arg(long, value_name = "seconds", default_value = "30", value_parser = parse_duration_sec, verbatim_doc_comment)] #[arg(long, value_name = "seconds", default_value = "30", value_parser = parse_duration_sec, verbatim_doc_comment)]
websocket_ping_frequency_sec: Option<Duration>, websocket_ping_frequency_sec: Option<Duration>,
@ -277,7 +278,8 @@ struct Server {
socket_so_mark: Option<u32>, socket_so_mark: Option<u32>,
/// Frequency at which the server will send websocket ping to client. /// Frequency at which the server will send websocket ping to client.
#[arg(long, value_name = "seconds", value_parser = parse_duration_sec, verbatim_doc_comment)] /// Set to zero to disable.
#[arg(long, value_name = "seconds", default_value = "30", value_parser = parse_duration_sec, verbatim_doc_comment)]
websocket_ping_frequency_sec: Option<Duration>, websocket_ping_frequency_sec: Option<Duration>,
/// Enable the masking of websocket frames. Default is false /// Enable the masking of websocket frames. Default is false
@ -806,7 +808,10 @@ async fn main() -> anyhow::Result<()> {
http_headers_file: args.http_headers_file, http_headers_file: args.http_headers_file,
http_header_host: host_header, http_header_host: host_header,
timeout_connect: Duration::from_secs(10), timeout_connect: Duration::from_secs(10),
websocket_ping_frequency: args.websocket_ping_frequency_sec.unwrap_or(Duration::from_secs(30)), websocket_ping_frequency: args
.websocket_ping_frequency_sec
.or(Some(Duration::from_secs(30)))
.filter(|d| d.as_secs() > 0),
websocket_mask_frame: args.websocket_mask_frame, websocket_mask_frame: args.websocket_mask_frame,
dns_resolver: DnsResolver::new_from_urls( dns_resolver: DnsResolver::new_from_urls(
&args.dns_resolver, &args.dns_resolver,
@ -1121,7 +1126,10 @@ async fn main() -> anyhow::Result<()> {
let server_config = WsServerConfig { let server_config = WsServerConfig {
socket_so_mark: args.socket_so_mark, socket_so_mark: args.socket_so_mark,
bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0], bind: args.remote_addr.socket_addrs(|| Some(8080)).unwrap()[0],
websocket_ping_frequency: args.websocket_ping_frequency_sec, websocket_ping_frequency: args
.websocket_ping_frequency_sec
.or(Some(Duration::from_secs(30)))
.filter(|d| d.as_secs() > 0),
timeout_connect: Duration::from_secs(10), timeout_connect: Duration::from_secs(10),
websocket_mask_frame: args.websocket_mask_frame, websocket_mask_frame: args.websocket_mask_frame,
tls: tls_config, tls: tls_config,

View file

@ -85,7 +85,7 @@ impl WsClient {
// Forward local tx to websocket tx // Forward local tx to websocket tx
let ping_frequency = self.config.websocket_ping_frequency; let ping_frequency = self.config.websocket_ping_frequency;
tokio::spawn( tokio::spawn(
super::super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, Some(ping_frequency)) super::super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, ping_frequency)
.instrument(Span::current()), .instrument(Span::current()),
); );
@ -197,13 +197,8 @@ impl WsClient {
let tunnel = async move { let tunnel = async move {
let ping_frequency = client.config.websocket_ping_frequency; let ping_frequency = client.config.websocket_ping_frequency;
tokio::spawn( tokio::spawn(
super::super::transport::io::propagate_local_to_remote( super::super::transport::io::propagate_local_to_remote(local_rx, ws_tx, close_tx, ping_frequency)
local_rx, .in_current_span(),
ws_tx,
close_tx,
Some(ping_frequency),
)
.in_current_span(),
); );
// Forward websocket rx to local rx // Forward websocket rx to local rx

View file

@ -22,7 +22,7 @@ pub struct WsClientConfig {
pub http_headers_file: Option<PathBuf>, pub http_headers_file: Option<PathBuf>,
pub http_header_host: HeaderValue, pub http_header_host: HeaderValue,
pub timeout_connect: Duration, pub timeout_connect: Duration,
pub websocket_ping_frequency: Duration, pub websocket_ping_frequency: Option<Duration>,
pub websocket_mask_frame: bool, pub websocket_mask_frame: bool,
pub http_proxy: Option<Url>, pub http_proxy: Option<Url>,
pub dns_resolver: DnsResolver, pub dns_resolver: DnsResolver,

View file

@ -11,7 +11,6 @@ use hyper::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL};
use hyper::{Request, Response}; use hyper::{Request, Response};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tracing::{error, warn, Instrument, Span}; use tracing::{error, warn, Instrument, Span};
@ -69,7 +68,7 @@ pub(super) async fn ws_server_upgrade(
local_rx, local_rx,
WebsocketTunnelWrite::new(ws_tx, pending_ops), WebsocketTunnelWrite::new(ws_tx, pending_ops),
close_tx, close_tx,
Some(Duration::from_secs(30)), server.config.websocket_ping_frequency,
) )
.await; .await;
} }