diff --git a/src/Lib.hs b/src/Lib.hs index 5398fa1..47d2e6f 100644 --- a/src/Lib.hs +++ b/src/Lib.hs @@ -12,8 +12,9 @@ module Lib ) where import ClassyPrelude -import Control.Concurrent.Async (async, asyncWithUnmask, race_) +import Control.Concurrent.Async (async, race_) import qualified Data.HashMap.Strict as H +import Data.Maybe (fromJust) import System.Timeout (timeout) import qualified Data.ByteString.Char8 as BC @@ -76,9 +77,9 @@ runUDPClient (host, port) app = do putStrLn $ "CONNECTING to " <> tshow host <> ":" <> tshow port bracket (N.getSocketUDP host (fromIntegral port)) (N.close . fst) $ \(socket, addrInfo) -> do sem <- newEmptyMVar - app UdpAppData { appAddr = N.addrAddress addrInfo - , appSem = sem - , appRead = fst <$> N.recvFrom socket 4096 + app UdpAppData { appAddr = N.addrAddress addrInfo + , appSem = sem + , appRead = fst <$> N.recvFrom socket 4096 , appWrite = \payload -> void $ N.sendTo socket payload (N.addrAddress addrInfo) } @@ -88,10 +89,7 @@ runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO () runUDPServer (host, port) app = do putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port clientsCtx <- newIORef mempty - void $ bracket - (N.bindPortUDP (fromIntegral port) (fromString host)) - N.close - (runEventLoop clientsCtx) + void $ bracket (N.bindPortUDP (fromIntegral port) (fromString host)) N.close (runEventLoop clientsCtx) putStrLn "CLOSE tunnel" where @@ -100,10 +98,10 @@ runUDPServer (host, port) app = do addNewClient clientsCtx socket addr payload = do sem <- newMVar payload let appData = UdpAppData { appAddr = addr - , appSem = sem - , appRead = takeMVar sem - , appWrite = \payload' -> void $ N.sendTo socket payload' addr - } + , appSem = sem + , appRead = takeMVar sem + , appWrite = \payload' -> void $ N.sendTo socket payload' addr + } void $ atomicModifyIORef' clientsCtx (\clients -> (H.insert addr appData clients, ())) return appData @@ -122,7 +120,7 @@ runUDPServer (host, port) app = do case clientCtx of Just clientCtx' -> pushDataToClient clientCtx' payload - _ -> void $ async $ bracket + _ -> void . async $ bracket (addNewClient clientsCtx socket addr payload) (removeClient clientsCtx) (timeout (30 * 10^(6 :: Int)) . app) @@ -135,32 +133,23 @@ runTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort -runApp :: N.Socket - -> WS.ConnectionOptions - -> WS.ServerApp - -> IO () -runApp socket opts app = - bracket - (WS.makePendingConnection socket opts) - (WS.close . WS.pendingStream) - app runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () runTunnelingServer (host, port) isAllowed = do putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port - N.withSocketsDo $ bracket (WS.makeListenSocket host (fromIntegral port)) N.sClose (\sock -> - forever $ mask_ $ do - (conn, _) <- N.accept sock - void $ asyncWithUnmask $ \unmask -> - finally (unmask $ runApp conn WS.defaultConnectionOptions runEventLoop) (N.sClose conn) - ) + void $ N.runTCPServer (N.serverSettingsTCP (fromIntegral port) (fromString host)) $ \sClient -> + runApp (fromJust $ N.appRawSocket sClient) WS.defaultConnectionOptions runEventLoop putStrLn "CLOSE server" where + runApp :: N.Socket -> WS.ConnectionOptions -> WS.ServerApp -> IO () + runApp socket opts = bracket (WS.makePendingConnection socket opts) + (\conn -> catch (WS.close $ WS.pendingStream conn) (\(_ :: SomeException) -> return ())) + runEventLoop pendingConn = do - let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn + let path = fromPath . WS.requestPath $ WS.pendingRequest pendingConn case path of Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information" Just (!proto, !rhost, !rport) -> @@ -175,14 +164,6 @@ runTunnelingServer (host, port) isAllowed = do TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn) - parsePath :: ByteString -> Maybe (Proto, ByteString, Int) - parsePath path = let rets = BC.split '/' . BC.drop 1 $ path - in do - guard (length rets == 3) - let [protocol, h, prt] = rets - prt' <- readMay . BC.unpack $ prt :: Maybe Int - proto <- readMay . toUpper . BC.unpack $ protocol :: Maybe Proto - return (proto, h, prt') propagateRW :: N.HasReadWrite a => WS.Connection -> a -> IO () @@ -190,7 +171,7 @@ propagateRW hTunnel hOther = myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther) myTry :: IO () -> IO () -myTry f = void $ catch f (\(e :: SomeException) -> print e) +myTry f = void $ catch f (\(_ :: SomeException) -> return ()) propagateReads :: N.HasReadWrite a => WS.Connection -> a -> IO () propagateReads hTunnel hOther = myTry (forever $ WS.receiveData hTunnel >>= N.appWrite hOther) @@ -244,5 +225,15 @@ reader connection = fmap Just (connectionGetChunk connection) writer :: Connection -> Maybe LByteString -> IO () writer connection = maybe (return ()) (connectionPut connection . toStrict) + toPath :: Proto -> HostName -> PortNumber -> String toPath proto remoteHost remotePort = "/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort + +fromPath :: ByteString -> Maybe (Proto, ByteString, Int) +fromPath path = let rets = BC.split '/' . BC.drop 1 $ path + in do + guard (length rets == 3) + let [protocol, h, prt] = rets + prt' <- readMay . BC.unpack $ prt :: Maybe Int + proto <- readMay . toUpper . BC.unpack $ protocol :: Maybe Proto + return (proto, h, prt')