diff --git a/app/Main.hs b/app/Main.hs index c6c2d36..72c8826 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -90,7 +90,7 @@ main = do >> runServer (host serverInfo, fromIntegral $ port serverInfo) else if not $ null (localToRemote cfg) then let (TunnelInfo lHost lPort rHost rPort) = parseTunnelInfo (localToRemote cfg) - in runClient (if udpMode cfg then UDP else TCP) (lHost, (fromIntegral lPort)) + in runClient (useTls serverInfo) (if udpMode cfg then UDP else TCP) (lHost, (fromIntegral lPort)) (host serverInfo, fromIntegral $ port serverInfo) (rHost, (fromIntegral rPort)) else return () diff --git a/src/Lib.hs b/src/Lib.hs index d6edb72..7bb655c 100644 --- a/src/Lib.hs +++ b/src/Lib.hs @@ -123,7 +123,7 @@ runUDPServer (host, port) app = do runTunnelingClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (WS.Connection -> IO ()) -> IO () runTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do putStrLn $ "OPEN connection to " <> tshow remoteHost <> ":" <> tshow remotePort - void $ WS.runClient wsHost (fromIntegral wsPort) ("/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort) app + void $ WS.runClient wsHost (fromIntegral wsPort) (toPath proto remoteHost remotePort) app putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort @@ -137,8 +137,8 @@ runTunnelingServer (host, port) = do Just (!proto, !rhost, !rport) -> do conn <- WS.acceptRequest pendingConn case proto of - UDP -> runUDPClient (BC.unpack rhost, (fromIntegral rport)) (propagateRW conn) - TCP -> runTCPClient (BC.unpack rhost, (fromIntegral rport)) (propagateRW conn) + UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn) + TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn) putStrLn "CLOSE server" @@ -167,9 +167,9 @@ propagateWrites hTunnel hOther = void . tryAny $ do unless (null payload) (WS.sendBinaryData hTunnel payload >> propagateWrites hTunnel hOther) -runClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (HostName, PortNumber) -> IO () -runClient proto local wsServer remote = do - let out = runSecureClient proto wsServer remote +runClient :: Bool -> Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (HostName, PortNumber) -> IO () +runClient useTls proto local wsServer remote = do + let out = (if useTls then runTlsTunnelingClient else runTunnelingClient) proto wsServer remote case proto of UDP -> runUDPServer local (\hOther -> out (`propagateRW` hOther)) TCP -> runTCPServer local (\hOther -> out (`propagateRW` hOther)) @@ -179,39 +179,34 @@ runServer :: (HostName, PortNumber) -> IO () runServer = runTunnelingServer -runSecureClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (WS.Connection -> IO ()) -> IO () -runSecureClient proto (wsHost, wsPort) (remoteHost, remotePort) app = - let options = defaultConnectionOptions - headers = [] - in runSecureClientWith wsHost (fromIntegral wsPort) - ("/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort) - options headers app +runTlsTunnelingClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (WS.Connection -> IO ()) -> IO () +runTlsTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do + context <- initConnectionContext + connection <- connectTo context (connectionParams wsHost (fromIntegral wsPort)) + stream <- makeStream (reader connection) (writer connection) + runClientWithStream stream wsHost (toPath proto remoteHost remotePort) defaultConnectionOptions [] app -runSecureClientWith :: HostName -> PortNumber -> String -> ConnectionOptions -> Headers -> ClientApp a -> IO a -runSecureClientWith host port path options headers app = do - context <- initConnectionContext - connection <- connectTo context (connectionParams host port) - stream <- makeStream (reader connection) (writer connection) - runClientWithStream stream host path options headers app - connectionParams :: HostName -> PortNumber -> ConnectionParams connectionParams host port = ConnectionParams - { connectionHostname = host - , connectionPort = port - , connectionUseSecure = Just tlsSettings - , connectionUseSocks = Nothing - } + { connectionHostname = host + , connectionPort = port + , connectionUseSecure = Just tlsSettings + , connectionUseSocks = Nothing + } tlsSettings :: TLSSettings tlsSettings = TLSSettingsSimple - { settingDisableCertificateValidation = True - , settingDisableSession = False - , settingUseServerName = False - } + { settingDisableCertificateValidation = True + , settingDisableSession = False + , settingUseServerName = False + } reader :: Connection -> IO (Maybe ByteString) reader connection = fmap Just (connectionGetChunk connection) writer :: Connection -> Maybe BL.ByteString -> IO () writer connection = maybe (return ()) (connectionPut connection . toStrict) + +toPath :: Proto -> HostName -> PortNumber -> String +toPath proto remoteHost remotePort = "/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort