Better handling of exceptions

This commit is contained in:
Erèbe 2016-05-18 23:25:07 +02:00
parent a315f59673
commit 51fdf79ed1

View file

@ -12,8 +12,9 @@ module Lib
) where ) where
import ClassyPrelude import ClassyPrelude
import Control.Concurrent.Async (async, asyncWithUnmask, race_) import Control.Concurrent.Async (async, race_)
import qualified Data.HashMap.Strict as H import qualified Data.HashMap.Strict as H
import Data.Maybe (fromJust)
import System.Timeout (timeout) import System.Timeout (timeout)
import qualified Data.ByteString.Char8 as BC import qualified Data.ByteString.Char8 as BC
@ -76,9 +77,9 @@ runUDPClient (host, port) app = do
putStrLn $ "CONNECTING to " <> tshow host <> ":" <> tshow port putStrLn $ "CONNECTING to " <> tshow host <> ":" <> tshow port
bracket (N.getSocketUDP host (fromIntegral port)) (N.close . fst) $ \(socket, addrInfo) -> do bracket (N.getSocketUDP host (fromIntegral port)) (N.close . fst) $ \(socket, addrInfo) -> do
sem <- newEmptyMVar sem <- newEmptyMVar
app UdpAppData { appAddr = N.addrAddress addrInfo app UdpAppData { appAddr = N.addrAddress addrInfo
, appSem = sem , appSem = sem
, appRead = fst <$> N.recvFrom socket 4096 , appRead = fst <$> N.recvFrom socket 4096
, appWrite = \payload -> void $ N.sendTo socket payload (N.addrAddress addrInfo) , appWrite = \payload -> void $ N.sendTo socket payload (N.addrAddress addrInfo)
} }
@ -88,10 +89,7 @@ runUDPServer :: (HostName, PortNumber) -> (UdpAppData -> IO ()) -> IO ()
runUDPServer (host, port) app = do runUDPServer (host, port) app = do
putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port putStrLn $ "WAIT for datagrames on " <> tshow host <> ":" <> tshow port
clientsCtx <- newIORef mempty clientsCtx <- newIORef mempty
void $ bracket void $ bracket (N.bindPortUDP (fromIntegral port) (fromString host)) N.close (runEventLoop clientsCtx)
(N.bindPortUDP (fromIntegral port) (fromString host))
N.close
(runEventLoop clientsCtx)
putStrLn "CLOSE tunnel" putStrLn "CLOSE tunnel"
where where
@ -100,10 +98,10 @@ runUDPServer (host, port) app = do
addNewClient clientsCtx socket addr payload = do addNewClient clientsCtx socket addr payload = do
sem <- newMVar payload sem <- newMVar payload
let appData = UdpAppData { appAddr = addr let appData = UdpAppData { appAddr = addr
, appSem = sem , appSem = sem
, appRead = takeMVar sem , appRead = takeMVar sem
, appWrite = \payload' -> void $ N.sendTo socket payload' addr , appWrite = \payload' -> void $ N.sendTo socket payload' addr
} }
void $ atomicModifyIORef' clientsCtx (\clients -> (H.insert addr appData clients, ())) void $ atomicModifyIORef' clientsCtx (\clients -> (H.insert addr appData clients, ()))
return appData return appData
@ -122,7 +120,7 @@ runUDPServer (host, port) app = do
case clientCtx of case clientCtx of
Just clientCtx' -> pushDataToClient clientCtx' payload Just clientCtx' -> pushDataToClient clientCtx' payload
_ -> void $ async $ bracket _ -> void . async $ bracket
(addNewClient clientsCtx socket addr payload) (addNewClient clientsCtx socket addr payload)
(removeClient clientsCtx) (removeClient clientsCtx)
(timeout (30 * 10^(6 :: Int)) . app) (timeout (30 * 10^(6 :: Int)) . app)
@ -135,32 +133,23 @@ runTunnelingClient proto (wsHost, wsPort) (remoteHost, remotePort) app = do
putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort putStrLn $ "CLOSE connection to " <> tshow remoteHost <> ":" <> tshow remotePort
runApp :: N.Socket
-> WS.ConnectionOptions
-> WS.ServerApp
-> IO ()
runApp socket opts app =
bracket
(WS.makePendingConnection socket opts)
(WS.close . WS.pendingStream)
app
runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO () runTunnelingServer :: (HostName, PortNumber) -> ((ByteString, Int) -> Bool) -> IO ()
runTunnelingServer (host, port) isAllowed = do runTunnelingServer (host, port) isAllowed = do
putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port putStrLn $ "WAIT for connection on " <> tshow host <> ":" <> tshow port
N.withSocketsDo $ bracket (WS.makeListenSocket host (fromIntegral port)) N.sClose (\sock -> void $ N.runTCPServer (N.serverSettingsTCP (fromIntegral port) (fromString host)) $ \sClient ->
forever $ mask_ $ do runApp (fromJust $ N.appRawSocket sClient) WS.defaultConnectionOptions runEventLoop
(conn, _) <- N.accept sock
void $ asyncWithUnmask $ \unmask ->
finally (unmask $ runApp conn WS.defaultConnectionOptions runEventLoop) (N.sClose conn)
)
putStrLn "CLOSE server" putStrLn "CLOSE server"
where where
runApp :: N.Socket -> WS.ConnectionOptions -> WS.ServerApp -> IO ()
runApp socket opts = bracket (WS.makePendingConnection socket opts)
(\conn -> catch (WS.close $ WS.pendingStream conn) (\(_ :: SomeException) -> return ()))
runEventLoop pendingConn = do runEventLoop pendingConn = do
let path = parsePath . WS.requestPath $ WS.pendingRequest pendingConn let path = fromPath . WS.requestPath $ WS.pendingRequest pendingConn
case path of case path of
Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information" Nothing -> putStrLn "Rejecting connection" >> WS.rejectRequest pendingConn "Invalid tunneling information"
Just (!proto, !rhost, !rport) -> Just (!proto, !rhost, !rport) ->
@ -175,14 +164,6 @@ runTunnelingServer (host, port) isAllowed = do
TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn) TCP -> runTCPClient (BC.unpack rhost, fromIntegral rport) (propagateRW conn)
parsePath :: ByteString -> Maybe (Proto, ByteString, Int)
parsePath path = let rets = BC.split '/' . BC.drop 1 $ path
in do
guard (length rets == 3)
let [protocol, h, prt] = rets
prt' <- readMay . BC.unpack $ prt :: Maybe Int
proto <- readMay . toUpper . BC.unpack $ protocol :: Maybe Proto
return (proto, h, prt')
propagateRW :: N.HasReadWrite a => WS.Connection -> a -> IO () propagateRW :: N.HasReadWrite a => WS.Connection -> a -> IO ()
@ -190,7 +171,7 @@ propagateRW hTunnel hOther =
myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther) myTry $ race_ (propagateReads hTunnel hOther) (propagateWrites hTunnel hOther)
myTry :: IO () -> IO () myTry :: IO () -> IO ()
myTry f = void $ catch f (\(e :: SomeException) -> print e) myTry f = void $ catch f (\(_ :: SomeException) -> return ())
propagateReads :: N.HasReadWrite a => WS.Connection -> a -> IO () propagateReads :: N.HasReadWrite a => WS.Connection -> a -> IO ()
propagateReads hTunnel hOther = myTry (forever $ WS.receiveData hTunnel >>= N.appWrite hOther) propagateReads hTunnel hOther = myTry (forever $ WS.receiveData hTunnel >>= N.appWrite hOther)
@ -244,5 +225,15 @@ reader connection = fmap Just (connectionGetChunk connection)
writer :: Connection -> Maybe LByteString -> IO () writer :: Connection -> Maybe LByteString -> IO ()
writer connection = maybe (return ()) (connectionPut connection . toStrict) writer connection = maybe (return ()) (connectionPut connection . toStrict)
toPath :: Proto -> HostName -> PortNumber -> String toPath :: Proto -> HostName -> PortNumber -> String
toPath proto remoteHost remotePort = "/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort toPath proto remoteHost remotePort = "/" <> toLower (show proto) <> "/" <> remoteHost <> "/" <> show remotePort
fromPath :: ByteString -> Maybe (Proto, ByteString, Int)
fromPath path = let rets = BC.split '/' . BC.drop 1 $ path
in do
guard (length rets == 3)
let [protocol, h, prt] = rets
prt' <- readMay . BC.unpack $ prt :: Maybe Int
proto <- readMay . toUpper . BC.unpack $ protocol :: Maybe Proto
return (proto, h, prt')