Change Mvar to IORef + Signature

Better handling of exceptions
This commit is contained in:
Erèbe 2016-05-17 16:01:03 +02:00
parent 0340dc49f1
commit a315f59673

View file

@ -12,26 +12,28 @@ module Lib
) where ) where
import ClassyPrelude import ClassyPrelude
import Control.Concurrent.Async (async, race_) import Control.Concurrent.Async (async, asyncWithUnmask, race_)
import qualified Data.HashMap.Strict as H import qualified Data.HashMap.Strict as H
import System.Timeout (timeout) import System.Timeout (timeout)
import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Char8 as BC
import qualified Data.Streaming.Network as N import qualified Data.Streaming.Network as N
import Network.Socket (HostName, PortNumber) import Network.Socket (HostName, PortNumber)
import qualified Network.Socket as N hiding (recv, recvFrom, send, import qualified Network.Socket as N hiding (recv, recvFrom,
sendTo) send, sendTo)
import qualified Network.Socket.ByteString as N import qualified Network.Socket.ByteString as N
import qualified Network.WebSockets as WS import qualified Network.WebSockets as WS
import qualified Network.WebSockets.Stream as WS import qualified Network.WebSockets.Connection as WS
import qualified Network.WebSockets.Stream as WS
import Network.Connection (Connection, ConnectionParams (..), import Network.Connection (Connection,
TLSSettings (..), connectTo, ConnectionParams (..),
connectionGetChunk, connectionPut, TLSSettings (..), connectTo,
initConnectionContext) connectionGetChunk,
connectionPut,
initConnectionContext)
instance Hashable N.SockAddr where instance Hashable N.SockAddr where
@ -85,7 +87,7 @@ runUDPClient (host, port) app = do
runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO () runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
runUDPServer (host, port) app = do runUDPServer (host, port) app = do
putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port
clientsCtx <- newMVar mempty clientsCtx <- newIORef mempty
void $ bracket void $ bracket
(N.bindPortUDP (fromIntegral port) (fromString host)) (N.bindPortUDP (fromIntegral port) (fromString host))
N.close N.close
@ -93,6 +95,8 @@ runUDPServer (host, port) app = do
putStrLn "CLOSE tunnel" putStrLn "CLOSE tunnel"
where where
addNewClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> N.SockAddr -> ByteString
-> IO UdpAppData
addNewClient clientsCtx socket addr payload = do addNewClient clientsCtx socket addr payload = do
sem <- newMVar payload sem <- newMVar payload
let appData = UdpAppData { appAddr = addr let appData = UdpAppData { appAddr = addr
@ -100,18 +104,21 @@ runUDPServer (host, port) app = do
, appRead = takeMVar sem , appRead = takeMVar sem
, appWrite = \payload' -> void $ N.sendTo socket payload' addr , appWrite = \payload' -> void $ N.sendTo socket payload' addr
} }
void $ withMVar clientsCtx (return . H.insert addr appData) void $ atomicModifyIORef' clientsCtx (\clients -> (H.insert addr appData clients, ()))
return appData return appData
removeClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> UdpAppData -> IO ()
removeClient clientsCtx clientCtx = do removeClient clientsCtx clientCtx = do
void $ withMVar clientsCtx (return . H.delete (appAddr clientCtx)) void $ atomicModifyIORef' clientsCtx (\clients -> (H.delete (appAddr clientCtx) clients, ()))
putStrLn "TIMEOUT connection" putStrLn "TIMEOUT connection"
pushDataToClient :: UdpAppData -> ByteString -> IO ()
pushDataToClient clientCtx = putMVar (appSem clientCtx) pushDataToClient clientCtx = putMVar (appSem clientCtx)
runEventLoop :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> IO ()
runEventLoop clientsCtx socket = forever $ do runEventLoop clientsCtx socket = forever $ do
(payload, addr) <- N.recvFrom socket 4096 (payload, addr) <- N.recvFrom socket 4096
clientCtx <- H.lookup addr <$> readMVar clientsCtx clientCtx <- H.lookup addr <$> readIORef clientsCtx
case clientCtx of case clientCtx of
Just clientCtx' -> pushDataToClient clientCtx' payload Just clientCtx' -> pushDataToClient clientCtx' payload
@ -128,27 +135,46 @@ runTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do
putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort
runApp :: N.Socket
-> WS.ConnectionOptions
-> WS.ServerApp
-> IO ()
runApp socket opts app =
bracket
(WS.makePendingConnection socket opts)
(WS.close . WS.pendingStream)
app
runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runTunnelingServer (host, port) isAllowed = do runTunnelingServer (host, port) isAllowed = do
putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port
WS.runServer host (fromIntegral port) $ \pendingConn -> do
let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn N.withSocketsDo $ bracket (WS.makeListenSocket host (fromIntegral port)) N.sClose (\sock ->
case path of forever $ mask_ $ do
Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information" (conn, _) <- N.accept sock
Just (!proto, !rhost, !rport) -> void $ asyncWithUnmask $ \unmask ->
if not $ isAllowed (rhost, rport) finally (unmask $ runApp conn WS.defaultConnectionOptions runEventLoop) (N.sClose conn)
then do )
putStrLn "Rejecting tunneling"
WS.rejectRequest pendingConn "Restriction is on, You cannot request this tunneling"
else do
conn <- WS.acceptRequest pendingConn
case proto of
UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
putStrLn "CLOSE server" putStrLn "CLOSE server"
where where
runEventLoop pendingConn = do
let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn
case path of
Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information"
Just (!proto, !rhost, !rport) ->
if not $ isAllowed (rhost, rport)
then do
putStrLn "Rejecting tunneling"
WS.rejectRequest pendingConn "Restriction is on, You cannot request this tunneling"
else do
conn <- WS.acceptRequest pendingConn
case proto of
UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
parsePath :: ByteString -> Maybe (Proto, ByteString, Int) parsePath :: ByteString -> Maybe (Proto, ByteString, Int)
parsePath path = let rets = BC.split '/' . BC.drop 1 $ path parsePath path = let rets = BC.split '/' . BC.drop 1 $ path
in do in do
@ -161,14 +187,16 @@ runTunnelingServer (host, port) isAllowed = do
propagateRW :: N.HasReadWrite a => WS.Connection -> a -> IO () propagateRW :: N.HasReadWrite a => WS.Connection -> a -> IO ()
propagateRW hTunnel hOther = propagateRW hTunnel hOther =
void $ tryAny $ finally (race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther)) myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther)
(WS.sendClose hTunnel B.empty)
myTry :: IO () -> IO ()
myTry f = void $ catch f (\(e :: SomeException) -> print e)
propagateReads :: N.HasReadWrite a => WS.Connection -> a -> IO () propagateReads :: N.HasReadWrite a => WS.Connection -> a -> IO ()
propagateReads hTunnel hOther = void . tryAny . forever $ WS.receiveData hTunnel >>= N.appWrite hOther propagateReads hTunnel hOther = myTry (forever $ WS.receiveData hTunnel >>= N.appWrite hOther)
propagateWrites :: N.HasReadWrite a => WS.Connection -> a -> IO () propagateWrites :: N.HasReadWrite a => WS.Connection -> a -> IO ()
propagateWrites hTunnel hOther = void . tryAny $ do propagateWrites hTunnel hOther = myTry $ do
payload <- N.appRead hOther payload <- N.appRead hOther
unless (null payload) (WS.sendBinaryData hTunnel payload >> propagateWrites hTunnel hOther) unless (null payload) (WS.sendBinaryData hTunnel payload >> propagateWrites hTunnel hOther)