cleanup code
This commit is contained in:
parent
dfb20ff55f
commit
f19efa37f1
1 changed files with 131 additions and 60 deletions
|
@ -1,7 +1,7 @@
|
||||||
use ahash::{HashMap, HashMapExt};
|
use ahash::{HashMap, HashMapExt};
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use futures_util::{pin_mut, Stream, StreamExt};
|
use futures_util::{pin_mut, FutureExt, Stream, StreamExt};
|
||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
|
@ -30,9 +30,9 @@ use tokio::sync::{mpsc, oneshot};
|
||||||
use tracing::{error, info, span, warn, Instrument, Level, Span};
|
use tracing::{error, info, span, warn, Instrument, Level, Span};
|
||||||
use url::Host;
|
use url::Host;
|
||||||
|
|
||||||
async fn from_query(
|
async fn run_tunnel(
|
||||||
server_config: &WsServerConfig,
|
server_config: &WsServerConfig,
|
||||||
query: &str,
|
jwt: TokenData<JwtTunnelConfig>,
|
||||||
) -> anyhow::Result<(
|
) -> anyhow::Result<(
|
||||||
LocalProtocol,
|
LocalProtocol,
|
||||||
Host,
|
Host,
|
||||||
|
@ -40,30 +40,6 @@ async fn from_query(
|
||||||
Pin<Box<dyn AsyncRead + Send>>,
|
Pin<Box<dyn AsyncRead + Send>>,
|
||||||
Pin<Box<dyn AsyncWrite + Send>>,
|
Pin<Box<dyn AsyncWrite + Send>>,
|
||||||
)> {
|
)> {
|
||||||
let jwt: TokenData<JwtTunnelConfig> = match query.split_once('=') {
|
|
||||||
Some(("bearer", jwt)) => {
|
|
||||||
let (validation, decode_key) = JWT_DECODE.deref();
|
|
||||||
match jsonwebtoken::decode(jwt, decode_key, validation) {
|
|
||||||
Ok(jwt) => jwt,
|
|
||||||
err => {
|
|
||||||
error!("error while decoding jwt for tunnel info {:?}", err);
|
|
||||||
return Err(anyhow::anyhow!("Invalid upgrade request"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_err => return Err(anyhow::anyhow!("Invalid upgrade request")),
|
|
||||||
};
|
|
||||||
|
|
||||||
Span::current().record("id", &jwt.claims.id);
|
|
||||||
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
|
||||||
if let Some(allowed_dests) = &server_config.restrict_to {
|
|
||||||
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
|
|
||||||
if allowed_dests.iter().any(|dest| dest == &requested_dest).not() {
|
|
||||||
warn!("Rejecting connection with not allowed destination: {}", requested_dest);
|
|
||||||
return Err(anyhow::anyhow!("Invalid upgrade request"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
match jwt.claims.p {
|
match jwt.claims.p {
|
||||||
LocalProtocol::Udp { timeout, .. } => {
|
LocalProtocol::Udp { timeout, .. } => {
|
||||||
let host = Host::parse(&jwt.claims.r)?;
|
let host = Host::parse(&jwt.claims.r)?;
|
||||||
|
@ -199,24 +175,29 @@ where
|
||||||
Ok(cnx)
|
Ok(cnx)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn server_upgrade(
|
#[inline]
|
||||||
server_config: Arc<WsServerConfig>,
|
fn extract_x_forwarded_for(req: &Request<Incoming>) -> Result<Option<&str>, Response<String>> {
|
||||||
mut req: Request<Incoming>,
|
let Some(x_forward_for) = req.headers().get("X-Forwarded-For") else {
|
||||||
) -> Result<Response<String>, anyhow::Error> {
|
return Ok(None);
|
||||||
if let Some(x) = req.headers().get("X-Forwarded-For") {
|
};
|
||||||
info!("Request X-Forwarded-For: {:?}", x);
|
|
||||||
Span::current().record("forwarded_for", x.to_str().unwrap_or_default());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
Ok(Some(x_forward_for.to_str().unwrap_or_default()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn validate_url(
|
||||||
|
req: &Request<Incoming>,
|
||||||
|
path_restriction_prefix: &Option<Vec<String>>,
|
||||||
|
) -> Result<(), Response<String>> {
|
||||||
if !req.uri().path().ends_with("/events") {
|
if !req.uri().path().ends_with("/events") {
|
||||||
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
||||||
return Ok(http::Response::builder()
|
return Err(http::Response::builder()
|
||||||
.status(StatusCode::BAD_REQUEST)
|
.status(StatusCode::BAD_REQUEST)
|
||||||
.body("Invalid upgrade request".into())
|
.body("Invalid upgrade request".into())
|
||||||
.unwrap());
|
.unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(paths_prefix) = &server_config.restrict_http_upgrade_path_prefix {
|
if let Some(paths_prefix) = &path_restriction_prefix {
|
||||||
let path = req.uri().path();
|
let path = req.uri().path();
|
||||||
let min_len = min(path.len(), 1);
|
let min_len = min(path.len(), 1);
|
||||||
let mut max_len = 0;
|
let mut max_len = 0;
|
||||||
|
@ -228,34 +209,121 @@ async fn server_upgrade(
|
||||||
|| !path[max_len..].starts_with('/')
|
|| !path[max_len..].starts_with('/')
|
||||||
{
|
{
|
||||||
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
|
warn!("Rejecting connection with bad path prefix in upgrade request: {}", req.uri());
|
||||||
return Ok(http::Response::builder()
|
return Err(http::Response::builder()
|
||||||
.status(StatusCode::BAD_REQUEST)
|
.status(StatusCode::BAD_REQUEST)
|
||||||
.body("Invalid upgrade request".to_string())
|
.body("Invalid upgrade request".to_string())
|
||||||
.unwrap());
|
.unwrap());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let (protocol, dest, port, local_rx, local_tx) =
|
Ok(())
|
||||||
match from_query(&server_config, req.uri().query().unwrap_or_default()).await {
|
}
|
||||||
Ok(ret) => ret,
|
|
||||||
Err(err) => {
|
|
||||||
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
|
||||||
return Ok(http::Response::builder()
|
|
||||||
.status(StatusCode::BAD_REQUEST)
|
|
||||||
.body("Invalid upgrade request".to_string())
|
|
||||||
.unwrap());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn extract_tunnel_info(req: &Request<Incoming>) -> Result<TokenData<JwtTunnelConfig>, Response<String>> {
|
||||||
|
let jwt: TokenData<JwtTunnelConfig> = match req.uri().query().unwrap_or_default().split_once('=') {
|
||||||
|
Some(("bearer", jwt)) => {
|
||||||
|
let (validation, decode_key) = JWT_DECODE.deref();
|
||||||
|
match jsonwebtoken::decode(jwt, decode_key, validation) {
|
||||||
|
Ok(jwt) => jwt,
|
||||||
|
err => {
|
||||||
|
error!("error while decoding jwt for tunnel info {:?}", err);
|
||||||
|
return Err(http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body("Invalid upgrade request".to_string())
|
||||||
|
.unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err => {
|
||||||
|
error!("Missing jwt tunnel config from request {:?}", err);
|
||||||
|
return Err(http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body("Invalid upgrade request".to_string())
|
||||||
|
.unwrap());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(jwt)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
fn validate_destination(
|
||||||
|
_req: &Request<Incoming>,
|
||||||
|
jwt: &TokenData<JwtTunnelConfig>,
|
||||||
|
destination_restriction: &Option<Vec<String>>,
|
||||||
|
) -> Result<(), Response<String>> {
|
||||||
|
let Some(allowed_dests) = &destination_restriction else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let requested_dest = format!("{}:{}", jwt.claims.r, jwt.claims.rp);
|
||||||
|
if allowed_dests.iter().any(|dest| dest == &requested_dest).not() {
|
||||||
|
warn!("Rejecting connection with not allowed destination: {}", requested_dest);
|
||||||
|
return Err(http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body("Invalid upgrade request".to_string())
|
||||||
|
.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn server_upgrade(server_config: Arc<WsServerConfig>, mut req: Request<Incoming>) -> Response<String> {
|
||||||
|
if !fastwebsockets::upgrade::is_upgrade_request(&req) {
|
||||||
|
warn!("Rejecting connection with bad upgrade request: {}", req.uri());
|
||||||
|
return http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body("Invalid upgrade request".to_string())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
match extract_x_forwarded_for(&req) {
|
||||||
|
Ok(Some(x_forward_for)) => {
|
||||||
|
info!("Request X-Forwarded-For: {:?}", x_forward_for);
|
||||||
|
Span::current().record("forwarded_for", x_forward_for);
|
||||||
|
}
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(err) => return err,
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(err) = validate_url(&req, &server_config.restrict_http_upgrade_path_prefix) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
let jwt = match extract_tunnel_info(&req) {
|
||||||
|
Ok(jwt) => jwt,
|
||||||
|
Err(err) => return err,
|
||||||
|
};
|
||||||
|
|
||||||
|
Span::current().record("id", &jwt.claims.id);
|
||||||
|
Span::current().record("remote", format!("{}:{}", jwt.claims.r, jwt.claims.rp));
|
||||||
|
|
||||||
|
if let Err(err) = validate_destination(&req, &jwt, &server_config.restrict_to) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
let tunnel = match run_tunnel(&server_config, jwt).await {
|
||||||
|
Ok(ret) => ret,
|
||||||
|
Err(err) => {
|
||||||
|
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||||
|
return http::Response::builder()
|
||||||
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body("Invalid upgrade request".to_string())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (protocol, dest, port, local_rx, local_tx) = tunnel;
|
||||||
info!("connected to {:?} {:?} {:?}", protocol, dest, port);
|
info!("connected to {:?} {:?} {:?}", protocol, dest, port);
|
||||||
let (mut response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
|
let (mut response, fut) = match fastwebsockets::upgrade::upgrade(&mut req) {
|
||||||
Ok(ret) => ret,
|
Ok(ret) => ret,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
warn!("Rejecting connection with bad upgrade request: {} {}", err, req.uri());
|
||||||
return Ok(http::Response::builder()
|
return http::Response::builder()
|
||||||
.status(StatusCode::BAD_REQUEST)
|
.status(StatusCode::BAD_REQUEST)
|
||||||
.body(format!("Invalid upgrade request: {:?}", err))
|
.body(format!("Invalid upgrade request: {:?}", err))
|
||||||
.unwrap());
|
.unwrap();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -279,23 +347,26 @@ async fn server_upgrade(
|
||||||
);
|
);
|
||||||
|
|
||||||
if protocol == LocalProtocol::ReverseSocks5 {
|
if protocol == LocalProtocol::ReverseSocks5 {
|
||||||
response.headers_mut().insert(
|
let Ok(header_val) = HeaderValue::from_str(
|
||||||
COOKIE,
|
&base64::engine::general_purpose::STANDARD.encode(format!("fake://{}:{}", dest, port)),
|
||||||
HeaderValue::from_str(
|
) else {
|
||||||
&base64::engine::general_purpose::STANDARD.encode(format!("fake://{}:{}", dest, port)),
|
error!("Bad headervalue for reverse socks5: {} {}", dest, port);
|
||||||
)?,
|
return http::Response::builder()
|
||||||
);
|
.status(StatusCode::BAD_REQUEST)
|
||||||
|
.body("Invalid upgrade request".to_string())
|
||||||
|
.unwrap();
|
||||||
|
};
|
||||||
|
response.headers_mut().insert(COOKIE, header_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = Response::from_parts(response.into_parts().0, "".to_string());
|
Response::from_parts(response.into_parts().0, "".to_string())
|
||||||
Ok(response)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
|
pub async fn run_server(server_config: Arc<WsServerConfig>) -> anyhow::Result<()> {
|
||||||
info!("Starting wstunnel server listening on {}", server_config.bind);
|
info!("Starting wstunnel server listening on {}", server_config.bind);
|
||||||
|
|
||||||
let config = server_config.clone();
|
let config = server_config.clone();
|
||||||
let upgrade_fn = move |req: Request<Incoming>| server_upgrade(config.clone(), req);
|
let upgrade_fn = move |req: Request<Incoming>| server_upgrade(config.clone(), req).map::<anyhow::Result<_>, _>(Ok);
|
||||||
|
|
||||||
let listener = TcpListener::bind(&server_config.bind).await?;
|
let listener = TcpListener::bind(&server_config.bind).await?;
|
||||||
let tls_acceptor = if let Some(tls) = &server_config.tls {
|
let tls_acceptor = if let Some(tls) = &server_config.tls {
|
||||||
|
|
Loading…
Reference in a new issue