From a315f59673c61c6834ce5961e9ece0de06fc593b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Er=C3=A8be?= Date: Tue, 17 May 2016 16:01:03 +0200 Subject: [PATCH] Change Mvar to IORef + Signature Better handling of exceptions --- src/Lib.hs | 104 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 66 insertions(+), 38 deletions(-) diff --git a/src/Lib.hs b/src/Lib.hs index 39c90b8..5398fa1 100644 --- a/src/Lib.hs +++ b/src/Lib.hs @@ -12,26 +12,28 @@ module Lib ) where import ClassyPrelude -import Control.Concurrent.Async (async, race_) -import qualified Data.HashMap.Strict as H -import System.Timeout (timeout) +import Control.Concurrent.Async (async, asyncWithUnmask, race_) +import qualified Data.HashMap.Strict as H +import System.Timeout (timeout) -import qualified Data.ByteString as B -import qualified Data.ByteString.Char8 as BC +import qualified Data.ByteString.Char8 as BC -import qualified Data.Streaming.Network as N -import Network.Socket (HostName, PortNumber) -import qualified Network.Socket as N hiding (recv, recvFrom, send, - sendTo) -import qualified Network.Socket.ByteString as N +import qualified Data.Streaming.Network as N +import Network.Socket (HostName, PortNumber) +import qualified Network.Socket as N hiding (recv, recvFrom, + send, sendTo) +import qualified Network.Socket.ByteString as N -import qualified Network.WebSockets as WS -import qualified Network.WebSockets.Stream as WS +import qualified Network.WebSockets as WS +import qualified Network.WebSockets.Connection as WS +import qualified Network.WebSockets.Stream as WS -import Network.Connection (Connection, ConnectionParams (..), - TLSSettings (..), connectTo, - connectionGetChunk, connectionPut, - initConnectionContext) +import Network.Connection (Connection, + ConnectionParams (..), + TLSSettings (..), connectTo, + connectionGetChunk, + connectionPut, + initConnectionContext) instance Hashable N.SockAddr where @@ -85,7 +87,7 @@ runUDPClient (host, port) app = do runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO () runUDPServer (host, port) app = do putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port - clientsCtx <- newMVar mempty + clientsCtx <- newIORef mempty void $ bracket (N.bindPortUDP (fromIntegral port) (fromString host)) N.close @@ -93,6 +95,8 @@ runUDPServer (host, port) app = do putStrLn "CLOSE tunnel" where + addNewClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> N.SockAddr -> ByteString + -> IO UdpAppData addNewClient clientsCtx socket addr payload = do sem <- newMVar payload let appData = UdpAppData { appAddr = addr @@ -100,18 +104,21 @@ runUDPServer (host, port) app = do , appRead = takeMVar sem , appWrite = \payload' -> void $ N.sendTo socket payload' addr } - void $ withMVar clientsCtx (return . H.insert addr appData) + void $ atomicModifyIORef' clientsCtx (\clients -> (H.insert addr appData clients, ())) return appData + removeClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> UdpAppData -> IO () removeClient clientsCtx clientCtx = do - void $ withMVar clientsCtx (return . H.delete (appAddr clientCtx)) + void $ atomicModifyIORef' clientsCtx (\clients -> (H.delete (appAddr clientCtx) clients, ())) putStrLn "TIMEOUT connection" + pushDataToClient :: UdpAppData -> ByteString -> IO () pushDataToClient clientCtx = putMVar (appSem clientCtx) + runEventLoop :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> IO () runEventLoop clientsCtx socket = forever $ do (payload, addr) <- N.recvFrom socket 4096 - clientCtx <- H.lookup addr <$> readMVar clientsCtx + clientCtx <- H.lookup addr <$> readIORef clientsCtx case clientCtx of Just clientCtx' -> pushDataToClient clientCtx' payload @@ -128,27 +135,46 @@ runTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort +runApp :: N.Socket + -> WS.ConnectionOptions + -> WS.ServerApp + -> IO () +runApp socket opts app = + bracket + (WS.makePendingConnection socket opts) + (WS.close . WS.pendingStream) + app + runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () runTunnelingServer (host, port) isAllowed = do putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port - WS.runServer host (fromIntegral port) $ \pendingConn -> do - let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn - case path of - Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information" - Just (!proto, !rhost, !rport) -> - if not $ isAllowed (rhost, rport) - then do - putStrLn "Rejecting tunneling" - WS.rejectRequest pendingConn "Restriction is on, You cannot request this tunneling" - else do - conn <- WS.acceptRequest pendingConn - case proto of - UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn) - TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn) + + N.withSocketsDo $ bracket (WS.makeListenSocket host (fromIntegral port)) N.sClose (\sock -> + forever $ mask_ $ do + (conn, _) <- N.accept sock + void $ asyncWithUnmask $ \unmask -> + finally (unmask $ runApp conn WS.defaultConnectionOptions runEventLoop) (N.sClose conn) + ) putStrLn "CLOSE server" where + runEventLoop pendingConn = do + let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn + case path of + Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information" + Just (!proto, !rhost, !rport) -> + if not $ isAllowed (rhost, rport) + then do + putStrLn "Rejecting tunneling" + WS.rejectRequest pendingConn "Restriction is on, You cannot request this tunneling" + else do + conn <- WS.acceptRequest pendingConn + case proto of + UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn) + TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn) + + parsePath :: ByteString -> Maybe (Proto, ByteString, Int) parsePath path = let rets = BC.split '/' . BC.drop 1 $ path in do @@ -161,14 +187,16 @@ runTunnelingServer (host, port) isAllowed = do propagateRW :: N.HasReadWrite a => WS.Connection -> a -> IO () propagateRW hTunnel hOther = - void $ tryAny $ finally (race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther)) - (WS.sendClose hTunnel B.empty) + myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther) + +myTry :: IO () -> IO () +myTry f = void $ catch f (\(e :: SomeException) -> print e) propagateReads :: N.HasReadWrite a => WS.Connection -> a -> IO () -propagateReads hTunnel hOther = void . tryAny . forever $ WS.receiveData hTunnel >>= N.appWrite hOther +propagateReads hTunnel hOther = myTry (forever $ WS.receiveData hTunnel >>= N.appWrite hOther) propagateWrites :: N.HasReadWrite a => WS.Connection -> a -> IO () -propagateWrites hTunnel hOther = void . tryAny $ do +propagateWrites hTunnel hOther = myTry $ do payload <- N.appRead hOther unless (null payload) (WS.sendBinaryData hTunnel payload >> propagateWrites hTunnel hOther)