diff --git a/app/Main.hs b/app/Main.hs index 876334e..c6c2d36 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -33,6 +33,7 @@ data TunnelInfo = TunnelInfo , remotePort :: !Int } deriving (Show) + cmdLine :: WsTunnel cmdLine = WsTunnel { localToRemote = def &= explicit &= name "L" &= name "localToRemote" &= typ "[BIND:]PORT:HOST:PORT" @@ -76,8 +77,6 @@ parseTunnelInfo str = mk $ BC.unpack <$> BC.split ':' (BC.pack str) mk _ = error $ "Invalid tunneling information `" ++ str ++ "`, please use format [BIND:]PORT:HOST:PORT" - - main :: IO () main = do args <- getArgs @@ -85,13 +84,14 @@ main = do let serverInfo = parseServerInfo (WsServerInfo False "" 0) (wsTunnelServer cfg) + if serverMode cfg then putStrLn ("Starting server with opts " ++ show serverInfo ) - >> runServer (host serverInfo, port serverInfo) + >> runServer (host serverInfo, fromIntegral $ port serverInfo) else if not $ null (localToRemote cfg) then let (TunnelInfo lHost lPort rHost rPort) = parseTunnelInfo (localToRemote cfg) - in runClient (if udpMode cfg then UDP else TCP) (lHost, lPort) - (host serverInfo, port serverInfo) (rHost, rPort) + in runClient (if udpMode cfg then UDP else TCP) (lHost, (fromIntegral lPort)) + (host serverInfo, fromIntegral $ port serverInfo) (rHost, (fromIntegral rPort)) else return () diff --git a/src/Lib.hs b/src/Lib.hs index a56d56e..d6edb72 100644 --- a/src/Lib.hs +++ b/src/Lib.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -25,6 +25,16 @@ import qualified Network.Socket.ByteString as N import qualified Network.WebSockets as WS +import qualified Data.ByteString.Lazy as BL +import Network.Connection (Connection, ConnectionParams (..), + TLSSettings (..), connectTo, + connectionGetChunk, connectionPut, + initConnectionContext) +import Network.Socket (HostName, PortNumber) +import Network.WebSockets (ClientApp, ConnectionOptions, + Headers, defaultConnectionOptions, + runClientWithStream) +import Network.WebSockets.Stream (makeStream) instance Hashable N.SockAddr where @@ -49,34 +59,40 @@ instance N.HasReadWrite UdpAppData where -runTCPServer :: (String, Int) -> (N.AppData -> IO ()) -> IO () +runTCPServer :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO () runTCPServer (host, port) app = do putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port - _ <- N.runTCPServer (N.serverSettingsTCP port (fromString host)) app + _ <- N.runTCPServer (N.serverSettingsTCP (fromIntegral port) (fromString host)) app putStrLn "CLOSE tunnel" -runTCPClient :: (String, Int) -> (N.AppData -> IO ()) -> IO () +runTCPClient :: (HostName, PortNumber) -> (N.AppData -> IO ()) -> IO () runTCPClient (host, port) app = do putStrLn $ "CONNECTING to " <> tshow host <> ":" <> tshow port - void $ N.runTCPClient (N.clientSettingsTCP port (BC.pack host)) app + void $ N.runTCPClient (N.clientSettingsTCP (fromIntegral port) (BC.pack host)) app putStrLn $ "CLOSE connection to " <> tshow host <> ":" <> tshow port -runUDPClient :: (String, Int) -> (UdpAppData -> IO ()) -> IO () +runUDPClient :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO () runUDPClient (host, port) app = do putStrLn $ "CONNECTING to " <> tshow host <> ":" <> tshow port - (socket, addrInfo) <- N.getSocketUDP host port - sem <- newEmptyMVar - let appData = UdpAppData (N.addrAddress addrInfo) sem (fst <$> N.recvFrom socket 4096) (\payload -> void $ N.sendTo socket payload (N.addrAddress addrInfo)) - app appData + bracket + (N.getSocketUDP host (fromIntegral port)) + (N.close . fst) + (\(socket, addrInfo) -> do + sem <- newEmptyMVar + app UdpAppData { appAddr = N.addrAddress addrInfo + , appSem = sem + , appRead = fst <$> N.recvFrom socket 4096 + , appWrite = \payload -> void $ N.sendTo socket payload (N.addrAddress addrInfo) + }) + putStrLn $ "CLOSE connection to " <> tshow host <> ":" <> tshow port -runUDPServer :: (String, Int) -> (UdpAppData -> IO ()) -> IO () +runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO () runUDPServer (host, port) app = do putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port - sock <- N.bindPortUDP port (fromString host) notebook <- newMVar mempty - runEventLoop notebook sock + bracket (N.bindPortUDP (fromIntegral port) (fromString host)) N.close (runEventLoop notebook) putStrLn "CLOSE tunnel" where @@ -97,32 +113,32 @@ runUDPServer (host, port) app = do putMVar clientMapM (H.delete (appAddr appData') m) putStrLn "TIMEOUT connection" ) - (timeout (5 * 10^(6 :: Int)) . app) + (timeout (30 * 10^(6 :: Int)) . app) void $ async action runEventLoop clientMapM socket -runTunnelingClient :: Proto -> (String, Int) -> (String, Int) -> (WS.Connection -> IO ()) -> IO () +runTunnelingClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (WS.Connection -> IO ()) -> IO () runTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do putStrLn $ "OPEN connection to " <> tshow remoteHost <> ":" <> tshow remotePort - void $ WS.runClient wsHost wsPort ("/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort) app + void $ WS.runClient wsHost (fromIntegral wsPort) ("/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort) app putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort -runTunnelingServer :: (String, Int) -> IO () +runTunnelingServer :: (HostName, PortNumber) -> IO () runTunnelingServer (host, port) = do putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port - WS.runServer host port $ \pendingConn -> do + WS.runServer host (fromIntegral port) $ \pendingConn -> do let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn case path of Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information" Just (!proto, !rhost, !rport) -> do conn <- WS.acceptRequest pendingConn case proto of - UDP -> runUDPClient (BC.unpack rhost, rport) (propagateRW conn) - TCP -> runTCPClient (BC.unpack rhost, rport) (propagateRW conn) + UDP -> runUDPClient (BC.unpack rhost, (fromIntegral rport)) (propagateRW conn) + TCP -> runTCPClient (BC.unpack rhost, (fromIntegral rport)) (propagateRW conn) putStrLn "CLOSE server" @@ -151,12 +167,51 @@ propagateWrites hTunnel hOther = void . tryAny $ do unless (null payload) (WS.sendBinaryData hTunnel payload >> propagateWrites hTunnel hOther) -runClient :: Proto -> (String, Int) -> (String, Int) -> (String, Int) -> IO () +runClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (HostName, PortNumber) -> IO () runClient proto local wsServer remote = do - let out = runTunnelingClient proto wsServer remote + let out = runSecureClient proto wsServer remote case proto of UDP -> runUDPServer local (\hOther -> out (`propagateRW` hOther)) TCP -> runTCPServer local (\hOther -> out (`propagateRW` hOther)) -runServer :: (String, Int) -> IO () + +runServer :: (HostName, PortNumber) -> IO () runServer = runTunnelingServer + + +runSecureClient :: Proto -> (HostName, PortNumber) -> (HostName, PortNumber) -> (WS.Connection -> IO ()) -> IO () +runSecureClient proto (wsHost, wsPort) (remoteHost, remotePort) app = + let options = defaultConnectionOptions + headers = [] + in runSecureClientWith wsHost (fromIntegral wsPort) + ("/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort) + options headers app + + +runSecureClientWith :: HostName -> PortNumber -> String -> ConnectionOptions -> Headers -> ClientApp a -> IO a +runSecureClientWith host port path options headers app = do + context <- initConnectionContext + connection <- connectTo context (connectionParams host port) + stream <- makeStream (reader connection) (writer connection) + runClientWithStream stream host path options headers app + +connectionParams :: HostName -> PortNumber -> ConnectionParams +connectionParams host port = ConnectionParams + { connectionHostname = host + , connectionPort = port + , connectionUseSecure = Just tlsSettings + , connectionUseSocks = Nothing + } + +tlsSettings :: TLSSettings +tlsSettings = TLSSettingsSimple + { settingDisableCertificateValidation = True + , settingDisableSession = False + , settingUseServerName = False + } + +reader :: Connection -> IO (Maybe ByteString) +reader connection = fmap Just (connectionGetChunk connection) + +writer :: Connection -> Maybe BL.ByteString -> IO () +writer connection = maybe (return ()) (connectionPut connection . toStrict) diff --git a/wstunnel.cabal b/wstunnel.cabal index 01d602e..db510e5 100644 --- a/wstunnel.cabal +++ b/wstunnel.cabal @@ -17,13 +17,14 @@ library hs-source-dirs: src exposed-modules: Lib build-depends: base >= 4.7 && < 5 - , websockets , classy-prelude , bytestring - , streaming-commons >= 0.1.3 - , network , async , unordered-containers + , network + , streaming-commons >= 0.1.3 + , connection >= 0.2 + , websockets default-language: Haskell2010