diff --git a/app/Main.hs b/app/Main.hs index 790ef48..338c0db 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -17,6 +17,7 @@ import System.Environment (getArgs, withArgs) import qualified Logger import Tunnel import Types +import Credentials import Control.Concurrent.Async as Async data WsTunnel = WsTunnel @@ -28,8 +29,6 @@ data WsTunnel = WsTunnel , udpTimeout :: Int , proxy :: String , soMark :: Int - , serverMode :: Bool - , restrictTo :: String , verbose :: Bool , quiet :: Bool , pathPrefix :: String @@ -38,6 +37,10 @@ data WsTunnel = WsTunnel , websocketPingFrequencySec :: Int , wsTunnelCredentials :: String , customHeaders :: [String] + , serverMode :: Bool + , restrictTo :: String + , tlsCertificate :: FilePath + , tlsKey :: FilePath } deriving (Show, Data, Typeable) data WsServerInfo = WsServerInfo @@ -88,15 +91,19 @@ cmdLine = WsTunnel , serverMode = def &= explicit &= name "server" &= help "Start a server that will forward traffic for you" &= groupname "Server options" - , restrictTo = def &= explicit &= name "r" &= name "restrictTo" + , restrictTo = def &= explicit &= name "r" &= name "restrictTo" &= groupname "Server options" &= help "Accept traffic to be forwarded only to this service" &= typ "HOST:PORT" + , tlsCertificate = def &= explicit &= name "tlsCertificate" &= groupname "Server options" + &= help "[optional] provide a custom tls certificate (.crt) that the server will use instead of the embeded one" &= typFile + , tlsKey = def &= explicit &= name "tlsKey" &= groupname "Server options" + &= help "[optional] provide a custom tls key (.key) that the server will use instead of the embeded one" &= typFile , verbose = def &= groupname "Common options" &= help "Print debug information" - , quiet = def &= help "Print only errors" + , quiet = def &= help "Print only errors" &= groupname "Common options" } &= summary ( "Use the websockets protocol to tunnel {TCP,UDP} traffic\n" ++ "wsTunnelClient <---> wsTunnelServer <---> RemoteHost\n" ++ "Use secure connection (wss://) to bypass proxies" ) - &= helpArg [explicit, name "help", name "h"] + &= helpArg [explicit, name "help", name "h", groupname "Common options"] toPort :: String -> Int @@ -212,7 +219,10 @@ runApp cfg serverInfo -- server mode | serverMode cfg = do putStrLn $ "Starting server with opts " <> tshow serverInfo - runServer (Main.useTls serverInfo) (Main.host serverInfo, fromIntegral $ Main.port serverInfo) (parseRestrictTo $ restrictTo cfg) + key <- if (Main.tlsKey cfg) /= mempty then readFile (Main.tlsKey cfg) else return Credentials.key + certificate <- if (Main.tlsCertificate cfg) /= mempty then readFile (Main.tlsCertificate cfg) else return Credentials.certificate + let tls = if Main.useTls serverInfo then Just (certificate, key) else Nothing + runServer tls (Main.host serverInfo, fromIntegral $ Main.port serverInfo) (parseRestrictTo $ restrictTo cfg) -- -L localToRemote tunnels | not . null $ localToRemote cfg = do diff --git a/src/Tunnel.hs b/src/Tunnel.hs index 26ee1e5..6af014e 100644 --- a/src/Tunnel.hs +++ b/src/Tunnel.hs @@ -188,11 +188,11 @@ runClient cfg@TunnelSettings{..} = do -- -- Server -- -runTlsTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () -runTlsTunnelingServer endPoint@(bindTo, portNumber) isAllowed = do +runTlsTunnelingServer :: (ByteString, ByteString) -> (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () +runTlsTunnelingServer (tlsCert, tlsKey) endPoint@(bindTo, portNumber) isAllowed = do info $ "WAIT for TLS connection on " <> toStr endPoint - N.runTCPServerTLS (N.tlsConfigBS (fromString bindTo) (fromIntegral portNumber) Credentials.certificate Credentials.key) $ \sClient -> + N.runTCPServerTLS (N.tlsConfigBS (fromString bindTo) (fromIntegral portNumber) tlsCert tlsKey) $ \sClient -> runApp sClient WS.defaultConnectionOptions (serverEventLoop (N.appSockAddr sClient) isAllowed) info "SHUTDOWN server" @@ -244,8 +244,9 @@ serverEventLoop sClient isAllowed pendingConn = do SOCKS5 -> mempty -runServer :: Bool -> (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () -runServer useTLS = if useTLS then runTlsTunnelingServer else runTunnelingServer +runServer :: Maybe (ByteString, ByteString) -> (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () +runServer Nothing = runTunnelingServer +runServer (Just (tlsCert, tlsKey)) = runTlsTunnelingServer (tlsCert, tlsKey)