Change Mvar to IORef + Signature

Better handling of exceptions
This commit is contained in:
Erèbe 2016-05-17 16:01:03 +02:00
parent 0340dc49f1
commit a315f59673

View file

@ -12,26 +12,28 @@ module Lib
) where
import ClassyPrelude
import Control.Concurrent.Async (async, race_)
import qualified Data.HashMap.Strict as H
import System.Timeout (timeout)
import Control.Concurrent.Async (async, asyncWithUnmask, race_)
import qualified Data.HashMap.Strict as H
import System.Timeout (timeout)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Char8 as BC
import qualified Data.Streaming.Network as N
import Network.Socket (HostName, PortNumber)
import qualified Network.Socket as N hiding (recv, recvFrom, send,
sendTo)
import qualified Network.Socket.ByteString as N
import qualified Data.Streaming.Network as N
import Network.Socket (HostName, PortNumber)
import qualified Network.Socket as N hiding (recv, recvFrom,
send, sendTo)
import qualified Network.Socket.ByteString as N
import qualified Network.WebSockets as WS
import qualified Network.WebSockets.Stream as WS
import qualified Network.WebSockets as WS
import qualified Network.WebSockets.Connection as WS
import qualified Network.WebSockets.Stream as WS
import Network.Connection (Connection, ConnectionParams (..),
TLSSettings (..), connectTo,
connectionGetChunk, connectionPut,
initConnectionContext)
import Network.Connection (Connection,
ConnectionParams (..),
TLSSettings (..), connectTo,
connectionGetChunk,
connectionPut,
initConnectionContext)
instance Hashable N.SockAddr where
@ -85,7 +87,7 @@ runUDPClient (host, port) app = do
runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
runUDPServer (host, port) app = do
putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port
clientsCtx <- newMVar mempty
clientsCtx <- newIORef mempty
void $ bracket
(N.bindPortUDP (fromIntegral port) (fromString host))
N.close
@ -93,6 +95,8 @@ runUDPServer (host, port) app = do
putStrLn "CLOSE tunnel"
where
addNewClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> N.SockAddr -> ByteString
-> IO UdpAppData
addNewClient clientsCtx socket addr payload = do
sem <- newMVar payload
let appData = UdpAppData { appAddr = addr
@ -100,18 +104,21 @@ runUDPServer (host, port) app = do
, appRead = takeMVar sem
, appWrite = \payload' -> void $ N.sendTo socket payload' addr
}
void $ withMVar clientsCtx (return . H.insert addr appData)
void $ atomicModifyIORef' clientsCtx (\clients -> (H.insert addr appData clients, ()))
return appData
removeClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> UdpAppData -> IO ()
removeClient clientsCtx clientCtx = do
void $ withMVar clientsCtx (return . H.delete (appAddr clientCtx))
void $ atomicModifyIORef' clientsCtx (\clients -> (H.delete (appAddr clientCtx) clients, ()))
putStrLn "TIMEOUT connection"
pushDataToClient :: UdpAppData -> ByteString -> IO ()
pushDataToClient clientCtx = putMVar (appSem clientCtx)
runEventLoop :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> IO ()
runEventLoop clientsCtx socket = forever $ do
(payload, addr) <- N.recvFrom socket 4096
clientCtx <- H.lookup addr <$> readMVar clientsCtx
clientCtx <- H.lookup addr <$> readIORef clientsCtx
case clientCtx of
Just clientCtx' -> pushDataToClient clientCtx' payload
@ -128,27 +135,46 @@ runTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do
putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort
runApp :: N.Socket
-> WS.ConnectionOptions
-> WS.ServerApp
-> IO ()
runApp socket opts app =
bracket
(WS.makePendingConnection socket opts)
(WS.close . WS.pendingStream)
app
runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runTunnelingServer (host, port) isAllowed = do
putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port
WS.runServer host (fromIntegral port) $ \pendingConn -> do
let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn
case path of
Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information"
Just (!proto, !rhost, !rport) ->
if not $ isAllowed (rhost, rport)
then do
putStrLn "Rejecting tunneling"
WS.rejectRequest pendingConn "Restriction is on, You cannot request this tunneling"
else do
conn <- WS.acceptRequest pendingConn
case proto of
UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
N.withSocketsDo $ bracket (WS.makeListenSocket host (fromIntegral port)) N.sClose (\sock ->
forever $ mask_ $ do
(conn, _) <- N.accept sock
void $ asyncWithUnmask $ \unmask ->
finally (unmask $ runApp conn WS.defaultConnectionOptions runEventLoop) (N.sClose conn)
)
putStrLn "CLOSE server"
where
runEventLoop pendingConn = do
let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn
case path of
Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information"
Just (!proto, !rhost, !rport) ->
if not $ isAllowed (rhost, rport)
then do
putStrLn "Rejecting tunneling"
WS.rejectRequest pendingConn "Restriction is on, You cannot request this tunneling"
else do
conn <- WS.acceptRequest pendingConn
case proto of
UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
parsePath :: ByteString -> Maybe (Proto, ByteString, Int)
parsePath path = let rets = BC.split '/' . BC.drop 1 $ path
in do
@ -161,14 +187,16 @@ runTunnelingServer (host, port) isAllowed = do
propagateRW :: N.HasReadWrite a => WS.Connection -> a -> IO ()
propagateRW hTunnel hOther =
void $ tryAny $ finally (race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther))
(WS.sendClose hTunnel B.empty)
myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther)
myTry :: IO () -> IO ()
myTry f = void $ catch f (\(e :: SomeException) -> print e)
propagateReads :: N.HasReadWrite a => WS.Connection -> a -> IO ()
propagateReads hTunnel hOther = void . tryAny . forever $ WS.receiveData hTunnel >>= N.appWrite hOther
propagateReads hTunnel hOther = myTry (forever $ WS.receiveData hTunnel >>= N.appWrite hOther)
propagateWrites :: N.HasReadWrite a => WS.Connection -> a -> IO ()
propagateWrites hTunnel hOther = void . tryAny $ do
propagateWrites hTunnel hOther = myTry $ do
payload <- N.appRead hOther
unless (null payload) (WS.sendBinaryData hTunnel payload >> propagateWrites hTunnel hOther)