Change Mvar to IORef + Signature
Better handling of exceptions
This commit is contained in:
parent
0340dc49f1
commit
a315f59673
1 changed files with 66 additions and 38 deletions
104
src/Lib.hs
104
src/Lib.hs
|
@ -12,26 +12,28 @@ module Lib
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import ClassyPrelude
|
import ClassyPrelude
|
||||||
import Control.Concurrent.Async (async, race_)
|
import Control.Concurrent.Async (async, asyncWithUnmask, race_)
|
||||||
import qualified Data.HashMap.Strict as H
|
import qualified Data.HashMap.Strict as H
|
||||||
import System.Timeout (timeout)
|
import System.Timeout (timeout)
|
||||||
|
|
||||||
import qualified Data.ByteString as B
|
import qualified Data.ByteString.Char8 as BC
|
||||||
import qualified Data.ByteString.Char8 as BC
|
|
||||||
|
|
||||||
import qualified Data.Streaming.Network as N
|
import qualified Data.Streaming.Network as N
|
||||||
import Network.Socket (HostName, PortNumber)
|
import Network.Socket (HostName, PortNumber)
|
||||||
import qualified Network.Socket as N hiding (recv, recvFrom, send,
|
import qualified Network.Socket as N hiding (recv, recvFrom,
|
||||||
sendTo)
|
send, sendTo)
|
||||||
import qualified Network.Socket.ByteString as N
|
import qualified Network.Socket.ByteString as N
|
||||||
|
|
||||||
import qualified Network.WebSockets as WS
|
import qualified Network.WebSockets as WS
|
||||||
import qualified Network.WebSockets.Stream as WS
|
import qualified Network.WebSockets.Connection as WS
|
||||||
|
import qualified Network.WebSockets.Stream as WS
|
||||||
|
|
||||||
import Network.Connection (Connection, ConnectionParams (..),
|
import Network.Connection (Connection,
|
||||||
TLSSettings (..), connectTo,
|
ConnectionParams (..),
|
||||||
connectionGetChunk, connectionPut,
|
TLSSettings (..), connectTo,
|
||||||
initConnectionContext)
|
connectionGetChunk,
|
||||||
|
connectionPut,
|
||||||
|
initConnectionContext)
|
||||||
|
|
||||||
|
|
||||||
instance Hashable N.SockAddr where
|
instance Hashable N.SockAddr where
|
||||||
|
@ -85,7 +87,7 @@ runUDPClient (host, port) app = do
|
||||||
runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
|
runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
|
||||||
runUDPServer (host, port) app = do
|
runUDPServer (host, port) app = do
|
||||||
putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port
|
putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port
|
||||||
clientsCtx <- newMVar mempty
|
clientsCtx <- newIORef mempty
|
||||||
void $ bracket
|
void $ bracket
|
||||||
(N.bindPortUDP (fromIntegral port) (fromString host))
|
(N.bindPortUDP (fromIntegral port) (fromString host))
|
||||||
N.close
|
N.close
|
||||||
|
@ -93,6 +95,8 @@ runUDPServer (host, port) app = do
|
||||||
putStrLn "CLOSE tunnel"
|
putStrLn "CLOSE tunnel"
|
||||||
|
|
||||||
where
|
where
|
||||||
|
addNewClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> N.SockAddr -> ByteString
|
||||||
|
-> IO UdpAppData
|
||||||
addNewClient clientsCtx socket addr payload = do
|
addNewClient clientsCtx socket addr payload = do
|
||||||
sem <- newMVar payload
|
sem <- newMVar payload
|
||||||
let appData = UdpAppData { appAddr = addr
|
let appData = UdpAppData { appAddr = addr
|
||||||
|
@ -100,18 +104,21 @@ runUDPServer (host, port) app = do
|
||||||
, appRead = takeMVar sem
|
, appRead = takeMVar sem
|
||||||
, appWrite = \payload' -> void $ N.sendTo socket payload' addr
|
, appWrite = \payload' -> void $ N.sendTo socket payload' addr
|
||||||
}
|
}
|
||||||
void $ withMVar clientsCtx (return . H.insert addr appData)
|
void $ atomicModifyIORef' clientsCtx (\clients -> (H.insert addr appData clients, ()))
|
||||||
return appData
|
return appData
|
||||||
|
|
||||||
|
removeClient :: IORef (H.HashMap N.SockAddr UdpAppData) -> UdpAppData -> IO ()
|
||||||
removeClient clientsCtx clientCtx = do
|
removeClient clientsCtx clientCtx = do
|
||||||
void $ withMVar clientsCtx (return . H.delete (appAddr clientCtx))
|
void $ atomicModifyIORef' clientsCtx (\clients -> (H.delete (appAddr clientCtx) clients, ()))
|
||||||
putStrLn "TIMEOUT connection"
|
putStrLn "TIMEOUT connection"
|
||||||
|
|
||||||
|
pushDataToClient :: UdpAppData -> ByteString -> IO ()
|
||||||
pushDataToClient clientCtx = putMVar (appSem clientCtx)
|
pushDataToClient clientCtx = putMVar (appSem clientCtx)
|
||||||
|
|
||||||
|
runEventLoop :: IORef (H.HashMap N.SockAddr UdpAppData) -> N.Socket -> IO ()
|
||||||
runEventLoop clientsCtx socket = forever $ do
|
runEventLoop clientsCtx socket = forever $ do
|
||||||
(payload, addr) <- N.recvFrom socket 4096
|
(payload, addr) <- N.recvFrom socket 4096
|
||||||
clientCtx <- H.lookup addr <$> readMVar clientsCtx
|
clientCtx <- H.lookup addr <$> readIORef clientsCtx
|
||||||
|
|
||||||
case clientCtx of
|
case clientCtx of
|
||||||
Just clientCtx' -> pushDataToClient clientCtx' payload
|
Just clientCtx' -> pushDataToClient clientCtx' payload
|
||||||
|
@ -128,27 +135,46 @@ runTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do
|
||||||
putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort
|
putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort
|
||||||
|
|
||||||
|
|
||||||
|
runApp :: N.Socket
|
||||||
|
-> WS.ConnectionOptions
|
||||||
|
-> WS.ServerApp
|
||||||
|
-> IO ()
|
||||||
|
runApp socket opts app =
|
||||||
|
bracket
|
||||||
|
(WS.makePendingConnection socket opts)
|
||||||
|
(WS.close . WS.pendingStream)
|
||||||
|
app
|
||||||
|
|
||||||
runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
|
runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
|
||||||
runTunnelingServer (host, port) isAllowed = do
|
runTunnelingServer (host, port) isAllowed = do
|
||||||
putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port
|
putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port
|
||||||
WS.runServer host (fromIntegral port) $ \pendingConn -> do
|
|
||||||
let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn
|
N.withSocketsDo $ bracket (WS.makeListenSocket host (fromIntegral port)) N.sClose (\sock ->
|
||||||
case path of
|
forever $ mask_ $ do
|
||||||
Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information"
|
(conn, _) <- N.accept sock
|
||||||
Just (!proto, !rhost, !rport) ->
|
void $ asyncWithUnmask $ \unmask ->
|
||||||
if not $ isAllowed (rhost, rport)
|
finally (unmask $ runApp conn WS.defaultConnectionOptions runEventLoop) (N.sClose conn)
|
||||||
then do
|
)
|
||||||
putStrLn "Rejecting tunneling"
|
|
||||||
WS.rejectRequest pendingConn "Restriction is on, You cannot request this tunneling"
|
|
||||||
else do
|
|
||||||
conn <- WS.acceptRequest pendingConn
|
|
||||||
case proto of
|
|
||||||
UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
|
|
||||||
TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
|
|
||||||
|
|
||||||
putStrLn "CLOSE server"
|
putStrLn "CLOSE server"
|
||||||
|
|
||||||
where
|
where
|
||||||
|
runEventLoop pendingConn = do
|
||||||
|
let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn
|
||||||
|
case path of
|
||||||
|
Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information"
|
||||||
|
Just (!proto, !rhost, !rport) ->
|
||||||
|
if not $ isAllowed (rhost, rport)
|
||||||
|
then do
|
||||||
|
putStrLn "Rejecting tunneling"
|
||||||
|
WS.rejectRequest pendingConn "Restriction is on, You cannot request this tunneling"
|
||||||
|
else do
|
||||||
|
conn <- WS.acceptRequest pendingConn
|
||||||
|
case proto of
|
||||||
|
UDP -> runUDPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
|
||||||
|
TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
|
||||||
|
|
||||||
|
|
||||||
parsePath :: ByteString -> Maybe (Proto, ByteString, Int)
|
parsePath :: ByteString -> Maybe (Proto, ByteString, Int)
|
||||||
parsePath path = let rets = BC.split '/' . BC.drop 1 $ path
|
parsePath path = let rets = BC.split '/' . BC.drop 1 $ path
|
||||||
in do
|
in do
|
||||||
|
@ -161,14 +187,16 @@ runTunnelingServer (host, port) isAllowed = do
|
||||||
|
|
||||||
propagateRW :: N.HasReadWrite a => WS.Connection -> a -> IO ()
|
propagateRW :: N.HasReadWrite a => WS.Connection -> a -> IO ()
|
||||||
propagateRW hTunnel hOther =
|
propagateRW hTunnel hOther =
|
||||||
void $ tryAny $ finally (race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther))
|
myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther)
|
||||||
(WS.sendClose hTunnel B.empty)
|
|
||||||
|
myTry :: IO () -> IO ()
|
||||||
|
myTry f = void $ catch f (\(e :: SomeException) -> print e)
|
||||||
|
|
||||||
propagateReads :: N.HasReadWrite a => WS.Connection -> a -> IO ()
|
propagateReads :: N.HasReadWrite a => WS.Connection -> a -> IO ()
|
||||||
propagateReads hTunnel hOther = void . tryAny . forever $ WS.receiveData hTunnel >>= N.appWrite hOther
|
propagateReads hTunnel hOther = myTry (forever $ WS.receiveData hTunnel >>= N.appWrite hOther)
|
||||||
|
|
||||||
propagateWrites :: N.HasReadWrite a => WS.Connection -> a -> IO ()
|
propagateWrites :: N.HasReadWrite a => WS.Connection -> a -> IO ()
|
||||||
propagateWrites hTunnel hOther = void . tryAny $ do
|
propagateWrites hTunnel hOther = myTry $ do
|
||||||
payload <- N.appRead hOther
|
payload <- N.appRead hOther
|
||||||
unless (null payload) (WS.sendBinaryData hTunnel payload >> propagateWrites hTunnel hOther)
|
unless (null payload) (WS.sendBinaryData hTunnel payload >> propagateWrites hTunnel hOther)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue